Skip to content

Commit

Permalink
Merge pull request #35 from kaizen-solutions/tapir
Browse files Browse the repository at this point in the history
Implement tracing RequestInterceptor for Tapir
  • Loading branch information
calvinlfer authored Dec 12, 2023
2 parents 8879387 + 8b53868 commit ce24507
Show file tree
Hide file tree
Showing 8 changed files with 276 additions and 149 deletions.
18 changes: 12 additions & 6 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -291,16 +291,21 @@ lazy val tapir =
.in(file("tapir"))
.settings(kindProjectorSettings*)
.settings(
name := "trace4cats-zio-extras-tapir",
libraryDependencies += "com.softwaremill.sttp.tapir" %% "tapir-core" % Versions.tapir
name := "trace4cats-zio-extras-tapir",
libraryDependencies ++=
Seq(
"com.softwaremill.sttp.tapir" %% "tapir-server" % Versions.tapir,
"com.softwaremill.sttp.tapir" %% "tapir-zio" % Versions.tapir % Test,
"com.softwaremill.sttp.tapir" %% "tapir-http4s-server" % Versions.tapir % Test
)
)
.dependsOn(core)
.dependsOn(core % "compile->compile;test->test")

lazy val tapirExample =
project
.in(file("tapir-examples"))
.settings(
name := "trace4cats-zio-extras-zio-sttp-examples",
name := "trace4cats-zio-extras-zio--sttp-examples",
publish / skip := true,
libraryDependencies ++=
Seq(
Expand Down Expand Up @@ -375,7 +380,7 @@ lazy val skunk =
"org.tpolecat" %% "skunk-core" % Versions.skunk,
"io.zonky.test" % "embedded-postgres" % Versions.embeddedPostgres % Test,
"dev.zio" %% "zio-logging-slf4j" % Versions.zioLogging % Test,
"ch.qos.logback" % "logback-classic" % "1.4.11" % Test
"ch.qos.logback" % "logback-classic" % "1.4.14" % Test
)
)
.dependsOn(
Expand Down Expand Up @@ -403,7 +408,7 @@ lazy val zioKafka =
"dev.zio" %% "zio-kafka" % Versions.zioKafka,
"io.github.embeddedkafka" %% "embedded-kafka" % Versions.kafkaEmbedded % Test,
"dev.zio" %% "zio-logging-slf4j" % Versions.zioLogging % Test,
"ch.qos.logback" % "logback-classic" % "1.4.11" % Test
"ch.qos.logback" % "logback-classic" % "1.4.14" % Test
),
excludeDependencies ++= {
CrossVersion.partialVersion(scalaVersion.value) match {
Expand Down Expand Up @@ -450,6 +455,7 @@ lazy val docs =
sttp,
sttpExample,
tapir,
tapirExample,
virgil,
doobie,
skunk,
Expand Down
5 changes: 5 additions & 0 deletions docs/Integrations/Database/virgil.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,16 @@ import com.datastax.oss.driver.api.core.CqlSessionBuilder
import io.kaizensolutions.trace4cats.zio.extras.ZTracer
import io.kaizensolutions.virgil.trace4cats.zio.extras.TracedCQLExecutor
import io.kaizensolutions.virgil.*
import io.kaizensolutions.virgil.codecs.*
import io.kaizensolutions.virgil.cql.*
import io.kaizensolutions.virgil.dsl.*
import zio.*

case class Person(id: Int, age: Int, name: String)
object Person {
// explicit declaration needed for Scala 3
implicit val personCodec: CqlRowDecoder.Object[Person] = CqlRowDecoder.derive[Person]
}

val insert: ZIO[CQLExecutor, Throwable, Unit] =
for {
Expand Down
34 changes: 20 additions & 14 deletions docs/Integrations/HTTP/tapir.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,19 @@ This process requires each Tapir endpoint to reveal headers that hold trace info

```scala mdoc:compile-only
import io.kaizensolutions.trace4cats.zio.extras.ZTracer
import io.kaizensolutions.trace4cats.zio.extras.tapir.TapirServerTracer
import io.kaizensolutions.trace4cats.zio.extras.tapir.TraceInterceptor
import sttp.tapir.*
import sttp.tapir.server.ServerEndpoint
import sttp.model.{Header, Headers, StatusCode}
import sttp.tapir.server.http4s.{Http4sServerInterpreter, Http4sServerOptions}
import sttp.model.StatusCode
import zio.*
import zio.interop.catz.*

final class CountCharactersEndpoint(tracer: ZTracer) {
private val countCharactersEndpoint: Endpoint[Unit, (String, List[Header]), Unit, Int, Any] =
private val countCharactersEndpoint: Endpoint[Unit, String, Unit, Int, Any] =
endpoint.post
.in("count" / "characters")
.in(stringBody)
.in(headers)
.errorOut(statusCode(StatusCode.BadRequest))
.out(plainBody[Int])

Expand All @@ -29,17 +30,22 @@ final class CountCharactersEndpoint(tracer: ZTracer) {
else ZIO.succeed(raw.length)
}

val countCharactersServerEndpoint: ServerEndpoint.Full[Unit, Unit, (String, List[Header]), Unit, Int, Any, Task] =
countCharactersEndpoint.serverLogic { case (raw, _) => countCharactersServerLogic(raw).either }

val tracedEndpoint: ServerEndpoint.Full[Unit, Unit, (String, List[Header]), Unit, Int, Any, Task] =
TapirServerTracer.traceEndpoint(
tracer = tracer,
serverEndpoint = countCharactersServerEndpoint,
extractRequestHeaders = (input: (String, Seq[Header])) => Headers(input._2.toList),
extractResponseHeaders = (_: Int) => Headers(Nil)
)
val countCharactersServerEndpoint: ServerEndpoint.Full[Unit, Unit, String, Unit, Int, Any, Task] =
countCharactersEndpoint.serverLogic { raw => countCharactersServerLogic(raw).either }
}

val http4sApp =
for {
tracer <- ZIO.service[ZTracer]
interceptor = TraceInterceptor(tracer)
endpoint = new CountCharactersEndpoint(tracer)
httpApp = Http4sServerInterpreter[Task](
Http4sServerOptions
.default[Task]
.prependInterceptor(interceptor)
).toRoutes(endpoint.countCharactersServerEndpoint).orNotFound
} yield httpApp

```

The tracedEndpoint can then be used when compiling your Tapir endpoints down to the server's representation.
Expand Down
8 changes: 4 additions & 4 deletions project/Versions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@ object Versions {
val trace4CatsJaegarExporter = "0.14.0"
val trace4CatsHttp4sCommon = "0.14.1"
val zio = "2.0.19"
val zioLogging = "2.1.15"
val zioLogging = "2.1.16"
val zioInteropCats = "23.1.0.0"
val zioHttp = "3.0.0-RC3"
val http4s = "0.23.24"
val tapir = "1.9.1"
val tapir = "1.9.4"
val sttp = "3.9.1"
val virgil = "1.0.4"
val doobie = "1.0.0-RC5"
val skunk = "0.6.1"
val embeddedPostgres = "2.0.4"
val skunk = "0.6.2"
val embeddedPostgres = "2.0.6"
val scribe = "3.12.2"
val zioKafka = "2.7.0"
val kafkaEmbedded = "3.6.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ import fs2.io.net.Network
import io.circe.Codec as CirceCodec
import io.circe.generic.semiauto.deriveCodec
import io.kaizensolutions.trace4cats.zio.extras.ZTracer
import io.kaizensolutions.trace4cats.zio.extras.tapir.TapirServerTracer
import io.kaizensolutions.trace4cats.zio.extras.tapir.TraceInterceptor
import org.http4s.ember.server.EmberServerBuilder
import sttp.model.{Headers, StatusCode}
import sttp.tapir.*
import sttp.tapir.json.circe.*
import sttp.tapir.server.ServerEndpoint
import sttp.tapir.server.http4s.Http4sServerInterpreter
import sttp.tapir.server.http4s.{Http4sServerInterpreter, Http4sServerOptions}
import trace4cats.kernel.ToHeaders
import zio.*
import zio.interop.catz.*

Expand All @@ -22,7 +23,7 @@ object ExampleServerApp extends ZIOAppDefault {
def countCharacters(tracer: ZTracer)(in: Request): UIO[Either[NoCharacters, Int]] = {
val l = in.input.length
val out = tracer.spanSource() {
if (l > 0) ZIO.succeed(l)
if (l > 0) ZIO.logInfo(s"Received ${in.input}").as(l)
else ZIO.fail(NoCharacters("Please supply at least 1 character to count"))
}

Expand All @@ -41,22 +42,16 @@ object ExampleServerApp extends ZIOAppDefault {
def serverEndpoint(tracer: ZTracer): ServerEndpoint.Full[Unit, Unit, Request, NoCharacters, Int, Any, Task] =
countCharactersEndpoint.serverLogic(countCharacters(tracer))

def tracedServerEndpoint(tracer: ZTracer): ServerEndpoint.Full[Unit, Unit, Request, NoCharacters, Int, Any, Task] =
TapirServerTracer
.traceEndpoint[Request, NoCharacters, Int, Any, Any, Throwable](
tracer = tracer,
serverEndpoint = serverEndpoint(tracer),
extractRequestHeaders = _.headers,
extractResponseHeaders = _ => Headers(Nil)
)

override val run: ZIO[ZIOAppArgs & Scope, Any, Any] = {
val program =
for {
tracer <- ZIO.service[ZTracer]
endpoint = tracedServerEndpoint(tracer)
httpApp = Http4sServerInterpreter[Task]().toRoutes(endpoint).orNotFound
port <- ZIO.fromEither(Port.fromInt(8080).toRight(new RuntimeException("Invalid Port")))
endpoint = serverEndpoint(tracer)
serverOptions = Http4sServerOptions
.default[Task]
.prependInterceptor(TraceInterceptor(tracer, headerFormat = ToHeaders.b3Single))
httpApp = Http4sServerInterpreter[Task](serverOptions).toRoutes(endpoint).orNotFound
port <- ZIO.fromEither(Port.fromInt(8080).toRight(new RuntimeException("Invalid Port")))
server <- EmberServerBuilder
.default[Task]
.withHostOption(Host.fromString("localhost"))
Expand Down

This file was deleted.

Loading

0 comments on commit ce24507

Please sign in to comment.