From 033bf11925d461b05814be7fec9a5f5355b1900c Mon Sep 17 00:00:00 2001 From: Vaclav Novak Date: Tue, 16 Dec 2025 17:49:48 +0100 Subject: [PATCH] feat: added support for aten.split --- .../decompose_split_to_slices_pass.py | 203 +++++++++++++++ .../aten_passes/neutron_aten_pass_manager.py | 4 + backends/nxp/tests/models.py | 34 ++- .../tests/test_decompose_split_to_slices.py | 243 ++++++++++++++++++ 4 files changed, 483 insertions(+), 1 deletion(-) create mode 100644 backends/nxp/aten_passes/decompose_split_to_slices_pass.py create mode 100644 backends/nxp/tests/test_decompose_split_to_slices.py diff --git a/backends/nxp/aten_passes/decompose_split_to_slices_pass.py b/backends/nxp/aten_passes/decompose_split_to_slices_pass.py new file mode 100644 index 00000000000..3cad897f3b1 --- /dev/null +++ b/backends/nxp/aten_passes/decompose_split_to_slices_pass.py @@ -0,0 +1,203 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import List, Optional, Tuple, TypeAlias + +import torch +from torch._subclasses import FakeTensor, FakeTensorMode +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.pass_base import PassBase, PassResult + + +class DecomposeSplitToSlicesPass(PassBase): + """ + The `split` operator returns multiple tensors by partitioning `x` along `dim`. Each partitioning can be done + using one `slice` operator. Replacing the `split` operator with multiple `slice` operators will yield the same results. + + + │ + ┌─────────────▼─────────────┐ + │ x │ + └─────────────┬─────────────┘ + │ + ┌─────────────────────▼─────────────────────┐ + │ aten.split / aten.split_with_sizes │ + └─────────────────────┬─────────────────────┘ + │ + ┌────────────────────┼─────────────────────────┐ + │ │ │ + ┌────────▼────────┐ ┌────────▼────────┐ ┌────────▼────────┐ + │ getitem(0) │ │ getitem(1) │ ... │ getitem(N-1) │ + └────────┬────────┘ └────────┬────────┘ └────────┬────────┘ + │ │ │ + ▼ ▼ ▼ + out0 out1 out(N-1) + + + | + | + replace with + | + | + ▼ + + + │ + ┌─────────────▼─────────────┐ + │ x │ + └─────────────┬─────────────┘ + │ + ┌────────────────────┼─────────────────────────┐ + │ │ │ + ┌────────▼────────┐ ┌────────▼────────┐ ┌────────▼────────┐ + │ aten.slice(x, │ │ aten.slice(x, │ ... │ (more slices) │ + │ dim,s0,e0 │ │ dim,s1,e1) │ ... │ │ + └────────┬────────┘ └────────┬────────┘ └────────┬────────┘ + │ │ │ + │ │ │ + ▼ ▼ ▼ + out0 out1 outN-1 + + """ + + graph_module: GraphModule + + @staticmethod + def _is_split_with_sizes(node: Node) -> bool: + return ( + node.op == "call_function" + and node.target == torch.ops.aten.split_with_sizes.default + ) + + @staticmethod + def _is_regular_split(node: Node) -> bool: + is_split_tensor = ( + node.op == "call_function" and node.target == torch.ops.aten.split.Tensor + ) + + is_split_default = ( + node.op == "call_function" and node.target == torch.ops.aten.split.default + ) + + return is_split_tensor or is_split_default + + def _create_slice_node(self, *slice_args) -> Node: + slice_target = torch.ops.aten.slice_copy.Tensor + slice_node = self.graph_module.graph.call_function(slice_target, slice_args) + + slice_node.meta["source_fn_stack"] = [(slice_node.name, torch.slice_copy)] + + x_val = slice_args[0].meta["val"] + with FakeTensorMode() as mode: + fake_input = FakeTensor.from_tensor( + torch.empty(x_val.shape, dtype=x_val.dtype), mode + ) + output_shape = slice_target(fake_input, *slice_args[1:]).shape + slice_node.meta["val"] = FakeTensor.from_tensor( + torch.empty(output_shape, dtype=x_val.dtype), mode + ) + + return slice_node + + SlicesArgs: TypeAlias = Tuple[List[int], List[int], int] + + def _get_slices_args(self, input_node: Node, split_node: Node) -> SlicesArgs: + split_nodes_chunks = split_node.meta["val"] + + # Sometimes chunks are in tuples + if isinstance(split_nodes_chunks, tuple): + split_nodes_chunks = list(split_nodes_chunks) + + if not isinstance(split_nodes_chunks, list): + raise RuntimeError("Faulty split chunks") + + # Get split dim + dim = -1 + input_shape = input_node.meta["val"].shape + for possible_dim in range(len(split_nodes_chunks[0].shape)): + if split_nodes_chunks[0].shape[possible_dim] != input_shape[possible_dim]: + dim = possible_dim + break + + if dim == -1: + raise RuntimeError("Could not determine dim param") + + # Get slices start, end params + starts = [] + ends = [] + + curr_start = 0 + for s in split_nodes_chunks: + starts.append(curr_start) + ends.append(curr_start + s.shape[dim]) + curr_start += s.shape[dim] + + return starts, ends, dim + + def _replace_split_with_slices(self, input_node, split_node, starts, ends, dim): + # Replace getitem nodes after split with slices + getitem_nodes = list(split_node.users.keys()) + slice_nodes = [] + for i in range(len(starts)): + slice_arguments = (input_node, dim, starts[i], ends[i]) + with self.graph_module.graph.inserting_after(split_node): + slice_node = self._create_slice_node(*slice_arguments) + slice_nodes.append(slice_node) + + getitem_node = getitem_nodes[i] + getitem_node.replace_all_uses_with(slice_node) + + self.graph_module.graph.erase_node(getitem_node) + + # Wire split node correctly to the input node + split_node.replace_all_uses_with(input_node) + self.graph_module.graph.erase_node(split_node) + + return True + + def call(self, graph_module: GraphModule) -> Optional[PassResult]: + self.graph_module = graph_module + made_changes = False + + if not any(map(self._is_regular_split, graph_module.graph.nodes)) and not any( + map(self._is_split_with_sizes, graph_module.graph.nodes) + ): + return PassResult(graph_module, made_changes) + + for node in graph_module.graph.nodes: + # Skip if not split + is_split_with_sizes = self._is_split_with_sizes(node) + is_regular_split = self._is_regular_split(node) + + if not is_split_with_sizes and not is_regular_split: + continue + + # Get split args + split_node = node + input_node = split_node.all_input_nodes[0] + split_nodes_chunks = split_node.meta["val"] + + # Check if split is even necessary - if not, remove it + if len(split_nodes_chunks) == 1: + getitem_node = split_node.next + getitem_node.replace_all_uses_with(input_node) + + self.graph_module.graph.erase_node(getitem_node) + self.graph_module.graph.erase_node(split_node) + + made_changes = True + continue + + # Get arguments for the new slices + starts, ends, dim = self._get_slices_args(input_node, split_node) + + # Replace split with slices and restructure the graph + made_changes = self._replace_split_with_slices( + input_node, split_node, starts, ends, dim + ) + + self.graph_module.recompile() + self.graph_module.graph.eliminate_dead_code() + + return PassResult(self.graph_module, made_changes) diff --git a/backends/nxp/aten_passes/neutron_aten_pass_manager.py b/backends/nxp/aten_passes/neutron_aten_pass_manager.py index 35205c76c68..a02bb7ba1bc 100644 --- a/backends/nxp/aten_passes/neutron_aten_pass_manager.py +++ b/backends/nxp/aten_passes/neutron_aten_pass_manager.py @@ -7,6 +7,9 @@ import torch +from executorch.backends.nxp.aten_passes.decompose_split_to_slices_pass import ( + DecomposeSplitToSlicesPass, +) from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_conv_pass import ( FuseBatchNormWithConvPass, ) @@ -42,6 +45,7 @@ def __init__( self, neutron_target_spec: NeutronTargetSpec, passes: list[PassType] = None ): passes: list[PassType] = passes or [ + DecomposeSplitToSlicesPass(), FuseBatchNormWithConvPass(), FuseBatchNormWithLinearPass(), SplitGroupConvolution(), diff --git a/backends/nxp/tests/models.py b/backends/nxp/tests/models.py index e2b41aab8de..3e2fc66cc5f 100644 --- a/backends/nxp/tests/models.py +++ b/backends/nxp/tests/models.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import math -from typing import Callable, Collection, Union +from typing import Callable, Collection, Tuple, Union import torch @@ -633,6 +633,38 @@ def forward(self, x): return self.activation(x) +class GRUModel(nn.Module): + def __init__(self, num_layers=1): + super().__init__() + self.gru = torch.nn.GRU(8, 8, num_layers=num_layers) + + def forward(self, input_): + # `input_` has shape [sequence_length, batch_size, input_size] ([8, 1, 8]) + return self.gru( + input_, None + ) # The initial hidden is `None`, which will result in a `Zeros` node being added. + + +class SplitWithSize(torch.nn.Module): + def __init__(self, split_size, dim): + super().__init__() + self.split_size = split_size + self.dim = dim + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]: + return torch.split(x, self.split_size, self.dim) + + +class SplitWithSections(torch.nn.Module): + def __init__(self, sections, dim): + super().__init__() + self.sections = sections + self.dim = dim + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]: + return torch.split(x, self.sections, self.dim) + + class MiniConvNetWithRegressionHead(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/nxp/tests/test_decompose_split_to_slices.py b/backends/nxp/tests/test_decompose_split_to_slices.py new file mode 100644 index 00000000000..99e97b9820d --- /dev/null +++ b/backends/nxp/tests/test_decompose_split_to_slices.py @@ -0,0 +1,243 @@ +# Copyright 2025 NXP +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import pytest +import torch + +from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import ( + DecomposeSplitToSlicesPass, + NeutronAtenPassManager, + SplitGRUBasedOnNumLayers, +) +from executorch.backends.nxp.tests.executorch_pipeline import neutron_target_spec +from executorch.backends.nxp.tests.executors import graph_contains_any_of_ops +from executorch.backends.nxp.tests.models import ( + GRUModel, + SplitWithSections, + SplitWithSize, +) + + +@pytest.mark.parametrize( + "input_shape, split_size, dim", + [ + pytest.param((8,), 3, 0, id="1D."), + pytest.param((4, 8), 5, 1, id="2D."), + ], +) +def test_decompose_split_with_size(mocker, input_shape, split_size, dim): + model = SplitWithSize(split_size, dim) + example_input = torch.rand(input_shape) + + exir_program_aten = torch.export.export(model, (example_input,)).module() + + # Check "aten.split.Tensor" is present + assert graph_contains_any_of_ops( + exir_program_aten.graph, [torch.ops.aten.split.Tensor] + ) + outputs_before = [o.detach().numpy() for o in exir_program_aten(example_input)] + + # Apply the optimization. + NeutronAtenPassManager(neutron_target_spec, [DecomposeSplitToSlicesPass()])( + exir_program_aten + ) + + # Make sure no "Split" is in the model. + assert not graph_contains_any_of_ops( + exir_program_aten.graph, + [ + torch.ops.aten.split.Tensor, + torch.ops.aten.split.default, + torch.ops.aten.split_with_sizes.default, + ], + ) + + # Check correct placement of slices + nodes = list(exir_program_aten.graph.nodes) + slices_count = input_shape[dim] // split_size + # Slice nodes start appearing at index 1 + slices_start_idx = 1 + + for i in range(0, slices_count): + assert nodes[slices_start_idx + i].target == torch.ops.aten.slice_copy.Tensor + + outputs_after = [o.detach().numpy() for o in exir_program_aten(example_input)] + + # Make sure the model still produces the exact same output. + assert len(outputs_before) == len(outputs_after) + + for i in range(len(outputs_before)): + assert np.allclose(outputs_before[i], outputs_after[i]) + + +@pytest.mark.parametrize( + "input_shape, sections, dim", + [ + pytest.param((8,), [5, 3], 0, id="1D."), + pytest.param((4, 8), [3, 3, 2], 1, id="2D."), + ], +) +def test_decompose_split_with_section(mocker, input_shape, sections, dim): + model = SplitWithSections(sections, dim) + example_input = torch.rand(input_shape) + + exir_program_aten = torch.export.export(model, (example_input,)).module() + + # Check "aten.split_with_sizes" is present + assert graph_contains_any_of_ops( + exir_program_aten.graph, [torch.ops.aten.split_with_sizes.default] + ) + outputs_before = [o.detach().numpy() for o in exir_program_aten(example_input)] + + # Apply the optimization. + NeutronAtenPassManager(neutron_target_spec, [DecomposeSplitToSlicesPass()])( + exir_program_aten + ) + + # Make sure no "Split" is in the model. + assert not graph_contains_any_of_ops( + exir_program_aten.graph, + [ + torch.ops.aten.split.Tensor, + torch.ops.aten.split.default, + torch.ops.aten.split_with_sizes.default, + ], + ) + + # Check correct placement of slices + nodes = list(exir_program_aten.graph.nodes) + slices_count = len(sections) + # Slice nodes start appearing at index 1 + slices_start_idx = 1 + + for i in range(0, slices_count): + assert nodes[slices_start_idx + i].target == torch.ops.aten.slice_copy.Tensor + + outputs_after = [o.detach().numpy() for o in exir_program_aten(example_input)] + + # Make sure the model still produces the exact same output. + assert len(outputs_before) == len(outputs_after) + + for i in range(len(outputs_before)): + assert np.allclose(outputs_before[i], outputs_after[i]) + + +@pytest.mark.parametrize( + "gru_layers", + [ + pytest.param(2, id="2 GRU layers"), + ], +) +def test_decompose_gru_with_split_default(mocker, gru_layers): + model = GRUModel(gru_layers).eval() + + input_shape = (8, 1, 8) + example_input = (torch.ones(input_shape),) + + exir_program_aten = torch.export.export(model, example_input).module() + + # Apply the pass to split the `aten.gru.input` into multiple instances, which adds a split operator + NeutronAtenPassManager(neutron_target_spec, [SplitGRUBasedOnNumLayers()])( + exir_program_aten + ) + + # Check "aten.split.default" is present + assert graph_contains_any_of_ops( + exir_program_aten.graph, [torch.ops.aten.split.default] + ) + + outputs_before = [o.detach().numpy() for o in exir_program_aten(*example_input)] + + # Apply the optimization. + NeutronAtenPassManager(neutron_target_spec, [DecomposeSplitToSlicesPass()])( + exir_program_aten + ) + + # Make sure no "Split" is in the model. + assert not graph_contains_any_of_ops( + exir_program_aten.graph, + [ + torch.ops.aten.split.Tensor, + torch.ops.aten.split.default, + torch.ops.aten.split_with_sizes.default, + ], + ) + + # Check correct placement of slices + nodes = list(exir_program_aten.graph.nodes) + slices_count = gru_layers + # Slice nodes start appearing at index 10 for gru_layer=2, for gru_layer=3 they start at index 14... + slices_start_idx = 4 * gru_layers + 2 + + for i in range(0, slices_count): + assert nodes[slices_start_idx + i].target == torch.ops.aten.slice_copy.Tensor + + outputs_after = [o.detach().numpy() for o in exir_program_aten(*example_input)] + + # Make sure the model still produces the exact same output. + assert len(outputs_before) == len(outputs_after) + + for i in range(len(outputs_before)): + assert np.allclose(outputs_before[i], outputs_after[i]) + + +@pytest.mark.parametrize( + "input_shape, size_or_sections, dim", + [ + # pytest.param((8, 4), 4, 1, id="2D, one chunk using split size."), + pytest.param( + (8, 4), + 5, + 1, + id="2D, one chunk using split size, chunk size over the limit.", + ), + pytest.param((8, 4), [4], 1, id="2D, one chunk using sections."), + ], +) +def test_decompose_split_with_one_chunk(mocker, input_shape, size_or_sections, dim): + if isinstance(size_or_sections, list): + model = SplitWithSections(size_or_sections, dim) + else: + model = SplitWithSize(size_or_sections, dim) + example_input = torch.rand(input_shape) + + exir_program_aten = torch.export.export(model, (example_input,)).module() + + # Check "aten.split" is present + assert graph_contains_any_of_ops( + exir_program_aten.graph, + [torch.ops.aten.split.Tensor, torch.ops.aten.split_with_sizes.default], + ) + outputs_before = [o.detach().numpy() for o in exir_program_aten(example_input)] + + # Apply the optimization. + NeutronAtenPassManager(neutron_target_spec, [DecomposeSplitToSlicesPass()])( + exir_program_aten + ) + + # Make sure no "Split" is in the model. + assert not graph_contains_any_of_ops( + exir_program_aten.graph, + [ + torch.ops.aten.split.Tensor, + torch.ops.aten.split.default, + torch.ops.aten.split_with_sizes.default, + ], + ) + + # Make sure there are no "aten.slice_copy.Tensor" either. Since the split was done using one chunk, + # slicing is unnecessary + assert not graph_contains_any_of_ops( + exir_program_aten.graph, [torch.ops.aten.slice_copy.Tensor] + ) + + outputs_after = [o.detach().numpy() for o in exir_program_aten(example_input)] + + # Make sure the model still produces the exact same output. + assert len(outputs_before) == len(outputs_after) + + for i in range(len(outputs_before)): + assert np.allclose(outputs_before[i], outputs_after[i])