From 1eab710a9bdbaaffa438d3c917cdc78875e5d76d Mon Sep 17 00:00:00 2001 From: "sota.n" Date: Mon, 18 May 2026 18:57:32 +0900 Subject: [PATCH] feat: add --check-env preflight flag for OOM risk detection Adds a --check-env CLI flag that collects physical hardware characteristics (GPU VRAM, CPU RAM, disk space) and model memory estimates before quantization starts, then classifies OOM risk as safe/warning/danger. Exits with code 1 on danger; otherwise prints a report and proceeds with quantization. - onecomp/utils/vram_estimator.py: new EnvironmentSnapshot, ModelMemoryProfile, EnvCheckResult dataclasses; check_environment() and print_env_report() functions reusing existing weight_memory_gb() and estimate_target_bitwidth() - onecomp/utils/__init__.py: export 5 new public symbols - onecomp/cli.py: --check-env argparse flag with preflight invocation - onecomp/runner.py: check_env=False kwarg in auto_run() for library API use - pyproject.toml: optional extras [check-env] = ["psutil>=5.9"] Co-Authored-By: Claude Sonnet 4.6 --- onecomp/cli.py | 25 +++ onecomp/runner.py | 19 +++ onecomp/utils/__init__.py | 5 + onecomp/utils/vram_estimator.py | 274 ++++++++++++++++++++++++++++++++ pyproject.toml | 2 + 5 files changed, 325 insertions(+) diff --git a/onecomp/cli.py b/onecomp/cli.py index 37497a4..0e684f0 100644 --- a/onecomp/cli.py +++ b/onecomp/cli.py @@ -63,6 +63,14 @@ def main(): default="auto", help='save directory (default: auto-generated, "none" to skip)', ) + parser.add_argument( + "--check-env", + action="store_true", + help=( + "Print an environment and memory report before quantization. " + "Exits with code 1 if OOM risk is 'danger'." + ), + ) parser.add_argument( "--version", action="version", @@ -76,6 +84,23 @@ def main(): # Lazy import to keep --help fast from .runner import Runner # pylint: disable=import-outside-toplevel + if args.check_env: + import sys # pylint: disable=import-outside-toplevel + from .utils.vram_estimator import ( # pylint: disable=import-outside-toplevel + check_environment, + print_env_report, + ) + + env_result = check_environment( + args.model_id, + total_vram_gb=args.total_vram_gb, + group_size=args.groupsize, + save_dir=save_dir if isinstance(save_dir, str) and save_dir != "auto" else None, + ) + print_env_report(env_result, total_vram_gb_override=args.total_vram_gb) + if env_result.risk == "danger": + sys.exit(1) + Runner.auto_run( model_id=args.model_id, wbits=args.wbits, diff --git a/onecomp/runner.py b/onecomp/runner.py index f7153d4..85b0c02 100644 --- a/onecomp/runner.py +++ b/onecomp/runner.py @@ -417,6 +417,7 @@ def auto_run( evaluate: bool = True, eval_original_model: bool = False, save_dir: str = "auto", + check_env: bool = False, **kwargs, ): """One-liner quantization with sensible defaults. @@ -487,6 +488,24 @@ def auto_run( setup_logger() logger = getLogger(__name__) + if check_env: + from .utils.vram_estimator import ( # pylint: disable=import-outside-toplevel + check_environment, + print_env_report, + ) + + env_result = check_environment( + model_id, + total_vram_gb=total_vram_gb, + group_size=groupsize, + save_dir=save_dir if isinstance(save_dir, str) and save_dir != "auto" else None, + ) + print_env_report(env_result, total_vram_gb_override=total_vram_gb) + if env_result.risk == "danger": + raise RuntimeError( + f"Environment check failed (OOM risk=danger): {env_result.risk_detail}" + ) + candidate_bits = (2, 3, 4, 8) if wbits is None: diff --git a/onecomp/utils/__init__.py b/onecomp/utils/__init__.py index 8bbcba6..88b0523 100644 --- a/onecomp/utils/__init__.py +++ b/onecomp/utils/__init__.py @@ -25,6 +25,11 @@ effective_bits_for_quantizer, weight_memory_gb, VRAMBitwidthEstimation, + EnvironmentSnapshot, + ModelMemoryProfile, + EnvCheckResult, + check_environment, + print_env_report, ) from .model_inputs import add_model_specific_inputs diff --git a/onecomp/utils/vram_estimator.py b/onecomp/utils/vram_estimator.py index 7320c38..84e81cc 100644 --- a/onecomp/utils/vram_estimator.py +++ b/onecomp/utils/vram_estimator.py @@ -153,6 +153,42 @@ class VRAMBitwidthEstimation: meta_bits_per_param: float +@dataclass +class EnvironmentSnapshot: + """Physical hardware readings at check-env time.""" + + gpu_count: int + gpu_name: str | None + gpu_total_vram_gb: float | None + gpu_free_vram_gb: float | None + ram_total_gb: float | None + ram_available_gb: float | None + disk_available_gb: float | None + disk_path: str + + +@dataclass +class ModelMemoryProfile: + """Derived memory footprint for the target model.""" + + total_params: int + fp16_gb: float + quantized_gb: dict + calibration_overhead_gb: float + + +@dataclass +class EnvCheckResult: + """Composite result returned by check_environment().""" + + model_id: str + env: EnvironmentSnapshot + model: ModelMemoryProfile + estimation: VRAMBitwidthEstimation | None + risk: str + risk_detail: str + + def estimate_target_bitwidth( model: torch.nn.Module, vram_ratio: float = 0.70, @@ -322,3 +358,241 @@ def estimate_wbits_from_vram( wbits=wbits, logger=logger, ) + + +def check_environment( + model_id: str, + *, + total_vram_gb: float | None = None, + group_size: int = 128, + save_dir: str | None = None, + vram_ratio: float = 0.80, + calibration_overhead_ratio: float = 0.15, +) -> EnvCheckResult: + """Collect hardware info and estimate OOM risk before quantization. + + Loads the model architecture on a ``meta`` device (no GPU/CPU memory) + to count parameters, then compares available VRAM against estimated + memory requirements at 2/4/8-bit quantization. + + Args: + model_id: Hugging Face model ID or local path. + total_vram_gb: Override GPU VRAM in GB for estimation math only. + Physical GPU readings are always from the real device. + group_size: GPTQ group size for metadata calculation. + save_dir: Path used for disk-space check. Defaults to cwd. + vram_ratio: Fraction of VRAM allocated for the estimation budget. + calibration_overhead_ratio: Calibration activation buffer as a + fraction of the FP16 model footprint (default 15 %). + + Returns: + :class:`EnvCheckResult` with hardware snapshot, memory profile, + VRAM estimation, and risk level (``"safe"``, ``"warning"``, + ``"danger"``, or ``"unknown"``). + """ + import os + import pathlib + import shutil + + from transformers import AutoConfig, AutoModelForCausalLM + + # --- GPU snapshot -------------------------------------------------------- + gpu_count = torch.cuda.device_count() + if gpu_count > 0: + dev = torch.cuda.current_device() + props = torch.cuda.get_device_properties(dev) + gpu_name = props.name + gpu_total_vram_gb = props.total_memory / _BYTES_PER_GB + try: + free_bytes, _ = torch.cuda.mem_get_info(dev) + gpu_free_vram_gb = free_bytes / _BYTES_PER_GB + except Exception: + gpu_free_vram_gb = None + else: + gpu_name = None + gpu_total_vram_gb = None + gpu_free_vram_gb = None + + # --- CPU RAM (psutil optional) ------------------------------------------- + try: + import psutil + + vm = psutil.virtual_memory() + ram_total_gb = vm.total / _BYTES_PER_GB + ram_available_gb = vm.available / _BYTES_PER_GB + except ImportError: + ram_total_gb = None + ram_available_gb = None + + # --- Disk space (stdlib) ------------------------------------------------- + check_path = save_dir if save_dir else os.getcwd() + p = pathlib.Path(check_path) + while not p.exists(): + p = p.parent + disk_available_gb = shutil.disk_usage(p).free / _BYTES_PER_GB + + # --- Model memory profile ------------------------------------------------ + config = AutoConfig.from_pretrained(model_id) + with torch.device("meta"): + model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16) + + total_params = sum(p.numel() for p in model.parameters()) + fp16_gb = (total_params * 2) / _BYTES_PER_GB + quantized_gb = {b: weight_memory_gb(total_params, b, group_size) for b in (2, 4, 8)} + calibration_overhead_gb = fp16_gb * calibration_overhead_ratio + + # --- VRAM bitwidth estimation (reuse existing) --------------------------- + try: + estimation = estimate_target_bitwidth( + model, + vram_ratio=vram_ratio, + total_vram_gb=total_vram_gb, + group_size=group_size, + ) + except (RuntimeError, ValueError): + estimation = None + + # --- OOM risk assessment ------------------------------------------------- + # Use free VRAM (runtime reality) when available; fall back to override. + effective_vram = gpu_free_vram_gb if gpu_free_vram_gb is not None else total_vram_gb + + if effective_vram is None: + risk = "unknown" + risk_detail = "No GPU detected and no --total-vram-gb provided." + else: + need_4bit = quantized_gb[4] + calibration_overhead_gb + if effective_vram >= fp16_gb * 1.2: + risk = "safe" + risk_detail = ( + f"Free VRAM ({effective_vram:.1f} GB) comfortably fits " + f"even FP16 weights ({fp16_gb:.1f} GB × 1.2)." + ) + elif effective_vram >= need_4bit: + risk = "warning" + risk_detail = ( + f"Free VRAM ({effective_vram:.1f} GB) fits 4-bit quantized " + f"weights but is tight (calibration overhead included)." + ) + else: + risk = "danger" + risk_detail = ( + f"Free VRAM ({effective_vram:.1f} GB) is insufficient for " + f"4-bit + calibration ({need_4bit:.1f} GB needed)." + ) + + return EnvCheckResult( + model_id=model_id, + env=EnvironmentSnapshot( + gpu_count=gpu_count, + gpu_name=gpu_name, + gpu_total_vram_gb=gpu_total_vram_gb, + gpu_free_vram_gb=gpu_free_vram_gb, + ram_total_gb=ram_total_gb, + ram_available_gb=ram_available_gb, + disk_available_gb=disk_available_gb, + disk_path=str(p), + ), + model=ModelMemoryProfile( + total_params=total_params, + fp16_gb=fp16_gb, + quantized_gb=quantized_gb, + calibration_overhead_gb=calibration_overhead_gb, + ), + estimation=estimation, + risk=risk, + risk_detail=risk_detail, + ) + + +def print_env_report(result: EnvCheckResult, *, total_vram_gb_override: float | None = None) -> None: + """Print a human-readable environment and OOM risk report to stdout. + + Args: + result: The :class:`EnvCheckResult` from :func:`check_environment`. + total_vram_gb_override: When not ``None``, annotates the VRAM budget + line with ``[--total-vram-gb override]``. + """ + _W = 60 + _SEP = "=" * _W + _COL = 22 + + def _row(label: str, value: str) -> str: + return f" {label:<{_COL}}: {value}" + + risk_labels = { + "safe": "SAFE", + "warning": "WARNING", + "danger": "DANGER !!", + "unknown": "UNKNOWN", + } + risk_label = risk_labels.get(result.risk, result.risk.upper()) + + e = result.env + m = result.model + + print(_SEP) + print(" OneComp Environment Check") + print(_SEP) + print() + + # Hardware + print("Hardware") + print(_row("GPU count", str(e.gpu_count))) + if e.gpu_name is not None: + print(_row("GPU name", e.gpu_name)) + if e.gpu_total_vram_gb is not None: + label = "GPU VRAM (total)" + value = f"{e.gpu_total_vram_gb:.1f} GB" + if total_vram_gb_override is not None: + value += " [physical]" + print(_row(label, value)) + if total_vram_gb_override is not None: + print(_row("VRAM budget used", f"{total_vram_gb_override:.1f} GB [--total-vram-gb override]")) + if e.gpu_free_vram_gb is not None: + print(_row("GPU VRAM (free)", f"{e.gpu_free_vram_gb:.1f} GB")) + if e.ram_total_gb is not None: + print(_row("CPU RAM (total)", f"{e.ram_total_gb:.1f} GB")) + print(_row("CPU RAM (avail)", f"{e.ram_available_gb:.1f} GB")) + else: + print(_row("CPU RAM", "n/a (install psutil for RAM info)")) + print(_row("Disk (avail)", f"{e.disk_available_gb:.1f} GB [{e.disk_path}]")) + print() + + # Model + print(f"Model: {result.model_id}") + print(_row("Parameters", f"{m.total_params:,}")) + print(_row("FP16 footprint", f"{m.fp16_gb:.2f} GB")) + print() + + # Memory estimates + gs = "(group_size varies)" + print(f"Memory Estimates") + for bits in (2, 4, 8): + print(_row(f"{bits}-bit quantized", f"{m.quantized_gb[bits]:.2f} GB")) + print(_row("Calib. overhead", f"{m.calibration_overhead_gb:.2f} GB (15% of FP16)")) + print(_row("4-bit + overhead", f"{m.quantized_gb[4] + m.calibration_overhead_gb:.2f} GB")) + print() + + # OOM risk + print("OOM Risk Assessment") + print(_row("Risk level", risk_label)) + detail_words = result.risk_detail.split() + detail_line = "" + detail_lines = [] + for word in detail_words: + if len(detail_line) + len(word) + 1 > 34: + detail_lines.append(detail_line) + detail_line = word + else: + detail_line = (detail_line + " " + word).lstrip() + if detail_line: + detail_lines.append(detail_line) + for i, dl in enumerate(detail_lines): + if i == 0: + print(_row("Detail", dl)) + else: + print(f" {'':<{_COL}} {dl}") + print() + if result.estimation is not None: + print(_row("Recommended wbits", f"{result.estimation.target_bitwidth:.2f} (VRAM-estimated)")) + print(_SEP) diff --git a/pyproject.toml b/pyproject.toml index b874ba1..9f82301 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,10 +53,12 @@ cu124 = ["torch", "torchvision"] cu126 = ["torch", "torchvision"] cu128 = ["torch", "torchvision"] cu130 = ["torch", "torchvision"] +check-env = ["psutil>=5.9"] dev = [ "black", "hydra-core", "pylint", + "psutil>=5.9", "pytest", ] visualize = [