/* This is the parallel push/relabel max-flow algorithm
 * This version allows each processor work on its own small active queue
 * this also inserts adj nodes in random order
 * this is using a randome level graph has width 2x the height
 */
#include "stdio.h"
#include "stdlib.h"
#include "time.h"
#include "omp.h"

#define RAND_SEED 99
#define NODES_H   50
#define NODES_W   2*NODES_H
#define NUM_NODES NODES_H*NODES_W
#define MAX_D     NUM_NODES*2
#define S         0               // nodeId of source
#define T         NUM_NODES-1     // nodeId of target
#define MAX_CAP   10
#define NUM_RUNS  1

// Global locks
int nodeLocks[NUM_NODES];

void lock(int* l) {

  while (__lock_test_and_set (l, 1) != 0); /* acquire the lock */ 
}

void release(int* l) {

  __lock_release (l);
}

class nodeList;

class nodeClass {
  
public:
  int nodeId,d,e,working;
  nodeList *adjNodeList;
  int numArcs;

  nodeClass (int nodeId, int d, int e, nodeList *adjNodeList) {
    this->nodeId = nodeId;
    this->d = d;
    this->e = e;
    this->adjNodeList = adjNodeList;
    this->working = 0;
    this->numArcs = 0;
  }

};

class nodeList {

public:
  nodeClass *node;
  nodeList *next;

  nodeList (nodeClass *node, nodeList *next) {
    this->node = node;
    this->next = next;
  }

  void addNode (nodeClass *newNode) {
    nodeList *oldNext = next;
    next = new nodeList(newNode, oldNext);
  }
};

class nodeListClass {
  
public:
  nodeList *head, *tail;

  nodeListClass() {
    head = NULL;
    tail = NULL;
  }

  int empty() {
    if (head == NULL) return 1; else return 0;
  }

  nodeClass* getHead() {
    if (head == NULL) return NULL;

    nodeClass *headNode;
    nodeList *newHead;

    headNode = head->node;
    newHead = head->next;
    delete head;
    head = newHead;
    if (head == NULL) {
      tail = NULL;
    }

    return headNode;
  }

  void addToTail (nodeClass *node) {
    if (head == NULL) {
      // this is a new list, nothing in it yet
      head = new nodeList(node, NULL);
      tail = head;
    } else {
      tail->next = new nodeList(node, NULL);
      tail = tail->next;
    }
  }
};


class activeListClass {

public:
  nodeListClass *list;
  int inList[NUM_NODES];

  activeListClass() {
    list = new nodeListClass();
    for (int a=0; a<NUM_NODES; a++) {
      inList[a] = 0;
    }
  }

  int empty() {
    return list->empty();
  }

  nodeClass* getHead() {
    nodeClass *node;

    node = list->getHead();
    if (node != NULL) {
      inList[node->nodeId] = 0;
    }

    return node;
  }

  void addToTail (nodeClass *node) {
    
    if ((node->nodeId != S) && 
	(node->nodeId !=T) &&
	(inList[node->nodeId] == 0)) {
      inList[node->nodeId] = 1;
      list->addToTail(node);
    }
  }

  void print() {
    nodeList *curNodeList;
    
    curNodeList = list->head;

    printf ("Active List: ");
    while (curNodeList != NULL) {
      printf ("%d ", curNodeList->node->nodeId);
      curNodeList = curNodeList->next;
    }
    printf ("\n");
  }
};


// Global variables
int **r; // residual network capacities
nodeClass *nodes[NUM_NODES];

int min (int a, int b) {
  if (a<=b) return a; else return b;
}

int max (int a, int b) {
  if (a>b) return a; else return b;
}

/* Used for debuggin only
 */
void printNodes() {

  nodeClass *n;

  printf ("Nodes:\n");

  for (int a=0; a<NUM_NODES; a++) {
    n = nodes[a];
    printf("Node: %d d=%d e=%d\n", n->nodeId, n->d, n->e);
  }
}

void printR() {

  printf ("R:\n");
  for (int a=0; a<NUM_NODES; a++) {
    for (int b=0; b<NUM_NODES; b++) {
      printf ("%d ", r[a][b]);
    }
    printf ("\n");
  }

}

void printArcs() {

  nodeClass *node;
  nodeList *curAdjNodeList;
  
  printf ("Arcs:\n");
  for (int a=0; a<NUM_NODES; a++) {
    node = nodes[a];
    printf ("Node: %d ", a);
    curAdjNodeList = node->adjNodeList;
    while (curAdjNodeList != NULL) {
      printf ("%d ", curAdjNodeList->node->nodeId);
      curAdjNodeList = curAdjNodeList->next;
    }
    printf ("\n");
  }
}




void generateNodes() {

  int a;
  nodeClass *node;

  for (a=0; a<NUM_NODES; a++) {
    node = new nodeClass(a, 0, 0, NULL);
    nodes[a] = node;
  }
}

/* This adds the necessary arcs to both fromNode and toNode.
 * It also adds the capacity to r
 * It assumes that an edge from fromNode to toNode (or other way around) has
 * never been added before!!
 * cap > 0
 */
void addArc(nodeClass *fromNode, nodeClass *toNode, int cap) {
  int randLoc;
  nodeList *tempList;

  r[fromNode->nodeId][toNode->nodeId] = cap;
  if (fromNode->adjNodeList == NULL) {
    fromNode->adjNodeList = new nodeList (toNode, NULL);
  } else {
    randLoc = (rand()%fromNode->numArcs)+1;
    tempList = fromNode->adjNodeList;
    for (int a=1; a<randLoc; a++) {
      tempList = tempList->next;
    }
    tempList->addNode (toNode);
  }
  if (toNode->adjNodeList == NULL) {
    toNode->adjNodeList = new nodeList (fromNode, NULL);
  } else {
    randLoc = (rand()%toNode->numArcs)+1;
    tempList = toNode->adjNodeList;
    for (int a=1; a<randLoc; a++) {
      tempList = tempList->next;
    }
    tempList->addNode (fromNode);
  }
  fromNode->numArcs++;
  toNode->numArcs++;
}


/* Special case to push/relabel S
 */
void pushRelabelS(activeListClass *activeNodeList) {
  
  nodeClass *node;
  nodeList *curAdjNodeList;
  nodeClass *curAdjNode;
  int pushAmnt, *resArcCap, *revResArcCap;

  node = nodes[S];
  curAdjNodeList = node->adjNodeList;

  while (curAdjNodeList != NULL) {
    curAdjNode   = curAdjNodeList->node;
    resArcCap    = &r[node->nodeId][curAdjNode->nodeId];
    revResArcCap = &r[curAdjNode->nodeId][node->nodeId];
    if (*resArcCap > 0) {
      pushAmnt     = *resArcCap;
      node->e -= pushAmnt;
      curAdjNode->e += pushAmnt;
      *resArcCap -= pushAmnt;
      *revResArcCap += pushAmnt;
      // need to schedule curAdjNode so that it will get processed as active
      activeNodeList->addToTail(curAdjNode);
    }
    curAdjNodeList = curAdjNodeList->next;
  }

  node->d = NUM_NODES;
}


void pushRelabel (nodeClass *node, activeListClass *activeNodeList) {
  // node must be an active node (e>0) - not necessary anymore

  nodeList *curAdjNodeList;
  nodeClass *curAdjNode;
  int pushAmnt, *resArcCap, *revResArcCap;

  if (node->working) {
    activeNodeList->addToTail(node);  
  } else {

    lock(&node->working);
    if ((node->nodeId != S) && (node->nodeId != T)) {
      release(&nodeLocks[node->nodeId]);

      curAdjNodeList = node->adjNodeList;

      // checking all the adjacent nodes
      while (curAdjNodeList != NULL) {
	curAdjNode   = curAdjNodeList->node;
	resArcCap    = &r[node->nodeId][curAdjNode->nodeId];
	revResArcCap = &r[curAdjNode->nodeId][node->nodeId];

	lock(&nodeLocks[min(node->nodeId, curAdjNode->nodeId)]);
	lock(&nodeLocks[max(node->nodeId, curAdjNode->nodeId)]);

	if ((curAdjNode->d == node->d - 1) && 
	    (*resArcCap > 0) && 
	    (node->e > 0)) {
	  // this arc is admissible - we can push flow
	  pushAmnt     = min(node->e, *resArcCap);
	  node->e -= pushAmnt;
	  curAdjNode->e += pushAmnt;
	  *resArcCap -= pushAmnt;
	  *revResArcCap += pushAmnt;
	  // need to schedule curAdjNode so that it will get processed as active
	  //activeNodeList->addToTail(curAdjNode);

	  release(&nodeLocks[max(node->nodeId, curAdjNode->nodeId)]);
	  release(&nodeLocks[min(node->nodeId, curAdjNode->nodeId)]);
	
	  pushRelabel(curAdjNode, activeNodeList);
	} else {
	
	  release(&nodeLocks[max(node->nodeId, curAdjNode->nodeId)]);
	  release(&nodeLocks[min(node->nodeId, curAdjNode->nodeId)]);

	}
	curAdjNodeList = curAdjNodeList->next;
      }

      if (node->e > 0) {
    
	// otherwise this node is still active, need to relabel and reschedule
  
	int curD;
	int minAdjD = MAX_D;
	curAdjNodeList = node->adjNodeList;

	// checking all the adjacents nodes
	while (curAdjNodeList != NULL) {
	  if (r[node->nodeId][curAdjNodeList->node->nodeId]) {
	    curD = curAdjNodeList->node->d;
	    if (curD < minAdjD) minAdjD = curD;
	  }
	  curAdjNodeList = curAdjNodeList->next;
	}

	node->d = minAdjD + 1;

	activeNodeList->addToTail(node);
	//pushRelabel(node);
      }
    } 
    release(&node->working);
  }
  
}

void initLocks() {
  for (int a=0; a<NUM_NODES; a++) {
    nodeLocks[a] = 0;
  }
}

void initR() {

  r = (int**)malloc(sizeof(int*)*NUM_NODES);
  for (int a=0; a<NUM_NODES; a++) {
    r[a] = (int*)malloc(sizeof(int)*NUM_NODES);
    for (int b=0; b<NUM_NODES; b++) {
      r[a][b] = 0;
    }
  }
}


void initNodes() {

  nodeList *adjNodeList;

  for (int a=0; a<NUM_NODES; a++) {
    nodes[a]->e = 0;
    nodes[a]->d = 0;
    nodes[a]->working = 0;
    nodes[a]->numArcs = 0;
    while (nodes[a]->adjNodeList != NULL) {
      adjNodeList = nodes[a]->adjNodeList;
      nodes[a]->adjNodeList = nodes[a]->adjNodeList->next;
      delete adjNodeList;
    }
  }
}

void initDistances() {
  
  nodeListClass *nodesToTouch;
  nodeList *curAdjNodeList;
  nodeClass *node, *adjNode;
  int nextD;
  int touched[NUM_NODES];

  for (int a=0; a<NUM_NODES; a++) {
    touched[a] =0;
  }

  nodesToTouch = new nodeListClass();

  nodes[T]->d = 0;
  touched[T] = 1;
  nodesToTouch->addToTail(nodes[T]);

  // this basically does breadth first search backwards from T
  while (!nodesToTouch->empty()) {
    node = nodesToTouch->getHead();
    nextD = node->d+1;
    curAdjNodeList = node->adjNodeList;
    while (curAdjNodeList != NULL) {
      adjNode = curAdjNodeList->node;
      if (!touched[adjNode->nodeId] && r[adjNode->nodeId][node->nodeId]) {
	touched[adjNode->nodeId] = 1;
	adjNode->d = nextD;
	nodesToTouch->addToTail(adjNode);
      }
      curAdjNodeList = curAdjNodeList->next;
    }
  }
}

/* This connects a small test graph used for testing
 */
void connectTestGraph() {
  addArc(nodes[0],nodes[1],3);
  addArc(nodes[0],nodes[3],2);
  addArc(nodes[0],nodes[2],3);
  addArc(nodes[1],nodes[4],4);
  addArc(nodes[2],nodes[3],1);
  addArc(nodes[2],nodes[5],2);
  addArc(nodes[3],nodes[1],1);
  addArc(nodes[3],nodes[5],4);
  addArc(nodes[4],nodes[3],1);
  addArc(nodes[4],nodes[5],1);
  addArc(nodes[1],nodes[2],0);
}

/* This connects a level graph with random caps 
 */
void connectRLGraph() {

  int h,w,node1,node2,node3;

  for (h=0; h<NODES_H-1; h++) {
    for (w=0; w<NODES_W; w++) {
      //pick 3 random nodes in the next level
      node1 = rand()%(NODES_W*1);
      do {node2 = rand()%(NODES_W*1);} while (node2 == node1);
      do {node3 = rand()%(NODES_W*1); } while ((node3 == node1) || (node3 == node2));

      addArc(nodes[h*NODES_W+w],nodes[(h+1)*NODES_W+node1], rand()%(MAX_CAP*1)+1);
      addArc(nodes[h*NODES_W+w],nodes[(h+1)*NODES_W+node2], rand()%(MAX_CAP*1)+1);
      addArc(nodes[h*NODES_W+w],nodes[(h+1)*NODES_W+node3], rand()%(MAX_CAP*1)+1);
    }
  }

  //connect S and T to first and last level
  for (w=0; w<NODES_W-1; w++) {
    addArc(nodes[S], nodes[w+1], MAX_CAP*NODES_W);
  }
  
  for (w=(NODES_H-1)*NODES_W+1; w<NUM_NODES-1; w++) {
    addArc(nodes[w], nodes[T], MAX_CAP*NODES_W);
  }

}

/* This distributes the global active list to the local procs at the beginning
 */
void distG2LActiveNodes(activeListClass *globalList, 
			activeListClass *localList) {
  
  int curProc = 0;
  int thisProc = omp_get_thread_num();
  int numProcs = omp_get_num_threads();

  while (!globalList->empty()) {
    
    if (curProc == thisProc) {
      localList->addToTail(globalList->getHead());
    }
    
    #pragma omp barrier
    curProc = (curProc+1)%numProcs;
  }

}



void findMaxFlow() {

  activeListClass *globalActiveNodeList;

  globalActiveNodeList = new activeListClass();
  pushRelabelS(globalActiveNodeList);
  
  #pragma omp parallel 
  {
    nodeClass *node;
    activeListClass *localActiveNodeList;
    
    localActiveNodeList = new activeListClass();
    
    distG2LActiveNodes(globalActiveNodeList, localActiveNodeList);

    node = localActiveNodeList->getHead();
    while (node != NULL) {
   
      pushRelabel(node, localActiveNodeList);
      node = localActiveNodeList->getHead();
    }
  }
}


void main() {

  clock_t startTime, endTime, totTime;
  
  generateNodes();
  totTime = 0;

  for (int run=0; run<NUM_RUNS; run++) {
  
    //srand(time(NULL));
    srand(RAND_SEED); //so that we can compare runs

    initNodes();
    initR();
    initLocks();

    //connectTestGraph();
    //connectRandomGraph();
    connectRLGraph();

    initDistances();

    //printR();
  
    printf ("Starting run %d\n", run);
    startTime = clock();
    findMaxFlow();
    endTime = clock();
    //printf ("Done Max-Flow\n");

    totTime += endTime-startTime;

  }

  printf ("max flow:     %d\n", nodes[NUM_NODES-1]->e);
  printf ("# of procs:   %d\n", omp_get_max_threads());
  printf ("# of nodes:   %d\n", NUM_NODES);
  printf ("# of runs:    %d\n", NUM_RUNS);
  printf ("tot time (s): %f\n", ((float)totTime)/CLOCKS_PER_SEC);
  printf ("avg time (s): %f\n", ((float)totTime)/CLOCKS_PER_SEC/NUM_RUNS);

}

