diff --git a/python/board.py b/python/board.py index ca561a90d..4aa2e21ac 100644 --- a/python/board.py +++ b/python/board.py @@ -2,6 +2,9 @@ import random import numpy as np +class IllegalMoveError(ValueError): + pass + #Implements legal moves without superko class Board: EMPTY = 0 @@ -290,9 +293,9 @@ def is_on_board(self,loc): #Set a given location with error checking. Suicide setting allowed. def set_stone(self,pla,loc): if pla != Board.EMPTY and pla != Board.BLACK and pla != Board.WHITE: - raise ValueError("Invalid pla for board.set") + raise IllegalMoveError("Invalid pla for board.set") if not self.is_on_board(loc): - raise ValueError("Invalid loc for board.set") + raise IllegalMoveError("Invalid loc for board.set") if self.board[loc] == pla: pass @@ -312,17 +315,17 @@ def set_stone(self,pla,loc): #Single stone suicide is disallowed but suicide is allowed, to support rule sets and sgfs that have suicide def play(self,pla,loc): if pla != Board.BLACK and pla != Board.WHITE: - raise ValueError("Invalid pla for board.play") + raise IllegalMoveError("Invalid pla for board.play") if loc != Board.PASS_LOC: if not self.is_on_board(loc): - raise ValueError("Invalid loc for board.set") + raise IllegalMoveError("Invalid loc for board.set") if self.board[loc] != Board.EMPTY: - raise ValueError("Location is nonempty") + raise IllegalMoveError("Location is nonempty") if self.would_be_single_stone_suicide(pla,loc): - raise ValueError("Move would be illegal single stone suicide") + raise IllegalMoveError("Move would be illegal single stone suicide") if loc == self.simple_ko_point: - raise ValueError("Move would be illegal simple ko recapture") + raise IllegalMoveError("Move would be illegal simple ko recapture") self.playUnsafe(pla,loc) diff --git a/python/data.py b/python/data.py index 4db3b92b7..25fd22502 100644 --- a/python/data.py +++ b/python/data.py @@ -5,13 +5,14 @@ from board import Board class Metadata: - def __init__(self, size, bname, wname, brank, wrank, komi): + def __init__(self, size, bname, wname, brank, wrank, komi, handicap): self.size = size self.bname = bname self.wname = wname self.brank = brank self.wrank = wrank self.komi = komi + self.handicap = handicap #Returns (metadata, list of setup stones, list of move stones) #Setup and move stones are both pairs of (pla,loc) @@ -89,6 +90,7 @@ def load_sgf_moves_exn(path): wrank = (root.get("WR") if root.has_property("WR") else None) komi = (root.get("KM") if root.has_property("KM") else None) rulesstr = (root.get("RU") if root.has_property("RU") else None) + handicap = (root.get("HA") if root.has_property("HA") else None) rules = None if rulesstr is not None: @@ -157,5 +159,5 @@ def load_sgf_moves_exn(path): else: raise Exception("Could not parse rules: " + origrulesstr) - metadata = Metadata(size, bname, wname, brank, wrank, komi) + metadata = Metadata(size, bname, wname, brank, wrank, komi, handicap) return metadata, setup, moves, rules diff --git a/python/genboard_common.py b/python/genboard_common.py new file mode 100644 index 000000000..69ac6f35d --- /dev/null +++ b/python/genboard_common.py @@ -0,0 +1,131 @@ +import traceback +import json +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +class ResBlock(nn.Module): + def __init__(self,num_channels,scale_init): + super(ResBlock, self).__init__() + kernel_size = 3 + self.biasa = nn.Parameter(torch.zeros(num_channels,1,1)) + self.conva = nn.Conv2d(in_channels=num_channels, out_channels=num_channels, kernel_size=kernel_size, padding=1, bias=False) + torch.nn.init.normal_(self.conva.weight,std=math.sqrt(2.0 / num_channels / kernel_size / kernel_size)*scale_init) + self.biasb = nn.Parameter(torch.zeros(num_channels,1,1)) + self.scalb = nn.Parameter(torch.ones(num_channels,1,1)) + self.convb = nn.Conv2d(in_channels=num_channels, out_channels=num_channels, kernel_size=kernel_size, padding=1, bias=False) + torch.nn.init.zeros_(self.convb.weight) + + def forward(self, trunk): + x = F.relu(trunk+self.biasa) + x = self.conva(x) + x = F.relu(x*self.scalb+self.biasb) + x = self.convb(x) + return trunk+x + +class GPoolResBlock(nn.Module): + def __init__(self,num_channels,scale_init): + super(GPoolResBlock, self).__init__() + kernel_size = 3 + self.biasa = nn.Parameter(torch.zeros(num_channels,1,1)) + self.conva = nn.Conv2d(in_channels=num_channels, out_channels=num_channels, kernel_size=kernel_size, padding=1, bias=False) + torch.nn.init.normal_(self.conva.weight,std=math.sqrt(1.0 / num_channels / kernel_size / kernel_size)*scale_init) + self.convg = nn.Conv2d(in_channels=num_channels, out_channels=num_channels, kernel_size=kernel_size, padding=1, bias=False) + torch.nn.init.normal_(self.convg.weight,std=math.sqrt(1.0 / num_channels / kernel_size / kernel_size)*math.sqrt(scale_init)) + self.matg = nn.Parameter(torch.zeros(num_channels,num_channels)) + torch.nn.init.normal_(self.matg,std=math.sqrt(1.0 / num_channels)*math.sqrt(scale_init)) + self.biasb = nn.Parameter(torch.zeros(num_channels,1,1)) + self.scalb = nn.Parameter(torch.ones(num_channels,1,1)) + self.convb = nn.Conv2d(in_channels=num_channels, out_channels=num_channels, kernel_size=kernel_size, padding=1, bias=False) + torch.nn.init.zeros_(self.convb.weight) + + def forward(self, trunk): + x = F.relu(trunk+self.biasa) + x = self.conva(x) + g = self.convg(x) + gsize = g.size() + g = torch.sum(g,(2,3)) / (gsize[2] * gsize[3]) # nchw -> nc + g = torch.matmul(g,self.matg) + g = g.view(gsize[0],gsize[1],1,1) + x = x + g + x = F.relu(x*self.scalb+self.biasb) + x = self.convb(x) + return trunk+x + + +class Model(nn.Module): + def __init__(self, num_channels, num_blocks): + super(Model, self).__init__() + # Channel 0: Next inference point + # Channel 1: On-board + # Channel 2: Black + # Channel 3: White + # Channel 4: Unknown + # Channel 5: Turn number / 100 + # Channel 6: Noise stdev in turn number / 50 + # Channel 7: Source + + self.inference_channel = 0 + self.num_channels = num_channels + self.num_blocks = num_blocks + self.conv0 = nn.Conv2d(in_channels=8, out_channels=self.num_channels, kernel_size=3, padding=1, bias=False) + + self.blocks = nn.ModuleList([]) + self.fixup_scale_init = 1.0 / math.sqrt(self.num_blocks) + self.blocks.append(ResBlock(self.num_channels,self.fixup_scale_init)) + self.blocks.append(ResBlock(self.num_channels,self.fixup_scale_init)) + + next_is_gpool = True + for b in range(num_blocks-2): + if next_is_gpool: + self.blocks.append(GPoolResBlock(self.num_channels,self.fixup_scale_init)) + else: + self.blocks.append(ResBlock(self.num_channels,self.fixup_scale_init)) + next_is_gpool = not next_is_gpool + + assert(len(self.blocks) == self.num_blocks) + + self.endtrunk_bias_focus = nn.Parameter(torch.zeros(self.num_channels,1,1)) + self.endtrunk_bias_g = nn.Parameter(torch.zeros(self.num_channels,1,1)) + self.convg = nn.Conv2d(in_channels=self.num_channels, out_channels=self.num_channels, kernel_size=1, padding=0, bias=False) + + self.fc1 = nn.Linear(self.num_channels*2, self.num_channels) + self.fc2 = nn.Linear(self.num_channels,3) + self.convaux = nn.Conv2d(in_channels=self.num_channels, out_channels=3, kernel_size=1, padding=0, bias=True) + + def forward(self, inputs): + trunk = self.conv0(inputs) + for i in range(self.num_blocks): + trunk = self.blocks[i](trunk) + + head_focus = F.relu(trunk+self.endtrunk_bias_focus) + head_g = F.relu(trunk+self.endtrunk_bias_g) + aux = self.convaux(head_focus) + gsize = head_g.size() + + x = torch.sum(head_focus * inputs[:,self.inference_channel:self.inference_channel+1,:,:],(2,3)) + g = torch.sum(head_g,(2,3)) / (gsize[2] * gsize[3]) # nchw -> nc + + x = torch.cat((x,g),dim=1) + + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return x,aux + + def save_to_file(self, filename): + state_dict = self.state_dict() + data = {} + data["num_channels"] = self.num_channels + data["num_blocks"] = self.num_blocks + data["state_dict"] = state_dict + torch.save(data, filename) + + @staticmethod + def load_from_file(filename): + data = torch.load(filename) + model = Model(data["num_channels"], data["num_blocks"]) + model.load_state_dict(data["state_dict"]) + return model + + diff --git a/python/genboard_run.py b/python/genboard_run.py new file mode 100755 index 000000000..92529500b --- /dev/null +++ b/python/genboard_run.py @@ -0,0 +1,204 @@ +#!/usr/bin/python3 +import sys +import os +import argparse +import traceback +import logging +import json +import math +import random +import hashlib +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from copy import deepcopy + +import data +from board import Board, IllegalMoveError +from genboard_common import Model + +if __name__ == '__main__': + + description = """ + Generate completions of Go positions + """ + + parser = argparse.ArgumentParser(description=description) + parser.add_argument('-model', help='Model file to load', required=True) + parser.add_argument('-board', help='Board pattern using {.,*},X,O,? for empty, black, white, unknown', required=True) + parser.add_argument('-turn', help='Approx turn number to tell the net to generate for, [0,300]', required=True, type=float) + parser.add_argument('-turnstdev', help='Approx turn number randomness [0,100]', required=True, type=float) + parser.add_argument('-source', help='Tell the net to mimic positions from source, {-1,0,1}', required=True, type=int) + parser.add_argument('-verbose', help='Print various info and debug messages instead of only the board', required=False, action='store_true') + parser.add_argument('-n', help='How many batches to generate, default 1', required=False, type=int, default=1) + parser.add_argument('-batchsize', help='How many positions to generate, default 1', required=False, type=int, default=1) + args = vars(parser.parse_args()) + + modelfile = args["model"] + boardstr = args["board"] + turn = args["turn"] + turnstdev = args["turnstdev"] + source = args["source"] + verbose = args["verbose"] + numbatches = args["n"] + batchsize = args["batchsize"] + + if turn < 0 or turn > 300: + raise Exception("Turn must be in [0,300]") + if turnstdev < 0 or turnstdev > 100: + raise Exception("Turn must be in [0,100]") + if source != -1 and source != 0 and source != 1: + raise Exception("Source must be in {-1,0,1}") + if numbatches < 0: + raise Exception("Num batches must be nonnegative") + if batchsize < 1: + raise Exception("Batchsize must be positive") + + cpudevice = torch.device("cpu") + if torch.cuda.is_available(): + if verbose: + print("CUDA is available, using it",flush=True) + gpudevice = torch.device("cuda:0") + else: + gpudevice = cpudevice + model = Model.load_from_file(modelfile).to(gpudevice) + + size = 19 + boardbase = [["." for x in range(size)] for y in range(size)] + boardbase[3][3] = "," + boardbase[9][3] = "," + boardbase[15][3] = "," + boardbase[3][9] = "," + boardbase[9][9] = "," + boardbase[15][9] = "," + boardbase[3][15] = "," + boardbase[9][15] = "," + boardbase[15][15] = "," + + num_channels = 8 + inputsbase = torch.zeros((1,num_channels,size,size)) + + inference_point_channel = 0 + black_channel = 2 + white_channel = 3 + unknown_channel = 4 + + # Channel 1: On-board + inputsbase[:,1,:,:].fill_(1.0) + + def fail_if_idx_too_large(idx): + if idx >= size * size: + raise Exception("Provided board is larger than 19x19") + + idx = 0 + for c in boardstr: + y = idx // 19 + x = idx % 19 + if c == "." or c == "*" or c == ",": + fail_if_idx_too_large(idx) + elif c == "X" or c == "x" or c == "B" or c == "b": + fail_if_idx_too_large(idx) + boardbase[y][x] = "X" + inputsbase[0,black_channel,y,x] = 1.0 + elif c == "O" or c == "o" or c == "W" or c == "w": + fail_if_idx_too_large(idx) + boardbase[y][x] = "O" + inputsbase[0,white_channel,y,x] = 1.0 + elif c == "?": + fail_if_idx_too_large(idx) + inputsbase[0,unknown_channel,y,x] = 1.0 + else: + # Ignore this char, counteract the += 1 at the end + idx -= 1 + idx += 1 + + # Channel 5: Turn number / 100 + inputsbase[:,5,:,:].fill_(turn / 100.0) + # Channel 6: Noise stdev in turn number / 50 + inputsbase[:,6,:,:].fill_(turnstdev / 50.0) + # Channel 7: Source + inputsbase[:,7,:,:].fill_(float(source)) + + rand = random.Random(os.urandom(32) + hashlib.md5(boardstr.encode()).hexdigest().encode()) + + with torch.no_grad(): + + for i in range(numbatches): + + flipx = rand.random() < 0.5 + flipy = rand.random() < 0.5 + swapxy = rand.random() < 0.5 + + flipx2 = rand.random() < 0.5 + flipy2 = rand.random() < 0.5 + swapxy2 = rand.random() < 0.5 + + def query_model(inputs): + inputstransformed = inputs.detach().clone() + if flipx: + if flipy: + inputstransformed = torch.flip(inputstransformed,[2,3]) + else: + inputstransformed = torch.flip(inputstransformed,[2]) + else: + if flipx: + inputstransformed = torch.flip(inputstransformed,[3]) + else: + pass + if swapxy: + inputstransformed = torch.transpose(inputstransformed,2,3) + + preds, auxpreds = model(inputstransformed.to(gpudevice)) + preds = F.softmax(preds,dim=1) + assert(len(preds.size()) == 2) + assert(preds.size()[0] == batchsize) + assert(preds.size()[1] == 3) + choices = [] + for b in range(batchsize): + weights = [preds[b,0],preds[b,1],preds[b,2]] + choice = rand.choices([0,1,2],weights=weights)[0] + choices.append(choice) + return choices + + inputs = inputsbase.expand([batchsize,-1,-1,-1]).detach().clone() + boards = [ deepcopy(boardbase) for b in range(batchsize) ] + + for y in range(size): + for x in range(size): + sx = x + sy = y + if flipx2: + sx = size - sx - 1 + if flipy2: + sy = size - sy - 1 + if swapxy2: + tmp = sx + sx = sy + sy = tmp + + if inputs[0,unknown_channel,sy,sx] == 1.0: + for b in range(batchsize): + inputs[b,unknown_channel,sy,sx] = 0.0 + inputs[b,inference_point_channel,sy,sx] = 1.0 + choices = query_model(inputs) + for b in range(batchsize): + inputs[b,inference_point_channel,sy,sx] = 0.0 + + choice = choices[b] + if choice == 0: + pass + elif choice == 1: + inputs[b,black_channel,sy,sx] = 1.0 + boards[b][sy][sx] = "X" + elif choice == 2: + inputs[b,white_channel,sy,sx] = 1.0 + boards[b][sy][sx] = "O" + + for b in range(batchsize): + s = "\n".join([" ".join(row) for row in boards[b]]) + s += "\n" + print(s) + sys.stdout.flush() + diff --git a/python/genboard_train.py b/python/genboard_train.py new file mode 100755 index 000000000..65990e105 --- /dev/null +++ b/python/genboard_train.py @@ -0,0 +1,492 @@ +#!/usr/bin/python3 +import sys +import os +import argparse +import traceback +import logging +import json +import math +import random +import hashlib +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim + +import data +from board import Board, IllegalMoveError +from genboard_common import Model + +class ShuffledDataset(torch.utils.data.IterableDataset): + def __init__(self, dataset, shuffle_buffer_size): + super().__init__() + self.dataset = dataset + self.shuffle_buffer_size = shuffle_buffer_size + + def __iter__(self): + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + rand = random.Random(os.urandom(32)) + else: + rand = random.Random(os.urandom(32)+ "#ShuffledDataset#".encode() + str(worker_info.id).encode()) + + shuffle_buffer = [] + try: + it = iter(self.dataset) + while len(shuffle_buffer) < self.shuffle_buffer_size: + item = next(it) + if isinstance(item, Exception): + yield item + else: + shuffle_buffer.append(item) + except StopIteration: + self.shuffle_buffer_size = len(shuffle_buffer) + + print("Initial shuffle buffer filled", flush=True) + rand.shuffle(shuffle_buffer) + try: + while True: + try: + item = next(it) + if isinstance(item, Exception): + yield item + else: + idx = rand.randint(0, self.shuffle_buffer_size-1) + old_item = shuffle_buffer[idx] + shuffle_buffer[idx] = item + yield old_item + except StopIteration: + break + while len(shuffle_buffer) > 0: + yield shuffle_buffer.pop() + except GeneratorExit: + pass + +def rand_triangular(rand,maxvalue): + r = (maxvalue+1) * (1.0 - math.sqrt(rand.random())) + r = int(math.floor(r)) + if r <= 0: + return 0 + if r >= maxvalue: + return maxvalue + return r + +def random_subinterval(rand,size): + # Anchor rectangles near the edge more often + if rand.random() < 0.5: + x0 = rand_triangular(rand,size)-1 + x1 = rand_triangular(rand,size)-1 + else: + x0 = rand.randint(0,size-1) + x1 = rand.randint(0,size-1) + + if rand.random() < 0.5: + x0 = size - x0 - 1 + x1 = size - x1 - 1 + + if x0 > x1: + return (x1,x0) + return (x0,x1) + + +class SgfDataset(torch.utils.data.IterableDataset): + def __init__(self, files, max_turn, break_prob_per_turn, sample_prob, endless): + self.files = files + self.max_turn = max_turn + self.break_prob_per_turn = break_prob_per_turn + self.sample_prob = sample_prob + self.endless = endless + + def __iter__(self): + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + rand = random.Random(os.urandom(32)) + else: + rand = random.Random(os.urandom(32)+ "#SgfDataset#".encode() + str(worker_info.id).encode()) + + files = self.files + cpudevice = torch.device("cpu") + + try: + while True: + rand.shuffle(files) + file_count = 0 + error_count = 0 + print("Iterator beginning reading of files %d / %d" % (file_count, len(files)), flush=True) + for filename in files: + try: + (metadata,setup,moves,rules) = data.load_sgf_moves_exn(filename) + except Exception as e: + error_count += 1 + continue + # Only even 19x19 games! + if metadata.size != 19 or len(setup) != 0 or (metadata.handicap is not None and metadata.handicap != 0): + continue + board = Board(size=metadata.size) + turn_number = 0 + for (pla,loc) in moves: + + if rand.random() < self.sample_prob: + inputs = torch.zeros((8,metadata.size,metadata.size),dtype=torch.float32,device=cpudevice) + result = torch.zeros((3,),dtype=torch.float32,device=cpudevice) + aux = torch.zeros((3,metadata.size,metadata.size),dtype=torch.float32,device=cpudevice) + + (alwaysknownxmin,alwaysknownxmax) = random_subinterval(rand,metadata.size) + (alwaysknownymin,alwaysknownymax) = random_subinterval(rand,metadata.size) + + if alwaysknownxmin <= 0 and alwaysknownxmax >= metadata.size-1 and alwaysknownymin <= 0 and alwaysknownymax >= metadata.size-1: + pass + else: + # Channel 1: On-board + inputs[1,:,:].fill_(1.0) + + num_always_known_poses = 0 + if alwaysknownxmax < 0 or alwaysknownxmin >= metadata.size or alwaysknownymax < 0 or alwaysknownymin >= metadata.size: + num_always_known_poses = 0 + else: + num_always_known_poses = ( + ( min(alwaysknownxmax, metadata.size-1) - max(alwaysknownxmin, 0) + 1) * + ( min(alwaysknownymax, metadata.size-1) - max(alwaysknownymin, 0) + 1) + ) + num_not_always_known_poses = metadata.size * metadata.size - num_always_known_poses + inferenceidx = rand.randint(0,num_not_always_known_poses-1) + + flipx = rand.random() < 0.5 + flipy = rand.random() < 0.5 + swapxy = rand.random() < 0.5 + + idx = 0 + for y in range(metadata.size): + for x in range(metadata.size): + pos = y * metadata.size + x + always_known = (x >= alwaysknownxmin and x <= alwaysknownxmax and y >= alwaysknownymin and y <= alwaysknownymax) + + sx = x + sy = y + if flipx: + sx = metadata.size - sx - 1 + if flipy: + sy = metadata.size - sy - 1 + if swapxy: + tmp = sx + sx = sy + sy = tmp + stone = board.board[board.loc(sx,sy)] + + # Channel 4: Unknown + if idx > inferenceidx and not always_known: + inputs[4,y,x] = 1.0 + # Channel 0: Next inference point + elif idx == inferenceidx and not always_known: + inputs[0,y,x] = 1.0 + result + if stone == Board.BLACK: + result[1] = 1.0 + elif stone == Board.WHITE: + result[2] = 1.0 + else: + result[0] = 1.0 + else: + # Channel 2: Black + if stone == Board.BLACK: + inputs[2,y,x] = 1.0 + # Channel 3: White + elif stone == Board.WHITE: + inputs[3,y,x] = 1.0 + + if stone == Board.BLACK: + aux[1,y,x] = 1.0 + elif stone == Board.WHITE: + aux[2,y,x] = 1.0 + else: + aux[0,y,x] = 1.0 + + if not always_known: + idx += 1 + + assert(idx == num_not_always_known_poses) + + if rand.random() < 0.3: + turn_noise_stdev = 0.0 + reported_turn = turn_number + else: + turn_noise_stdev = (rand.random() ** 2.0) * 100 + reported_turn = turn_number + rand.normalvariate(0.0,turn_noise_stdev) + + # Channel 5: Turn number / 100 + inputs[5,:,:].fill_(reported_turn / 100.0) + # Channel 6: Noise stdev in turn number / 50 + inputs[6,:,:].fill_(turn_noise_stdev / 50.0) + # Channel 7: Source + is_kgs = ("/kgs" in filename) or ("\\KGS" in filename) or ("/KGS" in filename) or ("\\KGS" in filename) + is_fox = ("/fox" in filename) or ("\\fox" in filename) or ("/FOX" in filename) or ("\\FOX" in filename) + if is_kgs: + inputs[7,:,:].fill_(1.0) + elif is_fox: + inputs[7,:,:].fill_(-1.0) + + if rand.random() < 0.5: + if rand.random() < 0.5: + inputs = torch.flip(inputs,[1,2]) + aux = torch.flip(aux,[1,2]) + else: + inputs = torch.flip(inputs,[1]) + aux = torch.flip(aux,[1]) + else: + if rand.random() < 0.5: + inputs = torch.flip(inputs,[2]) + aux = torch.flip(aux,[2]) + else: + pass + + if rand.random() < 0.5: + inputs = torch.transpose(inputs,1,2) + aux = torch.transpose(aux,1,2) + + yield (inputs,result,aux) + + try: + board.play(pla,loc) + except IllegalMoveError as e: + # On illegal move in the SGF, don't attempt to recover, just move on to new game + print("Illegal move, skipping file " + filename + ":" + str(e), flush=True) + break + turn_number += 1 + if turn_number > self.max_turn: + break + if rand.random() < self.break_prob_per_turn: + break + + file_count += 1 + if file_count % 200 == 0: + print("Read through file %d / %d (error count %d)" % (file_count, len(files), error_count), flush=True) + + if not self.endless: + break + + except GeneratorExit: + pass + except Exception as e: + print("EXCEPTION IN GENERATOR: " + str(e)) + traceback.print_exc() + print("---",flush=True) + yield e + + +def save_json(data,filename): + with open(filename,"w") as f: + json.dump(data,f) + f.flush() + os.fsync(f.fileno()) + +def load_json(filename): + with open(filename) as f: + data = json.load(f) + return data + + +if __name__ == '__main__': + + description = """ + Train net to predict Go positions one stone at a time + """ + + parser = argparse.ArgumentParser(description=description) + parser.add_argument('-traindir', help='Dir to write to for recording training results', required=True) + parser.add_argument('-datadirs', help='Directory with sgfs', required=True) + parser.add_argument('-testprop', help='Proportion of data for test', type=float, required=True) + parser.add_argument('-lr-scale', help='LR multiplier', type=float, required=False) + parser.add_argument('-channels', help='Channels', type=int, required=True) + parser.add_argument('-blocks', help='Blocks', type=int, required=True) + parser.add_argument('-grad-clip-scale', help='Gradient clip multiplier', type=float, required=False) + parser.add_argument('-num-data-workers', help='Number of processes for data loading', type=int, required=False) + args = vars(parser.parse_args()) + + traindir = args["traindir"] + datadirs = args["datadirs"] + testprop = args["testprop"] + lr_scale = args["lr_scale"] + num_channels = args["channels"] + num_blocks = args["blocks"] + grad_clip_scale = args["grad_clip_scale"] + num_data_workers = args["num_data_workers"] + logfilemode = "a" + + if lr_scale is None: + lr_scale = 1.0 + if grad_clip_scale is None: + grad_clip_scale = 1.0 + + if num_data_workers is None: + num_data_workers = 0 + + if not os.path.exists(traindir): + os.mkdir(traindir) + + bareformatter = logging.Formatter("%(asctime)s %(message)s") + fh = logging.FileHandler(os.path.join(traindir,"train.log"), mode=logfilemode) + fh.setFormatter(bareformatter) + stdouthandler = logging.StreamHandler(sys.stdout) + stdouthandler.setFormatter(bareformatter) + trainlogger = logging.getLogger("trainlogger") + trainlogger.setLevel(logging.INFO) + trainlogger.addHandler(fh) + trainlogger.addHandler(stdouthandler) + trainlogger.propagate=False + np.set_printoptions(linewidth=150) + def trainlog(s): + trainlogger.info(s) + sys.stdout.flush() + + shuffle_buffer_size = 100000 + + files_found = 0 + trainfiles = [] + testfiles = [] + for datadir in datadirs.split(","): + for parent, subdirs, files in os.walk(datadir): + for name in files: + if name.endswith(".sgf"): + files_found += 1 + if files_found % 10000 == 0: + trainlog("Found %d sgfs..." % files_found) + r = float.fromhex("0."+hashlib.md5(os.path.join(parent,name).encode()).hexdigest()[:16]) + if r < testprop: + testfiles.append(os.path.join(parent,name)) + else: + trainfiles.append(os.path.join(parent,name)) + + trainlog("Found %d training sgfs" % len(trainfiles)) + trainlog("Found %d testing sgfs" % len(testfiles)) + + max_turn = 300 + break_prob_per_turn = 0.01 + + traindataset = ShuffledDataset(SgfDataset(trainfiles,max_turn,break_prob_per_turn,sample_prob=0.5,endless=True),shuffle_buffer_size) + testdataset = SgfDataset(testfiles,max_turn,break_prob_per_turn,sample_prob=0.2,endless=True) + + batch_size = 128 + trainloader = torch.utils.data.DataLoader(traindataset, batch_size=batch_size, shuffle=False, num_workers=num_data_workers, drop_last=True) + testloader = torch.utils.data.DataLoader(testdataset, batch_size=batch_size, shuffle=False, num_workers=num_data_workers, drop_last=True) + + trainlog("Made data loaders") + + samples_per_epoch = 400000 + samples_per_test = 25600 + batches_per_epoch = samples_per_epoch // batch_size + batches_per_test = samples_per_test // batch_size + + def lossfunc(inputs, results, preds, aux, auxpreds): + assert(preds.size()[1] == 3) + assert(auxpreds.size()[1] == 3) + main_loss = -torch.sum(results * F.log_softmax(preds,dim=1)) + aux_loss = -torch.sum(aux * F.log_softmax(auxpreds,dim=1) * inputs[:,4:5,:,:] / torch.sum(inputs[:,1:2,:,:], dim=[2,3], keepdim=True)) * 0.3 + return main_loss, aux_loss + + cpudevice = torch.device("cpu") + if torch.cuda.is_available(): + trainlog("CUDA is available, using it") + gpudevice = torch.device("cuda:0") + else: + gpudevice = cpudevice + + modelpath = os.path.join(traindir,"model.data") + optimpath = os.path.join(traindir,"optim.data") + traindatapath = os.path.join(traindir,"traindata.json") + if os.path.exists(modelpath): + trainlog("Loading preexisting model!") + model = Model.load_from_file(modelpath).to(gpudevice) + if model.num_channels != num_channels: + raise Exception("Number of channels in model is %d but command line arg was %d" % (model.num_channels,num_channels)) + if model.num_blocks != num_blocks: + raise Exception("Number of blocks in model is %d but command line arg was %d" % (model.num_blocks,num_blocks)) + optimizer = optim.SGD(model.parameters(), lr=0.00001*lr_scale, momentum=0.9) + optimizer.load_state_dict(torch.load(optimpath)) + traindata = load_json(traindatapath) + else: + model = Model(num_channels=num_channels, num_blocks=num_blocks).to(gpudevice) + optimizer = optim.SGD(model.parameters(), lr=0.00001*lr_scale, momentum=0.9) + traindata = {"samples_so_far":0, "batches_so_far":0} + + trainlog("Saving!") + model.save_to_file(modelpath) + torch.save(optimizer.state_dict(), optimpath) + save_json(traindata,traindatapath) + + grad_clip_max = 400 * grad_clip_scale + #Loosen gradient clipping as we shift to smaller learning rates + grad_clip_max = grad_clip_max / math.sqrt(lr_scale) + + running_batch_count = 0 + running_main_loss = 0.0 + running_aux_loss = 0.0 + running_gnorm = 0.0 + running_ewms_exgnorm = 0.0 + print_every_batches = 100 + trainiter = iter(trainloader) + testiter = iter(testloader) + while True: + for i in range(batches_per_epoch): + inputs, results, auxs = next(trainiter) + inputs = inputs.to(gpudevice) + results = results.to(gpudevice) + auxs = auxs.to(gpudevice) + + optimizer.zero_grad() + + preds, auxpreds = model(inputs) + main_loss,aux_loss = lossfunc(inputs, results, preds, auxs, auxpreds) + loss = main_loss + aux_loss + loss.backward() + gnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_max) + optimizer.step() + + traindata["samples_so_far"] += batch_size + traindata["batches_so_far"] += 1 + + running_batch_count += 1 + running_main_loss += main_loss.item() + running_aux_loss += aux_loss.item() + running_gnorm += gnorm + running_ewms_exgnorm += max(0.0, gnorm - grad_clip_max) + if running_batch_count >= print_every_batches: + trainlog("TRAIN samples: %d, batches: %d, main loss: %.5f, aux loss: %.5f, gnorm: %.2f, ewms_exgnorm: %.3g" % ( + traindata["samples_so_far"], + traindata["batches_so_far"], + running_main_loss / (running_batch_count * batch_size), + running_aux_loss / (running_batch_count * batch_size), + running_gnorm / (running_batch_count), + running_ewms_exgnorm / (running_batch_count), + )) + running_batch_count = 0 + running_main_loss = 0.0 + running_aux_loss = 0.0 + running_gnorm = 0.0 + running_ewms_exgnorm *= 0.5 + + trainlog("Saving!") + model.save_to_file(modelpath) + torch.save(optimizer.state_dict(), optimpath) + save_json(traindata,traindatapath) + + trainlog("Testing!") + test_samples = 0 + test_main_loss = 0.0 + test_aux_loss = 0.0 + with torch.no_grad(): + for i in range(batches_per_test): + inputs, results, auxs = next(testiter) + inputs = inputs.to(gpudevice) + results = results.to(gpudevice) + auxs = auxs.to(gpudevice) + + preds, auxpreds = model(inputs) + main_loss, aux_loss = lossfunc(inputs, results, preds, auxs, auxpreds) + test_samples += batch_size + test_main_loss += main_loss.item() + test_aux_loss += aux_loss.item() + trainlog("TEST samples %d, main loss: %.5f, aux loss %.5f" % (test_samples, test_main_loss / test_samples, test_aux_loss / test_samples)) + + +trainlog('Finished Training')