@@ -2,30 +2,34 @@ package dotty.tools.dotc
22package ast
33
44import core ._
5- import Symbols ._ , Types ._ , Contexts ._ , Flags ._ , Constants ._
6- import StdNames .nme
7-
8- /** Generate proxy classes for @main functions.
9- * A function like
10- *
11- * @main def f(x: S, ys: T*) = ...
12- *
13- * would be translated to something like
14- *
15- * import CommandLineParser._
16- * class f {
17- * @static def main(args: Array[String]): Unit =
18- * try
19- * f(
20- * parseArgument[S](args, 0),
21- * parseRemainingArguments[T](args, 1): _*
22- * )
23- * catch case err: ParseError => showError(err)
24- * }
25- */
26- object MainProxies {
5+ import Symbols ._ , Types ._ , Contexts ._ , Decorators ._ , util .Spans ._ , Flags ._ , Constants ._
6+ import StdNames .{nme , tpnme }
7+ import ast .Trees ._
8+ import Names .{Name , TermName }
9+ import Comments .Comment
10+ import NameKinds .DefaultGetterName
11+ import Annotations .Annotation
2712
28- def mainProxies (stats : List [tpd.Tree ])(using Context ): List [untpd.Tree ] = {
13+ object MainProxies {
14+ /** Generate proxy classes for @main functions.
15+ * A function like
16+ *
17+ * @main def f(x: S, ys: T*) = ...
18+ *
19+ * would be translated to something like
20+ *
21+ * import CommandLineParser._
22+ * class f {
23+ * @static def main(args: Array[String]): Unit =
24+ * try
25+ * f(
26+ * parseArgument[S](args, 0),
27+ * parseRemainingArguments[T](args, 1): _*
28+ * )
29+ * catch case err: ParseError => showError(err)
30+ * }
31+ */
32+ def mainProxiesOld (stats : List [tpd.Tree ])(using Context ): List [untpd.Tree ] = {
2933 import tpd ._
3034 def mainMethods (stats : List [Tree ]): List [Symbol ] = stats.flatMap {
3135 case stat : DefDef if stat.symbol.hasAnnotation(defn.MainAnnot ) =>
@@ -35,11 +39,11 @@ object MainProxies {
3539 case _ =>
3640 Nil
3741 }
38- mainMethods(stats).flatMap(mainProxy )
42+ mainMethods(stats).flatMap(mainProxyOld )
3943 }
4044
4145 import untpd ._
42- def mainProxy (mainFun : Symbol )(using Context ): List [TypeDef ] = {
46+ def mainProxyOld (mainFun : Symbol )(using Context ): List [TypeDef ] = {
4347 val mainAnnotSpan = mainFun.getAnnotation(defn.MainAnnot ).get.tree.span
4448 def pos = mainFun.sourcePos
4549 val argsRef = Ident (nme.args)
@@ -114,4 +118,311 @@ object MainProxies {
114118 }
115119 result
116120 }
121+
122+ private type DefaultValueSymbols = Map [Int , Symbol ]
123+ private type ParameterAnnotationss = Seq [Seq [Annotation ]]
124+
125+ /**
126+ * Generate proxy classes for main functions.
127+ * A function like
128+ *
129+ * /* *
130+ * * Lorem ipsum dolor sit amet
131+ * * consectetur adipiscing elit.
132+ * *
133+ * * @param x my param x
134+ * * @param ys all my params y
135+ * */
136+ * @main(80) def f(
137+ * @main.Alias("myX") x: S,
138+ * ys: T*
139+ * ) = ...
140+ *
141+ * would be translated to something like
142+ *
143+ * final class f {
144+ * static def main(args: Array[String]): Unit = {
145+ * val cmd = new main(80).command(
146+ * args,
147+ * "f",
148+ * "Lorem ipsum dolor sit amet consectetur adipiscing elit.",
149+ * new scala.annotation.MainAnnotation.ParameterInfos("x", "S")
150+ * .withDocumentation("my param x")
151+ * .withAnnotations(new scala.main.Alias("myX")),
152+ * new scala.annotation.MainAnnotation.ParameterInfos("ys", "T")
153+ * .withDocumentation("all my params y")
154+ * )
155+ *
156+ * val args0: () => S = cmd.argGetter[S]("x", None)
157+ * val args1: () => Seq[T] = cmd.varargGetter[T]("ys")
158+ *
159+ * cmd.run(f(args0(), args1()*))
160+ * }
161+ * }
162+ */
163+ def mainProxies (stats : List [tpd.Tree ])(using Context ): List [untpd.Tree ] = {
164+ import tpd ._
165+
166+ /**
167+ * Computes the symbols of the default values of the function. Since they cannot be infered anymore at this
168+ * point of the compilation, they must be explicitely passed by [[mainProxy ]].
169+ */
170+ def defaultValueSymbols (scope : Tree , funSymbol : Symbol ): DefaultValueSymbols =
171+ scope match {
172+ case TypeDef (_, template : Template ) =>
173+ template.body.flatMap((_ : Tree ) match {
174+ case dd : DefDef if dd.name.is(DefaultGetterName ) && dd.name.firstPart == funSymbol.name =>
175+ val DefaultGetterName .NumberedInfo (index) = dd.name.info
176+ List (index -> dd.symbol)
177+ case _ => Nil
178+ }).toMap
179+ case _ => Map .empty
180+ }
181+
182+ /** Computes the list of main methods present in the code. */
183+ def mainMethods (scope : Tree , stats : List [Tree ]): List [(Symbol , ParameterAnnotationss , DefaultValueSymbols , Option [Comment ])] = stats.flatMap {
184+ case stat : DefDef =>
185+ val sym = stat.symbol
186+ sym.annotations.filter(_.matches(defn.MainAnnot )) match {
187+ case Nil =>
188+ Nil
189+ case _ :: Nil =>
190+ val paramAnnotations = stat.paramss.flatMap(_.map(
191+ valdef => valdef.symbol.annotations.filter(_.matches(defn.MainAnnotParameterAnnotation ))
192+ ))
193+ (sym, paramAnnotations.toVector, defaultValueSymbols(scope, sym), stat.rawComment) :: Nil
194+ case mainAnnot :: others =>
195+ report.error(s " method cannot have multiple main annotations " , mainAnnot.tree)
196+ Nil
197+ }
198+ case stat @ TypeDef (_, impl : Template ) if stat.symbol.is(Module ) =>
199+ mainMethods(stat, impl.body)
200+ case _ =>
201+ Nil
202+ }
203+
204+ // Assuming that the top-level object was already generated, all main methods will have a scope
205+ mainMethods(EmptyTree , stats).flatMap(mainProxy)
206+ }
207+
208+ def mainProxy (mainFun : Symbol , paramAnnotations : ParameterAnnotationss , defaultValueSymbols : DefaultValueSymbols , docComment : Option [Comment ])(using Context ): List [TypeDef ] = {
209+ val mainAnnot = mainFun.getAnnotation(defn.MainAnnot ).get
210+ def pos = mainFun.sourcePos
211+ val cmdName : TermName = Names .termName(" cmd" )
212+
213+ val documentation = new Documentation (docComment)
214+
215+ /** A literal value (Boolean, Int, String, etc.) */
216+ inline def lit (any : Any ): Literal = Literal (Constant (any))
217+
218+ /** None */
219+ inline def none : Tree = ref(defn.NoneModule .termRef)
220+
221+ /** Some(value) */
222+ inline def some (value : Tree ): Tree = Apply (ref(defn.SomeClass .companionModule.termRef), value)
223+
224+ /** () => value */
225+ def unitToValue (value : Tree ): Tree =
226+ val anonName = nme.ANON_FUN
227+ val defdef = DefDef (anonName, List (Nil ), TypeTree (), value)
228+ Block (defdef, Closure (Nil , Ident (anonName), EmptyTree ))
229+
230+ /**
231+ * Creates a list of references and definitions of arguments, the first referencing the second.
232+ * The goal is to create the
233+ * `val args0: () => S = cmd.argGetter[S]("x", None)`
234+ * part of the code.
235+ * For each tuple, the first element is a ref to `args0`, the second is the whole definition, the third
236+ * is the ParameterInfos definition associated to this argument.
237+ */
238+ def createArgs (mt : MethodType , cmdName : TermName ): List [(Tree , ValDef , Tree )] =
239+ mt.paramInfos.zip(mt.paramNames).zipWithIndex.map {
240+ case ((formal, paramName), n) =>
241+ val argName = nme.args ++ n.toString
242+ val isRepeated = formal.isRepeatedParam
243+
244+ val (argRef, formalType, getterSym) = {
245+ val argRef0 = Apply (Ident (argName), Nil )
246+ if formal.isRepeatedParam then
247+ (repeated(argRef0), formal.argTypes.head, defn.MainAnnotCommand_varargGetter )
248+ else (argRef0, formal, defn.MainAnnotCommand_argGetter )
249+ }
250+
251+ // The ParameterInfos
252+ val parameterInfos = {
253+ val param = paramName.toString
254+ val paramInfosTree = New (
255+ TypeTree (defn.MainAnnotParameterInfos .typeRef),
256+ // Arguments to be passed to ParameterInfos' constructor
257+ List (List (lit(param), lit(formalType.show)))
258+ )
259+
260+ /*
261+ * Assignations to be made after the creation of the ParameterInfos.
262+ * For example:
263+ * args0paramInfos.withDocumentation("my param x")
264+ * is represented by the pair
265+ * defn.MainAnnotationParameterInfos_withDocumentation -> List(lit("my param x"))
266+ */
267+ var assignations : List [(Symbol , List [Tree ])] = Nil
268+ for (doc <- documentation.argDocs.get(param))
269+ assignations = (defn.MainAnnotationParameterInfos_withDocumentation -> List (lit(doc))) :: assignations
270+
271+ val instanciatedAnnots = paramAnnotations(n).map(instanciateAnnotation).toList
272+ if instanciatedAnnots.nonEmpty then
273+ assignations = (defn.MainAnnotationParameterInfos_withAnnotations -> instanciatedAnnots) :: assignations
274+
275+ assignations.foldLeft[Tree ](paramInfosTree){ case (tree, (setterSym, values)) => Apply (Select (tree, setterSym.name), values) }
276+ }
277+
278+ val argParams =
279+ if formal.isRepeatedParam then
280+ List (lit(paramName.toString))
281+ else
282+ val defaultValueGetterOpt = defaultValueSymbols.get(n) match {
283+ case None =>
284+ none
285+ case Some (dvSym) =>
286+ some(unitToValue(ref(dvSym.termRef)))
287+ }
288+ List (lit(paramName.toString), defaultValueGetterOpt)
289+
290+ val argDef = ValDef (
291+ argName,
292+ TypeTree (),
293+ Apply (TypeApply (Select (Ident (cmdName), getterSym.name), TypeTree (formalType) :: Nil ), argParams),
294+ )
295+
296+ (argRef, argDef, parameterInfos)
297+ }
298+ end createArgs
299+
300+ /** Turns an annotation (e.g. `@main(40)`) into an instance of the class (e.g. `new scala.main(40)`). */
301+ def instanciateAnnotation (annot : Annotation ): Tree =
302+ val argss = {
303+ def recurse (t : tpd.Tree , acc : List [List [Tree ]]): List [List [Tree ]] = t match {
304+ case Apply (t, args : List [tpd.Tree ]) => recurse(t, extractArgs(args) :: acc)
305+ case _ => acc
306+ }
307+
308+ def extractArgs (args : List [tpd.Tree ]): List [Tree ] =
309+ args.flatMap {
310+ case Typed (SeqLiteral (varargs, _), _) => varargs.map(arg => TypedSplice (arg))
311+ case arg : Select if arg.name.is(DefaultGetterName ) => Nil // Ignore default values, they will be added later by the compiler
312+ case arg => List (TypedSplice (arg))
313+ }
314+
315+ recurse(annot.tree, Nil )
316+ }
317+
318+ New (TypeTree (annot.symbol.typeRef), argss)
319+ end instanciateAnnotation
320+
321+ var result : List [TypeDef ] = Nil
322+ if (! mainFun.owner.isStaticOwner)
323+ report.error(s " main method is not statically accessible " , pos)
324+ else {
325+ var args : List [ValDef ] = Nil
326+ var mainCall : Tree = ref(mainFun.termRef)
327+ var parameterInfoss : List [Tree ] = Nil
328+
329+ mainFun.info match {
330+ case _ : ExprType =>
331+ case mt : MethodType =>
332+ if (mt.isImplicitMethod) {
333+ report.error(s " main method cannot have implicit parameters " , pos)
334+ }
335+ else mt.resType match {
336+ case restpe : MethodType =>
337+ report.error(s " main method cannot be curried " , pos)
338+ Nil
339+ case _ =>
340+ val (argRefs, argVals, paramInfoss) = createArgs(mt, cmdName).unzip3
341+ args = argVals
342+ mainCall = Apply (mainCall, argRefs)
343+ parameterInfoss = paramInfoss
344+ }
345+ case _ : PolyType =>
346+ report.error(s " main method cannot have type parameters " , pos)
347+ case _ =>
348+ report.error(s " main can only annotate a method " , pos)
349+ }
350+
351+ val cmd = ValDef (
352+ cmdName,
353+ TypeTree (),
354+ Apply (
355+ Select (instanciateAnnotation(mainAnnot), defn.MainAnnot_command .name),
356+ Ident (nme.args) :: lit(mainFun.showName) :: lit(documentation.mainDoc) :: parameterInfoss
357+ )
358+ )
359+ val run = Apply (Select (Ident (cmdName), defn.MainAnnotCommand_run .name), mainCall)
360+ val body = Block (cmd :: args, run)
361+ val mainArg = ValDef (nme.args, TypeTree (defn.ArrayType .appliedTo(defn.StringType )), EmptyTree )
362+ .withFlags(Param )
363+ /** Replace typed `Ident`s that have been typed with a TypeSplice with the reference to the symbol.
364+ * The annotations will be retype-checked in another scope that may not have the same imports.
365+ */
366+ def insertTypeSplices = new TreeMap {
367+ override def transform (tree : Tree )(using Context ): Tree = tree match
368+ case tree : tpd.Ident @ unchecked => TypedSplice (tree)
369+ case tree => super .transform(tree)
370+ }
371+ val annots = mainFun.annotations
372+ .filterNot(_.matches(defn.MainAnnot ))
373+ .map(annot => insertTypeSplices.transform(annot.tree))
374+ val mainMeth = DefDef (nme.main, (mainArg :: Nil ) :: Nil , TypeTree (defn.UnitType ), body)
375+ .withFlags(JavaStatic )
376+ .withAnnotations(annots)
377+ val mainTempl = Template (emptyConstructor, Nil , Nil , EmptyValDef , mainMeth :: Nil )
378+ val mainCls = TypeDef (mainFun.name.toTypeName, mainTempl)
379+ .withFlags(Final | Invisible )
380+ if (! ctx.reporter.hasErrors) result = mainCls.withSpan(mainAnnot.tree.span.toSynthetic) :: Nil
381+ }
382+ result
383+ }
384+
385+ /** A class responsible for extracting the docstrings of a method. */
386+ private class Documentation (docComment : Option [Comment ]):
387+ import util .CommentParsing ._
388+
389+ /** The main part of the documentation. */
390+ lazy val mainDoc : String = _mainDoc
391+ /** The parameters identified by @param. Maps from parameter name to its documentation. */
392+ lazy val argDocs : Map [String , String ] = _argDocs
393+
394+ private var _mainDoc : String = " "
395+ private var _argDocs : Map [String , String ] = Map ()
396+
397+ docComment match {
398+ case Some (comment) => if comment.isDocComment then parseDocComment(comment.raw) else _mainDoc = comment.raw
399+ case None =>
400+ }
401+
402+ private def cleanComment (raw : String ): String =
403+ var lines : Seq [String ] = raw.trim.split('\n ' ).toSeq
404+ lines = lines.map(l => l.substring(skipLineLead(l, - 1 ), l.length).trim)
405+ var s = lines.foldLeft(" " ) {
406+ case (" " , s2) => s2
407+ case (s1, " " ) if s1.last == '\n ' => s1 // Multiple newlines are kept as single newlines
408+ case (s1, " " ) => s1 + '\n '
409+ case (s1, s2) if s1.last == '\n ' => s1 + s2
410+ case (s1, s2) => s1 + ' ' + s2
411+ }
412+ s.replaceAll(raw " \[\[ " , " " ).replaceAll(raw " \]\] " , " " ).trim
413+
414+ private def parseDocComment (raw : String ): Unit =
415+ // Positions of the sections (@) in the docstring
416+ val tidx : List [(Int , Int )] = tagIndex(raw)
417+
418+ // Parse main comment
419+ var mainComment : String = raw.substring(skipLineLead(raw, 0 ), startTag(raw, tidx))
420+ _mainDoc = cleanComment(mainComment)
421+
422+ // Parse arguments comments
423+ val argsCommentsSpans : Map [String , (Int , Int )] = paramDocs(raw, " @param" , tidx)
424+ val argsCommentsTextSpans = argsCommentsSpans.view.mapValues(extractSectionText(raw, _))
425+ val argsCommentsTexts = argsCommentsTextSpans.mapValues({ case (beg, end) => raw.substring(beg, end) })
426+ _argDocs = argsCommentsTexts.mapValues(cleanComment(_)).toMap
427+ end Documentation
117428}
0 commit comments