diff --git a/cmd/containerd-shim-runhcs-v1/main.go b/cmd/containerd-shim-runhcs-v1/main.go index e0ee3a4ac0..375b2dcb46 100644 --- a/cmd/containerd-shim-runhcs-v1/main.go +++ b/cmd/containerd-shim-runhcs-v1/main.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "strings" + "time" "github.com/Microsoft/go-winio/pkg/etw" "github.com/Microsoft/go-winio/pkg/etwlogrus" @@ -41,6 +42,9 @@ var ( containerdBinaryFlag string idFlag string + + // gracefulShutdownTimeout is how long to wait for clean-up before just exiting + gracefulShutdownTimeout = 3 * time.Second ) func etwCallback(sourceID guid.GUID, state etw.ProviderState, level etw.Level, matchAnyKeyword uint64, matchAllKeyword uint64, filterData uintptr) { diff --git a/cmd/containerd-shim-runhcs-v1/serve.go b/cmd/containerd-shim-runhcs-v1/serve.go index ecef00c2e7..6a973f1c3c 100644 --- a/cmd/containerd-shim-runhcs-v1/serve.go +++ b/cmd/containerd-shim-runhcs-v1/serve.go @@ -169,11 +169,9 @@ var serveCommand = cli.Command{ ttrpcAddress := os.Getenv(ttrpcAddressEnv) ttrpcEventPublisher, err := newEventPublisher(ttrpcAddress) - if err != nil { return err } - defer func() { if err != nil { ttrpcEventPublisher.close() @@ -181,11 +179,13 @@ var serveCommand = cli.Command{ }() // Setup the ttrpc server - svc = &service{ - events: ttrpcEventPublisher, - tid: idFlag, - isSandbox: ctx.Bool("is-sandbox"), + svc, err = NewService(WithEventPublisher(ttrpcEventPublisher), + WithTID(idFlag), + WithIsSandbox(ctx.Bool("is-sandbox"))) + if err != nil { + return fmt.Errorf("failed to create new service: %w", err) } + s, err := ttrpc.NewServer(ttrpc.WithUnaryServerInterceptor(octtrpc.ServerInterceptor())) if err != nil { return err @@ -204,10 +204,10 @@ var serveCommand = cli.Command{ serrs := make(chan error, 1) defer close(serrs) go func() { - // TODO: JTERRY75 We should use a real context with cancellation shared by - // the service for shim shutdown gracefully. - ctx := context.Background() - if err := trapClosedConnErr(s.Serve(ctx, sl)); err != nil { + // Serve loops infinitely unless s.Shutdown or s.Close are called. + // Passed in context is used as parent context for handling requests, + // but canceliing does not bring down ttrpc service. + if err := trapClosedConnErr(s.Serve(context.Background(), sl)); err != nil { logrus.WithError(err).Fatal("containerd-shim: ttrpc server failure") serrs <- err return @@ -221,8 +221,7 @@ var serveCommand = cli.Command{ case err := <-serrs: return err case <-time.After(2 * time.Millisecond): - // TODO: JTERRY75 this is terrible code. Contribue a change to - // ttrpc that you can: + // TODO: Contribute a change to ttrpc so that you can: // // go func () { errs <- s.Serve() } // select { @@ -232,12 +231,27 @@ var serveCommand = cli.Command{ // This is our best indication that we have not errored on creation // and are successfully serving the API. + // Closing stdout signals to containerd that shim started successfully os.Stdout.Close() } // Wait for the serve API to be shut down. - <-serrs - return nil + select { + case err = <-serrs: + // the ttrpc server shutdown without processing a shutdown request + case <-svc.Done(): + if !svc.gracefulShutdown { + // Return immediately, but still close ttrpc server, pipes, and spans + // Shouldn't need to os.Exit without clean up (ie, deferred `.Close()`s) + return nil + } + // currently the ttrpc shutdown is the only clean up to wait on + sctx, cancel := context.WithTimeout(context.Background(), gracefulShutdownTimeout) + defer cancel() + err = s.Shutdown(sctx) + } + + return err }, } diff --git a/cmd/containerd-shim-runhcs-v1/service.go b/cmd/containerd-shim-runhcs-v1/service.go index 0085168580..c1a080eff5 100644 --- a/cmd/containerd-shim-runhcs-v1/service.go +++ b/cmd/containerd-shim-runhcs-v1/service.go @@ -18,7 +18,29 @@ import ( "go.opencensus.io/trace" ) -var _ = (task.TaskService)(&service{}) +type ServiceOptions struct { + Events publisher + TID string + IsSandbox bool +} + +type ServiceOption func(*ServiceOptions) + +func WithEventPublisher(e publisher) ServiceOption { + return func(o *ServiceOptions) { + o.Events = e + } +} +func WithTID(tid string) ServiceOption { + return func(o *ServiceOptions) { + o.TID = tid + } +} +func WithIsSandbox(s bool) ServiceOption { + return func(o *ServiceOptions) { + o.IsSandbox = s + } +} type service struct { events publisher @@ -42,10 +64,35 @@ type service struct { taskOrPod atomic.Value // cl is the create lock. Since each shim MUST only track a single task or - // POD. `cl` is used to create the task or POD sandbox. It SHOULD not be + // POD. `cl` is used to create the task or POD sandbox. It SHOULD NOT be // taken when creating tasks in a POD sandbox as they can happen // concurrently. cl sync.Mutex + + // shutdown is closed to signal a shutdown request is received + shutdown chan struct{} + // shutdownOnce is responsible for closing `shutdown` and any other necessary cleanup + shutdownOnce sync.Once + // gracefulShutdown dictates whether to shutdown gracefully and clean up resources + // or exit immediately + gracefulShutdown bool +} + +var _ = (task.TaskService)(&service{}) + +func NewService(o ...ServiceOption) (svc *service, err error) { + var opts ServiceOptions + for _, op := range o { + op(&opts) + } + + svc = &service{ + events: opts.Events, + tid: opts.TID, + isSandbox: opts.IsSandbox, + shutdown: make(chan struct{}), + } + return svc, nil } func (s *service) State(ctx context.Context, req *task.StateRequest) (resp *task.StateResponse, err error) { @@ -475,3 +522,16 @@ func (s *service) ComputeProcessorInfo(ctx context.Context, req *extendedtask.Co r, e := s.computeProcessorInfoInternal(ctx, req) return r, errdefs.ToGRPC(e) } + +func (s *service) Done() <-chan struct{} { + return s.shutdown +} + +func (s *service) IsShutdown() bool { + select { + case <-s.shutdown: + return true + default: + return false + } +} diff --git a/cmd/containerd-shim-runhcs-v1/service_internal.go b/cmd/containerd-shim-runhcs-v1/service_internal.go index 739531d3a1..27c1a2cfad 100644 --- a/cmd/containerd-shim-runhcs-v1/service_internal.go +++ b/cmd/containerd-shim-runhcs-v1/service_internal.go @@ -447,12 +447,13 @@ func (s *service) shutdownInternal(ctx context.Context, req *task.ShutdownReques return empty, nil } - if req.Now { - os.Exit(0) - } - // TODO: JTERRY75 if we dont use `now` issue a Shutdown to the ttrpc - // connection to drain any active requests. - os.Exit(0) + s.shutdownOnce.Do(func() { + // TODO: should taskOrPod be deleted/set to nil? + // TODO: is there any extra leftovers of the shimTask/Pod to clean? ie: verify all handles are closed? + s.gracefulShutdown = !req.Now + close(s.shutdown) + }) + return empty, nil } diff --git a/cmd/containerd-shim-runhcs-v1/service_internal_podshim_test.go b/cmd/containerd-shim-runhcs-v1/service_internal_podshim_test.go index dc316e6ea3..0d1fcc7766 100644 --- a/cmd/containerd-shim-runhcs-v1/service_internal_podshim_test.go +++ b/cmd/containerd-shim-runhcs-v1/service_internal_podshim_test.go @@ -2,9 +2,11 @@ package main import ( "context" + "fmt" "math/rand" "strconv" "testing" + "time" "github.com/Microsoft/hcsshim/cmd/containerd-shim-runhcs-v1/options" "github.com/Microsoft/hcsshim/cmd/containerd-shim-runhcs-v1/stats" @@ -16,11 +18,22 @@ import ( func setupPodServiceWithFakes(t *testing.T) (*service, *testShimTask, *testShimTask, *testShimExec) { tid := strconv.Itoa(rand.Int()) - s := service{ - tid: tid, - isSandbox: true, + + s, err := NewService(WithTID(tid), WithIsSandbox(true)) + if err != nil { + t.Fatalf("could not create service: %v", err) } + // clean up the service + t.Cleanup(func() { + if _, err := s.shutdownInternal(context.Background(), &task.ShutdownRequest{ + ID: s.tid, + Now: true, + }); err != nil { + t.Fatalf("could not shutdown service: %v", err) + } + }) + pod := &testShimPod{id: tid} // create init fake container @@ -45,7 +58,7 @@ func setupPodServiceWithFakes(t *testing.T) (*service, *testShimTask, *testShimT pod.tasks.Store(task.id, task) pod.tasks.Store(task2.id, task2) s.taskOrPod.Store(pod) - return &s, task, task2, task2exec2 + return s, task, task2, task2exec2 } func Test_PodShim_getPod_NotCreated_Error(t *testing.T) { @@ -723,3 +736,35 @@ func Test_PodShim_statsInternal_2ndTaskID_Success(t *testing.T) { }) } } + +func Test_PodShim_shutdownInternal(t *testing.T) { + for _, now := range []bool{true, false} { + t.Run(fmt.Sprintf("%s_Now_%t", t.Name(), now), func(t *testing.T) { + s, _, _, _ := setupPodServiceWithFakes(t) + + if s.IsShutdown() { + t.Fatal("service prematurely shutdown") + } + + _, err := s.shutdownInternal(context.Background(), &task.ShutdownRequest{ + ID: s.tid, + Now: now, + }) + if err != nil { + t.Fatalf("could not shut down service: %v", err) + } + + tm := time.NewTimer(5 * time.Millisecond) + select { + case <-tm.C: + t.Fatalf("shutdown channel did not close") + case <-s.Done(): + tm.Stop() + } + + if !s.IsShutdown() { + t.Fatal("service did not shutdown") + } + }) + } +} diff --git a/cmd/containerd-shim-runhcs-v1/service_internal_taskshim_test.go b/cmd/containerd-shim-runhcs-v1/service_internal_taskshim_test.go index 1ee8eb7157..43fbaa66d1 100644 --- a/cmd/containerd-shim-runhcs-v1/service_internal_taskshim_test.go +++ b/cmd/containerd-shim-runhcs-v1/service_internal_taskshim_test.go @@ -2,9 +2,11 @@ package main import ( "context" + "fmt" "math/rand" "strconv" "testing" + "time" "github.com/Microsoft/hcsshim/cmd/containerd-shim-runhcs-v1/options" "github.com/Microsoft/hcsshim/cmd/containerd-shim-runhcs-v1/stats" @@ -16,10 +18,22 @@ import ( func setupTaskServiceWithFakes(t *testing.T) (*service, *testShimTask, *testShimExec) { tid := strconv.Itoa(rand.Int()) - s := service{ - tid: tid, - isSandbox: false, + + s, err := NewService(WithTID(tid), WithIsSandbox(false)) + if err != nil { + t.Fatalf("could not create service: %v", err) } + + // clean up the service + t.Cleanup(func() { + if _, err := s.shutdownInternal(context.Background(), &task.ShutdownRequest{ + ID: s.tid, + Now: true, + }); err != nil { + t.Fatalf("could not shutdown service: %v", err) + } + }) + task := &testShimTask{ id: tid, exec: newTestShimExec(tid, tid, 10), @@ -29,7 +43,7 @@ func setupTaskServiceWithFakes(t *testing.T) (*service, *testShimTask, *testShim secondExec := newTestShimExec(tid, secondExecID, 101) task.execs[secondExecID] = secondExec s.taskOrPod.Store(task) - return &s, task, secondExec + return s, task, secondExec } func Test_TaskShim_getTask_NotCreated_Error(t *testing.T) { @@ -619,3 +633,35 @@ func Test_TaskShim_statsInternal_InitTaskID_Sucess(t *testing.T) { }) } } + +func Test_TaskShim_shutdownInternal(t *testing.T) { + for _, now := range []bool{true, false} { + t.Run(fmt.Sprintf("%s_Now_%t", t.Name(), now), func(t *testing.T) { + s, _, _ := setupTaskServiceWithFakes(t) + + if s.IsShutdown() { + t.Fatal("service prematurely shutdown") + } + + _, err := s.shutdownInternal(context.Background(), &task.ShutdownRequest{ + ID: s.tid, + Now: now, + }) + if err != nil { + t.Fatalf("could not shut down service: %v", err) + } + + tm := time.NewTimer(5 * time.Millisecond) + select { + case <-tm.C: + t.Fatalf("shutdown channel did not close") + case <-s.Done(): + tm.Stop() + } + + if !s.IsShutdown() { + t.Fatal("service did not shutdown") + } + }) + } +} diff --git a/cmd/containerd-shim-runhcs-v1/service_internal_test.go b/cmd/containerd-shim-runhcs-v1/service_internal_test.go index a6a0075e1f..2d0e483963 100644 --- a/cmd/containerd-shim-runhcs-v1/service_internal_test.go +++ b/cmd/containerd-shim-runhcs-v1/service_internal_test.go @@ -1,16 +1,20 @@ package main import ( + "context" + "fmt" "reflect" "testing" + "time" "github.com/Microsoft/hcsshim/cmd/containerd-shim-runhcs-v1/stats" v1 "github.com/containerd/cgroups/stats/v1" + "github.com/containerd/containerd/runtime/v2/task" "github.com/pkg/errors" ) func verifyExpectedError(t *testing.T, resp interface{}, actual, expected error) { - if actual == nil || errors.Cause(actual) != expected { + if actual == nil || errors.Cause(actual) != expected || !errors.Is(actual, expected) { t.Fatalf("expected error: %v, got: %v", expected, actual) } @@ -126,3 +130,38 @@ func verifyExpectedVirtualMachineStatistics(t *testing.T, v *stats.VirtualMachin t.Fatalf("expected VirtualMachineStatistics.Memory.WorkingSetBytes == 100, got: %d", v.Memory.WorkingSetBytes) } } + +func Test_Service_shutdownInternal(t *testing.T) { + for _, now := range []bool{true, false} { + t.Run(fmt.Sprintf("%s_Now_%t", t.Name(), now), func(t *testing.T) { + s, err := NewService(WithTID(t.Name())) + if err != nil { + t.Fatal(err) + } + + if s.IsShutdown() { + t.Fatal("service prematurely shutdown") + } + + _, err = s.shutdownInternal(context.Background(), &task.ShutdownRequest{ + ID: s.tid, + Now: now, + }) + if err != nil { + t.Fatalf("could not shut down service: %v", err) + } + + tm := time.NewTimer(5 * time.Millisecond) + select { + case <-tm.C: + t.Fatalf("shutdown channel did not close") + case <-s.Done(): + tm.Stop() + } + + if !s.IsShutdown() { + t.Fatal("service did not shutdown") + } + }) + } +}