Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 77 additions & 2 deletions sdks/python/apache_beam/transforms/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from collections import namedtuple

import grpc
import yaml

from apache_beam import pvalue
from apache_beam.coders import RowCoder
Expand All @@ -42,10 +43,12 @@
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.portability.api import external_transforms_pb2
from apache_beam.portability.api import schema_pb2
from apache_beam.portability.common_urns import ManagedTransforms
from apache_beam.runners import pipeline_context
from apache_beam.runners.portability import artifact_service
from apache_beam.transforms import environments
from apache_beam.transforms import ptransform
from apache_beam.transforms.util import is_compat_version_prior_to
from apache_beam.typehints import WithTypeHints
from apache_beam.typehints import native_type_compatibility
from apache_beam.typehints import row_type
Expand All @@ -61,6 +64,25 @@

DEFAULT_EXPANSION_SERVICE = 'localhost:8097'

MANAGED_SCHEMA_TRANSFORM_IDENTIFIER = "beam:transform:managed:v1"

_IO_EXPANSION_SERVICE_JAR_TARGET = "sdks:java:io:expansion-service:shadowJar"

_GCP_EXPANSION_SERVICE_JAR_TARGET = (
"sdks:java:io:google-cloud-platform:expansion-service:shadowJar")
Comment on lines +69 to +72
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe combine these in a list of "supported expansion jar targets" and replace this:

services_and_names = managed._EXPANSION_SERVICE_JAR_TARGETS

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


# A mapping from supported managed transforms URNs to expansion service jars
# that include the corresponding transforms.
MANAGED_TRANSFORM_URN_TO_JAR_TARGET_MAPPING = {
ManagedTransforms.Urns.ICEBERG_READ.urn: _IO_EXPANSION_SERVICE_JAR_TARGET,
ManagedTransforms.Urns.ICEBERG_WRITE.urn: _IO_EXPANSION_SERVICE_JAR_TARGET,
ManagedTransforms.Urns.ICEBERG_CDC_READ.urn: _IO_EXPANSION_SERVICE_JAR_TARGET, # pylint: disable=line-too-long
ManagedTransforms.Urns.KAFKA_READ.urn: _IO_EXPANSION_SERVICE_JAR_TARGET,
ManagedTransforms.Urns.KAFKA_WRITE.urn: _IO_EXPANSION_SERVICE_JAR_TARGET,
ManagedTransforms.Urns.BIGQUERY_READ.urn: _GCP_EXPANSION_SERVICE_JAR_TARGET,
ManagedTransforms.Urns.BIGQUERY_WRITE.urn: _GCP_EXPANSION_SERVICE_JAR_TARGET
}


def convert_to_typing_type(type_):
if isinstance(type_, row_type.RowTypeConstraint):
Expand Down Expand Up @@ -378,6 +400,10 @@ def _has_constructor(self):
'SchemaTransformsConfig',
['identifier', 'configuration_schema', 'inputs', 'outputs', 'description'])

ManagedReplacement = namedtuple(
'ManagedReplacement',
['underlying_transform_identifier', 'update_compatibility_version'])


class SchemaAwareExternalTransform(ptransform.PTransform):
"""A proxy transform for SchemaTransforms implemented in external SDKs.
Expand All @@ -396,6 +422,12 @@ class SchemaAwareExternalTransform(ptransform.PTransform):
the configuration.
:param classpath: (Optional) A list paths to additional jars to place on the
expansion service classpath.
:param managed_replacement: (Optional) a 'ManagedReplacement' namedtuple that
defines information needed to replace the transform with an equivalent
managed transform during the expansion. If an
'updateCompatibilityBeamVersion' pipeline option is provided, we will
only replace if the managed transform is update compatible with the
provided version.
:kwargs: field name to value mapping for configuring the schema transform.
keys map to the field names of the schema of the SchemaTransform
(in-order).
Expand All @@ -406,10 +438,14 @@ def __init__(
expansion_service,
rearrange_based_on_discovery=False,
classpath=None,
managed_replacement=None,
**kwargs):
self._expansion_service = expansion_service
self._kwargs = kwargs
self._classpath = classpath
if managed_replacement:
assert isinstance(managed_replacement, ManagedReplacement)
self._managed_replacement = managed_replacement

_kwargs = kwargs
if rearrange_based_on_discovery:
Expand All @@ -420,16 +456,55 @@ def __init__(
named_tuple_to_schema(config.configuration_schema),
**_kwargs)

if self._managed_replacement:
# We have to do the replacement at the expansion instead of at
# construction
# since we don't have access to the PipelineOptions object at the
# construction.
underlying_transform_id = (
self._managed_replacement.underlying_transform_identifier)
if not (underlying_transform_id in
MANAGED_TRANSFORM_URN_TO_JAR_TARGET_MAPPING):
raise ValueError(
'Could not find an expansion service jar for the managed ' +
'transform ' + underlying_transform_id)
managed_expansion_service_jar = (
MANAGED_TRANSFORM_URN_TO_JAR_TARGET_MAPPING
)[underlying_transform_id]
self._managed_expansion_service = BeamJarExpansionService(
managed_expansion_service_jar)
managed_config = SchemaAwareExternalTransform.discover_config(
self._managed_expansion_service,
MANAGED_SCHEMA_TRANSFORM_IDENTIFIER)

yaml_config = yaml.dump(kwargs)
self._managed_payload_builder = (
ExplicitSchemaTransformPayloadBuilder(
MANAGED_SCHEMA_TRANSFORM_IDENTIFIER,
named_tuple_to_schema(managed_config.configuration_schema),
transform_identifier=underlying_transform_id,
config=yaml_config))
else:
self._payload_builder = SchemaTransformPayloadBuilder(
identifier, **_kwargs)

def expand(self, pcolls):
# Expand the transform using the expansion service.
payload_builder = self._payload_builder
expansion_service = self._expansion_service

if self._managed_replacement:
compat_version_prior_to_current = is_compat_version_prior_to(
pcolls.pipeline._options,
self._managed_replacement.update_compatibility_version)
if not compat_version_prior_to_current:
payload_builder = self._managed_payload_builder
expansion_service = self._managed_expansion_service

return pcolls | self._payload_builder.identifier() >> ExternalTransform(
common_urns.schematransform_based_expand.urn,
self._payload_builder,
self._expansion_service)
payload_builder,
expansion_service)

@classmethod
@functools.lru_cache
Expand Down
66 changes: 65 additions & 1 deletion sdks/python/apache_beam/transforms/external_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,21 @@
import mock

import apache_beam as beam
from apache_beam import ManagedReplacement
from apache_beam import Pipeline
from apache_beam.coders import RowCoder
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.portability.api import beam_expansion_api_pb2
from apache_beam.portability.api import external_transforms_pb2
from apache_beam.portability.api import schema_pb2
from apache_beam.portability.common_urns import ManagedTransforms
from apache_beam.runners import pipeline_context
from apache_beam.runners.portability import expansion_service
from apache_beam.runners.portability.expansion_service_test import FibTransform
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.transforms import external
from apache_beam.transforms.external import MANAGED_SCHEMA_TRANSFORM_IDENTIFIER
from apache_beam.transforms.external import AnnotationBasedPayloadBuilder
from apache_beam.transforms.external import ImplicitSchemaPayloadBuilder
from apache_beam.transforms.external import JavaClassLookupPayloadBuilder
Expand Down Expand Up @@ -530,8 +533,28 @@ def DiscoverSchemaTransform(self, unused_request=None):
id="test-id"),
input_pcollection_names=["input"],
output_pcollection_names=["output"])

test_managed_config = beam_expansion_api_pb2.SchemaTransformConfig(
config_schema=schema_pb2.Schema(
fields=[
schema_pb2.Field(
name="transform_identifier",
type=schema_pb2.FieldType(atomic_type="STRING")),
schema_pb2.Field(
name="config_url",
type=schema_pb2.FieldType(atomic_type="STRING")),
schema_pb2.Field(
name="config",
type=schema_pb2.FieldType(atomic_type="STRING"))
],
id="test-id1"),
input_pcollection_names=["input"],
output_pcollection_names=["output"])
return beam_expansion_api_pb2.DiscoverSchemaTransformResponse(
schema_transform_configs={"test_schematransform": test_config})
schema_transform_configs={
"test_schematransform": test_config,
MANAGED_SCHEMA_TRANSFORM_IDENTIFIER: test_managed_config
})

@mock.patch("apache_beam.transforms.external.ExternalTransform.service")
def test_discover_one_config(self, mock_service):
Expand Down Expand Up @@ -573,6 +596,47 @@ def test_rearrange_kwargs_based_on_discovery(self, mock_service):
self.assertNotEqual(tuple(kwargs.keys()), external_config_fields)
self.assertEqual(tuple(ordered_fields), external_config_fields)

@mock.patch("apache_beam.transforms.external.ExternalTransform.service")
def test_managed_replacement_unknown_id(self, mock_service):
mock_service.return_value = self.MockDiscoveryService()

identifier = "test_schematransform"
kwargs = {"int_field": 0, "str_field": "str"}

managed_replacement = ManagedReplacement(
underlying_transform_identifier="unknown_id",
update_compatibility_version="2.50.0")

with self.assertRaises(ValueError):
beam.SchemaAwareExternalTransform(
identifier=identifier,
expansion_service=expansion_service,
rearrange_based_on_discovery=True,
managed_replacement=managed_replacement,
**kwargs)

@mock.patch("apache_beam.transforms.external.ExternalTransform.service")
@mock.patch("apache_beam.transforms.external.BeamJarExpansionService")
def test_managed_replacement_known_id(
self, mock_service, mock_beam_jar_service):
mock_service.return_value = self.MockDiscoveryService()
mock_beam_jar_service.return_value = self.MockDiscoveryService()

identifier = "test_schematransform"
kwargs = {"int_field": 0, "str_field": "str"}

managed_replacement = ManagedReplacement(
underlying_transform_identifier=ManagedTransforms.Urns.ICEBERG_READ.urn,
update_compatibility_version="2.50.0")

external_transform = beam.SchemaAwareExternalTransform(
identifier=identifier,
expansion_service=expansion_service,
rearrange_based_on_discovery=True,
managed_replacement=managed_replacement,
**kwargs)
self.assertIsNotNone(external_transform._managed_payload_builder)


class JavaClassLookupPayloadBuilderTest(unittest.TestCase):
def _verify_row(self, schema, row_payload, expected_values):
Expand Down
25 changes: 9 additions & 16 deletions sdks/python/apache_beam/transforms/managed.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@
import yaml

from apache_beam.portability.common_urns import ManagedTransforms
from apache_beam.transforms.external import MANAGED_SCHEMA_TRANSFORM_IDENTIFIER
from apache_beam.transforms.external import MANAGED_TRANSFORM_URN_TO_JAR_TARGET_MAPPING
from apache_beam.transforms.external import BeamJarExpansionService
from apache_beam.transforms.external import SchemaAwareExternalTransform
from apache_beam.transforms.ptransform import PTransform
Expand All @@ -87,13 +89,6 @@
_ICEBERG_CDC = "iceberg_cdc"
KAFKA = "kafka"
BIGQUERY = "bigquery"
_MANAGED_IDENTIFIER = "beam:transform:managed:v1"
_EXPANSION_SERVICE_JAR_TARGETS = {
"sdks:java:io:expansion-service:shadowJar": [KAFKA, ICEBERG, _ICEBERG_CDC],
"sdks:java:io:google-cloud-platform:expansion-service:shadowJar": [
BIGQUERY
]
}

__all__ = ["ICEBERG", "KAFKA", "BIGQUERY", "Read", "Write"]

Expand Down Expand Up @@ -131,7 +126,7 @@ def __init__(

def expand(self, input):
return input | SchemaAwareExternalTransform(
identifier=_MANAGED_IDENTIFIER,
identifier=MANAGED_SCHEMA_TRANSFORM_IDENTIFIER,
expansion_service=self._expansion_service,
rearrange_based_on_discovery=True,
transform_identifier=self._underlying_identifier,
Expand Down Expand Up @@ -175,7 +170,7 @@ def __init__(

def expand(self, input):
return input | SchemaAwareExternalTransform(
identifier=_MANAGED_IDENTIFIER,
identifier=MANAGED_SCHEMA_TRANSFORM_IDENTIFIER,
expansion_service=self._expansion_service,
rearrange_based_on_discovery=True,
transform_identifier=self._underlying_identifier,
Expand All @@ -192,13 +187,11 @@ def _resolve_expansion_service(
if expansion_service:
return expansion_service

default_target = None
for gradle_target, transforms in _EXPANSION_SERVICE_JAR_TARGETS.items():
if transform_name.lower() in transforms:
default_target = gradle_target
break
if not default_target:
gradle_target = None
if identifier in MANAGED_TRANSFORM_URN_TO_JAR_TARGET_MAPPING:
gradle_target = MANAGED_TRANSFORM_URN_TO_JAR_TARGET_MAPPING.get(identifier)
if not gradle_target:
raise ValueError(
"No expansion service was specified and could not find a "
f"default expansion service for {transform_name}: '{identifier}'.")
return BeamJarExpansionService(default_target)
return BeamJarExpansionService(gradle_target)
6 changes: 6 additions & 0 deletions sdks/python/apache_beam/yaml/standard_io.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@
'WriteToKafka': 'beam:schematransform:org.apache.beam:kafka_write:v1'
config:
gradle_target: 'sdks:java:io:expansion-service:shadowJar'
managed_replacement:
# Following transforms may be replaced with equivalent managed transforms,
# if the pipelines 'updateCompatibilityBeamVersion' match the provided
# version.
'ReadFromKafka': '2.66.0'
'WriteToKafka': '2.66.0'

# PubSub
- type: renaming
Expand Down
Loading
Loading