Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4016,8 +4016,28 @@ LogicalResult ConvertAtenOp<AtenTransposeIntOp>::matchAndRewrite(
transposedDims[dim0] = dim1;
transposedDims[dim1] = dim0;

Type resultType = getTypeConverter()->convertType(op.getType());
if (auto rankedSelf = dyn_cast<RankedTensorType>(selfType)) {
SmallVector<int64_t> transposedShape(rankedSelf.getRank(),
ShapedType::kDynamic);
if (rankedSelf.hasStaticShape()) {
auto staticShape =
llvm::to_vector(makeShapeTorchCompatible(rankedSelf.getShape()));
auto dim0Index = static_cast<size_t>(dim0);
auto dim1Index = static_cast<size_t>(dim1);
if (dim0Index < staticShape.size() && dim1Index < staticShape.size())
std::swap(staticShape[dim0Index], staticShape[dim1Index]);
for (size_t i = 0; i < staticShape.size(); ++i)
transposedShape[i] = staticShape[i];
}
auto rankedResult = RankedTensorType::get(
makeShapeLLVMCompatible(transposedShape), rankedSelf.getElementType());
if (auto converted = getTypeConverter()->convertType(rankedResult))
resultType = converted;
}

rewriter.replaceOpWithNewOp<tosa::TransposeOp>(
op, getTypeConverter()->convertType(op.getType()), adaptor.getSelf(),
op, resultType, adaptor.getSelf(),
rewriter.getDenseI32ArrayAttr(transposedDims));

return success();
Expand Down Expand Up @@ -9402,6 +9422,21 @@ void populateTorchToTosaConversionLegalOps(ConversionTarget &target) {
target.addLegalOp<ConstantDeviceOp>();
target.addLegalOp<PrimListConstructOp>();
target.addLegalOp<PrimTupleConstructOp>();
target.addDynamicallyLegalOp<tensor::CastOp>([](tensor::CastOp op) -> bool {
auto sourceType = dyn_cast<RankedTensorType>(op.getSource().getType());
auto resultType = dyn_cast<RankedTensorType>(op.getType());
if (!sourceType || !resultType)
return true;
if (sourceType.getElementType() != resultType.getElementType())
return true;
if (!sourceType.hasStaticShape())
return true;
if (!resultType.hasStaticShape())
return true;
if (sourceType == resultType)
return true;
return false;
});
}

std::set<StringRef> populateTorchToTosaConversionPatternsAndIllegalOps(
Expand Down
228 changes: 219 additions & 9 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2295,17 +2295,223 @@ class DecomposeAtenTraceOp : public OpRewritePattern<AtenTraceOp> {
};
} // namespace

static Value getSoftmaxResult(Operation *op, Value self, Value dim,
Type resultType, Type accumulatorType,
PatternRewriter &rewriter);

namespace {
// Decompose scaled dot product attention into matmul/softmax pipeline when
// there is no masking, dropout, causal, or GQA behaviour.
class DecomposeAtenScaledDotProductAttentionOp
: public OpRewritePattern<AtenScaledDotProductAttentionOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenScaledDotProductAttentionOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();

if (!isa<Torch::NoneType>(op.getAttnMask().getType()))
return rewriter.notifyMatchFailure(
op, "attention mask decomposition not implemented");

double dropoutP;
if (!matchPattern(op.getDropoutP(), m_TorchConstantFloat(&dropoutP)) ||
dropoutP != 0.0)
return rewriter.notifyMatchFailure(
op, "expected dropout_p to be the constant 0.0");

bool isCausal;
if (!matchPattern(op.getIsCausal(), m_TorchConstantBool(&isCausal)) ||
isCausal)
return rewriter.notifyMatchFailure(op,
"causal attention not supported yet");

bool enableGqa;
if (!matchPattern(op.getEnableGqa(), m_TorchConstantBool(&enableGqa)) ||
enableGqa)
return rewriter.notifyMatchFailure(op,
"grouped-query attention unsupported");

Value query = op.getQuery();
Value key = op.getKey();
Value value = op.getValue();

auto queryTensorType = dyn_cast<BaseTensorType>(query.getType());
auto keyTensorType = dyn_cast<BaseTensorType>(key.getType());
auto valueTensorType = dyn_cast<BaseTensorType>(value.getType());
if (!queryTensorType || !keyTensorType || !valueTensorType)
return rewriter.notifyMatchFailure(op, "expected tensor inputs");
if (!queryTensorType.hasSizes() || !keyTensorType.hasSizes() ||
!valueTensorType.hasSizes())
return rewriter.notifyMatchFailure(
op, "expected tensor inputs to have known shapes");
auto queryValueTensorType = dyn_cast<ValueTensorType>(queryTensorType);
auto keyValueTensorType = dyn_cast<ValueTensorType>(keyTensorType);
auto valueValueTensorType = dyn_cast<ValueTensorType>(valueTensorType);
if (!queryValueTensorType || !keyValueTensorType || !valueValueTensorType)
return rewriter.notifyMatchFailure(op, "expected value tensor semantics");
if (!queryValueTensorType.hasDtype() || !keyValueTensorType.hasDtype() ||
!valueValueTensorType.hasDtype())
return rewriter.notifyMatchFailure(
op, "expected tensor inputs to have dtypes");
Type queryDtype = queryValueTensorType.getOptionalDtype();
if (queryDtype != keyValueTensorType.getOptionalDtype() ||
queryDtype != valueValueTensorType.getOptionalDtype())
return rewriter.notifyMatchFailure(
op, "expected query, key, and value to share dtype");

Value oneInt =
ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1));
Value zeroInt =
ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0));
Value rank = AtenDimOp::create(rewriter, loc, query);
Value lastDim = AtenSubIntOp::create(rewriter, loc, rank, oneInt);
Value headDim = AtenSizeIntOp::create(rewriter, loc, query, lastDim);
Value seqDimIndex = AtenSubIntOp::create(rewriter, loc, lastDim, oneInt);
Value seqLen = AtenSizeIntOp::create(rewriter, loc, query, seqDimIndex);
Value keySeqLen = AtenSizeIntOp::create(rewriter, loc, key, seqDimIndex);
ArrayRef<int64_t> querySizes = queryValueTensorType.getSizes();
int64_t queryRank = querySizes.size();
if (queryRank < 3 || queryRank > 4)
return rewriter.notifyMatchFailure(
op, "expected query tensor rank to be 3 or 4");
ArrayRef<int64_t> keySizes = keyValueTensorType.getSizes();
ArrayRef<int64_t> valueSizes = valueValueTensorType.getSizes();
if (static_cast<int64_t>(keySizes.size()) != queryRank ||
static_cast<int64_t>(valueSizes.size()) != queryRank)
return rewriter.notifyMatchFailure(
op, "expected query, key, and value to share rank");
bool hasExplicitHeadDim = queryRank == 4;
Value numHeadsSize =
hasExplicitHeadDim
? (Value)AtenSizeIntOp::create(rewriter, loc, query, oneInt)
: oneInt;
Value batchSize = AtenSizeIntOp::create(rewriter, loc, query, zeroInt);
auto listIntType =
Torch::ListType::get(Torch::IntType::get(rewriter.getContext()));

auto getDimValue = [&](int64_t staticDim, Value fallback) -> Value {
if (staticDim != Torch::kUnknownSize)
return ConstantIntOp::create(rewriter, loc,
rewriter.getI64IntegerAttr(staticDim));
return fallback;
};

Value scaleFloat;
if (isa<Torch::NoneType>(op.getScale().getType())) {
Value sqrtHeadDim = AtenSqrtIntOp::create(rewriter, loc, headDim);
Value oneFloat =
ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(1.0));
scaleFloat = AtenDivFloatOp::create(rewriter, loc, oneFloat, sqrtHeadDim);
} else {
scaleFloat = op.getScale();
}

Value negTwo =
ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-2));
Value negOne =
ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-1));

SmallVector<int64_t> keyTransposedSizes(keySizes.begin(), keySizes.end());
if (keyTransposedSizes.size() < 2)
return rewriter.notifyMatchFailure(
op, "expected key tensor rank >= 2 for transpose");
std::swap(keyTransposedSizes[keyTransposedSizes.size() - 1],
keyTransposedSizes[keyTransposedSizes.size() - 2]);
ArrayRef<int64_t> keyTransposedRef(keyTransposedSizes);
std::optional<ArrayRef<int64_t>> keyTransposedOpt(keyTransposedRef);
Type keyTransposedType = keyValueTensorType.getWithSizesAndDtypeAndSparsity(
keyTransposedSizes, keyValueTensorType.getOptionalDtype(),
keyValueTensorType.getOptionalSparsity());
Value keyTransposed = AtenTransposeIntOp::create(
rewriter, loc, keyTransposedType, key, negTwo, negOne);
SmallVector<Value> keyDims;
auto getOrFallback = [&](ArrayRef<int64_t> staticDims, unsigned idx,
Value fallback) -> Value {
return getDimValue(idx < staticDims.size() ? staticDims[idx]
: Torch::kUnknownSize,
fallback);
};
keyDims.push_back(getOrFallback(keyTransposedSizes, 0, batchSize));
if (hasExplicitHeadDim) {
keyDims.push_back(getOrFallback(keyTransposedSizes, 1, numHeadsSize));
keyDims.push_back(getOrFallback(keyTransposedSizes, 2, headDim));
keyDims.push_back(getOrFallback(keyTransposedSizes, 3, keySeqLen));
} else {
keyDims.push_back(getOrFallback(keyTransposedSizes, 1, headDim));
keyDims.push_back(getOrFallback(keyTransposedSizes, 2, keySeqLen));
}
Value keyTransposeShapeList =
PrimListConstructOp::create(rewriter, loc, listIntType, keyDims);
keyTransposed = AtenViewOp::create(rewriter, loc, keyTransposedType,
keyTransposed, keyTransposeShapeList);

auto getStaticDim = [](ArrayRef<int64_t> sizes, int64_t index) {
if (index < 0)
index += sizes.size();
if (index < 0 || index >= static_cast<int64_t>(sizes.size()))
return Torch::kUnknownSize;
return sizes[index];
};
int64_t queryBatchStatic = getStaticDim(querySizes, 0);
int64_t querySeqStatic = getStaticDim(querySizes, -2);
int64_t keySeqStatic = getStaticDim(keySizes, -2);
int64_t queryHeadsStatic =
hasExplicitHeadDim ? getStaticDim(querySizes, 1) : 1;
SmallVector<int64_t, 4> scoresSizes;
if (hasExplicitHeadDim)
scoresSizes.assign(
{queryBatchStatic, queryHeadsStatic, querySeqStatic, keySeqStatic});
else
scoresSizes.assign({queryBatchStatic, querySeqStatic, keySeqStatic});
Type scoresType = ValueTensorType::get(
op->getContext(),
ArrayRef<int64_t>(scoresSizes.begin(), scoresSizes.end()),
queryValueTensorType.getOptionalDtype(),
queryValueTensorType.getOptionalSparsity());
Value scores =
AtenMatmulOp::create(rewriter, loc, scoresType, query, keyTransposed);
SmallVector<Value> scoresDims;
scoresDims.push_back(getDimValue(scoresSizes[0], batchSize));
unsigned seqIndex = 1;
if (hasExplicitHeadDim) {
scoresDims.push_back(getDimValue(scoresSizes[1], numHeadsSize));
seqIndex = 2;
}
scoresDims.push_back(getDimValue(scoresSizes[seqIndex], seqLen));
scoresDims.push_back(getDimValue(scoresSizes.back(), keySeqLen));
Value scoresShapeList =
PrimListConstructOp::create(rewriter, loc, listIntType, scoresDims);
scores =
AtenViewOp::create(rewriter, loc, scoresType, scores, scoresShapeList);
Value scaledScores =
AtenMulScalarOp::create(rewriter, loc, scoresType, scores, scaleFloat);

Value softmax = getSoftmaxResult(op.getOperation(), scaledScores, negOne,
scoresType, scoresType, rewriter);
if (!softmax)
return rewriter.notifyMatchFailure(op,
"failed to compute softmax scores");

Value output =
AtenMatmulOp::create(rewriter, loc, op.getType(), softmax, value);

rewriter.replaceOp(op, output);
return success();
}
};
} // namespace

// Calculates the softmax function on the given `input` tensor. Softmax(x) =
// exp(x)/sum(exp(x)).
// To avoid overflow we use the following decomposition rule:
// x_max = max(input, dim, keepdim = True)
// unnorm = aten.exp(input - x_max)
// softmax = unnorm / sum(unnorm, dim, keepdim = True)
template <typename OpTy>
static Value getSoftmaxResult(OpTy op, Value self, Type resultType,
Type accumulatorType, PatternRewriter &rewriter) {
Location loc = op.getLoc();
Value dim = op.getDim();
static Value getSoftmaxResult(Operation *op, Value self, Value dim,
Type resultType, Type accumulatorType,
PatternRewriter &rewriter) {
Location loc = op->getLoc();
if (resultType != accumulatorType)
self = convertTensorToDtype(rewriter, loc, self, accumulatorType);
Value xMax =
Expand Down Expand Up @@ -2362,8 +2568,9 @@ class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {

Type accumulatorTensorType = getDefaultAccType(rewriter, resultTensorDtype);

Value result = getSoftmaxResult(op, self, resultTensorType,
accumulatorTensorType, rewriter);
Value result =
getSoftmaxResult(op.getOperation(), self, op.getDim(), resultTensorType,
accumulatorTensorType, rewriter);
if (!result)
return failure();
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
Expand Down Expand Up @@ -2411,8 +2618,9 @@ class DecomposeAten_SoftmaxOp : public OpRewritePattern<Aten_SoftmaxOp> {

Type accumulatorTensorType = getDefaultAccType(rewriter, resultTensorDtype);

Value result = getSoftmaxResult(op, self, resultTensorType,
accumulatorTensorType, rewriter);
Value result =
getSoftmaxResult(op.getOperation(), self, op.getDim(), resultTensorType,
accumulatorTensorType, rewriter);
if (!result)
return op.emitError("failed to get softmax result");
rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, resultTensorType,
Expand Down Expand Up @@ -13084,6 +13292,8 @@ class DecomposeComplexOpsPass
legalOpsSet.clear();
legalOpsSet.insert(legalOps.begin(), legalOps.end());

patterns.add<DecomposeAtenScaledDotProductAttentionOp>(context);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this pattern needed anymore with the change in fx_decomp_util?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please correct me if I misunderstand, but I believe we still need the MLIR-side pattern. The new entry in python/torch_mlir/extras/fx_decomp_util.py only affects the FX/ExportedProgram import path. Other frontends—TorchScript, AOTAutograd, or anyone who feeds raw Torch dialect into torch-mlir-opt—never touch that Python list, so they can still produce torch.aten.scaled_dot_product_attention. For those cases the rewrite in lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp is what lowers sdpa into the matmul/softmax pipeline so that downstream -convert-torch-to-tosa or -convert-torch-to-linalg keeps working.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on

This path doesn't give access to the current generation work that is being driven via the fx_importer
IIUC, fx_importer path is the only maintained path. Rest have been deprecated but the code still exists. Maybe @sjarus / @zjgarvey can confirm / correct that understanding and we can discuss if it's still valuable to have this decomposition pattern or we can rely on PyTorch's decomposition.

I don't have any flag on adding this, just want to make sure that it will actually be exercised.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the fx_importer is the only path we should expect to support.

I have found attention to be a bit frustrating, however. For example, running decompositions on an exported program with an sdpa op sometimes converts sdpa into a slightly different attention op- even when attention itself isn't getting decomposed. Merely running decompositions at all actually retraces the graph with a different tool, and may select different ops further varied based on other factors like the torch device used by the inputs.

In any case, I don't mind adding a decomposition pattern. We have a bit more control with a pattern like this as opposed to fx decompositions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the fx_importer is the only path we should expect to support.

I have found attention to be a bit frustrating, however. For example, running decompositions on an exported program with an sdpa op sometimes converts sdpa into a slightly different attention op- even when attention itself isn't getting decomposed. Merely running decompositions at all actually retraces the graph with a different tool, and may select different ops further varied based on other factors like the torch device used by the inputs.

In any case, I don't mind adding a decomposition pattern. We have a bit more control with a pattern like this as opposed to fx decompositions.


addPatternIfTargetOpIsIllegal<DecomposeAten_WeightNormInterfaceOp>(
patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSoftmaxIntOp>(patterns);
Expand Down
20 changes: 20 additions & 0 deletions lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//===----------------------------------------------------------------------===//

#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"

using namespace mlir;
Expand Down Expand Up @@ -40,6 +41,25 @@ static void setupValueTensorToBuiltinTensorConversion(
return {};
return ToBuiltinTensorOp::create(builder, loc, type, inputs[0]);
});
typeConverter.addTargetMaterialization([](OpBuilder &builder, Type type,
ValueRange inputs,
Location loc) -> Value {
if (inputs.size() != 1)
return Value();
auto fromType = dyn_cast<RankedTensorType>(inputs[0].getType());
auto toType = dyn_cast<RankedTensorType>(type);
if (!fromType || !toType)
return Value();
if (fromType == toType)
return inputs[0];
if (fromType.getElementType() != toType.getElementType())
return Value();
if (!toType.hasStaticShape())
return Value();
if (!tensor::CastOp::areCastCompatible(inputs[0].getType(), toType))
return Value();
return tensor::CastOp::create(builder, loc, toType, inputs[0]);
});
auto sourceMaterialization = [](OpBuilder &builder,
Torch::ValueTensorType type,
ValueRange inputs, Location loc) -> Value {
Expand Down
10 changes: 0 additions & 10 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,8 @@
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionMaskModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScaledDotProductAttentionSameModule_basic",
}

LINALG_CRASHING_SET = {
Expand Down Expand Up @@ -953,11 +950,8 @@
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionMaskModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScaledDotProductAttentionSameModule_basic",
"SubIntModule_basic",
"TensorToIntZeroRank_basic",
"UpSampleNearest2dDynamicFactor_basic",
Expand Down Expand Up @@ -3978,11 +3972,8 @@
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionMaskModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScaledDotProductAttentionSameModule_basic",
"ScaledDotProductAttentionGQAModule_basic",
# error: 'tosa.scatter' op requires dimensions K >= W
"IndexPut1DFloatNonAccumulateModule_basic",
Expand Down Expand Up @@ -4887,7 +4878,6 @@
# REMOVE WHEN ENABLE_GQA IS ADDED
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScatterAddDynamicModule_basic",
"ScatterReduceFloatMaxModule",
"ScatterReduceFloatMaxModuleIncludeSelf",
Expand Down
Loading
Loading