From d4fafc9c938ca0eb303f38d8461c02edd111abd4 Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 29 Apr 2024 02:22:29 -0400 Subject: [PATCH] GH-41427: [Go] Fix stateless prepared statements --- go/arrow/flight/flightsql/client.go | 93 ++++++++++-------------- go/arrow/flight/flightsql/client_test.go | 10 +-- 2 files changed, 45 insertions(+), 58 deletions(-) diff --git a/go/arrow/flight/flightsql/client.go b/go/arrow/flight/flightsql/client.go index e594191c35fd..c6794820dc17 100644 --- a/go/arrow/flight/flightsql/client.go +++ b/go/arrow/flight/flightsql/client.go @@ -1119,24 +1119,10 @@ func (p *PreparedStatement) Execute(ctx context.Context, opts ...grpc.CallOption return nil, err } - if p.hasBindParameters() { - pstream, err := p.client.Client.DoPut(ctx, opts...) - if err != nil { - return nil, err - } - wr, err := p.writeBindParameters(pstream, desc) - if err != nil { - return nil, err - } - if err = wr.Close(); err != nil { - return nil, err - } - pstream.CloseSend() - if err = p.captureDoPutPreparedStatementHandle(pstream); err != nil { - return nil, err - } + desc, err = p.bindParameters(ctx, desc, opts...) + if err != nil { + return nil, err } - return p.client.getFlightInfo(ctx, desc, opts...) } @@ -1156,23 +1142,9 @@ func (p *PreparedStatement) ExecutePut(ctx context.Context, opts ...grpc.CallOpt return err } - if p.hasBindParameters() { - pstream, err := p.client.Client.DoPut(ctx, opts...) - if err != nil { - return err - } - - wr, err := p.writeBindParameters(pstream, desc) - if err != nil { - return err - } - if err = wr.Close(); err != nil { - return err - } - pstream.CloseSend() - if err = p.captureDoPutPreparedStatementHandle(pstream); err != nil { - return err - } + _, err = p.bindParameters(ctx, desc, opts...) + if err != nil { + return err } return nil @@ -1200,23 +1172,9 @@ func (p *PreparedStatement) ExecutePoll(ctx context.Context, retryDescriptor *fl } if retryDescriptor == nil { - if p.hasBindParameters() { - pstream, err := p.client.Client.DoPut(ctx, opts...) - if err != nil { - return nil, err - } - - wr, err := p.writeBindParameters(pstream, desc) - if err != nil { - return nil, err - } - if err = wr.Close(); err != nil { - return nil, err - } - pstream.CloseSend() - if err = p.captureDoPutPreparedStatementHandle(pstream); err != nil { - return nil, err - } + desc, err = p.bindParameters(ctx, desc, opts...) + if err != nil { + return nil, err } } return p.client.Client.PollFlightInfo(ctx, desc, opts...) @@ -1248,7 +1206,7 @@ func (p *PreparedStatement) ExecuteUpdate(ctx context.Context, opts ...grpc.Call return } if p.hasBindParameters() { - wr, err = p.writeBindParameters(pstream, desc) + wr, err = p.writeBindParametersToStream(pstream, desc) if err != nil { return } @@ -1283,7 +1241,36 @@ func (p *PreparedStatement) hasBindParameters() bool { return (p.paramBinding != nil && p.paramBinding.NumRows() > 0) || (p.streamBinding != nil) } -func (p *PreparedStatement) writeBindParameters(pstream pb.FlightService_DoPutClient, desc *pb.FlightDescriptor) (*flight.Writer, error) { +func (p *PreparedStatement) bindParameters(ctx context.Context, desc *pb.FlightDescriptor, opts ...grpc.CallOption) (*flight.FlightDescriptor, error) { + if p.hasBindParameters() { + pstream, err := p.client.Client.DoPut(ctx, opts...) + if err != nil { + return nil, err + } + wr, err := p.writeBindParametersToStream(pstream, desc) + if err != nil { + return nil, err + } + if err = wr.Close(); err != nil { + return nil, err + } + pstream.CloseSend() + if err = p.captureDoPutPreparedStatementHandle(pstream); err != nil { + return nil, err + } + + cmd := pb.CommandPreparedStatementQuery{PreparedStatementHandle: p.handle} + desc, err = descForCommand(&cmd) + if err != nil { + return nil, err + } + return desc, nil + } + return desc, nil +} + +// XXX: this does not capture the updated handle. Prefer bindParameters. +func (p *PreparedStatement) writeBindParametersToStream(pstream pb.FlightService_DoPutClient, desc *pb.FlightDescriptor) (*flight.Writer, error) { if p.paramBinding != nil { wr := flight.NewRecordWriter(pstream, ipc.WithSchema(p.paramBinding.Schema())) wr.SetFlightDescriptor(desc) diff --git a/go/arrow/flight/flightsql/client_test.go b/go/arrow/flight/flightsql/client_test.go index 727fe02aa706..33da79167c4a 100644 --- a/go/arrow/flight/flightsql/client_test.go +++ b/go/arrow/flight/flightsql/client_test.go @@ -448,9 +448,9 @@ func (s *FlightSqlClientSuite) TestPreparedStatementExecuteParamBinding() { expectedDesc := getDesc(&pb.CommandPreparedStatementQuery{PreparedStatementHandle: []byte(handle)}) // mocked DoPut result - doPutPreparedStatementResult := &pb.DoPutPreparedStatementResult{PreparedStatementHandle: []byte(updatedHandle)} + doPutPreparedStatementResult := &pb.DoPutPreparedStatementResult{PreparedStatementHandle: []byte(updatedHandle)} resdata, _ := proto.Marshal(doPutPreparedStatementResult) - putResult := &pb.PutResult{ AppMetadata: resdata } + putResult := &pb.PutResult{AppMetadata: resdata} // mocked client stream for DoPut mockedPut := &mockDoPutClient{} @@ -461,7 +461,7 @@ func (s *FlightSqlClientSuite) TestPreparedStatementExecuteParamBinding() { mockedPut.On("CloseSend").Return(nil) mockedPut.On("Recv").Return(putResult, nil) - infoCmd := &pb.CommandPreparedStatementQuery{PreparedStatementHandle: []byte(handle)} + infoCmd := &pb.CommandPreparedStatementQuery{PreparedStatementHandle: []byte(updatedHandle)} desc := getDesc(infoCmd) s.mockClient.On("GetFlightInfo", desc.Type, desc.Cmd, s.callOpts).Return(&emptyFlightInfo, nil) @@ -525,9 +525,9 @@ func (s *FlightSqlClientSuite) TestPreparedStatementExecuteReaderBinding() { expectedDesc := getDesc(&pb.CommandPreparedStatementQuery{PreparedStatementHandle: []byte(query)}) // mocked DoPut result - doPutPreparedStatementResult := &pb.DoPutPreparedStatementResult{PreparedStatementHandle: []byte(query)} + doPutPreparedStatementResult := &pb.DoPutPreparedStatementResult{PreparedStatementHandle: []byte(query)} resdata, _ := proto.Marshal(doPutPreparedStatementResult) - putResult := &pb.PutResult{ AppMetadata: resdata } + putResult := &pb.PutResult{AppMetadata: resdata} // mocked client stream for DoPut mockedPut := &mockDoPutClient{}