//
//  mbBit.cpp
//  fp64
//
//  Created by Bob Delaney on 2/11/19.
//  Copyright © 2019 Bob Delaney. All rights reserved.
//

#include "mbBit.hpp"
#include "mbMath.hpp"

using namespace std;

static void coutHex(const mb& x)
{
    long    i;
    
    cout << "sign = " << x.s << endl;
    
    for(i=0;i<x.n;++i)
        cout << hex << x.b[i] << endl;
    cout << dec << endl;
    
}/* coutHex(mb) */

// counts the number of binary 1's in x
long countOnes(UINT64 x)
{
    UINT64          mask;
    long            i, count;
    
    mask = 1;
    count = 0;
    
    for(i=0;i<blockBits;i++)
    {
        if((x & mask) == mask)
            count++;
        mask = (mask<<1);
    }
    
    return count;
    
}/* countOnes */

// counts identical bits in same positions where positions go from 0 to maxPos inclusive
long compareBitsUINT64(UINT64 x, UINT64 y, long maxPos)
{
    UINT64          mask;
    long            i, count;
    
    if(maxPos<0)
        maxPos = 0;
    
    if(maxPos>63)
        maxPos = 63;
    
    mask = 1;
    count = 0;
    
    for(i=0;i<=maxPos;i++)
    {
        if((x & mask)==(y & mask))
            count++;
        
        mask = (mask<<1);
    }
    
    return count;
    
}/* compareBitsUINT64 */

// returns the value of the bit at bitIndex
long getBit(UINT64 x, long bitIndex)
{
    UINT64          mask;
    
    if(bitIndex<0 || bitIndex>63)
        return 0;
    
    mask = 1;
    mask = (mask<<bitIndex);
    
    if((x & mask)==mask)
        return 1;
    
    return 0;
    
}/* getBit */

// clears the bit at bitIndex
UINT64 clearBit(UINT64 x, long bitIndex)
{
    UINT64          mask;
    
    if(bitIndex<0 || bitIndex>63)
        return x;
    
    mask = 1;
    mask = (mask<<bitIndex);
    
    if((x & mask)==0)
    {
        return x;
    }
    else
    {
        x = (x ^ mask);
        return x;
    }
    
}/* clearBit */

// sets the bit at bitIndex
UINT64 setBit(UINT64 x, long bitIndex)
{
    UINT64          mask;
    
    if(bitIndex<0 || bitIndex>63)
        return x;
    
    mask = 1;
    mask = (mask<<bitIndex);
    
    x = (x | mask);
    
    return x;
    
}/* setBit */

// flips the bit at bitIndex
UINT64 flipBit(UINT64 x, long bitIndex)
{
    UINT64          mask;
    
    if(bitIndex<0 || bitIndex>63)
        return x;
    
    mask = 1;
    mask = (mask<<bitIndex);
    
    if((x & mask)==0)
    {
        x = x | mask;
        return x;
    }
    else
    {
        x = x ^ mask;
        return x;
    }
    
}/* flipBit */


// returns the value of the bit at bitIndex
long mbGetBit(const mb& x, long bitIndex)
{
    long            blockIndex;
    
    if(x==0)
        return 0;
    
    if(bitIndex<0 || bitIndex>=NumBits(x))
        return 0;
    
    blockIndex = bitIndex/blockBits;
    bitIndex = bitIndex - blockBits*blockIndex;
    
    return getBit(x.b[blockIndex], bitIndex);
    
}/* mbGetBit */

// counts identical bits in same positions from lowBitIndex to highBitIndex inclusive
// zero extension is used if necessary
long mbCompareBits(const mb& x, const mb& y, long lowBitIndex, long highBitIndex)
{
    long            i, count;
    
    count = 0;
    
    if(lowBitIndex<0)
        return 0;
    
    for(i=lowBitIndex;i<=highBitIndex;i++)
        if(mbGetBit(x, i)==mbGetBit(y, i))
            count++;
    
    return count;
    
}/* mbCompareBits */

// if 0 <= bitIndex it sets the bit at bitIndex to 1 and returns new value
// if bitIndex is out of range, the return mb is extended
mb mbSetBit(const mb& x, long bitIndex)
{
    mb          z, zt;
    long        i, blockNum, blockBitIndex;
    
    z = x;
    
    if(bitIndex<0)
        return z;

    blockNum = bitIndex / blockBits;
    blockBitIndex = bitIndex - blockNum*blockBits;
    
    if(blockNum>=z.n)
    {
        init(zt, blockNum+1);
        zt.s = x.s;
        for(i=0;i<x.n;++i)
            zt.b[i] = x.b[i];
        for(i=x.n;i<zt.n;++i)
            zt.b[i] = 0;
        z = zt;
    }
    
    z.b[blockNum] = setBit(z.b[blockNum], blockBitIndex);
    return z;
    
}/* mbSetBit */

// sets bit in x
void mbSetBitFast(mb& x, long bitIndex)
{
    long        blockNum, blockBitIndex;
    
    if(bitIndex<0 || bitIndex>=blockBits*x.n)
        return;
    
    blockNum = bitIndex / blockBits;
    blockBitIndex = bitIndex - blockNum*blockBits;
    
    x.b[blockNum] = setBit(x.b[blockNum], blockBitIndex);
    return;
    
}/* mbSetBitFast */

// if 0 <= bitIndex it clears the bit at bitIndex to 0 and returns new value
mb mbClearBit(const mb& x, long bitIndex)
{
    mb          z;
    long        blockNum, blockBitIndex;
    
    z = x;
    
    if(bitIndex<0 || bitIndex>=blockBits*x.n)
        return z;
    
    blockNum = bitIndex / blockBits;
    blockBitIndex = bitIndex - blockNum*blockBits;
    z.b[blockNum] = clearBit(z.b[blockNum], blockBitIndex);
    mbNormalize(z);
    return z;
    
}/* mbClearBit */

// sets bit in x
void mbClearBitFast(mb& x, long bitIndex)
{
    long        blockNum, blockBitIndex;
    
    if(bitIndex<0 || bitIndex>=blockBits*x.n)
        return;
    
    blockNum = bitIndex / blockBits;
    blockBitIndex = bitIndex - blockNum*blockBits;
    
    x.b[blockNum] = clearBit(x.b[blockNum], blockBitIndex);
    return;
    
}/* mbClearBitFast */

// flips bit
mb mbFlipBit(const mb& x, long bitIndex)
{
    mb              z, zt;
    long            i, blockNum, blockBitIndex;
    
    z = x;
    
    if(bitIndex<0)
        return z;
    
    blockNum = bitIndex / blockBits;
    blockBitIndex = bitIndex - blockNum*blockBits;
    
    if(blockNum>=z.n)
    {
        init(zt, blockNum+1);
        zt.s = x.s;
        for(i=0;i<x.n;++i)
            zt.b[i] = x.b[i];
        for(i=x.n;i<zt.n;++i)
            zt.b[i] = 0;
        zt.b[blockNum] = setBit(zt.b[blockNum], blockBitIndex);
        z = zt;
        return z;
    }
    
    z.b[blockNum] = flipBit(z.b[blockNum], blockBitIndex);
    mbNormalize(z);
    return z;
    
}/* mbFlipBit */

void mbFlipBitFast(mb& x, long bitIndex)
{
    long        blockNum, blockBitIndex;
    
    if(bitIndex<0 || bitIndex>=blockBits*x.n)
        return;
    
    blockNum = bitIndex / blockBits;
    blockBitIndex = bitIndex - blockNum*blockBits;
    
    x.b[blockNum] = flipBit(x.b[blockNum], blockBitIndex);
    return;
    
}/* mbFlipBitFast */

long mbCountOnes(const mb& x)
{
    long        i, count;
    
    count = 0;
    
    if(x==0)
        return count;
    
    for(i=0;i<x.n;++i)
        count = count + countOnes(x.b[i]);
    
    return count;
    
}/* mbCountOnes */

// does bit by bit AND of x with y and returns resultant mb; we ignore signs so resultant>=0
// the number of bits starting with the leading 1 is always the same or less than the equivalent in the smaller of (x, y)
mb mbAnd(const mb& x, const mb& y)
{
    long            i, m;
    mb              z;
    
    if(x==0 || y==0)
    {
        z = 0;
        return z;
    }
    
    // find minimum number of blocks
    m = min(x.n, y.n);
    init(z, m);
    for(i=0;i<m;i++)
    {
        z.b[i] = (x.b[i] & y.b[i]);
    }
    
    mbNormalize(z);
    
    return z;
    
}/* mbAnd */

mb mbOr(const mb& x, const mb& y)
{
    mb          z;
    long        i, mMin, mMax;
    
    if(x==0 && y==0)
    {
        z = 0;
        return z;
    }
    
    if(x==0)
    {
        z = y;
        return z;
    }
    
    if(y==0)
    {
        z = x;
        return z;
    }
    
    mMax = max(x.n, y.n);
    init(z, mMax);
    mMin = min(x.n, y.n);
    
    for(i=0;i<mMin;i++)
        z.b[i] = (x.b[i] | y.b[i]);
    
    if(x.n>y.n)
    {
        for(i=mMin;i<mMax;i++)
            z.b[i] = x.b[i];
    }
    else
    {
        for(i=mMin;i<mMax;i++)
            z.b[i] = y.b[i];
    }
    
    return z;
    
}/* mbOr */

mb mbXor(const mb& x, const mb& y)
{
    mb          z;
    long        i, mMin, mMax;
    
    if(x==0 && y==0)
    {
        z = 0;
        return z;
    }
    
    if(x==0)
    {
        z = y;
        return z;
    }
    
    if(y==0)
    {
        z = x;
        return z;
    }
    
    mMax = max(x.n, y.n);
    init(z, mMax);
    mMin = min(x.n, y.n);
    
    for(i=0;i<mMin;i++)
        z.b[i] = (x.b[i] ^ y.b[i]);
    
    if(x.n>y.n)
    {
        for(i=mMin;i<mMax;i++)
            z.b[i] = x.b[i];
    }
    else
    {
        for(i=mMin;i<mMax;i++)
            z.b[i] = y.b[i];
    }
    
    mbNormalize(z);
    
    return z;
    
}/* mbXor */


