Skip to content

Commit df64147

Browse files
authored
fix(storage): optimize gRPC writer with zero-copy and lazy allocation (#13481)
1 parent 9c80b8b commit df64147

File tree

2 files changed

+259
-19
lines changed

2 files changed

+259
-19
lines changed

storage/grpc_writer.go

Lines changed: 86 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"net/http"
2424
"net/url"
2525
"strings"
26+
"sync"
2627
"time"
2728

2829
gapic "cloud.google.com/go/storage/internal/apiv2"
@@ -185,7 +186,8 @@ func (c *grpcStorageClient) OpenWriter(params *openWriterParams, opts ...storage
185186
appendGen: params.appendGen,
186187
finalizeOnClose: params.finalizeOnClose,
187188

188-
buf: make([]byte, 0, chunkSize),
189+
buf: nil, // Allocated lazily on first buffered write.
190+
chunkSize: chunkSize,
189191
writeQuantum: writeQuantum,
190192
lastSegmentStart: lastSegmentStart,
191193
sendableUnits: sendableUnits,
@@ -267,7 +269,8 @@ type gRPCWriter struct {
267269
appendGen int64
268270
finalizeOnClose bool
269271

270-
buf []byte
272+
buf []byte
273+
chunkSize int
271274
// A writeQuantum is the largest quantity of data which can be sent to the
272275
// service in a single message.
273276
writeQuantum int
@@ -384,21 +387,26 @@ func (w *gRPCWriter) gatherFirstBuffer() error {
384387
for cmd := range w.writesChan {
385388
switch v := cmd.(type) {
386389
case *gRPCWriterCommandWrite:
387-
if len(w.buf)+len(v.p) <= cap(w.buf) {
388-
// We have not started sending yet, and we can stage all data without
389-
// starting a send. Compare against cap(w.buf) instead of
390-
// w.writeQuantum: that way we can perform a oneshot upload for objects
391-
// which fit in one chunk, even though we will cut the request into
392-
// w.writeQuantum units when we do start sending.
393-
origLen := len(w.buf)
394-
w.buf = w.buf[:origLen+len(v.p)]
395-
copy(w.buf[origLen:], v.p)
396-
close(v.done)
397-
} else {
398-
// Too large. Handle it in writeLoop.
390+
// If zero-copy one-shot is requested, OR the payload is larger than the buffer,
391+
// bypass buffering entirely and hand off to the writeLoop immediately.
392+
if w.forceOneShot || len(w.buf)+len(v.p) > w.chunkSize {
399393
w.currentCommand = cmd
400394
return nil
401395
}
396+
397+
// Otherwise, lazily allocate and stage the small write (normal buffered path)
398+
if w.buf == nil {
399+
w.buf = make([]byte, 0, w.chunkSize)
400+
}
401+
// We have not started sending yet, and we can stage all data without
402+
// starting a send. Compare against w.chunkSize instead of
403+
// w.writeQuantum: that way we can perform a oneshot upload for objects
404+
// which fit in one chunk, even though we will cut the request into
405+
// w.writeQuantum units when we do start sending.
406+
origLen := len(w.buf)
407+
w.buf = w.buf[:origLen+len(v.p)]
408+
copy(w.buf[origLen:], v.p)
409+
close(v.done)
402410
break
403411
case *gRPCWriterCommandClose:
404412
// If we get here, data (if any) fits in w.buf, so we can force oneshot.
@@ -568,17 +576,33 @@ type gRPCWriterCommand interface {
568576
}
569577

570578
type gRPCWriterCommandWrite struct {
571-
p []byte
572-
done chan struct{}
579+
p []byte
580+
done chan struct{}
581+
initialOffset int64
582+
hasStarted bool
583+
closeOnce sync.Once
573584
}
574585

575586
func (c *gRPCWriterCommandWrite) handle(w *gRPCWriter, cs gRPCWriterCommandHandleChans) error {
576587
if len(c.p) == 0 {
577588
// No data to write.
578-
close(c.done)
589+
c.markDone()
590+
return nil
591+
}
592+
593+
// Zero-Copy send.
594+
if w.forceOneShot {
595+
err := c.zeroCopyWrite(w, cs)
596+
if err != nil {
597+
return err
598+
}
599+
// If zeroCopyWrite returns without error, the write is done.
579600
return nil
580601
}
581602

603+
if w.buf == nil {
604+
w.buf = make([]byte, 0, w.chunkSize)
605+
}
582606
wblen := len(w.buf)
583607
allKnownBytes := wblen + len(c.p)
584608
fullBufs := allKnownBytes / cap(w.buf)
@@ -605,7 +629,7 @@ func (c *gRPCWriterCommandWrite) handle(w *gRPCWriter, cs gRPCWriterCommandHandl
605629
return w.streamSender.err()
606630
}
607631
w.bufUnsentIdx = int(sentOffset - w.bufBaseOffset)
608-
close(c.done)
632+
c.markDone()
609633
return nil
610634
}
611635

@@ -698,10 +722,53 @@ func (c *gRPCWriterCommandWrite) handle(w *gRPCWriter, cs gRPCWriterCommandHandl
698722
w.buf = w.buf[:len(toCopyIn)]
699723
copy(w.buf, toCopyIn)
700724
w.bufUnsentIdx = int(sentOffset - w.bufBaseOffset)
701-
close(c.done)
725+
c.markDone()
726+
return nil
727+
}
728+
729+
func (c *gRPCWriterCommandWrite) zeroCopyWrite(w *gRPCWriter, cs gRPCWriterCommandHandleChans) error {
730+
// Pre-emptively get the context channel to avoid closure overhead in the loop.
731+
ctxDone := w.preRunCtx.Done()
732+
733+
// sendBufferToTarget handles the quantum breakdown.
734+
newOffset, ok := w.sendBufferToTarget(cs, c.p, w.bufBaseOffset, len(c.p), w.handleCompletion)
735+
if !ok {
736+
return w.streamSender.err()
737+
}
738+
739+
// Request an ack from the sender goroutine to ensure the buffer has been
740+
// dispatched to gRPC and is safe for the user to reuse.
741+
if !cs.deliverRequestUnlessCompleted(gRPCBidiWriteRequest{requestAck: true}, w.handleCompletion) {
742+
return w.streamSender.err()
743+
}
744+
745+
ackOutstanding := true
746+
747+
// Wait for server acknowledgement and sender transmissions to enable incremental progress.
748+
for ackOutstanding || w.bufBaseOffset < newOffset {
749+
select {
750+
case completion, ok := <-cs.completions:
751+
if !ok {
752+
return w.streamSender.err()
753+
}
754+
w.handleCompletion(completion)
755+
case <-cs.requestAcks:
756+
ackOutstanding = false
757+
case <-ctxDone:
758+
return w.preRunCtx.Err()
759+
}
760+
}
761+
762+
c.p = nil
763+
c.markDone()
702764
return nil
703765
}
704766

767+
// Helper to ensure we don't close done twice and keep the main logic clean.
768+
func (c *gRPCWriterCommandWrite) markDone() {
769+
c.closeOnce.Do(func() { close(c.done) })
770+
}
771+
705772
type gRPCWriterCommandFlush struct {
706773
done chan int64
707774
}

storage/grpc_writer_test.go

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515
package storage
1616

1717
import (
18+
"context"
19+
"sync"
1820
"testing"
1921

2022
"cloud.google.com/go/storage/internal/apiv2/storagepb"
23+
gax "github.com/googleapis/gax-go/v2"
2124
"google.golang.org/protobuf/proto"
2225
)
2326

@@ -121,3 +124,173 @@ func TestGetObjectChecksums(t *testing.T) {
121124
})
122125
}
123126
}
127+
func TestGRPCWriter_MemoryAllocationPaths(t *testing.T) {
128+
tests := []struct {
129+
name string
130+
chunkSize int
131+
dataSize int
132+
forceOneShot bool
133+
wantZeroCopy bool
134+
}{
135+
{
136+
name: "OneShot_ZeroCopy_1MB",
137+
chunkSize: 0,
138+
dataSize: 1 * 1024 * 1024, // 1 MiB
139+
forceOneShot: true,
140+
wantZeroCopy: true,
141+
},
142+
{
143+
name: "OneShot_ZeroCopy_10MB",
144+
chunkSize: 0,
145+
dataSize: 10 * 1024 * 1024, // 10 MiB
146+
forceOneShot: true,
147+
wantZeroCopy: true,
148+
},
149+
{
150+
name: "Resumable_Buffering",
151+
chunkSize: 2 * 1024 * 1024,
152+
dataSize: 1 * 1024 * 1024, // 1 MiB
153+
forceOneShot: false,
154+
wantZeroCopy: false,
155+
},
156+
{
157+
name: "Resumable_ZeroCopy",
158+
chunkSize: 1 * 1024 * 1024, // 1 MiB
159+
dataSize: 2 * 1024 * 1024, // 2 MiB
160+
forceOneShot: false,
161+
wantZeroCopy: true,
162+
},
163+
{
164+
name: "Resumable_Hybrid",
165+
chunkSize: 2 * 1024 * 1024, // 2 MiB
166+
dataSize: 3 * 1024 * 1024, // 3 MiB
167+
forceOneShot: false,
168+
wantZeroCopy: false,
169+
},
170+
}
171+
172+
for _, tt := range tests {
173+
t.Run(tt.name, func(t *testing.T) {
174+
data := make([]byte, tt.dataSize)
175+
data[0] = 1
176+
data[tt.dataSize-1] = 2
177+
chunkSize := gRPCChunkSize(tt.chunkSize)
178+
mockSender := &mockZeroCopySender{}
179+
w := &gRPCWriter{
180+
buf: nil, // Allocated lazily on first buffered write.
181+
chunkSize: chunkSize,
182+
forceOneShot: tt.forceOneShot,
183+
writeQuantum: maxPerMessageWriteSize,
184+
preRunCtx: context.Background(),
185+
sendableUnits: 10,
186+
writesChan: make(chan gRPCWriterCommand, 1),
187+
donec: make(chan struct{}),
188+
streamSender: mockSender,
189+
settings: &settings{},
190+
}
191+
w.progress = func(int64) {}
192+
w.setObj = func(*ObjectAttrs) {}
193+
w.setSize = func(int64) {}
194+
195+
go func() {
196+
w.writeLoop(context.Background())
197+
close(w.donec)
198+
}()
199+
200+
if _, err := w.Write(data); err != nil {
201+
t.Fatalf("Write failed: %v", err)
202+
}
203+
if err := w.Close(); err != nil {
204+
t.Fatalf("Close failed: %v", err)
205+
}
206+
mockSender.wg.Wait()
207+
208+
mockSender.mu.Lock()
209+
defer mockSender.mu.Unlock()
210+
211+
reqs := filterDataRequests(mockSender.requests)
212+
if len(reqs) == 0 {
213+
t.Fatalf("Expected at least 1 data request, got 0")
214+
}
215+
216+
// Verify memory address logic:
217+
// The last byte of the last request buffer should match the last byte of the input data for zero-copy.
218+
// For buffering/copying, the pointers must differ.
219+
idx := len(reqs) - 1
220+
bufIdx := len(reqs[idx].buf)
221+
isZeroCopy := &reqs[idx].buf[bufIdx-1] == &data[tt.dataSize-1]
222+
if isZeroCopy != tt.wantZeroCopy {
223+
if tt.wantZeroCopy && tt.forceOneShot {
224+
t.Errorf("One-shot upload bypassed zero-copy path; data was unexpectedly copied")
225+
} else if !tt.wantZeroCopy && !tt.forceOneShot {
226+
t.Errorf("Resumable upload bypassed buffering path; data was unexpectedly zero-copied")
227+
} else if tt.wantZeroCopy && !tt.forceOneShot {
228+
t.Errorf("Resumable upload bypassed zero-copy path; data was unexpectedly copied")
229+
}
230+
}
231+
})
232+
}
233+
}
234+
235+
type mockZeroCopySender struct {
236+
mu sync.Mutex
237+
requests []gRPCBidiWriteRequest
238+
errResult error
239+
wg sync.WaitGroup // Waits for all async operations to complete.
240+
}
241+
242+
func (m *mockZeroCopySender) connect(ctx context.Context, cs gRPCBufSenderChans, opts ...gax.CallOption) {
243+
m.wg.Add(1)
244+
go func() {
245+
defer m.wg.Done()
246+
247+
// Track active flush goroutines to prevent closing the channel prematurely.
248+
var completionWg sync.WaitGroup
249+
250+
defer func() {
251+
completionWg.Wait()
252+
close(cs.completions)
253+
}()
254+
255+
for req := range cs.requests {
256+
m.mu.Lock()
257+
m.requests = append(m.requests, req)
258+
m.mu.Unlock()
259+
260+
if req.requestAck {
261+
select {
262+
case cs.requestAcks <- struct{}{}:
263+
case <-ctx.Done():
264+
return
265+
}
266+
}
267+
268+
if req.flush {
269+
completionWg.Add(1)
270+
// Send completions asynchronously to avoid blocking the request loop.
271+
go func(offset int64) {
272+
defer completionWg.Done()
273+
select {
274+
case cs.completions <- gRPCBidiWriteCompletion{
275+
flushOffset: offset,
276+
}:
277+
case <-ctx.Done():
278+
}
279+
}(req.offset + int64(len(req.buf)))
280+
}
281+
}
282+
}()
283+
}
284+
285+
func (m *mockZeroCopySender) err() error { return m.errResult }
286+
287+
// filterDataRequests returns only requests containing data, ignoring protocol overhead.
288+
func filterDataRequests(reqs []gRPCBidiWriteRequest) []gRPCBidiWriteRequest {
289+
var dataReqs []gRPCBidiWriteRequest
290+
for _, r := range reqs {
291+
if len(r.buf) > 0 {
292+
dataReqs = append(dataReqs, r)
293+
}
294+
}
295+
return dataReqs
296+
}

0 commit comments

Comments
 (0)