Add topology-derived sparse attention kernel#22
Conversation
|
Requesting review: this PR moves the earlier triton-lang/triton#10768 work here following the repo-fit feedback and now includes focused correctness/validation tests plus a checked-in benchmark runner. @ThomasRaoux, if there is a better reviewer for kernels examples, could you route this to the right maintainer? |
ThomasRaoux
left a comment
There was a problem hiding this comment.
looks good to me, just one comment
Thanks, fixed. I moved the optional pandas dependency handling into benchmark_utils.compare_benchmarks, where pandas is actually used, and restored benchmarking/init.py to the normal exports. Added a regression test for importing benchmark_utils without pandas installed. any other blockers? @ThomasRaoux how can I move forward? |
should we just make panda an explicit dependency? I thought it was already the case |
I see two reasonable designs:
i am specifically working with second design in plan But well design one is also valid anyone operating triton must have pandas' installed . |
Summary
[seq, dim]and[batch, heads, seq, dim]inputs with the same CSR schedule shared across batch/head lanes.Closes #21.
This was originally attempted in triton-lang/triton#10768. Maintainer feedback there said the Triton core repo was probably not the right place and that a repo of Triton kernels would be a better fit, so this PR moves the contribution here instead of reopening the core PR.
Local validation
wsl.exe -d Ubuntu -- bash -lc "cd /mnt/c/Users/seal/Documents/GitHub/kernels && /home/seal/.cache/codex-triton-topology-venv/bin/python -m py_compile benchmarking/__init__.py benchmarking/topology_sparse_attention.py kernels/topology_sparse_attention.py test/test_topology_sparse_attention.py && /home/seal/.cache/codex-triton-topology-venv/bin/python -m pytest -s --tb=short test/test_topology_sparse_attention.py"Result: 17 passed.
git diff --check -- .Result: no whitespace errors. Git emitted local CRLF warnings only.
Benchmark smoke command:
wsl.exe -d Ubuntu -- bash -lc "cd /mnt/c/Users/seal/Documents/GitHub/kernels && /home/seal/.cache/codex-triton-topology-venv/bin/python -m benchmarking.topology_sparse_attention --seq 1024 --rounds 3"Result: command completed and printed the markdown benchmark table.
Local benchmark
To reproduce the full table:
Measured on NVIDIA GeForce RTX 4060 Laptop GPU, PyTorch 2.12.1+cu130, Triton 3.7.1,
dim=64,block_size=64, 50 timing rounds.The main Triton-side comparison is
Triton sparse vs dense CSR: both sides use this PR's scheduled-attention Triton kernel, but the dense CSR path visits every causal block. The SDPA number is a full-causal optimized baseline, not the same sparse mask; the 1024-token case is slower than SDPA, so this should not be read as a universal SDPA speedup claim.