diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 0053503501047..8f777997bf615 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -924,27 +924,36 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String } private[this] def changePrecision(d: ExprValue, decimalType: DecimalType, - evPrim: ExprValue, evNull: ExprValue): Block = - code""" - if ($d.changePrecision(${decimalType.precision}, ${decimalType.scale})) { - $evPrim = $d; - } else { - $evNull = true; - } - """ + evPrim: ExprValue, evNull: ExprValue, canNullSafeCast: Boolean): Block = { + if (canNullSafeCast) { + code""" + |$d.changePrecision(${decimalType.precision}, ${decimalType.scale}); + |$evPrim = $d; + """.stripMargin + } else { + code""" + |if ($d.changePrecision(${decimalType.precision}, ${decimalType.scale})) { + | $evPrim = $d; + |} else { + | $evNull = true; + |} + """.stripMargin + } + } private[this] def castToDecimalCode( from: DataType, target: DecimalType, ctx: CodegenContext): CastFunction = { val tmp = ctx.freshVariable("tmpDecimal", classOf[Decimal]) + val canNullSafeCast = Cast.canNullSafeCastToDecimal(from, target) from match { case StringType => (c, evPrim, evNull) => code""" try { Decimal $tmp = Decimal.apply(new java.math.BigDecimal($c.toString())); - ${changePrecision(tmp, target, evPrim, evNull)} + ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} } catch (java.lang.NumberFormatException e) { $evNull = true; } @@ -953,7 +962,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => code""" Decimal $tmp = $c ? Decimal.apply(1) : Decimal.apply(0); - ${changePrecision(tmp, target, evPrim, evNull)} + ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} """ case DateType => // date can't cast to decimal in Hive @@ -964,19 +973,19 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String code""" Decimal $tmp = Decimal.apply( scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)})); - ${changePrecision(tmp, target, evPrim, evNull)} + ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} """ case DecimalType() => (c, evPrim, evNull) => code""" Decimal $tmp = $c.clone(); - ${changePrecision(tmp, target, evPrim, evNull)} + ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} """ case x: IntegralType => (c, evPrim, evNull) => code""" Decimal $tmp = Decimal.apply((long) $c); - ${changePrecision(tmp, target, evPrim, evNull)} + ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} """ case x: FractionalType => // All other numeric types can be represented precisely as Doubles @@ -984,7 +993,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String code""" try { Decimal $tmp = Decimal.apply(scala.math.BigDecimal.valueOf((double) $c)); - ${changePrecision(tmp, target, evPrim, evNull)} + ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast)} } catch (java.lang.NumberFormatException e) { $evNull = true; }