diff --git a/src/main/java/com/google/api/generator/gapic/composer/BatchingDescriptorComposer.java b/src/main/java/com/google/api/generator/gapic/composer/BatchingDescriptorComposer.java index 3677a02b79..bcdb809df3 100644 --- a/src/main/java/com/google/api/generator/gapic/composer/BatchingDescriptorComposer.java +++ b/src/main/java/com/google/api/generator/gapic/composer/BatchingDescriptorComposer.java @@ -23,6 +23,7 @@ import com.google.api.generator.engine.ast.ConcreteReference; import com.google.api.generator.engine.ast.Expr; import com.google.api.generator.engine.ast.ExprStatement; +import com.google.api.generator.engine.ast.ForStatement; import com.google.api.generator.engine.ast.IfStatement; import com.google.api.generator.engine.ast.MethodDefinition; import com.google.api.generator.engine.ast.MethodInvocationExpr; @@ -39,6 +40,7 @@ import com.google.common.base.Preconditions; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Objects; @@ -58,6 +60,7 @@ public class BatchingDescriptorComposer { private static final String ADD_ALL_METHOD_PATTERN = "addAll%s"; private static final String GET_LIST_METHOD_PATTERN = "get%sList"; + private static final String GET_COUNT_METHOD_PATTERN = "get%sCount"; public static Expr createBatchingDescriptorFieldDeclExpr( Method method, GapicBatchingSettings batchingSettings, Map messageTypes) { @@ -65,6 +68,10 @@ public static Expr createBatchingDescriptorFieldDeclExpr( javaMethods.add(createGetBatchPartitionKeyMethod(method, batchingSettings, messageTypes)); javaMethods.add(createGetRequestBuilderMethod(method, batchingSettings)); + javaMethods.add(createSplitExceptionMethod(method)); + javaMethods.add(createCountElementsMethod(method, batchingSettings)); + javaMethods.add(createCountByteSMethod(method)); + TypeNode batchingDescriptorType = toType(BATCHING_DESCRIPTOR_REF, method.inputType(), method.outputType()); AnonymousClassExpr batchingDescriptorClassExpr = @@ -229,6 +236,99 @@ private static MethodDefinition createGetRequestBuilderMethod( .build(); } + private static MethodDefinition createSplitExceptionMethod(Method method) { + VariableExpr throwableVarExpr = + VariableExpr.withVariable( + Variable.builder().setType(toType(Throwable.class)).setName("throwable").build()); + + TypeNode batchedRequestIssuerType = toType(BATCHED_REQUEST_ISSUER_REF, method.outputType()); + TypeNode batchVarType = + TypeNode.withReference( + ConcreteReference.builder() + .setClazz(Collection.class) + .setGenerics( + Arrays.asList( + ConcreteReference.wildcardWithUpperBound( + batchedRequestIssuerType.reference()))) + .build()); + VariableExpr batchVarExpr = + VariableExpr.withVariable( + Variable.builder().setType(batchVarType).setName("batch").build()); + VariableExpr responderVarExpr = + VariableExpr.withVariable( + Variable.builder().setType(batchedRequestIssuerType).setName("responder").build()); + + ForStatement forStatement = + ForStatement.builder() + .setLocalVariableExpr(responderVarExpr.toBuilder().setIsDecl(true).build()) + .setCollectionExpr(batchVarExpr) + .setBody( + Arrays.asList( + ExprStatement.withExpr( + MethodInvocationExpr.builder() + .setExprReferenceExpr(responderVarExpr) + .setMethodName("setException") + .setArguments(throwableVarExpr) + .build()))) + .build(); + + return MethodDefinition.builder() + .setIsOverride(true) + .setScope(ScopeNode.PUBLIC) + .setReturnType(TypeNode.VOID) + .setName("splitException") + .setArguments( + Arrays.asList(throwableVarExpr, batchVarExpr).stream() + .map(v -> v.toBuilder().setIsDecl(true).build()) + .collect(Collectors.toList())) + .setBody(Arrays.asList(forStatement)) + .build(); + } + + private static MethodDefinition createCountElementsMethod( + Method method, GapicBatchingSettings batchingSettings) { + String getFooCountMethodName = + String.format( + GET_COUNT_METHOD_PATTERN, + JavaStyle.toUpperCamelCase(batchingSettings.batchedFieldName())); + VariableExpr requestVarExpr = + VariableExpr.withVariable( + Variable.builder().setType(method.inputType()).setName("request").build()); + + return MethodDefinition.builder() + .setIsOverride(true) + .setScope(ScopeNode.PUBLIC) + .setReturnType(TypeNode.LONG) + .setName("countElements") + .setArguments(requestVarExpr.toBuilder().setIsDecl(true).build()) + .setReturnExpr( + MethodInvocationExpr.builder() + .setExprReferenceExpr(requestVarExpr) + .setMethodName(getFooCountMethodName) + .setReturnType(TypeNode.LONG) + .build()) + .build(); + } + + private static MethodDefinition createCountByteSMethod(Method method) { + VariableExpr requestVarExpr = + VariableExpr.withVariable( + Variable.builder().setType(method.inputType()).setName("request").build()); + return MethodDefinition.builder() + .setIsOverride(true) + .setScope(ScopeNode.PUBLIC) + .setReturnType(TypeNode.LONG) + .setName("countBytes") + .setArguments(requestVarExpr.toBuilder().setIsDecl(true).build()) + .setReturnExpr( + MethodInvocationExpr.builder() + .setExprReferenceExpr(requestVarExpr) + .setMethodName("getSerializedSize") + .setReturnType(TypeNode.LONG) + .build()) + .build(); + } + private static TypeNode toType(Class clazz) { return TypeNode.withReference(ConcreteReference.withClazz(clazz)); } diff --git a/src/test/java/com/google/api/generator/gapic/composer/BatchingDescriptorComposerTest.java b/src/test/java/com/google/api/generator/gapic/composer/BatchingDescriptorComposerTest.java index 3b5dd833ac..e412e0a832 100644 --- a/src/test/java/com/google/api/generator/gapic/composer/BatchingDescriptorComposerTest.java +++ b/src/test/java/com/google/api/generator/gapic/composer/BatchingDescriptorComposerTest.java @@ -128,6 +128,22 @@ public void batchingDescriptor_hasSubresponseField() { "}\n", "};\n", "}\n", + "@Override\n", + "public void splitException(", + "Throwable throwable, ", + "Collection> batch) {\n", + "for (BatchedRequestIssuer responder : batch) {\n", + "responder.setException(throwable);\n", + "}\n", + "}\n", + "@Override\n", + "public long countElements(PublishRequest request) {\n", + "return request.getMessagesCount();\n", + "}\n", + "@Override\n", + "public long countBytes(PublishRequest request) {\n", + "return request.getSerializedSize();\n", + "}\n", "}"); assertEquals(expected, writerVisitor.write()); } @@ -209,6 +225,22 @@ public void batchingDescriptor_noSubresponseField() { "}\n", "};\n", "}\n", + "@Override\n", + "public void splitException(", + "Throwable throwable, ", + "Collection> batch) {\n", + "for (BatchedRequestIssuer responder : batch) {\n", + "responder.setException(throwable);\n", + "}\n", + "}\n", + "@Override\n", + "public long countElements(WriteLogEntriesRequest request) {\n", + "return request.getEntriesCount();\n", + "}\n", + "@Override\n", + "public long countBytes(WriteLogEntriesRequest request) {\n", + "return request.getSerializedSize();\n", + "}\n", "}"); assertEquals(expected, writerVisitor.write()); diff --git a/src/test/java/com/google/api/generator/gapic/composer/ServiceStubSettingsClassComposerTest.java b/src/test/java/com/google/api/generator/gapic/composer/ServiceStubSettingsClassComposerTest.java index 72ff6e2e6e..10456faed3 100644 --- a/src/test/java/com/google/api/generator/gapic/composer/ServiceStubSettingsClassComposerTest.java +++ b/src/test/java/com/google/api/generator/gapic/composer/ServiceStubSettingsClassComposerTest.java @@ -695,6 +695,7 @@ private static List parseServices( + "import com.google.api.gax.retrying.RetrySettings;\n" + "import com.google.api.gax.rpc.ApiCallContext;\n" + "import com.google.api.gax.rpc.ApiClientHeaderProvider;\n" + + "import com.google.api.gax.rpc.BatchedRequestIssuer;\n" + "import com.google.api.gax.rpc.BatchingCallSettings;\n" + "import com.google.api.gax.rpc.BatchingDescriptor;\n" + "import com.google.api.gax.rpc.ClientContext;\n" @@ -723,6 +724,7 @@ private static List parseServices( + "import com.google.logging.v2.WriteLogEntriesResponse;\n" + "import com.google.protobuf.Empty;\n" + "import java.io.IOException;\n" + + "import java.util.Collection;\n" + "import java.util.List;\n" + "import java.util.Objects;\n" + "import javax.annotation.Generated;\n" @@ -995,6 +997,27 @@ private static List parseServices( + " }\n" + " };\n" + " }\n" + + "\n" + + " @Override\n" + + " public void splitException(\n" + + " Throwable throwable,\n" + + " Collection>" + + " batch) {\n" + + " for (BatchedRequestIssuer responder : batch)" + + " {\n" + + " responder.setException(throwable);\n" + + " }\n" + + " }\n" + + "\n" + + " @Override\n" + + " public long countElements(WriteLogEntriesRequest request) {\n" + + " return request.getEntriesCount();\n" + + " }\n" + + "\n" + + " @Override\n" + + " public long countBytes(WriteLogEntriesRequest request) {\n" + + " return request.getSerializedSize();\n" + + " }\n" + " };\n" + "\n" + " public UnaryCallSettings deleteLogSettings() {\n" @@ -1341,6 +1364,7 @@ private static List parseServices( + "import com.google.api.gax.retrying.RetrySettings;\n" + "import com.google.api.gax.rpc.ApiCallContext;\n" + "import com.google.api.gax.rpc.ApiClientHeaderProvider;\n" + + "import com.google.api.gax.rpc.BatchedRequestIssuer;\n" + "import com.google.api.gax.rpc.BatchingCallSettings;\n" + "import com.google.api.gax.rpc.BatchingDescriptor;\n" + "import com.google.api.gax.rpc.ClientContext;\n" @@ -1373,6 +1397,7 @@ private static List parseServices( + "import com.google.pubsub.v1.Topic;\n" + "import com.google.pubsub.v1.UpdateTopicRequest;\n" + "import java.io.IOException;\n" + + "import java.util.Collection;\n" + "import java.util.List;\n" + "import java.util.Objects;\n" + "import javax.annotation.Generated;\n" @@ -1635,6 +1660,25 @@ private static List parseServices( + " }\n" + " };\n" + " }\n" + + "\n" + + " @Override\n" + + " public void splitException(\n" + + " Throwable throwable,\n" + + " Collection> batch) {\n" + + " for (BatchedRequestIssuer responder : batch) {\n" + + " responder.setException(throwable);\n" + + " }\n" + + " }\n" + + "\n" + + " @Override\n" + + " public long countElements(PublishRequest request) {\n" + + " return request.getMessagesCount();\n" + + " }\n" + + "\n" + + " @Override\n" + + " public long countBytes(PublishRequest request) {\n" + + " return request.getSerializedSize();\n" + + " }\n" + " };\n" + "\n" + " public UnaryCallSettings createTopicSettings() {\n"