diff --git a/sdks/go/pkg/beam/provision/provision_test.go b/sdks/go/pkg/beam/provision/provision_test.go index f29bc9b3be51..92dc315fc062 100644 --- a/sdks/go/pkg/beam/provision/provision_test.go +++ b/sdks/go/pkg/beam/provision/provision_test.go @@ -16,8 +16,16 @@ package provision import ( + "context" + "fmt" + "log" + "net" "reflect" + "sync" "testing" + + fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" + "google.golang.org/grpc" ) type s struct { @@ -52,3 +60,51 @@ func TestConversions(t *testing.T) { } } } + +type ProvisionServiceServicer struct { + fnpb.UnimplementedProvisionServiceServer +} + +func (p ProvisionServiceServicer) GetProvisionInfo(ctx context.Context, req *fnpb.GetProvisionInfoRequest) (*fnpb.GetProvisionInfoResponse, error) { + return &fnpb.GetProvisionInfoResponse{Info: &fnpb.ProvisionInfo{RetrievalToken: "token"}}, nil +} + +func setup(addr *string, wg *sync.WaitGroup) { + l, err := net.Listen("tcp", ":0") + defer l.Close() + if err != nil { + log.Fatalf("failed to find an open port: %v", err) + } + port := l.Addr().(*net.TCPAddr).Port + *addr = fmt.Sprintf(":%d", port) + + server := grpc.NewServer() + defer server.Stop() + + prs := &ProvisionServiceServicer{} + fnpb.RegisterProvisionServiceServer(server, prs) + + wg.Done() + + if err := server.Serve(l); err != nil { + log.Fatalf("cannot serve the server: %v", err) + } +} + +func TestProvisionInfo(t *testing.T) { + + endpoint := "" + var wg sync.WaitGroup + wg.Add(1) + go setup(&endpoint, &wg) + wg.Wait() + + got, err := Info(context.Background(), endpoint) + if err != nil { + t.Errorf("error in response: %v", err) + } + want := &fnpb.ProvisionInfo{RetrievalToken: "token"} + if got.GetRetrievalToken() != want.GetRetrievalToken() { + t.Errorf("provision.Info() = %v, want %v", got, want) + } +}