Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions src/relax/op/tensor/manipulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2780,13 +2780,53 @@ StructInfo InferStructInfoScatterND(const Call& call, const BlockBuilder& ctx) {
return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice);
}

InferLayoutOutput InferLayoutScatterND(
const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>& desired_layouts,
const VarLayoutMap& var_layout_map) {
ICHECK(NoDesiredLayout(call, desired_layouts));

LayoutDecision data_layout = GetLayoutDecision(var_layout_map, call->args[0]);
LayoutDecision indices_layout = GetLayoutDecision(var_layout_map, call->args[1]);
LayoutDecision updates_layout = GetLayoutDecision(var_layout_map, call->args[2]);

const auto* data_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
const auto* updates_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[2]);
ICHECK(data_sinfo != nullptr) << "Invalid Call";
ICHECK(updates_sinfo != nullptr) << "Invalid Call";
ICHECK(!data_sinfo->IsUnknownNdim()) << "Only support static ndim for now";
ICHECK(!updates_sinfo->IsUnknownNdim()) << "Only support static ndim for now";

LayoutDecision layout = data_layout;
LayoutDecision out_updates_layout = updates_layout;

// Check if data has a sub-indexed layout
bool has_sub_indexed_layout = layout->layout.ndim() != layout->layout.ndim_primal();

if (has_sub_indexed_layout) {
// Fall back to initial layouts for both data and updates
layout = LayoutDecision(InitialLayout(data_sinfo->ndim));
out_updates_layout = LayoutDecision(InitialLayout(updates_sinfo->ndim));
} else if (data_sinfo->ndim == updates_sinfo->ndim) {
// When data and updates have the same rank, apply the same layout to both
out_updates_layout = layout;
} else {
// Different ranks - fall back to initial layouts for both
layout = LayoutDecision(InitialLayout(data_sinfo->ndim));
out_updates_layout = LayoutDecision(InitialLayout(updates_sinfo->ndim));
}

return InferLayoutOutput({layout, indices_layout, out_updates_layout}, {layout},
Attrs(call->attrs));
}

TVM_REGISTER_OP("relax.scatter_nd")
.set_attrs_type<ScatterNDAttrs>()
.set_num_inputs(3)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("indices", "Tensor", "The indices tensor.")
.add_argument("updates", "Tensor", "The input tensor of updates.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoScatterND)
.set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutScatterND)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.scatter_nd */
Expand Down
52 changes: 52 additions & 0 deletions tests/python/relax/test_transform_convert_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -5382,5 +5382,57 @@ def main(
verify(Input, Expected)


def test_conv2d_scatter_nd():
@I.ir_module
class Input:
@R.function
def main(
x: R.Tensor((2, 3, 28, 28), "float32"),
w: R.Tensor((4, 3, 3, 3), "float32"),
indices: R.Tensor((2, 1), "int64"),
) -> R.Tensor(None, "float32", ndim=4):
with R.dataflow():
data: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32")
updates: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(data)
gv = R.scatter_nd(data, indices, updates)
R.output(gv)
return gv

@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor((2, 3, 28, 28), dtype="float32"),
w: R.Tensor((4, 3, 3, 3), dtype="float32"),
indices: R.Tensor((2, 1), dtype="int64"),
) -> R.Tensor(None, dtype="float32", ndim=4):
with R.dataflow():
lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1])
lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1])
data: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
lv,
lv1,
strides=[1, 1],
padding=[0, 0, 0, 0],
dilation=[1, 1],
groups=1,
data_layout="NHWC",
kernel_layout="OHWI",
out_layout="NHWC",
out_dtype="float32",
)
updates: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(data)
lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.scatter_nd(
data, indices, updates, reduction="update"
)
gv: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims(
lv2, axes=[0, 3, 1, 2]
)
R.output(gv)
return gv

verify(Input, Expected)


if __name__ == "__main__":
tvm.testing.main()
Loading