@@ -8,10 +8,13 @@ import Contexts.Context, Types._, Decorators._, Symbols._, DenotTransformers._
88import Denotations ._ , SymDenotations ._ , Scopes ._ , StdNames ._ , NameOps ._ , Names ._
99import ast .tpd
1010
11- class SpecializeFunctions extends MiniPhaseTransform with DenotTransformer {
11+ import scala .collection .mutable
12+ import scala .annotation .tailrec
13+
14+ class SpecializeFunctions extends MiniPhaseTransform with InfoTransformer {
1215 import ast .tpd ._
1316
14- val phaseName = " specializeFunction1 "
17+ val phaseName = " specializeFunctions "
1518
1619 // Setup ---------------------------------------------------------------------
1720 private [this ] val functionName = " JFunction" .toTermName
@@ -41,191 +44,115 @@ class SpecializeFunctions extends MiniPhaseTransform with DenotTransformer {
4144
4245 // Transformations -----------------------------------------------------------
4346
44- /** Transforms all classes extending `Function1[-T1, +R]` so that
45- * they instead extend the specialized version `JFunction$mp...`
46- */
47- def transform (ref : SingleDenotation )(implicit ctx : Context ) = ref match {
48- case cref @ ShouldTransformDenot (targets) => {
49- val specializedSymbols : Map [Symbol , (Symbol , Symbol )] = (for (SpecializationTarget (target, args, ret, original) <- targets) yield {
50- val arity = args.length
51- val specializedParent = ctx.getClassIfDefined {
52- functionPkg ++ specializedName(functionName ++ arity, args, ret)
53- }
47+ def transformInfo (tp : Type , sym : Symbol )(implicit ctx : Context ) = tp match {
48+ case tp : ClassInfo if ! sym.is(Flags .Package ) && (tp.decls ne EmptyScope ) => {
49+ val newDecls = tp.decls.cloneScope
50+ def newParents = tp.parents.mapConserve { parent =>
51+ if (defn.isPlainFunctionClass(parent.symbol)) {
52+ val typeParams = tp.typeRef.baseArgTypes(parent.classSymbol)
53+ val interface = specInterface(typeParams)
54+
55+ if (interface.exists) {
56+ val specializedApply : Symbol = {
57+ val specializedMethodName = specializedName(nme.apply, typeParams)
58+ ctx.newSymbol(
59+ sym,
60+ specializedMethodName,
61+ Flags .Override | Flags .Method ,
62+ interface.info.decls.lookup(specializedMethodName).info
63+ )
64+ }
5465
55- val specializedApply : Symbol = {
56- val specializedMethodName = specializedName(nme.apply, args, ret)
57- ctx.newSymbol(
58- cref.symbol,
59- specializedMethodName,
60- Flags .Override | Flags .Method ,
61- specializedParent.info.decls.lookup(specializedMethodName).info
62- )
66+ newDecls.enter(specializedApply)
67+ interface.typeRef
68+ }
69+ else parent
6370 }
71+ else parent
72+ }
6473
65- original -> (specializedParent, specializedApply )
66- }).toMap
74+ tp.derivedClassInfo(classParents = newParents, decls = newDecls )
75+ }
6776
68- def specializeApplys (scope : Scope ): Scope = {
69- val alteredScope = scope.cloneScope
70- specializedSymbols.values.foreach { case (_, apply) =>
71- alteredScope.enter(apply)
72- }
73- alteredScope
74- }
77+ case _ => tp
78+ }
7579
76- def replace (in : List [TypeRef ]): List [TypeRef ] =
77- in.map { tref =>
78- val sym = tref.symbol
79- specializedSymbols.get(sym).map { case (specializedParent, _) =>
80- specializedParent.typeRef
81- }
82- .getOrElse(tref)
80+ override def transformTemplate (tree : Template )(implicit ctx : Context , info : TransformerInfo ) = {
81+ val buf = new mutable.ListBuffer [Tree ]
82+ val newBody = tree.body.mapConserve {
83+ case dt : DefDef if dt.name == nme.apply && dt.vparamss.length == 1 => {
84+ val specializedApply = ctx.owner.info.decls.lookup {
85+ specializedName(
86+ nme.apply,
87+ dt.vparamss.head.map(_.symbol.info) :+ dt.tpe.widen.finalResultType
88+ )
8389 }
8490
85- val ClassInfo (prefix, cls, parents, decls, info) = cref.classInfo
86- val newParents = replace(parents)
87- val newInfo = ClassInfo (prefix, cls, newParents, specializeApplys(decls), info)
88- cref.copySymDenotation(info = newInfo)
91+ if (specializedApply.exists) {
92+ val apply = specializedApply.asTerm
93+ val specializedDecl =
94+ polyDefDef(apply, trefs => vrefss => {
95+ dt.rhs
96+ .changeOwner(dt.symbol, apply)
97+ .subst(dt.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol))
98+ })
99+
100+ buf += specializedDecl
101+
102+ // create a forwarding to the specialized apply
103+ cpy.DefDef (dt)(rhs = {
104+ tpd
105+ .ref(apply)
106+ .appliedToArgs(dt.vparamss.head.map(vparam => ref(vparam.symbol)))
107+ })
108+ } else dt
109+ }
110+ case x => x
89111 }
90- case _ => ref
91- }
92112
93- /** Transform the class definition's `Template`:
94- *
95- * - change the tree to have the correct parent
96- * - add the specialized apply method to the template body
97- * - forward the old `apply` to the specialized version
98- */
99- override def transformTemplate (tree : Template )(implicit ctx : Context , info : TransformerInfo ) =
100- tree match {
101- case tmpl @ ShouldTransformTree (targets) => {
102- val symbolMap = (for ((tree, SpecializationTarget (target, args, ret, orig)) <- targets) yield {
103- val arity = args.length
104- val specializedParent = TypeTree {
105- ctx.requiredClassRef(functionPkg ++ specializedName(functionName ++ arity, args, ret))
106- }
107- val specializedMethodName = specializedName(nme.apply, args, ret)
108- val specializedApply = ctx.owner.info.decls.lookup(specializedMethodName)
109-
110- if (specializedApply.exists)
111- Some (orig -> (specializedParent, specializedApply.asTerm))
112- else None
113- }).flatten.toMap
114-
115- val body0 = tmpl.body.foldRight(List .empty[Tree ]) {
116- case (tree : DefDef , acc) if tree.name == nme.apply => {
117- val inheritedFrom =
118- tree.symbol.allOverriddenSymbols
119- .map(_.owner)
120- .map(symbolMap.get)
121- .flatten
122- .toList
123- .headOption
124-
125- inheritedFrom.map { case (parent, apply) =>
126- val forwardingBody = tpd
127- .ref(apply)
128- .appliedToArgs(tree.vparamss.head.map(vparam => ref(vparam.symbol)))
129-
130- val applyWithForwarding = cpy.DefDef (tree)(rhs = forwardingBody)
131-
132- val specializedApplyDefDef =
133- polyDefDef(apply, trefs => vrefss => {
134- tree.rhs
135- .changeOwner(tree.symbol, apply)
136- .subst(tree.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol))
137- })
138-
139- applyWithForwarding :: specializedApplyDefDef :: acc
140- }
141- .getOrElse(tree :: acc)
142- }
143- case (tree, acc) => tree :: acc
144- }
113+ val newParents = tree.parents.mapConserve { parent =>
114+ if (defn.isPlainFunctionClass(parent.symbol)) {
115+ val typeParams = tree.tpe.baseArgTypes(parent.symbol)
116+ val interface = specInterface(typeParams)
145117
146- val specializedParents = tree.parents.map { t =>
147- symbolMap
148- .get(t.symbol)
149- .map { case (newSym, _) => newSym }
150- .getOrElse(t)
118+ if (interface.exists) TypeTree (interface.info)
119+ else parent
151120 }
152-
153- cpy.Template (tmpl)(parents = specializedParents, body = body0)
154- }
155- case _ => tree
121+ else parent
156122 }
157123
124+ cpy.Template (tree)(parents = newParents, body = buf.toList ++ newBody)
125+ }
126+
158127 /** Dispatch to specialized `apply`s in user code */
159128 override def transformApply (tree : Apply )(implicit ctx : Context , info : TransformerInfo ) = {
160129 import ast .Trees ._
161130 tree match {
162- case Apply (select @ Select (id, nme.apply), arg :: Nil ) =>
131+ case Apply (select @ Select (id, nme.apply), arg :: Nil ) => {
163132 val params = List (arg.tpe, tree.tpe)
164- val specializedApply = nme.apply.specializedFor(params , params.map(_.typeSymbol.name), Nil , Nil )
165- val hasOverridenSpecializedApply = id.tpe.decls.iterator.exists { sym =>
166- sym.is(Flags .Override ) && (sym.name eq specializedApply)
133+ val specializedApply = specializedName( nme.apply, params)
134+ val hasOverridenSpecializedApply = id.tpe.decls.iterator.exists {
135+ sym => sym .is(Flags .Override ) && (sym.name eq specializedApply)
167136 }
168137
169138 if (hasOverridenSpecializedApply) tpd.Apply (tpd.Select (id, specializedApply), arg :: Nil )
170139 else tree
140+ }
171141 case _ => tree
172142 }
173143 }
174144
175- private def specializedName (name : Name , args : List [Type ], ret : Type )(implicit ctx : Context ): Name = {
176- val typeParams = args :+ ret
177- name.specializedFor(typeParams, typeParams.map(_.typeSymbol.name), Nil , Nil )
178- }
179-
180- // Extractors ----------------------------------------------------------------
181- private object ShouldTransformDenot {
182- def unapply (cref : ClassDenotation )(implicit ctx : Context ): Option [Seq [SpecializationTarget ]] =
183- if (! cref.classParents.map(_.symbol).exists(defn.isPlainFunctionClass)) None
184- else Some (getSpecTargets(cref.typeRef))
185- }
145+ @ inline private def specializedName (name : Name , args : List [Type ])(implicit ctx : Context ): Name =
146+ name.specializedFor(args, args.map(_.typeSymbol.name), Nil , Nil )
186147
187- private object ShouldTransformTree {
188- def unapply (tree : Template )(implicit ctx : Context ): Option [Seq [(Tree , SpecializationTarget )]] = {
189- val treeToTargets = tree.parents
190- .map(t => (t, getSpecTargets(t.tpe)))
191- .filter(_._2.nonEmpty)
192- .map { case (t, xs) => (t, xs.head) }
148+ @ inline private def specInterface (typeParams : List [Type ])(implicit ctx : Context ) = {
149+ val args = typeParams.init
150+ val ret = typeParams.last
193151
194- if (treeToTargets.isEmpty) None else Some (treeToTargets)
195- }
196- }
152+ val specName =
153+ (functionName ++ args.length)
154+ .specializedFor(typeParams, typeParams.map(_.typeSymbol.name), Nil , Nil )
197155
198- private case class SpecializationTarget (target : Symbol ,
199- params : List [Type ],
200- ret : Type ,
201- original : Symbol )
202-
203- /** Gets all valid specialization targets on `tpe`, allowing multiple
204- * implementations of FunctionX traits
205- */
206- private def getSpecTargets (tpe : Type )(implicit ctx : Context ): List [SpecializationTarget ] = {
207- val functionParents =
208- tpe.classSymbols.iterator
209- .flatMap(_.baseClasses)
210- .filter(defn.isPlainFunctionClass)
211-
212- val tpeCls = tpe.widenDealias
213- functionParents.map { sym =>
214- val typeParams = tpeCls.baseArgTypes(sym)
215- val args = typeParams.init
216- val ret = typeParams.last
217-
218- val interfaceName =
219- (functionName ++ args.length)
220- .specializedFor(typeParams, typeParams.map(_.typeSymbol.name), Nil , Nil )
221-
222- val interface = ctx.getClassIfDefined(functionPkg ++ interfaceName)
223-
224- if (interface.exists) Some {
225- SpecializationTarget (interface, args, ret, sym)
226- }
227- else None
228- }
229- .flatten.toList
156+ ctx.getClassIfDefined(functionPkg ++ specName)
230157 }
231158}
0 commit comments