From 7a6804a0be33c091170f52d914d7df4062f78b2b Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Thu, 17 Apr 2025 11:27:29 -0700 Subject: [PATCH] Disable weights only for generalist policy loading --- compass/distillation/distillation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compass/distillation/distillation.py b/compass/distillation/distillation.py index d97ea40..3867f9b 100644 --- a/compass/distillation/distillation.py +++ b/compass/distillation/distillation.py @@ -158,7 +158,7 @@ class ESDistillationPolicyWrapper(nn.Module): def __init__(self, distillation_policy_ckpt_path: str, embodiment_type: str): super().__init__() # Load the checkpoint and remove the prefix if any. - state_dict = torch.load(distillation_policy_ckpt_path)['state_dict'] + state_dict = torch.load(distillation_policy_ckpt_path, weights_only=False)['state_dict'] state_dict = {k.removeprefix('model.'): v for k, v in state_dict.items()} # Load the state dict. self.model = ESDistillationPolicy()