Public implementation for the paper Generalized Behavior Learning from Diverse Demonstrations, published at ICLR 2025!
Code for an earlier workshop version can be found at the branch corlw.
Our implementation uses pytorch. We recommend setting up a conda environment. env.yaml provides the environment file.
conda env create -f env.yaml
Additionally, mujoco210 is required. Please download and setup from the official source.
To run models for any of the domains, first use the provided base commands and then proceed to modify command line arguments as indicated. For sweep ranges and exact hyperparameter values for each domain and setting, please refer to the Appendix, Table 2.
Importance weight of diversity lambda_I is implemented as 1-dr_cc. For example, to set lambda_I=0.2, use --dr_cc 0.8. For regularization methods (SN, GSD), magnitude of regularization lambda_C is implemented as dl_scale. For example, to set lambda_C=0.1, use --dl_scale 0.1. Distillation strength can be set with --crew_regcf.
We provide the base commands for each domain to train InfoGAIL (IG).
python code/vild_main.py \
--env_id -5 --c_data 1 --v_data 1 --max_step 2500000 --big_batch_size 1000 \
--bc_step 0 \
--il_method infogsdr --rl_method ppo \
--encode_dim 2 --encode_sampling normal \
--info_loss_type bce --clip_discriminator 0 --offset_reward 0 \
--ac_dec 0 --ac_rew 1 \
--dr_cc 0.9 --gp_lambda 0.01 --learning_rate_d 1e-3 \
--p_step 1 --lr_p 1e-3 --wd_p 0 --lr_dk 0 \
--reg_dec 0 --sn_dec 0 \
--dl_scale 0.0001 --dl_linit 100 --dl_llr 1e-3 --dl_slack 1e-6 --dl_l2m 1 \
--cond_rew 0 dl_ztype prior \
--seed 1 --nthreads 1
python code/vild_main.py \
--env_id -20 --c_data 1 --max_step 15000000 \
--il_method infogsdr --rl_method ppo \
--wd_policy 0 \
--encode_dim 2 --encode_sampling normal \
--info_loss_type bce --clip_discriminator 0 --offset_reward 0 \
--ac_dec 1 --ac_rew 1 \
--dr_cc 0.9 --gp_lambda 0.1 --learning_rate_d 1e-3 \
--p_step 1 --lr_p 1e-3 --wd_p 0 --lr_dk 0 \
--reg_dec 0 --sn_dec 0 \
--dl_scale 0.001 --dl_linit 100 --dl_llr 1e-3 --dl_slack 1e-6 --dl_l2m 0 \
--cond_rew 0 --dl_ztype prior \
--seed 1 --nthreads 2
python code/vild_main.py \
--env_id -51 --c_data 1 --max_step 5000000 \
--il_method infogsdr --rl_method ppobc --bc_cf 0.1 --big_batch_size 1000 \
--wd_policy 1e-4 --warmstart 0 \
--encode_dim 2 --encode_sampling normal \
--info_loss_type bce --clip_discriminator 0 --offset_reward 0 \
--ac_dec 1 --ac_rew 1 \
--dr_cc 0.9 --gp_lambda 0.1 --learning_rate_d 1e-3 \
--p_step 1 --lr_p 1e-4 --wd_p 0 --lr_dk 0 --dec_gclip 25 --post_rclip 1 \
--reg_dec 0 --sn_dec 0 \
--dl_scale 0.001 --dl_linit 500 --dl_llr 1e-3 --dl_slack 1e-6 --dl_l2m 0 \
--cond_rew 0 --dl_ztype prior \
--seed 1 --nthreads 1
python code/vild_main.py \
--env_id -43 --c_data 1 --max_step 10000000 \
--il_method infogsdr --rl_method ppobc --bc_cf 0.1 --norm_obs 1 \
--wd_policy 1e-4 \
--encode_dim 2 --encode_sampling normal \
--info_loss_type bce --clip_discriminator 0 --offset_reward 0 \
--ac_dec 1 --ac_rew 1 \
--dr_cc 0.9 --gp_lambda 0.1 --learning_rate_d 1e-3 \
--p_step 1 --lr_p 1e-3 --wd_p 0 --lr_dk 0 \
--reg_dec 0 --sn_dec 0 \
--dl_scale 0.001 --dl_linit 5000 --dl_llr 1e-3 --dl_slack 1e-6 --dl_l2m 0 \
--cond_rew 0 --dl_ztype prior \
--seed 1 --nthreads 2
python code/vild_main.py \
--env_id -60 --c_data 0 --v_data 5 --max_step 20000000 \
--il_method infogsdr --rl_method trpo --norm_obs 1 \
--wd_policy 0 --warmstart 0 \
--encode_dim 2 --encode_sampling normal \
--info_loss_type bce --clip_discriminator 0 --offset_reward 0 \
--ac_dec 1 --ac_rew 1 \
--dr_cc 0.9 --gp_lambda 0.1 --learning_rate_d 1e-3 --d_step 5 \
--p_step 1 --lr_p 1e-3 --wd_p 0 --lr_dk 0 \
--reg_dec 0 --sn_dec 0 \
--dl_scale 0.001 --dl_linit 100 --dl_llr 1e-3 --dl_slack 1e-6 --dl_l2m 0 \
--cond_rew 0 --dl_ztype prior \
--seed 1 --nthreads 1
Interpolation
--c_data 1
Extrapolation
--c_data 2
IG+Lipz
--sn_dec 1
IG+Con
--cond_rew 1 --crew_grew 0
--dec_ezsample 0 --encode_ezsample 0 --crew_expxm 2
IG+ConDist
--cond_rew 1 --crew_grew 1 --crew_regcf 0.02 --clip_discriminator 10
--dec_ezsample 0 --encode_ezsample 0 --crew_expxm 2
IG+ConDist+Lipz
--cond_rew 1 --crew_grew 1 --crew_regcf 0.02 --clip_discriminator 10
--dec_ezsample 0 --encode_ezsample 0 --crew_expxm 2
--sn_dec 1
GSD (Ours)
--cond_rew 1 --crew_grew 1 --crew_regcf 0.02 --clip_discriminator 10
--dec_ezsample 0 --encode_ezsample 0 --crew_expxm 2
--sn_dec 1
--reg_dec 1 --dl_type disc --dl_l2m 1 dl_ztype nglobal --dl_linit 100
Please use the below command to perform evaluation, by setting the correct arguments as indicated below. The sampling process is parallelized by default. To set the desired number of threads, modify code/run_model.py:L630 to set
NPARALLEL=<desired_nthreads>.
python code/run_model.py \
--env_id -20 --c_data 1 \
--mode prior --num_eps 1 --bgt_info `seq 10 10 51` --num_info 1500 \
--test_seed 1 \
--ckptpath results_IL/path/to/dir/ckpt_policy_T15000000.pt
HalfCheetah: --env_id -20 --num_eps 5
DriveLaneshift: --env_id -51
FetchPickPlace: --env_id -43
TableTennis: --env_id -60 --c_data 0 --mode vstats -num_eps 5 --num_info 200
The same commands apply for both generalization settings. The script will log a significant amount of information to the console. Among them, lines of interest take the form:
VL<K>-<GTD> <mean> <std>- Denotes the recovery metric (least MAE) reported in the paper.
RT-VL<K>-<GTD> <mean> <std> <max> <min>- Denotes the reward obtained by the behavior corresponding to the z that minimized MAE with the desired GT factor value
DV50 <val>(Only TableTennis)- Denotes the entropy calculated for the set of GT factor values.
Here, <K> denotes the number of samples considered, <GTD> corresponds to the desired GT factor value among [1, 2, 3, 4, 5] (canonicalized across domains).
Such information should be averaged over train seeds to construct the figures in the paper. The below command does so from log files, each corresponding to the various domains in the paper. Each individual log file contains the console outputs corresponding to the methods and 5 train seeds.
Interpolation:
python code/print_latex.py 1 runc0.log
python code/print_latex.py 1 rund0.log
python code/print_latex.py 1 runf0.log
Extrapolation:
python code/print_latex.py 3 runc2.log
python code/print_latex.py 3 rund2.log
python code/print_latex.py 3 runf2.log
Entropy (TableTennis):
python code/print_latex.py 0 runt.log
Code is available under MIT License.
If you use this work and/or this codebase in your research, we would greatly appreciate being cited as shown below!
@inproceedings{sreeramdass2025generalized,
title={Generalized Behavior Learning from Diverse Demonstrations},
author={Sreeramdass, Varshith and Paleja, Rohan R and Chen, Letian and Van Waveren, Sanne and Gombolay, Matthew},
booktitle={International Conference on Learning Representations},
year={2025},
url={https://openreview.net/pdf?id=Q7EjHroO1w}
}