mfinzi/OMGchess

View on GitHub
chess/aMCTS.py

Summary

Maintainability
A
2 hrs
Test Coverage
from agent import ChessBoard, Agent, NNAgent
import threading
import queue
from torch.utils.data.dataloader import default_collate
import torch
import torch.nn.functional as F
import numpy as np
from concurrent import futures
import time
import copy

class NNevalQueue(queue.Queue):
    # batched_nn_executor holds a queue
    # when the queue is of length (batchsize) then the network
    # evaluates
    # may need its own thread (a daemon thread?)
    def __init__(self,network,batch_size=16):
        super().__init__()
        self.network = network
        self.outputTable = {}
        self.batch_size = batch_size
        self.worker = threading.Thread(target=self.work)
        self.worker.daemon = True
        self.worker.start()
        self.num_dispatches = 0

    def evaluate(self,board):
    #TODO: Switch keys to a random integer to prevent board collision
    # where multiple copies of the same board are submitted in the same
    # minibatch. Ideally avoiding locks & synchronization
        id = np.random.randint(2147483647)
        self.outputTable[id] = None
        self.put((id,board))
        while self.outputTable[id] is None:
            time.sleep(.0001)
        return self.outputTable.pop(id)

    def work(self):
        while True:
            if not self.empty():
                mb_boards = []
                ids = []
                for i in range(self.batch_size):
                    try: 
                        id,board = self.get(False)
                        mb_boards.append(board)
                        ids.append(id)
                    except queue.Empty: break
                #print("queue dispatch with {} boards".format(len(mb_boards)))
                encoded_mb = [self.network.encode(board) for board in mb_boards]
                collated_mb = default_collate(encoded_mb)
                values,logits = self.network(*collated_mb)
                values = values.data.cpu().numpy()
                moveProbs = F.softmax(logits,dim=1).data.cpu().numpy()
                for id, val, moveProb in zip(ids,values,moveProbs):
                    self.outputTable[id] = (val,moveProb)
                self.num_dispatches+=1
            else: time.sleep(.0001)
            
                
            
            
        
class SearchNode(object):
    C_PUCT = 2.5
    EPSILON = 1e-5
    def __init__(self, moveProbs):
        considered_mvs_mask = moveProbs > SearchNode.EPSILON
        self.mv_ids = np.arange(len(moveProbs))[considered_mvs_mask] # Move encodings
        self.Ps = np.ones(len(self.mv_ids))/len(self.mv_ids)#moveProbs[considered_mvs_mask] # Nonzero move probabilites
        self.Ns = np.zeros(len(self.mv_ids)) # Edge visit counts
        self.Vs = np.zeros(len(self.mv_ids)) # Edge values
        self.Qs = np.zeros(len(self.mv_ids))-1 # Vs/Ns (the mean Q value)
        self.children = [None]*len(self.mv_ids)

    @staticmethod #@lru_cache
    def newchild_and_value(board,transposition_table,eval_queue):
        # TODO: Check if board is in transposition table
        # if so return the cached results
        # output = transposition_table.get(board,None)
        # if output is None:
        if board.is_game_over():
            return None, board.outcome()
        value_abs, moveProbs = eval_queue.evaluate(board)
        child = SearchNode(moveProbs) #TODO: Deal with terminal board state, checkmate, draw, etc
        color = board.turn*2 - 1 # -1 or 1
        value_rel = value_abs*color
        output = (child, value_rel)
        #transposition_table[board] = output
        return output

    def update_path(self,board,table,eval_queue):
        i = self.select()
        # The virtual loss to prevent unwanted thread interaction
        # increments the visit counts before the value is propagated
        self.Ns[i] +=1 
        mv_id = self.mv_ids[i]
        board.make_move(board.nn_decode_move(mv_id))
        child = self.children[i]
        if child is None:
            self.children[i], value = self.newchild_and_value(board,table,eval_queue)
        else:
            value = child.update_path(board,table,eval_queue)
        self.Vs[i] += value
        self.Qs[i] = self.Vs[i]/self.Ns[i]
        return value

    def select(self):
        """ Returns the move index according to PUCT"""
        sqrtSumN = np.sqrt(np.sum(self.Ns))
        Us = self.Ps*sqrtSumN/(1+self.Ns)
        mv_index = np.argmax(self.Qs + SearchNode.C_PUCT*Us)
        return mv_index


class MCTSAgent(Agent):
    def __init__(self,GameType,network,movetime=1,bs=12,num_threads=30):
        super().__init__(GameType)
        self.num_threads = num_threads
        self.movetime=movetime
        self.trans_table = {}
        network.eval()
        self.eval_queue = NNevalQueue(network,batch_size=bs)
        self.searchTree,_ = SearchNode.newchild_and_value(
                            self.board,self.trans_table,self.eval_queue)
    def make_action(self,move):
        mvindex = np.where(self.searchTree.mv_ids==self.board.nn_encode_move(move))[0][0]
        super().make_action(move)
        self.searchTree = self.searchTree.children[mvindex]
        if self.searchTree is None:
            self.searchTree,_ = SearchNode.newchild_and_value(
                        self.board,self.trans_table,self.eval_queue)

    def run_simulation(self):
        simulation_board = copy.deepcopy(self.board)
        self.searchTree.update_path(simulation_board,self.trans_table,self.eval_queue)
        return 0

    def run_k_simulations(self,k=100):
        for i in range(k):
            self.run_simulation()

    def think(self,thinktime=1):
        start_time = time.time()
        with futures.ThreadPoolExecutor(self.num_threads) as exc:
            active_simulations = {exc.submit(self.run_simulation) for i in range(self.num_threads)}
            while time.time() - start_time < thinktime:
                done, not_done = futures.wait(active_simulations,
                                              timeout=.02,return_when=futures.FIRST_COMPLETED)
                for future in done:
                    active_simulations.remove(future)
                    active_simulations.add(exc.submit(self.run_simulation))
            futures.wait(active_simulations)

    def compute_action(self):
        start_time =time.time()
        while time.time() - start_time < self.movetime:
            self.run_simulation()
        # with futures.ThreadPoolExecutor(self.num_threads) as exc:
        #     active_simulations = {exc.submit(self.run_simulation) for i in range(self.num_threads)}
        #     while time.time() - start_time < self.movetime:
        #         done, not_done = futures.wait(active_simulations,
        #                                       timeout=.02,return_when=futures.FIRST_COMPLETED)
        #         for future in done:
        #             active_simulations.remove(future)
        #             active_simulations.add(exc.submit(self.run_simulation))
        #     futures.wait(active_simulations)
        #mv_id = self.searchTree.mv_ids[np.argmax(self.searchTree.Ns)]
        #sqrtSumN = np.sqrt(np.sum(self.searchTree.Ns))
        #Us = self.searchTree.Ps*sqrtSumN/(1+self.searchTree.Ns)
        j=np.argmax(self.searchTree.Qs)
        mv_id = self.searchTree.mv_ids[j]
        action = self.board.nn_decode_move(mv_id)
        print("{} takes action {} with evals {}".format(['Black','White'][self.board.turn],action,self.searchTree.Ns[j]))
        print("# of Nodes evaluated: {}".format(np.sum(self.searchTree.Ns)))
        #j = np.argmax(self.searchTree.Ns)
        print("Yielding score: {}".format(self.searchTree.Vs[j]/(self.searchTree.Ns[j]+.00001)))
        print("Best scoring move was: _ with score {}".format(np.max(self.searchTree.Vs/(self.searchTree.Ns+.0001))))
        return action

    # def play_through_game()