From 5290cd4764ef6d4189178f624c7232285815779b Mon Sep 17 00:00:00 2001 From: calvinlfer Date: Mon, 11 Dec 2023 17:45:11 -0500 Subject: [PATCH 1/3] Implement tracing RequestInterceptor for Tapir - includes enriched logs - includes traced response headers --- build.sbt | 17 +- project/Versions.scala | 8 +- .../tapir/examples/ExampleServerApp.scala | 25 ++- .../zio/extras/tapir/TapirServerTracer.scala | 110 ------------- .../zio/extras/tapir/TraceInterceptor.scala | 154 ++++++++++++++++++ .../extras/tapir/TraceInterceptorSpec.scala | 50 ++++++ 6 files changed, 229 insertions(+), 135 deletions(-) delete mode 100644 tapir/src/main/scala/io/kaizensolutions/trace4cats/zio/extras/tapir/TapirServerTracer.scala create mode 100644 tapir/src/main/scala/io/kaizensolutions/trace4cats/zio/extras/tapir/TraceInterceptor.scala create mode 100644 tapir/src/test/scala/io/kaizensolutions/trace4cats/zio/extras/tapir/TraceInterceptorSpec.scala diff --git a/build.sbt b/build.sbt index acc42aa..68021f6 100644 --- a/build.sbt +++ b/build.sbt @@ -291,16 +291,21 @@ lazy val tapir = .in(file("tapir")) .settings(kindProjectorSettings*) .settings( - name := "trace4cats-zio-extras-tapir", - libraryDependencies += "com.softwaremill.sttp.tapir" %% "tapir-core" % Versions.tapir + name := "trace4cats-zio-extras-tapir", + libraryDependencies ++= + Seq( + "com.softwaremill.sttp.tapir" %% "tapir-server" % Versions.tapir, + "com.softwaremill.sttp.tapir" %% "tapir-zio" % Versions.tapir % Test, + "com.softwaremill.sttp.tapir" %% "tapir-http4s-server" % Versions.tapir % Test + ) ) - .dependsOn(core) + .dependsOn(core % "compile->compile;test->test") lazy val tapirExample = project .in(file("tapir-examples")) .settings( - name := "trace4cats-zio-extras-zio-sttp-examples", + name := "trace4cats-zio-extras-zio--sttp-examples", publish / skip := true, libraryDependencies ++= Seq( @@ -375,7 +380,7 @@ lazy val skunk = "org.tpolecat" %% "skunk-core" % Versions.skunk, "io.zonky.test" % "embedded-postgres" % Versions.embeddedPostgres % Test, "dev.zio" %% "zio-logging-slf4j" % Versions.zioLogging % Test, - "ch.qos.logback" % "logback-classic" % "1.4.11" % Test + "ch.qos.logback" % "logback-classic" % "1.4.14" % Test ) ) .dependsOn( @@ -403,7 +408,7 @@ lazy val zioKafka = "dev.zio" %% "zio-kafka" % Versions.zioKafka, "io.github.embeddedkafka" %% "embedded-kafka" % Versions.kafkaEmbedded % Test, "dev.zio" %% "zio-logging-slf4j" % Versions.zioLogging % Test, - "ch.qos.logback" % "logback-classic" % "1.4.11" % Test + "ch.qos.logback" % "logback-classic" % "1.4.14" % Test ), excludeDependencies ++= { CrossVersion.partialVersion(scalaVersion.value) match { diff --git a/project/Versions.scala b/project/Versions.scala index 9f315e4..c295ee9 100644 --- a/project/Versions.scala +++ b/project/Versions.scala @@ -7,16 +7,16 @@ object Versions { val trace4CatsJaegarExporter = "0.14.0" val trace4CatsHttp4sCommon = "0.14.1" val zio = "2.0.19" - val zioLogging = "2.1.15" + val zioLogging = "2.1.16" val zioInteropCats = "23.1.0.0" val zioHttp = "3.0.0-RC3" val http4s = "0.23.24" - val tapir = "1.9.1" + val tapir = "1.9.4" val sttp = "3.9.1" val virgil = "1.0.4" val doobie = "1.0.0-RC5" - val skunk = "0.6.1" - val embeddedPostgres = "2.0.4" + val skunk = "0.6.2" + val embeddedPostgres = "2.0.6" val scribe = "3.12.2" val zioKafka = "2.7.0" val kafkaEmbedded = "3.6.0" diff --git a/tapir-examples/src/main/scala/io/kaizensolutions/trace4cats/zio/extras/tapir/examples/ExampleServerApp.scala b/tapir-examples/src/main/scala/io/kaizensolutions/trace4cats/zio/extras/tapir/examples/ExampleServerApp.scala index 18653c3..b39d9d9 100644 --- a/tapir-examples/src/main/scala/io/kaizensolutions/trace4cats/zio/extras/tapir/examples/ExampleServerApp.scala +++ b/tapir-examples/src/main/scala/io/kaizensolutions/trace4cats/zio/extras/tapir/examples/ExampleServerApp.scala @@ -5,13 +5,14 @@ import fs2.io.net.Network import io.circe.Codec as CirceCodec import io.circe.generic.semiauto.deriveCodec import io.kaizensolutions.trace4cats.zio.extras.ZTracer -import io.kaizensolutions.trace4cats.zio.extras.tapir.TapirServerTracer +import io.kaizensolutions.trace4cats.zio.extras.tapir.TraceInterceptor import org.http4s.ember.server.EmberServerBuilder import sttp.model.{Headers, StatusCode} import sttp.tapir.* import sttp.tapir.json.circe.* import sttp.tapir.server.ServerEndpoint -import sttp.tapir.server.http4s.Http4sServerInterpreter +import sttp.tapir.server.http4s.{Http4sServerInterpreter, Http4sServerOptions} +import trace4cats.kernel.ToHeaders import zio.* import zio.interop.catz.* @@ -22,7 +23,7 @@ object ExampleServerApp extends ZIOAppDefault { def countCharacters(tracer: ZTracer)(in: Request): UIO[Either[NoCharacters, Int]] = { val l = in.input.length val out = tracer.spanSource() { - if (l > 0) ZIO.succeed(l) + if (l > 0) ZIO.logInfo(s"Received ${in.input}").as(l) else ZIO.fail(NoCharacters("Please supply at least 1 character to count")) } @@ -41,22 +42,16 @@ object ExampleServerApp extends ZIOAppDefault { def serverEndpoint(tracer: ZTracer): ServerEndpoint.Full[Unit, Unit, Request, NoCharacters, Int, Any, Task] = countCharactersEndpoint.serverLogic(countCharacters(tracer)) - def tracedServerEndpoint(tracer: ZTracer): ServerEndpoint.Full[Unit, Unit, Request, NoCharacters, Int, Any, Task] = - TapirServerTracer - .traceEndpoint[Request, NoCharacters, Int, Any, Any, Throwable]( - tracer = tracer, - serverEndpoint = serverEndpoint(tracer), - extractRequestHeaders = _.headers, - extractResponseHeaders = _ => Headers(Nil) - ) - override val run: ZIO[ZIOAppArgs & Scope, Any, Any] = { val program = for { tracer <- ZIO.service[ZTracer] - endpoint = tracedServerEndpoint(tracer) - httpApp = Http4sServerInterpreter[Task]().toRoutes(endpoint).orNotFound - port <- ZIO.fromEither(Port.fromInt(8080).toRight(new RuntimeException("Invalid Port"))) + endpoint = serverEndpoint(tracer) + serverOptions = Http4sServerOptions + .default[Task] + .prependInterceptor(TraceInterceptor(tracer, headerFormat = ToHeaders.b3Single)) + httpApp = Http4sServerInterpreter[Task](serverOptions).toRoutes(endpoint).orNotFound + port <- ZIO.fromEither(Port.fromInt(8080).toRight(new RuntimeException("Invalid Port"))) server <- EmberServerBuilder .default[Task] .withHostOption(Host.fromString("localhost")) diff --git a/tapir/src/main/scala/io/kaizensolutions/trace4cats/zio/extras/tapir/TapirServerTracer.scala b/tapir/src/main/scala/io/kaizensolutions/trace4cats/zio/extras/tapir/TapirServerTracer.scala deleted file mode 100644 index 0605347..0000000 --- a/tapir/src/main/scala/io/kaizensolutions/trace4cats/zio/extras/tapir/TapirServerTracer.scala +++ /dev/null @@ -1,110 +0,0 @@ -package io.kaizensolutions.trace4cats.zio.extras.tapir - -import trace4cats.model.{AttributeValue, SpanKind, SpanStatus, TraceHeaders} -import io.kaizensolutions.trace4cats.zio.extras.ZTracer -import sttp.model.{HeaderNames, Headers} -import sttp.tapir.Endpoint -import sttp.tapir.internal.* -import sttp.tapir.server.ServerEndpoint -import zio.{Exit, ZIO} - -object TapirServerTracer { - - /** - * @param tracer - * is the ZTracer to use for tracing server endpoints - * @param serverEndpoint - * is the server endpoint to be traced - * @param extractRequestHeaders - * extracts the headers from the request - * @param extractResponseHeaders - * extracts the headers from the response - * @param spanNamer - * defines how a span is named - * @param errorToSpanStatus - * maps an error to a span status - * @param dropHeadersWhen - * is the predicate to use for determining whether to drop the headers - * @tparam I - * Tapir's I (Input) - * @tparam E - * Tapir's E (Error) - * @tparam O - * Tapir's O (Output) - * @tparam R - * Tapir's R (endpoint capabilities) - * @tparam EffectEnv - * is the ZIO Effect's environment - * @tparam EffectErr - * is the ZIO Effect's error type - * @return - */ - def traceEndpoint[I, E, O, R, EffectEnv, EffectErr]( - tracer: ZTracer, - serverEndpoint: ServerEndpoint.Full[Unit, Unit, I, E, O, R, ZIO[EffectEnv, EffectErr, *]], - extractRequestHeaders: I => Headers, - extractResponseHeaders: O => Headers, - spanNamer: (Endpoint[?, I, ?, ?, ?], I) => String = methodWithPathTemplateSpanNamer[I], - errorToSpanStatus: E => SpanStatus = (e: E) => SpanStatus.Internal(e.toString), - dropHeadersWhen: String => Boolean = HeaderNames.isSensitive - ): ServerEndpoint.Full[Unit, Unit, I, E, O, R, ZIO[EffectEnv, EffectErr, *]] = - ServerEndpoint.public( - endpoint = serverEndpoint.endpoint, - logic = { monadError => input => - val reqHeaders = extractRequestHeaders(input) - val traceHeaders = TraceHeaders.of(reqHeaders.headers.map(h => (h.name, h.value))*) - val spanName = spanNamer(serverEndpoint.endpoint, input) - - tracer.fromHeaders(headers = traceHeaders, kind = SpanKind.Server, name = spanName) { span => - tracer.putAll(requestFields(reqHeaders, dropHeadersWhen)*) *> - serverEndpoint - .logic(monadError)(())(input) - .flatMap { - case left @ Left(error) => - span - .setStatus(errorToSpanStatus(error)) - .as(left) - - case right @ Right(success) => - tracer - .putAll(responseFields(extractResponseHeaders(success), dropHeadersWhen)*) - .as(right) - } - .onExit { - case Exit.Failure(cause) if cause.isDie => - span.setStatus(SpanStatus.Internal(cause.prettyPrint)) - - case _ => - ZIO.unit - } - } - } - ) - - def methodWithPathTemplateSpanNamer[I]: (Endpoint[?, I, ?, ?, ?], I) => String = - (endpoint, _) => endpoint.input.method.fold("ANY")(_.method) + " " + endpoint.showPathTemplate() - - private def requestFields( - hs: Headers, - dropHeadersWhen: String => Boolean - ): List[(String, AttributeValue)] = - headerFields(hs, "req", dropHeadersWhen) - - private def responseFields( - hs: Headers, - dropHeadersWhen: String => Boolean - ): List[(String, AttributeValue)] = - headerFields(hs, "resp", dropHeadersWhen) - - private def headerFields( - hs: Headers, - `type`: String, - dropHeadersWhen: String => Boolean - ): List[(String, AttributeValue)] = - hs.headers - .filter(h => !dropHeadersWhen(h.name)) - .map { h => - (s"${`type`}.header.${h.name}", h.value: AttributeValue) - } - .toList -} diff --git a/tapir/src/main/scala/io/kaizensolutions/trace4cats/zio/extras/tapir/TraceInterceptor.scala b/tapir/src/main/scala/io/kaizensolutions/trace4cats/zio/extras/tapir/TraceInterceptor.scala new file mode 100644 index 0000000..28cca51 --- /dev/null +++ b/tapir/src/main/scala/io/kaizensolutions/trace4cats/zio/extras/tapir/TraceInterceptor.scala @@ -0,0 +1,154 @@ +package io.kaizensolutions.trace4cats.zio.extras.tapir + +import io.kaizensolutions.trace4cats.zio.extras.{ZSpan, ZTracer} +import sttp.model.{Header, HeaderNames, StatusCode} +import sttp.monad.MonadError +import sttp.tapir.model.ServerRequest +import sttp.tapir.server.ServerEndpoint +import sttp.tapir.server.interceptor.* +import sttp.tapir.server.model.ServerResponse +import trace4cats.model.{SpanKind, SpanStatus, TraceHeaders} +import trace4cats.{AttributeValue, ToHeaders} +import zio.* + +final class TraceInterceptor private ( + private val tracer: ZTracer, + private val dropHeadersWhen: String => Boolean, + private val spanNamer: PartialFunction[ServerRequest, String], + private val enrichResponseHeadersWithTraceIds: Boolean, + private val enrichLogs: Boolean, + private val headerFormat: ToHeaders +) extends RequestInterceptor[Task] { + + override def apply[R, B]( + responder: Responder[Task, B], + requestHandler: EndpointInterceptor[Task] => RequestHandler[Task, R, B] + ): RequestHandler[Task, R, B] = new RequestHandler[Task, R, B] { + private def spanNamerTotal(input: ServerRequest): String = + if (spanNamer.isDefinedAt(input)) spanNamer(input) + else input.showShort + + override def apply(request: ServerRequest, endpoints: List[ServerEndpoint[R, Task]])(implicit + monad: MonadError[Task] + ): Task[RequestResult[B]] = + if (endpoints.nonEmpty) { + // only run tracing if there are endpoints to match + val spanName = spanNamerTotal(request) + val traceHeaders = TraceHeaders.of(request.headers.map(h => (h.name, h.value))*) + + tracer.fromHeaders(traceHeaders, name = spanName, kind = SpanKind.Server) { span => + val logTraceContext = + if (enrichLogs) ZIOAspect.annotated(annotations = extractKVHeaders(span, headerFormat).toList*) + else ZIOAspect.identity + + enrichSpanFromRequest(request, dropHeadersWhen, span) *> + (requestHandler(EndpointInterceptor.noop)(request, endpoints) @@ logTraceContext) + .foldZIO( + error => span.setStatus(SpanStatus.Internal(error.toString)) *> ZIO.fail(error), + { + case res @ RequestResult.Response(response) => + enrichSpanFromResponse(response, dropHeadersWhen, span).as( + if (enrichResponseHeadersWithTraceIds) + RequestResult.Response(response.addHeaders(toHttpHeaders(span, headerFormat))) + else res + ) + + case res @ RequestResult.Failure(failures) => + span + .putAll( + failures.map(f => + s"error.${f.failingInput.show}" -> AttributeValue.stringToTraceValue(f.failure.toString) + )* + ) + .as(res) + } + ) + } + } else requestHandler(EndpointInterceptor.noop)(request, endpoints) + } + + private def toHttpHeaders(span: ZSpan, whichHeaders: ToHeaders): Seq[Header] = + span + .extractHeaders(whichHeaders) + .values + .collect { case (k, v) if v.nonEmpty => Header(k.toString, v) } + .toSeq + + private def extractKVHeaders(span: ZSpan, whichHeaders: ToHeaders): Map[String, String] = + span + .extractHeaders(whichHeaders) + .values + .collect { case (k, v) if v.nonEmpty => (k.toString, v) } + + private def enrichSpanFromRequest( + request: ServerRequest, + dropHeadersWhen: String => Boolean, + span: ZSpan + ): UIO[Unit] = + if (span.isSampled) span.putAll(requestFields(request.headers, dropHeadersWhen)*) + else ZIO.unit + + private def enrichSpanFromResponse[A]( + response: ServerResponse[A], + dropHeadersWhen: String => Boolean, + span: ZSpan + ): UIO[Unit] = { + val respFields = responseFields(response.headers, dropHeadersWhen) + val spanRespAttrs = if (span.isSampled) span.putAll(respFields*) else ZIO.unit + spanRespAttrs *> span.setStatus(toSpanStatus(response.code)) + } + + private def toSpanStatus(value: StatusCode): SpanStatus = + value match { + case StatusCode.BadRequest => SpanStatus.Internal("Bad Request") + case StatusCode.Unauthorized => SpanStatus.Unauthenticated + case StatusCode.Forbidden => SpanStatus.PermissionDenied + case StatusCode.NotFound => SpanStatus.NotFound + case StatusCode.TooManyRequests => SpanStatus.Unavailable + case StatusCode.BadGateway => SpanStatus.Unavailable + case StatusCode.ServiceUnavailable => SpanStatus.Unavailable + case StatusCode.GatewayTimeout => SpanStatus.Unavailable + case status if status.isSuccess => SpanStatus.Ok + case _ => SpanStatus.Unknown + } + + private def requestFields( + hs: Seq[Header], + dropHeadersWhen: String => Boolean + ): List[(String, AttributeValue)] = + headerFields(hs, "req", dropHeadersWhen) + + private def responseFields( + hs: Seq[Header], + dropHeadersWhen: String => Boolean + ): List[(String, AttributeValue)] = + headerFields(hs, "resp", dropHeadersWhen) + + private def headerFields( + hs: Seq[Header], + `type`: String, + dropHeadersWhen: String => Boolean + ): List[(String, AttributeValue)] = + hs.filter(h => !dropHeadersWhen(h.name)) + .map { h => + (s"${`type`}.header.${h.name}", h.value: AttributeValue) + } + .toList +} +object TraceInterceptor { + def apply( + tracer: ZTracer, + dropHeadersWhen: String => Boolean = HeaderNames.isSensitive, + spanNamer: PartialFunction[ServerRequest, String] = _.showShort, + enrichResponseHeadersWithTraceIds: Boolean = true, + enrichLogs: Boolean = true, + headerFormat: ToHeaders = ToHeaders.standard + ): TraceInterceptor = new TraceInterceptor( + tracer, + dropHeadersWhen, + spanNamer, + enrichResponseHeadersWithTraceIds, + enrichLogs, + headerFormat + ) +} diff --git a/tapir/src/test/scala/io/kaizensolutions/trace4cats/zio/extras/tapir/TraceInterceptorSpec.scala b/tapir/src/test/scala/io/kaizensolutions/trace4cats/zio/extras/tapir/TraceInterceptorSpec.scala new file mode 100644 index 0000000..959d685 --- /dev/null +++ b/tapir/src/test/scala/io/kaizensolutions/trace4cats/zio/extras/tapir/TraceInterceptorSpec.scala @@ -0,0 +1,50 @@ +package io.kaizensolutions.trace4cats.zio.extras.tapir + +import io.kaizensolutions.trace4cats.zio.extras.{InMemorySpanCompleter, ZTracer} +import org.http4s.* +import org.http4s.syntax.all.* +import org.typelevel.ci.CIString +import sttp.model.StatusCode +import sttp.tapir.server.http4s.{Http4sServerInterpreter, Http4sServerOptions} +import sttp.tapir.ztapir.* +import trace4cats.{ToHeaders, TraceProcess} +import zio.interop.catz.* +import zio.test.* +import zio.{Scope, Task, ZIOAspect} + +object TraceInterceptorSpec extends ZIOSpecDefault { + final class TestEndpoint(tracer: ZTracer) { + private val testEndpoint = + endpoint.get + .in("hello") + .out(statusCode(StatusCode.Ok)) + + val serverLogic: ZServerEndpoint[Any, Any] = + testEndpoint.zServerLogic(_ => tracer.retrieveCurrentSpan.tap(_.put("hello", "hello")).unit) + } + + def spec: Spec[TestEnvironment & Scope, Throwable] = suite("TraceInterceptor specification")( + test("traces http requests") { + for { + result <- InMemorySpanCompleter.entryPoint(TraceProcess("tapir-trace-interceptor-test")) + (sc, ep) = result + tracer <- InMemorySpanCompleter.toZTracer(ep) + interceptor = TraceInterceptor(tracer, headerFormat = ToHeaders.w3c) + endpoint = new TestEndpoint(tracer) + httpApp = Http4sServerInterpreter[Task]( + Http4sServerOptions + .default[Task] + .prependInterceptor(interceptor) + ).toRoutes(endpoint.serverLogic).orNotFound + response <- httpApp.run(Request(uri = uri"/hello")) + spans <- sc.retrieveCollected + _ = println(response) + _ = println(spans) + } yield assertTrue( + response.headers.get(CIString("traceparent")).isDefined, + spans.exists(_.name == "GET /hello"), + spans.exists(_.attributes.contains("hello")) + ) + } + ) +} From e516e3ca9451eb82b99db5b6207b1d0f611789df Mon Sep 17 00:00:00 2001 From: calvinlfer Date: Tue, 12 Dec 2023 10:33:55 -0500 Subject: [PATCH 2/3] Rework Trace Interceptor to provide a better experience - paths now reflect the templated version automatically --- .../zio/extras/tapir/TraceInterceptor.scala | 141 ++++++++++-------- .../extras/tapir/TraceInterceptorSpec.scala | 18 ++- 2 files changed, 90 insertions(+), 69 deletions(-) diff --git a/tapir/src/main/scala/io/kaizensolutions/trace4cats/zio/extras/tapir/TraceInterceptor.scala b/tapir/src/main/scala/io/kaizensolutions/trace4cats/zio/extras/tapir/TraceInterceptor.scala index 28cca51..79f233f 100644 --- a/tapir/src/main/scala/io/kaizensolutions/trace4cats/zio/extras/tapir/TraceInterceptor.scala +++ b/tapir/src/main/scala/io/kaizensolutions/trace4cats/zio/extras/tapir/TraceInterceptor.scala @@ -6,6 +6,7 @@ import sttp.monad.MonadError import sttp.tapir.model.ServerRequest import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.interceptor.* +import sttp.tapir.server.interpreter.BodyListener import sttp.tapir.server.model.ServerResponse import trace4cats.model.{SpanKind, SpanStatus, TraceHeaders} import trace4cats.{AttributeValue, ToHeaders} @@ -14,7 +15,6 @@ import zio.* final class TraceInterceptor private ( private val tracer: ZTracer, private val dropHeadersWhen: String => Boolean, - private val spanNamer: PartialFunction[ServerRequest, String], private val enrichResponseHeadersWithTraceIds: Boolean, private val enrichLogs: Boolean, private val headerFormat: ToHeaders @@ -24,47 +24,81 @@ final class TraceInterceptor private ( responder: Responder[Task, B], requestHandler: EndpointInterceptor[Task] => RequestHandler[Task, R, B] ): RequestHandler[Task, R, B] = new RequestHandler[Task, R, B] { - private def spanNamerTotal(input: ServerRequest): String = - if (spanNamer.isDefinedAt(input)) spanNamer(input) - else input.showShort + private val tracingEndpointInterceptor = new TraceEndpointInterceptor( + tracer, + dropHeadersWhen, + enrichResponseHeadersWithTraceIds, + enrichLogs, + headerFormat + ) override def apply(request: ServerRequest, endpoints: List[ServerEndpoint[R, Task]])(implicit monad: MonadError[Task] ): Task[RequestResult[B]] = - if (endpoints.nonEmpty) { - // only run tracing if there are endpoints to match - val spanName = spanNamerTotal(request) - val traceHeaders = TraceHeaders.of(request.headers.map(h => (h.name, h.value))*) - - tracer.fromHeaders(traceHeaders, name = spanName, kind = SpanKind.Server) { span => - val logTraceContext = - if (enrichLogs) ZIOAspect.annotated(annotations = extractKVHeaders(span, headerFormat).toList*) - else ZIOAspect.identity - - enrichSpanFromRequest(request, dropHeadersWhen, span) *> - (requestHandler(EndpointInterceptor.noop)(request, endpoints) @@ logTraceContext) - .foldZIO( - error => span.setStatus(SpanStatus.Internal(error.toString)) *> ZIO.fail(error), - { - case res @ RequestResult.Response(response) => - enrichSpanFromResponse(response, dropHeadersWhen, span).as( - if (enrichResponseHeadersWithTraceIds) - RequestResult.Response(response.addHeaders(toHttpHeaders(span, headerFormat))) - else res - ) - - case res @ RequestResult.Failure(failures) => - span - .putAll( - failures.map(f => - s"error.${f.failingInput.show}" -> AttributeValue.stringToTraceValue(f.failure.toString) - )* - ) - .as(res) - } - ) - } - } else requestHandler(EndpointInterceptor.noop)(request, endpoints) + requestHandler(tracingEndpointInterceptor)(request, endpoints) + } +} +object TraceInterceptor { + def apply( + tracer: ZTracer, + dropHeadersWhen: String => Boolean = HeaderNames.isSensitive, + enrichResponseHeadersWithTraceIds: Boolean = true, + enrichLogs: Boolean = true, + headerFormat: ToHeaders = ToHeaders.standard + ): TraceInterceptor = new TraceInterceptor( + tracer, + dropHeadersWhen, + enrichResponseHeadersWithTraceIds, + enrichLogs, + headerFormat + ) +} + +private class TraceEndpointInterceptor( + private val tracer: ZTracer, + private val dropHeadersWhen: String => Boolean, + private val enrichResponseHeadersWithTraceIds: Boolean, + private val enrichLogs: Boolean, + private val headerFormat: ToHeaders +) extends EndpointInterceptor[Task] { + override def apply[B]( + responder: Responder[Task, B], + endpointHandler: EndpointHandler[Task, B] + ): EndpointHandler[Task, B] = new EndpointHandler[Task, B] { + + override def onDecodeSuccess[A, U, I]( + ctx: DecodeSuccessContext[Task, A, U, I] + )(implicit monad: MonadError[Task], bodyListener: BodyListener[Task, B]): Task[ServerResponse[B]] = { + val spanName = ctx.endpoint.showShort + val request = ctx.request + val traceHeaders = TraceHeaders.of(request.headers.map(h => (h.name, h.value))*) + tracer.fromHeaders(traceHeaders, name = spanName, kind = SpanKind.Server) { span => + val logTraceContext = + if (enrichLogs) ZIOAspect.annotated(annotations = extractKVHeaders(span, headerFormat).toList*) + else ZIOAspect.identity + + enrichSpanFromRequest(request, dropHeadersWhen, span) *> + (endpointHandler.onDecodeSuccess(ctx) @@ logTraceContext) + .foldZIO( + error => span.setStatus(SpanStatus.Internal(error.toString)) *> ZIO.fail(error), + serverResponse => + enrichSpanFromResponse(serverResponse, dropHeadersWhen, span).as( + if (enrichResponseHeadersWithTraceIds) serverResponse.addHeaders(toHttpHeaders(span, headerFormat)) + else serverResponse + ) + ) + } + } + + override def onSecurityFailure[A]( + ctx: SecurityFailureContext[Task, A] + )(implicit monad: MonadError[Task], bodyListener: BodyListener[Task, B]): Task[ServerResponse[B]] = + endpointHandler.onSecurityFailure(ctx) + + override def onDecodeFailure( + ctx: DecodeFailureContext + )(implicit monad: MonadError[Task], bodyListener: BodyListener[Task, B]): Task[Option[ServerResponse[B]]] = + endpointHandler.onDecodeFailure(ctx) } private def toHttpHeaders(span: ZSpan, whichHeaders: ToHeaders): Seq[Header] = @@ -115,40 +149,21 @@ final class TraceInterceptor private ( private def requestFields( hs: Seq[Header], dropHeadersWhen: String => Boolean - ): List[(String, AttributeValue)] = + ): Seq[(String, AttributeValue)] = headerFields(hs, "req", dropHeadersWhen) private def responseFields( hs: Seq[Header], dropHeadersWhen: String => Boolean - ): List[(String, AttributeValue)] = + ): Seq[(String, AttributeValue)] = headerFields(hs, "resp", dropHeadersWhen) private def headerFields( hs: Seq[Header], `type`: String, dropHeadersWhen: String => Boolean - ): List[(String, AttributeValue)] = - hs.filter(h => !dropHeadersWhen(h.name)) - .map { h => - (s"${`type`}.header.${h.name}", h.value: AttributeValue) - } - .toList -} -object TraceInterceptor { - def apply( - tracer: ZTracer, - dropHeadersWhen: String => Boolean = HeaderNames.isSensitive, - spanNamer: PartialFunction[ServerRequest, String] = _.showShort, - enrichResponseHeadersWithTraceIds: Boolean = true, - enrichLogs: Boolean = true, - headerFormat: ToHeaders = ToHeaders.standard - ): TraceInterceptor = new TraceInterceptor( - tracer, - dropHeadersWhen, - spanNamer, - enrichResponseHeadersWithTraceIds, - enrichLogs, - headerFormat - ) + ): Seq[(String, AttributeValue)] = + hs.filter(h => !dropHeadersWhen(h.name)).map { h => + (s"${`type`}.header.${h.name}", h.value: AttributeValue) + } } diff --git a/tapir/src/test/scala/io/kaizensolutions/trace4cats/zio/extras/tapir/TraceInterceptorSpec.scala b/tapir/src/test/scala/io/kaizensolutions/trace4cats/zio/extras/tapir/TraceInterceptorSpec.scala index 959d685..76f2441 100644 --- a/tapir/src/test/scala/io/kaizensolutions/trace4cats/zio/extras/tapir/TraceInterceptorSpec.scala +++ b/tapir/src/test/scala/io/kaizensolutions/trace4cats/zio/extras/tapir/TraceInterceptorSpec.scala @@ -10,17 +10,23 @@ import sttp.tapir.ztapir.* import trace4cats.{ToHeaders, TraceProcess} import zio.interop.catz.* import zio.test.* -import zio.{Scope, Task, ZIOAspect} +import zio.{Scope, Task} object TraceInterceptorSpec extends ZIOSpecDefault { final class TestEndpoint(tracer: ZTracer) { private val testEndpoint = endpoint.get .in("hello") + .in(path[String]("name")) + .in("greeting") .out(statusCode(StatusCode.Ok)) val serverLogic: ZServerEndpoint[Any, Any] = - testEndpoint.zServerLogic(_ => tracer.retrieveCurrentSpan.tap(_.put("hello", "hello")).unit) + testEndpoint.zServerLogic(name => + tracer.withSpan("moshi") { span => + span.put("hello", name).unit + } + ) } def spec: Spec[TestEnvironment & Scope, Throwable] = suite("TraceInterceptor specification")( @@ -36,13 +42,13 @@ object TraceInterceptorSpec extends ZIOSpecDefault { .default[Task] .prependInterceptor(interceptor) ).toRoutes(endpoint.serverLogic).orNotFound - response <- httpApp.run(Request(uri = uri"/hello")) + response <- httpApp.run(Request(uri = uri"/hello/cal/greeting")) spans <- sc.retrieveCollected - _ = println(response) - _ = println(spans) } yield assertTrue( response.headers.get(CIString("traceparent")).isDefined, - spans.exists(_.name == "GET /hello"), + response.status == Status.Ok, + spans.exists(_.name == "GET /hello/{name}/greeting"), + spans.find(_.name == "moshi").exists(_.context.parent.isDefined), spans.exists(_.attributes.contains("hello")) ) } From 8b538681320623bc48c774bef4f6d921573ff09c Mon Sep 17 00:00:00 2001 From: calvinlfer Date: Tue, 12 Dec 2023 11:07:47 -0500 Subject: [PATCH 3/3] Fix documentation --- build.sbt | 1 + docs/Integrations/Database/virgil.md | 5 +++ docs/Integrations/HTTP/tapir.md | 34 +++++++++++-------- .../extras/tapir/TraceInterceptorSpec.scala | 4 +-- 4 files changed, 28 insertions(+), 16 deletions(-) diff --git a/build.sbt b/build.sbt index 68021f6..f17b629 100644 --- a/build.sbt +++ b/build.sbt @@ -455,6 +455,7 @@ lazy val docs = sttp, sttpExample, tapir, + tapirExample, virgil, doobie, skunk, diff --git a/docs/Integrations/Database/virgil.md b/docs/Integrations/Database/virgil.md index 46a6e0c..e4ac366 100644 --- a/docs/Integrations/Database/virgil.md +++ b/docs/Integrations/Database/virgil.md @@ -14,11 +14,16 @@ import com.datastax.oss.driver.api.core.CqlSessionBuilder import io.kaizensolutions.trace4cats.zio.extras.ZTracer import io.kaizensolutions.virgil.trace4cats.zio.extras.TracedCQLExecutor import io.kaizensolutions.virgil.* +import io.kaizensolutions.virgil.codecs.* import io.kaizensolutions.virgil.cql.* import io.kaizensolutions.virgil.dsl.* import zio.* case class Person(id: Int, age: Int, name: String) +object Person { + // explicit declaration needed for Scala 3 + implicit val personCodec: CqlRowDecoder.Object[Person] = CqlRowDecoder.derive[Person] +} val insert: ZIO[CQLExecutor, Throwable, Unit] = for { diff --git a/docs/Integrations/HTTP/tapir.md b/docs/Integrations/HTTP/tapir.md index 4841c3b..1d5ad3f 100644 --- a/docs/Integrations/HTTP/tapir.md +++ b/docs/Integrations/HTTP/tapir.md @@ -9,18 +9,19 @@ This process requires each Tapir endpoint to reveal headers that hold trace info ```scala mdoc:compile-only import io.kaizensolutions.trace4cats.zio.extras.ZTracer -import io.kaizensolutions.trace4cats.zio.extras.tapir.TapirServerTracer +import io.kaizensolutions.trace4cats.zio.extras.tapir.TraceInterceptor import sttp.tapir.* import sttp.tapir.server.ServerEndpoint -import sttp.model.{Header, Headers, StatusCode} +import sttp.tapir.server.http4s.{Http4sServerInterpreter, Http4sServerOptions} +import sttp.model.StatusCode import zio.* +import zio.interop.catz.* final class CountCharactersEndpoint(tracer: ZTracer) { - private val countCharactersEndpoint: Endpoint[Unit, (String, List[Header]), Unit, Int, Any] = + private val countCharactersEndpoint: Endpoint[Unit, String, Unit, Int, Any] = endpoint.post .in("count" / "characters") .in(stringBody) - .in(headers) .errorOut(statusCode(StatusCode.BadRequest)) .out(plainBody[Int]) @@ -29,17 +30,22 @@ final class CountCharactersEndpoint(tracer: ZTracer) { else ZIO.succeed(raw.length) } - val countCharactersServerEndpoint: ServerEndpoint.Full[Unit, Unit, (String, List[Header]), Unit, Int, Any, Task] = - countCharactersEndpoint.serverLogic { case (raw, _) => countCharactersServerLogic(raw).either } - - val tracedEndpoint: ServerEndpoint.Full[Unit, Unit, (String, List[Header]), Unit, Int, Any, Task] = - TapirServerTracer.traceEndpoint( - tracer = tracer, - serverEndpoint = countCharactersServerEndpoint, - extractRequestHeaders = (input: (String, Seq[Header])) => Headers(input._2.toList), - extractResponseHeaders = (_: Int) => Headers(Nil) - ) + val countCharactersServerEndpoint: ServerEndpoint.Full[Unit, Unit, String, Unit, Int, Any, Task] = + countCharactersEndpoint.serverLogic { raw => countCharactersServerLogic(raw).either } } + +val http4sApp = + for { + tracer <- ZIO.service[ZTracer] + interceptor = TraceInterceptor(tracer) + endpoint = new CountCharactersEndpoint(tracer) + httpApp = Http4sServerInterpreter[Task]( + Http4sServerOptions + .default[Task] + .prependInterceptor(interceptor) + ).toRoutes(endpoint.countCharactersServerEndpoint).orNotFound + } yield httpApp + ``` The tracedEndpoint can then be used when compiling your Tapir endpoints down to the server's representation. diff --git a/tapir/src/test/scala/io/kaizensolutions/trace4cats/zio/extras/tapir/TraceInterceptorSpec.scala b/tapir/src/test/scala/io/kaizensolutions/trace4cats/zio/extras/tapir/TraceInterceptorSpec.scala index 76f2441..cbe2c27 100644 --- a/tapir/src/test/scala/io/kaizensolutions/trace4cats/zio/extras/tapir/TraceInterceptorSpec.scala +++ b/tapir/src/test/scala/io/kaizensolutions/trace4cats/zio/extras/tapir/TraceInterceptorSpec.scala @@ -6,7 +6,7 @@ import org.http4s.syntax.all.* import org.typelevel.ci.CIString import sttp.model.StatusCode import sttp.tapir.server.http4s.{Http4sServerInterpreter, Http4sServerOptions} -import sttp.tapir.ztapir.* +import sttp.tapir.ztapir.{path as pathParam, *} import trace4cats.{ToHeaders, TraceProcess} import zio.interop.catz.* import zio.test.* @@ -17,7 +17,7 @@ object TraceInterceptorSpec extends ZIOSpecDefault { private val testEndpoint = endpoint.get .in("hello") - .in(path[String]("name")) + .in(pathParam[String]("name")) .in("greeting") .out(statusCode(StatusCode.Ok))