Validate paths for static file server

This commit is contained in:
2024-09-30 19:31:20 +02:00
parent 58fe6355bc
commit ab27b36aad
2 changed files with 22 additions and 12 deletions

View File

@@ -58,20 +58,26 @@ func main() {
api.SetupRoutes(r, database, fs) api.SetupRoutes(r, database, fs)
}) })
// Set up static file server // Set up static file server with path validation
staticPath := os.Getenv("NOVAMD_STATIC_PATH") staticPath := os.Getenv("NOVAMD_STATIC_PATH")
if staticPath == "" { if staticPath == "" {
staticPath = "../frontend/dist" staticPath = "../frontend/dist"
} }
fileServer := http.FileServer(http.Dir(staticPath)) fileServer := http.FileServer(http.Dir(staticPath))
r.Get("/*", func(w http.ResponseWriter, r *http.Request) { r.Get("/*", func(w http.ResponseWriter, r *http.Request) {
filePath := filepath.Join(staticPath, r.URL.Path) requestedPath := r.URL.Path
_, err := os.Stat(filePath) validatedPath, err := filesystem.ValidatePath(staticPath, requestedPath)
if err != nil {
http.Error(w, "Invalid path", http.StatusBadRequest)
return
}
_, err = os.Stat(validatedPath)
if os.IsNotExist(err) { if os.IsNotExist(err) {
http.ServeFile(w, r, filepath.Join(staticPath, "index.html")) http.ServeFile(w, r, filepath.Join(staticPath, "index.html"))
return return
} }
fileServer.ServeHTTP(w, r) http.StripPrefix("/", fileServer).ServeHTTP(w, r)
}) })
// Start server // Start server
@@ -81,4 +87,4 @@ func main() {
} }
log.Printf("Server starting on port %s", port) log.Printf("Server starting on port %s", port)
log.Fatal(http.ListenAndServe(":"+port, r)) log.Fatal(http.ListenAndServe(":"+port, r))
} }

View File

@@ -2,6 +2,7 @@ package filesystem
import ( import (
"errors" "errors"
"fmt"
"novamd/internal/gitutils" "novamd/internal/gitutils"
"novamd/internal/models" "novamd/internal/models"
"os" "os"
@@ -47,27 +48,30 @@ func (fs *FileSystem) InitializeGitRepo() error {
return fs.GitRepo.EnsureRepo() return fs.GitRepo.EnsureRepo()
} }
// validatePath checks if the given path is within the root directory func ValidatePath(rootDir, path string) (string, error) {
func (fs *FileSystem) validatePath(path string) (string, error) { fullPath := filepath.Join(rootDir, path)
fullPath := filepath.Join(fs.RootDir, path)
cleanPath := filepath.Clean(fullPath) cleanPath := filepath.Clean(fullPath)
if !strings.HasPrefix(cleanPath, fs.RootDir) { if !strings.HasPrefix(cleanPath, filepath.Clean(rootDir)) {
return "", errors.New("invalid path: outside of root directory") return "", fmt.Errorf("invalid path: outside of root directory")
} }
relPath, err := filepath.Rel(fs.RootDir, cleanPath) relPath, err := filepath.Rel(rootDir, cleanPath)
if err != nil { if err != nil {
return "", err return "", err
} }
if strings.HasPrefix(relPath, "..") { if strings.HasPrefix(relPath, "..") {
return "", errors.New("invalid path: outside of root directory") return "", fmt.Errorf("invalid path: outside of root directory")
} }
return cleanPath, nil return cleanPath, nil
} }
func (fs *FileSystem) validatePath(path string) (string, error) {
return ValidatePath(fs.RootDir, path)
}
func (fs *FileSystem) ListFilesRecursively() ([]FileNode, error) { func (fs *FileSystem) ListFilesRecursively() ([]FileNode, error) {
return fs.walkDirectory(fs.RootDir) return fs.walkDirectory(fs.RootDir)
} }