diff --git a/3rdparty/tvm-ffi b/3rdparty/tvm-ffi index ae346ec92a3c..89be2d3ff90d 160000 --- a/3rdparty/tvm-ffi +++ b/3rdparty/tvm-ffi @@ -1 +1 @@ -Subproject commit ae346ec92a3c386f1376064ae086aae72947c329 +Subproject commit 89be2d3ff90d021355975a52d72e5eefc02bc335 diff --git a/python/tvm/arith/_ffi_api.py b/python/tvm/arith/_ffi_api.py index 519423aa4e1f..12ddd90af574 100644 --- a/python/tvm/arith/_ffi_api.py +++ b/python/tvm/arith/_ffi_api.py @@ -15,7 +15,71 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.arith""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from arith import ConstIntBound, IntConstraints, IntConstraintsTransform, IntGroupBounds, IntervalSet, IterMark, IterSplitExpr, IterSumExpr, ModularSet + from collections.abc import Mapping, Sequence + from ir import IntImm, IntSet, PrimExpr, Range + from tir import Buffer, PrimFunc, Stmt, Var + from tvm_ffi import Object + from typing import Any +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("arith", __name__) + +# tvm-ffi-stubgen(begin): global/arith +# fmt: off +_FFI_INIT_FUNC("arith", __name__) +if TYPE_CHECKING: + def ConstIntBound(_0: int, _1: int, /) -> ConstIntBound: ... + def CreateAnalyzer(*args: Any) -> Any: ... + def DeduceBound(_0: PrimExpr, _1: PrimExpr, _2: Mapping[Var, IntSet], _3: Mapping[Var, IntSet], /) -> IntSet: ... + def DetectClipBound(_0: PrimExpr, _1: Sequence[Var], /) -> Sequence[PrimExpr]: ... + def DetectCommonSubExpr(_0: PrimExpr, _1: int, /) -> Mapping[PrimExpr, IntImm]: ... + def DetectIterMap(_0: Sequence[PrimExpr], _1: Mapping[Var, Range], _2: PrimExpr, _3: int, _4: bool, /) -> Object: ... + def DetectLinearEquation(_0: PrimExpr, _1: Sequence[Var], /) -> Sequence[PrimExpr]: ... + def DomainTouched(_0: Stmt, _1: Buffer, _2: bool, _3: bool, /) -> Sequence[Range]: ... + def DomainTouchedAccessMap(_0: PrimFunc, /) -> Mapping[Buffer, Sequence[Object]]: ... + def EstimateRegionLowerBound(_0: Sequence[Range], _1: Mapping[Var, Range], _2: PrimExpr, /) -> Sequence[IntSet] | None: ... + def EstimateRegionStrictBound(_0: Sequence[Range], _1: Mapping[Var, Range], _2: PrimExpr, /) -> Sequence[IntSet] | None: ... + def EstimateRegionUpperBound(_0: Sequence[Range], _1: Mapping[Var, Range], _2: PrimExpr, /) -> Sequence[IntSet] | None: ... + def IntConstraints(_0: Sequence[Var], _1: Mapping[Var, Range], _2: Sequence[PrimExpr], /) -> IntConstraints: ... + def IntConstraintsTransform(_0: IntConstraints, _1: IntConstraints, _2: Mapping[Var, PrimExpr], _3: Mapping[Var, PrimExpr], /) -> IntConstraintsTransform: ... + def IntGroupBounds(_0: PrimExpr, _1: Sequence[PrimExpr], _2: Sequence[PrimExpr], _3: Sequence[PrimExpr], /) -> IntGroupBounds: ... + def IntGroupBounds_FindBestRange(*args: Any) -> Any: ... + def IntGroupBounds_from_range(_0: Range, /) -> IntGroupBounds: ... + def IntSetIsEverything(_0: IntSet, /) -> bool: ... + def IntSetIsNothing(_0: IntSet, /) -> bool: ... + def IntervalSet(_0: PrimExpr, _1: PrimExpr, /) -> IntervalSet: ... + def IntervalSetGetMax(_0: IntSet, /) -> PrimExpr: ... + def IntervalSetGetMin(_0: IntSet, /) -> PrimExpr: ... + def InverseAffineIterMap(_0: Sequence[IterSumExpr], _1: Sequence[PrimExpr], /) -> Mapping[Var, PrimExpr]: ... + def IterMapSimplify(_0: Sequence[PrimExpr], _1: Mapping[Var, Range], _2: PrimExpr, _3: int, _4: bool, /) -> Sequence[PrimExpr]: ... + def IterMark(_0: PrimExpr, _1: PrimExpr, /) -> IterMark: ... + def IterSplitExpr(_0: IterMark, _1: PrimExpr, _2: PrimExpr, _3: PrimExpr, /) -> IterSplitExpr: ... + def IterSumExpr(_0: Sequence[IterSplitExpr], _1: PrimExpr, /) -> IterSumExpr: ... + def ModularSet(_0: int, _1: int, /) -> ModularSet: ... + def NarrowPredicateExpression(_0: PrimExpr, _1: Mapping[Var, Range], /) -> PrimExpr: ... + def NegInf() -> PrimExpr: ... + def NormalizeIterMapToExpr(_0: PrimExpr, /) -> PrimExpr: ... + def NormalizeToIterSum(_0: PrimExpr, _1: Mapping[Var, Range], /) -> IterSumExpr: ... + def PosInf() -> PrimExpr: ... + def PresburgerSet(_0: PrimExpr, /) -> IntSet: ... + def SolveInequalitiesAsCondition(*args: Any) -> Any: ... + def SolveInequalitiesDeskewRange(*args: Any) -> Any: ... + def SolveInequalitiesToRange(*args: Any) -> Any: ... + def SolveLinearEquations(*args: Any) -> Any: ... + def SubspaceDivide(_0: Sequence[PrimExpr], _1: Mapping[Var, Range], _2: Sequence[Var], _3: PrimExpr, _4: int, _5: bool, /) -> Sequence[Sequence[IterMark]]: ... + def UnionLowerBound(_0: Sequence[IntSet], /) -> IntSet: ... + def intset_interval(_0: PrimExpr, _1: PrimExpr, /) -> IntSet: ... + def intset_single_point(_0: PrimExpr, /) -> IntSet: ... + def intset_vector(_0: PrimExpr, /) -> IntSet: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index c5c8fc067cc8..ec57b263c598 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -16,6 +16,8 @@ # under the License. # pylint: disable=invalid-name """Arithmetic data structure and utility""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) import enum from typing import Union @@ -51,6 +53,13 @@ class Extension(enum.Flag): class ModularSet(Object): """Represent range of (coeff * x + base) for x in Z""" + # tvm-ffi-stubgen(begin): object/arith.ModularSet + # fmt: off + coeff: int + base: int + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, coeff, base): self.__init_handle_by_constructor__(_ffi_api.ModularSet, coeff, base) @@ -68,6 +77,13 @@ class ConstIntBound(Object): The maximum value of the bound. """ + # tvm-ffi-stubgen(begin): object/arith.ConstIntBound + # fmt: off + min_value: int + max_value: int + # fmt: on + # tvm-ffi-stubgen(end) + POS_INF = (1 << 63) - 1 NEG_INF = -POS_INF diff --git a/python/tvm/arith/int_set.py b/python/tvm/arith/int_set.py index fc6c20dec1ce..2b865801d509 100644 --- a/python/tvm/arith/int_set.py +++ b/python/tvm/arith/int_set.py @@ -15,6 +15,16 @@ # specific language governing permissions and limitations # under the License. """Integer set.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ir import PrimExpr +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import tvm_ffi from tvm.runtime import Object from . import _ffi_api @@ -24,6 +34,11 @@ class IntSet(Object): """Represent a set of integer in one dimension.""" + # tvm-ffi-stubgen(begin): object/ir.IntSet + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def is_nothing(self): """Whether the set represent nothing""" return _ffi_api.IntSetIsNothing(self) @@ -78,6 +93,13 @@ class IntervalSet(IntSet): The maximum value in the interval. """ + # tvm-ffi-stubgen(begin): object/arith.IntervalSet + # fmt: off + min_value: PrimExpr + max_value: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, min_value, max_value): self.__init_handle_by_constructor__(_ffi_api.IntervalSet, min_value, max_value) @@ -86,6 +108,11 @@ def __init__(self, min_value, max_value): class PresburgerSet(IntSet): """Represent of Presburger Set""" + # tvm-ffi-stubgen(begin): object/arith.PresburgerSet + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self): self.__init_handle_by_constructor__(_ffi_api.PresburgerSet) diff --git a/python/tvm/arith/int_solver.py b/python/tvm/arith/int_solver.py index 72e4c46896ff..c7f7cafb02e2 100644 --- a/python/tvm/arith/int_solver.py +++ b/python/tvm/arith/int_solver.py @@ -15,6 +15,18 @@ # specific language governing permissions and limitations # under the License. """integer constraints data structures and solvers""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import PrimExpr, Range + from tir import Var +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import tvm_ffi from tvm.runtime import Object from . import _ffi_api @@ -40,6 +52,15 @@ class IntGroupBounds(Object): the upper bounds (include) """ + # tvm-ffi-stubgen(begin): object/arith.IntGroupBounds + # fmt: off + coef: PrimExpr + lower: Sequence[PrimExpr] + equal: Sequence[PrimExpr] + upper: Sequence[PrimExpr] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, coef, lower, equal, upper): self.__init_handle_by_constructor__(_ffi_api.IntGroupBounds, coef, lower, equal, upper) @@ -81,6 +102,14 @@ class IntConstraints(Object): The relations between the variables (either equations or inequalities) """ + # tvm-ffi-stubgen(begin): object/arith.IntConstraints + # fmt: off + variables: Sequence[Var] + ranges: Mapping[Var, Range] + relations: Sequence[PrimExpr] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, variables, ranges, relations): self.__init_handle_by_constructor__(_ffi_api.IntConstraints, variables, ranges, relations) @@ -113,6 +142,15 @@ class IntConstraintsTransform(Object): e.g., {m -> a, n -> -b} """ + # tvm-ffi-stubgen(begin): object/arith.IntConstraintsTransform + # fmt: off + src: IntConstraints + dst: IntConstraints + src_to_dst: Mapping[Var, PrimExpr] + dst_to_src: Mapping[Var, PrimExpr] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, src, dst, src_to_dst, dst_to_src): self.__init_handle_by_constructor__( _ffi_api.IntConstraintsTransform, src, dst, src_to_dst, dst_to_src diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py index 69ad3022fb4a..4dd5a658cbb5 100644 --- a/python/tvm/arith/iter_affine_map.py +++ b/python/tvm/arith/iter_affine_map.py @@ -15,6 +15,17 @@ # specific language governing permissions and limitations # under the License. """Iterator (quasi)affine mapping patterns.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import PrimExpr +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from enum import IntEnum import tvm_ffi from tvm.runtime import Object @@ -26,6 +37,11 @@ class IterMapExpr(PrimExpr): """Base class of all IterMap expressions.""" + # tvm-ffi-stubgen(begin): object/arith.IterMapExpr + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("arith.IterMark") class IterMark(Object): @@ -40,6 +56,13 @@ class IterMark(Object): The extent of the iterator. """ + # tvm-ffi-stubgen(begin): object/arith.IterMark + # fmt: off + source: PrimExpr + extent: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, source, extent): self.__init_handle_by_constructor__(_ffi_api.IterMark, source, extent) @@ -65,6 +88,15 @@ class IterSplitExpr(IterMapExpr): Additional scale to the split. """ + # tvm-ffi-stubgen(begin): object/arith.IterSplitExpr + # fmt: off + source: IterMark + lower_factor: PrimExpr + extent: PrimExpr + scale: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, source, lower_factor, extent, scale): self.__init_handle_by_constructor__( _ffi_api.IterSplitExpr, source, lower_factor, extent, scale @@ -86,6 +118,13 @@ class IterSumExpr(IterMapExpr): The base offset. """ + # tvm-ffi-stubgen(begin): object/arith.IterSumExpr + # fmt: off + args: Sequence[IterSplitExpr] + base: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, args, base): self.__init_handle_by_constructor__(_ffi_api.IterSumExpr, args, base) @@ -94,6 +133,14 @@ def __init__(self, args, base): class IterMapResult(Object): """Result of iter map detection.""" + # tvm-ffi-stubgen(begin): object/arith.IterMapResult + # fmt: off + indices: Sequence[IterSumExpr] + errors: Sequence[str] + padding_predicate: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + class IterMapLevel(IntEnum): """Possible kinds of iter mapping check level.""" diff --git a/python/tvm/contrib/cutlass/_ffi_api.py b/python/tvm/contrib/cutlass/_ffi_api.py index d57825835b6b..d6824bf02101 100644 --- a/python/tvm/contrib/cutlass/_ffi_api.py +++ b/python/tvm/contrib/cutlass/_ffi_api.py @@ -15,6 +15,17 @@ # specific language governing permissions and limitations # under the License. """FFI API for CUTLASS BYOC.""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("contrib.cutlass", __name__) + +# tvm-ffi-stubgen(begin): global/contrib.cutlass +_FFI_INIT_FUNC("contrib.cutlass", __name__) +# tvm-ffi-stubgen(end) diff --git a/python/tvm/contrib/msc/core/_ffi_api.py b/python/tvm/contrib/msc/core/_ffi_api.py index ff027a0dec8e..e2446770e2db 100644 --- a/python/tvm/contrib/msc/core/_ffi_api.py +++ b/python/tvm/contrib/msc/core/_ffi_api.py @@ -15,7 +15,17 @@ # specific language governing permissions and limitations # under the License. """tvm.contrib.msc.core._ffi_api""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -import tvm_ffi -tvm_ffi.init_ffi_api("msc.core", __name__) +# tvm-ffi-stubgen(begin): global/msc.core +_FFI_INIT_FUNC("msc.core", __name__) +# tvm-ffi-stubgen(end) diff --git a/python/tvm/contrib/msc/core/ir/graph.py b/python/tvm/contrib/msc/core/ir/graph.py index 6b40be4bf9de..c0cb4adf5eb6 100644 --- a/python/tvm/contrib/msc/core/ir/graph.py +++ b/python/tvm/contrib/msc/core/ir/graph.py @@ -16,6 +16,8 @@ # under the License. """tvm.contrib.msc.core.ir.graph""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from typing import Dict, Tuple, List, Optional, Union, Iterable, Any import numpy as np import tvm_ffi @@ -46,6 +48,9 @@ class MSCTensor(Object): The prims of the tensor. """ + # tvm-ffi-stubgen(begin): object/msc.core.MSCTensor + # tvm-ffi-stubgen(end) + def __init__( self, name: str, @@ -199,6 +204,9 @@ def ndim(self) -> int: class BaseJoint(Object): """Base class of all MSC Nodes.""" + # tvm-ffi-stubgen(begin): object/msc.core.BaseJoint + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("msc.core.MSCJoint") class MSCJoint(BaseJoint): @@ -224,6 +232,9 @@ class MSCJoint(BaseJoint): The weights of the node. """ + # tvm-ffi-stubgen(begin): object/msc.core.MSCJoint + # tvm-ffi-stubgen(end) + def __init__( self, index: int, @@ -443,6 +454,9 @@ class MSCPrim(BaseJoint): The parents of the prim. """ + # tvm-ffi-stubgen(begin): object/msc.core.MSCPrim + # tvm-ffi-stubgen(end) + def __init__( self, index: int, name: str, optype: str, attrs: Dict[str, str], parents: List[BaseJoint] ): @@ -477,6 +491,9 @@ class WeightJoint(BaseJoint): The friends of the node. """ + # tvm-ffi-stubgen(begin): object/msc.core.WeightJoint + # tvm-ffi-stubgen(end) + def __init__( self, index: int, @@ -567,6 +584,9 @@ def has_attr(self, key: str) -> bool: class BaseGraph(Object): """Base class of all MSC Graphs.""" + # tvm-ffi-stubgen(begin): object/msc.core.BaseGraph + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("msc.core.MSCGraph") class MSCGraph(BaseGraph): @@ -584,6 +604,9 @@ class MSCGraph(BaseGraph): The output names of the graph. """ + # tvm-ffi-stubgen(begin): object/msc.core.MSCGraph + # tvm-ffi-stubgen(end) + def __init__( self, name: str, @@ -969,6 +992,9 @@ class WeightGraph(BaseGraph): The nodes of the graph. """ + # tvm-ffi-stubgen(begin): object/msc.core.WeightGraph + # tvm-ffi-stubgen(end) + def __init__( self, name: str, diff --git a/python/tvm/contrib/msc/framework/tensorflow/_ffi_api.py b/python/tvm/contrib/msc/framework/tensorflow/_ffi_api.py index f7cd2ea43e3e..a51b18e030f8 100644 --- a/python/tvm/contrib/msc/framework/tensorflow/_ffi_api.py +++ b/python/tvm/contrib/msc/framework/tensorflow/_ffi_api.py @@ -15,7 +15,17 @@ # specific language governing permissions and limitations # under the License. """tvm.contrib.msc.framework.tensorflow._ffi_api""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -import tvm_ffi -tvm_ffi.init_ffi_api("msc.framework.tensorflow", __name__) +# tvm-ffi-stubgen(begin): global/msc.framework.tensorflow +_FFI_INIT_FUNC("msc.framework.tensorflow", __name__) +# tvm-ffi-stubgen(end) diff --git a/python/tvm/contrib/msc/framework/tensorrt/_ffi_api.py b/python/tvm/contrib/msc/framework/tensorrt/_ffi_api.py index a09ab875fbed..146282ec8f90 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/_ffi_api.py +++ b/python/tvm/contrib/msc/framework/tensorrt/_ffi_api.py @@ -15,7 +15,17 @@ # specific language governing permissions and limitations # under the License. """tvm.contrib.msc.framework.tensorrt._ffi_api""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -import tvm_ffi -tvm_ffi.init_ffi_api("msc.framework.tensorrt", __name__) +# tvm-ffi-stubgen(begin): global/msc.framework.tensorrt +_FFI_INIT_FUNC("msc.framework.tensorrt", __name__) +# tvm-ffi-stubgen(end) diff --git a/python/tvm/contrib/msc/framework/torch/_ffi_api.py b/python/tvm/contrib/msc/framework/torch/_ffi_api.py index d1f27a53bdcf..a3e88799a8fb 100644 --- a/python/tvm/contrib/msc/framework/torch/_ffi_api.py +++ b/python/tvm/contrib/msc/framework/torch/_ffi_api.py @@ -15,7 +15,17 @@ # specific language governing permissions and limitations # under the License. """tvm.contrib.msc.framework.torch._ffi_api""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -import tvm_ffi -tvm_ffi.init_ffi_api("msc.framework.torch", __name__) +# tvm-ffi-stubgen(begin): global/msc.framework.torch +_FFI_INIT_FUNC("msc.framework.torch", __name__) +# tvm-ffi-stubgen(end) diff --git a/python/tvm/contrib/msc/framework/tvm/_ffi_api.py b/python/tvm/contrib/msc/framework/tvm/_ffi_api.py index c9f63e21eaef..239f553fe977 100644 --- a/python/tvm/contrib/msc/framework/tvm/_ffi_api.py +++ b/python/tvm/contrib/msc/framework/tvm/_ffi_api.py @@ -15,7 +15,17 @@ # specific language governing permissions and limitations # under the License. """tvm.contrib.msc.framework.tvm._ffi_api""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -import tvm_ffi -tvm_ffi.init_ffi_api("msc.framework.tvm", __name__) +# tvm-ffi-stubgen(begin): global/msc.framework.tvm +_FFI_INIT_FUNC("msc.framework.tvm", __name__) +# tvm-ffi-stubgen(end) diff --git a/python/tvm/contrib/msc/plugin/_ffi_api.py b/python/tvm/contrib/msc/plugin/_ffi_api.py index 88f9204f3a02..c1b1f1b0e6d2 100644 --- a/python/tvm/contrib/msc/plugin/_ffi_api.py +++ b/python/tvm/contrib/msc/plugin/_ffi_api.py @@ -15,7 +15,17 @@ # specific language governing permissions and limitations # under the License. """tvm.contrib.msc.plugin._ffi_api""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -import tvm_ffi -tvm_ffi.init_ffi_api("msc.plugin", __name__) +# tvm-ffi-stubgen(begin): global/msc.plugin +_FFI_INIT_FUNC("msc.plugin", __name__) +# tvm-ffi-stubgen(end) diff --git a/python/tvm/contrib/msc/plugin/op/_ffi_api.py b/python/tvm/contrib/msc/plugin/op/_ffi_api.py index 8ca5071cdaf6..d01bdaab758a 100644 --- a/python/tvm/contrib/msc/plugin/op/_ffi_api.py +++ b/python/tvm/contrib/msc/plugin/op/_ffi_api.py @@ -15,7 +15,17 @@ # specific language governing permissions and limitations # under the License. """tvm.contrib.msc.plugin.op._ffi_api""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -import tvm_ffi -tvm_ffi.init_ffi_api("msc.plugin.op", __name__) +# tvm-ffi-stubgen(begin): global/msc.plugin.op +_FFI_INIT_FUNC("msc.plugin.op", __name__) +# tvm-ffi-stubgen(end) diff --git a/python/tvm/driver/_ffi_api.py b/python/tvm/driver/_ffi_api.py index e56426fd5182..62e3efce6d04 100644 --- a/python/tvm/driver/_ffi_api.py +++ b/python/tvm/driver/_ffi_api.py @@ -15,6 +15,17 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.driver""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("driver", __name__) + +# tvm-ffi-stubgen(begin): global/driver +_FFI_INIT_FUNC("driver", __name__) +# tvm-ffi-stubgen(end) diff --git a/python/tvm/ir/_ffi_api.py b/python/tvm/ir/_ffi_api.py index 798e69fca507..4f80a5d4da78 100644 --- a/python/tvm/ir/_ffi_api.py +++ b/python/tvm/ir/_ffi_api.py @@ -15,7 +15,96 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.ir""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import BaseFunc, DictAttrs, DummyGlobalInfo, EnvFunc, FloatImm, FuncType, GlobalInfo, GlobalVar, GlobalVarSupply, IRModule, IntImm, NameSupply, Op, PointerType, PrimExpr, PrimType, Range, RelaxExpr, SequentialSpan, SourceName, Span, StructInfo, TensorMapType, TupleType, Type, VDevice + from target import Target + from tvm_ffi import Object, dtype + from typing import Any, Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("ir", __name__) + +# tvm-ffi-stubgen(begin): global/ir +# fmt: off +_FFI_INIT_FUNC("ir", __name__) +if TYPE_CHECKING: + def BaseFuncCopy(_0: BaseFunc, /) -> BaseFunc: ... + def BaseFuncWithAttr(_0: ObjectRValueRef[BaseFunc], _1: str, _2: Any, /) -> BaseFunc: ... + def BaseFuncWithAttrs(_0: ObjectRValueRef[BaseFunc], _1: Mapping[str, Any], /) -> BaseFunc: ... + def BaseFuncWithoutAttr(_0: ObjectRValueRef[BaseFunc], _1: str, /) -> BaseFunc: ... + def BaseFunc_Attrs(_0: BaseFunc, /) -> DictAttrs: ... + def DebugPrint(_0: Object, /) -> str: ... + def DictAttrsGetDict(_0: DictAttrs, /) -> Mapping[str, Any]: ... + def DummyGlobalInfo() -> DummyGlobalInfo: ... + def EnvFuncCall(*args: Any) -> Any: ... + def EnvFuncGet(_0: str, /) -> EnvFunc: ... + def EnvFuncGetFunction(_0: EnvFunc, /) -> Callable[..., Any]: ... + def ExprStructInfo(_0: RelaxExpr, /) -> StructInfo: ... + def FloatImm(_0: dtype, _1: float, _2: Span, /) -> FloatImm: ... + def FuncType(_0: Sequence[Type], _1: Type, /) -> FuncType: ... + def GetOp(_0: str, /) -> Op: ... + def GlobalVar(_0: str, /) -> GlobalVar: ... + def GlobalVarSupply_FreshGlobal(_0: GlobalVarSupply, _1: str, _2: bool, /) -> GlobalVar: ... + def GlobalVarSupply_IRModule(_0: IRModule, /) -> GlobalVarSupply: ... + def GlobalVarSupply_IRModules(_0: Sequence[IRModule], /) -> GlobalVarSupply: ... + def GlobalVarSupply_NameSupply(_0: NameSupply, /) -> GlobalVarSupply: ... + def GlobalVarSupply_ReserveGlobalVar(_0: GlobalVarSupply, _1: GlobalVar, _2: bool, /) -> None: ... + def GlobalVarSupply_UniqueGlobalFor(_0: GlobalVarSupply, _1: str, _2: bool, /) -> GlobalVar: ... + def IRModule(_0: Mapping[GlobalVar, BaseFunc], _1: Object, _2: Mapping[str, Sequence[GlobalInfo]], /) -> IRModule: ... + def IntImm(_0: dtype, _1: int, _2: Span, /) -> IntImm: ... + def ListOpNames() -> Sequence[str]: ... + def Module_Add(_0: IRModule, _1: GlobalVar, _2: Object, _3: bool, /) -> IRModule: ... + def Module_Clone(_0: IRModule, /) -> IRModule: ... + def Module_ContainGlobalVar(_0: IRModule, _1: str, /) -> bool: ... + def Module_Contains(_0: IRModule, _1: str | GlobalVar, /) -> bool: ... + def Module_FromExpr(_0: RelaxExpr, _1: Mapping[GlobalVar, BaseFunc], /) -> IRModule: ... + def Module_GetAttr(_0: IRModule, _1: str, /) -> Object: ... + def Module_GetAttrs(_0: IRModule, /) -> Object: ... + def Module_GetGlobalVar(_0: IRModule, _1: str, /) -> GlobalVar: ... + def Module_GetGlobalVars(_0: IRModule, /) -> Sequence[GlobalVar]: ... + def Module_Lookup(_0: IRModule, _1: GlobalVar, /) -> BaseFunc: ... + def Module_Lookup_str(_0: IRModule, _1: str, /) -> BaseFunc: ... + def Module_Remove(_0: IRModule, _1: str | GlobalVar, /) -> IRModule: ... + def Module_ReplaceGlobalVars(_0: IRModule, _1: Mapping[str | GlobalVar, str | GlobalVar], /) -> IRModule: ... + def Module_Update(_0: IRModule, _1: IRModule, /) -> None: ... + def Module_UpdateFunction(_0: IRModule, _1: GlobalVar, _2: BaseFunc, /) -> None: ... + def Module_UpdateGlobalInfo(_0: IRModule, _1: str, _2: Sequence[GlobalInfo], /) -> None: ... + def Module_WithAttr(_0: ObjectRValueRef[IRModule], _1: str, _2: Any, /) -> IRModule: ... + def Module_WithAttrs(_0: ObjectRValueRef[IRModule], _1: Mapping[str, Any], /) -> IRModule: ... + def Module_WithoutAttr(_0: ObjectRValueRef[IRModule], _1: str, /) -> IRModule: ... + def NameSupply(_0: str, /) -> NameSupply: ... + def NameSupply_ContainsName(_0: NameSupply, _1: str, _2: bool, /) -> bool: ... + def NameSupply_FreshName(_0: NameSupply, _1: str, _2: bool, _3: bool, /) -> str: ... + def NameSupply_ReserveName(_0: NameSupply, _1: str, _2: bool, /) -> str: ... + def OpAddArgument(_0: Op, _1: str, _2: str, _3: str, /) -> None: ... + def OpGetAttr(_0: Op, _1: str, /) -> Any: ... + def OpHasAttr(_0: Op, _1: str, /) -> bool: ... + def OpResetAttr(_0: Op, _1: str, /) -> None: ... + def OpSetAttr(_0: Op, _1: str, _2: Any, _3: int, /) -> None: ... + def OpSetAttrsTypeKey(_0: Op, _1: str, /) -> None: ... + def OpSetNumInputs(_0: Op, _1: int, /) -> None: ... + def OpSetSupportLevel(_0: Op, _1: int, /) -> None: ... + def PointerType(_0: Type, _1: str, /) -> PointerType: ... + def PrimType(_0: dtype, /) -> PrimType: ... + def Range(_0: PrimExpr, _1: PrimExpr | None, _2: Span, /) -> Range: ... + def Range_from_min_extent(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> Range: ... + def RegisterOp(_0: str, _1: str, /) -> None: ... + def RegisterOpAttr(_0: str, _1: str, _2: Any, _3: int, /) -> None: ... + def RegisterOpLowerIntrinsic(_0: str, _1: Callable[..., Any], _2: str, _3: int, /) -> None: ... + def SequentialSpan(_0: Sequence[Span], /) -> SequentialSpan: ... + def SourceName(_0: str, /) -> SourceName: ... + def Span(_0: SourceName, _1: int, _2: int, _3: int, _4: int, /) -> Span: ... + def TensorMapType(_0: Span, /) -> TensorMapType: ... + def TupleType(_0: Sequence[Type], /) -> TupleType: ... + def VDevice(_0: Target, _1: int, _2: str, /) -> VDevice: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/ir/attrs.py b/python/tvm/ir/attrs.py index fb408cdb8c70..c199c6ea8282 100644 --- a/python/tvm/ir/attrs.py +++ b/python/tvm/ir/attrs.py @@ -15,6 +15,17 @@ # specific language governing permissions and limitations # under the License. """TVM Attribute module, which is mainly used for defining attributes of operators.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping + from typing import Any +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import tvm_ffi from tvm.runtime import Object @@ -30,6 +41,11 @@ class Attrs(Object): Attrs is passed as the first argument to these functions. """ + # tvm-ffi-stubgen(begin): object/ir.Attrs + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def get_int_tuple(self, key): """Get a python int tuple of a key @@ -77,6 +93,12 @@ def __getitem__(self, item): class DictAttrs(Attrs): """Dictionary attributes.""" + # tvm-ffi-stubgen(begin): object/ir.DictAttrs + # fmt: off + __dict__: Mapping[str, Any] + # fmt: on + # tvm-ffi-stubgen(end) + def _dict(self): """Get internal dict""" return _ffi_api.DictAttrsGetDict(self) diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index 651ab392039c..3f4aca0d00fe 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -15,6 +15,18 @@ # specific language governing permissions and limitations # under the License. """Common base structures.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import Source + from typing import Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import tvm.error from tvm_ffi import get_global_func, register_object from tvm.runtime import Object, _ffi_node_api @@ -28,6 +40,11 @@ class Node(Object): @register_object("ir.SourceMap") class SourceMap(Object): + # tvm-ffi-stubgen(begin): object/ir.SourceMap + # fmt: off + source_map: Mapping[SourceName, Source] + # fmt: on + # tvm-ffi-stubgen(end) def add(self, name, content): return get_global_func("SourceMapAdd")(self, name, content) @@ -42,6 +59,12 @@ class SourceName(Object): The name of the source. """ + # tvm-ffi-stubgen(begin): object/ir.SourceName + # fmt: off + name: str + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, name): self.__init_handle_by_constructor__(_ffi_api.SourceName, name) # type: ignore # pylint: disable=no-member @@ -62,6 +85,16 @@ class Span(Object): The column offset of the location. """ + # tvm-ffi-stubgen(begin): object/ir.Span + # fmt: off + source_name: SourceName + line: int + column: int + end_line: int + end_column: int + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, source_name, line, end_line, column, end_column): self.__init_handle_by_constructor__( _ffi_api.Span, source_name, line, end_line, column, end_column # type: ignore # pylint: disable=no-member @@ -81,6 +114,12 @@ class SequentialSpan(Object): The array of spans. """ + # tvm-ffi-stubgen(begin): object/ir.SequentialSpan + # fmt: off + spans: Sequence[Span] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, spans): self.__init_handle_by_constructor__(_ffi_api.SequentialSpan, spans) @@ -92,6 +131,13 @@ class EnvFunc(Object): This is a global function object that can be serialized by its name. """ + # tvm-ffi-stubgen(begin): object/ir.EnvFunc + # fmt: off + name: str + func: Callable[..., Any] + # fmt: on + # tvm-ffi-stubgen(end) + def __call__(self, *args): return _ffi_api.EnvFuncCall(self, *args) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/ir/diagnostics/__init__.py b/python/tvm/ir/diagnostics/__init__.py index 4a521dfa587e..e9b7ebaf4361 100644 --- a/python/tvm/ir/diagnostics/__init__.py +++ b/python/tvm/ir/diagnostics/__init__.py @@ -21,6 +21,18 @@ three key abstractions: a Diagnostic, the DiagnosticContext, and the DiagnosticRenderer. """ +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import IRModule, Span + from typing import Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import enum import tvm_ffi from . import _ffi_api @@ -73,6 +85,14 @@ class DiagnosticLevel(enum.IntEnum): class Diagnostic(Object): """A single diagnostic object from TVM.""" + # tvm-ffi-stubgen(begin): object/Diagnostic + # fmt: off + level: int + span: Span + message: str + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, level, span, message): self.__init_handle_by_constructor__(_ffi_api.Diagnostic, level, span, message) @@ -84,6 +104,12 @@ class DiagnosticRenderer(Object): form of the diagnostics for either human or computer consumption. """ + # tvm-ffi-stubgen(begin): object/DiagnosticRenderer + # fmt: off + renderer: Callable[[DiagnosticContext], None] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, render_func): self.__init_handle_by_constructor__(_ffi_api.DiagnosticRenderer, render_func) @@ -107,6 +133,13 @@ class DiagnosticContext(Object): and contains a renderer. """ + # tvm-ffi-stubgen(begin): object/DiagnosticContext + # fmt: off + module: IRModule + diagnostics: Sequence[Diagnostic] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, module, renderer): self.__init_handle_by_constructor__(_ffi_api.DiagnosticContext, module, renderer) diff --git a/python/tvm/ir/diagnostics/_ffi_api.py b/python/tvm/ir/diagnostics/_ffi_api.py index 65fb2cc896f3..8c183ac50291 100644 --- a/python/tvm/ir/diagnostics/_ffi_api.py +++ b/python/tvm/ir/diagnostics/_ffi_api.py @@ -15,7 +15,34 @@ # specific language governing permissions and limitations # under the License. """FFI for TVM diagnostics.""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ir import IRModule, Span + from typing import Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("diagnostics", __name__) + +# tvm-ffi-stubgen(begin): global/diagnostics +# fmt: off +_FFI_INIT_FUNC("diagnostics", __name__) +if TYPE_CHECKING: + def ClearRenderer() -> None: ... + def Default(_0: IRModule, /) -> DiagnosticContext: ... + def DefaultRenderer() -> DiagnosticRenderer: ... + def Diagnostic(_0: int, _1: Span, _2: str, /) -> Diagnostic: ... + def DiagnosticContext(_0: IRModule, _1: DiagnosticRenderer, /) -> DiagnosticContext: ... + def DiagnosticContextRender(_0: DiagnosticContext, /) -> None: ... + def DiagnosticRenderer(_0: Callable[[DiagnosticContext], None], /) -> DiagnosticRenderer: ... + def DiagnosticRendererRender(_0: DiagnosticRenderer, _1: DiagnosticContext, /) -> None: ... + def Emit(_0: DiagnosticContext, _1: Diagnostic, /) -> None: ... + def GetRenderer() -> DiagnosticRenderer: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index 19abb6bd1eae..b2bd5a9f374f 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -15,6 +15,17 @@ # specific language governing permissions and limitations # under the License. """Common expressions data structures in the IR.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ir import Span + from tvm_ffi import Object, dtype +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from numbers import Number from typing import Optional @@ -30,6 +41,12 @@ class BaseExpr(Node): """Base class of all the expressions.""" + # tvm-ffi-stubgen(begin): object/ir.BaseExpr + # fmt: off + span: Span + # fmt: on + # tvm-ffi-stubgen(end) + span: Optional[Span] @@ -41,6 +58,12 @@ class PrimExpr(BaseExpr): optimizations and integer analysis. """ + # tvm-ffi-stubgen(begin): object/ir.PrimExpr + # fmt: off + dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + dtype: str @@ -48,6 +71,12 @@ class PrimExpr(BaseExpr): class RelaxExpr(BaseExpr): """Base class of all non-primitive expressions.""" + # tvm-ffi-stubgen(begin): object/ir.RelaxExpr + # fmt: off + struct_info_: Object | None + # fmt: on + # tvm-ffi-stubgen(end) + @property def struct_info(self) -> Optional["tvm.relax.StructInfo"]: """Get the struct info field @@ -73,6 +102,12 @@ class GlobalVar(RelaxExpr): The name of the variable. """ + # tvm-ffi-stubgen(begin): object/ir.GlobalVar + # fmt: off + name_hint: str + # fmt: on + # tvm-ffi-stubgen(end) + name_hint: str def __init__(self, name_hint: str): @@ -131,6 +166,14 @@ class Range(Node, Scriptable): if the end argument is not None. Otherwise, it creates `[0, begin)`. """ + # tvm-ffi-stubgen(begin): object/ir.Range + # fmt: off + min: PrimExpr + extent: PrimExpr + span: Span + # fmt: on + # tvm-ffi-stubgen(end) + min: PrimExpr extent: PrimExpr span: Optional[Span] diff --git a/python/tvm/ir/function.py b/python/tvm/ir/function.py index 75718503aae1..76e1897366e0 100644 --- a/python/tvm/ir/function.py +++ b/python/tvm/ir/function.py @@ -16,6 +16,16 @@ # under the License. # pylint: disable=invalid-name """Function definitions.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ir import DictAttrs +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import Union, Dict from enum import IntEnum @@ -40,6 +50,12 @@ class CallingConv(IntEnum): class BaseFunc(RelaxExpr): """Base class of all functions.""" + # tvm-ffi-stubgen(begin): object/ir.BaseFunc + # fmt: off + attrs: DictAttrs + # fmt: on + # tvm-ffi-stubgen(end) + @property def attrs(self): """Return the attrs member of the function.""" diff --git a/python/tvm/ir/global_info.py b/python/tvm/ir/global_info.py index 185e10b88cce..f5731ec0d9c3 100644 --- a/python/tvm/ir/global_info.py +++ b/python/tvm/ir/global_info.py @@ -15,6 +15,16 @@ # specific language governing permissions and limitations # under the License. """Global Info.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from target import Target +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import tvm import tvm_ffi from tvm.runtime.object import Object @@ -25,6 +35,11 @@ class GlobalInfo(Object): """Base node for all global info that can appear in the IR""" + # tvm-ffi-stubgen(begin): object/ir.GlobalInfo + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __eq__(self, other): """Compare two struct info for structural equivalence.""" return tvm.ir.structural_equal(self, other) @@ -39,6 +54,10 @@ def same_as(self, other): @tvm_ffi.register_object("ir.DummyGlobalInfo") class DummyGlobalInfo(GlobalInfo): + # tvm-ffi-stubgen(begin): object/ir.DummyGlobalInfo + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.DummyGlobalInfo, @@ -47,6 +66,13 @@ def __init__(self) -> None: @tvm_ffi.register_object("ir.VDevice") class VDevice(GlobalInfo): + # tvm-ffi-stubgen(begin): object/ir.VDevice + # fmt: off + target: Target + vdevice_id: int + memory_scope: str + # fmt: on + # tvm-ffi-stubgen(end) def __init__( self, target=None, diff --git a/python/tvm/ir/instrument.py b/python/tvm/ir/instrument.py index 0f1bcf3adfda..32e2f5c7dd6a 100644 --- a/python/tvm/ir/instrument.py +++ b/python/tvm/ir/instrument.py @@ -16,6 +16,8 @@ # under the License. # pylint: disable=invalid-name,unused-argument """Common pass instrumentation across IR variants.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) import functools import inspect import re @@ -43,6 +45,12 @@ class PassInstrument(tvm.runtime.Object): """ + # tvm-ffi-stubgen(begin): object/instrument.PassInstrument + # fmt: off + name: str + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self): # initialize handle in case pi_cls creation failed. cls = type(self) @@ -234,6 +242,12 @@ def create_pass_instrument(pi_cls): class PassTimingInstrument(tvm.runtime.Object): """A wrapper to create a passes time instrument that implemented in C++""" + # tvm-ffi-stubgen(begin): object/instrument.PassInstrument + # fmt: off + name: str + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self): self.__init_handle_by_constructor__(_ffi_instrument_api.MakePassTimingInstrument) diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 21c86c05ec4c..df29574c20b9 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -18,6 +18,17 @@ from __future__ import annotations +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import BaseFunc, DictAttrs, GlobalInfo, GlobalVar, SourceMap +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import Dict, Union import tvm @@ -43,6 +54,16 @@ class IRModule(Node, Scriptable): Map of global var to BaseFunc """ + # tvm-ffi-stubgen(begin): object/ir.IRModule + # fmt: off + functions: Mapping[GlobalVar, BaseFunc] + global_var_map_: Mapping[str, GlobalVar] + source_map: SourceMap + attrs: DictAttrs + global_infos: Mapping[str, Sequence[GlobalInfo]] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, functions=None, attrs=None, global_infos=None): if functions is None: functions = {} diff --git a/python/tvm/ir/op.py b/python/tvm/ir/op.py index 5b62d3fe8df7..53e5bdb15600 100644 --- a/python/tvm/ir/op.py +++ b/python/tvm/ir/op.py @@ -16,6 +16,17 @@ # under the License. # pylint: disable=invalid-name """Primitive operators in the TVM IR.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import AttrFieldInfo, FuncType +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import tvm_ffi from . import _ffi_api @@ -26,6 +37,18 @@ class Op(RelaxExpr): """Primitive operator in the IR.""" + # tvm-ffi-stubgen(begin): object/ir.Op + # fmt: off + name: str + op_type: FuncType + description: str + arguments: Sequence[AttrFieldInfo] + attrs_type_key: str + num_inputs: int + support_level: int + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self): raise RuntimeError("Cannot create op, use get instead") diff --git a/python/tvm/ir/supply.py b/python/tvm/ir/supply.py index bc38089b2254..749f2fbfb66e 100644 --- a/python/tvm/ir/supply.py +++ b/python/tvm/ir/supply.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Suppliers that are used to guarantee uniqueness of names and GlobalVars.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) import tvm import tvm_ffi from tvm import Object, IRModule @@ -30,6 +32,11 @@ class NameSupply(Object): prefix: The prefix to be added to the generated names. """ + # tvm-ffi-stubgen(begin): object/ir.NameSupply + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, prefix=""): self.__init_handle_by_constructor__(_ffi_api.NameSupply, prefix) @@ -91,6 +98,11 @@ class GlobalVarSupply(Object): The IRModules used to build this GlobalVarSupply or a NameSupply. """ + # tvm-ffi-stubgen(begin): object/ir.GlobalVarSupply + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, value=None): if value is None: name_supply = NameSupply("") diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index fd9a2ac3b212..1384fff3db4a 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -16,6 +16,18 @@ # under the License. # pylint: disable=invalid-name,unused-argument """Common pass infrastructure across IR variants.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from instrument import PassInstrument + from typing import Any +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import inspect import functools @@ -44,6 +56,15 @@ class PassInfo(tvm.runtime.Object): The list of passes that are required by a certain pass. """ + # tvm-ffi-stubgen(begin): object/transform.PassInfo + # fmt: off + opt_level: int + name: str + required: Sequence[str] + traceable: bool + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, opt_level, name, required=None, traceable=False): self.__init_handle_by_constructor__( _ffi_transform_api.PassInfo, opt_level, name, required, traceable @@ -73,6 +94,17 @@ class PassContext(tvm.runtime.Object): Additional configurations for specific passes. """ + # tvm-ffi-stubgen(begin): object/transform.PassContext + # fmt: off + opt_level: int + required_pass: Sequence[str] + disabled_pass: Sequence[str] + instruments: Sequence[PassInstrument] + config: Mapping[str, Any] + diag_ctx: DiagnosticContext | None + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, opt_level=2, @@ -145,6 +177,11 @@ class Pass(tvm.runtime.Object): conveniently interact with the base class. """ + # tvm-ffi-stubgen(begin): object/transform.Pass + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + @property def info(self): """Get the pass meta.""" @@ -177,6 +214,12 @@ class ModulePass(Pass): The same rule applies to FunctionPass as well. """ + # tvm-ffi-stubgen(begin): object/transform.ModulePass + # fmt: off + pass_info: PassInfo + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("transform.Sequential") class Sequential(Pass): @@ -205,6 +248,13 @@ class Sequential(Pass): The list of passes that the sequential pass is dependent on. """ + # tvm-ffi-stubgen(begin): object/transform.Sequential + # fmt: off + pass_info: PassInfo + passes: Sequence[Pass] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, passes=None, opt_level=0, name="sequential", required=None, traceable=False): passes = passes if passes else [] if not isinstance(passes, (list, tuple)): diff --git a/python/tvm/ir/type.py b/python/tvm/ir/type.py index 68bed8fb69f0..de1e6a06a85e 100644 --- a/python/tvm/ir/type.py +++ b/python/tvm/ir/type.py @@ -15,6 +15,18 @@ # specific language governing permissions and limitations # under the License. """Unified type system in the project.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import Span + from tvm_ffi import dtype +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import tvm import tvm_ffi from tvm.runtime import Scriptable @@ -27,6 +39,12 @@ class Type(Node, Scriptable): """The base class of all types.""" + # tvm-ffi-stubgen(begin): object/ir.Type + # fmt: off + span: Span + # fmt: on + # tvm-ffi-stubgen(end) + def __eq__(self, other): """Compare two types for structural equivalence.""" return bool(tvm.ir.structural_equal(self, other)) @@ -49,6 +67,12 @@ class PrimType(Type): The runtime data type relates to the primtype. """ + # tvm-ffi-stubgen(begin): object/ir.PrimType + # fmt: off + dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, dtype): self.__init_handle_by_constructor__(_ffi_api.PrimType, dtype) @@ -66,6 +90,13 @@ class PointerType(Type): The storage scope into which the pointer addresses. """ + # tvm-ffi-stubgen(begin): object/ir.PointerType + # fmt: off + element_type: Type + storage_scope: str + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, element_type, storage_scope=""): self.__init_handle_by_constructor__(_ffi_api.PointerType, element_type, storage_scope) @@ -80,6 +111,13 @@ class TupleType(Type): The fields in the tuple """ + # tvm-ffi-stubgen(begin): object/ir.TupleType + # fmt: off + fields: Sequence[Type] + span: Span + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, fields): self.__init_handle_by_constructor__(_ffi_api.TupleType, fields) @@ -102,6 +140,14 @@ class FuncType(Type): The return type. """ + # tvm-ffi-stubgen(begin): object/ir.FuncType + # fmt: off + arg_types: Sequence[Type] + ret_type: Type + span: Span + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, arg_types, ret_type): self.__init_handle_by_constructor__( _ffi_api.FuncType, @@ -120,6 +166,12 @@ class TensorMapType(Type): The span information. """ + # tvm-ffi-stubgen(begin): object/ir.TensorMapType + # fmt: off + span: Span + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, span=None): self.__init_handle_by_constructor__( _ffi_api.TensorMapType, span # pylint: disable=no-member diff --git a/python/tvm/ir/type_relation.py b/python/tvm/ir/type_relation.py index 70950958024d..c625577bcadf 100644 --- a/python/tvm/ir/type_relation.py +++ b/python/tvm/ir/type_relation.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Type relation and function for type checking.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) import tvm_ffi from .type import Type, TypeConstraint @@ -39,6 +41,9 @@ class TypeCall(Type): The type function application. """ + # tvm-ffi-stubgen(begin): object/TypeCall + # tvm-ffi-stubgen(end) + def __init__(self, func, args): self.__init_handle_by_constructor__(_ffi_api.TypeCall, func, args) @@ -71,5 +76,8 @@ class TypeRelation(TypeConstraint): The type relation. """ + # tvm-ffi-stubgen(begin): object/TypeRelation + # tvm-ffi-stubgen(end) + def __init__(self, func, args, num_inputs, attrs): self.__init_handle_by_constructor__(_ffi_api.TypeRelation, func, args, num_inputs, attrs) diff --git a/python/tvm/meta_schedule/_ffi_api.py b/python/tvm/meta_schedule/_ffi_api.py index 1a06aef5a482..f12a70e0e396 100644 --- a/python/tvm/meta_schedule/_ffi_api.py +++ b/python/tvm/meta_schedule/_ffi_api.py @@ -15,6 +15,179 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.meta_schedule""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ctypes import c_void_p + from ir import FloatImm, IRModule, IntImm + from meta_schedule import ArgInfo, Builder, BuilderInput, BuilderResult, CostModel, Database, EvolutionarySearch, ExtractedTask, FeatureExtractor, MeasureCallback, MeasureCandidate, Mutator, Postproc, Profiler, Runner, RunnerFuture, RunnerInput, RunnerResult, ScheduleRule, SearchStrategy, SpaceGenerator, TaskScheduler, TensorInfo, TuneContext, TuningRecord, Workload + from target import Target + from tir import BlockRV, PrimFunc, Schedule, Trace + from tvm_ffi import Object, Shape, Tensor, dtype + from typing import Any, Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("meta_schedule", __name__) # pylint: disable=protected-access + +# tvm-ffi-stubgen(begin): global/meta_schedule +# fmt: off +_FFI_INIT_FUNC("meta_schedule", __name__) +if TYPE_CHECKING: + def ArgInfoAsJSON(_0: ArgInfo, /) -> Object: ... + def ArgInfoFromEntryFunc(_0: IRModule, _1: bool, /) -> Sequence[ArgInfo]: ... + def ArgInfoFromJSON(_0: Object, /) -> ArgInfo: ... + def ArgInfoFromPrimFunc(_0: PrimFunc, /) -> Sequence[ArgInfo]: ... + def BuilderBuild(_0: Builder, _1: Sequence[BuilderInput], /) -> Sequence[BuilderResult]: ... + def BuilderInput(_0: IRModule, _1: Target, _2: Mapping[str, Tensor] | None, /) -> BuilderInput: ... + def BuilderPyBuilder(_0: Callable[[Sequence[BuilderInput]], Sequence[BuilderResult]], /) -> Builder: ... + def BuilderResult(_0: str | None, _1: str | None, /) -> BuilderResult: ... + def CostModelLoad(_0: CostModel, _1: str, /) -> None: ... + def CostModelPredict(_0: CostModel, _1: TuneContext, _2: Sequence[MeasureCandidate], _3: c_void_p, /) -> None: ... + def CostModelPyCostModel(_0: Callable[[str], None], _1: Callable[[str], None], _2: Callable[[TuneContext, Sequence[MeasureCandidate], Sequence[RunnerResult]], None], _3: Callable[[TuneContext, Sequence[MeasureCandidate], c_void_p], None], _4: Callable[[], str], /) -> CostModel: ... + def CostModelSave(_0: CostModel, _1: str, /) -> None: ... + def CostModelUpdate(_0: CostModel, _1: TuneContext, _2: Sequence[MeasureCandidate], _3: Sequence[RunnerResult], /) -> None: ... + def DatabaseCommitTuningRecord(_0: Database, _1: TuningRecord, /) -> None: ... + def DatabaseCommitWorkload(_0: Database, _1: IRModule, /) -> Workload: ... + def DatabaseCurrent() -> Database | None: ... + def DatabaseDumpPruned(_0: Database, _1: Database, /) -> None: ... + def DatabaseEnterWithScope(_0: Database, /) -> None: ... + def DatabaseExitWithScope(_0: Database, /) -> None: ... + def DatabaseGetAllTuningRecords(_0: Database, /) -> Sequence[TuningRecord]: ... + def DatabaseGetTopK(_0: Database, _1: Workload, _2: int, /) -> Sequence[TuningRecord]: ... + def DatabaseHasWorkload(_0: Database, _1: IRModule, /) -> bool: ... + def DatabaseJSONDatabase(_0: str, _1: str, _2: bool, _3: str, /) -> Database: ... + def DatabaseMemoryDatabase(_0: str, /) -> Database: ... + def DatabaseOrderedUnionDatabase(_0: Sequence[Database], /) -> Database: ... + def DatabasePyDatabase(_0: Callable[[IRModule], bool], _1: Callable[[IRModule], Workload], _2: Callable[[TuningRecord], None], _3: Callable[[Workload, int], Sequence[TuningRecord]], _4: Callable[[], Sequence[TuningRecord]], _5: Callable[[IRModule, Target, str], TuningRecord | None], _6: Callable[[IRModule, Target, str], Schedule | None], _7: Callable[[IRModule, Target, str], IRModule | None], _8: Callable[[], int], _9: str, /) -> Database: ... + def DatabaseQueryIRModule(_0: Database, _1: IRModule, _2: Target, _3: str, /) -> IRModule | None: ... + def DatabaseQuerySchedule(_0: Database, _1: IRModule, _2: Target, _3: str, /) -> Schedule | None: ... + def DatabaseQueryTuningRecord(_0: Database, _1: IRModule, _2: Target, _3: str, /) -> TuningRecord | None: ... + def DatabaseScheduleFnDatabase(_0: Callable[[Schedule], bool], _1: str, /) -> Database: ... + def DatabaseSize(_0: Database, /) -> int: ... + def DatabaseUnionDatabase(_0: Sequence[Database], /) -> Database: ... + def ExtractedTask(_0: str, _1: IRModule, _2: Target, _3: Sequence[IRModule], _4: int, /) -> ExtractedTask: ... + def FeatureExtractorExtractFrom(_0: FeatureExtractor, _1: TuneContext, _2: Sequence[MeasureCandidate], /) -> Sequence[Tensor]: ... + def FeatureExtractorPerStoreFeature(_0: int, _1: int, _2: int, _3: bool, /) -> FeatureExtractor: ... + def FeatureExtractorPyFeatureExtractor(_0: Callable[[TuneContext, Sequence[MeasureCandidate]], Sequence[Tensor]], _1: Callable[[], str], /) -> FeatureExtractor: ... + def MeasureCallbackAddToDatabase() -> MeasureCallback: ... + def MeasureCallbackApply(_0: MeasureCallback, _1: TaskScheduler, _2: int, _3: Sequence[MeasureCandidate], _4: Sequence[BuilderResult], _5: Sequence[RunnerResult], /) -> None: ... + def MeasureCallbackDefault() -> Sequence[MeasureCallback]: ... + def MeasureCallbackPyMeasureCallback(_0: Callable[[TaskScheduler, int, Sequence[MeasureCandidate], Sequence[BuilderResult], Sequence[RunnerResult]], None], _1: Callable[[], str], /) -> MeasureCallback: ... + def MeasureCallbackRemoveBuildArtifact() -> MeasureCallback: ... + def MeasureCallbackUpdateCostModel() -> MeasureCallback: ... + def MeasureCandidate(_0: Schedule, _1: Sequence[ArgInfo] | None, /) -> MeasureCandidate: ... + def MutateThreadBinding() -> Mutator: ... + def MutatorApply(_0: Mutator, _1: Trace, _2: int, /) -> Trace | None: ... + def MutatorClone(_0: Mutator, /) -> Mutator: ... + def MutatorDefaultCUDA() -> Mapping[Mutator, FloatImm]: ... + def MutatorDefaultCUDATensorCore() -> Mapping[Mutator, FloatImm]: ... + def MutatorDefaultHexagon() -> Mapping[Mutator, FloatImm]: ... + def MutatorDefaultLLVM() -> Mapping[Mutator, FloatImm]: ... + def MutatorInitializeWithTuneContext(_0: Mutator, _1: TuneContext, /) -> None: ... + def MutatorMutateComputeLocation() -> Mutator: ... + def MutatorMutateParallel(_0: int, /) -> Mutator: ... + def MutatorMutateTileSize() -> Mutator: ... + def MutatorMutateUnroll() -> Mutator: ... + def MutatorPyMutator(_0: Callable[[TuneContext], None], _1: Callable[[Trace, int], Trace | None], _2: Callable[[], Mutator], _3: Callable[[], str], /) -> Mutator: ... + def PostprocApply(_0: Postproc, _1: Schedule, /) -> bool: ... + def PostprocClone(_0: Postproc, /) -> Postproc: ... + def PostprocDefaultCUDA() -> Sequence[Postproc]: ... + def PostprocDefaultCUDATensorCore() -> Sequence[Postproc]: ... + def PostprocDefaultHexagon() -> Sequence[Postproc]: ... + def PostprocDefaultLLVM() -> Sequence[Postproc]: ... + def PostprocDisallowAsyncStridedMemCopy() -> Postproc: ... + def PostprocDisallowDynamicLoop() -> Postproc: ... + def PostprocInitializeWithTuneContext(_0: Postproc, _1: TuneContext, /) -> None: ... + def PostprocPyPostproc(_0: Callable[[TuneContext], None], _1: Callable[[Schedule], bool], _2: Callable[[], Postproc], _3: Callable[[], str], /) -> Postproc: ... + def PostprocRewriteCooperativeFetch() -> Postproc: ... + def PostprocRewriteLayout() -> Postproc: ... + def PostprocRewriteParallelVectorizeUnroll() -> Postproc: ... + def PostprocRewriteReductionBlock() -> Postproc: ... + def PostprocRewriteTensorize(_0: bool, /) -> Postproc: ... + def PostprocRewriteUnboundBlock(_0: int, /) -> Postproc: ... + def PostprocVerifyGPUCode() -> Postproc: ... + def PostprocVerifyVTCMLimit() -> Postproc: ... + def Profiler() -> Profiler: ... + def ProfilerCurrent() -> Profiler | None: ... + def ProfilerEnterWithScope(_0: Profiler, /) -> None: ... + def ProfilerExitWithScope(_0: Profiler, /) -> None: ... + def ProfilerGet(_0: Profiler, /) -> Mapping[str, FloatImm]: ... + def ProfilerTable(_0: Profiler, /) -> str: ... + def ProfilerTimedScope(_0: str, /) -> Callable[..., Any]: ... + def RunnerFuture(_0: Callable[[], bool], _1: Callable[[], RunnerResult], /) -> RunnerFuture: ... + def RunnerFutureDone(_0: RunnerFuture, /) -> bool: ... + def RunnerFutureResult(_0: RunnerFuture, /) -> RunnerResult: ... + def RunnerInput(_0: str, _1: str, _2: Sequence[ArgInfo], /) -> RunnerInput: ... + def RunnerPyRunner(_0: Callable[[Sequence[RunnerInput]], Sequence[RunnerFuture]], /) -> Runner: ... + def RunnerResult(_0: Sequence[FloatImm] | None, _1: str | None, /) -> RunnerResult: ... + def RunnerRun(_0: Runner, _1: Sequence[RunnerInput], /) -> Sequence[RunnerFuture]: ... + def ScheduleRuleAddRFactor(_0: int, _1: IntImm | None, /) -> ScheduleRule: ... + def ScheduleRuleApply(_0: ScheduleRule, _1: Schedule, _2: BlockRV, /) -> Sequence[Schedule]: ... + def ScheduleRuleApplyCustomRule() -> ScheduleRule: ... + def ScheduleRuleAutoBind(_0: int, _1: Sequence[IntImm], _2: int, /) -> ScheduleRule: ... + def ScheduleRuleAutoInline(_0: bool, _1: bool, _2: bool, _3: bool, _4: bool, _5: bool, _6: Sequence[str] | None, /) -> ScheduleRule: ... + def ScheduleRuleClone(_0: ScheduleRule, /) -> ScheduleRule: ... + def ScheduleRuleCrossThreadReduction(_0: Sequence[IntImm], /) -> ScheduleRule: ... + def ScheduleRuleDefaultARM(_0: str, /) -> Sequence[ScheduleRule]: ... + def ScheduleRuleDefaultCUDA() -> Sequence[ScheduleRule]: ... + def ScheduleRuleDefaultCUDATensorCore() -> Sequence[ScheduleRule]: ... + def ScheduleRuleDefaultHexagon() -> Sequence[ScheduleRule]: ... + def ScheduleRuleDefaultLLVM() -> Sequence[ScheduleRule]: ... + def ScheduleRuleInitializeWithTuneContext(_0: ScheduleRule, _1: TuneContext, /) -> None: ... + def ScheduleRuleInlineConstantScalars() -> ScheduleRule: ... + def ScheduleRuleMultiLevelTiling(_0: str, _1: Sequence[str] | None, _2: IntImm | None, _3: Sequence[IntImm] | None, _4: Mapping[str, Any] | None, _5: Mapping[str, Any] | None, _6: Callable[..., Any] | None, /) -> ScheduleRule: ... + def ScheduleRuleMultiLevelTilingTensorCore(_0: Sequence[Mapping[str, str]], _1: str, _2: Sequence[str] | None, _3: IntImm | None, _4: Sequence[IntImm] | None, _5: Mapping[str, Any] | None, _6: Mapping[str, Any] | None, _7: bool, /) -> ScheduleRule: ... + def ScheduleRuleMultiLevelTilingWideVector(_0: str, _1: IntImm, _2: IntImm | None, _3: Mapping[str, Any] | None, _4: Mapping[str, Any] | None, /) -> ScheduleRule: ... + def ScheduleRuleMultiLevelTilingWithIntrin(_0: str, _1: str, _2: Sequence[str] | None, _3: IntImm | None, _4: Sequence[IntImm] | None, _5: Mapping[str, Any] | None, _6: Mapping[str, Any] | None, /) -> ScheduleRule: ... + def ScheduleRuleParallelizeVectorizeUnroll(_0: int, _1: int, _2: Sequence[IntImm], _3: bool, /) -> ScheduleRule: ... + def ScheduleRulePyScheduleRule(_0: Callable[[TuneContext], None], _1: Callable[[Schedule, BlockRV], Sequence[Schedule]], _2: Callable[[], ScheduleRule], _3: Callable[[], str], /) -> ScheduleRule: ... + def ScheduleRuleRandomComputeLocation() -> ScheduleRule: ... + def ScheduleUsingAnchorTrace(_0: Schedule, _1: Trace, _2: Target, /) -> None: ... + def SearchStrategyClone(_0: SearchStrategy, /) -> SearchStrategy: ... + def SearchStrategyEvolutionarySearch(_0: int, _1: float, _2: int, _3: int, _4: int, _5: float, _6: int, _7: float, /) -> SearchStrategy: ... + def SearchStrategyEvolutionarySearchEvolveWithCostModel(_0: EvolutionarySearch, _1: Sequence[Schedule], _2: int, /) -> Sequence[Schedule]: ... + def SearchStrategyEvolutionarySearchSampleInitPopulation(_0: EvolutionarySearch, _1: int, /) -> Sequence[Schedule]: ... + def SearchStrategyGenerateMeasureCandidates(_0: SearchStrategy, /) -> Sequence[MeasureCandidate] | None: ... + def SearchStrategyInitializeWithTuneContext(_0: SearchStrategy, _1: TuneContext, /) -> None: ... + def SearchStrategyNotifyRunnerResults(_0: SearchStrategy, _1: Sequence[MeasureCandidate], _2: Sequence[RunnerResult], /) -> None: ... + def SearchStrategyPostTuning(_0: SearchStrategy, /) -> None: ... + def SearchStrategyPreTuning(_0: SearchStrategy, _1: int, _2: int, _3: Sequence[Schedule], _4: Database | None, _5: CostModel | None, /) -> None: ... + def SearchStrategyPySearchStrategy(_0: Callable[[TuneContext], None], _1: Callable[[int, int, Sequence[Schedule], Database | None, CostModel | None], None], _2: Callable[[], None], _3: Callable[[], Sequence[MeasureCandidate] | None], _4: Callable[[Sequence[MeasureCandidate], Sequence[RunnerResult]], None], _5: Callable[[], SearchStrategy], /) -> SearchStrategy: ... + def SearchStrategyReplayFunc() -> SearchStrategy: ... + def SearchStrategyReplayTrace(_0: int, /) -> SearchStrategy: ... + def SpaceGeneratorClone(_0: SpaceGenerator, /) -> SpaceGenerator: ... + def SpaceGeneratorGenerateDesignSpace(_0: SpaceGenerator, _1: IRModule, /) -> Sequence[Schedule]: ... + def SpaceGeneratorInitializeWithTuneContext(_0: SpaceGenerator, _1: TuneContext, /) -> None: ... + def SpaceGeneratorPostOrderApply(_0: Callable[..., Any], _1: Sequence[ScheduleRule] | None, _2: Sequence[Postproc] | None, _3: Mapping[Mutator, FloatImm] | None, /) -> SpaceGenerator: ... + def SpaceGeneratorPySpaceGenerator(_0: Sequence[ScheduleRule] | None, _1: Sequence[Postproc] | None, _2: Mapping[Mutator, FloatImm] | None, _3: Callable[[TuneContext], None], _4: Callable[[IRModule], Sequence[Schedule]], _5: Callable[[], SpaceGenerator], /) -> SpaceGenerator: ... + def SpaceGeneratorScheduleFn(_0: Callable[..., Any], _1: Sequence[ScheduleRule] | None, _2: Sequence[Postproc] | None, _3: Mapping[Mutator, FloatImm] | None, /) -> SpaceGenerator: ... + def SpaceGeneratorSpaceGeneratorUnion(_0: Sequence[SpaceGenerator], _1: Sequence[ScheduleRule] | None, _2: Sequence[Postproc] | None, _3: Mapping[Mutator, FloatImm] | None, /) -> SpaceGenerator: ... + def TaskSchedulerGradientBased(_0: Callable[..., Any], _1: float, _2: int, _3: int, /) -> TaskScheduler: ... + def TaskSchedulerJoinRunningTask(_0: TaskScheduler, _1: int, /) -> Sequence[RunnerResult]: ... + def TaskSchedulerNextTaskId(_0: TaskScheduler, /) -> int: ... + def TaskSchedulerPrintTuningStatistics(_0: TaskScheduler, /) -> None: ... + def TaskSchedulerPyTaskScheduler(_0: Callable[..., Any], _1: Callable[[], int], _2: Callable[[int], Sequence[RunnerResult]], _3: Callable[[Sequence[TuneContext], Sequence[FloatImm], int, int, int, Builder, Runner, Sequence[MeasureCallback], Database | None, CostModel | None], None], /) -> TaskScheduler: ... + def TaskSchedulerRoundRobin(_0: Callable[..., Any], /) -> TaskScheduler: ... + def TaskSchedulerTerminateTask(_0: TaskScheduler, _1: int, /) -> None: ... + def TaskSchedulerTouchTask(_0: TaskScheduler, _1: int, /) -> None: ... + def TaskSchedulerTune(_0: TaskScheduler, _1: Sequence[TuneContext], _2: Sequence[FloatImm], _3: int, _4: int, _5: int, _6: Builder, _7: Runner, _8: Sequence[MeasureCallback], _9: Database | None, _10: CostModel | None, /) -> None: ... + def TensorInfo(_0: dtype, _1: Shape, /) -> TensorInfo: ... + def TuneContext(_0: IRModule | None, _1: Target | None, _2: SpaceGenerator | None, _3: SearchStrategy | None, _4: str | None, _5: int, _6: int, _7: Callable[..., Any], /) -> TuneContext: ... + def TuneContextClone(_0: TuneContext, /) -> TuneContext: ... + def TuneContextInitialize(_0: TuneContext, /) -> None: ... + def TuningRecord(_0: Trace, _1: Workload, _2: Sequence[FloatImm] | None, _3: Target | None, _4: Sequence[ArgInfo] | None, /) -> TuningRecord: ... + def TuningRecordAsJSON(_0: TuningRecord, /) -> Object: ... + def TuningRecordAsMeasureCandidate(_0: TuningRecord, /) -> MeasureCandidate: ... + def TuningRecordFromJSON(_0: Object, _1: Workload, /) -> TuningRecord: ... + def Workload(_0: IRModule, /) -> Workload: ... + def WorkloadAsJSON(_0: Workload, /) -> Object: ... + def WorkloadFromJSON(_0: Object, /) -> Workload: ... + def _SHash2Hex(_0: Object, /) -> str: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/meta_schedule/arg_info.py b/python/tvm/meta_schedule/arg_info.py index 3f8d721ed1f0..415f165aa808 100644 --- a/python/tvm/meta_schedule/arg_info.py +++ b/python/tvm/meta_schedule/arg_info.py @@ -15,6 +15,16 @@ # specific language governing permissions and limitations # under the License. """The argument information""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from tvm_ffi import Shape, dtype +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import Any, List, Union from tvm_ffi import register_object @@ -30,6 +40,11 @@ class ArgInfo(Object): """Argument information""" + # tvm-ffi-stubgen(begin): object/meta_schedule.ArgInfo + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def as_json(self) -> Any: """Converts the ArgInfo to its corresponding JSON representation.""" return _json_de_tvm(_ffi_api.ArgInfoAsJSON(self)) # type: ignore # pylint: disable=no-member @@ -97,6 +112,13 @@ class TensorInfo(ArgInfo): The shape of the tensor. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.TensorInfo + # fmt: off + dtype: dtype + shape: Shape + # fmt: on + # tvm-ffi-stubgen(end) + dtype: DataType shape: ShapeTuple diff --git a/python/tvm/meta_schedule/builder/builder.py b/python/tvm/meta_schedule/builder/builder.py index 39493781404a..d1024eac7468 100644 --- a/python/tvm/meta_schedule/builder/builder.py +++ b/python/tvm/meta_schedule/builder/builder.py @@ -15,6 +15,20 @@ # specific language governing permissions and limitations # under the License. """Meta Schedule builders that translate IRModule to runtime.Module, and then export""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import IRModule + from target import Target + from tvm_ffi import Tensor + from typing import Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import Callable, Dict, List, Optional, Union # isort: off @@ -43,6 +57,14 @@ class BuilderInput(Object): The parameters for Relax build module """ + # tvm-ffi-stubgen(begin): object/meta_schedule.BuilderInput + # fmt: off + mod: IRModule + target: Target + params: Mapping[str, Tensor] | None + # fmt: on + # tvm-ffi-stubgen(end) + mod: IRModule target: Target params: Optional[Dict[str, Tensor]] @@ -84,6 +106,13 @@ class BuilderResult(Object): The error message. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.BuilderResult + # fmt: off + artifact_path: str | None + error_msg: str | None + # fmt: on + # tvm-ffi-stubgen(end) + artifact_path: Optional[str] error_msg: Optional[str] @@ -112,6 +141,11 @@ def __init__( class Builder(Object): """The abstract builder interface.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.Builder + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + BuilderType = Union["Builder", Literal["local"]] def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]: @@ -165,6 +199,12 @@ class _PyBuilder(Builder): See also: PyBuilder """ + # tvm-ffi-stubgen(begin): object/meta_schedule.PyBuilder + # fmt: off + f_build: Callable[[Sequence[BuilderInput]], Sequence[BuilderResult]] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, f_build: Callable = None): """Constructor.""" diff --git a/python/tvm/meta_schedule/cost_model/cost_model.py b/python/tvm/meta_schedule/cost_model/cost_model.py index f51d2f2ac89b..129fc276acfa 100644 --- a/python/tvm/meta_schedule/cost_model/cost_model.py +++ b/python/tvm/meta_schedule/cost_model/cost_model.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Meta Schedule CostModel.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) import ctypes from typing import Callable, List, Union @@ -38,6 +40,11 @@ class CostModel(Object): """Cost model.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.CostModel + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + CostModelType = Union["CostModel", Literal["xgb", "mlp", "random"]] def load(self, path: str) -> None: @@ -159,6 +166,11 @@ class _PyCostModel(CostModel): See also: PyCostModel """ + # tvm-ffi-stubgen(begin): object/meta_schedule.PyCostModel + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, f_load: Callable = None, diff --git a/python/tvm/meta_schedule/database/database.py b/python/tvm/meta_schedule/database/database.py index 08bcbd33c7ad..e8d7f4f36ef2 100644 --- a/python/tvm/meta_schedule/database/database.py +++ b/python/tvm/meta_schedule/database/database.py @@ -15,6 +15,20 @@ # specific language governing permissions and limitations # under the License. """TuningRecord database""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import FloatImm, IRModule + from meta_schedule import ArgInfo + from target import Target + from tir import Trace +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import Any, Callable, List, Optional, Union # isort: off @@ -43,6 +57,12 @@ class Workload(Object): The workload's IRModule """ + # tvm-ffi-stubgen(begin): object/meta_schedule.Workload + # fmt: off + mod: IRModule + # fmt: on + # tvm-ffi-stubgen(end) + mod: IRModule def __init__(self, mod: IRModule) -> None: @@ -97,6 +117,16 @@ class TuningRecord(Object): The argument information of the tuning record. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.TuningRecord + # fmt: off + trace: Trace + workload: Workload + run_secs: Sequence[FloatImm] | None + target: Target | None + args_info: Sequence[ArgInfo] | None + # fmt: on + # tvm-ffi-stubgen(end) + trace: Trace workload: Workload run_secs: Optional[List[float]] @@ -164,6 +194,11 @@ def from_json(json_obj: Any, workload: Workload) -> "TuningRecord": class Database(Object): """The abstract database interface.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.Database + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + DatabaseType = Union["Database", Literal["json", "memory"]] def has_workload(self, mod: IRModule) -> bool: @@ -436,6 +471,11 @@ class _PyDatabase(Database): See also: PyDatabase """ + # tvm-ffi-stubgen(begin): object/meta_schedule.PyDatabase + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, f_has_workload: Callable = None, diff --git a/python/tvm/meta_schedule/database/json_database.py b/python/tvm/meta_schedule/database/json_database.py index 7c6f7459cacc..333e7b65a3f9 100644 --- a/python/tvm/meta_schedule/database/json_database.py +++ b/python/tvm/meta_schedule/database/json_database.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """The default database that uses a JSON File to store tuning records""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) import os.path as osp from typing import Optional @@ -46,6 +48,13 @@ class JSONDatabase(Database): For the definition of the anchor block, see tir/analysis/analysis.py. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.JSONDatabase + # fmt: off + path_workload: str + path_tuning_record: str + # fmt: on + # tvm-ffi-stubgen(end) + path_workload: str path_tuning_record: str diff --git a/python/tvm/meta_schedule/database/memory_database.py b/python/tvm/meta_schedule/database/memory_database.py index 1d6d4121231c..1fcbaec04758 100644 --- a/python/tvm/meta_schedule/database/memory_database.py +++ b/python/tvm/meta_schedule/database/memory_database.py @@ -15,6 +15,17 @@ # specific language governing permissions and limitations # under the License. """A database that stores TuningRecords in memory""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from meta_schedule import TuningRecord, Workload +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -39,6 +50,13 @@ class MemoryDatabase(Database): For the definition of the anchor block, see tir/analysis/analysis.py. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.MemoryDatabase + # fmt: off + records: Sequence[TuningRecord] + workloads: Sequence[Workload] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, module_equality: str = "structural", diff --git a/python/tvm/meta_schedule/database/ordered_union_database.py b/python/tvm/meta_schedule/database/ordered_union_database.py index 717d2f3001c9..6284b2024a0d 100644 --- a/python/tvm/meta_schedule/database/ordered_union_database.py +++ b/python/tvm/meta_schedule/database/ordered_union_database.py @@ -15,6 +15,17 @@ # specific language governing permissions and limitations # under the License. """A database consists of multiple databases.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from meta_schedule import Database +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -98,6 +109,12 @@ class OrderedUnionDatabase(Database): merged_db.query_tuning_record(..., target_workload) """ + # tvm-ffi-stubgen(begin): object/meta_schedule.OrderedUnionDatabase + # fmt: off + databases: Sequence[Database] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, *databases: Database) -> None: """Construct a merged database from multiple databases. diff --git a/python/tvm/meta_schedule/database/schedule_fn_database.py b/python/tvm/meta_schedule/database/schedule_fn_database.py index 74b2a6eb60da..3c85c99ba268 100644 --- a/python/tvm/meta_schedule/database/schedule_fn_database.py +++ b/python/tvm/meta_schedule/database/schedule_fn_database.py @@ -15,6 +15,17 @@ # specific language governing permissions and limitations # under the License. """A database for injecting handcrafted schedule functions.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from tir import Schedule + from typing import Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import Callable from tvm_ffi import register_object @@ -45,6 +56,12 @@ class ScheduleFnDatabase(Database): For the definition of the anchor block, see tir/analysis/analysis.py. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.ScheduleFnDatabase + # fmt: off + schedule_fn: Callable[[Schedule], bool] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, schedule_fn: Callable[[Schedule], bool], diff --git a/python/tvm/meta_schedule/database/union_database.py b/python/tvm/meta_schedule/database/union_database.py index 3a1afbe32adf..43724767dc77 100644 --- a/python/tvm/meta_schedule/database/union_database.py +++ b/python/tvm/meta_schedule/database/union_database.py @@ -15,6 +15,17 @@ # specific language governing permissions and limitations # under the License. """A database consists of multiple databases.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from meta_schedule import Database +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -98,6 +109,12 @@ class UnionDatabase(Database): merged_db.query_tuning_record(..., target_workload) """ + # tvm-ffi-stubgen(begin): object/meta_schedule.UnionDatabase + # fmt: off + databases: Sequence[Database] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, *databases: Database) -> None: """Construct a merged database from multiple databases. diff --git a/python/tvm/meta_schedule/extracted_task.py b/python/tvm/meta_schedule/extracted_task.py index df66e774e595..beb179c2c502 100644 --- a/python/tvm/meta_schedule/extracted_task.py +++ b/python/tvm/meta_schedule/extracted_task.py @@ -15,6 +15,18 @@ # specific language governing permissions and limitations # under the License. """Extracted tasks from high-level IR.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import IRModule + from target import Target +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import List from tvm_ffi import register_object @@ -43,6 +55,16 @@ class ExtractedTask(Object): The weight of the task """ + # tvm-ffi-stubgen(begin): object/meta_schedule.ExtractedTask + # fmt: off + task_name: str + mod: IRModule + target: Target + dispatched: Sequence[IRModule] + weight: int + # fmt: on + # tvm-ffi-stubgen(end) + task_name: str mod: IRModule dispatched: List[IRModule] diff --git a/python/tvm/meta_schedule/feature_extractor/feature_extractor.py b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py index b50a22142943..5a5d9f61f041 100644 --- a/python/tvm/meta_schedule/feature_extractor/feature_extractor.py +++ b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Meta Schedule FeatureExtractor.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from typing import Callable, List, Union # isort: off @@ -36,6 +38,11 @@ class FeatureExtractor(Object): """Extractor for features from measure candidates for use in cost model.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.FeatureExtractor + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + FeatureExtractorType = Union[Literal["per-store-feature"], "FeatureExtractor"] def extract_from( @@ -83,6 +90,11 @@ class _PyFeatureExtractor(FeatureExtractor): See also: PyFeatureExtractor """ + # tvm-ffi-stubgen(begin): object/meta_schedule.PyFeatureExtractor + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, f_extract_from: Callable, f_as_string: Callable = None): """Constructor.""" diff --git a/python/tvm/meta_schedule/feature_extractor/per_store_feature.py b/python/tvm/meta_schedule/feature_extractor/per_store_feature.py index 673a722955d2..db8c4d66bcb9 100644 --- a/python/tvm/meta_schedule/feature_extractor/per_store_feature.py +++ b/python/tvm/meta_schedule/feature_extractor/per_store_feature.py @@ -18,6 +18,8 @@ """We extract one feature vector per BufferStoreNode statement in a TIR Stmt, so we call this feature as "per-store" feature. """ +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -40,6 +42,16 @@ class PerStoreFeature(FeatureExtractor): Whether to extract features in the workload in tuning context or not. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.PerStoreFeature + # fmt: off + buffers_per_store: int + arith_intensity_curve_num_samples: int + cache_line_bytes: int + extract_workload: bool + feature_vector_length: int + # fmt: on + # tvm-ffi-stubgen(end) + buffers_per_store: int """The number of buffers in each BufferStore; Pad or truncate if necessary.""" arith_intensity_curve_num_samples: int # pylint: disable=invalid-name diff --git a/python/tvm/meta_schedule/measure_callback/add_to_database.py b/python/tvm/meta_schedule/measure_callback/add_to_database.py index e0a6f5a273fc..793dec6971ab 100644 --- a/python/tvm/meta_schedule/measure_callback/add_to_database.py +++ b/python/tvm/meta_schedule/measure_callback/add_to_database.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """A callback that adds the measurement results into the database""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -23,6 +25,10 @@ @register_object("meta_schedule.AddToDatabase") class AddToDatabase(MeasureCallback): + # tvm-ffi-stubgen(begin): object/meta_schedule.AddToDatabase + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) def __init__(self) -> None: """A callback that adds the measurement results into the database""" self.__init_handle_by_constructor__( diff --git a/python/tvm/meta_schedule/measure_callback/measure_callback.py b/python/tvm/meta_schedule/measure_callback/measure_callback.py index 885f70e88de8..5f7720a12e76 100644 --- a/python/tvm/meta_schedule/measure_callback/measure_callback.py +++ b/python/tvm/meta_schedule/measure_callback/measure_callback.py @@ -16,6 +16,8 @@ # under the License. """Meta Schedule MeasureCallback.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from typing import TYPE_CHECKING, Callable, List, Union # isort: off @@ -40,6 +42,11 @@ class MeasureCallback(Object): """Rules to apply after measure results is available.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.MeasureCallback + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + CallbackListType = Union[List["MeasureCallback"], "MeasureCallback", Literal["default"]] def apply( @@ -91,6 +98,11 @@ class _PyMeasureCallback(MeasureCallback): See also: PyMeasureCallback """ + # tvm-ffi-stubgen(begin): object/meta_schedule.PyMeasureCallback + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, f_apply: Callable, f_as_string: Callable = None): """Constructor.""" diff --git a/python/tvm/meta_schedule/measure_callback/remove_build_artifact.py b/python/tvm/meta_schedule/measure_callback/remove_build_artifact.py index 23808b7e99d7..ce8e1ef2688e 100644 --- a/python/tvm/meta_schedule/measure_callback/remove_build_artifact.py +++ b/python/tvm/meta_schedule/measure_callback/remove_build_artifact.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """A callback that removes the build artifacts from the disk""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -23,6 +25,10 @@ @register_object("meta_schedule.RemoveBuildArtifact") class RemoveBuildArtifact(MeasureCallback): + # tvm-ffi-stubgen(begin): object/meta_schedule.RemoveBuildArtifact + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) def __init__(self) -> None: """A callback that removes the build artifacts from the disk""" self.__init_handle_by_constructor__( diff --git a/python/tvm/meta_schedule/measure_callback/update_cost_model.py b/python/tvm/meta_schedule/measure_callback/update_cost_model.py index 7cf60c095b97..5fd69f45b099 100644 --- a/python/tvm/meta_schedule/measure_callback/update_cost_model.py +++ b/python/tvm/meta_schedule/measure_callback/update_cost_model.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """A measure callback that updates the cost model""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -23,6 +25,10 @@ @register_object("meta_schedule.UpdateCostModel") class UpdateCostModel(MeasureCallback): + # tvm-ffi-stubgen(begin): object/meta_schedule.UpdateCostModel + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) def __init__(self) -> None: """A measure callback that updates the cost model""" self.__init_handle_by_constructor__( diff --git a/python/tvm/meta_schedule/mutator/mutate_compute_location.py b/python/tvm/meta_schedule/mutator/mutate_compute_location.py index 620e0062cbff..5a2d607d0f42 100644 --- a/python/tvm/meta_schedule/mutator/mutate_compute_location.py +++ b/python/tvm/meta_schedule/mutator/mutate_compute_location.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """A mutator that mutates the compute-at location decision of SampleComputeLocation""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api @@ -25,6 +27,11 @@ class MutateComputeLocation(Mutator): """A mutator that mutates the compute-at location decision of SampleComputeLocation""" + # tvm-ffi-stubgen(begin): object/meta_schedule.MutateComputeLocation + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.MutatorMutateComputeLocation, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/mutator/mutate_parallel.py b/python/tvm/meta_schedule/mutator/mutate_parallel.py index fc077cd0d4aa..24a11fd6f77c 100644 --- a/python/tvm/meta_schedule/mutator/mutate_parallel.py +++ b/python/tvm/meta_schedule/mutator/mutate_parallel.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Mutator that mutates the parallel extent""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api @@ -25,6 +27,12 @@ class MutateParallel(Mutator): """Mutator that mutates the parallel extent""" + # tvm-ffi-stubgen(begin): object/meta_schedule.MutateParallel + # fmt: off + max_jobs_per_core: int + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, max_jobs_per_core: int) -> None: """Mutator that mutates the parallel extent""" self.__init_handle_by_constructor__( diff --git a/python/tvm/meta_schedule/mutator/mutate_thread_binding.py b/python/tvm/meta_schedule/mutator/mutate_thread_binding.py index 4c9fa44c50a0..56c74f41cb6e 100644 --- a/python/tvm/meta_schedule/mutator/mutate_thread_binding.py +++ b/python/tvm/meta_schedule/mutator/mutate_thread_binding.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Mutator that mutates the thread binding extent""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api @@ -25,6 +27,11 @@ class MutateThreadBinding(Mutator): """Mutator that mutates the binding extent""" + # tvm-ffi-stubgen(begin): object/meta_schedule.MutateThreadBinding + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: """Mutator that mutates the binding extent""" self.__init_handle_by_constructor__( diff --git a/python/tvm/meta_schedule/mutator/mutate_tile_size.py b/python/tvm/meta_schedule/mutator/mutate_tile_size.py index f40894f5ba0f..801ccff4c0fb 100644 --- a/python/tvm/meta_schedule/mutator/mutate_tile_size.py +++ b/python/tvm/meta_schedule/mutator/mutate_tile_size.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Mutator that mutates the decision of instruction Sample-Perfect-Tile""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api @@ -25,6 +27,11 @@ class MutateTileSize(Mutator): """Mutator that mutates the decision of instruction Sample-Perfect-Tile""" + # tvm-ffi-stubgen(begin): object/meta_schedule.MutateTileSize + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.MutatorMutateTileSize, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/mutator/mutate_unroll.py b/python/tvm/meta_schedule/mutator/mutate_unroll.py index 97999c2888f8..2b9691d23cd5 100644 --- a/python/tvm/meta_schedule/mutator/mutate_unroll.py +++ b/python/tvm/meta_schedule/mutator/mutate_unroll.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Mutator that mutates auto unroll step""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api @@ -25,6 +27,11 @@ class MutateUnroll(Mutator): """Mutator that mutates auto unroll step""" + # tvm-ffi-stubgen(begin): object/meta_schedule.MutateUnroll + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.MutatorMutateUnroll, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/mutator/mutator.py b/python/tvm/meta_schedule/mutator/mutator.py index 211e2c2b5015..f29b03832005 100644 --- a/python/tvm/meta_schedule/mutator/mutator.py +++ b/python/tvm/meta_schedule/mutator/mutator.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Meta Schedule Mutator.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from typing import TYPE_CHECKING, Callable, Dict, Optional # isort: off @@ -120,6 +122,11 @@ class _PyMutator(Mutator): See also: PyMutator """ + # tvm-ffi-stubgen(begin): object/meta_schedule.PyMutator + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, f_initialize_with_tune_context: Callable = None, diff --git a/python/tvm/meta_schedule/postproc/disallow_async_strided_mem_copy.py b/python/tvm/meta_schedule/postproc/disallow_async_strided_mem_copy.py index 5c18475ea0ca..f4c9624f396c 100644 --- a/python/tvm/meta_schedule/postproc/disallow_async_strided_mem_copy.py +++ b/python/tvm/meta_schedule/postproc/disallow_async_strided_mem_copy.py @@ -16,6 +16,8 @@ # under the License. """A postprocessor that checks if the IRModule has any strided memory copies""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc @@ -25,6 +27,11 @@ class DisallowAsyncStridedMemCopy(Postproc): """A postprocessor that disallows schedules that use async strided mem copies.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.DisallowAsyncStridedMemCopy + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.PostprocDisallowAsyncStridedMemCopy, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py b/python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py index da604e42cc81..2541a00153d6 100644 --- a/python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py +++ b/python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py @@ -16,6 +16,8 @@ # under the License. """A postprocessor that checks if the IRModule has any loop with non-constant extent""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc @@ -25,6 +27,11 @@ class DisallowDynamicLoop(Postproc): """A postprocessor that checks if the IRModule has any loop with non-constant extent""" + # tvm-ffi-stubgen(begin): object/meta_schedule.DisallowDynamicLoop + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.PostprocDisallowDynamicLoop, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/postproc/postproc.py b/python/tvm/meta_schedule/postproc/postproc.py index 8e89ad2fe138..a4089929e677 100644 --- a/python/tvm/meta_schedule/postproc/postproc.py +++ b/python/tvm/meta_schedule/postproc/postproc.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Meta Schedule Postproc.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from typing import TYPE_CHECKING, Callable, List # isort: off @@ -37,6 +39,11 @@ class Postproc(Object): """Rules to apply a postprocessor to a schedule.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.Postproc + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def _initialize_with_tune_context(self, context: "TuneContext") -> None: """Initialize the postprocessor with a tune context. @@ -114,6 +121,11 @@ class _PyPostproc(Postproc): See also: PyPostproc """ + # tvm-ffi-stubgen(begin): object/meta_schedule.PyPostproc + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, f_initialize_with_tune_context: Callable = None, diff --git a/python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py b/python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py index d20c22d0f6d8..5f681a7c9a93 100644 --- a/python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py +++ b/python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py @@ -17,6 +17,8 @@ """A postprocessor that rewrites the cooperative fetch annotation to actual vectorized cooperative fetching in loop bindings.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc @@ -28,6 +30,11 @@ class RewriteCooperativeFetch(Postproc): cooperative fetching in loop bindings. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.RewriteCooperativeFetch + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.PostprocRewriteCooperativeFetch, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/postproc/rewrite_layout.py b/python/tvm/meta_schedule/postproc/rewrite_layout.py index 73b6dde9f76a..f5e99fe435c2 100644 --- a/python/tvm/meta_schedule/postproc/rewrite_layout.py +++ b/python/tvm/meta_schedule/postproc/rewrite_layout.py @@ -16,6 +16,8 @@ # under the License. """A postprocessor that rewrites the layout of input tensor""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api @@ -26,6 +28,11 @@ class RewriteLayout(Postproc): """A postprocessor that rewrites the layout of input tensor""" + # tvm-ffi-stubgen(begin): object/meta_schedule.RewriteLayout + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.PostprocRewriteLayout, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py b/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py index 30235517f9c6..3aa1c8c14ed7 100644 --- a/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py +++ b/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py @@ -17,6 +17,8 @@ """A postprocessor that applies parallelization, vectorization and auto unrolling according to the annotation of each block""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc @@ -27,6 +29,11 @@ class RewriteParallelVectorizeUnroll(Postproc): """A postprocessor that applies parallelization, vectorization and auto unrolling according to the annotation of each block""" + # tvm-ffi-stubgen(begin): object/meta_schedule.RewriteParallelVectorizeUnroll + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.PostprocRewriteParallelVectorizeUnroll, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/postproc/rewrite_reduction_block.py b/python/tvm/meta_schedule/postproc/rewrite_reduction_block.py index 5bbe2b88381e..b748e9adb77f 100644 --- a/python/tvm/meta_schedule/postproc/rewrite_reduction_block.py +++ b/python/tvm/meta_schedule/postproc/rewrite_reduction_block.py @@ -16,6 +16,8 @@ # under the License. """A postprocessor that rewrites reduction block by moving the init block out.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc @@ -25,6 +27,11 @@ class RewriteReductionBlock(Postproc): """A postprocessor that rewrites reduction block by moving the init block out.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.RewriteReductionBlock + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.PostprocRewriteReductionBlock, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/postproc/rewrite_tensorize.py b/python/tvm/meta_schedule/postproc/rewrite_tensorize.py index 8f0edb869586..0a2d59fac9e0 100644 --- a/python/tvm/meta_schedule/postproc/rewrite_tensorize.py +++ b/python/tvm/meta_schedule/postproc/rewrite_tensorize.py @@ -16,6 +16,8 @@ # under the License. """A postprocessor that tensorize related components.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc @@ -31,6 +33,11 @@ class RewriteTensorize(Postproc): Whether or not vectorize the initialization loop produced by DecomposeReduction """ + # tvm-ffi-stubgen(begin): object/meta_schedule.RewriteTensorize + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, vectorize_init_loop=False) -> None: self.__init_handle_by_constructor__( _ffi_api.PostprocRewriteTensorize, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py b/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py index b274c2f55c11..4820afab989b 100644 --- a/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py +++ b/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py @@ -16,6 +16,8 @@ # under the License. """A postprocessor that adds thread binding to unbound blocks""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api @@ -26,6 +28,11 @@ class RewriteUnboundBlock(Postproc): """A postprocessor that adds thread binding to unbound blocks""" + # tvm-ffi-stubgen(begin): object/meta_schedule.RewriteUnboundBlock + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, max_threadblocks: int = 256) -> None: self.__init_handle_by_constructor__( _ffi_api.PostprocRewriteUnboundBlock, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/postproc/verify_gpu_code.py b/python/tvm/meta_schedule/postproc/verify_gpu_code.py index 48fbe8f4b14c..c42f2a8ae44f 100644 --- a/python/tvm/meta_schedule/postproc/verify_gpu_code.py +++ b/python/tvm/meta_schedule/postproc/verify_gpu_code.py @@ -16,6 +16,8 @@ # under the License. """A postprocessor that verifies if the GPU code is correct""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc @@ -25,6 +27,11 @@ class VerifyGPUCode(Postproc): """A postprocessor that verifies if the GPU code is correct""" + # tvm-ffi-stubgen(begin): object/meta_schedule.VerifyGPUCode + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.PostprocVerifyGPUCode, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/postproc/verify_vtcm_limit.py b/python/tvm/meta_schedule/postproc/verify_vtcm_limit.py index 96ece2270bbc..745c6f684d96 100644 --- a/python/tvm/meta_schedule/postproc/verify_vtcm_limit.py +++ b/python/tvm/meta_schedule/postproc/verify_vtcm_limit.py @@ -16,6 +16,8 @@ # under the License. """A postprocessor that verifies the VTCM usage of a given schedule.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc @@ -25,6 +27,11 @@ class VerifyVTCMLimit(Postproc): """Verifies that the VTCM usage of a given schedule is within the provided limit.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.VerifyVTCMLimit + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.PostprocVerifyVTCMLimit, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/profiler.py b/python/tvm/meta_schedule/profiler.py index 1a41f589de4c..cd5c6564414c 100644 --- a/python/tvm/meta_schedule/profiler.py +++ b/python/tvm/meta_schedule/profiler.py @@ -16,6 +16,8 @@ # under the License. # pylint: disable=used-before-assignment """A context manager that profiles tuning time cost for different parts.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from contextlib import contextmanager from typing import Dict, Optional @@ -29,6 +31,11 @@ class Profiler(Object): """Tuning time profiler.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.Profiler + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.Profiler, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/runner/runner.py b/python/tvm/meta_schedule/runner/runner.py index 0d7cd32bd7a5..a829a24d8e44 100644 --- a/python/tvm/meta_schedule/runner/runner.py +++ b/python/tvm/meta_schedule/runner/runner.py @@ -15,6 +15,18 @@ # specific language governing permissions and limitations # under the License. """Runners""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import FloatImm + from meta_schedule import ArgInfo +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import Callable, List, Optional, Union # isort: off @@ -43,6 +55,14 @@ class RunnerInput(Object): The argument information. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.RunnerInput + # fmt: off + artifact_path: str + device_type: str + args_info: Sequence[ArgInfo] + # fmt: on + # tvm-ffi-stubgen(end) + artifact_path: str device_type: str args_info: List[ArgInfo] @@ -84,6 +104,13 @@ class RunnerResult(Object): The error message, if any. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.RunnerResult + # fmt: off + run_secs: Sequence[FloatImm] | None + error_msg: str | None + # fmt: on + # tvm-ffi-stubgen(end) + run_secs: Optional[List[float]] error_msg: Optional[str] @@ -118,6 +145,11 @@ class RunnerFuture(Object): See also: PyRunnerFuture """ + # tvm-ffi-stubgen(begin): object/meta_schedule.RunnerFuture + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, f_done: Callable, f_result: Callable = None) -> None: """Constructor""" @@ -167,6 +199,11 @@ def result(self) -> RunnerResult: class Runner(Object): """The abstract runner interface""" + # tvm-ffi-stubgen(begin): object/meta_schedule.Runner + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + RunnerType = Union["Runner", Literal["local", "rpc"]] def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: @@ -214,6 +251,11 @@ class _PyRunner(Runner): See also: PyRunner """ + # tvm-ffi-stubgen(begin): object/meta_schedule.PyRunner + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, f_run: Callable = None) -> None: """Constructor""" diff --git a/python/tvm/meta_schedule/schedule_rule/add_rfactor.py b/python/tvm/meta_schedule/schedule_rule/add_rfactor.py index 2bef40fffe74..3812d6f0290f 100644 --- a/python/tvm/meta_schedule/schedule_rule/add_rfactor.py +++ b/python/tvm/meta_schedule/schedule_rule/add_rfactor.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Add-rfactor Rule that add-rfactor to some blocks if needed""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from typing import Optional from tvm_ffi import register_object @@ -37,6 +39,13 @@ class AddRFactor(ScheduleRule): The maximum size of the innermost factor. None means no limit. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.AddRFactor + # fmt: off + max_jobs_per_core: int + max_innermost_factor: int + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, max_jobs_per_core: int = 16, diff --git a/python/tvm/meta_schedule/schedule_rule/apply_custom_rule.py b/python/tvm/meta_schedule/schedule_rule/apply_custom_rule.py index 2e383c75eb91..fde1d249b777 100644 --- a/python/tvm/meta_schedule/schedule_rule/apply_custom_rule.py +++ b/python/tvm/meta_schedule/schedule_rule/apply_custom_rule.py @@ -16,6 +16,16 @@ # under the License. """Create a rule that applies customized rules registered using block attribute `schedule_rule`. The rule will be dispatched according to target keys.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from target import Target +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -27,6 +37,12 @@ class ApplyCustomRule(ScheduleRule): """A rule that applies customized rules registered using block attribute `schedule_rule`. The rule will be dispatched according to target keys.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.ApplyCustomRule + # fmt: off + target_: Target | None + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.ScheduleRuleApplyCustomRule, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/schedule_rule/auto_bind.py b/python/tvm/meta_schedule/schedule_rule/auto_bind.py index 0704b03f740f..05d794a249f4 100644 --- a/python/tvm/meta_schedule/schedule_rule/auto_bind.py +++ b/python/tvm/meta_schedule/schedule_rule/auto_bind.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Auto-bind Rule that binds blocks to threads if needed""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from typing import List, Optional from tvm_ffi import register_object @@ -37,6 +39,11 @@ class AutoBind(ScheduleRule): The maximum number of threads per block, if it is known when this schedule rule is created. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.AutoBind + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, max_threadblocks: int = 256, diff --git a/python/tvm/meta_schedule/schedule_rule/auto_inline.py b/python/tvm/meta_schedule/schedule_rule/auto_inline.py index b789dd750707..cb0f2d42a6a2 100644 --- a/python/tvm/meta_schedule/schedule_rule/auto_inline.py +++ b/python/tvm/meta_schedule/schedule_rule/auto_inline.py @@ -15,6 +15,17 @@ # specific language governing permissions and limitations # under the License. """Auto-Inline. Rule that inlines spatial blocks if it satisfies some conditions""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import Op +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import List, Optional from tvm_ffi import register_object @@ -45,6 +56,18 @@ class AutoInline(ScheduleRule): The operators that are disallowed in auto inline """ + # tvm-ffi-stubgen(begin): object/meta_schedule.AutoInline + # fmt: off + into_producer: bool + into_consumer: bool + inline_const_tensor: bool + disallow_if_then_else: bool + require_injective: bool + require_ordered: bool + disallow_op: Sequence[Op] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, into_producer: bool, @@ -76,6 +99,11 @@ class InlineConstantScalars(ScheduleRule): InlineConstantScalars before AutoInline. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.InlineConstantScalars + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, ) -> None: diff --git a/python/tvm/meta_schedule/schedule_rule/cross_thread_reduction.py b/python/tvm/meta_schedule/schedule_rule/cross_thread_reduction.py index 0c79d4f08bac..7370977652e9 100644 --- a/python/tvm/meta_schedule/schedule_rule/cross_thread_reduction.py +++ b/python/tvm/meta_schedule/schedule_rule/cross_thread_reduction.py @@ -15,6 +15,17 @@ # specific language governing permissions and limitations # under the License. """Rules which apply cross-thread reduction to some reduction blocks correspondingly when needed""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import IntImm +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import List from tvm_ffi import register_object @@ -34,6 +45,14 @@ class CrossThreadReduction(ScheduleRule): Candidates of thread axis extent (values are required to be positive). """ + # tvm-ffi-stubgen(begin): object/meta_schedule.CrossThreadReduction + # fmt: off + max_threads_per_block: int + warp_size: int + thread_extents: Sequence[IntImm] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, thread_extents: List[int]) -> None: self.__init_handle_by_constructor__( _ffi_api.ScheduleRuleCrossThreadReduction, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py index 41c97a7862b4..c0fe81eb456f 100644 --- a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py +++ b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py @@ -15,6 +15,16 @@ # specific language governing permissions and limitations # under the License. """Multi-level tiling with reuse.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Callable from tvm.tir.schedule import Schedule, BlockRV @@ -70,6 +80,14 @@ class MultiLevelTiling(ScheduleRule): for a block that should be tiled (based on the block name, for example). """ + # tvm-ffi-stubgen(begin): object/meta_schedule.MultiLevelTiling + # fmt: off + structure: str + tile_binds: Sequence[str] + max_innermost_factor: int + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, structure: str, @@ -119,6 +137,11 @@ class MultiLevelTilingWithIntrin(ScheduleRule): Data reuse configuration for writing. None means no reuse. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.MultiLevelTilingWithIntrin + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, intrin_name: str, @@ -173,6 +196,11 @@ class MultiLevelTilingTensorCore(ScheduleRule): Whether to use the software pipeline. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.MultiLevelTilingTensorCore + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, intrin_groups: List[Mapping[str, str]], @@ -216,6 +244,11 @@ class MultiLevelTilingWideVector(ScheduleRule): Data reuse configuration for writing. None means no reuse. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.MultiLevelTilingWideVector + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, structure: str, diff --git a/python/tvm/meta_schedule/schedule_rule/parallel_vectorize_unroll.py b/python/tvm/meta_schedule/schedule_rule/parallel_vectorize_unroll.py index 259620b3f715..7c459173e8c6 100644 --- a/python/tvm/meta_schedule/schedule_rule/parallel_vectorize_unroll.py +++ b/python/tvm/meta_schedule/schedule_rule/parallel_vectorize_unroll.py @@ -16,6 +16,17 @@ # under the License. """Rule that mark parallelize, vectorize and unroll to the root block. The mark will be applied to each block in a follow-up post processor""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import IntImm +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import List, Optional from tvm_ffi import register_object @@ -46,6 +57,15 @@ class ParallelizeVectorizeUnroll(ScheduleRule): Whether to explicitly unroll the loop, or just add an "unroll" pragma """ + # tvm-ffi-stubgen(begin): object/meta_schedule.ParallelizeVectorizeUnroll + # fmt: off + max_jobs_per_core: int + max_vectorize_extent: int + unroll_max_steps: Sequence[IntImm] + unroll_explicit: bool + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, max_jobs_per_core: int = 16, diff --git a/python/tvm/meta_schedule/schedule_rule/random_compute_location.py b/python/tvm/meta_schedule/schedule_rule/random_compute_location.py index 8f1c96f6eb0a..d9916db594a5 100644 --- a/python/tvm/meta_schedule/schedule_rule/random_compute_location.py +++ b/python/tvm/meta_schedule/schedule_rule/random_compute_location.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Rule that randomly select a compute-at location for a free block""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -25,6 +27,11 @@ class RandomComputeLocation(ScheduleRule): """A rule that randomly select a compute-at location for a free block""" + # tvm-ffi-stubgen(begin): object/meta_schedule.RandomComputeLocation + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.ScheduleRuleRandomComputeLocation, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py index 98c81e5b8f30..39903ed4969f 100644 --- a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py +++ b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py @@ -18,6 +18,8 @@ Meta Schedule schedule rules are used for modification of blocks in a schedule. See also PostOrderApply. """ +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from typing import TYPE_CHECKING, Callable, List # isort: off @@ -40,6 +42,11 @@ class ScheduleRule(Object): """Rules to modify a block in a schedule.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.ScheduleRule + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def _initialize_with_tune_context(self, context: "TuneContext") -> None: """Initialize the schedule rule with a tune context. @@ -121,6 +128,11 @@ class _PyScheduleRule(ScheduleRule): See also: PyScheduleRule """ + # tvm-ffi-stubgen(begin): object/meta_schedule.PyScheduleRule + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, f_initialize_with_tune_context: Callable = None, diff --git a/python/tvm/meta_schedule/search_strategy/evolutionary_search.py b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py index 04f9310e6e0d..c816c2a5a690 100644 --- a/python/tvm/meta_schedule/search_strategy/evolutionary_search.py +++ b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Evolutionary Search Strategy""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -47,6 +49,20 @@ class EvolutionarySearch(SearchStrategy): The ratio of greedy selected samples in the final picks. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.EvolutionarySearch + # fmt: off + population_size: int + num_empty_iters_before_early_stop: int + init_measured_ratio: float + init_min_unmeasured: int + max_fail_count: int + genetic_num_iters: int + genetic_mutate_prob: float + genetic_max_fail_count: int + eps_greedy: float + # fmt: on + # tvm-ffi-stubgen(end) + population_size: int init_measured_ratio: int init_min_unmeasured: int diff --git a/python/tvm/meta_schedule/search_strategy/replay_func.py b/python/tvm/meta_schedule/search_strategy/replay_func.py index 682c9638c513..0ced919f671b 100644 --- a/python/tvm/meta_schedule/search_strategy/replay_func.py +++ b/python/tvm/meta_schedule/search_strategy/replay_func.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Replay Trace Search Strategy""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -35,6 +37,11 @@ class ReplayFunc(SearchStrategy): Total number of trials for one task """ + # tvm-ffi-stubgen(begin): object/meta_schedule.ReplayFunc + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self): """Constructor""" self.__init_handle_by_constructor__( diff --git a/python/tvm/meta_schedule/search_strategy/replay_trace.py b/python/tvm/meta_schedule/search_strategy/replay_trace.py index e04a440da68a..a29643d1d26b 100644 --- a/python/tvm/meta_schedule/search_strategy/replay_trace.py +++ b/python/tvm/meta_schedule/search_strategy/replay_trace.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Replay Trace Search Strategy""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -33,6 +35,12 @@ class ReplayTrace(SearchStrategy): Max number of failures during trace replaying. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.ReplayTrace + # fmt: off + max_fail_count: int + # fmt: on + # tvm-ffi-stubgen(end) + max_fail_count: int def __init__(self, max_fail_count: int = 100): diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py index cfb45dafdeb2..958bb98df17e 100644 --- a/python/tvm/meta_schedule/search_strategy/search_strategy.py +++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py @@ -18,6 +18,18 @@ Meta Schedule search strategy that generates the measure candidates for measurement. """ +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from meta_schedule import ArgInfo + from tir import Schedule +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import TYPE_CHECKING, Callable, List, Optional, Union # isort: off @@ -50,6 +62,13 @@ class MeasureCandidate(Object): The argument information. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.MeasureCandidate + # fmt: off + sch: Schedule + args_info: Sequence[ArgInfo] + # fmt: on + # tvm-ffi-stubgen(end) + sch: Schedule args_info: List[ArgInfo] @@ -78,6 +97,11 @@ def __init__( class SearchStrategy(Object): """Search strategy is the class that generates the measure candidates.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.SearchStrategy + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + SearchStrategyType = Union[ "SearchStrategy", Literal[ @@ -228,6 +252,11 @@ class _PySearchStrategy(SearchStrategy): See also: PySearchStrategy """ + # tvm-ffi-stubgen(begin): object/meta_schedule.PySearchStrategy + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, f_initialize_with_tune_context: Callable = None, diff --git a/python/tvm/meta_schedule/space_generator/post_order_apply.py b/python/tvm/meta_schedule/space_generator/post_order_apply.py index 45b81bdf3e59..9ab087270e98 100644 --- a/python/tvm/meta_schedule/space_generator/post_order_apply.py +++ b/python/tvm/meta_schedule/space_generator/post_order_apply.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Post Order Apply Space Generator.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -42,6 +44,11 @@ class PostOrderApply(SpaceGenerator): all blocks will have schedules generated. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.PostOrderApply + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, f_block_filter=None, diff --git a/python/tvm/meta_schedule/space_generator/schedule_fn.py b/python/tvm/meta_schedule/space_generator/schedule_fn.py index d01cd7fdcbd1..85dbcbaf0d0a 100644 --- a/python/tvm/meta_schedule/space_generator/schedule_fn.py +++ b/python/tvm/meta_schedule/space_generator/schedule_fn.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Union of meta Schedule design space generators.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -36,6 +38,11 @@ class ScheduleFn(SpaceGenerator): - 3) [Schedule] -> List[Schedule] """ + # tvm-ffi-stubgen(begin): object/meta_schedule.ScheduleFn + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, sch_fn: SpaceGenerator.ScheduleFnType, diff --git a/python/tvm/meta_schedule/space_generator/space_generator.py b/python/tvm/meta_schedule/space_generator/space_generator.py index 35f9e2236764..852b26a32ee2 100644 --- a/python/tvm/meta_schedule/space_generator/space_generator.py +++ b/python/tvm/meta_schedule/space_generator/space_generator.py @@ -18,6 +18,18 @@ Meta Schedule design space generators that generates design space for generation of measure candidates. """ +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import FloatImm + from meta_schedule import Mutator, Postproc, ScheduleRule +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union # isort: off @@ -42,6 +54,14 @@ class SpaceGenerator(Object): """The abstract design space generator interface.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.SpaceGenerator + # fmt: off + sch_rules: Sequence[ScheduleRule] | None + postprocs: Sequence[Postproc] | None + mutator_probs: Mapping[Mutator, FloatImm] | None + # fmt: on + # tvm-ffi-stubgen(end) + ScheduleFnType = Union[ Callable[[Schedule], None], # No output Callable[[Schedule], Schedule], # Single output @@ -193,6 +213,11 @@ class _PySpaceGenerator(SpaceGenerator): See also: PySpaceGenerator """ + # tvm-ffi-stubgen(begin): object/meta_schedule.PySpaceGenerator + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, sch_rules: ScheduleRuleType = "from-target", diff --git a/python/tvm/meta_schedule/space_generator/space_generator_union.py b/python/tvm/meta_schedule/space_generator/space_generator_union.py index 0b8ceb453116..24d4c7ceb612 100644 --- a/python/tvm/meta_schedule/space_generator/space_generator_union.py +++ b/python/tvm/meta_schedule/space_generator/space_generator_union.py @@ -15,6 +15,17 @@ # specific language governing permissions and limitations # under the License. """Union of meta Schedule design space generators.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from meta_schedule import SpaceGenerator +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import List from tvm_ffi import register_object @@ -33,6 +44,12 @@ class SpaceGeneratorUnion(SpaceGenerator): """Union of design space generators.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.SpaceGeneratorUnion + # fmt: off + space_generators: Sequence[SpaceGenerator] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, space_generators: List[SpaceGenerator], diff --git a/python/tvm/meta_schedule/task_scheduler/gradient_based.py b/python/tvm/meta_schedule/task_scheduler/gradient_based.py index 18d7e2be614a..008a4dfffb87 100644 --- a/python/tvm/meta_schedule/task_scheduler/gradient_based.py +++ b/python/tvm/meta_schedule/task_scheduler/gradient_based.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Gradient Based Task Scheduler""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -28,6 +30,13 @@ class GradientBased(TaskScheduler): """Gradient Based Task Scheduler""" + # tvm-ffi-stubgen(begin): object/meta_schedule.GradientBased + # fmt: off + alpha: float + window_size: int + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, *, diff --git a/python/tvm/meta_schedule/task_scheduler/round_robin.py b/python/tvm/meta_schedule/task_scheduler/round_robin.py index 78504608f9ab..96ee6824425d 100644 --- a/python/tvm/meta_schedule/task_scheduler/round_robin.py +++ b/python/tvm/meta_schedule/task_scheduler/round_robin.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Round Robin Task Scheduler""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -28,6 +30,12 @@ class RoundRobin(TaskScheduler): """Round Robin Task Scheduler""" + # tvm-ffi-stubgen(begin): object/meta_schedule.RoundRobin + # fmt: off + task_id: int + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: """Constructor.""" self.__init_handle_by_constructor__( diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py index 4513f6081560..56f78c6b4ed8 100644 --- a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py +++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py @@ -15,6 +15,17 @@ # specific language governing permissions and limitations # under the License. """Auto-tuning Task Scheduler""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from meta_schedule import BuilderResult, CostModel, Database, MeasureCallback, MeasureCandidate, RunnerFuture, TuneContext +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import Callable, List, Optional, Union # isort: off @@ -42,6 +53,20 @@ class TaskRecord(Object): """The running record of a task.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.TaskRecord + # fmt: off + ctx: TuneContext + task_weight: float + flop: float + is_terminated: bool + build_error_count: int + run_error_count: int + measure_candidates: Sequence[MeasureCandidate] | None + builder_results: Sequence[BuilderResult] | None + runner_futures: Sequence[RunnerFuture] | None + # fmt: on + # tvm-ffi-stubgen(end) + ctx: TuneContext task_weight: float flop: float @@ -57,6 +82,16 @@ class TaskRecord(Object): class TaskScheduler(Object): """The abstract task scheduler interface.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.TaskScheduler + # fmt: off + tasks_: Sequence[TaskRecord] + measure_callbacks_: Sequence[MeasureCallback] + database_: Database | None + cost_model_: CostModel | None + remaining_tasks_: int + # fmt: on + # tvm-ffi-stubgen(end) + tasks_: List[TaskRecord] measure_callbacks_: List[MeasureCallback] database_: Optional[Database] @@ -198,6 +233,11 @@ class _PyTaskScheduler(TaskScheduler): See also: PyTaskScheduler """ + # tvm-ffi-stubgen(begin): object/meta_schedule.PyTaskScheduler + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, f_next_task_id: Callable, diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 35a8d468a75c..a6816445aaa0 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -16,6 +16,18 @@ # under the License. """Meta Schedule tuning context.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ir import IRModule + from meta_schedule import SearchStrategy, SpaceGenerator + from target import Target +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import TYPE_CHECKING, List, Optional, Union # isort: off @@ -84,6 +96,18 @@ class TuneContext(Object): The number of threads to be used, None means using the logical cpu count. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.TuneContext + # fmt: off + mod: IRModule | None + target: Target | None + space_generator: SpaceGenerator | None + search_strategy: SearchStrategy | None + task_name: str | None + num_threads: int + rand_state: int + # fmt: on + # tvm-ffi-stubgen(end) + mod: Optional[IRModule] target: Optional[Target] space_generator: Optional["SpaceGenerator"] diff --git a/python/tvm/relax/_ffi_api.py b/python/tvm/relax/_ffi_api.py index c5e98a22eaaf..89df5defbe5d 100644 --- a/python/tvm/relax/_ffi_api.py +++ b/python/tvm/relax/_ffi_api.py @@ -15,6 +15,143 @@ # specific language governing permissions and limitations # under the License. """FFI API for Relax.""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from expr_functor import PyExprMutator, PyExprVisitor + from ir import Attrs, BaseFunc, DictAttrs, EnvFunc, GlobalVar, IRModule, IntImm, PrimExpr, RelaxExpr, Span, StructInfo, VDevice + from relax import BlockBuilder, DataflowBlockRewrite, DynTensorType, ExecBuilder, FuncStructInfo, Id, ObjectStructInfo, ObjectType, PackedFuncType, PrimStructInfo, ShapeStructInfo, ShapeType, TensorStructInfo, TupleStructInfo + from relax.expr import Binding, BindingBlock, Call, Constant, DataTypeImm, DataflowBlock, DataflowVar, ExternFunc, Function, If, MatchCast, PrimValue, SeqExpr, ShapeExpr, StringImm, Tuple, TupleGetItem, Var, VarBinding + from target import Target + from te import Tensor + from tir import Var + from tvm_ffi import Module, Object, Tensor, dtype + from typing import Any, Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax", __name__) + +# tvm-ffi-stubgen(begin): global/relax +# fmt: off +_FFI_INIT_FUNC("relax", __name__) +if TYPE_CHECKING: + def BindingBlock(_0: Sequence[Binding], _1: Span, /) -> BindingBlock: ... + def BlockBuilderAddFunction(_0: BlockBuilder, _1: BaseFunc, _2: str, /) -> GlobalVar: ... + def BlockBuilderBeginBindingBlock(_0: BlockBuilder, /) -> None: ... + def BlockBuilderBeginDataflowBlock(_0: BlockBuilder, /) -> None: ... + def BlockBuilderBeginScope(_0: BlockBuilder, _1: Sequence[Var] | None, /) -> None: ... + def BlockBuilderCreate(_0: IRModule | None, /) -> BlockBuilder: ... + def BlockBuilderCurrentBlockIsDataFlow(_0: BlockBuilder, /) -> bool: ... + def BlockBuilderEmit(_0: BlockBuilder, _1: RelaxExpr, _2: str, /) -> Var: ... + def BlockBuilderEmitMatchCast(_0: BlockBuilder, _1: RelaxExpr, _2: StructInfo, _3: str, /) -> Var: ... + def BlockBuilderEmitNormalized(_0: BlockBuilder, _1: Binding, /) -> None: ... + def BlockBuilderEmitOutput(_0: BlockBuilder, _1: RelaxExpr, _2: str, /) -> Var: ... + def BlockBuilderEndBlock(_0: BlockBuilder, /) -> BindingBlock: ... + def BlockBuilderEndScope(_0: BlockBuilder, /) -> None: ... + def BlockBuilderFinalize(_0: BlockBuilder, /) -> IRModule: ... + def BlockBuilderGetContextIRModule(_0: BlockBuilder, /) -> IRModule: ... + def BlockBuilderGetUniqueName(_0: BlockBuilder, _1: str, /) -> str: ... + def BlockBuilderLookupBinding(_0: BlockBuilder, _1: Var, /) -> RelaxExpr | None: ... + def BlockBuilderNormalize(_0: BlockBuilder, _1: RelaxExpr, /) -> RelaxExpr: ... + def BlockBuilderUpdateFunction(_0: BlockBuilder, _1: GlobalVar, _2: BaseFunc, /) -> None: ... + def Call(_0: RelaxExpr, _1: Sequence[RelaxExpr], _2: Attrs, _3: Sequence[StructInfo], _4: Span, /) -> Call: ... + def Constant(_0: Tensor, _1: StructInfo | None, _2: Span, /) -> Constant: ... + def CopyWithNewVars(_0: Function, /) -> Function: ... + def DataTypeImm(_0: dtype, _1: Span, /) -> DataTypeImm: ... + def DataflowBlock(_0: Sequence[Binding], _1: Span, /) -> DataflowBlock: ... + def DataflowBlockRewrite(_0: DataflowBlock, _1: Function, /) -> DataflowBlockRewrite: ... + def DataflowVar(_0: str, _1: StructInfo | None, _2: Span, /) -> DataflowVar: ... + def DataflowVarFromId(_0: Id, _1: StructInfo | None, _2: Span, /) -> DataflowVar: ... + def ExecBuilderC(_0: ExecBuilder, _1: int, /) -> int: ... + def ExecBuilderConvertConstant(*args: Any) -> Any: ... + def ExecBuilderCreate() -> ExecBuilder: ... + def ExecBuilderDeclareFunction(_0: ExecBuilder, _1: str, _2: int, /) -> None: ... + def ExecBuilderEmitCall(_0: ExecBuilder, _1: str, _2: Sequence[IntImm], _3: int, /) -> None: ... + def ExecBuilderEmitFunction(_0: ExecBuilder, _1: str, _2: int, _3: Sequence[str] | None, /) -> None: ... + def ExecBuilderEmitGoto(_0: ExecBuilder, _1: int, /) -> None: ... + def ExecBuilderEmitIf(_0: ExecBuilder, _1: int, _2: int, /) -> None: ... + def ExecBuilderEmitRet(_0: ExecBuilder, _1: int, /) -> None: ... + def ExecBuilderEndFunction(_0: ExecBuilder, _1: str, /) -> None: ... + def ExecBuilderF(_0: ExecBuilder, _1: str, /) -> int: ... + def ExecBuilderGet(_0: ExecBuilder, /) -> Module: ... + def ExecBuilderImm(_0: ExecBuilder, _1: int, /) -> int: ... + def ExecBuilderR(_0: ExecBuilder, _1: int, /) -> int: ... + def ExecutableLoadFromFile(_0: str, /) -> Module: ... + def ExprMutatorVisitBinding(_0: PyExprMutator, _1: Binding, /) -> None: ... + def ExprMutatorVisitBindingBlock(_0: PyExprMutator, _1: BindingBlock, /) -> BindingBlock: ... + def ExprMutatorVisitExpr(_0: PyExprMutator, _1: RelaxExpr, /) -> RelaxExpr: ... + def ExprMutatorVisitVarDef(_0: PyExprMutator, _1: Var, /) -> Var: ... + def ExprVisitorVisitBinding(_0: PyExprVisitor, _1: Binding, /) -> None: ... + def ExprVisitorVisitBindingBlock(_0: PyExprVisitor, _1: BindingBlock, /) -> None: ... + def ExprVisitorVisitExpr(_0: PyExprVisitor, _1: RelaxExpr, /) -> None: ... + def ExprVisitorVisitSpan(_0: PyExprVisitor, _1: Span, /) -> None: ... + def ExprVisitorVisitVarDef(_0: PyExprVisitor, _1: Var, /) -> None: ... + def ExternFunc(_0: str, _1: StructInfo | None, _2: Span, /) -> ExternFunc: ... + def FuncStructInfo(_0: Sequence[StructInfo], _1: StructInfo, _2: bool, _3: Span, /) -> FuncStructInfo: ... + def FuncStructInfoOpaqueFunc(_0: StructInfo | None, _1: EnvFunc | None, _2: bool, _3: Span, /) -> FuncStructInfo: ... + def FuncWithAttr(_0: BaseFunc, _1: str, _2: Object, /) -> Function | None: ... + def FuncWithAttrs(_0: BaseFunc, _1: Mapping[str, Any], /) -> Function | None: ... + def FuncWithoutAttr(_0: BaseFunc, _1: str, /) -> Function | None: ... + def Function(_0: Sequence[Var], _1: RelaxExpr, _2: StructInfo | None, _3: bool, _4: DictAttrs, _5: Span, /) -> Function: ... + def FunctionBindParams(_0: Function, _1: Mapping[Any, Object], /) -> Function: ... + def FunctionBindSymbolicVars(_0: Function, _1: Mapping[Var | str, PrimExpr], /) -> Function: ... + def FunctionCreateEmpty(_0: Sequence[Var], _1: StructInfo, _2: bool, _3: DictAttrs, _4: Span, /) -> Function: ... + def FunctionInlineFunctions(_0: Function, _1: Mapping[str | GlobalVar, Function], /) -> Function: ... + def GetShapeOf(_0: RelaxExpr, /) -> RelaxExpr: ... + def If(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: Span, /) -> If: ... + def MakePyExprMutator(_0: BlockBuilder, _1: Callable[..., Any], _2: Callable[..., Any], _3: Callable[..., Any], _4: Callable[..., Any], _5: Callable[..., Any], _6: Callable[..., Any], _7: Callable[..., Any], _8: Callable[..., Any], _9: Callable[..., Any], _10: Callable[..., Any], _11: Callable[..., Any], _12: Callable[..., Any], _13: Callable[..., Any], _14: Callable[..., Any], _15: Callable[..., Any], _16: Callable[..., Any], _17: Callable[..., Any], _18: Callable[..., Any], _19: Callable[..., Any], _20: Callable[..., Any], _21: Callable[..., Any], _22: Callable[..., Any], _23: Callable[..., Any], _24: Callable[..., Any], _25: Callable[..., Any], _26: Callable[..., Any], _27: Callable[..., Any], /) -> PyExprMutator: ... + def MakePyExprVisitor(_0: Callable[..., Any], _1: Callable[..., Any], _2: Callable[..., Any], _3: Callable[..., Any], _4: Callable[..., Any], _5: Callable[..., Any], _6: Callable[..., Any], _7: Callable[..., Any], _8: Callable[..., Any], _9: Callable[..., Any], _10: Callable[..., Any], _11: Callable[..., Any], _12: Callable[..., Any], _13: Callable[..., Any], _14: Callable[..., Any], _15: Callable[..., Any], _16: Callable[..., Any], _17: Callable[..., Any], _18: Callable[..., Any], _19: Callable[..., Any], _20: Callable[..., Any], _21: Callable[..., Any], _22: Callable[..., Any], _23: Callable[..., Any], _24: Callable[..., Any], _25: Callable[..., Any], _26: Callable[..., Any], /) -> PyExprVisitor: ... + def MatchCast(_0: Var, _1: RelaxExpr, _2: StructInfo, _3: Span, /) -> MatchCast: ... + def ObjectStructInfo(_0: Span, /) -> ObjectStructInfo: ... + def ObjectType(_0: Span, /) -> ObjectType: ... + def PackedFuncType(_0: Span, /) -> PackedFuncType: ... + def PrimStructInfoFromDtype(_0: dtype, _1: Span, /) -> PrimStructInfo: ... + def PrimStructInfoFromValue(_0: PrimExpr, _1: Span, /) -> PrimStructInfo: ... + def PrimValue(_0: PrimExpr, _1: Span, /) -> PrimValue: ... + def PyExprMutatorGetVarRemap(_0: PyExprMutator, _1: Id, /) -> Var: ... + def PyExprMutatorLookupBinding(_0: PyExprMutator, _1: Var, /) -> RelaxExpr | None: ... + def PyExprMutatorSetVarRemap(_0: PyExprMutator, _1: Id, _2: Var, /) -> Var: ... + def PyExprMutatorVisitBinding(_0: PyExprMutator, _1: Binding, /) -> None: ... + def PyExprMutatorVisitBindingBlock(_0: PyExprMutator, _1: BindingBlock, /) -> BindingBlock: ... + def PyExprMutatorVisitExpr(_0: PyExprMutator, _1: RelaxExpr, /) -> RelaxExpr: ... + def PyExprMutatorVisitExprPostOrder(_0: PyExprMutator, _1: RelaxExpr, /) -> RelaxExpr: ... + def PyExprMutatorVisitVarDef(_0: PyExprMutator, _1: Var, /) -> Var: ... + def PyExprMutatorVisitWithNewScope(_0: PyExprMutator, _1: RelaxExpr, /) -> RelaxExpr: ... + def PyExprMutatorWithStructInfo(_0: PyExprMutator, _1: Var, _2: StructInfo, /) -> Var: ... + def PyExprVisitorVisitBinding(_0: PyExprVisitor, _1: Binding, /) -> None: ... + def PyExprVisitorVisitBindingBlock(_0: PyExprVisitor, _1: BindingBlock, /) -> None: ... + def PyExprVisitorVisitExpr(_0: PyExprVisitor, _1: RelaxExpr, /) -> None: ... + def PyExprVisitorVisitVarDef(_0: PyExprVisitor, _1: Var, /) -> None: ... + def SeqExpr(_0: Sequence[BindingBlock], _1: RelaxExpr, _2: Span, /) -> SeqExpr: ... + def ShapeExpr(_0: Sequence[PrimExpr], _1: Span, /) -> ShapeExpr: ... + def ShapeStructInfo(_0: Sequence[PrimExpr] | None, _1: int, _2: Span, /) -> ShapeStructInfo: ... + def ShapeType(_0: int, _1: Span, /) -> ShapeType: ... + def StringImm(_0: str, _1: Span, /) -> StringImm: ... + def StructInfoIsBaseOf(_0: StructInfo, _1: StructInfo, /) -> bool: ... + def TETensor(_0: RelaxExpr, _1: Mapping[Var, PrimExpr], _2: str, /) -> Tensor: ... + def TensorStructInfo(_0: RelaxExpr | None, _1: dtype | None, _2: int, _3: VDevice, _4: Span, /) -> TensorStructInfo: ... + def TensorType(_0: int, _1: dtype, _2: Span, /) -> DynTensorType: ... + def Tuple(_0: Sequence[RelaxExpr], _1: Span, /) -> Tuple: ... + def TupleGetItem(_0: RelaxExpr, _1: int, _2: Span, /) -> TupleGetItem: ... + def TupleStructInfo(_0: Sequence[StructInfo], _1: Span, /) -> TupleStructInfo: ... + def UpdateStructInfo(_0: RelaxExpr, _1: StructInfo, /) -> None: ... + def VMCodeGen(_0: ExecBuilder, _1: IRModule, /) -> IRModule: ... + def VMLink(_0: ExecBuilder, _1: Target, _2: Module | None, _3: Sequence[Module], _4: Mapping[str, Tensor], /) -> Module: ... + def VMTIRCodeGen(_0: ExecBuilder, _1: IRModule, /) -> IRModule: ... + def Var(_0: str, _1: StructInfo | None, _2: Span, /) -> Var: ... + def VarBinding(_0: Var, _1: RelaxExpr, _2: Span, /) -> VarBinding: ... + def VarFromId(_0: Id, _1: StructInfo | None, _2: Span, /) -> Var: ... + def dfb_rewrite_add(_0: DataflowBlockRewrite, _1: RelaxExpr, _2: str | None, _3: bool, /) -> None: ... + def dfb_rewrite_add_binding(_0: DataflowBlockRewrite, _1: Binding, /) -> None: ... + def dfb_rewrite_mutate_irmodule(_0: DataflowBlockRewrite, _1: IRModule, /) -> IRModule: ... + def dfb_rewrite_remove_all_unused(_0: DataflowBlockRewrite, /) -> None: ... + def dfb_rewrite_remove_unused(_0: DataflowBlockRewrite, _1: Var, _2: bool, /) -> None: ... + def dfb_rewrite_replace_all_uses(_0: DataflowBlockRewrite, _1: Var, _2: Var, /) -> None: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/analysis/_ffi_api.py b/python/tvm/relax/analysis/_ffi_api.py index 0a230fbd8bb6..3c570e4b3fde 100644 --- a/python/tvm/relax/analysis/_ffi_api.py +++ b/python/tvm/relax/analysis/_ffi_api.py @@ -14,6 +14,53 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import GlobalVar, IRModule, PrimExpr, RelaxExpr, StructInfo, Type + from relax import BlockBuilder, FuncStructInfo + from relax.expr import Binding, Call, DataflowBlock, Function, Var + from tir import Block, IndexMap, PrimFunc, Var + from tvm_ffi import Object + from typing import Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.analysis", __name__) + +# tvm-ffi-stubgen(begin): global/relax.analysis +# fmt: off +_FFI_INIT_FUNC("relax.analysis", __name__) +if TYPE_CHECKING: + def CollectNonNegativeExpressions(_0: StructInfo, /) -> Sequence[PrimExpr]: ... + def DefinableTIRVarsInStructInfo(_0: StructInfo, /) -> Sequence[Var]: ... + def DefinedSymbolicVars(_0: RelaxExpr, /) -> Sequence[Var]: ... + def DeriveCallRetStructInfo(_0: FuncStructInfo, _1: Call, _2: BlockBuilder, /) -> StructInfo: ... + def EraseToWellDefined(_0: StructInfo, _1: Mapping[Var, PrimExpr], _2: Mapping[Var, RelaxExpr], /) -> StructInfo: ... + def FreeSymbolicVars(_0: RelaxExpr, /) -> Sequence[Var]: ... + def GetStaticType(_0: StructInfo, /) -> Type: ... + def StructInfoBaseCheck(_0: StructInfo, _1: StructInfo, /) -> int: ... + def StructInfoLCA(_0: StructInfo, _1: StructInfo, /) -> StructInfo: ... + def TIRVarsInStructInfo(_0: StructInfo, /) -> Sequence[Var]: ... + def all_global_vars(_0: RelaxExpr, /) -> Sequence[GlobalVar]: ... + def all_vars(_0: RelaxExpr, /) -> Sequence[Var]: ... + def bound_vars(_0: RelaxExpr, /) -> Sequence[Var]: ... + def computable_at_compile_time(_0: Function, /) -> Sequence[Var]: ... + def contains_impure_call(_0: RelaxExpr, _1: RelaxExpr | None, /) -> bool: ... + def detect_recursion(_0: IRModule, /) -> Sequence[Sequence[GlobalVar]]: ... + def free_vars(_0: RelaxExpr, /) -> Sequence[Var]: ... + def get_var2val(_0: Function, /) -> Mapping[Var, RelaxExpr]: ... + def has_reshape_pattern(_0: PrimFunc, /) -> bool: ... + def name_to_binding(_0: Function, /) -> Mapping[str, Sequence[Binding]]: ... + def post_order_visit(_0: RelaxExpr, _1: Callable[..., Any], /) -> None: ... + def remove_all_unused(_0: RelaxExpr, /) -> RelaxExpr: ... + def suggest_layout_transforms(_0: PrimFunc, _1: Sequence[IndexMap], /) -> Mapping[Block, Mapping[Object, IndexMap]]: ... + def udchain(_0: DataflowBlock, /) -> Mapping[Var, Sequence[Var]]: ... + def well_formed(_0: IRModule | Function, _1: bool, /) -> bool: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/backend/_ffi_api.py b/python/tvm/relax/backend/_ffi_api.py index 97a999788b93..6fcf84e3e644 100644 --- a/python/tvm/relax/backend/_ffi_api.py +++ b/python/tvm/relax/backend/_ffi_api.py @@ -15,7 +15,31 @@ # specific language governing permissions and limitations # under the License. """FFI API for Relax backend.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import IRModule + from meta_schedule import ExtractedTask + from relax.transform import FusionPattern + from target import Target +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -import tvm_ffi -tvm_ffi.init_ffi_api("relax.backend", __name__) +# tvm-ffi-stubgen(begin): global/relax.backend +# fmt: off +_FFI_INIT_FUNC("relax.backend", __name__) +if TYPE_CHECKING: + def GetPattern(_0: str, /) -> FusionPattern | None: ... + def GetPatternsWithPrefix(_0: str, /) -> Sequence[FusionPattern]: ... + def MetaScheduleExtractTask(_0: IRModule, _1: Target, _2: str, /) -> Sequence[ExtractedTask]: ... + def RegisterPatterns(_0: Sequence[FusionPattern], /) -> None: ... + def RemovePatterns(_0: Sequence[str], /) -> None: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/backend/adreno/transform/_ffi_api.py b/python/tvm/relax/backend/adreno/transform/_ffi_api.py index d665ba02a70e..1cd2b63d8f3b 100644 --- a/python/tvm/relax/backend/adreno/transform/_ffi_api.py +++ b/python/tvm/relax/backend/adreno/transform/_ffi_api.py @@ -14,6 +14,25 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs for Adreno transform""" -import tvm.ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from target import Target + from transform import Pass +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm.ffi.init_ffi_api("relax.backend.adreno.transform", __name__) + +# tvm-ffi-stubgen(begin): global/relax.backend.adreno.transform +# fmt: off +_FFI_INIT_FUNC("relax.backend.adreno.transform", __name__) +if TYPE_CHECKING: + def AnnotateCustomMemoryScope(_0: Target, /) -> Pass: ... + def FoldVDeviceScopeChange() -> Pass: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/binding_rewrite.py b/python/tvm/relax/binding_rewrite.py index 077f8feebb90..6ee11c9a503b 100644 --- a/python/tvm/relax/binding_rewrite.py +++ b/python/tvm/relax/binding_rewrite.py @@ -17,6 +17,16 @@ # pylint: disable=no-else-return, invalid-name """Developer API of add/remove/replace bindings in Relax.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from relax.expr import DataflowBlock, Function +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import Optional import tvm @@ -39,6 +49,13 @@ class DataflowBlockRewrite(Object): use mutate_irmodule which rewrites the old function that registered in the constructor. """ + # tvm-ffi-stubgen(begin): object/relax.DataflowBlockRewrite + # fmt: off + dfb: DataflowBlock + root_fn: Function | None + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, dfb: DataflowBlock, root_fn: Function): """ Construct a rewriter with the DataflowBlock to rewrite and its root function. diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index 8c777eb53756..3dc6979bcbf4 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -17,6 +17,8 @@ # pylint: disable=no-else-return, invalid-name, unused-argument, import-outside-toplevel """Developer API of constructing Relax AST.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from typing import Any, Callable, Dict, List, Optional, Sequence, Union import tvm @@ -150,6 +152,11 @@ class BlockBuilder(Object): mod = bb.get() """ + # tvm-ffi-stubgen(begin): object/relax.BlockBuilder + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + _stack = [] @staticmethod diff --git a/python/tvm/relax/distributed/_ffi_api.py b/python/tvm/relax/distributed/_ffi_api.py index 71185a1276da..5a389164175a 100644 --- a/python/tvm/relax/distributed/_ffi_api.py +++ b/python/tvm/relax/distributed/_ffi_api.py @@ -15,6 +15,32 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.relax.distributed""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import IntImm, Range, Span + from relax import DTensorStructInfo, TensorStructInfo + from relax.distributed import DeviceMesh, Placement, PlacementSpec + from tvm_ffi import Shape +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.distributed", __name__) + +# tvm-ffi-stubgen(begin): global/relax.distributed +# fmt: off +_FFI_INIT_FUNC("relax.distributed", __name__) +if TYPE_CHECKING: + def DTensorStructInfo(_0: TensorStructInfo, _1: DeviceMesh, _2: Placement, _3: Span, /) -> DTensorStructInfo: ... + def DeviceMesh(_0: Shape, _1: Sequence[IntImm], _2: Range | None, /) -> DeviceMesh: ... + def Placement(_0: Sequence[PlacementSpec], /) -> Placement: ... + def PlacementFromText(_0: str, /) -> Placement: ... + def Replica() -> PlacementSpec: ... + def Sharding(_0: int, /) -> PlacementSpec: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/distributed/global_info.py b/python/tvm/relax/distributed/global_info.py index 34d3f2da4720..fd5dccf9473b 100644 --- a/python/tvm/relax/distributed/global_info.py +++ b/python/tvm/relax/distributed/global_info.py @@ -16,6 +16,18 @@ # under the License. # pylint: disable=redefined-builtin, invalid-name """Global Info Data structures for distributed tensor.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import IntImm, Range + from tvm_ffi import Shape +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import List, Union, Tuple import tvm_ffi @@ -39,6 +51,14 @@ class DeviceMesh(GlobalInfo): Represents the device id in the mesh """ + # tvm-ffi-stubgen(begin): object/relax.distributed.DeviceMesh + # fmt: off + shape: Shape + device_ids: Sequence[IntImm] + device_range: Range | None + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, shape: Union[ShapeTuple, List[int], Tuple[int]], device_ids: Union[List[int], Range] ): diff --git a/python/tvm/relax/distributed/struct_info.py b/python/tvm/relax/distributed/struct_info.py index 554c83e47490..688fa25223aa 100644 --- a/python/tvm/relax/distributed/struct_info.py +++ b/python/tvm/relax/distributed/struct_info.py @@ -16,6 +16,18 @@ # under the License. # pylint: disable=redefined-builtin, invalid-name """Struct Info for distributed tensor.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from relax import TensorStructInfo + from relax.distributed import DeviceMesh +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import enum from typing import List import tvm_ffi @@ -46,6 +58,13 @@ class PlacementSpec(Object): The kind of placement spec. Possible values: kSharding and kReplica. """ + # tvm-ffi-stubgen(begin): object/relax.distributed.PlacementSpec + # fmt: off + axis: int + kind: int + # fmt: on + # tvm-ffi-stubgen(end) + axis: int kind: PlacementSpecKind @@ -90,6 +109,12 @@ class Placement(Object): The placement spec for each dimension of the device mesh. """ + # tvm-ffi-stubgen(begin): object/relax.distributed.Placement + # fmt: off + dim_specs: Sequence[PlacementSpec] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, dim_specs: List[PlacementSpec]): self.__init_handle_by_constructor__(_ffi_api.Placement, dim_specs) # type: ignore @@ -125,6 +150,14 @@ class DTensorStructInfo(StructInfo): """ + # tvm-ffi-stubgen(begin): object/relax.DTensorStructInfo + # fmt: off + device_mesh: DeviceMesh + placement: Placement + tensor_sinfo: TensorStructInfo + # fmt: on + # tvm-ffi-stubgen(end) + tensor_sinfo: TensorStructInfo device_mesh: DeviceMesh placement: Placement diff --git a/python/tvm/relax/distributed/transform/_ffi_api.py b/python/tvm/relax/distributed/transform/_ffi_api.py index 35808cc2bc93..1e60f4c1c454 100644 --- a/python/tvm/relax/distributed/transform/_ffi_api.py +++ b/python/tvm/relax/distributed/transform/_ffi_api.py @@ -14,6 +14,26 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs for tvm.relax.distributed.transform""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from transform import Pass +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.distributed.transform", __name__) + +# tvm-ffi-stubgen(begin): global/relax.distributed.transform +# fmt: off +_FFI_INIT_FUNC("relax.distributed.transform", __name__) +if TYPE_CHECKING: + def LegalizeRedistribute() -> Pass: ... + def LowerDistIR() -> Pass: ... + def LowerGlobalViewToLocalView() -> Pass: ... + def PropagateSharding() -> Pass: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/dpl/rewrite.py b/python/tvm/relax/dpl/rewrite.py index 6dd730e83147..1ac60cb6d8e1 100644 --- a/python/tvm/relax/dpl/rewrite.py +++ b/python/tvm/relax/dpl/rewrite.py @@ -16,6 +16,19 @@ # under the License. """APIs for pattern-based rewriting.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import RelaxExpr + from relax.dpl import DFPattern + from typing import Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import Dict, Callable, Union from tvm.ir import IRModule @@ -32,6 +45,11 @@ class PatternMatchingRewriter(Object): """A pattern-matching rewriter for Relax""" + # tvm-ffi-stubgen(begin): object/relax.dpl.PatternMatchingRewriter + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + @staticmethod def from_pattern( pattern: DFPattern, @@ -173,6 +191,12 @@ def __or__(self, other: "PatternMatchingRewriter") -> "PatternMatchingRewriter": @register_object("relax.dpl.ExprPatternRewriter") class ExprPatternRewriter(PatternMatchingRewriter): + # tvm-ffi-stubgen(begin): object/relax.dpl.ExprPatternRewriter + # fmt: off + pattern: DFPattern + func: Callable[[RelaxExpr, Mapping[DFPattern, RelaxExpr]], RelaxExpr | None] + # fmt: on + # tvm-ffi-stubgen(end) def __init__(self, pattern, func): self.__init_handle_by_constructor__( ffi.PatternRewriter, @@ -183,6 +207,12 @@ def __init__(self, pattern, func): @register_object("relax.dpl.OrRewriter") class OrRewriter(PatternMatchingRewriter): + # tvm-ffi-stubgen(begin): object/relax.dpl.OrRewriter + # fmt: off + lhs: PatternMatchingRewriter + rhs: PatternMatchingRewriter + # fmt: on + # tvm-ffi-stubgen(end) def __init__(self, lhs, rhs): self.__init_handle_by_constructor__( ffi.OrRewriter, @@ -193,6 +223,12 @@ def __init__(self, lhs, rhs): @register_object("relax.dpl.TupleRewriter") class TupleRewriter(PatternMatchingRewriter): + # tvm-ffi-stubgen(begin): object/relax.dpl.TupleRewriter + # fmt: off + patterns: Sequence[DFPattern] + func: Callable[[RelaxExpr, Mapping[DFPattern, RelaxExpr]], RelaxExpr | None] + # fmt: on + # tvm-ffi-stubgen(end) def __init__(self, patterns, func): self.__init_handle_by_constructor__( ffi.TupleRewriter, diff --git a/python/tvm/relax/exec_builder.py b/python/tvm/relax/exec_builder.py index 50d6c0679eca..0dd80e95e585 100644 --- a/python/tvm/relax/exec_builder.py +++ b/python/tvm/relax/exec_builder.py @@ -16,6 +16,8 @@ # under the License. # pylint: disable=invalid-name """A builder to build Relax VM executable.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from enum import IntEnum from typing import Optional, Union, List import tvm @@ -60,6 +62,11 @@ def __exit__(self, ptype, value, trace): class ExecBuilder(tvm_ffi.core.Object): """A builder to emit instructions and build executable for the virtual machine.""" + # tvm-ffi-stubgen(begin): object/relax.ExecBuilder + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__(_ffi_api.ExecBuilderCreate) # type: ignore diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index e9bc9a7a3e98..02ee90501d75 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -15,6 +15,19 @@ # specific language governing permissions and limitations # under the License. """The expression nodes of Relax.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import Attrs, PrimExpr, RelaxExpr, Span + from tvm_ffi import Tensor, dtype + from typing import Any +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import typing from numbers import Number from typing import Any, Callable, Dict, List, Optional, Union, Mapping @@ -48,6 +61,12 @@ class Id(Object): Guaranteed to be stable across all passes. """ + # tvm-ffi-stubgen(begin): object/relax.Id + # fmt: off + name_hint: str + # fmt: on + # tvm-ffi-stubgen(end) + name_hint: str def __init__(self): @@ -64,6 +83,12 @@ class StructInfo(Node, Scriptable): and runtime structural information. """ + # tvm-ffi-stubgen(begin): object/ir.StructInfo + # fmt: off + span: Span + # fmt: on + # tvm-ffi-stubgen(end) + def __eq__(self, other): """Compare two struct info for structural equivalence.""" return tvm.ir.structural_equal(self, other) @@ -555,6 +580,15 @@ class Call(ExprWithOp): Span that points to original source code """ + # tvm-ffi-stubgen(begin): object/relax.expr.Call + # fmt: off + op: RelaxExpr + args: Sequence[RelaxExpr] + attrs: Attrs + sinfo_args: Sequence[StructInfo] + # fmt: on + # tvm-ffi-stubgen(end) + op: Expr args: List[Expr] attrs: tvm.ir.Attrs @@ -595,6 +629,14 @@ class If(ExprWithOp): Span that points to original source code """ + # tvm-ffi-stubgen(begin): object/relax.expr.If + # fmt: off + cond: RelaxExpr + true_branch: SeqExpr + false_branch: SeqExpr + # fmt: on + # tvm-ffi-stubgen(end) + cond: Expr true_branch: Expr false_branch: Expr @@ -621,6 +663,12 @@ class Tuple(ExprWithOp): Span that points to original source code """ + # tvm-ffi-stubgen(begin): object/relax.expr.Tuple + # fmt: off + fields: Sequence[RelaxExpr] + # fmt: on + # tvm-ffi-stubgen(end) + fields: List[Expr] span: Optional[Span] @@ -659,6 +707,13 @@ class TupleGetItem(ExprWithOp): Span that points to original source code """ + # tvm-ffi-stubgen(begin): object/relax.expr.TupleGetItem + # fmt: off + tuple_value: RelaxExpr + index: int + # fmt: on + # tvm-ffi-stubgen(end) + tuple_value: Expr index: int span: Optional[Span] @@ -682,6 +737,12 @@ class ShapeExpr(ExprWithOp): Span that points to original source code """ + # tvm-ffi-stubgen(begin): object/relax.expr.ShapeExpr + # fmt: off + values: Sequence[PrimExpr] + # fmt: on + # tvm-ffi-stubgen(end) + values: List[PrimExpr] span: Optional[Span] @@ -727,6 +788,12 @@ class Constant(ExprWithOp): Scalar constants are represented by ndim-0 constant tensors. """ + # tvm-ffi-stubgen(begin): object/relax.expr.Constant + # fmt: off + data: Tensor + # fmt: on + # tvm-ffi-stubgen(end) + data: tvm.runtime.Tensor span: Optional[Span] @@ -757,6 +824,12 @@ class Var(ExprWithOp): Span that points to original source code """ + # tvm-ffi-stubgen(begin): object/relax.expr.Var + # fmt: off + vid: Id + # fmt: on + # tvm-ffi-stubgen(end) + vid: Id span: Optional[Span] @@ -806,6 +879,11 @@ class DataflowVar(Var): Span that points to original source code """ + # tvm-ffi-stubgen(begin): object/relax.expr.DataflowVar + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + vid: Id span: Optional[Span] @@ -841,6 +919,12 @@ def __init__( class PrimValue(Expr, Scriptable): """The prim expr representing the value.""" + # tvm-ffi-stubgen(begin): object/relax.expr.PrimValue + # fmt: off + value: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + value: PrimExpr def __init__(self, value: Union[PrimExpr, int], span: Optional[Span] = None) -> None: @@ -853,6 +937,12 @@ def __init__(self, value: Union[PrimExpr, int], span: Optional[Span] = None) -> class StringImm(Expr, Scriptable): """Represent a string literal constant.""" + # tvm-ffi-stubgen(begin): object/relax.expr.StringImm + # fmt: off + value: str + # fmt: on + # tvm-ffi-stubgen(end) + value: str span: Optional[Span] @@ -864,6 +954,12 @@ def __init__(self, value: str, span: Optional[Span] = None) -> None: class DataTypeImm(Expr, Scriptable): """Represent a data type constant.""" + # tvm-ffi-stubgen(begin): object/relax.expr.DataTypeImm + # fmt: off + value: dtype + # fmt: on + # tvm-ffi-stubgen(end) + value: DataType span: Optional[Span] @@ -875,6 +971,13 @@ def __init__(self, value: Union[DataType, str], span: Optional[Span] = None) -> class Binding(Node, Scriptable): """The base class of a binding in Relax.""" + # tvm-ffi-stubgen(begin): object/relax.expr.Binding + # fmt: off + span: Span + var: Var + # fmt: on + # tvm-ffi-stubgen(end) + var: Var span: Optional[Span] @@ -899,6 +1002,13 @@ class MatchCast(Binding): The struct info to match cast to. """ + # tvm-ffi-stubgen(begin): object/relax.expr.MatchCast + # fmt: off + value: RelaxExpr + struct_info: StructInfo + # fmt: on + # tvm-ffi-stubgen(end) + struct_info: StructInfo value: Expr span: Optional[Span] @@ -925,6 +1035,12 @@ class VarBinding(Binding): """ + # tvm-ffi-stubgen(begin): object/relax.expr.VarBinding + # fmt: off + value: RelaxExpr + # fmt: on + # tvm-ffi-stubgen(end) + var: Var value: Expr span: Optional[Span] @@ -938,6 +1054,13 @@ class BindingBlock(Node, Scriptable): """base class of binding block, bindings inside can be impure (with side effect or control flow)""" + # tvm-ffi-stubgen(begin): object/relax.expr.BindingBlock + # fmt: off + bindings: Sequence[Binding] + span: Span + # fmt: on + # tvm-ffi-stubgen(end) + bindings: List[Binding] span: Optional[Span] @@ -949,6 +1072,11 @@ def __init__(self, bindings: List[Binding], span: Optional[Span] = None) -> None class DataflowBlock(BindingBlock): """dataflow block, bindings inside are pure (no side effect and no control flow)""" + # tvm-ffi-stubgen(begin): object/relax.expr.DataflowBlock + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + bindings: List[Binding] span: Optional[Span] @@ -961,6 +1089,13 @@ def __init__(self, bindings: List[Binding], span: Optional[Span] = None) -> None class SeqExpr(ExprWithOp): """A sequence of binding blocks followed by an expression.""" + # tvm-ffi-stubgen(begin): object/relax.expr.SeqExpr + # fmt: off + blocks: Sequence[BindingBlock] + body: RelaxExpr + # fmt: on + # tvm-ffi-stubgen(end) + blocks: List[BindingBlock] body: Expr span: Optional[Span] @@ -973,6 +1108,15 @@ def __init__(self, blocks: List[BindingBlock], body: Expr, span: Optional[Span] class Function(BaseFunc, Scriptable): """A Relax function.""" + # tvm-ffi-stubgen(begin): object/relax.expr.Function + # fmt: off + params: Sequence[Var] + body: SeqExpr + ret_struct_info: StructInfo + is_pure: bool + # fmt: on + # tvm-ffi-stubgen(end) + params: List[Var] body: Expr ret_struct_info: StructInfo @@ -1112,6 +1256,12 @@ def inline_functions( class ExternFunc(BaseFunc, ExprWithOp): """extern function, which represents a PackedFunc.""" + # tvm-ffi-stubgen(begin): object/relax.expr.ExternFunc + # fmt: off + global_symbol: str + # fmt: on + # tvm-ffi-stubgen(end) + global_symbol: String span: Optional[Span] @@ -1183,6 +1333,17 @@ def const( class TEPlaceholderOp(tvm.te.tensor.Operation): """The placeholder op that represents a relax expression.""" + # tvm-ffi-stubgen(begin): object/relax.TEPlaceholderOp + # fmt: off + name: str + tag: str + attrs: Mapping[str, Any] + value: RelaxExpr + shape: Sequence[PrimExpr] + dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + def te_tensor( value: Expr, tir_var_map: Dict[tvm.tir.Var, tvm.tir.PrimExpr], name: str = "rxplaceholder" diff --git a/python/tvm/relax/expr_functor.py b/python/tvm/relax/expr_functor.py index e5e77251c66d..40174f72218b 100644 --- a/python/tvm/relax/expr_functor.py +++ b/python/tvm/relax/expr_functor.py @@ -16,6 +16,16 @@ # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, arguments-differ """The expression functor of Relax.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from relax import BlockBuilder +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import Callable, Optional import tvm_ffi @@ -272,6 +282,11 @@ class _PyExprVisitor(tvm_ffi.core.Object): See also: visitor, PyExprVisitor """ + # tvm-ffi-stubgen(begin): object/expr_functor.PyExprVisitor + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, f_visit_expr: Callable = None, @@ -792,6 +807,12 @@ class _PyExprMutator(Object): See also: mutator, PyExprmutator """ + # tvm-ffi-stubgen(begin): object/expr_functor.PyExprMutator + # fmt: off + builder_: BlockBuilder + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, builder: BlockBuilder = None, diff --git a/python/tvm/relax/op/_ffi_api.py b/python/tvm/relax/op/_ffi_api.py index 867c43e4d85b..ed4d541841c8 100644 --- a/python/tvm/relax/op/_ffi_api.py +++ b/python/tvm/relax/op/_ffi_api.py @@ -14,6 +14,170 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs for tvm.relax.op""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import Attrs, IntImm, PrimExpr, RelaxExpr, StructInfo, VDevice + from relax import TensorStructInfo + from relax.expr import PrimValue, StringImm, Tuple + from tir import IndexMap + from tvm_ffi import dtype + from typing import Any +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.op", __name__) + +# tvm-ffi-stubgen(begin): global/relax.op +# fmt: off +_FFI_INIT_FUNC("relax.op", __name__) +if TYPE_CHECKING: + def abs(_0: RelaxExpr, /) -> RelaxExpr: ... + def acos(_0: RelaxExpr, /) -> RelaxExpr: ... + def acosh(_0: RelaxExpr, /) -> RelaxExpr: ... + def add(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def arange(_0: PrimValue, _1: PrimValue, _2: PrimValue, _3: dtype, /) -> RelaxExpr: ... + def argmax(_0: RelaxExpr, _1: int | None, _2: bool, /) -> RelaxExpr: ... + def argmin(_0: RelaxExpr, _1: int | None, _2: bool, /) -> RelaxExpr: ... + def argsort(_0: RelaxExpr, _1: int, _2: bool, _3: dtype, /) -> RelaxExpr: ... + def asin(_0: RelaxExpr, /) -> RelaxExpr: ... + def asinh(_0: RelaxExpr, /) -> RelaxExpr: ... + def assert_op(_0: RelaxExpr, _1: Sequence[RelaxExpr], _2: StringImm, /) -> RelaxExpr: ... + def astype(_0: RelaxExpr, _1: dtype, /) -> RelaxExpr: ... + def atan(_0: RelaxExpr, /) -> RelaxExpr: ... + def atanh(_0: RelaxExpr, /) -> RelaxExpr: ... + def bitwise_and(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def bitwise_not(_0: RelaxExpr, /) -> RelaxExpr: ... + def bitwise_or(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def bitwise_xor(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def broadcast_to(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def bucketize(_0: RelaxExpr, _1: RelaxExpr, _2: bool, _3: bool, /) -> RelaxExpr: ... + def call_builtin_with_ctx(_0: RelaxExpr, _1: Tuple, _2: Sequence[StructInfo], /) -> RelaxExpr: ... + def call_dps_packed(_0: RelaxExpr, _1: Tuple, _2: Sequence[TensorStructInfo], /) -> RelaxExpr: ... + def call_inplace_packed(_0: RelaxExpr, _1: Sequence[RelaxExpr], _2: Sequence[IntImm], _3: Sequence[StructInfo], /) -> RelaxExpr: ... + def call_pure_packed(_0: RelaxExpr, _1: Sequence[RelaxExpr], _2: Attrs, _3: Sequence[StructInfo], /) -> RelaxExpr: ... + def call_py_func(_0: StringImm, _1: Tuple, _2: Sequence[TensorStructInfo], /) -> RelaxExpr: ... + def call_tir(_0: RelaxExpr, _1: Tuple, _2: Sequence[TensorStructInfo], _3: RelaxExpr | None, /) -> RelaxExpr: ... + def call_tir_inplace(_0: RelaxExpr, _1: Tuple, _2: Sequence[IntImm], _3: Sequence[TensorStructInfo], _4: RelaxExpr | None, /) -> RelaxExpr: ... + def call_tir_with_grad(_0: RelaxExpr, _1: Tuple, _2: Sequence[TensorStructInfo], _3: str, _4: Mapping[str, Any], _5: RelaxExpr | None, /) -> RelaxExpr: ... + def ceil(_0: RelaxExpr, /) -> RelaxExpr: ... + def clip(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, /) -> RelaxExpr: ... + def collapse_sum_like(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def collapse_sum_to(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def concat(_0: RelaxExpr, _1: int | None, /) -> RelaxExpr: ... + def cos(_0: RelaxExpr, /) -> RelaxExpr: ... + def cosh(_0: RelaxExpr, /) -> RelaxExpr: ... + def cumprod(_0: RelaxExpr, _1: int | None, _2: dtype | None, _3: IntImm, /) -> RelaxExpr: ... + def cumsum(_0: RelaxExpr, _1: int | None, _2: dtype | None, _3: IntImm, /) -> RelaxExpr: ... + def dequantize(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: int, _4: dtype, /) -> RelaxExpr: ... + def divide(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def dynamic_strided_slice(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: RelaxExpr, /) -> RelaxExpr: ... + def einsum(_0: RelaxExpr, _1: str, /) -> RelaxExpr: ... + def equal(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def erf(_0: RelaxExpr, /) -> RelaxExpr: ... + def ewise_fma(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, /) -> RelaxExpr: ... + def exp(_0: RelaxExpr, /) -> RelaxExpr: ... + def expand_dims(_0: RelaxExpr, _1: Sequence[IntImm], /) -> RelaxExpr: ... + def eye(_0: PrimValue, _1: PrimValue, _2: PrimValue, _3: dtype, /) -> RelaxExpr: ... + def eye_like(_0: RelaxExpr, _1: PrimValue, _2: dtype | None, /) -> RelaxExpr: ... + def flatten(_0: RelaxExpr, /) -> RelaxExpr: ... + def flip(_0: RelaxExpr, _1: IntImm, /) -> RelaxExpr: ... + def floor(_0: RelaxExpr, /) -> RelaxExpr: ... + def floor_divide(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def floor_mod(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def full(_0: RelaxExpr | Sequence[PrimExpr], _1: RelaxExpr, _2: dtype | None, /) -> RelaxExpr: ... + def full_like(_0: RelaxExpr, _1: RelaxExpr, _2: dtype | None, /) -> RelaxExpr: ... + def gather_elements(_0: RelaxExpr, _1: RelaxExpr, _2: int, /) -> RelaxExpr: ... + def gather_nd(_0: RelaxExpr, _1: RelaxExpr, _2: int, /) -> RelaxExpr: ... + def greater(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def greater_equal(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def hamming_window(_0: PrimValue, _1: PrimValue, _2: PrimValue, _3: PrimValue, _4: dtype, /) -> RelaxExpr: ... + def hint_on_device(*args: Any) -> Any: ... + def index_put(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: bool, /) -> RelaxExpr: ... + def index_tensor(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def invoke_closure(_0: RelaxExpr, _1: Tuple, _2: Sequence[StructInfo], /) -> RelaxExpr: ... + def invoke_pure_closure(_0: RelaxExpr, _1: Tuple, _2: Sequence[StructInfo], /) -> RelaxExpr: ... + def isfinite(_0: RelaxExpr, /) -> RelaxExpr: ... + def isinf(_0: RelaxExpr, /) -> RelaxExpr: ... + def isnan(_0: RelaxExpr, /) -> RelaxExpr: ... + def layout_transform(_0: RelaxExpr, _1: IndexMap, _2: PrimValue | None, _3: Sequence[IntImm] | None, _4: Sequence[IntImm] | None, /) -> RelaxExpr: ... + def left_shift(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def less(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def less_equal(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def log(_0: RelaxExpr, /) -> RelaxExpr: ... + def log_add_exp(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def logical_and(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def logical_not(_0: RelaxExpr, /) -> RelaxExpr: ... + def logical_or(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def logical_xor(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def make_closure(_0: RelaxExpr, _1: Tuple, /) -> RelaxExpr: ... + def matmul(_0: RelaxExpr, _1: RelaxExpr, _2: dtype | None, /) -> RelaxExpr: ... + def max(_0: RelaxExpr, _1: Sequence[IntImm] | None, _2: bool, /) -> RelaxExpr: ... + def maximum(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def mean(_0: RelaxExpr, _1: Sequence[IntImm] | None, _2: bool, /) -> RelaxExpr: ... + def meshgrid(_0: RelaxExpr, _1: str | None, /) -> RelaxExpr: ... + def min(_0: RelaxExpr, _1: Sequence[IntImm] | None, _2: bool, /) -> RelaxExpr: ... + def minimum(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def mod(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def multinomial_from_uniform(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: dtype, /) -> RelaxExpr: ... + def multiply(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def negative(_0: RelaxExpr, /) -> RelaxExpr: ... + def nonzero(_0: RelaxExpr, /) -> RelaxExpr: ... + def not_equal(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def null_value() -> RelaxExpr: ... + def one_hot(_0: RelaxExpr, _1: PrimValue, _2: PrimValue, _3: int, _4: int, /) -> RelaxExpr: ... + def ones(_0: RelaxExpr, _1: dtype, /) -> RelaxExpr: ... + def ones_like(_0: RelaxExpr, _1: dtype | None, /) -> RelaxExpr: ... + def outer(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def permute_dims(_0: RelaxExpr, _1: Sequence[IntImm] | None, /) -> RelaxExpr: ... + def power(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def print(_0: Sequence[RelaxExpr], _1: StringImm, /) -> RelaxExpr: ... + def prod(_0: RelaxExpr, _1: Sequence[IntImm] | None, _2: bool, /) -> RelaxExpr: ... + def quantize(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: int, _4: dtype, /) -> RelaxExpr: ... + def repeat(_0: RelaxExpr, _1: int, _2: int | None, /) -> RelaxExpr: ... + def reshape(_0: RelaxExpr, _1: RelaxExpr | Sequence[PrimExpr], /) -> RelaxExpr: ... + def right_shift(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def round(_0: RelaxExpr, /) -> RelaxExpr: ... + def rsqrt(_0: RelaxExpr, /) -> RelaxExpr: ... + def scatter_elements(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: int, _4: str, /) -> RelaxExpr: ... + def scatter_nd(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: str, /) -> RelaxExpr: ... + def shape_of(_0: RelaxExpr, /) -> RelaxExpr: ... + def shape_to_tensor(_0: RelaxExpr, /) -> RelaxExpr: ... + def sigmoid(_0: RelaxExpr, /) -> RelaxExpr: ... + def sign(_0: RelaxExpr, /) -> RelaxExpr: ... + def sin(_0: RelaxExpr, /) -> RelaxExpr: ... + def sinh(_0: RelaxExpr, /) -> RelaxExpr: ... + def slice_scatter(_0: RelaxExpr, _1: RelaxExpr, _2: int, _3: PrimValue, _4: PrimValue, _5: PrimValue, /) -> RelaxExpr: ... + def sort(_0: RelaxExpr, _1: int, _2: bool, /) -> RelaxExpr: ... + def split(_0: RelaxExpr, _1: IntImm | Sequence[IntImm], _2: int, /) -> RelaxExpr: ... + def sqrt(_0: RelaxExpr, /) -> RelaxExpr: ... + def square(_0: RelaxExpr, /) -> RelaxExpr: ... + def squeeze(_0: RelaxExpr, _1: Sequence[IntImm] | None, /) -> RelaxExpr: ... + def stack(_0: RelaxExpr, _1: IntImm | None, /) -> RelaxExpr: ... + def std(_0: RelaxExpr, _1: Sequence[IntImm] | None, _2: bool, /) -> RelaxExpr: ... + def strided_slice(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: RelaxExpr, _4: RelaxExpr | None, _5: bool, /) -> RelaxExpr: ... + def subtract(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def sum(_0: RelaxExpr, _1: Sequence[IntImm] | None, _2: bool, /) -> RelaxExpr: ... + def take(_0: RelaxExpr, _1: RelaxExpr, _2: int | None, _3: str, /) -> RelaxExpr: ... + def tan(_0: RelaxExpr, /) -> RelaxExpr: ... + def tanh(_0: RelaxExpr, /) -> RelaxExpr: ... + def tensor_to_shape(_0: RelaxExpr, /) -> RelaxExpr: ... + def tile(_0: RelaxExpr, _1: Sequence[IntImm], /) -> RelaxExpr: ... + def to_vdevice(_0: RelaxExpr, _1: VDevice, /) -> RelaxExpr: ... + def topk(_0: RelaxExpr, _1: int, _2: int, _3: str, _4: bool, _5: dtype, /) -> RelaxExpr: ... + def tril(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def triu(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def trunc(_0: RelaxExpr, /) -> RelaxExpr: ... + def unique(_0: RelaxExpr, _1: PrimValue, _2: PrimValue, _3: PrimValue, _4: PrimValue, _5: PrimValue | None, /) -> RelaxExpr: ... + def variance(_0: RelaxExpr, _1: Sequence[IntImm] | None, _2: bool, /) -> RelaxExpr: ... + def where(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, /) -> RelaxExpr: ... + def wrap_param(_0: RelaxExpr, _1: dtype, /) -> RelaxExpr: ... + def zeros(_0: RelaxExpr, _1: dtype, /) -> RelaxExpr: ... + def zeros_like(_0: RelaxExpr, _1: dtype | None, /) -> RelaxExpr: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/op/builtin/_ffi_api.py b/python/tvm/relax/op/builtin/_ffi_api.py index 0e5955f6e47d..ae4cd27e1c94 100644 --- a/python/tvm/relax/op/builtin/_ffi_api.py +++ b/python/tvm/relax/op/builtin/_ffi_api.py @@ -14,6 +14,25 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs for tvm.relax.op.builtin""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ir import RelaxExpr + from relax.expr import DataTypeImm, PrimValue, StringImm +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.op.builtin", __name__) + +# tvm-ffi-stubgen(begin): global/relax.op.builtin +# fmt: off +_FFI_INIT_FUNC("relax.op.builtin", __name__) +if TYPE_CHECKING: + def alloc_tensor(_0: RelaxExpr, _1: DataTypeImm, _2: PrimValue, _3: StringImm, /) -> RelaxExpr: ... + def stop_lift_params(_0: RelaxExpr, /) -> RelaxExpr: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/op/ccl/_ffi_api.py b/python/tvm/relax/op/ccl/_ffi_api.py index f0796d3da318..d74fd7d48031 100644 --- a/python/tvm/relax/op/ccl/_ffi_api.py +++ b/python/tvm/relax/op/ccl/_ffi_api.py @@ -15,6 +15,26 @@ # specific language governing permissions and limitations # under the License. """Operators serving for Collective Communications Library (CCL) operators""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ir import RelaxExpr +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.op.ccl", __name__) + +# tvm-ffi-stubgen(begin): global/relax.op.ccl +# fmt: off +_FFI_INIT_FUNC("relax.op.ccl", __name__) +if TYPE_CHECKING: + def allgather(_0: RelaxExpr, _1: int, _2: bool, /) -> RelaxExpr: ... + def allreduce(_0: RelaxExpr, _1: str, _2: bool, /) -> RelaxExpr: ... + def broadcast_from_worker0(_0: RelaxExpr, /) -> RelaxExpr: ... + def scatter_from_worker0(_0: RelaxExpr, _1: int, _2: int, /) -> RelaxExpr: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/op/distributed/_ffi_api.py b/python/tvm/relax/op/distributed/_ffi_api.py index fa1c163794b9..4065214bb0c4 100644 --- a/python/tvm/relax/op/distributed/_ffi_api.py +++ b/python/tvm/relax/op/distributed/_ffi_api.py @@ -15,6 +15,30 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.relax.op.distributed""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import RelaxExpr + from relax import DTensorStructInfo + from relax.distributed import DeviceMesh, Placement + from relax.expr import Tuple +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.op.dist", __name__) + +# tvm-ffi-stubgen(begin): global/relax.op.dist +# fmt: off +_FFI_INIT_FUNC("relax.op.dist", __name__) +if TYPE_CHECKING: + def annotate_sharding(_0: RelaxExpr, _1: DeviceMesh, _2: Placement, /) -> RelaxExpr: ... + def call_tir_local_view(_0: RelaxExpr, _1: Tuple, _2: Sequence[DTensorStructInfo], _3: RelaxExpr | None, /) -> RelaxExpr: ... + def redistribute(_0: RelaxExpr, _1: DeviceMesh, _2: Placement, /) -> RelaxExpr: ... + def redistribute_replica_to_shard(_0: RelaxExpr, _1: int, _2: int, /) -> RelaxExpr: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/op/grad/_ffi_api.py b/python/tvm/relax/op/grad/_ffi_api.py index 1a8ebb09aa8d..49bb114d65db 100644 --- a/python/tvm/relax/op/grad/_ffi_api.py +++ b/python/tvm/relax/op/grad/_ffi_api.py @@ -15,6 +15,30 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.relax.op.grad""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import IntImm, RelaxExpr +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.op.grad", __name__) + +# tvm-ffi-stubgen(begin): global/relax.op.grad +# fmt: off +_FFI_INIT_FUNC("relax.op.grad", __name__) +if TYPE_CHECKING: + def avg_pool2d_backward(_0: RelaxExpr, _1: RelaxExpr, _2: Sequence[IntImm], _3: Sequence[IntImm], _4: Sequence[IntImm], _5: Sequence[IntImm], _6: bool, _7: bool, _8: str, _9: str | None, /) -> RelaxExpr: ... + def end_checkpoint(_0: RelaxExpr, /) -> RelaxExpr: ... + def max_pool2d_backward(_0: RelaxExpr, _1: RelaxExpr, _2: Sequence[IntImm], _3: Sequence[IntImm], _4: Sequence[IntImm], _5: Sequence[IntImm], _6: bool, _7: bool, _8: str, _9: str | None, /) -> RelaxExpr: ... + def nll_loss_backward(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: RelaxExpr | None, _4: str, _5: int, /) -> RelaxExpr: ... + def no_grad(_0: RelaxExpr, /) -> RelaxExpr: ... + def start_checkpoint(_0: RelaxExpr, /) -> RelaxExpr: ... + def take_backward(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: int | None, /) -> RelaxExpr: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/op/image/_ffi_api.py b/python/tvm/relax/op/image/_ffi_api.py index 8147a155cb76..c734db1bd753 100644 --- a/python/tvm/relax/op/image/_ffi_api.py +++ b/python/tvm/relax/op/image/_ffi_api.py @@ -15,6 +15,26 @@ # specific language governing permissions and limitations # under the License. """Constructor APIs""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import FloatImm, RelaxExpr + from tvm_ffi import dtype +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.op.image", __name__) + +# tvm-ffi-stubgen(begin): global/relax.op.image +# fmt: off +_FFI_INIT_FUNC("relax.op.image", __name__) +if TYPE_CHECKING: + def grid_sample(_0: RelaxExpr, _1: RelaxExpr, _2: str, _3: str, _4: str, _5: bool, /) -> RelaxExpr: ... + def resize2d(_0: RelaxExpr, _1: RelaxExpr, _2: Sequence[FloatImm], _3: str, _4: str, _5: str, _6: str, _7: float, _8: int, _9: float, _10: dtype | None, /) -> RelaxExpr: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/op/memory/_ffi_api.py b/python/tvm/relax/op/memory/_ffi_api.py index 05dbf534c7f5..5d67cc449283 100644 --- a/python/tvm/relax/op/memory/_ffi_api.py +++ b/python/tvm/relax/op/memory/_ffi_api.py @@ -14,6 +14,29 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs for tvm.relax.op.memory""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ir import RelaxExpr + from relax.expr import DataTypeImm, PrimValue, StringImm +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.op.memory", __name__) + +# tvm-ffi-stubgen(begin): global/relax.op.memory +# fmt: off +_FFI_INIT_FUNC("relax.op.memory", __name__) +if TYPE_CHECKING: + def alloc_storage(_0: RelaxExpr, _1: PrimValue, _2: StringImm, _3: DataTypeImm, /) -> RelaxExpr: ... + def alloc_tensor(_0: RelaxExpr, _1: PrimValue, _2: RelaxExpr, _3: DataTypeImm, /) -> RelaxExpr: ... + def ensure_zero_offset(_0: RelaxExpr, /) -> RelaxExpr: ... + def kill_storage(_0: RelaxExpr, /) -> RelaxExpr: ... + def kill_tensor(_0: RelaxExpr, /) -> RelaxExpr: ... + def view(_0: RelaxExpr, _1: RelaxExpr | None, _2: RelaxExpr | None, _3: RelaxExpr | None, /) -> RelaxExpr: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/op/nn/_ffi_api.py b/python/tvm/relax/op/nn/_ffi_api.py index d58fa186fc7c..cf446282d6a3 100644 --- a/python/tvm/relax/op/nn/_ffi_api.py +++ b/python/tvm/relax/op/nn/_ffi_api.py @@ -15,6 +15,60 @@ # specific language governing permissions and limitations # under the License. """Constructor APIs""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import FloatImm, IntImm, RelaxExpr + from tvm_ffi import dtype +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.op.nn", __name__) + +# tvm-ffi-stubgen(begin): global/relax.op.nn +# fmt: off +_FFI_INIT_FUNC("relax.op.nn", __name__) +if TYPE_CHECKING: + def adaptive_avg_pool1d(_0: RelaxExpr, _1: Sequence[IntImm] | None, _2: str, _3: str | None, /) -> RelaxExpr: ... + def adaptive_avg_pool2d(_0: RelaxExpr, _1: Sequence[IntImm] | None, _2: str, _3: str | None, /) -> RelaxExpr: ... + def adaptive_avg_pool3d(_0: RelaxExpr, _1: Sequence[IntImm] | None, _2: str, _3: str | None, /) -> RelaxExpr: ... + def attention(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: RelaxExpr | None, _4: FloatImm | None, _5: str | None, _6: IntImm | None, /) -> RelaxExpr: ... + def attention_var_len(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: RelaxExpr, _4: RelaxExpr, _5: RelaxExpr, _6: RelaxExpr, _7: FloatImm | None, _8: str | None, _9: IntImm | None, /) -> RelaxExpr: ... + def avg_pool1d(_0: RelaxExpr, _1: Sequence[IntImm], _2: Sequence[IntImm], _3: Sequence[IntImm], _4: Sequence[IntImm], _5: bool, _6: bool, _7: str, _8: str | None, /) -> RelaxExpr: ... + def avg_pool2d(_0: RelaxExpr, _1: Sequence[IntImm], _2: Sequence[IntImm], _3: Sequence[IntImm], _4: Sequence[IntImm], _5: bool, _6: bool, _7: str, _8: str | None, /) -> RelaxExpr: ... + def avg_pool3d(_0: RelaxExpr, _1: Sequence[IntImm], _2: Sequence[IntImm], _3: Sequence[IntImm], _4: Sequence[IntImm], _5: bool, _6: bool, _7: str, _8: str | None, /) -> RelaxExpr: ... + def batch_norm(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: RelaxExpr, _4: RelaxExpr, _5: int, _6: float, _7: bool, _8: bool, _9: float, _10: bool, /) -> RelaxExpr: ... + def conv1d(_0: RelaxExpr, _1: RelaxExpr, _2: Sequence[IntImm], _3: Sequence[IntImm], _4: Sequence[IntImm], _5: int, _6: str, _7: str, _8: str | None, _9: dtype | None, /) -> RelaxExpr: ... + def conv1d_transpose(_0: RelaxExpr, _1: RelaxExpr, _2: Sequence[IntImm], _3: Sequence[IntImm], _4: Sequence[IntImm], _5: Sequence[IntImm], _6: int, _7: str, _8: str, _9: str | None, _10: dtype | None, /) -> RelaxExpr: ... + def conv2d(_0: RelaxExpr, _1: RelaxExpr, _2: Sequence[IntImm], _3: Sequence[IntImm], _4: Sequence[IntImm], _5: int, _6: str, _7: str, _8: str | None, _9: dtype | None, /) -> RelaxExpr: ... + def conv2d_transpose(_0: RelaxExpr, _1: RelaxExpr, _2: Sequence[IntImm], _3: Sequence[IntImm], _4: Sequence[IntImm], _5: Sequence[IntImm], _6: int, _7: str, _8: str, _9: str | None, _10: dtype | None, /) -> RelaxExpr: ... + def conv3d(_0: RelaxExpr, _1: RelaxExpr, _2: Sequence[IntImm], _3: Sequence[IntImm], _4: Sequence[IntImm], _5: int, _6: str, _7: str, _8: str | None, _9: dtype | None, /) -> RelaxExpr: ... + def cross_entropy_with_logits(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def dropout(_0: RelaxExpr, _1: float, /) -> RelaxExpr: ... + def gelu(_0: RelaxExpr, /) -> RelaxExpr: ... + def gelu_tanh(_0: RelaxExpr, /) -> RelaxExpr: ... + def group_norm(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: int, _4: int, _5: Sequence[IntImm], _6: float, _7: bool, _8: bool, /) -> RelaxExpr: ... + def instance_norm(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: int, _4: Sequence[IntImm], _5: float, _6: bool, _7: bool, /) -> RelaxExpr: ... + def layer_norm(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: Sequence[IntImm], _4: float, _5: bool, _6: bool, /) -> RelaxExpr: ... + def leakyrelu(_0: RelaxExpr, _1: float, /) -> RelaxExpr: ... + def log_softmax(_0: RelaxExpr, _1: int, /) -> RelaxExpr: ... + def max_pool1d(_0: RelaxExpr, _1: Sequence[IntImm], _2: Sequence[IntImm], _3: Sequence[IntImm], _4: Sequence[IntImm], _5: bool, _6: bool, _7: str, _8: str | None, /) -> RelaxExpr: ... + def max_pool2d(_0: RelaxExpr, _1: Sequence[IntImm], _2: Sequence[IntImm], _3: Sequence[IntImm], _4: Sequence[IntImm], _5: bool, _6: bool, _7: str, _8: str | None, /) -> RelaxExpr: ... + def max_pool3d(_0: RelaxExpr, _1: Sequence[IntImm], _2: Sequence[IntImm], _3: Sequence[IntImm], _4: Sequence[IntImm], _5: bool, _6: bool, _7: str, _8: str | None, /) -> RelaxExpr: ... + def nll_loss(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr | None, _3: str, _4: int, /) -> RelaxExpr: ... + def pad(_0: RelaxExpr, _1: Sequence[IntImm], _2: str, _3: float, /) -> RelaxExpr: ... + def pixel_shuffle(_0: RelaxExpr, _1: int, /) -> RelaxExpr: ... + def prelu(_0: RelaxExpr, _1: RelaxExpr, _2: int, /) -> RelaxExpr: ... + def relu(_0: RelaxExpr, /) -> RelaxExpr: ... + def rms_norm(_0: RelaxExpr, _1: RelaxExpr, _2: Sequence[IntImm], _3: float, /) -> RelaxExpr: ... + def selu(_0: RelaxExpr, /) -> RelaxExpr: ... + def silu(_0: RelaxExpr, /) -> RelaxExpr: ... + def softmax(_0: RelaxExpr, _1: int, /) -> RelaxExpr: ... + def softplus(_0: RelaxExpr, _1: float, _2: float, /) -> RelaxExpr: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index 229a789a45ef..9bebbff5c8b6 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -15,6 +15,21 @@ # specific language governing permissions and limitations # under the License. """The attributes node used for Relax operators""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import FloatImm, IntImm, VDevice + from relax.expr import PrimValue + from tir import IndexMap + from tvm_ffi import Object, dtype + from typing import Any +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from tvm.ir import Attrs import tvm_ffi @@ -23,342 +38,878 @@ class CallTIRWithGradAttrs(Attrs): """Attributes used in call_tir_with_grad operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.CallTIRWithGradAttrs + # fmt: off + te_grad_name: str + te_grad_kwargs: Mapping[str, Any] + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.InitAttrs") class InitAttrs(Attrs): """Attributes used in full/full_like, ones/ones_like, and zeros/zeros_like operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.InitAttrs + # fmt: off + dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.TriluAttrs") class TriluAttrs(Attrs): """Attributes used in tril and triu operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.TriluAttrs + # fmt: off + k: int + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.AstypeAttrs") class AstypeAttrs(Attrs): """Attributes used in astype operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.AstypeAttrs + # fmt: off + dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.TakeAttrs") class TakeAttrs(Attrs): """Attributes used in take operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.TakeAttrs + # fmt: off + axis: int | None + mode: str + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.StridedSliceAttrs") class StridedSliceAttrs(Attrs): """Attributes used in strided_slice operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.StridedSliceAttrs + # fmt: off + assume_inbound: bool + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.MatmulAttrs") class MatmulAttrs(Attrs): """Attributes for matmul operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.MatmulAttrs + # fmt: off + out_dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.Conv2DAttrs") class Conv2DAttrs(Attrs): """Attributes for nn.conv2d""" + # tvm-ffi-stubgen(begin): object/relax.attrs.Conv2DAttrs + # fmt: off + strides: Sequence[IntImm] + padding: Sequence[IntImm] + dilation: Sequence[IntImm] + groups: int + data_layout: str + kernel_layout: str + out_layout: str + out_dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.Conv3DAttrs") class Conv3DAttrs(Attrs): """Attributes for nn.conv3d""" + # tvm-ffi-stubgen(begin): object/relax.attrs.Conv3DAttrs + # fmt: off + strides: Sequence[IntImm] + padding: Sequence[IntImm] + dilation: Sequence[IntImm] + groups: int + data_layout: str + kernel_layout: str + out_layout: str + out_dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.Conv2DTransposeAttrs") class Conv2DTransposeAttrs(Attrs): """Attributes for nn.conv2d_transpose""" + # tvm-ffi-stubgen(begin): object/relax.attrs.Conv2DTransposeAttrs + # fmt: off + strides: Sequence[IntImm] + padding: Sequence[IntImm] + output_padding: Sequence[IntImm] + dilation: Sequence[IntImm] + groups: int + data_layout: str + kernel_layout: str + out_layout: str + out_dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.Pool2DAttrs") class Pool2DAttrs(Attrs): """Attributes for nn.max_pool2d""" + # tvm-ffi-stubgen(begin): object/relax.attrs.Pool2DAttrs + # fmt: off + pool_size: Sequence[IntImm] + strides: Sequence[IntImm] + dilation: Sequence[IntImm] + padding: Sequence[IntImm] + ceil_mode: bool + count_include_pad: bool + layout: str + out_layout: str + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.AdaptivePool2DAttrs") class AdaptivePool2DAttrs(Attrs): """Attributes for 2d adaptive pool operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.AdaptivePool2DAttrs + # fmt: off + output_size: Sequence[IntImm] | None + layout: str + out_layout: str + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.SoftmaxAttrs") class SoftmaxAttrs(Attrs): """Attributes for nn.softmax""" + # tvm-ffi-stubgen(begin): object/relax.attrs.SoftmaxAttrs + # fmt: off + axis: int + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.BatchNormAttrs") class BatchNormAttrs(Attrs): """Attributes used in batch_norm operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.BatchNormAttrs + # fmt: off + axis: int + epsilon: float + center: bool + scale: bool + momentum: float + training: bool + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.LayerNormAttrs") class LayerNormAttrs(Attrs): """Attributes used in layer_norm operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.LayerNormAttrs + # fmt: off + axes: Sequence[IntImm] + epsilon: float + center: bool + scale: bool + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.InstanceNormAttrs") class InstanceNormAttrs(Attrs): """Attributes used in instance_norm operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.InstanceNormAttrs + # fmt: off + channel_axis: int + axes: Sequence[IntImm] + epsilon: float + center: bool + scale: bool + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.DropoutAttrs") class DropoutAttrs(Attrs): """Attributes for dropout operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.DropoutAttrs + # fmt: off + rate: float + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.StatisticalAttrs") class StatisticalAttrs(Attrs): """Attributes used in statistical operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.StatisticalAttrs + # fmt: off + axis: Sequence[IntImm] | None + keepdims: bool + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.ConcatAttrs") class ConcatAttrs(Attrs): """Attributes for concat operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.ConcatAttrs + # fmt: off + axis: int | None + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.ExpandDimsAttrs") class ExpandDimsAttrs(Attrs): """Attributes for expand_dims operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.ExpandDimsAttrs + # fmt: off + axis: Sequence[IntImm] + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.PermuteDimsAttrs") class PermuteDimsAttrs(Attrs): """Attributes for permute_dims operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.PermuteDimsAttrs + # fmt: off + axes: Sequence[IntImm] | None + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.SortAttrs") class SortAttrs(Attrs): """Attributes for sort operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.SortAttrs + # fmt: off + axis: int + descending: bool + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.ArgsortAttrs") class ArgsortAttrs(Attrs): """Attributes for argsort operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.ArgsortAttrs + # fmt: off + axis: int + descending: bool + dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.SplitAttrs") class SplitAttrs(Attrs): """Attributes used in split operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.SplitAttrs + # fmt: off + indices_or_sections: Object + axis: int + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.SqueezeAttrs") class SqueezeAttrs(Attrs): """Attributes for squeeze operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.SqueezeAttrs + # fmt: off + axis: Sequence[IntImm] | None + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.StackAttrs") class StackAttrs(Attrs): """Attributes for concat operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.StackAttrs + # fmt: off + axis: IntImm | None + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.IndexPutAttrs") class IndexPutAttrs(Attrs): """Attributes for index_put operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.IndexPutAttrs + # fmt: off + accumulate: bool + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.LayoutTransformAttrs") class LayoutTransformAttrs(Attrs): """Attributes used in layout_transform operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.LayoutTransformAttrs + # fmt: off + index_map: IndexMap + pad_value: PrimValue | None + axis_separators: Sequence[IntImm] | None + input_axis_separators: Sequence[IntImm] | None + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.Resize2DAttrs") class Resize2DAttrs(Attrs): """Attributes used in image resize2d operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.Resize2DAttrs + # fmt: off + roi: Sequence[FloatImm] + layout: str + method: str + coordinate_transformation_mode: str + rounding_method: str + cubic_alpha: float + cubic_exclude: int + extrapolation_value: float + out_dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.ArgmaxArgminAttrs") class ArgmaxArgminAttrs(Attrs): """Attributes for argmax/argmin operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.ArgmaxArgminAttrs + # fmt: off + axis: int | None + keepdims: bool + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.RepeatAttrs") class RepeatAttrs(Attrs): """Attributes for repeat operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.RepeatAttrs + # fmt: off + repeats: int + axis: int | None + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.TileAttrs") class TileAttrs(Attrs): """Attributes for tile operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.TileAttrs + # fmt: off + repeats: Sequence[IntImm] + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.ScanopAttrs") class ScanopAttrs(Attrs): """Attributes for scan operators""" + # tvm-ffi-stubgen(begin): object/relax.attrs.ScanopAttrs + # fmt: off + axis: int | None + dtype: dtype + exclusive: IntImm + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.TopKAttrs") class TopKAttrs(Attrs): """Attributes for topk operators""" + # tvm-ffi-stubgen(begin): object/relax.attrs.TopKAttrs + # fmt: off + k: int + axis: int + ret_type: str + largest: bool + dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.EinsumAttrs") class EinsumAttrs(Attrs): """Attributes for einsum operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.EinsumAttrs + # fmt: off + subscripts: str + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.FlipAttrs") class FlipAttrs(Attrs): """Attributes for flip operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.FlipAttrs + # fmt: off + axis: IntImm + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.PadAttrs") class PadAttrs(Attrs): """Attributes used in pad operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.PadAttrs + # fmt: off + pad_width: Sequence[IntImm] + pad_value: float + pad_mode: str + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.MultinomialFromUniformAttrs") class MultinomialFromUniformAttrs(Attrs): """Attributes for multinomial_from_uniform operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.MultinomialFromUniformAttrs + # fmt: off + dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.CallInplacePackedAttrs") class CallInplacePackedAttrs(Attrs): """Attributes used in call_inplace_packed operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.CallInplacePackedAttrs + # fmt: off + inplace_indices: Sequence[IntImm] + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.CallTIRInplaceAttrs") class CallTIRInplaceAttrs(Attrs): """Attributes used in call_tir_inplace operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.CallTIRInplaceAttrs + # fmt: off + inplace_indices: Sequence[IntImm] + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.ToVDeviceAttrs") class ToVDeviceAttrs(Attrs): """Attributes used in to_vdevice operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.ToVDeviceAttrs + # fmt: off + dst_vdevice: VDevice + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.HintOnDeviceAttrs") class HintOnDeviceAttrs(Attrs): """Attributes used in hint_on_device operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.HintOnDeviceAttrs + # fmt: off + device_type: int + index: int + memory_scope: str + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.ScatterCollectiveAttrs") class ScatterCollectiveAttrs(Attrs): """Attributes used in scatter collective operators""" + # tvm-ffi-stubgen(begin): object/relax.attrs.ScatterCollectiveAttrs + # fmt: off + num_workers: int + axis: int + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.AttentionAttrs") class AttentionAttrs(Attrs): """Attributes used in attention operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.AttentionAttrs + # fmt: off + scale: FloatImm | None + causal_mask: str | None + window_size: IntImm | None + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.AllClassNonMaximumSuppressionAttrs") class AllClassNonMaximumSuppressionAttrs(Attrs): """Attributes for vision.all_class_non_max_suppression""" + # tvm-ffi-stubgen(begin): object/relax.attrs.AllClassNonMaximumSuppressionAttrs + # fmt: off + output_format: str + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.Conv1DAttrs") class Conv1DAttrs(Attrs): """Attributes for nn.conv1d""" + # tvm-ffi-stubgen(begin): object/relax.attrs.Conv1DAttrs + # fmt: off + strides: Sequence[IntImm] + padding: Sequence[IntImm] + dilation: Sequence[IntImm] + groups: int + data_layout: str + kernel_layout: str + out_layout: str + out_dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.Conv1DTransposeAttrs") class Conv1DTransposeAttrs(Attrs): """Attributes for nn.conv1d_transpose""" + # tvm-ffi-stubgen(begin): object/relax.attrs.Conv1DTransposeAttrs + # fmt: off + strides: Sequence[IntImm] + padding: Sequence[IntImm] + output_padding: Sequence[IntImm] + dilation: Sequence[IntImm] + groups: int + data_layout: str + kernel_layout: str + out_layout: str + out_dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.Pool1DAttrs") class Pool1DAttrs(Attrs): """Attributes for nn.max_pool1d and nn.avg_pool1d""" + # tvm-ffi-stubgen(begin): object/relax.attrs.Pool1DAttrs + # fmt: off + pool_size: Sequence[IntImm] + strides: Sequence[IntImm] + dilation: Sequence[IntImm] + padding: Sequence[IntImm] + ceil_mode: bool + count_include_pad: bool + layout: str + out_layout: str + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.Pool3DAttrs") class Pool3DAttrs(Attrs): """Attributes for nn.max_pool3d and nn.avg_pool3d""" + # tvm-ffi-stubgen(begin): object/relax.attrs.Pool3DAttrs + # fmt: off + pool_size: Sequence[IntImm] + strides: Sequence[IntImm] + dilation: Sequence[IntImm] + padding: Sequence[IntImm] + ceil_mode: bool + count_include_pad: bool + layout: str + out_layout: str + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.AdaptivePool1DAttrs") class AdaptivePool1DAttrs(Attrs): """Attributes for 1d adaptive pool operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.AdaptivePool1DAttrs + # fmt: off + output_size: Sequence[IntImm] | None + layout: str + out_layout: str + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.AdaptivePool3DAttrs") class AdaptivePool3DAttrs(Attrs): """Attributes for 3d adaptive pool operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.AdaptivePool3DAttrs + # fmt: off + output_size: Sequence[IntImm] | None + layout: str + out_layout: str + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.LeakyReluAttrs") class LeakyReluAttrs(Attrs): """Attributes used in leaky_relu operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.LeakyReluAttrs + # fmt: off + alpha: float + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.SoftplusAttrs") class SoftplusAttrs(Attrs): """Attributes used in softplus operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.SoftplusAttrs + # fmt: off + beta: float + threshold: float + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.PReluAttrs") class PReluAttrs(Attrs): """Attributes used in prelu operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.PReluAttrs + # fmt: off + axis: int + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.PixelShuffleAttrs") class PixelShuffleAttrs(Attrs): """Attributes used in pixel_shuffle operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.PixelShuffleAttrs + # fmt: off + upscale_factor: int + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.GroupNormAttrs") class GroupNormAttrs(Attrs): """Attributes used in group_norm operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.GroupNormAttrs + # fmt: off + num_groups: int + channel_axis: int + axes: Sequence[IntImm] + epsilon: float + center: bool + scale: bool + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.RMSNormAttrs") class RMSNormAttrs(Attrs): """Attributes used in rms_norm operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.RMSNormAttrs + # fmt: off + axes: Sequence[IntImm] + epsilon: float + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.NLLLossAttrs") class NLLLossAttrs(Attrs): """Attributes used in nll_loss operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.NLLLossAttrs + # fmt: off + reduction: str + ignore_index: int + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.AllReduceAttrs") class AllReduceAttrs(Attrs): """Attributes used in allreduce operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.AllReduceAttrs + # fmt: off + op_type: str + in_group: bool + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.AllGatherAttrs") class AllGatherAttrs(Attrs): """Attributes used in allgather operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.AllGatherAttrs + # fmt: off + num_workers: int + in_group: bool + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.WrapParamAttrs") class WrapParamAttrs(Attrs): """Attributes used in wrap_param operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.WrapParamAttrs + # fmt: off + dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.QuantizeAttrs") class QuantizeAttrs(Attrs): """Attributes used in quantize/dequantize operators""" + # tvm-ffi-stubgen(begin): object/relax.attrs.QuantizeAttrs + # fmt: off + out_dtype: dtype + axis: int + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.GatherElementsAttrs") class GatherElementsAttrs(Attrs): """Attributes for gather_elements operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.GatherElementsAttrs + # fmt: off + axis: IntImm + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.GatherNDAttrs") class GatherNDAttrs(Attrs): """Attributes for gather_nd operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.GatherNDAttrs + # fmt: off + batch_dims: IntImm + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.MeshgridAttrs") class MeshgridAttrs(Attrs): """Attributes for meshgrid operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.MeshgridAttrs + # fmt: off + indexing: str | None + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.ScatterElementsAttrs") class ScatterElementsAttrs(Attrs): """Attributes for scatter_elements operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.ScatterElementsAttrs + # fmt: off + axis: IntImm + reduction: str + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.ScatterNDAttrs") class ScatterNDAttrs(Attrs): """Attributes for scatter_nd operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.ScatterNDAttrs + # fmt: off + reduction: str + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.SliceScatterAttrs") class SliceScatterAttrs(Attrs): """Attributes for slice_scatter operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.SliceScatterAttrs + # fmt: off + axis: int + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.OneHotAttrs") class OneHotAttrs(Attrs): """Attributes for one_hot operator""" + + # tvm-ffi-stubgen(begin): object/relax.attrs.OneHotAttrs + # fmt: off + depth: int + axis: int + # fmt: on + # tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/op/vision/_ffi_api.py b/python/tvm/relax/op/vision/_ffi_api.py index 8af761dc5a00..4532e527298c 100644 --- a/python/tvm/relax/op/vision/_ffi_api.py +++ b/python/tvm/relax/op/vision/_ffi_api.py @@ -15,6 +15,23 @@ # specific language governing permissions and limitations # under the License. """Constructor APIs""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ir import RelaxExpr +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.op.vision", __name__) + +# tvm-ffi-stubgen(begin): global/relax.op.vision +# fmt: off +_FFI_INIT_FUNC("relax.op.vision", __name__) +if TYPE_CHECKING: + def all_class_non_max_suppression(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: RelaxExpr, _4: RelaxExpr, _5: str, /) -> RelaxExpr: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/op/vm/_ffi_api.py b/python/tvm/relax/op/vm/_ffi_api.py index eed64e53f036..15f7f55903ee 100644 --- a/python/tvm/relax/op/vm/_ffi_api.py +++ b/python/tvm/relax/op/vm/_ffi_api.py @@ -14,6 +14,27 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs for tvm.relax.op.vm""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ir import RelaxExpr + from relax.expr import DataTypeImm, PrimValue, StringImm, Tuple +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.op.vm", __name__) + +# tvm-ffi-stubgen(begin): global/relax.op.vm +# fmt: off +_FFI_INIT_FUNC("relax.op.vm", __name__) +if TYPE_CHECKING: + def alloc_storage(_0: RelaxExpr, _1: PrimValue, _2: DataTypeImm, _3: StringImm, /) -> RelaxExpr: ... + def alloc_tensor(_0: RelaxExpr, _1: PrimValue, _2: RelaxExpr, _3: DataTypeImm, /) -> RelaxExpr: ... + def call_tir_dyn(_0: RelaxExpr, _1: Tuple, /) -> RelaxExpr: ... + def kill_object(_0: RelaxExpr, /) -> RelaxExpr: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/struct_info.py b/python/tvm/relax/struct_info.py index e8f6c42435da..c7f32aaa9b6f 100644 --- a/python/tvm/relax/struct_info.py +++ b/python/tvm/relax/struct_info.py @@ -16,6 +16,18 @@ # under the License. # pylint: disable=invalid-name, unused-import """The struct info nodes of the Relax language.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import EnvFunc, PrimExpr, RelaxExpr, StructInfo, VDevice + from tvm_ffi import dtype +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import List, Optional, Union import tvm_ffi @@ -33,6 +45,11 @@ class ObjectStructInfo(StructInfo): """StructInfo of an Object.""" + # tvm-ffi-stubgen(begin): object/relax.ObjectStructInfo + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, span: Span = None) -> None: self.__init_handle_by_constructor__(_ffi_api.ObjectStructInfo, span) # type: ignore @@ -49,6 +66,13 @@ class PrimStructInfo(StructInfo): value. """ + # tvm-ffi-stubgen(begin): object/relax.PrimStructInfo + # fmt: off + value: PrimExpr | None + dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + value: Optional[PrimExpr] dtype: str @@ -124,6 +148,13 @@ class ShapeStructInfo(StructInfo): Do not specify values and ndim at the same time. """ + # tvm-ffi-stubgen(begin): object/relax.ShapeStructInfo + # fmt: off + values: Sequence[PrimExpr] | None + ndim: int + # fmt: on + # tvm-ffi-stubgen(end) + values: Optional[List[PrimExpr]] ndim: int span: Span @@ -159,6 +190,15 @@ class TensorStructInfo(StructInfo): Do not specify shape and ndim at the same time. """ + # tvm-ffi-stubgen(begin): object/relax.TensorStructInfo + # fmt: off + shape: RelaxExpr | None + dtype: dtype + vdevice: VDevice | None + ndim: int + # fmt: on + # tvm-ffi-stubgen(end) + shape: Optional[Expr] dtype: str vdevice: Optional[VDevice] @@ -190,6 +230,12 @@ class TupleStructInfo(StructInfo): The struct info of the fields. """ + # tvm-ffi-stubgen(begin): object/relax.TupleStructInfo + # fmt: off + fields: Sequence[StructInfo] + # fmt: on + # tvm-ffi-stubgen(end) + fields: List[StructInfo] span: Span @@ -216,6 +262,15 @@ class FuncStructInfo(StructInfo): we still consider it impure. """ + # tvm-ffi-stubgen(begin): object/relax.FuncStructInfo + # fmt: off + params: Sequence[StructInfo] | None + ret: StructInfo + derive_func: EnvFunc | None + purity: bool + # fmt: on + # tvm-ffi-stubgen(end) + params: Optional[List[StructInfo]] ret: StructInfo derive_func: Optional[EnvFunc] diff --git a/python/tvm/relax/testing/transform.py b/python/tvm/relax/testing/transform.py index 617ba73f09f4..5b19bafa6b05 100644 --- a/python/tvm/relax/testing/transform.py +++ b/python/tvm/relax/testing/transform.py @@ -17,6 +17,17 @@ # pylint: disable=unused-argument, invalid-name, no-else-return, abstract-method, arguments-differ """Relax transformation passes for testing""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import IntImm +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import logging import os from typing import Dict, List, Set, Tuple @@ -86,6 +97,13 @@ class InplaceOpportunity(Object): Indices of arguments that are eligible to be used as in-place targets. """ + # tvm-ffi-stubgen(begin): object/relax.transform.InplaceOpportunity + # fmt: off + binding_idx: IntImm + arg_idxs: Sequence[IntImm] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, _binding_idx, _arg_idxs): raise NotImplementedError("Constructor for InplaceOpportunity not exposed!") diff --git a/python/tvm/relax/training/_ffi_api.py b/python/tvm/relax/training/_ffi_api.py index 25f395830341..c7f1729db4e5 100644 --- a/python/tvm/relax/training/_ffi_api.py +++ b/python/tvm/relax/training/_ffi_api.py @@ -15,6 +15,24 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.relax.training""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from relax.expr import Function + from transform import Pass +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.training", __name__) + +# tvm-ffi-stubgen(begin): global/relax.training +# fmt: off +_FFI_INIT_FUNC("relax.training", __name__) +if TYPE_CHECKING: + def AppendLoss(_0: str, _1: Function, _2: int, _3: str | None, /) -> Pass: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/transform/_ffi_api.py b/python/tvm/relax/transform/_ffi_api.py index 25d6ecd75385..b928aa7c64d7 100644 --- a/python/tvm/relax/transform/_ffi_api.py +++ b/python/tvm/relax/transform/_ffi_api.py @@ -14,6 +14,96 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs for tvm.transform""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import IRModule, IntImm, PrimExpr, RelaxExpr, StructInfo, VDevice + from relax import DataflowBlockPass, FunctionPass, MatchResult + from relax.dpl import DFPattern + from relax.expr import DataflowBlock, Function, Var + from relax.transform import FusionPattern + from tir import IndexMap, PrimFunc, Var + from transform import Pass, PassContext, PassInfo + from tvm_ffi import Object, Tensor, dtype + from typing import Any, Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.transform", __name__) + +# tvm-ffi-stubgen(begin): global/relax.transform +# fmt: off +_FFI_INIT_FUNC("relax.transform", __name__) +if TYPE_CHECKING: + def AdjustMatmulOrder() -> Pass: ... + def AllocateWorkspace() -> Pass: ... + def AlterOpImpl(_0: Mapping[str, PrimFunc], _1: Mapping[str, Sequence[IndexMap]], _2: Mapping[str, Sequence[Sequence[IntImm]] | None], _3: Mapping[str, Sequence[Sequence[IntImm]] | None], /) -> Pass: ... + def AnnotateTIROpPattern() -> Pass: ... + def AttachAttrLayoutFreeBuffers() -> Pass: ... + def AttachGlobalSymbol() -> Pass: ... + def BindParams(_0: str, _1: Mapping[Any, Object], /) -> Pass: ... + def BindSymbolicVars(_0: Mapping[Var | str, PrimExpr], _1: str | None, /) -> Pass: ... + def BundleModelParams(_0: str | None, /) -> Pass: ... + def CallTIRRewrite() -> Pass: ... + def CanonicalizeBindings() -> Pass: ... + def CombineParallelMatmul(_0: Callable[[Var, Sequence[Var], Sequence[Var], Mapping[Var, RelaxExpr]], bool], /) -> Pass: ... + def ComputePrimValue() -> Pass: ... + def ConvertLayout(_0: Mapping[str, Sequence[str]], /) -> Pass: ... + def ConvertToDataflow(_0: int, /) -> Pass: ... + def DataflowUseInplaceCalls() -> Pass: ... + def DeadCodeElimination(_0: Sequence[str], /) -> Pass: ... + def DecomposeOpsForInference(_0: str | None, /) -> Pass: ... + def DecomposeOpsForTraining(_0: str | None, /) -> Pass: ... + def EliminateCommonSubexpr(_0: bool, /) -> Pass: ... + def ExpandMatmulOfSum() -> Pass: ... + def ExpandTupleArguments() -> Pass: ... + def FewShotTuning(_0: int, _1: bool, /) -> Pass: ... + def FoldConstant() -> Pass: ... + def FuseOps(_0: int, /) -> Pass: ... + def FuseOpsByPattern(_0: Sequence[FusionPattern], _1: bool, _2: bool, _3: Sequence[str], /) -> Pass: ... + def FuseTIR() -> Pass: ... + def FusionPattern(_0: str, _1: DFPattern, _2: Mapping[str, DFPattern], _3: Callable[..., Any] | None, _4: Callable[..., Any] | None, /) -> FusionPattern: ... + def Gradient(_0: str, _1: Sequence[Var] | None, _2: int, /) -> Pass: ... + def InlinePrivateFunctions() -> Pass: ... + def KillAfterLastUse() -> Pass: ... + def LambdaLift() -> Pass: ... + def LazyGetInput() -> Pass: ... + def LazySetOutput() -> Pass: ... + def LegalizeOps(_0: Mapping[str, Callable[..., Any]] | None, _1: Sequence[str] | None, _2: bool, /) -> Pass: ... + def LiftTransformParams(_0: IntImm | Sequence[str], /) -> Pass: ... + def LowerAllocTensor() -> Pass: ... + def LowerRuntimeBuiltin() -> Pass: ... + def MakeDataflowBlockPass(_0: Callable[[ObjectRValueRef[DataflowBlock], IRModule, PassContext], DataflowBlock], _1: PassInfo, /) -> DataflowBlockPass: ... + def MakeFunctionPass(_0: Callable[[ObjectRValueRef[Function], IRModule, PassContext], Function], _1: PassInfo, /) -> FunctionPass: ... + def MergeCompositeFunctions() -> Pass: ... + def MetaScheduleApplyDatabase(_0: str | None, _1: bool, /) -> Pass: ... + def MetaScheduleTuneIRMod(_0: Mapping[str, Tensor], _1: str, _2: IntImm, _3: IntImm | None, _4: Sequence[str] | None, /) -> Pass: ... + def MetaScheduleTuneTIR(_0: str, _1: IntImm, /) -> Pass: ... + def Normalize() -> Pass: ... + def NormalizeGlobalVar() -> Pass: ... + def RealizeVDevice() -> Pass: ... + def RemovePurityChecking() -> Pass: ... + def RemoveUnusedOutputs() -> Pass: ... + def RemoveUnusedParameters() -> Pass: ... + def ReorderPermuteDimsAfterConcat() -> Pass: ... + def ReorderTakeAfterMatmul() -> Pass: ... + def RewriteCUDAGraph() -> Pass: ... + def RewriteDataflowReshape() -> Pass: ... + def RunCodegen(_0: Mapping[str, Mapping[str, Any]] | None, _1: Sequence[str], /) -> Pass: ... + def SpecializePrimFuncBasedOnCallSite() -> Pass: ... + def SplitCallTIRByPattern(_0: Sequence[PrimFunc], _1: Callable[[Sequence[MatchResult]], Sequence[Any]], /) -> Pass: ... + def SplitLayoutRewritePreproc() -> Pass: ... + def StaticPlanBlockMemory() -> Pass: ... + def ToMixedPrecision(_0: dtype, _1: Sequence[str] | None, /) -> Pass: ... + def ToNonDataflow() -> Pass: ... + def TopologicalSort(_0: str, _1: str, /) -> Pass: ... + def UpdateParamStructInfo(_0: Callable[[Var], StructInfo | None], /) -> Pass: ... + def UpdateVDevice(_0: VDevice, _1: int, /) -> Pass: ... + def VMShapeLower(_0: bool, /) -> Pass: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 46efc17e3d4f..6ab829299170 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -16,6 +16,21 @@ # under the License. # pylint: disable=invalid-name """Relax transformation passes.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import RelaxExpr + from relax.dpl import DFPattern + from relax.expr import Var + from transform import PassInfo + from typing import Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import functools import inspect import types @@ -42,11 +57,23 @@ class FunctionPass(tvm.ir.transform.Pass): pass class should be created through `function_pass`. """ + # tvm-ffi-stubgen(begin): object/relax.FunctionPass + # fmt: off + pass_info: PassInfo + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.DataflowBlockPass") class DataflowBlockPass(tvm.ir.transform.Pass): """A pass that works on each tvm.relax.DataflowBlock in a module.""" + # tvm-ffi-stubgen(begin): object/relax.DataflowBlockPass + # fmt: off + pass_info: PassInfo + # fmt: on + # tvm-ffi-stubgen(end) + def Gradient( func_name: str, require_grads: Optional[Union[Var, List[Var]]] = None, target_index: int = 0 @@ -847,6 +874,16 @@ class PatternCheckContext(Object): matched expression. """ + # tvm-ffi-stubgen(begin): object/relax.transform.PatternCheckContext + # fmt: off + matched_expr: RelaxExpr + annotated_expr: Mapping[str, RelaxExpr] + matched_bindings: Mapping[Var, RelaxExpr] + var_usages: Mapping[Var, Sequence[Var]] + value_to_bound_var: Mapping[RelaxExpr, Var] + # fmt: on + # tvm-ffi-stubgen(end) + matched_expr: Expr annotated_expr: Mapping[str, Expr] matched_bindings: Mapping[Var, Expr] @@ -877,6 +914,16 @@ class FusionPattern(Object): The function to check whether the match result is accepted. """ + # tvm-ffi-stubgen(begin): object/relax.transform.FusionPattern + # fmt: off + name: str + pattern: DFPattern + annotation_patterns: Mapping[str, DFPattern] + check: Callable[..., Any] | None + attrs_getter: Callable[..., Any] | None + # fmt: on + # tvm-ffi-stubgen(end) + name: str pattern: DFPattern annotation_patterns: Mapping[str, DFPattern] diff --git a/python/tvm/relax/ty.py b/python/tvm/relax/ty.py index ebf757f38136..d1ba64d76223 100644 --- a/python/tvm/relax/ty.py +++ b/python/tvm/relax/ty.py @@ -16,6 +16,16 @@ # under the License. # pylint: disable=invalid-name, unused-import """The type nodes of the Relax language.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from tvm_ffi import dtype +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import tvm_ffi from tvm.ir import Type, TupleType, FuncType, Span @@ -32,6 +42,12 @@ class ShapeType(Type): The size of the shape. """ + # tvm-ffi-stubgen(begin): object/relax.ShapeType + # fmt: off + ndim: int + # fmt: on + # tvm-ffi-stubgen(end) + # TODO(relax-team): consider make ndim mandatory def __init__(self, ndim: int = -1, span: Span = None) -> None: self.__init_handle_by_constructor__(_ffi_api.ShapeType, ndim, span) # type: ignore @@ -42,6 +58,11 @@ class ObjectType(Type): """A type that corresponds to tvm::runtime::Object, is base of all possible object values in TVM.""" + # tvm-ffi-stubgen(begin): object/relax.ObjectType + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, span: Span = None) -> None: self.__init_handle_by_constructor__(_ffi_api.ObjectType, span) # type: ignore @@ -61,6 +82,13 @@ class TensorType(Type): The content data type. """ + # tvm-ffi-stubgen(begin): object/relax.DynTensorType + # fmt: off + ndim: int + dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, ndim=-1, dtype="float32", span: Span = None) -> None: self.__init_handle_by_constructor__(_ffi_api.TensorType, ndim, dtype, span) # type: ignore @@ -69,5 +97,10 @@ def __init__(self, ndim=-1, dtype="float32", span: Span = None) -> None: class PackedFuncType(Type): """The type of ExternFunc in Relax.""" + # tvm-ffi-stubgen(begin): object/relax.PackedFuncType + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, span: Span = None) -> None: self.__init_handle_by_constructor__(_ffi_api.PackedFuncType, span) # type: ignore diff --git a/python/tvm/rpc/_ffi_api.py b/python/tvm/rpc/_ffi_api.py index 80fd79e31348..2ba1da29eabb 100644 --- a/python/tvm/rpc/_ffi_api.py +++ b/python/tvm/rpc/_ffi_api.py @@ -15,7 +15,33 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.rpc""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from tvm_ffi import Module + from typing import Any, Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("rpc", __name__) + +# tvm-ffi-stubgen(begin): global/rpc +# fmt: off +_FFI_INIT_FUNC("rpc", __name__) +if TYPE_CHECKING: + def Connect(*args: Any) -> Any: ... + def CreateEventDrivenServer(_0: Callable[..., Any], _1: str, _2: str, /) -> Callable[..., Any]: ... + def CreatePipeClient(*args: Any) -> Any: ... + def ImportRemoteModule(_0: Module, _1: Module, /) -> None: ... + def LoadRemoteModule(_0: Module, _1: str, /) -> Module: ... + def LocalSession() -> Module: ... + def ReturnException(_0: int, _1: str, /) -> None: ... + def ServerLoop(*args: Any) -> Any: ... + def SessTableIndex(*args: Any) -> Any: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/runtime/_ffi_api.py b/python/tvm/runtime/_ffi_api.py index c713b379c384..8774925ca0e8 100644 --- a/python/tvm/runtime/_ffi_api.py +++ b/python/tvm/runtime/_ffi_api.py @@ -15,7 +15,58 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.runtime""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ctypes import c_void_p + from tvm_ffi import Device, Module, Shape, Tensor, dtype + from typing import Any, Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) + # Exports functions registered in runtime namespace. -tvm_ffi.init_ffi_api("runtime", __name__) +# tvm-ffi-stubgen(begin): global/runtime +# fmt: off +_FFI_INIT_FUNC("runtime", __name__) +if TYPE_CHECKING: + def CSourceModuleCreate(_0: str, _1: str, _2: Sequence[str] | None, _3: Sequence[str] | None, /) -> Module: ... + def Device_SetStream(_0: Device, _1: int, /) -> None: ... + def Device_StreamCreate(_0: Device, /) -> int: ... + def Device_StreamFree(_0: Device, _1: int, /) -> None: ... + def Device_StreamSync(_0: Device, _1: int, /) -> None: ... + def Device_StreamSyncFromTo(_0: Device, _1: int, _2: int, /) -> None: ... + def GetCudaDeviceCount() -> int: ... + def GetCudaFreeMemory() -> str: ... + def GetDeviceAttr(*args: Any) -> Any: ... + def LoadParams(_0: bytes, /) -> Mapping[str, Tensor]: ... + def LoadParamsFromFile(_0: str, /) -> Mapping[str, Tensor]: ... + def ModuleImportsBlobName() -> str: ... + def ModuleLoadStaticLibrary(_0: str, _1: Sequence[str], /) -> Module: ... + def ModulePackImportsToC(_0: Module, _1: bool, _2: str, /) -> str: ... + def ModulePackImportsToLLVM(_0: Module, _1: bool, _2: str, _3: str, /) -> Module: ... + def ModulePackImportsToTensor(_0: Module, /) -> Tensor: ... + def NumThreads() -> int: ... + def RPCTimeEvaluator(_0: Module | None, _1: str, _2: int, _3: int, _4: int, _5: int, _6: int, _7: int, _8: int, _9: int, _10: int, _11: str, /) -> Callable[..., Any]: ... + def RuntimeEnabled(_0: str, /) -> bool: ... + def SaveParams(_0: Mapping[str, Tensor], /) -> bytes: ... + def SaveParamsToFile(_0: Mapping[str, Tensor], _1: str, /) -> None: ... + def SourceModuleCreate(_0: str, _1: str, /) -> Module: ... + def TVMSetStream(_0: int, _1: int, _2: c_void_p, /) -> None: ... + def TVMTensorAllocWithScope(_0: Shape, _1: dtype, _2: Device, _3: str | None, /) -> Tensor: ... + def TVMTensorCopyFromBytes(_0: Tensor, _1: c_void_p, _2: int, /) -> None: ... + def TVMTensorCopyFromTo(_0: Tensor, _1: Tensor, /) -> None: ... + def TVMTensorCopyToBytes(_0: Tensor, _1: c_void_p, _2: int, /) -> None: ... + def TVMTensorCreateView(_0: Tensor, _1: Shape, _2: dtype, _3: int, /) -> Tensor: ... + def _datatype_get_type_registered(*args: Any) -> Any: ... + def config_threadpool(*args: Any) -> Any: ... + def cuTensorMapEncodeTiled(*args: Any) -> Any: ... + def get_cuda_stream() -> c_void_p: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/runtime/_tensor.py b/python/tvm/runtime/_tensor.py index 3affbf55d563..68da1e0ca79d 100644 --- a/python/tvm/runtime/_tensor.py +++ b/python/tvm/runtime/_tensor.py @@ -16,6 +16,8 @@ # under the License. # pylint: disable=invalid-name, unused-import, redefined-outer-name """Runtime Tensor API""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) import ctypes import warnings from typing import Optional @@ -72,6 +74,11 @@ class Tensor(tvm_ffi.core.Tensor): how can we use TVM in existing project which might have their own array containers. """ + # tvm-ffi-stubgen(begin): object/ffi.Tensor + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __setitem__(self, in_slice, value): """Set ndarray value""" if ( diff --git a/python/tvm/runtime/disco/_ffi_api.py b/python/tvm/runtime/disco/_ffi_api.py index 2caeef293ea5..f69842d447ac 100644 --- a/python/tvm/runtime/disco/_ffi_api.py +++ b/python/tvm/runtime/disco/_ffi_api.py @@ -15,6 +15,63 @@ # specific language governing permissions and limitations # under the License. """FFI APIs from C++""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from runtime.disco import DRef, Session + from tvm_ffi import Device, Module, Object, Shape, Tensor, dtype + from typing import Any +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("runtime.disco", __name__) + +# tvm-ffi-stubgen(begin): global/runtime.disco +# fmt: off +_FFI_INIT_FUNC("runtime.disco", __name__) +if TYPE_CHECKING: + def DRefDebugCopyFrom(_0: DRef, _1: int, _2: Any, /) -> None: ... + def DRefDebugGetFromRemote(_0: DRef, _1: int, /) -> Any: ... + def RemoteSocketSession(_0: str, _1: int, _2: int, /) -> None: ... + def SessionCallPacked(*args: Any) -> Any: ... + def SessionCopyFromWorker0(_0: Session, _1: Tensor, _2: DRef, /) -> None: ... + def SessionCopyToWorker0(_0: Session, _1: Tensor, _2: DRef, /) -> None: ... + def SessionGetGlobalFunc(_0: Session, _1: str, /) -> DRef: ... + def SessionGetNumWorkers(_0: Session, /) -> int: ... + def SessionInitCCL(_0: Session, _1: str, _2: Shape, /) -> None: ... + def SessionProcess(_0: int, _1: int, _2: str, _3: str, /) -> Session: ... + def SessionShutdown(_0: Session, /) -> None: ... + def SessionSyncWorker(_0: Session, _1: int, /) -> None: ... + def SessionThreaded(_0: int, _1: int, /) -> Session: ... + def ShardLoader(_0: str, _1: str, _2: str, _3: Module | None, /) -> Object: ... + def ShardLoaderLoad(_0: Object, _1: Shape, /) -> Tensor: ... + def ShardLoaderLoadAll(_0: Object, /) -> Sequence[Tensor]: ... + def ShardLoaderLoadAllPresharded(_0: Object, /) -> Sequence[Tensor]: ... + def ShardLoaderLoadParamOnWorker0(_0: Object, _1: int, /) -> Tensor: ... + def ShardLoaderLoadPresharded(_0: Object, _1: Shape, /) -> Tensor: ... + def SocketSession(_0: int, _1: int, _2: int, _3: str, _4: int, /) -> Session: ... + def WorkerProcess(_0: int, _1: int, _2: int, _3: int, _4: int, /) -> None: ... + def allgather(_0: Tensor, _1: bool, _2: Tensor, /) -> None: ... + def allreduce(_0: Tensor, _1: Shape, _2: bool, _3: Tensor, /) -> None: ... + def bind_worker_to_cpu_core(_0: Shape, /) -> None: ... + def broadcast_from_worker0(_0: Tensor, _1: bool, _2: Tensor, /) -> None: ... + def device() -> Device: ... + def empty(_0: Shape, _1: dtype, _2: Device | None, _3: bool, _4: bool, /) -> Tensor | None: ... + def gather_to_worker0(_0: Tensor, _1: bool, _2: Tensor | None, /) -> None: ... + def load_vm_module(_0: str, _1: Device | None, /) -> Module: ... + def recv_from_prev_group(_0: Tensor, /) -> None: ... + def recv_from_worker(_0: Tensor, _1: int, /) -> None: ... + def recv_from_worker0(_0: Tensor, /) -> None: ... + def scatter_from_worker0(_0: Tensor | None, _1: bool, _2: Tensor, /) -> None: ... + def send_to_next_group(_0: Tensor, /) -> None: ... + def send_to_worker(_0: Tensor, _1: int, /) -> None: ... + def socket_session_init_workers(_0: int, _1: int, _2: int, _3: int, /) -> None: ... + def worker_id() -> Shape: ... + def worker_rank() -> int: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/runtime/disco/session.py b/python/tvm/runtime/disco/session.py index f2c2dfc791ab..be7791be85f5 100644 --- a/python/tvm/runtime/disco/session.py +++ b/python/tvm/runtime/disco/session.py @@ -18,6 +18,8 @@ with the distributed runtime. """ +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) import logging import os import pickle @@ -40,6 +42,11 @@ class DRef(Object): to each object, and the worker process uses this id to refer to the object residing on itself. """ + # tvm-ffi-stubgen(begin): object/runtime.disco.DRef + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def debug_get_from_remote(self, worker_id: int) -> Any: """Get the value of a DRef from a remote worker. It is only used for debugging purposes. @@ -103,6 +110,11 @@ class Session(Object): """A Disco interactive session. It allows users to interact with the Disco command queue with various PackedFunc calling convention.""" + # tvm-ffi-stubgen(begin): object/runtime.disco.Session + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def _get_cached_method(self, name: str) -> Callable: if "_cache" not in self.__dict__: cache = self._cache = {} # pylint: disable=attribute-defined-outside-init @@ -532,6 +544,11 @@ def _clear_ipc_memory_pool(self): class ThreadedSession(Session): """A Disco session backed by multi-threading.""" + # tvm-ffi-stubgen(begin): object/runtime.disco.ThreadedSession + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, num_workers: int, num_groups: int = 1) -> None: """Create a disco session backed by multiple threads in the same process.""" self.__init_handle_by_constructor__( @@ -545,6 +562,11 @@ def __init__(self, num_workers: int, num_groups: int = 1) -> None: class ProcessSession(Session): """A Disco session backed by pipe-based multi-processing.""" + # tvm-ffi-stubgen(begin): object/runtime.disco.ProcessSession + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, num_workers: int, @@ -593,6 +615,11 @@ def _create_socket_session_local_workers(num_workers) -> Session: class SocketSession(Session): """A Disco session backed by socket-based multi-node communication.""" + # tvm-ffi-stubgen(begin): object/runtime.disco.SocketSession + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, num_nodes: int, diff --git a/python/tvm/runtime/profiling/__init__.py b/python/tvm/runtime/profiling/__init__.py index 3ca831ac4200..e2074eb28e41 100644 --- a/python/tvm/runtime/profiling/__init__.py +++ b/python/tvm/runtime/profiling/__init__.py @@ -16,6 +16,8 @@ # under the License. """Registration of profiling objects in python.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from typing import Dict, Sequence, Optional from ... import ffi as _ffi from . import _ffi_api @@ -35,6 +37,9 @@ class Report(Object): Per-device metrics collected over the entire run. """ + # tvm-ffi-stubgen(begin): object/runtime.profiling.Report + # tvm-ffi-stubgen(end) + def __init__( self, calls: Sequence[Dict[str, Object]], @@ -176,6 +181,9 @@ def from_json(cls, s): class Count(Object): """A integer count of something""" + # tvm-ffi-stubgen(begin): object/runtime.profiling.Count + # tvm-ffi-stubgen(end) + def __init__(self, count: int): self.__init_handle_by_constructor__(_ffi_api.Count, count) @@ -184,6 +192,9 @@ def __init__(self, count: int): class Duration(Object): """A duration of something""" + # tvm-ffi-stubgen(begin): object/runtime.profiling.Duration + # tvm-ffi-stubgen(end) + def __init__(self, duration: float): self.__init_handle_by_constructor__(_ffi_api.Duration, duration) @@ -192,6 +203,9 @@ def __init__(self, duration: float): class Percent(Object): """A Percent of something""" + # tvm-ffi-stubgen(begin): object/runtime.profiling.Percent + # tvm-ffi-stubgen(end) + def __init__(self, percent: float): self.__init_handle_by_constructor__(_ffi_api.Percent, percent) @@ -200,6 +214,9 @@ def __init__(self, percent: float): class Ratio(Object): """A Ratio of two things""" + # tvm-ffi-stubgen(begin): object/runtime.profiling.Ratio + # tvm-ffi-stubgen(end) + def __init__(self, ratio: float): self.__init_handle_by_constructor__(_ffi_api.Ratio, ratio) @@ -208,11 +225,17 @@ def __init__(self, ratio: float): class MetricCollector(Object): """Interface for user defined profiling metric collection.""" + # tvm-ffi-stubgen(begin): object/runtime.profiling.MetricCollector + # tvm-ffi-stubgen(end) + @_ffi.register_object("runtime.profiling.DeviceWrapper") class DeviceWrapper(Object): """Wraps a tvm.runtime.Device""" + # tvm-ffi-stubgen(begin): object/runtime.profiling.DeviceWrapper + # tvm-ffi-stubgen(end) + def __init__(self, dev: Device): self.__init_handle_by_constructor__(_ffi_api.DeviceWrapper, dev) @@ -279,6 +302,9 @@ class PAPIMetricCollector(MetricCollector): Application Programming Interface (PAPI). """ + # tvm-ffi-stubgen(begin): object/runtime.profiling.PAPIMetricCollector + # tvm-ffi-stubgen(end) + def __init__(self, metric_names: Optional[Dict[Device, Sequence[str]]] = None): """ Parameters diff --git a/python/tvm/runtime/profiling/_ffi_api.py b/python/tvm/runtime/profiling/_ffi_api.py index 883e3ca6e778..927d0a77ab5c 100644 --- a/python/tvm/runtime/profiling/_ffi_api.py +++ b/python/tvm/runtime/profiling/_ffi_api.py @@ -15,6 +15,36 @@ # specific language governing permissions and limitations # under the License. """FFI for profiling""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from runtime.profiling import DeviceWrapper, MetricCollector, Report + from tvm_ffi import Device, Module, Object + from typing import Any, Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("runtime.profiling", __name__) + +# tvm-ffi-stubgen(begin): global/runtime.profiling +# fmt: off +_FFI_INIT_FUNC("runtime.profiling", __name__) +if TYPE_CHECKING: + def AsCSV(_0: Report, /) -> str: ... + def AsJSON(_0: Report, /) -> str: ... + def AsTable(_0: Report, _1: bool, _2: bool, _3: bool, /) -> str: ... + def Count(_0: int, /) -> Object: ... + def DeviceWrapper(_0: Device, /) -> DeviceWrapper: ... + def Duration(_0: float, /) -> Object: ... + def FromJSON(_0: str, /) -> Report: ... + def Percent(_0: float, /) -> Object: ... + def ProfileFunction(_0: Module, _1: str, _2: int, _3: int, _4: int, _5: Sequence[MetricCollector], /) -> Callable[..., Any]: ... + def Ratio(_0: float, /) -> Object: ... + def Report(_0: Sequence[Mapping[str, Any]], _1: Mapping[str, Mapping[str, Any]], _2: Mapping[str, Any], /) -> Report: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/runtime/script_printer.py b/python/tvm/runtime/script_printer.py index 7442cd99172f..4d2d561aa1ff 100644 --- a/python/tvm/runtime/script_printer.py +++ b/python/tvm/runtime/script_printer.py @@ -15,6 +15,18 @@ # specific language governing permissions and limitations # under the License. """Configuration of TVMScript printer""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from tvm_ffi import Object, dtype + from tvm_ffi.reflection import AccessPath +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import os from typing import Dict, List, Optional, Sequence @@ -29,6 +41,31 @@ class PrinterConfig(Object): """Configuration of TVMScript printer""" + # tvm-ffi-stubgen(begin): object/script.PrinterConfig + # fmt: off + binding_names: Sequence[str] + show_meta: bool + ir_prefix: str + tir_prefix: str + relax_prefix: str + module_alias: str + buffer_dtype: dtype + int_dtype: dtype + float_dtype: dtype + verbose_expr: bool + indent_spaces: int + print_line_numbers: bool + num_context_lines: int + syntax_sugar: bool + show_object_address: bool + show_all_struct_info: bool + path_to_underline: Sequence[AccessPath] + path_to_annotate: Mapping[AccessPath, str] + obj_to_underline: Sequence[Object] + obj_to_annotate: Mapping[Object, str] + # fmt: on + # tvm-ffi-stubgen(end) + binding_names: Sequence[str] show_meta: bool ir_prefix: str diff --git a/python/tvm/script/_ffi_api.py b/python/tvm/script/_ffi_api.py index 1354d3f2ec2c..f866edba616a 100644 --- a/python/tvm/script/_ffi_api.py +++ b/python/tvm/script/_ffi_api.py @@ -14,7 +14,25 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.script""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from tir import Buffer, PrimFunc +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("script", __name__) + +# tvm-ffi-stubgen(begin): global/script +# fmt: off +_FFI_INIT_FUNC("script", __name__) +if TYPE_CHECKING: + def Complete(_0: PrimFunc, _1: Sequence[Buffer], /) -> PrimFunc: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/script/ir_builder/_ffi_api.py b/python/tvm/script/ir_builder/_ffi_api.py index c8a9597d5292..7b60db417169 100644 --- a/python/tvm/script/ir_builder/_ffi_api.py +++ b/python/tvm/script/ir_builder/_ffi_api.py @@ -15,6 +15,34 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.script.ir_builder""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from script.ir_builder import IRBuilder, IRBuilderFrame + from tvm_ffi import Object + from typing import Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("script.ir_builder", __name__) # pylint: disable=protected-access + +# tvm-ffi-stubgen(begin): global/script.ir_builder +# fmt: off +_FFI_INIT_FUNC("script.ir_builder", __name__) +if TYPE_CHECKING: + def IRBuilder() -> IRBuilder: ... + def IRBuilderCurrent() -> IRBuilder: ... + def IRBuilderEnter(_0: IRBuilder, /) -> None: ... + def IRBuilderExit(_0: IRBuilder, /) -> None: ... + def IRBuilderFrameAddCallback(_0: IRBuilderFrame, _1: Callable[[], None], /) -> None: ... + def IRBuilderFrameEnter(_0: IRBuilderFrame, /) -> None: ... + def IRBuilderFrameExit(_0: IRBuilderFrame, /) -> None: ... + def IRBuilderGet(_0: IRBuilder, /) -> Object: ... + def IRBuilderIsInScope() -> bool: ... + def IRBuilderName(_0: str, _1: Object, /) -> Object: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/script/ir_builder/ir/_ffi_api.py b/python/tvm/script/ir_builder/ir/_ffi_api.py index e319c3d4612e..831faf434166 100644 --- a/python/tvm/script/ir_builder/ir/_ffi_api.py +++ b/python/tvm/script/ir_builder/ir/_ffi_api.py @@ -15,6 +15,34 @@ # specific language governing permissions and limitations # under the License. """FFI APIs""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import BaseFunc, GlobalInfo, GlobalVar, VDevice + from script.ir_builder import IRModuleFrame + from tvm_ffi import Object + from typing import Any +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("script.ir_builder.ir", __name__) # pylint: disable=protected-access + +# tvm-ffi-stubgen(begin): global/script.ir_builder.ir +# fmt: off +_FFI_INIT_FUNC("script.ir_builder.ir", __name__) +if TYPE_CHECKING: + def DeclFunction(_0: str, _1: BaseFunc, /) -> GlobalVar: ... + def DefFunction(_0: str, _1: BaseFunc, /) -> None: ... + def IRModule() -> IRModuleFrame: ... + def LookupVDevice(_0: str, _1: int, /) -> VDevice: ... + def ModuleAttrs(_0: Mapping[str, Any], _1: bool, /) -> None: ... + def ModuleGetAttr(_0: str, /) -> Object | None: ... + def ModuleGlobalInfos(_0: Mapping[str, Sequence[GlobalInfo]], /) -> None: ... + def ModuleSetAttr(_0: str, _1: Object | None, _2: bool, /) -> None: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/script/ir_builder/relax/_ffi_api.py b/python/tvm/script/ir_builder/relax/_ffi_api.py index f6c53336ff4c..ef1b58ff3697 100644 --- a/python/tvm/script/ir_builder/relax/_ffi_api.py +++ b/python/tvm/script/ir_builder/relax/_ffi_api.py @@ -15,6 +15,42 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.script.ir_builder.relax""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import IntImm, RelaxExpr, StructInfo + from relax.expr import Var, VarBinding + from script.ir_builder.relax import BlockFrame, ElseFrame, FunctionFrame, IfFrame, SeqExprFrame, ThenFrame + from typing import Any +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("script.ir_builder.relax", __name__) # pylint: disable=protected-access + +# tvm-ffi-stubgen(begin): global/script.ir_builder.relax +# fmt: off +_FFI_INIT_FUNC("script.ir_builder.relax", __name__) +if TYPE_CHECKING: + def Arg(_0: str, _1: StructInfo, /) -> Var: ... + def BindingBlock() -> BlockFrame: ... + def Dataflow() -> BlockFrame: ... + def DataflowBlockOutput(_0: Sequence[Var], /) -> None: ... + def Else() -> ElseFrame: ... + def Emit(_0: RelaxExpr, _1: StructInfo | None, /) -> Var: ... + def EmitMatchCast(_0: RelaxExpr, _1: StructInfo, /) -> Var: ... + def EmitVarBinding(_0: VarBinding, /) -> Var: ... + def FuncAttrs(_0: Mapping[str, Any], /) -> None: ... + def FuncName(_0: str, /) -> None: ... + def FuncRetStructInfo(_0: StructInfo, /) -> None: ... + def FuncRetValue(_0: RelaxExpr, /) -> None: ... + def Function(_0: IntImm, _1: IntImm, /) -> FunctionFrame: ... + def If(_0: RelaxExpr, /) -> IfFrame: ... + def SeqExpr() -> SeqExprFrame: ... + def Then() -> ThenFrame: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/script/ir_builder/relax/distributed/_ffi_api.py b/python/tvm/script/ir_builder/relax/distributed/_ffi_api.py index b82fa37e8f3f..5eed61bab3b6 100644 --- a/python/tvm/script/ir_builder/relax/distributed/_ffi_api.py +++ b/python/tvm/script/ir_builder/relax/distributed/_ffi_api.py @@ -15,8 +15,26 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.script.ir_builder.relax.distributed""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import RelaxExpr + from relax import DTensorStructInfo + from relax.expr import Tuple +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api( - "script.ir_builder.relax.distributed", __name__ -) # pylint: disable=protected-access + +# tvm-ffi-stubgen(begin): global/script.ir_builder.relax.distributed +# fmt: off +_FFI_INIT_FUNC("script.ir_builder.relax.distributed", __name__) +if TYPE_CHECKING: + def call_tir_dist(_0: RelaxExpr, _1: Tuple, _2: Sequence[DTensorStructInfo], _3: RelaxExpr | None, /) -> RelaxExpr: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/script/ir_builder/tir/_ffi_api.py b/python/tvm/script/ir_builder/tir/_ffi_api.py index 4385b2ec13d0..b612a32af8e5 100644 --- a/python/tvm/script/ir_builder/tir/_ffi_api.py +++ b/python/tvm/script/ir_builder/tir/_ffi_api.py @@ -15,6 +15,219 @@ # specific language governing permissions and limitations # under the License. """FFI APIs""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import IntImm, PrimExpr, Range, Type + from script.ir_builder.tir import AllocateConstFrame, AllocateFrame, AssertFrame, AttrFrame, BlockFrame, BlockInitFrame, DeclBufferFrame, ElseFrame, ForFrame, IfFrame, LaunchThreadFrame, LetFrame, PrimFuncFrame, RealizeFrame, ThenFrame, WhileFrame + from tir import Buffer, BufferRegion, Var + from tvm_ffi import Object, Tensor, dtype + from typing import Any +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("script.ir_builder.tir", __name__) # pylint: disable=protected-access + +# tvm-ffi-stubgen(begin): global/script.ir_builder.tir +# fmt: off +_FFI_INIT_FUNC("script.ir_builder.tir", __name__) +if TYPE_CHECKING: + def AllocBuffer(_0: Sequence[PrimExpr], _1: dtype, _2: Var | None, _3: Sequence[PrimExpr], _4: PrimExpr, _5: str, _6: int, _7: int, _8: str, _9: Sequence[IntImm] | None, /) -> Buffer: ... + def Allocate(_0: Sequence[PrimExpr], _1: dtype, _2: str, _3: PrimExpr | None, _4: Mapping[str, Any] | None, /) -> AllocateFrame: ... + def AllocateConst(_0: Tensor, _1: dtype, _2: Sequence[PrimExpr], _3: Mapping[str, Any] | None, /) -> AllocateConstFrame: ... + def Arg(_0: str, _1: Object, /) -> Object: ... + def Assert(_0: PrimExpr, _1: str, /) -> AssertFrame: ... + def Attr(_0: Any, _1: str, _2: PrimExpr, /) -> AttrFrame: ... + def AxisOpaque(_0: Range, _1: PrimExpr, _2: dtype, /) -> Var: ... + def AxisReduce(_0: Range, _1: PrimExpr, _2: dtype, /) -> Var: ... + def AxisRemap(_0: str, _1: Sequence[PrimExpr], _2: dtype, /) -> Sequence[Var]: ... + def AxisScan(_0: Range, _1: PrimExpr, _2: dtype, /) -> Var: ... + def AxisSpatial(_0: Range, _1: PrimExpr, _2: dtype, /) -> Var: ... + def BFloat16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def BFloat16x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def BFloat16x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def BFloat16x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def BFloat16x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def BFloat16x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Block(_0: str, _1: bool, /) -> BlockFrame: ... + def BlockAttrs(_0: Mapping[str, Any], /) -> None: ... + def Boolean(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Buffer(_0: Sequence[PrimExpr], _1: dtype, _2: str, _3: Var | None, _4: Sequence[PrimExpr] | None, _5: PrimExpr | None, _6: str, _7: int, _8: int, _9: str, _10: Sequence[IntImm] | None, /) -> Buffer: ... + def BufferStore(_0: Buffer, _1: PrimExpr, _2: Sequence[PrimExpr], _3: PrimExpr | None, /) -> None: ... + def DeclBuffer(_0: Sequence[PrimExpr], _1: dtype, _2: str, _3: Var | None, _4: Sequence[PrimExpr] | None, _5: PrimExpr | None, _6: str, _7: int, _8: int, _9: str, _10: Sequence[IntImm] | None, /) -> DeclBufferFrame: ... + def Else() -> ElseFrame: ... + def EnvThread(_0: str, _1: dtype, /) -> Var: ... + def Evaluate(_0: PrimExpr, /) -> None: ... + def Float16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float16x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float16x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float16x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float16x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float16x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float32x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float32x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float32x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float32x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float32x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float4E2M1FN(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float4E2M1FNx16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float4E2M1FNx32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float4E2M1FNx4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float4E2M1FNx64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float4E2M1FNx8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float64x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float64x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float64x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float64x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float64x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float6E2M3FN(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float6E2M3FNx16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float6E2M3FNx32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float6E2M3FNx4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float6E2M3FNx64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float6E2M3FNx8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float6E3M2FN(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float6E3M2FNx16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float6E3M2FNx32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float6E3M2FNx4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float6E3M2FNx64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float6E3M2FNx8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E3M4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E3M4x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E3M4x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E3M4x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E3M4x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E3M4x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3B11FNUZ(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3B11FNUZx16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3B11FNUZx32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3B11FNUZx4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3B11FNUZx64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3B11FNUZx8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3FN(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3FNUZ(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3FNUZx16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3FNUZx32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3FNUZx4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3FNUZx64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3FNUZx8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3FNx16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3FNx32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3FNx4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3FNx64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3FNx8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E5M2(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E5M2FNUZ(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E5M2FNUZx16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E5M2FNUZx32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E5M2FNUZx4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E5M2FNUZx64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E5M2FNUZx8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E5M2x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E5M2x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E5M2x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E5M2x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E5M2x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E8M0FNU(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E8M0FNUx16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E8M0FNUx32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E8M0FNUx4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E8M0FNUx64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E8M0FNUx8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def FuncAttrs(_0: Mapping[str, Any], /) -> None: ... + def FuncName(_0: str, /) -> None: ... + def FuncRet(_0: Type, /) -> Type: ... + def Grid(_0: Sequence[PrimExpr], /) -> ForFrame: ... + def Handle(_0: dtype, _1: str, _2: bool, _3: bool, /) -> Var: ... + def If(_0: PrimExpr, /) -> IfFrame: ... + def Init() -> BlockInitFrame: ... + def Int16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int16x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int16x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int16x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int16x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int16x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int32x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int32x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int32x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int32x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int32x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int64x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int64x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int64x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int64x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int64x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int8x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int8x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int8x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int8x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int8x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def LaunchThread(_0: Var | str, _1: PrimExpr, /) -> LaunchThreadFrame: ... + def LegacyLetStmt(_0: Var, _1: PrimExpr, /) -> LetFrame: ... + def LetStmt(_0: PrimExpr, _1: Type | None, _2: Var | None, /) -> LetFrame: ... + def MatchBuffer(_0: Object, _1: Sequence[PrimExpr], _2: dtype, _3: Var | None, _4: Sequence[PrimExpr], _5: PrimExpr, _6: str, _7: int, _8: int, _9: str, _10: Sequence[IntImm] | None, /) -> Buffer: ... + def Parallel(_0: PrimExpr, _1: PrimExpr, _2: Mapping[str, Any] | None, _3: PrimExpr | None, /) -> ForFrame: ... + def PrimFunc(_0: bool, /) -> PrimFuncFrame: ... + def Ptr(_0: dtype, _1: str, _2: bool, /) -> PrimExpr: ... + def Reads(_0: Sequence[Object], /) -> None: ... + def Realize(_0: BufferRegion, _1: str, _2: PrimExpr, /) -> RealizeFrame: ... + def Serial(_0: PrimExpr, _1: PrimExpr, _2: Mapping[str, Any] | None, _3: PrimExpr | None, /) -> ForFrame: ... + def TensormapHandle() -> Var: ... + def Then() -> ThenFrame: ... + def ThreadBinding(_0: PrimExpr, _1: PrimExpr, _2: str, _3: Mapping[str, Any] | None, /) -> ForFrame: ... + def UInt16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt16x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt16x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt16x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt16x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt16x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt32x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt32x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt32x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt32x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt32x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt64x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt64x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt64x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt64x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt64x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt8x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt8x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt8x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt8x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt8x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Unroll(_0: PrimExpr, _1: PrimExpr, _2: Mapping[str, Any] | None, _3: PrimExpr | None, /) -> ForFrame: ... + def Vectorized(_0: PrimExpr, _1: PrimExpr, _2: Mapping[str, Any] | None, _3: PrimExpr | None, /) -> ForFrame: ... + def Void(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Where(_0: PrimExpr, /) -> None: ... + def While(_0: PrimExpr, /) -> WhileFrame: ... + def Writes(_0: Sequence[Object], /) -> None: ... + def max(_0: PrimExpr, _1: PrimExpr, /) -> PrimExpr: ... + def min(_0: PrimExpr, _1: PrimExpr, /) -> PrimExpr: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/script/printer/_ffi_api.py b/python/tvm/script/printer/_ffi_api.py index 967d0d824ba2..2b1b920e46fc 100644 --- a/python/tvm/script/printer/_ffi_api.py +++ b/python/tvm/script/printer/_ffi_api.py @@ -15,6 +15,61 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.script.printer""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from script import PrinterConfig + from script.printer import AssertDoc, AssignDoc, AttrAccessDoc, CallDoc, ClassDoc, CommentDoc, DictDoc, Doc, DocStringDoc, ExprDoc, ExprStmtDoc, ForDoc, FunctionDoc, IdDoc, IfDoc, IndexDoc, LambdaDoc, ListDoc, LiteralDoc, OperationDoc, ReturnDoc, ScopeDoc, SliceDoc, StmtBlockDoc, StmtDoc, TupleDoc, WhileDoc + from tvm_ffi import Object + from tvm_ffi.reflection import AccessPath +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("script.printer", __name__) # pylint: disable=protected-access + +# tvm-ffi-stubgen(begin): global/script.printer +# fmt: off +_FFI_INIT_FUNC("script.printer", __name__) +if TYPE_CHECKING: + def AssertDoc(_0: ExprDoc, _1: ExprDoc | None, /) -> AssertDoc: ... + def AssignDoc(_0: ExprDoc, _1: ExprDoc | None, _2: ExprDoc | None, /) -> AssignDoc: ... + def AttrAccessDoc(_0: ExprDoc, _1: str, /) -> AttrAccessDoc: ... + def CallDoc(_0: ExprDoc, _1: Sequence[ExprDoc], _2: Sequence[str], _3: Sequence[ExprDoc], /) -> CallDoc: ... + def ClassDoc(_0: IdDoc, _1: Sequence[ExprDoc], _2: Sequence[StmtDoc], /) -> ClassDoc: ... + def CommentDoc(_0: str, /) -> CommentDoc: ... + def DictDoc(_0: Sequence[ExprDoc], _1: Sequence[ExprDoc], /) -> DictDoc: ... + def DocSetSourcePaths(_0: Doc, _1: Sequence[AccessPath], /) -> None: ... + def DocStringDoc(_0: str, /) -> DocStringDoc: ... + def DocToPythonScript(_0: Doc, _1: PrinterConfig, /) -> str: ... + def ExprDocAttr(_0: ExprDoc, _1: str, /) -> ExprDoc: ... + def ExprDocCall(_0: ExprDoc, _1: Sequence[ExprDoc], _2: Sequence[str], _3: Sequence[ExprDoc], /) -> ExprDoc: ... + def ExprDocIndex(_0: ExprDoc, _1: Sequence[Doc], /) -> ExprDoc: ... + def ExprStmtDoc(_0: ExprDoc, /) -> ExprStmtDoc: ... + def ForDoc(_0: ExprDoc, _1: ExprDoc, _2: Sequence[StmtDoc], /) -> ForDoc: ... + def FunctionDoc(_0: IdDoc, _1: Sequence[AssignDoc], _2: Sequence[ExprDoc], _3: ExprDoc | None, _4: Sequence[StmtDoc], /) -> FunctionDoc: ... + def IdDoc(_0: str, /) -> IdDoc: ... + def IfDoc(_0: ExprDoc, _1: Sequence[StmtDoc], _2: Sequence[StmtDoc], /) -> IfDoc: ... + def IndexDoc(_0: ExprDoc, _1: Sequence[Doc], /) -> IndexDoc: ... + def LambdaDoc(_0: Sequence[IdDoc], _1: ExprDoc, /) -> LambdaDoc: ... + def ListDoc(_0: Sequence[ExprDoc], /) -> ListDoc: ... + def LiteralDocBoolean(_0: bool, _1: AccessPath | None, /) -> LiteralDoc: ... + def LiteralDocFloat(_0: float, _1: AccessPath | None, /) -> LiteralDoc: ... + def LiteralDocInt(_0: int, _1: AccessPath | None, /) -> LiteralDoc: ... + def LiteralDocNone(_0: AccessPath | None, /) -> LiteralDoc: ... + def LiteralDocStr(_0: str, _1: AccessPath | None, /) -> LiteralDoc: ... + def OperationDoc(_0: int, _1: Sequence[ExprDoc], /) -> OperationDoc: ... + def ReprPrintRelax(_0: Object, _1: PrinterConfig, /) -> str: ... + def ReturnDoc(_0: ExprDoc, /) -> ReturnDoc: ... + def ScopeDoc(_0: ExprDoc | None, _1: ExprDoc, _2: Sequence[StmtDoc], /) -> ScopeDoc: ... + def SliceDoc(_0: ExprDoc | None, _1: ExprDoc | None, _2: ExprDoc | None, /) -> SliceDoc: ... + def StmtBlockDoc(_0: Sequence[StmtDoc], /) -> StmtBlockDoc: ... + def StmtDocSetComment(_0: StmtDoc, _1: str | None, /) -> None: ... + def TupleDoc(_0: Sequence[ExprDoc], /) -> TupleDoc: ... + def WhileDoc(_0: ExprDoc, _1: Sequence[StmtDoc], /) -> WhileDoc: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py index 62d8c563dd3f..544642b2752c 100644 --- a/python/tvm/script/printer/doc.py +++ b/python/tvm/script/printer/doc.py @@ -16,6 +16,18 @@ # under the License. """Doc types for TVMScript Unified Printer""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from tvm_ffi.reflection import AccessPath + from typing import Any +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from enum import IntEnum, unique from typing import Dict, List, Optional, Sequence, Tuple, Union @@ -31,11 +43,22 @@ class Doc(Object): """Base class of all Docs""" + # tvm-ffi-stubgen(begin): object/script.printer.Doc + # fmt: off + source_paths: Sequence[AccessPath] + # fmt: on + # tvm-ffi-stubgen(end) + @register_object("script.printer.ExprDoc") class ExprDoc(Doc): """Base class of all expression Docs""" + # tvm-ffi-stubgen(begin): object/script.printer.ExprDoc + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def attr(self, name: str) -> "AttrAccessDoc": """ Create a doc that represents attribute access on self. @@ -106,6 +129,12 @@ def __iter__(self): class StmtDoc(Doc): """Base class of statement doc""" + # tvm-ffi-stubgen(begin): object/script.printer.StmtDoc + # fmt: off + comment: str | None + # fmt: on + # tvm-ffi-stubgen(end) + @register_object("script.printer.StmtBlockDoc") class StmtBlockDoc(Doc): @@ -115,6 +144,12 @@ class StmtBlockDoc(Doc): list of StmtDoc. """ + # tvm-ffi-stubgen(begin): object/script.printer.StmtBlockDoc + # fmt: off + stmts: Sequence[StmtDoc] + # fmt: on + # tvm-ffi-stubgen(end) + stmts: Sequence[StmtDoc] def __init__(self, stmts: List[StmtDoc]): @@ -125,6 +160,12 @@ def __init__(self, stmts: List[StmtDoc]): class LiteralDoc(ExprDoc): """Doc that represents literal value""" + # tvm-ffi-stubgen(begin): object/script.printer.LiteralDoc + # fmt: off + value: Any + # fmt: on + # tvm-ffi-stubgen(end) + value: Union[str, IntImm, FloatImm, None] def __init__( @@ -166,6 +207,12 @@ def __init__( class IdDoc(ExprDoc): """Doc that represents identifier""" + # tvm-ffi-stubgen(begin): object/script.printer.IdDoc + # fmt: off + name: str + # fmt: on + # tvm-ffi-stubgen(end) + name: str def __init__(self, name: str): @@ -176,6 +223,13 @@ def __init__(self, name: str): class AttrAccessDoc(ExprDoc): """Doc that represents attribute access on an expression""" + # tvm-ffi-stubgen(begin): object/script.printer.AttrAccessDoc + # fmt: off + value: ExprDoc + name: str + # fmt: on + # tvm-ffi-stubgen(end) + value: ExprDoc name: str @@ -187,6 +241,13 @@ def __init__(self, value: ExprDoc, name: str): class IndexDoc(ExprDoc): """Doc that represents index access on an expression""" + # tvm-ffi-stubgen(begin): object/script.printer.IndexDoc + # fmt: off + value: ExprDoc + indices: Sequence[Doc] + # fmt: on + # tvm-ffi-stubgen(end) + value: ExprDoc indices: Sequence[Union[ExprDoc, "SliceDoc"]] @@ -198,6 +259,15 @@ def __init__(self, value: ExprDoc, indices: List[Union[ExprDoc, "SliceDoc"]]): class CallDoc(ExprDoc): """Doc that represents function call""" + # tvm-ffi-stubgen(begin): object/script.printer.CallDoc + # fmt: off + callee: ExprDoc + args: Sequence[ExprDoc] + kwargs_keys: Sequence[str] + kwargs_values: Sequence[ExprDoc] + # fmt: on + # tvm-ffi-stubgen(end) + callee: ExprDoc args: Sequence[ExprDoc] kwargs_keys: Sequence[str] @@ -271,6 +341,13 @@ class OperationDoc(ExprDoc): if-then-else expression). """ + # tvm-ffi-stubgen(begin): object/script.printer.OperationDoc + # fmt: off + kind: int + operands: Sequence[ExprDoc] + # fmt: on + # tvm-ffi-stubgen(end) + kind: OperationKind operands: Sequence[ExprDoc] @@ -282,6 +359,13 @@ def __init__(self, kind: OperationKind, operands: List[ExprDoc]): class LambdaDoc(ExprDoc): """Doc that represents lambda function""" + # tvm-ffi-stubgen(begin): object/script.printer.LambdaDoc + # fmt: off + args: Sequence[IdDoc] + body: ExprDoc + # fmt: on + # tvm-ffi-stubgen(end) + args: Sequence[IdDoc] body: ExprDoc @@ -293,6 +377,12 @@ def __init__(self, args: List[IdDoc], body: ExprDoc): class TupleDoc(ExprDoc): """Doc that represents tuple literal""" + # tvm-ffi-stubgen(begin): object/script.printer.TupleDoc + # fmt: off + elements: Sequence[ExprDoc] + # fmt: on + # tvm-ffi-stubgen(end) + elements: Sequence[ExprDoc] def __init__(self, elements: List[ExprDoc]): @@ -303,6 +393,12 @@ def __init__(self, elements: List[ExprDoc]): class ListDoc(ExprDoc): """Doc that represents list literal""" + # tvm-ffi-stubgen(begin): object/script.printer.ListDoc + # fmt: off + elements: Sequence[ExprDoc] + # fmt: on + # tvm-ffi-stubgen(end) + elements: Sequence[ExprDoc] def __init__(self, elements: List[ExprDoc]): @@ -313,6 +409,13 @@ def __init__(self, elements: List[ExprDoc]): class DictDoc(ExprDoc): """Doc that represents dict literal""" + # tvm-ffi-stubgen(begin): object/script.printer.DictDoc + # fmt: off + keys: Sequence[ExprDoc] + values: Sequence[ExprDoc] + # fmt: on + # tvm-ffi-stubgen(end) + keys: Sequence[ExprDoc] values: Sequence[ExprDoc] @@ -330,6 +433,14 @@ class SliceDoc(ExprDoc): This doc can only appear in `IndexDoc.indices`. """ + # tvm-ffi-stubgen(begin): object/script.printer.SliceDoc + # fmt: off + start: ExprDoc | None + stop: ExprDoc | None + step: ExprDoc | None + # fmt: on + # tvm-ffi-stubgen(end) + start: Optional[ExprDoc] stop: Optional[ExprDoc] step: Optional[ExprDoc] @@ -347,6 +458,14 @@ def __init__( class AssignDoc(StmtDoc): """Doc that represents assign statement.""" + # tvm-ffi-stubgen(begin): object/script.printer.AssignDoc + # fmt: off + lhs: ExprDoc + rhs: ExprDoc | None + annotation: ExprDoc | None + # fmt: on + # tvm-ffi-stubgen(end) + lhs: ExprDoc rhs: Optional[ExprDoc] annotation: Optional[ExprDoc] @@ -364,6 +483,14 @@ def __init__(self, lhs: ExprDoc, rhs: Optional[ExprDoc], annotation: Optional[Ex class IfDoc(StmtDoc): """Doc that represent if-then-else statement.""" + # tvm-ffi-stubgen(begin): object/script.printer.IfDoc + # fmt: off + predicate: ExprDoc + then_branch: Sequence[StmtDoc] + else_branch: Sequence[StmtDoc] + # fmt: on + # tvm-ffi-stubgen(end) + predicate: ExprDoc then_branch: Sequence[StmtDoc] else_branch: Sequence[StmtDoc] @@ -381,6 +508,13 @@ def __init__(self, predicate: ExprDoc, then_branch: List[StmtDoc], else_branch: class WhileDoc(StmtDoc): """Doc that represents while statement.""" + # tvm-ffi-stubgen(begin): object/script.printer.WhileDoc + # fmt: off + predicate: ExprDoc + body: Sequence[StmtDoc] + # fmt: on + # tvm-ffi-stubgen(end) + predicate: ExprDoc body: Sequence[StmtDoc] @@ -392,6 +526,14 @@ def __init__(self, predicate: ExprDoc, body: List[StmtDoc]): class ForDoc(StmtDoc): """Doc that represents for statement.""" + # tvm-ffi-stubgen(begin): object/script.printer.ForDoc + # fmt: off + lhs: ExprDoc + rhs: ExprDoc + body: Sequence[StmtDoc] + # fmt: on + # tvm-ffi-stubgen(end) + lhs: ExprDoc rhs: ExprDoc body: Sequence[StmtDoc] @@ -411,6 +553,14 @@ class ScopeDoc(StmtDoc):