From eef77f76fd0f7fddf70f2240f92c720d6cefe3fa Mon Sep 17 00:00:00 2001 From: turbofei Date: Thu, 9 Jan 2020 23:16:44 +0800 Subject: [PATCH 1/5] [SPARK-30374] Cast String to Integer Type, throw exception on format invalid and overflow. --- .../apache/spark/unsafe/types/UTF8String.java | 7 +++-- .../spark/sql/catalyst/expressions/Cast.scala | 28 ++++++++++++++++--- .../org/apache/spark/sql/SQLQuerySuite.scala | 6 ++++ 3 files changed, 35 insertions(+), 6 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 3754a1a0374a8..cb5881e0edff5 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -1075,8 +1075,8 @@ public UTF8String translate(Map dict) { * Wrapper over `long` to allow result of parsing long from string to be accessed via reference. * This is done solely for better performance and is not expected to be used by end users. */ - public static class LongWrapper implements Serializable { - public transient long value = 0; + public static class LongWrapper extends IntWrapper { + public transient long value = 0l; } /** @@ -1088,6 +1088,7 @@ public static class LongWrapper implements Serializable { */ public static class IntWrapper implements Serializable { public transient int value = 0; + public transient boolean formatInvalid = false; } /** @@ -1140,6 +1141,7 @@ public boolean toLong(LongWrapper toLongResult) { if (b >= '0' && b <= '9') { digit = b - '0'; } else { + toLongResult.formatInvalid = false; return false; } @@ -1233,6 +1235,7 @@ public boolean toInt(IntWrapper intWrapper) { if (b >= '0' && b <= '9') { digit = b - '0'; } else { + intWrapper.formatInvalid = true; return false; } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index fa27a48419dbb..9679b0b67cf7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -480,11 +480,29 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit buildCast[UTF8String](_, s => IntervalUtils.safeStringToInterval(s)) } + private[this] def onStringToIntegerFailed( + intWrapper: IntWrapper, + str: UTF8String, + ansiEnabled: Boolean, + typeName: String): Any = { + if (ansiEnabled) { + if (intWrapper.formatInvalid) { + throw new ArithmeticException(s"Invalid input syntax for type integer: $str") + } else { + throw new ArithmeticException(s"Casting $str to $typeName causes overflow") + } + } else { + null + } + } + // LongConverter private[this] def castToLong(from: DataType): Any => Any = from match { case StringType => val result = new LongWrapper() - buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null) + buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else { + onStringToIntegerFailed(result, s, ansiEnabled, "Long") + }) case BooleanType => buildCast[Boolean](_, b => if (b) 1L else 0L) case DateType => @@ -501,7 +519,9 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit private[this] def castToInt(from: DataType): Any => Any = from match { case StringType => val result = new IntWrapper() - buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null) + buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else { + onStringToIntegerFailed(result, s, ansiEnabled, "Int") + }) case BooleanType => buildCast[Boolean](_, b => if (b) 1 else 0) case DateType => @@ -523,7 +543,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit buildCast[UTF8String](_, s => if (s.toShort(result)) { result.value.toShort } else { - null + onStringToIntegerFailed(result, s, ansiEnabled, "Short") }) case BooleanType => buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort) @@ -564,7 +584,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit buildCast[UTF8String](_, s => if (s.toByte(result)) { result.value.toByte } else { - null + onStringToIntegerFailed(result, s, ansiEnabled, "Byte") }) case BooleanType => buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index cf24372e0e0b9..04366caebbc72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3384,6 +3384,12 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { assert(exp.getMessage.contains("Resources not found")) } } + + test("SPARK-26128: Throw exception on overflow when casting string to Integer.") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + sql("SELECT CAST(CAST(2147483648 as STRING) AS INTEGER)").show() + } + } } case class Foo(bar: Option[String]) From 264e4d52765c91b79b99239adf6d9c99120f5a40 Mon Sep 17 00:00:00 2001 From: turbofei Date: Fri, 10 Jan 2020 00:29:00 +0800 Subject: [PATCH 2/5] fix ut --- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 04366caebbc72..53c37f245acf5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3385,9 +3385,15 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { } } - test("SPARK-26128: Throw exception on overflow when casting string to Integer.") { + test("SPARK-30472: ANSI SQL: Throw exception on format invalid and overflow when casting " + + "String to Integer type.") { withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { - sql("SELECT CAST(CAST(2147483648 as STRING) AS INTEGER)").show() + intercept[ArithmeticException]( + sql("SELECT CAST(CAST('abc' as STRING) AS INTEGER)").collect() + ) + intercept[ArithmeticException]( + sql("SELECT CAST(CAST('2147483648' as STRING) AS INTEGER)").collect() + ) } } } From 0bd12f40c83cccdde041b2bf107def8b630d7b15 Mon Sep 17 00:00:00 2001 From: turbofei Date: Fri, 10 Jan 2020 00:30:33 +0800 Subject: [PATCH 3/5] fix style --- .../src/main/java/org/apache/spark/unsafe/types/UTF8String.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index cb5881e0edff5..a041abbe6f98c 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -1076,7 +1076,7 @@ public UTF8String translate(Map dict) { * This is done solely for better performance and is not expected to be used by end users. */ public static class LongWrapper extends IntWrapper { - public transient long value = 0l; + public transient long value = 0; } /** From df3b4d60242702fb319596e7a31227ae093571ea Mon Sep 17 00:00:00 2001 From: turbofei Date: Fri, 10 Jan 2020 00:34:32 +0800 Subject: [PATCH 4/5] refactor --- .../org/apache/spark/unsafe/types/UTF8String.java | 3 ++- .../spark/sql/catalyst/expressions/Cast.scala | 13 ++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index a041abbe6f98c..9b64144af64be 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -1075,8 +1075,9 @@ public UTF8String translate(Map dict) { * Wrapper over `long` to allow result of parsing long from string to be accessed via reference. * This is done solely for better performance and is not expected to be used by end users. */ - public static class LongWrapper extends IntWrapper { + public static class LongWrapper implements Serializable { public transient long value = 0; + public transient boolean formatInvalid = false; } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 9679b0b67cf7e..a61863c9a129b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -481,12 +481,11 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } private[this] def onStringToIntegerFailed( - intWrapper: IntWrapper, str: UTF8String, - ansiEnabled: Boolean, + formatInvalid: Boolean, typeName: String): Any = { if (ansiEnabled) { - if (intWrapper.formatInvalid) { + if (formatInvalid) { throw new ArithmeticException(s"Invalid input syntax for type integer: $str") } else { throw new ArithmeticException(s"Casting $str to $typeName causes overflow") @@ -501,7 +500,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit case StringType => val result = new LongWrapper() buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else { - onStringToIntegerFailed(result, s, ansiEnabled, "Long") + onStringToIntegerFailed(s, result.formatInvalid, "Long") }) case BooleanType => buildCast[Boolean](_, b => if (b) 1L else 0L) @@ -520,7 +519,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit case StringType => val result = new IntWrapper() buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else { - onStringToIntegerFailed(result, s, ansiEnabled, "Int") + onStringToIntegerFailed(s, result.formatInvalid, "Int") }) case BooleanType => buildCast[Boolean](_, b => if (b) 1 else 0) @@ -543,7 +542,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit buildCast[UTF8String](_, s => if (s.toShort(result)) { result.value.toShort } else { - onStringToIntegerFailed(result, s, ansiEnabled, "Short") + onStringToIntegerFailed(s, result.formatInvalid, "Short") }) case BooleanType => buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort) @@ -584,7 +583,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit buildCast[UTF8String](_, s => if (s.toByte(result)) { result.value.toByte } else { - onStringToIntegerFailed(result, s, ansiEnabled, "Byte") + onStringToIntegerFailed(s, result.formatInvalid, "Byte") }) case BooleanType => buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte) From 0c7c65dadb5d095e8ec501344ce2708b80b3e365 Mon Sep 17 00:00:00 2001 From: turbofei Date: Fri, 10 Jan 2020 00:41:52 +0800 Subject: [PATCH 5/5] add ut --- .../org/apache/spark/sql/SQLQuerySuite.scala | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 53c37f245acf5..59676eb973254 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3391,9 +3391,27 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { intercept[ArithmeticException]( sql("SELECT CAST(CAST('abc' as STRING) AS INTEGER)").collect() ) + intercept[ArithmeticException]( + sql("SELECT CAST(CAST('abc' as STRING) AS BYTE)").collect() + ) + intercept[ArithmeticException]( + sql("SELECT CAST(CAST('abc' as STRING) AS SHORT)").collect() + ) + intercept[ArithmeticException]( + sql("SELECT CAST(CAST('abc' as STRING) AS LONG)").collect() + ) intercept[ArithmeticException]( sql("SELECT CAST(CAST('2147483648' as STRING) AS INTEGER)").collect() ) + intercept[ArithmeticException]( + sql("SELECT CAST(CAST('128' as STRING) AS BYTE)").collect() + ) + intercept[ArithmeticException]( + sql("SELECT CAST(CAST('32768' as STRING) AS SHORT)").collect() + ) + intercept[ArithmeticException]( + sql("SELECT CAST(CAST('9223372036854775808' as STRING) AS LONG)").collect() + ) } } }