Declaratively build workflow DAGs for SLURM clusters. Define your pipeline as a tree of Step, Chain, and Parallel nodes, and planit handles dependency chains (afterok) and submission via submitit.
uv add git+https://github.com/WPoelman/planit.git
# or
pip install git+https://github.com/WPoelman/planit.gitAs an aside, someone is name-squatting
planiton PyPi. I requested a transfer...
import logging
import submitit
from planit import Parallel, Plan, Chain, SlurmArgs, Step
logging.basicConfig(level=logging.INFO, format="%(message)s")
GPU = SlurmArgs(time="02:00:00", partition="gpu_a100", gpus_per_node=1, cpus_per_gpu=18)
CPU = SlurmArgs(time="01:00:00", partition="batch", cpus_per_task=8)
plan = Plan("experiment",
Chain(
Step("download", CPU, download_data),
Step("preprocess", CPU, preprocess),
Parallel(
Step("train_model_a", GPU, train, "model_a"),
Step("train_model_b", GPU, train, "model_b"),
),
Step("evaluate", CPU, evaluate),
),
)
plan.describe()
executor = submitit.AutoExecutor(folder="slurm_logs")
plan.submit(executor)describe() prints the DAG and a best-case time estimate:
Plan: experiment
└── ▼ Chain [5:00:00]
├── ● download [01:00:00]
├── ● preprocess [01:00:00]
├── ⇉ Parallel [2:00:00]
│ ├── ● train_model_a [02:00:00]
│ └── ● train_model_b [02:00:00]
└── ● evaluate [01:00:00]
Time Estimate (not taking queuing into account): 5:00:00
A plan is a tree built from three Node types:
Step(name, slurm_args, func, *args, **kwargs): a single SLURM job.funcis called with*argsand**kwargswhen the job runs.Chain(*nodes): runs children one after another. Each child waits for the previous one to finish (afterok).Parallel(*nodes): runs children concurrently. All children start after the same parent finishes.
With nesting, these can generate different kinds of workflow DAGs:
Chain(
Step("setup", CPU, setup_fn), # no arguments for 'setup_fn'
Parallel(
Step("branch_a", GPU, work_a, "config_a"), # 'work_a' will be called as work_a("config_a")
Chain(
Step("branch_b_prep", CPU, prep_b, option=1), # prep_b(option=1)
Step("branch_b_run", GPU, work_b),
),
),
Step("aggregate", CPU, combine_results),
)You can declaratively define all experiments you need to run for example:
Plan("thesis", Chain(
Step("download_data", CPU, download),
Parallel(
Chain(
Step("preprocess_en", CPU, preprocess, "en"),
Parallel(
# Gets called as train("en", "clm", epochs=10, lr=3e-5)
Step("en_clm", GPU, train, "en", "clm", epochs=10, lr=3e-5),
Step("en_mlm", GPU, train, "en", "mlm", epochs=10, lr=3e-5),
),
),
Chain(
Step("preprocess_nl", CPU, preprocess, "nl"),
Parallel(
Step("nl_clm", GPU, train, "nl", "clm", epochs=10, lr=3e-5),
Step("nl_mlm", GPU, train, "nl", "mlm", epochs=10, lr=3e-5),
),
),
),
Step("evaluate_all", CPU, evaluate),
Step("generate_plots", CPU, plot),
))If we assume CPU takes 1 hour and GPU 2, this generates the following DAG with plan.describe():
Plan: thesis
└── ▼ Chain [6:00:00]
├── ● download_data [01:00:00]
├── ⇉ Parallel [3:00:00]
│ ├── ▼ Chain [3:00:00]
│ │ ├── ● preprocess_en [01:00:00]
│ │ └── ⇉ Parallel [2:00:00]
│ │ ├── ● en_clm [02:00:00]
│ │ └── ● en_mlm [02:00:00]
│ └── ▼ Chain [3:00:00]
│ ├── ● preprocess_nl [01:00:00]
│ └── ⇉ Parallel [2:00:00]
│ ├── ● nl_clm [02:00:00]
│ └── ● nl_mlm [02:00:00]
├── ● evaluate_all [01:00:00]
└── ● generate_plots [01:00:00]
Time Estimate (not taking queuing into account): 6:00:00
The 6 hours is the critical path:
download_data(1h) +- slowest parallel branch (
preprocess_*1h +train_clm/train_mlm2h = 3h) + evaluate_all(1h) +generate_plots(1h)- = 6h
This is of course without any potential queue time or jobs finishing early. It's only an estimate of the total requested time.
Lastly, you can also programmatically generate parts of the DAG:
configs = ["small", "medium", "large"]
Plan("grid_search", Chain(
Step("prepare", CPU, prepare_data),
Parallel(*[Step(f"train_{cfg}", GPU, train, cfg) for cfg in configs]),
Step("compare", CPU, compare_results),
))Which results in:
Plan: grid_search
└── ▼ Chain [4:00:00]
├── ● prepare [01:00:00]
├── ⇉ Parallel [2:00:00]
│ ├── ● train_small [02:00:00]
│ ├── ● train_medium [02:00:00]
│ └── ● train_large [02:00:00]
└── ● compare [01:00:00]
Time Estimate (not taking queuing into account): 4:00:00
This can be especially powerful when you have experimental variables that all need to be searched.
I like to use itertools for this:
import itertools
# Experimental variables
all_vars = [
[1, 2, 3], # first variable
["a", "b", "c"], # second variable
[True, False], # ...
]
steps = [
Step(f"step_{idx}", CPU, func, *exp_args)
for idx, exp_args in enumerate(itertools.product(*all_vars))
]
# This creates Steps that call `func` with:
# func(1, 'a', True),
# func(1, 'a', False),
# func(1, 'b', True),
# func(1, 'b', False),
# func(1, 'c', True),
# ...This is a dataclass for defining the SLURM parameters for your job:
from planit import MailType, SlurmArgs
args = SlurmArgs(
time="03:00:00",
partition="gpu_a100",
gpus_per_node=1,
cpus_per_gpu=18,
cluster="wice",
account="my-account",
mail_type=[MailType.BEGIN, MailType.END, MailType.FAIL],
mail_user="me@university.edu",
)This is tailored to my own work on the VSC and may not cover all the configurations you might need on your cluster (see the next section for an alternative). CPU and GPU configurations I regularly use on the VSC are included in example_vsc_args.py.
Available fields:
| Field | Type | Default | Description |
|---|---|---|---|
time |
str |
required | Wall time (HH:MM:SS, MM:SS, or days-HH:MM:SS) |
partition |
str |
required | SLURM partition |
gpus_per_node |
int |
0 |
GPUs per node |
nodes |
int |
1 |
Number of nodes |
cpus_per_task |
int | None |
None |
CPUs per task |
cpus_per_gpu |
int | None |
None |
CPUs per GPU |
mem_gb |
int | None |
None |
Memory in GB |
account |
str | None |
None |
Account name |
cluster |
str | None |
None |
Cluster name |
mail_type |
list[MailType] |
[] |
Mail notification types |
mail_user |
str | None |
None |
Mail recipient |
additional_params |
dict |
{} |
Extra slurm_additional_parameters |
You can pass a raw dict directly if SlurmArgs doesn't fit for your cluster.
This dict should be compatible with submitit.
It must (at least) include "slurm_time" so planit can estimate durations:
args = {
"slurm_time": "02:00:00",
"slurm_partition": "gpu_v100",
"gpus_per_node": 1,
"slurm_additional_parameters": {
"clusters": "my-cluster",
"account": "my-account",
"cpus_per_gpu": 4,
},
}
Step("train", args, train_fn)Communication between jobs is expected to happen through the filesystem. For example, one step writes a checkpoint file and the next step reads it. planit does not pass return values between steps; it only manages the dependency graph, slurm args, and submission.
planit only uses afterok dependencies: a job only starts if all its parents succeeded. If a job fails, SLURM automatically cancels all downstream dependents. If you need to re-run the entire plan, it's up to the step functions to know if something is already done! This is also part of the idea of letting communication happen through the filesystem.
You can use submitit's "debug" executor to run your plan locally without a cluster (and keeping everything in one process):
executor = submitit.AutoExecutor(folder="slurm_logs", cluster="debug")
plan.submit(executor)
plan.wait() # blocks until done, runs parallel branches concurrentlywait() walks the DAG structure using threads so parallel branches execute concurrently, mirroring real cluster behavior. This is intended for debugging and short jobs.
On a real cluster you should probably not use wait(), unless you know queue times and job durations will be short.
_____
.-'. ':'-.
.''::: .: '.
/ :::::' \
;. ':' ` ;
| '.. |
; ' ::::. ;
\ ':::: /
'. ::: .'
jgs '-.___'_.-'