From 4bce4f26dcfc443e5d45bda813711be959fa4e7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 3 Feb 2026 19:12:12 +0100 Subject: [PATCH 01/14] fix missing pixel_values --- .../plot_export_gemma3_tiny_input_observer.py | 3 +- .../ut_investigate/test_input_observer.py | 74 ++++++++++++---- .../test_input_observer_transformers.py | 88 +++++++++++++++++++ onnx_diagnostic/investigate/input_observer.py | 31 ++++++- 4 files changed, 174 insertions(+), 22 deletions(-) diff --git a/_doc/final/plot_export_gemma3_tiny_input_observer.py b/_doc/final/plot_export_gemma3_tiny_input_observer.py index 2b57d8b5..d719423c 100644 --- a/_doc/final/plot_export_gemma3_tiny_input_observer.py +++ b/_doc/final/plot_export_gemma3_tiny_input_observer.py @@ -9,6 +9,7 @@ """ import pandas +import torch from onnx_diagnostic import doc from onnx_diagnostic.helpers import string_type from onnx_diagnostic.export.api import to_onnx @@ -50,7 +51,7 @@ # %% # Captures inputs and outputs for the model. -observer = InputObserver() +observer = InputObserver(missing=dict(pixel_values=torch.empty((0, 3, 896, 896)))) with ( register_additional_serialization_functions(patch_transformers=True), observer(pipe.model), diff --git a/_unittests/ut_investigate/test_input_observer.py b/_unittests/ut_investigate/test_input_observer.py index b40dac7d..e41756b2 100644 --- a/_unittests/ut_investigate/test_input_observer.py +++ b/_unittests/ut_investigate/test_input_observer.py @@ -878,25 +878,61 @@ def forward(self, x=None, y=None): # self.assertEqual(2, len(args)) # self.assertEqual(len([v for v in args.values() if v is not None]), 2) - def test_infer_dynamic_shapes_exception(self): - """ - dict(input_ids:T7s1x282, - pixel_values:T1s1x3x896x896, - attention_mask:T7s1x282, - position_ids:T7s1x282, - token_type_ids:T7s1x282,cache_position:T7s282 - ) - dict(input_ids:T7s1x1,attention_mask:T7s1x283,position_ids:T7s1x1, - past_key_values:DynamicCache( - DynamicSlidingWindowLayer(T16s1x1x282x32, T16s1x1x282x32), - DynamicLayer(T16s1x1x282x32, T16s1x1x282x32)), - token_type_ids:T7s1x1,cache_position:T7s1) - dict(input_ids:T7s1x1,attention_mask:T7s1x284,position_ids:T7s1x1, - past_key_values:DynamicCache( - DynamicSlidingWindowLayer(T16s1x1x283x32, T16s1x1x283x32), - DynamicLayer(T16s1x1x283x32, T16s1x1x283x32)), - token_type_ids:T7s1x1,cache_position:T7s1) - """ + def test_infer_dynamic_shapes_missing(self): + class Model(torch.nn.Module): + def forward( + self, + input_ids=None, + pixel_values=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + token_type_ids=None, + cache_position=None, + ): + return input_ids + + inputs = [ + dict( + input_ids=torch.ones((1, 282), dtype=torch.int64), + pixel_values=torch.ones((1, 3, 896, 896), dtype=torch.int64), + attention_mask=torch.ones((1, 282), dtype=torch.int64), + position_ids=torch.ones((1, 282), dtype=torch.int64), + token_type_ids=torch.ones((1, 282), dtype=torch.int64), + cache_position=torch.ones((282,), dtype=torch.int64), + ), + dict( + input_ids=torch.ones((1, 1), dtype=torch.int64), + attention_mask=torch.ones((1, 283), dtype=torch.int64), + position_ids=torch.ones((1, 1), dtype=torch.int64), + past_key_values=torch.rand((1, 1, 282, 32)), + ), + dict( + input_ids=torch.ones((1, 1), dtype=torch.int64), + attention_mask=torch.ones((1, 284), dtype=torch.int64), + position_ids=torch.ones((1, 1), dtype=torch.int64), + past_key_values=torch.rand((1, 1, 283, 32)), + ), + ] + + model = Model() + observer = InputObserver(missing=dict(pixel_values=torch.empty((0, 3, 896, 896)))) + with observer(model): + for kwargs in inputs: + model(**kwargs) + + shapes = observer.infer_dynamic_shapes(set_batch_dimension_for=True) + cst = torch.export.Dim.DYNAMIC + expected = { + "input_ids": {0: cst, 1: cst}, + "pixel_values": {0: cst}, + "attention_mask": {0: cst, 1: cst}, + "position_ids": {0: cst, 1: cst}, + "past_key_values": {0: cst, 2: cst}, + "token_type_ids": {0: cst, 1: cst}, + "cache_position": {0: cst}, + } + self.assertEqual(expected, shapes) if __name__ == "__main__": diff --git a/_unittests/ut_investigate/test_input_observer_transformers.py b/_unittests/ut_investigate/test_input_observer_transformers.py index 52284fed..6faf4e84 100644 --- a/_unittests/ut_investigate/test_input_observer_transformers.py +++ b/_unittests/ut_investigate/test_input_observer_transformers.py @@ -216,6 +216,94 @@ def forward(self, cache): args["cache"].cross_attention_cache.layers[0].keys.shape, (1, 6, 1500, 64) ) + def test_infer_dynamic_shapes_missing_pixels(self): + import transformers + + class Model(torch.nn.Module): + def forward( + self, + input_ids=None, + pixel_values=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + token_type_ids=None, + cache_position=None, + ): + return input_ids + + inputs = [ + dict( + input_ids=torch.ones((1, 282), dtype=torch.int64), + pixel_values=torch.ones((1, 3, 896, 896), dtype=torch.int64), + attention_mask=torch.ones((1, 282), dtype=torch.int64), + position_ids=torch.ones((1, 282), dtype=torch.int64), + token_type_ids=torch.ones((1, 282), dtype=torch.int64), + cache_position=torch.ones((282,), dtype=torch.int64), + ), + dict( + input_ids=torch.ones((1, 1), dtype=torch.int64), + attention_mask=torch.ones((1, 283), dtype=torch.int64), + position_ids=torch.ones((1, 1), dtype=torch.int64), + past_key_values=make_dynamic_cache( + [ + (torch.rand((1, 1, 282, 32)), torch.rand((1, 1, 282, 32))), + (torch.rand((1, 1, 282, 32)), torch.rand((1, 1, 282, 32))), + ], + cls_layers=[ + transformers.cache_utils.DynamicSlidingWindowLayer, + transformers.cache_utils.DynamicLayer, + ], + ), + token_type_ids=torch.ones((1, 1), dtype=torch.int64), + cache_position=torch.ones((1,), dtype=torch.int64), + ), + dict( + input_ids=torch.ones((1, 1), dtype=torch.int64), + attention_mask=torch.ones((1, 284), dtype=torch.int64), + position_ids=torch.ones((1, 1), dtype=torch.int64), + past_key_values=make_dynamic_cache( + [ + (torch.rand((1, 1, 283, 32)), torch.rand((1, 1, 283, 32))), + (torch.rand((1, 1, 283, 32)), torch.rand((1, 1, 283, 32))), + ], + cls_layers=[ + transformers.cache_utils.DynamicSlidingWindowLayer, + transformers.cache_utils.DynamicLayer, + ], + ), + token_type_ids=torch.ones((1, 1), dtype=torch.int64), + cache_position=torch.ones((1,), dtype=torch.int64), + ), + ] + + model = Model() + observer = InputObserver(missing=dict(pixel_values=torch.empty((0, 3, 896, 896)))) + with ( + register_additional_serialization_functions(patch_transformers=True), + observer(model), + ): + for kwargs in inputs: + model(**kwargs) + + shapes = observer.infer_dynamic_shapes(set_batch_dimension_for=True) + cst = torch.export.Dim.DYNAMIC + expected = { + "input_ids": {0: cst, 1: cst}, + "pixel_values": {0: cst}, + "attention_mask": {0: cst, 1: cst}, + "position_ids": {0: cst, 1: cst}, + "past_key_values": [ + {0: cst, 2: cst}, + {0: cst, 2: cst}, + {0: cst, 2: cst}, + {0: cst, 2: cst}, + ], + "token_type_ids": {0: cst, 1: cst}, + "cache_position": {0: cst}, + } + self.assertEqual(expected, shapes) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/investigate/input_observer.py b/onnx_diagnostic/investigate/input_observer.py index 9233cd04..1e97d705 100644 --- a/onnx_diagnostic/investigate/input_observer.py +++ b/onnx_diagnostic/investigate/input_observer.py @@ -285,12 +285,23 @@ class InputObserverInfo: to be the same in the ordered dictionaries `add_inputs` receive. default_values: Default values defined by the signature of the function, any value equal to that is ignore to simplify the export. + missing: If a named argument (in kwargs) is missing, + a default value will be taken in this dictionary, + this is used when after the prefill step, an argument + disappears (such as `pixel_values`) and another one + is added (such as `past_key_values`). + The values are only to infer dynamic shapes and arguments, + not to run the model. """ def __init__( - self, signature_names: list[str], default_values: dict[str, int | bool | str | float] + self, + signature_names: list[str], + default_values: dict[str, int | bool | str | float], + missing: dict[str, Any], ): self.default_values = default_values + self.missing = missing self.inputs: list[InputCandidate] = [] self.outputs_specs: list[torch.utils._pytree.PyTreeSpec] = [] self.flat_outputs: list[list[torch.Tensor | None]] = [] @@ -323,6 +334,11 @@ def add_inputs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): if v is not None and not isinstance(v, (int, float, bool)) } + # adds missing attributes + for k, v in self.missing.items(): + if k not in kwargs: + kwargs[k] = v + # kwargs may come in a different ordeer teach. # dictionaries are ordered and torch.export.export expects # dynamic shapes an kwargs to follow the same order. @@ -631,6 +647,15 @@ class InputObserver: This information is used to infer dynamic shapes and export arguments. + Args: + missing: If a named argument (in kwargs) is missing, + a default value will be taken in this dictionary, + this is used when after the prefill step, an argument + disappears (such as `pixel_values`) and another one + is added (such as `past_key_values`). + The values are only to infer dynamic shapes and arguments, + not to run the model. + Examples -------- >>> input_observer = InputObserver() @@ -658,8 +683,9 @@ class InputObserver: :ref:`l-plot-whisper-tiny-export-input-observer`. """ - def __init__(self): + def __init__(self, missing: dict[str, Any] | None = None): self.info: InputObserverInfo | None = None # type: ignore[annotation-unchecked] + self.missing = missing or {} def _replaced_method( self, @@ -715,6 +741,7 @@ def __call__( if p.default != inspect.Parameter.empty and isinstance(p.default, (int, bool, str, float)) }, + missing=self.missing, ) n_already_stored = len(self.info) lambda_method = lambda *args, _cm=captured_method, _snc=( # noqa: E731 From 7e34f871d1004079ac78b08f46148b2f7bdc82e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 3 Feb 2026 19:13:33 +0100 Subject: [PATCH 02/14] changelogs --- CHANGELOGS.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 38b99ba2..52c21ffa 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -5,7 +5,7 @@ Change Logs ++++++ -* :pr:`400`, :pr:`401`:, :pr:`402`: improves InputObserver (investigations), add it the documentation +* :pr:`400`, :pr:`401`:, :pr:`402`, :pr:`403`: improves InputObserver (investigations), add it the documentation * :pr:`399`: update CI 0.8.11 From e13a426758690534779541316a7b374820e1e395 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 4 Feb 2026 09:03:58 +0100 Subject: [PATCH 03/14] update cache serialization --- _unittests/ut_helpers/test_cache_helper.py | 22 +++++++ .../ut_investigate/test_input_observer.py | 4 ++ .../test_patch_serialization_transformers.py | 12 +++- onnx_diagnostic/helpers/cache_helper.py | 64 ++++++++++++------- .../serialization/transformers_impl.py | 30 ++++++++- 5 files changed, 105 insertions(+), 27 deletions(-) diff --git a/_unittests/ut_helpers/test_cache_helper.py b/_unittests/ut_helpers/test_cache_helper.py index eea74ba7..16b444c5 100644 --- a/_unittests/ut_helpers/test_cache_helper.py +++ b/_unittests/ut_helpers/test_cache_helper.py @@ -373,6 +373,28 @@ def test_make_dynamic_cache_2_types(self): ) self.assertEqual(0, max_diff(cache, cache)["abs"]) + @requires_transformers("4.57") + def test_make_dynamic_cache_2_types_kwargs(self): + cache = make_dynamic_cache( + [ + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + ], + cls_layers=[ + transformers.cache_utils.DynamicLayer, + transformers.cache_utils.DynamicSlidingWindowLayer, + ], + cls_kwargs=[{}, dict(sliding_window=12)], + ) + text = self.string_type(cache, with_shape=True) + self.assertEqual( + "DynamicCache(DynamicLayer(T1s4x5x6x7, T1s4x5x6x7), " + "DynamicSlidingWindowLayer(T1s4x5x6x7, T1s4x5x6x7))", + text, + ) + self.assertEqual(0, max_diff(cache, cache)["abs"]) + self.assertEqual(cache.layers[1].sliding_window, 12) + @requires_transformers("4.57") def test_unflatten_flatten_mixed_layers(self): with torch_export_patches(patch_transformers=True): diff --git a/_unittests/ut_investigate/test_input_observer.py b/_unittests/ut_investigate/test_input_observer.py index e41756b2..2629c97a 100644 --- a/_unittests/ut_investigate/test_input_observer.py +++ b/_unittests/ut_investigate/test_input_observer.py @@ -906,12 +906,16 @@ def forward( attention_mask=torch.ones((1, 283), dtype=torch.int64), position_ids=torch.ones((1, 1), dtype=torch.int64), past_key_values=torch.rand((1, 1, 282, 32)), + token_type_ids=torch.ones((1, 1), dtype=torch.int64), + cache_position=torch.ones((1,), dtype=torch.int64), ), dict( input_ids=torch.ones((1, 1), dtype=torch.int64), attention_mask=torch.ones((1, 284), dtype=torch.int64), position_ids=torch.ones((1, 1), dtype=torch.int64), past_key_values=torch.rand((1, 1, 283, 32)), + token_type_ids=torch.ones((1, 1), dtype=torch.int64), + cache_position=torch.ones((1,), dtype=torch.int64), ), ] diff --git a/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py b/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py index 0f69c99c..0771dfb7 100644 --- a/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py @@ -275,13 +275,19 @@ def test_sliding_window_cache_flatten(self): @unittest.skipIf(make_sliding_window_cache, "transformers<5") def test_sliding_window_cache_flatten5(self): cache = make_dynamic_cache( - [(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4)))], + [ + (torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4))), + (torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4))), + ], cls_layers="DynamicSlidingWindowLayer", + cls_kwargs=[dict(sliding_window=11), dict(sliding_window=12)], ) + self.assertEqual(cache.layers[0].sliding_window, 11) + self.assertEqual(cache.layers[1].sliding_window, 12) with torch_export_patches(patch_transformers=True): flat, _spec = torch.utils._pytree.tree_flatten(cache) self.assertEqual( - "#2[T1s4x4x4x4,T1s4x4x4x4]", + "#4[T1s4x4x4x4,T1s4x4x4x4,T1s4x4x4x4,T1s4x4x4x4]", self.string_type(flat, with_shape=True), ) cache2 = torch.utils._pytree.tree_unflatten(flat, _spec) @@ -292,6 +298,8 @@ def test_sliding_window_cache_flatten5(self): self.assertEqual( [type(lay) for lay in cache.layers], [type(lay) for lay in cache2.layers] ) + self.assertEqual(cache2.layers[0].sliding_window, 11) + self.assertEqual(cache2.layers[1].sliding_window, 12) @ignore_warnings(UserWarning) @requires_torch("2.7.99") diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index eef4e925..cb98052f 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -4,6 +4,19 @@ import transformers import transformers.cache_utils +KWARGS_LAYER = {} +if hasattr(transformers.cache_utils, "DynamicSlidingWindowLayer"): + KWARGS_LAYER.update( + { + transformers.cache_utils.DynamicSlidingWindowLayer: lambda tensor: { + "sliding_window": tensor.shape[2] + }, + transformers.cache_utils.StaticSlidingWindowLayer: lambda tensor: { + "sliding_window": tensor.shape[2] + }, + } + ) + class CacheKeyValue: """ @@ -185,6 +198,7 @@ def _preprocess_key_value_pairs( def make_dynamic_cache( key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]], cls_layers: Optional[Union[str, List[type]]] = None, + cls_kwargs: Optional[Union[Dict[str, int], List[Dict[str, int]]]] = None, ) -> transformers.cache_utils.DynamicCache: """ Creates an instance of :class:`transformers.cache_utils.DynamicCache`. @@ -194,6 +208,8 @@ def make_dynamic_cache( :param cls_layers: to select the appropriate class to use on each layer, if specified, sliding_window is ignored, it can be a string if all layers are expected to follow the same class + :param cls_kwargs: arguments used to build a specific layer, + such as ``sliding_window`` for ``DynamicSlidingWindowLayer`` :return: :class:`transformers.cache_utils.DynamicCache` Example: @@ -227,35 +243,33 @@ def make_dynamic_cache( if isinstance(cls_layers, str): assert hasattr( transformers.cache_utils, cls_layers - ), f"Unable to find class {cls_layers!r} in transformers.cache_utils" - cls_kwargs = {} - cls_layer = getattr(transformers.cache_utils, cls_layers) - if cls_layers == "DynamicSlidingWindowLayer": - cls_kwargs["sliding_window"] = key_value_pairs[0][0].shape[2] - assert isinstance( - cls_kwargs["sliding_window"], int - ), f"sliding_window must be an integer but shape={key_value_pairs[0][0].shape}" - elif cls_layers is not None and isinstance(cls_layers, list): + ), f"Missing layer class {cls_layers!r}" + cls_layers = getattr(transformers.cache_utils, cls_layers) + if cls_layers and not isinstance(cls_layers, list): + cls_layers = [cls_layers for _ in key_value_pairs] + if cls_layers is not None and isinstance(cls_layers, list): assert len(cls_layers) == len(key_value_pairs), ( f"Length mismatch {len(key_value_pairs)} expected but " f"{len(cls_layers)} layer types are given." ) - cls_kwargs = [{} for _kv in key_value_pairs] # type: ignore[assignment] + if cls_kwargs is None: + cls_kwargs = [{} for _kv in key_value_pairs] # type: ignore[assignment] + assert len(cls_layers) == len(cls_kwargs), ( + f"Length mismatch {len(cls_kwargs)} expected but " + f"{len(cls_layers)} layer types are given, " + f"cls_layers={cls_layers}, cls_kwargs={cls_kwargs}" + ) cls_layer = None - if ( - hasattr(transformers.cache_utils, "DynamicSlidingWindowLayer") - and transformers.cache_utils.DynamicSlidingWindowLayer in cls_layers - ): - assert ( - key_value_pairs and key_value_pairs[0] - ), f"not implemented for type(key_value_pairs[0])={type(key_value_pairs[0])}" - for kv, clsy, kws in zip(key_value_pairs, cls_layers, cls_kwargs): - if clsy == transformers.cache_utils.DynamicSlidingWindowLayer: - kws["sliding_window"] = kv[0].shape[2] # type: ignore[index] - assert isinstance( - kws["sliding_window"], int # type: ignore[index] - ), f"sliding_window must be an integer but shape={kv[0].shape}" + assert ( + key_value_pairs and key_value_pairs[0] + ), f"not implemented for type(key_value_pairs[0])={type(key_value_pairs[0])}" + for kv, clsy, kws in zip(key_value_pairs, cls_layers, cls_kwargs): + default_values = KWARGS_LAYER.get(clsy, lambda tensor: {})(kv[0]) + for k, v in default_values.items(): + if k not in kws: + kws[k] = v else: + assert cls_kwargs is None, "cls_layers must be a list if cls_kwargs is specified" assert ( cls_layers is None ), f"cls_layers must be list or a string but it is {cls_layers}" @@ -267,6 +281,10 @@ def make_dynamic_cache( ) if cls_layer is not None: + assert isinstance(cls_kwargs, dict), ( + f"one layer = one set of arguments, cls_layer={cls_layer}, " + f"cls_kwargs={cls_kwargs}" + ) cls_layers = [cls_layer for _ in key_value_pairs] cls_kwargs = ( cls_kwargs # type: ignore[assignment] diff --git a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py index 4dedb248..b93c4601 100644 --- a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +++ b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py @@ -24,10 +24,12 @@ from . import make_serialization_function_for_dataclass SUPPORTED_DATACLASSES: Set[type] = set() + WRONG_REGISTRATIONS = { DynamicCache: "4.50", BaseModelOutput: None, } + SHORTEN_LAYER_NAMES = { "DynamicLayer": "D", "DynamicSlidingWindowLayer": "W", @@ -39,6 +41,20 @@ "X": "StaticSlidingWindowLayer", } +KWARGS_LAYER_NAMES = { + "DynamicLayer": lambda layer: "", + "DynamicSlidingWindowLayer": lambda layer: str(layer.sliding_window), + "StaticLayer": lambda layer: "", + "StaticSlidingWindowLayer": lambda layer: str(layer.sliding_window), +} + +PARSE_LAYER_NAMES = { + "DynamicLayer": lambda skw: {}, + "DynamicSlidingWindowLayer": lambda skw: dict(sliding_window=int(skw[1:])), + "StaticLayer": lambda skw: {}, + "StaticSlidingWindowLayer": lambda skw: dict(sliding_window=int(skw[1:])), +} + def _flatten_key_value_cache(cache: Cache) -> Tuple[List[Any], torch.utils._pytree.Context]: ca = CacheKeyValue(cache) @@ -59,7 +75,11 @@ def _flatten_key_value_cache(cache: Cache) -> Tuple[List[Any], torch.utils._pytr keys = [] for i in range(len(ca.key_cache)): letter = SHORTEN_LAYER_NAMES[ca.cls_layers[i].__name__] - keys.extend([f"key_{letter}{i}", f"value_{letter}{i}"]) + if hasattr(cache, "layers"): + kwargs = KWARGS_LAYER_NAMES[ca.cls_layers[i].__name__](cache.layers[i]) + else: + kwargs = "" + keys.extend([f"key_{letter}{kwargs}_{i}", f"value_{letter}{kwargs}_{i}"]) return flat, keys @@ -86,10 +106,16 @@ def _unflatten_cache( res = make_cache(list(zip(values[::2], values[1::2]))) else: cls_layer_names = [SHORTEN_LAYER_NAMES[name.split("_")[1][0]] for name in context][::2] + cls_kwargs = [ + PARSE_LAYER_NAMES[SHORTEN_LAYER_NAMES[name.split("_")[1][0]]](name.split("_")[1]) + for name in context + ][::2] cls_layers = [ getattr(transformers.cache_utils, cls_name) for cls_name in cls_layer_names ] - res = make_cache(list(zip(values[::2], values[1::2])), cls_layers=cls_layers) + res = make_cache( + list(zip(values[::2], values[1::2])), cls_layers=cls_layers, cls_kwargs=cls_kwargs + ) assert output_type is None or isinstance( res, output_type From c35401bae2cf632f937aaf1da32b0e882df8661b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 4 Feb 2026 09:22:17 +0100 Subject: [PATCH 04/14] fix serialization --- _doc/final/plot_export_gemma3_tiny_input_observer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_doc/final/plot_export_gemma3_tiny_input_observer.py b/_doc/final/plot_export_gemma3_tiny_input_observer.py index d719423c..af7e6493 100644 --- a/_doc/final/plot_export_gemma3_tiny_input_observer.py +++ b/_doc/final/plot_export_gemma3_tiny_input_observer.py @@ -26,7 +26,7 @@ pipe = pipeline( "image-text-to-text", model=model_id, - device="cuda", + device="cpu", trust_remote_code=True, max_new_tokens=3, ) From 2099f1333013f98d4d94a5bdd86659974321da69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 4 Feb 2026 09:26:26 +0100 Subject: [PATCH 05/14] upgrade version --- CHANGELOGS.rst | 8 ++++---- _doc/index.rst | 2 +- onnx_diagnostic/__init__.py | 2 +- pyproject.toml | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 52c21ffa..561d2874 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -1,11 +1,11 @@ Change Logs =========== -0.8.12 -++++++ - +0.9.0 ++++++ -* :pr:`400`, :pr:`401`:, :pr:`402`, :pr:`403`: improves InputObserver (investigations), add it the documentation +* :pr:`403`: update the serialization of SlidingWindowCache to include parameter slidinw_window +* :pr:`400`, :pr:`401`:, :pr:`402`: improves InputObserver (investigations), add it the documentation * :pr:`399`: update CI 0.8.11 diff --git a/_doc/index.rst b/_doc/index.rst index 77c31840..0f8bf986 100644 --- a/_doc/index.rst +++ b/_doc/index.rst @@ -240,7 +240,7 @@ The function replaces dynamic dimensions defined as strings by Older versions ============== -* `0.8.12 <../v0.8.12/index.html>`_ +* `0.9.0 <../v0.9.0/index.html>`_ * `0.8.11 <../v0.8.11/index.html>`_ * `0.7.16 <../v0.7.16/index.html>`_ * `0.6.3 <../v0.6.3/index.html>`_ diff --git a/onnx_diagnostic/__init__.py b/onnx_diagnostic/__init__.py index 28c63942..0562b32b 100644 --- a/onnx_diagnostic/__init__.py +++ b/onnx_diagnostic/__init__.py @@ -3,5 +3,5 @@ Functions, classes to dig into a model when this one is right, slow, wrong... """ -__version__ = "0.8.12" +__version__ = "0.9.0" __author__ = "Xavier Dupré" diff --git a/pyproject.toml b/pyproject.toml index c0fe58b5..c8c0779e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "onnx-diagnostic" -version = "0.8.12" +version = "0.9.0" description = "Tools to help converting pytorch models into ONNX." readme = "README.rst" authors = [ From ecf00b064c445b65de620d9f5fab34f37fe03757 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 4 Feb 2026 15:53:19 +0100 Subject: [PATCH 06/14] fix issues --- onnx_diagnostic/helpers/cache_helper.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index cb98052f..12c9ca8a 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -246,7 +246,7 @@ def make_dynamic_cache( ), f"Missing layer class {cls_layers!r}" cls_layers = getattr(transformers.cache_utils, cls_layers) if cls_layers and not isinstance(cls_layers, list): - cls_layers = [cls_layers for _ in key_value_pairs] + cls_layers = [cls_layers for _ in key_value_pairs] # type: ignore[misc] if cls_layers is not None and isinstance(cls_layers, list): assert len(cls_layers) == len(key_value_pairs), ( f"Length mismatch {len(key_value_pairs)} expected but " @@ -267,7 +267,7 @@ def make_dynamic_cache( default_values = KWARGS_LAYER.get(clsy, lambda tensor: {})(kv[0]) for k, v in default_values.items(): if k not in kws: - kws[k] = v + kws[k] = v # type: ignore[index] else: assert cls_kwargs is None, "cls_layers must be a list if cls_kwargs is specified" assert ( @@ -321,16 +321,19 @@ def make_dynamic_cache( if hasattr(cache, "layers") and ( cls_layer is None or cls_layer != transformers.cache_utils.DynamicLayer ): - assert isinstance( - cls_kwargs, list - ), f"Wrong type {type(cls_kwargs)} for cls_kwargs" - assert len(cls_kwargs) == len( - cls_layers - ), f"Length mismatch between cls_kwargs={cls_kwargs} and cls_layers={cls_layers}" - assert len(cls_kwargs) == len(key_value_pairs), ( - f"Length mismatch between cls_kwargs={cls_kwargs} and " - f"len(key_value_pairs)={len(key_value_pairs)}" + assert isinstance(cls_layers, list) and isinstance(cls_kwargs, list), ( + f"Wrong type {type(cls_layers)} for cls_layers or " + f"{type(cls_kwargs)} for cls_kwargs" ) + assert len(cls_kwargs) == len(cls_layers) and len(cls_kwargs) == len( + key_value_pairs + ), ( + f"Length mismatch between len(cls_kwargs)={len(cls_kwargs)}, " + f"len(cls_layers)={len(cls_layers)}, " + f"len(key_value_pairs)={len(key_value_pairs)}, " + f"cls_kwargs={cls_kwargs}, cls_layers={cls_layers}" + ) + del cache.layers[:] cache.layers.extend( [cls_layer(**kws) for cls_layer, kws in zip(cls_layers, cls_kwargs)] # type: ignore[operator, arg-type] ) From 7e1c7ad0d88b5555659e94026f362a7e8c3de2f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 4 Feb 2026 18:05:20 +0100 Subject: [PATCH 07/14] fix cache serialization --- CHANGELOGS.rst | 2 +- .../plot_export_gemma3_tiny_input_observer.py | 15 +++- .../test_input_observer_transformers.py | 1 + .../test_patch_transformers.py | 34 ++++++++ onnx_diagnostic/helpers/cache_helper.py | 5 +- onnx_diagnostic/investigate/input_observer.py | 4 +- .../onnx_export_errors.py | 34 +++++++- .../_patch_transformers_masking_utils.py | 82 +++++++++++++++++++ .../patches/patch_torch.py | 10 ++- .../patches/patch_transformers.py | 1 + 10 files changed, 180 insertions(+), 8 deletions(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 561d2874..0aff1a74 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,7 +4,7 @@ Change Logs 0.9.0 +++++ -* :pr:`403`: update the serialization of SlidingWindowCache to include parameter slidinw_window +* :pr:`403`: update the serialization of SlidingWindowCache to include parameter slidinw_window, patch for sdpa_mask * :pr:`400`, :pr:`401`:, :pr:`402`: improves InputObserver (investigations), add it the documentation * :pr:`399`: update CI diff --git a/_doc/final/plot_export_gemma3_tiny_input_observer.py b/_doc/final/plot_export_gemma3_tiny_input_observer.py index af7e6493..9be541c9 100644 --- a/_doc/final/plot_export_gemma3_tiny_input_observer.py +++ b/_doc/final/plot_export_gemma3_tiny_input_observer.py @@ -29,6 +29,7 @@ device="cpu", trust_remote_code=True, max_new_tokens=3, + dtype=torch.float16, ) messages = [ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, @@ -51,7 +52,9 @@ # %% # Captures inputs and outputs for the model. -observer = InputObserver(missing=dict(pixel_values=torch.empty((0, 3, 896, 896)))) +observer = InputObserver( + missing=dict(pixel_values=torch.empty((0, 3, 896, 896), dtype=torch.float16)) +) with ( register_additional_serialization_functions(patch_transformers=True), observer(pipe.model), @@ -77,7 +80,7 @@ filename = "plot_export_gemma3_tiny_input_observer.onnx" -with torch_export_patches(patch_transformers=True): +with torch_export_patches(patch_transformers=True, patch_torch=True, stop_if_static=2): to_onnx( pipe.model, args=(), @@ -89,9 +92,15 @@ # %% # Let's measure the discrepancies. -data = observer.check_discrepancies(filename, progress_bar=True) +data = observer.check_discrepancies(filename, progress_bar=True, atol=1e-2) print(pandas.DataFrame(data)) +# %% +# Let's show the errors. +for row in data: + if not row["SUCCESS"] and "error" in row: + print(row["error"]) + # %% doc.save_fig(doc.plot_dot(filename), f"{filename}.png", dpi=400) diff --git a/_unittests/ut_investigate/test_input_observer_transformers.py b/_unittests/ut_investigate/test_input_observer_transformers.py index 6faf4e84..7b1b75af 100644 --- a/_unittests/ut_investigate/test_input_observer_transformers.py +++ b/_unittests/ut_investigate/test_input_observer_transformers.py @@ -216,6 +216,7 @@ def forward(self, cache): args["cache"].cross_attention_cache.layers[0].keys.shape, (1, 6, 1500, 64) ) + @requires_transformers("4.57") def test_infer_dynamic_shapes_missing_pixels(self): import transformers diff --git a/_unittests/ut_torch_export_patches/test_patch_transformers.py b/_unittests/ut_torch_export_patches/test_patch_transformers.py index 3fcbb7a9..4982d06c 100644 --- a/_unittests/ut_torch_export_patches/test_patch_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_transformers.py @@ -65,6 +65,40 @@ def test_sdpa_mask_recent_torch(self): got = patched_sdpa_mask_recent_torch(**kwargs) self.assertEqualArray(expected, got) + @requires_transformers("4.99") + def test_sdpa_mask_patched(self): + sdpa_mask = transformers.masking_utils.sdpa_mask + patched_sdpa_mask = patch_transformers.patched_sdpa_mask + kwargs = { + "batch_size": 1, + "cache_position": torch.tensor([3], dtype=torch.int64), + "kv_length": 4, + "kv_offset": 0, + "mask_function": transformers.masking_utils.causal_mask_function, + "attention_mask": torch.tensor([[True, True, True, True]]), + "local_size": None, + "allow_is_causal_skip": True, + "allow_is_bidirectional_skip": False, + } + expected = sdpa_mask(**kwargs) + got = patched_sdpa_mask(**kwargs) + self.assertEqual(expected, got) + + kwargs = { + "batch_size": 1, + "cache_position": torch.tensor([3], dtype=torch.int64), + "kv_length": 4, + "kv_offset": 0, + "mask_function": transformers.masking_utils.causal_mask_function, + "attention_mask": torch.tensor([[True, True, True, True]]), + "local_size": None, + "allow_is_causal_skip": False, + "allow_is_bidirectional_skip": False, + } + expected = sdpa_mask(**kwargs) + got = patched_sdpa_mask(**kwargs) + self.assertEqualArray(expected, got) + @requires_transformers("4.99") def test_sdpa_mask_recent_torch_is_running(self): def _copy_vmap_for_bhqkv(mask_function, bh_indices=True): diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index 12c9ca8a..371160cd 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -360,6 +360,7 @@ def make_dynamic_cache( def make_dynamic_cache( key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]], cls_layers: Optional[Union[str, List[type]]] = None, + cls_kwargs: Optional[Union[Dict[str, int], List[Dict[str, int]]]] = None, ) -> transformers.cache_utils.DynamicCache: """ Creates an instance of :class:`transformers.cache_utils.DynamicCache`. @@ -391,7 +392,9 @@ def make_dynamic_cache( ) print(string_type(past_key_values, with_shape=True)) """ - assert not cls_layers, "cls_layers cannot be used for transformers<5." + assert ( + not cls_layers and not cls_kwargs + ), "cls_layers, cls_kwargs cannot be used for transformers<5." key_value_pairs = _preprocess_key_value_pairs(key_value_pairs) cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) # type: ignore for i, (key, value) in enumerate(key_value_pairs): diff --git a/onnx_diagnostic/investigate/input_observer.py b/onnx_diagnostic/investigate/input_observer.py index 1e97d705..f7cf28c1 100644 --- a/onnx_diagnostic/investigate/input_observer.py +++ b/onnx_diagnostic/investigate/input_observer.py @@ -128,6 +128,7 @@ def __init__( self._position_to_args_kwargs: list[int | str] | None = None self._n_tensors_for_args_kwargs: dict[int | str, int] | None = None self.cst_kwargs = cst_kwargs.copy() + assert "use_cache" not in self.cst_kwargs if clone: self.flat_list = [ @@ -327,6 +328,7 @@ def add_inputs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): if k in self.signature_names and isinstance(v, (int, float, bool, str)) and v != self.default_values.get(k, None) + and self.default_values.get(k, None) is not None } kwargs = { k: v @@ -580,7 +582,7 @@ def infer_arguments( dynamic_shapes = self.infer_dynamic_shapes(return_flat=True) # type checking assert isinstance(dynamic_shapes, tuple) - aligned_flat_list = aligned_flat_list.copy() + aligned_flat_list = list(aligned_flat_list) for index in range(len(aligned_flat_list)): if aligned_flat_list[index] is not None: continue diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index 8e2f9de3..5867257f 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -562,6 +562,7 @@ def _patch_transformers( "[torch_export_patches] patches " "transformers.masking_utils.sdpa_mask_recent_torch" ) + f_transformers_sdpa_mask_recent_torch = masking_utils.sdpa_mask_recent_torch masking_utils.sdpa_mask_recent_torch = ( patch_transformers_list.patched_sdpa_mask_recent_torch @@ -574,7 +575,9 @@ def _patch_transformers( ) if masking_utils.sdpa_mask == f_transformers_sdpa_mask_recent_torch: if verbose: - print("[torch_export_patches] patches transformers.masking_utils.sdpa_mask") + print( + "[torch_export_patches] patches transformers.masking_utils.sdpa_mask (1)" + ) f_transformers_sdpa_mask = masking_utils.sdpa_mask masking_utils.sdpa_mask = patch_transformers_list.patched_sdpa_mask_recent_torch if patch_details: @@ -583,9 +586,38 @@ def _patch_transformers( f_transformers_sdpa_mask, patch_transformers_list.patched_sdpa_mask_recent_torch, ) + elif hasattr(masking_utils, "sdpa_mask"): + if verbose: + print( + "[torch_export_patches] patches transformers.masking_utils.sdpa_mask (2)" + ) + f_transformers_sdpa_mask = masking_utils.sdpa_mask + masking_utils.sdpa_mask = patch_transformers_list.patched_sdpa_mask + if patch_details: + patch_details.append( + "transformers", + f_transformers_sdpa_mask, + patch_transformers_list.patched_sdpa_mask, + ) else: f_transformers_sdpa_mask = None + if ( # vmap + masking_utils + and patch_transformers_list.patch_masking_utils + and hasattr(masking_utils, "sdpa_mask") + ): + if verbose: + print("[torch_export_patches] patches transformers.masking_utils.sdpa_mask (3)") + f_transformers_sdpa_mask = masking_utils.sdpa_mask + masking_utils.sdpa_mask = patch_transformers_list.patched_sdpa_mask + if patch_details: + patch_details.append( + "transformers", + f_transformers_sdpa_mask, + patch_transformers_list.patched_sdpa_mask, + ) + if ( # eager_mask masking_utils and patch_transformers_list.patch_masking_utils diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py index dac67e46..cd350087 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py @@ -36,6 +36,8 @@ _ignore_bidirectional_mask_sdpa = None bidirectional_mask_function = None + from transformers.masking_utils import _non_vmap_expansion_sdpa + def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable: """manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``.""" from ...helpers import string_type @@ -180,3 +182,83 @@ def patched_sdpa_mask_recent_torch( batch_arange, head_arange, cache_position, kv_arange ) return causal_mask + + def patched_sdpa_mask( + batch_size: int, + cache_position: torch.Tensor, + kv_length: int, + kv_offset: int = 0, + mask_function: Callable = causal_mask_function, + attention_mask: torch.Tensor | None = None, + local_size: int | None = None, + allow_is_causal_skip: bool = True, + allow_is_bidirectional_skip: bool = False, + allow_torch_fix: bool = True, + use_vmap: bool = False, + **kwargs, + ) -> torch.Tensor | None: + """manual patch for function ``transformers.masking_utils.sdpa_mask``.""" + q_length = cache_position.shape[0] + + # Potentially pad the 2D mask + padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset) + + # Under specific conditions, we can avoid materializing the mask + # 1. Causal masks can rely on the `is_causal` argument + # 2. Bidirectional do not need any further processing (no bias) + if allow_is_causal_skip and _ignore_causal_mask_sdpa( + padding_mask, q_length, kv_length, kv_offset, local_size + ): + return None + if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa( + padding_mask, kv_length, local_size + ): + return None + + # Potentially add the padding 2D mask + if padding_mask is not None: + mask_function = and_masks(mask_function, padding_mask_function(padding_mask)) + + batch_arange = torch.arange(batch_size, device=cache_position.device) + head_arange = torch.arange(1, device=cache_position.device) + # Similar to `kv_arange = torch.arange(start=kv_offset, + # end=kv_offset + kv_length, device=cache_position.device)` + # but without data-dependent slicing (i.e. torch.compile friendly) + kv_arange = torch.arange(kv_length, device=cache_position.device) + kv_offset + + # Actual mask creation + # Option 1: Fast non-vmap mask creation (default) + # PATCHED + use_vmap = False + if not use_vmap: + # Apply mask function element-wise through broadcasting + attention_mask = mask_function( + *_non_vmap_expansion_sdpa(batch_arange, head_arange, cache_position, kv_arange) + ) + # Expand the mask to match batch size + # and query length if they weren't used in the mask function + attention_mask = attention_mask.expand(batch_size, -1, q_length, kv_length) + + # Option 2: Vmap mask creation (torch>=2.6 and custom patterns) + # elif _is_torch_greater_or_equal_than_2_6: + # This creates the 4D mask easily. + # Note that we need this context manager as vmap cannot handle slicing a tensor from + # scalar tensor (it internally calls `.item()` which vmap does not allow, + # but this context works around it + # We don't need to add an offset to the mask_function either, + # as we vmap directly the correct indices for k and kv indices + # with TransformGetItemToIndex(): + # attention_mask = _vmap_expansion_sdpa(mask_function)( + # batch_arange, head_arange, cache_position, kv_arange + # ) + + # Option 3: Error out since it indicates that the user did something custom, + # which they shouldn't have (torch<2.6) + else: + raise ValueError( + "The vmap functionality for mask creation " + "is only supported from torch>=2.6. " + "Please update your torch version or use " + "`use_vmap=False` with index-based masks." + ) + return attention_mask diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py index 0f23328d..30bb8521 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py @@ -188,6 +188,11 @@ def patched__broadcast_shapes(*_shapes): return common_shape +def value_ranges_is_positive(value_ranges: torch.utils._sympy.value_ranges.ValueRanges): + """Tells if an interval is equivalent to a positive or null integer.""" + return value_ranges.lower == 0 and value_ranges.upper > 4623372036854775806 + + class patched_ShapeEnv: def _check_frozen( @@ -281,7 +286,10 @@ def _set_replacement( ) self._update_var_to_range(b, b_bound, self.var_to_range_sloc[a]) tgt_bound = self.bound_sympy(tgt) - assert tgt_bound.issubset( + assert ( + value_ranges_is_positive(tgt_bound) + and value_ranges_is_positive(src_bound) + ) or tgt_bound.issubset( src_bound ), f"{tgt_bound=} not a subset of {src_bound=}" diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 07417855..89b71250 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -38,6 +38,7 @@ from ._patch_transformers_masking_utils import ( patched__vmap_for_bhqkv, patched_eager_mask, + patched_sdpa_mask, patched_sdpa_mask_recent_torch, ) From 905ef812451bc8c9b6f380249a32dbc8e3b5be5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 4 Feb 2026 18:14:28 +0100 Subject: [PATCH 08/14] fix sdpa_mask --- .../onnx_export_errors.py | 1 + .../_patch_transformers_masking_utils.py | 20 ++++++++++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index 5867257f..3a574b6d 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -606,6 +606,7 @@ def _patch_transformers( masking_utils and patch_transformers_list.patch_masking_utils and hasattr(masking_utils, "sdpa_mask") + and f_transformers_sdpa_mask is None ): if verbose: print("[torch_export_patches] patches transformers.masking_utils.sdpa_mask (3)") diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py index cd350087..66b65d98 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py @@ -36,7 +36,25 @@ _ignore_bidirectional_mask_sdpa = None bidirectional_mask_function = None - from transformers.masking_utils import _non_vmap_expansion_sdpa + try: + from transformers.masking_utils import _non_vmap_expansion_sdpa + except ImportError: + + def _non_vmap_expansion_sdpa( + batch_indices: torch.Tensor, + head_indices: torch.Tensor, + q_indices: torch.Tensor, + kv_indices: torch.Tensor, + ): + """ + https://github.com/huggingface/optimum-onnx/blob/ + c123e8f4fab61b54a8e0e31ce74462bcacca576e/optimum/exporters/onnx/model_patcher.py#L362-L365 + """ + batch_indices = batch_indices[:, None, None, None] + head_indices = head_indices[None, :, None, None] + q_indices = q_indices[None, None, :, None] + kv_indices = kv_indices[None, None, None, :] + return batch_indices, head_indices, q_indices, kv_indices def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable: """manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``.""" From 6efcb11095e3916ef2c1d9727121488dd5486522 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 4 Feb 2026 18:29:41 +0100 Subject: [PATCH 09/14] fix check_discrepancies --- .../plot_export_gemma3_tiny_input_observer.py | 6 +++-- onnx_diagnostic/investigate/input_observer.py | 25 +++++++++++-------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/_doc/final/plot_export_gemma3_tiny_input_observer.py b/_doc/final/plot_export_gemma3_tiny_input_observer.py index 9be541c9..ce675056 100644 --- a/_doc/final/plot_export_gemma3_tiny_input_observer.py +++ b/_doc/final/plot_export_gemma3_tiny_input_observer.py @@ -92,8 +92,10 @@ # %% # Let's measure the discrepancies. -data = observer.check_discrepancies(filename, progress_bar=True, atol=1e-2) -print(pandas.DataFrame(data)) +data = observer.check_discrepancies(filename, progress_bar=True, atol=1e-2, include_io=True) +df = pandas.DataFrame(data) +df.to_excel("plot_export_gemma3_tiny_input_observer.xlsx") +print(df) # %% # Let's show the errors. diff --git a/onnx_diagnostic/investigate/input_observer.py b/onnx_diagnostic/investigate/input_observer.py index f7cf28c1..bf922d52 100644 --- a/onnx_diagnostic/investigate/input_observer.py +++ b/onnx_diagnostic/investigate/input_observer.py @@ -859,22 +859,20 @@ def check_discrepancies( rtol: float = 0.1, hist=(0.1, 0.01), progress_bar: bool = False, + include_io: bool = True, ) -> list[dict[str, str | int | float]]: """Computes the discrepancies between the saved inputs and outputs with the saved onnx model. Args: - onnx_model: - ONNX Model to verify. - atol: - Absolute tolerance, recommended values, 1e-4 for float, 1e-2 flot float16. - rtol: - Relative tolerance. - hist: - Thresholds, the function determines the number of discrepancies + onnx_model: ONNX Model to verify. + atol: Absolute tolerance, recommended values, 1e-4 for float, 1e-2 flot float16. + rtol: Relative tolerance. + hist: Thresholds, the function determines the number of discrepancies above these thresholds. - progress_bar: - Shows a progress bar (requires :epkg:`tqdm`). + progress_bar: Shows a progress bar (requires :epkg:`tqdm`). + include_io: Shows inputs/outputs shapes in the summary + returned by this function. Returns: A list of dictionaries, ready to be consumed by a dataframe. @@ -921,6 +919,9 @@ def check_discrepancies( if error: diff: dict[str, Any] = dict(error=error, SUCCESS=False) else: + # The last output may be empty and torch could skip it. + while len(ort_outputs) > len(outputs) and ort_outputs[-1].numel() == 0: + ort_outputs.pop() diff = max_diff(outputs, ort_outputs, hist=lhist) if "rep" in diff and isinstance(diff["rep"], dict): diff.update(diff["rep"]) @@ -941,5 +942,9 @@ def check_discrepancies( n_empty=n_empty, ) ) + if include_io: + diff["inputs"] = string_type(feeds, with_shape=True) + diff["outputs_torch"] = string_type(outputs, with_shape=True) + diff["outputs_ort"] = string_type(ort_outputs, with_shape=True) data.append(diff) return data From 03b126665261b18a447c5d6beecac2bba310d258 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 4 Feb 2026 18:35:36 +0100 Subject: [PATCH 10/14] fix example --- _unittests/ut_investigate/test_input_observer.py | 2 +- onnx_diagnostic/investigate/input_observer.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/_unittests/ut_investigate/test_input_observer.py b/_unittests/ut_investigate/test_input_observer.py index 2629c97a..645becd3 100644 --- a/_unittests/ut_investigate/test_input_observer.py +++ b/_unittests/ut_investigate/test_input_observer.py @@ -695,7 +695,7 @@ def forward(self, x, y, z=None, w=None): exporter="custom", filename=proto_name, ) - data = observer.check_discrepancies(proto_name, progress_bar=False) + data = observer.check_discrepancies(proto_name, progress_bar=False, include_io=True) df = pandas.DataFrame(data) self.assertLess(df["abs"].max(), 1e-5) diff --git a/onnx_diagnostic/investigate/input_observer.py b/onnx_diagnostic/investigate/input_observer.py index bf922d52..02ac9a4c 100644 --- a/onnx_diagnostic/investigate/input_observer.py +++ b/onnx_diagnostic/investigate/input_observer.py @@ -920,8 +920,9 @@ def check_discrepancies( diff: dict[str, Any] = dict(error=error, SUCCESS=False) else: # The last output may be empty and torch could skip it. - while len(ort_outputs) > len(outputs) and ort_outputs[-1].numel() == 0: - ort_outputs.pop() + if isinstance(outputs, list): + while len(ort_outputs) > len(outputs) and ort_outputs[-1].numel() == 0: + ort_outputs.pop() diff = max_diff(outputs, ort_outputs, hist=lhist) if "rep" in diff and isinstance(diff["rep"], dict): diff.update(diff["rep"]) From 1ff7cc5b2485f3cbf8be259a52126b8f11b36cee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 5 Feb 2026 00:23:49 +0100 Subject: [PATCH 11/14] fix patches --- onnx_diagnostic/investigate/input_observer.py | 8 ++-- .../onnx_export_errors.py | 37 ++++++------------- 2 files changed, 15 insertions(+), 30 deletions(-) diff --git a/onnx_diagnostic/investigate/input_observer.py b/onnx_diagnostic/investigate/input_observer.py index 02ac9a4c..358372d2 100644 --- a/onnx_diagnostic/investigate/input_observer.py +++ b/onnx_diagnostic/investigate/input_observer.py @@ -860,7 +860,7 @@ def check_discrepancies( hist=(0.1, 0.01), progress_bar: bool = False, include_io: bool = True, - ) -> list[dict[str, str | int | float]]: + ) -> list[dict[str, str | int | float | bool]]: """Computes the discrepancies between the saved inputs and outputs with the saved onnx model. @@ -917,13 +917,13 @@ def check_discrepancies( duration = time.perf_counter() - begin if error: - diff: dict[str, Any] = dict(error=error, SUCCESS=False) + diff: dict[str, str | int | float | bool] = dict(error=error, SUCCESS=False) else: # The last output may be empty and torch could skip it. - if isinstance(outputs, list): + if isinstance(outputs, list) and isinstance(ort_outputs, list): while len(ort_outputs) > len(outputs) and ort_outputs[-1].numel() == 0: ort_outputs.pop() - diff = max_diff(outputs, ort_outputs, hist=lhist) + diff = max_diff(outputs, ort_outputs, hist=lhist) # type: ignore[bad-assignment] if "rep" in diff and isinstance(diff["rep"], dict): diff.update(diff["rep"]) del diff["rep"] diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index 3a574b6d..79c4d9e7 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -586,21 +586,6 @@ def _patch_transformers( f_transformers_sdpa_mask, patch_transformers_list.patched_sdpa_mask_recent_torch, ) - elif hasattr(masking_utils, "sdpa_mask"): - if verbose: - print( - "[torch_export_patches] patches transformers.masking_utils.sdpa_mask (2)" - ) - f_transformers_sdpa_mask = masking_utils.sdpa_mask - masking_utils.sdpa_mask = patch_transformers_list.patched_sdpa_mask - if patch_details: - patch_details.append( - "transformers", - f_transformers_sdpa_mask, - patch_transformers_list.patched_sdpa_mask, - ) - else: - f_transformers_sdpa_mask = None if ( # vmap masking_utils @@ -775,17 +760,17 @@ def _unpatch_transformers( "transformers.masking_utils.sdpa_mask_recent_torch" ) - if f_transformers_sdpa_mask is not None: - assert f_transformers_sdpa_mask.__name__ in ( - "sdpa_mask", - "sdpa_mask_recent_torch", - ), ( - f"corrupted function 'sdpa_mask', its name is " - f"{f_transformers_sdpa_mask.__name__!r}" - ) - masking_utils.sdpa_mask = f_transformers_sdpa_mask - if verbose: - print("[torch_export_patches] restored transformers.masking_utils.sdpa_mask") + if f_transformers_sdpa_mask is not None: + assert f_transformers_sdpa_mask.__name__ in ( + "sdpa_mask", + "sdpa_mask_recent_torch", + ), ( + f"corrupted function 'sdpa_mask', its name is " + f"{f_transformers_sdpa_mask.__name__!r}" + ) + masking_utils.sdpa_mask = f_transformers_sdpa_mask + if verbose: + print("[torch_export_patches] restored transformers.masking_utils.sdpa_mask") if ( # eager_mask masking_utils From 97a22690bb54cbe9071471b96073349e95546d88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 5 Feb 2026 00:48:08 +0100 Subject: [PATCH 12/14] fix patches --- onnx_diagnostic/investigate/input_observer.py | 2 +- .../patches/_patch_transformers_masking_utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnx_diagnostic/investigate/input_observer.py b/onnx_diagnostic/investigate/input_observer.py index 358372d2..a7ec8321 100644 --- a/onnx_diagnostic/investigate/input_observer.py +++ b/onnx_diagnostic/investigate/input_observer.py @@ -923,7 +923,7 @@ def check_discrepancies( if isinstance(outputs, list) and isinstance(ort_outputs, list): while len(ort_outputs) > len(outputs) and ort_outputs[-1].numel() == 0: ort_outputs.pop() - diff = max_diff(outputs, ort_outputs, hist=lhist) # type: ignore[bad-assignment] + diff = max_diff(outputs, ort_outputs, hist=lhist) if "rep" in diff and isinstance(diff["rep"], dict): diff.update(diff["rep"]) del diff["rep"] diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py index 66b65d98..70472e60 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py @@ -169,7 +169,7 @@ def patched_sdpa_mask_recent_torch( if ( allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa - and _ignore_bidirectional_mask_sdpa(padding_mask) + and _ignore_bidirectional_mask_sdpa(padding_mask, kv_length, kv_offset) ): return None From eeb79e00c4a1a6a7584a63e7c6585e6dd91c75f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 5 Feb 2026 00:52:44 +0100 Subject: [PATCH 13/14] style --- onnx_diagnostic/investigate/input_observer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx_diagnostic/investigate/input_observer.py b/onnx_diagnostic/investigate/input_observer.py index a7ec8321..2e09d6a2 100644 --- a/onnx_diagnostic/investigate/input_observer.py +++ b/onnx_diagnostic/investigate/input_observer.py @@ -923,7 +923,7 @@ def check_discrepancies( if isinstance(outputs, list) and isinstance(ort_outputs, list): while len(ort_outputs) > len(outputs) and ort_outputs[-1].numel() == 0: ort_outputs.pop() - diff = max_diff(outputs, ort_outputs, hist=lhist) + diff = max_diff(outputs, ort_outputs, hist=lhist) # type: ignore[assignment] if "rep" in diff and isinstance(diff["rep"], dict): diff.update(diff["rep"]) del diff["rep"] From cd682d86b890fa3b6fafcf94ad6839faa7fead69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 5 Feb 2026 10:30:34 +0100 Subject: [PATCH 14/14] fix patch --- .../patches/_patch_transformers_masking_utils.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py index 70472e60..46e2b33f 100644 --- a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py @@ -166,12 +166,13 @@ def patched_sdpa_mask_recent_torch( padding_mask, q_length, kv_length, kv_offset, local_size ): return None - if ( - allow_is_bidirectional_skip - and _ignore_bidirectional_mask_sdpa - and _ignore_bidirectional_mask_sdpa(padding_mask, kv_length, kv_offset) - ): - return None + if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa: + # transformers<=5.0: 1 parameter, 3 for transformers>5.0 + n_parameters = len(inspect.signature(_ignore_bidirectional_mask_sdpa).parameters) + if _ignore_bidirectional_mask_sdpa( + *[padding_mask, kv_length, kv_offset][:n_parameters] + ): + return None if mask_function is bidirectional_mask_function: if padding_mask is not None: