From 002837c4703acf90bf3fd49ef3e55036c684af75 Mon Sep 17 00:00:00 2001 From: Gayathri Srividya Rajavarapu Date: Sat, 30 May 2026 19:08:26 +0530 Subject: [PATCH] fix: normalize dictionary types in Arrow scans --- pyiceberg/io/pyarrow.py | 37 ++++++++++++++++++++++++++++++++++- tests/io/test_pyarrow.py | 42 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 1 deletion(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 4ec7a73afe..2b2f4e0ffc 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1205,6 +1205,37 @@ def _pyarrow_schema_ensure_small_types(schema: pa.Schema) -> pa.Schema: return visit_pyarrow(schema, _ConvertToSmallTypes()) +def _pyarrow_type_ensure_non_dictionary_types(field_type: pa.DataType) -> pa.DataType: + if pa.types.is_dictionary(field_type): + return _pyarrow_type_ensure_non_dictionary_types(field_type.value_type) + elif pa.types.is_struct(field_type): + return pa.struct([field.with_type(_pyarrow_type_ensure_non_dictionary_types(field.type)) for field in field_type]) + elif pa.types.is_list(field_type): + return pa.list_(field_type.value_field.with_type(_pyarrow_type_ensure_non_dictionary_types(field_type.value_type))) + elif pa.types.is_large_list(field_type): + return pa.large_list(field_type.value_field.with_type(_pyarrow_type_ensure_non_dictionary_types(field_type.value_type))) + elif pa.types.is_fixed_size_list(field_type): + return pa.list_( + field_type.value_field.with_type(_pyarrow_type_ensure_non_dictionary_types(field_type.value_type)), + field_type.list_size, + ) + elif pa.types.is_map(field_type): + return pa.map_( + field_type.key_field.with_type(_pyarrow_type_ensure_non_dictionary_types(field_type.key_type)), + field_type.item_field.with_type(_pyarrow_type_ensure_non_dictionary_types(field_type.item_type)), + keys_sorted=field_type.keys_sorted, + ) + return field_type + + +def _pyarrow_table_ensure_non_dictionary_types(table: pa.Table) -> pa.Table: + schema = pa.schema( + [field.with_type(_pyarrow_type_ensure_non_dictionary_types(field.type)) for field in table.schema], + metadata=table.schema.metadata, + ) + return table.cast(schema) if schema != table.schema else table + + @singledispatch def visit_pyarrow(obj: pa.DataType | pa.Schema, visitor: PyArrowSchemaVisitor[T]) -> T: """Apply a pyarrow schema visitor to any point within a schema. @@ -1795,7 +1826,11 @@ def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table: # Note: cannot use pa.Table.from_batches(itertools.chain([first_batch], batches))) # as different batches can use different schema's (due to large_ types) result = pa.concat_tables( - (pa.Table.from_batches([batch]) for batch in itertools.chain([first_batch], batches)), promote_options="permissive" + ( + _pyarrow_table_ensure_non_dictionary_types(pa.Table.from_batches([batch])) + for batch in itertools.chain([first_batch], batches) + ), + promote_options="permissive", ) return result diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 2f36661a1f..1874cff50f 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -73,6 +73,7 @@ _ConvertToArrowSchema, _determine_partitions, _primitive_to_physical, + _pyarrow_table_ensure_non_dictionary_types, _read_deletes, _task_to_record_batches, _to_requested_schema, @@ -1301,6 +1302,47 @@ def test_projection_concat_files(schema_int: Schema, file_int: str) -> None: assert repr(result_table.schema) == "id: int32" +def test_arrow_scan_to_table_with_mixed_dictionary_and_plain_strings() -> None: + schema = Schema(NestedField(1, "foo", StringType(), required=False)) + scan = ArrowScan( + table_metadata=TableMetadataV2( + location="file://a/b/", + last_column_id=1, + format_version=2, + schemas=[schema], + partition_specs=[PartitionSpec()], + ), + io=PyArrowFileIO(), + projected_schema=schema, + row_filter=AlwaysTrue(), + ) + values = pa.array(["a"], type=pa.string()) + batches = iter([pa.record_batch([values], names=["foo"]), pa.record_batch([values.dictionary_encode()], names=["foo"])]) + + with patch.object(scan, "to_record_batches", return_value=batches): + assert scan.to_table([]).to_pydict() == {"foo": ["a", "a"]} + + +def test_pyarrow_table_ensure_non_dictionary_types_nested() -> None: + dictionary_values = pa.array(["a"]).dictionary_encode() + table = pa.table( + { + "struct": pa.StructArray.from_arrays([dictionary_values], names=["value"]), + "list": pa.ListArray.from_arrays(pa.array([0, 1]), dictionary_values), + } + ) + + normalized_table = _pyarrow_table_ensure_non_dictionary_types(table) + + assert normalized_table.schema == pa.schema( + [ + pa.field("struct", pa.struct([pa.field("value", pa.string())])), + pa.field("list", pa.list_(pa.string())), + ] + ) + assert normalized_table.to_pydict() == {"struct": [{"value": "a"}], "list": [["a"]]} + + def test_identity_transform_column_projection(tmp_path: str, catalog: InMemoryCatalog) -> None: # Test by adding a non-partitioned data file to a partitioned table, verifying partition value # projection from manifest metadata.