Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions gpu_test/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,43 @@ def test_multi_param(kernel_runner: KernelRunner) -> None:
assert result == [20, 40, 60, 80]


# --- Matmul ---


def test_naive_matmul_i64(kernel_runner: KernelRunner) -> None:
"""Naive i64 matmul: C = A(2x4) * B(4x3) -> C(2x3)."""
# Work partition: one thread per output element.
# GLOBAL-ID maps to (row, col) with row = gid / N, col = gid MOD N.
result = kernel_runner.run(
forth_source=(
"PARAM A 8\n"
"PARAM B 12\n"
"PARAM C 6\n"
"GLOBAL-ID\n"
"DUP 3 /\n"
"SWAP 3 MOD\n"
"0\n"
"4 0 DO\n"
"2 PICK\n"
"I SWAP 4 * +\n"
"CELLS A + @\n"
"I 3 * 3 PICK + CELLS B + @\n"
"* +\n"
"LOOP\n"
"2 PICK 3 * 2 PICK +\n"
"CELLS C + !"
),
params={
"A": [1, 2, 3, 4, 5, 6, 7, 8],
"B": [1, 0, 2, 0, 1, 2, 1, 0, 1, 2, 1, 0],
},
block=(6, 1, 1),
output_param=2,
output_count=6,
)
assert result == [12, 6, 9, 28, 14, 29]


# --- User-Defined Words ---


Expand Down
27 changes: 27 additions & 0 deletions test/Pipeline/matmul-naive.forth
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
\ RUN: %warpforth-translate --forth-to-mlir %s | %warpforth-opt --warpforth-pipeline | %FileCheck %s
\ RUN: %warpforth-translate --forth-to-mlir %s | %warpforth-opt --convert-forth-to-memref --convert-forth-to-gpu | %FileCheck %s --check-prefix=MID

\ Verify that a naive integer matmul kernel survives the full pipeline.
\ CHECK: gpu.binary @warpforth_module

\ Verify the kernel signature at the memref+gpu stage.
\ MID: gpu.func @main(%arg0: memref<8xi64> {forth.param_name = "A"}, %arg1: memref<12xi64> {forth.param_name = "B"}, %arg2: memref<6xi64> {forth.param_name = "C"}) kernel

PARAM A 8
PARAM B 12
PARAM C 6

\ M=2, N=3, K=4. One thread computes C[row, col] where gid = row*N + col.
GLOBAL-ID
DUP 3 /
SWAP 3 MOD
0
4 0 DO
2 PICK
I SWAP 4 * +
CELLS A + @
I 3 * 3 PICK + CELLS B + @
* +
LOOP
2 PICK 3 * 2 PICK +
CELLS C + !