diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index e370d35ad..04d136960 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -269,6 +269,9 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 if type_name == "dict_items": # Convert to dict for order-insensitive comparison (handles unhashable values) return comparator(dict(orig), dict(new), superset_obj) + if type_name == "bytearray_iterator": + # bytearray_iterator needs element-wise comparison (order matters) + return comparator(list(orig), list(new)) if HAS_NUMPY: import numpy as np # type: ignore # noqa: PGH003 diff --git a/tests/test_comparator.py b/tests/test_comparator.py index a62d61d80..1c2d87695 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -2316,6 +2316,55 @@ def test_dict_views() -> None: assert not comparator(d.items(), [("a", 1), ("b", 2)]) +def test_bytearray_iterator() -> None: + """Test comparator support for bytearray_iterator.""" + # Basic equality + ba1 = bytearray([1, 2, 3]) + ba2 = bytearray([1, 2, 3]) + assert comparator(iter(ba1), iter(ba2)) + + # Different values + ba3 = bytearray([1, 2, 4]) + assert not comparator(iter(ba1), iter(ba3)) + + # Different lengths + ba4 = bytearray([1, 2]) + assert not comparator(iter(ba1), iter(ba4)) + + ba5 = bytearray([1, 2, 3, 4]) + assert not comparator(iter(ba1), iter(ba5)) + + # Empty bytearrays + empty1 = bytearray() + empty2 = bytearray() + assert comparator(iter(empty1), iter(empty2)) + + # Empty vs non-empty + assert not comparator(iter(empty1), iter(ba1)) + + # Single element + single1 = bytearray([42]) + single2 = bytearray([42]) + single3 = bytearray([43]) + assert comparator(iter(single1), iter(single2)) + assert not comparator(iter(single1), iter(single3)) + + # Large bytearrays + large1 = bytearray(range(256)) + large2 = bytearray(range(256)) + large3 = bytearray(range(255, -1, -1)) # reversed + assert comparator(iter(large1), iter(large2)) + assert not comparator(iter(large1), iter(large3)) + + # Verify type name is correct + assert type(iter(bytearray())).__name__ == "bytearray_iterator" + + # bytearray_iterator is not equal to list (different types) + ba = bytearray([1, 2, 3]) + assert not comparator(iter(ba), [1, 2, 3]) + assert not comparator([1, 2, 3], iter(ba)) + + def test_tensorflow_tensor() -> None: """Test comparator support for TensorFlow Tensor objects.""" try: