connect4.py
import numpy as np
from numba import jit,njit,int32,float32,void,boolean,int64
from numba.experimental import jitclass
import matplotlib.pyplot as plt
import time
import copy
import math
from math import sqrt
import sys
from mcts import MCTS
spec = [
('array', int32[:,:]),
('col_length', int32[:]),
('num_moves_made', int32),
]
@jitclass(spec)
class Connect4Board(object):
def __init__(self):
self.array = np.zeros((6,7),dtype=int32)
self.col_length = np.zeros(7,dtype=int32)
self.num_moves_made = 0
def copy(self, otherboard):
self.array = np.copy(otherboard.array)
self.col_length = np.copy(otherboard.col_length)
self.num_moves_made = otherboard.num_moves_made
def get_moves(self):
moves = []
for i in range(7):
if self.array[-1,i]==0:
moves.append(i)
return moves#np.arange(7)[(self.board[-1]==0)]
def color_to_move(self):
return 2*(self.num_moves_made%2)-1
def make_move(self,i):
color = self.color_to_move()
self.array[self.col_length[i],i] = color
self.col_length[i] +=1
self.num_moves_made +=1
return self.move_won(i)*color
def unmake_move(self,i):
self.array[self.col_length[i]-1,i]=0
self.col_length[i] -=1
self.num_moves_made -=1
#@staticmethod
def inbounds(self,j,i):
return (j<6) and (j>=0) and (i<7) and (i>=0)
def amove_won(self):
for i in range(7):
if self.move_won(i):
return True
return False
def move_won(self,i):
j,i = self.col_length[i]-1,i
color = self.array[j,i]
if color==0:
return False
for (dj,di) in ((0,1),(1,0),(1,1),(1,-1)):
connect_count = 1
for k in range(1,4):
nj,ni = j+k*dj,i+k*di
if not self.inbounds(nj,ni) \
or self.array[nj,ni]!=color:break
connect_count+=1
for k in range(1,4):
nj,ni = j-k*dj,i-k*di
if not self.inbounds(nj,ni) \
or self.array[nj,ni]!=color:break
connect_count+=1
if connect_count >=4:
return True
return False
def is_draw(self):
return self.num_moves_made==42
def reset(self):
self.__init__()
def data(self):
return self.array[::-1]
def show(self):
plt.imshow(self.data())
spec = [
('p1', int64),
('p2', int64),
('col_lengths', int32),
('num_moves_made', int32),
]
@jitclass(spec)
class Connect4BitBoard(object):
def __init__(self):
self.p1 = 0 # encoded 0b[col7]00[col2]00..00[col1]
self.p2 = 0
self.col_lengths = 0 # encoded 3 bits each 0b[010][111][000]...[]
self.num_moves_made = 0
def copy(self, otherboard):
self.p1 = otherboard.p1
self.p2 = otherboard.p2
self.col_lengths = otherboard.col_lengths
self.num_moves_made = otherboard.num_moves_made
def get_moves(self):
filled = (self.p1|self.p2)>>5
moves = []
for i in range(7):
if not filled&0x1:
moves.append(i)
filled = filled>>8
return moves
def color_to_move(self):
return 2*(self.num_moves_made%2)-1
def make_move(self,i):
color = self.color_to_move()
bitboard = self.p1 if color==-1 else self.p2
col_length = (self.col_lengths>>(3*i))&7
bitboard |= ((1<<col_length)<<(8*i))
if color==-1: self.p1=bitboard
else: self.p2=bitboard
self.col_lengths += 1<<(3*i)
self.num_moves_made += 1
return self.move_won(i)*color
def amove_won(self):
color = self.color_to_move()
bitboard = self.p2 if color==-1 else self.p1
# Check \
temp_bboard = bitboard & (bitboard >> 7)
if(temp_bboard & (temp_bboard >> 2 * 7)):
return True
# Check -
temp_bboard = bitboard & (bitboard >> 8)
if(temp_bboard & (temp_bboard >> 2 * 8)):
return True
# Check /
temp_bboard = bitboard & (bitboard >> 9)
if(temp_bboard & (temp_bboard >> 2 * 9)):
return True
# Check |
temp_bboard = bitboard & (bitboard >> 1)
if(temp_bboard & (temp_bboard >> 2 * 1)):
return True
return False
def move_won(self,i):
return self.amove_won()
def is_draw(self):
return self.num_moves_made==42
def reset(self):
self.__init__()
def data(self):
array_rep = np.zeros((6,7),dtype=int32)
for j in range(7):
for i in range(6):
array_rep[i,j] = ((self.p2>>(i +8*j))&1) - ((self.p1>>(i +8*j))&1)
return array_rep[::-1]
def show(self):
plt.imshow(self.data())
text_bank = ["You are the reason they\n put instructions on shampoo",
"A CSGO bot would give\n me a better challenge",
"With moves like that I could\n beat you running on an arduino",
"And I thought this was\n gonna be a tough game",
"You think you've got\n what it takes?","Not bad for a Human",
"How is this possible?",
"Nooo! I cannot lose!!","RIP"]
class Connect4Game(object):
def __init__(self,move_first=True,think_time=1,debug=True):
self.engine = MCTS(Connect4BitBoard)
self.think_time = think_time
self.fig,self.ax = plt.subplots(1,1,figsize=(4,4))
self.ax.grid(which='minor', color='k', linestyle='-', linewidth=2)
self.ax.set_xticks(np.arange(-.5, 7, 1), minor=True);
self.ax.set_yticks(np.arange(-.5, 6, 1), minor=True);
self.ax.set_xticks([])
self.ax.set_yticks([])
self.ppt = self.ax.imshow(self.engine.gameBoard.data(),vmin=-1,vmax=1)
self.text_artist = self.ax.text(3,-.8,"",color='k',fontsize=15,ha='center', va='bottom')
self.text_artist2 = self.ax.text(3,6,"",color='k' if debug else 'white',fontsize=15,ha='center', va='center')
self.fig.canvas.mpl_connect('button_press_event', self.on_click)
plt.show()
if not move_first: self.engine_move_update()
def on_click(self,event):
#plt.text(.5,.5,"arrg")
if self.ax.in_axes(event):
#self.engine.interrupt=True
self.user_move_update(event)
self.engine_move_update()
#threading.thread(None,self.engine.ponder,args=(10,)).start()
def user_move_update(self,event):
user_move,j = self.get_click_coords(event)
outcome = self.engine.make_move(user_move)
if outcome: self.show_victory(outcome)
#self.ax.plot(user_move,j,".r",markersize=4)
self.ppt.set_data(self.engine.gameBoard.data())
self.fig.canvas.draw()
self.fig.canvas.flush_events()
time.sleep(.1)
def engine_move_update(self):
pold =self.engine.searchTree.win_ratio()
engine_move =self.engine.compute_move(self.think_time)
p =self.engine.searchTree.win_ratio()
#p/(pold+1e-6)
i = np.digitize(p/(pold+1e-2),[.7,.8,.9,1.0,1.1,1.2,1.3])
text = text_bank[i]
self.text_artist.set_text(f"{text}")#\n (N={self.engine.searchTree.num_visits},p={p:1.2f})
self.text_artist2.set_text(f"(p={p:1.2f},N={self.engine.searchTree.num_visits})")
outcome = self.engine.make_move(engine_move)
if outcome: self.show_victory(outcome)
self.ppt.set_data(self.engine.gameBoard.data())
self.fig.canvas.draw()
self.fig.canvas.flush_events()
time.sleep(.1)
#self.text_artist.set_text("{:1.2f}".format(self.engine.searchTree.win_ratio()))
def show_victory(self,outcome):
text = "WHITE WINS" if outcome==1 else "BLACK WINS"
plt.text(5, 1.5, text, size=20,
ha="right", va="top",
bbox=dict(boxstyle="round", fc="w", ec="0.5", alpha=0.9)
)
def get_click_coords(self,event):
# Transform the event from display to axes coordinates
imshape = self.ax.get_images()[0]._A.shape[:2]
ax_pos = self.ax.transAxes.inverted().transform((event.x, event.y))
rotate_left = np.array([[0,-1],[1,0]])
i,j = (rotate_left@(ax_pos)*np.array(imshape)//1).astype(int)
i,j = i%imshape[0],j%imshape[1]
return j,i