From 207da9b7f6bbe38aba0b257388289e1e08b04279 Mon Sep 17 00:00:00 2001 From: Guan-Ming Chiu Date: Wed, 7 Jan 2026 12:54:55 +0800 Subject: [PATCH 1/2] [Relax] Add FRelaxInferLayout for scatter_nd operator --- src/relax/op/tensor/manipulate.cc | 23 ++++++++ .../relax/test_transform_convert_layout.py | 52 +++++++++++++++++++ 2 files changed, 75 insertions(+) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 7c5682d462fc..583b94937088 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -2780,6 +2780,28 @@ StructInfo InferStructInfoScatterND(const Call& call, const BlockBuilder& ctx) { return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice); } +InferLayoutOutput InferLayoutScatterND( + const Call& call, const ffi::Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ICHECK(NoDesiredLayout(call, desired_layouts)); + + LayoutDecision data_layout = GetLayoutDecision(var_layout_map, call->args[0]); + LayoutDecision indices_layout = GetLayoutDecision(var_layout_map, call->args[1]); + LayoutDecision updates_layout = GetLayoutDecision(var_layout_map, call->args[2]); + + LayoutDecision layout = data_layout; + + if (layout->layout.ndim() != layout->layout.ndim_primal()) { + const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); + ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; + int ndim = tensor_sinfo->ndim; + layout = LayoutDecision(InitialLayout(ndim)); + } + + return InferLayoutOutput({layout, indices_layout, updates_layout}, {layout}, Attrs(call->attrs)); +} + TVM_REGISTER_OP("relax.scatter_nd") .set_attrs_type() .set_num_inputs(3) @@ -2787,6 +2809,7 @@ TVM_REGISTER_OP("relax.scatter_nd") .add_argument("indices", "Tensor", "The indices tensor.") .add_argument("updates", "Tensor", "The input tensor of updates.") .set_attr("FInferStructInfo", InferStructInfoScatterND) + .set_attr("FRelaxInferLayout", InferLayoutScatterND) .set_attr("FPurity", Bool(true)); /* relax.scatter_nd */ diff --git a/tests/python/relax/test_transform_convert_layout.py b/tests/python/relax/test_transform_convert_layout.py index 26990bc44db3..221d680ebcfe 100644 --- a/tests/python/relax/test_transform_convert_layout.py +++ b/tests/python/relax/test_transform_convert_layout.py @@ -5382,5 +5382,57 @@ def main( verify(Input, Expected) +def test_conv2d_scatter_nd(): + @I.ir_module + class Input: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), "float32"), + w: R.Tensor((4, 3, 3, 3), "float32"), + indices: R.Tensor((2, 1), "int64"), + ) -> R.Tensor(None, "float32", ndim=4): + with R.dataflow(): + data: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32") + updates: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(data) + gv = R.scatter_nd(data, indices, updates) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3, 28, 28), dtype="float32"), + w: R.Tensor((4, 3, 3, 3), dtype="float32"), + indices: R.Tensor((2, 1), dtype="int64"), + ) -> R.Tensor(None, dtype="float32", ndim=4): + with R.dataflow(): + lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1]) + lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1]) + data: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d( + lv, + lv1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + groups=1, + data_layout="NHWC", + kernel_layout="OHWI", + out_layout="NHWC", + out_dtype="float32", + ) + updates: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(data) + lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.scatter_nd( + data, indices, updates, reduction="update" + ) + gv: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims( + lv2, axes=[0, 3, 1, 2] + ) + R.output(gv) + return gv + + verify(Input, Expected) + + if __name__ == "__main__": tvm.testing.main() From cc38602819866c96f25e54ae0523b46d75e6e08b Mon Sep 17 00:00:00 2001 From: Guan-Ming Chiu Date: Wed, 7 Jan 2026 16:10:04 +0800 Subject: [PATCH 2/2] Apply review comments --- src/relax/op/tensor/manipulate.cc | 33 +++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 583b94937088..3170b28eebf7 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -2789,17 +2789,34 @@ InferLayoutOutput InferLayoutScatterND( LayoutDecision indices_layout = GetLayoutDecision(var_layout_map, call->args[1]); LayoutDecision updates_layout = GetLayoutDecision(var_layout_map, call->args[2]); - LayoutDecision layout = data_layout; + const auto* data_sinfo = GetStructInfoAs(call->args[0]); + const auto* updates_sinfo = GetStructInfoAs(call->args[2]); + ICHECK(data_sinfo != nullptr) << "Invalid Call"; + ICHECK(updates_sinfo != nullptr) << "Invalid Call"; + ICHECK(!data_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; + ICHECK(!updates_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; - if (layout->layout.ndim() != layout->layout.ndim_primal()) { - const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); - ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; - ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; - int ndim = tensor_sinfo->ndim; - layout = LayoutDecision(InitialLayout(ndim)); + LayoutDecision layout = data_layout; + LayoutDecision out_updates_layout = updates_layout; + + // Check if data has a sub-indexed layout + bool has_sub_indexed_layout = layout->layout.ndim() != layout->layout.ndim_primal(); + + if (has_sub_indexed_layout) { + // Fall back to initial layouts for both data and updates + layout = LayoutDecision(InitialLayout(data_sinfo->ndim)); + out_updates_layout = LayoutDecision(InitialLayout(updates_sinfo->ndim)); + } else if (data_sinfo->ndim == updates_sinfo->ndim) { + // When data and updates have the same rank, apply the same layout to both + out_updates_layout = layout; + } else { + // Different ranks - fall back to initial layouts for both + layout = LayoutDecision(InitialLayout(data_sinfo->ndim)); + out_updates_layout = LayoutDecision(InitialLayout(updates_sinfo->ndim)); } - return InferLayoutOutput({layout, indices_layout, updates_layout}, {layout}, Attrs(call->attrs)); + return InferLayoutOutput({layout, indices_layout, out_updates_layout}, {layout}, + Attrs(call->attrs)); } TVM_REGISTER_OP("relax.scatter_nd")