// Copyright 1992-2000 by Jon Dart.  All Rights Reserved.
#include "bearing.h"
#include "util.h"
#include "beardata.h"
#include "debug.h"

const int RankIncr = 8; // add this to move 1 rank

const int Bearing::FilesR[64] =
{
   7, 6, 5, 4, 3, 2, 1, 0,
   7, 6, 5, 4, 3, 2, 1, 0,
   7, 6, 5, 4, 3, 2, 1, 0,
   7, 6, 5, 4, 3, 2, 1, 0,
   7, 6, 5, 4, 3, 2, 1, 0,
   7, 6, 5, 4, 3, 2, 1, 0,
   7, 6, 5, 4, 3, 2, 1, 0,
   7, 6, 5, 4, 3, 2, 1, 0
};


Bitmap Bearing::knight_attacks[64];
Bitmap Bearing::king_attacks[64];
Bitmap Bearing::pawn_attacks[64][2];
Bitmap  Bearing::file_mask[8];
Bitmap  Bearing::rank_mask[8];
Bitmap  Bearing::file_mask_down[64];
Bitmap  Bearing::file_mask_up[64];
Bitmap  Bearing::rank_mask_right[64];
Bitmap  Bearing::rank_mask_left[64];
Bitmap  Bearing::diag_a1_mask[64];
Bitmap  Bearing::diag_a8_mask[64];
Bitmap  Bearing::black_squares;
Bitmap  Bearing::white_squares;
byte Bearing::whichAttack[256][8];
Bitmap  Bearing::rankAttacks[64];
Bitmap  Bearing::rankAttacksDown[64];
Bitmap  Bearing::rankAttacksUp[64];
byte  Bearing::fileAttacks[64];
byte  Bearing::fileAttacksDown[64];
byte  Bearing::fileAttacksUp[64];
Bitmap Bearing::diagA1Attacks[DIAG_ATTACKS_LEN];
Bitmap Bearing::diagA8Attacks[DIAG_ATTACKS_LEN];
Bitmap Bearing::ep_mask[8][2];

signed char Bearing::Directions[64][64];
Bitmap Bearing::rank_file_attack[64];
Bitmap Bearing::diag_attack[64];

static inline void 
SetSquare(
         const Square sq, Square * squares, unsigned &NumSquares)
{
   ASSERT(OnBoard(sq));
   ASSERT(NumSquares < Bearing::MaxBearSq);
   squares[NumSquares++] = sq;
}

int Bearing::is_pinned(const Board & board, ColorType kingColor,
                       Piece p,Square source,Square dest)
{
   if (p == EmptyPiece || (TypeOfPiece(p) == King && PieceColor(p) == kingColor))
      return 0;
   int dir = Bearing::Directions[source][board.KingPos(kingColor)];
   if (dir == 0) return 0;

   int attackerFound = 0;
   const ColorType oside = OppositeColor(kingColor);
   switch (dir)
   {
   case 1:
      // Find attacks on the rank
      {
         attackerFound = !Bitmap::And(RankAttacksLeft(board,source),
            board.rooksqueens[oside]).is_clear();
      }
      break;
   case -1:
      {
         attackerFound = !Bitmap::And(RankAttacksRight(board,source),
            board.rooksqueens[oside]).is_clear();
      }
      break;
   case 8:
      {
         attackerFound = !Bitmap::And(FileAttacksUp(board,source),board.rooksqueens[oside]).is_clear();
      }
      break;
   case -8:
      {
         attackerFound = !Bitmap::And(FileAttacksDown(board,source),board.rooksqueens[oside]).is_clear();
      }
      break;
   case 7:
      {
         Bitmap b(DiagAtcksA1Upper(board,source));
         b.And(board.bishopsqueens[oside]);
         attackerFound = !b.is_clear();
      }
      break;
   case -7:
      {
        Bitmap b(DiagAtcksA1Lower(board,source));
        b.And(board.bishopsqueens[oside]);
        attackerFound = !b.is_clear();
      }
      break;
   case 9:
      {
         Bitmap b(DiagAtcksA8Upper(board,source));
         b.And(board.bishopsqueens[oside]);
         attackerFound = !b.is_clear();
      }
      break;
   case -9:
      {
         Bitmap b(DiagAtcksA8Lower(board,source));
         b.And(board.bishopsqueens[oside]);
         attackerFound = !b.is_clear();
      }
      break;
   default:  
      break;
   }
   if (attackerFound)
   {
      Square ks = board.KingPos(kingColor);
      // see if the path is clear to the king
      Square sq(source);
      do
      {
         sq += dir;
         if (!OnBoard(sq) || ((sq != ks) && board[sq] != EmptyPiece)) {
	   return 0;
         }
      } while (sq != ks);
      // This code checks for moves in the direction of the pin,
      // which are ok:
      int dir1 = Util::Abs((int) source - (int) dest);
      int dir2 = Util::Abs(dir);
      if (dir2 == 1 && Rank(source,White) != Rank(dest,White))
         return 1;
      else if (dir2 == RankIncr && File(source) != File(dest))
	 return 1;
      else if (dir1 % dir2)
	 return 1;
   }
   return 0;
}

static Bitmap do_unrot90(byte b)
{
      Bitmap bm((uint64)b);
      Bitmap bunrot;
      // "unrotate" the bitmap.
      for (int k=0;k<8;k++)
      {
         if (bm.is_set(k))
            bunrot.set(Bitmap::unrot90[k]);
      }
      return bunrot;
}

void Bearing::init()
{
   int    attacksCount = 0;
   int i,j;
   for (i = 0; i < 64; i++)
   {
      Square sq = i;
      file_mask[File(sq)-1].set(i);
      for (j=Rank(sq,White)-1;j>=1;--j) {
         file_mask_down[sq].set(Square(File(sq),j,White));
      }
      for (j=Rank(sq,White)+1;j<=8;++j) {
         file_mask_up[sq].set(Square(File(sq),j,White));
      }
      for (j=File(sq)+1;j<=8;++j) {
         rank_mask_right[sq].set(Square(j,Rank(sq,White),White));
      }
      for (j=File(sq)-1;j>=1;--j) {
         rank_mask_left[sq].set(Square(j,Rank(sq,White),White));
      }
      rank_mask[Rank(sq,Black)-1].set(i);
      if (SquareColor(sq) == White)
         white_squares.set(i);
      else
         black_squares.set(i);
      const int *data = KnightSquares[i];
      j = 0;
      while (j < 8 && *data != 255)
      {
         knight_attacks[i].set(*data++);
         ++j;
      }
      const int *data2= KingSquares[i];
      j = 0;
      while (j < 8 && *data2 != 255)
      {
         king_attacks[i].set(*data2++);
         ++j;
      }
      if (File(sq) != 1)
      {
         if (i-9>=0) pawn_attacks[i][Black].set(i-9);
         if (i+7<=63) pawn_attacks[i][White].set(i+7);
      }
      if (File(sq) != 8)
      {
         if (i-7>=0) pawn_attacks[i][Black].set(i-7);
         if (i+9<=63) pawn_attacks[i][White].set(i+9);
      }
      Square sq2 = sq;
      diag_a1_mask[sq].set(sq);
      diag_a8_mask[sq].set(sq);
      while (sq2 >= 7 && File(sq2) != 8)
      {
         sq2 = sq2 - 7;
         diag_a1_mask[sq].set(sq2);
      }
      sq2 = sq;
      while (sq2 <= 63-7 && File(sq2) != 1)
      {
         sq2 = sq2 + 7;
         diag_a1_mask[sq].set(sq2);
      }
      sq2 = sq;
      while (sq2 >= 9 && File(sq2) != 1)
      {
         sq2 = sq2 - 9;
         diag_a8_mask[sq].set(sq2);
      }
      sq2 = sq;
      while (sq2 <= 63-9 && File(sq2) != 8)
      {
         sq2 = sq2 + 9;
         diag_a8_mask[sq].set(sq2);
      }
      for (j = 0; j < 64; j++)
      {
         if (i == j)
         {
            Directions[i][j] = 0;
            continue;
         }
         Square sq1 = i;
         Square sq2 = j;
         int offset;
         int abs = (sq2 > sq1) ? (int)sq2 - (int)sq1 : (int)sq1 - (int)sq2;
         if (File(sq1) == File(sq2))
            offset = RankIncr;
         else if (Rank(sq1,White) == Rank(sq2,White))
            offset = 1;
         else if (SquareColor(sq1) == SquareColor(sq2))
         {
            if (abs % (RankIncr+1) == 0)
               offset = RankIncr+1;
            else if (abs % (RankIncr-1) == 0)
               offset = RankIncr-1;
            else
            {
               Directions[i][j] = 0;
               continue;
            }
         }
         else
         {
            Directions[i][j] = 0;
            continue;
         }
         Directions[i][j] = (sq2 > sq1) ? offset : -offset;
         if (File(i) == File(j) ||
             Rank(i,White) == Rank(j,White))
            rank_file_attack[i].set(j);
      }
   }
   for (i=0; i<64; i++)
      for (j=0;j<64;j++)
         if (Directions[i][j]==9 ||
             Directions[i][j]==-9 ||
             Directions[i][j]==7 ||
             Directions[i][j]==-7)
            diag_attack[i].set(j);
   int m;
   for (i=0; i<256; i++)
   {
      for (int off=0; off <8; off++)
      {
         byte b,lb,rb,s;
         b = lb = rb = s = 0;
         m = 0;
         for (j=off+1;j<8;j++)
         {
            if (i & (1<<j))
            {
               b |= (1<<j);
               rb |= (1<<j);
               break;
            }
            else
	      s |= (1<<j);
            ++m;
         }
         for (j=off-1;j>=0;--j)
         {
            if (i & (1<<j))
            {
               b |=(1<<j);
               lb |= (1<<j);
               break;
            }
            else
	      s |= (1<<j);
            ++m;
         }
         int found = 0;
	 int xx;
         for (xx = 0; xx < attacksCount; xx++)
         {
            if (fileAttacks[xx] == b &&
                fileAttacksUp[xx] == lb &&
                fileAttacksDown[xx] == rb)
            {
               whichAttack[i][off] = xx;
               ++found;
               break;
            }
         }
         if (!found)
         {
            ASSERT(attacksCount < 64);
            whichAttack[i][off] = attacksCount;
             fileAttacks[attacksCount] = b;
            fileAttacksDown[attacksCount] = rb;
            fileAttacksUp[attacksCount] = lb;
            attacksCount++;
         }
      }
   }
   for (i = 0; i < attacksCount; i++)
   {
      rankAttacks[i] = do_unrot90(fileAttacks[i]);
      rankAttacksDown[i] = do_unrot90(fileAttacksDown[i]);
      rankAttacksUp[i] = do_unrot90(fileAttacksUp[i]);
   }
   // populate diag attack vectors & related info
   Bitmap *d = diagA1Attacks;
   for (i = 0; i < 8; i++)
   {
      int len = i+1;
      Bitmap *dstart = d;
      int mask = (1<<len)-1;
      int xx;
      for (xx=0; xx <attacksCount; xx++) {
         byte b = fileAttacks[xx];
         
         if (((int)b & ~mask) == 0)
         {
            // "b" is a possible value for the attacks for
            // this square
            Bitmap unrot;
            Bitmap bits((uint64)b);
            Square sq;
            while (bits.iterate(sq)) {
               ASSERT(i*8+sq<64);
               unrot.set(Bitmap::unrot_da1[i*8+sq]);
            }
            *d++ = unrot;
         }
      }
      for (int j = 0; j<len; j++) {
         Square sq = Bitmap::unrot_da1[i*8+j];
         DiagInfo2 *diagInfo =
          &DiagInfoA1[sq];
         diagInfo->byte_shift = i;
         diagInfo->bit_shift = 0;
         diagInfo->mask = mask;
         diagInfo->bit = j;
         diagInfo->attacks = dstart;
         Square sq2 = sq;
         while (File(sq2) != 8 && ((int)sq2 - 7) >= 0) {
           diagInfo->upperMask.set(sq2-7); sq2-=7;
         }
         sq2 = sq;
         while (File(sq2) != 1 && ((int)sq2 + 7 <= 63)) {
           diagInfo->lowerMask.set(sq2+7); sq2+=7;
         }
      }
   }
   for (i = 0; i < 8; i++)
   {
      int len = 7-i;
      Bitmap *dstart = d;
      int mask = (1<<len)-1;
      for (int xx=0; xx< attacksCount; xx++)
      {
         byte b = fileAttacks[xx];
         
         if (((int)b & ~mask) == 0)
         {
            // "b" is a possible value for the attacks for
            // this square
            Bitmap unrot;
            Bitmap bits((uint64)b);
            Square sq;
            while (bits.iterate(sq)) {
               unrot.set(Bitmap::unrot_da1[i*9+1+sq]);
            }
            *d++ = unrot;
            ASSERT(d-diagA1Attacks < DIAG_ATTACKS_LEN);
         }
      }
      for (int j = 0; j<len; j++)
      {
         ASSERT(i*9+j+1<64);
         Square sq = Bitmap::unrot_da1[i*9+j+1];
         DiagInfo2 *diagInfo = &DiagInfoA1[sq];
         diagInfo->byte_shift = i;
         diagInfo->bit_shift = i+1;
         diagInfo->mask = mask;
         diagInfo->bit = j;
         diagInfo->attacks = dstart;
         Square sq2 = sq;
         while (File(sq2) != 8 && ((int)sq2 - 7) >= 0) {
           diagInfo->upperMask.set(sq2-7); sq2-=7;
         }
         sq2 = sq;
         while (File(sq2) != 1 && ((int)sq2 + 7 <= 63)) {
           diagInfo->lowerMask.set(sq2+7); sq2+=7;
         }
      }
      int r = i+1; 
      if (r!=1) {
        ep_mask[i][Black].set(Square(r-1,4,Black));
        ep_mask[i][White].set(Square(r-1,4,White));
      }
      if (r!=8) {
        ep_mask[i][Black].set(Square(r+1,4,Black));
        ep_mask[i][White].set(Square(r+1,4,White));
      }
   }
   d = diagA8Attacks;
   for (i = 0; i < 8; i++)
   {
      int len = i+1;
      Bitmap *dstart = d;
      int mask = (1<<len)-1;
      for (int xx=0; xx< attacksCount; xx++)
      {
         byte b = fileAttacks[xx];
         
         if (((int)b & ~mask) == 0)
         {
            // "b" is a possible value for the attacks for
            // this square
            Bitmap unrot;
            Bitmap bits((uint64)b);
            Square sq;
            while (bits.iterate(sq)) {
               ASSERT(i*8+sq<64);
               unrot.set(Bitmap::unrot_da8[i*8+sq]);
            }
            *d++ = unrot;
         }
      }
      for (int j = 0; j<len; j++)
      {
         ASSERT(i*8+j<64);
         Square sq = Bitmap::unrot_da8[i*8+j];
         DiagInfo2 *diagInfo =
          &DiagInfoA8[sq];
         diagInfo->byte_shift = i;
         diagInfo->bit_shift = 0;
         diagInfo->mask = mask;
         diagInfo->bit = j;
         diagInfo->attacks = dstart;
         Square sq2 = sq;
         while (File(sq2) != 1 && ((int)sq2 - 9 >= 0)) {
           diagInfo->upperMask.set(sq2-9); sq2-=9;
         }
         sq2 = sq;
         while (File(sq2) != 8 && ((int)sq2 + 9 <= 63)) {
           diagInfo->lowerMask.set(sq2+9); sq2+=9;
         }
      }
   }
   for (i = 0; i < 8; i++)
   {
      int len = 7-i;
      Bitmap *dstart = d;
      int mask = (1<<len)-1;
      for (int xx=0; xx< attacksCount; xx++)
      {
         byte b = fileAttacks[xx];
         
         if (((int)b & ~mask) == 0)
         {
            // "b" is a possible value for the attacks for
            // this square
            Bitmap unrot;
            Bitmap bits((uint64)b);
            Square sq;
            while (bits.iterate(sq)) {
               unrot.set(Bitmap::unrot_da8[i*9+1+sq]);
            }
            *d++ = unrot;
            ASSERT(d-diagA8Attacks < DIAG_ATTACKS_LEN);
         }
      }
      for (int j = 0; j<len; j++)
      {
         ASSERT(i*9+j+1<64);
         Square sq = Bitmap::unrot_da8[i*9+j+1];
         DiagInfo2 *diagInfo = &DiagInfoA8[sq];
         diagInfo->byte_shift = i;
         diagInfo->bit_shift = i+1;
         diagInfo->mask = mask;
         diagInfo->bit = j;
         diagInfo->attacks = dstart;
         Square sq2 = sq;
         while (File(sq2) != 1 && ((int)sq2 - 9 >= 0)) {
           diagInfo->upperMask.set(sq2-9); sq2-=9;
         }
         sq2 = sq;
         while (File(sq2) != 8 && ((int)sq2 + 9 <= 63)) {
           diagInfo->lowerMask.set(sq2+9); sq2+=9;
         }
      }
   }
}
