Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk/ml/azure-ai-ml/azure/ai/ml/_schema/assets/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class IndexAssetSchema(ArtifactSchema):
name = fields.Str(required=True, allow_none=False)
id = ArmStr(azureml_type=AzureMLResourceType.INDEX, dump_only=True)
stage = fields.Str(default="Development")
stage = fields.Str(dump_default="Development")
path = fields.Str(
required=True,
metadata={
Expand Down
31 changes: 25 additions & 6 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
import traceback
import typing
from abc import abstractmethod
from datetime import datetime
from pathlib import Path
from typing import List, Optional, Union

from marshmallow import RAISE, fields
from marshmallow.exceptions import ValidationError
from marshmallow.fields import Field, Nested
from marshmallow.utils import FieldInstanceResolutionError, from_iso_datetime, resolve_field_instance
from marshmallow.fields import _resolve_field_instance as resolve_field_instance # Private API in marshmallow 4.0

from ..._utils._arm_id_utils import AMLVersionedArmId, is_ARM_id_for_resource, parse_name_label, parse_name_version
from ..._utils._experimental import _is_warning_cached
Expand Down Expand Up @@ -249,7 +250,9 @@ def _serialize(self, value, attr, obj, **kwargs):

def _validate(self, value):
try:
from_iso_datetime(value)
# Handle 'Z' UTC indicator which fromisoformat doesn't support
normalized_value = value.replace('Z', '+00:00') if isinstance(value, str) else value
datetime.fromisoformat(normalized_value)
except Exception as e:
raise ValidationError(f"Not a valid ISO8601-formatted datetime string: {value}") from e

Expand Down Expand Up @@ -418,9 +421,19 @@ class NestedField(Nested):
"""anticipates the default coming in next marshmallow version, unknown=True."""

def __init__(self, *args, **kwargs):
if kwargs.get("unknown") is None:
kwargs["unknown"] = RAISE
# In marshmallow 4.0, 'unknown' is only valid on Schema classes, not Field classes
# Extract it from kwargs and apply to schema instance later
self._unknown_value = kwargs.pop("unknown", RAISE)
super().__init__(*args, **kwargs)

@property
def schema(self):
"""Override to apply unknown setting to the nested schema instance."""
nested_schema = super().schema
# Apply the unknown setting to the schema instance
if nested_schema is not None and hasattr(nested_schema, 'unknown'):
nested_schema.unknown = self._unknown_value
return nested_schema


# Note: Currently contains a bug where the order in which fields are inputted can potentially cause a bug
Expand All @@ -432,14 +445,17 @@ class UnionField(fields.Field):
"""A field that can be one of multiple types."""

def __init__(self, union_fields: List[fields.Field], is_strict=False, **kwargs):
# In marshmallow 4.0, 'unknown' is only valid on Schema classes, not Field classes
# Remove it from kwargs if present
kwargs.pop("unknown", None)
super().__init__(**kwargs)
try:
# add the validation and make sure union_fields must be subclasses or instances of
# marshmallow.base.FieldABC
self._union_fields = [resolve_field_instance(cls_or_instance) for cls_or_instance in union_fields]
# TODO: make serialization/de-serialization work in the same way as json schema when is_strict is True
self.is_strict = is_strict # S\When True, combine fields with oneOf instead of anyOf at schema generation
except FieldInstanceResolutionError as error:
except Exception as error:
raise ValueError(
'Elements of "union_fields" must be subclasses or instances of marshmallow.base.FieldABC.'
) from error
Expand Down Expand Up @@ -552,6 +568,9 @@ def __init__(
resolve_field_instance(cls_or_instance) for cls_or_instance in type_sensitive_fields
]

# In marshmallow 4.0, 'unknown' is only valid on Schema classes, not Field classes
# Remove it from kwargs before passing to parent UnionField
kwargs.pop("unknown", None)
super(TypeSensitiveUnionField, self).__init__(union_fields, **kwargs)
self._type_field_name = type_field_name

Expand Down Expand Up @@ -876,7 +895,7 @@ def __init__(self, experimental_field: fields.Field, **kwargs):
try:
self._experimental_field = resolve_field_instance(experimental_field)
self.required = experimental_field.required
except FieldInstanceResolutionError as error:
except Exception as error:
raise ValueError(
'"experimental_field" must be subclasses or instances of marshmallow.base.FieldABC.'
) from error
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

# pylint: disable=unused-argument

from marshmallow import fields, post_load
from marshmallow import fields, post_load, validate
from azure.ai.ml._schema.core.schema import PatchedSchemaMeta


class ComputeConfigurationSchema(metaclass=PatchedSchemaMeta):
compute_type = fields.Str(allowed_values=["ServerlessSpark"])
compute_type = fields.Str(validate=validate.OneOf(["ServerlessSpark"]))


class ServerlessSparkComputeSchema(ComputeConfigurationSchema):
Expand Down
2 changes: 1 addition & 1 deletion sdk/ml/azure-ai-ml/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
"pyyaml>=5.1.0,<7.0.0",
"azure-core>=1.23.0",
"azure-mgmt-core>=1.3.0",
"marshmallow>=3.5,<4.0.0",
"marshmallow>=3.5,<5.0.0",
"jsonschema>=4.0.0,<5.0.0",
"tqdm<5.0.0",
# Used for PR 825138
Expand Down
Loading