diff --git a/logger.go b/logger.go index 0f53ebc..72e9346 100644 --- a/logger.go +++ b/logger.go @@ -86,16 +86,24 @@ func newPodEventLogger(ctx context.Context, opts podEventLoggerOptions) (*podEve }, maxRetries: opts.maxRetries, }, + doneChan: make(chan struct{}), } + // Start the work goroutine once + go reporter.lq.work(reporter.ctx, reporter.doneChan) + // If no namespaces are provided, we listen for events in all namespaces. if len(opts.namespaces) == 0 { if err := reporter.initNamespace(""); err != nil { + reporter.cancelFunc() + <-reporter.doneChan return nil, fmt.Errorf("init namespace: %w", err) } } else { for _, namespace := range opts.namespaces { if err := reporter.initNamespace(namespace); err != nil { + reporter.cancelFunc() + <-reporter.doneChan return nil, err } } @@ -119,6 +127,11 @@ type podEventLogger struct { // hasSyncedFuncs tracks informer cache sync functions for testing hasSyncedFuncs []cache.InformerSynced + + // closeOnce ensures Close() is idempotent + closeOnce sync.Once + // doneChan is closed when the work goroutine exits + doneChan chan struct{} } // resolveEnvValue resolves the value of an environment variable, supporting both @@ -161,8 +174,6 @@ func (p *podEventLogger) initNamespace(namespace string) error { // This is to prevent us from sending duplicate events. startTime := time.Now() - go p.lq.work(p.ctx) - podFactory := informers.NewSharedInformerFactoryWithOptions(p.client, 0, informers.WithNamespace(namespace), informers.WithTweakListOptions(func(lo *v1.ListOptions) { lo.FieldSelector = p.fieldSelector lo.LabelSelector = p.labelSelector @@ -411,10 +422,15 @@ func (p *podEventLogger) sendDelete(token string) { } // Close stops the pod event logger and releases all resources. +// Close is idempotent and safe to call multiple times. func (p *podEventLogger) Close() error { - p.cancelFunc() - close(p.stopChan) - close(p.errChan) + p.closeOnce.Do(func() { + p.cancelFunc() + close(p.stopChan) + close(p.errChan) + }) + // Wait for the work goroutine to exit + <-p.doneChan return nil } @@ -503,7 +519,10 @@ type logQueuer struct { maxRetries int } -func (l *logQueuer) work(ctx context.Context) { +func (l *logQueuer) work(ctx context.Context, done chan struct{}) { + defer close(done) + defer l.cleanup() + for ctx.Err() == nil { select { case log := <-l.q: @@ -521,6 +540,19 @@ func (l *logQueuer) work(ctx context.Context) { } } +// cleanup stops all retry timers and cleans up resources when the work loop exits. +func (l *logQueuer) cleanup() { + l.mu.Lock() + defer l.mu.Unlock() + + for token, rs := range l.retries { + if rs != nil && rs.timer != nil { + rs.timer.Stop() + } + delete(l.retries, token) + } +} + func (l *logQueuer) newLogger(ctx context.Context, log agentLog) (agentLoggerLifecycle, error) { client := agentsdk.New(l.coderURL, agentsdk.WithFixedToken(log.agentToken)) logger := l.logger.With(slog.F("resource_name", log.resourceName)) diff --git a/logger_test.go b/logger_test.go index fd1dbcf..df509e8 100644 --- a/logger_test.go +++ b/logger_test.go @@ -675,7 +675,7 @@ func Test_logQueuer(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - go lq.work(ctx) + go lq.work(ctx, make(chan struct{})) ch <- agentLog{ op: opLog, @@ -742,7 +742,7 @@ func Test_logQueuer(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - go lq.work(ctx) + go lq.work(ctx, make(chan struct{})) token := "retry-token" ch <- agentLog{ @@ -905,7 +905,7 @@ func Test_logQueuer(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - go lq.work(ctx) + go lq.work(ctx, make(chan struct{})) token := "max-retry-token" ch <- agentLog{ @@ -1111,7 +1111,7 @@ func Test_logCache(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - go lq.work(ctx) + go lq.work(ctx, make(chan struct{})) token := "test-token" @@ -1179,6 +1179,90 @@ func Test_logCache(t *testing.T) { }) } +func TestCloseIdempotent(t *testing.T) { + t.Parallel() + + api := newFakeAgentAPI(t) + + ctx := testutil.Context(t, testutil.WaitShort) + agentURL, err := url.Parse(api.server.URL) + require.NoError(t, err) + namespace := "test-namespace" + + client := fake.NewSimpleClientset() + + cMock := quartz.NewMock(t) + reporter, err := newPodEventLogger(ctx, podEventLoggerOptions{ + client: client, + coderURL: agentURL, + namespaces: []string{namespace}, + logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), + logDebounce: 5 * time.Second, + clock: cMock, + }) + require.NoError(t, err) + + // First close should succeed + err = reporter.Close() + require.NoError(t, err) + + // Second close should not panic (idempotent) + err = reporter.Close() + require.NoError(t, err) +} + +func TestCloseDuringProcessing(t *testing.T) { + t.Parallel() + + api := newFakeAgentAPI(t) + + ctx := testutil.Context(t, testutil.WaitShort) + agentURL, err := url.Parse(api.server.URL) + require.NoError(t, err) + namespace := "test-namespace" + + client := fake.NewSimpleClientset() + + cMock := quartz.NewMock(t) + reporter, err := newPodEventLogger(ctx, podEventLoggerOptions{ + client: client, + coderURL: agentURL, + namespaces: []string{namespace}, + logger: slogtest.Make(t, nil).Leveled(slog.LevelDebug), + logDebounce: 5 * time.Second, + clock: cMock, + }) + require.NoError(t, err) + + // Create a pod to trigger processing + pod := &corev1.Pod{ + ObjectMeta: v1.ObjectMeta{ + Name: "test-pod-close", + Namespace: namespace, + CreationTimestamp: v1.Time{ + Time: time.Now().Add(time.Hour), + }, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Env: []corev1.EnvVar{{ + Name: "CODER_AGENT_TOKEN", + Value: "test-token", + }}, + }}, + }, + } + _, err = client.CoreV1().Pods(namespace).Create(ctx, pod, v1.CreateOptions{}) + require.NoError(t, err) + + // Wait for log source to be registered + _ = testutil.RequireReceive(ctx, t, api.logSource) + + // Close while processing is active + err = reporter.Close() + require.NoError(t, err) +} + func newFakeAgentAPI(t *testing.T) *fakeAgentAPI { logger := slogtest.Make(t, nil) mux := drpcmux.New()