diff --git a/virtcontainers/factory/factory_test.go b/virtcontainers/factory/factory_test.go index 1b1f7888b4..452889d5cb 100644 --- a/virtcontainers/factory/factory_test.go +++ b/virtcontainers/factory/factory_test.go @@ -30,6 +30,7 @@ func TestNewFactory(t *testing.T) { config.VMConfig = vc.VMConfig{ HypervisorType: vc.MockHypervisor, AgentType: vc.NoopAgentType, + ProxyType: vc.NoopProxyType, } _, err = NewFactory(ctx, config, false) @@ -43,28 +44,33 @@ func TestNewFactory(t *testing.T) { } // direct - _, err = NewFactory(ctx, config, false) + f, err := NewFactory(ctx, config, false) assert.Nil(err) - _, err = NewFactory(ctx, config, true) + f.CloseFactory(ctx) + f, err = NewFactory(ctx, config, true) assert.Nil(err) + f.CloseFactory(ctx) // template config.Template = true - _, err = NewFactory(ctx, config, false) + f, err = NewFactory(ctx, config, false) assert.Nil(err) + f.CloseFactory(ctx) _, err = NewFactory(ctx, config, true) assert.Error(err) // Cache config.Cache = 10 - _, err = NewFactory(ctx, config, false) + f, err = NewFactory(ctx, config, false) assert.Nil(err) + f.CloseFactory(ctx) _, err = NewFactory(ctx, config, true) assert.Error(err) config.Template = false - _, err = NewFactory(ctx, config, false) + f, err = NewFactory(ctx, config, false) assert.Nil(err) + f.CloseFactory(ctx) _, err = NewFactory(ctx, config, true) assert.Error(err) } diff --git a/virtcontainers/factory/template/template.go b/virtcontainers/factory/template/template.go index beea512d23..d0c8c5650e 100644 --- a/virtcontainers/factory/template/template.go +++ b/virtcontainers/factory/template/template.go @@ -51,6 +51,11 @@ func New(ctx context.Context, config vc.VMConfig) base.FactoryBase { // fallback to direct factory if template is not supported. return direct.New(ctx, config) } + defer func() { + if err != nil { + t.close() + } + }() err = t.createTemplateVM(ctx) if err != nil { @@ -73,6 +78,10 @@ func (t *template) GetBaseVM(ctx context.Context, config vc.VMConfig) (*vc.VM, e // CloseFactory cleans up the template VM. func (t *template) CloseFactory(ctx context.Context) { + t.close() +} + +func (t *template) close() { syscall.Unmount(t.statePath, 0) os.RemoveAll(t.statePath) } @@ -86,10 +95,12 @@ func (t *template) prepareTemplateFiles() error { flags := uintptr(syscall.MS_NOSUID | syscall.MS_NODEV) opts := fmt.Sprintf("size=%dM", t.config.HypervisorConfig.MemorySize+8) if err = syscall.Mount("tmpfs", t.statePath, "tmpfs", flags, opts); err != nil { + t.close() return err } f, err := os.Create(t.statePath + "/memory") if err != nil { + t.close() return err } f.Close()