-
Notifications
You must be signed in to change notification settings - Fork 627
[TorchToLinalg] Upcast low precision dtypes for direct backward conv lowering #4408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: zjgarvey <zjgarvey@gmail.com>
| // CHECK-LABEL: func.func @convolution_backward_weights_2x2s_2x2p_2x2d_4g_bf16( | ||
| // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,16,33,33],bf16>, %[[VAL_1:.*]]: !torch.vtensor<[2,128,64,64],bf16>, | ||
| // CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16,32,2,2],bf16>) -> (!torch.vtensor<[16,32,2,2],bf16>, !torch.vtensor<[16],bf16>) { | ||
| func.func @convolution_backward_weights_2x2s_2x2p_2x2d_4g_bf16(%arg0: !torch.vtensor<[2,16,33,33],bf16>, %arg1: !torch.vtensor<[2,128,64,64],bf16>, %arg2: !torch.vtensor<[16,32,2,2],bf16>) -> (!torch.vtensor<[16,32,2,2],bf16>, !torch.vtensor<[16],bf16>) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this need an e2e test? Will one of the test from 3cebce2 error out if changed to bf16?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current e2e tests use the decomposition for this op, so won't encounter this logic anyway. Maybe there is a case where the decomposition fails, in which case we would be able to test this e2e. In any case, I'll try it out locally and see.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I tried locally, and we don't even support bfloat16 in the e2e tests.
Hacking a few things through to enable testing bf16, all of the tests report numerics mismatches against pytorch for this dtype- whether the decomposition is enabled or not, but at least I get one fewer mismatched individual tensor when including these changes vs. not (when locally disabling the decomposition).
I wouldn't expect much else honestly. Testing lower-precision dtypes through our ref-backend against pytorch's cpu implementation seems a bit hyper-specific. IIRC pytorch cpu often accumulates to float64 instead of float32, but I'd have to double check that. If the device ends up being important for choosing accumulator dtype, we can try to push through a device-info attribute to inform accumulator type selection in the future.
Signed-off-by: zjgarvey <zjgarvey@gmail.com>
a-sidorova
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zjgarvey thank you for quick fix!
| %3 = torch.prim.ListConstruct %false, %true, %true : (!torch.bool, !torch.bool, !torch.bool) -> !torch.list<bool> | ||
| %result0, %result1, %result2 = torch.aten.convolution_backward %arg0, %arg1, %arg2, %0, %1, %1, %1, %false, %2, %int4, %3 : !torch.vtensor<[2,16,33,33],bf16>, !torch.vtensor<[2,128,64,64],bf16>, !torch.vtensor<[16,32,2,2],bf16>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int, !torch.list<bool> -> !torch.none, !torch.vtensor<[16,32,2,2],bf16>, !torch.vtensor<[16],bf16> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: for better test coverage we can set [true, true, true] here to validate the calculation of input gradient as well. But I don't insist.
The newer direct lowering for backward conv is directly accumulating lower precision types like bf16. This patch adds a check for the default accumulator type. If this type doesn't match the result types for the op, it will also introduce a downcasting elementwise op (post-convolution and pre-collapsing for groups).