diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java index 4c01cb6e581..c6db84c7e78 100644 --- a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java @@ -92,7 +92,7 @@ static ArrowFlightJdbcFlightStreamResultSet fromFlightInfo( final TimeZone timeZone = TimeZone.getDefault(); final QueryState state = new QueryState(); - final Meta.Signature signature = ArrowFlightMetaImpl.newSignature(null); + final Meta.Signature signature = ArrowFlightMetaImpl.newSignature(null, null); final AvaticaResultSetMetaData resultSetMetaData = new AvaticaResultSetMetaData(null, null, signature); diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java index 9e377e51dec..72875597322 100644 --- a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java @@ -74,7 +74,7 @@ public static ArrowFlightJdbcVectorSchemaRootResultSet fromVectorSchemaRoot( final TimeZone timeZone = TimeZone.getDefault(); final QueryState state = new QueryState(); - final Meta.Signature signature = ArrowFlightMetaImpl.newSignature(null); + final Meta.Signature signature = ArrowFlightMetaImpl.newSignature(null, null); final AvaticaResultSetMetaData resultSetMetaData = new AvaticaResultSetMetaData(null, null, signature); diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java index f825e7d13ce..2589e0c639e 100644 --- a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java @@ -27,9 +27,12 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler.PreparedStatement; +import org.apache.arrow.driver.jdbc.utils.TypedValueBinder; import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.types.pojo.Schema; import org.apache.calcite.avatica.AvaticaConnection; import org.apache.calcite.avatica.AvaticaParameter; import org.apache.calcite.avatica.ColumnMetaData; @@ -54,17 +57,31 @@ public ArrowFlightMetaImpl(final AvaticaConnection connection) { setDefaultConnectionProperties(); } - static Signature newSignature(final String sql) { + static Signature newSignature(final String sql, Schema parameterSchema) { + final List parameters; + if (parameterSchema == null) { + parameters = Collections.emptyList(); + } else { + parameters = parameterSchema.getFields() + .stream() + .map(AvaticaParameterFromArrowTypeVisitor::fromArrowField) + .collect(Collectors.toList()); + } + Map internalParameters = Collections.emptyMap(); return new Signature( new ArrayList(), sql, - Collections.emptyList(), - Collections.emptyMap(), - null, // unnecessary, as SQL requests use ArrowFlightJdbcCursor + parameters, + internalParameters, + /*cursorFactory*/null, // unnecessary, as SQL requests use ArrowFlightJdbcCursor StatementType.SELECT ); } + private ArrowFlightConnection getConnection() { + return (ArrowFlightConnection) connection; + } + @Override public void closeStatement(final StatementHandle statementHandle) { PreparedStatement preparedStatement = @@ -86,17 +103,26 @@ public ExecuteResult execute(final StatementHandle statementHandle, Preconditions.checkArgument(connection.id.equals(statementHandle.connectionId), "Connection IDs are not consistent"); if (statementHandle.signature == null) { + // TODO: refactor update/select queries out into separate methods // Update query final StatementHandleKey key = new StatementHandleKey(statementHandle); PreparedStatement preparedStatement = statementHandlePreparedStatementMap.get(key); if (preparedStatement == null) { throw new IllegalStateException("Prepared statement not found: " + statementHandle); } - long updatedCount = preparedStatement.executeUpdate(); + + long updatedCount; + try (final TypedValueBinder binder = + new TypedValueBinder(preparedStatement, getConnection().getBufferAllocator())) { + binder.bind(typedValues); + updatedCount = preparedStatement.executeUpdate(); + } return new ExecuteResult(Collections.singletonList(MetaResultSet.count(statementHandle.connectionId, statementHandle.id, updatedCount))); } else { // TODO Why is maxRowCount ignored? + // TODO: TypedValues + // TODO: should move execution eagerly here instead of deferring it return new ExecuteResult( Collections.singletonList(MetaResultSet.create( statementHandle.connectionId, statementHandle.id, @@ -134,9 +160,8 @@ public Frame fetch(final StatementHandle statementHandle, final long offset, public StatementHandle prepare(final ConnectionHandle connectionHandle, final String query, final long maxRowCount) { final StatementHandle handle = super.createStatement(connectionHandle); - handle.signature = newSignature(query); - final PreparedStatement preparedStatement = - ((ArrowFlightConnection) connection).getClientHandler().prepare(query); + final PreparedStatement preparedStatement = getConnection().getClientHandler().prepare(query); + handle.signature = newSignature(query, preparedStatement.getParameterSchema()); statementHandlePreparedStatementMap.put(new StatementHandleKey(handle), preparedStatement); return handle; } @@ -157,11 +182,10 @@ public ExecuteResult prepareAndExecute(final StatementHandle handle, final PrepareCallback callback) throws NoSuchStatementException { try { - final PreparedStatement preparedStatement = - ((ArrowFlightConnection) connection).getClientHandler().prepare(query); + final PreparedStatement preparedStatement = getConnection().getClientHandler().prepare(query); final StatementType statementType = preparedStatement.getType(); statementHandlePreparedStatementMap.put(new StatementHandleKey(handle), preparedStatement); - final Signature signature = newSignature(query); + final Signature signature = newSignature(query, preparedStatement.getParameterSchema()); final long updateCount = statementType.equals(StatementType.UPDATE) ? preparedStatement.executeUpdate() : -1; synchronized (callback.getMonitor()) { diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/AvaticaParameterFromArrowTypeVisitor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/AvaticaParameterFromArrowTypeVisitor.java new file mode 100644 index 00000000000..c02f88fb7a7 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/AvaticaParameterFromArrowTypeVisitor.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import java.sql.Types; + +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.calcite.avatica.AvaticaParameter; + +/** + * Turn an Arrow Field into an equivalent AvaticaParameter. + */ +class AvaticaParameterFromArrowTypeVisitor implements ArrowType.ArrowTypeVisitor { + private final Field field; + + AvaticaParameterFromArrowTypeVisitor(Field field) { + this.field = field; + } + + @Override + public AvaticaParameter visit(ArrowType.Null type) { + throw new UnsupportedOperationException("Creating parameter with Arrow type " + type); + } + + @Override + public AvaticaParameter visit(ArrowType.Struct type) { + throw new UnsupportedOperationException("Creating parameter with Arrow type " + type); + } + + @Override + public AvaticaParameter visit(ArrowType.List type) { + throw new UnsupportedOperationException("Creating parameter with Arrow type " + type); + } + + @Override + public AvaticaParameter visit(ArrowType.LargeList type) { + throw new UnsupportedOperationException("Creating parameter with Arrow type " + type); + } + + @Override + public AvaticaParameter visit(ArrowType.FixedSizeList type) { + throw new UnsupportedOperationException("Creating parameter with Arrow type " + type); + } + + @Override + public AvaticaParameter visit(ArrowType.Union type) { + throw new UnsupportedOperationException("Creating parameter with Arrow type " + type); + } + + @Override + public AvaticaParameter visit(ArrowType.Map type) { + throw new UnsupportedOperationException("Creating parameter with Arrow type " + type); + } + + @Override + public AvaticaParameter visit(ArrowType.Int type) { + throw new UnsupportedOperationException("Creating parameter with Arrow type " + type); + } + + @Override + public AvaticaParameter visit(ArrowType.FloatingPoint type) { + throw new UnsupportedOperationException("Creating parameter with Arrow type " + type); + } + + @Override + public AvaticaParameter visit(ArrowType.Utf8 type) { + return new AvaticaParameter(/*signed*/false, /*precision*/0, /*scale*/0, Types.VARCHAR, "VARCHAR", + String.class.getName(), + field.getName()); + } + + @Override + public AvaticaParameter visit(ArrowType.LargeUtf8 type) { + throw new UnsupportedOperationException("Creating parameter with Arrow type " + type); + } + + @Override + public AvaticaParameter visit(ArrowType.Binary type) { + throw new UnsupportedOperationException("Creating parameter with Arrow type " + type); + } + + @Override + public AvaticaParameter visit(ArrowType.LargeBinary type) { + throw new UnsupportedOperationException("Creating parameter with Arrow type " + type); + } + + @Override + public AvaticaParameter visit(ArrowType.FixedSizeBinary type) { + throw new UnsupportedOperationException("Creating parameter with Arrow type " + type); + } + + @Override + public AvaticaParameter visit(ArrowType.Bool type) { + throw new UnsupportedOperationException("Creating parameter with Arrow type " + type); + } + + @Override + public AvaticaParameter visit(ArrowType.Decimal type) { + throw new UnsupportedOperationException("Creating parameter with Arrow type " + type); + } + + @Override + public AvaticaParameter visit(ArrowType.Date type) { + throw new UnsupportedOperationException("Creating parameter with Arrow type " + type); + } + + @Override + public AvaticaParameter visit(ArrowType.Time type) { + throw new UnsupportedOperationException("Creating parameter with Arrow type " + type); + } + + @Override + public AvaticaParameter visit(ArrowType.Timestamp type) { + throw new UnsupportedOperationException("Creating parameter with Arrow type " + type); + } + + @Override + public AvaticaParameter visit(ArrowType.Interval type) { + throw new UnsupportedOperationException("Creating parameter with Arrow type " + type); + } + + @Override + public AvaticaParameter visit(ArrowType.Duration type) { + throw new UnsupportedOperationException("Creating parameter with Arrow type " + type); + } + + static AvaticaParameter fromArrowField(Field field) { + return field.getType().accept(new AvaticaParameterFromArrowTypeVisitor(field)); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java index 7b059ab02f8..e944128d42f 100644 --- a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java @@ -48,6 +48,7 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.AutoCloseables; import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.calcite.avatica.Meta.StatementType; import org.slf4j.Logger; @@ -155,6 +156,10 @@ public interface PreparedStatement extends AutoCloseable { */ Schema getDataSetSchema(); + Schema getParameterSchema(); + + void setParameters(VectorSchemaRoot parameters); + @Override void close(); } @@ -190,6 +195,16 @@ public Schema getDataSetSchema() { return preparedStatement.getResultSetSchema(); } + @Override + public Schema getParameterSchema() { + return preparedStatement.getParameterSchema(); + } + + @Override + public void setParameters(VectorSchemaRoot parameters) { + preparedStatement.setParameters(parameters); + } + @Override public void close() { try { diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/TypedValueBinder.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/TypedValueBinder.java new file mode 100644 index 00000000000..026a421e596 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/TypedValueBinder.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import java.nio.charset.StandardCharsets; +import java.util.List; + +import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler.PreparedStatement; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.calcite.avatica.remote.TypedValue; + +/** + * Bind {@link TypedValue}s to a {@link VectorSchemaRoot}. + */ +public class TypedValueBinder implements AutoCloseable { + private final PreparedStatement preparedStatement; + private final VectorSchemaRoot parameters; + + public TypedValueBinder(PreparedStatement preparedStatement, BufferAllocator bufferAllocator) { + this.parameters = VectorSchemaRoot.create(preparedStatement.getParameterSchema(), bufferAllocator); + this.preparedStatement = preparedStatement; + } + + /** + * Bind the given Avatica values to the prepared statement. + * @param typedValues The parameter values. + */ + public void bind(List typedValues) { + if (preparedStatement.getParameterSchema().getFields().size() != typedValues.size()) { + throw new IllegalStateException( + String.format("Prepared statement has %s parameters, but only received %s", + preparedStatement.getParameterSchema().getFields().size(), + typedValues.size())); + } + + for (int i = 0; i < typedValues.size(); i++) { + final TypedValue param = typedValues.get(i); + final FieldVector vector = parameters.getVector(i); + switch (param.type) { + case STRING: + bindValue((String) param.value, vector); + break; + default: + throw new UnsupportedOperationException( + String.format("Binding JDBC type %s to Arrow Flight SQL statement", param.type)); + } + } + + if (!typedValues.isEmpty()) { + parameters.setRowCount(1); + preparedStatement.setParameters(parameters); + } + } + + private void bindValue(String value, FieldVector vector) { + if (vector instanceof VarCharVector) { + ((VarCharVector) vector).setSafe(0, value.getBytes(StandardCharsets.UTF_8)); + } else { + throw new UnsupportedOperationException( + String.format("Binding String to parameter of Arrow type %s", vector.getField().getType())); + } + } + + @Override + public void close() { + parameters.close(); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java index 8af529296fa..ce5ee64b796 100644 --- a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java @@ -20,13 +20,19 @@ import static org.hamcrest.CoreMatchers.equalTo; import static org.junit.jupiter.api.Assertions.assertEquals; +import java.nio.charset.StandardCharsets; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; +import java.util.Collections; import org.apache.arrow.driver.jdbc.utils.CoreMockedSqlProducers; import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Text; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.ClassRule; @@ -79,6 +85,16 @@ public void testReturnColumnCount() throws SQLException { } } + @Test + public void testSelectQueryWithParameters() throws SQLException { + String query = "Fake select"; + PRODUCER.addUpdateQuery(query, /*updatedRows*/42); + try (final PreparedStatement stmt = connection.prepareStatement(query)) { + int updated = stmt.executeUpdate(); + assertEquals(42, updated); + } + } + @Test public void testUpdateQuery() throws SQLException { String query = "Fake update"; @@ -88,4 +104,19 @@ public void testUpdateQuery() throws SQLException { assertEquals(42, updated); } } + + @Test + public void testUpdateQueryWithParameters() throws SQLException { + String query = "Fake update with parameters"; + PRODUCER.addUpdateQuery(query, /*updatedRows*/42); + PRODUCER.addExpectedParameters(query, + new Schema(Collections.singletonList(Field.nullable("", ArrowType.Utf8.INSTANCE))), + Collections.singletonList(Collections.singletonList(new Text("foo".getBytes(StandardCharsets.UTF_8))))); + try (final PreparedStatement stmt = connection.prepareStatement(query)) { + // TODO: make sure this is validated on the server too + stmt.setString(1, "foo"); + int updated = stmt.executeUpdate(); + assertEquals(42, updated); + } + } } diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java index cc8fae9722f..eec3a307b3f 100644 --- a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java @@ -34,6 +34,7 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Objects; import java.util.UUID; import java.util.function.BiConsumer; import java.util.function.Consumer; @@ -74,6 +75,7 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.ipc.WriteChannel; import org.apache.arrow.vector.ipc.message.MessageSerializer; import org.apache.arrow.vector.types.pojo.Schema; @@ -96,7 +98,9 @@ public final class MockFlightSqlProducer implements FlightSqlProducer { private final Map>> updateResultProviders = new HashMap<>(); - private SqlInfoBuilder sqlInfoBuilder = new SqlInfoBuilder(); + private final SqlInfoBuilder sqlInfoBuilder = new SqlInfoBuilder(); + private final Map parameterSchemas = new HashMap<>(); + private final Map>> expectedParameterValues = new HashMap<>(); private static FlightInfo getFightInfoExportedAndImportedKeys(final Message message, final FlightDescriptor descriptor) { @@ -189,6 +193,13 @@ void addUpdateQuery(final String sqlCommand, format("Attempted to overwrite pre-existing query: <%s>.", sqlCommand)); } + + /** Registers parameters expected to be provided with a prepared statement. */ + public void addExpectedParameters(String query, Schema parameterSchema, List> expectedValues) { + parameterSchemas.put(query, parameterSchema); + expectedParameterValues.put(query, expectedValues); + } + @Override public void createPreparedStatement(final ActionCreatePreparedStatementRequest request, final CallContext callContext, @@ -213,13 +224,19 @@ public void createPreparedStatement(final ActionCreatePreparedStatementRequest r resultBuilder.setDatasetSchema(datasetSchemaBytes); } else if (updateResultProviders.containsKey(query)) { preparedStatements.put(preparedStatementHandle, query); - } else { listener.onError( CallStatus.INVALID_ARGUMENT.withDescription("Query not found").toRuntimeException()); return; } + final Schema parameterSchema = parameterSchemas.get(query); + if (parameterSchema != null) { + final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + MessageSerializer.serialize(new WriteChannel(Channels.newChannel(outputStream)), parameterSchema); + resultBuilder.setParameterSchema(ByteString.copyFrom(outputStream.toByteArray())); + } + listener.onNext(new Result(pack(resultBuilder.build()).toByteArray())); } catch (final Throwable t) { listener.onError(t); @@ -336,6 +353,45 @@ public Runnable acceptPutPreparedStatementUpdate( final String query = Preconditions.checkNotNull( preparedStatements.get(handle), format("No query registered under handle: <%s>.", handle)); + final List> expectedValues = expectedParameterValues.get(query); + if (expectedValues != null) { + int index = 0; + while (flightStream.next()) { + final VectorSchemaRoot root = flightStream.getRoot(); + for (int i = 0; i < root.getRowCount(); i++) { + if (index >= expectedValues.size()) { + streamListener.onError(CallStatus.INVALID_ARGUMENT + .withDescription("More parameter rows provided than expected") + .toRuntimeException()); + return () -> { }; + } + List expectedRow = expectedValues.get(index++); + if (root.getFieldVectors().size() != expectedRow.size()) { + streamListener.onError(CallStatus.INVALID_ARGUMENT + .withDescription("Parameter count mismatch") + .toRuntimeException()); + return () -> { }; + } + + for (int paramIndex = 0; paramIndex < expectedRow.size(); paramIndex++) { + Object expected = expectedRow.get(paramIndex); + Object actual = root.getVector(paramIndex).getObject(i); + if (!Objects.equals(expected, actual)) { + streamListener.onError(CallStatus.INVALID_ARGUMENT + .withDescription(String.format("Parameter mismatch. Expected: %s Actual: %s", expected, actual)) + .toRuntimeException()); + return () -> { }; + } + } + } + } + if (index < expectedValues.size()) { + streamListener.onError(CallStatus.INVALID_ARGUMENT + .withDescription("Fewer parameter rows provided than expected") + .toRuntimeException()); + return () -> { }; + } + } return acceptPutStatement( CommandStatementUpdate.newBuilder().setQuery(query).build(), callContext, flightStream, streamListener); diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/TypedValueBinderTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/TypedValueBinderTest.java new file mode 100644 index 00000000000..693bf9d31ec --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/TypedValueBinderTest.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +class TypedValueBinderTest { + +}