Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 1 addition & 54 deletions rust/lance-linalg/src/distance/dot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,6 @@ use lance_core::utils::cpu::SimdSupport;
use lance_core::utils::cpu::SIMD_SUPPORT;
use num_traits::{real::Real, AsPrimitive, Num};

use crate::simd::{
f32::{f32x16, f32x8},
SIMD,
};
use crate::Result;

/// Default implementation of dot product.
Expand Down Expand Up @@ -142,56 +138,7 @@ impl Dot for f16 {
impl Dot for f32 {
#[inline]
fn dot(x: &[Self], y: &[Self]) -> f32 {
// Manually unrolled 8 times to get enough registers.
// TODO: avx512 can unroll more
let x_unrolled_chunks = x.chunks_exact(64);
let y_unrolled_chunks = y.chunks_exact(64);

// 8 float32 SIMD
let x_aligned_chunks = x_unrolled_chunks.remainder().chunks_exact(8);
let y_aligned_chunks = y_unrolled_chunks.remainder().chunks_exact(8);

let sum = if x_aligned_chunks.remainder().is_empty() {
0.0
} else {
debug_assert_eq!(
x_aligned_chunks.remainder().len(),
y_aligned_chunks.remainder().len()
);
x_aligned_chunks
.remainder()
.iter()
.zip(y_aligned_chunks.remainder().iter())
.map(|(&x, &y)| x * y)
.sum()
};

let mut sum8 = f32x8::zeros();
x_aligned_chunks
.zip(y_aligned_chunks)
.for_each(|(x_chunk, y_chunk)| unsafe {
let x1 = f32x8::load_unaligned(x_chunk.as_ptr());
let y1 = f32x8::load_unaligned(y_chunk.as_ptr());
sum8 += x1 * y1;
});

let mut sum16 = f32x16::zeros();
x_unrolled_chunks
.zip(y_unrolled_chunks)
.for_each(|(x, y)| unsafe {
let x1 = f32x16::load_unaligned(x.as_ptr());
let x2 = f32x16::load_unaligned(x.as_ptr().add(16));
let x3 = f32x16::load_unaligned(x.as_ptr().add(32));
let x4 = f32x16::load_unaligned(x.as_ptr().add(48));

let y1 = f32x16::load_unaligned(y.as_ptr());
let y2 = f32x16::load_unaligned(y.as_ptr().add(16));
let y3 = f32x16::load_unaligned(y.as_ptr().add(32));
let y4 = f32x16::load_unaligned(y.as_ptr().add(48));

sum16 += (x1 * y1 + x2 * y2) + (x3 * y3 + x4 * y4);
});
sum16.reduce_sum() + sum8.reduce_sum() + sum
dot_scalar::<Self, Self, 16>(x, y)
}
}

Expand Down
Loading