From e2620f8e28d262bf140926f47894a5f5c3f67470 Mon Sep 17 00:00:00 2001 From: "Taro L. Saito" Date: Mon, 11 Sep 2023 20:36:23 -0700 Subject: [PATCH] http (feature): Support Cors.newFilter (#3215) --- .../main/scala/wvlet/airframe/http/Http.scala | 1 + .../airframe/http/filter/CorsFilter.scala | 142 ++++++++++++++++++ .../wvlet/airframe/http/filter/CorsTest.scala | 129 ++++++++++++++++ 3 files changed, 272 insertions(+) create mode 100644 airframe-http/src/main/scala/wvlet/airframe/http/filter/CorsFilter.scala create mode 100644 airframe-http/src/test/scala/wvlet/airframe/http/filter/CorsTest.scala diff --git a/airframe-http/src/main/scala/wvlet/airframe/http/Http.scala b/airframe-http/src/main/scala/wvlet/airframe/http/Http.scala index ed18650576..7f83c9e4f7 100644 --- a/airframe-http/src/main/scala/wvlet/airframe/http/Http.scala +++ b/airframe-http/src/main/scala/wvlet/airframe/http/Http.scala @@ -60,6 +60,7 @@ object Http { def DELETE(uri: String) = request(HttpMethod.DELETE, uri) def PUT(uri: String) = request(HttpMethod.PUT, uri) def PATCH(uri: String) = request(HttpMethod.PATCH, uri) + def OPTIONS(uri: String) = request(HttpMethod.OPTIONS, uri) def response(status: HttpStatus = HttpStatus.Ok_200): HttpMessage.Response = { HttpMessage.Response.empty.withStatus(status) diff --git a/airframe-http/src/main/scala/wvlet/airframe/http/filter/CorsFilter.scala b/airframe-http/src/main/scala/wvlet/airframe/http/filter/CorsFilter.scala new file mode 100644 index 0000000000..b17cadefe2 --- /dev/null +++ b/airframe-http/src/main/scala/wvlet/airframe/http/filter/CorsFilter.scala @@ -0,0 +1,142 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package wvlet.airframe.http.filter + +import wvlet.airframe.http.Http +import wvlet.airframe.http.HttpMessage.{Request, Response} +import wvlet.airframe.http.{HttpMessage, HttpMethod, RxHttpEndpoint, RxHttpFilter} +import wvlet.airframe.rx.Rx + +import scala.concurrent.duration.Duration + +/** + * Implements https://fetch.spec.whatwg.org/#http-cors-protocol + */ +object Cors { + case class Policy( + allowsOrigin: String => Option[String], + allowsMethods: String => Option[Seq[String]], + allowsHeaders: Seq[String] => Option[Seq[String]], + exposedHeaders: Seq[String] = Nil, + supportsCredentials: Boolean = false, + maxAge: Option[Duration] = None + ) + + /** A CORS policy that lets you do whatever you want. Don't use this in production. */ + def unsafePermissivePolicy: Policy = Policy( + allowsOrigin = origin => Some(origin), + allowsMethods = method => Some(Seq(method)), + allowsHeaders = headers => Some(headers), + supportsCredentials = true + ) + + /** + * Create a new RxHttpFilter to add headers to support Cross-origin resource sharing (CORS). + * + * {{{ + * Cors.newFilter( + * Cors.Policy( + * allowsOrigin = origin => { origin match { + * case x if x.endsWith("mydomain.com") => Some(origin) + * case _ => None + * }}, + * allowsMethods = _ => Some(Seq(HttpMethod.POST)), + * allowsHeaders = headers => Some(headers) + * )) + * }}} + * + * @param policy + */ + def newFilter(policy: Policy): RxHttpFilter = new CorsFilter(policy) + + private class CorsFilter(policy: Policy) extends RxHttpFilter { + + private def getOrigin(request: Request): Option[String] = + request.header.get("Origin") + + private def getMethod(request: Request): Option[String] = + request.header.get("Access-Control-Request-Method") + + private def commaSpace = ", *".r + private def getHeaders(request: Request): Seq[String] = + request.header.get("Access-Control-Request-Headers") match { + case Some(value) => commaSpace.split(value).toSeq + case None => Nil + } + + private def setOriginAndCredential(resp: Response, origin: String): Response = { + var r = resp.withHeader("Access-Control-Allow-Origin", origin) + if (policy.supportsCredentials && origin != "*") { + r = r.withHeader("Access-Control-Allow-Credentials", "true") + } + r + } + + private def addExposedHeaders(response: Response): Response = { + if (policy.exposedHeaders.nonEmpty) + response.withHeader("Access-Control-Expose-Headers", policy.exposedHeaders.mkString(", ")) + else + response + } + + private def handlePreflight(request: Request): Option[Response] = { + getOrigin(request).flatMap { origin => + getMethod(request).flatMap { method => + val headers = getHeaders(request) + policy.allowsMethods(method).flatMap { allowedMethods => + policy.allowsHeaders(headers).map { allowedHeaders => + var resp = Http.response() + resp = setOriginAndCredential(resp, origin) + // max-age + policy.maxAge.foreach { maxAge => + resp = resp.withHeader("Access-Control-Max-Age", maxAge.toSeconds.toString) + } + // methods + resp = resp.withHeader("Access-Control-Allow-Methods", allowedMethods.mkString(", ")) + // headers + resp = resp.withHeader("Access-Control-Allow-Headers", allowedHeaders.mkString(", ")) + resp + } + } + } + } + } + + private def handleSimple(request: Request, response: Response): Response = { + getOrigin(request) + .map(origin => setOriginAndCredential(response, origin)) + .map(addExposedHeaders) + .getOrElse(response) + } + + override def apply(request: HttpMessage.Request, next: RxHttpEndpoint): Rx[HttpMessage.Response] = { + val resp: Rx[Response] = request.method match { + case HttpMethod.OPTIONS => + // Preflight request + handlePreflight(request) match { + case Some(resp) => + Rx.single(resp) + case None => + // No matching policy + Rx.single(Http.response()) + } + case _ => + next(request).map(handleSimple(request, _)) + } + resp.map { resp => + resp.withHeader("Vary", "Origin") + } + } + } +} diff --git a/airframe-http/src/test/scala/wvlet/airframe/http/filter/CorsTest.scala b/airframe-http/src/test/scala/wvlet/airframe/http/filter/CorsTest.scala new file mode 100644 index 0000000000..b932e38539 --- /dev/null +++ b/airframe-http/src/test/scala/wvlet/airframe/http/filter/CorsTest.scala @@ -0,0 +1,129 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package wvlet.airframe.http.filter + +import wvlet.airframe.http.{Http, HttpMessage, HttpMethod, RxHttpEndpoint, RxHttpFilter} +import wvlet.airframe.rx.Rx +import wvlet.airspec.AirSpec + +import java.util.concurrent.TimeUnit +import scala.concurrent.duration.Duration + +class CorsTest extends AirSpec { + + private val maxAge = Duration(1, TimeUnit.HOURS) + + private def policy = Cors.Policy( + allowsOrigin = { + case origin if origin.startsWith("console") => Some(origin) + case origin if origin.endsWith("td.com") => Some(origin) + case _ => None + }, + allowsMethods = method => Some(Seq("GET")), + allowsHeaders = headers => Some(headers), + exposedHeaders = Seq("x-airframe-rpc"), + supportsCredentials = true, + maxAge = Some(maxAge) + ) + + private val corsFilter = Cors.newFilter(policy) + private val endpoint = new RxHttpEndpoint { + override def apply(request: HttpMessage.Request): Rx[HttpMessage.Response] = { + Rx.single(Http.response()) + } + } + + test("CorsFilter handles preflight requests") { + val request = Http + .OPTIONS("/") + .withHeader("Origin", "thetd.com") + .withHeader("Access-Control-Request-Method", "BRR") + + corsFilter.apply(request, endpoint).map { resp => + resp.getHeader("Access-Control-Allow-Origin") shouldBe Some("thetd.com") + resp.getHeader("Access-Control-Allow-Credentials") shouldBe Some("true") + resp.getHeader("Access-Control-Allow-Methods") shouldBe Some("GET") + resp.getHeader("Vary") shouldBe Some("Origin") + resp.getHeader("Access-Control-Max-Age") shouldBe Some(maxAge.toSeconds.toString) + resp.contentString shouldBe empty + } + } + + test("CorsFilter responds to invalid preflight requests without CORS headers") { + val request = Http + .OPTIONS("/") + + corsFilter.apply(request, endpoint).map { resp => + resp.getHeader("Access-Control-Allow-Origin") shouldBe empty + resp.getHeader("Access-Control-Allow-Credentials") shouldBe empty + resp.getHeader("Access-Control-Allow-Methods") shouldBe empty + resp.getHeader("Access-Control-Max-Age") shouldBe empty + resp.getHeader("Vary") shouldBe Some("Origin") + resp.contentString shouldBe empty + } + } + + test("CorsFilter responds to unacceptable cross-origin requests without CORS headers") { + val request = Http + .OPTIONS("/") + .withHeader("Origin", "theclub") + + corsFilter.apply(request, endpoint).map { resp => + resp.getHeader("Access-Control-Allow-Origin") shouldBe empty + resp.getHeader("Access-Control-Allow-Credentials") shouldBe empty + resp.getHeader("Access-Control-Allow-Methods") shouldBe empty + resp.getHeader("Access-Control-Max-Age") shouldBe empty + resp.getHeader("Vary") shouldBe Some("Origin") + resp.contentString shouldBe empty + } + } + + test("CorsFilter handles simple requests") { + val request = Http + .GET("/") + .withHeader("Origin", "thetd.com") + + corsFilter.apply(request, endpoint).map { resp => + resp.getHeader("Access-Control-Allow-Origin") shouldBe Some("thetd.com") + resp.getHeader("Access-Control-Allow-Credentials") shouldBe Some("true") + resp.getHeader("Access-Control-Expose-Headers") shouldBe Some("x-airframe-rpc") + resp.contentString shouldBe empty + } + } + + test("CorsFilter handles simple requests with multiple origins") { + val request = Http + .GET("/") + .withHeader("Origin", "thetd.com,console") + + corsFilter.apply(request, endpoint).map { resp => + resp.getHeader("Access-Control-Allow-Origin") shouldBe Some("thetd.com,console") + resp.getHeader("Access-Control-Allow-Credentials") shouldBe Some("true") + resp.getHeader("Access-Control-Expose-Headers") shouldBe Some("x-airframe-rpc") + resp.contentString shouldBe empty + } + } + + test("CorsFilter does not add response headers to simple requests if request headers aren't present") { + val request = Http + .GET("/") + + corsFilter.apply(request, endpoint).map { resp => + resp.getHeader("Access-Control-Allow-Origin") shouldBe empty + resp.getHeader("Access-Control-Allow-Credentials") shouldBe empty + resp.getHeader("Access-Control-Expose-Headers") shouldBe empty + resp.contentString shouldBe empty + } + } +}