Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions backends/nxp/aten_passes/neutron_aten_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,30 @@
PassType = type[Callable[[torch.fx.GraphModule], PassResult]]


def _get_default_passes(neutron_target_spec, qat_mode: bool = False) -> list[PassType]:
passes = [
SplitGroupConvolution(),
SplitGRUBasedOnNumLayers(),
RemoveNodesWithKnownOutputs(),
FuseLinearAndAddPass(),
MoveActivationBeforeConcat(neutron_target_spec),
]

if not qat_mode:
# In QAT mode, the fusing should happen after the training
# to preserve batch norm stats updating mechanism.
passes.append(FuseBatchNormWithConvPass())
passes.append(FuseBatchNormWithLinearPass())

return passes


class NeutronAtenPassManager(PassManager):

def __init__(
self, neutron_target_spec: NeutronTargetSpec, passes: list[PassType] = None
):
passes: list[PassType] = passes or [
FuseBatchNormWithConvPass(),
FuseBatchNormWithLinearPass(),
SplitGroupConvolution(),
SplitGRUBasedOnNumLayers(),
RemoveNodesWithKnownOutputs(),
FuseLinearAndAddPass(),
MoveActivationBeforeConcat(neutron_target_spec),
]

passes: list[PassType] = passes or _get_default_passes(neutron_target_spec)
super().__init__(passes)

def __call__(self, module: nn.Module) -> PassResult:
Expand Down
10 changes: 9 additions & 1 deletion backends/nxp/quantizer/neutron_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import torch
from executorch.backends.nxp.aten_passes.neutron_aten_pass_manager import (
_get_default_passes,
NeutronAtenPassManager,
)

Expand All @@ -17,6 +18,7 @@
AddmmPattern,
AddTensorPattern,
AvgPoolPattern,
BatchNormPattern,
CatPattern,
Conv1dPattern,
Conv2dPattern,
Expand Down Expand Up @@ -245,6 +247,7 @@ def __init__(self, neutron_target_spec: NeutronTargetSpec, is_qat: bool = False)
OpQuantizer(AddTensorPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(AddmmPattern(self, is_qat=is_qat), static_fc_qconfig),
OpQuantizer(AvgPoolPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(BatchNormPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(CatPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(Conv1dPattern(is_qat=is_qat), static_qconfig),
OpQuantizer(Conv2dPattern(self, is_qat=is_qat), static_qconfig),
Expand Down Expand Up @@ -293,7 +296,12 @@ def transform_for_annotation(
) -> torch.fx.GraphModule:
model.graph.eliminate_dead_code() # Remove dead code to simplify the graph for the passes.

model = NeutronAtenPassManager(self.neutron_target_spec)(model).graph_module
pass_manager = NeutronAtenPassManager(
self.neutron_target_spec,
_get_default_passes(self.neutron_target_spec, self.is_qat),
)

model = pass_manager(model).graph_module

model.graph.eliminate_dead_code() # Remove dead code again, in case it was created by the passes.

Expand Down
59 changes: 57 additions & 2 deletions backends/nxp/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,27 @@ def get_anchors(
)


class BatchNormPattern(QuantizationPattern):
def __init__(self, is_qat: bool):
super().__init__(is_qat=is_qat)

def partition_types(self) -> list[OpOverload]:
# BatchNorm quantization is needed only when in QAT mode
return [torch.ops.aten.batch_norm.default] if self.is_qat else []

def get_anchors(
self, gm: fx.GraphModule, fused_partition: list[fx.GraphModule]
) -> PartitionAnchors | None:
node = fused_partition[0].nodes[-1]

return PartitionAnchors(
inputs=[],
weights=[],
biases=[],
output=[(node,)],
)


def get_anchors_for_fixed_quant_specs(
fused_partition: list[fx.GraphModule],
scale: float,
Expand Down Expand Up @@ -356,6 +377,14 @@ def get_anchors(
)


def _is_batch_norm(node_: Node) -> bool:
return node_.op == "call_function" and node_.target in [
torch.ops.aten.batch_norm.default,
torch.ops.aten.native_batch_norm.default,
torch.ops.aten._native_batch_norm_legit_no_training.default,
]


class ConvPattern(QuantizationPattern):
@abstractmethod
def partition_types(self) -> list[OpOverload]:
Expand Down Expand Up @@ -398,11 +427,20 @@ def get_anchors(
if len(conv_node.args) > 2 and conv_node.args[2] is not None:
bias = [(conv_node, NodeArgsIdx(2), bias_quantization_qspec)]

output_specs = [(conv_node,)]
# In order for QAT to be numerically correct, there should be no quantization between
# convolution node and batch norm node.
if self.is_qat:
conv_users = conv_node.users
possibly_bn = list(conv_users.keys())[0] if len(conv_users) == 1 else None
if possibly_bn and _is_batch_norm(possibly_bn):
output_specs = []

return PartitionAnchors(
inputs=[(conv_node, NodeArgsIdx(0))],
weights=[(conv_node, NodeArgsIdx(1), weight_quantization_spec)],
biases=bias,
output=[(conv_node,)],
output=output_specs,
)


Expand Down Expand Up @@ -479,6 +517,14 @@ def get_anchors(
output = []
activation.meta["quantization_annotation"].input_qspec_map = {}

# In order for QAT to be numerically correct, there should be no quantization between
# convolution node and batch norm node.
if self.is_qat:
conv_users = conv_node.users
possibly_bn = list(conv_users.keys())[0] if len(conv_users) == 1 else None
if possibly_bn and _is_batch_norm(possibly_bn):
output = []

return PartitionAnchors(
inputs=[(conv_node, NodeArgsIdx(0))],
weights=[(conv_node, NodeArgsIdx(1), weight_quantization_spec)],
Expand Down Expand Up @@ -524,11 +570,20 @@ def get_anchors(
if len(conv_node.args) > 2 and conv_node.args[2] is not None:
bias = [(conv_node, NodeArgsIdx(2), bias_quantization_qspec)]

output_specs = [(conv_node,)]
# In order for QAT to be numerically correct, there should be no quantization between
# convolution node and batch norm node.
if self.is_qat:
conv_users = conv_node.users
possibly_bn = list(conv_users.keys())[0] if len(conv_users) == 1 else None
if possibly_bn and _is_batch_norm(possibly_bn):
output_specs = []

return PartitionAnchors(
inputs=[(conv_node, NodeArgsIdx(0))],
weights=[(conv_node, NodeArgsIdx(1), weight_quantization_spec)],
biases=bias,
output=[(conv_node,)],
output=output_specs,
)


Expand Down
24 changes: 24 additions & 0 deletions backends/nxp/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,30 @@ def forward(self, x):
return self.pool(x)


class ConvBNModule(torch.nn.Module):
def __init__(self, conv_module, conv_bias, bn_affine):
super().__init__()

if conv_module == "conv1d":
self.conv = torch.nn.Conv1d(3, 64, 3, padding=1, bias=conv_bias)
self.bn = torch.nn.BatchNorm1d(64, affine=bn_affine)
elif conv_module == "conv2d":
self.conv = torch.nn.Conv2d(3, 64, 3, padding=1, bias=conv_bias)
self.bn = torch.nn.BatchNorm2d(64, affine=bn_affine)
elif conv_module == "conv1d_t":
self.conv = torch.nn.ConvTranspose1d(3, 64, 3, padding=1, bias=conv_bias)
self.bn = torch.nn.BatchNorm1d(64, affine=bn_affine)
elif conv_module == "conv2d_t":
self.conv = torch.nn.ConvTranspose2d(3, 64, 3, padding=1, bias=conv_bias)
self.bn = torch.nn.BatchNorm2d(64, affine=bn_affine)
else:
raise ValueError(f"Unknown conv_module: {conv_module}")

def forward(self, x):
x = self.conv(x)
return self.bn(x)


class MulTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
48 changes: 48 additions & 0 deletions backends/nxp/tests/test_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,3 +636,51 @@ def test_qat_produces_same_graph_as_ptq():
qat_quantized_model.graph.nodes, ptq_quantized_model.graph.nodes
)
)


# TODO: conv1d_t is currently unsupported, add when resolved
@pytest.mark.parametrize("conv_module", ["conv1d", "conv2d", "conv2d_t"])
@pytest.mark.parametrize("conv_bias", [True, False])
@pytest.mark.parametrize("bn_affine", [True, False])
def test_torchao_native_conv_bn_qat_fusing(conv_module, conv_bias, bn_affine):
if not conv_bias:
pytest.skip("Conv without bias is not supported.")

if conv_module.startswith("conv1d"):
input_shape = (1, 3, 32)
elif conv_module.startswith("conv2d"):
input_shape = (1, 3, 32, 32)

model = models.ConvBNModule(
conv_module=conv_module,
conv_bias=conv_bias,
bn_affine=bn_affine,
)
model.eval()

exported_model = export(model, (torch.randn(*input_shape),), strict=True)
prepared_model = _prepare_for_quantization(exported_model, is_qat=True)
quantized_model = convert_pt2e(prepared_model)

def is_conv(node):
return node.op == "call_function" and node.target in [
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.conv_transpose2d.input,
]

graph_nodes = list(quantized_model.graph.nodes)
conv_node = next(n for n in graph_nodes if is_conv(n))
conv_node_args = conv_node.args

if len(conv_node_args) > 3:
conv_node_args = conv_node_args[:3]

assert len([n for n in graph_nodes if "batch_norm" in n.name]) == 0
assert (
len(conv_node.users) == 1
and list(conv_node.users.keys())[0].target
== torch.ops.quantized_decomposed.quantize_per_tensor.default
)
assert all(arg.name.startswith("dequantize") for arg in conv_node_args)
assert len(graph_nodes) == 15
13 changes: 11 additions & 2 deletions examples/nxp/aot_neutron_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,13 @@ def get_model_and_inputs_from_name(model_name: str):
default=False,
help="Produce a quantized model",
)
parser.add_argument(
"--use_qat",
action="store_true",
required=False,
default=False,
help="Use QAT mode for quantization (does not include QAT training)",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the quantization aware training is not possible using this module, why include it? Just to show how it can be triggered? If so, perhaps a separate example module, or even just a README might be better in my opinion.

)
parser.add_argument(
"-s",
"--so_library",
Expand Down Expand Up @@ -218,8 +225,10 @@ def get_model_and_inputs_from_name(model_name: str):
"No calibration inputs available, using the example inputs instead"
)
calibration_inputs = example_inputs
quantizer = NeutronQuantizer(neutron_target_spec)
module = calibrate_and_quantize(module, calibration_inputs, quantizer)
quantizer = NeutronQuantizer(neutron_target_spec, args.use_qat)
module = calibrate_and_quantize(
module, calibration_inputs, quantizer, is_qat=args.use_qat
)

if args.so_library is not None:
logging.debug(f"Loading libraries: {args.so_library}")
Expand Down
Loading