diff --git a/tests/functional/dpo.sh b/tests/functional/dpo.sh index 1431f17e61..2421c5da6a 100755 --- a/tests/functional/dpo.sh +++ b/tests/functional/dpo.sh @@ -22,6 +22,8 @@ python -u $PROJECT_ROOT/examples/run_dpo.py \ cluster.gpus_per_node=2 \ dpo.max_num_steps=3 \ dpo.val_batches=1 \ + dpo.val_global_batch_size=8 \ + policy.train_global_batch_size=8 \ logger.tensorboard_enabled=true \ logger.log_dir=$LOG_DIR \ logger.wandb_enabled=false \