diff --git a/spark/pom.xml b/spark/pom.xml index 822b989395..3b832e37a2 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -173,8 +173,8 @@ under the License. software.amazon.awssdk s3 - - + + @@ -188,6 +188,19 @@ under the License. 1.5.2 test + + + org.eclipse.jetty + jetty-server + 9.4.53.v20231009 + test + + + org.eclipse.jetty + jetty-servlet + 9.4.53.v20231009 + test + @@ -203,6 +216,19 @@ under the License. 1.8.1 test + + + org.eclipse.jetty + jetty-server + 9.4.53.v20231009 + test + + + org.eclipse.jetty + jetty-servlet + 9.4.53.v20231009 + test + @@ -215,6 +241,19 @@ under the License. 1.10.0 test + + + org.eclipse.jetty + jetty-server + 11.0.24 + test + + + org.eclipse.jetty + jetty-servlet + 11.0.24 + test + diff --git a/spark/src/main/scala/org/apache/comet/iceberg/IcebergReflection.scala b/spark/src/main/scala/org/apache/comet/iceberg/IcebergReflection.scala index fc9fd8e6eb..2d772063e4 100644 --- a/spark/src/main/scala/org/apache/comet/iceberg/IcebergReflection.scala +++ b/spark/src/main/scala/org/apache/comet/iceberg/IcebergReflection.scala @@ -204,9 +204,11 @@ object IcebergReflection extends Logging { case _: NoSuchMethodException => try { // If not directly available, access via operations/metadata - val opsMethod = table.getClass.getMethod("operations") + val opsMethod = table.getClass.getDeclaredMethod("operations") + opsMethod.setAccessible(true) val ops = opsMethod.invoke(table) - val currentMethod = ops.getClass.getMethod("current") + val currentMethod = ops.getClass.getDeclaredMethod("current") + currentMethod.setAccessible(true) val metadata = currentMethod.invoke(ops) val formatVersionMethod = metadata.getClass.getMethod("formatVersion") Some(formatVersionMethod.invoke(metadata).asInstanceOf[Int]) @@ -274,10 +276,12 @@ object IcebergReflection extends Logging { */ def getTableMetadata(table: Any): Option[Any] = { try { - val operationsMethod = table.getClass.getMethod("operations") + val operationsMethod = table.getClass.getDeclaredMethod("operations") + operationsMethod.setAccessible(true) val operations = operationsMethod.invoke(table) - val currentMethod = operations.getClass.getMethod("current") + val currentMethod = operations.getClass.getDeclaredMethod("current") + currentMethod.setAccessible(true) Some(currentMethod.invoke(operations)) } catch { case e: Exception => diff --git a/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala index 8d15223d00..69bce75559 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala @@ -284,22 +284,56 @@ case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] with Com // Extract all Iceberg metadata once using reflection. // If any required reflection fails, this returns None, and we fall back to Spark. // First get metadataLocation and catalogProperties which are needed by the factory. - val metadataLocationOpt = IcebergReflection - .getTable(scanExec.scan) - .flatMap(IcebergReflection.getMetadataLocation) + val tableOpt = IcebergReflection.getTable(scanExec.scan) + + val metadataLocationOpt = tableOpt.flatMap { table => + IcebergReflection.getMetadataLocation(table) + } val metadataOpt = metadataLocationOpt.flatMap { metadataLocation => try { val session = org.apache.spark.sql.SparkSession.active val hadoopConf = session.sessionState.newHadoopConf() + + // For REST catalogs, the metadata file may not exist on disk since metadata + // is fetched via HTTP. Check if file exists; if not, use table location instead. val metadataUri = new java.net.URI(metadataLocation) - val hadoopS3Options = NativeConfig.extractObjectStoreOptions(hadoopConf, metadataUri) + + val metadataFile = new java.io.File(metadataUri.getPath) + + val effectiveLocation = + if (!metadataFile.exists() && metadataUri.getScheme == "file") { + // Metadata file doesn't exist (REST catalog with InMemoryFileIO or similar) + // Use table location instead for FileIO initialization + + tableOpt + .flatMap { table => + try { + val locationMethod = table.getClass.getMethod("location") + val tableLocation = locationMethod.invoke(table).asInstanceOf[String] + Some(tableLocation) + } catch { + case _: Exception => + Some(metadataLocation) + } + } + .getOrElse(metadataLocation) + } else { + metadataLocation + } + + val effectiveUri = new java.net.URI(effectiveLocation) + + val hadoopS3Options = NativeConfig.extractObjectStoreOptions(hadoopConf, effectiveUri) + val catalogProperties = org.apache.comet.serde.operator.CometIcebergNativeScan .hadoopToIcebergS3Properties(hadoopS3Options) - CometIcebergNativeScanMetadata - .extract(scanExec.scan, metadataLocation, catalogProperties) + val result = CometIcebergNativeScanMetadata + .extract(scanExec.scan, effectiveLocation, catalogProperties) + + result } catch { case e: Exception => logError( @@ -319,21 +353,18 @@ case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] with Com // Now perform all validation using the pre-extracted metadata // Check if table uses a FileIO implementation compatible with iceberg-rust + val fileIOCompatible = IcebergReflection.getFileIO(metadata.table) match { - case Some(fileIO) => - val fileIOClassName = fileIO.getClass.getName - if (fileIOClassName == "org.apache.iceberg.inmemory.InMemoryFileIO") { - fallbackReasons += "Comet does not support InMemoryFileIO table locations" - false - } else { - true - } + case Some(_) => + // InMemoryFileIO is now supported with table location fallback for REST catalogs + true case None => fallbackReasons += "Could not check FileIO compatibility" false } // Check Iceberg table format version + val formatVersionSupported = IcebergReflection.getFormatVersion(metadata.table) match { case Some(formatVersion) => if (formatVersion > 2) { diff --git a/spark/src/test/java/org/apache/iceberg/rest/RESTCatalogAdapter.java b/spark/src/test/java/org/apache/iceberg/rest/RESTCatalogAdapter.java new file mode 100644 index 0000000000..7d5d6ce6b2 --- /dev/null +++ b/spark/src/test/java/org/apache/iceberg/rest/RESTCatalogAdapter.java @@ -0,0 +1,655 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.rest; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; + +import org.apache.iceberg.BaseTable; +import org.apache.iceberg.BaseTransaction; +import org.apache.iceberg.Table; +import org.apache.iceberg.Transaction; +import org.apache.iceberg.Transactions; +import org.apache.iceberg.catalog.Catalog; +import org.apache.iceberg.catalog.Namespace; +import org.apache.iceberg.catalog.SupportsNamespaces; +import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.catalog.ViewCatalog; +import org.apache.iceberg.exceptions.AlreadyExistsException; +import org.apache.iceberg.exceptions.CommitFailedException; +import org.apache.iceberg.exceptions.CommitStateUnknownException; +import org.apache.iceberg.exceptions.ForbiddenException; +import org.apache.iceberg.exceptions.NamespaceNotEmptyException; +import org.apache.iceberg.exceptions.NoSuchIcebergTableException; +import org.apache.iceberg.exceptions.NoSuchNamespaceException; +import org.apache.iceberg.exceptions.NoSuchTableException; +import org.apache.iceberg.exceptions.NoSuchViewException; +import org.apache.iceberg.exceptions.NotAuthorizedException; +import org.apache.iceberg.exceptions.RESTException; +import org.apache.iceberg.exceptions.UnprocessableEntityException; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.base.Splitter; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.rest.requests.CommitTransactionRequest; +import org.apache.iceberg.rest.requests.CreateNamespaceRequest; +import org.apache.iceberg.rest.requests.CreateTableRequest; +import org.apache.iceberg.rest.requests.CreateViewRequest; +import org.apache.iceberg.rest.requests.RegisterTableRequest; +import org.apache.iceberg.rest.requests.RenameTableRequest; +import org.apache.iceberg.rest.requests.ReportMetricsRequest; +import org.apache.iceberg.rest.requests.UpdateNamespacePropertiesRequest; +import org.apache.iceberg.rest.requests.UpdateTableRequest; +import org.apache.iceberg.rest.responses.ConfigResponse; +import org.apache.iceberg.rest.responses.CreateNamespaceResponse; +import org.apache.iceberg.rest.responses.ErrorResponse; +import org.apache.iceberg.rest.responses.GetNamespaceResponse; +import org.apache.iceberg.rest.responses.ListNamespacesResponse; +import org.apache.iceberg.rest.responses.ListTablesResponse; +import org.apache.iceberg.rest.responses.LoadTableResponse; +import org.apache.iceberg.rest.responses.LoadViewResponse; +import org.apache.iceberg.rest.responses.OAuthTokenResponse; +import org.apache.iceberg.rest.responses.UpdateNamespacePropertiesResponse; +import org.apache.iceberg.util.Pair; +import org.apache.iceberg.util.PropertyUtil; + +/** Adaptor class to translate REST requests into {@link Catalog} API calls. */ +public class RESTCatalogAdapter implements RESTClient { + private static final Splitter SLASH = Splitter.on('/'); + + private static final Map, Integer> EXCEPTION_ERROR_CODES = + ImmutableMap., Integer>builder() + .put(IllegalArgumentException.class, 400) + .put(ValidationException.class, 400) + .put(NamespaceNotEmptyException.class, 400) // TODO: should this be more specific? + .put(NotAuthorizedException.class, 401) + .put(ForbiddenException.class, 403) + .put(NoSuchNamespaceException.class, 404) + .put(NoSuchTableException.class, 404) + .put(NoSuchViewException.class, 404) + .put(NoSuchIcebergTableException.class, 404) + .put(UnsupportedOperationException.class, 406) + .put(AlreadyExistsException.class, 409) + .put(CommitFailedException.class, 409) + .put(UnprocessableEntityException.class, 422) + .put(CommitStateUnknownException.class, 500) + .buildOrThrow(); + + private final Catalog catalog; + private final SupportsNamespaces asNamespaceCatalog; + private final ViewCatalog asViewCatalog; + + public RESTCatalogAdapter(Catalog catalog) { + this.catalog = catalog; + this.asNamespaceCatalog = + catalog instanceof SupportsNamespaces ? (SupportsNamespaces) catalog : null; + this.asViewCatalog = catalog instanceof ViewCatalog ? (ViewCatalog) catalog : null; + } + + enum HTTPMethod { + GET, + HEAD, + POST, + DELETE + } + + enum Route { + TOKENS(HTTPMethod.POST, "v1/oauth/tokens", null, OAuthTokenResponse.class), + SEPARATE_AUTH_TOKENS_URI( + HTTPMethod.POST, "https://auth-server.com/token", null, OAuthTokenResponse.class), + CONFIG(HTTPMethod.GET, "v1/config", null, ConfigResponse.class), + LIST_NAMESPACES(HTTPMethod.GET, "v1/namespaces", null, ListNamespacesResponse.class), + CREATE_NAMESPACE( + HTTPMethod.POST, + "v1/namespaces", + CreateNamespaceRequest.class, + CreateNamespaceResponse.class), + LOAD_NAMESPACE(HTTPMethod.GET, "v1/namespaces/{namespace}", null, GetNamespaceResponse.class), + DROP_NAMESPACE(HTTPMethod.DELETE, "v1/namespaces/{namespace}"), + UPDATE_NAMESPACE( + HTTPMethod.POST, + "v1/namespaces/{namespace}/properties", + UpdateNamespacePropertiesRequest.class, + UpdateNamespacePropertiesResponse.class), + LIST_TABLES(HTTPMethod.GET, "v1/namespaces/{namespace}/tables", null, ListTablesResponse.class), + CREATE_TABLE( + HTTPMethod.POST, + "v1/namespaces/{namespace}/tables", + CreateTableRequest.class, + LoadTableResponse.class), + LOAD_TABLE( + HTTPMethod.GET, "v1/namespaces/{namespace}/tables/{name}", null, LoadTableResponse.class), + REGISTER_TABLE( + HTTPMethod.POST, + "v1/namespaces/{namespace}/register", + RegisterTableRequest.class, + LoadTableResponse.class), + UPDATE_TABLE( + HTTPMethod.POST, + "v1/namespaces/{namespace}/tables/{name}", + UpdateTableRequest.class, + LoadTableResponse.class), + DROP_TABLE(HTTPMethod.DELETE, "v1/namespaces/{namespace}/tables/{name}"), + RENAME_TABLE(HTTPMethod.POST, "v1/tables/rename", RenameTableRequest.class, null), + REPORT_METRICS( + HTTPMethod.POST, + "v1/namespaces/{namespace}/tables/{name}/metrics", + ReportMetricsRequest.class, + null), + COMMIT_TRANSACTION( + HTTPMethod.POST, "v1/transactions/commit", CommitTransactionRequest.class, null), + LIST_VIEWS(HTTPMethod.GET, "v1/namespaces/{namespace}/views", null, ListTablesResponse.class), + LOAD_VIEW( + HTTPMethod.GET, "v1/namespaces/{namespace}/views/{name}", null, LoadViewResponse.class), + CREATE_VIEW( + HTTPMethod.POST, + "v1/namespaces/{namespace}/views", + CreateViewRequest.class, + LoadViewResponse.class), + UPDATE_VIEW( + HTTPMethod.POST, + "v1/namespaces/{namespace}/views/{name}", + UpdateTableRequest.class, + LoadViewResponse.class), + RENAME_VIEW(HTTPMethod.POST, "v1/views/rename", RenameTableRequest.class, null), + DROP_VIEW(HTTPMethod.DELETE, "v1/namespaces/{namespace}/views/{name}"); + + private final HTTPMethod method; + private final int requiredLength; + private final Map requirements; + private final Map variables; + private final Class requestClass; + private final Class responseClass; + + Route(HTTPMethod method, String pattern) { + this(method, pattern, null, null); + } + + Route( + HTTPMethod method, + String pattern, + Class requestClass, + Class responseClass) { + this.method = method; + + // parse the pattern into requirements and variables + List parts = SLASH.splitToList(pattern); + ImmutableMap.Builder requirementsBuilder = ImmutableMap.builder(); + ImmutableMap.Builder variablesBuilder = ImmutableMap.builder(); + for (int pos = 0; pos < parts.size(); pos += 1) { + String part = parts.get(pos); + if (part.startsWith("{") && part.endsWith("}")) { + variablesBuilder.put(pos, part.substring(1, part.length() - 1)); + } else { + requirementsBuilder.put(pos, part); + } + } + + this.requestClass = requestClass; + this.responseClass = responseClass; + + this.requiredLength = parts.size(); + this.requirements = requirementsBuilder.build(); + this.variables = variablesBuilder.build(); + } + + private boolean matches(HTTPMethod requestMethod, List requestPath) { + return method == requestMethod + && requiredLength == requestPath.size() + && requirements.entrySet().stream() + .allMatch( + requirement -> + requirement + .getValue() + .equalsIgnoreCase(requestPath.get(requirement.getKey()))); + } + + private Map variables(List requestPath) { + ImmutableMap.Builder vars = ImmutableMap.builder(); + variables.forEach((key, value) -> vars.put(value, requestPath.get(key))); + return vars.build(); + } + + public static Pair> from(HTTPMethod method, String path) { + List parts = SLASH.splitToList(path); + for (Route candidate : Route.values()) { + if (candidate.matches(method, parts)) { + return Pair.of(candidate, candidate.variables(parts)); + } + } + + return null; + } + + public Class requestClass() { + return requestClass; + } + + public Class responseClass() { + return responseClass; + } + } + + private static OAuthTokenResponse handleOAuthRequest(Object body) { + Map request = (Map) castRequest(Map.class, body); + String grantType = request.get("grant_type"); + switch (grantType) { + case "client_credentials": + return OAuthTokenResponse.builder() + .withToken("client-credentials-token:sub=" + request.get("client_id")) + .withTokenType("Bearer") + .build(); + + case "urn:ietf:params:oauth:grant-type:token-exchange": + String actor = request.get("actor_token"); + String token = + String.format( + "token-exchange-token:sub=%s%s", + request.get("subject_token"), actor != null ? ",act=" + actor : ""); + return OAuthTokenResponse.builder() + .withToken(token) + .withIssuedTokenType("urn:ietf:params:oauth:token-type:access_token") + .withTokenType("Bearer") + .build(); + + default: + throw new UnsupportedOperationException("Unsupported grant_type: " + grantType); + } + } + + @SuppressWarnings({"MethodLength", "checkstyle:CyclomaticComplexity"}) + public T handleRequest( + Route route, Map vars, Object body, Class responseType) { + switch (route) { + case TOKENS: + return castResponse(responseType, handleOAuthRequest(body)); + + case CONFIG: + return castResponse(responseType, ConfigResponse.builder().build()); + + case LIST_NAMESPACES: + if (asNamespaceCatalog != null) { + Namespace ns; + if (vars.containsKey("parent")) { + ns = + Namespace.of( + RESTUtil.NAMESPACE_SPLITTER + .splitToStream(vars.get("parent")) + .toArray(String[]::new)); + } else { + ns = Namespace.empty(); + } + + return castResponse(responseType, CatalogHandlers.listNamespaces(asNamespaceCatalog, ns)); + } + break; + + case CREATE_NAMESPACE: + if (asNamespaceCatalog != null) { + CreateNamespaceRequest request = castRequest(CreateNamespaceRequest.class, body); + return castResponse( + responseType, CatalogHandlers.createNamespace(asNamespaceCatalog, request)); + } + break; + + case LOAD_NAMESPACE: + if (asNamespaceCatalog != null) { + Namespace namespace = namespaceFromPathVars(vars); + return castResponse( + responseType, CatalogHandlers.loadNamespace(asNamespaceCatalog, namespace)); + } + break; + + case DROP_NAMESPACE: + if (asNamespaceCatalog != null) { + CatalogHandlers.dropNamespace(asNamespaceCatalog, namespaceFromPathVars(vars)); + return null; + } + break; + + case UPDATE_NAMESPACE: + if (asNamespaceCatalog != null) { + Namespace namespace = namespaceFromPathVars(vars); + UpdateNamespacePropertiesRequest request = + castRequest(UpdateNamespacePropertiesRequest.class, body); + return castResponse( + responseType, + CatalogHandlers.updateNamespaceProperties(asNamespaceCatalog, namespace, request)); + } + break; + + case LIST_TABLES: + { + Namespace namespace = namespaceFromPathVars(vars); + return castResponse(responseType, CatalogHandlers.listTables(catalog, namespace)); + } + + case CREATE_TABLE: + { + Namespace namespace = namespaceFromPathVars(vars); + CreateTableRequest request = castRequest(CreateTableRequest.class, body); + request.validate(); + if (request.stageCreate()) { + return castResponse( + responseType, CatalogHandlers.stageTableCreate(catalog, namespace, request)); + } else { + return castResponse( + responseType, CatalogHandlers.createTable(catalog, namespace, request)); + } + } + + case DROP_TABLE: + { + if (PropertyUtil.propertyAsBoolean(vars, "purgeRequested", false)) { + CatalogHandlers.purgeTable(catalog, identFromPathVars(vars)); + } else { + CatalogHandlers.dropTable(catalog, identFromPathVars(vars)); + } + return null; + } + + case LOAD_TABLE: + { + TableIdentifier ident = identFromPathVars(vars); + return castResponse(responseType, CatalogHandlers.loadTable(catalog, ident)); + } + + case REGISTER_TABLE: + { + Namespace namespace = namespaceFromPathVars(vars); + RegisterTableRequest request = castRequest(RegisterTableRequest.class, body); + return castResponse( + responseType, CatalogHandlers.registerTable(catalog, namespace, request)); + } + + case UPDATE_TABLE: + { + TableIdentifier ident = identFromPathVars(vars); + UpdateTableRequest request = castRequest(UpdateTableRequest.class, body); + return castResponse(responseType, CatalogHandlers.updateTable(catalog, ident, request)); + } + + case RENAME_TABLE: + { + RenameTableRequest request = castRequest(RenameTableRequest.class, body); + CatalogHandlers.renameTable(catalog, request); + return null; + } + + case REPORT_METRICS: + { + // nothing to do here other than checking that we're getting the correct request + castRequest(ReportMetricsRequest.class, body); + return null; + } + + case COMMIT_TRANSACTION: + { + CommitTransactionRequest request = castRequest(CommitTransactionRequest.class, body); + commitTransaction(catalog, request); + return null; + } + + case LIST_VIEWS: + { + if (null != asViewCatalog) { + Namespace namespace = namespaceFromPathVars(vars); + return castResponse(responseType, CatalogHandlers.listViews(asViewCatalog, namespace)); + } + break; + } + + case CREATE_VIEW: + { + if (null != asViewCatalog) { + Namespace namespace = namespaceFromPathVars(vars); + CreateViewRequest request = castRequest(CreateViewRequest.class, body); + return castResponse( + responseType, CatalogHandlers.createView(asViewCatalog, namespace, request)); + } + break; + } + + case LOAD_VIEW: + { + if (null != asViewCatalog) { + TableIdentifier ident = identFromPathVars(vars); + return castResponse(responseType, CatalogHandlers.loadView(asViewCatalog, ident)); + } + break; + } + + case UPDATE_VIEW: + { + if (null != asViewCatalog) { + TableIdentifier ident = identFromPathVars(vars); + UpdateTableRequest request = castRequest(UpdateTableRequest.class, body); + return castResponse( + responseType, CatalogHandlers.updateView(asViewCatalog, ident, request)); + } + break; + } + + case RENAME_VIEW: + { + if (null != asViewCatalog) { + RenameTableRequest request = castRequest(RenameTableRequest.class, body); + CatalogHandlers.renameView(asViewCatalog, request); + return null; + } + break; + } + + case DROP_VIEW: + { + if (null != asViewCatalog) { + CatalogHandlers.dropView(asViewCatalog, identFromPathVars(vars)); + return null; + } + break; + } + + default: + if (responseType == OAuthTokenResponse.class) { + return castResponse(responseType, handleOAuthRequest(body)); + } + } + + return null; + } + + /** + * This is a very simplistic approach that only validates the requirements for each table and does + * not do any other conflict detection. Therefore, it does not guarantee true transactional + * atomicity, which is left to the implementation details of a REST server. + */ + private static void commitTransaction(Catalog catalog, CommitTransactionRequest request) { + List transactions = Lists.newArrayList(); + + for (UpdateTableRequest tableChange : request.tableChanges()) { + Table table = catalog.loadTable(tableChange.identifier()); + if (table instanceof BaseTable) { + Transaction transaction = + Transactions.newTransaction( + tableChange.identifier().toString(), ((BaseTable) table).operations()); + transactions.add(transaction); + + BaseTransaction.TransactionTable txTable = + (BaseTransaction.TransactionTable) transaction.table(); + + // this performs validations and makes temporary commits that are in-memory + CatalogHandlers.commit(txTable.operations(), tableChange); + } else { + throw new IllegalStateException("Cannot wrap catalog that does not produce BaseTable"); + } + } + + // only commit if validations passed previously + transactions.forEach(Transaction::commitTransaction); + } + + public T execute( + HTTPMethod method, + String path, + Map queryParams, + Object body, + Class responseType, + Map headers, + Consumer errorHandler) { + ErrorResponse.Builder errorBuilder = ErrorResponse.builder(); + Pair> routeAndVars = Route.from(method, path); + if (routeAndVars != null) { + try { + ImmutableMap.Builder vars = ImmutableMap.builder(); + if (queryParams != null) { + vars.putAll(queryParams); + } + vars.putAll(routeAndVars.second()); + + return handleRequest(routeAndVars.first(), vars.build(), body, responseType); + + } catch (RuntimeException e) { + configureResponseFromException(e, errorBuilder); + } + + } else { + errorBuilder + .responseCode(400) + .withType("BadRequestException") + .withMessage(String.format("No route for request: %s %s", method, path)); + } + + ErrorResponse error = errorBuilder.build(); + errorHandler.accept(error); + + // if the error handler doesn't throw an exception, throw a generic one + throw new RESTException("Unhandled error: %s", error); + } + + @Override + public T delete( + String path, + Class responseType, + Map headers, + Consumer errorHandler) { + return execute(HTTPMethod.DELETE, path, null, null, responseType, headers, errorHandler); + } + + @Override + public T delete( + String path, + Map queryParams, + Class responseType, + Map headers, + Consumer errorHandler) { + return execute(HTTPMethod.DELETE, path, queryParams, null, responseType, headers, errorHandler); + } + + @Override + public T post( + String path, + RESTRequest body, + Class responseType, + Map headers, + Consumer errorHandler) { + return execute(HTTPMethod.POST, path, null, body, responseType, headers, errorHandler); + } + + @Override + public T get( + String path, + Map queryParams, + Class responseType, + Map headers, + Consumer errorHandler) { + return execute(HTTPMethod.GET, path, queryParams, null, responseType, headers, errorHandler); + } + + @Override + public void head(String path, Map headers, Consumer errorHandler) { + execute(HTTPMethod.HEAD, path, null, null, null, headers, errorHandler); + } + + @Override + public T postForm( + String path, + Map formData, + Class responseType, + Map headers, + Consumer errorHandler) { + return execute(HTTPMethod.POST, path, null, formData, responseType, headers, errorHandler); + } + + @Override + public void close() throws IOException { + // The calling test is responsible for closing the underlying catalog backing this REST catalog + // so that the underlying backend catalog is not closed and reopened during the REST catalog's + // initialize method when fetching the server configuration. + } + + private static class BadResponseType extends RuntimeException { + private BadResponseType(Class responseType, Object response) { + super( + String.format("Invalid response object, not a %s: %s", responseType.getName(), response)); + } + } + + private static class BadRequestType extends RuntimeException { + private BadRequestType(Class requestType, Object request) { + super(String.format("Invalid request object, not a %s: %s", requestType.getName(), request)); + } + } + + public static T castRequest(Class requestType, Object request) { + if (requestType.isInstance(request)) { + return requestType.cast(request); + } + + throw new BadRequestType(requestType, request); + } + + public static T castResponse(Class responseType, Object response) { + if (responseType.isInstance(response)) { + return responseType.cast(response); + } + + throw new BadResponseType(responseType, response); + } + + public static void configureResponseFromException( + Exception exc, ErrorResponse.Builder errorBuilder) { + errorBuilder + .responseCode(EXCEPTION_ERROR_CODES.getOrDefault(exc.getClass(), 500)) + .withType(exc.getClass().getSimpleName()) + .withMessage(exc.getMessage()) + .withStackTrace(exc); + } + + private static Namespace namespaceFromPathVars(Map pathVars) { + return RESTUtil.decodeNamespace(pathVars.get("namespace")); + } + + private static TableIdentifier identFromPathVars(Map pathVars) { + return TableIdentifier.of( + namespaceFromPathVars(pathVars), RESTUtil.decodeString(pathVars.get("name"))); + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometIcebergNativeSuite.scala b/spark/src/test/scala/org/apache/comet/CometIcebergNativeSuite.scala index 8a666dc76f..174b091050 100644 --- a/spark/src/test/scala/org/apache/comet/CometIcebergNativeSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometIcebergNativeSuite.scala @@ -26,12 +26,14 @@ import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.comet.CometIcebergNativeScanExec import org.apache.spark.sql.execution.SparkPlan +import org.apache.comet.iceberg.RESTCatalogHelper + /** * Test suite for native Iceberg scan using FileScanTasks and iceberg-rust. * * Note: Requires Iceberg dependencies to be added to pom.xml */ -class CometIcebergNativeSuite extends CometTestBase { +class CometIcebergNativeSuite extends CometTestBase with RESTCatalogHelper { // Skip these tests if Iceberg is not available in classpath private def icebergAvailable: Boolean = { @@ -2242,6 +2244,45 @@ class CometIcebergNativeSuite extends CometTestBase { } } + test("REST catalog with native Iceberg scan") { + assume(icebergAvailable, "Iceberg not available in classpath") + + withRESTCatalog { (restUri, _, warehouseDir) => + withSQLConf( + "spark.sql.catalog.rest_cat" -> "org.apache.iceberg.spark.SparkCatalog", + "spark.sql.catalog.rest_cat.catalog-impl" -> "org.apache.iceberg.rest.RESTCatalog", + "spark.sql.catalog.rest_cat.uri" -> restUri, + "spark.sql.catalog.rest_cat.warehouse" -> warehouseDir.getAbsolutePath, + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_ICEBERG_NATIVE_ENABLED.key -> "true", + CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.key -> "true") { + + // Create namespace first (REST catalog requires explicit namespace creation) + spark.sql("CREATE NAMESPACE rest_cat.db") + + // Create a table via REST catalog + spark.sql(""" + CREATE TABLE rest_cat.db.test_table ( + id INT, + name STRING, + value DOUBLE + ) USING iceberg + """) + + spark.sql(""" + INSERT INTO rest_cat.db.test_table + VALUES (1, 'Alice', 10.5), (2, 'Bob', 20.3), (3, 'Charlie', 30.7) + """) + + checkIcebergNativeScan("SELECT * FROM rest_cat.db.test_table ORDER BY id") + + spark.sql("DROP TABLE rest_cat.db.test_table") + spark.sql("DROP NAMESPACE rest_cat.db") + } + } + } + // Helper to create temp directory def withTempIcebergDir(f: File => Unit): Unit = { val dir = Files.createTempDirectory("comet-iceberg-test").toFile diff --git a/spark/src/test/spark-3.x/org/apache/comet/iceberg/RESTCatalogHelper.scala b/spark/src/test/spark-3.x/org/apache/comet/iceberg/RESTCatalogHelper.scala new file mode 100644 index 0000000000..6230ee33e1 --- /dev/null +++ b/spark/src/test/spark-3.x/org/apache/comet/iceberg/RESTCatalogHelper.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.iceberg + +import java.io.File +import java.nio.file.Files + +/** Helper trait for setting up REST catalog with Jetty 9.4 (javax.servlet) for Spark 3.x */ +trait RESTCatalogHelper { + + /** Helper to set up REST catalog with embedded Jetty server (Spark 3.x / Jetty 9.4) */ + def withRESTCatalog(f: (String, org.eclipse.jetty.server.Server, File) => Unit): Unit = { + import org.apache.iceberg.inmemory.InMemoryCatalog + import org.apache.iceberg.CatalogProperties + import org.apache.iceberg.rest.{RESTCatalogAdapter, RESTCatalogServlet} + import org.eclipse.jetty.server.Server + import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} + import org.eclipse.jetty.server.handler.gzip.GzipHandler + + val warehouseDir = Files.createTempDirectory("comet-rest-catalog-test").toFile + val backendCatalog = new InMemoryCatalog() + backendCatalog.initialize( + "in-memory", + java.util.Map.of(CatalogProperties.WAREHOUSE_LOCATION, warehouseDir.getAbsolutePath)) + + val adapter = new RESTCatalogAdapter(backendCatalog) + val servlet = new RESTCatalogServlet(adapter) + + val servletContext = new ServletContextHandler(ServletContextHandler.NO_SESSIONS) + servletContext.setContextPath("/") + val servletHolder = new ServletHolder(servlet.asInstanceOf[javax.servlet.Servlet]) + servletHolder.setInitParameter("javax.ws.rs.Application", "ServiceListPublic") + servletContext.addServlet(servletHolder, "/*") + servletContext.setVirtualHosts(null) + servletContext.setGzipHandler(new GzipHandler()) + + val httpServer = new Server(0) // random port + httpServer.setHandler(servletContext) + + try { + httpServer.start() + val restUri = httpServer.getURI.toString.stripSuffix("/") + f(restUri, httpServer, warehouseDir) + } finally { + try { + httpServer.stop() + httpServer.join() + } catch { + case _: Exception => // ignore cleanup errors + } + try { + backendCatalog.close() + } catch { + case _: Exception => // ignore cleanup errors + } + def deleteRecursively(file: File): Unit = { + if (file.isDirectory) { + file.listFiles().foreach(deleteRecursively) + } + file.delete() + } + deleteRecursively(warehouseDir) + } + } +} diff --git a/spark/src/test/spark-3.x/org/apache/iceberg/rest/RESTCatalogServlet.java b/spark/src/test/spark-3.x/org/apache/iceberg/rest/RESTCatalogServlet.java new file mode 100644 index 0000000000..88de30f2d8 --- /dev/null +++ b/spark/src/test/spark-3.x/org/apache/iceberg/rest/RESTCatalogServlet.java @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.rest; + +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.Reader; +import java.io.UncheckedIOException; +import java.util.Collections; +import java.util.Map; +import java.util.Optional; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.stream.Collectors; +import javax.servlet.http.HttpServlet; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.iceberg.exceptions.RESTException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.io.CharStreams; +import org.apache.iceberg.rest.RESTCatalogAdapter.HTTPMethod; +import org.apache.iceberg.rest.RESTCatalogAdapter.Route; +import org.apache.iceberg.rest.responses.ErrorResponse; +import org.apache.iceberg.util.Pair; + +import static java.lang.String.format; + +/** + * The RESTCatalogServlet provides a servlet implementation used in combination with a + * RESTCatalogAdaptor to proxy the REST Spec to any Catalog implementation. + * Modified version of Iceberg's org/apache/iceberg/rest/RESTCatalogServlet.java + */ +public class RESTCatalogServlet extends HttpServlet { + private static final Logger LOG = LoggerFactory.getLogger(RESTCatalogServlet.class); + + private final RESTCatalogAdapter restCatalogAdapter; + private final Map responseHeaders = + ImmutableMap.of("Content-Type", "application/json"); + + public RESTCatalogServlet(RESTCatalogAdapter restCatalogAdapter) { + this.restCatalogAdapter = restCatalogAdapter; + } + + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) + throws IOException { + execute(ServletRequestContext.from(request), response); + } + + @Override + protected void doHead(HttpServletRequest request, HttpServletResponse response) + throws IOException { + execute(ServletRequestContext.from(request), response); + } + + @Override + protected void doPost(HttpServletRequest request, HttpServletResponse response) + throws IOException { + execute(ServletRequestContext.from(request), response); + } + + @Override + protected void doDelete(HttpServletRequest request, HttpServletResponse response) + throws IOException { + execute(ServletRequestContext.from(request), response); + } + + protected void execute(ServletRequestContext context, HttpServletResponse response) + throws IOException { + response.setStatus(HttpServletResponse.SC_OK); + responseHeaders.forEach(response::setHeader); + + if (context.error().isPresent()) { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST); + RESTObjectMapper.mapper().writeValue(response.getWriter(), context.error().get()); + return; + } + + try { + Object responseBody = + restCatalogAdapter.execute( + context.method(), + context.path(), + context.queryParams(), + context.body(), + context.route().responseClass(), + context.headers(), + handle(response)); + + if (responseBody != null) { + RESTObjectMapper.mapper().writeValue(response.getWriter(), responseBody); + } + } catch (RESTException e) { + LOG.error("Error processing REST request", e); + response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); + } catch (Exception e) { + LOG.error("Unexpected exception when processing REST request", e); + response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); + } + } + + protected Consumer handle(HttpServletResponse response) { + return (errorResponse) -> { + response.setStatus(errorResponse.code()); + try { + RESTObjectMapper.mapper().writeValue(response.getWriter(), errorResponse); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + }; + } + + public static class ServletRequestContext { + private HTTPMethod method; + private Route route; + private String path; + private Map headers; + private Map queryParams; + private Object body; + + private ErrorResponse errorResponse; + + private ServletRequestContext(ErrorResponse errorResponse) { + this.errorResponse = errorResponse; + } + + private ServletRequestContext( + HTTPMethod method, + Route route, + String path, + Map headers, + Map queryParams, + Object body) { + this.method = method; + this.route = route; + this.path = path; + this.headers = headers; + this.queryParams = queryParams; + this.body = body; + } + + static ServletRequestContext from(HttpServletRequest request) throws IOException { + HTTPMethod method = HTTPMethod.valueOf(request.getMethod()); + String path = request.getRequestURI().substring(1); + Pair> routeContext = Route.from(method, path); + + if (routeContext == null) { + return new ServletRequestContext( + ErrorResponse.builder() + .responseCode(400) + .withType("BadRequestException") + .withMessage(format("No route for request: %s %s", method, path)) + .build()); + } + + Route route = routeContext.first(); + Object requestBody = null; + if (route.requestClass() != null) { + requestBody = + RESTObjectMapper.mapper().readValue(request.getReader(), route.requestClass()); + } else if (route == Route.TOKENS) { + try (Reader reader = new InputStreamReader(request.getInputStream())) { + requestBody = RESTUtil.decodeFormData(CharStreams.toString(reader)); + } + } + + Map queryParams = + request.getParameterMap().entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue()[0])); + Map headers = + Collections.list(request.getHeaderNames()).stream() + .collect(Collectors.toMap(Function.identity(), request::getHeader)); + + return new ServletRequestContext(method, route, path, headers, queryParams, requestBody); + } + + public HTTPMethod method() { + return method; + } + + public Route route() { + return route; + } + + public String path() { + return path; + } + + public Map headers() { + return headers; + } + + public Map queryParams() { + return queryParams; + } + + public Object body() { + return body; + } + + public Optional error() { + return Optional.ofNullable(errorResponse); + } + } +} diff --git a/spark/src/test/spark-4.0/org/apache/comet/iceberg/RESTCatalogHelper.scala b/spark/src/test/spark-4.0/org/apache/comet/iceberg/RESTCatalogHelper.scala new file mode 100644 index 0000000000..ccd03c544d --- /dev/null +++ b/spark/src/test/spark-4.0/org/apache/comet/iceberg/RESTCatalogHelper.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.iceberg + +import java.io.File +import java.nio.file.Files + +/** Helper trait for setting up REST catalog with Jetty 11 (jakarta.servlet) for Spark 4.0 */ +trait RESTCatalogHelper { + + /** Helper to set up REST catalog with embedded Jetty server (Spark 4.0 / Jetty 11) */ + def withRESTCatalog(f: (String, org.eclipse.jetty.server.Server, File) => Unit): Unit = { + import org.apache.iceberg.inmemory.InMemoryCatalog + import org.apache.iceberg.CatalogProperties + import org.apache.iceberg.rest.{RESTCatalogAdapter, RESTCatalogServlet} + import org.eclipse.jetty.server.Server + import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} + import org.eclipse.jetty.server.handler.gzip.GzipHandler + + val warehouseDir = Files.createTempDirectory("comet-rest-catalog-test").toFile + val backendCatalog = new InMemoryCatalog() + backendCatalog.initialize( + "in-memory", + java.util.Map.of(CatalogProperties.WAREHOUSE_LOCATION, warehouseDir.getAbsolutePath)) + + val adapter = new RESTCatalogAdapter(backendCatalog) + val servlet = new RESTCatalogServlet(adapter) + + val servletContext = new ServletContextHandler(ServletContextHandler.NO_SESSIONS) + servletContext.setContextPath("/") + val servletHolder = new ServletHolder(servlet.asInstanceOf[jakarta.servlet.Servlet]) + servletHolder.setInitParameter("jakarta.ws.rs.Application", "ServiceListPublic") + servletContext.addServlet(servletHolder, "/*") + servletContext.setVirtualHosts(null) + servletContext.insertHandler(new GzipHandler()) + + val httpServer = new Server(0) // random port + httpServer.setHandler(servletContext) + + try { + httpServer.start() + val restUri = httpServer.getURI.toString.stripSuffix("/") + f(restUri, httpServer, warehouseDir) + } finally { + try { + httpServer.stop() + httpServer.join() + } catch { + case _: Exception => // ignore cleanup errors + } + try { + backendCatalog.close() + } catch { + case _: Exception => // ignore cleanup errors + } + def deleteRecursively(file: File): Unit = { + if (file.isDirectory) { + file.listFiles().foreach(deleteRecursively) + } + file.delete() + } + deleteRecursively(warehouseDir) + } + } +} diff --git a/spark/src/test/spark-4.0/org/apache/iceberg/rest/RESTCatalogServlet.java b/spark/src/test/spark-4.0/org/apache/iceberg/rest/RESTCatalogServlet.java new file mode 100644 index 0000000000..b54dacac48 --- /dev/null +++ b/spark/src/test/spark-4.0/org/apache/iceberg/rest/RESTCatalogServlet.java @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.rest; + +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.Reader; +import java.io.UncheckedIOException; +import java.util.Collections; +import java.util.Map; +import java.util.Optional; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.stream.Collectors; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.iceberg.exceptions.RESTException; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.relocated.com.google.common.io.CharStreams; +import org.apache.iceberg.rest.RESTCatalogAdapter.HTTPMethod; +import org.apache.iceberg.rest.RESTCatalogAdapter.Route; +import org.apache.iceberg.rest.responses.ErrorResponse; +import org.apache.iceberg.util.Pair; + +import static java.lang.String.format; + +/** + * The RESTCatalogServlet provides a servlet implementation used in combination with a + * RESTCatalogAdaptor to proxy the REST Spec to any Catalog implementation. + * Modified version of Iceberg's org/apache/iceberg/rest/RESTCatalogServlet.java + */ +public class RESTCatalogServlet extends HttpServlet { + private static final Logger LOG = LoggerFactory.getLogger(RESTCatalogServlet.class); + + private final RESTCatalogAdapter restCatalogAdapter; + private final Map responseHeaders = + ImmutableMap.of("Content-Type", "application/json"); + + public RESTCatalogServlet(RESTCatalogAdapter restCatalogAdapter) { + this.restCatalogAdapter = restCatalogAdapter; + } + + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) + throws IOException { + execute(ServletRequestContext.from(request), response); + } + + @Override + protected void doHead(HttpServletRequest request, HttpServletResponse response) + throws IOException { + execute(ServletRequestContext.from(request), response); + } + + @Override + protected void doPost(HttpServletRequest request, HttpServletResponse response) + throws IOException { + execute(ServletRequestContext.from(request), response); + } + + @Override + protected void doDelete(HttpServletRequest request, HttpServletResponse response) + throws IOException { + execute(ServletRequestContext.from(request), response); + } + + protected void execute(ServletRequestContext context, HttpServletResponse response) + throws IOException { + response.setStatus(HttpServletResponse.SC_OK); + responseHeaders.forEach(response::setHeader); + + if (context.error().isPresent()) { + response.setStatus(HttpServletResponse.SC_BAD_REQUEST); + RESTObjectMapper.mapper().writeValue(response.getWriter(), context.error().get()); + return; + } + + try { + Object responseBody = + restCatalogAdapter.execute( + context.method(), + context.path(), + context.queryParams(), + context.body(), + context.route().responseClass(), + context.headers(), + handle(response)); + + if (responseBody != null) { + RESTObjectMapper.mapper().writeValue(response.getWriter(), responseBody); + } + } catch (RESTException e) { + LOG.error("Error processing REST request", e); + response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); + } catch (Exception e) { + LOG.error("Unexpected exception when processing REST request", e); + response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); + } + } + + protected Consumer handle(HttpServletResponse response) { + return (errorResponse) -> { + response.setStatus(errorResponse.code()); + try { + RESTObjectMapper.mapper().writeValue(response.getWriter(), errorResponse); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + }; + } + + public static class ServletRequestContext { + private HTTPMethod method; + private Route route; + private String path; + private Map headers; + private Map queryParams; + private Object body; + + private ErrorResponse errorResponse; + + private ServletRequestContext(ErrorResponse errorResponse) { + this.errorResponse = errorResponse; + } + + private ServletRequestContext( + HTTPMethod method, + Route route, + String path, + Map headers, + Map queryParams, + Object body) { + this.method = method; + this.route = route; + this.path = path; + this.headers = headers; + this.queryParams = queryParams; + this.body = body; + } + + static ServletRequestContext from(HttpServletRequest request) throws IOException { + HTTPMethod method = HTTPMethod.valueOf(request.getMethod()); + String path = request.getRequestURI().substring(1); + Pair> routeContext = Route.from(method, path); + + if (routeContext == null) { + return new ServletRequestContext( + ErrorResponse.builder() + .responseCode(400) + .withType("BadRequestException") + .withMessage(format("No route for request: %s %s", method, path)) + .build()); + } + + Route route = routeContext.first(); + Object requestBody = null; + if (route.requestClass() != null) { + requestBody = + RESTObjectMapper.mapper().readValue(request.getReader(), route.requestClass()); + } else if (route == Route.TOKENS) { + try (Reader reader = new InputStreamReader(request.getInputStream())) { + requestBody = RESTUtil.decodeFormData(CharStreams.toString(reader)); + } + } + + Map queryParams = + request.getParameterMap().entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue()[0])); + Map headers = + Collections.list(request.getHeaderNames()).stream() + .collect(Collectors.toMap(Function.identity(), request::getHeader)); + + return new ServletRequestContext(method, route, path, headers, queryParams, requestBody); + } + + public HTTPMethod method() { + return method; + } + + public Route route() { + return route; + } + + public String path() { + return path; + } + + public Map headers() { + return headers; + } + + public Map queryParams() { + return queryParams; + } + + public Object body() { + return body; + } + + public Optional error() { + return Optional.ofNullable(errorResponse); + } + } +}