diff --git a/compass/distillation/distillation.py b/compass/distillation/distillation.py index d97ea40..3867f9b 100644 --- a/compass/distillation/distillation.py +++ b/compass/distillation/distillation.py @@ -158,7 +158,7 @@ class ESDistillationPolicyWrapper(nn.Module): def __init__(self, distillation_policy_ckpt_path: str, embodiment_type: str): super().__init__() # Load the checkpoint and remove the prefix if any. - state_dict = torch.load(distillation_policy_ckpt_path)['state_dict'] + state_dict = torch.load(distillation_policy_ckpt_path, weights_only=False)['state_dict'] state_dict = {k.removeprefix('model.'): v for k, v in state_dict.items()} # Load the state dict. self.model = ESDistillationPolicy()