// Matrix.h
//
// General maths stuff

#ifndef _MATRIX_H_
#define _MATRIX_H_

// Complex numbers
class CComplex {
public:
    double Real, Imag;

public:
    CComplex( const double Re = 0.0, const double Im = 0.0 )
        : Real( Re )
        , Imag( Im )
        {
        }


    double MagSquared() const;

    friend CComplex operator*( const CComplex& c1, const CComplex &c2 )
        {
            CComplex ret;

            ret.Real = c1.Real * c2.Real - c1.Imag * c2.Imag;
            ret.Imag = c1.Real * c2.Imag + c1.Imag * c2.Real;

            return ret;
        }

    friend CComplex operator+( const CComplex& c1, const CComplex &c2 );
    friend CComplex operator-( const CComplex& c1, const CComplex &c2 );

    friend CComplex operator*( const CComplex& c, const double &d );
    friend CComplex operator+( const CComplex& c, const double &d );
    friend CComplex operator-( const CComplex& c, const double &d );

    friend CComplex operator*( const double &d, const CComplex& c );
    friend CComplex operator+( const double &d, const CComplex& c );
    friend CComplex operator-( const double &d, const CComplex& c );

    friend CComplex operator&&( const CComplex &c1, const CComplex &c2 );
    friend CComplex operator||( const CComplex &c1, const CComplex &c2 );

    friend BOOL operator==( const CComplex &c1, const CComplex &c2 );
    friend BOOL operator!=( const CComplex &c1, const CComplex &c2 )
        { return ! (c1==c2); }

    CComplex& operator=( const double &d );
    CComplex& operator=( int i );

    CComplex& operator*=( const double &d );
    CComplex& operator+=( const double &d );

    CComplex& operator-=( const double &d );

    CComplex& operator*=( const CComplex &c );
    CComplex& operator+=( const CComplex &c )
        {
            Real += c.Real;
            Imag += c.Imag;

            return *this;
        }

    CComplex& operator-=( const CComplex &c );

    friend ostream& operator<<( ostream &os, const CComplex &c );
};

class CBaseComplexVector {
private:
    int m_iLength;

public:
    CBaseComplexVector(){ m_iLength = 0; };
    virtual ~CBaseComplexVector(){ }

    virtual void Reset();

public:
    virtual void Zero();
    virtual void SetLength( int iLength, BOOL bZero = TRUE );

    int Length() const {return m_iLength;}

    virtual const CComplex& GetElement( int i ) const PURE;
    virtual void SetElement( int i, const CComplex &cv ) PURE;

    virtual void ScaleElement( int i, const CComplex &cv );
    virtual void AddToElement( int i, const CComplex &cv );

    virtual CComplex DotProduct( const CBaseComplexVector &cv1 ) const;
    virtual void Add( const CBaseComplexVector &cv1 );

    virtual CComplex MagSquared() const;
    const CComplex& operator[]( const int n ) const{ return GetElement(n); }

    virtual int MemUsage() const PURE;
    virtual void ShrinkMem(){ }

    CBaseComplexVector& operator=( CBaseComplexVector &cv );
};

class CComplexVector : public CBaseComplexVector {
    CComplex *m_pComplex;
    int m_iRealLength;

public:
    CComplexVector();
    CComplexVector( int iLength, BOOL bZero = TRUE );
    CComplexVector( const CComplexVector &cv );
    ~CComplexVector();

public:
    virtual void Reset();
protected:
    void Copy( const CComplexVector &cv );

public:
    CComplexVector& operator=( const CComplexVector &cv );
    CComplexVector& operator=( const CBaseComplexVector &cv );

    virtual void SetLength( int iLength, BOOL bZero = TRUE );
    virtual const CComplex& GetElement( int i ) const;
    virtual void SetElement( int i, const CComplex &cv );

    CComplex& operator[]( const int n )
    {
        ASSERT( n >= 0 && n < Length() );
        return m_pComplex[ n ];
    }
    const CComplex& operator[]( const int n ) const
    {
        ASSERT( n >= 0 && n < Length() );
        return m_pComplex[ n ];
    }

    friend CComplex operator*( const CComplexVector &cv1, const CComplexVector &cv2 );
    friend CComplexVector operator+( const CComplexVector &cv1, const CComplexVector &cv2 );
    friend CComplexVector operator-( const CComplexVector &cv1, const CComplexVector &cv2 );

    friend CComplexVector operator*( const CComplexVector &cv, const CComplex &c );
    friend CComplexVector operator*( const CComplex &c, const CComplexVector &cv );

    CComplexVector& operator*=( const CComplex &c );
    CComplexVector& operator+=( const CComplexVector &cv );
    CComplexVector& operator-=( const CComplexVector &cv );

    friend ostream& operator<<( ostream &os, const CComplexVector &cv );

    virtual int MemUsage() const { return m_iRealLength * sizeof CComplex; }
};

class CSparseComplexVector : public CBaseComplexVector {
    CComplexVector **m_ppComplexVector;
    const CComplex m_complexZero;
    BOOL m_bZero;

    enum {ELEMENTS_PER_SECTION = 128};
public:
    CSparseComplexVector();
    CSparseComplexVector( const CSparseComplexVector &scv ){ ASSERT( FALSE ); }
    ~CSparseComplexVector();
    CSparseComplexVector& operator=( const CSparseComplexVector &cv );
    CSparseComplexVector& operator=( const CBaseComplexVector &cv );

protected:
    int Section( int n ) const { return n / ELEMENTS_PER_SECTION; }
    int Element( int n ) const { return n % ELEMENTS_PER_SECTION; }
    void LoadSection( int i );
    CComplex &Value( int n ) const { return (*m_ppComplexVector[ Section(n) ])[ Element(n) ]; }
    int Sections( ) const { return Length() == 0 ? 0 : 1 + Section(Length()-1); }
    
    void Copy( const CSparseComplexVector &cv );
public:
    virtual void Reset();
    virtual void Zero( );
    virtual void ScaleElement( int i, const CComplex &cv );
    virtual void AddToElement( int i, const CComplex &cv );

    virtual void SetLength( int iLength, BOOL bZero = TRUE );

    virtual void SetElement( int i, const CComplex &cv );
    virtual const CComplex& GetElement( int i ) const { return (*this)[i]; }

    const CComplex& operator[]( const int n ) const
    {
        ASSERT( m_ppComplexVector );
        ASSERT( n >= 0 && n < Length() );

        if( m_ppComplexVector[Section(n)] )
            return Value(n);

        return m_complexZero;
    }

    virtual int MemUsage() const;
    virtual void ShrinkMem();
};

class CComplexMatrix {
    int m_iColumns;
    int m_iRows;

    CComplexVector *m_pRow;
public:
    CComplexMatrix();
    CComplexMatrix( int iRows, int iColumns, BOOL bZero = TRUE );
    CComplexMatrix( const CComplexMatrix &cm );
    ~CComplexMatrix();

public:
    void Reset();
protected:
    void Copy( const CComplexMatrix &cm );

public:
    void Zero();
    void SetSize( int iColumns, int iRows, BOOL bZero = TRUE );

    CComplexMatrix& operator=( const CComplexMatrix &cm );

    int Rows() const { return m_iRows; }
    int Columns() const { return m_iColumns; }

    void Multiply( CComplexVector& ret, const CComplexVector &cv ) const;

    CComplexVector& operator[]( int iRow )
    {
        ASSERT( iRow >= 0 && iRow < m_iRows );
        return m_pRow[ iRow ];
    }

    const CComplexVector& operator[]( int iRow ) const
    {
        ASSERT( iRow >= 0 && iRow < m_iRows );
        return m_pRow[ iRow ];
    }

    CComplex& Element( int iRow, int iColumn ){
        return (*this)[iRow][iColumn];
    }

    CComplexVector Row( int iRow ) const;
    CComplexVector Column( int iColumn ) const;

    BOOL IsUnitary( ) const;

    friend CComplexMatrix operator-( const CComplexMatrix &cm1, const CComplexMatrix &cm2 );
    friend CComplexMatrix operator+( const CComplexMatrix &cm1, const CComplexMatrix &cm2 );

    friend CComplexMatrix operator*( const CComplexMatrix &cm1, const CComplexMatrix &cm2 );
    friend CComplexVector operator*( const CComplexMatrix &cm, const CComplexVector &cv );

    CComplexMatrix& operator+=( const CComplexMatrix &cm );
    CComplexMatrix& operator-=( const CComplexMatrix &cm );

    // not tested
    CComplexMatrix& operator*=( const CComplexMatrix &cm );

    friend CComplexMatrix operator*( const CComplexMatrix &cm, const CComplex &c );
    friend CComplexMatrix operator*( const CComplex &c, const CComplexMatrix &cm );

    CComplexMatrix& operator*=( const CComplex &c );

    friend ostream& operator<<( ostream &os, const CComplexMatrix &cm );
};

#endif