Skip to content

[KVCache] DSA for v1 cache manager#7787

Open
Moonchild1227 wants to merge 23 commits into
PaddlePaddle:developfrom
Moonchild1227:feat/dsa-for-v1
Open

[KVCache] DSA for v1 cache manager#7787
Moonchild1227 wants to merge 23 commits into
PaddlePaddle:developfrom
Moonchild1227:feat/dsa-for-v1

Conversation

@Moonchild1227
Copy link
Copy Markdown
Contributor

@Moonchild1227 Moonchild1227 commented May 12, 2026

Motivation

将 per-layer KV cache 分配逻辑从 CacheController 下沉到 AttentionBackend,使 CacheController 变为 variant-agnostic。新增 DSA(DeepSeek V3.2-Exp-BF16)cache layout 支持(key uint8 + indexer uint8),并为后续新增 attention 变体提供可扩展基础(无需修改 CacheController)。

Modifications

  • base_attention_backend.py:新增 create_kv_cache() 默认实现(GQA/MHA key + value,含 block_wise_fp8 scale 支持);新增 create_host_kv_cache()free_host_kv_cache() 默认实现
  • dsa_attention_backend.py:override create_kv_cache() 返回 {"key": uint8, "indexer": uint8};override create_host_kv_cache() 抛出 NotImplementedError(暂不支持 host cache 下沉)
  • mla_attention_backend.py:override create_kv_cache() 返回 {"key": tensor};override create_host_kv_cache() 仅分配 key buffer
  • cache_controller.py:重写 initialize_kv_cache / initialize_mtp_kv_cache,统一通过 attn_backend.create_kv_cache() 分配;新增 _format_cache_name();重写 initialize_host_cache_free_host_cache,委托给 backend;删除 MLACacheControllerDSACacheControllercreate_cache_controller()

Usage or Command

N/A

Accuracy Tests

DSA(DeepSeek V3.2-Exp-BF16)端到端 /v1/chat/completions 请求验证通过。

# v0
python3 gsm8k.py
🎯 Evaluation Complete: Accuracy = 95.22% (657/690)
time: 28:44
# v1
python3 gsm8k.py
🎯 Evaluation Complete: Accuracy = 94.35% (651/690)
time: 23:38

# v1 有6个问题回答正确但是输出为mardown格式的代码,导致脚本未识别。

MLA / GQA 模型验证待补充。

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings May 12, 2026 06:31
@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented May 12, 2026

Thanks for your contribution!

@paddle-bot paddle-bot Bot added the contributor External developers label May 12, 2026
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 将 per-layer KV cache 的分配逻辑下沉到 AttentionBackend(通过新增 create_kv_cache 接口),使 cache_manager/v1CacheController 仅负责 role→存储名映射、注册与可选的 set_data_ipc pin,从而减少 controller 对不同 attention variant(GQA/MLA/DSA)的耦合。

Changes:

  • AttentionBackend 新增 pin_kv_cache_for_cudagraph 与默认 create_kv_cache(...)(GQA/MHA:key/value,fp8 额外 scale)。
  • MLA/DSA backend 覆写 create_kv_cache:MLA 仅 key;DSA 返回 key+indexer(uint8)。
  • CacheController.initialize_kv_cache / initialize_mtp_kv_cache 改为逐层调用 attn_backend.create_kv_cache,并新增 "indexer" role 的存储名映射及 cudagraph pin 逻辑。

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

File Description
fastdeploy/model_executor/layers/attention/base_attention_backend.py 为 attention backend 增加通用 KV cache 分配入口与 cudagraph pin 标志位。
fastdeploy/model_executor/layers/attention/mla_attention_backend.py MLA backend 覆写 KV cache 分配:仅分配压缩 latent key cache,并要求 pin。
fastdeploy/model_executor/layers/attention/dsa_attention_backend.py DSA backend 覆写 KV cache 分配:分配 uint8 key + uint8 indexer,并要求 pin。
fastdeploy/cache_manager/v1/cache_controller.py controller 重构为 role 注册/命名映射 + 可选 pin;主模型与 MTP 走同一分配路径。

Comment thread fastdeploy/cache_manager/v1/cache_controller.py Outdated
Comment thread fastdeploy/cache_manager/v1/cache_controller.py
Comment thread fastdeploy/cache_manager/v1/cache_controller.py Outdated
Comment thread fastdeploy/cache_manager/v1/cache_controller.py Outdated
Comment thread fastdeploy/cache_manager/v1/cache_controller.py Outdated
Comment thread fastdeploy/cache_manager/v1/cache_controller.py Outdated
Comment thread fastdeploy/cache_manager/v1/cache_controller.py
PaddlePaddle-bot

This comment was marked as outdated.

@PaddlePaddle-bot
Copy link
Copy Markdown

PaddlePaddle-bot commented May 12, 2026

🤖 Paddle-CI-Agent | ci_status_monitor | 2026-05-15 17:17:55

CI报告基于以下代码生成(30分钟更新一次):


1 任务总览

❌ 有 1 个 required 任务失败,需优先处理。

总执行(rerun次数) 总任务 ✅ 通过 ❌ 失败 ⏳ 运行中 ⏸️ 等待中 跳过
17(0) 17 12 1 4 0 0

2 任务状态汇总

2.1 Required任务 : 1/2 通过

必选任务阻塞合并,失败需优先处理。

状态 任务 耗时 根因 修复建议 日志 重跑
Approval 8s PR问题:新增logger.info触发日志审批,缺少RD审批 请xyxinyang或zyyzghb review并approve Job -
其余 1 个必选任务通过 - - - - -

2.2 可选任务 — 11/15 通过

可选任务不阻塞合并,失败仅供参考。

状态 任务 耗时 日志 重跑
Trigger Jenkins for PR - Job -
xpu_build_test / xpu-build-test - Job -
Run iluvatar Tests / run_iluvatar_cases - Job -
FD-Build-Linux / fd-build - Job -
其余 11 个可选任务通过 - - -

3 失败详情(仅 required)

Approval — 代码审批(置信度: 高)

Approval

  • 状态: ❌ 失败
  • 错误类型: 代码规范
  • 置信度: 高
  • 根因摘要: PR新增多处logger.info调用,触发日志修改审批规则,缺少指定RD审批
  • 分析器: 通用分析(fallback)

根因详情:
PR 在 diff 中新增了多处 logger.info( 调用,触发了 FastDeploy 的日志修改审批规则(check_approval.sh)。脚本检测到 "There are 1 approved errors",即存在 1 个未满足的审批条件:修改日志行为(.info/.debug/.error/log_request)需要 FastDeploy RD xyxinyang(zhouchong) 或 zyyzghb(zhangyongyue) 中至少一人 approve。

关键日志:

Detected log modification in diff:
+        logger.info(
+        logger.info(
+        logger.info(
+        logger.info(f"[free_host_kv_cache]...")
0. You must have one FastDeploy RD (xyxinyang(zhouchong), zyyzghb(zhangyongyue)) approval for modifying logging behavior.
There are 1 approved errors.
##[error]Process completed with exit code 6.

修复建议:

  1. @xyxinyang(zhouchong) 或 @zyyzghb(zhangyongyue) 在 PR 上点击 "Approve" 完成审批

修复建议摘要: 请xyxinyang或zyyzghb review并approve此PR

关联变更: PR 中新增了多处 logger.info( 日志调用

链接: 查看日志

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented May 12, 2026

Codecov Report

❌ Patch coverage is 99.20000% with 1 line in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@e3541c2). Learn more about missing BASE report.

Files with missing lines Patch % Lines
...xecutor/layers/attention/base_attention_backend.py 98.27% 0 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7787   +/-   ##
==========================================
  Coverage           ?   63.48%           
==========================================
  Files              ?      462           
  Lines              ?    64310           
  Branches           ?     9854           
==========================================
  Hits               ?    40827           
  Misses             ?    20708           
  Partials           ?     2775           
Flag Coverage Δ
GPU 72.62% <99.20%> (?)
XPU 7.11% <0.80%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

PaddlePaddle-bot

This comment was marked as outdated.

Copilot AI review requested due to automatic review settings May 12, 2026 08:11
PaddlePaddle-bot

This comment was marked as outdated.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.

Comment thread fastdeploy/cache_manager/v1/cache_controller.py Outdated
Comment thread fastdeploy/cache_manager/v1/cache_controller.py
Comment thread fastdeploy/cache_manager/v1/cache_controller.py
PaddlePaddle-bot

This comment was marked as outdated.

Copilot AI review requested due to automatic review settings May 13, 2026 02:44
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.

Comments suppressed due to low confidence (2)

fastdeploy/model_executor/layers/attention/base_attention_backend.py:137

  • create_host_kv_cache 的 docstring 说明“host alloc 不可用时返回空 dict”,但实现里在 cuda_host_alloc is None 时直接 raise RuntimeError。这会让调用方(如 CacheController)难以按文档处理降级逻辑。建议要么按文档返回 {} 并由上层跳过 swap space 初始化,要么修正文档并让上层显式捕获该异常。
        Returns:
            Dict keyed by ``(role, layer_idx)``. Empty dict if host alloc is
            unavailable on the current platform.
        """
        if cuda_host_alloc is None:

fastdeploy/cache_manager/v1/cache_controller.py:544

  • initialize_host_cache 目前只捕获 NotImplementedError。但默认实现 AttentionBackend.create_host_kv_cache()cuda_host_alloc is None 时会抛 RuntimeError(以及部分 backend 可能同样抛 RuntimeError),这会让启用 swap space 的场景直接初始化失败。建议在这里同时捕获 RuntimeError(必要时也可捕获 TypeError/AttributeError)并以 warning 方式跳过 host cache 初始化,保证在不支持 pinned host alloc 的平台上可降级运行。
        try:
            host_caches = attn_backend.create_host_kv_cache(
                num_layers=num_layers,
                num_blocks=num_host_blocks,
                cache_item_bytes=cache_item_bytes,

Comment on lines 43 to 49
class AttentionBackend(ABC):
"""The base class of attention backends"""

@abstractmethod
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize the forward metadata."""
raise NotImplementedError
Comment on lines +295 to +299
caches = attn_backend.create_kv_cache(
num_layers=self._num_layers,
num_blocks=num_gpu_blocks,
cache_dtype=cache_dtype,
kv_cache_quant_type=kv_cache_quant_type,
PaddlePaddle-bot

This comment was marked as outdated.

Copilot AI review requested due to automatic review settings May 14, 2026 11:58
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.

Comments suppressed due to low confidence (1)

fastdeploy/cache_manager/v1/cache_controller.py:371

  • initialize_mtp_kv_cache 同样通过 caches.items() 直接拼 cache_kvs_list,存在与 initialize_kv_cache 相同的顺序不确定问题;而 MTP attention backend 也会用 layer_id 下标访问 caches。建议同样改为显式排序/按 role 优先级构造列表,或在 create_kv_cache 的接口契约中明确并保证插入顺序。
        cache_kvs_list: List[Any] = []
        for (role, layer_idx), tensor in caches.items():
            name = self._format_cache_name(role, layer_idx)
            self.cache_kvs_map[name] = tensor
            cache_kvs_list.append(tensor)

Comment on lines 501 to 505
def initialize_host_cache(
self,
attn_backend: Any,
) -> Dict[str, Any]:
"""
Comment on lines +305 to +306
cache_kvs_list: List[Any] = []
for (role, layer_idx), tensor in caches.items():
Comment on lines +751 to +758
for i in range(num_layers):
layer_idx = layer_offset + i
caches[("key", layer_idx)] = paddle.zeros(key_shape_list, dtype=cache_dtype)
if resolved_val_shape is not None:
caches[("value", layer_idx)] = paddle.zeros(resolved_val_shape, dtype=cache_dtype)
if kv_cache_quant_type == "block_wise_fp8":
caches[("key_scale", layer_idx)] = paddle.zeros([1], dtype="float32")
if resolved_val_shape is not None:
PaddlePaddle-bot

This comment was marked as outdated.

PaddlePaddle-bot

This comment was marked as outdated.

Copilot AI review requested due to automatic review settings May 15, 2026 03:36
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot was unable to review this pull request because the user who requested the review is ineligible. To be eligible to request a review, you need a paid Copilot license, or your organization must enable Copilot code review.

PaddlePaddle-bot

This comment was marked as outdated.

Copilot AI review requested due to automatic review settings May 15, 2026 05:20
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.

Comments suppressed due to low confidence (2)

fastdeploy/cache_manager/v1/cache_controller.py:309

  • 这里通过遍历 caches.items() 来构造 cache_kvs_list,会把“caches dict 的插入顺序”隐式当成缓存 list 的语义顺序。但多个 attention backend 在 forward 时会用固定下标访问 forward_meta.caches(例如 caches[2*layer_id]/caches[4*layer_id+2]),因此 list 顺序必须是确定且与角色布局一致的。建议在 CacheController 侧显式按 (layer_idx, role_priority) 排序/拼接(或要求 backend 返回一个有序 list + map),并在 docstring 中明确顺序契约,避免后续 backend 返回 dict 时因插入顺序不同导致运行时读错缓存。
        cache_kvs_list: List[Any] = []
        for (role, layer_idx), tensor in caches.items():
            name = self._format_cache_name(role, layer_idx)
            self.cache_kvs_map[name] = tensor
            cache_kvs_list.append(tensor)

fastdeploy/cache_manager/v1/cache_controller.py:371

  • initialize_mtp_kv_cache 同样依赖 caches.items() 的插入顺序来生成 cache_kvs_list。由于 forward 侧对 forward_meta.caches 存在固定下标访问约定(按层/角色排列),这里建议与主模型路径一致:显式按 layer/role 的稳定顺序构建 list(或由 backend 提供稳定的顺序输出),避免 MTP 缓存与主模型缓存在 list 布局上不一致导致下标错位。
        cache_kvs_list: List[Any] = []
        for (role, layer_idx), tensor in caches.items():
            name = self._format_cache_name(role, layer_idx)
            self.cache_kvs_map[name] = tensor
            cache_kvs_list.append(tensor)

Comment on lines +134 to +141
Returns:
Dict keyed by ``(role, layer_idx)``. Empty dict if host alloc is
unavailable on the current platform.
"""
if cuda_host_alloc is None:
raise RuntimeError(
f"[create_host_kv_cache][{type(self).__name__}] cuda_host_alloc " "is not available on this platform"
)
Comment on lines 235 to +243
return {
"key": f"key_caches_{layer_idx}_rank{local_rank}.device{self._device_id}",
"value": f"value_caches_{layer_idx}_rank{local_rank}.device{self._device_id}",
"key_scale": f"key_cache_scales_{layer_idx}_rank{local_rank}.device{self._device_id}",
"value_scale": f"value_cache_scales_{layer_idx}_rank{local_rank}.device{self._device_id}",
"indexer": f"indexer_caches_{layer_idx}_rank{local_rank}.device{self._device_id}",
}

def _format_cache_name(self, role: str, layer_idx: int) -> str:
PaddlePaddle-bot

This comment was marked as outdated.

PaddlePaddle-bot

This comment was marked as outdated.

Copilot AI review requested due to automatic review settings May 15, 2026 08:57
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot was unable to review this pull request because the user who requested the review is ineligible. To be eligible to request a review, you need a paid Copilot license, or your organization must enable Copilot code review.

Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Paddle-CI-Agent | pr_review | 2026-05-15 17:10:53

📋 Review 摘要

PR 概述:将 per-layer KV cache 分配逻辑从 CacheController 下沉至各 AttentionBackend,新增 DSA cache layout(uint8 key + uint8 indexer),使 CacheController 成为 variant-agnostic。
变更范围cache_manager/v1/model_executor/layers/attention/tests/cache_manager/tests/layers/
影响面 Tag[KVCache]

📝 PR 规范检查

标题含合法 [KVCache] Tag,格式规范;描述各必填 section 均已填写,Checklist 勾选与实际变更一致。PR 规范合规,无需修改建议。

问题

级别 文件 概述
🟡 建议 mla_attention_backend.py:655 create_host_kv_cache 直接 import 导致 None 检查为死代码,非 CUDA 平台会抛 ImportError 而非被 controller 捕获
❓ 疑问 cache_controller.py initialize_host_cache 未将 attn_backend 保存到 self.attn_backend,直接调用该方法时 _free_host_cache 会走 warning 路径并泄漏 pinned memory

总体评价

架构重构思路清晰,责任边界划分合理,单测覆盖全面。存在一处非 CUDA 平台下 MLA host cache 的兼容性问题(import 异常类型不匹配),以及一处 attn_backend 生命周期的潜在问题,建议在合入前修复。MLA/GQA 端到端精度验证结果 pending,建议补充后正式合入。

layer_offset: int = 0,
):
"""
MLA host cache: only the compressed latent key buffer, no value, no scales.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 cuda_host_alloc 直接 import 导致 None 检查为死代码

此处使用 from ... import cuda_host_alloc 直接导入,若平台不支持(如非 CUDA 环境),会抛出 ImportError;而 cache_controller.initialize_host_cache 只捕获 NotImplementedError,导致异常向上传播,host cache 初始化失败且无友好提示。同时 if cuda_host_alloc is None 检查永远为 False(成功直接导入不会得到 None),是死代码。

对比 base_attention_backend.py 的做法,建议改为引用模块级变量或加 try/except:

# 方案1:复用父类已导入的模块级变量(推荐)
from fastdeploy.model_executor.layers.attention.base_attention_backend import cuda_host_alloc as _alloc
if _alloc is None:
    raise NotImplementedError("...")  # 改为 NotImplementedError,与 DSA 保持一致,controller 能正确捕获

# 方案2:与 base_attention_backend.py 一致的 try/except
try:
    from fastdeploy.cache_manager.ops import cuda_host_alloc as _alloc
except Exception:
    _alloc = None
if _alloc is None:
    raise NotImplementedError("...")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants