Skip to content

Conversation

@zjgarvey
Copy link
Collaborator

The newer direct lowering for backward conv is directly accumulating lower precision types like bf16. This patch adds a check for the default accumulator type. If this type doesn't match the result types for the op, it will also introduce a downcasting elementwise op (post-convolution and pre-collapsing for groups).

Signed-off-by: zjgarvey <zjgarvey@gmail.com>
@zjgarvey zjgarvey requested review from IanWood1 and a-sidorova and removed request for a-sidorova December 15, 2025 21:31
// CHECK-LABEL: func.func @convolution_backward_weights_2x2s_2x2p_2x2d_4g_bf16(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,33,33],bf16>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],bf16>,
// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,32,2,2],bf16>) -> (!torch.vtensor<[16,32,2,2],bf16>, !torch.vtensor<[16],bf16>) {
func.func @convolution_backward_weights_2x2s_2x2p_2x2d_4g_bf16(%arg0: !torch.vtensor<[2,16,33,33],bf16>, %arg1: !torch.vtensor<[2,128,64,64],bf16>, %arg2: !torch.vtensor<[16,32,2,2],bf16>) -> (!torch.vtensor<[16,32,2,2],bf16>, !torch.vtensor<[16],bf16>) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need an e2e test? Will one of the test from 3cebce2 error out if changed to bf16?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current e2e tests use the decomposition for this op, so won't encounter this logic anyway. Maybe there is a case where the decomposition fails, in which case we would be able to test this e2e. In any case, I'll try it out locally and see.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I tried locally, and we don't even support bfloat16 in the e2e tests.

Hacking a few things through to enable testing bf16, all of the tests report numerics mismatches against pytorch for this dtype- whether the decomposition is enabled or not, but at least I get one fewer mismatched individual tensor when including these changes vs. not (when locally disabling the decomposition).

I wouldn't expect much else honestly. Testing lower-precision dtypes through our ref-backend against pytorch's cpu implementation seems a bit hyper-specific. IIRC pytorch cpu often accumulates to float64 instead of float32, but I'd have to double check that. If the device ends up being important for choosing accumulator dtype, we can try to push through a device-info attribute to inform accumulator type selection in the future.

Signed-off-by: zjgarvey <zjgarvey@gmail.com>
Copy link
Collaborator

@a-sidorova a-sidorova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zjgarvey thank you for quick fix!

Comment on lines +480 to +481
%3 = torch.prim.ListConstruct %false, %true, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list<bool>
%result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %1, %1, %false, %2, %int4, %3 : !torch.vtensor<[2,16,33,33],bf16>, !torch.vtensor<[2,128,64,64],bf16>, !torch.vtensor<[16,32,2,2],bf16>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int, !torch.list<bool> -> !torch.none, !torch.vtensor<[16,32,2,2],bf16>, !torch.vtensor<[16],bf16>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: for better test coverage we can set [true, true, true] here to validate the calculation of input gradient as well. But I don't insist.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants