From e152abe5087121c0a1b715e8190db2ecf56149c8 Mon Sep 17 00:00:00 2001 From: Nicolas Stucki Date: Thu, 15 Apr 2021 09:59:24 +0200 Subject: [PATCH] Improve invariant checks for erased terms Instead of erasing an erased term to `???` we erase it to `erasedValue[T]`. This has 2 advantages, first, the term does not lose its type, and second, the term is still marked as erased. The second implies that if there is a bug in the compiler or a macro where the term might end outside an erased context, the code will not compiler. Currently, the code compiles and then throws when calling the spurious `???`. See #11996. --- .../dotty/tools/dotc/transform/PostTyper.scala | 13 ++++++------- .../tools/dotc/transform/PruneErasedDefs.scala | 15 ++++++++------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala index 8e9b95a0c572..a4d93d2fc318 100644 --- a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala +++ b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala @@ -243,10 +243,8 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase private object dropInlines extends TreeMap { override def transform(tree: Tree)(using Context): Tree = tree match { case Inlined(call, _, expansion) => - val newExpansion = tree.tpe match - case ConstantType(c) => Literal(c) - case _ => Typed(ref(defn.Predef_undefined), TypeTree(tree.tpe)) - cpy.Inlined(tree)(call, Nil, newExpansion.withSpan(tree.span)) + val newExpansion = PruneErasedDefs.trivialErasedTree(tree) + cpy.Inlined(tree)(call, Nil, newExpansion) case _ => super.transform(tree) } } @@ -282,7 +280,8 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase tpd.cpy.Apply(tree)( tree.fun, tree.args.mapConserve(arg => - if (methType.isImplicitMethod && arg.span.isSynthetic) ref(defn.Predef_undefined) + if (methType.isImplicitMethod && arg.span.isSynthetic) + PruneErasedDefs.trivialErasedTree(arg) else dropInlines.transform(arg))) else tree @@ -414,12 +413,12 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase // case x: (_: Tree[?]) case m @ MatchTypeTree(bounds, selector, cases) => // Analog to the case above for match types - def tranformIgnoringBoundsCheck(x: CaseDef): CaseDef = + def transformIgnoringBoundsCheck(x: CaseDef): CaseDef = withMode(Mode.Pattern)(super.transform(x)).asInstanceOf[CaseDef] cpy.MatchTypeTree(tree)( super.transform(bounds), super.transform(selector), - cases.mapConserve(tranformIgnoringBoundsCheck) + cases.mapConserve(transformIgnoringBoundsCheck) ) case Block(_, Closure(_, _, tpt)) if ExpandSAMs.needsWrapperClass(tpt.tpe) => superAcc.withInvalidCurrentClass(super.transform(tree)) diff --git a/compiler/src/dotty/tools/dotc/transform/PruneErasedDefs.scala b/compiler/src/dotty/tools/dotc/transform/PruneErasedDefs.scala index 5412736628db..b8e7a7924a79 100644 --- a/compiler/src/dotty/tools/dotc/transform/PruneErasedDefs.scala +++ b/compiler/src/dotty/tools/dotc/transform/PruneErasedDefs.scala @@ -22,6 +22,7 @@ import ast.tpd */ class PruneErasedDefs extends MiniPhase with SymTransformer { thisTransform => import tpd._ + import PruneErasedDefs._ override def phaseName: String = PruneErasedDefs.name @@ -39,19 +40,19 @@ class PruneErasedDefs extends MiniPhase with SymTransformer { thisTransform => override def transformValDef(tree: ValDef)(using Context): Tree = if !tree.symbol.isEffectivelyErased || tree.rhs.isEmpty then tree - else cpy.ValDef(tree)(rhs = trivialErasedTree(tree)) + else cpy.ValDef(tree)(rhs = trivialErasedTree(tree.rhs)) override def transformDefDef(tree: DefDef)(using Context): Tree = if !tree.symbol.isEffectivelyErased || tree.rhs.isEmpty then tree - else cpy.DefDef(tree)(rhs = trivialErasedTree(tree)) - - private def trivialErasedTree(tree: Tree)(using Context): Tree = - tree.tpe.widenTermRefExpr.dealias.normalized match - case ConstantType(c) => Literal(c) - case _ => ref(defn.Predef_undefined) + else cpy.DefDef(tree)(rhs = trivialErasedTree(tree.rhs)) } object PruneErasedDefs { + import tpd._ + val name: String = "pruneErasedDefs" + + def trivialErasedTree(tree: Tree)(using Context): Tree = + ref(defn.Compiletime_erasedValue).appliedToType(tree.tpe).withSpan(tree.span) }