diff --git a/_doc/conf.py b/_doc/conf.py index 8c5fb587..2a5b8928 100644 --- a/_doc/conf.py +++ b/_doc/conf.py @@ -144,6 +144,7 @@ def linkcode_resolve(domain, info): ("py:class", "torch.utils._pytree.Context"), ("py:class", "torch.utils._pytree.KeyEntry"), ("py:class", "torch.utils._pytree.TreeSpec"), + ("py:class", "torch.utils._sympy.value_ranges.ValueRanges"), ("py:class", "transformers.BartForConditionalGeneration"), ("py:class", "transformers.LlamaConfig"), ("py:class", "transformers.cache_utils.Cache"), diff --git a/_doc/final/plot_export_tiny_llm_method_generate.py b/_doc/final/plot_export_tiny_llm_method_generate.py index 6859ebe1..7f658ee1 100644 --- a/_doc/final/plot_export_tiny_llm_method_generate.py +++ b/_doc/final/plot_export_tiny_llm_method_generate.py @@ -28,7 +28,14 @@ def generate_text( - prompt, model, tokenizer, max_length=50, temperature=1, top_k=50, top_p=0.95 + prompt, + model, + tokenizer, + max_length=50, + temperature=1, + top_k=50, + top_p=0.95, + do_sample=False, ): inputs = tokenizer(prompt, return_tensors="pt") input_ids = inputs["input_ids"] @@ -41,7 +48,7 @@ def generate_text( temperature=temperature, top_k=top_k, top_p=top_p, - do_sample=True, + do_sample=do_sample, ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) diff --git a/onnx_diagnostic/export/api.py b/onnx_diagnostic/export/api.py index dfcd15be..03898e78 100644 --- a/onnx_diagnostic/export/api.py +++ b/onnx_diagnostic/export/api.py @@ -428,6 +428,16 @@ def _reorder_kwargs(self, kwargs): new_kwargs[k] = v return new_kwargs + def is_empty_cache(self, cache): + if cache.__class__.__name__ == "DynamicCache" and hasattr(cache, "layers"): + if len(cache.layers) == 1 and cache.layers[0].keys is None: + return True + if len(cache.layers) == 0: + return True + if cache is None: + return True + return False + def forward(self, *args, **kwargs): if not self._export_done: inp_args = args @@ -443,6 +453,7 @@ def forward(self, *args, **kwargs): if v is not None and (not self.skip_kwargs_names or k not in self.skip_kwargs_names) and not isinstance(v, (bool, int, float)) + and not self.is_empty_cache(v) } ) inp_args, inp_kwargs = torch_deepcopy((inp_args, inp_kwargs)) diff --git a/onnx_diagnostic/export/dynamic_shapes.py b/onnx_diagnostic/export/dynamic_shapes.py index a8c3aadd..f02614a6 100644 --- a/onnx_diagnostic/export/dynamic_shapes.py +++ b/onnx_diagnostic/export/dynamic_shapes.py @@ -834,7 +834,7 @@ def guess_dynamic_shape_object( """Guesses the dynamic shapes for one argument.""" if len(objs) == 0: return None - set_types = set(type(o) for o in objs) + set_types = set(type(o) for o in objs if o is not None) assert ( len(set_types) == 1 ), f"Unexpected variety of input type {set_types}{msg() if msg else ''})" diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index 371160cd..303ebb48 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -832,4 +832,13 @@ def finalize_cache(cache: transformers.cache_utils.Cache) -> transformers.cache_ # This is used to expand the cache when it does not contains enough layers. # This is needed since transformers>4.55.3 cache.layer_class_to_replicate = cache.layers[0].__class__ + assert ( + not hasattr(cache, "layers") + or len(cache.layers) != 1 + or cache.layers[0].keys is not None + ), ( + f"Size mismatch between {len(cache.layers)=}, " + f"first key={cache.layers[0].keys}, " # type: ignore[attr-defined] + f"first value={cache.layers[0].values}" # type: ignore[attr-defined] + ) return cache diff --git a/onnx_diagnostic/helpers/torch_helper.py b/onnx_diagnostic/helpers/torch_helper.py index 1a8f4d0f..8888dbef 100644 --- a/onnx_diagnostic/helpers/torch_helper.py +++ b/onnx_diagnostic/helpers/torch_helper.py @@ -851,9 +851,14 @@ def torch_deepcopy(value: Any) -> Any: from .cache_helper import CacheKeyValue ca = CacheKeyValue(value) - return make_dynamic_cache( - torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))), cls_layers=ca.cls_layers + pairs = list(zip(ca.key_cache, ca.value_cache)) + assert not hasattr(value, "layers") or len(value.layers) == len(pairs), ( + f"Size mismatch between {len(value.layers)=} and {len(pairs)=}. " + f"value={string_type(value, with_shape=True)}, " + f"first key={value.layers[0].keys}, " + f"first value={value.layers[0].values}" ) + return make_dynamic_cache(torch_deepcopy(pairs), cls_layers=ca.cls_layers) if value.__class__.__name__ == "StaticCache": from .cache_helper import CacheKeyValue diff --git a/onnx_diagnostic/investigate/input_observer.py b/onnx_diagnostic/investigate/input_observer.py index 2e09d6a2..e642f623 100644 --- a/onnx_diagnostic/investigate/input_observer.py +++ b/onnx_diagnostic/investigate/input_observer.py @@ -32,7 +32,7 @@ def _flatten_unflatten_for_dynamic_shapes( like replace them by a shape Returns: - the serialized object + the flattened object """ if isinstance(obj, torch.Tensor): return change_function(obj) if change_function else obj @@ -104,14 +104,22 @@ def _infer_dynamic_dimensions( class InputCandidate: """Retains one set of inputs given to the forward method or any other method the class :class:`InputObserver` is stealing from. + Any class is allowed as long as it can be flattened. Args: - args: Positional arguments. - kwargs: Optional arguments. - clone: Clone the inputs before storing them. Some tensors + args: + Positional arguments. + kwargs: + Optional arguments. + clone: + Clones the inputs before storing them. Some tensors may be modified inplace, the original value must be retained. - cst_kwargs: Any optional arguments constant over multiple calls. + cst_kwargs: + Any optional arguments constant over multiple calls. int, float, str, bool values must be stored here. + + The constructor flattens the received arguments. + Any necessary flattening function should have been registered first. """ def __init__( @@ -671,18 +679,20 @@ class InputObserver: >>> ) With LLM: + >>> input_observer = InputObserver() >>> with input_observer(model): >>> model.generate(input_ids) >>> ep = torch.export.export( # or torch.onnx.export >>> model, - >>> () + >>> (), >>> kwargs=input_observer.infer_arguments(), >>> dynamic_shapes.input_observer.infer_dynamic_shapes(), >>> ) Examples can be found in :ref:`l-plot-tiny-llm-export-input-observer`, - :ref:`l-plot-whisper-tiny-export-input-observer`. + :ref:`l-plot-whisper-tiny-export-input-observer`, + :ref:`l-plot-gemma3-tiny-export-input-observer`. """ def __init__(self, missing: dict[str, Any] | None = None): @@ -865,17 +875,26 @@ def check_discrepancies( with the saved onnx model. Args: - onnx_model: ONNX Model to verify. - atol: Absolute tolerance, recommended values, 1e-4 for float, 1e-2 flot float16. - rtol: Relative tolerance. - hist: Thresholds, the function determines the number of discrepancies + onnx_model: + ONNX Model to verify. + atol: + Absolute tolerance, recommended values, 1e-4 for float, 1e-2 flot float16. + rtol: + Relative tolerance. + hist: + Thresholds, the function determines the number of discrepancies above these thresholds. - progress_bar: Shows a progress bar (requires :epkg:`tqdm`). - include_io: Shows inputs/outputs shapes in the summary + progress_bar: + Shows a progress bar (requires :epkg:`tqdm`). + include_io: + Shows inputs/outputs shapes in the summary returned by this function. Returns: A list of dictionaries, ready to be consumed by a dataframe. + + The function catches exceptions, it shows the error in the returned + summary. """ sess = OnnxruntimeEvaluator(onnx_model, whole=True) input_names = sess.input_names