Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 203 additions & 0 deletions backends/nxp/aten_passes/decompose_split_to_slices_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# Copyright 2025 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Optional, Tuple, TypeAlias

import torch
from torch._subclasses import FakeTensor, FakeTensorMode
from torch.fx import GraphModule, Node
from torch.fx.passes.infra.pass_base import PassBase, PassResult


class DecomposeSplitToSlicesPass(PassBase):
"""
The `split` operator returns multiple tensors by partitioning `x` along `dim`. Each partitioning can be done
using one `slice` operator. Replacing the `split` operator with multiple `slice` operators will yield the same results.


┌─────────────▼─────────────┐
│ x │
└─────────────┬─────────────┘
┌─────────────────────▼─────────────────────┐
│ aten.split / aten.split_with_sizes │
└─────────────────────┬─────────────────────┘
┌────────────────────┼─────────────────────────┐
│ │ │
┌────────▼────────┐ ┌────────▼────────┐ ┌────────▼────────┐
│ getitem(0) │ │ getitem(1) │ ... │ getitem(N-1) │
└────────┬────────┘ └────────┬────────┘ └────────┬────────┘
│ │ │
▼ ▼ ▼
out0 out1 out(N-1)


|
|
replace with
|
|


┌─────────────▼─────────────┐
│ x │
└─────────────┬─────────────┘
┌────────────────────┼─────────────────────────┐
│ │ │
┌────────▼────────┐ ┌────────▼────────┐ ┌────────▼────────┐
│ aten.slice(x, │ │ aten.slice(x, │ ... │ (more slices) │
│ dim,s0,e0 │ │ dim,s1,e1) │ ... │ │
└────────┬────────┘ └────────┬────────┘ └────────┬────────┘
│ │ │
│ │ │
▼ ▼ ▼
out0 out1 outN-1

"""

graph_module: GraphModule

@staticmethod
def _is_split_with_sizes(node: Node) -> bool:
return (
node.op == "call_function"
and node.target == torch.ops.aten.split_with_sizes.default
)

@staticmethod
def _is_regular_split(node: Node) -> bool:
is_split_tensor = (
node.op == "call_function" and node.target == torch.ops.aten.split.Tensor
)

is_split_default = (
node.op == "call_function" and node.target == torch.ops.aten.split.default
)

return is_split_tensor or is_split_default

def _create_slice_node(self, *slice_args) -> Node:
slice_target = torch.ops.aten.slice_copy.Tensor
slice_node = self.graph_module.graph.call_function(slice_target, slice_args)

slice_node.meta["source_fn_stack"] = [(slice_node.name, torch.slice_copy)]

x_val = slice_args[0].meta["val"]
with FakeTensorMode() as mode:
fake_input = FakeTensor.from_tensor(
torch.empty(x_val.shape, dtype=x_val.dtype), mode
)
output_shape = slice_target(fake_input, *slice_args[1:]).shape
slice_node.meta["val"] = FakeTensor.from_tensor(
torch.empty(output_shape, dtype=x_val.dtype), mode
)

return slice_node

SlicesArgs: TypeAlias = Tuple[List[int], List[int], int]

def _get_slices_args(self, input_node: Node, split_node: Node) -> SlicesArgs:
split_nodes_chunks = split_node.meta["val"]

# Sometimes chunks are in tuples
if isinstance(split_nodes_chunks, tuple):
split_nodes_chunks = list(split_nodes_chunks)

if not isinstance(split_nodes_chunks, list):
raise RuntimeError("Faulty split chunks")

# Get split dim
dim = -1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you guessing the dim? Is it not possible to read it from the arguments?

input_shape = input_node.meta["val"].shape
for possible_dim in range(len(split_nodes_chunks[0].shape)):
if split_nodes_chunks[0].shape[possible_dim] != input_shape[possible_dim]:
dim = possible_dim
break

if dim == -1:
raise RuntimeError("Could not determine dim param")

# Get slices start, end params
starts = []
ends = []

curr_start = 0
for s in split_nodes_chunks:
starts.append(curr_start)
ends.append(curr_start + s.shape[dim])
curr_start += s.shape[dim]

return starts, ends, dim

def _replace_split_with_slices(self, input_node, split_node, starts, ends, dim):
# Replace getitem nodes after split with slices
getitem_nodes = list(split_node.users.keys())
slice_nodes = []
for i in range(len(starts)):
slice_arguments = (input_node, dim, starts[i], ends[i])
with self.graph_module.graph.inserting_after(split_node):
slice_node = self._create_slice_node(*slice_arguments)
slice_nodes.append(slice_node)

getitem_node = getitem_nodes[i]
getitem_node.replace_all_uses_with(slice_node)

self.graph_module.graph.erase_node(getitem_node)

# Wire split node correctly to the input node
split_node.replace_all_uses_with(input_node)
self.graph_module.graph.erase_node(split_node)

return True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why return True? I don't see any other possible return value.


def call(self, graph_module: GraphModule) -> Optional[PassResult]:
self.graph_module = graph_module
made_changes = False

if not any(map(self._is_regular_split, graph_module.graph.nodes)) and not any(
map(self._is_split_with_sizes, graph_module.graph.nodes)
):
return PassResult(graph_module, made_changes)

for node in graph_module.graph.nodes:
# Skip if not split
is_split_with_sizes = self._is_split_with_sizes(node)
is_regular_split = self._is_regular_split(node)

if not is_split_with_sizes and not is_regular_split:
continue

# Get split args
split_node = node
input_node = split_node.all_input_nodes[0]
split_nodes_chunks = split_node.meta["val"]

# Check if split is even necessary - if not, remove it
if len(split_nodes_chunks) == 1:
getitem_node = split_node.next
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The split_node.next is merely the node that is stored right after the split_node in the graph.nodes. The nodes don't have to be connected at all. Please use something like list(split_node.users)[0] instead.

getitem_node.replace_all_uses_with(input_node)

self.graph_module.graph.erase_node(getitem_node)
self.graph_module.graph.erase_node(split_node)

made_changes = True
continue

# Get arguments for the new slices
starts, ends, dim = self._get_slices_args(input_node, split_node)

# Replace split with slices and restructure the graph
made_changes = self._replace_split_with_slices(
input_node, split_node, starts, ends, dim
)

self.graph_module.recompile()
self.graph_module.graph.eliminate_dead_code()

return PassResult(self.graph_module, made_changes)
4 changes: 4 additions & 0 deletions backends/nxp/aten_passes/neutron_aten_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

import torch

from executorch.backends.nxp.aten_passes.decompose_split_to_slices_pass import (
DecomposeSplitToSlicesPass,
)
from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_conv_pass import (
FuseBatchNormWithConvPass,
)
Expand Down Expand Up @@ -42,6 +45,7 @@ def __init__(
self, neutron_target_spec: NeutronTargetSpec, passes: list[PassType] = None
):
passes: list[PassType] = passes or [
DecomposeSplitToSlicesPass(),
FuseBatchNormWithConvPass(),
FuseBatchNormWithLinearPass(),
SplitGroupConvolution(),
Expand Down
34 changes: 33 additions & 1 deletion backends/nxp/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import math
from typing import Callable, Collection, Union
from typing import Callable, Collection, Tuple, Union

import torch

Expand Down Expand Up @@ -633,6 +633,38 @@ def forward(self, x):
return self.activation(x)


class GRUModel(nn.Module):
def __init__(self, num_layers=1):
super().__init__()
self.gru = torch.nn.GRU(8, 8, num_layers=num_layers)

def forward(self, input_):
# `input_` has shape [sequence_length, batch_size, input_size] ([8, 1, 8])
return self.gru(
input_, None
) # The initial hidden is `None`, which will result in a `Zeros` node being added.


class SplitWithSize(torch.nn.Module):
def __init__(self, split_size, dim):
super().__init__()
self.split_size = split_size
self.dim = dim

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
return torch.split(x, self.split_size, self.dim)


class SplitWithSections(torch.nn.Module):
def __init__(self, sections, dim):
super().__init__()
self.sections = sections
self.dim = dim

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
return torch.split(x, self.sections, self.dim)


class MiniConvNetWithRegressionHead(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
Loading
Loading