-
Notifications
You must be signed in to change notification settings - Fork 788
NXP backend: added aten.split support #16276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,203 @@ | ||
| # Copyright 2025 NXP | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| from typing import List, Optional, Tuple, TypeAlias | ||
|
|
||
| import torch | ||
| from torch._subclasses import FakeTensor, FakeTensorMode | ||
| from torch.fx import GraphModule, Node | ||
| from torch.fx.passes.infra.pass_base import PassBase, PassResult | ||
|
|
||
|
|
||
| class DecomposeSplitToSlicesPass(PassBase): | ||
| """ | ||
| The `split` operator returns multiple tensors by partitioning `x` along `dim`. Each partitioning can be done | ||
| using one `slice` operator. Replacing the `split` operator with multiple `slice` operators will yield the same results. | ||
|
|
||
|
|
||
| │ | ||
| ┌─────────────▼─────────────┐ | ||
| │ x │ | ||
| └─────────────┬─────────────┘ | ||
| │ | ||
| ┌─────────────────────▼─────────────────────┐ | ||
| │ aten.split / aten.split_with_sizes │ | ||
| └─────────────────────┬─────────────────────┘ | ||
| │ | ||
| ┌────────────────────┼─────────────────────────┐ | ||
| │ │ │ | ||
| ┌────────▼────────┐ ┌────────▼────────┐ ┌────────▼────────┐ | ||
| │ getitem(0) │ │ getitem(1) │ ... │ getitem(N-1) │ | ||
| └────────┬────────┘ └────────┬────────┘ └────────┬────────┘ | ||
| │ │ │ | ||
| ▼ ▼ ▼ | ||
| out0 out1 out(N-1) | ||
|
|
||
|
|
||
| | | ||
| | | ||
| replace with | ||
| | | ||
| | | ||
| ▼ | ||
|
|
||
|
|
||
| │ | ||
| ┌─────────────▼─────────────┐ | ||
| │ x │ | ||
| └─────────────┬─────────────┘ | ||
| │ | ||
| ┌────────────────────┼─────────────────────────┐ | ||
| │ │ │ | ||
| ┌────────▼────────┐ ┌────────▼────────┐ ┌────────▼────────┐ | ||
| │ aten.slice(x, │ │ aten.slice(x, │ ... │ (more slices) │ | ||
| │ dim,s0,e0 │ │ dim,s1,e1) │ ... │ │ | ||
| └────────┬────────┘ └────────┬────────┘ └────────┬────────┘ | ||
| │ │ │ | ||
| │ │ │ | ||
| ▼ ▼ ▼ | ||
| out0 out1 outN-1 | ||
|
|
||
| """ | ||
|
|
||
| graph_module: GraphModule | ||
|
|
||
| @staticmethod | ||
| def _is_split_with_sizes(node: Node) -> bool: | ||
| return ( | ||
| node.op == "call_function" | ||
| and node.target == torch.ops.aten.split_with_sizes.default | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def _is_regular_split(node: Node) -> bool: | ||
| is_split_tensor = ( | ||
| node.op == "call_function" and node.target == torch.ops.aten.split.Tensor | ||
| ) | ||
|
|
||
| is_split_default = ( | ||
| node.op == "call_function" and node.target == torch.ops.aten.split.default | ||
| ) | ||
|
|
||
| return is_split_tensor or is_split_default | ||
|
|
||
| def _create_slice_node(self, *slice_args) -> Node: | ||
| slice_target = torch.ops.aten.slice_copy.Tensor | ||
| slice_node = self.graph_module.graph.call_function(slice_target, slice_args) | ||
|
|
||
| slice_node.meta["source_fn_stack"] = [(slice_node.name, torch.slice_copy)] | ||
|
|
||
| x_val = slice_args[0].meta["val"] | ||
| with FakeTensorMode() as mode: | ||
| fake_input = FakeTensor.from_tensor( | ||
| torch.empty(x_val.shape, dtype=x_val.dtype), mode | ||
| ) | ||
| output_shape = slice_target(fake_input, *slice_args[1:]).shape | ||
| slice_node.meta["val"] = FakeTensor.from_tensor( | ||
| torch.empty(output_shape, dtype=x_val.dtype), mode | ||
| ) | ||
|
|
||
| return slice_node | ||
|
|
||
| SlicesArgs: TypeAlias = Tuple[List[int], List[int], int] | ||
|
|
||
| def _get_slices_args(self, input_node: Node, split_node: Node) -> SlicesArgs: | ||
| split_nodes_chunks = split_node.meta["val"] | ||
|
|
||
| # Sometimes chunks are in tuples | ||
| if isinstance(split_nodes_chunks, tuple): | ||
| split_nodes_chunks = list(split_nodes_chunks) | ||
|
|
||
| if not isinstance(split_nodes_chunks, list): | ||
| raise RuntimeError("Faulty split chunks") | ||
|
|
||
| # Get split dim | ||
| dim = -1 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are you guessing the |
||
| input_shape = input_node.meta["val"].shape | ||
| for possible_dim in range(len(split_nodes_chunks[0].shape)): | ||
| if split_nodes_chunks[0].shape[possible_dim] != input_shape[possible_dim]: | ||
| dim = possible_dim | ||
| break | ||
|
|
||
| if dim == -1: | ||
| raise RuntimeError("Could not determine dim param") | ||
|
|
||
| # Get slices start, end params | ||
| starts = [] | ||
| ends = [] | ||
|
|
||
| curr_start = 0 | ||
| for s in split_nodes_chunks: | ||
| starts.append(curr_start) | ||
| ends.append(curr_start + s.shape[dim]) | ||
| curr_start += s.shape[dim] | ||
|
|
||
| return starts, ends, dim | ||
|
|
||
| def _replace_split_with_slices(self, input_node, split_node, starts, ends, dim): | ||
| # Replace getitem nodes after split with slices | ||
| getitem_nodes = list(split_node.users.keys()) | ||
| slice_nodes = [] | ||
| for i in range(len(starts)): | ||
| slice_arguments = (input_node, dim, starts[i], ends[i]) | ||
| with self.graph_module.graph.inserting_after(split_node): | ||
| slice_node = self._create_slice_node(*slice_arguments) | ||
| slice_nodes.append(slice_node) | ||
|
|
||
| getitem_node = getitem_nodes[i] | ||
| getitem_node.replace_all_uses_with(slice_node) | ||
|
|
||
| self.graph_module.graph.erase_node(getitem_node) | ||
|
|
||
| # Wire split node correctly to the input node | ||
| split_node.replace_all_uses_with(input_node) | ||
| self.graph_module.graph.erase_node(split_node) | ||
|
|
||
| return True | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why |
||
|
|
||
| def call(self, graph_module: GraphModule) -> Optional[PassResult]: | ||
| self.graph_module = graph_module | ||
| made_changes = False | ||
|
|
||
| if not any(map(self._is_regular_split, graph_module.graph.nodes)) and not any( | ||
| map(self._is_split_with_sizes, graph_module.graph.nodes) | ||
| ): | ||
| return PassResult(graph_module, made_changes) | ||
|
|
||
| for node in graph_module.graph.nodes: | ||
| # Skip if not split | ||
| is_split_with_sizes = self._is_split_with_sizes(node) | ||
| is_regular_split = self._is_regular_split(node) | ||
|
|
||
| if not is_split_with_sizes and not is_regular_split: | ||
| continue | ||
|
|
||
| # Get split args | ||
| split_node = node | ||
| input_node = split_node.all_input_nodes[0] | ||
| split_nodes_chunks = split_node.meta["val"] | ||
|
|
||
| # Check if split is even necessary - if not, remove it | ||
| if len(split_nodes_chunks) == 1: | ||
| getitem_node = split_node.next | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| getitem_node.replace_all_uses_with(input_node) | ||
|
|
||
| self.graph_module.graph.erase_node(getitem_node) | ||
| self.graph_module.graph.erase_node(split_node) | ||
|
|
||
| made_changes = True | ||
| continue | ||
|
|
||
| # Get arguments for the new slices | ||
| starts, ends, dim = self._get_slices_args(input_node, split_node) | ||
|
|
||
| # Replace split with slices and restructure the graph | ||
| made_changes = self._replace_split_with_slices( | ||
| input_node, split_node, starts, ends, dim | ||
| ) | ||
|
|
||
| self.graph_module.recompile() | ||
| self.graph_module.graph.eliminate_dead_code() | ||
|
|
||
| return PassResult(self.graph_module, made_changes) | ||
Uh oh!
There was an error while loading. Please reload this page.