From feca67a0922d97749380d5ef8944cefe55aa6c2b Mon Sep 17 00:00:00 2001 From: shivsood Date: Wed, 13 Nov 2019 17:56:13 -0800 Subject: [PATCH] This is a port SPARK-29644 to 2.4 What changes were proposed in this pull request? Corrected ShortType and ByteType mapping to SmallInt and TinyInt, corrected setter methods to set ShortType and ByteType as setShort() and setByte(). Changes in JDBCUtils.scala Fixed Unit test cases to where applicable and added new E2E test cases in to test table read/write using ShortType and ByteType. Problems - In master in JDBCUtils.scala line number 547 and 551 have a problem where ShortType and ByteType are set as Integers rather than set as Short and Byte respectively. ``` case ShortType => (stmt: PreparedStatement, row: Row, pos: Int) => stmt.setInt(pos + 1, row.getShort(pos)) The issue was pointed out by @maropu case ByteType => (stmt: PreparedStatement, row: Row, pos: Int) => stmt.setInt(pos + 1, row.getByte(pos)) - Also at line JDBCUtils.scala 247 TinyInt is interpreted wrongly as IntergetType in getCatalystType() ``` case java.sql.Types.TINYINT => IntegerType ``` - At line 172 ShortType was wrongly interpreted as IntegerType ``` case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT)) ``` - All thru out tests, ShortType and ByteType were being interpreted as IntegerTypes. Why are the changes needed? Given type should be set using the right type. Does this PR introduce any user-facing change? No How was this patch tested? Corrected Unit test cases where applicable. Validated in CI/CD Added/fixed test case in MsSqlServerIntegrationSuite.scala, PostgresIntegrationSuite.scala , MySQLIntegrationSuite.scala to write/read tables from dataframe with cols as shorttype and bytetype. Validated by manual as follows. ./build/mvn install -DskipTests ./build/mvn test -Pdocker-integration-tests -pl :spark-docker-integration-tests_2.12 --- .../jdbc/MsSqlServerIntegrationSuite.scala | 48 +++++++++++++++++-- .../sql/jdbc/MySQLIntegrationSuite.scala | 4 +- .../datasources/jdbc/JdbcUtils.scala | 12 ++--- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 4 +- .../spark/sql/jdbc/JDBCWriteSuite.scala | 42 ++++++++++++++++ 5 files changed, 97 insertions(+), 13 deletions(-) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala index efd7ca74c796b..f1cd3343b7925 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala @@ -59,7 +59,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite { """ |INSERT INTO numbers VALUES ( |0, - |255, 32767, 2147483647, 9223372036854775807, + |127, 32767, 2147483647, 9223372036854775807, |123456789012345.123456789012345, 123456789012345.123456789012345, |123456789012345.123456789012345, |123, 12345.12, @@ -119,7 +119,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite { val types = row.toSeq.map(x => x.getClass.toString) assert(types.length == 12) assert(types(0).equals("class java.lang.Boolean")) - assert(types(1).equals("class java.lang.Integer")) + assert(types(1).equals("class java.lang.Byte")) assert(types(2).equals("class java.lang.Short")) assert(types(3).equals("class java.lang.Integer")) assert(types(4).equals("class java.lang.Long")) @@ -131,7 +131,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite { assert(types(10).equals("class java.math.BigDecimal")) assert(types(11).equals("class java.math.BigDecimal")) assert(row.getBoolean(0) == false) - assert(row.getInt(1) == 255) + assert(row.getByte(1) == 127) assert(row.getShort(2) == 32767) assert(row.getInt(3) == 2147483647) assert(row.getLong(4) == 9223372036854775807L) @@ -202,4 +202,46 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite { df2.write.jdbc(jdbcUrl, "datescopy", new Properties) df3.write.jdbc(jdbcUrl, "stringscopy", new Properties) } + + test("SPARK-29644: Write tables with ShortType") { + import testImplicits._ + val df = Seq(-32768.toShort, 0.toShort, 1.toShort, 38.toShort, 32768.toShort).toDF("a") + val tablename = "shorttable" + df.write + .format("jdbc") + .mode("overwrite") + .option("url", jdbcUrl) + .option("dbtable", tablename) + .save() + val df2 = spark.read + .format("jdbc") + .option("url", jdbcUrl) + .option("dbtable", tablename) + .load() + assert(df.count == df2.count) + val rows = df2.collect() + val colType = rows(0).toSeq.map(x => x.getClass.toString) + assert(colType(0) == "class java.lang.Short") + } + + test("SPARK-29644: Write tables with ByteType") { + import testImplicits._ + val df = Seq(-127.toByte, 0.toByte, 1.toByte, 38.toByte, 128.toByte).toDF("a") + val tablename = "bytetable" + df.write + .format("jdbc") + .mode("overwrite") + .option("url", jdbcUrl) + .option("dbtable", tablename) + .save() + val df2 = spark.read + .format("jdbc") + .option("url", jdbcUrl) + .option("dbtable", tablename) + .load() + assert(df.count == df2.count) + val rows = df2.collect() + val colType = rows(0).toSeq.map(x => x.getClass.toString) + assert(colType(0) == "class java.lang.Byte") + } } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala index 9cd5c4ec41a52..5b08093d930b1 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -82,7 +82,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { assert(types.length == 9) assert(types(0).equals("class java.lang.Boolean")) assert(types(1).equals("class java.lang.Long")) - assert(types(2).equals("class java.lang.Integer")) + assert(types(2).equals("class java.lang.Short")) assert(types(3).equals("class java.lang.Integer")) assert(types(4).equals("class java.lang.Integer")) assert(types(5).equals("class java.lang.Long")) @@ -91,7 +91,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { assert(types(8).equals("class java.lang.Double")) assert(rows(0).getBoolean(0) == false) assert(rows(0).getLong(1) == 0x225) - assert(rows(0).getInt(2) == 17) + assert(rows(0).getShort(2) == 17) assert(rows(0).getInt(3) == 77777) assert(rows(0).getInt(4) == 123456789) assert(rows(0).getLong(5) == 123456789012345L) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index c0f628ff04108..f19778f6f05f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -165,8 +165,8 @@ object JdbcUtils extends Logging { case LongType => Option(JdbcType("BIGINT", java.sql.Types.BIGINT)) case DoubleType => Option(JdbcType("DOUBLE PRECISION", java.sql.Types.DOUBLE)) case FloatType => Option(JdbcType("REAL", java.sql.Types.FLOAT)) - case ShortType => Option(JdbcType("INTEGER", java.sql.Types.SMALLINT)) - case ByteType => Option(JdbcType("BYTE", java.sql.Types.TINYINT)) + case ShortType => Option(JdbcType("SMALLINT", java.sql.Types.SMALLINT)) + case ByteType => Option(JdbcType("TINYINT", java.sql.Types.TINYINT)) case BooleanType => Option(JdbcType("BIT(1)", java.sql.Types.BIT)) case StringType => Option(JdbcType("TEXT", java.sql.Types.CLOB)) case BinaryType => Option(JdbcType("BLOB", java.sql.Types.BLOB)) @@ -230,7 +230,7 @@ object JdbcUtils extends Logging { case java.sql.Types.REF => StringType case java.sql.Types.REF_CURSOR => null case java.sql.Types.ROWID => LongType - case java.sql.Types.SMALLINT => IntegerType + case java.sql.Types.SMALLINT => ShortType case java.sql.Types.SQLXML => StringType case java.sql.Types.STRUCT => StringType case java.sql.Types.TIME => TimestampType @@ -239,7 +239,7 @@ object JdbcUtils extends Logging { case java.sql.Types.TIMESTAMP => TimestampType case java.sql.Types.TIMESTAMP_WITH_TIMEZONE => null - case java.sql.Types.TINYINT => IntegerType + case java.sql.Types.TINYINT => ByteType case java.sql.Types.VARBINARY => BinaryType case java.sql.Types.VARCHAR => StringType case _ => @@ -541,11 +541,11 @@ object JdbcUtils extends Logging { case ShortType => (stmt: PreparedStatement, row: Row, pos: Int) => - stmt.setInt(pos + 1, row.getShort(pos)) + stmt.setShort(pos + 1, row.getShort(pos)) case ByteType => (stmt: PreparedStatement, row: Row, pos: Int) => - stmt.setInt(pos + 1, row.getByte(pos)) + stmt.setByte(pos + 1, row.getByte(pos)) case BooleanType => (stmt: PreparedStatement, row: Row, pos: Int) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 2dcedc3fc1cc2..348c1a749c97d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -578,8 +578,8 @@ class JDBCSuite extends QueryTest assert(rows.length === 1) assert(rows(0).getInt(0) === 1) assert(rows(0).getBoolean(1) === false) - assert(rows(0).getInt(2) === 3) - assert(rows(0).getInt(3) === 4) + assert(rows(0).getByte(2) === 3.toByte) + assert(rows(0).getShort(3) === 4.toShort) assert(rows(0).getLong(4) === 1234567890123L) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index b751ec2de4825..e8155f42d3695 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -543,4 +543,46 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { }.getMessage assert(errMsg.contains("Statement was canceled or the session timed out")) } + + test("SPARK-29644: Write tables with ShortType") { + import testImplicits._ + val df = Seq(-32768.toShort, 0.toShort, 1.toShort, 38.toShort, 32768.toShort).toDF("a") + val tablename = "shorttable" + df.write + .format("jdbc") + .mode("overwrite") + .option("url", url) + .option("dbtable", tablename) + .save() + val df2 = spark.read + .format("jdbc") + .option("url", url) + .option("dbtable", tablename) + .load() + assert(df.count == df2.count) + val rows = df2.collect() + val colType = rows(0).toSeq.map(x => x.getClass.toString) + assert(colType(0) == "class java.lang.Short") + } + + test("SPARK-29644: Write tables with ByteType") { + import testImplicits._ + val df = Seq(-127.toByte, 0.toByte, 1.toByte, 38.toByte, 128.toByte).toDF("a") + val tablename = "bytetable" + df.write + .format("jdbc") + .mode("overwrite") + .option("url", url) + .option("dbtable", tablename) + .save() + val df2 = spark.read + .format("jdbc") + .option("url", url) + .option("dbtable", tablename) + .load() + assert(df.count == df2.count) + val rows = df2.collect() + val colType = rows(0).toSeq.map(x => x.getClass.toString) + assert(colType(0) == "class java.lang.Byte") + } }