From 4b08376fadbf9fd7f5824d36d868846ef5a233f3 Mon Sep 17 00:00:00 2001 From: Kilowhisky Date: Sun, 5 May 2024 20:43:56 -0700 Subject: [PATCH 1/3] Add support for CORS in http requests --- internal/server/server.go | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/internal/server/server.go b/internal/server/server.go index 26c4e5fd..49b06488 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -867,6 +867,7 @@ func (s *Server) handleInputCommand(client *Client, msg *Message) error { "Connection: close\r\n"+ "Content-Length: %d\r\n"+ "Content-Type: application/json; charset=utf-8\r\n"+ + "Access-Control-Allow-Origin: *\r\n"+ "\r\n", status, len(res)+2) if err != nil { return err @@ -1444,6 +1445,21 @@ func readNextHTTPCommand(packet []byte, argsIn [][]byte, msg *Message, wr io.Wri } method := parts[0] path := parts[1] + // Handle CORS request for allowed origins + if method == "OPTIONS" { + if wr == nil { + return false, errors.New("connection is nil") + } + corshead := "HTTP/1.1 204 No Content\r\n"+ + "Connection: close\r\n"+ + "Access-Control-Allow-Origin: *\r\n"+ + "Access-Control-Allow-Methods: POST, GET, OPTIONS\r\n\r\n" + + if _, err = wr.Write([]byte(corshead)); err != nil { + return false, err + } + return false, nil + } if len(path) == 0 || path[0] != '/' { return false, errInvalidHTTP } @@ -1528,7 +1544,7 @@ func readNextHTTPCommand(packet []byte, argsIn [][]byte, msg *Message, wr io.Wri func readNextCommand(packet []byte, argsIn [][]byte, msg *Message, wr io.Writer) ( complete bool, args [][]byte, kind redcon.Kind, leftover []byte, err error, ) { - if packet[0] == 'G' || packet[0] == 'P' { + if packet[0] == 'G' || packet[0] == 'P' || packet[0] == 'O' { // could be an HTTP request var line []byte for i := 1; i < len(packet); i++ { From e653b765b74fb532421ec77cdd0a019434237b31 Mon Sep 17 00:00:00 2001 From: Kilowhisky Date: Mon, 6 May 2024 21:59:32 -0700 Subject: [PATCH 2/3] Add tests --- tests/proto_test.go | 52 +++++++++++++++++++++++++++++++++++++++++++++ tests/tests_test.go | 1 + 2 files changed, 53 insertions(+) create mode 100644 tests/proto_test.go diff --git a/tests/proto_test.go b/tests/proto_test.go new file mode 100644 index 00000000..40f87dfe --- /dev/null +++ b/tests/proto_test.go @@ -0,0 +1,52 @@ +package tests + +import ( + "fmt" + "net/http" +) + +func subTestProto(g *testGroup) { + g.regSubTest("HTTP CORS", proto_HTTP_CORS_test) +} + +func proto_HTTP_CORS_test(mc *mockServer) error { + // Make CORS request for GET /SERVER + morigin := "http://my-test-origin" + url := fmt.Sprintf("http://127.0.0.1:%d/SERVER", mc.port) + req, err := http.NewRequest(http.MethodOptions, url, nil) + if err != nil { + return err + } + req.Header.Add("Origin", morigin) + req.Header.Add("Access-Control-Request-Method", "GET") + req.Header.Add("Access-Control-Request-Headers", "Authorization") + resp, err := http.DefaultClient.Do(req) + + // Validate CORS response + if err != nil { + return err + } + if resp.StatusCode != 204 { + return fmt.Errorf("expected http stuats '204', got '%d'", resp.StatusCode) + } + origin := resp.Header.Get("Access-Control-Allow-Origin") + methods := resp.Header.Get("Access-Control-Allow-Methods") + if !(origin == "*" || origin == morigin) { + return fmt.Errorf("expected http access-control-allow-origin value '*', got '%s'", origin) + } + if methods != "POST, GET, OPTIONS" { + return fmt.Errorf("expected http access-control-allow-Methods value 'POST, GET, OPTIONS', got '%s'", methods) + } + + // Make the actual request now + resp, err = http.Get(url) + if err != nil { + return err + } + origin = resp.Header.Get("Access-Control-Allow-Origin") + if !(origin == "*" || origin == morigin) { + return fmt.Errorf("expected http access-control-allow-origin value '*', got '%s'", origin) + } + + return nil +} diff --git a/tests/tests_test.go b/tests/tests_test.go index d796f048..260ab9a0 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -57,6 +57,7 @@ func TestIntegration(t *testing.T) { regTestGroup("follower", subTestFollower) regTestGroup("aof", subTestAOF) regTestGroup("monitor", subTestMonitor) + regTestGroup("proto", subTestProto) runTestGroups(t) } From 1a437d38555ace0123491ebd62dea451e3fb22c4 Mon Sep 17 00:00:00 2001 From: Kilowhisky Date: Mon, 6 May 2024 22:07:07 -0700 Subject: [PATCH 3/3] Access-Control-Allow-Headers is apparently required by the spec --- internal/server/server.go | 7 ++++--- tests/proto_test.go | 4 ++++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/internal/server/server.go b/internal/server/server.go index 49b06488..910e83ec 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1450,9 +1450,10 @@ func readNextHTTPCommand(packet []byte, argsIn [][]byte, msg *Message, wr io.Wri if wr == nil { return false, errors.New("connection is nil") } - corshead := "HTTP/1.1 204 No Content\r\n"+ - "Connection: close\r\n"+ - "Access-Control-Allow-Origin: *\r\n"+ + corshead := "HTTP/1.1 204 No Content\r\n" + + "Connection: close\r\n" + + "Access-Control-Allow-Origin: *\r\n" + + "Access-Control-Allow-Headers: *, Authorization\r\n" + "Access-Control-Allow-Methods: POST, GET, OPTIONS\r\n\r\n" if _, err = wr.Write([]byte(corshead)); err != nil { diff --git a/tests/proto_test.go b/tests/proto_test.go index 40f87dfe..80b74625 100644 --- a/tests/proto_test.go +++ b/tests/proto_test.go @@ -31,12 +31,16 @@ func proto_HTTP_CORS_test(mc *mockServer) error { } origin := resp.Header.Get("Access-Control-Allow-Origin") methods := resp.Header.Get("Access-Control-Allow-Methods") + headers := resp.Header.Get("Access-Control-Allow-Headers") if !(origin == "*" || origin == morigin) { return fmt.Errorf("expected http access-control-allow-origin value '*', got '%s'", origin) } if methods != "POST, GET, OPTIONS" { return fmt.Errorf("expected http access-control-allow-Methods value 'POST, GET, OPTIONS', got '%s'", methods) } + if headers != "*, Authorization" { + return fmt.Errorf("expected http access-control-allow-headers value '*, Authorization', got '%s'", headers) + } // Make the actual request now resp, err = http.Get(url)