Skip to content

fix: Add NVIDIA Blackwell (RTX 50xx, sm_120) GPU support#4155

Open
Hasham-dev wants to merge 1 commit into
lllyasviel:mainfrom
Hasham-dev:fix/blackwell-rtx50xx-support
Open

fix: Add NVIDIA Blackwell (RTX 50xx, sm_120) GPU support#4155
Hasham-dev wants to merge 1 commit into
lllyasviel:mainfrom
Hasham-dev:fix/blackwell-rtx50xx-support

Conversation

@Hasham-dev

Copy link
Copy Markdown

Summary

Minimal fixes to support NVIDIA Blackwell architecture GPUs (RTX 5050/5060/5070/5080/5090, compute capability sm_120) on Fooocus.

  • bfloat16 UNet dtype for Blackwell GPUs (compute major >= 12) which have native bf16 tensor core support
  • Skip manual_cast for bf16 weights to avoid unnecessary dtype casting overhead
  • Fix numpy TypeError with bfloat16 tensors in modules/patch.py and extras/ip_adapter.py — numpy doesn't support bf16, so we convert to float32 before .numpy() calls

Changes (3 files, +13 -2 lines)

File Change
ldm_patched/modules/model_management.py Auto-detect Blackwell GPUs and use bf16 dtype; skip manual_cast for bf16
modules/patch.py Fix bf16→numpy crash in patched_unet_forward
extras/ip_adapter.py Fix bf16→numpy crash in IP-Adapter attention patcher

Testing

  • GPU: NVIDIA GeForce RTX 5070 (sm_120, 11.5GB VRAM)
  • CUDA: 12.8
  • PyTorch: 2.12.0.dev (nightly, cu128)
  • Result: Image generation works at ~3.2 it/s at 896x1152, including Image Prompt (IP-Adapter) mode
  • VAE note: Users will also need madebyollin/sdxl-vae-fp16-fix VAE for stable bf16 decoding (the default SDXL VAE overflows in bf16)

Fixes

Prerequisites

Users with Blackwell GPUs need PyTorch nightly with CUDA 12.8 support:

pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128

- Use bfloat16 dtype for UNet on Blackwell GPUs (compute major >= 12)
  which have native bf16 tensor core support
- Skip manual_cast for bfloat16 weights to avoid unnecessary casting
- Fix numpy TypeError with bfloat16 tensors in patch.py and
  ip_adapter.py by converting to float32 before .numpy() calls

Tested on RTX 5070 (sm_120, CUDA 12.8) with PyTorch nightly (cu128).
Generates images at ~3.2 it/s including Image Prompt (IP-Adapter) mode.

Fixes lllyasviel#3862, lllyasviel#4123, lllyasviel#4141
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

1 participant