diff --git a/descriptors.go b/descriptors.go index 1c101af..04e8d44 100644 --- a/descriptors.go +++ b/descriptors.go @@ -19,9 +19,10 @@ type descriptorSet struct { seen map[string]struct{} ignoreFiles map[string]struct{} descProto string + includeDir string } -func newDescriptorSet(ignoreFiles []string, d string) *descriptorSet { +func newDescriptorSet(ignoreFiles []string, d string, i string) *descriptorSet { ifm := make(map[string]struct{}, len(ignoreFiles)) for _, ignore := range ignoreFiles { ifm[ignore] = struct{}{} @@ -30,6 +31,7 @@ func newDescriptorSet(ignoreFiles []string, d string) *descriptorSet { seen: make(map[string]struct{}), ignoreFiles: ifm, descProto: d, + includeDir: i, } } @@ -56,7 +58,7 @@ func (d *descriptorSet) add(descs ...*descriptor.FileDescriptorProto) { // // This is equivalent to the following command: // -// cat merged.pb | protoc --decode google.protobuf.FileDescriptorSet /path/to/google/protobuf/descriptor.proto +// cat merged.pb | protoc -I /path/to --decode google.protobuf.FileDescriptorSet /path/to/google/protobuf/descriptor.proto func (d *descriptorSet) marshalTo(w io.Writer) error { p, err := proto.Marshal(&d.merged) if err != nil { @@ -65,6 +67,8 @@ func (d *descriptorSet) marshalTo(w io.Writer) error { args := []string{ "protoc", + "-I", + d.includeDir, "--decode", "google.protobuf.FileDescriptorSet", d.descProto, diff --git a/main.go b/main.go index e75875d..a6f4ddd 100644 --- a/main.go +++ b/main.go @@ -70,7 +70,7 @@ func main() { } } - descProto, err := descriptorProto(c.Includes.After) + descProto, includeDir, err := descriptorProto(c.Includes.After) if err != nil { log.Fatalln(err) } @@ -78,7 +78,7 @@ func main() { // Aggregate descriptors for each descriptor prefix. descriptorSets := map[string]*descriptorSet{} for _, stable := range c.Descriptors { - descriptorSets[stable.Prefix] = newDescriptorSet(stable.IgnoreFiles, descProto) + descriptorSets[stable.Prefix] = newDescriptorSet(stable.IgnoreFiles, descProto, includeDir) } shouldGenerateDescriptors := func(p string) bool { @@ -342,17 +342,17 @@ func gopathJoin(gopath, element string) string { // descriptorProto returns the full path to google/protobuf/descriptor.proto // which might be different depending on whether it was installed. The argument // is the list of paths to check. -func descriptorProto(paths []string) (string, error) { +func descriptorProto(paths []string) (string, string, error) { const descProto = "google/protobuf/descriptor.proto" for _, dir := range paths { file := path.Join(dir, descProto) if _, err := os.Stat(file); err == nil { - return file, err + return file, dir, err } } - return "", fmt.Errorf("File %q not found (looked in: %v)", descProto, paths) + return "", "", fmt.Errorf("File %q not found (looked in: %v)", descProto, paths) } var errVendorNotFound = fmt.Errorf("no vendor dir found")