diff --git a/nemo_rl/models/generation/vllm/vllm_worker.py b/nemo_rl/models/generation/vllm/vllm_worker.py index 7ce826a27f..668b18f9f1 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker.py +++ b/nemo_rl/models/generation/vllm/vllm_worker.py @@ -375,7 +375,11 @@ def _patch_vllm_sampler(): # overriden by quant config, however vllm complains if this not passed self.precision = "bfloat16" - vllm_kwargs["hf_overrides"] = self.cfg["vllm_cfg"].get("hf_overrides", {}) or {} + if not isinstance(vllm_kwargs.get("hf_overrides"), dict): + vllm_kwargs["hf_overrides"] = {} + vllm_kwargs["hf_overrides"].update( + self.cfg["vllm_cfg"].get("hf_overrides", {}) or {} + ) llm_kwargs = dict( model=self.model_name,