@@ -110,6 +110,25 @@ func.func @test_acos(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,
110110 return %0 : !torch.vtensor <[3 ,4 ,5 ],f32 >
111111}
112112
113+ // CHECK-LABEL: @test_bernoulli
114+ func.func @test_bernoulli (%arg0: !torch.vtensor <[10 ],f64 >) -> !torch.vtensor <[10 ],f64 > attributes {torch.onnx_meta.ir_version = 8 : si64 , torch.onnx_meta.opset_version = 15 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
115+ // CHECK: %[[NONE:.*]] = torch.constant.none
116+ // CHECK: %0 = torch.aten.bernoulli %arg0, %[[NONE]] : !torch.vtensor<[10],f64>, !torch.none -> !torch.vtensor<[10],f64>
117+ %0 = torch.operator " onnx.Bernoulli" (%arg0 ) : (!torch.vtensor <[10 ],f64 >) -> !torch.vtensor <[10 ],f64 >
118+ return %0 : !torch.vtensor <[10 ],f64 >
119+ }
120+
121+ // CHECK-LABEL: @test_bernoulli_double
122+ func.func @test_bernoulli_double (%arg0: !torch.vtensor <[10 ],f32 >) -> !torch.vtensor <[10 ],f64 > attributes {torch.onnx_meta.ir_version = 8 : si64 , torch.onnx_meta.opset_version = 15 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
123+ // CHECK: %[[NONE:.*]] = torch.constant.none
124+ // CHECK: %[[BERNOULLI:.*]] = torch.aten.bernoulli %arg0, %[[NONE]] : !torch.vtensor<[10],f32>, !torch.none -> !torch.vtensor<[10],f32>
125+ // CHECK: %[[DTYPE:.*]] = torch.constant.int 7
126+ // CHECK: %[[FALSE:.*]] = torch.constant.bool false
127+ // CHECK: torch.aten.to.dtype %[[BERNOULLI]], %[[DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f64>
128+ %0 = torch.operator " onnx.Bernoulli" (%arg0 ) {torch.onnx.dtype = 11 : si64 } : (!torch.vtensor <[10 ],f32 >) -> !torch.vtensor <[10 ],f64 >
129+ return %0 : !torch.vtensor <[10 ],f64 >
130+ }
131+
113132// CHECK-LABEL: @test_bitshift_left_uint8
114133func.func @test_bitshift_left_uint8 (%arg0: !torch.vtensor <[3 ],ui8 >, %arg1: !torch.vtensor <[3 ],ui8 >) -> !torch.vtensor <[3 ],ui8 > attributes {torch.onnx_meta.ir_version = 6 : si64 , torch.onnx_meta.opset_version = 11 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
115134 // CHECK: torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui8>, !torch.vtensor<[3],ui8> -> !torch.vtensor<[3],ui8>
@@ -323,6 +342,46 @@ func.func @test_cast_FLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],f16>) -> !torc
323342 return %0 : !torch.vtensor <[3 ,4 ],f32 >
324343}
325344
345+ // CHECK-LABEL: @test_castlike_BFLOAT16_to_FLOAT
346+ func.func @test_castlike_BFLOAT16_to_FLOAT (%arg0: !torch.vtensor <[3 ,4 ],bf16 >, %arg1: !torch.vtensor <[1 ],f32 >) -> !torch.vtensor <[3 ,4 ],f32 > attributes {torch.onnx_meta.ir_version = 9 : si64 , torch.onnx_meta.opset_version = 19 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
347+ // CHECK: %[[INT:.*]] = torch.constant.int 6
348+ // CHECK: %[[NONE:.*]] = torch.constant.none
349+ // CHECK: %[[FALSE:.*]] = torch.constant.bool false
350+ // CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],bf16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f32>
351+ %0 = torch.operator " onnx.CastLike" (%arg0 , %arg1 ) : (!torch.vtensor <[3 ,4 ],bf16 >, !torch.vtensor <[1 ],f32 >) -> !torch.vtensor <[3 ,4 ],f32 >
352+ return %0 : !torch.vtensor <[3 ,4 ],f32 >
353+ }
354+
355+ // CHECK-LABEL: @test_castlike_DOUBLE_to_FLOAT
356+ func.func @test_castlike_DOUBLE_to_FLOAT (%arg0: !torch.vtensor <[3 ,4 ],f64 >, %arg1: !torch.vtensor <[1 ],f32 >) -> !torch.vtensor <[3 ,4 ],f32 > attributes {torch.onnx_meta.ir_version = 9 : si64 , torch.onnx_meta.opset_version = 19 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
357+ // CHECK: %[[INT:.*]] = torch.constant.int 6
358+ // CHECK: %[[NONE:.*]] = torch.constant.none
359+ // CHECK: %[[FALSE:.*]] = torch.constant.bool false
360+ // CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f32>
361+ %0 = torch.operator " onnx.CastLike" (%arg0 , %arg1 ) : (!torch.vtensor <[3 ,4 ],f64 >, !torch.vtensor <[1 ],f32 >) -> !torch.vtensor <[3 ,4 ],f32 >
362+ return %0 : !torch.vtensor <[3 ,4 ],f32 >
363+ }
364+
365+ // CHECK-LABEL: @test_castlike_FLOAT_to_DOUBLE
366+ func.func @test_castlike_FLOAT_to_DOUBLE (%arg0: !torch.vtensor <[3 ,4 ],f32 >, %arg1: !torch.vtensor <[1 ],f64 >) -> !torch.vtensor <[3 ,4 ],f64 > attributes {torch.onnx_meta.ir_version = 9 : si64 , torch.onnx_meta.opset_version = 19 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
367+ // CHECK: %[[INT:.*]] = torch.constant.int 7
368+ // CHECK: %[[NONE:.*]] = torch.constant.none
369+ // CHECK: %[[FALSE:.*]] = torch.constant.bool false
370+ // CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f64>
371+ %0 = torch.operator " onnx.CastLike" (%arg0 , %arg1 ) : (!torch.vtensor <[3 ,4 ],f32 >, !torch.vtensor <[1 ],f64 >) -> !torch.vtensor <[3 ,4 ],f64 >
372+ return %0 : !torch.vtensor <[3 ,4 ],f64 >
373+ }
374+
375+ // CHECK-LABEL: @test_castlike_FLOAT16_to_FLOAT
376+ func.func @test_castlike_FLOAT16_to_FLOAT (%arg0: !torch.vtensor <[3 ,4 ],f16 >, %arg1: !torch.vtensor <[1 ],f32 >) -> !torch.vtensor <[3 ,4 ],f32 > attributes {torch.onnx_meta.ir_version = 9 : si64 , torch.onnx_meta.opset_version = 19 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
377+ // CHECK: %[[INT:.*]] = torch.constant.int 6
378+ // CHECK: %[[NONE:.*]] = torch.constant.none
379+ // CHECK: %[[FALSE:.*]] = torch.constant.bool false
380+ // CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f32>
381+ %0 = torch.operator " onnx.CastLike" (%arg0 , %arg1 ) : (!torch.vtensor <[3 ,4 ],f16 >, !torch.vtensor <[1 ],f32 >) -> !torch.vtensor <[3 ,4 ],f32 >
382+ return %0 : !torch.vtensor <[3 ,4 ],f32 >
383+ }
384+
326385// CHECK-LABEL: @test_ceil_example
327386func.func @test_ceil_example (%arg0: !torch.vtensor <[2 ],f32 >) -> !torch.vtensor <[2 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
328387 // CHECK: torch.aten.ceil %arg0 : !torch.vtensor<[2],f32> -> !torch.vtensor<[2],f32>
0 commit comments