From 4efe649d6df5138d5547e915e8f7eab519a30146 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Fri, 26 Jul 2024 10:20:50 -0700 Subject: [PATCH 1/2] simplify dot implemetation --- rust/lance-linalg/src/distance/dot.rs | 55 +-------------------------- 1 file changed, 1 insertion(+), 54 deletions(-) diff --git a/rust/lance-linalg/src/distance/dot.rs b/rust/lance-linalg/src/distance/dot.rs index ef35c22cac9..92e65c1fe68 100644 --- a/rust/lance-linalg/src/distance/dot.rs +++ b/rust/lance-linalg/src/distance/dot.rs @@ -19,10 +19,6 @@ use lance_core::utils::cpu::SimdSupport; use lance_core::utils::cpu::SIMD_SUPPORT; use num_traits::{real::Real, AsPrimitive, Num}; -use crate::simd::{ - f32::{f32x16, f32x8}, - SIMD, -}; use crate::Result; /// Default implementation of dot product. @@ -142,56 +138,7 @@ impl Dot for f16 { impl Dot for f32 { #[inline] fn dot(x: &[Self], y: &[Self]) -> f32 { - // Manually unrolled 8 times to get enough registers. - // TODO: avx512 can unroll more - let x_unrolled_chunks = x.chunks_exact(64); - let y_unrolled_chunks = y.chunks_exact(64); - - // 8 float32 SIMD - let x_aligned_chunks = x_unrolled_chunks.remainder().chunks_exact(8); - let y_aligned_chunks = y_unrolled_chunks.remainder().chunks_exact(8); - - let sum = if x_aligned_chunks.remainder().is_empty() { - 0.0 - } else { - debug_assert_eq!( - x_aligned_chunks.remainder().len(), - y_aligned_chunks.remainder().len() - ); - x_aligned_chunks - .remainder() - .iter() - .zip(y_aligned_chunks.remainder().iter()) - .map(|(&x, &y)| x * y) - .sum() - }; - - let mut sum8 = f32x8::zeros(); - x_aligned_chunks - .zip(y_aligned_chunks) - .for_each(|(x_chunk, y_chunk)| unsafe { - let x1 = f32x8::load_unaligned(x_chunk.as_ptr()); - let y1 = f32x8::load_unaligned(y_chunk.as_ptr()); - sum8 += x1 * y1; - }); - - let mut sum16 = f32x16::zeros(); - x_unrolled_chunks - .zip(y_unrolled_chunks) - .for_each(|(x, y)| unsafe { - let x1 = f32x16::load_unaligned(x.as_ptr()); - let x2 = f32x16::load_unaligned(x.as_ptr().add(16)); - let x3 = f32x16::load_unaligned(x.as_ptr().add(32)); - let x4 = f32x16::load_unaligned(x.as_ptr().add(48)); - - let y1 = f32x16::load_unaligned(y.as_ptr()); - let y2 = f32x16::load_unaligned(y.as_ptr().add(16)); - let y3 = f32x16::load_unaligned(y.as_ptr().add(32)); - let y4 = f32x16::load_unaligned(y.as_ptr().add(48)); - - sum16 += (x1 * y1 + x2 * y2) + (x3 * y3 + x4 * y4); - }); - sum16.reduce_sum() + sum8.reduce_sum() + sum + dot_scalar::(x, y) } } From 3b0a32e0614ef2c7959fd9c085bdadd0f5fb2ad5 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Tue, 30 Jul 2024 11:22:23 -0700 Subject: [PATCH 2/2] fix duration import