Skip to content

Commit 13ad2bd

Browse files
committed
leaf array bcast types: better code placement
1 parent 658fadb commit 13ad2bd

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

arraycontext/container/arithmetic.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -375,15 +375,15 @@ def {fname}(arg1):
375375
gen(f"return cls({zip_init_args})")
376376

377377
if _bcast_actx_array_type:
378-
all_outer_bcast_type_names = (
379-
outer_bcast_type_names
380-
+ ("*arg1.array_context.array_types",))
378+
ary_types = ("*arg1.array_context.array_types",)
381379
else:
382-
all_outer_bcast_type_names = outer_bcast_type_names
380+
ary_types = ()
383381

384382
gen(f"""
385383
if {bool(outer_bcast_type_names)}: # optimized away
386-
if isinstance(arg2, {tup_str(all_outer_bcast_type_names)}):
384+
if isinstance(arg2,
385+
{tup_str(outer_bcast_type_names
386+
+ ary_types)}):
387387
return cls({bcast_same_cls_init_args})
388388
if {numpy_pred("arg2")}:
389389
result = np.empty_like(arg2, dtype=object)
@@ -400,26 +400,27 @@ def {fname}(arg1):
400400
# {{{ "reverse" binary operators
401401

402402
if reversible:
403-
if _bcast_actx_array_type:
404-
all_outer_bcast_type_names = (
405-
outer_bcast_type_names
406-
+ ("*arg2.array_context.array_types",))
407-
else:
408-
all_outer_bcast_type_names = outer_bcast_type_names
409403
fname = f"_{cls.__name__.lower()}_r{dunder_name}"
410404
bcast_init_args = cls._deserialize_init_arrays_code("arg2", {
411405
key_arg2: _format_binary_op_str(
412406
op_str, "arg1", expr_arg2)
413407
for key_arg2, expr_arg2 in
414408
cls._serialize_init_arrays_code("arg2").items()
415409
})
410+
411+
if _bcast_actx_array_type:
412+
ary_types = ("*arg2.array_context.array_types",)
413+
else:
414+
ary_types = ()
415+
416416
gen(f"""
417417
def {fname}(arg2, arg1):
418418
# assert other.__cls__ is not cls
419419
420420
if {bool(outer_bcast_type_names)}: # optimized away
421421
if isinstance(arg1,
422-
{tup_str(all_outer_bcast_type_names)}):
422+
{tup_str(outer_bcast_type_names
423+
+ ary_types)}):
423424
return cls({bcast_init_args})
424425
if {numpy_pred("arg1")}:
425426
result = np.empty_like(arg1, dtype=object)

0 commit comments

Comments
 (0)