# include <Game/Game.h>
# include <ProofNumberNode.h>
# include <Moves/Move.h>
# include <Moves/SanNotation.h>
# include <iostream>
# include <fstream>
# include <stack>
# include <cassert>
# include <CpuTimer.h>
# include <TestData/TestPositionSuite.h>
# include <iterator>
using namespace Alice;

Color* winner;
int bound = 200;

class ProofNumberSearch
{
public:
  enum Result{PROVEN, DISPROVEN, UNKNOWN};
  ProofNumberSearch(Game& game, long mNodes = 10000000);
  virtual ~ProofNumberSearch();
  int countLegalMoves( Game& g );
  void findCurrentNode( Game& g );
  void setDraw(Game& g, ProofNumberNode* node);
  void backup(Game& g);
  virtual void evaluate(Game& g, ProofNumberNode* child);
  void legalMoves(Game& g, MoveList& moves);
  void expand(Game& g);
  void printSolution(Game& g, bool pvOnly, ProofNumberNode* node, 
		     int moveNumber = 2);
  void printSolution(Game& g, bool pvOnly)
  {
    printSolution(g, pvOnly, root, 2);
  };
  bool isDone() const;
  Result getResult() const;
  int lengthToMate() const;
public:
  ProofNumberNode* root;
private:
  Game& g;
  ProofNumberNode* currentNode;
  std::stack<ProofNumberNode*> nodeStack;
  long maxNodes;
};

bool
ProofNumberSearch::isDone() const
{
  if (root->getProofNumber() == 0)
    return true;
  if (root->getDisproofNumber() == 0)
    return true;
  return ProofNumberNode::getNumberOfInstances() > maxNodes;
};

ProofNumberSearch::ProofNumberSearch( Game& game, long mNodes )
  : g( game ),
    maxNodes( mNodes )
{
  root = new ProofNumberNode;
  currentNode = root;
};

ProofNumberSearch::~ProofNumberSearch()
{
  delete root;
};

int
ProofNumberSearch::countLegalMoves(Game& g)
{
  MoveList moves;
  g.pseudoLegalMoves(moves);
  int result = 0;
  for (MoveList::iterator it = moves.begin();
       it != moves.end(); ++ it)
    {
      (*it)->makeOn(g);
      if (! g.isInCheck(g.colorToMove()->otherColor()))
	result ++;
      (*it)->takeBackOn(g);
    }
  return result;
};

void 
ProofNumberSearch::findCurrentNode(Game& g)
{
  while ( currentNode->isExpanded() )
    {
      nodeStack.push(currentNode);
      int i = currentNode->getChildToExpand();
      currentNode = currentNode->getChild(i);
      SmartPointer<Move> move = currentNode->getIncomingMove();
      move->makeOn(g);
    };
};

void

ProofNumberSearch::setDraw(Game& g, ProofNumberNode* node)
{
  
  if (g.colorToMove() != winner)
    node->setProven();
  else
    node->setDisproven();
};

void

ProofNumberSearch::backup(Game& g)
{
  currentNode->getIncomingMove()->takeBackOn(g);
  currentNode = nodeStack.top();
  nodeStack.pop();
  currentNode->adjustNumbers();
};

void
ProofNumberSearch::evaluate(Game& g, ProofNumberNode* child)
{
  assert(child == currentNode);
  if(g.isRepeated() || child->getHeight() > bound)
    {
      setDraw(g, child);
      return;
    }
  int n = countLegalMoves(g);
  child->setDisproofNumber(n);
  if (n != 0)
    return;
  if (g.isInCheck(g.colorToMove()))
    child->setDisproven();
  else
    setDraw(g, child);
}

void
ProofNumberSearch::legalMoves(Game& g, MoveList& moves)
{
  MoveList pseudoLegalMoves;
  g.pseudoLegalMoves(pseudoLegalMoves);
  for (MoveList::iterator it = pseudoLegalMoves.begin();
       it != pseudoLegalMoves.end();
       ++ it)
    {
      bool legal = true;
      (*it)->makeOn(g);
      if (g.isInCheck(g.colorToMove()->otherColor()))
	legal = false;
      (*it)->takeBackOn(g);
      if (legal)
	moves.push_back(*it);
    }
};

void 
ProofNumberSearch::expand(Game& g)
{
  findCurrentNode(g);
  MoveList(moves);
  legalMoves(g, moves);
  currentNode->expand(moves.size());
  int i = 0;
  for (MoveList::iterator it = moves.begin();
       it != moves.end(); ++ it)
    currentNode->getChild(i++)->setIncomingMove(*it);
  for (unsigned int i = 0; i < moves.size(); i++)
    {
      nodeStack.push(currentNode);
      currentNode = currentNode->getChild(i);
      SmartPointer<Move> move = currentNode->getIncomingMove();
      move->makeOn(g);
      evaluate( g, currentNode );
      move->takeBackOn(g);
      currentNode = nodeStack.top();
      nodeStack.pop();
    }
  currentNode->adjustNumbers();
  while (nodeStack.size())
    backup(g);
};

void
ProofNumberSearch::printSolution(Game& g,  bool pvOnly, ProofNumberNode* node,
	      int moveNumber)
{
  if(!node->isExpanded())
    return;

  int i = 0;
  node->sort();
  std::vector<ProofNumberNode*> goodChildren;
  for (i = 0; i < node->getNumberOfChildren(); i++)
    {
      ProofNumberNode* child = node->getChild(i);
      if ((child->getDisproofNumber() == node->getProofNumber()))
	goodChildren.push_back(child);
    }
  if (! goodChildren.size())
    return;
  SmartPointer<Move> bestMove = goodChildren.front()->getIncomingMove();
  
  SanNotation san(bestMove, g);
  std::cout<<san<<" ";
  if (goodChildren.size() > 1 && ! pvOnly)
    {
      for (std::vector<ProofNumberNode*>::iterator it = goodChildren.begin()+1;
	   it != goodChildren.end(); ++ it)
	{
	  std::cout<<"( ";
      	  SmartPointer<Move> move = (*it)->getIncomingMove();
	  SanNotation san(move, g);
	  std::cout<<san<<" ";
	  move->makeOn(g);
	  printSolution(g,  pvOnly,*it, moveNumber+1);
	  move->takeBackOn(g);
	  std::cout<<") ";
	}
    }
  bestMove->makeOn(g);
  printSolution(g, pvOnly, goodChildren.front(), moveNumber+1);
  bestMove->takeBackOn(g);
};

ProofNumberSearch::Result
ProofNumberSearch::getResult() const
{
  if (root->getProofNumber() == 0)
    return PROVEN;
  if (root->getDisproofNumber() == 0 )
    return DISPROVEN;
  return UNKNOWN;
};

int
ProofNumberSearch::lengthToMate() const
{
  return root->lengthToMate();
};

class ProofNumberSquaredSearch
  : public ProofNumberSearch
{
public:
  ProofNumberSquaredSearch(Game& g, long mNodes);
  virtual ~ProofNumberSquaredSearch();
  void evaluate(Game& g, ProofNumberNode* child);
};

ProofNumberSquaredSearch::ProofNumberSquaredSearch( Game& g, long mNodes )
  : ProofNumberSearch( g, mNodes )
{
};

ProofNumberSquaredSearch::~ProofNumberSquaredSearch()
{
};

void
ProofNumberSquaredSearch::evaluate(Game& g, ProofNumberNode* child)
{
  if (g.colorToMove() != winner)
    {
      ProofNumberSearch::evaluate(g, child);
      return;
    }
  ProofNumberSearch search(g, ProofNumberNode::getNumberOfInstances() + 50 );
  int oldBound = bound;
  bound = bound - child->getHeight();
  while (!search.isDone())
    search.expand(g);
  bound = oldBound;
  if (search.getResult() == ProofNumberSearch::UNKNOWN)
    {
      child->setProofNumber(search.root->getProofNumber());
      child->setDisproofNumber(search.root->getDisproofNumber());
      return;
    };
  child->swap(*search.root);
  child->adjustNumbers();
  
};

int main(int argc, char* argv[])
{
  int solved = 0;
  TestPositionSuite suite(argv[1]);
  std::string baseName(argv[1]);
  std::list<int> incorrect;
  std::string incorrectName = baseName + ".incorrect";
  std::string shorterName =  baseName + ".shorter";
  std::string notSolvedName =  baseName + ".notSolved";
  std::ofstream incorrectFile(incorrectName.c_str());
  std::ofstream shorterFile(shorterName.c_str());
  std::ofstream notSolvedFile(notSolvedName.c_str());
  CpuTimer totalTime;
  for (int number = 0; number < suite.numberOfPositions(); number++)
    {
      std::cout<<number+1<<" of "<<suite.numberOfPositions()<<std::endl;

      Game g;
      std::string boundString = suite.getPosition(number).getOperand("ce");
      if (boundString != "")
	bound = 32767 - std::atoi(boundString.c_str());
      else
	bound = 200;
      std::cout<<"bound: "<<bound<<std::endl;;
      g.forsytheString(suite.getPosition(number).forsythe());
      if (suite.getPosition(number).getColorToMove() == "b")
	  g.doNullMove();
      winner = g.colorToMove();
      ProofNumberSquaredSearch algorithm( g, 200000 );
      CpuTimer timer;
      while (! algorithm.isDone())
	algorithm.expand(g);
      if ( algorithm.getResult() == ProofNumberSearch::UNKNOWN )
	{
	  std::cout<<timer<<" no solution found"<<std::endl;
	  suite.getPosition(number).writeEPD(notSolvedFile);
	}
      else
	{
	  solved++;
	  std::cout<<std::endl<<timer;
	  if (algorithm.getResult() == ProofNumberSearch::DISPROVEN)
	    {
	      std::cout<<" No mate in ";
	      incorrect.push_back(number);
	      suite.getPosition(number).writeEPD(incorrectFile);
	      std::cout<<(bound+1)/2<<std::endl;
	    }
	  else
	    {
	      std::cout<<" found mate in "<<(algorithm.lengthToMate() +1)/2
		       << std::endl;;
	      if (algorithm.lengthToMate() < bound)
		suite.getPosition(number).writeEPD(shorterFile);
	    }
	  algorithm.printSolution(g, true);
	  std::cout<<std::endl;
	}
      std::cout<<"solved "<<solved<<" of "
	       <<number+1<<", "<<100*solved/(number+1)<<"%"<<std::endl;
      std::cout<<"average time per problem:    ";
      double averageTime = totalTime.age()/(number+1);
      Timer::printTime(std::cout, averageTime);
      std::cout<<std::endl;
      std::cout<<"expected time to completion: ";
      double remainingTime = averageTime *(suite.numberOfPositions()-1-number);
      Timer::printTime(std::cout, remainingTime);
      std::cout<<std::endl;
      std::cout<<"expected total time:         ";
      double total = totalTime.age() + remainingTime;
      Timer::printTime(std::cout, total);
      std::cout<<std::endl;
    }
  std::cout<<"total time: "<<totalTime<<std::endl;
  return 0;
};
