diff --git a/components/viaduct/android/src/main/java/mozilla/appservices/httpconfig/Helpers.kt b/components/viaduct/android/src/main/java/mozilla/appservices/httpconfig/Helpers.kt new file mode 100644 index 0000000000..0f18c05cef --- /dev/null +++ b/components/viaduct/android/src/main/java/mozilla/appservices/httpconfig/Helpers.kt @@ -0,0 +1,301 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +package mozilla.appservices.httpconfig + +// TODO: We'd like to be using the a-c log tooling here, but adding that +// dependency is slightly tricky (This also could run before its log sink +// is setup!). Since logging here very much helps debugging substitution +// issues, we just use logcat. + +import android.util.Log +import com.google.protobuf.CodedOutputStream +import com.google.protobuf.MessageLite +import com.sun.jna.DefaultTypeMapper +import com.sun.jna.FromNativeContext +import com.sun.jna.Library +import com.sun.jna.Memory +import com.sun.jna.Native +import com.sun.jna.Pointer +import com.sun.jna.ToNativeContext +import com.sun.jna.TypeConverter +import java.nio.ByteBuffer +import java.nio.ByteOrder + +/** + * A helper for converting a protobuf Message into a direct `java.nio.ByteBuffer` + * and its length. This avoids a copy when passing data to Rust, when compared + * to using an `Array` + */ + +fun T.toNioDirectBuffer(): Pair { + val len = this.serializedSize + val nioBuf = ByteBuffer.allocateDirect(len) + nioBuf.order(ByteOrder.nativeOrder()) + val output = CodedOutputStream.newInstance(nioBuf) + this.writeTo(output) + output.checkNoSpaceLeft() + return Pair(first = nioBuf, second = len) +} + +sealed class MegazordError : Exception { + /** + * The name of the component we were trying to initialize when we had the error. + */ + val componentName: String + + constructor(componentName: String, msg: String) : super(msg) { + this.componentName = componentName + } + + constructor(componentName: String, msg: String, cause: Throwable) : super(msg, cause) { + this.componentName = componentName + } +} + +class IncompatibleMegazordVersion( + componentName: String, + val componentVersion: String, + val megazordLibrary: String, + val megazordVersion: String?, +) : MegazordError( + componentName, + "Incompatible megazord version: library \"$componentName\" was compiled expecting " + + "app-services version \"$componentVersion\", but the megazord \"$megazordLibrary\" provides " + + "version \"${megazordVersion ?: "unknown"}\"", +) + +class MegazordNotInitialized(componentName: String) : MegazordError( + componentName, + "The application-services megazord has not yet been initialized, but is needed by \"$componentName\"", +) + +/** + * I think we'd expect this to be caused by the following two things both happening + * + * 1. Substitution not actually replacing the full megazord + * 2. Megazord initialization getting called after the first attempt to load something from the + * megazord, causing us to fall back to checking the full-megazord (and finding it, because + * of #1). + * + * It's very unlikely, but if it did happen it could be a memory safety error, so we check. + */ +class MultipleMegazordsPresent( + componentName: String, + val loadedMegazord: String, + val requestedMegazord: String, +) : MegazordError( + componentName, + "Multiple megazords are present, and bindings have already been loaded from " + + "\"$loadedMegazord\" when a request to load $componentName from $requestedMegazord " + + "is made. (This probably stems from an error in your build configuration)", +) + +internal const val FULL_MEGAZORD_LIBRARY: String = "megazord" + +internal fun lookupMegazordLibrary(componentName: String, componentVersion: String): String { + val mzLibrary = System.getProperty("mozilla.appservices.megazord.library") + Log.d("RustNativeSupport", "lib configured: ${mzLibrary ?: "none"}") + if (mzLibrary == null) { + // If it's null, then the megazord hasn't been initialized. + if (checkFullMegazord(componentName, componentVersion)) { + return FULL_MEGAZORD_LIBRARY + } + Log.e( + "RustNativeSupport", + "megazord not initialized, and default not present. failing to init $componentName", + ) + throw MegazordNotInitialized(componentName) + } + + // Assume it's properly initialized if it's been initialized at all + val mzVersion = System.getProperty("mozilla.appservices.megazord.version")!! + Log.d("RustNativeSupport", "lib version configured: $mzVersion") + + // We require exact equality, since we don't perform a major version + // bump if we change the ABI. In practice, this seems unlikely to + // cause problems, but we could come up with a scheme if this proves annoying. + if (componentVersion != mzVersion) { + Log.e( + "RustNativeSupport", + "version requested by component doesn't match initialized " + + "megazord version ($componentVersion != $mzVersion)", + ) + throw IncompatibleMegazordVersion(componentName, componentVersion, mzLibrary, mzVersion) + } + return mzLibrary +} + +/** + * Determine the megazord library name, and check that its version is + * compatible with the version of our bindings. Returns the megazord + * library name. + * + * Note: This is only public because it's called by an inline function. + * It should not be called by consumers. + */ +@Synchronized +fun findMegazordLibraryName(componentName: String, componentVersion: String): String { + Log.d("RustNativeSupport", "findMegazordLibraryName($componentName, $componentVersion") + val mzLibraryUsed = System.getProperty("mozilla.appservices.megazord.library.used") + Log.d("RustNativeSupport", "lib in use: ${mzLibraryUsed ?: "none"}") + val mzLibraryDetermined = lookupMegazordLibrary(componentName, componentVersion) + Log.d("RustNativeSupport", "settled on $mzLibraryDetermined") + + // If we've already initialized the megazord, that means we've probably already loaded bindings + // from it somewhere. It would be a big problem for us to use some bindings from one lib and + // some from another, so we just fail. + if (mzLibraryUsed != null && mzLibraryDetermined != mzLibraryUsed) { + Log.e( + "RustNativeSupport", + "Different than first time through ($mzLibraryDetermined != $mzLibraryUsed)!", + ) + throw MultipleMegazordsPresent(componentName, mzLibraryUsed, mzLibraryDetermined) + } + + // Mark that we're about to load bindings from the specified lib. Note that we don't do this + // in the case that the megazord check threw. + if (mzLibraryUsed != null) { + Log.d("RustNativeSupport", "setting first time through: $mzLibraryDetermined") + System.setProperty("mozilla.appservices.megazord.library.used", mzLibraryDetermined) + } + return mzLibraryDetermined +} + +/** + * Contains all the boilerplate for loading a library binding from the megazord, + * locating it if necessary, safety-checking versions, and setting up a fallback + * if loading fails. + * + * Indirect as in, we aren't using JNA direct mapping. Eventually we'd + * like to (it's faster), but that's a problem for another day. + */ +inline fun loadIndirect( + componentName: String, + componentVersion: String, +): Lib { + val mzLibrary = findMegazordLibraryName(componentName, componentVersion) + // Rust code always expects strings to be UTF-8 encoded. + // Unfortunately, the `STRING_ENCODING` option doesn't seem to apply + // to function arguments, only to return values, so, if the default encoding + // is not UTF-8, we need to use an explicit TypeMapper to ensure that + // strings are handled correctly. + // Further, see also https://github.com/mozilla/uniffi-rs/issues/1044 - if + // this code and our uniffi-generated code don't agree on the options, it's + // possible our megazord gets loaded twice, breaking things in + // creative/obscure ways. + // We used to unconditionally set `options[Library.OPTION_STRING_ENCODING] = "UTF-8"` + // but we now only do it if we really need to. This means in practice, both + // us and uniffi agree everywhere we care about. + // Assuming uniffi fixes this in the same way we've done it here, there should be + // no need to adjust anything once that issue is fixed. + val options: MutableMap = mutableMapOf() + if (Native.getDefaultStringEncoding() != "UTF-8") { + options[Library.OPTION_STRING_ENCODING] = "UTF-8" + options[Library.OPTION_TYPE_MAPPER] = UTF8TypeMapper() + } + return Native.load(mzLibrary, Lib::class.java, options) +} + +// See the comment on full_megazord_get_version for background +// on why this exists and what we use it for. +@Suppress("FunctionNaming", "ktlint:standard:function-naming") +internal interface LibMegazordFfi : Library { + // Note: Rust doesn't want us to free this string (because + // it's a pain for us to arrange here), so it is actually + // correct for us to return a String over the FFI for this. + fun full_megazord_get_version(): String? +} + +/** + * Try and load the full megazord library, call the function for getting its + * version, and check it against componentVersion. + * + * - If the megazord does not exist, returns false + * - If the megazord exists and the version is valid, returns true. + * - If the megazord exists and the version is invalid, throws a IncompatibleMegazordVersion error. + * (This is done here instead of returning false so that we can provide better info in the error) + */ +internal fun checkFullMegazord(componentName: String, componentVersion: String): Boolean { + return try { + Log.d( + "RustNativeSupport", + "No lib configured, trying full megazord", + ) + // It's not ideal to do this every time, but it should be rare, not too costly, + // and the workaround for the app is simple (just init the megazord). + val lib = Native.load(FULL_MEGAZORD_LIBRARY, LibMegazordFfi::class.java) + + val version = lib.full_megazord_get_version() + + Log.d( + "RustNativeSupport", + "found full megazord, it self-reports version as: ${version ?: "unknown"}", + ) + if (version == null) { + throw IncompatibleMegazordVersion( + componentName, + componentVersion, + FULL_MEGAZORD_LIBRARY, + null, + ) + } + + if (version != componentVersion) { + Log.e( + "RustNativeSupport", + "found default megazord, but versions don't match ($version != $componentVersion)", + ) + throw IncompatibleMegazordVersion( + componentName, + componentVersion, + FULL_MEGAZORD_LIBRARY, + version, + ) + } + + true + } catch (e: UnsatisfiedLinkError) { + Log.e("RustNativeSupport", "Default megazord not found: ${e.localizedMessage}") + if (componentVersion.startsWith("0.0.1-SNAPSHOT")) { + Log.i("RustNativeSupport", "It looks like you're using a local development build.") + Log.i("RustNativeSupport", "You may need to check that `rust.targets` contains the appropriate platforms.") + } + false + } +} + +/** + * A JNA TypeMapper that always converts strings as UTF-8 bytes. + * + * Rust always expects strings to be in UTF-8, but JNA defaults to using the + * system encoding. This is *often* UTF-8, but not always. In cases where it + * isn't you can use this TypeMapper to ensure Strings are correctly + * interpreted by Rust. + * + * The logic here is essentially the same as what JNA does by default + * with String values, but explicitly using a fixed UTF-8 encoding. + * + */ +public class UTF8TypeMapper : DefaultTypeMapper() { + init { + addTypeConverter(String::class.java, UTF8TypeConverter()) + } +} + +internal class UTF8TypeConverter : TypeConverter { + override fun toNative(value: Any, context: ToNativeContext): Any { + val bytes = (value as String).toByteArray(Charsets.UTF_8) + val mem = Memory(bytes.size.toLong() + 1L) + mem.write(0, bytes, 0, bytes.size) + mem.setByte(bytes.size.toLong(), 0) + return mem + } + + override fun fromNative(value: Any, context: FromNativeContext): Any { + return (value as Pointer).getString(0, Charsets.UTF_8.name()) + } + + override fun nativeType() = Pointer::class.java +} diff --git a/components/viaduct/android/src/main/java/mozilla/appservices/httpconfig/HttpConfig.kt b/components/viaduct/android/src/main/java/mozilla/appservices/httpconfig/HttpConfig.kt index 368e21a7e1..0b073c04ed 100644 --- a/components/viaduct/android/src/main/java/mozilla/appservices/httpconfig/HttpConfig.kt +++ b/components/viaduct/android/src/main/java/mozilla/appservices/httpconfig/HttpConfig.kt @@ -10,8 +10,11 @@ import mozilla.components.concept.fetch.Client import mozilla.components.concept.fetch.MutableHeaders import mozilla.components.concept.fetch.Request import java.util.concurrent.TimeUnit +import java.util.concurrent.locks.ReentrantReadWriteLock import kotlin.concurrent.atomics.AtomicBoolean import kotlin.concurrent.atomics.ExperimentalAtomicApi +import kotlin.concurrent.read +import kotlin.concurrent.write import mozilla.appservices.viaduct.allowAndroidEmulatorLoopback as rustAllowAndroidEmulatorLoopback /** @@ -31,9 +34,22 @@ class UnsupportedRequestMethodError(method: String) : */ @OptIn(ExperimentalAtomicApi::class) object RustHttpConfig { - // Used to only initialize the client once - // https://bugzilla.mozilla.org/show_bug.cgi?id=1989865. - private var backendInitialized = AtomicBoolean(false) + // Protects imp/client + private var lock = ReentrantReadWriteLock() + + // Add back once we move to the new backend again + // // Used to only initialize the client once + // // https://bugzilla.mozilla.org/show_bug.cgi?id=1989865. + // private var backendInitialized = AtomicBoolean(false) + + @Volatile + private var client: Lazy? = null + + // Important note to future maintainers: if you mess around with + // this code, you have to make sure `imp` can't get GCed. Extremely + // bad things will happen if it does! + @Volatile + private var imp: CallbackImpl? = null /** * Set the HTTP client to be used by all Rust code. @@ -41,8 +57,12 @@ object RustHttpConfig { */ @Synchronized fun setClient(c: Lazy) { - if (backendInitialized.compareAndSet(false, true)) { - initBackend(FetchBackend(c)) + lock.write { + client = c + if (imp == null) { + imp = CallbackImpl() + LibViaduct.INSTANCE.viaduct_initialize(imp!!) + } } } @@ -53,7 +73,9 @@ object RustHttpConfig { * are sure you are running on an emulator. */ fun allowAndroidEmulatorLoopback() { - rustAllowAndroidEmulatorLoopback() + lock.read { + LibViaduct.INSTANCE.viaduct_allow_android_emulator_loopback() + } } internal fun convertRequest(request: MsgTypes.Request): Request { @@ -81,6 +103,53 @@ object RustHttpConfig { useCaches = request.useCaches, ) } + + @Suppress("TooGenericExceptionCaught", "ReturnCount") + internal fun doFetch(b: RustBuffer.ByValue): RustBuffer.ByValue { + lock.read { + try { + val request = MsgTypes.Request.parseFrom(b.asCodedInputStream()) + val rb = try { + // Note: `client!!` is fine here, since if client is null, + // we wouldn't have yet initialized + val resp = client!!.value.fetch(convertRequest(request)) + val rb = MsgTypes.Response.newBuilder() + .setUrl(resp.url) + .setStatus(resp.status) + .setBody( + resp.body.useStream { + ByteString.readFrom(it) + }, + ) + + for (h in resp.headers) { + rb.putHeaders(h.name, h.value) + } + rb + } catch (e: Throwable) { + MsgTypes.Response.newBuilder() + .setExceptionMessage("fetch error: ${e.message ?: e.javaClass.canonicalName}") + } + val built = rb.build() + val needed = built.serializedSize + val outputBuf = LibViaduct.INSTANCE.viaduct_alloc_bytebuffer(needed) + try { + // This is only null if we passed a negative number or something to + // viaduct_alloc_bytebuffer. + val stream = outputBuf.asCodedOutputStream()!! + built.writeTo(stream) + return outputBuf + } catch (e: Throwable) { + // Note: we want to clean this up only if we are not returning it to rust. + LibViaduct.INSTANCE.viaduct_destroy_bytebuffer(outputBuf) + LibViaduct.INSTANCE.viaduct_log_error("Failed to write buffer: ${e.message}") + throw e + } + } finally { + LibViaduct.INSTANCE.viaduct_destroy_bytebuffer(b) + } + } + } } internal fun convertMethod(m: MsgTypes.Request.Method): Request.Method { @@ -96,3 +165,17 @@ internal fun convertMethod(m: MsgTypes.Request.Method): Request.Method { else -> throw UnsupportedRequestMethodError(m.toString()) } } + +internal class CallbackImpl : RawFetchCallback { + @Suppress("TooGenericExceptionCaught") + override fun invoke(b: RustBuffer.ByValue): RustBuffer.ByValue { + try { + return RustHttpConfig.doFetch(b) + } catch (e: Throwable) { + LibViaduct.INSTANCE.viaduct_log_error("doFetch failed: ${e.message}") + // This is our last resort. It's bad news should we fail to + // return something from this function. + return RustBuffer.ByValue() + } + } +} diff --git a/components/viaduct/android/src/main/java/mozilla/appservices/httpconfig/LibViaduct.kt b/components/viaduct/android/src/main/java/mozilla/appservices/httpconfig/LibViaduct.kt new file mode 100644 index 0000000000..8cf4b9ad58 --- /dev/null +++ b/components/viaduct/android/src/main/java/mozilla/appservices/httpconfig/LibViaduct.kt @@ -0,0 +1,40 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +@file:Suppress("ktlint:standard:function-naming", "ktlint:standard:property-naming") + +package mozilla.appservices.httpconfig + +import com.sun.jna.Callback +import com.sun.jna.Library +import org.mozilla.appservices.httpconfig.BuildConfig + +@Suppress("FunctionNaming", "TooGenericExceptionThrown") +internal interface LibViaduct : Library { + companion object { + internal var INSTANCE: LibViaduct = { + loadIndirect( + componentName = "viaduct", + componentVersion = BuildConfig.LIBRARY_VERSION, + ) + }() + } + + fun viaduct_destroy_bytebuffer(b: RustBuffer.ByValue) + + // Returns null buffer to indicate failure + fun viaduct_alloc_bytebuffer(sz: Int): RustBuffer.ByValue + + // Returns 0 to indicate redundant init. + fun viaduct_initialize(cb: RawFetchCallback): Byte + + // No return value, never fails. + fun viaduct_allow_android_emulator_loopback() + + fun viaduct_log_error(s: String) +} + +internal interface RawFetchCallback : Callback { + fun invoke(b: RustBuffer.ByValue): RustBuffer.ByValue +} diff --git a/components/viaduct/android/src/main/java/mozilla/appservices/httpconfig/RustBuffer.kt b/components/viaduct/android/src/main/java/mozilla/appservices/httpconfig/RustBuffer.kt new file mode 100644 index 0000000000..63301b41b5 --- /dev/null +++ b/components/viaduct/android/src/main/java/mozilla/appservices/httpconfig/RustBuffer.kt @@ -0,0 +1,75 @@ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +package mozilla.appservices.httpconfig + +import com.google.protobuf.CodedInputStream +import com.google.protobuf.CodedOutputStream +import com.sun.jna.Pointer +import com.sun.jna.Structure +import java.nio.ByteBuffer + +/** + * This is a mapping for the `ffi_support::ByteBuffer` struct. + * + * The name differs for two reasons. + * + * 1. To that the memory this type manages is allocated from rust code, + * and must subsequently be freed by rust code. + * + * 2. To avoid confusion with java's nio ByteBuffer, which we use for + * passing data *to* Rust without incurring additional copies. + * + * # Caveats: + * + * 1. It is for receiving data *FROM* Rust, and not the other direction. + * RustBuffer doesn't expose a way to inspect its contents from Rust. + * See `docs/howtos/passing-protobuf-data-over-ffi.md` for how to do + * this instead. + * + * 2. A `RustBuffer` passed into kotlin code must be freed by kotlin + * code *after* the protobuf message is completely deserialized. + * + * The rust code must expose a destructor for this purpose, + * and it should be called in the finally block after the data + * is read from the `CodedInputStream` (and not before). + * + * 3. You almost always should use `RustBuffer.ByValue` instead + * of `RustBuffer`. E.g. + * `fun mylib_get_stuff(some: X, args: Y): RustBuffer.ByValue` + * for the function returning the RustBuffer, and + * `fun mylib_destroy_bytebuffer(bb: RustBuffer.ByValue)`. + */ +@Structure.FieldOrder("len", "data") +open class RustBuffer : Structure() { + @JvmField var len: Long = 0 + + @JvmField var data: Pointer? = null + + @Suppress("TooGenericExceptionThrown") + fun asCodedInputStream(): CodedInputStream? { + return this.data?.let { + // We use a ByteArray instead of a ByteBuffer to avoid triggering the following code path: + // https://github.com/protocolbuffers/protobuf/blob/e667bf6eaaa2fb1ba2987c6538df81f88500d030/java/core/src/main/java/com/google/protobuf/CodedInputStream.java#L185-L187 + // Bug: https://github.com/protocolbuffers/protobuf/issues/7422 + if (this.len < Int.MIN_VALUE || this.len > Int.MAX_VALUE) { + throw RuntimeException("len does not fit in a int") + } + CodedInputStream.newInstance(it.getByteArray(0, this.len.toInt())) + } + } + + fun asCodedOutputStream(): CodedOutputStream? { + return this.data?.let { + // We use newSafeInstance through reflection to avoid triggering the following code path: + // https://github.com/protocolbuffers/protobuf/blob/e667bf6eaaa2fb1ba2987c6538df81f88500d030/java/core/src/main/java/com/google/protobuf/CodedOutputStream.java#L134-L136 + // Bug: https://github.com/protocolbuffers/protobuf/issues/7422 + val method = CodedOutputStream::class.java.getDeclaredMethod("newSafeInstance", ByteBuffer::class.java) + method.isAccessible = true + return method.invoke(null, it.getByteBuffer(0, this.len)) as CodedOutputStream + } + } + + class ByValue : RustBuffer(), Structure.ByValue +}