#include "graph.h"
#include "passing.h"
#include<cfloat>
#include<cassert>
#include<cmath>
using namespace std;

int *passing::q;
int passing::n, passing::tail;
long double **passing:: result;
int *passing::adv, *passing::adv0;

inline long double trans(const long double x)
{
	return log(x);
	//return x;
}

inline long double itrans(const long double x)
{
	return exp(x);
	//return x;
}

inline long double sum(const long double x, const long double y)
{
	return (max(x,y));
	//return (x+y);
}

inline long double product(const long double x, const long double y)
{
	return (x+y);
	//return (x*y);
}

inline long double iprod(const long double x, const long double y)
{
	return (x-y);
	//return (y<EPS?x:x/y);
}

inline long double prod0()
{
	return 0;
}

inline long double sum0()
{
	return -MAXTIME;
}

void normalize(long double result[])
{
	int d=sizeof(result);
	long double s=0;
	for (int i=0;i<d;i++)
		s+=result[i];
	s/=d;
	for (int i=0;i<d;i++)
		result[i]-=s;
	for (int i=0;i<d;i++)
		result[i]=trans(result[i]);

}

void pushmessage(int x)
{
	int degree=Graph::GetDegree(x);
	message m;
	m.value.clear();
	for (int i=0;i<degree;i++)
	{
		Edge e=Graph::GetEdge(x,i);
		e.prob=trans(e.prob);
		//list<Edge>::iterator p=m.value.begin();
		//while (p!=m.value.end() && p->st<e.st) p++;
		//if (p==m.value.end() || p->st>e.st)
		//{
		m.value.push_back(e);
		//}
		//else p->prob+=prob;
	}
	Graph::pushSM(x,m);
}

void passing::init()
{
	n=Graph::GetN();
	q=new int[n+1];
	tail=0;
	for (int i=0;i<n+1;i++)
	{
		if (Graph::GetIDegree(i)==0) {
			q[tail++]=i;
		}
		pushmessage(i);
	}

}

void merge(message &m)
{
	list<Edge>::iterator p=m.value.begin();
	//assert(p!=m.value.end());
	do
	{
		list<Edge>::iterator oldp=p;
		p++;
		while (p!=m.value.end() && oldp->st==p->st) {
			oldp->prob=sum(oldp->prob,p->prob);
			p=m.value.erase(p);
		}
	} while (p!=m.value.end());
}

long double findadvisee(int x, int adv, int year)
{
	message m=Graph::GetSM(x);
	list<Edge>::iterator p=m.value.end();
	long double prob=sum0();
	//assert(m.value.begin()!=p);
	do
	{
	//rule 3
		p--;
		if (p->st>=year || p->v!=adv) prob=sum(prob,p->prob); 
		//else 
		//	if (p->v!=adv) prob=sum(prob,p->prob); 
	} while (p!=m.value.begin());
	return prob;
}

//double findadvisor(int x, int year)
//{
//	message m=Graph::GetSM(x);
//	list<Edge>::iterator p=m.value.end();
//	double prob=0;
//	do
//	{
//	//rule 3
//		p--;
//		if (p->st>year) prob+=p->prob;
//	} while (p!=m.value.begin());
//	return prob;
//}

void passing::pass()
{
	int diff=0;
	int head=0;
	int i,j,k;
	//first half: foward
	adv0=new int[n+1];
	for (i=0;i<=n;i++) adv0[i]=0;
	while (head<tail && q[head])
	{
		//assert(head<tail);
		k=q[head++];
		adv0[k]=1;
		message m=Graph::GetSM(k);
		//merge(m);
		Graph::pushSM(k,m);
		int d=Graph::GetDegree(k);
		for (i=0;i<d;i++)
		{
			Edge e=Graph::GetEdge(k,i);
			if (Graph::receiveMessage(e.v)) {
				q[tail++]=e.v;
				int id=Graph::GetIDegree(e.v);
				m=Graph::GetSM(e.v);
				list<Edge>::iterator p=m.value.begin();
				while (p!=m.value.end()){
					int gradyear=p->ed;
					long double prob=prod0();
					for (j=0;j<id;j++)
						prob=product(prob,findadvisee(Graph::GetIEdge(e.v,j).v,e.v,gradyear));
					p->prob=product(p->prob,prob);
					p++;
				}
				Graph::pushSM(e.v,m);
			}

		}
	}
	if (head==tail) head--;
	//assert(!q[head]);
	//printf("%d\n",tail);
	//for (i=1;i<=n;i++) if (!adv0[i]) printf("%d\n",i);
	//printf("OK\n");
	//second half backward
	message *rmo=Graph::GetRMO(0);
	rmo[0].value.clear();
	rmo[0].value.push_back(Edge(0,prod0(),MAXTIME,0));
	//rmo[0].value.push_back(Edge(0,1,MAXTIME,0));
	//Graph::pushRMO(rmo);
	adv=new int[n+1];
	while (head>=0)
	{
		k=q[head--];
		//printf("%d %d\n",head,k);
		rmo=Graph::GetRMO(k);
		int d=Graph::GetDegree(k);
		int id=Graph::GetIDegree(k);
		//if (!k) printf("%d\n",id);
		vector<long double> prob(d,prod0()),prob2(d,prod0());
		//result[k]=new long double[1];
		if (k) {
			list<Edge> le=Graph::GetSM(k).value;
			list<Edge>::iterator p1=le.begin();
			long double prob0=prod0();
			for (i=0;i<d;i++)
			{
				list<Edge>::iterator p=rmo[i].value.end();
				prob0=product(prob0,(--p)->prob);
			}
			for (i=0;i<d;i++) 
			{
				//if (rmo[i].value.begin()==rmo[i].value.end()) {
				//	printf("%d %d\n",k,Graph::GetEdge(k,i).v);
				//	exit(0);
				//}
				list<Edge>::iterator p=rmo[i].value.begin();
				//for (j=0;j<i;j++)
				//	prob[j]*=p->prob;
				//for (j=i+1;j<d;j++)
				//	prob[j]*=p->prob;
				//p++;
				prob[i]=product(prob0,p->prob);
				prob[i]/=d;
				/*prob[i]=product(prob[i],Graph::GetEdge(k,i).prob);
				int gradyear=Graph::GetEdge(k,i).ed;
				for (j=0;j<id;j++)
					prob[i]=product(prob[i],findadvisee(Graph::GetIEdge(k,j).v,k,gradyear));*/
				prob[i]=product(prob[i],p1->prob);
				p1++;
			}
		}
		//normalize(result[k]);
		//printf("after normalize %d\n",k);
		//assert(k||id==n);
		for (i=0;i<id;i++)
			//for each possible advisee
		{
			Edge e=Graph::GetIEdge(k,i);
			int v=e.v;
			e.prob=sum0();
			long double prob0=sum0();
			//for (j=0;j<d;j++)
			//{
			//	int gradyear=Graph::GetEdge(k,j).ed;
			//	double x=findadvisee(v,gradyear);
			//	//prob2[j]/=x;
			//	if (_isnan(prob2[j]/x)) continue;
			//		//prob2[j]/=x;
			//	e.prob=sum(e.prob,prob2[j]/x);
			//	
			//}
			rmo=Graph::GetRMO(v);
			
			int dv=Graph::GetDegree(v);
			for (j=0;j<dv;j++)
				//for the message from each possible advisor
				if (Graph::GetEdge(v,j).v==k) break;


			rmo[j].value.clear();
			//rmo[j].value.push_back(e);
				
			int styear=Graph::GetEdge(v,j).st;
			//e=Edge(k,0,e.st,e.ed);
			for (int j2=0;j2<d;j2++)
			{
				int gradyear=Graph::GetEdge(k,j2).ed;
				double x=findadvisee(v,k,gradyear);
				if (_isnan(iprod(prob[j2],x))) continue;
				prob0=sum(prob0,iprod(prob[j2],x));
				if (gradyear<=styear) 
				{
					//double x=findadvisee(v,k,gradyear);
					//if (_isnan(iprod(prob[j2],x))) continue;
					//e.prob-=prob2[j2]/x;
					e.prob=sum(e.prob,iprod(prob[j2],x));
				}
			}
			//if (_isnan(iprod(e.prob,prob0))) e.prob=sum0(); else
				e.prob=iprod(e.prob,prob0);
			rmo[j].value.push_back(e);
			//if (_isnan(prob0)) e.prob=sum0(); else 
				e.prob=prob0;
			rmo[j].value.push_back(e);
		}
		long double max=0;
		for (i=1;i<d;i++) if (Graph::GetEdge(k,i).prob>Graph::GetEdge(k,max).prob && Graph::GetEdge(k,i).st>=Graph::GetEdge(Graph::GetEdge(k,i).v,adv[Graph::GetEdge(k,i).v]).ed) max=i;
		adv0[k]=max;
		for (i=0;i<d;i++) Graph::updateProb(k,i,prob[i]);
		//for (i=0;i<d;i++) printf("%lg ",prob[i]);
		//printf("\n");
		max=0;
		for (i=1;i<d;i++) if (prob[i]>prob[max] && Graph::GetEdge(k,i).st>=Graph::GetEdge(Graph::GetEdge(k,i).v,adv[Graph::GetEdge(k,i).v]).ed) max=i;
		adv[k]=max;
		diff+=(adv0[k]!=adv[k]);
	}
	//getc(stdin);
	delete[] q;
	printf("diff:%d\n",diff);
}
long double ** passing::collect()
{
	int i,j;
	result=new long double*[n];
	long double prob=0;
	//for (i=1;i<n+1;i++)
	//{
	//	int d=Graph::GetDegree(i);
	//	for (j=0;j<d;j++)
	//		prob=sum(prob,Graph::GetEdge(i,j).prob);
	//}
	for (i=1;i<n+1;i++)
		//if (prob<=EPS)
		//{
		//	for (j=0;j<d;j++)
		//		result[i][j]=1.0/d;
		//	//printf("%d %f\n",i,prob);
		//	//exit(0);
		//}
		//else
	{
		int d=Graph::GetDegree(i);
		result[i]=new long double[d]; 
		for (j=0;j<d;j++)
		{
			result[i][j]=Graph::GetEdge(i,j).prob;
			//if (result[i][j]!=result[i][j]) result[i][j]=1.0/d;
		}
	}
	return result;
}

void passing::collect0()
{
	int i,j;
	long double prob=sum0();
	for (i=1;i<n+1;i++)
	{
		long double prob0=sum0();
		int d=Graph::GetDegree(i);
		for (j=0;j<d;j++)
			prob0=sum(prob0,Graph::GetEdge(i,j).prob);
		if (prob0>prob) prob=prob0;
	}
	for (i=1;i<n+1;i++)
		//if (prob<=EPS)
		//{
		//	for (j=0;j<d;j++)
		//		result[i][j]=1.0/d;
		//	//printf("%d %f\n",i,prob);
		//	//exit(0);
		//}
		//else
	{
		int d=Graph::GetDegree(i);
		//result[i]=new long double[d]; 
		long double prob0=sum0();
		for (j=0;j<d;j++)
			prob0=sum(prob0,Graph::GetEdge(i,j).prob);
		prob=prob0;
		for (j=0;j<d;j++)
		{
			Graph::updateProb(i,j,itrans(iprod(Graph::GetEdge(i,j).prob,prob)));
			//Graph::updateProb(i,j,Graph::GetEdge(i,j).prob/prob);
			//if (_isnan(Graph::GetEdge(i,j).prob)) (Graph::GetEdge(i,j).prob)=1.0/d;
		}
	}
	
}

long double *passing::collect(int i)
{
	int j,k;
	//for (i=1;i<n+1;i++)
		long double * result;
	{
		int d=Graph::GetDegree(i);
		long double prob=0;
		for (j=0;j<d;j++)
			prob=sum(prob,Graph::GetEdge(i,j).prob);
		//if (prob<=EPS)
		//{
		//	for (j=0;j<d;j++)
		//		result[i][j]=1.0/d;
		//	//printf("%d %f\n",i,prob);
		//	//exit(0);
		//}
		//else
		result=new long double[d];
		for (j=0;j<d;j++)
		{
			double p=Graph::GetEdge(i,j).prob;
			for (k=0;k<j && result[k]>=p;k++);
			for (int l=j;l>k;l--) result[l]=result[l-1];	
			result[k]=p;
			//if (_isnan(result[j])) result[j]=1.0/d;
		}
	}
	return result;
}

int* passing::advisor()
{
	return adv;
}

int* passing::advisor0()
{
	return adv0;
}