Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions go/arrow/flight/flightsql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1124,7 +1124,6 @@ func (p *PreparedStatement) Execute(ctx context.Context, opts ...grpc.CallOption
if err != nil {
return nil, err
}

wr, err := p.writeBindParameters(pstream, desc)
if err != nil {
return nil, err
Expand All @@ -1133,9 +1132,7 @@ func (p *PreparedStatement) Execute(ctx context.Context, opts ...grpc.CallOption
return nil, err
}
pstream.CloseSend()

// wait for the server to ack the result
if _, err = pstream.Recv(); err != nil && err != io.EOF {
if err = p.captureDoPutPreparedStatementHandle(pstream); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -1173,9 +1170,7 @@ func (p *PreparedStatement) ExecutePut(ctx context.Context, opts ...grpc.CallOpt
return err
}
pstream.CloseSend()

// wait for the server to ack the result
if _, err = pstream.Recv(); err != nil && err != io.EOF {
if err = p.captureDoPutPreparedStatementHandle(pstream); err != nil {
return err
}
}
Expand Down Expand Up @@ -1219,9 +1214,7 @@ func (p *PreparedStatement) ExecutePoll(ctx context.Context, retryDescriptor *fl
return nil, err
}
pstream.CloseSend()

// wait for the server to ack the result
if _, err = pstream.Recv(); err != nil && err != io.EOF {
if err = p.captureDoPutPreparedStatementHandle(pstream); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -1313,6 +1306,29 @@ func (p *PreparedStatement) writeBindParameters(pstream pb.FlightService_DoPutCl
}
}

func (p *PreparedStatement) captureDoPutPreparedStatementHandle(pstream pb.FlightService_DoPutClient) error {
var (
result *pb.PutResult
preparedStatementResult pb.DoPutPreparedStatementResult
err error
)
if result, err = pstream.Recv(); err != nil && err != io.EOF {
return err
}
// skip if server does not provide a response (legacy server)
if result == nil {
return nil
}
if err = proto.Unmarshal(result.GetAppMetadata(), &preparedStatementResult); err != nil {
return err
}
handle := preparedStatementResult.GetPreparedStatementHandle()
if handle != nil {
p.handle = handle
}
return nil
}

// DatasetSchema may be nil if the server did not return it when creating the
// Prepared Statement.
func (p *PreparedStatement) DatasetSchema() *arrow.Schema { return p.datasetSchema }
Expand Down
35 changes: 24 additions & 11 deletions go/arrow/flight/flightsql/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,24 +408,26 @@ func (s *FlightSqlClientSuite) TestPreparedStatementExecute() {

func (s *FlightSqlClientSuite) TestPreparedStatementExecuteParamBinding() {
const query = "query"
const handle = "handle"
const updatedHandle = "updated handle"

// create and close actions
cmd := &pb.ActionCreatePreparedStatementRequest{Query: query}
action := getAction(cmd)
action.Type = flightsql.CreatePreparedStatementActionType
closeAct := getAction(&pb.ActionClosePreparedStatementRequest{PreparedStatementHandle: []byte(query)})
closeAct := getAction(&pb.ActionClosePreparedStatementRequest{PreparedStatementHandle: []byte(updatedHandle)})
closeAct.Type = flightsql.ClosePreparedStatementActionType

// results from createprepared statement
result := &pb.ActionCreatePreparedStatementResult{
PreparedStatementHandle: []byte(query),
actionResult := &pb.ActionCreatePreparedStatementResult{
PreparedStatementHandle: []byte(handle),
}
schema := arrow.NewSchema([]arrow.Field{{Name: "id", Type: arrow.PrimitiveTypes.Int64, Nullable: true}}, nil)
result.ParameterSchema = flight.SerializeSchema(schema, memory.DefaultAllocator)
actionResult.ParameterSchema = flight.SerializeSchema(schema, memory.DefaultAllocator)

// mocked client stream
var out anypb.Any
out.MarshalFrom(result)
out.MarshalFrom(actionResult)
data, _ := proto.Marshal(&out)

createRsp := &mockDoActionClient{}
Expand All @@ -443,7 +445,12 @@ func (s *FlightSqlClientSuite) TestPreparedStatementExecuteParamBinding() {
s.mockClient.On("DoAction", flightsql.CreatePreparedStatementActionType, action.Body, s.callOpts).Return(createRsp, nil)
s.mockClient.On("DoAction", flightsql.ClosePreparedStatementActionType, closeAct.Body, s.callOpts).Return(closeRsp, nil)

expectedDesc := getDesc(&pb.CommandPreparedStatementQuery{PreparedStatementHandle: []byte(query)})
expectedDesc := getDesc(&pb.CommandPreparedStatementQuery{PreparedStatementHandle: []byte(handle)})

// mocked DoPut result
doPutPreparedStatementResult := &pb.DoPutPreparedStatementResult{PreparedStatementHandle: []byte(updatedHandle)}
resdata, _ := proto.Marshal(doPutPreparedStatementResult)
putResult := &pb.PutResult{ AppMetadata: resdata }

// mocked client stream for DoPut
mockedPut := &mockDoPutClient{}
Expand All @@ -452,29 +459,30 @@ func (s *FlightSqlClientSuite) TestPreparedStatementExecuteParamBinding() {
return proto.Equal(expectedDesc, fd.FlightDescriptor)
})).Return(nil).Twice() // first sends schema message, second sends data
mockedPut.On("CloseSend").Return(nil)
mockedPut.On("Recv").Return((*pb.PutResult)(nil), nil)
mockedPut.On("Recv").Return(putResult, nil)

infoCmd := &pb.CommandPreparedStatementQuery{PreparedStatementHandle: []byte(query)}
infoCmd := &pb.CommandPreparedStatementQuery{PreparedStatementHandle: []byte(handle)}
desc := getDesc(infoCmd)
s.mockClient.On("GetFlightInfo", desc.Type, desc.Cmd, s.callOpts).Return(&emptyFlightInfo, nil)

prepared, err := s.sqlClient.Prepare(context.TODO(), query, s.callOpts...)
s.NoError(err)
defer prepared.Close(context.TODO(), s.callOpts...)

s.Equal(string(prepared.Handle()), "query")
s.Equal(string(prepared.Handle()), handle)

paramSchema := prepared.ParameterSchema()
rec, _, err := array.RecordFromJSON(memory.DefaultAllocator, paramSchema, strings.NewReader(`[{"id": 1}]`))
s.NoError(err)
defer rec.Release()

s.Equal(string(prepared.Handle()), "query")
s.Equal(string(prepared.Handle()), handle)

prepared.SetParameters(rec)
info, err := prepared.Execute(context.TODO(), s.callOpts...)
s.NoError(err)
s.Equal(&emptyFlightInfo, info)
s.Equal(string(prepared.Handle()), updatedHandle)
}

func (s *FlightSqlClientSuite) TestPreparedStatementExecuteReaderBinding() {
Expand Down Expand Up @@ -516,6 +524,11 @@ func (s *FlightSqlClientSuite) TestPreparedStatementExecuteReaderBinding() {

expectedDesc := getDesc(&pb.CommandPreparedStatementQuery{PreparedStatementHandle: []byte(query)})

// mocked DoPut result
doPutPreparedStatementResult := &pb.DoPutPreparedStatementResult{PreparedStatementHandle: []byte(query)}
resdata, _ := proto.Marshal(doPutPreparedStatementResult)
putResult := &pb.PutResult{ AppMetadata: resdata }

// mocked client stream for DoPut
mockedPut := &mockDoPutClient{}
s.mockClient.On("DoPut", s.callOpts).Return(mockedPut, nil)
Expand All @@ -528,7 +541,7 @@ func (s *FlightSqlClientSuite) TestPreparedStatementExecuteReaderBinding() {
return fd.FlightDescriptor == nil
})).Return(nil).Times(3)
mockedPut.On("CloseSend").Return(nil)
mockedPut.On("Recv").Return((*pb.PutResult)(nil), nil)
mockedPut.On("Recv").Return(putResult, nil)

infoCmd := &pb.CommandPreparedStatementQuery{PreparedStatementHandle: []byte(query)}
desc := getDesc(infoCmd)
Expand Down
10 changes: 5 additions & 5 deletions go/arrow/flight/flightsql/driver/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1768,16 +1768,16 @@ func (s *MockServer) CreatePreparedStatement(ctx context.Context, req flightsql.
}, nil
}

func (s *MockServer) DoPutPreparedStatementQuery(ctx context.Context, qry flightsql.PreparedStatementQuery, r flight.MessageReader, w flight.MetadataWriter) error {
func (s *MockServer) DoPutPreparedStatementQuery(ctx context.Context, qry flightsql.PreparedStatementQuery, r flight.MessageReader, w flight.MetadataWriter) ([]byte, error) {
if s.ExpectedPreparedStatementSchema != nil {
if !s.ExpectedPreparedStatementSchema.Equal(r.Schema()) {
return errors.New("parameter schema: unexpected")
return nil, errors.New("parameter schema: unexpected")
}
return nil
return qry.GetPreparedStatementHandle(), nil
}

if s.PreparedStatementParameterSchema != nil && !s.PreparedStatementParameterSchema.Equal(r.Schema()) {
return fmt.Errorf("parameter schema: %w", arrow.ErrInvalid)
return nil, fmt.Errorf("parameter schema: %w", arrow.ErrInvalid)
}

// GH-35328: it's rare, but this function can complete execution and return
Expand All @@ -1791,7 +1791,7 @@ func (s *MockServer) DoPutPreparedStatementQuery(ctx context.Context, qry flight
for r.Next() {
}

return nil
return qry.GetPreparedStatementHandle(), nil
}

func (s *MockServer) DoGetStatement(ctx context.Context, ticket flightsql.StatementQueryTicket) (*arrow.Schema, <-chan flight.StreamChunk, error) {
Expand Down
8 changes: 4 additions & 4 deletions go/arrow/flight/flightsql/example/sqlite_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -618,21 +618,21 @@ func getParamsForStatement(rdr flight.MessageReader) (params [][]interface{}, er
return params, rdr.Err()
}

func (s *SQLiteFlightSQLServer) DoPutPreparedStatementQuery(_ context.Context, cmd flightsql.PreparedStatementQuery, rdr flight.MessageReader, _ flight.MetadataWriter) error {
func (s *SQLiteFlightSQLServer) DoPutPreparedStatementQuery(_ context.Context, cmd flightsql.PreparedStatementQuery, rdr flight.MessageReader, _ flight.MetadataWriter) ([]byte, error) {
val, ok := s.prepared.Load(string(cmd.GetPreparedStatementHandle()))
if !ok {
return status.Error(codes.InvalidArgument, "prepared statement not found")
return nil, status.Error(codes.InvalidArgument, "prepared statement not found")
}

stmt := val.(Statement)
args, err := getParamsForStatement(rdr)
if err != nil {
return status.Errorf(codes.Internal, "error gathering parameters for prepared statement query: %s", err.Error())
return nil, status.Errorf(codes.Internal, "error gathering parameters for prepared statement query: %s", err.Error())
}

stmt.params = args
s.prepared.Store(string(cmd.GetPreparedStatementHandle()), stmt)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a Go expert, but is this caching the prepared statement with parameters against the original handle?
Should a stateless implementation provide a new handle, presumably containing query and parameters?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is caching the prepared statement parameters against the original handle.

According to the updated spec, it is optional for the server to respond with an updated handle. In this example, we aren't leveraging the statelessness to embed the arguments and everything into the handle itself.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for clarifying @zeroshade . Looks like I have a grasp of Go then.
I think if possible it would be worth having an example of stateless use in this PR. That might be a bigger change since this server implementation is based on stateful behaviour.
I guess I'm curious how we would create a new handle based on query and parameters or if we need a new proto message from which we can generate a handle.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the handles are completely opaque to the protocol, it would depend entirely on how a given server chooses to represent its handle. The example server just uses a random character string as the handle which is a key into the map. But the handle could just as easily be a serialized protobuf message, etc. if desired.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I get that it's opaque so really doesn't matter here. I'll post a Java implementation PR soon with a stateless example which might be useful, but very much down to individuals what they might pick from it.

Copy link
Copy Markdown
Contributor Author

@erratic-pattern erratic-pattern Apr 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I decided an actual stateless implementation was out of scope for this change, and just made the minimal change to achieve compatibility with the new spec, which allows for either stateless or stateful server implementation. There is certainly an argument to be made for stateless implementation, but I leave that decision up to the maintainers.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The title of this PR is misworded then, so I will update it to reflect what the change is.

return nil
return cmd.GetPreparedStatementHandle(), nil
}

func (s *SQLiteFlightSQLServer) DoPutPreparedStatementUpdate(ctx context.Context, cmd flightsql.PreparedStatementUpdate, rdr flight.MessageReader) (int64, error) {
Expand Down
17 changes: 13 additions & 4 deletions go/arrow/flight/flightsql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -510,8 +510,8 @@ func (BaseServer) DoPutCommandSubstraitPlan(context.Context, StatementSubstraitP
return 0, status.Error(codes.Unimplemented, "DoPutCommandSubstraitPlan not implemented")
}

func (BaseServer) DoPutPreparedStatementQuery(context.Context, PreparedStatementQuery, flight.MessageReader, flight.MetadataWriter) error {
return status.Error(codes.Unimplemented, "DoPutPreparedStatementQuery not implemented")
func (BaseServer) DoPutPreparedStatementQuery(context.Context, PreparedStatementQuery, flight.MessageReader, flight.MetadataWriter) ([]byte, error) {
return nil, status.Error(codes.Unimplemented, "DoPutPreparedStatementQuery not implemented")
}

func (BaseServer) DoPutPreparedStatementUpdate(context.Context, PreparedStatementUpdate, flight.MessageReader) (int64, error) {
Expand Down Expand Up @@ -677,7 +677,7 @@ type Server interface {
// Currently anything written to the writer will be ignored. It is in the
// interface for potential future enhancements to avoid having to change
// the interface in the future.
DoPutPreparedStatementQuery(context.Context, PreparedStatementQuery, flight.MessageReader, flight.MetadataWriter) error
DoPutPreparedStatementQuery(context.Context, PreparedStatementQuery, flight.MessageReader, flight.MetadataWriter) ([]byte, error)
// DoPutPreparedStatementUpdate executes an update SQL Prepared statement
// for the specified statement handle. The reader allows providing a sequence
// of uploaded record batches to bind the parameters to. Returns the number
Expand Down Expand Up @@ -990,7 +990,16 @@ func (f *flightSqlServer) DoPut(stream flight.FlightService_DoPutServer) error {
}
return stream.Send(out)
case *pb.CommandPreparedStatementQuery:
return f.srv.DoPutPreparedStatementQuery(stream.Context(), cmd, rdr, &putMetadataWriter{stream})
handle, err := f.srv.DoPutPreparedStatementQuery(stream.Context(), cmd, rdr, &putMetadataWriter{stream})
if err != nil {
return err
}
result := pb.DoPutPreparedStatementResult{PreparedStatementHandle: handle}
out := &flight.PutResult{}
if out.AppMetadata, err = proto.Marshal(&result); err != nil {
return status.Errorf(codes.Internal, "failed to marshal PutResult: %s", err.Error())
}
return stream.Send(out)
case *pb.CommandPreparedStatementUpdate:
recordCount, err := f.srv.DoPutPreparedStatementUpdate(stream.Context(), cmd, rdr)
if err != nil {
Expand Down
Loading