From bf7c1b9b7e60f530457094725f4e0eb64738d47d Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 24 Dec 2025 10:12:58 -0800 Subject: [PATCH] Update TVM-FFI to v0.1.7-rc0 --- 3rdparty/tvm-ffi | 2 +- python/tvm/arith/_ffi_api.py | 68 ++- python/tvm/arith/analyzer.py | 16 + python/tvm/arith/int_set.py | 27 + python/tvm/arith/int_solver.py | 38 ++ python/tvm/arith/iter_affine_map.py | 47 ++ python/tvm/contrib/cutlass/_ffi_api.py | 15 +- python/tvm/contrib/msc/core/_ffi_api.py | 14 +- python/tvm/contrib/msc/core/ir/graph.py | 26 + .../msc/framework/tensorflow/_ffi_api.py | 14 +- .../msc/framework/tensorrt/_ffi_api.py | 14 +- .../contrib/msc/framework/torch/_ffi_api.py | 14 +- .../tvm/contrib/msc/framework/tvm/_ffi_api.py | 14 +- python/tvm/contrib/msc/plugin/_ffi_api.py | 14 +- python/tvm/contrib/msc/plugin/op/_ffi_api.py | 14 +- python/tvm/driver/_ffi_api.py | 15 +- python/tvm/ir/_ffi_api.py | 93 ++- python/tvm/ir/attrs.py | 22 + python/tvm/ir/base.py | 46 ++ python/tvm/ir/diagnostics/__init__.py | 33 ++ python/tvm/ir/diagnostics/_ffi_api.py | 31 +- python/tvm/ir/expr.py | 43 ++ python/tvm/ir/function.py | 16 + python/tvm/ir/global_info.py | 26 + python/tvm/ir/instrument.py | 14 + python/tvm/ir/module.py | 21 + python/tvm/ir/op.py | 23 + python/tvm/ir/supply.py | 12 + python/tvm/ir/transform.py | 50 ++ python/tvm/ir/type.py | 52 ++ python/tvm/ir/type_relation.py | 8 + python/tvm/meta_schedule/_ffi_api.py | 177 +++++- python/tvm/meta_schedule/arg_info.py | 22 + python/tvm/meta_schedule/builder/builder.py | 40 ++ .../meta_schedule/cost_model/cost_model.py | 12 + python/tvm/meta_schedule/database/database.py | 40 ++ .../meta_schedule/database/json_database.py | 9 + .../meta_schedule/database/memory_database.py | 18 + .../database/ordered_union_database.py | 17 + .../database/schedule_fn_database.py | 17 + .../meta_schedule/database/union_database.py | 17 + python/tvm/meta_schedule/extracted_task.py | 22 + .../feature_extractor/feature_extractor.py | 12 + .../feature_extractor/per_store_feature.py | 12 + .../measure_callback/add_to_database.py | 6 + .../measure_callback/measure_callback.py | 12 + .../measure_callback/remove_build_artifact.py | 6 + .../measure_callback/update_cost_model.py | 6 + .../mutator/mutate_compute_location.py | 7 + .../meta_schedule/mutator/mutate_parallel.py | 8 + .../mutator/mutate_thread_binding.py | 7 + .../meta_schedule/mutator/mutate_tile_size.py | 7 + .../meta_schedule/mutator/mutate_unroll.py | 7 + python/tvm/meta_schedule/mutator/mutator.py | 7 + .../disallow_async_strided_mem_copy.py | 7 + .../postproc/disallow_dynamic_loop.py | 7 + python/tvm/meta_schedule/postproc/postproc.py | 12 + .../postproc/rewrite_cooperative_fetch.py | 7 + .../meta_schedule/postproc/rewrite_layout.py | 7 + .../rewrite_parallel_vectorize_unroll.py | 7 + .../postproc/rewrite_reduction_block.py | 7 + .../postproc/rewrite_tensorize.py | 7 + .../postproc/rewrite_unbound_block.py | 7 + .../meta_schedule/postproc/verify_gpu_code.py | 7 + .../postproc/verify_vtcm_limit.py | 7 + python/tvm/meta_schedule/profiler.py | 7 + python/tvm/meta_schedule/runner/runner.py | 42 ++ .../schedule_rule/add_rfactor.py | 9 + .../schedule_rule/apply_custom_rule.py | 16 + .../meta_schedule/schedule_rule/auto_bind.py | 7 + .../schedule_rule/auto_inline.py | 28 + .../schedule_rule/cross_thread_reduction.py | 19 + .../schedule_rule/multi_level_tiling.py | 33 ++ .../parallel_vectorize_unroll.py | 20 + .../schedule_rule/random_compute_location.py | 7 + .../schedule_rule/schedule_rule.py | 12 + .../search_strategy/evolutionary_search.py | 16 + .../search_strategy/replay_func.py | 7 + .../search_strategy/replay_trace.py | 8 + .../search_strategy/search_strategy.py | 29 + .../space_generator/post_order_apply.py | 7 + .../space_generator/schedule_fn.py | 7 + .../space_generator/space_generator.py | 25 + .../space_generator/space_generator_union.py | 17 + .../task_scheduler/gradient_based.py | 9 + .../task_scheduler/round_robin.py | 8 + .../task_scheduler/task_scheduler.py | 40 ++ python/tvm/meta_schedule/tune_context.py | 24 + python/tvm/relax/_ffi_api.py | 141 ++++- python/tvm/relax/analysis/_ffi_api.py | 51 +- python/tvm/relax/backend/_ffi_api.py | 28 +- .../backend/adreno/transform/_ffi_api.py | 23 +- python/tvm/relax/binding_rewrite.py | 17 + python/tvm/relax/block_builder.py | 7 + python/tvm/relax/distributed/_ffi_api.py | 30 +- python/tvm/relax/distributed/global_info.py | 20 + python/tvm/relax/distributed/struct_info.py | 33 ++ .../relax/distributed/transform/_ffi_api.py | 24 +- python/tvm/relax/dpl/rewrite.py | 36 ++ python/tvm/relax/exec_builder.py | 7 + python/tvm/relax/expr.py | 161 +++++ python/tvm/relax/expr_functor.py | 21 + python/tvm/relax/op/_ffi_api.py | 168 +++++- python/tvm/relax/op/builtin/_ffi_api.py | 23 +- python/tvm/relax/op/ccl/_ffi_api.py | 24 +- python/tvm/relax/op/distributed/_ffi_api.py | 28 +- python/tvm/relax/op/grad/_ffi_api.py | 28 +- python/tvm/relax/op/image/_ffi_api.py | 24 +- python/tvm/relax/op/memory/_ffi_api.py | 27 +- python/tvm/relax/op/nn/_ffi_api.py | 58 +- python/tvm/relax/op/op_attrs.py | 551 ++++++++++++++++++ python/tvm/relax/op/vision/_ffi_api.py | 21 +- python/tvm/relax/op/vm/_ffi_api.py | 25 +- python/tvm/relax/struct_info.py | 55 ++ python/tvm/relax/testing/transform.py | 18 + python/tvm/relax/training/_ffi_api.py | 22 +- python/tvm/relax/transform/_ffi_api.py | 94 ++- python/tvm/relax/transform/transform.py | 47 ++ python/tvm/relax/ty.py | 33 ++ python/tvm/rpc/_ffi_api.py | 30 +- python/tvm/runtime/_ffi_api.py | 55 +- python/tvm/runtime/_tensor.py | 7 + python/tvm/runtime/disco/_ffi_api.py | 61 +- python/tvm/runtime/disco/session.py | 27 + python/tvm/runtime/profiling/__init__.py | 26 + python/tvm/runtime/profiling/_ffi_api.py | 34 +- python/tvm/runtime/script_printer.py | 37 ++ python/tvm/script/_ffi_api.py | 22 +- python/tvm/script/ir_builder/_ffi_api.py | 32 +- python/tvm/script/ir_builder/ir/_ffi_api.py | 32 +- .../tvm/script/ir_builder/relax/_ffi_api.py | 40 +- .../ir_builder/relax/distributed/_ffi_api.py | 26 +- python/tvm/script/ir_builder/tir/_ffi_api.py | 217 ++++++- python/tvm/script/printer/_ffi_api.py | 59 +- python/tvm/script/printer/doc.py | 197 +++++++ python/tvm/target/_ffi_api.py | 40 +- python/tvm/target/target.py | 19 + python/tvm/target/virtual_device.py | 9 + python/tvm/te/_ffi_api.py | 36 +- python/tvm/te/tensor.py | 71 +++ python/tvm/testing/_ffi_api.py | 92 ++- python/tvm/testing/attrs.py | 20 + python/tvm/tir/_ffi_api.py | 179 +++++- python/tvm/tir/analysis/_ffi_api.py | 41 +- python/tvm/tir/block_dependence_info.py | 7 + python/tvm/tir/block_scope.py | 21 + python/tvm/tir/buffer.py | 33 ++ python/tvm/tir/data_layout.py | 30 + python/tvm/tir/expr.py | 263 +++++++++ python/tvm/tir/function.py | 37 ++ python/tvm/tir/functor.py | 12 + python/tvm/tir/schedule/_ffi_api.py | 125 +++- python/tvm/tir/schedule/analysis.py | 30 + python/tvm/tir/schedule/state.py | 18 + python/tvm/tir/stmt.py | 163 ++++++ python/tvm/tir/transform/_ffi_api.py | 106 +++- python/tvm/tir/transform/function_pass.py | 16 + python/tvm/tir/transform/transform.py | 62 ++ 158 files changed, 5885 insertions(+), 101 deletions(-) diff --git a/3rdparty/tvm-ffi b/3rdparty/tvm-ffi index ae346ec92a3c..89be2d3ff90d 160000 --- a/3rdparty/tvm-ffi +++ b/3rdparty/tvm-ffi @@ -1 +1 @@ -Subproject commit ae346ec92a3c386f1376064ae086aae72947c329 +Subproject commit 89be2d3ff90d021355975a52d72e5eefc02bc335 diff --git a/python/tvm/arith/_ffi_api.py b/python/tvm/arith/_ffi_api.py index 519423aa4e1f..12ddd90af574 100644 --- a/python/tvm/arith/_ffi_api.py +++ b/python/tvm/arith/_ffi_api.py @@ -15,7 +15,71 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.arith""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from arith import ConstIntBound, IntConstraints, IntConstraintsTransform, IntGroupBounds, IntervalSet, IterMark, IterSplitExpr, IterSumExpr, ModularSet + from collections.abc import Mapping, Sequence + from ir import IntImm, IntSet, PrimExpr, Range + from tir import Buffer, PrimFunc, Stmt, Var + from tvm_ffi import Object + from typing import Any +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("arith", __name__) + +# tvm-ffi-stubgen(begin): global/arith +# fmt: off +_FFI_INIT_FUNC("arith", __name__) +if TYPE_CHECKING: + def ConstIntBound(_0: int, _1: int, /) -> ConstIntBound: ... + def CreateAnalyzer(*args: Any) -> Any: ... + def DeduceBound(_0: PrimExpr, _1: PrimExpr, _2: Mapping[Var, IntSet], _3: Mapping[Var, IntSet], /) -> IntSet: ... + def DetectClipBound(_0: PrimExpr, _1: Sequence[Var], /) -> Sequence[PrimExpr]: ... + def DetectCommonSubExpr(_0: PrimExpr, _1: int, /) -> Mapping[PrimExpr, IntImm]: ... + def DetectIterMap(_0: Sequence[PrimExpr], _1: Mapping[Var, Range], _2: PrimExpr, _3: int, _4: bool, /) -> Object: ... + def DetectLinearEquation(_0: PrimExpr, _1: Sequence[Var], /) -> Sequence[PrimExpr]: ... + def DomainTouched(_0: Stmt, _1: Buffer, _2: bool, _3: bool, /) -> Sequence[Range]: ... + def DomainTouchedAccessMap(_0: PrimFunc, /) -> Mapping[Buffer, Sequence[Object]]: ... + def EstimateRegionLowerBound(_0: Sequence[Range], _1: Mapping[Var, Range], _2: PrimExpr, /) -> Sequence[IntSet] | None: ... + def EstimateRegionStrictBound(_0: Sequence[Range], _1: Mapping[Var, Range], _2: PrimExpr, /) -> Sequence[IntSet] | None: ... + def EstimateRegionUpperBound(_0: Sequence[Range], _1: Mapping[Var, Range], _2: PrimExpr, /) -> Sequence[IntSet] | None: ... + def IntConstraints(_0: Sequence[Var], _1: Mapping[Var, Range], _2: Sequence[PrimExpr], /) -> IntConstraints: ... + def IntConstraintsTransform(_0: IntConstraints, _1: IntConstraints, _2: Mapping[Var, PrimExpr], _3: Mapping[Var, PrimExpr], /) -> IntConstraintsTransform: ... + def IntGroupBounds(_0: PrimExpr, _1: Sequence[PrimExpr], _2: Sequence[PrimExpr], _3: Sequence[PrimExpr], /) -> IntGroupBounds: ... + def IntGroupBounds_FindBestRange(*args: Any) -> Any: ... + def IntGroupBounds_from_range(_0: Range, /) -> IntGroupBounds: ... + def IntSetIsEverything(_0: IntSet, /) -> bool: ... + def IntSetIsNothing(_0: IntSet, /) -> bool: ... + def IntervalSet(_0: PrimExpr, _1: PrimExpr, /) -> IntervalSet: ... + def IntervalSetGetMax(_0: IntSet, /) -> PrimExpr: ... + def IntervalSetGetMin(_0: IntSet, /) -> PrimExpr: ... + def InverseAffineIterMap(_0: Sequence[IterSumExpr], _1: Sequence[PrimExpr], /) -> Mapping[Var, PrimExpr]: ... + def IterMapSimplify(_0: Sequence[PrimExpr], _1: Mapping[Var, Range], _2: PrimExpr, _3: int, _4: bool, /) -> Sequence[PrimExpr]: ... + def IterMark(_0: PrimExpr, _1: PrimExpr, /) -> IterMark: ... + def IterSplitExpr(_0: IterMark, _1: PrimExpr, _2: PrimExpr, _3: PrimExpr, /) -> IterSplitExpr: ... + def IterSumExpr(_0: Sequence[IterSplitExpr], _1: PrimExpr, /) -> IterSumExpr: ... + def ModularSet(_0: int, _1: int, /) -> ModularSet: ... + def NarrowPredicateExpression(_0: PrimExpr, _1: Mapping[Var, Range], /) -> PrimExpr: ... + def NegInf() -> PrimExpr: ... + def NormalizeIterMapToExpr(_0: PrimExpr, /) -> PrimExpr: ... + def NormalizeToIterSum(_0: PrimExpr, _1: Mapping[Var, Range], /) -> IterSumExpr: ... + def PosInf() -> PrimExpr: ... + def PresburgerSet(_0: PrimExpr, /) -> IntSet: ... + def SolveInequalitiesAsCondition(*args: Any) -> Any: ... + def SolveInequalitiesDeskewRange(*args: Any) -> Any: ... + def SolveInequalitiesToRange(*args: Any) -> Any: ... + def SolveLinearEquations(*args: Any) -> Any: ... + def SubspaceDivide(_0: Sequence[PrimExpr], _1: Mapping[Var, Range], _2: Sequence[Var], _3: PrimExpr, _4: int, _5: bool, /) -> Sequence[Sequence[IterMark]]: ... + def UnionLowerBound(_0: Sequence[IntSet], /) -> IntSet: ... + def intset_interval(_0: PrimExpr, _1: PrimExpr, /) -> IntSet: ... + def intset_single_point(_0: PrimExpr, /) -> IntSet: ... + def intset_vector(_0: PrimExpr, /) -> IntSet: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index c5c8fc067cc8..ec57b263c598 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -16,6 +16,8 @@ # under the License. # pylint: disable=invalid-name """Arithmetic data structure and utility""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) import enum from typing import Union @@ -51,6 +53,13 @@ class Extension(enum.Flag): class ModularSet(Object): """Represent range of (coeff * x + base) for x in Z""" + # tvm-ffi-stubgen(begin): object/arith.ModularSet + # fmt: off + coeff: int + base: int + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, coeff, base): self.__init_handle_by_constructor__(_ffi_api.ModularSet, coeff, base) @@ -68,6 +77,13 @@ class ConstIntBound(Object): The maximum value of the bound. """ + # tvm-ffi-stubgen(begin): object/arith.ConstIntBound + # fmt: off + min_value: int + max_value: int + # fmt: on + # tvm-ffi-stubgen(end) + POS_INF = (1 << 63) - 1 NEG_INF = -POS_INF diff --git a/python/tvm/arith/int_set.py b/python/tvm/arith/int_set.py index fc6c20dec1ce..2b865801d509 100644 --- a/python/tvm/arith/int_set.py +++ b/python/tvm/arith/int_set.py @@ -15,6 +15,16 @@ # specific language governing permissions and limitations # under the License. """Integer set.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ir import PrimExpr +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import tvm_ffi from tvm.runtime import Object from . import _ffi_api @@ -24,6 +34,11 @@ class IntSet(Object): """Represent a set of integer in one dimension.""" + # tvm-ffi-stubgen(begin): object/ir.IntSet + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def is_nothing(self): """Whether the set represent nothing""" return _ffi_api.IntSetIsNothing(self) @@ -78,6 +93,13 @@ class IntervalSet(IntSet): The maximum value in the interval. """ + # tvm-ffi-stubgen(begin): object/arith.IntervalSet + # fmt: off + min_value: PrimExpr + max_value: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, min_value, max_value): self.__init_handle_by_constructor__(_ffi_api.IntervalSet, min_value, max_value) @@ -86,6 +108,11 @@ def __init__(self, min_value, max_value): class PresburgerSet(IntSet): """Represent of Presburger Set""" + # tvm-ffi-stubgen(begin): object/arith.PresburgerSet + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self): self.__init_handle_by_constructor__(_ffi_api.PresburgerSet) diff --git a/python/tvm/arith/int_solver.py b/python/tvm/arith/int_solver.py index 72e4c46896ff..c7f7cafb02e2 100644 --- a/python/tvm/arith/int_solver.py +++ b/python/tvm/arith/int_solver.py @@ -15,6 +15,18 @@ # specific language governing permissions and limitations # under the License. """integer constraints data structures and solvers""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import PrimExpr, Range + from tir import Var +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import tvm_ffi from tvm.runtime import Object from . import _ffi_api @@ -40,6 +52,15 @@ class IntGroupBounds(Object): the upper bounds (include) """ + # tvm-ffi-stubgen(begin): object/arith.IntGroupBounds + # fmt: off + coef: PrimExpr + lower: Sequence[PrimExpr] + equal: Sequence[PrimExpr] + upper: Sequence[PrimExpr] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, coef, lower, equal, upper): self.__init_handle_by_constructor__(_ffi_api.IntGroupBounds, coef, lower, equal, upper) @@ -81,6 +102,14 @@ class IntConstraints(Object): The relations between the variables (either equations or inequalities) """ + # tvm-ffi-stubgen(begin): object/arith.IntConstraints + # fmt: off + variables: Sequence[Var] + ranges: Mapping[Var, Range] + relations: Sequence[PrimExpr] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, variables, ranges, relations): self.__init_handle_by_constructor__(_ffi_api.IntConstraints, variables, ranges, relations) @@ -113,6 +142,15 @@ class IntConstraintsTransform(Object): e.g., {m -> a, n -> -b} """ + # tvm-ffi-stubgen(begin): object/arith.IntConstraintsTransform + # fmt: off + src: IntConstraints + dst: IntConstraints + src_to_dst: Mapping[Var, PrimExpr] + dst_to_src: Mapping[Var, PrimExpr] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, src, dst, src_to_dst, dst_to_src): self.__init_handle_by_constructor__( _ffi_api.IntConstraintsTransform, src, dst, src_to_dst, dst_to_src diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py index 69ad3022fb4a..4dd5a658cbb5 100644 --- a/python/tvm/arith/iter_affine_map.py +++ b/python/tvm/arith/iter_affine_map.py @@ -15,6 +15,17 @@ # specific language governing permissions and limitations # under the License. """Iterator (quasi)affine mapping patterns.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import PrimExpr +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from enum import IntEnum import tvm_ffi from tvm.runtime import Object @@ -26,6 +37,11 @@ class IterMapExpr(PrimExpr): """Base class of all IterMap expressions.""" + # tvm-ffi-stubgen(begin): object/arith.IterMapExpr + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("arith.IterMark") class IterMark(Object): @@ -40,6 +56,13 @@ class IterMark(Object): The extent of the iterator. """ + # tvm-ffi-stubgen(begin): object/arith.IterMark + # fmt: off + source: PrimExpr + extent: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, source, extent): self.__init_handle_by_constructor__(_ffi_api.IterMark, source, extent) @@ -65,6 +88,15 @@ class IterSplitExpr(IterMapExpr): Additional scale to the split. """ + # tvm-ffi-stubgen(begin): object/arith.IterSplitExpr + # fmt: off + source: IterMark + lower_factor: PrimExpr + extent: PrimExpr + scale: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, source, lower_factor, extent, scale): self.__init_handle_by_constructor__( _ffi_api.IterSplitExpr, source, lower_factor, extent, scale @@ -86,6 +118,13 @@ class IterSumExpr(IterMapExpr): The base offset. """ + # tvm-ffi-stubgen(begin): object/arith.IterSumExpr + # fmt: off + args: Sequence[IterSplitExpr] + base: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, args, base): self.__init_handle_by_constructor__(_ffi_api.IterSumExpr, args, base) @@ -94,6 +133,14 @@ def __init__(self, args, base): class IterMapResult(Object): """Result of iter map detection.""" + # tvm-ffi-stubgen(begin): object/arith.IterMapResult + # fmt: off + indices: Sequence[IterSumExpr] + errors: Sequence[str] + padding_predicate: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + class IterMapLevel(IntEnum): """Possible kinds of iter mapping check level.""" diff --git a/python/tvm/contrib/cutlass/_ffi_api.py b/python/tvm/contrib/cutlass/_ffi_api.py index d57825835b6b..d6824bf02101 100644 --- a/python/tvm/contrib/cutlass/_ffi_api.py +++ b/python/tvm/contrib/cutlass/_ffi_api.py @@ -15,6 +15,17 @@ # specific language governing permissions and limitations # under the License. """FFI API for CUTLASS BYOC.""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("contrib.cutlass", __name__) + +# tvm-ffi-stubgen(begin): global/contrib.cutlass +_FFI_INIT_FUNC("contrib.cutlass", __name__) +# tvm-ffi-stubgen(end) diff --git a/python/tvm/contrib/msc/core/_ffi_api.py b/python/tvm/contrib/msc/core/_ffi_api.py index ff027a0dec8e..e2446770e2db 100644 --- a/python/tvm/contrib/msc/core/_ffi_api.py +++ b/python/tvm/contrib/msc/core/_ffi_api.py @@ -15,7 +15,17 @@ # specific language governing permissions and limitations # under the License. """tvm.contrib.msc.core._ffi_api""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -import tvm_ffi -tvm_ffi.init_ffi_api("msc.core", __name__) +# tvm-ffi-stubgen(begin): global/msc.core +_FFI_INIT_FUNC("msc.core", __name__) +# tvm-ffi-stubgen(end) diff --git a/python/tvm/contrib/msc/core/ir/graph.py b/python/tvm/contrib/msc/core/ir/graph.py index 6b40be4bf9de..c0cb4adf5eb6 100644 --- a/python/tvm/contrib/msc/core/ir/graph.py +++ b/python/tvm/contrib/msc/core/ir/graph.py @@ -16,6 +16,8 @@ # under the License. """tvm.contrib.msc.core.ir.graph""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from typing import Dict, Tuple, List, Optional, Union, Iterable, Any import numpy as np import tvm_ffi @@ -46,6 +48,9 @@ class MSCTensor(Object): The prims of the tensor. """ + # tvm-ffi-stubgen(begin): object/msc.core.MSCTensor + # tvm-ffi-stubgen(end) + def __init__( self, name: str, @@ -199,6 +204,9 @@ def ndim(self) -> int: class BaseJoint(Object): """Base class of all MSC Nodes.""" + # tvm-ffi-stubgen(begin): object/msc.core.BaseJoint + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("msc.core.MSCJoint") class MSCJoint(BaseJoint): @@ -224,6 +232,9 @@ class MSCJoint(BaseJoint): The weights of the node. """ + # tvm-ffi-stubgen(begin): object/msc.core.MSCJoint + # tvm-ffi-stubgen(end) + def __init__( self, index: int, @@ -443,6 +454,9 @@ class MSCPrim(BaseJoint): The parents of the prim. """ + # tvm-ffi-stubgen(begin): object/msc.core.MSCPrim + # tvm-ffi-stubgen(end) + def __init__( self, index: int, name: str, optype: str, attrs: Dict[str, str], parents: List[BaseJoint] ): @@ -477,6 +491,9 @@ class WeightJoint(BaseJoint): The friends of the node. """ + # tvm-ffi-stubgen(begin): object/msc.core.WeightJoint + # tvm-ffi-stubgen(end) + def __init__( self, index: int, @@ -567,6 +584,9 @@ def has_attr(self, key: str) -> bool: class BaseGraph(Object): """Base class of all MSC Graphs.""" + # tvm-ffi-stubgen(begin): object/msc.core.BaseGraph + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("msc.core.MSCGraph") class MSCGraph(BaseGraph): @@ -584,6 +604,9 @@ class MSCGraph(BaseGraph): The output names of the graph. """ + # tvm-ffi-stubgen(begin): object/msc.core.MSCGraph + # tvm-ffi-stubgen(end) + def __init__( self, name: str, @@ -969,6 +992,9 @@ class WeightGraph(BaseGraph): The nodes of the graph. """ + # tvm-ffi-stubgen(begin): object/msc.core.WeightGraph + # tvm-ffi-stubgen(end) + def __init__( self, name: str, diff --git a/python/tvm/contrib/msc/framework/tensorflow/_ffi_api.py b/python/tvm/contrib/msc/framework/tensorflow/_ffi_api.py index f7cd2ea43e3e..a51b18e030f8 100644 --- a/python/tvm/contrib/msc/framework/tensorflow/_ffi_api.py +++ b/python/tvm/contrib/msc/framework/tensorflow/_ffi_api.py @@ -15,7 +15,17 @@ # specific language governing permissions and limitations # under the License. """tvm.contrib.msc.framework.tensorflow._ffi_api""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -import tvm_ffi -tvm_ffi.init_ffi_api("msc.framework.tensorflow", __name__) +# tvm-ffi-stubgen(begin): global/msc.framework.tensorflow +_FFI_INIT_FUNC("msc.framework.tensorflow", __name__) +# tvm-ffi-stubgen(end) diff --git a/python/tvm/contrib/msc/framework/tensorrt/_ffi_api.py b/python/tvm/contrib/msc/framework/tensorrt/_ffi_api.py index a09ab875fbed..146282ec8f90 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/_ffi_api.py +++ b/python/tvm/contrib/msc/framework/tensorrt/_ffi_api.py @@ -15,7 +15,17 @@ # specific language governing permissions and limitations # under the License. """tvm.contrib.msc.framework.tensorrt._ffi_api""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -import tvm_ffi -tvm_ffi.init_ffi_api("msc.framework.tensorrt", __name__) +# tvm-ffi-stubgen(begin): global/msc.framework.tensorrt +_FFI_INIT_FUNC("msc.framework.tensorrt", __name__) +# tvm-ffi-stubgen(end) diff --git a/python/tvm/contrib/msc/framework/torch/_ffi_api.py b/python/tvm/contrib/msc/framework/torch/_ffi_api.py index d1f27a53bdcf..a3e88799a8fb 100644 --- a/python/tvm/contrib/msc/framework/torch/_ffi_api.py +++ b/python/tvm/contrib/msc/framework/torch/_ffi_api.py @@ -15,7 +15,17 @@ # specific language governing permissions and limitations # under the License. """tvm.contrib.msc.framework.torch._ffi_api""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -import tvm_ffi -tvm_ffi.init_ffi_api("msc.framework.torch", __name__) +# tvm-ffi-stubgen(begin): global/msc.framework.torch +_FFI_INIT_FUNC("msc.framework.torch", __name__) +# tvm-ffi-stubgen(end) diff --git a/python/tvm/contrib/msc/framework/tvm/_ffi_api.py b/python/tvm/contrib/msc/framework/tvm/_ffi_api.py index c9f63e21eaef..239f553fe977 100644 --- a/python/tvm/contrib/msc/framework/tvm/_ffi_api.py +++ b/python/tvm/contrib/msc/framework/tvm/_ffi_api.py @@ -15,7 +15,17 @@ # specific language governing permissions and limitations # under the License. """tvm.contrib.msc.framework.tvm._ffi_api""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -import tvm_ffi -tvm_ffi.init_ffi_api("msc.framework.tvm", __name__) +# tvm-ffi-stubgen(begin): global/msc.framework.tvm +_FFI_INIT_FUNC("msc.framework.tvm", __name__) +# tvm-ffi-stubgen(end) diff --git a/python/tvm/contrib/msc/plugin/_ffi_api.py b/python/tvm/contrib/msc/plugin/_ffi_api.py index 88f9204f3a02..c1b1f1b0e6d2 100644 --- a/python/tvm/contrib/msc/plugin/_ffi_api.py +++ b/python/tvm/contrib/msc/plugin/_ffi_api.py @@ -15,7 +15,17 @@ # specific language governing permissions and limitations # under the License. """tvm.contrib.msc.plugin._ffi_api""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -import tvm_ffi -tvm_ffi.init_ffi_api("msc.plugin", __name__) +# tvm-ffi-stubgen(begin): global/msc.plugin +_FFI_INIT_FUNC("msc.plugin", __name__) +# tvm-ffi-stubgen(end) diff --git a/python/tvm/contrib/msc/plugin/op/_ffi_api.py b/python/tvm/contrib/msc/plugin/op/_ffi_api.py index 8ca5071cdaf6..d01bdaab758a 100644 --- a/python/tvm/contrib/msc/plugin/op/_ffi_api.py +++ b/python/tvm/contrib/msc/plugin/op/_ffi_api.py @@ -15,7 +15,17 @@ # specific language governing permissions and limitations # under the License. """tvm.contrib.msc.plugin.op._ffi_api""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -import tvm_ffi -tvm_ffi.init_ffi_api("msc.plugin.op", __name__) +# tvm-ffi-stubgen(begin): global/msc.plugin.op +_FFI_INIT_FUNC("msc.plugin.op", __name__) +# tvm-ffi-stubgen(end) diff --git a/python/tvm/driver/_ffi_api.py b/python/tvm/driver/_ffi_api.py index e56426fd5182..62e3efce6d04 100644 --- a/python/tvm/driver/_ffi_api.py +++ b/python/tvm/driver/_ffi_api.py @@ -15,6 +15,17 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.driver""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("driver", __name__) + +# tvm-ffi-stubgen(begin): global/driver +_FFI_INIT_FUNC("driver", __name__) +# tvm-ffi-stubgen(end) diff --git a/python/tvm/ir/_ffi_api.py b/python/tvm/ir/_ffi_api.py index 798e69fca507..4f80a5d4da78 100644 --- a/python/tvm/ir/_ffi_api.py +++ b/python/tvm/ir/_ffi_api.py @@ -15,7 +15,96 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.ir""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import BaseFunc, DictAttrs, DummyGlobalInfo, EnvFunc, FloatImm, FuncType, GlobalInfo, GlobalVar, GlobalVarSupply, IRModule, IntImm, NameSupply, Op, PointerType, PrimExpr, PrimType, Range, RelaxExpr, SequentialSpan, SourceName, Span, StructInfo, TensorMapType, TupleType, Type, VDevice + from target import Target + from tvm_ffi import Object, dtype + from typing import Any, Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("ir", __name__) + +# tvm-ffi-stubgen(begin): global/ir +# fmt: off +_FFI_INIT_FUNC("ir", __name__) +if TYPE_CHECKING: + def BaseFuncCopy(_0: BaseFunc, /) -> BaseFunc: ... + def BaseFuncWithAttr(_0: ObjectRValueRef[BaseFunc], _1: str, _2: Any, /) -> BaseFunc: ... + def BaseFuncWithAttrs(_0: ObjectRValueRef[BaseFunc], _1: Mapping[str, Any], /) -> BaseFunc: ... + def BaseFuncWithoutAttr(_0: ObjectRValueRef[BaseFunc], _1: str, /) -> BaseFunc: ... + def BaseFunc_Attrs(_0: BaseFunc, /) -> DictAttrs: ... + def DebugPrint(_0: Object, /) -> str: ... + def DictAttrsGetDict(_0: DictAttrs, /) -> Mapping[str, Any]: ... + def DummyGlobalInfo() -> DummyGlobalInfo: ... + def EnvFuncCall(*args: Any) -> Any: ... + def EnvFuncGet(_0: str, /) -> EnvFunc: ... + def EnvFuncGetFunction(_0: EnvFunc, /) -> Callable[..., Any]: ... + def ExprStructInfo(_0: RelaxExpr, /) -> StructInfo: ... + def FloatImm(_0: dtype, _1: float, _2: Span, /) -> FloatImm: ... + def FuncType(_0: Sequence[Type], _1: Type, /) -> FuncType: ... + def GetOp(_0: str, /) -> Op: ... + def GlobalVar(_0: str, /) -> GlobalVar: ... + def GlobalVarSupply_FreshGlobal(_0: GlobalVarSupply, _1: str, _2: bool, /) -> GlobalVar: ... + def GlobalVarSupply_IRModule(_0: IRModule, /) -> GlobalVarSupply: ... + def GlobalVarSupply_IRModules(_0: Sequence[IRModule], /) -> GlobalVarSupply: ... + def GlobalVarSupply_NameSupply(_0: NameSupply, /) -> GlobalVarSupply: ... + def GlobalVarSupply_ReserveGlobalVar(_0: GlobalVarSupply, _1: GlobalVar, _2: bool, /) -> None: ... + def GlobalVarSupply_UniqueGlobalFor(_0: GlobalVarSupply, _1: str, _2: bool, /) -> GlobalVar: ... + def IRModule(_0: Mapping[GlobalVar, BaseFunc], _1: Object, _2: Mapping[str, Sequence[GlobalInfo]], /) -> IRModule: ... + def IntImm(_0: dtype, _1: int, _2: Span, /) -> IntImm: ... + def ListOpNames() -> Sequence[str]: ... + def Module_Add(_0: IRModule, _1: GlobalVar, _2: Object, _3: bool, /) -> IRModule: ... + def Module_Clone(_0: IRModule, /) -> IRModule: ... + def Module_ContainGlobalVar(_0: IRModule, _1: str, /) -> bool: ... + def Module_Contains(_0: IRModule, _1: str | GlobalVar, /) -> bool: ... + def Module_FromExpr(_0: RelaxExpr, _1: Mapping[GlobalVar, BaseFunc], /) -> IRModule: ... + def Module_GetAttr(_0: IRModule, _1: str, /) -> Object: ... + def Module_GetAttrs(_0: IRModule, /) -> Object: ... + def Module_GetGlobalVar(_0: IRModule, _1: str, /) -> GlobalVar: ... + def Module_GetGlobalVars(_0: IRModule, /) -> Sequence[GlobalVar]: ... + def Module_Lookup(_0: IRModule, _1: GlobalVar, /) -> BaseFunc: ... + def Module_Lookup_str(_0: IRModule, _1: str, /) -> BaseFunc: ... + def Module_Remove(_0: IRModule, _1: str | GlobalVar, /) -> IRModule: ... + def Module_ReplaceGlobalVars(_0: IRModule, _1: Mapping[str | GlobalVar, str | GlobalVar], /) -> IRModule: ... + def Module_Update(_0: IRModule, _1: IRModule, /) -> None: ... + def Module_UpdateFunction(_0: IRModule, _1: GlobalVar, _2: BaseFunc, /) -> None: ... + def Module_UpdateGlobalInfo(_0: IRModule, _1: str, _2: Sequence[GlobalInfo], /) -> None: ... + def Module_WithAttr(_0: ObjectRValueRef[IRModule], _1: str, _2: Any, /) -> IRModule: ... + def Module_WithAttrs(_0: ObjectRValueRef[IRModule], _1: Mapping[str, Any], /) -> IRModule: ... + def Module_WithoutAttr(_0: ObjectRValueRef[IRModule], _1: str, /) -> IRModule: ... + def NameSupply(_0: str, /) -> NameSupply: ... + def NameSupply_ContainsName(_0: NameSupply, _1: str, _2: bool, /) -> bool: ... + def NameSupply_FreshName(_0: NameSupply, _1: str, _2: bool, _3: bool, /) -> str: ... + def NameSupply_ReserveName(_0: NameSupply, _1: str, _2: bool, /) -> str: ... + def OpAddArgument(_0: Op, _1: str, _2: str, _3: str, /) -> None: ... + def OpGetAttr(_0: Op, _1: str, /) -> Any: ... + def OpHasAttr(_0: Op, _1: str, /) -> bool: ... + def OpResetAttr(_0: Op, _1: str, /) -> None: ... + def OpSetAttr(_0: Op, _1: str, _2: Any, _3: int, /) -> None: ... + def OpSetAttrsTypeKey(_0: Op, _1: str, /) -> None: ... + def OpSetNumInputs(_0: Op, _1: int, /) -> None: ... + def OpSetSupportLevel(_0: Op, _1: int, /) -> None: ... + def PointerType(_0: Type, _1: str, /) -> PointerType: ... + def PrimType(_0: dtype, /) -> PrimType: ... + def Range(_0: PrimExpr, _1: PrimExpr | None, _2: Span, /) -> Range: ... + def Range_from_min_extent(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> Range: ... + def RegisterOp(_0: str, _1: str, /) -> None: ... + def RegisterOpAttr(_0: str, _1: str, _2: Any, _3: int, /) -> None: ... + def RegisterOpLowerIntrinsic(_0: str, _1: Callable[..., Any], _2: str, _3: int, /) -> None: ... + def SequentialSpan(_0: Sequence[Span], /) -> SequentialSpan: ... + def SourceName(_0: str, /) -> SourceName: ... + def Span(_0: SourceName, _1: int, _2: int, _3: int, _4: int, /) -> Span: ... + def TensorMapType(_0: Span, /) -> TensorMapType: ... + def TupleType(_0: Sequence[Type], /) -> TupleType: ... + def VDevice(_0: Target, _1: int, _2: str, /) -> VDevice: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/ir/attrs.py b/python/tvm/ir/attrs.py index fb408cdb8c70..c199c6ea8282 100644 --- a/python/tvm/ir/attrs.py +++ b/python/tvm/ir/attrs.py @@ -15,6 +15,17 @@ # specific language governing permissions and limitations # under the License. """TVM Attribute module, which is mainly used for defining attributes of operators.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping + from typing import Any +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import tvm_ffi from tvm.runtime import Object @@ -30,6 +41,11 @@ class Attrs(Object): Attrs is passed as the first argument to these functions. """ + # tvm-ffi-stubgen(begin): object/ir.Attrs + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def get_int_tuple(self, key): """Get a python int tuple of a key @@ -77,6 +93,12 @@ def __getitem__(self, item): class DictAttrs(Attrs): """Dictionary attributes.""" + # tvm-ffi-stubgen(begin): object/ir.DictAttrs + # fmt: off + __dict__: Mapping[str, Any] + # fmt: on + # tvm-ffi-stubgen(end) + def _dict(self): """Get internal dict""" return _ffi_api.DictAttrsGetDict(self) diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index 651ab392039c..3f4aca0d00fe 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -15,6 +15,18 @@ # specific language governing permissions and limitations # under the License. """Common base structures.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import Source + from typing import Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import tvm.error from tvm_ffi import get_global_func, register_object from tvm.runtime import Object, _ffi_node_api @@ -28,6 +40,11 @@ class Node(Object): @register_object("ir.SourceMap") class SourceMap(Object): + # tvm-ffi-stubgen(begin): object/ir.SourceMap + # fmt: off + source_map: Mapping[SourceName, Source] + # fmt: on + # tvm-ffi-stubgen(end) def add(self, name, content): return get_global_func("SourceMapAdd")(self, name, content) @@ -42,6 +59,12 @@ class SourceName(Object): The name of the source. """ + # tvm-ffi-stubgen(begin): object/ir.SourceName + # fmt: off + name: str + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, name): self.__init_handle_by_constructor__(_ffi_api.SourceName, name) # type: ignore # pylint: disable=no-member @@ -62,6 +85,16 @@ class Span(Object): The column offset of the location. """ + # tvm-ffi-stubgen(begin): object/ir.Span + # fmt: off + source_name: SourceName + line: int + column: int + end_line: int + end_column: int + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, source_name, line, end_line, column, end_column): self.__init_handle_by_constructor__( _ffi_api.Span, source_name, line, end_line, column, end_column # type: ignore # pylint: disable=no-member @@ -81,6 +114,12 @@ class SequentialSpan(Object): The array of spans. """ + # tvm-ffi-stubgen(begin): object/ir.SequentialSpan + # fmt: off + spans: Sequence[Span] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, spans): self.__init_handle_by_constructor__(_ffi_api.SequentialSpan, spans) @@ -92,6 +131,13 @@ class EnvFunc(Object): This is a global function object that can be serialized by its name. """ + # tvm-ffi-stubgen(begin): object/ir.EnvFunc + # fmt: off + name: str + func: Callable[..., Any] + # fmt: on + # tvm-ffi-stubgen(end) + def __call__(self, *args): return _ffi_api.EnvFuncCall(self, *args) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/ir/diagnostics/__init__.py b/python/tvm/ir/diagnostics/__init__.py index 4a521dfa587e..e9b7ebaf4361 100644 --- a/python/tvm/ir/diagnostics/__init__.py +++ b/python/tvm/ir/diagnostics/__init__.py @@ -21,6 +21,18 @@ three key abstractions: a Diagnostic, the DiagnosticContext, and the DiagnosticRenderer. """ +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import IRModule, Span + from typing import Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import enum import tvm_ffi from . import _ffi_api @@ -73,6 +85,14 @@ class DiagnosticLevel(enum.IntEnum): class Diagnostic(Object): """A single diagnostic object from TVM.""" + # tvm-ffi-stubgen(begin): object/Diagnostic + # fmt: off + level: int + span: Span + message: str + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, level, span, message): self.__init_handle_by_constructor__(_ffi_api.Diagnostic, level, span, message) @@ -84,6 +104,12 @@ class DiagnosticRenderer(Object): form of the diagnostics for either human or computer consumption. """ + # tvm-ffi-stubgen(begin): object/DiagnosticRenderer + # fmt: off + renderer: Callable[[DiagnosticContext], None] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, render_func): self.__init_handle_by_constructor__(_ffi_api.DiagnosticRenderer, render_func) @@ -107,6 +133,13 @@ class DiagnosticContext(Object): and contains a renderer. """ + # tvm-ffi-stubgen(begin): object/DiagnosticContext + # fmt: off + module: IRModule + diagnostics: Sequence[Diagnostic] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, module, renderer): self.__init_handle_by_constructor__(_ffi_api.DiagnosticContext, module, renderer) diff --git a/python/tvm/ir/diagnostics/_ffi_api.py b/python/tvm/ir/diagnostics/_ffi_api.py index 65fb2cc896f3..8c183ac50291 100644 --- a/python/tvm/ir/diagnostics/_ffi_api.py +++ b/python/tvm/ir/diagnostics/_ffi_api.py @@ -15,7 +15,34 @@ # specific language governing permissions and limitations # under the License. """FFI for TVM diagnostics.""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ir import IRModule, Span + from typing import Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("diagnostics", __name__) + +# tvm-ffi-stubgen(begin): global/diagnostics +# fmt: off +_FFI_INIT_FUNC("diagnostics", __name__) +if TYPE_CHECKING: + def ClearRenderer() -> None: ... + def Default(_0: IRModule, /) -> DiagnosticContext: ... + def DefaultRenderer() -> DiagnosticRenderer: ... + def Diagnostic(_0: int, _1: Span, _2: str, /) -> Diagnostic: ... + def DiagnosticContext(_0: IRModule, _1: DiagnosticRenderer, /) -> DiagnosticContext: ... + def DiagnosticContextRender(_0: DiagnosticContext, /) -> None: ... + def DiagnosticRenderer(_0: Callable[[DiagnosticContext], None], /) -> DiagnosticRenderer: ... + def DiagnosticRendererRender(_0: DiagnosticRenderer, _1: DiagnosticContext, /) -> None: ... + def Emit(_0: DiagnosticContext, _1: Diagnostic, /) -> None: ... + def GetRenderer() -> DiagnosticRenderer: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index 19abb6bd1eae..b2bd5a9f374f 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -15,6 +15,17 @@ # specific language governing permissions and limitations # under the License. """Common expressions data structures in the IR.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ir import Span + from tvm_ffi import Object, dtype +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from numbers import Number from typing import Optional @@ -30,6 +41,12 @@ class BaseExpr(Node): """Base class of all the expressions.""" + # tvm-ffi-stubgen(begin): object/ir.BaseExpr + # fmt: off + span: Span + # fmt: on + # tvm-ffi-stubgen(end) + span: Optional[Span] @@ -41,6 +58,12 @@ class PrimExpr(BaseExpr): optimizations and integer analysis. """ + # tvm-ffi-stubgen(begin): object/ir.PrimExpr + # fmt: off + dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + dtype: str @@ -48,6 +71,12 @@ class PrimExpr(BaseExpr): class RelaxExpr(BaseExpr): """Base class of all non-primitive expressions.""" + # tvm-ffi-stubgen(begin): object/ir.RelaxExpr + # fmt: off + struct_info_: Object | None + # fmt: on + # tvm-ffi-stubgen(end) + @property def struct_info(self) -> Optional["tvm.relax.StructInfo"]: """Get the struct info field @@ -73,6 +102,12 @@ class GlobalVar(RelaxExpr): The name of the variable. """ + # tvm-ffi-stubgen(begin): object/ir.GlobalVar + # fmt: off + name_hint: str + # fmt: on + # tvm-ffi-stubgen(end) + name_hint: str def __init__(self, name_hint: str): @@ -131,6 +166,14 @@ class Range(Node, Scriptable): if the end argument is not None. Otherwise, it creates `[0, begin)`. """ + # tvm-ffi-stubgen(begin): object/ir.Range + # fmt: off + min: PrimExpr + extent: PrimExpr + span: Span + # fmt: on + # tvm-ffi-stubgen(end) + min: PrimExpr extent: PrimExpr span: Optional[Span] diff --git a/python/tvm/ir/function.py b/python/tvm/ir/function.py index 75718503aae1..76e1897366e0 100644 --- a/python/tvm/ir/function.py +++ b/python/tvm/ir/function.py @@ -16,6 +16,16 @@ # under the License. # pylint: disable=invalid-name """Function definitions.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ir import DictAttrs +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import Union, Dict from enum import IntEnum @@ -40,6 +50,12 @@ class CallingConv(IntEnum): class BaseFunc(RelaxExpr): """Base class of all functions.""" + # tvm-ffi-stubgen(begin): object/ir.BaseFunc + # fmt: off + attrs: DictAttrs + # fmt: on + # tvm-ffi-stubgen(end) + @property def attrs(self): """Return the attrs member of the function.""" diff --git a/python/tvm/ir/global_info.py b/python/tvm/ir/global_info.py index 185e10b88cce..f5731ec0d9c3 100644 --- a/python/tvm/ir/global_info.py +++ b/python/tvm/ir/global_info.py @@ -15,6 +15,16 @@ # specific language governing permissions and limitations # under the License. """Global Info.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from target import Target +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import tvm import tvm_ffi from tvm.runtime.object import Object @@ -25,6 +35,11 @@ class GlobalInfo(Object): """Base node for all global info that can appear in the IR""" + # tvm-ffi-stubgen(begin): object/ir.GlobalInfo + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __eq__(self, other): """Compare two struct info for structural equivalence.""" return tvm.ir.structural_equal(self, other) @@ -39,6 +54,10 @@ def same_as(self, other): @tvm_ffi.register_object("ir.DummyGlobalInfo") class DummyGlobalInfo(GlobalInfo): + # tvm-ffi-stubgen(begin): object/ir.DummyGlobalInfo + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.DummyGlobalInfo, @@ -47,6 +66,13 @@ def __init__(self) -> None: @tvm_ffi.register_object("ir.VDevice") class VDevice(GlobalInfo): + # tvm-ffi-stubgen(begin): object/ir.VDevice + # fmt: off + target: Target + vdevice_id: int + memory_scope: str + # fmt: on + # tvm-ffi-stubgen(end) def __init__( self, target=None, diff --git a/python/tvm/ir/instrument.py b/python/tvm/ir/instrument.py index 0f1bcf3adfda..32e2f5c7dd6a 100644 --- a/python/tvm/ir/instrument.py +++ b/python/tvm/ir/instrument.py @@ -16,6 +16,8 @@ # under the License. # pylint: disable=invalid-name,unused-argument """Common pass instrumentation across IR variants.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) import functools import inspect import re @@ -43,6 +45,12 @@ class PassInstrument(tvm.runtime.Object): """ + # tvm-ffi-stubgen(begin): object/instrument.PassInstrument + # fmt: off + name: str + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self): # initialize handle in case pi_cls creation failed. cls = type(self) @@ -234,6 +242,12 @@ def create_pass_instrument(pi_cls): class PassTimingInstrument(tvm.runtime.Object): """A wrapper to create a passes time instrument that implemented in C++""" + # tvm-ffi-stubgen(begin): object/instrument.PassInstrument + # fmt: off + name: str + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self): self.__init_handle_by_constructor__(_ffi_instrument_api.MakePassTimingInstrument) diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 21c86c05ec4c..df29574c20b9 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -18,6 +18,17 @@ from __future__ import annotations +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import BaseFunc, DictAttrs, GlobalInfo, GlobalVar, SourceMap +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import Dict, Union import tvm @@ -43,6 +54,16 @@ class IRModule(Node, Scriptable): Map of global var to BaseFunc """ + # tvm-ffi-stubgen(begin): object/ir.IRModule + # fmt: off + functions: Mapping[GlobalVar, BaseFunc] + global_var_map_: Mapping[str, GlobalVar] + source_map: SourceMap + attrs: DictAttrs + global_infos: Mapping[str, Sequence[GlobalInfo]] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, functions=None, attrs=None, global_infos=None): if functions is None: functions = {} diff --git a/python/tvm/ir/op.py b/python/tvm/ir/op.py index 5b62d3fe8df7..53e5bdb15600 100644 --- a/python/tvm/ir/op.py +++ b/python/tvm/ir/op.py @@ -16,6 +16,17 @@ # under the License. # pylint: disable=invalid-name """Primitive operators in the TVM IR.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import AttrFieldInfo, FuncType +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import tvm_ffi from . import _ffi_api @@ -26,6 +37,18 @@ class Op(RelaxExpr): """Primitive operator in the IR.""" + # tvm-ffi-stubgen(begin): object/ir.Op + # fmt: off + name: str + op_type: FuncType + description: str + arguments: Sequence[AttrFieldInfo] + attrs_type_key: str + num_inputs: int + support_level: int + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self): raise RuntimeError("Cannot create op, use get instead") diff --git a/python/tvm/ir/supply.py b/python/tvm/ir/supply.py index bc38089b2254..749f2fbfb66e 100644 --- a/python/tvm/ir/supply.py +++ b/python/tvm/ir/supply.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Suppliers that are used to guarantee uniqueness of names and GlobalVars.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) import tvm import tvm_ffi from tvm import Object, IRModule @@ -30,6 +32,11 @@ class NameSupply(Object): prefix: The prefix to be added to the generated names. """ + # tvm-ffi-stubgen(begin): object/ir.NameSupply + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, prefix=""): self.__init_handle_by_constructor__(_ffi_api.NameSupply, prefix) @@ -91,6 +98,11 @@ class GlobalVarSupply(Object): The IRModules used to build this GlobalVarSupply or a NameSupply. """ + # tvm-ffi-stubgen(begin): object/ir.GlobalVarSupply + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, value=None): if value is None: name_supply = NameSupply("") diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index fd9a2ac3b212..1384fff3db4a 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -16,6 +16,18 @@ # under the License. # pylint: disable=invalid-name,unused-argument """Common pass infrastructure across IR variants.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from instrument import PassInstrument + from typing import Any +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import inspect import functools @@ -44,6 +56,15 @@ class PassInfo(tvm.runtime.Object): The list of passes that are required by a certain pass. """ + # tvm-ffi-stubgen(begin): object/transform.PassInfo + # fmt: off + opt_level: int + name: str + required: Sequence[str] + traceable: bool + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, opt_level, name, required=None, traceable=False): self.__init_handle_by_constructor__( _ffi_transform_api.PassInfo, opt_level, name, required, traceable @@ -73,6 +94,17 @@ class PassContext(tvm.runtime.Object): Additional configurations for specific passes. """ + # tvm-ffi-stubgen(begin): object/transform.PassContext + # fmt: off + opt_level: int + required_pass: Sequence[str] + disabled_pass: Sequence[str] + instruments: Sequence[PassInstrument] + config: Mapping[str, Any] + diag_ctx: DiagnosticContext | None + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, opt_level=2, @@ -145,6 +177,11 @@ class Pass(tvm.runtime.Object): conveniently interact with the base class. """ + # tvm-ffi-stubgen(begin): object/transform.Pass + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + @property def info(self): """Get the pass meta.""" @@ -177,6 +214,12 @@ class ModulePass(Pass): The same rule applies to FunctionPass as well. """ + # tvm-ffi-stubgen(begin): object/transform.ModulePass + # fmt: off + pass_info: PassInfo + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("transform.Sequential") class Sequential(Pass): @@ -205,6 +248,13 @@ class Sequential(Pass): The list of passes that the sequential pass is dependent on. """ + # tvm-ffi-stubgen(begin): object/transform.Sequential + # fmt: off + pass_info: PassInfo + passes: Sequence[Pass] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, passes=None, opt_level=0, name="sequential", required=None, traceable=False): passes = passes if passes else [] if not isinstance(passes, (list, tuple)): diff --git a/python/tvm/ir/type.py b/python/tvm/ir/type.py index 68bed8fb69f0..de1e6a06a85e 100644 --- a/python/tvm/ir/type.py +++ b/python/tvm/ir/type.py @@ -15,6 +15,18 @@ # specific language governing permissions and limitations # under the License. """Unified type system in the project.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import Span + from tvm_ffi import dtype +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import tvm import tvm_ffi from tvm.runtime import Scriptable @@ -27,6 +39,12 @@ class Type(Node, Scriptable): """The base class of all types.""" + # tvm-ffi-stubgen(begin): object/ir.Type + # fmt: off + span: Span + # fmt: on + # tvm-ffi-stubgen(end) + def __eq__(self, other): """Compare two types for structural equivalence.""" return bool(tvm.ir.structural_equal(self, other)) @@ -49,6 +67,12 @@ class PrimType(Type): The runtime data type relates to the primtype. """ + # tvm-ffi-stubgen(begin): object/ir.PrimType + # fmt: off + dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, dtype): self.__init_handle_by_constructor__(_ffi_api.PrimType, dtype) @@ -66,6 +90,13 @@ class PointerType(Type): The storage scope into which the pointer addresses. """ + # tvm-ffi-stubgen(begin): object/ir.PointerType + # fmt: off + element_type: Type + storage_scope: str + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, element_type, storage_scope=""): self.__init_handle_by_constructor__(_ffi_api.PointerType, element_type, storage_scope) @@ -80,6 +111,13 @@ class TupleType(Type): The fields in the tuple """ + # tvm-ffi-stubgen(begin): object/ir.TupleType + # fmt: off + fields: Sequence[Type] + span: Span + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, fields): self.__init_handle_by_constructor__(_ffi_api.TupleType, fields) @@ -102,6 +140,14 @@ class FuncType(Type): The return type. """ + # tvm-ffi-stubgen(begin): object/ir.FuncType + # fmt: off + arg_types: Sequence[Type] + ret_type: Type + span: Span + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, arg_types, ret_type): self.__init_handle_by_constructor__( _ffi_api.FuncType, @@ -120,6 +166,12 @@ class TensorMapType(Type): The span information. """ + # tvm-ffi-stubgen(begin): object/ir.TensorMapType + # fmt: off + span: Span + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, span=None): self.__init_handle_by_constructor__( _ffi_api.TensorMapType, span # pylint: disable=no-member diff --git a/python/tvm/ir/type_relation.py b/python/tvm/ir/type_relation.py index 70950958024d..c625577bcadf 100644 --- a/python/tvm/ir/type_relation.py +++ b/python/tvm/ir/type_relation.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Type relation and function for type checking.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) import tvm_ffi from .type import Type, TypeConstraint @@ -39,6 +41,9 @@ class TypeCall(Type): The type function application. """ + # tvm-ffi-stubgen(begin): object/TypeCall + # tvm-ffi-stubgen(end) + def __init__(self, func, args): self.__init_handle_by_constructor__(_ffi_api.TypeCall, func, args) @@ -71,5 +76,8 @@ class TypeRelation(TypeConstraint): The type relation. """ + # tvm-ffi-stubgen(begin): object/TypeRelation + # tvm-ffi-stubgen(end) + def __init__(self, func, args, num_inputs, attrs): self.__init_handle_by_constructor__(_ffi_api.TypeRelation, func, args, num_inputs, attrs) diff --git a/python/tvm/meta_schedule/_ffi_api.py b/python/tvm/meta_schedule/_ffi_api.py index 1a06aef5a482..f12a70e0e396 100644 --- a/python/tvm/meta_schedule/_ffi_api.py +++ b/python/tvm/meta_schedule/_ffi_api.py @@ -15,6 +15,179 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.meta_schedule""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ctypes import c_void_p + from ir import FloatImm, IRModule, IntImm + from meta_schedule import ArgInfo, Builder, BuilderInput, BuilderResult, CostModel, Database, EvolutionarySearch, ExtractedTask, FeatureExtractor, MeasureCallback, MeasureCandidate, Mutator, Postproc, Profiler, Runner, RunnerFuture, RunnerInput, RunnerResult, ScheduleRule, SearchStrategy, SpaceGenerator, TaskScheduler, TensorInfo, TuneContext, TuningRecord, Workload + from target import Target + from tir import BlockRV, PrimFunc, Schedule, Trace + from tvm_ffi import Object, Shape, Tensor, dtype + from typing import Any, Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("meta_schedule", __name__) # pylint: disable=protected-access + +# tvm-ffi-stubgen(begin): global/meta_schedule +# fmt: off +_FFI_INIT_FUNC("meta_schedule", __name__) +if TYPE_CHECKING: + def ArgInfoAsJSON(_0: ArgInfo, /) -> Object: ... + def ArgInfoFromEntryFunc(_0: IRModule, _1: bool, /) -> Sequence[ArgInfo]: ... + def ArgInfoFromJSON(_0: Object, /) -> ArgInfo: ... + def ArgInfoFromPrimFunc(_0: PrimFunc, /) -> Sequence[ArgInfo]: ... + def BuilderBuild(_0: Builder, _1: Sequence[BuilderInput], /) -> Sequence[BuilderResult]: ... + def BuilderInput(_0: IRModule, _1: Target, _2: Mapping[str, Tensor] | None, /) -> BuilderInput: ... + def BuilderPyBuilder(_0: Callable[[Sequence[BuilderInput]], Sequence[BuilderResult]], /) -> Builder: ... + def BuilderResult(_0: str | None, _1: str | None, /) -> BuilderResult: ... + def CostModelLoad(_0: CostModel, _1: str, /) -> None: ... + def CostModelPredict(_0: CostModel, _1: TuneContext, _2: Sequence[MeasureCandidate], _3: c_void_p, /) -> None: ... + def CostModelPyCostModel(_0: Callable[[str], None], _1: Callable[[str], None], _2: Callable[[TuneContext, Sequence[MeasureCandidate], Sequence[RunnerResult]], None], _3: Callable[[TuneContext, Sequence[MeasureCandidate], c_void_p], None], _4: Callable[[], str], /) -> CostModel: ... + def CostModelSave(_0: CostModel, _1: str, /) -> None: ... + def CostModelUpdate(_0: CostModel, _1: TuneContext, _2: Sequence[MeasureCandidate], _3: Sequence[RunnerResult], /) -> None: ... + def DatabaseCommitTuningRecord(_0: Database, _1: TuningRecord, /) -> None: ... + def DatabaseCommitWorkload(_0: Database, _1: IRModule, /) -> Workload: ... + def DatabaseCurrent() -> Database | None: ... + def DatabaseDumpPruned(_0: Database, _1: Database, /) -> None: ... + def DatabaseEnterWithScope(_0: Database, /) -> None: ... + def DatabaseExitWithScope(_0: Database, /) -> None: ... + def DatabaseGetAllTuningRecords(_0: Database, /) -> Sequence[TuningRecord]: ... + def DatabaseGetTopK(_0: Database, _1: Workload, _2: int, /) -> Sequence[TuningRecord]: ... + def DatabaseHasWorkload(_0: Database, _1: IRModule, /) -> bool: ... + def DatabaseJSONDatabase(_0: str, _1: str, _2: bool, _3: str, /) -> Database: ... + def DatabaseMemoryDatabase(_0: str, /) -> Database: ... + def DatabaseOrderedUnionDatabase(_0: Sequence[Database], /) -> Database: ... + def DatabasePyDatabase(_0: Callable[[IRModule], bool], _1: Callable[[IRModule], Workload], _2: Callable[[TuningRecord], None], _3: Callable[[Workload, int], Sequence[TuningRecord]], _4: Callable[[], Sequence[TuningRecord]], _5: Callable[[IRModule, Target, str], TuningRecord | None], _6: Callable[[IRModule, Target, str], Schedule | None], _7: Callable[[IRModule, Target, str], IRModule | None], _8: Callable[[], int], _9: str, /) -> Database: ... + def DatabaseQueryIRModule(_0: Database, _1: IRModule, _2: Target, _3: str, /) -> IRModule | None: ... + def DatabaseQuerySchedule(_0: Database, _1: IRModule, _2: Target, _3: str, /) -> Schedule | None: ... + def DatabaseQueryTuningRecord(_0: Database, _1: IRModule, _2: Target, _3: str, /) -> TuningRecord | None: ... + def DatabaseScheduleFnDatabase(_0: Callable[[Schedule], bool], _1: str, /) -> Database: ... + def DatabaseSize(_0: Database, /) -> int: ... + def DatabaseUnionDatabase(_0: Sequence[Database], /) -> Database: ... + def ExtractedTask(_0: str, _1: IRModule, _2: Target, _3: Sequence[IRModule], _4: int, /) -> ExtractedTask: ... + def FeatureExtractorExtractFrom(_0: FeatureExtractor, _1: TuneContext, _2: Sequence[MeasureCandidate], /) -> Sequence[Tensor]: ... + def FeatureExtractorPerStoreFeature(_0: int, _1: int, _2: int, _3: bool, /) -> FeatureExtractor: ... + def FeatureExtractorPyFeatureExtractor(_0: Callable[[TuneContext, Sequence[MeasureCandidate]], Sequence[Tensor]], _1: Callable[[], str], /) -> FeatureExtractor: ... + def MeasureCallbackAddToDatabase() -> MeasureCallback: ... + def MeasureCallbackApply(_0: MeasureCallback, _1: TaskScheduler, _2: int, _3: Sequence[MeasureCandidate], _4: Sequence[BuilderResult], _5: Sequence[RunnerResult], /) -> None: ... + def MeasureCallbackDefault() -> Sequence[MeasureCallback]: ... + def MeasureCallbackPyMeasureCallback(_0: Callable[[TaskScheduler, int, Sequence[MeasureCandidate], Sequence[BuilderResult], Sequence[RunnerResult]], None], _1: Callable[[], str], /) -> MeasureCallback: ... + def MeasureCallbackRemoveBuildArtifact() -> MeasureCallback: ... + def MeasureCallbackUpdateCostModel() -> MeasureCallback: ... + def MeasureCandidate(_0: Schedule, _1: Sequence[ArgInfo] | None, /) -> MeasureCandidate: ... + def MutateThreadBinding() -> Mutator: ... + def MutatorApply(_0: Mutator, _1: Trace, _2: int, /) -> Trace | None: ... + def MutatorClone(_0: Mutator, /) -> Mutator: ... + def MutatorDefaultCUDA() -> Mapping[Mutator, FloatImm]: ... + def MutatorDefaultCUDATensorCore() -> Mapping[Mutator, FloatImm]: ... + def MutatorDefaultHexagon() -> Mapping[Mutator, FloatImm]: ... + def MutatorDefaultLLVM() -> Mapping[Mutator, FloatImm]: ... + def MutatorInitializeWithTuneContext(_0: Mutator, _1: TuneContext, /) -> None: ... + def MutatorMutateComputeLocation() -> Mutator: ... + def MutatorMutateParallel(_0: int, /) -> Mutator: ... + def MutatorMutateTileSize() -> Mutator: ... + def MutatorMutateUnroll() -> Mutator: ... + def MutatorPyMutator(_0: Callable[[TuneContext], None], _1: Callable[[Trace, int], Trace | None], _2: Callable[[], Mutator], _3: Callable[[], str], /) -> Mutator: ... + def PostprocApply(_0: Postproc, _1: Schedule, /) -> bool: ... + def PostprocClone(_0: Postproc, /) -> Postproc: ... + def PostprocDefaultCUDA() -> Sequence[Postproc]: ... + def PostprocDefaultCUDATensorCore() -> Sequence[Postproc]: ... + def PostprocDefaultHexagon() -> Sequence[Postproc]: ... + def PostprocDefaultLLVM() -> Sequence[Postproc]: ... + def PostprocDisallowAsyncStridedMemCopy() -> Postproc: ... + def PostprocDisallowDynamicLoop() -> Postproc: ... + def PostprocInitializeWithTuneContext(_0: Postproc, _1: TuneContext, /) -> None: ... + def PostprocPyPostproc(_0: Callable[[TuneContext], None], _1: Callable[[Schedule], bool], _2: Callable[[], Postproc], _3: Callable[[], str], /) -> Postproc: ... + def PostprocRewriteCooperativeFetch() -> Postproc: ... + def PostprocRewriteLayout() -> Postproc: ... + def PostprocRewriteParallelVectorizeUnroll() -> Postproc: ... + def PostprocRewriteReductionBlock() -> Postproc: ... + def PostprocRewriteTensorize(_0: bool, /) -> Postproc: ... + def PostprocRewriteUnboundBlock(_0: int, /) -> Postproc: ... + def PostprocVerifyGPUCode() -> Postproc: ... + def PostprocVerifyVTCMLimit() -> Postproc: ... + def Profiler() -> Profiler: ... + def ProfilerCurrent() -> Profiler | None: ... + def ProfilerEnterWithScope(_0: Profiler, /) -> None: ... + def ProfilerExitWithScope(_0: Profiler, /) -> None: ... + def ProfilerGet(_0: Profiler, /) -> Mapping[str, FloatImm]: ... + def ProfilerTable(_0: Profiler, /) -> str: ... + def ProfilerTimedScope(_0: str, /) -> Callable[..., Any]: ... + def RunnerFuture(_0: Callable[[], bool], _1: Callable[[], RunnerResult], /) -> RunnerFuture: ... + def RunnerFutureDone(_0: RunnerFuture, /) -> bool: ... + def RunnerFutureResult(_0: RunnerFuture, /) -> RunnerResult: ... + def RunnerInput(_0: str, _1: str, _2: Sequence[ArgInfo], /) -> RunnerInput: ... + def RunnerPyRunner(_0: Callable[[Sequence[RunnerInput]], Sequence[RunnerFuture]], /) -> Runner: ... + def RunnerResult(_0: Sequence[FloatImm] | None, _1: str | None, /) -> RunnerResult: ... + def RunnerRun(_0: Runner, _1: Sequence[RunnerInput], /) -> Sequence[RunnerFuture]: ... + def ScheduleRuleAddRFactor(_0: int, _1: IntImm | None, /) -> ScheduleRule: ... + def ScheduleRuleApply(_0: ScheduleRule, _1: Schedule, _2: BlockRV, /) -> Sequence[Schedule]: ... + def ScheduleRuleApplyCustomRule() -> ScheduleRule: ... + def ScheduleRuleAutoBind(_0: int, _1: Sequence[IntImm], _2: int, /) -> ScheduleRule: ... + def ScheduleRuleAutoInline(_0: bool, _1: bool, _2: bool, _3: bool, _4: bool, _5: bool, _6: Sequence[str] | None, /) -> ScheduleRule: ... + def ScheduleRuleClone(_0: ScheduleRule, /) -> ScheduleRule: ... + def ScheduleRuleCrossThreadReduction(_0: Sequence[IntImm], /) -> ScheduleRule: ... + def ScheduleRuleDefaultARM(_0: str, /) -> Sequence[ScheduleRule]: ... + def ScheduleRuleDefaultCUDA() -> Sequence[ScheduleRule]: ... + def ScheduleRuleDefaultCUDATensorCore() -> Sequence[ScheduleRule]: ... + def ScheduleRuleDefaultHexagon() -> Sequence[ScheduleRule]: ... + def ScheduleRuleDefaultLLVM() -> Sequence[ScheduleRule]: ... + def ScheduleRuleInitializeWithTuneContext(_0: ScheduleRule, _1: TuneContext, /) -> None: ... + def ScheduleRuleInlineConstantScalars() -> ScheduleRule: ... + def ScheduleRuleMultiLevelTiling(_0: str, _1: Sequence[str] | None, _2: IntImm | None, _3: Sequence[IntImm] | None, _4: Mapping[str, Any] | None, _5: Mapping[str, Any] | None, _6: Callable[..., Any] | None, /) -> ScheduleRule: ... + def ScheduleRuleMultiLevelTilingTensorCore(_0: Sequence[Mapping[str, str]], _1: str, _2: Sequence[str] | None, _3: IntImm | None, _4: Sequence[IntImm] | None, _5: Mapping[str, Any] | None, _6: Mapping[str, Any] | None, _7: bool, /) -> ScheduleRule: ... + def ScheduleRuleMultiLevelTilingWideVector(_0: str, _1: IntImm, _2: IntImm | None, _3: Mapping[str, Any] | None, _4: Mapping[str, Any] | None, /) -> ScheduleRule: ... + def ScheduleRuleMultiLevelTilingWithIntrin(_0: str, _1: str, _2: Sequence[str] | None, _3: IntImm | None, _4: Sequence[IntImm] | None, _5: Mapping[str, Any] | None, _6: Mapping[str, Any] | None, /) -> ScheduleRule: ... + def ScheduleRuleParallelizeVectorizeUnroll(_0: int, _1: int, _2: Sequence[IntImm], _3: bool, /) -> ScheduleRule: ... + def ScheduleRulePyScheduleRule(_0: Callable[[TuneContext], None], _1: Callable[[Schedule, BlockRV], Sequence[Schedule]], _2: Callable[[], ScheduleRule], _3: Callable[[], str], /) -> ScheduleRule: ... + def ScheduleRuleRandomComputeLocation() -> ScheduleRule: ... + def ScheduleUsingAnchorTrace(_0: Schedule, _1: Trace, _2: Target, /) -> None: ... + def SearchStrategyClone(_0: SearchStrategy, /) -> SearchStrategy: ... + def SearchStrategyEvolutionarySearch(_0: int, _1: float, _2: int, _3: int, _4: int, _5: float, _6: int, _7: float, /) -> SearchStrategy: ... + def SearchStrategyEvolutionarySearchEvolveWithCostModel(_0: EvolutionarySearch, _1: Sequence[Schedule], _2: int, /) -> Sequence[Schedule]: ... + def SearchStrategyEvolutionarySearchSampleInitPopulation(_0: EvolutionarySearch, _1: int, /) -> Sequence[Schedule]: ... + def SearchStrategyGenerateMeasureCandidates(_0: SearchStrategy, /) -> Sequence[MeasureCandidate] | None: ... + def SearchStrategyInitializeWithTuneContext(_0: SearchStrategy, _1: TuneContext, /) -> None: ... + def SearchStrategyNotifyRunnerResults(_0: SearchStrategy, _1: Sequence[MeasureCandidate], _2: Sequence[RunnerResult], /) -> None: ... + def SearchStrategyPostTuning(_0: SearchStrategy, /) -> None: ... + def SearchStrategyPreTuning(_0: SearchStrategy, _1: int, _2: int, _3: Sequence[Schedule], _4: Database | None, _5: CostModel | None, /) -> None: ... + def SearchStrategyPySearchStrategy(_0: Callable[[TuneContext], None], _1: Callable[[int, int, Sequence[Schedule], Database | None, CostModel | None], None], _2: Callable[[], None], _3: Callable[[], Sequence[MeasureCandidate] | None], _4: Callable[[Sequence[MeasureCandidate], Sequence[RunnerResult]], None], _5: Callable[[], SearchStrategy], /) -> SearchStrategy: ... + def SearchStrategyReplayFunc() -> SearchStrategy: ... + def SearchStrategyReplayTrace(_0: int, /) -> SearchStrategy: ... + def SpaceGeneratorClone(_0: SpaceGenerator, /) -> SpaceGenerator: ... + def SpaceGeneratorGenerateDesignSpace(_0: SpaceGenerator, _1: IRModule, /) -> Sequence[Schedule]: ... + def SpaceGeneratorInitializeWithTuneContext(_0: SpaceGenerator, _1: TuneContext, /) -> None: ... + def SpaceGeneratorPostOrderApply(_0: Callable[..., Any], _1: Sequence[ScheduleRule] | None, _2: Sequence[Postproc] | None, _3: Mapping[Mutator, FloatImm] | None, /) -> SpaceGenerator: ... + def SpaceGeneratorPySpaceGenerator(_0: Sequence[ScheduleRule] | None, _1: Sequence[Postproc] | None, _2: Mapping[Mutator, FloatImm] | None, _3: Callable[[TuneContext], None], _4: Callable[[IRModule], Sequence[Schedule]], _5: Callable[[], SpaceGenerator], /) -> SpaceGenerator: ... + def SpaceGeneratorScheduleFn(_0: Callable[..., Any], _1: Sequence[ScheduleRule] | None, _2: Sequence[Postproc] | None, _3: Mapping[Mutator, FloatImm] | None, /) -> SpaceGenerator: ... + def SpaceGeneratorSpaceGeneratorUnion(_0: Sequence[SpaceGenerator], _1: Sequence[ScheduleRule] | None, _2: Sequence[Postproc] | None, _3: Mapping[Mutator, FloatImm] | None, /) -> SpaceGenerator: ... + def TaskSchedulerGradientBased(_0: Callable[..., Any], _1: float, _2: int, _3: int, /) -> TaskScheduler: ... + def TaskSchedulerJoinRunningTask(_0: TaskScheduler, _1: int, /) -> Sequence[RunnerResult]: ... + def TaskSchedulerNextTaskId(_0: TaskScheduler, /) -> int: ... + def TaskSchedulerPrintTuningStatistics(_0: TaskScheduler, /) -> None: ... + def TaskSchedulerPyTaskScheduler(_0: Callable[..., Any], _1: Callable[[], int], _2: Callable[[int], Sequence[RunnerResult]], _3: Callable[[Sequence[TuneContext], Sequence[FloatImm], int, int, int, Builder, Runner, Sequence[MeasureCallback], Database | None, CostModel | None], None], /) -> TaskScheduler: ... + def TaskSchedulerRoundRobin(_0: Callable[..., Any], /) -> TaskScheduler: ... + def TaskSchedulerTerminateTask(_0: TaskScheduler, _1: int, /) -> None: ... + def TaskSchedulerTouchTask(_0: TaskScheduler, _1: int, /) -> None: ... + def TaskSchedulerTune(_0: TaskScheduler, _1: Sequence[TuneContext], _2: Sequence[FloatImm], _3: int, _4: int, _5: int, _6: Builder, _7: Runner, _8: Sequence[MeasureCallback], _9: Database | None, _10: CostModel | None, /) -> None: ... + def TensorInfo(_0: dtype, _1: Shape, /) -> TensorInfo: ... + def TuneContext(_0: IRModule | None, _1: Target | None, _2: SpaceGenerator | None, _3: SearchStrategy | None, _4: str | None, _5: int, _6: int, _7: Callable[..., Any], /) -> TuneContext: ... + def TuneContextClone(_0: TuneContext, /) -> TuneContext: ... + def TuneContextInitialize(_0: TuneContext, /) -> None: ... + def TuningRecord(_0: Trace, _1: Workload, _2: Sequence[FloatImm] | None, _3: Target | None, _4: Sequence[ArgInfo] | None, /) -> TuningRecord: ... + def TuningRecordAsJSON(_0: TuningRecord, /) -> Object: ... + def TuningRecordAsMeasureCandidate(_0: TuningRecord, /) -> MeasureCandidate: ... + def TuningRecordFromJSON(_0: Object, _1: Workload, /) -> TuningRecord: ... + def Workload(_0: IRModule, /) -> Workload: ... + def WorkloadAsJSON(_0: Workload, /) -> Object: ... + def WorkloadFromJSON(_0: Object, /) -> Workload: ... + def _SHash2Hex(_0: Object, /) -> str: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/meta_schedule/arg_info.py b/python/tvm/meta_schedule/arg_info.py index 3f8d721ed1f0..415f165aa808 100644 --- a/python/tvm/meta_schedule/arg_info.py +++ b/python/tvm/meta_schedule/arg_info.py @@ -15,6 +15,16 @@ # specific language governing permissions and limitations # under the License. """The argument information""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from tvm_ffi import Shape, dtype +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import Any, List, Union from tvm_ffi import register_object @@ -30,6 +40,11 @@ class ArgInfo(Object): """Argument information""" + # tvm-ffi-stubgen(begin): object/meta_schedule.ArgInfo + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def as_json(self) -> Any: """Converts the ArgInfo to its corresponding JSON representation.""" return _json_de_tvm(_ffi_api.ArgInfoAsJSON(self)) # type: ignore # pylint: disable=no-member @@ -97,6 +112,13 @@ class TensorInfo(ArgInfo): The shape of the tensor. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.TensorInfo + # fmt: off + dtype: dtype + shape: Shape + # fmt: on + # tvm-ffi-stubgen(end) + dtype: DataType shape: ShapeTuple diff --git a/python/tvm/meta_schedule/builder/builder.py b/python/tvm/meta_schedule/builder/builder.py index 39493781404a..d1024eac7468 100644 --- a/python/tvm/meta_schedule/builder/builder.py +++ b/python/tvm/meta_schedule/builder/builder.py @@ -15,6 +15,20 @@ # specific language governing permissions and limitations # under the License. """Meta Schedule builders that translate IRModule to runtime.Module, and then export""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import IRModule + from target import Target + from tvm_ffi import Tensor + from typing import Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import Callable, Dict, List, Optional, Union # isort: off @@ -43,6 +57,14 @@ class BuilderInput(Object): The parameters for Relax build module """ + # tvm-ffi-stubgen(begin): object/meta_schedule.BuilderInput + # fmt: off + mod: IRModule + target: Target + params: Mapping[str, Tensor] | None + # fmt: on + # tvm-ffi-stubgen(end) + mod: IRModule target: Target params: Optional[Dict[str, Tensor]] @@ -84,6 +106,13 @@ class BuilderResult(Object): The error message. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.BuilderResult + # fmt: off + artifact_path: str | None + error_msg: str | None + # fmt: on + # tvm-ffi-stubgen(end) + artifact_path: Optional[str] error_msg: Optional[str] @@ -112,6 +141,11 @@ def __init__( class Builder(Object): """The abstract builder interface.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.Builder + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + BuilderType = Union["Builder", Literal["local"]] def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]: @@ -165,6 +199,12 @@ class _PyBuilder(Builder): See also: PyBuilder """ + # tvm-ffi-stubgen(begin): object/meta_schedule.PyBuilder + # fmt: off + f_build: Callable[[Sequence[BuilderInput]], Sequence[BuilderResult]] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, f_build: Callable = None): """Constructor.""" diff --git a/python/tvm/meta_schedule/cost_model/cost_model.py b/python/tvm/meta_schedule/cost_model/cost_model.py index f51d2f2ac89b..129fc276acfa 100644 --- a/python/tvm/meta_schedule/cost_model/cost_model.py +++ b/python/tvm/meta_schedule/cost_model/cost_model.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Meta Schedule CostModel.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) import ctypes from typing import Callable, List, Union @@ -38,6 +40,11 @@ class CostModel(Object): """Cost model.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.CostModel + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + CostModelType = Union["CostModel", Literal["xgb", "mlp", "random"]] def load(self, path: str) -> None: @@ -159,6 +166,11 @@ class _PyCostModel(CostModel): See also: PyCostModel """ + # tvm-ffi-stubgen(begin): object/meta_schedule.PyCostModel + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, f_load: Callable = None, diff --git a/python/tvm/meta_schedule/database/database.py b/python/tvm/meta_schedule/database/database.py index 08bcbd33c7ad..e8d7f4f36ef2 100644 --- a/python/tvm/meta_schedule/database/database.py +++ b/python/tvm/meta_schedule/database/database.py @@ -15,6 +15,20 @@ # specific language governing permissions and limitations # under the License. """TuningRecord database""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import FloatImm, IRModule + from meta_schedule import ArgInfo + from target import Target + from tir import Trace +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import Any, Callable, List, Optional, Union # isort: off @@ -43,6 +57,12 @@ class Workload(Object): The workload's IRModule """ + # tvm-ffi-stubgen(begin): object/meta_schedule.Workload + # fmt: off + mod: IRModule + # fmt: on + # tvm-ffi-stubgen(end) + mod: IRModule def __init__(self, mod: IRModule) -> None: @@ -97,6 +117,16 @@ class TuningRecord(Object): The argument information of the tuning record. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.TuningRecord + # fmt: off + trace: Trace + workload: Workload + run_secs: Sequence[FloatImm] | None + target: Target | None + args_info: Sequence[ArgInfo] | None + # fmt: on + # tvm-ffi-stubgen(end) + trace: Trace workload: Workload run_secs: Optional[List[float]] @@ -164,6 +194,11 @@ def from_json(json_obj: Any, workload: Workload) -> "TuningRecord": class Database(Object): """The abstract database interface.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.Database + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + DatabaseType = Union["Database", Literal["json", "memory"]] def has_workload(self, mod: IRModule) -> bool: @@ -436,6 +471,11 @@ class _PyDatabase(Database): See also: PyDatabase """ + # tvm-ffi-stubgen(begin): object/meta_schedule.PyDatabase + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, f_has_workload: Callable = None, diff --git a/python/tvm/meta_schedule/database/json_database.py b/python/tvm/meta_schedule/database/json_database.py index 7c6f7459cacc..333e7b65a3f9 100644 --- a/python/tvm/meta_schedule/database/json_database.py +++ b/python/tvm/meta_schedule/database/json_database.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """The default database that uses a JSON File to store tuning records""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) import os.path as osp from typing import Optional @@ -46,6 +48,13 @@ class JSONDatabase(Database): For the definition of the anchor block, see tir/analysis/analysis.py. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.JSONDatabase + # fmt: off + path_workload: str + path_tuning_record: str + # fmt: on + # tvm-ffi-stubgen(end) + path_workload: str path_tuning_record: str diff --git a/python/tvm/meta_schedule/database/memory_database.py b/python/tvm/meta_schedule/database/memory_database.py index 1d6d4121231c..1fcbaec04758 100644 --- a/python/tvm/meta_schedule/database/memory_database.py +++ b/python/tvm/meta_schedule/database/memory_database.py @@ -15,6 +15,17 @@ # specific language governing permissions and limitations # under the License. """A database that stores TuningRecords in memory""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from meta_schedule import TuningRecord, Workload +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -39,6 +50,13 @@ class MemoryDatabase(Database): For the definition of the anchor block, see tir/analysis/analysis.py. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.MemoryDatabase + # fmt: off + records: Sequence[TuningRecord] + workloads: Sequence[Workload] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, module_equality: str = "structural", diff --git a/python/tvm/meta_schedule/database/ordered_union_database.py b/python/tvm/meta_schedule/database/ordered_union_database.py index 717d2f3001c9..6284b2024a0d 100644 --- a/python/tvm/meta_schedule/database/ordered_union_database.py +++ b/python/tvm/meta_schedule/database/ordered_union_database.py @@ -15,6 +15,17 @@ # specific language governing permissions and limitations # under the License. """A database consists of multiple databases.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from meta_schedule import Database +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -98,6 +109,12 @@ class OrderedUnionDatabase(Database): merged_db.query_tuning_record(..., target_workload) """ + # tvm-ffi-stubgen(begin): object/meta_schedule.OrderedUnionDatabase + # fmt: off + databases: Sequence[Database] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, *databases: Database) -> None: """Construct a merged database from multiple databases. diff --git a/python/tvm/meta_schedule/database/schedule_fn_database.py b/python/tvm/meta_schedule/database/schedule_fn_database.py index 74b2a6eb60da..3c85c99ba268 100644 --- a/python/tvm/meta_schedule/database/schedule_fn_database.py +++ b/python/tvm/meta_schedule/database/schedule_fn_database.py @@ -15,6 +15,17 @@ # specific language governing permissions and limitations # under the License. """A database for injecting handcrafted schedule functions.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from tir import Schedule + from typing import Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import Callable from tvm_ffi import register_object @@ -45,6 +56,12 @@ class ScheduleFnDatabase(Database): For the definition of the anchor block, see tir/analysis/analysis.py. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.ScheduleFnDatabase + # fmt: off + schedule_fn: Callable[[Schedule], bool] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, schedule_fn: Callable[[Schedule], bool], diff --git a/python/tvm/meta_schedule/database/union_database.py b/python/tvm/meta_schedule/database/union_database.py index 3a1afbe32adf..43724767dc77 100644 --- a/python/tvm/meta_schedule/database/union_database.py +++ b/python/tvm/meta_schedule/database/union_database.py @@ -15,6 +15,17 @@ # specific language governing permissions and limitations # under the License. """A database consists of multiple databases.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from meta_schedule import Database +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -98,6 +109,12 @@ class UnionDatabase(Database): merged_db.query_tuning_record(..., target_workload) """ + # tvm-ffi-stubgen(begin): object/meta_schedule.UnionDatabase + # fmt: off + databases: Sequence[Database] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, *databases: Database) -> None: """Construct a merged database from multiple databases. diff --git a/python/tvm/meta_schedule/extracted_task.py b/python/tvm/meta_schedule/extracted_task.py index df66e774e595..beb179c2c502 100644 --- a/python/tvm/meta_schedule/extracted_task.py +++ b/python/tvm/meta_schedule/extracted_task.py @@ -15,6 +15,18 @@ # specific language governing permissions and limitations # under the License. """Extracted tasks from high-level IR.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import IRModule + from target import Target +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import List from tvm_ffi import register_object @@ -43,6 +55,16 @@ class ExtractedTask(Object): The weight of the task """ + # tvm-ffi-stubgen(begin): object/meta_schedule.ExtractedTask + # fmt: off + task_name: str + mod: IRModule + target: Target + dispatched: Sequence[IRModule] + weight: int + # fmt: on + # tvm-ffi-stubgen(end) + task_name: str mod: IRModule dispatched: List[IRModule] diff --git a/python/tvm/meta_schedule/feature_extractor/feature_extractor.py b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py index b50a22142943..5a5d9f61f041 100644 --- a/python/tvm/meta_schedule/feature_extractor/feature_extractor.py +++ b/python/tvm/meta_schedule/feature_extractor/feature_extractor.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Meta Schedule FeatureExtractor.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from typing import Callable, List, Union # isort: off @@ -36,6 +38,11 @@ class FeatureExtractor(Object): """Extractor for features from measure candidates for use in cost model.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.FeatureExtractor + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + FeatureExtractorType = Union[Literal["per-store-feature"], "FeatureExtractor"] def extract_from( @@ -83,6 +90,11 @@ class _PyFeatureExtractor(FeatureExtractor): See also: PyFeatureExtractor """ + # tvm-ffi-stubgen(begin): object/meta_schedule.PyFeatureExtractor + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, f_extract_from: Callable, f_as_string: Callable = None): """Constructor.""" diff --git a/python/tvm/meta_schedule/feature_extractor/per_store_feature.py b/python/tvm/meta_schedule/feature_extractor/per_store_feature.py index 673a722955d2..db8c4d66bcb9 100644 --- a/python/tvm/meta_schedule/feature_extractor/per_store_feature.py +++ b/python/tvm/meta_schedule/feature_extractor/per_store_feature.py @@ -18,6 +18,8 @@ """We extract one feature vector per BufferStoreNode statement in a TIR Stmt, so we call this feature as "per-store" feature. """ +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -40,6 +42,16 @@ class PerStoreFeature(FeatureExtractor): Whether to extract features in the workload in tuning context or not. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.PerStoreFeature + # fmt: off + buffers_per_store: int + arith_intensity_curve_num_samples: int + cache_line_bytes: int + extract_workload: bool + feature_vector_length: int + # fmt: on + # tvm-ffi-stubgen(end) + buffers_per_store: int """The number of buffers in each BufferStore; Pad or truncate if necessary.""" arith_intensity_curve_num_samples: int # pylint: disable=invalid-name diff --git a/python/tvm/meta_schedule/measure_callback/add_to_database.py b/python/tvm/meta_schedule/measure_callback/add_to_database.py index e0a6f5a273fc..793dec6971ab 100644 --- a/python/tvm/meta_schedule/measure_callback/add_to_database.py +++ b/python/tvm/meta_schedule/measure_callback/add_to_database.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """A callback that adds the measurement results into the database""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -23,6 +25,10 @@ @register_object("meta_schedule.AddToDatabase") class AddToDatabase(MeasureCallback): + # tvm-ffi-stubgen(begin): object/meta_schedule.AddToDatabase + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) def __init__(self) -> None: """A callback that adds the measurement results into the database""" self.__init_handle_by_constructor__( diff --git a/python/tvm/meta_schedule/measure_callback/measure_callback.py b/python/tvm/meta_schedule/measure_callback/measure_callback.py index 885f70e88de8..5f7720a12e76 100644 --- a/python/tvm/meta_schedule/measure_callback/measure_callback.py +++ b/python/tvm/meta_schedule/measure_callback/measure_callback.py @@ -16,6 +16,8 @@ # under the License. """Meta Schedule MeasureCallback.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from typing import TYPE_CHECKING, Callable, List, Union # isort: off @@ -40,6 +42,11 @@ class MeasureCallback(Object): """Rules to apply after measure results is available.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.MeasureCallback + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + CallbackListType = Union[List["MeasureCallback"], "MeasureCallback", Literal["default"]] def apply( @@ -91,6 +98,11 @@ class _PyMeasureCallback(MeasureCallback): See also: PyMeasureCallback """ + # tvm-ffi-stubgen(begin): object/meta_schedule.PyMeasureCallback + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, f_apply: Callable, f_as_string: Callable = None): """Constructor.""" diff --git a/python/tvm/meta_schedule/measure_callback/remove_build_artifact.py b/python/tvm/meta_schedule/measure_callback/remove_build_artifact.py index 23808b7e99d7..ce8e1ef2688e 100644 --- a/python/tvm/meta_schedule/measure_callback/remove_build_artifact.py +++ b/python/tvm/meta_schedule/measure_callback/remove_build_artifact.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """A callback that removes the build artifacts from the disk""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -23,6 +25,10 @@ @register_object("meta_schedule.RemoveBuildArtifact") class RemoveBuildArtifact(MeasureCallback): + # tvm-ffi-stubgen(begin): object/meta_schedule.RemoveBuildArtifact + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) def __init__(self) -> None: """A callback that removes the build artifacts from the disk""" self.__init_handle_by_constructor__( diff --git a/python/tvm/meta_schedule/measure_callback/update_cost_model.py b/python/tvm/meta_schedule/measure_callback/update_cost_model.py index 7cf60c095b97..5fd69f45b099 100644 --- a/python/tvm/meta_schedule/measure_callback/update_cost_model.py +++ b/python/tvm/meta_schedule/measure_callback/update_cost_model.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """A measure callback that updates the cost model""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -23,6 +25,10 @@ @register_object("meta_schedule.UpdateCostModel") class UpdateCostModel(MeasureCallback): + # tvm-ffi-stubgen(begin): object/meta_schedule.UpdateCostModel + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) def __init__(self) -> None: """A measure callback that updates the cost model""" self.__init_handle_by_constructor__( diff --git a/python/tvm/meta_schedule/mutator/mutate_compute_location.py b/python/tvm/meta_schedule/mutator/mutate_compute_location.py index 620e0062cbff..5a2d607d0f42 100644 --- a/python/tvm/meta_schedule/mutator/mutate_compute_location.py +++ b/python/tvm/meta_schedule/mutator/mutate_compute_location.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """A mutator that mutates the compute-at location decision of SampleComputeLocation""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api @@ -25,6 +27,11 @@ class MutateComputeLocation(Mutator): """A mutator that mutates the compute-at location decision of SampleComputeLocation""" + # tvm-ffi-stubgen(begin): object/meta_schedule.MutateComputeLocation + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.MutatorMutateComputeLocation, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/mutator/mutate_parallel.py b/python/tvm/meta_schedule/mutator/mutate_parallel.py index fc077cd0d4aa..24a11fd6f77c 100644 --- a/python/tvm/meta_schedule/mutator/mutate_parallel.py +++ b/python/tvm/meta_schedule/mutator/mutate_parallel.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Mutator that mutates the parallel extent""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api @@ -25,6 +27,12 @@ class MutateParallel(Mutator): """Mutator that mutates the parallel extent""" + # tvm-ffi-stubgen(begin): object/meta_schedule.MutateParallel + # fmt: off + max_jobs_per_core: int + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, max_jobs_per_core: int) -> None: """Mutator that mutates the parallel extent""" self.__init_handle_by_constructor__( diff --git a/python/tvm/meta_schedule/mutator/mutate_thread_binding.py b/python/tvm/meta_schedule/mutator/mutate_thread_binding.py index 4c9fa44c50a0..56c74f41cb6e 100644 --- a/python/tvm/meta_schedule/mutator/mutate_thread_binding.py +++ b/python/tvm/meta_schedule/mutator/mutate_thread_binding.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Mutator that mutates the thread binding extent""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api @@ -25,6 +27,11 @@ class MutateThreadBinding(Mutator): """Mutator that mutates the binding extent""" + # tvm-ffi-stubgen(begin): object/meta_schedule.MutateThreadBinding + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: """Mutator that mutates the binding extent""" self.__init_handle_by_constructor__( diff --git a/python/tvm/meta_schedule/mutator/mutate_tile_size.py b/python/tvm/meta_schedule/mutator/mutate_tile_size.py index f40894f5ba0f..801ccff4c0fb 100644 --- a/python/tvm/meta_schedule/mutator/mutate_tile_size.py +++ b/python/tvm/meta_schedule/mutator/mutate_tile_size.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Mutator that mutates the decision of instruction Sample-Perfect-Tile""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api @@ -25,6 +27,11 @@ class MutateTileSize(Mutator): """Mutator that mutates the decision of instruction Sample-Perfect-Tile""" + # tvm-ffi-stubgen(begin): object/meta_schedule.MutateTileSize + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.MutatorMutateTileSize, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/mutator/mutate_unroll.py b/python/tvm/meta_schedule/mutator/mutate_unroll.py index 97999c2888f8..2b9691d23cd5 100644 --- a/python/tvm/meta_schedule/mutator/mutate_unroll.py +++ b/python/tvm/meta_schedule/mutator/mutate_unroll.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Mutator that mutates auto unroll step""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api @@ -25,6 +27,11 @@ class MutateUnroll(Mutator): """Mutator that mutates auto unroll step""" + # tvm-ffi-stubgen(begin): object/meta_schedule.MutateUnroll + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.MutatorMutateUnroll, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/mutator/mutator.py b/python/tvm/meta_schedule/mutator/mutator.py index 211e2c2b5015..f29b03832005 100644 --- a/python/tvm/meta_schedule/mutator/mutator.py +++ b/python/tvm/meta_schedule/mutator/mutator.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Meta Schedule Mutator.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from typing import TYPE_CHECKING, Callable, Dict, Optional # isort: off @@ -120,6 +122,11 @@ class _PyMutator(Mutator): See also: PyMutator """ + # tvm-ffi-stubgen(begin): object/meta_schedule.PyMutator + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, f_initialize_with_tune_context: Callable = None, diff --git a/python/tvm/meta_schedule/postproc/disallow_async_strided_mem_copy.py b/python/tvm/meta_schedule/postproc/disallow_async_strided_mem_copy.py index 5c18475ea0ca..f4c9624f396c 100644 --- a/python/tvm/meta_schedule/postproc/disallow_async_strided_mem_copy.py +++ b/python/tvm/meta_schedule/postproc/disallow_async_strided_mem_copy.py @@ -16,6 +16,8 @@ # under the License. """A postprocessor that checks if the IRModule has any strided memory copies""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc @@ -25,6 +27,11 @@ class DisallowAsyncStridedMemCopy(Postproc): """A postprocessor that disallows schedules that use async strided mem copies.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.DisallowAsyncStridedMemCopy + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.PostprocDisallowAsyncStridedMemCopy, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py b/python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py index da604e42cc81..2541a00153d6 100644 --- a/python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py +++ b/python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py @@ -16,6 +16,8 @@ # under the License. """A postprocessor that checks if the IRModule has any loop with non-constant extent""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc @@ -25,6 +27,11 @@ class DisallowDynamicLoop(Postproc): """A postprocessor that checks if the IRModule has any loop with non-constant extent""" + # tvm-ffi-stubgen(begin): object/meta_schedule.DisallowDynamicLoop + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.PostprocDisallowDynamicLoop, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/postproc/postproc.py b/python/tvm/meta_schedule/postproc/postproc.py index 8e89ad2fe138..a4089929e677 100644 --- a/python/tvm/meta_schedule/postproc/postproc.py +++ b/python/tvm/meta_schedule/postproc/postproc.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Meta Schedule Postproc.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from typing import TYPE_CHECKING, Callable, List # isort: off @@ -37,6 +39,11 @@ class Postproc(Object): """Rules to apply a postprocessor to a schedule.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.Postproc + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def _initialize_with_tune_context(self, context: "TuneContext") -> None: """Initialize the postprocessor with a tune context. @@ -114,6 +121,11 @@ class _PyPostproc(Postproc): See also: PyPostproc """ + # tvm-ffi-stubgen(begin): object/meta_schedule.PyPostproc + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, f_initialize_with_tune_context: Callable = None, diff --git a/python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py b/python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py index d20c22d0f6d8..5f681a7c9a93 100644 --- a/python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py +++ b/python/tvm/meta_schedule/postproc/rewrite_cooperative_fetch.py @@ -17,6 +17,8 @@ """A postprocessor that rewrites the cooperative fetch annotation to actual vectorized cooperative fetching in loop bindings.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc @@ -28,6 +30,11 @@ class RewriteCooperativeFetch(Postproc): cooperative fetching in loop bindings. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.RewriteCooperativeFetch + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.PostprocRewriteCooperativeFetch, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/postproc/rewrite_layout.py b/python/tvm/meta_schedule/postproc/rewrite_layout.py index 73b6dde9f76a..f5e99fe435c2 100644 --- a/python/tvm/meta_schedule/postproc/rewrite_layout.py +++ b/python/tvm/meta_schedule/postproc/rewrite_layout.py @@ -16,6 +16,8 @@ # under the License. """A postprocessor that rewrites the layout of input tensor""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api @@ -26,6 +28,11 @@ class RewriteLayout(Postproc): """A postprocessor that rewrites the layout of input tensor""" + # tvm-ffi-stubgen(begin): object/meta_schedule.RewriteLayout + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.PostprocRewriteLayout, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py b/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py index 30235517f9c6..3aa1c8c14ed7 100644 --- a/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py +++ b/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py @@ -17,6 +17,8 @@ """A postprocessor that applies parallelization, vectorization and auto unrolling according to the annotation of each block""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc @@ -27,6 +29,11 @@ class RewriteParallelVectorizeUnroll(Postproc): """A postprocessor that applies parallelization, vectorization and auto unrolling according to the annotation of each block""" + # tvm-ffi-stubgen(begin): object/meta_schedule.RewriteParallelVectorizeUnroll + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.PostprocRewriteParallelVectorizeUnroll, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/postproc/rewrite_reduction_block.py b/python/tvm/meta_schedule/postproc/rewrite_reduction_block.py index 5bbe2b88381e..b748e9adb77f 100644 --- a/python/tvm/meta_schedule/postproc/rewrite_reduction_block.py +++ b/python/tvm/meta_schedule/postproc/rewrite_reduction_block.py @@ -16,6 +16,8 @@ # under the License. """A postprocessor that rewrites reduction block by moving the init block out.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc @@ -25,6 +27,11 @@ class RewriteReductionBlock(Postproc): """A postprocessor that rewrites reduction block by moving the init block out.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.RewriteReductionBlock + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.PostprocRewriteReductionBlock, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/postproc/rewrite_tensorize.py b/python/tvm/meta_schedule/postproc/rewrite_tensorize.py index 8f0edb869586..0a2d59fac9e0 100644 --- a/python/tvm/meta_schedule/postproc/rewrite_tensorize.py +++ b/python/tvm/meta_schedule/postproc/rewrite_tensorize.py @@ -16,6 +16,8 @@ # under the License. """A postprocessor that tensorize related components.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc @@ -31,6 +33,11 @@ class RewriteTensorize(Postproc): Whether or not vectorize the initialization loop produced by DecomposeReduction """ + # tvm-ffi-stubgen(begin): object/meta_schedule.RewriteTensorize + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, vectorize_init_loop=False) -> None: self.__init_handle_by_constructor__( _ffi_api.PostprocRewriteTensorize, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py b/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py index b274c2f55c11..4820afab989b 100644 --- a/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py +++ b/python/tvm/meta_schedule/postproc/rewrite_unbound_block.py @@ -16,6 +16,8 @@ # under the License. """A postprocessor that adds thread binding to unbound blocks""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api @@ -26,6 +28,11 @@ class RewriteUnboundBlock(Postproc): """A postprocessor that adds thread binding to unbound blocks""" + # tvm-ffi-stubgen(begin): object/meta_schedule.RewriteUnboundBlock + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, max_threadblocks: int = 256) -> None: self.__init_handle_by_constructor__( _ffi_api.PostprocRewriteUnboundBlock, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/postproc/verify_gpu_code.py b/python/tvm/meta_schedule/postproc/verify_gpu_code.py index 48fbe8f4b14c..c42f2a8ae44f 100644 --- a/python/tvm/meta_schedule/postproc/verify_gpu_code.py +++ b/python/tvm/meta_schedule/postproc/verify_gpu_code.py @@ -16,6 +16,8 @@ # under the License. """A postprocessor that verifies if the GPU code is correct""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc @@ -25,6 +27,11 @@ class VerifyGPUCode(Postproc): """A postprocessor that verifies if the GPU code is correct""" + # tvm-ffi-stubgen(begin): object/meta_schedule.VerifyGPUCode + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.PostprocVerifyGPUCode, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/postproc/verify_vtcm_limit.py b/python/tvm/meta_schedule/postproc/verify_vtcm_limit.py index 96ece2270bbc..745c6f684d96 100644 --- a/python/tvm/meta_schedule/postproc/verify_vtcm_limit.py +++ b/python/tvm/meta_schedule/postproc/verify_vtcm_limit.py @@ -16,6 +16,8 @@ # under the License. """A postprocessor that verifies the VTCM usage of a given schedule.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi.registry import register_object from .. import _ffi_api from .postproc import Postproc @@ -25,6 +27,11 @@ class VerifyVTCMLimit(Postproc): """Verifies that the VTCM usage of a given schedule is within the provided limit.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.VerifyVTCMLimit + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.PostprocVerifyVTCMLimit, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/profiler.py b/python/tvm/meta_schedule/profiler.py index 1a41f589de4c..cd5c6564414c 100644 --- a/python/tvm/meta_schedule/profiler.py +++ b/python/tvm/meta_schedule/profiler.py @@ -16,6 +16,8 @@ # under the License. # pylint: disable=used-before-assignment """A context manager that profiles tuning time cost for different parts.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from contextlib import contextmanager from typing import Dict, Optional @@ -29,6 +31,11 @@ class Profiler(Object): """Tuning time profiler.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.Profiler + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.Profiler, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/runner/runner.py b/python/tvm/meta_schedule/runner/runner.py index 0d7cd32bd7a5..a829a24d8e44 100644 --- a/python/tvm/meta_schedule/runner/runner.py +++ b/python/tvm/meta_schedule/runner/runner.py @@ -15,6 +15,18 @@ # specific language governing permissions and limitations # under the License. """Runners""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import FloatImm + from meta_schedule import ArgInfo +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import Callable, List, Optional, Union # isort: off @@ -43,6 +55,14 @@ class RunnerInput(Object): The argument information. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.RunnerInput + # fmt: off + artifact_path: str + device_type: str + args_info: Sequence[ArgInfo] + # fmt: on + # tvm-ffi-stubgen(end) + artifact_path: str device_type: str args_info: List[ArgInfo] @@ -84,6 +104,13 @@ class RunnerResult(Object): The error message, if any. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.RunnerResult + # fmt: off + run_secs: Sequence[FloatImm] | None + error_msg: str | None + # fmt: on + # tvm-ffi-stubgen(end) + run_secs: Optional[List[float]] error_msg: Optional[str] @@ -118,6 +145,11 @@ class RunnerFuture(Object): See also: PyRunnerFuture """ + # tvm-ffi-stubgen(begin): object/meta_schedule.RunnerFuture + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, f_done: Callable, f_result: Callable = None) -> None: """Constructor""" @@ -167,6 +199,11 @@ def result(self) -> RunnerResult: class Runner(Object): """The abstract runner interface""" + # tvm-ffi-stubgen(begin): object/meta_schedule.Runner + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + RunnerType = Union["Runner", Literal["local", "rpc"]] def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]: @@ -214,6 +251,11 @@ class _PyRunner(Runner): See also: PyRunner """ + # tvm-ffi-stubgen(begin): object/meta_schedule.PyRunner + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, f_run: Callable = None) -> None: """Constructor""" diff --git a/python/tvm/meta_schedule/schedule_rule/add_rfactor.py b/python/tvm/meta_schedule/schedule_rule/add_rfactor.py index 2bef40fffe74..3812d6f0290f 100644 --- a/python/tvm/meta_schedule/schedule_rule/add_rfactor.py +++ b/python/tvm/meta_schedule/schedule_rule/add_rfactor.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Add-rfactor Rule that add-rfactor to some blocks if needed""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from typing import Optional from tvm_ffi import register_object @@ -37,6 +39,13 @@ class AddRFactor(ScheduleRule): The maximum size of the innermost factor. None means no limit. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.AddRFactor + # fmt: off + max_jobs_per_core: int + max_innermost_factor: int + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, max_jobs_per_core: int = 16, diff --git a/python/tvm/meta_schedule/schedule_rule/apply_custom_rule.py b/python/tvm/meta_schedule/schedule_rule/apply_custom_rule.py index 2e383c75eb91..fde1d249b777 100644 --- a/python/tvm/meta_schedule/schedule_rule/apply_custom_rule.py +++ b/python/tvm/meta_schedule/schedule_rule/apply_custom_rule.py @@ -16,6 +16,16 @@ # under the License. """Create a rule that applies customized rules registered using block attribute `schedule_rule`. The rule will be dispatched according to target keys.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from target import Target +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -27,6 +37,12 @@ class ApplyCustomRule(ScheduleRule): """A rule that applies customized rules registered using block attribute `schedule_rule`. The rule will be dispatched according to target keys.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.ApplyCustomRule + # fmt: off + target_: Target | None + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.ScheduleRuleApplyCustomRule, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/schedule_rule/auto_bind.py b/python/tvm/meta_schedule/schedule_rule/auto_bind.py index 0704b03f740f..05d794a249f4 100644 --- a/python/tvm/meta_schedule/schedule_rule/auto_bind.py +++ b/python/tvm/meta_schedule/schedule_rule/auto_bind.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Auto-bind Rule that binds blocks to threads if needed""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from typing import List, Optional from tvm_ffi import register_object @@ -37,6 +39,11 @@ class AutoBind(ScheduleRule): The maximum number of threads per block, if it is known when this schedule rule is created. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.AutoBind + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, max_threadblocks: int = 256, diff --git a/python/tvm/meta_schedule/schedule_rule/auto_inline.py b/python/tvm/meta_schedule/schedule_rule/auto_inline.py index b789dd750707..cb0f2d42a6a2 100644 --- a/python/tvm/meta_schedule/schedule_rule/auto_inline.py +++ b/python/tvm/meta_schedule/schedule_rule/auto_inline.py @@ -15,6 +15,17 @@ # specific language governing permissions and limitations # under the License. """Auto-Inline. Rule that inlines spatial blocks if it satisfies some conditions""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import Op +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import List, Optional from tvm_ffi import register_object @@ -45,6 +56,18 @@ class AutoInline(ScheduleRule): The operators that are disallowed in auto inline """ + # tvm-ffi-stubgen(begin): object/meta_schedule.AutoInline + # fmt: off + into_producer: bool + into_consumer: bool + inline_const_tensor: bool + disallow_if_then_else: bool + require_injective: bool + require_ordered: bool + disallow_op: Sequence[Op] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, into_producer: bool, @@ -76,6 +99,11 @@ class InlineConstantScalars(ScheduleRule): InlineConstantScalars before AutoInline. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.InlineConstantScalars + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, ) -> None: diff --git a/python/tvm/meta_schedule/schedule_rule/cross_thread_reduction.py b/python/tvm/meta_schedule/schedule_rule/cross_thread_reduction.py index 0c79d4f08bac..7370977652e9 100644 --- a/python/tvm/meta_schedule/schedule_rule/cross_thread_reduction.py +++ b/python/tvm/meta_schedule/schedule_rule/cross_thread_reduction.py @@ -15,6 +15,17 @@ # specific language governing permissions and limitations # under the License. """Rules which apply cross-thread reduction to some reduction blocks correspondingly when needed""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import IntImm +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import List from tvm_ffi import register_object @@ -34,6 +45,14 @@ class CrossThreadReduction(ScheduleRule): Candidates of thread axis extent (values are required to be positive). """ + # tvm-ffi-stubgen(begin): object/meta_schedule.CrossThreadReduction + # fmt: off + max_threads_per_block: int + warp_size: int + thread_extents: Sequence[IntImm] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, thread_extents: List[int]) -> None: self.__init_handle_by_constructor__( _ffi_api.ScheduleRuleCrossThreadReduction, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py index 41c97a7862b4..c0fe81eb456f 100644 --- a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py +++ b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py @@ -15,6 +15,16 @@ # specific language governing permissions and limitations # under the License. """Multi-level tiling with reuse.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Callable from tvm.tir.schedule import Schedule, BlockRV @@ -70,6 +80,14 @@ class MultiLevelTiling(ScheduleRule): for a block that should be tiled (based on the block name, for example). """ + # tvm-ffi-stubgen(begin): object/meta_schedule.MultiLevelTiling + # fmt: off + structure: str + tile_binds: Sequence[str] + max_innermost_factor: int + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, structure: str, @@ -119,6 +137,11 @@ class MultiLevelTilingWithIntrin(ScheduleRule): Data reuse configuration for writing. None means no reuse. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.MultiLevelTilingWithIntrin + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, intrin_name: str, @@ -173,6 +196,11 @@ class MultiLevelTilingTensorCore(ScheduleRule): Whether to use the software pipeline. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.MultiLevelTilingTensorCore + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, intrin_groups: List[Mapping[str, str]], @@ -216,6 +244,11 @@ class MultiLevelTilingWideVector(ScheduleRule): Data reuse configuration for writing. None means no reuse. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.MultiLevelTilingWideVector + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, structure: str, diff --git a/python/tvm/meta_schedule/schedule_rule/parallel_vectorize_unroll.py b/python/tvm/meta_schedule/schedule_rule/parallel_vectorize_unroll.py index 259620b3f715..7c459173e8c6 100644 --- a/python/tvm/meta_schedule/schedule_rule/parallel_vectorize_unroll.py +++ b/python/tvm/meta_schedule/schedule_rule/parallel_vectorize_unroll.py @@ -16,6 +16,17 @@ # under the License. """Rule that mark parallelize, vectorize and unroll to the root block. The mark will be applied to each block in a follow-up post processor""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import IntImm +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import List, Optional from tvm_ffi import register_object @@ -46,6 +57,15 @@ class ParallelizeVectorizeUnroll(ScheduleRule): Whether to explicitly unroll the loop, or just add an "unroll" pragma """ + # tvm-ffi-stubgen(begin): object/meta_schedule.ParallelizeVectorizeUnroll + # fmt: off + max_jobs_per_core: int + max_vectorize_extent: int + unroll_max_steps: Sequence[IntImm] + unroll_explicit: bool + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, max_jobs_per_core: int = 16, diff --git a/python/tvm/meta_schedule/schedule_rule/random_compute_location.py b/python/tvm/meta_schedule/schedule_rule/random_compute_location.py index 8f1c96f6eb0a..d9916db594a5 100644 --- a/python/tvm/meta_schedule/schedule_rule/random_compute_location.py +++ b/python/tvm/meta_schedule/schedule_rule/random_compute_location.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Rule that randomly select a compute-at location for a free block""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -25,6 +27,11 @@ class RandomComputeLocation(ScheduleRule): """A rule that randomly select a compute-at location for a free block""" + # tvm-ffi-stubgen(begin): object/meta_schedule.RandomComputeLocation + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__( _ffi_api.ScheduleRuleRandomComputeLocation, # type: ignore # pylint: disable=no-member diff --git a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py index 98c81e5b8f30..39903ed4969f 100644 --- a/python/tvm/meta_schedule/schedule_rule/schedule_rule.py +++ b/python/tvm/meta_schedule/schedule_rule/schedule_rule.py @@ -18,6 +18,8 @@ Meta Schedule schedule rules are used for modification of blocks in a schedule. See also PostOrderApply. """ +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from typing import TYPE_CHECKING, Callable, List # isort: off @@ -40,6 +42,11 @@ class ScheduleRule(Object): """Rules to modify a block in a schedule.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.ScheduleRule + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def _initialize_with_tune_context(self, context: "TuneContext") -> None: """Initialize the schedule rule with a tune context. @@ -121,6 +128,11 @@ class _PyScheduleRule(ScheduleRule): See also: PyScheduleRule """ + # tvm-ffi-stubgen(begin): object/meta_schedule.PyScheduleRule + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, f_initialize_with_tune_context: Callable = None, diff --git a/python/tvm/meta_schedule/search_strategy/evolutionary_search.py b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py index 04f9310e6e0d..c816c2a5a690 100644 --- a/python/tvm/meta_schedule/search_strategy/evolutionary_search.py +++ b/python/tvm/meta_schedule/search_strategy/evolutionary_search.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Evolutionary Search Strategy""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -47,6 +49,20 @@ class EvolutionarySearch(SearchStrategy): The ratio of greedy selected samples in the final picks. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.EvolutionarySearch + # fmt: off + population_size: int + num_empty_iters_before_early_stop: int + init_measured_ratio: float + init_min_unmeasured: int + max_fail_count: int + genetic_num_iters: int + genetic_mutate_prob: float + genetic_max_fail_count: int + eps_greedy: float + # fmt: on + # tvm-ffi-stubgen(end) + population_size: int init_measured_ratio: int init_min_unmeasured: int diff --git a/python/tvm/meta_schedule/search_strategy/replay_func.py b/python/tvm/meta_schedule/search_strategy/replay_func.py index 682c9638c513..0ced919f671b 100644 --- a/python/tvm/meta_schedule/search_strategy/replay_func.py +++ b/python/tvm/meta_schedule/search_strategy/replay_func.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Replay Trace Search Strategy""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -35,6 +37,11 @@ class ReplayFunc(SearchStrategy): Total number of trials for one task """ + # tvm-ffi-stubgen(begin): object/meta_schedule.ReplayFunc + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self): """Constructor""" self.__init_handle_by_constructor__( diff --git a/python/tvm/meta_schedule/search_strategy/replay_trace.py b/python/tvm/meta_schedule/search_strategy/replay_trace.py index e04a440da68a..a29643d1d26b 100644 --- a/python/tvm/meta_schedule/search_strategy/replay_trace.py +++ b/python/tvm/meta_schedule/search_strategy/replay_trace.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Replay Trace Search Strategy""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -33,6 +35,12 @@ class ReplayTrace(SearchStrategy): Max number of failures during trace replaying. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.ReplayTrace + # fmt: off + max_fail_count: int + # fmt: on + # tvm-ffi-stubgen(end) + max_fail_count: int def __init__(self, max_fail_count: int = 100): diff --git a/python/tvm/meta_schedule/search_strategy/search_strategy.py b/python/tvm/meta_schedule/search_strategy/search_strategy.py index cfb45dafdeb2..958bb98df17e 100644 --- a/python/tvm/meta_schedule/search_strategy/search_strategy.py +++ b/python/tvm/meta_schedule/search_strategy/search_strategy.py @@ -18,6 +18,18 @@ Meta Schedule search strategy that generates the measure candidates for measurement. """ +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from meta_schedule import ArgInfo + from tir import Schedule +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import TYPE_CHECKING, Callable, List, Optional, Union # isort: off @@ -50,6 +62,13 @@ class MeasureCandidate(Object): The argument information. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.MeasureCandidate + # fmt: off + sch: Schedule + args_info: Sequence[ArgInfo] + # fmt: on + # tvm-ffi-stubgen(end) + sch: Schedule args_info: List[ArgInfo] @@ -78,6 +97,11 @@ def __init__( class SearchStrategy(Object): """Search strategy is the class that generates the measure candidates.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.SearchStrategy + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + SearchStrategyType = Union[ "SearchStrategy", Literal[ @@ -228,6 +252,11 @@ class _PySearchStrategy(SearchStrategy): See also: PySearchStrategy """ + # tvm-ffi-stubgen(begin): object/meta_schedule.PySearchStrategy + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, f_initialize_with_tune_context: Callable = None, diff --git a/python/tvm/meta_schedule/space_generator/post_order_apply.py b/python/tvm/meta_schedule/space_generator/post_order_apply.py index 45b81bdf3e59..9ab087270e98 100644 --- a/python/tvm/meta_schedule/space_generator/post_order_apply.py +++ b/python/tvm/meta_schedule/space_generator/post_order_apply.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Post Order Apply Space Generator.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -42,6 +44,11 @@ class PostOrderApply(SpaceGenerator): all blocks will have schedules generated. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.PostOrderApply + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, f_block_filter=None, diff --git a/python/tvm/meta_schedule/space_generator/schedule_fn.py b/python/tvm/meta_schedule/space_generator/schedule_fn.py index d01cd7fdcbd1..85dbcbaf0d0a 100644 --- a/python/tvm/meta_schedule/space_generator/schedule_fn.py +++ b/python/tvm/meta_schedule/space_generator/schedule_fn.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Union of meta Schedule design space generators.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -36,6 +38,11 @@ class ScheduleFn(SpaceGenerator): - 3) [Schedule] -> List[Schedule] """ + # tvm-ffi-stubgen(begin): object/meta_schedule.ScheduleFn + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, sch_fn: SpaceGenerator.ScheduleFnType, diff --git a/python/tvm/meta_schedule/space_generator/space_generator.py b/python/tvm/meta_schedule/space_generator/space_generator.py index 35f9e2236764..852b26a32ee2 100644 --- a/python/tvm/meta_schedule/space_generator/space_generator.py +++ b/python/tvm/meta_schedule/space_generator/space_generator.py @@ -18,6 +18,18 @@ Meta Schedule design space generators that generates design space for generation of measure candidates. """ +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import FloatImm + from meta_schedule import Mutator, Postproc, ScheduleRule +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union # isort: off @@ -42,6 +54,14 @@ class SpaceGenerator(Object): """The abstract design space generator interface.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.SpaceGenerator + # fmt: off + sch_rules: Sequence[ScheduleRule] | None + postprocs: Sequence[Postproc] | None + mutator_probs: Mapping[Mutator, FloatImm] | None + # fmt: on + # tvm-ffi-stubgen(end) + ScheduleFnType = Union[ Callable[[Schedule], None], # No output Callable[[Schedule], Schedule], # Single output @@ -193,6 +213,11 @@ class _PySpaceGenerator(SpaceGenerator): See also: PySpaceGenerator """ + # tvm-ffi-stubgen(begin): object/meta_schedule.PySpaceGenerator + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, sch_rules: ScheduleRuleType = "from-target", diff --git a/python/tvm/meta_schedule/space_generator/space_generator_union.py b/python/tvm/meta_schedule/space_generator/space_generator_union.py index 0b8ceb453116..24d4c7ceb612 100644 --- a/python/tvm/meta_schedule/space_generator/space_generator_union.py +++ b/python/tvm/meta_schedule/space_generator/space_generator_union.py @@ -15,6 +15,17 @@ # specific language governing permissions and limitations # under the License. """Union of meta Schedule design space generators.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from meta_schedule import SpaceGenerator +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import List from tvm_ffi import register_object @@ -33,6 +44,12 @@ class SpaceGeneratorUnion(SpaceGenerator): """Union of design space generators.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.SpaceGeneratorUnion + # fmt: off + space_generators: Sequence[SpaceGenerator] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, space_generators: List[SpaceGenerator], diff --git a/python/tvm/meta_schedule/task_scheduler/gradient_based.py b/python/tvm/meta_schedule/task_scheduler/gradient_based.py index 18d7e2be614a..008a4dfffb87 100644 --- a/python/tvm/meta_schedule/task_scheduler/gradient_based.py +++ b/python/tvm/meta_schedule/task_scheduler/gradient_based.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Gradient Based Task Scheduler""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -28,6 +30,13 @@ class GradientBased(TaskScheduler): """Gradient Based Task Scheduler""" + # tvm-ffi-stubgen(begin): object/meta_schedule.GradientBased + # fmt: off + alpha: float + window_size: int + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, *, diff --git a/python/tvm/meta_schedule/task_scheduler/round_robin.py b/python/tvm/meta_schedule/task_scheduler/round_robin.py index 78504608f9ab..96ee6824425d 100644 --- a/python/tvm/meta_schedule/task_scheduler/round_robin.py +++ b/python/tvm/meta_schedule/task_scheduler/round_robin.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Round Robin Task Scheduler""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from .. import _ffi_api @@ -28,6 +30,12 @@ class RoundRobin(TaskScheduler): """Round Robin Task Scheduler""" + # tvm-ffi-stubgen(begin): object/meta_schedule.RoundRobin + # fmt: off + task_id: int + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: """Constructor.""" self.__init_handle_by_constructor__( diff --git a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py index 4513f6081560..56f78c6b4ed8 100644 --- a/python/tvm/meta_schedule/task_scheduler/task_scheduler.py +++ b/python/tvm/meta_schedule/task_scheduler/task_scheduler.py @@ -15,6 +15,17 @@ # specific language governing permissions and limitations # under the License. """Auto-tuning Task Scheduler""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from meta_schedule import BuilderResult, CostModel, Database, MeasureCallback, MeasureCandidate, RunnerFuture, TuneContext +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import Callable, List, Optional, Union # isort: off @@ -42,6 +53,20 @@ class TaskRecord(Object): """The running record of a task.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.TaskRecord + # fmt: off + ctx: TuneContext + task_weight: float + flop: float + is_terminated: bool + build_error_count: int + run_error_count: int + measure_candidates: Sequence[MeasureCandidate] | None + builder_results: Sequence[BuilderResult] | None + runner_futures: Sequence[RunnerFuture] | None + # fmt: on + # tvm-ffi-stubgen(end) + ctx: TuneContext task_weight: float flop: float @@ -57,6 +82,16 @@ class TaskRecord(Object): class TaskScheduler(Object): """The abstract task scheduler interface.""" + # tvm-ffi-stubgen(begin): object/meta_schedule.TaskScheduler + # fmt: off + tasks_: Sequence[TaskRecord] + measure_callbacks_: Sequence[MeasureCallback] + database_: Database | None + cost_model_: CostModel | None + remaining_tasks_: int + # fmt: on + # tvm-ffi-stubgen(end) + tasks_: List[TaskRecord] measure_callbacks_: List[MeasureCallback] database_: Optional[Database] @@ -198,6 +233,11 @@ class _PyTaskScheduler(TaskScheduler): See also: PyTaskScheduler """ + # tvm-ffi-stubgen(begin): object/meta_schedule.PyTaskScheduler + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, f_next_task_id: Callable, diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 35a8d468a75c..a6816445aaa0 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -16,6 +16,18 @@ # under the License. """Meta Schedule tuning context.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ir import IRModule + from meta_schedule import SearchStrategy, SpaceGenerator + from target import Target +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import TYPE_CHECKING, List, Optional, Union # isort: off @@ -84,6 +96,18 @@ class TuneContext(Object): The number of threads to be used, None means using the logical cpu count. """ + # tvm-ffi-stubgen(begin): object/meta_schedule.TuneContext + # fmt: off + mod: IRModule | None + target: Target | None + space_generator: SpaceGenerator | None + search_strategy: SearchStrategy | None + task_name: str | None + num_threads: int + rand_state: int + # fmt: on + # tvm-ffi-stubgen(end) + mod: Optional[IRModule] target: Optional[Target] space_generator: Optional["SpaceGenerator"] diff --git a/python/tvm/relax/_ffi_api.py b/python/tvm/relax/_ffi_api.py index c5e98a22eaaf..89df5defbe5d 100644 --- a/python/tvm/relax/_ffi_api.py +++ b/python/tvm/relax/_ffi_api.py @@ -15,6 +15,143 @@ # specific language governing permissions and limitations # under the License. """FFI API for Relax.""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from expr_functor import PyExprMutator, PyExprVisitor + from ir import Attrs, BaseFunc, DictAttrs, EnvFunc, GlobalVar, IRModule, IntImm, PrimExpr, RelaxExpr, Span, StructInfo, VDevice + from relax import BlockBuilder, DataflowBlockRewrite, DynTensorType, ExecBuilder, FuncStructInfo, Id, ObjectStructInfo, ObjectType, PackedFuncType, PrimStructInfo, ShapeStructInfo, ShapeType, TensorStructInfo, TupleStructInfo + from relax.expr import Binding, BindingBlock, Call, Constant, DataTypeImm, DataflowBlock, DataflowVar, ExternFunc, Function, If, MatchCast, PrimValue, SeqExpr, ShapeExpr, StringImm, Tuple, TupleGetItem, Var, VarBinding + from target import Target + from te import Tensor + from tir import Var + from tvm_ffi import Module, Object, Tensor, dtype + from typing import Any, Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax", __name__) + +# tvm-ffi-stubgen(begin): global/relax +# fmt: off +_FFI_INIT_FUNC("relax", __name__) +if TYPE_CHECKING: + def BindingBlock(_0: Sequence[Binding], _1: Span, /) -> BindingBlock: ... + def BlockBuilderAddFunction(_0: BlockBuilder, _1: BaseFunc, _2: str, /) -> GlobalVar: ... + def BlockBuilderBeginBindingBlock(_0: BlockBuilder, /) -> None: ... + def BlockBuilderBeginDataflowBlock(_0: BlockBuilder, /) -> None: ... + def BlockBuilderBeginScope(_0: BlockBuilder, _1: Sequence[Var] | None, /) -> None: ... + def BlockBuilderCreate(_0: IRModule | None, /) -> BlockBuilder: ... + def BlockBuilderCurrentBlockIsDataFlow(_0: BlockBuilder, /) -> bool: ... + def BlockBuilderEmit(_0: BlockBuilder, _1: RelaxExpr, _2: str, /) -> Var: ... + def BlockBuilderEmitMatchCast(_0: BlockBuilder, _1: RelaxExpr, _2: StructInfo, _3: str, /) -> Var: ... + def BlockBuilderEmitNormalized(_0: BlockBuilder, _1: Binding, /) -> None: ... + def BlockBuilderEmitOutput(_0: BlockBuilder, _1: RelaxExpr, _2: str, /) -> Var: ... + def BlockBuilderEndBlock(_0: BlockBuilder, /) -> BindingBlock: ... + def BlockBuilderEndScope(_0: BlockBuilder, /) -> None: ... + def BlockBuilderFinalize(_0: BlockBuilder, /) -> IRModule: ... + def BlockBuilderGetContextIRModule(_0: BlockBuilder, /) -> IRModule: ... + def BlockBuilderGetUniqueName(_0: BlockBuilder, _1: str, /) -> str: ... + def BlockBuilderLookupBinding(_0: BlockBuilder, _1: Var, /) -> RelaxExpr | None: ... + def BlockBuilderNormalize(_0: BlockBuilder, _1: RelaxExpr, /) -> RelaxExpr: ... + def BlockBuilderUpdateFunction(_0: BlockBuilder, _1: GlobalVar, _2: BaseFunc, /) -> None: ... + def Call(_0: RelaxExpr, _1: Sequence[RelaxExpr], _2: Attrs, _3: Sequence[StructInfo], _4: Span, /) -> Call: ... + def Constant(_0: Tensor, _1: StructInfo | None, _2: Span, /) -> Constant: ... + def CopyWithNewVars(_0: Function, /) -> Function: ... + def DataTypeImm(_0: dtype, _1: Span, /) -> DataTypeImm: ... + def DataflowBlock(_0: Sequence[Binding], _1: Span, /) -> DataflowBlock: ... + def DataflowBlockRewrite(_0: DataflowBlock, _1: Function, /) -> DataflowBlockRewrite: ... + def DataflowVar(_0: str, _1: StructInfo | None, _2: Span, /) -> DataflowVar: ... + def DataflowVarFromId(_0: Id, _1: StructInfo | None, _2: Span, /) -> DataflowVar: ... + def ExecBuilderC(_0: ExecBuilder, _1: int, /) -> int: ... + def ExecBuilderConvertConstant(*args: Any) -> Any: ... + def ExecBuilderCreate() -> ExecBuilder: ... + def ExecBuilderDeclareFunction(_0: ExecBuilder, _1: str, _2: int, /) -> None: ... + def ExecBuilderEmitCall(_0: ExecBuilder, _1: str, _2: Sequence[IntImm], _3: int, /) -> None: ... + def ExecBuilderEmitFunction(_0: ExecBuilder, _1: str, _2: int, _3: Sequence[str] | None, /) -> None: ... + def ExecBuilderEmitGoto(_0: ExecBuilder, _1: int, /) -> None: ... + def ExecBuilderEmitIf(_0: ExecBuilder, _1: int, _2: int, /) -> None: ... + def ExecBuilderEmitRet(_0: ExecBuilder, _1: int, /) -> None: ... + def ExecBuilderEndFunction(_0: ExecBuilder, _1: str, /) -> None: ... + def ExecBuilderF(_0: ExecBuilder, _1: str, /) -> int: ... + def ExecBuilderGet(_0: ExecBuilder, /) -> Module: ... + def ExecBuilderImm(_0: ExecBuilder, _1: int, /) -> int: ... + def ExecBuilderR(_0: ExecBuilder, _1: int, /) -> int: ... + def ExecutableLoadFromFile(_0: str, /) -> Module: ... + def ExprMutatorVisitBinding(_0: PyExprMutator, _1: Binding, /) -> None: ... + def ExprMutatorVisitBindingBlock(_0: PyExprMutator, _1: BindingBlock, /) -> BindingBlock: ... + def ExprMutatorVisitExpr(_0: PyExprMutator, _1: RelaxExpr, /) -> RelaxExpr: ... + def ExprMutatorVisitVarDef(_0: PyExprMutator, _1: Var, /) -> Var: ... + def ExprVisitorVisitBinding(_0: PyExprVisitor, _1: Binding, /) -> None: ... + def ExprVisitorVisitBindingBlock(_0: PyExprVisitor, _1: BindingBlock, /) -> None: ... + def ExprVisitorVisitExpr(_0: PyExprVisitor, _1: RelaxExpr, /) -> None: ... + def ExprVisitorVisitSpan(_0: PyExprVisitor, _1: Span, /) -> None: ... + def ExprVisitorVisitVarDef(_0: PyExprVisitor, _1: Var, /) -> None: ... + def ExternFunc(_0: str, _1: StructInfo | None, _2: Span, /) -> ExternFunc: ... + def FuncStructInfo(_0: Sequence[StructInfo], _1: StructInfo, _2: bool, _3: Span, /) -> FuncStructInfo: ... + def FuncStructInfoOpaqueFunc(_0: StructInfo | None, _1: EnvFunc | None, _2: bool, _3: Span, /) -> FuncStructInfo: ... + def FuncWithAttr(_0: BaseFunc, _1: str, _2: Object, /) -> Function | None: ... + def FuncWithAttrs(_0: BaseFunc, _1: Mapping[str, Any], /) -> Function | None: ... + def FuncWithoutAttr(_0: BaseFunc, _1: str, /) -> Function | None: ... + def Function(_0: Sequence[Var], _1: RelaxExpr, _2: StructInfo | None, _3: bool, _4: DictAttrs, _5: Span, /) -> Function: ... + def FunctionBindParams(_0: Function, _1: Mapping[Any, Object], /) -> Function: ... + def FunctionBindSymbolicVars(_0: Function, _1: Mapping[Var | str, PrimExpr], /) -> Function: ... + def FunctionCreateEmpty(_0: Sequence[Var], _1: StructInfo, _2: bool, _3: DictAttrs, _4: Span, /) -> Function: ... + def FunctionInlineFunctions(_0: Function, _1: Mapping[str | GlobalVar, Function], /) -> Function: ... + def GetShapeOf(_0: RelaxExpr, /) -> RelaxExpr: ... + def If(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: Span, /) -> If: ... + def MakePyExprMutator(_0: BlockBuilder, _1: Callable[..., Any], _2: Callable[..., Any], _3: Callable[..., Any], _4: Callable[..., Any], _5: Callable[..., Any], _6: Callable[..., Any], _7: Callable[..., Any], _8: Callable[..., Any], _9: Callable[..., Any], _10: Callable[..., Any], _11: Callable[..., Any], _12: Callable[..., Any], _13: Callable[..., Any], _14: Callable[..., Any], _15: Callable[..., Any], _16: Callable[..., Any], _17: Callable[..., Any], _18: Callable[..., Any], _19: Callable[..., Any], _20: Callable[..., Any], _21: Callable[..., Any], _22: Callable[..., Any], _23: Callable[..., Any], _24: Callable[..., Any], _25: Callable[..., Any], _26: Callable[..., Any], _27: Callable[..., Any], /) -> PyExprMutator: ... + def MakePyExprVisitor(_0: Callable[..., Any], _1: Callable[..., Any], _2: Callable[..., Any], _3: Callable[..., Any], _4: Callable[..., Any], _5: Callable[..., Any], _6: Callable[..., Any], _7: Callable[..., Any], _8: Callable[..., Any], _9: Callable[..., Any], _10: Callable[..., Any], _11: Callable[..., Any], _12: Callable[..., Any], _13: Callable[..., Any], _14: Callable[..., Any], _15: Callable[..., Any], _16: Callable[..., Any], _17: Callable[..., Any], _18: Callable[..., Any], _19: Callable[..., Any], _20: Callable[..., Any], _21: Callable[..., Any], _22: Callable[..., Any], _23: Callable[..., Any], _24: Callable[..., Any], _25: Callable[..., Any], _26: Callable[..., Any], /) -> PyExprVisitor: ... + def MatchCast(_0: Var, _1: RelaxExpr, _2: StructInfo, _3: Span, /) -> MatchCast: ... + def ObjectStructInfo(_0: Span, /) -> ObjectStructInfo: ... + def ObjectType(_0: Span, /) -> ObjectType: ... + def PackedFuncType(_0: Span, /) -> PackedFuncType: ... + def PrimStructInfoFromDtype(_0: dtype, _1: Span, /) -> PrimStructInfo: ... + def PrimStructInfoFromValue(_0: PrimExpr, _1: Span, /) -> PrimStructInfo: ... + def PrimValue(_0: PrimExpr, _1: Span, /) -> PrimValue: ... + def PyExprMutatorGetVarRemap(_0: PyExprMutator, _1: Id, /) -> Var: ... + def PyExprMutatorLookupBinding(_0: PyExprMutator, _1: Var, /) -> RelaxExpr | None: ... + def PyExprMutatorSetVarRemap(_0: PyExprMutator, _1: Id, _2: Var, /) -> Var: ... + def PyExprMutatorVisitBinding(_0: PyExprMutator, _1: Binding, /) -> None: ... + def PyExprMutatorVisitBindingBlock(_0: PyExprMutator, _1: BindingBlock, /) -> BindingBlock: ... + def PyExprMutatorVisitExpr(_0: PyExprMutator, _1: RelaxExpr, /) -> RelaxExpr: ... + def PyExprMutatorVisitExprPostOrder(_0: PyExprMutator, _1: RelaxExpr, /) -> RelaxExpr: ... + def PyExprMutatorVisitVarDef(_0: PyExprMutator, _1: Var, /) -> Var: ... + def PyExprMutatorVisitWithNewScope(_0: PyExprMutator, _1: RelaxExpr, /) -> RelaxExpr: ... + def PyExprMutatorWithStructInfo(_0: PyExprMutator, _1: Var, _2: StructInfo, /) -> Var: ... + def PyExprVisitorVisitBinding(_0: PyExprVisitor, _1: Binding, /) -> None: ... + def PyExprVisitorVisitBindingBlock(_0: PyExprVisitor, _1: BindingBlock, /) -> None: ... + def PyExprVisitorVisitExpr(_0: PyExprVisitor, _1: RelaxExpr, /) -> None: ... + def PyExprVisitorVisitVarDef(_0: PyExprVisitor, _1: Var, /) -> None: ... + def SeqExpr(_0: Sequence[BindingBlock], _1: RelaxExpr, _2: Span, /) -> SeqExpr: ... + def ShapeExpr(_0: Sequence[PrimExpr], _1: Span, /) -> ShapeExpr: ... + def ShapeStructInfo(_0: Sequence[PrimExpr] | None, _1: int, _2: Span, /) -> ShapeStructInfo: ... + def ShapeType(_0: int, _1: Span, /) -> ShapeType: ... + def StringImm(_0: str, _1: Span, /) -> StringImm: ... + def StructInfoIsBaseOf(_0: StructInfo, _1: StructInfo, /) -> bool: ... + def TETensor(_0: RelaxExpr, _1: Mapping[Var, PrimExpr], _2: str, /) -> Tensor: ... + def TensorStructInfo(_0: RelaxExpr | None, _1: dtype | None, _2: int, _3: VDevice, _4: Span, /) -> TensorStructInfo: ... + def TensorType(_0: int, _1: dtype, _2: Span, /) -> DynTensorType: ... + def Tuple(_0: Sequence[RelaxExpr], _1: Span, /) -> Tuple: ... + def TupleGetItem(_0: RelaxExpr, _1: int, _2: Span, /) -> TupleGetItem: ... + def TupleStructInfo(_0: Sequence[StructInfo], _1: Span, /) -> TupleStructInfo: ... + def UpdateStructInfo(_0: RelaxExpr, _1: StructInfo, /) -> None: ... + def VMCodeGen(_0: ExecBuilder, _1: IRModule, /) -> IRModule: ... + def VMLink(_0: ExecBuilder, _1: Target, _2: Module | None, _3: Sequence[Module], _4: Mapping[str, Tensor], /) -> Module: ... + def VMTIRCodeGen(_0: ExecBuilder, _1: IRModule, /) -> IRModule: ... + def Var(_0: str, _1: StructInfo | None, _2: Span, /) -> Var: ... + def VarBinding(_0: Var, _1: RelaxExpr, _2: Span, /) -> VarBinding: ... + def VarFromId(_0: Id, _1: StructInfo | None, _2: Span, /) -> Var: ... + def dfb_rewrite_add(_0: DataflowBlockRewrite, _1: RelaxExpr, _2: str | None, _3: bool, /) -> None: ... + def dfb_rewrite_add_binding(_0: DataflowBlockRewrite, _1: Binding, /) -> None: ... + def dfb_rewrite_mutate_irmodule(_0: DataflowBlockRewrite, _1: IRModule, /) -> IRModule: ... + def dfb_rewrite_remove_all_unused(_0: DataflowBlockRewrite, /) -> None: ... + def dfb_rewrite_remove_unused(_0: DataflowBlockRewrite, _1: Var, _2: bool, /) -> None: ... + def dfb_rewrite_replace_all_uses(_0: DataflowBlockRewrite, _1: Var, _2: Var, /) -> None: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/analysis/_ffi_api.py b/python/tvm/relax/analysis/_ffi_api.py index 0a230fbd8bb6..3c570e4b3fde 100644 --- a/python/tvm/relax/analysis/_ffi_api.py +++ b/python/tvm/relax/analysis/_ffi_api.py @@ -14,6 +14,53 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import GlobalVar, IRModule, PrimExpr, RelaxExpr, StructInfo, Type + from relax import BlockBuilder, FuncStructInfo + from relax.expr import Binding, Call, DataflowBlock, Function, Var + from tir import Block, IndexMap, PrimFunc, Var + from tvm_ffi import Object + from typing import Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.analysis", __name__) + +# tvm-ffi-stubgen(begin): global/relax.analysis +# fmt: off +_FFI_INIT_FUNC("relax.analysis", __name__) +if TYPE_CHECKING: + def CollectNonNegativeExpressions(_0: StructInfo, /) -> Sequence[PrimExpr]: ... + def DefinableTIRVarsInStructInfo(_0: StructInfo, /) -> Sequence[Var]: ... + def DefinedSymbolicVars(_0: RelaxExpr, /) -> Sequence[Var]: ... + def DeriveCallRetStructInfo(_0: FuncStructInfo, _1: Call, _2: BlockBuilder, /) -> StructInfo: ... + def EraseToWellDefined(_0: StructInfo, _1: Mapping[Var, PrimExpr], _2: Mapping[Var, RelaxExpr], /) -> StructInfo: ... + def FreeSymbolicVars(_0: RelaxExpr, /) -> Sequence[Var]: ... + def GetStaticType(_0: StructInfo, /) -> Type: ... + def StructInfoBaseCheck(_0: StructInfo, _1: StructInfo, /) -> int: ... + def StructInfoLCA(_0: StructInfo, _1: StructInfo, /) -> StructInfo: ... + def TIRVarsInStructInfo(_0: StructInfo, /) -> Sequence[Var]: ... + def all_global_vars(_0: RelaxExpr, /) -> Sequence[GlobalVar]: ... + def all_vars(_0: RelaxExpr, /) -> Sequence[Var]: ... + def bound_vars(_0: RelaxExpr, /) -> Sequence[Var]: ... + def computable_at_compile_time(_0: Function, /) -> Sequence[Var]: ... + def contains_impure_call(_0: RelaxExpr, _1: RelaxExpr | None, /) -> bool: ... + def detect_recursion(_0: IRModule, /) -> Sequence[Sequence[GlobalVar]]: ... + def free_vars(_0: RelaxExpr, /) -> Sequence[Var]: ... + def get_var2val(_0: Function, /) -> Mapping[Var, RelaxExpr]: ... + def has_reshape_pattern(_0: PrimFunc, /) -> bool: ... + def name_to_binding(_0: Function, /) -> Mapping[str, Sequence[Binding]]: ... + def post_order_visit(_0: RelaxExpr, _1: Callable[..., Any], /) -> None: ... + def remove_all_unused(_0: RelaxExpr, /) -> RelaxExpr: ... + def suggest_layout_transforms(_0: PrimFunc, _1: Sequence[IndexMap], /) -> Mapping[Block, Mapping[Object, IndexMap]]: ... + def udchain(_0: DataflowBlock, /) -> Mapping[Var, Sequence[Var]]: ... + def well_formed(_0: IRModule | Function, _1: bool, /) -> bool: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/backend/_ffi_api.py b/python/tvm/relax/backend/_ffi_api.py index 97a999788b93..6fcf84e3e644 100644 --- a/python/tvm/relax/backend/_ffi_api.py +++ b/python/tvm/relax/backend/_ffi_api.py @@ -15,7 +15,31 @@ # specific language governing permissions and limitations # under the License. """FFI API for Relax backend.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import IRModule + from meta_schedule import ExtractedTask + from relax.transform import FusionPattern + from target import Target +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -import tvm_ffi -tvm_ffi.init_ffi_api("relax.backend", __name__) +# tvm-ffi-stubgen(begin): global/relax.backend +# fmt: off +_FFI_INIT_FUNC("relax.backend", __name__) +if TYPE_CHECKING: + def GetPattern(_0: str, /) -> FusionPattern | None: ... + def GetPatternsWithPrefix(_0: str, /) -> Sequence[FusionPattern]: ... + def MetaScheduleExtractTask(_0: IRModule, _1: Target, _2: str, /) -> Sequence[ExtractedTask]: ... + def RegisterPatterns(_0: Sequence[FusionPattern], /) -> None: ... + def RemovePatterns(_0: Sequence[str], /) -> None: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/backend/adreno/transform/_ffi_api.py b/python/tvm/relax/backend/adreno/transform/_ffi_api.py index d665ba02a70e..1cd2b63d8f3b 100644 --- a/python/tvm/relax/backend/adreno/transform/_ffi_api.py +++ b/python/tvm/relax/backend/adreno/transform/_ffi_api.py @@ -14,6 +14,25 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs for Adreno transform""" -import tvm.ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from target import Target + from transform import Pass +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm.ffi.init_ffi_api("relax.backend.adreno.transform", __name__) + +# tvm-ffi-stubgen(begin): global/relax.backend.adreno.transform +# fmt: off +_FFI_INIT_FUNC("relax.backend.adreno.transform", __name__) +if TYPE_CHECKING: + def AnnotateCustomMemoryScope(_0: Target, /) -> Pass: ... + def FoldVDeviceScopeChange() -> Pass: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/binding_rewrite.py b/python/tvm/relax/binding_rewrite.py index 077f8feebb90..6ee11c9a503b 100644 --- a/python/tvm/relax/binding_rewrite.py +++ b/python/tvm/relax/binding_rewrite.py @@ -17,6 +17,16 @@ # pylint: disable=no-else-return, invalid-name """Developer API of add/remove/replace bindings in Relax.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from relax.expr import DataflowBlock, Function +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import Optional import tvm @@ -39,6 +49,13 @@ class DataflowBlockRewrite(Object): use mutate_irmodule which rewrites the old function that registered in the constructor. """ + # tvm-ffi-stubgen(begin): object/relax.DataflowBlockRewrite + # fmt: off + dfb: DataflowBlock + root_fn: Function | None + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, dfb: DataflowBlock, root_fn: Function): """ Construct a rewriter with the DataflowBlock to rewrite and its root function. diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index 8c777eb53756..3dc6979bcbf4 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -17,6 +17,8 @@ # pylint: disable=no-else-return, invalid-name, unused-argument, import-outside-toplevel """Developer API of constructing Relax AST.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from typing import Any, Callable, Dict, List, Optional, Sequence, Union import tvm @@ -150,6 +152,11 @@ class BlockBuilder(Object): mod = bb.get() """ + # tvm-ffi-stubgen(begin): object/relax.BlockBuilder + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + _stack = [] @staticmethod diff --git a/python/tvm/relax/distributed/_ffi_api.py b/python/tvm/relax/distributed/_ffi_api.py index 71185a1276da..5a389164175a 100644 --- a/python/tvm/relax/distributed/_ffi_api.py +++ b/python/tvm/relax/distributed/_ffi_api.py @@ -15,6 +15,32 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.relax.distributed""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import IntImm, Range, Span + from relax import DTensorStructInfo, TensorStructInfo + from relax.distributed import DeviceMesh, Placement, PlacementSpec + from tvm_ffi import Shape +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.distributed", __name__) + +# tvm-ffi-stubgen(begin): global/relax.distributed +# fmt: off +_FFI_INIT_FUNC("relax.distributed", __name__) +if TYPE_CHECKING: + def DTensorStructInfo(_0: TensorStructInfo, _1: DeviceMesh, _2: Placement, _3: Span, /) -> DTensorStructInfo: ... + def DeviceMesh(_0: Shape, _1: Sequence[IntImm], _2: Range | None, /) -> DeviceMesh: ... + def Placement(_0: Sequence[PlacementSpec], /) -> Placement: ... + def PlacementFromText(_0: str, /) -> Placement: ... + def Replica() -> PlacementSpec: ... + def Sharding(_0: int, /) -> PlacementSpec: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/distributed/global_info.py b/python/tvm/relax/distributed/global_info.py index 34d3f2da4720..fd5dccf9473b 100644 --- a/python/tvm/relax/distributed/global_info.py +++ b/python/tvm/relax/distributed/global_info.py @@ -16,6 +16,18 @@ # under the License. # pylint: disable=redefined-builtin, invalid-name """Global Info Data structures for distributed tensor.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import IntImm, Range + from tvm_ffi import Shape +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import List, Union, Tuple import tvm_ffi @@ -39,6 +51,14 @@ class DeviceMesh(GlobalInfo): Represents the device id in the mesh """ + # tvm-ffi-stubgen(begin): object/relax.distributed.DeviceMesh + # fmt: off + shape: Shape + device_ids: Sequence[IntImm] + device_range: Range | None + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, shape: Union[ShapeTuple, List[int], Tuple[int]], device_ids: Union[List[int], Range] ): diff --git a/python/tvm/relax/distributed/struct_info.py b/python/tvm/relax/distributed/struct_info.py index 554c83e47490..688fa25223aa 100644 --- a/python/tvm/relax/distributed/struct_info.py +++ b/python/tvm/relax/distributed/struct_info.py @@ -16,6 +16,18 @@ # under the License. # pylint: disable=redefined-builtin, invalid-name """Struct Info for distributed tensor.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from relax import TensorStructInfo + from relax.distributed import DeviceMesh +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import enum from typing import List import tvm_ffi @@ -46,6 +58,13 @@ class PlacementSpec(Object): The kind of placement spec. Possible values: kSharding and kReplica. """ + # tvm-ffi-stubgen(begin): object/relax.distributed.PlacementSpec + # fmt: off + axis: int + kind: int + # fmt: on + # tvm-ffi-stubgen(end) + axis: int kind: PlacementSpecKind @@ -90,6 +109,12 @@ class Placement(Object): The placement spec for each dimension of the device mesh. """ + # tvm-ffi-stubgen(begin): object/relax.distributed.Placement + # fmt: off + dim_specs: Sequence[PlacementSpec] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, dim_specs: List[PlacementSpec]): self.__init_handle_by_constructor__(_ffi_api.Placement, dim_specs) # type: ignore @@ -125,6 +150,14 @@ class DTensorStructInfo(StructInfo): """ + # tvm-ffi-stubgen(begin): object/relax.DTensorStructInfo + # fmt: off + device_mesh: DeviceMesh + placement: Placement + tensor_sinfo: TensorStructInfo + # fmt: on + # tvm-ffi-stubgen(end) + tensor_sinfo: TensorStructInfo device_mesh: DeviceMesh placement: Placement diff --git a/python/tvm/relax/distributed/transform/_ffi_api.py b/python/tvm/relax/distributed/transform/_ffi_api.py index 35808cc2bc93..1e60f4c1c454 100644 --- a/python/tvm/relax/distributed/transform/_ffi_api.py +++ b/python/tvm/relax/distributed/transform/_ffi_api.py @@ -14,6 +14,26 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs for tvm.relax.distributed.transform""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from transform import Pass +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.distributed.transform", __name__) + +# tvm-ffi-stubgen(begin): global/relax.distributed.transform +# fmt: off +_FFI_INIT_FUNC("relax.distributed.transform", __name__) +if TYPE_CHECKING: + def LegalizeRedistribute() -> Pass: ... + def LowerDistIR() -> Pass: ... + def LowerGlobalViewToLocalView() -> Pass: ... + def PropagateSharding() -> Pass: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/dpl/rewrite.py b/python/tvm/relax/dpl/rewrite.py index 6dd730e83147..1ac60cb6d8e1 100644 --- a/python/tvm/relax/dpl/rewrite.py +++ b/python/tvm/relax/dpl/rewrite.py @@ -16,6 +16,19 @@ # under the License. """APIs for pattern-based rewriting.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import RelaxExpr + from relax.dpl import DFPattern + from typing import Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import Dict, Callable, Union from tvm.ir import IRModule @@ -32,6 +45,11 @@ class PatternMatchingRewriter(Object): """A pattern-matching rewriter for Relax""" + # tvm-ffi-stubgen(begin): object/relax.dpl.PatternMatchingRewriter + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + @staticmethod def from_pattern( pattern: DFPattern, @@ -173,6 +191,12 @@ def __or__(self, other: "PatternMatchingRewriter") -> "PatternMatchingRewriter": @register_object("relax.dpl.ExprPatternRewriter") class ExprPatternRewriter(PatternMatchingRewriter): + # tvm-ffi-stubgen(begin): object/relax.dpl.ExprPatternRewriter + # fmt: off + pattern: DFPattern + func: Callable[[RelaxExpr, Mapping[DFPattern, RelaxExpr]], RelaxExpr | None] + # fmt: on + # tvm-ffi-stubgen(end) def __init__(self, pattern, func): self.__init_handle_by_constructor__( ffi.PatternRewriter, @@ -183,6 +207,12 @@ def __init__(self, pattern, func): @register_object("relax.dpl.OrRewriter") class OrRewriter(PatternMatchingRewriter): + # tvm-ffi-stubgen(begin): object/relax.dpl.OrRewriter + # fmt: off + lhs: PatternMatchingRewriter + rhs: PatternMatchingRewriter + # fmt: on + # tvm-ffi-stubgen(end) def __init__(self, lhs, rhs): self.__init_handle_by_constructor__( ffi.OrRewriter, @@ -193,6 +223,12 @@ def __init__(self, lhs, rhs): @register_object("relax.dpl.TupleRewriter") class TupleRewriter(PatternMatchingRewriter): + # tvm-ffi-stubgen(begin): object/relax.dpl.TupleRewriter + # fmt: off + patterns: Sequence[DFPattern] + func: Callable[[RelaxExpr, Mapping[DFPattern, RelaxExpr]], RelaxExpr | None] + # fmt: on + # tvm-ffi-stubgen(end) def __init__(self, patterns, func): self.__init_handle_by_constructor__( ffi.TupleRewriter, diff --git a/python/tvm/relax/exec_builder.py b/python/tvm/relax/exec_builder.py index 50d6c0679eca..0dd80e95e585 100644 --- a/python/tvm/relax/exec_builder.py +++ b/python/tvm/relax/exec_builder.py @@ -16,6 +16,8 @@ # under the License. # pylint: disable=invalid-name """A builder to build Relax VM executable.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from enum import IntEnum from typing import Optional, Union, List import tvm @@ -60,6 +62,11 @@ def __exit__(self, ptype, value, trace): class ExecBuilder(tvm_ffi.core.Object): """A builder to emit instructions and build executable for the virtual machine.""" + # tvm-ffi-stubgen(begin): object/relax.ExecBuilder + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self) -> None: self.__init_handle_by_constructor__(_ffi_api.ExecBuilderCreate) # type: ignore diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index e9bc9a7a3e98..02ee90501d75 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -15,6 +15,19 @@ # specific language governing permissions and limitations # under the License. """The expression nodes of Relax.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import Attrs, PrimExpr, RelaxExpr, Span + from tvm_ffi import Tensor, dtype + from typing import Any +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import typing from numbers import Number from typing import Any, Callable, Dict, List, Optional, Union, Mapping @@ -48,6 +61,12 @@ class Id(Object): Guaranteed to be stable across all passes. """ + # tvm-ffi-stubgen(begin): object/relax.Id + # fmt: off + name_hint: str + # fmt: on + # tvm-ffi-stubgen(end) + name_hint: str def __init__(self): @@ -64,6 +83,12 @@ class StructInfo(Node, Scriptable): and runtime structural information. """ + # tvm-ffi-stubgen(begin): object/ir.StructInfo + # fmt: off + span: Span + # fmt: on + # tvm-ffi-stubgen(end) + def __eq__(self, other): """Compare two struct info for structural equivalence.""" return tvm.ir.structural_equal(self, other) @@ -555,6 +580,15 @@ class Call(ExprWithOp): Span that points to original source code """ + # tvm-ffi-stubgen(begin): object/relax.expr.Call + # fmt: off + op: RelaxExpr + args: Sequence[RelaxExpr] + attrs: Attrs + sinfo_args: Sequence[StructInfo] + # fmt: on + # tvm-ffi-stubgen(end) + op: Expr args: List[Expr] attrs: tvm.ir.Attrs @@ -595,6 +629,14 @@ class If(ExprWithOp): Span that points to original source code """ + # tvm-ffi-stubgen(begin): object/relax.expr.If + # fmt: off + cond: RelaxExpr + true_branch: SeqExpr + false_branch: SeqExpr + # fmt: on + # tvm-ffi-stubgen(end) + cond: Expr true_branch: Expr false_branch: Expr @@ -621,6 +663,12 @@ class Tuple(ExprWithOp): Span that points to original source code """ + # tvm-ffi-stubgen(begin): object/relax.expr.Tuple + # fmt: off + fields: Sequence[RelaxExpr] + # fmt: on + # tvm-ffi-stubgen(end) + fields: List[Expr] span: Optional[Span] @@ -659,6 +707,13 @@ class TupleGetItem(ExprWithOp): Span that points to original source code """ + # tvm-ffi-stubgen(begin): object/relax.expr.TupleGetItem + # fmt: off + tuple_value: RelaxExpr + index: int + # fmt: on + # tvm-ffi-stubgen(end) + tuple_value: Expr index: int span: Optional[Span] @@ -682,6 +737,12 @@ class ShapeExpr(ExprWithOp): Span that points to original source code """ + # tvm-ffi-stubgen(begin): object/relax.expr.ShapeExpr + # fmt: off + values: Sequence[PrimExpr] + # fmt: on + # tvm-ffi-stubgen(end) + values: List[PrimExpr] span: Optional[Span] @@ -727,6 +788,12 @@ class Constant(ExprWithOp): Scalar constants are represented by ndim-0 constant tensors. """ + # tvm-ffi-stubgen(begin): object/relax.expr.Constant + # fmt: off + data: Tensor + # fmt: on + # tvm-ffi-stubgen(end) + data: tvm.runtime.Tensor span: Optional[Span] @@ -757,6 +824,12 @@ class Var(ExprWithOp): Span that points to original source code """ + # tvm-ffi-stubgen(begin): object/relax.expr.Var + # fmt: off + vid: Id + # fmt: on + # tvm-ffi-stubgen(end) + vid: Id span: Optional[Span] @@ -806,6 +879,11 @@ class DataflowVar(Var): Span that points to original source code """ + # tvm-ffi-stubgen(begin): object/relax.expr.DataflowVar + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + vid: Id span: Optional[Span] @@ -841,6 +919,12 @@ def __init__( class PrimValue(Expr, Scriptable): """The prim expr representing the value.""" + # tvm-ffi-stubgen(begin): object/relax.expr.PrimValue + # fmt: off + value: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + value: PrimExpr def __init__(self, value: Union[PrimExpr, int], span: Optional[Span] = None) -> None: @@ -853,6 +937,12 @@ def __init__(self, value: Union[PrimExpr, int], span: Optional[Span] = None) -> class StringImm(Expr, Scriptable): """Represent a string literal constant.""" + # tvm-ffi-stubgen(begin): object/relax.expr.StringImm + # fmt: off + value: str + # fmt: on + # tvm-ffi-stubgen(end) + value: str span: Optional[Span] @@ -864,6 +954,12 @@ def __init__(self, value: str, span: Optional[Span] = None) -> None: class DataTypeImm(Expr, Scriptable): """Represent a data type constant.""" + # tvm-ffi-stubgen(begin): object/relax.expr.DataTypeImm + # fmt: off + value: dtype + # fmt: on + # tvm-ffi-stubgen(end) + value: DataType span: Optional[Span] @@ -875,6 +971,13 @@ def __init__(self, value: Union[DataType, str], span: Optional[Span] = None) -> class Binding(Node, Scriptable): """The base class of a binding in Relax.""" + # tvm-ffi-stubgen(begin): object/relax.expr.Binding + # fmt: off + span: Span + var: Var + # fmt: on + # tvm-ffi-stubgen(end) + var: Var span: Optional[Span] @@ -899,6 +1002,13 @@ class MatchCast(Binding): The struct info to match cast to. """ + # tvm-ffi-stubgen(begin): object/relax.expr.MatchCast + # fmt: off + value: RelaxExpr + struct_info: StructInfo + # fmt: on + # tvm-ffi-stubgen(end) + struct_info: StructInfo value: Expr span: Optional[Span] @@ -925,6 +1035,12 @@ class VarBinding(Binding): """ + # tvm-ffi-stubgen(begin): object/relax.expr.VarBinding + # fmt: off + value: RelaxExpr + # fmt: on + # tvm-ffi-stubgen(end) + var: Var value: Expr span: Optional[Span] @@ -938,6 +1054,13 @@ class BindingBlock(Node, Scriptable): """base class of binding block, bindings inside can be impure (with side effect or control flow)""" + # tvm-ffi-stubgen(begin): object/relax.expr.BindingBlock + # fmt: off + bindings: Sequence[Binding] + span: Span + # fmt: on + # tvm-ffi-stubgen(end) + bindings: List[Binding] span: Optional[Span] @@ -949,6 +1072,11 @@ def __init__(self, bindings: List[Binding], span: Optional[Span] = None) -> None class DataflowBlock(BindingBlock): """dataflow block, bindings inside are pure (no side effect and no control flow)""" + # tvm-ffi-stubgen(begin): object/relax.expr.DataflowBlock + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + bindings: List[Binding] span: Optional[Span] @@ -961,6 +1089,13 @@ def __init__(self, bindings: List[Binding], span: Optional[Span] = None) -> None class SeqExpr(ExprWithOp): """A sequence of binding blocks followed by an expression.""" + # tvm-ffi-stubgen(begin): object/relax.expr.SeqExpr + # fmt: off + blocks: Sequence[BindingBlock] + body: RelaxExpr + # fmt: on + # tvm-ffi-stubgen(end) + blocks: List[BindingBlock] body: Expr span: Optional[Span] @@ -973,6 +1108,15 @@ def __init__(self, blocks: List[BindingBlock], body: Expr, span: Optional[Span] class Function(BaseFunc, Scriptable): """A Relax function.""" + # tvm-ffi-stubgen(begin): object/relax.expr.Function + # fmt: off + params: Sequence[Var] + body: SeqExpr + ret_struct_info: StructInfo + is_pure: bool + # fmt: on + # tvm-ffi-stubgen(end) + params: List[Var] body: Expr ret_struct_info: StructInfo @@ -1112,6 +1256,12 @@ def inline_functions( class ExternFunc(BaseFunc, ExprWithOp): """extern function, which represents a PackedFunc.""" + # tvm-ffi-stubgen(begin): object/relax.expr.ExternFunc + # fmt: off + global_symbol: str + # fmt: on + # tvm-ffi-stubgen(end) + global_symbol: String span: Optional[Span] @@ -1183,6 +1333,17 @@ def const( class TEPlaceholderOp(tvm.te.tensor.Operation): """The placeholder op that represents a relax expression.""" + # tvm-ffi-stubgen(begin): object/relax.TEPlaceholderOp + # fmt: off + name: str + tag: str + attrs: Mapping[str, Any] + value: RelaxExpr + shape: Sequence[PrimExpr] + dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + def te_tensor( value: Expr, tir_var_map: Dict[tvm.tir.Var, tvm.tir.PrimExpr], name: str = "rxplaceholder" diff --git a/python/tvm/relax/expr_functor.py b/python/tvm/relax/expr_functor.py index e5e77251c66d..40174f72218b 100644 --- a/python/tvm/relax/expr_functor.py +++ b/python/tvm/relax/expr_functor.py @@ -16,6 +16,16 @@ # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, arguments-differ """The expression functor of Relax.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from relax import BlockBuilder +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import Callable, Optional import tvm_ffi @@ -272,6 +282,11 @@ class _PyExprVisitor(tvm_ffi.core.Object): See also: visitor, PyExprVisitor """ + # tvm-ffi-stubgen(begin): object/expr_functor.PyExprVisitor + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, f_visit_expr: Callable = None, @@ -792,6 +807,12 @@ class _PyExprMutator(Object): See also: mutator, PyExprmutator """ + # tvm-ffi-stubgen(begin): object/expr_functor.PyExprMutator + # fmt: off + builder_: BlockBuilder + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, builder: BlockBuilder = None, diff --git a/python/tvm/relax/op/_ffi_api.py b/python/tvm/relax/op/_ffi_api.py index 867c43e4d85b..ed4d541841c8 100644 --- a/python/tvm/relax/op/_ffi_api.py +++ b/python/tvm/relax/op/_ffi_api.py @@ -14,6 +14,170 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs for tvm.relax.op""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import Attrs, IntImm, PrimExpr, RelaxExpr, StructInfo, VDevice + from relax import TensorStructInfo + from relax.expr import PrimValue, StringImm, Tuple + from tir import IndexMap + from tvm_ffi import dtype + from typing import Any +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.op", __name__) + +# tvm-ffi-stubgen(begin): global/relax.op +# fmt: off +_FFI_INIT_FUNC("relax.op", __name__) +if TYPE_CHECKING: + def abs(_0: RelaxExpr, /) -> RelaxExpr: ... + def acos(_0: RelaxExpr, /) -> RelaxExpr: ... + def acosh(_0: RelaxExpr, /) -> RelaxExpr: ... + def add(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def arange(_0: PrimValue, _1: PrimValue, _2: PrimValue, _3: dtype, /) -> RelaxExpr: ... + def argmax(_0: RelaxExpr, _1: int | None, _2: bool, /) -> RelaxExpr: ... + def argmin(_0: RelaxExpr, _1: int | None, _2: bool, /) -> RelaxExpr: ... + def argsort(_0: RelaxExpr, _1: int, _2: bool, _3: dtype, /) -> RelaxExpr: ... + def asin(_0: RelaxExpr, /) -> RelaxExpr: ... + def asinh(_0: RelaxExpr, /) -> RelaxExpr: ... + def assert_op(_0: RelaxExpr, _1: Sequence[RelaxExpr], _2: StringImm, /) -> RelaxExpr: ... + def astype(_0: RelaxExpr, _1: dtype, /) -> RelaxExpr: ... + def atan(_0: RelaxExpr, /) -> RelaxExpr: ... + def atanh(_0: RelaxExpr, /) -> RelaxExpr: ... + def bitwise_and(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def bitwise_not(_0: RelaxExpr, /) -> RelaxExpr: ... + def bitwise_or(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def bitwise_xor(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def broadcast_to(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def bucketize(_0: RelaxExpr, _1: RelaxExpr, _2: bool, _3: bool, /) -> RelaxExpr: ... + def call_builtin_with_ctx(_0: RelaxExpr, _1: Tuple, _2: Sequence[StructInfo], /) -> RelaxExpr: ... + def call_dps_packed(_0: RelaxExpr, _1: Tuple, _2: Sequence[TensorStructInfo], /) -> RelaxExpr: ... + def call_inplace_packed(_0: RelaxExpr, _1: Sequence[RelaxExpr], _2: Sequence[IntImm], _3: Sequence[StructInfo], /) -> RelaxExpr: ... + def call_pure_packed(_0: RelaxExpr, _1: Sequence[RelaxExpr], _2: Attrs, _3: Sequence[StructInfo], /) -> RelaxExpr: ... + def call_py_func(_0: StringImm, _1: Tuple, _2: Sequence[TensorStructInfo], /) -> RelaxExpr: ... + def call_tir(_0: RelaxExpr, _1: Tuple, _2: Sequence[TensorStructInfo], _3: RelaxExpr | None, /) -> RelaxExpr: ... + def call_tir_inplace(_0: RelaxExpr, _1: Tuple, _2: Sequence[IntImm], _3: Sequence[TensorStructInfo], _4: RelaxExpr | None, /) -> RelaxExpr: ... + def call_tir_with_grad(_0: RelaxExpr, _1: Tuple, _2: Sequence[TensorStructInfo], _3: str, _4: Mapping[str, Any], _5: RelaxExpr | None, /) -> RelaxExpr: ... + def ceil(_0: RelaxExpr, /) -> RelaxExpr: ... + def clip(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, /) -> RelaxExpr: ... + def collapse_sum_like(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def collapse_sum_to(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def concat(_0: RelaxExpr, _1: int | None, /) -> RelaxExpr: ... + def cos(_0: RelaxExpr, /) -> RelaxExpr: ... + def cosh(_0: RelaxExpr, /) -> RelaxExpr: ... + def cumprod(_0: RelaxExpr, _1: int | None, _2: dtype | None, _3: IntImm, /) -> RelaxExpr: ... + def cumsum(_0: RelaxExpr, _1: int | None, _2: dtype | None, _3: IntImm, /) -> RelaxExpr: ... + def dequantize(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: int, _4: dtype, /) -> RelaxExpr: ... + def divide(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def dynamic_strided_slice(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: RelaxExpr, /) -> RelaxExpr: ... + def einsum(_0: RelaxExpr, _1: str, /) -> RelaxExpr: ... + def equal(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def erf(_0: RelaxExpr, /) -> RelaxExpr: ... + def ewise_fma(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, /) -> RelaxExpr: ... + def exp(_0: RelaxExpr, /) -> RelaxExpr: ... + def expand_dims(_0: RelaxExpr, _1: Sequence[IntImm], /) -> RelaxExpr: ... + def eye(_0: PrimValue, _1: PrimValue, _2: PrimValue, _3: dtype, /) -> RelaxExpr: ... + def eye_like(_0: RelaxExpr, _1: PrimValue, _2: dtype | None, /) -> RelaxExpr: ... + def flatten(_0: RelaxExpr, /) -> RelaxExpr: ... + def flip(_0: RelaxExpr, _1: IntImm, /) -> RelaxExpr: ... + def floor(_0: RelaxExpr, /) -> RelaxExpr: ... + def floor_divide(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def floor_mod(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def full(_0: RelaxExpr | Sequence[PrimExpr], _1: RelaxExpr, _2: dtype | None, /) -> RelaxExpr: ... + def full_like(_0: RelaxExpr, _1: RelaxExpr, _2: dtype | None, /) -> RelaxExpr: ... + def gather_elements(_0: RelaxExpr, _1: RelaxExpr, _2: int, /) -> RelaxExpr: ... + def gather_nd(_0: RelaxExpr, _1: RelaxExpr, _2: int, /) -> RelaxExpr: ... + def greater(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def greater_equal(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def hamming_window(_0: PrimValue, _1: PrimValue, _2: PrimValue, _3: PrimValue, _4: dtype, /) -> RelaxExpr: ... + def hint_on_device(*args: Any) -> Any: ... + def index_put(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: bool, /) -> RelaxExpr: ... + def index_tensor(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def invoke_closure(_0: RelaxExpr, _1: Tuple, _2: Sequence[StructInfo], /) -> RelaxExpr: ... + def invoke_pure_closure(_0: RelaxExpr, _1: Tuple, _2: Sequence[StructInfo], /) -> RelaxExpr: ... + def isfinite(_0: RelaxExpr, /) -> RelaxExpr: ... + def isinf(_0: RelaxExpr, /) -> RelaxExpr: ... + def isnan(_0: RelaxExpr, /) -> RelaxExpr: ... + def layout_transform(_0: RelaxExpr, _1: IndexMap, _2: PrimValue | None, _3: Sequence[IntImm] | None, _4: Sequence[IntImm] | None, /) -> RelaxExpr: ... + def left_shift(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def less(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def less_equal(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def log(_0: RelaxExpr, /) -> RelaxExpr: ... + def log_add_exp(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def logical_and(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def logical_not(_0: RelaxExpr, /) -> RelaxExpr: ... + def logical_or(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def logical_xor(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def make_closure(_0: RelaxExpr, _1: Tuple, /) -> RelaxExpr: ... + def matmul(_0: RelaxExpr, _1: RelaxExpr, _2: dtype | None, /) -> RelaxExpr: ... + def max(_0: RelaxExpr, _1: Sequence[IntImm] | None, _2: bool, /) -> RelaxExpr: ... + def maximum(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def mean(_0: RelaxExpr, _1: Sequence[IntImm] | None, _2: bool, /) -> RelaxExpr: ... + def meshgrid(_0: RelaxExpr, _1: str | None, /) -> RelaxExpr: ... + def min(_0: RelaxExpr, _1: Sequence[IntImm] | None, _2: bool, /) -> RelaxExpr: ... + def minimum(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def mod(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def multinomial_from_uniform(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: dtype, /) -> RelaxExpr: ... + def multiply(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def negative(_0: RelaxExpr, /) -> RelaxExpr: ... + def nonzero(_0: RelaxExpr, /) -> RelaxExpr: ... + def not_equal(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def null_value() -> RelaxExpr: ... + def one_hot(_0: RelaxExpr, _1: PrimValue, _2: PrimValue, _3: int, _4: int, /) -> RelaxExpr: ... + def ones(_0: RelaxExpr, _1: dtype, /) -> RelaxExpr: ... + def ones_like(_0: RelaxExpr, _1: dtype | None, /) -> RelaxExpr: ... + def outer(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def permute_dims(_0: RelaxExpr, _1: Sequence[IntImm] | None, /) -> RelaxExpr: ... + def power(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def print(_0: Sequence[RelaxExpr], _1: StringImm, /) -> RelaxExpr: ... + def prod(_0: RelaxExpr, _1: Sequence[IntImm] | None, _2: bool, /) -> RelaxExpr: ... + def quantize(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: int, _4: dtype, /) -> RelaxExpr: ... + def repeat(_0: RelaxExpr, _1: int, _2: int | None, /) -> RelaxExpr: ... + def reshape(_0: RelaxExpr, _1: RelaxExpr | Sequence[PrimExpr], /) -> RelaxExpr: ... + def right_shift(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def round(_0: RelaxExpr, /) -> RelaxExpr: ... + def rsqrt(_0: RelaxExpr, /) -> RelaxExpr: ... + def scatter_elements(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: int, _4: str, /) -> RelaxExpr: ... + def scatter_nd(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: str, /) -> RelaxExpr: ... + def shape_of(_0: RelaxExpr, /) -> RelaxExpr: ... + def shape_to_tensor(_0: RelaxExpr, /) -> RelaxExpr: ... + def sigmoid(_0: RelaxExpr, /) -> RelaxExpr: ... + def sign(_0: RelaxExpr, /) -> RelaxExpr: ... + def sin(_0: RelaxExpr, /) -> RelaxExpr: ... + def sinh(_0: RelaxExpr, /) -> RelaxExpr: ... + def slice_scatter(_0: RelaxExpr, _1: RelaxExpr, _2: int, _3: PrimValue, _4: PrimValue, _5: PrimValue, /) -> RelaxExpr: ... + def sort(_0: RelaxExpr, _1: int, _2: bool, /) -> RelaxExpr: ... + def split(_0: RelaxExpr, _1: IntImm | Sequence[IntImm], _2: int, /) -> RelaxExpr: ... + def sqrt(_0: RelaxExpr, /) -> RelaxExpr: ... + def square(_0: RelaxExpr, /) -> RelaxExpr: ... + def squeeze(_0: RelaxExpr, _1: Sequence[IntImm] | None, /) -> RelaxExpr: ... + def stack(_0: RelaxExpr, _1: IntImm | None, /) -> RelaxExpr: ... + def std(_0: RelaxExpr, _1: Sequence[IntImm] | None, _2: bool, /) -> RelaxExpr: ... + def strided_slice(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: RelaxExpr, _4: RelaxExpr | None, _5: bool, /) -> RelaxExpr: ... + def subtract(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def sum(_0: RelaxExpr, _1: Sequence[IntImm] | None, _2: bool, /) -> RelaxExpr: ... + def take(_0: RelaxExpr, _1: RelaxExpr, _2: int | None, _3: str, /) -> RelaxExpr: ... + def tan(_0: RelaxExpr, /) -> RelaxExpr: ... + def tanh(_0: RelaxExpr, /) -> RelaxExpr: ... + def tensor_to_shape(_0: RelaxExpr, /) -> RelaxExpr: ... + def tile(_0: RelaxExpr, _1: Sequence[IntImm], /) -> RelaxExpr: ... + def to_vdevice(_0: RelaxExpr, _1: VDevice, /) -> RelaxExpr: ... + def topk(_0: RelaxExpr, _1: int, _2: int, _3: str, _4: bool, _5: dtype, /) -> RelaxExpr: ... + def tril(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def triu(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def trunc(_0: RelaxExpr, /) -> RelaxExpr: ... + def unique(_0: RelaxExpr, _1: PrimValue, _2: PrimValue, _3: PrimValue, _4: PrimValue, _5: PrimValue | None, /) -> RelaxExpr: ... + def variance(_0: RelaxExpr, _1: Sequence[IntImm] | None, _2: bool, /) -> RelaxExpr: ... + def where(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, /) -> RelaxExpr: ... + def wrap_param(_0: RelaxExpr, _1: dtype, /) -> RelaxExpr: ... + def zeros(_0: RelaxExpr, _1: dtype, /) -> RelaxExpr: ... + def zeros_like(_0: RelaxExpr, _1: dtype | None, /) -> RelaxExpr: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/op/builtin/_ffi_api.py b/python/tvm/relax/op/builtin/_ffi_api.py index 0e5955f6e47d..ae4cd27e1c94 100644 --- a/python/tvm/relax/op/builtin/_ffi_api.py +++ b/python/tvm/relax/op/builtin/_ffi_api.py @@ -14,6 +14,25 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs for tvm.relax.op.builtin""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ir import RelaxExpr + from relax.expr import DataTypeImm, PrimValue, StringImm +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.op.builtin", __name__) + +# tvm-ffi-stubgen(begin): global/relax.op.builtin +# fmt: off +_FFI_INIT_FUNC("relax.op.builtin", __name__) +if TYPE_CHECKING: + def alloc_tensor(_0: RelaxExpr, _1: DataTypeImm, _2: PrimValue, _3: StringImm, /) -> RelaxExpr: ... + def stop_lift_params(_0: RelaxExpr, /) -> RelaxExpr: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/op/ccl/_ffi_api.py b/python/tvm/relax/op/ccl/_ffi_api.py index f0796d3da318..d74fd7d48031 100644 --- a/python/tvm/relax/op/ccl/_ffi_api.py +++ b/python/tvm/relax/op/ccl/_ffi_api.py @@ -15,6 +15,26 @@ # specific language governing permissions and limitations # under the License. """Operators serving for Collective Communications Library (CCL) operators""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ir import RelaxExpr +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.op.ccl", __name__) + +# tvm-ffi-stubgen(begin): global/relax.op.ccl +# fmt: off +_FFI_INIT_FUNC("relax.op.ccl", __name__) +if TYPE_CHECKING: + def allgather(_0: RelaxExpr, _1: int, _2: bool, /) -> RelaxExpr: ... + def allreduce(_0: RelaxExpr, _1: str, _2: bool, /) -> RelaxExpr: ... + def broadcast_from_worker0(_0: RelaxExpr, /) -> RelaxExpr: ... + def scatter_from_worker0(_0: RelaxExpr, _1: int, _2: int, /) -> RelaxExpr: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/op/distributed/_ffi_api.py b/python/tvm/relax/op/distributed/_ffi_api.py index fa1c163794b9..4065214bb0c4 100644 --- a/python/tvm/relax/op/distributed/_ffi_api.py +++ b/python/tvm/relax/op/distributed/_ffi_api.py @@ -15,6 +15,30 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.relax.op.distributed""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import RelaxExpr + from relax import DTensorStructInfo + from relax.distributed import DeviceMesh, Placement + from relax.expr import Tuple +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.op.dist", __name__) + +# tvm-ffi-stubgen(begin): global/relax.op.dist +# fmt: off +_FFI_INIT_FUNC("relax.op.dist", __name__) +if TYPE_CHECKING: + def annotate_sharding(_0: RelaxExpr, _1: DeviceMesh, _2: Placement, /) -> RelaxExpr: ... + def call_tir_local_view(_0: RelaxExpr, _1: Tuple, _2: Sequence[DTensorStructInfo], _3: RelaxExpr | None, /) -> RelaxExpr: ... + def redistribute(_0: RelaxExpr, _1: DeviceMesh, _2: Placement, /) -> RelaxExpr: ... + def redistribute_replica_to_shard(_0: RelaxExpr, _1: int, _2: int, /) -> RelaxExpr: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/op/grad/_ffi_api.py b/python/tvm/relax/op/grad/_ffi_api.py index 1a8ebb09aa8d..49bb114d65db 100644 --- a/python/tvm/relax/op/grad/_ffi_api.py +++ b/python/tvm/relax/op/grad/_ffi_api.py @@ -15,6 +15,30 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.relax.op.grad""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import IntImm, RelaxExpr +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.op.grad", __name__) + +# tvm-ffi-stubgen(begin): global/relax.op.grad +# fmt: off +_FFI_INIT_FUNC("relax.op.grad", __name__) +if TYPE_CHECKING: + def avg_pool2d_backward(_0: RelaxExpr, _1: RelaxExpr, _2: Sequence[IntImm], _3: Sequence[IntImm], _4: Sequence[IntImm], _5: Sequence[IntImm], _6: bool, _7: bool, _8: str, _9: str | None, /) -> RelaxExpr: ... + def end_checkpoint(_0: RelaxExpr, /) -> RelaxExpr: ... + def max_pool2d_backward(_0: RelaxExpr, _1: RelaxExpr, _2: Sequence[IntImm], _3: Sequence[IntImm], _4: Sequence[IntImm], _5: Sequence[IntImm], _6: bool, _7: bool, _8: str, _9: str | None, /) -> RelaxExpr: ... + def nll_loss_backward(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: RelaxExpr | None, _4: str, _5: int, /) -> RelaxExpr: ... + def no_grad(_0: RelaxExpr, /) -> RelaxExpr: ... + def start_checkpoint(_0: RelaxExpr, /) -> RelaxExpr: ... + def take_backward(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: int | None, /) -> RelaxExpr: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/op/image/_ffi_api.py b/python/tvm/relax/op/image/_ffi_api.py index 8147a155cb76..c734db1bd753 100644 --- a/python/tvm/relax/op/image/_ffi_api.py +++ b/python/tvm/relax/op/image/_ffi_api.py @@ -15,6 +15,26 @@ # specific language governing permissions and limitations # under the License. """Constructor APIs""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import FloatImm, RelaxExpr + from tvm_ffi import dtype +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.op.image", __name__) + +# tvm-ffi-stubgen(begin): global/relax.op.image +# fmt: off +_FFI_INIT_FUNC("relax.op.image", __name__) +if TYPE_CHECKING: + def grid_sample(_0: RelaxExpr, _1: RelaxExpr, _2: str, _3: str, _4: str, _5: bool, /) -> RelaxExpr: ... + def resize2d(_0: RelaxExpr, _1: RelaxExpr, _2: Sequence[FloatImm], _3: str, _4: str, _5: str, _6: str, _7: float, _8: int, _9: float, _10: dtype | None, /) -> RelaxExpr: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/op/memory/_ffi_api.py b/python/tvm/relax/op/memory/_ffi_api.py index 05dbf534c7f5..5d67cc449283 100644 --- a/python/tvm/relax/op/memory/_ffi_api.py +++ b/python/tvm/relax/op/memory/_ffi_api.py @@ -14,6 +14,29 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs for tvm.relax.op.memory""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ir import RelaxExpr + from relax.expr import DataTypeImm, PrimValue, StringImm +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.op.memory", __name__) + +# tvm-ffi-stubgen(begin): global/relax.op.memory +# fmt: off +_FFI_INIT_FUNC("relax.op.memory", __name__) +if TYPE_CHECKING: + def alloc_storage(_0: RelaxExpr, _1: PrimValue, _2: StringImm, _3: DataTypeImm, /) -> RelaxExpr: ... + def alloc_tensor(_0: RelaxExpr, _1: PrimValue, _2: RelaxExpr, _3: DataTypeImm, /) -> RelaxExpr: ... + def ensure_zero_offset(_0: RelaxExpr, /) -> RelaxExpr: ... + def kill_storage(_0: RelaxExpr, /) -> RelaxExpr: ... + def kill_tensor(_0: RelaxExpr, /) -> RelaxExpr: ... + def view(_0: RelaxExpr, _1: RelaxExpr | None, _2: RelaxExpr | None, _3: RelaxExpr | None, /) -> RelaxExpr: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/op/nn/_ffi_api.py b/python/tvm/relax/op/nn/_ffi_api.py index d58fa186fc7c..cf446282d6a3 100644 --- a/python/tvm/relax/op/nn/_ffi_api.py +++ b/python/tvm/relax/op/nn/_ffi_api.py @@ -15,6 +15,60 @@ # specific language governing permissions and limitations # under the License. """Constructor APIs""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import FloatImm, IntImm, RelaxExpr + from tvm_ffi import dtype +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.op.nn", __name__) + +# tvm-ffi-stubgen(begin): global/relax.op.nn +# fmt: off +_FFI_INIT_FUNC("relax.op.nn", __name__) +if TYPE_CHECKING: + def adaptive_avg_pool1d(_0: RelaxExpr, _1: Sequence[IntImm] | None, _2: str, _3: str | None, /) -> RelaxExpr: ... + def adaptive_avg_pool2d(_0: RelaxExpr, _1: Sequence[IntImm] | None, _2: str, _3: str | None, /) -> RelaxExpr: ... + def adaptive_avg_pool3d(_0: RelaxExpr, _1: Sequence[IntImm] | None, _2: str, _3: str | None, /) -> RelaxExpr: ... + def attention(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: RelaxExpr | None, _4: FloatImm | None, _5: str | None, _6: IntImm | None, /) -> RelaxExpr: ... + def attention_var_len(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: RelaxExpr, _4: RelaxExpr, _5: RelaxExpr, _6: RelaxExpr, _7: FloatImm | None, _8: str | None, _9: IntImm | None, /) -> RelaxExpr: ... + def avg_pool1d(_0: RelaxExpr, _1: Sequence[IntImm], _2: Sequence[IntImm], _3: Sequence[IntImm], _4: Sequence[IntImm], _5: bool, _6: bool, _7: str, _8: str | None, /) -> RelaxExpr: ... + def avg_pool2d(_0: RelaxExpr, _1: Sequence[IntImm], _2: Sequence[IntImm], _3: Sequence[IntImm], _4: Sequence[IntImm], _5: bool, _6: bool, _7: str, _8: str | None, /) -> RelaxExpr: ... + def avg_pool3d(_0: RelaxExpr, _1: Sequence[IntImm], _2: Sequence[IntImm], _3: Sequence[IntImm], _4: Sequence[IntImm], _5: bool, _6: bool, _7: str, _8: str | None, /) -> RelaxExpr: ... + def batch_norm(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: RelaxExpr, _4: RelaxExpr, _5: int, _6: float, _7: bool, _8: bool, _9: float, _10: bool, /) -> RelaxExpr: ... + def conv1d(_0: RelaxExpr, _1: RelaxExpr, _2: Sequence[IntImm], _3: Sequence[IntImm], _4: Sequence[IntImm], _5: int, _6: str, _7: str, _8: str | None, _9: dtype | None, /) -> RelaxExpr: ... + def conv1d_transpose(_0: RelaxExpr, _1: RelaxExpr, _2: Sequence[IntImm], _3: Sequence[IntImm], _4: Sequence[IntImm], _5: Sequence[IntImm], _6: int, _7: str, _8: str, _9: str | None, _10: dtype | None, /) -> RelaxExpr: ... + def conv2d(_0: RelaxExpr, _1: RelaxExpr, _2: Sequence[IntImm], _3: Sequence[IntImm], _4: Sequence[IntImm], _5: int, _6: str, _7: str, _8: str | None, _9: dtype | None, /) -> RelaxExpr: ... + def conv2d_transpose(_0: RelaxExpr, _1: RelaxExpr, _2: Sequence[IntImm], _3: Sequence[IntImm], _4: Sequence[IntImm], _5: Sequence[IntImm], _6: int, _7: str, _8: str, _9: str | None, _10: dtype | None, /) -> RelaxExpr: ... + def conv3d(_0: RelaxExpr, _1: RelaxExpr, _2: Sequence[IntImm], _3: Sequence[IntImm], _4: Sequence[IntImm], _5: int, _6: str, _7: str, _8: str | None, _9: dtype | None, /) -> RelaxExpr: ... + def cross_entropy_with_logits(_0: RelaxExpr, _1: RelaxExpr, /) -> RelaxExpr: ... + def dropout(_0: RelaxExpr, _1: float, /) -> RelaxExpr: ... + def gelu(_0: RelaxExpr, /) -> RelaxExpr: ... + def gelu_tanh(_0: RelaxExpr, /) -> RelaxExpr: ... + def group_norm(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: int, _4: int, _5: Sequence[IntImm], _6: float, _7: bool, _8: bool, /) -> RelaxExpr: ... + def instance_norm(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: int, _4: Sequence[IntImm], _5: float, _6: bool, _7: bool, /) -> RelaxExpr: ... + def layer_norm(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: Sequence[IntImm], _4: float, _5: bool, _6: bool, /) -> RelaxExpr: ... + def leakyrelu(_0: RelaxExpr, _1: float, /) -> RelaxExpr: ... + def log_softmax(_0: RelaxExpr, _1: int, /) -> RelaxExpr: ... + def max_pool1d(_0: RelaxExpr, _1: Sequence[IntImm], _2: Sequence[IntImm], _3: Sequence[IntImm], _4: Sequence[IntImm], _5: bool, _6: bool, _7: str, _8: str | None, /) -> RelaxExpr: ... + def max_pool2d(_0: RelaxExpr, _1: Sequence[IntImm], _2: Sequence[IntImm], _3: Sequence[IntImm], _4: Sequence[IntImm], _5: bool, _6: bool, _7: str, _8: str | None, /) -> RelaxExpr: ... + def max_pool3d(_0: RelaxExpr, _1: Sequence[IntImm], _2: Sequence[IntImm], _3: Sequence[IntImm], _4: Sequence[IntImm], _5: bool, _6: bool, _7: str, _8: str | None, /) -> RelaxExpr: ... + def nll_loss(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr | None, _3: str, _4: int, /) -> RelaxExpr: ... + def pad(_0: RelaxExpr, _1: Sequence[IntImm], _2: str, _3: float, /) -> RelaxExpr: ... + def pixel_shuffle(_0: RelaxExpr, _1: int, /) -> RelaxExpr: ... + def prelu(_0: RelaxExpr, _1: RelaxExpr, _2: int, /) -> RelaxExpr: ... + def relu(_0: RelaxExpr, /) -> RelaxExpr: ... + def rms_norm(_0: RelaxExpr, _1: RelaxExpr, _2: Sequence[IntImm], _3: float, /) -> RelaxExpr: ... + def selu(_0: RelaxExpr, /) -> RelaxExpr: ... + def silu(_0: RelaxExpr, /) -> RelaxExpr: ... + def softmax(_0: RelaxExpr, _1: int, /) -> RelaxExpr: ... + def softplus(_0: RelaxExpr, _1: float, _2: float, /) -> RelaxExpr: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index 229a789a45ef..9bebbff5c8b6 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -15,6 +15,21 @@ # specific language governing permissions and limitations # under the License. """The attributes node used for Relax operators""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import FloatImm, IntImm, VDevice + from relax.expr import PrimValue + from tir import IndexMap + from tvm_ffi import Object, dtype + from typing import Any +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from tvm.ir import Attrs import tvm_ffi @@ -23,342 +38,878 @@ class CallTIRWithGradAttrs(Attrs): """Attributes used in call_tir_with_grad operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.CallTIRWithGradAttrs + # fmt: off + te_grad_name: str + te_grad_kwargs: Mapping[str, Any] + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.InitAttrs") class InitAttrs(Attrs): """Attributes used in full/full_like, ones/ones_like, and zeros/zeros_like operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.InitAttrs + # fmt: off + dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.TriluAttrs") class TriluAttrs(Attrs): """Attributes used in tril and triu operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.TriluAttrs + # fmt: off + k: int + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.AstypeAttrs") class AstypeAttrs(Attrs): """Attributes used in astype operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.AstypeAttrs + # fmt: off + dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.TakeAttrs") class TakeAttrs(Attrs): """Attributes used in take operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.TakeAttrs + # fmt: off + axis: int | None + mode: str + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.StridedSliceAttrs") class StridedSliceAttrs(Attrs): """Attributes used in strided_slice operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.StridedSliceAttrs + # fmt: off + assume_inbound: bool + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.MatmulAttrs") class MatmulAttrs(Attrs): """Attributes for matmul operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.MatmulAttrs + # fmt: off + out_dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.Conv2DAttrs") class Conv2DAttrs(Attrs): """Attributes for nn.conv2d""" + # tvm-ffi-stubgen(begin): object/relax.attrs.Conv2DAttrs + # fmt: off + strides: Sequence[IntImm] + padding: Sequence[IntImm] + dilation: Sequence[IntImm] + groups: int + data_layout: str + kernel_layout: str + out_layout: str + out_dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.Conv3DAttrs") class Conv3DAttrs(Attrs): """Attributes for nn.conv3d""" + # tvm-ffi-stubgen(begin): object/relax.attrs.Conv3DAttrs + # fmt: off + strides: Sequence[IntImm] + padding: Sequence[IntImm] + dilation: Sequence[IntImm] + groups: int + data_layout: str + kernel_layout: str + out_layout: str + out_dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.Conv2DTransposeAttrs") class Conv2DTransposeAttrs(Attrs): """Attributes for nn.conv2d_transpose""" + # tvm-ffi-stubgen(begin): object/relax.attrs.Conv2DTransposeAttrs + # fmt: off + strides: Sequence[IntImm] + padding: Sequence[IntImm] + output_padding: Sequence[IntImm] + dilation: Sequence[IntImm] + groups: int + data_layout: str + kernel_layout: str + out_layout: str + out_dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.Pool2DAttrs") class Pool2DAttrs(Attrs): """Attributes for nn.max_pool2d""" + # tvm-ffi-stubgen(begin): object/relax.attrs.Pool2DAttrs + # fmt: off + pool_size: Sequence[IntImm] + strides: Sequence[IntImm] + dilation: Sequence[IntImm] + padding: Sequence[IntImm] + ceil_mode: bool + count_include_pad: bool + layout: str + out_layout: str + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.AdaptivePool2DAttrs") class AdaptivePool2DAttrs(Attrs): """Attributes for 2d adaptive pool operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.AdaptivePool2DAttrs + # fmt: off + output_size: Sequence[IntImm] | None + layout: str + out_layout: str + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.SoftmaxAttrs") class SoftmaxAttrs(Attrs): """Attributes for nn.softmax""" + # tvm-ffi-stubgen(begin): object/relax.attrs.SoftmaxAttrs + # fmt: off + axis: int + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.BatchNormAttrs") class BatchNormAttrs(Attrs): """Attributes used in batch_norm operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.BatchNormAttrs + # fmt: off + axis: int + epsilon: float + center: bool + scale: bool + momentum: float + training: bool + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.LayerNormAttrs") class LayerNormAttrs(Attrs): """Attributes used in layer_norm operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.LayerNormAttrs + # fmt: off + axes: Sequence[IntImm] + epsilon: float + center: bool + scale: bool + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.InstanceNormAttrs") class InstanceNormAttrs(Attrs): """Attributes used in instance_norm operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.InstanceNormAttrs + # fmt: off + channel_axis: int + axes: Sequence[IntImm] + epsilon: float + center: bool + scale: bool + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.DropoutAttrs") class DropoutAttrs(Attrs): """Attributes for dropout operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.DropoutAttrs + # fmt: off + rate: float + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.StatisticalAttrs") class StatisticalAttrs(Attrs): """Attributes used in statistical operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.StatisticalAttrs + # fmt: off + axis: Sequence[IntImm] | None + keepdims: bool + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.ConcatAttrs") class ConcatAttrs(Attrs): """Attributes for concat operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.ConcatAttrs + # fmt: off + axis: int | None + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.ExpandDimsAttrs") class ExpandDimsAttrs(Attrs): """Attributes for expand_dims operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.ExpandDimsAttrs + # fmt: off + axis: Sequence[IntImm] + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.PermuteDimsAttrs") class PermuteDimsAttrs(Attrs): """Attributes for permute_dims operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.PermuteDimsAttrs + # fmt: off + axes: Sequence[IntImm] | None + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.SortAttrs") class SortAttrs(Attrs): """Attributes for sort operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.SortAttrs + # fmt: off + axis: int + descending: bool + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.ArgsortAttrs") class ArgsortAttrs(Attrs): """Attributes for argsort operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.ArgsortAttrs + # fmt: off + axis: int + descending: bool + dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.SplitAttrs") class SplitAttrs(Attrs): """Attributes used in split operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.SplitAttrs + # fmt: off + indices_or_sections: Object + axis: int + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.SqueezeAttrs") class SqueezeAttrs(Attrs): """Attributes for squeeze operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.SqueezeAttrs + # fmt: off + axis: Sequence[IntImm] | None + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.StackAttrs") class StackAttrs(Attrs): """Attributes for concat operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.StackAttrs + # fmt: off + axis: IntImm | None + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.IndexPutAttrs") class IndexPutAttrs(Attrs): """Attributes for index_put operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.IndexPutAttrs + # fmt: off + accumulate: bool + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.LayoutTransformAttrs") class LayoutTransformAttrs(Attrs): """Attributes used in layout_transform operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.LayoutTransformAttrs + # fmt: off + index_map: IndexMap + pad_value: PrimValue | None + axis_separators: Sequence[IntImm] | None + input_axis_separators: Sequence[IntImm] | None + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.Resize2DAttrs") class Resize2DAttrs(Attrs): """Attributes used in image resize2d operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.Resize2DAttrs + # fmt: off + roi: Sequence[FloatImm] + layout: str + method: str + coordinate_transformation_mode: str + rounding_method: str + cubic_alpha: float + cubic_exclude: int + extrapolation_value: float + out_dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.ArgmaxArgminAttrs") class ArgmaxArgminAttrs(Attrs): """Attributes for argmax/argmin operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.ArgmaxArgminAttrs + # fmt: off + axis: int | None + keepdims: bool + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.RepeatAttrs") class RepeatAttrs(Attrs): """Attributes for repeat operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.RepeatAttrs + # fmt: off + repeats: int + axis: int | None + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.TileAttrs") class TileAttrs(Attrs): """Attributes for tile operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.TileAttrs + # fmt: off + repeats: Sequence[IntImm] + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.ScanopAttrs") class ScanopAttrs(Attrs): """Attributes for scan operators""" + # tvm-ffi-stubgen(begin): object/relax.attrs.ScanopAttrs + # fmt: off + axis: int | None + dtype: dtype + exclusive: IntImm + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.TopKAttrs") class TopKAttrs(Attrs): """Attributes for topk operators""" + # tvm-ffi-stubgen(begin): object/relax.attrs.TopKAttrs + # fmt: off + k: int + axis: int + ret_type: str + largest: bool + dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.EinsumAttrs") class EinsumAttrs(Attrs): """Attributes for einsum operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.EinsumAttrs + # fmt: off + subscripts: str + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.FlipAttrs") class FlipAttrs(Attrs): """Attributes for flip operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.FlipAttrs + # fmt: off + axis: IntImm + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.PadAttrs") class PadAttrs(Attrs): """Attributes used in pad operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.PadAttrs + # fmt: off + pad_width: Sequence[IntImm] + pad_value: float + pad_mode: str + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.MultinomialFromUniformAttrs") class MultinomialFromUniformAttrs(Attrs): """Attributes for multinomial_from_uniform operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.MultinomialFromUniformAttrs + # fmt: off + dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.CallInplacePackedAttrs") class CallInplacePackedAttrs(Attrs): """Attributes used in call_inplace_packed operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.CallInplacePackedAttrs + # fmt: off + inplace_indices: Sequence[IntImm] + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.CallTIRInplaceAttrs") class CallTIRInplaceAttrs(Attrs): """Attributes used in call_tir_inplace operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.CallTIRInplaceAttrs + # fmt: off + inplace_indices: Sequence[IntImm] + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.ToVDeviceAttrs") class ToVDeviceAttrs(Attrs): """Attributes used in to_vdevice operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.ToVDeviceAttrs + # fmt: off + dst_vdevice: VDevice + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.HintOnDeviceAttrs") class HintOnDeviceAttrs(Attrs): """Attributes used in hint_on_device operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.HintOnDeviceAttrs + # fmt: off + device_type: int + index: int + memory_scope: str + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.ScatterCollectiveAttrs") class ScatterCollectiveAttrs(Attrs): """Attributes used in scatter collective operators""" + # tvm-ffi-stubgen(begin): object/relax.attrs.ScatterCollectiveAttrs + # fmt: off + num_workers: int + axis: int + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.AttentionAttrs") class AttentionAttrs(Attrs): """Attributes used in attention operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.AttentionAttrs + # fmt: off + scale: FloatImm | None + causal_mask: str | None + window_size: IntImm | None + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.AllClassNonMaximumSuppressionAttrs") class AllClassNonMaximumSuppressionAttrs(Attrs): """Attributes for vision.all_class_non_max_suppression""" + # tvm-ffi-stubgen(begin): object/relax.attrs.AllClassNonMaximumSuppressionAttrs + # fmt: off + output_format: str + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.Conv1DAttrs") class Conv1DAttrs(Attrs): """Attributes for nn.conv1d""" + # tvm-ffi-stubgen(begin): object/relax.attrs.Conv1DAttrs + # fmt: off + strides: Sequence[IntImm] + padding: Sequence[IntImm] + dilation: Sequence[IntImm] + groups: int + data_layout: str + kernel_layout: str + out_layout: str + out_dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.Conv1DTransposeAttrs") class Conv1DTransposeAttrs(Attrs): """Attributes for nn.conv1d_transpose""" + # tvm-ffi-stubgen(begin): object/relax.attrs.Conv1DTransposeAttrs + # fmt: off + strides: Sequence[IntImm] + padding: Sequence[IntImm] + output_padding: Sequence[IntImm] + dilation: Sequence[IntImm] + groups: int + data_layout: str + kernel_layout: str + out_layout: str + out_dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.Pool1DAttrs") class Pool1DAttrs(Attrs): """Attributes for nn.max_pool1d and nn.avg_pool1d""" + # tvm-ffi-stubgen(begin): object/relax.attrs.Pool1DAttrs + # fmt: off + pool_size: Sequence[IntImm] + strides: Sequence[IntImm] + dilation: Sequence[IntImm] + padding: Sequence[IntImm] + ceil_mode: bool + count_include_pad: bool + layout: str + out_layout: str + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.Pool3DAttrs") class Pool3DAttrs(Attrs): """Attributes for nn.max_pool3d and nn.avg_pool3d""" + # tvm-ffi-stubgen(begin): object/relax.attrs.Pool3DAttrs + # fmt: off + pool_size: Sequence[IntImm] + strides: Sequence[IntImm] + dilation: Sequence[IntImm] + padding: Sequence[IntImm] + ceil_mode: bool + count_include_pad: bool + layout: str + out_layout: str + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.AdaptivePool1DAttrs") class AdaptivePool1DAttrs(Attrs): """Attributes for 1d adaptive pool operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.AdaptivePool1DAttrs + # fmt: off + output_size: Sequence[IntImm] | None + layout: str + out_layout: str + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.AdaptivePool3DAttrs") class AdaptivePool3DAttrs(Attrs): """Attributes for 3d adaptive pool operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.AdaptivePool3DAttrs + # fmt: off + output_size: Sequence[IntImm] | None + layout: str + out_layout: str + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.LeakyReluAttrs") class LeakyReluAttrs(Attrs): """Attributes used in leaky_relu operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.LeakyReluAttrs + # fmt: off + alpha: float + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.SoftplusAttrs") class SoftplusAttrs(Attrs): """Attributes used in softplus operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.SoftplusAttrs + # fmt: off + beta: float + threshold: float + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.PReluAttrs") class PReluAttrs(Attrs): """Attributes used in prelu operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.PReluAttrs + # fmt: off + axis: int + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.PixelShuffleAttrs") class PixelShuffleAttrs(Attrs): """Attributes used in pixel_shuffle operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.PixelShuffleAttrs + # fmt: off + upscale_factor: int + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.GroupNormAttrs") class GroupNormAttrs(Attrs): """Attributes used in group_norm operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.GroupNormAttrs + # fmt: off + num_groups: int + channel_axis: int + axes: Sequence[IntImm] + epsilon: float + center: bool + scale: bool + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.RMSNormAttrs") class RMSNormAttrs(Attrs): """Attributes used in rms_norm operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.RMSNormAttrs + # fmt: off + axes: Sequence[IntImm] + epsilon: float + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.NLLLossAttrs") class NLLLossAttrs(Attrs): """Attributes used in nll_loss operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.NLLLossAttrs + # fmt: off + reduction: str + ignore_index: int + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.AllReduceAttrs") class AllReduceAttrs(Attrs): """Attributes used in allreduce operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.AllReduceAttrs + # fmt: off + op_type: str + in_group: bool + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.AllGatherAttrs") class AllGatherAttrs(Attrs): """Attributes used in allgather operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.AllGatherAttrs + # fmt: off + num_workers: int + in_group: bool + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.WrapParamAttrs") class WrapParamAttrs(Attrs): """Attributes used in wrap_param operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.WrapParamAttrs + # fmt: off + dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.QuantizeAttrs") class QuantizeAttrs(Attrs): """Attributes used in quantize/dequantize operators""" + # tvm-ffi-stubgen(begin): object/relax.attrs.QuantizeAttrs + # fmt: off + out_dtype: dtype + axis: int + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.GatherElementsAttrs") class GatherElementsAttrs(Attrs): """Attributes for gather_elements operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.GatherElementsAttrs + # fmt: off + axis: IntImm + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.GatherNDAttrs") class GatherNDAttrs(Attrs): """Attributes for gather_nd operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.GatherNDAttrs + # fmt: off + batch_dims: IntImm + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.MeshgridAttrs") class MeshgridAttrs(Attrs): """Attributes for meshgrid operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.MeshgridAttrs + # fmt: off + indexing: str | None + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.ScatterElementsAttrs") class ScatterElementsAttrs(Attrs): """Attributes for scatter_elements operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.ScatterElementsAttrs + # fmt: off + axis: IntImm + reduction: str + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.ScatterNDAttrs") class ScatterNDAttrs(Attrs): """Attributes for scatter_nd operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.ScatterNDAttrs + # fmt: off + reduction: str + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.SliceScatterAttrs") class SliceScatterAttrs(Attrs): """Attributes for slice_scatter operator""" + # tvm-ffi-stubgen(begin): object/relax.attrs.SliceScatterAttrs + # fmt: off + axis: int + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.attrs.OneHotAttrs") class OneHotAttrs(Attrs): """Attributes for one_hot operator""" + + # tvm-ffi-stubgen(begin): object/relax.attrs.OneHotAttrs + # fmt: off + depth: int + axis: int + # fmt: on + # tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/op/vision/_ffi_api.py b/python/tvm/relax/op/vision/_ffi_api.py index 8af761dc5a00..4532e527298c 100644 --- a/python/tvm/relax/op/vision/_ffi_api.py +++ b/python/tvm/relax/op/vision/_ffi_api.py @@ -15,6 +15,23 @@ # specific language governing permissions and limitations # under the License. """Constructor APIs""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ir import RelaxExpr +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.op.vision", __name__) + +# tvm-ffi-stubgen(begin): global/relax.op.vision +# fmt: off +_FFI_INIT_FUNC("relax.op.vision", __name__) +if TYPE_CHECKING: + def all_class_non_max_suppression(_0: RelaxExpr, _1: RelaxExpr, _2: RelaxExpr, _3: RelaxExpr, _4: RelaxExpr, _5: str, /) -> RelaxExpr: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/op/vm/_ffi_api.py b/python/tvm/relax/op/vm/_ffi_api.py index eed64e53f036..15f7f55903ee 100644 --- a/python/tvm/relax/op/vm/_ffi_api.py +++ b/python/tvm/relax/op/vm/_ffi_api.py @@ -14,6 +14,27 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs for tvm.relax.op.vm""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ir import RelaxExpr + from relax.expr import DataTypeImm, PrimValue, StringImm, Tuple +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.op.vm", __name__) + +# tvm-ffi-stubgen(begin): global/relax.op.vm +# fmt: off +_FFI_INIT_FUNC("relax.op.vm", __name__) +if TYPE_CHECKING: + def alloc_storage(_0: RelaxExpr, _1: PrimValue, _2: DataTypeImm, _3: StringImm, /) -> RelaxExpr: ... + def alloc_tensor(_0: RelaxExpr, _1: PrimValue, _2: RelaxExpr, _3: DataTypeImm, /) -> RelaxExpr: ... + def call_tir_dyn(_0: RelaxExpr, _1: Tuple, /) -> RelaxExpr: ... + def kill_object(_0: RelaxExpr, /) -> RelaxExpr: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/struct_info.py b/python/tvm/relax/struct_info.py index e8f6c42435da..c7f32aaa9b6f 100644 --- a/python/tvm/relax/struct_info.py +++ b/python/tvm/relax/struct_info.py @@ -16,6 +16,18 @@ # under the License. # pylint: disable=invalid-name, unused-import """The struct info nodes of the Relax language.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import EnvFunc, PrimExpr, RelaxExpr, StructInfo, VDevice + from tvm_ffi import dtype +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import List, Optional, Union import tvm_ffi @@ -33,6 +45,11 @@ class ObjectStructInfo(StructInfo): """StructInfo of an Object.""" + # tvm-ffi-stubgen(begin): object/relax.ObjectStructInfo + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, span: Span = None) -> None: self.__init_handle_by_constructor__(_ffi_api.ObjectStructInfo, span) # type: ignore @@ -49,6 +66,13 @@ class PrimStructInfo(StructInfo): value. """ + # tvm-ffi-stubgen(begin): object/relax.PrimStructInfo + # fmt: off + value: PrimExpr | None + dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + value: Optional[PrimExpr] dtype: str @@ -124,6 +148,13 @@ class ShapeStructInfo(StructInfo): Do not specify values and ndim at the same time. """ + # tvm-ffi-stubgen(begin): object/relax.ShapeStructInfo + # fmt: off + values: Sequence[PrimExpr] | None + ndim: int + # fmt: on + # tvm-ffi-stubgen(end) + values: Optional[List[PrimExpr]] ndim: int span: Span @@ -159,6 +190,15 @@ class TensorStructInfo(StructInfo): Do not specify shape and ndim at the same time. """ + # tvm-ffi-stubgen(begin): object/relax.TensorStructInfo + # fmt: off + shape: RelaxExpr | None + dtype: dtype + vdevice: VDevice | None + ndim: int + # fmt: on + # tvm-ffi-stubgen(end) + shape: Optional[Expr] dtype: str vdevice: Optional[VDevice] @@ -190,6 +230,12 @@ class TupleStructInfo(StructInfo): The struct info of the fields. """ + # tvm-ffi-stubgen(begin): object/relax.TupleStructInfo + # fmt: off + fields: Sequence[StructInfo] + # fmt: on + # tvm-ffi-stubgen(end) + fields: List[StructInfo] span: Span @@ -216,6 +262,15 @@ class FuncStructInfo(StructInfo): we still consider it impure. """ + # tvm-ffi-stubgen(begin): object/relax.FuncStructInfo + # fmt: off + params: Sequence[StructInfo] | None + ret: StructInfo + derive_func: EnvFunc | None + purity: bool + # fmt: on + # tvm-ffi-stubgen(end) + params: Optional[List[StructInfo]] ret: StructInfo derive_func: Optional[EnvFunc] diff --git a/python/tvm/relax/testing/transform.py b/python/tvm/relax/testing/transform.py index 617ba73f09f4..5b19bafa6b05 100644 --- a/python/tvm/relax/testing/transform.py +++ b/python/tvm/relax/testing/transform.py @@ -17,6 +17,17 @@ # pylint: disable=unused-argument, invalid-name, no-else-return, abstract-method, arguments-differ """Relax transformation passes for testing""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import IntImm +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import logging import os from typing import Dict, List, Set, Tuple @@ -86,6 +97,13 @@ class InplaceOpportunity(Object): Indices of arguments that are eligible to be used as in-place targets. """ + # tvm-ffi-stubgen(begin): object/relax.transform.InplaceOpportunity + # fmt: off + binding_idx: IntImm + arg_idxs: Sequence[IntImm] + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, _binding_idx, _arg_idxs): raise NotImplementedError("Constructor for InplaceOpportunity not exposed!") diff --git a/python/tvm/relax/training/_ffi_api.py b/python/tvm/relax/training/_ffi_api.py index 25f395830341..c7f1729db4e5 100644 --- a/python/tvm/relax/training/_ffi_api.py +++ b/python/tvm/relax/training/_ffi_api.py @@ -15,6 +15,24 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.relax.training""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from relax.expr import Function + from transform import Pass +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.training", __name__) + +# tvm-ffi-stubgen(begin): global/relax.training +# fmt: off +_FFI_INIT_FUNC("relax.training", __name__) +if TYPE_CHECKING: + def AppendLoss(_0: str, _1: Function, _2: int, _3: str | None, /) -> Pass: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/transform/_ffi_api.py b/python/tvm/relax/transform/_ffi_api.py index 25d6ecd75385..b928aa7c64d7 100644 --- a/python/tvm/relax/transform/_ffi_api.py +++ b/python/tvm/relax/transform/_ffi_api.py @@ -14,6 +14,96 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations """FFI APIs for tvm.transform""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import IRModule, IntImm, PrimExpr, RelaxExpr, StructInfo, VDevice + from relax import DataflowBlockPass, FunctionPass, MatchResult + from relax.dpl import DFPattern + from relax.expr import DataflowBlock, Function, Var + from relax.transform import FusionPattern + from tir import IndexMap, PrimFunc, Var + from transform import Pass, PassContext, PassInfo + from tvm_ffi import Object, Tensor, dtype + from typing import Any, Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("relax.transform", __name__) + +# tvm-ffi-stubgen(begin): global/relax.transform +# fmt: off +_FFI_INIT_FUNC("relax.transform", __name__) +if TYPE_CHECKING: + def AdjustMatmulOrder() -> Pass: ... + def AllocateWorkspace() -> Pass: ... + def AlterOpImpl(_0: Mapping[str, PrimFunc], _1: Mapping[str, Sequence[IndexMap]], _2: Mapping[str, Sequence[Sequence[IntImm]] | None], _3: Mapping[str, Sequence[Sequence[IntImm]] | None], /) -> Pass: ... + def AnnotateTIROpPattern() -> Pass: ... + def AttachAttrLayoutFreeBuffers() -> Pass: ... + def AttachGlobalSymbol() -> Pass: ... + def BindParams(_0: str, _1: Mapping[Any, Object], /) -> Pass: ... + def BindSymbolicVars(_0: Mapping[Var | str, PrimExpr], _1: str | None, /) -> Pass: ... + def BundleModelParams(_0: str | None, /) -> Pass: ... + def CallTIRRewrite() -> Pass: ... + def CanonicalizeBindings() -> Pass: ... + def CombineParallelMatmul(_0: Callable[[Var, Sequence[Var], Sequence[Var], Mapping[Var, RelaxExpr]], bool], /) -> Pass: ... + def ComputePrimValue() -> Pass: ... + def ConvertLayout(_0: Mapping[str, Sequence[str]], /) -> Pass: ... + def ConvertToDataflow(_0: int, /) -> Pass: ... + def DataflowUseInplaceCalls() -> Pass: ... + def DeadCodeElimination(_0: Sequence[str], /) -> Pass: ... + def DecomposeOpsForInference(_0: str | None, /) -> Pass: ... + def DecomposeOpsForTraining(_0: str | None, /) -> Pass: ... + def EliminateCommonSubexpr(_0: bool, /) -> Pass: ... + def ExpandMatmulOfSum() -> Pass: ... + def ExpandTupleArguments() -> Pass: ... + def FewShotTuning(_0: int, _1: bool, /) -> Pass: ... + def FoldConstant() -> Pass: ... + def FuseOps(_0: int, /) -> Pass: ... + def FuseOpsByPattern(_0: Sequence[FusionPattern], _1: bool, _2: bool, _3: Sequence[str], /) -> Pass: ... + def FuseTIR() -> Pass: ... + def FusionPattern(_0: str, _1: DFPattern, _2: Mapping[str, DFPattern], _3: Callable[..., Any] | None, _4: Callable[..., Any] | None, /) -> FusionPattern: ... + def Gradient(_0: str, _1: Sequence[Var] | None, _2: int, /) -> Pass: ... + def InlinePrivateFunctions() -> Pass: ... + def KillAfterLastUse() -> Pass: ... + def LambdaLift() -> Pass: ... + def LazyGetInput() -> Pass: ... + def LazySetOutput() -> Pass: ... + def LegalizeOps(_0: Mapping[str, Callable[..., Any]] | None, _1: Sequence[str] | None, _2: bool, /) -> Pass: ... + def LiftTransformParams(_0: IntImm | Sequence[str], /) -> Pass: ... + def LowerAllocTensor() -> Pass: ... + def LowerRuntimeBuiltin() -> Pass: ... + def MakeDataflowBlockPass(_0: Callable[[ObjectRValueRef[DataflowBlock], IRModule, PassContext], DataflowBlock], _1: PassInfo, /) -> DataflowBlockPass: ... + def MakeFunctionPass(_0: Callable[[ObjectRValueRef[Function], IRModule, PassContext], Function], _1: PassInfo, /) -> FunctionPass: ... + def MergeCompositeFunctions() -> Pass: ... + def MetaScheduleApplyDatabase(_0: str | None, _1: bool, /) -> Pass: ... + def MetaScheduleTuneIRMod(_0: Mapping[str, Tensor], _1: str, _2: IntImm, _3: IntImm | None, _4: Sequence[str] | None, /) -> Pass: ... + def MetaScheduleTuneTIR(_0: str, _1: IntImm, /) -> Pass: ... + def Normalize() -> Pass: ... + def NormalizeGlobalVar() -> Pass: ... + def RealizeVDevice() -> Pass: ... + def RemovePurityChecking() -> Pass: ... + def RemoveUnusedOutputs() -> Pass: ... + def RemoveUnusedParameters() -> Pass: ... + def ReorderPermuteDimsAfterConcat() -> Pass: ... + def ReorderTakeAfterMatmul() -> Pass: ... + def RewriteCUDAGraph() -> Pass: ... + def RewriteDataflowReshape() -> Pass: ... + def RunCodegen(_0: Mapping[str, Mapping[str, Any]] | None, _1: Sequence[str], /) -> Pass: ... + def SpecializePrimFuncBasedOnCallSite() -> Pass: ... + def SplitCallTIRByPattern(_0: Sequence[PrimFunc], _1: Callable[[Sequence[MatchResult]], Sequence[Any]], /) -> Pass: ... + def SplitLayoutRewritePreproc() -> Pass: ... + def StaticPlanBlockMemory() -> Pass: ... + def ToMixedPrecision(_0: dtype, _1: Sequence[str] | None, /) -> Pass: ... + def ToNonDataflow() -> Pass: ... + def TopologicalSort(_0: str, _1: str, /) -> Pass: ... + def UpdateParamStructInfo(_0: Callable[[Var], StructInfo | None], /) -> Pass: ... + def UpdateVDevice(_0: VDevice, _1: int, /) -> Pass: ... + def VMShapeLower(_0: bool, /) -> Pass: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 46efc17e3d4f..6ab829299170 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -16,6 +16,21 @@ # under the License. # pylint: disable=invalid-name """Relax transformation passes.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import RelaxExpr + from relax.dpl import DFPattern + from relax.expr import Var + from transform import PassInfo + from typing import Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import functools import inspect import types @@ -42,11 +57,23 @@ class FunctionPass(tvm.ir.transform.Pass): pass class should be created through `function_pass`. """ + # tvm-ffi-stubgen(begin): object/relax.FunctionPass + # fmt: off + pass_info: PassInfo + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("relax.DataflowBlockPass") class DataflowBlockPass(tvm.ir.transform.Pass): """A pass that works on each tvm.relax.DataflowBlock in a module.""" + # tvm-ffi-stubgen(begin): object/relax.DataflowBlockPass + # fmt: off + pass_info: PassInfo + # fmt: on + # tvm-ffi-stubgen(end) + def Gradient( func_name: str, require_grads: Optional[Union[Var, List[Var]]] = None, target_index: int = 0 @@ -847,6 +874,16 @@ class PatternCheckContext(Object): matched expression. """ + # tvm-ffi-stubgen(begin): object/relax.transform.PatternCheckContext + # fmt: off + matched_expr: RelaxExpr + annotated_expr: Mapping[str, RelaxExpr] + matched_bindings: Mapping[Var, RelaxExpr] + var_usages: Mapping[Var, Sequence[Var]] + value_to_bound_var: Mapping[RelaxExpr, Var] + # fmt: on + # tvm-ffi-stubgen(end) + matched_expr: Expr annotated_expr: Mapping[str, Expr] matched_bindings: Mapping[Var, Expr] @@ -877,6 +914,16 @@ class FusionPattern(Object): The function to check whether the match result is accepted. """ + # tvm-ffi-stubgen(begin): object/relax.transform.FusionPattern + # fmt: off + name: str + pattern: DFPattern + annotation_patterns: Mapping[str, DFPattern] + check: Callable[..., Any] | None + attrs_getter: Callable[..., Any] | None + # fmt: on + # tvm-ffi-stubgen(end) + name: str pattern: DFPattern annotation_patterns: Mapping[str, DFPattern] diff --git a/python/tvm/relax/ty.py b/python/tvm/relax/ty.py index ebf757f38136..d1ba64d76223 100644 --- a/python/tvm/relax/ty.py +++ b/python/tvm/relax/ty.py @@ -16,6 +16,16 @@ # under the License. # pylint: disable=invalid-name, unused-import """The type nodes of the Relax language.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from tvm_ffi import dtype +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import tvm_ffi from tvm.ir import Type, TupleType, FuncType, Span @@ -32,6 +42,12 @@ class ShapeType(Type): The size of the shape. """ + # tvm-ffi-stubgen(begin): object/relax.ShapeType + # fmt: off + ndim: int + # fmt: on + # tvm-ffi-stubgen(end) + # TODO(relax-team): consider make ndim mandatory def __init__(self, ndim: int = -1, span: Span = None) -> None: self.__init_handle_by_constructor__(_ffi_api.ShapeType, ndim, span) # type: ignore @@ -42,6 +58,11 @@ class ObjectType(Type): """A type that corresponds to tvm::runtime::Object, is base of all possible object values in TVM.""" + # tvm-ffi-stubgen(begin): object/relax.ObjectType + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, span: Span = None) -> None: self.__init_handle_by_constructor__(_ffi_api.ObjectType, span) # type: ignore @@ -61,6 +82,13 @@ class TensorType(Type): The content data type. """ + # tvm-ffi-stubgen(begin): object/relax.DynTensorType + # fmt: off + ndim: int + dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, ndim=-1, dtype="float32", span: Span = None) -> None: self.__init_handle_by_constructor__(_ffi_api.TensorType, ndim, dtype, span) # type: ignore @@ -69,5 +97,10 @@ def __init__(self, ndim=-1, dtype="float32", span: Span = None) -> None: class PackedFuncType(Type): """The type of ExternFunc in Relax.""" + # tvm-ffi-stubgen(begin): object/relax.PackedFuncType + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, span: Span = None) -> None: self.__init_handle_by_constructor__(_ffi_api.PackedFuncType, span) # type: ignore diff --git a/python/tvm/rpc/_ffi_api.py b/python/tvm/rpc/_ffi_api.py index 80fd79e31348..2ba1da29eabb 100644 --- a/python/tvm/rpc/_ffi_api.py +++ b/python/tvm/rpc/_ffi_api.py @@ -15,7 +15,33 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.rpc""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from tvm_ffi import Module + from typing import Any, Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("rpc", __name__) + +# tvm-ffi-stubgen(begin): global/rpc +# fmt: off +_FFI_INIT_FUNC("rpc", __name__) +if TYPE_CHECKING: + def Connect(*args: Any) -> Any: ... + def CreateEventDrivenServer(_0: Callable[..., Any], _1: str, _2: str, /) -> Callable[..., Any]: ... + def CreatePipeClient(*args: Any) -> Any: ... + def ImportRemoteModule(_0: Module, _1: Module, /) -> None: ... + def LoadRemoteModule(_0: Module, _1: str, /) -> Module: ... + def LocalSession() -> Module: ... + def ReturnException(_0: int, _1: str, /) -> None: ... + def ServerLoop(*args: Any) -> Any: ... + def SessTableIndex(*args: Any) -> Any: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/runtime/_ffi_api.py b/python/tvm/runtime/_ffi_api.py index c713b379c384..8774925ca0e8 100644 --- a/python/tvm/runtime/_ffi_api.py +++ b/python/tvm/runtime/_ffi_api.py @@ -15,7 +15,58 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.runtime""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ctypes import c_void_p + from tvm_ffi import Device, Module, Shape, Tensor, dtype + from typing import Any, Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) + # Exports functions registered in runtime namespace. -tvm_ffi.init_ffi_api("runtime", __name__) +# tvm-ffi-stubgen(begin): global/runtime +# fmt: off +_FFI_INIT_FUNC("runtime", __name__) +if TYPE_CHECKING: + def CSourceModuleCreate(_0: str, _1: str, _2: Sequence[str] | None, _3: Sequence[str] | None, /) -> Module: ... + def Device_SetStream(_0: Device, _1: int, /) -> None: ... + def Device_StreamCreate(_0: Device, /) -> int: ... + def Device_StreamFree(_0: Device, _1: int, /) -> None: ... + def Device_StreamSync(_0: Device, _1: int, /) -> None: ... + def Device_StreamSyncFromTo(_0: Device, _1: int, _2: int, /) -> None: ... + def GetCudaDeviceCount() -> int: ... + def GetCudaFreeMemory() -> str: ... + def GetDeviceAttr(*args: Any) -> Any: ... + def LoadParams(_0: bytes, /) -> Mapping[str, Tensor]: ... + def LoadParamsFromFile(_0: str, /) -> Mapping[str, Tensor]: ... + def ModuleImportsBlobName() -> str: ... + def ModuleLoadStaticLibrary(_0: str, _1: Sequence[str], /) -> Module: ... + def ModulePackImportsToC(_0: Module, _1: bool, _2: str, /) -> str: ... + def ModulePackImportsToLLVM(_0: Module, _1: bool, _2: str, _3: str, /) -> Module: ... + def ModulePackImportsToTensor(_0: Module, /) -> Tensor: ... + def NumThreads() -> int: ... + def RPCTimeEvaluator(_0: Module | None, _1: str, _2: int, _3: int, _4: int, _5: int, _6: int, _7: int, _8: int, _9: int, _10: int, _11: str, /) -> Callable[..., Any]: ... + def RuntimeEnabled(_0: str, /) -> bool: ... + def SaveParams(_0: Mapping[str, Tensor], /) -> bytes: ... + def SaveParamsToFile(_0: Mapping[str, Tensor], _1: str, /) -> None: ... + def SourceModuleCreate(_0: str, _1: str, /) -> Module: ... + def TVMSetStream(_0: int, _1: int, _2: c_void_p, /) -> None: ... + def TVMTensorAllocWithScope(_0: Shape, _1: dtype, _2: Device, _3: str | None, /) -> Tensor: ... + def TVMTensorCopyFromBytes(_0: Tensor, _1: c_void_p, _2: int, /) -> None: ... + def TVMTensorCopyFromTo(_0: Tensor, _1: Tensor, /) -> None: ... + def TVMTensorCopyToBytes(_0: Tensor, _1: c_void_p, _2: int, /) -> None: ... + def TVMTensorCreateView(_0: Tensor, _1: Shape, _2: dtype, _3: int, /) -> Tensor: ... + def _datatype_get_type_registered(*args: Any) -> Any: ... + def config_threadpool(*args: Any) -> Any: ... + def cuTensorMapEncodeTiled(*args: Any) -> Any: ... + def get_cuda_stream() -> c_void_p: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/runtime/_tensor.py b/python/tvm/runtime/_tensor.py index 3affbf55d563..68da1e0ca79d 100644 --- a/python/tvm/runtime/_tensor.py +++ b/python/tvm/runtime/_tensor.py @@ -16,6 +16,8 @@ # under the License. # pylint: disable=invalid-name, unused-import, redefined-outer-name """Runtime Tensor API""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) import ctypes import warnings from typing import Optional @@ -72,6 +74,11 @@ class Tensor(tvm_ffi.core.Tensor): how can we use TVM in existing project which might have their own array containers. """ + # tvm-ffi-stubgen(begin): object/ffi.Tensor + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __setitem__(self, in_slice, value): """Set ndarray value""" if ( diff --git a/python/tvm/runtime/disco/_ffi_api.py b/python/tvm/runtime/disco/_ffi_api.py index 2caeef293ea5..f69842d447ac 100644 --- a/python/tvm/runtime/disco/_ffi_api.py +++ b/python/tvm/runtime/disco/_ffi_api.py @@ -15,6 +15,63 @@ # specific language governing permissions and limitations # under the License. """FFI APIs from C++""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from runtime.disco import DRef, Session + from tvm_ffi import Device, Module, Object, Shape, Tensor, dtype + from typing import Any +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("runtime.disco", __name__) + +# tvm-ffi-stubgen(begin): global/runtime.disco +# fmt: off +_FFI_INIT_FUNC("runtime.disco", __name__) +if TYPE_CHECKING: + def DRefDebugCopyFrom(_0: DRef, _1: int, _2: Any, /) -> None: ... + def DRefDebugGetFromRemote(_0: DRef, _1: int, /) -> Any: ... + def RemoteSocketSession(_0: str, _1: int, _2: int, /) -> None: ... + def SessionCallPacked(*args: Any) -> Any: ... + def SessionCopyFromWorker0(_0: Session, _1: Tensor, _2: DRef, /) -> None: ... + def SessionCopyToWorker0(_0: Session, _1: Tensor, _2: DRef, /) -> None: ... + def SessionGetGlobalFunc(_0: Session, _1: str, /) -> DRef: ... + def SessionGetNumWorkers(_0: Session, /) -> int: ... + def SessionInitCCL(_0: Session, _1: str, _2: Shape, /) -> None: ... + def SessionProcess(_0: int, _1: int, _2: str, _3: str, /) -> Session: ... + def SessionShutdown(_0: Session, /) -> None: ... + def SessionSyncWorker(_0: Session, _1: int, /) -> None: ... + def SessionThreaded(_0: int, _1: int, /) -> Session: ... + def ShardLoader(_0: str, _1: str, _2: str, _3: Module | None, /) -> Object: ... + def ShardLoaderLoad(_0: Object, _1: Shape, /) -> Tensor: ... + def ShardLoaderLoadAll(_0: Object, /) -> Sequence[Tensor]: ... + def ShardLoaderLoadAllPresharded(_0: Object, /) -> Sequence[Tensor]: ... + def ShardLoaderLoadParamOnWorker0(_0: Object, _1: int, /) -> Tensor: ... + def ShardLoaderLoadPresharded(_0: Object, _1: Shape, /) -> Tensor: ... + def SocketSession(_0: int, _1: int, _2: int, _3: str, _4: int, /) -> Session: ... + def WorkerProcess(_0: int, _1: int, _2: int, _3: int, _4: int, /) -> None: ... + def allgather(_0: Tensor, _1: bool, _2: Tensor, /) -> None: ... + def allreduce(_0: Tensor, _1: Shape, _2: bool, _3: Tensor, /) -> None: ... + def bind_worker_to_cpu_core(_0: Shape, /) -> None: ... + def broadcast_from_worker0(_0: Tensor, _1: bool, _2: Tensor, /) -> None: ... + def device() -> Device: ... + def empty(_0: Shape, _1: dtype, _2: Device | None, _3: bool, _4: bool, /) -> Tensor | None: ... + def gather_to_worker0(_0: Tensor, _1: bool, _2: Tensor | None, /) -> None: ... + def load_vm_module(_0: str, _1: Device | None, /) -> Module: ... + def recv_from_prev_group(_0: Tensor, /) -> None: ... + def recv_from_worker(_0: Tensor, _1: int, /) -> None: ... + def recv_from_worker0(_0: Tensor, /) -> None: ... + def scatter_from_worker0(_0: Tensor | None, _1: bool, _2: Tensor, /) -> None: ... + def send_to_next_group(_0: Tensor, /) -> None: ... + def send_to_worker(_0: Tensor, _1: int, /) -> None: ... + def socket_session_init_workers(_0: int, _1: int, _2: int, _3: int, /) -> None: ... + def worker_id() -> Shape: ... + def worker_rank() -> int: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/runtime/disco/session.py b/python/tvm/runtime/disco/session.py index f2c2dfc791ab..be7791be85f5 100644 --- a/python/tvm/runtime/disco/session.py +++ b/python/tvm/runtime/disco/session.py @@ -18,6 +18,8 @@ with the distributed runtime. """ +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) import logging import os import pickle @@ -40,6 +42,11 @@ class DRef(Object): to each object, and the worker process uses this id to refer to the object residing on itself. """ + # tvm-ffi-stubgen(begin): object/runtime.disco.DRef + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def debug_get_from_remote(self, worker_id: int) -> Any: """Get the value of a DRef from a remote worker. It is only used for debugging purposes. @@ -103,6 +110,11 @@ class Session(Object): """A Disco interactive session. It allows users to interact with the Disco command queue with various PackedFunc calling convention.""" + # tvm-ffi-stubgen(begin): object/runtime.disco.Session + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def _get_cached_method(self, name: str) -> Callable: if "_cache" not in self.__dict__: cache = self._cache = {} # pylint: disable=attribute-defined-outside-init @@ -532,6 +544,11 @@ def _clear_ipc_memory_pool(self): class ThreadedSession(Session): """A Disco session backed by multi-threading.""" + # tvm-ffi-stubgen(begin): object/runtime.disco.ThreadedSession + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, num_workers: int, num_groups: int = 1) -> None: """Create a disco session backed by multiple threads in the same process.""" self.__init_handle_by_constructor__( @@ -545,6 +562,11 @@ def __init__(self, num_workers: int, num_groups: int = 1) -> None: class ProcessSession(Session): """A Disco session backed by pipe-based multi-processing.""" + # tvm-ffi-stubgen(begin): object/runtime.disco.ProcessSession + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, num_workers: int, @@ -593,6 +615,11 @@ def _create_socket_session_local_workers(num_workers) -> Session: class SocketSession(Session): """A Disco session backed by socket-based multi-node communication.""" + # tvm-ffi-stubgen(begin): object/runtime.disco.SocketSession + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, num_nodes: int, diff --git a/python/tvm/runtime/profiling/__init__.py b/python/tvm/runtime/profiling/__init__.py index 3ca831ac4200..e2074eb28e41 100644 --- a/python/tvm/runtime/profiling/__init__.py +++ b/python/tvm/runtime/profiling/__init__.py @@ -16,6 +16,8 @@ # under the License. """Registration of profiling objects in python.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from typing import Dict, Sequence, Optional from ... import ffi as _ffi from . import _ffi_api @@ -35,6 +37,9 @@ class Report(Object): Per-device metrics collected over the entire run. """ + # tvm-ffi-stubgen(begin): object/runtime.profiling.Report + # tvm-ffi-stubgen(end) + def __init__( self, calls: Sequence[Dict[str, Object]], @@ -176,6 +181,9 @@ def from_json(cls, s): class Count(Object): """A integer count of something""" + # tvm-ffi-stubgen(begin): object/runtime.profiling.Count + # tvm-ffi-stubgen(end) + def __init__(self, count: int): self.__init_handle_by_constructor__(_ffi_api.Count, count) @@ -184,6 +192,9 @@ def __init__(self, count: int): class Duration(Object): """A duration of something""" + # tvm-ffi-stubgen(begin): object/runtime.profiling.Duration + # tvm-ffi-stubgen(end) + def __init__(self, duration: float): self.__init_handle_by_constructor__(_ffi_api.Duration, duration) @@ -192,6 +203,9 @@ def __init__(self, duration: float): class Percent(Object): """A Percent of something""" + # tvm-ffi-stubgen(begin): object/runtime.profiling.Percent + # tvm-ffi-stubgen(end) + def __init__(self, percent: float): self.__init_handle_by_constructor__(_ffi_api.Percent, percent) @@ -200,6 +214,9 @@ def __init__(self, percent: float): class Ratio(Object): """A Ratio of two things""" + # tvm-ffi-stubgen(begin): object/runtime.profiling.Ratio + # tvm-ffi-stubgen(end) + def __init__(self, ratio: float): self.__init_handle_by_constructor__(_ffi_api.Ratio, ratio) @@ -208,11 +225,17 @@ def __init__(self, ratio: float): class MetricCollector(Object): """Interface for user defined profiling metric collection.""" + # tvm-ffi-stubgen(begin): object/runtime.profiling.MetricCollector + # tvm-ffi-stubgen(end) + @_ffi.register_object("runtime.profiling.DeviceWrapper") class DeviceWrapper(Object): """Wraps a tvm.runtime.Device""" + # tvm-ffi-stubgen(begin): object/runtime.profiling.DeviceWrapper + # tvm-ffi-stubgen(end) + def __init__(self, dev: Device): self.__init_handle_by_constructor__(_ffi_api.DeviceWrapper, dev) @@ -279,6 +302,9 @@ class PAPIMetricCollector(MetricCollector): Application Programming Interface (PAPI). """ + # tvm-ffi-stubgen(begin): object/runtime.profiling.PAPIMetricCollector + # tvm-ffi-stubgen(end) + def __init__(self, metric_names: Optional[Dict[Device, Sequence[str]]] = None): """ Parameters diff --git a/python/tvm/runtime/profiling/_ffi_api.py b/python/tvm/runtime/profiling/_ffi_api.py index 883e3ca6e778..927d0a77ab5c 100644 --- a/python/tvm/runtime/profiling/_ffi_api.py +++ b/python/tvm/runtime/profiling/_ffi_api.py @@ -15,6 +15,36 @@ # specific language governing permissions and limitations # under the License. """FFI for profiling""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from runtime.profiling import DeviceWrapper, MetricCollector, Report + from tvm_ffi import Device, Module, Object + from typing import Any, Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("runtime.profiling", __name__) + +# tvm-ffi-stubgen(begin): global/runtime.profiling +# fmt: off +_FFI_INIT_FUNC("runtime.profiling", __name__) +if TYPE_CHECKING: + def AsCSV(_0: Report, /) -> str: ... + def AsJSON(_0: Report, /) -> str: ... + def AsTable(_0: Report, _1: bool, _2: bool, _3: bool, /) -> str: ... + def Count(_0: int, /) -> Object: ... + def DeviceWrapper(_0: Device, /) -> DeviceWrapper: ... + def Duration(_0: float, /) -> Object: ... + def FromJSON(_0: str, /) -> Report: ... + def Percent(_0: float, /) -> Object: ... + def ProfileFunction(_0: Module, _1: str, _2: int, _3: int, _4: int, _5: Sequence[MetricCollector], /) -> Callable[..., Any]: ... + def Ratio(_0: float, /) -> Object: ... + def Report(_0: Sequence[Mapping[str, Any]], _1: Mapping[str, Mapping[str, Any]], _2: Mapping[str, Any], /) -> Report: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/runtime/script_printer.py b/python/tvm/runtime/script_printer.py index 7442cd99172f..4d2d561aa1ff 100644 --- a/python/tvm/runtime/script_printer.py +++ b/python/tvm/runtime/script_printer.py @@ -15,6 +15,18 @@ # specific language governing permissions and limitations # under the License. """Configuration of TVMScript printer""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from tvm_ffi import Object, dtype + from tvm_ffi.reflection import AccessPath +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import os from typing import Dict, List, Optional, Sequence @@ -29,6 +41,31 @@ class PrinterConfig(Object): """Configuration of TVMScript printer""" + # tvm-ffi-stubgen(begin): object/script.PrinterConfig + # fmt: off + binding_names: Sequence[str] + show_meta: bool + ir_prefix: str + tir_prefix: str + relax_prefix: str + module_alias: str + buffer_dtype: dtype + int_dtype: dtype + float_dtype: dtype + verbose_expr: bool + indent_spaces: int + print_line_numbers: bool + num_context_lines: int + syntax_sugar: bool + show_object_address: bool + show_all_struct_info: bool + path_to_underline: Sequence[AccessPath] + path_to_annotate: Mapping[AccessPath, str] + obj_to_underline: Sequence[Object] + obj_to_annotate: Mapping[Object, str] + # fmt: on + # tvm-ffi-stubgen(end) + binding_names: Sequence[str] show_meta: bool ir_prefix: str diff --git a/python/tvm/script/_ffi_api.py b/python/tvm/script/_ffi_api.py index 1354d3f2ec2c..f866edba616a 100644 --- a/python/tvm/script/_ffi_api.py +++ b/python/tvm/script/_ffi_api.py @@ -14,7 +14,25 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.script""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from tir import Buffer, PrimFunc +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("script", __name__) + +# tvm-ffi-stubgen(begin): global/script +# fmt: off +_FFI_INIT_FUNC("script", __name__) +if TYPE_CHECKING: + def Complete(_0: PrimFunc, _1: Sequence[Buffer], /) -> PrimFunc: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/script/ir_builder/_ffi_api.py b/python/tvm/script/ir_builder/_ffi_api.py index c8a9597d5292..7b60db417169 100644 --- a/python/tvm/script/ir_builder/_ffi_api.py +++ b/python/tvm/script/ir_builder/_ffi_api.py @@ -15,6 +15,34 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.script.ir_builder""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from script.ir_builder import IRBuilder, IRBuilderFrame + from tvm_ffi import Object + from typing import Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("script.ir_builder", __name__) # pylint: disable=protected-access + +# tvm-ffi-stubgen(begin): global/script.ir_builder +# fmt: off +_FFI_INIT_FUNC("script.ir_builder", __name__) +if TYPE_CHECKING: + def IRBuilder() -> IRBuilder: ... + def IRBuilderCurrent() -> IRBuilder: ... + def IRBuilderEnter(_0: IRBuilder, /) -> None: ... + def IRBuilderExit(_0: IRBuilder, /) -> None: ... + def IRBuilderFrameAddCallback(_0: IRBuilderFrame, _1: Callable[[], None], /) -> None: ... + def IRBuilderFrameEnter(_0: IRBuilderFrame, /) -> None: ... + def IRBuilderFrameExit(_0: IRBuilderFrame, /) -> None: ... + def IRBuilderGet(_0: IRBuilder, /) -> Object: ... + def IRBuilderIsInScope() -> bool: ... + def IRBuilderName(_0: str, _1: Object, /) -> Object: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/script/ir_builder/ir/_ffi_api.py b/python/tvm/script/ir_builder/ir/_ffi_api.py index e319c3d4612e..831faf434166 100644 --- a/python/tvm/script/ir_builder/ir/_ffi_api.py +++ b/python/tvm/script/ir_builder/ir/_ffi_api.py @@ -15,6 +15,34 @@ # specific language governing permissions and limitations # under the License. """FFI APIs""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import BaseFunc, GlobalInfo, GlobalVar, VDevice + from script.ir_builder import IRModuleFrame + from tvm_ffi import Object + from typing import Any +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("script.ir_builder.ir", __name__) # pylint: disable=protected-access + +# tvm-ffi-stubgen(begin): global/script.ir_builder.ir +# fmt: off +_FFI_INIT_FUNC("script.ir_builder.ir", __name__) +if TYPE_CHECKING: + def DeclFunction(_0: str, _1: BaseFunc, /) -> GlobalVar: ... + def DefFunction(_0: str, _1: BaseFunc, /) -> None: ... + def IRModule() -> IRModuleFrame: ... + def LookupVDevice(_0: str, _1: int, /) -> VDevice: ... + def ModuleAttrs(_0: Mapping[str, Any], _1: bool, /) -> None: ... + def ModuleGetAttr(_0: str, /) -> Object | None: ... + def ModuleGlobalInfos(_0: Mapping[str, Sequence[GlobalInfo]], /) -> None: ... + def ModuleSetAttr(_0: str, _1: Object | None, _2: bool, /) -> None: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/script/ir_builder/relax/_ffi_api.py b/python/tvm/script/ir_builder/relax/_ffi_api.py index f6c53336ff4c..ef1b58ff3697 100644 --- a/python/tvm/script/ir_builder/relax/_ffi_api.py +++ b/python/tvm/script/ir_builder/relax/_ffi_api.py @@ -15,6 +15,42 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.script.ir_builder.relax""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import IntImm, RelaxExpr, StructInfo + from relax.expr import Var, VarBinding + from script.ir_builder.relax import BlockFrame, ElseFrame, FunctionFrame, IfFrame, SeqExprFrame, ThenFrame + from typing import Any +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("script.ir_builder.relax", __name__) # pylint: disable=protected-access + +# tvm-ffi-stubgen(begin): global/script.ir_builder.relax +# fmt: off +_FFI_INIT_FUNC("script.ir_builder.relax", __name__) +if TYPE_CHECKING: + def Arg(_0: str, _1: StructInfo, /) -> Var: ... + def BindingBlock() -> BlockFrame: ... + def Dataflow() -> BlockFrame: ... + def DataflowBlockOutput(_0: Sequence[Var], /) -> None: ... + def Else() -> ElseFrame: ... + def Emit(_0: RelaxExpr, _1: StructInfo | None, /) -> Var: ... + def EmitMatchCast(_0: RelaxExpr, _1: StructInfo, /) -> Var: ... + def EmitVarBinding(_0: VarBinding, /) -> Var: ... + def FuncAttrs(_0: Mapping[str, Any], /) -> None: ... + def FuncName(_0: str, /) -> None: ... + def FuncRetStructInfo(_0: StructInfo, /) -> None: ... + def FuncRetValue(_0: RelaxExpr, /) -> None: ... + def Function(_0: IntImm, _1: IntImm, /) -> FunctionFrame: ... + def If(_0: RelaxExpr, /) -> IfFrame: ... + def SeqExpr() -> SeqExprFrame: ... + def Then() -> ThenFrame: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/script/ir_builder/relax/distributed/_ffi_api.py b/python/tvm/script/ir_builder/relax/distributed/_ffi_api.py index b82fa37e8f3f..5eed61bab3b6 100644 --- a/python/tvm/script/ir_builder/relax/distributed/_ffi_api.py +++ b/python/tvm/script/ir_builder/relax/distributed/_ffi_api.py @@ -15,8 +15,26 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.script.ir_builder.relax.distributed""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import RelaxExpr + from relax import DTensorStructInfo + from relax.expr import Tuple +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api( - "script.ir_builder.relax.distributed", __name__ -) # pylint: disable=protected-access + +# tvm-ffi-stubgen(begin): global/script.ir_builder.relax.distributed +# fmt: off +_FFI_INIT_FUNC("script.ir_builder.relax.distributed", __name__) +if TYPE_CHECKING: + def call_tir_dist(_0: RelaxExpr, _1: Tuple, _2: Sequence[DTensorStructInfo], _3: RelaxExpr | None, /) -> RelaxExpr: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/script/ir_builder/tir/_ffi_api.py b/python/tvm/script/ir_builder/tir/_ffi_api.py index 4385b2ec13d0..b612a32af8e5 100644 --- a/python/tvm/script/ir_builder/tir/_ffi_api.py +++ b/python/tvm/script/ir_builder/tir/_ffi_api.py @@ -15,6 +15,219 @@ # specific language governing permissions and limitations # under the License. """FFI APIs""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import IntImm, PrimExpr, Range, Type + from script.ir_builder.tir import AllocateConstFrame, AllocateFrame, AssertFrame, AttrFrame, BlockFrame, BlockInitFrame, DeclBufferFrame, ElseFrame, ForFrame, IfFrame, LaunchThreadFrame, LetFrame, PrimFuncFrame, RealizeFrame, ThenFrame, WhileFrame + from tir import Buffer, BufferRegion, Var + from tvm_ffi import Object, Tensor, dtype + from typing import Any +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("script.ir_builder.tir", __name__) # pylint: disable=protected-access + +# tvm-ffi-stubgen(begin): global/script.ir_builder.tir +# fmt: off +_FFI_INIT_FUNC("script.ir_builder.tir", __name__) +if TYPE_CHECKING: + def AllocBuffer(_0: Sequence[PrimExpr], _1: dtype, _2: Var | None, _3: Sequence[PrimExpr], _4: PrimExpr, _5: str, _6: int, _7: int, _8: str, _9: Sequence[IntImm] | None, /) -> Buffer: ... + def Allocate(_0: Sequence[PrimExpr], _1: dtype, _2: str, _3: PrimExpr | None, _4: Mapping[str, Any] | None, /) -> AllocateFrame: ... + def AllocateConst(_0: Tensor, _1: dtype, _2: Sequence[PrimExpr], _3: Mapping[str, Any] | None, /) -> AllocateConstFrame: ... + def Arg(_0: str, _1: Object, /) -> Object: ... + def Assert(_0: PrimExpr, _1: str, /) -> AssertFrame: ... + def Attr(_0: Any, _1: str, _2: PrimExpr, /) -> AttrFrame: ... + def AxisOpaque(_0: Range, _1: PrimExpr, _2: dtype, /) -> Var: ... + def AxisReduce(_0: Range, _1: PrimExpr, _2: dtype, /) -> Var: ... + def AxisRemap(_0: str, _1: Sequence[PrimExpr], _2: dtype, /) -> Sequence[Var]: ... + def AxisScan(_0: Range, _1: PrimExpr, _2: dtype, /) -> Var: ... + def AxisSpatial(_0: Range, _1: PrimExpr, _2: dtype, /) -> Var: ... + def BFloat16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def BFloat16x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def BFloat16x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def BFloat16x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def BFloat16x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def BFloat16x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Block(_0: str, _1: bool, /) -> BlockFrame: ... + def BlockAttrs(_0: Mapping[str, Any], /) -> None: ... + def Boolean(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Buffer(_0: Sequence[PrimExpr], _1: dtype, _2: str, _3: Var | None, _4: Sequence[PrimExpr] | None, _5: PrimExpr | None, _6: str, _7: int, _8: int, _9: str, _10: Sequence[IntImm] | None, /) -> Buffer: ... + def BufferStore(_0: Buffer, _1: PrimExpr, _2: Sequence[PrimExpr], _3: PrimExpr | None, /) -> None: ... + def DeclBuffer(_0: Sequence[PrimExpr], _1: dtype, _2: str, _3: Var | None, _4: Sequence[PrimExpr] | None, _5: PrimExpr | None, _6: str, _7: int, _8: int, _9: str, _10: Sequence[IntImm] | None, /) -> DeclBufferFrame: ... + def Else() -> ElseFrame: ... + def EnvThread(_0: str, _1: dtype, /) -> Var: ... + def Evaluate(_0: PrimExpr, /) -> None: ... + def Float16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float16x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float16x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float16x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float16x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float16x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float32x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float32x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float32x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float32x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float32x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float4E2M1FN(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float4E2M1FNx16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float4E2M1FNx32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float4E2M1FNx4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float4E2M1FNx64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float4E2M1FNx8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float64x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float64x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float64x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float64x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float64x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float6E2M3FN(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float6E2M3FNx16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float6E2M3FNx32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float6E2M3FNx4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float6E2M3FNx64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float6E2M3FNx8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float6E3M2FN(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float6E3M2FNx16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float6E3M2FNx32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float6E3M2FNx4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float6E3M2FNx64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float6E3M2FNx8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E3M4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E3M4x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E3M4x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E3M4x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E3M4x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E3M4x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3B11FNUZ(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3B11FNUZx16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3B11FNUZx32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3B11FNUZx4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3B11FNUZx64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3B11FNUZx8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3FN(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3FNUZ(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3FNUZx16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3FNUZx32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3FNUZx4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3FNUZx64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3FNUZx8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3FNx16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3FNx32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3FNx4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3FNx64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3FNx8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E4M3x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E5M2(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E5M2FNUZ(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E5M2FNUZx16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E5M2FNUZx32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E5M2FNUZx4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E5M2FNUZx64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E5M2FNUZx8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E5M2x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E5M2x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E5M2x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E5M2x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E5M2x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E8M0FNU(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E8M0FNUx16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E8M0FNUx32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E8M0FNUx4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E8M0FNUx64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8E8M0FNUx8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Float8x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def FuncAttrs(_0: Mapping[str, Any], /) -> None: ... + def FuncName(_0: str, /) -> None: ... + def FuncRet(_0: Type, /) -> Type: ... + def Grid(_0: Sequence[PrimExpr], /) -> ForFrame: ... + def Handle(_0: dtype, _1: str, _2: bool, _3: bool, /) -> Var: ... + def If(_0: PrimExpr, /) -> IfFrame: ... + def Init() -> BlockInitFrame: ... + def Int16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int16x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int16x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int16x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int16x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int16x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int32x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int32x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int32x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int32x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int32x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int64x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int64x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int64x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int64x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int64x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int8x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int8x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int8x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int8x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Int8x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def LaunchThread(_0: Var | str, _1: PrimExpr, /) -> LaunchThreadFrame: ... + def LegacyLetStmt(_0: Var, _1: PrimExpr, /) -> LetFrame: ... + def LetStmt(_0: PrimExpr, _1: Type | None, _2: Var | None, /) -> LetFrame: ... + def MatchBuffer(_0: Object, _1: Sequence[PrimExpr], _2: dtype, _3: Var | None, _4: Sequence[PrimExpr], _5: PrimExpr, _6: str, _7: int, _8: int, _9: str, _10: Sequence[IntImm] | None, /) -> Buffer: ... + def Parallel(_0: PrimExpr, _1: PrimExpr, _2: Mapping[str, Any] | None, _3: PrimExpr | None, /) -> ForFrame: ... + def PrimFunc(_0: bool, /) -> PrimFuncFrame: ... + def Ptr(_0: dtype, _1: str, _2: bool, /) -> PrimExpr: ... + def Reads(_0: Sequence[Object], /) -> None: ... + def Realize(_0: BufferRegion, _1: str, _2: PrimExpr, /) -> RealizeFrame: ... + def Serial(_0: PrimExpr, _1: PrimExpr, _2: Mapping[str, Any] | None, _3: PrimExpr | None, /) -> ForFrame: ... + def TensormapHandle() -> Var: ... + def Then() -> ThenFrame: ... + def ThreadBinding(_0: PrimExpr, _1: PrimExpr, _2: str, _3: Mapping[str, Any] | None, /) -> ForFrame: ... + def UInt16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt16x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt16x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt16x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt16x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt16x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt32x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt32x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt32x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt32x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt32x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt64x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt64x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt64x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt64x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt64x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt8x16(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt8x32(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt8x4(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt8x64(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def UInt8x8(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Unroll(_0: PrimExpr, _1: PrimExpr, _2: Mapping[str, Any] | None, _3: PrimExpr | None, /) -> ForFrame: ... + def Vectorized(_0: PrimExpr, _1: PrimExpr, _2: Mapping[str, Any] | None, _3: PrimExpr | None, /) -> ForFrame: ... + def Void(_0: PrimExpr | None, _1: bool, /) -> PrimExpr: ... + def Where(_0: PrimExpr, /) -> None: ... + def While(_0: PrimExpr, /) -> WhileFrame: ... + def Writes(_0: Sequence[Object], /) -> None: ... + def max(_0: PrimExpr, _1: PrimExpr, /) -> PrimExpr: ... + def min(_0: PrimExpr, _1: PrimExpr, /) -> PrimExpr: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/script/printer/_ffi_api.py b/python/tvm/script/printer/_ffi_api.py index 967d0d824ba2..2b1b920e46fc 100644 --- a/python/tvm/script/printer/_ffi_api.py +++ b/python/tvm/script/printer/_ffi_api.py @@ -15,6 +15,61 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.script.printer""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from script import PrinterConfig + from script.printer import AssertDoc, AssignDoc, AttrAccessDoc, CallDoc, ClassDoc, CommentDoc, DictDoc, Doc, DocStringDoc, ExprDoc, ExprStmtDoc, ForDoc, FunctionDoc, IdDoc, IfDoc, IndexDoc, LambdaDoc, ListDoc, LiteralDoc, OperationDoc, ReturnDoc, ScopeDoc, SliceDoc, StmtBlockDoc, StmtDoc, TupleDoc, WhileDoc + from tvm_ffi import Object + from tvm_ffi.reflection import AccessPath +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("script.printer", __name__) # pylint: disable=protected-access + +# tvm-ffi-stubgen(begin): global/script.printer +# fmt: off +_FFI_INIT_FUNC("script.printer", __name__) +if TYPE_CHECKING: + def AssertDoc(_0: ExprDoc, _1: ExprDoc | None, /) -> AssertDoc: ... + def AssignDoc(_0: ExprDoc, _1: ExprDoc | None, _2: ExprDoc | None, /) -> AssignDoc: ... + def AttrAccessDoc(_0: ExprDoc, _1: str, /) -> AttrAccessDoc: ... + def CallDoc(_0: ExprDoc, _1: Sequence[ExprDoc], _2: Sequence[str], _3: Sequence[ExprDoc], /) -> CallDoc: ... + def ClassDoc(_0: IdDoc, _1: Sequence[ExprDoc], _2: Sequence[StmtDoc], /) -> ClassDoc: ... + def CommentDoc(_0: str, /) -> CommentDoc: ... + def DictDoc(_0: Sequence[ExprDoc], _1: Sequence[ExprDoc], /) -> DictDoc: ... + def DocSetSourcePaths(_0: Doc, _1: Sequence[AccessPath], /) -> None: ... + def DocStringDoc(_0: str, /) -> DocStringDoc: ... + def DocToPythonScript(_0: Doc, _1: PrinterConfig, /) -> str: ... + def ExprDocAttr(_0: ExprDoc, _1: str, /) -> ExprDoc: ... + def ExprDocCall(_0: ExprDoc, _1: Sequence[ExprDoc], _2: Sequence[str], _3: Sequence[ExprDoc], /) -> ExprDoc: ... + def ExprDocIndex(_0: ExprDoc, _1: Sequence[Doc], /) -> ExprDoc: ... + def ExprStmtDoc(_0: ExprDoc, /) -> ExprStmtDoc: ... + def ForDoc(_0: ExprDoc, _1: ExprDoc, _2: Sequence[StmtDoc], /) -> ForDoc: ... + def FunctionDoc(_0: IdDoc, _1: Sequence[AssignDoc], _2: Sequence[ExprDoc], _3: ExprDoc | None, _4: Sequence[StmtDoc], /) -> FunctionDoc: ... + def IdDoc(_0: str, /) -> IdDoc: ... + def IfDoc(_0: ExprDoc, _1: Sequence[StmtDoc], _2: Sequence[StmtDoc], /) -> IfDoc: ... + def IndexDoc(_0: ExprDoc, _1: Sequence[Doc], /) -> IndexDoc: ... + def LambdaDoc(_0: Sequence[IdDoc], _1: ExprDoc, /) -> LambdaDoc: ... + def ListDoc(_0: Sequence[ExprDoc], /) -> ListDoc: ... + def LiteralDocBoolean(_0: bool, _1: AccessPath | None, /) -> LiteralDoc: ... + def LiteralDocFloat(_0: float, _1: AccessPath | None, /) -> LiteralDoc: ... + def LiteralDocInt(_0: int, _1: AccessPath | None, /) -> LiteralDoc: ... + def LiteralDocNone(_0: AccessPath | None, /) -> LiteralDoc: ... + def LiteralDocStr(_0: str, _1: AccessPath | None, /) -> LiteralDoc: ... + def OperationDoc(_0: int, _1: Sequence[ExprDoc], /) -> OperationDoc: ... + def ReprPrintRelax(_0: Object, _1: PrinterConfig, /) -> str: ... + def ReturnDoc(_0: ExprDoc, /) -> ReturnDoc: ... + def ScopeDoc(_0: ExprDoc | None, _1: ExprDoc, _2: Sequence[StmtDoc], /) -> ScopeDoc: ... + def SliceDoc(_0: ExprDoc | None, _1: ExprDoc | None, _2: ExprDoc | None, /) -> SliceDoc: ... + def StmtBlockDoc(_0: Sequence[StmtDoc], /) -> StmtBlockDoc: ... + def StmtDocSetComment(_0: StmtDoc, _1: str | None, /) -> None: ... + def TupleDoc(_0: Sequence[ExprDoc], /) -> TupleDoc: ... + def WhileDoc(_0: ExprDoc, _1: Sequence[StmtDoc], /) -> WhileDoc: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py index 62d8c563dd3f..544642b2752c 100644 --- a/python/tvm/script/printer/doc.py +++ b/python/tvm/script/printer/doc.py @@ -16,6 +16,18 @@ # under the License. """Doc types for TVMScript Unified Printer""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from tvm_ffi.reflection import AccessPath + from typing import Any +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from enum import IntEnum, unique from typing import Dict, List, Optional, Sequence, Tuple, Union @@ -31,11 +43,22 @@ class Doc(Object): """Base class of all Docs""" + # tvm-ffi-stubgen(begin): object/script.printer.Doc + # fmt: off + source_paths: Sequence[AccessPath] + # fmt: on + # tvm-ffi-stubgen(end) + @register_object("script.printer.ExprDoc") class ExprDoc(Doc): """Base class of all expression Docs""" + # tvm-ffi-stubgen(begin): object/script.printer.ExprDoc + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def attr(self, name: str) -> "AttrAccessDoc": """ Create a doc that represents attribute access on self. @@ -106,6 +129,12 @@ def __iter__(self): class StmtDoc(Doc): """Base class of statement doc""" + # tvm-ffi-stubgen(begin): object/script.printer.StmtDoc + # fmt: off + comment: str | None + # fmt: on + # tvm-ffi-stubgen(end) + @register_object("script.printer.StmtBlockDoc") class StmtBlockDoc(Doc): @@ -115,6 +144,12 @@ class StmtBlockDoc(Doc): list of StmtDoc. """ + # tvm-ffi-stubgen(begin): object/script.printer.StmtBlockDoc + # fmt: off + stmts: Sequence[StmtDoc] + # fmt: on + # tvm-ffi-stubgen(end) + stmts: Sequence[StmtDoc] def __init__(self, stmts: List[StmtDoc]): @@ -125,6 +160,12 @@ def __init__(self, stmts: List[StmtDoc]): class LiteralDoc(ExprDoc): """Doc that represents literal value""" + # tvm-ffi-stubgen(begin): object/script.printer.LiteralDoc + # fmt: off + value: Any + # fmt: on + # tvm-ffi-stubgen(end) + value: Union[str, IntImm, FloatImm, None] def __init__( @@ -166,6 +207,12 @@ def __init__( class IdDoc(ExprDoc): """Doc that represents identifier""" + # tvm-ffi-stubgen(begin): object/script.printer.IdDoc + # fmt: off + name: str + # fmt: on + # tvm-ffi-stubgen(end) + name: str def __init__(self, name: str): @@ -176,6 +223,13 @@ def __init__(self, name: str): class AttrAccessDoc(ExprDoc): """Doc that represents attribute access on an expression""" + # tvm-ffi-stubgen(begin): object/script.printer.AttrAccessDoc + # fmt: off + value: ExprDoc + name: str + # fmt: on + # tvm-ffi-stubgen(end) + value: ExprDoc name: str @@ -187,6 +241,13 @@ def __init__(self, value: ExprDoc, name: str): class IndexDoc(ExprDoc): """Doc that represents index access on an expression""" + # tvm-ffi-stubgen(begin): object/script.printer.IndexDoc + # fmt: off + value: ExprDoc + indices: Sequence[Doc] + # fmt: on + # tvm-ffi-stubgen(end) + value: ExprDoc indices: Sequence[Union[ExprDoc, "SliceDoc"]] @@ -198,6 +259,15 @@ def __init__(self, value: ExprDoc, indices: List[Union[ExprDoc, "SliceDoc"]]): class CallDoc(ExprDoc): """Doc that represents function call""" + # tvm-ffi-stubgen(begin): object/script.printer.CallDoc + # fmt: off + callee: ExprDoc + args: Sequence[ExprDoc] + kwargs_keys: Sequence[str] + kwargs_values: Sequence[ExprDoc] + # fmt: on + # tvm-ffi-stubgen(end) + callee: ExprDoc args: Sequence[ExprDoc] kwargs_keys: Sequence[str] @@ -271,6 +341,13 @@ class OperationDoc(ExprDoc): if-then-else expression). """ + # tvm-ffi-stubgen(begin): object/script.printer.OperationDoc + # fmt: off + kind: int + operands: Sequence[ExprDoc] + # fmt: on + # tvm-ffi-stubgen(end) + kind: OperationKind operands: Sequence[ExprDoc] @@ -282,6 +359,13 @@ def __init__(self, kind: OperationKind, operands: List[ExprDoc]): class LambdaDoc(ExprDoc): """Doc that represents lambda function""" + # tvm-ffi-stubgen(begin): object/script.printer.LambdaDoc + # fmt: off + args: Sequence[IdDoc] + body: ExprDoc + # fmt: on + # tvm-ffi-stubgen(end) + args: Sequence[IdDoc] body: ExprDoc @@ -293,6 +377,12 @@ def __init__(self, args: List[IdDoc], body: ExprDoc): class TupleDoc(ExprDoc): """Doc that represents tuple literal""" + # tvm-ffi-stubgen(begin): object/script.printer.TupleDoc + # fmt: off + elements: Sequence[ExprDoc] + # fmt: on + # tvm-ffi-stubgen(end) + elements: Sequence[ExprDoc] def __init__(self, elements: List[ExprDoc]): @@ -303,6 +393,12 @@ def __init__(self, elements: List[ExprDoc]): class ListDoc(ExprDoc): """Doc that represents list literal""" + # tvm-ffi-stubgen(begin): object/script.printer.ListDoc + # fmt: off + elements: Sequence[ExprDoc] + # fmt: on + # tvm-ffi-stubgen(end) + elements: Sequence[ExprDoc] def __init__(self, elements: List[ExprDoc]): @@ -313,6 +409,13 @@ def __init__(self, elements: List[ExprDoc]): class DictDoc(ExprDoc): """Doc that represents dict literal""" + # tvm-ffi-stubgen(begin): object/script.printer.DictDoc + # fmt: off + keys: Sequence[ExprDoc] + values: Sequence[ExprDoc] + # fmt: on + # tvm-ffi-stubgen(end) + keys: Sequence[ExprDoc] values: Sequence[ExprDoc] @@ -330,6 +433,14 @@ class SliceDoc(ExprDoc): This doc can only appear in `IndexDoc.indices`. """ + # tvm-ffi-stubgen(begin): object/script.printer.SliceDoc + # fmt: off + start: ExprDoc | None + stop: ExprDoc | None + step: ExprDoc | None + # fmt: on + # tvm-ffi-stubgen(end) + start: Optional[ExprDoc] stop: Optional[ExprDoc] step: Optional[ExprDoc] @@ -347,6 +458,14 @@ def __init__( class AssignDoc(StmtDoc): """Doc that represents assign statement.""" + # tvm-ffi-stubgen(begin): object/script.printer.AssignDoc + # fmt: off + lhs: ExprDoc + rhs: ExprDoc | None + annotation: ExprDoc | None + # fmt: on + # tvm-ffi-stubgen(end) + lhs: ExprDoc rhs: Optional[ExprDoc] annotation: Optional[ExprDoc] @@ -364,6 +483,14 @@ def __init__(self, lhs: ExprDoc, rhs: Optional[ExprDoc], annotation: Optional[Ex class IfDoc(StmtDoc): """Doc that represent if-then-else statement.""" + # tvm-ffi-stubgen(begin): object/script.printer.IfDoc + # fmt: off + predicate: ExprDoc + then_branch: Sequence[StmtDoc] + else_branch: Sequence[StmtDoc] + # fmt: on + # tvm-ffi-stubgen(end) + predicate: ExprDoc then_branch: Sequence[StmtDoc] else_branch: Sequence[StmtDoc] @@ -381,6 +508,13 @@ def __init__(self, predicate: ExprDoc, then_branch: List[StmtDoc], else_branch: class WhileDoc(StmtDoc): """Doc that represents while statement.""" + # tvm-ffi-stubgen(begin): object/script.printer.WhileDoc + # fmt: off + predicate: ExprDoc + body: Sequence[StmtDoc] + # fmt: on + # tvm-ffi-stubgen(end) + predicate: ExprDoc body: Sequence[StmtDoc] @@ -392,6 +526,14 @@ def __init__(self, predicate: ExprDoc, body: List[StmtDoc]): class ForDoc(StmtDoc): """Doc that represents for statement.""" + # tvm-ffi-stubgen(begin): object/script.printer.ForDoc + # fmt: off + lhs: ExprDoc + rhs: ExprDoc + body: Sequence[StmtDoc] + # fmt: on + # tvm-ffi-stubgen(end) + lhs: ExprDoc rhs: ExprDoc body: Sequence[StmtDoc] @@ -411,6 +553,14 @@ class ScopeDoc(StmtDoc): """ + # tvm-ffi-stubgen(begin): object/script.printer.ScopeDoc + # fmt: off + lhs: ExprDoc | None + rhs: ExprDoc + body: Sequence[StmtDoc] + # fmt: on + # tvm-ffi-stubgen(end) + lhs: Optional[ExprDoc] rhs: ExprDoc body: Sequence[StmtDoc] @@ -423,6 +573,12 @@ def __init__(self, lhs: Optional[ExprDoc], rhs: ExprDoc, body: List[StmtDoc]): class ExprStmtDoc(StmtDoc): """Doc that represents an expression as statement.""" + # tvm-ffi-stubgen(begin): object/script.printer.ExprStmtDoc + # fmt: off + expr: ExprDoc + # fmt: on + # tvm-ffi-stubgen(end) + expr: ExprDoc def __init__(self, expr: ExprDoc): @@ -433,6 +589,13 @@ def __init__(self, expr: ExprDoc): class AssertDoc(StmtDoc): """Doc that represents assert statement.""" + # tvm-ffi-stubgen(begin): object/script.printer.AssertDoc + # fmt: off + test: ExprDoc + msg: ExprDoc | None + # fmt: on + # tvm-ffi-stubgen(end) + test: ExprDoc msg: Optional[ExprDoc] @@ -444,6 +607,12 @@ def __init__(self, test: ExprDoc, msg: Optional[ExprDoc] = None): class ReturnDoc(StmtDoc): """Doc that represents return statement.""" + # tvm-ffi-stubgen(begin): object/script.printer.ReturnDoc + # fmt: off + value: ExprDoc + # fmt: on + # tvm-ffi-stubgen(end) + value: ExprDoc def __init__(self, value: ExprDoc): @@ -454,6 +623,16 @@ def __init__(self, value: ExprDoc): class FunctionDoc(StmtDoc): """Doc that represents function definition.""" + # tvm-ffi-stubgen(begin): object/script.printer.FunctionDoc + # fmt: off + name: IdDoc + args: Sequence[AssignDoc] + decorators: Sequence[ExprDoc] + return_type: ExprDoc | None + body: Sequence[StmtDoc] + # fmt: on + # tvm-ffi-stubgen(end) + name: IdDoc args: Sequence[AssignDoc] decorators: Sequence[ExprDoc] @@ -482,6 +661,14 @@ def __init__( class ClassDoc(StmtDoc): """Doc that represents class definition.""" + # tvm-ffi-stubgen(begin): object/script.printer.ClassDoc + # fmt: off + name: IdDoc + decorators: Sequence[ExprDoc] + body: Sequence[StmtDoc] + # fmt: on + # tvm-ffi-stubgen(end) + name: IdDoc decorators: Sequence[ExprDoc] body: Sequence[StmtDoc] @@ -499,6 +686,11 @@ def __init__(self, name: IdDoc, decorators: List[ExprDoc], body: List[StmtDoc]): class CommentDoc(StmtDoc): """Doc that represents comment.""" + # tvm-ffi-stubgen(begin): object/script.printer.CommentDoc + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, comment: str): self.__init_handle_by_constructor__( _ffi_api.CommentDoc, comment # type: ignore # pylint: disable=no-member @@ -509,6 +701,11 @@ def __init__(self, comment: str): class DocStringDoc(StmtDoc): """Doc that represents docstring.""" + # tvm-ffi-stubgen(begin): object/script.printer.DocStringDoc + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, docs: str): self.__init_handle_by_constructor__( _ffi_api.DocStringDoc, docs # type: ignore # pylint: disable=no-member diff --git a/python/tvm/target/_ffi_api.py b/python/tvm/target/_ffi_api.py index 8b9f6c73bd4e..f858326ed940 100644 --- a/python/tvm/target/_ffi_api.py +++ b/python/tvm/target/_ffi_api.py @@ -15,7 +15,43 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.target""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import IRModule + from target import Target, TargetKind, VirtualDevice + from tvm_ffi import Device, Module + from typing import Any +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("target", __name__) + +# tvm-ffi-stubgen(begin): global/target +# fmt: off +_FFI_INIT_FUNC("target", __name__) +if TYPE_CHECKING: + def Build(_0: IRModule, _1: Target, /) -> Module: ... + def ListTargetKindOptions(_0: TargetKind, /) -> Mapping[str, str]: ... + def ListTargetKindOptionsFromName(_0: str, /) -> Mapping[str, str]: ... + def ListTargetKinds() -> Sequence[str]: ... + def Target(*args: Any) -> Any: ... + def TargetCurrent(_0: bool, /) -> Target: ... + def TargetEnterScope(_0: Target, /) -> None: ... + def TargetExitScope(_0: Target, /) -> None: ... + def TargetExport(_0: Target, /) -> Mapping[str, Any]: ... + def TargetGetDeviceType(_0: Target, /) -> int: ... + def TargetGetFeature(_0: Target, _1: str, /) -> Any: ... + def TargetKindGetAttr(_0: TargetKind, _1: str, /) -> Any: ... + def TargetTagAddTag(_0: str, _1: Mapping[str, Any], _2: bool, /) -> Target: ... + def TargetTagListTags() -> Mapping[str, Target]: ... + def VirtualDevice_ForDeviceTargetAndMemoryScope(_0: Device, _1: Target, _2: str, /) -> VirtualDevice: ... + def WithHost(_0: Target, _1: Target, /) -> Target: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index eb6e25f0450c..243844f1110a 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -34,6 +34,14 @@ class TargetKind(Object): """Kind of a compilation target""" + # tvm-ffi-stubgen(begin): object/target.TargetKind + # fmt: off + name: str + default_device_type: int + default_keys: Sequence[str] + # fmt: on + # tvm-ffi-stubgen(end) + @property def options(self): """Returns the dict of available option names and types""" @@ -68,6 +76,17 @@ class Target(Object): - :py:func:`tvm.target.intel_graphics` create Intel Graphics target """ + # tvm-ffi-stubgen(begin): object/target.Target + # fmt: off + kind: TargetKind + tag: str + keys: Sequence[str] + attrs: Mapping[str, Any] + features: Mapping[str, Any] + host: Object | None + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, target, host=None): """Construct a TVM target object from 1) Raw target string diff --git a/python/tvm/target/virtual_device.py b/python/tvm/target/virtual_device.py index e509c5670750..e8494dee05f6 100644 --- a/python/tvm/target/virtual_device.py +++ b/python/tvm/target/virtual_device.py @@ -26,6 +26,15 @@ class VirtualDevice(tvm_ffi.core.Object): """A compile time representation for where data is to be stored at runtime, and how to compile code to compute it.""" + # tvm-ffi-stubgen(begin): object/target.VirtualDevice + # fmt: off + device_type_int: int + virtual_device_id: int + target: Target + memory_scope: str + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, device=None, target=None, memory_scope="") -> None: if device is None: # The 'unconstrained' device has device type -1 and device id -1. diff --git a/python/tvm/te/_ffi_api.py b/python/tvm/te/_ffi_api.py index 172fff01d7ff..f8b58d6f2fe9 100644 --- a/python/tvm/te/_ffi_api.py +++ b/python/tvm/te/_ffi_api.py @@ -15,7 +15,39 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.te""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import PrimExpr + from te import ComputeOp, ExternOp, Operation, ScanOp, Tensor + from tir import Buffer, IterVar, Stmt + from tvm_ffi import dtype + from typing import Any +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("te", __name__) + +# tvm-ffi-stubgen(begin): global/te +# fmt: off +_FFI_INIT_FUNC("te", __name__) +if TYPE_CHECKING: + def ComputeOp(_0: str, _1: str, _2: Mapping[str, Any] | None, _3: Sequence[IterVar], _4: Sequence[PrimExpr], /) -> ComputeOp: ... + def CreatePrimFunc(*args: Any) -> Any: ... + def ExternOp(_0: str, _1: str, _2: Mapping[str, Any] | None, _3: Sequence[Tensor], _4: Sequence[Buffer], _5: Sequence[Buffer], _6: Stmt, /) -> ExternOp: ... + def OpGetOutput(_0: Operation, _1: int, /) -> Tensor: ... + def OpInputTensors(_0: Operation, /) -> Sequence[Tensor]: ... + def OpNumOutputs(_0: Operation, /) -> int: ... + def Placeholder(_0: PrimExpr | Sequence[PrimExpr], _1: dtype, _2: str, /) -> Tensor: ... + def ScanOp(_0: str, _1: str, _2: Mapping[str, Any] | None, _3: IterVar, _4: Sequence[Tensor], _5: Sequence[Tensor], _6: Sequence[Tensor], _7: Sequence[Tensor], /) -> ScanOp: ... + def Tensor(_0: Sequence[PrimExpr], _1: dtype, _2: Operation, _3: int, /) -> Tensor: ... + def TensorEqual(_0: Tensor, _1: Tensor, /) -> bool: ... + def TensorHash(_0: Tensor, /) -> int: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py index 4ef1b67969c8..e187f467d803 100644 --- a/python/tvm/te/tensor.py +++ b/python/tvm/te/tensor.py @@ -16,6 +16,20 @@ # under the License. """Tensor class for computation declaration.""" # pylint: disable=invalid-name +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import PrimExpr + from tir import Buffer, IterVar, Stmt + from tvm_ffi import dtype + from typing import Any +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import tvm_ffi from tvm.runtime import Object, ObjectConvertible @@ -52,6 +66,15 @@ def dtype(self): class Tensor(DataProducer, _expr.ExprOp): """Tensor object, to construct, see function.Tensor""" + # tvm-ffi-stubgen(begin): object/te.Tensor + # fmt: off + shape: Sequence[PrimExpr] + dtype: dtype + op: Operation + value_index: int + # fmt: on + # tvm-ffi-stubgen(end) + def __call__(self, *indices): ndim = self.ndim if len(indices) != ndim: @@ -96,6 +119,14 @@ def name(self): class Operation(Object): """Represent an operation that generates a tensor""" + # tvm-ffi-stubgen(begin): object/te.Operation + # fmt: off + name: str + tag: str + attrs: Mapping[str, Any] + # fmt: on + # tvm-ffi-stubgen(end) + def output(self, index): """Get the index-th output of the operation @@ -126,22 +157,62 @@ def input_tensors(self): class PlaceholderOp(Operation): """Placeholder operation.""" + # tvm-ffi-stubgen(begin): object/te.PlaceholderOp + # fmt: off + shape: Sequence[PrimExpr] + dtype: dtype + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("te.BaseComputeOp") class BaseComputeOp(Operation): """Compute operation.""" + # tvm-ffi-stubgen(begin): object/te.BaseComputeOp + # fmt: off + axis: Sequence[IterVar] + reduce_axis: Sequence[IterVar] + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("te.ComputeOp") class ComputeOp(BaseComputeOp): """Scalar operation.""" + # tvm-ffi-stubgen(begin): object/te.ComputeOp + # fmt: off + body: Sequence[PrimExpr] + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("te.ScanOp") class ScanOp(Operation): """Scan operation.""" + # tvm-ffi-stubgen(begin): object/te.ScanOp + # fmt: off + scan_axis: IterVar + init: Sequence[Tensor] + update: Sequence[Tensor] + state_placeholder: Sequence[Tensor] + inputs: Sequence[Tensor] + spatial_axis_: Sequence[IterVar] + # fmt: on + # tvm-ffi-stubgen(end) + @tvm_ffi.register_object("te.ExternOp") class ExternOp(Operation): """External operation.""" + + # tvm-ffi-stubgen(begin): object/te.ExternOp + # fmt: off + inputs: Sequence[Tensor] + input_placeholders: Sequence[Buffer] + output_placeholders: Sequence[Buffer] + body: Stmt + # fmt: on + # tvm-ffi-stubgen(end) diff --git a/python/tvm/testing/_ffi_api.py b/python/tvm/testing/_ffi_api.py index b7a0b59fd0e4..cde5169a567f 100644 --- a/python/tvm/testing/_ffi_api.py +++ b/python/tvm/testing/_ffi_api.py @@ -15,10 +15,98 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.testing""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import IntImm, PrimExpr + from tvm_ffi import Device, Module, Object, Shape, Tensor, dtype + from tvm_ffi.testing import TestIntPair + from typing import Any, Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) + # must import testing before init_ffi_api import tvm_ffi.testing -tvm_ffi.init_ffi_api("testing", __name__) +# tvm-ffi-stubgen(begin): global/testing +# fmt: off +_FFI_INIT_FUNC("testing", __name__) +if TYPE_CHECKING: + def AcceptsArrayOfPrimExpr(_0: Sequence[PrimExpr], /) -> Object: ... + def AcceptsArrayOfVariant(_0: Sequence[Callable[..., Any] | PrimExpr], /) -> Object: ... + def AcceptsBool(_0: bool, /) -> bool: ... + def AcceptsInt(_0: int, /) -> int: ... + def AcceptsMapOfPrimExpr(_0: Mapping[Any, PrimExpr], /) -> Object: ... + def AcceptsMapReturnsMap(_0: Mapping[Any, Any], /) -> Object: ... + def AcceptsMapReturnsValue(_0: Mapping[Any, Any], _1: Any, /) -> Any: ... + def AcceptsObjectRefArray(_0: Sequence[Any], /) -> Any: ... + def AcceptsPrimExpr(_0: PrimExpr, /) -> Object: ... + def AcceptsVariant(_0: str | IntImm, /) -> str: ... + def ErrorTest(_0: int, _1: int, /) -> None: ... + def FrontendTestModule() -> Module: ... + def GetShapeElem(_0: Shape, _1: int, /) -> int: ... + def GetShapeSize(_0: Shape, /) -> int: ... + def ReturnsVariant(_0: int, /) -> str | IntImm: ... + def TestIntPairSum(_0: TestIntPair, /) -> int: ... + def add_one(_0: int, /) -> int: ... + def apply(*args: Any) -> Any: ... + def device_test(*args: Any) -> Any: ... + def dump_events() -> None: ... + def echo(*args: Any) -> Any: ... + def get_add_one_c_symbol() -> int: ... + def get_mlir_add_one_c_symbol() -> int: ... + def identity_cpp(*args: Any) -> Any: ... + def make_unregistered_object() -> Object: ... + def nop(*args: Any) -> Any: ... + def object_use_count(_0: Object, /) -> int: ... + def optional_tensor_view_has_value(_0: Tensor | None, /) -> bool: ... + def record_event(*args: Any) -> Any: ... + def reset_events(*args: Any) -> Any: ... + def run_check_signal(_0: int, /) -> None: ... + def schema_arr_map_opt(_0: Sequence[int | None], _1: Mapping[str, Sequence[int]], _2: str | None, /) -> Mapping[str, Sequence[int]]: ... + def schema_id_any(_0: Any, /) -> Any: ... + def schema_id_arr(_0: Sequence[Any], /) -> Sequence[Any]: ... + def schema_id_arr_int(_0: Sequence[int], /) -> Sequence[int]: ... + def schema_id_arr_obj(_0: Sequence[Object], /) -> Sequence[Object]: ... + def schema_id_arr_str(_0: Sequence[str], /) -> Sequence[str]: ... + def schema_id_bool(_0: bool, /) -> bool: ... + def schema_id_bytes(_0: bytes, /) -> bytes: ... + def schema_id_device(_0: Device, /) -> Device: ... + def schema_id_dltensor(_0: Tensor, /) -> Tensor: ... + def schema_id_dtype(_0: dtype, /) -> dtype: ... + def schema_id_float(_0: float, /) -> float: ... + def schema_id_func(_0: Callable[..., Any], /) -> Callable[..., Any]: ... + def schema_id_func_typed(_0: Callable[[int, float, Callable[..., Any]], None], /) -> Callable[[int, float, Callable[..., Any]], None]: ... + def schema_id_int(_0: int, /) -> int: ... + def schema_id_map(_0: Mapping[Any, Any], /) -> Mapping[Any, Any]: ... + def schema_id_map_str_int(_0: Mapping[str, int], /) -> Mapping[str, int]: ... + def schema_id_map_str_obj(_0: Mapping[str, Object], /) -> Mapping[str, Object]: ... + def schema_id_map_str_str(_0: Mapping[str, str], /) -> Mapping[str, str]: ... + def schema_id_object(_0: Object, /) -> Object: ... + def schema_id_opt_int(_0: int | None, /) -> int | None: ... + def schema_id_opt_obj(_0: Object | None, /) -> Object | None: ... + def schema_id_opt_str(_0: str | None, /) -> str | None: ... + def schema_id_string(_0: str, /) -> str: ... + def schema_id_tensor(_0: Tensor, /) -> Tensor: ... + def schema_id_variant_int_str(_0: int | str, /) -> int | str: ... + def schema_no_args() -> int: ... + def schema_no_args_no_return() -> None: ... + def schema_no_return(_0: int, /) -> None: ... + def schema_packed(*args: Any) -> Any: ... + def schema_tensor_view_input(_0: Tensor, /) -> None: ... + def schema_variant_mix(_0: int | str | Sequence[int], /) -> int | str | Sequence[int]: ... + def sleep_in_ffi(_0: float, /) -> None: ... + def test_check_eq_callback(*args: Any) -> Any: ... + def test_raise_error(_0: str, _1: str, /) -> None: ... + def test_wrap_callback(*args: Any) -> Any: ... + def test_wrap_callback_suppress_err(*args: Any) -> Any: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/testing/attrs.py b/python/tvm/testing/attrs.py index 4e946ce6d4b9..38c8f7bed9f6 100644 --- a/python/tvm/testing/attrs.py +++ b/python/tvm/testing/attrs.py @@ -16,6 +16,17 @@ # under the License. # pylint: disable=invalid-name, import-outside-toplevel, unused-variable """Testing utilities for attrs""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import EnvFunc, PrimExpr +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from tvm_ffi import register_object from ..ir import Attrs @@ -24,5 +35,14 @@ class TestAttrs(Attrs): """Attrs used for testing purposes""" + # tvm-ffi-stubgen(begin): object/attrs.TestAttrs + # fmt: off + axis: int + name: str + padding: Sequence[PrimExpr] + func: EnvFunc + # fmt: on + # tvm-ffi-stubgen(end) + __all__ = ["TestAttrs"] diff --git a/python/tvm/tir/_ffi_api.py b/python/tvm/tir/_ffi_api.py index 4140cda741dd..5e2cd0c2e747 100644 --- a/python/tvm/tir/_ffi_api.py +++ b/python/tvm/tir/_ffi_api.py @@ -15,7 +15,182 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.tir""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import DictAttrs, IRModule, PrimExpr, Range, RelaxExpr, Span, Type + from tir import Add, Allocate, AllocateConst, And, AssertStmt, AttrStmt, BijectiveLayout, Block, BlockDependenceInfo, BlockRealize, BlockScope, Broadcast, Buffer, BufferLoad, BufferRealize, BufferRegion, BufferStore, Call, Cast, CommReducer, DataProducer, DeclBuffer, Dependency, Div, EQ, Evaluate, FloorDiv, FloorMod, For, GE, GT, IfThenElse, IndexMap, IterVar, LE, LT, Layout, Let, LetStmt, MatchBufferRegion, Max, Min, Mod, Mul, NE, Not, Or, PrimFunc, ProducerLoad, PyStmtExprMutator, PyStmtExprVisitor, Ramp, Reduce, Select, SeqStmt, Shuffle, SizeVar, Stmt, StmtSRef, StringImm, Sub, TensorIntrin, Var, While + from tvm_ffi import Object, Tensor, dtype + from typing import Any, Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("tir", __name__) + +# tvm-ffi-stubgen(begin): global/tir +# fmt: off +_FFI_INIT_FUNC("tir", __name__) +if TYPE_CHECKING: + def Add(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> Add: ... + def Allocate(_0: Var, _1: dtype, _2: Sequence[PrimExpr], _3: PrimExpr, _4: Stmt, _5: Mapping[str, Any], _6: Span, /) -> Allocate: ... + def AllocateConst(_0: Var, _1: dtype, _2: Sequence[PrimExpr], _3: Object, _4: Stmt, _5: Mapping[str, Any] | None, _6: Span, /) -> AllocateConst: ... + def And(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> And: ... + def AssertStmt(_0: PrimExpr, _1: StringImm, _2: Stmt, _3: Span, /) -> AssertStmt: ... + def AttrStmt(_0: Any, _1: str, _2: PrimExpr, _3: Stmt, _4: Span, /) -> AttrStmt: ... + def BijectiveLayout(_0: Layout, _1: Layout, /) -> BijectiveLayout: ... + def BijectiveLayoutBackwardIndex(_0: BijectiveLayout, _1: Sequence[PrimExpr], /) -> Sequence[PrimExpr]: ... + def BijectiveLayoutBackwardShape(_0: BijectiveLayout, _1: Sequence[PrimExpr], /) -> Sequence[PrimExpr]: ... + def BijectiveLayoutForwardIndex(_0: BijectiveLayout, _1: Sequence[PrimExpr], /) -> Sequence[PrimExpr]: ... + def BijectiveLayoutForwardShape(_0: BijectiveLayout, _1: Sequence[PrimExpr], /) -> Sequence[PrimExpr]: ... + def Block(_0: Sequence[IterVar], _1: Sequence[BufferRegion], _2: Sequence[BufferRegion], _3: str, _4: Stmt, _5: Stmt | None, _6: Sequence[Buffer], _7: Sequence[MatchBufferRegion], _8: Mapping[str, Any], _9: Span, /) -> Block: ... + def BlockDependenceInfo(_0: IRModule, /) -> BlockDependenceInfo: ... + def BlockDependenceInfoGetBlockScope(_0: BlockDependenceInfo, _1: StmtSRef, /) -> BlockScope: ... + def BlockDependenceInfoGetSRef(_0: BlockDependenceInfo, _1: Stmt, /) -> StmtSRef | None: ... + def BlockRealize(_0: Sequence[PrimExpr], _1: PrimExpr, _2: Block, _3: Span, /) -> BlockRealize: ... + def BlockScopeGetDepsByDst(_0: BlockScope, _1: StmtSRef, /) -> Sequence[Dependency]: ... + def BlockScopeGetDepsBySrc(_0: BlockScope, _1: StmtSRef, /) -> Sequence[Dependency]: ... + def Broadcast(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> Broadcast: ... + def Buffer(*args: Any) -> Any: ... + def BufferAccessPtr(_0: Buffer, _1: int, _2: dtype, _3: int, _4: PrimExpr, _5: PrimExpr | None, /) -> PrimExpr: ... + def BufferGetFlattenedBuffer(_0: Buffer, /) -> Buffer: ... + def BufferLoad(_0: Buffer, _1: Sequence[PrimExpr], _2: PrimExpr | None, _3: Span, /) -> BufferLoad: ... + def BufferOffsetOf(_0: Buffer, _1: Sequence[PrimExpr], /) -> Sequence[PrimExpr]: ... + def BufferRealize(_0: Buffer, _1: Sequence[Range], _2: PrimExpr, _3: Stmt, _4: Span, /) -> BufferRealize: ... + def BufferRegion(_0: Buffer, _1: Sequence[Range], /) -> BufferRegion: ... + def BufferStorageScope(_0: Buffer, /) -> str: ... + def BufferStore(_0: Buffer, _1: PrimExpr, _2: Sequence[PrimExpr], _3: PrimExpr | None, _4: Span, /) -> BufferStore: ... + def BufferVLoad(_0: Buffer, _1: Sequence[PrimExpr], _2: dtype, _3: PrimExpr | None, /) -> PrimExpr: ... + def BufferVStore(_0: Buffer, _1: Sequence[PrimExpr], _2: PrimExpr, _3: PrimExpr | None, /) -> Stmt: ... + def Call(_0: dtype | None, _1: RelaxExpr, _2: Sequence[str | dtype | IterVar | BufferRegion | PrimExpr], _3: Span, /) -> Call: ... + def Cast(_0: dtype, _1: PrimExpr, _2: Span, /) -> Cast: ... + def CommReducer(_0: Sequence[Var], _1: Sequence[Var], _2: Sequence[PrimExpr], _3: Sequence[PrimExpr], _4: Span, /) -> CommReducer: ... + def CommReducerCombine(_0: CommReducer, _1: Sequence[PrimExpr], _2: Sequence[PrimExpr], /) -> Sequence[PrimExpr]: ... + def DeclBuffer(_0: Buffer, _1: Stmt, _2: Span, /) -> DeclBuffer: ... + def Div(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> Div: ... + def EQ(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> EQ: ... + def Evaluate(_0: PrimExpr, _1: Span, /) -> Evaluate: ... + def FloorDiv(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> FloorDiv: ... + def FloorMod(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> FloorMod: ... + def For(_0: Var, _1: PrimExpr, _2: PrimExpr, _3: int, _4: Stmt, _5: IterVar | None, _6: Mapping[str, Any] | None, _7: PrimExpr | None, _8: Span, /) -> For: ... + def GE(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> GE: ... + def GT(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> GT: ... + def IRTransform(_0: Stmt, _1: Callable[..., Any], _2: Callable[..., Any], _3: Sequence[str] | None, /) -> Stmt: ... + def IfThenElse(_0: PrimExpr, _1: Stmt, _2: Stmt, _3: Span, /) -> IfThenElse: ... + def IndexMap(_0: Sequence[Var], _1: Sequence[PrimExpr], _2: IndexMap | None, /) -> IndexMap: ... + def IndexMapInverse(_0: IndexMap, _1: Sequence[Range], /) -> IndexMap: ... + def IndexMapMapIndices(_0: IndexMap, _1: Sequence[PrimExpr], /) -> Sequence[PrimExpr]: ... + def IndexMapMapShape(_0: IndexMap, _1: Sequence[PrimExpr], /) -> Sequence[PrimExpr]: ... + def IndexMapMapTensor(_0: IndexMap, _1: Tensor, /) -> Tensor: ... + def IndexMapNonSurjectiveInverse(_0: IndexMap, _1: Sequence[Range], /) -> Sequence[Object]: ... + def IterVar(_0: Range, _1: Var, _2: int, _3: str, _4: Span, /) -> IterVar: ... + def LE(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> LE: ... + def LT(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> LT: ... + def Layout(_0: str, _1: dtype, /) -> Layout: ... + def LayoutFactorOf(_0: Layout, _1: str, /) -> int: ... + def LayoutGetItem(_0: Layout, _1: int, /) -> str: ... + def LayoutIndexOf(_0: Layout, _1: str, /) -> int: ... + def LayoutNdim(_0: Layout, /) -> int: ... + def Let(_0: Var, _1: PrimExpr, _2: PrimExpr, _3: Span, /) -> Let: ... + def LetStmt(_0: Var, _1: PrimExpr, _2: Stmt, _3: Span, /) -> LetStmt: ... + def MakePyStmtExprMutator(_0: Callable[..., Any], _1: Callable[..., Any], _2: Callable[..., Any], _3: Callable[..., Any], _4: Callable[..., Any], _5: Callable[..., Any], _6: Callable[..., Any], _7: Callable[..., Any], _8: Callable[..., Any], _9: Callable[..., Any], _10: Callable[..., Any], _11: Callable[..., Any], _12: Callable[..., Any], _13: Callable[..., Any], _14: Callable[..., Any], _15: Callable[..., Any], _16: Callable[..., Any], _17: Callable[..., Any], _18: Callable[..., Any], _19: Callable[..., Any], _20: Callable[..., Any], _21: Callable[..., Any], _22: Callable[..., Any], _23: Callable[..., Any], _24: Callable[..., Any], _25: Callable[..., Any], _26: Callable[..., Any], _27: Callable[..., Any], _28: Callable[..., Any], _29: Callable[..., Any], _30: Callable[..., Any], _31: Callable[..., Any], _32: Callable[..., Any], _33: Callable[..., Any], _34: Callable[..., Any], _35: Callable[..., Any], _36: Callable[..., Any], _37: Callable[..., Any], _38: Callable[..., Any], _39: Callable[..., Any], _40: Callable[..., Any], _41: Callable[..., Any], _42: Callable[..., Any], _43: Callable[..., Any], _44: Callable[..., Any], _45: Callable[..., Any], _46: Callable[..., Any], _47: Callable[..., Any], _48: Callable[..., Any], _49: Callable[..., Any], /) -> PyStmtExprMutator: ... + def MakePyStmtExprVisitor(_0: Callable[..., Any], _1: Callable[..., Any], _2: Callable[..., Any], _3: Callable[..., Any], _4: Callable[..., Any], _5: Callable[..., Any], _6: Callable[..., Any], _7: Callable[..., Any], _8: Callable[..., Any], _9: Callable[..., Any], _10: Callable[..., Any], _11: Callable[..., Any], _12: Callable[..., Any], _13: Callable[..., Any], _14: Callable[..., Any], _15: Callable[..., Any], _16: Callable[..., Any], _17: Callable[..., Any], _18: Callable[..., Any], _19: Callable[..., Any], _20: Callable[..., Any], _21: Callable[..., Any], _22: Callable[..., Any], _23: Callable[..., Any], _24: Callable[..., Any], _25: Callable[..., Any], _26: Callable[..., Any], _27: Callable[..., Any], _28: Callable[..., Any], _29: Callable[..., Any], _30: Callable[..., Any], _31: Callable[..., Any], _32: Callable[..., Any], _33: Callable[..., Any], _34: Callable[..., Any], _35: Callable[..., Any], _36: Callable[..., Any], _37: Callable[..., Any], _38: Callable[..., Any], _39: Callable[..., Any], _40: Callable[..., Any], _41: Callable[..., Any], _42: Callable[..., Any], _43: Callable[..., Any], _44: Callable[..., Any], _45: Callable[..., Any], _46: Callable[..., Any], _47: Callable[..., Any], _48: Callable[..., Any], _49: Callable[..., Any], /) -> PyStmtExprVisitor: ... + def MatchBufferRegion(_0: Buffer, _1: BufferRegion, /) -> MatchBufferRegion: ... + def Max(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> Max: ... + def Min(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> Min: ... + def Mod(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> Mod: ... + def Mul(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> Mul: ... + def NE(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> NE: ... + def Not(_0: PrimExpr, _1: Span, /) -> Not: ... + def Or(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> Or: ... + def PostOrderVisit(_0: Object, _1: Callable[..., Any], /) -> None: ... + def PreOrderVisit(_0: Object, _1: Callable[..., Any], /) -> None: ... + def PrimFunc(_0: Sequence[Var], _1: Stmt, _2: Type, _3: Mapping[Var, Buffer], _4: DictAttrs, _5: Span, /) -> PrimFunc: ... + def ProducerLoad(_0: DataProducer, _1: Sequence[PrimExpr], _2: Span, /) -> ProducerLoad: ... + def PyStmtExprMutatorDefaultVisitExpr(_0: PyStmtExprMutator, _1: PrimExpr, /) -> None: ... + def PyStmtExprMutatorDefaultVisitStmt(_0: PyStmtExprMutator, _1: Stmt, /) -> None: ... + def PyStmtExprMutatorVisitExpr(_0: PyStmtExprMutator, _1: PrimExpr, /) -> PrimExpr: ... + def PyStmtExprMutatorVisitStmt(_0: PyStmtExprMutator, _1: Stmt, /) -> Stmt: ... + def PyStmtExprVisitorDefaultVisitExpr(_0: PyStmtExprVisitor, _1: PrimExpr, /) -> None: ... + def PyStmtExprVisitorDefaultVisitStmt(_0: PyStmtExprVisitor, _1: Stmt, /) -> None: ... + def PyStmtExprVisitorVisitExpr(_0: PyStmtExprVisitor, _1: PrimExpr, /) -> None: ... + def PyStmtExprVisitorVisitStmt(_0: PyStmtExprVisitor, _1: Stmt, /) -> None: ... + def Ramp(_0: PrimExpr, _1: PrimExpr, _2: PrimExpr, _3: Span, /) -> Ramp: ... + def Reduce(_0: CommReducer, _1: Sequence[PrimExpr], _2: Sequence[IterVar], _3: PrimExpr, _4: int, _5: Sequence[PrimExpr], _6: Span, /) -> Reduce: ... + def RenewDefs(_0: PrimFunc, /) -> PrimFunc: ... + def Select(_0: PrimExpr, _1: PrimExpr, _2: PrimExpr, _3: Span, /) -> Select: ... + def SeqStmt(_0: Sequence[Stmt], _1: Span, /) -> SeqStmt: ... + def Shuffle(_0: Sequence[PrimExpr], _1: Sequence[PrimExpr], _2: Span, /) -> Shuffle: ... + def SizeVar(_0: str, _1: dtype, _2: Span, /) -> SizeVar: ... + def Specialize(_0: PrimFunc, _1: Mapping[Var, Buffer | PrimExpr], /) -> PrimFunc: ... + def StmtSRefInlineMark() -> StmtSRef: ... + def StmtSRefParent(_0: StmtSRef, /) -> StmtSRef | None: ... + def StmtSRefRootMark() -> StmtSRef: ... + def StmtSRefStmt(_0: StmtSRef, /) -> Stmt | None: ... + def StringImm(_0: str, _1: Span, /) -> StringImm: ... + def Sub(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> Sub: ... + def Substitute(_0: Object, _1: Mapping[Var, PrimExpr], /) -> Object: ... + def TensorIntrin(_0: PrimFunc, _1: PrimFunc, /) -> TensorIntrin: ... + def TensorIntrinGet(_0: str, _1: bool, /) -> TensorIntrin | None: ... + def TensorIntrinRegister(_0: str, _1: TensorIntrin, _2: bool, /) -> None: ... + def Var(_0: str, _1: Any, _2: Span, /) -> Var: ... + def While(_0: PrimExpr, _1: Stmt, _2: Span, /) -> While: ... + def _OpAdd(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> PrimExpr: ... + def _OpAnd(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> PrimExpr: ... + def _OpCeilDiv(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> PrimExpr: ... + def _OpDiv(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> PrimExpr: ... + def _OpEQ(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> PrimExpr: ... + def _OpFloorDiv(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> PrimExpr: ... + def _OpFloorMod(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> PrimExpr: ... + def _OpGE(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> PrimExpr: ... + def _OpGT(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> PrimExpr: ... + def _OpIfThenElse(_0: PrimExpr, _1: PrimExpr, _2: PrimExpr, _3: Span, /) -> PrimExpr: ... + def _OpIndexDiv(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> PrimExpr: ... + def _OpIndexMod(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> PrimExpr: ... + def _OpLE(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> PrimExpr: ... + def _OpLT(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> PrimExpr: ... + def _OpLogAddExp(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> PrimExpr: ... + def _OpMax(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> PrimExpr: ... + def _OpMin(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> PrimExpr: ... + def _OpMod(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> PrimExpr: ... + def _OpMul(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> PrimExpr: ... + def _OpNE(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> PrimExpr: ... + def _OpOr(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> PrimExpr: ... + def _OpPow(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> PrimExpr: ... + def _OpSub(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> PrimExpr: ... + def _OpTruncDiv(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> PrimExpr: ... + def _OpTruncMod(_0: PrimExpr, _1: PrimExpr, _2: Span, /) -> PrimExpr: ... + def _cast(_0: dtype, _1: PrimExpr, _2: Span, /) -> PrimExpr: ... + def abs(_0: PrimExpr, _1: Span, /) -> PrimExpr: ... + def bitwise_and(*args: Any) -> Any: ... + def bitwise_not(_0: PrimExpr, _1: Span, /) -> PrimExpr: ... + def bitwise_or(*args: Any) -> Any: ... + def bitwise_xor(*args: Any) -> Any: ... + def break_loop(_0: Span, /) -> PrimExpr: ... + def ceil(_0: PrimExpr, _1: Span, /) -> PrimExpr: ... + def const_true(_0: dtype, _1: Span, /) -> PrimExpr: ... + def continue_loop(_0: Span, /) -> PrimExpr: ... + def convert(_0: PrimExpr | Sequence[PrimExpr], /) -> PrimExpr | Sequence[PrimExpr]: ... + def floor(_0: PrimExpr, _1: Span, /) -> PrimExpr: ... + def infinity(_0: dtype, _1: Span, /) -> PrimExpr: ... + def isfinite(_0: PrimExpr, _1: Span, /) -> PrimExpr: ... + def isinf(_0: PrimExpr, _1: Span, /) -> PrimExpr: ... + def isnan(_0: PrimExpr, _1: Span, /) -> PrimExpr: ... + def left_shift(*args: Any) -> Any: ... + def likely(_0: PrimExpr, _1: Span, /) -> PrimExpr: ... + def max_value(_0: dtype, _1: Span, /) -> PrimExpr: ... + def min_value(_0: dtype, _1: Span, /) -> PrimExpr: ... + def nearbyint(_0: PrimExpr, _1: Span, /) -> PrimExpr: ... + def reinterpret(_0: dtype, _1: PrimExpr, _2: Span, /) -> PrimExpr: ... + def ret(_0: PrimExpr, _1: Span, /) -> PrimExpr: ... + def right_shift(*args: Any) -> Any: ... + def round(_0: PrimExpr, _1: Span, /) -> PrimExpr: ... + def thread_return(_0: Span, /) -> PrimExpr: ... + def trunc(_0: PrimExpr, _1: Span, /) -> PrimExpr: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/tir/analysis/_ffi_api.py b/python/tvm/tir/analysis/_ffi_api.py index 9e5d094c1a82..1ba289ae828b 100644 --- a/python/tvm/tir/analysis/_ffi_api.py +++ b/python/tvm/tir/analysis/_ffi_api.py @@ -15,7 +15,44 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.tir.analysis""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import IRModule, IntImm, PrimExpr + from tir import Block, Buffer, BufferRegion, PrimFunc, Stmt, Var + from transform import Pass + from tvm_ffi import Object + from typing import Any +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("tir.analysis", __name__) + +# tvm-ffi-stubgen(begin): global/tir.analysis +# fmt: off +_FFI_INIT_FUNC("tir.analysis", __name__) +if TYPE_CHECKING: + def EstimateTIRFlops(_0: Object, /) -> float: ... + def GetBlockAccessRegion(_0: Block, _1: Mapping[Var, Buffer], /) -> Sequence[Sequence[BufferRegion]]: ... + def GetBlockReadWriteRegion(_0: Block, _1: Mapping[Var, Buffer], /) -> Sequence[Sequence[BufferRegion]]: ... + def OOBChecker() -> Pass: ... + def UndefinedVars(*args: Any) -> Any: ... + def VerifyWellFormed(_0: Object, _1: bool, /) -> bool: ... + def _identify_memcpy(_0: Stmt, /) -> Sequence[Object]: ... + def calculate_allocated_bytes(_0: Object, /) -> Mapping[str, Mapping[str, IntImm]]: ... + def detect_buffer_access_lca(_0: PrimFunc, /) -> Mapping[Buffer, Stmt | None]: ... + def expr_deep_equal(_0: PrimExpr, _1: PrimExpr, /) -> bool: ... + def find_anchor_block(_0: IRModule, /) -> Block | None: ... + def get_vtcm_compaction_passes() -> Sequence[Pass]: ... + def is_pure_function(_0: PrimFunc, _1: bool, /) -> bool: ... + def verify_gpu_code(_0: PrimFunc, _1: Mapping[str, PrimExpr], /) -> bool: ... + def verify_memory(_0: PrimFunc, /) -> bool: ... + def verify_ssa(_0: PrimFunc, /) -> bool: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/tir/block_dependence_info.py b/python/tvm/tir/block_dependence_info.py index 7bd6b418fc72..355f47461eb7 100644 --- a/python/tvm/tir/block_dependence_info.py +++ b/python/tvm/tir/block_dependence_info.py @@ -17,6 +17,8 @@ """Define BlockDependenceInfoNode that uses the BlockScope and StmtSRef objects to store the block level dependences""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from typing import Union, Optional from tvm_ffi import register_object from tvm.ir.module import IRModule @@ -43,6 +45,11 @@ class BlockDependenceInfo(Object): sref can be directly accessed as sref->parent """ + # tvm-ffi-stubgen(begin): object/tir.BlockDependenceInfo + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + mod: IRModule def __init__(self, mod: Union[IRModule, PrimFunc]): diff --git a/python/tvm/tir/block_scope.py b/python/tvm/tir/block_scope.py index d63771fae93e..7dfbe3777fd2 100644 --- a/python/tvm/tir/block_scope.py +++ b/python/tvm/tir/block_scope.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. """Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from enum import IntEnum from typing import List, Optional, Union @@ -40,6 +42,12 @@ class StmtSRef(Object): uniquely determined by the TensorIR AST. """ + # tvm-ffi-stubgen(begin): object/tir.StmtSRef + # fmt: off + seq_index: int + # fmt: on + # tvm-ffi-stubgen(end) + seq_index: int @property @@ -102,6 +110,14 @@ class Dependency(Object): The dependency kind """ + # tvm-ffi-stubgen(begin): object/tir.Dependency + # fmt: off + src: StmtSRef + dst: StmtSRef + kind: int + # fmt: on + # tvm-ffi-stubgen(end) + src: StmtSRef dst: StmtSRef kind: DepKind @@ -124,6 +140,11 @@ class BlockScope(Object): - Child block: The scope leaf blocks under the scope root or a specific internal sref """ + # tvm-ffi-stubgen(begin): object/tir.BlockScope + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def get_deps_by_src(self, block: StmtSRef) -> List[Dependency]: """Get all dependencies whose `src` is the target`block`. diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index 259017608275..2b1354803643 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -15,6 +15,19 @@ # specific language governing permissions and limitations # under the License. """Abstraction for array data structures.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import IntImm, PrimExpr, Span + from tir import Var + from tvm_ffi import dtype +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from numbers import Integral import tvm_ffi @@ -40,6 +53,22 @@ class Buffer(Object, Scriptable): decl_buffer : Declare a buffer """ + # tvm-ffi-stubgen(begin): object/tir.Buffer + # fmt: off + data: Var + dtype: dtype + shape: Sequence[PrimExpr] + strides: Sequence[PrimExpr] + axis_separators: Sequence[IntImm] + elem_offset: PrimExpr + name: str + data_alignment: int + offset_factor: int + buffer_type: int + span: Span + # fmt: on + # tvm-ffi-stubgen(end) + READ = 1 WRITE = 2 @@ -352,4 +381,8 @@ def decl_buffer( @tvm_ffi.register_object("tir.DataProducer") class DataProducer(Object): + # tvm-ffi-stubgen(begin): object/tir.DataProducer + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) pass diff --git a/python/tvm/tir/data_layout.py b/python/tvm/tir/data_layout.py index f9c0e0cdc7ce..9002cbb88119 100644 --- a/python/tvm/tir/data_layout.py +++ b/python/tvm/tir/data_layout.py @@ -15,6 +15,18 @@ # specific language governing permissions and limitations # under the License. """Data layout.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import PrimExpr + from tir import IterVar +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import Union import tvm_ffi @@ -37,6 +49,13 @@ class Layout(Object): layout : Declare a layout """ + # tvm-ffi-stubgen(begin): object/tir.Layout + # fmt: off + name: str + axes: Sequence[IterVar] + # fmt: on + # tvm-ffi-stubgen(end) + def __len__(self): return _ffi_api.LayoutNdim(self) # type: ignore @@ -102,6 +121,17 @@ class BijectiveLayout(Object): bijective_layout : Declare a layout """ + # tvm-ffi-stubgen(begin): object/tir.BijectiveLayout + # fmt: off + src_layout: Layout + dst_layout: Layout + index_forward_rule: Sequence[PrimExpr] + index_backward_rule: Sequence[PrimExpr] + shape_forward_rule: Sequence[PrimExpr] + shape_backward_rule: Sequence[PrimExpr] + # fmt: on + # tvm-ffi-stubgen(end) + def forward_index(self, index): """Given the indices of the src-layout, infer the dst index. diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index f5476230c19b..4e5e43a883a2 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -27,6 +27,18 @@ assert(isinstance(y, tvm.tir.Add)) assert(y.a == x) """ +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Sequence + from ir import PrimExpr, Range, RelaxExpr, Span, Type + from tir import Buffer, DataProducer +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import List, Optional, Union import tvm_ffi @@ -365,6 +377,13 @@ class Var(PrimExprWithOp): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.Var + # fmt: off + name: str + type_annotation: Type + # fmt: on + # tvm-ffi-stubgen(end) + name_hint: str type_annotation: ir.Type @@ -389,6 +408,11 @@ class SizeVar(Var): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.SizeVar + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + # pylint: disable=super-init-not-called def __init__(self, name: str, dtype: Union[str, ir.Type], span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.SizeVar, name, dtype, span) # type: ignore @@ -423,6 +447,15 @@ class IterVar(ExprOp, Object, Scriptable): te.reduce_axis: Create reduce axis IterVar. """ + # tvm-ffi-stubgen(begin): object/tir.IterVar + # fmt: off + dom: Range + var: Var + iter_type: int + thread_tag: str + # fmt: on + # tvm-ffi-stubgen(end) + DataPar = 0 ThreadIndex = 1 CommReduce = 2 @@ -489,6 +522,16 @@ class CommReducer(Object, Scriptable): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.CommReducer + # fmt: off + lhs: Sequence[Var] + rhs: Sequence[Var] + result: Sequence[PrimExpr] + identity_element: Sequence[PrimExpr] + span: Span + # fmt: on + # tvm-ffi-stubgen(end) + lhs: List[Var] rhs: List[Var] result: List[PrimExpr] @@ -535,6 +578,17 @@ class Reduce(PrimExprWithOp): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.Reduce + # fmt: off + combiner: CommReducer + source: Sequence[PrimExpr] + init: Sequence[PrimExpr] + axis: Sequence[IterVar] + condition: PrimExpr + value_index: int + # fmt: on + # tvm-ffi-stubgen(end) + combiner: CommReducer source: List[PrimExpr] init: List[PrimExpr] @@ -574,6 +628,12 @@ class FloatImm(ConstExpr): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/ir.FloatImm + # fmt: off + value: float + # fmt: on + # tvm-ffi-stubgen(end) + value: float def __init__(self, dtype: str, value: float, span: Optional[Span] = None) -> None: @@ -601,6 +661,12 @@ class IntImm(ConstExpr): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/ir.IntImm + # fmt: off + value: int + # fmt: on + # tvm-ffi-stubgen(end) + value: int def __init__(self, dtype: str, value: int, span: Optional[Span] = None) -> None: @@ -640,6 +706,12 @@ class StringImm(ConstExpr): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.StringImm + # fmt: off + value: str + # fmt: on + # tvm-ffi-stubgen(end) + value: str def __init__(self, value: str, span: Optional[Span] = None) -> None: @@ -675,6 +747,12 @@ class Cast(PrimExprWithOp): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.Cast + # fmt: off + value: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + value: PrimExpr def __init__(self, dtype, value, span: Optional[Span] = None) -> None: @@ -697,6 +775,13 @@ class Add(BinaryOpExpr): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.Add + # fmt: off + a: PrimExpr + b: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Add, a, b, span) # type: ignore @@ -717,6 +802,13 @@ class Sub(BinaryOpExpr): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.Sub + # fmt: off + a: PrimExpr + b: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Sub, a, b, span) # type: ignore @@ -737,6 +829,13 @@ class Mul(BinaryOpExpr): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.Mul + # fmt: off + a: PrimExpr + b: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Mul, a, b, span) # type: ignore @@ -757,6 +856,13 @@ class Div(BinaryOpExpr): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.Div + # fmt: off + a: PrimExpr + b: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Div, a, b, span) # type: ignore @@ -777,6 +883,13 @@ class Mod(BinaryOpExpr): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.Mod + # fmt: off + a: PrimExpr + b: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Mod, a, b, span) # type: ignore @@ -797,6 +910,13 @@ class FloorDiv(BinaryOpExpr): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.FloorDiv + # fmt: off + a: PrimExpr + b: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.FloorDiv, a, b, span) # type: ignore @@ -817,6 +937,13 @@ class FloorMod(BinaryOpExpr): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.FloorMod + # fmt: off + a: PrimExpr + b: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.FloorMod, a, b, span) # type: ignore @@ -837,6 +964,13 @@ class Min(BinaryOpExpr): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.Min + # fmt: off + a: PrimExpr + b: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Min, a, b, span) # type: ignore @@ -857,6 +991,13 @@ class Max(BinaryOpExpr): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.Max + # fmt: off + a: PrimExpr + b: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Max, a, b, span) # type: ignore @@ -877,6 +1018,13 @@ class EQ(CmpExpr): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.EQ + # fmt: off + a: PrimExpr + b: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.EQ, a, b, span) # type: ignore @@ -897,6 +1045,13 @@ class NE(CmpExpr): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.NE + # fmt: off + a: PrimExpr + b: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.NE, a, b, span) # type: ignore @@ -917,6 +1072,13 @@ class LT(CmpExpr): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.LT + # fmt: off + a: PrimExpr + b: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.LT, a, b, span) # type: ignore @@ -937,6 +1099,13 @@ class LE(CmpExpr): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.LE + # fmt: off + a: PrimExpr + b: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.LE, a, b, span) # type: ignore @@ -957,6 +1126,13 @@ class GT(CmpExpr): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.GT + # fmt: off + a: PrimExpr + b: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.GT, a, b, span) # type: ignore @@ -977,6 +1153,13 @@ class GE(CmpExpr): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.GE + # fmt: off + a: PrimExpr + b: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.GE, a, b, span) # type: ignore @@ -997,6 +1180,13 @@ class And(LogicalExpr): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.And + # fmt: off + a: PrimExpr + b: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, a: PrimExpr, b: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.And, a, b, span) # type: ignore @@ -1017,6 +1207,13 @@ class Or(LogicalExpr): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.Or + # fmt: off + a: PrimExpr + b: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + a: PrimExpr b: PrimExpr @@ -1037,6 +1234,12 @@ class Not(LogicalExpr): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.Not + # fmt: off + a: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + a: PrimExpr def __init__(self, a: PrimExpr, span: Optional[Span] = None) -> None: @@ -1069,6 +1272,14 @@ class Select(PrimExprWithOp): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.Select + # fmt: off + condition: PrimExpr + true_value: PrimExpr + false_value: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + condition: PrimExpr true_value: PrimExpr false_value: PrimExpr @@ -1107,6 +1318,14 @@ class BufferLoad(PrimExprWithOp): loaded. The number lanes of the mask must be equal to the number of lanes being loaded. """ + # tvm-ffi-stubgen(begin): object/tir.BufferLoad + # fmt: off + buffer: Buffer + indices: Sequence[PrimExpr] + predicate: PrimExpr | None + # fmt: on + # tvm-ffi-stubgen(end) + buffer: Buffer indices: List[PrimExpr] @@ -1138,6 +1357,13 @@ class ProducerLoad(PrimExprWithOp): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.ProducerLoad + # fmt: off + producer: DataProducer + indices: Sequence[PrimExpr] + # fmt: on + # tvm-ffi-stubgen(end) + producer: DataProducer indices: List[PrimExpr] @@ -1168,6 +1394,14 @@ class Ramp(PrimExprWithOp): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.Ramp + # fmt: off + base: PrimExpr + stride: PrimExpr + lanes: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + base: PrimExpr stride: PrimExpr lanes: PrimExpr @@ -1196,6 +1430,13 @@ class Broadcast(PrimExprWithOp): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.Broadcast + # fmt: off + value: PrimExpr + lanes: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + value: PrimExpr lanes: PrimExpr @@ -1219,6 +1460,13 @@ class Shuffle(PrimExprWithOp): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.Shuffle + # fmt: off + vectors: Sequence[PrimExpr] + indices: Sequence[PrimExpr] + # fmt: on + # tvm-ffi-stubgen(end) + vectors: List[PrimExpr] indices: List[PrimExpr] @@ -1261,6 +1509,13 @@ class Call(PrimExprWithOp): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.Call + # fmt: off + op: RelaxExpr + args: Sequence[PrimExpr] + # fmt: on + # tvm-ffi-stubgen(end) + op: Op args: List[PrimExpr] @@ -1300,6 +1555,14 @@ class Let(PrimExprWithOp): The location of this expression in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.Let + # fmt: off + var: Var + value: PrimExpr + body: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + var: Var value: PrimExpr body: PrimExpr diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 5b365e124cfc..b5c6c10ac22f 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -17,6 +17,19 @@ # pylint: disable=unrecognized-inline-option """Function data types.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import PrimExpr, Type + from tir import Buffer, Stmt, Var + from tvm_ffi import Object +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import collections import inspect from typing import Callable, List, Mapping, Optional, Tuple, Union @@ -59,6 +72,15 @@ class PrimFunc(BaseFunc, Scriptable): The location of this itervar in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.PrimFunc + # fmt: off + params: Sequence[Var] + ret_type: Type + buffer_map: Mapping[Var, Buffer] + body: Stmt + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, params, @@ -188,6 +210,13 @@ class TensorIntrin(Object): The function of the implementation for the execution. """ + # tvm-ffi-stubgen(begin): object/tir.TensorIntrin + # fmt: off + desc: PrimFunc + impl: PrimFunc + # fmt: on + # tvm-ffi-stubgen(end) + def __init__(self, desc, impl): self.__init_handle_by_constructor__(_ffi_api.TensorIntrin, desc, impl) @@ -249,6 +278,14 @@ class IndexMap(Object): index map. """ + # tvm-ffi-stubgen(begin): object/tir.IndexMap + # fmt: off + initial_indices: Sequence[Var] + final_indices: Sequence[PrimExpr] + inverse_index_map: Object | None + # fmt: on + # tvm-ffi-stubgen(end) + initial_indices: List[Var] final_indices: List[PrimExpr] diff --git a/python/tvm/tir/functor.py b/python/tvm/tir/functor.py index c2594835fedf..4b46ba7fef5a 100644 --- a/python/tvm/tir/functor.py +++ b/python/tvm/tir/functor.py @@ -16,6 +16,8 @@ # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, arguments-differ """The expression and statement functor of TIR.""" +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) from typing import Callable import tvm_ffi @@ -154,6 +156,11 @@ class _PyStmtExprVisitor(tvm_ffi.core.Object): See also: PyStmtExprVisitor, stmt_expr_visitor """ + # tvm-ffi-stubgen(begin): object/tir.PyStmtExprVisitor + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, f_visit_stmt: Callable = None, @@ -988,6 +995,11 @@ class _PyStmtExprMutator(tvm_ffi.core.Object): See also: stmt_expr_mutator, PyStmtExprMutator """ + # tvm-ffi-stubgen(begin): object/tir.PyStmtExprMutator + # fmt: off + # fmt: on + # tvm-ffi-stubgen(end) + def __init__( self, f_visit_stmt: Callable = None, diff --git a/python/tvm/tir/schedule/_ffi_api.py b/python/tvm/tir/schedule/_ffi_api.py index 5087112b892a..5977bba55fa5 100644 --- a/python/tvm/tir/schedule/_ffi_api.py +++ b/python/tvm/tir/schedule/_ffi_api.py @@ -15,6 +15,127 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.tir.schedule""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import FloatImm, GlobalVar, IRModule, IntImm, PrimExpr + from tir import Block, BlockRV, BlockRealize, BlockScope, Buffer, For, IndexMap, Instruction, InstructionKind, LoopRV, PrimFunc, Schedule, ScheduleState, Stmt, StmtSRef, Trace + from tir.schedule import AutoTensorizeMappingInfo, TensorizeInfo + from tvm_ffi import Object + from typing import Any, Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("tir.schedule", __name__) # pylint: disable=protected-access + +# tvm-ffi-stubgen(begin): global/tir.schedule +# fmt: off +_FFI_INIT_FUNC("tir.schedule", __name__) +if TYPE_CHECKING: + def BlockRV() -> BlockRV: ... + def CanDecomposePadding(_0: Schedule, _1: BlockRV, _2: LoopRV, /) -> bool: ... + def ConcreteSchedule(_0: IRModule, _1: int, _2: int, _3: int, _4: bool, /) -> Schedule: ... + def GetAutoTensorizeMappingInfo(_0: Schedule, _1: BlockRV, _2: PrimFunc, /) -> AutoTensorizeMappingInfo | None: ... + def GetBlockRealize(_0: Schedule, _1: BlockRV, /) -> BlockRealize: ... + def GetLoopIterType(_0: Schedule, _1: LoopRV, /) -> str: ... + def GetTensorizeLoopMapping(_0: Schedule, _1: BlockRV, _2: PrimFunc, _3: bool, /) -> TensorizeInfo | None: ... + def HasBlock(_0: Schedule, _1: str, /) -> bool: ... + def HasIfThenElse(_0: Stmt, /) -> bool: ... + def Instruction(_0: InstructionKind, _1: Sequence[Any], _2: Sequence[Any], _3: Sequence[Any], /) -> Instruction: ... + def InstructionKindGet(_0: str, /) -> InstructionKind: ... + def IsOutputBlock(_0: Schedule, _1: BlockRV, /) -> bool: ... + def IsReductionBlock(_0: Schedule, _1: BlockRV, _2: BlockRV, /) -> bool: ... + def IsSpatialPrimFunc(_0: PrimFunc, /) -> bool: ... + def IsTrivialBinding(_0: Schedule, _1: BlockRV, /) -> bool: ... + def LoopRV() -> LoopRV: ... + def NormalizePrimFunc(_0: Schedule, /) -> Object | None: ... + def RegisterReducer(_0: int, _1: Callable[..., Any], _2: Callable[..., Any], /) -> None: ... + def ScheduleAddUnitLoop(_0: Schedule, _1: Object, /) -> LoopRV: ... + def ScheduleAnnotate(_0: Schedule, _1: Object, _2: str, _3: Any, /) -> None: ... + def ScheduleAnnotateBufferAccess(_0: Schedule, _1: BlockRV, _2: int, _3: int, _4: IndexMap, /) -> None: ... + def ScheduleBind(_0: Schedule, _1: LoopRV, _2: str, /) -> None: ... + def ScheduleBlockize(_0: Schedule, _1: Object, _2: bool, /) -> BlockRV: ... + def ScheduleCacheIndex(_0: Schedule, _1: BlockRV, _2: str, _3: int, /) -> Sequence[BlockRV]: ... + def ScheduleCacheInplace(_0: Schedule, _1: BlockRV, _2: int, _3: str, /) -> Sequence[BlockRV]: ... + def ScheduleCacheRead(_0: Schedule, _1: BlockRV, _2: int, _3: str, _4: Sequence[BlockRV], /) -> BlockRV: ... + def ScheduleCacheWrite(_0: Schedule, _1: BlockRV, _2: int, _3: str, _4: Sequence[BlockRV], /) -> BlockRV: ... + def ScheduleComputeAt(_0: Schedule, _1: BlockRV, _2: LoopRV, _3: bool, _4: int, /) -> None: ... + def ScheduleComputeInline(_0: Schedule, _1: BlockRV, /) -> None: ... + def ScheduleCopy(_0: Schedule, /) -> Schedule: ... + def ScheduleDecomposePadding(_0: Schedule, _1: BlockRV, _2: LoopRV, /) -> BlockRV: ... + def ScheduleDecomposeReduction(_0: Schedule, _1: BlockRV, _2: LoopRV, /) -> BlockRV: ... + def ScheduleEnterPostproc(_0: Schedule, /) -> None: ... + def ScheduleForkSeed(_0: Schedule, /) -> int: ... + def ScheduleFuse(_0: Schedule, _1: Sequence[LoopRV], _2: bool, /) -> LoopRV: ... + def ScheduleFuseReductionEpilogue(_0: Schedule, _1: BlockRV, _2: BlockRV, /) -> None: ... + def ScheduleGet(_0: Schedule, _1: Object, /) -> Object: ... + def ScheduleGetBlock(_0: Schedule, _1: str, _2: str | None, /) -> BlockRV: ... + def ScheduleGetChildBlocks(_0: Schedule, _1: Object, /) -> Sequence[BlockRV]: ... + def ScheduleGetConsumers(_0: Schedule, _1: BlockRV, /) -> Sequence[BlockRV]: ... + def ScheduleGetFuncWorkingOn(_0: Schedule, /) -> GlobalVar | None: ... + def ScheduleGetLoops(_0: Schedule, _1: BlockRV, /) -> Sequence[LoopRV]: ... + def ScheduleGetMod(_0: Schedule, /) -> IRModule: ... + def ScheduleGetOutputBlocks(_0: Schedule, _1: BlockRV, /) -> Sequence[BlockRV]: ... + def ScheduleGetProducers(_0: Schedule, _1: BlockRV, /) -> Sequence[BlockRV]: ... + def ScheduleGetSRef(_0: Schedule, _1: Object, /) -> Object | None: ... + def ScheduleGetState(_0: Schedule, /) -> ScheduleState: ... + def ScheduleGetTrace(_0: Schedule, /) -> Trace | None: ... + def ScheduleLoopPartition(_0: Schedule, _1: LoopRV, _2: Sequence[PrimExpr | None], _3: bool, /) -> Sequence[LoopRV]: ... + def ScheduleMerge(_0: Schedule, _1: Sequence[LoopRV], /) -> LoopRV: ... + def SchedulePadEinsum(_0: Schedule, _1: BlockRV, _2: Sequence[IntImm], /) -> None: ... + def ScheduleParallel(_0: Schedule, _1: LoopRV, /) -> None: ... + def ScheduleRFactor(_0: Schedule, _1: LoopRV, _2: int, /) -> BlockRV: ... + def ScheduleReIndex(_0: Schedule, _1: BlockRV, _2: int, _3: int, /) -> BlockRV: ... + def ScheduleReadAt(_0: Schedule, _1: LoopRV, _2: BlockRV, _3: int, _4: str, /) -> BlockRV: ... + def ScheduleReindexCacheRead(_0: Schedule, _1: BlockRV, _2: int, _3: str, _4: IndexMap, /) -> BlockRV: ... + def ScheduleReindexCacheWrite(_0: Schedule, _1: BlockRV, _2: int, _3: str, _4: IndexMap, /) -> BlockRV: ... + def ScheduleRemoveRV(_0: Schedule, _1: Object, /) -> None: ... + def ScheduleReorder(_0: Schedule, _1: Sequence[LoopRV], /) -> None: ... + def ScheduleReorderBlockIterVar(_0: Schedule, _1: BlockRV, _2: Sequence[IntImm], /) -> None: ... + def ScheduleReverseComputeAt(_0: Schedule, _1: BlockRV, _2: LoopRV, _3: bool, _4: int, /) -> None: ... + def ScheduleReverseComputeInline(_0: Schedule, _1: BlockRV, /) -> None: ... + def ScheduleRollingBuffer(_0: Schedule, _1: BlockRV, _2: int, /) -> None: ... + def ScheduleSampleCategorical(_0: Schedule, _1: Sequence[IntImm], _2: Sequence[FloatImm], _3: IntImm | None, /) -> PrimExpr: ... + def ScheduleSampleComputeLocation(_0: Schedule, _1: BlockRV, _2: IntImm | None, /) -> LoopRV: ... + def ScheduleSamplePartitionedTile(_0: Schedule, _1: LoopRV, _2: int, _3: int, _4: int, _5: Sequence[IntImm] | None, /) -> Sequence[PrimExpr]: ... + def ScheduleSamplePerfectTile(_0: Schedule, _1: LoopRV, _2: int, _3: int, _4: Sequence[IntImm] | None, /) -> Sequence[PrimExpr]: ... + def ScheduleSeed(_0: Schedule, _1: int, /) -> None: ... + def ScheduleSetAxisSeparator(_0: Schedule, _1: BlockRV, _2: int, _3: int, _4: Sequence[IntImm], /) -> None: ... + def ScheduleSetScope(_0: Schedule, _1: BlockRV, _2: int, _3: str, /) -> None: ... + def ScheduleSplit(_0: Schedule, _1: LoopRV, _2: Sequence[PrimExpr | None], _3: bool, _4: bool, /) -> Sequence[LoopRV]: ... + def ScheduleState(_0: IRModule, _1: int, _2: bool, /) -> ScheduleState: ... + def ScheduleStateGetBlockScope(_0: ScheduleState, _1: StmtSRef, /) -> BlockScope: ... + def ScheduleStateGetCachedFlags(_0: ScheduleState, _1: StmtSRef, /) -> Sequence[IntImm]: ... + def ScheduleStateGetSRef(_0: ScheduleState, _1: Stmt, /) -> StmtSRef | None: ... + def ScheduleStateReplace(_0: ScheduleState, _1: StmtSRef, _2: Stmt, _3: Mapping[Block, Block], /) -> None: ... + def ScheduleStorageAlign(_0: Schedule, _1: BlockRV, _2: int, _3: int, _4: int, _5: int, /) -> None: ... + def ScheduleTensorize(_0: Schedule, _1: Object, _2: str, _3: bool, /) -> None: ... + def ScheduleTransformBlockLayout(_0: Schedule, _1: BlockRV, _2: IndexMap, /) -> None: ... + def ScheduleTransformLayout(_0: Schedule, _1: BlockRV, _2: int, _3: int, _4: IndexMap, _5: IndexMap | None, _6: bool, /) -> None: ... + def ScheduleUnannotate(_0: Schedule, _1: Object, _2: str, /) -> None: ... + def ScheduleUnroll(_0: Schedule, _1: LoopRV, /) -> None: ... + def ScheduleUnsafeHideBufferAccess(_0: Schedule, _1: BlockRV, _2: str, _3: Sequence[IntImm], /) -> None: ... + def ScheduleUnsafeSetDType(_0: Schedule, _1: BlockRV, _2: int, _3: str, /) -> None: ... + def ScheduleVectorize(_0: Schedule, _1: LoopRV, /) -> None: ... + def ScheduleWorkOn(_0: Schedule, _1: str, /) -> None: ... + def ScheduleWriteAt(_0: Schedule, _1: LoopRV, _2: BlockRV, _3: int, _4: str, /) -> BlockRV: ... + def SuggestIndexMap(_0: Buffer, _1: Sequence[PrimExpr], _2: Sequence[For], _3: PrimExpr, /) -> IndexMap | None: ... + def TileWithTensorIntrin(_0: Schedule, _1: BlockRV, _2: str, _3: bool, /) -> LoopRV | None: ... + def Trace(_0: Sequence[Instruction] | None, _1: Mapping[Instruction, Any] | None, /) -> Trace: ... + def TraceAppend(_0: Trace, _1: Instruction, _2: Object | None, /) -> None: ... + def TraceApplyJSONToSchedule(_0: Object, _1: Schedule, /) -> None: ... + def TraceApplyToSchedule(_0: Trace, _1: Schedule, _2: bool, _3: Callable[[Instruction, Sequence[Any], Sequence[Any], Any], Any], /) -> None: ... + def TraceAsJSON(_0: Trace, _1: bool, /) -> Object: ... + def TraceAsPython(_0: Trace, _1: bool, /) -> Sequence[str]: ... + def TraceGetDecision(_0: Trace, _1: Instruction, /) -> Any: ... + def TracePop(_0: Trace, /) -> Instruction | None: ... + def TraceSimplified(_0: Trace, _1: bool, /) -> Trace: ... + def TraceWithDecision(_0: Trace, _1: Instruction, _2: Any, _3: bool, /) -> Trace: ... + def TracedSchedule(_0: IRModule, _1: int, _2: int, _3: int, _4: bool, /) -> Schedule: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py index 66eab497eb5a..03a67d51d67a 100644 --- a/python/tvm/tir/schedule/analysis.py +++ b/python/tvm/tir/schedule/analysis.py @@ -15,6 +15,18 @@ # specific language governing permissions and limitations # under the License. """Analysis used in TensorIR scheduling""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import IntImm, PrimExpr + from tir import Buffer, For, IndexMap, IterVar, StmtSRef +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from typing import List, Optional import tvm_ffi @@ -66,6 +78,14 @@ def suggest_index_map( class TensorizeInfo(Object): """Necessary information used for tensorization.""" + # tvm-ffi-stubgen(begin): object/tir.schedule.TensorizeInfo + # fmt: off + loop_map: Mapping[StmtSRef, For] + desc_loop_indexer: Mapping[For, IntImm] + block_iter_paddings: Sequence[IntImm] | None + # fmt: on + # tvm-ffi-stubgen(end) + def get_tensorize_loop_mapping( sch: Schedule, block: BlockRV, desc_func: PrimFunc, allow_padding: bool = False @@ -94,6 +114,16 @@ def get_tensorize_loop_mapping( class AutoTensorizeMappingInfo(Object): """Necessary information used to perform transformations for tensorization.""" + # tvm-ffi-stubgen(begin): object/tir.schedule.AutoTensorizeMappingInfo + # fmt: off + mappings: Sequence[IndexMap] + lhs_buffer_map: Mapping[Buffer, Buffer] + rhs_buffer_indices: Mapping[Buffer, Sequence[PrimExpr]] + lhs_iters: Sequence[IterVar] + rhs_iters: Sequence[IterVar] + # fmt: on + # tvm-ffi-stubgen(end) + def get_auto_tensorize_mapping_info( sch: Schedule, block: BlockRV, desc_func: PrimFunc diff --git a/python/tvm/tir/schedule/state.py b/python/tvm/tir/schedule/state.py index 36436fe95783..f321bd44a53c 100644 --- a/python/tvm/tir/schedule/state.py +++ b/python/tvm/tir/schedule/state.py @@ -16,6 +16,16 @@ # under the License. # pylint: disable=invalid-name """This file defines ScheduleState, the core data structure of TensorIR scheduling.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from ir import IRModule +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from collections import namedtuple from enum import IntEnum from typing import Dict, Optional, Union @@ -102,6 +112,14 @@ class ScheduleState(Object): defaults to `True`. """ + # tvm-ffi-stubgen(begin): object/tir.ScheduleState + # fmt: off + mod: IRModule + debug_mask: int + enable_check: bool + # fmt: on + # tvm-ffi-stubgen(end) + mod: IRModule debug_mask: int diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 448ace3ade63..d56aec8575fb 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -26,6 +26,20 @@ assert isinstance(st, tvm.tir.stmt.BufferStore) assert(st.buffer == buffer) """ +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + from ir import IntImm, PrimExpr, Range + from tir import Buffer, IterVar, Stmt, Var + from tvm_ffi import Tensor, dtype + from typing import Any +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) from enum import IntEnum from typing import List, Mapping, Optional, Union @@ -61,6 +75,14 @@ class LetStmt(Stmt): The location of the stmt in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.LetStmt + # fmt: off + var: Var + value: PrimExpr + body: Stmt + # fmt: on + # tvm-ffi-stubgen(end) + var: Var value: PrimExpr body: Stmt @@ -91,6 +113,14 @@ class AssertStmt(Stmt): The location of the stmt in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.AssertStmt + # fmt: off + condition: PrimExpr + message: PrimExpr + body: Stmt + # fmt: on + # tvm-ffi-stubgen(end) + condition: PrimExpr message: PrimExpr body: Stmt @@ -156,6 +186,19 @@ class For(Stmt): The location of the stmt in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.For + # fmt: off + loop_var: Var + min: PrimExpr + extent: PrimExpr + kind: int + body: Stmt + thread_binding: IterVar | None + annotations: Mapping[str, Any] + step: PrimExpr | None + # fmt: on + # tvm-ffi-stubgen(end) + loop_var: Var min: PrimExpr extent: PrimExpr @@ -208,6 +251,13 @@ class While(Stmt): The location of the stmt in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.While + # fmt: off + condition: PrimExpr + body: Stmt + # fmt: on + # tvm-ffi-stubgen(end) + condition: PrimExpr body: Stmt span: Optional[Span] @@ -240,6 +290,15 @@ class BufferStore(Stmt): The location of the stmt in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.BufferStore + # fmt: off + buffer: Buffer + value: PrimExpr + indices: Sequence[PrimExpr] + predicate: PrimExpr | None + # fmt: on + # tvm-ffi-stubgen(end) + buffer: Buffer value: PrimExpr indices: List[PrimExpr] @@ -281,6 +340,15 @@ class BufferRealize(Stmt): The location of the stmt in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.BufferRealize + # fmt: off + buffer: Buffer + bounds: Sequence[Range] + condition: PrimExpr + body: Stmt + # fmt: on + # tvm-ffi-stubgen(end) + buffer: Buffer bounds: List[Range] condition: PrimExpr @@ -328,6 +396,17 @@ class Allocate(Stmt): The location of the stmt in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.Allocate + # fmt: off + buffer_var: Var + dtype: dtype + extents: Sequence[PrimExpr] + condition: PrimExpr + body: Stmt + annotations: Mapping[str, Any] + # fmt: on + # tvm-ffi-stubgen(end) + buffer_var: Var dtype: str extents: List[PrimExpr] @@ -391,6 +470,18 @@ class AllocateConst(Stmt): The location of the stmt in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.AllocateConst + # fmt: off + buffer_var: Var + data: Tensor | None + irmod_storage_idx: IntImm | None + dtype: dtype + extents: Sequence[PrimExpr] + body: Stmt + annotations: Mapping[str, Any] + # fmt: on + # tvm-ffi-stubgen(end) + buffer_var: Var dtype: str extents: List[PrimExpr] @@ -438,6 +529,13 @@ class DeclBuffer(Stmt): The location of this DeclBuffer in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.DeclBuffer + # fmt: off + buffer: Buffer + body: Stmt + # fmt: on + # tvm-ffi-stubgen(end) + buffer: Buffer body: Stmt span: Optional[Span] @@ -468,6 +566,15 @@ class AttrStmt(Stmt): The location of the stmt in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.AttrStmt + # fmt: off + node: Any + attr_key: str + value: PrimExpr + body: Stmt + # fmt: on + # tvm-ffi-stubgen(end) + node: Object attr_key: str value: PrimExpr @@ -495,6 +602,12 @@ class SeqStmt(Stmt): The location of the stmt in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.SeqStmt + # fmt: off + seq: Sequence[Stmt] + # fmt: on + # tvm-ffi-stubgen(end) + seq: List[Stmt] span: Optional[Span] @@ -527,6 +640,14 @@ class IfThenElse(Stmt): The location of the stmt in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.IfThenElse + # fmt: off + condition: PrimExpr + then_case: Stmt + else_case: Stmt | None + # fmt: on + # tvm-ffi-stubgen(end) + condition: PrimExpr then_case: Stmt else_case: Optional[Stmt] @@ -556,6 +677,12 @@ class Evaluate(Stmt): The location of the stmt in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.Evaluate + # fmt: off + value: PrimExpr + # fmt: on + # tvm-ffi-stubgen(end) + value: PrimExpr span: Optional[Span] @@ -576,6 +703,13 @@ class BufferRegion(Object, Scriptable): The region array of the buffer region """ + # tvm-ffi-stubgen(begin): object/tir.BufferRegion + # fmt: off + buffer: Buffer + region: Sequence[Range] + # fmt: on + # tvm-ffi-stubgen(end) + buffer: Buffer region: List[Range] @@ -596,6 +730,13 @@ class MatchBufferRegion(Object, Scriptable): The region of source buffer """ + # tvm-ffi-stubgen(begin): object/tir.MatchBufferRegion + # fmt: off + buffer: Buffer + source: BufferRegion + # fmt: on + # tvm-ffi-stubgen(end) + buffer: Buffer source: BufferRegion @@ -642,6 +783,20 @@ class Block(Stmt): The location of this block in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.Block + # fmt: off + iter_vars: Sequence[IterVar] + reads: Sequence[BufferRegion] + writes: Sequence[BufferRegion] + name_hint: str + alloc_buffers: Sequence[Buffer] + match_buffers: Sequence[MatchBufferRegion] + annotations: Mapping[str, Any] + init: Stmt | None + body: Stmt + # fmt: on + # tvm-ffi-stubgen(end) + iter_vars: List[IterVar] reads: List[BufferRegion] writes: List[BufferRegion] @@ -706,6 +861,14 @@ class BlockRealize(Stmt): The location of this block_realize in the source code. """ + # tvm-ffi-stubgen(begin): object/tir.BlockRealize + # fmt: off + iter_values: Sequence[PrimExpr] + predicate: PrimExpr + block: Block + # fmt: on + # tvm-ffi-stubgen(end) + iter_values: List[PrimExpr] predicate: PrimExpr block: Block diff --git a/python/tvm/tir/transform/_ffi_api.py b/python/tvm/tir/transform/_ffi_api.py index 67896ec05dda..89cc0bcf9f4c 100644 --- a/python/tvm/tir/transform/_ffi_api.py +++ b/python/tvm/tir/transform/_ffi_api.py @@ -15,7 +15,109 @@ # specific language governing permissions and limitations # under the License. """FFI APIs for tvm.tir.transform""" -import tvm_ffi +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Mapping + from ir import IRModule, PrimExpr + from target import Target + from tir import IterVar, PrimFunc, PrimFuncPass + from transform import Pass, PassContext, PassInfo + from typing import Callable +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) -tvm_ffi.init_ffi_api("tir.transform", __name__) + +# tvm-ffi-stubgen(begin): global/tir.transform +# fmt: off +_FFI_INIT_FUNC("tir.transform", __name__) +if TYPE_CHECKING: + def AnnotateDeviceRegions() -> Pass: ... + def AnnotateEntryFunc() -> Pass: ... + def AnnotateIrregularLoop() -> Pass: ... + def BF16ComputeLegalize() -> Pass: ... + def BF16StorageLegalize() -> Pass: ... + def BindTarget(_0: Target, /) -> Pass: ... + def CanonicalizeLoop() -> Pass: ... + def CombineContextCall() -> Pass: ... + def CommonSubexprElimTIR(_0: bool, _1: bool, /) -> Pass: ... + def CompactBufferAllocation(_0: bool, /) -> Pass: ... + def ConvertBlocksToOpaque() -> Pass: ... + def ConvertForLoopsToSerial() -> Pass: ... + def ConvertSSA() -> Pass: ... + def CreatePrimFuncPass(_0: Callable[[ObjectRValueRef[PrimFunc], IRModule, PassContext], PrimFunc], _1: PassInfo, /) -> PrimFuncPass: ... + def DecorateDeviceScope() -> Pass: ... + def DefaultGPUSchedule() -> Pass: ... + def ExtractPrimFuncConstants() -> Pass: ... + def FP8ComputeLegalize(_0: str, /) -> Pass: ... + def FP8StorageLegalize() -> Pass: ... + def Filter(_0: Callable[[PrimFunc], bool], /) -> Pass: ... + def FlattenBuffer() -> Pass: ... + def ForceNarrowIndexToInt32() -> Pass: ... + def HoistExpression() -> Pass: ... + def HoistIfThenElse() -> Pass: ... + def HoistIfThenElseBasic() -> Pass: ... + def InferFragment() -> Pass: ... + def InjectDoubleBuffer() -> Pass: ... + def InjectPTXAsyncCopy() -> Pass: ... + def InjectPTXLDG32(_0: bool, /) -> Pass: ... + def InjectPermutedLayout() -> Pass: ... + def InjectRollingBuffer() -> Pass: ... + def InjectSoftwarePipeline() -> Pass: ... + def InjectVirtualThread() -> Pass: ... + def InlinePrivateFunctions() -> Pass: ... + def InstrumentBoundCheckers() -> Pass: ... + def InstrumentProfileIntrinsics() -> Pass: ... + def LiftThreadBinding() -> Pass: ... + def LoopPartition() -> Pass: ... + def LowerAsyncDMA() -> Pass: ... + def LowerAutoCopy() -> Pass: ... + def LowerCrossThreadReduction() -> Pass: ... + def LowerCustomDatatypes() -> Pass: ... + def LowerDeviceKernelLaunch() -> Pass: ... + def LowerDeviceStorageAccessInfo() -> Pass: ... + def LowerInitBlock() -> Pass: ... + def LowerIntrin() -> Pass: ... + def LowerMatchBuffer() -> Pass: ... + def LowerOpaqueBlock() -> Pass: ... + def LowerTVMBuiltin() -> Pass: ... + def LowerThreadAllreduce() -> Pass: ... + def LowerVtcmAlloc() -> Pass: ... + def LowerWarpMemory() -> Pass: ... + def MakePackedAPI() -> Pass: ... + def MakeUnpackedAPI() -> Pass: ... + def ManifestSharedMemoryLocalStage() -> Pass: ... + def MergeSharedMemoryAllocations() -> Pass: ... + def NarrowDataType(_0: int, /) -> Pass: ... + def PlanAndUpdateBufferAllocationLocation() -> Pass: ... + def PointerValueTypeRewrite() -> Pass: ... + def ReduceBranchingThroughOvercompute() -> Pass: ... + def RemapThreadAxis(_0: Mapping[str, IterVar], /) -> Pass: ... + def RemoveAssume() -> Pass: ... + def RemoveNoOp() -> Pass: ... + def RemoveStoreUndef() -> Pass: ... + def RemoveWeightLayoutRewriteBlock(_0: bool, /) -> Pass: ... + def RenormalizeSplitPattern() -> Pass: ... + def RewriteUnsafeSelect() -> Pass: ... + def Simplify() -> Pass: ... + def SkipAssert() -> Pass: ... + def SplitHostDevice() -> Pass: ... + def StorageRewrite() -> Pass: ... + def ThreadSync(_0: str, /) -> Pass: ... + def TransformMmaBufferLayout() -> Pass: ... + def UnifyThreadBinding() -> Pass: ... + def UnrollLoop() -> Pass: ... + def UseAssumeToReduceBranches() -> Pass: ... + def VectorizeLoop(_0: bool, /) -> Pass: ... + def VerifyGPUCode(_0: Mapping[str, PrimExpr], /) -> Pass: ... + def VerifyMemory() -> Pass: ... + def VerifySSA() -> Pass: ... + def VerifyVTCMLimit(_0: Target | None, /) -> Pass: ... +# fmt: on +# tvm-ffi-stubgen(end) diff --git a/python/tvm/tir/transform/function_pass.py b/python/tvm/tir/transform/function_pass.py index a85eabd970e1..d8e1d09344f7 100644 --- a/python/tvm/tir/transform/function_pass.py +++ b/python/tvm/tir/transform/function_pass.py @@ -15,6 +15,16 @@ # specific language governing permissions and limitations # under the License. """TIR specific function pass support.""" +# tvm-ffi-stubgen(begin): import-section +# fmt: off +# isort: off +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from transform import PassInfo +# isort: on +# fmt: on +# tvm-ffi-stubgen(end) import inspect import functools from typing import Callable, List, Optional, Union @@ -31,6 +41,12 @@ class PrimFuncPass(Pass): pass class should be created through py:func:`tvm.tir.transform.function_pass`. """ + # tvm-ffi-stubgen(begin): object/tir.PrimFuncPass + # fmt: off + pass_info: PassInfo + # fmt: on + # tvm-ffi-stubgen(end) + def _wrap_class_function_pass(pass_cls, pass_info): """Wrap a python class as function pass""" diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 88cf4720d3a6..0cce7c70374d 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -18,6 +18,8 @@ # pylint: disable=invalid-name, unsupported-binary-operation +# tvm-ffi-stubgen(begin): import-section +# tvm-ffi-stubgen(end) import enum from typing import Callable, Optional @@ -54,6 +56,14 @@ def _transform(func, mod, ctx): class LoopPartitionConfig(_ir.Attrs): """Config for loop partition pass""" + # tvm-ffi-stubgen(begin): object/tir.transform.LoopPartitionConfig + # fmt: off + partition_const_loop: bool + no_unroll_loop_with_extent_one: bool + unroll_loop_with_partition_hint_no_interval: bool + # fmt: on + # tvm-ffi-stubgen(end) + def LoopPartition(): """Inject virtual thread loops. @@ -98,6 +108,12 @@ def InjectVirtualThread(): class InjectDoubleBufferConfig(_ir.Attrs): """Config for inject double buffer pass""" + # tvm-ffi-stubgen(begin): object/tir.transform.InjectDoubleBufferConfig + # fmt: off + split_loop: int + # fmt: on + # tvm-ffi-stubgen(end) + def InjectDoubleBuffer(): """Inject double buffer statements. @@ -165,6 +181,16 @@ def PointerValueTypeRewrite(): class UnrollLoopConfig(_ir.Attrs): """Config for unroll loop pass""" + # tvm-ffi-stubgen(begin): object/tir.transform.UnrollLoopConfig + # fmt: off + auto_max_step: int + auto_max_depth: int + auto_max_extent: int + explicit_unroll: int + unroll_local_access: int + # fmt: on + # tvm-ffi-stubgen(end) + def UnrollLoop(): """Unroll the constant loop marked by unroll. @@ -183,6 +209,12 @@ def UnrollLoop(): class ReduceBranchingThroughOvercomputeConfig(_ir.Attrs): """Config for reduce branching through overcompute pass""" + # tvm-ffi-stubgen(begin): object/tir.transform.ReduceBranchingThroughOvercomputeConfig + # fmt: off + use_dataflow_analysis: bool + # fmt: on + # tvm-ffi-stubgen(end) + def ReduceBranchingThroughOvercompute(): """Reduce branching by introducing overcompute @@ -199,6 +231,13 @@ def ReduceBranchingThroughOvercompute(): class RemoveNoOpConfig(_ir.Attrs): """Config for remove no op pass""" + # tvm-ffi-stubgen(begin): object/tir.transform.RemoveNoOpConfig + # fmt: off + use_dataflow_analysis: bool + max_simplification_steps: int + # fmt: on + # tvm-ffi-stubgen(end) + def RemoveNoOp(): """Remove No Op from the Stmt. @@ -308,6 +347,16 @@ def RewriteUnsafeSelect(): class SimplifyConfig(_ir.Attrs): """Config for simplify pass""" + # tvm-ffi-stubgen(begin): object/tir.transform.SimplifyConfig + # fmt: off + transitively_prove_inequalities: bool + propagate_knowns_to_prove_conditional: bool + propagate_knowns_to_simplify_expressions: bool + convert_boolean_to_and_of_ors: bool + apply_constraints_to_boolean_branches: bool + # fmt: on + # tvm-ffi-stubgen(end) + def Simplify(): """Run arithmetic simplifications on the statements and expressions. @@ -656,6 +705,12 @@ def VerifyVTCMLimit(limit=None): class HoistIfThenElseConfig(_ir.Attrs): """Config for hoist if then else pass""" + # tvm-ffi-stubgen(begin): object/tir.transform.HoistIfThenElseConfig + # fmt: off + support_block_scope_hoisting: bool + # fmt: on + # tvm-ffi-stubgen(end) + # pylint: disable=no-else-return,inconsistent-return-statements def HoistIfThenElse(variant: Optional[str] = None): @@ -740,6 +795,13 @@ class HoistedLetBindings(enum.Flag): class HoistExpressionConfig(_ir.Attrs): """Config for hoist expression pass""" + # tvm-ffi-stubgen(begin): object/tir.transform.HoistExpressionConfig + # fmt: off + hoisted_conditionals: int + hoisted_let_bindings: int + # fmt: on + # tvm-ffi-stubgen(end) + def HoistExpression(): """Generalized verison of HoistIfThenElse.