diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 89e7b83..710801a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,7 +18,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python: [ "3.10", "3.11", "3.12", "3.13" ] + python: [ "3.10", "3.11", "3.12", "3.13", "3.14" ] env: UV_PYTHON: ${{ matrix.python }} steps: diff --git a/python/restate/aws_lambda.py b/python/restate/aws_lambda.py index aa335d9..3eb8244 100644 --- a/python/restate/aws_lambda.py +++ b/python/restate/aws_lambda.py @@ -62,6 +62,11 @@ def request_to_receive(req: RestateLambdaRequest) -> Receive: assert req["isBase64Encoded"] body = base64.b64decode(req["body"]) + # Decompress zstd-encoded request body + headers = {k.lower(): v for k, v in req.get("headers", {}).items()} + if "zstd" in headers.get("content-encoding", ""): + body = zstd_decompress(body) + events = cast( list[HTTPRequestEvent], [ @@ -80,15 +85,19 @@ async def recv() -> HTTPRequestEvent: return recv +RESPONSE_COMPRESSION_THRESHOLD = 3 * 1024 * 1024 + + class ResponseCollector: """ Response collector from ASGI Send to Lambda """ - def __init__(self): + def __init__(self, accept_encoding: str = ""): self.body = bytearray() - self.headers = {} + self.headers: dict[str, str] = {} self.status_code = 500 + self.accept_encoding = accept_encoding async def __call__(self, message: Union[HTTPResponseStartEvent, HTTPResponseBodyEvent]) -> None: """ @@ -105,11 +114,18 @@ def to_lambda_response(self) -> RestateLambdaResponse: """ Convert collected values to lambda response """ + body: bytes | bytearray = self.body + + # Compress response if it exceeds threshold and client accepts zstd + if len(body) > RESPONSE_COMPRESSION_THRESHOLD and "zstd" in self.accept_encoding and ZSTD_AVAILABLE: + body = zstd_compress(body) + self.headers["content-encoding"] = "zstd" + return { "statusCode": self.status_code, "headers": self.headers, "isBase64Encoded": True, - "body": base64.b64encode(self.body).decode(), + "body": base64.b64encode(body).decode(), } @@ -134,7 +150,8 @@ def lambda_handler(event: RestateLambdaRequest, _context: Any) -> RestateLambdaR scope = create_scope(event) recv = request_to_receive(event) - send = ResponseCollector() + req_headers = {k.lower(): v for k, v in event.get("headers", {}).items()} + send = ResponseCollector(accept_encoding=req_headers.get("accept-encoding", "")) asgi_instance = asgi_app(scope, recv, send) asgi_task = loop.create_task(asgi_instance) # type: ignore[var-annotated, arg-type] @@ -143,3 +160,47 @@ def lambda_handler(event: RestateLambdaRequest, _context: Any) -> RestateLambdaR return send.to_lambda_response() return lambda_handler + + +def _check_zstd_available() -> bool: + """Return True if zstd compression is available (Python 3.14+).""" + try: + import compression.zstd # type: ignore[import-not-found] + + return compression.zstd is not None + except ImportError: + return False + + +ZSTD_AVAILABLE = _check_zstd_available() + + +def is_lambda_compression_supported(): + """Return 'zstd' if running on Lambda and compression.zstd is available (Python 3.14+), else None.""" + if is_running_on_lambda() and ZSTD_AVAILABLE: + return "zstd" + return None + + +def zstd_compress(data: bytes | bytearray) -> bytes: + """Compress data using zstd.""" + try: + import compression.zstd # type: ignore[import-not-found] + except ImportError as e: + raise RuntimeError( + "zstd compression requested but compression.zstd is not available. " + "Python 3.14+ is required for zstd compression support." + ) from e + return compression.zstd.compress(data) + + +def zstd_decompress(data: bytes) -> bytes: + """Decompress zstd-compressed data.""" + try: + import compression.zstd # type: ignore[import-not-found] + except ImportError as e: + raise RuntimeError( + "Received zstd-compressed request but compression.zstd is not available. " + "Python 3.14+ is required for zstd compression support." + ) from e + return compression.zstd.decompress(data) diff --git a/python/restate/discovery.py b/python/restate/discovery.py index 343a09d..e8365a9 100644 --- a/python/restate/discovery.py +++ b/python/restate/discovery.py @@ -34,6 +34,7 @@ from restate.handler import TypeHint from restate.object import VirtualObject from restate.workflow import Workflow +from restate.aws_lambda import is_lambda_compression_supported class ProtocolMode(Enum): @@ -159,6 +160,7 @@ def __init__( self.minProtocolVersion = minProtocolVersion self.maxProtocolVersion = maxProtocolVersion self.services = services + self.lambdaCompression = is_lambda_compression_supported() PROTOCOL_MODES = {"bidi": ProtocolMode.BIDI_STREAM, "request_response": ProtocolMode.REQUEST_RESPONSE} @@ -235,6 +237,9 @@ def compute_discovery_json( # Validate that new discovery fields aren't used with older protocol versions if version <= 3: + # Strip lambdaCompression for older discovery versions + ep.lambdaCompression = None + for service in ep.services: if service.retryPolicyInitialInterval is not None: raise ValueError("retryPolicyInitialInterval is only supported in discovery protocol version 4")