diff --git a/gpu_test/test_kernels.py b/gpu_test/test_kernels.py index 99ab192..dc0bdf7 100644 --- a/gpu_test/test_kernels.py +++ b/gpu_test/test_kernels.py @@ -234,6 +234,43 @@ def test_multi_param(kernel_runner: KernelRunner) -> None: assert result == [20, 40, 60, 80] +# --- Matmul --- + + +def test_naive_matmul_i64(kernel_runner: KernelRunner) -> None: + """Naive i64 matmul: C = A(2x4) * B(4x3) -> C(2x3).""" + # Work partition: one thread per output element. + # GLOBAL-ID maps to (row, col) with row = gid / N, col = gid MOD N. + result = kernel_runner.run( + forth_source=( + "PARAM A 8\n" + "PARAM B 12\n" + "PARAM C 6\n" + "GLOBAL-ID\n" + "DUP 3 /\n" + "SWAP 3 MOD\n" + "0\n" + "4 0 DO\n" + "2 PICK\n" + "I SWAP 4 * +\n" + "CELLS A + @\n" + "I 3 * 3 PICK + CELLS B + @\n" + "* +\n" + "LOOP\n" + "2 PICK 3 * 2 PICK +\n" + "CELLS C + !" + ), + params={ + "A": [1, 2, 3, 4, 5, 6, 7, 8], + "B": [1, 0, 2, 0, 1, 2, 1, 0, 1, 2, 1, 0], + }, + block=(6, 1, 1), + output_param=2, + output_count=6, + ) + assert result == [12, 6, 9, 28, 14, 29] + + # --- User-Defined Words --- diff --git a/test/Pipeline/matmul-naive.forth b/test/Pipeline/matmul-naive.forth new file mode 100644 index 0000000..fe87d58 --- /dev/null +++ b/test/Pipeline/matmul-naive.forth @@ -0,0 +1,27 @@ +\ RUN: %warpforth-translate --forth-to-mlir %s | %warpforth-opt --warpforth-pipeline | %FileCheck %s +\ RUN: %warpforth-translate --forth-to-mlir %s | %warpforth-opt --convert-forth-to-memref --convert-forth-to-gpu | %FileCheck %s --check-prefix=MID + +\ Verify that a naive integer matmul kernel survives the full pipeline. +\ CHECK: gpu.binary @warpforth_module + +\ Verify the kernel signature at the memref+gpu stage. +\ MID: gpu.func @main(%arg0: memref<8xi64> {forth.param_name = "A"}, %arg1: memref<12xi64> {forth.param_name = "B"}, %arg2: memref<6xi64> {forth.param_name = "C"}) kernel + +PARAM A 8 +PARAM B 12 +PARAM C 6 + +\ M=2, N=3, K=4. One thread computes C[row, col] where gid = row*N + col. +GLOBAL-ID +DUP 3 / +SWAP 3 MOD +0 +4 0 DO + 2 PICK + I SWAP 4 * + + CELLS A + @ + I 3 * 3 PICK + CELLS B + @ + * + +LOOP +2 PICK 3 * 2 PICK + +CELLS C + !