From 4698d05db5e874cc6cb7aa3dced022809bf3ba3d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 13 Aug 2015 12:24:42 +0800 Subject: [PATCH 1/2] support metadata in withColumn --- .../spark/ml/classification/OneVsRest.scala | 6 +++--- .../apache/spark/ml/feature/Bucketizer.scala | 2 +- .../spark/ml/feature/VectorIndexer.scala | 2 +- .../spark/ml/feature/VectorSlicer.scala | 3 +-- .../org/apache/spark/sql/DataFrame.scala | 19 +++++++++++++++++++ 5 files changed, 25 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 1741f19dc911c..ec06824a7372a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -131,7 +131,7 @@ final class OneVsRestModel private[ml] ( // output label and label metadata as prediction aggregatedDataset - .withColumn($(predictionCol), labelUDF(col(accColName)).as($(predictionCol), labelMetadata)) + .withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata) .drop(accColName) } @@ -203,8 +203,8 @@ final class OneVsRest(override val uid: String) // TODO: use when ... otherwise after SPARK-7321 is merged val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata() val labelColName = "mc2b$" + index - val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta) - val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta) + val trainingDataset = + multiclassLabeled.withColumn(labelColName, labelUDF(col($(labelCol))), newLabelMeta) val classifier = getClassifier val paramMap = new ParamMap() paramMap.put(classifier.labelCol -> labelColName) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 67e4785bc3553..89d27584073f8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -75,7 +75,7 @@ final class Bucketizer(override val uid: String) } val newCol = bucketizer(dataset($(inputCol))) val newField = prepOutputField(dataset.schema) - dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata)) + dataset.withColumn($(outputCol), newCol, newField.metadata) } private def prepOutputField(schema: StructType): StructField = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index c73bdccdef5fa..489b6e5173a07 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -341,7 +341,7 @@ class VectorIndexerModel private[ml] ( val newField = prepOutputField(dataset.schema) val transformUDF = udf { (vector: Vector) => transformFunc(vector) } val newCol = transformUDF(dataset($(inputCol))) - dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata)) + dataset.withColumn($(outputCol), newCol, newField.metadata) } override def transformSchema(schema: StructType): StructType = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index 772bebeff214b..c5c2272270792 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -119,8 +119,7 @@ final class VectorSlicer(override val uid: String) case features: SparseVector => features.slice(inds) } } - dataset.withColumn($(outputCol), - slicer(dataset($(inputCol))).as($(outputCol), outputAttr.toMetadata())) + dataset.withColumn($(outputCol), slicer(dataset($(inputCol))), outputAttr.toMetadata()) } /** Get the feature indices in order: indices, names */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 27b994f1f0caf..eda83c260d5b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1149,6 +1149,25 @@ class DataFrame private[sql]( } } + /** + * Returns a new [[DataFrame]] by adding a column with medadata. + * @group dfops + * @since 1.5.0 + */ + def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = { + val resolver = sqlContext.analyzer.resolver + val replaced = schema.exists(f => resolver(f.name, colName)) + if (replaced) { + val colNames = schema.map { field => + val name = field.name + if (resolver(name, colName)) col.as(colName, metadata) else Column(name) + } + select(colNames : _*) + } else { + select(Column("*"), col.as(colName, metadata)) + } + } + /** * Returns a new [[DataFrame]] with a column renamed. * This is a no-op if schema doesn't contain existingName. From 39ce9c738e159b1c2a5517e6bd8d6f5d1b952b14 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 14 Aug 2015 10:45:16 +0800 Subject: [PATCH 2/2] make it private --- .../src/main/scala/org/apache/spark/sql/DataFrame.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index eda83c260d5b4..3d3528839df63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1150,11 +1150,9 @@ class DataFrame private[sql]( } /** - * Returns a new [[DataFrame]] by adding a column with medadata. - * @group dfops - * @since 1.5.0 + * Returns a new [[DataFrame]] by adding a column with metadata. */ - def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = { + private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = { val resolver = sqlContext.analyzer.resolver val replaced = schema.exists(f => resolver(f.name, colName)) if (replaced) {