diff --git a/shell/platform/android/io/flutter/embedding/engine/FlutterJNI.java b/shell/platform/android/io/flutter/embedding/engine/FlutterJNI.java index 17f918174af50..6b817e35aa565 100644 --- a/shell/platform/android/io/flutter/embedding/engine/FlutterJNI.java +++ b/shell/platform/android/io/flutter/embedding/engine/FlutterJNI.java @@ -814,7 +814,7 @@ public void setPlatformMessageHandler(@Nullable PlatformMessageHandler platformM @SuppressWarnings("unused") @VisibleForTesting public void handlePlatformMessage( - @NonNull final String channel, byte[] message, final int replyId) { + @NonNull final String channel, ByteBuffer message, final int replyId) { if (platformMessageHandler != null) { platformMessageHandler.handleMessageFromDart(channel, message, replyId); } @@ -825,7 +825,7 @@ public void handlePlatformMessage( // Called by native to respond to a platform message that we sent. // TODO(mattcarroll): determine if reply is nonull or nullable @SuppressWarnings("unused") - private void handlePlatformMessageResponse(int replyId, byte[] reply) { + private void handlePlatformMessageResponse(int replyId, ByteBuffer reply) { if (platformMessageHandler != null) { platformMessageHandler.handlePlatformMessageResponse(replyId, reply); } diff --git a/shell/platform/android/io/flutter/embedding/engine/dart/DartMessenger.java b/shell/platform/android/io/flutter/embedding/engine/dart/DartMessenger.java index fc03bb1f3c6f6..8a5e44f637a1a 100644 --- a/shell/platform/android/io/flutter/embedding/engine/dart/DartMessenger.java +++ b/shell/platform/android/io/flutter/embedding/engine/dart/DartMessenger.java @@ -75,14 +75,18 @@ public void send( @Override public void handleMessageFromDart( - @NonNull final String channel, @Nullable byte[] message, final int replyId) { + @NonNull final String channel, @Nullable ByteBuffer message, final int replyId) { Log.v(TAG, "Received message from Dart over channel '" + channel + "'"); BinaryMessenger.BinaryMessageHandler handler = messageHandlers.get(channel); if (handler != null) { try { Log.v(TAG, "Deferring to registered handler to process message."); - final ByteBuffer buffer = (message == null ? null : ByteBuffer.wrap(message)); - handler.onMessage(buffer, new Reply(flutterJNI, replyId)); + handler.onMessage(message, new Reply(flutterJNI, replyId)); + if (message != null && message.isDirect()) { + // This ensures that if a user retains an instance to the ByteBuffer and it happens to + // be direct they will get a deterministic error. + message.limit(0); + } } catch (Exception ex) { Log.e(TAG, "Uncaught exception in binary message listener", ex); flutterJNI.invokePlatformMessageEmptyResponseCallback(replyId); @@ -96,13 +100,18 @@ public void handleMessageFromDart( } @Override - public void handlePlatformMessageResponse(int replyId, @Nullable byte[] reply) { + public void handlePlatformMessageResponse(int replyId, @Nullable ByteBuffer reply) { Log.v(TAG, "Received message reply from Dart."); BinaryMessenger.BinaryReply callback = pendingReplies.remove(replyId); if (callback != null) { try { Log.v(TAG, "Invoking registered callback for reply from Dart."); - callback.reply(reply == null ? null : ByteBuffer.wrap(reply)); + callback.reply(reply); + if (reply != null && reply.isDirect()) { + // This ensures that if a user retains an instance to the ByteBuffer and it happens to + // be direct they will get a deterministic error. + reply.limit(0); + } } catch (Exception ex) { Log.e(TAG, "Uncaught exception in binary message reply handler", ex); } catch (Error err) { diff --git a/shell/platform/android/io/flutter/embedding/engine/dart/PlatformMessageHandler.java b/shell/platform/android/io/flutter/embedding/engine/dart/PlatformMessageHandler.java index ed8a5b044daad..22bb6f195666f 100644 --- a/shell/platform/android/io/flutter/embedding/engine/dart/PlatformMessageHandler.java +++ b/shell/platform/android/io/flutter/embedding/engine/dart/PlatformMessageHandler.java @@ -6,11 +6,12 @@ import androidx.annotation.NonNull; import androidx.annotation.Nullable; +import java.nio.ByteBuffer; /** Handler that receives messages from Dart code. */ public interface PlatformMessageHandler { void handleMessageFromDart( - @NonNull final String channel, @Nullable byte[] message, final int replyId); + @NonNull final String channel, @Nullable ByteBuffer message, final int replyId); - void handlePlatformMessageResponse(int replyId, @Nullable byte[] reply); + void handlePlatformMessageResponse(int replyId, @Nullable ByteBuffer reply); } diff --git a/shell/platform/android/io/flutter/plugin/common/BinaryCodec.java b/shell/platform/android/io/flutter/plugin/common/BinaryCodec.java index 348de24c2be4c..0bcae697f50d1 100644 --- a/shell/platform/android/io/flutter/plugin/common/BinaryCodec.java +++ b/shell/platform/android/io/flutter/plugin/common/BinaryCodec.java @@ -18,8 +18,31 @@ public final class BinaryCodec implements MessageCodec { // This codec must match the Dart codec of the same name in package flutter/services. public static final BinaryCodec INSTANCE = new BinaryCodec(); + /** + * A BinaryCodec that returns direct ByteBuffers from `decodeMessage` for better performance. + * + * @see BinaryCodec.BinaryCodec(boolean) + */ + public static final BinaryCodec INSTANCE_DIRECT = new BinaryCodec(true); - private BinaryCodec() {} + private final boolean returnsDirectByteBufferFromDecoding; + + private BinaryCodec() { + this.returnsDirectByteBufferFromDecoding = false; + } + + /** + * A constructor for BinaryCodec. + * + * @param returnsDirectByteBufferFromDecoding `true` means that the Codec will return direct + * ByteBuffers from `decodeMessage`. Direct ByteBuffers will have better performance but will + * be invalid beyond the scope of the `decodeMessage` call. `false` means Flutter will copy + * the encoded message to Java's memory, so the ByteBuffer will be valid beyond the + * decodeMessage call, at the cost of a copy. + */ + private BinaryCodec(boolean returnsDirectByteBufferFromDecoding) { + this.returnsDirectByteBufferFromDecoding = returnsDirectByteBufferFromDecoding; + } @Override public ByteBuffer encodeMessage(ByteBuffer message) { @@ -28,6 +51,13 @@ public ByteBuffer encodeMessage(ByteBuffer message) { @Override public ByteBuffer decodeMessage(ByteBuffer message) { - return message; + if (returnsDirectByteBufferFromDecoding) { + return message; + } else { + ByteBuffer result = ByteBuffer.allocate(message.capacity()); + result.put(message); + result.rewind(); + return result; + } } } diff --git a/shell/platform/android/io/flutter/plugin/common/MessageCodec.java b/shell/platform/android/io/flutter/plugin/common/MessageCodec.java index 2e9d88718f9a7..a87214a14f38e 100644 --- a/shell/platform/android/io/flutter/plugin/common/MessageCodec.java +++ b/shell/platform/android/io/flutter/plugin/common/MessageCodec.java @@ -26,6 +26,10 @@ public interface MessageCodec { /** * Decodes the specified message from binary. * + *

Warning: The ByteBuffer is "direct" and it won't be valid beyond this call. Storing + * the ByteBuffer and using it later and will lead to a {@code java.nio.BufferUnderflowException}. + * If you want to retain the data you'll need to copy it. + * * @param message the {@link ByteBuffer} message, possibly null. * @return a T value representation of the bytes between the given buffer's current position and * its limit, or null, if message is null. diff --git a/shell/platform/android/platform_view_android_jni_impl.cc b/shell/platform/android/platform_view_android_jni_impl.cc index 16e58e46ce280..dc0cc7ec39836 100644 --- a/shell/platform/android/platform_view_android_jni_impl.cc +++ b/shell/platform/android/platform_view_android_jni_impl.cc @@ -828,7 +828,7 @@ bool RegisterApi(JNIEnv* env) { g_handle_platform_message_method = env->GetMethodID(g_flutter_jni_class->obj(), "handlePlatformMessage", - "(Ljava/lang/String;[BI)V"); + "(Ljava/lang/String;Ljava/nio/ByteBuffer;I)V"); if (g_handle_platform_message_method == nullptr) { FML_LOG(ERROR) << "Could not locate handlePlatformMessage method"; @@ -836,7 +836,8 @@ bool RegisterApi(JNIEnv* env) { } g_handle_platform_message_response_method = env->GetMethodID( - g_flutter_jni_class->obj(), "handlePlatformMessageResponse", "(I[B)V"); + g_flutter_jni_class->obj(), "handlePlatformMessageResponse", + "(ILjava/nio/ByteBuffer;)V"); if (g_handle_platform_message_response_method == nullptr) { FML_LOG(ERROR) << "Could not locate handlePlatformMessageResponse method"; @@ -1107,11 +1108,10 @@ void PlatformViewAndroidJNIImpl::FlutterViewHandlePlatformMessage( fml::jni::StringToJavaString(env, message->channel()); if (message->hasData()) { - fml::jni::ScopedJavaLocalRef message_array( - env, env->NewByteArray(message->data().GetSize())); - env->SetByteArrayRegion( - message_array.obj(), 0, message->data().GetSize(), - reinterpret_cast(message->data().GetMapping())); + fml::jni::ScopedJavaLocalRef message_array( + env, env->NewDirectByteBuffer( + const_cast(message->data().GetMapping()), + message->data().GetSize())); env->CallVoidMethod(java_object.obj(), g_handle_platform_message_method, java_channel.obj(), message_array.obj(), responseId); } else { @@ -1141,10 +1141,9 @@ void PlatformViewAndroidJNIImpl::FlutterViewHandlePlatformMessageResponse( nullptr); } else { // Convert the vector to a Java byte array. - fml::jni::ScopedJavaLocalRef data_array( - env, env->NewByteArray(data->GetSize())); - env->SetByteArrayRegion(data_array.obj(), 0, data->GetSize(), - reinterpret_cast(data->GetMapping())); + fml::jni::ScopedJavaLocalRef data_array( + env, env->NewDirectByteBuffer(const_cast(data->GetMapping()), + data->GetSize())); env->CallVoidMethod(java_object.obj(), g_handle_platform_message_response_method, responseId, diff --git a/shell/platform/android/test/io/flutter/embedding/engine/dart/DartMessengerTest.java b/shell/platform/android/test/io/flutter/embedding/engine/dart/DartMessengerTest.java index 3529448fc85b6..af4008e98c72e 100644 --- a/shell/platform/android/test/io/flutter/embedding/engine/dart/DartMessengerTest.java +++ b/shell/platform/android/test/io/flutter/embedding/engine/dart/DartMessengerTest.java @@ -1,11 +1,13 @@ package io.flutter.embedding.engine.dart; +import static junit.framework.TestCase.assertEquals; import static junit.framework.TestCase.assertNotNull; import static junit.framework.TestCase.assertTrue; import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; import io.flutter.embedding.engine.FlutterJNI; +import io.flutter.plugin.common.BinaryMessenger; import io.flutter.plugin.common.BinaryMessenger.BinaryMessageHandler; import java.nio.ByteBuffer; import org.junit.Test; @@ -46,9 +48,80 @@ public void itHandlesErrors() { .onMessage(any(ByteBuffer.class), any(DartMessenger.Reply.class)); messenger.setMessageHandler("test", throwingHandler); - messenger.handleMessageFromDart("test", new byte[] {}, 0); + messenger.handleMessageFromDart("test", ByteBuffer.allocate(0), 0); assertNotNull(reportingHandler.latestException); assertTrue(reportingHandler.latestException instanceof AssertionError); currentThread.setUncaughtExceptionHandler(savedHandler); } + + @Test + public void givesDirectByteBuffer() { + // Setup test. + final FlutterJNI fakeFlutterJni = mock(FlutterJNI.class); + final DartMessenger messenger = new DartMessenger(fakeFlutterJni); + final String channel = "foobar"; + final boolean[] wasDirect = {false}; + final BinaryMessenger.BinaryMessageHandler handler = + (message, reply) -> { + wasDirect[0] = message.isDirect(); + }; + messenger.setMessageHandler(channel, handler); + final ByteBuffer message = ByteBuffer.allocateDirect(4 * 2); + message.rewind(); + message.putChar('a'); + message.putChar('b'); + message.putChar('c'); + message.putChar('d'); + messenger.handleMessageFromDart(channel, message, /*replyId=*/ 123); + assertTrue(wasDirect[0]); + } + + @Test + public void directByteBufferLimitZeroAfterUsage() { + // Setup test. + final FlutterJNI fakeFlutterJni = mock(FlutterJNI.class); + final DartMessenger messenger = new DartMessenger(fakeFlutterJni); + final String channel = "foobar"; + final ByteBuffer[] byteBuffers = {null}; + final int bufferSize = 4 * 2; + final BinaryMessenger.BinaryMessageHandler handler = + (message, reply) -> { + byteBuffers[0] = message; + assertEquals(bufferSize, byteBuffers[0].limit()); + }; + messenger.setMessageHandler(channel, handler); + final ByteBuffer message = ByteBuffer.allocateDirect(bufferSize); + message.rewind(); + message.putChar('a'); + message.putChar('b'); + message.putChar('c'); + message.putChar('d'); + messenger.handleMessageFromDart(channel, message, /*replyId=*/ 123); + assertNotNull(byteBuffers[0]); + assertTrue(byteBuffers[0].isDirect()); + assertEquals(0, byteBuffers[0].limit()); + } + + @Test + public void directByteBufferLimitZeroAfterReply() { + // Setup test. + final FlutterJNI fakeFlutterJni = mock(FlutterJNI.class); + final DartMessenger messenger = new DartMessenger(fakeFlutterJni); + final ByteBuffer message = ByteBuffer.allocateDirect(4 * 2); + final String channel = "foobar"; + message.rewind(); + message.putChar('a'); + message.putChar('b'); + message.putChar('c'); + message.putChar('d'); + final ByteBuffer[] byteBuffers = {null}; + BinaryMessenger.BinaryReply callback = + (reply) -> { + assertTrue(reply.isDirect()); + byteBuffers[0] = reply; + }; + messenger.send(channel, null, callback); + messenger.handlePlatformMessageResponse(1, message); + assertEquals(0, byteBuffers[0].limit()); + } } diff --git a/shell/platform/android/test/io/flutter/plugin/platform/PlatformViewsControllerTest.java b/shell/platform/android/test/io/flutter/plugin/platform/PlatformViewsControllerTest.java index 1b75c81705f94..12db03678b61c 100644 --- a/shell/platform/android/test/io/flutter/plugin/platform/PlatformViewsControllerTest.java +++ b/shell/platform/android/test/io/flutter/plugin/platform/PlatformViewsControllerTest.java @@ -630,12 +630,10 @@ public void checkInputConnectionProxy__falseIfViewIsNull() { assertFalse(shouldProxying); } - private static byte[] encodeMethodCall(MethodCall call) { + private static ByteBuffer encodeMethodCall(MethodCall call) { final ByteBuffer buffer = StandardMethodCodec.INSTANCE.encodeMethodCall(call); buffer.rewind(); - final byte[] dest = new byte[buffer.remaining()]; - buffer.get(dest); - return dest; + return buffer; } private static void createPlatformView(