@@ -14,23 +14,28 @@ class SpecializeFunction1 extends MiniPhaseTransform with DenotTransformer {
1414 val phaseName = " specializeFunction1"
1515
1616 // Setup ---------------------------------------------------------------------
17- private [this ] val functionName = " JFunction1 " .toTermName
17+ private [this ] val functionName = " JFunction " .toTermName
1818 private [this ] val functionPkg = " scala.compat.java8." .toTermName
1919 private [this ] var argTypes : Set [Symbol ] = _
2020 private [this ] var retTypes : Set [Symbol ] = _
2121
2222 override def prepareForUnit (tree : Tree )(implicit ctx : Context ) = {
23- retTypes = Set (defn.BooleanClass ,
24- defn.DoubleClass ,
25- defn.FloatClass ,
23+ retTypes = Set (defn.UnitClass ,
24+ defn.BooleanClass ,
2625 defn.IntClass ,
26+ defn.FloatClass ,
2727 defn.LongClass ,
28- defn.UnitClass )
28+ defn.DoubleClass ,
29+ /* only for Function0: */
30+ defn.ByteClass ,
31+ defn.ShortClass ,
32+ defn.CharClass )
2933
30- argTypes = Set (defn.DoubleClass ,
31- defn.FloatClass ,
32- defn.IntClass ,
33- defn.LongClass )
34+ argTypes = Set (defn.IntClass ,
35+ defn.LongClass ,
36+ defn.DoubleClass ,
37+ /* only for Function1: */
38+ defn.FloatClass )
3439 this
3540 }
3641
@@ -40,37 +45,46 @@ class SpecializeFunction1 extends MiniPhaseTransform with DenotTransformer {
4045 * they instead extend the specialized version `JFunction$mp...`
4146 */
4247 def transform (ref : SingleDenotation )(implicit ctx : Context ) = ref match {
43- case ShouldTransformDenot (cref, t1, r, func1) => {
44- val specializedFunction : Symbol =
45- ctx.getClassIfDefined(functionPkg ++ specializedName(functionName, t1, r))
46-
47- def replaceFunction1 (in : List [TypeRef ]): List [TypeRef ] =
48- in.mapConserve { tp =>
49- if (tp.isRef(defn.FunctionClass (1 )) && (specializedFunction ne NoSymbol ))
50- specializedFunction.typeRef
51- else tp
48+ case cref @ ShouldTransformDenot (targets) => {
49+ val specializedSymbols : Map [Symbol , (Symbol , Symbol )] = (for (SpecializationTarget (target, args, ret, original) <- targets) yield {
50+ val arity = args.length
51+ val specializedParent = ctx.getClassIfDefined {
52+ functionPkg ++ specializedName(functionName ++ arity, args, ret)
5253 }
5354
54- def specializeApply (scope : Scope ): Scope =
55- if ((specializedFunction ne NoSymbol ) && (scope.lookup(nme.apply) ne NoSymbol )) {
56- def specializedApply : Symbol = {
57- val specializedMethodName = specializedName(nme.apply, t1, r)
58- ctx.newSymbol(
59- cref.symbol,
60- specializedMethodName,
61- Flags .Override | Flags .Method ,
62- specializedFunction.info.decls.lookup(specializedMethodName).info
63- )
64- }
55+ val specializedApply : Symbol = {
56+ val specializedMethodName = specializedName(nme.apply, args, ret)
57+ ctx.newSymbol(
58+ cref.symbol,
59+ specializedMethodName,
60+ Flags .Override | Flags .Method ,
61+ specializedParent.info.decls.lookup(specializedMethodName).info
62+ )
63+ }
64+
65+ original -> (specializedParent, specializedApply)
66+ }).toMap
6567
66- val alteredScope = scope.cloneScope
67- alteredScope.enter(specializedApply)
68- alteredScope
68+ def specializeApplys (scope : Scope ): Scope = {
69+ val alteredScope = scope.cloneScope
70+ specializedSymbols.values.foreach { case (_, apply) =>
71+ alteredScope.enter(apply)
72+ }
73+ alteredScope
74+ }
75+
76+ def replace (in : List [TypeRef ]): List [TypeRef ] =
77+ in.map { tref =>
78+ val sym = tref.symbol
79+ specializedSymbols.get(sym).map { case (specializedParent, _) =>
80+ specializedParent.typeRef
81+ }
82+ .getOrElse(tref)
6983 }
70- else scope
7184
7285 val ClassInfo (prefix, cls, parents, decls, info) = cref.classInfo
73- val newInfo = ClassInfo (prefix, cls, replaceFunction1(in = parents), specializeApply(decls), info)
86+ val newParents = replace(parents)
87+ val newInfo = ClassInfo (prefix, cls, newParents, specializeApplys(decls), info)
7488 cref.copySymDenotation(info = newInfo)
7589 }
7690 case _ => ref
@@ -84,37 +98,51 @@ class SpecializeFunction1 extends MiniPhaseTransform with DenotTransformer {
8498 */
8599 override def transformTemplate (tree : Template )(implicit ctx : Context , info : TransformerInfo ) =
86100 tree match {
87- case tmpl @ ShouldTransformTree (func1, t1, r) => {
88- val specializedFunc1 =
89- TypeTree (ctx.requiredClassRef(functionPkg ++ specializedName(functionName, t1, r)))
101+ case tmpl @ ShouldTransformTree (targets) => {
102+ val symbolMap = (for ((tree, SpecializationTarget (target, args, ret, orig)) <- targets) yield {
103+ val arity = args.length
104+ val specializedParent = TypeTree {
105+ ctx.requiredClassRef(functionPkg ++ specializedName(functionName ++ arity, args, ret))
106+ }
107+ val specializedMethodName = specializedName(nme.apply, args, ret)
108+ val specializedApply = ctx.owner.info.decls.lookup(specializedMethodName).asTerm
90109
91- val parents = tmpl.parents.mapConserve { t =>
92- if (func1.isDefined && (func1.get eq t)) specializedFunc1 else t
93- }
110+ orig -> (specializedParent, specializedApply)
111+ }).toMap
94112
95- val body = tmpl.body.foldRight(List .empty[Tree ]) {
113+ val body0 = tmpl.body.foldRight(List .empty[Tree ]) {
96114 case (tree : DefDef , acc) if tree.name == nme.apply => {
97- val specializedMethodName = specializedName(nme.apply, t1, r)
98- val specializedApply = ctx.owner.info.decls.lookup(specializedMethodName).asTerm
99-
100- val forwardingBody =
101- tpd.ref(specializedApply)
102- .appliedToArgs(tree.vparamss.head.map(vparam => ref(vparam.symbol)))
103-
104- val applyWithForwarding = cpy.DefDef (tree)(rhs = forwardingBody)
105-
106- val specializedApplyDefDef = polyDefDef(specializedApply, trefs => vrefss => {
107- tree.rhs
108- .changeOwner(tree.symbol, specializedApply)
109- .subst(tree.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol))
110- })
111-
112- applyWithForwarding :: specializedApplyDefDef :: acc
115+ val inheritedFrom =
116+ tree.symbol.allOverriddenSymbols
117+ .map(_.owner)
118+ .map(symbolMap.get)
119+ .flatten
120+ .toList
121+ .headOption
122+
123+ inheritedFrom.map { case (parent, apply) =>
124+ val forwardingBody = tpd
125+ .ref(apply)
126+ .appliedToArgs(tree.vparamss.head.map(vparam => ref(vparam.symbol)))
127+
128+ val applyWithForwarding = cpy.DefDef (tree)(rhs = forwardingBody)
129+
130+ val specializedApplyDefDef =
131+ polyDefDef(apply, trefs => vrefss => {
132+ tree.rhs
133+ .changeOwner(tree.symbol, apply)
134+ .subst(tree.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol))
135+ })
136+
137+ applyWithForwarding :: specializedApplyDefDef :: acc
138+ }
139+ .getOrElse(tree :: acc)
113140 }
114141 case (tree, acc) => tree :: acc
115142 }
143+ val parents = symbolMap.map { case (_, (parent, _)) => parent }
116144
117- cpy.Template (tmpl)(parents = parents, body = body )
145+ cpy.Template (tmpl)(parents = parents.toList , body = body0 )
118146 }
119147 case _ => tree
120148 }
@@ -136,28 +164,60 @@ class SpecializeFunction1 extends MiniPhaseTransform with DenotTransformer {
136164 }
137165 }
138166
139- private def specializedName (name : Name , t1 : Type , r : Type )(implicit ctx : Context ): Name =
140- name.specializedFor(List (t1, r), List (t1, r).map(_.typeSymbol.name), Nil , Nil )
167+ private def specializedName (name : Name , args : List [Type ], ret : Type )(implicit ctx : Context ): Name = {
168+ val typeParams = args :+ ret
169+ name.specializedFor(typeParams, typeParams.map(_.typeSymbol.name), Nil , Nil )
170+ }
141171
142172 // Extractors ----------------------------------------------------------------
143173 private object ShouldTransformDenot {
144- def unapply (cref : ClassDenotation )(implicit ctx : Context ): Option [( ClassDenotation , Type , Type , Type ) ] =
145- if (! cref.classParents.exists (_.isRef (defn.FunctionClass ( 1 )) )) None
146- else getFunc1( cref.typeRef).map { case (t1, r, func1) => (cref, t1, r, func1) }
174+ def unapply (cref : ClassDenotation )(implicit ctx : Context ): Option [Seq [ SpecializationTarget ] ] =
175+ if (! cref.classParents.map (_.symbol).exists (defn.isFunctionClass )) None
176+ else Some (getSpecTargets( cref.typeRef))
147177 }
148178
149179 private object ShouldTransformTree {
150- def unapply (tree : Template )(implicit ctx : Context ): Option [(Option [Tree ], Type , Type )] =
151- tree.parents.find(_.tpe.isRef(defn.FunctionClass (1 ))).flatMap { t =>
152- getFunc1(t.tpe).map { case (t1, r, _) => (Some (t), t1, r) }
153- }
180+ def unapply (tree : Template )(implicit ctx : Context ): Option [Seq [(Tree , SpecializationTarget )]] = {
181+ val treeToTargets = tree.parents
182+ .map(t => (t, getSpecTargets(t.tpe)))
183+ .filter(_._2.nonEmpty)
184+ .map { case (t, xs) => (t, xs.head) }
185+
186+ if (treeToTargets.isEmpty) None else Some (treeToTargets)
187+ }
154188 }
155189
156- private def getFunc1 (tpe : Type )(implicit ctx : Context ): Option [(Type , Type , Type )] =
157- tpe.baseTypeWithArgs(defn.FunctionClass (1 )) match {
158- case func1 @ RefinedType (RefinedType (parent, _, t1), _, r) if (
159- argTypes.contains(t1.typeSymbol) && retTypes.contains(r.typeSymbol)
160- ) => Some ((t1, r, func1))
161- case _ => None
190+ private case class SpecializationTarget (target : Symbol ,
191+ params : List [Type ],
192+ ret : Type ,
193+ original : Symbol )
194+
195+ /** Gets all valid specialization targets on `tpe`, allowing multiple
196+ * implementations of FunctionX traits
197+ */
198+ private def getSpecTargets (tpe : Type )(implicit ctx : Context ): List [SpecializationTarget ] = {
199+ val functionParents =
200+ tpe.classSymbols.iterator
201+ .flatMap(_.baseClasses)
202+ .filter(defn.isFunctionClass)
203+
204+ val tpeCls = tpe.widenDealias
205+ functionParents.map { sym =>
206+ val typeParams = tpeCls.baseArgTypes(sym)
207+ val args = typeParams.init
208+ val ret = typeParams.last
209+
210+ val interfaceName =
211+ (functionName ++ args.length)
212+ .specializedFor(typeParams, typeParams.map(_.typeSymbol.name), Nil , Nil )
213+
214+ val interface = ctx.getClassIfDefined(functionPkg ++ interfaceName)
215+
216+ if (interface.exists) Some {
217+ SpecializationTarget (interface, args, ret, sym)
218+ }
219+ else None
162220 }
221+ .flatten.toList
222+ }
163223}
0 commit comments