//
//  mbMath.cpp
//  fp64
//
//  Created by Bob Delaney on 12/4/18.
//  Copyright © 2018 Bob Delaney. All rights reserved.
//


#include "mbMath.hpp"
#include "fpConv.hpp"
#include "fpMath.hpp"
#include "fpFuncs.hpp"
#include <stdio.h>
#include <string.h>

extern long blockPrec, decPrec;

static void coutHex(const mb& x)
{
    long    i;
    
    cout << "sign = " << x.s << endl;
    
    for(i=0;i<x.n;++i)
        cout << hex << x.b[i] << endl;
    cout << dec << endl;
    
}/* coutHex(mb) */

static void cout128(const UINT128 x)
{
    UINT64      hi, lo;
    char        hexBlock[17], temp[17];
    long        i, length;
    
    hi = x>>64;
    lo = x;
    printf("%llx", hi);
    sprintf(temp, "%llx", lo);
    length = strlen(temp);
    hexBlock[0] = 0;
    for(i=0;i<16-length;++i)
        strcat(hexBlock, "0");
    strcat(hexBlock, temp);
    cout << hexBlock << endl;
    
}/* cout128 */

long numLowOrderZeroBits(UINT64 x)
{
    UINT64              mask=1;
    long                theBits;
    
    theBits = 0;
    
    while(!(x & mask) && theBits<blockBits)
    {
        x = (x>>1);
        theBits++;
    }
    
    return theBits;
    
}/* numLowOrderZeroBits */


// approximate log to base 2 of x, x can be negative; it is treated as positive
long Lg2(const mb& x)
{
    mb                  xt;
    static UINT64       mask;
    static bool         initGood=false;
    long                i, xPowerOfTwo, numZeroBits;
    UINT64              cs;
    
    if(!initGood)
    {
        mask = 1;
        mask = (mask<<(blockBits-1)); // 0x8000000000000000
        initGood = true;
    }
    
    xt = x;
    
    xPowerOfTwo = blockBits*(xt.n-1);
    
    // look at high blocks
    numZeroBits = 0;
    cs = xt.b[xt.n-1];
    for(i=0;i<blockBits;i++)
    {
        if(!(cs & mask))
        {
            numZeroBits++;
            cs = (cs<<1);
        }
        else
            break;
    }
    
    return xPowerOfTwo + blockBits - numZeroBits - 1;
    
}/* Lg2 */

// this shifts x left so the most significant bit of the high block of x is 1
// it returns the number of bit shifts
// x is assumed normalized
long alignLeft(mb& x)
{
    static UINT64           mask; // 0x8000000000000
    static bool             initGood=false;
    UINT64                  cx;
    long                    numBits, inBits, i;
    mb                      xt;
    
    if(!initGood)
    {
        mask = 1;
        mask = (mask<<(blockBits-1));
        initGood = true;
    }
    
    xt = x;
    numBits = 0;
    
    for(i=0;i<xt.n;i++)
    {
        inBits = 0;
        cx = xt.b[x.n-1-i];
        if(!cx)
            inBits = blockBits;
        else
            while(!(cx & mask))
            {
                inBits++;
                cx = (cx<<1);
            }
        numBits = numBits + inBits;
        if(inBits!=blockBits)
            break;
    }
    
    mbShiftLeft(xt, xt, numBits);
    x = xt;
    
    return numBits;
    
}/* alignLeft */

mb abs(const mb& x)
{
    mb      z;
    
    z = x;
    if(!z.s)
        z.s = !z.s;
    return z;
    
}/* abs */

// compares abs(x) to abs(y); returns:
// +1 if abs(x)>abs(y)
// 0  if abs(x)=abs(y)
// -1 if abs(x)<abs(y)
static long compareAbs(const mb& x, const mb& y)
{
    long          i;
    
    if(!x.n && !y.n)
        return 0;
    
    if(!y.n)
        return 1;
    
    if(!x.n)
        return -1;
    
    if(x.n>y.n)
        return 1;
    if(x.n<y.n)
        return -1;

    // now x.n=y.n
    for(i=x.n-1;i>=0;i--)
    {
        if(x.b[i]>y.b[i])
            return 1;
        if(x.b[i]<y.b[i])
            return -1;
    }
    
    return 0;
    
}/* compareAbs(mb, mb) */

static void mbAddAbs(mb& z, const mb& x, const mb& y)
{
    mb           zt;
    long         i, minNumBlocks;
    UINT64       carry;
    UINT128      sLL;
    
    // zt.n is one more than max of x.n and y.n in case of non-zero carry
    zt.n = 1 + max(x.n, y.n);
    minNumBlocks = min(x.n, y.n);
    zt.b = (UINT64*)malloc(zt.n*sizeof(UINT64));
    zt.s = true;
    
    carry = 0;
    for(i=0;i<minNumBlocks;i++)
    {
        sLL = x.b[i];
        sLL = sLL + y.b[i];
        sLL = sLL + carry;
        zt.b[i] = sLL;
        carry = sLL>>blockBits;
        if(carry)
            carry = 1;
        else
            carry = 0;
    }
    
    if(x.n>=y.n)
        for(i=minNumBlocks;i<x.n;i++)
        {
            sLL = x.b[i];
            sLL = sLL + carry;
            zt.b[i] = sLL;
            carry = sLL>>blockBits;
            if(carry)
                carry = 1;
            else
                carry = 0;
        }
    
    if(y.n>x.n)
        for(i=minNumBlocks;i<y.n;i++)
        {
            sLL = y.b[i];
            sLL = sLL + carry;
            zt.b[i] = sLL;
            carry = sLL>>blockBits;
            if(carry)
                carry = 1;
            else
                carry = 0;
        }
    
    if(carry)
        zt.b[zt.n-1] = carry;
    else
        zt.n--;
    
    z = zt;
    
}/* mbAddAbs */

// for abs values
// if x=y z = 0
// if x>y z = x-y  z.s = true
// if x<y z = y-x  but z.s=false as a signal to caller since x-y = -(y-x)
static void mbSubAbs(mb& z, const mb& x, const mb& y)
{
    mb          zt;
    UINT64      borrow;
    UINT128     sLL;
    long        i, comp;
    
    comp = compareAbs(x, y);
    if(comp==0)
    {
        z = 0;
        return;
    }
    
    if(comp==1)
    {
        // x>y, so subtract y from x
        init(zt, x.n);
        borrow = 0;
        for(i=0;i<y.n;++i)
        {
            sLL = x.b[i];
            sLL = sLL - y.b[i];
            sLL = sLL - borrow;
            zt.b[i] = sLL;
            borrow = sLL>>blockBits;
            if(borrow)
                borrow = 1;
            else
                borrow = 0;
        }
        for(i=y.n;i<x.n;++i)
        {
            sLL = x.b[i];
            sLL = sLL - borrow;
            zt.b[i] = sLL;
            borrow = sLL>>blockBits;
            if(borrow)
                borrow = 1;
            else
                borrow = 0;
        }
        // we are done
        mbNormalize(zt);
        z = zt;
        return;
    }
    // now x<y
    init(zt, y.n);
    borrow = 0;
    for(i=0;i<x.n;++i)
    {
        sLL = y.b[i];
        sLL = sLL - x.b[i];
        sLL = sLL - borrow;
        zt.b[i] = sLL;
        borrow = sLL>>blockBits;
        if(borrow)
            borrow = 1;
        else
            borrow = 0;
    }
    for(i=x.n;i<y.n;++i)
    {
        sLL = y.b[i];
        sLL = sLL - borrow;
        zt.b[i] = sLL;
        borrow = sLL>>blockBits;
        if(borrow)
            borrow = 1;
        else
            borrow = 0;
    }
    // we are done
    zt.s = false;
    mbNormalize(zt);
    z = zt;
    
}/* mbSubAbs */

void mbNormalize(mb& x)
{
    mb      xt;
    
    xt = x;
    while(xt.n>0 && xt.b[xt.n-1]==0)
        xt.n--;
    if(xt.n==0)
        xt.n = 1;
    x = xt;
    
}/* mbNormalize */

static void mul64(UINT64& hi, UINT64& lo, UINT64 x, UINT64 y)
{
    __int128    xx, yy, zz;
    
    xx = x;
    yy = y;
    zz = xx * yy;
    lo = zz;
    hi = zz>>64;
}

// +1 if x > y
// 0  if x=y
// -1 x < y
long compare(const mb& x, const mb& y)
{
    long    comp;
    
    comp = compareAbs(x, y);
    // worry about x and y being -0 and 0
    if(!comp)
    {
        if(x.n==1 && y.n==1 && x.b[0]==0)
            return 0;
        if((x.s && y.s) || (!x.s && !y.s))
            return 0;
        if(x.s && !y.s)
            return 1;
        else
            return -1;
    }
    //now comp = +/- 1
    if(comp==1)
    {
        if(x.s)
            return 1;
        else
            return -1;
    }
    // now comp = -1
    if(y.s)
        return -1;
    else
        return 1;
}/* compare(mb, mb) */

long compare(const mb& x, int y)
{
    mb  yy;
    
    yy = y;
    return compare(x, yy);
    
}/* compare(mb, int) */

long compare(const mb& x, long y)
{
    mb  yy;
    
    yy = y;
    return compare(x, yy);
    
}/* compare(mb, long) */

void mbShiftLeft(mb& z, const mb& x, long numBits)
{
    mb                  zt, s;
    long                restBits; // num bits more than in initial n blocks
    UINT64              bv1,bv2; // hold shift bits
    long                i, numWholeBlocks;
    long                temp;
    
    if(numBits<=0)
    {
        zt = x;
        z = zt;
        return;
    }
    
    temp = numWholeBlocks = numBits/blockBits;
    restBits = numBits - temp*blockBits;
    temp+=x.n; // needed number of s blocks
    
    if(!restBits)
    {
        init(s, temp);
        // copy x blocks into s blocks moving numWholeBlocks up
        for(i=x.n-1;i>=0;i--)
            s.b[i+numWholeBlocks] = x.b[i];
        // zero the rest
        for(i=numWholeBlocks-1;i>=0;i--)
            s.b[i] = 0;
        s.s = x.s;
        z = s;
        return;
    }
    
    // now restBits>0
    temp++; // need another block
    //myFree(s);
    init(s, temp);
    // need to zero that highest block!
    s.b[s.n-1] = 0;
    // note that this is identical to above!
    
    // copy x blocks into upper blocks except for the highest
    for(i=x.n-1;i>=0;i--)
        s.b[i+numWholeBlocks] = x.b[i];
    
    // zero the rest
    for(i=numWholeBlocks-1;i>=0;i--)
        s.b[i] = 0;
    
    // now we must shift left the upper x.n+1 blocks by restBits
    bv1 = 0;
    for(i=0;i<x.n+1;i++)
    {
        bv2 = (s.b[i+numWholeBlocks]>>(blockBits-restBits)); // high restBits part shifted to low
        s.b[i+numWholeBlocks] = bv1 + (s.b[i+numWholeBlocks]<<restBits);
        bv1 = bv2;
    }
    
    mbNormalize(s); // the uppermost block might be zero
    s.s = x.s;
    z = s;
    
}/* mbShiftLeft */

mb mbShiftLeft(const mb& x, long numBits)
{
    mb  z;
    
    mbShiftLeft(z, x, numBits);
    return z;
    
}/* mbShiftLeft */

// s is initialized with needed number of blocks, then x is shifted into s, then z=s
// this divides by 2^numBits
// x is assumed normalized
// fpNormalize use this mbShiftRight
void mbShiftRight(mb& z, const mb& x, long numBits)
{
    mb                  zt, s;
    long                restBits; // num bits more than in initial n blocks
    UINT64              bv1,bv2; // upper block bits
    long                i, numWholeBlocks;
    long                temp;
    
    if(numBits<=0)
    {
        //zt = x;
        //z = zt;
        z = x;
        return;
    }
    
    numWholeBlocks = numBits/blockBits;
    restBits = numBits - numWholeBlocks*blockBits;
    
    if(x.n>numWholeBlocks)
        temp  = x.n - numWholeBlocks;
    else
    {
        z = 0;
        return;
    }

    init(s, temp);
    
    // x -> s shifting down by numWholeBlocks
    for(i=0;i<s.n;i++)
        s.b[i] = x.b[i+numWholeBlocks];
    
    if(restBits)
    {
        bv1 = 0; // will shift zeroes into initial high part of s
        for(i=s.n-1;i>=0;i--)
        {
            bv2 = (s.b[i]<<(blockBits-restBits)); // low restBits part of s.b[i] shifted to high
            s.b[i] = bv1 + (s.b[i]>>restBits);
            bv1 = bv2;
        }
    }

    mbNormalize(s);
    s.s = x.s;
    z = s;
    if(z.n==0)
        cout << "In mbShifyRight z.n = 0" << endl;
    
}/* mbShiftRight */

mb mbShiftRight(const mb& x, long numBits)
{
    mb  z;
    
    mbShiftRight(z, x, numBits);
    return z;
    
}/* mbShiftRight */

void add(mb& z, const mb& x, const mb& y)
{
    if((x.s && y.s) || (!x.s && !y.s))
    {
        mbAddAbs(z, x, y);
        if(x.s)
            z.s = true;
        else
            z.s = false;
        return;
    }
    // now don't have same sign
    mbSubAbs(z, x, y);
    if(z.s)
        z.s = x.s;
    else
        z.s = y.s;
    
}/* add(mb, mb. mb) */

void sub(mb& z, const mb& x, const mb& y)
{
    if((x.s && !y.s) || (!x.s && y.s))
    {
        mbAddAbs(z, x, y);
        z.s = x.s;
        return;
    }
    // now x and y have the same sign
    mbSubAbs(z, x, y);
    if(z.s)
    {
        // x was larger in abs value
        z.s = x.s;
    }
    else
    {
        // y was larger in abs value
        z.s = !y.s;
    }
    
}/* sub(mb, mb, mb) */

void mul(mb& z, const mb& x, const mb& y)
{
    mb      zt;
    long    i, ix, iy;
    UINT64  cx, cy, sL, carry;
    UINT128 temp, temp1;
    bool    sign;
    
    if(x==0 || y==0)
    {
        z = 0;
        return;
    }
    
    if(x.s==y.s)
        sign = true;
    else
        sign = false;
    
    if(abs(x)==1)
    {
        z = y;
        z.s = sign;
        return;
    }
    
    if(abs(y)==1)
    {
        z = x;
        z.s = sign;
        return;
    }
    
    init(zt, x.n+y.n);
    // clear out zt
    for(i=0;i<zt.n;++i)
        zt.b[i] = 0;
    
    if(x.n>=y.n)
    {
        // y is multiplier
        for(iy=0;iy<y.n;++iy)
        {
            cy = y.b[iy];
            carry = 0;
            for(ix=0;ix<x.n;++ix)
            {
                cx = x.b[ix];
                sL = zt.b[ix+iy];
                temp = cx;
                temp = cy*temp;
                temp = temp + sL;
                temp = temp + carry;
                zt.b[ix+iy] = temp;
                carry = temp>>blockBits;
            }
            temp1 = zt.b[iy+x.n];
            temp1 = temp1 + carry;
            zt.b[iy+x.n] = temp1;
            carry = temp1>>blockBits;
        }
    }
    else
    {
        // x is multiplier
        for(ix=0;ix<x.n;++ix)
        {
            cx = x.b[ix];
            carry = 0;
            for(iy=0;iy<y.n;++iy)
            {
                cy = y.b[iy];
                sL = zt.b[ix+iy];
                temp = cx;
                temp = cy*temp;
                temp = temp + sL;
                temp = temp + carry;
                zt.b[ix+iy] = temp;
                carry = temp>>blockBits;
            }
            temp1 = zt.b[ix+y.n];
            temp1 = temp1 + carry;
            zt.b[ix+y.n] = temp1;
            carry = temp1>>blockBits;
        }
    }
    if(carry)
        cout << "mul(mb, mb, mb) should not end with non-zero carry" << endl;
    
    zt.s = sign;
    mbNormalize(zt);
    z = zt;
    
}/* mul(mb, mb, mb) */



// z = x/y when abs(y.n)=1; returns remainder
/*
static UINT64 shortDiv(mb& z, const mb& x, const mb& y)
{
    mb                  s;
    long                i;
    UINT128             cx, cy, cs, carry;
    UINT128             cr;
    UINT64              remainder;
    
    if(y==0 || y.n!=1)
        return 0;
    
    s.n = x.n;
    init(s,s.n);
    cr = 0;
    cy = y.b[0]; // what if cy = 1?
    carry = 0;
    for(i=s.n-1;i>=0;i--)
    {
        cx = x.b[i];
        cx = cx + carry;
        cs = cx/cy;
        s.b[i] = cs;
        cr = cx - cy*cs;
        carry = (cr<<blockBits);
    }
    
    mbNormalize(s);
    
    z = s;
    if(x.s==y.s)
        z.s = true;
    else
        z.s = false;
    
    remainder = cr;
    
    return remainder;
    
}*//* mbShortDiv */



// http://www.hackersdelight.org/hdcodetxt/divmnu64.c.txt
#define max(x, y) ((x) > (y) ? (x) : (y))

// counts number of left zero bits
long nlz(UINT64 x)
{
    long n;
    // FFFFFFFF
    if (x == 0) return(64);
    n = 0;
    if(x <= 0x00000000FFFFFFFF)  {n = n + 32; x = x <<32;}
    if (x <= 0x0000FFFFFFFFFFFF) {n = n + 16; x = x <<16;}
    if (x <= 0x00FFFFFFFFFFFFFF) {n = n + 8; x = x << 8;}
    if (x <= 0x0FFFFFFFFFFFFFFF) {n = n + 4; x = x << 4;}
    if (x <= 0x3FFFFFFFFFFFFFFF) {n = n + 2; x = x << 2;}
    if (x <= 0x7FFFFFFFFFFFFFFF) {n = n + 1;}
    return n;
}/* n1z */

// z = x/y when abs(y.n)=1; returns remainder
UINT64 shortDiv(mb& z, const mb& x, const mb& y)
{
    mb                  s;
    long                i;
    UINT128             cx, cy, cs, carry;
    UINT128             cr;
    UINT64              remainder;
    
    if(y==0 || y.n!=1)
        return 0;
    
    if(y==1 || y==-1)
    {
        z = x;
        if(x.s==y.s)
            z.s = true;
        else
            z.s = false;
        return 0;
    }
    
    s.n = x.n;
    init(s,s.n);
    cr = 0;
    cy = y.b[0];
    carry = 0;
    for(i=s.n-1;i>=0;i--)
    {
        cx = x.b[i];
        cx = cx + carry;
        cs = cx/cy;
        s.b[i] = cs;
        cr = cx - cy*cs;
        carry = (cr<<blockBits);
    }
    
    mbNormalize(s);
    
    z = s;
    if(x.s==y.s)
        z.s = true;
    else
        z.s = false;
    
    remainder = cr;
    
    return remainder;
    
}/* shortDiv(mb, mb, mb) */

long numLeftZeroBiits(UINT64 x)
{
    long            i, n;
    
    UINT64          mask = 0x8000000000000000;
    
    if(!x)
        return blockBits;
    
    n = 0;
    for(i=0;i<blockBits-1;++i)
    {
        if(x & mask)
            break;
        
        n++;
        x = x<<1;
    }
    
    return n;
    
}/* numLeftZeroBiits */

static UINT64 findqhat(mb& sdiv, mb& vqhat, mb& qhatmb, const mb& vext, UINT64 qhat64)
{
    mb      sdivSave;
    
    init(sdivSave, sdiv.n);
    conv(sdivSave, sdiv);
    while(qhat64>0)
    {
        qhatmb = qhat64;
        mul(vqhat, vext, qhatmb);
        if(!(compare(sdiv, vqhat)==-1))
        {
            sub(sdiv, sdiv, vqhat);
            return qhat64;
        }
        conv(sdiv, sdivSave);
        qhat64--;
    }
    
    return qhat64;
    
}/* findqhat */

static void ThreeByTwo(UINT64& z, const mb& b3, const mb& b2)
{
    long double             TwoTo64, x3, x2, zD;
    static UINT64           max64=0xFFFFFFFFFFFFFFFF;
    UINT128                 temp;
    UINT128                 zMax; // 2^64
    
    zMax = 1;
    zMax = zMax<<blockBits;
    TwoTo64 = 18446744073709551615;
    TwoTo64 = TwoTo64 + 1;
    
    x3 = b3.b[2];
    x3 = TwoTo64*x3 + b3.b[1];
    x3 = TwoTo64*x3 + b3.b[0];
    
    x2 = b2.b[1];
    x2 = TwoTo64*x2 + b2.b[0];
    
    // if x3/x2 = 0x10000000000000000 and equate a UIN64 to it, Cocoa gives 0xFFFFFFFFFFFFFFFF while minGW gives 0
    // without .01, rarely z is one less than it should be, which is a disaster!
    // was .5 which often caused z to be one more than it should be; div handled that but was slowe
    
    zD = x3;
    zD = zD/x2;
    zD = zD + 1;
    temp = zD;
    if(temp>=zMax)
        z = max64;
    else
        z = temp;
    
}/* ThreeByTwo */

UINT64 ThreeByTwo(const mb& b3, const mb& b2)
{
    UINT64      z;
    
    ThreeByTwo(z, b3, b2);
    return z;
    
}/* ThreeByTwo */

void divAbs(mb& q, mb& r, const mb& x, const mb& y)
{
    mb              a2, b2, b3;
    mb              qt, d, t; // d is working dividend
    UINT64          qd; // quotient "digit"
    UINT64          carry, borrow;
    UINT128         sLL;
    long            i, result;
    long            qBlock; // index pointer into quotient
    long            dBlock; // index pointer into working dividend
    bool            doSubtract, done;
    
    if(!y)
        return;
    
    if(!x)
    {
        q = 0;
        r = 0;
        return;
    }
    
    result = compareAbs(x, y);
    if(result==-1)
    {
        q = 0;
        r = x;
        return;
    }
    
    if(result==0)
    {
        q = 1;
        r = 0;
        return;
    }
    
    if(y.n==1)
    {
        r = shortDiv(q, x, y);
        return;
    }
    
    init(t, y.n + 1); // t holds quotient "digit" X Y
    init(a2, 2);
    init(b2, 2);
    init(b3, 3);
    
    // load b2
    b2.b[1] = y.b[y.n-1];
    b2.b[0] = y.b[y.n-2];
    
    init(d, x.n+1); // working dividend
    for(i=0;i<x.n;++i)
        d.b[i] = x.b[i];
    d.b[x.n] = 0;
    d.s = true;
    dBlock = d.n - 1;
    init(qt, x.n - y.n + 1); // working quotient
    qBlock = qt.n - 1; // where next quotient "digit" goes
    
    while(qBlock>=0)
    {
        // get qd
        // load b3
        b3.b[2] = d.b[dBlock];
        b3.b[1] = d.b[dBlock-1];
        b3.b[0] = d.b[dBlock-2];
        qd = ThreeByTwo(b3, b2);
        done = false;
        while(!done)
        {
            // calc t = qd*y
            carry = 0;
            for(i=0;i<y.n;i++)
            {
                sLL = qd;
                sLL = sLL*y.b[i];
                sLL = sLL + carry;
                t.b[i] = sLL;
                carry = sLL>>blockBits;
            }//for
            t.b[t.n-1] = carry;
            
            // check if d>=t
            doSubtract = true;
            for(i=0;i<t.n;++i)
            {
                if(d.b[dBlock-i]>t.b[t.n-1-i])
                    break;
                if(d.b[dBlock-i]<t.b[t.n-1-i])
                {
                    doSubtract = false;
                    break;
                }// if
            }// for
            
            if(doSubtract)
            {
                // subtract t from high part of d, modifying d
                borrow = 0;
                for(i=0;i<t.n;++i)
                {
                    sLL = d.b[dBlock-t.n+1+i];
                    sLL = sLL - t.b[i];
                    sLL = sLL - borrow;
                    d.b[dBlock-t.n+1+i] = sLL;
                    borrow = sLL>>blockBits;
                    if(borrow)
                        borrow = 1;
                    else
                        borrow = 0;
                }// for
                qt.b[qBlock--] = qd;
                dBlock--;
                done = true;
            }// if(doSubtract)
            else
            {
                qd--;
                if(qd==0)
                {
                    qt.b[qBlock--] = qd;
                    dBlock--;
                    done = true;
                }
            }
        }// while)!done)
    }// while(qBlock>=0)
    
    q = qt;
    r = d;
    mbNormalize(q);
    mbNormalize(r);
    
}/* divAbs */

// returns remainder
mb div(mb& q, const mb& x, const mb& y)
{
    mb      r;
    bool    sign;
    
    if(x.s==y.s)
        sign = true;
    else
        sign = false;
    
    if(y.n==1)
        r = shortDiv(q, x, y);
    else
        divAbs(q, r, x, y);
    
    q.s = sign;
    if((x.s && y.s) || (x.s && !y.s))
        r.s = true;
    else
        r.s = false;
        
    return r;
    
}/* div(mb, mb, mb) */

void add(mb& z, const mb& x, double y)
{
    mb  ymb;
    
    conv(ymb, y);
    add(z, x, ymb);
    
}/* add(mb, mb, double) */

void sub(mb& z, const mb& x, double y)
{
    mb  ymb;
    
    conv(ymb, y);
    sub(z, x, ymb);
    
}/* sub(mb, mb, double) */

void mul(mb& z, const mb& x, double y)
{
    mb  ymb;
    
    conv(ymb, y);
    mul(z, x, ymb);
    
}/* mul(mb, mb, double) */

void div(mb& z, const mb& x, double y)
{
    mb  ymb;
    
    conv(ymb, y);
    div(z, x, ymb);
    
}/* div(mb, mb, double) */

mb divRem(mb& z, const mb& x, double y)
{
    mb  ymb;
    
    conv(ymb, y);
    return div(z, x, ymb);
    
}/* divRem(mb, mb, double) */

void add(mb& z, double x, const mb& y)
{
    mb  xmb;
    
    conv(xmb, x);
    add(z, xmb, y);
    
}/* add(mb, double, mb) */

void sub(mb& z, double x, const mb& y)
{
    mb  xmb;
    
    conv(xmb, x);
    sub(z, xmb, y);
    
}/* sub(mb, double, mb) */

void mul(mb& z, double x, const mb& y)
{
    mb  xmb;
    
    conv(xmb, x);
    mul(z, xmb, y);
    
}/* mul(mb, double, mb) */

void div(mb& z, double x, const mb& y)
{
    mb  xmb;
    
    conv(xmb, x);
    div(z, xmb, y);
    
}/* div(mb, double, mb) */

mb divRem(mb& z, double x, const mb& y)
{
    mb  xmb;
    
    conv(xmb, x);
    return div(z, xmb, y);
    
}/* divRem(mb, double, mb) */

 // for a long, counts the number of 0 bits to first 1 bit after and including the leading 1 bit
 // so it gives the number of bits actually needed for the number
 long NumBits(long x)
 {
 static UINT64       mask;
 static bool         initGood=false;
 long                theBits;
 
 if(!initGood)
 {
 mask = 1;
 mask = mask<<(blockBits-1); // 0x8000000000000000 (which we can't use directly since it's a negative long!
 initGood = true;
 }
 
 theBits = blockBits - 1;
 x = abs(x);
 x = (x<<1); // gets rid of sign bit
 
 while(!(x & mask) && theBits)
 {
 x = (x<<1);
 theBits--;
 }
 
 return theBits;
 
 }/* NumBits */

// for a UINT64, counts the number of bits after and including the leading 1 bit
long NumBits(UINT64 x)
{
    static UINT64       mask;
    static bool         initGood=false;
    long                theBits;
    
    if(!initGood)
    {
        mask = 1;
        mask = (mask<<(blockBits-1)); // 0x8000000000000000
        initGood = true;
    }
    
    theBits = blockBits;;
    
    while(!(x & mask) && theBits)
    {
        x = (x<<1);
        theBits--;
    }
    
    return theBits;
    
}/* NumBits for UINT64 */

// counts the number of bits in x
long NumBits(const mb& x)
{
    long        count;
    
    if(x.n==0)
        return 1;
    
    count = NumBits(x.b[x.n-1]);
    
    if(x.n>1)
        count = count + blockBits*(x.n-1);
    
    return count;
    
}/* NumBits */


// uses Newton's Method  z = (z + x/z)/2 to give z = sqrt(x)
void sqrt(mb& z, const mb& x)
{
    static mb        one;
    mb               s, s1, t1, ss;
    static bool      initGood=false;
    long             i, iL, bitCount;
    
    if(!initGood)
    {
        init(one, 1);
        initGood = true;
    }
    
    iL = 0;
    
    if(!x.s)
        return;
    
    if(x==0)
    {
        z = 0;
        return;
    }
    
    // initial guess
    s1 = x;

    // now count the bits in the blocks
    bitCount = NumBits(s1.b[s1.n-1]);
    if(s1.n>1)
        bitCount = bitCount + blockBits*(s1.n-1);
    
    bitCount = bitCount/2;
    
    mbShiftRight(s1, s1, bitCount);
    t1 = s1;
    
    for(i=0;i<10000;i++)
    {
        iL = i;
        div(s, x, s1);
        add(s, s, s1);
        mbShiftRight(s, s, 1); // s=s/2
        if(!compare(s, s1))
            break;
        if(!compare(s, t1))
        {
            mul(ss, s, s);
            if(compare(ss, x)==1)
                sub(s, s, one);
            break;
        }
        t1 = s1;
        s1 = s;
    }
    z = s;
    if(iL>=9999)
        cout << "sqrt problem" << endl;
    
}/* sqrt */

mb sqrt(const mb& x)
{
    mb  z;
    
    sqrt(z, x);
    return z;
    
}/* sqrt */

// z = x mod y
void myModulus(mb& z, const mb& x, const mb& y, bool isSym)
{
    mb  two=2;
    
    if(y==0)
        return;
    
    z = div(z, x, y);
    if(z>y/two && isSym)
        z = z - y;
    
}/* myModulus(mb, mb, mb, bool) */

// z = x mod y
void myModulus(mb& z, const mb& x, double y, bool isSym)
{
    mb  ymb;
    
    if(y==0)
        return;
    
    conv(ymb, y);
    myModulus(z, x, ymb, isSym);
    
}/* myModulus(mb, mb, double, bool) */

// z = x mod y
void myModulus(mb& z, double x, const mb& y, bool isSym)
{
    mb  xmb;
    
    if(y==0)
        return;
    
    conv(xmb, x);
    myModulus(z, xmb, y, isSym);
    
}/* myModulus(mb, double, mb, bool) */

mb myModulus(const mb& x, const mb& y, bool isSym)
{
    mb  z;
    
    myModulus(z, x, y, isSym);
    return z;
    
}/* myModulus(mb, mb, bool) */

mb myModulus(const mb& x, double y, bool isSym)
{
    mb    ymb;
    
    conv(ymb, y);
    return myModulus(x, ymb, isSym);
    
}/* myModulus(mb, double, bool) */

mb myModulus(double x, const mb& y, bool isSym)
{
    mb    xmb;
    
    conv(xmb, x);
    return myModulus(xmb, y, isSym);
    
}/* myModulus(double, mb, bool) */

// greater common divisor
void gcd(mb& z, const mb& x, const mb& y)
{
    mb          xt, yt, r, zTemp; // xt = u  yt = v
    
    xt = x;
    yt = y;
    
    if(!xt.s)
        xt.s = !xt.s;
    
    if(!yt.s)
        yt.s = !yt.s;
    
    if(yt==0)
    {
        z = xt;
        return;
    }
    
    if(xt==0)
    {
        z = yt;
        return;
    }
    
    do
    {
        divAbs(zTemp, r, xt, yt); // t = xt % yt
        xt = yt;
        yt = r;
    }
    while(yt!=0);
    
    z = xt;
    
}/* gcd */

mb gcd(const mb& x, const mb& y)
{
    mb  z;
    
    gcd(z, x, y);
    return z;
    
}/* gcd */

// uses iterative method described in http://en.wikipedia.org/wiki/Extended_Euclidean_algorithm
// solves a*x + b*y = gcd(a, b) for x, y, gcd
void extendedGCD(mb& x, mb& y, mb& gcd, const mb& a, const mb& b)
{
    mb        r0, r1, r2,  q, s0, s1, s2, t0, t1, t2;
    
    r0 = a;
    r1 = b;
    
    if(!r0.s)
        r0.s = !r0.s;
    
    if(!r1.s)
        r1.s = !r1.s;
    
    s0 = 1;
    s1 = 0;
    t0 = 0;
    t1 = 1;
    
    do
    {
        div(q, r0, r1);
        r2 = r0 - q*r1;
        s2 = s0 - q*s1;
        t2 = t0 - q*t1;
        r0 = r1;
        r1 = r2;
        s0 = s1;
        s1 = s2;
        t0 = t1;
        t1 = t2;
    }
    while(r1!=0);
    
    gcd = r0;
    x = s0;
    y = t0;
    if(!a.s)
        x.s = !x.s;
    if(!b.s)
        y.s = !y.s;
    
    
}/* extendedGCD */


// LCM of two integers x and y is the smallest positive integer that is divisible by both x and y, it's also the lcd for two fractions
void lcm(mb& z, const mb& x, const mb& y)
{
    
    z = abs(x*y) / gcd(x, y);
    
}/* lcm */

mb lcm(const mb& x, const mb& y)
{
    mb  z;
    
    lcm(z, x, y);
    
    return z;
    
}/* lcm */

void nthRoot(mb& z, const mb& x, long n)
{
    mb                  one=1, two=2;
    mb                  xt, zt, s, s1;
    bool                isNegative;
    long                i, result;
    
    if(n<=0)
    {
        z = 0;
        return;
    }
    
    if(n==1)
    {
        z = x;
        return;
    }
    
    xt = x;
    if(!xt.s)
    {
        xt.s = !xt.s;
        isNegative = true;
    }
    else
        isNegative = false;
    
    if(2*(n/2)==n && isNegative)
    {
        z = 0; // error in input should be caught by caller
        return;
    }
    
    if(!compare(xt, one))
    {
        z = x;
        return;
    }
    
    power(s, two, n);
    if(compare(xt, s)==-1)
    {
        z = one;
        if(isNegative)
            z.s = !z.s;
        return;
    }
    
    // initial guess
    power(zt, two, Lg2(xt)/n);
    // s1=n*zt^(n-1); s=zt^n
    for(i=0; ;i++)
    {
        power(s1, zt, n-1);
        mul(s, s1, zt);
        mul(s1, s1, n);
        sub(s, xt, s);
        div(s, s, s1);
        add(s, s, zt);
        if(!compare(s, zt))
            break;
        zt = s;
    }
    
    power(s, zt, n);
    result = compare(s, xt);
    if(result==1)
        while(result==1)
        {
            sub(zt, zt, one);
            power(s, zt, n);
            result = compare(s, xt);
        }
    else
        if(result==-1)
        {
            while(result==-1)
            {
                add(zt, zt, one);
                power(s, zt, n);
                result = compare(s, xt);
            }
            if(result==1)
                sub(zt, zt, one);
        }
    z = zt;
    if(isNegative)
        z.s = !z.s;
    
}/* nthRoot */

mb nthRoot(const mb& x, long n)
{
    mb        z;
    
    nthRoot(z, x, n);
    return z;
    
}/* nthRoot */

// z = x^y
bool power(mb& z, const mb& x, long y)
{
    long                i, n, sy;
    UINT64              mask=0x8000000000000000;
    mb                  one=1;
    mb                  zt;
    
    if(!compare(x, one))
    {
        z = one;
        return true;
    }
    
    if(y<0 && x>0)
    {
        z = 0;
        return true;
    }
    
    if(y<0 && x==0)
    {
        z = 0;
        return false;
    }
    
    if(y==0)
    {
        z = one;
        if(x==0)
            return false;
        return true;
    }
    
    // now y>0
    zt = one;
    sy = y;
    n = NumBits(sy);
    sy = (sy<<(blockBits-n));  // go to leading 1 bit
    
    for(i=0;i<n;i++)
    {
        mul(zt, zt, zt); // zt^2
        if(sy & mask)
            mul(zt, zt, x);
        sy = (sy<<1);
    }
    z = zt;
    return true;
    
}/* power */

mb power(const mb& x, long y)
{
    mb      z;
    power(z, x, y);
    return z;
    
}/* power */

// calculates the number of digits in decimal representation of x
long numDigits(const mb& x)
{
    long    length, bp, dp;
    fp      xfp, ten=10, log10;
    
    if(x==0)
    {
        length = 1;
        return length;
    }
    // we don't need high precision
    bp = blockPrec;
    dp = decPrec;
    blockPrec = 2;
    decPrec = 32;
    conv(xfp, x);
    xfp.i.s = true;
    log10 = log(xfp)/log(ten) + 1e-16;
    length = to_long(log10) + 1;
    blockPrec = bp;
    decPrec = dp;
    return length;
    
}/* numDigits */
