Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions sdks/python/apache_beam/options/pipeline_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -1716,6 +1716,21 @@ def _add_argparse_args(cls, parser):
help=(
'Docker registry url to use for tagging and pushing the prebuilt '
'sdk worker container image.'))
parser.add_argument(
'--gbek',
default=None,
help=(
'When set, will replace all GroupByKey transforms in the pipeline '
'with EncryptedGroupByKey transforms using the secret passed in '
'the option. Beam will infer the secret type and value based on '
'secret itself. This guarantees that any data at rest during the '
'GBK will be encrypted. Many runners only store data at rest when '
'performing a GBK, so this can be used to guarantee that data is '
'not unencrypted. Runners with this behavior include the '
'Dataflow, Flink, and Spark runners. The option should be '
'structured like: '
'--gbek=type:<secret_type>;<secret_param>:<value>, for example '
'--gbek=type:GcpSecret;version_name:my_secret/versions/latest'))
parser.add_argument(
'--user_agent',
default=None,
Expand Down
21 changes: 20 additions & 1 deletion sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from apache_beam.coders import typecoders
from apache_beam.internal import pickler
from apache_beam.internal import util
from apache_beam.options.pipeline_options import SetupOptions
from apache_beam.options.pipeline_options import TypeOptions
from apache_beam.portability import common_urns
from apache_beam.portability import python_urns
Expand Down Expand Up @@ -3324,6 +3325,10 @@ class GroupByKey(PTransform):

The implementation here is used only when run on the local direct runner.
"""
def __init__(self):
self._replaced_by_gbek = False
self._inside_gbek = False

class ReifyWindows(DoFn):
def process(
self, element, window=DoFn.WindowParam, timestamp=DoFn.TimestampParam):
Expand Down Expand Up @@ -3354,6 +3359,16 @@ def get_windowing(self, inputs):
environment_id=windowing.environment_id)

def expand(self, pcoll):
replace_with_gbek_secret = (
pcoll.pipeline._options.view_as(SetupOptions).gbek)
if replace_with_gbek_secret is not None and not self._inside_gbek:
self._replaced_by_gbek = True
from apache_beam.transforms.util import GroupByEncryptedKey
from apache_beam.transforms.util import Secret

secret = Secret.parse_secret_option(replace_with_gbek_secret)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we somehow retry get_secret_bytes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gets called during setup, so it should get automatic retries from the runner -

self._hmac_key = self.hmac_key_secret.get_secret_bytes()

return (pcoll | "Group by encrypted key" >> GroupByEncryptedKey(secret))

from apache_beam.transforms.trigger import DataLossReason
from apache_beam.transforms.trigger import DefaultTrigger
windowing = pcoll.windowing
Expand Down Expand Up @@ -3400,7 +3415,11 @@ def infer_output_type(self, input_type):
return typehints.KV[key_type, typehints.Iterable[value_type]]

def to_runner_api_parameter(self, unused_context):
# type: (PipelineContext) -> typing.Tuple[str, None]
# type: (PipelineContext) -> tuple[str, typing.Optional[typing.Union[message.Message, bytes, str]]]
# if we're containing a GroupByEncryptedKey, don't allow runners to
# recognize this transform as a GBEK so that it doesn't get replaced.
if self._replaced_by_gbek:
return super().to_runner_api_parameter(unused_context)
return common_urns.primitives.GROUP_BY_KEY.urn, None

@staticmethod
Expand Down
54 changes: 51 additions & 3 deletions sdks/python/apache_beam/transforms/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,44 @@ def generate_secret_bytes() -> bytes:
"""Generates a new secret key."""
return Fernet.generate_key()

@staticmethod
def parse_secret_option(secret) -> 'Secret':
"""Parses a secret string and returns the appropriate secret type.

The secret string should be formatted like:
'type:<secret_type>;<secret_param>:<value>'

For example, 'type:GcpSecret;version_name:my_secret/versions/latest'
would return a GcpSecret initialized with 'my_secret/versions/latest'.
"""
param_map = {}
for param in secret.split(';'):
parts = param.split(':')
param_map[parts[0]] = parts[1]

if 'type' not in param_map:
raise ValueError('Secret string must contain a valid type parameter')

secret_type = param_map['type'].lower()
del param_map['type']
secret_class = None
secret_params = None
if secret_type == 'gcpsecret':
secret_class = GcpSecret
secret_params = ['version_name']
else:
raise ValueError(
f'Invalid secret type {secret_type}, currently only '
'GcpSecret is supported')

for param_name in param_map.keys():
if param_name not in secret_params:
raise ValueError(
f'Invalid secret parameter {param_name}, '
f'{secret_type} only supports the following '
f'parameters: {secret_params}')
return secret_class(**param_map)


class GcpSecret(Secret):
"""A secret manager implementation that retrieves secrets from Google Cloud
Expand All @@ -367,7 +405,12 @@ def get_secret_bytes(self) -> bytes:
secret = response.payload.data
return secret
except Exception as e:
raise RuntimeError(f'Failed to retrieve secret bytes with excetion {e}')
raise RuntimeError(
'Failed to retrieve secret bytes for secret '
f'{self._version_name} with exception {e}')

def __eq__(self, secret):
return self._version_name == getattr(secret, '_version_name', None)


class _EncryptMessage(DoFn):
Expand Down Expand Up @@ -499,7 +542,9 @@ def __init__(self, hmac_key: Secret):
self._hmac_key = hmac_key

def expand(self, pcoll):
kv_type_hint = pcoll.element_type
key_type, value_type = (typehints.typehints.coerce_to_kv_type(
pcoll.element_type).tuple_types)
kv_type_hint = typehints.KV[key_type, value_type]
if kv_type_hint and kv_type_hint != typehints.Any:
coder = coders.registry.get_coder(kv_type_hint).as_deterministic_coder(
f'GroupByEncryptedKey {self.label}'
Expand All @@ -518,10 +563,13 @@ def expand(self, pcoll):
key_coder = coders.registry.get_coder(typehints.Any)
value_coder = key_coder

gbk = beam.GroupByKey()
gbk._inside_gbek = True

return (
pcoll
| beam.ParDo(_EncryptMessage(self._hmac_key, key_coder, value_coder))
| beam.GroupByKey()
| gbk
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this breaks the update compatibility given the line number changes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true - that's why I think we should do #36251 (comment) to permanently avoid this problem.

I can follow up with a PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#36381 will fix this broadly.

| beam.ParDo(_DecryptMessage(self._hmac_key, key_coder, value_coder)))


Expand Down
77 changes: 73 additions & 4 deletions sdks/python/apache_beam/transforms/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from apache_beam.coders import coders
from apache_beam.metrics import MetricsFilter
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions
from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.options.pipeline_options import TypeOptions
from apache_beam.portability import common_urns
Expand Down Expand Up @@ -252,7 +253,7 @@ def test_co_group_by_key_on_unpickled(self):


class FakeSecret(beam.Secret):
def __init__(self, should_throw=False):
def __init__(self, version_name=None, should_throw=False):
self._secret = b'aKwI2PmqYFt2p5tNKCyBS5qYmHhHsGZcyZrnZQiQ-uE='
self._should_throw = should_throw

Expand All @@ -273,6 +274,12 @@ def __init__(self, hmac_key_secret, key_coder, value_coder):
super().__init__(hmac_key_secret, key_coder, value_coder)

def process(self, element):
final_elements = list(super().process(element))
# Check if we're looking at the actual elements being encoded/decoded
# There is also a gbk on assertEqual, which uses None as the key type.
final_element_keys = [e for e in final_elements if e[0] in ['a', 'b', 'c']]
if len(final_element_keys) == 0:
return final_elements
hmac_key, actual_elements = element
if hmac_key not in self.known_hmacs:
raise ValueError(f'GBK produced unencrypted value {hmac_key}')
Expand All @@ -286,7 +293,38 @@ def process(self, element):
except InvalidToken:
raise ValueError(f'GBK produced unencrypted value {e[1]}')

return super().process(element)
return final_elements


class SecretTest(unittest.TestCase):
@parameterized.expand([
param(
secret_string='type:GcpSecret;version_name:my_secret/versions/latest',
secret=GcpSecret('my_secret/versions/latest')),
param(
secret_string='type:GcpSecret;version_name:foo',
secret=GcpSecret('foo')),
param(
secret_string='type:gcpsecreT;version_name:my_secret/versions/latest',
secret=GcpSecret('my_secret/versions/latest')),
])
def test_secret_manager_parses_correctly(self, secret_string, secret):
self.assertEqual(secret, Secret.parse_secret_option(secret_string))

@parameterized.expand([
param(
secret_string='version_name:foo',
exception_str='must contain a valid type parameter'),
param(
secret_string='type:gcpsecreT',
exception_str='missing 1 required positional argument'),
param(
secret_string='type:gcpsecreT;version_name:foo;extra:val',
exception_str='Invalid secret parameter extra'),
])
def test_secret_manager_throws_on_invalid(self, secret_string, exception_str):
with self.assertRaisesRegex(Exception, exception_str):
Secret.parse_secret_option(secret_string)


class GroupByEncryptedKeyTest(unittest.TestCase):
Expand Down Expand Up @@ -318,7 +356,9 @@ def setUp(self):
'data': Secret.generate_secret_bytes()
}
})
self.gcp_secret = GcpSecret(f'{self.secret_path}/versions/latest')
version_name = f'{self.secret_path}/versions/latest'
self.gcp_secret = GcpSecret(version_name)
self.secret_option = f'type:GcpSecret;version_name:{version_name}'

def tearDown(self):
if secretmanager is not None:
Expand All @@ -334,6 +374,20 @@ def test_gbek_fake_secret_manager_roundtrips(self):
assert_that(
result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))]))

@unittest.skipIf(secretmanager is None, 'GCP dependencies are not installed')
def test_gbk_with_gbek_option_fake_secret_manager_roundtrips(self):
options = PipelineOptions()
options.view_as(SetupOptions).gbek = self.secret_option

with beam.Pipeline(options=options) as pipeline:
pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2),
('b', 3), ('c', 4)])
result = (pcoll_1) | beam.GroupByKey()
sorted_result = result | beam.Map(lambda x: (x[0], sorted(x[1])))
assert_that(
sorted_result,
equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))]))

@mock.patch('apache_beam.transforms.util._DecryptMessage', MockNoOpDecrypt)
def test_gbek_fake_secret_manager_actually_does_encryption(self):
fakeSecret = FakeSecret()
Expand All @@ -345,8 +399,23 @@ def test_gbek_fake_secret_manager_actually_does_encryption(self):
assert_that(
result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))]))

@mock.patch('apache_beam.transforms.util._DecryptMessage', MockNoOpDecrypt)
@mock.patch('apache_beam.transforms.util.GcpSecret', FakeSecret)
def test_gbk_actually_does_encryption(self):
options = PipelineOptions()
# Version of GcpSecret doesn't matter since it is replaced by FakeSecret
options.view_as(SetupOptions).gbek = 'type:GcpSecret;version_name:Foo'

with TestPipeline('FnApiRunner', options=options) as pipeline:
pcoll_1 = pipeline | 'Start 1' >> beam.Create([('a', 1), ('a', 2),
('b', 3), ('c', 4)],
reshuffle=False)
result = pcoll_1 | beam.GroupByKey()
assert_that(
result, equal_to([('a', ([1, 2])), ('b', ([3])), ('c', ([4]))]))

def test_gbek_fake_secret_manager_throws(self):
fakeSecret = FakeSecret(True)
fakeSecret = FakeSecret(None, True)

with self.assertRaisesRegex(RuntimeError, r'Exception retrieving secret'):
with TestPipeline() as pipeline:
Expand Down
Loading