Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 50 additions & 2 deletions msgio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,16 @@ import (
"fmt"
"io"
"math/rand"
str "strings"
"strings"
"sync"
"testing"
"time"

"github.com/libp2p/go-msgio/pbio/pb"
//lint:ignore SA1019 We are testing better errors when using the deprecated protoio package
"github.com/libp2p/go-msgio/protoio"
"github.com/multiformats/go-varint"
"google.golang.org/protobuf/proto"
)

func randBuf(r *rand.Rand, size int) []byte {
Expand Down Expand Up @@ -79,7 +85,7 @@ func TestMultiError(t *testing.T) {
}

twoErrors := multiErr([]error{errors.New("one"), errors.New("two")})
if eStr := twoErrors.Error(); !str.Contains(eStr, "one") && !str.Contains(eStr, "two") {
if eStr := twoErrors.Error(); !strings.Contains(eStr, "one") && !strings.Contains(eStr, "two") {
t.Fatal("Expected error messages not included")
}
}
Expand Down Expand Up @@ -328,3 +334,45 @@ func SubtestReadShortBuffer(t *testing.T, writer WriteCloser, reader ReadCloser)
t.Fatal("Expected short buffer error")
}
}

func TestHandleProtoGeneratedByGoogleProtobufInProtoio(t *testing.T) {
record := &pb.TestRecord{
Uint32: 42,
Uint64: 84,
Bytes: []byte("test bytes"),
String_: "test string",
Int32: -42,
Int64: -84,
}

recordBytes, err := proto.Marshal(record)
if err != nil {
t.Fatal(err)
}

for _, tc := range []string{"read", "write"} {
t.Run(tc, func(t *testing.T) {
var buf bytes.Buffer
readRecord := &pb.TestRecord{}
switch tc {
case "read":
buf.Write(varint.ToUvarint(uint64(len(recordBytes))))
buf.Write(recordBytes)

reader := protoio.NewDelimitedReader(&buf, 1024)
defer reader.Close()
err = reader.ReadMsg(readRecord)
case "write":
writer := protoio.NewDelimitedWriter(&buf)
err = writer.WriteMsg(record)
}
if err == nil {
t.Fatal("expected error")
}
expectedError := "google Protobuf message passed into a GoGo Protobuf"
if !strings.Contains(err.Error(), expectedError) {
t.Fatalf("expected error to contain '%s'", expectedError)
}
})
}
}
15 changes: 15 additions & 0 deletions protoio/isgoog.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package protoio

import (
"github.com/gogo/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
)

// isGoogleProtobufMsg checks if the given proto.Message was
// generated by the official Google protobuf compiler
func isGoogleProtobufMsg(msg proto.Message) bool {
_, ok := msg.(interface {
ProtoReflect() protoreflect.Message
})
return ok
}
19 changes: 18 additions & 1 deletion protoio/uvarint_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ package protoio

import (
"bufio"
"errors"
"fmt"
"io"
"os"
Expand Down Expand Up @@ -82,7 +83,23 @@ func (ur *uvarintReader) ReadMsg(msg proto.Message) (err error) {
if _, err := io.ReadFull(ur.r, buf); err != nil {
return err
}
return proto.Unmarshal(buf, msg)

// Hoist up gogo's proto.Unmarshal logic so we can also check if this is a google protobuf message
msg.Reset()
if u, ok := msg.(interface {
XXX_Unmarshal([]byte) error
}); ok {
return u.XXX_Unmarshal(buf)
} else if u, ok := msg.(interface {
Unmarshal([]byte) error
}); ok {
return u.Unmarshal(buf)
} else if isGoogleProtobufMsg(msg) {
return errors.New("google Protobuf message passed into a GoGo Protobuf reader. Use github.com/libp2p/go-msgio/pbio instead of github.com/gogo/protobuf/proto")
}

// Fallback to GoGo's proto.Unmarshal around this buffer
return proto.NewBuffer(buf).Unmarshal(msg)
}

func (ur *uvarintReader) Close() error {
Expand Down
5 changes: 5 additions & 0 deletions protoio/uvarint_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
package protoio

import (
"errors"
"fmt"
"io"
"os"
Expand Down Expand Up @@ -80,6 +81,10 @@ func (uw *uvarintWriter) WriteMsg(msg proto.Message) (err error) {
}
}

if isGoogleProtobufMsg(msg) {
return errors.New("google Protobuf message passed into a GoGo Protobuf writer. Use github.com/libp2p/go-msgio/pbio instead of github.com/gogo/protobuf/proto")
}

// fallback
data, err = proto.Marshal(msg)
if err != nil {
Expand Down