Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,7 @@ object FunctionRegistry {
expression[ArrayFilter]("filter"),
expression[ArrayExists]("exists"),
expression[ArrayAggregate]("aggregate"),
expression[TransformValues]("transform_values"),
expression[TransformKeys]("transform_keys"),
expression[MapZipWith]("map_zip_with"),
expression[ZipWith]("zip_with"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ case class TransformKeys(
}

@transient lazy val LambdaFunction(
_, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function
_, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function


override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
Expand All @@ -550,6 +550,54 @@ case class TransformKeys(
override def prettyName: String = "transform_keys"
}

/**
* Returns a map that applies the function to each value of the map.
*/
@ExpressionDescription(
usage = "_FUNC_(expr, func) - Transforms values in the map using the function.",
examples = """
Examples:
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> v + 1);
map(array(1, 2, 3), array(2, 3, 4))
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v);
map(array(1, 2, 3), array(2, 4, 6))
""",
since = "2.4.0")
case class TransformValues(
argument: Expression,
function: Expression)
extends MapBasedSimpleHigherOrderFunction with CodegenFallback {

override def nullable: Boolean = argument.nullable

@transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType

override def dataType: DataType = MapType(keyType, function.dataType, function.nullable)

override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction)
: TransformValues = {
copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil))
}

@transient lazy val LambdaFunction(
_, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function

override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
val map = argumentValue.asInstanceOf[MapData]
val resultValues = new GenericArrayData(new Array[Any](map.numElements))
var i = 0
while (i < map.numElements) {
keyVar.value.set(map.keyArray().get(i, keyVar.dataType))
valueVar.value.set(map.valueArray().get(i, valueVar.dataType))
resultValues.update(i, functionForEval.eval(inputRow))
i += 1
}
new ArrayBasedMapData(map.keyArray(), resultValues)
}

override def prettyName: String = "transform_values"
}

/**
* Merges two given maps into a single map by applying function to the pair of values with
* the same key.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
aggregate(expr, zero, merge, identity)
}

def transformValues(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
val map = expr.dataType.asInstanceOf[MapType]
TransformValues(expr, createLambda(map.keyType, false, map.valueType, map.valueContainsNull, f))
}

test("ArrayTransform") {
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true))
Expand Down Expand Up @@ -358,6 +363,74 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(transformKeys(ax0, plusOne), Map(2 -> "x", 3 -> "y", 4 -> "z"))
}

test("TransformValues") {
val ai0 = Literal.create(
Map(1 -> 1, 2 -> 2, 3 -> 3),
MapType(IntegerType, IntegerType, valueContainsNull = false))
val ai1 = Literal.create(
Map(1 -> 1, 2 -> null, 3 -> 3),
MapType(IntegerType, IntegerType, valueContainsNull = true))
val ai2 = Literal.create(
Map.empty[Int, Int],
MapType(IntegerType, IntegerType, valueContainsNull = true))
val ai3 = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false))

val plusOne: (Expression, Expression) => Expression = (k, v) => v + 1
val valueUpdate: (Expression, Expression) => Expression = (k, v) => k * k

checkEvaluation(transformValues(ai0, plusOne), Map(1 -> 2, 2 -> 3, 3 -> 4))
checkEvaluation(transformValues(ai0, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9))
checkEvaluation(
transformValues(transformValues(ai0, plusOne), valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9))
checkEvaluation(transformValues(ai1, plusOne), Map(1 -> 2, 2 -> null, 3 -> 4))
checkEvaluation(transformValues(ai1, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9))
checkEvaluation(
transformValues(transformValues(ai1, plusOne), valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9))
checkEvaluation(transformValues(ai2, plusOne), Map.empty[Int, Int])
checkEvaluation(transformValues(ai3, plusOne), null)

val as0 = Literal.create(
Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"),
MapType(StringType, StringType, valueContainsNull = false))
val as1 = Literal.create(
Map("a" -> "xy", "bb" -> null, "ccc" -> "zx"),
MapType(StringType, StringType, valueContainsNull = true))
val as2 = Literal.create(Map.empty[StringType, StringType],
MapType(StringType, StringType, valueContainsNull = true))
val as3 = Literal.create(null, MapType(StringType, StringType, valueContainsNull = true))

val concatValue: (Expression, Expression) => Expression = (k, v) => Concat(Seq(k, v))
val valueTypeUpdate: (Expression, Expression) => Expression =
(k, v) => Length(v) + 1

checkEvaluation(
transformValues(as0, concatValue), Map("a" -> "axy", "bb" -> "bbyz", "ccc" -> "ccczx"))
checkEvaluation(transformValues(as0, valueTypeUpdate),
Map("a" -> 3, "bb" -> 3, "ccc" -> 3))
checkEvaluation(
transformValues(transformValues(as0, concatValue), concatValue),
Map("a" -> "aaxy", "bb" -> "bbbbyz", "ccc" -> "cccccczx"))
checkEvaluation(transformValues(as1, concatValue),
Map("a" -> "axy", "bb" -> null, "ccc" -> "ccczx"))
checkEvaluation(transformValues(as1, valueTypeUpdate),
Map("a" -> 3, "bb" -> null, "ccc" -> 3))
checkEvaluation(
transformValues(transformValues(as1, concatValue), concatValue),
Map("a" -> "aaxy", "bb" -> null, "ccc" -> "cccccczx"))
checkEvaluation(transformValues(as2, concatValue), Map.empty[String, String])
checkEvaluation(transformValues(as2, valueTypeUpdate), Map.empty[String, Int])
checkEvaluation(
transformValues(transformValues(as2, concatValue), valueTypeUpdate),
Map.empty[String, Int])
checkEvaluation(transformValues(as3, concatValue), null)

val ax0 = Literal.create(
Map(1 -> "x", 2 -> "y", 3 -> "z"),
MapType(IntegerType, StringType, valueContainsNull = false))

checkEvaluation(transformValues(ax0, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9))
}

test("MapZipWith") {
def map_zip_with(
left: Expression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,12 @@ select transform_keys(ys, (k, v) -> k + 1) as v from nested;

-- Transform Keys in a map using values
select transform_keys(ys, (k, v) -> k + v) as v from nested;

-- Identity Transform values in a map
select transform_values(ys, (k, v) -> v) as v from nested;

-- Transform values in a map by adding constant
select transform_values(ys, (k, v) -> v + 1) as v from nested;

-- Transform values in a map using values
select transform_values(ys, (k, v) -> k + v) as v from nested;
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 20
-- Number of queries: 27


-- !query 0
Expand Down Expand Up @@ -226,3 +226,30 @@ struct<v:map<int,int>>
-- !query 23 output
{10:5,12:6,8:4}
{2:1,4:2,6:3}


-- !query 24
select transform_values(ys, (k, v) -> v) as v from nested
-- !query 24 schema
struct<v:map<int,int>>
-- !query 24 output
{1:1,2:2,3:3}
{4:4,5:5,6:6}


-- !query 25
select transform_values(ys, (k, v) -> v + 1) as v from nested
-- !query 25 schema
struct<v:map<int,int>>
-- !query 25 output
{1:2,2:3,3:4}
{4:5,5:6,6:7}


-- !query 26
select transform_values(ys, (k, v) -> k + v) as v from nested
-- !query 26 schema
struct<v:map<int,int>>
-- !query 26 output
{1:2,2:4,3:6}
{4:8,5:10,6:12}
Loading