Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 103 additions & 19 deletions async_substrate_interface/substrate_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import ssl
import time
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass
from datetime import datetime
from hashlib import blake2b
Expand Down Expand Up @@ -51,29 +52,117 @@


class ScaleObj:
def __new__(cls, value):
if isinstance(value, (dict, str, int)):
return value
return super().__new__(cls)
"""Bittensor representation of Scale Object."""

def __init__(self, value):
self.value = list(value) if isinstance(value, tuple) else value

def __new__(cls, value):
return super().__new__(cls)

def __str__(self):
return f"BittensorScaleType(value={self.value})>"

def __bool__(self):
if self.value:
return True
else:
return False

def __repr__(self):
return repr(self.value)
return repr(f"BittensorScaleType(value={self.value})>")

def __eq__(self, other):
return self.value == other
return self.value == (other.value if isinstance(other, ScaleObj) else other)

def __lt__(self, other):
return self.value < (other.value if isinstance(other, ScaleObj) else other)

def __gt__(self, other):
return self.value > (other.value if isinstance(other, ScaleObj) else other)

def __le__(self, other):
return self.value <= (other.value if isinstance(other, ScaleObj) else other)

def __ge__(self, other):
return self.value >= (other.value if isinstance(other, ScaleObj) else other)

def __add__(self, other):
if isinstance(other, ScaleObj):
return ScaleObj(self.value + other.value)
return ScaleObj(self.value + other)

def __radd__(self, other):
return ScaleObj(other + self.value)

def __sub__(self, other):
if isinstance(other, ScaleObj):
return ScaleObj(self.value - other.value)
return ScaleObj(self.value - other)

def __rsub__(self, other):
return ScaleObj(other - self.value)

def __mul__(self, other):
if isinstance(other, ScaleObj):
return ScaleObj(self.value * other.value)
return ScaleObj(self.value * other)

def __rmul__(self, other):
return ScaleObj(other * self.value)

def __truediv__(self, other):
if isinstance(other, ScaleObj):
return ScaleObj(self.value / other.value)
return ScaleObj(self.value / other)

def __rtruediv__(self, other):
return ScaleObj(other / self.value)

def __floordiv__(self, other):
if isinstance(other, ScaleObj):
return ScaleObj(self.value // other.value)
return ScaleObj(self.value // other)

def __rfloordiv__(self, other):
return ScaleObj(other // self.value)

def __mod__(self, other):
if isinstance(other, ScaleObj):
return ScaleObj(self.value % other.value)
return ScaleObj(self.value % other)

def __rmod__(self, other):
return ScaleObj(other % self.value)

def __pow__(self, other):
if isinstance(other, ScaleObj):
return ScaleObj(self.value**other.value)
return ScaleObj(self.value**other)

def __rpow__(self, other):
return ScaleObj(other**self.value)

def __getitem__(self, key):
if isinstance(self.value, (list, tuple, dict, str)):
return self.value[key]
raise TypeError(
f"Object of type '{type(self.value).__name__}' does not support indexing"
)

def __iter__(self):
for item in self.value:
yield item
if isinstance(self.value, Iterable):
return iter(self.value)
raise TypeError(f"Object of type '{type(self.value).__name__}' is not iterable")

def __getitem__(self, item):
return self.value[item]
def __len__(self):
return len(self.value)

def serialize(self):
return self.value

def decode(self):
return self.value


class AsyncExtrinsicReceipt:
Expand Down Expand Up @@ -998,10 +1087,7 @@ def __init__(
)
if pre_initialize:
if not _mock:
execute_coroutine(
coroutine=self.initialize(),
event_loop=self.event_loop,
)
self.event_loop.create_task(self.initialize())
else:
self.reload_type_registry()

Expand Down Expand Up @@ -3478,9 +3564,9 @@ async def query(
raw_storage_key: Optional[bytes] = None,
subscription_handler=None,
reuse_block_hash: bool = False,
) -> "ScaleType":
) -> Optional[Union["ScaleObj", Any]]:
"""
Queries subtensor. This should only be used when making a single request. For multiple requests,
Queries substrate. This should only be used when making a single request. For multiple requests,
you should use ``self.query_multiple``
"""
block_hash = await self._get_current_block_hash(block_hash, reuse_block_hash)
Expand Down Expand Up @@ -3524,7 +3610,7 @@ async def query_map(
page_size: int = 100,
ignore_decoding_errors: bool = False,
reuse_block_hash: bool = False,
) -> "QueryMapResult":
) -> QueryMapResult:
"""
Iterates over all key-pairs located at the given module and storage_function. The storage
item must be a map.
Expand Down Expand Up @@ -3686,9 +3772,7 @@ def concat_hash_len(key_hasher: str) -> int:
if not ignore_decoding_errors:
raise
item_value = None

result.append([item_key, item_value])

return QueryMapResult(
records=result,
page_size=page_size,
Expand Down
61 changes: 59 additions & 2 deletions tests/unit_tests/test_substrate_interface.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,68 @@
import asyncio
import pytest
from websockets.exceptions import InvalidURI

from async_substrate_interface.substrate_interface import AsyncSubstrateInterface
from async_substrate_interface.substrate_interface import (
AsyncSubstrateInterface,
ScaleObj,
)


@pytest.mark.asyncio
async def test_invalid_url_raises_exception():
"""Test that invalid URI raises an InvalidURI exception."""
with pytest.raises(InvalidURI):
AsyncSubstrateInterface("non_existent_entry_point")
async with AsyncSubstrateInterface("non_existent_entry_point"):
pass


def test_scale_object():
"""Verifies that the instance can be subject to various operations."""
# Preps
inst_int = ScaleObj(100)

# Asserts
assert inst_int + 1 == 101
assert 1 + inst_int == 101
assert inst_int - 1 == 99
assert 101 - inst_int == 1
assert inst_int * 2 == 200
assert 2 * inst_int == 200
assert inst_int / 2 == 50
assert 100 / inst_int == 1
assert inst_int // 2 == 50
assert 1001 // inst_int == 10
assert inst_int % 3 == 1
assert 1002 % inst_int == 2
assert inst_int >= 99
assert inst_int <= 101

# Preps
inst_str = ScaleObj("test")

# Asserts
assert inst_str + "test1" == "testtest1"
assert "test1" + inst_str == "test1test"
assert inst_str * 2 == "testtest"
assert 2 * inst_str == "testtest"
assert inst_str >= "test"
assert inst_str <= "testtest"
assert inst_str[0] == "t"
assert [i for i in inst_str] == ["t", "e", "s", "t"]

# Preps
inst_list = ScaleObj([1, 2, 3])

# Asserts
assert inst_list[0] == 1
assert inst_list[-1] == 3
assert inst_list * 2 == inst_list + inst_list
assert [i for i in inst_list] == [1, 2, 3]
assert inst_list >= [1, 2]
assert inst_list <= [1, 2, 3, 4]
assert len(inst_list) == 3

inst_dict = ScaleObj({"a": 1, "b": 2})
assert inst_dict["a"] == 1
assert inst_dict["b"] == 2
assert [i for i in inst_dict] == ["a", "b"]