Proposal
I've already seen other issues mentioning this in the context of loading quantized models, so am wondering if there's any plan going forward for local model loading i.e an easy interface on loading the local models and doing interpretability analysis.
As it's a well known current limitation, sharing basic examples, but let me know if you need anything from my end.
Motivation
I can load Llama-2-7b without any issue but am unable to load the compressed counterparts.
Pitch
Working example - Able to load base model in both Huggingface and HookedTransformer.
model_name = "meta-llama/Llama-2-7b-chat-hf"
hf_model = AutoModel.from_pretrained(model_name)
print(hf_model)
hooked_model = transformer_lens.HookedTransformer.from_pretrained(model_name)
print(hooked_model)
HF model
LlamaModel(
(embed_tokens): Embedding(32000, 4096)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=4096, bias=False)
(v_proj): Linear(in_features=4096, out_features=4096, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
(up_proj): Linear(in_features=4096, out_features=11008, bias=False)
(down_proj): Linear(in_features=11008, out_features=4096, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm()
(post_attention_layernorm): LlamaRMSNorm()
)
)
(norm): LlamaRMSNorm()
(rotary_emb): LlamaRotaryEmbedding()
)
and Hooked model is
HookedTransformer(
(embed): Embed()
(hook_embed): HookPoint()
(blocks): ModuleList(
(0-31): 32 x TransformerBlock(
(ln1): RMSNormPre(
(hook_scale): HookPoint()
(hook_normalized): HookPoint()
)
(ln2): RMSNormPre(
(hook_scale): HookPoint()
(hook_normalized): HookPoint()
)
(attn): Attention(
(hook_k): HookPoint()
(hook_q): HookPoint()
(hook_v): HookPoint()
(hook_z): HookPoint()
(hook_attn_scores): HookPoint()
(hook_pattern): HookPoint()
(hook_result): HookPoint()
(hook_rot_k): HookPoint()
(hook_rot_q): HookPoint()
)
(mlp): GatedMLP(
(hook_pre): HookPoint()
(hook_pre_linear): HookPoint()
(hook_post): HookPoint()
)
(hook_attn_in): HookPoint()
(hook_q_input): HookPoint()
(hook_k_input): HookPoint()
(hook_v_input): HookPoint()
(hook_mlp_in): HookPoint()
(hook_attn_out): HookPoint()
(hook_mlp_out): HookPoint()
(hook_resid_pre): HookPoint()
(hook_resid_mid): HookPoint()
(hook_resid_post): HookPoint()
)
)
(ln_final): RMSNormPre(
(hook_scale): HookPoint()
(hook_normalized): HookPoint()
)
(unembed): Unembed()
)
Pruned model
LlamaModel(
(embed_tokens): Embedding(32000, 4096)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaSdpaAttention(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=4096, bias=False)
(v_proj): Linear(in_features=4096, out_features=4096, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
(up_proj): Linear(in_features=4096, out_features=11008, bias=False)
(down_proj): Linear(in_features=11008, out_features=4096, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm()
(post_attention_layernorm): LlamaRMSNorm()
)
)
(norm): LlamaRMSNorm()
(rotary_emb): LlamaRotaryEmbedding()
)
but you get the error - ValueError: /path/to/compressed_model/ not found. Valid official model names (excl aliases) ....
Similar story for quantized models
Basically, an easy interface/class which allows users to bring their custom models to load (atleast those which follow same architecture).
Proposal
I've already seen other issues mentioning this in the context of loading quantized models, so am wondering if there's any plan going forward for local model loading i.e an easy interface on loading the local models and doing interpretability analysis.
As it's a well known current limitation, sharing basic examples, but let me know if you need anything from my end.
Motivation
I can load
Llama-2-7bwithout any issue but am unable to load the compressed counterparts.Pitch
Working example - Able to load base model in both Huggingface and HookedTransformer.
HF model
and Hooked model is
Pruned model
but you get the error -
ValueError: /path/to/compressed_model/ not found. Valid official model names (excl aliases) ....Similar story for quantized models
Basically, an easy interface/class which allows users to bring their custom models to load (atleast those which follow same architecture).