diff --git a/bittensor/bittensor.proto b/bittensor/bittensor.proto index c19a79cbe3..ccfcdace46 100644 --- a/bittensor/bittensor.proto +++ b/bittensor/bittensor.proto @@ -157,6 +157,7 @@ message TensorMessage { string version = 1; // Neuron key: [REQUIRED] Ed25519 raw hex encoded public key. + // Public key of the caller. Used to make a call to the public key of the synapse. // Links message to calling neuron-account. // i.e. b'4c598ff31b68eb6c458c2dc51b25367fa213c566088077f46d93156148429d78' // SIZE: 256-bits (32-bytes) diff --git a/bittensor/dendrite.py b/bittensor/dendrite.py index d6efa8392c..744e9ba710 100644 --- a/bittensor/dendrite.py +++ b/bittensor/dendrite.py @@ -2,7 +2,6 @@ from bittensor import bittensor_pb2 from bittensor.serializer import PyTorchSerializer import bittensor -import bittensor from loguru import logger from typing import List, Tuple, Dict, Optional @@ -97,7 +96,6 @@ class _RemoteModuleCall(torch.autograd.Function): # TODO (const) should take multiple input tensors and kwargs. @staticmethod def forward(ctx, caller: RemoteSynapse, dummy: torch.Tensor, inputs: torch.Tensor) -> torch.Tensor: - # Save for backward call. ctx.caller = caller diff --git a/examples/mnist/main.py b/examples/mnist/main.py index ebbfee5075..0eeda910a4 100644 --- a/examples/mnist/main.py +++ b/examples/mnist/main.py @@ -109,7 +109,6 @@ def train(epoch, global_step): # Flatten mnist inputs inputs = torch.flatten(data, start_dim=1) - # Query the remote network. synapses = neuron.synapses() # Returns a list of synapses on the network. requests, scores = router.route(inputs, synapses) # routes inputs to network. diff --git a/examples/rehoboam/main.py b/examples/rehoboam/main.py index 2fe8a2d77d..1fd829cbef 100644 --- a/examples/rehoboam/main.py +++ b/examples/rehoboam/main.py @@ -8,39 +8,43 @@ import torch from torch import nn +import torch.nn.functional as F sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from transformer import TransformerModel from dataset import Dataset +from loguru import logger + class TransformerSynapse(bittensor.Synapse): - """ An bittensor endpoint trained on 28, 28 pixel images to detect handwritten characters. + """ An bittensor endpoint trained on wiki corpus. """ - def __init__(self, transformer): - super(Net, self).__init__() + def __init__(self, transformer, ntokens): + super(TransformerSynapse, self).__init__() self.transformer = transformer - + self.ntokens = ntokens + def indef(self): x_def = bittensor_pb2.TensorDef( version = bittensor.__version__, - shape = [-1, 784], - dtype = bittensor_pb2.FLOAT32, + shape = [700, 20], + dtype = bittensor_pb2.INT64, requires_grad = True, ) - return x_def + return [x_def] def outdef(self): y_def = bittensor_pb2.TensorDef( version = bittensor.__version__, - shape = [-1, 10], - dtype = bittensor_pb2.FLOAT32, + shape = [700, 20], + dtype = bittensor_pb2.INT64, requires_grad = True, ) - return y_def + return [y_def] def forward(self, x): - x = x.view(-1, 1, 28, 28) - x = transformer.encode(x) + x = self.transformer.encode(x) + x = torch.flatten(x, start_dim=1) return x def main(hparams): @@ -49,27 +53,22 @@ def main(hparams): batch_size = 20 eval_batch_size = 10 bptt = 35 + log_interval = 10 - dataset = Dataset() + dataset = Dataset(bptt) train_data = dataset.batchify(dataset.train_txt, batch_size) val_data = dataset.batchify(dataset.val_txt, eval_batch_size) test_data = dataset.batchify(dataset.test_txt, eval_batch_size) # Transformer model architecture ntokens = len(dataset.TEXT.vocab.stoi) # the size of vocabulary - emsize = 200 # embedding dimension + emsize = 20 # embedding dimension nhid = 200 # the dimension of the feedforward network model in nn.TransformerEncoder nlayers = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder nhead = 2 # the number of heads in the multiheadattention models dropout = 0.2 # the dropout value transformer = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout) - # Optimizer. - criterion = nn.CrossEntropyLoss() # loss function - lr = 5.0 # learning rate - optimizer = torch.optim.SGD(model.parameters(), lr=lr) - scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95) - # bittensor: # Load bittensor config from hparams. config = bittensor.Config(hparams) @@ -78,119 +77,89 @@ def main(hparams): neuron = bittensor.Neuron(config) # Init a trainable request router. - router = bittensor.Router(x_dim = 784, key_dim = 100, topk = 10) + router = bittensor.Router(x_dim = dataset.bptt * emsize, key_dim = 100, topk = 10) # Build local network. - net = Net() + net = TransformerSynapse(transformer, ntokens) # Subscribe the local network to the network - neuron.subscribe(transformer) + neuron.subscribe(net) # Start the neuron backend. neuron.start() - def train(dataset, transformer): - model.train() # Turn on the train mode + # Optimizer. + criterion = nn.CrossEntropyLoss() # loss function + lr = 0.01 # learning rate + optimizer = torch.optim.SGD(net.parameters(), lr=lr) + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95) + + + def train(dataset, transformer, epoch): + transformer.train() # Turn on the train mode total_loss = 0. + global_step = 0 start_time = time.time() ntokens = len(dataset.TEXT.vocab.stoi) - for batch, i in enumerate( + for batch_idx, i in enumerate( range(0, train_data.size(0) - 1, dataset.bptt)): data, targets = dataset.get_batch(train_data, i) optimizer.zero_grad() - - # data - print (data.shape) - + encodings = net.transformer.encode(data) # Flatten encoder inputs inputs inputs = data.view(-1, bptt, emsize) inputs = torch.flatten(inputs, start_dim=1) # Query the remote network. synapses = neuron.synapses() # Returns a list of synapses on the network. - requests, scores = router.route(inputs, synapses) # routes inputs to network. + requests, scores = router.route(inputs.float(), synapses) # routes inputs to network. + + # Convert request indices back to type long() + request_list = [*requests] + request_list[0] = requests[0].type(torch.LongTensor) + requests = *request_list, responses = neuron(requests, synapses) # Makes network calls. + remote = router.join(responses) # Joins responses based on scores. - # Encode sequence inputs. - encodings = transformer.encode(data) # (seq_len, batch_size, embedding_size) + local = net(inputs) + # Train. + output = local + remote - - - # Get nodes from metagraph. - # and map nodes to torch keys. - axons = neuron.axons() # List[bittensor_pb2.Node])) - keys = keymap.toKeys(axons) # (-1, key_dim) - - # Learning a map from the gate_inputs to keys - # gates[i, j] = score for the jth key for input i - gate_inputs = encodings.view( - batch_size, x_dim) # (batch_size, seq_len * embedding_size) - gates = gate(gate_inputs, keys, topk=min(len(keys), topk)) - - # Dispatch data to inputs for each key. - # when gates[i, j] == 0, the key j does not recieve input i - dispatch_inputs = data.view(batch_size, - -1) # (batch_size, sequence_length) - dispatch = dispatcher.dispatch(dispatch_inputs, - gates) # List[(-1, seq_len)] - - # Query the network by mapping from keys to node endpoints. - # results = list[torch.Tensor], len(results) = len(keys) - axons = keymap.toAxons(keys) # List[bittensor_pb2.Node] - query = neuron(dispatch, axons) # List[(-1, embedding_size)] - - # Join results using gates to combine inputs. - results = dispatcher.combine( - query, gates) # (batch_size, seq_len * embedding_size) - # Decode responses. - results = results.view( - -1, batch_size, - emsize) # (seq_len, batch_size, embedding_size) - to_decode = results + encodings - output = model.decode( - to_decode) # (target_len, batch_size, embedding_size) - - # Loss and optimizer step + output = net.transformer.decode(output.view(-1, batch_size, emsize) + encodings) + loss = criterion(output.view(-1, ntokens), targets) loss.backward() - torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) optimizer.step() - - # Update bittensor weights - weights = neuron.getweights(axons) - weights = (0.95) * weights + (0.05) * torch.mean(gates, dim=0) - neuron.setweights(axons, weights) - - total_loss += loss.item() - log_interval = 1 - if batch % log_interval == 0 and batch > 0: - cur_loss = total_loss / log_interval - elapsed = time.time() - start_time - print('| epoch {:3d} | {:5d}/{:5d} batches | ' - 'lr {:02.2f} | ms/batch {:5.2f} | ' - 'loss {:5.2f} | ppl {:8.2f}'.format( - epoch, batch, - len(train_data) // dataset.bptt, - scheduler.get_lr()[0], elapsed * 1000 / log_interval, - cur_loss, math.exp(cur_loss))) - total_loss = 0 - start_time = time.time() - + global_step += 1 + + # Set network weights. + weights = neuron.getweights(synapses) + weights = (0.99) * weights + 0.01 * torch.mean(scores, dim=0) + neuron.setweights(synapses, weights) + + if batch_idx % log_interval == 0: + logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} \tnP|nS: {}|{}'.format( + epoch, batch_idx * len(data), train_data.size(0), + 100. * batch_idx / train_data.size(0), loss.item(), len(neuron.metagraph.peers), len(neuron.metagraph.synapses))) + + + epochs = 10 + global_step = 0 for epoch in range(1, epochs + 1): - epoch_start_time = time.time() - train(dataset, model) + #epoch_start_time = time.time() + train(dataset, net, epoch) + #epoch_end_time = time.time() - epoch_start_time + #logger.info('Train Epoch: {} finished in: {} seconds'.format( + # epoch, epoch_end_time)) scheduler.step() if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--bootstrap', - default='', - type=str, - help='ip address of bootstrap metagraph') + hparams = bittensor.Config.add_args(parser) hparams = parser.parse_args() main(hparams) diff --git a/examples/rehoboam/requirements.txt b/examples/rehoboam/requirements.txt index f6f6e574d8..3d62b43057 100644 --- a/examples/rehoboam/requirements.txt +++ b/examples/rehoboam/requirements.txt @@ -9,3 +9,4 @@ nltk sentencepiece timeloop pycryptodome +torchtext diff --git a/examples/rehoboam/transformer.py b/examples/rehoboam/transformer.py index 03dc51e861..5afbe585e8 100644 --- a/examples/rehoboam/transformer.py +++ b/examples/rehoboam/transformer.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn import torch.nn.functional as F - +import bittensor from torch.nn import TransformerEncoder, TransformerEncoderLayer class TransformerModel(bittensor.Synapse):