From 049d87a75647d4c1cce8845de2b725b738ce8529 Mon Sep 17 00:00:00 2001 From: Lucy Qiu Date: Sun, 21 Dec 2025 20:03:16 -0800 Subject: [PATCH] FlatTensor alignment tests (#16317) Summary: 1. Pass segment_alignment from ExecutorchBackendConfig to flat tensor serializer Other changes: 2. Set segment_alignment=128 as default to match ExecutorchBackendConfig 3. Remove tensor_alignment from the config (we do not have multiple tensors per segment anymore) 4. Set flatbuffer alignment=16. I don't think we need any alignment at the moment here, but can't hurt. The PTE file aligns flatbuffer to the largest force_align value in the schema, which the PTD file doesn't have. 5. Additional tests for varying segment alignment. Reviewed By: larryliu0820 Differential Revision: D89422691 --- exir/program/_program.py | 10 ++++-- extension/flat_tensor/serialize/serialize.py | 22 +++++-------- extension/flat_tensor/test/test_serialize.py | 34 +++++++++++++------- 3 files changed, 39 insertions(+), 27 deletions(-) diff --git a/exir/program/_program.py b/exir/program/_program.py index 6ab797903a1..c5100394aa3 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -592,10 +592,13 @@ def __init__( self._constant_tensor_alignment: Optional[int] = constant_tensor_alignment self._delegate_alignment: Optional[int] = delegate_alignment from executorch.extension.flat_tensor.serialize.serialize import ( + FlatTensorConfig, FlatTensorSerializer, ) - self._data_serializer: DataSerializer = FlatTensorSerializer() + self._data_serializer: DataSerializer = FlatTensorSerializer( + FlatTensorConfig(self._segment_alignment) + ) def _get_emitter_output(self) -> EmitterOutput: if self._emitter_output is None: @@ -1851,10 +1854,13 @@ def __init__( # Serialize emitter output, ready to be written to a file. from executorch.extension.flat_tensor.serialize.serialize import ( + FlatTensorConfig, FlatTensorSerializer, ) - self._data_serializer = FlatTensorSerializer() + self._data_serializer = FlatTensorSerializer( + FlatTensorConfig(segment_alignment=backend_config.segment_alignment) + ) self._pte_data, self._tensor_data = serialize_for_executorch( self._emitter_output, backend_config, diff --git a/extension/flat_tensor/serialize/serialize.py b/extension/flat_tensor/serialize/serialize.py index 644a923f88d..94303958caa 100644 --- a/extension/flat_tensor/serialize/serialize.py +++ b/extension/flat_tensor/serialize/serialize.py @@ -39,6 +39,9 @@ # endian. _HEADER_BYTEORDER: Literal["little"] = "little" +# Alignment of the flatbuffer (after the header). +_FLATBUFFER_ALIGNMENT: int = 16 + # Current version. Keep in sync with c++ version number in serialize. _FLAT_TENSOR_VERSION: int = 0 @@ -95,8 +98,7 @@ def _deserialize_to_flat_tensor(flatbuffer: bytes) -> FlatTensor: @dataclass class FlatTensorConfig: - tensor_alignment: int = 16 - segment_alignment: int = 16 + segment_alignment: int = 128 @dataclass @@ -334,18 +336,13 @@ def serialize( ) flatbuffer_payload = _serialize_to_flatbuffer(flat_tensor) - padded_flatbuffer_length: int = aligned_size( - input_size=len(flatbuffer_payload), - alignment=self.config.tensor_alignment, - ) - padded_header_length: int = aligned_size( input_size=FlatTensorHeader.EXPECTED_LENGTH, - alignment=self.config.tensor_alignment, + alignment=_FLATBUFFER_ALIGNMENT, ) segment_base_offset = aligned_size( - padded_flatbuffer_length + padded_header_length, + len(flatbuffer_payload) + padded_header_length, self.config.segment_alignment, ) @@ -360,19 +357,16 @@ def serialize( # Pad header and payload to segment alignment. header_data = pad_to(header_data, padded_header_length) - original_flatbuffer_payload_size = len(flatbuffer_payload) - flatbuffer_payload.append( - b"\x00" * (padded_flatbuffer_length - len(flatbuffer_payload)) - ) injected_flatbuffer_data: bytes = _insert_flatbuffer_header( flatbuffer_data=flatbuffer_payload.__bytes__(), magic_regex=r"FT[0-9a-zA-Z][0-9a-zA-Z]", header_data=header_data, ) + injected_flatbuffer_data = pad_to(injected_flatbuffer_data, segment_base_offset) eh = _get_extended_header(injected_flatbuffer_data) assert eh is not None - assert eh.flatbuffer_size == original_flatbuffer_payload_size + assert eh.flatbuffer_size == len(flatbuffer_payload) assert eh.segment_base_offset == segment_base_offset assert eh.flatbuffer_offset == padded_header_length assert eh.segment_data_size == len(aggregated_segment_data) diff --git a/extension/flat_tensor/test/test_serialize.py b/extension/flat_tensor/test/test_serialize.py index 2446bd9213c..6ecd6911ac8 100644 --- a/extension/flat_tensor/test/test_serialize.py +++ b/extension/flat_tensor/test/test_serialize.py @@ -29,6 +29,7 @@ from executorch.extension.flat_tensor.serialize.serialize import ( _deserialize_to_flat_tensor, + _FLATBUFFER_ALIGNMENT, FlatTensorConfig, FlatTensorHeader, FlatTensorSerializer, @@ -109,8 +110,7 @@ def _check_named_data_entries( f"Named data record {key}.{field.name} does not match.", ) - def test_serialize(self) -> None: - config = FlatTensorConfig() + def _serialize_with_alignment(self, config: FlatTensorConfig) -> None: serializer: DataSerializer = FlatTensorSerializer(config) serialized_data = bytes(serializer.serialize(TEST_DATA_PAYLOAD)) @@ -120,15 +120,15 @@ def test_serialize(self) -> None: ) self.assertTrue(header.is_valid()) - # Header is aligned to config.segment_alignment, which is where the flatbuffer starts. + # Flatbuffer is non-empty. + self.assertTrue(header.flatbuffer_size > 0) + + # Align the flatbuffer to _FLATBUFFER_ALIGNMENT. self.assertEqual( header.flatbuffer_offset, - aligned_size(FlatTensorHeader.EXPECTED_LENGTH, config.segment_alignment), + aligned_size(FlatTensorHeader.EXPECTED_LENGTH, _FLATBUFFER_ALIGNMENT), ) - # Flatbuffer is non-empty. - self.assertTrue(header.flatbuffer_size > 0) - # Segment base offset is aligned to config.segment_alignment. expected_segment_base_offset = aligned_size( header.flatbuffer_offset + header.flatbuffer_size, config.segment_alignment @@ -180,12 +180,12 @@ def test_serialize(self) -> None: segments = flat_tensor.segments self.assertEqual(len(segments), 3) - # Segment 0 contains fqn1, fqn2; 4 bytes, aligned to config.tensor_alignment. + # Segment 0 contains fqn1, fqn2; 4 bytes, aligned to config.segment_alignment. self.assertEqual(segments[0].offset, 0) self.assertEqual(segments[0].size, len(TEST_BUFFER[0])) - # Segment 1 contains fqn3; 32 bytes, aligned to config.tensor_alignment. - self.assertEqual(segments[1].offset, config.tensor_alignment) + # Segment 1 contains fqn3; 32 bytes, aligned to config.segment_alignment. + self.assertEqual(segments[1].offset, config.segment_alignment) self.assertEqual(segments[1].size, len(TEST_BUFFER[1])) # Segment 2 contains key0; 17 bytes, aligned to 64. @@ -194,7 +194,7 @@ def test_serialize(self) -> None: ) self.assertEqual( segments[2].offset, - aligned_size(config.tensor_alignment * 3, custom_alignment), + aligned_size(config.segment_alignment * 2, custom_alignment), ) self.assertEqual(segments[2].size, len(TEST_BUFFER[2])) @@ -245,6 +245,18 @@ def test_serialize(self) -> None: self.assertEqual(segments[2].offset + segments[2].size, len(segment_data)) + def test_serialize_default_alignment(self) -> None: + config = FlatTensorConfig() + self._serialize_with_alignment(config) + + def test_serialize_align_4096(self) -> None: + config = FlatTensorConfig(segment_alignment=4096) + self._serialize_with_alignment(config) + + def test_serialize_align_1024(self) -> None: + config = FlatTensorConfig(segment_alignment=1024) + self._serialize_with_alignment(config) + def test_round_trip(self) -> None: # Serialize and then deserialize the test payload. Make sure it's reconstructed # properly.