@@ -193,6 +193,8 @@ class OutlinedCall:
193193 def __call__ (self , * args : Any , ** kwargs : Any ) -> ArrayOrContainer :
194194 arg_id_to_arg = _get_arg_id_to_arg (args , kwargs )
195195
196+ from .utils import _verify_is_dag
197+
196198 if __debug__ :
197199 # Add a prefix to the names to distinguish them from any existing
198200 # placeholders
@@ -201,9 +203,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> ArrayOrContainer:
201203
202204 prefixed_output = _call_with_placeholders (
203205 self .f , args , kwargs , arg_id_to_prefixed_placeholder )
204- unpacked_prefixed_output = pt .transform .Deduplicator ()(
205- pt .make_dict_of_named_arrays (
206- _unpack_output (prefixed_output )))
206+ unpacked_prefixed_output = _verify_is_dag (
207+ pt .transform .Deduplicator ()(
208+ pt .make_dict_of_named_arrays (
209+ _unpack_output (prefixed_output ))))
207210
208211 prefixed_placeholders = frozenset (
209212 arg_id_to_prefixed_placeholder .values ())
@@ -220,9 +223,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> ArrayOrContainer:
220223 arg_id_to_placeholder = _get_arg_id_to_placeholder (arg_id_to_arg )
221224
222225 output = _call_with_placeholders (self .f , args , kwargs , arg_id_to_placeholder )
223- unpacked_output = pt .transform .Deduplicator ()(
224- pt .make_dict_of_named_arrays (
225- _unpack_output (output )))
226+ unpacked_output = _verify_is_dag (
227+ pt .transform .Deduplicator ()(
228+ pt .make_dict_of_named_arrays (
229+ _unpack_output (output ))))
226230 if len (unpacked_output ) == 1 and "_" in unpacked_output :
227231 ret_type = pt .function .ReturnType .ARRAY
228232 else :
0 commit comments