Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions conda/dali_python_bindings/recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ requirements:
- future
- astunparse >=1.6.0
- gast >=0.3.3
- dm-tree >=0.1.8
- optree
- packaging
- nvtx
- makefun
Expand All @@ -100,7 +100,7 @@ requirements:
- future
- astunparse >=1.6.0
- gast >=0.3.3
- dm-tree >=0.1.8
- optree
- packaging
- nvtx
- makefun
Expand Down
24 changes: 12 additions & 12 deletions dali/python/nvidia/dali/_conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,19 @@

from enum import Enum

import tree
import optree


def _data_node_repr(data_node):
return f"DataNode(name={data_node.name}, device={data_node.device}, source={data_node.source})"


def _map_structure(func, *structures, **kwargs):
"""Custom wrapper over tree.map_structure that filters it out from the user-visible stack trace
"""Custom wrapper over optree.tree_map that filters it out from the user-visible stack trace
for error reporting purposes.
"""
with _autograph.CustomModuleFilter(tree):
return tree.map_structure(func, *structures, **kwargs)
with _autograph.CustomModuleFilter(optree):
return optree.tree_map(func, *structures, **kwargs)


class _Branch(Enum):
Expand Down Expand Up @@ -610,14 +610,14 @@ def if_stmt(self, cond, body, orelse, get_state, set_state, symbol_names, nouts)
" same set of keys, the values may be different.\n"
)

try:
tree.assert_same_structure(body_outputs, orelse_outputs, check_types=True)
except ValueError as e:
# Suppress the original exception, add DALI explanation at the beginning,
# raise the full error message.
raise ValueError(err_msg + str(e)) from None
except TypeError as e:
raise TypeError(err_msg + str(e)) from None
body_structure = optree.tree_structure(body_outputs)
orelse_structure = optree.tree_structure(orelse_outputs)
if body_structure != orelse_structure:
raise ValueError(
f"{err_msg}\n"
f"'If' output structure: {optree.tree_map(lambda _: '*', body_outputs)}\n"
f"'Else' output structure: {optree.tree_map(lambda _: '*', orelse_outputs)}"
)

def merge_branches(new_body_val, new_orelse_val):
logging.log(
Expand Down
4 changes: 2 additions & 2 deletions dali/python/nvidia/dali/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# pylint: disable=no-member
import sys
import threading
import tree
import optree
import warnings
import weakref
from itertools import count
Expand Down Expand Up @@ -853,7 +853,7 @@ def _promote_scalar_constant(value, input_device):
dev = get_input_device(schema, idx)
# Process the single ScalarConstant or list possibly containing ScalarConstants
# and promote each of them into a DataNode
inp = tree.map_structure(lambda val: _promote_scalar_constant(val, dev), inp)
inp = optree.tree_map(lambda val: _promote_scalar_constant(val, dev), inp)

inputs[idx] = inp
return inputs
Expand Down
2 changes: 1 addition & 1 deletion dali/python/setup.py.in
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ For more details please check the
# the latest astunparse (1.6.3) doesn't work with any other six than
# 1.16 or later on python 3.12 due to import six.moves
'six >= 1.16, <= 1.17',
'dm-tree <= 0.1.9; python_version>="3.10"',
'optree',
'packaging <= 25.0',
'numpy',
'nvtx',
Expand Down
9 changes: 3 additions & 6 deletions dali/test/python/conditionals/test_nests.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,9 @@ def pipeline():
"*Divergent data found in different branches of `if/else` control"
" flow statement. Variables in all code paths are merged into common"
" output batches. The values assigned to a given variable need to"
" have the same nesting structure in every code path"
" (both `if` branches).*"
"*The two structures don't have the same nested structure*"
"*The two dictionaries don't have the same set of keys."
" First structure has keys type=list str=*'out', 'mismatched'*,"
" while second structure has keys type=list str=*'out'*"
" have the same nesting structure in every code path (both `if` branches).*"
"'If' output structure:*"
"'Else' output structure:*"
),
):
_ = pipeline()
4 changes: 2 additions & 2 deletions dali/test/python/operator_2/test_enum_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from nvidia.dali import fn, pipeline_def, types

import numpy as np
import tree
import optree

from nose_utils import assert_raises
from nose2.tools import params
Expand All @@ -31,7 +31,7 @@
lambda value, dtype: types.Constant(value=value),
# Explicit type when passed the underlying numeric value of the enum
lambda value, dtype: types.Constant(
value=tree.map_structure(lambda v: v.value, value), dtype=dtype
value=optree.tree_map(lambda v: v.value, value), dtype=dtype
),
]
)
Expand Down
2 changes: 1 addition & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ ${LIBRARY_PATH}
RUN ln -s /opt/python/cp${PYV}* /opt/python/v

# install Python bindings and patch it to use the clang we have here
RUN pip install future setuptools wheel clang==14.0 flake8 bandit astunparse gast dm-tree "black[jupyter]"==25.12.0 nvtx makefun && \
RUN pip install future setuptools wheel clang==14.0 flake8 bandit astunparse gast optree numpy "black[jupyter]"==25.12.0 nvtx makefun && \
PY_CLANG_PATH=$(echo $(pip show clang) | sed 's/.*Location: \(.*\) Requires.*/\1/')/clang/cindex.py && \
LIBCLANG_PATH=/usr/lib64/libclang.so && \
sed -i "s|library_file = None|library_file = \"${LIBCLANG_PATH}\"|" ${PY_CLANG_PATH} && \
Expand Down
2 changes: 1 addition & 1 deletion qa/TL1_custom_src_pattern_build/test.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash -e

pip_packages='astunparse gast dm-tree black nvtx makefun'
pip_packages='astunparse gast optree black nvtx makefun'

build_and_check() {
make -j
Expand Down