diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index deebe0efcc335..95b78175d5561 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2052,11 +2052,12 @@ def element_at(col, extraction): [Row(element_at(data, 1)=u'a'), Row(element_at(data, 1)=None)] >>> df = spark.createDataFrame([({"a": 1.0, "b": 2.0},), ({},)], ['data']) - >>> df.select(element_at(df.data, "a")).collect() + >>> df.select(element_at(df.data, lit("a"))).collect() [Row(element_at(data, a)=1.0), Row(element_at(data, a)=None)] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.element_at(_to_java_column(col), extraction)) + return Column(sc._jvm.functions.element_at( + _to_java_column(col), lit(extraction)._jc)) # noqa: F821 'lit' is dynamically defined. @since(2.4)