Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bittensor/bittensor.proto
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ message TensorMessage {
string version = 1;

// Neuron key: [REQUIRED] Ed25519 raw hex encoded public key.
// Public key of the caller. Used to make a call to the public key of the synapse.
// Links message to calling neuron-account.
// i.e. b'4c598ff31b68eb6c458c2dc51b25367fa213c566088077f46d93156148429d78'
// SIZE: 256-bits (32-bytes)
Expand Down
2 changes: 0 additions & 2 deletions bittensor/dendrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from bittensor import bittensor_pb2
from bittensor.serializer import PyTorchSerializer
import bittensor
import bittensor

from loguru import logger
from typing import List, Tuple, Dict, Optional
Expand Down Expand Up @@ -97,7 +96,6 @@ class _RemoteModuleCall(torch.autograd.Function):
# TODO (const) should take multiple input tensors and kwargs.
@staticmethod
def forward(ctx, caller: RemoteSynapse, dummy: torch.Tensor, inputs: torch.Tensor) -> torch.Tensor:

# Save for backward call.
ctx.caller = caller

Expand Down
1 change: 0 additions & 1 deletion examples/mnist/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def train(epoch, global_step):

# Flatten mnist inputs
inputs = torch.flatten(data, start_dim=1)

# Query the remote network.
synapses = neuron.synapses() # Returns a list of synapses on the network.
requests, scores = router.route(inputs, synapses) # routes inputs to network.
Expand Down
165 changes: 67 additions & 98 deletions examples/rehoboam/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,39 +8,43 @@

import torch
from torch import nn
import torch.nn.functional as F

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from transformer import TransformerModel
from dataset import Dataset
from loguru import logger


class TransformerSynapse(bittensor.Synapse):
""" An bittensor endpoint trained on 28, 28 pixel images to detect handwritten characters.
""" An bittensor endpoint trained on wiki corpus.
"""
def __init__(self, transformer):
super(Net, self).__init__()
def __init__(self, transformer, ntokens):
super(TransformerSynapse, self).__init__()
self.transformer = transformer

self.ntokens = ntokens

def indef(self):
x_def = bittensor_pb2.TensorDef(
version = bittensor.__version__,
shape = [-1, 784],
dtype = bittensor_pb2.FLOAT32,
shape = [700, 20],
dtype = bittensor_pb2.INT64,
requires_grad = True,
)
return x_def
return [x_def]

def outdef(self):
y_def = bittensor_pb2.TensorDef(
version = bittensor.__version__,
shape = [-1, 10],
dtype = bittensor_pb2.FLOAT32,
shape = [700, 20],
dtype = bittensor_pb2.INT64,
requires_grad = True,
)
return y_def
return [y_def]

def forward(self, x):
x = x.view(-1, 1, 28, 28)
x = transformer.encode(x)
x = self.transformer.encode(x)
x = torch.flatten(x, start_dim=1)
return x

def main(hparams):
Expand All @@ -49,27 +53,22 @@ def main(hparams):
batch_size = 20
eval_batch_size = 10
bptt = 35
log_interval = 10

dataset = Dataset()
dataset = Dataset(bptt)
train_data = dataset.batchify(dataset.train_txt, batch_size)
val_data = dataset.batchify(dataset.val_txt, eval_batch_size)
test_data = dataset.batchify(dataset.test_txt, eval_batch_size)

# Transformer model architecture
ntokens = len(dataset.TEXT.vocab.stoi) # the size of vocabulary
emsize = 200 # embedding dimension
emsize = 20 # embedding dimension
nhid = 200 # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2 # the number of heads in the multiheadattention models
dropout = 0.2 # the dropout value
transformer = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout)

# Optimizer.
criterion = nn.CrossEntropyLoss() # loss function
lr = 5.0 # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

# bittensor:
# Load bittensor config from hparams.
config = bittensor.Config(hparams)
Expand All @@ -78,119 +77,89 @@ def main(hparams):
neuron = bittensor.Neuron(config)

# Init a trainable request router.
router = bittensor.Router(x_dim = 784, key_dim = 100, topk = 10)
router = bittensor.Router(x_dim = dataset.bptt * emsize, key_dim = 100, topk = 10)

# Build local network.
net = Net()
net = TransformerSynapse(transformer, ntokens)

# Subscribe the local network to the network
neuron.subscribe(transformer)
neuron.subscribe(net)

# Start the neuron backend.
neuron.start()

def train(dataset, transformer):
model.train() # Turn on the train mode
# Optimizer.
criterion = nn.CrossEntropyLoss() # loss function
lr = 0.01 # learning rate
optimizer = torch.optim.SGD(net.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)


def train(dataset, transformer, epoch):
transformer.train() # Turn on the train mode
total_loss = 0.
global_step = 0
start_time = time.time()
ntokens = len(dataset.TEXT.vocab.stoi)
for batch, i in enumerate(
for batch_idx, i in enumerate(
range(0,
train_data.size(0) - 1, dataset.bptt)):
data, targets = dataset.get_batch(train_data, i)
optimizer.zero_grad()

# data
print (data.shape)

encodings = net.transformer.encode(data)
# Flatten encoder inputs inputs
inputs = data.view(-1, bptt, emsize)
inputs = torch.flatten(inputs, start_dim=1)

# Query the remote network.
synapses = neuron.synapses() # Returns a list of synapses on the network.
requests, scores = router.route(inputs, synapses) # routes inputs to network.
requests, scores = router.route(inputs.float(), synapses) # routes inputs to network.

# Convert request indices back to type long()
request_list = [*requests]
request_list[0] = requests[0].type(torch.LongTensor)
requests = *request_list,
responses = neuron(requests, synapses) # Makes network calls.

remote = router.join(responses) # Joins responses based on scores.

# Encode sequence inputs.
encodings = transformer.encode(data) # (seq_len, batch_size, embedding_size)
local = net(inputs)

# Train.
output = local + remote



# Get nodes from metagraph.
# and map nodes to torch keys.
axons = neuron.axons() # List[bittensor_pb2.Node]))
keys = keymap.toKeys(axons) # (-1, key_dim)

# Learning a map from the gate_inputs to keys
# gates[i, j] = score for the jth key for input i
gate_inputs = encodings.view(
batch_size, x_dim) # (batch_size, seq_len * embedding_size)
gates = gate(gate_inputs, keys, topk=min(len(keys), topk))

# Dispatch data to inputs for each key.
# when gates[i, j] == 0, the key j does not recieve input i
dispatch_inputs = data.view(batch_size,
-1) # (batch_size, sequence_length)
dispatch = dispatcher.dispatch(dispatch_inputs,
gates) # List[(-1, seq_len)]

# Query the network by mapping from keys to node endpoints.
# results = list[torch.Tensor], len(results) = len(keys)
axons = keymap.toAxons(keys) # List[bittensor_pb2.Node]
query = neuron(dispatch, axons) # List[(-1, embedding_size)]

# Join results using gates to combine inputs.
results = dispatcher.combine(
query, gates) # (batch_size, seq_len * embedding_size)

# Decode responses.
results = results.view(
-1, batch_size,
emsize) # (seq_len, batch_size, embedding_size)
to_decode = results + encodings
output = model.decode(
to_decode) # (target_len, batch_size, embedding_size)

# Loss and optimizer step
output = net.transformer.decode(output.view(-1, batch_size, emsize) + encodings)

loss = criterion(output.view(-1, ntokens), targets)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()

# Update bittensor weights
weights = neuron.getweights(axons)
weights = (0.95) * weights + (0.05) * torch.mean(gates, dim=0)
neuron.setweights(axons, weights)

total_loss += loss.item()
log_interval = 1
if batch % log_interval == 0 and batch > 0:
cur_loss = total_loss / log_interval
elapsed = time.time() - start_time
print('| epoch {:3d} | {:5d}/{:5d} batches | '
'lr {:02.2f} | ms/batch {:5.2f} | '
'loss {:5.2f} | ppl {:8.2f}'.format(
epoch, batch,
len(train_data) // dataset.bptt,
scheduler.get_lr()[0], elapsed * 1000 / log_interval,
cur_loss, math.exp(cur_loss)))
total_loss = 0
start_time = time.time()

global_step += 1

# Set network weights.
weights = neuron.getweights(synapses)
weights = (0.99) * weights + 0.01 * torch.mean(scores, dim=0)
neuron.setweights(synapses, weights)

if batch_idx % log_interval == 0:
logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} \tnP|nS: {}|{}'.format(
epoch, batch_idx * len(data), train_data.size(0),
100. * batch_idx / train_data.size(0), loss.item(), len(neuron.metagraph.peers), len(neuron.metagraph.synapses)))


epochs = 10
global_step = 0
for epoch in range(1, epochs + 1):
epoch_start_time = time.time()
train(dataset, model)
#epoch_start_time = time.time()
train(dataset, net, epoch)
#epoch_end_time = time.time() - epoch_start_time
#logger.info('Train Epoch: {} finished in: {} seconds'.format(
# epoch, epoch_end_time))
scheduler.step()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--bootstrap',
default='',
type=str,
help='ip address of bootstrap metagraph')
hparams = bittensor.Config.add_args(parser)
hparams = parser.parse_args()
main(hparams)
1 change: 1 addition & 0 deletions examples/rehoboam/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ nltk
sentencepiece
timeloop
pycryptodome
torchtext
2 changes: 1 addition & 1 deletion examples/rehoboam/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

import bittensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer

class TransformerModel(bittensor.Synapse):
Expand Down