diff --git a/internal/db/test/test.go b/internal/db/test/test.go index 2852ee943..dc5273dc3 100644 --- a/internal/db/test/test.go +++ b/internal/db/test/test.go @@ -7,6 +7,8 @@ import ( "os" "path" "path/filepath" + "regexp" + "strings" "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/network" @@ -25,31 +27,43 @@ const ( DISABLE_PGTAP = "drop extension if exists pgtap" ) +var irPattern = regexp.MustCompile(`(?im)^\s*\\ir\s+['"]?([^'"\s]+)['"]?`) + func Run(ctx context.Context, testFiles []string, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { // Build test command if len(testFiles) == 0 { - absTestsDir, err := filepath.Abs(utils.DbTestsDir) - if err != nil { - return errors.Errorf("failed to resolve tests dir: %w", err) - } - testFiles = append(testFiles, absTestsDir) + testFiles = append(testFiles, utils.DbTestsDir) + } + allFiles, err := traverseImports(testFiles, fsys) + if err != nil { + return err } - binds := make([]string, len(testFiles)) + testFileSet := make(map[string]struct{}, len(testFiles)) + for _, tf := range testFiles { + testFileSet[tf] = struct{}{} + } + binds := make([]string, len(allFiles)) cmd := []string{"pg_prove", "--ext", ".pg", "--ext", ".sql", "-r"} var workingDir string - for i, fp := range testFiles { + for i, fp := range allFiles { if !filepath.IsAbs(fp) { fp = filepath.Join(utils.CurrentDirAbs, fp) } dockerPath := utils.ToDockerPath(fp) - cmd = append(cmd, dockerPath) - binds[i] = fmt.Sprintf("%s:%s:ro", fp, dockerPath) if workingDir == "" { workingDir = dockerPath if path.Ext(dockerPath) != "" { workingDir = path.Dir(dockerPath) } } + if _, isTestFile := testFileSet[allFiles[i]]; isTestFile { + relPath := dockerPath + if path.Ext(dockerPath) != "" && path.Dir(dockerPath) == workingDir { + relPath = path.Base(dockerPath) + } + cmd = append(cmd, relPath) + } + binds[i] = fmt.Sprintf("%s:%s:ro", fp, dockerPath) } if viper.GetBool("DEBUG") { cmd = append(cmd, "--verbose") @@ -107,3 +121,38 @@ func Run(ctx context.Context, testFiles []string, config pgconn.Config, fsys afe os.Stderr, ) } + +func traverseImports(testFiles []string, fsys afero.Fs) ([]string, error) { + seen := map[string]struct{}{} + q := append([]string{}, testFiles...) + result := []string{} + for len(q) > 0 { + curr := q[len(q)-1] + q = q[:len(q)-1] + if _, ok := seen[curr]; ok { + continue + } + seen[curr] = struct{}{} + result = append(result, curr) + info, err := fsys.Stat(curr) + if err != nil { + return nil, errors.Errorf("failed to stat %s: %w", curr, err) + } + if info.IsDir() { + continue + } + data, err := afero.ReadFile(fsys, curr) + if err != nil { + return nil, errors.Errorf("failed to read %s: %w", curr, err) + } + for _, m := range irPattern.FindAllStringSubmatch(string(data), -1) { + if len(m) < 2 { + continue + } + importPath := strings.TrimSpace(m[1]) + resolved := filepath.Join(filepath.Dir(curr), importPath) + q = append(q, resolved) + } + } + return result, nil +} diff --git a/internal/db/test/test_test.go b/internal/db/test/test_test.go index 063f9bcd1..f82f844bf 100644 --- a/internal/db/test/test_test.go +++ b/internal/db/test/test_test.go @@ -30,6 +30,7 @@ func TestRunCommand(t *testing.T) { // Setup in-memory fs fsys := afero.NewMemMapFs() require.NoError(t, utils.WriteConfig(fsys, false)) + require.NoError(t, afero.WriteFile(fsys, "nested", []byte("SELECT 1;"), 0644)) // Setup mock postgres conn := pgtest.NewConn() defer conn.Close(t) @@ -53,6 +54,7 @@ func TestRunCommand(t *testing.T) { // Setup in-memory fs fsys := afero.NewMemMapFs() require.NoError(t, utils.WriteConfig(fsys, false)) + require.NoError(t, fsys.MkdirAll(utils.DbTestsDir, 0755)) // Run test err := Run(context.Background(), nil, dbConfig, fsys) // Check error @@ -63,6 +65,7 @@ func TestRunCommand(t *testing.T) { // Setup in-memory fs fsys := afero.NewMemMapFs() require.NoError(t, utils.WriteConfig(fsys, false)) + require.NoError(t, fsys.MkdirAll(utils.DbTestsDir, 0755)) // Setup mock postgres conn := pgtest.NewConn() defer conn.Close(t) @@ -79,6 +82,7 @@ func TestRunCommand(t *testing.T) { // Setup in-memory fs fsys := afero.NewMemMapFs() require.NoError(t, utils.WriteConfig(fsys, false)) + require.NoError(t, fsys.MkdirAll(utils.DbTestsDir, 0755)) // Setup mock postgres conn := pgtest.NewConn() defer conn.Close(t) @@ -99,3 +103,39 @@ func TestRunCommand(t *testing.T) { assert.Empty(t, apitest.ListUnmatchedRequests()) }) } + +func TestTraverseImports(t *testing.T) { + t.Run("handles file with \\ir import", func(t *testing.T) { + fsys := afero.NewMemMapFs() + require.NoError(t, afero.WriteFile(fsys, "main.sql", []byte("\\ir helper.sql"), 0644)) + require.NoError(t, afero.WriteFile(fsys, "helper.sql", []byte("SELECT 1;"), 0644)) + + result, err := traverseImports([]string{"main.sql"}, fsys) + + assert.NoError(t, err) + assert.Len(t, result, 2) + }) + + t.Run("handles nested \\ir imports", func(t *testing.T) { + fsys := afero.NewMemMapFs() + require.NoError(t, afero.WriteFile(fsys, "main.sql", []byte("\\ir level1.sql"), 0644)) + require.NoError(t, afero.WriteFile(fsys, "level1.sql", []byte("\\ir level2.sql"), 0644)) + require.NoError(t, afero.WriteFile(fsys, "level2.sql", []byte("SELECT 1;"), 0644)) + + result, err := traverseImports([]string{"main.sql"}, fsys) + + assert.NoError(t, err) + assert.Len(t, result, 3) + }) + + t.Run("handles circular imports", func(t *testing.T) { + fsys := afero.NewMemMapFs() + require.NoError(t, afero.WriteFile(fsys, "a.sql", []byte("\\ir b.sql"), 0644)) + require.NoError(t, afero.WriteFile(fsys, "b.sql", []byte("\\ir a.sql"), 0644)) + + result, err := traverseImports([]string{"a.sql"}, fsys) + + assert.NoError(t, err) + assert.Len(t, result, 2) + }) +}