Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 50 additions & 16 deletions ggml/src/ggml-backend-meta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1205,40 +1205,57 @@ static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, gg

if (split_state.n_segments != 1) {
GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS);
GGML_ASSERT(offset == 0);
GGML_ASSERT(size == ggml_nbytes(tensor));
GGML_ASSERT(tensor->ne[3] == 1);

size_t offset_data = 0;
std::vector<size_t> simple_offsets(n_bufs, 0);
if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) {
GGML_ASSERT(tensor->ne[2] == 1);

const size_t row_stride = tensor->nb[1];
GGML_ASSERT(offset % row_stride == 0);
GGML_ASSERT(size % row_stride == 0);
const int64_t r_start = offset / row_stride;
const int64_t r_count = size / row_stride;
GGML_ASSERT(r_start + r_count <= tensor->ne[1]);

const int64_t blck_size = ggml_blck_size(tensor->type);
for (size_t s = 0; s < split_state.n_segments; s++) {
for (size_t j = 0; j < n_bufs; j++) {
ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0);
const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0];
ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, simple_offsets[j], nbytes,
tensor->ne[1], simple_tensor->nb[1], tensor->nb[1]);
ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data,
simple_offsets[j] + r_start * simple_tensor->nb[1], nbytes,
r_count, simple_tensor->nb[1], tensor->nb[1]);
offset_data += nbytes;
simple_offsets[j] += nbytes;
}
}
GGML_ASSERT(offset_data*tensor->ne[1] == size);
GGML_ASSERT(offset_data*r_count == size);
return;
}
GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1);

const size_t row_stride = tensor->nb[2];
GGML_ASSERT(offset % row_stride == 0);
GGML_ASSERT(size % row_stride == 0);
const int64_t r_start = offset / row_stride;
const int64_t r_count = size / row_stride;
GGML_ASSERT(r_start + r_count <= tensor->ne[2]);

for (size_t s = 0; s < split_state.n_segments; s++) {
for (size_t j = 0; j < n_bufs; j++) {
ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1];
ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, simple_offsets[j], nbytes,
tensor->ne[2], simple_tensor->nb[2], tensor->nb[2]);
ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data,
simple_offsets[j] + r_start * simple_tensor->nb[2], nbytes,
r_count, simple_tensor->nb[2], tensor->nb[2]);
offset_data += nbytes;
simple_offsets[j] += nbytes;
}
}
GGML_ASSERT(offset_data*tensor->ne[2] == size);
GGML_ASSERT(offset_data*r_count == size);
return;
}

Expand Down Expand Up @@ -1295,40 +1312,57 @@ static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, co

if (split_state.n_segments != 1) {
GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS);
GGML_ASSERT(offset == 0);
GGML_ASSERT(size == ggml_nbytes(tensor));
GGML_ASSERT(tensor->ne[3] == 1);

size_t offset_data = 0;
std::vector<size_t> simple_offsets(n_bufs, 0);
if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) {
GGML_ASSERT(tensor->ne[2] == 1);

const size_t row_stride = tensor->nb[1];
GGML_ASSERT(offset % row_stride == 0);
GGML_ASSERT(size % row_stride == 0);
const int64_t r_start = offset / row_stride;
const int64_t r_count = size / row_stride;
GGML_ASSERT(r_start + r_count <= tensor->ne[1]);

const int64_t blck_size = ggml_blck_size(tensor->type);
for (size_t s = 0; s < split_state.n_segments; s++) {
for (size_t j = 0; j < n_bufs; j++) {
const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0);
const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0];
ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, simple_offsets[j], nbytes,
tensor->ne[1], simple_tensor->nb[1], tensor->nb[1]);
ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data,
simple_offsets[j] + r_start * simple_tensor->nb[1], nbytes,
r_count, simple_tensor->nb[1], tensor->nb[1]);
offset_data += nbytes;
simple_offsets[j] += nbytes;
}
}
GGML_ASSERT(offset_data*tensor->ne[1] == size);
GGML_ASSERT(offset_data*r_count == size);
return;
}
GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1);

const size_t row_stride = tensor->nb[2];
GGML_ASSERT(offset % row_stride == 0);
GGML_ASSERT(size % row_stride == 0);
const int64_t r_start = offset / row_stride;
const int64_t r_count = size / row_stride;
GGML_ASSERT(r_start + r_count <= tensor->ne[2]);

for (size_t s = 0; s < split_state.n_segments; s++) {
for (size_t j = 0; j < n_bufs; j++) {
const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j);
const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1];
ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, simple_offsets[j], nbytes,
tensor->ne[2], simple_tensor->nb[2], tensor->nb[2]);
ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data,
simple_offsets[j] + r_start * simple_tensor->nb[2], nbytes,
r_count, simple_tensor->nb[2], tensor->nb[2]);
offset_data += nbytes;
simple_offsets[j] += nbytes;
}
}
GGML_ASSERT(offset_data*tensor->ne[2] == size);
GGML_ASSERT(offset_data*r_count == size);
return;
}

Expand Down
Loading