Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 98 additions & 32 deletions dead_pool_reaper.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ import (
crand "crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"math/rand"
"strings"
"time"

"github.com/gomodule/redigo/redis"
"go.uber.org/multierr"
)

const (
Expand All @@ -20,6 +20,24 @@ const (
requeueKeysPerJob = 4
)

// ReapResult is a set of data that reaper works with.
type ReapResult struct {
// Err is any errors during the reaper cycle.
Err error
// NoPoolHeartBeatJobs is a collection of job names that have been adjusted
// due to outdated worker pool heartbeats.
NoPoolHeartBeatJobs []string
// UnknownPoolJobs is a set of job names that have been adjusted because the
// worker pools working on them are not part of the overall set of worker pools.
UnknownPoolJobs []string
// DanglingLockJobs is a set of job names that have been adjusted due to
// inconsistency in their "lock" and "lock_info" keys.
DanglingLockJobs []string
}

// ReaperHook can be used to monitor the reaper's actions.
type ReaperHook func() (afterHook func(ReapResult))

type deadPoolReaper struct {
namespace string
pool Pool
Expand All @@ -29,9 +47,17 @@ type deadPoolReaper struct {

stopChan chan struct{}
doneStoppingChan chan struct{}

hook ReaperHook
}

func newDeadPoolReaper(namespace string, pool Pool, curJobTypes []string, reapPeriod time.Duration) *deadPoolReaper {
func newDeadPoolReaper(
namespace string,
pool Pool,
curJobTypes []string,
reapPeriod time.Duration,
hook ReaperHook,
) *deadPoolReaper {
if reapPeriod == 0 {
reapPeriod = defaultReapPeriod
}
Expand All @@ -44,6 +70,7 @@ func newDeadPoolReaper(namespace string, pool Pool, curJobTypes []string, reapPe
curJobTypes: curJobTypes,
stopChan: make(chan struct{}),
doneStoppingChan: make(chan struct{}),
hook: hook,
}
}

Expand Down Expand Up @@ -105,82 +132,106 @@ func (r *deadPoolReaper) reap() (err error) {
err = r.releaseLock(lockValue)
}()

rErr := r.reapDeadPools()
cErr := r.clearUnknownPools()
reapResult := ReapResult{}
if r.hook != nil {
finish := r.hook()

if finish != nil {
defer func() { finish(reapResult) }()
}
}

deadPools, rErr := r.reapDeadPools()
if jobs := deadPools.getAllJobs(); len(jobs) != 0 {
Logger.Printf("Reaper: dead pools: %v", deadPools)

reapResult.NoPoolHeartBeatJobs = jobs
}

unknownPools, cErr := r.clearUnknownPools()
if jobs := unknownPools.getAllJobs(); len(jobs) != 0 {
Logger.Printf("Reaper: unknown pools: %v", unknownPools)

reapResult.UnknownPoolJobs = jobs
}

// TODO: consider refactoring requeueInProgressJobs and cleanStaleLockInfo
// and removing removeDanglingLocks. There was a block where lock is 1 and
// lock_info is 0.
dErr := r.removeDanglingLocks()
jobs, dErr := r.removeDanglingLocks()
if len(jobs) != 0 {
Logger.Printf("Reaper: dangling locks: %v", jobs)

reapResult.DanglingLockJobs = jobs
}

return multierr.Combine(err, rErr, cErr, dErr)
reapResult.Err = errors.Join(err, rErr, cErr, dErr)

return reapResult.Err
}

// reapDeadPools collects the IDs of expired heartbeat pools and releases the
// associated resources.
func (r *deadPoolReaper) reapDeadPools() error {
deadPoolIDs, err := r.findDeadPools()
func (r *deadPoolReaper) reapDeadPools() (poolsJobs, error) {
deadPools, err := r.findDeadPools()
if err != nil {
return err
return nil, err
}

Logger.Printf("Reaper: dead pools: %v", deadPoolIDs)

conn := r.pool.Get()
defer conn.Close()

// Cleanup all dead pools
for deadPoolID, jobTypes := range deadPoolIDs {
for deadPoolID, jobTypes := range deadPools {
lockJobTypes := jobTypes
// if we found jobs from the heartbeat, requeue them and remove the heartbeat
if len(jobTypes) > 0 {
if err = r.requeueInProgressJobs(deadPoolID, jobTypes); err != nil {
return err
return deadPools, err
}

if _, err = conn.Do("DEL", redisKeyHeartbeat(r.namespace, deadPoolID)); err != nil {
return err
return deadPools, err
}
} else {
// try to clean up locks for the current set of jobs if heartbeat was not found
lockJobTypes = r.curJobTypes
deadPools[deadPoolID] = r.curJobTypes
}

// Cleanup any stale lock info
if err = r.cleanStaleLockInfo(deadPoolID, lockJobTypes); err != nil {
return err
return deadPools, err
}

// Remove dead pool from worker pools set
if _, err = conn.Do("SREM", redisKeyWorkerPools(r.namespace), deadPoolID); err != nil {
return err
return deadPools, err
}
}

return nil
return deadPools, nil
}

// clearUnknownPools enumerates the lock_info keys, collects pool IDs that are
// not in the worker_pools set, and releases associated locks.
func (r *deadPoolReaper) clearUnknownPools() error {
func (r *deadPoolReaper) clearUnknownPools() (poolsJobs, error) {
unknownPools, err := r.getUnknownPools()
if err != nil {
return err
return nil, err
}

Logger.Printf("Reaper: unknown pools: %v", unknownPools)

for poolID, jobTypes := range unknownPools {
if err = r.requeueInProgressJobs(poolID, jobTypes); err != nil {
return err
return unknownPools, err
}

if err = r.cleanStaleLockInfo(poolID, jobTypes); err != nil {
return err
return unknownPools, err
}
}

return nil
return unknownPools, nil
}

func (r *deadPoolReaper) cleanStaleLockInfo(poolID string, jobTypes []string) error {
Expand Down Expand Up @@ -240,7 +291,7 @@ func (r *deadPoolReaper) requeueInProgressJobs(poolID string, jobTypes []string)
}

// findDeadPools returns staled pools IDs and associated jobs.
func (r *deadPoolReaper) findDeadPools() (map[string][]string, error) {
func (r *deadPoolReaper) findDeadPools() (poolsJobs, error) {
conn := r.pool.Get()
defer conn.Close()

Expand All @@ -250,7 +301,7 @@ func (r *deadPoolReaper) findDeadPools() (map[string][]string, error) {
return nil, err
}

deadPools := make(map[string][]string, len(workerPoolIDs))
deadPools := make(poolsJobs, len(workerPoolIDs))
for _, workerPoolID := range workerPoolIDs {
heartbeatKey := redisKeyHeartbeat(r.namespace, workerPoolID)
heartbeatAt, err := redis.Int64(conn.Do("HGET", heartbeatKey, "heartbeat_at"))
Expand Down Expand Up @@ -284,7 +335,7 @@ func (r *deadPoolReaper) findDeadPools() (map[string][]string, error) {

// getUnknownPools returns the IDs of the unknown pools and associated job types
// found in the lock_info keys.
func (r *deadPoolReaper) getUnknownPools() (map[string][]string, error) {
func (r *deadPoolReaper) getUnknownPools() (poolsJobs, error) {
scriptArgs := make([]interface{}, 0, len(r.curJobTypes)+2) // +2 for keys count and pools key
scriptArgs = append(scriptArgs, len(r.curJobTypes)+1) // +1 for pools key
scriptArgs = append(scriptArgs, redisKeyWorkerPools(r.namespace))
Expand All @@ -301,7 +352,7 @@ func (r *deadPoolReaper) getUnknownPools() (map[string][]string, error) {
return nil, err
}

var pools map[string][]string
var pools poolsJobs

if err := json.Unmarshal(data, &pools); err != nil {
return nil, err
Expand All @@ -323,7 +374,7 @@ func (r *deadPoolReaper) getUnknownPools() (map[string][]string, error) {

// removeDanglingLocks adjusts the lock keys according to the lock_info numbers.
// TODO: it's better to find where the inconsistency comes from.
func (r *deadPoolReaper) removeDanglingLocks() error {
func (r *deadPoolReaper) removeDanglingLocks() ([]string, error) {
keysCount := len(r.curJobTypes) * 2 // lock and lock_info keys
scriptArgs := make([]interface{}, 0, keysCount+1) // +1 for keys count arg
scriptArgs = append(scriptArgs, keysCount)
Expand All @@ -338,12 +389,15 @@ func (r *deadPoolReaper) removeDanglingLocks() error {

keys, err := redis.Strings(redisRemoveDanglingLocksScript.Do(conn, scriptArgs...))
if err != nil {
return err
return nil, err
}

Logger.Printf("Reaper: dangling locks: %v", keys)
// convert lock keys to job types
for i, k := range keys {
keys[i] = redisJobNameFromLockKey(r.namespace, k)
}

return nil
return keys, nil
}

// acquireLock acquires lock with a value and an expiration time for reap period.
Expand Down Expand Up @@ -380,3 +434,15 @@ func genValue() (string, error) {

return base64.StdEncoding.EncodeToString(b), nil
}

type poolsJobs map[string][]string

func (p poolsJobs) getAllJobs() []string {
r := make([]string, 0, len(p))

for _, jobs := range p {
r = append(r, jobs...)
}

return r
}
Loading