22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
44//! TurboQuant array definition: stores quantized coordinate codes, norms,
5- //! centroids (codebook), rotation signs, and optional QJL correction fields .
5+ //! centroids (codebook), and rotation signs .
66
77use vortex_array:: ArrayId ;
88use vortex_array:: ArrayRef ;
@@ -32,38 +32,6 @@ pub struct TurboQuantMetadata {
3232 /// MSE bits per coordinate (1-8).
3333 #[ prost( uint32, tag = "2" ) ]
3434 pub bit_width : u32 ,
35- /// Whether QJL correction children are present.
36- #[ prost( bool , tag = "3" ) ]
37- pub has_qjl : bool ,
38- }
39-
40- /// Optional QJL (Quantized Johnson-Lindenstrauss) correction for unbiased
41- /// inner product estimation. When present, adds 3 additional children.
42- #[ derive( Clone , Debug ) ]
43- pub struct QjlCorrection {
44- /// Sign bits: `BoolArray`, length `num_rows * padded_dim`.
45- pub ( crate ) signs : ArrayRef ,
46- /// Residual norms: `PrimitiveArray<f32>`, length `num_rows`.
47- pub ( crate ) residual_norms : ArrayRef ,
48- /// QJL rotation signs: `BoolArray`, length `3 * padded_dim` (inverse order).
49- pub ( crate ) rotation_signs : ArrayRef ,
50- }
51-
52- impl QjlCorrection {
53- /// The QJL sign bits.
54- pub fn signs ( & self ) -> & ArrayRef {
55- & self . signs
56- }
57-
58- /// The residual norms.
59- pub fn residual_norms ( & self ) -> & ArrayRef {
60- & self . residual_norms
61- }
62-
63- /// The QJL rotation signs (BoolArray, inverse application order).
64- pub fn rotation_signs ( & self ) -> & ArrayRef {
65- & self . rotation_signs
66- }
6735}
6836
6937/// Slot positions for TurboQuantArray children.
@@ -74,23 +42,17 @@ pub(crate) enum Slot {
7442 Norms = 1 ,
7543 Centroids = 2 ,
7644 RotationSigns = 3 ,
77- QjlSigns = 4 ,
78- QjlResidualNorms = 5 ,
79- QjlRotationSigns = 6 ,
8045}
8146
8247impl Slot {
83- pub ( crate ) const COUNT : usize = 7 ;
48+ pub ( crate ) const COUNT : usize = 4 ;
8449
8550 pub ( crate ) fn name ( self ) -> & ' static str {
8651 match self {
8752 Self :: Codes => "codes" ,
8853 Self :: Norms => "norms" ,
8954 Self :: Centroids => "centroids" ,
9055 Self :: RotationSigns => "rotation_signs" ,
91- Self :: QjlSigns => "qjl_signs" ,
92- Self :: QjlResidualNorms => "qjl_residual_norms" ,
93- Self :: QjlRotationSigns => "qjl_rotation_signs" ,
9456 }
9557 }
9658
@@ -100,26 +62,18 @@ impl Slot {
10062 1 => Self :: Norms ,
10163 2 => Self :: Centroids ,
10264 3 => Self :: RotationSigns ,
103- 4 => Self :: QjlSigns ,
104- 5 => Self :: QjlResidualNorms ,
105- 6 => Self :: QjlRotationSigns ,
10665 _ => vortex_error:: vortex_panic!( "invalid slot index {idx}" ) ,
10766 }
10867 }
10968}
11069
11170/// TurboQuant array.
11271///
113- /// Slots (always present):
114- /// - 0: `codes` — `FixedSizeListArray<u8>` (quantized indices, list_size=padded_dim)
115- /// - 1: `norms` — `PrimitiveArray<f32>` (one per vector row)
116- /// - 2: `centroids` — `PrimitiveArray<f32>` (codebook, length 2^bit_width)
117- /// - 3: `rotation_signs` — `BitPackedArray` (3 * padded_dim, 1-bit u8 0/1, inverse order)
118- ///
119- /// Optional QJL slots (None when MSE-only):
120- /// - 4: `qjl_signs` — `FixedSizeListArray<u8>` (num_rows * padded_dim, 1-bit)
121- /// - 5: `qjl_residual_norms` — `PrimitiveArray<f32>` (one per row)
122- /// - 6: `qjl_rotation_signs` — `BitPackedArray` (3 * padded_dim, 1-bit, QJL rotation)
72+ /// Slots:
73+ /// - 0: `codes` -- `FixedSizeListArray<u8>` (quantized indices, list_size=padded_dim).
74+ /// - 1: `norms` -- `PrimitiveArray<f32>` (one per vector row).
75+ /// - 2: `centroids` -- `PrimitiveArray<f32>` (codebook, length 2^bit_width).
76+ /// - 3: `rotation_signs` -- `BitPackedArray` (3 * padded_dim, 1-bit u8 0/1, inverse order).
12377#[ derive( Clone , Debug ) ]
12478pub struct TurboQuantData {
12579 pub ( crate ) dtype : DType ,
@@ -130,9 +84,9 @@ pub struct TurboQuantData {
13084}
13185
13286impl TurboQuantData {
133- /// Build a TurboQuant array with MSE-only encoding (no QJL correction) .
87+ /// Build a TurboQuant array.
13488 #[ allow( clippy:: too_many_arguments) ]
135- pub fn try_new_mse (
89+ pub fn try_new (
13690 dtype : DType ,
13791 codes : ArrayRef ,
13892 norms : ArrayRef ,
@@ -143,7 +97,7 @@ impl TurboQuantData {
14397 ) -> VortexResult < Self > {
14498 vortex_ensure ! (
14599 ( 1 ..=8 ) . contains( & bit_width) ,
146- "MSE bit_width must be 1-8, got {bit_width}"
100+ "bit_width must be 1-8, got {bit_width}"
147101 ) ;
148102 let mut slots = vec ! [ None ; Slot :: COUNT ] ;
149103 slots[ Slot :: Codes as usize ] = Some ( codes) ;
@@ -159,39 +113,6 @@ impl TurboQuantData {
159113 } )
160114 }
161115
162- /// Build a TurboQuant array with QJL correction (MSE + QJL).
163- #[ allow( clippy:: too_many_arguments) ]
164- pub fn try_new_qjl (
165- dtype : DType ,
166- codes : ArrayRef ,
167- norms : ArrayRef ,
168- centroids : ArrayRef ,
169- rotation_signs : ArrayRef ,
170- qjl : QjlCorrection ,
171- dimension : u32 ,
172- bit_width : u8 ,
173- ) -> VortexResult < Self > {
174- vortex_ensure ! (
175- ( 1 ..=8 ) . contains( & bit_width) ,
176- "MSE bit_width must be 1-8, got {bit_width}"
177- ) ;
178- let mut slots = vec ! [ None ; Slot :: COUNT ] ;
179- slots[ Slot :: Codes as usize ] = Some ( codes) ;
180- slots[ Slot :: Norms as usize ] = Some ( norms) ;
181- slots[ Slot :: Centroids as usize ] = Some ( centroids) ;
182- slots[ Slot :: RotationSigns as usize ] = Some ( rotation_signs) ;
183- slots[ Slot :: QjlSigns as usize ] = Some ( qjl. signs ) ;
184- slots[ Slot :: QjlResidualNorms as usize ] = Some ( qjl. residual_norms ) ;
185- slots[ Slot :: QjlRotationSigns as usize ] = Some ( qjl. rotation_signs ) ;
186- Ok ( Self {
187- dtype,
188- slots,
189- dimension,
190- bit_width,
191- stats_set : Default :: default ( ) ,
192- } )
193- }
194-
195116 /// The vector dimension d.
196117 pub fn dimension ( & self ) -> u32 {
197118 self . dimension
@@ -207,11 +128,6 @@ impl TurboQuantData {
207128 self . dimension . next_power_of_two ( )
208129 }
209130
210- /// Whether QJL correction is present.
211- pub fn has_qjl ( & self ) -> bool {
212- self . slots [ Slot :: QjlSigns as usize ] . is_some ( )
213- }
214-
215131 fn slot ( & self , idx : usize ) -> & ArrayRef {
216132 self . slots [ idx]
217133 . as_ref ( )
@@ -237,20 +153,4 @@ impl TurboQuantData {
237153 pub fn rotation_signs ( & self ) -> & ArrayRef {
238154 self . slot ( Slot :: RotationSigns as usize )
239155 }
240-
241- /// The optional QJL correction fields, reconstructed from slots.
242- pub fn qjl ( & self ) -> Option < QjlCorrection > {
243- Some ( QjlCorrection {
244- signs : self . slots [ Slot :: QjlSigns as usize ] . clone ( ) ?,
245- residual_norms : self . slots [ Slot :: QjlResidualNorms as usize ] . clone ( ) ?,
246- rotation_signs : self . slots [ Slot :: QjlRotationSigns as usize ] . clone ( ) ?,
247- } )
248- }
249-
250- /// Set the QJL correction fields on this array.
251- pub ( crate ) fn set_qjl ( & mut self , qjl : QjlCorrection ) {
252- self . slots [ Slot :: QjlSigns as usize ] = Some ( qjl. signs ) ;
253- self . slots [ Slot :: QjlResidualNorms as usize ] = Some ( qjl. residual_norms ) ;
254- self . slots [ Slot :: QjlRotationSigns as usize ] = Some ( qjl. rotation_signs ) ;
255- }
256156}
0 commit comments