Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 96 additions & 83 deletions bittensor/_axon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import inspect
import time
from concurrent import futures
from typing import List, Callable, Optional
from typing import Dict, List, Callable, Optional, Tuple, Union
from bittensor._threadpool import prioritythreadpool

import torch
Expand Down Expand Up @@ -339,101 +339,114 @@ def check_forward_callback( forward_callback:Callable, synapses:list = []):
forward_callback([sample_input], synapses, hotkey='')

class AuthInterceptor(grpc.ServerInterceptor):
""" Creates a new server interceptor that authenticates incoming messages from passed arguments.
"""
def __init__(self, key:str = 'Bittensor',blacklist:List = []):
r""" Creates a new server interceptor that authenticates incoming messages from passed arguments.
"""Creates a new server interceptor that authenticates incoming messages from passed arguments."""

def __init__(self, key: str = "Bittensor", blacklist: List = []):
r"""Creates a new server interceptor that authenticates incoming messages from passed arguments.
Args:
key (str, `optional`):
key for authentication header in the metadata (default= Bittensor)
black_list (Fucntion, `optional`):
key for authentication header in the metadata (default = Bittensor)
black_list (Function, `optional`):
black list function that prevents certain pubkeys from sending messages
"""
super().__init__()
self._valid_metadata = ('rpc-auth-header', key)
self.nounce_dic = {}
self.message = 'Invalid key'
self.auth_header_value = key
self.nonces = {}
self.blacklist = blacklist
def deny(_, context):
context.abort(grpc.StatusCode.UNAUTHENTICATED, self.message)

self._deny = grpc.unary_unary_rpc_method_handler(deny)

def intercept_service(self, continuation, handler_call_details):
r""" Authentication between bittensor nodes. Intercepts messages and checks them
"""
meta = handler_call_details.invocation_metadata
def parse_legacy_signature(
self, signature: str
) -> Union[Tuple[int, str, str, str], None]:
r"""Attempts to parse a signature using the legacy format, using `bitxx` as a separator"""
parts = signature.split("bitxx")
if len(parts) < 4:
return None
try:
nonce = int(parts[0])
parts = parts[1:]
except ValueError:
return None
receptor_uuid, parts = parts[-1], parts[:-1]
message, parts = parts[-1], parts[:-1]
pubkey = "".join(parts)
return (nonce, pubkey, message, receptor_uuid)

def parse_signature(self, metadata: Dict[str, str]) -> Tuple[int, str, str, str]:
r"""Attempts to parse a signature from the metadata"""
signature = metadata.get("bittensor-signature")
if signature is None:
raise Exception("Request signature missing")
parts = self.parse_legacy_signature(signature)
if parts is not None:
return parts
raise Exception("Unknown signature format")

def check_signature(
self, nonce: int, pubkey: str, signature: str, receptor_uuid: str
):
r"""verification of signature in metadata. Uses the pubkey and nonce"""
keypair = Keypair(ss58_address=pubkey)
# Build the expected message which was used to build the signature.
message = f"{nonce}{pubkey}{receptor_uuid}"
# Build the key which uniquely identifies the endpoint that has signed
# the message.
endpoint_key = f"{pubkey}:{receptor_uuid}"

if endpoint_key in self.nonces.keys():
previous_nonce = self.nonces[endpoint_key]
# Nonces must be strictly monotonic over time.
if nonce - previous_nonce <= -10:
raise Exception("Nonce is too small")
if not keypair.verify(message, signature):
raise Exception("Signature mismatch")
self.nonces[endpoint_key] = nonce
return

if not keypair.verify(message, signature):
raise Exception("Signature mismatch")
self.nonces[endpoint_key] = nonce

def version_checking(self, metadata: Dict[str, str]):
r"""Checks the header and version in the metadata"""
provided_value = metadata.get("rpc-auth-header")
if provided_value is None or provided_value != self.auth_header_value:
raise Exception("Unexpected caller metadata")

def black_list_checking(self, pubkey: str, method: str):
r"""Tries to call to blacklist function in the miner and checks if it should blacklist the pubkey"""
if self.blacklist == None:
return

try:
#version checking
self.version_checking(meta)
request_type = {
"/Bittensor/Forward": bittensor.proto.RequestType.FORWARD,
"/Bittensor/Backward": bittensor.proto.RequestType.BACKWARD,
}.get(method)
if request_type is None:
raise Exception("Unknown request type")

#signature checking
self.signature_checking(meta)
if self.blacklist(pubkey, request_type):
raise Exception("Request type is blacklisted")

#blacklist checking
self.black_list_checking(meta)
def intercept_service(self, continuation, handler_call_details):
r"""Authentication between bittensor nodes. Intercepts messages and checks them"""
method = handler_call_details.method
metadata = dict(handler_call_details.invocation_metadata)

return continuation(handler_call_details)
try:
# version checking
self.version_checking(metadata)

except Exception as e:
self.message = str(e)
return self._deny
(nonce, pubkey, signature, receptor_uuid) = self.parse_signature(metadata)

def vertification(self,meta):
r"""vertification of signature in metadata. Uses the pubkey and nounce
"""
variable_length_messages = meta[1].value.split('bitxx')
nounce = int(variable_length_messages[0])
pubkey = variable_length_messages[1]
message = variable_length_messages[2]
unique_receptor_uid = variable_length_messages[3]
_keypair = Keypair(ss58_address=pubkey)

# Unique key that specifies the endpoint.
endpoint_key = str(pubkey) + str(unique_receptor_uid)

#checking the time of creation, compared to previous messages
if endpoint_key in self.nounce_dic.keys():
prev_data_time = self.nounce_dic[ endpoint_key ]
if nounce - prev_data_time > -10:
self.nounce_dic[ endpoint_key ] = nounce

#decrypting the message and verify that message is correct
verification = _keypair.verify( str(nounce) + str(pubkey) + str(unique_receptor_uid), message)
else:
verification = False
else:
self.nounce_dic[ endpoint_key ] = nounce
verification = _keypair.verify( str( nounce ) + str(pubkey) + str(unique_receptor_uid), message)
# signature checking
self.check_signature(nonce, pubkey, signature, receptor_uuid)

return verification
# blacklist checking
self.black_list_checking(pubkey, method)

def signature_checking(self,meta):
r""" Calls the vertification of the signature and raises an error if failed
"""
if self.vertification(meta):
pass
else:
raise Exception('Incorrect Signature')

def version_checking(self,meta):
r""" Checks the header and version in the metadata
"""
if meta[0] == self._valid_metadata:
pass
else:
raise Exception('Incorrect Metadata format')
return continuation(handler_call_details)

def black_list_checking(self,meta):
r"""Tries to call to blacklist function in the miner and checks if it should blacklist the pubkey
"""
variable_length_messages = meta[1].value.split('bitxx')
pubkey = variable_length_messages[1]

if self.blacklist == None:
pass
elif self.blacklist(pubkey,int(meta[3].value)):
raise Exception('Black listed')
else:
pass
except Exception as e:
message = str(e)
abort = lambda _, ctx: ctx.abort(grpc.StatusCode.UNAUTHENTICATED, message)
return grpc.unary_unary_rpc_method_handler(abort)
2 changes: 1 addition & 1 deletion bittensor/_receptor/receptor_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ def finalize_stats_and_logs():
('rpc-auth-header','Bittensor'),
('bittensor-signature',self.sign()),
('bittensor-version',str(bittensor.__version_as_int__)),
('request_type', str(bittensor.proto.RequestType.FORWARD)),
('request_type', str(bittensor.proto.RequestType.BACKWARD)),
))
asyncio_future.cancel()

Expand Down