Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions tests/python/relax/test_codegen_blas_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Shared test utilities for cuBLAS and hipBLAS codegen tests."""
import numpy as np

import tvm
from tvm import relax
from tvm.relax.testing import get_relax_matmul_module


def build_and_run(mod, inputs_np, target, legalize=False, cuda_graph=False):
dev = tvm.device(target, 0)
with tvm.transform.PassContext(
config={
"relax.backend.use_cuda_graph": cuda_graph,
"relax.transform.apply_legalize_ops": legalize,
}
):
ex = tvm.compile(mod, target)
vm = relax.VirtualMachine(ex, dev)
f = vm["main"]
inputs = [tvm.runtime.tensor(inp, dev) for inp in inputs_np]

# For cuda graph, run the compiled function twice to make sure that we can launch the cached
# graph on the second run.
if cuda_graph:
f(*inputs)

return f(*inputs).numpy()


def to_concrete_shape(symbolic_shape, var_table):
result = []
for dim in symbolic_shape:
if not isinstance(dim, tvm.tir.expr.Var):
result.append(dim)
continue

if dim not in var_table:
var_table[dim] = np.random.randint(10, 50)
result.append(var_table[dim])

return tuple(result)


def run_matmul_offload_test(
x_shape,
y_shape,
transpose_y,
epilogue,
in_dtype,
out_dtype,
epilogue_table,
partition_fn,
target,
):
"""Shared test logic for matmul offload tests across different BLAS backends.

Parameters
----------
x_shape : tuple
Shape of the first input tensor.
y_shape : tuple
Shape of the second input tensor.
transpose_y : bool
Whether to transpose the second input.
epilogue : str
Type of epilogue operation.
in_dtype : str
Input data type.
out_dtype : str
Output data type.
epilogue_table : dict
Mapping of epilogue names to (with_bias, activation) tuples.
partition_fn : callable
Function to partition the module for the specific BLAS backend.
target : str
Target device (e.g., "cuda" or "rocm").
"""
with_bias, activation = epilogue_table[epilogue]
var_table = {}
concrete_x_shape = to_concrete_shape(x_shape, var_table)
concrete_y_shape = to_concrete_shape(y_shape, var_table)
x = np.random.randn(*concrete_x_shape).astype(in_dtype)
y = np.random.randn(*concrete_y_shape).astype(in_dtype)

if transpose_y:
y = np.swapaxes(y, -2, -1)
y_shape = (*y_shape[:-2], y_shape[-1], y_shape[-2])
Comment on lines +101 to +103
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for handling transpose_y is correct but confusing. The y_shape variable is used to define the shape of the y operand for matmul, but then it's reassigned to represent the shape of the y argument for the Relax function, which is different when transpose_y is true. This dual-purpose use of y_shape makes the code hard to reason about.

A refactor using separate variables (e.g., y_matmul_shape and y_arg_shape) would greatly improve readability and maintainability.


if with_bias:
bias = np.random.randn(concrete_y_shape[-1]).astype(out_dtype)
args = (x, y, bias)
else:
bias = None
args = (x, y)

mod = get_relax_matmul_module(
x_shape,
y_shape,
in_dtype,
out_dtype,
bias_shape=bias.shape if with_bias else None,
transposed_y=transpose_y,
activation=activation,
)

mod = partition_fn(mod)
mod = relax.transform.RunCodegen()(mod)
out = build_and_run(mod, args, target)
ref = build_and_run(mod, args, "llvm", legalize=True)

tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
103 changes: 14 additions & 89 deletions tests/python/relax/test_codegen_cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from tvm.script.ir_builder import IRBuilder
from tvm.script.ir_builder import relax as relax_builder

from test_codegen_blas_common import build_and_run, to_concrete_shape, run_matmul_offload_test

try:
import ml_dtypes
except ImportError:
Expand All @@ -41,48 +43,13 @@ def reset_seed():
pytestmark = tvm.testing.requires_cublas.marks()


def build_and_run(mod, inputs_np, target, legalize=False, cuda_graph=False):
dev = tvm.device(target, 0)
with tvm.transform.PassContext(
config={
"relax.backend.use_cuda_graph": cuda_graph,
"relax.transform.apply_legalize_ops": legalize,
}
):
ex = tvm.compile(mod, target)
vm = relax.VirtualMachine(ex, dev)
f = vm["main"]
inputs = [tvm.runtime.tensor(inp, dev) for inp in inputs_np]

# For cuda graph, run the compiled function twice to make sure that we can launch the cached
# graph on the second run.
if cuda_graph:
f(*inputs)

return f(*inputs).numpy()


def get_result_with_relax_cublas_offload(mod, np_inputs, cuda_graph=False, bind_constants=False):
mod = partition_for_cublas(mod, bind_constants=bind_constants)
mod = relax.transform.RunCodegen()(mod)

return build_and_run(mod, np_inputs, "cuda", cuda_graph)


def _to_concrete_shape(symbolic_shape, var_table):
result = []
for dim in symbolic_shape:
if not isinstance(dim, tvm.tir.expr.Var):
result.append(dim)
continue

if dim not in var_table:
var_table[dim] = np.random.randint(10, 50)
result.append(var_table[dim])

return tuple(result)


_vars = {
"a": tvm.tir.expr.Var("a", "int64"),
"b": tvm.tir.expr.Var("b", "int64"),
Expand Down Expand Up @@ -204,39 +171,18 @@ def test_matmul_offload(
in_dtype,
out_dtype,
):
with_bias, activation = _epilogue_table[epilogue]
var_table = {}
concrete_x_shape = _to_concrete_shape(x_shape, var_table)
concrete_y_shape = _to_concrete_shape(y_shape, var_table)
x = np.random.randn(*concrete_x_shape).astype(in_dtype)
y = np.random.randn(*concrete_y_shape).astype(in_dtype)

if transpose_y:
y = np.swapaxes(y, -2, -1)
y_shape = (*y_shape[:-2], y_shape[-1], y_shape[-2])

if with_bias:
bias = np.random.randn(concrete_y_shape[-1]).astype(out_dtype)
args = (x, y, bias)
else:
bias = None
args = (x, y)

mod = get_relax_matmul_module(
run_matmul_offload_test(
x_shape,
y_shape,
transpose_y,
epilogue,
in_dtype,
out_dtype,
bias_shape=bias.shape if with_bias else None,
transposed_y=transpose_y,
activation=activation,
_epilogue_table,
partition_for_cublas,
"cuda",
)

out = get_result_with_relax_cublas_offload(mod, args)
ref = build_and_run(mod, args, "llvm", legalize=True)

tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)


@pytest.mark.parametrize(
"x_shape, y_shape, transpose_y, epilogue",
Expand Down Expand Up @@ -265,39 +211,18 @@ def test_matmul_igemm_offload(
):
in_dtype = "int8"
out_dtype = "int32"
with_bias, activation = _epilogue_table[epilogue]
var_table = {}
concrete_x_shape = _to_concrete_shape(x_shape, var_table)
concrete_y_shape = _to_concrete_shape(y_shape, var_table)
x = np.random.randn(*concrete_x_shape).astype(in_dtype)
y = np.random.randn(*concrete_y_shape).astype(in_dtype)

if transpose_y:
y = np.swapaxes(y, -2, -1)
y_shape = (*y_shape[:-2], y_shape[-1], y_shape[-2])

if with_bias:
bias = np.random.randn(concrete_y_shape[-1]).astype(out_dtype)
args = (x, y, bias)
else:
bias = None
args = (x, y)

mod = get_relax_matmul_module(
run_matmul_offload_test(
x_shape,
y_shape,
transpose_y,
epilogue,
in_dtype,
out_dtype,
bias_shape=bias.shape if with_bias else None,
transposed_y=transpose_y,
activation=activation,
_epilogue_table,
partition_for_cublas,
"cuda",
)

out = get_result_with_relax_cublas_offload(mod, args)
ref = build_and_run(mod, args, "llvm", legalize=True)

tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)


@tvm.testing.requires_cuda_compute_version(9)
@pytest.mark.skipif(ml_dtypes is None, reason="requires ml_dtypes to be installed")
Expand Down
66 changes: 8 additions & 58 deletions tests/python/relax/test_codegen_hipblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from tvm.relax.testing import get_relax_matmul_module
from tvm.script import relax as R

from test_codegen_blas_common import run_matmul_offload_test

try:
import ml_dtypes
except ImportError:
Expand All @@ -39,37 +41,6 @@ def reset_seed():
pytestmark = tvm.testing.requires_hipblas.marks()


def build_and_run(mod, inputs_np, target, legalize=False):
dev = tvm.device(target, 0)
with tvm.transform.PassContext(config={"relax.transform.apply_legalize_ops": legalize}):
ex = tvm.compile(mod, target)
vm = relax.VirtualMachine(ex, dev)
f = vm["main"]
inputs = [tvm.runtime.tensor(inp, dev) for inp in inputs_np]
return f(*inputs).numpy()


def get_result_with_relax_cublas_offload(mod, np_inputs):
mod = partition_for_hipblas(mod)
mod = relax.transform.RunCodegen()(mod)

return build_and_run(mod, np_inputs, "rocm")


def _to_concrete_shape(symbolic_shape, var_table):
result = []
for dim in symbolic_shape:
if not isinstance(dim, tvm.tir.expr.Var):
result.append(dim)
continue

if dim not in var_table:
var_table[dim] = np.random.randint(10, 50)
result.append(var_table[dim])

return tuple(result)


_vars = {
"a": tvm.tir.expr.Var("a", "int64"),
"b": tvm.tir.expr.Var("b", "int64"),
Expand Down Expand Up @@ -118,39 +89,18 @@ def test_matmul_offload(
in_dtype,
out_dtype,
):
with_bias, activation = _epilogue_table[epilogue]
var_table = {}
concrete_x_shape = _to_concrete_shape(x_shape, var_table)
concrete_y_shape = _to_concrete_shape(y_shape, var_table)
x = np.random.randn(*concrete_x_shape).astype(in_dtype)
y = np.random.randn(*concrete_y_shape).astype(in_dtype)

if transpose_y:
y = np.swapaxes(y, -2, -1)
y_shape = (*y_shape[:-2], y_shape[-1], y_shape[-2])

if with_bias:
bias = np.random.randn(concrete_y_shape[-1]).astype(out_dtype)
args = (x, y, bias)
else:
bias = None
args = (x, y)

mod = get_relax_matmul_module(
run_matmul_offload_test(
x_shape,
y_shape,
transpose_y,
epilogue,
in_dtype,
out_dtype,
bias_shape=bias.shape if with_bias else None,
transposed_y=transpose_y,
activation=activation,
_epilogue_table,
partition_for_hipblas,
"rocm",
)

out = get_result_with_relax_cublas_offload(mod, args)
ref = build_and_run(mod, args, "llvm", legalize=True)

tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)


def test_hipblas_partition_matmul_without_bias():
# hipBLAS does not handle 2D bias (residual input)
Expand Down
Loading
Loading