Skip to content

Commit d7cc5d7

Browse files
ssjiaSS-JIA
authored andcommitted
[ET-VK] Fix embedding_q4gsw out-of-bounds access with dynamic shapes
The embedding_q4gsw shader used push constants for num_indices, out_height, and embed_dim that were captured at graph build time and never updated when input tensors were dynamically resized. This caused out-of-bounds GPU memory reads when the actual input was smaller than the initial allocation, resulting in VK_ERROR_DEVICE_LOST on Mali GPUs. The fix derives all shape-dependent values (embed_dim, out_height, num_indices) from the output tensor's sizes UBO, which is automatically updated on resize. Only truly constant values (group_size, is_linear_weight) remain as push constants. Root cause: With a 7-token input on a graph built for 256 tokens, the local workgroup rounding created an extra thread (y=7) that passed the stale bounds check (7 >= 256 == false) and read past the 7-element indices buffer. Differential Revision: [D98642319](https://our.internmc.facebook.com/intern/diff/D98642319/) ghstack-source-id: 359350851 Pull Request resolved: #18558
1 parent def3699 commit d7cc5d7

2 files changed

Lines changed: 22 additions & 26 deletions

File tree

backends/vulkan/runtime/graph/ops/glsl/embedding_q4gsw.glsl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,11 @@ $else:
4545
// Scales are ALWAYS buffer, loaded as scalar
4646
${layout_declare_tensor(B, "r", "t_scales", SCALES_DTYPE, "buffer")}
4747

48+
// Output sizes in WHCN order
49+
${layout_declare_ubo(B, "ivec4", "out_sizes")}
50+
4851
layout(push_constant) uniform PushConstants {
4952
int group_size;
50-
int embed_dim;
51-
int num_indices;
52-
int out_height;
5353
int is_linear_weight;
5454
};
5555

@@ -66,6 +66,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
6666
VEC4_T load_embedding_weights(
6767
const int embedding_idx,
6868
const int dim,
69+
const int embed_dim,
6970
const float scale) {
7071
const int n8 = embedding_idx >> 3;
7172
const int n_local = embedding_idx & 7;
@@ -96,6 +97,7 @@ VEC4_T load_embedding_weights(
9697
VEC4_T load_embedding_weights(
9798
const int embedding_idx,
9899
const int dim,
100+
const int embed_dim,
99101
const float scale) {
100102
const int blocks_per_row = embed_dim >> 5;
101103
const int block_in_row = dim >> 5;
@@ -124,7 +126,12 @@ void main() {
124126
const int y_idx = int(gl_GlobalInvocationID.y);
125127
const int z_idx = int(gl_GlobalInvocationID.z);
126128

129+
// out_sizes is in WHCN order: x=W(embed_dim), y=H, z=C, w=N
130+
const int embed_dim = out_sizes.x;
127131
const int blocks_per_row = embed_dim >> 5;
132+
const int out_height = out_sizes.y;
133+
const int num_indices = out_sizes.y * out_sizes.z * out_sizes.w;
134+
128135
const int indices_idx = z_idx * out_height + y_idx;
129136
if (block_in_row >= blocks_per_row || indices_idx >= num_indices) {
130137
return;
@@ -147,7 +154,7 @@ void main() {
147154
float(t_scales[embedding_idx * groups_per_row + dim / group_size]);
148155

149156
const VEC4_T vals =
150-
load_embedding_weights(embedding_idx, dim, scale);
157+
load_embedding_weights(embedding_idx, dim, embed_dim, scale);
151158

152159
#ifdef OUTPUT_BUFFER
153160
const int out_base = indices_idx * embed_dim + dim;

backends/vulkan/runtime/graph/ops/impl/EmbeddingQ4gsw.cpp

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,14 @@ void add_embedding_q4gsw_node(
6464
const ValueRef weight,
6565
const ValueRef weight_scales,
6666
const int32_t group_size,
67-
const int32_t embed_dim,
68-
const int32_t num_indices,
69-
const int32_t out_height,
7067
const int32_t is_linear_weight,
71-
const ValueRef out) {
68+
const ValueRef out,
69+
const ValueRef embed_dim_ref) {
7270
VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kWidthDim);
7371
VK_CHECK_COND(graph.packed_dim_of(indices) == WHCN::kWidthDim);
74-
VK_CHECK_COND(embed_dim % 32 == 0, "embed_dim must be a multiple of 32");
72+
VK_CHECK_COND(
73+
graph.get_int(embed_dim_ref) % 32 == 0,
74+
"embed_dim must be a multiple of 32");
7575

7676
std::string kernel_name = "embedding_q4gsw";
7777
kernel_name.reserve(kShaderNameReserve);
@@ -91,21 +91,18 @@ void add_embedding_q4gsw_node(
9191

9292
std::vector<PushConstantDataInfo> push_constants = {
9393
PushConstantDataInfo(&group_size, sizeof(group_size)),
94-
PushConstantDataInfo(&embed_dim, sizeof(embed_dim)),
95-
PushConstantDataInfo(&num_indices, sizeof(num_indices)),
96-
PushConstantDataInfo(&out_height, sizeof(out_height)),
9794
PushConstantDataInfo(&is_linear_weight, sizeof(is_linear_weight)),
9895
};
9996

100-
ValueRef embed_dim_ref = graph.add_scalar<int64_t>(embed_dim);
97+
vkapi::ParamsBindList param_ubos = {graph.sizes_ubo(out)};
10198

10299
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
103100
graph,
104101
VK_KERNEL_FROM_STR(kernel_name),
105102
pick_embedding_q4gsw_global_wg_size,
106103
default_pick_local_wg_size,
107104
{{out, vkapi::kWrite}, {{indices, weight, weight_scales}, vkapi::kRead}},
108-
{},
105+
param_ubos,
109106
push_constants,
110107
{},
111108
{embed_dim_ref},
@@ -125,14 +122,8 @@ void embedding_q4gsw(ComputeGraph& graph, const std::vector<ValueRef>& args) {
125122
graph.extract_scalar<bool>(is_linear_weight_ref) ? 1 : 0;
126123

127124
const std::vector<int64_t> weight_sizes = graph.sizes_of(weight_data);
128-
int32_t embed_dim = static_cast<int32_t>(weight_sizes.back() * 2);
129-
130-
const std::vector<int64_t> indices_sizes = graph.sizes_of(indices);
131-
int32_t num_indices = 1;
132-
for (auto s : indices_sizes) {
133-
num_indices *= static_cast<int32_t>(s);
134-
}
135-
int32_t out_height = static_cast<int32_t>(indices_sizes.back());
125+
int64_t embed_dim = weight_sizes.back() * 2;
126+
ValueRef embed_dim_ref = graph.add_scalar<int64_t>(embed_dim);
136127

137128
ValueRef weight;
138129
if (is_linear_weight) {
@@ -152,11 +143,9 @@ void embedding_q4gsw(ComputeGraph& graph, const std::vector<ValueRef>& args) {
152143
weight,
153144
weight_scales,
154145
group_size,
155-
embed_dim,
156-
num_indices,
157-
out_height,
158146
is_linear_weight,
159-
out);
147+
out,
148+
embed_dim_ref);
160149
}
161150

162151
REGISTER_OPERATORS {

0 commit comments

Comments
 (0)