From b6d76337bdf3bb718687e90199fab6cbeeabb6ae Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Thu, 23 Apr 2026 17:12:31 -0700 Subject: [PATCH 1/2] Multi-GPU initial setup for TransformerBridge --- .../model_bridge/test_multi_gpu_bridge.py | 299 ++++++++++++++++++ transformer_lens/model_bridge/bridge.py | 81 ++++- .../model_bridge/sources/transformers.py | 90 +++++- transformer_lens/utilities/__init__.py | 3 + transformer_lens/utilities/multi_gpu.py | 109 ++++++- 5 files changed, 564 insertions(+), 18 deletions(-) create mode 100644 tests/acceptance/model_bridge/test_multi_gpu_bridge.py diff --git a/tests/acceptance/model_bridge/test_multi_gpu_bridge.py b/tests/acceptance/model_bridge/test_multi_gpu_bridge.py new file mode 100644 index 000000000..58f84395d --- /dev/null +++ b/tests/acceptance/model_bridge/test_multi_gpu_bridge.py @@ -0,0 +1,299 @@ +"""Multi-GPU support tests for TransformerBridge. + +CPU-runnable tests exercise the resolver / param-plumbing / .to() guard / +validation logic. Tests requiring real multi-GPU hardware are marked skipif. +""" + +from typing import Dict, Union + +import pytest +import torch + +from transformer_lens.model_bridge import TransformerBridge +from transformer_lens.utilities.multi_gpu import ( + count_unique_devices, + find_embedding_device, + resolve_device_map, +) + +# ---------- CPU-runnable tests ---------- + + +class TestResolveDeviceMap: + def test_no_multi_device_returns_none(self): + dm, mm = resolve_device_map(None, None, None) + assert dm is None and mm is None + dm, mm = resolve_device_map(1, None, None) + assert dm is None and mm is None + dm, mm = resolve_device_map(0, None, None) + assert dm is None and mm is None + + def test_explicit_device_map_string_passes_through(self): + dm, mm = resolve_device_map(None, "auto", None) + assert dm == "auto" + assert mm is None + + def test_explicit_device_map_dict_passes_through(self): + explicit: Dict[str, Union[str, int]] = {"transformer.h.0": 0} + dm, mm = resolve_device_map(None, explicit, None) + assert dm is explicit + assert mm is None + + def test_user_max_memory_passes_through(self): + user_mm: Dict[Union[str, int], str] = {0: "20GiB"} + dm, mm = resolve_device_map(None, "auto", None, max_memory=user_mm) + assert dm == "auto" + assert mm is user_mm + + def test_device_and_device_map_mutually_exclusive(self): + with pytest.raises(ValueError, match="mutually exclusive"): + resolve_device_map(None, "auto", "cuda") + + def test_n_devices_without_cuda_raises(self): + if torch.cuda.is_available(): + pytest.skip("CUDA available; this test targets the no-CUDA path.") + with pytest.raises(ValueError, match="requires CUDA"): + resolve_device_map(2, None, None) + + def test_n_devices_exceeds_visible_raises(self): + if not torch.cuda.is_available(): + pytest.skip("CUDA required.") + too_many = torch.cuda.device_count() + 1 + with pytest.raises(ValueError, match="only"): + resolve_device_map(too_many, None, None) + + def test_n_devices_returns_balanced_string_and_max_memory(self): + if not torch.cuda.is_available() or torch.cuda.device_count() < 2: + pytest.skip("Requires 2+ CUDA devices.") + dm, mm = resolve_device_map(2, None, None) + # device_map must be a string directive (HF device_map dicts are keyed by + # submodule path — int keys would fail to match any submodule). + assert dm == "balanced" + assert isinstance(mm, dict) + assert set(mm.keys()) == {0, 1} + + def test_n_devices_respects_user_max_memory(self): + if not torch.cuda.is_available() or torch.cuda.device_count() < 2: + pytest.skip("Requires 2+ CUDA devices.") + user_mm: Dict[Union[str, int], str] = {0: "10GiB", 1: "10GiB"} + dm, mm = resolve_device_map(2, None, None, max_memory=user_mm) + assert dm == "balanced" + assert mm == user_mm + + def test_cpu_value_in_device_map_rejected(self): + bad: Dict[str, Union[str, int]] = {"transformer.h.0": "cpu"} + with pytest.raises(ValueError, match="not supported"): + resolve_device_map(None, bad, None) + + def test_disk_value_in_device_map_rejected(self): + bad: Dict[str, Union[str, int]] = {"transformer.h.0": "disk"} + with pytest.raises(ValueError, match="not supported"): + resolve_device_map(None, bad, None) + + +class TestFindEmbeddingDevice: + def test_returns_none_for_no_device_map(self): + class Stub: + pass + + assert find_embedding_device(Stub()) is None + + def test_uses_get_input_embeddings_when_available(self): + # A stub model with both hf_device_map AND get_input_embeddings should + # consult the embedding module, not the first dict entry. This is the key + # difference from the insertion-order heuristic — covers the multimodal / + # encoder-decoder case where the first map entry is the vision tower. + embed = torch.nn.Embedding(10, 4) + embed = embed.to("cpu") + + class Stub: + hf_device_map = {"vision_tower.stuff": 1, "language_model.embed_tokens": "cpu"} + + def get_input_embeddings(self): + return embed + + result = find_embedding_device(Stub()) + assert result is not None + assert result.type == "cpu" + + def test_falls_back_to_first_entry_when_get_input_embeddings_unavailable(self): + class Stub: + hf_device_map = {"embed_tokens": "cpu", "layers.0": "cpu"} + + assert find_embedding_device(Stub()) == torch.device("cpu") + + def test_handles_int_device_ids_in_fallback(self): + class Stub: + hf_device_map = {"embed_tokens": 0, "layers.0": 1} + + result = find_embedding_device(Stub()) + assert result is not None + assert result.type == "cuda" + assert result.index == 0 + + def test_handles_get_input_embeddings_returning_none(self): + class Stub: + hf_device_map = {"embed_tokens": "cpu"} + + def get_input_embeddings(self): + return None + + assert find_embedding_device(Stub()) == torch.device("cpu") + + +class TestCountUniqueDevices: + def test_no_map_returns_1(self): + class Stub: + pass + + assert count_unique_devices(Stub()) == 1 + + def test_counts_unique_values(self): + class Stub: + hf_device_map = {"a": 0, "b": 0, "c": 1, "d": 1, "e": 2} + + assert count_unique_devices(Stub()) == 3 + + +class TestBootParamValidation: + def test_device_and_device_map_mutually_exclusive(self): + with pytest.raises(ValueError, match="mutually exclusive"): + TransformerBridge.boot_transformers("gpt2", device="cpu", device_map="auto") + + def test_preloaded_with_device_map_rejected(self, gpt2_bridge): + # Passing both hf_model= and device_map/n_devices is ambiguous — the device_map + # would be silently ignored. We raise so the caller isn't surprised. + with pytest.raises(ValueError, match="only supported when the bridge loads"): + TransformerBridge.boot_transformers( + "gpt2", hf_model=gpt2_bridge.original_model, device_map="auto" + ) + + def test_preloaded_with_n_devices_rejected(self, gpt2_bridge): + with pytest.raises(ValueError, match="only supported when the bridge loads"): + TransformerBridge.boot_transformers( + "gpt2", hf_model=gpt2_bridge.original_model, n_devices=2 + ) + + +class TestSingleDevicePathUnchanged: + def test_cpu_load_default_unchanged(self, gpt2_bridge): + # If any of our changes broke the baseline path, existing bridge tests would + # catch it too — this is a smoke check that n_devices stays 1 on the default path. + assert gpt2_bridge.cfg.n_devices == 1 + assert gpt2_bridge.cfg.device is not None + + +class TestToMethodGuardsMultiDevice: + def test_to_warns_and_drops_device_when_n_devices_gt_1(self, gpt2_bridge): + # Simulate a dispatched model by bumping n_devices — we don't need multi-GPU + # hardware to verify the .to() guard path. + original_n_devices = gpt2_bridge.cfg.n_devices + gpt2_bridge.cfg.n_devices = 2 + try: + with pytest.warns(UserWarning, match="ignored"): + gpt2_bridge.to("cpu") + assert next(gpt2_bridge.original_model.parameters()).device.type == "cpu" + finally: + gpt2_bridge.cfg.n_devices = original_n_devices + + def test_to_still_honors_dtype_under_multi_device(self, gpt2_bridge): + original_n_devices = gpt2_bridge.cfg.n_devices + original_dtype = next(gpt2_bridge.original_model.parameters()).dtype + gpt2_bridge.cfg.n_devices = 2 + try: + with pytest.warns(UserWarning, match="ignored"): + gpt2_bridge.to("cpu", torch.float64) + assert next(gpt2_bridge.original_model.parameters()).dtype == torch.float64 + finally: + gpt2_bridge.cfg.n_devices = original_n_devices + gpt2_bridge.original_model.to(original_dtype) + + +class TestRunWithCacheGuardsMultiDevice: + def test_run_with_cache_device_arg_warns_under_multi_device(self, gpt2_bridge): + original_n_devices = gpt2_bridge.cfg.n_devices + gpt2_bridge.cfg.n_devices = 2 + try: + with pytest.warns(UserWarning, match="ignored"): + gpt2_bridge.run_with_cache(torch.tensor([[1, 2, 3]]), device="cpu") + finally: + gpt2_bridge.cfg.n_devices = original_n_devices + + +class TestStackedWeightsHandleCrossDevice: + def test_stack_gathers_across_devices(self, gpt2_bridge): + # Fake multi-device state by flipping cfg.n_devices. The GPT-2 bridge's weights + # still live on CPU, so gathering to cfg.device (also CPU) is a no-op — but the + # code path we care about (the `if n_devices > 1` branch) is exercised. + original_n_devices = gpt2_bridge.cfg.n_devices + gpt2_bridge.cfg.n_devices = 2 + try: + # None of these should raise, even with n_devices>1. + W_Q = gpt2_bridge.W_Q + W_K = gpt2_bridge.W_K + W_V = gpt2_bridge.W_V + W_O = gpt2_bridge.W_O + assert W_Q.shape[0] == gpt2_bridge.cfg.n_layers + assert W_K.shape[0] == gpt2_bridge.cfg.n_layers + assert W_V.shape[0] == gpt2_bridge.cfg.n_layers + assert W_O.shape[0] == gpt2_bridge.cfg.n_layers + finally: + gpt2_bridge.cfg.n_devices = original_n_devices + + def test_accumulated_bias_handles_cross_device(self, gpt2_bridge): + original_n_devices = gpt2_bridge.cfg.n_devices + gpt2_bridge.cfg.n_devices = 2 + try: + # Exercises the .to(accumulated.device) branch without requiring real GPUs. + bias = gpt2_bridge.accumulated_bias(layer=gpt2_bridge.cfg.n_layers - 1) + assert bias.shape == (gpt2_bridge.cfg.d_model,) + finally: + gpt2_bridge.cfg.n_devices = original_n_devices + + +# ---------- Multi-GPU tests (require real hardware) ---------- + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2+ CUDA devices") +class TestMultiDeviceIntegration: + def test_n_devices_matches_single_device_logits(self): + single = TransformerBridge.boot_transformers("gpt2", device="cuda:0") + multi = TransformerBridge.boot_transformers("gpt2", n_devices=2) + + assert multi.cfg.n_devices == 2 + assert single.cfg.n_devices == 1 + + tokens = torch.tensor([[1, 2, 3, 4]]) + logits_single = single(tokens).to("cpu") + logits_multi = multi(tokens).to("cpu") + assert torch.allclose(logits_single, logits_multi, atol=1e-4, rtol=1e-4) + + def test_parameters_distributed_across_devices(self): + bridge = TransformerBridge.boot_transformers("gpt2", n_devices=2) + cuda_indices = { + p.device.index for p in bridge.original_model.parameters() if p.device.type == "cuda" + } + assert cuda_indices == {0, 1} + + def test_generate_works_with_multi_device(self): + bridge = TransformerBridge.boot_transformers("gpt2", n_devices=2) + out = bridge.generate("Hello", max_new_tokens=3, do_sample=False) + assert isinstance(out, str) + assert len(out) > len("Hello") + + def test_stacked_weights_work_across_devices(self): + # Real multi-device exercise of _stack_block_params (no spoofed n_devices). + bridge = TransformerBridge.boot_transformers("gpt2", n_devices=2) + W_Q = bridge.W_Q + assert W_Q.shape[0] == bridge.cfg.n_layers + # After stacking, all elements should be on cfg.device (the embedding device). + assert bridge.cfg.device is not None + assert W_Q.device == torch.device(bridge.cfg.device) + + def test_preloaded_device_map_model(self): + from transformers import AutoModelForCausalLM + + hf_model = AutoModelForCausalLM.from_pretrained("gpt2", device_map="auto") + bridge = TransformerBridge.boot_transformers("gpt2", hf_model=hf_model) + assert bridge.cfg.n_devices >= 1 + assert bridge.cfg.device is not None diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index c2bd5dbd5..cfb3ae7d1 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -167,13 +167,17 @@ def boot_transformers( trust_remote_code: bool = False, model_class: Optional[type] = None, hf_model: Optional[Any] = None, + device_map: Optional[Union[str, Dict[str, Union[str, int]]]] = None, + n_devices: Optional[int] = None, + max_memory: Optional[Dict[Union[str, int], str]] = None, ) -> "TransformerBridge": """Boot a model from HuggingFace (alias for sources.transformers.boot). Args: model_name: The name of the model to load. hf_config_overrides: Optional overrides applied to the HuggingFace config before model load. - device: The device to use. If None, will be determined automatically. + device: The device to use. If None, will be determined automatically. Mutually exclusive + with ``device_map``. dtype: The dtype to use for the model. tokenizer: Optional pre-initialized tokenizer to use; if not provided one will be created. load_weights: If False, load model without weights (on meta device) for config inspection only. @@ -182,7 +186,17 @@ def boot_transformers( auto-detected class (e.g., BertForNextSentencePrediction). hf_model: Optional pre-loaded HuggingFace model to use instead of loading one. Useful for models loaded with custom configurations (e.g., quantization via - BitsAndBytesConfig). When provided, load_weights is ignored. + BitsAndBytesConfig). When provided, load_weights is ignored. If the pre-loaded + model was built with a ``device_map``, ``cfg.device`` and ``cfg.n_devices`` are + derived from its ``hf_device_map`` automatically. + device_map: HuggingFace-style device map for multi-GPU inference. Pass ``"auto"``, + ``"balanced"``, ``"sequential"``, or an explicit ``{submodule_path: device}`` dict. + Mutually exclusive with ``device``. + n_devices: Convenience shortcut: split the model across this many CUDA devices. + Translated to a ``max_memory`` dict over devices 0..n_devices-1 and passed as + ``device_map`` to HF. Requires CUDA with at least this many visible devices. + max_memory: Optional per-device memory budget, passed through to HF's dispatcher. + Only used when ``device_map`` or ``n_devices`` is in effect. Returns: The bridge to the loaded model. @@ -199,6 +213,9 @@ def boot_transformers( trust_remote_code=trust_remote_code, model_class=model_class, hf_model=hf_model, + device_map=device_map, + n_devices=n_devices, + max_memory=max_memory, ) @property @@ -1095,6 +1112,12 @@ def _stack_block_params( if reshape_fn is not None: w = reshape_fn(w) weights.append(w) + # Under a device_map split, per-block tensors live on different devices. + # torch.stack requires a common device; gather onto cfg.device (the embedding / + # input device — a natural "home" for cross-layer reductions). + if getattr(self.cfg, "n_devices", 1) > 1 and weights and self.cfg.device is not None: + target_device = torch.device(self.cfg.device) + weights = [w.to(target_device) for w in weights] return torch.stack(weights, dim=0) def _reshape_qkv(self, w: torch.Tensor) -> torch.Tensor: @@ -1300,17 +1323,17 @@ def accumulated_bias( block = self.blocks[i] b_O = self._get_block_variant_bias(block) if b_O is not None: - accumulated = accumulated + b_O + accumulated = accumulated + b_O.to(accumulated.device) if include_mlp_biases and "mlp" in block._modules: b_out = getattr(block.mlp, "b_out", None) if b_out is not None: - accumulated = accumulated + b_out + accumulated = accumulated + b_out.to(accumulated.device) if mlp_input: assert layer < self.cfg.n_layers, "Cannot include attn_bias from beyond the final layer" block = self.blocks[layer] b_O = self._get_block_variant_bias(block) if b_O is not None: - accumulated = accumulated + b_O + accumulated = accumulated + b_O.to(accumulated.device) return accumulated def all_composition_scores(self, mode: str) -> CompositionScores: @@ -1334,6 +1357,10 @@ def _stack(attr_path: str, reshape_fn: Optional[Callable] = None) -> torch.Tenso if reshape_fn is not None: w = reshape_fn(w) weights.append(w) + # See _stack_block_params: gather per-block tensors onto cfg.device when split. + if getattr(self.cfg, "n_devices", 1) > 1 and weights and self.cfg.device is not None: + target_device = torch.device(self.cfg.device) + weights = [w.to(target_device) for w in weights] return torch.stack(weights, dim=0) W_V = _stack("attn.W_V", self._reshape_qkv) @@ -1940,12 +1967,24 @@ def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor: hooks.append((hook_dict[block_hook_name], block_hook_name)) filtered_kwargs = kwargs.copy() if cache_device is not None: - self.original_model = self.original_model.to(cache_device) - if processed_args and isinstance(processed_args[0], torch.Tensor): - processed_args = [processed_args[0].to(cache_device)] + list(processed_args[1:]) - for key, value in filtered_kwargs.items(): - if isinstance(value, torch.Tensor): - filtered_kwargs[key] = value.to(cache_device) + if getattr(self.cfg, "n_devices", 1) > 1: + # Moving a dispatched model to a single device collapses accelerate's + # split and breaks its routing hooks. The cache will stay spread across + # the per-layer devices; callers can .to(cache_device) on cache entries + # after the fact if they need a single-device cache. + warnings.warn( + f"run_with_cache(device={cache_device!r}) ignored: model is dispatched " + f"across {self.cfg.n_devices} devices via device_map. Cached activations " + "will remain on their per-layer devices.", + stacklevel=2, + ) + else: + self.original_model = self.original_model.to(cache_device) + if processed_args and isinstance(processed_args[0], torch.Tensor): + processed_args = [processed_args[0].to(cache_device)] + list(processed_args[1:]) + for key, value in filtered_kwargs.items(): + if isinstance(value, torch.Tensor): + filtered_kwargs[key] = value.to(cache_device) try: if "output_attentions" not in filtered_kwargs: filtered_kwargs["output_attentions"] = True @@ -2833,12 +2872,30 @@ def to(self, *args, **kwargs) -> "TransformerBridge": if "dtype" in kwargs: target_dtype = kwargs["dtype"] + # Moving a multi-device (device_map-dispatched) model to a single device would + # collapse the split and break accelerate's hook routing. Warn and drop the + # device move; still honor dtype changes. + if target_device is not None and getattr(self.cfg, "n_devices", 1) > 1: + warnings.warn( + f"TransformerBridge.to({target_device!r}) ignored: model is dispatched " + f"across {self.cfg.n_devices} devices via device_map. Reload with " + "device=... (and no device_map/n_devices) to move to a single device.", + stacklevel=2, + ) + target_device = None + if target_device is not None: move_to_and_update_config(self, target_device, print_details) if target_dtype is not None: move_to_and_update_config(self, target_dtype, print_details) - # Move the original model with all original args/kwargs (with print_details removed) + # Move the original model with all original args/kwargs (with print_details removed). + # When we've nulled target_device for multi-GPU safety, strip device args so the + # underlying module isn't moved either. + if target_device is None and (len(args) > 0 or "device" in kwargs): + kwargs.pop("device", None) + # Filter positional args: drop devices/strings, keep dtypes. + args = tuple(a for a in args if not isinstance(a, (torch.device, str))) self.original_model = self.original_model.to(*args, **kwargs) return self diff --git a/transformer_lens/model_bridge/sources/transformers.py b/transformer_lens/model_bridge/sources/transformers.py index 0169da4dc..95dcdadaa 100644 --- a/transformer_lens/model_bridge/sources/transformers.py +++ b/transformer_lens/model_bridge/sources/transformers.py @@ -286,13 +286,17 @@ def boot( trust_remote_code: bool = False, model_class: Any | None = None, hf_model: Any | None = None, + device_map: str | dict[str, str | int] | None = None, + n_devices: int | None = None, + max_memory: dict[str | int, str] | None = None, ) -> TransformerBridge: """Boot a model from HuggingFace. Args: model_name: The name of the model to load. hf_config_overrides: Optional overrides applied to the HuggingFace config before model load. - device: The device to use. If None, will be determined automatically. + device: The device to use. If None, will be determined automatically. Mutually exclusive + with ``device_map``. dtype: The dtype to use for the model. tokenizer: Optional pre-initialized tokenizer to use; if not provided one will be created. load_weights: If False, load model without weights (on meta device) for config inspection only. @@ -302,6 +306,12 @@ def boot( hf_model: Optional pre-loaded HuggingFace model to use instead of loading one. Useful for models loaded with custom configurations (e.g., quantization via BitsAndBytesConfig). When provided, load_weights is ignored. + device_map: HuggingFace-style device map (``"auto"``, ``"balanced"``, dict, etc.) for + multi-GPU inference. Passed straight to ``from_pretrained``. Mutually exclusive + with ``device``. + n_devices: Convenience: split the model across this many CUDA devices (translated to a + ``max_memory`` dict internally). Requires CUDA with at least this many visible devices. + max_memory: Optional per-device memory budget for HF's dispatcher. Returns: The bridge to the loaded model. @@ -376,9 +386,49 @@ def boot( if attn_logit_softcapping is not None: bridge_config.attn_scores_soft_cap = float(attn_logit_softcapping) adapter = ArchitectureAdapterFactory.select_architecture_adapter(bridge_config) - if device is None: - device = get_device() - adapter.cfg.device = str(device) + # Pre-loaded models carry their own weight placement (possibly set by the caller via + # device_map). Passing device_map / n_devices / max_memory alongside hf_model= is + # ambiguous and would silently be ignored, so fail loudly. + if hf_model is not None and ( + device_map is not None or n_devices is not None or max_memory is not None + ): + raise ValueError( + "device_map / n_devices / max_memory are only supported when the bridge loads " + "the HF model itself. When passing hf_model=..., apply device_map via " + "AutoModel.from_pretrained before handing the model to the bridge." + ) + # Stateful/SSM (e.g. Mamba) models keep a per-layer recurrent cache that must live on + # that layer's device. The bridge currently allocates the stateful cache on a single + # cfg.device, so cross-device splits would silently misplace the cache. Block this + # combination until a v2 addresses per-layer stateful cache placement. + if (n_devices is not None and n_devices > 1) or device_map is not None: + if getattr(bridge_config, "is_stateful", False): + raise ValueError( + "Multi-device splits are not yet supported for stateful (SSM / Mamba) " + "architectures: the stateful cache allocation is single-device. " + "Load on one device, or wait for v2 support." + ) + # Resolve device_map before defaulting `device` — the two are mutually exclusive, and + # the resolver raises on conflict. If n_devices>1 is passed, it's translated into a + # device_map + max_memory pair here so downstream code only needs to check the + # resolved values. + from transformer_lens.utilities.multi_gpu import ( + count_unique_devices, + find_embedding_device, + resolve_device_map, + ) + + resolved_device_map, resolved_max_memory = resolve_device_map( + n_devices, device_map, device, max_memory + ) + if resolved_device_map is None: + if device is None: + device = get_device() + adapter.cfg.device = str(device) + else: + # cfg.device will be set from hf_device_map after the model is loaded. + # Provisionally keep it None; find_embedding_device fills it in below. + adapter.cfg.device = None if model_class is None: model_class = get_hf_model_class_for_architecture(architecture) # Ensure pad_token_id exists (v5 raises AttributeError if missing) @@ -393,6 +443,10 @@ def boot( model_kwargs["token"] = _hf_token if trust_remote_code: model_kwargs["trust_remote_code"] = True + if resolved_device_map is not None: + model_kwargs["device_map"] = resolved_device_map + if resolved_max_memory is not None: + model_kwargs["max_memory"] = resolved_max_memory if hasattr(adapter.cfg, "attn_implementation") and adapter.cfg.attn_implementation is not None: model_kwargs["attn_implementation"] = adapter.cfg.attn_implementation else: @@ -410,12 +464,38 @@ def boot( hf_model = model_class.from_config(hf_config, **from_config_kwargs) else: hf_model = model_class.from_pretrained(model_name, **model_kwargs) - if device is not None: + # Skip explicit .to(device) when accelerate has placed weights via device_map. + if resolved_device_map is None and device is not None: hf_model = hf_model.to(device) # Cast params to dtype; preserve float32 buffers (e.g., RotaryEmbedding.inv_freq) for param in hf_model.parameters(): if param.is_floating_point() and param.dtype != dtype: param.data = param.data.to(dtype=dtype) + # Derive cfg.device / cfg.n_devices from hf_device_map when present. This covers: + # - fresh loads with a resolved device_map (set above) + # - pre-loaded hf_model that the caller dispatched themselves (e.g., device_map="auto") + hf_device_map_post = getattr(hf_model, "hf_device_map", None) + if hf_device_map_post: + # Pre-loaded path can still smuggle CPU/disk offload in; validate here too. + offload_values = {str(v).lower() for v in hf_device_map_post.values() if isinstance(v, str)} + forbidden = offload_values & {"cpu", "disk", "meta"} + if forbidden and ((n_devices is not None and n_devices > 1) or device_map is not None): + # Fresh-load path: we set the device_map ourselves, so this shouldn't happen — + # but if the user asked for n_devices>1 and somehow got CPU offload, surface it. + raise ValueError( + f"hf_device_map contains unsupported offload targets: {sorted(forbidden)}. " + "v1 multi-device support is GPU-only." + ) + embedding_device = find_embedding_device(hf_model) + if embedding_device is not None: + adapter.cfg.device = str(embedding_device) + adapter.cfg.n_devices = count_unique_devices(hf_model) + elif adapter.cfg.device is None: + # Pre-loaded single-device model with no hf_device_map — fall back to first param. + try: + adapter.cfg.device = str(next(hf_model.parameters()).device) + except StopIteration: + adapter.cfg.device = "cpu" adapter.prepare_model(hf_model) tokenizer = tokenizer default_padding_side = getattr(adapter.cfg, "default_padding_side", None) diff --git a/transformer_lens/utilities/__init__.py b/transformer_lens/utilities/__init__.py index de5e79db6..56975f90d 100644 --- a/transformer_lens/utilities/__init__.py +++ b/transformer_lens/utilities/__init__.py @@ -50,10 +50,13 @@ # Re-export multi-GPU helpers here (devices.py must not import multi_gpu directly) from .multi_gpu import ( calculate_available_device_cuda_memory, + count_unique_devices, determine_available_memory_for_available_devices, + find_embedding_device, get_best_available_cuda_device, get_best_available_device, get_device_for_block_index, + resolve_device_map, sort_devices_based_on_available_memory, ) from .slice import Slice, SliceInput diff --git a/transformer_lens/utilities/multi_gpu.py b/transformer_lens/utilities/multi_gpu.py index 8b5afaf58..0584604af 100644 --- a/transformer_lens/utilities/multi_gpu.py +++ b/transformer_lens/utilities/multi_gpu.py @@ -5,7 +5,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import torch @@ -141,3 +141,110 @@ def get_device_for_block_index( return device device_index = (device.index or 0) + (index // layers_per_device) return torch.device(device.type, device_index) + + +_UNSUPPORTED_DEVICE_MAP_VALUES = {"cpu", "disk", "meta"} +"""v1 multi-GPU scope is GPU-only. CPU offload and disk offload cause dtype-cast loops to +silently miss offloaded params (meta tensors), and cross-layer hook routing has different +semantics. Reject them explicitly until a v2 scopes those paths.""" + + +def _validate_device_map_values( + device_map: Union[str, Dict[str, Union[str, int]]], +) -> None: + """Reject CPU / disk / meta values in a user-supplied device_map dict.""" + if isinstance(device_map, str): + # "balanced_low_0" is fine — still GPU-only; "cpu" as a string-form device_map + # would tell HF to put everything on CPU, which is single-device and meaningless + # as a multi-GPU config. We allow strings through; HF will validate them. + return + for key, value in device_map.items(): + normalized = str(value).lower() if isinstance(value, str) else None + if normalized in _UNSUPPORTED_DEVICE_MAP_VALUES: + raise ValueError( + f"device_map[{key!r}]={value!r} is not supported. Multi-device bridge " + f"support is GPU-only in v1; CPU / disk / meta offload routes are excluded." + ) + + +def resolve_device_map( + n_devices: Optional[int], + device_map: Optional[Union[str, Dict[str, Union[str, int]]]], + device: Optional[Union[str, torch.device]], + max_memory: Optional[Dict[Union[str, int], str]] = None, +) -> Tuple[Optional[Union[str, Dict[str, Union[str, int]]]], Optional[Dict[Union[str, int], str]]]: + """Resolve ``n_devices`` / ``device_map`` / ``device`` into HF ``from_pretrained`` kwargs. + + Returns ``(device_map, max_memory)`` tuple ready to pass into ``model_kwargs``. + + Semantics: + - Explicit ``device_map`` wins; it's validated and passed through unchanged (user- + provided ``max_memory`` is passed through too). + - ``n_devices=None`` or ``1``: returns ``(None, None)`` — single-device path. + - ``n_devices > 1``: returns ``("balanced", {0: "auto", ..., n-1: "auto"})``. + ``"balanced"`` is accelerate's string directive for balanced layer dispatch; + the ``max_memory`` dict caps visibility to exactly ``n_devices`` GPUs. + """ + if device_map is not None and device is not None: + raise ValueError("device and device_map are mutually exclusive — pass one.") + if device_map is not None: + _validate_device_map_values(device_map) + return device_map, max_memory + if n_devices is None or n_devices <= 1: + return None, max_memory + if not torch.cuda.is_available(): + raise ValueError(f"n_devices={n_devices} requires CUDA, which is not available.") + if torch.cuda.device_count() < n_devices: + raise ValueError( + f"n_devices={n_devices} but only {torch.cuda.device_count()} CUDA devices present." + ) + resolved_max_memory: Dict[Union[str, int], str] = ( + dict(max_memory) if max_memory else {i: "auto" for i in range(n_devices)} + ) + return "balanced", resolved_max_memory + + +def find_embedding_device(hf_model: Any) -> Optional[torch.device]: + """Return the device that input tokens should be placed on for a dispatched HF model. + + When a model is loaded with ``device_map``, accelerate populates ``hf_device_map`` + and inserts pre/post-forward hooks that route activations. Input tensors must land on + the device of whichever module first *consumes* them — the input embedding. Returns + ``None`` for single-device models (no ``hf_device_map`` set). + + Resolves via ``hf_model.get_input_embeddings()`` rather than dict insertion order to + cover encoder-decoder / multimodal / audio architectures where the first entry in + ``hf_device_map`` is not the text-token embedding (e.g. the vision tower on LLaVA). + """ + hf_device_map = getattr(hf_model, "hf_device_map", None) + if not hf_device_map: + return None + # Preferred: ask the model for its input embedding module and read its device. + get_input_embeddings = getattr(hf_model, "get_input_embeddings", None) + if callable(get_input_embeddings): + try: + embed_module = get_input_embeddings() + except (AttributeError, NotImplementedError): + embed_module = None + if embed_module is not None: + try: + param = next(embed_module.parameters()) + return param.device + except StopIteration: + pass + # Fallback: first entry in hf_device_map. Less reliable but better than nothing. + first_device = next(iter(hf_device_map.values())) + if isinstance(first_device, int): + return torch.device("cuda", first_device) + return torch.device(first_device) + + +def count_unique_devices(hf_model: Any) -> int: + """Count the number of unique devices across a dispatched HF model's ``hf_device_map``. + + Returns 1 if the model has no ``hf_device_map`` (single-device load). + """ + hf_device_map = getattr(hf_model, "hf_device_map", None) + if not hf_device_map: + return 1 + return len(set(hf_device_map.values())) From 052da5b3d15ae13b88c788735d6fe59fba665f31 Mon Sep 17 00:00:00 2001 From: jlarson4 Date: Wed, 29 Apr 2026 09:09:31 -0500 Subject: [PATCH 2/2] Added additional documentation note --- transformer_lens/model_bridge/sources/transformers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_lens/model_bridge/sources/transformers.py b/transformer_lens/model_bridge/sources/transformers.py index 1c54f8b50..afc002eee 100644 --- a/transformer_lens/model_bridge/sources/transformers.py +++ b/transformer_lens/model_bridge/sources/transformers.py @@ -286,10 +286,12 @@ def boot( trust_remote_code: bool = False, model_class: Any | None = None, hf_model: Any | None = None, + n_ctx: int | None = None, + # Experimental – Have not been fully tested on multi-gpu devices + # Use at your own risk, report any issues here: https://github.com/TransformerLensOrg/TransformerLens/issues device_map: str | dict[str, str | int] | None = None, n_devices: int | None = None, max_memory: dict[str | int, str] | None = None, - n_ctx: int | None = None, ) -> TransformerBridge: """Boot a model from HuggingFace.