Skip to content
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,7 @@ test-venv
__pycache__

# egg-info
sdp.egg-info
sdp.egg-info

# build
build
4 changes: 3 additions & 1 deletion docs/src/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,14 @@ def setup(app):

nitpick_ignore = [
('py:class', 'abc.ABC'),
('py:class', 'sdp.processors.base_processor.DataEntry'),
('py:class', 'optional'),
('py:mod', 'sdp.utils.apply_operators'),
]
# nitpick_ignore_regex = [('py:class', '*')]

#adding this especially for coraal, temporary
linkcheck_ignore = [
r'https://lingtools\.uoregon\.edu/coraal/coraal_download_list\.txt',
r'https://ieeexplore\.ieee\.org/document/1326009'
]
# https://lingtools.uoregon.edu/coraal/coraal_download_list.txt
3 changes: 3 additions & 0 deletions docs/src/sdp/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ Data modifications
.. autodata:: sdp.processors.InverseNormalizeText
:annotation:

.. autodata:: sdp.processors.LambdaExpression
:annotation:

Data filtering
''''''''''''''

Expand Down
1 change: 1 addition & 0 deletions sdp/processors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
SubIfASRSubstitution,
SubMakeLowercase,
SubRegex,
LambdaExpression,
)
from sdp.processors.modify_manifest.data_to_dropbool import (
DropASRError,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def restore_pc(orig_words, norm_words):
# separately in normalized form, so just removing the comma here
add_punct = ""
if orig_text[idx_orig][0].isdigit() and not orig_text[idx_orig].isdigit():
number, word = re.split('(\d+)', orig_text[idx_orig])[1:]
number, word = re.split(r'(\d+)', orig_text[idx_orig])[1:]
orig_text[idx_orig] = number
if word in string.punctuation:
add_punct = word
Expand All @@ -87,7 +87,7 @@ def restore_pc(orig_words, norm_words):
# another annoying case is if typo ends with number like here "dell'11"
# same logic, but need to go back to the first check, so doing "continue" below
if orig_text[idx_orig][-1].isdigit() and not orig_text[idx_orig].isdigit():
word, number = re.split('(\d+)', orig_text[idx_orig])[:-1]
word, number = re.split(r'(\d+)', orig_text[idx_orig])[:-1]
orig_text[idx_orig] = word
orig_text.insert(idx_orig + 1, number)
continue
Expand Down
105 changes: 98 additions & 7 deletions sdp/processors/modify_manifest/data_to_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,8 @@
from sdp.utils.common import ffmpeg_convert
from sdp.utils.edit_spaces import add_start_end_spaces, remove_extra_spaces
from sdp.utils.get_diff import get_diff_with_subs_grouped
from sdp.utils.metrics_computation import (
get_cer,
get_charrate,
get_wer,
get_wmr,
get_wordrate,
)
from sdp.utils.metrics_computation import get_wer
from sdp.utils.apply_operators import evaluate_expression


class GetAudioDuration(BaseParallelProcessor):
Expand Down Expand Up @@ -1127,3 +1122,99 @@ def process(self):
if self.failed_files:
logger.warning(f"Failed to process {len(self.failed_files)} files.")
logger.debug(f"Failed files: {self.failed_files}")


class LambdaExpression(BaseParallelProcessor):
"""
A dataset processor that evaluates a Python expression on each data entry and either stores
the result in a new field or uses it as a filtering condition.

This processor is useful for dynamic field computation or conditional filtering of entries based
on configurable expressions. It leverages ``evaluate_expression``, which safely evaluates expressions
using the abstract syntax tree (AST).

Filtering behavior:
If ``filter=True``, the expression is evaluated for each entry. Only entries for which the expression evaluates to ``True`` are kept; all others are filtered out (removed from the output).
If ``filter=False``, the result of the expression is stored in the field specified by ``new_field`` for each entry (no filtering occurs).

Examples::

# Example 1: Filtering entries where the duration is greater than 5.0 seconds
LambdaExpression(
new_field="keep", # This field is ignored when filter=True
expression="entry['duration'] > 5.0",
lambda_param_name="entry",
filter=True
)
# Only entries with duration > 5.0 will be kept in the output manifest.

# Example 2: Adding a new field with the number of words in the text
LambdaExpression(
new_field="num_words",
expression="len(entry['text'].split())",
lambda_param_name="entry",
filter=False
)
# Each entry will have a new field 'num_words' with the word count of the 'text' field.

Supported operations:

The expression supports a safe subset of Python operations, including:

- Arithmetic: ``+``, ``-``, ``*``, ``/``, ``//``, ``%``, ``**``
- Comparisons: ``==``, ``!=``, ``<``, ``<=``, ``>``, ``>=``, ``is``, ``is not``
- Logical: ``and``, ``or``, ``not``
- Bitwise: ``|``, ``&``, ``^``, ``~``, ``<<``, ``>>``
- Indexing and slicing: ``entry['key']``, ``entry[0]``, ``entry[1:3]``
- Conditional (ternary) expressions: ``a if cond else b``
- List and dict literals: ``[a, b]``, ``{k: v}``
- Attribute access: ``entry.attr``
- Function calls (limited): ``max``, ``min``, ``len``, ``sum``, ``abs``, ``sorted``

For the full list, see the ``OPERATORS`` and ``SAFE_FUNCTIONS`` in :mod:`sdp.utils.apply_operators`.
See also: https://docs.python.org/3/library/operator.html

Args:
new_field (str): The name of the field to store the result of the expression (ignored if filter=True).
expression (str): A Python expression to evaluate. It can reference fields of the data entry
using the name specified by ``lambda_param_name`` (default: 'entry').
lambda_param_name (str, optional): The name to refer to the current data entry in the expression.
Default is "entry".
filter (bool, optional): If True, the expression result is treated as a condition.
The entry is kept only if the result is ``True``. Default is ``False``.
**kwargs: Additional keyword arguments passed to the ``BaseParallelProcessor`` class.

Returns:
str: A line-delimited JSON manifest, where each line is a processed entry.
The result may contain fewer entries than the input if ``filter=True``.
"""
def __init__(
self,
new_field: str,
expression: str,
lambda_param_name: str = "entry",
filter: bool = False,
**kwargs,
):
super().__init__(**kwargs)
self.new_field = new_field
self.expression = expression
self.lambda_param_name = lambda_param_name
self.filter = filter

def process_dataset_entry(self, data_entry) -> List[DataEntry]:
"""
Process a single data entry by evaluating the expression.

If `filter` is True, the entry is only retained if the expression evaluates to True.
Otherwise, the result is stored in `new_field`.
"""
value = evaluate_expression(self.expression, data_entry, self.lambda_param_name)
if self.filter:
if value is not True:
return []
data_entry[self.new_field] = value
return [DataEntry(data=data_entry)]

def finalize(self, metrics):
super().finalize(metrics)
171 changes: 171 additions & 0 deletions sdp/utils/apply_operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import operator
import ast
import re
from typing import Any, Dict

"""
This module provides a safe evaluator for simple Python expressions using the abstract syntax tree (AST).
It restricts execution to a subset of safe operations (arithmetic, logical, comparisons, indexing, etc.)
and selected built-in functions (e.g., max, min, len), while preventing arbitrary code execution.

Useful in cases where dynamic expressions need to be evaluated using a provided variable context,
such as configuration systems, data transformation pipelines, or manifest filtering.

Functions:
- evaluate_expression: Safely evaluates a Python expression string using restricted AST operations.
"""

OPERATORS = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.FloorDiv: operator.floordiv,
ast.Mod: operator.mod,
ast.Pow: operator.pow,
ast.BitOr: operator.or_,
ast.BitAnd: operator.and_,
ast.BitXor: operator.xor,
ast.LShift: operator.lshift,
ast.RShift: operator.rshift,
ast.Invert: operator.invert,
ast.USub: operator.neg,
ast.UAdd: operator.pos,
ast.Eq: operator.eq,
ast.NotEq: operator.ne,
ast.Lt: operator.lt,
ast.LtE: operator.le,
ast.Gt: operator.gt,
ast.GtE: operator.ge,
ast.Is: operator.is_,
ast.IsNot: operator.is_not,
ast.And: operator.and_,
ast.Or: operator.or_,
ast.Not: operator.not_,
}

SAFE_FUNCTIONS = {
'max': max,
'min': min,
'len': len,
'sum': sum,
'abs': abs,
'sorted': sorted,
}


def evaluate_expression(expression: str, variables: Dict[str, Any] = None, var_prefix: str = None) -> Any:
"""
Safely evaluates a Python expression string using a restricted set of AST nodes and operators.

Args:
expression (str): The expression to evaluate.
variables (Dict[str, Any], optional): A dictionary of variable names and values to use in evaluation.
var_prefix (str, optional): If specified, this prefix will be removed from variable names
in the expression before evaluation.

Returns:
any: The result of evaluating the expression.

Raises:
ValueError: If the expression contains unsupported operations or names.
"""
if variables is None:
variables = {}

def _eval(node):
match node:
case ast.Expression():
return _eval(node.body)

case ast.BinOp():
left = _eval(node.left)
right = _eval(node.right)
return OPERATORS[type(node.op)](left, right)

case ast.UnaryOp():
operand = _eval(node.operand)
return OPERATORS[type(node.op)](operand)

case ast.Subscript():
value = _eval(node.value)
match node.slice:
case ast.Slice():
start = _eval(node.slice.lower) if node.slice.lower else None
stop = _eval(node.slice.upper) if node.slice.upper else None
step = _eval(node.slice.step) if node.slice.step else None
return value[start:stop:step]
case _:
key = _eval(node.slice)
return value[key]

case ast.Compare():
left = _eval(node.left)
right = _eval(node.comparators[0])
return OPERATORS[type(node.ops[0])](left, right)

case ast.BoolOp():
values = [_eval(v) for v in node.values]
match node.op:
case ast.And():
return all(values)
case ast.Or():
return any(values)

case ast.IfExp():
test = _eval(node.test)
return _eval(node.body) if test else _eval(node.orelse)

case ast.Constant():
return node.value

case ast.Name():
var_name = node.id
if var_name in variables:
return variables[var_name]
elif var_name in {"True", "False"}:
return eval(var_name)
raise ValueError(f"Unsupported name: {var_name}")

case ast.Call():
func_name = node.func.id if isinstance(node.func, ast.Name) else None
if func_name in SAFE_FUNCTIONS:
func = SAFE_FUNCTIONS[func_name]
args = [_eval(arg) for arg in node.args]
return func(*args)
else:
raise ValueError(f"Function {func_name} is not allowed")

case ast.List():
return [_eval(elt) for elt in node.elts]

case ast.Dict():
return {_eval(k): _eval(v) for k, v in zip(node.keys, node.values)}

case ast.Attribute():
value = _eval(node.value)
return getattr(value, node.attr)

case _:
raise ValueError(f"Unsupported node type: {type(node)}")

if var_prefix:
var_prefix += '.'
expression = re.sub(rf'{re.escape(var_prefix)}(\w+)', r'\1', expression)

tree = ast.parse(expression, mode='eval')
return _eval(tree.body)
Loading
Loading