mfinzi/OMGchess

View on GitHub
quarto.py

Summary

Maintainability
A
3 hrs
Test Coverage
import numpy as np
from numba import jit,njit,int32,float32,void,boolean,int64
from numba.experimental import jitclass
import matplotlib.pyplot as plt
import time
import copy
import math
from math import sqrt
import sys
from mcts import MCTS

spec = [
    ('array', int32[:,:]),
    ('pieces_used', int32[:,:]),
    ('num_moves_made', int32),
]
@jitclass(spec)
class QuartoBoard(object):
    def __init__(self):
        self.array = np.zeros((4,4),dtype=int32)
        self.pieces_used = np.zeros((16,),dtype=int32)
        self.num_moves_made = 1
    
    def copy(self, otherboard):
        self.array = np.copy(otherboard.array)
        self.num_moves_made = otherboard.num_moves_made
        
    def get_moves(self):
        moves = []
        if not self.num_moves_made%2: 
            for i in range(16): # Placement turn
                if not self.array[i//4,i%4]:
                    moves.append(i)
        else:
            #Proposal turn
            for i in range(16):
                if not self.pieces_used[i]:
                    moves.append(i)
        return moves#np.arange(7)[(self.board[-1]==0)]
        
    def color_to_move(self):
        return 2*((self.num_moves_made%4)//2)-1
    
    def make_move(self,i):
        color = self.color_to_move()
        self.array[self.col_length[i],i] = color
        self.col_length[i] +=1
        self.num_moves_made +=1
        return self.move_won(i)*color
    
    def unmake_move(self,i):
        self.array[self.col_length[i]-1,i]=0
        self.col_length[i] -=1
        self.num_moves_made -=1
        
    #@staticmethod
    def inbounds(self,j,i):
        return (j<6) and (j>=0) and (i<7) and (i>=0)
    
    def amove_won(self):
        for i in range(7):
            if self.move_won(i):
                return True
        return False

    def move_won(self,i):
        j,i = self.col_length[i]-1,i
        color = self.array[j,i]
        if color==0:
            return False
        for (dj,di) in ((0,1),(1,0),(1,1),(1,-1)):
            connect_count = 1
            for k in range(1,4):
                nj,ni = j+k*dj,i+k*di
                if not self.inbounds(nj,ni) \
                    or self.array[nj,ni]!=color:break
                connect_count+=1
            for k in range(1,4):
                nj,ni = j-k*dj,i-k*di
                if not self.inbounds(nj,ni) \
                    or self.array[nj,ni]!=color:break
                connect_count+=1
            if connect_count >=4:
                return True
        return False
    
    def is_draw(self):
        return self.num_moves_made==42
    
    def reset(self):
        self.__init__()
        
    def data(self):
        return self.array[::-1]
    
    def show(self):
        plt.imshow(self.data())


class Connect4Game(object):
    def __init__(self,move_first=True,think_time=1):
        self.engine = MCTS(Connect4BitBoard)
        self.think_time = think_time
        self.fig,self.ax = plt.subplots(1,1,figsize=(4,4))
        self.ax.grid(which='minor', color='k', linestyle='-', linewidth=2)
        self.ax.set_xticks(np.arange(-.5, 7, 1), minor=True);
        self.ax.set_yticks(np.arange(-.5, 6, 1), minor=True);
        self.ppt = self.ax.imshow(self.engine.gameBoard.data(),vmin=-1,vmax=1)
        self.text_artist = self.ax.text(2,1,"",color='w')
        self.fig.canvas.mpl_connect('button_press_event', self.on_click)
        plt.show()
        if not move_first: self.engine_move_update()
    
    def on_click(self,event):
        #plt.text(.5,.5,"arrg")
        if self.ax.in_axes(event):
            #self.engine.interrupt=True
            self.user_move_update(event)
            self.engine_move_update()
            #threading.thread(None,self.engine.ponder,args=(10,)).start()
            
    def user_move_update(self,event):
        user_move,j = self.get_click_coords(event)
        outcome = self.engine.make_move(user_move)
        if outcome: self.show_victory(outcome)
        #self.ax.plot(user_move,j,".r",markersize=4)
        self.ppt.set_data(self.engine.gameBoard.data())
        self.fig.canvas.draw()
        self.fig.canvas.flush_events()
        time.sleep(.1)
        
    def engine_move_update(self):
        engine_move =self.engine.compute_move(self.think_time)
        self.text_artist.set_text("{};{:1.2f}".format(self.engine.searchTree.num_visits,self.engine.searchTree.win_ratio()))
        outcome = self.engine.make_move(engine_move)
        if outcome: self.show_victory(outcome)
        self.ppt.set_data(self.engine.gameBoard.data())
        self.fig.canvas.draw()
        self.fig.canvas.flush_events()
        time.sleep(.1)
        #self.text_artist.set_text("{:1.2f}".format(self.engine.searchTree.win_ratio()))
        
            
    def show_victory(self,outcome):
        text = "WHITE WINS" if outcome==1 else "BLACK WINS"
        plt.text(5, 1.5, text, size=20,
             ha="right", va="top",
             bbox=dict(boxstyle="round", fc="w", ec="0.5", alpha=0.9)
             )
    
    def get_click_coords(self,event):
        # Transform the event from display to axes coordinates
        imshape = self.ax.get_images()[0]._A.shape[:2]
        ax_pos = self.ax.transAxes.inverted().transform((event.x, event.y))
        rotate_left = np.array([[0,-1],[1,0]])
        i,j = (rotate_left@(ax_pos)*np.array(imshape)//1).astype(int)
        i,j = i%imshape[0],j%imshape[1]
        return j,i