#include "LockFreeQueue.h"

Node *NewNode(void *init_val) {
	Node *new_node = (Node*) malloc(sizeof(Node));
  
	new_node->val = init_val;
	new_node->next = NULL;
	
	return new_node;
}

LockFreeQueue *NewLockFreeQueue() {
	Node *init_head_tail = NewNode(NULL);
	LockFreeQueue *new_queue = 
        (LockFreeQueue*) malloc(sizeof(LockFreeQueue));
  
	new_queue->head = new_queue->tail = init_head_tail;
  
	return new_queue;
}

/* compare and exchange functions */
void InitCE (pthread_mutex_t *ce_lock) {
	pthread_mutex_init(ce_lock, NULL);
	return;
}

void DestroyCE (pthread_mutex_t *ce_lock) {
	pthread_mutex_destroy(ce_lock);
	return;
}

// return 1 for true and 0 for fasle
int CompareAndExchange (void **destination, void *value, void *comparand) {
	int success = 0;

	pthread_mutex_lock(&CE_MUTEX);
	if ((*destination) == comparand) {
		(*destination) = value;
		success = 1;
	}
	pthread_mutex_unlock(&CE_MUTEX);
	
	return success;
}

/* lock free queue enqueue */
void LockFreeQueue_enq(void *x) {
	Node *last = NULL;
	Node *next = NULL;
	Node *node = NewNode(x);

	while (1) {
		last = this_queue->tail;
		next = last->next;

		if (last == this_queue->tail) {
			if (next == NULL) {
				if ( CompareAndExchange(&(this_queue->tail->next), node, next) ) {
					CompareAndExchange(&(this_queue->tail), node, last);
					return;
				} 
			} else {
				CompareAndExchange(&(this_queue->tail), next, last);
			}
		}
	}
	
	return;
}

void *LockFreeQueue_deq() {
	Node *first = NULL;
	Node *last = NULL;
	Node *next = NULL;
	void *value = NULL;

	while (1) {
		first = this_queue->head;
		last = this_queue->tail;
		next = first->next;

		if (first == this_queue->head) {
			if (first == last) {
				if (next == NULL) {
					fprintf(stderr, "Nothing to deq\n");
					return NULL;
					//exit(1);
				}
				CompareAndExchange(&(this_queue->tail), next, last);
			} else {
				value = next->val;
				if ( CompareAndExchange(&(this_queue->head), next, first) ) 
					return value;
			}
		}
	}

	return NULL;
}

void LockFreeQueue_enq_WithError1 (void *x) {
	Node *last = NULL;
	Node *next = NULL;
	Node *node = NewNode(x);

	while (1) {
		last = this_queue->tail;
		next = last->next;
		
		if (last == this_queue->tail) {
			if (next == NULL) {
				if ( CompareAndExchange(&(this_queue->tail->next), node, next) ) {
					if (this_queue->tail == last) 
						this_queue->tail = node;
					return;
				}
			} else {
				CompareAndExchange(&(this_queue->tail), next, last);
			}
		}
	}
	
	return ;
}

void LockFreeQueue_enq_WithError2 (void *x) {
	Node *last = NULL;
	Node *next = NULL;
	Node *node = NewNode(x);

	while (1) {
		last = this_queue->tail;
		next = last->next;
		
		if (last == this_queue->tail) {
			if (next == NULL) {
				if ( CompareAndExchange(&(this_queue->tail->next), node, next) ) {
					CompareAndExchange(&(this_queue->tail), node, last);
					return;
				}
			} else {
				if (this_queue->tail == last)
					this_queue->tail = next;
			}
		}
	}
	return ;
}

void *EnqTwoStrings (void *contain) {
	LockFreeQueue_enq(((TwoStrings*)contain)->first);
	LockFreeQueue_enq(((TwoStrings*)contain)->second);
	return NULL;
}

void *EnqTwoStrings_WithError1 (void *contain) {
	LockFreeQueue_enq_WithError1(((TwoStrings*)contain)->first);
	LockFreeQueue_enq_WithError1(((TwoStrings*)contain)->second);
	return NULL;
}

void *EnqTwoStrings_WithError2 (void *contain) {
	LockFreeQueue_enq_WithError2(((TwoStrings*)contain)->first);
	LockFreeQueue_enq_WithError2(((TwoStrings*)contain)->second);
	return NULL;
}

void *DeqThreeStrings (void *lfq) {
	LockFreeQueue_deq();
	LockFreeQueue_deq();
	LockFreeQueue_deq();
	return NULL;
}