diff --git a/pkg/tools/script_resolver.go b/pkg/tools/script_resolver.go index 0b1c3d7..5983e1b 100644 --- a/pkg/tools/script_resolver.go +++ b/pkg/tools/script_resolver.go @@ -5,11 +5,15 @@ import ( "fmt" "os" "path" + "strings" ) // ResolveScript is resolving the target script. func ResolveScript(dir, name string) (string, error) { - script := path.Join(dir, fmt.Sprintf("%s.sh", name)) + script := path.Clean(path.Join(dir, fmt.Sprintf("%s.sh", name))) + if !strings.HasPrefix(script, dir) { + return "", errors.New("Invalid script path: " + name) + } if _, err := os.Stat(script); os.IsNotExist(err) { return "", errors.New("Script not found: " + script) } diff --git a/pkg/tools_test/script_resolver_test.go b/pkg/tools_test/script_resolver_test.go index 3e7a543..0b5f07c 100644 --- a/pkg/tools_test/script_resolver_test.go +++ b/pkg/tools_test/script_resolver_test.go @@ -8,7 +8,7 @@ import ( ) func TestResolveScript(t *testing.T) { - script, err := tools.ResolveScript("../../scripts", "echo") + script, err := tools.ResolveScript("../../scripts", "../scripts/echo") assert.Nil(t, err, "") assert.Equal(t, "../../scripts/echo.sh", script, "") } @@ -18,3 +18,9 @@ func TestNotResolveScript(t *testing.T) { assert.NotNil(t, err, "") assert.Equal(t, "Script not found: ../../scripts/foo.sh", err.Error(), "") } + +func TestResolveBadScript(t *testing.T) { + _, err := tools.ResolveScript("../../scripts", "../tests/test_simple") + assert.NotNil(t, err, "") + assert.Equal(t, "Invalid script path: ../tests/test_simple", err.Error(), "") +}