Skip to content

Add FLUX.2-klein-base-9B contrib model#146

Open
jimburtoft wants to merge 2 commits into
aws-neuron:mainfrom
jimburtoft:contrib/flux2-klein
Open

Add FLUX.2-klein-base-9B contrib model#146
jimburtoft wants to merge 2 commits into
aws-neuron:mainfrom
jimburtoft:contrib/flux2-klein

Conversation

@jimburtoft
Copy link
Copy Markdown
Contributor

Note: The below template includes items meant for model contributions only. For other contributions such as bug fixes, features, etc., only fill out the relevant portions of the form.

Description

Add FLUX.2-klein-base-9B (9.08B parameter diffusion transformer) as a contrib model with NxD Inference tensor parallelism on trn2.3xlarge.

FLUX.2-klein differs from FLUX.1 in several key ways: SwiGLU activation (vs GELU), pre-computed modulation, fused QKV+MLP projections in single-stream blocks, Qwen3-8B text encoder, and 32 latent channels with 4D RoPE. The implementation splits all fused SwiGLU projections into separate ColumnParallelLinear layers for correct TP sharding, and decomposes the massive fused to_qkv_mlp_proj into independent Q/K/V and MLP projections to stay within compiler instruction limits.

Model Information

Model Name: FLUX.2-klein-base-9B

Model Architecture: Diffusion transformer (DiT) with 8 double-stream MMDiT blocks + 24 single-stream DiT blocks, 4D RoPE, SwiGLU, pre-computed modulation

Purpose: Text-to-image generation

Checklist

Required Components

  • Accuracy Test (test/integration/test_model.py)

    • Integration test validates backbone cosine similarity vs CPU reference (0.9987 achieved)
    • Test compiles and runs the model on Neuron
  • README.md with the following sections:

    • Usage Example: Clear code example showing how to use the model
    • Compatibility Matrix: Table showing tested Neuron SDK versions and instance types
    • Example Checkpoints: Links to compatible model checkpoints
    • Testing Instructions: Command to run the test suite
  • Source Code (src/)

    • modeling_flux2_klein.py: Full NxDI model implementation (~1350 lines)
    • application.py: Config factory, NeuronTransformerWrapper, NeuronFlux2KleinApplication
    • generate_flux2_klein.py: CLI entry point

Optional Components

  • Unit Tests (test/unit/ directory present, tests pending)

Folder Structure

/contrib/models/flux2-klein/
  README.md
  /samples
    hello_world_cat.png
  /src
    __init__.py
    modeling_flux2_klein.py
    application.py
    generate_flux2_klein.py
  /test
    __init__.py
    /unit
      __init__.py
    /integration
      __init__.py
      test_model.py

Testing

How did you test this change?

Tested on trn2.3xlarge (LNC=2, TP=4) with Neuron SDK 2.29 (DLAMI 20260410). Compiled backbone, ran direct backbone comparison against HF CPU reference with identical inputs, then ran full end-to-end pipeline with 5 warm generations for benchmarking.

Test Results:

Backbone accuracy:

  • Cosine similarity (Neuron vs CPU): 0.9987
  • Max absolute difference: 0.15
  • Mean absolute difference: 0.022

Benchmark (1024x1024, 30 steps, guidance_scale=4.0, classic CFG):

  • E2E generation time: 31.09s +/- 0.08s
  • Pipeline steps/sec: 0.96
  • Per-step latency: 1036ms
  • Backbone forward/sec: 1.93
  • Compilation time: ~135s (2.3 min)
  • Model load time: ~20s

Compatibility

Tested with:

  • Neuron SDK Version(s): 2.29
  • Instance Type(s): trn2.3xlarge (LNC=2, TP=4)
  • PyTorch Version: 2.9
  • Python Version: 3.12

Additional Information

  • FLUX.2-klein is a gated model on HuggingFace requiring access approval before use
  • Text encoder (Qwen3-8B) runs on CPU (~5s per prompt) since it only executes once per image
  • The key technical challenge was SwiGLU + TP sharding: HuggingFace fuses gate and value into a single linear, but ColumnParallelLinear partitions contiguously, breaking the gate/value split boundary. Solved by splitting into separate parallel projections.
  • Classic CFG requires 2 forward passes per denoising step (positive + negative prompt)

Add 2048x2048 benchmark results (191.44s, TP=4, 50 steps) validated by
xniwang on trn2.3xlarge. Document that 4096x4096 is NOT SUPPORTED due to
a fundamental model limitation (max_area=4MP), not a hardware constraint.
The model produces noise/gray at 4K on ALL devices including H100.

Maximum supported resolution: 2048x2048.
for _ in range(2):
app(
prompt=args.prompt,
negative_prompt=args.negative_prompt,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


os.environ["LOCAL_WORLD_SIZE"] = str(self.config.neuron_config.world_size)
if _HARDWARE == hardware.TRN2:
os.environ["NEURON_RT_VIRTUAL_CORE_SIZE"] = "2"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like it's a bad idea to set a runtime variable at compile time, in a function like get_compiler_args. Above, you decide which kernel to use based on if it's sharded or not, but it makes that decision based on this environment variable that's presumably set at compile time.

Philosophically this shouldn't happen because you could theoretically load multiple (pre-compiled) NEFFs in a single runtime.

# For FLUX.2-klein: latent is (B, H*W, 128) after pack
# At 1024x1024: H=W=64 (1024/16), so img_seq = 4096
vae_scale_factor = getattr(self.config, "vae_scale_factor", 16)
num_patches = self.config.height * self.config.width // (vae_scale_factor**2)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes width is perfectly divisible by vae_scale_factor**2, which could result in the input being truncated. I can see why this is the way, but an assertion or a warning if this occurs might be helpful.

@lutfanm-aws
Copy link
Copy Markdown

Approved because I managed to generate an image using the same prompt. Comments left are non-blocking.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants