diff --git a/internal/guest/bridge/bridge_v2.go b/internal/guest/bridge/bridge_v2.go index f7924a2576..99777ca84a 100644 --- a/internal/guest/bridge/bridge_v2.go +++ b/internal/guest/bridge/bridge_v2.go @@ -199,7 +199,7 @@ func (b *Bridge) execProcessV2(r *Request) (_ RequestResponse, err error) { var c *hcsv2.Container if params.IsExternal || request.ContainerID == hcsv2.UVMContainerID { pid, err = b.hostState.RunExternalProcess(ctx, params, conSettings) - } else if c, err = b.hostState.GetContainer(request.ContainerID); err == nil { + } else if c, err = b.hostState.GetCreatedContainer(request.ContainerID); err == nil { // We found a V2 container. Treat this as a V2 process. if params.OCIProcess == nil { pid, err = c.Start(ctx, conSettings) @@ -267,7 +267,7 @@ func (b *Bridge) signalContainerV2(ctx context.Context, span *trace.Span, r *Req b.quitChan <- true b.hostState.Shutdown() } else { - c, err := b.hostState.GetContainer(request.ContainerID) + c, err := b.hostState.GetCreatedContainer(request.ContainerID) if err != nil { return nil, err } @@ -296,7 +296,7 @@ func (b *Bridge) signalProcessV2(r *Request) (_ RequestResponse, err error) { trace.Int64Attribute("pid", int64(request.ProcessID)), trace.Int64Attribute("signal", int64(request.Options.Signal))) - c, err := b.hostState.GetContainer(request.ContainerID) + c, err := b.hostState.GetCreatedContainer(request.ContainerID) if err != nil { return nil, err } @@ -344,7 +344,7 @@ func (b *Bridge) getPropertiesV2(r *Request) (_ RequestResponse, err error) { return nil, errors.New("getPropertiesV2 is not supported against the UVM") } - c, err := b.hostState.GetContainer(request.ContainerID) + c, err := b.hostState.GetCreatedContainer(request.ContainerID) if err != nil { return nil, err } @@ -407,7 +407,7 @@ func (b *Bridge) waitOnProcessV2(r *Request) (_ RequestResponse, err error) { } exitCodeChan, doneChan = p.Wait() } else { - c, err := b.hostState.GetContainer(request.ContainerID) + c, err := b.hostState.GetCreatedContainer(request.ContainerID) if err != nil { return nil, err } @@ -453,7 +453,7 @@ func (b *Bridge) resizeConsoleV2(r *Request) (_ RequestResponse, err error) { trace.Int64Attribute("height", int64(request.Height)), trace.Int64Attribute("width", int64(request.Width))) - c, err := b.hostState.GetContainer(request.ContainerID) + c, err := b.hostState.GetCreatedContainer(request.ContainerID) if err != nil { return nil, err } @@ -514,7 +514,7 @@ func (b *Bridge) deleteContainerStateV2(r *Request) (_ RequestResponse, err erro return nil, errors.Wrapf(err, "failed to unmarshal JSON in message \"%s\"", r.Message) } - c, err := b.hostState.GetContainer(request.ContainerID) + c, err := b.hostState.GetCreatedContainer(request.ContainerID) if err != nil { return nil, err } diff --git a/internal/guest/runtime/hcsv2/container.go b/internal/guest/runtime/hcsv2/container.go index d890d0c2b0..cc4304c2f3 100644 --- a/internal/guest/runtime/hcsv2/container.go +++ b/internal/guest/runtime/hcsv2/container.go @@ -6,6 +6,7 @@ package hcsv2 import ( "context" "sync" + "sync/atomic" "syscall" "github.com/containerd/cgroups" @@ -28,6 +29,18 @@ import ( "github.com/Microsoft/hcsshim/internal/protocol/guestresource" ) +// containerStatus has been introduced to enable parallel container creation +type containerStatus uint32 + +const ( + // containerCreating is the default status set on a Container object, when + // no underlying runtime container or init process has been assigned + containerCreating containerStatus = iota + // containerCreated is the status when a runtime container and init process + // have been assigned, but runtime start command has not been issued yet + containerCreated +) + type Container struct { id string vsock transport.Transport @@ -43,6 +56,9 @@ type Container struct { processesMutex sync.Mutex processes map[uint32]*containerProcess + + // Only access atomically through getStatus/setStatus. + status containerStatus } func (c *Container) Start(ctx context.Context, conSettings stdio.ConnectionSettings) (int, error) { @@ -220,3 +236,12 @@ func (c *Container) GetStats(ctx context.Context) (*v1.Metrics, error) { func (c *Container) modifyContainerConstraints(ctx context.Context, rt guestrequest.RequestType, cc *guestresource.LCOWContainerConstraints) (err error) { return c.Update(ctx, cc.Linux) } + +func (c *Container) getStatus() containerStatus { + val := atomic.LoadUint32((*uint32)(&c.status)) + return containerStatus(val) +} + +func (c *Container) setStatus(st containerStatus) { + atomic.StoreUint32((*uint32)(&c.status), uint32(st)) +} diff --git a/internal/guest/runtime/hcsv2/uvm.go b/internal/guest/runtime/hcsv2/uvm.go index 568f0ed448..380112a00d 100644 --- a/internal/guest/runtime/hcsv2/uvm.go +++ b/internal/guest/runtime/hcsv2/uvm.go @@ -16,11 +16,8 @@ import ( "syscall" "time" - "github.com/Microsoft/hcsshim/internal/guest/policy" - "github.com/mattn/go-shellwords" - "github.com/pkg/errors" - "github.com/Microsoft/hcsshim/internal/guest/gcserr" + "github.com/Microsoft/hcsshim/internal/guest/policy" "github.com/Microsoft/hcsshim/internal/guest/prot" "github.com/Microsoft/hcsshim/internal/guest/runtime" "github.com/Microsoft/hcsshim/internal/guest/spec" @@ -36,6 +33,8 @@ import ( "github.com/Microsoft/hcsshim/internal/protocol/guestresource" "github.com/Microsoft/hcsshim/pkg/annotations" "github.com/Microsoft/hcsshim/pkg/securitypolicy" + "github.com/mattn/go-shellwords" + "github.com/pkg/errors" ) // UVMContainerID is the ContainerID that will be sent on any prot.MessageBase @@ -123,19 +122,30 @@ func (h *Host) RemoveContainer(id string) { delete(h.containers, id) } -func (h *Host) getContainerLocked(id string) (*Container, error) { +func (h *Host) GetCreatedContainer(id string) (*Container, error) { + h.containersMutex.Lock() + defer h.containersMutex.Unlock() + if c, ok := h.containers[id]; !ok { return nil, gcserr.NewHresultError(gcserr.HrVmcomputeSystemNotFound) } else { + if c.getStatus() != containerCreated { + return nil, fmt.Errorf("container is not in state \"created\": %w", + gcserr.NewHresultError(gcserr.HrVmcomputeInvalidState)) + } return c, nil } } -func (h *Host) GetContainer(id string) (*Container, error) { +func (h *Host) AddContainer(id string, c *Container) error { h.containersMutex.Lock() defer h.containersMutex.Unlock() - return h.getContainerLocked(id) + if _, ok := h.containers[id]; ok { + return gcserr.NewHresultError(gcserr.HrVmcomputeSystemAlreadyExists) + } + h.containers[id] = c + return nil } func setupSandboxMountsPath(id string) (err error) { @@ -162,12 +172,25 @@ func setupSandboxHugePageMountsPath(id string) error { } func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VMHostedContainerSettingsV2) (_ *Container, err error) { - h.containersMutex.Lock() - defer h.containersMutex.Unlock() + criType, isCRI := settings.OCISpecification.Annotations[annotations.KubernetesContainerType] + c := &Container{ + id: id, + vsock: h.vsock, + spec: settings.OCISpecification, + isSandbox: criType == "sandbox", + exitType: prot.NtUnexpectedExit, + processes: make(map[uint32]*containerProcess), + status: containerCreating, + } - if _, ok := h.containers[id]; ok { - return nil, gcserr.NewHresultError(gcserr.HrVmcomputeSystemAlreadyExists) + if err := h.AddContainer(id, c); err != nil { + return nil, err } + defer func() { + if err != nil { + h.RemoveContainer(id) + } + }() err = h.securityPolicyEnforcer.EnforceCreateContainerPolicy( id, @@ -175,13 +198,11 @@ func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VM settings.OCISpecification.Process.Env, settings.OCISpecification.Process.Cwd, ) - if err != nil { return nil, errors.Wrapf(err, "container creation denied due to policy") } var namespaceID string - criType, isCRI := settings.OCISpecification.Annotations[annotations.KubernetesContainerType] // for sandbox container sandboxID is same as container id sandboxID := id if isCRI { @@ -290,15 +311,7 @@ func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VM return nil, errors.Wrapf(err, "failed to get container init process") } - c := &Container{ - id: id, - vsock: h.vsock, - spec: settings.OCISpecification, - isSandbox: criType == "sandbox", - container: con, - exitType: prot.NtUnexpectedExit, - processes: make(map[uint32]*containerProcess), - } + c.container = con c.initProcess = newProcess(c, settings.OCISpecification.Process, init, uint32(c.container.Pid()), true) // Sandbox or standalone, move the networks to the container namespace @@ -318,7 +331,7 @@ func (h *Host) CreateContainer(ctx context.Context, id string, settings *prot.VM } } - h.containers[id] = c + c.setStatus(containerCreated) return c, nil } @@ -337,7 +350,7 @@ func (h *Host) modifyHostSettings(ctx context.Context, containerID string, req * case guestresource.ResourceTypeVPCIDevice: return modifyMappedVPCIDevice(ctx, req.RequestType, req.Settings.(*guestresource.LCOWMappedVPCIDevice)) case guestresource.ResourceTypeContainerConstraints: - c, err := h.GetContainer(containerID) + c, err := h.GetCreatedContainer(containerID) if err != nil { return err } @@ -355,7 +368,7 @@ func (h *Host) modifyHostSettings(ctx context.Context, containerID string, req * } func (h *Host) modifyContainerSettings(ctx context.Context, containerID string, req *guestrequest.ModificationRequest) error { - c, err := h.GetContainer(containerID) + c, err := h.GetCreatedContainer(containerID) if err != nil { return err } diff --git a/test/cri-containerd/policy_test.go b/test/cri-containerd/policy_test.go index 9d1348218d..2892c774e4 100644 --- a/test/cri-containerd/policy_test.go +++ b/test/cri-containerd/policy_test.go @@ -8,6 +8,7 @@ import ( "fmt" "strings" "testing" + "time" runtime "k8s.io/cri-api/pkg/apis/runtime/v1alpha2" @@ -232,13 +233,30 @@ func Test_RunContainers_WithSyncHooks_ValidWaitPath(t *testing.T) { cidWriter := createContainer(t, client, ctx, writerReq) cidWaiter := createContainer(t, client, ctx, waiterReq) - startContainer(t, client, ctx, cidWriter) - defer removeContainer(t, client, ctx, cidWriter) - defer stopContainer(t, client, ctx, cidWriter) - - startContainer(t, client, ctx, cidWaiter) - defer removeContainer(t, client, ctx, cidWaiter) - defer stopContainer(t, client, ctx, cidWaiter) + errChan := make(chan error) + go func() { + _, err := client.StartContainer(ctx, &runtime.StartContainerRequest{ContainerId: cidWaiter}) + errChan <- err + defer removeContainer(t, client, ctx, cidWaiter) + defer stopContainer(t, client, ctx, cidWaiter) + }() + + // give some time for the first go routine to kick in. + time.Sleep(time.Second) + + go func() { + _, err := client.StartContainer(ctx, &runtime.StartContainerRequest{ContainerId: cidWriter}) + errChan <- err + defer removeContainer(t, client, ctx, cidWriter) + defer stopContainer(t, client, ctx, cidWriter) + }() + + for i := 0; i < 2; i++ { + if err := <-errChan; err != nil { + close(errChan) + t.Fatalf("failed to start container: %s", err) + } + } } func Test_RunContainers_WithSyncHooks_InvalidWaitPath(t *testing.T) {