diff --git a/sdks/go/pkg/beam/core/runtime/metricsx/metricsx.go b/sdks/go/pkg/beam/core/runtime/metricsx/metricsx.go index 4a872e291c6a..c71ead208364 100644 --- a/sdks/go/pkg/beam/core/runtime/metricsx/metricsx.go +++ b/sdks/go/pkg/beam/core/runtime/metricsx/metricsx.go @@ -24,6 +24,7 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/metrics" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" + "golang.org/x/exp/slog" ) // FromMonitoringInfos extracts metrics from monitored states and @@ -139,7 +140,7 @@ func groupByType(p *pipepb.Pipeline, minfos []*pipepb.MonitoringInfo) ( } } if len(errs) > 0 { - log.Printf("Warning: %v errors during metrics processing: %v\n", len(errs), errs) + slog.Debug("errors during metrics processing", "count", len(errs), "errors", errs) } return counters, distributions, gauges, msecs, pcols } diff --git a/sdks/go/pkg/beam/runners/prism/internal/environments.go b/sdks/go/pkg/beam/runners/prism/internal/environments.go new file mode 100644 index 000000000000..5830325bd054 --- /dev/null +++ b/sdks/go/pkg/beam/runners/prism/internal/environments.go @@ -0,0 +1,199 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "bytes" + "context" + "fmt" + "io" + "os" + + fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" + pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/jobservices" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns" + "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" + "golang.org/x/exp/slog" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/protobuf/proto" + + dtyp "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/mount" + dcli "github.com/docker/docker/client" + "github.com/docker/docker/pkg/stdcopy" +) + +// TODO move environment handling to the worker package. + +func runEnvironment(ctx context.Context, j *jobservices.Job, env string, wk *worker.W) error { + logger := slog.With(slog.String("envID", wk.Env)) + // TODO fix broken abstraction. + // We're starting a worker pool here, because that's the loopback environment. + // It's sort of a mess, largely because of loopback, which has + // a different flow from a provisioned docker container. + e := j.Pipeline.GetComponents().GetEnvironments()[env] + switch e.GetUrn() { + case urns.EnvExternal: + ep := &pipepb.ExternalPayload{} + if err := (proto.UnmarshalOptions{}).Unmarshal(e.GetPayload(), ep); err != nil { + logger.Error("unmarshing external environment payload", "error", err) + } + go func() { + externalEnvironment(ctx, ep, wk) + slog.Debug("environment stopped", slog.String("job", j.String())) + }() + return nil + case urns.EnvDocker: + dp := &pipepb.DockerPayload{} + if err := (proto.UnmarshalOptions{}).Unmarshal(e.GetPayload(), dp); err != nil { + logger.Error("unmarshing docker environment payload", "error", err) + } + return dockerEnvironment(ctx, logger, dp, wk, j.ArtifactEndpoint()) + default: + return fmt.Errorf("environment %v with urn %v unimplemented", env, e.GetUrn()) + } +} + +func externalEnvironment(ctx context.Context, ep *pipepb.ExternalPayload, wk *worker.W) { + conn, err := grpc.Dial(ep.GetEndpoint().GetUrl(), grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + panic(fmt.Sprintf("unable to dial sdk worker %v: %v", ep.GetEndpoint().GetUrl(), err)) + } + defer conn.Close() + pool := fnpb.NewBeamFnExternalWorkerPoolClient(conn) + + endpoint := &pipepb.ApiServiceDescriptor{ + Url: wk.Endpoint(), + } + pool.StartWorker(ctx, &fnpb.StartWorkerRequest{ + WorkerId: wk.ID, + ControlEndpoint: endpoint, + LoggingEndpoint: endpoint, + ArtifactEndpoint: endpoint, + ProvisionEndpoint: endpoint, + Params: ep.GetParams(), + }) + // Job processing happens here, but orchestrated by other goroutines + // This goroutine blocks until the context is cancelled, signalling + // that the pool runner should stop the worker. + <-ctx.Done() + + // Previous context cancelled so we need a new one + // for this request. + pool.StopWorker(context.Background(), &fnpb.StopWorkerRequest{ + WorkerId: wk.ID, + }) +} + +func dockerEnvironment(ctx context.Context, logger *slog.Logger, dp *pipepb.DockerPayload, wk *worker.W, artifactEndpoint string) error { + logger = logger.With("worker_id", wk.ID, "image", dp.GetContainerImage()) + + // TODO consider preserving client? + cli, err := dcli.NewClientWithOpts(dcli.FromEnv, dcli.WithAPIVersionNegotiation()) + if err != nil { + return fmt.Errorf("couldn't connect to docker:%w", err) + } + + // TODO abstract mounting cloud specific auths better. + const gcloudCredsEnv = "GOOGLE_APPLICATION_CREDENTIALS" + gcloudCredsFile, ok := os.LookupEnv(gcloudCredsEnv) + var mounts []mount.Mount + var envs []string + if ok { + _, err := os.Stat(gcloudCredsFile) + // File exists + if err == nil { + dockerGcloudCredsFile := "/docker_cred_file.json" + mounts = append(mounts, mount.Mount{ + Type: "bind", + Source: gcloudCredsFile, + Target: dockerGcloudCredsFile, + }) + credEnv := fmt.Sprintf("%v=%v", gcloudCredsEnv, dockerGcloudCredsFile) + envs = append(envs, credEnv) + } + } + + if rc, err := cli.ImagePull(ctx, dp.GetContainerImage(), dtyp.ImagePullOptions{}); err == nil { + // Copy the output, but discard it so we can wait until the image pull is finished. + io.Copy(io.Discard, rc) + rc.Close() + } else { + logger.Warn("unable to pull image", "error", err) + } + + ccr, err := cli.ContainerCreate(ctx, &container.Config{ + Image: dp.GetContainerImage(), + Cmd: []string{ + fmt.Sprintf("--id=%v-%v", wk.JobKey, wk.Env), + fmt.Sprintf("--control_endpoint=%v", wk.Endpoint()), + fmt.Sprintf("--artifact_endpoint=%v", artifactEndpoint), + fmt.Sprintf("--provision_endpoint=%v", wk.Endpoint()), + fmt.Sprintf("--logging_endpoint=%v", wk.Endpoint()), + }, + Env: envs, + Tty: false, + }, &container.HostConfig{ + NetworkMode: "host", + Mounts: mounts, + }, nil, nil, "") + if err != nil { + cli.Close() + return fmt.Errorf("unable to create container image %v with docker for env %v, err: %w", dp.GetContainerImage(), wk.Env, err) + } + containerID := ccr.ID + logger = logger.With("container", containerID) + + if err := cli.ContainerStart(ctx, containerID, dtyp.ContainerStartOptions{}); err != nil { + cli.Close() + return fmt.Errorf("unable to start container image %v with docker for env %v, err: %w", dp.GetContainerImage(), wk.Env, err) + } + + // Start goroutine to wait on container state. + go func() { + defer cli.Close() + + statusCh, errCh := cli.ContainerWait(ctx, containerID, container.WaitConditionNotRunning) + select { + case <-ctx.Done(): + // Can't use command context, since it's already canceled here. + err := cli.ContainerKill(context.Background(), containerID, "") + if err != nil { + logger.Error("docker container kill error", "error", err) + } + case err := <-errCh: + if err != nil { + logger.Error("docker container wait error", "error", err) + } + case resp := <-statusCh: + logger.Info("docker container has self terminated", "status_code", resp.StatusCode) + + rc, err := cli.ContainerLogs(ctx, containerID, dtyp.ContainerLogsOptions{Details: true, ShowStdout: true, ShowStderr: true}) + if err != nil { + logger.Error("docker container logs error", "error", err) + } + defer rc.Close() + var buf bytes.Buffer + stdcopy.StdCopy(&buf, &buf, rc) + logger.Error("container self terminated", "log", buf.String()) + } + }() + + return nil +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute.go b/sdks/go/pkg/beam/runners/prism/internal/execute.go index b2f9d866603a..e0c67105d451 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/execute.go +++ b/sdks/go/pkg/beam/runners/prism/internal/execute.go @@ -24,7 +24,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec" - fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine" "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/jobservices" @@ -32,8 +31,6 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/worker" "golang.org/x/exp/maps" "golang.org/x/exp/slog" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" "google.golang.org/protobuf/proto" ) @@ -54,30 +51,21 @@ func RunPipeline(j *jobservices.Job) { return } env, _ := getOnlyPair(envs) - wk := worker.New(j.String()+"_"+env, env) // Cheating by having the worker id match the environment id. - go wk.Serve() - timeout := time.Minute - time.AfterFunc(timeout, func() { - if wk.Connected() { - return - } - err := fmt.Errorf("prism %v didn't get control connection after %v", wk, timeout) + wk, err := makeWorker(env, j) + if err != nil { j.Failed(err) - j.CancelFn(err) - }) - + return + } // When this function exits, we cancel the context to clear // any related job resources. defer func() { j.CancelFn(fmt.Errorf("runPipeline returned, cleaning up")) }() - go runEnvironment(j.RootCtx, j, env, wk) j.SendMsg("running " + j.String()) j.Running() - err := executePipeline(j.RootCtx, wk, j) - if err != nil { + if err := executePipeline(j.RootCtx, wk, j); err != nil { j.Failed(err) return } @@ -90,57 +78,27 @@ func RunPipeline(j *jobservices.Job) { j.Done() } -// TODO move environment handling to the worker package. - -func runEnvironment(ctx context.Context, j *jobservices.Job, env string, wk *worker.W) { - // TODO fix broken abstraction. - // We're starting a worker pool here, because that's the loopback environment. - // It's sort of a mess, largely because of loopback, which has - // a different flow from a provisioned docker container. - e := j.Pipeline.GetComponents().GetEnvironments()[env] - switch e.GetUrn() { - case urns.EnvExternal: - ep := &pipepb.ExternalPayload{} - if err := (proto.UnmarshalOptions{}).Unmarshal(e.GetPayload(), ep); err != nil { - slog.Error("unmarshing environment payload", err, slog.String("envID", wk.Env)) - } - externalEnvironment(ctx, ep, wk) - slog.Debug("environment stopped", slog.String("envID", wk.String()), slog.String("job", j.String())) - default: - panic(fmt.Sprintf("environment %v with urn %v unimplemented", env, e.GetUrn())) - } -} - -func externalEnvironment(ctx context.Context, ep *pipepb.ExternalPayload, wk *worker.W) { - conn, err := grpc.Dial(ep.GetEndpoint().GetUrl(), grpc.WithTransportCredentials(insecure.NewCredentials())) - if err != nil { - panic(fmt.Sprintf("unable to dial sdk worker %v: %v", ep.GetEndpoint().GetUrl(), err)) - } - defer conn.Close() - pool := fnpb.NewBeamFnExternalWorkerPoolClient(conn) - - endpoint := &pipepb.ApiServiceDescriptor{ - Url: wk.Endpoint(), +// makeWorker creates a worker for that environment. +func makeWorker(env string, j *jobservices.Job) (*worker.W, error) { + wk := worker.New(j.String()+"_"+env, env) + wk.EnvPb = j.Pipeline.GetComponents().GetEnvironments()[env] + wk.JobKey = j.JobKey() + wk.ArtifactEndpoint = j.ArtifactEndpoint() + go wk.Serve() + if err := runEnvironment(j.RootCtx, j, env, wk); err != nil { + return nil, fmt.Errorf("failed to start environment %v for job %v: %w", env, j, err) } - pool.StartWorker(ctx, &fnpb.StartWorkerRequest{ - WorkerId: wk.ID, - ControlEndpoint: endpoint, - LoggingEndpoint: endpoint, - ArtifactEndpoint: endpoint, - ProvisionEndpoint: endpoint, - Params: nil, - }) - - // Job processing happens here, but orchestrated by other goroutines - // This goroutine blocks until the context is cancelled, signalling - // that the pool runner should stop the worker. - <-ctx.Done() - - // Previous context cancelled so we need a new one - // for this request. - pool.StopWorker(context.Background(), &fnpb.StopWorkerRequest{ - WorkerId: wk.ID, + // Check for connection succeeding after we've created the environment successfully. + timeout := 1 * time.Minute + time.AfterFunc(timeout, func() { + if wk.Connected() { + return + } + err := fmt.Errorf("prism %v didn't get control connection to %v after %v", wk, wk.Endpoint(), timeout) + j.Failed(err) + j.CancelFn(err) }) + return wk, nil } type transformExecuter interface { diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/artifact.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/artifact.go index e66def5b0fe8..99b786d45980 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/artifact.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/artifact.go @@ -16,11 +16,14 @@ package jobservices import ( + "bytes" + "context" "fmt" "io" jobpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1" "golang.org/x/exp/slog" + "google.golang.org/protobuf/encoding/prototext" ) func (s *Server) ReverseArtifactRetrievalService(stream jobpb.ArtifactStagingService_ReverseArtifactRetrievalServiceServer) error { @@ -47,7 +50,7 @@ func (s *Server) ReverseArtifactRetrievalService(stream jobpb.ArtifactStagingSer }, }, }) - var count int + var buf bytes.Buffer for { in, err := stream.Recv() if err == io.EOF { @@ -56,26 +59,61 @@ func (s *Server) ReverseArtifactRetrievalService(stream jobpb.ArtifactStagingSer if err != nil { return err } - if in.IsLast { - slog.Debug("GetArtifact finish", + if in.GetIsLast() { + slog.Debug("GetArtifact finished", slog.Group("dep", slog.String("urn", dep.GetTypeUrn()), slog.String("payload", string(dep.GetTypePayload()))), - slog.Int("bytesReceived", count)) + slog.Int("bytesReceived", buf.Len()), + slog.String("rtype", fmt.Sprintf("%T", in.GetResponse())), + ) break } // Here's where we go through each environment's artifacts. // We do nothing with them. switch req := in.GetResponse().(type) { case *jobpb.ArtifactResponseWrapper_GetArtifactResponse: - count += len(req.GetArtifactResponse.GetData()) + buf.Write(req.GetArtifactResponse.GetData()) + case *jobpb.ArtifactResponseWrapper_ResolveArtifactResponse: err := fmt.Errorf("unexpected ResolveArtifactResponse to GetArtifact: %v", in.GetResponse()) slog.Error("GetArtifact failure", err) return err } } + if len(s.artifacts) == 0 { + s.artifacts = map[string][]byte{} + } + s.artifacts[string(dep.GetTypePayload())] = buf.Bytes() } } return nil } + +func (s *Server) ResolveArtifacts(_ context.Context, req *jobpb.ResolveArtifactsRequest) (*jobpb.ResolveArtifactsResponse, error) { + return &jobpb.ResolveArtifactsResponse{ + Replacements: req.GetArtifacts(), + }, nil +} + +func (s *Server) GetArtifact(req *jobpb.GetArtifactRequest, stream jobpb.ArtifactRetrievalService_GetArtifactServer) error { + info := req.GetArtifact() + buf, ok := s.artifacts[string(info.GetTypePayload())] + if !ok { + pt := prototext.Format(info) + slog.Warn("unable to provide artifact to worker", "artifact_info", pt) + return fmt.Errorf("unable to provide %v to worker", pt) + } + chunk := 128 * 1024 * 1024 // 128 MB + var i int + for i+chunk < len(buf) { + stream.Send(&jobpb.GetArtifactResponse{ + Data: buf[i : i+chunk], + }) + i += chunk + } + stream.Send(&jobpb.GetArtifactResponse{ + Data: buf[i:], + }) + return nil +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go index 10d36066391f..87b0ec007bfb 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go @@ -68,6 +68,8 @@ type Job struct { key string jobName string + artifactEndpoint string + Pipeline *pipepb.Pipeline options *structpb.Struct @@ -88,6 +90,10 @@ type Job struct { metrics metricsStore } +func (j *Job) ArtifactEndpoint() string { + return j.artifactEndpoint +} + // ContributeTentativeMetrics returns the datachannel read index, and any unknown monitoring short ids. func (j *Job) ContributeTentativeMetrics(payloads *fnpb.ProcessBundleProgressResponse) (int64, []string) { return j.metrics.ContributeTentativeMetrics(payloads) @@ -113,6 +119,10 @@ func (j *Job) LogValue() slog.Value { slog.String("name", j.jobName)) } +func (j *Job) JobKey() string { + return j.key +} + func (j *Job) SendMsg(msg string) { j.streamCond.L.Lock() defer j.streamCond.L.Unlock() diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go index e626a05b51e1..213e33a78379 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go @@ -79,6 +79,8 @@ func (s *Server) Prepare(ctx context.Context, req *jobpb.PrepareJobRequest) (*jo streamCond: sync.NewCond(&sync.Mutex{}), RootCtx: rootCtx, CancelFn: cancelFn, + + artifactEndpoint: s.Endpoint(), } // Queue initial state of the job. diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go b/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go index e3fb7766b519..bf2db814813c 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go +++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/server.go @@ -29,6 +29,7 @@ import ( type Server struct { jobpb.UnimplementedJobServiceServer jobpb.UnimplementedArtifactStagingServiceServer + jobpb.UnimplementedArtifactRetrievalServiceServer fnpb.UnimplementedProvisionServiceServer // Server management @@ -42,6 +43,9 @@ type Server struct { // execute defines how a job is executed. execute func(*Job) + + // Artifact hack + artifacts map[string][]byte } // NewServer acquires the indicated port. @@ -60,6 +64,7 @@ func NewServer(port int, execute func(*Job)) *Server { s.server = grpc.NewServer(opts...) jobpb.RegisterJobServiceServer(s.server, s) jobpb.RegisterArtifactStagingServiceServer(s.server, s) + jobpb.RegisterArtifactRetrievalServiceServer(s.server, s) return s } diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go index 0ad7ccb37032..3a862a143b73 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go +++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go @@ -57,6 +57,9 @@ type W struct { ID, Env string + JobKey, ArtifactEndpoint string + EnvPb *pipepb.Environment + // Server management lis net.Listener server *grpc.Server @@ -107,6 +110,7 @@ func New(id, env string) *W { fnpb.RegisterBeamFnDataServer(wk.server, wk) fnpb.RegisterBeamFnLoggingServer(wk.server, wk) fnpb.RegisterBeamFnStateServer(wk.server, wk) + fnpb.RegisterProvisionServiceServer(wk.server, wk) return wk } @@ -164,10 +168,15 @@ func (wk *W) GetProvisionInfo(_ context.Context, _ *fnpb.GetProvisionInfoRequest RunnerCapabilities: []string{ urns.CapabilityMonitoringInfoShortIDs, }, - LoggingEndpoint: endpoint, - ControlEndpoint: endpoint, - ArtifactEndpoint: endpoint, - // TODO add this job's RetrievalToken + LoggingEndpoint: endpoint, + ControlEndpoint: endpoint, + ArtifactEndpoint: &pipepb.ApiServiceDescriptor{ + Url: wk.ArtifactEndpoint, + }, + + RetrievalToken: wk.JobKey, + Dependencies: wk.EnvPb.GetDependencies(), + // TODO add this job's artifact Dependencies Metadata: map[string]string{ diff --git a/sdks/go/pkg/beam/runners/universal/runnerlib/stage.go b/sdks/go/pkg/beam/runners/universal/runnerlib/stage.go index 732f4382ab5d..d5cc6aa7327a 100644 --- a/sdks/go/pkg/beam/runners/universal/runnerlib/stage.go +++ b/sdks/go/pkg/beam/runners/universal/runnerlib/stage.go @@ -44,7 +44,7 @@ func Stage(ctx context.Context, id, endpoint, binary, st string) (retrievalToken defer cc.Close() if err := StageViaPortableAPI(ctx, cc, binary, st); err == nil { - return "", nil + return st, nil } log.Warnf(ctx, "unable to stage with PortableAPI: %v; falling back to legacy", err)