diff --git a/.gitignore b/.gitignore index 559fd1bd..9b51c176 100644 --- a/.gitignore +++ b/.gitignore @@ -24,4 +24,7 @@ test-venv __pycache__ # egg-info -sdp.egg-info \ No newline at end of file +sdp.egg-info + +# build +build \ No newline at end of file diff --git a/docs/src/conf.py b/docs/src/conf.py index 975a5dbf..11964c06 100644 --- a/docs/src/conf.py +++ b/docs/src/conf.py @@ -184,12 +184,14 @@ def setup(app): nitpick_ignore = [ ('py:class', 'abc.ABC'), - ('py:class', 'sdp.processors.base_processor.DataEntry'), + ('py:class', 'optional'), + ('py:mod', 'sdp.utils.apply_operators'), ] # nitpick_ignore_regex = [('py:class', '*')] #adding this especially for coraal, temporary linkcheck_ignore = [ r'https://lingtools\.uoregon\.edu/coraal/coraal_download_list\.txt', + r'https://ieeexplore\.ieee\.org/document/1326009' ] # https://lingtools.uoregon.edu/coraal/coraal_download_list.txt \ No newline at end of file diff --git a/docs/src/sdp/api.rst b/docs/src/sdp/api.rst index dcdd13bc..5ee21e2a 100644 --- a/docs/src/sdp/api.rst +++ b/docs/src/sdp/api.rst @@ -219,6 +219,9 @@ Data modifications .. autodata:: sdp.processors.InverseNormalizeText :annotation: +.. autodata:: sdp.processors.LambdaExpression + :annotation: + Data filtering '''''''''''''' diff --git a/sdp/processors/__init__.py b/sdp/processors/__init__.py index c3ff70b6..9e88b4f6 100644 --- a/sdp/processors/__init__.py +++ b/sdp/processors/__init__.py @@ -106,6 +106,7 @@ SubIfASRSubstitution, SubMakeLowercase, SubRegex, + LambdaExpression, ) from sdp.processors.modify_manifest.data_to_dropbool import ( DropASRError, diff --git a/sdp/processors/datasets/voxpopuli/normalize_from_non_pc_text.py b/sdp/processors/datasets/voxpopuli/normalize_from_non_pc_text.py index ebe86083..02ade228 100644 --- a/sdp/processors/datasets/voxpopuli/normalize_from_non_pc_text.py +++ b/sdp/processors/datasets/voxpopuli/normalize_from_non_pc_text.py @@ -77,7 +77,7 @@ def restore_pc(orig_words, norm_words): # separately in normalized form, so just removing the comma here add_punct = "" if orig_text[idx_orig][0].isdigit() and not orig_text[idx_orig].isdigit(): - number, word = re.split('(\d+)', orig_text[idx_orig])[1:] + number, word = re.split(r'(\d+)', orig_text[idx_orig])[1:] orig_text[idx_orig] = number if word in string.punctuation: add_punct = word @@ -87,7 +87,7 @@ def restore_pc(orig_words, norm_words): # another annoying case is if typo ends with number like here "dell'11" # same logic, but need to go back to the first check, so doing "continue" below if orig_text[idx_orig][-1].isdigit() and not orig_text[idx_orig].isdigit(): - word, number = re.split('(\d+)', orig_text[idx_orig])[:-1] + word, number = re.split(r'(\d+)', orig_text[idx_orig])[:-1] orig_text[idx_orig] = word orig_text.insert(idx_orig + 1, number) continue diff --git a/sdp/processors/modify_manifest/data_to_data.py b/sdp/processors/modify_manifest/data_to_data.py index 16e1de6d..9f3d496c 100644 --- a/sdp/processors/modify_manifest/data_to_data.py +++ b/sdp/processors/modify_manifest/data_to_data.py @@ -33,13 +33,8 @@ from sdp.utils.common import ffmpeg_convert from sdp.utils.edit_spaces import add_start_end_spaces, remove_extra_spaces from sdp.utils.get_diff import get_diff_with_subs_grouped -from sdp.utils.metrics_computation import ( - get_cer, - get_charrate, - get_wer, - get_wmr, - get_wordrate, -) +from sdp.utils.metrics_computation import get_wer +from sdp.utils.apply_operators import evaluate_expression class GetAudioDuration(BaseParallelProcessor): @@ -1127,3 +1122,99 @@ def process(self): if self.failed_files: logger.warning(f"Failed to process {len(self.failed_files)} files.") logger.debug(f"Failed files: {self.failed_files}") + + +class LambdaExpression(BaseParallelProcessor): + """ + A dataset processor that evaluates a Python expression on each data entry and either stores + the result in a new field or uses it as a filtering condition. + + This processor is useful for dynamic field computation or conditional filtering of entries based + on configurable expressions. It leverages ``evaluate_expression``, which safely evaluates expressions + using the abstract syntax tree (AST). + + Filtering behavior: + If ``filter=True``, the expression is evaluated for each entry. Only entries for which the expression evaluates to ``True`` are kept; all others are filtered out (removed from the output). + If ``filter=False``, the result of the expression is stored in the field specified by ``new_field`` for each entry (no filtering occurs). + + Examples:: + + # Example 1: Filtering entries where the duration is greater than 5.0 seconds + LambdaExpression( + new_field="keep", # This field is ignored when filter=True + expression="entry['duration'] > 5.0", + lambda_param_name="entry", + filter=True + ) + # Only entries with duration > 5.0 will be kept in the output manifest. + + # Example 2: Adding a new field with the number of words in the text + LambdaExpression( + new_field="num_words", + expression="len(entry['text'].split())", + lambda_param_name="entry", + filter=False + ) + # Each entry will have a new field 'num_words' with the word count of the 'text' field. + + Supported operations: + + The expression supports a safe subset of Python operations, including: + + - Arithmetic: ``+``, ``-``, ``*``, ``/``, ``//``, ``%``, ``**`` + - Comparisons: ``==``, ``!=``, ``<``, ``<=``, ``>``, ``>=``, ``is``, ``is not`` + - Logical: ``and``, ``or``, ``not`` + - Bitwise: ``|``, ``&``, ``^``, ``~``, ``<<``, ``>>`` + - Indexing and slicing: ``entry['key']``, ``entry[0]``, ``entry[1:3]`` + - Conditional (ternary) expressions: ``a if cond else b`` + - List and dict literals: ``[a, b]``, ``{k: v}`` + - Attribute access: ``entry.attr`` + - Function calls (limited): ``max``, ``min``, ``len``, ``sum``, ``abs``, ``sorted`` + + For the full list, see the ``OPERATORS`` and ``SAFE_FUNCTIONS`` in :mod:`sdp.utils.apply_operators`. + See also: https://docs.python.org/3/library/operator.html + + Args: + new_field (str): The name of the field to store the result of the expression (ignored if filter=True). + expression (str): A Python expression to evaluate. It can reference fields of the data entry + using the name specified by ``lambda_param_name`` (default: 'entry'). + lambda_param_name (str, optional): The name to refer to the current data entry in the expression. + Default is "entry". + filter (bool, optional): If True, the expression result is treated as a condition. + The entry is kept only if the result is ``True``. Default is ``False``. + **kwargs: Additional keyword arguments passed to the ``BaseParallelProcessor`` class. + + Returns: + str: A line-delimited JSON manifest, where each line is a processed entry. + The result may contain fewer entries than the input if ``filter=True``. + """ + def __init__( + self, + new_field: str, + expression: str, + lambda_param_name: str = "entry", + filter: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + self.new_field = new_field + self.expression = expression + self.lambda_param_name = lambda_param_name + self.filter = filter + + def process_dataset_entry(self, data_entry) -> List[DataEntry]: + """ + Process a single data entry by evaluating the expression. + + If `filter` is True, the entry is only retained if the expression evaluates to True. + Otherwise, the result is stored in `new_field`. + """ + value = evaluate_expression(self.expression, data_entry, self.lambda_param_name) + if self.filter: + if value is not True: + return [] + data_entry[self.new_field] = value + return [DataEntry(data=data_entry)] + + def finalize(self, metrics): + super().finalize(metrics) \ No newline at end of file diff --git a/sdp/utils/apply_operators.py b/sdp/utils/apply_operators.py new file mode 100644 index 00000000..80db65b8 --- /dev/null +++ b/sdp/utils/apply_operators.py @@ -0,0 +1,171 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import operator +import ast +import re +from typing import Any, Dict + +""" +This module provides a safe evaluator for simple Python expressions using the abstract syntax tree (AST). +It restricts execution to a subset of safe operations (arithmetic, logical, comparisons, indexing, etc.) +and selected built-in functions (e.g., max, min, len), while preventing arbitrary code execution. + +Useful in cases where dynamic expressions need to be evaluated using a provided variable context, +such as configuration systems, data transformation pipelines, or manifest filtering. + +Functions: + - evaluate_expression: Safely evaluates a Python expression string using restricted AST operations. +""" + +OPERATORS = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.FloorDiv: operator.floordiv, + ast.Mod: operator.mod, + ast.Pow: operator.pow, + ast.BitOr: operator.or_, + ast.BitAnd: operator.and_, + ast.BitXor: operator.xor, + ast.LShift: operator.lshift, + ast.RShift: operator.rshift, + ast.Invert: operator.invert, + ast.USub: operator.neg, + ast.UAdd: operator.pos, + ast.Eq: operator.eq, + ast.NotEq: operator.ne, + ast.Lt: operator.lt, + ast.LtE: operator.le, + ast.Gt: operator.gt, + ast.GtE: operator.ge, + ast.Is: operator.is_, + ast.IsNot: operator.is_not, + ast.And: operator.and_, + ast.Or: operator.or_, + ast.Not: operator.not_, +} + +SAFE_FUNCTIONS = { + 'max': max, + 'min': min, + 'len': len, + 'sum': sum, + 'abs': abs, + 'sorted': sorted, +} + + +def evaluate_expression(expression: str, variables: Dict[str, Any] = None, var_prefix: str = None) -> Any: + """ + Safely evaluates a Python expression string using a restricted set of AST nodes and operators. + + Args: + expression (str): The expression to evaluate. + variables (Dict[str, Any], optional): A dictionary of variable names and values to use in evaluation. + var_prefix (str, optional): If specified, this prefix will be removed from variable names + in the expression before evaluation. + + Returns: + any: The result of evaluating the expression. + + Raises: + ValueError: If the expression contains unsupported operations or names. + """ + if variables is None: + variables = {} + + def _eval(node): + match node: + case ast.Expression(): + return _eval(node.body) + + case ast.BinOp(): + left = _eval(node.left) + right = _eval(node.right) + return OPERATORS[type(node.op)](left, right) + + case ast.UnaryOp(): + operand = _eval(node.operand) + return OPERATORS[type(node.op)](operand) + + case ast.Subscript(): + value = _eval(node.value) + match node.slice: + case ast.Slice(): + start = _eval(node.slice.lower) if node.slice.lower else None + stop = _eval(node.slice.upper) if node.slice.upper else None + step = _eval(node.slice.step) if node.slice.step else None + return value[start:stop:step] + case _: + key = _eval(node.slice) + return value[key] + + case ast.Compare(): + left = _eval(node.left) + right = _eval(node.comparators[0]) + return OPERATORS[type(node.ops[0])](left, right) + + case ast.BoolOp(): + values = [_eval(v) for v in node.values] + match node.op: + case ast.And(): + return all(values) + case ast.Or(): + return any(values) + + case ast.IfExp(): + test = _eval(node.test) + return _eval(node.body) if test else _eval(node.orelse) + + case ast.Constant(): + return node.value + + case ast.Name(): + var_name = node.id + if var_name in variables: + return variables[var_name] + elif var_name in {"True", "False"}: + return eval(var_name) + raise ValueError(f"Unsupported name: {var_name}") + + case ast.Call(): + func_name = node.func.id if isinstance(node.func, ast.Name) else None + if func_name in SAFE_FUNCTIONS: + func = SAFE_FUNCTIONS[func_name] + args = [_eval(arg) for arg in node.args] + return func(*args) + else: + raise ValueError(f"Function {func_name} is not allowed") + + case ast.List(): + return [_eval(elt) for elt in node.elts] + + case ast.Dict(): + return {_eval(k): _eval(v) for k, v in zip(node.keys, node.values)} + + case ast.Attribute(): + value = _eval(node.value) + return getattr(value, node.attr) + + case _: + raise ValueError(f"Unsupported node type: {type(node)}") + + if var_prefix: + var_prefix += '.' + expression = re.sub(rf'{re.escape(var_prefix)}(\w+)', r'\1', expression) + + tree = ast.parse(expression, mode='eval') + return _eval(tree.body) \ No newline at end of file diff --git a/tests/test_data_to_data.py b/tests/test_data_to_data.py index 5bd75f47..9f62f730 100644 --- a/tests/test_data_to_data.py +++ b/tests/test_data_to_data.py @@ -19,6 +19,7 @@ SubIfASRSubstitution, SubMakeLowercase, SubRegex, + LambdaExpression, ) test_params_list = [] @@ -29,13 +30,13 @@ InsIfASRInsertion, {"insert_words": [" nemo", "nemo ", " nemo "]}, {"text": "i love the toolkit", "pred_text": "i love the nemo toolkit"}, - {"text": "i love the nemo toolkit", "pred_text": "i love the nemo toolkit"}, + [{"text": "i love the nemo toolkit", "pred_text": "i love the nemo toolkit"}], ), ( InsIfASRInsertion, {"insert_words": [" nemo", "nemo ", " nemo "]}, {"text": "i love the toolkit", "pred_text": "i love the new nemo toolkit"}, - {"text": "i love the toolkit", "pred_text": "i love the new nemo toolkit"}, + [{"text": "i love the toolkit", "pred_text": "i love the new nemo toolkit"}], ), ] ) @@ -46,7 +47,7 @@ SubIfASRSubstitution, {"sub_words": {"nmo ": "nemo "}}, {"text": "i love the nmo toolkit", "pred_text": "i love the nemo toolkit"}, - {"text": "i love the nemo toolkit", "pred_text": "i love the nemo toolkit"}, + [{"text": "i love the nemo toolkit", "pred_text": "i love the nemo toolkit"}], ), ] ) @@ -57,7 +58,7 @@ SubIfASRSubstitution, {"sub_words": {"nmo ": "nemo "}}, {"text": "i love the nmo toolkit", "pred_text": "i love the nemo toolkit"}, - {"text": "i love the nemo toolkit", "pred_text": "i love the nemo toolkit"}, + [{"text": "i love the nemo toolkit", "pred_text": "i love the nemo toolkit"}], ), ] ) @@ -68,13 +69,13 @@ SubMakeLowercase, {}, {"text": "Hello Привет 123"}, - {"text": "hello привет 123"}, + [{"text": "hello привет 123"}], ), ( SubMakeLowercase, {"text_key": "text_new"}, {"text_new": "Hello Привет 123"}, - {"text_new": "hello привет 123"}, + [{"text_new": "hello привет 123"}], ), ] ) @@ -83,9 +84,96 @@ [ ( SubRegex, - {"regex_params_list": [{"pattern": "\s<.*>\s", "repl": " "}]}, + {"regex_params_list": [{"pattern": r"\s<.*>\s", "repl": " "}]}, {"text": "hello world"}, + [{"text": "hello world"}], + ), + ] +) + + +test_params_list.extend( + [ + # Simple arithmetic expression + ( + LambdaExpression, + {"new_field": "duration_x2", "expression": "entry.duration * 2"}, + {"duration": 3.5}, + [{"duration": 3.5, "duration_x2": 7.0}], + ), + + # Ternary expression + ( + LambdaExpression, + {"new_field": "label", "expression": "'long' if entry.duration > 10 else 'short'"}, + {"duration": 12.0}, + [{"duration": 12.0, "label": "long"}], + ), + + # Filtering: entry should be dropped (condition is False) + ( + LambdaExpression, + {"new_field": "valid", "expression": "entry.duration > 10", "filter": True}, + {"duration": 5.0}, + [], + ), + + # Filtering: entry should be kept (condition is True) + ( + LambdaExpression, + {"new_field": "valid", "expression": "entry.duration > 10", "filter": True}, + {"duration": 12.0}, + [{"duration": 12.0, "valid": True}], + ), + + # Using built-in function len() + ( + LambdaExpression, + {"new_field": "num_chars", "expression": "len(entry.text)"}, {"text": "hello world"}, + [{"text": "hello world", "num_chars": 11}], + ), + + # Using built-in max() with sub-expressions + ( + LambdaExpression, + {"new_field": "score", "expression": "max(entry.a, entry.b * 2)"}, + {"a": 4, "b": 3}, + [{"a": 4, "b": 3, "score": 6}], + ), + + # Expression using variable prefix (e.g., entry.a + entry.b) + ( + LambdaExpression, + { + "new_field": "sum", + "expression": "entry.a + entry.b", + "lambda_param_name": "entry", + }, + {"a": 1, "b": 2}, + [{"a": 1, "b": 2, "sum": 3}], + ), + + # Logical expression using `and` + ( + LambdaExpression, + { + "new_field": "check", + "expression": "entry.a > 0 and entry.b < 5", + }, + {"a": 1, "b": 4}, + [{"a": 1, "b": 4, "check": True}], + ), + + # Boolean expression without filtering (entry is always returned) + ( + LambdaExpression, + { + "new_field": "is_zero", + "expression": "entry.value == 0", + }, + {"value": 5}, + [{"value": 5, "is_zero": False}], ), ] ) @@ -94,7 +182,6 @@ @pytest.mark.parametrize("test_class,class_kwargs,test_input,expected_output", test_params_list, ids=str) def test_data_to_data(test_class, class_kwargs, test_input, expected_output): processor = test_class(**class_kwargs, output_manifest_file=None) + result = [entry.data for entry in processor.process_dataset_entry(test_input)] - output = processor.process_dataset_entry(test_input)[0].data - - assert output == expected_output + assert result == expected_output \ No newline at end of file