/* This is the serial push/relabel max-flow algorithm extended to use MPI
 * It uses a random level graph (each node connects to 3 nodes on next level)
 * It splits up the nodes into vertical (up/down) strips
 */
#include "stdio.h"
#include "stdlib.h"
#include "time.h"
#include "mpi.h"

#define RAND_SEED 99
#define NODES_H   20
#define NODES_W   2*NODES_H       // must be divisible by numProcs
#define NUM_NODES NODES_H*NODES_W
#define MAX_D     NUM_NODES*2
#define S         0               // nodeId of source
#define T         NUM_NODES-1     // nodeId of target
#define MAX_CAP   10
#define NUM_RUNS  1

/* global variables */
int numProcs, thisProc, nodes_w_perProc;
MPI_Status status;

class nodeList;

class nodeClass {
  
public:
  int nodeId,d,e;
  nodeList *adjNodeList;
  //added information so we can just put the node into the outPushQueue
  int toNodeId, pushAmnt;
  //added for outRelabelQueue stuff
  int relabeled;

  nodeClass (int nodeId, int d, int e, nodeList *adjNodeList) {
    this->nodeId = nodeId;
    this->d = d;
    this->e = e;
    this->adjNodeList = adjNodeList;
  }

  nodeClass (int fromNodeId, int toNodeId, int pushAmnt) {
    this->nodeId = fromNodeId;
    this->toNodeId = toNodeId;
    this->pushAmnt = pushAmnt;
  }

};

class nodeList {

public:
  nodeClass *node;
  nodeList *next;

  nodeList (nodeClass *node, nodeList *next) {
    this->node = node;
    this->next = next;
  }

  void addNode (nodeClass *newNode) {
    next = new nodeList(newNode, next);
  }
};

class nodeListClass {
  
public:
  nodeList *head, *tail;

  nodeListClass() {
    head = NULL;
    tail = NULL;
  }

  int empty() {
    if (head == NULL) return 1; else return 0;
  }

  nodeClass* getHead() {
    if (head == NULL) return NULL;

    nodeClass *headNode;
    nodeList *newHead;

    headNode = head->node;
    newHead = head->next;
    delete head;
    head = newHead;
    if (head == NULL) {
      tail = NULL;
    }

    return headNode;
  }

  void addToTail (nodeClass *node) {
    if (head == NULL) {
      // this is a new list, nothing in it yet
      head = new nodeList(node, NULL);
      tail = head;
    } else {
      tail->next = new nodeList(node, NULL);
      tail = tail->next;
    }
  }

  void print() {
    nodeList *curNodeList;
    
    curNodeList = head;

    printf ("List: ");
    while (curNodeList != NULL) {
      printf ("%d ", curNodeList->node->nodeId);
      curNodeList = curNodeList->next;
    }
    printf ("\n");
  }

};


class activeListClass {

public:
  nodeListClass *list;
  int inList[NUM_NODES];

  activeListClass() {
    list = new nodeListClass();
    for (int a=0; a<NUM_NODES; a++) {
      inList[a] = 0;
    }
  }

  int empty() {
    return list->empty();
  }

  nodeClass* getHead() {
    nodeClass *node;
    
    node = list->getHead();
    if (node != NULL) {
      inList[node->nodeId] = 0;
    }
    return node;
  }

  void addToTail (nodeClass *node) {
    if ((node->nodeId != S) && 
	(node->nodeId !=T) &&
	(inList[node->nodeId] == 0)) {
      inList[node->nodeId] = 1;
      list->addToTail(node);
    }
  }

  void print() {
    list->print();
  }
};


// Global variables
int **r; // residual network capacities
nodeClass *nodes[NUM_NODES];
activeListClass *activeNodeList;
nodeListClass **outPushQueues;
nodeListClass *outRelabelQueue;

int min (int a, int b) {
  if (a<=b) return a; else return b;
}

int getProcNum (int nodeId) {

  return (nodeId%(NODES_W*1))/nodes_w_perProc;
}

int inThisProc (int nodeId) {

  return (getProcNum(nodeId) == thisProc);
}

void generateNodes() {

  int a;

  for (a=0; a<NUM_NODES; a++) {
    nodes[a] = new nodeClass(a, 0, 0, NULL);
  }
}

/* This adds the necessary arcs to both fromNode and toNode.
 * It sends information to the appropriate proc telling that proc to add the connection
 * It also adds the capacity to r
 * It assumes that an edge from fromNode to toNode (or other way around) has
 * never been added before!!
 * cap > 0
 */
void addArc(nodeClass *fromNode, nodeClass *toNode, int cap) {

  //so we get the same graph as in mp case
  int temp;
  if (fromNode->adjNodeList != NULL) {
    temp = rand();
  }
  if (toNode->adjNodeList != NULL) {
    temp = rand();
  }
  
  r[fromNode->nodeId][toNode->nodeId] = cap;
  
  if (fromNode->adjNodeList == NULL) {
    fromNode->adjNodeList = new nodeList (toNode, NULL);
  } else {
    fromNode->adjNodeList->addNode (toNode);
  }
  if (toNode->adjNodeList == NULL) {
    toNode->adjNodeList = new nodeList (fromNode, NULL);
  } else {
    toNode->adjNodeList->addNode (fromNode);
  }

}

void bcastArc(nodeClass *fromNode, nodeClass *toNode, int cap) {

  for (int i=0; i<numProcs; i++) {
    if (i != 0) {
      MPI_Send (&(fromNode->nodeId), 1, MPI_INT, i, 0, MPI_COMM_WORLD);
      MPI_Send (&(toNode->nodeId),   1, MPI_INT, i, 0, MPI_COMM_WORLD);
      MPI_Send (&cap,                1, MPI_INT, i, 0, MPI_COMM_WORLD);
    }
  }
}

void recvArcs() {

  int fromNodeId, toNodeId, cap;

  do {
    MPI_Recv(&fromNodeId, 1, MPI_INT, 0, 0, MPI_COMM_WORLD, &status);
    if (fromNodeId != -1) {

      MPI_Recv(&toNodeId,   1, MPI_INT, 0, 0, MPI_COMM_WORLD, &status);
      MPI_Recv(&cap,        1, MPI_INT, 0, 0, MPI_COMM_WORLD, &status);

      addArc(nodes[fromNodeId], nodes[toNodeId], cap);
    }
  } while (fromNodeId != -1);

}

void sendPush (nodeClass *fromNode, nodeClass *toNode, int pushAmnt) {

  MPI_Send (&(fromNode->nodeId), 1, MPI_INT, getProcNum(toNode->nodeId), 0, MPI_COMM_WORLD);
  MPI_Send (&(toNode->nodeId),   1, MPI_INT, getProcNum(toNode->nodeId), 0, MPI_COMM_WORLD);
  MPI_Send (&pushAmnt,           1, MPI_INT, getProcNum(toNode->nodeId), 0, MPI_COMM_WORLD);
}

void sendPush (int fromNodeId, int toNodeId, int pushAmnt, int toProc) {
  MPI_Send (&fromNodeId, 1, MPI_INT, toProc, 0, MPI_COMM_WORLD);
  MPI_Send (&toNodeId,   1, MPI_INT, toProc, 0, MPI_COMM_WORLD);
  MPI_Send (&pushAmnt,   1, MPI_INT, toProc, 0, MPI_COMM_WORLD);
}

void recvPush(int fromProc) {

  int fromNodeId, toNodeId, pushAmnt;

  do {
    //printf ("proc %d start recv push\n", thisProc);
    MPI_Recv(&fromNodeId, 1, MPI_INT, fromProc, 0, MPI_COMM_WORLD, &status);
    if (fromNodeId != -1) {

      MPI_Recv(&toNodeId, 1, MPI_INT, fromProc, 0, MPI_COMM_WORLD, &status);
      MPI_Recv(&pushAmnt,   1, MPI_INT, fromProc, 0, MPI_COMM_WORLD, &status);

      //printf ("getting push on node %d proc %d\n", toNodeId, thisProc);

      if (!inThisProc(toNodeId)) {
	printf ("getting push for not my node\n");
      }

      activeNodeList->addToTail(nodes[toNodeId]);
      
      //r[fromNodeId][toNodeId] -= pushAmnt;
      r[toNodeId][fromNodeId] += pushAmnt;
      nodes[toNodeId]->e += pushAmnt;
      //nodes[fromNodeId]->e -= pushAmnt;
      nodes[toNodeId]->relabeled = 1; //so it doesn't get relabeled
    }
    //printf ("proc %d done recv push\n", thisProc);
  } while (fromNodeId != -1);

}

void checkD() {
  
  int a, inD;
  
  if (thisProc == 0) {
    for (a=0; a<NUM_NODES; a++) {
      MPI_Send (&(nodes[a]->d), 1, MPI_INT, 1, 0, MPI_COMM_WORLD);
    }
  } else {
    for (a=0; a<NUM_NODES; a++) {
      MPI_Recv(&inD, 1, MPI_INT, 0, 0, MPI_COMM_WORLD, &status);
      if (inD != nodes[a]->d) {
	printf ("BAD d on node %d, got %d, is %d\n",a, inD, nodes[a]->d);
      }
    }
  }

}

void bcastDistance(nodeClass *node) {

  for (int i=0; i<numProcs; i++) {
    if (i != thisProc) {
      MPI_Send (&(node->nodeId), 1, MPI_INT, i, 0, MPI_COMM_WORLD);
      MPI_Send (&(node->d),      1, MPI_INT, i, 0, MPI_COMM_WORLD);
    }
  }
}


void recvDistances(int fromProc) {

  int nodeId, d;

  do {
    MPI_Recv(&nodeId, 1, MPI_INT, fromProc, 0, MPI_COMM_WORLD, &status);
    if (nodeId != -1) {

      MPI_Recv(&d, 1, MPI_INT, fromProc, 0, MPI_COMM_WORLD, &status);
    
      nodes[nodeId]->d = d;
     }
  } while (nodeId != -1);
}





/* this broadcasts out the relabel changes that have been made
 */

void processOutRelabelQueue() {

  nodeClass *node;
  int temp = -1;
  int a, proc;
  int minAdjD, curD;
  nodeList *curAdjNodeList;

  //let each proc relabel independently

  for (proc=0; proc<numProcs; proc++) {

    if (proc == thisProc) {
      //this proc gets to relabel
      //and send out the new lablels

      node = outRelabelQueue->getHead();
      while (node != NULL) {
	if (!node->relabeled) {
	  node->relabeled = 1;
	  
	  //actually do the relabeling

	  //not needed since we only add to outRelabelQueue if e>0 and
	  //incoming pushes can only result in increase of e
	  //if (node->e > 0) {
	  
	  // this node is still active, need to relabel and reschedule
	  minAdjD = MAX_D;
	  curAdjNodeList = node->adjNodeList;

	  // checking all the adjacents nodes
	  while (curAdjNodeList != NULL) {
	    if (r[node->nodeId][curAdjNodeList->node->nodeId]) {
	      curD = curAdjNodeList->node->d;
	      if (curD < minAdjD) minAdjD = curD;
	    }
	    curAdjNodeList = curAdjNodeList->next;
	  }

	  node->d = minAdjD + 1;

	  activeNodeList->addToTail(node);
	  //now send this result out to everyone
	  bcastDistance(node);
	}
	node = outRelabelQueue->getHead();
      }

      //tell everyone we're done
      for (a=0; a<numProcs; a++) {
	if (a != thisProc) {
	  MPI_Send (&temp, 1, MPI_INT, a, 0, MPI_COMM_WORLD);
	}
      }
    } else {
      //recv the new labels from proc
      recvDistances(proc);
    }

    MPI_Barrier(MPI_COMM_WORLD);
  }
}


  /* this first sends out all the pushes necessary and then recvs the incoming ones
   */
void processOutPushQueue() {
  
  nodeClass *node;
  int temp = -1;
  int a;

  //  printf ("proc %d ", thisProc); 
  for (a=0; a<numProcs; a++) {
    if (a != thisProc) {
      //printf ("out %d ", a);
      //outPushQueues[a]->print();
      node = outPushQueues[a]->getHead();
      while (node != NULL) {
	//printf ("sending stuff out\n");
	sendPush(node->nodeId, node->toNodeId, node->pushAmnt, a);
	delete node;
	node = outPushQueues[a]->getHead();
      }
      //sending out done
      MPI_Send (&temp, 1, MPI_INT, a, 0, MPI_COMM_WORLD);
    }
  }
  for (a=0; a<numProcs; a++) {
    if (a != thisProc) {
      recvPush(a);
    }
  }
}




/* Special case to push/relabel S
 */
void pushRelabelS() {

  nodeClass *node;
  nodeList *curAdjNodeList;
  nodeClass *curAdjNode;
  int pushAmnt, *resArcCap, *revResArcCap, temp;
  int i;

  node = nodes[S];
  curAdjNodeList = node->adjNodeList;

  while (curAdjNodeList != NULL) {
    curAdjNode   = curAdjNodeList->node;
    resArcCap    = &r[node->nodeId][curAdjNode->nodeId];
    revResArcCap = &r[curAdjNode->nodeId][node->nodeId];
    if (*resArcCap > 0) {
      pushAmnt     = *resArcCap;
      node->e -= pushAmnt;
      curAdjNode->e += pushAmnt;
      *resArcCap -= pushAmnt;
      *revResArcCap += pushAmnt;
      // need to schedule curAdjNode so that it will get processed as active
      if (inThisProc(curAdjNode->nodeId)) {
	activeNodeList->addToTail(curAdjNode);
      } else {
	sendPush(node, curAdjNode, pushAmnt);
      }
    }
    curAdjNodeList = curAdjNodeList->next;
  }

  node->d = NUM_NODES;

  temp = -1;
  //send out -1 to tell everyone we're done
  for (i=1; i<numProcs; i++) {
    MPI_Send (&temp, 1, MPI_INT, i, 0, MPI_COMM_WORLD);
  }

  bcastDistance(node);
  //send out -1 to tell everyone we're done
  for (i=1; i<numProcs; i++) {
    MPI_Send (&temp, 1, MPI_INT, i, 0, MPI_COMM_WORLD);
  }
}

void pushRelabel (nodeClass *node) {
  // node must be an active node (e>0)

  nodeList *curAdjNodeList;
  nodeClass *curAdjNode;
  int pushAmnt, *resArcCap, *revResArcCap;

  curAdjNodeList = node->adjNodeList;

  // checking all the adjacent nodes
  while (curAdjNodeList != NULL) {
    curAdjNode   = curAdjNodeList->node;
    resArcCap    = &r[node->nodeId][curAdjNode->nodeId];
    revResArcCap = &r[curAdjNode->nodeId][node->nodeId];
    if ((curAdjNode->d == node->d - 1) && 
	(*resArcCap > 0) && 
	(node->e > 0)) {
      // this arc is admissible - we can push flow
      pushAmnt     = min(node->e, *resArcCap);
      node->e -= pushAmnt;
      *resArcCap -= pushAmnt;
      // need to schedule curAdjNode so that it will get processed as active
      
      if (inThisProc(curAdjNode->nodeId)) {
	curAdjNode->e += pushAmnt;
	*revResArcCap += pushAmnt;
	activeNodeList->addToTail(curAdjNode);
      } else {
	outPushQueues[getProcNum(curAdjNode->nodeId)]->addToTail(new nodeClass(node->nodeId, curAdjNode->nodeId, pushAmnt));
      }
    }
    curAdjNodeList = curAdjNodeList->next;
  }

  if (node->e > 0) {
    node->relabeled = 0;
    outRelabelQueue->addToTail(node);
  }
}

void generateOutPushQueues() {
  outPushQueues = (nodeListClass**)malloc(sizeof(nodeListClass*)*numProcs);
  for (int a=0; a<numProcs; a++) {
    outPushQueues[a] = new nodeListClass();
  }
}


void initR() {

  for (int a=0; a<NUM_NODES; a++) {
    for (int b=0; b<NUM_NODES; b++) {
      r[a][b] = 0;
    }
  }
}

void generateR() {

  r = (int**)malloc(sizeof(int*)*NUM_NODES);
  for (int a=0; a<NUM_NODES; a++) {
    r[a] = (int*)malloc(sizeof(int)*NUM_NODES);
  }
}


void initNodes() {

  nodeList *adjNodeList;

  for (int a=0; a<NUM_NODES; a++) {
    nodes[a]->e = 0;
    nodes[a]->d = 0;
    while (nodes[a]->adjNodeList != NULL) {
      adjNodeList = nodes[a]->adjNodeList;
      nodes[a]->adjNodeList = nodes[a]->adjNodeList->next;
      delete adjNodeList;
    }
  }
}

void initDistances() {
  
  nodeListClass *nodesToTouch;
  nodeList *curAdjNodeList;
  nodeClass *node, *adjNode;
  int nextD, temp;
  int touched[NUM_NODES];

  for (int a=0; a<NUM_NODES; a++) {
    touched[a] =0;
  }

  nodesToTouch = new nodeListClass();

  nodes[T]->d = 0;
  touched[T] = 1;
  nodesToTouch->addToTail(nodes[T]);

  // this basically does breadth first search backwards from T
  while (!nodesToTouch->empty()) {
    node = nodesToTouch->getHead();
    nextD = node->d+1;
    curAdjNodeList = node->adjNodeList;
    while (curAdjNodeList != NULL) {
      adjNode = curAdjNodeList->node;
      if (!touched[adjNode->nodeId] && r[adjNode->nodeId][node->nodeId]) {
	touched[adjNode->nodeId] = 1;
	adjNode->d = nextD;
	bcastDistance(adjNode);
	nodesToTouch->addToTail(adjNode);
      }
      curAdjNodeList = curAdjNodeList->next;
    }
  }

  temp = -1;
  //send out -1 to tell everyone we're done
  for (int i=1; i<numProcs; i++) {
    MPI_Send (&temp, 1, MPI_INT, i, 0, MPI_COMM_WORLD);
  }
}

/* This connects a small test graph used for testing
 */
/*
void connectTestGraph() {
  addArc(nodes[0],nodes[1],3);
  addArc(nodes[0],nodes[3],2);
  addArc(nodes[0],nodes[2],3);
  addArc(nodes[1],nodes[4],4);
  addArc(nodes[2],nodes[3],1);
  addArc(nodes[2],nodes[5],2);
  addArc(nodes[3],nodes[1],1);
  addArc(nodes[3],nodes[5],4);
  addArc(nodes[4],nodes[3],1);
  addArc(nodes[4],nodes[5],1);
  addArc(nodes[1],nodes[2],0);
}
*/
/* This connects a level graph with random caps 
 */
void connectRLGraph() {

  int h,w,node1,node2,node3,cap,temp;

  for (h=0; h<NODES_H-1; h++) {
    for (w=0; w<NODES_W; w++) {
      //pick 3 random nodes in the next level
      node1 = rand()%(NODES_W*1);
      do {node2 = rand()%(NODES_W*1);} while (node2 == node1);
      do {node3 = rand()%(NODES_W*1); } while ((node3 == node1) || (node3 == node2));

      cap = rand()%(MAX_CAP*1)+1;
      addArc(nodes[h*NODES_W+w],nodes[(h+1)*NODES_W+node1], cap);
      bcastArc(nodes[h*NODES_W+w],nodes[(h+1)*NODES_W+node1], cap);
      cap = rand()%(MAX_CAP*1)+1;
      addArc(nodes[h*NODES_W+w],nodes[(h+1)*NODES_W+node2], cap);
      bcastArc(nodes[h*NODES_W+w],nodes[(h+1)*NODES_W+node2], cap);
      cap = rand()%(MAX_CAP*1)+1;
      addArc(nodes[h*NODES_W+w],nodes[(h+1)*NODES_W+node3], cap);
      bcastArc(nodes[h*NODES_W+w],nodes[(h+1)*NODES_W+node3], cap);
    }
  }


  //connect S and T to first and last level
  for (w=0; w<NODES_W-1; w++) {
    addArc(nodes[S], nodes[w+1], MAX_CAP*NODES_W);
    bcastArc(nodes[S],nodes[w+1], MAX_CAP*NODES_W);
  }
  
  for (w=(NODES_H-1)*NODES_W+1; w<NUM_NODES-1; w++) {
    addArc(nodes[w], nodes[T], MAX_CAP*NODES_W);
    bcastArc(nodes[w], nodes[T], MAX_CAP*NODES_W);

  }

  temp = -1;
  //send out -1 to tell everyone we're done
  for (int i=1; i<numProcs; i++) {
    MPI_Send (&temp, 1, MPI_INT, i, 0, MPI_COMM_WORLD);
    
  }

}


/* Used for debuggin only
 */
void printNodes() {

  nodeClass *n;

  printf ("Nodes:\n");

  for (int a=0; a<NUM_NODES; a++) {
    //if (inThisProc(a)) {
    n = nodes[a];
    printf("Node: %d d=%d e=%d\n", n->nodeId, n->d, n->e);
    //}
  }
}

void printR() {

  printf ("R:\n");
  for (int a=0; a<NUM_NODES; a++) {
    for (int b=0; b<NUM_NODES; b++) {
      printf ("%d ", r[a][b]);
    }
    printf ("\n");
  }

}

void printArcs() {

  nodeClass *node;
  nodeList *curAdjNodeList;
  
  printf ("Arcs:\n");
  for (int a=0; a<NUM_NODES; a++) {
    node = nodes[a];
    printf ("Node: %d ", a);
    curAdjNodeList = node->adjNodeList;
    while (curAdjNodeList != NULL) {
      printf ("%d ", curAdjNodeList->node->nodeId);
      curAdjNodeList = curAdjNodeList->next;
    }
    printf ("\n");
  }
}

void findMaxFlow() {

  int done = 0;
  int a, oldDone;
  nodeClass *node;

  nodes_w_perProc = NODES_W/numProcs;

  if (thisProc == 0) {
    pushRelabelS();
  } else {
    recvPush(0);
    recvDistances(0);
  }
  
  //if (thisProc == 0) {
  // printR();
    
  //}


  while (!done) {
    node = activeNodeList->getHead();
    
    if (node != NULL) {
      pushRelabel(node);
    } else {
      
      processOutPushQueue();
      processOutRelabelQueue();

      //checkD();

      //printf("proc %d ", thisProc);
      //activeNodeList->print();

      if (activeNodeList->empty()) {
	done = 1;
      }
      //gotta make sure everyone is done
      
      oldDone = done;

      MPI_Reduce (&oldDone, &done, 1, MPI_INT, MPI_MIN, 0, MPI_COMM_WORLD);

      if (thisProc == 0) {
	
	for (a=1; a<numProcs; a++) {
	  MPI_Send(&done, 1, MPI_INT, a, 0, MPI_COMM_WORLD);
	}
      } else {
	MPI_Recv(&done, 1, MPI_INT, 0, 0, MPI_COMM_WORLD, &status);
      }
    }
  }
}


void main(int argc, char **argv) {

  double startTime, endTime, totTime;

  MPI_Init(&argc, &argv);
  
  MPI_Comm_size(MPI_COMM_WORLD, &numProcs);
  MPI_Comm_rank(MPI_COMM_WORLD, &thisProc);

  generateNodes();
  generateR();
  generateOutPushQueues();
  activeNodeList = new activeListClass();
  outRelabelQueue = new nodeListClass();

  totTime = 0;

  for (int run=0; run<NUM_RUNS; run++) {

    srand(RAND_SEED);
  
    initNodes();
    initR();
  
    //connectTestGraph();
    if (thisProc == 0) {
      connectRLGraph();
      initDistances();
    } else {
      recvArcs();
      recvDistances(0);
    }
    
    MPI_Barrier(MPI_COMM_WORLD);
      
    if (thisProc == numProcs-1) {
      printf ("Starting run %d\n", run);
      startTime = MPI_Wtime();
    }
    findMaxFlow();
    
    MPI_Barrier(MPI_COMM_WORLD);
    
    if (thisProc == numProcs-1) {
      endTime = MPI_Wtime();
      totTime += endTime-startTime;
    }
    
  }
  
  printf ("proc %d\n", thisProc);
  //printNodes();

  if (thisProc == numProcs-1) {
    printf ("max flow:     %d\n", nodes[NUM_NODES-1]->e);
    printf ("# of procs:   %d\n", numProcs);
    printf ("# of nodes:   %d\n", NUM_NODES);
    printf ("# of runs:    %d\n", NUM_RUNS);
    printf ("tot time (s): %f\n", totTime);
    printf ("avg time (s): %f\n", totTime/NUM_RUNS);
  }
  MPI_Finalize();

}

