Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm-ffi
Submodule tvm-ffi updated 147 files
68 changes: 66 additions & 2 deletions python/tvm/arith/_ffi_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,71 @@
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.arith"""
import tvm_ffi
# tvm-ffi-stubgen(begin): import-section
# fmt: off
# isort: off
from __future__ import annotations
from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from arith import ConstIntBound, IntConstraints, IntConstraintsTransform, IntGroupBounds, IntervalSet, IterMark, IterSplitExpr, IterSumExpr, ModularSet
from collections.abc import Mapping, Sequence
from ir import IntImm, IntSet, PrimExpr, Range
from tir import Buffer, PrimFunc, Stmt, Var
from tvm_ffi import Object
from typing import Any
# isort: on
# fmt: on
# tvm-ffi-stubgen(end)


tvm_ffi.init_ffi_api("arith", __name__)

# tvm-ffi-stubgen(begin): global/arith
# fmt: off
_FFI_INIT_FUNC("arith", __name__)
if TYPE_CHECKING:
def ConstIntBound(_0: int, _1: int, /) -> ConstIntBound: ...
def CreateAnalyzer(*args: Any) -> Any: ...
def DeduceBound(_0: PrimExpr, _1: PrimExpr, _2: Mapping[Var, IntSet], _3: Mapping[Var, IntSet], /) -> IntSet: ...
def DetectClipBound(_0: PrimExpr, _1: Sequence[Var], /) -> Sequence[PrimExpr]: ...
def DetectCommonSubExpr(_0: PrimExpr, _1: int, /) -> Mapping[PrimExpr, IntImm]: ...
def DetectIterMap(_0: Sequence[PrimExpr], _1: Mapping[Var, Range], _2: PrimExpr, _3: int, _4: bool, /) -> Object: ...
def DetectLinearEquation(_0: PrimExpr, _1: Sequence[Var], /) -> Sequence[PrimExpr]: ...
def DomainTouched(_0: Stmt, _1: Buffer, _2: bool, _3: bool, /) -> Sequence[Range]: ...
def DomainTouchedAccessMap(_0: PrimFunc, /) -> Mapping[Buffer, Sequence[Object]]: ...
def EstimateRegionLowerBound(_0: Sequence[Range], _1: Mapping[Var, Range], _2: PrimExpr, /) -> Sequence[IntSet] | None: ...
def EstimateRegionStrictBound(_0: Sequence[Range], _1: Mapping[Var, Range], _2: PrimExpr, /) -> Sequence[IntSet] | None: ...
def EstimateRegionUpperBound(_0: Sequence[Range], _1: Mapping[Var, Range], _2: PrimExpr, /) -> Sequence[IntSet] | None: ...
def IntConstraints(_0: Sequence[Var], _1: Mapping[Var, Range], _2: Sequence[PrimExpr], /) -> IntConstraints: ...
def IntConstraintsTransform(_0: IntConstraints, _1: IntConstraints, _2: Mapping[Var, PrimExpr], _3: Mapping[Var, PrimExpr], /) -> IntConstraintsTransform: ...
def IntGroupBounds(_0: PrimExpr, _1: Sequence[PrimExpr], _2: Sequence[PrimExpr], _3: Sequence[PrimExpr], /) -> IntGroupBounds: ...
def IntGroupBounds_FindBestRange(*args: Any) -> Any: ...
def IntGroupBounds_from_range(_0: Range, /) -> IntGroupBounds: ...
def IntSetIsEverything(_0: IntSet, /) -> bool: ...
def IntSetIsNothing(_0: IntSet, /) -> bool: ...
def IntervalSet(_0: PrimExpr, _1: PrimExpr, /) -> IntervalSet: ...
def IntervalSetGetMax(_0: IntSet, /) -> PrimExpr: ...
def IntervalSetGetMin(_0: IntSet, /) -> PrimExpr: ...
def InverseAffineIterMap(_0: Sequence[IterSumExpr], _1: Sequence[PrimExpr], /) -> Mapping[Var, PrimExpr]: ...
def IterMapSimplify(_0: Sequence[PrimExpr], _1: Mapping[Var, Range], _2: PrimExpr, _3: int, _4: bool, /) -> Sequence[PrimExpr]: ...
def IterMark(_0: PrimExpr, _1: PrimExpr, /) -> IterMark: ...
def IterSplitExpr(_0: IterMark, _1: PrimExpr, _2: PrimExpr, _3: PrimExpr, /) -> IterSplitExpr: ...
def IterSumExpr(_0: Sequence[IterSplitExpr], _1: PrimExpr, /) -> IterSumExpr: ...
def ModularSet(_0: int, _1: int, /) -> ModularSet: ...
def NarrowPredicateExpression(_0: PrimExpr, _1: Mapping[Var, Range], /) -> PrimExpr: ...
def NegInf() -> PrimExpr: ...
def NormalizeIterMapToExpr(_0: PrimExpr, /) -> PrimExpr: ...
def NormalizeToIterSum(_0: PrimExpr, _1: Mapping[Var, Range], /) -> IterSumExpr: ...
def PosInf() -> PrimExpr: ...
def PresburgerSet(_0: PrimExpr, /) -> IntSet: ...
def SolveInequalitiesAsCondition(*args: Any) -> Any: ...
def SolveInequalitiesDeskewRange(*args: Any) -> Any: ...
def SolveInequalitiesToRange(*args: Any) -> Any: ...
def SolveLinearEquations(*args: Any) -> Any: ...
def SubspaceDivide(_0: Sequence[PrimExpr], _1: Mapping[Var, Range], _2: Sequence[Var], _3: PrimExpr, _4: int, _5: bool, /) -> Sequence[Sequence[IterMark]]: ...
def UnionLowerBound(_0: Sequence[IntSet], /) -> IntSet: ...
def intset_interval(_0: PrimExpr, _1: PrimExpr, /) -> IntSet: ...
def intset_single_point(_0: PrimExpr, /) -> IntSet: ...
def intset_vector(_0: PrimExpr, /) -> IntSet: ...
# fmt: on
# tvm-ffi-stubgen(end)
16 changes: 16 additions & 0 deletions python/tvm/arith/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
# pylint: disable=invalid-name
"""Arithmetic data structure and utility"""
# tvm-ffi-stubgen(begin): import-section
# tvm-ffi-stubgen(end)
import enum
from typing import Union

Expand Down Expand Up @@ -51,6 +53,13 @@ class Extension(enum.Flag):
class ModularSet(Object):
"""Represent range of (coeff * x + base) for x in Z"""

# tvm-ffi-stubgen(begin): object/arith.ModularSet
# fmt: off
coeff: int
base: int
# fmt: on
# tvm-ffi-stubgen(end)

def __init__(self, coeff, base):
self.__init_handle_by_constructor__(_ffi_api.ModularSet, coeff, base)

Expand All @@ -68,6 +77,13 @@ class ConstIntBound(Object):
The maximum value of the bound.
"""

# tvm-ffi-stubgen(begin): object/arith.ConstIntBound
# fmt: off
min_value: int
max_value: int
# fmt: on
# tvm-ffi-stubgen(end)

POS_INF = (1 << 63) - 1
NEG_INF = -POS_INF

Expand Down
27 changes: 27 additions & 0 deletions python/tvm/arith/int_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@
# specific language governing permissions and limitations
# under the License.
"""Integer set."""
# tvm-ffi-stubgen(begin): import-section
# fmt: off
# isort: off
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ir import PrimExpr
# isort: on
# fmt: on
# tvm-ffi-stubgen(end)
import tvm_ffi
from tvm.runtime import Object
from . import _ffi_api
Expand All @@ -24,6 +34,11 @@
class IntSet(Object):
"""Represent a set of integer in one dimension."""

# tvm-ffi-stubgen(begin): object/ir.IntSet
# fmt: off
# fmt: on
# tvm-ffi-stubgen(end)

def is_nothing(self):
"""Whether the set represent nothing"""
return _ffi_api.IntSetIsNothing(self)
Expand Down Expand Up @@ -78,6 +93,13 @@ class IntervalSet(IntSet):
The maximum value in the interval.
"""

# tvm-ffi-stubgen(begin): object/arith.IntervalSet
# fmt: off
min_value: PrimExpr
max_value: PrimExpr
# fmt: on
# tvm-ffi-stubgen(end)

def __init__(self, min_value, max_value):
self.__init_handle_by_constructor__(_ffi_api.IntervalSet, min_value, max_value)

Expand All @@ -86,6 +108,11 @@ def __init__(self, min_value, max_value):
class PresburgerSet(IntSet):
"""Represent of Presburger Set"""

# tvm-ffi-stubgen(begin): object/arith.PresburgerSet
# fmt: off
# fmt: on
# tvm-ffi-stubgen(end)

def __init__(self):
self.__init_handle_by_constructor__(_ffi_api.PresburgerSet)

Expand Down
38 changes: 38 additions & 0 deletions python/tvm/arith/int_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@
# specific language governing permissions and limitations
# under the License.
"""integer constraints data structures and solvers"""
# tvm-ffi-stubgen(begin): import-section
# fmt: off
# isort: off
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from ir import PrimExpr, Range
from tir import Var
# isort: on
# fmt: on
# tvm-ffi-stubgen(end)
import tvm_ffi
from tvm.runtime import Object
from . import _ffi_api
Expand All @@ -40,6 +52,15 @@ class IntGroupBounds(Object):
the upper bounds (include)
"""

# tvm-ffi-stubgen(begin): object/arith.IntGroupBounds
# fmt: off
coef: PrimExpr
lower: Sequence[PrimExpr]
equal: Sequence[PrimExpr]
upper: Sequence[PrimExpr]
# fmt: on
# tvm-ffi-stubgen(end)

def __init__(self, coef, lower, equal, upper):
self.__init_handle_by_constructor__(_ffi_api.IntGroupBounds, coef, lower, equal, upper)

Expand Down Expand Up @@ -81,6 +102,14 @@ class IntConstraints(Object):
The relations between the variables (either equations or inequalities)
"""

# tvm-ffi-stubgen(begin): object/arith.IntConstraints
# fmt: off
variables: Sequence[Var]
ranges: Mapping[Var, Range]
relations: Sequence[PrimExpr]
# fmt: on
# tvm-ffi-stubgen(end)

def __init__(self, variables, ranges, relations):
self.__init_handle_by_constructor__(_ffi_api.IntConstraints, variables, ranges, relations)

Expand Down Expand Up @@ -113,6 +142,15 @@ class IntConstraintsTransform(Object):
e.g., {m -> a, n -> -b}
"""

# tvm-ffi-stubgen(begin): object/arith.IntConstraintsTransform
# fmt: off
src: IntConstraints
dst: IntConstraints
src_to_dst: Mapping[Var, PrimExpr]
dst_to_src: Mapping[Var, PrimExpr]
# fmt: on
# tvm-ffi-stubgen(end)

def __init__(self, src, dst, src_to_dst, dst_to_src):
self.__init_handle_by_constructor__(
_ffi_api.IntConstraintsTransform, src, dst, src_to_dst, dst_to_src
Expand Down
47 changes: 47 additions & 0 deletions python/tvm/arith/iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@
# specific language governing permissions and limitations
# under the License.
"""Iterator (quasi)affine mapping patterns."""
# tvm-ffi-stubgen(begin): import-section
# fmt: off
# isort: off
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Sequence
from ir import PrimExpr
# isort: on
# fmt: on
# tvm-ffi-stubgen(end)
from enum import IntEnum
import tvm_ffi
from tvm.runtime import Object
Expand All @@ -26,6 +37,11 @@
class IterMapExpr(PrimExpr):
"""Base class of all IterMap expressions."""

# tvm-ffi-stubgen(begin): object/arith.IterMapExpr
# fmt: off
# fmt: on
# tvm-ffi-stubgen(end)


@tvm_ffi.register_object("arith.IterMark")
class IterMark(Object):
Expand All @@ -40,6 +56,13 @@ class IterMark(Object):
The extent of the iterator.
"""

# tvm-ffi-stubgen(begin): object/arith.IterMark
# fmt: off
source: PrimExpr
extent: PrimExpr
# fmt: on
# tvm-ffi-stubgen(end)

def __init__(self, source, extent):
self.__init_handle_by_constructor__(_ffi_api.IterMark, source, extent)

Expand All @@ -65,6 +88,15 @@ class IterSplitExpr(IterMapExpr):
Additional scale to the split.
"""

# tvm-ffi-stubgen(begin): object/arith.IterSplitExpr
# fmt: off
source: IterMark
lower_factor: PrimExpr
extent: PrimExpr
scale: PrimExpr
# fmt: on
# tvm-ffi-stubgen(end)

def __init__(self, source, lower_factor, extent, scale):
self.__init_handle_by_constructor__(
_ffi_api.IterSplitExpr, source, lower_factor, extent, scale
Expand All @@ -86,6 +118,13 @@ class IterSumExpr(IterMapExpr):
The base offset.
"""

# tvm-ffi-stubgen(begin): object/arith.IterSumExpr
# fmt: off
args: Sequence[IterSplitExpr]
base: PrimExpr
# fmt: on
# tvm-ffi-stubgen(end)

def __init__(self, args, base):
self.__init_handle_by_constructor__(_ffi_api.IterSumExpr, args, base)

Expand All @@ -94,6 +133,14 @@ def __init__(self, args, base):
class IterMapResult(Object):
"""Result of iter map detection."""

# tvm-ffi-stubgen(begin): object/arith.IterMapResult
# fmt: off
indices: Sequence[IterSumExpr]
errors: Sequence[str]
padding_predicate: PrimExpr
# fmt: on
# tvm-ffi-stubgen(end)


class IterMapLevel(IntEnum):
"""Possible kinds of iter mapping check level."""
Expand Down
15 changes: 13 additions & 2 deletions python/tvm/contrib/cutlass/_ffi_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@
# specific language governing permissions and limitations
# under the License.
"""FFI API for CUTLASS BYOC."""
import tvm_ffi
# tvm-ffi-stubgen(begin): import-section
# fmt: off
# isort: off
from __future__ import annotations
from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC
from typing import TYPE_CHECKING
# isort: on
# fmt: on
# tvm-ffi-stubgen(end)

tvm_ffi.init_ffi_api("contrib.cutlass", __name__)

# tvm-ffi-stubgen(begin): global/contrib.cutlass
_FFI_INIT_FUNC("contrib.cutlass", __name__)
# tvm-ffi-stubgen(end)
14 changes: 12 additions & 2 deletions python/tvm/contrib/msc/core/_ffi_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,17 @@
# specific language governing permissions and limitations
# under the License.
"""tvm.contrib.msc.core._ffi_api"""
# tvm-ffi-stubgen(begin): import-section
# fmt: off
# isort: off
from __future__ import annotations
from tvm_ffi import init_ffi_api as _FFI_INIT_FUNC
from typing import TYPE_CHECKING
# isort: on
# fmt: on
# tvm-ffi-stubgen(end)

import tvm_ffi

tvm_ffi.init_ffi_api("msc.core", __name__)
# tvm-ffi-stubgen(begin): global/msc.core
_FFI_INIT_FUNC("msc.core", __name__)
# tvm-ffi-stubgen(end)
Loading
Loading