diff --git a/shell/platform/android/io/flutter/embedding/engine/FlutterJNI.java b/shell/platform/android/io/flutter/embedding/engine/FlutterJNI.java index b4aaff8bcde33..17f918174af50 100644 --- a/shell/platform/android/io/flutter/embedding/engine/FlutterJNI.java +++ b/shell/platform/android/io/flutter/embedding/engine/FlutterJNI.java @@ -902,8 +902,11 @@ private native void nativeInvokePlatformMessageEmptyResponseCallback( // TODO(mattcarroll): differentiate between channel responses and platform responses. @UiThread public void invokePlatformMessageResponseCallback( - int responseId, @Nullable ByteBuffer message, int position) { + int responseId, @NonNull ByteBuffer message, int position) { ensureRunningOnMainThread(); + if (!message.isDirect()) { + throw new IllegalArgumentException("Expected a direct ByteBuffer."); + } if (isAttached()) { nativeInvokePlatformMessageResponseCallback( nativeShellHolderId, responseId, message, position); diff --git a/shell/platform/android/platform_view_android.cc b/shell/platform/android/platform_view_android.cc index 1420fbebe4b9b..aa83d9cacbfeb 100644 --- a/shell/platform/android/platform_view_android.cc +++ b/shell/platform/android/platform_view_android.cc @@ -220,6 +220,7 @@ void PlatformViewAndroid::InvokePlatformMessageResponseCallback( return; uint8_t* response_data = static_cast(env->GetDirectBufferAddress(java_response_data)); + FML_DCHECK(response_data != nullptr); std::vector response = std::vector( response_data, response_data + java_response_position); auto message_response = std::move(it->second); diff --git a/shell/platform/android/test/io/flutter/embedding/engine/FlutterJNITest.java b/shell/platform/android/test/io/flutter/embedding/engine/FlutterJNITest.java index 5f8eeea590172..5f057c36609a0 100644 --- a/shell/platform/android/test/io/flutter/embedding/engine/FlutterJNITest.java +++ b/shell/platform/android/test/io/flutter/embedding/engine/FlutterJNITest.java @@ -17,6 +17,7 @@ import io.flutter.embedding.engine.systemchannels.LocalizationChannel; import io.flutter.plugin.localization.LocalizationPlugin; import io.flutter.plugin.platform.PlatformViewsController; +import java.nio.ByteBuffer; import java.util.Locale; import java.util.concurrent.atomic.AtomicInteger; import org.junit.Test; @@ -237,4 +238,11 @@ public void createOverlaySurface__callsPlatformViewsController() { // --- Verify Results --- verify(platformViewsController, times(1)).createOverlaySurface(); } + + @Test(expected = IllegalArgumentException.class) + public void invokePlatformMessageResponseCallback__wantsDirectBuffer() { + FlutterJNI flutterJNI = new FlutterJNI(); + ByteBuffer buffer = ByteBuffer.allocate(4); + flutterJNI.invokePlatformMessageResponseCallback(0, buffer, buffer.position()); + } }