From 06d90d481d84da7bd08a248080f86f3d29c15629 Mon Sep 17 00:00:00 2001 From: Rob Elliott Date: Thu, 18 Dec 2025 11:03:15 +0000 Subject: [PATCH 1/2] Arm backend: Add log1p decomposition * Decompose log1p and mark as supported in TOSA FP profile * Add log1p operator tests with data around the epsilon value * Add log1p to table-op LUT mapping for quantized path Change-Id: I34695a95e76e1291f0ccd52df8b7254260e74240 --- backends/arm/_passes/__init__.py | 1 + backends/arm/_passes/arm_pass_manager.py | 2 + backends/arm/_passes/decompose_log1p_pass.py | 97 +++++++++++++++++++ backends/arm/_passes/insert_table_ops.py | 1 + .../tosa_profile_supported_op_lists.py | 4 +- .../arm/quantizer/quantization_annotator.py | 1 + backends/arm/test/ops/test_log1p.py | 78 +++++++++++++++ 7 files changed, 183 insertions(+), 1 deletion(-) create mode 100644 backends/arm/_passes/decompose_log1p_pass.py create mode 100644 backends/arm/test/ops/test_log1p.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 2904e64a658..798c25c9ee0 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -56,6 +56,7 @@ from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa from .decompose_linalg_vector_norm_pass import DecomposeLinalgVectorNormPass # noqa from .decompose_linear_pass import DecomposeLinearPass # noqa +from .decompose_log1p_pass import DecomposeLog1pPass # noqa from .decompose_logit_pass import DecomposeLogitPass # noqa from .decompose_masked_fill_pass import DecomposeMaskedFillPass # noqa from .decompose_maxpool2d_with_dilation_pass import DecomposeMaxPool2dPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index dc418f18d27..7c77d779cfa 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -59,6 +59,7 @@ DecomposeLeakyReLUPass, DecomposeLinalgVectorNormPass, DecomposeLinearPass, + DecomposeLog1pPass, DecomposeLogitPass, DecomposeMaskedFillPass, DecomposeMaxPool2dPass, @@ -268,6 +269,7 @@ def _tosa_pipeline( DecomposeEluPass(), DecomposeExpm1Pass(), DecomposeIntPowPass(), + DecomposeLog1pPass(), PromoteBoolOperandsPass(), DecomposeSinhPass(), DecomposeSignPass(), diff --git a/backends/arm/_passes/decompose_log1p_pass.py b/backends/arm/_passes/decompose_log1p_pass.py new file mode 100644 index 00000000000..4531d66c560 --- /dev/null +++ b/backends/arm/_passes/decompose_log1p_pass.py @@ -0,0 +1,97 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Set, Type + +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass +from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass +from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass +from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import ( + ReplaceScalarWithTensorByProfilePass, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + + +class DecomposeLog1pPass(ArmPass): + """Decompose log1p into a small polynomial with a log fallback for larger inputs.""" + + _passes_required_after: Set[Type[ExportPass]] = { + InsertTableOpsPass, + MatchArgRanksPass, + MatchArgDtypePass, + ReplaceScalarWithTensorByProfilePass, + } + + _supported_ops = { + exir_ops.edge.aten.log1p.default, + } + + def _poly(self, x, meta): + # 6-term Taylor: x - x^2/2 + x^3/3 - x^4/4 + x^5/5 - x^6/6 + op_mul = exir_ops.edge.aten.mul.Tensor + op_mul_scalar = exir_ops.edge.aten.mul.Scalar + op_add = exir_ops.edge.aten.add.Tensor + + x2 = super().call_operator(op_mul, (x, x), {}, meta, updated=True) + x3 = super().call_operator(op_mul, (x2, x), {}, meta, updated=True) + x4 = super().call_operator(op_mul, (x3, x), {}, meta, updated=True) + x5 = super().call_operator(op_mul, (x4, x), {}, meta, updated=True) + x6 = super().call_operator(op_mul, (x5, x), {}, meta, updated=True) + + t2 = super().call_operator(op_mul_scalar, (x2, -0.5), {}, meta, updated=True) + t3 = super().call_operator( + op_mul_scalar, (x3, 1.0 / 3.0), {}, meta, updated=True + ) + t4 = super().call_operator(op_mul_scalar, (x4, -0.25), {}, meta, updated=True) + t5 = super().call_operator(op_mul_scalar, (x5, 0.2), {}, meta, updated=True) + t6 = super().call_operator( + op_mul_scalar, (x6, -1.0 / 6.0), {}, meta, updated=True + ) + + acc = super().call_operator(op_add, (x, t2), {}, meta, updated=True) + acc = super().call_operator(op_add, (acc, t3), {}, meta, updated=True) + acc = super().call_operator(op_add, (acc, t4), {}, meta, updated=True) + acc = super().call_operator(op_add, (acc, t5), {}, meta, updated=True) + acc = super().call_operator(op_add, (acc, t6), {}, meta, updated=True) + return acc + + def call_operator(self, op, args, kwargs, meta): + if op not in self._supported_ops: + return super().call_operator(op, args, kwargs, meta, updated=False) + + is_quantized = ( + len(meta.data.get("input_qparams", {})) > 0 + and len(meta.data.get("output_qparams", {})) > 0 + ) + if is_quantized: + # Quantized log1p should be handled by LUT/table instead of decomposition. + return super().call_operator(op, args, kwargs, meta) + + logging.info("Decomposing log1p via polynomial + log branch for FP profile.") + + x = args[0] + approx = self._poly(x, meta) + + # For x > 1e-3, use log(1+x) directly. + op_add_scalar = exir_ops.edge.aten.add.Scalar + op_gt = exir_ops.edge.aten.gt.Scalar + op_where = exir_ops.edge.aten.where.self + op_log = exir_ops.edge.aten.log.default + + one_plus_x = super().call_operator( + op_add_scalar, (x, 1.0), {}, meta, updated=True + ) + log_branch = super().call_operator( + op_log, (one_plus_x,), {}, meta, updated=True + ) + + mask = super().call_operator(op_gt, (x, 1e-3), {}, meta, updated=True) + result = super().call_operator( + op_where, (mask, log_branch, approx), {}, meta, updated=True + ) + return result diff --git a/backends/arm/_passes/insert_table_ops.py b/backends/arm/_passes/insert_table_ops.py index 27de85e5ba9..a342b9d56b6 100644 --- a/backends/arm/_passes/insert_table_ops.py +++ b/backends/arm/_passes/insert_table_ops.py @@ -37,6 +37,7 @@ class TableOps: exir_ops.edge.aten.expm1.default: torch.expm1, exir_ops.edge.aten.floor.default: torch.floor, exir_ops.edge.aten.log.default: torch.log, + exir_ops.edge.aten.log1p.default: torch.log1p, exir_ops.edge.aten.reciprocal.default: torch.reciprocal, exir_ops.edge.aten.rsqrt.default: torch.rsqrt, exir_ops.edge.aten.sigmoid.default: torch.sigmoid, diff --git a/backends/arm/operator_support/tosa_profile_supported_op_lists.py b/backends/arm/operator_support/tosa_profile_supported_op_lists.py index f4f72690345..22b89b1bfe9 100644 --- a/backends/arm/operator_support/tosa_profile_supported_op_lists.py +++ b/backends/arm/operator_support/tosa_profile_supported_op_lists.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025 Arm LimitAed and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -54,6 +54,7 @@ exir_ops.edge.aten.exp.default, exir_ops.edge.aten.expm1.default, exir_ops.edge.aten.log.default, + exir_ops.edge.aten.log1p.default, exir_ops.edge.aten.linear.default, exir_ops.edge.aten.split_with_sizes_copy.default, exir_ops.edge.aten.split_copy.Tensor, @@ -154,6 +155,7 @@ exir_ops.edge.aten.erf.default, exir_ops.edge.aten.exp.default, exir_ops.edge.aten.expm1.default, + exir_ops.edge.aten.log1p.default, exir_ops.edge.aten.log.default, exir_ops.edge.aten.linear.default, exir_ops.edge.aten.split_with_sizes_copy.default, diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 60f739c09ad..66e799aadc4 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -369,6 +369,7 @@ def _match_pattern( torch.ops.aten.gelu.default, torch.ops.aten.sinh.default, torch.ops.aten.atan.default, + torch.ops.aten.log1p.default, torch.ops.aten.acosh.default, torch.ops.aten.sign.default, torch.ops.aten.asin.default, diff --git a/backends/arm/test/ops/test_log1p.py b/backends/arm/test/ops/test_log1p.py new file mode 100644 index 00000000000..5455c428be1 --- /dev/null +++ b/backends/arm/test/ops/test_log1p.py @@ -0,0 +1,78 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import ( + TosaPipelineFP, + TosaPipelineINT, + VgfPipeline, +) + +aten_op = "torch.ops.aten.log1p.default" +exir_op = "executorch_exir_dialects_edge__ops_aten_log1p_default" + +input_t1 = Tuple[torch.Tensor] + + +def _tensor(values): + return torch.tensor(values, dtype=torch.float32) + + +test_data_suite = { + "tiny_positive": lambda: (_tensor([5e-4, 8e-4, 9e-4, 1e-3, 1.2e-3]),), + "straddle_eps": lambda: (_tensor([5e-4, 1e-3, 2e-3, -5e-4, -1e-3]),), + "mixed_range": lambda: (_tensor([1e-4, 5e-4, 2e-3, 1e-2, 5e-2]),), +} + + +class Log1p(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.log1p(x) + + +@common.parametrize("test_data", test_data_suite) +def test_log1p_tosa_FP(test_data: input_t1): + pipeline = TosaPipelineFP[input_t1](Log1p(), test_data(), aten_op, exir_op) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +def test_log1p_tosa_INT(test_data: input_t1): + pipeline = TosaPipelineINT[input_t1]( + Log1p(), + test_data(), + aten_op, + exir_op, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.SkipIfNoModelConverter +def test_atan_vgf_no_quant(test_data: Tuple): + pipeline = VgfPipeline[input_t1]( + Log1p(), + test_data(), + aten_op, + exir_op, + quantize=False, + ) + pipeline.run() + + +@common.parametrize("test_data", test_data_suite) +@common.SkipIfNoModelConverter +def test_atan_vgf_quant(test_data: Tuple): + pipeline = VgfPipeline[input_t1]( + Log1p(), + test_data(), + aten_op, + exir_op, + quantize=True, + ) + pipeline.run() From aeaf9d7c5c7d7056128c494fb0e5196b2517a7e5 Mon Sep 17 00:00:00 2001 From: Rob Elliott Date: Thu, 18 Dec 2025 15:07:10 +0000 Subject: [PATCH 2/2] minor review fixes Signed-off-by: Rob Elliott --- .../arm/operator_support/tosa_profile_supported_op_lists.py | 2 +- backends/arm/test/ops/test_log1p.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/backends/arm/operator_support/tosa_profile_supported_op_lists.py b/backends/arm/operator_support/tosa_profile_supported_op_lists.py index 22b89b1bfe9..e74d0a98e37 100644 --- a/backends/arm/operator_support/tosa_profile_supported_op_lists.py +++ b/backends/arm/operator_support/tosa_profile_supported_op_lists.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm LimitAed and/or its affiliates. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. diff --git a/backends/arm/test/ops/test_log1p.py b/backends/arm/test/ops/test_log1p.py index 5455c428be1..a8ca4a9421b 100644 --- a/backends/arm/test/ops/test_log1p.py +++ b/backends/arm/test/ops/test_log1p.py @@ -54,7 +54,7 @@ def test_log1p_tosa_INT(test_data: input_t1): @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_atan_vgf_no_quant(test_data: Tuple): +def test_log1p_vgf_no_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( Log1p(), test_data(), @@ -67,7 +67,7 @@ def test_atan_vgf_no_quant(test_data: Tuple): @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter -def test_atan_vgf_quant(test_data: Tuple): +def test_log1p_vgf_quant(test_data: Tuple): pipeline = VgfPipeline[input_t1]( Log1p(), test_data(),