Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ now be installed separately through `watergrid[...]` metapackages. (#54)

### Fixed

- Resolved issue with objects inside a `DataContext` not being copied when the
output mode of a step is set to `SPLIT`. (#6)

### Security

## [1.0.1] - 2022-04-01
Expand Down
60 changes: 60 additions & 0 deletions test/bug_6_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import unittest

from watergrid.context import DataContext, OutputMode
from watergrid.pipelines.pipeline import Pipeline
from watergrid.steps import Step


class TestDTO:
def __init__(self, value):
self.value = value

def set_value(self, value):
self.value = value

def get_value(self):
return self.value


class TestCreateDTOStep(Step):
def run(self, context: DataContext):
context.set("dto", TestDTO(1))
context.set("val", [1, 2])
context.set_output_mode(OutputMode.SPLIT)

def __init__(self):
super().__init__("test_create_dto_step", provides=["val", "dto"])


class TestModifyDTOStep(Step):
def __init__(self):
super().__init__("test_modify_dto_step", requires=["val"])

def run(self, context: DataContext):
if context.get("val") == 1:
context.get("dto").set_value(5)


class TestVerifyCopySafetyStep(Step):
def __init__(self):
super().__init__("test_verify_copy_safety_step", requires=["val"])
self.mod_1_flag = -1
self.keep_2_flag = -1

def run(self, context: DataContext):
if context.get("val") == 1:
self.mod_1_flag = context.get("dto").get_value()
elif context.get("val") == 2:
self.keep_2_flag = context.get("dto").get_value()


class Bug6TestCase(unittest.TestCase):
def test_bug6(self):
pipeline = Pipeline("test_pipeline")
pipeline.add_step(TestCreateDTOStep())
pipeline.add_step(TestModifyDTOStep())
step3 = TestVerifyCopySafetyStep()
pipeline.add_step(step3)
pipeline.run()
self.assertEqual(5, step3.mod_1_flag)
self.assertEqual(1, step3.keep_2_flag)
10 changes: 7 additions & 3 deletions watergrid/pipelines/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import time
from abc import ABC

Expand Down Expand Up @@ -184,7 +185,8 @@ def __split_context(
split_key = step_provides[0]
split_value = context.get(split_key)
for value in split_value:
new_context = DataContext.deep_copy_context(context)
new_context = copy.deepcopy(context)
new_context.set_output_mode(OutputMode.DIRECT)
new_context.set(split_key, value)
next_contexts.append(new_context)

Expand All @@ -199,7 +201,9 @@ def __filter_context(
:return: None
"""
if context.get(step_provides[0]) is not None:
next_contexts.append(DataContext.deep_copy_context(context))
new_context = copy.deepcopy(context)
new_context.set_output_mode(OutputMode.DIRECT)
next_contexts.append(new_context)

def __forward_context(self, context: DataContext, next_contexts: list):
"""
Expand All @@ -208,4 +212,4 @@ def __forward_context(self, context: DataContext, next_contexts: list):
:param next_contexts: List of contexts that will be used by the next step.
:return: None
"""
next_contexts.append(DataContext.deep_copy_context(context))
next_contexts.append(copy.deepcopy(context))