diff --git a/internal/cmd/backup.go b/internal/cmd/backup.go index 9656e895..7f544040 100644 --- a/internal/cmd/backup.go +++ b/internal/cmd/backup.go @@ -39,6 +39,146 @@ const ( doNotReturnIfExists = false ) +// BackupConfig holds the configuration for creating a backup. +type BackupConfig struct { + // PrefixFilter filters relationships to only include those with this prefix. + PrefixFilter string + // PageLimit defines the number of relationships to be read per page during backup. + PageLimit uint32 + // RewriteLegacy indicates whether to rewrite legacy schema syntax. + RewriteLegacy bool +} + +// ProgressTracker tracks backup progress for resumability. +type ProgressTracker interface { + // GetCursor returns the stored cursor, or nil if no progress exists. + GetCursor() *v1.Cursor + // WriteCursor writes the current cursor to storage. + WriteCursor(cursor *v1.Cursor) error + // MarkComplete marks the backup as complete (e.g., removes progress file). + MarkComplete() error + // Close closes any underlying resources. + Close() error +} + +// fileProgressTracker implements ProgressTracker using a file. +type fileProgressTracker struct { + file *os.File + cursor *v1.Cursor +} + +func newFileProgressTracker(backupFileName string, backupAlreadyExisted bool) (*fileProgressTracker, error) { + progressFileName := toLockFileName(backupFileName) + var cursor *v1.Cursor + var fileMode int + + readCursor, readErr := os.ReadFile(progressFileName) + if backupAlreadyExisted { + // Backup exists - we need a valid progress file to resume + // Check for errors first (except not-exist) to avoid masking permission/I/O errors + if readErr != nil && !os.IsNotExist(readErr) { + return nil, fmt.Errorf("failed to read progress file for existing backup: %w", readErr) + } + if os.IsNotExist(readErr) || len(readCursor) == 0 { + return nil, fmt.Errorf("backup file %s already exists", backupFileName) + } + // Successfully read the cursor + cursor = &v1.Cursor{ + Token: string(readCursor), + } + // if backup existed and there is a progress marker, the latter should not be truncated + fileMode = os.O_WRONLY | os.O_CREATE + log.Info().Str("filename", backupFileName).Msg("backup file already exists, will resume") + } else { + // if a backup did not exist, make sure to truncate the progress file + fileMode = os.O_WRONLY | os.O_CREATE | os.O_TRUNC + } + + progressFile, err := os.OpenFile(progressFileName, fileMode, 0o644) + if err != nil { + return nil, fmt.Errorf("failed to open progress file: %w", err) + } + + return &fileProgressTracker{ + file: progressFile, + cursor: cursor, + }, nil +} + +func (f *fileProgressTracker) GetCursor() *v1.Cursor { + return f.cursor +} + +func (f *fileProgressTracker) WriteCursor(cursor *v1.Cursor) error { + if cursor == nil { + return errors.New("cannot write nil cursor to progress file") + } + + if err := f.file.Truncate(0); err != nil { + return fmt.Errorf("unable to truncate backup progress file: %w", err) + } + + if _, err := f.file.Seek(0, 0); err != nil { + return fmt.Errorf("unable to seek backup progress file: %w", err) + } + + if _, err := f.file.WriteString(cursor.Token); err != nil { + return fmt.Errorf("unable to write result cursor to backup progress file: %w", err) + } + + // Sync to ensure cursor is durably persisted before continuing + if err := f.file.Sync(); err != nil { + return fmt.Errorf("unable to sync backup progress file: %w", err) + } + + // Update in-memory cursor to keep it consistent with persisted state + f.cursor = cursor + + return nil +} + +func (f *fileProgressTracker) MarkComplete() error { + // Check if already closed/completed + if f.file == nil { + return nil + } + + // Store the filename before closing, as we need to remove after close + // to ensure Windows compatibility (can't remove open files on Windows) + filename := f.file.Name() + + // Close the file first to release the handle + if err := f.file.Sync(); err != nil { + return fmt.Errorf("failed to sync progress file before removal: %w", err) + } + if err := f.file.Close(); err != nil { + return fmt.Errorf("failed to close progress file before removal: %w", err) + } + f.file = nil // Mark as closed so Close() becomes a no-op + + // Now remove the file - log warning but don't fail the backup for cleanup issues + if err := os.Remove(filename); err != nil { + log.Warn(). + Str("progress-file", filename). + Err(err). + Msg("failed to remove progress file, consider removing it manually") + // Don't return error - the backup succeeded, this is just cleanup + } + + return nil +} + +func (f *fileProgressTracker) Close() error { + // Check if file is already closed (e.g., by MarkComplete) + if f.file == nil { + return nil + } + syncErr := f.file.Sync() + closeErr := f.file.Close() + f.file = nil + return errors.Join(syncErr, closeErr) +} + // cobraRunEFunc is the signature of a cobra.Command.RunE function. type cobraRunEFunc = func(cmd *cobra.Command, args []string) (err error) @@ -270,8 +410,11 @@ func hasRelPrefix(rel *v1.Relationship, prefix string) bool { } func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { - prefixFilter := cobrautil.MustGetString(cmd, "prefix-filter") - pageLimit := cobrautil.MustGetUint32(cmd, "page-limit") + config := BackupConfig{ + PrefixFilter: cobrautil.MustGetString(cmd, "prefix-filter"), + PageLimit: cobrautil.MustGetUint32(cmd, "page-limit"), + RewriteLegacy: cobrautil.MustGetBool(cmd, "rewrite-legacy"), + } backupFileName, err := computeBackupFileName(cmd, args) if err != nil { @@ -288,26 +431,11 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { *e = errors.Join(*e, backupFile.Close()) }(&err) - // the goal of this file is to keep the bulk export cursor in case the process is terminated - // and we need to resume from where we left off. OCF does not support in-place record updates. - progressFile, cursor, err := openProgressFile(backupFileName, backupExists) + progressTracker, err := newFileProgressTracker(backupFileName, backupExists) if err != nil { return err } - - var backupCompleted bool - defer func(e *error) { - *e = errors.Join(*e, progressFile.Sync()) - *e = errors.Join(*e, progressFile.Close()) - - if backupCompleted { - if err := os.Remove(progressFile.Name()); err != nil { - log.Warn(). - Str("progress-file", progressFile.Name()). - Msg("failed to remove progress file, consider removing it manually") - } - } - }(&err) + defer func(e *error) { *e = errors.Join(*e, progressTracker.Close()) }(&err) spiceClient, err := client.NewClient(cmd) if err != nil { @@ -322,7 +450,7 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { return fmt.Errorf("error creating backup file encoder: %w", err) } } else { - encoder, zedToken, err = encoderForNewBackup(cmd, spiceClient, backupFile) + encoder, zedToken, err = encoderForNewBackupWithConfig(cmd.Context(), spiceClient, backupFile, config) if err != nil { return err } @@ -330,12 +458,50 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { defer func(e *error) { *e = errors.Join(*e, encoder.Close()) }(&err) + backupCompleted, err := backupCreateImpl(cmd.Context(), spiceClient, encoder, progressTracker, config, zedToken) + if err != nil { + return err + } + + if backupCompleted { + if markErr := progressTracker.MarkComplete(); markErr != nil { + err = errors.Join(err, markErr) + } + } + + return err +} + +// backupCreateImpl performs the core backup logic. It is designed to be testable +// by accepting dependencies as parameters rather than creating them internally. +// +// Parameters: +// - ctx: Context for cancellation +// - spiceClient: The SpiceDB client to use for exporting relationships +// - encoder: The encoder to write relationships to +// - progressTracker: Tracks progress for resumability +// - config: Backup configuration options +// - zedToken: The token to use for consistency (nil if resuming from cursor) +// +// Returns: +// - backupCompleted: true if the backup completed successfully +// - err: any error that occurred during backup +func backupCreateImpl( + ctx context.Context, + spiceClient client.Client, + encoder *backupformat.Encoder, + progressTracker ProgressTracker, + config BackupConfig, + zedToken *v1.ZedToken, +) (backupCompleted bool, err error) { + cursor := progressTracker.GetCursor() + if zedToken == nil && cursor == nil { - return errors.New("malformed existing backup, consider recreating it") + return false, errors.New("malformed existing backup, consider recreating it") } req := &v1.ExportBulkRelationshipsRequest{ - OptionalLimit: pageLimit, + OptionalLimit: config.PageLimit, OptionalCursor: cursor, } @@ -348,10 +514,9 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { } } - ctx := cmd.Context() - relationshipReadStart := time.Now() - tick := time.Tick(5 * time.Second) + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() progressBar := console.CreateProgressBar("processing backup") var relsFilteredOut, relsProcessed uint64 defer func() { @@ -374,7 +539,7 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { err = takeBackup(ctx, spiceClient, req, func(response *v1.ExportBulkRelationshipsResponse) error { for _, rel := range response.Relationships { - if hasRelPrefix(rel, prefixFilter) { + if hasRelPrefix(rel, config.PrefixFilter) { if err := encoder.Append(rel); err != nil { return fmt.Errorf("error storing relationship: %w", err) } @@ -390,7 +555,7 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { // progress fallback in case there is no TTY if !isatty.IsTerminal(os.Stderr.Fd()) { select { - case <-tick: + case <-ticker.C: log.Info(). Uint64("filtered", relsFilteredOut). Uint64("processed", relsProcessed). @@ -402,20 +567,18 @@ func backupCreateCmdFunc(cmd *cobra.Command, args []string) (err error) { } } - if err := writeProgress(progressFile, response); err != nil { - return err + if response.AfterResultCursor != nil { + if err := progressTracker.WriteCursor(response.AfterResultCursor); err != nil { + return err + } } return nil }) if err != nil { - return err + return false, err } - backupCompleted = true - // NOTE: we return err here because there's cleanup being done - // in the `defer` blocks that will modify the `err` if cleanup - // fails - return err + return true, nil } func takeBackup(ctx context.Context, spiceClient client.Client, req *v1.ExportBulkRelationshipsRequest, processResponse func(*v1.ExportBulkRelationshipsResponse) error) error { @@ -479,12 +642,10 @@ func takeBackup(ctx context.Context, spiceClient client.Client, req *v1.ExportBu return nil } -// encoderForNewBackup creates a new encoder for a new zed backup file. It returns the ZedToken at which the backup -// must be taken. -func encoderForNewBackup(cmd *cobra.Command, c client.Client, backupFile *os.File) (*backupformat.Encoder, *v1.ZedToken, error) { - prefixFilter := cobrautil.MustGetString(cmd, "prefix-filter") - - schemaResp, err := c.ReadSchema(cmd.Context(), &v1.ReadSchemaRequest{}) +// encoderForNewBackupWithConfig creates a new encoder for a new zed backup file using the provided config. +// It returns the ZedToken at which the backup must be taken. +func encoderForNewBackupWithConfig(ctx context.Context, c client.Client, backupFile *os.File, config BackupConfig) (*backupformat.Encoder, *v1.ZedToken, error) { + schemaResp, err := c.ReadSchema(ctx, &v1.ReadSchemaRequest{}) if err != nil { return nil, nil, fmt.Errorf("error reading schema: %w", err) } @@ -495,14 +656,13 @@ func encoderForNewBackup(cmd *cobra.Command, c client.Client, backupFile *os.Fil // Remove any invalid relations generated from old, backwards-incompat // Serverless permission systems. - if cobrautil.MustGetBool(cmd, "rewrite-legacy") { + if config.RewriteLegacy { schema = rewriteLegacy(schema) } // Skip any definitions without the provided prefix - - if prefixFilter != "" { - schema, err = filterSchemaDefs(schema, prefixFilter) + if config.PrefixFilter != "" { + schema, err = filterSchemaDefs(schema, config.PrefixFilter) if err != nil { return nil, nil, err } @@ -518,64 +678,6 @@ func encoderForNewBackup(cmd *cobra.Command, c client.Client, backupFile *os.Fil return encoder, zedToken, nil } -func writeProgress(progressFile *os.File, relsResp *v1.ExportBulkRelationshipsResponse) error { - err := progressFile.Truncate(0) - if err != nil { - return fmt.Errorf("unable to truncate backup progress file: %w", err) - } - - _, err = progressFile.Seek(0, 0) - if err != nil { - return fmt.Errorf("unable to seek backup progress file: %w", err) - } - - _, err = progressFile.WriteString(relsResp.AfterResultCursor.Token) - if err != nil { - return fmt.Errorf("unable to write result cursor to backup progress file: %w", err) - } - - return nil -} - -// openProgressFile returns the progress marker file and the stored progress cursor if it exists, or creates -// a new one if it does not exist. If the backup file exists, but the progress marker does not, it will return an error. -// -// The progress marker file keeps track of the last successful cursor received from the server, and is used to resume -// backups in case of failure. -func openProgressFile(backupFileName string, backupAlreadyExisted bool) (*os.File, *v1.Cursor, error) { - var cursor *v1.Cursor - progressFileName := toLockFileName(backupFileName) - var progressFile *os.File - // if a backup existed - var fileMode int - readCursor, err := os.ReadFile(progressFileName) - if backupAlreadyExisted { - if os.IsNotExist(err) || len(readCursor) == 0 { - return nil, nil, fmt.Errorf("backup file %s already exists", backupFileName) - } - if err == nil { - cursor = &v1.Cursor{ - Token: string(readCursor), - } - - // if backup existed and there is a progress marker, the latter should not be truncated to make sure the - // cursor stays around in case of a failure before we even start ingesting from bulk export - fileMode = os.O_WRONLY | os.O_CREATE - log.Info().Str("filename", backupFileName).Msg("backup file already exists, will resume") - } - } else { - // if a backup did not exist, make sure to truncate the progress file - fileMode = os.O_WRONLY | os.O_CREATE | os.O_TRUNC - } - - progressFile, err = os.OpenFile(progressFileName, fileMode, 0o644) - if err != nil { - return nil, nil, err - } - - return progressFile, cursor, nil -} - func toLockFileName(backupFileName string) string { return backupFileName + ".lock" } diff --git a/internal/cmd/backup_test.go b/internal/cmd/backup_test.go index 7ad02163..08148f7b 100644 --- a/internal/cmd/backup_test.go +++ b/internal/cmd/backup_test.go @@ -780,3 +780,426 @@ func (m *mockClientForBackup) ExportBulkRelationships(_ context.Context, req *v1 func (m *mockClientForBackup) assertAllRecvCalls() { require.Equal(m.t, len(m.recvCalls), m.recvCallIndex, "the number of provided recvCalls should match the number of invocations") } + +// mockProgressTracker is a test implementation of ProgressTracker. +type mockProgressTracker struct { + cursor *v1.Cursor + writtenCursors []*v1.Cursor + markCompleteErr error + writeErr error + completed bool + closed bool +} + +func (m *mockProgressTracker) GetCursor() *v1.Cursor { + return m.cursor +} + +func (m *mockProgressTracker) WriteCursor(cursor *v1.Cursor) error { + if m.writeErr != nil { + return m.writeErr + } + m.writtenCursors = append(m.writtenCursors, cursor) + return nil +} + +func (m *mockProgressTracker) MarkComplete() error { + m.completed = true + return m.markCompleteErr +} + +func (m *mockProgressTracker) Close() error { + m.closed = true + return nil +} + +func TestBackupCreateImpl(t *testing.T) { + t.Parallel() + + testRels := []*v1.Relationship{ + { + Resource: &v1.ObjectReference{ObjectType: "test/resource", ObjectId: "1"}, + Relation: "reader", + Subject: &v1.SubjectReference{Object: &v1.ObjectReference{ObjectType: "test/user", ObjectId: "1"}}, + }, + { + Resource: &v1.ObjectReference{ObjectType: "test/resource", ObjectId: "2"}, + Relation: "reader", + Subject: &v1.SubjectReference{Object: &v1.ObjectReference{ObjectType: "test/user", ObjectId: "2"}}, + }, + } + + t.Run("successful backup with relationships", func(t *testing.T) { + t.Parallel() + + cursor := &v1.Cursor{Token: "after-cursor"} + mockClient := &mockClientForBackup{ + t: t, + recvCalls: []func() (*v1.ExportBulkRelationshipsResponse, error){ + func() (*v1.ExportBulkRelationshipsResponse, error) { + return &v1.ExportBulkRelationshipsResponse{ + Relationships: testRels, + AfterResultCursor: cursor, + }, nil + }, + }, + } + + progressTracker := &mockProgressTracker{} + zedToken := &v1.ZedToken{Token: "test-token"} + + // Create a temp file for the encoder + tmpFile, err := os.CreateTemp(t.TempDir(), "backup-test") + require.NoError(t, err) + defer func() { _ = tmpFile.Close() }() + + encoder, err := backupformat.NewEncoder(tmpFile, testSchema, zedToken) + require.NoError(t, err) + t.Cleanup(func() { _ = encoder.Close() }) + + config := BackupConfig{PrefixFilter: "test"} + + completed, err := backupCreateImpl(t.Context(), mockClient, encoder, progressTracker, config, zedToken) + + require.NoError(t, err) + require.True(t, completed, "backup should be marked as completed") + require.Len(t, progressTracker.writtenCursors, 1, "should have written one cursor") + require.Equal(t, cursor.Token, progressTracker.writtenCursors[0].Token) + + mockClient.assertAllRecvCalls() + }) + + t.Run("returns error when both zedToken and cursor are nil", func(t *testing.T) { + t.Parallel() + + mockClient := &mockClientForBackup{t: t} + progressTracker := &mockProgressTracker{cursor: nil} + + // Create a temp file for the encoder + tmpFile, err := os.CreateTemp(t.TempDir(), "backup-test") + require.NoError(t, err) + defer func() { _ = tmpFile.Close() }() + + encoder, err := backupformat.NewEncoder(tmpFile, testSchema, &v1.ZedToken{Token: "dummy"}) + require.NoError(t, err) + t.Cleanup(func() { _ = encoder.Close() }) + + config := BackupConfig{} + + completed, err := backupCreateImpl(t.Context(), mockClient, encoder, progressTracker, config, nil) + + require.Error(t, err) + require.False(t, completed) + require.Contains(t, err.Error(), "malformed existing backup") + }) + + t.Run("handles permission denied error", func(t *testing.T) { + t.Parallel() + + mockClient := &mockClientForBackup{ + t: t, + recvCalls: []func() (*v1.ExportBulkRelationshipsResponse, error){ + func() (*v1.ExportBulkRelationshipsResponse, error) { + return nil, status.Error(codes.PermissionDenied, "unauthorized") + }, + }, + } + + progressTracker := &mockProgressTracker{} + zedToken := &v1.ZedToken{Token: "test-token"} + + // Create a temp file for the encoder + tmpFile, err := os.CreateTemp(t.TempDir(), "backup-test") + require.NoError(t, err) + defer func() { _ = tmpFile.Close() }() + + encoder, err := backupformat.NewEncoder(tmpFile, testSchema, zedToken) + require.NoError(t, err) + t.Cleanup(func() { _ = encoder.Close() }) + + config := BackupConfig{PrefixFilter: "test"} + + completed, err := backupCreateImpl(t.Context(), mockClient, encoder, progressTracker, config, zedToken) + + require.Error(t, err) + require.False(t, completed) + require.Contains(t, err.Error(), "PermissionDenied") + }) + + t.Run("filters relationships by prefix", func(t *testing.T) { + t.Parallel() + + mixedRels := []*v1.Relationship{ + { + Resource: &v1.ObjectReference{ObjectType: "test/resource", ObjectId: "1"}, + Relation: "reader", + Subject: &v1.SubjectReference{Object: &v1.ObjectReference{ObjectType: "test/user", ObjectId: "1"}}, + }, + { + Resource: &v1.ObjectReference{ObjectType: "other/resource", ObjectId: "1"}, + Relation: "reader", + Subject: &v1.SubjectReference{Object: &v1.ObjectReference{ObjectType: "other/user", ObjectId: "1"}}, + }, + } + + mockClient := &mockClientForBackup{ + t: t, + recvCalls: []func() (*v1.ExportBulkRelationshipsResponse, error){ + func() (*v1.ExportBulkRelationshipsResponse, error) { + return &v1.ExportBulkRelationshipsResponse{ + Relationships: mixedRels, + AfterResultCursor: &v1.Cursor{Token: "cursor"}, + }, nil + }, + }, + } + + progressTracker := &mockProgressTracker{} + zedToken := &v1.ZedToken{Token: "test-token"} + + // Create a temp file for the encoder + tmpFile, err := os.CreateTemp(t.TempDir(), "backup-test") + require.NoError(t, err) + defer func() { _ = tmpFile.Close() }() + + encoder, err := backupformat.NewEncoder(tmpFile, testSchema, zedToken) + require.NoError(t, err) + t.Cleanup(func() { _ = encoder.Close() }) + + // Only include relationships with "test" prefix + config := BackupConfig{PrefixFilter: "test"} + + completed, err := backupCreateImpl(t.Context(), mockClient, encoder, progressTracker, config, zedToken) + + require.NoError(t, err) + require.True(t, completed) + }) + + t.Run("resumes from cursor", func(t *testing.T) { + t.Parallel() + + resumeCursor := &v1.Cursor{Token: "resume-from-here"} + + mockClient := &mockClientForBackup{ + t: t, + recvCalls: []func() (*v1.ExportBulkRelationshipsResponse, error){ + func() (*v1.ExportBulkRelationshipsResponse, error) { + return &v1.ExportBulkRelationshipsResponse{ + Relationships: testRels, + AfterResultCursor: &v1.Cursor{Token: "new-cursor"}, + }, nil + }, + }, + exportCalls: []func(t *testing.T, req *v1.ExportBulkRelationshipsRequest){ + func(t *testing.T, req *v1.ExportBulkRelationshipsRequest) { + require.NotNil(t, req.OptionalCursor) + require.Equal(t, resumeCursor.Token, req.OptionalCursor.Token) + // When resuming from cursor, consistency should not be set + require.Nil(t, req.Consistency) + }, + }, + } + + progressTracker := &mockProgressTracker{cursor: resumeCursor} + + // Create a temp file for the encoder + tmpFile, err := os.CreateTemp(t.TempDir(), "backup-test") + require.NoError(t, err) + defer func() { _ = tmpFile.Close() }() + + encoder, err := backupformat.NewEncoder(tmpFile, testSchema, &v1.ZedToken{Token: "dummy"}) + require.NoError(t, err) + t.Cleanup(func() { _ = encoder.Close() }) + + config := BackupConfig{PrefixFilter: "test"} + + // Pass nil zedToken to simulate resume scenario + completed, err := backupCreateImpl(t.Context(), mockClient, encoder, progressTracker, config, nil) + + require.NoError(t, err) + require.True(t, completed) + }) + + t.Run("handles context cancellation", func(t *testing.T) { + t.Parallel() + + mockClient := &mockClientForBackup{ + t: t, + recvCalls: []func() (*v1.ExportBulkRelationshipsResponse, error){ + func() (*v1.ExportBulkRelationshipsResponse, error) { + return nil, context.Canceled + }, + }, + } + + progressTracker := &mockProgressTracker{} + zedToken := &v1.ZedToken{Token: "test-token"} + + // Create a temp file for the encoder + tmpFile, err := os.CreateTemp(t.TempDir(), "backup-test") + require.NoError(t, err) + defer func() { _ = tmpFile.Close() }() + + encoder, err := backupformat.NewEncoder(tmpFile, testSchema, zedToken) + require.NoError(t, err) + t.Cleanup(func() { _ = encoder.Close() }) + + config := BackupConfig{PrefixFilter: "test"} + + completed, err := backupCreateImpl(t.Context(), mockClient, encoder, progressTracker, config, zedToken) + + require.Error(t, err) + require.False(t, completed) + require.ErrorIs(t, err, context.Canceled) + }) + + t.Run("handles WriteCursor error", func(t *testing.T) { + t.Parallel() + + writeErr := errors.New("disk full") + mockClient := &mockClientForBackup{ + t: t, + recvCalls: []func() (*v1.ExportBulkRelationshipsResponse, error){ + func() (*v1.ExportBulkRelationshipsResponse, error) { + return &v1.ExportBulkRelationshipsResponse{ + Relationships: testRels, + AfterResultCursor: &v1.Cursor{Token: "cursor"}, + }, nil + }, + }, + } + + progressTracker := &mockProgressTracker{writeErr: writeErr} + zedToken := &v1.ZedToken{Token: "test-token"} + + // Create a temp file for the encoder + tmpFile, err := os.CreateTemp(t.TempDir(), "backup-test") + require.NoError(t, err) + defer func() { _ = tmpFile.Close() }() + + encoder, err := backupformat.NewEncoder(tmpFile, testSchema, zedToken) + require.NoError(t, err) + t.Cleanup(func() { _ = encoder.Close() }) + + config := BackupConfig{PrefixFilter: "test"} + + completed, err := backupCreateImpl(t.Context(), mockClient, encoder, progressTracker, config, zedToken) + + require.Error(t, err) + require.False(t, completed) + require.ErrorIs(t, err, writeErr) + }) + + t.Run("mock MarkComplete returns error when configured", func(t *testing.T) { + t.Parallel() + + // This test verifies the mockProgressTracker correctly returns markCompleteErr. + // Note: backupCreateImpl does NOT call MarkComplete - that's done by the + // caller (backupCreateRunE). This test ensures the mock works correctly + // for integration testing scenarios. + markCompleteErr := errors.New("failed to remove progress file") + progressTracker := &mockProgressTracker{markCompleteErr: markCompleteErr} + + err := progressTracker.MarkComplete() + require.Error(t, err) + require.ErrorIs(t, err, markCompleteErr) + require.True(t, progressTracker.completed, "completed flag should be set even on error") + }) + + t.Run("verifies progressTracker.closed is set on Close", func(t *testing.T) { + t.Parallel() + + progressTracker := &mockProgressTracker{} + require.False(t, progressTracker.closed) + + err := progressTracker.Close() + require.NoError(t, err) + require.True(t, progressTracker.closed) + }) +} + +func TestProgressTracker(t *testing.T) { + t.Parallel() + + t.Run("newFileProgressTracker creates new file when backup doesn't exist", func(t *testing.T) { + t.Parallel() + + backupFile := filepath.Join(t.TempDir(), "test-backup.zedbackup") + + tracker, err := newFileProgressTracker(backupFile, false) + require.NoError(t, err) + defer func() { _ = tracker.Close() }() + + require.Nil(t, tracker.GetCursor()) + require.FileExists(t, toLockFileName(backupFile)) + }) + + t.Run("newFileProgressTracker errors when backup exists without progress file", func(t *testing.T) { + t.Parallel() + + backupFile := filepath.Join(t.TempDir(), "test-backup.zedbackup") + + _, err := newFileProgressTracker(backupFile, true) + require.Error(t, err) + require.Contains(t, err.Error(), "already exists") + }) + + t.Run("newFileProgressTracker resumes from existing progress file", func(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + backupFile := filepath.Join(tmpDir, "test-backup.zedbackup") + progressFile := toLockFileName(backupFile) + + // Create progress file with cursor + err := os.WriteFile(progressFile, []byte("test-cursor-token"), 0o600) + require.NoError(t, err) + + tracker, err := newFileProgressTracker(backupFile, true) + require.NoError(t, err) + defer func() { _ = tracker.Close() }() + + cursor := tracker.GetCursor() + require.NotNil(t, cursor) + require.Equal(t, "test-cursor-token", cursor.Token) + }) + + t.Run("WriteCursor updates progress file", func(t *testing.T) { + t.Parallel() + + backupFile := filepath.Join(t.TempDir(), "test-backup.zedbackup") + + tracker, err := newFileProgressTracker(backupFile, false) + require.NoError(t, err) + defer func() { _ = tracker.Close() }() + + cursor := &v1.Cursor{Token: "new-cursor-token"} + err = tracker.WriteCursor(cursor) + require.NoError(t, err) + + // Verify file contents + contents, err := os.ReadFile(toLockFileName(backupFile)) + require.NoError(t, err) + require.Equal(t, "new-cursor-token", string(contents)) + }) + + t.Run("MarkComplete removes progress file", func(t *testing.T) { + t.Parallel() + + backupFile := filepath.Join(t.TempDir(), "test-backup.zedbackup") + + tracker, err := newFileProgressTracker(backupFile, false) + require.NoError(t, err) + + progressFileName := toLockFileName(backupFile) + require.FileExists(t, progressFileName) + + err = tracker.MarkComplete() + require.NoError(t, err) + require.NoFileExists(t, progressFileName) + + // Close should be a no-op after MarkComplete (file already closed) + err = tracker.Close() + require.NoError(t, err) + }) +}