From 791c6cddaafdb9e32870abb1f00ddb255c90cbec Mon Sep 17 00:00:00 2001
From: David Li
Date: Sun, 13 Nov 2022 10:47:44 -0500
Subject: [PATCH] ARROW-18300: [Java] Implement parameters for JDBC driver
---
.../ArrowFlightJdbcFlightStreamResultSet.java | 2 +-
...owFlightJdbcVectorSchemaRootResultSet.java | 2 +-
.../driver/jdbc/ArrowFlightMetaImpl.java | 46 ++++--
.../AvaticaParameterFromArrowTypeVisitor.java | 146 ++++++++++++++++++
.../client/ArrowFlightSqlClientHandler.java | 15 ++
.../driver/jdbc/utils/TypedValueBinder.java | 86 +++++++++++
.../ArrowFlightPreparedStatementTest.java | 31 ++++
.../jdbc/utils/MockFlightSqlProducer.java | 60 ++++++-
.../jdbc/utils/TypedValueBinderTest.java | 22 +++
9 files changed, 395 insertions(+), 15 deletions(-)
create mode 100644 java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/AvaticaParameterFromArrowTypeVisitor.java
create mode 100644 java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/TypedValueBinder.java
create mode 100644 java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/TypedValueBinderTest.java
diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java
index 4c01cb6e581..c6db84c7e78 100644
--- a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java
+++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java
@@ -92,7 +92,7 @@ static ArrowFlightJdbcFlightStreamResultSet fromFlightInfo(
final TimeZone timeZone = TimeZone.getDefault();
final QueryState state = new QueryState();
- final Meta.Signature signature = ArrowFlightMetaImpl.newSignature(null);
+ final Meta.Signature signature = ArrowFlightMetaImpl.newSignature(null, null);
final AvaticaResultSetMetaData resultSetMetaData =
new AvaticaResultSetMetaData(null, null, signature);
diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java
index 9e377e51dec..72875597322 100644
--- a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java
+++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java
@@ -74,7 +74,7 @@ public static ArrowFlightJdbcVectorSchemaRootResultSet fromVectorSchemaRoot(
final TimeZone timeZone = TimeZone.getDefault();
final QueryState state = new QueryState();
- final Meta.Signature signature = ArrowFlightMetaImpl.newSignature(null);
+ final Meta.Signature signature = ArrowFlightMetaImpl.newSignature(null, null);
final AvaticaResultSetMetaData resultSetMetaData =
new AvaticaResultSetMetaData(null, null, signature);
diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java
index f825e7d13ce..2589e0c639e 100644
--- a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java
+++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java
@@ -27,9 +27,12 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
+import java.util.stream.Collectors;
import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler.PreparedStatement;
+import org.apache.arrow.driver.jdbc.utils.TypedValueBinder;
import org.apache.arrow.util.Preconditions;
+import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.calcite.avatica.AvaticaConnection;
import org.apache.calcite.avatica.AvaticaParameter;
import org.apache.calcite.avatica.ColumnMetaData;
@@ -54,17 +57,31 @@ public ArrowFlightMetaImpl(final AvaticaConnection connection) {
setDefaultConnectionProperties();
}
- static Signature newSignature(final String sql) {
+ static Signature newSignature(final String sql, Schema parameterSchema) {
+ final List parameters;
+ if (parameterSchema == null) {
+ parameters = Collections.emptyList();
+ } else {
+ parameters = parameterSchema.getFields()
+ .stream()
+ .map(AvaticaParameterFromArrowTypeVisitor::fromArrowField)
+ .collect(Collectors.toList());
+ }
+ Map internalParameters = Collections.emptyMap();
return new Signature(
new ArrayList(),
sql,
- Collections.emptyList(),
- Collections.emptyMap(),
- null, // unnecessary, as SQL requests use ArrowFlightJdbcCursor
+ parameters,
+ internalParameters,
+ /*cursorFactory*/null, // unnecessary, as SQL requests use ArrowFlightJdbcCursor
StatementType.SELECT
);
}
+ private ArrowFlightConnection getConnection() {
+ return (ArrowFlightConnection) connection;
+ }
+
@Override
public void closeStatement(final StatementHandle statementHandle) {
PreparedStatement preparedStatement =
@@ -86,17 +103,26 @@ public ExecuteResult execute(final StatementHandle statementHandle,
Preconditions.checkArgument(connection.id.equals(statementHandle.connectionId),
"Connection IDs are not consistent");
if (statementHandle.signature == null) {
+ // TODO: refactor update/select queries out into separate methods
// Update query
final StatementHandleKey key = new StatementHandleKey(statementHandle);
PreparedStatement preparedStatement = statementHandlePreparedStatementMap.get(key);
if (preparedStatement == null) {
throw new IllegalStateException("Prepared statement not found: " + statementHandle);
}
- long updatedCount = preparedStatement.executeUpdate();
+
+ long updatedCount;
+ try (final TypedValueBinder binder =
+ new TypedValueBinder(preparedStatement, getConnection().getBufferAllocator())) {
+ binder.bind(typedValues);
+ updatedCount = preparedStatement.executeUpdate();
+ }
return new ExecuteResult(Collections.singletonList(MetaResultSet.count(statementHandle.connectionId,
statementHandle.id, updatedCount)));
} else {
// TODO Why is maxRowCount ignored?
+ // TODO: TypedValues
+ // TODO: should move execution eagerly here instead of deferring it
return new ExecuteResult(
Collections.singletonList(MetaResultSet.create(
statementHandle.connectionId, statementHandle.id,
@@ -134,9 +160,8 @@ public Frame fetch(final StatementHandle statementHandle, final long offset,
public StatementHandle prepare(final ConnectionHandle connectionHandle,
final String query, final long maxRowCount) {
final StatementHandle handle = super.createStatement(connectionHandle);
- handle.signature = newSignature(query);
- final PreparedStatement preparedStatement =
- ((ArrowFlightConnection) connection).getClientHandler().prepare(query);
+ final PreparedStatement preparedStatement = getConnection().getClientHandler().prepare(query);
+ handle.signature = newSignature(query, preparedStatement.getParameterSchema());
statementHandlePreparedStatementMap.put(new StatementHandleKey(handle), preparedStatement);
return handle;
}
@@ -157,11 +182,10 @@ public ExecuteResult prepareAndExecute(final StatementHandle handle,
final PrepareCallback callback)
throws NoSuchStatementException {
try {
- final PreparedStatement preparedStatement =
- ((ArrowFlightConnection) connection).getClientHandler().prepare(query);
+ final PreparedStatement preparedStatement = getConnection().getClientHandler().prepare(query);
final StatementType statementType = preparedStatement.getType();
statementHandlePreparedStatementMap.put(new StatementHandleKey(handle), preparedStatement);
- final Signature signature = newSignature(query);
+ final Signature signature = newSignature(query, preparedStatement.getParameterSchema());
final long updateCount =
statementType.equals(StatementType.UPDATE) ? preparedStatement.executeUpdate() : -1;
synchronized (callback.getMonitor()) {
diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/AvaticaParameterFromArrowTypeVisitor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/AvaticaParameterFromArrowTypeVisitor.java
new file mode 100644
index 00000000000..c02f88fb7a7
--- /dev/null
+++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/AvaticaParameterFromArrowTypeVisitor.java
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.driver.jdbc;
+
+import java.sql.Types;
+
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.calcite.avatica.AvaticaParameter;
+
+/**
+ * Turn an Arrow Field into an equivalent AvaticaParameter.
+ */
+class AvaticaParameterFromArrowTypeVisitor implements ArrowType.ArrowTypeVisitor {
+ private final Field field;
+
+ AvaticaParameterFromArrowTypeVisitor(Field field) {
+ this.field = field;
+ }
+
+ @Override
+ public AvaticaParameter visit(ArrowType.Null type) {
+ throw new UnsupportedOperationException("Creating parameter with Arrow type " + type);
+ }
+
+ @Override
+ public AvaticaParameter visit(ArrowType.Struct type) {
+ throw new UnsupportedOperationException("Creating parameter with Arrow type " + type);
+ }
+
+ @Override
+ public AvaticaParameter visit(ArrowType.List type) {
+ throw new UnsupportedOperationException("Creating parameter with Arrow type " + type);
+ }
+
+ @Override
+ public AvaticaParameter visit(ArrowType.LargeList type) {
+ throw new UnsupportedOperationException("Creating parameter with Arrow type " + type);
+ }
+
+ @Override
+ public AvaticaParameter visit(ArrowType.FixedSizeList type) {
+ throw new UnsupportedOperationException("Creating parameter with Arrow type " + type);
+ }
+
+ @Override
+ public AvaticaParameter visit(ArrowType.Union type) {
+ throw new UnsupportedOperationException("Creating parameter with Arrow type " + type);
+ }
+
+ @Override
+ public AvaticaParameter visit(ArrowType.Map type) {
+ throw new UnsupportedOperationException("Creating parameter with Arrow type " + type);
+ }
+
+ @Override
+ public AvaticaParameter visit(ArrowType.Int type) {
+ throw new UnsupportedOperationException("Creating parameter with Arrow type " + type);
+ }
+
+ @Override
+ public AvaticaParameter visit(ArrowType.FloatingPoint type) {
+ throw new UnsupportedOperationException("Creating parameter with Arrow type " + type);
+ }
+
+ @Override
+ public AvaticaParameter visit(ArrowType.Utf8 type) {
+ return new AvaticaParameter(/*signed*/false, /*precision*/0, /*scale*/0, Types.VARCHAR, "VARCHAR",
+ String.class.getName(),
+ field.getName());
+ }
+
+ @Override
+ public AvaticaParameter visit(ArrowType.LargeUtf8 type) {
+ throw new UnsupportedOperationException("Creating parameter with Arrow type " + type);
+ }
+
+ @Override
+ public AvaticaParameter visit(ArrowType.Binary type) {
+ throw new UnsupportedOperationException("Creating parameter with Arrow type " + type);
+ }
+
+ @Override
+ public AvaticaParameter visit(ArrowType.LargeBinary type) {
+ throw new UnsupportedOperationException("Creating parameter with Arrow type " + type);
+ }
+
+ @Override
+ public AvaticaParameter visit(ArrowType.FixedSizeBinary type) {
+ throw new UnsupportedOperationException("Creating parameter with Arrow type " + type);
+ }
+
+ @Override
+ public AvaticaParameter visit(ArrowType.Bool type) {
+ throw new UnsupportedOperationException("Creating parameter with Arrow type " + type);
+ }
+
+ @Override
+ public AvaticaParameter visit(ArrowType.Decimal type) {
+ throw new UnsupportedOperationException("Creating parameter with Arrow type " + type);
+ }
+
+ @Override
+ public AvaticaParameter visit(ArrowType.Date type) {
+ throw new UnsupportedOperationException("Creating parameter with Arrow type " + type);
+ }
+
+ @Override
+ public AvaticaParameter visit(ArrowType.Time type) {
+ throw new UnsupportedOperationException("Creating parameter with Arrow type " + type);
+ }
+
+ @Override
+ public AvaticaParameter visit(ArrowType.Timestamp type) {
+ throw new UnsupportedOperationException("Creating parameter with Arrow type " + type);
+ }
+
+ @Override
+ public AvaticaParameter visit(ArrowType.Interval type) {
+ throw new UnsupportedOperationException("Creating parameter with Arrow type " + type);
+ }
+
+ @Override
+ public AvaticaParameter visit(ArrowType.Duration type) {
+ throw new UnsupportedOperationException("Creating parameter with Arrow type " + type);
+ }
+
+ static AvaticaParameter fromArrowField(Field field) {
+ return field.getType().accept(new AvaticaParameterFromArrowTypeVisitor(field));
+ }
+}
diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java
index 7b059ab02f8..e944128d42f 100644
--- a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java
+++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java
@@ -48,6 +48,7 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.util.Preconditions;
+import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.calcite.avatica.Meta.StatementType;
import org.slf4j.Logger;
@@ -155,6 +156,10 @@ public interface PreparedStatement extends AutoCloseable {
*/
Schema getDataSetSchema();
+ Schema getParameterSchema();
+
+ void setParameters(VectorSchemaRoot parameters);
+
@Override
void close();
}
@@ -190,6 +195,16 @@ public Schema getDataSetSchema() {
return preparedStatement.getResultSetSchema();
}
+ @Override
+ public Schema getParameterSchema() {
+ return preparedStatement.getParameterSchema();
+ }
+
+ @Override
+ public void setParameters(VectorSchemaRoot parameters) {
+ preparedStatement.setParameters(parameters);
+ }
+
@Override
public void close() {
try {
diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/TypedValueBinder.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/TypedValueBinder.java
new file mode 100644
index 00000000000..026a421e596
--- /dev/null
+++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/TypedValueBinder.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.driver.jdbc.utils;
+
+import java.nio.charset.StandardCharsets;
+import java.util.List;
+
+import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler.PreparedStatement;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.VarCharVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.calcite.avatica.remote.TypedValue;
+
+/**
+ * Bind {@link TypedValue}s to a {@link VectorSchemaRoot}.
+ */
+public class TypedValueBinder implements AutoCloseable {
+ private final PreparedStatement preparedStatement;
+ private final VectorSchemaRoot parameters;
+
+ public TypedValueBinder(PreparedStatement preparedStatement, BufferAllocator bufferAllocator) {
+ this.parameters = VectorSchemaRoot.create(preparedStatement.getParameterSchema(), bufferAllocator);
+ this.preparedStatement = preparedStatement;
+ }
+
+ /**
+ * Bind the given Avatica values to the prepared statement.
+ * @param typedValues The parameter values.
+ */
+ public void bind(List typedValues) {
+ if (preparedStatement.getParameterSchema().getFields().size() != typedValues.size()) {
+ throw new IllegalStateException(
+ String.format("Prepared statement has %s parameters, but only received %s",
+ preparedStatement.getParameterSchema().getFields().size(),
+ typedValues.size()));
+ }
+
+ for (int i = 0; i < typedValues.size(); i++) {
+ final TypedValue param = typedValues.get(i);
+ final FieldVector vector = parameters.getVector(i);
+ switch (param.type) {
+ case STRING:
+ bindValue((String) param.value, vector);
+ break;
+ default:
+ throw new UnsupportedOperationException(
+ String.format("Binding JDBC type %s to Arrow Flight SQL statement", param.type));
+ }
+ }
+
+ if (!typedValues.isEmpty()) {
+ parameters.setRowCount(1);
+ preparedStatement.setParameters(parameters);
+ }
+ }
+
+ private void bindValue(String value, FieldVector vector) {
+ if (vector instanceof VarCharVector) {
+ ((VarCharVector) vector).setSafe(0, value.getBytes(StandardCharsets.UTF_8));
+ } else {
+ throw new UnsupportedOperationException(
+ String.format("Binding String to parameter of Arrow type %s", vector.getField().getType()));
+ }
+ }
+
+ @Override
+ public void close() {
+ parameters.close();
+ }
+}
diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java
index 8af529296fa..ce5ee64b796 100644
--- a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java
+++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java
@@ -20,13 +20,19 @@
import static org.hamcrest.CoreMatchers.equalTo;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import java.nio.charset.StandardCharsets;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
+import java.util.Collections;
import org.apache.arrow.driver.jdbc.utils.CoreMockedSqlProducers;
import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.apache.arrow.vector.util.Text;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.ClassRule;
@@ -79,6 +85,16 @@ public void testReturnColumnCount() throws SQLException {
}
}
+ @Test
+ public void testSelectQueryWithParameters() throws SQLException {
+ String query = "Fake select";
+ PRODUCER.addUpdateQuery(query, /*updatedRows*/42);
+ try (final PreparedStatement stmt = connection.prepareStatement(query)) {
+ int updated = stmt.executeUpdate();
+ assertEquals(42, updated);
+ }
+ }
+
@Test
public void testUpdateQuery() throws SQLException {
String query = "Fake update";
@@ -88,4 +104,19 @@ public void testUpdateQuery() throws SQLException {
assertEquals(42, updated);
}
}
+
+ @Test
+ public void testUpdateQueryWithParameters() throws SQLException {
+ String query = "Fake update with parameters";
+ PRODUCER.addUpdateQuery(query, /*updatedRows*/42);
+ PRODUCER.addExpectedParameters(query,
+ new Schema(Collections.singletonList(Field.nullable("", ArrowType.Utf8.INSTANCE))),
+ Collections.singletonList(Collections.singletonList(new Text("foo".getBytes(StandardCharsets.UTF_8)))));
+ try (final PreparedStatement stmt = connection.prepareStatement(query)) {
+ // TODO: make sure this is validated on the server too
+ stmt.setString(1, "foo");
+ int updated = stmt.executeUpdate();
+ assertEquals(42, updated);
+ }
+ }
}
diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java
index cc8fae9722f..eec3a307b3f 100644
--- a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java
+++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java
@@ -34,6 +34,7 @@
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
+import java.util.Objects;
import java.util.UUID;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
@@ -74,6 +75,7 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.Preconditions;
+import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.WriteChannel;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.pojo.Schema;
@@ -96,7 +98,9 @@ public final class MockFlightSqlProducer implements FlightSqlProducer {
private final Map>>
updateResultProviders =
new HashMap<>();
- private SqlInfoBuilder sqlInfoBuilder = new SqlInfoBuilder();
+ private final SqlInfoBuilder sqlInfoBuilder = new SqlInfoBuilder();
+ private final Map parameterSchemas = new HashMap<>();
+ private final Map>> expectedParameterValues = new HashMap<>();
private static FlightInfo getFightInfoExportedAndImportedKeys(final Message message,
final FlightDescriptor descriptor) {
@@ -189,6 +193,13 @@ void addUpdateQuery(final String sqlCommand,
format("Attempted to overwrite pre-existing query: <%s>.", sqlCommand));
}
+
+ /** Registers parameters expected to be provided with a prepared statement. */
+ public void addExpectedParameters(String query, Schema parameterSchema, List> expectedValues) {
+ parameterSchemas.put(query, parameterSchema);
+ expectedParameterValues.put(query, expectedValues);
+ }
+
@Override
public void createPreparedStatement(final ActionCreatePreparedStatementRequest request,
final CallContext callContext,
@@ -213,13 +224,19 @@ public void createPreparedStatement(final ActionCreatePreparedStatementRequest r
resultBuilder.setDatasetSchema(datasetSchemaBytes);
} else if (updateResultProviders.containsKey(query)) {
preparedStatements.put(preparedStatementHandle, query);
-
} else {
listener.onError(
CallStatus.INVALID_ARGUMENT.withDescription("Query not found").toRuntimeException());
return;
}
+ final Schema parameterSchema = parameterSchemas.get(query);
+ if (parameterSchema != null) {
+ final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+ MessageSerializer.serialize(new WriteChannel(Channels.newChannel(outputStream)), parameterSchema);
+ resultBuilder.setParameterSchema(ByteString.copyFrom(outputStream.toByteArray()));
+ }
+
listener.onNext(new Result(pack(resultBuilder.build()).toByteArray()));
} catch (final Throwable t) {
listener.onError(t);
@@ -336,6 +353,45 @@ public Runnable acceptPutPreparedStatementUpdate(
final String query = Preconditions.checkNotNull(
preparedStatements.get(handle),
format("No query registered under handle: <%s>.", handle));
+ final List> expectedValues = expectedParameterValues.get(query);
+ if (expectedValues != null) {
+ int index = 0;
+ while (flightStream.next()) {
+ final VectorSchemaRoot root = flightStream.getRoot();
+ for (int i = 0; i < root.getRowCount(); i++) {
+ if (index >= expectedValues.size()) {
+ streamListener.onError(CallStatus.INVALID_ARGUMENT
+ .withDescription("More parameter rows provided than expected")
+ .toRuntimeException());
+ return () -> { };
+ }
+ List