@@ -17,6 +17,8 @@ import transform.SyntheticMembers._
1717import util .Property
1818import annotation .{tailrec , constructorOnly }
1919
20+ import scala .collection .mutable
21+
2022/** Synthesize terms for special classes */
2123class Synthesizer (typer : Typer )(using @ constructorOnly c : Context ):
2224 import ast .tpd ._
@@ -339,7 +341,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
339341 if acceptable(mirroredType) && cls.isGenericSum(if useCompanion then cls.linkedClass else ctx.owner) then
340342 val elemLabels = cls.children.map(c => ConstantType (Constant (c.name.toString)))
341343
342- def solve (sym : Symbol ): Type = sym match
344+ def solve (target : Type )( sym : Symbol ): Type = sym match
343345 case childClass : ClassSymbol =>
344346 assert(childClass.isOneOf(Case | Sealed ))
345347 if childClass.is(Module ) then
@@ -350,36 +352,50 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
350352 // Compute the the full child type by solving the subtype constraint
351353 // `C[X1, ..., Xn] <: P`, where
352354 //
353- // - P is the current `mirroredType `
355+ // - P is the current `targetPart `
354356 // - C is the child class, with type parameters X1, ..., Xn
355357 //
356358 // Contravariant type parameters are minimized, all other type parameters are maximized.
357- def instantiate (using Context ) =
358- val poly = constrained(info, untpd. EmptyTree )._1
359+ def instantiate (targetPart : Type )( using Context ) =
360+ val poly = constrained(info)
359361 val resType = poly.finalResultType
360- val target = mirroredType match
361- case tp : HKTypeLambda => tp.resultType
362- case tp => tp
363- resType <:< target
362+ resType <:< targetPart // record constraints
364363 val tparams = poly.paramRefs
365364 val variances = childClass.typeParams.map(_.paramVarianceSign)
366365 val instanceTypes = tparams.lazyZip(variances).map((tparam, variance) =>
367366 TypeComparer .instanceType(tparam, fromBelow = variance < 0 ))
368367 resType.substParams(poly, instanceTypes)
369- instantiate(using ctx.fresh.setExploreTyperState().setOwner(childClass))
368+
369+ def instantiateAll (using Context ): Type =
370+
371+ // instantiate for each part of a union type, compute lub of the results
372+ def loop (explore : List [Type ], acc : mutable.ListBuffer [Type ]): Type = explore match
373+ case OrType (tp1, tp2) :: rest => loop(tp1 :: tp2 :: rest, acc )
374+ case tp :: rest => loop(rest , acc += instantiate(tp))
375+ case _ => TypeComparer .lub(acc.toList)
376+
377+ def instantiateLub (tp1 : Type , tp2 : Type ): Type =
378+ loop(tp1 :: tp2 :: Nil , new mutable.ListBuffer [Type ])
379+
380+ target match
381+ case OrType (tp1, tp2) => instantiateLub(tp1, tp2)
382+ case _ => instantiate(target)
383+
384+ instantiateAll(using ctx.fresh.setExploreTyperState().setOwner(childClass))
370385 case _ =>
371386 childClass.typeRef
372387 case child => child.termRef
373388 end solve
374389
375390 val (monoType, elemsType) = mirroredType match
376391 case mirroredType : HKTypeLambda =>
392+ val target = mirroredType.resultType
377393 val elems = mirroredType.derivedLambdaType(
378- resType = TypeOps .nestedPairs(cls.children.map(solve))
394+ resType = TypeOps .nestedPairs(cls.children.map(solve(target) ))
379395 )
380396 (mkMirroredMonoType(mirroredType), elems)
381- case _ =>
382- val elems = TypeOps .nestedPairs(cls.children.map(solve))
397+ case target =>
398+ val elems = TypeOps .nestedPairs(cls.children.map(solve(target) ))
383399 (mirroredType, elems)
384400
385401 val mirrorType =
0 commit comments