@@ -636,3 +636,50 @@ def test_qat_produces_same_graph_as_ptq():
636636 qat_quantized_model .graph .nodes , ptq_quantized_model .graph .nodes
637637 )
638638 )
639+
640+
641+ @pytest .mark .parametrize ("conv_module" , ["conv1d" , "conv2d" , "conv2d_t" ])
642+ @pytest .mark .parametrize ("conv_bias" , [True , False ])
643+ @pytest .mark .parametrize ("bn_affine" , [True , False ])
644+ def test_torchao_native_conv_bn_qat_fusing (conv_module , conv_bias , bn_affine ):
645+ if not conv_bias and bn_affine :
646+ pytest .skip ("Conv without bias is only supported if BN has no affine." )
647+
648+ if conv_module .startswith ("conv1d" ):
649+ input_shape = (1 , 3 , 32 )
650+ elif conv_module .startswith ("conv2d" ):
651+ input_shape = (1 , 3 , 32 , 32 )
652+
653+ model = models .ConvBNModule (
654+ conv_module = conv_module ,
655+ conv_bias = conv_bias ,
656+ bn_affine = bn_affine ,
657+ )
658+ model .eval ()
659+
660+ exported_model = export (model , (torch .randn (* input_shape ),), strict = True )
661+ prepared_model = _prepare_for_quantization (exported_model , is_qat = True )
662+ quantized_model = convert_pt2e (prepared_model )
663+
664+ def is_conv (node ):
665+ return node .op == "call_function" and node .target in [
666+ torch .ops .aten .conv1d .default ,
667+ torch .ops .aten .conv2d .default ,
668+ torch .ops .aten .conv_transpose2d .input ,
669+ ]
670+
671+ graph_nodes = list (quantized_model .graph .nodes )
672+ conv_node = next (n for n in graph_nodes if is_conv (n ))
673+ conv_node_args = conv_node .args
674+
675+ if len (conv_node_args ) > 3 :
676+ conv_node_args = conv_node_args [:3 ]
677+
678+ assert len ([n for n in graph_nodes if "batch_norm" in n .name ]) == 0
679+ assert (
680+ len (conv_node .users ) == 1
681+ and list (conv_node .users .keys ())[0 ].target
682+ == torch .ops .quantized_decomposed .quantize_per_tensor .default
683+ )
684+ assert all (arg .name .startswith ("dequantize" ) for arg in conv_node_args )
685+ assert len (graph_nodes ) == 15
0 commit comments