Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ uv run ruff format gpu_test/

- **Stack Type**: `!forth.stack` - untyped stack, programmer ensures type safety
- **Operations**: All take stack as input and produce stack as output (except `forth.stack`)
- **Supported Words**: literals, `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `AND OR XOR NOT LSHIFT RSHIFT`, `= < > <> <= >= 0=`, `@ !`, `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `BEGIN WHILE REPEAT`, `DO LOOP I J K`, `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing).
- **Supported Words**: literals, `DUP DROP SWAP OVER ROT NIP TUCK PICK ROLL`, `+ - * / MOD`, `AND OR XOR NOT LSHIFT RSHIFT`, `= < > <> <= >= 0=`, `@ !`, `CELLS`, `IF ELSE THEN`, `BEGIN UNTIL`, `BEGIN WHILE REPEAT`, `DO LOOP +LOOP I J K`, `LEAVE UNLOOP EXIT`, `TID-X/Y/Z BID-X/Y/Z BDIM-X/Y/Z GDIM-X/Y/Z GLOBAL-ID` (GPU indexing).
- **Kernel Parameters**: Declared with `PARAM <name> <size>`, each becomes a `memref<Nxi64>` function argument with `forth.param_name` attribute. Using a param name in code pushes its byte address onto the stack via `forth.param_ref`
- **Conversion**: `!forth.stack` → `memref<256xi64>` with explicit stack pointer
- **GPU**: Functions wrapped in `gpu.module`, `main` gets `gpu.kernel` attribute, configured with bare pointers for NVVM conversion
Expand Down
18 changes: 18 additions & 0 deletions gpu_test/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,24 @@ def test_do_loop(kernel_runner: KernelRunner) -> None:
assert result == [0, 1, 2, 3, 4]


def test_do_plus_loop(kernel_runner: KernelRunner) -> None:
"""DO/+LOOP: write I values 0, 2, 4, 6, 8 to DATA[0..4]."""
result = kernel_runner.run(
forth_source=("PARAM DATA 256\n0\n10 0 DO\n I OVER CELLS DATA + !\n 1 +\n2 +LOOP\nDROP"),
output_count=5,
)
assert result == [0, 2, 4, 6, 8]


def test_do_plus_loop_negative(kernel_runner: KernelRunner) -> None:
"""DO/+LOOP with negative step: count down from 10 to 1."""
result = kernel_runner.run(
forth_source=("PARAM DATA 256\n0\n0 10 DO\n I OVER CELLS DATA + !\n 1 +\n-1 +LOOP\nDROP"),
output_count=10,
)
assert result == [10, 9, 8, 7, 6, 5, 4, 3, 2, 1]


def test_multi_while(kernel_runner: KernelRunner) -> None:
"""Multi-WHILE: two exit conditions from the same loop (interleaved CF).

Expand Down
81 changes: 51 additions & 30 deletions lib/Translation/ForthToMLIR/ForthToMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,36 @@ std::pair<Value, Value> ForthParser::emitPopFlag(Location loc, Value stack) {
return {popFlag.getOutputStack(), popFlag.getFlag()};
}

void ForthParser::emitLoopEnd(Location loc, const LoopContext &ctx, Value step,
Value &stack) {
auto i64Type = builder.getI64Type();

// Load old counter, compute new = old + step, store.
Value c0 = builder.create<arith::ConstantIndexOp>(loc, 0);
Value oldIdx =
builder.create<memref::LoadOp>(loc, ctx.counter, ValueRange{c0});
Value newIdx = builder.create<arith::AddIOp>(loc, oldIdx, step);
builder.create<memref::StoreOp>(loc, newIdx, ctx.counter, ValueRange{c0});

// Crossing test: ((oldIdx - limit) XOR (newIdx - limit)) < 0
// This correctly handles both positive and negative step values.
Value oldDiff = builder.create<arith::SubIOp>(loc, oldIdx, ctx.limit);
Value newDiff = builder.create<arith::SubIOp>(loc, newIdx, ctx.limit);
Value xorVal = builder.create<arith::XOrIOp>(loc, oldDiff, newDiff);
Value zero = builder.create<arith::ConstantOp>(loc, i64Type,
builder.getI64IntegerAttr(0));
Value crossed = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
xorVal, zero);

// If crossed → exit, otherwise → loop back to body.
builder.create<cf::CondBranchOp>(loc, crossed, ctx.exit, ValueRange{stack},
ctx.body, ValueRange{stack});

// Continue after exit.
builder.setInsertionPointToStart(ctx.exit);
stack = ctx.exit->getArgument(0);
}

LogicalResult ForthParser::parseBody(Value &stack) {
Type stackType = forth::StackType::get(context);

Expand Down Expand Up @@ -644,27 +674,15 @@ LogicalResult ForthParser::parseBody(Value &stack) {
Value c0 = builder.create<arith::ConstantIndexOp>(loc, 0);
builder.create<memref::StoreOp>(loc, start, counter, ValueRange{c0});

// Create check, body, and exit blocks.
auto *checkBlock = createStackBlock(parentRegion, loc);
// Create body and exit blocks (post-test loop: always enters once).
auto *bodyBlock = createStackBlock(parentRegion, loc);
auto *exitBlock = createStackBlock(parentRegion, loc);

// Branch to check.
builder.create<cf::BranchOp>(loc, checkBlock, ValueRange{s2});

// --- Check block: load counter, compare < limit ---
builder.setInsertionPointToStart(checkBlock);
Value checkC0 = builder.create<arith::ConstantIndexOp>(loc, 0);
Value idx =
builder.create<memref::LoadOp>(loc, counter, ValueRange{checkC0});
Value cond = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, idx, limit);
builder.create<cf::CondBranchOp>(
loc, cond, bodyBlock, ValueRange{checkBlock->getArgument(0)},
exitBlock, ValueRange{checkBlock->getArgument(0)});
// Branch directly to body.
builder.create<cf::BranchOp>(loc, bodyBlock, ValueRange{s2});

// Push loop context for I/J/K.
loopStack.push_back({counter, limit, checkBlock, exitBlock});
loopStack.push_back({counter, limit, bodyBlock, exitBlock});

// Continue parsing in body.
builder.setInsertionPointToStart(bodyBlock);
Expand All @@ -673,29 +691,32 @@ LogicalResult ForthParser::parseBody(Value &stack) {
//=== LOOP ===
} else if (word == "LOOP") {
consume();
auto i64Type = builder.getI64Type();

if (loopStack.empty()) {
return emitError("LOOP without matching DO");
}

auto ctx = loopStack.pop_back_val();

// Increment counter: load, add 1, store.
Value c0 = builder.create<arith::ConstantIndexOp>(loc, 0);
Value idx =
builder.create<memref::LoadOp>(loc, ctx.counter, ValueRange{c0});
Value one = builder.create<arith::ConstantOp>(
loc, i64Type, builder.getI64IntegerAttr(1));
Value next = builder.create<arith::AddIOp>(loc, idx, one);
builder.create<memref::StoreOp>(loc, next, ctx.counter, ValueRange{c0});
loc, builder.getI64Type(), builder.getI64IntegerAttr(1));
emitLoopEnd(loc, ctx, one, stack);

// Branch back to check.
builder.create<cf::BranchOp>(loc, ctx.check, ValueRange{stack});
//=== +LOOP ===
} else if (word == "+LOOP") {
consume();

// Continue after exit.
builder.setInsertionPointToStart(ctx.exit);
stack = ctx.exit->getArgument(0);
if (loopStack.empty()) {
return emitError("+LOOP without matching DO");
}

auto ctx = loopStack.pop_back_val();

// Pop step from data stack.
auto popOp = builder.create<forth::PopOp>(
loc, forth::StackType::get(context), builder.getI64Type(), stack);
stack = popOp.getOutputStack();
Value step = popOp.getValue();
emitLoopEnd(loc, ctx, step, stack);

//=== Normal word ===
} else {
Expand Down
7 changes: 6 additions & 1 deletion lib/Translation/ForthToMLIR/ForthToMLIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class ForthParser {
struct LoopContext {
Value counter; // memref<1xi64> alloca for the loop counter
Value limit; // i64 loop limit
Block *check; // condition check block
Block *body; // loop body block
Block *exit; // loop exit block
};
SmallVector<LoopContext> loopStack;
Expand Down Expand Up @@ -127,6 +127,11 @@ class ForthParser {
/// Parse a sequence of Forth operations, handling control flow inline.
LogicalResult parseBody(Value &stack);

/// Emit the common loop-end logic for LOOP and +LOOP:
/// load counter, add step, store, crossing test, cond_br to exit or body.
void emitLoopEnd(Location loc, const LoopContext &ctx, Value step,
Value &stack);

/// Parse a user-defined word definition.
LogicalResult parseWordDefinition();
};
Expand Down
44 changes: 20 additions & 24 deletions test/Conversion/ForthToMemRef/do-loop.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: %warpforth-opt --convert-forth-to-memref %s | %FileCheck %s

// Test: DO...LOOP with I conversion to memref with CF-based control flow
// Test: DO...LOOP with I conversion to memref with post-test crossing check
// Forth: 10 0 DO I LOOP

// CHECK-LABEL: func.func private @main
Expand All @@ -23,24 +23,20 @@
// CHECK: memref.store %{{.*}}, %[[COUNTER]]
// CHECK: cf.br ^bb1

// Loop header: load counter, compare < limit, cond_br
// Loop body: push I (load counter, push to stack), crossing test
// CHECK: ^bb1(%{{.*}}: memref<256xi64>, %{{.*}}: index):
// CHECK: memref.load %[[COUNTER]]
// CHECK: arith.cmpi slt
// CHECK: cf.cond_br %{{.*}}, ^bb2(%{{.*}}: memref<256xi64>, index), ^bb3(%{{.*}}: memref<256xi64>, index)

// Loop body: push I (load counter, push to stack), increment counter
// CHECK: ^bb2(%{{.*}}: memref<256xi64>, %{{.*}}: index):
// CHECK: memref.load %[[COUNTER]]
// CHECK: memref.store
// CHECK: memref.load %[[COUNTER]]
// CHECK: arith.constant 1 : i64
// CHECK: arith.addi
// CHECK: memref.store %{{.*}}, %[[COUNTER]]
// CHECK: cf.br ^bb1
// CHECK: arith.subi
// CHECK: arith.subi
// CHECK: arith.xori
// CHECK: arith.cmpi slt
// CHECK: cf.cond_br

// Exit block
// CHECK: ^bb3(%{{.*}}: memref<256xi64>, %{{.*}}: index):
// CHECK: ^bb2(%{{.*}}: memref<256xi64>, %{{.*}}: index):
// CHECK: return

module {
Expand All @@ -57,19 +53,19 @@ module {
^bb1(%3: !forth.stack):
%c0_2 = arith.constant 0 : index
%4 = memref.load %alloca[%c0_2] : memref<1xi64>
%5 = arith.cmpi slt, %4, %value_1 : i64
cf.cond_br %5, ^bb2(%3 : !forth.stack), ^bb3(%3 : !forth.stack)
^bb2(%6: !forth.stack):
%c0_3 = arith.constant 0 : index
%7 = memref.load %alloca[%c0_3] : memref<1xi64>
%8 = forth.push_value %6, %7 : !forth.stack, i64 -> !forth.stack
%c0_4 = arith.constant 0 : index
%9 = memref.load %alloca[%c0_4] : memref<1xi64>
%5 = forth.push_value %3, %4 : !forth.stack, i64 -> !forth.stack
%c1_i64 = arith.constant 1 : i64
%10 = arith.addi %9, %c1_i64 : i64
memref.store %10, %alloca[%c0_4] : memref<1xi64>
cf.br ^bb1(%8 : !forth.stack)
^bb3(%11: !forth.stack):
%c0_3 = arith.constant 0 : index
%6 = memref.load %alloca[%c0_3] : memref<1xi64>
%7 = arith.addi %6, %c1_i64 : i64
memref.store %7, %alloca[%c0_3] : memref<1xi64>
%8 = arith.subi %6, %value_1 : i64
%9 = arith.subi %7, %value_1 : i64
%10 = arith.xori %8, %9 : i64
%c0_i64 = arith.constant 0 : i64
%11 = arith.cmpi slt, %10, %c0_i64 : i64
cf.cond_br %11, ^bb2(%5 : !forth.stack), ^bb1(%5 : !forth.stack)
^bb2(%12: !forth.stack):
return
}
}
26 changes: 16 additions & 10 deletions test/Conversion/ForthToMemRef/leave.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
// CHECK: %[[STACK:.*]] = memref.alloca() : memref<256xi64>
// CHECK: cf.br ^bb1(%[[STACK]], %{{.*}} : memref<256xi64>, index)
// CHECK: ^bb1(%{{.*}}: memref<256xi64>, %{{.*}}: index):
// CHECK: cf.cond_br %{{.*}}, ^bb2(%{{.*}}: memref<256xi64>, index), ^bb3(%{{.*}}: memref<256xi64>, index)
// CHECK: cf.cond_br %true, ^bb2(%{{.*}}: memref<256xi64>, index), ^bb3(%{{.*}}: memref<256xi64>, index)
// CHECK: ^bb2(%{{.*}}: memref<256xi64>, %{{.*}}: index):
// CHECK-NEXT: cf.br ^bb3(%{{.*}}: memref<256xi64>, index)
// CHECK: ^bb3(%{{.*}}: memref<256xi64>, %{{.*}}: index):
// CHECK: return

module {
Expand All @@ -25,13 +23,21 @@ module {
memref.store %value, %alloca[%c0] : memref<1xi64>
cf.br ^bb1(%output_stack_0 : !forth.stack)
^bb1(%3: !forth.stack):
%c0_2 = arith.constant 0 : index
%4 = memref.load %alloca[%c0_2] : memref<1xi64>
%5 = arith.cmpi slt, %4, %value_1 : i64
cf.cond_br %5, ^bb2(%3 : !forth.stack), ^bb3(%3 : !forth.stack)
^bb2(%6: !forth.stack):
cf.br ^bb3(%6 : !forth.stack)
^bb3(%7: !forth.stack):
%true = arith.constant true
cf.cond_br %true, ^bb2(%3 : !forth.stack), ^bb3(%3 : !forth.stack)
^bb2(%4: !forth.stack):
return
^bb3(%5: !forth.stack):
%c1_i64 = arith.constant 1 : i64
%c0_2 = arith.constant 0 : index
%6 = memref.load %alloca[%c0_2] : memref<1xi64>
%7 = arith.addi %6, %c1_i64 : i64
memref.store %7, %alloca[%c0_2] : memref<1xi64>
%8 = arith.subi %6, %value_1 : i64
%9 = arith.subi %7, %value_1 : i64
%10 = arith.xori %8, %9 : i64
%c0_i64 = arith.constant 0 : i64
%11 = arith.cmpi slt, %10, %c0_i64 : i64
cf.cond_br %11, ^bb2(%5 : !forth.stack), ^bb1(%5 : !forth.stack)
}
}
3 changes: 1 addition & 2 deletions test/Pipeline/nested-control-flow.forth
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
\ MID: gpu.module @warpforth_module
\ MID: gpu.func @main(%arg0: memref<4xi64> {forth.param_name = "DATA"}) kernel
\ MID: cf.br
\ MID: cf.cond_br
\ MID: gpu.return
\ MID: arith.xori

PARAM DATA 4
3 0 DO 4 0 DO J I + LOOP LOOP DATA 0 CELLS + !
7 changes: 7 additions & 0 deletions test/Pipeline/plus-loop-negative.forth
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
\ RUN: %warpforth-translate --forth-to-mlir %s | %warpforth-opt --warpforth-pipeline | %FileCheck %s

\ Verify that +LOOP with negative step through the full pipeline produces a gpu.binary
\ CHECK: gpu.binary @warpforth_module

PARAM DATA 4
0 10 DO I DATA 0 CELLS + ! -1 +LOOP
7 changes: 7 additions & 0 deletions test/Pipeline/plus-loop.forth
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
\ RUN: %warpforth-translate --forth-to-mlir %s | %warpforth-opt --warpforth-pipeline | %FileCheck %s

\ Verify that +LOOP through the full pipeline produces a gpu.binary
\ CHECK: gpu.binary @warpforth_module

PARAM DATA 4
10 0 DO I DATA 0 CELLS + ! 2 +LOOP
26 changes: 13 additions & 13 deletions test/Translation/Forth/do-loop.forth
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
\ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s

\ Verify DO/LOOP generates loop counter with memref.alloca, pop, cmpi, cond_br
\ Verify DO/LOOP generates post-test loop with crossing test

\ CHECK: %[[S0:.*]] = forth.stack !forth.stack
\ CHECK-NEXT: %[[S1:.*]] = forth.literal %[[S0]] 10 : !forth.stack -> !forth.stack
Expand All @@ -14,18 +14,18 @@
\ CHECK: ^bb1(%[[B1:.*]]: !forth.stack):
\ CHECK-NEXT: %[[C0_2:.*]] = arith.constant 0 : index
\ CHECK-NEXT: %[[LOAD1:.*]] = memref.load %[[ALLOCA]][%[[C0_2]]] : memref<1xi64>
\ CHECK-NEXT: %[[CMP:.*]] = arith.cmpi slt, %[[LOAD1]], %[[LIM]] : i64
\ CHECK-NEXT: cf.cond_br %[[CMP]], ^bb2(%[[B1]] : !forth.stack), ^bb3(%[[B1]] : !forth.stack)
\ CHECK: ^bb2(%[[B2:.*]]: !forth.stack):
\ CHECK-NEXT: %[[C0_3:.*]] = arith.constant 0 : index
\ CHECK-NEXT: %[[LOAD2:.*]] = memref.load %[[ALLOCA]][%[[C0_3]]] : memref<1xi64>
\ CHECK-NEXT: %[[PUSH:.*]] = forth.push_value %[[B2]], %[[LOAD2]] : !forth.stack, i64 -> !forth.stack
\ CHECK-NEXT: %[[C0_4:.*]] = arith.constant 0 : index
\ CHECK-NEXT: %[[LOAD3:.*]] = memref.load %[[ALLOCA]][%[[C0_4]]] : memref<1xi64>
\ CHECK-NEXT: %[[PUSH:.*]] = forth.push_value %[[B1]], %[[LOAD1]] : !forth.stack, i64 -> !forth.stack
\ CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : i64
\ CHECK-NEXT: %[[ADDI:.*]] = arith.addi %[[LOAD3]], %[[C1]] : i64
\ CHECK-NEXT: memref.store %[[ADDI]], %[[ALLOCA]][%[[C0_4]]] : memref<1xi64>
\ CHECK-NEXT: cf.br ^bb1(%[[PUSH]] : !forth.stack)
\ CHECK: ^bb3(%[[B3:.*]]: !forth.stack):
\ CHECK-NEXT: %[[C0_3:.*]] = arith.constant 0 : index
\ CHECK-NEXT: %[[OLD:.*]] = memref.load %[[ALLOCA]][%[[C0_3]]] : memref<1xi64>
\ CHECK-NEXT: %[[NEW:.*]] = arith.addi %[[OLD]], %[[C1]] : i64
\ CHECK-NEXT: memref.store %[[NEW]], %[[ALLOCA]][%[[C0_3]]] : memref<1xi64>
\ CHECK-NEXT: %[[D1:.*]] = arith.subi %[[OLD]], %[[LIM]] : i64
\ CHECK-NEXT: %[[D2:.*]] = arith.subi %[[NEW]], %[[LIM]] : i64
\ CHECK-NEXT: %[[XOR:.*]] = arith.xori %[[D1]], %[[D2]] : i64
\ CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 : i64
\ CHECK-NEXT: %[[CROSSED:.*]] = arith.cmpi slt, %[[XOR]], %[[ZERO]] : i64
\ CHECK-NEXT: cf.cond_br %[[CROSSED]], ^bb2(%[[PUSH]] : !forth.stack), ^bb1(%[[PUSH]] : !forth.stack)
\ CHECK: ^bb2(%[[B2:.*]]: !forth.stack):
\ CHECK-NEXT: return
10 0 DO I LOOP
28 changes: 20 additions & 8 deletions test/Translation/Forth/leave-conditional.forth
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,31 @@

\ Verify conditional LEAVE preserves the loop backedge for non-LEAVE paths.

\ Branch directly to body (post-test loop)
\ CHECK: cf.br ^bb1(%{{.*}} : !forth.stack)
\ CHECK: ^bb1(%[[CHK:.*]]: !forth.stack):
\ CHECK: cf.cond_br %{{.*}}, ^bb2(%[[CHK]] : !forth.stack), ^bb[[EXIT:[0-9]+]](%[[CHK]] : !forth.stack)
\ CHECK: ^bb2(%[[B:.*]]: !forth.stack):
\ CHECK: cf.cond_br %{{.*}}, ^bb[[LEAVE:[0-9]+]](%{{.*}} : !forth.stack), ^bb[[JOIN:[0-9]+]](%{{.*}} : !forth.stack)
\ CHECK: ^bb[[EXIT]](%{{.*}}: !forth.stack):

\ Body: I 5 = IF → cond_br to LEAVE or THEN merge
\ CHECK: ^bb1(%[[B:.*]]: !forth.stack):
\ CHECK: forth.pop_flag
\ CHECK-NEXT: cf.cond_br %{{[^,]*}}, ^bb[[LEAVE:[0-9]+]](%{{[^)]*}} : !forth.stack), ^bb[[JOIN:[0-9]+]](%{{[^)]*}} : !forth.stack)

\ Exit: return
\ CHECK: ^bb[[EXIT:[0-9]+]](%{{.*}}: !forth.stack):
\ CHECK: return

\ LEAVE branch: unconditional jump to exit
\ CHECK: ^bb[[LEAVE]](%{{.*}}: !forth.stack):
\ CHECK: cf.cond_br %{{.*}}, ^bb[[EXIT]](%{{.*}} : !forth.stack), ^bb[[DEAD:[0-9]+]](%{{.*}} : !forth.stack)
\ CHECK: cf.cond_br %true, ^bb[[EXIT]](%{{.*}} : !forth.stack), ^bb[[DEAD:[0-9]+]](%{{.*}} : !forth.stack)

\ Join (THEN merge): 1 DROP, crossing test, loop back to body or exit
\ CHECK: ^bb[[JOIN]](%{{.*}}: !forth.stack):
\ CHECK: cf.br ^bb1(%{{.*}} : !forth.stack)
\ CHECK: arith.xori
\ CHECK: arith.cmpi slt
\ CHECK: cf.cond_br

\ Dead block from LEAVE
\ CHECK: ^bb[[DEAD]](%{{.*}}: !forth.stack):
\ CHECK: cf.br ^bb[[JOIN]](%{{.*}} : !forth.stack)
\ CHECK: cf.br ^bb[[JOIN]]

10 0 DO
I 5 = IF LEAVE THEN
Expand Down
Loading