-
Notifications
You must be signed in to change notification settings - Fork 627
[TOSA] MultiheadAttention legalization #4382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -2295,17 +2295,223 @@ class DecomposeAtenTraceOp : public OpRewritePattern<AtenTraceOp> { | |||
| }; | ||||
| } // namespace | ||||
|
|
||||
| static Value getSoftmaxResult(Operation *op, Value self, Value dim, | ||||
| Type resultType, Type accumulatorType, | ||||
| PatternRewriter &rewriter); | ||||
|
|
||||
| namespace { | ||||
| // Decompose scaled dot product attention into matmul/softmax pipeline when | ||||
| // there is no masking, dropout, causal, or GQA behaviour. | ||||
| class DecomposeAtenScaledDotProductAttentionOp | ||||
| : public OpRewritePattern<AtenScaledDotProductAttentionOp> { | ||||
| public: | ||||
| using OpRewritePattern::OpRewritePattern; | ||||
| LogicalResult matchAndRewrite(AtenScaledDotProductAttentionOp op, | ||||
| PatternRewriter &rewriter) const override { | ||||
| Location loc = op.getLoc(); | ||||
|
|
||||
| if (!isa<Torch::NoneType>(op.getAttnMask().getType())) | ||||
| return rewriter.notifyMatchFailure( | ||||
| op, "attention mask decomposition not implemented"); | ||||
|
|
||||
| double dropoutP; | ||||
| if (!matchPattern(op.getDropoutP(), m_TorchConstantFloat(&dropoutP)) || | ||||
| dropoutP != 0.0) | ||||
| return rewriter.notifyMatchFailure( | ||||
| op, "expected dropout_p to be the constant 0.0"); | ||||
|
|
||||
| bool isCausal; | ||||
| if (!matchPattern(op.getIsCausal(), m_TorchConstantBool(&isCausal)) || | ||||
| isCausal) | ||||
| return rewriter.notifyMatchFailure(op, | ||||
| "causal attention not supported yet"); | ||||
|
|
||||
| bool enableGqa; | ||||
| if (!matchPattern(op.getEnableGqa(), m_TorchConstantBool(&enableGqa)) || | ||||
| enableGqa) | ||||
| return rewriter.notifyMatchFailure(op, | ||||
| "grouped-query attention unsupported"); | ||||
|
|
||||
| Value query = op.getQuery(); | ||||
| Value key = op.getKey(); | ||||
| Value value = op.getValue(); | ||||
|
|
||||
| auto queryTensorType = dyn_cast<BaseTensorType>(query.getType()); | ||||
| auto keyTensorType = dyn_cast<BaseTensorType>(key.getType()); | ||||
| auto valueTensorType = dyn_cast<BaseTensorType>(value.getType()); | ||||
| if (!queryTensorType || !keyTensorType || !valueTensorType) | ||||
| return rewriter.notifyMatchFailure(op, "expected tensor inputs"); | ||||
| if (!queryTensorType.hasSizes() || !keyTensorType.hasSizes() || | ||||
| !valueTensorType.hasSizes()) | ||||
| return rewriter.notifyMatchFailure( | ||||
| op, "expected tensor inputs to have known shapes"); | ||||
| auto queryValueTensorType = dyn_cast<ValueTensorType>(queryTensorType); | ||||
| auto keyValueTensorType = dyn_cast<ValueTensorType>(keyTensorType); | ||||
| auto valueValueTensorType = dyn_cast<ValueTensorType>(valueTensorType); | ||||
| if (!queryValueTensorType || !keyValueTensorType || !valueValueTensorType) | ||||
catcor01 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
| return rewriter.notifyMatchFailure(op, "expected value tensor semantics"); | ||||
| if (!queryValueTensorType.hasDtype() || !keyValueTensorType.hasDtype() || | ||||
| !valueValueTensorType.hasDtype()) | ||||
| return rewriter.notifyMatchFailure( | ||||
| op, "expected tensor inputs to have dtypes"); | ||||
| Type queryDtype = queryValueTensorType.getOptionalDtype(); | ||||
| if (queryDtype != keyValueTensorType.getOptionalDtype() || | ||||
| queryDtype != valueValueTensorType.getOptionalDtype()) | ||||
| return rewriter.notifyMatchFailure( | ||||
| op, "expected query, key, and value to share dtype"); | ||||
|
|
||||
| Value oneInt = | ||||
| ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); | ||||
| Value zeroInt = | ||||
| ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); | ||||
| Value rank = AtenDimOp::create(rewriter, loc, query); | ||||
| Value lastDim = AtenSubIntOp::create(rewriter, loc, rank, oneInt); | ||||
| Value headDim = AtenSizeIntOp::create(rewriter, loc, query, lastDim); | ||||
| Value seqDimIndex = AtenSubIntOp::create(rewriter, loc, lastDim, oneInt); | ||||
| Value seqLen = AtenSizeIntOp::create(rewriter, loc, query, seqDimIndex); | ||||
| Value keySeqLen = AtenSizeIntOp::create(rewriter, loc, key, seqDimIndex); | ||||
| ArrayRef<int64_t> querySizes = queryValueTensorType.getSizes(); | ||||
| int64_t queryRank = querySizes.size(); | ||||
| if (queryRank < 3 || queryRank > 4) | ||||
| return rewriter.notifyMatchFailure( | ||||
| op, "expected query tensor rank to be 3 or 4"); | ||||
| ArrayRef<int64_t> keySizes = keyValueTensorType.getSizes(); | ||||
| ArrayRef<int64_t> valueSizes = valueValueTensorType.getSizes(); | ||||
| if (static_cast<int64_t>(keySizes.size()) != queryRank || | ||||
| static_cast<int64_t>(valueSizes.size()) != queryRank) | ||||
| return rewriter.notifyMatchFailure( | ||||
| op, "expected query, key, and value to share rank"); | ||||
| bool hasExplicitHeadDim = queryRank == 4; | ||||
| Value numHeadsSize = | ||||
| hasExplicitHeadDim | ||||
| ? (Value)AtenSizeIntOp::create(rewriter, loc, query, oneInt) | ||||
| : oneInt; | ||||
| Value batchSize = AtenSizeIntOp::create(rewriter, loc, query, zeroInt); | ||||
| auto listIntType = | ||||
| Torch::ListType::get(Torch::IntType::get(rewriter.getContext())); | ||||
|
|
||||
| auto getDimValue = [&](int64_t staticDim, Value fallback) -> Value { | ||||
| if (staticDim != Torch::kUnknownSize) | ||||
| return ConstantIntOp::create(rewriter, loc, | ||||
| rewriter.getI64IntegerAttr(staticDim)); | ||||
| return fallback; | ||||
| }; | ||||
|
|
||||
| Value scaleFloat; | ||||
| if (isa<Torch::NoneType>(op.getScale().getType())) { | ||||
| Value sqrtHeadDim = AtenSqrtIntOp::create(rewriter, loc, headDim); | ||||
| Value oneFloat = | ||||
| ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(1.0)); | ||||
| scaleFloat = AtenDivFloatOp::create(rewriter, loc, oneFloat, sqrtHeadDim); | ||||
| } else { | ||||
| scaleFloat = op.getScale(); | ||||
| } | ||||
|
|
||||
| Value negTwo = | ||||
| ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-2)); | ||||
| Value negOne = | ||||
| ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-1)); | ||||
|
|
||||
| SmallVector<int64_t> keyTransposedSizes(keySizes.begin(), keySizes.end()); | ||||
| if (keyTransposedSizes.size() < 2) | ||||
| return rewriter.notifyMatchFailure( | ||||
| op, "expected key tensor rank >= 2 for transpose"); | ||||
| std::swap(keyTransposedSizes[keyTransposedSizes.size() - 1], | ||||
| keyTransposedSizes[keyTransposedSizes.size() - 2]); | ||||
| ArrayRef<int64_t> keyTransposedRef(keyTransposedSizes); | ||||
| std::optional<ArrayRef<int64_t>> keyTransposedOpt(keyTransposedRef); | ||||
| Type keyTransposedType = keyValueTensorType.getWithSizesAndDtypeAndSparsity( | ||||
| keyTransposedSizes, keyValueTensorType.getOptionalDtype(), | ||||
| keyValueTensorType.getOptionalSparsity()); | ||||
| Value keyTransposed = AtenTransposeIntOp::create( | ||||
| rewriter, loc, keyTransposedType, key, negTwo, negOne); | ||||
| SmallVector<Value> keyDims; | ||||
| auto getOrFallback = [&](ArrayRef<int64_t> staticDims, unsigned idx, | ||||
| Value fallback) -> Value { | ||||
| return getDimValue(idx < staticDims.size() ? staticDims[idx] | ||||
| : Torch::kUnknownSize, | ||||
| fallback); | ||||
| }; | ||||
| keyDims.push_back(getOrFallback(keyTransposedSizes, 0, batchSize)); | ||||
| if (hasExplicitHeadDim) { | ||||
| keyDims.push_back(getOrFallback(keyTransposedSizes, 1, numHeadsSize)); | ||||
| keyDims.push_back(getOrFallback(keyTransposedSizes, 2, headDim)); | ||||
| keyDims.push_back(getOrFallback(keyTransposedSizes, 3, keySeqLen)); | ||||
| } else { | ||||
| keyDims.push_back(getOrFallback(keyTransposedSizes, 1, headDim)); | ||||
| keyDims.push_back(getOrFallback(keyTransposedSizes, 2, keySeqLen)); | ||||
| } | ||||
| Value keyTransposeShapeList = | ||||
| PrimListConstructOp::create(rewriter, loc, listIntType, keyDims); | ||||
| keyTransposed = AtenViewOp::create(rewriter, loc, keyTransposedType, | ||||
| keyTransposed, keyTransposeShapeList); | ||||
|
|
||||
| auto getStaticDim = [](ArrayRef<int64_t> sizes, int64_t index) { | ||||
| if (index < 0) | ||||
| index += sizes.size(); | ||||
| if (index < 0 || index >= static_cast<int64_t>(sizes.size())) | ||||
| return Torch::kUnknownSize; | ||||
| return sizes[index]; | ||||
| }; | ||||
| int64_t queryBatchStatic = getStaticDim(querySizes, 0); | ||||
| int64_t querySeqStatic = getStaticDim(querySizes, -2); | ||||
| int64_t keySeqStatic = getStaticDim(keySizes, -2); | ||||
| int64_t queryHeadsStatic = | ||||
| hasExplicitHeadDim ? getStaticDim(querySizes, 1) : 1; | ||||
| SmallVector<int64_t, 4> scoresSizes; | ||||
| if (hasExplicitHeadDim) | ||||
| scoresSizes.assign( | ||||
| {queryBatchStatic, queryHeadsStatic, querySeqStatic, keySeqStatic}); | ||||
| else | ||||
| scoresSizes.assign({queryBatchStatic, querySeqStatic, keySeqStatic}); | ||||
| Type scoresType = ValueTensorType::get( | ||||
| op->getContext(), | ||||
| ArrayRef<int64_t>(scoresSizes.begin(), scoresSizes.end()), | ||||
| queryValueTensorType.getOptionalDtype(), | ||||
| queryValueTensorType.getOptionalSparsity()); | ||||
| Value scores = | ||||
| AtenMatmulOp::create(rewriter, loc, scoresType, query, keyTransposed); | ||||
| SmallVector<Value> scoresDims; | ||||
| scoresDims.push_back(getDimValue(scoresSizes[0], batchSize)); | ||||
| unsigned seqIndex = 1; | ||||
| if (hasExplicitHeadDim) { | ||||
| scoresDims.push_back(getDimValue(scoresSizes[1], numHeadsSize)); | ||||
| seqIndex = 2; | ||||
| } | ||||
| scoresDims.push_back(getDimValue(scoresSizes[seqIndex], seqLen)); | ||||
| scoresDims.push_back(getDimValue(scoresSizes.back(), keySeqLen)); | ||||
| Value scoresShapeList = | ||||
| PrimListConstructOp::create(rewriter, loc, listIntType, scoresDims); | ||||
| scores = | ||||
| AtenViewOp::create(rewriter, loc, scoresType, scores, scoresShapeList); | ||||
| Value scaledScores = | ||||
| AtenMulScalarOp::create(rewriter, loc, scoresType, scores, scaleFloat); | ||||
|
|
||||
| Value softmax = getSoftmaxResult(op.getOperation(), scaledScores, negOne, | ||||
| scoresType, scoresType, rewriter); | ||||
| if (!softmax) | ||||
| return rewriter.notifyMatchFailure(op, | ||||
| "failed to compute softmax scores"); | ||||
|
|
||||
| Value output = | ||||
| AtenMatmulOp::create(rewriter, loc, op.getType(), softmax, value); | ||||
|
|
||||
| rewriter.replaceOp(op, output); | ||||
| return success(); | ||||
| } | ||||
| }; | ||||
| } // namespace | ||||
|
|
||||
| // Calculates the softmax function on the given `input` tensor. Softmax(x) = | ||||
| // exp(x)/sum(exp(x)). | ||||
| // To avoid overflow we use the following decomposition rule: | ||||
| // x_max = max(input, dim, keepdim = True) | ||||
| // unnorm = aten.exp(input - x_max) | ||||
| // softmax = unnorm / sum(unnorm, dim, keepdim = True) | ||||
| template <typename OpTy> | ||||
| static Value getSoftmaxResult(OpTy op, Value self, Type resultType, | ||||
| Type accumulatorType, PatternRewriter &rewriter) { | ||||
| Location loc = op.getLoc(); | ||||
| Value dim = op.getDim(); | ||||
| static Value getSoftmaxResult(Operation *op, Value self, Value dim, | ||||
| Type resultType, Type accumulatorType, | ||||
| PatternRewriter &rewriter) { | ||||
| Location loc = op->getLoc(); | ||||
| if (resultType != accumulatorType) | ||||
| self = convertTensorToDtype(rewriter, loc, self, accumulatorType); | ||||
| Value xMax = | ||||
|
|
@@ -2362,8 +2568,9 @@ class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> { | |||
|
|
||||
| Type accumulatorTensorType = getDefaultAccType(rewriter, resultTensorDtype); | ||||
|
|
||||
| Value result = getSoftmaxResult(op, self, resultTensorType, | ||||
| accumulatorTensorType, rewriter); | ||||
| Value result = | ||||
| getSoftmaxResult(op.getOperation(), self, op.getDim(), resultTensorType, | ||||
| accumulatorTensorType, rewriter); | ||||
| if (!result) | ||||
| return failure(); | ||||
| rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(), | ||||
|
|
@@ -2411,8 +2618,9 @@ class DecomposeAten_SoftmaxOp : public OpRewritePattern<Aten_SoftmaxOp> { | |||
|
|
||||
| Type accumulatorTensorType = getDefaultAccType(rewriter, resultTensorDtype); | ||||
|
|
||||
| Value result = getSoftmaxResult(op, self, resultTensorType, | ||||
| accumulatorTensorType, rewriter); | ||||
| Value result = | ||||
| getSoftmaxResult(op.getOperation(), self, op.getDim(), resultTensorType, | ||||
| accumulatorTensorType, rewriter); | ||||
| if (!result) | ||||
| return op.emitError("failed to get softmax result"); | ||||
| rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, resultTensorType, | ||||
|
|
@@ -13084,6 +13292,8 @@ class DecomposeComplexOpsPass | |||
| legalOpsSet.clear(); | ||||
| legalOpsSet.insert(legalOps.begin(), legalOps.end()); | ||||
|
|
||||
| patterns.add<DecomposeAtenScaledDotProductAttentionOp>(context); | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this pattern needed anymore with the change in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please correct me if I misunderstand, but I believe we still need the MLIR-side pattern. The new entry in python/torch_mlir/extras/fx_decomp_util.py only affects the FX/ExportedProgram import path. Other frontends—TorchScript, AOTAutograd, or anyone who feeds raw Torch dialect into torch-mlir-opt—never touch that Python list, so they can still produce torch.aten.scaled_dot_product_attention. For those cases the rewrite in lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp is what lowers sdpa into the matmul/softmax pipeline so that downstream -convert-torch-to-tosa or -convert-torch-to-linalg keeps working.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Based on torch-mlir/docs/development.md Line 244 in 0844d4d
fx_importer path is the only maintained path. Rest have been deprecated but the code still exists. Maybe @sjarus / @zjgarvey can confirm / correct that understanding and we can discuss if it's still valuable to have this decomposition pattern or we can rely on PyTorch's decomposition.
I don't have any flag on adding this, just want to make sure that it will actually be exercised.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the I have found attention to be a bit frustrating, however. For example, running decompositions on an exported program with an sdpa op sometimes converts sdpa into a slightly different attention op- even when attention itself isn't getting decomposed. Merely running decompositions at all actually retraces the graph with a different tool, and may select different ops further varied based on other factors like the torch device used by the inputs. In any case, I don't mind adding a decomposition pattern. We have a bit more control with a pattern like this as opposed to fx decompositions.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the I have found attention to be a bit frustrating, however. For example, running decompositions on an exported program with an sdpa op sometimes converts sdpa into a slightly different attention op- even when attention itself isn't getting decomposed. Merely running decompositions at all actually retraces the graph with a different tool, and may select different ops further varied based on other factors like the torch device used by the inputs. In any case, I don't mind adding a decomposition pattern. We have a bit more control with a pattern like this as opposed to fx decompositions. |
||||
|
|
||||
| addPatternIfTargetOpIsIllegal<DecomposeAten_WeightNormInterfaceOp>( | ||||
| patterns); | ||||
| addPatternIfTargetOpIsIllegal<DecomposeAtenSoftmaxIntOp>(patterns); | ||||
|
|
||||
Uh oh!
There was an error while loading. Please reload this page.