Skip to content

Commit eff1033

Browse files
Copilotzingo
authored andcommitted
Revert "Eliminate redundant NCHW↔NHWC permute_copy and NHWC-safe view_copy transposes in ToTosaMemoryFormatPass" (pytorch#18555)
Reverts pytorch#18314. ### Summary Reverts the NHWC-safe reshape detection and redundant `permute_copy` elimination introduced in pytorch#18314 (`ToTosaMemoryFormatPass`). ### Test plan Full revert of pytorch#18314 — no new logic introduced. Relevant tests: ``` pytest backends/arm/test/passes/test_to_tosa_memory_format.py pytest backends/arm/test/ -v ``` These require the Arm/TOSA toolchain and cannot run in the automation environment. <!-- START COPILOT ORIGINAL PROMPT --> <details> <summary>Original prompt</summary> > Create a revert pull request for merged PR pytorch#18314 in repository pytorch/executorch. > > Requirements: > 1) Create a new branch off the default branch (main unless repo indicates otherwise). > 2) Revert the merge commit(s) introduced by PR pytorch#18314 ("Eliminate redundant NCHW↔NHWC permute_copy and NHWC-safe view_copy transposes in ToTosaMemoryFormatPass"). Use `git revert -m 1 <merge_sha>`-equivalent semantics if it was merged via a merge commit; otherwise revert the relevant commits. > 3) Ensure the revert builds/tests as per repo conventions: > - Run the existing CI-relevant unit tests or minimal test suite if available in repo documentation. > - If running tests is not possible in the automation environment, add a note in the PR description explaining what should be run. > 4) PR title should be: "Revert \"Eliminate redundant NCHW↔NHWC permute_copy and NHWC-safe view_copy transposes in ToTosaMemoryFormatPass\"" > 5) PR description should include: > - Link to PR pytorch#18314 > - Brief explanation that this PR reverts pytorch#18314 > - Any conflicts resolved or follow-up actions needed. > > Acceptance criteria: > - The revert PR cleanly applies to the base branch and restores behavior prior to pytorch#18314. > - No unrelated formatting or refactoring changes. > > > The following is the prior conversation context from the user's chat exploration (may be truncated): > > User: can you create a revert PR > > Assistant: I’m waiting for you to accept the pending action request; once you do, I can proceed to create the revert pull request for pytorch#18314. > > </details> <!-- START COPILOT CODING AGENT SUFFIX --> *This pull request was created from Copilot chat.* > <!-- START COPILOT CODING AGENT TIPS --> --- 💬 Send tasks to Copilot coding agent from [Slack](https://gh.io/cca-slack-docs) and [Teams](https://gh.io/cca-teams-docs) to turn conversations into code. Copilot posts an update in your thread when it's finished. cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: zingo <368986+zingo@users.noreply.github.com>
1 parent be7eb44 commit eff1033

4 files changed

Lines changed: 9 additions & 404 deletions

File tree

backends/arm/_passes/to_tosa_memory_format_pass.py

Lines changed: 5 additions & 255 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
)
1717
from executorch.backends.arm.constants import NCHW_ORDER, NNCHW_ORDER, NNNCHW_ORDER
1818
from executorch.backends.arm.tosa.dialect.shape import is_shape_op_node
19-
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
2019
from executorch.exir import ExportedProgram
2120
from executorch.exir.dialects._ops import ops as exir_ops
2221
from executorch.exir.pass_base import ExportPass, PassResult
@@ -188,75 +187,12 @@ def memory_format_differs(shape, spatial_rank):
188187
channel_dim = shape[channel_idx]
189188
return channel_dim > 1 and any(dim > 1 for dim in spatial_dims)
190189

191-
@staticmethod
192-
def _is_nhwc_safe_reshape(
193-
input_shape, output_shape, cl_order: tuple[int, ...]
194-
) -> bool:
195-
"""Return ``True`` when a 4-D+ reshape can operate directly on NHWC
196-
data.
197-
198-
A reshape is NHWC-safe when its shape_indices are monotonic, both the
199-
batch dimension (index 0) and the channel dimension (last index) are
200-
preserved alone in their output groups, and every merged group contains
201-
only dims that are contiguous in the NHWC physical layout.
202-
203-
"""
204-
rank_in = len(input_shape)
205-
rank_out = len(output_shape)
206-
if rank_in < 4 or rank_out < 4:
207-
return False
208-
209-
indices = ToTosaMemoryFormatPass._get_shape_indices(
210-
list(input_shape), list(output_shape)
211-
)
212-
if indices is None or not ToTosaMemoryFormatPass._is_monotonic(indices):
213-
return False
214-
215-
# The channel dim (last axis in NHWC) and batch dim (index 0)
216-
# must each appear alone — merging either with spatial dims
217-
# would reorder data or change element pairing semantics.
218-
channel_idx = rank_in - 1
219-
batch_idx = 0
220-
for group in indices:
221-
if channel_idx in group and len(group) != 1:
222-
return False
223-
if batch_idx in group and len(group) != 1:
224-
return False
225-
226-
batch_found = any(batch_idx in g for g in indices)
227-
channel_found = any(channel_idx in g for g in indices)
228-
if not (batch_found and channel_found):
229-
return False
230-
231-
# Merged dims must be contiguous in the NHWC physical layout.
232-
# The TOSA RESHAPE operates on row-major data in NHWC order,
233-
# so only dims adjacent in that order can be validly merged.
234-
nhwc_pos = [0] * rank_in
235-
for pos, dim in enumerate(cl_order):
236-
nhwc_pos[dim] = pos
237-
for group in indices:
238-
if len(group) <= 1:
239-
continue
240-
positions = sorted(nhwc_pos[d] for d in group)
241-
for i in range(1, len(positions)):
242-
if positions[i] != positions[i - 1] + 1:
243-
return False
244-
245-
return True
246-
247190
@staticmethod
248191
def is_channel_reshape(
249192
input_shape, output_shape, input_spatial_rank, output_spatial_rank
250193
):
251194
"""Check whether a reshape touches the logical channel or consolidated
252-
batch dimensions in a way that would invalidate dim-order annotations.
253-
254-
Returns ``False`` (no transposes needed) when either:
255-
- The reshape does not change the channel or batch dimensions at all, OR
256-
- The reshape is NHWC-safe: monotonic shape_indices with both batch
257-
(index 0) and channel (last index) preserved alone in their output
258-
groups, meaning the view_copy can operate directly on NHWC data.
259-
195+
batch dimensions, which would invalidate dim-order annotations.
260196
"""
261197

262198
valid_ranks = {4, 5, 6}
@@ -284,27 +220,7 @@ def get_batch_prod_dim(shape, spatial_rank):
284220
N_old = get_batch_prod_dim(input_shape, input_spatial_rank)
285221
N_new = get_batch_prod_dim(output_shape, output_spatial_rank)
286222

287-
if (N_old == N_new) and (C_old == C_new):
288-
return False
289-
290-
# The reshape touches batch/channel dims — check whether it is
291-
# NHWC-safe (can operate directly on NHWC data without transposes).
292-
# This optimisation is only valid when both tensors use the same
293-
# channels-last permutation; when the spatial rank changes relative
294-
# to the tensor rank the NHWC axis mapping differs and the reshape
295-
# would scramble data.
296-
in_cl = ToTosaMemoryFormatPass._channels_last_order(
297-
len(input_shape), input_spatial_rank
298-
)
299-
out_cl = ToTosaMemoryFormatPass._channels_last_order(
300-
len(output_shape), output_spatial_rank
301-
)
302-
if in_cl == out_cl and ToTosaMemoryFormatPass._is_nhwc_safe_reshape(
303-
input_shape, output_shape, in_cl
304-
):
305-
return False
306-
307-
return True
223+
return (N_old != N_new) or (C_old != C_new)
308224

309225
@staticmethod
310226
def insert_input_transpose(node, input_node, graph_module):
@@ -355,7 +271,7 @@ def insert_output_transpose(node, graph_module):
355271
# Guard: mem_format must be a true permutation for the current rank
356272
assert sorted(mem_format) == list(
357273
range(rank)
358-
), f"bad perm {mem_format} for rank {rank} in insert_output_transpose"
274+
), f"bad perm {mem_format} for rank {rank} in insert_input_transpose"
359275

360276
with graph_module.graph.inserting_after(node):
361277
permute_node = create_node(
@@ -380,65 +296,6 @@ def insert_output_transpose(node, graph_module):
380296
for user in users:
381297
user.replace_input_with(node, permute_node)
382298

383-
@staticmethod
384-
def _get_shape_indices(
385-
src_shape: list[int], tgt_shape: list[int]
386-
) -> list[list[int]] | None:
387-
"""Greedy dimension matching for reshape operations.
388-
389-
For each target dimension, greedily consumes contiguous source
390-
dimensions whose product equals the target size. Size-1 target
391-
dimensions that do not correspond to any source dimension produce
392-
empty index lists (inserted dims).
393-
394-
Returns ``None`` when no valid mapping exists.
395-
396-
"""
397-
src_idx = 0
398-
result: list[list[int]] = []
399-
400-
for tgt_dim in tgt_shape:
401-
if tgt_dim <= 0:
402-
return None
403-
404-
indices: list[int] = []
405-
remaining = tgt_dim
406-
407-
while src_idx < len(src_shape):
408-
if src_shape[src_idx] == 0:
409-
return None
410-
if remaining % src_shape[src_idx] != 0:
411-
break
412-
indices.append(src_idx)
413-
remaining //= src_shape[src_idx]
414-
src_idx += 1
415-
if remaining == 1:
416-
break
417-
418-
if remaining != 1:
419-
return None
420-
421-
result.append(indices)
422-
423-
if src_idx != len(src_shape):
424-
return None
425-
426-
return result
427-
428-
@staticmethod
429-
def _is_monotonic(indices: list[list[int]]) -> bool:
430-
"""Return ``True`` when all non-empty index groups are strictly ordered
431-
— i.e. each group's indices follow the previous group's.
432-
"""
433-
last_max = -1
434-
for group in indices:
435-
if not group:
436-
continue
437-
if group[0] <= last_max:
438-
return False
439-
last_max = group[-1]
440-
return True
441-
442299
@staticmethod
443300
def _insert_view_transpose(
444301
input_shape, output_shape, node, input_node, graph_module
@@ -472,110 +329,6 @@ def _insert_view_transpose(
472329
) and ToTosaMemoryFormatPass.memory_format_differs(output_shape, output_sr):
473330
ToTosaMemoryFormatPass.insert_output_transpose(node, graph_module)
474331

475-
@staticmethod
476-
def _is_input_channels_last(input_node: torch.fx.Node, cl_order: list[int]) -> bool:
477-
"""Return True if *input_node* is already in channels-last order.
478-
479-
Only when the input is in NHWC does a cl_order/cl_inv permute duplicate
480-
the tosa_dim_order annotation. When the input is in NCHW (e.g. from a
481-
placeholder or non-spatial op) the permute is the model's intended
482-
computation and must be kept.
483-
484-
"""
485-
input_dim_order = input_node.meta.get("tosa_dim_order")
486-
if input_dim_order is None:
487-
return True
488-
return list(input_dim_order) == cl_order
489-
490-
@staticmethod
491-
def _is_semantic_permute(input_node: torch.fx.Node) -> bool:
492-
"""Return True if the permute's input traces back to a shape-
493-
manipulation op through transpose/permute nodes.
494-
495-
Walk upstream through tosa.TRANSPOSE and aten.permute_copy nodes
496-
(chained permutes arise from decomposition passes, e.g. unfold ->
497-
as_strided + movedim -> permute_copy). If a shape-manipulation op is
498-
found, the permute is semantic, not a format conversion.
499-
500-
"""
501-
upstream: torch.fx.Node | object = input_node
502-
while isinstance(upstream, torch.fx.Node) and upstream.target in (
503-
exir_ops.backend.tosa.TRANSPOSE.default,
504-
exir_ops.edge.aten.permute_copy.default,
505-
exir_ops.edge.aten.permute.default,
506-
):
507-
upstream = upstream.args[0]
508-
return isinstance(upstream, torch.fx.Node) and upstream.target in (
509-
exir_ops.edge.aten.view_copy.default,
510-
exir_ops.edge.aten.reshape.default,
511-
exir_ops.edge.aten.as_strided.default,
512-
exir_ops.edge.aten.as_strided_copy.default,
513-
)
514-
515-
def _try_replace_redundant_permute(
516-
self, node: torch.fx.Node, graph_module: torch.fx.GraphModule
517-
) -> bool:
518-
"""Remove a permute_copy if it duplicates tosa_dim_order.
519-
520-
When a permute_copy's permutation matches the channels-last order
521-
(or its inverse) AND the input is already in NHWC dim_order, the
522-
permute does the same NCHW<>NHWC conversion that tosa_dim_order
523-
already handles — keeping both would double-convert. Remove the
524-
permute by wiring its users directly to its input.
525-
526-
Returns ``True`` if the node was removed.
527-
528-
"""
529-
if node.target not in (
530-
exir_ops.edge.aten.permute_copy.default,
531-
exir_ops.edge.aten.permute.default,
532-
):
533-
return False
534-
535-
perm_arg = node.args[1]
536-
assert isinstance(perm_arg, (list, tuple))
537-
perm = list(perm_arg)
538-
rank = len(perm)
539-
sr = node.meta.get("tosa_spatial_rank", 0)
540-
541-
if rank < 3 or sr < 1:
542-
return False
543-
544-
cl_order = list(self._channels_last_order(rank, sr))
545-
cl_inv = list(self._channels_last_inverse_order(rank, sr))
546-
if perm != cl_order and perm != cl_inv:
547-
return False
548-
549-
input_node = node.args[0]
550-
if not isinstance(input_node, torch.fx.Node):
551-
return False
552-
553-
if not self._is_input_channels_last(input_node, cl_order):
554-
return False
555-
556-
if self._is_semantic_permute(input_node):
557-
return False
558-
559-
output_shape = list(node.meta["val"].shape)
560-
with graph_module.graph.inserting_before(node):
561-
const_shape_node = graph_module.graph.call_function(
562-
exir_ops.backend.tosa.CONST_SHAPE.default,
563-
(output_shape,),
564-
)
565-
const_shape_node.meta["val"] = output_shape
566-
const_shape_node.meta["tosa_dim_order"] = node.meta.get(
567-
"tosa_dim_order", tuple(range(rank))
568-
)
569-
const_shape_node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.SHAPE
570-
view_node = graph_module.graph.call_function(
571-
exir_ops.edge.aten.view_copy.default,
572-
(input_node, const_shape_node),
573-
)
574-
view_node.meta = dict(node.meta)
575-
node.replace_all_uses_with(view_node)
576-
graph_module.graph.erase_node(node)
577-
return True
578-
579332
def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
580333
"""Transposes are needed for operators transforming the input to a
581334
different rank, as 4D and 5D-tensors are assumed to be in (N)NHWC-
@@ -592,15 +345,12 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
592345
- 1D/2D tensors
593346
594347
"""
595-
for node in list(graph_module.graph.nodes):
348+
for node in graph_module.graph.nodes:
596349
if node.op != "call_function":
597350
continue
598351

599-
if self._try_replace_redundant_permute(node, graph_module):
600-
continue
601-
602352
# Transpose views
603-
if node.target == exir_ops.edge.aten.view_copy.default:
353+
elif node.target == exir_ops.edge.aten.view_copy.default:
604354
input_node = node.args[0]
605355
input_shape = input_node.meta["val"].shape
606356
output_shape = node.meta["val"].shape

backends/arm/operators/op_tosa_conv3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class Conv3dVisitor(Conv2dVisitor):
1515
target = "tosa.CONV3D.default"
1616

1717
def _get_tosa_op(self):
18-
import tosa_serializer as ts # type: ignore
18+
import serializer.tosa_serializer as ts # type: ignore
1919

2020
return ts.Op.CONV3D
2121

0 commit comments

Comments
 (0)