Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
from .decompose_linalg_vector_norm_pass import DecomposeLinalgVectorNormPass # noqa
from .decompose_linear_pass import DecomposeLinearPass # noqa
from .decompose_log1p_pass import DecomposeLog1pPass # noqa
from .decompose_logit_pass import DecomposeLogitPass # noqa
from .decompose_masked_fill_pass import DecomposeMaskedFillPass # noqa
from .decompose_maxpool2d_with_dilation_pass import DecomposeMaxPool2dPass # noqa
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
DecomposeLeakyReLUPass,
DecomposeLinalgVectorNormPass,
DecomposeLinearPass,
DecomposeLog1pPass,
DecomposeLogitPass,
DecomposeMaskedFillPass,
DecomposeMaxPool2dPass,
Expand Down Expand Up @@ -268,6 +269,7 @@ def _tosa_pipeline(
DecomposeEluPass(),
DecomposeExpm1Pass(),
DecomposeIntPowPass(),
DecomposeLog1pPass(),
PromoteBoolOperandsPass(),
DecomposeSinhPass(),
DecomposeSignPass(),
Expand Down
97 changes: 97 additions & 0 deletions backends/arm/_passes/decompose_log1p_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import Set, Type

from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
from executorch.backends.arm._passes.match_arg_dtype_pass import MatchArgDtypePass
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.backends.arm._passes.replace_scalar_with_tensor_pass import (
ReplaceScalarWithTensorByProfilePass,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass


class DecomposeLog1pPass(ArmPass):
"""Decompose log1p into a small polynomial with a log fallback for larger inputs."""

_passes_required_after: Set[Type[ExportPass]] = {
InsertTableOpsPass,
MatchArgRanksPass,
MatchArgDtypePass,
ReplaceScalarWithTensorByProfilePass,
}

_supported_ops = {
exir_ops.edge.aten.log1p.default,
}

def _poly(self, x, meta):
# 6-term Taylor: x - x^2/2 + x^3/3 - x^4/4 + x^5/5 - x^6/6
op_mul = exir_ops.edge.aten.mul.Tensor
op_mul_scalar = exir_ops.edge.aten.mul.Scalar
op_add = exir_ops.edge.aten.add.Tensor

x2 = super().call_operator(op_mul, (x, x), {}, meta, updated=True)
x3 = super().call_operator(op_mul, (x2, x), {}, meta, updated=True)
x4 = super().call_operator(op_mul, (x3, x), {}, meta, updated=True)
x5 = super().call_operator(op_mul, (x4, x), {}, meta, updated=True)
x6 = super().call_operator(op_mul, (x5, x), {}, meta, updated=True)

t2 = super().call_operator(op_mul_scalar, (x2, -0.5), {}, meta, updated=True)
t3 = super().call_operator(
op_mul_scalar, (x3, 1.0 / 3.0), {}, meta, updated=True
)
t4 = super().call_operator(op_mul_scalar, (x4, -0.25), {}, meta, updated=True)
t5 = super().call_operator(op_mul_scalar, (x5, 0.2), {}, meta, updated=True)
t6 = super().call_operator(
op_mul_scalar, (x6, -1.0 / 6.0), {}, meta, updated=True
)

acc = super().call_operator(op_add, (x, t2), {}, meta, updated=True)
acc = super().call_operator(op_add, (acc, t3), {}, meta, updated=True)
acc = super().call_operator(op_add, (acc, t4), {}, meta, updated=True)
acc = super().call_operator(op_add, (acc, t5), {}, meta, updated=True)
acc = super().call_operator(op_add, (acc, t6), {}, meta, updated=True)
return acc

def call_operator(self, op, args, kwargs, meta):
if op not in self._supported_ops:
return super().call_operator(op, args, kwargs, meta, updated=False)

is_quantized = (
len(meta.data.get("input_qparams", {})) > 0
and len(meta.data.get("output_qparams", {})) > 0
)
if is_quantized:
# Quantized log1p should be handled by LUT/table instead of decomposition.
return super().call_operator(op, args, kwargs, meta)

logging.info("Decomposing log1p via polynomial + log branch for FP profile.")

x = args[0]
approx = self._poly(x, meta)

# For x > 1e-3, use log(1+x) directly.
op_add_scalar = exir_ops.edge.aten.add.Scalar
op_gt = exir_ops.edge.aten.gt.Scalar
op_where = exir_ops.edge.aten.where.self
op_log = exir_ops.edge.aten.log.default

one_plus_x = super().call_operator(
op_add_scalar, (x, 1.0), {}, meta, updated=True
)
log_branch = super().call_operator(
op_log, (one_plus_x,), {}, meta, updated=True
)

mask = super().call_operator(op_gt, (x, 1e-3), {}, meta, updated=True)
result = super().call_operator(
op_where, (mask, log_branch, approx), {}, meta, updated=True
)
return result
1 change: 1 addition & 0 deletions backends/arm/_passes/insert_table_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class TableOps:
exir_ops.edge.aten.expm1.default: torch.expm1,
exir_ops.edge.aten.floor.default: torch.floor,
exir_ops.edge.aten.log.default: torch.log,
exir_ops.edge.aten.log1p.default: torch.log1p,
exir_ops.edge.aten.reciprocal.default: torch.reciprocal,
exir_ops.edge.aten.rsqrt.default: torch.rsqrt,
exir_ops.edge.aten.sigmoid.default: torch.sigmoid,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
exir_ops.edge.aten.exp.default,
exir_ops.edge.aten.expm1.default,
exir_ops.edge.aten.log.default,
exir_ops.edge.aten.log1p.default,
exir_ops.edge.aten.linear.default,
exir_ops.edge.aten.split_with_sizes_copy.default,
exir_ops.edge.aten.split_copy.Tensor,
Expand Down Expand Up @@ -154,6 +155,7 @@
exir_ops.edge.aten.erf.default,
exir_ops.edge.aten.exp.default,
exir_ops.edge.aten.expm1.default,
exir_ops.edge.aten.log1p.default,
exir_ops.edge.aten.log.default,
exir_ops.edge.aten.linear.default,
exir_ops.edge.aten.split_with_sizes_copy.default,
Expand Down
1 change: 1 addition & 0 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def _match_pattern(
torch.ops.aten.gelu.default,
torch.ops.aten.sinh.default,
torch.ops.aten.atan.default,
torch.ops.aten.log1p.default,
torch.ops.aten.acosh.default,
torch.ops.aten.sign.default,
torch.ops.aten.asin.default,
Expand Down
78 changes: 78 additions & 0 deletions backends/arm/test/ops/test_log1p.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch
from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
TosaPipelineFP,
TosaPipelineINT,
VgfPipeline,
)

aten_op = "torch.ops.aten.log1p.default"
exir_op = "executorch_exir_dialects_edge__ops_aten_log1p_default"

input_t1 = Tuple[torch.Tensor]


def _tensor(values):
return torch.tensor(values, dtype=torch.float32)


test_data_suite = {
"tiny_positive": lambda: (_tensor([5e-4, 8e-4, 9e-4, 1e-3, 1.2e-3]),),
"straddle_eps": lambda: (_tensor([5e-4, 1e-3, 2e-3, -5e-4, -1e-3]),),
"mixed_range": lambda: (_tensor([1e-4, 5e-4, 2e-3, 1e-2, 5e-2]),),
}


class Log1p(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.log1p(x)


@common.parametrize("test_data", test_data_suite)
def test_log1p_tosa_FP(test_data: input_t1):
pipeline = TosaPipelineFP[input_t1](Log1p(), test_data(), aten_op, exir_op)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
def test_log1p_tosa_INT(test_data: input_t1):
pipeline = TosaPipelineINT[input_t1](
Log1p(),
test_data(),
aten_op,
exir_op,
)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
@common.SkipIfNoModelConverter
def test_log1p_vgf_no_quant(test_data: Tuple):
pipeline = VgfPipeline[input_t1](
Log1p(),
test_data(),
aten_op,
exir_op,
quantize=False,
)
pipeline.run()


@common.parametrize("test_data", test_data_suite)
@common.SkipIfNoModelConverter
def test_log1p_vgf_quant(test_data: Tuple):
pipeline = VgfPipeline[input_t1](
Log1p(),
test_data(),
aten_op,
exir_op,
quantize=True,
)
pipeline.run()
Loading