Skip to content

WPoelman/planit

Repository files navigation

planit

Declaratively build workflow DAGs for SLURM clusters. Define your pipeline as a tree of Step, Chain, and Parallel nodes, and planit handles dependency chains (afterok) and submission via submitit.

Installation

uv add git+https://github.com/WPoelman/planit.git
# or
pip install git+https://github.com/WPoelman/planit.git

As an aside, someone is name-squatting planit on PyPi. I requested a transfer...

Quick start

import logging

import submitit

from planit import Parallel, Plan, Chain, SlurmArgs, Step

logging.basicConfig(level=logging.INFO, format="%(message)s")

GPU = SlurmArgs(time="02:00:00", partition="gpu_a100", gpus_per_node=1, cpus_per_gpu=18)
CPU = SlurmArgs(time="01:00:00", partition="batch", cpus_per_task=8)

plan = Plan("experiment",
    Chain(
        Step("download", CPU, download_data),
        Step("preprocess", CPU, preprocess),
        Parallel(
            Step("train_model_a", GPU, train, "model_a"),
            Step("train_model_b", GPU, train, "model_b"),
        ),
        Step("evaluate", CPU, evaluate),
    ),
)

plan.describe()

executor = submitit.AutoExecutor(folder="slurm_logs")
plan.submit(executor)

describe() prints the DAG and a best-case time estimate:

Plan: experiment
└── ▼ Chain [5:00:00]
    ├── ● download [01:00:00]
    ├── ● preprocess [01:00:00]
    ├── ⇉ Parallel [2:00:00]
    │   ├── ● train_model_a [02:00:00]
    │   └── ● train_model_b [02:00:00]
    └── ● evaluate [01:00:00]
Time Estimate (not taking queuing into account): 5:00:00

Design

Nodes

A plan is a tree built from three Node types:

  • Step(name, slurm_args, func, *args, **kwargs): a single SLURM job. func is called with *args and **kwargs when the job runs.
  • Chain(*nodes): runs children one after another. Each child waits for the previous one to finish (afterok).
  • Parallel(*nodes): runs children concurrently. All children start after the same parent finishes.

With nesting, these can generate different kinds of workflow DAGs:

Chain(
    Step("setup", CPU, setup_fn), # no arguments for 'setup_fn'
    Parallel(
        Step("branch_a", GPU, work_a, "config_a"), # 'work_a' will be called as work_a("config_a")
        Chain(
            Step("branch_b_prep", CPU, prep_b, option=1), # prep_b(option=1)
            Step("branch_b_run", GPU, work_b),
        ),
    ),
    Step("aggregate", CPU, combine_results),
)

You can declaratively define all experiments you need to run for example:

Plan("thesis", Chain(
    Step("download_data", CPU, download),
    Parallel(
        Chain(
            Step("preprocess_en", CPU, preprocess, "en"),
            Parallel(
                # Gets called as train("en", "clm", epochs=10, lr=3e-5)
                Step("en_clm", GPU, train, "en", "clm", epochs=10, lr=3e-5),
                Step("en_mlm", GPU, train, "en", "mlm", epochs=10, lr=3e-5),
            ),
        ),
        Chain(
            Step("preprocess_nl", CPU, preprocess, "nl"),
            Parallel(
                Step("nl_clm", GPU, train, "nl", "clm", epochs=10, lr=3e-5),
                Step("nl_mlm", GPU, train, "nl", "mlm", epochs=10, lr=3e-5),
            ),
        ),
    ),
    Step("evaluate_all", CPU, evaluate),
    Step("generate_plots", CPU, plot),
))

If we assume CPU takes 1 hour and GPU 2, this generates the following DAG with plan.describe():

Plan: thesis
└── ▼ Chain [6:00:00]
    ├── ● download_data [01:00:00]
    ├── ⇉ Parallel [3:00:00]
    │   ├── ▼ Chain [3:00:00]
    │   │   ├── ● preprocess_en [01:00:00]
    │   │   └── ⇉ Parallel [2:00:00]
    │   │       ├── ● en_clm [02:00:00]
    │   │       └── ● en_mlm [02:00:00]
    │   └── ▼ Chain [3:00:00]
    │       ├── ● preprocess_nl [01:00:00]
    │       └── ⇉ Parallel [2:00:00]
    │           ├── ● nl_clm [02:00:00]
    │           └── ● nl_mlm [02:00:00]
    ├── ● evaluate_all [01:00:00]
    └── ● generate_plots [01:00:00]
Time Estimate (not taking queuing into account): 6:00:00

The 6 hours is the critical path:

  • download_data (1h) +
  • slowest parallel branch (preprocess_* 1h + train_clm/train_mlm 2h = 3h) +
  • evaluate_all (1h) +
  • generate_plots (1h)
  • = 6h

This is of course without any potential queue time or jobs finishing early. It's only an estimate of the total requested time.

Lastly, you can also programmatically generate parts of the DAG:

configs = ["small", "medium", "large"]

Plan("grid_search", Chain(
    Step("prepare", CPU, prepare_data),
    Parallel(*[Step(f"train_{cfg}", GPU, train, cfg) for cfg in configs]),
    Step("compare", CPU, compare_results),
))

Which results in:

Plan: grid_search
└── ▼ Chain [4:00:00]
    ├── ● prepare [01:00:00]
    ├── ⇉ Parallel [2:00:00]
    │   ├── ● train_small [02:00:00]
    │   ├── ● train_medium [02:00:00]
    │   └── ● train_large [02:00:00]
    └── ● compare [01:00:00]
Time Estimate (not taking queuing into account): 4:00:00

This can be especially powerful when you have experimental variables that all need to be searched. I like to use itertools for this:

import itertools

# Experimental variables
all_vars = [
    [1, 2, 3],       # first variable
    ["a", "b", "c"], # second variable
    [True, False],   # ...
]

steps = [
    Step(f"step_{idx}", CPU, func, *exp_args)
    for idx, exp_args in enumerate(itertools.product(*all_vars))
]

# This creates Steps that call `func` with:
#  func(1, 'a', True),
#  func(1, 'a', False),
#  func(1, 'b', True),
#  func(1, 'b', False),
#  func(1, 'c', True),
# ...

SLURM config

SlurmArgs

This is a dataclass for defining the SLURM parameters for your job:

from planit import MailType, SlurmArgs

args = SlurmArgs(
    time="03:00:00",
    partition="gpu_a100",
    gpus_per_node=1,
    cpus_per_gpu=18,
    cluster="wice",
    account="my-account",
    mail_type=[MailType.BEGIN, MailType.END, MailType.FAIL],
    mail_user="me@university.edu",
)

This is tailored to my own work on the VSC and may not cover all the configurations you might need on your cluster (see the next section for an alternative). CPU and GPU configurations I regularly use on the VSC are included in example_vsc_args.py.

Available fields:

Field Type Default Description
time str required Wall time (HH:MM:SS, MM:SS, or days-HH:MM:SS)
partition str required SLURM partition
gpus_per_node int 0 GPUs per node
nodes int 1 Number of nodes
cpus_per_task int | None None CPUs per task
cpus_per_gpu int | None None CPUs per GPU
mem_gb int | None None Memory in GB
account str | None None Account name
cluster str | None None Cluster name
mail_type list[MailType] [] Mail notification types
mail_user str | None None Mail recipient
additional_params dict {} Extra slurm_additional_parameters

Raw dict

You can pass a raw dict directly if SlurmArgs doesn't fit for your cluster. This dict should be compatible with submitit. It must (at least) include "slurm_time" so planit can estimate durations:

args = {
    "slurm_time": "02:00:00",
    "slurm_partition": "gpu_v100",
    "gpus_per_node": 1,
    "slurm_additional_parameters": {
        "clusters": "my-cluster",
        "account": "my-account",
        "cpus_per_gpu": 4,
    },
}
Step("train", args, train_fn)

Communication and errors

Communication between jobs is expected to happen through the filesystem. For example, one step writes a checkpoint file and the next step reads it. planit does not pass return values between steps; it only manages the dependency graph, slurm args, and submission.

planit only uses afterok dependencies: a job only starts if all its parents succeeded. If a job fails, SLURM automatically cancels all downstream dependents. If you need to re-run the entire plan, it's up to the step functions to know if something is already done! This is also part of the idea of letting communication happen through the filesystem.

Debugging locally

You can use submitit's "debug" executor to run your plan locally without a cluster (and keeping everything in one process):

executor = submitit.AutoExecutor(folder="slurm_logs", cluster="debug")
plan.submit(executor)
plan.wait()  # blocks until done, runs parallel branches concurrently

wait() walks the DAG structure using threads so parallel branches execute concurrently, mirroring real cluster behavior. This is intended for debugging and short jobs. On a real cluster you should probably not use wait(), unless you know queue times and job durations will be short.


             _____
          .-'.  ':'-.
        .''::: .:    '.
       /   :::::'      \
      ;.    ':' `       ;
      |       '..       |
      ; '      ::::.    ;
       \       '::::   /
        '.      :::  .'
jgs        '-.___'_.-'

Credit: https://www.asciiart.eu/art/a5e06526e7b3ae4b

About

Define SLURM DAG workflows declaratively in Python. Builds on top of submitit.

Resources

License

Stars

Watchers

Forks

Contributors

Languages