diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index f7d54a6216a7..6c8b20d99b92 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -2275,6 +2275,14 @@ def _empty_like(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] return self.block_builder.emit(relax.op.zeros_like(x)) + def _randn(self, node: fx.Node) -> relax.Var: + import torch + + dtype = self._convert_data_type( + node.kwargs.get("dtype", torch.get_default_dtype()), self.env + ) + return self.block_builder.emit(relax.op.zeros(node.args[0], dtype)) + def _eye(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) n = args[0] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index b6b9723c131f..748936090d75 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -1476,6 +1476,7 @@ def create_convert_map( "ones_like.default": lambda node: self.block_builder.emit( relax.op.ones_like(self.env[node.args[0]]) ), + "randn.default": self._randn, "zero_.default": self._zeros_inplace, "zeros.default": self._zeros, "zeros_like.default": self._zeros_like, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 9f8842ddcb69..8d42d32c7679 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -8719,5 +8719,28 @@ def main( verify_model(UpsampleNearest2dSize(), example_args, {}, expected_size) +def test_randn(): + """Test basic torch.randn operation.""" + + class Randn(Module): + def forward(self, x): + return torch.randn(4, 8) + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 3), dtype="float32") + ) -> R.Tuple(R.Tensor((4, 8), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((4, 8), dtype="float32") = R.zeros(R.shape([4, 8]), dtype="float32") + gv: R.Tuple(R.Tensor((4, 8), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(2, 3, dtype=torch.float32),) + verify_model(Randn(), example_args, {}, Expected) + + if __name__ == "__main__": tvm.testing.main()