diff --git a/pkg/manager/remote.go b/pkg/manager/remote.go index 2dda681..acabcd0 100644 --- a/pkg/manager/remote.go +++ b/pkg/manager/remote.go @@ -10,7 +10,6 @@ import ( "llamactl/pkg/instance" "net/http" "net/url" - "strings" "sync" "time" ) @@ -82,38 +81,6 @@ func (rm *remoteManager) removeInstance(instanceName string) { // --- HTTP request helpers --- -// validateInstanceNameForURL ensures the instance name is safe for use in URLs. -func validateInstanceNameForURL(name string) (string, error) { - if name == "" { - return "", fmt.Errorf("instance name cannot be empty") - } - - // Check for path separators and parent directory references - // This prevents path traversal and SSRF attacks - if strings.Contains(name, "/") || strings.Contains(name, "\\") || strings.Contains(name, "..") { - return "", fmt.Errorf("invalid instance name: %s (cannot contain path separators or '..')", name) - } - - // Check for URL-unsafe characters that could be used for injection - // Reject names with special URL characters that could allow URL manipulation - unsafeChars := []string{"?", "&", "#", "%", "=", "@", ":", " "} - for _, char := range unsafeChars { - if strings.Contains(name, char) { - return "", fmt.Errorf("invalid instance name: %s (cannot contain URL-unsafe characters)", name) - } - } - - // Additional validation: use url.PathEscape to ensure the name doesn't change - // when URL-encoded (indicating it contains characters that need encoding) - // This catches any other characters that could cause issues in URLs - escaped := url.PathEscape(name) - if escaped != name { - return "", fmt.Errorf("invalid instance name: %s (contains characters requiring URL encoding)", name) - } - - return name, nil -} - // makeRemoteRequest creates and executes an HTTP request to a remote node with context support. func (rm *remoteManager) makeRemoteRequest(ctx context.Context, nodeConfig *config.NodeConfig, method, path string, body any) (*http.Response, error) { var reqBody io.Reader @@ -173,12 +140,11 @@ func parseRemoteResponse(resp *http.Response, result any) error { // CreateInstance creates a new instance on a remote node. func (rm *remoteManager) createInstance(ctx context.Context, node *config.NodeConfig, name string, opts *instance.Options) (*instance.Instance, error) { - validatedName, err := validateInstanceNameForURL(name) - if err != nil { - return nil, err - } + // URL-encode the instance name to safely include it in the URL path + // This prevents SSRF and URL injection attacks + escapedName := url.PathEscape(name) - path := fmt.Sprintf("%s%s/", apiBasePath, validatedName) + path := fmt.Sprintf("%s%s/", apiBasePath, escapedName) resp, err := rm.makeRemoteRequest(ctx, node, "POST", path, opts) if err != nil { @@ -195,12 +161,10 @@ func (rm *remoteManager) createInstance(ctx context.Context, node *config.NodeCo // GetInstance retrieves an instance by name from a remote node. func (rm *remoteManager) getInstance(ctx context.Context, node *config.NodeConfig, name string) (*instance.Instance, error) { - validatedName, err := validateInstanceNameForURL(name) - if err != nil { - return nil, err - } + // URL-encode the instance name to safely include it in the URL path + escapedName := url.PathEscape(name) - path := fmt.Sprintf("%s%s/", apiBasePath, validatedName) + path := fmt.Sprintf("%s%s/", apiBasePath, escapedName) resp, err := rm.makeRemoteRequest(ctx, node, "GET", path, nil) if err != nil { return nil, err @@ -216,12 +180,10 @@ func (rm *remoteManager) getInstance(ctx context.Context, node *config.NodeConfi // UpdateInstance updates an existing instance on a remote node. func (rm *remoteManager) updateInstance(ctx context.Context, node *config.NodeConfig, name string, opts *instance.Options) (*instance.Instance, error) { - validatedName, err := validateInstanceNameForURL(name) - if err != nil { - return nil, err - } + // URL-encode the instance name to safely include it in the URL path + escapedName := url.PathEscape(name) - path := fmt.Sprintf("%s%s/", apiBasePath, validatedName) + path := fmt.Sprintf("%s%s/", apiBasePath, escapedName) resp, err := rm.makeRemoteRequest(ctx, node, "PUT", path, opts) if err != nil { @@ -238,12 +200,10 @@ func (rm *remoteManager) updateInstance(ctx context.Context, node *config.NodeCo // DeleteInstance deletes an instance from a remote node. func (rm *remoteManager) deleteInstance(ctx context.Context, node *config.NodeConfig, name string) error { - validatedName, err := validateInstanceNameForURL(name) - if err != nil { - return err - } + // URL-encode the instance name to safely include it in the URL path + escapedName := url.PathEscape(name) - path := fmt.Sprintf("%s%s/", apiBasePath, validatedName) + path := fmt.Sprintf("%s%s/", apiBasePath, escapedName) resp, err := rm.makeRemoteRequest(ctx, node, "DELETE", path, nil) if err != nil { return err @@ -254,12 +214,10 @@ func (rm *remoteManager) deleteInstance(ctx context.Context, node *config.NodeCo // StartInstance starts an instance on a remote node. func (rm *remoteManager) startInstance(ctx context.Context, node *config.NodeConfig, name string) (*instance.Instance, error) { - validatedName, err := validateInstanceNameForURL(name) - if err != nil { - return nil, err - } + // URL-encode the instance name to safely include it in the URL path + escapedName := url.PathEscape(name) - path := fmt.Sprintf("%s%s/start", apiBasePath, validatedName) + path := fmt.Sprintf("%s%s/start", apiBasePath, escapedName) resp, err := rm.makeRemoteRequest(ctx, node, "POST", path, nil) if err != nil { return nil, err @@ -275,12 +233,10 @@ func (rm *remoteManager) startInstance(ctx context.Context, node *config.NodeCon // StopInstance stops an instance on a remote node. func (rm *remoteManager) stopInstance(ctx context.Context, node *config.NodeConfig, name string) (*instance.Instance, error) { - validatedName, err := validateInstanceNameForURL(name) - if err != nil { - return nil, err - } + // URL-encode the instance name to safely include it in the URL path + escapedName := url.PathEscape(name) - path := fmt.Sprintf("%s%s/stop", apiBasePath, validatedName) + path := fmt.Sprintf("%s%s/stop", apiBasePath, escapedName) resp, err := rm.makeRemoteRequest(ctx, node, "POST", path, nil) if err != nil { return nil, err @@ -296,12 +252,10 @@ func (rm *remoteManager) stopInstance(ctx context.Context, node *config.NodeConf // RestartInstance restarts an instance on a remote node. func (rm *remoteManager) restartInstance(ctx context.Context, node *config.NodeConfig, name string) (*instance.Instance, error) { - validatedName, err := validateInstanceNameForURL(name) - if err != nil { - return nil, err - } + // URL-encode the instance name to safely include it in the URL path + escapedName := url.PathEscape(name) - path := fmt.Sprintf("%s%s/restart", apiBasePath, validatedName) + path := fmt.Sprintf("%s%s/restart", apiBasePath, escapedName) resp, err := rm.makeRemoteRequest(ctx, node, "POST", path, nil) if err != nil { return nil, err @@ -317,12 +271,10 @@ func (rm *remoteManager) restartInstance(ctx context.Context, node *config.NodeC // GetInstanceLogs retrieves logs for an instance from a remote node. func (rm *remoteManager) getInstanceLogs(ctx context.Context, node *config.NodeConfig, name string, numLines int) (string, error) { - validatedName, err := validateInstanceNameForURL(name) - if err != nil { - return "", err - } + // URL-encode the instance name to safely include it in the URL path + escapedName := url.PathEscape(name) - path := fmt.Sprintf("%s%s/logs?lines=%d", apiBasePath, validatedName, numLines) + path := fmt.Sprintf("%s%s/logs?lines=%d", apiBasePath, escapedName, numLines) resp, err := rm.makeRemoteRequest(ctx, node, "GET", path, nil) if err != nil { return "", err diff --git a/pkg/manager/remote_test.go b/pkg/manager/remote_test.go deleted file mode 100644 index f628921..0000000 --- a/pkg/manager/remote_test.go +++ /dev/null @@ -1,68 +0,0 @@ -package manager - -import ( - "testing" -) - -// TestValidateInstanceNameForURL tests the URL validation function -func TestValidateInstanceNameForURL(t *testing.T) { - tests := []struct { - name string - input string - shouldError bool - }{ - // Valid names - {"valid simple name", "test-instance", false}, - {"valid with underscores", "my_instance", false}, - {"valid with numbers", "instance123", false}, - {"valid with dashes", "test-name-with-dashes", false}, - - // Invalid names - path traversal - {"path traversal with ..", "../../etc/passwd", true}, - {"path traversal multiple", "../../../etc/shadow", true}, - {"path traversal in middle", "foo/../bar", true}, - {"double dots variation", ".../", true}, - - // Invalid names - path separators - {"forward slash", "foo/bar", true}, - {"backslash", "foo\\bar", true}, - {"absolute path", "/etc/passwd", true}, - - // Invalid names - URL-unsafe characters - {"question mark", "test?param=value", true}, - {"ampersand", "test¶m", true}, - {"hash", "test#anchor", true}, - {"percent", "test%20space", true}, - {"equals", "test=value", true}, - {"at sign", "test@example", true}, - {"colon", "test:8080", true}, - {"space", "test instance", true}, - - // Invalid names - empty - {"empty string", "", true}, - - // Invalid names - characters requiring encoding - {"unicode", "test\u00e9", true}, - {"newline", "test\n", true}, - {"tab", "test\t", true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - validatedName, err := validateInstanceNameForURL(tt.input) - - if tt.shouldError { - if err == nil { - t.Errorf("Expected error for input %q, but got none", tt.input) - } - } else { - if err != nil { - t.Errorf("Expected no error for input %q, but got: %v", tt.input, err) - } - if validatedName != tt.input { - t.Errorf("Expected validated name to be %q, but got %q", tt.input, validatedName) - } - } - }) - } -}