From 1d1493da02c925d9bfe0ca36b577b035f4011e28 Mon Sep 17 00:00:00 2001 From: jilen Date: Wed, 9 Oct 2024 11:35:38 +0800 Subject: [PATCH] handle type params in constructor --- .../internals/ConstructorCrimper.scala | 62 ++++++++++++------- .../test-cases/simpleGenericClass.success | 14 +++++ 2 files changed, 52 insertions(+), 24 deletions(-) create mode 100644 tests/src/test/resources/test-cases/simpleGenericClass.success diff --git a/macros/src/main/scala-3/com/softwaremill/macwire/internals/ConstructorCrimper.scala b/macros/src/main/scala-3/com/softwaremill/macwire/internals/ConstructorCrimper.scala index 230291f8..10ec1627 100644 --- a/macros/src/main/scala-3/com/softwaremill/macwire/internals/ConstructorCrimper.scala +++ b/macros/src/main/scala-3/com/softwaremill/macwire/internals/ConstructorCrimper.scala @@ -51,7 +51,28 @@ private[macwire] class ConstructorCrimper[Q <: Quotes, T: Type](using val q: Q)( ctor } - lazy val constructorParamLists: Option[List[List[Symbol]]] = constructor.map(_.paramSymss) + private def constructorParamTypes(ctorType: TypeRepr): List[List[TypeRepr]] = { + ctorType match { + case MethodType(_, paramTypes, retType) => + paramTypes.map(_.widen.simplified) :: constructorParamTypes(retType.simplified) + case _ => + Nil + } + } + + lazy val constructorParamLists: Option[List[List[(Symbol, TypeRepr)]]] = { + constructor.map { c => + // paramSymss contains both type arg symbols (generic types) and value arg symbols + val symLists = c.paramSymss.filter(_.forall(!_.isTypeDef)) + val ctorType = + if (targetType.typeArgs.isEmpty) targetType.memberType(c) + else targetType.memberType(c).appliedTo(targetType.typeArgs) + val typeLists = constructorParamTypes(ctorType) + symLists.zip(typeLists).map { case (syms, tpes) => + syms.zip(tpes) + } + } + } lazy val constructorArgs: Option[List[List[Term]]] = log.withBlock("Looking for targetConstructor arguments") { constructorParamLists.map(wireConstructorParamsWithImplicitLookups) @@ -62,35 +83,28 @@ private[macwire] class ConstructorCrimper[Q <: Quotes, T: Type](using val q: Q)( constructorValue <- constructor constructorArgsValue <- constructorArgs } yield { - val constructionMethodTree: Term = Select(New(TypeIdent(targetType.typeSymbol)), constructorValue) + val constructionMethodTree: Term = { + val ctor = Select(New(TypeIdent(targetType.typeSymbol)), constructorValue) + if (targetType.typeArgs.isEmpty) ctor else ctor.appliedToTypes(targetType.typeArgs) + } constructorArgsValue.foldLeft(constructionMethodTree)((acc: Term, args: List[Term]) => Apply(acc, args)) } } - def wireConstructorParams(paramLists: List[List[Symbol]]): List[List[Term]] = - paramLists.map(_.map(p => dependencyResolver.resolve(p, /*SI-4751*/ paramType(p)))) - - def wireConstructorParamsWithImplicitLookups(paramLists: List[List[Symbol]]): List[List[Term]] = paramLists.map { - case params if params.forall(_.flags is Flags.Implicit) => params.map(resolveImplicitOrFail) - case params => params.map(p => dependencyResolver.resolve(p, /*SI-4751*/ paramType(p))) - } - - private def resolveImplicitOrFail(param: Symbol): Term = Implicits.search(paramType(param)) match { - case iss: ImplicitSearchSuccess => iss.tree - case isf: ImplicitSearchFailure => report.throwError(s"Failed to resolve an implicit for [$param].") - } + def wireConstructorParams(paramLists: List[List[(Symbol, TypeRepr)]]): List[List[Term]] = + paramLists.map(_.map(p => dependencyResolver.resolve(p._1, /*SI-4751*/ p._2))) - private def paramType(param: Symbol): TypeRepr = { - // val (sym: Symbol, tpeArgs: List[Type]) = targetTypeD match { - // case TypeRef(_, sym, tpeArgs) => (sym, tpeArgs) - // case t => abort(s"Target type not supported for wiring: $t. Please file a bug report with your use-case.") - // } - // val pTpe = param.signature.substituteTypes(sym.asClass.typeParams, tpeArgs) - // if (param.asTerm.isByNameParam) pTpe.typeArgs.head else pTpe + def wireConstructorParamsWithImplicitLookups(paramLists: List[List[(Symbol, TypeRepr)]]): List[List[Term]] = + paramLists.map { + case params if params.forall(_._1.flags is Flags.Implicit) => params.map(resolveImplicitOrFail) + case params => params.map(p => dependencyResolver.resolve(p._1, /*SI-4751*/ p._2)) + } - //FIXME assertion error in test inheritanceHKT.success, selfTypeHKT.success - Ref(param).tpe.widen - } + private def resolveImplicitOrFail(param: Symbol, paramType: TypeRepr): Term = + Implicits.search(paramType) match { + case iss: ImplicitSearchSuccess => iss.tree + case isf: ImplicitSearchFailure => report.throwError(s"Failed to resolve an implicit for [$param].") + } /** In some cases there is one extra (phantom) constructor. This happens when extended trait has implicit param: * diff --git a/tests/src/test/resources/test-cases/simpleGenericClass.success b/tests/src/test/resources/test-cases/simpleGenericClass.success new file mode 100644 index 00000000..26112b5c --- /dev/null +++ b/tests/src/test/resources/test-cases/simpleGenericClass.success @@ -0,0 +1,14 @@ +#include commonHKTClasses + +case class LocalA[X]() +case class LocalB[X](la: LocalA[X]) + +object Test { + val la = wire[LocalA[Int]] + val lb = wire[LocalB[Int]] + val a = wire[A[IO]] + val b = wire[B[IO]] +} + +require(Test.lb.la == Test.la) +require(Test.b.a == Test.a)