Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,22 @@ We provide a sample SFT experiment that uses the [SQuAD dataset](https://rajpurk

#### Single Node

The experiment is set up to run on 8 GPUs. If using a machine that has access to 8 GPUs, you can launch the experiment as follows:
The default SFT experiment is configured to run on a single GPU. To launch the experiment,

```sh
uv run python examples/run_sft.py
```

This trains `Llama3.1-8B` on 8 GPUs. To run on a single GPU, we'll have to override a few of the experiment settings. We replace the 8B model with a smaller 1B model, decrease the batch size, and update the cluster configuration to use a single gpu:
This trains `Llama3.2-1B` on one GPU using SQUAD dataset.

If you have access to more GPUs, you can update the experiment accordingly. To run on 8 GPUs, we update the cluster configuration. We also switch to an 8B Llama base model and increase the batch size:

```sh
uv run python examples/run_sft.py \
policy.model_name="meta-llama/Llama-3.2-1B" \
policy.train_global_batch_size=16 \
sft.val_global_batch_size=16 \
cluster.gpus_per_node=1
policy.model_name="meta-llama/Meta-Llama-3-8B" \
policy.train_global_batch_size=128 \
sft.val_global_batch_size=128 \
cluster.gpus_per_node=8
```

Refer to [sft.yaml](examples/configs/sft.yaml) for a full list of parameters that can be overridden.
Expand Down
27 changes: 10 additions & 17 deletions examples/configs/sft.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# SFT Algorithm Configuration
sft:
max_num_steps: 1000
max_num_steps: 60
val_period: 10
val_batches: 8
val_global_batch_size: 128
val_global_batch_size: 32
val_micro_batch_size: 1
val_at_start: true
seed: 42
Expand All @@ -17,10 +17,10 @@ checkpointing:
save_period: 10

policy:
model_name: "meta-llama/Meta-Llama-3-8B"
train_global_batch_size: 128
model_name: "meta-llama/Llama-3.2-1B"
train_global_batch_size: 32
train_micro_batch_size: 1
max_total_sequence_length: 2048
max_total_sequence_length: 1024
precision: "float32"

optimizer:
Expand All @@ -30,32 +30,25 @@ policy:
weight_decay: 0.1
betas: [0.9, 0.98]
eps: 1e-5

scheduler:
name: "torch.optim.lr_scheduler.LinearLR"
kwargs:
start_factor: 0.0196078
end_factor: 1.0
total_iters: 50

data:
max_input_seq_length: ${policy.max_total_sequence_length}
dataset_name: "squad"

logger:
log_dir: "logs" # Base directory for all logs
wandb_enabled: false
tensorboard_enabled: false
wandb_enabled: true # Make sure you do ``wandb login [Your API key]'' before run
tensorboard_enabled: true
monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard
wandb:
project: "sft-dev"
name: "sft-dev-logger"
name: "sft-dev-${data.dataset_name}"
tensorboard:
log_dir: "tb_logs"
log_dir: "tb_logs-sft-dev-${data.dataset_name}"
gpu_monitoring:
collection_interval: 10 # How often to collect GPU usage metrics (in seconds)
flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds)

cluster:
gpus_per_node: 8
gpus_per_node: 1
num_nodes: 1