diff --git a/http/interceptor.go b/http/interceptor.go index 92047d307..7b333cc0a 100644 --- a/http/interceptor.go +++ b/http/interceptor.go @@ -43,7 +43,7 @@ func (i *rwInterceptor) WriteHeader(statusCode int) { i.statusCode = statusCode if it := i.tx.ProcessResponseHeaders(statusCode, i.proto); it != nil { - i.statusCode = obtainStatusCodeFromInterruptionOrDefault(it, i.statusCode) + i.statusCode = ObtainStatusCodeFromInterruptionOrDefault(it, i.statusCode) i.flushWriteHeader() return } @@ -109,13 +109,13 @@ func (i *rwInterceptor) Header() http.Header { var _ http.ResponseWriter = (*rwInterceptor)(nil) -// wrap wraps the interceptor into a response writer that also preserves +// Wrap wraps the interceptor into a response writer that also preserves // the http interfaces implemented by the original response writer to avoid // the observer effect. It also returns the response processor which takes care // of the response body copyback from the transaction buffer. // // Heavily inspired in https://github.com/openzipkin/zipkin-go/blob/master/middleware/http/server.go#L218 -func wrap(w http.ResponseWriter, r *http.Request, tx types.Transaction) ( +func Wrap(w http.ResponseWriter, r *http.Request, tx types.Transaction) ( http.ResponseWriter, func(types.Transaction, *http.Request) error, ) { // nolint:gocyclo @@ -136,7 +136,7 @@ func wrap(w http.ResponseWriter, r *http.Request, tx types.Transaction) ( i.flushWriteHeader() return err } else if it != nil { - i.overrideWriteHeader(obtainStatusCodeFromInterruptionOrDefault(it, i.statusCode)) + i.overrideWriteHeader(ObtainStatusCodeFromInterruptionOrDefault(it, i.statusCode)) i.flushWriteHeader() return nil } diff --git a/http/interceptor_test.go b/http/interceptor_test.go index e8424705b..c5c498536 100644 --- a/http/interceptor_test.go +++ b/http/interceptor_test.go @@ -24,7 +24,7 @@ func TestWriteHeader(t *testing.T) { tx := waf.NewTransaction() req, _ := http.NewRequest("GET", "", nil) res := httptest.NewRecorder() - rw, responseProcessor := wrap(res, req, tx) + rw, responseProcessor := Wrap(res, req, tx) rw.WriteHeader(204) rw.WriteHeader(205) // although we called WriteHeader, status code should be applied until diff --git a/http/middleware.go b/http/middleware.go index 06bc99044..d07076935 100644 --- a/http/middleware.go +++ b/http/middleware.go @@ -18,12 +18,12 @@ import ( "github.com/corazawaf/coraza/v3/types" ) -// processRequest fills all transaction variables from an http.Request object +// ProcessRequest fills all transaction variables from an http.Request object // Most implementations of Coraza will probably use http.Request objects // so this will implement all phase 0, 1 and 2 variables // Note: This function will stop after an interruption // Note: Do not manually fill any request variables -func processRequest(tx types.Transaction, req *http.Request) (*types.Interruption, error) { +func ProcessRequest(tx types.Transaction, req *http.Request) (*types.Interruption, error) { var ( client string cport int @@ -139,20 +139,20 @@ func WrapHandler(waf coraza.WAF, h http.Handler) http.Handler { // ProcessRequest is just a wrapper around ProcessConnection, ProcessURI, // ProcessRequestHeaders and ProcessRequestBody. // It fails if any of these functions returns an error and it stops on interruption. - if it, err := processRequest(tx, r); err != nil { + if it, err := ProcessRequest(tx, r); err != nil { tx.DebugLogger().Error().Err(err).Msg("Failed to process request") return } else if it != nil { - w.WriteHeader(obtainStatusCodeFromInterruptionOrDefault(it, http.StatusOK)) + w.WriteHeader(ObtainStatusCodeFromInterruptionOrDefault(it, http.StatusOK)) return } - ww, processResponse := wrap(w, r, tx) + ww, ProcessResponse := Wrap(w, r, tx) // We continue with the other middlewares by catching the response h.ServeHTTP(ww, r) - if err := processResponse(tx, r); err != nil { + if err := ProcessResponse(tx, r); err != nil { tx.DebugLogger().Error().Err(err).Msg("Failed to close the response") return } @@ -161,9 +161,9 @@ func WrapHandler(waf coraza.WAF, h http.Handler) http.Handler { return http.HandlerFunc(fn) } -// obtainStatusCodeFromInterruptionOrDefault returns the desired status code derived from the interruption +// ObtainStatusCodeFromInterruptionOrDefault returns the desired status code derived from the interruption // on a "deny" action or a default value. -func obtainStatusCodeFromInterruptionOrDefault(it *types.Interruption, defaultStatusCode int) int { +func ObtainStatusCodeFromInterruptionOrDefault(it *types.Interruption, defaultStatusCode int) int { if it.Action == "deny" { statusCode := it.Status if statusCode == 0 { diff --git a/http/middleware_test.go b/http/middleware_test.go index 7cbb15e21..0a267335e 100644 --- a/http/middleware_test.go +++ b/http/middleware_test.go @@ -32,7 +32,7 @@ func TestProcessRequest(t *testing.T) { req, _ := http.NewRequest("POST", "https://www.coraza.io/test", strings.NewReader("test=456")) waf, _ := coraza.NewWAF(coraza.NewWAFConfig()) tx := waf.NewTransaction().(*corazawaf.Transaction) - if _, err := processRequest(tx, req); err != nil { + if _, err := ProcessRequest(tx, req); err != nil { t.Fatal(err) } if tx.Variables().RequestMethod().Get() != "POST" { @@ -48,7 +48,7 @@ func TestProcessRequestEngineOff(t *testing.T) { // TODO(jcchavezs): Shall we make RuleEngine a first class method in WAF config? waf, _ := coraza.NewWAF(coraza.NewWAFConfig().WithDirectives("SecRuleEngine OFF")) tx := waf.NewTransaction().(*corazawaf.Transaction) - if _, err := processRequest(tx, req); err != nil { + if _, err := ProcessRequest(tx, req); err != nil { t.Fatal(err) } if tx.Variables().RequestMethod().Get() != "POST" { @@ -66,7 +66,7 @@ func TestProcessRequestMultipart(t *testing.T) { req := createMultipartRequest(t) - if _, err := processRequest(tx, req); err != nil { + if _, err := ProcessRequest(tx, req); err != nil { t.Fatal(err) } @@ -95,7 +95,7 @@ SecRule &REQUEST_HEADERS:Transfer-Encoding "!@eq 0" "id:1,phase:1,deny" req, _ := http.NewRequest("GET", "https://www.coraza.io/test", nil) req.TransferEncoding = []string{"chunked"} - it, err := processRequest(tx, req) + it, err := ProcessRequest(tx, req) if err != nil { t.Fatal(err) } @@ -187,7 +187,7 @@ func TestDirectiveSecAuditLog(t *testing.T) { t.Errorf("Description HTTP request parsing failed") } - _, err = processRequest(tx, req) + _, err = ProcessRequest(tx, req) if err != nil { t.Errorf("Failed to load the HTTP request") } @@ -504,7 +504,7 @@ func TestObtainStatusCodeFromInterruptionOrDefault(t *testing.T) { for name, tCase := range tCases { t.Run(name, func(t *testing.T) { want := tCase.expectedCode - have := obtainStatusCodeFromInterruptionOrDefault(&types.Interruption{ + have := ObtainStatusCodeFromInterruptionOrDefault(&types.Interruption{ Status: tCase.interruptionCode, Action: tCase.interruptionAction, }, tCase.defaultCode)