Skip to content

Implementation for Guided Strategy Discovery, published at ICLR 2025.

License

Notifications You must be signed in to change notification settings

CORE-Robotics-Lab/GSD

Repository files navigation

Guided Strategy Discovery

Public implementation for the paper Generalized Behavior Learning from Diverse Demonstrations, published at ICLR 2025!

Code for an earlier workshop version can be found at the branch corlw.

Setup

Our implementation uses pytorch. We recommend setting up a conda environment. env.yaml provides the environment file.

conda env create -f env.yaml

Additionally, mujoco210 is required. Please download and setup from the official source.

Training

To run models for any of the domains, first use the provided base commands and then proceed to modify command line arguments as indicated. For sweep ranges and exact hyperparameter values for each domain and setting, please refer to the Appendix, Table 2.

Importance weight of diversity lambda_I is implemented as 1-dr_cc. For example, to set lambda_I=0.2, use --dr_cc 0.8. For regularization methods (SN, GSD), magnitude of regularization lambda_C is implemented as dl_scale. For example, to set lambda_C=0.1, use --dl_scale 0.1. Distillation strength can be set with --crew_regcf.

Base commands

We provide the base commands for each domain to train InfoGAIL (IG).

PointMaze

python code/vild_main.py \
    --env_id -5 --c_data 1 --v_data 1 --max_step 2500000 --big_batch_size 1000 \
    --bc_step 0 \
    --il_method infogsdr --rl_method ppo \
    --encode_dim 2 --encode_sampling normal \
    --info_loss_type bce --clip_discriminator 0 --offset_reward 0 \
    --ac_dec 0 --ac_rew 1 \
    --dr_cc 0.9 --gp_lambda 0.01 --learning_rate_d 1e-3 \
    --p_step 1 --lr_p 1e-3 --wd_p 0 --lr_dk 0 \
    --reg_dec 0 --sn_dec 0 \
    --dl_scale 0.0001 --dl_linit 100 --dl_llr 1e-3 --dl_slack 1e-6  --dl_l2m 1 \
    --cond_rew 0 dl_ztype prior \
    --seed 1 --nthreads 1

HalfCheetah

python code/vild_main.py \
    --env_id -20 --c_data 1 --max_step 15000000 \
    --il_method infogsdr --rl_method ppo \
    --wd_policy 0 \
    --encode_dim 2 --encode_sampling normal \
    --info_loss_type bce --clip_discriminator 0 --offset_reward 0 \
    --ac_dec 1 --ac_rew 1 \
    --dr_cc 0.9 --gp_lambda 0.1 --learning_rate_d 1e-3 \
    --p_step 1 --lr_p 1e-3 --wd_p 0 --lr_dk 0 \
    --reg_dec 0 --sn_dec 0 \
    --dl_scale 0.001 --dl_linit 100 --dl_llr 1e-3 --dl_slack 1e-6 --dl_l2m 0 \
    --cond_rew 0 --dl_ztype prior \
    --seed 1 --nthreads 2

DriveLaneshift

python code/vild_main.py \
    --env_id -51 --c_data 1 --max_step 5000000 \
    --il_method infogsdr --rl_method ppobc --bc_cf 0.1 --big_batch_size 1000 \
    --wd_policy 1e-4 --warmstart 0 \
    --encode_dim 2 --encode_sampling normal \
    --info_loss_type bce --clip_discriminator 0 --offset_reward 0 \
    --ac_dec 1 --ac_rew 1 \
    --dr_cc 0.9 --gp_lambda 0.1 --learning_rate_d 1e-3 \
    --p_step 1 --lr_p 1e-4 --wd_p 0 --lr_dk 0 --dec_gclip 25 --post_rclip 1  \
    --reg_dec 0 --sn_dec 0 \
    --dl_scale 0.001 --dl_linit 500 --dl_llr 1e-3 --dl_slack 1e-6 --dl_l2m 0 \
    --cond_rew 0 --dl_ztype prior \
    --seed 1 --nthreads 1

FetchPickPlace

python code/vild_main.py \
    --env_id -43 --c_data 1 --max_step 10000000 \
    --il_method infogsdr --rl_method ppobc --bc_cf 0.1 --norm_obs 1 \
    --wd_policy 1e-4 \
    --encode_dim 2 --encode_sampling normal \
    --info_loss_type bce --clip_discriminator 0 --offset_reward 0 \
    --ac_dec 1 --ac_rew 1 \
    --dr_cc 0.9 --gp_lambda 0.1 --learning_rate_d 1e-3 \
    --p_step 1 --lr_p 1e-3 --wd_p 0 --lr_dk 0 \
    --reg_dec 0 --sn_dec 0 \
    --dl_scale 0.001 --dl_linit 5000 --dl_llr 1e-3 --dl_slack 1e-6 --dl_l2m 0 \
    --cond_rew 0 --dl_ztype prior \
    --seed 1 --nthreads 2

TableTennis

python code/vild_main.py \
    --env_id -60 --c_data 0 --v_data 5 --max_step 20000000 \
    --il_method infogsdr --rl_method trpo --norm_obs 1 \
    --wd_policy 0 --warmstart 0 \
    --encode_dim 2 --encode_sampling normal \
    --info_loss_type bce --clip_discriminator 0 --offset_reward 0 \
    --ac_dec 1 --ac_rew 1 \
    --dr_cc 0.9 --gp_lambda 0.1 --learning_rate_d 1e-3 --d_step 5 \
    --p_step 1 --lr_p 1e-3 --wd_p 0 --lr_dk 0 \
    --reg_dec 0 --sn_dec 0 \
    --dl_scale 0.001 --dl_linit 100 --dl_llr 1e-3 --dl_slack 1e-6 --dl_l2m 0 \
    --cond_rew 0 --dl_ztype prior \
    --seed 1 --nthreads 1

Modifications to run various evaluation settings

Generalization Setting

Interpolation

--c_data 1

Extrapolation

--c_data 2

Approaches

IG+Lipz

--sn_dec 1

IG+Con

--cond_rew 1 --crew_grew 0
--dec_ezsample 0 --encode_ezsample 0 --crew_expxm 2

IG+ConDist

--cond_rew 1 --crew_grew 1 --crew_regcf 0.02 --clip_discriminator 10
--dec_ezsample 0 --encode_ezsample 0 --crew_expxm 2

IG+ConDist+Lipz

--cond_rew 1 --crew_grew 1 --crew_regcf 0.02 --clip_discriminator 10
--dec_ezsample 0 --encode_ezsample 0 --crew_expxm 2
--sn_dec 1

GSD (Ours)

--cond_rew 1 --crew_grew 1 --crew_regcf 0.02 --clip_discriminator 10
--dec_ezsample 0 --encode_ezsample 0 --crew_expxm 2
--sn_dec 1
--reg_dec 1 --dl_type disc --dl_l2m 1 dl_ztype nglobal --dl_linit 100

Evaluation

Please use the below command to perform evaluation, by setting the correct arguments as indicated below. The sampling process is parallelized by default. To set the desired number of threads, modify code/run_model.py:L630 to set NPARALLEL=<desired_nthreads>.

python code/run_model.py \
    --env_id -20 --c_data 1 \
    --mode prior --num_eps 1 --bgt_info `seq 10 10 51` --num_info 1500 \
    --test_seed 1 \
    --ckptpath results_IL/path/to/dir/ckpt_policy_T15000000.pt

HalfCheetah: --env_id -20 --num_eps 5
DriveLaneshift: --env_id -51
FetchPickPlace: --env_id -43
TableTennis: --env_id -60 --c_data 0 --mode vstats -num_eps 5 --num_info 200

The same commands apply for both generalization settings. The script will log a significant amount of information to the console. Among them, lines of interest take the form:

  • VL<K>-<GTD> <mean> <std>
    • Denotes the recovery metric (least MAE) reported in the paper.
  • RT-VL<K>-<GTD> <mean> <std> <max> <min>
    • Denotes the reward obtained by the behavior corresponding to the z that minimized MAE with the desired GT factor value
  • DV50 <val> (Only TableTennis)
    • Denotes the entropy calculated for the set of GT factor values.

Here, <K> denotes the number of samples considered, <GTD> corresponds to the desired GT factor value among [1, 2, 3, 4, 5] (canonicalized across domains).

Such information should be averaged over train seeds to construct the figures in the paper. The below command does so from log files, each corresponding to the various domains in the paper. Each individual log file contains the console outputs corresponding to the methods and 5 train seeds.

Interpolation:

python code/print_latex.py 1 runc0.log
python code/print_latex.py 1 rund0.log
python code/print_latex.py 1 runf0.log

Extrapolation:

python code/print_latex.py 3 runc2.log
python code/print_latex.py 3 rund2.log
python code/print_latex.py 3 runf2.log

Entropy (TableTennis):

python code/print_latex.py 0 runt.log

License

Code is available under MIT License.

Citation

If you use this work and/or this codebase in your research, we would greatly appreciate being cited as shown below!

@inproceedings{sreeramdass2025generalized,
    title={Generalized Behavior Learning from Diverse Demonstrations},
    author={Sreeramdass, Varshith and Paleja, Rohan R and Chen, Letian and Van Waveren, Sanne and Gombolay, Matthew},
    booktitle={International Conference on Learning Representations},
    year={2025},
    url={https://openreview.net/pdf?id=Q7EjHroO1w}
}

About

Implementation for Guided Strategy Discovery, published at ICLR 2025.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages