// Circuit class implementation file

#include <stdafx.h>

#include "circuit.h"
#include "hilbert.h"
#include "gate.h"
#include "gates.h"

#include "resource.h"
#include "circuitmap.h"

#include <stdio.h>
#include <iostream.h>
#include <fstream.h>
#include <stdlib.h>

// constructor
CCircuit::CCircuit()
{
	n_qbs = 0;
	n_ops = 0;
	input_val = 0;
	load_ok = false;

	oplist = NULL;
	qbitlist = NULL;
}

// destructor
CCircuit::~CCircuit()
{
	TRACE("Circuit destructor start\n");

	// delete arrays created within load circuit
	// these arrays hold the indices for the bits
	// that the gates operate on
	for(int i = 0; i < n_ops; i++)
		if(qbitlist[i] != NULL)
			delete[] qbitlist[i];

	TRACE("Circuit destructor mid\n");
	
	// delete primary arrays
	if(oplist != NULL)
		delete[] oplist;
	if(qbitlist != NULL)
		delete[] qbitlist;
	TRACE("Circuit destructor end\n");
}

// accessor function
int CCircuit::numbits()
{
	return n_qbs;
}
int CCircuit::numops()
{
	return n_ops;
}

int CCircuit::load_okay()
{
	return load_ok;
}

// functions
void CCircuit::setmap()
{
	int code;

	// ***** gate codes *****

	// butterfly gate
	code = 0;
	circuitmap.SetAt(CString("butterfly"), code);

	// not gate
	code = 1;
	circuitmap.SetAt(CString("not"), code);

	// 2 bit controlled not gate
	code = 2;
	circuitmap.SetAt(CString("2 bit cnot"), code);

	// 3 bit controlled not gate
	code = 3;
	circuitmap.SetAt(CString("3 bit cnot"), code);

	// ***** control codes *****

	// numqbits
	code = 4;
	circuitmap.SetAt(CString("numqbits"), code);

	// numopserations
	code = 5;
	circuitmap.SetAt(CString("numops"), code);

	// ***** Other controls *****

	// measure
	code = 6;
	circuitmap.SetAt(CString("measure"), code);

}

int CCircuit::getnonedigits(ifstream fin)
{
	char character;
	int count = 0;

	for(;;)
	{
		fin >> character;
		count++;
		if(isdigit(character))
			return (int)(character - '0');
		else if(count > 10)	// bit hit/miss, need to sort out
			return -1;	// error!
	}
}

int CCircuit::getnum(char* string2, int* pos)
{
	char string[80];
	strcpy(string, string2);

	// skip spaces
	while(!isdigit(string[*pos]))
	{
		if(string[*pos] == NULL)
		{
			return -1;	// no number!
		}
		else
		{
			(*pos)++;	
		}
	}
		
	int numdigits = 0;
	while(isdigit(string[*pos]))
	{
		(*pos)++;
		numdigits++;
	}

	if(numdigits > 0)
	{
		(*pos) -= numdigits;	
		
		int num = 0;

		for(int i = numdigits - 1; i >= 0; i--)
		{
			num += ((int)(string[(*pos)] - '0')) * (int)pow(10, i);
			(*pos)++;
		}
		
		return num;
	}
	else
		return -1;	// error
}


	
int CCircuit::loadcircuit(int mode, THREADPARAMS* ptp)
{
	setmap();

	// variables for selective super bit entry
	int sel_array[80];
	int nsbits = 0;		// number of superposition bits
	CString selection_string = ptp->circuit_input_select;
	int length = selection_string.GetLength();
	int stringpos = 0;
	char* charstring = NULL;

	// general variables
	int i;
	char line[80];
	int numbits;
	int startindex = 0;
	int code, num, linepos, opcount = 0;
	bool allgates = false;
	CString codestring;
	CString numstring;

	CString string("Loading circuit from file \"");
	string += ptp->circuit_file_name;
	string += "\"";
	ptp->res->activity_string = string;
	ptp->pWnd->SendMessage(WM_USER_THREAD_UPDATE_BOX);	

	n_ops = 0;		// no operations loaded yet

	if(mode == 1)	// selective super
	{	
		// form integer array of selection bits from the user input string

		linepos = 0;
		int length = selection_string.GetLength();
		if(length == 0)
			return 2;

		charstring = new char[length+1];

		for(int i = 0; i < length; i++)
		{
			charstring[i] = selection_string[i];
			TRACE("charstring: %c", charstring[i]);
		}

		charstring[length] = NULL;
		
		for(;;)
		{
			num = getnum(charstring, &linepos);
			if(num == -1)
			{
				TRACE("getnum = -1\n");
				break;	// end of line or bad data 
			}
			else
			{
				sel_array[nsbits] = num;
				nsbits++;	// update number of selection bits
			}
		}

		delete[] charstring;

		if(nsbits == 0)
			return 2;		// no selection bits set -> bit select error

	}

	for(;;)
	{
		linepos = 0;
		codestring = "";
		numstring = "";

		ptp->fin.getline(line, 80);

		if(line[0] == '!')	// eof mark (should use real eof but it doesn't work!)
		{
			break;	// finished
		}

		if(line[0] == '*')
		{
			TRACE("Start of data line\n");
	
			for(;;)	// get key word
			{
				linepos++;
				if(line[linepos] == '*')
					break;	// got codestring
				else
					codestring += line[linepos];
			}

			TRACE("code string: %s\n", codestring);
			
			num = getnum(line, &linepos);
			if(num == -1)
			{
				TRACE("getnum = -1\n");
				return 1;
			}

			TRACE("num: %d\n", num);
			if(!circuitmap.Lookup(codestring, code))
			{
				TRACE(codestring);
				TRACE(_T ("lookup error"));
				n_ops = opcount;
				return 1;		// keyword doesn't exist
			}
			else
			{
				switch(code)	// use num as appropriate for operation
				{
				case 4:
					n_qbs = num;
					break;
				
				case 5:
					if(mode == 2)	// super all mode, need n_qbs butterfly gates
						nsbits += n_qbs;

					n_ops = num + nsbits;	// total number of ops 
					oplist = new int[n_ops];
					qbitlist = new int*[n_ops];
					
					for(i = 0; i < n_ops; i++)	// in case incorrect data is found
						qbitlist[i] = NULL;			

					break;
				
				case 0: case 1: case 2: case 3: // gate operations

					numbits = num;
					
					if(oplist == NULL || qbitlist == NULL)
						return 1;	// oplist not yet created
					else
					{
						if(opcount >= (n_ops-nsbits))
						{
							TRACE("opcount >= n_ops - nsbits\n");

							return 3;	// return op error message
						}

						oplist[opcount + nsbits] = code;
								
						int* qbits = new int[numbits];

						for(int j = 0; j < numbits; j++)
						{
							num = getnum(line, &linepos);
							if(num == -1 || num > pow(2, n_qbs))
							{
								TRACE("getnum = -1\n");
								qbitlist[opcount + nsbits] = NULL;
								n_ops = opcount;
								return 1;
							}

							TRACE("num: %d\n", num);
							qbits[j] = num;
						}

						qbitlist[opcount+nsbits] = qbits;

						opcount++;
					}
					break;
				
				case 6:	// measurement
					
					numbits = num + 1; // to allow for number of measure bits
					
					if(oplist == NULL || qbitlist == NULL)
						return 1;	// oplist not yet created
					else
					{
						if(opcount >= (n_ops-nsbits))
						{
							TRACE("opcount >= n_ops - nsbits\n");

							return 3;	// return op error message
						}

						oplist[opcount+nsbits] = code;
								
						int* qbits = new int[numbits];

						qbits[0] = num;	// number of bits to measure
						for(int j = 1; j < numbits; j++)
						{
							num = getnum(line, &linepos);
							if(num == -1 || num > pow(2, n_qbs))
							{
								TRACE("getnum = -1\n");
								qbitlist[opcount+nsbits] = NULL;
								n_ops = opcount;
								return 1;
							}

							TRACE("num: %d\n", num);
							qbits[j] = num;
						}

						qbitlist[opcount+nsbits] = qbits;

						opcount++;
					}

					break;

				}
			}
		}
	}

	TRACE("The end! n_ops = %d, opcount = %d\n", n_ops, opcount);

	if((opcount + nsbits) != n_ops)
	{	
		return 3;	// missing op data
	}


	if(oplist == NULL || qbitlist == NULL)
		return 1;	// oplist not yet created


	TRACE("nsbits: %d\n", nsbits);

	// tag a bunch of butterfly gates on the front of the circuit
	// if required (for mode 1)

	if(mode == 1)
	{
		for(int k = 0; k < nsbits; k++)
		{
			oplist[k] = 0;
			int* qbits = new int[1];
			if((sel_array[k]) >= n_qbs)
			{
				// invalid select bit
				load_ok = false;
				n_ops = 0;
				return 2;
			}	

			qbits[0] = sel_array[k];
			qbitlist[k] = qbits;
			opcount++;
		}
	}

	if(mode == 2)
	{
		for(int k = 0; k < n_qbs; k++)
		{
			oplist[k] = 0;
			int* qbits = new int[1];

			qbits[0] = k;
			qbitlist[k] = qbits;
			opcount++;
		}
	}

	TRACE("opcount: %d\n", opcount);

	for(i = 0; i < opcount; i++)
	{
		TRACE("opindex: %d\n", oplist[i]);
		TRACE("first qbit: %d\n", qbitlist[i][0]);
	}

	// all okay if got here
	load_ok = true;

	return 0;
}

int CCircuit::solvecircuit(int mode, int val, THREADPARAMS* ptp, int fileout)
{
	ptp->res->activity_string = "Solving circuit";
	ptp->pWnd->SendMessage(WM_USER_THREAD_UPDATE_BOX);	

	// get activity progress object ptr
	CProgressCtrl *progress_ctrl_ptr = (CProgressCtrl*) ptp->pView->GetDlgItem(IDC_ACTIVITY_PROGRESS);
	progress_ctrl_ptr->SetPos(0);
	
	// initialise state vector for the circuit
	
	CHilbert* state_vec;

	if(mode == 0)
	{
		if(val >= pow(2, n_qbs))
			return 1;
		else
			state_vec = new CHilbert(n_qbs, val);
	}
	else if(mode == 1 || mode == 2)
		state_vec = new CHilbert(n_qbs, 0);
	
	if(fileout)
		ptp->fout << *state_vec << endl;

	int j;

	for(int i = 0; i < n_ops; i++)
	{
		float progress = ((float)i/n_ops)*100;
		progress_ctrl_ptr->SetPos((int)progress);	

		switch(oplist[i])
		{
		case 0:	{
					CButterfly* gate = new CButterfly();
					gate->applysub(state_vec, qbitlist[i]);
					if(fileout)
					{
						ptp->fout << "Applying " << gate->retname() << " gate to qbit(s): ";
						for(j = 0; j < gate->retn_gqbs(); j++)
							ptp->fout << qbitlist[i][j] << " ";
						ptp->fout << endl << endl;
					}
					delete gate;		
					break;
				}

		case 1: {
					CNot* gate = new CNot();
					gate->applysub(state_vec, qbitlist[i]);
					if(fileout)
					{
						ptp->fout << "Applying " << gate->retname() << " gate to qbit(s): ";
						for(j = 0; j < gate->retn_gqbs(); j++)
							ptp->fout << qbitlist[i][j] << " ";
						ptp->fout << endl << endl;
					}
					delete gate;		
					break;
				}

		case 2: {
					CCnot2* gate = new CCnot2();
					gate->applysub(state_vec, qbitlist[i]);
					if(fileout)
					{
						ptp->fout << "Applying " << gate->retname() << " gate to qbit(s): ";
						for(j = 0; j < gate->retn_gqbs(); j++)
							ptp->fout << qbitlist[i][j] << " ";
						ptp->fout << endl << endl;
					}
					delete gate;		
					break;
				}
		
		case 3: {
					
					CCnot3* gate = new CCnot3();
					gate->applysub(state_vec, qbitlist[i]);
					if(fileout)
					{
						ptp->fout << "Applying " << gate->retname() << " gate to qbit(s): ";
						for(j = 0; j < gate->retn_gqbs(); j++)
							ptp->fout << qbitlist[i][j] << " ";
						ptp->fout << endl << endl;
					}
					delete gate;		
					break;
				}

		case 6: {
					if(fileout)
					{
						ptp->fout << "Measurement qubits "; 
						for(j = 1; j < qbitlist[i][0]; j++)
							ptp->fout << qbitlist[i][j] << " ";
						ptp->fout << endl << endl;
					}

					state_vec->measure(qbitlist[i][0], &qbitlist[i][1], ptp);

					break;
				}


		default:{ 
					ptp->res->activity_string = "Operation doesn't exist!";
					ptp->pWnd->SendMessage(WM_USER_THREAD_UPDATE_BOX);	
					
					break;		
				}
		}

		if(fileout)
			ptp->fout << *state_vec << endl;
	}
	
	delete state_vec;

	//all okay
	return 0;
}
