diff --git a/init.go b/init.go index 807d201d1c..2d0e1704eb 100644 --- a/init.go +++ b/init.go @@ -9,11 +9,22 @@ import ( "github.com/go-task/task/v3/taskfile" ) -const defaultFilename = "Taskfile.yml" +var defaultFilename = "Taskfile.yml" + +// SetDefaultFilename sets the default filename for testing purposes. +func SetDefaultFilename(name string) { + defaultFilename = name +} //go:embed taskfile/templates/default.yml var DefaultTaskfile string +func init() { + if name := os.Getenv("TASKFILE_DEFAULT_NAME"); name != "" { + defaultFilename = name + } +} + // InitTaskfile creates a new Taskfile at path. // // path can be either a file path or a directory path. diff --git a/init_test.go b/init_test.go index 41095d65e1..748e85ae79 100644 --- a/init_test.go +++ b/init_test.go @@ -30,6 +30,37 @@ func TestInitDir(t *testing.T) { _ = os.Remove(file) } +func TestInitDirWithCustomDefaultName(t *testing.T) { + const dir = "testdata/init" + + // Set environment variable before running the test + t.Setenv("TASKFILE_DEFAULT_NAME", "Taskfile.yaml") + + file := filepathext.SmartJoin(dir, "Taskfile.yaml") + defaultFile := filepathext.SmartJoin(dir, "Taskfile.yml") + + // Clean up any existing files + _ = os.Remove(file) + _ = os.Remove(defaultFile) + if _, err := os.Stat(file); err == nil { + t.Errorf("Taskfile.yaml should not exist") + } + + // Manually call init logic + task.SetDefaultFilename("Taskfile.yaml") + defer task.SetDefaultFilename("Taskfile.yml") + + if _, err := task.InitTaskfile(dir); err != nil { + t.Error(err) + } + + if _, err := os.Stat(file); err != nil { + t.Errorf("Taskfile.yaml should exist") + } + + _ = os.Remove(file) +} + func TestInitFile(t *testing.T) { t.Parallel()