Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 120 additions & 26 deletions eng/emitter-package-lock.json

Large diffs are not rendered by default.

30 changes: 16 additions & 14 deletions eng/emitter-package.json
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
{
"name": "dist/src/index.js",
"main": "dist/src/index.js",
"dependencies": {
"@azure-tools/typespec-python": "0.56.1"
"@azure-tools/typespec-python": "https://pkgs.dev.azure.com/azure-sdk/public/_packaging/azure-sdk-for-js-test-autorest@local/npm/registry/@azure-tools/typespec-python/-/typespec-python-0.56.1-alpha.20251231.1.tgz"
},
"devDependencies": {
"@typespec/compiler": "^1.7.1",
"@typespec/http": "^1.7.0",
"@typespec/rest": "~0.77.0",
"@typespec/versioning": "~0.77.0",
"@typespec/openapi": "^1.7.0",
"@typespec/events": "~0.77.0",
"@typespec/sse": "~0.77.0",
"@typespec/streams": "~0.77.0",
"@typespec/xml": "~0.77.0",
"@azure-tools/typespec-autorest": "~0.63.0",
"@azure-tools/typespec-azure-core": "~0.63.0",
"@azure-tools/typespec-azure-resource-manager": "~0.63.0",
"@azure-tools/typespec-autorest": "~0.63.1",
"@azure-tools/typespec-azure-rulesets": "~0.63.0",
"@azure-tools/typespec-client-generator-core": "~0.63.1",
"@azure-tools/typespec-liftr-base": "0.11.0"
"@typespec/compiler": "^1.7.0",
"@typespec/events": "~0.77.0",
"@typespec/http": "^1.7.0",
"@typespec/openapi": "^1.7.0",
"@typespec/rest": "~0.77.0",
"@typespec/sse": "~0.77.0",
"@typespec/streams": "~0.77.0",
"@typespec/versioning": "~0.77.0",
"@typespec/xml": "~0.77.0"
},
"overrides": {
"@autorest/python": "https://pkgs.dev.azure.com/azure-sdk/public/_packaging/azure-sdk-for-js-test-autorest@local/npm/registry/@autorest/python/-/python-6.45.1-alpha.20251231.1.tgz"
}
}
}
113 changes: 110 additions & 3 deletions sdk/ai/azure-ai-agents/azure/ai/agents/_utils/model_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=line-too-long,useless-suppression,too-many-lines
# coding=utf-8
# --------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
Expand Down Expand Up @@ -170,6 +171,21 @@ def default(self, o): # pylint: disable=too-many-return-statements
r"(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s\d{4}\s\d{2}:\d{2}:\d{2}\sGMT"
)

_ARRAY_ENCODE_MAPPING = {
"pipeDelimited": "|",
"spaceDelimited": " ",
"commaDelimited": ",",
"newlineDelimited": "\n",
}


def _deserialize_array_encoded(delimit: str, attr):
if isinstance(attr, str):
if attr == "":
return []
return attr.split(delimit)
return attr


def _deserialize_datetime(attr: typing.Union[str, datetime]) -> datetime:
"""Deserialize ISO-8601 formatted string into Datetime object.
Expand Down Expand Up @@ -314,6 +330,8 @@ def _deserialize_int_as_str(attr):
def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = None):
if annotation is int and rf and rf._format == "str":
return _deserialize_int_as_str
if annotation is str and rf and rf._format in _ARRAY_ENCODE_MAPPING:
return functools.partial(_deserialize_array_encoded, _ARRAY_ENCODE_MAPPING[rf._format])
if rf and rf._format:
return _DESERIALIZE_MAPPING_WITHFORMAT.get(rf._format)
return _DESERIALIZE_MAPPING.get(annotation) # pyright: ignore
Expand Down Expand Up @@ -352,9 +370,39 @@ def __contains__(self, key: typing.Any) -> bool:
return key in self._data

def __getitem__(self, key: str) -> typing.Any:
# If this key has been deserialized (for mutable types), we need to handle serialization
if hasattr(self, "_attr_to_rest_field"):
cache_attr = f"_deserialized_{key}"
if hasattr(self, cache_attr):
rf = _get_rest_field(getattr(self, "_attr_to_rest_field"), key)
if rf:
value = self._data.get(key)
if isinstance(value, (dict, list, set)):
# For mutable types, serialize and return
# But also update _data with serialized form and clear flag
# so mutations via this returned value affect _data
serialized = _serialize(value, rf._format)
# If serialized form is same type (no transformation needed),
# return _data directly so mutations work
if isinstance(serialized, type(value)) and serialized == value:
return self._data.get(key)
# Otherwise return serialized copy and clear flag
try:
object.__delattr__(self, cache_attr)
except AttributeError:
pass
# Store serialized form back
self._data[key] = serialized
return serialized
return self._data.__getitem__(key)

def __setitem__(self, key: str, value: typing.Any) -> None:
# Clear any cached deserialized value when setting through dictionary access
cache_attr = f"_deserialized_{key}"
try:
object.__delattr__(self, cache_attr)
except AttributeError:
pass
self._data.__setitem__(key, value)

def __delitem__(self, key: str) -> None:
Expand Down Expand Up @@ -482,6 +530,8 @@ def _is_model(obj: typing.Any) -> bool:

def _serialize(o, format: typing.Optional[str] = None): # pylint: disable=too-many-return-statements
if isinstance(o, list):
if format in _ARRAY_ENCODE_MAPPING and all(isinstance(x, str) for x in o):
return _ARRAY_ENCODE_MAPPING[format].join(o)
return [_serialize(x, format) for x in o]
if isinstance(o, dict):
return {k: _serialize(v, format) for k, v in o.items()}
Expand Down Expand Up @@ -637,6 +687,10 @@ def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> Self:
if not rf._rest_name_input:
rf._rest_name_input = attr
cls._attr_to_rest_field: dict[str, _RestField] = dict(attr_to_rest_field.items())
cls._backcompat_attr_to_rest_field: dict[str, _RestField] = {
Model._get_backcompat_attribute_name(cls._attr_to_rest_field, attr): rf
for attr, rf in cls._attr_to_rest_field.items()
}
cls._calculated.add(f"{cls.__module__}.{cls.__qualname__}")

return super().__new__(cls)
Expand All @@ -646,6 +700,16 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None:
if hasattr(base, "__mapping__"):
base.__mapping__[discriminator or cls.__name__] = cls # type: ignore

@classmethod
def _get_backcompat_attribute_name(cls, attr_to_rest_field: dict[str, "_RestField"], attr_name: str) -> str:
rest_field_obj = attr_to_rest_field.get(attr_name) # pylint: disable=protected-access
if rest_field_obj is None:
return attr_name
original_tsp_name = getattr(rest_field_obj, "_original_tsp_name", None) # pylint: disable=protected-access
if original_tsp_name:
return original_tsp_name
return attr_name

@classmethod
def _get_discriminator(cls, exist_discriminators) -> typing.Optional["_RestField"]:
for v in cls.__dict__.values():
Expand Down Expand Up @@ -766,6 +830,17 @@ def _deserialize_sequence(
return obj
if isinstance(obj, ET.Element):
obj = list(obj)
try:
if (
isinstance(obj, str)
and isinstance(deserializer, functools.partial)
and isinstance(deserializer.args[0], functools.partial)
and deserializer.args[0].func == _deserialize_array_encoded # pylint: disable=comparison-with-callable
):
# encoded string may be deserialized to sequence
return deserializer(obj)
except: # pylint: disable=bare-except
pass
return type(obj)(_deserialize(deserializer, entry, module) for entry in obj)


Expand Down Expand Up @@ -971,6 +1046,7 @@ def _failsafe_deserialize_xml(
return None


# pylint: disable=too-many-instance-attributes
class _RestField:
def __init__(
self,
Expand All @@ -983,6 +1059,7 @@ def __init__(
format: typing.Optional[str] = None,
is_multipart_file_input: bool = False,
xml: typing.Optional[dict[str, typing.Any]] = None,
original_tsp_name: typing.Optional[str] = None,
):
self._type = type
self._rest_name_input = name
Expand All @@ -994,10 +1071,15 @@ def __init__(
self._format = format
self._is_multipart_file_input = is_multipart_file_input
self._xml = xml if xml is not None else {}
self._original_tsp_name = original_tsp_name

@property
def _class_type(self) -> typing.Any:
return getattr(self._type, "args", [None])[0]
result = getattr(self._type, "args", [None])[0]
# type may be wrapped by nested functools.partial so we need to check for that
if isinstance(result, functools.partial):
return getattr(result, "args", [None])[0]
return result

@property
def _rest_name(self) -> str:
Expand All @@ -1008,14 +1090,37 @@ def _rest_name(self) -> str:
def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin
# by this point, type and rest_name will have a value bc we default
# them in __new__ of the Model class
item = obj.get(self._rest_name)
# Use _data.get() directly to avoid triggering __getitem__ which clears the cache
item = obj._data.get(self._rest_name)
if item is None:
return item
if self._is_model:
return item
return _deserialize(self._type, _serialize(item, self._format), rf=self)

# For mutable types, we want mutations to directly affect _data
# Check if we've already deserialized this value
cache_attr = f"_deserialized_{self._rest_name}"
if hasattr(obj, cache_attr):
# Return the value from _data directly (it's been deserialized in place)
return obj._data.get(self._rest_name)

deserialized = _deserialize(self._type, _serialize(item, self._format), rf=self)

# For mutable types, store the deserialized value back in _data
# so mutations directly affect _data
if isinstance(deserialized, (dict, list, set)):
obj._data[self._rest_name] = deserialized
object.__setattr__(obj, cache_attr, True) # Mark as deserialized
return deserialized

return deserialized

def __set__(self, obj: Model, value) -> None:
# Clear the cached deserialized object when setting a new value
cache_attr = f"_deserialized_{self._rest_name}"
if hasattr(obj, cache_attr):
object.__delattr__(obj, cache_attr)

if value is None:
# we want to wipe out entries if users set attr to None
try:
Expand Down Expand Up @@ -1045,6 +1150,7 @@ def rest_field(
format: typing.Optional[str] = None,
is_multipart_file_input: bool = False,
xml: typing.Optional[dict[str, typing.Any]] = None,
original_tsp_name: typing.Optional[str] = None,
) -> typing.Any:
return _RestField(
name=name,
Expand All @@ -1054,6 +1160,7 @@ def rest_field(
format=format,
is_multipart_file_input=is_multipart_file_input,
xml=xml,
original_tsp_name=original_tsp_name,
)


Expand Down
17 changes: 14 additions & 3 deletions sdk/ai/azure-ai-agents/azure/ai/agents/_utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,13 +821,20 @@ def serialize_basic(cls, data, data_type, **kwargs):
:param str data_type: Type of object in the iterable.
:rtype: str, int, float, bool
:return: serialized object
:raises TypeError: raise if data_type is not one of str, int, float, bool.
"""
custom_serializer = cls._get_custom_serializers(data_type, **kwargs)
if custom_serializer:
return custom_serializer(data)
if data_type == "str":
return cls.serialize_unicode(data)
return eval(data_type)(data) # nosec # pylint: disable=eval-used
if data_type == "int":
return int(data)
if data_type == "float":
return float(data)
if data_type == "bool":
return bool(data)
raise TypeError("Unknown basic data type: {}".format(data_type))

@classmethod
def serialize_unicode(cls, data):
Expand Down Expand Up @@ -1757,7 +1764,7 @@ def deserialize_basic(self, attr, data_type): # pylint: disable=too-many-return
:param str data_type: deserialization data type.
:return: Deserialized basic type.
:rtype: str, int, float or bool
:raises TypeError: if string format is not valid.
:raises TypeError: if string format is not valid or data_type is not one of str, int, float, bool.
"""
# If we're here, data is supposed to be a basic type.
# If it's still an XML node, take the text
Expand All @@ -1783,7 +1790,11 @@ def deserialize_basic(self, attr, data_type): # pylint: disable=too-many-return

if data_type == "str":
return self.deserialize_unicode(attr)
return eval(data_type)(attr) # nosec # pylint: disable=eval-used
if data_type == "int":
return int(attr)
if data_type == "float":
return float(attr)
raise TypeError("Unknown basic data type: {}".format(data_type))

@staticmethod
def deserialize_unicode(data):
Expand Down
9 changes: 5 additions & 4 deletions sdk/ai/azure-ai-agents/azure/ai/agents/_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,20 @@ def serialize_multipart_data_entry(data_entry: Any) -> Any:

def prepare_multipart_form_data(
body: Mapping[str, Any], multipart_fields: list[str], data_fields: list[str]
) -> tuple[list[FileType], dict[str, Any]]:
) -> list[FileType]:
files: list[FileType] = []
data: dict[str, Any] = {}
for multipart_field in multipart_fields:
multipart_entry = body.get(multipart_field)
if isinstance(multipart_entry, list):
files.extend([(multipart_field, e) for e in multipart_entry])
elif multipart_entry:
files.append((multipart_field, multipart_entry))

# if files is empty, sdk core library can't handle multipart/form-data correctly, so
# we put data fields into files with filename as None to avoid that scenario.
for data_field in data_fields:
data_entry = body.get(data_field)
if data_entry:
data[data_field] = serialize_multipart_data_entry(data_entry)
files.append((data_field, str(serialize_multipart_data_entry(data_entry))))

return files, data
return files
Loading