diff --git a/tests/_util.py b/tests/_util.py index 42dd54cb..d3ac99cb 100644 --- a/tests/_util.py +++ b/tests/_util.py @@ -20,6 +20,7 @@ # limitations under the License. # import collections.abc as collections_abc +import inspect import pytest import sklearn.utils.estimator_checks @@ -62,17 +63,64 @@ def parametrize_slow(arg_names, fast_arguments, slow_arguments): ) -def parametrize_with_checks_slow(fast_arguments, slow_arguments): - fast_params = [ - list( - sklearn.utils.estimator_checks.check_estimator( - fast_arg, generate_only=True))[0] for fast_arg in fast_arguments - ] - slow_params = [ - list( - sklearn.utils.estimator_checks.check_estimator( - slow_arg, generate_only=True))[0] for slow_arg in slow_arguments - ] +def parametrize_with_checks_slow(fast_arguments, slow_arguments, generate_only=True): + """Parametrize tests with sklearn estimator checks, supporting fast/slow test modes. + + Args: + fast_arguments: List of estimators to use for fast tests. + slow_arguments: List of estimators to use for slow tests. + generate_only: If True, only generate tests without running them (default: True). + This parameter is only used with sklearn < 1.5. In newer versions, + tests are always generated (not run immediately). + + Returns: + A pytest parametrize decorator configured with fast and slow test parameters. + """ + # Check if check_estimator supports generate_only parameter (sklearn < 1.5) + check_estimator_sig = inspect.signature(sklearn.utils.estimator_checks.check_estimator) + supports_generate_only = 'generate_only' in check_estimator_sig.parameters + + def _get_first_check_for_estimator(estimator): + """Helper to get the first check for a given estimator in new sklearn API. + + Note: We only return the first check to match the original behavior with [0]. + This maintains consistency with the legacy API which also returned [0]. + """ + try: + decorator = sklearn.utils.estimator_checks.parametrize_with_checks([estimator]) + # Extract the generator from the decorator's internal structure + # Note: This accesses pytest/sklearn internals and may be fragile across versions + gen = decorator.mark.args[1] + # Convert to list and take first element to avoid generator exhaustion issues + checks_list = list(gen) + return checks_list[0] if checks_list else None + except (AttributeError, IndexError, TypeError) as e: + raise RuntimeError( + f"Failed to extract checks from sklearn.utils.estimator_checks.parametrize_with_checks. " + f"This may be due to sklearn API changes. Error: {e}" + ) + + if supports_generate_only: + # Old sklearn API (<= 1.4.x): use check_estimator with generate_only=True + fast_params = [ + list( + sklearn.utils.estimator_checks.check_estimator( + fast_arg, generate_only=generate_only))[0] for fast_arg in fast_arguments + ] + slow_params = [ + list( + sklearn.utils.estimator_checks.check_estimator( + slow_arg, generate_only=generate_only))[0] for slow_arg in slow_arguments + ] + else: + # New sklearn API (>= 1.5): use parametrize_with_checks to get test params + # For each estimator, get the first check + fast_params = [_get_first_check_for_estimator(fast_arg) for fast_arg in fast_arguments] + slow_params = [_get_first_check_for_estimator(slow_arg) for slow_arg in slow_arguments] + # Filter out any None values + fast_params = [p for p in fast_params if p is not None] + slow_params = [p for p in slow_params if p is not None] + return parametrize_slow("estimator,check", fast_params, slow_params)