diff --git a/pyproject.toml b/pyproject.toml index df57acf10..7375af482 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] authors=[{name="Neel Nanda", email="77788841+TransformerLensOrg@users.noreply.github.com"}] dependencies=[ - "accelerate>=0.23.0", # Needed for Llama Models + "accelerate>=0.23.0", # Needed for Llama Models "beartype>=0.14.1", "better-abc>=0.0.3", "datasets>=2.7.1", @@ -85,13 +85,11 @@ "-W ignore::beartype.roar.BeartypeDecorHintPep585DeprecationWarning", ] doctest_optionflags="NORMALIZE_WHITESPACE ELLIPSIS FLOAT_CMP" - markers=[ - "slow: marks tests as slow (deselect with '-m \"not slow\"')", - ] filterwarnings=[ - "ignore:pkg_resources is deprecated as an API:DeprecationWarning", "ignore:distutils Version classes are deprecated:DeprecationWarning", + "ignore:pkg_resources is deprecated as an API:DeprecationWarning", ] + markers=["slow: marks tests as slow (deselect with '-m \"not slow\"')"] pythonpath=["."] testpaths=["tests", "transformer_lens"] # Only test these directories diff --git a/transformer_lens/pretrained/weight_conversions/nanogpt.py b/transformer_lens/pretrained/weight_conversions/nanogpt.py index 0a138cdd0..235575861 100644 --- a/transformer_lens/pretrained/weight_conversions/nanogpt.py +++ b/transformer_lens/pretrained/weight_conversions/nanogpt.py @@ -29,6 +29,8 @@ def convert_nanogpt_weights(old_state_dict, cfg: HookedTransformerConfig): if "transformer.ln_f.bias" in old_state_dict: bias = True new_state_dict["ln_final.b"] = old_state_dict["transformer.ln_f.bias"] + else: + new_state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab, dtype=cfg.dtype) for layer in range(cfg.n_layers): layer_key = f"transformer.h.{layer}" @@ -43,6 +45,11 @@ def convert_nanogpt_weights(old_state_dict, cfg: HookedTransformerConfig): old_state_dict[f"{layer_key}.ln_2.weight"] ) + new_state_dict[f"blocks.{layer}.attn.mask"] = torch.tril( + torch.ones((cfg.n_ctx, cfg.n_ctx)).bool() + ) + new_state_dict[f"blocks.{layer}.attn.IGNORE"] = torch.tensor(-torch.inf) + W = old_state_dict[f"{layer_key}.attn.c_attn.weight"] W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=0) W_Q = einops.rearrange(W_Q, "(i h) m->i m h", i=cfg.n_heads) @@ -84,5 +91,22 @@ def convert_nanogpt_weights(old_state_dict, cfg: HookedTransformerConfig): new_state_dict[f"blocks.{layer}.attn.b_O"] = old_state_dict[ f"{layer_key}.attn.c_proj.bias" ] + else: + if cfg.d_mlp is None: + raise ValueError( + "cfg.d_mlp must be set to convert nanoGPT weights for the no-bias case." + ) + new_state_dict[f"blocks.{layer}.mlp.b_out"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) + new_state_dict[f"blocks.{layer}.mlp.b_in"] = torch.zeros(cfg.d_mlp, dtype=cfg.dtype) + new_state_dict[f"blocks.{layer}.attn.b_Q"] = torch.zeros( + (cfg.n_heads, cfg.d_head), dtype=cfg.dtype + ) + new_state_dict[f"blocks.{layer}.attn.b_K"] = torch.zeros( + cfg.n_heads, cfg.d_head, dtype=cfg.dtype + ) + new_state_dict[f"blocks.{layer}.attn.b_V"] = torch.zeros( + cfg.n_heads, cfg.d_head, dtype=cfg.dtype + ) + new_state_dict[f"blocks.{layer}.attn.b_O"] = torch.zeros(cfg.d_model, dtype=cfg.dtype) return new_state_dict