//
//  matmbMath.cpp
//  fp
//
//  Created by Robert Delaney on 9/9/17.
//
//


#include "fp.h"
#include "mb.h"
#include "mbMath.h"
#include "bf.h"
#include "matmb.hpp"
#include "matmbMath.hpp"
#include "matmbConv.hpp"
#include "matq.h"
#include "matqMath.h"

static matmb NULLMatmb()
{
    matmb   z;
    
    return z;
    
}/* NULLMatmb */

// swap row i with row j
static void matmbSwapRows(matmb& x, INT32 i, INT32 j)
{
    mb		temp;
    INT32	k;
    
    for(k=0;k<x.nc;k++)
    {
        temp = x.array[i][k];
        x.array[i][k] = x.array[j][k];
        x.array[j][k] = temp;
    }
    
}/* matmbSwapRows */

static bool doGauss(matmb& x, INT32& swapParity, INT32 ii, INT32 jj)
{
    mb          temp, temp1;
    INT32       i, j, k;
    bool        isZero;
    
    if(ii<0 || jj<0 || ii>=x.nr || jj>=x.nc)
        return false;
    
    k = 0;
    
    // is pivot zero?
    if(!x.array[ii][jj])
    {
        // must swap rows
        // find non-zero element below pivot in pivot column
        isZero = true;
        for(i=ii+1;i<x.nr;++i)
        {
            if(x.array[i][jj]!=0)
            {
                isZero = false;
                k = i;
                break;
            }
        }
        
        if(isZero)
            return false;
        
        // have non-zero element in row k, so swap rows ii with k
        matmbSwapRows(x, ii, k);
        swapParity = -swapParity;
    }
    
    // now we do the Gauss Elim. operation
    // pivot element xt[ii][jj] is not zero
    
    // now for every k not ii subtract row ii multiplied by xt[k][jj] from row k provided x[k][jj] != 0
    for(k=0;k<x.nr;++k)
    {
        if(k!=ii && (x.array[k][jj]!=0))
        {
            //temp = x[k][jj] / x[ii][jj];
            //DivCC(temp, x[k][jj], x[ii][jj]);
            temp = x.array[k][jj] / x.array[ii][jj];
            for(j=0;j<x.nc;++j)
            {
                //MulCC(temp1, x[ii][j], temp);
                temp1 = x.array[ii][j] * temp;
                //SubCC(x[k][j], x[k][j], temp1);
                x.array[k][j] = x.array[k][j] - temp1;
            }
        }
    }
    return true;
    
}/* doGauss */

static bool doReduce(matmb& x, INT32 ii, INT32 jj)
{
    mb              temp, temp1;
    INT32            i, j, k;
    bool            isZero;
    
    if(ii<0 || jj<0 || ii>=x.nr || jj>=x.nc)
        return false;
    
    k = 0; // to avoid incorrect warning
    // is pivot zero?
    if(x.array[ii][jj]==0)
    {
        // must swap rows
        // find non-zero element below pivot in pivot column
        isZero = true;
        for(i=ii+1;i<x.nr;++i)
        {
            if(x.array[i][jj]!=0)
            {
                isZero = false;
                k = i;
                break;
            }
        }
        
        if(isZero)
            return false;
        
        // have non-zero element in row k, so swap rows ii with k
        matmbSwapRows(x, ii, k);
    }
    // now we do the reduce operation
    // pivot element xt.array[ii][jj] is not zero
    // divide each element of row ii by xt.array[ii][jj] making xt.array[ii][jj] = 1
    //temp.Re = x[ii][jj].Re;
    //temp.Im = x[ii][jj].Im;
    temp = x.array[ii][jj];
    for(j=0;j<x.nc;++j)
    {
        //DivCC(x[ii][j], x[ii][j], temp);
        x.array[ii][j] = x.array[ii][j] / temp;
    }
    
    // now for every k not ii subtract row ii multiplied by xt[k][jj] from row k provided x[k][jj] != 0
    for(k=0;k<x.nr;++k)
    {
        if(k!=ii && x.array[k][jj]!=0)
        {
            //temp.Re = x[k][jj].Re;
            //temp.Im = x[k][jj].Im;
            temp = x.array[k][jj];
            for(j=0;j<x.nc;++j)
            {
                //MulCC(temp1, x[ii][j], temp);
                //SubCC(x[k][j], x[k][j], temp1);
                temp1 = x.array[ii][j] * temp;
                x.array[k][j] = x.array[k][j] - temp1;
            }
        }
    }
    return true;
    
}/* doReduce */


// returns 1 if x=y else returns 0
INT32 matmbCompare(const matmb& x, const matmb& y)
{
    INT32		i, j;
    
    if(x.nr!=y.nr || x.nc!=y.nc)
        return 0;
    
    for(i=0;i<x.nr;i++)
        for(j=0;j<x.nc;j++)
            if(x.array[i][j]!=y.array[i][j])
                return 0;
    
    return 1;
    
}/* matmbCompare */


void add(matmb& z, const matmb& x, const matmb& y)
{
    matmb		zt;
    INT32		i, j;
    
    if(x.nr!=y.nr || x.nc!=y.nc)
        return;
    
    init(zt, x.nr, x.nc);
    
    for(i=0;i<zt.nr;i++)
        for(j=0;j<zt.nc;j++)
            zt.array[i][j] = x.array[i][j] + y.array[i][j];
    
    z = zt;
    
}/* add */


void sub(matmb& z, const matmb& x, const matmb& y)
{
    matmb		zt;
    INT32		i, j;
    
    if(x.nr!=y.nr || x.nc!=y.nc)
        return;
    
    init(zt, x.nr, x.nc);
    
    for(i=0;i<zt.nr;i++)
        for(j=0;j<zt.nc;j++)
            zt.array[i][j] = x.array[i][j] - y.array[i][j];
    
    z = zt;
    
}/* sub */


void mul(matmb& z, const matmb& x, const matmb& y)
{
    matmb		zt;
    INT32		i, j, k;
    
    if(x.nc!=y.nr)
        return;
    
    init(zt, x.nr, y.nc);
    
    for(i=0;i<zt.nr;i++)
        for(j=0;j<zt.nc;j++)
        {
            zt.array[i][j] = 0;
            for(k=0;k<x.nc;k++)
                zt.array[i][j] += x.array[i][k] * y.array[k][j];
        }
    
    z = zt;
    
}/* mul */


void mul(matmb& z, const matmb& x, const mb& y)
{
    matmb		zt;
    INT32		i, j;
    
    init(zt, x.nr, x.nc);
    
    for(i=0;i<zt.nr;i++)
        for(j=0;j<zt.nc;j++)
            zt.array[i][j] = y * x.array[i][j];
    
    z = zt;
    
}/* mul */


void mul(matmb& z, const matmb& x, double y)
{
    matmb		zt;
    INT32		i, j;
    
    init(zt, x.nr, x.nc);
    
    for(i=0;i<zt.nr;i++)
        for(j=0;j<zt.nc;j++)
            zt.array[i][j] = y * x.array[i][j];
    
    z = zt;
    
}/* mul */

void div(matmb& z, const matmb& x, const mb& y)
{
    matmb		zt;
    INT32		i, j;
    
    if(!y)
        return;
    
    init(zt, x.nr, x.nc);
    
    for(i=0;i<zt.nr;i++)
        for(j=0;j<zt.nc;j++)
            zt.array[i][j] = x.array[i][j] / y;
    
    z = zt;
    
}/* div */

mb lcdMatmb(const matq& x)
{
    mb      z;
    INT32   i, j;
    
    z = 1;
    for(i=0;i<x.nr;++i)
        for(j=0;j<x.nc;++j)
            z = lcm(z, x.array[i][j].den);
    return z;
    
}/* lcdMatmb */

void div(matmb& z, const matmb& x, const matmb& y, mb& LCD)
{
    matq    yq, yqInv;
    matmb   zt;
    INT32   i, j;
    bf      det;
    
    if(y.nr != y.nc)
        return;
    
    yq = y;
    
    if(!matqInvert(yqInv, det, yq))
    {
        return;
    }
    
    LCD = lcdMatmb(yqInv);
    
    init(zt, y.nr, y.nc);
    
    for(i=0;i<zt.nr;++i)
        for(j=0;j<zt.nc;++j)
            zt.array[i][j] = (LCD * yqInv.array[i][j]).num;
    
    z = x * zt;
    
}/* div */


// creates unit matrix of size n
void matmbUnit(matmb& z, INT32 n)
{
    INT32		i, j;
    
    init(z, n, n);
    
    for(i=0;i<z.nr;i++)
        for(j=0;j<z.nc;j++)
        {
            if(i==j)
                z.array[i][j] = 1;
            else
                z.array[i][j] = 0;
        }
    
}/* matmbUnit */


matmb matmbUnit(INT32 n)
{
    matmb		z;
    
    matmbUnit(z, n);
    
    return z;
    
}/* matmbUnit */


void transpose(matmb& z, const matmb& x)
{
    matmb		zt;
    INT32		i, j;
    
    init(zt, x.nc, x.nr);
    
    for(i=0;i<zt.nr;i++)
        for(j=0;j<zt.nc;j++)
            zt.array[i][j] = x.array[j][i];
    
    z = zt;
    
}/* transpose */


matmb transpose(const matmb& x)
{
    matmb		z;
    
    transpose(z, x);
    
    return z;
    
}/* transpose */

// x must be square
bool trace(mb& z, const matmb& x)
{
    mb			zt;
    INT32		i;
    
    if(x.nr != x.nc)
    {
        z = 0;
        return false;
    }
    
    zt = 0;
    for(i=0;i<x.nr;i++)
        zt+=x.array[i][i];
    
    z = zt;
    
    return true;
    
}/* trace */


mb trace(const matmb& x)
{
    mb		z;
    
    if(trace(z, x))
    {
        return z;
    }
    else
    {
        z = 0;
        return z;
    }
    
}/* trace */

bool matmbInvert(matmb& z, mb& det, const matmb& x, mb& LCD)
{
    matmb   zt;
    matq    xq, xqInv;
    bf      detq;
    INT32   i, j;
    
    if(x.nr != x.nc)
        return false;
    
    xq = x;
    
    if(!matqInvert(xqInv, detq, xq))
    {
        return false;
    }
    
    LCD = lcdMatmb(xqInv);
    
    init(zt, x.nr, x.nc);
    
    for(i=0;i<zt.nr;++i)
        for(j=0;j<zt.nc;++j)
            zt.array[i][j] = (LCD * xqInv.array[i][j]).num;
    
    z = zt;
    det = detq.num;
    
    return true;
    
}/* matmbInvert */


matmb matmbInvert(mb& det, const matmb& x, mb& LCD)
{
    matmb   z;
    
    if(!matmbInvert(z, det, x, LCD))
        return NULLMatmb();
    
    return z;
    
}/* matmbInvert */

// puts matrix y.array at bottom of x.array; the matrices must have the same number of columns
void matmbRowAugment(matmb& z, const matmb& x, const matmb& y)
{
    matmb		zt;
    INT32		i, j;
    
    if(x.nc!=y.nc)
        return;
    
    init(zt, x.nr+y.nr, x.nc);
    
    // copy x.array into zt.array top
    for(i=0;i<x.nr;i++)
        for(j=0;j<x.nc;j++)
            zt.array[i][j] = x.array[i][j];
    
    // copy y.array to bottom of zt.array
    for(i=0;i<y.nr;i++)
        for(j=0;j<y.nc;j++)
            zt.array[i+x.nr][j] = y.array[i][j];
    
    z = zt;
    
}/* matmbRowAugment */


matmb matmbRowAugment(const matmb& x, const matmb& y)
{
    matmb	z;
    
    matmbRowAugment(z, x, y);
    
    return z;
    
}/* matmbRowAugment */


// puts matrix y.array to right of x.array; the matrices must have the same number of rows
void matmbColAugment(matmb& z, const matmb& x, const matmb& y)
{
    matmb		zt;
    INT32		i, j;
    
    if(x.nr!=y.nr)
        return;
    
    init(zt, x.nr, x.nc+y.nc);
    
    zt.nr = x.nr;
    zt.nc = x.nc + y.nc;
    
    // copy x.array into zt.array left
    for(i=0;i<x.nr;i++)
        for(j=0;j<x.nc;j++)
            zt.array[i][j] = x.array[i][j];
    
    // copy y.array to right of zt.array
    for(i=0;i<y.nr;i++)
        for(j=0;j<y.nc;j++)
            zt.array[i][j+x.nc] = y.array[i][j];
    
    z = zt;
    
}/* matmbColAugment */

matmb matmbColAugment(const matmb& x, const matmb& y)
{
    matmb	z;
    
    matmbColAugment(z, x, y);
    
    return z;
    
}/* matmbColAugment */

// puts z.array in echelon form
void matmbGaussElim(matmb& z, INT32& swapParity, const matmb& x)
{
    matmb        xt;
    INT32       ii, jj;
    
    swapParity = 1;
    
    init(xt, x.nr, x.nc);
    
    xt = x;
    
    ii = jj = 0;
    while(ii<xt.nr && jj<xt.nc)
    {
        if(doGauss(xt, swapParity, ii, jj))
        {
            ii++;
            jj++;
        }
        else
            jj++;
    }
    
    z = xt;
    
}/* matqGaussElim */


bool matqGaussElim(matmb& z, INT32& swapParity, const char *xString)
{
    matmb		xt;
    
    if(!matmbConvFromStr(xt, xString))
        return false;
    
    matmbGaussElim(z, swapParity, xt);
    
    return true;
    
}/* matmbGaussElim */


// by rows mod q
void matmbGaussElim(matmb& z, const matmb& x, double q)
{
    matmb	xt;
    INT32	i, j, k, i_max, ii;
    bool	inSearch;
    
    xt = x;
    ii = 0; // i -> pivot row; ii -> pivot column
    for(i=0;i<xt.nr-1;i++)
    {
        inSearch = true;
        while(inSearch)
        {
            // find row whose iith column is not zero at ith row and below
            i_max = i;
            for(k=i+1;k<xt.nr;k++)
            {
                if(abs(xt.array[k][ii])>abs(xt.array[i_max][ii]))
                    i_max = k;
            }
            
            if(!xt.array[i_max][ii].n)
            {
                // iith column must be skipped since it's zero from ith row down
                ii++; // move to next column to right
                if(ii>=xt.nc)
                {
                    z = xt;
                    return;
                }
            }
            else
                inSearch = false;
        }
        
        if(!(i==i_max))
        {
            matmbSwapRows(xt, i, i_max);
        }
        
        // Eliminate the iith element of the rows below the ith
        
        for(k=i+1;k<xt.nr;k++)
        {
            for(j=ii+1;j<xt.nc;j++)
            {
                xt.array[k][j] = xt.array[i][ii] * xt.array[k][j] - xt.array[k][ii] * xt.array[i][j];
                xt.array[k][j] = (xt.array[k][j] % q);
            }
            
            xt.array[k][ii] = 0;
        }
        ii++;
        if(ii>=xt.nc)
        {
            z = xt;
            return;
        }
    }
    
    z = xt;
    
}/* matmbGaussElim */

// reduced echelon form; x must be in echolon form
void matmbReduce(matmb& z, const matmb& x)
{
    matmb            xt;
    INT32           ii, jj;
    
    init(xt, x.nr, x.nc);
    
    xt = x;
    
    ii = jj = 0;
    while(ii<xt.nr && jj<xt.nc)
    {
        if(doReduce(xt, ii, jj))
        {
            ii++;
            jj++;
        }
        else
            jj++;
    }
    z = xt;
    
}/* matmbReduce */

void matmbNullSpace(INT32& numNV, matmb& nullVectors, matmb& reduced, const matmb& x)
{
    matmb		xt, yt, zt;
    INT32		i, j, k;
    INT32		swapParity;
    
    xt = x;
    
    matmbRowAugment(yt, xt, matmbUnit(xt.nc));
    matmbGaussElim(zt, swapParity, transpose(yt));
    matmbReduce(reduced, zt);
    
    // look for zero rows in reduced for column index from 0 to xt.nr; null vectors are at the ends of these rows
    // create a Boolean array to hold true or false for null vector rows
    bool	*hasNV;
    bool	isNull;
    hasNV = (bool*)malloc(reduced.nr*sizeof(bool));
    for(i=0;i<reduced.nr;i++)
    {
        isNull = true;
        for(j=0;j<xt.nr;j++)
        {
            if(reduced.array[i][j].n)
            {
                isNull = false;
                break;
            }
        }
        if(isNull)
            hasNV[i] = true;
        else
            hasNV[i] = false;
    }
    
    // get number of null vectors
    numNV = 0;
    for(i=0;i<reduced.nr;i++)
    {
        if(hasNV[i])
            numNV++;
    }
    
    if(!numNV)
    {
        free(hasNV);
        init(nullVectors, 1, 1);
        nullVectors.array[0][0] = 0;
        return;
    }
    
    // now have null vectors which we want as columns of myNullVectors.array
    init(nullVectors, xt.nc, numNV);
    // copy null vectors into columns of nullVectors.array
    k = 0;
    for(j=0;j<xt.nc;j++)
    {
        if(hasNV[j])
        {
            for(i=0;i<nullVectors.nr;i++)
                nullVectors.array[i][k] = reduced.array[j][xt.nr+i];
            k++;
        }
    }
    
    mb		temp;
    
    for(j=0;j<nullVectors.nc;j++)
    {
        if(nullVectors.nr==1)
            nullVectors.array[0][j] = 1;
        else
        {
            temp = gcd(nullVectors.array[0][j], nullVectors.array[1][j]);
            for(i=2;i<nullVectors.nr;i++)
                temp = gcd(temp, nullVectors.array[i][j]);
            for(i=0;i<nullVectors.nr;i++)
                nullVectors.array[i][j] = nullVectors.array[i][j] / temp;
        }
        
    }
    
    free(hasNV);
    
}/* matmbNullSpace */


bool matmbNullSpace(INT32& numNV, matmb& nullVectors, matmb& reduced, const char *xString)
{
    matmb		x;
    
    if(!matmbConvFromStr(x, xString))
        return false;
    
    matmbNullSpace(numNV, nullVectors, reduced, x);
    
    return true;
    
}/* matmbNullSpace */

// mod q
void matmbNullSpace(INT32& numNV, matmb& nullVectors, matmb& reduced, const matmb& x, double q)
{
    matmb		xt, yt, zt;
    INT32		i, j, k;
    
    xt = x;
    
    matmbRowAugment(yt, xt, matmbUnit(xt.nc));
    matmbGaussElim(zt, transpose(yt), q);
    matmbReduce(reduced, zt);
    
    // look for zero rows in reduced for column index from 0 to xt.nr; null vectors are at the ends of these rows
    // create a Boolean array to hold true or false for null vector rows
    bool	*hasNV;
    bool	isNull;
    hasNV = (bool*)malloc(reduced.nr*sizeof(bool));
    for(i=0;i<reduced.nr;i++)
    {
        isNull = true;
        for(j=0;j<xt.nr;j++)
        {
            if(reduced.array[i][j].n)
            {
                isNull = false;
                break;
            }
        }
        if(isNull)
            hasNV[i] = true;
        else
            hasNV[i] = false;
    }
    
    // get number of null vectors
    numNV = 0;
    for(i=0;i<reduced.nr;i++)
    {
        if(hasNV[i])
            numNV++;
    }
    
    if(!numNV)
    {
        free(hasNV);
        init(nullVectors, 1, 1);
        nullVectors.array[0][0] = 0;
        return;
    }
    
    // now have null vectors which we want as columns of myNullVectors.array
    init(nullVectors, xt.nc, numNV);
    // copy null vectors into columns of nullVectors.array
    k = 0;
    for(j=0;j<xt.nc;j++)
    {
        if(hasNV[j])
        {
            for(i=0;i<nullVectors.nr;i++)
                nullVectors.array[i][k] = reduced.array[j][xt.nr+i];
            k++;
        }
    }
    
    for(i=0;i<nullVectors.nr;i++)
        for(j=0;j<nullVectors.nc;j++)
            nullVectors.array[i][j] = (nullVectors.array[i][j] % q);
    
    mb		temp;
    
    for(j=0;j<nullVectors.nc;j++)
    {
        if(nullVectors.nr==1)
            nullVectors.array[0][j] = 1;
        else
        {
            temp = gcd(nullVectors.array[0][j], nullVectors.array[1][j]);
            for(i=2;i<nullVectors.nr;i++)
                temp = gcd(temp, nullVectors.array[i][j]);
            if(temp>1)
                for(i=0;i<nullVectors.nr;i++)
                    nullVectors.array[i][j] = nullVectors.array[i][j] / temp;
        }
    }
    
    free(hasNV);
    
}/* matmbNullSpace */

void matmbMod(matmb& z, const matmb& x, const mb& q)
{
    INT32		i, j;
    
    z = x;
    
    for(i=0;i<x.nr;i++)
        for(j=0;j<x.nc;j++)
            z.array[i][j] = z.array[i][j] % q;
    
}/* matmbMod */

void matmbMod(matmb& z, const matmb& x, double q)
{
    INT32		i, j;
    
    z = x;
    
    for(i=0;i<x.nr;i++)
        for(j=0;j<x.nc;j++)
            z.array[i][j] = z.array[i][j] % q;
    
}/* matmbMod */

// z = x^y ; x must be a square matrix and y>=0
bool power(matmb& z, const matmb& x, INT32 y)
{
    INT32				i, n, sy;
    static UINT32		mask; // 0x80000000
    //static mb			one;
    matmb				zt;
    static bool			initGood=false;
    
    if(!initGood)
    {
        //init(one, 1);
        mask = 1;
        mask = (mask<<31); // 0x80000000
        initGood = true;
    }
    
    if(x.nr!=x.nc || y<0)
    {
        z = 0;
        return false;
    }
    
    if(matmbCompare(x, matmbUnit(x.nr)))
    {
        z = matmbUnit(x.nr);
        return true;
    }
    
    if(y==0)
    {
        z = matmbUnit(x.nr);
        return true;
    }
    
    zt = matmbUnit(x.nr);
    sy = y;
    n = NumBits(sy);
    
    sy = (sy<<(32-n));  // go to leading 1 bit
    
    for(i=0;i<n;i++)
    {
        mul(zt, zt, zt); // zt^2
        if(sy & mask)
            mul(zt, zt, x);
        sy = (sy<<1);
    }
    
    z = zt;
    
    return true;
    
}/* power */


// z = x^y mod q ; x must be a square matrix and y>=0
bool powerMod(matmb& z, const matmb& x, const mb& y, const mb& q)
{
    INT32				i, n;
    mb					mask;
    matmb				zt;
    
    if(x.nr!=x.nc || y.n<0)
    {
        z = 0;
        return false;
    }
    
    if(matmbCompare(x, matmbUnit(x.nr)))
    {
        z = matmbUnit(x.nr);
        return true;
    }
    
    if(y.n==0)
    {
        z = matmbUnit(x.nr);
        return true;
    }
    
    zt = matmbUnit(x.nr);
    n = NumBits(y);
    
    // need to make mask have n bits with leading bit one and all others zero
    mask = 1;
    mask = (mask<<(n-1));
    
    for(i=0;i<n;i++)
    {
        mul(zt, zt, zt); // zt^2
        zt = zt % q;
        if((y & mask).n)
        {
            mul(zt, zt, x);
            zt = zt % q;
        }
        mask = (mask>>1);
    }
    
    z = zt;
    
    return true;
    
}/* powerMod */
