diff --git a/tests/acceptance/model_bridge/test_multi_gpu_bridge.py b/tests/acceptance/model_bridge/test_multi_gpu_bridge.py new file mode 100644 index 000000000..58f84395d --- /dev/null +++ b/tests/acceptance/model_bridge/test_multi_gpu_bridge.py @@ -0,0 +1,299 @@ +"""Multi-GPU support tests for TransformerBridge. + +CPU-runnable tests exercise the resolver / param-plumbing / .to() guard / +validation logic. Tests requiring real multi-GPU hardware are marked skipif. +""" + +from typing import Dict, Union + +import pytest +import torch + +from transformer_lens.model_bridge import TransformerBridge +from transformer_lens.utilities.multi_gpu import ( + count_unique_devices, + find_embedding_device, + resolve_device_map, +) + +# ---------- CPU-runnable tests ---------- + + +class TestResolveDeviceMap: + def test_no_multi_device_returns_none(self): + dm, mm = resolve_device_map(None, None, None) + assert dm is None and mm is None + dm, mm = resolve_device_map(1, None, None) + assert dm is None and mm is None + dm, mm = resolve_device_map(0, None, None) + assert dm is None and mm is None + + def test_explicit_device_map_string_passes_through(self): + dm, mm = resolve_device_map(None, "auto", None) + assert dm == "auto" + assert mm is None + + def test_explicit_device_map_dict_passes_through(self): + explicit: Dict[str, Union[str, int]] = {"transformer.h.0": 0} + dm, mm = resolve_device_map(None, explicit, None) + assert dm is explicit + assert mm is None + + def test_user_max_memory_passes_through(self): + user_mm: Dict[Union[str, int], str] = {0: "20GiB"} + dm, mm = resolve_device_map(None, "auto", None, max_memory=user_mm) + assert dm == "auto" + assert mm is user_mm + + def test_device_and_device_map_mutually_exclusive(self): + with pytest.raises(ValueError, match="mutually exclusive"): + resolve_device_map(None, "auto", "cuda") + + def test_n_devices_without_cuda_raises(self): + if torch.cuda.is_available(): + pytest.skip("CUDA available; this test targets the no-CUDA path.") + with pytest.raises(ValueError, match="requires CUDA"): + resolve_device_map(2, None, None) + + def test_n_devices_exceeds_visible_raises(self): + if not torch.cuda.is_available(): + pytest.skip("CUDA required.") + too_many = torch.cuda.device_count() + 1 + with pytest.raises(ValueError, match="only"): + resolve_device_map(too_many, None, None) + + def test_n_devices_returns_balanced_string_and_max_memory(self): + if not torch.cuda.is_available() or torch.cuda.device_count() < 2: + pytest.skip("Requires 2+ CUDA devices.") + dm, mm = resolve_device_map(2, None, None) + # device_map must be a string directive (HF device_map dicts are keyed by + # submodule path — int keys would fail to match any submodule). + assert dm == "balanced" + assert isinstance(mm, dict) + assert set(mm.keys()) == {0, 1} + + def test_n_devices_respects_user_max_memory(self): + if not torch.cuda.is_available() or torch.cuda.device_count() < 2: + pytest.skip("Requires 2+ CUDA devices.") + user_mm: Dict[Union[str, int], str] = {0: "10GiB", 1: "10GiB"} + dm, mm = resolve_device_map(2, None, None, max_memory=user_mm) + assert dm == "balanced" + assert mm == user_mm + + def test_cpu_value_in_device_map_rejected(self): + bad: Dict[str, Union[str, int]] = {"transformer.h.0": "cpu"} + with pytest.raises(ValueError, match="not supported"): + resolve_device_map(None, bad, None) + + def test_disk_value_in_device_map_rejected(self): + bad: Dict[str, Union[str, int]] = {"transformer.h.0": "disk"} + with pytest.raises(ValueError, match="not supported"): + resolve_device_map(None, bad, None) + + +class TestFindEmbeddingDevice: + def test_returns_none_for_no_device_map(self): + class Stub: + pass + + assert find_embedding_device(Stub()) is None + + def test_uses_get_input_embeddings_when_available(self): + # A stub model with both hf_device_map AND get_input_embeddings should + # consult the embedding module, not the first dict entry. This is the key + # difference from the insertion-order heuristic — covers the multimodal / + # encoder-decoder case where the first map entry is the vision tower. + embed = torch.nn.Embedding(10, 4) + embed = embed.to("cpu") + + class Stub: + hf_device_map = {"vision_tower.stuff": 1, "language_model.embed_tokens": "cpu"} + + def get_input_embeddings(self): + return embed + + result = find_embedding_device(Stub()) + assert result is not None + assert result.type == "cpu" + + def test_falls_back_to_first_entry_when_get_input_embeddings_unavailable(self): + class Stub: + hf_device_map = {"embed_tokens": "cpu", "layers.0": "cpu"} + + assert find_embedding_device(Stub()) == torch.device("cpu") + + def test_handles_int_device_ids_in_fallback(self): + class Stub: + hf_device_map = {"embed_tokens": 0, "layers.0": 1} + + result = find_embedding_device(Stub()) + assert result is not None + assert result.type == "cuda" + assert result.index == 0 + + def test_handles_get_input_embeddings_returning_none(self): + class Stub: + hf_device_map = {"embed_tokens": "cpu"} + + def get_input_embeddings(self): + return None + + assert find_embedding_device(Stub()) == torch.device("cpu") + + +class TestCountUniqueDevices: + def test_no_map_returns_1(self): + class Stub: + pass + + assert count_unique_devices(Stub()) == 1 + + def test_counts_unique_values(self): + class Stub: + hf_device_map = {"a": 0, "b": 0, "c": 1, "d": 1, "e": 2} + + assert count_unique_devices(Stub()) == 3 + + +class TestBootParamValidation: + def test_device_and_device_map_mutually_exclusive(self): + with pytest.raises(ValueError, match="mutually exclusive"): + TransformerBridge.boot_transformers("gpt2", device="cpu", device_map="auto") + + def test_preloaded_with_device_map_rejected(self, gpt2_bridge): + # Passing both hf_model= and device_map/n_devices is ambiguous — the device_map + # would be silently ignored. We raise so the caller isn't surprised. + with pytest.raises(ValueError, match="only supported when the bridge loads"): + TransformerBridge.boot_transformers( + "gpt2", hf_model=gpt2_bridge.original_model, device_map="auto" + ) + + def test_preloaded_with_n_devices_rejected(self, gpt2_bridge): + with pytest.raises(ValueError, match="only supported when the bridge loads"): + TransformerBridge.boot_transformers( + "gpt2", hf_model=gpt2_bridge.original_model, n_devices=2 + ) + + +class TestSingleDevicePathUnchanged: + def test_cpu_load_default_unchanged(self, gpt2_bridge): + # If any of our changes broke the baseline path, existing bridge tests would + # catch it too — this is a smoke check that n_devices stays 1 on the default path. + assert gpt2_bridge.cfg.n_devices == 1 + assert gpt2_bridge.cfg.device is not None + + +class TestToMethodGuardsMultiDevice: + def test_to_warns_and_drops_device_when_n_devices_gt_1(self, gpt2_bridge): + # Simulate a dispatched model by bumping n_devices — we don't need multi-GPU + # hardware to verify the .to() guard path. + original_n_devices = gpt2_bridge.cfg.n_devices + gpt2_bridge.cfg.n_devices = 2 + try: + with pytest.warns(UserWarning, match="ignored"): + gpt2_bridge.to("cpu") + assert next(gpt2_bridge.original_model.parameters()).device.type == "cpu" + finally: + gpt2_bridge.cfg.n_devices = original_n_devices + + def test_to_still_honors_dtype_under_multi_device(self, gpt2_bridge): + original_n_devices = gpt2_bridge.cfg.n_devices + original_dtype = next(gpt2_bridge.original_model.parameters()).dtype + gpt2_bridge.cfg.n_devices = 2 + try: + with pytest.warns(UserWarning, match="ignored"): + gpt2_bridge.to("cpu", torch.float64) + assert next(gpt2_bridge.original_model.parameters()).dtype == torch.float64 + finally: + gpt2_bridge.cfg.n_devices = original_n_devices + gpt2_bridge.original_model.to(original_dtype) + + +class TestRunWithCacheGuardsMultiDevice: + def test_run_with_cache_device_arg_warns_under_multi_device(self, gpt2_bridge): + original_n_devices = gpt2_bridge.cfg.n_devices + gpt2_bridge.cfg.n_devices = 2 + try: + with pytest.warns(UserWarning, match="ignored"): + gpt2_bridge.run_with_cache(torch.tensor([[1, 2, 3]]), device="cpu") + finally: + gpt2_bridge.cfg.n_devices = original_n_devices + + +class TestStackedWeightsHandleCrossDevice: + def test_stack_gathers_across_devices(self, gpt2_bridge): + # Fake multi-device state by flipping cfg.n_devices. The GPT-2 bridge's weights + # still live on CPU, so gathering to cfg.device (also CPU) is a no-op — but the + # code path we care about (the `if n_devices > 1` branch) is exercised. + original_n_devices = gpt2_bridge.cfg.n_devices + gpt2_bridge.cfg.n_devices = 2 + try: + # None of these should raise, even with n_devices>1. + W_Q = gpt2_bridge.W_Q + W_K = gpt2_bridge.W_K + W_V = gpt2_bridge.W_V + W_O = gpt2_bridge.W_O + assert W_Q.shape[0] == gpt2_bridge.cfg.n_layers + assert W_K.shape[0] == gpt2_bridge.cfg.n_layers + assert W_V.shape[0] == gpt2_bridge.cfg.n_layers + assert W_O.shape[0] == gpt2_bridge.cfg.n_layers + finally: + gpt2_bridge.cfg.n_devices = original_n_devices + + def test_accumulated_bias_handles_cross_device(self, gpt2_bridge): + original_n_devices = gpt2_bridge.cfg.n_devices + gpt2_bridge.cfg.n_devices = 2 + try: + # Exercises the .to(accumulated.device) branch without requiring real GPUs. + bias = gpt2_bridge.accumulated_bias(layer=gpt2_bridge.cfg.n_layers - 1) + assert bias.shape == (gpt2_bridge.cfg.d_model,) + finally: + gpt2_bridge.cfg.n_devices = original_n_devices + + +# ---------- Multi-GPU tests (require real hardware) ---------- + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires 2+ CUDA devices") +class TestMultiDeviceIntegration: + def test_n_devices_matches_single_device_logits(self): + single = TransformerBridge.boot_transformers("gpt2", device="cuda:0") + multi = TransformerBridge.boot_transformers("gpt2", n_devices=2) + + assert multi.cfg.n_devices == 2 + assert single.cfg.n_devices == 1 + + tokens = torch.tensor([[1, 2, 3, 4]]) + logits_single = single(tokens).to("cpu") + logits_multi = multi(tokens).to("cpu") + assert torch.allclose(logits_single, logits_multi, atol=1e-4, rtol=1e-4) + + def test_parameters_distributed_across_devices(self): + bridge = TransformerBridge.boot_transformers("gpt2", n_devices=2) + cuda_indices = { + p.device.index for p in bridge.original_model.parameters() if p.device.type == "cuda" + } + assert cuda_indices == {0, 1} + + def test_generate_works_with_multi_device(self): + bridge = TransformerBridge.boot_transformers("gpt2", n_devices=2) + out = bridge.generate("Hello", max_new_tokens=3, do_sample=False) + assert isinstance(out, str) + assert len(out) > len("Hello") + + def test_stacked_weights_work_across_devices(self): + # Real multi-device exercise of _stack_block_params (no spoofed n_devices). + bridge = TransformerBridge.boot_transformers("gpt2", n_devices=2) + W_Q = bridge.W_Q + assert W_Q.shape[0] == bridge.cfg.n_layers + # After stacking, all elements should be on cfg.device (the embedding device). + assert bridge.cfg.device is not None + assert W_Q.device == torch.device(bridge.cfg.device) + + def test_preloaded_device_map_model(self): + from transformers import AutoModelForCausalLM + + hf_model = AutoModelForCausalLM.from_pretrained("gpt2", device_map="auto") + bridge = TransformerBridge.boot_transformers("gpt2", hf_model=hf_model) + assert bridge.cfg.n_devices >= 1 + assert bridge.cfg.device is not None diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index d0601aa63..7c9f7b915 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -169,6 +169,9 @@ def boot_transformers( trust_remote_code: bool = False, model_class: Optional[type] = None, hf_model: Optional[Any] = None, + device_map: Optional[Union[str, Dict[str, Union[str, int]]]] = None, + n_devices: Optional[int] = None, + max_memory: Optional[Dict[Union[str, int], str]] = None, n_ctx: Optional[int] = None, ) -> "TransformerBridge": """Boot a model from HuggingFace (alias for sources.transformers.boot). @@ -181,7 +184,8 @@ def boot_transformers( Args: model_name: The name of the model to load. hf_config_overrides: Optional overrides applied to the HuggingFace config before model load. - device: The device to use. If None, will be determined automatically. + device: The device to use. If None, will be determined automatically. Mutually exclusive + with ``device_map``. dtype: The dtype to use for the model. tokenizer: Optional pre-initialized tokenizer to use; if not provided one will be created. load_weights: If False, load model without weights (on meta device) for config inspection only. @@ -190,7 +194,17 @@ def boot_transformers( auto-detected class (e.g., BertForNextSentencePrediction). hf_model: Optional pre-loaded HuggingFace model to use instead of loading one. Useful for models loaded with custom configurations (e.g., quantization via - BitsAndBytesConfig). When provided, load_weights is ignored. + BitsAndBytesConfig). When provided, load_weights is ignored. If the pre-loaded + model was built with a ``device_map``, ``cfg.device`` and ``cfg.n_devices`` are + derived from its ``hf_device_map`` automatically. + device_map: HuggingFace-style device map for multi-GPU inference. Pass ``"auto"``, + ``"balanced"``, ``"sequential"``, or an explicit ``{submodule_path: device}`` dict. + Mutually exclusive with ``device``. + n_devices: Convenience shortcut: split the model across this many CUDA devices. + Translated to a ``max_memory`` dict over devices 0..n_devices-1 and passed as + ``device_map`` to HF. Requires CUDA with at least this many visible devices. + max_memory: Optional per-device memory budget, passed through to HF's dispatcher. + Only used when ``device_map`` or ``n_devices`` is in effect. n_ctx: Optional context length override. Writes to the appropriate HF config field for this model automatically (callers don't need to know the field name). Warns if larger than the model's default context length. @@ -210,6 +224,9 @@ def boot_transformers( trust_remote_code=trust_remote_code, model_class=model_class, hf_model=hf_model, + device_map=device_map, + n_devices=n_devices, + max_memory=max_memory, n_ctx=n_ctx, ) @@ -1109,6 +1126,12 @@ def _stack_block_params( if reshape_fn is not None: w = reshape_fn(w) weights.append(w) + # Under a device_map split, per-block tensors live on different devices. + # torch.stack requires a common device; gather onto cfg.device (the embedding / + # input device — a natural "home" for cross-layer reductions). + if getattr(self.cfg, "n_devices", 1) > 1 and weights and self.cfg.device is not None: + target_device = torch.device(self.cfg.device) + weights = [w.to(target_device) for w in weights] return torch.stack(weights, dim=0) def _reshape_qkv(self, w: torch.Tensor) -> torch.Tensor: @@ -1314,17 +1337,17 @@ def accumulated_bias( block = self.blocks[i] b_O = self._get_block_variant_bias(block) if b_O is not None: - accumulated = accumulated + b_O + accumulated = accumulated + b_O.to(accumulated.device) if include_mlp_biases and "mlp" in block._modules: b_out = getattr(block.mlp, "b_out", None) if b_out is not None: - accumulated = accumulated + b_out + accumulated = accumulated + b_out.to(accumulated.device) if mlp_input: assert layer < self.cfg.n_layers, "Cannot include attn_bias from beyond the final layer" block = self.blocks[layer] b_O = self._get_block_variant_bias(block) if b_O is not None: - accumulated = accumulated + b_O + accumulated = accumulated + b_O.to(accumulated.device) return accumulated def all_composition_scores(self, mode: str) -> CompositionScores: @@ -1348,6 +1371,10 @@ def _stack(attr_path: str, reshape_fn: Optional[Callable] = None) -> torch.Tenso if reshape_fn is not None: w = reshape_fn(w) weights.append(w) + # See _stack_block_params: gather per-block tensors onto cfg.device when split. + if getattr(self.cfg, "n_devices", 1) > 1 and weights and self.cfg.device is not None: + target_device = torch.device(self.cfg.device) + weights = [w.to(target_device) for w in weights] return torch.stack(weights, dim=0) W_V = _stack("attn.W_V", self._reshape_qkv) @@ -1954,12 +1981,24 @@ def stop_hook(tensor: torch.Tensor, *, hook: Any) -> torch.Tensor: hooks.append((hook_dict[block_hook_name], block_hook_name)) filtered_kwargs = kwargs.copy() if cache_device is not None: - self.original_model = self.original_model.to(cache_device) - if processed_args and isinstance(processed_args[0], torch.Tensor): - processed_args = [processed_args[0].to(cache_device)] + list(processed_args[1:]) - for key, value in filtered_kwargs.items(): - if isinstance(value, torch.Tensor): - filtered_kwargs[key] = value.to(cache_device) + if getattr(self.cfg, "n_devices", 1) > 1: + # Moving a dispatched model to a single device collapses accelerate's + # split and breaks its routing hooks. The cache will stay spread across + # the per-layer devices; callers can .to(cache_device) on cache entries + # after the fact if they need a single-device cache. + warnings.warn( + f"run_with_cache(device={cache_device!r}) ignored: model is dispatched " + f"across {self.cfg.n_devices} devices via device_map. Cached activations " + "will remain on their per-layer devices.", + stacklevel=2, + ) + else: + self.original_model = self.original_model.to(cache_device) + if processed_args and isinstance(processed_args[0], torch.Tensor): + processed_args = [processed_args[0].to(cache_device)] + list(processed_args[1:]) + for key, value in filtered_kwargs.items(): + if isinstance(value, torch.Tensor): + filtered_kwargs[key] = value.to(cache_device) try: if "output_attentions" not in filtered_kwargs: filtered_kwargs["output_attentions"] = True @@ -3070,12 +3109,30 @@ def to(self, *args, **kwargs) -> "TransformerBridge": if "dtype" in kwargs: target_dtype = kwargs["dtype"] + # Moving a multi-device (device_map-dispatched) model to a single device would + # collapse the split and break accelerate's hook routing. Warn and drop the + # device move; still honor dtype changes. + if target_device is not None and getattr(self.cfg, "n_devices", 1) > 1: + warnings.warn( + f"TransformerBridge.to({target_device!r}) ignored: model is dispatched " + f"across {self.cfg.n_devices} devices via device_map. Reload with " + "device=... (and no device_map/n_devices) to move to a single device.", + stacklevel=2, + ) + target_device = None + if target_device is not None: move_to_and_update_config(self, target_device, print_details) if target_dtype is not None: move_to_and_update_config(self, target_dtype, print_details) - # Move the original model with all original args/kwargs (with print_details removed) + # Move the original model with all original args/kwargs (with print_details removed). + # When we've nulled target_device for multi-GPU safety, strip device args so the + # underlying module isn't moved either. + if target_device is None and (len(args) > 0 or "device" in kwargs): + kwargs.pop("device", None) + # Filter positional args: drop devices/strings, keep dtypes. + args = tuple(a for a in args if not isinstance(a, (torch.device, str))) self.original_model = self.original_model.to(*args, **kwargs) return self diff --git a/transformer_lens/model_bridge/sources/transformers.py b/transformer_lens/model_bridge/sources/transformers.py index 99b90a968..afc002eee 100644 --- a/transformer_lens/model_bridge/sources/transformers.py +++ b/transformer_lens/model_bridge/sources/transformers.py @@ -287,13 +287,19 @@ def boot( model_class: Any | None = None, hf_model: Any | None = None, n_ctx: int | None = None, + # Experimental – Have not been fully tested on multi-gpu devices + # Use at your own risk, report any issues here: https://github.com/TransformerLensOrg/TransformerLens/issues + device_map: str | dict[str, str | int] | None = None, + n_devices: int | None = None, + max_memory: dict[str | int, str] | None = None, ) -> TransformerBridge: """Boot a model from HuggingFace. Args: model_name: The name of the model to load. hf_config_overrides: Optional overrides applied to the HuggingFace config before model load. - device: The device to use. If None, will be determined automatically. + device: The device to use. If None, will be determined automatically. Mutually exclusive + with ``device_map``. dtype: The dtype to use for the model. tokenizer: Optional pre-initialized tokenizer to use; if not provided one will be created. load_weights: If False, load model without weights (on meta device) for config inspection only. @@ -303,6 +309,12 @@ def boot( hf_model: Optional pre-loaded HuggingFace model to use instead of loading one. Useful for models loaded with custom configurations (e.g., quantization via BitsAndBytesConfig). When provided, load_weights is ignored. + device_map: HuggingFace-style device map (``"auto"``, ``"balanced"``, dict, etc.) for + multi-GPU inference. Passed straight to ``from_pretrained``. Mutually exclusive + with ``device``. + n_devices: Convenience: split the model across this many CUDA devices (translated to a + ``max_memory`` dict internally). Requires CUDA with at least this many visible devices. + max_memory: Optional per-device memory budget for HF's dispatcher. n_ctx: Optional context length override. The bridge normally uses the model's documented max context from the HF config. Setting this writes to whichever HF field the model uses (n_positions / max_position_embeddings / etc.), so callers don't need to know @@ -430,9 +442,49 @@ def boot( if attn_logit_softcapping is not None: bridge_config.attn_scores_soft_cap = float(attn_logit_softcapping) adapter = ArchitectureAdapterFactory.select_architecture_adapter(bridge_config) - if device is None: - device = get_device() - adapter.cfg.device = str(device) + # Pre-loaded models carry their own weight placement (possibly set by the caller via + # device_map). Passing device_map / n_devices / max_memory alongside hf_model= is + # ambiguous and would silently be ignored, so fail loudly. + if hf_model is not None and ( + device_map is not None or n_devices is not None or max_memory is not None + ): + raise ValueError( + "device_map / n_devices / max_memory are only supported when the bridge loads " + "the HF model itself. When passing hf_model=..., apply device_map via " + "AutoModel.from_pretrained before handing the model to the bridge." + ) + # Stateful/SSM (e.g. Mamba) models keep a per-layer recurrent cache that must live on + # that layer's device. The bridge currently allocates the stateful cache on a single + # cfg.device, so cross-device splits would silently misplace the cache. Block this + # combination until a v2 addresses per-layer stateful cache placement. + if (n_devices is not None and n_devices > 1) or device_map is not None: + if getattr(bridge_config, "is_stateful", False): + raise ValueError( + "Multi-device splits are not yet supported for stateful (SSM / Mamba) " + "architectures: the stateful cache allocation is single-device. " + "Load on one device, or wait for v2 support." + ) + # Resolve device_map before defaulting `device` — the two are mutually exclusive, and + # the resolver raises on conflict. If n_devices>1 is passed, it's translated into a + # device_map + max_memory pair here so downstream code only needs to check the + # resolved values. + from transformer_lens.utilities.multi_gpu import ( + count_unique_devices, + find_embedding_device, + resolve_device_map, + ) + + resolved_device_map, resolved_max_memory = resolve_device_map( + n_devices, device_map, device, max_memory + ) + if resolved_device_map is None: + if device is None: + device = get_device() + adapter.cfg.device = str(device) + else: + # cfg.device will be set from hf_device_map after the model is loaded. + # Provisionally keep it None; find_embedding_device fills it in below. + adapter.cfg.device = None if model_class is None: model_class = get_hf_model_class_for_architecture(architecture) # Ensure pad_token_id exists (v5 raises AttributeError if missing) @@ -447,6 +499,10 @@ def boot( model_kwargs["token"] = _hf_token if trust_remote_code: model_kwargs["trust_remote_code"] = True + if resolved_device_map is not None: + model_kwargs["device_map"] = resolved_device_map + if resolved_max_memory is not None: + model_kwargs["max_memory"] = resolved_max_memory if hasattr(adapter.cfg, "attn_implementation") and adapter.cfg.attn_implementation is not None: model_kwargs["attn_implementation"] = adapter.cfg.attn_implementation else: @@ -482,12 +538,38 @@ def boot( f"weight mismatch." ) from e raise - if device is not None: + # Skip explicit .to(device) when accelerate has placed weights via device_map. + if resolved_device_map is None and device is not None: hf_model = hf_model.to(device) # Cast params to dtype; preserve float32 buffers (e.g., RotaryEmbedding.inv_freq) for param in hf_model.parameters(): if param.is_floating_point() and param.dtype != dtype: param.data = param.data.to(dtype=dtype) + # Derive cfg.device / cfg.n_devices from hf_device_map when present. This covers: + # - fresh loads with a resolved device_map (set above) + # - pre-loaded hf_model that the caller dispatched themselves (e.g., device_map="auto") + hf_device_map_post = getattr(hf_model, "hf_device_map", None) + if hf_device_map_post: + # Pre-loaded path can still smuggle CPU/disk offload in; validate here too. + offload_values = {str(v).lower() for v in hf_device_map_post.values() if isinstance(v, str)} + forbidden = offload_values & {"cpu", "disk", "meta"} + if forbidden and ((n_devices is not None and n_devices > 1) or device_map is not None): + # Fresh-load path: we set the device_map ourselves, so this shouldn't happen — + # but if the user asked for n_devices>1 and somehow got CPU offload, surface it. + raise ValueError( + f"hf_device_map contains unsupported offload targets: {sorted(forbidden)}. " + "v1 multi-device support is GPU-only." + ) + embedding_device = find_embedding_device(hf_model) + if embedding_device is not None: + adapter.cfg.device = str(embedding_device) + adapter.cfg.n_devices = count_unique_devices(hf_model) + elif adapter.cfg.device is None: + # Pre-loaded single-device model with no hf_device_map — fall back to first param. + try: + adapter.cfg.device = str(next(hf_model.parameters()).device) + except StopIteration: + adapter.cfg.device = "cpu" # #7: Verify the n_ctx override actually took effect on the loaded model. # If HF's config class silently dropped or normalized the value, warn so # the user doesn't get misled into thinking longer sequences are supported. diff --git a/transformer_lens/utilities/__init__.py b/transformer_lens/utilities/__init__.py index de5e79db6..56975f90d 100644 --- a/transformer_lens/utilities/__init__.py +++ b/transformer_lens/utilities/__init__.py @@ -50,10 +50,13 @@ # Re-export multi-GPU helpers here (devices.py must not import multi_gpu directly) from .multi_gpu import ( calculate_available_device_cuda_memory, + count_unique_devices, determine_available_memory_for_available_devices, + find_embedding_device, get_best_available_cuda_device, get_best_available_device, get_device_for_block_index, + resolve_device_map, sort_devices_based_on_available_memory, ) from .slice import Slice, SliceInput diff --git a/transformer_lens/utilities/multi_gpu.py b/transformer_lens/utilities/multi_gpu.py index 8b5afaf58..0584604af 100644 --- a/transformer_lens/utilities/multi_gpu.py +++ b/transformer_lens/utilities/multi_gpu.py @@ -5,7 +5,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union import torch @@ -141,3 +141,110 @@ def get_device_for_block_index( return device device_index = (device.index or 0) + (index // layers_per_device) return torch.device(device.type, device_index) + + +_UNSUPPORTED_DEVICE_MAP_VALUES = {"cpu", "disk", "meta"} +"""v1 multi-GPU scope is GPU-only. CPU offload and disk offload cause dtype-cast loops to +silently miss offloaded params (meta tensors), and cross-layer hook routing has different +semantics. Reject them explicitly until a v2 scopes those paths.""" + + +def _validate_device_map_values( + device_map: Union[str, Dict[str, Union[str, int]]], +) -> None: + """Reject CPU / disk / meta values in a user-supplied device_map dict.""" + if isinstance(device_map, str): + # "balanced_low_0" is fine — still GPU-only; "cpu" as a string-form device_map + # would tell HF to put everything on CPU, which is single-device and meaningless + # as a multi-GPU config. We allow strings through; HF will validate them. + return + for key, value in device_map.items(): + normalized = str(value).lower() if isinstance(value, str) else None + if normalized in _UNSUPPORTED_DEVICE_MAP_VALUES: + raise ValueError( + f"device_map[{key!r}]={value!r} is not supported. Multi-device bridge " + f"support is GPU-only in v1; CPU / disk / meta offload routes are excluded." + ) + + +def resolve_device_map( + n_devices: Optional[int], + device_map: Optional[Union[str, Dict[str, Union[str, int]]]], + device: Optional[Union[str, torch.device]], + max_memory: Optional[Dict[Union[str, int], str]] = None, +) -> Tuple[Optional[Union[str, Dict[str, Union[str, int]]]], Optional[Dict[Union[str, int], str]]]: + """Resolve ``n_devices`` / ``device_map`` / ``device`` into HF ``from_pretrained`` kwargs. + + Returns ``(device_map, max_memory)`` tuple ready to pass into ``model_kwargs``. + + Semantics: + - Explicit ``device_map`` wins; it's validated and passed through unchanged (user- + provided ``max_memory`` is passed through too). + - ``n_devices=None`` or ``1``: returns ``(None, None)`` — single-device path. + - ``n_devices > 1``: returns ``("balanced", {0: "auto", ..., n-1: "auto"})``. + ``"balanced"`` is accelerate's string directive for balanced layer dispatch; + the ``max_memory`` dict caps visibility to exactly ``n_devices`` GPUs. + """ + if device_map is not None and device is not None: + raise ValueError("device and device_map are mutually exclusive — pass one.") + if device_map is not None: + _validate_device_map_values(device_map) + return device_map, max_memory + if n_devices is None or n_devices <= 1: + return None, max_memory + if not torch.cuda.is_available(): + raise ValueError(f"n_devices={n_devices} requires CUDA, which is not available.") + if torch.cuda.device_count() < n_devices: + raise ValueError( + f"n_devices={n_devices} but only {torch.cuda.device_count()} CUDA devices present." + ) + resolved_max_memory: Dict[Union[str, int], str] = ( + dict(max_memory) if max_memory else {i: "auto" for i in range(n_devices)} + ) + return "balanced", resolved_max_memory + + +def find_embedding_device(hf_model: Any) -> Optional[torch.device]: + """Return the device that input tokens should be placed on for a dispatched HF model. + + When a model is loaded with ``device_map``, accelerate populates ``hf_device_map`` + and inserts pre/post-forward hooks that route activations. Input tensors must land on + the device of whichever module first *consumes* them — the input embedding. Returns + ``None`` for single-device models (no ``hf_device_map`` set). + + Resolves via ``hf_model.get_input_embeddings()`` rather than dict insertion order to + cover encoder-decoder / multimodal / audio architectures where the first entry in + ``hf_device_map`` is not the text-token embedding (e.g. the vision tower on LLaVA). + """ + hf_device_map = getattr(hf_model, "hf_device_map", None) + if not hf_device_map: + return None + # Preferred: ask the model for its input embedding module and read its device. + get_input_embeddings = getattr(hf_model, "get_input_embeddings", None) + if callable(get_input_embeddings): + try: + embed_module = get_input_embeddings() + except (AttributeError, NotImplementedError): + embed_module = None + if embed_module is not None: + try: + param = next(embed_module.parameters()) + return param.device + except StopIteration: + pass + # Fallback: first entry in hf_device_map. Less reliable but better than nothing. + first_device = next(iter(hf_device_map.values())) + if isinstance(first_device, int): + return torch.device("cuda", first_device) + return torch.device(first_device) + + +def count_unique_devices(hf_model: Any) -> int: + """Count the number of unique devices across a dispatched HF model's ``hf_device_map``. + + Returns 1 if the model has no ``hf_device_map`` (single-device load). + """ + hf_device_map = getattr(hf_model, "hf_device_map", None) + if not hf_device_map: + return 1 + return len(set(hf_device_map.values()))