Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python: [ "3.10", "3.11", "3.12", "3.13" ]
python: [ "3.10", "3.11", "3.12", "3.13", "3.14" ]
env:
UV_PYTHON: ${{ matrix.python }}
steps:
Expand Down
69 changes: 65 additions & 4 deletions python/restate/aws_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ def request_to_receive(req: RestateLambdaRequest) -> Receive:
assert req["isBase64Encoded"]
body = base64.b64decode(req["body"])

# Decompress zstd-encoded request body
headers = {k.lower(): v for k, v in req.get("headers", {}).items()}
if "zstd" in headers.get("content-encoding", ""):
body = zstd_decompress(body)

events = cast(
list[HTTPRequestEvent],
[
Expand All @@ -80,15 +85,19 @@ async def recv() -> HTTPRequestEvent:
return recv


RESPONSE_COMPRESSION_THRESHOLD = 3 * 1024 * 1024


class ResponseCollector:
"""
Response collector from ASGI Send to Lambda
"""

def __init__(self):
def __init__(self, accept_encoding: str = ""):
self.body = bytearray()
self.headers = {}
self.headers: dict[str, str] = {}
self.status_code = 500
self.accept_encoding = accept_encoding

async def __call__(self, message: Union[HTTPResponseStartEvent, HTTPResponseBodyEvent]) -> None:
"""
Expand All @@ -105,11 +114,18 @@ def to_lambda_response(self) -> RestateLambdaResponse:
"""
Convert collected values to lambda response
"""
body: bytes | bytearray = self.body

# Compress response if it exceeds threshold and client accepts zstd
if len(body) > RESPONSE_COMPRESSION_THRESHOLD and "zstd" in self.accept_encoding and ZSTD_AVAILABLE:
body = zstd_compress(body)
self.headers["content-encoding"] = "zstd"

return {
"statusCode": self.status_code,
"headers": self.headers,
"isBase64Encoded": True,
"body": base64.b64encode(self.body).decode(),
"body": base64.b64encode(body).decode(),
}


Expand All @@ -134,7 +150,8 @@ def lambda_handler(event: RestateLambdaRequest, _context: Any) -> RestateLambdaR

scope = create_scope(event)
recv = request_to_receive(event)
send = ResponseCollector()
req_headers = {k.lower(): v for k, v in event.get("headers", {}).items()}
send = ResponseCollector(accept_encoding=req_headers.get("accept-encoding", ""))

asgi_instance = asgi_app(scope, recv, send)
asgi_task = loop.create_task(asgi_instance) # type: ignore[var-annotated, arg-type]
Expand All @@ -143,3 +160,47 @@ def lambda_handler(event: RestateLambdaRequest, _context: Any) -> RestateLambdaR
return send.to_lambda_response()

return lambda_handler


def _check_zstd_available() -> bool:
"""Return True if zstd compression is available (Python 3.14+)."""
try:
import compression.zstd # type: ignore[import-not-found]

return compression.zstd is not None
except ImportError:
return False


ZSTD_AVAILABLE = _check_zstd_available()


def is_lambda_compression_supported():
"""Return 'zstd' if running on Lambda and compression.zstd is available (Python 3.14+), else None."""
if is_running_on_lambda() and ZSTD_AVAILABLE:
return "zstd"
return None


def zstd_compress(data: bytes | bytearray) -> bytes:
"""Compress data using zstd."""
try:
import compression.zstd # type: ignore[import-not-found]
except ImportError as e:
raise RuntimeError(
"zstd compression requested but compression.zstd is not available. "
"Python 3.14+ is required for zstd compression support."
) from e
return compression.zstd.compress(data)


def zstd_decompress(data: bytes) -> bytes:
"""Decompress zstd-compressed data."""
try:
import compression.zstd # type: ignore[import-not-found]
except ImportError as e:
raise RuntimeError(
"Received zstd-compressed request but compression.zstd is not available. "
"Python 3.14+ is required for zstd compression support."
) from e
return compression.zstd.decompress(data)
5 changes: 5 additions & 0 deletions python/restate/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from restate.handler import TypeHint
from restate.object import VirtualObject
from restate.workflow import Workflow
from restate.aws_lambda import is_lambda_compression_supported


class ProtocolMode(Enum):
Expand Down Expand Up @@ -159,6 +160,7 @@ def __init__(
self.minProtocolVersion = minProtocolVersion
self.maxProtocolVersion = maxProtocolVersion
self.services = services
self.lambdaCompression = is_lambda_compression_supported()


PROTOCOL_MODES = {"bidi": ProtocolMode.BIDI_STREAM, "request_response": ProtocolMode.REQUEST_RESPONSE}
Expand Down Expand Up @@ -235,6 +237,9 @@ def compute_discovery_json(

# Validate that new discovery fields aren't used with older protocol versions
if version <= 3:
# Strip lambdaCompression for older discovery versions
ep.lambdaCompression = None

for service in ep.services:
if service.retryPolicyInitialInterval is not None:
raise ValueError("retryPolicyInitialInterval is only supported in discovery protocol version 4")
Expand Down