-
Notifications
You must be signed in to change notification settings - Fork 147
feat: Count aggegate
#7267
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Count aggegate
#7267
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,266 @@ | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| // SPDX-FileCopyrightText: Copyright the Vortex contributors | ||
|
|
||
| use vortex_error::VortexExpect; | ||
| use vortex_error::VortexResult; | ||
|
|
||
| use crate::ArrayRef; | ||
| use crate::Columnar; | ||
| use crate::DynArray; | ||
| use crate::ExecutionCtx; | ||
| use crate::aggregate_fn::AggregateFnId; | ||
| use crate::aggregate_fn::AggregateFnVTable; | ||
| use crate::aggregate_fn::EmptyOptions; | ||
| use crate::dtype::DType; | ||
| use crate::dtype::Nullability; | ||
| use crate::dtype::PType; | ||
| use crate::scalar::Scalar; | ||
|
|
||
| /// Count the number of non-null elements in an array. | ||
| /// | ||
| /// Applies to all types. Returns a `u64` count. | ||
| /// The identity value is zero. | ||
| #[derive(Clone, Debug)] | ||
| pub struct Count; | ||
|
|
||
| impl AggregateFnVTable for Count { | ||
| type Options = EmptyOptions; | ||
| type Partial = u64; | ||
|
|
||
| fn id(&self) -> AggregateFnId { | ||
| AggregateFnId::new_ref("vortex.count") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are you sure (also @gatesn ) we want a global namespace of scalar-fns and aggregations? |
||
| } | ||
|
|
||
| fn serialize(&self, _options: &Self::Options) -> VortexResult<Option<Vec<u8>>> { | ||
| Ok(Some(vec![])) | ||
| } | ||
|
|
||
| fn deserialize( | ||
| &self, | ||
| _metadata: &[u8], | ||
| _session: &vortex_session::VortexSession, | ||
| ) -> VortexResult<Self::Options> { | ||
| Ok(EmptyOptions) | ||
| } | ||
|
|
||
| fn return_dtype(&self, _options: &Self::Options, _input_dtype: &DType) -> Option<DType> { | ||
| Some(DType::Primitive(PType::U64, Nullability::NonNullable)) | ||
| } | ||
|
|
||
| fn partial_dtype(&self, options: &Self::Options, input_dtype: &DType) -> Option<DType> { | ||
| self.return_dtype(options, input_dtype) | ||
| } | ||
|
|
||
| fn empty_partial( | ||
| &self, | ||
| _options: &Self::Options, | ||
| _input_dtype: &DType, | ||
| ) -> VortexResult<Self::Partial> { | ||
| Ok(0u64) | ||
| } | ||
|
|
||
| fn combine_partials(&self, partial: &mut Self::Partial, other: Scalar) -> VortexResult<()> { | ||
| let val = other | ||
| .as_primitive() | ||
| .typed_value::<u64>() | ||
| .vortex_expect("count partial should not be null"); | ||
| *partial += val; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what about overflow?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. from
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that's probably fair for a count to not exceed u64 |
||
| Ok(()) | ||
| } | ||
|
|
||
| fn to_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> { | ||
| Ok(Scalar::primitive(*partial, Nullability::NonNullable)) | ||
| } | ||
|
|
||
| fn reset(&self, partial: &mut Self::Partial) { | ||
| *partial = 0; | ||
| } | ||
|
|
||
| #[inline] | ||
| fn is_saturated(&self, _partial: &Self::Partial) -> bool { | ||
| false | ||
| } | ||
|
|
||
| fn try_accumulate( | ||
| &self, | ||
| state: &mut Self::Partial, | ||
| batch: &ArrayRef, | ||
| _ctx: &mut ExecutionCtx, | ||
| ) -> VortexResult<bool> { | ||
| *state += batch.valid_count()? as u64; | ||
| Ok(true) | ||
| } | ||
|
|
||
| fn accumulate( | ||
| &self, | ||
| _partial: &mut Self::Partial, | ||
| _batch: &Columnar, | ||
| _ctx: &mut ExecutionCtx, | ||
| ) -> VortexResult<()> { | ||
| unreachable!("Count::try_accumulate handles all arrays") | ||
| } | ||
|
|
||
| fn finalize(&self, partials: ArrayRef) -> VortexResult<ArrayRef> { | ||
| Ok(partials) | ||
| } | ||
|
|
||
| fn finalize_scalar(&self, partial: &Self::Partial) -> VortexResult<Scalar> { | ||
| self.to_scalar(partial) | ||
| } | ||
| } | ||
|
|
||
| #[cfg(test)] | ||
| mod tests { | ||
| use vortex_buffer::buffer; | ||
| use vortex_error::VortexExpect; | ||
| use vortex_error::VortexResult; | ||
|
|
||
| use crate::ArrayRef; | ||
| use crate::ExecutionCtx; | ||
| use crate::IntoArray; | ||
| use crate::LEGACY_SESSION; | ||
| use crate::VortexSessionExecute; | ||
| use crate::aggregate_fn::Accumulator; | ||
| use crate::aggregate_fn::AggregateFnVTable; | ||
| use crate::aggregate_fn::DynAccumulator; | ||
| use crate::aggregate_fn::EmptyOptions; | ||
| use crate::aggregate_fn::fns::count::Count; | ||
| use crate::arrays::ChunkedArray; | ||
| use crate::arrays::ConstantArray; | ||
| use crate::arrays::PrimitiveArray; | ||
| use crate::dtype::DType; | ||
| use crate::dtype::Nullability; | ||
| use crate::dtype::PType; | ||
| use crate::scalar::Scalar; | ||
| use crate::validity::Validity; | ||
|
|
||
| pub fn count(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<usize> { | ||
| let mut acc = Accumulator::try_new(Count, EmptyOptions, array.dtype().clone())?; | ||
| acc.accumulate(array, ctx)?; | ||
| let result = acc.finish()?; | ||
|
|
||
| Ok(usize::try_from( | ||
| result | ||
| .as_primitive() | ||
| .typed_value::<u64>() | ||
| .vortex_expect("count result should not be null"), | ||
| )?) | ||
| } | ||
|
|
||
| #[test] | ||
| fn count_all_valid() -> VortexResult<()> { | ||
| let array = | ||
| PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5], Validity::NonNullable).into_array(); | ||
| let mut ctx = LEGACY_SESSION.create_execution_ctx(); | ||
| assert_eq!(count(&array, &mut ctx)?, 5); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn count_with_nulls() -> VortexResult<()> { | ||
| let array = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5)]) | ||
| .into_array(); | ||
| let mut ctx = LEGACY_SESSION.create_execution_ctx(); | ||
| assert_eq!(count(&array, &mut ctx)?, 3); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn count_all_null() -> VortexResult<()> { | ||
| let array = PrimitiveArray::from_option_iter::<i32, _>([None, None, None]).into_array(); | ||
| let mut ctx = LEGACY_SESSION.create_execution_ctx(); | ||
| assert_eq!(count(&array, &mut ctx)?, 0); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn count_empty() -> VortexResult<()> { | ||
| let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); | ||
| let mut acc = Accumulator::try_new(Count, EmptyOptions, dtype)?; | ||
| let result = acc.finish()?; | ||
| assert_eq!(result.as_primitive().typed_value::<u64>(), Some(0)); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn count_multi_batch() -> VortexResult<()> { | ||
| let mut ctx = LEGACY_SESSION.create_execution_ctx(); | ||
| let dtype = DType::Primitive(PType::I32, Nullability::Nullable); | ||
| let mut acc = Accumulator::try_new(Count, EmptyOptions, dtype)?; | ||
|
|
||
| let batch1 = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]).into_array(); | ||
| acc.accumulate(&batch1, &mut ctx)?; | ||
|
|
||
| let batch2 = PrimitiveArray::from_option_iter([None, Some(5i32)]).into_array(); | ||
| acc.accumulate(&batch2, &mut ctx)?; | ||
|
|
||
| let result = acc.finish()?; | ||
| assert_eq!(result.as_primitive().typed_value::<u64>(), Some(3)); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn count_finish_resets_state() -> VortexResult<()> { | ||
| let mut ctx = LEGACY_SESSION.create_execution_ctx(); | ||
| let dtype = DType::Primitive(PType::I32, Nullability::Nullable); | ||
| let mut acc = Accumulator::try_new(Count, EmptyOptions, dtype)?; | ||
|
|
||
| let batch1 = PrimitiveArray::from_option_iter([Some(1i32), None]).into_array(); | ||
| acc.accumulate(&batch1, &mut ctx)?; | ||
| let result1 = acc.finish()?; | ||
| assert_eq!(result1.as_primitive().typed_value::<u64>(), Some(1)); | ||
|
|
||
| let batch2 = PrimitiveArray::from_option_iter([Some(2i32), Some(3), None]).into_array(); | ||
| acc.accumulate(&batch2, &mut ctx)?; | ||
| let result2 = acc.finish()?; | ||
| assert_eq!(result2.as_primitive().typed_value::<u64>(), Some(2)); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn count_state_merge() -> VortexResult<()> { | ||
| let dtype = DType::Primitive(PType::I32, Nullability::NonNullable); | ||
| let mut state = Count.empty_partial(&EmptyOptions, &dtype)?; | ||
|
|
||
| let scalar1 = Scalar::primitive(5u64, Nullability::NonNullable); | ||
| Count.combine_partials(&mut state, scalar1)?; | ||
|
|
||
| let scalar2 = Scalar::primitive(3u64, Nullability::NonNullable); | ||
| Count.combine_partials(&mut state, scalar2)?; | ||
|
|
||
| let result = Count.to_scalar(&state)?; | ||
| Count.reset(&mut state); | ||
| assert_eq!(result.as_primitive().typed_value::<u64>(), Some(8)); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn count_constant_non_null() -> VortexResult<()> { | ||
| let array = ConstantArray::new(42i32, 10); | ||
| let mut ctx = LEGACY_SESSION.create_execution_ctx(); | ||
| assert_eq!(count(&array.into_array(), &mut ctx)?, 10); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn count_constant_null() -> VortexResult<()> { | ||
| let array = ConstantArray::new( | ||
| Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)), | ||
| 10, | ||
| ); | ||
| let mut ctx = LEGACY_SESSION.create_execution_ctx(); | ||
| assert_eq!(count(&array.into_array(), &mut ctx)?, 0); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn count_chunked() -> VortexResult<()> { | ||
| let chunk1 = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3)]); | ||
| let chunk2 = PrimitiveArray::from_option_iter([None, Some(5i32), None]); | ||
| let dtype = chunk1.dtype().clone(); | ||
| let chunked = ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype)?; | ||
| let mut ctx = LEGACY_SESSION.create_execution_ctx(); | ||
| assert_eq!(count(&chunked.into_array(), &mut ctx)?, 3); | ||
| Ok(()) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -105,6 +105,22 @@ pub trait AggregateFnVTable: 'static + Sized + Clone + Send + Sync { | |
| /// final result is fully determined. | ||
| fn is_saturated(&self, state: &Self::Partial) -> bool; | ||
|
|
||
| /// Try to accumulate the raw array before decompression. | ||
| /// | ||
| /// Returns `true` if the array was handled, `false` to fall through to | ||
| /// the default kernel dispatch and canonicalization path. | ||
| /// | ||
| /// This is useful for aggregates that only depend on array metadata (e.g., validity) | ||
| /// rather than the encoded data, avoiding unnecessary decompression. | ||
| fn try_accumulate( | ||
| &self, | ||
| _state: &mut Self::Partial, | ||
| _batch: &ArrayRef, | ||
| _ctx: &mut ExecutionCtx, | ||
| ) -> VortexResult<bool> { | ||
| Ok(false) | ||
| } | ||
|
Comment on lines
+115
to
+122
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you remove this. We want a different system to short cut decompression, similar to execute_parent
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think then let's wait for that system to merge - any pr i should follow?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't this the equivalent of "reduce"? i.e. something that can operate without looking at data buffers. So maybe we just need to rename this to |
||
|
|
||
| /// Accumulate a new canonical array into the accumulator state. | ||
| fn accumulate( | ||
| &self, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.