@@ -40,49 +40,74 @@ class BetaReduce extends MiniPhase:
4040
4141 override def transformApply (app : Apply )(using Context ): Tree = app.fun match
4242 case Select (fn, nme.apply) if defn.isFunctionType(fn.tpe) =>
43- val app1 = BetaReduce (app, fn, app.args)
43+ val app1 = BetaReduce (app, fn, List (app.args))
44+ if app1 ne app then report.log(i " beta reduce $app -> $app1" )
45+ app1
46+ case TypeApply (Select (fn, nme.apply), targs) if fn.tpe.typeSymbol eq defn.PolyFunctionClass =>
47+ val app1 = BetaReduce (app, fn, List (targs, app.args))
4448 if app1 ne app then report.log(i " beta reduce $app -> $app1" )
4549 app1
4650 case _ =>
4751 app
4852
49-
5053object BetaReduce :
5154 import ast .tpd ._
5255
5356 val name : String = " betaReduce"
5457 val description : String = " reduce closure applications"
5558
5659 /** Beta-reduces a call to `fn` with arguments `argSyms` or returns `tree` */
57- def apply (original : Tree , fn : Tree , args : List [Tree ])(using Context ): Tree =
60+ def apply (original : Tree , fn : Tree , argss : List [List [ Tree ] ])(using Context ): Tree =
5861 fn match
5962 case Typed (expr, _) =>
60- BetaReduce (original, expr, args )
63+ BetaReduce (original, expr, argss )
6164 case Block ((anonFun : DefDef ) :: Nil , closure : Closure ) =>
62- BetaReduce (anonFun, args)
65+ BetaReduce (anonFun, argss)
66+ case Block ((TypeDef (_, template : Template )) :: Nil , Typed (Apply (Select (New (_), _), _), _)) if template.constr.rhs.isEmpty =>
67+ template.body match
68+ case (anonFun : DefDef ) :: Nil =>
69+ BetaReduce (anonFun, argss)
70+ case _ =>
71+ original
6372 case Block (stats, expr) =>
64- val tree = BetaReduce (original, expr, args )
73+ val tree = BetaReduce (original, expr, argss )
6574 if tree eq original then original
6675 else cpy.Block (fn)(stats, tree)
6776 case Inlined (call, bindings, expr) =>
68- val tree = BetaReduce (original, expr, args )
77+ val tree = BetaReduce (original, expr, argss )
6978 if tree eq original then original
7079 else cpy.Inlined (fn)(call, bindings, tree)
7180 case _ =>
7281 original
7382 end apply
7483
7584 /** Beta-reduces a call to `ddef` with arguments `args` */
76- def apply (ddef : DefDef , args : List [Tree ])(using Context ) =
77- val bindings = new ListBuffer [ValDef ]()
78- val expansion1 = reduceApplication(ddef, args , bindings)
85+ def apply (ddef : DefDef , argss : List [List [ Tree ] ])(using Context ) =
86+ val bindings = new ListBuffer [DefTree ]()
87+ val expansion1 = reduceApplication(ddef, argss , bindings)
7988 val bindings1 = bindings.result()
8089 seq(bindings1, expansion1)
8190
8291 /** Beta-reduces a call to `ddef` with arguments `args` and registers new bindings */
83- def reduceApplication (ddef : DefDef , args : List [Tree ], bindings : ListBuffer [ValDef ])(using Context ): Tree =
84- val vparams = ddef.termParamss.iterator.flatten.toList
92+ def reduceApplication (ddef : DefDef , argss : List [List [Tree ]], bindings : ListBuffer [DefTree ])(using Context ): Tree =
93+ assert(argss.size == 1 || argss.size == 2 )
94+ val targs = if argss.size == 2 then argss.head else Nil
95+ val args = argss.last
96+ val tparams = ddef.leadingTypeParams
97+ val vparams = ddef.termParamss.flatten
98+ assert(targs.hasSameLengthAs(tparams))
8599 assert(args.hasSameLengthAs(vparams))
100+
101+ val targSyms =
102+ for (targ, tparam) <- targs.zip(tparams) yield
103+ targ.tpe.dealias match
104+ case ref @ TypeRef (NoPrefix , _) =>
105+ ref.symbol
106+ case _ =>
107+ val binding = TypeDef (newSymbol(ctx.owner, tparam.name, EmptyFlags , targ.tpe, coord = targ.span)).withSpan(targ.span)
108+ bindings += binding
109+ binding.symbol
110+
86111 val argSyms =
87112 for (arg, param) <- args.zip(vparams) yield
88113 arg.tpe.dealias match
@@ -99,8 +124,8 @@ object BetaReduce:
99124 val expansion = TreeTypeMap (
100125 oldOwners = ddef.symbol :: Nil ,
101126 newOwners = ctx.owner :: Nil ,
102- substFrom = vparams.map(_.symbol),
103- substTo = argSyms
127+ substFrom = (tparams ::: vparams) .map(_.symbol),
128+ substTo = targSyms ::: argSyms
104129 ).transform(ddef.rhs)
105130
106131 val expansion1 = new TreeMap {
0 commit comments