diff --git a/pkg/manager/remote.go b/pkg/manager/remote.go index f149fa9..2dda681 100644 --- a/pkg/manager/remote.go +++ b/pkg/manager/remote.go @@ -9,6 +9,8 @@ import ( "llamactl/pkg/config" "llamactl/pkg/instance" "net/http" + "net/url" + "strings" "sync" "time" ) @@ -80,6 +82,38 @@ func (rm *remoteManager) removeInstance(instanceName string) { // --- HTTP request helpers --- +// validateInstanceNameForURL ensures the instance name is safe for use in URLs. +func validateInstanceNameForURL(name string) (string, error) { + if name == "" { + return "", fmt.Errorf("instance name cannot be empty") + } + + // Check for path separators and parent directory references + // This prevents path traversal and SSRF attacks + if strings.Contains(name, "/") || strings.Contains(name, "\\") || strings.Contains(name, "..") { + return "", fmt.Errorf("invalid instance name: %s (cannot contain path separators or '..')", name) + } + + // Check for URL-unsafe characters that could be used for injection + // Reject names with special URL characters that could allow URL manipulation + unsafeChars := []string{"?", "&", "#", "%", "=", "@", ":", " "} + for _, char := range unsafeChars { + if strings.Contains(name, char) { + return "", fmt.Errorf("invalid instance name: %s (cannot contain URL-unsafe characters)", name) + } + } + + // Additional validation: use url.PathEscape to ensure the name doesn't change + // when URL-encoded (indicating it contains characters that need encoding) + // This catches any other characters that could cause issues in URLs + escaped := url.PathEscape(name) + if escaped != name { + return "", fmt.Errorf("invalid instance name: %s (contains characters requiring URL encoding)", name) + } + + return name, nil +} + // makeRemoteRequest creates and executes an HTTP request to a remote node with context support. func (rm *remoteManager) makeRemoteRequest(ctx context.Context, nodeConfig *config.NodeConfig, method, path string, body any) (*http.Response, error) { var reqBody io.Reader @@ -139,7 +173,12 @@ func parseRemoteResponse(resp *http.Response, result any) error { // CreateInstance creates a new instance on a remote node. func (rm *remoteManager) createInstance(ctx context.Context, node *config.NodeConfig, name string, opts *instance.Options) (*instance.Instance, error) { - path := fmt.Sprintf("%s%s/", apiBasePath, name) + validatedName, err := validateInstanceNameForURL(name) + if err != nil { + return nil, err + } + + path := fmt.Sprintf("%s%s/", apiBasePath, validatedName) resp, err := rm.makeRemoteRequest(ctx, node, "POST", path, opts) if err != nil { @@ -156,7 +195,12 @@ func (rm *remoteManager) createInstance(ctx context.Context, node *config.NodeCo // GetInstance retrieves an instance by name from a remote node. func (rm *remoteManager) getInstance(ctx context.Context, node *config.NodeConfig, name string) (*instance.Instance, error) { - path := fmt.Sprintf("%s%s/", apiBasePath, name) + validatedName, err := validateInstanceNameForURL(name) + if err != nil { + return nil, err + } + + path := fmt.Sprintf("%s%s/", apiBasePath, validatedName) resp, err := rm.makeRemoteRequest(ctx, node, "GET", path, nil) if err != nil { return nil, err @@ -172,7 +216,12 @@ func (rm *remoteManager) getInstance(ctx context.Context, node *config.NodeConfi // UpdateInstance updates an existing instance on a remote node. func (rm *remoteManager) updateInstance(ctx context.Context, node *config.NodeConfig, name string, opts *instance.Options) (*instance.Instance, error) { - path := fmt.Sprintf("%s%s/", apiBasePath, name) + validatedName, err := validateInstanceNameForURL(name) + if err != nil { + return nil, err + } + + path := fmt.Sprintf("%s%s/", apiBasePath, validatedName) resp, err := rm.makeRemoteRequest(ctx, node, "PUT", path, opts) if err != nil { @@ -189,7 +238,12 @@ func (rm *remoteManager) updateInstance(ctx context.Context, node *config.NodeCo // DeleteInstance deletes an instance from a remote node. func (rm *remoteManager) deleteInstance(ctx context.Context, node *config.NodeConfig, name string) error { - path := fmt.Sprintf("%s%s/", apiBasePath, name) + validatedName, err := validateInstanceNameForURL(name) + if err != nil { + return err + } + + path := fmt.Sprintf("%s%s/", apiBasePath, validatedName) resp, err := rm.makeRemoteRequest(ctx, node, "DELETE", path, nil) if err != nil { return err @@ -200,7 +254,12 @@ func (rm *remoteManager) deleteInstance(ctx context.Context, node *config.NodeCo // StartInstance starts an instance on a remote node. func (rm *remoteManager) startInstance(ctx context.Context, node *config.NodeConfig, name string) (*instance.Instance, error) { - path := fmt.Sprintf("%s%s/start", apiBasePath, name) + validatedName, err := validateInstanceNameForURL(name) + if err != nil { + return nil, err + } + + path := fmt.Sprintf("%s%s/start", apiBasePath, validatedName) resp, err := rm.makeRemoteRequest(ctx, node, "POST", path, nil) if err != nil { return nil, err @@ -216,7 +275,12 @@ func (rm *remoteManager) startInstance(ctx context.Context, node *config.NodeCon // StopInstance stops an instance on a remote node. func (rm *remoteManager) stopInstance(ctx context.Context, node *config.NodeConfig, name string) (*instance.Instance, error) { - path := fmt.Sprintf("%s%s/stop", apiBasePath, name) + validatedName, err := validateInstanceNameForURL(name) + if err != nil { + return nil, err + } + + path := fmt.Sprintf("%s%s/stop", apiBasePath, validatedName) resp, err := rm.makeRemoteRequest(ctx, node, "POST", path, nil) if err != nil { return nil, err @@ -232,7 +296,12 @@ func (rm *remoteManager) stopInstance(ctx context.Context, node *config.NodeConf // RestartInstance restarts an instance on a remote node. func (rm *remoteManager) restartInstance(ctx context.Context, node *config.NodeConfig, name string) (*instance.Instance, error) { - path := fmt.Sprintf("%s%s/restart", apiBasePath, name) + validatedName, err := validateInstanceNameForURL(name) + if err != nil { + return nil, err + } + + path := fmt.Sprintf("%s%s/restart", apiBasePath, validatedName) resp, err := rm.makeRemoteRequest(ctx, node, "POST", path, nil) if err != nil { return nil, err @@ -248,7 +317,12 @@ func (rm *remoteManager) restartInstance(ctx context.Context, node *config.NodeC // GetInstanceLogs retrieves logs for an instance from a remote node. func (rm *remoteManager) getInstanceLogs(ctx context.Context, node *config.NodeConfig, name string, numLines int) (string, error) { - path := fmt.Sprintf("%s%s/logs?lines=%d", apiBasePath, name, numLines) + validatedName, err := validateInstanceNameForURL(name) + if err != nil { + return "", err + } + + path := fmt.Sprintf("%s%s/logs?lines=%d", apiBasePath, validatedName, numLines) resp, err := rm.makeRemoteRequest(ctx, node, "GET", path, nil) if err != nil { return "", err diff --git a/pkg/manager/remote_test.go b/pkg/manager/remote_test.go new file mode 100644 index 0000000..f628921 --- /dev/null +++ b/pkg/manager/remote_test.go @@ -0,0 +1,68 @@ +package manager + +import ( + "testing" +) + +// TestValidateInstanceNameForURL tests the URL validation function +func TestValidateInstanceNameForURL(t *testing.T) { + tests := []struct { + name string + input string + shouldError bool + }{ + // Valid names + {"valid simple name", "test-instance", false}, + {"valid with underscores", "my_instance", false}, + {"valid with numbers", "instance123", false}, + {"valid with dashes", "test-name-with-dashes", false}, + + // Invalid names - path traversal + {"path traversal with ..", "../../etc/passwd", true}, + {"path traversal multiple", "../../../etc/shadow", true}, + {"path traversal in middle", "foo/../bar", true}, + {"double dots variation", ".../", true}, + + // Invalid names - path separators + {"forward slash", "foo/bar", true}, + {"backslash", "foo\\bar", true}, + {"absolute path", "/etc/passwd", true}, + + // Invalid names - URL-unsafe characters + {"question mark", "test?param=value", true}, + {"ampersand", "test¶m", true}, + {"hash", "test#anchor", true}, + {"percent", "test%20space", true}, + {"equals", "test=value", true}, + {"at sign", "test@example", true}, + {"colon", "test:8080", true}, + {"space", "test instance", true}, + + // Invalid names - empty + {"empty string", "", true}, + + // Invalid names - characters requiring encoding + {"unicode", "test\u00e9", true}, + {"newline", "test\n", true}, + {"tab", "test\t", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + validatedName, err := validateInstanceNameForURL(tt.input) + + if tt.shouldError { + if err == nil { + t.Errorf("Expected error for input %q, but got none", tt.input) + } + } else { + if err != nil { + t.Errorf("Expected no error for input %q, but got: %v", tt.input, err) + } + if validatedName != tt.input { + t.Errorf("Expected validated name to be %q, but got %q", tt.input, validatedName) + } + } + }) + } +}