From 36092caa20d0265c705221dbcbf2f64ff3e5b1b5 Mon Sep 17 00:00:00 2001 From: Umang Mundhra Date: Mon, 4 Nov 2024 12:26:36 +0530 Subject: [PATCH] Fix AddStaticFiles method (#1147) --- pkg/gofr/gofr.go | 27 ++++++++++++++++----------- pkg/gofr/http/router.go | 19 +++++++++++++++---- 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/pkg/gofr/gofr.go b/pkg/gofr/gofr.go index ab4ae49c4..6e7adc5be 100644 --- a/pkg/gofr/gofr.go +++ b/pkg/gofr/gofr.go @@ -88,6 +88,22 @@ func New() *App { app.httpServer.certFile = app.Config.GetOrDefault("CERT_FILE", "") app.httpServer.keyFile = app.Config.GetOrDefault("KEY_FILE", "") + // Add Default routes + app.add(http.MethodGet, "/.well-known/health", healthHandler) + app.add(http.MethodGet, "/.well-known/alive", liveHandler) + app.add(http.MethodGet, "/favicon.ico", faviconHandler) + + // If the openapi.json file exists in the static directory, set up routes for OpenAPI and Swagger documentation. + if _, err = os.Stat("./static/" + gofrHTTP.DefaultSwaggerFileName); err == nil { + // Route to serve the OpenAPI JSON specification file. + app.add(http.MethodGet, "/.well-known/"+gofrHTTP.DefaultSwaggerFileName, OpenAPIHandler) + // Route to serve the Swagger UI, providing a user interface for the API documentation. + app.add(http.MethodGet, "/.well-known/swagger", SwaggerUIHandler) + // Catchall route: any request to /.well-known/{name} (e.g., /.well-known/other) + // will be handled by the SwaggerUIHandler, serving the Swagger UI. + app.add(http.MethodGet, "/.well-known/{name}", SwaggerUIHandler) + } + if app.Config.Get("APP_ENV") == "DEBUG" { app.httpServer.RegisterProfilingRoutes() } @@ -226,17 +242,6 @@ func (a *App) Shutdown(ctx context.Context) error { } func (a *App) httpServerSetup() { - // Add Default routes - a.add(http.MethodGet, "/.well-known/health", healthHandler) - a.add(http.MethodGet, "/.well-known/alive", liveHandler) - a.add(http.MethodGet, "/favicon.ico", faviconHandler) - - if _, err := os.Stat("./static/openapi.json"); err == nil { - a.add(http.MethodGet, "/.well-known/openapi.json", OpenAPIHandler) - a.add(http.MethodGet, "/.well-known/swagger", SwaggerUIHandler) - a.add(http.MethodGet, "/.well-known/{name}", SwaggerUIHandler) - } - // TODO: find a way to read REQUEST_TIMEOUT config only once and log it there. currently doing it twice one for populating // the value and other for logging requestTimeout := a.Config.Get("REQUEST_TIMEOUT") diff --git a/pkg/gofr/http/router.go b/pkg/gofr/http/router.go index 48c3ecee1..137e260f2 100644 --- a/pkg/gofr/http/router.go +++ b/pkg/gofr/http/router.go @@ -2,14 +2,16 @@ package http import ( "net/http" - "os" "path/filepath" "strings" "github.com/gorilla/mux" + "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" ) +const DefaultSwaggerFileName = "openapi.json" + // Router is responsible for routing HTTP request. type Router struct { mux.Router @@ -56,6 +58,13 @@ func (rou *Router) AddStaticFiles(endpoint, dirName string) { cfg := staticFileConfig{directoryName: dirName} fileServer := http.FileServer(http.Dir(cfg.directoryName)) + + if endpoint == "/" { + rou.Router.NewRoute().PathPrefix("/").Handler(cfg.staticHandler(fileServer)) + + return + } + rou.Router.NewRoute().PathPrefix(endpoint + "/").Handler(http.StripPrefix(endpoint, cfg.staticHandler(fileServer))) } @@ -67,9 +76,11 @@ func (staticConfig staticFileConfig) staticHandler(fileServer http.Handler) http fileName := filePath[len(filePath)-1] - const defaultSwaggerFileName = "openapi.json" - - if _, err := os.Stat(filepath.Clean(filepath.Join(staticConfig.directoryName, url))); fileName == defaultSwaggerFileName && err == nil { + // Prevent direct access to the openapi.json file via static file routes. + // The file should only be accessible through the explicitly defined /.well-known/swagger or + // /.well-known/openapi.json for controlled access. + absPath, err := filepath.Abs(filepath.Join(staticConfig.directoryName, url)) + if err != nil || !strings.HasPrefix(absPath, staticConfig.directoryName) || (fileName == DefaultSwaggerFileName && err == nil) { w.WriteHeader(http.StatusForbidden) _, _ = w.Write([]byte("403 forbidden"))