Skip to content

Adapt omniquant to transformers 4.41.0 #107

@zijunx

Description

@zijunx

we adapt omniquant to transformers 4.41.0 since the hardware restrict.
however the interface changed, transformers/models/llama/modeling_llama.py
95 class LlamaRotaryEmbedding(nn.Module):
96 def init(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
97 super().init()
98 self.scaling_factor = scaling_factor
99 self.dim = dim
100 self.max_position_embeddings = max_position_embeddings
101 self.base = base
102 inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
103 self.register_buffer("inv_freq", inv_freq, persistent=False)
104 # For BC we register cos and sin cached
105 self.max_seq_len_cached = max_position_embeddings
106
107 @torch.no_grad()
108 def forward(self, x, position_ids):
109 # x: [bs, num_attention_heads, seq_len, head_size]
110 inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
111 position_ids_expanded = position_ids[:, None, :].float()
112 # Force float32 since bfloat16 loses precision on long contexts
113 # See huggingface/transformers#29285
114 device_type = x.device.type
115 device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
116 with torch.autocast(device_type=device_type, enabled=False):
117 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
118 emb = torch.cat((freqs, freqs), dim=-1)
119 cos = emb.cos()
120 sin = emb.sin()
121 return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

the foward function changed, so we change
OmniQuant/models/int_llama_layer.py
def forward(
from cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) to cos, sin = self.rotary_emb(value_states, position_ids=position_ids)

then a new error happens
Attention mask should be of size (1, 1, 2048, 2048), but is torch.Size([1, 1, 2048, 2049])

transformers 4.37 Attention mask size is (1, 1, 2048, 2048), I see they are both from layer [0], do you have any ideas why this happens?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions