Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 25 additions & 12 deletions convert-falcon-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def count_model_parts(dir_model: str) -> int:
with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
hparams = json.load(f)

if hparams["architectures"][0] != "RWForCausalLM":
if hparams["architectures"][0] not in ("RWForCausalLM", "FalconForCausalLM"):
print("Model architecture not supported: " + hparams["architectures"][0])

sys.exit()
Expand All @@ -93,19 +93,34 @@ def count_model_parts(dir_model: str) -> int:

print("gguf: get model metadata")

block_count = hparams["n_layer"]
if "n_layer" in hparams:
block_count = hparams["n_layer"]
elif "num_hidden_layers" in hparams:
block_count = hparams["num_hidden_layers"]
else:
print("No block count found")

sys.exit()

if "n_head" in hparams:
n_head = hparams["n_head"]
elif "num_attention_heads" in hparams:
n_head = hparams["num_attention_heads"]
else:
print("No head count found")

sys.exit()

n_head_kv = hparams["n_head_kv"] if "n_head_kv" in hparams else 1

gguf_writer.add_name("Falcon")
gguf_writer.add_context_length(2048) # not in config.json
gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
gguf_writer.add_embedding_length(hparams["hidden_size"])
gguf_writer.add_feed_forward_length(4 * hparams["hidden_size"])
gguf_writer.add_block_count(block_count)
gguf_writer.add_head_count(hparams["n_head"])
if "n_head_kv" in hparams:
gguf_writer.add_head_count_kv(hparams["n_head_kv"])
else:
gguf_writer.add_head_count_kv(1)
gguf_writer.add_head_count(n_head)
gguf_writer.add_head_count_kv(n_head_kv)
gguf_writer.add_layer_norm_eps(hparams["layer_norm_epsilon"])
gguf_writer.add_file_type(ftype)

Expand Down Expand Up @@ -133,7 +148,7 @@ def count_model_parts(dir_model: str) -> int:

print("gguf: get gpt2 tokenizer vocab")

vocab_size = len(tokenizer_json["model"]["vocab"])
vocab_size = hparams["vocab_size"]

# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
tokenizer = AutoTokenizer.from_pretrained(dir_model)
Expand Down Expand Up @@ -190,10 +205,8 @@ def count_model_parts(dir_model: str) -> int:
tensor_map = gguf.get_tensor_name_map(ARCH,block_count)

# params for qkv transform
n_head = hparams["n_head"]
n_head_kv = hparams["n_head_kv"] if "n_head_kv" in hparams else 1

head_dim = hparams["hidden_size"] // n_head
parallel_attn = hparams["parallel_attn"]

# tensor info
print("gguf: get tensor metadata")
Expand Down Expand Up @@ -228,7 +241,7 @@ def count_model_parts(dir_model: str) -> int:
# in contiguous fashion.
# ref: https://github.com/jploski/ggml/blob/falcon40b/examples/falcon/convert-hf-to-ggml.py

if "query_key_value" in name:
if "query_key_value" in name and parallel_attn:
qkv = data.view(n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head)
q = qkv[:, :-2 ].reshape(n_head * head_dim, head_dim * n_head)
k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head)
Expand Down
2 changes: 2 additions & 0 deletions gguf-py/gguf/gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2",
MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv",
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
},
Expand Down Expand Up @@ -291,6 +292,7 @@ def get_tensor_name_map(arch: MODEL_ARCH, n_blocks: int) -> dict:
tensor_map["transformer.blocks."+str(i)+".norm_2"] = mapped_to # mpt
tensor_map["model.layers."+str(i)+".post_attention_layernorm"] = mapped_to # llama-hf
tensor_map["layers."+str(i)+".ffn_norm"] = mapped_to # llama-pth
tensor_map["transformer.h."+str(i)+".post_attention_layernorm"] = mapped_to # falcon-rw

# Feed-forward up
mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_UP, None)
Expand Down
72 changes: 66 additions & 6 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
{ LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
Expand Down Expand Up @@ -830,6 +831,7 @@ static llama_state g_state;
// available llama models
enum e_model {
MODEL_UNKNOWN,
MODEL_1B,
MODEL_3B,
MODEL_7B,
MODEL_13B,
Expand Down Expand Up @@ -899,15 +901,20 @@ struct llama_layer {
struct ggml_tensor * wk;
struct ggml_tensor * wv;
struct ggml_tensor * wo;
struct ggml_tensor * wo_b;
struct ggml_tensor * wqkv;
struct ggml_tensor * wqkv_b;

// normalization
struct ggml_tensor * ffn_norm;
struct ggml_tensor * ffn_norm_b;

// ff
struct ggml_tensor * w1; // ffn_gate
struct ggml_tensor * w2; // ffn_down
struct ggml_tensor * w3; // ffn_up
struct ggml_tensor * w1; // ffn_gate
struct ggml_tensor * w2; // ffn_down
struct ggml_tensor * w2_b; // ff_down bias
struct ggml_tensor * w3; // ffn_up
struct ggml_tensor * w3_b; // ff_up bias
};

struct llama_kv_cache {
Expand Down Expand Up @@ -1524,6 +1531,7 @@ std::string llama_model_ftype_name(enum llama_ftype ftype) {

static const char * llama_model_type_name(e_model type) {
switch (type) {
case MODEL_1B: return "1B";
case MODEL_3B: return "3B";
case MODEL_7B: return "7B";
case MODEL_13B: return "13B";
Expand Down Expand Up @@ -1626,6 +1634,7 @@ static void llm_load_hparams(
GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS));

switch (hparams.n_layer) {
case 24: model.type = e_model::MODEL_1B; break;
case 32: model.type = e_model::MODEL_7B; break;
case 60: model.type = e_model::MODEL_40B; break;
default: model.type = e_model::MODEL_UNKNOWN;
Expand Down Expand Up @@ -2005,11 +2014,41 @@ static void llm_load_tensors(
}
}

layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
// TODO: For 1B, we need e.g {2048, 6144} and the usual calculation gives us e.g {2048, 2176}.
// I think this is because we skip the QKV reshaping in the conversion script (maybe because parallel attention is disabled?)
if (model.type == MODEL_1B) {
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd * 3}, backend_split);
// TODO - The config.json has a `bias: true` -- can we plumb that through here to figure out if we need to include it?
layer.wqkv_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd * 3}, backend);
} else {
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
}

layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
// TODO - The config.json has a `bias: true` -- can we plumb that through here to figure out if we need to include it?
if (model.type == MODEL_1B) {
layer.wo_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend_split);
}

// TODO: Can we figure out if we need this dynamically?
if (model.type == MODEL_1B) {
layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend);
}

layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split);

// TODO - The config.json has a `bias: true` -- can we plumb that through here to figure out if we need to include it?
if (model.type == MODEL_1B) {
layer.w2_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend_split);
}

layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
// TODO - The config.json has a `bias: true` -- can we plumb that through here to figure out if we need to include it?
if (model.type == MODEL_1B) {
// TODO - where does 4 come from?
layer.w3_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_embd * 4}, backend_split);
}

if (backend == GGML_BACKEND_GPU) {
vram_weights +=
Expand Down Expand Up @@ -2636,6 +2675,11 @@ static struct ggml_cgraph * llm_build_falcon(
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
offload_func_kq(cur);

if (model.layers[il].wqkv_b) { // Falcon-RW-1B
cur = ggml_add(ctx0, cur, model.layers[il].wqkv_b);
offload_func(cur);
}

// Note that the strides for Kcur, Vcur are set up so that the
// resulting views are misaligned with the tensor's storage
// (by applying the K/V offset we shift the tensor's original
Expand Down Expand Up @@ -2747,6 +2791,12 @@ static struct ggml_cgraph * llm_build_falcon(

cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
offload_func(cur);

if (model.layers[il].wo_b) { // Falcon-RW-1B
cur = ggml_add(ctx0, cur, model.layers[il].wo_b);
offload_func(cur);
}

ggml_set_name(cur, "result_wo");
}

Expand All @@ -2759,10 +2809,20 @@ static struct ggml_cgraph * llm_build_falcon(
cur = ggml_mul_mat(ctx0, model.layers[il].w3, inpFF);
offload_func(cur);

if (model.layers[il].w3_b) { // Falcon-RW-1B
cur = ggml_add(ctx0, cur, model.layers[il].w3_b);
offload_func(cur);
}

cur = ggml_gelu(ctx0, cur);
offload_func(cur);
cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur);
offload_func(cur);

if (model.layers[il].w2_b) { // Falcon-RW-1B
cur = ggml_add(ctx0, cur, model.layers[il].w2_b);
offload_func(cur);
}
}

cur = ggml_add(ctx0, cur, attn_out);
Expand Down