#include "stdafx.h"

#include "qlist.h"
#include "matrix.h"
#include "pin.h"
#include "circuit.h"
#include "gate.h"
#include "token.h"
#include "qexception.h"
#include "GateGroup.h"

// CCircuit::CCircuit
//
// Create an empty circuit
CCircuit::CCircuit( const char *szName, const char *szCircuit )
: m_iInputBits( 0 )
, m_iOutputBits( 0 )
, m_bOrdered( TRUE )                    // Vacuously true
, m_szName( NULL )
, m_szCircuit( NULL )
{
    // There are no gates in the circuit, therefore no pins which 
    // output bits
    for( int i = 0; i < MAX_INPUT_BITS; i++ )
        m_pLastBit[ i ] = NULL;

    // Default name(s) to ""
    if( !szName ) szName = "";
    if( !szCircuit ) szCircuit = "";

    m_szName = new char[ strlen( szName ) + 1 ];
    strcpy( m_szName, szName );

    m_szCircuit = new char[ strlen( szCircuit ) + 1 ];
    strcpy( m_szCircuit, szCircuit );

}

// CCircuit::~CCircuit
//
// Destroy the circuit
CCircuit::~CCircuit( )
{
    for( int i = 0; i < NumberOfGates(); i++ )
        delete m_listGates[i];

    if( m_szName )
        delete [] m_szName;

    if( m_szCircuit )
        delete [] m_szCircuit;
}

// CCircuit::AddGate
//
// Add a gate to the circuit. The gate will be deleted on exiting
// and should not belong to any other circuit
CGate *CCircuit::AddGate( CGate *pGate )
{
    // Check that the gate inputs or outputs something
    ASSERT( pGate->InputPins() > 0 || pGate->OutputPins() > 0 );

    // We are geared towards quantum circuit(s) where each gate
    // must have equal numbers of input and output pins, or must
    // be a source or a sink gate.
    ASSERT( pGate->InputPins() == 0 
          || pGate->OutputPins() == 0 
          || pGate->InputPins() == pGate->OutputPins() 
          );

    // Point the gate to this circuit. Add it to out list
    pGate->SetParentCircuit( this );
    m_listGates += pGate;

    // Assume that we have lost ordering
    m_bOrdered = FALSE;

    // If the gate is a source gate...
    if( pGate->InputPins() == 0 )
    {
        // We can cope with single bit sources only
        ASSERT( pGate->OutputPins() == 1 );

        // Inform the output pin of its bit number. Set the pointer
        // to the last output pin to use that number
        pGate->OutputPin( 0 )->BitNumber() = m_iInputBits;
        m_pLastBit[ m_iInputBits ] = pGate->OutputPin( 0 );

        ++m_iInputBits;
    } else if( pGate->OutputPins() == 0 ) {
        // otherwise if we are a sink gate...

        // again, single bit sinks only.
        ASSERT( pGate->InputPins() == 1 );

        ++m_iOutputBits;
    }

    return pGate;
}

// CCircuit::OrderGates
//
// Order the gates in the circuit in to processing order
// That is, if a depends on the output of b then make sure that
// a occurs before b in m_listGates.
void CCircuit::OrderGates( )
{
    CQList<CGateGroup> listGates;
    CGateGroup gg( m_listGates );

    gg.ExpandGatesToSeparateGroups( listGates );

    m_listGates.Empty();

    for( int i = 0; i < listGates.Length(); i++ ){
        m_listGates += const_cast<CGate *>(listGates[i].Gate( 0 ));
    }
    /*

    CQList<CGateGroup> listGates;
    int iGates = NumberOfGates();

    for( int i = 0; i < iGates; i++ ){
        CGateGroup gg;

        gg.AddGate( Gate( i ) );
//        gg.Sort();

        listGates += gg;
    }

    listGates.PartialOrderSort();
    m_listGates.Empty();

    for( i = 0; i < iGates; i++ ){
        m_listGates += const_cast<CGate *>(listGates[i].Gate( 0 ));
    }


//    DBG_OUT( listGates );

*/
    m_bOrdered = TRUE;
}

// CCircuit::PrepareForProcessing
//
// Ensure that each gate in the circuit is in a ready state
// for processing.
void CCircuit::PrepareForProcessing()
{
    for( int i = 0; i < NumberOfGates(); i++ ){
        Gate( i )->PrepareForProcessing();
    }
}
 
// CCircuit::AllPinsConnected
//
// Perform some tests to detect invalid circuits. We return
// false if the gate has an unconnected pin
BOOL CCircuit::AllPinsConnected( ) const
{
    // Let's make sure that all the pins are connected.
    for( int i = 0; i < NumberOfGates(); i++ ){
        for( int j = 0; j < m_listGates[i]->InputPins(); j++ )
            if( !m_listGates[i]->InputPin( j )->IsConnected() )
                return FALSE;

        for( j = 0; j < m_listGates[i]->OutputPins(); j++ )
            if( !m_listGates[i]->OutputPin( j )->IsConnected() )
                return FALSE;
    }

    return TRUE;
}

// CCircuit::NotifyConnect
//
// Called during connections to notify the circuit when a connection is
// made between two pins. We use this to remember the last connection
// that is made on a certain bit (m_pLastBit[i]). This helps us connect
// gates in a 'bitwise' fashion.
// 
void CCircuit::NotifyConnect( COutputPin *pOutputPin, CInputPin *pInputPin )
{
    int iBit = pInputPin->BitNumber();

    // if the output pin's parent is not a sink gate...
    if( pInputPin->ParentGate()->OutputPins() != 0 ){
        // Remember when the last connection on this bit was made
        int iPin = pInputPin->ParentGate()->GetPinNumber( pInputPin );

        m_pLastBit[ iBit ] = pInputPin->ParentGate()->OutputPin( iPin );
    } else
        // parent was a sink. No more connections on this bit
        m_pLastBit[ iBit ] = NULL;
    
}

// CCircuit::Connect
//
// Connect a pin to the last output pin to output a certain bit number
void CCircuit::Connect( int iBitNumber, CInputPin *pInputPin )
{
    ASSERT( pInputPin );
    ASSERT( m_pLastBit[ iBitNumber ] );
    m_pLastBit[ iBitNumber ]->ConnectTo( pInputPin );
}

// CCircuit::FindGate
//
// Find the index of a gate in m_listGates
int CCircuit::FindGate( const CGate *pGate ) const
{
    return m_listGates.Find( const_cast<CGate *>(pGate) );
}

// CCircuit::LoadFrom
//
// Called to load in a circuit's details from an input stream.
// Circuit is in the form
//   ( size, initial, gate array )
//   where
//      size        - number of input bits to the circuit
//      initial     - Initial value on the input bits
//      gate array  - array of gates in the circuit
//                    in the form [ ( [bits], [matrix data] ) ]
//      bits        - array of bits that the gate operates on
//      matrix data - complex numbers for the matrix in the form
//                    real :+ imag

// We use two defines to help us catch errors. Both throw an error
// if a token type is not as expected
#define EXPECT_TOKEN( a ) if( token.GetTokenType() != CToken::a ){ QTHROW( "Expected input token " << CToken::GetTokenString((CToken::a)) << ", got " << token ); }
#define EXPECT_TOKEN_GATE( a, g ) if( token.GetTokenType() != CToken::a ){ QTHROW( "Expected input token " << CToken::GetTokenString((CToken::a)) << " while processing gate " << g << ", got " << token ); }

void CCircuit::LoadFrom( istream &is, ostream &os )
{
    CToken token;                       // Next token to process
    int iSize, iInitial;                // size and initial values
    const int iMaxGateInputs = 20;      // maximum gate size

    is >> token;    // (
    EXPECT_TOKEN( TOKEN_OPEN_BRACKET );
 
    is >> token;    // Size of circuit
    EXPECT_TOKEN( TOKEN_INT );
    iSize = token.GetInt();

    is >> token;    // ,
    EXPECT_TOKEN( TOKEN_COMMA );

    is >> token;    // Initial Value
    EXPECT_TOKEN( TOKEN_INT );
    iInitial = token.GetInt();

    // Now need to create source gates for the circuit (these are
    // not specified in the input file but are implicit). 
    for( int i = 0; i < iSize; i++ ){
        AddGate( new CStaticSource( "Static Source", iInitial & pow2(i) ) );
    }

    is >> token;    // ,
    EXPECT_TOKEN( TOKEN_COMMA );

    is >> token;    // [ 
    EXPECT_TOKEN( TOKEN_OPEN_SQUARE_BRACKET );

    int iGateNo = 0;
    // For each processing gate...
    do {
        char strGateName[ 30 ];         // Name for this gate
        int iBit[ iMaxGateInputs ];     // Bits that it works on
        int iGateInputs = 0;            // Number of inputs

        // Create a gate name in the form "Gate #xxx"
        sprintf( strGateName, "Gate #%i", ++iGateNo );

        is >> token;    // (
        EXPECT_TOKEN_GATE( TOKEN_OPEN_BRACKET, strGateName );

        is >> token;    // [ 
        EXPECT_TOKEN_GATE( TOKEN_OPEN_SQUARE_BRACKET, strGateName );

        // Read in the list of bits to operate on
        do {
            is >> token;
            if( token.GetTokenType() == CToken::TOKEN_INT )
                iBit[ iGateInputs++ ] = token.GetInt();
            is >> token;
        } while( token.GetTokenType() == CToken::TOKEN_COMMA );

        EXPECT_TOKEN_GATE( TOKEN_CLOSE_SQUARE_BRACKET, strGateName );

        is >> token;    // ,
        EXPECT_TOKEN_GATE( TOKEN_COMMA, strGateName );

        is >> token;    // [ 
        EXPECT_TOKEN_GATE( TOKEN_OPEN_SQUARE_BRACKET, strGateName );

        // The size of the matrix data should be 2**Number of input bits
        int iMatrixRowCols = pow2( iGateInputs );
        CComplexMatrix matrix( iMatrixRowCols, iMatrixRowCols );

        // Read in the matrix data
        for( i = 0; i < iMatrixRowCols; i++ ){
            for( int j = 0; j < iMatrixRowCols; j++ ){
                is >> token;
                EXPECT_TOKEN_GATE( TOKEN_COMPLEX, strGateName );
                matrix[i][j] = CComplex( token.GetReal(), token.GetImag() );
                if( i != (iMatrixRowCols-1) || j != (iMatrixRowCols-1) ){
                    is >> token;
                    EXPECT_TOKEN_GATE( TOKEN_COMMA, strGateName );
                }
            }
        }

        is >> token;    // ] 
        EXPECT_TOKEN_GATE( TOKEN_CLOSE_SQUARE_BRACKET, strGateName );

        is >> token;    // ] 
        EXPECT_TOKEN_GATE( TOKEN_CLOSE_BRACKET, strGateName );

        // Gates must be reversible
        if( !matrix.IsUnitary() )
            QTHROW( strGateName << " is not reversible." );

        // Create the gate and add it to the circuit
        CGate *pGate = AddGate( new CMatrixGate( strGateName, iGateInputs, matrix ) );

        // Connect its input bits
        for( i = 0; i < iGateInputs; i++ )
            Connect( iBit[ i ], pGate->InputPin( i ) );

        is >> token;
    } while( token.GetTokenType() == CToken::TOKEN_COMMA );

    EXPECT_TOKEN( TOKEN_CLOSE_SQUARE_BRACKET );
    is >> token;
    EXPECT_TOKEN( TOKEN_CLOSE_BRACKET );

        
    // Create output gates. Again these are implicit
    for( i = 0; i < iSize; i++ ){
        CGate *pGate = AddGate( new CSinkGate( "Output" ) );
        Connect( i, pGate->InputPin( 0 ));
    }
}

ostream& operator<<( ostream &os, CCircuit &Circuit )
{
    os << "Input Bits: " << Circuit.NumberOfInputBits() << endl;
    os << "Output Bits: " << Circuit.NumberOfOutputBits() << endl;
    os << "Gates: " << Circuit.NumberOfGates() << endl;
    os << "Ordered: " << (Circuit.m_bOrdered ? "TRUE" : "FALSE") << endl;

    os << "Gate List: \n";
    for( int i = 0; i < Circuit.NumberOfGates(); i++ )
        os << *Circuit.Gate(i) << "\n";

    return os;
}


CCircuit *CCircuit::CreateTestCircuit( int iCircuit, const char *szName, const char *szCircuit, int iSize, ostream &os )
{
    ASSERT( iCircuit >= 0 && iCircuit < NumberOfCircuitTests() );

    CCircuit *pCircuit = new CCircuit( szName, szCircuit );

    ASSERT( pCircuit );

    CGate **pSource = new CGate*[ iSize ];
    CGate **pInter1 = new CGate*[ iSize ];
    CGate **pInter2 = new CGate*[ iSize ];
    CGate **pInter3 = new CGate*[ iSize ];
    CGate **pInter4 = new CGate*[ iSize ];
    CGate **pInter5 = new CGate*[ iSize ];
    CGate **pInter6 = new CGate*[ iSize ];
    CGate **pSink   = new CGate*[ iSize ];

//    os << "Creating test circuit " << iCircuit << ", size " << iSize << endl;

    if( iCircuit == 0 ){
        const int n = iSize;

        for( int i = 0; i < n; i++ ){
            pSource[i] = pCircuit->AddGate( new CRandomisedSource( "Source X" ) );
            pInter1[i] = pCircuit->AddGate( new CSquareRootNOT( "sqrtNOT1 X" ) );
            pInter2[i] = pCircuit->AddGate( new CSquareRootNOT( "sqrtNOT2 X" ) );
            pSink[i] = pCircuit->AddGate( new CSinkGate( "Sink X" ) );
            pSource[i]->OutputPin(0)->ConnectTo( pInter1[i]->InputPin(0) );
            pInter1[i] ->OutputPin(0)->ConnectTo( pInter2[i]->InputPin(0) );
            pInter2[i] ->OutputPin(0)->ConnectTo( pSink[i] ->InputPin(0) );
        }
    } else if( iCircuit == 1 ){
        const int n = iSize;

        for( int i = 0; i < n; i++ ){
            pSource[i] = pCircuit->AddGate( new CRandomisedSource( "Source X" ) );
            pInter1[i] = pCircuit->AddGate( new CSquareRootNOT( "sqrtNOT1 X" ) );
            pInter2[i] = pCircuit->AddGate( new CSquareRootNOT( "sqrtNOT2 X" ) );
            pSink[i] = pCircuit->AddGate( new CSinkGate( "Sink X" ) );
            pCircuit->Connect( i, pInter1[i]->InputPin(0) );
            pCircuit->Connect( i, pInter2[i]->InputPin(0) );
            pCircuit->Connect( i, pSink[i]->InputPin(0) );
        }
    } else if( iCircuit == 2 ){
        const int n = iSize / 2;

        for( int i = 0; i < 2*n; i++ ){
            pSource[i]   = pCircuit->AddGate( new CRandomisedSource( "Source X" ) );
            pSink[i] = pCircuit->AddGate( new CSinkGate( "Sink X" ) );
        }

        for( i = 0; i < n; i++ ){
            pInter1[i] = pCircuit->AddGate( new CFlipGate( "Flip1 Gate" ) );
            pInter2[i] = pCircuit->AddGate( new CFlipGate( "Flip2 Gate" ) );
        }

        for( i = 0; i < n; i++ ){
            pSource[2*i]  ->OutputPin(0)->ConnectTo( pInter1[i]->InputPin(0) );
            pSource[2*i+1]->OutputPin(0)->ConnectTo( pInter1[i]->InputPin(1) );

            pInter1[i] ->OutputPin(0)->ConnectTo( pInter2[i] ->InputPin(0) );
            pInter1[i] ->OutputPin(1)->ConnectTo( pInter2[i] ->InputPin(1) );

            pInter2[i] ->OutputPin(0)->ConnectTo( pSink[2*i] ->InputPin(0) );
            pInter2[i] ->OutputPin(1)->ConnectTo( pSink[2*i+1] ->InputPin(0) );

        }
    } else if (iCircuit == 3) {
        const int n = iSize / 2;

        for( int i = 0; i < 2*n; i++ ){
            pSource[i]   = pCircuit->AddGate( new CRandomisedSource( "Source X" ) );
            pSink[i] = pCircuit->AddGate( new CSinkGate( "Sink X" ) );
        }

        for( i = 0; i < n; i++ ){
            pInter1[i] = pCircuit->AddGate( new CFlipGate( "Flip1 Gate" ) );
            pInter2[i] = pCircuit->AddGate( new CFlipGate( "Flip2 Gate" ) );
        }

        for( i = 0; i < n; i++ ){

            pCircuit->Connect( 2*i, pInter1[i]->InputPin(0) );
            pCircuit->Connect( 2*i+1, pInter1[i]->InputPin(1) );

            pCircuit->Connect( 2*i, pInter2[i]->InputPin(0) );
            pCircuit->Connect( 2*i+1, pInter2[i]->InputPin(1) );

            pCircuit->Connect( 2*i, pSink[2*i]->InputPin(0) );
            pCircuit->Connect( 2*i+1, pSink[2*i+1]->InputPin(0) );

        }
    } else if (iCircuit == 4) {
        const int n = iSize / 2;

        for( int i = 0; i < iSize; i++ ){
            pSource[i]   = pCircuit->AddGate( new CRandomisedSource( "Source X" ) );
            pSink[i] = pCircuit->AddGate( new CSinkGate( "Sink X" ) );
        }

        for( i = 0; i < iSize-1; i++ ){
            pInter1[i] = pCircuit->AddGate( new CFlipGate( "Flip1 Gate" ) );
            pInter2[i] = pCircuit->AddGate( new CFlipGate( "Flip2 Gate" ) );
        }


        for( i = 0; i < iSize-1; i++ ){
            pCircuit->Connect( i, pInter1[i]->InputPin(0) );
            pCircuit->Connect( i+1, pInter1[i]->InputPin(1) );

            pCircuit->Connect( i, pInter2[i]->InputPin(0) );
            pCircuit->Connect( i+1, pInter2[i]->InputPin(1) );

        }

        for( i = 0; i < iSize; i++ )
            pCircuit->Connect( i, pSink[i]->InputPin(0) );

    } else if (iCircuit == 5) {
        const int n = iSize / 2;

        for( int i = 0; i < iSize; i++ ){
            pSource[i]   = pCircuit->AddGate( new CRandomisedSource( "Source X" ) );
            pSink[i] = pCircuit->AddGate( new CSinkGate( "Sink X" ) );
        }

        for( i = 0; i < iSize-1; i++ ){
            pInter1[i] = pCircuit->AddGate( new CFlipGate( "Flip1 Gate" ) );
            pInter2[i] = pCircuit->AddGate( new CFlipGate( "Flip2 Gate" ) );
        }

        for( i = iSize-1; i > 0; i-- ){
            pCircuit->Connect( i, pInter1[iSize-i-1]->InputPin(1) );
            pCircuit->Connect( i-1, pInter1[iSize-i-1]->InputPin(0) );

            pCircuit->Connect( i, pInter2[iSize-i-1]->InputPin(1) );
            pCircuit->Connect( i-1, pInter2[iSize-i-1]->InputPin(0) );

        }

        for( i = 0; i < iSize; i++ )
            pCircuit->Connect( i, pSink[i]->InputPin(0) );


    } else if (iCircuit == 6) {
        const int n = iSize / 2;

        for( int i = 0; i < 2*n; i++ ){
//            pSource[i]   = pCircuit->AddGate( new CStaticSource( "Source X", !i ) );
            pSource[i]   = pCircuit->AddGate( new CRandomisedSource( "Source X" ) );
            pInter2[i] = pCircuit->AddGate( new CSquareRootNOT( "sqrtNOT1 X" ) );
            pInter3[i] = pCircuit->AddGate( new CSquareRootNOT( "sqrtNOT2 X" ) );
            pInter5[i] = pCircuit->AddGate( new CSquareRootNOT( "sqrtNOT3 X" ) );
            pInter6[i] = pCircuit->AddGate( new CSquareRootNOT( "sqrtNOT4 X" ) );
            pSink[i] = pCircuit->AddGate( new CSinkGate( "Sink X" ) );
        }

        for( i = 0; i < n; i++ ){
            pInter1[i] = pCircuit->AddGate( new CFlipGate( "Flip1 Gate" ) );
            pInter4[i] = pCircuit->AddGate( new CFlipGate( "Flip2 Gate" ) );
        }

        for( i = 0; i < n; i++ ){

            // not
            pCircuit->Connect( 2*i, pInter2[2*i]->InputPin(0) );
            pCircuit->Connect( 2*i, pInter2[2*i+1]->InputPin(0) );
            pCircuit->Connect( 2*i+1, pInter3[2*i]->InputPin(0) );
            pCircuit->Connect( 2*i+1, pInter3[2*i+1]->InputPin(0) );

            // flip
            pCircuit->Connect( 2*i, pInter1[i]->InputPin(0) );
            pCircuit->Connect( 2*i+1, pInter1[i]->InputPin(1) );
            pCircuit->Connect( 2*i, pInter4[i]->InputPin(0) );
            pCircuit->Connect( 2*i+1, pInter4[i]->InputPin(1) );

            // not
            pCircuit->Connect( 2*i, pInter5[2*i]->InputPin(0) );
            pCircuit->Connect( 2*i, pInter5[2*i+1]->InputPin(0) );
            pCircuit->Connect( 2*i+1, pInter6[2*i]->InputPin(0) );
            pCircuit->Connect( 2*i+1, pInter6[2*i+1]->InputPin(0) );

            pCircuit->Connect( 2*i, pSink[2*i]->InputPin(0) );
            pCircuit->Connect( 2*i+1, pSink[2*i+1]->InputPin(0) );
        }
    } 


//    os << endl;

    delete [] pSource;
    delete [] pInter1;
    delete [] pInter2;
    delete [] pInter3;
    delete [] pInter4;
    delete [] pInter5;
    delete [] pInter6;
    delete [] pSink;

    return pCircuit;
}

BOOL CCircuit::SimpleTestCorrect( int iCircuit, int iSize, int iInitial, int iOutput )
{
    ASSERT( iCircuit >= 0 && iCircuit < NumberOfCircuitTests() );
    int iExpected = 0;

    if( iCircuit == 0 || iCircuit == 1 ){
        iExpected = ~iInitial & (pow2( iSize ) -1);
        /*
        ASSERT( (iInitial & iMask) == (~iOutput & iMask) );
        int iMask = pow2( iSize ) - 1;
        return (iInitial & iMask) == (~iOutput & iMask);
        */
    } else if( iCircuit == 2 || iCircuit == 3 || iCircuit == 6){
        iExpected = iInitial;

        if( iCircuit == 6 ) 
            iExpected = ~iExpected;

        for( int i = 0; i < iSize-1; i+=2 ){
            int iBit0 = pow2( i );
            int iBit1 = pow2( i+1 );

            if( iExpected & iBit1 ) 
                iExpected ^= iBit0;
        }

        if( iCircuit == 6 ) 
            iExpected = ~iExpected & (pow2( iSize ) -1);
/*
        int iFlip = 0xaaaaaaaa & iMask;

        if( (iInitial & iFlip) != (iOutput & iFlip) ){
            ASSERT( FALSE );
            return FALSE;
        }

        for( int i = 1; i < iSize; i+=2 ){
            if(iInitial & pow2( i )){
                if( (iInitial & pow2(i-1)) == (iOutput & pow2(i-1)) ){
                    ASSERT( FALSE );
                    return FALSE;
                }
            }
        }

        return TRUE;
        */

    } else if(iCircuit == 4){
        iExpected = iInitial;
        for( int i = 0; i < iSize -1; i++ ){
            int iBit0 = pow2( i );
            int iBit1 = pow2( i+1 );

            if(iExpected & iBit1)
                iExpected ^= iBit0;
        }
        

    } else if(iCircuit == 5){
        iExpected = iInitial;
        for( int i = iSize-1; i > 0; i-- ){
            int iBit0 = pow2( i-1 );
            int iBit1 = pow2( i );

            if(iExpected & iBit1)
                iExpected ^= iBit0;
        }
        
    } else
        ASSERT( !"Shouldn't have reached here!" );

    ASSERT( iExpected == iOutput );
    return (iExpected == iOutput);
}

