Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 26 additions & 60 deletions bittensor/axon.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,27 +466,19 @@ def verify_custom(synapse: MyCustomSynapse):
offered by this method allows developers to tailor the Axon's behavior to specific requirements and
use cases.
"""

# Assert 'forward_fn' has exactly one argument
forward_sig = signature(forward_fn)
assert (
len(list(forward_sig.parameters)) == 1
), "The passed function must have exactly one argument"

# Obtain the class of the first argument of 'forward_fn'
request_class = forward_sig.parameters[
list(forward_sig.parameters)[0]
].annotation
try:
first_param = next(iter(forward_sig.parameters.values()))
except StopIteration:
raise ValueError(
"The forward_fn first argument must be a subclass of bittensor.Synapse, but it has no arguments"
)

# Assert that the first argument of 'forward_fn' is a subclass of 'bittensor.Synapse'
param_class = first_param.annotation
assert issubclass(
request_class, bittensor.Synapse
), "The argument of forward_fn must inherit from bittensor.Synapse"

# Obtain the class name of the first argument of 'forward_fn'
request_name = forward_sig.parameters[
list(forward_sig.parameters)[0]
].annotation.__name__
param_class, bittensor.Synapse
), "The first argument of forward_fn must inherit from bittensor.Synapse"
request_name = param_class.__name__

# Add the endpoint to the router, making it available on both GET and POST methods
self.router.add_api_route(
Expand All @@ -497,68 +489,42 @@ def verify_custom(synapse: MyCustomSynapse):
)
self.app.include_router(self.router)

# Expected signatures for 'blacklist_fn', 'priority_fn' and 'verify_fn'
blacklist_sig = Signature(
[
Parameter(
"synapse",
Parameter.POSITIONAL_OR_KEYWORD,
annotation=forward_sig.parameters[
list(forward_sig.parameters)[0]
].annotation,
)
],
return_annotation=Tuple[bool, str],
)
priority_sig = Signature(
[
Parameter(
"synapse",
Parameter.POSITIONAL_OR_KEYWORD,
annotation=forward_sig.parameters[
list(forward_sig.parameters)[0]
].annotation,
)
],
return_annotation=float,
)
verify_sig = Signature(
[
Parameter(
"synapse",
Parameter.POSITIONAL_OR_KEYWORD,
annotation=forward_sig.parameters[
list(forward_sig.parameters)[0]
].annotation,
)
],
return_annotation=None,
)

# Check the signature of blacklist_fn, priority_fn and verify_fn if they are provided
expected_params = [
Parameter(
"synapse",
Parameter.POSITIONAL_OR_KEYWORD,
annotation=forward_sig.parameters[
list(forward_sig.parameters)[0]
].annotation,
)
]
if blacklist_fn:
blacklist_sig = Signature(
expected_params, return_annotation=Tuple[bool, str]
)
assert (
signature(blacklist_fn) == blacklist_sig
), "The blacklist_fn function must have the signature: blacklist( synapse: {} ) -> Tuple[bool, str]".format(
request_name
)
if priority_fn:
priority_sig = Signature(expected_params, return_annotation=float)
assert (
signature(priority_fn) == priority_sig
), "The priority_fn function must have the signature: priority( synapse: {} ) -> float".format(
request_name
)
if verify_fn:
verify_sig = Signature(expected_params, return_annotation=None)
assert (
signature(verify_fn) == verify_sig
), "The verify_fn function must have the signature: verify( synapse: {} ) -> None".format(
request_name
)

# Store functions in appropriate attribute dictionaries
self.forward_class_types[request_name] = forward_sig.parameters[
list(forward_sig.parameters)[0]
].annotation
self.forward_class_types[request_name] = param_class
self.blacklist_fns[request_name] = blacklist_fn
self.priority_fns[request_name] = priority_fn
self.verify_fns[request_name] = (
Expand All @@ -567,7 +533,7 @@ def verify_custom(synapse: MyCustomSynapse):
self.forward_fns[request_name] = forward_fn

# Parse required hash fields from the forward function protocol defaults
required_hash_fields = request_class.__dict__["__fields__"][
required_hash_fields = param_class.__dict__["__fields__"][
"required_hash_fields"
].default
self.required_hash_fields[request_name] = required_hash_fields
Expand Down