Skip to content

Commit 6fccd5a

Browse files
rascaniclaude
andauthored
Cortex-M: Fuse relu activation into quantized_add (#18462)
### Summary ResNet8 has skip connections with relu(add(conv(x), skip(x))). The ActivationFusionPass only fused relu into conv/linear, leaving 3 unfused relu ops that fell through to portable aten::relu.out which incorrectly clamps int8 tensors to literal 0 instead of the quantized zero_point, causing numerical mismatches on the FVP. Add fused activation patterns (relu, hardtanh, clamp) for add/add_ to quantizer_support.py BINARY_OP_PATTERNS so the quantizer produces activation-aware quantization bounds. Add aten.add.Tensor to ActivationFusionPass FUSE_OPS. Update QuantizedOpFusionPass to read activation bounds from output_qparams and pass them to quantized_add. Update the quantized_add operator (schema, meta, impl, C++) to accept activation_min/activation_max parameters. --------- Co-authored-by: Claude <noreply@anthropic.com>
1 parent fd30125 commit 6fccd5a

8 files changed

Lines changed: 103 additions & 12 deletions

File tree

backends/cortex_m/ops/op_quantized_add.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ Tensor& quantized_add_out(
2626
const int64_t output_zero_point,
2727
const int64_t output_multiplier,
2828
const int64_t output_shift,
29+
const int64_t activation_min,
30+
const int64_t activation_max,
2931
Tensor& out) {
3032
// Validate tensor types and dim order
3133
bool channel_broadcast = is_channel_broadcast(input1_int8, input2_int8);
@@ -69,8 +71,8 @@ Tensor& quantized_add_out(
6971

7072
// Left shift to maximize precision
7173
const int32_t left_shift = 20;
72-
const int32_t activation_min = std::numeric_limits<int8_t>::min();
73-
const int32_t activation_max = std::numeric_limits<int8_t>::max();
74+
const int32_t act_min = static_cast<int32_t>(activation_min);
75+
const int32_t act_max = static_cast<int32_t>(activation_max);
7476

7577
ET_LOG(
7678
Debug,
@@ -121,8 +123,8 @@ Tensor& quantized_add_out(
121123
static_cast<int32_t>(out_zp),
122124
output_mult,
123125
output_shift_val,
124-
activation_min,
125-
activation_max,
126+
act_min,
127+
act_max,
126128
adds_per_loop);
127129

128130
if (status != ARM_CMSIS_NN_SUCCESS) {

backends/cortex_m/ops/operators.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,15 +124,16 @@ def dequantize_per_tensor_impl(
124124
"quantized_add("
125125
"Tensor self, int self_zero_point, int self_multiplier, int self_shift, "
126126
"Tensor other, int other_zero_point, int other_multiplier, int other_shift, "
127-
"int output_zero_point, int output_multiplier, int output_shift) -> Tensor"
127+
"int output_zero_point, int output_multiplier, int output_shift, "
128+
"int activation_min, int activation_max) -> Tensor"
128129
)
129130

130-
# Define the operator schema with multipliers and shifts (11 args + out tensor)
131131
lib.define(
132132
"quantized_add.out("
133133
"Tensor self, int self_zero_point, int self_multiplier, int self_shift, "
134134
"Tensor other, int other_zero_point, int other_multiplier, int other_shift, "
135135
"int output_zero_point, int output_multiplier, int output_shift, "
136+
"int activation_min, int activation_max, "
136137
"*, Tensor(a!) out) -> Tensor(a!)"
137138
)
138139

@@ -150,6 +151,8 @@ def quantized_add_meta(
150151
output_zero_point: int,
151152
output_multiplier: int,
152153
output_shift: int,
154+
activation_min: int,
155+
activation_max: int,
153156
) -> torch.Tensor:
154157
assert self.shape == other.shape or is_channel_broadcast(self, other), (
155158
"Cortex-M quantized_add: broadcasting is not yet supported except for channel dim — "
@@ -175,6 +178,8 @@ def quantized_add_impl(
175178
output_zero_point: int,
176179
output_multiplier: int,
177180
output_shift: int,
181+
activation_min: int,
182+
activation_max: int,
178183
) -> torch.Tensor:
179184
assert self.shape == other.shape or is_channel_broadcast(self, other), (
180185
"Cortex-M quantized_add: broadcasting is not yet supported except for channel dim — "
@@ -188,7 +193,9 @@ def quantized_add_impl(
188193

189194
result_fp = self_fp + other_fp
190195
result_quantized = requantize_cmsis(result_fp, output_multiplier, output_shift)
191-
result = torch.clamp(result_quantized + output_zero_point, -128, 127).to(torch.int8)
196+
result = torch.clamp(
197+
result_quantized + output_zero_point, activation_min, activation_max
198+
).to(torch.int8)
192199
return result
193200

194201

backends/cortex_m/ops/operators.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
- arg_meta: null
1818
kernel_name: cortex_m::dequantize_per_tensor_out
1919

20-
- func: cortex_m::quantized_add.out(Tensor self, int self_zero_point, int self_multiplier, int self_shift, Tensor other, int other_zero_point, int other_multiplier, int other_shift, int output_zero_point, int output_multiplier, int output_shift, *, Tensor(a!) out) -> Tensor(a!)
20+
- func: cortex_m::quantized_add.out(Tensor self, int self_zero_point, int self_multiplier, int self_shift, Tensor other, int other_zero_point, int other_multiplier, int other_shift, int output_zero_point, int output_multiplier, int output_shift, int activation_min, int activation_max, *, Tensor(a!) out) -> Tensor(a!)
2121
variants: function
2222
kernels:
2323
- arg_meta: null

backends/cortex_m/passes/activation_fusion_pass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class ActivationFusionPass(ExportPass):
4040
FUSE_OPS = {
4141
exir_ops.edge.aten.linear.default,
4242
exir_ops.edge.aten.convolution.default,
43+
exir_ops.edge.aten.add.Tensor,
4344
}
4445

4546
def _get_validated_qparams(self, node, input_node):
@@ -85,7 +86,7 @@ def _get_validated_qparams(self, node, input_node):
8586
else qmax
8687
)
8788
case _:
88-
raise RuntimeError("Unexpected target {node.target}.")
89+
raise RuntimeError(f"Unexpected target {node.target}.")
8990

9091
# If the minimal quantized value is larger than the qmin, it means that the quantized range contains
9192
# invalid values [qmin, ..., quantized_min_val-1], indicating bad quantization parameters.

backends/cortex_m/passes/quantized_op_fusion_pass.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ def _get_add_replacement(self, args, meta):
7070
max_scale_2x / (output_scale * (1 << SHIFT_INT8))
7171
)
7272

73+
activation_min = meta["output_qparams"][0].qmin
74+
activation_max = meta["output_qparams"][0].qmax
75+
7376
args = (
7477
args[0],
7578
zero_point1,
@@ -82,6 +85,8 @@ def _get_add_replacement(self, args, meta):
8285
output_zero_point,
8386
output_mult,
8487
output_shift,
88+
activation_min,
89+
activation_max,
8590
)
8691

8792
return exir_ops.edge.cortex_m.quantized_add.default, args

backends/cortex_m/quantizer/quantizer_support.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@
1717

1818
BINARY_OP_PATTERNS = {
1919
(torch.ops.aten.add.Tensor,): CortexMAddMulCheck,
20+
(torch.ops.aten.add.Tensor, torch.ops.aten.relu.default): CortexMAddMulCheck,
21+
(torch.ops.aten.add.Tensor, torch.ops.aten.relu_.default): CortexMAddMulCheck,
22+
(torch.ops.aten.add.Tensor, torch.ops.aten.hardtanh.default): CortexMAddMulCheck,
23+
(torch.ops.aten.add.Tensor, torch.ops.aten.hardtanh_.default): CortexMAddMulCheck,
24+
(torch.ops.aten.add.Tensor, torch.ops.aten.clamp.default): CortexMAddMulCheck,
25+
(torch.ops.aten.add.Tensor, torch.ops.aten.clamp_.default): CortexMAddMulCheck,
2026
(torch.ops.aten.add_.Tensor,): CortexMAddMulCheck,
2127
(torch.ops.aten.mul.Tensor,): CortexMAddMulCheck,
2228
(torch.ops.aten.mul_.Tensor,): CortexMAddMulCheck,

backends/cortex_m/test/models/test_nn_modules.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,7 @@ def forward(self, x):
188188
),
189189
}
190190

191-
xfails: dict[str, xfail_type] = {
192-
"conv_add_relu": "Activation fusion does not support relu after add",
193-
}
191+
xfails: dict[str, xfail_type] = {}
194192

195193

196194
@parametrize("test_case", test_cases, xfails=xfails, strict=False)

backends/cortex_m/test/ops/test_add.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,50 @@ class CortexMAlphaAdd(ModelAlpha):
7373
}
7474

7575

76+
class CortexMAddReLU(torch.nn.Module):
77+
ops_before_transforms = {
78+
"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1,
79+
"executorch_exir_dialects_edge__ops_aten_relu_default": 1,
80+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3,
81+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3,
82+
}
83+
84+
ops_after_transforms = {
85+
"executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 1,
86+
"executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 2,
87+
"executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1,
88+
}
89+
90+
def __init__(self):
91+
super().__init__()
92+
self.relu = torch.nn.ReLU()
93+
94+
def forward(self, x, y):
95+
return self.relu(x + y)
96+
97+
98+
class CortexMAddHardtanh(torch.nn.Module):
99+
ops_before_transforms = {
100+
"executorch_exir_dialects_edge__ops_aten_add_Tensor": 1,
101+
"executorch_exir_dialects_edge__ops_aten_hardtanh_default": 1,
102+
"executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default": 3,
103+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default": 3,
104+
}
105+
106+
ops_after_transforms = {
107+
"executorch_exir_dialects_edge__ops_cortex_m_quantized_add_default": 1,
108+
"executorch_exir_dialects_edge__ops_cortex_m_quantize_per_tensor_default": 2,
109+
"executorch_exir_dialects_edge__ops_cortex_m_dequantize_per_tensor_default": 1,
110+
}
111+
112+
def __init__(self, min_val=-0.5, max_val=0.5):
113+
super().__init__()
114+
self.act = torch.nn.Hardtanh(min_val=min_val, max_val=max_val)
115+
116+
def forward(self, x, y):
117+
return self.act(x + y)
118+
119+
76120
test_cases = {
77121
"self_rank_1": McuTestCase(
78122
CortexMSelfAdd(),
@@ -149,6 +193,34 @@ class CortexMAlphaAdd(ModelAlpha):
149193
ramp_tensor(-20, 20, (4, 5)),
150194
),
151195
),
196+
"add_relu": McuTestCase(
197+
CortexMAddReLU(),
198+
(
199+
ramp_tensor(-5, 5, (2, 4)),
200+
ramp_tensor(-3, 3, (2, 4)),
201+
),
202+
),
203+
"add_relu_channels_last": McuTestCase(
204+
CortexMAddReLU(),
205+
(
206+
ramp_tensor(-5, 5, (1, 4, 8, 8)).to(memory_format=torch.channels_last),
207+
ramp_tensor(-3, 3, (1, 4, 8, 8)).to(memory_format=torch.channels_last),
208+
),
209+
),
210+
"add_hardtanh": McuTestCase(
211+
CortexMAddHardtanh(min_val=-0.5, max_val=0.5),
212+
(
213+
ramp_tensor(-2, 2, (2, 4)),
214+
ramp_tensor(-1, 1, (2, 4)),
215+
),
216+
),
217+
"add_hardtanh_channels_last": McuTestCase(
218+
CortexMAddHardtanh(min_val=-1.0, max_val=1.0),
219+
(
220+
ramp_tensor(-3, 3, (1, 4, 8, 8)).to(memory_format=torch.channels_last),
221+
ramp_tensor(-2, 2, (1, 4, 8, 8)).to(memory_format=torch.channels_last),
222+
),
223+
),
152224
}
153225

154226

0 commit comments

Comments
 (0)