diff --git a/.github/workflows/gradle-wrapper-validation.yml b/.github/workflows/gradle-wrapper-validation.yml index b827468d719..da1e2fed114 100644 --- a/.github/workflows/gradle-wrapper-validation.yml +++ b/.github/workflows/gradle-wrapper-validation.yml @@ -9,5 +9,5 @@ jobs: name: "Gradle wrapper validation" runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - uses: gradle/wrapper-validation-action@v1 + - uses: actions/checkout@v4 + - uses: gradle/actions/wrapper-validation@v4 diff --git a/.github/workflows/lock.yml b/.github/workflows/lock.yml index 907a9dad2b5..3070a1a2f7c 100644 --- a/.github/workflows/lock.yml +++ b/.github/workflows/lock.yml @@ -13,7 +13,7 @@ jobs: lock: runs-on: ubuntu-latest steps: - - uses: dessant/lock-threads@v4 + - uses: dessant/lock-threads@v5 with: github-token: ${{ github.token }} issue-inactive-days: 90 diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index bc5a175906f..8c639cf14ed 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -21,14 +21,14 @@ jobs: fail-fast: false # Should swap to true if we grow a large matrix steps: - - uses: actions/checkout@v3 - - uses: actions/setup-java@v3 + - uses: actions/checkout@v4 + - uses: actions/setup-java@v4 with: java-version: ${{ matrix.jre }} distribution: 'temurin' - name: Gradle cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: | ~/.gradle/caches @@ -37,7 +37,7 @@ jobs: restore-keys: | ${{ runner.os }}-gradle- - name: Maven cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: | ~/.m2/repository @@ -46,7 +46,7 @@ jobs: restore-keys: | ${{ runner.os }}-maven- - name: Protobuf cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: /tmp/protobuf-cache key: ${{ runner.os }}-maven-${{ hashFiles('buildscripts/make_dependencies.sh') }} @@ -55,7 +55,7 @@ jobs: run: buildscripts/kokoro/unix.sh - name: Post Failure Upload Test Reports to Artifacts if: ${{ failure() }} - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: Test Reports (JRE ${{ matrix.jre }}) path: | @@ -71,7 +71,9 @@ jobs: COVERALLS_REPO_TOKEN: ${{ secrets.COVERALLS_REPO_TOKEN }} run: ./gradlew :grpc-all:coveralls -PskipAndroid=true -x compileJava - name: Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} bazel: runs-on: ubuntu-latest @@ -79,7 +81,7 @@ jobs: USE_BAZEL_VERSION: 6.0.0 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Check versions match in MODULE.bazel and repositories.bzl run: | @@ -87,7 +89,7 @@ jobs: <(sed -n '/GRPC_DEPS_START/,/GRPC_DEPS_END/ {/GRPC_DEPS_/! p}' repositories.bzl) - name: Bazel cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: | ~/.cache/bazel/*/cache diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ce40827e748..646a7d986fd 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -30,43 +30,36 @@ style configurations are commonly useful. For IntelliJ 14, copy the style to `~/.IdeaIC14/config/codestyles/`, start IntelliJ, go to File > Settings > Code Style, and set the Scheme to `GoogleStyle`. -## Maintaining clean commit history - -We have few conventions for keeping history clean and making code reviews easier -for reviewers: - -* First line of commit messages should be in format of - - `package-name: summary of change` - - where the summary finishes the sentence: `This commit improves gRPC to ____________.` - - for example: - - `core,netty,interop-testing: add capacitive duractance to turbo encabulators` - -* Every time you receive a feedback on your pull request, push changes that - address it as a separate one or multiple commits with a descriptive commit - message (try avoid using vauge `addressed pr feedback` type of messages). - - Project maintainers are obligated to squash those commits into one when - merging. - ## Guidelines for Pull Requests How to get your contributions merged smoothly and quickly. - Create **small PRs** that are narrowly focused on **addressing a single concern**. We often times receive PRs that are trying to fix several things at a time, but only one fix is considered acceptable, nothing gets merged and both author's & review's time is wasted. Create more PRs to address different concerns and everyone will be happy. -- For speculative changes, consider opening an issue and discussing it first. If you are suggesting a behavioral or API change, consider starting with a [gRFC proposal](https://github.com/grpc/proposal). - -- Provide a good **PR description** as a record of **what** change is being made and **why** it was made. Link to a github issue if it exists. - -- Don't fix code style and formatting unless you are already changing that line to address an issue. PRs with irrelevant changes won't be merged. If you do want to fix formatting or style, do that in a separate PR. - -- Unless your PR is trivial, you should expect there will be reviewer comments that you'll need to address before merging. We expect you to be reasonably responsive to those comments, otherwise the PR will be closed after 2-3 weeks of inactivity. - -- Maintain **clean commit history** and use **meaningful commit messages**. See [maintaining clean commit history](#maintaining-clean-commit-history) for details. - +- For speculative changes, consider opening an issue and discussing it to avoid + wasting time on an inappropriate approach. If you are suggesting a behavioral + or API change, consider starting with a [gRFC + proposal](https://github.com/grpc/proposal). + +- Follow [typical Git commit message](https://cbea.ms/git-commit/#seven-rules) + structure. Have a good **commit description** as a record of **what** and + **why** the change is being made. Link to a GitHub issue if it exists. The + commit description makes a good PR description and is auto-copied by GitHub if + you have a single commit when creating the PR. + + If your change is mostly for a single module (e.g., other module changes are + trivial), prefix your commit summary with the module name changed. Instead of + "Add HTTP/2 faster-than-light support to gRPC Netty" it is more terse as + "netty: Add faster-than-light support". + +- Don't fix code style and formatting unless you are already changing that line + to address an issue. If you do want to fix formatting or style, do that in a + separate PR. + +- Unless your PR is trivial, you should expect there will be reviewer comments + that you'll need to address before merging. Address comments with additional + commits so the reviewer can review just the changes; do not squash reviewed + commits unless the reviewer agrees. PRs are squashed when merging. + - Keep your PR up to date with upstream/master (if there are merge conflicts, we can't really merge your change). - **All tests need to be passing** before your change can be merged. We recommend you **run tests locally** before creating your PR to catch breakages early on. Also, `./gradlew build` (`gradlew build` on Windows) **must not introduce any new warnings**. diff --git a/MODULE.bazel b/MODULE.bazel index 81c3249f47a..b60ea565073 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -2,7 +2,7 @@ module( name = "grpc-java", compatibility_level = 0, repo_name = "io_grpc_grpc_java", - version = "1.67.0-SNAPSHOT", # CURRENT_GRPC_VERSION + version = "1.68.0-SNAPSHOT", # CURRENT_GRPC_VERSION ) # GRPC_DEPS_START diff --git a/README.md b/README.md index fef37c1c3bb..cb38ad66394 100644 --- a/README.md +++ b/README.md @@ -44,8 +44,8 @@ For a guided tour, take a look at the [quick start guide](https://grpc.io/docs/languages/java/quickstart) or the more explanatory [gRPC basics](https://grpc.io/docs/languages/java/basics). -The [examples](https://github.com/grpc/grpc-java/tree/v1.65.0/examples) and the -[Android example](https://github.com/grpc/grpc-java/tree/v1.65.0/examples/android) +The [examples](https://github.com/grpc/grpc-java/tree/v1.66.0/examples) and the +[Android example](https://github.com/grpc/grpc-java/tree/v1.66.0/examples/android) are standalone projects that showcase the usage of gRPC. Download @@ -56,18 +56,18 @@ Download [the JARs][]. Or for Maven with non-Android, add to your `pom.xml`: io.grpc grpc-netty-shaded - 1.65.0 + 1.66.0 runtime io.grpc grpc-protobuf - 1.65.0 + 1.66.0 io.grpc grpc-stub - 1.65.0 + 1.66.0 org.apache.tomcat @@ -79,18 +79,18 @@ Download [the JARs][]. Or for Maven with non-Android, add to your `pom.xml`: Or for Gradle with non-Android, add to your dependencies: ```gradle -runtimeOnly 'io.grpc:grpc-netty-shaded:1.65.0' -implementation 'io.grpc:grpc-protobuf:1.65.0' -implementation 'io.grpc:grpc-stub:1.65.0' +runtimeOnly 'io.grpc:grpc-netty-shaded:1.66.0' +implementation 'io.grpc:grpc-protobuf:1.66.0' +implementation 'io.grpc:grpc-stub:1.66.0' compileOnly 'org.apache.tomcat:annotations-api:6.0.53' // necessary for Java 9+ ``` For Android client, use `grpc-okhttp` instead of `grpc-netty-shaded` and `grpc-protobuf-lite` instead of `grpc-protobuf`: ```gradle -implementation 'io.grpc:grpc-okhttp:1.65.0' -implementation 'io.grpc:grpc-protobuf-lite:1.65.0' -implementation 'io.grpc:grpc-stub:1.65.0' +implementation 'io.grpc:grpc-okhttp:1.66.0' +implementation 'io.grpc:grpc-protobuf-lite:1.66.0' +implementation 'io.grpc:grpc-stub:1.66.0' compileOnly 'org.apache.tomcat:annotations-api:6.0.53' // necessary for Java 9+ ``` @@ -99,7 +99,7 @@ For [Bazel](https://bazel.build), you can either (with the GAVs from above), or use `@io_grpc_grpc_java//api` et al (see below). [the JARs]: -https://search.maven.org/search?q=g:io.grpc%20AND%20v:1.65.0 +https://search.maven.org/search?q=g:io.grpc%20AND%20v:1.66.0 Development snapshots are available in [Sonatypes's snapshot repository](https://oss.sonatype.org/content/repositories/snapshots/). @@ -129,9 +129,9 @@ For protobuf-based codegen integrated with the Maven build system, you can use protobuf-maven-plugin 0.6.1 - com.google.protobuf:protoc:3.25.1:exe:${os.detected.classifier} + com.google.protobuf:protoc:3.25.3:exe:${os.detected.classifier} grpc-java - io.grpc:protoc-gen-grpc-java:1.65.0:exe:${os.detected.classifier} + io.grpc:protoc-gen-grpc-java:1.66.0:exe:${os.detected.classifier} @@ -157,11 +157,11 @@ plugins { protobuf { protoc { - artifact = "com.google.protobuf:protoc:3.25.1" + artifact = "com.google.protobuf:protoc:3.25.3" } plugins { grpc { - artifact = 'io.grpc:protoc-gen-grpc-java:1.65.0' + artifact = 'io.grpc:protoc-gen-grpc-java:1.66.0' } } generateProtoTasks { @@ -190,11 +190,11 @@ plugins { protobuf { protoc { - artifact = "com.google.protobuf:protoc:3.25.1" + artifact = "com.google.protobuf:protoc:3.25.3" } plugins { grpc { - artifact = 'io.grpc:protoc-gen-grpc-java:1.65.0' + artifact = 'io.grpc:protoc-gen-grpc-java:1.66.0' } } generateProtoTasks { diff --git a/alts/BUILD.bazel b/alts/BUILD.bazel index 819daedcc82..73420e11053 100644 --- a/alts/BUILD.bazel +++ b/alts/BUILD.bazel @@ -19,7 +19,6 @@ java_library( "@com_google_protobuf//:protobuf_java_util", artifact("com.google.code.findbugs:jsr305"), artifact("com.google.guava:guava"), - artifact("com.google.j2objc:j2objc-annotations"), artifact("io.netty:netty-buffer"), artifact("io.netty:netty-codec"), artifact("io.netty:netty-common"), @@ -45,7 +44,6 @@ java_library( artifact("com.google.auth:google-auth-library-oauth2-http"), artifact("com.google.code.findbugs:jsr305"), artifact("com.google.guava:guava"), - artifact("com.google.j2objc:j2objc-annotations"), artifact("io.netty:netty-common"), artifact("io.netty:netty-handler"), artifact("io.netty:netty-transport"), diff --git a/android-interop-testing/build.gradle b/android-interop-testing/build.gradle index 22aa5f2288d..9b3b021afce 100644 --- a/android-interop-testing/build.gradle +++ b/android-interop-testing/build.gradle @@ -69,7 +69,6 @@ dependencies { implementation project(':grpc-android'), project(':grpc-core'), - project(':grpc-auth'), project(':grpc-census'), project(':grpc-okhttp'), project(':grpc-protobuf-lite'), @@ -81,10 +80,6 @@ dependencies { libraries.androidx.test.rules, libraries.opencensus.contrib.grpc.metrics - implementation (libraries.google.auth.oauth2Http) { - exclude group: 'org.apache.httpcomponents' - } - implementation (project(':grpc-services')) { exclude group: 'com.google.protobuf' exclude group: 'com.google.guava' diff --git a/api/BUILD.bazel b/api/BUILD.bazel index 07be1d58dc7..6bf3375e9f0 100644 --- a/api/BUILD.bazel +++ b/api/BUILD.bazel @@ -13,6 +13,5 @@ java_library( artifact("com.google.errorprone:error_prone_annotations"), artifact("com.google.guava:failureaccess"), # future transitive dep of Guava. See #5214 artifact("com.google.guava:guava"), - artifact("com.google.j2objc:j2objc-annotations"), ], ) diff --git a/api/src/main/java/io/grpc/InternalSubchannelAddressAttributes.java b/api/src/main/java/io/grpc/InternalSubchannelAddressAttributes.java new file mode 100644 index 00000000000..cfc2f7c5137 --- /dev/null +++ b/api/src/main/java/io/grpc/InternalSubchannelAddressAttributes.java @@ -0,0 +1,31 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +/** + * An internal class. Do not use. + * + *

An interface to provide the attributes for address connected by subchannel. + */ +@Internal +public interface InternalSubchannelAddressAttributes { + + /** + * Return attributes of the server address connected by sub channel. + */ + public Attributes getConnectedAddressAttributes(); +} diff --git a/api/src/main/java/io/grpc/LoadBalancer.java b/api/src/main/java/io/grpc/LoadBalancer.java index 15106a5ffc6..0fbce5fa5be 100644 --- a/api/src/main/java/io/grpc/LoadBalancer.java +++ b/api/src/main/java/io/grpc/LoadBalancer.java @@ -1428,6 +1428,18 @@ public void updateAddresses(List addrs) { public Object getInternalSubchannel() { throw new UnsupportedOperationException(); } + + /** + * (Internal use only) returns attributes of the address subchannel is connected to. + * + *

Warning: this is INTERNAL API, is not supposed to be used by external users, and may + * change without notice. If you think you must use it, please file an issue and we can consider + * removing its "internal" status. + */ + @Internal + public Attributes getConnectedAddressAttributes() { + throw new UnsupportedOperationException(); + } } /** diff --git a/auth/BUILD.bazel b/auth/BUILD.bazel index 095fae5af8b..a19562fa7f7 100644 --- a/auth/BUILD.bazel +++ b/auth/BUILD.bazel @@ -11,6 +11,5 @@ java_library( artifact("com.google.auth:google-auth-library-credentials"), artifact("com.google.code.findbugs:jsr305"), artifact("com.google.guava:guava"), - artifact("com.google.j2objc:j2objc-annotations"), ], ) diff --git a/build.gradle b/build.gradle index 74cfacb800a..740f534e136 100644 --- a/build.gradle +++ b/build.gradle @@ -21,7 +21,7 @@ subprojects { apply plugin: "net.ltgt.errorprone" group = "io.grpc" - version = "1.67.0-SNAPSHOT" // CURRENT_GRPC_VERSION + version = "1.68.0-SNAPSHOT" // CURRENT_GRPC_VERSION repositories { maven { // The google mirror is less flaky than mavenCentral() diff --git a/buildscripts/grpc-java-artifacts/Dockerfile b/buildscripts/grpc-java-artifacts/Dockerfile index 97c152780a3..736babe9d8e 100644 --- a/buildscripts/grpc-java-artifacts/Dockerfile +++ b/buildscripts/grpc-java-artifacts/Dockerfile @@ -28,6 +28,6 @@ RUN mkdir -p "$ANDROID_HOME/cmdline-tools" && \ yes | "$ANDROID_HOME/cmdline-tools/latest/bin/sdkmanager" --licenses # Install Maven -RUN curl -Ls https://repo.maven.apache.org/maven2/org/apache/maven/apache-maven/3.3.9/apache-maven-3.3.9-bin.tar.gz | \ +RUN curl -Ls https://dlcdn.apache.org/maven/maven-3/3.8.8/binaries/apache-maven-3.8.8-bin.tar.gz | \ tar xz -C /var/local -ENV PATH /var/local/apache-maven-3.3.9/bin:$PATH +ENV PATH /var/local/apache-maven-3.8.8/bin:$PATH diff --git a/buildscripts/kokoro/macos.sh b/buildscripts/kokoro/macos.sh index 97259231ee8..018d15dd2f9 100755 --- a/buildscripts/kokoro/macos.sh +++ b/buildscripts/kokoro/macos.sh @@ -15,4 +15,7 @@ export GRADLE_FLAGS="${GRADLE_FLAGS:-} --max-workers=2" . "$GRPC_JAVA_DIR"/buildscripts/kokoro/kokoro.sh trap spongify_logs EXIT +export -n JAVA_HOME +export PATH="$(/usr/libexec/java_home -v"1.8.0")/bin:${PATH}" + "$GRPC_JAVA_DIR"/buildscripts/kokoro/unix.sh diff --git a/buildscripts/kokoro/unix.sh b/buildscripts/kokoro/unix.sh index 9b1a4054c7e..1b88b56ab40 100755 --- a/buildscripts/kokoro/unix.sh +++ b/buildscripts/kokoro/unix.sh @@ -23,11 +23,6 @@ readonly GRPC_JAVA_DIR="$(cd "$(dirname "$0")"/../.. && pwd)" # cd to the root dir of grpc-java cd $(dirname $0)/../.. -# TODO(zpencer): always make sure we are using Oracle jdk8 -if [[ -f /usr/libexec/java_home ]]; then - JAVA_HOME=$(/usr/libexec/java_home -v"1.8.0") -fi - # ARCH is x86_64 unless otherwise specified. ARCH="${ARCH:-x86_64}" diff --git a/compiler/src/test/golden/TestDeprecatedService.java.txt b/compiler/src/test/golden/TestDeprecatedService.java.txt index 75e9e0b47e0..04a7f2406b3 100644 --- a/compiler/src/test/golden/TestDeprecatedService.java.txt +++ b/compiler/src/test/golden/TestDeprecatedService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.67.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.68.0-SNAPSHOT)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated @java.lang.Deprecated diff --git a/compiler/src/test/golden/TestService.java.txt b/compiler/src/test/golden/TestService.java.txt index 3852b6ee547..d69abad7cbb 100644 --- a/compiler/src/test/golden/TestService.java.txt +++ b/compiler/src/test/golden/TestService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.67.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.68.0-SNAPSHOT)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class TestServiceGrpc { diff --git a/core/BUILD.bazel b/core/BUILD.bazel index a1d3d19e828..35c20628d0b 100644 --- a/core/BUILD.bazel +++ b/core/BUILD.bazel @@ -30,7 +30,6 @@ java_library( artifact("com.google.code.findbugs:jsr305"), artifact("com.google.errorprone:error_prone_annotations"), artifact("com.google.guava:guava"), - artifact("com.google.j2objc:j2objc-annotations"), artifact("io.perfmark:perfmark-api"), artifact("org.codehaus.mojo:animal-sniffer-annotations"), ], diff --git a/core/src/main/java/io/grpc/internal/AbstractClientStream.java b/core/src/main/java/io/grpc/internal/AbstractClientStream.java index 51c31993f46..bb346657d53 100644 --- a/core/src/main/java/io/grpc/internal/AbstractClientStream.java +++ b/core/src/main/java/io/grpc/internal/AbstractClientStream.java @@ -92,8 +92,8 @@ void writeFrame( private final TransportTracer transportTracer; private final Framer framer; - private boolean shouldBeCountedForInUse; - private boolean useGet; + private final boolean shouldBeCountedForInUse; + private final boolean useGet; private Metadata headers; /** * Whether cancel() has been called. This is not strictly necessary, but removes the delay between diff --git a/core/src/main/java/io/grpc/internal/GrpcUtil.java b/core/src/main/java/io/grpc/internal/GrpcUtil.java index 593bdbce13f..a1fe34c2edc 100644 --- a/core/src/main/java/io/grpc/internal/GrpcUtil.java +++ b/core/src/main/java/io/grpc/internal/GrpcUtil.java @@ -219,7 +219,7 @@ public byte[] parseAsciiString(byte[] serialized) { public static final Splitter ACCEPT_ENCODING_SPLITTER = Splitter.on(',').trimResults(); - public static final String IMPLEMENTATION_VERSION = "1.67.0-SNAPSHOT"; // CURRENT_GRPC_VERSION + public static final String IMPLEMENTATION_VERSION = "1.68.0-SNAPSHOT"; // CURRENT_GRPC_VERSION /** * The default timeout in nanos for a keepalive ping request. diff --git a/core/src/main/java/io/grpc/internal/InternalSubchannel.java b/core/src/main/java/io/grpc/internal/InternalSubchannel.java index a986cb2deff..70e42e2f5f1 100644 --- a/core/src/main/java/io/grpc/internal/InternalSubchannel.java +++ b/core/src/main/java/io/grpc/internal/InternalSubchannel.java @@ -157,6 +157,8 @@ protected void handleNotInUse() { private Status shutdownReason; + private volatile Attributes connectedAddressAttributes; + InternalSubchannel(List addressGroups, String authority, String userAgent, BackoffPolicy.Provider backoffPolicyProvider, ClientTransportFactory transportFactory, ScheduledExecutorService scheduledExecutor, @@ -525,6 +527,13 @@ public void run() { return channelStatsFuture; } + /** + * Return attributes for server address connected by sub channel. + */ + public Attributes getConnectedAddressAttributes() { + return connectedAddressAttributes; + } + ConnectivityState getState() { return state.getState(); } @@ -568,6 +577,7 @@ public void run() { } else if (pendingTransport == transport) { activeTransport = transport; pendingTransport = null; + connectedAddressAttributes = addressIndex.getCurrentEagAttributes(); gotoNonErrorState(READY); } } diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index 7f45ca967ea..07dcf9ee7bb 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -2044,6 +2044,11 @@ public void updateAddresses(List addrs) { subchannel.updateAddresses(addrs); } + @Override + public Attributes getConnectedAddressAttributes() { + return subchannel.getConnectedAddressAttributes(); + } + private List stripOverrideAuthorityAttributes( List eags) { List eagsWithoutOverrideAttr = new ArrayList<>(); diff --git a/core/src/main/java/io/grpc/internal/PickFirstLeafLoadBalancer.java b/core/src/main/java/io/grpc/internal/PickFirstLeafLoadBalancer.java index 253422d3dbd..bfa462e16e1 100644 --- a/core/src/main/java/io/grpc/internal/PickFirstLeafLoadBalancer.java +++ b/core/src/main/java/io/grpc/internal/PickFirstLeafLoadBalancer.java @@ -33,7 +33,6 @@ import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; import io.grpc.Status; -import io.grpc.SynchronizationContext; import io.grpc.SynchronizationContext.ScheduledHandle; import java.net.SocketAddress; import java.util.ArrayList; @@ -61,7 +60,7 @@ final class PickFirstLeafLoadBalancer extends LoadBalancer { static final int CONNECTION_DELAY_INTERVAL_MS = 250; private final Helper helper; private final Map subchannels = new HashMap<>(); - private Index addressIndex; + private final Index addressIndex = new Index(ImmutableList.of()); private int numTf = 0; private boolean firstPass = true; @Nullable @@ -70,6 +69,7 @@ final class PickFirstLeafLoadBalancer extends LoadBalancer { private ConnectivityState concludedState = IDLE; private final boolean enableHappyEyeballs = PickFirstLoadBalancerProvider.isEnabledHappyEyeballs(); + private boolean notAPetiolePolicy = true; // means not under a petiole policy PickFirstLeafLoadBalancer(Helper helper) { this.helper = checkNotNull(helper, "helper"); @@ -81,6 +81,10 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { return Status.FAILED_PRECONDITION.withDescription("Already shut down"); } + // Cache whether or not this is a petiole policy, which is based off of an address attribute + Boolean isPetiolePolicy = resolvedAddresses.getAttributes().get(IS_PETIOLE_POLICY); + this.notAPetiolePolicy = isPetiolePolicy == null || !isPetiolePolicy; + List servers = resolvedAddresses.getAddresses(); // Validate the address list @@ -122,9 +126,7 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { final ImmutableList newImmutableAddressGroups = ImmutableList.builder().addAll(cleanServers).build(); - if (addressIndex == null) { - addressIndex = new Index(newImmutableAddressGroups); - } else if (rawConnectivityState == READY) { + if (rawConnectivityState == READY) { // If the previous ready subchannel exists in new address list, // keep this connection and don't create new subchannels SocketAddress previousAddress = addressIndex.getCurrentAddress(); @@ -133,9 +135,8 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { SubchannelData subchannelData = subchannels.get(previousAddress); subchannelData.getSubchannel().updateAddresses(addressIndex.getCurrentEagAsList()); return Status.OK; - } else { - addressIndex.reset(); // Previous ready subchannel not in the new list of addresses } + // Previous ready subchannel not in the new list of addresses } else { addressIndex.updateGroups(newImmutableAddressGroups); } @@ -156,20 +157,18 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { } } - if (oldAddrs.size() == 0 || rawConnectivityState == CONNECTING - || rawConnectivityState == READY) { - // start connection attempt at first address + if (oldAddrs.size() == 0) { + // Make tests happy; they don't properly assume starting in CONNECTING rawConnectivityState = CONNECTING; updateBalancingState(CONNECTING, new Picker(PickResult.withNoResult())); - cancelScheduleTask(); - requestConnection(); + } - } else if (rawConnectivityState == IDLE) { - // start connection attempt at first address when requested - SubchannelPicker picker = new RequestConnectionPicker(this); - updateBalancingState(IDLE, picker); + if (rawConnectivityState == READY) { + // connect from beginning when prompted + rawConnectivityState = IDLE; + updateBalancingState(IDLE, new RequestConnectionPicker(this)); - } else if (rawConnectivityState == TRANSIENT_FAILURE) { + } else if (rawConnectivityState == CONNECTING || rawConnectivityState == TRANSIENT_FAILURE) { // start connection attempt at first address cancelScheduleTask(); requestConnection(); @@ -207,9 +206,7 @@ public void handleNameResolutionError(Status error) { subchannelData.getSubchannel().shutdown(); } subchannels.clear(); - if (addressIndex != null) { - addressIndex.updateGroups(null); - } + addressIndex.updateGroups(ImmutableList.of()); rawConnectivityState = TRANSIENT_FAILURE; updateBalancingState(TRANSIENT_FAILURE, new Picker(PickResult.withError(error))); } @@ -311,7 +308,8 @@ private void updateHealthCheckedState(SubchannelData subchannelData) { if (subchannelData.state != READY) { return; } - if (subchannelData.getHealthState() == READY) { + + if (notAPetiolePolicy || subchannelData.getHealthState() == READY) { updateBalancingState(READY, new FixedResultPicker(PickResult.withSubchannel(subchannelData.subchannel))); } else if (subchannelData.getHealthState() == TRANSIENT_FAILURE) { @@ -372,7 +370,7 @@ private void shutdownRemaining(SubchannelData activeSubchannelData) { */ @Override public void requestConnection() { - if (addressIndex == null || !addressIndex.isValid() || rawConnectivityState == SHUTDOWN ) { + if (!addressIndex.isValid() || rawConnectivityState == SHUTDOWN) { return; } @@ -390,22 +388,14 @@ public void requestConnection() { scheduleNextConnection(); break; case CONNECTING: - if (enableHappyEyeballs) { - scheduleNextConnection(); - } else { - subchannelData.subchannel.requestConnection(); - } + scheduleNextConnection(); break; case TRANSIENT_FAILURE: addressIndex.increment(); requestConnection(); break; - case READY: // Shouldn't ever happen - log.warning("Requesting a connection even though we have a READY subchannel"); - break; - case SHUTDOWN: default: - // Makes checkstyle happy + // Wait for current subchannel to change state } } @@ -430,16 +420,7 @@ public void run() { } } - SynchronizationContext synchronizationContext = null; - try { - synchronizationContext = helper.getSynchronizationContext(); - } catch (NullPointerException e) { - // All helpers should have a sync context, but if one doesn't (ex. user had a custom test) - // we don't want to break previously working functionality. - return; - } - - scheduleConnectionTask = synchronizationContext.schedule( + scheduleConnectionTask = helper.getSynchronizationContext().schedule( new StartNextConnection(), CONNECTION_DELAY_INTERVAL_MS, TimeUnit.MILLISECONDS, @@ -469,7 +450,7 @@ private SubchannelData createNewSubchannel(SocketAddress addr, Attributes attrs) hcListener.subchannelData = subchannelData; subchannels.put(addr, subchannelData); Attributes scAttrs = subchannel.getAttributes(); - if (scAttrs.get(LoadBalancer.HAS_HEALTH_PRODUCER_LISTENER_KEY) == null) { + if (notAPetiolePolicy || scAttrs.get(LoadBalancer.HAS_HEALTH_PRODUCER_LISTENER_KEY) == null) { subchannelData.healthStateInfo = ConnectivityStateInfo.forNonError(READY); } subchannel.start(stateInfo -> processSubchannelState(subchannelData, stateInfo)); @@ -477,8 +458,7 @@ private SubchannelData createNewSubchannel(SocketAddress addr, Attributes attrs) } private boolean isPassComplete() { - if (addressIndex == null || addressIndex.isValid() - || subchannels.size() < addressIndex.size()) { + if (addressIndex.isValid() || subchannels.size() < addressIndex.size()) { return false; } for (SubchannelData sc : subchannels.values()) { @@ -494,15 +474,19 @@ private final class HealthListener implements SubchannelStateListener { @Override public void onSubchannelState(ConnectivityStateInfo newState) { + if (notAPetiolePolicy) { + log.log(Level.WARNING, + "Ignoring health status {0} for subchannel {1} as this is not under a petiole policy", + new Object[]{newState, subchannelData.subchannel}); + return; + } + log.log(Level.FINE, "Received health status {0} for subchannel {1}", new Object[]{newState, subchannelData.subchannel}); subchannelData.healthStateInfo = newState; - try { - if (subchannelData == subchannels.get(addressIndex.getCurrentAddress())) { - updateHealthCheckedState(subchannelData); - } - } catch (IllegalStateException e) { - log.fine("Health listener received state change after subchannel was removed"); + if (addressIndex.isValid() + && subchannelData == subchannels.get(addressIndex.getCurrentAddress())) { + updateHealthCheckedState(subchannelData); } } } @@ -566,11 +550,12 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { @VisibleForTesting static final class Index { private List addressGroups; + private int size; private int groupIndex; private int addressIndex; public Index(List groups) { - this.addressGroups = groups != null ? groups : Collections.emptyList(); + updateGroups(groups); } public boolean isValid() { @@ -629,9 +614,14 @@ public List getCurrentEagAsList() { /** * Update to new groups, resetting the current index. */ - public void updateGroups(ImmutableList newGroups) { - addressGroups = newGroups != null ? newGroups : Collections.emptyList(); + public void updateGroups(List newGroups) { + addressGroups = checkNotNull(newGroups, "newGroups"); reset(); + int size = 0; + for (EquivalentAddressGroup eag : newGroups) { + size += eag.getAddresses().size(); + } + this.size = size; } /** @@ -652,7 +642,7 @@ public boolean seekTo(SocketAddress needle) { } public int size() { - return (addressGroups != null) ? addressGroups.size() : 0; + return size; } } diff --git a/core/src/main/java/io/grpc/internal/ReadableBuffers.java b/core/src/main/java/io/grpc/internal/ReadableBuffers.java index 1435be138de..e512c810f84 100644 --- a/core/src/main/java/io/grpc/internal/ReadableBuffers.java +++ b/core/src/main/java/io/grpc/internal/ReadableBuffers.java @@ -415,6 +415,7 @@ public ByteBuffer getByteBuffer() { public InputStream detach() { ReadableBuffer detachedBuffer = buffer; buffer = buffer.readBytes(0); + detachedBuffer.touch(); return new BufferInputStream(detachedBuffer); } diff --git a/core/src/test/java/io/grpc/internal/InternalSubchannelTest.java b/core/src/test/java/io/grpc/internal/InternalSubchannelTest.java index f7631c34c0d..e4d9f27ed46 100644 --- a/core/src/test/java/io/grpc/internal/InternalSubchannelTest.java +++ b/core/src/test/java/io/grpc/internal/InternalSubchannelTest.java @@ -1339,6 +1339,32 @@ public void channelzStatContainsTransport() throws Exception { assertThat(index.getCurrentAddress()).isSameInstanceAs(addr2); } + @Test + public void connectedAddressAttributes_ready() { + SocketAddress addr = new SocketAddress() {}; + Attributes attr = Attributes.newBuilder().set(Attributes.Key.create("some-key"), "1").build(); + createInternalSubchannel(new EquivalentAddressGroup(Arrays.asList(addr), attr)); + + assertEquals(IDLE, internalSubchannel.getState()); + assertNoCallbackInvoke(); + assertNull(internalSubchannel.obtainActiveTransport()); + assertNull(internalSubchannel.getConnectedAddressAttributes()); + + assertExactCallbackInvokes("onStateChange:CONNECTING"); + assertEquals(CONNECTING, internalSubchannel.getState()); + verify(mockTransportFactory).newClientTransport( + eq(addr), + eq(createClientTransportOptions().setEagAttributes(attr)), + isA(TransportLogger.class)); + assertNull(internalSubchannel.getConnectedAddressAttributes()); + + internalSubchannel.obtainActiveTransport(); + transports.peek().listener.transportReady(); + assertExactCallbackInvokes("onStateChange:READY"); + assertEquals(READY, internalSubchannel.getState()); + assertEquals(attr, internalSubchannel.getConnectedAddressAttributes()); + } + /** Create ClientTransportOptions. Should not be reused if it may be mutated. */ private ClientTransportFactory.ClientTransportOptions createClientTransportOptions() { return new ClientTransportFactory.ClientTransportOptions() diff --git a/core/src/test/java/io/grpc/internal/PickFirstLeafLoadBalancerTest.java b/core/src/test/java/io/grpc/internal/PickFirstLeafLoadBalancerTest.java index 335d199d8b1..63915bddc99 100644 --- a/core/src/test/java/io/grpc/internal/PickFirstLeafLoadBalancerTest.java +++ b/core/src/test/java/io/grpc/internal/PickFirstLeafLoadBalancerTest.java @@ -25,6 +25,7 @@ import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import static io.grpc.LoadBalancer.HAS_HEALTH_PRODUCER_LISTENER_KEY; import static io.grpc.LoadBalancer.HEALTH_CONSUMER_LISTENER_ARG_KEY; +import static io.grpc.LoadBalancer.IS_PETIOLE_POLICY; import static io.grpc.internal.PickFirstLeafLoadBalancer.CONNECTION_DELAY_INTERVAL_MS; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -361,11 +362,7 @@ public void pickAfterResolvedAndUnchanged() { // Second acceptResolvedAddresses shouldn't do anything loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(servers).setAttributes(affinity).build()); - if (enableHappyEyeballs) { - inOrder.verify(mockSubchannel1, never()).requestConnection(); - } else { - inOrder.verify(mockSubchannel1, times(1)).requestConnection(); - } + inOrder.verify(mockSubchannel1, never()).requestConnection(); inOrder.verify(mockHelper, never()).updateBalancingState(any(), any()); } @@ -393,15 +390,44 @@ public void pickAfterResolvedAndChanged() { verify(mockSubchannel2).requestConnection(); } + @Test + public void healthCheck_nonPetiolePolicy() { + when(mockSubchannel1.getAttributes()).thenReturn( + Attributes.newBuilder().set(HAS_HEALTH_PRODUCER_LISTENER_KEY, true).build()); + + // Initialize with one server loadbalancer and both health and state listeners + List oneServer = Lists.newArrayList(servers.get(0)); + loadBalancer.acceptResolvedAddresses(ResolvedAddresses.newBuilder().setAddresses(oneServer) + .setAttributes(Attributes.EMPTY).build()); + InOrder inOrder = inOrder(mockHelper, mockSubchannel1); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); + inOrder.verify(mockHelper).createSubchannel(createArgsCaptor.capture()); + SubchannelStateListener healthListener = createArgsCaptor.getValue() + .getOption(HEALTH_CONSUMER_LISTENER_ARG_KEY); + inOrder.verify(mockSubchannel1).start(stateListenerCaptor.capture()); + SubchannelStateListener stateListener = stateListenerCaptor.getValue(); + + stateListener.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); + healthListener.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); + inOrder.verify(mockHelper, never()).updateBalancingState(any(), any()); + + stateListener.onSubchannelState(ConnectivityStateInfo.forNonError(READY)); + inOrder.verify(mockHelper).updateBalancingState(eq(READY), any()); // health listener ignored + + healthListener.onSubchannelState(ConnectivityStateInfo.forTransientFailure(Status.INTERNAL)); + inOrder.verify(mockHelper, never()).updateBalancingState(any(), any(SubchannelPicker.class)); + } + @Test public void healthCheckFlow() { when(mockSubchannel1.getAttributes()).thenReturn( Attributes.newBuilder().set(HAS_HEALTH_PRODUCER_LISTENER_KEY, true).build()); when(mockSubchannel2.getAttributes()).thenReturn( Attributes.newBuilder().set(HAS_HEALTH_PRODUCER_LISTENER_KEY, true).build()); + List oneServer = Lists.newArrayList(servers.get(0), servers.get(1)); loadBalancer.acceptResolvedAddresses(ResolvedAddresses.newBuilder().setAddresses(oneServer) - .setAttributes(Attributes.EMPTY).build()); + .setAttributes(Attributes.newBuilder().set(IS_PETIOLE_POLICY, true).build()).build()); InOrder inOrder = inOrder(mockHelper, mockSubchannel1, mockSubchannel2); inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), any(SubchannelPicker.class)); @@ -417,13 +443,13 @@ public void healthCheckFlow() { // subchannel2 | IDLE | IDLE stateListener.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); healthListener.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); - inOrder.verify(mockHelper, times(0)).updateBalancingState(any(), any()); + inOrder.verify(mockHelper, never()).updateBalancingState(any(), any()); // subchannel | state | health // subchannel1 | READY | CONNECTING // subchannel2 | IDLE | IDLE stateListener.onSubchannelState(ConnectivityStateInfo.forNonError(READY)); - inOrder.verify(mockHelper, times(0)).updateBalancingState(any(), any()); + inOrder.verify(mockHelper, never()).updateBalancingState(any(), any()); // subchannel | state | health // subchannel1 | READY | READY @@ -662,7 +688,7 @@ public void nameResolutionError_emptyAddressList() { } @Test - public void nameResolutionAfterSufficientTFs() { + public void nameResolutionAfterSufficientTFs_multipleEags() { InOrder inOrder = inOrder(mockHelper); acceptXSubchannels(3); Status error = Status.UNAVAILABLE.withDescription("boom!"); @@ -707,6 +733,57 @@ public void nameResolutionAfterSufficientTFs() { inOrder.verify(mockHelper).refreshNameResolution(); } + @Test + public void nameResolutionAfterSufficientTFs_singleEag() { + InOrder inOrder = inOrder(mockHelper); + EquivalentAddressGroup eag = new EquivalentAddressGroup(Arrays.asList( + new FakeSocketAddress("server1"), + new FakeSocketAddress("server2"), + new FakeSocketAddress("server3"))); + loadBalancer.acceptResolvedAddresses( + ResolvedAddresses.newBuilder().setAddresses(Arrays.asList(eag)).build()); + Status error = Status.UNAVAILABLE.withDescription("boom!"); + + // Initial subchannel gets TF, LB is still in CONNECTING + verify(mockSubchannel1).start(stateListenerCaptor.capture()); + SubchannelStateListener stateListener1 = stateListenerCaptor.getValue(); + stateListener1.onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + assertEquals(Status.OK, pickerCaptor.getValue().pickSubchannel(mockArgs).getStatus()); + + // Second subchannel gets TF, no UpdateBalancingState called + verify(mockSubchannel2).start(stateListenerCaptor.capture()); + SubchannelStateListener stateListener2 = stateListenerCaptor.getValue(); + stateListener2.onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + inOrder.verify(mockHelper, never()).refreshNameResolution(); + inOrder.verify(mockHelper, never()).updateBalancingState(any(), any()); + + // Third subchannel gets TF, LB goes into TRANSIENT_FAILURE and does a refreshNameResolution + verify(mockSubchannel3).start(stateListenerCaptor.capture()); + SubchannelStateListener stateListener3 = stateListenerCaptor.getValue(); + stateListener3.onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + inOrder.verify(mockHelper).refreshNameResolution(); + assertEquals(error, pickerCaptor.getValue().pickSubchannel(mockArgs).getStatus()); + + // Only after we have TFs reported for # of subchannels do we call refreshNameResolution + stateListener2.onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + inOrder.verify(mockHelper, never()).refreshNameResolution(); + stateListener2.onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + inOrder.verify(mockHelper, never()).refreshNameResolution(); + stateListener2.onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + inOrder.verify(mockHelper).refreshNameResolution(); + + // Now that we have refreshed, the count should have been reset + // Only after we have TFs reported for # of subchannels do we call refreshNameResolution + stateListener1.onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + inOrder.verify(mockHelper, never()).refreshNameResolution(); + stateListener2.onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + inOrder.verify(mockHelper, never()).refreshNameResolution(); + stateListener3.onSubchannelState(ConnectivityStateInfo.forTransientFailure(error)); + inOrder.verify(mockHelper).refreshNameResolution(); + } + @Test public void nameResolutionSuccessAfterError() { loadBalancer.handleNameResolutionError(Status.NOT_FOUND.withDescription("nameResolutionError")); @@ -811,8 +888,7 @@ public void requestConnection() { loadBalancer.requestConnection(); inOrder.verify(mockSubchannel2).start(stateListenerCaptor.capture()); SubchannelStateListener stateListener2 = stateListenerCaptor.getValue(); - int expectedRequests = enableHappyEyeballs ? 1 : 2; - inOrder.verify(mockSubchannel2, times(expectedRequests)).requestConnection(); + inOrder.verify(mockSubchannel2).requestConnection(); stateListener2.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); @@ -820,11 +896,7 @@ public void requestConnection() { loadBalancer.requestConnection(); inOrder.verify(mockHelper, never()).updateBalancingState(any(), any()); inOrder.verify(mockSubchannel1, never()).requestConnection(); - if (enableHappyEyeballs) { - inOrder.verify(mockSubchannel2, never()).requestConnection(); - } else { - inOrder.verify(mockSubchannel2).requestConnection(); - } + inOrder.verify(mockSubchannel2, never()).requestConnection(); } @Test @@ -1079,10 +1151,17 @@ public void updateAddresses_disjoint_ready_twice() { loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(newServers).setAttributes(affinity).build()); inOrder.verify(mockSubchannel1).shutdown(); - inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + inOrder.verify(mockHelper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); + inOrder.verify(mockSubchannel3, never()).start(stateListenerCaptor.capture()); + + // Trigger connection creation + picker = pickerCaptor.getValue(); + assertEquals(PickResult.withNoResult(), picker.pickSubchannel(mockArgs)); inOrder.verify(mockSubchannel3).start(stateListenerCaptor.capture()); SubchannelStateListener stateListener3 = stateListenerCaptor.getValue(); inOrder.verify(mockSubchannel3).requestConnection(); + stateListener3.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); if (enableHappyEyeballs) { forwardTimeByConnectionDelay(); @@ -1129,17 +1208,19 @@ public void updateAddresses_disjoint_ready_twice() { loadBalancer.acceptResolvedAddresses( ResolvedAddresses.newBuilder().setAddresses(newestServers).setAttributes(affinity).build()); inOrder.verify(mockSubchannel3).shutdown(); - inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); - inOrder.verify(mockSubchannel1n2).start(stateListenerCaptor.capture()); - stateListener = stateListenerCaptor.getValue(); - assertEquals(CONNECTING, loadBalancer.getConcludedConnectivityState()); + inOrder.verify(mockHelper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); + assertEquals(IDLE, loadBalancer.getConcludedConnectivityState()); picker = pickerCaptor.getValue(); // Calling pickSubchannel() twice gave the same result assertEquals(picker.pickSubchannel(mockArgs), picker.pickSubchannel(mockArgs)); // But the picker calls requestConnection() only once + inOrder.verify(mockSubchannel1n2).start(stateListenerCaptor.capture()); + stateListener = stateListenerCaptor.getValue(); inOrder.verify(mockSubchannel1n2).requestConnection(); + stateListener.onSubchannelState(ConnectivityStateInfo.forNonError(CONNECTING)); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); assertEquals(PickResult.withNoResult(), pickerCaptor.getValue().pickSubchannel(mockArgs)); assertEquals(CONNECTING, loadBalancer.getConcludedConnectivityState()); diff --git a/cronet/src/main/java/io/grpc/cronet/CronetChannelBuilder.java b/cronet/src/main/java/io/grpc/cronet/CronetChannelBuilder.java index 1c60f82846d..f42dabdd55a 100644 --- a/cronet/src/main/java/io/grpc/cronet/CronetChannelBuilder.java +++ b/cronet/src/main/java/io/grpc/cronet/CronetChannelBuilder.java @@ -20,9 +20,12 @@ import static com.google.common.base.Preconditions.checkNotNull; import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; +import android.net.Network; +import android.os.Build; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.util.concurrent.MoreExecutors; +import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.DoNotCall; import io.grpc.ChannelCredentials; import io.grpc.ChannelLogger; @@ -105,6 +108,7 @@ public static CronetChannelBuilder forAddress(String name, int port) { private int trafficStatsTag; private boolean trafficStatsUidSet; private int trafficStatsUid; + private Network network; private CronetChannelBuilder(String host, int port, CronetEngine cronetEngine) { final class CronetChannelTransportFactoryBuilder implements ClientTransportFactoryBuilder { @@ -190,6 +194,13 @@ CronetChannelBuilder setTrafficStatsUid(int uid) { return this; } + /** Sets the network ID to use for this channel traffic. */ + @CanIgnoreReturnValue + CronetChannelBuilder bindToNetwork(@Nullable Network network) { + this.network = network; + return this; + } + /** * Provides a custom scheduled executor service. * @@ -210,7 +221,12 @@ public CronetChannelBuilder scheduledExecutorService( ClientTransportFactory buildTransportFactory() { return new CronetTransportFactory( new TaggingStreamFactory( - cronetEngine, trafficStatsTagSet, trafficStatsTag, trafficStatsUidSet, trafficStatsUid), + cronetEngine, + trafficStatsTagSet, + trafficStatsTag, + trafficStatsUidSet, + trafficStatsUid, + network), MoreExecutors.directExecutor(), scheduledExecutorService, maxMessageSize, @@ -294,18 +310,21 @@ private static class TaggingStreamFactory extends StreamBuilderFactory { private final int trafficStatsTag; private final boolean trafficStatsUidSet; private final int trafficStatsUid; + private final Network network; TaggingStreamFactory( CronetEngine cronetEngine, boolean trafficStatsTagSet, int trafficStatsTag, boolean trafficStatsUidSet, - int trafficStatsUid) { + int trafficStatsUid, + Network network) { this.cronetEngine = cronetEngine; this.trafficStatsTagSet = trafficStatsTagSet; this.trafficStatsTag = trafficStatsTag; this.trafficStatsUidSet = trafficStatsUidSet; this.trafficStatsUid = trafficStatsUid; + this.network = network; } @Override @@ -320,6 +339,11 @@ public BidirectionalStream.Builder newBidirectionalStreamBuilder( if (trafficStatsUidSet) { builder.setTrafficStatsUid(trafficStatsUid); } + if (network != null) { + if (Build.VERSION.SDK_INT >= 23) { + builder.bindToNetwork(network.getNetworkHandle()); + } + } return builder; } } diff --git a/cronet/src/main/java/io/grpc/cronet/InternalCronetChannelBuilder.java b/cronet/src/main/java/io/grpc/cronet/InternalCronetChannelBuilder.java index 2954f1eee81..7e5e610ca67 100644 --- a/cronet/src/main/java/io/grpc/cronet/InternalCronetChannelBuilder.java +++ b/cronet/src/main/java/io/grpc/cronet/InternalCronetChannelBuilder.java @@ -16,7 +16,9 @@ package io.grpc.cronet; +import android.net.Network; import io.grpc.Internal; +import org.checkerframework.checker.nullness.qual.Nullable; /** * Internal {@link CronetChannelBuilder} accessor. This is intended for usage internal to the gRPC @@ -58,4 +60,9 @@ public static void setTrafficStatsTag(CronetChannelBuilder builder, int tag) { public static void setTrafficStatsUid(CronetChannelBuilder builder, int uid) { builder.setTrafficStatsUid(uid); } + + /** Sets the network {@link android.net.Network} to use when relying traffic by this channel. */ + public static void bindToNetwork(CronetChannelBuilder builder, @Nullable Network network) { + builder.bindToNetwork(network); + } } diff --git a/examples/android/clientcache/app/build.gradle b/examples/android/clientcache/app/build.gradle index 0ca032fb0e4..6b5b966e7f6 100644 --- a/examples/android/clientcache/app/build.gradle +++ b/examples/android/clientcache/app/build.gradle @@ -34,7 +34,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.25.1' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -54,12 +54,12 @@ dependencies { implementation 'androidx.appcompat:appcompat:1.0.0' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' testImplementation 'junit:junit:4.13.2' testImplementation 'com.google.truth:truth:1.1.5' - testImplementation 'io.grpc:grpc-testing:1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION + testImplementation 'io.grpc:grpc-testing:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION } diff --git a/examples/android/helloworld/app/build.gradle b/examples/android/helloworld/app/build.gradle index 0f1e8b4047b..4edbcb14612 100644 --- a/examples/android/helloworld/app/build.gradle +++ b/examples/android/helloworld/app/build.gradle @@ -32,7 +32,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.25.1' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -52,8 +52,8 @@ dependencies { implementation 'androidx.appcompat:appcompat:1.0.0' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' } diff --git a/examples/android/routeguide/app/build.gradle b/examples/android/routeguide/app/build.gradle index c33135233ea..4a08f40e4ee 100644 --- a/examples/android/routeguide/app/build.gradle +++ b/examples/android/routeguide/app/build.gradle @@ -32,7 +32,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.25.1' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -52,8 +52,8 @@ dependencies { implementation 'androidx.appcompat:appcompat:1.0.0' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' } diff --git a/examples/android/strictmode/app/build.gradle b/examples/android/strictmode/app/build.gradle index e8e2e8cac29..9f41994e3c2 100644 --- a/examples/android/strictmode/app/build.gradle +++ b/examples/android/strictmode/app/build.gradle @@ -33,7 +33,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.25.1' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -53,8 +53,8 @@ dependencies { implementation 'androidx.appcompat:appcompat:1.0.0' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' } diff --git a/examples/build.gradle b/examples/build.gradle index 076e0c4a25b..c10b4eef46a 100644 --- a/examples/build.gradle +++ b/examples/build.gradle @@ -23,7 +23,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protobufVersion = '3.25.3' def protocVersion = protobufVersion diff --git a/examples/example-alts/build.gradle b/examples/example-alts/build.gradle index 3c998586bb6..0d7d959de93 100644 --- a/examples/example-alts/build.gradle +++ b/examples/example-alts/build.gradle @@ -24,7 +24,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protocVersion = '3.25.3' dependencies { diff --git a/examples/example-debug/build.gradle b/examples/example-debug/build.gradle index ca151a13c1a..5565747cb19 100644 --- a/examples/example-debug/build.gradle +++ b/examples/example-debug/build.gradle @@ -25,7 +25,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protobufVersion = '3.25.3' dependencies { diff --git a/examples/example-debug/pom.xml b/examples/example-debug/pom.xml index 10ccf834d86..064d989c04c 100644 --- a/examples/example-debug/pom.xml +++ b/examples/example-debug/pom.xml @@ -6,13 +6,13 @@ jar - 1.67.0-SNAPSHOT + 1.68.0-SNAPSHOT example-debug https://github.com/grpc/grpc-java UTF-8 - 1.67.0-SNAPSHOT + 1.68.0-SNAPSHOT 3.25.3 1.8 @@ -98,7 +98,7 @@ org.apache.maven.plugins maven-enforcer-plugin - 1.4.1 + 3.5.0 enforce diff --git a/examples/example-dualstack/README.md b/examples/example-dualstack/README.md new file mode 100644 index 00000000000..6c191661d1b --- /dev/null +++ b/examples/example-dualstack/README.md @@ -0,0 +1,54 @@ +# gRPC Dualstack Example + +The dualstack example uses a custom name resolver that provides both IPv4 and IPv6 localhost +endpoints for each of 3 server instances. The client will first use the default name resolver and +load balancers which will only connect tot he first server. It will then use the +custom name resolver with round robin to connect to each of the servers in turn. The 3 instances +of the server will bind respectively to: both IPv4 and IPv6, IPv4 only, and IPv6 only. + +The example requires grpc-java to already be built. You are strongly encouraged +to check out a git release tag, since there will already be a build of grpc +available. Otherwise, you must follow [COMPILING](../../COMPILING.md). + +### Build the example + +To build the dualstack example server and client. From the + `grpc-java/examples/example-dualstack` directory run: + +```bash +$ ../gradlew installDist +``` + +This creates the scripts +`build/install/example-dualstack/bin/dual-stack-server` + and `build/install/example-dualstack/bin/dual-stack-client`. + +To run the dualstack example, run the server with: + +```bash +$ ./build/install/example-dualstack/bin/dual-stack-server +``` + +And in a different terminal window run the client. + +```bash +$ ./build/install/example-dualstack/bin/dual-stack-client +``` + +### Maven + +If you prefer to use Maven: + +Run in the example-debug directory: + +```bash +$ mvn verify +$ # Run the server in one terminal +$ mvn exec:java -Dexec.mainClass=io.grpc.examples.dualstack.DualStackServer +``` + +```bash +$ # In another terminal run the client +$ mvn exec:java -Dexec.mainClass=io.grpc.examples.dualstack.DualStackClient +``` + diff --git a/examples/example-dualstack/build.gradle b/examples/example-dualstack/build.gradle new file mode 100644 index 00000000000..554b5f758d9 --- /dev/null +++ b/examples/example-dualstack/build.gradle @@ -0,0 +1,79 @@ +plugins { + id 'application' // Provide convenience executables for trying out the examples. + id 'java' + + id "com.google.protobuf" version "0.9.4" + + // Generate IntelliJ IDEA's .idea & .iml project files + id 'idea' +} + +repositories { + maven { // The google mirror is less flaky than mavenCentral() + url "https://maven-central.storage-download.googleapis.com/maven2/" } + mavenCentral() + mavenLocal() +} + +java { + sourceCompatibility = JavaVersion.VERSION_1_8 + targetCompatibility = JavaVersion.VERSION_1_8 +} + +// IMPORTANT: You probably want the non-SNAPSHOT version of gRPC. Make sure you +// are looking at a tagged version of the example and not "master"! + +// Feel free to delete the comment at the next line. It is just for safely +// updating the version in our release process. +def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def protobufVersion = '3.25.3' + +dependencies { + implementation "io.grpc:grpc-protobuf:${grpcVersion}" + implementation "io.grpc:grpc-netty:${grpcVersion}" + implementation "io.grpc:grpc-stub:${grpcVersion}" + implementation "io.grpc:grpc-services:${grpcVersion}" + compileOnly "org.apache.tomcat:annotations-api:6.0.53" +} + +protobuf { + protoc { + artifact = "com.google.protobuf:protoc:${protobufVersion}" + } + plugins { + grpc { + artifact = "io.grpc:protoc-gen-grpc-java:${grpcVersion}" + } + } + generateProtoTasks { + all()*.plugins { + grpc {} + } + } +} + +startScripts.enabled = false + +task DualStackClient(type: CreateStartScripts) { + mainClass = 'io.grpc.examples.dualstack.DualStackClient' + applicationName = 'dual-stack-client' + outputDir = new File(project.buildDir, 'tmp/scripts/' + name) + classpath = startScripts.classpath +} + +task DualStackServer(type: CreateStartScripts) { + mainClass = 'io.grpc.examples.dualstack.DualStackServer' + applicationName = 'dual-stack-server' + outputDir = new File(project.buildDir, 'tmp/scripts/' + name) + classpath = startScripts.classpath +} + +application { + applicationDistribution.into('bin') { + from(DualStackClient) + from(DualStackServer) + filePermissions { + unix(0755) + } + } +} diff --git a/examples/example-dualstack/pom.xml b/examples/example-dualstack/pom.xml new file mode 100644 index 00000000000..dfd650cdfa4 --- /dev/null +++ b/examples/example-dualstack/pom.xml @@ -0,0 +1,122 @@ + + 4.0.0 + io.grpc + example-dualstack + jar + + 1.68.0-SNAPSHOT + example-dualstack + https://github.com/grpc/grpc-java + + + UTF-8 + 1.68.0-SNAPSHOT + 3.25.3 + + 1.8 + 1.8 + + + + + + io.grpc + grpc-bom + ${grpc.version} + pom + import + + + + + + + io.grpc + grpc-services + + + io.grpc + grpc-protobuf + + + io.grpc + grpc-stub + + + io.grpc + grpc-netty + + + org.apache.tomcat + annotations-api + 6.0.53 + provided + + + io.grpc + grpc-netty-shaded + runtime + + + junit + junit + 4.13.2 + test + + + io.grpc + grpc-testing + test + + + + + + + kr.motd.maven + os-maven-plugin + 1.7.1 + + + + + org.xolstice.maven.plugins + protobuf-maven-plugin + 0.6.1 + + com.google.protobuf:protoc:${protoc.version}:exe:${os.detected.classifier} + grpc-java + io.grpc:protoc-gen-grpc-java:${grpc.version}:exe:${os.detected.classifier} + + + + + compile + compile-custom + + + + + + org.apache.maven.plugins + maven-enforcer-plugin + 3.5.0 + + + enforce + + enforce + + + + + + + + + + + + diff --git a/examples/example-dualstack/settings.gradle b/examples/example-dualstack/settings.gradle new file mode 100644 index 00000000000..0aae8f7304e --- /dev/null +++ b/examples/example-dualstack/settings.gradle @@ -0,0 +1,10 @@ +pluginManagement { + repositories { + maven { // The google mirror is less flaky than mavenCentral() + url "https://maven-central.storage-download.googleapis.com/maven2/" + } + gradlePluginPortal() + } +} + +rootProject.name = 'example-dualstack' diff --git a/examples/example-dualstack/src/main/java/io/grpc/examples/dualstack/DualStackClient.java b/examples/example-dualstack/src/main/java/io/grpc/examples/dualstack/DualStackClient.java new file mode 100644 index 00000000000..b9993a524d6 --- /dev/null +++ b/examples/example-dualstack/src/main/java/io/grpc/examples/dualstack/DualStackClient.java @@ -0,0 +1,95 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.examples.dualstack; + +import io.grpc.Channel; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import io.grpc.NameResolverRegistry; +import io.grpc.StatusRuntimeException; +import io.grpc.examples.helloworld.GreeterGrpc; +import io.grpc.examples.helloworld.HelloReply; +import io.grpc.examples.helloworld.HelloRequest; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * A client that requests greetings from the {@link DualStackServer}. + * First it sends 5 requests using the default nameresolver and load balancer. + * Then it sends 10 requests using the example nameresolver and round robin load balancer. These + * requests are evenly distributed among the 3 servers rather than favoring the server listening + * on both addresses because the ExampleDualStackNameResolver groups the 3 servers as 3 endpoints + * each with 2 addresses. + */ +public class DualStackClient { + public static final String channelTarget = "example:///lb.example.grpc.io"; + private static final Logger logger = Logger.getLogger(DualStackClient.class.getName()); + private final GreeterGrpc.GreeterBlockingStub blockingStub; + + public DualStackClient(Channel channel) { + blockingStub = GreeterGrpc.newBlockingStub(channel); + } + + public static void main(String[] args) throws Exception { + NameResolverRegistry.getDefaultRegistry() + .register(new ExampleDualStackNameResolverProvider()); + + logger.info("\n **** Use default DNS resolver ****"); + ManagedChannel channel = ManagedChannelBuilder.forTarget("localhost:50051") + .usePlaintext() + .build(); + try { + DualStackClient client = new DualStackClient(channel); + for (int i = 0; i < 5; i++) { + client.greet("request:" + i); + } + } finally { + channel.shutdownNow().awaitTermination(5, TimeUnit.SECONDS); + } + + logger.info("\n **** Change to use example name resolver ****"); + /* + Dial to "example:///resolver.example.grpc.io", use {@link ExampleNameResolver} to create connection + "resolver.example.grpc.io" is converted to {@link java.net.URI.path} + */ + channel = ManagedChannelBuilder.forTarget(channelTarget) + .defaultLoadBalancingPolicy("round_robin") + .usePlaintext() + .build(); + try { + DualStackClient client = new DualStackClient(channel); + for (int i = 0; i < 10; i++) { + client.greet("request:" + i); + } + } finally { + channel.shutdownNow().awaitTermination(5, TimeUnit.SECONDS); + } + } + + public void greet(String name) { + HelloRequest request = HelloRequest.newBuilder().setName(name).build(); + HelloReply response; + try { + response = blockingStub.sayHello(request); + } catch (StatusRuntimeException e) { + logger.log(Level.WARNING, "RPC failed: {0}", e.getStatus()); + return; + } + logger.info("Greeting: " + response.getMessage()); + } +} diff --git a/examples/example-dualstack/src/main/java/io/grpc/examples/dualstack/DualStackServer.java b/examples/example-dualstack/src/main/java/io/grpc/examples/dualstack/DualStackServer.java new file mode 100644 index 00000000000..43b21e963f8 --- /dev/null +++ b/examples/example-dualstack/src/main/java/io/grpc/examples/dualstack/DualStackServer.java @@ -0,0 +1,126 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.examples.dualstack; + +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.examples.helloworld.GreeterGrpc; +import io.grpc.examples.helloworld.HelloReply; +import io.grpc.examples.helloworld.HelloRequest; +import io.grpc.netty.NettyServerBuilder; +import io.grpc.stub.StreamObserver; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.logging.Logger; + +/** + * Starts 3 different greeter services each on its own port, but all for localhost. + * The first service listens on both IPv4 and IPv6, + * the second on just IPv4, and the third on just IPv6. + */ +public class DualStackServer { + private static final Logger logger = Logger.getLogger(DualStackServer.class.getName()); + private List servers; + + public static void main(String[] args) throws IOException, InterruptedException { + final DualStackServer server = new DualStackServer(); + server.start(); + server.blockUntilShutdown(); + } + + private void start() throws IOException { + InetSocketAddress inetSocketAddress; + + servers = new ArrayList<>(); + int[] serverPorts = ExampleDualStackNameResolver.SERVER_PORTS; + for (int i = 0; i < serverPorts.length; i++ ) { + String addressType; + int port = serverPorts[i]; + ServerBuilder serverBuilder; + switch (i) { + case 0: + serverBuilder = ServerBuilder.forPort(port); // bind to both IPv4 and IPv6 + addressType = "both IPv4 and IPv6"; + break; + case 1: + // bind to IPv4 only + inetSocketAddress = new InetSocketAddress("127.0.0.1", port); + serverBuilder = NettyServerBuilder.forAddress(inetSocketAddress); + addressType = "IPv4 only"; + break; + case 2: + // bind to IPv6 only + inetSocketAddress = new InetSocketAddress("::1", port); + serverBuilder = NettyServerBuilder.forAddress(inetSocketAddress); + addressType = "IPv6 only"; + break; + default: + throw new IllegalStateException("Unexpected value: " + i); + } + + servers.add(serverBuilder + .addService(new GreeterImpl(port, addressType)) + .build() + .start()); + logger.info("Server started, listening on " + port); + } + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + System.err.println("*** shutting down gRPC server since JVM is shutting down"); + try { + DualStackServer.this.stop(); + } catch (InterruptedException e) { + e.printStackTrace(System.err); + } + System.err.println("*** server shut down"); + })); + } + + private void stop() throws InterruptedException { + for (Server server : servers) { + server.shutdown().awaitTermination(30, TimeUnit.SECONDS); + } + } + + private void blockUntilShutdown() throws InterruptedException { + for (Server server : servers) { + server.awaitTermination(); + } + } + + static class GreeterImpl extends GreeterGrpc.GreeterImplBase { + + int port; + String addressType; + + public GreeterImpl(int port, String addressType) { + this.port = port; + this.addressType = addressType; + } + + @Override + public void sayHello(HelloRequest req, StreamObserver responseObserver) { + String msg = String.format("Hello %s from server<%d> type: %s", + req.getName(), this.port, addressType); + HelloReply reply = HelloReply.newBuilder().setMessage(msg).build(); + responseObserver.onNext(reply); + responseObserver.onCompleted(); + } + } +} diff --git a/examples/example-dualstack/src/main/java/io/grpc/examples/dualstack/ExampleDualStackNameResolver.java b/examples/example-dualstack/src/main/java/io/grpc/examples/dualstack/ExampleDualStackNameResolver.java new file mode 100644 index 00000000000..70675b3de3d --- /dev/null +++ b/examples/example-dualstack/src/main/java/io/grpc/examples/dualstack/ExampleDualStackNameResolver.java @@ -0,0 +1,98 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.examples.dualstack; + +import com.google.common.collect.ImmutableMap; +import io.grpc.EquivalentAddressGroup; +import io.grpc.NameResolver; +import io.grpc.Status; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.net.URI; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +/** + * A fake name resolver that resolves to a hard-coded list of 3 endpoints (EquivalentAddressGropu) + * each with 2 addresses (one IPv4 and one IPv6). + */ +public class ExampleDualStackNameResolver extends NameResolver { + static public final int[] SERVER_PORTS = {50051, 50052, 50053}; + + // This is a fake name resolver, so we just hard code the address here. + private static final ImmutableMap>> addrStore = + ImmutableMap.>>builder() + .put("lb.example.grpc.io", + Arrays.stream(SERVER_PORTS) + .mapToObj(port -> getLocalAddrs(port)) + .collect(Collectors.toList()) + ) + .build(); + + private Listener2 listener; + + private final URI uri; + + public ExampleDualStackNameResolver(URI targetUri) { + this.uri = targetUri; + } + + private static List getLocalAddrs(int port) { + return Arrays.asList( + new InetSocketAddress("127.0.0.1", port), + new InetSocketAddress("::1", port)); + } + + @Override + public String getServiceAuthority() { + return uri.getPath().substring(1); + } + + @Override + public void shutdown() { + } + + @Override + public void start(Listener2 listener) { + this.listener = listener; + this.resolve(); + } + + @Override + public void refresh() { + this.resolve(); + } + + private void resolve() { + List> addresses = addrStore.get(uri.getPath().substring(1)); + try { + List eagList = new ArrayList<>(); + for (List endpoint : addresses) { + // every server is an EquivalentAddressGroup, so they can be accessed randomly + eagList.add(new EquivalentAddressGroup(endpoint)); + } + + this.listener.onResult(ResolutionResult.newBuilder().setAddresses(eagList).build()); + } catch (Exception e){ + // when error occurs, notify listener + this.listener.onError(Status.UNAVAILABLE.withDescription("Unable to resolve host ").withCause(e)); + } + } + +} diff --git a/examples/example-dualstack/src/main/java/io/grpc/examples/dualstack/ExampleDualStackNameResolverProvider.java b/examples/example-dualstack/src/main/java/io/grpc/examples/dualstack/ExampleDualStackNameResolverProvider.java new file mode 100644 index 00000000000..a01d68aca3e --- /dev/null +++ b/examples/example-dualstack/src/main/java/io/grpc/examples/dualstack/ExampleDualStackNameResolverProvider.java @@ -0,0 +1,47 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.examples.dualstack; + +import io.grpc.NameResolver; +import io.grpc.NameResolverProvider; + +import java.net.URI; + +public class ExampleDualStackNameResolverProvider extends NameResolverProvider { + public static final String exampleScheme = "example"; + + @Override + public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { + return new ExampleDualStackNameResolver(targetUri); + } + + @Override + protected boolean isAvailable() { + return true; + } + + @Override + protected int priority() { + return 5; + } + + @Override + // gRPC choose the first NameResolverProvider that supports the target URI scheme. + public String getDefaultScheme() { + return exampleScheme; + } +} diff --git a/examples/example-dualstack/src/main/proto/helloworld/helloworld.proto b/examples/example-dualstack/src/main/proto/helloworld/helloworld.proto new file mode 100644 index 00000000000..c60d9416f1f --- /dev/null +++ b/examples/example-dualstack/src/main/proto/helloworld/helloworld.proto @@ -0,0 +1,37 @@ +// Copyright 2015 The gRPC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +syntax = "proto3"; + +option java_multiple_files = true; +option java_package = "io.grpc.examples.helloworld"; +option java_outer_classname = "HelloWorldProto"; +option objc_class_prefix = "HLW"; + +package helloworld; + +// The greeting service definition. +service Greeter { + // Sends a greeting + rpc SayHello (HelloRequest) returns (HelloReply) {} +} + +// The request message containing the user's name. +message HelloRequest { + string name = 1; +} + +// The response message containing the greetings +message HelloReply { + string message = 1; +} diff --git a/examples/example-gauth/build.gradle b/examples/example-gauth/build.gradle index 40e72afad82..47e812fde15 100644 --- a/examples/example-gauth/build.gradle +++ b/examples/example-gauth/build.gradle @@ -24,7 +24,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protobufVersion = '3.25.3' def protocVersion = protobufVersion diff --git a/examples/example-gauth/pom.xml b/examples/example-gauth/pom.xml index 1e58e21e975..d2cba1a7959 100644 --- a/examples/example-gauth/pom.xml +++ b/examples/example-gauth/pom.xml @@ -6,13 +6,13 @@ jar - 1.67.0-SNAPSHOT + 1.68.0-SNAPSHOT example-gauth https://github.com/grpc/grpc-java UTF-8 - 1.67.0-SNAPSHOT + 1.68.0-SNAPSHOT 3.25.3 1.8 @@ -96,7 +96,7 @@ org.apache.maven.plugins maven-enforcer-plugin - 1.4.1 + 3.5.0 enforce diff --git a/examples/example-gcp-csm-observability/build.gradle b/examples/example-gcp-csm-observability/build.gradle index 5de2b1995e2..a392018ba25 100644 --- a/examples/example-gcp-csm-observability/build.gradle +++ b/examples/example-gcp-csm-observability/build.gradle @@ -25,7 +25,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protocVersion = '3.25.3' def openTelemetryVersion = '1.40.0' def openTelemetryPrometheusVersion = '1.40.0-alpha' diff --git a/examples/example-gcp-observability/build.gradle b/examples/example-gcp-observability/build.gradle index 0462c987f52..dcb8d420020 100644 --- a/examples/example-gcp-observability/build.gradle +++ b/examples/example-gcp-observability/build.gradle @@ -25,7 +25,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protocVersion = '3.25.3' dependencies { diff --git a/examples/example-hostname/build.gradle b/examples/example-hostname/build.gradle index ab45ee2dc5b..df8b0fde121 100644 --- a/examples/example-hostname/build.gradle +++ b/examples/example-hostname/build.gradle @@ -23,7 +23,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protobufVersion = '3.25.3' dependencies { diff --git a/examples/example-hostname/pom.xml b/examples/example-hostname/pom.xml index 19b5f8b3c20..c6d39887bac 100644 --- a/examples/example-hostname/pom.xml +++ b/examples/example-hostname/pom.xml @@ -6,13 +6,13 @@ jar - 1.67.0-SNAPSHOT + 1.68.0-SNAPSHOT example-hostname https://github.com/grpc/grpc-java UTF-8 - 1.67.0-SNAPSHOT + 1.68.0-SNAPSHOT 3.25.3 1.8 @@ -98,7 +98,7 @@ org.apache.maven.plugins maven-enforcer-plugin - 1.4.1 + 3.5.0 enforce diff --git a/examples/example-jwt-auth/build.gradle b/examples/example-jwt-auth/build.gradle index 6fdd4498c7d..f996282bbb0 100644 --- a/examples/example-jwt-auth/build.gradle +++ b/examples/example-jwt-auth/build.gradle @@ -23,7 +23,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protobufVersion = '3.25.3' def protocVersion = protobufVersion diff --git a/examples/example-jwt-auth/pom.xml b/examples/example-jwt-auth/pom.xml index ad530e33aa7..c84f9893980 100644 --- a/examples/example-jwt-auth/pom.xml +++ b/examples/example-jwt-auth/pom.xml @@ -7,13 +7,13 @@ jar - 1.67.0-SNAPSHOT + 1.68.0-SNAPSHOT example-jwt-auth https://github.com/grpc/grpc-java UTF-8 - 1.67.0-SNAPSHOT + 1.68.0-SNAPSHOT 3.25.3 3.25.3 @@ -94,7 +94,7 @@ org.xolstice.maven.plugins protobuf-maven-plugin - 0.5.1 + 0.6.1 com.google.protobuf:protoc:${protoc.version}:exe:${os.detected.classifier} @@ -116,7 +116,7 @@ org.apache.maven.plugins maven-enforcer-plugin - 1.4.1 + 3.5.0 enforce diff --git a/examples/example-oauth/build.gradle b/examples/example-oauth/build.gradle index 255633b4f9f..7f600c2bc53 100644 --- a/examples/example-oauth/build.gradle +++ b/examples/example-oauth/build.gradle @@ -23,7 +23,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protobufVersion = '3.25.3' def protocVersion = protobufVersion diff --git a/examples/example-oauth/pom.xml b/examples/example-oauth/pom.xml index 2c38a05b3e4..fa2eaa41e36 100644 --- a/examples/example-oauth/pom.xml +++ b/examples/example-oauth/pom.xml @@ -7,13 +7,13 @@ jar - 1.67.0-SNAPSHOT + 1.68.0-SNAPSHOT example-oauth https://github.com/grpc/grpc-java UTF-8 - 1.67.0-SNAPSHOT + 1.68.0-SNAPSHOT 3.25.3 3.25.3 @@ -99,7 +99,7 @@ org.xolstice.maven.plugins protobuf-maven-plugin - 0.5.1 + 0.6.1 com.google.protobuf:protoc:${protoc.version}:exe:${os.detected.classifier} @@ -121,7 +121,7 @@ org.apache.maven.plugins maven-enforcer-plugin - 1.4.1 + 3.5.0 enforce diff --git a/examples/example-opentelemetry/build.gradle b/examples/example-opentelemetry/build.gradle index 00f7dc101bf..21264ffcc17 100644 --- a/examples/example-opentelemetry/build.gradle +++ b/examples/example-opentelemetry/build.gradle @@ -24,7 +24,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protocVersion = '3.25.3' def openTelemetryVersion = '1.40.0' def openTelemetryPrometheusVersion = '1.40.0-alpha' diff --git a/examples/example-orca/build.gradle b/examples/example-orca/build.gradle index 22feb8cae42..d087a532aff 100644 --- a/examples/example-orca/build.gradle +++ b/examples/example-orca/build.gradle @@ -18,7 +18,7 @@ java { targetCompatibility = JavaVersion.VERSION_1_8 } -def grpcVersion = '1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protocVersion = '3.25.3' dependencies { diff --git a/examples/example-reflection/build.gradle b/examples/example-reflection/build.gradle index 78821391911..d7d5c50b7e6 100644 --- a/examples/example-reflection/build.gradle +++ b/examples/example-reflection/build.gradle @@ -18,7 +18,7 @@ java { targetCompatibility = JavaVersion.VERSION_1_8 } -def grpcVersion = '1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protocVersion = '3.25.3' dependencies { diff --git a/examples/example-servlet/build.gradle b/examples/example-servlet/build.gradle index 9542ba0277f..995e2d0979b 100644 --- a/examples/example-servlet/build.gradle +++ b/examples/example-servlet/build.gradle @@ -16,7 +16,7 @@ java { targetCompatibility = JavaVersion.VERSION_1_8 } -def grpcVersion = '1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protocVersion = '3.25.3' dependencies { diff --git a/examples/example-tls/build.gradle b/examples/example-tls/build.gradle index 94257af4758..8aad6b62bcb 100644 --- a/examples/example-tls/build.gradle +++ b/examples/example-tls/build.gradle @@ -24,7 +24,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protocVersion = '3.25.3' dependencies { diff --git a/examples/example-tls/pom.xml b/examples/example-tls/pom.xml index bc9c0a7a8ee..e1d569a628c 100644 --- a/examples/example-tls/pom.xml +++ b/examples/example-tls/pom.xml @@ -6,13 +6,13 @@ jar - 1.67.0-SNAPSHOT + 1.68.0-SNAPSHOT example-tls https://github.com/grpc/grpc-java UTF-8 - 1.67.0-SNAPSHOT + 1.68.0-SNAPSHOT 3.25.3 1.8 @@ -82,7 +82,7 @@ org.apache.maven.plugins maven-enforcer-plugin - 1.4.1 + 3.5.0 enforce diff --git a/examples/example-xds/build.gradle b/examples/example-xds/build.gradle index 2554adb0033..8339db77e0c 100644 --- a/examples/example-xds/build.gradle +++ b/examples/example-xds/build.gradle @@ -23,7 +23,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.67.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.68.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protocVersion = '3.25.3' dependencies { diff --git a/examples/pom.xml b/examples/pom.xml index 2b25d13b50c..247df4a73ce 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -6,13 +6,13 @@ jar - 1.67.0-SNAPSHOT + 1.68.0-SNAPSHOT examples https://github.com/grpc/grpc-java UTF-8 - 1.67.0-SNAPSHOT + 1.68.0-SNAPSHOT 3.25.3 3.25.3 @@ -55,6 +55,11 @@ protobuf-java-util ${protobuf.version} + + com.google.j2objc + j2objc-annotations + 3.0.0 + org.apache.tomcat annotations-api @@ -110,7 +115,7 @@ org.apache.maven.plugins maven-enforcer-plugin - 1.4.1 + 3.5.0 enforce diff --git a/examples/src/main/java/io/grpc/examples/keepalive/KeepAliveClient.java b/examples/src/main/java/io/grpc/examples/keepalive/KeepAliveClient.java index a7c59c3952f..414d92dea4c 100644 --- a/examples/src/main/java/io/grpc/examples/keepalive/KeepAliveClient.java +++ b/examples/src/main/java/io/grpc/examples/keepalive/KeepAliveClient.java @@ -78,7 +78,6 @@ public static void main(String[] args) throws Exception { // frames. // More details see: https://github.com/grpc/proposal/blob/master/A8-client-side-keepalive.md ManagedChannel channel = Grpc.newChannelBuilder(target, InsecureChannelCredentials.create()) - .keepAliveTime(5, TimeUnit.MINUTES) .keepAliveTime(10, TimeUnit.SECONDS) // Change to a larger value, e.g. 5min. .keepAliveTimeout(1, TimeUnit.SECONDS) // Change to a larger value, e.g. 10s. .keepAliveWithoutCalls(true)// You should normally avoid enabling this. diff --git a/examples/src/main/java/io/grpc/examples/loadbalance/ExampleNameResolver.java b/examples/src/main/java/io/grpc/examples/loadbalance/ExampleNameResolver.java index f562f0ac107..6ef327ade84 100644 --- a/examples/src/main/java/io/grpc/examples/loadbalance/ExampleNameResolver.java +++ b/examples/src/main/java/io/grpc/examples/loadbalance/ExampleNameResolver.java @@ -28,12 +28,12 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; -import java.util.stream.Stream; import static io.grpc.examples.loadbalance.LoadBalanceClient.exampleServiceName; public class ExampleNameResolver extends NameResolver { + static private final int[] SERVER_PORTS = {50051, 50052, 50053}; private Listener2 listener; private final URI uri; @@ -44,12 +44,11 @@ public ExampleNameResolver(URI targetUri) { this.uri = targetUri; // This is a fake name resolver, so we just hard code the address here. addrStore = ImmutableMap.>builder() - .put(exampleServiceName, - Stream.iterate(LoadBalanceServer.startPort,p->p+1) - .limit(LoadBalanceServer.serverCount) - .map(port->new InetSocketAddress("localhost",port)) - .collect(Collectors.toList()) - ) + .put(exampleServiceName, + Arrays.stream(SERVER_PORTS) + .mapToObj(port->new InetSocketAddress("localhost",port)) + .collect(Collectors.toList()) + ) .build(); } diff --git a/examples/src/main/java/io/grpc/examples/loadbalance/LoadBalanceServer.java b/examples/src/main/java/io/grpc/examples/loadbalance/LoadBalanceServer.java index c97d209497a..85ae92a537a 100644 --- a/examples/src/main/java/io/grpc/examples/loadbalance/LoadBalanceServer.java +++ b/examples/src/main/java/io/grpc/examples/loadbalance/LoadBalanceServer.java @@ -24,23 +24,24 @@ import io.grpc.stub.StreamObserver; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.TimeUnit; import java.util.logging.Logger; public class LoadBalanceServer { private static final Logger logger = Logger.getLogger(LoadBalanceServer.class.getName()); - static public final int serverCount = 3; - static public final int startPort = 50051; - private Server[] servers; + static final int[] SERVER_PORTS = {50051, 50052, 50053}; + private List servers; private void start() throws IOException { - servers = new Server[serverCount]; - for (int i = 0; i < serverCount; i++) { - int port = startPort + i; - servers[i] = ServerBuilder.forPort(port) + servers = new ArrayList<>(); + for (int port : SERVER_PORTS) { + servers.add( + ServerBuilder.forPort(port) .addService(new GreeterImpl(port)) .build() - .start(); + .start()); logger.info("Server started, listening on " + port); } Runtime.getRuntime().addShutdownHook(new Thread(() -> { @@ -55,18 +56,14 @@ private void start() throws IOException { } private void stop() throws InterruptedException { - for (int i = 0; i < serverCount; i++) { - if (servers[i] != null) { - servers[i].shutdown().awaitTermination(30, TimeUnit.SECONDS); - } + for (Server server : servers) { + server.shutdown().awaitTermination(30, TimeUnit.SECONDS); } } private void blockUntilShutdown() throws InterruptedException { - for (int i = 0; i < serverCount; i++) { - if (servers[i] != null) { - servers[i].awaitTermination(); - } + for (Server server : servers) { + server.awaitTermination(); } } @@ -86,7 +83,8 @@ public GreeterImpl(int port) { @Override public void sayHello(HelloRequest req, StreamObserver responseObserver) { - HelloReply reply = HelloReply.newBuilder().setMessage("Hello " + req.getName() + " from server<" + this.port + ">").build(); + HelloReply reply = HelloReply.newBuilder() + .setMessage("Hello " + req.getName() + " from server<" + this.port + ">").build(); responseObserver.onNext(reply); responseObserver.onCompleted(); } diff --git a/examples/src/main/java/io/grpc/examples/nameresolve/NameResolveClient.java b/examples/src/main/java/io/grpc/examples/nameresolve/NameResolveClient.java index ac6fdd32549..9aaccbe1096 100644 --- a/examples/src/main/java/io/grpc/examples/nameresolve/NameResolveClient.java +++ b/examples/src/main/java/io/grpc/examples/nameresolve/NameResolveClient.java @@ -26,8 +26,7 @@ import java.util.logging.Logger; public class NameResolveClient { - public static final String exampleScheme = "example"; - public static final String exampleServiceName = "lb.example.grpc.io"; + public static final String channelTarget = "example:///lb.example.grpc.io"; private static final Logger logger = Logger.getLogger(NameResolveClient.class.getName()); private final GreeterGrpc.GreeterBlockingStub blockingStub; @@ -56,11 +55,10 @@ public static void main(String[] args) throws Exception { Dial to "example:///resolver.example.grpc.io", use {@link ExampleNameResolver} to create connection "resolver.example.grpc.io" is converted to {@link java.net.URI.path} */ - channel = ManagedChannelBuilder.forTarget( - String.format("%s:///%s", exampleScheme, exampleServiceName)) - .defaultLoadBalancingPolicy("round_robin") - .usePlaintext() - .build(); + channel = ManagedChannelBuilder.forTarget(channelTarget) + .defaultLoadBalancingPolicy("round_robin") + .usePlaintext() + .build(); try { NameResolveClient client = new NameResolveClient(channel); for (int i = 0; i < 5; i++) { diff --git a/gcp-observability/build.gradle b/gcp-observability/build.gradle index 0de7f8363bc..f869bd61a76 100644 --- a/gcp-observability/build.gradle +++ b/gcp-observability/build.gradle @@ -65,8 +65,7 @@ dependencies { libraries.auto.value.annotations, // Use our newer version libraries.guava.jre, // Use our newer version libraries.protobuf.java.util, // Use our newer version - libraries.re2j, // Use our newer version - libraries.j2objc.annotations // Explicit dependency to keep in step with version used by guava + libraries.re2j // Use our newer version testImplementation testFixtures(project(':grpc-api')), project(':grpc-testing'), diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 299ca60ab4b..488ead9ad86 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -42,7 +42,6 @@ guava-testlib = "com.google.guava:guava-testlib:33.2.1-android" # May be different from the -android version. guava-jre = "com.google.guava:guava:33.2.1-jre" hdrhistogram = "org.hdrhistogram:HdrHistogram:2.2.2" -j2objc-annotations = " com.google.j2objc:j2objc-annotations:3.0.0" jakarta-servlet-api = "jakarta.servlet:jakarta.servlet-api:5.0.0" javax-annotation = "org.apache.tomcat:annotations-api:6.0.53" javax-servlet-api = "javax.servlet:javax.servlet-api:4.0.1" diff --git a/grpclb/BUILD.bazel b/grpclb/BUILD.bazel index 517155bbfc1..2dd24bb52a2 100644 --- a/grpclb/BUILD.bazel +++ b/grpclb/BUILD.bazel @@ -21,7 +21,6 @@ java_library( "@io_grpc_grpc_proto//:grpclb_load_balancer_java_proto", artifact("com.google.code.findbugs:jsr305"), artifact("com.google.guava:guava"), - artifact("com.google.j2objc:j2objc-annotations"), ], ) diff --git a/grpclb/build.gradle b/grpclb/build.gradle index cea599828f5..93331053b09 100644 --- a/grpclb/build.gradle +++ b/grpclb/build.gradle @@ -19,9 +19,9 @@ dependencies { implementation project(':grpc-core'), project(':grpc-protobuf'), project(':grpc-stub'), + libraries.guava, libraries.protobuf.java, - libraries.protobuf.java.util, - libraries.guava + libraries.protobuf.java.util runtimeOnly libraries.errorprone.annotations compileOnly libraries.javax.annotation testImplementation libraries.truth, diff --git a/inprocess/BUILD.bazel b/inprocess/BUILD.bazel index aa614df654c..bef38612713 100644 --- a/inprocess/BUILD.bazel +++ b/inprocess/BUILD.bazel @@ -13,6 +13,5 @@ java_library( artifact("com.google.code.findbugs:jsr305"), artifact("com.google.errorprone:error_prone_annotations"), artifact("com.google.guava:guava"), - artifact("com.google.j2objc:j2objc-annotations"), ], ) diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java index 3efd576abe6..a581c750028 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java @@ -21,18 +21,12 @@ import static io.grpc.stub.ClientCalls.blockingServerStreamingCall; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -import com.google.auth.oauth2.AccessToken; -import com.google.auth.oauth2.ComputeEngineCredentials; -import com.google.auth.oauth2.GoogleCredentials; -import com.google.auth.oauth2.OAuth2Credentials; -import com.google.auth.oauth2.ServiceAccountCredentials; import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableList; @@ -45,7 +39,6 @@ import io.grpc.Channel; import io.grpc.ClientCall; import io.grpc.ClientInterceptor; -import io.grpc.ClientInterceptors; import io.grpc.ClientStreamTracer; import io.grpc.Context; import io.grpc.Grpc; @@ -62,7 +55,6 @@ import io.grpc.ServerStreamTracer; import io.grpc.Status; import io.grpc.StatusRuntimeException; -import io.grpc.auth.MoreCallCredentials; import io.grpc.census.InternalCensusStatsAccessor; import io.grpc.census.internal.DeprecatedCensusConstants; import io.grpc.internal.GrpcUtil; @@ -77,7 +69,6 @@ import io.grpc.internal.testing.TestServerStreamTracer; import io.grpc.internal.testing.TestStreamTracer; import io.grpc.stub.ClientCallStreamObserver; -import io.grpc.stub.ClientCalls; import io.grpc.stub.MetadataUtils; import io.grpc.stub.StreamObserver; import io.grpc.testing.TestUtils; @@ -92,7 +83,6 @@ import io.grpc.testing.integration.Messages.StreamingInputCallResponse; import io.grpc.testing.integration.Messages.StreamingOutputCallRequest; import io.grpc.testing.integration.Messages.StreamingOutputCallResponse; -import io.grpc.testing.integration.Messages.TestOrcaReport; import io.opencensus.contrib.grpc.metrics.RpcMeasureConstants; import io.opencensus.stats.Measure; import io.opencensus.stats.Measure.MeasureDouble; @@ -118,7 +108,6 @@ import java.util.Locale; import java.util.Map; import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.BlockingQueue; import java.util.concurrent.Executors; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ScheduledExecutorService; @@ -191,11 +180,6 @@ public abstract class AbstractInteropTest { private final LinkedBlockingQueue serverStreamTracers = new LinkedBlockingQueue<>(); - static final CallOptions.Key> - ORCA_RPC_REPORT_KEY = CallOptions.Key.create("orca-rpc-report"); - static final CallOptions.Key> - ORCA_OOB_REPORT_KEY = CallOptions.Key.create("orca-oob-report"); - private static final class ServerStreamTracerInfo { final String fullMethodName; final InteropServerStreamTracer tracer; @@ -451,47 +435,6 @@ public void emptyUnaryWithRetriableStream() throws Exception { assertEquals(EMPTY, TestServiceGrpc.newBlockingStub(channel).emptyCall(EMPTY)); } - /** Sends a cacheable unary rpc using GET. Requires that the server is behind a caching proxy. */ - public void cacheableUnary() { - // THIS TEST IS BROKEN. Enabling safe just on the MethodDescriptor does nothing by itself. This - // test would need to enable GET on the channel. - // Set safe to true. - MethodDescriptor safeCacheableUnaryCallMethod = - TestServiceGrpc.getCacheableUnaryCallMethod().toBuilder().setSafe(true).build(); - // Set fake user IP since some proxies (GFE) won't cache requests from localhost. - Metadata.Key userIpKey = Metadata.Key.of("x-user-ip", Metadata.ASCII_STRING_MARSHALLER); - Metadata metadata = new Metadata(); - metadata.put(userIpKey, "1.2.3.4"); - Channel channelWithUserIpKey = - ClientInterceptors.intercept(channel, MetadataUtils.newAttachHeadersInterceptor(metadata)); - SimpleRequest requests1And2 = - SimpleRequest.newBuilder() - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFromUtf8(String.valueOf(System.nanoTime())))) - .build(); - SimpleRequest request3 = - SimpleRequest.newBuilder() - .setPayload( - Payload.newBuilder() - .setBody(ByteString.copyFromUtf8(String.valueOf(System.nanoTime())))) - .build(); - - SimpleResponse response1 = - ClientCalls.blockingUnaryCall( - channelWithUserIpKey, safeCacheableUnaryCallMethod, CallOptions.DEFAULT, requests1And2); - SimpleResponse response2 = - ClientCalls.blockingUnaryCall( - channelWithUserIpKey, safeCacheableUnaryCallMethod, CallOptions.DEFAULT, requests1And2); - SimpleResponse response3 = - ClientCalls.blockingUnaryCall( - channelWithUserIpKey, safeCacheableUnaryCallMethod, CallOptions.DEFAULT, request3); - - assertEquals(response1, response2); - assertNotEquals(response1, response3); - // THIS TEST IS BROKEN. See comment at start of method. - } - @Test public void largeUnary() throws Exception { assumeEnoughMemory(); @@ -603,26 +546,6 @@ public void serverCompressedUnary() throws Exception { Collections.singleton(goldenResponse)); } - /** - * Assuming "pick_first" policy is used, tests that all requests are sent to the same server. - */ - public void pickFirstUnary() throws Exception { - SimpleRequest request = SimpleRequest.newBuilder() - .setResponseSize(1) - .setFillServerId(true) - .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(new byte[1]))) - .build(); - - SimpleResponse firstResponse = blockingStub.unaryCall(request); - // Increase the chance of all servers are connected, in case the channel should be doing - // round_robin instead. - Thread.sleep(5000); - for (int i = 0; i < 100; i++) { - SimpleResponse response = blockingStub.unaryCall(request); - assertThat(response.getServerId()).isEqualTo(firstResponse.getServerId()); - } - } - @Test public void serverStreaming() throws Exception { final StreamingOutputCallRequest request = StreamingOutputCallRequest.newBuilder() @@ -1757,247 +1680,6 @@ public void getServerAddressAndLocalAddressFromClient() { assertNotNull(obtainLocalClientAddr()); } - /** - * Test backend metrics per query reporting: expect the test client LB policy to receive load - * reports. - */ - public void testOrcaPerRpc() throws Exception { - AtomicReference reportHolder = new AtomicReference<>(); - TestOrcaReport answer = TestOrcaReport.newBuilder() - .setCpuUtilization(0.8210) - .setMemoryUtilization(0.5847) - .putRequestCost("cost", 3456.32) - .putUtilization("util", 0.30499) - .build(); - blockingStub.withOption(ORCA_RPC_REPORT_KEY, reportHolder).unaryCall( - SimpleRequest.newBuilder().setOrcaPerQueryReport(answer).build()); - assertThat(reportHolder.get()).isEqualTo(answer); - } - - /** - * Test backend metrics OOB reporting: expect the test client LB policy to receive load reports. - */ - public void testOrcaOob() throws Exception { - AtomicReference reportHolder = new AtomicReference<>(); - final TestOrcaReport answer = TestOrcaReport.newBuilder() - .setCpuUtilization(0.8210) - .setMemoryUtilization(0.5847) - .putUtilization("util", 0.30499) - .build(); - final TestOrcaReport answer2 = TestOrcaReport.newBuilder() - .setCpuUtilization(0.29309) - .setMemoryUtilization(0.2) - .putUtilization("util", 0.2039) - .build(); - - final int retryLimit = 5; - BlockingQueue queue = new LinkedBlockingQueue<>(); - final Object lastItem = new Object(); - StreamObserver streamObserver = - asyncStub.fullDuplexCall(new StreamObserver() { - - @Override - public void onNext(StreamingOutputCallResponse value) { - queue.add(value); - } - - @Override - public void onError(Throwable t) { - queue.add(t); - } - - @Override - public void onCompleted() { - queue.add(lastItem); - } - }); - - streamObserver.onNext(StreamingOutputCallRequest.newBuilder() - .setOrcaOobReport(answer) - .addResponseParameters(ResponseParameters.newBuilder().setSize(1).build()).build()); - assertThat(queue.take()).isInstanceOf(StreamingOutputCallResponse.class); - int i = 0; - for (; i < retryLimit; i++) { - Thread.sleep(1000); - blockingStub.withOption(ORCA_OOB_REPORT_KEY, reportHolder).emptyCall(EMPTY); - if (answer.equals(reportHolder.get())) { - break; - } - } - assertThat(i).isLessThan(retryLimit); - streamObserver.onNext(StreamingOutputCallRequest.newBuilder() - .setOrcaOobReport(answer2) - .addResponseParameters(ResponseParameters.newBuilder().setSize(1).build()).build()); - assertThat(queue.take()).isInstanceOf(StreamingOutputCallResponse.class); - - for (i = 0; i < retryLimit; i++) { - Thread.sleep(1000); - blockingStub.withOption(ORCA_OOB_REPORT_KEY, reportHolder).emptyCall(EMPTY); - if (reportHolder.get().equals(answer2)) { - break; - } - } - assertThat(i).isLessThan(retryLimit); - streamObserver.onCompleted(); - assertThat(queue.take()).isSameInstanceAs(lastItem); - } - - /** Sends a large unary rpc with service account credentials. */ - public void serviceAccountCreds(String jsonKey, InputStream credentialsStream, String authScope) - throws Exception { - // cast to ServiceAccountCredentials to double-check the right type of object was created. - GoogleCredentials credentials = - ServiceAccountCredentials.class.cast(GoogleCredentials.fromStream(credentialsStream)); - credentials = credentials.createScoped(Arrays.asList(authScope)); - TestServiceGrpc.TestServiceBlockingStub stub = blockingStub - .withCallCredentials(MoreCallCredentials.from(credentials)); - final SimpleRequest request = SimpleRequest.newBuilder() - .setFillUsername(true) - .setFillOauthScope(true) - .setResponseSize(314159) - .setPayload(Payload.newBuilder() - .setBody(ByteString.copyFrom(new byte[271828]))) - .build(); - - final SimpleResponse response = stub.unaryCall(request); - assertFalse(response.getUsername().isEmpty()); - assertTrue("Received username: " + response.getUsername(), - jsonKey.contains(response.getUsername())); - assertFalse(response.getOauthScope().isEmpty()); - assertTrue("Received oauth scope: " + response.getOauthScope(), - authScope.contains(response.getOauthScope())); - - final SimpleResponse goldenResponse = SimpleResponse.newBuilder() - .setOauthScope(response.getOauthScope()) - .setUsername(response.getUsername()) - .setPayload(Payload.newBuilder() - .setBody(ByteString.copyFrom(new byte[314159]))) - .build(); - assertResponse(goldenResponse, response); - } - - /** Sends a large unary rpc with compute engine credentials. */ - public void computeEngineCreds(String serviceAccount, String oauthScope) throws Exception { - ComputeEngineCredentials credentials = ComputeEngineCredentials.create(); - TestServiceGrpc.TestServiceBlockingStub stub = blockingStub - .withCallCredentials(MoreCallCredentials.from(credentials)); - final SimpleRequest request = SimpleRequest.newBuilder() - .setFillUsername(true) - .setFillOauthScope(true) - .setResponseSize(314159) - .setPayload(Payload.newBuilder() - .setBody(ByteString.copyFrom(new byte[271828]))) - .build(); - - final SimpleResponse response = stub.unaryCall(request); - assertEquals(serviceAccount, response.getUsername()); - assertFalse(response.getOauthScope().isEmpty()); - assertTrue("Received oauth scope: " + response.getOauthScope(), - oauthScope.contains(response.getOauthScope())); - - final SimpleResponse goldenResponse = SimpleResponse.newBuilder() - .setOauthScope(response.getOauthScope()) - .setUsername(response.getUsername()) - .setPayload(Payload.newBuilder() - .setBody(ByteString.copyFrom(new byte[314159]))) - .build(); - assertResponse(goldenResponse, response); - } - - /** Sends an unary rpc with ComputeEngineChannelBuilder. */ - public void computeEngineChannelCredentials( - String defaultServiceAccount, - TestServiceGrpc.TestServiceBlockingStub computeEngineStub) throws Exception { - final SimpleRequest request = SimpleRequest.newBuilder() - .setFillUsername(true) - .setResponseSize(314159) - .setPayload(Payload.newBuilder() - .setBody(ByteString.copyFrom(new byte[271828]))) - .build(); - final SimpleResponse response = computeEngineStub.unaryCall(request); - assertEquals(defaultServiceAccount, response.getUsername()); - final SimpleResponse goldenResponse = SimpleResponse.newBuilder() - .setUsername(defaultServiceAccount) - .setPayload(Payload.newBuilder() - .setBody(ByteString.copyFrom(new byte[314159]))) - .build(); - assertResponse(goldenResponse, response); - } - - /** Test JWT-based auth. */ - public void jwtTokenCreds(InputStream serviceAccountJson) throws Exception { - final SimpleRequest request = SimpleRequest.newBuilder() - .setResponseSize(314159) - .setPayload(Payload.newBuilder() - .setBody(ByteString.copyFrom(new byte[271828]))) - .setFillUsername(true) - .build(); - - ServiceAccountCredentials credentials = (ServiceAccountCredentials) - GoogleCredentials.fromStream(serviceAccountJson); - TestServiceGrpc.TestServiceBlockingStub stub = blockingStub - .withCallCredentials(MoreCallCredentials.from(credentials)); - SimpleResponse response = stub.unaryCall(request); - assertEquals(credentials.getClientEmail(), response.getUsername()); - assertEquals(314159, response.getPayload().getBody().size()); - } - - /** Sends a unary rpc with raw oauth2 access token credentials. */ - public void oauth2AuthToken(String jsonKey, InputStream credentialsStream, String authScope) - throws Exception { - GoogleCredentials utilCredentials = - GoogleCredentials.fromStream(credentialsStream); - utilCredentials = utilCredentials.createScoped(Arrays.asList(authScope)); - AccessToken accessToken = utilCredentials.refreshAccessToken(); - - OAuth2Credentials credentials = OAuth2Credentials.create(accessToken); - - TestServiceGrpc.TestServiceBlockingStub stub = blockingStub - .withCallCredentials(MoreCallCredentials.from(credentials)); - final SimpleRequest request = SimpleRequest.newBuilder() - .setFillUsername(true) - .setFillOauthScope(true) - .build(); - - final SimpleResponse response = stub.unaryCall(request); - assertFalse(response.getUsername().isEmpty()); - assertTrue("Received username: " + response.getUsername(), - jsonKey.contains(response.getUsername())); - assertFalse(response.getOauthScope().isEmpty()); - assertTrue("Received oauth scope: " + response.getOauthScope(), - authScope.contains(response.getOauthScope())); - } - - /** Sends a unary rpc with "per rpc" raw oauth2 access token credentials. */ - public void perRpcCreds(String jsonKey, InputStream credentialsStream, String oauthScope) - throws Exception { - // In gRpc Java, we don't have per Rpc credentials, user can use an intercepted stub only once - // for that purpose. - // So, this test is identical to oauth2_auth_token test. - oauth2AuthToken(jsonKey, credentialsStream, oauthScope); - } - - /** Sends an unary rpc with "google default credentials". */ - public void googleDefaultCredentials( - String defaultServiceAccount, - TestServiceGrpc.TestServiceBlockingStub googleDefaultStub) throws Exception { - final SimpleRequest request = SimpleRequest.newBuilder() - .setFillUsername(true) - .setResponseSize(314159) - .setPayload(Payload.newBuilder() - .setBody(ByteString.copyFrom(new byte[271828]))) - .build(); - final SimpleResponse response = googleDefaultStub.unaryCall(request); - assertEquals(defaultServiceAccount, response.getUsername()); - - final SimpleResponse goldenResponse = SimpleResponse.newBuilder() - .setUsername(defaultServiceAccount) - .setPayload(Payload.newBuilder() - .setBody(ByteString.copyFrom(new byte[314159]))) - .build(); - assertResponse(goldenResponse, response); - } - private static class SoakIterationResult { public SoakIterationResult(long latencyMs, Status status) { this.latencyMs = latencyMs; @@ -2481,7 +2163,7 @@ private void assertResponse( } } - private void assertResponse(SimpleResponse expected, SimpleResponse actual) { + public void assertResponse(SimpleResponse expected, SimpleResponse actual) { assertPayload(expected.getPayload(), actual.getPayload()); assertEquals(expected.getUsername(), actual.getUsername()); assertEquals(expected.getOauthScope(), actual.getOauthScope()); diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/CustomBackendMetricsLoadBalancerProvider.java b/interop-testing/src/main/java/io/grpc/testing/integration/CustomBackendMetricsLoadBalancerProvider.java index 87ecf308674..b9a89a01e3a 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/CustomBackendMetricsLoadBalancerProvider.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/CustomBackendMetricsLoadBalancerProvider.java @@ -16,8 +16,8 @@ package io.grpc.testing.integration; -import static io.grpc.testing.integration.AbstractInteropTest.ORCA_OOB_REPORT_KEY; -import static io.grpc.testing.integration.AbstractInteropTest.ORCA_RPC_REPORT_KEY; +import static io.grpc.testing.integration.TestServiceClient.ORCA_OOB_REPORT_KEY; +import static io.grpc.testing.integration.TestServiceClient.ORCA_RPC_REPORT_KEY; import io.grpc.ConnectivityState; import io.grpc.LoadBalancer; diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java index 0c8f697ada5..e6829be11cb 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java @@ -16,10 +16,25 @@ package io.grpc.testing.integration; +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; + +import com.google.auth.oauth2.AccessToken; +import com.google.auth.oauth2.ComputeEngineCredentials; +import com.google.auth.oauth2.GoogleCredentials; +import com.google.auth.oauth2.OAuth2Credentials; +import com.google.auth.oauth2.ServiceAccountCredentials; import com.google.common.annotations.VisibleForTesting; import com.google.common.io.Files; +import com.google.protobuf.ByteString; +import io.grpc.CallOptions; +import io.grpc.Channel; import io.grpc.ChannelCredentials; import io.grpc.ClientInterceptor; +import io.grpc.ClientInterceptors; import io.grpc.Grpc; import io.grpc.InsecureChannelCredentials; import io.grpc.InsecureServerCredentials; @@ -28,11 +43,13 @@ import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import io.grpc.Metadata; +import io.grpc.MethodDescriptor; import io.grpc.ServerBuilder; import io.grpc.TlsChannelCredentials; import io.grpc.alts.AltsChannelCredentials; import io.grpc.alts.ComputeEngineChannelCredentials; import io.grpc.alts.GoogleDefaultChannelCredentials; +import io.grpc.auth.MoreCallCredentials; import io.grpc.internal.GrpcUtil; import io.grpc.internal.JsonParser; import io.grpc.netty.InsecureFromHttp1ChannelCredentials; @@ -40,13 +57,27 @@ import io.grpc.netty.NettyChannelBuilder; import io.grpc.okhttp.InternalOkHttpChannelBuilder; import io.grpc.okhttp.OkHttpChannelBuilder; +import io.grpc.stub.ClientCalls; import io.grpc.stub.MetadataUtils; +import io.grpc.stub.StreamObserver; import io.grpc.testing.TlsTesting; +import io.grpc.testing.integration.Messages.Payload; +import io.grpc.testing.integration.Messages.ResponseParameters; +import io.grpc.testing.integration.Messages.SimpleRequest; +import io.grpc.testing.integration.Messages.SimpleResponse; +import io.grpc.testing.integration.Messages.StreamingOutputCallRequest; +import io.grpc.testing.integration.Messages.StreamingOutputCallResponse; +import io.grpc.testing.integration.Messages.TestOrcaReport; import java.io.File; import java.io.FileInputStream; +import java.io.InputStream; import java.nio.charset.Charset; +import java.util.Arrays; import java.util.Map; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import javax.annotation.Nullable; /** @@ -57,6 +88,11 @@ public class TestServiceClient { private static final Charset UTF_8 = Charset.forName("UTF-8"); + static final CallOptions.Key> + ORCA_RPC_REPORT_KEY = CallOptions.Key.create("orca-rpc-report"); + static final CallOptions.Key> + ORCA_OOB_REPORT_KEY = CallOptions.Key.create("orca-oob-report"); + /** * The main application allowing this client to be launched from the command line. */ @@ -668,6 +704,313 @@ protected ManagedChannelBuilder createChannelBuilder() { return okBuilder.intercept(createCensusStatsClientInterceptor()); } + /** + * Assuming "pick_first" policy is used, tests that all requests are sent to the same server. + */ + public void pickFirstUnary() throws Exception { + SimpleRequest request = SimpleRequest.newBuilder() + .setResponseSize(1) + .setFillServerId(true) + .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(new byte[1]))) + .build(); + + SimpleResponse firstResponse = blockingStub.unaryCall(request); + // Increase the chance of all servers are connected, in case the channel should be doing + // round_robin instead. + Thread.sleep(5000); + for (int i = 0; i < 100; i++) { + SimpleResponse response = blockingStub.unaryCall(request); + assertThat(response.getServerId()).isEqualTo(firstResponse.getServerId()); + } + } + + /** + * Sends a cacheable unary rpc using GET. Requires that the server is behind a caching proxy. + */ + public void cacheableUnary() { + // THIS TEST IS BROKEN. Enabling safe just on the MethodDescriptor does nothing by itself. + // This test would need to enable GET on the channel. + // Set safe to true. + MethodDescriptor safeCacheableUnaryCallMethod = + TestServiceGrpc.getCacheableUnaryCallMethod().toBuilder().setSafe(true).build(); + // Set fake user IP since some proxies (GFE) won't cache requests from localhost. + Metadata.Key userIpKey = + Metadata.Key.of("x-user-ip", Metadata.ASCII_STRING_MARSHALLER); + Metadata metadata = new Metadata(); + metadata.put(userIpKey, "1.2.3.4"); + Channel channelWithUserIpKey = ClientInterceptors.intercept( + channel, MetadataUtils.newAttachHeadersInterceptor(metadata)); + SimpleRequest requests1And2 = + SimpleRequest.newBuilder() + .setPayload( + Payload.newBuilder() + .setBody(ByteString.copyFromUtf8(String.valueOf(System.nanoTime())))) + .build(); + SimpleRequest request3 = + SimpleRequest.newBuilder() + .setPayload( + Payload.newBuilder() + .setBody(ByteString.copyFromUtf8(String.valueOf(System.nanoTime())))) + .build(); + + SimpleResponse response1 = + ClientCalls.blockingUnaryCall( + channelWithUserIpKey, safeCacheableUnaryCallMethod, CallOptions.DEFAULT, + requests1And2); + SimpleResponse response2 = + ClientCalls.blockingUnaryCall( + channelWithUserIpKey, safeCacheableUnaryCallMethod, CallOptions.DEFAULT, + requests1And2); + SimpleResponse response3 = + ClientCalls.blockingUnaryCall( + channelWithUserIpKey, safeCacheableUnaryCallMethod, CallOptions.DEFAULT, request3); + + assertEquals(response1, response2); + assertNotEquals(response1, response3); + // THIS TEST IS BROKEN. See comment at start of method. + } + + /** Sends a large unary rpc with service account credentials. */ + public void serviceAccountCreds(String jsonKey, InputStream credentialsStream, String authScope) + throws Exception { + // cast to ServiceAccountCredentials to double-check the right type of object was created. + GoogleCredentials credentials = + ServiceAccountCredentials.class.cast(GoogleCredentials.fromStream(credentialsStream)); + credentials = credentials.createScoped(Arrays.asList(authScope)); + TestServiceGrpc.TestServiceBlockingStub stub = blockingStub + .withCallCredentials(MoreCallCredentials.from(credentials)); + final SimpleRequest request = SimpleRequest.newBuilder() + .setFillUsername(true) + .setFillOauthScope(true) + .setResponseSize(314159) + .setPayload(Payload.newBuilder() + .setBody(ByteString.copyFrom(new byte[271828]))) + .build(); + + final SimpleResponse response = stub.unaryCall(request); + assertFalse(response.getUsername().isEmpty()); + assertTrue("Received username: " + response.getUsername(), + jsonKey.contains(response.getUsername())); + assertFalse(response.getOauthScope().isEmpty()); + assertTrue("Received oauth scope: " + response.getOauthScope(), + authScope.contains(response.getOauthScope())); + + final SimpleResponse goldenResponse = SimpleResponse.newBuilder() + .setOauthScope(response.getOauthScope()) + .setUsername(response.getUsername()) + .setPayload(Payload.newBuilder() + .setBody(ByteString.copyFrom(new byte[314159]))) + .build(); + assertResponse(goldenResponse, response); + } + + /** Sends a large unary rpc with compute engine credentials. */ + public void computeEngineCreds(String serviceAccount, String oauthScope) throws Exception { + ComputeEngineCredentials credentials = ComputeEngineCredentials.create(); + TestServiceGrpc.TestServiceBlockingStub stub = blockingStub + .withCallCredentials(MoreCallCredentials.from(credentials)); + final SimpleRequest request = SimpleRequest.newBuilder() + .setFillUsername(true) + .setFillOauthScope(true) + .setResponseSize(314159) + .setPayload(Payload.newBuilder() + .setBody(ByteString.copyFrom(new byte[271828]))) + .build(); + + final SimpleResponse response = stub.unaryCall(request); + assertEquals(serviceAccount, response.getUsername()); + assertFalse(response.getOauthScope().isEmpty()); + assertTrue("Received oauth scope: " + response.getOauthScope(), + oauthScope.contains(response.getOauthScope())); + + final SimpleResponse goldenResponse = SimpleResponse.newBuilder() + .setOauthScope(response.getOauthScope()) + .setUsername(response.getUsername()) + .setPayload(Payload.newBuilder() + .setBody(ByteString.copyFrom(new byte[314159]))) + .build(); + assertResponse(goldenResponse, response); + } + + /** Sends an unary rpc with ComputeEngineChannelBuilder. */ + public void computeEngineChannelCredentials( + String defaultServiceAccount, + TestServiceGrpc.TestServiceBlockingStub computeEngineStub) throws Exception { + final SimpleRequest request = SimpleRequest.newBuilder() + .setFillUsername(true) + .setResponseSize(314159) + .setPayload(Payload.newBuilder() + .setBody(ByteString.copyFrom(new byte[271828]))) + .build(); + final SimpleResponse response = computeEngineStub.unaryCall(request); + assertEquals(defaultServiceAccount, response.getUsername()); + final SimpleResponse goldenResponse = SimpleResponse.newBuilder() + .setUsername(defaultServiceAccount) + .setPayload(Payload.newBuilder() + .setBody(ByteString.copyFrom(new byte[314159]))) + .build(); + assertResponse(goldenResponse, response); + } + + /** Test JWT-based auth. */ + public void jwtTokenCreds(InputStream serviceAccountJson) throws Exception { + final SimpleRequest request = SimpleRequest.newBuilder() + .setResponseSize(314159) + .setPayload(Payload.newBuilder() + .setBody(ByteString.copyFrom(new byte[271828]))) + .setFillUsername(true) + .build(); + + ServiceAccountCredentials credentials = (ServiceAccountCredentials) + GoogleCredentials.fromStream(serviceAccountJson); + TestServiceGrpc.TestServiceBlockingStub stub = blockingStub + .withCallCredentials(MoreCallCredentials.from(credentials)); + SimpleResponse response = stub.unaryCall(request); + assertEquals(credentials.getClientEmail(), response.getUsername()); + assertEquals(314159, response.getPayload().getBody().size()); + } + + /** Sends a unary rpc with raw oauth2 access token credentials. */ + public void oauth2AuthToken(String jsonKey, InputStream credentialsStream, String authScope) + throws Exception { + GoogleCredentials utilCredentials = + GoogleCredentials.fromStream(credentialsStream); + utilCredentials = utilCredentials.createScoped(Arrays.asList(authScope)); + AccessToken accessToken = utilCredentials.refreshAccessToken(); + + OAuth2Credentials credentials = OAuth2Credentials.create(accessToken); + + TestServiceGrpc.TestServiceBlockingStub stub = blockingStub + .withCallCredentials(MoreCallCredentials.from(credentials)); + final SimpleRequest request = SimpleRequest.newBuilder() + .setFillUsername(true) + .setFillOauthScope(true) + .build(); + + final SimpleResponse response = stub.unaryCall(request); + assertFalse(response.getUsername().isEmpty()); + assertTrue("Received username: " + response.getUsername(), + jsonKey.contains(response.getUsername())); + assertFalse(response.getOauthScope().isEmpty()); + assertTrue("Received oauth scope: " + response.getOauthScope(), + authScope.contains(response.getOauthScope())); + } + + /** Sends a unary rpc with "per rpc" raw oauth2 access token credentials. */ + public void perRpcCreds(String jsonKey, InputStream credentialsStream, String oauthScope) + throws Exception { + // In gRpc Java, we don't have per Rpc credentials, user can use an intercepted stub only once + // for that purpose. + // So, this test is identical to oauth2_auth_token test. + oauth2AuthToken(jsonKey, credentialsStream, oauthScope); + } + + /** Sends an unary rpc with "google default credentials". */ + public void googleDefaultCredentials( + String defaultServiceAccount, + TestServiceGrpc.TestServiceBlockingStub googleDefaultStub) throws Exception { + final SimpleRequest request = SimpleRequest.newBuilder() + .setFillUsername(true) + .setResponseSize(314159) + .setPayload(Payload.newBuilder() + .setBody(ByteString.copyFrom(new byte[271828]))) + .build(); + final SimpleResponse response = googleDefaultStub.unaryCall(request); + assertEquals(defaultServiceAccount, response.getUsername()); + + final SimpleResponse goldenResponse = SimpleResponse.newBuilder() + .setUsername(defaultServiceAccount) + .setPayload(Payload.newBuilder() + .setBody(ByteString.copyFrom(new byte[314159]))) + .build(); + assertResponse(goldenResponse, response); + } + + /** + * Test backend metrics per query reporting: expect the test client LB policy to receive load + * reports. + */ + public void testOrcaPerRpc() throws Exception { + AtomicReference reportHolder = new AtomicReference<>(); + TestOrcaReport answer = TestOrcaReport.newBuilder() + .setCpuUtilization(0.8210) + .setMemoryUtilization(0.5847) + .putRequestCost("cost", 3456.32) + .putUtilization("util", 0.30499) + .build(); + blockingStub.withOption(ORCA_RPC_REPORT_KEY, reportHolder).unaryCall( + SimpleRequest.newBuilder().setOrcaPerQueryReport(answer).build()); + assertThat(reportHolder.get()).isEqualTo(answer); + } + + /** + * Test backend metrics OOB reporting: expect the test client LB policy to receive load reports. + */ + public void testOrcaOob() throws Exception { + AtomicReference reportHolder = new AtomicReference<>(); + final TestOrcaReport answer = TestOrcaReport.newBuilder() + .setCpuUtilization(0.8210) + .setMemoryUtilization(0.5847) + .putUtilization("util", 0.30499) + .build(); + final TestOrcaReport answer2 = TestOrcaReport.newBuilder() + .setCpuUtilization(0.29309) + .setMemoryUtilization(0.2) + .putUtilization("util", 0.2039) + .build(); + + final int retryLimit = 5; + BlockingQueue queue = new LinkedBlockingQueue<>(); + final Object lastItem = new Object(); + StreamObserver streamObserver = + asyncStub.fullDuplexCall(new StreamObserver() { + + @Override + public void onNext(StreamingOutputCallResponse value) { + queue.add(value); + } + + @Override + public void onError(Throwable t) { + queue.add(t); + } + + @Override + public void onCompleted() { + queue.add(lastItem); + } + }); + + streamObserver.onNext(StreamingOutputCallRequest.newBuilder() + .setOrcaOobReport(answer) + .addResponseParameters(ResponseParameters.newBuilder().setSize(1).build()).build()); + assertThat(queue.take()).isInstanceOf(StreamingOutputCallResponse.class); + int i = 0; + for (; i < retryLimit; i++) { + Thread.sleep(1000); + blockingStub.withOption(ORCA_OOB_REPORT_KEY, reportHolder).emptyCall(EMPTY); + if (answer.equals(reportHolder.get())) { + break; + } + } + assertThat(i).isLessThan(retryLimit); + streamObserver.onNext(StreamingOutputCallRequest.newBuilder() + .setOrcaOobReport(answer2) + .addResponseParameters(ResponseParameters.newBuilder().setSize(1).build()).build()); + assertThat(queue.take()).isInstanceOf(StreamingOutputCallResponse.class); + + for (i = 0; i < retryLimit; i++) { + Thread.sleep(1000); + blockingStub.withOption(ORCA_OOB_REPORT_KEY, reportHolder).emptyCall(EMPTY); + if (reportHolder.get().equals(answer2)) { + break; + } + } + assertThat(i).isLessThan(retryLimit); + streamObserver.onCompleted(); + assertThat(queue.take()).isSameInstanceAs(lastItem); + } + @Override protected boolean metricsExpected() { // Exact message size doesn't match when testing with Go servers: diff --git a/netty/BUILD.bazel b/netty/BUILD.bazel index daf2e83e59a..9fe52ea5868 100644 --- a/netty/BUILD.bazel +++ b/netty/BUILD.bazel @@ -15,7 +15,6 @@ java_library( artifact("com.google.code.findbugs:jsr305"), artifact("com.google.errorprone:error_prone_annotations"), artifact("com.google.guava:guava"), - artifact("com.google.j2objc:j2objc-annotations"), artifact("io.netty:netty-buffer"), artifact("io.netty:netty-codec"), artifact("io.netty:netty-codec-http"), diff --git a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java index bbe9fab9233..73988f773cb 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientHandlerTest.java @@ -16,7 +16,6 @@ package io.grpc.netty; -import static com.google.common.base.Charsets.UTF_8; import static com.google.common.truth.Truth.assertThat; import static io.grpc.internal.ClientStreamListener.RpcProgress.MISCARRIED; import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED; @@ -30,6 +29,7 @@ import static io.grpc.netty.Utils.TE_HEADER; import static io.grpc.netty.Utils.TE_TRAILERS; import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; diff --git a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java index f94960cbab3..9777bb0926c 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java @@ -16,7 +16,6 @@ package io.grpc.netty; -import static com.google.common.base.Charsets.UTF_8; import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.TruthJUnit.assume; import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; @@ -29,6 +28,7 @@ import static io.grpc.netty.NettyServerBuilder.MAX_CONNECTION_IDLE_NANOS_DISABLED; import static io.grpc.netty.NettyServerBuilder.MAX_RST_COUNT_DISABLED; import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_WINDOW_SIZE; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; @@ -82,6 +82,7 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelOption; import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultEventLoopGroup; import io.netty.channel.EventLoopGroup; import io.netty.channel.ReflectiveChannelFactory; import io.netty.channel.local.LocalChannel; @@ -519,15 +520,20 @@ public void channelFactoryShouldSetSocketOptionKeepAlive() throws Exception { @Test public void channelFactoryShouldNNotSetSocketOptionKeepAlive() throws Exception { startServer(); - NettyClientTransport transport = newTransport(newNegotiator(), - DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, "testUserAgent", true, - TimeUnit.SECONDS.toNanos(10L), TimeUnit.SECONDS.toNanos(1L), - new ReflectiveChannelFactory<>(LocalChannel.class), group); + DefaultEventLoopGroup group = new DefaultEventLoopGroup(1); + try { + NettyClientTransport transport = newTransport(newNegotiator(), + DEFAULT_MAX_MESSAGE_SIZE, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE, "testUserAgent", true, + TimeUnit.SECONDS.toNanos(10L), TimeUnit.SECONDS.toNanos(1L), + new ReflectiveChannelFactory<>(LocalChannel.class), group); - callMeMaybe(transport.start(clientTransportListener)); + callMeMaybe(transport.start(clientTransportListener)); - assertThat(transport.channel().config().getOption(ChannelOption.SO_KEEPALIVE)) - .isNull(); + assertThat(transport.channel().config().getOption(ChannelOption.SO_KEEPALIVE)) + .isNull(); + } finally { + group.shutdownGracefully(0, 10, TimeUnit.SECONDS); + } } @Test diff --git a/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java b/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java index fbab1ca5fae..eef8d30e05a 100644 --- a/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java +++ b/netty/src/test/java/io/grpc/netty/NettyHandlerTestBase.java @@ -16,8 +16,8 @@ package io.grpc.netty; -import static com.google.common.base.Charsets.UTF_8; import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_WINDOW_SIZE; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertEquals; import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; diff --git a/netty/src/test/java/io/grpc/netty/NettyReadableBufferTest.java b/netty/src/test/java/io/grpc/netty/NettyReadableBufferTest.java index 1a0ac229a89..4f10504c07d 100644 --- a/netty/src/test/java/io/grpc/netty/NettyReadableBufferTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyReadableBufferTest.java @@ -16,7 +16,7 @@ package io.grpc.netty; -import static com.google.common.base.Charsets.UTF_8; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; diff --git a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java index ce902a9620b..541490847c0 100644 --- a/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyServerHandlerTest.java @@ -16,7 +16,6 @@ package io.grpc.netty; -import static com.google.common.base.Charsets.UTF_8; import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; import static io.grpc.internal.GrpcUtil.DEFAULT_SERVER_KEEPALIVE_TIMEOUT_NANOS; import static io.grpc.internal.GrpcUtil.DEFAULT_SERVER_KEEPALIVE_TIME_NANOS; @@ -29,6 +28,7 @@ import static io.grpc.netty.Utils.HTTP_METHOD; import static io.grpc.netty.Utils.TE_HEADER; import static io.grpc.netty.Utils.TE_TRAILERS; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; diff --git a/netty/src/test/java/io/grpc/netty/NettyStreamTestBase.java b/netty/src/test/java/io/grpc/netty/NettyStreamTestBase.java index f073fb6b2e4..ce42e3d25df 100644 --- a/netty/src/test/java/io/grpc/netty/NettyStreamTestBase.java +++ b/netty/src/test/java/io/grpc/netty/NettyStreamTestBase.java @@ -16,8 +16,8 @@ package io.grpc.netty; -import static com.google.common.base.Charsets.US_ASCII; import static io.grpc.netty.NettyTestUtil.messageFrame; +import static java.nio.charset.StandardCharsets.US_ASCII; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; diff --git a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java index 1852213da52..6939d835892 100644 --- a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java +++ b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java @@ -16,10 +16,10 @@ package io.grpc.netty; -import static com.google.common.base.Charsets.UTF_8; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; diff --git a/okhttp/BUILD.bazel b/okhttp/BUILD.bazel index 7cf1775da2c..80068c9bb5b 100644 --- a/okhttp/BUILD.bazel +++ b/okhttp/BUILD.bazel @@ -17,7 +17,6 @@ java_library( artifact("com.google.code.findbugs:jsr305"), artifact("com.google.errorprone:error_prone_annotations"), artifact("com.google.guava:guava"), - artifact("com.google.j2objc:j2objc-annotations"), artifact("com.squareup.okhttp:okhttp"), artifact("com.squareup.okio:okio"), artifact("io.perfmark:perfmark-api"), diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/BinaryFormat.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/BinaryFormat.java new file mode 100644 index 00000000000..cdf27875903 --- /dev/null +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/BinaryFormat.java @@ -0,0 +1,143 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.opentelemetry; + + +import static com.google.common.base.Preconditions.checkNotNull; + +import io.grpc.Metadata; +import io.opentelemetry.api.trace.SpanContext; +import io.opentelemetry.api.trace.SpanId; +import io.opentelemetry.api.trace.TraceFlags; +import io.opentelemetry.api.trace.TraceId; +import io.opentelemetry.api.trace.TraceState; +import java.util.Arrays; + +/** + * Binary encoded {@link SpanContext} for context propagation. This is adapted from OpenCensus + * binary format. + * + *

BinaryFormat format: + * + *

    + *
  • Binary value: <version_id><version_format> + *
  • version_id: 1-byte representing the version id. + *
  • For version_id = 0: + *
      + *
    • version_format: <field><field> + *
    • field_format: <field_id><field_format> + *
    • Fields: + *
        + *
      • TraceId: (field_id = 0, len = 16, default = "0000000000000000") - + * 16-byte array representing the trace_id. + *
      • SpanId: (field_id = 1, len = 8, default = "00000000") - 8-byte array + * representing the span_id. + *
      • TraceFlags: (field_id = 2, len = 1, default = "0") - 1-byte array + * representing the trace_flags. + *
      + *
    • Fields MUST be encoded using the field id order (smaller to higher). + *
    • Valid value example: + *
        + *
      • {0, 0, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 1, 97, + * 98, 99, 100, 101, 102, 103, 104, 2, 1} + *
      • version_id = 0; + *
      • trace_id = {64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79} + *
      • span_id = {97, 98, 99, 100, 101, 102, 103, 104}; + *
      • trace_flags = {1}; + *
      + *
    + *
+ */ +final class BinaryFormat implements Metadata.BinaryMarshaller { + private static final byte VERSION_ID = 0; + private static final int VERSION_ID_OFFSET = 0; + private static final byte ID_SIZE = 1; + private static final byte TRACE_ID_FIELD_ID = 0; + + private static final int TRACE_ID_FIELD_ID_OFFSET = VERSION_ID_OFFSET + ID_SIZE; + private static final int TRACE_ID_OFFSET = TRACE_ID_FIELD_ID_OFFSET + ID_SIZE; + private static final int TRACE_ID_SIZE = TraceId.getLength() / 2; + + private static final byte SPAN_ID_FIELD_ID = 1; + private static final int SPAN_ID_FIELD_ID_OFFSET = TRACE_ID_OFFSET + TRACE_ID_SIZE; + private static final int SPAN_ID_OFFSET = SPAN_ID_FIELD_ID_OFFSET + ID_SIZE; + private static final int SPAN_ID_SIZE = SpanId.getLength() / 2; + + private static final byte TRACE_FLAG_FIELD_ID = 2; + private static final int TRACE_FLAG_FIELD_ID_OFFSET = SPAN_ID_OFFSET + SPAN_ID_SIZE; + private static final int TRACE_FLAG_OFFSET = TRACE_FLAG_FIELD_ID_OFFSET + ID_SIZE; + private static final int REQUIRED_FORMAT_LENGTH = 3 * ID_SIZE + TRACE_ID_SIZE + SPAN_ID_SIZE; + private static final int TRACE_FLAG_SIZE = TraceFlags.getLength() / 2; + private static final int ALL_FORMAT_LENGTH = REQUIRED_FORMAT_LENGTH + ID_SIZE + TRACE_FLAG_SIZE; + + private static final BinaryFormat INSTANCE = new BinaryFormat(); + + public static BinaryFormat getInstance() { + return INSTANCE; + } + + @Override + public byte[] toBytes(SpanContext spanContext) { + checkNotNull(spanContext, "spanContext"); + byte[] bytes = new byte[ALL_FORMAT_LENGTH]; + bytes[VERSION_ID_OFFSET] = VERSION_ID; + bytes[TRACE_ID_FIELD_ID_OFFSET] = TRACE_ID_FIELD_ID; + System.arraycopy(spanContext.getTraceIdBytes(), 0, bytes, TRACE_ID_OFFSET, TRACE_ID_SIZE); + bytes[SPAN_ID_FIELD_ID_OFFSET] = SPAN_ID_FIELD_ID; + System.arraycopy(spanContext.getSpanIdBytes(), 0, bytes, SPAN_ID_OFFSET, SPAN_ID_SIZE); + bytes[TRACE_FLAG_FIELD_ID_OFFSET] = TRACE_FLAG_FIELD_ID; + bytes[TRACE_FLAG_OFFSET] = spanContext.getTraceFlags().asByte(); + return bytes; + } + + + @Override + public SpanContext parseBytes(byte[] serialized) { + checkNotNull(serialized, "bytes"); + if (serialized.length == 0 || serialized[0] != VERSION_ID) { + throw new IllegalArgumentException("Unsupported version."); + } + if (serialized.length < REQUIRED_FORMAT_LENGTH) { + throw new IllegalArgumentException("Invalid input: truncated"); + } + String traceId; + String spanId; + TraceFlags traceFlags = TraceFlags.getDefault(); + int pos = 1; + if (serialized[pos] == TRACE_ID_FIELD_ID) { + traceId = TraceId.fromBytes( + Arrays.copyOfRange(serialized, pos + ID_SIZE, pos + ID_SIZE + TRACE_ID_SIZE)); + pos += ID_SIZE + TRACE_ID_SIZE; + } else { + throw new IllegalArgumentException("Invalid input: expected trace ID at offset " + pos); + } + if (serialized[pos] == SPAN_ID_FIELD_ID) { + spanId = SpanId.fromBytes( + Arrays.copyOfRange(serialized, pos + ID_SIZE, pos + ID_SIZE + SPAN_ID_SIZE)); + pos += ID_SIZE + SPAN_ID_SIZE; + } else { + throw new IllegalArgumentException("Invalid input: expected span ID at offset " + pos); + } + if (serialized.length > pos && serialized[pos] == TRACE_FLAG_FIELD_ID) { + if (serialized.length < ALL_FORMAT_LENGTH) { + throw new IllegalArgumentException("Invalid input: truncated"); + } + traceFlags = TraceFlags.fromByte(serialized[pos + ID_SIZE]); + } + return SpanContext.create(traceId, spanId, traceFlags, TraceState.getDefault()); + } +} diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/GrpcTraceBinContextPropagator.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/GrpcTraceBinContextPropagator.java new file mode 100644 index 00000000000..4825b203529 --- /dev/null +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/GrpcTraceBinContextPropagator.java @@ -0,0 +1,147 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.opentelemetry; + + +import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.InternalMetadata.BASE64_ENCODING_OMIT_PADDING; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.io.BaseEncoding; +import io.grpc.ExperimentalApi; +import io.grpc.Metadata; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.SpanContext; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.propagation.TextMapGetter; +import io.opentelemetry.context.propagation.TextMapPropagator; +import io.opentelemetry.context.propagation.TextMapSetter; +import java.util.Collection; +import java.util.Collections; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.Nullable; + +/** + * A {@link TextMapPropagator} for transmitting "grpc-trace-bin" span context. + * + *

This propagator can transmit the "grpc-trace-bin" context in either binary or Base64-encoded + * text format, depending on the capabilities of the provided {@link TextMapGetter} and + * {@link TextMapSetter}. + * + *

If the {@code TextMapGetter} and {@code TextMapSetter} only support text format, Base64 + * encoding and decoding will be used when communicating with the carrier API. But gRPC uses + * it with gRPC's metadata-based getter/setter, and the propagator can directly transmit the binary + * header, avoiding the need for Base64 encoding. + */ + +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/11400") +public final class GrpcTraceBinContextPropagator implements TextMapPropagator { + private static final Logger log = Logger.getLogger(GrpcTraceBinContextPropagator.class.getName()); + public static final String GRPC_TRACE_BIN_HEADER = "grpc-trace-bin"; + private final Metadata.BinaryMarshaller binaryFormat; + private static final GrpcTraceBinContextPropagator INSTANCE = + new GrpcTraceBinContextPropagator(BinaryFormat.getInstance()); + + public static GrpcTraceBinContextPropagator defaultInstance() { + return INSTANCE; + } + + @VisibleForTesting + GrpcTraceBinContextPropagator(Metadata.BinaryMarshaller binaryFormat) { + this.binaryFormat = checkNotNull(binaryFormat, "binaryFormat"); + } + + @Override + public Collection fields() { + return Collections.singleton(GRPC_TRACE_BIN_HEADER); + } + + @Override + public void inject(Context context, @Nullable C carrier, TextMapSetter setter) { + if (context == null || setter == null) { + return; + } + SpanContext spanContext = Span.fromContext(context).getSpanContext(); + if (!spanContext.isValid()) { + return; + } + try { + byte[] b = binaryFormat.toBytes(spanContext); + if (setter instanceof MetadataSetter) { + ((MetadataSetter) setter).set((Metadata) carrier, GRPC_TRACE_BIN_HEADER, b); + } else { + setter.set(carrier, GRPC_TRACE_BIN_HEADER, BASE64_ENCODING_OMIT_PADDING.encode(b)); + } + } catch (Exception e) { + log.log(Level.FINE, "Set grpc-trace-bin spanContext failed", e); + } + } + + @Override + public Context extract(Context context, @Nullable C carrier, TextMapGetter getter) { + if (context == null) { + return Context.root(); + } + if (getter == null) { + return context; + } + byte[] b; + if (getter instanceof MetadataGetter) { + try { + b = ((MetadataGetter) getter).getBinary((Metadata) carrier, GRPC_TRACE_BIN_HEADER); + if (b == null) { + log.log(Level.FINE, "No grpc-trace-bin present in carrier"); + return context; + } + } catch (Exception e) { + log.log(Level.FINE, "Get 'grpc-trace-bin' from MetadataGetter failed", e); + return context; + } + } else { + String value; + try { + value = getter.get(carrier, GRPC_TRACE_BIN_HEADER); + if (value == null) { + log.log(Level.FINE, "No grpc-trace-bin present in carrier"); + return context; + } + } catch (Exception e) { + log.log(Level.FINE, "Get 'grpc-trace-bin' from getter failed", e); + return context; + } + try { + b = BaseEncoding.base64().decode(value); + } catch (Exception e) { + log.log(Level.FINE, "Base64-decode spanContext bytes failed", e); + return context; + } + } + + SpanContext spanContext; + try { + spanContext = binaryFormat.parseBytes(b); + } catch (Exception e) { + log.log(Level.FINE, "Failed to parse tracing header", e); + return context; + } + if (!spanContext.isValid()) { + return context; + } + return context.with(Span.wrap(spanContext)); + } +} diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/MetadataGetter.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/MetadataGetter.java new file mode 100644 index 00000000000..f49c029f2fb --- /dev/null +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/MetadataGetter.java @@ -0,0 +1,87 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.opentelemetry; + + +import static io.grpc.InternalMetadata.BASE64_ENCODING_OMIT_PADDING; + +import io.grpc.Metadata; +import io.opentelemetry.context.propagation.TextMapGetter; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.Nullable; + +/** + * A TextMapGetter that reads value from gRPC {@link Metadata}. Supports both text and binary + * headers. Supporting binary header is an optimization path for GrpcTraceBinContextPropagator + * to work around the lack of binary propagator API and thus avoid + * base64 (de)encoding when passing data between propagator API interfaces. + */ +final class MetadataGetter implements TextMapGetter { + private static final Logger logger = Logger.getLogger(MetadataGetter.class.getName()); + private static final MetadataGetter INSTANCE = new MetadataGetter(); + + public static MetadataGetter getInstance() { + return INSTANCE; + } + + @Override + public Iterable keys(Metadata carrier) { + return carrier.keys(); + } + + @Nullable + @Override + public String get(@Nullable Metadata carrier, String key) { + if (carrier == null) { + logger.log(Level.FINE, "Carrier is null, getting no data"); + return null; + } + try { + if (key.equals("grpc-trace-bin")) { + byte[] value = carrier.get(Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER)); + if (value == null) { + return null; + } + return BASE64_ENCODING_OMIT_PADDING.encode(value); + } else { + return carrier.get(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER)); + } + } catch (Exception e) { + logger.log(Level.FINE, String.format("Failed to get metadata key %s", key), e); + return null; + } + } + + @Nullable + public byte[] getBinary(@Nullable Metadata carrier, String key) { + if (carrier == null) { + logger.log(Level.FINE, "Carrier is null, getting no data"); + return null; + } + if (!key.equals("grpc-trace-bin")) { + logger.log(Level.FINE, "Only support 'grpc-trace-bin' binary header. Get no data"); + return null; + } + try { + return carrier.get(Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER)); + } catch (Exception e) { + logger.log(Level.FINE, String.format("Failed to get metadata key %s", key), e); + return null; + } + } +} diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/MetadataSetter.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/MetadataSetter.java new file mode 100644 index 00000000000..5892c7accfe --- /dev/null +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/MetadataSetter.java @@ -0,0 +1,74 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.opentelemetry; + + +import com.google.common.io.BaseEncoding; +import io.grpc.Metadata; +import io.opentelemetry.context.propagation.TextMapSetter; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.Nullable; + +/** + * A {@link TextMapSetter} that sets value to gRPC {@link Metadata}. Supports both text and binary + * headers. Supporting binary header is an optimization path for GrpcTraceBinContextPropagator + * to work around the lack of binary propagator API and thus avoid + * base64 (de)encoding when passing data between propagator API interfaces. + */ +final class MetadataSetter implements TextMapSetter { + private static final Logger logger = Logger.getLogger(MetadataSetter.class.getName()); + private static final MetadataSetter INSTANCE = new MetadataSetter(); + + public static MetadataSetter getInstance() { + return INSTANCE; + } + + @Override + public void set(@Nullable Metadata carrier, String key, String value) { + if (carrier == null) { + logger.log(Level.FINE, "Carrier is null, setting no data"); + return; + } + try { + if (key.equals("grpc-trace-bin")) { + carrier.put(Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER), + BaseEncoding.base64().decode(value)); + } else { + carrier.put(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER), value); + } + } catch (Exception e) { + logger.log(Level.INFO, String.format("Failed to set metadata, key=%s", key), e); + } + } + + void set(@Nullable Metadata carrier, String key, byte[] value) { + if (carrier == null) { + logger.log(Level.FINE, "Carrier is null, setting no data"); + return; + } + if (!key.equals("grpc-trace-bin")) { + logger.log(Level.INFO, "Only support 'grpc-trace-bin' binary header. Set no data"); + return; + } + try { + carrier.put(Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER), value); + } catch (Exception e) { + logger.log(Level.INFO, String.format("Failed to set metadata key=%s", key), e); + } + } +} diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryTracingModule.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryTracingModule.java new file mode 100644 index 00000000000..11659c87708 --- /dev/null +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryTracingModule.java @@ -0,0 +1,408 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.opentelemetry; + +import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.ClientStreamTracer.NAME_RESOLUTION_DELAYED; + +import com.google.common.annotations.VisibleForTesting; +import io.grpc.Attributes; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.ClientStreamTracer; +import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; +import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.ServerStreamTracer; +import io.opentelemetry.api.OpenTelemetry; +import io.opentelemetry.api.common.AttributesBuilder; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.StatusCode; +import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.propagation.ContextPropagators; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.Nullable; + +/** + * Provides factories for {@link io.grpc.StreamTracer} that records tracing to OpenTelemetry. + */ +final class OpenTelemetryTracingModule { + private static final Logger logger = Logger.getLogger(OpenTelemetryTracingModule.class.getName()); + + @VisibleForTesting + static final String OTEL_TRACING_SCOPE_NAME = "grpc-java"; + @Nullable + private static final AtomicIntegerFieldUpdater callEndedUpdater; + @Nullable + private static final AtomicIntegerFieldUpdater streamClosedUpdater; + + /* + * When using Atomic*FieldUpdater, some Samsung Android 5.0.x devices encounter a bug in their JDK + * reflection API that triggers a NoSuchFieldException. When this occurs, we fallback to + * (potentially racy) direct updates of the volatile variables. + */ + static { + AtomicIntegerFieldUpdater tmpCallEndedUpdater; + AtomicIntegerFieldUpdater tmpStreamClosedUpdater; + try { + tmpCallEndedUpdater = + AtomicIntegerFieldUpdater.newUpdater(CallAttemptsTracerFactory.class, "callEnded"); + tmpStreamClosedUpdater = + AtomicIntegerFieldUpdater.newUpdater(ServerTracer.class, "streamClosed"); + } catch (Throwable t) { + logger.log(Level.SEVERE, "Creating atomic field updaters failed", t); + tmpCallEndedUpdater = null; + tmpStreamClosedUpdater = null; + } + callEndedUpdater = tmpCallEndedUpdater; + streamClosedUpdater = tmpStreamClosedUpdater; + } + + private final Tracer otelTracer; + private final ContextPropagators contextPropagators; + private final MetadataGetter metadataGetter = MetadataGetter.getInstance(); + private final MetadataSetter metadataSetter = MetadataSetter.getInstance(); + private final TracingClientInterceptor clientInterceptor = new TracingClientInterceptor(); + private final ServerTracerFactory serverTracerFactory = new ServerTracerFactory(); + + OpenTelemetryTracingModule(OpenTelemetry openTelemetry) { + this.otelTracer = checkNotNull(openTelemetry.getTracer(OTEL_TRACING_SCOPE_NAME), "otelTracer"); + this.contextPropagators = checkNotNull(openTelemetry.getPropagators(), "contextPropagators"); + } + + /** + * Creates a {@link CallAttemptsTracerFactory} for a new call. + */ + @VisibleForTesting + CallAttemptsTracerFactory newClientCallTracer(Span clientSpan, MethodDescriptor method) { + return new CallAttemptsTracerFactory(clientSpan, method); + } + + /** + * Returns the server tracer factory. + */ + ServerStreamTracer.Factory getServerTracerFactory() { + return serverTracerFactory; + } + + /** + * Returns the client interceptor that facilitates otel tracing reporting. + */ + ClientInterceptor getClientInterceptor() { + return clientInterceptor; + } + + @VisibleForTesting + final class CallAttemptsTracerFactory extends ClientStreamTracer.Factory { + volatile int callEnded; + private final Span clientSpan; + private final String fullMethodName; + + CallAttemptsTracerFactory(Span clientSpan, MethodDescriptor method) { + checkNotNull(method, "method"); + this.fullMethodName = checkNotNull(method.getFullMethodName(), "fullMethodName"); + this.clientSpan = checkNotNull(clientSpan, "clientSpan"); + } + + @Override + public ClientStreamTracer newClientStreamTracer( + ClientStreamTracer.StreamInfo info, Metadata headers) { + Span attemptSpan = otelTracer.spanBuilder( + "Attempt." + fullMethodName.replace('/', '.')) + .setParent(Context.current().with(clientSpan)) + .startSpan(); + attemptSpan.setAttribute( + "previous-rpc-attempts", info.getPreviousAttempts()); + attemptSpan.setAttribute( + "transparent-retry",info.isTransparentRetry()); + if (info.getCallOptions().getOption(NAME_RESOLUTION_DELAYED) != null) { + clientSpan.addEvent("Delayed name resolution complete"); + } + return new ClientTracer(attemptSpan, clientSpan); + } + + /** + * Record a finished call and mark the current time as the end time. + * + *

Can be called from any thread without synchronization. Calling it the second time or more + * is a no-op. + */ + void callEnded(io.grpc.Status status) { + if (callEndedUpdater != null) { + if (callEndedUpdater.getAndSet(this, 1) != 0) { + return; + } + } else { + if (callEnded != 0) { + return; + } + callEnded = 1; + } + endSpanWithStatus(clientSpan, status); + } + } + + private final class ClientTracer extends ClientStreamTracer { + private final Span span; + private final Span parentSpan; + volatile int seqNo; + boolean isPendingStream; + + ClientTracer(Span span, Span parentSpan) { + this.span = checkNotNull(span, "span"); + this.parentSpan = checkNotNull(parentSpan, "parent span"); + } + + @Override + public void streamCreated(Attributes transportAtts, Metadata headers) { + contextPropagators.getTextMapPropagator().inject(Context.current().with(span), headers, + metadataSetter); + if (isPendingStream) { + span.addEvent("Delayed LB pick complete"); + } + } + + @Override + public void createPendingStream() { + isPendingStream = true; + } + + @Override + public void outboundMessageSent( + int seqNo, long optionalWireSize, long optionalUncompressedSize) { + recordOutboundMessageSentEvent(span, seqNo, optionalWireSize, optionalUncompressedSize); + } + + @Override + public void inboundMessageRead( + int seqNo, long optionalWireSize, long optionalUncompressedSize) { + //TODO(yifeizhuang): needs support from message deframer. + if (optionalWireSize != optionalUncompressedSize) { + recordInboundCompressedMessage(span, seqNo, optionalWireSize); + } + } + + @Override + public void inboundMessage(int seqNo) { + this.seqNo = seqNo; + } + + @Override + public void inboundUncompressedSize(long bytes) { + recordInboundMessageSize(parentSpan, seqNo, bytes); + } + + @Override + public void streamClosed(io.grpc.Status status) { + endSpanWithStatus(span, status); + } + } + + private final class ServerTracer extends ServerStreamTracer { + private final Span span; + volatile int streamClosed; + private int seqNo; + + ServerTracer(String fullMethodName, @Nullable Span remoteSpan) { + checkNotNull(fullMethodName, "fullMethodName"); + this.span = + otelTracer.spanBuilder(generateTraceSpanName(true, fullMethodName)) + .setParent(remoteSpan == null ? null : Context.current().with(remoteSpan)) + .startSpan(); + } + + /** + * Record a finished stream and mark the current time as the end time. + * + *

Can be called from any thread without synchronization. Calling it the second time or more + * is a no-op. + */ + @Override + public void streamClosed(io.grpc.Status status) { + if (streamClosedUpdater != null) { + if (streamClosedUpdater.getAndSet(this, 1) != 0) { + return; + } + } else { + if (streamClosed != 0) { + return; + } + streamClosed = 1; + } + endSpanWithStatus(span, status); + } + + @Override + public void outboundMessageSent( + int seqNo, long optionalWireSize, long optionalUncompressedSize) { + recordOutboundMessageSentEvent(span, seqNo, optionalWireSize, optionalUncompressedSize); + } + + @Override + public void inboundMessageRead( + int seqNo, long optionalWireSize, long optionalUncompressedSize) { + if (optionalWireSize != optionalUncompressedSize) { + recordInboundCompressedMessage(span, seqNo, optionalWireSize); + } + } + + @Override + public void inboundMessage(int seqNo) { + this.seqNo = seqNo; + } + + @Override + public void inboundUncompressedSize(long bytes) { + recordInboundMessageSize(span, seqNo, bytes); + } + } + + @VisibleForTesting + final class ServerTracerFactory extends ServerStreamTracer.Factory { + @SuppressWarnings("ReferenceEquality") + @Override + public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata headers) { + Context context = contextPropagators.getTextMapPropagator().extract( + Context.current(), headers, metadataGetter + ); + Span remoteSpan = Span.fromContext(context); + if (remoteSpan == Span.getInvalid()) { + remoteSpan = null; + } + return new ServerTracer(fullMethodName, remoteSpan); + } + } + + @VisibleForTesting + final class TracingClientInterceptor implements ClientInterceptor { + + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + Span clientSpan = otelTracer.spanBuilder( + generateTraceSpanName(false, method.getFullMethodName())) + .startSpan(); + + final CallAttemptsTracerFactory tracerFactory = newClientCallTracer(clientSpan, method); + ClientCall call = + next.newCall( + method, + callOptions.withStreamTracerFactory(tracerFactory)); + return new SimpleForwardingClientCall(call) { + @Override + public void start(Listener responseListener, Metadata headers) { + delegate().start( + new SimpleForwardingClientCallListener(responseListener) { + @Override + public void onClose(io.grpc.Status status, Metadata trailers) { + tracerFactory.callEnded(status); + super.onClose(status, trailers); + } + }, + headers); + } + }; + } + } + + // Attribute named "message-size" always means the message size the application sees. + // If there was compression, additional event reports "message-size-compressed". + // + // An example trace with message compression: + // + // Sending: + // |-- Event 'Outbound message sent', attributes('sequence-numer' = 0, 'message-size' = 7854, + // 'message-size-compressed' = 5493) ----| + // + // Receiving: + // |-- Event 'Inbound compressed message', attributes('sequence-numer' = 0, + // 'message-size-compressed' = 5493 ) ----| + // |-- Event 'Inbound message received', attributes('sequence-numer' = 0, + // 'message-size' = 7854) ----| + // + // An example trace with no message compression: + // + // Sending: + // |-- Event 'Outbound message sent', attributes('sequence-numer' = 0, 'message-size' = 7854) ---| + // + // Receiving: + // |-- Event 'Inbound message received', attributes('sequence-numer' = 0, + // 'message-size' = 7854) ----| + private void recordOutboundMessageSentEvent(Span span, + int seqNo, long optionalWireSize, long optionalUncompressedSize) { + AttributesBuilder attributesBuilder = io.opentelemetry.api.common.Attributes.builder(); + attributesBuilder.put("sequence-number", seqNo); + if (optionalUncompressedSize != -1) { + attributesBuilder.put("message-size", optionalUncompressedSize); + } + if (optionalWireSize != -1 && optionalWireSize != optionalUncompressedSize) { + attributesBuilder.put("message-size-compressed", optionalWireSize); + } + span.addEvent("Outbound message sent", attributesBuilder.build()); + } + + private void recordInboundCompressedMessage(Span span, int seqNo, long optionalWireSize) { + AttributesBuilder attributesBuilder = io.opentelemetry.api.common.Attributes.builder(); + attributesBuilder.put("sequence-number", seqNo); + attributesBuilder.put("message-size-compressed", optionalWireSize); + span.addEvent("Inbound compressed message", attributesBuilder.build()); + } + + private void recordInboundMessageSize(Span span, int seqNo, long bytes) { + AttributesBuilder attributesBuilder = io.opentelemetry.api.common.Attributes.builder(); + attributesBuilder.put("sequence-number", seqNo); + attributesBuilder.put("message-size", bytes); + span.addEvent("Inbound message received", attributesBuilder.build()); + } + + private String generateErrorStatusDescription(io.grpc.Status status) { + if (status.getDescription() != null) { + return status.getCode() + ": " + status.getDescription(); + } else { + return status.getCode().toString(); + } + } + + private void endSpanWithStatus(Span span, io.grpc.Status status) { + if (status.isOk()) { + span.setStatus(StatusCode.OK); + } else { + span.setStatus(StatusCode.ERROR, generateErrorStatusDescription(status)); + } + span.end(); + } + + /** + * Convert a full method name to a tracing span name. + * + * @param isServer {@code false} if the span is on the client-side, {@code true} if on the + * server-side + * @param fullMethodName the method name as returned by + * {@link MethodDescriptor#getFullMethodName}. + */ + @VisibleForTesting + static String generateTraceSpanName(boolean isServer, String fullMethodName) { + String prefix = isServer ? "Recv" : "Sent"; + return prefix + "." + fullMethodName.replace('/', '.'); + } +} diff --git a/opentelemetry/src/test/java/io/grpc/opentelemetry/GrpcTraceBinContextPropagatorTest.java b/opentelemetry/src/test/java/io/grpc/opentelemetry/GrpcTraceBinContextPropagatorTest.java new file mode 100644 index 00000000000..f85b8067c26 --- /dev/null +++ b/opentelemetry/src/test/java/io/grpc/opentelemetry/GrpcTraceBinContextPropagatorTest.java @@ -0,0 +1,313 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.opentelemetry; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.InternalMetadata.BASE64_ENCODING_OMIT_PADDING; +import static org.junit.Assert.assertTrue; + +import com.google.common.collect.ImmutableMap; +import io.grpc.Metadata; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.SpanContext; +import io.opentelemetry.api.trace.TraceFlags; +import io.opentelemetry.api.trace.TraceState; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.propagation.TextMapGetter; +import io.opentelemetry.context.propagation.TextMapSetter; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import javax.annotation.Nullable; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GrpcTraceBinContextPropagatorTest { + private static final String TRACE_ID_BASE16 = "e384981d65129fa3e384981d65129fa3"; + private static final String SPAN_ID_BASE16 = "e384981d65129fa3"; + private static final String TRACE_HEADER_SAMPLED = + "0000" + TRACE_ID_BASE16 + "01" + SPAN_ID_BASE16 + "0201"; + private static final String TRACE_HEADER_NOT_SAMPLED = + "0000" + TRACE_ID_BASE16 + "01" + SPAN_ID_BASE16 + "0200"; + private final String goldenHeaderEncodedSampled = encode(TRACE_HEADER_SAMPLED); + private final String goldenHeaderEncodedNotSampled = encode(TRACE_HEADER_NOT_SAMPLED); + private static final TextMapSetter> setter = Map::put; + private static final TextMapGetter> getter = + new TextMapGetter>() { + @Override + public Iterable keys(Map carrier) { + return carrier.keySet(); + } + + @Nullable + @Override + public String get(Map carrier, String key) { + return carrier.get(key); + } + }; + private final GrpcTraceBinContextPropagator grpcTraceBinContextPropagator = + GrpcTraceBinContextPropagator.defaultInstance(); + + private static Context withSpanContext(SpanContext spanContext, Context context) { + return context.with(Span.wrap(spanContext)); + } + + private static SpanContext getSpanContext(Context context) { + return Span.fromContext(context).getSpanContext(); + } + + @Test + public void inject_map_Nothing() { + Map carrier = new HashMap<>(); + grpcTraceBinContextPropagator.inject(Context.current(), carrier, setter); + assertThat(carrier).hasSize(0); + } + + @Test + public void inject_map_invalidSpan() { + Map carrier = new HashMap<>(); + Context context = withSpanContext(SpanContext.getInvalid(), Context.current()); + grpcTraceBinContextPropagator.inject(context, carrier, setter); + assertThat(carrier).isEmpty(); + } + + @Test + public void inject_map_nullCarrier() { + Map carrier = new HashMap<>(); + Context context = + withSpanContext( + SpanContext.create( + TRACE_ID_BASE16, SPAN_ID_BASE16, TraceFlags.getSampled(), TraceState.getDefault()), + Context.current()); + grpcTraceBinContextPropagator.inject(context, null, + (TextMapSetter>) (ignored, key, value) -> carrier.put(key, value)); + assertThat(carrier) + .containsExactly( + GrpcTraceBinContextPropagator.GRPC_TRACE_BIN_HEADER, goldenHeaderEncodedSampled); + } + + @Test + public void inject_map_nullContext() { + Map carrier = new HashMap<>(); + grpcTraceBinContextPropagator.inject(null, carrier, setter); + assertThat(carrier).isEmpty(); + } + + @Test + public void inject_map_invalidBinaryFormat() { + GrpcTraceBinContextPropagator propagator = new GrpcTraceBinContextPropagator( + new Metadata.BinaryMarshaller() { + @Override + public byte[] toBytes(SpanContext value) { + throw new IllegalArgumentException("failed to byte"); + } + + @Override + public SpanContext parseBytes(byte[] serialized) { + return null; + } + }); + Map carrier = new HashMap<>(); + Context context = + withSpanContext( + SpanContext.create( + TRACE_ID_BASE16, SPAN_ID_BASE16, TraceFlags.getSampled(), TraceState.getDefault()), + Context.current()); + propagator.inject(context, carrier, setter); + assertThat(carrier).hasSize(0); + } + + @Test + public void inject_map_SampledContext() { + verify_inject_map(TraceFlags.getSampled(), goldenHeaderEncodedSampled); + } + + @Test + public void inject_map_NotSampledContext() { + verify_inject_map(TraceFlags.getDefault(), goldenHeaderEncodedNotSampled); + } + + private void verify_inject_map(TraceFlags traceFlags, String goldenHeader) { + Map carrier = new HashMap<>(); + Context context = + withSpanContext( + SpanContext.create( + TRACE_ID_BASE16, SPAN_ID_BASE16, traceFlags, TraceState.getDefault()), + Context.current()); + grpcTraceBinContextPropagator.inject(context, carrier, setter); + assertThat(carrier) + .containsExactly( + GrpcTraceBinContextPropagator.GRPC_TRACE_BIN_HEADER, goldenHeader); + } + + @Test + public void extract_map_nothing() { + Map carrier = new HashMap<>(); + assertThat(grpcTraceBinContextPropagator.extract(Context.current(), carrier, getter)) + .isSameInstanceAs(Context.current()); + } + + @Test + public void extract_map_SampledContext() { + verify_extract_map(TraceFlags.getSampled(), goldenHeaderEncodedSampled); + } + + @Test + public void extract_map_NotSampledContext() { + verify_extract_map(TraceFlags.getDefault(), goldenHeaderEncodedNotSampled); + } + + private void verify_extract_map(TraceFlags traceFlags, String goldenHeader) { + Map carrier = ImmutableMap.of( + GrpcTraceBinContextPropagator.GRPC_TRACE_BIN_HEADER, goldenHeader); + Context result = grpcTraceBinContextPropagator.extract(Context.current(), carrier, getter); + assertThat(getSpanContext(result)).isEqualTo(SpanContext.create( + TRACE_ID_BASE16, SPAN_ID_BASE16, traceFlags, TraceState.getDefault())); + } + + @Test + public void inject_metadata_Nothing() { + Metadata carrier = new Metadata(); + grpcTraceBinContextPropagator.inject(Context.current(), carrier, MetadataSetter.getInstance()); + assertThat(carrier.keys()).isEmpty(); + } + + @Test + public void inject_metadata_nullCarrier() { + Context context = + withSpanContext( + SpanContext.create( + TRACE_ID_BASE16, SPAN_ID_BASE16, TraceFlags.getSampled(), TraceState.getDefault()), + Context.current()); + grpcTraceBinContextPropagator.inject(context, null, MetadataSetter.getInstance()); + } + + @Test + public void inject_metadata_invalidSpan() { + Metadata carrier = new Metadata(); + Context context = withSpanContext(SpanContext.getInvalid(), Context.current()); + grpcTraceBinContextPropagator.inject(context, carrier, MetadataSetter.getInstance()); + assertThat(carrier.keys()).isEmpty(); + } + + @Test + public void inject_metadata_SampledContext() { + verify_inject_metadata(TraceFlags.getSampled(), hexStringToByteArray(TRACE_HEADER_SAMPLED)); + } + + @Test + public void inject_metadataSetter_NotSampledContext() { + verify_inject_metadata(TraceFlags.getDefault(), hexStringToByteArray(TRACE_HEADER_NOT_SAMPLED)); + } + + private void verify_inject_metadata(TraceFlags traceFlags, byte[] bytes) { + Metadata metadata = new Metadata(); + Context context = + withSpanContext( + SpanContext.create( + TRACE_ID_BASE16, SPAN_ID_BASE16, traceFlags, TraceState.getDefault()), + Context.current()); + grpcTraceBinContextPropagator.inject(context, metadata, MetadataSetter.getInstance()); + byte[] injected = metadata.get(Metadata.Key.of( + GrpcTraceBinContextPropagator.GRPC_TRACE_BIN_HEADER, Metadata.BINARY_BYTE_MARSHALLER)); + assertTrue(Arrays.equals(injected, bytes)); + } + + @Test + public void extract_metadata_nothing() { + assertThat(grpcTraceBinContextPropagator.extract( + Context.current(), new Metadata(), MetadataGetter.getInstance())) + .isSameInstanceAs(Context.current()); + } + + @Test + public void extract_metadata_nullCarrier() { + assertThat(grpcTraceBinContextPropagator.extract( + Context.current(), null, MetadataGetter.getInstance())) + .isSameInstanceAs(Context.current()); + } + + @Test + public void extract_metadata_SampledContext() { + verify_extract_metadata(TraceFlags.getSampled(), TRACE_HEADER_SAMPLED); + } + + @Test + public void extract_metadataGetter_NotSampledContext() { + verify_extract_metadata(TraceFlags.getDefault(), TRACE_HEADER_NOT_SAMPLED); + } + + private void verify_extract_metadata(TraceFlags traceFlags, String hex) { + Metadata carrier = new Metadata(); + carrier.put(Metadata.Key.of( + GrpcTraceBinContextPropagator.GRPC_TRACE_BIN_HEADER, Metadata.BINARY_BYTE_MARSHALLER), + hexStringToByteArray(hex)); + Context result = grpcTraceBinContextPropagator.extract(Context.current(), carrier, + MetadataGetter.getInstance()); + assertThat(getSpanContext(result)).isEqualTo(SpanContext.create( + TRACE_ID_BASE16, SPAN_ID_BASE16, traceFlags, TraceState.getDefault())); + } + + @Test + public void extract_metadata_invalidBinaryFormat() { + GrpcTraceBinContextPropagator propagator = new GrpcTraceBinContextPropagator( + new Metadata.BinaryMarshaller() { + @Override + public byte[] toBytes(SpanContext value) { + return new byte[0]; + } + + @Override + public SpanContext parseBytes(byte[] serialized) { + throw new IllegalArgumentException("failed to byte"); + } + }); + Metadata carrier = new Metadata(); + carrier.put(Metadata.Key.of( + GrpcTraceBinContextPropagator.GRPC_TRACE_BIN_HEADER, Metadata.BINARY_BYTE_MARSHALLER), + hexStringToByteArray(TRACE_HEADER_SAMPLED)); + assertThat(propagator.extract(Context.current(), carrier, MetadataGetter.getInstance())) + .isSameInstanceAs(Context.current()); + } + + @Test + public void extract_metadata_invalidBinaryFormatVersion() { + Metadata carrier = new Metadata(); + carrier.put(Metadata.Key.of( + GrpcTraceBinContextPropagator.GRPC_TRACE_BIN_HEADER, Metadata.BINARY_BYTE_MARSHALLER), + hexStringToByteArray("0100" + TRACE_ID_BASE16 + "01" + SPAN_ID_BASE16 + "0201")); + assertThat(grpcTraceBinContextPropagator.extract( + Context.current(), carrier, MetadataGetter.getInstance())) + .isSameInstanceAs(Context.current()); + } + + private static String encode(String hex) { + return BASE64_ENCODING_OMIT_PADDING.encode(hexStringToByteArray(hex)); + } + + private static byte[] hexStringToByteArray(String s) { + int len = s.length(); + byte[] data = new byte[len / 2]; + for (int i = 0; i < len; i += 2) { + data[i / 2] = (byte) ((Character.digit(s.charAt(i), 16) << 4) + + Character.digit(s.charAt(i + 1), 16)); + } + return data; + } +} diff --git a/opentelemetry/src/test/java/io/grpc/opentelemetry/MetadataGetterTest.java b/opentelemetry/src/test/java/io/grpc/opentelemetry/MetadataGetterTest.java new file mode 100644 index 00000000000..5934240e5c2 --- /dev/null +++ b/opentelemetry/src/test/java/io/grpc/opentelemetry/MetadataGetterTest.java @@ -0,0 +1,96 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.opentelemetry; + +import static io.grpc.InternalMetadata.BASE64_ENCODING_OMIT_PADDING; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import io.grpc.Metadata; +import java.nio.charset.Charset; +import java.util.Iterator; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class MetadataGetterTest { + private final MetadataGetter metadataGetter = MetadataGetter.getInstance(); + + @Test + public void getBinaryGrpcTraceBin() { + Metadata metadata = new Metadata(); + byte[] b = "generated".getBytes(Charset.defaultCharset()); + Metadata.Key grpc_trace_bin_key = + Metadata.Key.of("grpc-trace-bin", Metadata.BINARY_BYTE_MARSHALLER); + metadata.put(grpc_trace_bin_key, b); + assertArrayEquals(b, metadataGetter.getBinary(metadata, "grpc-trace-bin")); + } + + @Test + public void getBinaryEmptyMetadata() { + assertNull(metadataGetter.getBinary(new Metadata(), "grpc-trace-bin")); + } + + @Test + public void getBinaryNotGrpcTraceBin() { + Metadata metadata = new Metadata(); + byte[] b = "generated".getBytes(Charset.defaultCharset()); + Metadata.Key grpc_trace_bin_key = + Metadata.Key.of("another-bin", Metadata.BINARY_BYTE_MARSHALLER); + metadata.put(grpc_trace_bin_key, b); + assertNull(metadataGetter.getBinary(metadata, "another-bin")); + } + + @Test + public void getTextEmptyMetadata() { + assertNull(metadataGetter.get(new Metadata(), "a-key")); + } + + @Test + public void getTextBinHeader() { + assertNull(metadataGetter.get(new Metadata(), "a-key-bin")); + } + + @Test + public void getTestGrpcTraceBin() { + Metadata metadata = new Metadata(); + byte[] b = "generated".getBytes(Charset.defaultCharset()); + Metadata.Key grpc_trace_bin_key = + Metadata.Key.of("grpc-trace-bin", Metadata.BINARY_BYTE_MARSHALLER); + metadata.put(grpc_trace_bin_key, b); + assertEquals(BASE64_ENCODING_OMIT_PADDING.encode(b), + metadataGetter.get(metadata, "grpc-trace-bin")); + } + + @Test + public void getText() { + Metadata metadata = new Metadata(); + Metadata.Key other_key = + Metadata.Key.of("other", Metadata.ASCII_STRING_MARSHALLER); + metadata.put(other_key, "header-value"); + assertEquals("header-value", metadataGetter.get(metadata, "other")); + + Iterator iterator = metadataGetter.keys(metadata).iterator(); + assertTrue(iterator.hasNext()); + assertEquals("other", iterator.next()); + assertFalse(iterator.hasNext()); + } +} diff --git a/opentelemetry/src/test/java/io/grpc/opentelemetry/MetadataSetterTest.java b/opentelemetry/src/test/java/io/grpc/opentelemetry/MetadataSetterTest.java new file mode 100644 index 00000000000..fcd85480bb9 --- /dev/null +++ b/opentelemetry/src/test/java/io/grpc/opentelemetry/MetadataSetterTest.java @@ -0,0 +1,83 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.opentelemetry; + +import static io.grpc.InternalMetadata.BASE64_ENCODING_OMIT_PADDING; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +import io.grpc.Metadata; +import java.nio.charset.Charset; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class MetadataSetterTest { + private final MetadataSetter metadataSetter = MetadataSetter.getInstance(); + + @Test + public void setGrpcTraceBin() { + Metadata metadata = new Metadata(); + byte[] b = "generated".getBytes(Charset.defaultCharset()); + Metadata.Key grpc_trace_bin_key = + Metadata.Key.of("grpc-trace-bin", Metadata.BINARY_BYTE_MARSHALLER); + metadataSetter.set(metadata, "grpc-trace-bin", b); + assertArrayEquals(b, metadata.get(grpc_trace_bin_key)); + } + + @Test + public void setOtherBinaryKey() { + Metadata metadata = new Metadata(); + byte[] b = "generated".getBytes(Charset.defaultCharset()); + Metadata.Key other_key = + Metadata.Key.of("for-test-bin", Metadata.BINARY_BYTE_MARSHALLER); + metadataSetter.set(metadata, other_key.name(), b); + assertNull(metadata.get(other_key)); + } + + @Test + public void setText() { + Metadata metadata = new Metadata(); + String v = "generated"; + Metadata.Key textKey = + Metadata.Key.of("text-key", Metadata.ASCII_STRING_MARSHALLER); + metadataSetter.set(metadata, textKey.name(), v); + assertEquals(metadata.get(textKey), v); + } + + @Test + public void setTextBin() { + Metadata metadata = new Metadata(); + Metadata.Key other_key = + Metadata.Key.of("for-test-bin", Metadata.BINARY_BYTE_MARSHALLER); + metadataSetter.set(metadata, other_key.name(), "generated"); + assertNull(metadata.get(other_key)); + } + + @Test + public void setTextGrpcTraceBin() { + Metadata metadata = new Metadata(); + byte[] b = "generated".getBytes(Charset.defaultCharset()); + metadataSetter.set(metadata, "grpc-trace-bin", BASE64_ENCODING_OMIT_PADDING.encode(b)); + + Metadata.Key grpc_trace_bin_key = + Metadata.Key.of("grpc-trace-bin", Metadata.BINARY_BYTE_MARSHALLER); + assertArrayEquals(metadata.get(grpc_trace_bin_key), b); + } +} diff --git a/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryTracingModuleTest.java b/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryTracingModuleTest.java new file mode 100644 index 00000000000..68cba17e802 --- /dev/null +++ b/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryTracingModuleTest.java @@ -0,0 +1,582 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.opentelemetry; + +import static io.grpc.ClientStreamTracer.NAME_RESOLUTION_DELAYED; +import static io.grpc.opentelemetry.OpenTelemetryTracingModule.OTEL_TRACING_SCOPE_NAME; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableSet; +import io.grpc.Attributes; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.ClientInterceptors; +import io.grpc.ClientStreamTracer; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerServiceDefinition; +import io.grpc.ServerStreamTracer; +import io.grpc.Status; +import io.grpc.opentelemetry.OpenTelemetryTracingModule.CallAttemptsTracerFactory; +import io.grpc.testing.GrpcServerRule; +import io.opentelemetry.api.OpenTelemetry; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.SpanBuilder; +import io.opentelemetry.api.trace.SpanId; +import io.opentelemetry.api.trace.StatusCode; +import io.opentelemetry.api.trace.TraceId; +import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.Scope; +import io.opentelemetry.context.propagation.ContextPropagators; +import io.opentelemetry.context.propagation.TextMapPropagator; +import io.opentelemetry.sdk.testing.junit4.OpenTelemetryRule; +import io.opentelemetry.sdk.trace.data.EventData; +import io.opentelemetry.sdk.trace.data.SpanData; +import java.io.InputStream; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.InOrder; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class OpenTelemetryTracingModuleTest { + @Rule + public final MockitoRule mocks = MockitoJUnit.rule(); + + private static final ClientStreamTracer.StreamInfo STREAM_INFO = + ClientStreamTracer.StreamInfo.newBuilder() + .setCallOptions(CallOptions.DEFAULT.withOption(NAME_RESOLUTION_DELAYED, 10L)).build(); + private static final CallOptions.Key CUSTOM_OPTION = + CallOptions.Key.createWithDefault("option1", "default"); + private static final CallOptions CALL_OPTIONS = + CallOptions.DEFAULT.withOption(CUSTOM_OPTION, "customvalue"); + + private static class StringInputStream extends InputStream { + final String string; + + StringInputStream(String string) { + this.string = string; + } + + @Override + public int read() { + // InProcessTransport doesn't actually read bytes from the InputStream. The InputStream is + // passed to the InProcess server and consumed by MARSHALLER.parse(). + throw new UnsupportedOperationException("Should not be called"); + } + } + + private static final MethodDescriptor.Marshaller MARSHALLER = + new MethodDescriptor.Marshaller() { + @Override + public InputStream stream(String value) { + return new StringInputStream(value); + } + + @Override + public String parse(InputStream stream) { + return ((StringInputStream) stream).string; + } + }; + + private final MethodDescriptor method = + MethodDescriptor.newBuilder() + .setType(MethodDescriptor.MethodType.UNKNOWN) + .setRequestMarshaller(MARSHALLER) + .setResponseMarshaller(MARSHALLER) + .setFullMethodName("package1.service2/method3") + .build(); + + @Rule + public final OpenTelemetryRule openTelemetryRule = OpenTelemetryRule.create(); + @Rule + public final GrpcServerRule grpcServerRule = new GrpcServerRule().directExecutor(); + private Tracer tracerRule; + @Mock + private Tracer mockTracer; + @Mock + TextMapPropagator mockPropagator; + @Mock + private Span mockClientSpan; + @Mock + private Span mockAttemptSpan; + @Mock + private ServerCall.Listener mockServerCallListener; + @Mock + private ClientCall.Listener mockClientCallListener; + @Mock + private SpanBuilder mockSpanBuilder; + @Mock + private OpenTelemetry mockOpenTelemetry; + @Captor + private ArgumentCaptor eventNameCaptor; + @Captor + private ArgumentCaptor attributesCaptor; + @Captor + private ArgumentCaptor statusCaptor; + + @Before + public void setUp() { + tracerRule = openTelemetryRule.getOpenTelemetry().getTracer(OTEL_TRACING_SCOPE_NAME); + when(mockOpenTelemetry.getTracer(OTEL_TRACING_SCOPE_NAME)).thenReturn(mockTracer); + when(mockOpenTelemetry.getPropagators()).thenReturn(ContextPropagators.create(mockPropagator)); + when(mockSpanBuilder.startSpan()).thenReturn(mockAttemptSpan); + when(mockSpanBuilder.setParent(any())).thenReturn(mockSpanBuilder); + when(mockTracer.spanBuilder(any())).thenReturn(mockSpanBuilder); + } + + // Use mock instead of OpenTelemetryRule to verify inOrder and propagator. + @Test + public void clientBasicTracingMocking() { + OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule(mockOpenTelemetry); + CallAttemptsTracerFactory callTracer = + tracingModule.newClientCallTracer(mockClientSpan, method); + Metadata headers = new Metadata(); + ClientStreamTracer clientStreamTracer = callTracer.newClientStreamTracer(STREAM_INFO, headers); + clientStreamTracer.createPendingStream(); + clientStreamTracer.streamCreated(Attributes.EMPTY, headers); + + verify(mockTracer).spanBuilder(eq("Attempt.package1.service2.method3")); + verify(mockPropagator).inject(any(), eq(headers), eq(MetadataSetter.getInstance())); + verify(mockClientSpan, never()).end(); + verify(mockAttemptSpan, never()).end(); + + clientStreamTracer.outboundMessage(0); + clientStreamTracer.outboundMessageSent(0, 882, -1); + clientStreamTracer.inboundMessage(0); + clientStreamTracer.outboundMessage(1); + clientStreamTracer.outboundMessageSent(1, -1, 27); + clientStreamTracer.inboundMessageRead(0, 255, 90); + + clientStreamTracer.streamClosed(Status.OK); + callTracer.callEnded(Status.OK); + + InOrder inOrder = inOrder(mockClientSpan, mockAttemptSpan); + inOrder.verify(mockAttemptSpan) + .setAttribute("previous-rpc-attempts", 0); + inOrder.verify(mockAttemptSpan) + .setAttribute("transparent-retry", false); + inOrder.verify(mockClientSpan).addEvent("Delayed name resolution complete"); + inOrder.verify(mockAttemptSpan).addEvent("Delayed LB pick complete"); + inOrder.verify(mockAttemptSpan, times(3)).addEvent( + eventNameCaptor.capture(), attributesCaptor.capture() + ); + List events = eventNameCaptor.getAllValues(); + List attributes = attributesCaptor.getAllValues(); + assertEquals( + "Outbound message sent" , + events.get(0)); + assertEquals( + io.opentelemetry.api.common.Attributes.builder() + .put("sequence-number", 0) + .put("message-size-compressed", 882) + .build(), + attributes.get(0)); + + assertEquals( + "Outbound message sent" , + events.get(1)); + assertEquals( + io.opentelemetry.api.common.Attributes.builder() + .put("sequence-number", 1) + .put("message-size", 27) + .build(), + attributes.get(1)); + + assertEquals( + "Inbound compressed message" , + events.get(2)); + assertEquals( + io.opentelemetry.api.common.Attributes.builder() + .put("sequence-number", 0) + .put("message-size-compressed", 255) + .build(), + attributes.get(2)); + + inOrder.verify(mockAttemptSpan).setStatus(StatusCode.OK); + inOrder.verify(mockAttemptSpan).end(); + inOrder.verify(mockClientSpan).setStatus(StatusCode.OK); + inOrder.verify(mockClientSpan).end(); + inOrder.verifyNoMoreInteractions(); + } + + @Test + public void clientBasicTracingRule() { + OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule( + openTelemetryRule.getOpenTelemetry()); + Span clientSpan = tracerRule.spanBuilder("test-client-span").startSpan(); + CallAttemptsTracerFactory callTracer = + tracingModule.newClientCallTracer(clientSpan, method); + Metadata headers = new Metadata(); + ClientStreamTracer clientStreamTracer = callTracer.newClientStreamTracer(STREAM_INFO, headers); + clientStreamTracer.createPendingStream(); + clientStreamTracer.streamCreated(Attributes.EMPTY, headers); + clientStreamTracer.outboundMessage(0); + clientStreamTracer.outboundMessageSent(0, 882, -1); + clientStreamTracer.inboundMessage(0); + clientStreamTracer.outboundMessage(1); + clientStreamTracer.outboundMessageSent(1, -1, 27); + clientStreamTracer.inboundMessageRead(0, 255, -1); + clientStreamTracer.inboundUncompressedSize(288); + clientStreamTracer.inboundMessageRead(1, 128, 128); + clientStreamTracer.inboundMessage(1); + clientStreamTracer.inboundUncompressedSize(128); + + clientStreamTracer.streamClosed(Status.OK); + callTracer.callEnded(Status.OK); + + List spans = openTelemetryRule.getSpans(); + assertEquals(spans.size(), 2); + SpanData attemptSpanData = spans.get(0); + SpanData clientSpanData = spans.get(1); + assertEquals(attemptSpanData.getName(), "Attempt.package1.service2.method3"); + assertEquals(clientSpanData.getName(), "test-client-span"); + assertEquals(headers.keys(), ImmutableSet.of("traceparent")); + String spanContext = headers.get( + Metadata.Key.of("traceparent", Metadata.ASCII_STRING_MARSHALLER)); + assertEquals(spanContext.substring(3, 3 + TraceId.getLength()), + spans.get(1).getSpanContext().getTraceId()); + + // parent(client) span data + List clientSpanEvents = clientSpanData.getEvents(); + assertEquals(clientSpanEvents.size(), 3); + assertEquals( + "Delayed name resolution complete", + clientSpanEvents.get(0).getName()); + assertTrue(clientSpanEvents.get(0).getAttributes().isEmpty()); + + assertEquals( + "Inbound message received" , + clientSpanEvents.get(1).getName()); + assertEquals( + io.opentelemetry.api.common.Attributes.builder() + .put("sequence-number", 0) + .put("message-size", 288) + .build(), + clientSpanEvents.get(1).getAttributes()); + + assertEquals( + "Inbound message received" , + clientSpanEvents.get(2).getName()); + assertEquals( + io.opentelemetry.api.common.Attributes.builder() + .put("sequence-number", 1) + .put("message-size", 128) + .build(), + clientSpanEvents.get(2).getAttributes()); + assertEquals(clientSpanData.hasEnded(), true); + + // child(attempt) span data + List attemptSpanEvents = attemptSpanData.getEvents(); + assertEquals(clientSpanEvents.size(), 3); + assertEquals( + "Delayed LB pick complete", + attemptSpanEvents.get(0).getName()); + assertTrue(clientSpanEvents.get(0).getAttributes().isEmpty()); + + assertEquals( + "Outbound message sent" , + attemptSpanEvents.get(1).getName()); + assertEquals( + io.opentelemetry.api.common.Attributes.builder() + .put("sequence-number", 0) + .put("message-size-compressed", 882) + .build(), + attemptSpanEvents.get(1).getAttributes()); + + assertEquals( + "Outbound message sent" , + attemptSpanEvents.get(2).getName()); + assertEquals( + io.opentelemetry.api.common.Attributes.builder() + .put("sequence-number", 1) + .put("message-size", 27) + .build(), + attemptSpanEvents.get(2).getAttributes()); + + assertEquals( + "Inbound compressed message" , + attemptSpanEvents.get(3).getName()); + assertEquals( + io.opentelemetry.api.common.Attributes.builder() + .put("sequence-number", 0) + .put("message-size-compressed", 255) + .build(), + attemptSpanEvents.get(3).getAttributes()); + + assertEquals(attemptSpanData.hasEnded(), true); + } + + @Test + public void clientInterceptor() { + testClientInterceptors(false); + } + + @Test + public void clientInterceptorNonDefaultOtelContext() { + testClientInterceptors(true); + } + + private void testClientInterceptors(boolean nonDefaultOtelContext) { + final AtomicReference capturedMetadata = new AtomicReference<>(); + grpcServerRule.getServiceRegistry().addService( + ServerServiceDefinition.builder("package1.service2").addMethod( + method, new ServerCallHandler() { + @Override + public ServerCall.Listener startCall( + ServerCall call, Metadata headers) { + capturedMetadata.set(headers); + call.sendHeaders(new Metadata()); + call.sendMessage("Hello"); + call.close( + Status.PERMISSION_DENIED.withDescription("No you don't"), new Metadata()); + return mockServerCallListener; + } + }).build()); + + final AtomicReference capturedCallOptions = new AtomicReference<>(); + ClientInterceptor callOptionsCaptureInterceptor = new ClientInterceptor() { + @Override + public ClientCall interceptCall( + MethodDescriptor method, CallOptions callOptions, Channel next) { + capturedCallOptions.set(callOptions); + return next.newCall(method, callOptions); + } + }; + OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule( + openTelemetryRule.getOpenTelemetry()); + Channel interceptedChannel = + ClientInterceptors.intercept( + grpcServerRule.getChannel(), callOptionsCaptureInterceptor, + tracingModule.getClientInterceptor()); + Span parentSpan = tracerRule.spanBuilder("test-parent-span").startSpan(); + ClientCall call; + + if (nonDefaultOtelContext) { + try (Scope scope = io.opentelemetry.context.Context.current().with(parentSpan) + .makeCurrent()) { + call = interceptedChannel.newCall(method, CALL_OPTIONS); + } + } else { + call = interceptedChannel.newCall(method, CALL_OPTIONS); + } + assertEquals("customvalue", capturedCallOptions.get().getOption(CUSTOM_OPTION)); + assertEquals(1, capturedCallOptions.get().getStreamTracerFactories().size()); + assertTrue( + capturedCallOptions.get().getStreamTracerFactories().get(0) + instanceof CallAttemptsTracerFactory); + + // Make the call + Metadata headers = new Metadata(); + call.start(mockClientCallListener, headers); + + // End the call + call.halfClose(); + call.request(1); + parentSpan.end(); + + verify(mockClientCallListener).onClose(statusCaptor.capture(), any(Metadata.class)); + Status status = statusCaptor.getValue(); + assertEquals(Status.Code.PERMISSION_DENIED, status.getCode()); + assertEquals("No you don't", status.getDescription()); + + List spans = openTelemetryRule.getSpans(); + assertEquals(spans.size(), 3); + + SpanData clientSpan = spans.get(1); + SpanData attemptSpan = spans.get(0); + if (nonDefaultOtelContext) { + assertEquals(clientSpan.getParentSpanContext(), parentSpan.getSpanContext()); + } else { + assertEquals(clientSpan.getParentSpanContext(), + Span.fromContext(Context.root()).getSpanContext()); + } + String spanContext = capturedMetadata.get().get( + Metadata.Key.of("traceparent", Metadata.ASCII_STRING_MARSHALLER)); + // W3C format: 00--- + assertEquals(spanContext.substring(3, 3 + TraceId.getLength()), + attemptSpan.getSpanContext().getTraceId()); + assertEquals(spanContext.substring(3 + TraceId.getLength() + 1, + 3 + TraceId.getLength() + 1 + SpanId.getLength()), + attemptSpan.getSpanContext().getSpanId()); + + assertEquals(attemptSpan.getParentSpanContext(), clientSpan.getSpanContext()); + assertTrue(clientSpan.hasEnded()); + assertEquals(clientSpan.getStatus().getStatusCode(), StatusCode.ERROR); + assertEquals(clientSpan.getStatus().getDescription(), "PERMISSION_DENIED: No you don't"); + assertTrue(attemptSpan.hasEnded()); + assertTrue(attemptSpan.hasEnded()); + assertEquals(attemptSpan.getStatus().getStatusCode(), StatusCode.ERROR); + assertEquals(attemptSpan.getStatus().getDescription(), "PERMISSION_DENIED: No you don't"); + } + + @Test + public void clientStreamNeverCreatedStillRecordTracing() { + OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule(mockOpenTelemetry); + CallAttemptsTracerFactory callTracer = + tracingModule.newClientCallTracer(mockClientSpan, method); + + callTracer.callEnded(Status.DEADLINE_EXCEEDED.withDescription("3 seconds")); + verify(mockClientSpan).end(); + verify(mockClientSpan).setStatus(eq(StatusCode.ERROR), + eq("DEADLINE_EXCEEDED: 3 seconds")); + verifyNoMoreInteractions(mockClientSpan); + } + + @Test + public void serverBasicTracingNoHeaders() { + OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule( + openTelemetryRule.getOpenTelemetry()); + ServerStreamTracer.Factory tracerFactory = tracingModule.getServerTracerFactory(); + ServerStreamTracer serverStreamTracer = + tracerFactory.newServerStreamTracer(method.getFullMethodName(), new Metadata()); + assertSame(Span.fromContext(Context.current()), Span.getInvalid()); + + serverStreamTracer.outboundMessage(0); + serverStreamTracer.outboundMessageSent(0, 882, 998); + serverStreamTracer.inboundMessage(0); + serverStreamTracer.outboundMessage(1); + serverStreamTracer.outboundMessageSent(1, -1, 27); + serverStreamTracer.inboundMessageRead(0, 90, -1); + serverStreamTracer.inboundUncompressedSize(255); + + serverStreamTracer.streamClosed(Status.CANCELLED); + + List spans = openTelemetryRule.getSpans(); + assertEquals(spans.size(), 1); + assertEquals(spans.get(0).getName(), "Recv.package1.service2.method3"); + assertEquals(spans.get(0).getParentSpanContext(), Span.getInvalid().getSpanContext()); + + List events = spans.get(0).getEvents(); + assertEquals(events.size(), 4); + assertEquals( + "Outbound message sent" , + events.get(0).getName()); + assertEquals( + io.opentelemetry.api.common.Attributes.builder() + .put("sequence-number", 0) + .put("message-size-compressed", 882) + .put("message-size", 998) + .build(), + events.get(0).getAttributes()); + + assertEquals( + "Outbound message sent" , + events.get(1).getName()); + assertEquals( + io.opentelemetry.api.common.Attributes.builder() + .put("sequence-number", 1) + .put("message-size", 27) + .build(), + events.get(1).getAttributes()); + + assertEquals( + "Inbound compressed message" , + events.get(2).getName()); + assertEquals( + io.opentelemetry.api.common.Attributes.builder() + .put("sequence-number", 0) + .put("message-size-compressed", 90) + .build(), + events.get(2).getAttributes()); + + assertEquals( + "Inbound message received" , + events.get(3).getName()); + assertEquals( + io.opentelemetry.api.common.Attributes.builder() + .put("sequence-number", 0) + .put("message-size", 255) + .build(), + events.get(3).getAttributes()); + + assertEquals(spans.get(0).hasEnded(), true); + } + + @Test + public void grpcTraceBinPropagator() { + when(mockOpenTelemetry.getPropagators()).thenReturn( + ContextPropagators.create(GrpcTraceBinContextPropagator.defaultInstance())); + ArgumentCaptor contextArgumentCaptor = ArgumentCaptor.forClass(Context.class); + OpenTelemetryTracingModule tracingModule = new OpenTelemetryTracingModule(mockOpenTelemetry); + Span testClientSpan = tracerRule.spanBuilder("test-client-span").startSpan(); + CallAttemptsTracerFactory callTracer = + tracingModule.newClientCallTracer(testClientSpan, method); + Span testAttemptSpan = tracerRule.spanBuilder("test-attempt-span").startSpan(); + when(mockSpanBuilder.startSpan()).thenReturn(testAttemptSpan); + + Metadata headers = new Metadata(); + ClientStreamTracer clientStreamTracer = callTracer.newClientStreamTracer(STREAM_INFO, headers); + clientStreamTracer.streamCreated(Attributes.EMPTY, headers); + clientStreamTracer.streamClosed(Status.CANCELLED); + + Metadata.Key key = Metadata.Key.of( + GrpcTraceBinContextPropagator.GRPC_TRACE_BIN_HEADER, Metadata.BINARY_BYTE_MARSHALLER); + assertTrue(Arrays.equals(BinaryFormat.getInstance().toBytes(testAttemptSpan.getSpanContext()), + headers.get(key) + )); + verify(mockSpanBuilder).setParent(contextArgumentCaptor.capture()); + assertEquals(testClientSpan, Span.fromContext(contextArgumentCaptor.getValue())); + + Span serverSpan = tracerRule.spanBuilder("test-server-span").startSpan(); + when(mockSpanBuilder.startSpan()).thenReturn(serverSpan); + ServerStreamTracer.Factory tracerFactory = tracingModule.getServerTracerFactory(); + ServerStreamTracer serverStreamTracer = + tracerFactory.newServerStreamTracer(method.getFullMethodName(), headers); + serverStreamTracer.streamClosed(Status.CANCELLED); + + verify(mockSpanBuilder, times(2)) + .setParent(contextArgumentCaptor.capture()); + assertEquals(testAttemptSpan.getSpanContext(), + Span.fromContext(contextArgumentCaptor.getValue()).getSpanContext()); + } + + @Test + public void generateTraceSpanName() { + assertEquals( + "Sent.io.grpc.Foo", OpenTelemetryTracingModule.generateTraceSpanName( + false, "io.grpc/Foo")); + assertEquals( + "Recv.io.grpc.Bar", OpenTelemetryTracingModule.generateTraceSpanName( + true, "io.grpc/Bar")); + } +} diff --git a/protobuf-lite/BUILD.bazel b/protobuf-lite/BUILD.bazel index 087723e95fb..dad794e8b58 100644 --- a/protobuf-lite/BUILD.bazel +++ b/protobuf-lite/BUILD.bazel @@ -10,7 +10,6 @@ java_library( "//api", artifact("com.google.code.findbugs:jsr305"), artifact("com.google.guava:guava"), - artifact("com.google.j2objc:j2objc-annotations"), ] + select({ ":android": ["@com_google_protobuf//:protobuf_javalite"], "//conditions:default": ["@com_google_protobuf//:protobuf_java"], diff --git a/protobuf/BUILD.bazel b/protobuf/BUILD.bazel index 47cc8f9d032..724c78ca6ee 100644 --- a/protobuf/BUILD.bazel +++ b/protobuf/BUILD.bazel @@ -13,6 +13,5 @@ java_library( artifact("com.google.api.grpc:proto-google-common-protos"), artifact("com.google.code.findbugs:jsr305"), artifact("com.google.guava:guava"), - artifact("com.google.j2objc:j2objc-annotations"), ], ) diff --git a/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java b/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java index 0dcffadeb40..d0661ba3be8 100644 --- a/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java +++ b/rls/src/main/java/io/grpc/rls/CachingRlsLbClient.java @@ -329,7 +329,7 @@ final CachedRouteLookupResponse get(final RouteLookupRequest request) { final CacheEntry cacheEntry; cacheEntry = linkedHashLruCache.read(request); if (cacheEntry == null) { - logger.log(ChannelLogLevel.DEBUG, "No cache entry found, making a new lrs request"); + logger.log(ChannelLogLevel.DEBUG, "No cache entry found, making a new RLS request"); PendingCacheEntry pendingEntry = pendingCallCache.get(request); if (pendingEntry != null) { return CachedRouteLookupResponse.pendingResponse(pendingEntry); @@ -988,7 +988,7 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { new Object[]{serviceName, methodName, args.getHeaders(), response}); if (response.getHeaderData() != null && !response.getHeaderData().isEmpty()) { - logger.log(ChannelLogLevel.DEBUG, "Updating LRS metadata from the LRS response headers"); + logger.log(ChannelLogLevel.DEBUG, "Updating RLS metadata from the RLS response headers"); Metadata headers = args.getHeaders(); headers.discardAll(RLS_DATA_KEY); headers.put(RLS_DATA_KEY, response.getHeaderData()); @@ -997,7 +997,7 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { logger.log(ChannelLogLevel.DEBUG, "defaultTarget = {0}", defaultTarget); boolean hasFallback = defaultTarget != null && !defaultTarget.isEmpty(); if (response.hasData()) { - logger.log(ChannelLogLevel.DEBUG, "LRS response has data, proceed with selecting a picker"); + logger.log(ChannelLogLevel.DEBUG, "RLS response has data, proceed with selecting a picker"); ChildPolicyWrapper childPolicyWrapper = response.getChildPolicyWrapper(); SubchannelPicker picker = (childPolicyWrapper != null) ? childPolicyWrapper.getPicker() : null; diff --git a/services/build.gradle b/services/build.gradle index de716c9fa1d..fade7aef3fb 100644 --- a/services/build.gradle +++ b/services/build.gradle @@ -27,11 +27,10 @@ dependencies { implementation project(':grpc-core'), project(':grpc-protobuf'), project(':grpc-util'), - libraries.protobuf.java.util, - libraries.guava.jre // JRE required by protobuf-java-util + libraries.guava.jre, // JRE required by protobuf-java-util + libraries.protobuf.java.util runtimeOnly libraries.errorprone.annotations, - libraries.j2objc.annotations, // Explicit dependency to keep in step with version used by guava libraries.gson // to fix checkUpperBoundDeps error here compileOnly libraries.javax.annotation testImplementation project(':grpc-testing'), diff --git a/stub/BUILD.bazel b/stub/BUILD.bazel index 8950a1cfd3f..6d06e01f918 100644 --- a/stub/BUILD.bazel +++ b/stub/BUILD.bazel @@ -12,7 +12,6 @@ java_library( artifact("com.google.code.findbugs:jsr305"), artifact("com.google.errorprone:error_prone_annotations"), artifact("com.google.guava:guava"), - artifact("com.google.j2objc:j2objc-annotations"), ], ) diff --git a/stub/src/main/java/io/grpc/stub/MetadataUtils.java b/stub/src/main/java/io/grpc/stub/MetadataUtils.java index addf54c0f81..4208d3ca652 100644 --- a/stub/src/main/java/io/grpc/stub/MetadataUtils.java +++ b/stub/src/main/java/io/grpc/stub/MetadataUtils.java @@ -22,10 +22,15 @@ import io.grpc.Channel; import io.grpc.ClientCall; import io.grpc.ClientInterceptor; +import io.grpc.ExperimentalApi; import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; +import io.grpc.ForwardingServerCall.SimpleForwardingServerCall; import io.grpc.Metadata; import io.grpc.MethodDescriptor; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; import io.grpc.Status; import java.util.concurrent.atomic.AtomicReference; @@ -143,4 +148,63 @@ public void onClose(Status status, Metadata trailers) { } } } + + /** + * Returns a ServerInterceptor that adds the specified Metadata to every response stream, one way + * or another. + * + *

If, absent this interceptor, a stream would have headers, 'extras' will be added to those + * headers. Otherwise, 'extras' will be sent as trailers. This pattern is useful when you have + * some fixed information, server identity say, that should be included no matter how the call + * turns out. The fallback to trailers avoids artificially committing clients to error responses + * that could otherwise be retried (see https://grpc.io/docs/guides/retry/ for more). + * + *

For correct operation, be sure to arrange for this interceptor to run *before* any others + * that might add headers. + * + * @param extras the Metadata to be added to each stream. Caller gives up ownership. + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/11462") + public static ServerInterceptor newAttachMetadataServerInterceptor(Metadata extras) { + return new MetadataAttachingServerInterceptor(extras); + } + + private static final class MetadataAttachingServerInterceptor implements ServerInterceptor { + + private final Metadata extras; + + MetadataAttachingServerInterceptor(Metadata extras) { + this.extras = extras; + } + + @Override + public ServerCall.Listener interceptCall( + ServerCall call, Metadata headers, ServerCallHandler next) { + return next.startCall(new MetadataAttachingServerCall<>(call), headers); + } + + final class MetadataAttachingServerCall + extends SimpleForwardingServerCall { + boolean headersSent; + + MetadataAttachingServerCall(ServerCall delegate) { + super(delegate); + } + + @Override + public void sendHeaders(Metadata headers) { + headers.merge(extras); + headersSent = true; + super.sendHeaders(headers); + } + + @Override + public void close(Status status, Metadata trailers) { + if (!headersSent) { + trailers.merge(extras); + } + super.close(status, trailers); + } + } + } } diff --git a/stub/src/test/java/io/grpc/stub/MetadataUtilsTest.java b/stub/src/test/java/io/grpc/stub/MetadataUtilsTest.java new file mode 100644 index 00000000000..f9890ac0433 --- /dev/null +++ b/stub/src/test/java/io/grpc/stub/MetadataUtilsTest.java @@ -0,0 +1,175 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.stub; + +import static com.google.common.truth.Truth.assertThat; +import static io.grpc.stub.MetadataUtils.newAttachMetadataServerInterceptor; +import static io.grpc.stub.MetadataUtils.newCaptureMetadataInterceptor; +import static org.junit.Assert.fail; + +import com.google.common.collect.ImmutableList; +import io.grpc.CallOptions; +import io.grpc.ManagedChannel; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptors; +import io.grpc.ServerMethodDefinition; +import io.grpc.ServerServiceDefinition; +import io.grpc.Status; +import io.grpc.Status.Code; +import io.grpc.StatusRuntimeException; +import io.grpc.StringMarshaller; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.testing.GrpcCleanupRule; +import java.io.IOException; +import java.util.Iterator; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class MetadataUtilsTest { + + @Rule public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + + private static final String SERVER_NAME = "test"; + private static final Metadata.Key FOO_KEY = + Metadata.Key.of("foo-key", Metadata.ASCII_STRING_MARSHALLER); + + private final MethodDescriptor echoMethod = + MethodDescriptor.newBuilder(StringMarshaller.INSTANCE, StringMarshaller.INSTANCE) + .setFullMethodName("test/echo") + .setType(MethodDescriptor.MethodType.UNARY) + .build(); + + private final ServerCallHandler echoCallHandler = + ServerCalls.asyncUnaryCall( + (req, respObserver) -> { + respObserver.onNext(req); + respObserver.onCompleted(); + }); + + MethodDescriptor echoServerStreamingMethod = + MethodDescriptor.newBuilder(StringMarshaller.INSTANCE, StringMarshaller.INSTANCE) + .setFullMethodName("test/echoStream") + .setType(MethodDescriptor.MethodType.SERVER_STREAMING) + .build(); + + private final AtomicReference trailersCapture = new AtomicReference<>(); + private final AtomicReference headersCapture = new AtomicReference<>(); + + @Test + public void shouldAttachHeadersToResponse() throws IOException { + Metadata extras = new Metadata(); + extras.put(FOO_KEY, "foo-value"); + + ServerServiceDefinition serviceDef = + ServerInterceptors.intercept( + ServerServiceDefinition.builder("test").addMethod(echoMethod, echoCallHandler).build(), + ImmutableList.of(newAttachMetadataServerInterceptor(extras))); + + grpcCleanup.register(newInProcessServerBuilder().addService(serviceDef).build().start()); + ManagedChannel channel = + grpcCleanup.register( + newInProcessChannelBuilder() + .intercept(newCaptureMetadataInterceptor(headersCapture, trailersCapture)) + .build()); + + String response = + ClientCalls.blockingUnaryCall(channel, echoMethod, CallOptions.DEFAULT, "hello"); + assertThat(response).isEqualTo("hello"); + assertThat(trailersCapture.get() == null || !trailersCapture.get().containsKey(FOO_KEY)) + .isTrue(); + assertThat(headersCapture.get().get(FOO_KEY)).isEqualTo("foo-value"); + } + + @Test + public void shouldAttachTrailersWhenNoResponse() throws IOException { + Metadata extras = new Metadata(); + extras.put(FOO_KEY, "foo-value"); + + ServerServiceDefinition serviceDef = + ServerInterceptors.intercept( + ServerServiceDefinition.builder("test") + .addMethod( + ServerMethodDefinition.create( + echoServerStreamingMethod, + ServerCalls.asyncUnaryCall( + (req, respObserver) -> respObserver.onCompleted()))) + .build(), + ImmutableList.of(newAttachMetadataServerInterceptor(extras))); + grpcCleanup.register(newInProcessServerBuilder().addService(serviceDef).build().start()); + + ManagedChannel channel = + grpcCleanup.register( + newInProcessChannelBuilder() + .intercept(newCaptureMetadataInterceptor(headersCapture, trailersCapture)) + .build()); + + Iterator response = + ClientCalls.blockingServerStreamingCall( + channel, echoServerStreamingMethod, CallOptions.DEFAULT, "hello"); + assertThat(response.hasNext()).isFalse(); + assertThat(headersCapture.get() == null || !headersCapture.get().containsKey(FOO_KEY)).isTrue(); + assertThat(trailersCapture.get().get(FOO_KEY)).isEqualTo("foo-value"); + } + + @Test + public void shouldAttachTrailersToErrorResponse() throws IOException { + Metadata extras = new Metadata(); + extras.put(FOO_KEY, "foo-value"); + + ServerServiceDefinition serviceDef = + ServerInterceptors.intercept( + ServerServiceDefinition.builder("test") + .addMethod( + echoMethod, + ServerCalls.asyncUnaryCall( + (req, respObserver) -> + respObserver.onError(Status.INVALID_ARGUMENT.asRuntimeException()))) + .build(), + ImmutableList.of(newAttachMetadataServerInterceptor(extras))); + grpcCleanup.register(newInProcessServerBuilder().addService(serviceDef).build().start()); + + ManagedChannel channel = + grpcCleanup.register( + newInProcessChannelBuilder() + .intercept(newCaptureMetadataInterceptor(headersCapture, trailersCapture)) + .build()); + try { + ClientCalls.blockingUnaryCall(channel, echoMethod, CallOptions.DEFAULT, "hello"); + fail(); + } catch (StatusRuntimeException e) { + assertThat(e.getStatus()).isNotNull(); + assertThat(e.getStatus().getCode()).isEqualTo(Code.INVALID_ARGUMENT); + } + assertThat(headersCapture.get() == null || !headersCapture.get().containsKey(FOO_KEY)).isTrue(); + assertThat(trailersCapture.get().get(FOO_KEY)).isEqualTo("foo-value"); + } + + private static InProcessServerBuilder newInProcessServerBuilder() { + return InProcessServerBuilder.forName(SERVER_NAME).directExecutor(); + } + + private static InProcessChannelBuilder newInProcessChannelBuilder() { + return InProcessChannelBuilder.forName(SERVER_NAME).directExecutor(); + } +} diff --git a/testing/BUILD.bazel b/testing/BUILD.bazel index 668a666c2fe..78f9b840754 100644 --- a/testing/BUILD.bazel +++ b/testing/BUILD.bazel @@ -18,7 +18,6 @@ java_library( "//util", artifact("com.google.code.findbugs:jsr305"), artifact("com.google.guava:guava"), - artifact("com.google.j2objc:j2objc-annotations"), artifact("com.google.truth:truth"), artifact("junit:junit"), ], diff --git a/util/BUILD.bazel b/util/BUILD.bazel index 7a38063a983..8fb00e21d56 100644 --- a/util/BUILD.bazel +++ b/util/BUILD.bazel @@ -15,7 +15,6 @@ java_library( artifact("com.google.code.findbugs:jsr305"), artifact("com.google.errorprone:error_prone_annotations"), artifact("com.google.guava:guava"), - artifact("com.google.j2objc:j2objc-annotations"), artifact("org.codehaus.mojo:animal-sniffer-annotations"), ], ) diff --git a/util/src/main/java/io/grpc/util/ForwardingSubchannel.java b/util/src/main/java/io/grpc/util/ForwardingSubchannel.java index 51f2583186e..416be378162 100644 --- a/util/src/main/java/io/grpc/util/ForwardingSubchannel.java +++ b/util/src/main/java/io/grpc/util/ForwardingSubchannel.java @@ -74,11 +74,17 @@ public Object getInternalSubchannel() { return delegate().getInternalSubchannel(); } + @Override public void updateAddresses(List addrs) { delegate().updateAddresses(addrs); } + @Override + public Attributes getConnectedAddressAttributes() { + return delegate().getConnectedAddressAttributes(); + } + @Override public String toString() { return MoreObjects.toStringHelper(this).add("delegate", delegate()).toString(); diff --git a/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java b/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java index c5f774984fe..626c2e1104e 100644 --- a/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java +++ b/util/src/main/java/io/grpc/util/MultiChildLoadBalancer.java @@ -16,7 +16,6 @@ package io.grpc.util; -import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import static io.grpc.ConnectivityState.CONNECTING; import static io.grpc.ConnectivityState.IDLE; @@ -26,7 +25,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import io.grpc.Attributes; import io.grpc.ConnectivityState; import io.grpc.EquivalentAddressGroup; @@ -37,10 +35,10 @@ import io.grpc.internal.PickFirstLoadBalancerProvider; import java.net.SocketAddress; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -81,29 +79,27 @@ protected MultiChildLoadBalancer(Helper helper) { /** * Override to utilize parsing of the policy configuration or alternative helper/lb generation. + * Override this if keys are not Endpoints or if child policies have configuration. */ - protected Map createChildLbMap(ResolvedAddresses resolvedAddresses) { - Map childLbMap = new HashMap<>(); - List addresses = resolvedAddresses.getAddresses(); - for (EquivalentAddressGroup eag : addresses) { - Endpoint endpoint = new Endpoint(eag); // keys need to be just addresses - ChildLbState existingChildLbState = childLbStates.get(endpoint); - if (existingChildLbState != null) { - childLbMap.put(endpoint, existingChildLbState); - } else { - childLbMap.put(endpoint, - createChildLbState(endpoint, null, getInitialPicker(), resolvedAddresses)); - } - } - return childLbMap; + protected Map createChildAddressesMap( + ResolvedAddresses resolvedAddresses) { + Map childAddresses = new HashMap<>(); + for (EquivalentAddressGroup eag : resolvedAddresses.getAddresses()) { + ResolvedAddresses addresses = resolvedAddresses.toBuilder() + .setAddresses(Collections.singletonList(eag)) + .setAttributes(Attributes.newBuilder().set(IS_PETIOLE_POLICY, true).build()) + .setLoadBalancingPolicyConfig(null) + .build(); + childAddresses.put(new Endpoint(eag), addresses); + } + return childAddresses; } /** * Override to create an instance of a subclass. */ - protected ChildLbState createChildLbState(Object key, Object policyConfig, - SubchannelPicker initialPicker, ResolvedAddresses resolvedAddresses) { - return new ChildLbState(key, pickFirstLbProvider, policyConfig, initialPicker); + protected ChildLbState createChildLbState(Object key) { + return new ChildLbState(key, pickFirstLbProvider); } /** @@ -131,41 +127,6 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { } } - /** - * Override this if your keys are not of type Endpoint. - * @param key Key to identify the ChildLbState - * @param resolvedAddresses list of addresses which include attributes - * @param childConfig a load balancing policy config. This field is optional. - * @return a fully loaded ResolvedAddresses object for the specified key - */ - protected ResolvedAddresses getChildAddresses(Object key, ResolvedAddresses resolvedAddresses, - Object childConfig) { - Endpoint endpointKey; - if (key instanceof EquivalentAddressGroup) { - endpointKey = new Endpoint((EquivalentAddressGroup) key); - } else { - checkArgument(key instanceof Endpoint, "key is wrong type"); - endpointKey = (Endpoint) key; - } - - // Retrieve the non-stripped version - EquivalentAddressGroup eagToUse = null; - for (EquivalentAddressGroup currEag : resolvedAddresses.getAddresses()) { - if (endpointKey.equals(new Endpoint(currEag))) { - eagToUse = currEag; - break; - } - } - - checkNotNull(eagToUse, key + " no longer present in load balancer children"); - - return resolvedAddresses.toBuilder() - .setAddresses(Collections.singletonList(eagToUse)) - .setAttributes(Attributes.newBuilder().set(IS_PETIOLE_POLICY, true).build()) - .setLoadBalancingPolicyConfig(childConfig) - .build(); - } - /** * Handle the name resolution error. * @@ -174,37 +135,11 @@ protected ResolvedAddresses getChildAddresses(Object key, ResolvedAddresses reso @Override public void handleNameResolutionError(Status error) { if (currentConnectivityState != READY) { - helper.updateBalancingState(TRANSIENT_FAILURE, getErrorPicker(error)); + helper.updateBalancingState( + TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(error))); } } - /** - * Handle the name resolution error only for the specified child. - * - *

Override if you need special handling. - */ - protected void handleNameResolutionError(ChildLbState child, Status error) { - child.lb.handleNameResolutionError(error); - } - - /** - * Creates a picker representing the state before any connections have been established. - * - *

Override to produce a custom picker. - */ - protected SubchannelPicker getInitialPicker() { - return new FixedResultPicker(PickResult.withNoResult()); - } - - /** - * Creates a new picker representing an error status. - * - *

Override to produce a custom picker when there are errors. - */ - protected SubchannelPicker getErrorPicker(Status error) { - return new FixedResultPicker(PickResult.withError(error)); - } - @Override public void shutdown() { logger.log(Level.FINE, "Shutdown"); @@ -223,50 +158,38 @@ protected final AcceptResolvedAddrRetVal acceptResolvedAddressesInternal( ResolvedAddresses resolvedAddresses) { logger.log(Level.FINE, "Received resolution result: {0}", resolvedAddresses); - // Subclass handles any special manipulation to create appropriate types of keyed ChildLbStates - Map newChildren = createChildLbMap(resolvedAddresses); + Map newChildAddresses = createChildAddressesMap(resolvedAddresses); // Handle error case - if (newChildren.isEmpty()) { + if (newChildAddresses.isEmpty()) { Status unavailableStatus = Status.UNAVAILABLE.withDescription( "NameResolver returned no usable address. " + resolvedAddresses); handleNameResolutionError(unavailableStatus); return new AcceptResolvedAddrRetVal(unavailableStatus, null); } - addMissingChildren(newChildren); - - updateChildrenWithResolvedAddresses(resolvedAddresses, newChildren); - - return new AcceptResolvedAddrRetVal(Status.OK, getRemovedChildren(newChildren.keySet())); - } + updateChildrenWithResolvedAddresses(newChildAddresses); - protected final void addMissingChildren(Map newChildren) { - // Do adds and identify reused children - for (Map.Entry entry : newChildren.entrySet()) { - final Object key = entry.getKey(); - if (!childLbStates.containsKey(key)) { - childLbStates.put(key, entry.getValue()); - } - } + return new AcceptResolvedAddrRetVal(Status.OK, getRemovedChildren(newChildAddresses.keySet())); } - protected final void updateChildrenWithResolvedAddresses(ResolvedAddresses resolvedAddresses, - Map newChildren) { - for (Map.Entry entry : newChildren.entrySet()) { - Object childConfig = entry.getValue().getConfig(); + private void updateChildrenWithResolvedAddresses( + Map newChildAddresses) { + for (Map.Entry entry : newChildAddresses.entrySet()) { ChildLbState childLbState = childLbStates.get(entry.getKey()); - ResolvedAddresses childAddresses = - getChildAddresses(entry.getKey(), resolvedAddresses, childConfig); - childLbState.setResolvedAddresses(childAddresses); // update child - childLbState.lb.handleResolvedAddresses(childAddresses); // update child LB + if (childLbState == null) { + childLbState = createChildLbState(entry.getKey()); + childLbStates.put(entry.getKey(), childLbState); + } + childLbState.setResolvedAddresses(entry.getValue()); // update child + childLbState.lb.handleResolvedAddresses(entry.getValue()); // update child LB } } /** * Identifies which children have been removed (are not part of the newChildKeys). */ - protected final List getRemovedChildren(Set newChildKeys) { + private List getRemovedChildren(Set newChildKeys) { List removedChildren = new ArrayList<>(); // Do removals for (Object key : ImmutableList.copyOf(childLbStates.keySet())) { @@ -308,11 +231,6 @@ protected final Helper getHelper() { return helper; } - @VisibleForTesting - public final ImmutableMap getImmutableChildMap() { - return ImmutableMap.copyOf(childLbStates); - } - @VisibleForTesting public final Collection getChildLbStates() { return childLbStates.values(); @@ -361,17 +279,13 @@ protected final List getReadyChildren() { public class ChildLbState { private final Object key; private ResolvedAddresses resolvedAddresses; - private final Object config; private final LoadBalancer lb; private ConnectivityState currentState; - private SubchannelPicker currentPicker; + private SubchannelPicker currentPicker = new FixedResultPicker(PickResult.withNoResult()); - public ChildLbState(Object key, LoadBalancer.Factory policyFactory, Object childConfig, - SubchannelPicker initialPicker) { + public ChildLbState(Object key, LoadBalancer.Factory policyFactory) { this.key = key; - this.currentPicker = initialPicker; - this.config = childConfig; this.lb = policyFactory.newLoadBalancer(createChildHelper()); this.currentState = CONNECTING; } @@ -411,13 +325,6 @@ public final SubchannelPicker getCurrentPicker() { return currentPicker; } - protected final Subchannel getSubchannels(PickSubchannelArgs args) { - if (getCurrentPicker() == null) { - return null; - } - return getCurrentPicker().pickSubchannel(args).getSubchannel(); - } - public final ConnectivityState getCurrentState() { return currentState; } @@ -442,10 +349,6 @@ protected final void setResolvedAddresses(ResolvedAddresses newAddresses) { resolvedAddresses = newAddresses; } - private Object getConfig() { - return config; - } - @VisibleForTesting public final ResolvedAddresses getResolvedAddresses() { return resolvedAddresses; @@ -463,13 +366,11 @@ protected class ChildLbStateHelper extends ForwardingLoadBalancerHelper { /** * Update current state and picker for this child and then use * {@link #updateOverallBalancingState()} for the parent LB. - * - *

Override this if you don't want to automatically request a connection when in IDLE */ @Override public void updateBalancingState(final ConnectivityState newState, final SubchannelPicker newPicker) { - if (!childLbStates.containsKey(key)) { + if (currentState == SHUTDOWN) { return; } @@ -478,9 +379,6 @@ public void updateBalancingState(final ConnectivityState newState, // If we are already in the process of resolving addresses, the overall balancing state // will be updated at the end of it, and we don't need to trigger that update here. if (!resolvingAddresses) { - if (newState == IDLE) { - lb.requestConnection(); - } updateOverallBalancingState(); } } @@ -494,25 +392,27 @@ protected Helper delegate() { /** * Endpoint is an optimization to quickly lookup and compare EquivalentAddressGroup address sets. - * Ignores the attributes, orders the addresses in a deterministic manner and converts each - * address into a string for easy comparison. Also caches the hashcode. - * Is used as a key for ChildLbState for most load balancers (ClusterManagerLB uses a String). + * It ignores the attributes. Is used as a key for ChildLbState for most load balancers + * (ClusterManagerLB uses a String). */ protected static class Endpoint { - final String[] addrs; + final Collection addrs; final int hashCode; public Endpoint(EquivalentAddressGroup eag) { checkNotNull(eag, "eag"); - addrs = new String[eag.getAddresses().size()]; - int i = 0; + if (eag.getAddresses().size() < 10) { + addrs = eag.getAddresses(); + } else { + // This is expected to be very unlikely in practice + addrs = new HashSet<>(eag.getAddresses()); + } + int sum = 0; for (SocketAddress address : eag.getAddresses()) { - addrs[i++] = address.toString(); + sum += address.hashCode(); } - Arrays.sort(addrs); - - hashCode = Arrays.hashCode(addrs); + hashCode = sum; } @Override @@ -525,24 +425,21 @@ public boolean equals(Object other) { if (this == other) { return true; } - if (other == null) { - return false; - } if (!(other instanceof Endpoint)) { return false; } Endpoint o = (Endpoint) other; - if (o.hashCode != hashCode || o.addrs.length != addrs.length) { + if (o.hashCode != hashCode || o.addrs.size() != addrs.size()) { return false; } - return Arrays.equals(o.addrs, this.addrs); + return o.addrs.containsAll(addrs); } @Override public String toString() { - return Arrays.toString(addrs); + return addrs.toString(); } } diff --git a/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java b/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java index 7c235bb3640..22940e875ac 100644 --- a/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java +++ b/util/src/main/java/io/grpc/util/RoundRobinLoadBalancer.java @@ -42,7 +42,7 @@ */ final class RoundRobinLoadBalancer extends MultiChildLoadBalancer { private final AtomicInteger sequence = new AtomicInteger(new Random().nextInt()); - private SubchannelPicker currentPicker = new EmptyPicker(); + private SubchannelPicker currentPicker = new FixedResultPicker(PickResult.withNoResult()); public RoundRobinLoadBalancer(Helper helper) { super(helper); @@ -68,7 +68,7 @@ protected void updateOverallBalancingState() { } if (isConnecting) { - updateBalancingState(CONNECTING, new EmptyPicker()); + updateBalancingState(CONNECTING, new FixedResultPicker(PickResult.withNoResult())); } else { updateBalancingState(TRANSIENT_FAILURE, createReadyPicker(getChildLbStates())); } @@ -95,6 +95,24 @@ private SubchannelPicker createReadyPicker(Collection children) { return new ReadyPicker(pickerList, sequence); } + @Override + protected ChildLbState createChildLbState(Object key) { + return new ChildLbState(key, pickFirstLbProvider) { + @Override + protected ChildLbStateHelper createChildHelper() { + return new ChildLbStateHelper() { + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + super.updateBalancingState(newState, newPicker); + if (!resolvingAddresses && newState == IDLE) { + getLb().requestConnection(); + } + } + }; + } + }; + } + @VisibleForTesting static class ReadyPicker extends SubchannelPicker { private final List subchannelPickers; // non-empty @@ -161,22 +179,4 @@ public boolean equals(Object o) { && new HashSet<>(subchannelPickers).containsAll(other.subchannelPickers); } } - - @VisibleForTesting - static final class EmptyPicker extends SubchannelPicker { - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return PickResult.withNoResult(); - } - - @Override - public int hashCode() { - return getClass().hashCode(); - } - - @Override - public boolean equals(Object o) { - return o instanceof EmptyPicker; - } - } } diff --git a/util/src/test/java/io/grpc/util/MultiChildLoadBalancerTest.java b/util/src/test/java/io/grpc/util/MultiChildLoadBalancerTest.java index df226d5aee8..6bfd6d7a659 100644 --- a/util/src/test/java/io/grpc/util/MultiChildLoadBalancerTest.java +++ b/util/src/test/java/io/grpc/util/MultiChildLoadBalancerTest.java @@ -21,7 +21,6 @@ import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.SHUTDOWN; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.AdditionalAnswers.delegatesTo; @@ -34,6 +33,7 @@ import static org.mockito.Mockito.verify; import com.google.common.collect.Lists; +import com.google.common.testing.EqualsTester; import io.grpc.Attributes; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; @@ -244,37 +244,28 @@ public void testEndpoint_toString() { @Test public void testEndpoint_equals() { - assertEquals( - createEndpoint(Attributes.EMPTY, "addr1"), - createEndpoint(Attributes.EMPTY, "addr1")); - - assertEquals( - createEndpoint(Attributes.EMPTY, "addr1", "addr2"), - createEndpoint(Attributes.EMPTY, "addr2", "addr1")); - - assertEquals( - createEndpoint(Attributes.EMPTY, "addr1", "addr2"), - createEndpoint(affinity, "addr2", "addr1")); - - assertEquals( - createEndpoint(Attributes.EMPTY, "addr1", "addr2").hashCode(), - createEndpoint(affinity, "addr2", "addr1").hashCode()); - - } - - @Test - public void testEndpoint_notEquals() { - assertNotEquals( - createEndpoint(Attributes.EMPTY, "addr1", "addr2"), - createEndpoint(Attributes.EMPTY, "addr1", "addr3")); - - assertNotEquals( - createEndpoint(Attributes.EMPTY, "addr1"), - createEndpoint(Attributes.EMPTY, "addr1", "addr2")); - - assertNotEquals( - createEndpoint(Attributes.EMPTY, "addr1", "addr2"), - createEndpoint(Attributes.EMPTY, "addr1")); + new EqualsTester() + .addEqualityGroup( + createEndpoint(Attributes.EMPTY, "addr1"), + createEndpoint(Attributes.EMPTY, "addr1")) + .addEqualityGroup( + createEndpoint(Attributes.EMPTY, "addr1", "addr2"), + createEndpoint(Attributes.EMPTY, "addr2", "addr1"), + createEndpoint(affinity, "addr1", "addr2")) + .addEqualityGroup( + createEndpoint(Attributes.EMPTY, "addr1", "addr3")) + .addEqualityGroup( + createEndpoint(Attributes.EMPTY, "addr1", "addr2", "addr3", "addr4", "addr5", "addr6", + "addr7", "addr8", "addr9", "addr10"), + createEndpoint(Attributes.EMPTY, "addr2", "addr1", "addr3", "addr4", "addr5", "addr6", + "addr7", "addr8", "addr9", "addr10")) + .addEqualityGroup( + createEndpoint(Attributes.EMPTY, "addr1", "addr2", "addr3", "addr4", "addr5", "addr6", + "addr7", "addr8", "addr9", "addr11")) + .addEqualityGroup( + createEndpoint(Attributes.EMPTY, "addr1", "addr2", "addr3", "addr4", "addr5", "addr6", + "addr7", "addr8", "addr9", "addr10", "addr11")) + .testEquals(); } private String addressesOnlyString(EquivalentAddressGroup eag) { diff --git a/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java b/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java index 449230f0f45..743bbbef796 100644 --- a/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java +++ b/util/src/test/java/io/grpc/util/RoundRobinLoadBalancerTest.java @@ -46,7 +46,9 @@ import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.CreateSubchannelArgs; +import io.grpc.LoadBalancer.FixedResultPicker; import io.grpc.LoadBalancer.Helper; +import io.grpc.LoadBalancer.PickResult; import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.Subchannel; @@ -54,7 +56,6 @@ import io.grpc.Status; import io.grpc.internal.TestUtils; import io.grpc.util.MultiChildLoadBalancer.ChildLbState; -import io.grpc.util.RoundRobinLoadBalancer.EmptyPicker; import io.grpc.util.RoundRobinLoadBalancer.ReadyPicker; import java.net.SocketAddress; import java.util.ArrayList; @@ -84,6 +85,8 @@ @RunWith(JUnit4.class) public class RoundRobinLoadBalancerTest { private static final Attributes.Key MAJOR_KEY = Attributes.Key.create("major-key"); + private static final SubchannelPicker EMPTY_PICKER = + new FixedResultPicker(PickResult.withNoResult()); @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @@ -248,7 +251,7 @@ public void pickAfterStateChange() throws Exception { ChildLbState childLbState = loadBalancer.getChildLbStates().iterator().next(); Subchannel subchannel = subchannels.get(Arrays.asList(childLbState.getEag())); - inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), eq(EMPTY_PICKER)); assertThat(childLbState.getCurrentState()).isEqualTo(CONNECTING); deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY)); @@ -261,7 +264,7 @@ public void pickAfterStateChange() throws Exception { ConnectivityStateInfo.forTransientFailure(error)); assertThat(childLbState.getCurrentState()).isEqualTo(TRANSIENT_FAILURE); AbstractTestHelper.refreshInvokedAndUpdateBS(inOrder, CONNECTING, mockHelper, pickerCaptor); - assertThat(pickerCaptor.getValue()).isInstanceOf(EmptyPicker.class); + assertThat(pickerCaptor.getValue()).isEqualTo(EMPTY_PICKER); deliverSubchannelState(subchannel, ConnectivityStateInfo.forNonError(IDLE)); inOrder.verify(mockHelper).refreshNameResolution(); @@ -277,7 +280,7 @@ public void ignoreShutdownSubchannelStateChange() { InOrder inOrder = inOrder(mockHelper); Status addressesAcceptanceStatus = acceptAddresses(servers, Attributes.EMPTY); assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), eq(EMPTY_PICKER)); loadBalancer.shutdown(); for (ChildLbState child : loadBalancer.getChildLbStates()) { @@ -297,7 +300,7 @@ public void stayTransientFailureUntilReady() { Status addressesAcceptanceStatus = acceptAddresses(servers, Attributes.EMPTY); assertThat(addressesAcceptanceStatus.isOk()).isTrue(); - inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), eq(EMPTY_PICKER)); Map childToSubChannelMap = new HashMap<>(); // Simulate state transitions for each subchannel individually. @@ -336,7 +339,7 @@ public void refreshNameResolutionWhenSubchannelConnectionBroken() { assertThat(addressesAcceptanceStatus.isOk()).isTrue(); verify(mockHelper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); - inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), eq(EMPTY_PICKER)); // Simulate state transitions for each subchannel individually. for (ChildLbState child : loadBalancer.getChildLbStates()) { @@ -352,7 +355,7 @@ public void refreshNameResolutionWhenSubchannelConnectionBroken() { deliverSubchannelState(sc, ConnectivityStateInfo.forNonError(IDLE)); inOrder.verify(mockHelper).refreshNameResolution(); verify(sc, times(2)).requestConnection(); - inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class)); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), eq(EMPTY_PICKER)); } AbstractTestHelper.verifyNoMoreMeaningfulInteractions(mockHelper); @@ -461,7 +464,7 @@ public void subchannelStateIsolation() throws Exception { Iterator pickers = pickerCaptor.getAllValues().iterator(); // The picker is incrementally updated as subchannels become READY assertEquals(CONNECTING, stateIterator.next()); - assertThat(pickers.next()).isInstanceOf(EmptyPicker.class); + assertThat(pickers.next()).isEqualTo(EMPTY_PICKER); assertEquals(READY, stateIterator.next()); assertThat(getList(pickers.next())).containsExactly(sc1); assertEquals(READY, stateIterator.next()); @@ -492,8 +495,8 @@ public void readyPicker_emptyList() { @Test public void internalPickerComparisons() { - SubchannelPicker empty1 = new EmptyPicker(); - SubchannelPicker empty2 = new EmptyPicker(); + SubchannelPicker empty1 = new FixedResultPicker(PickResult.withNoResult()); + SubchannelPicker empty2 = new FixedResultPicker(PickResult.withNoResult()); AtomicInteger seq = new AtomicInteger(0); acceptAddresses(servers, Attributes.EMPTY); // create subchannels diff --git a/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java b/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java index b0239c56703..bdeff9d17c5 100644 --- a/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java +++ b/util/src/testFixtures/java/io/grpc/util/AbstractTestHelper.java @@ -276,7 +276,7 @@ public String toString() { } } - public static class FakeSocketAddress extends SocketAddress { + public static final class FakeSocketAddress extends SocketAddress { private static final long serialVersionUID = 0L; final String name; @@ -288,6 +288,20 @@ public static class FakeSocketAddress extends SocketAddress { public String toString() { return "FakeSocketAddress-" + name; } + + @Override + public boolean equals(Object o) { + if (!(o instanceof FakeSocketAddress)) { + return false; + } + FakeSocketAddress that = (FakeSocketAddress) o; + return this.name.equals(that.name); + } + + @Override + public int hashCode() { + return name.hashCode(); + } } } diff --git a/xds/build.gradle b/xds/build.gradle index a1d5aa753cb..a738145a2a0 100644 --- a/xds/build.gradle +++ b/xds/build.gradle @@ -52,6 +52,7 @@ dependencies { project(':grpc-services'), project(':grpc-auth'), project(path: ':grpc-alts', configuration: 'shadow'), + libraries.guava, libraries.gson, libraries.re2j, libraries.auto.value.annotations, diff --git a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java index 773fdf20563..3f1eb3e7e4f 100644 --- a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java +++ b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java @@ -206,8 +206,9 @@ private void handleClusterDiscovered() { } loopStatus = Status.UNAVAILABLE.withDescription(String.format( "CDS error: circular aggregate clusters directly under %s for " - + "root cluster %s, named %s", - clusterState.name, root.name, namesCausingLoops)); + + "root cluster %s, named %s, xDS node ID: %s", + clusterState.name, root.name, namesCausingLoops, + xdsClient.getBootstrapInfo().node().getId())); } } } @@ -224,9 +225,9 @@ private void handleClusterDiscovered() { childLb.shutdown(); childLb = null; } - Status unavailable = - Status.UNAVAILABLE.withDescription("CDS error: found 0 leaf (logical DNS or EDS) " - + "clusters for root cluster " + root.name); + Status unavailable = Status.UNAVAILABLE.withDescription(String.format( + "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster %s" + + " xDS node ID: %s", root.name, xdsClient.getBootstrapInfo().node().getId())); helper.updateBalancingState( TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(unavailable))); return; @@ -288,11 +289,14 @@ private void addAncestors(Set ancestors, ClusterState clusterState, } private void handleClusterDiscoveryError(Status error) { + String description = error.getDescription() == null ? "" : error.getDescription() + " "; + Status errorWithNodeId = error.withDescription( + description + "xDS node ID: " + xdsClient.getBootstrapInfo().node().getId()); if (childLb != null) { - childLb.handleNameResolutionError(error); + childLb.handleNameResolutionError(errorWithNodeId); } else { helper.updateBalancingState( - TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(error))); + TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(errorWithNodeId))); } } diff --git a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java index 702b2aa6caa..0ea2c7dd75f 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java @@ -27,6 +27,7 @@ import io.grpc.ClientStreamTracer; import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.ConnectivityState; +import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; import io.grpc.InternalLogId; import io.grpc.LoadBalancer; @@ -59,6 +60,7 @@ import java.util.Map; import java.util.Objects; import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; import javax.annotation.Nullable; /** @@ -77,10 +79,8 @@ final class ClusterImplLoadBalancer extends LoadBalancer { Strings.isNullOrEmpty(System.getenv("GRPC_XDS_EXPERIMENTAL_CIRCUIT_BREAKING")) || Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_CIRCUIT_BREAKING")); - private static final Attributes.Key ATTR_CLUSTER_LOCALITY_STATS = - Attributes.Key.create("io.grpc.xds.ClusterImplLoadBalancer.clusterLocalityStats"); - private static final Attributes.Key ATTR_CLUSTER_LOCALITY_NAME = - Attributes.Key.create("io.grpc.xds.ClusterImplLoadBalancer.clusterLocalityName"); + private static final Attributes.Key> ATTR_CLUSTER_LOCALITY = + Attributes.Key.create("io.grpc.xds.ClusterImplLoadBalancer.clusterLocality"); private final XdsLogger logger; private final Helper helper; @@ -213,36 +213,45 @@ public void updateBalancingState(ConnectivityState newState, SubchannelPicker ne @Override public Subchannel createSubchannel(CreateSubchannelArgs args) { List addresses = withAdditionalAttributes(args.getAddresses()); - Locality locality = args.getAddresses().get(0).getAttributes().get( - InternalXdsAttributes.ATTR_LOCALITY); // all addresses should be in the same locality - String localityName = args.getAddresses().get(0).getAttributes().get( - InternalXdsAttributes.ATTR_LOCALITY_NAME); - // Endpoint addresses resolved by ClusterResolverLoadBalancer should always contain - // attributes with its locality, including endpoints in LOGICAL_DNS clusters. - // In case of not (which really shouldn't), loads are aggregated under an empty locality. - if (locality == null) { - locality = Locality.create("", "", ""); - localityName = ""; - } - final ClusterLocalityStats localityStats = - (lrsServerInfo == null) - ? null - : xdsClient.addClusterLocalityStats(lrsServerInfo, cluster, - edsServiceName, locality); - + // This value for ClusterLocality is not recommended for general use. + // Currently, we extract locality data from the first address, even before the subchannel is + // READY. + // This is mainly to accommodate scenarios where a Load Balancing API (like "pick first") + // might return the subchannel before it is READY. Typically, we wouldn't report load for such + // selections because the channel will disregard the chosen (not-ready) subchannel. + // However, we needed to ensure this case is handled. + ClusterLocality clusterLocality = createClusterLocalityFromAttributes( + args.getAddresses().get(0).getAttributes()); + AtomicReference localityAtomicReference = new AtomicReference<>( + clusterLocality); Attributes attrs = args.getAttributes().toBuilder() - .set(ATTR_CLUSTER_LOCALITY_STATS, localityStats) - .set(ATTR_CLUSTER_LOCALITY_NAME, localityName) + .set(ATTR_CLUSTER_LOCALITY, localityAtomicReference) .build(); args = args.toBuilder().setAddresses(addresses).setAttributes(attrs).build(); final Subchannel subchannel = delegate().createSubchannel(args); return new ForwardingSubchannel() { + @Override + public void start(SubchannelStateListener listener) { + delegate().start(new SubchannelStateListener() { + @Override + public void onSubchannelState(ConnectivityStateInfo newState) { + if (newState.getState().equals(ConnectivityState.READY)) { + // Get locality based on the connected address attributes + ClusterLocality updatedClusterLocality = createClusterLocalityFromAttributes( + subchannel.getConnectedAddressAttributes()); + ClusterLocality oldClusterLocality = localityAtomicReference + .getAndSet(updatedClusterLocality); + oldClusterLocality.release(); + } + listener.onSubchannelState(newState); + } + }); + } + @Override public void shutdown() { - if (localityStats != null) { - localityStats.release(); - } + localityAtomicReference.get().release(); delegate().shutdown(); } @@ -274,6 +283,28 @@ private List withAdditionalAttributes( return newAddresses; } + private ClusterLocality createClusterLocalityFromAttributes(Attributes addressAttributes) { + Locality locality = addressAttributes.get(InternalXdsAttributes.ATTR_LOCALITY); + String localityName = addressAttributes.get(InternalXdsAttributes.ATTR_LOCALITY_NAME); + + // Endpoint addresses resolved by ClusterResolverLoadBalancer should always contain + // attributes with its locality, including endpoints in LOGICAL_DNS clusters. + // In case of not (which really shouldn't), loads are aggregated under an empty + // locality. + if (locality == null) { + locality = Locality.create("", "", ""); + localityName = ""; + } + + final ClusterLocalityStats localityStats = + (lrsServerInfo == null) + ? null + : xdsClient.addClusterLocalityStats(lrsServerInfo, cluster, + edsServiceName, locality); + + return new ClusterLocality(localityStats, localityName); + } + @Override protected Helper delegate() { return helper; @@ -361,18 +392,23 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { "Cluster max concurrent requests limit exceeded")); } } - final ClusterLocalityStats stats = - result.getSubchannel().getAttributes().get(ATTR_CLUSTER_LOCALITY_STATS); - if (stats != null) { - String localityName = - result.getSubchannel().getAttributes().get(ATTR_CLUSTER_LOCALITY_NAME); - args.getPickDetailsConsumer().addOptionalLabel("grpc.lb.locality", localityName); - - ClientStreamTracer.Factory tracerFactory = new CountingStreamTracerFactory( - stats, inFlights, result.getStreamTracerFactory()); - ClientStreamTracer.Factory orcaTracerFactory = OrcaPerRequestUtil.getInstance() - .newOrcaClientStreamTracerFactory(tracerFactory, new OrcaPerRpcListener(stats)); - return PickResult.withSubchannel(result.getSubchannel(), orcaTracerFactory); + final AtomicReference clusterLocality = + result.getSubchannel().getAttributes().get(ATTR_CLUSTER_LOCALITY); + + if (clusterLocality != null) { + ClusterLocalityStats stats = clusterLocality.get().getClusterLocalityStats(); + if (stats != null) { + String localityName = + result.getSubchannel().getAttributes().get(ATTR_CLUSTER_LOCALITY).get() + .getClusterLocalityName(); + args.getPickDetailsConsumer().addOptionalLabel("grpc.lb.locality", localityName); + + ClientStreamTracer.Factory tracerFactory = new CountingStreamTracerFactory( + stats, inFlights, result.getStreamTracerFactory()); + ClientStreamTracer.Factory orcaTracerFactory = OrcaPerRequestUtil.getInstance() + .newOrcaClientStreamTracerFactory(tracerFactory, new OrcaPerRpcListener(stats)); + return PickResult.withSubchannel(result.getSubchannel(), orcaTracerFactory); + } } } return result; @@ -447,4 +483,33 @@ public void onLoadReport(MetricReport report) { stats.recordBackendLoadMetricStats(report.getNamedMetrics()); } } + + /** + * Represents the {@link ClusterLocalityStats} and network locality name of a cluster. + */ + static final class ClusterLocality { + private final ClusterLocalityStats clusterLocalityStats; + private final String clusterLocalityName; + + @VisibleForTesting + ClusterLocality(ClusterLocalityStats localityStats, String localityName) { + this.clusterLocalityStats = localityStats; + this.clusterLocalityName = localityName; + } + + ClusterLocalityStats getClusterLocalityStats() { + return clusterLocalityStats; + } + + String getClusterLocalityName() { + return clusterLocalityName; + } + + @VisibleForTesting + void release() { + if (clusterLocalityStats != null) { + clusterLocalityStats.release(); + } + } + } } diff --git a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java index 9e9ca5e1da3..c175b847c63 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancer.java @@ -23,11 +23,11 @@ import com.google.common.base.MoreObjects; import io.grpc.ConnectivityState; import io.grpc.InternalLogId; -import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancer; import io.grpc.Status; import io.grpc.SynchronizationContext; import io.grpc.SynchronizationContext.ScheduledHandle; -import io.grpc.internal.ServiceConfigUtil.PolicySelection; +import io.grpc.util.GracefulSwitchLoadBalancer; import io.grpc.util.MultiChildLoadBalancer; import io.grpc.xds.ClusterManagerLoadBalancerProvider.ClusterManagerConfig; import io.grpc.xds.client.XdsLogger; @@ -70,30 +70,28 @@ class ClusterManagerLoadBalancer extends MultiChildLoadBalancer { } @Override - protected ResolvedAddresses getChildAddresses(Object key, ResolvedAddresses resolvedAddresses, - Object childConfig) { - return resolvedAddresses.toBuilder().setLoadBalancingPolicyConfig(childConfig).build(); + protected ChildLbState createChildLbState(Object key) { + return new ClusterManagerLbState(key, GracefulSwitchLoadBalancerFactory.INSTANCE); } @Override - protected Map createChildLbMap(ResolvedAddresses resolvedAddresses) { + protected Map createChildAddressesMap( + ResolvedAddresses resolvedAddresses) { ClusterManagerConfig config = (ClusterManagerConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); - Map newChildPolicies = new HashMap<>(); + Map childAddresses = new HashMap<>(); if (config != null) { - for (Entry entry : config.childPolicies.entrySet()) { - ChildLbState child = getChildLbState(entry.getKey()); - if (child == null) { - child = new ClusterManagerLbState(entry.getKey(), - entry.getValue().getProvider(), entry.getValue().getConfig(), getInitialPicker()); - } - newChildPolicies.put(entry.getKey(), child); + for (Map.Entry childPolicy : config.childPolicies.entrySet()) { + ResolvedAddresses addresses = resolvedAddresses.toBuilder() + .setLoadBalancingPolicyConfig(childPolicy.getValue()) + .build(); + childAddresses.put(childPolicy.getKey(), addresses); } } logger.log( XdsLogLevel.INFO, - "Received cluster_manager lb config: child names={0}", newChildPolicies.keySet()); - return newChildPolicies; + "Received cluster_manager lb config: child names={0}", childAddresses.keySet()); + return childAddresses; } /** @@ -108,8 +106,8 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { resolvedAddresses.getLoadBalancingPolicyConfig(); ClusterManagerConfig lastConfig = (ClusterManagerConfig) lastResolvedAddresses.getLoadBalancingPolicyConfig(); - Map adjChildPolicies = new HashMap<>(config.childPolicies); - for (Entry entry : lastConfig.childPolicies.entrySet()) { + Map adjChildPolicies = new HashMap<>(config.childPolicies); + for (Entry entry : lastConfig.childPolicies.entrySet()) { ClusterManagerLbState state = (ClusterManagerLbState) getChildLbState(entry.getKey()); if (adjChildPolicies.containsKey(entry.getKey())) { if (state.deletionTimer != null) { @@ -183,11 +181,12 @@ public void handleNameResolutionError(Status error) { for (ChildLbState state : getChildLbStates()) { if (((ClusterManagerLbState) state).deletionTimer == null) { gotoTransientFailure = false; - handleNameResolutionError(state, error); + state.getLb().handleNameResolutionError(error); } } if (gotoTransientFailure) { - getHelper().updateBalancingState(TRANSIENT_FAILURE, getErrorPicker(error)); + getHelper().updateBalancingState( + TRANSIENT_FAILURE, new FixedResultPicker(PickResult.withError(error))); } } @@ -201,9 +200,8 @@ private class ClusterManagerLbState extends ChildLbState { @Nullable ScheduledHandle deletionTimer; - public ClusterManagerLbState(Object key, LoadBalancerProvider policyProvider, - Object childConfig, SubchannelPicker initialPicker) { - super(key, policyProvider, childConfig, initialPicker); + public ClusterManagerLbState(Object key, LoadBalancer.Factory policyFactory) { + super(key, policyFactory); } @Override @@ -236,8 +234,8 @@ class DeletionTask implements Runnable { public void run() { ClusterManagerConfig config = (ClusterManagerConfig) lastResolvedAddresses.getLoadBalancingPolicyConfig(); - Map childPolicies = new HashMap<>(config.childPolicies); - PolicySelection removed = childPolicies.remove(getKey()); + Map childPolicies = new HashMap<>(config.childPolicies); + Object removed = childPolicies.remove(getKey()); assert removed != null; config = new ClusterManagerConfig(childPolicies); lastResolvedAddresses = @@ -259,9 +257,7 @@ private class ClusterManagerChildHelper extends ChildLbStateHelper { @Override public void updateBalancingState(final ConnectivityState newState, final SubchannelPicker newPicker) { - // If we are already in the process of resolving addresses, the overall balancing state - // will be updated at the end of it, and we don't need to trigger that update here. - if (getChildLbState(getKey()) == null) { + if (getCurrentState() == ConnectivityState.SHUTDOWN) { return; } @@ -269,10 +265,21 @@ public void updateBalancingState(final ConnectivityState newState, // when the child instance exits deactivated state. setCurrentState(newState); setCurrentPicker(newPicker); + // If we are already in the process of resolving addresses, the overall balancing state + // will be updated at the end of it, and we don't need to trigger that update here. if (deletionTimer == null && !resolvingAddresses) { updateOverallBalancingState(); } } } } + + static final class GracefulSwitchLoadBalancerFactory extends LoadBalancer.Factory { + static final LoadBalancer.Factory INSTANCE = new GracefulSwitchLoadBalancerFactory(); + + @Override + public LoadBalancer newLoadBalancer(LoadBalancer.Helper helper) { + return new GracefulSwitchLoadBalancer(helper); + } + } } diff --git a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancerProvider.java index 9c97d3fe966..7a7e16286f8 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/ClusterManagerLoadBalancerProvider.java @@ -26,12 +26,9 @@ import io.grpc.NameResolver.ConfigOrError; import io.grpc.Status; import io.grpc.internal.JsonUtil; -import io.grpc.internal.ServiceConfigUtil; -import io.grpc.internal.ServiceConfigUtil.LbConfig; -import io.grpc.internal.ServiceConfigUtil.PolicySelection; +import io.grpc.util.GracefulSwitchLoadBalancer; import java.util.Collections; import java.util.LinkedHashMap; -import java.util.List; import java.util.Map; import java.util.Objects; import javax.annotation.Nullable; @@ -73,7 +70,7 @@ public String getPolicyName() { @Override public ConfigOrError parseLoadBalancingPolicyConfig(Map rawConfig) { - Map parsedChildPolicies = new LinkedHashMap<>(); + Map parsedChildPolicies = new LinkedHashMap<>(); try { Map childPolicies = JsonUtil.getObject(rawConfig, "childPolicy"); if (childPolicies == null || childPolicies.isEmpty()) { @@ -86,27 +83,19 @@ public ConfigOrError parseLoadBalancingPolicyConfig(Map rawConfig) { return ConfigOrError.fromError(Status.INTERNAL.withDescription( "No config for child " + name + " in cluster_manager LB policy: " + rawConfig)); } - List childConfigCandidates = - ServiceConfigUtil.unwrapLoadBalancingConfigList( - JsonUtil.getListOfObjects(childPolicy, "lbPolicy")); - if (childConfigCandidates == null || childConfigCandidates.isEmpty()) { - return ConfigOrError.fromError(Status.INTERNAL.withDescription( - "No config specified for child " + name + " in cluster_manager Lb policy: " - + rawConfig)); - } LoadBalancerRegistry registry = lbRegistry != null ? lbRegistry : LoadBalancerRegistry.getDefaultRegistry(); - ConfigOrError selectedConfig = - ServiceConfigUtil.selectLbPolicyFromList(childConfigCandidates, registry); - if (selectedConfig.getError() != null) { - Status error = selectedConfig.getError(); + ConfigOrError childConfig = GracefulSwitchLoadBalancer.parseLoadBalancingPolicyConfig( + JsonUtil.getListOfObjects(childPolicy, "lbPolicy"), registry); + if (childConfig.getError() != null) { + Status error = childConfig.getError(); return ConfigOrError.fromError( Status.INTERNAL .withCause(error.getCause()) .withDescription(error.getDescription()) - .augmentDescription("Failed to select config for child " + name)); + .augmentDescription("Failed to parse config for child " + name)); } - parsedChildPolicies.put(name, (PolicySelection) selectedConfig.getConfig()); + parsedChildPolicies.put(name, childConfig.getConfig()); } } catch (RuntimeException e) { return ConfigOrError.fromError( @@ -122,9 +111,9 @@ public LoadBalancer newLoadBalancer(Helper helper) { } static class ClusterManagerConfig { - final Map childPolicies; + final Map childPolicies; - ClusterManagerConfig(Map childPolicies) { + ClusterManagerConfig(Map childPolicies) { this.childPolicies = Collections.unmodifiableMap(childPolicies); } diff --git a/xds/src/main/java/io/grpc/xds/CsdsService.java b/xds/src/main/java/io/grpc/xds/CsdsService.java index 0102836660c..a296beb45d0 100644 --- a/xds/src/main/java/io/grpc/xds/CsdsService.java +++ b/xds/src/main/java/io/grpc/xds/CsdsService.java @@ -39,6 +39,8 @@ import io.grpc.xds.client.XdsClient.ResourceMetadata.ResourceMetadataStatus; import io.grpc.xds.client.XdsClient.ResourceMetadata.UpdateFailureState; import io.grpc.xds.client.XdsResourceType; +import java.util.ArrayList; +import java.util.List; import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -117,43 +119,58 @@ public void onCompleted() { private boolean handleRequest( ClientStatusRequest request, StreamObserver responseObserver) { - StatusException error; - try { - responseObserver.onNext(getConfigDumpForRequest(request)); - return true; - } catch (StatusException e) { - error = e; - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - logger.log(Level.FINE, "Server interrupted while building CSDS config dump", e); - error = Status.ABORTED.withDescription("Thread interrupted").withCause(e).asException(); - } catch (RuntimeException e) { - logger.log(Level.WARNING, "Unexpected error while building CSDS config dump", e); - error = - Status.INTERNAL.withDescription("Unexpected internal error").withCause(e).asException(); - } - responseObserver.onError(error); - return false; - } + StatusException error = null; - private ClientStatusResponse getConfigDumpForRequest(ClientStatusRequest request) - throws StatusException, InterruptedException { if (request.getNodeMatchersCount() > 0) { - throw new StatusException( + error = new StatusException( Status.INVALID_ARGUMENT.withDescription("node_matchers not supported")); + } else { + List targets = xdsClientPoolFactory.getTargets(); + List clientConfigs = new ArrayList<>(targets.size()); + + for (int i = 0; i < targets.size() && error == null; i++) { + try { + ClientConfig clientConfig = getConfigForRequest(targets.get(i)); + if (clientConfig != null) { + clientConfigs.add(clientConfig); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + logger.log(Level.FINE, "Server interrupted while building CSDS config dump", e); + error = Status.ABORTED.withDescription("Thread interrupted").withCause(e).asException(); + } catch (RuntimeException e) { + logger.log(Level.WARNING, "Unexpected error while building CSDS config dump", e); + error = Status.INTERNAL.withDescription("Unexpected internal error").withCause(e) + .asException(); + } + } + + try { + responseObserver.onNext(getStatusResponse(clientConfigs)); + } catch (RuntimeException e) { + logger.log(Level.WARNING, "Unexpected error while processing CSDS config dump", e); + error = Status.INTERNAL.withDescription("Unexpected internal error").withCause(e) + .asException(); + } } - ObjectPool xdsClientPool = xdsClientPoolFactory.get(); + if (error == null) { + return true; // All clients reported without error + } + responseObserver.onError(error); + return false; + } + + private ClientConfig getConfigForRequest(String target) throws InterruptedException { + ObjectPool xdsClientPool = xdsClientPoolFactory.get(target); if (xdsClientPool == null) { - return ClientStatusResponse.getDefaultInstance(); + return null; } XdsClient xdsClient = null; try { xdsClient = xdsClientPool.getObject(); - return ClientStatusResponse.newBuilder() - .addConfig(getClientConfigForXdsClient(xdsClient)) - .build(); + return getClientConfigForXdsClient(xdsClient, target); } finally { if (xdsClient != null) { xdsClientPool.returnObject(xdsClient); @@ -161,9 +178,18 @@ private ClientStatusResponse getConfigDumpForRequest(ClientStatusRequest request } } + private ClientStatusResponse getStatusResponse(List clientConfigs) { + if (clientConfigs.isEmpty()) { + return ClientStatusResponse.getDefaultInstance(); + } + return ClientStatusResponse.newBuilder().addAllConfig(clientConfigs).build(); + } + @VisibleForTesting - static ClientConfig getClientConfigForXdsClient(XdsClient xdsClient) throws InterruptedException { + static ClientConfig getClientConfigForXdsClient(XdsClient xdsClient, String target) + throws InterruptedException { ClientConfig.Builder builder = ClientConfig.newBuilder() + .setClientScope(target) .setNode(xdsClient.getBootstrapInfo().node().toEnvoyProtoNode()); Map, Map> metadataByType = diff --git a/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java b/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java index 39b9ed0d095..0073cce1a88 100644 --- a/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java +++ b/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java @@ -36,6 +36,6 @@ public static void setDefaultProviderBootstrapOverride(Map bootstrap) public static ObjectPool getOrCreate(String target) throws XdsInitializationException { - return SharedXdsClientPoolProvider.getDefaultProvider().getOrCreate(); + return SharedXdsClientPoolProvider.getDefaultProvider().getOrCreate(target); } } diff --git a/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java b/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java index f96c171ee9c..6c13530ff49 100644 --- a/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/LeastRequestLoadBalancer.java @@ -126,9 +126,8 @@ protected void updateOverallBalancingState() { } @Override - protected ChildLbState createChildLbState(Object key, Object policyConfig, - SubchannelPicker initialPicker, ResolvedAddresses unused) { - return new LeastRequestLbState(key, pickFirstLbProvider, policyConfig, initialPicker); + protected ChildLbState createChildLbState(Object key) { + return new LeastRequestLbState(key, pickFirstLbProvider); } private void updateBalancingState(ConnectivityState state, SubchannelPicker picker) { @@ -320,13 +319,25 @@ public String toString() { protected class LeastRequestLbState extends ChildLbState { private final AtomicInteger activeRequests = new AtomicInteger(0); - public LeastRequestLbState(Object key, LoadBalancerProvider policyProvider, - Object childConfig, SubchannelPicker initialPicker) { - super(key, policyProvider, childConfig, initialPicker); + public LeastRequestLbState(Object key, LoadBalancerProvider policyProvider) { + super(key, policyProvider); } int getActiveRequests() { return activeRequests.get(); } + + @Override + protected ChildLbStateHelper createChildHelper() { + return new ChildLbStateHelper() { + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + super.updateBalancingState(newState, newPicker); + if (!resolvingAddresses && newState == IDLE) { + getLb().requestConnection(); + } + } + }; + } } } diff --git a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java index 3b7e451f2a5..4f93974b52c 100644 --- a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java @@ -27,7 +27,6 @@ import com.google.common.base.MoreObjects; import com.google.common.collect.HashMultiset; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.Multiset; import com.google.common.primitives.UnsignedInteger; import io.grpc.Attributes; @@ -42,6 +41,7 @@ import io.grpc.xds.client.XdsLogger.XdsLogLevel; import java.net.SocketAddress; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -89,19 +89,11 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { try { resolvingAddresses = true; - // Subclass handles any special manipulation to create appropriate types of ChildLbStates - Map newChildren = createChildLbMap(resolvedAddresses); - - if (newChildren.isEmpty()) { - addressValidityStatus = Status.UNAVAILABLE.withDescription( - "Ring hash lb error: EDS resolution was successful, but there were no valid addresses"); - handleNameResolutionError(addressValidityStatus); - return addressValidityStatus; + AcceptResolvedAddrRetVal acceptRetVal = acceptResolvedAddressesInternal(resolvedAddresses); + if (!acceptRetVal.status.isOk()) { + return acceptRetVal.status; } - addMissingChildren(newChildren); - updateChildrenWithResolvedAddresses(resolvedAddresses, newChildren); - // Now do the ringhash specific logic with weights and building the ring RingHashConfig config = (RingHashConfig) resolvedAddresses.getLoadBalancingPolicyConfig(); if (config == null) { @@ -145,7 +137,7 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { // clusters and resolver can remove them in service config. updateOverallBalancingState(); - shutdownRemoved(getRemovedChildren(newChildren.keySet())); + shutdownRemoved(acceptRetVal.removedChildren); } finally { this.resolvingAddresses = false; } @@ -221,15 +213,14 @@ protected void updateOverallBalancingState() { overallState = TRANSIENT_FAILURE; } - RingHashPicker picker = new RingHashPicker(syncContext, ring, getImmutableChildMap()); + RingHashPicker picker = new RingHashPicker(syncContext, ring, getChildLbStates()); getHelper().updateBalancingState(overallState, picker); this.currentConnectivityState = overallState; } @Override - protected ChildLbState createChildLbState(Object key, Object policyConfig, - SubchannelPicker initialPicker, ResolvedAddresses resolvedAddresses) { - return new RingHashChildLbState((Endpoint)key); + protected ChildLbState createChildLbState(Object key) { + return new ChildLbState(key, lazyLbFactory); } private Status validateAddrList(List addrList) { @@ -353,13 +344,12 @@ private static final class RingHashPicker extends SubchannelPicker { private RingHashPicker( SynchronizationContext syncContext, List ring, - ImmutableMap subchannels) { + Collection children) { this.syncContext = syncContext; this.ring = ring; - pickableSubchannels = new HashMap<>(subchannels.size()); - for (Map.Entry entry : subchannels.entrySet()) { - RingHashChildLbState childLbState = (RingHashChildLbState) entry.getValue(); - pickableSubchannels.put((Endpoint)entry.getKey(), + pickableSubchannels = new HashMap<>(children.size()); + for (ChildLbState childLbState : children) { + pickableSubchannels.put((Endpoint)childLbState.getKey(), new SubchannelView(childLbState, childLbState.getCurrentState())); } } @@ -405,7 +395,7 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { for (int i = 0; i < ring.size(); i++) { int index = (targetIndex + i) % ring.size(); SubchannelView subchannelView = pickableSubchannels.get(ring.get(index).addrKey); - RingHashChildLbState childLbState = subchannelView.childLbState; + ChildLbState childLbState = subchannelView.childLbState; if (subchannelView.connectivityState == READY) { return childLbState.getCurrentPicker().pickSubchannel(args); @@ -427,7 +417,7 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { } // return the pick from the original subchannel hit by hash, which is probably an error - RingHashChildLbState originalSubchannel = + ChildLbState originalSubchannel = pickableSubchannels.get(ring.get(targetIndex).addrKey).childLbState; return originalSubchannel.getCurrentPicker().pickSubchannel(args); } @@ -439,10 +429,10 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { * state changes. */ private static final class SubchannelView { - private final RingHashChildLbState childLbState; + private final ChildLbState childLbState; private final ConnectivityState connectivityState; - private SubchannelView(RingHashChildLbState childLbState, ConnectivityState state) { + private SubchannelView(ChildLbState childLbState, ConnectivityState state) { this.childLbState = childLbState; this.connectivityState = state; } @@ -487,41 +477,4 @@ public String toString() { .toString(); } } - - class RingHashChildLbState extends MultiChildLoadBalancer.ChildLbState { - - public RingHashChildLbState(Endpoint key) { - super(key, lazyLbFactory, null, EMPTY_PICKER); - } - - @Override - protected ChildLbStateHelper createChildHelper() { - return new RingHashChildHelper(); - } - - // Need to expose this to the LB class - @Override - protected void shutdown() { - super.shutdown(); - } - - private class RingHashChildHelper extends ChildLbStateHelper { - @Override - public void updateBalancingState(final ConnectivityState newState, - final SubchannelPicker newPicker) { - setCurrentState(newState); - setCurrentPicker(newPicker); - - if (getChildLbState(getKey()) == null) { - return; - } - - // If we are already in the process of resolving addresses, the overall balancing state - // will be updated at the end of it, and we don't need to trigger that update here. - if (!resolvingAddresses) { - updateOverallBalancingState(); - } - } - } - } } diff --git a/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java b/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java index 5ae1f5bbce5..c9195896d82 100644 --- a/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java +++ b/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java @@ -20,6 +20,7 @@ import static io.grpc.xds.GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; import io.grpc.internal.ExponentialBackoffPolicy; import io.grpc.internal.GrpcUtil; import io.grpc.internal.ObjectPool; @@ -32,6 +33,7 @@ import io.grpc.xds.client.XdsInitializationException; import io.grpc.xds.internal.security.TlsContextManagerImpl; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.atomic.AtomicReference; import java.util.logging.Level; @@ -53,7 +55,7 @@ final class SharedXdsClientPoolProvider implements XdsClientPoolFactory { private final Bootstrapper bootstrapper; private final Object lock = new Object(); private final AtomicReference> bootstrapOverride = new AtomicReference<>(); - private volatile ObjectPool xdsClientPool; + private final Map> targetToXdsClientMap = new ConcurrentHashMap<>(); SharedXdsClientPoolProvider() { this(new GrpcBootstrapperImpl()); @@ -75,16 +77,16 @@ public void setBootstrapOverride(Map bootstrap) { @Override @Nullable - public ObjectPool get() { - return xdsClientPool; + public ObjectPool get(String target) { + return targetToXdsClientMap.get(target); } @Override - public ObjectPool getOrCreate() throws XdsInitializationException { - ObjectPool ref = xdsClientPool; + public ObjectPool getOrCreate(String target) throws XdsInitializationException { + ObjectPool ref = targetToXdsClientMap.get(target); if (ref == null) { synchronized (lock) { - ref = xdsClientPool; + ref = targetToXdsClientMap.get(target); if (ref == null) { BootstrapInfo bootstrapInfo; Map rawBootstrap = bootstrapOverride.get(); @@ -96,13 +98,20 @@ public ObjectPool getOrCreate() throws XdsInitializationException { if (bootstrapInfo.servers().isEmpty()) { throw new XdsInitializationException("No xDS server provided"); } - ref = xdsClientPool = new RefCountedXdsClientObjectPool(bootstrapInfo); + ref = new RefCountedXdsClientObjectPool(bootstrapInfo, target); + targetToXdsClientMap.put(target, ref); } } } return ref; } + @Override + public ImmutableList getTargets() { + return ImmutableList.copyOf(targetToXdsClientMap.keySet()); + } + + private static class SharedXdsClientPoolProviderHolder { private static final SharedXdsClientPoolProvider instance = new SharedXdsClientPoolProvider(); } @@ -110,7 +119,11 @@ private static class SharedXdsClientPoolProviderHolder { @ThreadSafe @VisibleForTesting static class RefCountedXdsClientObjectPool implements ObjectPool { + + private static final ExponentialBackoffPolicy.Provider BACKOFF_POLICY_PROVIDER = + new ExponentialBackoffPolicy.Provider(); private final BootstrapInfo bootstrapInfo; + private final String target; // The target associated with the xDS client. private final Object lock = new Object(); @GuardedBy("lock") private ScheduledExecutorService scheduler; @@ -120,8 +133,9 @@ static class RefCountedXdsClientObjectPool implements ObjectPool { private int refCount; @VisibleForTesting - RefCountedXdsClientObjectPool(BootstrapInfo bootstrapInfo) { + RefCountedXdsClientObjectPool(BootstrapInfo bootstrapInfo, String target) { this.bootstrapInfo = checkNotNull(bootstrapInfo); + this.target = target; } @Override @@ -136,7 +150,7 @@ public XdsClient getObject() { DEFAULT_XDS_TRANSPORT_FACTORY, bootstrapInfo, scheduler, - new ExponentialBackoffPolicy.Provider(), + BACKOFF_POLICY_PROVIDER, GrpcUtil.STOPWATCH_SUPPLIER, TimeProvider.SYSTEM_TIME_PROVIDER, MessagePrinter.INSTANCE, @@ -167,5 +181,10 @@ XdsClient getXdsClientForTest() { return xdsClient; } } + + public String getTarget() { + return target; + } } + } diff --git a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java index 115857d43ff..73764c63c80 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancer.java @@ -17,7 +17,6 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkElementIndex; import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; @@ -30,7 +29,6 @@ import io.grpc.Deadline.Ticker; import io.grpc.DoubleHistogramMetricInstrument; import io.grpc.EquivalentAddressGroup; -import io.grpc.ExperimentalApi; import io.grpc.LoadBalancer; import io.grpc.LoadBalancerProvider; import io.grpc.LongCounterMetricInstrument; @@ -40,18 +38,16 @@ import io.grpc.SynchronizationContext; import io.grpc.SynchronizationContext.ScheduledHandle; import io.grpc.services.MetricReport; -import io.grpc.util.ForwardingLoadBalancerHelper; import io.grpc.util.ForwardingSubchannel; import io.grpc.util.MultiChildLoadBalancer; import io.grpc.xds.orca.OrcaOobUtil; import io.grpc.xds.orca.OrcaOobUtil.OrcaOobReportListener; import io.grpc.xds.orca.OrcaPerRequestUtil; import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaPerRequestReportListener; +import java.util.ArrayList; import java.util.Collection; -import java.util.HashMap; import java.util.HashSet; import java.util.List; -import java.util.Map; import java.util.Random; import java.util.Set; import java.util.concurrent.ScheduledExecutorService; @@ -89,7 +85,6 @@ * * See related documentation: https://cloud.google.com/service-mesh/legacy/load-balancing-apis/proxyless-configure-advanced-traffic-management#custom-lb-config */ -@ExperimentalApi("https://github.com/grpc/grpc-java/issues/9885") final class WeightedRoundRobinLoadBalancer extends MultiChildLoadBalancer { private static final LongCounterMetricInstrument RR_FALLBACK_COUNTER; @@ -137,12 +132,12 @@ final class WeightedRoundRobinLoadBalancer extends MultiChildLoadBalancer { } public WeightedRoundRobinLoadBalancer(Helper helper, Ticker ticker) { - this(new WrrHelper(OrcaOobUtil.newOrcaReportingHelper(helper)), ticker, new Random()); + this(helper, ticker, new Random()); } - public WeightedRoundRobinLoadBalancer(WrrHelper helper, Ticker ticker, Random random) { - super(helper); - helper.setLoadBalancer(this); + @VisibleForTesting + WeightedRoundRobinLoadBalancer(Helper helper, Ticker ticker, Random random) { + super(OrcaOobUtil.newOrcaReportingHelper(helper)); this.ticker = checkNotNull(ticker, "ticker"); this.infTime = ticker.nanoTime() + Long.MAX_VALUE; this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext"); @@ -152,17 +147,9 @@ public WeightedRoundRobinLoadBalancer(WrrHelper helper, Ticker ticker, Random ra log.log(Level.FINE, "weighted_round_robin LB created"); } - @VisibleForTesting - WeightedRoundRobinLoadBalancer(Helper helper, Ticker ticker, Random random) { - this(new WrrHelper(OrcaOobUtil.newOrcaReportingHelper(helper)), ticker, random); - } - @Override - protected ChildLbState createChildLbState(Object key, Object policyConfig, - SubchannelPicker initialPicker, ResolvedAddresses unused) { - ChildLbState childLbState = new WeightedChildLbState(key, pickFirstLbProvider, policyConfig, - initialPicker); - return childLbState; + protected ChildLbState createChildLbState(Object key) { + return new WeightedChildLbState(key, pickFirstLbProvider); } @Override @@ -242,9 +229,44 @@ protected void updateOverallBalancingState() { } private SubchannelPicker createReadyPicker(Collection activeList) { - return new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList), - config.enableOobLoadReport, config.errorUtilizationPenalty, sequence, getHelper(), - locality); + WeightedRoundRobinPicker picker = new WeightedRoundRobinPicker(ImmutableList.copyOf(activeList), + config.enableOobLoadReport, config.errorUtilizationPenalty, sequence); + updateWeight(picker); + return picker; + } + + private void updateWeight(WeightedRoundRobinPicker picker) { + Helper helper = getHelper(); + float[] newWeights = new float[picker.children.size()]; + AtomicInteger staleEndpoints = new AtomicInteger(); + AtomicInteger notYetUsableEndpoints = new AtomicInteger(); + for (int i = 0; i < picker.children.size(); i++) { + double newWeight = ((WeightedChildLbState) picker.children.get(i)).getWeight(staleEndpoints, + notYetUsableEndpoints); + helper.getMetricRecorder() + .recordDoubleHistogram(ENDPOINT_WEIGHTS_HISTOGRAM, newWeight, + ImmutableList.of(helper.getChannelTarget()), + ImmutableList.of(locality)); + newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f; + } + + if (staleEndpoints.get() > 0) { + helper.getMetricRecorder() + .addLongCounter(ENDPOINT_WEIGHT_STALE_COUNTER, staleEndpoints.get(), + ImmutableList.of(helper.getChannelTarget()), + ImmutableList.of(locality)); + } + if (notYetUsableEndpoints.get() > 0) { + helper.getMetricRecorder() + .addLongCounter(ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER, notYetUsableEndpoints.get(), + ImmutableList.of(helper.getChannelTarget()), ImmutableList.of(locality)); + } + boolean weightsEffective = picker.updateWeight(newWeights); + if (!weightsEffective) { + helper.getMetricRecorder() + .addLongCounter(RR_FALLBACK_COUNTER, 1, ImmutableList.of(helper.getChannelTarget()), + ImmutableList.of(locality)); + } } private void updateBalancingState(ConnectivityState state, SubchannelPicker picker) { @@ -265,9 +287,13 @@ final class WeightedChildLbState extends ChildLbState { private OrcaReportListener orcaReportListener; - public WeightedChildLbState(Object key, LoadBalancerProvider policyProvider, Object childConfig, - SubchannelPicker initialPicker) { - super(key, policyProvider, childConfig, initialPicker); + public WeightedChildLbState(Object key, LoadBalancerProvider policyProvider) { + super(key, policyProvider); + } + + @Override + protected ChildLbStateHelper createChildHelper() { + return new WrrChildLbStateHelper(); } private double getWeight(AtomicInteger staleEndpoints, AtomicInteger notYetUsableEndpoints) { @@ -305,6 +331,21 @@ public void removeSubchannel(WrrSubchannel wrrSubchannel) { subchannels.remove(wrrSubchannel); } + final class WrrChildLbStateHelper extends ChildLbStateHelper { + @Override + public Subchannel createSubchannel(CreateSubchannelArgs args) { + return new WrrSubchannel(super.createSubchannel(args), WeightedChildLbState.this); + } + + @Override + public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) { + super.updateBalancingState(newState, newPicker); + if (!resolvingAddresses && newState == ConnectivityState.IDLE) { + getLb().requestConnection(); + } + } + } + final class OrcaReportListener implements OrcaPerRequestReportListener, OrcaOobReportListener { private final float errorUtilizationPenalty; @@ -342,7 +383,7 @@ private final class UpdateWeightTask implements Runnable { @Override public void run() { if (currentPicker != null && currentPicker instanceof WeightedRoundRobinPicker) { - ((WeightedRoundRobinPicker) currentPicker).updateWeight(); + updateWeight((WeightedRoundRobinPicker) currentPicker); } weightUpdateTimer = syncContext.schedule(this, config.weightUpdatePeriodNanos, TimeUnit.NANOSECONDS, timeService); @@ -374,32 +415,6 @@ public void shutdown() { super.shutdown(); } - private static final class WrrHelper extends ForwardingLoadBalancerHelper { - private final Helper delegate; - private WeightedRoundRobinLoadBalancer wrr; - - WrrHelper(Helper helper) { - this.delegate = helper; - } - - void setLoadBalancer(WeightedRoundRobinLoadBalancer lb) { - this.wrr = lb; - } - - @Override - protected Helper delegate() { - return delegate; - } - - @Override - public Subchannel createSubchannel(CreateSubchannelArgs args) { - checkElementIndex(0, args.getAddresses().size(), "Empty address group"); - WeightedChildLbState childLbState = - (WeightedChildLbState) wrr.getChildLbStateEag(args.getAddresses().get(0)); - return wrr.new WrrSubchannel(delegate().createSubchannel(args), childLbState); - } - } - @VisibleForTesting final class WrrSubchannel extends ForwardingSubchannel { private final Subchannel delegate; @@ -438,53 +453,50 @@ public void shutdown() { @VisibleForTesting static final class WeightedRoundRobinPicker extends SubchannelPicker { - private final List children; - private final Map subchannelToReportListenerMap = - new HashMap<>(); + // Parallel lists (column-based storage instead of normal row-based storage of List). + // The ith element of children corresponds to the ith element of pickers, listeners, and even + // updateWeight(float[]). + private final List children; // May only be accessed from sync context + private final List pickers; + private final List reportListeners; private final boolean enableOobLoadReport; private final float errorUtilizationPenalty; private final AtomicInteger sequence; private final int hashCode; - private final LoadBalancer.Helper helper; - private final String locality; private volatile StaticStrideScheduler scheduler; WeightedRoundRobinPicker(List children, boolean enableOobLoadReport, - float errorUtilizationPenalty, AtomicInteger sequence, LoadBalancer.Helper helper, - String locality) { + float errorUtilizationPenalty, AtomicInteger sequence) { checkNotNull(children, "children"); Preconditions.checkArgument(!children.isEmpty(), "empty child list"); this.children = children; + List pickers = new ArrayList<>(children.size()); + List reportListeners = new ArrayList<>(children.size()); for (ChildLbState child : children) { WeightedChildLbState wChild = (WeightedChildLbState) child; - for (WrrSubchannel subchannel : wChild.subchannels) { - this.subchannelToReportListenerMap - .put(subchannel, wChild.getOrCreateOrcaListener(errorUtilizationPenalty)); - } + pickers.add(wChild.getCurrentPicker()); + reportListeners.add(wChild.getOrCreateOrcaListener(errorUtilizationPenalty)); } + this.pickers = pickers; + this.reportListeners = reportListeners; this.enableOobLoadReport = enableOobLoadReport; this.errorUtilizationPenalty = errorUtilizationPenalty; this.sequence = checkNotNull(sequence, "sequence"); - this.helper = helper; - this.locality = checkNotNull(locality, "locality"); - // For equality we treat children as a set; use hash code as defined by Set + // For equality we treat pickers as a set; use hash code as defined by Set int sum = 0; - for (ChildLbState child : children) { - sum += child.hashCode(); + for (SubchannelPicker picker : pickers) { + sum += picker.hashCode(); } this.hashCode = sum ^ Boolean.hashCode(enableOobLoadReport) ^ Float.hashCode(errorUtilizationPenalty); - - updateWeight(); } @Override public PickResult pickSubchannel(PickSubchannelArgs args) { - ChildLbState childLbState = children.get(scheduler.pick()); - WeightedChildLbState wChild = (WeightedChildLbState) childLbState; - PickResult pickResult = childLbState.getCurrentPicker().pickSubchannel(args); + int pick = scheduler.pick(); + PickResult pickResult = pickers.get(pick).pickSubchannel(args); Subchannel subchannel = pickResult.getSubchannel(); if (subchannel == null) { return pickResult; @@ -492,48 +504,16 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { if (!enableOobLoadReport) { return PickResult.withSubchannel(subchannel, OrcaPerRequestUtil.getInstance().newOrcaClientStreamTracerFactory( - subchannelToReportListenerMap.getOrDefault(subchannel, - wChild.getOrCreateOrcaListener(errorUtilizationPenalty)))); + reportListeners.get(pick))); } else { return PickResult.withSubchannel(subchannel); } } - private void updateWeight() { - float[] newWeights = new float[children.size()]; - AtomicInteger staleEndpoints = new AtomicInteger(); - AtomicInteger notYetUsableEndpoints = new AtomicInteger(); - for (int i = 0; i < children.size(); i++) { - double newWeight = ((WeightedChildLbState) children.get(i)).getWeight(staleEndpoints, - notYetUsableEndpoints); - // TODO: add locality label once available - helper.getMetricRecorder() - .recordDoubleHistogram(ENDPOINT_WEIGHTS_HISTOGRAM, newWeight, - ImmutableList.of(helper.getChannelTarget()), - ImmutableList.of(locality)); - newWeights[i] = newWeight > 0 ? (float) newWeight : 0.0f; - } - if (staleEndpoints.get() > 0) { - // TODO: add locality label once available - helper.getMetricRecorder() - .addLongCounter(ENDPOINT_WEIGHT_STALE_COUNTER, staleEndpoints.get(), - ImmutableList.of(helper.getChannelTarget()), - ImmutableList.of(locality)); - } - if (notYetUsableEndpoints.get() > 0) { - // TODO: add locality label once available - helper.getMetricRecorder() - .addLongCounter(ENDPOINT_WEIGHT_NOT_YET_USEABLE_COUNTER, notYetUsableEndpoints.get(), - ImmutableList.of(helper.getChannelTarget()), ImmutableList.of(locality)); - } - + /** Returns {@code true} if weights are different than round_robin. */ + private boolean updateWeight(float[] newWeights) { this.scheduler = new StaticStrideScheduler(newWeights, sequence); - if (this.scheduler.usesRoundRobin()) { - // TODO: locality label once available - helper.getMetricRecorder() - .addLongCounter(RR_FALLBACK_COUNTER, 1, ImmutableList.of(helper.getChannelTarget()), - ImmutableList.of(locality)); - } + return !this.scheduler.usesRoundRobin(); } @Override @@ -541,7 +521,8 @@ public String toString() { return MoreObjects.toStringHelper(WeightedRoundRobinPicker.class) .add("enableOobLoadReport", enableOobLoadReport) .add("errorUtilizationPenalty", errorUtilizationPenalty) - .add("list", children).toString(); + .add("pickers", pickers) + .toString(); } @VisibleForTesting @@ -568,8 +549,8 @@ public boolean equals(Object o) { && sequence == other.sequence && enableOobLoadReport == other.enableOobLoadReport && Float.compare(errorUtilizationPenalty, other.errorUtilizationPenalty) == 0 - && children.size() == other.children.size() - && new HashSet<>(children).containsAll(other.children); + && pickers.size() == other.pickers.size() + && new HashSet<>(pickers).containsAll(other.pickers); } } diff --git a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProvider.java index 161e7c4ed0c..433ea34b857 100644 --- a/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/WeightedRoundRobinLoadBalancerProvider.java @@ -18,7 +18,6 @@ import com.google.common.annotations.VisibleForTesting; import io.grpc.Deadline; -import io.grpc.ExperimentalApi; import io.grpc.Internal; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; @@ -32,7 +31,6 @@ /** * Provides a {@link WeightedRoundRobinLoadBalancer}. * */ -@ExperimentalApi("https://github.com/grpc/grpc-java/issues/9885") @Internal public final class WeightedRoundRobinLoadBalancerProvider extends LoadBalancerProvider { diff --git a/xds/src/main/java/io/grpc/xds/XdsClientPoolFactory.java b/xds/src/main/java/io/grpc/xds/XdsClientPoolFactory.java index c649b3b3069..313eb675116 100644 --- a/xds/src/main/java/io/grpc/xds/XdsClientPoolFactory.java +++ b/xds/src/main/java/io/grpc/xds/XdsClientPoolFactory.java @@ -19,6 +19,7 @@ import io.grpc.internal.ObjectPool; import io.grpc.xds.client.XdsClient; import io.grpc.xds.client.XdsInitializationException; +import java.util.List; import java.util.Map; import javax.annotation.Nullable; @@ -26,7 +27,9 @@ interface XdsClientPoolFactory { void setBootstrapOverride(Map bootstrap); @Nullable - ObjectPool get(); + ObjectPool get(String target); - ObjectPool getOrCreate() throws XdsInitializationException; + ObjectPool getOrCreate(String target) throws XdsInitializationException; + + List getTargets(); } diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java index 9ad9b6e82f0..ca73b7d8451 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java @@ -66,6 +66,7 @@ import io.grpc.xds.client.XdsClient.ResourceWatcher; import io.grpc.xds.client.XdsLogger; import io.grpc.xds.client.XdsLogger.XdsLogLevel; +import java.net.URI; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -104,6 +105,7 @@ final class XdsNameResolver extends NameResolver { private final XdsLogger logger; @Nullable private final String targetAuthority; + private final String target; private final String serviceAuthority; // Encoded version of the service authority as per // https://datatracker.ietf.org/doc/html/rfc3986#section-3.2. @@ -133,23 +135,24 @@ final class XdsNameResolver extends NameResolver { private boolean receivedConfig; XdsNameResolver( - @Nullable String targetAuthority, String name, @Nullable String overrideAuthority, + URI targetUri, String name, @Nullable String overrideAuthority, ServiceConfigParser serviceConfigParser, SynchronizationContext syncContext, ScheduledExecutorService scheduler, @Nullable Map bootstrapOverride) { - this(targetAuthority, name, overrideAuthority, serviceConfigParser, syncContext, scheduler, - SharedXdsClientPoolProvider.getDefaultProvider(), ThreadSafeRandomImpl.instance, - FilterRegistry.getDefaultRegistry(), bootstrapOverride); + this(targetUri, targetUri.getAuthority(), name, overrideAuthority, serviceConfigParser, + syncContext, scheduler, SharedXdsClientPoolProvider.getDefaultProvider(), + ThreadSafeRandomImpl.instance, FilterRegistry.getDefaultRegistry(), bootstrapOverride); } @VisibleForTesting XdsNameResolver( - @Nullable String targetAuthority, String name, @Nullable String overrideAuthority, - ServiceConfigParser serviceConfigParser, + URI targetUri, @Nullable String targetAuthority, String name, + @Nullable String overrideAuthority, ServiceConfigParser serviceConfigParser, SynchronizationContext syncContext, ScheduledExecutorService scheduler, XdsClientPoolFactory xdsClientPoolFactory, ThreadSafeRandom random, FilterRegistry filterRegistry, @Nullable Map bootstrapOverride) { this.targetAuthority = targetAuthority; + target = targetUri.toString(); // The name might have multiple slashes so encode it before verifying. serviceAuthority = checkNotNull(name, "name"); @@ -180,7 +183,7 @@ public String getServiceAuthority() { public void start(Listener2 listener) { this.listener = checkNotNull(listener, "listener"); try { - xdsClientPool = xdsClientPoolFactory.getOrCreate(); + xdsClientPool = xdsClientPoolFactory.getOrCreate(target); } catch (Exception e) { listener.onError( Status.UNAVAILABLE.withDescription("Failed to initialize xDS").withCause(e)); @@ -812,10 +815,12 @@ private void cleanUpRoutes(String error) { // the config selector handles the error message itself. Once the LB API allows providing // failure information for addresses yet still providing a service config, the config seector // could be avoided. + String errorWithNodeId = + error + ", xDS node ID: " + xdsClient.getBootstrapInfo().node().getId(); listener.onResult(ResolutionResult.newBuilder() .setAttributes(Attributes.newBuilder() .set(InternalConfigSelector.KEY, - new FailingConfigSelector(Status.UNAVAILABLE.withDescription(error))) + new FailingConfigSelector(Status.UNAVAILABLE.withDescription(errorWithNodeId))) .build()) .setServiceConfig(emptyServiceConfig) .build()); diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java b/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java index 598be07fcd8..8d0e59eaa91 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolverProvider.java @@ -78,7 +78,7 @@ public XdsNameResolver newNameResolver(URI targetUri, Args args) { targetUri); String name = targetPath.substring(1); return new XdsNameResolver( - targetUri.getAuthority(), name, args.getOverrideAuthority(), + targetUri, name, args.getOverrideAuthority(), args.getServiceConfigParser(), args.getSynchronizationContext(), args.getScheduledExecutorService(), bootstrapOverride); diff --git a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java index bf8603fb3e4..bd622a71124 100644 --- a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java +++ b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java @@ -171,7 +171,7 @@ public void run() { private void internalStart() { try { - xdsClientPool = xdsClientPoolFactory.getOrCreate(); + xdsClientPool = xdsClientPoolFactory.getOrCreate(""); } catch (Exception e) { StatusException statusException = Status.UNAVAILABLE.withDescription( "Failed to initialize xDS").withCause(e).asException(); @@ -425,7 +425,8 @@ public void onResourceDoesNotExist(final String resourceName) { return; } StatusException statusException = Status.UNAVAILABLE.withDescription( - "Listener " + resourceName + " unavailable").asException(); + String.format("Listener %s unavailable, xDS node ID: %s", resourceName, + xdsClient.getBootstrapInfo().node().getId())).asException(); handleConfigNotFound(statusException); } @@ -434,9 +435,12 @@ public void onError(final Status error) { if (stopped) { return; } - logger.log(Level.FINE, "Error from XdsClient", error); + String description = error.getDescription() == null ? "" : error.getDescription() + " "; + Status errorWithNodeId = error.withDescription( + description + "xDS node ID: " + xdsClient.getBootstrapInfo().node().getId()); + logger.log(Level.FINE, "Error from XdsClient", errorWithNodeId); if (!isServing) { - listener.onNotServing(error.asException()); + listener.onNotServing(errorWithNodeId.asException()); } } @@ -664,8 +668,11 @@ public void run() { if (!routeDiscoveryStates.containsKey(resourceName)) { return; } + String description = error.getDescription() == null ? "" : error.getDescription() + " "; + Status errorWithNodeId = error.withDescription( + description + "xDS node ID: " + xdsClient.getBootstrapInfo().node().getId()); logger.log(Level.WARNING, "Error loading RDS resource {0} from XdsClient: {1}.", - new Object[]{resourceName, error}); + new Object[]{resourceName, errorWithNodeId}); maybeUpdateSelector(); } }); diff --git a/xds/src/main/java/io/grpc/xds/client/LoadStatsManager2.java b/xds/src/main/java/io/grpc/xds/client/LoadStatsManager2.java index 393cce16194..be9d3587d14 100644 --- a/xds/src/main/java/io/grpc/xds/client/LoadStatsManager2.java +++ b/xds/src/main/java/io/grpc/xds/client/LoadStatsManager2.java @@ -91,7 +91,7 @@ private synchronized void releaseClusterDropCounter( String cluster, @Nullable String edsServiceName) { checkState(allDropStats.containsKey(cluster) && allDropStats.get(cluster).containsKey(edsServiceName), - "stats for cluster %s, edsServiceName %s not exits", cluster, edsServiceName); + "stats for cluster %s, edsServiceName %s do not exist", cluster, edsServiceName); ReferenceCounted ref = allDropStats.get(cluster).get(edsServiceName); ref.release(); } diff --git a/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java b/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java index 0884587cd95..da32332a2a5 100644 --- a/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java +++ b/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java @@ -58,7 +58,9 @@ import io.grpc.xds.LeastRequestLoadBalancer.LeastRequestConfig; import io.grpc.xds.RingHashLoadBalancer.RingHashConfig; import io.grpc.xds.XdsClusterResource.CdsUpdate; +import io.grpc.xds.client.Bootstrapper.BootstrapInfo; import io.grpc.xds.client.Bootstrapper.ServerInfo; +import io.grpc.xds.client.EnvoyProtoData; import io.grpc.xds.client.XdsClient; import io.grpc.xds.client.XdsResourceType; import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; @@ -94,6 +96,16 @@ public class CdsLoadBalancer2Test { private static final String DNS_HOST_NAME = "backend-service-dns.googleapis.com:443"; private static final ServerInfo LRS_SERVER_INFO = ServerInfo.create("lrs.googleapis.com", InsecureChannelCredentials.create()); + private static final String SERVER_URI = "trafficdirector.googleapis.com"; + private static final String NODE_ID = + "projects/42/networks/default/nodes/5c85b298-6f5b-4722-b74a-f7d1f0ccf5ad"; + private static final EnvoyProtoData.Node BOOTSTRAP_NODE = + EnvoyProtoData.Node.newBuilder().setId(NODE_ID).build(); + private static final BootstrapInfo BOOTSTRAP_INFO = BootstrapInfo.builder() + .servers(ImmutableList.of( + ServerInfo.create(SERVER_URI, InsecureChannelCredentials.create()))) + .node(BOOTSTRAP_NODE) + .build(); private final UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true); private final OutlierDetection outlierDetection = OutlierDetection.create( @@ -211,7 +223,8 @@ public void nonAggregateCluster_resourceNotExist_returnErrorPicker() { verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); Status unavailable = Status.UNAVAILABLE.withDescription( - "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster " + CLUSTER); + "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster " + CLUSTER + + " xDS node ID: " + NODE_ID); assertPicker(pickerCaptor.getValue(), unavailable, null); assertThat(childBalancers).isEmpty(); } @@ -254,7 +267,8 @@ public void nonAggregateCluster_resourceRevoked() { xdsClient.deliverResourceNotExist(CLUSTER); assertThat(childBalancer.shutdown).isTrue(); Status unavailable = Status.UNAVAILABLE.withDescription( - "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster " + CLUSTER); + "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster " + CLUSTER + + " xDS node ID: " + NODE_ID); verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); assertPicker(pickerCaptor.getValue(), unavailable, null); @@ -331,7 +345,8 @@ public void aggregateCluster_noNonAggregateClusterExits_returnErrorPicker() { verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); Status unavailable = Status.UNAVAILABLE.withDescription( - "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster " + CLUSTER); + "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster " + CLUSTER + + " xDS node ID: " + NODE_ID); assertPicker(pickerCaptor.getValue(), unavailable, null); assertThat(childBalancers).isEmpty(); } @@ -379,7 +394,8 @@ public void aggregateCluster_descendantClustersRevoked() { verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); Status unavailable = Status.UNAVAILABLE.withDescription( - "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster " + CLUSTER); + "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster " + CLUSTER + + " xDS node ID: " + NODE_ID); assertPicker(pickerCaptor.getValue(), unavailable, null); assertThat(childBalancer.shutdown).isTrue(); assertThat(childBalancers).isEmpty(); @@ -418,7 +434,8 @@ public void aggregateCluster_rootClusterRevoked() { verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); Status unavailable = Status.UNAVAILABLE.withDescription( - "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster " + CLUSTER); + "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster " + CLUSTER + + " xDS node ID: " + NODE_ID); assertPicker(pickerCaptor.getValue(), unavailable, null); assertThat(childBalancer.shutdown).isTrue(); assertThat(childBalancers).isEmpty(); @@ -466,7 +483,8 @@ public void aggregateCluster_intermediateClusterChanges() { verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); Status unavailable = Status.UNAVAILABLE.withDescription( - "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster " + CLUSTER); + "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster " + CLUSTER + + " xDS node ID: " + NODE_ID); assertPicker(pickerCaptor.getValue(), unavailable, null); assertThat(childBalancer.shutdown).isTrue(); assertThat(childBalancers).isEmpty(); @@ -507,7 +525,7 @@ public void aggregateCluster_withLoops() { Status unavailable = Status.UNAVAILABLE.withDescription( "CDS error: circular aggregate clusters directly under cluster-02.googleapis.com for root" + " cluster cluster-foo.googleapis.com, named [cluster-01.googleapis.com," - + " cluster-02.googleapis.com]"); + + " cluster-02.googleapis.com], xDS node ID: " + NODE_ID); assertPicker(pickerCaptor.getValue(), unavailable, null); } @@ -549,7 +567,7 @@ public void aggregateCluster_withLoops_afterEds() { Status unavailable = Status.UNAVAILABLE.withDescription( "CDS error: circular aggregate clusters directly under cluster-02.googleapis.com for root" + " cluster cluster-foo.googleapis.com, named [cluster-01.googleapis.com," - + " cluster-02.googleapis.com]"); + + " cluster-02.googleapis.com], xDS node ID: " + NODE_ID); assertPicker(pickerCaptor.getValue(), unavailable, null); } @@ -617,7 +635,7 @@ public void aggregateCluster_discoveryErrorBeforeChildLbCreated_returnErrorPicke eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); Status expectedError = Status.UNAVAILABLE.withDescription( "Unable to load CDS cluster-foo.googleapis.com. xDS server returned: " - + "RESOURCE_EXHAUSTED: OOM"); + + "RESOURCE_EXHAUSTED: OOM xDS node ID: " + NODE_ID); assertPicker(pickerCaptor.getValue(), expectedError, null); assertThat(childBalancers).isEmpty(); } @@ -647,7 +665,8 @@ public void aggregateCluster_discoveryErrorAfterChildLbCreated_propagateToChildL @Test public void handleNameResolutionErrorFromUpstream_beforeChildLbCreated_returnErrorPicker() { - Status upstreamError = Status.UNAVAILABLE.withDescription("unreachable"); + Status upstreamError = Status.UNAVAILABLE.withDescription( + "unreachable xDS node ID: " + NODE_ID); loadBalancer.handleNameResolutionError(upstreamError); verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); @@ -821,6 +840,11 @@ public void cancelXdsResourceWatch(XdsResourceType } } + @Override + public BootstrapInfo getBootstrapInfo() { + return BOOTSTRAP_INFO; + } + private void deliverCdsUpdate(String clusterName, CdsUpdate update) { if (watchers.containsKey(clusterName)) { List> resourceWatchers = diff --git a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java index 4e12a5717ae..aaaed9554f4 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java @@ -16,6 +16,7 @@ package io.grpc.xds; +import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.truth.Truth.assertThat; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.mock; @@ -29,6 +30,7 @@ import io.grpc.CallOptions; import io.grpc.ClientStreamTracer; import io.grpc.ConnectivityState; +import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; import io.grpc.InsecureChannelCredentials; import io.grpc.LoadBalancer; @@ -40,7 +42,9 @@ import io.grpc.LoadBalancer.ResolvedAddresses; import io.grpc.LoadBalancer.Subchannel; import io.grpc.LoadBalancer.SubchannelPicker; +import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.LoadBalancerProvider; +import io.grpc.LoadBalancerRegistry; import io.grpc.ManagedChannel; import io.grpc.Metadata; import io.grpc.Status; @@ -76,9 +80,11 @@ import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Queue; import java.util.concurrent.atomic.AtomicLong; import javax.annotation.Nonnull; import javax.annotation.Nullable; @@ -145,7 +151,7 @@ public AtomicLong getOrCreate(String cluster, @Nullable String edsServiceName) { return new AtomicLong(); } }; - private final Helper helper = new FakeLbHelper(); + private final FakeLbHelper helper = new FakeLbHelper(); private PickSubchannelArgs pickSubchannelArgs = new PickSubchannelArgsImpl( TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT, new PickDetailsConsumer() {}); @@ -272,9 +278,10 @@ public void pick_addsLocalityLabel() { EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); deliverAddressesAndConfig(Collections.singletonList(endpoint), config); FakeLoadBalancer leafBalancer = Iterables.getOnlyElement(downstreamBalancers); - Subchannel subchannel = leafBalancer.helper.createSubchannel( - CreateSubchannelArgs.newBuilder().setAddresses(leafBalancer.addresses).build()); - leafBalancer.deliverSubchannelState(subchannel, ConnectivityState.READY); + leafBalancer.createSubChannel(); + FakeSubchannel fakeSubchannel = helper.subchannels.poll(); + fakeSubchannel.setConnectedEagIndex(0); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); assertThat(currentState).isEqualTo(ConnectivityState.READY); PickDetailsConsumer detailsConsumer = mock(PickDetailsConsumer.class); @@ -300,9 +307,10 @@ public void recordLoadStats() { EquivalentAddressGroup endpoint = makeAddress("endpoint-addr", locality); deliverAddressesAndConfig(Collections.singletonList(endpoint), config); FakeLoadBalancer leafBalancer = Iterables.getOnlyElement(downstreamBalancers); - Subchannel subchannel = leafBalancer.helper.createSubchannel( - CreateSubchannelArgs.newBuilder().setAddresses(leafBalancer.addresses).build()); - leafBalancer.deliverSubchannelState(subchannel, ConnectivityState.READY); + Subchannel subchannel = leafBalancer.createSubChannel(); + FakeSubchannel fakeSubchannel = helper.subchannels.poll(); + fakeSubchannel.setConnectedEagIndex(0); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); assertThat(currentState).isEqualTo(ConnectivityState.READY); PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs); assertThat(result.getStatus().isOk()).isTrue(); @@ -357,7 +365,7 @@ public void recordLoadStats() { TOLERANCE).of(0.009); streamTracer3.streamClosed(Status.OK); - subchannel.shutdown(); // stats recorder released + subchannel.shutdown(); // stats recorder released clusterStats = Iterables.getOnlyElement(loadStatsManager.getClusterStatsReports(CLUSTER)); // Locality load is reported for one last time in case of loads occurred since the previous // load report. @@ -373,6 +381,95 @@ public void recordLoadStats() { assertThat(clusterStats.upstreamLocalityStatsList()).isEmpty(); // no longer reported } + // TODO(dnvindhya): This test has been added as a fix to verify + // https://github.com/grpc/grpc-java/issues/11434. + // Once we update PickFirstLeafLoadBalancer as default LoadBalancer, update the test. + @Test + public void pickFirstLoadReport_onUpdateAddress() { + Locality locality1 = + Locality.create("test-region", "test-zone", "test-subzone"); + Locality locality2 = + Locality.create("other-region", "other-zone", "other-subzone"); + + LoadBalancerProvider pickFirstProvider = LoadBalancerRegistry + .getDefaultRegistry().getProvider("pick_first"); + Object pickFirstConfig = pickFirstProvider.parseLoadBalancingPolicyConfig(new HashMap<>()) + .getConfig(); + ClusterImplConfig config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, + null, Collections.emptyList(), + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig(pickFirstProvider, + pickFirstConfig), + null, Collections.emptyMap()); + EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr1", locality1); + EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr2", locality2); + deliverAddressesAndConfig(Arrays.asList(endpoint1, endpoint2), config); + + // Leaf balancer is created by Pick First. Get FakeSubchannel created to update attributes + // A real subchannel would get these attributes from the connected address's EAG locality. + FakeSubchannel fakeSubchannel = helper.subchannels.poll(); + fakeSubchannel.setConnectedEagIndex(0); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + assertThat(currentState).isEqualTo(ConnectivityState.READY); + PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs); + assertThat(result.getStatus().isOk()).isTrue(); + + ClientStreamTracer streamTracer1 = result.getStreamTracerFactory().newClientStreamTracer( + ClientStreamTracer.StreamInfo.newBuilder().build(), new Metadata()); // first RPC call + streamTracer1.streamClosed(Status.OK); + + ClusterStats clusterStats = Iterables.getOnlyElement( + loadStatsManager.getClusterStatsReports(CLUSTER)); + UpstreamLocalityStats localityStats = Iterables.getOnlyElement( + clusterStats.upstreamLocalityStatsList()); + assertThat(localityStats.locality()).isEqualTo(locality1); + assertThat(localityStats.totalIssuedRequests()).isEqualTo(1L); + assertThat(localityStats.totalSuccessfulRequests()).isEqualTo(1L); + assertThat(localityStats.totalErrorRequests()).isEqualTo(0L); + + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.IDLE)); + loadBalancer.requestConnection(); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.CONNECTING)); + + // Faksubchannel mimics update address and returns different locality + fakeSubchannel.setConnectedEagIndex(1); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + result = currentPicker.pickSubchannel(pickSubchannelArgs); + assertThat(result.getStatus().isOk()).isTrue(); + ClientStreamTracer streamTracer2 = result.getStreamTracerFactory().newClientStreamTracer( + ClientStreamTracer.StreamInfo.newBuilder().build(), new Metadata()); // second RPC call + streamTracer2.streamClosed(Status.UNAVAILABLE); + + clusterStats = Iterables.getOnlyElement(loadStatsManager.getClusterStatsReports(CLUSTER)); + List upstreamLocalityStatsList = + clusterStats.upstreamLocalityStatsList(); + UpstreamLocalityStats localityStats1 = Iterables.find(upstreamLocalityStatsList, + upstreamLocalityStats -> upstreamLocalityStats.locality().equals(locality1)); + assertThat(localityStats1.totalIssuedRequests()).isEqualTo(0L); + assertThat(localityStats1.totalSuccessfulRequests()).isEqualTo(0L); + assertThat(localityStats1.totalErrorRequests()).isEqualTo(0L); + UpstreamLocalityStats localityStats2 = Iterables.find(upstreamLocalityStatsList, + upstreamLocalityStats -> upstreamLocalityStats.locality().equals(locality2)); + assertThat(localityStats2.totalIssuedRequests()).isEqualTo(1L); + assertThat(localityStats2.totalSuccessfulRequests()).isEqualTo(0L); + assertThat(localityStats2.totalErrorRequests()).isEqualTo(1L); + + loadBalancer.shutdown(); + loadBalancer = null; + // No more references are held for localityStats1 hence dropped. + // Locality load is reported for one last time in case of loads occurred since the previous + // load report. + clusterStats = Iterables.getOnlyElement(loadStatsManager.getClusterStatsReports(CLUSTER)); + localityStats2 = Iterables.getOnlyElement(clusterStats.upstreamLocalityStatsList()); + + assertThat(localityStats2.locality()).isEqualTo(locality2); + assertThat(localityStats2.totalIssuedRequests()).isEqualTo(0L); + assertThat(localityStats2.totalSuccessfulRequests()).isEqualTo(0L); + assertThat(localityStats2.totalErrorRequests()).isEqualTo(0L); + assertThat(localityStats2.totalRequestsInProgress()).isEqualTo(0L); + + assertThat(loadStatsManager.getClusterStatsReports(CLUSTER)).isEmpty(); + } + @Test public void dropRpcsWithRespectToLbConfigDropCategories() { LoadBalancerProvider weightedTargetProvider = new WeightedTargetLoadBalancerProvider(); @@ -391,9 +488,11 @@ public void dropRpcsWithRespectToLbConfigDropCategories() { assertThat(leafBalancer.name).isEqualTo("round_robin"); assertThat(Iterables.getOnlyElement(leafBalancer.addresses).getAddresses()) .isEqualTo(endpoint.getAddresses()); - Subchannel subchannel = leafBalancer.helper.createSubchannel( - CreateSubchannelArgs.newBuilder().setAddresses(leafBalancer.addresses).build()); - leafBalancer.deliverSubchannelState(subchannel, ConnectivityState.READY); + leafBalancer.createSubChannel(); + FakeSubchannel fakeSubchannel = helper.subchannels.poll(); + fakeSubchannel.setConnectedEagIndex(0); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + assertThat(currentState).isEqualTo(ConnectivityState.READY); PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs); assertThat(result.getStatus().isOk()).isFalse(); @@ -470,9 +569,11 @@ private void subtest_maxConcurrentRequests_appliedByLbConfig(boolean enableCircu assertThat(leafBalancer.name).isEqualTo("round_robin"); assertThat(Iterables.getOnlyElement(leafBalancer.addresses).getAddresses()) .isEqualTo(endpoint.getAddresses()); - Subchannel subchannel = leafBalancer.helper.createSubchannel( - CreateSubchannelArgs.newBuilder().setAddresses(leafBalancer.addresses).build()); - leafBalancer.deliverSubchannelState(subchannel, ConnectivityState.READY); + leafBalancer.createSubChannel(); + FakeSubchannel fakeSubchannel = helper.subchannels.poll(); + fakeSubchannel.setConnectedEagIndex(0); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + assertThat(currentState).isEqualTo(ConnectivityState.READY); assertThat(currentState).isEqualTo(ConnectivityState.READY); for (int i = 0; i < maxConcurrentRequests; i++) { PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs); @@ -562,9 +663,11 @@ private void subtest_maxConcurrentRequests_appliedWithDefaultValue( assertThat(leafBalancer.name).isEqualTo("round_robin"); assertThat(Iterables.getOnlyElement(leafBalancer.addresses).getAddresses()) .isEqualTo(endpoint.getAddresses()); - Subchannel subchannel = leafBalancer.helper.createSubchannel( - CreateSubchannelArgs.newBuilder().setAddresses(leafBalancer.addresses).build()); - leafBalancer.deliverSubchannelState(subchannel, ConnectivityState.READY); + leafBalancer.createSubChannel(); + FakeSubchannel fakeSubchannel = helper.subchannels.poll(); + fakeSubchannel.setConnectedEagIndex(0); + fakeSubchannel.updateState(ConnectivityStateInfo.forNonError(ConnectivityState.READY)); + assertThat(currentState).isEqualTo(ConnectivityState.READY); assertThat(currentState).isEqualTo(ConnectivityState.READY); for (int i = 0; i < ClusterImplLoadBalancer.DEFAULT_PER_CLUSTER_MAX_CONCURRENT_REQUESTS; i++) { PickResult result = currentPicker.pickSubchannel(pickSubchannelArgs); @@ -830,19 +933,24 @@ public void shutdown() { downstreamBalancers.remove(this); } - void deliverSubchannelState(final Subchannel subchannel, ConnectivityState state) { - SubchannelPicker picker = new SubchannelPicker() { - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return PickResult.withSubchannel(subchannel); + Subchannel createSubChannel() { + Subchannel subchannel = helper.createSubchannel( + CreateSubchannelArgs.newBuilder().setAddresses(addresses).build()); + subchannel.start(infoObject -> { + if (infoObject.getState() == ConnectivityState.READY) { + helper.updateBalancingState( + ConnectivityState.READY, + new FixedResultPicker(PickResult.withSubchannel(subchannel))); } - }; - helper.updateBalancingState(state, picker); + }); + return subchannel; } } private final class FakeLbHelper extends LoadBalancer.Helper { + private final Queue subchannels = new LinkedList<>(); + @Override public SynchronizationContext getSynchronizationContext() { return syncContext; @@ -857,7 +965,9 @@ public void updateBalancingState( @Override public Subchannel createSubchannel(CreateSubchannelArgs args) { - return new FakeSubchannel(args.getAddresses(), args.getAttributes()); + FakeSubchannel subchannel = new FakeSubchannel(args.getAddresses(), args.getAttributes()); + subchannels.add(subchannel); + return subchannel; } @Override @@ -869,17 +979,27 @@ public ManagedChannel createOobChannel(EquivalentAddressGroup eag, String author public String getAuthority() { return AUTHORITY; } + + @Override + public void refreshNameResolution() {} } private static final class FakeSubchannel extends Subchannel { private final List eags; private final Attributes attrs; + private SubchannelStateListener listener; + private Attributes connectedAttributes; private FakeSubchannel(List eags, Attributes attrs) { this.eags = eags; this.attrs = attrs; } + @Override + public void start(SubchannelStateListener listener) { + this.listener = checkNotNull(listener, "listener"); + } + @Override public void shutdown() { } @@ -901,6 +1021,19 @@ public Attributes getAttributes() { @Override public void updateAddresses(List addrs) { } + + @Override + public Attributes getConnectedAddressAttributes() { + return connectedAttributes; + } + + public void updateState(ConnectivityStateInfo newState) { + listener.onSubchannelState(newState); + } + + public void setConnectedEagIndex(int eagIndex) { + this.connectedAttributes = eags.get(eagIndex).getAttributes(); + } } private final class FakeXdsClient extends XdsClient { diff --git a/xds/src/test/java/io/grpc/xds/ClusterManagerLoadBalancerProviderTest.java b/xds/src/test/java/io/grpc/xds/ClusterManagerLoadBalancerProviderTest.java index 515f6fef3ef..40943658520 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterManagerLoadBalancerProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterManagerLoadBalancerProviderTest.java @@ -26,7 +26,7 @@ import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.internal.JsonParser; -import io.grpc.internal.ServiceConfigUtil.PolicySelection; +import io.grpc.util.GracefulSwitchLoadBalancer; import io.grpc.xds.ClusterManagerLoadBalancerProvider.ClusterManagerConfig; import java.io.IOException; import java.util.Map; @@ -133,10 +133,9 @@ public ConfigOrError parseLoadBalancingPolicyConfig( assertThat(config.childPolicies) .containsExactly( "child1", - new PolicySelection( - lbProviderFoo, fooConfig), + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig(lbProviderFoo, fooConfig), "child2", - new PolicySelection(lbProviderBar, barConfig)); + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig(lbProviderBar, barConfig)); } @Test diff --git a/xds/src/test/java/io/grpc/xds/ClusterManagerLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterManagerLoadBalancerTest.java index f55b0d73f79..aa0e205dd8f 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterManagerLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterManagerLoadBalancerTest.java @@ -52,10 +52,11 @@ import io.grpc.SynchronizationContext; import io.grpc.internal.FakeClock; import io.grpc.internal.PickSubchannelArgsImpl; -import io.grpc.internal.ServiceConfigUtil.PolicySelection; import io.grpc.testing.TestMethodDescriptors; +import io.grpc.util.GracefulSwitchLoadBalancer; import io.grpc.xds.ClusterManagerLoadBalancerProvider.ClusterManagerConfig; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; @@ -288,16 +289,27 @@ private void deliverResolvedAddresses(final Map childPolicies, b .build()); } + // Prevent ClusterManagerLB from detecting different providers even when the configuration is the + // same. + private Map, FakeLoadBalancerProvider> fakeLoadBalancerProviderCache + = new HashMap<>(); + private ClusterManagerConfig buildConfig(Map childPolicies, boolean failing) { - Map childPolicySelections = new LinkedHashMap<>(); + Map childConfigs = new LinkedHashMap<>(); for (String name : childPolicies.keySet()) { String childPolicyName = childPolicies.get(name); Object childConfig = lbConfigInventory.get(name); - PolicySelection policy = - new PolicySelection(new FakeLoadBalancerProvider(childPolicyName, failing), childConfig); - childPolicySelections.put(name, policy); + FakeLoadBalancerProvider lbProvider = + fakeLoadBalancerProviderCache.get(Arrays.asList(childPolicyName, failing)); + if (lbProvider == null) { + lbProvider = new FakeLoadBalancerProvider(childPolicyName, failing); + fakeLoadBalancerProviderCache.put(Arrays.asList(childPolicyName, failing), lbProvider); + } + Object policy = + GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig(lbProvider, childConfig); + childConfigs.put(name, policy); } - return new ClusterManagerConfig(childPolicySelections); + return new ClusterManagerConfig(childConfigs); } private static PickResult pickSubchannel(SubchannelPicker picker, String clusterName) { diff --git a/xds/src/test/java/io/grpc/xds/CsdsServiceTest.java b/xds/src/test/java/io/grpc/xds/CsdsServiceTest.java index bf330b1007a..63b9cda043c 100644 --- a/xds/src/test/java/io/grpc/xds/CsdsServiceTest.java +++ b/xds/src/test/java/io/grpc/xds/CsdsServiceTest.java @@ -54,6 +54,7 @@ import io.grpc.xds.client.XdsClient.ResourceMetadata; import io.grpc.xds.client.XdsClient.ResourceMetadata.ResourceMetadataStatus; import io.grpc.xds.client.XdsResourceType; +import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.List; @@ -85,6 +86,7 @@ public class CsdsServiceTest { private static final XdsResourceType CDS = XdsClusterResource.getInstance(); private static final XdsResourceType RDS = XdsRouteConfigureResource.getInstance(); private static final XdsResourceType EDS = XdsEndpointResource.getInstance(); + public static final String FAKE_CLIENT_SCOPE = "fake"; @RunWith(JUnit4.class) public static class ServiceTests { @@ -198,13 +200,13 @@ public void streamClientStatus_happyPath() { @Override @Nullable - public ObjectPool get() { + public ObjectPool get(String target) { // xDS client not ready on the first call, then becomes ready. if (!calledOnce) { calledOnce = true; return null; } else { - return super.get(); + return super.get(target); } } }); @@ -267,11 +269,51 @@ public void streamClientStatus_onClientError() { assertThat(responseObserver.getError()).isNull(); } + @Test + public void multipleXdsClients() { + FakeXdsClient xdsClient1 = new FakeXdsClient(); + FakeXdsClient xdsClient2 = new FakeXdsClient(); + Map clientMap = new HashMap<>(); + clientMap.put("target1", xdsClient1); + clientMap.put("target2", xdsClient2); + FakeXdsClientPoolFactory factory = new FakeXdsClientPoolFactory(clientMap); + CsdsService csdsService = new CsdsService(factory); + grpcServerRule.getServiceRegistry().addService(csdsService); + + StreamRecorder responseObserver = StreamRecorder.create(); + StreamObserver requestObserver = + csdsAsyncStub.streamClientStatus(responseObserver); + + requestObserver.onNext(REQUEST); + requestObserver.onCompleted(); + + List responses = responseObserver.getValues(); + assertThat(responses).hasSize(1); + Collection targets = verifyMultiResponse(responses.get(0), 2); + assertThat(targets).containsExactly("target1", "target2"); + responseObserver.onCompleted(); + } + private void verifyResponse(ClientStatusResponse response) { assertThat(response.getConfigCount()).isEqualTo(1); ClientConfig clientConfig = response.getConfig(0); verifyClientConfigNode(clientConfig); verifyClientConfigNoResources(XDS_CLIENT_NO_RESOURCES, clientConfig); + assertThat(clientConfig.getClientScope()).isEmpty(); + } + + private Collection verifyMultiResponse(ClientStatusResponse response, int numExpected) { + assertThat(response.getConfigCount()).isEqualTo(numExpected); + + List clientScopes = new ArrayList<>(); + for (int i = 0; i < numExpected; i++) { + ClientConfig clientConfig = response.getConfig(i); + verifyClientConfigNode(clientConfig); + verifyClientConfigNoResources(XDS_CLIENT_NO_RESOURCES, clientConfig); + clientScopes.add(clientConfig.getClientScope()); + } + + return clientScopes; } private void verifyRequestInvalidResponseStatus(Status status) { @@ -350,9 +392,11 @@ public Map> getSubscribedResourceTypesWithTypeUrl() { ); } }; - ClientConfig clientConfig = CsdsService.getClientConfigForXdsClient(fakeXdsClient); + ClientConfig clientConfig = CsdsService.getClientConfigForXdsClient(fakeXdsClient, + FAKE_CLIENT_SCOPE); verifyClientConfigNode(clientConfig); + assertThat(clientConfig.getClientScope()).isEqualTo(FAKE_CLIENT_SCOPE); // Minimal verification to confirm that the data/metadata XdsClient provides, // is propagated to the correct resource types. @@ -390,9 +434,11 @@ public Map> getSubscribedResourceTypesWithTypeUrl() { @Test public void getClientConfigForXdsClient_noSubscribedResources() throws InterruptedException { - ClientConfig clientConfig = CsdsService.getClientConfigForXdsClient(XDS_CLIENT_NO_RESOURCES); + ClientConfig clientConfig = + CsdsService.getClientConfigForXdsClient(XDS_CLIENT_NO_RESOURCES, FAKE_CLIENT_SCOPE); verifyClientConfigNode(clientConfig); verifyClientConfigNoResources(XDS_CLIENT_NO_RESOURCES, clientConfig); + assertThat(clientConfig.getClientScope()).isEqualTo(FAKE_CLIENT_SCOPE); } } @@ -460,22 +506,35 @@ public Collection getSubscribedResources(ServerInfo serverInfo, public Map> getSubscribedResourceTypesWithTypeUrl() { return ImmutableMap.of(); } + } private static class FakeXdsClientPoolFactory implements XdsClientPoolFactory { - @Nullable private final XdsClient xdsClient; + private final Map xdsClientMap = new HashMap<>(); + private boolean isOldStyle + ; private FakeXdsClientPoolFactory(@Nullable XdsClient xdsClient) { - this.xdsClient = xdsClient; + if (xdsClient != null) { + xdsClientMap.put("", xdsClient); + } + isOldStyle = true; + } + + private FakeXdsClientPoolFactory(Map xdsClientMap) { + this.xdsClientMap.putAll(xdsClientMap); + isOldStyle = false; } @Override @Nullable - public ObjectPool get() { + public ObjectPool get(String target) { + String targetToUse = isOldStyle ? "" : target; + return new ObjectPool() { @Override public XdsClient getObject() { - return xdsClient; + return xdsClientMap.get(targetToUse); } @Override @@ -485,13 +544,18 @@ public XdsClient returnObject(Object object) { }; } + @Override + public List getTargets() { + return new ArrayList<>(xdsClientMap.keySet()); + } + @Override public void setBootstrapOverride(Map bootstrap) { throw new UnsupportedOperationException("Should not be called"); } @Override - public ObjectPool getOrCreate() { + public ObjectPool getOrCreate(String target) { throw new UnsupportedOperationException("Should not be called"); } } diff --git a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java index 6b04edcb9b8..d41630cdb4a 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java +++ b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java @@ -340,8 +340,7 @@ public XdsTransport create(ServerInfo serverInfo) { } }; - xdsServerInfo = ServerInfo.create(SERVER_URI, CHANNEL_CREDENTIALS, - ignoreResourceDeletion()); + xdsServerInfo = ServerInfo.create(SERVER_URI, CHANNEL_CREDENTIALS, ignoreResourceDeletion()); BootstrapInfo bootstrapInfo = Bootstrapper.BootstrapInfo.builder() .servers(Collections.singletonList(xdsServerInfo)) @@ -2974,6 +2973,7 @@ public void flowControlAbsent() throws Exception { anotherWatcher, fakeWatchClock.getScheduledExecutorService()); verifyResourceMetadataRequested(CDS, CDS_RESOURCE); verifyResourceMetadataRequested(CDS, anotherCdsResource); + DiscoveryRpcCall call = resourceDiscoveryCalls.poll(); call.verifyRequest(CDS, Arrays.asList(CDS_RESOURCE, anotherCdsResource), "", "", NODE); assertThat(fakeWatchClock.runDueTasks()).isEqualTo(2); diff --git a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplV3Test.java b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplV3Test.java index 2b2ce5cbd72..40a9bff514f 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplV3Test.java +++ b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplV3Test.java @@ -17,6 +17,7 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; @@ -144,7 +145,8 @@ public StreamObserver streamAggregatedResources( assertThat(adsEnded.get()).isTrue(); // ensure previous call was ended adsEnded.set(false); @SuppressWarnings("unchecked") - StreamObserver requestObserver = mock(StreamObserver.class); + StreamObserver requestObserver = + mock(StreamObserver.class, delegatesTo(new MockStreamObserver())); DiscoveryRpcCall call = new DiscoveryRpcCallV3(requestObserver, responseObserver); resourceDiscoveryCalls.offer(call); Context.current().addListener( @@ -874,6 +876,19 @@ public boolean matches(DiscoveryRequest argument) { } return node.equals(argument.getNode()); } + + @Override + public String toString() { + return "DiscoveryRequestMatcher{" + + "node=" + node + + ", versionInfo='" + versionInfo + '\'' + + ", typeUrl='" + typeUrl + '\'' + + ", resources=" + resources + + ", responseNonce='" + responseNonce + '\'' + + ", errorCode=" + errorCode + + ", errorMessages=" + errorMessages + + '}'; + } } /** @@ -901,4 +916,23 @@ public boolean matches(LoadStatsRequest argument) { return actual.equals(expected); } } + + private static class MockStreamObserver implements StreamObserver { + private final List requests = new ArrayList<>(); + + @Override + public void onNext(DiscoveryRequest value) { + requests.add(value); + } + + @Override + public void onError(Throwable t) { + // Ignore + } + + @Override + public void onCompleted() { + // Ignore + } + } } diff --git a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java index de871cdd8f1..047ba71bbe0 100644 --- a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java @@ -66,7 +66,6 @@ import io.grpc.testing.TestMethodDescriptors; import io.grpc.util.AbstractTestHelper; import io.grpc.util.MultiChildLoadBalancer.ChildLbState; -import io.grpc.xds.RingHashLoadBalancer.RingHashChildLbState; import io.grpc.xds.RingHashLoadBalancer.RingHashConfig; import java.lang.Thread.UncaughtExceptionHandler; import java.net.SocketAddress; @@ -177,8 +176,7 @@ public void subchannelNotAutoReconnectAfterReenteringIdle() { assertThat(addressesAcceptanceStatus.isOk()).isTrue(); verify(helper).updateBalancingState(eq(IDLE), pickerCaptor.capture()); - RingHashChildLbState childLbState = - (RingHashChildLbState) loadBalancer.getChildLbStates().iterator().next(); + ChildLbState childLbState = loadBalancer.getChildLbStates().iterator().next(); assertThat(subchannels.get(Collections.singletonList(childLbState.getEag()))).isNull(); // Picking subchannel triggers connection. @@ -422,7 +420,7 @@ public void skipFailingHosts_pickNextNonFailingHost() { assertThat(addressesAcceptanceStatus.isOk()).isTrue(); // Create subchannel for the first address - ((RingHashChildLbState) loadBalancer.getChildLbStateEag(servers.get(0))).getCurrentPicker() + loadBalancer.getChildLbStateEag(servers.get(0)).getCurrentPicker() .pickSubchannel(getDefaultPickSubchannelArgs(hashFunc.hashVoid())); verifyConnection(1); diff --git a/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java b/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java index 0687b51aea6..ee164938b2d 100644 --- a/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java @@ -51,6 +51,7 @@ public class SharedXdsClientPoolProviderTest { @Rule public final ExpectedException thrown = ExpectedException.none(); private final Node node = Node.newBuilder().setId("SharedXdsClientPoolProviderTest").build(); + private static final String DUMMY_TARGET = "dummy"; @Mock private GrpcBootstrapperImpl bootstrapper; @@ -63,8 +64,8 @@ public void noServer() throws XdsInitializationException { SharedXdsClientPoolProvider provider = new SharedXdsClientPoolProvider(bootstrapper); thrown.expect(XdsInitializationException.class); thrown.expectMessage("No xDS server provided"); - provider.getOrCreate(); - assertThat(provider.get()).isNull(); + provider.getOrCreate(DUMMY_TARGET); + assertThat(provider.get(DUMMY_TARGET)).isNull(); } @Test @@ -75,12 +76,12 @@ public void sharedXdsClientObjectPool() throws XdsInitializationException { when(bootstrapper.bootstrap()).thenReturn(bootstrapInfo); SharedXdsClientPoolProvider provider = new SharedXdsClientPoolProvider(bootstrapper); - assertThat(provider.get()).isNull(); - ObjectPool xdsClientPool = provider.getOrCreate(); + assertThat(provider.get(DUMMY_TARGET)).isNull(); + ObjectPool xdsClientPool = provider.getOrCreate(DUMMY_TARGET); verify(bootstrapper).bootstrap(); - assertThat(provider.getOrCreate()).isSameInstanceAs(xdsClientPool); - assertThat(provider.get()).isNotNull(); - assertThat(provider.get()).isSameInstanceAs(xdsClientPool); + assertThat(provider.getOrCreate(DUMMY_TARGET)).isSameInstanceAs(xdsClientPool); + assertThat(provider.get(DUMMY_TARGET)).isNotNull(); + assertThat(provider.get(DUMMY_TARGET)).isSameInstanceAs(xdsClientPool); verifyNoMoreInteractions(bootstrapper); } @@ -90,7 +91,7 @@ public void refCountedXdsClientObjectPool_delayedCreation() { BootstrapInfo bootstrapInfo = BootstrapInfo.builder().servers(Collections.singletonList(server)).node(node).build(); RefCountedXdsClientObjectPool xdsClientPool = - new RefCountedXdsClientObjectPool(bootstrapInfo); + new RefCountedXdsClientObjectPool(bootstrapInfo, DUMMY_TARGET); assertThat(xdsClientPool.getXdsClientForTest()).isNull(); XdsClient xdsClient = xdsClientPool.getObject(); assertThat(xdsClientPool.getXdsClientForTest()).isNotNull(); @@ -103,7 +104,7 @@ public void refCountedXdsClientObjectPool_refCounted() { BootstrapInfo bootstrapInfo = BootstrapInfo.builder().servers(Collections.singletonList(server)).node(node).build(); RefCountedXdsClientObjectPool xdsClientPool = - new RefCountedXdsClientObjectPool(bootstrapInfo); + new RefCountedXdsClientObjectPool(bootstrapInfo, DUMMY_TARGET); // getObject once XdsClient xdsClient = xdsClientPool.getObject(); assertThat(xdsClient).isNotNull(); @@ -123,7 +124,7 @@ public void refCountedXdsClientObjectPool_getObjectCreatesNewInstanceIfAlreadySh BootstrapInfo bootstrapInfo = BootstrapInfo.builder().servers(Collections.singletonList(server)).node(node).build(); RefCountedXdsClientObjectPool xdsClientPool = - new RefCountedXdsClientObjectPool(bootstrapInfo); + new RefCountedXdsClientObjectPool(bootstrapInfo, DUMMY_TARGET); XdsClient xdsClient1 = xdsClientPool.getObject(); assertThat(xdsClientPool.returnObject(xdsClient1)).isNull(); assertThat(xdsClient1.isShutDown()).isTrue(); diff --git a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java index dd98f1e1ae6..05ad1f56ece 100644 --- a/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/WeightedRoundRobinLoadBalancerTest.java @@ -244,7 +244,7 @@ public void wrrLifeCycle() { String weightedPickerStr = weightedPicker.toString(); assertThat(weightedPickerStr).contains("enableOobLoadReport=false"); assertThat(weightedPickerStr).contains("errorUtilizationPenalty=1.0"); - assertThat(weightedPickerStr).contains("list="); + assertThat(weightedPickerStr).contains("pickers="); WeightedChildLbState weightedChild1 = (WeightedChildLbState) getChild(weightedPicker, 0); WeightedChildLbState weightedChild2 = (WeightedChildLbState) getChild(weightedPicker, 1); diff --git a/xds/src/test/java/io/grpc/xds/XdsClientFederationTest.java b/xds/src/test/java/io/grpc/xds/XdsClientFederationTest.java index 149c1d6170d..0b8e89de721 100644 --- a/xds/src/test/java/io/grpc/xds/XdsClientFederationTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientFederationTest.java @@ -72,12 +72,13 @@ public class XdsClientFederationTest { private ObjectPool xdsClientPool; private XdsClient xdsClient; + private static final String DUMMY_TARGET = "dummy"; @Before public void setUp() throws XdsInitializationException { SharedXdsClientPoolProvider clientPoolProvider = new SharedXdsClientPoolProvider(); clientPoolProvider.setBootstrapOverride(defaultBootstrapOverride()); - xdsClientPool = clientPoolProvider.getOrCreate(); + xdsClientPool = clientPoolProvider.getOrCreate(DUMMY_TARGET); xdsClient = xdsClientPool.getObject(); } diff --git a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java index 28871850e72..24c2a43b83a 100644 --- a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java @@ -95,11 +95,15 @@ import io.grpc.xds.client.XdsInitializationException; import io.grpc.xds.client.XdsResourceType; import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -166,15 +170,22 @@ public ConfigOrError parseServiceConfig(Map rawServiceConfig) { private XdsNameResolver resolver; private TestCall testCall; private boolean originalEnableTimeout; + private URI targetUri; @Before public void setUp() { + try { + targetUri = new URI(AUTHORITY); + } catch (URISyntaxException e) { + targetUri = null; + } + originalEnableTimeout = XdsNameResolver.enableTimeout; XdsNameResolver.enableTimeout = true; FilterRegistry filterRegistry = FilterRegistry.newRegistry().register( new FaultFilter(mockRandom, new AtomicLong()), RouterFilter.INSTANCE); - resolver = new XdsNameResolver(null, AUTHORITY, null, + resolver = new XdsNameResolver(targetUri, null, AUTHORITY, null, serviceConfigParser, syncContext, scheduler, xdsClientPoolFactory, mockRandom, filterRegistry, null); } @@ -199,16 +210,22 @@ public void setBootstrapOverride(Map bootstrap) { @Override @Nullable - public ObjectPool get() { + public ObjectPool get(String target) { throw new UnsupportedOperationException("Should not be called"); } @Override - public ObjectPool getOrCreate() throws XdsInitializationException { + public ObjectPool getOrCreate(String target) throws XdsInitializationException { throw new XdsInitializationException("Fail to read bootstrap file"); } + + @Override + public List getTargets() { + return null; + } }; - resolver = new XdsNameResolver(null, AUTHORITY, null, + + resolver = new XdsNameResolver(targetUri, null, AUTHORITY, null, serviceConfigParser, syncContext, scheduler, xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); resolver.start(mockListener); @@ -221,7 +238,7 @@ public ObjectPool getOrCreate() throws XdsInitializationException { @Test public void resolving_withTargetAuthorityNotFound() { - resolver = new XdsNameResolver( + resolver = new XdsNameResolver(targetUri, "notfound.google.com", AUTHORITY, null, serviceConfigParser, syncContext, scheduler, xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); resolver.start(mockListener); @@ -243,7 +260,7 @@ public void resolving_noTargetAuthority_templateWithoutXdstp() { String serviceAuthority = "[::FFFF:129.144.52.38]:80"; expectedLdsResourceName = "[::FFFF:129.144.52.38]:80/id=1"; resolver = new XdsNameResolver( - null, serviceAuthority, null, serviceConfigParser, syncContext, + targetUri, null, serviceAuthority, null, serviceConfigParser, syncContext, scheduler, xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); resolver.start(mockListener); @@ -264,7 +281,7 @@ public void resolving_noTargetAuthority_templateWithXdstp() { "xdstp://xds.authority.com/envoy.config.listener.v3.Listener/" + "%5B::FFFF:129.144.52.38%5D:80?id=1"; resolver = new XdsNameResolver( - null, serviceAuthority, null, serviceConfigParser, syncContext, scheduler, + targetUri, null, serviceAuthority, null, serviceConfigParser, syncContext, scheduler, xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); resolver.start(mockListener); verify(mockListener, never()).onError(any(Status.class)); @@ -284,7 +301,7 @@ public void resolving_noTargetAuthority_xdstpWithMultipleSlashes() { "xdstp://xds.authority.com/envoy.config.listener.v3.Listener/" + "path/to/service?id=1"; resolver = new XdsNameResolver( - null, serviceAuthority, null, serviceConfigParser, syncContext, scheduler, + targetUri, null, serviceAuthority, null, serviceConfigParser, syncContext, scheduler, xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); @@ -311,7 +328,7 @@ public void resolving_targetAuthorityInAuthoritiesMap() { .build(); expectedLdsResourceName = "xdstp://xds.authority.com/envoy.config.listener.v3.Listener/" + "%5B::FFFF:129.144.52.38%5D:80?bar=2&foo=1"; // query param canonified - resolver = new XdsNameResolver( + resolver = new XdsNameResolver(targetUri, "xds.authority.com", serviceAuthority, null, serviceConfigParser, syncContext, scheduler, xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); resolver.start(mockListener); @@ -343,7 +360,7 @@ public void resolving_ldsResourceUpdateRdsName() { .clientDefaultListenerResourceNameTemplate("test-%s") .node(Node.newBuilder().build()) .build(); - resolver = new XdsNameResolver(null, AUTHORITY, null, + resolver = new XdsNameResolver(targetUri, null, AUTHORITY, null, serviceConfigParser, syncContext, scheduler, xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); // use different ldsResourceName and service authority. The virtualhost lookup should use @@ -524,7 +541,7 @@ public void resolving_matchingVirtualHostNotFound_matchingOverrideAuthority() { Collections.singletonList(route), ImmutableMap.of()); - resolver = new XdsNameResolver(null, AUTHORITY, "random", + resolver = new XdsNameResolver(targetUri, null, AUTHORITY, "random", serviceConfigParser, syncContext, scheduler, xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); resolver.start(mockListener); @@ -547,7 +564,7 @@ public void resolving_matchingVirtualHostNotFound_notMatchingOverrideAuthority() Collections.singletonList(route), ImmutableMap.of()); - resolver = new XdsNameResolver(null, AUTHORITY, "random", + resolver = new XdsNameResolver(targetUri, null, AUTHORITY, "random", serviceConfigParser, syncContext, scheduler, xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); resolver.start(mockListener); @@ -558,7 +575,7 @@ public void resolving_matchingVirtualHostNotFound_notMatchingOverrideAuthority() @Test public void resolving_matchingVirtualHostNotFoundForOverrideAuthority() { - resolver = new XdsNameResolver(null, AUTHORITY, AUTHORITY, + resolver = new XdsNameResolver(targetUri, null, AUTHORITY, AUTHORITY, serviceConfigParser, syncContext, scheduler, xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); resolver.start(mockListener); @@ -643,8 +660,8 @@ public void resolved_fallbackToHttpMaxStreamDurationAsTimeout() { public void retryPolicyInPerMethodConfigGeneratedByResolverIsValid() { ServiceConfigParser realParser = new ScParser( true, 5, 5, new AutoConfiguredLoadBalancerFactory("pick-first")); - resolver = new XdsNameResolver(null, AUTHORITY, null, realParser, syncContext, scheduler, - xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); + resolver = new XdsNameResolver(targetUri, null, AUTHORITY, null, realParser, syncContext, + scheduler, xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); RetryPolicy retryPolicy = RetryPolicy.create( @@ -847,7 +864,7 @@ public void resolved_rpcHashingByChannelId() { resolver.shutdown(); reset(mockListener); when(mockRandom.nextLong()).thenReturn(123L); - resolver = new XdsNameResolver(null, AUTHORITY, null, serviceConfigParser, + resolver = new XdsNameResolver(targetUri, null, AUTHORITY, null, serviceConfigParser, syncContext, scheduler, xdsClientPoolFactory, mockRandom, FilterRegistry.getDefaultRegistry(), null); resolver.start(mockListener); @@ -1896,17 +1913,20 @@ private PickSubchannelArgs newPickSubchannelArgs( } private final class FakeXdsClientPoolFactory implements XdsClientPoolFactory { + Set targets = new HashSet<>(); + @Override public void setBootstrapOverride(Map bootstrap) {} @Override @Nullable - public ObjectPool get() { + public ObjectPool get(String target) { throw new UnsupportedOperationException("Should not be called"); } @Override - public ObjectPool getOrCreate() throws XdsInitializationException { + public ObjectPool getOrCreate(String target) throws XdsInitializationException { + targets.add(target); return new ObjectPool() { @Override public XdsClient getObject() { @@ -1919,6 +1939,16 @@ public XdsClient returnObject(Object object) { } }; } + + @Override + public List getTargets() { + if (targets.isEmpty()) { + List targetList = new ArrayList<>(); + targetList.add(targetUri.toString()); + return targetList; + } + return new ArrayList<>(targets); + } } private class FakeXdsClient extends XdsClient { diff --git a/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java b/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java index 5d59e97335e..791318c5355 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java @@ -146,12 +146,12 @@ public void setBootstrapOverride(Map bootstrap) { @Override @Nullable - public ObjectPool get() { + public ObjectPool get(String target) { throw new UnsupportedOperationException("Should not be called"); } @Override - public ObjectPool getOrCreate() throws XdsInitializationException { + public ObjectPool getOrCreate(String target) throws XdsInitializationException { return new ObjectPool() { @Override public XdsClient getObject() { @@ -165,6 +165,11 @@ public XdsClient returnObject(Object object) { } }; } + + @Override + public List getTargets() { + return Collections.singletonList("fake-target"); + } } static final class FakeXdsClient extends XdsClient { diff --git a/xds/src/test/java/io/grpc/xds/XdsTestControlPlaneService.java b/xds/src/test/java/io/grpc/xds/XdsTestControlPlaneService.java index c51327dc84d..cc12e3863ba 100644 --- a/xds/src/test/java/io/grpc/xds/XdsTestControlPlaneService.java +++ b/xds/src/test/java/io/grpc/xds/XdsTestControlPlaneService.java @@ -135,12 +135,15 @@ public void run() { new Object[]{value.getResourceNamesList(), value.getErrorDetail()}); return; } + String resourceType = value.getTypeUrl(); - if (!value.getResponseNonce().isEmpty() - && !String.valueOf(xdsNonces.get(resourceType)).equals(value.getResponseNonce())) { + if (!value.getResponseNonce().isEmpty() && xdsNonces.containsKey(resourceType) + && !String.valueOf(xdsNonces.get(resourceType).get(responseObserver)) + .equals(value.getResponseNonce())) { logger.log(Level.FINE, "Resource nonce does not match, ignore."); return; } + Set requestedResourceNames = new HashSet<>(value.getResourceNamesList()); if (subscribers.get(resourceType).containsKey(responseObserver) && subscribers.get(resourceType).get(responseObserver) @@ -149,9 +152,11 @@ public void run() { value.getResourceNamesList()); return; } + if (!xdsNonces.get(resourceType).containsKey(responseObserver)) { xdsNonces.get(resourceType).put(responseObserver, new AtomicInteger(0)); } + DiscoveryResponse response = generateResponse(resourceType, String.valueOf(xdsVersions.get(resourceType)), String.valueOf(xdsNonces.get(resourceType).get(responseObserver)),