From d4fafc9c938ca0eb303f38d8461c02edd111abd4 Mon Sep 17 00:00:00 2001
From: David Li
Date: Mon, 29 Apr 2024 02:22:29 -0400
Subject: [PATCH] GH-41427: [Go] Fix stateless prepared statements
---
go/arrow/flight/flightsql/client.go | 93 ++++++++++--------------
go/arrow/flight/flightsql/client_test.go | 10 +--
2 files changed, 45 insertions(+), 58 deletions(-)
diff --git a/go/arrow/flight/flightsql/client.go b/go/arrow/flight/flightsql/client.go
index e594191c35fd..c6794820dc17 100644
--- a/go/arrow/flight/flightsql/client.go
+++ b/go/arrow/flight/flightsql/client.go
@@ -1119,24 +1119,10 @@ func (p *PreparedStatement) Execute(ctx context.Context, opts ...grpc.CallOption
return nil, err
}
- if p.hasBindParameters() {
- pstream, err := p.client.Client.DoPut(ctx, opts...)
- if err != nil {
- return nil, err
- }
- wr, err := p.writeBindParameters(pstream, desc)
- if err != nil {
- return nil, err
- }
- if err = wr.Close(); err != nil {
- return nil, err
- }
- pstream.CloseSend()
- if err = p.captureDoPutPreparedStatementHandle(pstream); err != nil {
- return nil, err
- }
+ desc, err = p.bindParameters(ctx, desc, opts...)
+ if err != nil {
+ return nil, err
}
-
return p.client.getFlightInfo(ctx, desc, opts...)
}
@@ -1156,23 +1142,9 @@ func (p *PreparedStatement) ExecutePut(ctx context.Context, opts ...grpc.CallOpt
return err
}
- if p.hasBindParameters() {
- pstream, err := p.client.Client.DoPut(ctx, opts...)
- if err != nil {
- return err
- }
-
- wr, err := p.writeBindParameters(pstream, desc)
- if err != nil {
- return err
- }
- if err = wr.Close(); err != nil {
- return err
- }
- pstream.CloseSend()
- if err = p.captureDoPutPreparedStatementHandle(pstream); err != nil {
- return err
- }
+ _, err = p.bindParameters(ctx, desc, opts...)
+ if err != nil {
+ return err
}
return nil
@@ -1200,23 +1172,9 @@ func (p *PreparedStatement) ExecutePoll(ctx context.Context, retryDescriptor *fl
}
if retryDescriptor == nil {
- if p.hasBindParameters() {
- pstream, err := p.client.Client.DoPut(ctx, opts...)
- if err != nil {
- return nil, err
- }
-
- wr, err := p.writeBindParameters(pstream, desc)
- if err != nil {
- return nil, err
- }
- if err = wr.Close(); err != nil {
- return nil, err
- }
- pstream.CloseSend()
- if err = p.captureDoPutPreparedStatementHandle(pstream); err != nil {
- return nil, err
- }
+ desc, err = p.bindParameters(ctx, desc, opts...)
+ if err != nil {
+ return nil, err
}
}
return p.client.Client.PollFlightInfo(ctx, desc, opts...)
@@ -1248,7 +1206,7 @@ func (p *PreparedStatement) ExecuteUpdate(ctx context.Context, opts ...grpc.Call
return
}
if p.hasBindParameters() {
- wr, err = p.writeBindParameters(pstream, desc)
+ wr, err = p.writeBindParametersToStream(pstream, desc)
if err != nil {
return
}
@@ -1283,7 +1241,36 @@ func (p *PreparedStatement) hasBindParameters() bool {
return (p.paramBinding != nil && p.paramBinding.NumRows() > 0) || (p.streamBinding != nil)
}
-func (p *PreparedStatement) writeBindParameters(pstream pb.FlightService_DoPutClient, desc *pb.FlightDescriptor) (*flight.Writer, error) {
+func (p *PreparedStatement) bindParameters(ctx context.Context, desc *pb.FlightDescriptor, opts ...grpc.CallOption) (*flight.FlightDescriptor, error) {
+ if p.hasBindParameters() {
+ pstream, err := p.client.Client.DoPut(ctx, opts...)
+ if err != nil {
+ return nil, err
+ }
+ wr, err := p.writeBindParametersToStream(pstream, desc)
+ if err != nil {
+ return nil, err
+ }
+ if err = wr.Close(); err != nil {
+ return nil, err
+ }
+ pstream.CloseSend()
+ if err = p.captureDoPutPreparedStatementHandle(pstream); err != nil {
+ return nil, err
+ }
+
+ cmd := pb.CommandPreparedStatementQuery{PreparedStatementHandle: p.handle}
+ desc, err = descForCommand(&cmd)
+ if err != nil {
+ return nil, err
+ }
+ return desc, nil
+ }
+ return desc, nil
+}
+
+// XXX: this does not capture the updated handle. Prefer bindParameters.
+func (p *PreparedStatement) writeBindParametersToStream(pstream pb.FlightService_DoPutClient, desc *pb.FlightDescriptor) (*flight.Writer, error) {
if p.paramBinding != nil {
wr := flight.NewRecordWriter(pstream, ipc.WithSchema(p.paramBinding.Schema()))
wr.SetFlightDescriptor(desc)
diff --git a/go/arrow/flight/flightsql/client_test.go b/go/arrow/flight/flightsql/client_test.go
index 727fe02aa706..33da79167c4a 100644
--- a/go/arrow/flight/flightsql/client_test.go
+++ b/go/arrow/flight/flightsql/client_test.go
@@ -448,9 +448,9 @@ func (s *FlightSqlClientSuite) TestPreparedStatementExecuteParamBinding() {
expectedDesc := getDesc(&pb.CommandPreparedStatementQuery{PreparedStatementHandle: []byte(handle)})
// mocked DoPut result
- doPutPreparedStatementResult := &pb.DoPutPreparedStatementResult{PreparedStatementHandle: []byte(updatedHandle)}
+ doPutPreparedStatementResult := &pb.DoPutPreparedStatementResult{PreparedStatementHandle: []byte(updatedHandle)}
resdata, _ := proto.Marshal(doPutPreparedStatementResult)
- putResult := &pb.PutResult{ AppMetadata: resdata }
+ putResult := &pb.PutResult{AppMetadata: resdata}
// mocked client stream for DoPut
mockedPut := &mockDoPutClient{}
@@ -461,7 +461,7 @@ func (s *FlightSqlClientSuite) TestPreparedStatementExecuteParamBinding() {
mockedPut.On("CloseSend").Return(nil)
mockedPut.On("Recv").Return(putResult, nil)
- infoCmd := &pb.CommandPreparedStatementQuery{PreparedStatementHandle: []byte(handle)}
+ infoCmd := &pb.CommandPreparedStatementQuery{PreparedStatementHandle: []byte(updatedHandle)}
desc := getDesc(infoCmd)
s.mockClient.On("GetFlightInfo", desc.Type, desc.Cmd, s.callOpts).Return(&emptyFlightInfo, nil)
@@ -525,9 +525,9 @@ func (s *FlightSqlClientSuite) TestPreparedStatementExecuteReaderBinding() {
expectedDesc := getDesc(&pb.CommandPreparedStatementQuery{PreparedStatementHandle: []byte(query)})
// mocked DoPut result
- doPutPreparedStatementResult := &pb.DoPutPreparedStatementResult{PreparedStatementHandle: []byte(query)}
+ doPutPreparedStatementResult := &pb.DoPutPreparedStatementResult{PreparedStatementHandle: []byte(query)}
resdata, _ := proto.Marshal(doPutPreparedStatementResult)
- putResult := &pb.PutResult{ AppMetadata: resdata }
+ putResult := &pb.PutResult{AppMetadata: resdata}
// mocked client stream for DoPut
mockedPut := &mockDoPutClient{}