diff --git a/CHANGELOG.md b/CHANGELOG.md index 942d88f..c02f704 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,9 @@ now be installed separately through `watergrid[...]` metapackages. (#54) ### Fixed +- Resolved issue with objects inside a `DataContext` not being copied when the + output mode of a step is set to `SPLIT`. (#6) + ### Security ## [1.0.1] - 2022-04-01 diff --git a/test/bug_6_tests.py b/test/bug_6_tests.py new file mode 100644 index 0000000..8a4bdc4 --- /dev/null +++ b/test/bug_6_tests.py @@ -0,0 +1,60 @@ +import unittest + +from watergrid.context import DataContext, OutputMode +from watergrid.pipelines.pipeline import Pipeline +from watergrid.steps import Step + + +class TestDTO: + def __init__(self, value): + self.value = value + + def set_value(self, value): + self.value = value + + def get_value(self): + return self.value + + +class TestCreateDTOStep(Step): + def run(self, context: DataContext): + context.set("dto", TestDTO(1)) + context.set("val", [1, 2]) + context.set_output_mode(OutputMode.SPLIT) + + def __init__(self): + super().__init__("test_create_dto_step", provides=["val", "dto"]) + + +class TestModifyDTOStep(Step): + def __init__(self): + super().__init__("test_modify_dto_step", requires=["val"]) + + def run(self, context: DataContext): + if context.get("val") == 1: + context.get("dto").set_value(5) + + +class TestVerifyCopySafetyStep(Step): + def __init__(self): + super().__init__("test_verify_copy_safety_step", requires=["val"]) + self.mod_1_flag = -1 + self.keep_2_flag = -1 + + def run(self, context: DataContext): + if context.get("val") == 1: + self.mod_1_flag = context.get("dto").get_value() + elif context.get("val") == 2: + self.keep_2_flag = context.get("dto").get_value() + + +class Bug6TestCase(unittest.TestCase): + def test_bug6(self): + pipeline = Pipeline("test_pipeline") + pipeline.add_step(TestCreateDTOStep()) + pipeline.add_step(TestModifyDTOStep()) + step3 = TestVerifyCopySafetyStep() + pipeline.add_step(step3) + pipeline.run() + self.assertEqual(5, step3.mod_1_flag) + self.assertEqual(1, step3.keep_2_flag) diff --git a/watergrid/pipelines/pipeline.py b/watergrid/pipelines/pipeline.py index 6b7c958..ada9b72 100644 --- a/watergrid/pipelines/pipeline.py +++ b/watergrid/pipelines/pipeline.py @@ -1,3 +1,4 @@ +import copy import time from abc import ABC @@ -184,7 +185,8 @@ def __split_context( split_key = step_provides[0] split_value = context.get(split_key) for value in split_value: - new_context = DataContext.deep_copy_context(context) + new_context = copy.deepcopy(context) + new_context.set_output_mode(OutputMode.DIRECT) new_context.set(split_key, value) next_contexts.append(new_context) @@ -199,7 +201,9 @@ def __filter_context( :return: None """ if context.get(step_provides[0]) is not None: - next_contexts.append(DataContext.deep_copy_context(context)) + new_context = copy.deepcopy(context) + new_context.set_output_mode(OutputMode.DIRECT) + next_contexts.append(new_context) def __forward_context(self, context: DataContext, next_contexts: list): """ @@ -208,4 +212,4 @@ def __forward_context(self, context: DataContext, next_contexts: list): :param next_contexts: List of contexts that will be used by the next step. :return: None """ - next_contexts.append(DataContext.deep_copy_context(context)) + next_contexts.append(copy.deepcopy(context))