Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,13 @@ SingleOutputExpandableTransform<InputT, OutputT> of(
Endpoints.ApiServiceDescriptor apiDesc =
Endpoints.ApiServiceDescriptor.newBuilder().setUrl(endpoint).build();
return new SingleOutputExpandableTransform<>(
urn, payload, apiDesc, DEFAULT, getFreshNamespaceIndex(), ImmutableMap.of());
urn,
payload,
apiDesc,
DEFAULT,
getFreshNamespaceIndex(),
ImmutableMap.of(),
ImmutableMap.of());
}

@VisibleForTesting
Expand All @@ -103,7 +109,13 @@ static <InputT extends PInput, OutputT> SingleOutputExpandableTransform<InputT,
Endpoints.ApiServiceDescriptor apiDesc =
Endpoints.ApiServiceDescriptor.newBuilder().setUrl(endpoint).build();
return new SingleOutputExpandableTransform<>(
urn, payload, apiDesc, clientFactory, getFreshNamespaceIndex(), ImmutableMap.of());
urn,
payload,
apiDesc,
clientFactory,
getFreshNamespaceIndex(),
ImmutableMap.of(),
ImmutableMap.of());
}

/** Expandable transform for output type of PCollection. */
Expand All @@ -115,8 +127,9 @@ public static class SingleOutputExpandableTransform<InputT extends PInput, Outpu
Endpoints.ApiServiceDescriptor endpoint,
ExpansionServiceClientFactory clientFactory,
Integer namespaceIndex,
Map<String, Coder<?>> outputCoders) {
super(urn, payload, endpoint, clientFactory, namespaceIndex, outputCoders);
Map<String, Coder<?>> outputCoders,
Map<String, String> resources) {
super(urn, payload, endpoint, clientFactory, namespaceIndex, outputCoders, resources);
}

@Override
Expand All @@ -125,14 +138,27 @@ PCollection<OutputT> toOutputCollection(Map<TupleTag<?>, PCollection> output) {
return Iterables.getOnlyElement(output.values());
}

public SingleOutputExpandableTransform<InputT, OutputT> withResources(
Map<String, String> resources) {
return new SingleOutputExpandableTransform<>(
getUrn(),
getPayload(),
getEndpoint(),
getClientFactory(),
getNamespaceIndex(),
getOutputCoders(),
resources);
}

public MultiOutputExpandableTransform<InputT> withMultiOutputs() {
return new MultiOutputExpandableTransform<>(
getUrn(),
getPayload(),
getEndpoint(),
getClientFactory(),
getNamespaceIndex(),
getOutputCoders());
getOutputCoders(),
getResources());
}

public SingleOutputExpandableTransform<InputT, OutputT> withOutputCoder(Coder<?> outputCoder) {
Expand All @@ -142,7 +168,8 @@ public SingleOutputExpandableTransform<InputT, OutputT> withOutputCoder(Coder<?>
getEndpoint(),
getClientFactory(),
getNamespaceIndex(),
ImmutableMap.of("0", outputCoder));
ImmutableMap.of("0", outputCoder),
getResources());
}
}

Expand All @@ -155,8 +182,9 @@ public static class MultiOutputExpandableTransform<InputT extends PInput>
Endpoints.ApiServiceDescriptor endpoint,
ExpansionServiceClientFactory clientFactory,
Integer namespaceIndex,
Map<String, Coder<?>> outputCoders) {
super(urn, payload, endpoint, clientFactory, namespaceIndex, outputCoders);
Map<String, Coder<?>> outputCoders,
Map<String, String> resources) {
super(urn, payload, endpoint, clientFactory, namespaceIndex, outputCoders, resources);
}

@Override
Expand All @@ -178,7 +206,8 @@ public MultiOutputExpandableTransform<InputT> withOutputCoder(
getEndpoint(),
getClientFactory(),
getNamespaceIndex(),
outputCoders);
outputCoders,
getResources());
}
}

Expand All @@ -191,6 +220,7 @@ public abstract static class ExpandableTransform<InputT extends PInput, OutputT
private final ExpansionServiceClientFactory clientFactory;
private final Integer namespaceIndex;
private final Map<String, Coder<?>> outputCoders;
private final Map<String, String> resources;

private transient RunnerApi.@Nullable Components expandedComponents;
private transient RunnerApi.@Nullable PTransform expandedTransform;
Expand All @@ -204,13 +234,15 @@ public abstract static class ExpandableTransform<InputT extends PInput, OutputT
Endpoints.ApiServiceDescriptor endpoint,
ExpansionServiceClientFactory clientFactory,
Integer namespaceIndex,
Map<String, Coder<?>> outputCoders) {
Map<String, Coder<?>> outputCoders,
Map<String, String> resources) {
this.urn = urn;
this.payload = payload;
this.endpoint = endpoint;
this.clientFactory = clientFactory;
this.namespaceIndex = namespaceIndex;
this.outputCoders = outputCoders;
this.resources = resources;
}

@Override
Expand Down Expand Up @@ -281,20 +313,29 @@ public OutputT expand(InputT input) {
String.format("expansion service error: %s", response.getError()));
}

Map<String, RunnerApi.Environment> newEnvironmentsWithDependencies =
response.getComponents().getEnvironmentsMap().entrySet().stream()
.filter(
kv ->
!originalComponents.getEnvironmentsMap().containsKey(kv.getKey())
&& kv.getValue().getDependenciesCount() != 0)
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));

expandedComponents =
response
.getComponents()
.toBuilder()
.putAllEnvironments(resolveArtifacts(newEnvironmentsWithDependencies))
.build();
RunnerApi.Components.Builder componentsBuilder = response.getComponents().toBuilder();
componentsBuilder.putAllEnvironments(
resolveArtifacts(
componentsBuilder.getEnvironmentsMap().entrySet().stream()
.filter(
kv ->
!originalComponents.getEnvironmentsMap().containsKey(kv.getKey())
&& kv.getValue().getDependenciesCount() != 0)
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))));
List<RunnerApi.ArtifactInformation> artifacts =
Environments.getArtifacts(
resources.entrySet().stream()
.map(e -> String.format("%s=%s", e.getValue(), e.getKey()))
.collect(Collectors.toList()));
componentsBuilder.putAllEnvironments(
componentsBuilder.getEnvironmentsMap().entrySet().stream()
.filter(kv -> !originalComponents.getEnvironmentsMap().containsKey(kv.getKey()))
.collect(
Collectors.toMap(
Map.Entry::getKey,
e -> e.getValue().toBuilder().addAllDependencies(artifacts).build())));

expandedComponents = componentsBuilder.build();
expandedTransform = response.getTransform();
expandedRequirements = response.getRequirementsList();

Expand Down Expand Up @@ -478,5 +519,9 @@ Integer getNamespaceIndex() {
Map<String, Coder<?>> getOutputCoders() {
return outputCoders;
}

Map<String, String> getResources() {
return resources;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ public class PythonExternalTransform<InputT extends PInput, OutputT extends POut
private @Nullable Row providedKwargsRow;

Map<String, Coder<?>> outputCoders;
Map<String, String> resources;

private PythonExternalTransform(String fullyQualifiedName, String expansionService) {
this.fullyQualifiedName = fullyQualifiedName;
Expand All @@ -86,6 +87,7 @@ private PythonExternalTransform(String fullyQualifiedName, String expansionServi
PythonCallableSource.class, Schema.FieldType.logicalType(new PythonCallable()));
argsArray = new Object[] {};
this.outputCoders = new HashMap<>();
this.resources = new HashMap<>();
}

/**
Expand Down Expand Up @@ -228,6 +230,14 @@ public PythonExternalTransform<InputT, OutputT> withOutputCoder(Coder<?> outputC
return this;
}

public PythonExternalTransform<InputT, OutputT> withResources(Map<String, String> resources) {
if (this.resources.size() > 0) {
throw new IllegalArgumentException("resources were already specified");
}
this.resources = resources;
return this;
}

@VisibleForTesting
Row buildOrGetKwargsRow() {
if (providedKwargsRow != null) {
Expand Down Expand Up @@ -406,6 +416,7 @@ private OutputT apply(
"beam:transforms:python:fully_qualified_named",
payload.toByteArray(),
expansionService)
.withResources(this.resources)
.withMultiOutputs()
.withOutputCoder(this.outputCoders);
PCollectionTuple outputs;
Expand Down