From bdbc1542cf6f0df1c0c25c4ceb500d693d287070 Mon Sep 17 00:00:00 2001 From: Daniel Canter Date: Wed, 16 Sep 2020 09:23:03 -0700 Subject: [PATCH] Add high level job object wrapper * Add high level job object wrapper. * Add extra job object bindings for stats usage. Signed-off-by: Daniel Canter --- internal/jobobject/iocp.go | 111 ++++++++ internal/jobobject/jobobject.go | 418 ++++++++++++++++++++++++++++ internal/queue/mq.go | 111 ++++++++ internal/queue/queue_test.go | 199 +++++++++++++ internal/winapi/iocp.go | 3 + internal/winapi/jobobject.go | 91 +++++- internal/winapi/winapi.go | 2 +- internal/winapi/zsyscall_windows.go | 13 + 8 files changed, 934 insertions(+), 14 deletions(-) create mode 100644 internal/jobobject/iocp.go create mode 100644 internal/jobobject/jobobject.go create mode 100644 internal/queue/mq.go create mode 100644 internal/queue/queue_test.go create mode 100644 internal/winapi/iocp.go diff --git a/internal/jobobject/iocp.go b/internal/jobobject/iocp.go new file mode 100644 index 0000000000..afbc92c5ba --- /dev/null +++ b/internal/jobobject/iocp.go @@ -0,0 +1,111 @@ +package jobobject + +import ( + "context" + "fmt" + "sync" + "unsafe" + + "github.com/Microsoft/hcsshim/internal/log" + "github.com/Microsoft/hcsshim/internal/queue" + "github.com/Microsoft/hcsshim/internal/winapi" + "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" +) + +var ( + ioInitOnce sync.Once + initIOErr error + // Global iocp handle that will be re-used for every job object + ioCompletionPort windows.Handle + // Mapping of job handle to queue to place notifications in. + jobMap sync.Map +) + +// MsgAllProcessesExited is a type representing a message that every process in a job has exited. +type MsgAllProcessesExited struct{} + +// MsgUnimplemented represents a message that we are aware of, but that isn't implemented currently. +// This should not be treated as an error. +type MsgUnimplemented struct{} + +// pollIOCP polls the io completion port forever. +func pollIOCP(ctx context.Context, iocpHandle windows.Handle) { + var ( + overlapped uintptr + code uint32 + key uintptr + ) + + for { + err := winapi.GetQueuedCompletionStatus(iocpHandle, &code, &key, (**windows.Overlapped)(unsafe.Pointer(&overlapped)), windows.INFINITE) + if err != nil { + log.G(ctx).WithError(err).Error("failed to poll for job object message") + continue + } + if val, ok := jobMap.Load(key); ok { + msq, ok := val.(*queue.MessageQueue) + if !ok { + log.G(ctx).WithField("value", msq).Warn("encountered non queue type in job map") + continue + } + notification, err := parseMessage(code, overlapped) + if err != nil { + log.G(ctx).WithFields(logrus.Fields{ + "code": code, + "overlapped": overlapped, + }).Warn("failed to parse job object message") + continue + } + if err := msq.Write(notification); err == queue.ErrQueueClosed { + // Write will only return an error when the queue is closed. + // The only time a queue would ever be closed is when we call `Close` on + // the job it belongs to which also removes it from the jobMap, so something + // went wrong here. We can't return as this is reading messages for all jobs + // so just log it and move on. + log.G(ctx).WithFields(logrus.Fields{ + "code": code, + "overlapped": overlapped, + }).Warn("tried to write to a closed queue") + continue + } + } else { + log.G(ctx).Warn("received a message for a job not present in the mapping") + } + } +} + +func parseMessage(code uint32, overlapped uintptr) (interface{}, error) { + // Check code and parse out relevant information related to that notification + // that we care about. For now all we handle is the message that all processes + // in the job have exited. + switch code { + case winapi.JOB_OBJECT_MSG_ACTIVE_PROCESS_ZERO: + return MsgAllProcessesExited{}, nil + // Other messages for completeness and a check to make sure that if we fall + // into the default case that this is a code we don't know how to handle. + case winapi.JOB_OBJECT_MSG_END_OF_JOB_TIME: + case winapi.JOB_OBJECT_MSG_END_OF_PROCESS_TIME: + case winapi.JOB_OBJECT_MSG_ACTIVE_PROCESS_LIMIT: + case winapi.JOB_OBJECT_MSG_NEW_PROCESS: + case winapi.JOB_OBJECT_MSG_EXIT_PROCESS: + case winapi.JOB_OBJECT_MSG_ABNORMAL_EXIT_PROCESS: + case winapi.JOB_OBJECT_MSG_PROCESS_MEMORY_LIMIT: + case winapi.JOB_OBJECT_MSG_JOB_MEMORY_LIMIT: + case winapi.JOB_OBJECT_MSG_NOTIFICATION_LIMIT: + default: + return nil, fmt.Errorf("unknown job notification type: %d", code) + } + return MsgUnimplemented{}, nil +} + +// Assigns an IO completion port to get notified of events for the registered job +// object. +func attachIOCP(job windows.Handle, iocp windows.Handle) error { + info := winapi.JOBOBJECT_ASSOCIATE_COMPLETION_PORT{ + CompletionKey: job, + CompletionPort: iocp, + } + _, err := windows.SetInformationJobObject(job, windows.JobObjectAssociateCompletionPortInformation, uintptr(unsafe.Pointer(&info)), uint32(unsafe.Sizeof(info))) + return err +} diff --git a/internal/jobobject/jobobject.go b/internal/jobobject/jobobject.go new file mode 100644 index 0000000000..143e6b3126 --- /dev/null +++ b/internal/jobobject/jobobject.go @@ -0,0 +1,418 @@ +package jobobject + +import ( + "context" + "fmt" + "sync" + "unsafe" + + "github.com/Microsoft/hcsshim/internal/queue" + "github.com/Microsoft/hcsshim/internal/winapi" + "github.com/pkg/errors" + "golang.org/x/sys/windows" +) + +// This file provides higher level constructs for the win32 job object API. +// Most of the core creation and management functions are already present in "golang.org/x/sys/windows" +// (CreateJobObject, AssignProcessToJobObject, etc.) as well as most of the limit information +// structs and associated limit flags. Whatever is not present from the job object API +// in golang.org/x/sys/windows is located in /internal/winapi. +// +// https://docs.microsoft.com/en-us/windows/win32/procthread/job-objects + +// JobObject is a high level wrapper around a Windows job object. Holds a handle to +// the job, a queue to receive iocp notifications about the lifecycle +// of the job and a mutex for synchronized handle access. +type JobObject struct { + handle windows.Handle + mq *queue.MessageQueue + handleLock sync.RWMutex +} + +// JobLimits represents the resource constraints that can be applied to a job object. +type JobLimits struct { + CPULimit uint32 + CPUWeight uint32 + MemoryLimitInBytes uint64 + MaxIOPS int64 + MaxBandwidth int64 +} + +type CPURateControlType uint32 + +const ( + WeightBased CPURateControlType = iota + RateBased +) + +// Processor resource controls +const ( + CPULimitMin = 1 + CPULimitMax = 10000 + CPUWeightMin = 1 + CPUWeightMax = 9 +) + +var ( + ErrAlreadyClosed = errors.New("the handle has already been closed") + ErrNotRegistered = errors.New("job is not registered to receive notifications") +) + +// Create creates a job object. +// +// `name` specifies the name of the job object if a named job object is desired. If name +// is an empty string, the job will not be assigned a name. +// +// `notifications` specifies if the job will be registered to receive notifications. +// If this is false, `PollNotifications` will return immediately with error `errNotRegistered`. +// +// Returns a JobObject structure and an error if there is one. +func Create(ctx context.Context, name string, notifications bool) (_ *JobObject, err error) { + var ( + jobName *uint16 + mq *queue.MessageQueue + ) + + if name != "" { + jobName, err = windows.UTF16PtrFromString(name) + if err != nil { + return nil, err + } + } + + jobHandle, err := windows.CreateJobObject(nil, jobName) + if err != nil { + return nil, err + } + + defer func() { + if err != nil { + windows.Close(jobHandle) + } + }() + + // If the IOCP we'll be using to receive messages for all jobs hasn't been + // created, create it and start polling. + if notifications { + ioInitOnce.Do(func() { + h, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0xffffffff) + if err != nil { + initIOErr = err + return + } + ioCompletionPort = h + go pollIOCP(ctx, h) + }) + + if initIOErr != nil { + return nil, initIOErr + } + + mq = queue.NewMessageQueue() + jobMap.Store(uintptr(jobHandle), mq) + if err = attachIOCP(jobHandle, ioCompletionPort); err != nil { + jobMap.Delete(uintptr(jobHandle)) + return nil, err + } + } + + return &JobObject{ + handle: jobHandle, + mq: mq, + }, nil +} + +// SetResourceLimits sets resource limits on the job object (cpu, memory, storage). +func (job *JobObject) SetResourceLimits(limits *JobLimits) error { + // Go through and check what limits were specified and apply them to the job. + if limits.MemoryLimitInBytes != 0 { + if err := job.SetMemoryLimit(limits.MemoryLimitInBytes); err != nil { + return errors.Wrap(err, "failed to set job object memory limit") + } + } + + if limits.CPULimit != 0 { + if err := job.SetCPULimit(RateBased, limits.CPULimit); err != nil { + return errors.Wrap(err, "failed to set job object cpu limit") + } + } else if limits.CPUWeight != 0 { + if err := job.SetCPULimit(WeightBased, limits.CPUWeight); err != nil { + return errors.Wrap(err, "failed to set job object cpu limit") + } + } + + if limits.MaxBandwidth != 0 || limits.MaxIOPS != 0 { + if err := job.SetIOLimit(limits.MaxBandwidth, limits.MaxIOPS); err != nil { + return errors.Wrap(err, "failed to set io limit on job object") + } + } + return nil +} + +// SetCPULimit sets the CPU limit specified on the job object. +func (job *JobObject) SetCPULimit(rateControlType CPURateControlType, rateControlValue uint32) error { + job.handleLock.RLock() + defer job.handleLock.RUnlock() + + if job.handle == 0 { + return ErrAlreadyClosed + } + + var cpuInfo winapi.JOBOBJECT_CPU_RATE_CONTROL_INFORMATION + switch rateControlType { + case WeightBased: + if rateControlValue < CPUWeightMin || rateControlValue > CPUWeightMax { + return fmt.Errorf("processor weight value of `%d` is invalid", rateControlValue) + } + cpuInfo.ControlFlags = winapi.JOB_OBJECT_CPU_RATE_CONTROL_ENABLE | winapi.JOB_OBJECT_CPU_RATE_CONTROL_WEIGHT_BASED + cpuInfo.Value = rateControlValue + case RateBased: + if rateControlValue < CPULimitMin || rateControlValue > CPULimitMax { + return fmt.Errorf("processor rate of `%d` is invalid", rateControlValue) + } + cpuInfo.ControlFlags = winapi.JOB_OBJECT_CPU_RATE_CONTROL_ENABLE | winapi.JOB_OBJECT_CPU_RATE_CONTROL_HARD_CAP + cpuInfo.Value = rateControlValue + default: + return errors.New("invalid job object cpu rate control type") + } + + _, err := windows.SetInformationJobObject(job.handle, windows.JobObjectCpuRateControlInformation, uintptr(unsafe.Pointer(&cpuInfo)), uint32(unsafe.Sizeof(cpuInfo))) + if err != nil { + return fmt.Errorf("failed to set cpu limit info on job object: %s", err) + } + return nil +} + +// SetMemoryLimit sets the memory limit specified on the job object. +func (job *JobObject) SetMemoryLimit(memoryLimitInBytes uint64) error { + job.handleLock.RLock() + defer job.handleLock.RUnlock() + + if job.handle == 0 { + return ErrAlreadyClosed + } + + var eliInfo windows.JOBOBJECT_EXTENDED_LIMIT_INFORMATION + eliInfo.JobMemoryLimit = uintptr(memoryLimitInBytes) + eliInfo.BasicLimitInformation.LimitFlags = windows.JOB_OBJECT_LIMIT_JOB_MEMORY + _, err := windows.SetInformationJobObject(job.handle, windows.JobObjectExtendedLimitInformation, uintptr(unsafe.Pointer(&eliInfo)), uint32(unsafe.Sizeof(eliInfo))) + if err != nil { + return fmt.Errorf("failed to set extended limit info on job object: %s", err) + } + return nil +} + +// SetIOLimit sets the IO limits specified on the job object. +func (job *JobObject) SetIOLimit(maxBandwidth, maxIOPS int64) error { + job.handleLock.RLock() + defer job.handleLock.RUnlock() + + if job.handle == 0 { + return ErrAlreadyClosed + } + + ioInfo := winapi.JOBOBJECT_IO_RATE_CONTROL_INFORMATION{ + ControlFlags: winapi.JOB_OBJECT_IO_RATE_CONTROL_ENABLE, + } + if maxBandwidth != 0 { + ioInfo.MaxBandwidth = maxBandwidth + } + if maxIOPS != 0 { + ioInfo.MaxIops = maxIOPS + } + _, err := winapi.SetIoRateControlInformationJobObject(job.handle, &ioInfo) + if err != nil { + return fmt.Errorf("failed to set IO limit info on job object: %s", err) + } + return nil +} + +// PollNotification will poll for a job object notification. This call should only be called once +// per job (ideally in a goroutine loop) and will block if there is not a notification ready. +// This call will return immediately with error `ErrNotRegistered` if the job was not registered +// to receive notifications during `Create`. Internally, messages will be queued and there +// is no worry of messages being dropped. +func (job *JobObject) PollNotification() (interface{}, error) { + if job.mq == nil { + return nil, ErrNotRegistered + } + return job.mq.ReadOrWait() +} + +// Close closes the job object handle. +func (job *JobObject) Close() error { + job.handleLock.Lock() + defer job.handleLock.Unlock() + + if job.handle == 0 { + return ErrAlreadyClosed + } + + if err := windows.Close(job.handle); err != nil { + return err + } + + if job.mq != nil { + job.mq.Close() + } + // Handles now invalid so if the map entry to receive notifications for this job still + // exists remove it so we can stop receiving notifications. + if _, ok := jobMap.Load(uintptr(job.handle)); ok { + jobMap.Delete(uintptr(job.handle)) + } + + job.handle = 0 + return nil +} + +// Assign assigns a process to the job object. +func (job *JobObject) Assign(pid uint32) error { + job.handleLock.RLock() + defer job.handleLock.RUnlock() + + if job.handle == 0 { + return ErrAlreadyClosed + } + + if pid == 0 { + return errors.New("invalid pid: 0") + } + hProc, err := windows.OpenProcess(winapi.PROCESS_ALL_ACCESS, true, pid) + if err != nil { + return err + } + defer windows.Close(hProc) + return windows.AssignProcessToJobObject(job.handle, hProc) +} + +// Terminate terminates the job, essentially calls TerminateProcess on every process in the +// job. +func (job *JobObject) Terminate(exitCode uint32) error { + job.handleLock.RLock() + defer job.handleLock.RUnlock() + if job.handle == 0 { + return ErrAlreadyClosed + } + return windows.TerminateJobObject(job.handle, exitCode) +} + +// Pids returns all of the process IDs in the job object. +func (job *JobObject) Pids() ([]uint32, error) { + job.handleLock.RLock() + defer job.handleLock.RUnlock() + + if job.handle == 0 { + return nil, ErrAlreadyClosed + } + + info := winapi.JOBOBJECT_BASIC_PROCESS_ID_LIST{} + err := winapi.QueryInformationJobObject( + job.handle, + winapi.JobObjectBasicProcessIdList, + uintptr(unsafe.Pointer(&info)), + uint32(unsafe.Sizeof(info)), + nil, + ) + + // This is either the case where there is only one process or no processes in + // the job. Any other case will result in ERROR_MORE_DATA. Check if info.NumberOfProcessIdsInList + // is 1 and just return this, otherwise return an empty slice. + if err == nil { + if info.NumberOfProcessIdsInList == 1 { + return []uint32{uint32(info.ProcessIdList[0])}, nil + } + // Return empty slice instead of nil to play well with the caller of this. + // Do not return an error if no processes are running inside the job + return []uint32{}, nil + } + + if err != winapi.ERROR_MORE_DATA { + return nil, fmt.Errorf("failed initial query for PIDs in job object: %s", err) + } + + jobBasicProcessIDListSize := unsafe.Sizeof(info) + (unsafe.Sizeof(info.ProcessIdList[0]) * uintptr(info.NumberOfAssignedProcesses-1)) + buf := make([]byte, jobBasicProcessIDListSize) + if err = winapi.QueryInformationJobObject( + job.handle, + winapi.JobObjectBasicProcessIdList, + uintptr(unsafe.Pointer(&buf[0])), + uint32(len(buf)), + nil, + ); err != nil { + return nil, fmt.Errorf("failed to query for PIDs in job object: %s", err) + } + + bufInfo := (*winapi.JOBOBJECT_BASIC_PROCESS_ID_LIST)(unsafe.Pointer(&buf[0])) + bufPids := bufInfo.AllPids() + pids := make([]uint32, bufInfo.NumberOfProcessIdsInList) + for i, bufPid := range bufPids { + pids[i] = uint32(bufPid) + } + return pids, nil +} + +// QueryMemoryStats gets the memory stats for the job object. +func (job *JobObject) QueryMemoryStats() (*winapi.JOBOBJECT_MEMORY_USAGE_INFORMATION, error) { + job.handleLock.RLock() + defer job.handleLock.RUnlock() + + if job.handle == 0 { + return nil, ErrAlreadyClosed + } + + info := winapi.JOBOBJECT_MEMORY_USAGE_INFORMATION{} + if err := winapi.QueryInformationJobObject( + job.handle, + winapi.JobObjectMemoryUsageInformation, + uintptr(unsafe.Pointer(&info)), + uint32(unsafe.Sizeof(info)), + nil, + ); err != nil { + return nil, fmt.Errorf("failed to query for job object memory stats: %s", err) + } + return &info, nil +} + +// QueryProcessorStats gets the processor stats for the job object. +func (job *JobObject) QueryProcessorStats() (*winapi.JOBOBJECT_BASIC_ACCOUNTING_INFORMATION, error) { + job.handleLock.RLock() + defer job.handleLock.RUnlock() + + if job.handle == 0 { + return nil, ErrAlreadyClosed + } + + info := winapi.JOBOBJECT_BASIC_ACCOUNTING_INFORMATION{} + if err := winapi.QueryInformationJobObject( + job.handle, + winapi.JobObjectBasicAccountingInformation, + uintptr(unsafe.Pointer(&info)), + uint32(unsafe.Sizeof(info)), + nil, + ); err != nil { + return nil, fmt.Errorf("failed to query for job object process stats: %s", err) + } + return &info, nil +} + +// QueryStorageStats gets the storage (I/O) stats for the job object. +func (job *JobObject) QueryStorageStats() (*winapi.JOBOBJECT_BASIC_AND_IO_ACCOUNTING_INFORMATION, error) { + job.handleLock.RLock() + defer job.handleLock.RUnlock() + + if job.handle == 0 { + return nil, ErrAlreadyClosed + } + + info := winapi.JOBOBJECT_BASIC_AND_IO_ACCOUNTING_INFORMATION{} + if err := winapi.QueryInformationJobObject( + job.handle, + winapi.JobObjectBasicAndIoAccountingInformation, + uintptr(unsafe.Pointer(&info)), + uint32(unsafe.Sizeof(info)), + nil, + ); err != nil { + return nil, fmt.Errorf("failed to query for job object storage stats: %s", err) + } + return &info, nil +} diff --git a/internal/queue/mq.go b/internal/queue/mq.go new file mode 100644 index 0000000000..e177c9a629 --- /dev/null +++ b/internal/queue/mq.go @@ -0,0 +1,111 @@ +package queue + +import ( + "errors" + "sync" +) + +var ( + ErrQueueClosed = errors.New("the queue is closed for reading and writing") + ErrQueueEmpty = errors.New("the queue is empty") +) + +// MessageQueue represents a threadsafe message queue to be used to retrieve or +// write messages to. +type MessageQueue struct { + m *sync.RWMutex + c *sync.Cond + messages []interface{} + closed bool +} + +// NewMessageQueue returns a new MessageQueue. +func NewMessageQueue() *MessageQueue { + m := &sync.RWMutex{} + return &MessageQueue{ + m: m, + c: sync.NewCond(m), + messages: []interface{}{}, + } +} + +// Write writes `msg` to the queue. +func (mq *MessageQueue) Write(msg interface{}) error { + mq.m.Lock() + defer mq.m.Unlock() + + if mq.closed { + return ErrQueueClosed + } + mq.messages = append(mq.messages, msg) + // Signal a waiter that there is now a value available in the queue. + mq.c.Signal() + return nil +} + +// Read will read a value from the queue if available, otherwise return an error. +func (mq *MessageQueue) Read() (interface{}, error) { + mq.m.Lock() + defer mq.m.Unlock() + if mq.closed { + return nil, ErrQueueClosed + } + if mq.isEmpty() { + return nil, ErrQueueEmpty + } + val := mq.messages[0] + mq.messages[0] = nil + mq.messages = mq.messages[1:] + return val, nil +} + +// ReadOrWait will read a value from the queue if available, else it will wait for a +// value to become available. This will block forever if nothing gets written or until +// the queue gets closed. +func (mq *MessageQueue) ReadOrWait() (interface{}, error) { + mq.m.Lock() + if mq.closed { + mq.m.Unlock() + return nil, ErrQueueClosed + } + if mq.isEmpty() { + for !mq.closed && mq.isEmpty() { + mq.c.Wait() + } + mq.m.Unlock() + return mq.Read() + } + val := mq.messages[0] + mq.messages[0] = nil + mq.messages = mq.messages[1:] + mq.m.Unlock() + return val, nil +} + +// IsEmpty returns if the queue is empty +func (mq *MessageQueue) IsEmpty() bool { + mq.m.RLock() + defer mq.m.RUnlock() + return len(mq.messages) == 0 +} + +// Nonexported empty check that doesn't lock so we can call this in Read and Write. +func (mq *MessageQueue) isEmpty() bool { + return len(mq.messages) == 0 +} + +// Close closes the queue for future writes or reads. Any attempts to read or write from the +// queue after close will return ErrQueueClosed. This is safe to call multiple times. +func (mq *MessageQueue) Close() { + mq.m.Lock() + defer mq.m.Unlock() + // Already closed + if mq.closed { + return + } + mq.messages = nil + mq.closed = true + // If there's anybody currently waiting on a value from ReadOrWait, we need to + // broadcast so the read(s) can return ErrQueueClosed. + mq.c.Broadcast() +} diff --git a/internal/queue/queue_test.go b/internal/queue/queue_test.go new file mode 100644 index 0000000000..6bf181beca --- /dev/null +++ b/internal/queue/queue_test.go @@ -0,0 +1,199 @@ +package queue + +import ( + "sync" + "testing" + "time" +) + +func TestReadWrite(t *testing.T) { + q := NewMessageQueue() + + // Reading from an empty queue should return ErrQueueEmpty + if _, err := q.Read(); err != ErrQueueEmpty { + t.Fatal("expected to receive `ErrQueueEmpty` for reading from empty queue") + } + + // Write 1 to the queue and read this later. + if err := q.Write(1); err != nil { + t.Fatal(err) + } + + // Read the value. Value will be dequeued. + if msg, err := q.Read(); err != nil || msg != 1 { + t.Fatal(err) + } + + // We just read a value, now try and read again and verify that we get ErrQueueEmpty again. + if _, err := q.Read(); err != ErrQueueEmpty { + t.Fatal(err) + } + + // Close the queue and verify that we get an error on write. + q.Close() + if err := q.Write(1); err != ErrQueueClosed { + t.Fatal(err) + } +} + +func TestReadOrWaitClose(t *testing.T) { + q := NewMessageQueue() + + go func() { + q.Write(1) + q.Write(2) + q.Write(3) + time.Sleep(time.Second * 5) + q.Close() + }() + + time.Sleep(time.Second * 2) + + read := 0 + for { + if _, err := q.ReadOrWait(); err != nil { + if err == ErrQueueClosed && read == 3 { + break + } + t.Fatal(err) + } + read++ + } +} + +func TestReadOrWait(t *testing.T) { + q := NewMessageQueue() + + go func() { + q.Write(1) + q.Write(2) + q.Write(3) + time.Sleep(time.Second * 5) + q.Write(4) + }() + + // Small sleep so that we can give time to ensure a value is written to the queue so we + // can test both states ReadOrWait could be in. These states being there is already a value + // ready for consumption and all we have to do is just read it, or we wait to get signalled of + // an available value. + time.Sleep(time.Second * 1) + timeout := time.After(time.Second * 20) + done := make(chan struct{}) + readErr := make(chan error) + + go func() { + for { + if msg, err := q.ReadOrWait(); err != nil { + readErr <- err + } else { + if msg == 4 { + done <- struct{}{} + break + } + } + } + }() + + select { + case <-timeout: + t.Fatal("timed out waiting for all queue values to be read") + case <-done: + case err := <-readErr: + t.Fatal(err) + } +} + +func TestMultipleReaders(t *testing.T) { + q := NewMessageQueue() + errChan := make(chan error) + done := make(chan struct{}) + go func() { + for i := 0; i < 50; i++ { + if err := q.Write(1); err != nil { + errChan <- err + } + } + }() + + wg := sync.WaitGroup{} + wg.Add(2) + + // Reader 1 + go func() { + for i := 0; i < 25; i++ { + if _, err := q.ReadOrWait(); err != nil { + errChan <- err + } + } + wg.Done() + }() + + // Reader 2 + go func() { + for i := 0; i < 25; i++ { + if _, err := q.ReadOrWait(); err != nil { + errChan <- err + } + } + wg.Done() + }() + + go func() { + wg.Wait() + done <- struct{}{} + }() + + timeout := time.After(time.Second * 20) + + select { + case err := <-errChan: + t.Fatalf("failed in read or write: %s", err) + case <-done: + case <-timeout: + t.Fatalf("timeout exceeded waiting for reads to complete") + } +} + +func TestMultipleReadersClose(t *testing.T) { + q := NewMessageQueue() + errChan := make(chan error) + done := make(chan struct{}) + + wg := sync.WaitGroup{} + wg.Add(2) + + // Reader 1 + go func() { + if _, err := q.ReadOrWait(); err != ErrQueueClosed { + errChan <- err + } + wg.Done() + }() + + // Reader 2 + go func() { + if _, err := q.ReadOrWait(); err != ErrQueueClosed { + errChan <- err + } + wg.Done() + }() + + go func() { + wg.Wait() + done <- struct{}{} + }() + + time.Sleep(time.Second * 2) + // Close the queue and this should signal both readers to return ErrQueueClosed. + q.Close() + + timeout := time.After(time.Second * 20) + + select { + case err := <-errChan: + t.Fatalf("failed in read or write: %s", err) + case <-done: + case <-timeout: + t.Fatalf("timeout exceeded waiting for reads to complete") + } +} diff --git a/internal/winapi/iocp.go b/internal/winapi/iocp.go new file mode 100644 index 0000000000..4e609cbf1c --- /dev/null +++ b/internal/winapi/iocp.go @@ -0,0 +1,3 @@ +package winapi + +//sys GetQueuedCompletionStatus(cphandle windows.Handle, qty *uint32, key *uintptr, overlapped **windows.Overlapped, timeout uint32) (err error) diff --git a/internal/winapi/jobobject.go b/internal/winapi/jobobject.go index 1ea5b18a3d..6cb411cc42 100644 --- a/internal/winapi/jobobject.go +++ b/internal/winapi/jobobject.go @@ -1,22 +1,24 @@ package winapi import ( + "unsafe" + "golang.org/x/sys/windows" ) // Messages that can be received from an assigned io completion port. // https://docs.microsoft.com/en-us/windows/win32/api/winnt/ns-winnt-jobobject_associate_completion_port const ( - JOB_OBJECT_MSG_END_OF_JOB_TIME = 1 - JOB_OBJECT_MSG_END_OF_PROCESS_TIME = 2 - JOB_OBJECT_MSG_ACTIVE_PROCESS_LIMIT = 3 - JOB_OBJECT_MSG_ACTIVE_PROCESS_ZERO = 4 - JOB_OBJECT_MSG_NEW_PROCESS = 6 - JOB_OBJECT_MSG_EXIT_PROCESS = 7 - JOB_OBJECT_MSG_ABNORMAL_EXIT_PROCESS = 8 - JOB_OBJECT_MSG_PROCESS_MEMORY_LIMIT = 9 - JOB_OBJECT_MSG_JOB_MEMORY_LIMIT = 10 - JOB_OBJECT_MSG_NOTIFICATION_LIMIT = 11 + JOB_OBJECT_MSG_END_OF_JOB_TIME uint32 = 1 + JOB_OBJECT_MSG_END_OF_PROCESS_TIME uint32 = 2 + JOB_OBJECT_MSG_ACTIVE_PROCESS_LIMIT uint32 = 3 + JOB_OBJECT_MSG_ACTIVE_PROCESS_ZERO uint32 = 4 + JOB_OBJECT_MSG_NEW_PROCESS uint32 = 6 + JOB_OBJECT_MSG_EXIT_PROCESS uint32 = 7 + JOB_OBJECT_MSG_ABNORMAL_EXIT_PROCESS uint32 = 8 + JOB_OBJECT_MSG_PROCESS_MEMORY_LIMIT uint32 = 9 + JOB_OBJECT_MSG_JOB_MEMORY_LIMIT uint32 = 10 + JOB_OBJECT_MSG_NOTIFICATION_LIMIT uint32 = 11 ) // IO limit flags @@ -24,9 +26,11 @@ const ( // https://docs.microsoft.com/en-us/windows/win32/api/jobapi2/ns-jobapi2-jobobject_io_rate_control_information const JOB_OBJECT_IO_RATE_CONTROL_ENABLE = 0x1 +const JOBOBJECT_IO_ATTRIBUTION_CONTROL_ENABLE uint32 = 0x1 + // https://docs.microsoft.com/en-us/windows/win32/api/winnt/ns-winnt-jobobject_cpu_rate_control_information const ( - JOB_OBJECT_CPU_RATE_CONTROL_ENABLE = 1 << iota + JOB_OBJECT_CPU_RATE_CONTROL_ENABLE uint32 = 1 << iota JOB_OBJECT_CPU_RATE_CONTROL_WEIGHT_BASED JOB_OBJECT_CPU_RATE_CONTROL_HARD_CAP JOB_OBJECT_CPU_RATE_CONTROL_NOTIFY @@ -41,7 +45,9 @@ const ( JobObjectBasicProcessIdList uint32 = 3 JobObjectBasicAndIoAccountingInformation uint32 = 8 JobObjectLimitViolationInformation uint32 = 13 + JobObjectMemoryUsageInformation uint32 = 28 JobObjectNotificationLimitInformation2 uint32 = 33 + JobObjectIoAttribution uint32 = 42 ) // https://docs.microsoft.com/en-us/windows/win32/api/winnt/ns-winnt-jobobject_basic_limit_information @@ -60,7 +66,7 @@ type JOBOBJECT_BASIC_LIMIT_INFORMATION struct { // https://docs.microsoft.com/en-us/windows/win32/api/winnt/ns-winnt-jobobject_cpu_rate_control_information type JOBOBJECT_CPU_RATE_CONTROL_INFORMATION struct { ControlFlags uint32 - Rate uint32 + Value uint32 } // https://docs.microsoft.com/en-us/windows/win32/api/jobapi2/ns-jobapi2-jobobject_io_rate_control_information @@ -80,9 +86,68 @@ type JOBOBJECT_BASIC_PROCESS_ID_LIST struct { ProcessIdList [1]uintptr } +// AllPids returns all the process Ids in the job object. +func (p *JOBOBJECT_BASIC_PROCESS_ID_LIST) AllPids() []uintptr { + return (*[(1 << 27) - 1]uintptr)(unsafe.Pointer(&p.ProcessIdList[0]))[:p.NumberOfProcessIdsInList] +} + +// https://docs.microsoft.com/en-us/windows/win32/api/winnt/ns-winnt-jobobject_basic_accounting_information +type JOBOBJECT_BASIC_ACCOUNTING_INFORMATION struct { + TotalUserTime int64 + TotalKernelTime int64 + ThisPeriodTotalUserTime int64 + ThisPeriodTotalKernelTime int64 + TotalPageFaultCount uint32 + TotalProcesses uint32 + ActiveProcesses uint32 + TotalTerminateProcesses uint32 +} + +//https://docs.microsoft.com/en-us/windows/win32/api/winnt/ns-winnt-jobobject_basic_and_io_accounting_information +type JOBOBJECT_BASIC_AND_IO_ACCOUNTING_INFORMATION struct { + BasicInfo JOBOBJECT_BASIC_ACCOUNTING_INFORMATION + IoInfo windows.IO_COUNTERS +} + +// typedef struct _JOBOBJECT_MEMORY_USAGE_INFORMATION { +// ULONG64 JobMemory; +// ULONG64 PeakJobMemoryUsed; +// } JOBOBJECT_MEMORY_USAGE_INFORMATION, *PJOBOBJECT_MEMORY_USAGE_INFORMATION; +// +type JOBOBJECT_MEMORY_USAGE_INFORMATION struct { + JobMemory uint64 + PeakJobMemoryUsed uint64 +} + +// typedef struct _JOBOBJECT_IO_ATTRIBUTION_STATS { +// ULONG_PTR IoCount; +// ULONGLONG TotalNonOverlappedQueueTime; +// ULONGLONG TotalNonOverlappedServiceTime; +// ULONGLONG TotalSize; +// } JOBOBJECT_IO_ATTRIBUTION_STATS, *PJOBOBJECT_IO_ATTRIBUTION_STATS; +// +type JOBOBJECT_IO_ATTRIBUTION_STATS struct { + IoCount uintptr + TotalNonOverlappedQueueTime uint64 + TotalNonOverlappedServiceTime uint64 + TotalSize uint64 +} + +// typedef struct _JOBOBJECT_IO_ATTRIBUTION_INFORMATION { +// ULONG ControlFlags; +// JOBOBJECT_IO_ATTRIBUTION_STATS ReadStats; +// JOBOBJECT_IO_ATTRIBUTION_STATS WriteStats; +// } JOBOBJECT_IO_ATTRIBUTION_INFORMATION, *PJOBOBJECT_IO_ATTRIBUTION_INFORMATION; +// +type JOBOBJECT_IO_ATTRIBUTION_INFORMATION struct { + ControlFlags uint32 + ReadStats JOBOBJECT_IO_ATTRIBUTION_STATS + WriteStats JOBOBJECT_IO_ATTRIBUTION_STATS +} + // https://docs.microsoft.com/en-us/windows/win32/api/winnt/ns-winnt-jobobject_associate_completion_port type JOBOBJECT_ASSOCIATE_COMPLETION_PORT struct { - CompletionKey uintptr + CompletionKey windows.Handle CompletionPort windows.Handle } diff --git a/internal/winapi/winapi.go b/internal/winapi/winapi.go index 545a87dad4..77ea13e3e3 100644 --- a/internal/winapi/winapi.go +++ b/internal/winapi/winapi.go @@ -2,4 +2,4 @@ // be thought of as an extension to golang.org/x/sys/windows. package winapi -//go:generate go run ..\..\mksyscall_windows.go -output zsyscall_windows.go net.go jobobject.go path.go logon.go memory.go processor.go devices.go filesystem.go errors.go +//go:generate go run ..\..\mksyscall_windows.go -output zsyscall_windows.go net.go iocp.go jobobject.go path.go logon.go memory.go processor.go devices.go filesystem.go errors.go diff --git a/internal/winapi/zsyscall_windows.go b/internal/winapi/zsyscall_windows.go index fabe19c12b..0a990951d7 100644 --- a/internal/winapi/zsyscall_windows.go +++ b/internal/winapi/zsyscall_windows.go @@ -48,6 +48,7 @@ var ( procQueryInformationJobObject = modkernel32.NewProc("QueryInformationJobObject") procOpenJobObjectW = modkernel32.NewProc("OpenJobObjectW") procSetIoRateControlInformationJobObject = modkernel32.NewProc("SetIoRateControlInformationJobObject") + procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus") procSearchPathW = modkernel32.NewProc("SearchPathW") procLogonUserW = modadvapi32.NewProc("LogonUserW") procRtlMoveMemory = modkernel32.NewProc("RtlMoveMemory") @@ -129,6 +130,18 @@ func SetIoRateControlInformationJobObject(jobHandle windows.Handle, ioRateContro return } +func GetQueuedCompletionStatus(cphandle windows.Handle, qty *uint32, key *uintptr, overlapped **windows.Overlapped, timeout uint32) (err error) { + r1, _, e1 := syscall.Syscall6(procGetQueuedCompletionStatus.Addr(), 5, uintptr(cphandle), uintptr(unsafe.Pointer(qty)), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(overlapped)), uintptr(timeout), 0) + if r1 == 0 { + if e1 != 0 { + err = errnoErr(e1) + } else { + err = syscall.EINVAL + } + } + return +} + func SearchPath(lpPath *uint16, lpFileName *uint16, lpExtension *uint16, nBufferLength uint32, lpBuffer *uint16, lpFilePath **uint16) (size uint32, err error) { r0, _, e1 := syscall.Syscall6(procSearchPathW.Addr(), 6, uintptr(unsafe.Pointer(lpPath)), uintptr(unsafe.Pointer(lpFileName)), uintptr(unsafe.Pointer(lpExtension)), uintptr(nBufferLength), uintptr(unsafe.Pointer(lpBuffer)), uintptr(unsafe.Pointer(lpFilePath))) size = uint32(r0)