From aa376ace7bf27c83758a1fb970663020ff7eb385 Mon Sep 17 00:00:00 2001 From: Vichym Date: Wed, 11 Feb 2026 16:12:02 -0800 Subject: [PATCH] chore: incremental emulator update --- .../aws-lambda-rie/internal/app.go | 6 +- .../aws-lambda-rie/internal/invoke/consts.go | 12 +++ .../internal/invoke/rie_invoke_request.go | 50 +++++++++-- .../invoke/rie_invoke_request_test.go | 84 +++++++++++++++++-- .../aws-lambda-rie/test/rie_test.go | 2 - .../interop/error_utils_test.go | 8 +- .../invoke/invoke_router.go | 2 +- .../invoke/invoke_router_test.go | 2 +- .../invoke/metrics.go | 3 +- .../invoke/metrics_test.go | 28 +++++++ .../rapid/model/client_error.go | 2 +- .../lambda-managed-instances/rapid/sandbox.go | 8 +- .../rapid/shutdown_metrics.go | 2 +- .../rapid/shutdown_metrics_test.go | 6 +- .../rapidcore/env/environment.go | 2 +- .../rapidcore/env/environment_test.go | 1 + .../lambda-managed-instances/raptor/app.go | 9 +- .../raptor/app_test.go | 23 ++--- .../lambda-managed-instances/raptor/server.go | 8 +- .../testutils/functional/extension_actions.go | 2 - .../testutils/functional/extensions_client.go | 2 - .../testutils/functional/fluxpump_server.go | 2 - .../testutils/functional/httputils.go | 2 - .../functional/in_memory_events_api.go | 2 - .../functional/process_supervisor.go | 2 - .../testutils/functional/runtime_actions.go | 2 - .../testutils/functional/runtime_client.go | 2 - .../testutils/functional/supv.go | 2 - .../testutils/test_data.go | 3 +- internal/lambda/rapidcore/server.go | 2 +- internal/lambda/rie/handlers.go | 38 +++++++-- 31 files changed, 242 insertions(+), 77 deletions(-) create mode 100644 internal/lambda-managed-instances/aws-lambda-rie/internal/invoke/consts.go diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/app.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/app.go index a25091e..18dbe8e 100644 --- a/internal/lambda-managed-instances/aws-lambda-rie/internal/app.go +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/app.go @@ -60,7 +60,11 @@ func (h *HTTPHandler) invoke(w http.ResponseWriter, r *http.Request) { return } - invokeReq := rieinvoke.NewRieInvokeRequest(r, w) + invokeReq, err := rieinvoke.NewRieInvokeRequest(r, w) + if err != nil { + h.respondWithError(w, err) + return + } ctx := logging.WithInvokeID(r.Context(), invokeReq.InvokeID()) metrics := invoke.NewInvokeMetrics(nil, &noOpCounter{}) diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/invoke/consts.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/invoke/consts.go new file mode 100644 index 0000000..ddab8e8 --- /dev/null +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/invoke/consts.go @@ -0,0 +1,12 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package invoke + +const ( + RequestIdHeader = "X-Amzn-RequestId" + + ClientContextHeader = "X-Amz-Client-Context" + + CognitoIdentityHeader = "X-Amz-Cognito-Identity" +) diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/invoke/rie_invoke_request.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/invoke/rie_invoke_request.go index 26a1cae..608d8e4 100644 --- a/internal/lambda-managed-instances/aws-lambda-rie/internal/invoke/rie_invoke_request.go +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/invoke/rie_invoke_request.go @@ -4,8 +4,12 @@ package invoke import ( + "encoding/base64" + "encoding/json" "errors" + "fmt" "io" + "log/slog" "net/http" "time" @@ -16,6 +20,11 @@ import ( "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" ) +type cognitoIdentity struct { + CognitoIdentityID string `json:"cognitoIdentityId"` + CognitoIdentityPoolID string `json:"cognitoIdentityPoolId"` +} + type rieInvokeRequest struct { request *http.Request writer http.ResponseWriter @@ -35,18 +44,47 @@ type rieInvokeRequest struct { functionVersionID string } -func NewRieInvokeRequest(request *http.Request, writer http.ResponseWriter) *rieInvokeRequest { +func NewRieInvokeRequest(request *http.Request, writer http.ResponseWriter) (*rieInvokeRequest, model.AppError) { contentType := request.Header.Get(invoke.ДontentTypeHeader) if contentType == "" { contentType = "application/json" } - invokeID := request.Header.Get("X-Amzn-RequestId") + invokeID := request.Header.Get(RequestIdHeader) if invokeID == "" { invokeID = uuid.New().String() } + clientContext := "" + if encodedClientContext := request.Header.Get(ClientContextHeader); encodedClientContext != "" { + decodedClientContext, err := base64.StdEncoding.DecodeString(encodedClientContext) + if err != nil { + slog.Warn("Failed to decode X-Amz-Client-Context header", "err", err) + return nil, model.NewClientError( + fmt.Errorf("X-Amz-Client-Context must be a valid base64 encoded string: %w", err), + model.ErrorSeverityInvalid, + model.ErrorMalformedRequest, + ) + } + clientContext = string(decodedClientContext) + } + + var cognitoIdentityId, cognitoIdentityPoolId string + if cognitoIdentityHeader := request.Header.Get(CognitoIdentityHeader); cognitoIdentityHeader != "" { + var cognito cognitoIdentity + if err := json.Unmarshal([]byte(cognitoIdentityHeader), &cognito); err != nil { + slog.Warn("Failed to parse X-Amz-Cognito-Identity header", "err", err) + return nil, model.NewClientError( + fmt.Errorf("X-Amz-Cognito-Identity must be a valid JSON string: %w", err), + model.ErrorSeverityInvalid, + model.ErrorMalformedRequest, + ) + } + cognitoIdentityId = cognito.CognitoIdentityID + cognitoIdentityPoolId = cognito.CognitoIdentityPoolID + } + req := &rieInvokeRequest{ request: request, writer: writer, @@ -56,13 +94,13 @@ func NewRieInvokeRequest(request *http.Request, writer http.ResponseWriter) *rie responseBandwidthRate: 2 * 1024 * 1024, responseBandwidthBurstSize: 6 * 1024 * 1024, traceId: request.Header.Get(invoke.TraceIdHeader), - cognitoIdentityId: "", - cognitoIdentityPoolId: "", - clientContext: request.Header.Get("X-Amz-Client-Context"), + cognitoIdentityId: cognitoIdentityId, + cognitoIdentityPoolId: cognitoIdentityPoolId, + clientContext: clientContext, responseMode: request.Header.Get(invoke.ResponseModeHeader), } - return req + return req, nil } func (r *rieInvokeRequest) ContentType() string { diff --git a/internal/lambda-managed-instances/aws-lambda-rie/internal/invoke/rie_invoke_request_test.go b/internal/lambda-managed-instances/aws-lambda-rie/internal/invoke/rie_invoke_request_test.go index f817b38..7ead546 100644 --- a/internal/lambda-managed-instances/aws-lambda-rie/internal/invoke/rie_invoke_request_test.go +++ b/internal/lambda-managed-instances/aws-lambda-rie/internal/invoke/rie_invoke_request_test.go @@ -10,14 +10,18 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/aws/aws-lambda-runtime-interface-emulator/internal/lambda-managed-instances/rapid/model" ) func TestNewRieInvokeRequest(t *testing.T) { tests := []struct { - name string - request func() *http.Request - writer http.ResponseWriter - want *rieInvokeRequest + name string + request func() *http.Request + writer http.ResponseWriter + want *rieInvokeRequest + wantError bool + wantErrorContain string }{ { name: "no_headers_in_request", @@ -37,6 +41,7 @@ func TestNewRieInvokeRequest(t *testing.T) { cognitoIdentityPoolId: "", clientContext: "", }, + wantError: false, }, { name: "all_headers_present_in_request", @@ -46,6 +51,7 @@ func TestNewRieInvokeRequest(t *testing.T) { r.Header.Set("X-Amzn-Trace-Id", "Root=1-5e1b4151-5ac6c58f3375aa3c7c6b73c9") r.Header.Set("X-Amz-Client-Context", "eyJjdXN0b20iOnsidGVzdCI6InZhbHVlIn19") r.Header.Set("X-Amzn-RequestId", "test-invoke-id") + r.Header.Set("X-Amz-Cognito-Identity", `{"cognitoIdentityId":"us-east-1:12345678-1234-1234-1234-123456789012","cognitoIdentityPoolId":"us-east-1:87654321-4321-4321-4321-210987654321"}`) require.NoError(t, err) return r }, @@ -57,16 +63,80 @@ func TestNewRieInvokeRequest(t *testing.T) { responseBandwidthRate: 2 * 1024 * 1024, responseBandwidthBurstSize: 6 * 1024 * 1024, traceId: "Root=1-5e1b4151-5ac6c58f3375aa3c7c6b73c9", - cognitoIdentityId: "", + cognitoIdentityId: "us-east-1:12345678-1234-1234-1234-123456789012", + cognitoIdentityPoolId: "us-east-1:87654321-4321-4321-4321-210987654321", + clientContext: `{"custom":{"test":"value"}}`, + }, + wantError: false, + }, + { + name: "malformed_cognito_identity_header", + request: func() *http.Request { + r, err := http.NewRequest("GET", "http://localhost/", nil) + r.Header.Set("X-Amzn-RequestId", "test-invoke-id") + r.Header.Set("X-Amz-Cognito-Identity", "not-valid-json{") + require.NoError(t, err) + return r + }, + writer: httptest.NewRecorder(), + want: nil, + wantError: true, + wantErrorContain: "X-Amz-Cognito-Identity must be a valid JSON string", + }, + { + name: "malformed_client_context_header", + request: func() *http.Request { + r, err := http.NewRequest("GET", "http://localhost/", nil) + r.Header.Set("X-Amzn-RequestId", "test-invoke-id") + r.Header.Set("X-Amz-Client-Context", "not-valid-base64!!!") + require.NoError(t, err) + return r + }, + writer: httptest.NewRecorder(), + want: nil, + wantError: true, + wantErrorContain: "X-Amz-Client-Context must be a valid base64 encoded string", + }, + { + name: "partial_cognito_identity_header", + request: func() *http.Request { + r, err := http.NewRequest("GET", "http://localhost/", nil) + r.Header.Set("X-Amzn-RequestId", "test-invoke-id") + r.Header.Set("X-Amz-Cognito-Identity", `{"cognitoIdentityId":"us-east-1:only-id"}`) + require.NoError(t, err) + return r + }, + writer: httptest.NewRecorder(), + want: &rieInvokeRequest{ + invokeID: "test-invoke-id", + contentType: "application/json", + maxPayloadSize: 6*1024*1024 + 100, + responseBandwidthRate: 2 * 1024 * 1024, + responseBandwidthBurstSize: 6 * 1024 * 1024, + traceId: "", + cognitoIdentityId: "us-east-1:only-id", cognitoIdentityPoolId: "", - clientContext: "eyJjdXN0b20iOnsidGVzdCI6InZhbHVlIn19", + clientContext: "", }, + wantError: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { r := tt.request() - got := NewRieInvokeRequest(r, tt.writer) + got, err := NewRieInvokeRequest(r, tt.writer) + + if tt.wantError { + assert.NotNil(t, err) + assert.Nil(t, got) + assert.Equal(t, model.ErrorMalformedRequest, err.ErrorType()) + assert.Equal(t, http.StatusBadRequest, err.ReturnCode()) + assert.Contains(t, err.Error(), tt.wantErrorContain) + return + } + + assert.Nil(t, err) + require.NotNil(t, got) tt.want.request = r tt.want.writer = tt.writer diff --git a/internal/lambda-managed-instances/aws-lambda-rie/test/rie_test.go b/internal/lambda-managed-instances/aws-lambda-rie/test/rie_test.go index 4992ec9..2df9523 100644 --- a/internal/lambda-managed-instances/aws-lambda-rie/test/rie_test.go +++ b/internal/lambda-managed-instances/aws-lambda-rie/test/rie_test.go @@ -1,8 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -//go:build test - package test import ( diff --git a/internal/lambda-managed-instances/interop/error_utils_test.go b/internal/lambda-managed-instances/interop/error_utils_test.go index e6153a0..8163f19 100644 --- a/internal/lambda-managed-instances/interop/error_utils_test.go +++ b/internal/lambda-managed-instances/interop/error_utils_test.go @@ -18,22 +18,22 @@ func TestBuildStatusFromError(t *testing.T) { expected ResponseStatus }{ { - name: "nil error", + name: "nilError", err: nil, expected: Success, }, { - name: "sandbox timeout error", + name: "sandboxTimeoutError", err: model.NewCustomerError(model.ErrorSandboxTimedout), expected: Timeout, }, { - name: "customer error", + name: "customerError", err: model.NewCustomerError(model.ErrorFunctionUnknown), expected: Error, }, { - name: "runtime error", + name: "platformError", err: model.NewPlatformError(nil, model.ErrorReasonUnknownError), expected: Failure, }, diff --git a/internal/lambda-managed-instances/invoke/invoke_router.go b/internal/lambda-managed-instances/invoke/invoke_router.go index f67bf95..4066bfe 100644 --- a/internal/lambda-managed-instances/invoke/invoke_router.go +++ b/internal/lambda-managed-instances/invoke/invoke_router.go @@ -105,7 +105,7 @@ func (ir *InvokeRouter) Invoke(ctx context.Context, initData interop.InitStaticD if !ir.runningInvokes.SetIfAbsent(invokeReq.InvokeID(), idleRuntime) { logging.Warn(ctx, "InvokeRouter error: duplicated invokeId") - return model.NewClientError(ErrInvokeIdAlreadyExists, model.ErrorSeverityError, model.ErrorDublicatedInvokeId), false + return model.NewClientError(ErrInvokeIdAlreadyExists, model.ErrorSeverityError, model.ErrorDuplicatedInvokeId), false } defer ir.runningInvokes.Remove(invokeReq.InvokeID()) diff --git a/internal/lambda-managed-instances/invoke/invoke_router_test.go b/internal/lambda-managed-instances/invoke/invoke_router_test.go index 03e8bf6..7e5b638 100644 --- a/internal/lambda-managed-instances/invoke/invoke_router_test.go +++ b/internal/lambda-managed-instances/invoke/invoke_router_test.go @@ -190,7 +190,7 @@ func TestInvokeFailure_DublicatedInvokeId(t *testing.T) { err = <-ch assert.Error(t, err) - assert.Equal(t, model.ErrorDublicatedInvokeId, err.ErrorType()) + assert.Equal(t, model.ErrorDuplicatedInvokeId, err.ErrorType()) close(respChannel) err = <-ch diff --git a/internal/lambda-managed-instances/invoke/metrics.go b/internal/lambda-managed-instances/invoke/metrics.go index e7ecbc1..f629da0 100644 --- a/internal/lambda-managed-instances/invoke/metrics.go +++ b/internal/lambda-managed-instances/invoke/metrics.go @@ -352,7 +352,8 @@ func (e *invokeMetrics) buildMetrics() []servicelogs.Metric { switch e.error.(type) { case model.ClientError: clientErrCnt = 1 - if e.error.ErrorType() != model.ErrorRuntimeUnavailable { + if e.error.ErrorType() != model.ErrorRuntimeUnavailable && + e.error.ErrorType() != model.ErrorDuplicatedInvokeId { nonCustomerErrCnt = 1 } diff --git a/internal/lambda-managed-instances/invoke/metrics_test.go b/internal/lambda-managed-instances/invoke/metrics_test.go index d146845..d5cb56d 100644 --- a/internal/lambda-managed-instances/invoke/metrics_test.go +++ b/internal/lambda-managed-instances/invoke/metrics_test.go @@ -415,6 +415,34 @@ func Test_invokeMetrics_ServiceLogs(t *testing.T) { {Type: servicelogs.CounterType, Key: "NonCustomerError", Value: 0}, }, }, + { + name: "duplicated_invoke_id_error", + expectedBytes: 0, + metricFlow: func(ev *invokeMetrics, mocks *invokeMetricsMocks) { + mocks.timeStamp = mocks.timeStamp.Add(time.Second) + ev.AttachInvokeRequest(&mocks.invokeReq) + ev.AttachDependencies(&mocks.initData, &mocks.eventsApi) + ev.UpdateConcurrencyMetrics(5, 3) + mocks.error = model.NewClientError(nil, model.ErrorSeverityError, model.ErrorDuplicatedInvokeId) + }, + expectedProps: []servicelogs.Property{ + {Name: "RequestId", Value: "invoke-id"}, + }, + expectedDims: []servicelogs.Dimension{ + {Name: "RequestMode", Value: "Streaming"}, + }, + expectedMetrics: []servicelogs.Metric{ + {Type: servicelogs.TimerType, Key: "TotalDuration", Value: 1000000}, + {Type: servicelogs.CounterType, Key: "InflightRequestCount", Value: 5}, + {Type: servicelogs.CounterType, Key: "IdleRuntimesCount", Value: 3}, + {Type: servicelogs.TimerType, Key: "PlatformOverheadDuration", Value: 1000000}, + {Type: servicelogs.CounterType, Key: "ClientError", Value: 1}, + {Type: servicelogs.CounterType, Key: "CustomerError", Value: 0}, + {Type: servicelogs.CounterType, Key: "PlatformError", Value: 0}, + {Type: servicelogs.CounterType, Key: "ClientErrorReason-Client.DuplicatedInvokeId", Value: 1}, + {Type: servicelogs.CounterType, Key: "NonCustomerError", Value: 0}, + }, + }, { name: "runtime_timeout_flow", expectedBytes: 100, diff --git a/internal/lambda-managed-instances/rapid/model/client_error.go b/internal/lambda-managed-instances/rapid/model/client_error.go index a67e2a1..5780277 100644 --- a/internal/lambda-managed-instances/rapid/model/client_error.go +++ b/internal/lambda-managed-instances/rapid/model/client_error.go @@ -11,7 +11,7 @@ const ( ErrorInitIncomplete ErrorType = "Client.InitIncomplete" ErrorEnvironmentUnhealthy ErrorType = "Client.ExecutionEnvironmentUnhealthy" ErrorRuntimeUnavailable ErrorType = "Runtime.Unavailable" - ErrorDublicatedInvokeId ErrorType = "Client.DuplicatedInvokeId" + ErrorDuplicatedInvokeId ErrorType = "Client.DuplicatedInvokeId" ErrorInvalidFunctionVersion ErrorType = "ErrInvalidFunctionVersion" ErrorInvalidMaxPayloadSize ErrorType = "ErrInvalidMaxPayloadSize" ErrorInvalidResponseBandwidthRate ErrorType = "ErrInvalidResponseBandwidthRate" diff --git a/internal/lambda-managed-instances/rapid/sandbox.go b/internal/lambda-managed-instances/rapid/sandbox.go index 17df54d..2461855 100644 --- a/internal/lambda-managed-instances/rapid/sandbox.go +++ b/internal/lambda-managed-instances/rapid/sandbox.go @@ -120,12 +120,14 @@ func (r *rapidContext) HandleShutdown(shutdownCause model.AppError, metrics inte if err != nil { slog.Warn("Error during shutdown Context shutdown", "err", err) - return model.WrapErrorIntoPlatformFatalError(err, model.ErrSandboxShutdownFailed) + + return nil } duration := metrics.CreateDurationMetric(interop.ShutdownRuntimeServerDuration) - if err := r.server.Shutdown(); err != nil { - slog.Error("Error during runtime server shutdown", "err", err) + + if err := r.server.Close(); err != nil { + slog.Error("Error during runtime server close", "err", err) } duration.Done() diff --git a/internal/lambda-managed-instances/rapid/shutdown_metrics.go b/internal/lambda-managed-instances/rapid/shutdown_metrics.go index a6e2a0d..6e60bc1 100644 --- a/internal/lambda-managed-instances/rapid/shutdown_metrics.go +++ b/internal/lambda-managed-instances/rapid/shutdown_metrics.go @@ -119,7 +119,7 @@ func (m *shutdownMetrics) buildMetrics() { switch key := metric.metricName; { case key == interop.TotalDurationMetric: totalDuration = metric.duration - case key == interop.ShutdownRuntimeDuration, key == interop.ShutdownExtensionsDuration, key == interop.ShutdownWaitAllProcessesDuration, m.killProcessDurationRegex.MatchString(key): + case key == interop.ShutdownRuntimeDuration, key == interop.ShutdownExtensionsDuration, key == interop.ShutdownWaitAllProcessesDuration, key == interop.ShutdownAbortInvokesDurationMetric, m.killProcessDurationRegex.MatchString(key): sumCustomerDuration += metric.duration } diff --git a/internal/lambda-managed-instances/rapid/shutdown_metrics_test.go b/internal/lambda-managed-instances/rapid/shutdown_metrics_test.go index 20893fe..d24f5ee 100644 --- a/internal/lambda-managed-instances/rapid/shutdown_metrics_test.go +++ b/internal/lambda-managed-instances/rapid/shutdown_metrics_test.go @@ -70,7 +70,7 @@ func Test_shutdownMetrics(t *testing.T) { {Type: servicelogs.TimerType, Key: "KillruntimeDuration", Value: 1000000}, {Type: servicelogs.TimerType, Key: "WaitCustomerProcessesExitDuration", Value: 2000000}, {Type: servicelogs.TimerType, Key: "ShutdownRuntimeServerDuration", Value: 1000000}, - {Type: servicelogs.TimerType, Key: "PlatformOverheadDuration", Value: 2000000}, + {Type: servicelogs.TimerType, Key: "PlatformOverheadDuration", Value: 1000000}, {Type: servicelogs.CounterType, Key: "TotalExtensionsCount", Value: 0}, {Type: servicelogs.CounterType, Key: "InternalExtensionsCount", Value: 0}, {Type: servicelogs.CounterType, Key: "ExternalExtensionsCount", Value: 0}, @@ -109,7 +109,7 @@ func Test_shutdownMetrics(t *testing.T) { {Type: servicelogs.TimerType, Key: "StopRuntimeDuration", Value: 1000000}, {Type: servicelogs.TimerType, Key: "WaitCustomerProcessesExitDuration", Value: 2000000}, {Type: servicelogs.TimerType, Key: "ShutdownRuntimeServerDuration", Value: 1000000}, - {Type: servicelogs.TimerType, Key: "PlatformOverheadDuration", Value: 2000000}, + {Type: servicelogs.TimerType, Key: "PlatformOverheadDuration", Value: 1000000}, {Type: servicelogs.CounterType, Key: "TotalExtensionsCount", Value: 5}, {Type: servicelogs.CounterType, Key: "InternalExtensionsCount", Value: 2}, {Type: servicelogs.CounterType, Key: "ExternalExtensionsCount", Value: 3}, @@ -156,7 +156,7 @@ func Test_shutdownMetrics(t *testing.T) { {Type: servicelogs.TimerType, Key: "TotalDuration", Value: 2000000}, {Type: servicelogs.TimerType, Key: "AbortInvokeDuration", Value: 1000000}, {Type: servicelogs.TimerType, Key: "ShutdownRuntimeServerDuration", Value: 1000000}, - {Type: servicelogs.TimerType, Key: "PlatformOverheadDuration", Value: 2000000}, + {Type: servicelogs.TimerType, Key: "PlatformOverheadDuration", Value: 1000000}, {Type: servicelogs.CounterType, Key: "TotalExtensionsCount", Value: 5}, {Type: servicelogs.CounterType, Key: "InternalExtensionsCount", Value: 2}, {Type: servicelogs.CounterType, Key: "ExternalExtensionsCount", Value: 3}, diff --git a/internal/lambda-managed-instances/rapidcore/env/environment.go b/internal/lambda-managed-instances/rapidcore/env/environment.go index 6624677..919378f 100644 --- a/internal/lambda-managed-instances/rapidcore/env/environment.go +++ b/internal/lambda-managed-instances/rapidcore/env/environment.go @@ -96,6 +96,7 @@ func SetupEnvironment(config *model.InitRequestMessage, runtimePort, runtimeLogg AWS_LAMBDA_FUNCTION_MEMORY_SIZE: strconv.Itoa(config.MemorySizeBytes / 1024 / 1024), AWS_LAMBDA_FUNCTION_NAME: config.TaskName, AWS_LAMBDA_FUNCTION_VERSION: config.FunctionVersion, + AWS_LAMBDA_MAX_CONCURRENCY: strconv.Itoa(config.RuntimeWorkerCount), AWS_REGION: config.AwsRegion, AWS_SECRET_ACCESS_KEY: config.AwsSecret, AWS_SESSION_TOKEN: config.AwsSession, @@ -127,7 +128,6 @@ func getRuntimeOnlyEnvVars(common model.KVMap, config *model.InitRequestMessage, runtimeOnlyVars := model.KVMap{ AWS_LAMBDA_LOG_GROUP_NAME: config.LogGroupName, AWS_LAMBDA_LOG_STREAM_NAME: config.LogStreamName, - AWS_LAMBDA_MAX_CONCURRENCY: strconv.Itoa(config.RuntimeWorkerCount), _AWS_XRAY_DAEMON_ADDRESS: config.XRayDaemonAddress, _AWS_XRAY_DAEMON_PORT: "2000", AWS_XRAY_CONTEXT_MISSING: "LOG_ERROR", diff --git a/internal/lambda-managed-instances/rapidcore/env/environment_test.go b/internal/lambda-managed-instances/rapidcore/env/environment_test.go index f29578a..7a6712d 100644 --- a/internal/lambda-managed-instances/rapidcore/env/environment_test.go +++ b/internal/lambda-managed-instances/rapidcore/env/environment_test.go @@ -54,6 +54,7 @@ func TestSetupEnvironment(t *testing.T) { AWS_LAMBDA_FUNCTION_VERSION: "$LATEST", AWS_LAMBDA_LOG_FORMAT: "json", AWS_LAMBDA_LOG_LEVEL: "info", + AWS_LAMBDA_MAX_CONCURRENCY: "1", AWS_REGION: "us-west-2", AWS_SECRET_ACCESS_KEY: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", AWS_SESSION_TOKEN: "FwoGZXIvYXdzEMj//////////wEaDM1Qz0oN8BNwV9GqyyLVAebxhwq9ZGqojXZe1UTJkzK6F9V+VZHhT5JSWYzJUKEwOqOkQyQXJpfJsYHfkJEXtR6Kh9mXnEbqKi", diff --git a/internal/lambda-managed-instances/raptor/app.go b/internal/lambda-managed-instances/raptor/app.go index 9a9ac06..d9a5853 100644 --- a/internal/lambda-managed-instances/raptor/app.go +++ b/internal/lambda-managed-instances/raptor/app.go @@ -33,6 +33,7 @@ type App struct { err atomic.Value doneCh chan struct{} + shutdownStartedCh chan struct{} telemetryFDSocketPath string raptorLogger raptorLogger } @@ -48,6 +49,7 @@ func StartApp(deps rapid.Dependencies, telemetryFDSocketPath string, raptorLogge rapidCtx: rapidCtx, state: internal.NewStateGuard(), doneCh: make(chan struct{}), + shutdownStartedCh: make(chan struct{}), telemetryFDSocketPath: telemetryFDSocketPath, raptorLogger: raptorLogger, } @@ -77,7 +79,11 @@ func (a *App) Init(ctx context.Context, init *internalModel.InitRequestMessage, if initErr != nil { logging.Err(ctx, "Received Init error", initErr) - a.Shutdown(initErr) + go func() { + a.Shutdown(initErr) + }() + + <-a.shutdownStartedCh return initErr } @@ -139,6 +145,7 @@ func (a *App) Shutdown(shutdownReason model.AppError) { if shutdownReason != nil { a.err.Store(shutdownReason) } + close(a.shutdownStartedCh) var shutdownErr model.AppError if shutdownErr = a.rapidCtx.HandleShutdown(shutdownReason, metrics); shutdownErr != nil { diff --git a/internal/lambda-managed-instances/raptor/app_test.go b/internal/lambda-managed-instances/raptor/app_test.go index bf5c9d2..581712a 100644 --- a/internal/lambda-managed-instances/raptor/app_test.go +++ b/internal/lambda-managed-instances/raptor/app_test.go @@ -46,12 +46,13 @@ func TestAppInitFailure(t *testing.T) { expectedErr := model.NewCustomerError(model.ErrorReasonRuntimeExecFailed, model.WithSeverity(model.ErrorSeverityFatal)) mockRapidCtx, app, initRequest, initMetrics, _, _ := setupAppTest(t) mockRapidCtx.On("HandleInit", mock.Anything, mock.Anything, mock.Anything).Return(expectedErr) - mockRapidCtx.On("HandleShutdown", mock.Anything, mock.Anything).Return(nil) + mockRapidCtx.On("HandleShutdown", mock.Anything, mock.Anything).Return(nil).Maybe() initErr := app.Init(context.Background(), initRequest, initMetrics) assert.Equal(t, expectedErr, initErr) - assert.Equal(t, internal.Shutdown, app.state.GetState()) + state := app.state.GetState() + assert.True(t, state == internal.Shutdown || state == internal.ShuttingDown, "Expected state to be either Shutdown or ShuttingDown, got %s", state) mockRapidCtx.AssertExpectations(t) } @@ -99,10 +100,11 @@ func TestStartProcessTerminationMonitor(t *testing.T) { mockLogger.On("Log", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything) app := &App{ - rapidCtx: mockRapidCtx, - state: internal.NewStateGuard(), - doneCh: make(chan struct{}), - raptorLogger: mockLogger, + rapidCtx: mockRapidCtx, + state: internal.NewStateGuard(), + doneCh: make(chan struct{}), + shutdownStartedCh: make(chan struct{}), + raptorLogger: mockLogger, } assert.Equal(t, internal.Idle, app.state.GetState()) @@ -223,10 +225,11 @@ func setupAppTest(t *testing.T) (*interop.MockRapidContext, *App, *internalModel mockLogger.On("SetInitData", mock.Anything).Maybe() app := &App{ - rapidCtx: mockRapidCtx, - state: internal.NewStateGuard(), - doneCh: make(chan struct{}), - raptorLogger: mockLogger, + rapidCtx: mockRapidCtx, + state: internal.NewStateGuard(), + doneCh: make(chan struct{}), + shutdownStartedCh: make(chan struct{}), + raptorLogger: mockLogger, } app.StartProcessTerminationMonitor() diff --git a/internal/lambda-managed-instances/raptor/server.go b/internal/lambda-managed-instances/raptor/server.go index 27c00ce..23374f8 100644 --- a/internal/lambda-managed-instances/raptor/server.go +++ b/internal/lambda-managed-instances/raptor/server.go @@ -4,7 +4,6 @@ package raptor import ( - "context" "log/slog" "net" "net/http" @@ -45,13 +44,10 @@ func StartServer(shutdownHandler shutdownHandler, handler http.Handler, addr Add func (s *Server) Shutdown(err error) { s.shutdownOnce.Do(func() { - s.shutdownHandler.Shutdown(model.NewClientError(err, model.ErrorSeverityFatal, model.ErrorExecutionEnvironmentShutdown)) - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() slog.Info("Shutting down HTTP server...") - if err := s.httpServer.Shutdown(ctx); err != nil { - slog.Warn("could not gracefully shutdown EA http server", "err", err) + if err := s.httpServer.Close(); err != nil { + slog.Warn("error shutdown EA http server", "err", err) } if err != nil { diff --git a/internal/lambda-managed-instances/testutils/functional/extension_actions.go b/internal/lambda-managed-instances/testutils/functional/extension_actions.go index 6fa06bd..ef67d1f 100644 --- a/internal/lambda-managed-instances/testutils/functional/extension_actions.go +++ b/internal/lambda-managed-instances/testutils/functional/extension_actions.go @@ -1,8 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -//go:build test - package functional import ( diff --git a/internal/lambda-managed-instances/testutils/functional/extensions_client.go b/internal/lambda-managed-instances/testutils/functional/extensions_client.go index 40117e9..ae5c24a 100644 --- a/internal/lambda-managed-instances/testutils/functional/extensions_client.go +++ b/internal/lambda-managed-instances/testutils/functional/extensions_client.go @@ -1,8 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -//go:build test - package functional import ( diff --git a/internal/lambda-managed-instances/testutils/functional/fluxpump_server.go b/internal/lambda-managed-instances/testutils/functional/fluxpump_server.go index 7ff5e64..c08ce93 100644 --- a/internal/lambda-managed-instances/testutils/functional/fluxpump_server.go +++ b/internal/lambda-managed-instances/testutils/functional/fluxpump_server.go @@ -1,8 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -//go:build test - package functional import ( diff --git a/internal/lambda-managed-instances/testutils/functional/httputils.go b/internal/lambda-managed-instances/testutils/functional/httputils.go index 82c2837..3878116 100644 --- a/internal/lambda-managed-instances/testutils/functional/httputils.go +++ b/internal/lambda-managed-instances/testutils/functional/httputils.go @@ -1,8 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -//go:build test - package functional import ( diff --git a/internal/lambda-managed-instances/testutils/functional/in_memory_events_api.go b/internal/lambda-managed-instances/testutils/functional/in_memory_events_api.go index 773553e..86a648e 100644 --- a/internal/lambda-managed-instances/testutils/functional/in_memory_events_api.go +++ b/internal/lambda-managed-instances/testutils/functional/in_memory_events_api.go @@ -1,8 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -//go:build test - package functional import ( diff --git a/internal/lambda-managed-instances/testutils/functional/process_supervisor.go b/internal/lambda-managed-instances/testutils/functional/process_supervisor.go index dbffbf8..8fb4c22 100644 --- a/internal/lambda-managed-instances/testutils/functional/process_supervisor.go +++ b/internal/lambda-managed-instances/testutils/functional/process_supervisor.go @@ -1,8 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -//go:build test - package functional import ( diff --git a/internal/lambda-managed-instances/testutils/functional/runtime_actions.go b/internal/lambda-managed-instances/testutils/functional/runtime_actions.go index 601c027..ea1559a 100644 --- a/internal/lambda-managed-instances/testutils/functional/runtime_actions.go +++ b/internal/lambda-managed-instances/testutils/functional/runtime_actions.go @@ -1,8 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -//go:build test - package functional import ( diff --git a/internal/lambda-managed-instances/testutils/functional/runtime_client.go b/internal/lambda-managed-instances/testutils/functional/runtime_client.go index 5963d3d..c866016 100644 --- a/internal/lambda-managed-instances/testutils/functional/runtime_client.go +++ b/internal/lambda-managed-instances/testutils/functional/runtime_client.go @@ -1,8 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -//go:build test - package functional import ( diff --git a/internal/lambda-managed-instances/testutils/functional/supv.go b/internal/lambda-managed-instances/testutils/functional/supv.go index 70b16d8..51c59cd 100644 --- a/internal/lambda-managed-instances/testutils/functional/supv.go +++ b/internal/lambda-managed-instances/testutils/functional/supv.go @@ -1,8 +1,6 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -//go:build test - package functional import ( diff --git a/internal/lambda-managed-instances/testutils/test_data.go b/internal/lambda-managed-instances/testutils/test_data.go index 7a5c42c..ccbc0cf 100644 --- a/internal/lambda-managed-instances/testutils/test_data.go +++ b/internal/lambda-managed-instances/testutils/test_data.go @@ -80,9 +80,8 @@ func JsonEncode(payload model.InitRequestMessage) string { func WithInvalidPayload() InitPayloadOption { return func(p *model.InitRequestMessage) { - p.AwsKey = "AKIAIOSFODNN7EXAMPLE" - p.RuntimeBinaryCommand = nil + p.AwsKey = "" } } diff --git a/internal/lambda/rapidcore/server.go b/internal/lambda/rapidcore/server.go index 1ec0215..c0c24e3 100644 --- a/internal/lambda/rapidcore/server.go +++ b/internal/lambda/rapidcore/server.go @@ -658,7 +658,7 @@ func (s *Server) Invoke(responseWriter http.ResponseWriter, invoke *interop.Invo // The logic would be almost identical, except that init failures could manifest // through return values of FastInvoke and not Reserve() - reserveResp, err := s.Reserve("", "", "") + reserveResp, err := s.Reserve(invoke.ID, "", "") if err != nil { log.Infof("ReserveFailed: %s", err) } diff --git a/internal/lambda/rie/handlers.go b/internal/lambda/rie/handlers.go index c4310fd..1abe154 100644 --- a/internal/lambda/rie/handlers.go +++ b/internal/lambda/rie/handlers.go @@ -6,6 +6,7 @@ package rie import ( "bytes" "encoding/base64" + "encoding/json" "fmt" "io/ioutil" "math" @@ -116,14 +117,37 @@ func InvokeHandler(w http.ResponseWriter, r *http.Request, sandbox Sandbox, bs i } invokeStart := time.Now() + invokeID := r.Header.Get("X-Amzn-RequestId") + if invokeID == "" { + invokeID = uuid.New().String() + } + + // Parse X-Amz-Cognito-Identity header (JSON with cognitoIdentityId and cognitoIdentityPoolId fields) + var cognitoIdentityID, cognitoIdentityPoolID string + if cognitoIdentityHeader := r.Header.Get("X-Amz-Cognito-Identity"); cognitoIdentityHeader != "" { + var cognitoIdentity struct { + CognitoIdentityID string `json:"cognitoIdentityId"` + CognitoIdentityPoolID string `json:"cognitoIdentityPoolId"` + } + if err := json.Unmarshal([]byte(cognitoIdentityHeader), &cognitoIdentity); err != nil { + log.Errorf("Failed to parse X-Amz-Cognito-Identity header: %s", err) + w.WriteHeader(500) + return + } + cognitoIdentityID = cognitoIdentity.CognitoIdentityID + cognitoIdentityPoolID = cognitoIdentity.CognitoIdentityPoolID + } + invokePayload := &interop.Invoke{ - ID: uuid.New().String(), - InvokedFunctionArn: fmt.Sprintf("arn:aws:lambda:us-east-1:012345678912:function:%s", GetenvWithDefault("AWS_LAMBDA_FUNCTION_NAME", "test_function")), - TraceID: r.Header.Get("X-Amzn-Trace-Id"), - LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), - TenantID: interop.TenantID(r.Header.Get("X-Amz-Tenant-Id")), - Payload: bytes.NewReader(bodyBytes), - ClientContext: string(rawClientContext), + ID: invokeID, + InvokedFunctionArn: fmt.Sprintf("arn:aws:lambda:us-east-1:012345678912:function:%s", GetenvWithDefault("AWS_LAMBDA_FUNCTION_NAME", "test_function")), + TraceID: r.Header.Get("X-Amzn-Trace-Id"), + LambdaSegmentID: r.Header.Get("X-Amzn-Segment-Id"), + TenantID: interop.TenantID(r.Header.Get("X-Amz-Tenant-Id")), + Payload: bytes.NewReader(bodyBytes), + ClientContext: string(rawClientContext), + CognitoIdentityID: cognitoIdentityID, + CognitoIdentityPoolID: cognitoIdentityPoolID, } fmt.Println("START RequestId: " + invokePayload.ID + " Version: " + functionVersion)