From 365cc5bdc873899ae22600771dd1e39424b097d3 Mon Sep 17 00:00:00 2001 From: Ashley Coleman Date: Fri, 19 Dec 2025 15:33:20 -0700 Subject: [PATCH 1/2] [SM6.10] Reserve LinAlg DXIL Opcodes --- docs/DXIL.rst | 28 +- include/dxc/DXIL/DxilConstants.h | 125 +++++- include/dxc/DXIL/DxilInstructions.h | 638 ++++++++++++++++++++++++++++ lib/DXIL/DxilOperations.cpp | 376 +++++++++++++++- utils/hct/hctdb.py | 288 +++++++++++++ 5 files changed, 1448 insertions(+), 7 deletions(-) diff --git a/docs/DXIL.rst b/docs/DXIL.rst index e9514212dd..a5f0536c4e 100644 --- a/docs/DXIL.rst +++ b/docs/DXIL.rst @@ -3062,9 +3062,9 @@ Given width, offset: Opcode Table ExperimentalOps, id=32768: Experimental DXIL operations -========== ======================================== ================================================================== +========== ======================================== =================================================================================================================== ID Name Description -========== ======================================== ================================================================== +========== ======================================== =================================================================================================================== 2147483648 ExperimentalNop nop does nothing 2147483649 GetGroupWaveIndex returns the index of the wave in the thread group 2147483650 GetGroupWaveCount returns the number of waves in the thread group @@ -3076,7 +3076,29 @@ ID Name Description 2147483656 RayQuery_CandidateTriangleObjectPosition returns candidate triangle vertices in object space as <9 x float> 2147483657 RayQuery_CommittedTriangleObjectPosition returns committed triangle vertices in object space as <9 x float> 2147483658 HitObject_TriangleObjectPosition returns triangle vertices in object space as <9 x float> -========== ======================================== ================================================================== +2147483659 CreateMatrix creates a handle to a Matrix +2147483660 FillMatrix fills a matrix with a scalar value +2147483661 CopyConvertMatrix Converts and copies the element and use type of the source matrix to the destination matrix with optional transpose +2147483662 MatrixLoadFromDescriptor fills a matrix with data from a [RW]ByteAddressBuffer +2147483663 MatrixLoadFromMemory fills a matrix with data from a groupshared array +2147483664 MatrixLength returns the number of elements stored in thread-local storage on the active thread for the provided matrix +2147483665 MatrixGetCoordinate returns a two element vector containing the column and row of the matrix that the thread-local index corresponds to +2147483666 MatrixGetElement returns the element of the matrix corresponding to the provided thread-local index +2147483667 MatrixSetElement sets the element of the matrix corresponding to the provided thread-local index +2147483668 MatrixStoreToDescriptor stores a matrix to a RWByteAddressBuffer +2147483669 MatrixStoreToMemory stores a matrix to groupshared memory +2147483670 MatrixQueryAccumulatorLayout returns comptime 0 when accumulator matrix are A layout, 1 when B layout +2147483671 MatrixMulOp applies a multiplication op to matrix C using A and B as parameters +2147483672 MatrixAccumulate accumulate A or B matrix into Accumulator matrix following LHS += RHS +2147483673 MatrixVecMul Multiplies a MxK dimension matrix and a K sized input vector +2147483674 MatrixVecMulAdd Multiplies a MxK dimension matrix and a K sized input vector then adds a M sized bias vector +2147483675 MatrixAccumulateToDescriptor accumulates a matrix to a RWByteAddressBuffer +2147483676 MatrixAccumulateToMemory accumulates a matrix to groupshared memory +2147483677 MatrixOuterProduct Outer products an M sized vector and a K sized vector producing an MxK matrix +2147483678 LinAlgMatrixReserved0 reserved +2147483679 LinAlgMatrixReserved1 reserved +2147483680 LinAlgMatrixReserved2 reserved +========== ======================================== =================================================================================================================== .. OPCODES-RST:END diff --git a/include/dxc/DXIL/DxilConstants.h b/include/dxc/DXIL/DxilConstants.h index 83615dabfd..2c00080ada 100644 --- a/include/dxc/DXIL/DxilConstants.h +++ b/include/dxc/DXIL/DxilConstants.h @@ -511,6 +511,11 @@ namespace ExperimentalOps { static const OpCodeTableID TableID = OpCodeTableID::ExperimentalOps; // Enumeration for ExperimentalOps DXIL operations enum class OpCode : unsigned { + // + LinAlgMatrixReserved0 = 30, // reserved + LinAlgMatrixReserved1 = 31, // reserved + LinAlgMatrixReserved2 = 32, // reserved + // Group Wave Ops GetGroupWaveCount = 2, // returns the number of waves in the thread group GetGroupWaveIndex = 1, // returns the index of the wave in the thread group @@ -523,6 +528,43 @@ enum class OpCode : unsigned { RayQuery_CommittedTriangleObjectPosition = 9, // returns committed triangle vertices in object space as <9 x float> + // Linear Algebra Operations + CopyConvertMatrix = + 13, // Converts and copies the element and use type of the source matrix + // to the destination matrix with optional transpose + CreateMatrix = 11, // creates a handle to a Matrix + FillMatrix = 12, // fills a matrix with a scalar value + MatrixAccumulate = 24, // accumulate A or B matrix into Accumulator matrix + // following LHS += RHS + MatrixAccumulateToDescriptor = + 27, // accumulates a matrix to a RWByteAddressBuffer + MatrixAccumulateToMemory = 28, // accumulates a matrix to groupshared memory + MatrixGetCoordinate = + 17, // returns a two element vector containing the column and row of the + // matrix that the thread-local index corresponds to + MatrixGetElement = 18, // returns the element of the matrix corresponding to + // the provided thread-local index + MatrixLength = 16, // returns the number of elements stored in thread-local + // storage on the active thread for the provided matrix + MatrixLoadFromDescriptor = + 14, // fills a matrix with data from a [RW]ByteAddressBuffer + MatrixLoadFromMemory = + 15, // fills a matrix with data from a groupshared array + MatrixMulOp = + 23, // applies a multiplication op to matrix C using A and B as parameters + MatrixOuterProduct = 29, // Outer products an M sized vector and a K sized + // vector producing an MxK matrix + MatrixQueryAccumulatorLayout = 22, // returns comptime 0 when accumulator + // matrix are A layout, 1 when B layout + MatrixSetElement = 19, // sets the element of the matrix corresponding to the + // provided thread-local index + MatrixStoreToDescriptor = 20, // stores a matrix to a RWByteAddressBuffer + MatrixStoreToMemory = 21, // stores a matrix to groupshared memory + MatrixVecMul = + 25, // Multiplies a MxK dimension matrix and a K sized input vector + MatrixVecMulAdd = 26, // Multiplies a MxK dimension matrix and a K sized input + // vector then adds a M sized bias vector + // No-op ExperimentalNop = 0, // nop does nothing @@ -538,7 +580,7 @@ enum class OpCode : unsigned { HitObject_TriangleObjectPosition = 10, // returns triangle vertices in object space as <9 x float> - NumOpCodes = 11, // exclusive last value of enumeration + NumOpCodes = 33, // exclusive last value of enumeration }; } // namespace ExperimentalOps static const unsigned NumOpCodeTables = 2; @@ -1187,6 +1229,66 @@ enum class OpCode : unsigned { EXP_OPCODE(ExperimentalOps, HitObject_TriangleObjectPosition), // returns triangle vertices in // object space as <9 x float> + EXP_OPCODE(ExperimentalOps, CreateMatrix), // creates a handle to a Matrix + EXP_OPCODE(ExperimentalOps, FillMatrix), // fills a matrix with a scalar value + EXP_OPCODE(ExperimentalOps, + CopyConvertMatrix), // Converts and copies the element and use type + // of the source matrix to the destination + // matrix with optional transpose + EXP_OPCODE(ExperimentalOps, + MatrixLoadFromDescriptor), // fills a matrix with data from a + // [RW]ByteAddressBuffer + EXP_OPCODE(ExperimentalOps, MatrixLoadFromMemory), // fills a matrix with data + // from a groupshared array + EXP_OPCODE( + ExperimentalOps, + MatrixLength), // returns the number of elements stored in thread-local + // storage on the active thread for the provided matrix + EXP_OPCODE(ExperimentalOps, + MatrixGetCoordinate), // returns a two element vector containing + // the column and row of the matrix that the + // thread-local index corresponds to + EXP_OPCODE( + ExperimentalOps, + MatrixGetElement), // returns the element of the matrix corresponding to + // the provided thread-local index + EXP_OPCODE(ExperimentalOps, + MatrixSetElement), // sets the element of the matrix corresponding + // to the provided thread-local index + EXP_OPCODE( + ExperimentalOps, + MatrixStoreToDescriptor), // stores a matrix to a RWByteAddressBuffer + EXP_OPCODE(ExperimentalOps, + MatrixStoreToMemory), // stores a matrix to groupshared memory + EXP_OPCODE( + ExperimentalOps, + MatrixQueryAccumulatorLayout), // returns comptime 0 when accumulator + // matrix are A layout, 1 when B layout + EXP_OPCODE(ExperimentalOps, + MatrixMulOp), // applies a multiplication op to matrix C using A + // and B as parameters + EXP_OPCODE(ExperimentalOps, + MatrixAccumulate), // accumulate A or B matrix into Accumulator + // matrix following LHS += RHS + EXP_OPCODE(ExperimentalOps, + MatrixVecMul), // Multiplies a MxK dimension matrix and a K sized + // input vector + EXP_OPCODE( + ExperimentalOps, + MatrixVecMulAdd), // Multiplies a MxK dimension matrix and a K sized input + // vector then adds a M sized bias vector + EXP_OPCODE(ExperimentalOps, + MatrixAccumulateToDescriptor), // accumulates a matrix to a + // RWByteAddressBuffer + EXP_OPCODE( + ExperimentalOps, + MatrixAccumulateToMemory), // accumulates a matrix to groupshared memory + EXP_OPCODE(ExperimentalOps, + MatrixOuterProduct), // Outer products an M sized vector and a K + // sized vector producing an MxK matrix + EXP_OPCODE(ExperimentalOps, LinAlgMatrixReserved0), // reserved + EXP_OPCODE(ExperimentalOps, LinAlgMatrixReserved1), // reserved + EXP_OPCODE(ExperimentalOps, LinAlgMatrixReserved2), // reserved }; // OPCODE-ENUM:END #undef EXP_OPCODE @@ -1342,8 +1444,27 @@ enum class OpCodeClass : unsigned { CreateHandleForLib, // Linear Algebra Operations + CopyConvertMatrix, + CreateMatrix, + FillMatrix, MatVecMul, MatVecMulAdd, + MatrixAccumulate, + MatrixAccumulateToDescriptor, + MatrixAccumulateToMemory, + MatrixGetCoordinate, + MatrixGetElement, + MatrixLength, + MatrixLoadFromDescriptor, + MatrixLoadFromMemory, + MatrixMulOp, + MatrixOuterProduct, + MatrixQueryAccumulatorLayout, + MatrixSetElement, + MatrixStoreToDescriptor, + MatrixStoreToMemory, + MatrixVecMul, + MatrixVecMulAdd, OuterProductAccumulate, VectorAccumulate, @@ -1532,7 +1653,7 @@ enum class OpCodeClass : unsigned { NodeOutputIsValid, OutputComplete, - NumOpClasses = 204, // exclusive last value of enumeration + NumOpClasses = 223, // exclusive last value of enumeration }; // OPCODECLASS-ENUM:END diff --git a/include/dxc/DXIL/DxilInstructions.h b/include/dxc/DXIL/DxilInstructions.h index 2bef121009..4646a0b872 100644 --- a/include/dxc/DXIL/DxilInstructions.h +++ b/include/dxc/DXIL/DxilInstructions.h @@ -10499,5 +10499,643 @@ struct DxilInst_HitObject_TriangleObjectPosition { llvm::Value *get_hitObject() const { return Instr->getOperand(1); } void set_hitObject(llvm::Value *val) { Instr->setOperand(1, val); } }; + +/// This instruction creates a handle to a Matrix +struct DxilInst_CreateMatrix { + llvm::Instruction *Instr; + // Construction and identification + DxilInst_CreateMatrix(llvm::Instruction *pInstr) : Instr(pInstr) {} + operator bool() const { + return hlsl::OP::IsDxilOpFuncCallInst(Instr, + hlsl::OP::OpCode::CreateMatrix); + } + // Validation support + bool isAllowed() const { return true; } + bool isArgumentListValid() const { + if (1 != llvm::dyn_cast(Instr)->getNumArgOperands()) + return false; + return true; + } + // Metadata + bool requiresUniformInputs() const { return false; } +}; + +/// This instruction fills a matrix with a scalar value +struct DxilInst_FillMatrix { + llvm::Instruction *Instr; + // Construction and identification + DxilInst_FillMatrix(llvm::Instruction *pInstr) : Instr(pInstr) {} + operator bool() const { + return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::FillMatrix); + } + // Validation support + bool isAllowed() const { return true; } + bool isArgumentListValid() const { + if (3 != llvm::dyn_cast(Instr)->getNumArgOperands()) + return false; + return true; + } + // Metadata + bool requiresUniformInputs() const { return false; } + // Operand indexes + enum OperandIdx { + arg_matrixRef = 1, + arg_value = 2, + }; + // Accessors + llvm::Value *get_matrixRef() const { return Instr->getOperand(1); } + void set_matrixRef(llvm::Value *val) { Instr->setOperand(1, val); } + llvm::Value *get_value() const { return Instr->getOperand(2); } + void set_value(llvm::Value *val) { Instr->setOperand(2, val); } +}; + +/// This instruction Converts and copies the element and use type of the source +/// matrix to the destination matrix with optional transpose +struct DxilInst_CopyConvertMatrix { + llvm::Instruction *Instr; + // Construction and identification + DxilInst_CopyConvertMatrix(llvm::Instruction *pInstr) : Instr(pInstr) {} + operator bool() const { + return hlsl::OP::IsDxilOpFuncCallInst(Instr, + hlsl::OP::OpCode::CopyConvertMatrix); + } + // Validation support + bool isAllowed() const { return true; } + bool isArgumentListValid() const { + if (4 != llvm::dyn_cast(Instr)->getNumArgOperands()) + return false; + return true; + } + // Metadata + bool requiresUniformInputs() const { return false; } + // Operand indexes + enum OperandIdx { + arg_destMatrixRef = 1, + arg_srcMatrixRef = 2, + arg_transpose = 3, + }; + // Accessors + llvm::Value *get_destMatrixRef() const { return Instr->getOperand(1); } + void set_destMatrixRef(llvm::Value *val) { Instr->setOperand(1, val); } + llvm::Value *get_srcMatrixRef() const { return Instr->getOperand(2); } + void set_srcMatrixRef(llvm::Value *val) { Instr->setOperand(2, val); } + llvm::Value *get_transpose() const { return Instr->getOperand(3); } + void set_transpose(llvm::Value *val) { Instr->setOperand(3, val); } +}; + +/// This instruction fills a matrix with data from a [RW]ByteAddressBuffer +struct DxilInst_MatrixLoadFromDescriptor { + llvm::Instruction *Instr; + // Construction and identification + DxilInst_MatrixLoadFromDescriptor(llvm::Instruction *pInstr) + : Instr(pInstr) {} + operator bool() const { + return hlsl::OP::IsDxilOpFuncCallInst( + Instr, hlsl::OP::OpCode::MatrixLoadFromDescriptor); + } + // Validation support + bool isAllowed() const { return true; } + bool isArgumentListValid() const { + if (6 != llvm::dyn_cast(Instr)->getNumArgOperands()) + return false; + return true; + } + // Metadata + bool requiresUniformInputs() const { return false; } + // Operand indexes + enum OperandIdx { + arg_matrixRef = 1, + arg_handle = 2, + arg_offset = 3, + arg_stride = 4, + arg_layout = 5, + }; + // Accessors + llvm::Value *get_matrixRef() const { return Instr->getOperand(1); } + void set_matrixRef(llvm::Value *val) { Instr->setOperand(1, val); } + llvm::Value *get_handle() const { return Instr->getOperand(2); } + void set_handle(llvm::Value *val) { Instr->setOperand(2, val); } + llvm::Value *get_offset() const { return Instr->getOperand(3); } + void set_offset(llvm::Value *val) { Instr->setOperand(3, val); } + llvm::Value *get_stride() const { return Instr->getOperand(4); } + void set_stride(llvm::Value *val) { Instr->setOperand(4, val); } + llvm::Value *get_layout() const { return Instr->getOperand(5); } + void set_layout(llvm::Value *val) { Instr->setOperand(5, val); } +}; + +/// This instruction fills a matrix with data from a groupshared array +struct DxilInst_MatrixLoadFromMemory { + llvm::Instruction *Instr; + // Construction and identification + DxilInst_MatrixLoadFromMemory(llvm::Instruction *pInstr) : Instr(pInstr) {} + operator bool() const { + return hlsl::OP::IsDxilOpFuncCallInst( + Instr, hlsl::OP::OpCode::MatrixLoadFromMemory); + } + // Validation support + bool isAllowed() const { return true; } + bool isArgumentListValid() const { + if (6 != llvm::dyn_cast(Instr)->getNumArgOperands()) + return false; + return true; + } + // Metadata + bool requiresUniformInputs() const { return false; } + // Operand indexes + enum OperandIdx { + arg_matrixRef = 1, + arg_groupsharedArr = 2, + arg_offset = 3, + arg_stride = 4, + arg_layout = 5, + }; + // Accessors + llvm::Value *get_matrixRef() const { return Instr->getOperand(1); } + void set_matrixRef(llvm::Value *val) { Instr->setOperand(1, val); } + llvm::Value *get_groupsharedArr() const { return Instr->getOperand(2); } + void set_groupsharedArr(llvm::Value *val) { Instr->setOperand(2, val); } + llvm::Value *get_offset() const { return Instr->getOperand(3); } + void set_offset(llvm::Value *val) { Instr->setOperand(3, val); } + llvm::Value *get_stride() const { return Instr->getOperand(4); } + void set_stride(llvm::Value *val) { Instr->setOperand(4, val); } + llvm::Value *get_layout() const { return Instr->getOperand(5); } + void set_layout(llvm::Value *val) { Instr->setOperand(5, val); } +}; + +/// This instruction returns the number of elements stored in thread-local +/// storage on the active thread for the provided matrix +struct DxilInst_MatrixLength { + llvm::Instruction *Instr; + // Construction and identification + DxilInst_MatrixLength(llvm::Instruction *pInstr) : Instr(pInstr) {} + operator bool() const { + return hlsl::OP::IsDxilOpFuncCallInst(Instr, + hlsl::OP::OpCode::MatrixLength); + } + // Validation support + bool isAllowed() const { return true; } + bool isArgumentListValid() const { + if (2 != llvm::dyn_cast(Instr)->getNumArgOperands()) + return false; + return true; + } + // Metadata + bool requiresUniformInputs() const { return false; } + // Operand indexes + enum OperandIdx { + arg_matrixRef = 1, + }; + // Accessors + llvm::Value *get_matrixRef() const { return Instr->getOperand(1); } + void set_matrixRef(llvm::Value *val) { Instr->setOperand(1, val); } +}; + +/// This instruction returns a two element vector containing the column and row +/// of the matrix that the thread-local index corresponds to +struct DxilInst_MatrixGetCoordinate { + llvm::Instruction *Instr; + // Construction and identification + DxilInst_MatrixGetCoordinate(llvm::Instruction *pInstr) : Instr(pInstr) {} + operator bool() const { + return hlsl::OP::IsDxilOpFuncCallInst( + Instr, hlsl::OP::OpCode::MatrixGetCoordinate); + } + // Validation support + bool isAllowed() const { return true; } + bool isArgumentListValid() const { + if (3 != llvm::dyn_cast(Instr)->getNumArgOperands()) + return false; + return true; + } + // Metadata + bool requiresUniformInputs() const { return false; } + // Operand indexes + enum OperandIdx { + arg_matrixRef = 1, + arg_threadLocalIndex = 2, + }; + // Accessors + llvm::Value *get_matrixRef() const { return Instr->getOperand(1); } + void set_matrixRef(llvm::Value *val) { Instr->setOperand(1, val); } + llvm::Value *get_threadLocalIndex() const { return Instr->getOperand(2); } + void set_threadLocalIndex(llvm::Value *val) { Instr->setOperand(2, val); } +}; + +/// This instruction returns the element of the matrix corresponding to the +/// provided thread-local index +struct DxilInst_MatrixGetElement { + llvm::Instruction *Instr; + // Construction and identification + DxilInst_MatrixGetElement(llvm::Instruction *pInstr) : Instr(pInstr) {} + operator bool() const { + return hlsl::OP::IsDxilOpFuncCallInst(Instr, + hlsl::OP::OpCode::MatrixGetElement); + } + // Validation support + bool isAllowed() const { return true; } + bool isArgumentListValid() const { + if (3 != llvm::dyn_cast(Instr)->getNumArgOperands()) + return false; + return true; + } + // Metadata + bool requiresUniformInputs() const { return false; } + // Operand indexes + enum OperandIdx { + arg_matrixRef = 1, + arg_threadLocalIndex = 2, + }; + // Accessors + llvm::Value *get_matrixRef() const { return Instr->getOperand(1); } + void set_matrixRef(llvm::Value *val) { Instr->setOperand(1, val); } + llvm::Value *get_threadLocalIndex() const { return Instr->getOperand(2); } + void set_threadLocalIndex(llvm::Value *val) { Instr->setOperand(2, val); } +}; + +/// This instruction sets the element of the matrix corresponding to the +/// provided thread-local index +struct DxilInst_MatrixSetElement { + llvm::Instruction *Instr; + // Construction and identification + DxilInst_MatrixSetElement(llvm::Instruction *pInstr) : Instr(pInstr) {} + operator bool() const { + return hlsl::OP::IsDxilOpFuncCallInst(Instr, + hlsl::OP::OpCode::MatrixSetElement); + } + // Validation support + bool isAllowed() const { return true; } + bool isArgumentListValid() const { + if (4 != llvm::dyn_cast(Instr)->getNumArgOperands()) + return false; + return true; + } + // Metadata + bool requiresUniformInputs() const { return false; } + // Operand indexes + enum OperandIdx { + arg_matrixRef = 1, + arg_threadLocalIndex = 2, + arg_value = 3, + }; + // Accessors + llvm::Value *get_matrixRef() const { return Instr->getOperand(1); } + void set_matrixRef(llvm::Value *val) { Instr->setOperand(1, val); } + llvm::Value *get_threadLocalIndex() const { return Instr->getOperand(2); } + void set_threadLocalIndex(llvm::Value *val) { Instr->setOperand(2, val); } + llvm::Value *get_value() const { return Instr->getOperand(3); } + void set_value(llvm::Value *val) { Instr->setOperand(3, val); } +}; + +/// This instruction stores a matrix to a RWByteAddressBuffer +struct DxilInst_MatrixStoreToDescriptor { + llvm::Instruction *Instr; + // Construction and identification + DxilInst_MatrixStoreToDescriptor(llvm::Instruction *pInstr) : Instr(pInstr) {} + operator bool() const { + return hlsl::OP::IsDxilOpFuncCallInst( + Instr, hlsl::OP::OpCode::MatrixStoreToDescriptor); + } + // Validation support + bool isAllowed() const { return true; } + bool isArgumentListValid() const { + if (6 != llvm::dyn_cast(Instr)->getNumArgOperands()) + return false; + return true; + } + // Metadata + bool requiresUniformInputs() const { return false; } + // Operand indexes + enum OperandIdx { + arg_matrixRef = 1, + arg_handle = 2, + arg_offset = 3, + arg_stride = 4, + arg_layout = 5, + }; + // Accessors + llvm::Value *get_matrixRef() const { return Instr->getOperand(1); } + void set_matrixRef(llvm::Value *val) { Instr->setOperand(1, val); } + llvm::Value *get_handle() const { return Instr->getOperand(2); } + void set_handle(llvm::Value *val) { Instr->setOperand(2, val); } + llvm::Value *get_offset() const { return Instr->getOperand(3); } + void set_offset(llvm::Value *val) { Instr->setOperand(3, val); } + llvm::Value *get_stride() const { return Instr->getOperand(4); } + void set_stride(llvm::Value *val) { Instr->setOperand(4, val); } + llvm::Value *get_layout() const { return Instr->getOperand(5); } + void set_layout(llvm::Value *val) { Instr->setOperand(5, val); } +}; + +/// This instruction stores a matrix to groupshared memory +struct DxilInst_MatrixStoreToMemory { + llvm::Instruction *Instr; + // Construction and identification + DxilInst_MatrixStoreToMemory(llvm::Instruction *pInstr) : Instr(pInstr) {} + operator bool() const { + return hlsl::OP::IsDxilOpFuncCallInst( + Instr, hlsl::OP::OpCode::MatrixStoreToMemory); + } + // Validation support + bool isAllowed() const { return true; } + bool isArgumentListValid() const { + if (6 != llvm::dyn_cast(Instr)->getNumArgOperands()) + return false; + return true; + } + // Metadata + bool requiresUniformInputs() const { return false; } + // Operand indexes + enum OperandIdx { + arg_matrixRef = 1, + arg_groupsharedArr = 2, + arg_offset = 3, + arg_stride = 4, + arg_layout = 5, + }; + // Accessors + llvm::Value *get_matrixRef() const { return Instr->getOperand(1); } + void set_matrixRef(llvm::Value *val) { Instr->setOperand(1, val); } + llvm::Value *get_groupsharedArr() const { return Instr->getOperand(2); } + void set_groupsharedArr(llvm::Value *val) { Instr->setOperand(2, val); } + llvm::Value *get_offset() const { return Instr->getOperand(3); } + void set_offset(llvm::Value *val) { Instr->setOperand(3, val); } + llvm::Value *get_stride() const { return Instr->getOperand(4); } + void set_stride(llvm::Value *val) { Instr->setOperand(4, val); } + llvm::Value *get_layout() const { return Instr->getOperand(5); } + void set_layout(llvm::Value *val) { Instr->setOperand(5, val); } +}; + +/// This instruction returns comptime 0 when accumulator matrix are A layout, 1 +/// when B layout +struct DxilInst_MatrixQueryAccumulatorLayout { + llvm::Instruction *Instr; + // Construction and identification + DxilInst_MatrixQueryAccumulatorLayout(llvm::Instruction *pInstr) + : Instr(pInstr) {} + operator bool() const { + return hlsl::OP::IsDxilOpFuncCallInst( + Instr, hlsl::OP::OpCode::MatrixQueryAccumulatorLayout); + } + // Validation support + bool isAllowed() const { return true; } + bool isArgumentListValid() const { + if (1 != llvm::dyn_cast(Instr)->getNumArgOperands()) + return false; + return true; + } + // Metadata + bool requiresUniformInputs() const { return false; } +}; + +/// This instruction applies a multiplication op to matrix C using A and B as +/// parameters +struct DxilInst_MatrixMulOp { + llvm::Instruction *Instr; + // Construction and identification + DxilInst_MatrixMulOp(llvm::Instruction *pInstr) : Instr(pInstr) {} + operator bool() const { + return hlsl::OP::IsDxilOpFuncCallInst(Instr, hlsl::OP::OpCode::MatrixMulOp); + } + // Validation support + bool isAllowed() const { return true; } + bool isArgumentListValid() const { + if (4 != llvm::dyn_cast(Instr)->getNumArgOperands()) + return false; + return true; + } + // Metadata + bool requiresUniformInputs() const { return false; } + // Operand indexes + enum OperandIdx { + arg_matrixRefA = 1, + arg_matrixRefB = 2, + arg_matrixRefC = 3, + }; + // Accessors + llvm::Value *get_matrixRefA() const { return Instr->getOperand(1); } + void set_matrixRefA(llvm::Value *val) { Instr->setOperand(1, val); } + llvm::Value *get_matrixRefB() const { return Instr->getOperand(2); } + void set_matrixRefB(llvm::Value *val) { Instr->setOperand(2, val); } + llvm::Value *get_matrixRefC() const { return Instr->getOperand(3); } + void set_matrixRefC(llvm::Value *val) { Instr->setOperand(3, val); } +}; + +/// This instruction accumulate A or B matrix into Accumulator matrix following +/// LHS += RHS +struct DxilInst_MatrixAccumulate { + llvm::Instruction *Instr; + // Construction and identification + DxilInst_MatrixAccumulate(llvm::Instruction *pInstr) : Instr(pInstr) {} + operator bool() const { + return hlsl::OP::IsDxilOpFuncCallInst(Instr, + hlsl::OP::OpCode::MatrixAccumulate); + } + // Validation support + bool isAllowed() const { return true; } + bool isArgumentListValid() const { + if (3 != llvm::dyn_cast(Instr)->getNumArgOperands()) + return false; + return true; + } + // Metadata + bool requiresUniformInputs() const { return false; } + // Operand indexes + enum OperandIdx { + arg_matrixRefRHS = 1, + arg_matrixRefLHS = 2, + }; + // Accessors + llvm::Value *get_matrixRefRHS() const { return Instr->getOperand(1); } + void set_matrixRefRHS(llvm::Value *val) { Instr->setOperand(1, val); } + llvm::Value *get_matrixRefLHS() const { return Instr->getOperand(2); } + void set_matrixRefLHS(llvm::Value *val) { Instr->setOperand(2, val); } +}; + +/// This instruction Multiplies a MxK dimension matrix and a K sized input +/// vector +struct DxilInst_MatrixVecMul { + llvm::Instruction *Instr; + // Construction and identification + DxilInst_MatrixVecMul(llvm::Instruction *pInstr) : Instr(pInstr) {} + operator bool() const { + return hlsl::OP::IsDxilOpFuncCallInst(Instr, + hlsl::OP::OpCode::MatrixVecMul); + } + // Validation support + bool isAllowed() const { return true; } + bool isArgumentListValid() const { + if (4 != llvm::dyn_cast(Instr)->getNumArgOperands()) + return false; + return true; + } + // Metadata + bool requiresUniformInputs() const { return false; } + // Operand indexes + enum OperandIdx { + arg_matrixRef = 1, + arg_inputVector = 2, + arg_interpretation = 3, + }; + // Accessors + llvm::Value *get_matrixRef() const { return Instr->getOperand(1); } + void set_matrixRef(llvm::Value *val) { Instr->setOperand(1, val); } + llvm::Value *get_inputVector() const { return Instr->getOperand(2); } + void set_inputVector(llvm::Value *val) { Instr->setOperand(2, val); } + llvm::Value *get_interpretation() const { return Instr->getOperand(3); } + void set_interpretation(llvm::Value *val) { Instr->setOperand(3, val); } +}; + +/// This instruction Multiplies a MxK dimension matrix and a K sized input +/// vector then adds a M sized bias vector +struct DxilInst_MatrixVecMulAdd { + llvm::Instruction *Instr; + // Construction and identification + DxilInst_MatrixVecMulAdd(llvm::Instruction *pInstr) : Instr(pInstr) {} + operator bool() const { + return hlsl::OP::IsDxilOpFuncCallInst(Instr, + hlsl::OP::OpCode::MatrixVecMulAdd); + } + // Validation support + bool isAllowed() const { return true; } + bool isArgumentListValid() const { + if (6 != llvm::dyn_cast(Instr)->getNumArgOperands()) + return false; + return true; + } + // Metadata + bool requiresUniformInputs() const { return false; } + // Operand indexes + enum OperandIdx { + arg_matrixRef = 1, + arg_inputVector = 2, + arg_inputInterpretation = 3, + arg_biasVector = 4, + arg_biasInterpretation = 5, + }; + // Accessors + llvm::Value *get_matrixRef() const { return Instr->getOperand(1); } + void set_matrixRef(llvm::Value *val) { Instr->setOperand(1, val); } + llvm::Value *get_inputVector() const { return Instr->getOperand(2); } + void set_inputVector(llvm::Value *val) { Instr->setOperand(2, val); } + llvm::Value *get_inputInterpretation() const { return Instr->getOperand(3); } + void set_inputInterpretation(llvm::Value *val) { Instr->setOperand(3, val); } + llvm::Value *get_biasVector() const { return Instr->getOperand(4); } + void set_biasVector(llvm::Value *val) { Instr->setOperand(4, val); } + llvm::Value *get_biasInterpretation() const { return Instr->getOperand(5); } + void set_biasInterpretation(llvm::Value *val) { Instr->setOperand(5, val); } +}; + +/// This instruction accumulates a matrix to a RWByteAddressBuffer +struct DxilInst_MatrixAccumulateToDescriptor { + llvm::Instruction *Instr; + // Construction and identification + DxilInst_MatrixAccumulateToDescriptor(llvm::Instruction *pInstr) + : Instr(pInstr) {} + operator bool() const { + return hlsl::OP::IsDxilOpFuncCallInst( + Instr, hlsl::OP::OpCode::MatrixAccumulateToDescriptor); + } + // Validation support + bool isAllowed() const { return true; } + bool isArgumentListValid() const { + if (6 != llvm::dyn_cast(Instr)->getNumArgOperands()) + return false; + return true; + } + // Metadata + bool requiresUniformInputs() const { return false; } + // Operand indexes + enum OperandIdx { + arg_matrixRef = 1, + arg_handle = 2, + arg_offset = 3, + arg_stride = 4, + arg_layout = 5, + }; + // Accessors + llvm::Value *get_matrixRef() const { return Instr->getOperand(1); } + void set_matrixRef(llvm::Value *val) { Instr->setOperand(1, val); } + llvm::Value *get_handle() const { return Instr->getOperand(2); } + void set_handle(llvm::Value *val) { Instr->setOperand(2, val); } + llvm::Value *get_offset() const { return Instr->getOperand(3); } + void set_offset(llvm::Value *val) { Instr->setOperand(3, val); } + llvm::Value *get_stride() const { return Instr->getOperand(4); } + void set_stride(llvm::Value *val) { Instr->setOperand(4, val); } + llvm::Value *get_layout() const { return Instr->getOperand(5); } + void set_layout(llvm::Value *val) { Instr->setOperand(5, val); } +}; + +/// This instruction accumulates a matrix to groupshared memory +struct DxilInst_MatrixAccumulateToMemory { + llvm::Instruction *Instr; + // Construction and identification + DxilInst_MatrixAccumulateToMemory(llvm::Instruction *pInstr) + : Instr(pInstr) {} + operator bool() const { + return hlsl::OP::IsDxilOpFuncCallInst( + Instr, hlsl::OP::OpCode::MatrixAccumulateToMemory); + } + // Validation support + bool isAllowed() const { return true; } + bool isArgumentListValid() const { + if (6 != llvm::dyn_cast(Instr)->getNumArgOperands()) + return false; + return true; + } + // Metadata + bool requiresUniformInputs() const { return false; } + // Operand indexes + enum OperandIdx { + arg_matrixRef = 1, + arg_groupsharedArr = 2, + arg_offset = 3, + arg_stride = 4, + arg_layout = 5, + }; + // Accessors + llvm::Value *get_matrixRef() const { return Instr->getOperand(1); } + void set_matrixRef(llvm::Value *val) { Instr->setOperand(1, val); } + llvm::Value *get_groupsharedArr() const { return Instr->getOperand(2); } + void set_groupsharedArr(llvm::Value *val) { Instr->setOperand(2, val); } + llvm::Value *get_offset() const { return Instr->getOperand(3); } + void set_offset(llvm::Value *val) { Instr->setOperand(3, val); } + llvm::Value *get_stride() const { return Instr->getOperand(4); } + void set_stride(llvm::Value *val) { Instr->setOperand(4, val); } + llvm::Value *get_layout() const { return Instr->getOperand(5); } + void set_layout(llvm::Value *val) { Instr->setOperand(5, val); } +}; + +/// This instruction Outer products an M sized vector and a K sized vector +/// producing an MxK matrix +struct DxilInst_MatrixOuterProduct { + llvm::Instruction *Instr; + // Construction and identification + DxilInst_MatrixOuterProduct(llvm::Instruction *pInstr) : Instr(pInstr) {} + operator bool() const { + return hlsl::OP::IsDxilOpFuncCallInst(Instr, + hlsl::OP::OpCode::MatrixOuterProduct); + } + // Validation support + bool isAllowed() const { return true; } + bool isArgumentListValid() const { + if (4 != llvm::dyn_cast(Instr)->getNumArgOperands()) + return false; + return true; + } + // Metadata + bool requiresUniformInputs() const { return false; } + // Operand indexes + enum OperandIdx { + arg_matrixRef = 1, + arg_vectorA = 2, + arg_vectorB = 3, + }; + // Accessors + llvm::Value *get_matrixRef() const { return Instr->getOperand(1); } + void set_matrixRef(llvm::Value *val) { Instr->setOperand(1, val); } + llvm::Value *get_vectorA() const { return Instr->getOperand(2); } + void set_vectorA(llvm::Value *val) { Instr->setOperand(2, val); } + llvm::Value *get_vectorB() const { return Instr->getOperand(3); } + void set_vectorB(llvm::Value *val) { Instr->setOperand(3, val); } +}; // INSTR-HELPER:END } // namespace hlsl diff --git a/lib/DXIL/DxilOperations.cpp b/lib/DXIL/DxilOperations.cpp index 0e8835c69c..dde21c56c3 100644 --- a/lib/DXIL/DxilOperations.cpp +++ b/lib/DXIL/DxilOperations.cpp @@ -2822,6 +2822,185 @@ static const OP::OpCodeProperty ExperimentalOps_OpCodeProps[] = { 1, {{0x2}}, {{0x0}}}, // Overloads: f + + // Linear Algebra Operations + {OC::CreateMatrix, + "CreateMatrix", + OCC::CreateMatrix, + "createMatrix", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::FillMatrix, + "FillMatrix", + OCC::FillMatrix, + "fillMatrix", + Attribute::None, + 1, + {{0x63}}, + {{0x0}}}, // Overloads: hfwi + {OC::CopyConvertMatrix, + "CopyConvertMatrix", + OCC::CopyConvertMatrix, + "copyConvertMatrix", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::MatrixLoadFromDescriptor, + "MatrixLoadFromDescriptor", + OCC::MatrixLoadFromDescriptor, + "matrixLoadFromDescriptor", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::MatrixLoadFromMemory, + "MatrixLoadFromMemory", + OCC::MatrixLoadFromMemory, + "matrixLoadFromMemory", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::MatrixLength, + "MatrixLength", + OCC::MatrixLength, + "matrixLength", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::MatrixGetCoordinate, + "MatrixGetCoordinate", + OCC::MatrixGetCoordinate, + "matrixGetCoordinate", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::MatrixGetElement, + "MatrixGetElement", + OCC::MatrixGetElement, + "matrixGetElement", + Attribute::None, + 1, + {{0x63}}, + {{0x0}}}, // Overloads: hfwi + {OC::MatrixSetElement, + "MatrixSetElement", + OCC::MatrixSetElement, + "matrixSetElement", + Attribute::None, + 1, + {{0x63}}, + {{0x0}}}, // Overloads: hfwi + {OC::MatrixStoreToDescriptor, + "MatrixStoreToDescriptor", + OCC::MatrixStoreToDescriptor, + "matrixStoreToDescriptor", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::MatrixStoreToMemory, + "MatrixStoreToMemory", + OCC::MatrixStoreToMemory, + "matrixStoreToMemory", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::MatrixQueryAccumulatorLayout, + "MatrixQueryAccumulatorLayout", + OCC::MatrixQueryAccumulatorLayout, + "matrixQueryAccumulatorLayout", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::MatrixMulOp, + "MatrixMulOp", + OCC::MatrixMulOp, + "matrixMulOp", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::MatrixAccumulate, + "MatrixAccumulate", + OCC::MatrixAccumulate, + "matrixAccumulate", + Attribute::None, + 0, + {}, + {}}, // Overloads: v + {OC::MatrixVecMul, + "MatrixVecMul", + OCC::MatrixVecMul, + "matrixVecMul", + Attribute::None, + 2, + {{0x400}, {0x400}}, + {{0x63}, {0x63}}}, // Overloads: getNumParams() <= 2) return nullptr; return FT->getParamType(2); @@ -6519,6 +6861,7 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) { case OpCode::CalculateLOD: case OpCode::ReportHit: case OpCode::HitObject_FromRayQueryWithAttrs: + case OpCode::MatrixSetElement: if (FT->getNumParams() <= 3) return nullptr; return FT->getParamType(3); @@ -6629,6 +6972,22 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) { case OpCode::GetGroupWaveIndex: case OpCode::GetGroupWaveCount: case OpCode::ClusterID: + case OpCode::CreateMatrix: + case OpCode::CopyConvertMatrix: + case OpCode::MatrixLoadFromDescriptor: + case OpCode::MatrixLoadFromMemory: + case OpCode::MatrixLength: + case OpCode::MatrixGetCoordinate: + case OpCode::MatrixStoreToDescriptor: + case OpCode::MatrixStoreToMemory: + case OpCode::MatrixQueryAccumulatorLayout: + case OpCode::MatrixMulOp: + case OpCode::MatrixAccumulate: + case OpCode::MatrixAccumulateToDescriptor: + case OpCode::MatrixAccumulateToMemory: + case OpCode::LinAlgMatrixReserved0: + case OpCode::LinAlgMatrixReserved1: + case OpCode::LinAlgMatrixReserved2: return Type::getVoidTy(Ctx); case OpCode::QuadVote: return IntegerType::get(Ctx, 1); @@ -6666,6 +7025,19 @@ llvm::Type *OP::GetOverloadType(OpCode opCode, llvm::Function *F) { return llvm::StructType::get(Ctx, {FT->getParamType(1), FT->getParamType(2)}); + case OpCode::MatrixVecMul: + case OpCode::MatrixVecMulAdd: + if (FT->getNumParams() < 3) + return nullptr; + return llvm::StructType::get(Ctx, + {FT->getReturnType(), FT->getParamType(2)}); + + case OpCode::MatrixOuterProduct: + if (FT->getNumParams() < 4) + return nullptr; + return llvm::StructType::get(Ctx, + {FT->getParamType(2), FT->getParamType(3)}); + // OPCODE-OLOAD-TYPES:END default: return Ty; diff --git a/utils/hct/hctdb.py b/utils/hct/hctdb.py index a3b22024ff..b074c5382d 100644 --- a/utils/hct/hctdb.py +++ b/utils/hct/hctdb.py @@ -1151,6 +1151,19 @@ def populate_categories_and_models_ExperimentalOps(self): "miss", ) + for i in insts( + "CreateMatrix,FillMatrix,CopyConvertMatrix," + + "MatrixLoadFromDescriptor,MatrixLoadFromMemory," + + "MatrixLength,MatrixGetCoordinate,MatrixGetElement,MatrixSetElement," + + "MatrixStoreToDescriptor,MatrixStoreToMemory," + + "MatrixQueryAccumulatorLayout,MatrixMulOp,MatrixAccumulate," + + "MatrixVecMul,MatrixVecMulAdd," + + "MatrixAccumulateToDescriptor,MatrixAccumulateToMemory," + + "MatrixOuterProduct" + ): + i.category = "Linear Algebra Operations" + i.shader_model = experimental_sm + def populate_llvm_instructions(self): # Add instructions that map to LLVM instructions. # This is basically include\llvm\IR\Instruction.def @@ -6300,6 +6313,281 @@ def populate_ExperimentalOps(self): ], ) + # Linear Algebra Ops + add_dxil_op( + "CreateMatrix", + "CreateMatrix", + "creates a handle to a Matrix", + "v", + "", + [ + db_dxil_param(0, "i32", "", "operation result"), # TODO: %dx.types.MatrixRef + ], + ) + + add_dxil_op( + "FillMatrix", + "FillMatrix", + "fills a matrix with a scalar value", + "hfwi", + "", + [ + db_dxil_param(0, "v", "", ""), + db_dxil_param(2, "i32", "matrixRef", "matrix to be filled"), # TODO: %dx.types.MatrixRef + db_dxil_param(3, "$o", "value", "value to fill matrix with"), + ], + ) + + add_dxil_op( + "CopyConvertMatrix", + "CopyConvertMatrix", + "Converts and copies the element and use type of the source matrix to the destination matrix with optional transpose", + "v", + "", + [ + db_dxil_param(0, "v", "", ""), + db_dxil_param(2, "i32", "destMatrixRef", "matrix to be filled"), # TODO: %dx.types.MatrixRef + db_dxil_param(3, "i32", "srcMatrixRef", "matrix to fill matrix with"), # TODO: %dx.types.MatrixRef + db_dxil_param(4, "i1", "transpose", "should the matrix be transposed"), + ], + ) + + add_dxil_op( + "MatrixLoadFromDescriptor", + "MatrixLoadFromDescriptor", + "fills a matrix with data from a [RW]ByteAddressBuffer", + "v", + "", + [ + db_dxil_param(0, "v", "", ""), + db_dxil_param(2, "i32", "matrixRef", "matrix to be filled"), # TODO: %dx.types.MatrixRef + db_dxil_param(3, "res", "handle", "byte address buffer to fill matrix with"), + db_dxil_param(4, "i32", "offset", "starting offset in the buffer"), + db_dxil_param(5, "i32", "stride", "number of bytes between the start of each row or column"), + db_dxil_param(6, "i32", "layout", "memory layout of matrix elements"), + ], + ) + + add_dxil_op( + "MatrixLoadFromMemory", + "MatrixLoadFromMemory", + "fills a matrix with data from a groupshared array", + "v", # TODO: overload needs to be updated + "", + [ + db_dxil_param(0, "v", "", ""), + db_dxil_param(2, "i32", "matrixRef", "matrix to be filled"), # TODO: %dx.types.MatrixRef + #TODO: [Ty] * addrspace(4), ; groupshared T[M * N] + db_dxil_param(3, "i32", "groupsharedArr", "groupshared array to fill matrix with"), + db_dxil_param(4, "i32", "offset", "starting offset in the array"), + db_dxil_param(5, "i32", "stride", "number of bytes between the start of each row or column"), + db_dxil_param(6, "i32", "layout", "memory layout of matrix elements"), + ], + ) + + add_dxil_op( + "MatrixLength", + "MatrixLength", + "returns the number of elements stored in thread-local storage on the active thread for the provided matrix", + "v", + "", + [ + db_dxil_param(0, "i32", "", "operation result"), + db_dxil_param(2, "i32", "matrixRef", "matrix to be examined"), # TODO: %dx.types.MatrixRef + ], + ) + + add_dxil_op( + "MatrixGetCoordinate", + "MatrixGetCoordinate", + "returns a two element vector containing the column and row of the matrix that the thread-local index corresponds to", + "v", + "", + [ + db_dxil_param(0, "i32", "", "operation result"), # TODO: <2 x i32> + db_dxil_param(2, "i32", "matrixRef", "matrix to be examined"), # TODO: %dx.types.MatrixRef + db_dxil_param(3, "i32", "threadLocalIndex", "thread-local index to be examined"), + ], + ) + + add_dxil_op( + "MatrixGetElement", + "MatrixGetElement", + "returns the element of the matrix corresponding to the provided thread-local index", + "hfwi", + "", + [ + db_dxil_param(0, "$o", "", "operation result"), + db_dxil_param(2, "i32", "matrixRef", "matrix to be examined"), # TODO: %dx.types.MatrixRef + db_dxil_param(3, "i32", "threadLocalIndex", "thread-local index to be examined"), + ], + ) + + add_dxil_op( + "MatrixSetElement", + "MatrixSetElement", + "sets the element of the matrix corresponding to the provided thread-local index", + "hfwi", + "", + [ + db_dxil_param(0, "v", "", ""), + db_dxil_param(2, "i32", "matrixRef", "matrix to be examined"), # TODO: %dx.types.MatrixRef + db_dxil_param(3, "i32", "threadLocalIndex", "thread-local index to be examined"), + db_dxil_param(4, "$o", "value", "value to set"), + ], + ) + + add_dxil_op( + "MatrixStoreToDescriptor", + "MatrixStoreToDescriptor", + "stores a matrix to a RWByteAddressBuffer", + "v", + "", + [ + db_dxil_param(0, "v", "", ""), + db_dxil_param(2, "i32", "matrixRef", "matrix to be stored"), # TODO: %dx.types.MatrixRef + db_dxil_param(3, "res", "handle", "byte address buffer to store into"), + db_dxil_param(4, "i32", "offset", "starting offset in the buffer"), + db_dxil_param(5, "i32", "stride", "number of bytes between the start of each row or column"), + db_dxil_param(6, "i32", "layout", "memory layout of matrix elements"), + ], + ) + + add_dxil_op( + "MatrixStoreToMemory", + "MatrixStoreToMemory", + "stores a matrix to groupshared memory", + "v", # TODO: overload needs to be updated + "", + [ + db_dxil_param(0, "v", "", ""), + db_dxil_param(2, "i32", "matrixRef", "matrix to be stored"), # TODO: %dx.types.MatrixRef + #TODO: [Ty] * addrspace(4), ; groupshared T[M * N] + db_dxil_param(3, "i32", "groupsharedArr", "groupshared array to store into"), + db_dxil_param(4, "i32", "offset", "starting offset in the array"), + db_dxil_param(5, "i32", "stride", "number of bytes between the start of each row or column"), + db_dxil_param(6, "i32", "layout", "memory layout of matrix elements"), + ], + ) + + add_dxil_op( + "MatrixQueryAccumulatorLayout", + "MatrixQueryAccumulatorLayout", + "returns comptime 0 when accumulator matrix are A layout, 1 when B layout", + "v", + "", + [ + db_dxil_param(0, "i32", "", "operation result"), + ], + ) + + add_dxil_op( + "MatrixMulOp", + "MatrixMulOp", + "applies a multiplication op to matrix C using A and B as parameters", + "v", + "", + [ + db_dxil_param(0, "v", "", ""), + db_dxil_param(2, "i32", "matrixRefA", "matrix A"), # TODO: %dx.types.MatrixRef + db_dxil_param(3, "i32", "matrixRefB", "matrix B"), # TODO: %dx.types.MatrixRef + db_dxil_param(4, "i32", "matrixRefC", "matrix C"), # TODO: %dx.types.MatrixRef + ], + ) + + add_dxil_op( + "MatrixAccumulate", + "MatrixAccumulate", + "accumulate A or B matrix into Accumulator matrix following LHS += RHS", + "v", + "", + [ + db_dxil_param(0, "v", "", ""), + db_dxil_param(2, "i32", "matrixRefRHS", "A or B matrix"), # TODO: %dx.types.MatrixRef + db_dxil_param(3, "i32", "matrixRefLHS", "Accumulator matrix"), # TODO: %dx.types.MatrixRef + ], + ) + + add_dxil_op( + "MatrixVecMul", + "MatrixVecMul", + "Multiplies a MxK dimension matrix and a K sized input vector", + " Date: Fri, 19 Dec 2025 17:01:16 -0700 Subject: [PATCH 2/2] format --- utils/hct/hctdb.py | 186 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 140 insertions(+), 46 deletions(-) diff --git a/utils/hct/hctdb.py b/utils/hct/hctdb.py index b074c5382d..a08ef9d32e 100644 --- a/utils/hct/hctdb.py +++ b/utils/hct/hctdb.py @@ -6321,7 +6321,9 @@ def populate_ExperimentalOps(self): "v", "", [ - db_dxil_param(0, "i32", "", "operation result"), # TODO: %dx.types.MatrixRef + db_dxil_param( + 0, "i32", "", "operation result" + ), # TODO: %dx.types.MatrixRef ], ) @@ -6333,7 +6335,9 @@ def populate_ExperimentalOps(self): "", [ db_dxil_param(0, "v", "", ""), - db_dxil_param(2, "i32", "matrixRef", "matrix to be filled"), # TODO: %dx.types.MatrixRef + db_dxil_param( + 2, "i32", "matrixRef", "matrix to be filled" + ), # TODO: %dx.types.MatrixRef db_dxil_param(3, "$o", "value", "value to fill matrix with"), ], ) @@ -6346,8 +6350,12 @@ def populate_ExperimentalOps(self): "", [ db_dxil_param(0, "v", "", ""), - db_dxil_param(2, "i32", "destMatrixRef", "matrix to be filled"), # TODO: %dx.types.MatrixRef - db_dxil_param(3, "i32", "srcMatrixRef", "matrix to fill matrix with"), # TODO: %dx.types.MatrixRef + db_dxil_param( + 2, "i32", "destMatrixRef", "matrix to be filled" + ), # TODO: %dx.types.MatrixRef + db_dxil_param( + 3, "i32", "srcMatrixRef", "matrix to fill matrix with" + ), # TODO: %dx.types.MatrixRef db_dxil_param(4, "i1", "transpose", "should the matrix be transposed"), ], ) @@ -6360,10 +6368,19 @@ def populate_ExperimentalOps(self): "", [ db_dxil_param(0, "v", "", ""), - db_dxil_param(2, "i32", "matrixRef", "matrix to be filled"), # TODO: %dx.types.MatrixRef - db_dxil_param(3, "res", "handle", "byte address buffer to fill matrix with"), + db_dxil_param( + 2, "i32", "matrixRef", "matrix to be filled" + ), # TODO: %dx.types.MatrixRef + db_dxil_param( + 3, "res", "handle", "byte address buffer to fill matrix with" + ), db_dxil_param(4, "i32", "offset", "starting offset in the buffer"), - db_dxil_param(5, "i32", "stride", "number of bytes between the start of each row or column"), + db_dxil_param( + 5, + "i32", + "stride", + "number of bytes between the start of each row or column", + ), db_dxil_param(6, "i32", "layout", "memory layout of matrix elements"), ], ) @@ -6372,15 +6389,24 @@ def populate_ExperimentalOps(self): "MatrixLoadFromMemory", "MatrixLoadFromMemory", "fills a matrix with data from a groupshared array", - "v", # TODO: overload needs to be updated + "v", # TODO: overload needs to be updated "", [ db_dxil_param(0, "v", "", ""), - db_dxil_param(2, "i32", "matrixRef", "matrix to be filled"), # TODO: %dx.types.MatrixRef - #TODO: [Ty] * addrspace(4), ; groupshared T[M * N] - db_dxil_param(3, "i32", "groupsharedArr", "groupshared array to fill matrix with"), + db_dxil_param( + 2, "i32", "matrixRef", "matrix to be filled" + ), # TODO: %dx.types.MatrixRef + # TODO: [Ty] * addrspace(4), ; groupshared T[M * N] + db_dxil_param( + 3, "i32", "groupsharedArr", "groupshared array to fill matrix with" + ), db_dxil_param(4, "i32", "offset", "starting offset in the array"), - db_dxil_param(5, "i32", "stride", "number of bytes between the start of each row or column"), + db_dxil_param( + 5, + "i32", + "stride", + "number of bytes between the start of each row or column", + ), db_dxil_param(6, "i32", "layout", "memory layout of matrix elements"), ], ) @@ -6393,7 +6419,9 @@ def populate_ExperimentalOps(self): "", [ db_dxil_param(0, "i32", "", "operation result"), - db_dxil_param(2, "i32", "matrixRef", "matrix to be examined"), # TODO: %dx.types.MatrixRef + db_dxil_param( + 2, "i32", "matrixRef", "matrix to be examined" + ), # TODO: %dx.types.MatrixRef ], ) @@ -6404,9 +6432,13 @@ def populate_ExperimentalOps(self): "v", "", [ - db_dxil_param(0, "i32", "", "operation result"), # TODO: <2 x i32> - db_dxil_param(2, "i32", "matrixRef", "matrix to be examined"), # TODO: %dx.types.MatrixRef - db_dxil_param(3, "i32", "threadLocalIndex", "thread-local index to be examined"), + db_dxil_param(0, "i32", "", "operation result"), # TODO: <2 x i32> + db_dxil_param( + 2, "i32", "matrixRef", "matrix to be examined" + ), # TODO: %dx.types.MatrixRef + db_dxil_param( + 3, "i32", "threadLocalIndex", "thread-local index to be examined" + ), ], ) @@ -6418,8 +6450,12 @@ def populate_ExperimentalOps(self): "", [ db_dxil_param(0, "$o", "", "operation result"), - db_dxil_param(2, "i32", "matrixRef", "matrix to be examined"), # TODO: %dx.types.MatrixRef - db_dxil_param(3, "i32", "threadLocalIndex", "thread-local index to be examined"), + db_dxil_param( + 2, "i32", "matrixRef", "matrix to be examined" + ), # TODO: %dx.types.MatrixRef + db_dxil_param( + 3, "i32", "threadLocalIndex", "thread-local index to be examined" + ), ], ) @@ -6431,8 +6467,12 @@ def populate_ExperimentalOps(self): "", [ db_dxil_param(0, "v", "", ""), - db_dxil_param(2, "i32", "matrixRef", "matrix to be examined"), # TODO: %dx.types.MatrixRef - db_dxil_param(3, "i32", "threadLocalIndex", "thread-local index to be examined"), + db_dxil_param( + 2, "i32", "matrixRef", "matrix to be examined" + ), # TODO: %dx.types.MatrixRef + db_dxil_param( + 3, "i32", "threadLocalIndex", "thread-local index to be examined" + ), db_dxil_param(4, "$o", "value", "value to set"), ], ) @@ -6445,10 +6485,17 @@ def populate_ExperimentalOps(self): "", [ db_dxil_param(0, "v", "", ""), - db_dxil_param(2, "i32", "matrixRef", "matrix to be stored"), # TODO: %dx.types.MatrixRef + db_dxil_param( + 2, "i32", "matrixRef", "matrix to be stored" + ), # TODO: %dx.types.MatrixRef db_dxil_param(3, "res", "handle", "byte address buffer to store into"), db_dxil_param(4, "i32", "offset", "starting offset in the buffer"), - db_dxil_param(5, "i32", "stride", "number of bytes between the start of each row or column"), + db_dxil_param( + 5, + "i32", + "stride", + "number of bytes between the start of each row or column", + ), db_dxil_param(6, "i32", "layout", "memory layout of matrix elements"), ], ) @@ -6457,15 +6504,24 @@ def populate_ExperimentalOps(self): "MatrixStoreToMemory", "MatrixStoreToMemory", "stores a matrix to groupshared memory", - "v", # TODO: overload needs to be updated + "v", # TODO: overload needs to be updated "", [ db_dxil_param(0, "v", "", ""), - db_dxil_param(2, "i32", "matrixRef", "matrix to be stored"), # TODO: %dx.types.MatrixRef - #TODO: [Ty] * addrspace(4), ; groupshared T[M * N] - db_dxil_param(3, "i32", "groupsharedArr", "groupshared array to store into"), + db_dxil_param( + 2, "i32", "matrixRef", "matrix to be stored" + ), # TODO: %dx.types.MatrixRef + # TODO: [Ty] * addrspace(4), ; groupshared T[M * N] + db_dxil_param( + 3, "i32", "groupsharedArr", "groupshared array to store into" + ), db_dxil_param(4, "i32", "offset", "starting offset in the array"), - db_dxil_param(5, "i32", "stride", "number of bytes between the start of each row or column"), + db_dxil_param( + 5, + "i32", + "stride", + "number of bytes between the start of each row or column", + ), db_dxil_param(6, "i32", "layout", "memory layout of matrix elements"), ], ) @@ -6489,9 +6545,15 @@ def populate_ExperimentalOps(self): "", [ db_dxil_param(0, "v", "", ""), - db_dxil_param(2, "i32", "matrixRefA", "matrix A"), # TODO: %dx.types.MatrixRef - db_dxil_param(3, "i32", "matrixRefB", "matrix B"), # TODO: %dx.types.MatrixRef - db_dxil_param(4, "i32", "matrixRefC", "matrix C"), # TODO: %dx.types.MatrixRef + db_dxil_param( + 2, "i32", "matrixRefA", "matrix A" + ), # TODO: %dx.types.MatrixRef + db_dxil_param( + 3, "i32", "matrixRefB", "matrix B" + ), # TODO: %dx.types.MatrixRef + db_dxil_param( + 4, "i32", "matrixRefC", "matrix C" + ), # TODO: %dx.types.MatrixRef ], ) @@ -6503,8 +6565,12 @@ def populate_ExperimentalOps(self): "", [ db_dxil_param(0, "v", "", ""), - db_dxil_param(2, "i32", "matrixRefRHS", "A or B matrix"), # TODO: %dx.types.MatrixRef - db_dxil_param(3, "i32", "matrixRefLHS", "Accumulator matrix"), # TODO: %dx.types.MatrixRef + db_dxil_param( + 2, "i32", "matrixRefRHS", "A or B matrix" + ), # TODO: %dx.types.MatrixRef + db_dxil_param( + 3, "i32", "matrixRefLHS", "Accumulator matrix" + ), # TODO: %dx.types.MatrixRef ], ) @@ -6516,7 +6582,9 @@ def populate_ExperimentalOps(self): "", [ db_dxil_param(0, "$x0", "", "operation result"), - db_dxil_param(2, "i32", "matrixRef", "matrix to multiply"), # TODO: %dx.types.MatrixRef + db_dxil_param( + 2, "i32", "matrixRef", "matrix to multiply" + ), # TODO: %dx.types.MatrixRef db_dxil_param(3, "$x1", "inputVector", "K dim vector to multiply"), db_dxil_param(4, "i32", "interpretation", "vector interpretation type"), ], @@ -6526,16 +6594,22 @@ def populate_ExperimentalOps(self): "MatrixVecMulAdd", "MatrixVecMulAdd", "Multiplies a MxK dimension matrix and a K sized input vector then adds a M sized bias vector", - "