Skip to content

Commit 4707d3b

Browse files
[MLIR][ONNX] Add OnnxToTorch support for Bernoulli and CastLike op
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
1 parent 35e8f86 commit 4707d3b

File tree

2 files changed

+160
-15
lines changed

2 files changed

+160
-15
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp

Lines changed: 101 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,28 @@ using namespace mlir;
1515
using namespace mlir::torch;
1616
using namespace mlir::torch::onnx_c;
1717

18+
static int64_t onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) {
19+
int64_t dtypeIntTorch;
20+
// TODO: Add complete mapping.
21+
switch (dtypeIntOnnx) {
22+
case 1:
23+
dtypeIntTorch = 6; // float
24+
break;
25+
case 10:
26+
dtypeIntTorch = 5; // half
27+
break;
28+
case 11:
29+
dtypeIntTorch = 7; // double
30+
break;
31+
case 16:
32+
dtypeIntTorch = 15; // bfloat16
33+
break;
34+
default:
35+
dtypeIntTorch = -1; // No dtype
36+
}
37+
return dtypeIntTorch;
38+
}
39+
1840
// Simple rewrites for the default domain.
1941
// See: https://onnx.ai/onnx/operators/
2042
// For operators that are effectively version invariant, we register with
@@ -311,6 +333,53 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
311333
}
312334
return failure();
313335
});
336+
patterns.onOp(
337+
"Bernoulli", 15,
338+
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
339+
Torch::ValueTensorType resultType;
340+
Value input;
341+
int64_t dtypeIntOnnx, dtypeIntTorch;
342+
if (binder.tensorOperand(input) ||
343+
binder.s64IntegerAttr(dtypeIntOnnx, "dtype", -1) ||
344+
binder.tensorResultType(resultType))
345+
return failure();
346+
347+
SmallString<64> name("torch.onnx.");
348+
name.append("seed");
349+
auto attr = binder.op->getAttr(name);
350+
if (attr) {
351+
return rewriter.notifyMatchFailure(
352+
binder.op,
353+
"unimplemented: support not present for seed attribute");
354+
}
355+
356+
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
357+
Value bernoulli = rewriter.create<Torch::AtenBernoulliOp>(
358+
binder.getLoc(), input.getType(), input, /*generator=*/none);
359+
360+
if (dtypeIntOnnx == -1) {
361+
// True, if dtype attribute value is not present.
362+
rewriter.replaceOp(binder.op, bernoulli);
363+
return success();
364+
}
365+
dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
366+
if (dtypeIntTorch == -1) {
367+
return rewriter.notifyMatchFailure(
368+
binder.op,
369+
"unimplemented support for the given dtype conversion");
370+
}
371+
Value constDtype = rewriter.create<Torch::ConstantIntOp>(
372+
binder.getLoc(), rewriter.getType<Torch::IntType>(),
373+
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
374+
dtypeIntTorch));
375+
Value cstFalse =
376+
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
377+
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
378+
binder.op, resultType, bernoulli, constDtype,
379+
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
380+
/*memory_format=*/none);
381+
return success();
382+
});
314383
patterns.onOp(
315384
"BitShift", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
316385
Torch::ValueTensorType resultType;
@@ -386,21 +455,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
386455
binder.tensorResultType(resultType))
387456
return failure();
388457

389-
// TODO: Add complete mapping.
390-
switch (dtypeIntOnnx) {
391-
case 1:
392-
dtypeIntTorch = 6; // float
393-
break;
394-
case 10:
395-
dtypeIntTorch = 5; // half
396-
break;
397-
case 11:
398-
dtypeIntTorch = 7; // double
399-
break;
400-
case 16:
401-
dtypeIntTorch = 15; // bfloat16
402-
break;
403-
default:
458+
dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
459+
if (dtypeIntTorch == -1) {
404460
return rewriter.notifyMatchFailure(
405461
binder.op,
406462
"unimplemented support for the given dtype conversion");
@@ -418,6 +474,36 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
418474
/*memory_format=*/none);
419475
return success();
420476
});
477+
patterns.onOp(
478+
"CastLike", 15, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
479+
Torch::ValueTensorType resultType;
480+
Value input, target;
481+
if (binder.tensorOperands(input, target) ||
482+
binder.tensorResultType(resultType))
483+
return failure();
484+
485+
// TODO: Add support to handle the `saturate` attribute.
486+
// Ignoring it right now, since it's only using during the float8
487+
// conversions which are not supported in Torch-MLIR right now.
488+
489+
Torch::ValueTensorType targetTy =
490+
target.getType().cast<Torch::ValueTensorType>();
491+
if (!targetTy.hasDtype()) {
492+
return rewriter.notifyMatchFailure(binder.op,
493+
"target tensor must have a dtype");
494+
}
495+
Type targetDtype = targetTy.getDtype();
496+
Value constDtype = Torch::getDtypeIntValueForType(
497+
rewriter, binder.getLoc(), targetDtype);
498+
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
499+
Value cstFalse =
500+
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
501+
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
502+
binder.op, resultType, input, constDtype,
503+
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
504+
/*memory_format=*/none);
505+
return success();
506+
});
421507
patterns.onOp("Ceil", 13,
422508
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
423509
Torch::ValueTensorType resultType;

test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,25 @@ func.func @test_acos(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,
110110
return %0 : !torch.vtensor<[3,4,5],f32>
111111
}
112112

113+
// CHECK-LABEL: @test_bernoulli
114+
func.func @test_bernoulli(%arg0: !torch.vtensor<[10],f64>) -> !torch.vtensor<[10],f64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
115+
// CHECK: %[[NONE:.*]] = torch.constant.none
116+
// CHECK: %0 = torch.aten.bernoulli %arg0, %[[NONE]] : !torch.vtensor<[10],f64>, !torch.none -> !torch.vtensor<[10],f64>
117+
%0 = torch.operator "onnx.Bernoulli"(%arg0) : (!torch.vtensor<[10],f64>) -> !torch.vtensor<[10],f64>
118+
return %0 : !torch.vtensor<[10],f64>
119+
}
120+
121+
// CHECK-LABEL: @test_bernoulli_double
122+
func.func @test_bernoulli_double(%arg0: !torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
123+
// CHECK: %[[NONE:.*]] = torch.constant.none
124+
// CHECK: %[[BERNOULLI:.*]] = torch.aten.bernoulli %arg0, %[[NONE]] : !torch.vtensor<[10],f32>, !torch.none -> !torch.vtensor<[10],f32>
125+
// CHECK: %[[DTYPE:.*]] = torch.constant.int 7
126+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
127+
// CHECK: torch.aten.to.dtype %[[BERNOULLI]], %[[DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f64>
128+
%0 = torch.operator "onnx.Bernoulli"(%arg0) {torch.onnx.dtype = 11 : si64} : (!torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f64>
129+
return %0 : !torch.vtensor<[10],f64>
130+
}
131+
113132
// CHECK-LABEL: @test_bitshift_left_uint8
114133
func.func @test_bitshift_left_uint8(%arg0: !torch.vtensor<[3],ui8>, %arg1: !torch.vtensor<[3],ui8>) -> !torch.vtensor<[3],ui8> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
115134
// CHECK: torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui8>, !torch.vtensor<[3],ui8> -> !torch.vtensor<[3],ui8>
@@ -323,6 +342,46 @@ func.func @test_cast_FLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],f16>) -> !torc
323342
return %0 : !torch.vtensor<[3,4],f32>
324343
}
325344

345+
// CHECK-LABEL: @test_castlike_BFLOAT16_to_FLOAT
346+
func.func @test_castlike_BFLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],bf16>, %arg1: !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
347+
// CHECK: %[[INT:.*]] = torch.constant.int 6
348+
// CHECK: %[[NONE:.*]] = torch.constant.none
349+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
350+
// CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],bf16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f32>
351+
%0 = torch.operator "onnx.CastLike"(%arg0, %arg1) : (!torch.vtensor<[3,4],bf16>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32>
352+
return %0 : !torch.vtensor<[3,4],f32>
353+
}
354+
355+
// CHECK-LABEL: @test_castlike_DOUBLE_to_FLOAT
356+
func.func @test_castlike_DOUBLE_to_FLOAT(%arg0: !torch.vtensor<[3,4],f64>, %arg1: !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
357+
// CHECK: %[[INT:.*]] = torch.constant.int 6
358+
// CHECK: %[[NONE:.*]] = torch.constant.none
359+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
360+
// CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f32>
361+
%0 = torch.operator "onnx.CastLike"(%arg0, %arg1) : (!torch.vtensor<[3,4],f64>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32>
362+
return %0 : !torch.vtensor<[3,4],f32>
363+
}
364+
365+
// CHECK-LABEL: @test_castlike_FLOAT_to_DOUBLE
366+
func.func @test_castlike_FLOAT_to_DOUBLE(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],f64>) -> !torch.vtensor<[3,4],f64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
367+
// CHECK: %[[INT:.*]] = torch.constant.int 7
368+
// CHECK: %[[NONE:.*]] = torch.constant.none
369+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
370+
// CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f64>
371+
%0 = torch.operator "onnx.CastLike"(%arg0, %arg1) : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],f64>) -> !torch.vtensor<[3,4],f64>
372+
return %0 : !torch.vtensor<[3,4],f64>
373+
}
374+
375+
// CHECK-LABEL: @test_castlike_FLOAT16_to_FLOAT
376+
func.func @test_castlike_FLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],f16>, %arg1: !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
377+
// CHECK: %[[INT:.*]] = torch.constant.int 6
378+
// CHECK: %[[NONE:.*]] = torch.constant.none
379+
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
380+
// CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f32>
381+
%0 = torch.operator "onnx.CastLike"(%arg0, %arg1) : (!torch.vtensor<[3,4],f16>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32>
382+
return %0 : !torch.vtensor<[3,4],f32>
383+
}
384+
326385
// CHECK-LABEL: @test_ceil_example
327386
func.func @test_ceil_example(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
328387
// CHECK: torch.aten.ceil %arg0 : !torch.vtensor<[2],f32> -> !torch.vtensor<[2],f32>

0 commit comments

Comments
 (0)