Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions python/board.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import random
import numpy as np

class IllegalMoveError(ValueError):
pass

#Implements legal moves without superko
class Board:
EMPTY = 0
Expand Down Expand Up @@ -290,9 +293,9 @@ def is_on_board(self,loc):
#Set a given location with error checking. Suicide setting allowed.
def set_stone(self,pla,loc):
if pla != Board.EMPTY and pla != Board.BLACK and pla != Board.WHITE:
raise ValueError("Invalid pla for board.set")
raise IllegalMoveError("Invalid pla for board.set")
if not self.is_on_board(loc):
raise ValueError("Invalid loc for board.set")
raise IllegalMoveError("Invalid loc for board.set")

if self.board[loc] == pla:
pass
Expand All @@ -312,17 +315,17 @@ def set_stone(self,pla,loc):
#Single stone suicide is disallowed but suicide is allowed, to support rule sets and sgfs that have suicide
def play(self,pla,loc):
if pla != Board.BLACK and pla != Board.WHITE:
raise ValueError("Invalid pla for board.play")
raise IllegalMoveError("Invalid pla for board.play")

if loc != Board.PASS_LOC:
if not self.is_on_board(loc):
raise ValueError("Invalid loc for board.set")
raise IllegalMoveError("Invalid loc for board.set")
if self.board[loc] != Board.EMPTY:
raise ValueError("Location is nonempty")
raise IllegalMoveError("Location is nonempty")
if self.would_be_single_stone_suicide(pla,loc):
raise ValueError("Move would be illegal single stone suicide")
raise IllegalMoveError("Move would be illegal single stone suicide")
if loc == self.simple_ko_point:
raise ValueError("Move would be illegal simple ko recapture")
raise IllegalMoveError("Move would be illegal simple ko recapture")

self.playUnsafe(pla,loc)

Expand Down
6 changes: 4 additions & 2 deletions python/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from board import Board

class Metadata:
def __init__(self, size, bname, wname, brank, wrank, komi):
def __init__(self, size, bname, wname, brank, wrank, komi, handicap):
self.size = size
self.bname = bname
self.wname = wname
self.brank = brank
self.wrank = wrank
self.komi = komi
self.handicap = handicap

#Returns (metadata, list of setup stones, list of move stones)
#Setup and move stones are both pairs of (pla,loc)
Expand Down Expand Up @@ -89,6 +90,7 @@ def load_sgf_moves_exn(path):
wrank = (root.get("WR") if root.has_property("WR") else None)
komi = (root.get("KM") if root.has_property("KM") else None)
rulesstr = (root.get("RU") if root.has_property("RU") else None)
handicap = (root.get("HA") if root.has_property("HA") else None)

rules = None
if rulesstr is not None:
Expand Down Expand Up @@ -157,5 +159,5 @@ def load_sgf_moves_exn(path):
else:
raise Exception("Could not parse rules: " + origrulesstr)

metadata = Metadata(size, bname, wname, brank, wrank, komi)
metadata = Metadata(size, bname, wname, brank, wrank, komi, handicap)
return metadata, setup, moves, rules
131 changes: 131 additions & 0 deletions python/genboard_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import traceback
import json
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResBlock(nn.Module):
def __init__(self,num_channels,scale_init):
super(ResBlock, self).__init__()
kernel_size = 3
self.biasa = nn.Parameter(torch.zeros(num_channels,1,1))
self.conva = nn.Conv2d(in_channels=num_channels, out_channels=num_channels, kernel_size=kernel_size, padding=1, bias=False)
torch.nn.init.normal_(self.conva.weight,std=math.sqrt(2.0 / num_channels / kernel_size / kernel_size)*scale_init)
self.biasb = nn.Parameter(torch.zeros(num_channels,1,1))
self.scalb = nn.Parameter(torch.ones(num_channels,1,1))
self.convb = nn.Conv2d(in_channels=num_channels, out_channels=num_channels, kernel_size=kernel_size, padding=1, bias=False)
torch.nn.init.zeros_(self.convb.weight)

def forward(self, trunk):
x = F.relu(trunk+self.biasa)
x = self.conva(x)
x = F.relu(x*self.scalb+self.biasb)
x = self.convb(x)
return trunk+x

class GPoolResBlock(nn.Module):
def __init__(self,num_channels,scale_init):
super(GPoolResBlock, self).__init__()
kernel_size = 3
self.biasa = nn.Parameter(torch.zeros(num_channels,1,1))
self.conva = nn.Conv2d(in_channels=num_channels, out_channels=num_channels, kernel_size=kernel_size, padding=1, bias=False)
torch.nn.init.normal_(self.conva.weight,std=math.sqrt(1.0 / num_channels / kernel_size / kernel_size)*scale_init)
self.convg = nn.Conv2d(in_channels=num_channels, out_channels=num_channels, kernel_size=kernel_size, padding=1, bias=False)
torch.nn.init.normal_(self.convg.weight,std=math.sqrt(1.0 / num_channels / kernel_size / kernel_size)*math.sqrt(scale_init))
self.matg = nn.Parameter(torch.zeros(num_channels,num_channels))
torch.nn.init.normal_(self.matg,std=math.sqrt(1.0 / num_channels)*math.sqrt(scale_init))
self.biasb = nn.Parameter(torch.zeros(num_channels,1,1))
self.scalb = nn.Parameter(torch.ones(num_channels,1,1))
self.convb = nn.Conv2d(in_channels=num_channels, out_channels=num_channels, kernel_size=kernel_size, padding=1, bias=False)
torch.nn.init.zeros_(self.convb.weight)

def forward(self, trunk):
x = F.relu(trunk+self.biasa)
x = self.conva(x)
g = self.convg(x)
gsize = g.size()
g = torch.sum(g,(2,3)) / (gsize[2] * gsize[3]) # nchw -> nc
g = torch.matmul(g,self.matg)
g = g.view(gsize[0],gsize[1],1,1)
x = x + g
x = F.relu(x*self.scalb+self.biasb)
x = self.convb(x)
return trunk+x


class Model(nn.Module):
def __init__(self, num_channels, num_blocks):
super(Model, self).__init__()
# Channel 0: Next inference point
# Channel 1: On-board
# Channel 2: Black
# Channel 3: White
# Channel 4: Unknown
# Channel 5: Turn number / 100
# Channel 6: Noise stdev in turn number / 50
# Channel 7: Source

self.inference_channel = 0
self.num_channels = num_channels
self.num_blocks = num_blocks
self.conv0 = nn.Conv2d(in_channels=8, out_channels=self.num_channels, kernel_size=3, padding=1, bias=False)

self.blocks = nn.ModuleList([])
self.fixup_scale_init = 1.0 / math.sqrt(self.num_blocks)
self.blocks.append(ResBlock(self.num_channels,self.fixup_scale_init))
self.blocks.append(ResBlock(self.num_channels,self.fixup_scale_init))

next_is_gpool = True
for b in range(num_blocks-2):
if next_is_gpool:
self.blocks.append(GPoolResBlock(self.num_channels,self.fixup_scale_init))
else:
self.blocks.append(ResBlock(self.num_channels,self.fixup_scale_init))
next_is_gpool = not next_is_gpool

assert(len(self.blocks) == self.num_blocks)

self.endtrunk_bias_focus = nn.Parameter(torch.zeros(self.num_channels,1,1))
self.endtrunk_bias_g = nn.Parameter(torch.zeros(self.num_channels,1,1))
self.convg = nn.Conv2d(in_channels=self.num_channels, out_channels=self.num_channels, kernel_size=1, padding=0, bias=False)

self.fc1 = nn.Linear(self.num_channels*2, self.num_channels)
self.fc2 = nn.Linear(self.num_channels,3)
self.convaux = nn.Conv2d(in_channels=self.num_channels, out_channels=3, kernel_size=1, padding=0, bias=True)

def forward(self, inputs):
trunk = self.conv0(inputs)
for i in range(self.num_blocks):
trunk = self.blocks[i](trunk)

head_focus = F.relu(trunk+self.endtrunk_bias_focus)
head_g = F.relu(trunk+self.endtrunk_bias_g)
aux = self.convaux(head_focus)
gsize = head_g.size()

x = torch.sum(head_focus * inputs[:,self.inference_channel:self.inference_channel+1,:,:],(2,3))
g = torch.sum(head_g,(2,3)) / (gsize[2] * gsize[3]) # nchw -> nc

x = torch.cat((x,g),dim=1)

x = F.relu(self.fc1(x))
x = self.fc2(x)
return x,aux

def save_to_file(self, filename):
state_dict = self.state_dict()
data = {}
data["num_channels"] = self.num_channels
data["num_blocks"] = self.num_blocks
data["state_dict"] = state_dict
torch.save(data, filename)

@staticmethod
def load_from_file(filename):
data = torch.load(filename)
model = Model(data["num_channels"], data["num_blocks"])
model.load_state_dict(data["state_dict"])
return model


Loading