Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@
from operator import itemgetter

from pyspark.rdd import RDD, PipelinedRDD
from pyspark.serializers import BatchedSerializer, PickleSerializer
from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer

from itertools import chain, ifilter, imap

from py4j.protocol import Py4JError
from py4j.java_collections import ListConverter, MapConverter


__all__ = [
"StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType",
Expand Down Expand Up @@ -932,6 +936,39 @@ def _ssql_ctx(self):
self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc())
return self._scala_SQLContext

def registerFunction(self, name, f, returnType=StringType()):
"""Registers a lambda function as a UDF so it can be used in SQL statements.

In addition to a name and the function itself, the return type can be optionally specified.
When the return type is not given it default to a string and conversion will automatically
be done. For any other return type, the produced object must match the specified type.

>>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x))
>>> sqlCtx.sql("SELECT stringLengthString('test')").collect()
[Row(c0=u'4')]
>>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlCtx.sql("SELECT stringLengthInt('test')").collect()
[Row(c0=4)]
>>> sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
>>> sqlCtx.sql("SELECT twoArgs('test', 1)").collect()
[Row(c0=5)]
"""
func = lambda _, it: imap(lambda x: f(*x), it)
command = (func,
BatchedSerializer(PickleSerializer(), 1024),
BatchedSerializer(PickleSerializer(), 1024))
env = MapConverter().convert(self._sc.environment,
self._sc._gateway._gateway_client)
includes = ListConverter().convert(self._sc._python_includes,
self._sc._gateway._gateway_client)
self._ssql_ctx.registerPython(name,
bytearray(CloudPickleSerializer().dumps(command)),
env,
includes,
self._sc.pythonExec,
self._sc._javaAccumulator,
str(returnType))

def inferSchema(self, rdd):
"""Infer and apply a schema to an RDD of L{Row}s.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,49 @@
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.Expression
import scala.collection.mutable

/** A catalog for looking up user defined functions, used by an [[Analyzer]]. */
trait FunctionRegistry {
type FunctionBuilder = Seq[Expression] => Expression

def registerFunction(name: String, builder: FunctionBuilder): Unit

def lookupFunction(name: String, children: Seq[Expression]): Expression
}

trait OverrideFunctionRegistry extends FunctionRegistry {

val functionBuilders = new mutable.HashMap[String, FunctionBuilder]()

def registerFunction(name: String, builder: FunctionBuilder) = {
functionBuilders.put(name, builder)
}

abstract override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
functionBuilders.get(name).map(_(children)).getOrElse(super.lookupFunction(name,children))
}
}

class SimpleFunctionRegistry extends FunctionRegistry {
val functionBuilders = new mutable.HashMap[String, FunctionBuilder]()

def registerFunction(name: String, builder: FunctionBuilder) = {
functionBuilders.put(name, builder)
}

override def lookupFunction(name: String, children: Seq[Expression]): Expression = {
functionBuilders(name)(children)
}
}

/**
* A trivial catalog that returns an error when a function is requested. Used for testing when all
* functions are already filled in and the analyser needs only to resolve attribute references.
*/
object EmptyFunctionRegistry extends FunctionRegistry {
def registerFunction(name: String, builder: FunctionBuilder) = ???

def lookupFunction(name: String, children: Seq[Expression]): Expression = {
throw new UnsupportedOperationException
}
Expand Down
Loading