Skip to content

Commit b0c8667

Browse files
committed
Update SDK code ai_we_0
1 parent 18442a9 commit b0c8667

File tree

9 files changed

+559
-234
lines changed

9 files changed

+559
-234
lines changed

sdk/ai/azure-ai-agents/azure/ai/agents/_utils/model_base.py

Lines changed: 110 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# pylint: disable=line-too-long,useless-suppression,too-many-lines
12
# coding=utf-8
23
# --------------------------------------------------------------------------
34
# Copyright (c) Microsoft Corporation. All rights reserved.
@@ -170,6 +171,21 @@ def default(self, o): # pylint: disable=too-many-return-statements
170171
r"(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s\d{4}\s\d{2}:\d{2}:\d{2}\sGMT"
171172
)
172173

174+
_ARRAY_ENCODE_MAPPING = {
175+
"pipeDelimited": "|",
176+
"spaceDelimited": " ",
177+
"commaDelimited": ",",
178+
"newlineDelimited": "\n",
179+
}
180+
181+
182+
def _deserialize_array_encoded(delimit: str, attr):
183+
if isinstance(attr, str):
184+
if attr == "":
185+
return []
186+
return attr.split(delimit)
187+
return attr
188+
173189

174190
def _deserialize_datetime(attr: typing.Union[str, datetime]) -> datetime:
175191
"""Deserialize ISO-8601 formatted string into Datetime object.
@@ -314,6 +330,8 @@ def _deserialize_int_as_str(attr):
314330
def get_deserializer(annotation: typing.Any, rf: typing.Optional["_RestField"] = None):
315331
if annotation is int and rf and rf._format == "str":
316332
return _deserialize_int_as_str
333+
if annotation is str and rf and rf._format in _ARRAY_ENCODE_MAPPING:
334+
return functools.partial(_deserialize_array_encoded, _ARRAY_ENCODE_MAPPING[rf._format])
317335
if rf and rf._format:
318336
return _DESERIALIZE_MAPPING_WITHFORMAT.get(rf._format)
319337
return _DESERIALIZE_MAPPING.get(annotation) # pyright: ignore
@@ -352,9 +370,39 @@ def __contains__(self, key: typing.Any) -> bool:
352370
return key in self._data
353371

354372
def __getitem__(self, key: str) -> typing.Any:
373+
# If this key has been deserialized (for mutable types), we need to handle serialization
374+
if hasattr(self, "_attr_to_rest_field"):
375+
cache_attr = f"_deserialized_{key}"
376+
if hasattr(self, cache_attr):
377+
rf = _get_rest_field(getattr(self, "_attr_to_rest_field"), key)
378+
if rf:
379+
value = self._data.get(key)
380+
if isinstance(value, (dict, list, set)):
381+
# For mutable types, serialize and return
382+
# But also update _data with serialized form and clear flag
383+
# so mutations via this returned value affect _data
384+
serialized = _serialize(value, rf._format)
385+
# If serialized form is same type (no transformation needed),
386+
# return _data directly so mutations work
387+
if isinstance(serialized, type(value)) and serialized == value:
388+
return self._data.get(key)
389+
# Otherwise return serialized copy and clear flag
390+
try:
391+
object.__delattr__(self, cache_attr)
392+
except AttributeError:
393+
pass
394+
# Store serialized form back
395+
self._data[key] = serialized
396+
return serialized
355397
return self._data.__getitem__(key)
356398

357399
def __setitem__(self, key: str, value: typing.Any) -> None:
400+
# Clear any cached deserialized value when setting through dictionary access
401+
cache_attr = f"_deserialized_{key}"
402+
try:
403+
object.__delattr__(self, cache_attr)
404+
except AttributeError:
405+
pass
358406
self._data.__setitem__(key, value)
359407

360408
def __delitem__(self, key: str) -> None:
@@ -482,6 +530,8 @@ def _is_model(obj: typing.Any) -> bool:
482530

483531
def _serialize(o, format: typing.Optional[str] = None): # pylint: disable=too-many-return-statements
484532
if isinstance(o, list):
533+
if format in _ARRAY_ENCODE_MAPPING and all(isinstance(x, str) for x in o):
534+
return _ARRAY_ENCODE_MAPPING[format].join(o)
485535
return [_serialize(x, format) for x in o]
486536
if isinstance(o, dict):
487537
return {k: _serialize(v, format) for k, v in o.items()}
@@ -637,6 +687,10 @@ def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> Self:
637687
if not rf._rest_name_input:
638688
rf._rest_name_input = attr
639689
cls._attr_to_rest_field: dict[str, _RestField] = dict(attr_to_rest_field.items())
690+
cls._backcompat_attr_to_rest_field: dict[str, _RestField] = {
691+
Model._get_backcompat_attribute_name(cls._attr_to_rest_field, attr): rf
692+
for attr, rf in cls._attr_to_rest_field.items()
693+
}
640694
cls._calculated.add(f"{cls.__module__}.{cls.__qualname__}")
641695

642696
return super().__new__(cls)
@@ -646,6 +700,16 @@ def __init_subclass__(cls, discriminator: typing.Optional[str] = None) -> None:
646700
if hasattr(base, "__mapping__"):
647701
base.__mapping__[discriminator or cls.__name__] = cls # type: ignore
648702

703+
@classmethod
704+
def _get_backcompat_attribute_name(cls, attr_to_rest_field: dict[str, "_RestField"], attr_name: str) -> str:
705+
rest_field_obj = attr_to_rest_field.get(attr_name) # pylint: disable=protected-access
706+
if rest_field_obj is None:
707+
return attr_name
708+
original_tsp_name = getattr(rest_field_obj, "_original_tsp_name", None) # pylint: disable=protected-access
709+
if original_tsp_name:
710+
return original_tsp_name
711+
return attr_name
712+
649713
@classmethod
650714
def _get_discriminator(cls, exist_discriminators) -> typing.Optional["_RestField"]:
651715
for v in cls.__dict__.values():
@@ -766,6 +830,17 @@ def _deserialize_sequence(
766830
return obj
767831
if isinstance(obj, ET.Element):
768832
obj = list(obj)
833+
try:
834+
if (
835+
isinstance(obj, str)
836+
and isinstance(deserializer, functools.partial)
837+
and isinstance(deserializer.args[0], functools.partial)
838+
and deserializer.args[0].func == _deserialize_array_encoded # pylint: disable=comparison-with-callable
839+
):
840+
# encoded string may be deserialized to sequence
841+
return deserializer(obj)
842+
except: # pylint: disable=bare-except
843+
pass
769844
return type(obj)(_deserialize(deserializer, entry, module) for entry in obj)
770845

771846

@@ -971,6 +1046,7 @@ def _failsafe_deserialize_xml(
9711046
return None
9721047

9731048

1049+
# pylint: disable=too-many-instance-attributes
9741050
class _RestField:
9751051
def __init__(
9761052
self,
@@ -983,6 +1059,7 @@ def __init__(
9831059
format: typing.Optional[str] = None,
9841060
is_multipart_file_input: bool = False,
9851061
xml: typing.Optional[dict[str, typing.Any]] = None,
1062+
original_tsp_name: typing.Optional[str] = None,
9861063
):
9871064
self._type = type
9881065
self._rest_name_input = name
@@ -994,10 +1071,15 @@ def __init__(
9941071
self._format = format
9951072
self._is_multipart_file_input = is_multipart_file_input
9961073
self._xml = xml if xml is not None else {}
1074+
self._original_tsp_name = original_tsp_name
9971075

9981076
@property
9991077
def _class_type(self) -> typing.Any:
1000-
return getattr(self._type, "args", [None])[0]
1078+
result = getattr(self._type, "args", [None])[0]
1079+
# type may be wrapped by nested functools.partial so we need to check for that
1080+
if isinstance(result, functools.partial):
1081+
return getattr(result, "args", [None])[0]
1082+
return result
10011083

10021084
@property
10031085
def _rest_name(self) -> str:
@@ -1008,14 +1090,37 @@ def _rest_name(self) -> str:
10081090
def __get__(self, obj: Model, type=None): # pylint: disable=redefined-builtin
10091091
# by this point, type and rest_name will have a value bc we default
10101092
# them in __new__ of the Model class
1011-
item = obj.get(self._rest_name)
1093+
# Use _data.get() directly to avoid triggering __getitem__ which clears the cache
1094+
item = obj._data.get(self._rest_name)
10121095
if item is None:
10131096
return item
10141097
if self._is_model:
10151098
return item
1016-
return _deserialize(self._type, _serialize(item, self._format), rf=self)
1099+
1100+
# For mutable types, we want mutations to directly affect _data
1101+
# Check if we've already deserialized this value
1102+
cache_attr = f"_deserialized_{self._rest_name}"
1103+
if hasattr(obj, cache_attr):
1104+
# Return the value from _data directly (it's been deserialized in place)
1105+
return obj._data.get(self._rest_name)
1106+
1107+
deserialized = _deserialize(self._type, _serialize(item, self._format), rf=self)
1108+
1109+
# For mutable types, store the deserialized value back in _data
1110+
# so mutations directly affect _data
1111+
if isinstance(deserialized, (dict, list, set)):
1112+
obj._data[self._rest_name] = deserialized
1113+
object.__setattr__(obj, cache_attr, True) # Mark as deserialized
1114+
return deserialized
1115+
1116+
return deserialized
10171117

10181118
def __set__(self, obj: Model, value) -> None:
1119+
# Clear the cached deserialized object when setting a new value
1120+
cache_attr = f"_deserialized_{self._rest_name}"
1121+
if hasattr(obj, cache_attr):
1122+
object.__delattr__(obj, cache_attr)
1123+
10191124
if value is None:
10201125
# we want to wipe out entries if users set attr to None
10211126
try:
@@ -1045,6 +1150,7 @@ def rest_field(
10451150
format: typing.Optional[str] = None,
10461151
is_multipart_file_input: bool = False,
10471152
xml: typing.Optional[dict[str, typing.Any]] = None,
1153+
original_tsp_name: typing.Optional[str] = None,
10481154
) -> typing.Any:
10491155
return _RestField(
10501156
name=name,
@@ -1054,6 +1160,7 @@ def rest_field(
10541160
format=format,
10551161
is_multipart_file_input=is_multipart_file_input,
10561162
xml=xml,
1163+
original_tsp_name=original_tsp_name,
10571164
)
10581165

10591166

sdk/ai/azure-ai-agents/azure/ai/agents/_utils/serialization.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -821,13 +821,20 @@ def serialize_basic(cls, data, data_type, **kwargs):
821821
:param str data_type: Type of object in the iterable.
822822
:rtype: str, int, float, bool
823823
:return: serialized object
824+
:raises TypeError: raise if data_type is not one of str, int, float, bool.
824825
"""
825826
custom_serializer = cls._get_custom_serializers(data_type, **kwargs)
826827
if custom_serializer:
827828
return custom_serializer(data)
828829
if data_type == "str":
829830
return cls.serialize_unicode(data)
830-
return eval(data_type)(data) # nosec # pylint: disable=eval-used
831+
if data_type == "int":
832+
return int(data)
833+
if data_type == "float":
834+
return float(data)
835+
if data_type == "bool":
836+
return bool(data)
837+
raise TypeError("Unknown basic data type: {}".format(data_type))
831838

832839
@classmethod
833840
def serialize_unicode(cls, data):
@@ -1757,7 +1764,7 @@ def deserialize_basic(self, attr, data_type): # pylint: disable=too-many-return
17571764
:param str data_type: deserialization data type.
17581765
:return: Deserialized basic type.
17591766
:rtype: str, int, float or bool
1760-
:raises TypeError: if string format is not valid.
1767+
:raises TypeError: if string format is not valid or data_type is not one of str, int, float, bool.
17611768
"""
17621769
# If we're here, data is supposed to be a basic type.
17631770
# If it's still an XML node, take the text
@@ -1783,7 +1790,11 @@ def deserialize_basic(self, attr, data_type): # pylint: disable=too-many-return
17831790

17841791
if data_type == "str":
17851792
return self.deserialize_unicode(attr)
1786-
return eval(data_type)(attr) # nosec # pylint: disable=eval-used
1793+
if data_type == "int":
1794+
return int(attr)
1795+
if data_type == "float":
1796+
return float(attr)
1797+
raise TypeError("Unknown basic data type: {}".format(data_type))
17871798

17881799
@staticmethod
17891800
def deserialize_unicode(data):

sdk/ai/azure-ai-agents/azure/ai/agents/_utils/utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,19 +49,20 @@ def serialize_multipart_data_entry(data_entry: Any) -> Any:
4949

5050
def prepare_multipart_form_data(
5151
body: Mapping[str, Any], multipart_fields: list[str], data_fields: list[str]
52-
) -> tuple[list[FileType], dict[str, Any]]:
52+
) -> list[FileType]:
5353
files: list[FileType] = []
54-
data: dict[str, Any] = {}
5554
for multipart_field in multipart_fields:
5655
multipart_entry = body.get(multipart_field)
5756
if isinstance(multipart_entry, list):
5857
files.extend([(multipart_field, e) for e in multipart_entry])
5958
elif multipart_entry:
6059
files.append((multipart_field, multipart_entry))
6160

61+
# if files is empty, sdk core library can't handle multipart/form-data correctly, so
62+
# we put data fields into files with filename as None to avoid that scenario.
6263
for data_field in data_fields:
6364
data_entry = body.get(data_field)
6465
if data_entry:
65-
data[data_field] = serialize_multipart_data_entry(data_entry)
66+
files.append((data_field, str(serialize_multipart_data_entry(data_entry))))
6667

67-
return files, data
68+
return files

0 commit comments

Comments
 (0)