Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 58 additions & 9 deletions internal/db/test/test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"os"
"path"
"path/filepath"
"regexp"
"strings"

"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/network"
Expand All @@ -25,31 +27,43 @@ const (
DISABLE_PGTAP = "drop extension if exists pgtap"
)

var irPattern = regexp.MustCompile(`(?im)^\s*\\ir\s+['"]?([^'"\s]+)['"]?`)

func Run(ctx context.Context, testFiles []string, config pgconn.Config, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
// Build test command
if len(testFiles) == 0 {
absTestsDir, err := filepath.Abs(utils.DbTestsDir)
if err != nil {
return errors.Errorf("failed to resolve tests dir: %w", err)
}
testFiles = append(testFiles, absTestsDir)
testFiles = append(testFiles, utils.DbTestsDir)
}
allFiles, err := traverseImports(testFiles, fsys)
if err != nil {
return err
}
binds := make([]string, len(testFiles))
testFileSet := make(map[string]struct{}, len(testFiles))
for _, tf := range testFiles {
testFileSet[tf] = struct{}{}
}
binds := make([]string, len(allFiles))
cmd := []string{"pg_prove", "--ext", ".pg", "--ext", ".sql", "-r"}
var workingDir string
for i, fp := range testFiles {
for i, fp := range allFiles {
if !filepath.IsAbs(fp) {
fp = filepath.Join(utils.CurrentDirAbs, fp)
}
dockerPath := utils.ToDockerPath(fp)
cmd = append(cmd, dockerPath)
binds[i] = fmt.Sprintf("%s:%s:ro", fp, dockerPath)
if workingDir == "" {
workingDir = dockerPath
if path.Ext(dockerPath) != "" {
workingDir = path.Dir(dockerPath)
}
}
if _, isTestFile := testFileSet[allFiles[i]]; isTestFile {
relPath := dockerPath
if path.Ext(dockerPath) != "" && path.Dir(dockerPath) == workingDir {
relPath = path.Base(dockerPath)
}
cmd = append(cmd, relPath)
}
binds[i] = fmt.Sprintf("%s:%s:ro", fp, dockerPath)
}
if viper.GetBool("DEBUG") {
cmd = append(cmd, "--verbose")
Expand Down Expand Up @@ -107,3 +121,38 @@ func Run(ctx context.Context, testFiles []string, config pgconn.Config, fsys afe
os.Stderr,
)
}

func traverseImports(testFiles []string, fsys afero.Fs) ([]string, error) {
seen := map[string]struct{}{}
q := append([]string{}, testFiles...)
result := []string{}
for len(q) > 0 {
curr := q[len(q)-1]
q = q[:len(q)-1]
if _, ok := seen[curr]; ok {
continue
}
seen[curr] = struct{}{}
result = append(result, curr)
info, err := fsys.Stat(curr)
if err != nil {
return nil, errors.Errorf("failed to stat %s: %w", curr, err)
}
if info.IsDir() {
continue
}
data, err := afero.ReadFile(fsys, curr)
if err != nil {
return nil, errors.Errorf("failed to read %s: %w", curr, err)
}
for _, m := range irPattern.FindAllStringSubmatch(string(data), -1) {
if len(m) < 2 {
continue
}
importPath := strings.TrimSpace(m[1])
resolved := filepath.Join(filepath.Dir(curr), importPath)
q = append(q, resolved)
}
}
return result, nil
}
40 changes: 40 additions & 0 deletions internal/db/test/test_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func TestRunCommand(t *testing.T) {
// Setup in-memory fs
fsys := afero.NewMemMapFs()
require.NoError(t, utils.WriteConfig(fsys, false))
require.NoError(t, afero.WriteFile(fsys, "nested", []byte("SELECT 1;"), 0644))
// Setup mock postgres
conn := pgtest.NewConn()
defer conn.Close(t)
Expand All @@ -53,6 +54,7 @@ func TestRunCommand(t *testing.T) {
// Setup in-memory fs
fsys := afero.NewMemMapFs()
require.NoError(t, utils.WriteConfig(fsys, false))
require.NoError(t, fsys.MkdirAll(utils.DbTestsDir, 0755))
// Run test
err := Run(context.Background(), nil, dbConfig, fsys)
// Check error
Expand All @@ -63,6 +65,7 @@ func TestRunCommand(t *testing.T) {
// Setup in-memory fs
fsys := afero.NewMemMapFs()
require.NoError(t, utils.WriteConfig(fsys, false))
require.NoError(t, fsys.MkdirAll(utils.DbTestsDir, 0755))
// Setup mock postgres
conn := pgtest.NewConn()
defer conn.Close(t)
Expand All @@ -79,6 +82,7 @@ func TestRunCommand(t *testing.T) {
// Setup in-memory fs
fsys := afero.NewMemMapFs()
require.NoError(t, utils.WriteConfig(fsys, false))
require.NoError(t, fsys.MkdirAll(utils.DbTestsDir, 0755))
// Setup mock postgres
conn := pgtest.NewConn()
defer conn.Close(t)
Expand All @@ -99,3 +103,39 @@ func TestRunCommand(t *testing.T) {
assert.Empty(t, apitest.ListUnmatchedRequests())
})
}

func TestTraverseImports(t *testing.T) {
t.Run("handles file with \\ir import", func(t *testing.T) {
fsys := afero.NewMemMapFs()
require.NoError(t, afero.WriteFile(fsys, "main.sql", []byte("\\ir helper.sql"), 0644))
require.NoError(t, afero.WriteFile(fsys, "helper.sql", []byte("SELECT 1;"), 0644))

result, err := traverseImports([]string{"main.sql"}, fsys)

assert.NoError(t, err)
assert.Len(t, result, 2)
})

t.Run("handles nested \\ir imports", func(t *testing.T) {
fsys := afero.NewMemMapFs()
require.NoError(t, afero.WriteFile(fsys, "main.sql", []byte("\\ir level1.sql"), 0644))
require.NoError(t, afero.WriteFile(fsys, "level1.sql", []byte("\\ir level2.sql"), 0644))
require.NoError(t, afero.WriteFile(fsys, "level2.sql", []byte("SELECT 1;"), 0644))

result, err := traverseImports([]string{"main.sql"}, fsys)

assert.NoError(t, err)
assert.Len(t, result, 3)
})

t.Run("handles circular imports", func(t *testing.T) {
fsys := afero.NewMemMapFs()
require.NoError(t, afero.WriteFile(fsys, "a.sql", []byte("\\ir b.sql"), 0644))
require.NoError(t, afero.WriteFile(fsys, "b.sql", []byte("\\ir a.sql"), 0644))

result, err := traverseImports([]string{"a.sql"}, fsys)

assert.NoError(t, err)
assert.Len(t, result, 2)
})
}
Loading