Skip to content
Closed
70 changes: 59 additions & 11 deletions tests/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# limitations under the License.
#
import collections.abc as collections_abc
import inspect

import pytest
import sklearn.utils.estimator_checks
Expand Down Expand Up @@ -62,17 +63,64 @@ def parametrize_slow(arg_names, fast_arguments, slow_arguments):
)


def parametrize_with_checks_slow(fast_arguments, slow_arguments):
fast_params = [
list(
sklearn.utils.estimator_checks.check_estimator(
fast_arg, generate_only=True))[0] for fast_arg in fast_arguments
]
slow_params = [
list(
sklearn.utils.estimator_checks.check_estimator(
slow_arg, generate_only=True))[0] for slow_arg in slow_arguments
]
def parametrize_with_checks_slow(fast_arguments, slow_arguments, generate_only=True):
"""Parametrize tests with sklearn estimator checks, supporting fast/slow test modes.

Args:
fast_arguments: List of estimators to use for fast tests.
slow_arguments: List of estimators to use for slow tests.
generate_only: If True, only generate tests without running them (default: True).
This parameter is only used with sklearn < 1.5. In newer versions,
tests are always generated (not run immediately).

Returns:
A pytest parametrize decorator configured with fast and slow test parameters.
"""
# Check if check_estimator supports generate_only parameter (sklearn < 1.5)
check_estimator_sig = inspect.signature(sklearn.utils.estimator_checks.check_estimator)
supports_generate_only = 'generate_only' in check_estimator_sig.parameters

def _get_first_check_for_estimator(estimator):
"""Helper to get the first check for a given estimator in new sklearn API.

Note: We only return the first check to match the original behavior with [0].
This maintains consistency with the legacy API which also returned [0].
"""
try:
decorator = sklearn.utils.estimator_checks.parametrize_with_checks([estimator])
# Extract the generator from the decorator's internal structure
# Note: This accesses pytest/sklearn internals and may be fragile across versions
gen = decorator.mark.args[1]
# Convert to list and take first element to avoid generator exhaustion issues
checks_list = list(gen)
return checks_list[0] if checks_list else None
except (AttributeError, IndexError, TypeError) as e:
raise RuntimeError(
f"Failed to extract checks from sklearn.utils.estimator_checks.parametrize_with_checks. "
f"This may be due to sklearn API changes. Error: {e}"
)

if supports_generate_only:
# Old sklearn API (<= 1.4.x): use check_estimator with generate_only=True
fast_params = [
list(
sklearn.utils.estimator_checks.check_estimator(
fast_arg, generate_only=generate_only))[0] for fast_arg in fast_arguments
]
slow_params = [
list(
sklearn.utils.estimator_checks.check_estimator(
slow_arg, generate_only=generate_only))[0] for slow_arg in slow_arguments
]
else:
# New sklearn API (>= 1.5): use parametrize_with_checks to get test params
# For each estimator, get the first check
fast_params = [_get_first_check_for_estimator(fast_arg) for fast_arg in fast_arguments]
slow_params = [_get_first_check_for_estimator(slow_arg) for slow_arg in slow_arguments]
# Filter out any None values
fast_params = [p for p in fast_params if p is not None]
slow_params = [p for p in slow_params if p is not None]

return parametrize_slow("estimator,check", fast_params, slow_params)


Expand Down
Loading