diff --git a/bittensor/axon.py b/bittensor/axon.py index 63a97c4246..38093c2fde 100644 --- a/bittensor/axon.py +++ b/bittensor/axon.py @@ -364,7 +364,6 @@ def __init__( self.priority_fns: Dict[str, Optional[Callable]] = {} self.forward_fns: Dict[str, Optional[Callable]] = {} self.verify_fns: Dict[str, Optional[Callable]] = {} - self.required_hash_fields: Dict[str, str] = {} # Instantiate FastAPI self.app = FastAPI() @@ -566,12 +565,6 @@ def verify_custom(synapse: MyCustomSynapse): ) # Use 'default_verify' if 'verify_fn' is None self.forward_fns[request_name] = forward_fn - # Parse required hash fields from the forward function protocol defaults - required_hash_fields = request_class.__dict__["model_fields"][ - "required_hash_fields" - ].default - self.required_hash_fields[request_name] = required_hash_fields - return self @classmethod @@ -696,9 +689,7 @@ async def verify_body_integrity(self, request: Request): body = await request.body() request_body = body.decode() if isinstance(body, bytes) else body - # Gather the required field names from the axon's required_hash_fields dict request_name = request.url.path.split("/")[1] - required_hash_fields = self.required_hash_fields[request_name] # Load the body dict and check if all required field hashes match body_dict = json.loads(request_body) diff --git a/bittensor/synapse.py b/bittensor/synapse.py index cfd74e47d3..c2971202c0 100644 --- a/bittensor/synapse.py +++ b/bittensor/synapse.py @@ -20,6 +20,8 @@ import base64 import json import sys +import typing +import warnings from pydantic import ( BaseModel, @@ -29,7 +31,7 @@ model_validator, ) import bittensor -from typing import Optional, List, Any, Dict +from typing import Optional, Any, Dict def get_size(obj, seen=None) -> int: @@ -301,6 +303,8 @@ class Synapse(BaseModel): 5. Body Hash Computation (``computed_body_hash``, ``required_hash_fields``): Ensures data integrity and security by computing hashes of transmitted data. Provides users with a mechanism to verify data integrity and detect any tampering during transmission. + It is recommended that names of fields in `required_hash_fields` are listed in the order they are + defined in the class. 6. Serialization and Deserialization Methods: Facilitates the conversion of Synapse objects to and from a format suitable for network transmission. @@ -478,14 +482,7 @@ def set_name_type(cls, values) -> dict: repr=False, ) - required_hash_fields: Optional[List[str]] = Field( - title="required_hash_fields", - description="The list of required fields to compute the body hash.", - examples=["roles", "messages"], - default=[], - frozen=True, - repr=False, - ) + required_hash_fields: typing.ClassVar[typing.Tuple[str, ...]] = () _extract_total_size = field_validator("total_size", mode="before")(cast_int) @@ -692,21 +689,37 @@ def body_hash(self) -> str: Returns: str: The SHA3-256 hash as a hexadecimal string, providing a fingerprint of the Synapse instance's data for integrity checks. """ - # Hash the body for verification hashes = [] - # Getting the fields of the instance - instance_fields = self.model_dump() + hash_fields_field = self.model_fields.get("required_hash_fields") + instance_fields = None + if hash_fields_field: + warnings.warn( + "The 'required_hash_fields' field handling deprecated and will be removed. " + "Please update Synapse class definition to use 'required_hash_fields' class variable instead.", + DeprecationWarning, + ) + required_hash_fields = hash_fields_field.default + + if required_hash_fields: + instance_fields = self.model_dump() + # Preserve backward compatibility in which fields will added in .dict() order + # instead of the order one from `self.required_hash_fields` + required_hash_fields = [ + field for field in instance_fields if field in required_hash_fields + ] + + # Hack to cache the required hash fields names + if len(required_hash_fields) == len(required_hash_fields): + self.__class__.required_hash_fields = tuple(required_hash_fields) + else: + required_hash_fields = self.__class__.required_hash_fields + + if required_hash_fields: + instance_fields = instance_fields or self.dict() + for field in required_hash_fields: + hashes.append(bittensor.utils.hash(str(instance_fields[field]))) - for field, value in instance_fields.items(): - # If the field is required in the subclass schema, hash and add it. - if ( - self.required_hash_fields is not None - and field in self.required_hash_fields - ): - hashes.append(bittensor.utils.hash(str(value))) - - # Hash and return the hashes that have been concatenated return bittensor.utils.hash("".join(hashes)) @classmethod diff --git a/tests/unit_tests/test_synapse.py b/tests/unit_tests/test_synapse.py index 83063aaf50..6be99520c1 100644 --- a/tests/unit_tests/test_synapse.py +++ b/tests/unit_tests/test_synapse.py @@ -16,16 +16,16 @@ # DEALINGS IN THE SOFTWARE. import json import base64 -from typing import List, Optional +import typing +from typing import Optional -import pydantic_core import pytest import bittensor def test_parse_headers_to_inputs(): class Test(bittensor.Synapse): - key1: List[int] + key1: list[int] # Define a mock headers dictionary to use for testing headers = { @@ -60,7 +60,7 @@ class Test(bittensor.Synapse): def test_from_headers(): class Test(bittensor.Synapse): - key1: List[int] + key1: list[int] # Define a mock headers dictionary to use for testing headers = { @@ -131,13 +131,13 @@ class Test(bittensor.Synapse): a: int # Carried through because required. b: int = None # Not carried through headers c: Optional[int] # Required, carried through headers, cannot be None - d: Optional[List[int]] # Required, carried though headers, cannot be None - e: List[int] # Carried through headers + d: Optional[list[int]] # Required, carried though headers, cannot be None + e: list[int] # Carried through headers f: Optional[ int ] = None # Not Required, Not carried through headers, can be None g: Optional[ - List[int] + list[int] ] = None # Not Required, Not carried though headers, can be None # Create an instance of the custom Synapse subclass @@ -152,12 +152,12 @@ class Test(bittensor.Synapse): assert isinstance(synapse, Test) assert synapse.name == "Test" assert synapse.a == 1 - assert synapse.b == None + assert synapse.b is None assert synapse.c == 3 assert synapse.d == [1, 2, 3, 4] assert synapse.e == [1, 2, 3, 4] - assert synapse.f == None - assert synapse.g == None + assert synapse.f is None + assert synapse.g is None # Convert the Test instance to a headers dictionary headers = synapse.to_headers() @@ -169,12 +169,12 @@ class Test(bittensor.Synapse): # Create a new Test from the headers and check its properties next_synapse = synapse.from_headers(synapse.to_headers()) assert next_synapse.a == 0 # Default value is 0 - assert next_synapse.b == None + assert next_synapse.b is None assert next_synapse.c == 0 # Default is 0 assert next_synapse.d == [] # Default is [] assert next_synapse.e == [] # Empty list is default for list types - assert next_synapse.f == None - assert next_synapse.g == None + assert next_synapse.f is None + assert next_synapse.g is None def test_body_hash_override(): @@ -189,18 +189,6 @@ def test_body_hash_override(): synapse_instance.body_hash = [] -def test_required_fields_override(): - # Create a Synapse instance - synapse_instance = bittensor.Synapse() - - # Try to set the required_hash_fields property and expect a TypeError - with pytest.raises( - pydantic_core.ValidationError, - match="required_hash_fields\n Field is frozen", - ): - synapse_instance.required_hash_fields = [] - - def test_default_instance_fields_dict_consistency(): synapse_instance = bittensor.Synapse() assert synapse_instance.dict() == { @@ -233,5 +221,48 @@ def test_default_instance_fields_dict_consistency(): "signature": None, }, "computed_body_hash": "", - "required_hash_fields": [], } + + +class LegacyHashedSynapse(bittensor.Synapse): + """Legacy Synapse subclass that serialized `required_hash_fields`.""" + + a: int + b: int + c: Optional[int] = None + d: Optional[list[str]] = None + required_hash_fields: Optional[list[str]] = ["b", "a", "d"] + + +class HashedSynapse(bittensor.Synapse): + a: int + b: int + c: Optional[int] = None + d: Optional[list[str]] = None + required_hash_fields: typing.ClassVar[tuple[str, ...]] = ("a", "b", "d") + + +@pytest.mark.parametrize("synapse_cls", [LegacyHashedSynapse, HashedSynapse]) +def test_synapse_body_hash(synapse_cls): + synapse_instance = synapse_cls(a=1, b=2, d=["foobar"]) + assert ( + synapse_instance.body_hash + == "ae06397d08f30f75c91395c59f05c62ac3b62b88250eb78b109213258e6ced0c" + ) + + # Extra non-hashed values should not influence the body hash + synapse_instance_slightly_different = synapse_cls(d=["foobar"], c=3, a=1, b=2) + assert synapse_instance.body_hash == synapse_instance_slightly_different.body_hash + + # Even if someone tries to override the required_hash_fields, it should still be the same + synapse_instance_try_override_hash_fields = synapse_cls( + a=1, b=2, d=["foobar"], required_hash_fields=["a"] + ) + assert ( + synapse_instance.body_hash + == synapse_instance_try_override_hash_fields.body_hash + ) + + # Different hashed values should result in different body hashes + synapse_different = synapse_cls(a=1, b=2) + assert synapse_instance.body_hash != synapse_different.body_hash