From d570e5472b3067f3f47d7411d2ac3d31d4a6aa66 Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" <105915352+guan404ming@users.noreply.github.com> Date: Fri, 26 Dec 2025 19:38:28 +0800 Subject: [PATCH 1/2] Refactoring Duplicate cuBLAS/hipBLAS Tests --- .../python/relax/test_codegen_blas_common.py | 127 ++++++++++++++++++ tests/python/relax/test_codegen_cublas.py | 103 ++------------ tests/python/relax/test_codegen_hipblas.py | 66 ++------- 3 files changed, 149 insertions(+), 147 deletions(-) create mode 100644 tests/python/relax/test_codegen_blas_common.py diff --git a/tests/python/relax/test_codegen_blas_common.py b/tests/python/relax/test_codegen_blas_common.py new file mode 100644 index 000000000000..bc9022ed5ad4 --- /dev/null +++ b/tests/python/relax/test_codegen_blas_common.py @@ -0,0 +1,127 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Shared test utilities for cuBLAS and hipBLAS codegen tests.""" +import numpy as np + +import tvm +from tvm import relax +from tvm.relax.testing import get_relax_matmul_module + + +def build_and_run(mod, inputs_np, target, legalize=False, cuda_graph=False): + dev = tvm.device(target, 0) + with tvm.transform.PassContext( + config={ + "relax.backend.use_cuda_graph": cuda_graph, + "relax.transform.apply_legalize_ops": legalize, + } + ): + ex = tvm.compile(mod, target) + vm = relax.VirtualMachine(ex, dev) + f = vm["main"] + inputs = [tvm.runtime.tensor(inp, dev) for inp in inputs_np] + + # For cuda graph, run the compiled function twice to make sure that we can launch the cached + # graph on the second run. + if cuda_graph: + f(*inputs) + + return f(*inputs).numpy() + + +def to_concrete_shape(symbolic_shape, var_table): + result = [] + for dim in symbolic_shape: + if not isinstance(dim, tvm.tir.expr.Var): + result.append(dim) + continue + + if dim not in var_table: + var_table[dim] = np.random.randint(10, 50) + result.append(var_table[dim]) + + return tuple(result) + + +def run_matmul_offload_test( + x_shape, + y_shape, + transpose_y, + epilogue, + in_dtype, + out_dtype, + epilogue_table, + partition_fn, + target, +): + """Shared test logic for matmul offload tests across different BLAS backends. + + Parameters + ---------- + x_shape : tuple + Shape of the first input tensor. + y_shape : tuple + Shape of the second input tensor. + transpose_y : bool + Whether to transpose the second input. + epilogue : str + Type of epilogue operation. + in_dtype : str + Input data type. + out_dtype : str + Output data type. + epilogue_table : dict + Mapping of epilogue names to (with_bias, activation) tuples. + partition_fn : callable + Function to partition the module for the specific BLAS backend. + target : str + Target device (e.g., "cuda" or "rocm"). + """ + with_bias, activation = epilogue_table[epilogue] + var_table = {} + concrete_x_shape = to_concrete_shape(x_shape, var_table) + concrete_y_shape = to_concrete_shape(y_shape, var_table) + x = np.random.randn(*concrete_x_shape).astype(in_dtype) + y = np.random.randn(*concrete_y_shape).astype(in_dtype) + + if transpose_y: + y = np.swapaxes(y, -2, -1) + y_shape = (*y_shape[:-2], y_shape[-1], y_shape[-2]) + + if with_bias: + bias = np.random.randn(concrete_y_shape[-1]).astype(out_dtype) + args = (x, y, bias) + else: + bias = None + args = (x, y) + + mod = get_relax_matmul_module( + x_shape, + y_shape, + in_dtype, + out_dtype, + bias_shape=bias.shape if with_bias else None, + transposed_y=transpose_y, + activation=activation, + ) + + mod = partition_fn(mod) + mod = relax.transform.RunCodegen()(mod) + out = build_and_run(mod, args, target) + ref = build_and_run(mod, args, "llvm", legalize=True) + + tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) diff --git a/tests/python/relax/test_codegen_cublas.py b/tests/python/relax/test_codegen_cublas.py index 32666ebd1d8c..f305b47884ab 100644 --- a/tests/python/relax/test_codegen_cublas.py +++ b/tests/python/relax/test_codegen_cublas.py @@ -27,6 +27,8 @@ from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import relax as relax_builder +from test_codegen_blas_common import build_and_run, to_concrete_shape, run_matmul_offload_test + try: import ml_dtypes except ImportError: @@ -41,27 +43,6 @@ def reset_seed(): pytestmark = tvm.testing.requires_cublas.marks() -def build_and_run(mod, inputs_np, target, legalize=False, cuda_graph=False): - dev = tvm.device(target, 0) - with tvm.transform.PassContext( - config={ - "relax.backend.use_cuda_graph": cuda_graph, - "relax.transform.apply_legalize_ops": legalize, - } - ): - ex = tvm.compile(mod, target) - vm = relax.VirtualMachine(ex, dev) - f = vm["main"] - inputs = [tvm.runtime.tensor(inp, dev) for inp in inputs_np] - - # For cuda graph, run the compiled function twice to make sure that we can launch the cached - # graph on the second run. - if cuda_graph: - f(*inputs) - - return f(*inputs).numpy() - - def get_result_with_relax_cublas_offload(mod, np_inputs, cuda_graph=False, bind_constants=False): mod = partition_for_cublas(mod, bind_constants=bind_constants) mod = relax.transform.RunCodegen()(mod) @@ -69,20 +50,6 @@ def get_result_with_relax_cublas_offload(mod, np_inputs, cuda_graph=False, bind_ return build_and_run(mod, np_inputs, "cuda", cuda_graph) -def _to_concrete_shape(symbolic_shape, var_table): - result = [] - for dim in symbolic_shape: - if not isinstance(dim, tvm.tir.expr.Var): - result.append(dim) - continue - - if dim not in var_table: - var_table[dim] = np.random.randint(10, 50) - result.append(var_table[dim]) - - return tuple(result) - - _vars = { "a": tvm.tir.expr.Var("a", "int64"), "b": tvm.tir.expr.Var("b", "int64"), @@ -204,39 +171,18 @@ def test_matmul_offload( in_dtype, out_dtype, ): - with_bias, activation = _epilogue_table[epilogue] - var_table = {} - concrete_x_shape = _to_concrete_shape(x_shape, var_table) - concrete_y_shape = _to_concrete_shape(y_shape, var_table) - x = np.random.randn(*concrete_x_shape).astype(in_dtype) - y = np.random.randn(*concrete_y_shape).astype(in_dtype) - - if transpose_y: - y = np.swapaxes(y, -2, -1) - y_shape = (*y_shape[:-2], y_shape[-1], y_shape[-2]) - - if with_bias: - bias = np.random.randn(concrete_y_shape[-1]).astype(out_dtype) - args = (x, y, bias) - else: - bias = None - args = (x, y) - - mod = get_relax_matmul_module( + run_matmul_offload_test( x_shape, y_shape, + transpose_y, + epilogue, in_dtype, out_dtype, - bias_shape=bias.shape if with_bias else None, - transposed_y=transpose_y, - activation=activation, + _epilogue_table, + partition_for_cublas, + "cuda", ) - out = get_result_with_relax_cublas_offload(mod, args) - ref = build_and_run(mod, args, "llvm", legalize=True) - - tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) - @pytest.mark.parametrize( "x_shape, y_shape, transpose_y, epilogue", @@ -265,39 +211,18 @@ def test_matmul_igemm_offload( ): in_dtype = "int8" out_dtype = "int32" - with_bias, activation = _epilogue_table[epilogue] - var_table = {} - concrete_x_shape = _to_concrete_shape(x_shape, var_table) - concrete_y_shape = _to_concrete_shape(y_shape, var_table) - x = np.random.randn(*concrete_x_shape).astype(in_dtype) - y = np.random.randn(*concrete_y_shape).astype(in_dtype) - - if transpose_y: - y = np.swapaxes(y, -2, -1) - y_shape = (*y_shape[:-2], y_shape[-1], y_shape[-2]) - - if with_bias: - bias = np.random.randn(concrete_y_shape[-1]).astype(out_dtype) - args = (x, y, bias) - else: - bias = None - args = (x, y) - - mod = get_relax_matmul_module( + run_matmul_offload_test( x_shape, y_shape, + transpose_y, + epilogue, in_dtype, out_dtype, - bias_shape=bias.shape if with_bias else None, - transposed_y=transpose_y, - activation=activation, + _epilogue_table, + partition_for_cublas, + "cuda", ) - out = get_result_with_relax_cublas_offload(mod, args) - ref = build_and_run(mod, args, "llvm", legalize=True) - - tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) - @tvm.testing.requires_cuda_compute_version(9) @pytest.mark.skipif(ml_dtypes is None, reason="requires ml_dtypes to be installed") diff --git a/tests/python/relax/test_codegen_hipblas.py b/tests/python/relax/test_codegen_hipblas.py index 286acc44f1f1..2569fbb242dd 100644 --- a/tests/python/relax/test_codegen_hipblas.py +++ b/tests/python/relax/test_codegen_hipblas.py @@ -25,6 +25,8 @@ from tvm.relax.testing import get_relax_matmul_module from tvm.script import relax as R +from test_codegen_blas_common import run_matmul_offload_test + try: import ml_dtypes except ImportError: @@ -39,37 +41,6 @@ def reset_seed(): pytestmark = tvm.testing.requires_hipblas.marks() -def build_and_run(mod, inputs_np, target, legalize=False): - dev = tvm.device(target, 0) - with tvm.transform.PassContext(config={"relax.transform.apply_legalize_ops": legalize}): - ex = tvm.compile(mod, target) - vm = relax.VirtualMachine(ex, dev) - f = vm["main"] - inputs = [tvm.runtime.tensor(inp, dev) for inp in inputs_np] - return f(*inputs).numpy() - - -def get_result_with_relax_cublas_offload(mod, np_inputs): - mod = partition_for_hipblas(mod) - mod = relax.transform.RunCodegen()(mod) - - return build_and_run(mod, np_inputs, "rocm") - - -def _to_concrete_shape(symbolic_shape, var_table): - result = [] - for dim in symbolic_shape: - if not isinstance(dim, tvm.tir.expr.Var): - result.append(dim) - continue - - if dim not in var_table: - var_table[dim] = np.random.randint(10, 50) - result.append(var_table[dim]) - - return tuple(result) - - _vars = { "a": tvm.tir.expr.Var("a", "int64"), "b": tvm.tir.expr.Var("b", "int64"), @@ -118,39 +89,18 @@ def test_matmul_offload( in_dtype, out_dtype, ): - with_bias, activation = _epilogue_table[epilogue] - var_table = {} - concrete_x_shape = _to_concrete_shape(x_shape, var_table) - concrete_y_shape = _to_concrete_shape(y_shape, var_table) - x = np.random.randn(*concrete_x_shape).astype(in_dtype) - y = np.random.randn(*concrete_y_shape).astype(in_dtype) - - if transpose_y: - y = np.swapaxes(y, -2, -1) - y_shape = (*y_shape[:-2], y_shape[-1], y_shape[-2]) - - if with_bias: - bias = np.random.randn(concrete_y_shape[-1]).astype(out_dtype) - args = (x, y, bias) - else: - bias = None - args = (x, y) - - mod = get_relax_matmul_module( + run_matmul_offload_test( x_shape, y_shape, + transpose_y, + epilogue, in_dtype, out_dtype, - bias_shape=bias.shape if with_bias else None, - transposed_y=transpose_y, - activation=activation, + _epilogue_table, + partition_for_hipblas, + "rocm", ) - out = get_result_with_relax_cublas_offload(mod, args) - ref = build_and_run(mod, args, "llvm", legalize=True) - - tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) - def test_hipblas_partition_matmul_without_bias(): # hipBLAS does not handle 2D bias (residual input) From 536237643449ad860bfc42536be3f266497be595 Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" <105915352+guan404ming@users.noreply.github.com> Date: Wed, 31 Dec 2025 22:49:14 +0800 Subject: [PATCH 2/2] Remove redundant decorator --- tests/python/relax/test_frontend_dynamo.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/python/relax/test_frontend_dynamo.py b/tests/python/relax/test_frontend_dynamo.py index b3eac1d42709..48350ac73611 100644 --- a/tests/python/relax/test_frontend_dynamo.py +++ b/tests/python/relax/test_frontend_dynamo.py @@ -319,7 +319,6 @@ def _convert_data_type(input_type): raise NotImplementedError("input_type {} is not handled yet".format(input_type)) -@tvm.testing.requires_gpu def test_ones(): import torch from torch.nn import Module @@ -350,7 +349,6 @@ def main( ) -@tvm.testing.requires_gpu def test_full(): import torch from torch.nn import Module @@ -381,7 +379,6 @@ def main( ) -@tvm.testing.requires_gpu def test_gelu(): import torch from torch.nn import Module @@ -433,7 +430,6 @@ def main( ) -@tvm.testing.requires_gpu def test_masked_fill(): import torch from torch.nn import Module @@ -470,7 +466,6 @@ def main( ) -@tvm.testing.requires_gpu def test_getitem(): import torch from torch.nn import Module @@ -546,7 +541,6 @@ def forward(self, input1): version.parse(torch_version) >= version.parse("2.6.0"), reason="Need to support dynamic arange in Relax", ) -@tvm.testing.requires_gpu def test_arange(): import torch from torch.nn import Module