Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions _doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def linkcode_resolve(domain, info):
("py:class", "torch.utils._pytree.Context"),
("py:class", "torch.utils._pytree.KeyEntry"),
("py:class", "torch.utils._pytree.TreeSpec"),
("py:class", "torch.utils._sympy.value_ranges.ValueRanges"),
("py:class", "transformers.BartForConditionalGeneration"),
("py:class", "transformers.LlamaConfig"),
("py:class", "transformers.cache_utils.Cache"),
Expand Down
11 changes: 9 additions & 2 deletions _doc/final/plot_export_tiny_llm_method_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,14 @@


def generate_text(
prompt, model, tokenizer, max_length=50, temperature=1, top_k=50, top_p=0.95
prompt,
model,
tokenizer,
max_length=50,
temperature=1,
top_k=50,
top_p=0.95,
do_sample=False,
):
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"]
Expand All @@ -41,7 +48,7 @@ def generate_text(
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=True,
do_sample=do_sample,
)

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
Expand Down
11 changes: 11 additions & 0 deletions onnx_diagnostic/export/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,16 @@ def _reorder_kwargs(self, kwargs):
new_kwargs[k] = v
return new_kwargs

def is_empty_cache(self, cache):
if cache.__class__.__name__ == "DynamicCache" and hasattr(cache, "layers"):
if len(cache.layers) == 1 and cache.layers[0].keys is None:
return True
if len(cache.layers) == 0:
return True
if cache is None:
return True
return False

def forward(self, *args, **kwargs):
if not self._export_done:
inp_args = args
Expand All @@ -443,6 +453,7 @@ def forward(self, *args, **kwargs):
if v is not None
and (not self.skip_kwargs_names or k not in self.skip_kwargs_names)
and not isinstance(v, (bool, int, float))
and not self.is_empty_cache(v)
}
)
inp_args, inp_kwargs = torch_deepcopy((inp_args, inp_kwargs))
Expand Down
2 changes: 1 addition & 1 deletion onnx_diagnostic/export/dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ def guess_dynamic_shape_object(
"""Guesses the dynamic shapes for one argument."""
if len(objs) == 0:
return None
set_types = set(type(o) for o in objs)
set_types = set(type(o) for o in objs if o is not None)
assert (
len(set_types) == 1
), f"Unexpected variety of input type {set_types}{msg() if msg else ''})"
Expand Down
9 changes: 9 additions & 0 deletions onnx_diagnostic/helpers/cache_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,4 +832,13 @@ def finalize_cache(cache: transformers.cache_utils.Cache) -> transformers.cache_
# This is used to expand the cache when it does not contains enough layers.
# This is needed since transformers>4.55.3
cache.layer_class_to_replicate = cache.layers[0].__class__
assert (
not hasattr(cache, "layers")
or len(cache.layers) != 1
or cache.layers[0].keys is not None
), (
f"Size mismatch between {len(cache.layers)=}, "
f"first key={cache.layers[0].keys}, " # type: ignore[attr-defined]
f"first value={cache.layers[0].values}" # type: ignore[attr-defined]
)
return cache
9 changes: 7 additions & 2 deletions onnx_diagnostic/helpers/torch_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,9 +851,14 @@ def torch_deepcopy(value: Any) -> Any:
from .cache_helper import CacheKeyValue

ca = CacheKeyValue(value)
return make_dynamic_cache(
torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))), cls_layers=ca.cls_layers
pairs = list(zip(ca.key_cache, ca.value_cache))
assert not hasattr(value, "layers") or len(value.layers) == len(pairs), (
f"Size mismatch between {len(value.layers)=} and {len(pairs)=}. "
f"value={string_type(value, with_shape=True)}, "
f"first key={value.layers[0].keys}, "
f"first value={value.layers[0].values}"
)
return make_dynamic_cache(torch_deepcopy(pairs), cls_layers=ca.cls_layers)
if value.__class__.__name__ == "StaticCache":
from .cache_helper import CacheKeyValue

Expand Down
45 changes: 32 additions & 13 deletions onnx_diagnostic/investigate/input_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _flatten_unflatten_for_dynamic_shapes(
like replace them by a shape

Returns:
the serialized object
the flattened object
"""
if isinstance(obj, torch.Tensor):
return change_function(obj) if change_function else obj
Expand Down Expand Up @@ -104,14 +104,22 @@ def _infer_dynamic_dimensions(
class InputCandidate:
"""Retains one set of inputs given to the forward method or any
other method the class :class:`InputObserver` is stealing from.
Any class is allowed as long as it can be flattened.

Args:
args: Positional arguments.
kwargs: Optional arguments.
clone: Clone the inputs before storing them. Some tensors
args:
Positional arguments.
kwargs:
Optional arguments.
clone:
Clones the inputs before storing them. Some tensors
may be modified inplace, the original value must be retained.
cst_kwargs: Any optional arguments constant over multiple calls.
cst_kwargs:
Any optional arguments constant over multiple calls.
int, float, str, bool values must be stored here.

The constructor flattens the received arguments.
Any necessary flattening function should have been registered first.
"""

def __init__(
Expand Down Expand Up @@ -671,18 +679,20 @@ class InputObserver:
>>> )

With LLM:

>>> input_observer = InputObserver()
>>> with input_observer(model):
>>> model.generate(input_ids)
>>> ep = torch.export.export( # or torch.onnx.export
>>> model,
>>> ()
>>> (),
>>> kwargs=input_observer.infer_arguments(),
>>> dynamic_shapes.input_observer.infer_dynamic_shapes(),
>>> )

Examples can be found in :ref:`l-plot-tiny-llm-export-input-observer`,
:ref:`l-plot-whisper-tiny-export-input-observer`.
:ref:`l-plot-whisper-tiny-export-input-observer`,
:ref:`l-plot-gemma3-tiny-export-input-observer`.
"""

def __init__(self, missing: dict[str, Any] | None = None):
Expand Down Expand Up @@ -865,17 +875,26 @@ def check_discrepancies(
with the saved onnx model.

Args:
onnx_model: ONNX Model to verify.
atol: Absolute tolerance, recommended values, 1e-4 for float, 1e-2 flot float16.
rtol: Relative tolerance.
hist: Thresholds, the function determines the number of discrepancies
onnx_model:
ONNX Model to verify.
atol:
Absolute tolerance, recommended values, 1e-4 for float, 1e-2 flot float16.
rtol:
Relative tolerance.
hist:
Thresholds, the function determines the number of discrepancies
above these thresholds.
progress_bar: Shows a progress bar (requires :epkg:`tqdm`).
include_io: Shows inputs/outputs shapes in the summary
progress_bar:
Shows a progress bar (requires :epkg:`tqdm`).
include_io:
Shows inputs/outputs shapes in the summary
returned by this function.

Returns:
A list of dictionaries, ready to be consumed by a dataframe.

The function catches exceptions, it shows the error in the returned
summary.
"""
sess = OnnxruntimeEvaluator(onnx_model, whole=True)
input_names = sess.input_names
Expand Down
Loading