diff --git a/descriptors.go b/descriptors.go index 1c101af..1ad784e 100644 --- a/descriptors.go +++ b/descriptors.go @@ -19,9 +19,10 @@ type descriptorSet struct { seen map[string]struct{} ignoreFiles map[string]struct{} descProto string + protoPath string } -func newDescriptorSet(ignoreFiles []string, d string) *descriptorSet { +func newDescriptorSet(ignoreFiles []string, d, p string) *descriptorSet { ifm := make(map[string]struct{}, len(ignoreFiles)) for _, ignore := range ignoreFiles { ifm[ignore] = struct{}{} @@ -30,6 +31,7 @@ func newDescriptorSet(ignoreFiles []string, d string) *descriptorSet { seen: make(map[string]struct{}), ignoreFiles: ifm, descProto: d, + protoPath: p, } } @@ -67,6 +69,8 @@ func (d *descriptorSet) marshalTo(w io.Writer) error { "protoc", "--decode", "google.protobuf.FileDescriptorSet", + "--proto_path", + d.protoPath, d.descProto, } diff --git a/main.go b/main.go index e75875d..a35803d 100644 --- a/main.go +++ b/main.go @@ -70,7 +70,7 @@ func main() { } } - descProto, err := descriptorProto(c.Includes.After) + descProto, protoPath, err := descriptorProto(c.Includes.After) if err != nil { log.Fatalln(err) } @@ -78,7 +78,7 @@ func main() { // Aggregate descriptors for each descriptor prefix. descriptorSets := map[string]*descriptorSet{} for _, stable := range c.Descriptors { - descriptorSets[stable.Prefix] = newDescriptorSet(stable.IgnoreFiles, descProto) + descriptorSets[stable.Prefix] = newDescriptorSet(stable.IgnoreFiles, descProto, protoPath) } shouldGenerateDescriptors := func(p string) bool { @@ -340,19 +340,19 @@ func gopathJoin(gopath, element string) string { } // descriptorProto returns the full path to google/protobuf/descriptor.proto -// which might be different depending on whether it was installed. The argument -// is the list of paths to check. -func descriptorProto(paths []string) (string, error) { +// (which might be different depending on whether it was installed), and +// the path where it was found. The argument is the list of paths to check. +func descriptorProto(paths []string) (string, string, error) { const descProto = "google/protobuf/descriptor.proto" for _, dir := range paths { file := path.Join(dir, descProto) if _, err := os.Stat(file); err == nil { - return file, err + return file, dir, err } } - return "", fmt.Errorf("File %q not found (looked in: %v)", descProto, paths) + return "", "", fmt.Errorf("File %q not found (looked in: %v)", descProto, paths) } var errVendorNotFound = fmt.Errorf("no vendor dir found")