diff --git a/circkit/operation.py b/circkit/operation.py index 8bf2a08..ae00103 100644 --- a/circkit/operation.py +++ b/circkit/operation.py @@ -203,6 +203,7 @@ class Operation(metaclass=OperationMeta): PRECOMPUTABLE: bool Whether the operation is precomputable (given the inputs and the parameters). + Equivalent to whether the operation is deterministic. STR_LIMIT: int = 30 Maximum length of the string describing parameters to keep (for @@ -311,7 +312,7 @@ def __call__(self, *incoming, **kwargs): ) return self._circuit.add_const(value) - if self._circuit.CACHE_NODES: + if self._circuit.CACHE_NODES and self.PRECOMPUTABLE: node_cache_key = tuple(node.id for node in incoming) if self.SYMMETRIC: node_cache_key = tuple(sorted(node_cache_key)) @@ -328,7 +329,7 @@ def __call__(self, *incoming, **kwargs): self.after_create_node(node) - if self._circuit.CACHE_NODES: + if self._circuit.CACHE_NODES and self.PRECOMPUTABLE: self._circuit._nodes_cache[cache_key] = node return node diff --git a/tests/test_random.py b/tests/test_random.py new file mode 100644 index 0000000..4794c6c --- /dev/null +++ b/tests/test_random.py @@ -0,0 +1,35 @@ +import random +import pytest + +from circkit.arithmetic import ArithmeticCircuit +from circkit.arithmetic import OptArithmeticCircuit + + +class IntegerRing: + def __init__(self): + self.counter = 999 + + def random_element(self): + self.counter += 1 + return self.counter + + def __call__(self, value): + return int(value) + + +def test_random(): + C = ArithmeticCircuit(base_ring=IntegerRing(), name="TwoRandoms") + a = C.RND()() + b = C.RND()() + C.add_output(a) + C.add_output(b) + out = C.evaluate([]) + assert out == [1000, 1001] + + C = OptArithmeticCircuit(base_ring=IntegerRing(), name="TwoRandomsOpt") + a = C.RND()() + b = C.RND()() + C.add_output(a) + C.add_output(b) + out = C.evaluate([]) + assert out == [1000, 1001]