diff --git a/gpu_test/test_kernels.py b/gpu_test/test_kernels.py index c09b8b5..5a99cfa 100644 --- a/gpu_test/test_kernels.py +++ b/gpu_test/test_kernels.py @@ -154,6 +154,36 @@ def test_do_loop(kernel_runner: KernelRunner) -> None: assert result == [0, 1, 2, 3, 4] +def test_multi_while(kernel_runner: KernelRunner) -> None: + """Multi-WHILE: two exit conditions from the same loop (interleaved CF). + + 20 BEGIN DUP 10 > WHILE DUP 2 MOD 0= WHILE 1 - REPEAT THEN + Decrements while >10 AND even. 20→19 (odd, WHILE(2) exit) → result 19. + """ + result = kernel_runner.run( + forth_source=( + "PARAM DATA 256\n" + "20 BEGIN DUP 10 > WHILE DUP 2 MOD 0= WHILE 1 - REPEAT THEN\n" + "0 CELLS DATA + !" + ), + ) + assert result[0] == 19 + + +def test_while_until(kernel_runner: KernelRunner) -> None: + """WHILE+UNTIL: two different exit mechanisms from the same loop (interleaved CF). + + 10 BEGIN DUP 0 > WHILE 1 - DUP 5 = UNTIL THEN + Decrements while >0, stops early at 5. 10→9→…→5 (UNTIL exit) → result 5. + """ + result = kernel_runner.run( + forth_source=( + "PARAM DATA 256\n10 BEGIN DUP 0 > WHILE 1 - DUP 5 = UNTIL THEN\n0 CELLS DATA + !" + ), + ) + assert result[0] == 5 + + # --- GPU Indexing --- diff --git a/include/warpforth/Conversion/Passes.td b/include/warpforth/Conversion/Passes.td index 044b663..38342a1 100644 --- a/include/warpforth/Conversion/Passes.td +++ b/include/warpforth/Conversion/Passes.td @@ -28,7 +28,7 @@ def ConvertForthToMemRef let dependentDialects = ["mlir::memref::MemRefDialect", "mlir::arith::ArithDialect", "mlir::LLVM::LLVMDialect", - "mlir::scf::SCFDialect"]; + "mlir::cf::ControlFlowDialect"]; } def ConvertForthToGPU : Pass<"convert-forth-to-gpu", "mlir::ModuleOp"> { diff --git a/include/warpforth/Dialect/Forth/ForthOps.td b/include/warpforth/Dialect/Forth/ForthOps.td index 79ed571..f13f789 100644 --- a/include/warpforth/Dialect/Forth/ForthOps.td +++ b/include/warpforth/Dialect/Forth/ForthOps.td @@ -8,7 +8,6 @@ #define FORTH_OPS include "warpforth/Dialect/Forth/ForthDialect.td" -include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" //===----------------------------------------------------------------------===// @@ -446,102 +445,48 @@ def Forth_ZeroEqOp : Forth_StackOpBase<"zero_eq"> { } //===----------------------------------------------------------------------===// -// Control flow operations. +// Control flow support operations. //===----------------------------------------------------------------------===// -def Forth_YieldOp : Forth_Op<"yield", [Pure, Terminator, ReturnLike, - ParentOneOf<["IfOp", "BeginUntilOp", "BeginWhileRepeatOp", "DoLoopOp"]>]> { - let summary = "Yield stack from control flow region"; +def Forth_PopFlagOp : Forth_Op<"pop_flag", [Pure]> { + let summary = "Pop top of stack as boolean flag"; let description = [{ - Yields the current stack state from a control flow region back to - the parent operation. Acts as a region terminator. - When the optional `while_cond` attribute is present, the yield acts as - a WHILE condition (continue when flag is non-zero) rather than - UNTIL (exit when flag is non-zero). + Pops the top value from the stack and returns it as an i1 flag + (non-zero = true, zero = false). Used by IF, UNTIL, WHILE. + Forth semantics: ( flag -- ) }]; - let arguments = (ins Forth_StackType:$result, OptionalAttr:$while_cond); + let arguments = (ins Forth_StackType:$input_stack); + let results = (outs Forth_StackType:$output_stack, I1:$flag); let assemblyFormat = [{ - $result (`while_cond` $while_cond^)? attr-dict `:` type($result) + $input_stack attr-dict `:` type($input_stack) `->` type($output_stack) `,` type($flag) }]; } -def Forth_IfOp : Forth_Op<"if", [RecursiveMemoryEffects, - DeclareOpInterfaceMethods]> { - let summary = "Conditional execution"; +def Forth_PopOp : Forth_Op<"pop", [Pure]> { + let summary = "Pop top of stack as i64 value"; let description = [{ - Conditional execution. If the flag is non-zero, the then region executes; - otherwise the else region executes. Each region must yield the resulting - stack. - Forth semantics: flag IF then-body ELSE else-body THEN + Pops the top value from the stack and returns it as an i64. + Used by DO to pop start and limit. + Forth semantics: ( x -- ) }]; let arguments = (ins Forth_StackType:$input_stack); - let results = (outs Forth_StackType:$output_stack); - let regions = (region SizedRegion<1>:$then_region, - SizedRegion<1>:$else_region); - let hasCustomAssemblyFormat = 1; -} - -def Forth_BeginUntilOp : Forth_Op<"begin_until", [RecursiveMemoryEffects, - DeclareOpInterfaceMethods]> { - let summary = "Post-test loop (do-while)"; - let description = [{ - BEGIN/UNTIL loop. Executes body, pops flag. If flag is zero, loops back. - If non-zero, exits. Stack effect: ( -- ) with flag consumed each iteration. - }]; - let arguments = (ins Forth_StackType:$input_stack); - let results = (outs Forth_StackType:$output_stack); - let regions = (region SizedRegion<1>:$body_region); - let hasCustomAssemblyFormat = 1; -} - -def Forth_DoLoopOp : Forth_Op<"do_loop", [RecursiveMemoryEffects, - DeclareOpInterfaceMethods]> { - let summary = "Counted loop (DO/LOOP)"; - let description = [{ - Pops start and limit from the stack, iterates from start to limit-1. - Use forth.loop_index (I word) inside to access the current loop index. - Stack effect: ( limit start -- ) - }]; - let arguments = (ins Forth_StackType:$input_stack); - let results = (outs Forth_StackType:$output_stack); - let regions = (region SizedRegion<1>:$body_region); - let hasCustomAssemblyFormat = 1; -} - -def Forth_BeginWhileRepeatOp : Forth_Op<"begin_while_repeat", - [RecursiveMemoryEffects, - DeclareOpInterfaceMethods]> { - let summary = "Pre-test loop (BEGIN/WHILE/REPEAT)"; - let description = [{ - BEGIN/WHILE/REPEAT loop. The condition region runs first, WHILE pops flag. - If flag is non-zero, the body region executes and loops back to condition. - If flag is zero, the loop exits. - Stack effect: ( -- ) with flag consumed each iteration. + let results = (outs Forth_StackType:$output_stack, I64:$value); + let assemblyFormat = [{ + $input_stack attr-dict `:` type($input_stack) `->` type($output_stack) `,` type($value) }]; - let arguments = (ins Forth_StackType:$input_stack); - let results = (outs Forth_StackType:$output_stack); - let regions = (region SizedRegion<1>:$condition_region, - SizedRegion<1>:$body_region); - let hasCustomAssemblyFormat = 1; } -def Forth_LoopIndexOp : Forth_Op<"loop_index", [Pure]> { - let summary = "Push loop index onto stack (I/J/K words)"; +def Forth_PushValueOp : Forth_Op<"push_value", [Pure]> { + let summary = "Push dynamic i64 value onto stack"; let description = [{ - Pushes the loop index at the given nesting depth onto the stack. - depth=0 is I (innermost), depth=1 is J, depth=2 is K. - Only valid inside nested forth.do_loop bodies at sufficient depth. - ( -- index ) + Pushes a dynamic i64 value onto the stack. Used by I/J/K to push + the loop counter value. + Forth semantics: ( -- x ) }]; - let arguments = (ins Forth_StackType:$input_stack, - DefaultValuedAttr:$depth); + let arguments = (ins Forth_StackType:$input_stack, I64:$value); let results = (outs Forth_StackType:$output_stack); let assemblyFormat = [{ - $input_stack attr-dict `:` type($input_stack) `->` type($output_stack) + $input_stack `,` $value attr-dict `:` type($input_stack) `,` type($value) `->` type($output_stack) }]; } diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 9d431fd..329ecf0 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -12,7 +12,6 @@ add_mlir_library(MLIRConversionPasses MLIRGPUToNVVMTransforms MLIRGPUTransforms MLIRReconcileUnrealizedCasts - MLIRSCFToControlFlow MLIRTransforms ) diff --git a/lib/Conversion/ForthToMemRef/CMakeLists.txt b/lib/Conversion/ForthToMemRef/CMakeLists.txt index cee46f8..ff651f1 100644 --- a/lib/Conversion/ForthToMemRef/CMakeLists.txt +++ b/lib/Conversion/ForthToMemRef/CMakeLists.txt @@ -13,6 +13,6 @@ add_mlir_conversion_library(MLIRForthToMemRefConversion MLIRArithDialect MLIRLLVMDialect MLIRFuncDialect - MLIRSCFDialect + MLIRControlFlowDialect MLIRForth ) diff --git a/lib/Conversion/ForthToMemRef/ForthToMemRef.cpp b/lib/Conversion/ForthToMemRef/ForthToMemRef.cpp index 706af42..9a70304 100644 --- a/lib/Conversion/ForthToMemRef/ForthToMemRef.cpp +++ b/lib/Conversion/ForthToMemRef/ForthToMemRef.cpp @@ -6,11 +6,11 @@ #include "warpforth/Conversion/ForthToMemRef/ForthToMemRef.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" @@ -32,18 +32,19 @@ constexpr int64_t kStackSize = 256; class ForthToMemRefTypeConverter : public TypeConverter { public: ForthToMemRefTypeConverter() { + // Pass-through for all non-stack types (must be registered first) + addConversion([](Type type) { return type; }); + + // Stack type: !forth.stack -> memref<256xi64> + index addConversion( - [&](Type type, + [&](forth::StackType type, SmallVectorImpl &results) -> std::optional { - if (auto stackType = dyn_cast(type)) { - auto memrefType = MemRefType::get( - {kStackSize}, IntegerType::get(type.getContext(), 64)); - auto indexType = IndexType::get(type.getContext()); - results.push_back(memrefType); - results.push_back(indexType); - return success(); - } - return std::nullopt; + auto memrefType = MemRefType::get( + {kStackSize}, IntegerType::get(type.getContext(), 64)); + auto indexType = IndexType::get(type.getContext()); + results.push_back(memrefType); + results.push_back(indexType); + return success(); }); } }; @@ -353,6 +354,7 @@ struct PickOpConversion : public OpConversionPattern { /// Conversion pattern for forth.roll operation. /// Rotates nth element to top: ( xn ... x0 n -- xn-1 ... x0 xn ) +/// Uses a CF-based loop to shift elements down. struct RollOpConversion : public OpConversionPattern { RollOpConversion(const TypeConverter &typeConverter, MLIRContext *context) : OpConversionPattern(typeConverter, context) {} @@ -366,14 +368,15 @@ struct RollOpConversion : public OpConversionPattern { Value memref = inputStack[0]; Value stackPtr = inputStack[1]; + auto indexType = rewriter.getIndexType(); + // Pop n from stack Value nI64 = rewriter.create(loc, memref, stackPtr); Value one = rewriter.create(loc, 1); Value spAfterPop = rewriter.create(loc, stackPtr, one); // Cast n to index - Value nIdx = - rewriter.create(loc, rewriter.getIndexType(), nI64); + Value nIdx = rewriter.create(loc, indexType, nI64); // Compute address of the element to roll: SP' - n Value rolledAddr = rewriter.create(loc, spAfterPop, nIdx); @@ -382,18 +385,38 @@ struct RollOpConversion : public OpConversionPattern { Value rolledValue = rewriter.create(loc, memref, rolledAddr); - // Shift elements down: for i in [rolledAddr, SP') : memref[i] = memref[i+1] - auto forOp = rewriter.create(loc, rolledAddr, spAfterPop, one); - - // Insert ops at start of the auto-created body, before the yield - rewriter.setInsertionPointToStart(forOp.getBody()); - Value iv = forOp.getInductionVar(); - Value iPlusOne = rewriter.create(loc, iv, one); - Value shiftedVal = rewriter.create(loc, memref, iPlusOne); - rewriter.create(loc, shiftedVal, memref, iv); - - // Store saved value at top (SP') - rewriter.setInsertionPointAfter(forOp); + // Split block to create CF-based shift loop. + Block *currentBlock = rewriter.getInsertionBlock(); + Block *continueBlock = rewriter.splitBlock(currentBlock, op->getIterator()); + + // Create loop header and body blocks (inserted before continueBlock). + Block *headerBlock = rewriter.createBlock(continueBlock); + headerBlock->addArgument(indexType, loc); // induction variable + Block *bodyBlock = rewriter.createBlock(continueBlock); + bodyBlock->addArgument(indexType, loc); // induction variable + + // currentBlock -> headerBlock(rolledAddr) + rewriter.setInsertionPointToEnd(currentBlock); + rewriter.create(loc, headerBlock, ValueRange{rolledAddr}); + + // headerBlock: check iv < spAfterPop + rewriter.setInsertionPointToStart(headerBlock); + Value iv = headerBlock->getArgument(0); + Value cond = rewriter.create(loc, arith::CmpIPredicate::slt, + iv, spAfterPop); + rewriter.create(loc, cond, bodyBlock, ValueRange{iv}, + continueBlock, ValueRange{}); + + // bodyBlock: shift memref[iv] = memref[iv+1], branch back to header + rewriter.setInsertionPointToStart(bodyBlock); + Value biv = bodyBlock->getArgument(0); + Value next = rewriter.create(loc, biv, one); + Value shiftedVal = rewriter.create(loc, memref, next); + rewriter.create(loc, shiftedVal, memref, biv); + rewriter.create(loc, headerBlock, ValueRange{next}); + + // continueBlock: store rolled value at top, then rest of original block + rewriter.setInsertionPoint(op); rewriter.create(loc, rolledValue, memref, spAfterPop); // Net effect: SP' = SP - 1 (consumed n) @@ -763,338 +786,167 @@ struct GlobalIdOpConversion : public OpConversionPattern { } }; -/// Conversion pattern for forth.yield operation. -/// Context-aware: inside scf.while's `before` region (from BeginUntilOp or -/// BeginWhileRepeatOp), emits flag-pop + scf.condition; otherwise emits -/// scf.yield with SP. -struct YieldOpConversion : public OpConversionPattern { - YieldOpConversion(const TypeConverter &typeConverter, MLIRContext *context) - : OpConversionPattern(typeConverter, context) {} +/// Conversion pattern for forth.pop_flag operation. +/// Pops top of stack, compares != 0, returns (memref, newSP, i1 flag). +struct PopFlagOpConversion : public OpConversionPattern { + PopFlagOpConversion(const TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} using OneToNOpAdaptor = OpConversionPattern::OneToNOpAdaptor; LogicalResult - matchAndRewrite(forth::YieldOp op, OneToNOpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - ValueRange adaptedResult = adaptor.getOperands()[0]; - Value memref = adaptedResult[0]; - Value sp = adaptedResult[1]; // index - - // Check if we're inside scf.while's `before` region. - auto *parentOp = op->getParentOp(); - if (auto whileOp = dyn_cast(parentOp)) { - if (op->getParentRegion() == &whileOp.getBefore()) { - // Pop flag from stack top. - Value flag = rewriter.create(loc, memref, sp); - Value one = rewriter.create(loc, 1); - Value spAfterPop = rewriter.create(loc, sp, one); - - Value zero = rewriter.create(loc, 0, 64); - Value keepGoing; - if (op.getWhileCond()) { - // WHILE semantics: continue when flag is non-zero. - keepGoing = rewriter.create( - loc, arith::CmpIPredicate::ne, flag, zero); - } else { - // UNTIL semantics: exit on non-zero; keep going when flag == 0. - keepGoing = rewriter.create( - loc, arith::CmpIPredicate::eq, flag, zero); - } - - rewriter.replaceOpWithNewOp(op, keepGoing, - ValueRange{spAfterPop}); - return success(); - } - } - - // Default: emit scf.yield with just the SP. - rewriter.replaceOpWithNewOp(op, ValueRange{sp}); - return success(); - } -}; - -/// Conversion pattern for forth.if operation. -/// Loads the flag from the stack top, creates scf.if with the condition, -/// and inlines the region content after converting block args. -struct IfOpConversion : public OpConversionPattern { - IfOpConversion(const TypeConverter &typeConverter, MLIRContext *context) - : OpConversionPattern(typeConverter, context) {} - using OneToNOpAdaptor = OpConversionPattern::OneToNOpAdaptor; - - LogicalResult - matchAndRewrite(forth::IfOp op, OneToNOpAdaptor adaptor, + matchAndRewrite(forth::PopFlagOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); ValueRange inputStack = adaptor.getOperands()[0]; Value memref = inputStack[0]; Value stackPtr = inputStack[1]; - // Load flag from stack top. - Value flag = rewriter.create(loc, memref, stackPtr); - - // Condition: flag != 0. - Value zero = rewriter.create(loc, 0, 64); - Value cond = rewriter.create(loc, arith::CmpIPredicate::ne, - flag, zero); - - // Create scf.if with index result (SP). - auto indexType = rewriter.getIndexType(); - auto scfIf = rewriter.create(loc, TypeRange{indexType}, cond, - /*addElseBlock=*/true); - - // Convert block signatures and inline regions into scf.if. - // convertRegionTypes converts !forth.stack block arg to {memref, index} - // and inserts tracked materializations (unrealized_conversion_cast). - // We inline into scf.if and mergeBlocks to substitute the converted - // block args with parent-scope values. The original materialization - // cast stays intact (tracked by the framework). When the framework - // later converts the inlined inner ops, their adaptors unwrap the - // cast to get {memref, index}. - auto convertRegion = [&](Region &srcRegion, - Region &dstRegion) -> LogicalResult { - if (failed(rewriter.convertRegionTypes(&srcRegion, *getTypeConverter()))) - return failure(); - - rewriter.eraseBlock(&dstRegion.front()); - rewriter.inlineRegionBefore(srcRegion, dstRegion, dstRegion.end()); + // Load top value + Value topValue = rewriter.create(loc, memref, stackPtr); - Block &blockWithArgs = dstRegion.front(); - Block *newBlock = rewriter.createBlock(&dstRegion); - rewriter.mergeBlocks(&blockWithArgs, newBlock, {memref, stackPtr}); - return success(); - }; + // Decrement SP + Value one = rewriter.create(loc, 1); + Value newSP = rewriter.create(loc, stackPtr, one); - if (failed(convertRegion(op.getThenRegion(), scfIf.getThenRegion()))) - return failure(); - if (failed(convertRegion(op.getElseRegion(), scfIf.getElseRegion()))) - return failure(); + // Compare != 0 + Value zero = rewriter.create(loc, 0, 64); + Value flag = rewriter.create(loc, arith::CmpIPredicate::ne, + topValue, zero); - // Replace forth.if with {memref, scf.if result SP}. - rewriter.replaceOpWithMultiple(op, {{memref, scfIf.getResult(0)}}); + // Result 0: output_stack -> {memref, newSP} + // Result 1: flag -> i1 (passes through unchanged) + rewriter.replaceOpWithMultiple(op, {{memref, newSP}, {flag}}); return success(); } }; -/// Conversion pattern for forth.begin_until operation. -/// Creates scf.while with the body as the `before` region (condition test), -/// and an identity `after` region. -struct BeginUntilOpConversion - : public OpConversionPattern { - BeginUntilOpConversion(const TypeConverter &typeConverter, - MLIRContext *context) - : OpConversionPattern(typeConverter, context) {} +/// Conversion pattern for forth.pop operation. +/// Pops top of stack, returns (memref, newSP, i64 value). +struct PopOpConversion : public OpConversionPattern { + PopOpConversion(const TypeConverter &typeConverter, MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} using OneToNOpAdaptor = OpConversionPattern::OneToNOpAdaptor; LogicalResult - matchAndRewrite(forth::BeginUntilOp op, OneToNOpAdaptor adaptor, + matchAndRewrite(forth::PopOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); ValueRange inputStack = adaptor.getOperands()[0]; Value memref = inputStack[0]; Value stackPtr = inputStack[1]; - auto indexType = rewriter.getIndexType(); - - // Create scf.while with index result, stackPtr as iter arg. - auto whileOp = rewriter.create(loc, TypeRange{indexType}, - ValueRange{stackPtr}); + // Load top value + Value topValue = rewriter.create(loc, memref, stackPtr); - // Before region (body): convert + inline. - Region &bodyRegion = op.getBodyRegion(); - if (failed(rewriter.convertRegionTypes(&bodyRegion, *getTypeConverter()))) - return failure(); + // Decrement SP + Value one = rewriter.create(loc, 1); + Value newSP = rewriter.create(loc, stackPtr, one); - // scf.while's before region starts empty (no auto-created blocks). - // Inline the body region into before. - rewriter.inlineRegionBefore(bodyRegion, whileOp.getBefore(), - whileOp.getBefore().end()); - - // Merge the block args: replace converted {memref, index} with - // {memref, beforeSP}. - Block &beforeBlock = whileOp.getBefore().front(); - Block *newBeforeBlock = rewriter.createBlock(&whileOp.getBefore()); - newBeforeBlock->addArgument(indexType, loc); - Value beforeSP = newBeforeBlock->getArgument(0); - rewriter.mergeBlocks(&beforeBlock, newBeforeBlock, {memref, beforeSP}); - - // After region (identity): just yield the SP. - Block *afterBlock = rewriter.createBlock(&whileOp.getAfter()); - afterBlock->addArgument(indexType, loc); - Value afterSP = afterBlock->getArgument(0); - rewriter.setInsertionPointToStart(afterBlock); - rewriter.create(loc, ValueRange{afterSP}); - - // Replace forth.begin_until with {memref, whileOp result}. - rewriter.replaceOpWithMultiple(op, {{memref, whileOp.getResult(0)}}); + // Result 0: output_stack -> {memref, newSP} + // Result 1: value -> i64 (passes through unchanged) + rewriter.replaceOpWithMultiple(op, {{memref, newSP}, {topValue}}); return success(); } }; -/// Conversion pattern for forth.begin_while_repeat operation. -/// Creates scf.while with the condition as the `before` region, -/// and the body as the `after` region. -struct BeginWhileRepeatOpConversion - : public OpConversionPattern { - BeginWhileRepeatOpConversion(const TypeConverter &typeConverter, - MLIRContext *context) - : OpConversionPattern(typeConverter, context) { - } +/// Conversion pattern for forth.push_value operation. +/// Pushes a dynamic i64 value onto the stack. +struct PushValueOpConversion : public OpConversionPattern { + PushValueOpConversion(const TypeConverter &typeConverter, + MLIRContext *context) + : OpConversionPattern(typeConverter, context) {} using OneToNOpAdaptor = OpConversionPattern::OneToNOpAdaptor; LogicalResult - matchAndRewrite(forth::BeginWhileRepeatOp op, OneToNOpAdaptor adaptor, + matchAndRewrite(forth::PushValueOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); ValueRange inputStack = adaptor.getOperands()[0]; Value memref = inputStack[0]; Value stackPtr = inputStack[1]; - auto indexType = rewriter.getIndexType(); + // The value operand (i64) passes through as-is (not a converted type). + Value value = adaptor.getOperands()[1][0]; - // Create scf.while with index result, stackPtr as iter arg. - auto whileOp = rewriter.create(loc, TypeRange{indexType}, - ValueRange{stackPtr}); + Value newSP = pushValue(loc, rewriter, memref, stackPtr, value); - // Before region (condition): convert + inline. - Region &condRegion = op.getConditionRegion(); - if (failed(rewriter.convertRegionTypes(&condRegion, *getTypeConverter()))) - return failure(); + rewriter.replaceOpWithMultiple(op, {{memref, newSP}}); + return success(); + } +}; - rewriter.inlineRegionBefore(condRegion, whileOp.getBefore(), - whileOp.getBefore().end()); +/// Custom FuncOp conversion that calls convertRegionTypes to convert ALL +/// block args (including non-entry blocks used by CF branch ops). +/// The built-in pattern only converts the entry block. +struct FuncOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OneToNOpAdaptor = OpConversionPattern::OneToNOpAdaptor; - Block &beforeBlock = whileOp.getBefore().front(); - Block *newBeforeBlock = rewriter.createBlock(&whileOp.getBefore()); - newBeforeBlock->addArgument(indexType, loc); - Value beforeSP = newBeforeBlock->getArgument(0); - rewriter.mergeBlocks(&beforeBlock, newBeforeBlock, {memref, beforeSP}); + LogicalResult + matchAndRewrite(func::FuncOp funcOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto type = funcOp.getFunctionType(); - // After region (body): convert + inline. - Region &bodyRegion = op.getBodyRegion(); - if (failed(rewriter.convertRegionTypes(&bodyRegion, *getTypeConverter()))) + TypeConverter::SignatureConversion result(type.getNumInputs()); + SmallVector newResults; + if (failed(getTypeConverter()->convertSignatureArgs(type.getInputs(), + result)) || + failed(getTypeConverter()->convertTypes(type.getResults(), newResults))) return failure(); - rewriter.inlineRegionBefore(bodyRegion, whileOp.getAfter(), - whileOp.getAfter().end()); - - Block &afterBlock = whileOp.getAfter().front(); - Block *newAfterBlock = rewriter.createBlock(&whileOp.getAfter()); - newAfterBlock->addArgument(indexType, loc); - Value afterSP = newAfterBlock->getArgument(0); - rewriter.mergeBlocks(&afterBlock, newAfterBlock, {memref, afterSP}); + if (!funcOp.getFunctionBody().empty()) { + if (failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(), + *getTypeConverter(), &result))) + return failure(); + } - // Replace forth.begin_while_repeat with {memref, whileOp result}. - rewriter.replaceOpWithMultiple(op, {{memref, whileOp.getResult(0)}}); + auto newType = FunctionType::get(rewriter.getContext(), + result.getConvertedTypes(), newResults); + rewriter.modifyOpInPlace(funcOp, [&] { funcOp.setType(newType); }); return success(); } }; -/// Conversion pattern for forth.do_loop operation. -/// Pops start and limit from the stack, creates scf.for from start to limit. -struct DoLoopOpConversion : public OpConversionPattern { - DoLoopOpConversion(const TypeConverter &typeConverter, MLIRContext *context) - : OpConversionPattern(typeConverter, context) {} +/// Conversion pattern for cf::BranchOp with 1:N type conversion. +/// The built-in populateBranchOpInterfaceTypeConversionPattern uses the old +/// ArrayRef signature and crashes on 1:N conversions. +struct BranchOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; using OneToNOpAdaptor = OpConversionPattern::OneToNOpAdaptor; LogicalResult - matchAndRewrite(forth::DoLoopOp op, OneToNOpAdaptor adaptor, + matchAndRewrite(cf::BranchOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - ValueRange inputStack = adaptor.getOperands()[0]; - Value memref = inputStack[0]; - Value stackPtr = inputStack[1]; - - auto indexType = rewriter.getIndexType(); - Value one = rewriter.create(loc, 1); - - // Pop start (TOS): load memref[SP], SP -= 1 - Value startI64 = rewriter.create(loc, memref, stackPtr); - Value spAfterStart = rewriter.create(loc, stackPtr, one); - - // Pop limit (new TOS): load memref[SP-1], SP -= 1 - Value limitI64 = rewriter.create(loc, memref, spAfterStart); - Value spAfterPops = rewriter.create(loc, spAfterStart, one); - - // Cast i64 to index for scf.for bounds - Value startIdx = - rewriter.create(loc, indexType, startI64); - Value limitIdx = - rewriter.create(loc, indexType, limitI64); - Value stepIdx = one; // step = 1 - - // Create scf.for %iv = start to limit step 1 iter_args(%sp = spAfterPops) - auto forOp = rewriter.create(loc, startIdx, limitIdx, stepIdx, - ValueRange{spAfterPops}); - - // Convert body region types and inline into scf.for - Region &bodyRegion = op.getBodyRegion(); - if (failed(rewriter.convertRegionTypes(&bodyRegion, *getTypeConverter()))) - return failure(); - - // Erase the auto-created body block of scf.for - rewriter.eraseBlock(forOp.getBody()); - - // Inline the converted body region into scf.for - rewriter.inlineRegionBefore(bodyRegion, forOp.getRegion(), - forOp.getRegion().end()); - - // Merge block args: the converted block has {memref, index} args. - // Replace with {memref, iter_arg SP}. - Block &bodyBlock = forOp.getRegion().front(); - Block *newBlock = rewriter.createBlock(&forOp.getRegion()); - // scf.for block has: induction var, then iter args - newBlock->addArgument(indexType, loc); // induction variable - newBlock->addArgument(indexType, loc); // SP iter arg - Value spIterArg = newBlock->getArgument(1); - rewriter.mergeBlocks(&bodyBlock, newBlock, {memref, spIterArg}); - - // Replace forth.do_loop with {memref, forOp result SP} - rewriter.replaceOpWithMultiple(op, {{memref, forOp.getResult(0)}}); + SmallVector newOperands; + for (ValueRange vals : adaptor.getOperands()) + llvm::append_range(newOperands, vals); + rewriter.replaceOpWithNewOp(op, op.getDest(), newOperands); return success(); } }; -/// Conversion pattern for forth.loop_index operation (I word). -/// Finds the enclosing scf.for and pushes its induction variable onto the -/// stack. -struct LoopIndexOpConversion : public OpConversionPattern { - LoopIndexOpConversion(const TypeConverter &typeConverter, - MLIRContext *context) - : OpConversionPattern(typeConverter, context) {} +/// Conversion pattern for cf::CondBranchOp with 1:N type conversion. +struct CondBranchOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; using OneToNOpAdaptor = OpConversionPattern::OneToNOpAdaptor; LogicalResult - matchAndRewrite(forth::LoopIndexOp op, OneToNOpAdaptor adaptor, + matchAndRewrite(cf::CondBranchOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - ValueRange inputStack = adaptor.getOperands()[0]; - Value memref = inputStack[0]; - Value stackPtr = inputStack[1]; + SmallVector trueOperands, falseOperands; - // Get the depth attribute (default 0 = I) - int64_t depth = op.getDepth(); - - // Walk up through enclosing scf.for ops - Operation *current = op.getOperation(); - scf::ForOp targetFor; - for (int64_t i = 0; i <= depth; ++i) { - targetFor = current->getParentOfType(); - if (!targetFor) - return rewriter.notifyMatchFailure( - op, "not enough enclosing scf.for ops for depth " + - std::to_string(depth)); - current = targetFor.getOperation(); - } + Value condition = adaptor.getOperands()[0][0]; - // Get induction variable and cast index to i64 - Value iv = targetFor.getInductionVar(); - Value ivI64 = - rewriter.create(loc, rewriter.getI64Type(), iv); + unsigned trueCount = op.getTrueDestOperands().size(); + for (unsigned i = 1; i <= trueCount; ++i) + llvm::append_range(trueOperands, adaptor.getOperands()[i]); - // Push onto stack - Value newSP = pushValue(loc, rewriter, memref, stackPtr, ivI64); + for (unsigned i = 1 + trueCount; i < adaptor.getOperands().size(); ++i) + llvm::append_range(falseOperands, adaptor.getOperands()[i]); - rewriter.replaceOpWithMultiple(op, {{memref, newSP}}); + rewriter.replaceOpWithNewOp( + op, condition, op.getTrueDest(), trueOperands, op.getFalseDest(), + falseOperands); return success(); } }; @@ -1111,9 +963,9 @@ struct ConvertForthToMemRefPass // Mark Forth dialect as illegal (to be converted) target.addIllegalDialect(); - // Mark MemRef, Arith, LLVM, and SCF dialects as legal + // Mark MemRef, Arith, LLVM, and CF dialects as legal target.addLegalDialect(); + LLVM::LLVMDialect, cf::ControlFlowDialect>(); // Mark IntrinsicOp as legal (to be lowered later) target.addLegalOp(); @@ -1121,11 +973,15 @@ struct ConvertForthToMemRefPass // Use dynamic legality for func operations to ensure they're properly // converted target.addDynamicallyLegalOp([&](func::FuncOp op) { - // Function is legal if its signature doesn't contain forth.stack types - return llvm::none_of(op.getFunctionType().getInputs(), - [&](Type t) { return isa(t); }) && - llvm::none_of(op.getFunctionType().getResults(), - [&](Type t) { return isa(t); }); + auto isStack = [](Type t) { return isa(t); }; + if (llvm::any_of(op.getFunctionType().getInputs(), isStack) || + llvm::any_of(op.getFunctionType().getResults(), isStack)) + return false; + // Also check non-entry block args (CF control flow) + for (Block &block : op.getFunctionBody()) + if (llvm::any_of(block.getArgumentTypes(), isStack)) + return false; + return true; }); target.addDynamicallyLegalOp([&](func::CallOp op) { @@ -1140,6 +996,13 @@ struct ConvertForthToMemRefPass [&](Type t) { return isa(t); }); }); + // CF ops are legal but need dynamic legality for type-converted block args + target.addDynamicallyLegalOp( + [&](Operation *op) { + return llvm::none_of(op->getOperandTypes(), + [](Type t) { return isa(t); }); + }); + ForthToMemRefTypeConverter typeConverter; RewritePatternSet patterns(context); @@ -1153,10 +1016,8 @@ struct ConvertForthToMemRefPass NotOpConversion, LshiftOpConversion, RshiftOpConversion, EqOpConversion, LtOpConversion, GtOpConversion, NeOpConversion, LeOpConversion, GeOpConversion, ZeroEqOpConversion, ParamRefOpConversion, - LoadOpConversion, StoreOpConversion, IfOpConversion, - BeginUntilOpConversion, BeginWhileRepeatOpConversion, - DoLoopOpConversion, LoopIndexOpConversion, YieldOpConversion>( - typeConverter, context); + LoadOpConversion, StoreOpConversion, PopFlagOpConversion, + PopOpConversion, PushValueOpConversion>(typeConverter, context); // Add GPU indexing op conversion patterns patterns.add>(typeConverter, @@ -1187,9 +1048,9 @@ struct ConvertForthToMemRefPass // GlobalIdOp has custom pattern patterns.add(typeConverter, context); - // Add built-in function conversion patterns - populateFunctionOpInterfaceTypeConversionPattern( - patterns, typeConverter); + // Custom FuncOp + branch patterns for 1:N type conversion + patterns.add( + typeConverter, context); populateCallOpTypeConversionPattern(patterns, typeConverter); populateReturnOpTypeConversionPattern(patterns, typeConverter); diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index aceecfa..5e59eb9 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -8,7 +8,6 @@ #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" -#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/GPU/Transforms/Passes.h" @@ -23,34 +22,31 @@ namespace mlir { namespace warpforth { void buildWarpForthPipeline(OpPassManager &pm) { - // Stage 1: Lower Forth to MemRef + // Stage 1: Lower Forth to MemRef (CF ops pass through as-is) pm.addPass(createConvertForthToMemRefPass()); - // Stage 2: Lower SCF to CF - pm.addPass(createConvertSCFToCFPass()); - - // Stage 3: Convert to GPU dialect (includes private address space annotation) + // Stage 2: Convert to GPU dialect (includes private address space annotation) pm.addPass(createConvertForthToGPUPass()); - // Stage 4: Normalize MemRefs for GPU + // Stage 3: Normalize MemRefs for GPU pm.addPass(createCanonicalizerPass()); - // Stage 5: Attach NVVM target to GPU modules (sm_70 = Volta architecture) + // Stage 4: Attach NVVM target to GPU modules (sm_70 = Volta architecture) pm.addPass(createGpuNVVMAttachTarget()); - // Stage 6: Lower GPU to NVVM with bare pointers + // Stage 5: Lower GPU to NVVM with bare pointers ConvertGpuOpsToNVVMOpsOptions gpuToNVVMOptions; gpuToNVVMOptions.useBarePtrCallConv = true; pm.addNestedPass( createConvertGpuOpsToNVVMOps(gpuToNVVMOptions)); - // Stage 7: Lower NVVM to LLVM + // Stage 6: Lower NVVM to LLVM pm.addPass(createConvertNVVMToLLVMPass()); - // Stage 8: Reconcile type conversions + // Stage 7: Reconcile type conversions pm.addPass(createReconcileUnrealizedCastsPass()); - // Stage 9: Compile GPU module to PTX binary + // Stage 8: Compile GPU module to PTX binary GpuModuleToBinaryPassOptions binaryOptions; binaryOptions.compilationTarget = "isa"; // Output PTX assembly pm.addPass(createGpuModuleToBinaryPass(binaryOptions)); diff --git a/lib/Dialect/Forth/ForthDialect.cpp b/lib/Dialect/Forth/ForthDialect.cpp index 5521f64..e439f48 100644 --- a/lib/Dialect/Forth/ForthDialect.cpp +++ b/lib/Dialect/Forth/ForthDialect.cpp @@ -21,240 +21,6 @@ using namespace mlir::forth; #define GET_OP_CLASSES #include "warpforth/Dialect/Forth/ForthOps.cpp.inc" -//===----------------------------------------------------------------------===// -// IfOp RegionBranchOpInterface. -//===----------------------------------------------------------------------===// - -void IfOp::getSuccessorRegions(RegionBranchPoint point, - SmallVectorImpl ®ions) { - if (point.isParent()) { - // From parent: branch into then or else region. - regions.push_back( - RegionSuccessor(&getThenRegion(), getThenRegion().getArguments())); - regions.push_back( - RegionSuccessor(&getElseRegion(), getElseRegion().getArguments())); - return; - } - // From either region: return to parent with op results. - regions.push_back(RegionSuccessor(getOperation()->getResults())); -} - -OperandRange IfOp::getEntrySuccessorOperands(RegionBranchPoint point) { - return getOperation()->getOperands(); -} - -//===----------------------------------------------------------------------===// -// IfOp custom assembly format. -//===----------------------------------------------------------------------===// - -void IfOp::print(OpAsmPrinter &p) { - p << ' ' << getInputStack() << " : " << getInputStack().getType() << " -> " - << getOutputStack().getType() << ' '; - p.printRegion(getThenRegion()); - p << " else "; - p.printRegion(getElseRegion()); -} - -ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand inputStack; - Type inputType, outputType; - - if (parser.parseOperand(inputStack) || parser.parseColon() || - parser.parseType(inputType) || parser.parseArrow() || - parser.parseType(outputType) || - parser.resolveOperand(inputStack, inputType, result.operands)) - return failure(); - - result.addTypes(outputType); - - // Parse then region. - auto *thenRegion = result.addRegion(); - if (parser.parseRegion(*thenRegion)) - return failure(); - - // Parse "else" keyword and else region. - if (parser.parseKeyword("else")) - return failure(); - - auto *elseRegion = result.addRegion(); - if (parser.parseRegion(*elseRegion)) - return failure(); - - return success(); -} - -//===----------------------------------------------------------------------===// -// BeginUntilOp RegionBranchOpInterface. -//===----------------------------------------------------------------------===// - -void BeginUntilOp::getSuccessorRegions( - RegionBranchPoint point, SmallVectorImpl ®ions) { - if (point.isParent()) { - // From parent: enter the body region. - regions.push_back( - RegionSuccessor(&getBodyRegion(), getBodyRegion().getArguments())); - return; - } - // From body: loop back to body or exit to parent. - regions.push_back( - RegionSuccessor(&getBodyRegion(), getBodyRegion().getArguments())); - regions.push_back(RegionSuccessor(getOperation()->getResults())); -} - -OperandRange BeginUntilOp::getEntrySuccessorOperands(RegionBranchPoint point) { - return getOperation()->getOperands(); -} - -//===----------------------------------------------------------------------===// -// BeginUntilOp custom assembly format. -//===----------------------------------------------------------------------===// - -void BeginUntilOp::print(OpAsmPrinter &p) { - p << ' ' << getInputStack() << " : " << getInputStack().getType() << " -> " - << getOutputStack().getType() << ' '; - p.printRegion(getBodyRegion()); -} - -ParseResult BeginUntilOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand inputStack; - Type inputType, outputType; - - if (parser.parseOperand(inputStack) || parser.parseColon() || - parser.parseType(inputType) || parser.parseArrow() || - parser.parseType(outputType) || - parser.resolveOperand(inputStack, inputType, result.operands)) - return failure(); - - result.addTypes(outputType); - - auto *bodyRegion = result.addRegion(); - if (parser.parseRegion(*bodyRegion)) - return failure(); - - return success(); -} - -//===----------------------------------------------------------------------===// -// BeginWhileRepeatOp RegionBranchOpInterface. -//===----------------------------------------------------------------------===// - -void BeginWhileRepeatOp::getSuccessorRegions( - RegionBranchPoint point, SmallVectorImpl ®ions) { - if (point.isParent()) { - // From parent: enter the condition region. - regions.push_back(RegionSuccessor(&getConditionRegion(), - getConditionRegion().getArguments())); - return; - } - if (point.getRegionOrNull() == &getConditionRegion()) { - // From condition: enter body or exit to parent. - regions.push_back( - RegionSuccessor(&getBodyRegion(), getBodyRegion().getArguments())); - regions.push_back(RegionSuccessor(getOperation()->getResults())); - return; - } - // From body: loop back to condition. - regions.push_back(RegionSuccessor(&getConditionRegion(), - getConditionRegion().getArguments())); -} - -OperandRange -BeginWhileRepeatOp::getEntrySuccessorOperands(RegionBranchPoint point) { - return getOperation()->getOperands(); -} - -//===----------------------------------------------------------------------===// -// BeginWhileRepeatOp custom assembly format. -//===----------------------------------------------------------------------===// - -void BeginWhileRepeatOp::print(OpAsmPrinter &p) { - p << ' ' << getInputStack() << " : " << getInputStack().getType() << " -> " - << getOutputStack().getType() << ' '; - p.printRegion(getConditionRegion()); - p << " do "; - p.printRegion(getBodyRegion()); -} - -ParseResult BeginWhileRepeatOp::parse(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::UnresolvedOperand inputStack; - Type inputType, outputType; - - if (parser.parseOperand(inputStack) || parser.parseColon() || - parser.parseType(inputType) || parser.parseArrow() || - parser.parseType(outputType) || - parser.resolveOperand(inputStack, inputType, result.operands)) - return failure(); - - result.addTypes(outputType); - - // Parse condition region. - auto *condRegion = result.addRegion(); - if (parser.parseRegion(*condRegion)) - return failure(); - - // Parse "do" keyword and body region. - if (parser.parseKeyword("do")) - return failure(); - - auto *bodyRegion = result.addRegion(); - if (parser.parseRegion(*bodyRegion)) - return failure(); - - return success(); -} - -//===----------------------------------------------------------------------===// -// DoLoopOp RegionBranchOpInterface. -//===----------------------------------------------------------------------===// - -void DoLoopOp::getSuccessorRegions(RegionBranchPoint point, - SmallVectorImpl ®ions) { - if (point.isParent()) { - // From parent: enter the body region. - regions.push_back( - RegionSuccessor(&getBodyRegion(), getBodyRegion().getArguments())); - return; - } - // From body: loop back to body or exit to parent. - regions.push_back( - RegionSuccessor(&getBodyRegion(), getBodyRegion().getArguments())); - regions.push_back(RegionSuccessor(getOperation()->getResults())); -} - -OperandRange DoLoopOp::getEntrySuccessorOperands(RegionBranchPoint point) { - return getOperation()->getOperands(); -} - -//===----------------------------------------------------------------------===// -// DoLoopOp custom assembly format. -//===----------------------------------------------------------------------===// - -void DoLoopOp::print(OpAsmPrinter &p) { - p << ' ' << getInputStack() << " : " << getInputStack().getType() << " -> " - << getOutputStack().getType() << ' '; - p.printRegion(getBodyRegion()); -} - -ParseResult DoLoopOp::parse(OpAsmParser &parser, OperationState &result) { - OpAsmParser::UnresolvedOperand inputStack; - Type inputType, outputType; - - if (parser.parseOperand(inputStack) || parser.parseColon() || - parser.parseType(inputType) || parser.parseArrow() || - parser.parseType(outputType) || - parser.resolveOperand(inputStack, inputType, result.operands)) - return failure(); - - result.addTypes(outputType); - - auto *bodyRegion = result.addRegion(); - if (parser.parseRegion(*bodyRegion)) - return failure(); - - return success(); -} - //===----------------------------------------------------------------------===// // Forth dialect. //===----------------------------------------------------------------------===// diff --git a/lib/Translation/ForthToMLIR/CMakeLists.txt b/lib/Translation/ForthToMLIR/CMakeLists.txt index 83be194..c7abcb2 100644 --- a/lib/Translation/ForthToMLIR/CMakeLists.txt +++ b/lib/Translation/ForthToMLIR/CMakeLists.txt @@ -10,5 +10,8 @@ add_mlir_library(MLIRForthTranslation MLIRSupport MLIRTranslateLib MLIRForth + MLIRArithDialect + MLIRControlFlowDialect MLIRFuncDialect + MLIRMemRefDialect ) diff --git a/lib/Translation/ForthToMLIR/ForthToMLIR.cpp b/lib/Translation/ForthToMLIR/ForthToMLIR.cpp index c2e79ff..d000ea4 100644 --- a/lib/Translation/ForthToMLIR/ForthToMLIR.cpp +++ b/lib/Translation/ForthToMLIR/ForthToMLIR.cpp @@ -5,7 +5,10 @@ //===----------------------------------------------------------------------===// #include "ForthToMLIR.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" @@ -358,15 +361,18 @@ Value ForthParser::emitOperation(StringRef word, Value inputStack, .getResult(); } else if (word == "I" || word == "J" || word == "K") { int64_t depth = (word == "I") ? 0 : (word == "J") ? 1 : 2; - if (doLoopDepth < depth + 1) { + if (static_cast(loopStack.size()) < depth + 1) { (void)emitError("'" + word.str() + "' requires " + std::to_string(depth + 1) + " nested DO/LOOP(s)"); return nullptr; } - return builder - .create(loc, stackType, inputStack, - builder.getI64IntegerAttr(depth)) - .getResult(); + // Load counter from the appropriate loop context + auto &ctx = loopStack[loopStack.size() - 1 - depth]; + Value c0 = builder.create(loc, 0); + Value idx = + builder.create(loc, ctx.counter, ValueRange{c0}); + return builder.create(loc, stackType, inputStack, idx) + .getOutputStack(); } // Unknown word @@ -374,22 +380,30 @@ Value ForthParser::emitOperation(StringRef word, Value inputStack, } //===----------------------------------------------------------------------===// -// Body parsing — shared by word definitions, main, and control flow regions. +// Body parsing - shared by word definitions and main. +// Control flow words are handled inline using cf.br/cf.cond_br. //===----------------------------------------------------------------------===// -LogicalResult -ForthParser::parseBody(Value &stack, - llvm::function_ref isStopWord) { +Block *ForthParser::createStackBlock(Region *region, Location loc) { + auto *block = new Block(); + region->push_back(block); + block->addArgument(forth::StackType::get(context), loc); + return block; +} + +std::pair ForthParser::emitPopFlag(Location loc, Value stack) { + auto popFlag = builder.create( + loc, forth::StackType::get(context), builder.getI1Type(), stack); + return {popFlag.getOutputStack(), popFlag.getFlag()}; +} + +LogicalResult ForthParser::parseBody(Value &stack) { Type stackType = forth::StackType::get(context); while (currentToken.kind != Token::Kind::EndOfFile && currentToken.kind != Token::Kind::Semicolon && currentToken.kind != Token::Kind::Colon) { - // Check if current word is a stop word. - if (currentToken.kind == Token::Kind::Word && isStopWord(currentToken.text)) - break; - // Skip PARAM declarations at top level. if (!inWordDefinition && currentToken.kind == Token::Kind::Word && currentToken.text == "PARAM") { @@ -408,272 +422,231 @@ ForthParser::parseBody(Value &stack, .getResult(); consume(); } else if (currentToken.kind == Token::Kind::Word) { - if (currentToken.text == "IF") { - Location tokenLoc = getLoc(); - consume(); // consume IF - stack = parseIf(stack, tokenLoc); - if (!stack) - return failure(); - } else if (currentToken.text == "BEGIN") { - Location tokenLoc = getLoc(); - consume(); // consume BEGIN - if (isWhileLoop()) - stack = parseBeginWhileRepeat(stack, tokenLoc); - else - stack = parseBeginUntil(stack, tokenLoc); - if (!stack) - return failure(); - } else if (currentToken.text == "DO") { - Location tokenLoc = getLoc(); - consume(); // consume DO - stack = parseDoLoop(stack, tokenLoc); - if (!stack) - return failure(); - } else { - Location tokenLoc = getLoc(); - Value newStack = emitOperation(currentToken.text, stack, tokenLoc); - if (!newStack) - return emitError("unknown word: " + currentToken.text); - stack = newStack; + Location loc = getLoc(); + StringRef word = currentToken.text; + + //=== IF === + if (word == "IF") { consume(); - } - } else { - return emitError("unexpected token"); - } - } + Region *parentRegion = builder.getInsertionBlock()->getParent(); - return success(); -} + auto [s1, flag] = emitPopFlag(loc, stack); -//===----------------------------------------------------------------------===// -// IF / ELSE / THEN parsing. -//===----------------------------------------------------------------------===// + auto *thenBlock = createStackBlock(parentRegion, loc); + auto *joinBlock = createStackBlock(parentRegion, loc); -Value ForthParser::parseIf(Value inputStack, Location loc) { - Type stackType = forth::StackType::get(context); + // Branch: true -> then, false -> join. + builder.create(loc, flag, thenBlock, ValueRange{s1}, + joinBlock, ValueRange{s1}); - // Create forth.if op. inputStack has the flag on top. - // Regions capture inputStack from the enclosing scope (no block args). - // Each region starts with forth.drop to pop the flag. - auto ifOp = builder.create(loc, stackType, inputStack); - - auto isElseOrThen = [](StringRef word) { - return word == "ELSE" || word == "THEN"; - }; - auto isThen = [](StringRef word) { return word == "THEN"; }; - - // --- Then region --- - Block *thenBlock = new Block(); - thenBlock->addArgument(stackType, loc); - ifOp.getThenRegion().push_back(thenBlock); - - builder.setInsertionPointToStart(thenBlock); - Value thenArg = thenBlock->getArgument(0); - // Drop the flag from the block arg. - Value thenStack = - builder.create(loc, stackType, thenArg).getResult(); - if (failed(parseBody(thenStack, isElseOrThen))) - return nullptr; - builder.create(getLoc(), thenStack, /*while_cond=*/nullptr); - - // --- Else region --- - Block *elseBlock = new Block(); - elseBlock->addArgument(stackType, loc); - ifOp.getElseRegion().push_back(elseBlock); - - if (currentToken.kind == Token::Kind::Word && currentToken.text == "ELSE") { - consume(); // consume ELSE - builder.setInsertionPointToStart(elseBlock); - Value elseArg = elseBlock->getArgument(0); - Value elseStack = - builder.create(loc, stackType, elseArg).getResult(); - if (failed(parseBody(elseStack, isThen))) - return nullptr; - builder.create(getLoc(), elseStack, /*while_cond=*/nullptr); - } else { - // No ELSE clause — just drop the flag and yield (identity). - builder.setInsertionPointToStart(elseBlock); - Value elseArg = elseBlock->getArgument(0); - Value elseStack = - builder.create(loc, stackType, elseArg).getResult(); - builder.create(loc, elseStack, /*while_cond=*/nullptr); - } + // Push join block for THEN/ELSE to pick up. + cfStack.push_back({CFTag::Orig, joinBlock}); - // Consume THEN. - if (currentToken.kind != Token::Kind::Word || currentToken.text != "THEN") { - (void)emitError("expected 'THEN'"); - return nullptr; - } - consume(); // consume THEN + // Continue parsing in then block. + builder.setInsertionPointToStart(thenBlock); + stack = thenBlock->getArgument(0); - // Restore insertion point to after the forth.if op. - builder.setInsertionPointAfter(ifOp); - return ifOp.getOutputStack(); -} + //=== ELSE === + } else if (word == "ELSE") { + consume(); + Region *parentRegion = builder.getInsertionBlock()->getParent(); -//===----------------------------------------------------------------------===// -// BEGIN / UNTIL parsing. -//===----------------------------------------------------------------------===// + auto *mergeBlock = createStackBlock(parentRegion, loc); -Value ForthParser::parseBeginUntil(Value inputStack, Location loc) { - Type stackType = forth::StackType::get(context); + // End of then-body: branch to merge. + builder.create(loc, mergeBlock, ValueRange{stack}); - // Create forth.begin_until op. - auto beginUntilOp = - builder.create(loc, stackType, inputStack); + // Pop the false-path block (from IF) - this becomes else-body start. + auto [tag, joinBlock] = cfStack.pop_back_val(); - auto isUntil = [](StringRef word) { return word == "UNTIL"; }; + // Push merge block for THEN to pick up. + cfStack.push_back({CFTag::Orig, mergeBlock}); - // --- Body region --- - Block *bodyBlock = new Block(); - bodyBlock->addArgument(stackType, loc); - beginUntilOp.getBodyRegion().push_back(bodyBlock); + // Continue parsing in the else (false-path) block. + builder.setInsertionPointToStart(joinBlock); + stack = joinBlock->getArgument(0); - builder.setInsertionPointToStart(bodyBlock); - Value bodyStack = bodyBlock->getArgument(0); - if (failed(parseBody(bodyStack, isUntil))) - return nullptr; - builder.create(getLoc(), bodyStack, /*while_cond=*/nullptr); + //=== THEN === + } else if (word == "THEN") { + consume(); - // Consume UNTIL. - if (currentToken.kind != Token::Kind::Word || currentToken.text != "UNTIL") { - (void)emitError("expected 'UNTIL'"); - return nullptr; - } - consume(); // consume UNTIL + // Pop the join/merge block. + auto [tag, joinBlock] = cfStack.pop_back_val(); - // Restore insertion point to after the forth.begin_until op. - builder.setInsertionPointAfter(beginUntilOp); - return beginUntilOp.getOutputStack(); -} + // Branch from current block to join. + builder.create(loc, joinBlock, ValueRange{stack}); -//===----------------------------------------------------------------------===// -// BEGIN / WHILE / REPEAT lookahead + parsing. -//===----------------------------------------------------------------------===// + // Continue parsing after the join. + builder.setInsertionPointToStart(joinBlock); + stack = joinBlock->getArgument(0); -bool ForthParser::isWhileLoop() { - // Save lexer position and current token. - const char *savedPos = lexer.getPosition(); - Token savedToken = currentToken; + //=== BEGIN === + } else if (word == "BEGIN") { + consume(); + Region *parentRegion = builder.getInsertionBlock()->getParent(); - int depth = 0; - while (currentToken.kind != Token::Kind::EndOfFile) { - if (currentToken.kind == Token::Kind::Word) { - if (currentToken.text == "BEGIN" || currentToken.text == "DO") - ++depth; - else if (depth == 0 && currentToken.text == "UNTIL") { - // Found UNTIL at our nesting level → not a WHILE loop. - lexer.setPosition(savedPos); - currentToken = savedToken; - return false; - } else if (depth == 0 && currentToken.text == "WHILE") { - // Found WHILE at our nesting level → is a WHILE loop. - lexer.setPosition(savedPos); - currentToken = savedToken; - return true; - } else if (currentToken.text == "UNTIL" || currentToken.text == "LOOP" || - currentToken.text == "REPEAT") - --depth; - } - consume(); - } + auto *loopBlock = createStackBlock(parentRegion, loc); - // Reached EOF without finding UNTIL or WHILE — restore and return false. - lexer.setPosition(savedPos); - currentToken = savedToken; - return false; -} + // Branch to loop header. + builder.create(loc, loopBlock, ValueRange{stack}); -Value ForthParser::parseBeginWhileRepeat(Value inputStack, Location loc) { - Type stackType = forth::StackType::get(context); + // Push loop header as backward reference. + cfStack.push_back({CFTag::Dest, loopBlock}); - // Create forth.begin_while_repeat op. - auto bwrOp = - builder.create(loc, stackType, inputStack); + // Continue parsing in loop body. + builder.setInsertionPointToStart(loopBlock); + stack = loopBlock->getArgument(0); - auto isWhile = [](StringRef word) { return word == "WHILE"; }; - auto isRepeat = [](StringRef word) { return word == "REPEAT"; }; + //=== UNTIL === + } else if (word == "UNTIL") { + consume(); + Region *parentRegion = builder.getInsertionBlock()->getParent(); - // --- Condition region --- - Block *condBlock = new Block(); - condBlock->addArgument(stackType, loc); - bwrOp.getConditionRegion().push_back(condBlock); + auto [s1, flag] = emitPopFlag(loc, stack); - builder.setInsertionPointToStart(condBlock); - Value condStack = condBlock->getArgument(0); - if (failed(parseBody(condStack, isWhile))) - return nullptr; - // Terminate with forth.yield {while_cond} to indicate WHILE semantics. - builder.create(getLoc(), condStack, - /*while_cond=*/builder.getUnitAttr()); + auto [tag, loopBlock] = cfStack.pop_back_val(); - // Consume WHILE. - if (currentToken.kind != Token::Kind::Word || currentToken.text != "WHILE") { - (void)emitError("expected 'WHILE'"); - return nullptr; - } - consume(); // consume WHILE + auto *exitBlock = createStackBlock(parentRegion, loc); - // --- Body region --- - Block *bodyBlock = new Block(); - bodyBlock->addArgument(stackType, loc); - bwrOp.getBodyRegion().push_back(bodyBlock); + // true -> exit, false -> loop back. + builder.create(loc, flag, exitBlock, ValueRange{s1}, + loopBlock, ValueRange{s1}); - builder.setInsertionPointToStart(bodyBlock); - Value bodyStack = bodyBlock->getArgument(0); - if (failed(parseBody(bodyStack, isRepeat))) - return nullptr; - builder.create(getLoc(), bodyStack, /*while_cond=*/nullptr); + // Continue after exit. + builder.setInsertionPointToStart(exitBlock); + stack = exitBlock->getArgument(0); - // Consume REPEAT. - if (currentToken.kind != Token::Kind::Word || currentToken.text != "REPEAT") { - (void)emitError("expected 'REPEAT'"); - return nullptr; - } - consume(); // consume REPEAT + //=== WHILE === + } else if (word == "WHILE") { + consume(); + Region *parentRegion = builder.getInsertionBlock()->getParent(); - // Restore insertion point to after the forth.begin_while_repeat op. - builder.setInsertionPointAfter(bwrOp); - return bwrOp.getOutputStack(); -} + auto [s1, flag] = emitPopFlag(loc, stack); -//===----------------------------------------------------------------------===// -// DO / LOOP parsing. -//===----------------------------------------------------------------------===// + auto [tag, loopBlock] = cfStack.pop_back_val(); -Value ForthParser::parseDoLoop(Value inputStack, Location loc) { - Type stackType = forth::StackType::get(context); + auto *bodyBlock = createStackBlock(parentRegion, loc); + auto *exitBlock = createStackBlock(parentRegion, loc); - // Create forth.do_loop op. - auto doLoopOp = builder.create(loc, stackType, inputStack); + // true -> body, false -> exit. + builder.create(loc, flag, bodyBlock, ValueRange{s1}, + exitBlock, ValueRange{s1}); - auto isLoop = [](StringRef word) { return word == "LOOP"; }; + // Push exit (forward ref) then loop header (backward ref). + cfStack.push_back({CFTag::Orig, exitBlock}); + cfStack.push_back({CFTag::Dest, loopBlock}); - // --- Body region --- - Block *bodyBlock = new Block(); - bodyBlock->addArgument(stackType, loc); - doLoopOp.getBodyRegion().push_back(bodyBlock); + // Continue parsing in body. + builder.setInsertionPointToStart(bodyBlock); + stack = bodyBlock->getArgument(0); - builder.setInsertionPointToStart(bodyBlock); - Value bodyStack = bodyBlock->getArgument(0); - ++doLoopDepth; - if (failed(parseBody(bodyStack, isLoop))) { - --doLoopDepth; - return nullptr; - } - --doLoopDepth; - builder.create(getLoc(), bodyStack, /*while_cond=*/nullptr); + //=== REPEAT === + } else if (word == "REPEAT") { + consume(); - // Consume LOOP. - if (currentToken.kind != Token::Kind::Word || currentToken.text != "LOOP") { - (void)emitError("expected 'LOOP'"); - return nullptr; + // Pop loop header (from WHILE's re-push). + auto [destTag, loopBlock] = cfStack.pop_back_val(); + + // Branch back to loop header. + builder.create(loc, loopBlock, ValueRange{stack}); + + // Pop exit block (from WHILE). + auto [origTag, exitBlock] = cfStack.pop_back_val(); + + // Continue after exit. + builder.setInsertionPointToStart(exitBlock); + stack = exitBlock->getArgument(0); + + //=== DO === + } else if (word == "DO") { + consume(); + Region *parentRegion = builder.getInsertionBlock()->getParent(); + auto i64Type = builder.getI64Type(); + + // Pop start and limit from the Forth stack. + auto popStart = + builder.create(loc, stackType, i64Type, stack); + Value s1 = popStart.getOutputStack(); + Value start = popStart.getValue(); + + auto popLimit = + builder.create(loc, stackType, i64Type, s1); + Value s2 = popLimit.getOutputStack(); + Value limit = popLimit.getValue(); + + // Allocate counter storage. + auto counterType = MemRefType::get({1}, i64Type); + Value counter = builder.create(loc, counterType); + Value c0 = builder.create(loc, 0); + builder.create(loc, start, counter, ValueRange{c0}); + + // Create check, body, and exit blocks. + auto *checkBlock = createStackBlock(parentRegion, loc); + auto *bodyBlock = createStackBlock(parentRegion, loc); + auto *exitBlock = createStackBlock(parentRegion, loc); + + // Branch to check. + builder.create(loc, checkBlock, ValueRange{s2}); + + // --- Check block: load counter, compare < limit --- + builder.setInsertionPointToStart(checkBlock); + Value checkC0 = builder.create(loc, 0); + Value idx = + builder.create(loc, counter, ValueRange{checkC0}); + Value cond = builder.create( + loc, arith::CmpIPredicate::slt, idx, limit); + builder.create( + loc, cond, bodyBlock, ValueRange{checkBlock->getArgument(0)}, + exitBlock, ValueRange{checkBlock->getArgument(0)}); + + // Push loop context for I/J/K. + loopStack.push_back({counter, limit, checkBlock, exitBlock}); + + // Continue parsing in body. + builder.setInsertionPointToStart(bodyBlock); + stack = bodyBlock->getArgument(0); + + //=== LOOP === + } else if (word == "LOOP") { + consume(); + auto i64Type = builder.getI64Type(); + + if (loopStack.empty()) { + return emitError("LOOP without matching DO"); + } + + auto ctx = loopStack.pop_back_val(); + + // Increment counter: load, add 1, store. + Value c0 = builder.create(loc, 0); + Value idx = + builder.create(loc, ctx.counter, ValueRange{c0}); + Value one = builder.create( + loc, i64Type, builder.getI64IntegerAttr(1)); + Value next = builder.create(loc, idx, one); + builder.create(loc, next, ctx.counter, ValueRange{c0}); + + // Branch back to check. + builder.create(loc, ctx.check, ValueRange{stack}); + + // Continue after exit. + builder.setInsertionPointToStart(ctx.exit); + stack = ctx.exit->getArgument(0); + + //=== Normal word === + } else { + Value newStack = emitOperation(currentToken.text, stack, loc); + if (!newStack) + return emitError("unknown word: " + currentToken.text); + stack = newStack; + consume(); + } + } else { + return emitError("unexpected token"); + } } - consume(); // consume LOOP - // Restore insertion point to after the forth.do_loop op. - builder.setInsertionPointAfter(doLoopOp); - return doLoopOp.getOutputStack(); + return success(); } //===----------------------------------------------------------------------===// @@ -706,15 +679,15 @@ LogicalResult ForthParser::parseWordDefinition() { builder.setInsertionPointToStart(entryBlock); // Parse word body until ';' - if (failed(parseBody(resultStack, [](StringRef) { return false; }))) + if (failed(parseBody(resultStack))) return failure(); if (currentToken.kind != Token::Kind::Semicolon) { return emitError("unterminated word definition: missing ';'"); } - // Add return - builder.setInsertionPointToEnd(entryBlock); + // Add return at current insertion point (may differ from entry block + // if the word body contains control flow). builder.create(loc, resultStack); // Register the word @@ -751,7 +724,7 @@ LogicalResult ForthParser::parseOperations(Value &stack) { } // parseBody handles numbers, words, and IF/ELSE/THEN. - if (failed(parseBody(stack, [](StringRef) { return false; }))) + if (failed(parseBody(stack))) return failure(); } @@ -826,9 +799,12 @@ OwningOpRef ForthParser::parseModule() { OwningOpRef forth::parseForthSource(llvm::SourceMgr &sourceMgr, MLIRContext *context) { - // Ensure the Forth dialect is loaded + // Ensure required dialects are loaded context->loadDialect(); context->loadDialect(); + context->loadDialect(); + context->loadDialect(); + context->loadDialect(); // Create parser and parse the module ForthParser parser(sourceMgr, context); @@ -843,6 +819,8 @@ void mlir::forth::registerForthToMLIRTranslation() { return forth::parseForthSource(sourceMgr, context); }, [](DialectRegistry ®istry) { - registry.insert(); + registry.insert(); }); } diff --git a/lib/Translation/ForthToMLIR/ForthToMLIR.h b/lib/Translation/ForthToMLIR/ForthToMLIR.h index 0369f0a..264175b 100644 --- a/lib/Translation/ForthToMLIR/ForthToMLIR.h +++ b/lib/Translation/ForthToMLIR/ForthToMLIR.h @@ -47,10 +47,6 @@ class ForthLexer { /// Reset lexer to beginning of buffer. void reset(); - /// Save/restore lexer position for lookahead. - const char *getPosition() const { return curPtr; } - void setPosition(const char *pos) { curPtr = pos; } - private: llvm::SourceMgr &sourceMgr; unsigned bufferID; @@ -84,7 +80,22 @@ class ForthParser { std::unordered_set wordDefs; std::vector paramDecls; bool inWordDefinition = false; - int doLoopDepth = 0; + + /// Control flow stack tag: Orig = forward reference (IF->THEN, WHILE->exit), + /// Dest = backward reference (BEGIN->loop header). + enum class CFTag { Orig, Dest }; + + /// Control flow stack for IF/ELSE/THEN, BEGIN/UNTIL, BEGIN/WHILE/REPEAT. + SmallVector> cfStack; + + /// Loop context for DO/LOOP with I/J/K support. + struct LoopContext { + Value counter; // memref<1xi64> alloca for the loop counter + Value limit; // i64 loop limit + Block *check; // condition check block + Block *exit; // loop exit block + }; + SmallVector loopStack; /// Scan for `param ` declarations (pre-pass). void scanParamDeclarations(); @@ -105,26 +116,16 @@ class ForthParser { /// Returns the updated stack value or nullptr on error. Value emitOperation(StringRef word, Value inputStack, Location loc); - /// Parse a sequence of Forth operations until a stop word is hit. - /// The stop word is NOT consumed. Returns the final stack value. - LogicalResult parseBody(Value &stack, - llvm::function_ref isStopWord); - - /// Parse an IF/ELSE/THEN construct, creating a forth.if op. - Value parseIf(Value inputStack, Location loc); - - /// Parse a BEGIN/UNTIL loop, creating a forth.begin_until op. - Value parseBeginUntil(Value inputStack, Location loc); - - /// Parse a BEGIN/WHILE/REPEAT loop, creating a forth.begin_while_repeat op. - Value parseBeginWhileRepeat(Value inputStack, Location loc); + /// Create a new block with a single !forth.stack argument, appended to + /// the given region. + Block *createStackBlock(Region *region, Location loc); - /// Lookahead: is the current BEGIN a WHILE loop (vs UNTIL)? - /// Saves and restores lexer position. - bool isWhileLoop(); + /// Pop a flag from the stack: emits forth.pop_flag and returns + /// the updated stack and the i1 flag value. + std::pair emitPopFlag(Location loc, Value stack); - /// Parse a DO/LOOP counted loop, creating a forth.do_loop op. - Value parseDoLoop(Value inputStack, Location loc); + /// Parse a sequence of Forth operations, handling control flow inline. + LogicalResult parseBody(Value &stack); /// Parse a user-defined word definition. LogicalResult parseWordDefinition(); diff --git a/test/Conversion/ForthToMemRef/begin-until.mlir b/test/Conversion/ForthToMemRef/begin-until.mlir index 7ecb7bc..1df7aeb 100644 --- a/test/Conversion/ForthToMemRef/begin-until.mlir +++ b/test/Conversion/ForthToMemRef/begin-until.mlir @@ -1,36 +1,46 @@ // RUN: %warpforth-opt --convert-forth-to-memref %s | %FileCheck %s +// Test: BEGIN...UNTIL loop conversion to memref with CF-based control flow +// Forth: 10 BEGIN 1 - DUP 0= UNTIL + // CHECK-LABEL: func.func private @main -// Verify scf.while with index iter arg: -// CHECK: scf.while (%{{.*}} = %{{.*}}) : (index) -> index { +// Stack allocation and literal 10 push: +// CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<256xi64> +// CHECK: %[[C10:.*]] = arith.constant 10 : i64 +// CHECK: memref.store %[[C10]], %[[ALLOCA]] +// CHECK: cf.br ^bb1 -// Body: operations + flag pop + condition -// CHECK: arith.subi -// CHECK: arith.cmpi eq -// CHECK: arith.extsi -// CHECK: memref.load -// CHECK: arith.subi -// CHECK: arith.cmpi eq -// CHECK: scf.condition(%{{.*}}) %{{.*}} : index +// Loop body: push 1, subtract, dup, zero_eq, pop_flag, cond_br +// CHECK: ^bb1(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: arith.constant 1 : i64 +// CHECK: memref.store +// CHECK: arith.subi +// CHECK: memref.store +// CHECK: memref.store +// CHECK: arith.cmpi eq +// CHECK: arith.extsi +// CHECK: memref.store +// CHECK: arith.cmpi ne +// CHECK: cf.cond_br %{{.*}}, ^bb2(%{{.*}}: memref<256xi64>, index), ^bb1(%{{.*}}: memref<256xi64>, index) -// After region: identity yield -// CHECK: } do { -// CHECK: scf.yield %{{.*}} : index -// CHECK: } +// Exit block +// CHECK: ^bb2(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: return module { func.func private @main() { %0 = forth.stack !forth.stack %1 = forth.literal %0 10 : !forth.stack -> !forth.stack - %2 = forth.begin_until %1 : !forth.stack -> !forth.stack { - ^bb0(%arg0: !forth.stack): - %3 = forth.literal %arg0 1 : !forth.stack -> !forth.stack - %4 = forth.sub %3 : !forth.stack -> !forth.stack - %5 = forth.dup %4 : !forth.stack -> !forth.stack - %6 = forth.zero_eq %5 : !forth.stack -> !forth.stack - forth.yield %6 : !forth.stack - } + cf.br ^bb1(%1 : !forth.stack) + ^bb1(%2: !forth.stack): + %3 = forth.literal %2 1 : !forth.stack -> !forth.stack + %4 = forth.sub %3 : !forth.stack -> !forth.stack + %5 = forth.dup %4 : !forth.stack -> !forth.stack + %6 = forth.zero_eq %5 : !forth.stack -> !forth.stack + %output_stack, %flag = forth.pop_flag %6 : !forth.stack -> !forth.stack, i1 + cf.cond_br %flag, ^bb2(%output_stack : !forth.stack), ^bb1(%output_stack : !forth.stack) + ^bb2(%7: !forth.stack): return } } diff --git a/test/Conversion/ForthToMemRef/begin-while-repeat.mlir b/test/Conversion/ForthToMemRef/begin-while-repeat.mlir index 58978bc..ba23318 100644 --- a/test/Conversion/ForthToMemRef/begin-while-repeat.mlir +++ b/test/Conversion/ForthToMemRef/begin-while-repeat.mlir @@ -1,44 +1,56 @@ // RUN: %warpforth-opt --convert-forth-to-memref %s | %FileCheck %s +// Test: BEGIN...WHILE...REPEAT loop conversion to memref with CF-based control flow +// Forth: 10 BEGIN DUP 0 > WHILE 1 - REPEAT + // CHECK-LABEL: func.func private @main -// Verify scf.while with index iter arg: -// CHECK: scf.while (%{{.*}} = %{{.*}}) : (index) -> index { +// Stack allocation and literal 10 push: +// CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<256xi64> +// CHECK: %[[C10:.*]] = arith.constant 10 : i64 +// CHECK: memref.store %[[C10]], %[[ALLOCA]] +// CHECK: cf.br ^bb1 + +// Condition block: DUP, push 0, compare >, pop_flag, cond_br +// CHECK: ^bb1(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: memref.load +// CHECK: memref.store +// CHECK: arith.constant 0 : i64 +// CHECK: memref.store +// CHECK: arith.cmpi sgt +// CHECK: arith.extsi +// CHECK: memref.store +// CHECK: arith.cmpi ne +// CHECK: cf.cond_br %{{.*}}, ^bb2(%{{.*}}: memref<256xi64>, index), ^bb3(%{{.*}}: memref<256xi64>, index) -// Condition region: operations + flag pop + condition (ne for WHILE) -// CHECK: memref.load -// CHECK: arith.cmpi sgt -// CHECK: arith.extsi -// CHECK: memref.load -// CHECK: arith.subi -// CHECK: arith.cmpi ne -// CHECK: scf.condition(%{{.*}}) %{{.*}} : index +// Body block: push 1, subtract, branch back to condition +// CHECK: ^bb2(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: arith.constant 1 : i64 +// CHECK: memref.store +// CHECK: arith.subi +// CHECK: memref.store +// CHECK: cf.br ^bb1 -// Body region: operations + yield -// CHECK: } do { -// CHECK: arith.addi -// CHECK: memref.store -// CHECK: memref.load -// CHECK: arith.subi -// CHECK: scf.yield %{{.*}} : index -// CHECK: } +// Exit block +// CHECK: ^bb3(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: return module { func.func private @main() { %0 = forth.stack !forth.stack %1 = forth.literal %0 10 : !forth.stack -> !forth.stack - %2 = forth.begin_while_repeat %1 : !forth.stack -> !forth.stack { - ^bb0(%arg0: !forth.stack): - %3 = forth.dup %arg0 : !forth.stack -> !forth.stack - %4 = forth.literal %3 0 : !forth.stack -> !forth.stack - %5 = forth.gt %4 : !forth.stack -> !forth.stack - forth.yield %5 while_cond : !forth.stack - } do { - ^bb0(%arg1: !forth.stack): - %6 = forth.literal %arg1 1 : !forth.stack -> !forth.stack - %7 = forth.sub %6 : !forth.stack -> !forth.stack - forth.yield %7 : !forth.stack - } + cf.br ^bb1(%1 : !forth.stack) + ^bb1(%2: !forth.stack): + %3 = forth.dup %2 : !forth.stack -> !forth.stack + %4 = forth.literal %3 0 : !forth.stack -> !forth.stack + %5 = forth.gt %4 : !forth.stack -> !forth.stack + %output_stack, %flag = forth.pop_flag %5 : !forth.stack -> !forth.stack, i1 + cf.cond_br %flag, ^bb2(%output_stack : !forth.stack), ^bb3(%output_stack : !forth.stack) + ^bb2(%6: !forth.stack): + %7 = forth.literal %6 1 : !forth.stack -> !forth.stack + %8 = forth.sub %7 : !forth.stack -> !forth.stack + cf.br ^bb1(%8 : !forth.stack) + ^bb3(%9: !forth.stack): return } } diff --git a/test/Conversion/ForthToMemRef/control-flow.mlir b/test/Conversion/ForthToMemRef/control-flow.mlir index 8c77c64..121c0a8 100644 --- a/test/Conversion/ForthToMemRef/control-flow.mlir +++ b/test/Conversion/ForthToMemRef/control-flow.mlir @@ -1,46 +1,69 @@ // RUN: %warpforth-opt --convert-forth-to-memref %s | %FileCheck %s +// Test: IF/ELSE/THEN and IF/THEN conversion to memref with CF-based control flow +// Forth: 1 IF 42 ELSE 99 THEN 0 IF 7 THEN + // CHECK-LABEL: func.func private @main -// Verify flag load and condition: -// CHECK: %[[FLAG:.*]] = memref.load %{{.*}}[%[[SP:.*]]] -// CHECK: %[[ZERO:.*]] = arith.constant 0 : i64 -// CHECK: %[[COND:.*]] = arith.cmpi ne, %[[FLAG]], %[[ZERO]] : i64 - -// Verify scf.if with index result: -// CHECK: scf.if %[[COND]] -> (index) { - -// Then branch: drop (subi) + literal push + yield -// CHECK: arith.subi -// CHECK: arith.constant 42 : i64 -// CHECK: arith.addi -// CHECK: memref.store -// CHECK: scf.yield %{{.*}} : index - -// Else branch: drop (subi) + literal push + yield -// CHECK: } else { -// CHECK: arith.subi -// CHECK: arith.constant 99 : i64 -// CHECK: arith.addi -// CHECK: memref.store -// CHECK: scf.yield %{{.*}} : index -// CHECK: } +// Stack allocation and literal 1 push: +// CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<256xi64> +// CHECK: %[[C1:.*]] = arith.constant 1 : i64 +// CHECK: memref.store %[[C1]], %[[ALLOCA]] + +// Pop flag and conditional branch: +// CHECK: %[[FLAG1:.*]] = memref.load +// CHECK: %[[ZERO1:.*]] = arith.constant 0 : i64 +// CHECK: %[[COND1:.*]] = arith.cmpi ne, %[[FLAG1]], %[[ZERO1]] : i64 +// CHECK: cf.cond_br %[[COND1]], ^bb1(%[[ALLOCA]], %{{.*}} : memref<256xi64>, index), ^bb2(%[[ALLOCA]], %{{.*}} : memref<256xi64>, index) + +// Then branch: push 42 +// CHECK: ^bb1(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: arith.constant 42 : i64 +// CHECK: memref.store +// CHECK: cf.br ^bb3 + +// Else branch: push 99 +// CHECK: ^bb2(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: arith.constant 99 : i64 +// CHECK: memref.store +// CHECK: cf.br ^bb3 + +// Merge block: push 0, pop flag, second conditional branch +// CHECK: ^bb3(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: arith.constant 0 : i64 +// CHECK: memref.store +// CHECK: arith.cmpi ne +// CHECK: cf.cond_br + +// Second IF true branch: push 7 +// CHECK: ^bb4(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: arith.constant 7 : i64 +// CHECK: memref.store + +// Final merge and return +// CHECK: ^bb5(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: return module { func.func private @main() { %0 = forth.stack !forth.stack %1 = forth.literal %0 1 : !forth.stack -> !forth.stack - %2 = forth.if %1 : !forth.stack -> !forth.stack { - ^bb0(%arg0: !forth.stack): - %3 = forth.drop %arg0 : !forth.stack -> !forth.stack - %4 = forth.literal %3 42 : !forth.stack -> !forth.stack - forth.yield %4 : !forth.stack - } else { - ^bb0(%arg0: !forth.stack): - %3 = forth.drop %arg0 : !forth.stack -> !forth.stack - %4 = forth.literal %3 99 : !forth.stack -> !forth.stack - forth.yield %4 : !forth.stack - } + %output_stack, %flag = forth.pop_flag %1 : !forth.stack -> !forth.stack, i1 + cf.cond_br %flag, ^bb1(%output_stack : !forth.stack), ^bb2(%output_stack : !forth.stack) + ^bb1(%2: !forth.stack): + %3 = forth.literal %2 42 : !forth.stack -> !forth.stack + cf.br ^bb3(%3 : !forth.stack) + ^bb2(%4: !forth.stack): + %5 = forth.literal %4 99 : !forth.stack -> !forth.stack + cf.br ^bb3(%5 : !forth.stack) + ^bb3(%6: !forth.stack): + %7 = forth.literal %6 0 : !forth.stack -> !forth.stack + %output_stack_0, %flag_1 = forth.pop_flag %7 : !forth.stack -> !forth.stack, i1 + cf.cond_br %flag_1, ^bb4(%output_stack_0 : !forth.stack), ^bb5(%output_stack_0 : !forth.stack) + ^bb4(%8: !forth.stack): + %9 = forth.literal %8 7 : !forth.stack -> !forth.stack + cf.br ^bb5(%9 : !forth.stack) + ^bb5(%10: !forth.stack): return } } diff --git a/test/Conversion/ForthToMemRef/do-loop.mlir b/test/Conversion/ForthToMemRef/do-loop.mlir index 4957189..62fa347 100644 --- a/test/Conversion/ForthToMemRef/do-loop.mlir +++ b/test/Conversion/ForthToMemRef/do-loop.mlir @@ -1,35 +1,75 @@ // RUN: %warpforth-opt --convert-forth-to-memref %s | %FileCheck %s +// Test: DO...LOOP with I conversion to memref with CF-based control flow +// Forth: 10 0 DO I LOOP + // CHECK-LABEL: func.func private @main -// Verify start and limit are popped from stack: +// Stack allocation and push 10, 0: +// CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<256xi64> +// CHECK: arith.constant 10 : i64 +// CHECK: memref.store +// CHECK: arith.constant 0 : i64 +// CHECK: memref.store + +// Pop start and limit from stack: // CHECK: memref.load // CHECK: arith.subi // CHECK: memref.load // CHECK: arith.subi -// Verify scf.for with index bounds and iter arg: -// CHECK: arith.index_cast -// CHECK: arith.index_cast -// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}) -> (index) { +// Loop counter alloca and initialization: +// CHECK: %[[COUNTER:.*]] = memref.alloca() : memref<1xi64> +// CHECK: memref.store %{{.*}}, %[[COUNTER]] +// CHECK: cf.br ^bb1 + +// Loop header: load counter, compare < limit, cond_br +// CHECK: ^bb1(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: memref.load %[[COUNTER]] +// CHECK: arith.cmpi slt +// CHECK: cf.cond_br %{{.*}}, ^bb2(%{{.*}}: memref<256xi64>, index), ^bb3(%{{.*}}: memref<256xi64>, index) + +// Loop body: push I (load counter, push to stack), increment counter +// CHECK: ^bb2(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: memref.load %[[COUNTER]] +// CHECK: memref.store +// CHECK: memref.load %[[COUNTER]] +// CHECK: arith.constant 1 : i64 +// CHECK: arith.addi +// CHECK: memref.store %{{.*}}, %[[COUNTER]] +// CHECK: cf.br ^bb1 -// Verify loop_index pushes induction variable: -// CHECK: arith.index_cast -// CHECK: arith.addi -// CHECK: memref.store -// CHECK: scf.yield %{{.*}} : index -// CHECK: } +// Exit block +// CHECK: ^bb3(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: return module { func.func private @main() { %0 = forth.stack !forth.stack %1 = forth.literal %0 10 : !forth.stack -> !forth.stack %2 = forth.literal %1 0 : !forth.stack -> !forth.stack - %3 = forth.do_loop %2 : !forth.stack -> !forth.stack { - ^bb0(%arg0: !forth.stack): - %4 = forth.loop_index %arg0 : !forth.stack -> !forth.stack - forth.yield %4 : !forth.stack - } + %output_stack, %value = forth.pop %2 : !forth.stack -> !forth.stack, i64 + %output_stack_0, %value_1 = forth.pop %output_stack : !forth.stack -> !forth.stack, i64 + %alloca = memref.alloca() : memref<1xi64> + %c0 = arith.constant 0 : index + memref.store %value, %alloca[%c0] : memref<1xi64> + cf.br ^bb1(%output_stack_0 : !forth.stack) + ^bb1(%3: !forth.stack): + %c0_2 = arith.constant 0 : index + %4 = memref.load %alloca[%c0_2] : memref<1xi64> + %5 = arith.cmpi slt, %4, %value_1 : i64 + cf.cond_br %5, ^bb2(%3 : !forth.stack), ^bb3(%3 : !forth.stack) + ^bb2(%6: !forth.stack): + %c0_3 = arith.constant 0 : index + %7 = memref.load %alloca[%c0_3] : memref<1xi64> + %8 = forth.push_value %6, %7 : !forth.stack, i64 -> !forth.stack + %c0_4 = arith.constant 0 : index + %9 = memref.load %alloca[%c0_4] : memref<1xi64> + %c1_i64 = arith.constant 1 : i64 + %10 = arith.addi %9, %c1_i64 : i64 + memref.store %10, %alloca[%c0_4] : memref<1xi64> + cf.br ^bb1(%8 : !forth.stack) + ^bb3(%11: !forth.stack): return } } diff --git a/test/Conversion/ForthToMemRef/nested-control-flow.mlir b/test/Conversion/ForthToMemRef/nested-control-flow.mlir index f7face6..4ac0dab 100644 --- a/test/Conversion/ForthToMemRef/nested-control-flow.mlir +++ b/test/Conversion/ForthToMemRef/nested-control-flow.mlir @@ -1,149 +1,290 @@ // RUN: %warpforth-opt --convert-forth-to-memref %s | %FileCheck %s -// === Nested IF → nested scf.if === -// CHECK-LABEL: func.func private @test_nested_if -// CHECK: scf.if %{{.*}} -> (index) { -// CHECK: scf.if %{{.*}} -> (index) { -// CHECK: scf.yield -// CHECK: } else { -// CHECK: scf.yield -// CHECK: } -// CHECK: scf.yield -// CHECK: } else { -// CHECK: scf.yield -// CHECK: } +// === Nested IF: 1 IF 2 IF 3 THEN THEN === +// CHECK-LABEL: func.func private @TEST__NESTED__IF +// CHECK: arith.constant 1 : i64 +// CHECK: memref.store +// CHECK: arith.cmpi ne +// CHECK: cf.cond_br %{{.*}}, ^bb1({{.*}}), ^bb2({{.*}}) + +// Inner IF: push 2, pop_flag, cond_br +// CHECK: ^bb1(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: arith.constant 2 : i64 +// CHECK: memref.store +// CHECK: arith.cmpi ne +// CHECK: cf.cond_br %{{.*}}, ^bb3({{.*}}), ^bb4({{.*}}) + +// Outer merge -> return +// CHECK: ^bb2(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: return + +// Inner true: push 3 +// CHECK: ^bb3(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: arith.constant 3 : i64 +// CHECK: memref.store +// CHECK: cf.br ^bb4 + +// Inner merge -> outer merge +// CHECK: ^bb4(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: cf.br ^bb2 + +// === IF inside DO: 10 0 DO I 5 > IF I THEN LOOP === +// CHECK-LABEL: func.func private @TEST__IF__INSIDE__DO + +// DO loop setup: pop start/limit, alloca counter +// CHECK: arith.constant 10 : i64 +// CHECK: arith.constant 0 : i64 +// CHECK: %[[COUNTER1:.*]] = memref.alloca() : memref<1xi64> +// CHECK: memref.store %{{.*}}, %[[COUNTER1]] +// CHECK: cf.br ^bb1 + +// DO loop header: check counter < limit +// CHECK: ^bb1(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: memref.load %[[COUNTER1]] +// CHECK: arith.cmpi slt +// CHECK: cf.cond_br + +// DO loop body: push I, push 5, compare >, pop_flag, cond_br (IF) +// CHECK: ^bb2(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: memref.load %[[COUNTER1]] +// CHECK: memref.store +// CHECK: arith.constant 5 : i64 +// CHECK: arith.cmpi sgt +// CHECK: arith.cmpi ne +// CHECK: cf.cond_br + +// DO loop exit -> return +// CHECK: ^bb3(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: return + +// IF true: push I +// CHECK: ^bb4(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: memref.load %[[COUNTER1]] +// CHECK: memref.store +// CHECK: cf.br ^bb5 + +// IF merge: increment counter, loop back +// CHECK: ^bb5(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: memref.load %[[COUNTER1]] +// CHECK: arith.addi +// CHECK: memref.store %{{.*}}, %[[COUNTER1]] +// CHECK: cf.br ^bb1 + +// === Nested DO with J: 3 0 DO 4 0 DO J I + LOOP LOOP === +// CHECK-LABEL: func.func private @TEST__NESTED__DO__J + +// Outer DO setup +// CHECK: arith.constant 3 : i64 +// CHECK: arith.constant 0 : i64 +// CHECK: %[[OUTER:.*]] = memref.alloca() : memref<1xi64> +// CHECK: memref.store %{{.*}}, %[[OUTER]] +// CHECK: cf.br ^bb1 + +// Outer loop header: check outer counter < outer limit +// CHECK: ^bb1(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: memref.load %[[OUTER]] +// CHECK: arith.cmpi slt +// CHECK: cf.cond_br + +// Outer loop body: inner DO setup (4 0 DO) +// CHECK: ^bb2(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: arith.constant 4 : i64 +// CHECK: arith.constant 0 : i64 +// CHECK: %[[INNER:.*]] = memref.alloca() : memref<1xi64> +// CHECK: memref.store %{{.*}}, %[[INNER]] +// CHECK: cf.br ^bb4 + +// Outer loop exit -> return +// CHECK: ^bb3(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: return + +// Inner loop header: check inner counter < inner limit +// CHECK: ^bb4(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: memref.load %[[INNER]] +// CHECK: arith.cmpi slt +// CHECK: cf.cond_br + +// Inner loop body: J (load outer counter), I (load inner counter), add +// CHECK: ^bb5(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: memref.load %[[OUTER]] +// CHECK: memref.store +// CHECK: memref.load %[[INNER]] +// CHECK: memref.store +// CHECK: arith.addi +// CHECK: memref.store +// CHECK: memref.load %[[INNER]] +// CHECK: arith.addi +// CHECK: memref.store %{{.*}}, %[[INNER]] +// CHECK: cf.br ^bb4 + +// Inner loop exit -> increment outer counter, loop back +// CHECK: ^bb6(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: memref.load %[[OUTER]] +// CHECK: arith.addi +// CHECK: memref.store %{{.*}}, %[[OUTER]] +// CHECK: cf.br ^bb1 + +// === BEGIN/WHILE/REPEAT inside IF: 5 IF BEGIN DUP WHILE 1 - REPEAT THEN === +// CHECK-LABEL: func.func private @TEST__WHILE__INSIDE__IF + +// Push 5, pop_flag, cond_br (IF) +// CHECK: arith.constant 5 : i64 +// CHECK: memref.store +// CHECK: arith.cmpi ne +// CHECK: cf.cond_br %{{.*}}, ^bb1({{.*}}), ^bb2({{.*}}) + +// IF true -> jump to WHILE header +// CHECK: ^bb1(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: cf.br ^bb3 + +// IF false / WHILE exit merge -> return +// CHECK: ^bb2(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: return + +// WHILE condition: DUP, pop_flag, cond_br +// CHECK: ^bb3(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: memref.load +// CHECK: memref.store +// CHECK: arith.cmpi ne +// CHECK: cf.cond_br %{{.*}}, ^bb4({{.*}}), ^bb5({{.*}}) + +// WHILE body: push 1, subtract, loop back +// CHECK: ^bb4(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: arith.constant 1 : i64 +// CHECK: arith.subi +// CHECK: cf.br ^bb3 + +// WHILE exit -> merge with IF false +// CHECK: ^bb5(%{{.*}}: memref<256xi64>, %{{.*}}: index): +// CHECK: cf.br ^bb2 module { - func.func private @test_nested_if() { - %0 = forth.stack !forth.stack - %1 = forth.literal %0 1 : !forth.stack -> !forth.stack - %2 = forth.if %1 : !forth.stack -> !forth.stack { - ^bb0(%arg0: !forth.stack): - %3 = forth.drop %arg0 : !forth.stack -> !forth.stack - %4 = forth.literal %3 1 : !forth.stack -> !forth.stack - %5 = forth.if %4 : !forth.stack -> !forth.stack { - ^bb0(%arg1: !forth.stack): - %6 = forth.drop %arg1 : !forth.stack -> !forth.stack - %7 = forth.literal %6 42 : !forth.stack -> !forth.stack - forth.yield %7 : !forth.stack - } else { - ^bb0(%arg1: !forth.stack): - %6 = forth.drop %arg1 : !forth.stack -> !forth.stack - forth.yield %6 : !forth.stack - } - forth.yield %5 : !forth.stack - } else { - ^bb0(%arg0: !forth.stack): - %3 = forth.drop %arg0 : !forth.stack -> !forth.stack - forth.yield %3 : !forth.stack - } - return + func.func private @TEST__NESTED__IF(%arg0: !forth.stack) -> !forth.stack { + %0 = forth.literal %arg0 1 : !forth.stack -> !forth.stack + %output_stack, %flag = forth.pop_flag %0 : !forth.stack -> !forth.stack, i1 + cf.cond_br %flag, ^bb1(%output_stack : !forth.stack), ^bb2(%output_stack : !forth.stack) + ^bb1(%1: !forth.stack): + %2 = forth.literal %1 2 : !forth.stack -> !forth.stack + %output_stack_0, %flag_1 = forth.pop_flag %2 : !forth.stack -> !forth.stack, i1 + cf.cond_br %flag_1, ^bb3(%output_stack_0 : !forth.stack), ^bb4(%output_stack_0 : !forth.stack) + ^bb2(%3: !forth.stack): + return %3 : !forth.stack + ^bb3(%4: !forth.stack): + %5 = forth.literal %4 3 : !forth.stack -> !forth.stack + cf.br ^bb4(%5 : !forth.stack) + ^bb4(%6: !forth.stack): + cf.br ^bb2(%6 : !forth.stack) } - - // === IF inside DO/LOOP → scf.if inside scf.for === - // CHECK-LABEL: func.func private @test_if_inside_do - // CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}) -> (index) { - // CHECK: scf.if %{{.*}} -> (index) { - // CHECK: scf.yield - // CHECK: } else { - // CHECK: scf.yield - // CHECK: } - // CHECK: scf.yield - // CHECK: } - - func.func private @test_if_inside_do() { - %0 = forth.stack !forth.stack - %1 = forth.literal %0 10 : !forth.stack -> !forth.stack - %2 = forth.literal %1 0 : !forth.stack -> !forth.stack - %3 = forth.do_loop %2 : !forth.stack -> !forth.stack { - ^bb0(%arg0: !forth.stack): - %4 = forth.loop_index %arg0 : !forth.stack -> !forth.stack - %5 = forth.literal %4 5 : !forth.stack -> !forth.stack - %6 = forth.gt %5 : !forth.stack -> !forth.stack - %7 = forth.if %6 : !forth.stack -> !forth.stack { - ^bb0(%arg1: !forth.stack): - %8 = forth.drop %arg1 : !forth.stack -> !forth.stack - %9 = forth.literal %8 99 : !forth.stack -> !forth.stack - forth.yield %9 : !forth.stack - } else { - ^bb0(%arg1: !forth.stack): - %8 = forth.drop %arg1 : !forth.stack -> !forth.stack - forth.yield %8 : !forth.stack - } - forth.yield %7 : !forth.stack - } - return + func.func private @TEST__IF__INSIDE__DO(%arg0: !forth.stack) -> !forth.stack { + %0 = forth.literal %arg0 10 : !forth.stack -> !forth.stack + %1 = forth.literal %0 0 : !forth.stack -> !forth.stack + %output_stack, %value = forth.pop %1 : !forth.stack -> !forth.stack, i64 + %output_stack_0, %value_1 = forth.pop %output_stack : !forth.stack -> !forth.stack, i64 + %alloca = memref.alloca() : memref<1xi64> + %c0 = arith.constant 0 : index + memref.store %value, %alloca[%c0] : memref<1xi64> + cf.br ^bb1(%output_stack_0 : !forth.stack) + ^bb1(%2: !forth.stack): + %c0_2 = arith.constant 0 : index + %3 = memref.load %alloca[%c0_2] : memref<1xi64> + %4 = arith.cmpi slt, %3, %value_1 : i64 + cf.cond_br %4, ^bb2(%2 : !forth.stack), ^bb3(%2 : !forth.stack) + ^bb2(%5: !forth.stack): + %c0_3 = arith.constant 0 : index + %6 = memref.load %alloca[%c0_3] : memref<1xi64> + %7 = forth.push_value %5, %6 : !forth.stack, i64 -> !forth.stack + %8 = forth.literal %7 5 : !forth.stack -> !forth.stack + %9 = forth.gt %8 : !forth.stack -> !forth.stack + %output_stack_4, %flag = forth.pop_flag %9 : !forth.stack -> !forth.stack, i1 + cf.cond_br %flag, ^bb4(%output_stack_4 : !forth.stack), ^bb5(%output_stack_4 : !forth.stack) + ^bb3(%10: !forth.stack): + return %10 : !forth.stack + ^bb4(%11: !forth.stack): + %c0_5 = arith.constant 0 : index + %12 = memref.load %alloca[%c0_5] : memref<1xi64> + %13 = forth.push_value %11, %12 : !forth.stack, i64 -> !forth.stack + cf.br ^bb5(%13 : !forth.stack) + ^bb5(%14: !forth.stack): + %c0_6 = arith.constant 0 : index + %15 = memref.load %alloca[%c0_6] : memref<1xi64> + %c1_i64 = arith.constant 1 : i64 + %16 = arith.addi %15, %c1_i64 : i64 + memref.store %16, %alloca[%c0_6] : memref<1xi64> + cf.br ^bb1(%14 : !forth.stack) } - - // === Nested DO/LOOP with J (depth=1) === - // CHECK-LABEL: func.func private @test_nested_do_j - // Outer scf.for - // CHECK: scf.for %[[OUTER_IV:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}) -> (index) { - // Inner scf.for - // CHECK: scf.for %[[INNER_IV:.*]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}) -> (index) { - // J pushes outer IV, I pushes inner IV - // CHECK: arith.index_cast %[[OUTER_IV]] - // CHECK: arith.index_cast %[[INNER_IV]] - // CHECK: scf.yield - // CHECK: } - // CHECK: scf.yield - // CHECK: } - - func.func private @test_nested_do_j() { - %0 = forth.stack !forth.stack - %1 = forth.literal %0 3 : !forth.stack -> !forth.stack - %2 = forth.literal %1 0 : !forth.stack -> !forth.stack - %3 = forth.do_loop %2 : !forth.stack -> !forth.stack { - ^bb0(%arg0: !forth.stack): - %4 = forth.literal %arg0 4 : !forth.stack -> !forth.stack - %5 = forth.literal %4 0 : !forth.stack -> !forth.stack - %6 = forth.do_loop %5 : !forth.stack -> !forth.stack { - ^bb0(%arg1: !forth.stack): - %7 = forth.loop_index %arg1 {depth = 1 : i64} : !forth.stack -> !forth.stack - %8 = forth.loop_index %7 : !forth.stack -> !forth.stack - %9 = forth.add %8 : !forth.stack -> !forth.stack - forth.yield %9 : !forth.stack - } - forth.yield %6 : !forth.stack - } - return + func.func private @TEST__NESTED__DO__J(%arg0: !forth.stack) -> !forth.stack { + %0 = forth.literal %arg0 3 : !forth.stack -> !forth.stack + %1 = forth.literal %0 0 : !forth.stack -> !forth.stack + %output_stack, %value = forth.pop %1 : !forth.stack -> !forth.stack, i64 + %output_stack_0, %value_1 = forth.pop %output_stack : !forth.stack -> !forth.stack, i64 + %alloca = memref.alloca() : memref<1xi64> + %c0 = arith.constant 0 : index + memref.store %value, %alloca[%c0] : memref<1xi64> + cf.br ^bb1(%output_stack_0 : !forth.stack) + ^bb1(%2: !forth.stack): + %c0_2 = arith.constant 0 : index + %3 = memref.load %alloca[%c0_2] : memref<1xi64> + %4 = arith.cmpi slt, %3, %value_1 : i64 + cf.cond_br %4, ^bb2(%2 : !forth.stack), ^bb3(%2 : !forth.stack) + ^bb2(%5: !forth.stack): + %6 = forth.literal %5 4 : !forth.stack -> !forth.stack + %7 = forth.literal %6 0 : !forth.stack -> !forth.stack + %output_stack_3, %value_4 = forth.pop %7 : !forth.stack -> !forth.stack, i64 + %output_stack_5, %value_6 = forth.pop %output_stack_3 : !forth.stack -> !forth.stack, i64 + %alloca_7 = memref.alloca() : memref<1xi64> + %c0_8 = arith.constant 0 : index + memref.store %value_4, %alloca_7[%c0_8] : memref<1xi64> + cf.br ^bb4(%output_stack_5 : !forth.stack) + ^bb3(%8: !forth.stack): + return %8 : !forth.stack + ^bb4(%9: !forth.stack): + %c0_9 = arith.constant 0 : index + %10 = memref.load %alloca_7[%c0_9] : memref<1xi64> + %11 = arith.cmpi slt, %10, %value_6 : i64 + cf.cond_br %11, ^bb5(%9 : !forth.stack), ^bb6(%9 : !forth.stack) + ^bb5(%12: !forth.stack): + %c0_10 = arith.constant 0 : index + %13 = memref.load %alloca[%c0_10] : memref<1xi64> + %14 = forth.push_value %12, %13 : !forth.stack, i64 -> !forth.stack + %c0_11 = arith.constant 0 : index + %15 = memref.load %alloca_7[%c0_11] : memref<1xi64> + %16 = forth.push_value %14, %15 : !forth.stack, i64 -> !forth.stack + %17 = forth.add %16 : !forth.stack -> !forth.stack + %c0_12 = arith.constant 0 : index + %18 = memref.load %alloca_7[%c0_12] : memref<1xi64> + %c1_i64 = arith.constant 1 : i64 + %19 = arith.addi %18, %c1_i64 : i64 + memref.store %19, %alloca_7[%c0_12] : memref<1xi64> + cf.br ^bb4(%17 : !forth.stack) + ^bb6(%20: !forth.stack): + %c0_13 = arith.constant 0 : index + %21 = memref.load %alloca[%c0_13] : memref<1xi64> + %c1_i64_14 = arith.constant 1 : i64 + %22 = arith.addi %21, %c1_i64_14 : i64 + memref.store %22, %alloca[%c0_13] : memref<1xi64> + cf.br ^bb1(%20 : !forth.stack) } - - // === BEGIN/WHILE/REPEAT inside IF === - // CHECK-LABEL: func.func private @test_while_inside_if - // CHECK: scf.if %{{.*}} -> (index) { - // CHECK: scf.while - // CHECK: scf.condition - // CHECK: } do { - // CHECK: scf.yield - // CHECK: } - // CHECK: scf.yield - // CHECK: } else { - // CHECK: scf.yield - // CHECK: } - - func.func private @test_while_inside_if() { + func.func private @TEST__WHILE__INSIDE__IF(%arg0: !forth.stack) -> !forth.stack { + %0 = forth.literal %arg0 5 : !forth.stack -> !forth.stack + %output_stack, %flag = forth.pop_flag %0 : !forth.stack -> !forth.stack, i1 + cf.cond_br %flag, ^bb1(%output_stack : !forth.stack), ^bb2(%output_stack : !forth.stack) + ^bb1(%1: !forth.stack): + cf.br ^bb3(%1 : !forth.stack) + ^bb2(%2: !forth.stack): + return %2 : !forth.stack + ^bb3(%3: !forth.stack): + %4 = forth.dup %3 : !forth.stack -> !forth.stack + %output_stack_0, %flag_1 = forth.pop_flag %4 : !forth.stack -> !forth.stack, i1 + cf.cond_br %flag_1, ^bb4(%output_stack_0 : !forth.stack), ^bb5(%output_stack_0 : !forth.stack) + ^bb4(%5: !forth.stack): + %6 = forth.literal %5 1 : !forth.stack -> !forth.stack + %7 = forth.sub %6 : !forth.stack -> !forth.stack + cf.br ^bb3(%7 : !forth.stack) + ^bb5(%8: !forth.stack): + cf.br ^bb2(%8 : !forth.stack) + } + func.func private @main() { %0 = forth.stack !forth.stack - %1 = forth.literal %0 1 : !forth.stack -> !forth.stack - %2 = forth.if %1 : !forth.stack -> !forth.stack { - ^bb0(%arg0: !forth.stack): - %3 = forth.drop %arg0 : !forth.stack -> !forth.stack - %4 = forth.begin_while_repeat %3 : !forth.stack -> !forth.stack { - ^bb0(%arg1: !forth.stack): - %5 = forth.dup %arg1 : !forth.stack -> !forth.stack - forth.yield %5 while_cond : !forth.stack - } do { - ^bb0(%arg1: !forth.stack): - %5 = forth.literal %arg1 1 : !forth.stack -> !forth.stack - %6 = forth.sub %5 : !forth.stack -> !forth.stack - forth.yield %6 : !forth.stack - } - forth.yield %4 : !forth.stack - } else { - ^bb0(%arg0: !forth.stack): - %3 = forth.drop %arg0 : !forth.stack -> !forth.stack - forth.yield %3 : !forth.stack - } return } } diff --git a/test/Conversion/ForthToMemRef/stack-manipulation.mlir b/test/Conversion/ForthToMemRef/stack-manipulation.mlir index bd69acd..e8a1721 100644 --- a/test/Conversion/ForthToMemRef/stack-manipulation.mlir +++ b/test/Conversion/ForthToMemRef/stack-manipulation.mlir @@ -55,16 +55,22 @@ // CHECK: %[[PICK_VAL:.*]] = memref.load %{{.*}}[%[[PICK_ADDR]]] : memref<256xi64> // CHECK: memref.store %[[PICK_VAL]] -// roll: load n, index_cast, subi (dynamic), load saved, scf.for with load/store, store saved +// roll: load n, index_cast, subi (dynamic), load saved, cf loop with load/store, store saved // CHECK: %[[ROLL_N:.*]] = memref.load %{{.*}}[%{{.*}}] : memref<256xi64> // CHECK: %[[ROLL_SP1:.*]] = arith.subi // CHECK: %[[ROLL_NIDX:.*]] = arith.index_cast %[[ROLL_N]] // CHECK: %[[ROLL_ADDR:.*]] = arith.subi %[[ROLL_SP1]], %[[ROLL_NIDX]] // CHECK: %[[ROLL_SAVED:.*]] = memref.load %{{.*}}[%[[ROLL_ADDR]]] : memref<256xi64> -// CHECK: scf.for %[[ROLL_IV:.*]] = %[[ROLL_ADDR]] to %[[ROLL_SP1]] -// CHECK: %[[ROLL_NEXT:.*]] = arith.addi %[[ROLL_IV]] -// CHECK: %[[ROLL_SHIFTED:.*]] = memref.load %{{.*}}[%[[ROLL_NEXT]]] : memref<256xi64> -// CHECK: memref.store %[[ROLL_SHIFTED]], %{{.*}}[%[[ROLL_IV]]] : memref<256xi64> +// CHECK: cf.br ^[[ROLL_HDR:.*]](%[[ROLL_ADDR]] : index) +// CHECK: ^[[ROLL_HDR]](%[[ROLL_IV:.*]]: index): +// CHECK: %[[ROLL_CMP:.*]] = arith.cmpi slt, %[[ROLL_IV]], %[[ROLL_SP1]] : index +// CHECK: cf.cond_br %[[ROLL_CMP]], ^[[ROLL_BODY:.*]](%[[ROLL_IV]] : index), ^[[ROLL_EXIT:.*]] +// CHECK: ^[[ROLL_BODY]](%[[ROLL_BIV:.*]]: index): +// CHECK: %[[ROLL_NEXT:.*]] = arith.addi %[[ROLL_BIV]] +// CHECK: %[[ROLL_SHIFTED:.*]] = memref.load %{{.*}}[%[[ROLL_NEXT]]] : memref<256xi64> +// CHECK: memref.store %[[ROLL_SHIFTED]], %{{.*}}[%[[ROLL_BIV]]] : memref<256xi64> +// CHECK: cf.br ^[[ROLL_HDR]](%[[ROLL_NEXT]] : index) +// CHECK: ^[[ROLL_EXIT]]: // CHECK: memref.store %[[ROLL_SAVED]], %{{.*}}[%[[ROLL_SP1]]] : memref<256xi64> module { diff --git a/test/Pipeline/interleaved-control-flow.forth b/test/Pipeline/interleaved-control-flow.forth new file mode 100644 index 0000000..544b917 --- /dev/null +++ b/test/Pipeline/interleaved-control-flow.forth @@ -0,0 +1,30 @@ +\ RUN: %warpforth-translate --forth-to-mlir %s | %warpforth-opt --warpforth-pipeline | %FileCheck %s +\ RUN: %warpforth-translate --forth-to-mlir %s | %warpforth-opt --convert-forth-to-memref --convert-forth-to-gpu | %FileCheck %s --check-prefix=MID + +\ Verify interleaved control flow (multi-WHILE, WHILE+UNTIL) compiles +\ through the full pipeline to gpu.binary. +\ CHECK: gpu.binary @warpforth_module + +\ Verify intermediate MLIR: gpu.func with cf branches, no scf ops +\ MID: gpu.module @warpforth_module +\ MID: gpu.func @main(%arg0: memref<4xi64> {forth.param_name = "DATA"}) kernel +\ MID: gpu.return + +\ Multi-WHILE: two cond_br exits + one unconditional back-edge +\ MID: func.func private @MULTI_WHILE +\ MID: cf.cond_br +\ MID: cf.cond_br +\ MID: cf.br + +\ WHILE+UNTIL: WHILE exit + UNTIL exit merge at THEN +\ MID: func.func private @WHILE_UNTIL +\ MID: cf.cond_br +\ MID: cf.cond_br +\ MID: cf.br + +PARAM DATA 4 +: multi-while + BEGIN DUP 10 > WHILE DUP 2 MOD 0= WHILE 1 - REPEAT DROP THEN ; +: while-until + BEGIN DUP 0 > WHILE 1 - DUP 5 = UNTIL THEN ; +multi-while while-until DATA 0 CELLS + ! diff --git a/test/Translation/Forth/begin-until.forth b/test/Translation/Forth/begin-until.forth index 23936be..dbd2c27 100644 --- a/test/Translation/Forth/begin-until.forth +++ b/test/Translation/Forth/begin-until.forth @@ -1,15 +1,17 @@ \ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s -\ Verify BEGIN/UNTIL parsing produces forth.begin_until with body region +\ Verify BEGIN/UNTIL generates loop with pop_flag + cond_br -\ CHECK: %[[S0:.*]] = forth.stack -\ CHECK: %[[S1:.*]] = forth.literal %[[S0]] 10 -\ CHECK: %[[LOOP:.*]] = forth.begin_until %[[S1]] -\ CHECK: ^bb0(%[[ARG:.*]]: !forth.stack): -\ CHECK: forth.literal %[[ARG]] 1 -\ CHECK: forth.sub -\ CHECK: forth.dup -\ CHECK: forth.zero_eq -\ CHECK: forth.yield -\ CHECK: } +\ CHECK: %[[S0:.*]] = forth.stack !forth.stack +\ CHECK-NEXT: %[[S1:.*]] = forth.literal %[[S0]] 10 : !forth.stack -> !forth.stack +\ CHECK-NEXT: cf.br ^bb1(%[[S1]] : !forth.stack) +\ CHECK: ^bb1(%[[B1:.*]]: !forth.stack): +\ CHECK-NEXT: %[[L1:.*]] = forth.literal %[[B1]] 1 : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[SUB:.*]] = forth.sub %[[L1]] : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[DUP:.*]] = forth.dup %[[SUB]] : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[ZEQ:.*]] = forth.zero_eq %[[DUP]] : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[PF:.*]], %[[FLAG:.*]] = forth.pop_flag %[[ZEQ]] : !forth.stack -> !forth.stack, i1 +\ CHECK-NEXT: cf.cond_br %[[FLAG]], ^bb2(%[[PF]] : !forth.stack), ^bb1(%[[PF]] : !forth.stack) +\ CHECK: ^bb2(%[[B2:.*]]: !forth.stack): +\ CHECK-NEXT: return 10 BEGIN 1 - DUP 0= UNTIL diff --git a/test/Translation/Forth/begin-while-repeat.forth b/test/Translation/Forth/begin-while-repeat.forth index e8aea5d..873f44c 100644 --- a/test/Translation/Forth/begin-while-repeat.forth +++ b/test/Translation/Forth/begin-while-repeat.forth @@ -1,20 +1,20 @@ \ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s -\ Verify BEGIN/WHILE/REPEAT parsing produces forth.begin_while_repeat -\ with condition and body regions +\ Verify BEGIN/WHILE/REPEAT generates condition check + body loop with cond_br -\ CHECK: %[[S0:.*]] = forth.stack -\ CHECK: %[[S1:.*]] = forth.literal %[[S0]] 10 -\ CHECK: %[[LOOP:.*]] = forth.begin_while_repeat %[[S1]] -\ CHECK: ^bb0(%[[CARG:.*]]: !forth.stack): -\ CHECK: forth.dup -\ CHECK: forth.literal -\ CHECK: forth.gt -\ CHECK: forth.yield %{{.*}} while_cond -\ CHECK: } do { -\ CHECK: ^bb0(%[[BARG:.*]]: !forth.stack): -\ CHECK: forth.literal -\ CHECK: forth.sub -\ CHECK: forth.yield -\ CHECK: } +\ CHECK: %[[S0:.*]] = forth.stack !forth.stack +\ CHECK-NEXT: %[[S1:.*]] = forth.literal %[[S0]] 10 : !forth.stack -> !forth.stack +\ CHECK-NEXT: cf.br ^bb1(%[[S1]] : !forth.stack) +\ CHECK: ^bb1(%[[B1:.*]]: !forth.stack): +\ CHECK-NEXT: %[[DUP:.*]] = forth.dup %[[B1]] : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[L0:.*]] = forth.literal %[[DUP]] 0 : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[GT:.*]] = forth.gt %[[L0]] : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[PF:.*]], %[[FLAG:.*]] = forth.pop_flag %[[GT]] : !forth.stack -> !forth.stack, i1 +\ CHECK-NEXT: cf.cond_br %[[FLAG]], ^bb2(%[[PF]] : !forth.stack), ^bb3(%[[PF]] : !forth.stack) +\ CHECK: ^bb2(%[[B2:.*]]: !forth.stack): +\ CHECK-NEXT: %[[L1:.*]] = forth.literal %[[B2]] 1 : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[SUB:.*]] = forth.sub %[[L1]] : !forth.stack -> !forth.stack +\ CHECK-NEXT: cf.br ^bb1(%[[SUB]] : !forth.stack) +\ CHECK: ^bb3(%[[B3:.*]]: !forth.stack): +\ CHECK-NEXT: return 10 BEGIN DUP 0 > WHILE 1 - REPEAT diff --git a/test/Translation/Forth/control-flow.forth b/test/Translation/Forth/control-flow.forth index d208790..8e03361 100644 --- a/test/Translation/Forth/control-flow.forth +++ b/test/Translation/Forth/control-flow.forth @@ -1,33 +1,28 @@ \ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s -\ Verify IF/ELSE/THEN parsing produces forth.if with block-arg regions +\ Verify IF/ELSE/THEN generates pop_flag + cond_br control flow \ Basic IF/ELSE/THEN -\ CHECK: %[[S0:.*]] = forth.stack -\ CHECK: %[[S1:.*]] = forth.literal %[[S0]] 1 -\ CHECK: %[[IF1:.*]] = forth.if %[[S1]] -\ CHECK: ^bb0(%[[ARG1:.*]]: !forth.stack): -\ CHECK: forth.drop %[[ARG1]] -\ CHECK: forth.literal %{{.*}} 42 -\ CHECK: forth.yield -\ CHECK: } else { -\ CHECK: ^bb0(%[[ARG2:.*]]: !forth.stack): -\ CHECK: forth.drop %[[ARG2]] -\ CHECK: forth.literal %{{.*}} 99 -\ CHECK: forth.yield -\ CHECK: } +\ CHECK: %[[S0:.*]] = forth.stack !forth.stack +\ CHECK-NEXT: %[[S1:.*]] = forth.literal %[[S0]] 1 : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[PF1:.*]], %[[FLAG1:.*]] = forth.pop_flag %[[S1]] : !forth.stack -> !forth.stack, i1 +\ CHECK-NEXT: cf.cond_br %[[FLAG1]], ^bb1(%[[PF1]] : !forth.stack), ^bb2(%[[PF1]] : !forth.stack) +\ CHECK: ^bb1(%[[B1:.*]]: !forth.stack): +\ CHECK-NEXT: %[[L42:.*]] = forth.literal %[[B1]] 42 : !forth.stack -> !forth.stack +\ CHECK-NEXT: cf.br ^bb3(%[[L42]] : !forth.stack) +\ CHECK: ^bb2(%[[B2:.*]]: !forth.stack): +\ CHECK-NEXT: %[[L99:.*]] = forth.literal %[[B2]] 99 : !forth.stack -> !forth.stack +\ CHECK-NEXT: cf.br ^bb3(%[[L99]] : !forth.stack) 1 IF 42 ELSE 99 THEN -\ Basic IF/THEN (no ELSE — identity drop+yield in else region) -\ CHECK: %[[S2:.*]] = forth.literal %[[IF1]] 0 -\ CHECK: %[[IF2:.*]] = forth.if %[[S2]] -\ CHECK: ^bb0(%[[ARG3:.*]]: !forth.stack): -\ CHECK: forth.drop %[[ARG3]] -\ CHECK: forth.literal %{{.*}} 7 -\ CHECK: forth.yield -\ CHECK: } else { -\ CHECK: ^bb0(%[[ARG4:.*]]: !forth.stack): -\ CHECK: forth.drop %[[ARG4]] -\ CHECK: forth.yield -\ CHECK: } +\ Basic IF/THEN (no ELSE - fallthrough on false) +\ CHECK: ^bb3(%[[B3:.*]]: !forth.stack): +\ CHECK-NEXT: %[[S2:.*]] = forth.literal %[[B3]] 0 : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[PF2:.*]], %[[FLAG2:.*]] = forth.pop_flag %[[S2]] : !forth.stack -> !forth.stack, i1 +\ CHECK-NEXT: cf.cond_br %[[FLAG2]], ^bb4(%[[PF2]] : !forth.stack), ^bb5(%[[PF2]] : !forth.stack) +\ CHECK: ^bb4(%[[B4:.*]]: !forth.stack): +\ CHECK-NEXT: %[[L7:.*]] = forth.literal %[[B4]] 7 : !forth.stack -> !forth.stack +\ CHECK-NEXT: cf.br ^bb5(%[[L7]] : !forth.stack) +\ CHECK: ^bb5(%[[B5:.*]]: !forth.stack): +\ CHECK-NEXT: return 0 IF 7 THEN diff --git a/test/Translation/Forth/do-loop.forth b/test/Translation/Forth/do-loop.forth index 35eb704..4a4a7ff 100644 --- a/test/Translation/Forth/do-loop.forth +++ b/test/Translation/Forth/do-loop.forth @@ -1,13 +1,31 @@ \ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s -\ Verify DO/LOOP parsing produces forth.do_loop with forth.loop_index +\ Verify DO/LOOP generates loop counter with memref.alloca, pop, cmpi, cond_br -\ CHECK: %[[S0:.*]] = forth.stack -\ CHECK: %[[S1:.*]] = forth.literal %[[S0]] 10 -\ CHECK: %[[S2:.*]] = forth.literal %[[S1]] 0 -\ CHECK: %[[LOOP:.*]] = forth.do_loop %[[S2]] -\ CHECK: ^bb0(%[[ARG:.*]]: !forth.stack): -\ CHECK: forth.loop_index %[[ARG]] -\ CHECK: forth.yield -\ CHECK: } +\ CHECK: %[[S0:.*]] = forth.stack !forth.stack +\ CHECK-NEXT: %[[S1:.*]] = forth.literal %[[S0]] 10 : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[S2:.*]] = forth.literal %[[S1]] 0 : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[OS:.*]], %[[VAL:.*]] = forth.pop %[[S2]] : !forth.stack -> !forth.stack, i64 +\ CHECK-NEXT: %[[OS2:.*]], %[[LIM:.*]] = forth.pop %[[OS]] : !forth.stack -> !forth.stack, i64 +\ CHECK-NEXT: %[[ALLOCA:.*]] = memref.alloca() : memref<1xi64> +\ CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index +\ CHECK-NEXT: memref.store %[[VAL]], %[[ALLOCA]][%[[C0]]] : memref<1xi64> +\ CHECK-NEXT: cf.br ^bb1(%[[OS2]] : !forth.stack) +\ CHECK: ^bb1(%[[B1:.*]]: !forth.stack): +\ CHECK-NEXT: %[[C0_2:.*]] = arith.constant 0 : index +\ CHECK-NEXT: %[[LOAD1:.*]] = memref.load %[[ALLOCA]][%[[C0_2]]] : memref<1xi64> +\ CHECK-NEXT: %[[CMP:.*]] = arith.cmpi slt, %[[LOAD1]], %[[LIM]] : i64 +\ CHECK-NEXT: cf.cond_br %[[CMP]], ^bb2(%[[B1]] : !forth.stack), ^bb3(%[[B1]] : !forth.stack) +\ CHECK: ^bb2(%[[B2:.*]]: !forth.stack): +\ CHECK-NEXT: %[[C0_3:.*]] = arith.constant 0 : index +\ CHECK-NEXT: %[[LOAD2:.*]] = memref.load %[[ALLOCA]][%[[C0_3]]] : memref<1xi64> +\ CHECK-NEXT: %[[PUSH:.*]] = forth.push_value %[[B2]], %[[LOAD2]] : !forth.stack, i64 -> !forth.stack +\ CHECK-NEXT: %[[C0_4:.*]] = arith.constant 0 : index +\ CHECK-NEXT: %[[LOAD3:.*]] = memref.load %[[ALLOCA]][%[[C0_4]]] : memref<1xi64> +\ CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : i64 +\ CHECK-NEXT: %[[ADDI:.*]] = arith.addi %[[LOAD3]], %[[C1]] : i64 +\ CHECK-NEXT: memref.store %[[ADDI]], %[[ALLOCA]][%[[C0_4]]] : memref<1xi64> +\ CHECK-NEXT: cf.br ^bb1(%[[PUSH]] : !forth.stack) +\ CHECK: ^bb3(%[[B3:.*]]: !forth.stack): +\ CHECK-NEXT: return 10 0 DO I LOOP diff --git a/test/Translation/Forth/interleaved-control-flow.forth b/test/Translation/Forth/interleaved-control-flow.forth new file mode 100644 index 0000000..9e38d5b --- /dev/null +++ b/test/Translation/Forth/interleaved-control-flow.forth @@ -0,0 +1,88 @@ +\ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s + +\ These test interleaved control flow patterns that require a compile-time +\ control-flow stack (cf-stack) and cannot be expressed with structured +\ region-holding ops. + +\ === Multi-WHILE: two exit conditions from the same loop === +\ Loop exits via WHILE(1) when value <= 10, or via WHILE(2) when value is odd. +\ The DROP between REPEAT and THEN cleans up the stack on the WHILE(2) exit path. + +\ CHECK-LABEL: func.func private @MULTI_WHILE +\ Entry: branch to loop header +\ CHECK: cf.br ^bb1 + +\ Loop header: DUP 10 > → WHILE(1) +\ CHECK: ^bb1(%[[H:.*]]: !forth.stack): +\ CHECK: forth.dup +\ CHECK: forth.literal %{{.*}} 10 +\ CHECK-NEXT: %{{.*}} = forth.gt +\ CHECK: forth.pop_flag +\ CHECK-NEXT: cf.cond_br %{{.*}}, ^bb2(%{{.*}} : !forth.stack), ^bb3(%{{.*}} : !forth.stack) + +\ WHILE(1) body: DUP 2 MOD 0= → WHILE(2) +\ CHECK: ^bb2(%{{.*}}: !forth.stack): +\ CHECK: forth.dup +\ CHECK: forth.literal %{{.*}} 2 +\ CHECK-NEXT: %{{.*}} = forth.mod +\ CHECK-NEXT: %{{.*}} = forth.zero_eq +\ CHECK: forth.pop_flag +\ CHECK-NEXT: cf.cond_br %{{.*}}, ^bb4(%{{.*}} : !forth.stack), ^bb5(%{{.*}} : !forth.stack) + +\ WHILE(1) exit / return block (also reached from THEN) +\ CHECK: ^bb3(%{{.*}}: !forth.stack): +\ CHECK-NEXT: return + +\ WHILE(2) body: 1 - → REPEAT (branch back to loop header) +\ CHECK: ^bb4(%[[B4:.*]]: !forth.stack): +\ CHECK-NEXT: %{{.*}} = forth.literal %[[B4]] 1 +\ CHECK-NEXT: %{{.*}} = forth.sub +\ CHECK-NEXT: cf.br ^bb1 + +\ WHILE(2) exit: DROP → THEN (branch to WHILE(1) exit) +\ CHECK: ^bb5(%{{.*}}: !forth.stack): +\ CHECK-NEXT: %{{.*}} = forth.drop +\ CHECK-NEXT: cf.br ^bb3 + +: multi-while + BEGIN DUP 10 > WHILE DUP 2 MOD 0= WHILE 1 - REPEAT DROP THEN ; + +\ === WHILE+UNTIL: two different exit mechanisms from the same loop === +\ WHILE checks the pre-condition (value > 0), UNTIL checks a post-condition +\ (value = 5). The loop has two distinct exit paths that merge at THEN. + +\ CHECK-LABEL: func.func private @WHILE_UNTIL +\ Entry: branch to loop header +\ CHECK: cf.br ^bb1 + +\ Loop header: DUP 0 > → WHILE +\ CHECK: ^bb1(%{{.*}}: !forth.stack): +\ CHECK: forth.dup +\ CHECK: forth.literal %{{.*}} 0 +\ CHECK-NEXT: %{{.*}} = forth.gt +\ CHECK: forth.pop_flag +\ CHECK-NEXT: cf.cond_br %{{.*}}, ^bb2(%{{.*}} : !forth.stack), ^bb3(%{{.*}} : !forth.stack) + +\ WHILE body + UNTIL: 1 - DUP 5 = UNTIL +\ UNTIL true exits to ^bb4, UNTIL false loops back to ^bb1 +\ CHECK: ^bb2(%[[W:.*]]: !forth.stack): +\ CHECK-NEXT: %{{.*}} = forth.literal %[[W]] 1 +\ CHECK-NEXT: %{{.*}} = forth.sub +\ CHECK: forth.dup +\ CHECK: forth.literal %{{.*}} 5 +\ CHECK-NEXT: %{{.*}} = forth.eq +\ CHECK: forth.pop_flag +\ CHECK-NEXT: cf.cond_br %{{.*}}, ^bb4(%{{.*}} : !forth.stack), ^bb1(%{{.*}} : !forth.stack) + +\ WHILE exit / return block (also reached from THEN after UNTIL exit) +\ CHECK: ^bb3(%{{.*}}: !forth.stack): +\ CHECK-NEXT: return + +\ UNTIL exit → THEN (branch to WHILE exit) +\ CHECK: ^bb4(%{{.*}}: !forth.stack): +\ CHECK-NEXT: cf.br ^bb3 + +: while-until + BEGIN DUP 0 > WHILE 1 - DUP 5 = UNTIL THEN ; + +multi-while while-until diff --git a/test/Translation/Forth/nested-control-flow.forth b/test/Translation/Forth/nested-control-flow.forth index e0a43bf..2cbed2f 100644 --- a/test/Translation/Forth/nested-control-flow.forth +++ b/test/Translation/Forth/nested-control-flow.forth @@ -1,83 +1,154 @@ \ RUN: %warpforth-translate --forth-to-mlir %s | %FileCheck %s \ === Nested IF === -\ CHECK: %[[S0:.*]] = forth.stack -\ CHECK: %[[S1:.*]] = forth.literal %[[S0]] 1 -\ CHECK: %[[IF1:.*]] = forth.if %[[S1]] -\ CHECK: ^bb0(%[[A1:.*]]: !forth.stack): -\ CHECK: %[[D1:.*]] = forth.drop %[[A1]] -\ CHECK: %[[L2:.*]] = forth.literal %[[D1]] 2 -\ CHECK: %[[IF2:.*]] = forth.if %[[L2]] -\ CHECK: ^bb0(%[[A2:.*]]: !forth.stack): -\ CHECK: forth.drop %[[A2]] -\ CHECK: forth.literal %{{.*}} 3 -\ CHECK: forth.yield -\ CHECK: } else { -\ CHECK: ^bb0(%[[A3:.*]]: !forth.stack): -\ CHECK: forth.drop %[[A3]] -\ CHECK: forth.yield -\ CHECK: } -\ CHECK: forth.yield -\ CHECK: } else { -\ CHECK: ^bb0(%[[A4:.*]]: !forth.stack): -\ CHECK: forth.drop %[[A4]] -\ CHECK: forth.yield -\ CHECK: } +\ CHECK: %[[S0:.*]] = forth.stack !forth.stack +\ CHECK-NEXT: %[[S1:.*]] = forth.literal %[[S0]] 1 : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[PF1:.*]], %[[FL1:.*]] = forth.pop_flag %[[S1]] : !forth.stack -> !forth.stack, i1 +\ CHECK-NEXT: cf.cond_br %[[FL1]], ^bb1(%[[PF1]] : !forth.stack), ^bb2(%[[PF1]] : !forth.stack) +\ CHECK: ^bb1(%[[B1:.*]]: !forth.stack): +\ CHECK-NEXT: %[[L2:.*]] = forth.literal %[[B1]] 2 : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[PF2:.*]], %[[FL2:.*]] = forth.pop_flag %[[L2]] : !forth.stack -> !forth.stack, i1 +\ CHECK-NEXT: cf.cond_br %[[FL2]], ^bb3(%[[PF2]] : !forth.stack), ^bb4(%[[PF2]] : !forth.stack) 1 IF 2 IF 3 THEN THEN \ === IF inside DO === -\ CHECK: %[[S2:.*]] = forth.literal %[[IF1]] 10 -\ CHECK: %[[S3:.*]] = forth.literal %[[S2]] 0 -\ CHECK: %[[LOOP1:.*]] = forth.do_loop %[[S3]] -\ CHECK: ^bb0(%[[BA:.*]]: !forth.stack): -\ CHECK: %[[LI:.*]] = forth.loop_index %[[BA]] -\ CHECK: %[[L5:.*]] = forth.literal %[[LI]] 5 -\ CHECK: %[[GT:.*]] = forth.gt %[[L5]] -\ CHECK: forth.if %[[GT]] -\ CHECK: forth.loop_index -\ CHECK: forth.yield +\ CHECK: ^bb2(%[[B2:.*]]: !forth.stack): +\ CHECK-NEXT: %[[L10:.*]] = forth.literal %[[B2]] 10 : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[L0A:.*]] = forth.literal %[[L10]] 0 : !forth.stack -> !forth.stack +\ CHECK-NEXT: %[[POP1:.*]], %[[V1:.*]] = forth.pop %[[L0A]] : !forth.stack -> !forth.stack, i64 +\ CHECK-NEXT: %[[POP2:.*]], %[[V2:.*]] = forth.pop %[[POP1]] : !forth.stack -> !forth.stack, i64 +\ CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloca() : memref<1xi64> +\ CHECK-NEXT: %{{.*}} = arith.constant 0 : index +\ CHECK-NEXT: memref.store %[[V1]], %[[ALLOC1]][%{{.*}}] : memref<1xi64> +\ CHECK-NEXT: cf.br ^bb5(%[[POP2]] : !forth.stack) 10 0 DO I 5 > IF I THEN LOOP +\ Nested IF: true branch pushes 3, then merges +\ CHECK: ^bb3(%[[B3:.*]]: !forth.stack): +\ CHECK-NEXT: %[[L3:.*]] = forth.literal %[[B3]] 3 : !forth.stack -> !forth.stack +\ CHECK-NEXT: cf.br ^bb4(%[[L3]] : !forth.stack) +\ CHECK: ^bb4(%[[B4:.*]]: !forth.stack): +\ CHECK-NEXT: cf.br ^bb2(%[[B4]] : !forth.stack) + +\ DO loop header: check index < limit +\ CHECK: ^bb5(%[[B5:.*]]: !forth.stack): +\ CHECK: arith.cmpi slt +\ CHECK-NEXT: cf.cond_br %{{.*}}, ^bb6(%[[B5]] : !forth.stack), ^bb7(%[[B5]] : !forth.stack) + +\ DO loop body: I 5 > IF I THEN +\ CHECK: ^bb6(%[[B6:.*]]: !forth.stack): +\ CHECK: forth.push_value %[[B6]] +\ CHECK: forth.literal %{{.*}} 5 +\ CHECK-NEXT: %{{.*}} = forth.gt +\ CHECK: forth.pop_flag +\ CHECK-NEXT: cf.cond_br %{{.*}}, ^bb8(%{{.*}} : !forth.stack), ^bb9(%{{.*}} : !forth.stack) + \ === Nested DO with J === -\ CHECK: forth.do_loop -\ CHECK: forth.do_loop -\ CHECK: forth.loop_index %{{.*}} {depth = 1 : i64} -\ CHECK: forth.loop_index %{{.*}} -\ CHECK: forth.add +\ After first DO loop exits: bb7 sets up nested DO (3 0 DO) +\ CHECK: ^bb7(%[[B7:.*]]: !forth.stack): +\ CHECK-NEXT: %{{.*}} = forth.literal %[[B7]] 3 3 0 DO 4 0 DO J I + LOOP LOOP +\ IF I true branch: push loop index +\ CHECK: ^bb8(%[[B8:.*]]: !forth.stack): +\ CHECK: forth.push_value %[[B8]] +\ CHECK-NEXT: cf.br ^bb9 + +\ Loop increment and back-edge +\ CHECK: ^bb9(%{{.*}}: !forth.stack): +\ CHECK: arith.addi +\ CHECK: memref.store +\ CHECK: cf.br ^bb5 + +\ Outer DO loop (3 0 DO) header +\ CHECK: ^bb10(%{{.*}}: !forth.stack): +\ CHECK: arith.cmpi slt +\ CHECK: cf.cond_br + +\ Inner DO setup (4 0 DO) +\ CHECK: ^bb11(%{{.*}}: !forth.stack): +\ CHECK: forth.literal %{{.*}} 4 +\ CHECK: forth.literal %{{.*}} 0 +\ CHECK: forth.pop +\ CHECK: forth.pop +\ CHECK: memref.alloca() + \ === Triple-nested DO with K === -\ CHECK: forth.do_loop -\ CHECK: forth.do_loop -\ CHECK: forth.do_loop -\ CHECK: forth.loop_index %{{.*}} {depth = 2 : i64} -\ CHECK: forth.loop_index %{{.*}} {depth = 1 : i64} -\ CHECK: forth.loop_index %{{.*}} -\ CHECK: forth.add -\ CHECK: forth.add +\ After nested DO exits: sets up triple-nested DO (2 0 DO) +\ CHECK: ^bb12(%{{.*}}: !forth.stack): +\ CHECK: forth.literal %{{.*}} 2 2 0 DO 2 0 DO 2 0 DO K J I + + LOOP LOOP LOOP +\ Inner loop of J I + (bb13 header, bb14 body) +\ CHECK: ^bb13(%{{.*}}: !forth.stack): +\ CHECK: arith.cmpi slt +\ CHECK: cf.cond_br + +\ J I + body +\ CHECK: ^bb14(%{{.*}}: !forth.stack): +\ CHECK: forth.push_value +\ CHECK: forth.push_value +\ CHECK: forth.add + +\ Outer loop increment (bb15) +\ CHECK: ^bb15(%{{.*}}: !forth.stack): +\ CHECK: arith.addi +\ CHECK: cf.br ^bb10 + +\ Triple-nested outer loop header (bb16) +\ CHECK: ^bb16(%{{.*}}: !forth.stack): +\ CHECK: arith.cmpi slt +\ CHECK: cf.cond_br + +\ Triple-nested middle loop setup (bb17) +\ CHECK: ^bb17(%{{.*}}: !forth.stack): +\ CHECK: forth.literal %{{.*}} 2 +\ CHECK: forth.literal %{{.*}} 0 + \ === BEGIN/WHILE inside IF === -\ CHECK: forth.if -\ CHECK: forth.begin_while_repeat -\ CHECK: forth.dup -\ CHECK: forth.yield %{{.*}} while_cond -\ CHECK: } do { -\ CHECK: forth.literal %{{.*}} 1 -\ CHECK: forth.sub -\ CHECK: forth.yield +\ After triple-nested exits: 5 IF BEGIN DUP WHILE 1 - REPEAT THEN +\ CHECK: ^bb18(%{{.*}}: !forth.stack): +\ CHECK: forth.literal %{{.*}} 5 +\ CHECK: forth.pop_flag +\ CHECK-NEXT: cf.cond_br 5 IF BEGIN DUP WHILE 1 - REPEAT THEN +\ bb25: IF true branch -> jump to begin/while header +\ CHECK: ^bb25(%{{.*}}: !forth.stack): +\ CHECK-NEXT: cf.br ^bb27 + +\ bb26: IF false branch (and WHILE exit) -> jump to BEGIN/UNTIL +\ CHECK: ^bb26(%{{.*}}: !forth.stack): +\ CHECK-NEXT: cf.br ^bb30 + +\ WHILE condition: DUP + pop_flag +\ CHECK: ^bb27(%{{.*}}: !forth.stack): +\ CHECK: forth.dup +\ CHECK: forth.pop_flag +\ CHECK-NEXT: cf.cond_br + +\ WHILE body: 1 - +\ CHECK: ^bb28(%[[B28:.*]]: !forth.stack): +\ CHECK-NEXT: %{{.*}} = forth.literal %[[B28]] 1 +\ CHECK-NEXT: %{{.*}} = forth.sub + \ === IF inside BEGIN/UNTIL === -\ CHECK: forth.begin_until -\ CHECK: forth.dup -\ CHECK: forth.literal %{{.*}} 10 -\ CHECK: forth.lt -\ CHECK: forth.if -\ CHECK: forth.literal %{{.*}} 1 -\ CHECK: forth.add -\ CHECK: forth.dup -\ CHECK: forth.literal %{{.*}} 20 -\ CHECK: forth.eq -\ CHECK: forth.yield +\ BEGIN/UNTIL header: DUP 10 < +\ CHECK: ^bb30(%{{.*}}: !forth.stack): +\ CHECK: forth.dup +\ CHECK: forth.literal %{{.*}} 10 +\ CHECK-NEXT: %{{.*}} = forth.lt BEGIN DUP 10 < IF 1 + THEN DUP 20 = UNTIL + +\ IF true branch: 1 + +\ CHECK: ^bb31(%[[B31:.*]]: !forth.stack): +\ CHECK-NEXT: %{{.*}} = forth.literal %[[B31]] 1 +\ CHECK-NEXT: %{{.*}} = forth.add + +\ UNTIL condition: DUP 20 = +\ CHECK: ^bb32(%{{.*}}: !forth.stack): +\ CHECK: forth.dup +\ CHECK: forth.literal %{{.*}} 20 +\ CHECK-NEXT: %{{.*}} = forth.eq +\ CHECK: forth.pop_flag +\ CHECK-NEXT: cf.cond_br