From 70bf5ac5e55c570bd51c2d814ca03b2e79b48143 Mon Sep 17 00:00:00 2001 From: Adrian-Stefan Mares Date: Sun, 6 Nov 2022 12:37:49 +0100 Subject: [PATCH 1/6] Format AuthInterceptor using black --- bittensor/_axon/__init__.py | 80 ++++++++++++++++++------------------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/bittensor/_axon/__init__.py b/bittensor/_axon/__init__.py index ea422d0a0a..5ed4302154 100644 --- a/bittensor/_axon/__init__.py +++ b/bittensor/_axon/__init__.py @@ -339,39 +339,39 @@ def check_forward_callback( forward_callback:Callable, synapses:list = []): forward_callback([sample_input], synapses, hotkey='') class AuthInterceptor(grpc.ServerInterceptor): - """ Creates a new server interceptor that authenticates incoming messages from passed arguments. - """ - def __init__(self, key:str = 'Bittensor',blacklist:List = []): - r""" Creates a new server interceptor that authenticates incoming messages from passed arguments. + """Creates a new server interceptor that authenticates incoming messages from passed arguments.""" + + def __init__(self, key: str = "Bittensor", blacklist: List = []): + r"""Creates a new server interceptor that authenticates incoming messages from passed arguments. Args: key (str, `optional`): key for authentication header in the metadata (default= Bittensor) - black_list (Fucntion, `optional`): + black_list (Fucntion, `optional`): black list function that prevents certain pubkeys from sending messages """ super().__init__() - self._valid_metadata = ('rpc-auth-header', key) + self._valid_metadata = ("rpc-auth-header", key) self.nounce_dic = {} - self.message = 'Invalid key' + self.message = "Invalid key" self.blacklist = blacklist + def deny(_, context): context.abort(grpc.StatusCode.UNAUTHENTICATED, self.message) self._deny = grpc.unary_unary_rpc_method_handler(deny) def intercept_service(self, continuation, handler_call_details): - r""" Authentication between bittensor nodes. Intercepts messages and checks them - """ + r"""Authentication between bittensor nodes. Intercepts messages and checks them""" meta = handler_call_details.invocation_metadata - try: - #version checking + try: + # version checking self.version_checking(meta) - #signature checking + # signature checking self.signature_checking(meta) - #blacklist checking + # blacklist checking self.black_list_checking(meta) return continuation(handler_call_details) @@ -380,10 +380,9 @@ def intercept_service(self, continuation, handler_call_details): self.message = str(e) return self._deny - def vertification(self,meta): - r"""vertification of signature in metadata. Uses the pubkey and nounce - """ - variable_length_messages = meta[1].value.split('bitxx') + def vertification(self, meta): + r"""vertification of signature in metadata. Uses the pubkey and nounce""" + variable_length_messages = meta[1].value.split("bitxx") nounce = int(variable_length_messages[0]) pubkey = variable_length_messages[1] message = variable_length_messages[2] @@ -392,48 +391,49 @@ def vertification(self,meta): # Unique key that specifies the endpoint. endpoint_key = str(pubkey) + str(unique_receptor_uid) - - #checking the time of creation, compared to previous messages + + # checking the time of creation, compared to previous messages if endpoint_key in self.nounce_dic.keys(): - prev_data_time = self.nounce_dic[ endpoint_key ] + prev_data_time = self.nounce_dic[endpoint_key] if nounce - prev_data_time > -10: - self.nounce_dic[ endpoint_key ] = nounce + self.nounce_dic[endpoint_key] = nounce - #decrypting the message and verify that message is correct - verification = _keypair.verify( str(nounce) + str(pubkey) + str(unique_receptor_uid), message) + # decrypting the message and verify that message is correct + verification = _keypair.verify( + str(nounce) + str(pubkey) + str(unique_receptor_uid), message + ) else: verification = False else: - self.nounce_dic[ endpoint_key ] = nounce - verification = _keypair.verify( str( nounce ) + str(pubkey) + str(unique_receptor_uid), message) + self.nounce_dic[endpoint_key] = nounce + verification = _keypair.verify( + str(nounce) + str(pubkey) + str(unique_receptor_uid), message + ) return verification - def signature_checking(self,meta): - r""" Calls the vertification of the signature and raises an error if failed - """ + def signature_checking(self, meta): + r"""Calls the vertification of the signature and raises an error if failed""" if self.vertification(meta): pass else: - raise Exception('Incorrect Signature') + raise Exception("Incorrect Signature") - def version_checking(self,meta): - r""" Checks the header and version in the metadata - """ + def version_checking(self, meta): + r"""Checks the header and version in the metadata""" if meta[0] == self._valid_metadata: pass else: - raise Exception('Incorrect Metadata format') + raise Exception("Incorrect Metadata format") - def black_list_checking(self,meta): - r"""Tries to call to blacklist function in the miner and checks if it should blacklist the pubkey - """ - variable_length_messages = meta[1].value.split('bitxx') + def black_list_checking(self, meta): + r"""Tries to call to blacklist function in the miner and checks if it should blacklist the pubkey""" + variable_length_messages = meta[1].value.split("bitxx") pubkey = variable_length_messages[1] - + if self.blacklist == None: pass - elif self.blacklist(pubkey,int(meta[3].value)): - raise Exception('Black listed') + elif self.blacklist(pubkey, int(meta[3].value)): + raise Exception("Black listed") else: pass From 28072463ec431d2d8531b0bac8c81bb7af6f2092 Mon Sep 17 00:00:00 2001 From: Adrian-Stefan Mares Date: Sun, 6 Nov 2022 12:53:25 +0100 Subject: [PATCH 2/6] Parse request metadata as key value pairs --- bittensor/_axon/__init__.py | 60 +++++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 22 deletions(-) diff --git a/bittensor/_axon/__init__.py b/bittensor/_axon/__init__.py index 5ed4302154..c4b030734d 100644 --- a/bittensor/_axon/__init__.py +++ b/bittensor/_axon/__init__.py @@ -345,12 +345,13 @@ def __init__(self, key: str = "Bittensor", blacklist: List = []): r"""Creates a new server interceptor that authenticates incoming messages from passed arguments. Args: key (str, `optional`): - key for authentication header in the metadata (default= Bittensor) - black_list (Fucntion, `optional`): + key for authentication header in the metadata (default = Bittensor) + black_list (Function, `optional`): black list function that prevents certain pubkeys from sending messages """ super().__init__() - self._valid_metadata = ("rpc-auth-header", key) + self._signature_separator = "bitxx" + self._expected_auth_metadata = ("rpc-auth-header", key) self.nounce_dic = {} self.message = "Invalid key" self.blacklist = blacklist @@ -362,7 +363,7 @@ def deny(_, context): def intercept_service(self, continuation, handler_call_details): r"""Authentication between bittensor nodes. Intercepts messages and checks them""" - meta = handler_call_details.invocation_metadata + meta = dict(handler_call_details.invocation_metadata) try: # version checking @@ -380,9 +381,19 @@ def intercept_service(self, continuation, handler_call_details): self.message = str(e) return self._deny - def vertification(self, meta): - r"""vertification of signature in metadata. Uses the pubkey and nounce""" - variable_length_messages = meta[1].value.split("bitxx") + def get_signature(self, meta): + r"""get_signature from the metadata. Raises exception when the signature is missing""" + signature = meta.get("bittensor-signature") + if signature is None: + raise Exception("Request signature missing") + return signature + + def verification(self, meta): + r"""verification of signature in metadata. Uses the pubkey and nounce""" + variable_length_messages = self.get_signature(meta).split( + self._signature_separator + ) + nounce = int(variable_length_messages[0]) pubkey = variable_length_messages[1] message = variable_length_messages[2] @@ -413,27 +424,32 @@ def vertification(self, meta): return verification def signature_checking(self, meta): - r"""Calls the vertification of the signature and raises an error if failed""" - if self.vertification(meta): - pass - else: - raise Exception("Incorrect Signature") + r"""Calls the verification of the signature and raises an error if failed""" + if not self.verification(meta): + raise Exception("Signature mismatch") def version_checking(self, meta): r"""Checks the header and version in the metadata""" - if meta[0] == self._valid_metadata: - pass - else: - raise Exception("Incorrect Metadata format") + (key, expected_value) = self._expected_auth_metadata + provided_value = meta.get(key) + if provided_value is None or provided_value != expected_value: + raise Exception("Unexpected caller metadata") def black_list_checking(self, meta): r"""Tries to call to blacklist function in the miner and checks if it should blacklist the pubkey""" - variable_length_messages = meta[1].value.split("bitxx") + variable_length_messages = self.get_signature(meta).split( + self._signature_separator + ) pubkey = variable_length_messages[1] if self.blacklist == None: - pass - elif self.blacklist(pubkey, int(meta[3].value)): - raise Exception("Black listed") - else: - pass + return + + request_type = meta.get("request_type") + if request_type is None: + raise Exception("Missing request type") + request_type = int(request_type) + + if self.blacklist(pubkey, request_type): + raise Exception("Request type is blacklisted") + From 49c3abde5e3fe49524d71f9d9434fa4487037bda Mon Sep 17 00:00:00 2001 From: Adrian-Stefan Mares Date: Sun, 6 Nov 2022 13:34:31 +0100 Subject: [PATCH 3/6] Use request method to black list calls --- bittensor/_axon/__init__.py | 41 ++++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/bittensor/_axon/__init__.py b/bittensor/_axon/__init__.py index c4b030734d..7ce9b1656d 100644 --- a/bittensor/_axon/__init__.py +++ b/bittensor/_axon/__init__.py @@ -363,17 +363,18 @@ def deny(_, context): def intercept_service(self, continuation, handler_call_details): r"""Authentication between bittensor nodes. Intercepts messages and checks them""" - meta = dict(handler_call_details.invocation_metadata) + method = handler_call_details.method + metadata = dict(handler_call_details.invocation_metadata) try: # version checking - self.version_checking(meta) + self.version_checking(metadata) # signature checking - self.signature_checking(meta) + self.signature_checking(metadata) # blacklist checking - self.black_list_checking(meta) + self.black_list_checking(metadata, method) return continuation(handler_call_details) @@ -388,9 +389,9 @@ def get_signature(self, meta): raise Exception("Request signature missing") return signature - def verification(self, meta): + def verification(self, metadata): r"""verification of signature in metadata. Uses the pubkey and nounce""" - variable_length_messages = self.get_signature(meta).split( + variable_length_messages = self.get_signature(metadata).split( self._signature_separator ) @@ -423,32 +424,34 @@ def verification(self, meta): return verification - def signature_checking(self, meta): + def signature_checking(self, metadata): r"""Calls the verification of the signature and raises an error if failed""" - if not self.verification(meta): + if not self.verification(metadata): raise Exception("Signature mismatch") - def version_checking(self, meta): + def version_checking(self, metadata): r"""Checks the header and version in the metadata""" (key, expected_value) = self._expected_auth_metadata - provided_value = meta.get(key) + provided_value = metadata.get(key) if provided_value is None or provided_value != expected_value: raise Exception("Unexpected caller metadata") - def black_list_checking(self, meta): + def black_list_checking(self, metadata, method): r"""Tries to call to blacklist function in the miner and checks if it should blacklist the pubkey""" - variable_length_messages = self.get_signature(meta).split( - self._signature_separator - ) - pubkey = variable_length_messages[1] - if self.blacklist == None: return - request_type = meta.get("request_type") + request_type = { + "/Bittensor/Forward": bittensor.proto.RequestType.FORWARD, + "/Bittensor/Backward": bittensor.proto.RequestType.BACKWARD, + }.get(method) if request_type is None: - raise Exception("Missing request type") - request_type = int(request_type) + raise Exception("Unknown request type") + + variable_length_messages = self.get_signature(metadata).split( + self._signature_separator + ) + pubkey = variable_length_messages[1] if self.blacklist(pubkey, request_type): raise Exception("Request type is blacklisted") From b69c9447f6f454af34c382c4e0a97acdca853b64 Mon Sep 17 00:00:00 2001 From: Adrian-Stefan Mares Date: Sun, 6 Nov 2022 13:36:27 +0100 Subject: [PATCH 4/6] Fix request type provided on backward --- bittensor/_receptor/receptor_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bittensor/_receptor/receptor_impl.py b/bittensor/_receptor/receptor_impl.py index d064fb8ef7..821a3f2bdf 100644 --- a/bittensor/_receptor/receptor_impl.py +++ b/bittensor/_receptor/receptor_impl.py @@ -664,7 +664,7 @@ def finalize_stats_and_logs(): ('rpc-auth-header','Bittensor'), ('bittensor-signature',self.sign()), ('bittensor-version',str(bittensor.__version_as_int__)), - ('request_type', str(bittensor.proto.RequestType.FORWARD)), + ('request_type', str(bittensor.proto.RequestType.BACKWARD)), )) asyncio_future.cancel() From 55bc38ea38b93943480288f48cd3bc28f034fa01 Mon Sep 17 00:00:00 2001 From: Adrian-Stefan Mares Date: Tue, 8 Nov 2022 20:40:04 +0100 Subject: [PATCH 5/6] Add type hints --- bittensor/_axon/__init__.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/bittensor/_axon/__init__.py b/bittensor/_axon/__init__.py index 7ce9b1656d..e1aab94d74 100644 --- a/bittensor/_axon/__init__.py +++ b/bittensor/_axon/__init__.py @@ -24,7 +24,7 @@ import inspect import time from concurrent import futures -from typing import List, Callable, Optional +from typing import Dict, List, Callable, Optional from bittensor._threadpool import prioritythreadpool import torch @@ -353,14 +353,8 @@ def __init__(self, key: str = "Bittensor", blacklist: List = []): self._signature_separator = "bitxx" self._expected_auth_metadata = ("rpc-auth-header", key) self.nounce_dic = {} - self.message = "Invalid key" self.blacklist = blacklist - def deny(_, context): - context.abort(grpc.StatusCode.UNAUTHENTICATED, self.message) - - self._deny = grpc.unary_unary_rpc_method_handler(deny) - def intercept_service(self, continuation, handler_call_details): r"""Authentication between bittensor nodes. Intercepts messages and checks them""" method = handler_call_details.method @@ -379,17 +373,20 @@ def intercept_service(self, continuation, handler_call_details): return continuation(handler_call_details) except Exception as e: - self.message = str(e) - return self._deny + return grpc.unary_unary_rpc_method_handler( + lambda _, context: context.abort( + grpc.StatusCode.UNAUTHENTICATED, str(e) + ) + ) - def get_signature(self, meta): + def get_signature(self, metadata: Dict[str, str]) -> str: r"""get_signature from the metadata. Raises exception when the signature is missing""" - signature = meta.get("bittensor-signature") + signature = metadata.get("bittensor-signature") if signature is None: raise Exception("Request signature missing") return signature - def verification(self, metadata): + def verification(self, metadata: Dict[str, str]) -> bool: r"""verification of signature in metadata. Uses the pubkey and nounce""" variable_length_messages = self.get_signature(metadata).split( self._signature_separator @@ -424,19 +421,19 @@ def verification(self, metadata): return verification - def signature_checking(self, metadata): + def signature_checking(self, metadata: Dict[str, str]): r"""Calls the verification of the signature and raises an error if failed""" if not self.verification(metadata): raise Exception("Signature mismatch") - def version_checking(self, metadata): + def version_checking(self, metadata: Dict[str, str]): r"""Checks the header and version in the metadata""" (key, expected_value) = self._expected_auth_metadata provided_value = metadata.get(key) if provided_value is None or provided_value != expected_value: raise Exception("Unexpected caller metadata") - def black_list_checking(self, metadata, method): + def black_list_checking(self, metadata: Dict[str, str], method: str): r"""Tries to call to blacklist function in the miner and checks if it should blacklist the pubkey""" if self.blacklist == None: return @@ -455,4 +452,3 @@ def black_list_checking(self, metadata, method): if self.blacklist(pubkey, request_type): raise Exception("Request type is blacklisted") - From 54a9ab035d85c6a9e54c681dcaf41c0b816b7a18 Mon Sep 17 00:00:00 2001 From: Adrian-Stefan Mares Date: Wed, 9 Nov 2022 00:17:43 +0100 Subject: [PATCH 6/6] Refactor signature parsing --- bittensor/_axon/__init__.py | 152 ++++++++++++++++++------------------ 1 file changed, 75 insertions(+), 77 deletions(-) diff --git a/bittensor/_axon/__init__.py b/bittensor/_axon/__init__.py index e1aab94d74..b2f2d30a22 100644 --- a/bittensor/_axon/__init__.py +++ b/bittensor/_axon/__init__.py @@ -24,7 +24,7 @@ import inspect import time from concurrent import futures -from typing import Dict, List, Callable, Optional +from typing import Dict, List, Callable, Optional, Tuple, Union from bittensor._threadpool import prioritythreadpool import torch @@ -350,90 +350,69 @@ def __init__(self, key: str = "Bittensor", blacklist: List = []): black list function that prevents certain pubkeys from sending messages """ super().__init__() - self._signature_separator = "bitxx" - self._expected_auth_metadata = ("rpc-auth-header", key) - self.nounce_dic = {} + self.auth_header_value = key + self.nonces = {} self.blacklist = blacklist - def intercept_service(self, continuation, handler_call_details): - r"""Authentication between bittensor nodes. Intercepts messages and checks them""" - method = handler_call_details.method - metadata = dict(handler_call_details.invocation_metadata) - + def parse_legacy_signature( + self, signature: str + ) -> Union[Tuple[int, str, str, str], None]: + r"""Attempts to parse a signature using the legacy format, using `bitxx` as a separator""" + parts = signature.split("bitxx") + if len(parts) < 4: + return None try: - # version checking - self.version_checking(metadata) - - # signature checking - self.signature_checking(metadata) - - # blacklist checking - self.black_list_checking(metadata, method) - - return continuation(handler_call_details) - - except Exception as e: - return grpc.unary_unary_rpc_method_handler( - lambda _, context: context.abort( - grpc.StatusCode.UNAUTHENTICATED, str(e) - ) - ) - - def get_signature(self, metadata: Dict[str, str]) -> str: - r"""get_signature from the metadata. Raises exception when the signature is missing""" + nonce = int(parts[0]) + parts = parts[1:] + except ValueError: + return None + receptor_uuid, parts = parts[-1], parts[:-1] + message, parts = parts[-1], parts[:-1] + pubkey = "".join(parts) + return (nonce, pubkey, message, receptor_uuid) + + def parse_signature(self, metadata: Dict[str, str]) -> Tuple[int, str, str, str]: + r"""Attempts to parse a signature from the metadata""" signature = metadata.get("bittensor-signature") if signature is None: raise Exception("Request signature missing") - return signature - - def verification(self, metadata: Dict[str, str]) -> bool: - r"""verification of signature in metadata. Uses the pubkey and nounce""" - variable_length_messages = self.get_signature(metadata).split( - self._signature_separator - ) - - nounce = int(variable_length_messages[0]) - pubkey = variable_length_messages[1] - message = variable_length_messages[2] - unique_receptor_uid = variable_length_messages[3] - _keypair = Keypair(ss58_address=pubkey) - - # Unique key that specifies the endpoint. - endpoint_key = str(pubkey) + str(unique_receptor_uid) - - # checking the time of creation, compared to previous messages - if endpoint_key in self.nounce_dic.keys(): - prev_data_time = self.nounce_dic[endpoint_key] - if nounce - prev_data_time > -10: - self.nounce_dic[endpoint_key] = nounce - - # decrypting the message and verify that message is correct - verification = _keypair.verify( - str(nounce) + str(pubkey) + str(unique_receptor_uid), message - ) - else: - verification = False - else: - self.nounce_dic[endpoint_key] = nounce - verification = _keypair.verify( - str(nounce) + str(pubkey) + str(unique_receptor_uid), message - ) - - return verification + parts = self.parse_legacy_signature(signature) + if parts is not None: + return parts + raise Exception("Unknown signature format") + + def check_signature( + self, nonce: int, pubkey: str, signature: str, receptor_uuid: str + ): + r"""verification of signature in metadata. Uses the pubkey and nonce""" + keypair = Keypair(ss58_address=pubkey) + # Build the expected message which was used to build the signature. + message = f"{nonce}{pubkey}{receptor_uuid}" + # Build the key which uniquely identifies the endpoint that has signed + # the message. + endpoint_key = f"{pubkey}:{receptor_uuid}" + + if endpoint_key in self.nonces.keys(): + previous_nonce = self.nonces[endpoint_key] + # Nonces must be strictly monotonic over time. + if nonce - previous_nonce <= -10: + raise Exception("Nonce is too small") + if not keypair.verify(message, signature): + raise Exception("Signature mismatch") + self.nonces[endpoint_key] = nonce + return - def signature_checking(self, metadata: Dict[str, str]): - r"""Calls the verification of the signature and raises an error if failed""" - if not self.verification(metadata): + if not keypair.verify(message, signature): raise Exception("Signature mismatch") + self.nonces[endpoint_key] = nonce def version_checking(self, metadata: Dict[str, str]): r"""Checks the header and version in the metadata""" - (key, expected_value) = self._expected_auth_metadata - provided_value = metadata.get(key) - if provided_value is None or provided_value != expected_value: + provided_value = metadata.get("rpc-auth-header") + if provided_value is None or provided_value != self.auth_header_value: raise Exception("Unexpected caller metadata") - def black_list_checking(self, metadata: Dict[str, str], method: str): + def black_list_checking(self, pubkey: str, method: str): r"""Tries to call to blacklist function in the miner and checks if it should blacklist the pubkey""" if self.blacklist == None: return @@ -445,10 +424,29 @@ def black_list_checking(self, metadata: Dict[str, str], method: str): if request_type is None: raise Exception("Unknown request type") - variable_length_messages = self.get_signature(metadata).split( - self._signature_separator - ) - pubkey = variable_length_messages[1] - if self.blacklist(pubkey, request_type): raise Exception("Request type is blacklisted") + + def intercept_service(self, continuation, handler_call_details): + r"""Authentication between bittensor nodes. Intercepts messages and checks them""" + method = handler_call_details.method + metadata = dict(handler_call_details.invocation_metadata) + + try: + # version checking + self.version_checking(metadata) + + (nonce, pubkey, signature, receptor_uuid) = self.parse_signature(metadata) + + # signature checking + self.check_signature(nonce, pubkey, signature, receptor_uuid) + + # blacklist checking + self.black_list_checking(pubkey, method) + + return continuation(handler_call_details) + + except Exception as e: + message = str(e) + abort = lambda _, ctx: ctx.abort(grpc.StatusCode.UNAUTHENTICATED, message) + return grpc.unary_unary_rpc_method_handler(abort)