diff --git a/internal/tools/uvmboot/lcow.go b/internal/tools/uvmboot/lcow.go index 4d78c79848..4d029a9d04 100644 --- a/internal/tools/uvmboot/lcow.go +++ b/internal/tools/uvmboot/lcow.go @@ -4,6 +4,7 @@ package main import ( "context" + "fmt" "io" "os" "strings" @@ -17,8 +18,10 @@ import ( ) const ( + bootFilesPathArgName = "boot-files-path" consolePipeArgName = "console-pipe" kernelDirectArgName = "kernel-direct" + kernelFileArgName = "kernel-file" forwardStdoutArgName = "fwd-stdout" forwardStderrArgName = "fwd-stderr" outputHandlingArgName = "output-handling" @@ -26,10 +29,14 @@ const ( rootFSTypeArgName = "root-fs-type" vpMemMaxCountArgName = "vpmem-max-count" vpMemMaxSizeArgName = "vpmem-max-size" + scsiMountsArgName = "mount" + shareFilesArgName = "share" + securityPolicyArgName = "security-policy" ) var ( - lcowUseTerminal bool + lcowUseTerminal bool + lcowDisableTimeSync bool ) var lcowCommand = cli.Command{ @@ -45,6 +52,10 @@ var lcowCommand = cli.Command{ Name: rootFSTypeArgName, Usage: "Either 'initrd' or 'vhd'. (default: 'vhd' if rootfs.vhd exists)", }, + cli.StringFlag{ + Name: bootFilesPathArgName, + Usage: "The `path` to the boot files directory", + }, cli.UintFlag{ Name: vpMemMaxCountArgName, Usage: "Number of VPMem devices on the UVM. Uses hcsshim default if not specified", @@ -57,6 +68,19 @@ var lcowCommand = cli.Command{ Name: kernelDirectArgName, Usage: "Use kernel direct booting for UVM (default: true on builds >= 18286)", }, + cli.StringFlag{ + Name: kernelFileArgName, + Usage: "The kernel `file` to use; either 'kernel' or 'vmlinux'. (default: 'kernel')", + }, + cli.BoolFlag{ + Name: "disable-time-sync", + Usage: "Disable the time synchronization service", + Destination: &lcowDisableTimeSync, + }, + cli.StringFlag{ + Name: securityPolicyArgName, + Usage: "Security policy to set on the UVM. Leave empty to use an open door policy", + }, cli.StringFlag{ Name: execCommandLineArgName, Usage: "Command to execute in the UVM.", @@ -82,66 +106,32 @@ var lcowCommand = cli.Command{ Usage: "create the process in the UVM with a TTY enabled", Destination: &lcowUseTerminal, }, + cli.StringSliceFlag{ + Name: scsiMountsArgName, + Usage: "List of VHDs to SCSI mount into the UVM. Use repeat instances to add multiple. " + + "Value is of the form `'host[,guest[,w]]'`, where 'host' is path to the VHD, " + + `'guest' is an optional mount path inside the UVM, and 'w' mounts the VHD as writeable`, + }, + cli.StringSliceFlag{ + Name: shareFilesArgName, + Usage: "List of paths or files to plan9 share into the UVM. Use repeat instances to add multiple. " + + "Value is of the form `'host,guest[,w]' where 'host' is path to the VHD, " + + `'guest' is the mount path inside the UVM, and 'w' sets the shared files to writeable`, + }, }, Action: func(c *cli.Context) error { runMany(c, func(id string) error { - options := uvm.NewDefaultOptionsLCOW(id, "") - setGlobalOptions(c, options.Options) - useGcs := c.GlobalBool(gcsArgName) - options.UseGuestConnection = useGcs + ctx := context.Background() - if c.IsSet(kernelDirectArgName) { - options.KernelDirect = c.Bool(kernelDirectArgName) - } - if c.IsSet(rootFSTypeArgName) { - switch strings.ToLower(c.String(rootFSTypeArgName)) { - case "initrd": - options.RootFSFile = uvm.InitrdFile - options.PreferredRootFSType = uvm.PreferredRootFSTypeInitRd - case "vhd": - options.RootFSFile = uvm.VhdFile - options.PreferredRootFSType = uvm.PreferredRootFSTypeVHD - default: - logrus.Fatalf("Unrecognized value '%s' for option %s", c.String(rootFSTypeArgName), rootFSTypeArgName) - } - } - if c.IsSet(kernelArgsArgName) { - options.KernelBootOptions = c.String(kernelArgsArgName) - } - if c.IsSet(vpMemMaxCountArgName) { - options.VPMemDeviceCount = uint32(c.Uint(vpMemMaxCountArgName)) - } - if c.IsSet(vpMemMaxSizeArgName) { - options.VPMemSizeBytes = c.Uint64(vpMemMaxSizeArgName) * memory.MiB // convert from MB to bytes - } - if !useGcs { - if c.IsSet(execCommandLineArgName) { - options.ExecCommandLine = c.String(execCommandLineArgName) - } - if c.IsSet(forwardStdoutArgName) { - options.ForwardStdout = c.Bool(forwardStdoutArgName) - } - if c.IsSet(forwardStderrArgName) { - options.ForwardStderr = c.Bool(forwardStderrArgName) - } - if c.IsSet(outputHandlingArgName) { - switch strings.ToLower(c.String(outputHandlingArgName)) { - case "stdout": - options.OutputHandler = uvm.OutputHandler(func(r io.Reader) { - _, _ = io.Copy(os.Stdout, r) - }) - default: - logrus.Fatalf("Unrecognized value '%s' for option %s", c.String(outputHandlingArgName), outputHandlingArgName) - } - } - } - if c.IsSet(consolePipeArgName) { - options.ConsolePipe = c.String(consolePipeArgName) + options, err := createLCOWOptions(ctx, c, id) + if err != nil { + return err } - if err := runLCOW(context.TODO(), options, c); err != nil { + if err := runLCOW(ctx, options, c); err != nil { return err } + return nil }) @@ -149,27 +139,136 @@ var lcowCommand = cli.Command{ }, } +func init() { + lcowCommand.CustomHelpTemplate = cli.CommandHelpTemplate + "EXAMPLES:\n" + + `.\uvmboot.exe -gcs lcow -boot-files-path "C:\ContainerPlat\LinuxBootFiles" -root-fs-type vhd -t -exec "/bin/bash"` +} + +func createLCOWOptions(_ context.Context, c *cli.Context, id string) (*uvm.OptionsLCOW, error) { + options := uvm.NewDefaultOptionsLCOW(id, "") + setGlobalOptions(c, options.Options) + + // boot + if c.IsSet(bootFilesPathArgName) { + options.BootFilesPath = c.String(bootFilesPathArgName) + } + + // kernel + if c.IsSet(kernelDirectArgName) { + options.KernelDirect = c.Bool(kernelDirectArgName) + } + if c.IsSet(kernelFileArgName) { + switch strings.ToLower(c.String(kernelFileArgName)) { + case uvm.KernelFile: + options.KernelFile = uvm.KernelFile + case uvm.UncompressedKernelFile: + options.KernelFile = uvm.UncompressedKernelFile + default: + return nil, unrecognizedError(c.String(kernelFileArgName), kernelFileArgName) + } + } + if c.IsSet(kernelArgsArgName) { + options.KernelBootOptions = c.String(kernelArgsArgName) + } + + // rootfs + if c.IsSet(rootFSTypeArgName) { + switch strings.ToLower(c.String(rootFSTypeArgName)) { + case "initrd": + options.RootFSFile = uvm.InitrdFile + options.PreferredRootFSType = uvm.PreferredRootFSTypeInitRd + case "vhd": + options.RootFSFile = uvm.VhdFile + options.PreferredRootFSType = uvm.PreferredRootFSTypeVHD + default: + return nil, unrecognizedError(c.String(rootFSTypeArgName), rootFSTypeArgName) + } + } + + if c.IsSet(vpMemMaxCountArgName) { + options.VPMemDeviceCount = uint32(c.Uint(vpMemMaxCountArgName)) + } + if c.IsSet(vpMemMaxSizeArgName) { + options.VPMemSizeBytes = c.Uint64(vpMemMaxSizeArgName) * memory.MiB // convert from MB to bytes + } + + // GCS + options.UseGuestConnection = useGCS + if !useGCS { + if c.IsSet(execCommandLineArgName) { + options.ExecCommandLine = c.String(execCommandLineArgName) + } + if c.IsSet(forwardStdoutArgName) { + options.ForwardStdout = c.Bool(forwardStdoutArgName) + } + if c.IsSet(forwardStderrArgName) { + options.ForwardStderr = c.Bool(forwardStderrArgName) + } + if c.IsSet(outputHandlingArgName) { + switch strings.ToLower(c.String(outputHandlingArgName)) { + case "stdout": + options.OutputHandler = uvm.OutputHandler(func(r io.Reader) { + _, _ = io.Copy(os.Stdout, r) + }) + default: + return nil, unrecognizedError(c.String(outputHandlingArgName), outputHandlingArgName) + } + } + } + if c.IsSet(consolePipeArgName) { + options.ConsolePipe = c.String(consolePipeArgName) + } + + // general settings + if lcowDisableTimeSync { + options.DisableTimeSyncService = true + } + + if c.IsSet(securityPolicyArgName) { + options.SecurityPolicy = c.String(options.SecurityPolicy) + options.SecurityPolicyEnabled = true + } + + return options, nil +} + func runLCOW(ctx context.Context, options *uvm.OptionsLCOW, c *cli.Context) error { - uvm, err := uvm.CreateLCOW(ctx, options) + vm, err := uvm.CreateLCOW(ctx, options) if err != nil { return err } - defer uvm.Close() + defer vm.Close() - if err := uvm.Start(ctx); err != nil { + if err := vm.Start(ctx); err != nil { + return err + } + + if c.IsSet(securityPolicyArgName) { + if err := vm.SetSecurityPolicy(ctx, options.SecurityPolicy); err != nil { + return fmt.Errorf("could not set UVM security policy: %w", err) + } + logrus.WithField("policy", options.SecurityPolicy).Debug("Set UVM security policy") + } + + if err := mountSCSI(ctx, c, vm); err != nil { + return err + } + + if err := shareFiles(ctx, c, vm); err != nil { return err } if options.UseGuestConnection { - if err := execViaGcs(uvm, c); err != nil { + if err := execViaGcs(vm, c); err != nil { return err } - _ = uvm.Terminate(ctx) - _ = uvm.Wait() - return uvm.ExitError() + _ = vm.Terminate(ctx) + _ = vm.Wait() + + return vm.ExitError() } - return uvm.Wait() + return vm.Wait() } func execViaGcs(vm *uvm.UtilityVM, c *cli.Context) error { @@ -197,5 +296,6 @@ func execViaGcs(vm *uvm.UtilityVM, c *cli.Context) error { cmd.Stderr = os.Stdout // match non-GCS behavior and forward to stdout } } + return cmd.Run() } diff --git a/internal/tools/uvmboot/main.go b/internal/tools/uvmboot/main.go index 3d89af513d..508b5d799f 100644 --- a/internal/tools/uvmboot/main.go +++ b/internal/tools/uvmboot/main.go @@ -4,11 +4,13 @@ package main import ( "fmt" + "log" "os" "sync" "time" "github.com/Microsoft/hcsshim/internal/uvm" + "github.com/Microsoft/hcsshim/internal/winapi" "github.com/sirupsen/logrus" "github.com/urfave/cli" ) @@ -21,12 +23,17 @@ const ( measureArgName = "measure" parallelArgName = "parallel" countArgName = "count" - debugArgName = "debug" - gcsArgName = "gcs" execCommandLineArgName = "exec" ) +var ( + debug bool + useGCS bool +) + +type uvmRunFunc func(string) error + func main() { app := cli.NewApp() app.Name = "uvmboot" @@ -64,12 +71,14 @@ func main() { Usage: "Enable deferred commit on the UVM", }, cli.BoolFlag{ - Name: debugArgName, - Usage: "Enable debug level logging in HCSShim", + Name: "debug", + Usage: "Enable debug information", + Destination: &debug, }, cli.BoolFlag{ - Name: gcsArgName, - Usage: "Launch the GCS and perform requested operations via its RPC interface", + Name: "gcs", + Usage: "Launch the GCS and perform requested operations via its RPC interface", + Destination: &useGCS, }, } @@ -79,17 +88,21 @@ func main() { } app.Before = func(c *cli.Context) error { - if c.GlobalBool("debug") { + if !winapi.IsElevated() { + log.Fatal(c.App.Name + " must be run in an elevated context") + } + + if debug { logrus.SetLevel(logrus.DebugLevel) } else { logrus.SetLevel(logrus.WarnLevel) } + return nil } - err := app.Run(os.Args) - if err != nil { - logrus.Fatal(err) + if err := app.Run(os.Args); err != nil { + logrus.Fatalf("%v\n", err) } } @@ -108,7 +121,9 @@ func setGlobalOptions(c *cli.Context, options *uvm.Options) { } } -func runMany(c *cli.Context, runFunc func(id string) error) { +// todo: add a context here to propagate cancel/timeouts to runFunc uvm + +func runMany(c *cli.Context, runFunc uvmRunFunc) { parallelCount := c.GlobalInt(parallelArgName) var wg sync.WaitGroup @@ -118,8 +133,7 @@ func runMany(c *cli.Context, runFunc func(id string) error) { go func() { for i := range workChan { id := fmt.Sprintf("uvmboot-%d", i) - err := runFunc(id) - if err != nil { + if err := runFunc(id); err != nil { logrus.WithField("uvm-id", id).WithError(err).Error("failed to run UVM") } } @@ -138,3 +152,7 @@ func runMany(c *cli.Context, runFunc func(id string) error) { fmt.Println("Elapsed time:", time.Since(start)) } } + +func unrecognizedError(name, value string) error { + return fmt.Errorf("unrecognized value '%s' for option %s", name, value) +} diff --git a/internal/tools/uvmboot/mounts.go b/internal/tools/uvmboot/mounts.go new file mode 100644 index 0000000000..327d1d0b08 --- /dev/null +++ b/internal/tools/uvmboot/mounts.go @@ -0,0 +1,116 @@ +//go:build windows + +package main + +import ( + "context" + "fmt" + "strings" + + "github.com/Microsoft/hcsshim/internal/uvm" + "github.com/sirupsen/logrus" + "github.com/urfave/cli" +) + +func mountSCSI(ctx context.Context, c *cli.Context, vm *uvm.UtilityVM) error { + for _, m := range parseMounts(c, scsiMountsArgName) { + if _, err := vm.AddSCSI( + ctx, + m.host, + m.guest, + !m.writable, + false, // encrypted + []string{}, + uvm.VMAccessTypeIndividual, + ); err != nil { + return fmt.Errorf("could not mount disk %s: %w", m.host, err) + } else { + logrus.WithFields(logrus.Fields{ + "host": m.host, + "guest": m.guest, + "writable": m.writable, + }).Debug("Mounted SCSI disk") + } + } + + return nil +} + +func shareFiles(ctx context.Context, c *cli.Context, vm *uvm.UtilityVM) error { + switch os := vm.OS(); os { + case "linux": + return shareFilesLCOW(ctx, c, vm) + default: + return fmt.Errorf("file shares are not supported for %s UVMs", os) + } +} + +func shareFilesLCOW(ctx context.Context, c *cli.Context, vm *uvm.UtilityVM) error { + for _, s := range parseMounts(c, shareFilesArgName) { + if s.guest == "" { + return fmt.Errorf("file shares %q has invalid quest destination: %q", s.host, s.guest) + } + + if err := vm.Share(ctx, s.host, s.guest, !s.writable); err != nil { + return fmt.Errorf("could not share file or directory %s: %w", s.host, err) + } else { + logrus.WithFields(logrus.Fields{ + "host": s.host, + "guest": s.guest, + "writable": s.writable, + }).Debug("Shared path") + } + } + + return nil +} + +type mount struct { + host string + guest string + writable bool +} + +// parseMounts parses the mounts stored under the cli StringSlice argument, `n` +func parseMounts(c *cli.Context, n string) []mount { + if c.IsSet(n) { + ss := c.StringSlice(n) + ms := make([]mount, 0, len(ss)) + for _, s := range ss { + logrus.Debugf("parsing %q", s) + + if m, err := mountFromString(s); err == nil { + ms = append(ms, m) + } + } + + return ms + } + + return nil +} + +func mountFromString(s string) (m mount, _ error) { + ps := strings.Split(s, ",") + + l := len(ps) + if l == 0 { // shouldn't happen, but just in case + return m, fmt.Errorf("could not parse string %q", s) + } + + if l > 3 { + return m, fmt.Errorf("too many parts in %q", s) + } + + m.host = ps[0] + + if l >= 2 { + m.guest = ps[1] + } + + if l == 3 && strings.ToLower(ps[2]) == "w" { + m.writable = true + } + + return m, nil +} diff --git a/internal/winapi/elevation.go b/internal/winapi/elevation.go new file mode 100644 index 0000000000..40cbf8712f --- /dev/null +++ b/internal/winapi/elevation.go @@ -0,0 +1,11 @@ +//go:build windows + +package winapi + +import ( + "golang.org/x/sys/windows" +) + +func IsElevated() bool { + return windows.GetCurrentProcessToken().IsElevated() +} diff --git a/test/vendor/github.com/Microsoft/hcsshim/internal/winapi/elevation.go b/test/vendor/github.com/Microsoft/hcsshim/internal/winapi/elevation.go new file mode 100644 index 0000000000..40cbf8712f --- /dev/null +++ b/test/vendor/github.com/Microsoft/hcsshim/internal/winapi/elevation.go @@ -0,0 +1,11 @@ +//go:build windows + +package winapi + +import ( + "golang.org/x/sys/windows" +) + +func IsElevated() bool { + return windows.GetCurrentProcessToken().IsElevated() +}