diff --git a/dev/bench_arrow_scan.py b/dev/bench_arrow_scan.py index 20218ac749..16361bb14b 100755 --- a/dev/bench_arrow_scan.py +++ b/dev/bench_arrow_scan.py @@ -39,6 +39,7 @@ import sys import time import tracemalloc +import warnings from dataclasses import asdict, dataclass from pathlib import Path from typing import Any @@ -84,6 +85,12 @@ class RunResult: tracemalloc_peak_mb: float +FALLBACK_WARNING_MARKERS = ( + "Falling back to PyArrow scan because pyiceberg-core cannot handle this scan", + "Falling back to native task-based scan because Rust-planned scan failed", +) + + def _catalog_props(config: CatalogConfig) -> dict[str, str]: return { "type": "rest", @@ -120,6 +127,7 @@ def scenarios(namespace: str, table_prefix: str, rows: int, delete_rows: int) -> many = _table_name(namespace, table_prefix, "many_files") partitioned = _table_name(namespace, table_prefix, "partitioned") deletes = _table_name(namespace, table_prefix, "pos_deletes") + many_manifests = _table_name(namespace, table_prefix, "many_manifests") halfway = rows // 2 delete_halfway = delete_rows // 2 return [ @@ -131,6 +139,8 @@ def scenarios(namespace: str, table_prefix: str, rows: int, delete_rows: int) -> Scenario("partitioned_project", partitioned, selected_fields=("id", "part")), Scenario("pos_deletes_full", deletes), Scenario("pos_deletes_filter", deletes, row_filter=f"id >= {delete_halfway}", selected_fields=("id", "value")), + Scenario("many_manifests_full", many_manifests), + Scenario("many_manifests_filter", many_manifests, row_filter="id = 50", selected_fields=("id", "value")), ] @@ -150,12 +160,18 @@ def provision(args: argparse.Namespace, config: CatalogConfig) -> None: many = _table_name(args.namespace, args.table_prefix, "many_files") partitioned = _table_name(args.namespace, args.table_prefix, "partitioned") deletes = _table_name(args.namespace, args.table_prefix, "pos_deletes") + many_manifests = _table_name(args.namespace, args.table_prefix, "many_manifests") if args.refresh: - for identifier in (many, partitioned, deletes): + for identifier in (many, partitioned, deletes, many_manifests): spark.sql(f"DROP TABLE IF EXISTS rest.{identifier}") - if _table_exists(catalog, many) and _table_exists(catalog, partitioned) and _table_exists(catalog, deletes): + if ( + _table_exists(catalog, many) + and _table_exists(catalog, partitioned) + and _table_exists(catalog, deletes) + and _table_exists(catalog, many_manifests) + ): print("Benchmark tables already exist; use --refresh to recreate them.") return @@ -176,6 +192,12 @@ def provision(args: argparse.Namespace, config: CatalogConfig) -> None: ).tableProperty("write.update.mode", "merge-on-read").tableProperty("write.merge.mode", "merge-on-read").createOrReplace() spark.sql(f"DELETE FROM rest.{deletes} WHERE id % 20 = 0") + print(f"Creating {many_manifests}: planning-heavy table with many manifests") + small_df = _benchmark_dataframe(spark, 100).repartition(1) + small_df.writeTo(f"rest.{many_manifests}").using("iceberg").tableProperty("format-version", "2").createOrReplace() + for _ in range(20): + small_df.writeTo(f"rest.{many_manifests}").append() + def _benchmark_dataframe(spark: Any, rows: int) -> Any: from pyspark.sql import functions as F @@ -196,17 +218,23 @@ def _table_exists(catalog: Any, identifier: str) -> bool: return False -def validate_scenarios(config: CatalogConfig, scenario_list: list[Scenario]) -> None: +def validate_scenarios(config: CatalogConfig, scenario_list: list[Scenario], engines: list[str]) -> None: failures: list[str] = [] for scenario in scenario_list: - pyarrow_summary = _run_scan(config, scenario, "pyarrow", validate_only=True) - native_summary = _run_scan(config, scenario, "native", validate_only=True) + summaries = {} + for engine in engines: + summaries[engine] = _run_scan(config, scenario, engine, validate_only=True) + + ref_engine = engines[0] + ref_summary = summaries[ref_engine] comparable_keys = ("rows", "batches", "columns", "checksum") - mismatches = [key for key in comparable_keys if pyarrow_summary.get(key) != native_summary.get(key)] - if mismatches: - failures.append( - f"{scenario.name}: mismatched {', '.join(mismatches)} pyarrow={pyarrow_summary} native={native_summary}" - ) + for engine in engines[1:]: + mismatches = [key for key in comparable_keys if ref_summary.get(key) != summaries[engine].get(key)] + if mismatches: + failures.append( + f"{scenario.name}: mismatched {', '.join(mismatches)} " + f"between {ref_engine}={ref_summary} and {engine}={summaries[engine]}" + ) if failures: raise RuntimeError("Validation failed:\n" + "\n".join(failures)) @@ -214,7 +242,8 @@ def validate_scenarios(config: CatalogConfig, scenario_list: list[Scenario]) -> def _run_scan(config: CatalogConfig, scenario: Scenario, engine: str, validate_only: bool = False) -> dict[str, Any]: from pyiceberg.catalog import load_catalog - os.environ["PYICEBERG_RUST_ARROW_SCAN"] = "1" if engine == "native" else "0" + os.environ["PYICEBERG_RUST_ARROW_SCAN"] = "1" if engine == "native-task" else "0" + os.environ["PYICEBERG_RUST_PLANNED_ARROW_SCAN"] = "1" if engine == "native-planned" else "0" catalog = load_catalog("default", **_catalog_props(config)) table = catalog.load_table(scenario.table) scan_kwargs: dict[str, Any] = {} @@ -232,12 +261,23 @@ def _run_scan(config: CatalogConfig, scenario: Scenario, engine: str, validate_o batches = 0 checksum = 0 columns: list[str] | None = None - for batch in table.scan(**scan_kwargs).to_arrow_batch_reader(): - if columns is None: - columns = batch.schema.names - rows += batch.num_rows - batches += 1 - checksum += _batch_checksum(batch) + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always") + for batch in table.scan(**scan_kwargs).to_arrow_batch_reader(): + if columns is None: + columns = batch.schema.names + rows += batch.num_rows + batches += 1 + checksum += _batch_checksum(batch) + + fallback_warnings = [ + str(warning.message) + for warning in caught_warnings + if any(marker in str(warning.message) for marker in FALLBACK_WARNING_MARKERS) + ] + if fallback_warnings: + raise RuntimeError(f"{engine} {scenario.name} used a fallback scan path: {fallback_warnings}") + elapsed_ms = (time.perf_counter() - start) * 1000 _, peak = tracemalloc.get_traced_memory() rss_after_mb = _current_rss_mb() @@ -364,13 +404,10 @@ def _rss_mb(process: Any, default: float = 0.0) -> float: return default -def run_benchmarks(args: argparse.Namespace, config: CatalogConfig, scenario_list: list[Scenario]) -> list[RunResult]: +def run_benchmarks( + args: argparse.Namespace, config: CatalogConfig, scenario_list: list[Scenario], engines: list[str] +) -> list[RunResult]: results: list[RunResult] = [] - engines = [] - if not args.native_only: - engines.append("pyarrow") - if not args.pyarrow_only: - engines.append("native") for scenario in scenario_list: for engine in engines: @@ -466,6 +503,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--skip-validation", action="store_true", help="Skip parity validation before timed runs") parser.add_argument("--native-only", action="store_true") parser.add_argument("--pyarrow-only", action="store_true") + parser.add_argument("--engines", default="pyarrow,native-task,native-planned", help="Comma-separated engines to run") parser.add_argument("--json-out", type=Path) parser.add_argument("--markdown-out", type=Path) return parser.parse_args() @@ -490,15 +528,22 @@ def main() -> None: if args.native_only and args.pyarrow_only: raise ValueError("--native-only and --pyarrow-only are mutually exclusive") + if args.native_only: + engines = ["native-task", "native-planned"] + elif args.pyarrow_only: + engines = ["pyarrow"] + else: + engines = [e.strip() for e in args.engines.split(",") if e.strip()] + if not args.skip_provision: provision(args, config) if not args.skip_validation: print("Validating native and PyArrow scan parity...") - validate_scenarios(config, scenario_list) + validate_scenarios(config, scenario_list, engines) print("Running benchmark. Memory caveat: RSS is for the Python benchmark/client process, not Spark/REST/MinIO containers.") - results = run_benchmarks(args, config, scenario_list) + results = run_benchmarks(args, config, scenario_list, engines) summary_rows = summarize(results) table = markdown_table(summary_rows) print("\n" + table)