#include <stdio.h>
#include <stdlib.h>
#include <pthread.h>
#include <assert.h>

/* Node and Queue structure */
typedef struct NODE {
  void *val;
  struct NODE *next;
} Node;

typedef struct LOCKFREEQUEUE {
  Node *head, *tail;
} LockFreeQueue;

typedef struct TWOSTRINGS {
  char *first;
  char *second;
} TwoStrings;

/* global variables */
LockFreeQueue *this_queue;
pthread_mutex_t CE_MUTEX;		// compare and exchange

/* function prototype */
Node *NewNode(void *init_val);
LockFreeQueue *NewLockFreeQueue();

void InitCE (pthread_mutex_t *ce_lock);
void DestroyCE (pthread_mutex_t *ce_lock);
int CompareAndExchange (void **destination, 
                        void *value, void *comparand);

void LockFreeQueue_enq(void *x);
void *LockFreeQueue_deq();

void LockFreeQueue_enq_WithError1 (void *x);
void LockFreeQueue_enq_WithError2 (void *x);

void *EnqTwoStrings (void *contain);
void *EnqTwoStrings_WithError1 (void *contain);
void *EnqTwoStrings_WithError2 (void *contain);

void *DeqThreeStrings (void *lfq);

Node *NewNode(void *init_val) {
  Node *new_node = (Node*) malloc(sizeof(Node));
  
  new_node->val = init_val;
  new_node->next = NULL;
  
  return new_node;
}

LockFreeQueue *NewLockFreeQueue() {
  Node *init_head_tail = NewNode(NULL);
  LockFreeQueue *new_queue = 
            (LockFreeQueue*) malloc(sizeof(LockFreeQueue));
  
  new_queue->head = new_queue->tail = init_head_tail;
  
  return new_queue;
}

/* compare and exchange functions */
void InitCE (pthread_mutex_t *ce_lock) {
  pthread_mutex_init(ce_lock, NULL);
}

void DestroyCE (pthread_mutex_t *ce_lock) {
  pthread_mutex_destroy(ce_lock);
}

// return 1 for true and 0 for fasle
int CompareAndExchange (void **destination, 
                        void *value, void *comparand) {
  int success = 0;
  
  pthread_mutex_lock(&CE_MUTEX);
  if ((*destination) == comparand) {
    (*destination) = value;
    success = 1;
  }
  pthread_mutex_unlock(&CE_MUTEX);
  
  return success;
}

/* lock free queue enqueue */
void LockFreeQueue_enq_WithError1(void *x) {
  Node *last = NULL;
  Node *next = NULL;
  Node *node = NewNode(x);
  
  while (1) {
    last = this_queue->tail;
    next = last->next;
    
    if (last == this_queue->tail) {
      if (next == NULL) {
        if ( CompareAndExchange(&(this_queue->tail->next), node, next) ) {
          if (this_queue->tail == last) 
            this_queue->tail = node;
          // CompareAndExchange(&(this_queue->tail), node, last);
          return;
        }
      } else {
        CompareAndExchange(&(this_queue->tail), next, last);
      }
    }
  }
}

void *EnqTwoStrings_WithError1 (void *contain) {
  LockFreeQueue_enq_WithError1(((TwoStrings*)contain)->first);
  LockFreeQueue_enq_WithError1(((TwoStrings*)contain)->second);
  return NULL;
}

void TestLockFreeQueue () {
  int i;
  pthread_t t1;
  pthread_t t2;
  TwoStrings firstParams, secondParams;	
  
  firstParams.first = "a";
  firstParams.second = "b";
  secondParams.first = "c";
  secondParams.second = "d";
  
  pthread_create(&t1, NULL, EnqTwoStrings_WithError1, &firstParams);
  pthread_create(&t2, NULL, EnqTwoStrings_WithError1, &secondParams);
  pthread_join(t1, NULL);
  pthread_join(t2, NULL);	
  
  Node *current = this_queue->head;
  for (i = 0 ; i < 4 ; i++) {
    current = current->next;
    printf("%d: %s\n", i, (char*)current->val);
  }
  assert ( this_queue->tail == current );
}	

int main (void) {
  InitCE(&CE_MUTEX);
  TestLockFreeQueue ();
  DestroyCE(&CE_MUTEX);
  return 0;
}
