diff --git a/backend/internal/api/handlers.go b/backend/internal/api/handlers.go index f59b71f..008e5de 100644 --- a/backend/internal/api/handlers.go +++ b/backend/internal/api/handlers.go @@ -16,7 +16,7 @@ func ListFiles(fs *filesystem.FileSystem) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { files, err := fs.ListFilesRecursively() if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + http.Error(w, "Failed to list files", http.StatusInternalServerError) return } @@ -30,7 +30,7 @@ func GetFileContent(fs *filesystem.FileSystem) http.HandlerFunc { filePath := strings.TrimPrefix(r.URL.Path, "/api/v1/files/") content, err := fs.GetFileContent(filePath) if err != nil { - http.Error(w, err.Error(), http.StatusNotFound) + http.Error(w, "Failed to read file", http.StatusNotFound) return } diff --git a/backend/internal/filesystem/filesystem.go b/backend/internal/filesystem/filesystem.go index 2211918..fed5a98 100644 --- a/backend/internal/filesystem/filesystem.go +++ b/backend/internal/filesystem/filesystem.go @@ -1,8 +1,10 @@ package filesystem import ( + "errors" "os" "path/filepath" + "strings" ) type FileSystem struct { @@ -19,6 +21,27 @@ func New(rootDir string) *FileSystem { return &FileSystem{RootDir: rootDir} } +// validatePath checks if the given path is within the root directory +func (fs *FileSystem) validatePath(path string) (string, error) { + fullPath := filepath.Join(fs.RootDir, path) + cleanPath := filepath.Clean(fullPath) + + if !strings.HasPrefix(cleanPath, fs.RootDir) { + return "", errors.New("invalid path: outside of root directory") + } + + relPath, err := filepath.Rel(fs.RootDir, cleanPath) + if err != nil { + return "", err + } + + if strings.HasPrefix(relPath, "..") { + return "", errors.New("invalid path: outside of root directory") + } + + return cleanPath, nil +} + func (fs *FileSystem) ListFilesRecursively() ([]FileNode, error) { return fs.walkDirectory(fs.RootDir) } @@ -55,14 +78,20 @@ func (fs *FileSystem) walkDirectory(dir string) ([]FileNode, error) { } func (fs *FileSystem) GetFileContent(filePath string) ([]byte, error) { - fullPath := filepath.Join(fs.RootDir, filePath) + fullPath, err := fs.validatePath(filePath) + if err != nil { + return nil, err + } return os.ReadFile(fullPath) } func (fs *FileSystem) SaveFile(filePath string, content []byte) error { - fullPath := filepath.Join(fs.RootDir, filePath) - dir := filepath.Dir(fullPath) + fullPath, err := fs.validatePath(filePath) + if err != nil { + return err + } + dir := filepath.Dir(fullPath) if err := os.MkdirAll(dir, 0755); err != nil { return err } @@ -71,6 +100,9 @@ func (fs *FileSystem) SaveFile(filePath string, content []byte) error { } func (fs *FileSystem) DeleteFile(filePath string) error { - fullPath := filepath.Join(fs.RootDir, filePath) + fullPath, err := fs.validatePath(filePath) + if err != nil { + return err + } return os.Remove(fullPath) }