Add block-transposed multi-vector representation for SIMD-friendly layouts#805
Add block-transposed multi-vector representation for SIMD-friendly layouts#805suri-kumkaran wants to merge 4 commits intomainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR introduces a new multi-vector matrix representation (BlockTransposed<T, GROUP, PACK>) intended to replace/extend the prior kmeans-specific BlockTranspose layout and enable SIMD-friendly access patterns (including optional packed column interleaving).
Changes:
- Added
multi_vector::BlockTransposedrepresentation with row proxy types, block views, constructors, and extensive tests. - Migrated kmeans (lloyds, kmeans++) and product-quantization training/pivot codepaths to use
Mat<BlockTransposed<...>>instead ofBlockTranspose. - Removed the old
BlockTransposeimplementation fromalgorithms::kmeans::commonand stopped re-exporting it fromalgorithms::kmeans.
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| diskann-quantization/src/product/train.rs | Switches PQ training to build/use Mat<BlockTransposed<...>> for SIMD-friendly kmeans operations. |
| diskann-quantization/src/product/tables/transposed/pivots.rs | Updates pivot chunk storage to Mat<BlockTransposed<...>> and adjusts group-size/construction calls. |
| diskann-quantization/src/multi_vector/mod.rs | Adds the new block_transposed module and re-exports its public types. |
| diskann-quantization/src/multi_vector/matrix.rs | Makes some internals pub(crate) to support the new representation’s accessors/constructors. |
| diskann-quantization/src/multi_vector/block_transposed.rs | New core implementation of block-transposed (+ packed) representation with row proxies, block views, constructors, and tests. |
| diskann-quantization/src/algorithms/kmeans/plusplus.rs | Replaces BlockTranspose usage with Mat<BlockTransposed<...>> and updates microkernel wiring. |
| diskann-quantization/src/algorithms/kmeans/mod.rs | Removes pub use common::BlockTranspose; export. |
| diskann-quantization/src/algorithms/kmeans/lloyds.rs | Replaces BlockTranspose usage with Mat<BlockTransposed<...>>. |
| diskann-quantization/src/algorithms/kmeans/common.rs | Removes the old BlockTranspose implementation and its tests, leaving square_norm. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| for row in 0..nrows { | ||
| for col in 0..ncols { | ||
| let idx = linear_index::<GROUP, PACK>(row, col, ncols); | ||
| // SAFETY: idx < storage_len by construction, base points to a valid allocation. | ||
| unsafe { *base.add(idx) = v[(row, col)] }; |
There was a problem hiding this comment.
from_strided fills the backing store in logical row-major order, which writes with a stride of GROUP (or more with PACK) and can be significantly less cache-friendly than filling in physical block order. The previous BlockTranspose::from_strided implementation filled per-block/per-column (contiguous writes). Consider restructuring this loop to iterate blocks + column-groups + rows to write mostly sequentially into the backing allocation.
| for row in 0..nrows { | |
| for col in 0..ncols { | |
| let idx = linear_index::<GROUP, PACK>(row, col, ncols); | |
| // SAFETY: idx < storage_len by construction, base points to a valid allocation. | |
| unsafe { *base.add(idx) = v[(row, col)] }; | |
| // Iterate in physical block order: blocks -> column-groups -> rows -> packed columns. | |
| // This makes writes mostly sequential in the backing allocation, improving cache locality. | |
| let nblocks = (nrows + GROUP - 1) / GROUP; | |
| let ncol_groups = (ncols + PACK - 1) / PACK; | |
| for block in 0..nblocks { | |
| let row_base = block * GROUP; | |
| for col_group in 0..ncol_groups { | |
| let col_base = col_group * PACK; | |
| for in_block_row in 0..GROUP { | |
| let row = row_base + in_block_row; | |
| if row >= nrows { | |
| break; | |
| } | |
| for pack_idx in 0..PACK { | |
| let col = col_base + pack_idx; | |
| if col >= ncols { | |
| break; | |
| } | |
| let idx = linear_index::<GROUP, PACK>(row, col, ncols); | |
| // SAFETY: idx < storage_len by construction, base points to a valid allocation. | |
| unsafe { | |
| *base.add(idx) = v[(row, col)]; | |
| } | |
| } | |
| } |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #805 +/- ##
==========================================
- Coverage 89.45% 89.40% -0.06%
==========================================
Files 432 433 +1
Lines 79452 80296 +844
==========================================
+ Hits 71075 71788 +713
- Misses 8377 8508 +131
Flags with carried forward coverage won't be shown. Click here to find out more.
🚀 New features to boost your workflow:
|
hildebrandmw
left a comment
There was a problem hiding this comment.
Thanks Suryansh! I've taken a look over and left a few stylistic comments, but there are few larger ones that I will group together here:
-
We should be very careful with the unchecked arithmetic, particularly in
compute_capacity. We honestly probably want two flavors of this function, one that verifies none of the intermediate steps overflow that gets used in the constructor, and one that assumes everything has already been checked (this would be the currentcompute_capacityfunction). I would also recommend leaning in tousize::next_multiple_ofinstead of doing the operation manually. -
I'm getting more convinced that we do not want to use
Mat/MatRef/MatMutin public APIs. This forces us into an awkward situation where we need to start adding a bunch of inherent methods to these types, which makes method discoverability a little harder and is not a pattern that can be replicated outside ofdiskann-quantization. Instead, I think it would be cleaner to have thin wrappers:struct BlockTransposed<T, const GROUP: usize, const PACK: usize> { data: Mat<Repr<T, GROUP, PACK>>, } struct BlockTransposedRef<'a, T, const GROUP: usize, const PACK: usize> { data: MatRef<'a, Repr<T, GROUP, PACK>>, } // etc.
Inherent methods can be added to these at will and the generated docs for these inherent methods will be clear (since they will not be grouped in with all the other inherent methods). This will also let you safely has
as_sliceas an inherent method to avoid the manual unsafe construction in tests.Also, I'll admit a small preference for using method delegation instead of stamping out a giant macro, something like:
impl<T, const GROUP: usize, const PACK: usize> BlockTransposed<T, GROUP, PACK> { fn block(&self, block: usize) -> MatrixView<'_, T> { self.reborrow().block(block) } }
But I think that is secondary.
-
In the tests - we should really have one that calls
collecton aMat::row_iter_mut()and writes to a bunch of rows in like athread::scopedor something. Basically, something to make sure that working on multiple rows concurrently is fine (and when using Miri - that will add another level of safety).
| #[derive(Debug, Clone, Copy)] | ||
| pub struct BlockTransposedRow<'a, T, const GROUP: usize, const PACK: usize = 1> { | ||
| /// Pointer to the element at `(row, col=0)` in the backing allocation. | ||
| base: *const T, |
There was a problem hiding this comment.
It's possible we might be able to reuse bits::ptr::SlicePtr for representing a pointer with a lifetime. As an added benefit, SlicePtr is already NonNull, so can help Row use the niche optimization out-of-the-box. Finally, the SlicePtr type already has the correct variance behavior built-in, so we don't need to replicate the &'a T and &'a mut T phantom data fields.
|
|
||
| #[inline] | ||
| fn index(&self, col: usize) -> &Self::Output { | ||
| assert!( |
There was a problem hiding this comment.
As a small stylistic suggestion - If get() returns a reference, then index can be expressed as a call to get with a panic on the None case - saving one extra unsafe call.
| // SAFETY: bounds checked. | ||
| unsafe { &mut *self.base.add(col_offset::<GROUP, PACK>(col)) } | ||
| } | ||
| } |
There was a problem hiding this comment.
Should we also implement std::ops::Index for RowMut?
Thanks Mark, for the amazing suggestions. I have addressed (1) and (2) and will soon add tests to cover (3). Please help with a follow up review. |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 11 out of 11 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| /// To reclaim the memory later, reconstruct the `Box` via | ||
| /// `Box::from_raw(slice_from_raw_parts_mut(ptr, len))`. |
There was a problem hiding this comment.
The doc comment for box_into_nonnull describes reclaiming memory as Box::from_raw(slice_from_raw_parts_mut(ptr, len)), but ptr here is a NonNull<u8> (byte pointer) and len needs to be in elements of T. As written, this guidance is easy to follow incorrectly and could lead to reconstructing a Box with the wrong element type/length (UB). Consider updating the comment to show casting ptr back to *mut T (and clarifying that len is the T element count, not bytes).
| /// To reclaim the memory later, reconstruct the `Box` via | |
| /// `Box::from_raw(slice_from_raw_parts_mut(ptr, len))`. | |
| /// To reclaim the memory later, you must reconstruct the `Box<[T]>` using the | |
| /// original element type `T` and the original element count `len` (in `T` | |
| /// elements, **not** bytes). For example: | |
| /// | |
| /// ```rust,ignore | |
| /// // `ptr` is the `NonNull<u8>` returned by `box_into_nonnull::<T>` | |
| /// // and `len` is the number of `T` elements in the original box. | |
| /// unsafe { | |
| /// let ptr_t: *mut T = ptr.cast().as_ptr(); | |
| /// let slice: &mut [T] = std::slice::from_raw_parts_mut(ptr_t, len); | |
| /// let boxed: Box<[T]> = Box::from_raw(slice); | |
| /// } | |
| /// ``` |
| /// | ||
| /// To reclaim the memory later, reconstruct the `Box` via | ||
| /// `Box::from_raw(slice_from_raw_parts_mut(ptr, len))`. | ||
| pub(crate) fn box_into_nonnull<T>(b: Box<[T]>) -> NonNull<u8> { |
There was a problem hiding this comment.
I think this should return NonNull<T> and leave it to the caller to do the casting - similar to how the as_nonnull functions work. This also simplifies the safety documentation.
Summary
Introduces
BlockTransposed<T, GROUP, PACK>, a newReprimplementation that stores multi-vectors in a block-transposed layout optimized for SIMD processing. This enables faster multi-vector distance computations by ensuring data is arranged for efficient vectorized access patterns.Key changes
BlockTransposed<T, GROUP, PACK>: Generic block-transposed matrix representation where groups ofGROUProws are stored transposed, with an optionalPACKfactor for column interleaving (e.g.PACK=2forvpmaddwd-style instructions).BlockTransposedRow/BlockTransposedRowMutwith indexed access, iteration, andSend/Syncsupport.new_block_transposed,from_strided,from_matrix_view— generic over anyT: Copy + Default.block()/block_mut()/remainder_block()/remainder_block_mut()returningMatrixView/MutMatrixViewwith dimensions(padded_ncols/PACK, GROUP*PACK)for direct SIMD consumption.GROUP > 0,PACK > 0,GROUP % PACK == 0enforced at monomorphization.f32,i32,u8), bounds-checking panic tests, and Miri-compatible paths.