Skip to content

Commit 6d390d6

Browse files
committed
remove QJL
We are going to implement this later as a separate encoding (if we decide to implement it at all because word on the street is that the MSE + QJL is not actually better than MSE on its own). Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent 57bac7f commit 6d390d6

11 files changed

Lines changed: 171 additions & 869 deletions

File tree

vortex-btrblocks/src/builder.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,8 @@ impl BtrBlocksCompressorBuilder {
140140

141141
/// Adds the TurboQuant lossy vector quantization scheme.
142142
///
143-
/// When enabled, [`Vector`] extension arrays are compressed using the TurboQuant algorithm with
144-
/// QJL correction for unbiased inner product estimation.
143+
/// When enabled, [`Vector`] extension arrays are compressed using the TurboQuant algorithm
144+
/// with MSE-optimal scalar quantization.
145145
///
146146
/// # Panics
147147
///

vortex-tensor/src/encodings/turboquant/array.rs

Lines changed: 10 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
//! TurboQuant array definition: stores quantized coordinate codes, norms,
5-
//! centroids (codebook), rotation signs, and optional QJL correction fields.
5+
//! centroids (codebook), and rotation signs.
66
77
use vortex_array::ArrayId;
88
use vortex_array::ArrayRef;
@@ -32,38 +32,6 @@ pub struct TurboQuantMetadata {
3232
/// MSE bits per coordinate (1-8).
3333
#[prost(uint32, tag = "2")]
3434
pub bit_width: u32,
35-
/// Whether QJL correction children are present.
36-
#[prost(bool, tag = "3")]
37-
pub has_qjl: bool,
38-
}
39-
40-
/// Optional QJL (Quantized Johnson-Lindenstrauss) correction for unbiased
41-
/// inner product estimation. When present, adds 3 additional children.
42-
#[derive(Clone, Debug)]
43-
pub struct QjlCorrection {
44-
/// Sign bits: `BoolArray`, length `num_rows * padded_dim`.
45-
pub(crate) signs: ArrayRef,
46-
/// Residual norms: `PrimitiveArray<f32>`, length `num_rows`.
47-
pub(crate) residual_norms: ArrayRef,
48-
/// QJL rotation signs: `BoolArray`, length `3 * padded_dim` (inverse order).
49-
pub(crate) rotation_signs: ArrayRef,
50-
}
51-
52-
impl QjlCorrection {
53-
/// The QJL sign bits.
54-
pub fn signs(&self) -> &ArrayRef {
55-
&self.signs
56-
}
57-
58-
/// The residual norms.
59-
pub fn residual_norms(&self) -> &ArrayRef {
60-
&self.residual_norms
61-
}
62-
63-
/// The QJL rotation signs (BoolArray, inverse application order).
64-
pub fn rotation_signs(&self) -> &ArrayRef {
65-
&self.rotation_signs
66-
}
6735
}
6836

6937
/// Slot positions for TurboQuantArray children.
@@ -74,23 +42,17 @@ pub(crate) enum Slot {
7442
Norms = 1,
7543
Centroids = 2,
7644
RotationSigns = 3,
77-
QjlSigns = 4,
78-
QjlResidualNorms = 5,
79-
QjlRotationSigns = 6,
8045
}
8146

8247
impl Slot {
83-
pub(crate) const COUNT: usize = 7;
48+
pub(crate) const COUNT: usize = 4;
8449

8550
pub(crate) fn name(self) -> &'static str {
8651
match self {
8752
Self::Codes => "codes",
8853
Self::Norms => "norms",
8954
Self::Centroids => "centroids",
9055
Self::RotationSigns => "rotation_signs",
91-
Self::QjlSigns => "qjl_signs",
92-
Self::QjlResidualNorms => "qjl_residual_norms",
93-
Self::QjlRotationSigns => "qjl_rotation_signs",
9456
}
9557
}
9658

@@ -100,26 +62,18 @@ impl Slot {
10062
1 => Self::Norms,
10163
2 => Self::Centroids,
10264
3 => Self::RotationSigns,
103-
4 => Self::QjlSigns,
104-
5 => Self::QjlResidualNorms,
105-
6 => Self::QjlRotationSigns,
10665
_ => vortex_error::vortex_panic!("invalid slot index {idx}"),
10766
}
10867
}
10968
}
11069

11170
/// TurboQuant array.
11271
///
113-
/// Slots (always present):
114-
/// - 0: `codes` — `FixedSizeListArray<u8>` (quantized indices, list_size=padded_dim)
115-
/// - 1: `norms` — `PrimitiveArray<f32>` (one per vector row)
116-
/// - 2: `centroids` — `PrimitiveArray<f32>` (codebook, length 2^bit_width)
117-
/// - 3: `rotation_signs` — `BitPackedArray` (3 * padded_dim, 1-bit u8 0/1, inverse order)
118-
///
119-
/// Optional QJL slots (None when MSE-only):
120-
/// - 4: `qjl_signs` — `FixedSizeListArray<u8>` (num_rows * padded_dim, 1-bit)
121-
/// - 5: `qjl_residual_norms` — `PrimitiveArray<f32>` (one per row)
122-
/// - 6: `qjl_rotation_signs` — `BitPackedArray` (3 * padded_dim, 1-bit, QJL rotation)
72+
/// Slots:
73+
/// - 0: `codes` -- `FixedSizeListArray<u8>` (quantized indices, list_size=padded_dim).
74+
/// - 1: `norms` -- `PrimitiveArray<f32>` (one per vector row).
75+
/// - 2: `centroids` -- `PrimitiveArray<f32>` (codebook, length 2^bit_width).
76+
/// - 3: `rotation_signs` -- `BitPackedArray` (3 * padded_dim, 1-bit u8 0/1, inverse order).
12377
#[derive(Clone, Debug)]
12478
pub struct TurboQuantData {
12579
pub(crate) dtype: DType,
@@ -130,9 +84,9 @@ pub struct TurboQuantData {
13084
}
13185

13286
impl TurboQuantData {
133-
/// Build a TurboQuant array with MSE-only encoding (no QJL correction).
87+
/// Build a TurboQuant array.
13488
#[allow(clippy::too_many_arguments)]
135-
pub fn try_new_mse(
89+
pub fn try_new(
13690
dtype: DType,
13791
codes: ArrayRef,
13892
norms: ArrayRef,
@@ -143,7 +97,7 @@ impl TurboQuantData {
14397
) -> VortexResult<Self> {
14498
vortex_ensure!(
14599
(1..=8).contains(&bit_width),
146-
"MSE bit_width must be 1-8, got {bit_width}"
100+
"bit_width must be 1-8, got {bit_width}"
147101
);
148102
let mut slots = vec![None; Slot::COUNT];
149103
slots[Slot::Codes as usize] = Some(codes);
@@ -159,39 +113,6 @@ impl TurboQuantData {
159113
})
160114
}
161115

162-
/// Build a TurboQuant array with QJL correction (MSE + QJL).
163-
#[allow(clippy::too_many_arguments)]
164-
pub fn try_new_qjl(
165-
dtype: DType,
166-
codes: ArrayRef,
167-
norms: ArrayRef,
168-
centroids: ArrayRef,
169-
rotation_signs: ArrayRef,
170-
qjl: QjlCorrection,
171-
dimension: u32,
172-
bit_width: u8,
173-
) -> VortexResult<Self> {
174-
vortex_ensure!(
175-
(1..=8).contains(&bit_width),
176-
"MSE bit_width must be 1-8, got {bit_width}"
177-
);
178-
let mut slots = vec![None; Slot::COUNT];
179-
slots[Slot::Codes as usize] = Some(codes);
180-
slots[Slot::Norms as usize] = Some(norms);
181-
slots[Slot::Centroids as usize] = Some(centroids);
182-
slots[Slot::RotationSigns as usize] = Some(rotation_signs);
183-
slots[Slot::QjlSigns as usize] = Some(qjl.signs);
184-
slots[Slot::QjlResidualNorms as usize] = Some(qjl.residual_norms);
185-
slots[Slot::QjlRotationSigns as usize] = Some(qjl.rotation_signs);
186-
Ok(Self {
187-
dtype,
188-
slots,
189-
dimension,
190-
bit_width,
191-
stats_set: Default::default(),
192-
})
193-
}
194-
195116
/// The vector dimension d.
196117
pub fn dimension(&self) -> u32 {
197118
self.dimension
@@ -207,11 +128,6 @@ impl TurboQuantData {
207128
self.dimension.next_power_of_two()
208129
}
209130

210-
/// Whether QJL correction is present.
211-
pub fn has_qjl(&self) -> bool {
212-
self.slots[Slot::QjlSigns as usize].is_some()
213-
}
214-
215131
fn slot(&self, idx: usize) -> &ArrayRef {
216132
self.slots[idx]
217133
.as_ref()
@@ -237,20 +153,4 @@ impl TurboQuantData {
237153
pub fn rotation_signs(&self) -> &ArrayRef {
238154
self.slot(Slot::RotationSigns as usize)
239155
}
240-
241-
/// The optional QJL correction fields, reconstructed from slots.
242-
pub fn qjl(&self) -> Option<QjlCorrection> {
243-
Some(QjlCorrection {
244-
signs: self.slots[Slot::QjlSigns as usize].clone()?,
245-
residual_norms: self.slots[Slot::QjlResidualNorms as usize].clone()?,
246-
rotation_signs: self.slots[Slot::QjlRotationSigns as usize].clone()?,
247-
})
248-
}
249-
250-
/// Set the QJL correction fields on this array.
251-
pub(crate) fn set_qjl(&mut self, qjl: QjlCorrection) {
252-
self.slots[Slot::QjlSigns as usize] = Some(qjl.signs);
253-
self.slots[Slot::QjlResidualNorms as usize] = Some(qjl.residual_norms);
254-
self.slots[Slot::QjlRotationSigns as usize] = Some(qjl.rotation_signs);
255-
}
256156
}

0 commit comments

Comments
 (0)