Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 54 additions & 8 deletions ggml/src/ggml-cpu/arch/x86/quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -2369,6 +2369,30 @@ static const int8_t keven_signs_q2xs[1024] = {
};
#endif

#if defined(__AVX2__)
// for _mm_srlv_epi32, shifts to 7 bit signs in xxs quantizations
static const __attribute__((aligned(16))) uint32_t ksigns_shift_xxs[4] = {0, 7, 14, 21};
// for _mm(256)_shuffle_epi8, has 0x80 at indices that are encoded with odd bit counts
static const __attribute__((aligned(32))) uint32_t ksigns_popc_odd[8] = {
0x00808000, 0x80000080, 0x80000080, 0x00808000,
0x00808000, 0x80000080, 0x80000080, 0x00808000,
};
// for _mm256_shuffle_epi8, broadcasts bytes 0, 2, 4, 6 / 8, 10, 12, 14
static const __attribute__((aligned(32))) uint64_t ksigns_bcast_1[4] = {
0x0000000000000000ULL, 0x0202020202020202ULL,
0x0404040404040404ULL, 0x0606060606060606ULL,
};
static const __attribute__((aligned(32))) uint64_t ksigns_bcast_2[4] = {
0x0808080808080808ULL, 0x0A0A0A0A0A0A0A0AULL,
0x0C0C0C0C0C0C0C0CULL, 0x0E0E0E0E0E0E0E0EULL,
};
// for _mm256_cmpeq_epi8 / _mm256_and_si256 to check bits after broadcast
static const __attribute__((aligned(32))) uint8_t ksigns_bitsel[32] = {
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,
};
#endif

void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
Expand All @@ -2384,11 +2408,16 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const

#if defined(__AVX2__)

const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;

uint32_t aux32[4];
const uint8_t * aux8 = (const uint8_t *)aux32;

const __m128i ks_shift = _mm_load_si128((const __m128i *)ksigns_shift_xxs);
const __m128i ks_mask = _mm_set1_epi32(0x7F);
const __m128i popc_odd = _mm_load_si128((const __m128i *)ksigns_popc_odd);
const __m256i ks_bc_1 = _mm256_load_si256((const __m256i *)ksigns_bcast_1);
const __m256i ks_bc_2 = _mm256_load_si256((const __m256i *)ksigns_bcast_2);
const __m256i ks_bsel = _mm256_load_si256((const __m256i *)ksigns_bitsel);

__m256 accumf = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
Expand All @@ -2402,12 +2431,29 @@ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const
memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8;
const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]);
const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]);
const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127],
signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127],
signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]);
const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);

__m128i s_l = _mm_set1_epi32(aux32[1]);
__m128i s_h = _mm_set1_epi32(aux32[3]);
// shift each value to their offset, then zero out garbage
s_l = _mm_srlv_epi32(s_l, ks_shift);
s_h = _mm_srlv_epi32(s_h, ks_shift);
s_l = _mm_and_si128(s_l, ks_mask);
s_h = _mm_and_si128(s_h, ks_mask);
// pack, count bits via xor+lut, correct bit 8
__m128i signs_128 = _mm_packus_epi32(s_l, s_h);
const __m128i cnt4 = _mm_xor_si128(_mm_srli_epi16(signs_128, 4), signs_128);
const __m128i popc = _mm_shuffle_epi8(popc_odd, cnt4);
signs_128 = _mm_or_si128(signs_128, popc);
// expand to 256 bits, then broadcast to 8 bytes each
__m256i signs_256 = _mm256_broadcastsi128_si256(signs_128);
const __m256i s1_b = _mm256_shuffle_epi8(signs_256, ks_bc_1);
const __m256i s2_b = _mm256_shuffle_epi8(signs_256, ks_bc_2);
// set 0xFF in bytes that contain bit, then invert via xor+sub
const __m256i s1 = _mm256_cmpeq_epi8(_mm256_and_si256(s1_b, ks_bsel), ks_bsel);
const __m256i s2 = _mm256_cmpeq_epi8(_mm256_and_si256(s2_b, ks_bsel), ks_bsel);
const __m256i q8s_1 = _mm256_sub_epi8(_mm256_xor_si256(q8_1, s1), s1);
const __m256i q8s_2 = _mm256_sub_epi8(_mm256_xor_si256(q8_2, s2), s2);

const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
const uint16_t ls1 = aux32[1] >> 28;
Expand Down
Loading