Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions ci/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,8 @@ function gg_run_open_llama_3b_v2 {

(time ./bin/imatrix --model ${model_f16} -f ${wiki_test_60} -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log

(time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/save-load-state -fa --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log

function check_ppl {
qnt="$1"
Expand Down Expand Up @@ -517,7 +518,10 @@ function gg_run_open_llama_7b_v2 {

(time ./bin/imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log

(time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/save-load-state --model -ngl 10 ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/save-load-state --model -fa -ngl 10 ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/save-load-state --model -ngl 99 ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/save-load-state --model -fa -ngl 99 ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log

function check_ppl {
qnt="$1"
Expand Down
17 changes: 13 additions & 4 deletions ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -463,9 +463,6 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \
[metal_function release]; \
GGML_METAL_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
(int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
(int) kernel->pipeline.threadExecutionWidth); \
if (error) { \
GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
[metal_library release]; \
Expand Down Expand Up @@ -2646,13 +2643,25 @@ static enum ggml_status ggml_metal_graph_compute(
GGML_ASSERT(nqptg % 8 == 0);
GGML_ASSERT(ncpsg % 32 == 0);

int64_t nsgmax = 2;

while (true) {
const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
if (smem > ctx->device.maxThreadgroupMemoryLength) {
break;
}
nsgmax *= 2;
}
nsgmax /= 2;

// simdgroups per threadgroup (a.k.a. warps)
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4;
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;

const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);

//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);

[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];

[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
Expand Down
176 changes: 125 additions & 51 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2039,8 +2039,8 @@ struct llama_kv_cache {
bool has_shift = false;
bool do_defrag = false;
bool do_copy = false;
// with recurrent state models, a cell can hold the state for more than one past token
bool recurrent = false;
bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token
bool v_trans = true; // the value tensor is transposed

// Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_internal also uses it, so it
Expand Down Expand Up @@ -2339,11 +2339,14 @@ struct llama_context {

static bool llama_kv_cache_init(
struct llama_kv_cache & cache,
const llama_model & model,
const llama_context * ctx,
ggml_type type_k,
ggml_type type_v,
uint32_t kv_size,
bool offload) {
const llama_model & model = ctx->model;
const llama_cparams & cparams = ctx->cparams;

const struct llama_hparams & hparams = model.hparams;

const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
Expand All @@ -2354,6 +2357,7 @@ static bool llama_kv_cache_init(

// TODO: find a nicer way to add other recurrent model architectures
cache.recurrent = model.arch == LLM_ARCH_MAMBA;
cache.v_trans = !cparams.flash_attn;

// TODO: support mixed reccurent Transformer architectues
// NOTE: (!a || b) is a logical implication (a -> b)
Expand Down Expand Up @@ -2566,6 +2570,10 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
}
cache.head = 0;
cache.used = 0;

for (auto & buf : cache.bufs) {
ggml_backend_buffer_clear(buf, 0);
}
}

static bool llama_kv_cache_seq_rm(
Expand Down Expand Up @@ -15704,7 +15712,7 @@ struct llama_context * llama_new_context_with_model(
}
ctx->backends.push_back(ctx->backend_cpu);

if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, kv_size, cparams.offload_kqv)) {
if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) {
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
llama_free(ctx);
return nullptr;
Expand Down Expand Up @@ -16303,6 +16311,7 @@ size_t llama_state_get_size(const struct llama_context * ctx) {
const size_t s_kv_head = sizeof(uint32_t);
const size_t s_kv_size = sizeof(uint32_t);
const size_t s_kv_used = sizeof(uint32_t);
const size_t s_v_trans = sizeof(uint32_t);
const size_t s_kv = ctx->kv_self.total_size();
const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + cparams.n_seq_max*sizeof(llama_seq_id);
const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell;
Expand All @@ -16320,10 +16329,14 @@ size_t llama_state_get_size(const struct llama_context * ctx) {
+ s_kv_head
+ s_kv_size
+ s_kv_used
+ s_v_trans
+ s_kv
+ s_kv_cells
);

// on session change it is very likely that the state size has changed - so we need to update this function
static_assert(LLAMA_SESSION_VERSION == 6, "So you just bumped the session version - good. But did you remember to update llama_state_get_size?");

return s_total;
}

Expand Down Expand Up @@ -16469,11 +16482,13 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data
const uint32_t kv_size = kv_self.size;
const size_t kv_buf_size = kv_self.total_size() / (kv_size ? kv_size : 1) * kv_head;
const uint32_t kv_used = kv_self.used;
const uint32_t v_trans = kv_self.v_trans ? 1 : 0;

data_ctx->write(&kv_buf_size, sizeof(kv_buf_size));
data_ctx->write(&kv_head, sizeof(kv_head));
data_ctx->write(&kv_size, sizeof(kv_size));
data_ctx->write(&kv_used, sizeof(kv_used));
data_ctx->write(&v_trans, sizeof(v_trans));

if (kv_buf_size) {
const size_t pre_kv_buf_size = data_ctx->get_size_written();
Expand All @@ -16486,7 +16501,7 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data
ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), 0, tmp_buf.size());
data_ctx->write(tmp_buf.data(), tmp_buf.size());

if (kv_self.recurrent) {
if (kv_self.recurrent || !kv_self.v_trans) {
// v is contiguous for recurrent models
// TODO: use other tensors for state models than k and v
const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);
Expand Down Expand Up @@ -16619,11 +16634,15 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
uint32_t kv_head;
uint32_t kv_size;
uint32_t kv_used;
uint32_t v_trans;

memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size);
memcpy(&kv_head, inp, sizeof(kv_head)); inp += sizeof(kv_head);
memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used);
memcpy(&v_trans, inp, sizeof(v_trans)); inp += sizeof(v_trans);

GGML_ASSERT(kv_self.v_trans == (bool) v_trans); // incompatible V transposition

if (kv_self.size != kv_size) {
// the KV cache needs to be big enough to load all the KV cells from the saved state
Expand All @@ -16633,6 +16652,8 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
__func__, kv_head, kv_size, kv_self.size);
}

llama_kv_cache_clear(ctx);

if (kv_buf_size) {
const size_t pre_kv_buf_size = inp - src;

Expand All @@ -16644,7 +16665,7 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size);
inp += k_size;

if (kv_self.recurrent) {
if (kv_self.recurrent || !kv_self.v_trans) {
// v is contiguous for recurrent models
// TODO: use other tensors for state models than k and v
const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head);
Expand All @@ -16666,8 +16687,6 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
GGML_ASSERT(kv_buf_size == inp - src - pre_kv_buf_size);
}

llama_kv_cache_clear(ctx);

ctx->kv_self.head = kv_head;
ctx->kv_self.used = kv_used;

Expand Down Expand Up @@ -16927,28 +16946,49 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
}
}

// For the values, they are transposed, so we also need the element size and get the element ranges from each row
const uint32_t kv_size = kv_self.size;
for (int il = 0; il < (int)n_layer; ++il) {
// Write value type
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
data_ctx.write(&v_type_i, sizeof(v_type_i));
// TODO: simplify, reduce copy-paste
if (!kv_self.v_trans) {
for (int il = 0; il < (int)n_layer; ++il) {
// Write value type
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
data_ctx.write(&v_type_i, sizeof(v_type_i));

// Write element size
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
data_ctx.write(&v_size_el, sizeof(v_size_el));
// Write row size of value
const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
data_ctx.write(&v_size_row, sizeof(v_size_row));

// For each row, we get the element values of each cell
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
// Read each range of cells of v_size_el length each into tmp_buf and write out
// Read each range of cells of v_size length each into tmp_buf and write out
for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first;
const size_t src_offset = (range.first + j * kv_size) * v_size_el;
tmp_buf.resize(range_size * v_size_el);
ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size());
tmp_buf.resize(range_size * v_size_row);
ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), range.first * v_size_row, range_size * v_size_row);
data_ctx.write(tmp_buf.data(), tmp_buf.size());
}
}
} else {
// For the values, they are transposed, so we also need the element size and get the element ranges from each row
const uint32_t kv_size = kv_self.size;
for (int il = 0; il < (int)n_layer; ++il) {
// Write value type
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
data_ctx.write(&v_type_i, sizeof(v_type_i));

// Write element size
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
data_ctx.write(&v_size_el, sizeof(v_size_el));

// For each row, we get the element values of each cell
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
// Read each range of cells of v_size_el length each into tmp_buf and write out
for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first;
const size_t src_offset = (range.first + j * kv_size) * v_size_el;
tmp_buf.resize(range_size * v_size_el);
ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size());
data_ctx.write(tmp_buf.data(), tmp_buf.size());
}
}
}
}

return data_ctx.get_size_written();
Expand Down Expand Up @@ -17073,41 +17113,75 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
}
}

// For each layer, read the values for each cell (transposed)
for (int il = 0; il < (int)n_layer; ++il) {
// Read type of value
int32_t v_type_i_ref;
memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
inp += sizeof(v_type_i_ref);
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
if (v_type_i != v_type_i_ref) {
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
return 0;
}
// TODO: simplify, reduce copy-paste
if (!kv_self.v_trans) {
for (int il = 0; il < (int)n_layer; ++il) {
// Read type of value
int32_t v_type_i_ref;
memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
inp += sizeof(v_type_i_ref);
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
if (v_type_i != v_type_i_ref) {
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
return 0;
}

// Read element size of value
size_t v_size_el_ref;
memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref));
inp += sizeof(v_size_el_ref);
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
if (v_size_el != v_size_el_ref) {
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il);
return 0;
}
// Read row size of value
size_t v_size_row_ref;
memcpy(&v_size_row_ref, inp, sizeof(v_size_row_ref));
inp += sizeof(v_size_row_ref);
const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
if (v_size_row != v_size_row_ref) {
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, v_size_row_ref, il);
return 0;
}

if (cell_count) {
// For each row in the transposed matrix, read the values for the whole cell range
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
const size_t dst_offset = (kv_head + j * kv_size) * v_size_el;
ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el);
inp += cell_count * v_size_el;
if (cell_count) {
// Read and set the values for the whole cell range
ggml_backend_tensor_set(kv_self.v_l[il], inp, kv_head * v_size_row, cell_count * v_size_row);
inp += cell_count * v_size_row;
}
}
} else {
// For each layer, read the values for each cell (transposed)
for (int il = 0; il < (int)n_layer; ++il) {
// Read type of value
int32_t v_type_i_ref;
memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
inp += sizeof(v_type_i_ref);
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
if (v_type_i != v_type_i_ref) {
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
return 0;
}

// Read element size of value
size_t v_size_el_ref;
memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref));
inp += sizeof(v_size_el_ref);
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
if (v_size_el != v_size_el_ref) {
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il);
return 0;
}

if (cell_count) {
// For each row in the transposed matrix, read the values for the whole cell range
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
const size_t dst_offset = (kv_head + j * kv_size) * v_size_el;
ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el);
inp += cell_count * v_size_el;
}
}
}
}

const size_t nread = inp - src;

return nread;
}

Expand Down
4 changes: 2 additions & 2 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'

#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
#define LLAMA_SESSION_VERSION 5
#define LLAMA_SESSION_VERSION 6

#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
#define LLAMA_STATE_SEQ_VERSION 1
Expand Down Expand Up @@ -543,7 +543,7 @@ extern "C" {
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);

// Clear the KV cache
// Clear the KV cache - both cell info is erased and KV data is zeroed
LLAMA_API void llama_kv_cache_clear(
struct llama_context * ctx);

Expand Down
Loading