diff --git a/src/main/java/io/appium/java_client/AppiumDriver.java b/src/main/java/io/appium/java_client/AppiumDriver.java index 1e880c569..5c60ada7f 100644 --- a/src/main/java/io/appium/java_client/AppiumDriver.java +++ b/src/main/java/io/appium/java_client/AppiumDriver.java @@ -26,6 +26,7 @@ import io.appium.java_client.internal.CapabilityHelpers; import io.appium.java_client.internal.JsonToMobileElementConverter; import io.appium.java_client.remote.AppiumCommandExecutor; +import io.appium.java_client.remote.AppiumNewSessionCommandPayload; import io.appium.java_client.remote.MobileCapabilityType; import io.appium.java_client.service.local.AppiumDriverLocalService; import io.appium.java_client.service.local.AppiumServiceBuilder; @@ -34,6 +35,7 @@ import org.openqa.selenium.DeviceRotation; import org.openqa.selenium.MutableCapabilities; import org.openqa.selenium.ScreenOrientation; +import org.openqa.selenium.SessionNotCreatedException; import org.openqa.selenium.WebDriver; import org.openqa.selenium.WebDriverException; import org.openqa.selenium.WebElement; @@ -44,11 +46,13 @@ import org.openqa.selenium.remote.ErrorHandler; import org.openqa.selenium.remote.ExecuteMethod; import org.openqa.selenium.remote.HttpCommandExecutor; +import org.openqa.selenium.remote.RemoteWebDriver; import org.openqa.selenium.remote.Response; import org.openqa.selenium.remote.html5.RemoteLocationContext; import org.openqa.selenium.remote.http.HttpClient; import org.openqa.selenium.remote.http.HttpMethod; +import java.lang.reflect.Field; import java.net.URL; import java.util.Arrays; import java.util.LinkedHashSet; @@ -299,14 +303,33 @@ public boolean isBrowser() { @Override protected void startSession(Capabilities capabilities) { - super.startSession(capabilities); - // The RemoteWebDriver implementation overrides platformName - // so we need to restore it back to the original value - Object originalPlatformName = capabilities.getCapability(PLATFORM_NAME); - Capabilities originalCaps = super.getCapabilities(); - if (originalPlatformName != null && originalCaps instanceof MutableCapabilities) { - ((MutableCapabilities) super.getCapabilities()).setCapability(PLATFORM_NAME, - originalPlatformName); + Response response = execute(new AppiumNewSessionCommandPayload(capabilities)); + if (response == null) { + throw new SessionNotCreatedException( + "The underlying command executor returned a null response."); } + + Object responseValue = response.getValue(); + if (responseValue == null) { + throw new SessionNotCreatedException( + "The underlying command executor returned a response without payload: " + + response); + } + if (!(responseValue instanceof Map)) { + throw new SessionNotCreatedException( + "The underlying command executor returned a response with a non well formed payload: " + + response); + } + + @SuppressWarnings("unchecked") Map rawCapabilities = (Map) responseValue; + MutableCapabilities returnedCapabilities = new MutableCapabilities(rawCapabilities); + try { + Field capsField = RemoteWebDriver.class.getDeclaredField("capabilities"); + capsField.setAccessible(true); + capsField.set(this, returnedCapabilities); + } catch (NoSuchFieldException | IllegalAccessException e) { + throw new WebDriverException(e); + } + setSessionId(response.getSessionId()); } } diff --git a/src/main/java/io/appium/java_client/remote/AppiumCommandExecutor.java b/src/main/java/io/appium/java_client/remote/AppiumCommandExecutor.java index e98d1a2d9..0371748fe 100644 --- a/src/main/java/io/appium/java_client/remote/AppiumCommandExecutor.java +++ b/src/main/java/io/appium/java_client/remote/AppiumCommandExecutor.java @@ -145,7 +145,7 @@ private Response createSession(Command command) throws IOException { throw new SessionNotCreatedException("Session already exists"); } - ProtocolHandshake.Result result = new ProtocolHandshake().createSession( + ProtocolHandshake.Result result = new AppiumProtocolHandshake().createSession( getClient().with((httpHandler) -> (req) -> { req.setHeader(IDEMPOTENCY_KEY_HEADER, UUID.randomUUID().toString().toLowerCase()); return httpHandler.execute(req); diff --git a/src/main/java/io/appium/java_client/remote/AppiumNewSessionCommandPayload.java b/src/main/java/io/appium/java_client/remote/AppiumNewSessionCommandPayload.java new file mode 100644 index 000000000..fd81d801c --- /dev/null +++ b/src/main/java/io/appium/java_client/remote/AppiumNewSessionCommandPayload.java @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.appium.java_client.remote; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import org.openqa.selenium.Capabilities; +import org.openqa.selenium.internal.Require; +import org.openqa.selenium.remote.AcceptedW3CCapabilityKeys; +import org.openqa.selenium.remote.CommandPayload; + +import java.util.AbstractMap; +import java.util.Map; + +import static io.appium.java_client.internal.CapabilityHelpers.APPIUM_PREFIX; +import static org.openqa.selenium.remote.DriverCommand.NEW_SESSION; + +public class AppiumNewSessionCommandPayload extends CommandPayload { + private static final AcceptedW3CCapabilityKeys ACCEPTED_W3C_PATTERNS = new AcceptedW3CCapabilityKeys(); + + /** + * Appends "appium:" prefix to all non-prefixed non-standard capabilities. + * + * @param possiblyInvalidCapabilities user-provided capabilities mapping. + * @return Fixed capabilities mapping. + */ + private static Map makeW3CSafe(Capabilities possiblyInvalidCapabilities) { + Require.nonNull("Capabilities", possiblyInvalidCapabilities); + + return possiblyInvalidCapabilities.asMap().entrySet().stream() + .map((entry) -> ACCEPTED_W3C_PATTERNS.test(entry.getKey()) + ? entry + : new AbstractMap.SimpleEntry<>( + String.format("%s%s", APPIUM_PREFIX, entry.getKey()), entry.getValue())) + .collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); + } + + /** + * Overrides the default new session behavior to + * only handle W3C capabilities. + * + * @param capabilities User-provided capabilities. + */ + public AppiumNewSessionCommandPayload(Capabilities capabilities) { + super(NEW_SESSION, ImmutableMap.of( + "capabilities", ImmutableSet.of(makeW3CSafe(capabilities)), + "desiredCapabilities", capabilities + )); + } +} diff --git a/src/main/java/io/appium/java_client/remote/AppiumProtocolHandshake.java b/src/main/java/io/appium/java_client/remote/AppiumProtocolHandshake.java new file mode 100644 index 000000000..98b128554 --- /dev/null +++ b/src/main/java/io/appium/java_client/remote/AppiumProtocolHandshake.java @@ -0,0 +1,136 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.appium.java_client.remote; + +import com.google.common.io.CountingOutputStream; +import com.google.common.io.FileBackedOutputStream; +import org.openqa.selenium.Capabilities; +import org.openqa.selenium.ImmutableCapabilities; +import org.openqa.selenium.SessionNotCreatedException; +import org.openqa.selenium.WebDriverException; +import org.openqa.selenium.internal.Either; +import org.openqa.selenium.json.Json; +import org.openqa.selenium.json.JsonOutput; +import org.openqa.selenium.remote.Command; +import org.openqa.selenium.remote.NewSessionPayload; +import org.openqa.selenium.remote.ProtocolHandshake; +import org.openqa.selenium.remote.http.HttpHandler; + +import java.io.BufferedInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStreamWriter; +import java.io.Writer; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.Map; +import java.util.Set; +import java.util.stream.Stream; + +import static java.nio.charset.StandardCharsets.UTF_8; + +@SuppressWarnings("UnstableApiUsage") +public class AppiumProtocolHandshake extends ProtocolHandshake { + private static void writeJsonPayload(NewSessionPayload srcPayload, Appendable destination) { + try (JsonOutput json = new Json().newOutput(destination)) { + json.beginObject(); + + json.name("capabilities"); + json.beginObject(); + + json.name("firstMatch"); + json.beginArray(); + json.beginObject(); + json.endObject(); + json.endArray(); + + json.name("alwaysMatch"); + try { + Method getW3CMethod = NewSessionPayload.class.getDeclaredMethod("getW3C"); + getW3CMethod.setAccessible(true); + //noinspection unchecked + ((Stream>) getW3CMethod.invoke(srcPayload)) + .findFirst() + .map(json::write) + .orElseGet(() -> { + json.beginObject(); + json.endObject(); + return null; + }); + } catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException e) { + throw new WebDriverException(e); + } + + json.endObject(); // Close "capabilities" object + + try { + Method writeMetaDataMethod = NewSessionPayload.class.getDeclaredMethod( + "writeMetaData", JsonOutput.class); + writeMetaDataMethod.setAccessible(true); + writeMetaDataMethod.invoke(srcPayload, json); + } catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException e) { + throw new WebDriverException(e); + } + + json.endObject(); + } + } + + @Override + public Result createSession(HttpHandler client, Command command) throws IOException { + //noinspection unchecked + Capabilities desired = ((Set>) command.getParameters().get("capabilities")) + .stream() + .findAny() + .map(ImmutableCapabilities::new) + .orElseGet(ImmutableCapabilities::new); + try (NewSessionPayload payload = NewSessionPayload.create(desired)) { + Either result = createSession(client, payload); + if (result.isRight()) { + return result.right(); + } + throw result.left(); + } + } + + @Override + public Either createSession( + HttpHandler client, NewSessionPayload payload) throws IOException { + int threshold = (int) Math.min(Runtime.getRuntime().freeMemory() / 10, Integer.MAX_VALUE); + FileBackedOutputStream os = new FileBackedOutputStream(threshold); + + try (CountingOutputStream counter = new CountingOutputStream(os); + Writer writer = new OutputStreamWriter(counter, UTF_8)) { + writeJsonPayload(payload, writer); + + try (InputStream rawIn = os.asByteSource().openBufferedStream(); + BufferedInputStream contentStream = new BufferedInputStream(rawIn)) { + Method createSessionMethod = ProtocolHandshake.class.getDeclaredMethod("createSession", + HttpHandler.class, InputStream.class, long.class); + createSessionMethod.setAccessible(true); + //noinspection unchecked + return (Either) createSessionMethod.invoke( + this, client, contentStream, counter.getCount() + ); + } catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException e) { + throw new WebDriverException(e); + } + } finally { + os.reset(); + } + } +}