diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py index c6435c8a6f4b..2d3b8b49d8d7 100644 --- a/sdks/python/apache_beam/options/pipeline_options.py +++ b/sdks/python/apache_beam/options/pipeline_options.py @@ -1716,6 +1716,21 @@ def _add_argparse_args(cls, parser): help=( 'Docker registry url to use for tagging and pushing the prebuilt ' 'sdk worker container image.')) + parser.add_argument( + '--gbek', + default=None, + help=( + 'When set, will replace all GroupByKey transforms in the pipeline ' + 'with EncryptedGroupByKey transforms using the secret passed in ' + 'the option. Beam will infer the secret type and value based on ' + 'secret itself. This guarantees that any data at rest during the ' + 'GBK will be encrypted. Many runners only store data at rest when ' + 'performing a GBK, so this can be used to guarantee that data is ' + 'not unencrypted. Runners with this behavior include the ' + 'Dataflow, Flink, and Spark runners. The option should be ' + 'structured like: ' + '--gbek=type:;:, for example ' + '--gbek=type:GcpSecret;version_name:my_secret/versions/latest')) parser.add_argument( '--user_agent', default=None, diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index cbd78d8222e8..db4a652cf97e 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -39,6 +39,7 @@ from apache_beam.coders import typecoders from apache_beam.internal import pickler from apache_beam.internal import util +from apache_beam.options.pipeline_options import SetupOptions from apache_beam.options.pipeline_options import TypeOptions from apache_beam.portability import common_urns from apache_beam.portability import python_urns @@ -3324,6 +3325,10 @@ class GroupByKey(PTransform): The implementation here is used only when run on the local direct runner. """ + def __init__(self): + self._replaced_by_gbek = False + self._inside_gbek = False + class ReifyWindows(DoFn): def process( self, element, window=DoFn.WindowParam, timestamp=DoFn.TimestampParam): @@ -3354,6 +3359,16 @@ def get_windowing(self, inputs): environment_id=windowing.environment_id) def expand(self, pcoll): + replace_with_gbek_secret = ( + pcoll.pipeline._options.view_as(SetupOptions).gbek) + if replace_with_gbek_secret is not None and not self._inside_gbek: + self._replaced_by_gbek = True + from apache_beam.transforms.util import GroupByEncryptedKey + from apache_beam.transforms.util import Secret + + secret = Secret.parse_secret_option(replace_with_gbek_secret) + return (pcoll | "Group by encrypted key" >> GroupByEncryptedKey(secret)) + from apache_beam.transforms.trigger import DataLossReason from apache_beam.transforms.trigger import DefaultTrigger windowing = pcoll.windowing @@ -3400,7 +3415,11 @@ def infer_output_type(self, input_type): return typehints.KV[key_type, typehints.Iterable[value_type]] def to_runner_api_parameter(self, unused_context): - # type: (PipelineContext) -> typing.Tuple[str, None] + # type: (PipelineContext) -> tuple[str, typing.Optional[typing.Union[message.Message, bytes, str]]] + # if we're containing a GroupByEncryptedKey, don't allow runners to + # recognize this transform as a GBEK so that it doesn't get replaced. + if self._replaced_by_gbek: + return super().to_runner_api_parameter(unused_context) return common_urns.primitives.GROUP_BY_KEY.urn, None @staticmethod diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index c63478dc0cfc..79421ff957b4 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -341,6 +341,44 @@ def generate_secret_bytes() -> bytes: """Generates a new secret key.""" return Fernet.generate_key() + @staticmethod + def parse_secret_option(secret) -> 'Secret': + """Parses a secret string and returns the appropriate secret type. + + The secret string should be formatted like: + 'type:;:' + + For example, 'type:GcpSecret;version_name:my_secret/versions/latest' + would return a GcpSecret initialized with 'my_secret/versions/latest'. + """ + param_map = {} + for param in secret.split(';'): + parts = param.split(':') + param_map[parts[0]] = parts[1] + + if 'type' not in param_map: + raise ValueError('Secret string must contain a valid type parameter') + + secret_type = param_map['type'].lower() + del param_map['type'] + secret_class = None + secret_params = None + if secret_type == 'gcpsecret': + secret_class = GcpSecret + secret_params = ['version_name'] + else: + raise ValueError( + f'Invalid secret type {secret_type}, currently only ' + 'GcpSecret is supported') + + for param_name in param_map.keys(): + if param_name not in secret_params: + raise ValueError( + f'Invalid secret parameter {param_name}, ' + f'{secret_type} only supports the following ' + f'parameters: {secret_params}') + return secret_class(**param_map) + class GcpSecret(Secret): """A secret manager implementation that retrieves secrets from Google Cloud @@ -367,7 +405,12 @@ def get_secret_bytes(self) -> bytes: secret = response.payload.data return secret except Exception as e: - raise RuntimeError(f'Failed to retrieve secret bytes with excetion {e}') + raise RuntimeError( + 'Failed to retrieve secret bytes for secret ' + f'{self._version_name} with exception {e}') + + def __eq__(self, secret): + return self._version_name == getattr(secret, '_version_name', None) class _EncryptMessage(DoFn): @@ -499,7 +542,9 @@ def __init__(self, hmac_key: Secret): self._hmac_key = hmac_key def expand(self, pcoll): - kv_type_hint = pcoll.element_type + key_type, value_type = (typehints.typehints.coerce_to_kv_type( + pcoll.element_type).tuple_types) + kv_type_hint = typehints.KV[key_type, value_type] if kv_type_hint and kv_type_hint != typehints.Any: coder = coders.registry.get_coder(kv_type_hint).as_deterministic_coder( f'GroupByEncryptedKey {self.label}' @@ -518,10 +563,13 @@ def expand(self, pcoll): key_coder = coders.registry.get_coder(typehints.Any) value_coder = key_coder + gbk = beam.GroupByKey() + gbk._inside_gbek = True + return ( pcoll | beam.ParDo(_EncryptMessage(self._hmac_key, key_coder, value_coder)) - | beam.GroupByKey() + | gbk | beam.ParDo(_DecryptMessage(self._hmac_key, key_coder, value_coder))) diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index 6cd8d5fcba76..d892534b69af 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -50,6 +50,7 @@ from apache_beam.coders import coders from apache_beam.metrics import MetricsFilter from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions from apache_beam.options.pipeline_options import StandardOptions from apache_beam.options.pipeline_options import TypeOptions from apache_beam.portability import common_urns @@ -252,7 +253,7 @@ def test_co_group_by_key_on_unpickled(self): class FakeSecret(beam.Secret): - def __init__(self, should_throw=False): + def __init__(self, version_name=None, should_throw=False): self._secret = b'aKwI2PmqYFt2p5tNKCyBS5qYmHhHsGZcyZrnZQiQ-uE=' self._should_throw = should_throw @@ -273,6 +274,12 @@ def __init__(self, hmac_key_secret, key_coder, value_coder): super().__init__(hmac_key_secret, key_coder, value_coder) def process(self, element): + final_elements = list(super().process(element)) + # Check if we're looking at the actual elements being encoded/decoded + # There is also a gbk on assertEqual, which uses None as the key type. + final_element_keys = [e for e in final_elements if e[0] in ['a', 'b', 'c']] + if len(final_element_keys) == 0: + return final_elements hmac_key, actual_elements = element if hmac_key not in self.known_hmacs: raise ValueError(f'GBK produced unencrypted value {hmac_key}') @@ -286,7 +293,38 @@ def process(self, element): except InvalidToken: raise ValueError(f'GBK produced unencrypted value {e[1]}') - return super().process(element) + return final_elements + + +class SecretTest(unittest.TestCase): + @parameterized.expand([ + param( + secret_string='type:GcpSecret;version_name:my_secret/versions/latest', + secret=GcpSecret('my_secret/versions/latest')), + param( + secret_string='type:GcpSecret;version_name:foo', + secret=GcpSecret('foo')), + param( + secret_string='type:gcpsecreT;version_name:my_secret/versions/latest', + secret=GcpSecret('my_secret/versions/latest')), + ]) + def test_secret_manager_parses_correctly(self, secret_string, secret): + self.assertEqual(secret, Secret.parse_secret_option(secret_string)) + + @parameterized.expand([ + param( + secret_string='version_name:foo', + exception_str='must contain a valid type parameter'), + param( + secret_string='type:gcpsecreT', + exception_str='missing 1 required positional argument'), + param( + secret_string='type:gcpsecreT;version_name:foo;extra:val', + exception_str='Invalid secret parameter extra'), + ]) + def test_secret_manager_throws_on_invalid(self, secret_string, exception_str): + with self.assertRaisesRegex(Exception, exception_str): + Secret.parse_secret_option(secret_string) class GroupByEncryptedKeyTest(unittest.TestCase): @@ -318,7 +356,9 @@ def setUp(self): 'data': Secret.generate_secret_bytes() } }) - self.gcp_secret = GcpSecret(f'{self.secret_path}/versions/latest') + version_name = f'{self.secret_path}/versions/latest' + self.gcp_secret = GcpSecret(version_name) + self.secret_option = f'type:GcpSecret;version_name:{version_name}' def tearDown(self): if secretmanager is not None: @@ -334,6 +374,20 @@ def test_gbek_fake_secret_manager_roundtrips(self): assert_that( result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))])) + @unittest.skipIf(secretmanager is None, 'GCP dependencies are not installed') + def test_gbk_with_gbek_option_fake_secret_manager_roundtrips(self): + options = PipelineOptions() + options.view_as(SetupOptions).gbek = self.secret_option + + with beam.Pipeline(options=options) as pipeline: + pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2), + ('b', 3), ('c', 4)]) + result = (pcoll_1) | beam.GroupByKey() + sorted_result = result | beam.Map(lambda x: (x[0], sorted(x[1]))) + assert_that( + sorted_result, + equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))])) + @mock.patch('apache_beam.transforms.util._DecryptMessage', MockNoOpDecrypt) def test_gbek_fake_secret_manager_actually_does_encryption(self): fakeSecret = FakeSecret() @@ -345,8 +399,23 @@ def test_gbek_fake_secret_manager_actually_does_encryption(self): assert_that( result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))])) + @mock.patch('apache_beam.transforms.util._DecryptMessage', MockNoOpDecrypt) + @mock.patch('apache_beam.transforms.util.GcpSecret', FakeSecret) + def test_gbk_actually_does_encryption(self): + options = PipelineOptions() + # Version of GcpSecret doesn't matter since it is replaced by FakeSecret + options.view_as(SetupOptions).gbek = 'type:GcpSecret;version_name:Foo' + + with TestPipeline('FnApiRunner', options=options) as pipeline: + pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2), + ('b', 3), ('c', 4)], + reshuffle=False) + result = pcoll_1 | beam.GroupByKey() + assert_that( + result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))])) + def test_gbek_fake_secret_manager_throws(self): - fakeSecret = FakeSecret(True) + fakeSecret = FakeSecret(None, True) with self.assertRaisesRegex(RuntimeError, r'Exception retrieving secret'): with TestPipeline() as pipeline: