From 1e3b315c734b6298227d4e5a511c1ca5bb45fd8c Mon Sep 17 00:00:00 2001 From: abhijeet-lele Date: Wed, 2 Mar 2022 21:03:19 +0530 Subject: [PATCH 1/4] Suggested changes to handle nested row in an array --- .../java/org/apache/beam/sdk/values/Row.java | 112 ++++++++------- .../sql/example/BeamSqlNestedExample.java | 131 ++++++++++++++++++ 2 files changed, 194 insertions(+), 49 deletions(-) create mode 100644 sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/example/BeamSqlNestedExample.java diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java index 9dd02d32c2f3..17fe025b94be 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java @@ -31,6 +31,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.function.Function; import java.util.stream.Collector; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -86,8 +87,8 @@ */ @Experimental(Kind.SCHEMAS) @SuppressWarnings({ - "nullness", // TODO(https://issues.apache.org/jira/browse/BEAM-10402) - "rawtypes" + "nullness", // TODO(https://issues.apache.org/jira/browse/BEAM-10402) + "rawtypes" }) public abstract class Row implements Serializable { private final Schema schema; @@ -111,8 +112,21 @@ public abstract class Row implements Serializable { /** Return a list of data values. Any LogicalType values are returned as base values. * */ public List getBaseValues() { return IntStream.range(0, getFieldCount()) - .mapToObj(i -> getBaseValue(i)) - .collect(Collectors.toList()); + .mapToObj(i -> { + List values = new ArrayList<>(); + FieldType fieldType = this.getSchema().getField(i).getType(); + if(fieldType.getTypeName().equals(TypeName.ROW)) { + Row row = this.getBaseValue(i, Row.class); + List rowValues = row.getBaseValues(); + if(null != rowValues) { + values.addAll(rowValues); + } + } else { + values.add(this.getBaseValue(i)); + } + return values.stream(); + }).flatMap(Function.identity()).collect(Collectors.toList()); + } /** Get value by field name, {@link ClassCastException} is thrown if type doesn't match. */ @@ -465,13 +479,13 @@ public static boolean deepEquals(Object a, Object b, Schema.FieldType fieldType) return Arrays.equals((byte[]) a, (byte[]) b); } else if (fieldType.getTypeName() == TypeName.ARRAY) { return deepEqualsForCollection( - (Collection) a, (Collection) b, fieldType.getCollectionElementType()); + (Collection) a, (Collection) b, fieldType.getCollectionElementType()); } else if (fieldType.getTypeName() == TypeName.ITERABLE) { return deepEqualsForIterable( - (Iterable) a, (Iterable) b, fieldType.getCollectionElementType()); + (Iterable) a, (Iterable) b, fieldType.getCollectionElementType()); } else if (fieldType.getTypeName() == Schema.TypeName.MAP) { return deepEqualsForMap( - (Map) a, (Map) b, fieldType.getMapValueType()); + (Map) a, (Map) b, fieldType.getMapValueType()); } else { return Objects.equals(a, b); } @@ -488,7 +502,7 @@ public static int deepHashCode(Object a, Schema.FieldType fieldType) { return deepHashCodeForIterable((Iterable) a, fieldType.getCollectionElementType()); } else if (fieldType.getTypeName() == Schema.TypeName.MAP) { return deepHashCodeForMap( - (Map) a, fieldType.getMapKeyType(), fieldType.getMapValueType()); + (Map) a, fieldType.getMapKeyType(), fieldType.getMapValueType()); } else { return Objects.hashCode(a); } @@ -523,7 +537,7 @@ static boolean deepEqualsForMap(Map a, Map b, Schema.FieldTyp } static int deepHashCodeForMap( - Map a, Schema.FieldType keyType, Schema.FieldType valueType) { + Map a, Schema.FieldType keyType, Schema.FieldType valueType) { int h = 0; for (Map.Entry e : a.entrySet()) { @@ -537,7 +551,7 @@ static int deepHashCodeForMap( } static boolean deepEqualsForCollection( - Collection a, Collection b, Schema.FieldType elementType) { + Collection a, Collection b, Schema.FieldType elementType) { if (a == b) { return true; } @@ -550,7 +564,7 @@ static boolean deepEqualsForCollection( } static boolean deepEqualsForIterable( - Iterable a, Iterable b, Schema.FieldType elementType) { + Iterable a, Iterable b, Schema.FieldType elementType) { if (a == b) { return true; } @@ -605,7 +619,7 @@ private String toString(Schema.FieldType fieldType, Object value, boolean includ builder.append("["); for (Object element : (Iterable) value) { builder.append( - toString(fieldType.getCollectionElementType(), element, includeFieldNames)); + toString(fieldType.getCollectionElementType(), element, includeFieldNames)); builder.append(", "); } builder.append("]"); @@ -617,7 +631,7 @@ private String toString(Schema.FieldType fieldType, Object value, boolean includ builder.append(toString(fieldType.getMapKeyType(), entry.getKey(), includeFieldNames)); builder.append(", "); builder.append( - toString(fieldType.getMapValueType(), entry.getValue(), includeFieldNames)); + toString(fieldType.getMapValueType(), entry.getValue(), includeFieldNames)); builder.append("), "); } builder.append("}"); @@ -684,7 +698,7 @@ public FieldValueBuilder withFieldValue(Integer fieldId, Object value) { /** Set a field value using a FieldAccessDescriptor. */ public FieldValueBuilder withFieldValue( - FieldAccessDescriptor fieldAccessDescriptor, Object value) { + FieldAccessDescriptor fieldAccessDescriptor, Object value) { FieldAccessDescriptor fieldAccess = fieldAccessDescriptor.resolve(getSchema()); checkArgument(fieldAccess.referencesSingleField(), ""); fieldOverrides.addOverride(fieldAccess, new FieldOverride(value)); @@ -697,11 +711,11 @@ public FieldValueBuilder withFieldValue( */ public FieldValueBuilder withFieldValues(Map values) { values.entrySet().stream() - .forEach( - e -> - fieldOverrides.addOverride( - FieldAccessDescriptor.withFieldNames(e.getKey()).resolve(getSchema()), - new FieldOverride(e.getValue()))); + .forEach( + e -> + fieldOverrides.addOverride( + FieldAccessDescriptor.withFieldNames(e.getKey()).resolve(getSchema()), + new FieldOverride(e.getValue()))); return this; } @@ -711,19 +725,19 @@ public FieldValueBuilder withFieldValues(Map values) { */ public FieldValueBuilder withFieldAccessDescriptors(Map values) { values.entrySet().stream() - .forEach(e -> fieldOverrides.addOverride(e.getKey(), new FieldOverride(e.getValue()))); + .forEach(e -> fieldOverrides.addOverride(e.getKey(), new FieldOverride(e.getValue()))); return this; } public Row build() { Row row = - (Row) - new RowFieldMatcher() - .match( - new CapturingRowCases(getSchema(), this.fieldOverrides), - FieldType.row(getSchema()), - new RowPosition(FieldAccessDescriptor.create()), - sourceRow); + (Row) + new RowFieldMatcher() + .match( + new CapturingRowCases(getSchema(), this.fieldOverrides), + FieldType.row(getSchema()), + new RowPosition(FieldAccessDescriptor.create()), + sourceRow); return row; } } @@ -759,7 +773,7 @@ public FieldValueBuilder withFieldValue(Integer fieldId, Object value) { /** Set a field value using a FieldAccessDescriptor. */ public FieldValueBuilder withFieldValue( - FieldAccessDescriptor fieldAccessDescriptor, Object value) { + FieldAccessDescriptor fieldAccessDescriptor, Object value) { checkState(values.isEmpty()); return new FieldValueBuilder(schema, null).withFieldValue(fieldAccessDescriptor, value); } @@ -829,7 +843,7 @@ public int nextFieldId() { @Internal public Row withFieldValueGetters( - Factory> fieldValueGetterFactory, Object getterTarget) { + Factory> fieldValueGetterFactory, Object getterTarget) { checkState(getterTarget != null, "getters require withGetterTarget."); return new RowWithGetters(schema, fieldValueGetterFactory, getterTarget); } @@ -839,23 +853,23 @@ public Row build() { if (!values.isEmpty() && values.size() != schema.getFieldCount()) { throw new IllegalArgumentException( - "Row expected " - + schema.getFieldCount() - + " fields. initialized with " - + values.size() - + " fields."); + "Row expected " + + schema.getFieldCount() + + " fields. initialized with " + + values.size() + + " fields."); } if (!values.isEmpty()) { FieldOverrides fieldOverrides = new FieldOverrides(schema, this.values); if (!fieldOverrides.isEmpty()) { return (Row) - new RowFieldMatcher() - .match( - new CapturingRowCases(schema, fieldOverrides), - FieldType.row(schema), - new RowPosition(FieldAccessDescriptor.create()), - null); + new RowFieldMatcher() + .match( + new CapturingRowCases(schema, fieldOverrides), + FieldType.row(schema), + new RowPosition(FieldAccessDescriptor.create()), + null); } } return new RowWithStorage(schema, Collections.emptyList()); @@ -865,19 +879,19 @@ public Row build() { /** Creates a {@link Row} from the list of values and {@link #getSchema()}. */ public static Collector, Row> toRow(Schema schema) { return Collector.of( - () -> new ArrayList<>(schema.getFieldCount()), - List::add, - (left, right) -> { - left.addAll(right); - return left; - }, - values -> Row.withSchema(schema).addValues(values).build()); + () -> new ArrayList<>(schema.getFieldCount()), + List::add, + (left, right) -> { + left.addAll(right); + return left; + }, + values -> Row.withSchema(schema).addValues(values).build()); } /** Creates a new record filled with nulls. */ public static Row nullRow(Schema schema) { return Row.withSchema(schema) - .addValues(Collections.nCopies(schema.getFieldCount(), null)) - .build(); + .addValues(Collections.nCopies(schema.getFieldCount(), null)) + .build(); } } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/example/BeamSqlNestedExample.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/example/BeamSqlNestedExample.java new file mode 100644 index 000000000000..21e889414542 --- /dev/null +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/example/BeamSqlNestedExample.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.sql.example; + +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.extensions.sql.SqlTransform; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.SimpleFunction; +import org.apache.beam.sdk.values.*; + +import java.util.Arrays; + +/** + * This is a quick example, which uses Beam SQL DSL to create a data pipeline. + * + *

Run the example from the Beam source root with + * + *

+ *   ./gradlew :sdks:java:extensions:sql:runBasicExample
+ * 
+ * + *

The above command executes the example locally using direct runner. Running the pipeline in + * other runners require additional setup and are out of scope of the SQL examples. Please consult + * Beam documentation on how to run pipelines. + */ +class BeamSqlNestedExample { + + public static void main(String[] args) { + PipelineOptions options = PipelineOptionsFactory.fromArgs(args).create(); + Pipeline p = Pipeline.create(options); + + // define the input row format level3 + Schema level3Type = + Schema.builder().addInt32Field("c1").addStringField("c2").addDoubleField("c3").build(); + + Row level3Row1 = Row.withSchema(level3Type).addValues(1, "row", 1.0).build(); + Row level3Row2 = Row.withSchema(level3Type).addValues(2, "row", 2.0).build(); + Row level3Row3 = Row.withSchema(level3Type).addValues(3, "row", 3.0).build(); + + // define the input row format level3 + Schema level2Type = + Schema.builder().addInt32Field("b1") + .addStringField("b2") + .addRowField("b3", level3Type) + .addDoubleField("b4").build(); + + + Row level2Row1 = Row.withSchema(level2Type).addValues(1, "row", level3Row1, 1.0).build(); + Row level2Row2 = Row.withSchema(level2Type).addValues(2, "row", level3Row2, 2.0).build(); + Row level2Row3 = Row.withSchema(level2Type).addValues(3, "row", level3Row3, 3.0).build(); + + // define the input row format level3 + Schema level1Type = + Schema.builder().addInt32Field("a1") + .addStringField("a2") + .addDoubleField("a3") + .addArrayField("a4", Schema.FieldType.row(level2Type)) + .build(); + Row level1Row1 = Row.withSchema(level1Type).addValues(1, "row", 1.0, + Arrays.asList(level2Row1, level2Row2, level2Row3)).build(); + Row level1Row2 = Row.withSchema(level1Type).addValues(2, "row", 2.0, + Arrays.asList(level2Row1, level2Row2, level2Row3)).build(); + Row level1Row3 = Row.withSchema(level1Type).addValues(3, "row", 3.0, + Arrays.asList(level2Row1, level2Row2, level2Row3)).build(); + + + // create a source PCollection with Create.of(); + PCollection inputTable = + PBegin.in(p).apply(Create.of(level1Row1, level1Row2, level1Row3).withRowSchema(level1Type)); + + String sql = "select t.a1, t.a2, t.a3, d.b1, d.b2, d.b4, " + + "d.b3.c1, d.b3.c2, d.b3.c3 from test t cross join unnest(t.a4) d"; + // Case 1. run a simple SQL query over input PCollection with BeamSql.simpleQuery; + PCollection dfTemp = + PCollectionTuple.of(new TupleTag<>("test"), inputTable).apply(SqlTransform.query(sql)); + + + + // print the output record of case 1; + Schema dfTempSchema = dfTemp.getSchema(); + // with out the fix it will throw following exception + // Caused by: java.lang.IllegalArgumentException: Row expected 10 fields. initialized with 8 fields. + + + // with the changes in the Row.Java + dfTemp + .apply( + "log_result", + MapElements.via( + new SimpleFunction() { + @Override + public Row apply(Row input) { + // expect output: + // PCOLLECTION: [1, row, 1.0, 1, row, 1.0, 1, row, 1.0] + // PCOLLECTION: [1, row, 1.0, 2, row, 2.0, 2, row, 2.0] + // PCOLLECTION: [1, row, 1.0, 3, row, 3.0, 3, row, 3.0] + // PCOLLECTION: [3, row, 3.0, 1, row, 1.0, 1, row, 1.0] + // PCOLLECTION: [3, row, 3.0, 2, row, 2.0, 2, row, 2.0] + // PCOLLECTION: [3, row, 3.0, 3, row, 3.0, 3, row, 3.0] + // PCOLLECTION: [2, row, 2.0, 1, row, 1.0, 1, row, 1.0] + // PCOLLECTION: [2, row, 2.0, 2, row, 2.0, 2, row, 2.0] + // PCOLLECTION: [2, row, 2.0, 3, row, 3.0, 3, row, 3.0] + + System.out.println("PCOLLECTION: " + input.getValues()); + return input; + } + })) + .setRowSchema(dfTempSchema); + + p.run().waitUntilFinish(); + } +} From 866686551e9a761446a28f5c65baa293cdfba61e Mon Sep 17 00:00:00 2001 From: abhijeet-lele Date: Wed, 2 Mar 2022 22:23:19 +0530 Subject: [PATCH 2/4] Beam-14026 Suggested changes to handle nested row in an array --- sdks/java/extensions/sql/build.gradle | 8 ++++++++ .../sdk/extensions/sql/example/BeamSqlNestedExample.java | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sdks/java/extensions/sql/build.gradle b/sdks/java/extensions/sql/build.gradle index e6f021865225..8c0860517f77 100644 --- a/sdks/java/extensions/sql/build.gradle +++ b/sdks/java/extensions/sql/build.gradle @@ -206,6 +206,14 @@ task runBasicExample(type: JavaExec) { args = ["--runner=DirectRunner"] } +// Run basic SQL example +task runNestedRowInArrayExample(type: JavaExec) { + description = "Run basic SQL example" + mainClass = "org.apache.beam.sdk.extensions.sql.example.BeamSqlNestedExample" + classpath = sourceSets.main.runtimeClasspath + args = ["--runner=DirectRunner"] +} + // Run SQL example on POJO inputs task runPojoExample(type: JavaExec) { description = "Run SQL example for PCollections of POJOs" diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/example/BeamSqlNestedExample.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/example/BeamSqlNestedExample.java index 21e889414542..5c45cc464ecf 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/example/BeamSqlNestedExample.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/example/BeamSqlNestedExample.java @@ -35,7 +35,7 @@ *

Run the example from the Beam source root with * *

- *   ./gradlew :sdks:java:extensions:sql:runBasicExample
+ *   ./gradlew :sdks:java:extensions:sql:runNestedRowInArrayExample
  * 
* *

The above command executes the example locally using direct runner. Running the pipeline in From 4d40ba0ba19bc4ce63d0e4576bc1e6900caa5f22 Mon Sep 17 00:00:00 2001 From: abhijeet-lele Date: Thu, 3 Mar 2022 23:39:56 +0530 Subject: [PATCH 3/4] Beam-14026 Enhanced by segregating the code from getBaseValues enhanced test case and example. --- .../java/org/apache/beam/sdk/values/Row.java | 14 ++- sdks/java/extensions/sql/build.gradle | 2 +- ...Example.java => BeamSqlUnnestExample.java} | 5 +- .../sql/impl/rel/BeamUnnestRel.java | 2 +- .../sql/BeamSqlDslUnnestRowsTest.java | 119 ++++++++++++++++++ 5 files changed, 134 insertions(+), 8 deletions(-) rename sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/example/{BeamSqlNestedExample.java => BeamSqlUnnestExample.java} (97%) create mode 100644 sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUnnestRowsTest.java diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java index 17fe025b94be..da8536896e15 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java @@ -109,15 +109,17 @@ public abstract class Row implements Serializable { /** Return the list of data values. */ public abstract List getValues(); - /** Return a list of data values. Any LogicalType values are returned as base values. * */ - public List getBaseValues() { + /** This is recursive call to get all the values of the nested rows. + The recusion is bounded by the amount of nesting with in the data + This mirrors the unnest behavior of calcite towards schema **/ + public List getNestedRowBaseValues() { return IntStream.range(0, getFieldCount()) .mapToObj(i -> { List values = new ArrayList<>(); FieldType fieldType = this.getSchema().getField(i).getType(); if(fieldType.getTypeName().equals(TypeName.ROW)) { Row row = this.getBaseValue(i, Row.class); - List rowValues = row.getBaseValues(); + List rowValues = row.getNestedRowBaseValues(); if(null != rowValues) { values.addAll(rowValues); } @@ -126,7 +128,13 @@ public List getBaseValues() { } return values.stream(); }).flatMap(Function.identity()).collect(Collectors.toList()); + } + /** Return a list of data values. Any LogicalType values are returned as base values. * */ + public List getBaseValues() { + return IntStream.range(0, getFieldCount()) + .mapToObj(i -> getBaseValue(i)) + .collect(Collectors.toList()); } /** Get value by field name, {@link ClassCastException} is thrown if type doesn't match. */ diff --git a/sdks/java/extensions/sql/build.gradle b/sdks/java/extensions/sql/build.gradle index 8c0860517f77..4fda5e0b0ea8 100644 --- a/sdks/java/extensions/sql/build.gradle +++ b/sdks/java/extensions/sql/build.gradle @@ -209,7 +209,7 @@ task runBasicExample(type: JavaExec) { // Run basic SQL example task runNestedRowInArrayExample(type: JavaExec) { description = "Run basic SQL example" - mainClass = "org.apache.beam.sdk.extensions.sql.example.BeamSqlNestedExample" + mainClass = "org.apache.beam.sdk.extensions.sql.example.BeamSqlUnnestExample" classpath = sourceSets.main.runtimeClasspath args = ["--runner=DirectRunner"] } diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/example/BeamSqlNestedExample.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/example/BeamSqlUnnestExample.java similarity index 97% rename from sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/example/BeamSqlNestedExample.java rename to sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/example/BeamSqlUnnestExample.java index 5c45cc464ecf..b5cb42e07a91 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/example/BeamSqlNestedExample.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/example/BeamSqlUnnestExample.java @@ -42,7 +42,7 @@ * other runners require additional setup and are out of scope of the SQL examples. Please consult * Beam documentation on how to run pipelines. */ -class BeamSqlNestedExample { +class BeamSqlUnnestExample { public static void main(String[] args) { PipelineOptions options = PipelineOptionsFactory.fromArgs(args).create(); @@ -87,8 +87,7 @@ public static void main(String[] args) { PCollection inputTable = PBegin.in(p).apply(Create.of(level1Row1, level1Row2, level1Row3).withRowSchema(level1Type)); - String sql = "select t.a1, t.a2, t.a3, d.b1, d.b2, d.b4, " + - "d.b3.c1, d.b3.c2, d.b3.c3 from test t cross join unnest(t.a4) d"; + String sql = "select t.a1, t.a2, t.a3, d.b1, d.b2, d.b4, d.b3.c1, d.b3.c2, d.b3.c3 from test t cross join unnest(t.a4) d"; // Case 1. run a simple SQL query over input PCollection with BeamSql.simpleQuery; PCollection dfTemp = PCollectionTuple.of(new TupleTag<>("test"), inputTable).apply(SqlTransform.query(sql)); diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnnestRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnnestRel.java index 655d75a01e1b..9e6145eb93c3 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnnestRel.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnnestRel.java @@ -157,7 +157,7 @@ public void process(@Element Row row, OutputReceiver out) { out.output( Row.withSchema(outputSchema) .addValues(row.getBaseValues()) - .addValues(nestedRow.getBaseValues()) + .addValues(nestedRow.getNestedRowBaseValues()) .build()); } else { out.output( diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUnnestRowsTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUnnestRowsTest.java new file mode 100644 index 000000000000..cd3e10db64ae --- /dev/null +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUnnestRowsTest.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.sql; + +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.values.*; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import java.util.Arrays; + +/** + * Tests for nested rows handling. + */ +public class BeamSqlDslUnnestRowsTest { + + @Rule + public final TestPipeline pipeline = TestPipeline.create(); + @Rule + public ExpectedException exceptions = ExpectedException.none(); + + /** + * TODO([BEAM-14026]): This is a test of the incorrect behavior unnest + * because calcite flattens the row. + */ + @Test + public void testUnnestArrayWithNestedRows() { + + Schema level3Type = + Schema.builder().addInt32Field("c1").addStringField("c2").addDoubleField("c3").build(); + + Row level3Row1 = Row.withSchema(level3Type).addValues(1, "row", 1.0).build(); + Row level3Row2 = Row.withSchema(level3Type).addValues(2, "row", 2.0).build(); + Row level3Row3 = Row.withSchema(level3Type).addValues(3, "row", 3.0).build(); + + // define the input row format level3 + Schema level2Type = + Schema.builder().addInt32Field("b1") + .addStringField("b2") + .addRowField("b3", level3Type) + .addDoubleField("b4").build(); + + + Row level2Row1 = Row.withSchema(level2Type).addValues(1, "row", level3Row1, 1.0).build(); + Row level2Row2 = Row.withSchema(level2Type).addValues(2, "row", level3Row2, 2.0).build(); + Row level2Row3 = Row.withSchema(level2Type).addValues(3, "row", level3Row3, 3.0).build(); + + // define the input row format level3 + Schema level1Type = + Schema.builder().addInt32Field("a1") + .addStringField("a2") + .addDoubleField("a3") + .addArrayField("a4", Schema.FieldType.row(level2Type)) + .build(); + Row level1Row1 = Row.withSchema(level1Type).addValues(1, "row", 1.0, + Arrays.asList(level2Row1, level2Row2, level2Row3)).build(); + Row level1Row2 = Row.withSchema(level1Type).addValues(2, "row", 2.0, + Arrays.asList(level2Row1, level2Row2, level2Row3)).build(); + Row level1Row3 = Row.withSchema(level1Type).addValues(3, "row", 3.0, + Arrays.asList(level2Row1, level2Row2, level2Row3)).build(); + + + // create a source PCollection with Create.of(); + PCollection inputTable = + PBegin.in(pipeline).apply(Create.of(level1Row1, level1Row2, level1Row3).withRowSchema(level1Type)); + + String sql = "select t.a1, t.a2, t.a3, d.b1, d.b2, d.b4, " + + "d.b3.c1, d.b3.c2, d.b3.c3 from test t cross join unnest(t.a4) d"; + // Case 1. run a simple SQL query over input PCollection with BeamSql.simpleQuery; + PCollection result = + PCollectionTuple.of(new TupleTag<>("test"), inputTable).apply(SqlTransform.query(sql)); + + + Schema resultSchema = + Schema.builder().addInt32Field("a1") + .addStringField("a2") + .addDoubleField("a3") + .addInt32Field("b1") + .addStringField("b2") + .addDoubleField("b4") + .addInt32Field("c1") + .addStringField("c2") + .addDoubleField("c3") + .build(); + + PAssert.that(result) + .containsInAnyOrder( + Row.withSchema(resultSchema).addValues(1, "row", 1.0, 1, "row", 1.0, 1, "row", 1.0).build(), + Row.withSchema(resultSchema).addValues(1, "row", 1.0, 2, "row", 2.0, 2, "row", 2.0).build(), + Row.withSchema(resultSchema).addValues(1, "row", 1.0, 3, "row", 3.0, 3, "row", 3.0).build(), + Row.withSchema(resultSchema).addValues(3, "row", 3.0, 1, "row", 1.0, 1, "row", 1.0).build(), + Row.withSchema(resultSchema).addValues(3, "row", 3.0, 2, "row", 2.0, 2, "row", 2.0).build(), + Row.withSchema(resultSchema).addValues(3, "row", 3.0, 3, "row", 3.0, 3, "row", 3.0).build(), + Row.withSchema(resultSchema).addValues(2, "row", 2.0, 1, "row", 1.0, 1, "row", 1.0).build(), + Row.withSchema(resultSchema).addValues(2, "row", 2.0, 2, "row", 2.0, 2, "row", 2.0).build(), + Row.withSchema(resultSchema).addValues(2, "row", 2.0, 3, "row", 3.0, 3, "row", 3.0).build() + ); + pipeline.run(); + } + +} From 9c2942e91dd46c3092179ead22f9ab20762f8f05 Mon Sep 17 00:00:00 2001 From: abhijeet-lele Date: Fri, 4 Mar 2022 20:30:03 +0530 Subject: [PATCH 4/4] Beam-14026 The code is moved from Row to avoid impact to the public interface. The code is moved to BeamUnnestRel.java since its the caller class. The Example code was duplicate, hence dropped. build.gradle updated with the removal of example code. --- .../java/org/apache/beam/sdk/values/Row.java | 120 +++++++-------- sdks/java/extensions/sql/build.gradle | 8 - .../sql/example/BeamSqlUnnestExample.java | 130 ----------------- .../sql/impl/rel/BeamUnnestRel.java | 35 ++++- .../sql/BeamSqlDslUnnestRowsTest.java | 137 ++++++++++-------- 5 files changed, 163 insertions(+), 267 deletions(-) delete mode 100644 sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/example/BeamSqlUnnestExample.java diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java index da8536896e15..9dd02d32c2f3 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/Row.java @@ -31,7 +31,6 @@ import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.function.Function; import java.util.stream.Collector; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -87,8 +86,8 @@ */ @Experimental(Kind.SCHEMAS) @SuppressWarnings({ - "nullness", // TODO(https://issues.apache.org/jira/browse/BEAM-10402) - "rawtypes" + "nullness", // TODO(https://issues.apache.org/jira/browse/BEAM-10402) + "rawtypes" }) public abstract class Row implements Serializable { private final Schema schema; @@ -109,32 +108,11 @@ public abstract class Row implements Serializable { /** Return the list of data values. */ public abstract List getValues(); - /** This is recursive call to get all the values of the nested rows. - The recusion is bounded by the amount of nesting with in the data - This mirrors the unnest behavior of calcite towards schema **/ - public List getNestedRowBaseValues() { - return IntStream.range(0, getFieldCount()) - .mapToObj(i -> { - List values = new ArrayList<>(); - FieldType fieldType = this.getSchema().getField(i).getType(); - if(fieldType.getTypeName().equals(TypeName.ROW)) { - Row row = this.getBaseValue(i, Row.class); - List rowValues = row.getNestedRowBaseValues(); - if(null != rowValues) { - values.addAll(rowValues); - } - } else { - values.add(this.getBaseValue(i)); - } - return values.stream(); - }).flatMap(Function.identity()).collect(Collectors.toList()); - } - /** Return a list of data values. Any LogicalType values are returned as base values. * */ public List getBaseValues() { return IntStream.range(0, getFieldCount()) - .mapToObj(i -> getBaseValue(i)) - .collect(Collectors.toList()); + .mapToObj(i -> getBaseValue(i)) + .collect(Collectors.toList()); } /** Get value by field name, {@link ClassCastException} is thrown if type doesn't match. */ @@ -487,13 +465,13 @@ public static boolean deepEquals(Object a, Object b, Schema.FieldType fieldType) return Arrays.equals((byte[]) a, (byte[]) b); } else if (fieldType.getTypeName() == TypeName.ARRAY) { return deepEqualsForCollection( - (Collection) a, (Collection) b, fieldType.getCollectionElementType()); + (Collection) a, (Collection) b, fieldType.getCollectionElementType()); } else if (fieldType.getTypeName() == TypeName.ITERABLE) { return deepEqualsForIterable( - (Iterable) a, (Iterable) b, fieldType.getCollectionElementType()); + (Iterable) a, (Iterable) b, fieldType.getCollectionElementType()); } else if (fieldType.getTypeName() == Schema.TypeName.MAP) { return deepEqualsForMap( - (Map) a, (Map) b, fieldType.getMapValueType()); + (Map) a, (Map) b, fieldType.getMapValueType()); } else { return Objects.equals(a, b); } @@ -510,7 +488,7 @@ public static int deepHashCode(Object a, Schema.FieldType fieldType) { return deepHashCodeForIterable((Iterable) a, fieldType.getCollectionElementType()); } else if (fieldType.getTypeName() == Schema.TypeName.MAP) { return deepHashCodeForMap( - (Map) a, fieldType.getMapKeyType(), fieldType.getMapValueType()); + (Map) a, fieldType.getMapKeyType(), fieldType.getMapValueType()); } else { return Objects.hashCode(a); } @@ -545,7 +523,7 @@ static boolean deepEqualsForMap(Map a, Map b, Schema.FieldTyp } static int deepHashCodeForMap( - Map a, Schema.FieldType keyType, Schema.FieldType valueType) { + Map a, Schema.FieldType keyType, Schema.FieldType valueType) { int h = 0; for (Map.Entry e : a.entrySet()) { @@ -559,7 +537,7 @@ static int deepHashCodeForMap( } static boolean deepEqualsForCollection( - Collection a, Collection b, Schema.FieldType elementType) { + Collection a, Collection b, Schema.FieldType elementType) { if (a == b) { return true; } @@ -572,7 +550,7 @@ static boolean deepEqualsForCollection( } static boolean deepEqualsForIterable( - Iterable a, Iterable b, Schema.FieldType elementType) { + Iterable a, Iterable b, Schema.FieldType elementType) { if (a == b) { return true; } @@ -627,7 +605,7 @@ private String toString(Schema.FieldType fieldType, Object value, boolean includ builder.append("["); for (Object element : (Iterable) value) { builder.append( - toString(fieldType.getCollectionElementType(), element, includeFieldNames)); + toString(fieldType.getCollectionElementType(), element, includeFieldNames)); builder.append(", "); } builder.append("]"); @@ -639,7 +617,7 @@ private String toString(Schema.FieldType fieldType, Object value, boolean includ builder.append(toString(fieldType.getMapKeyType(), entry.getKey(), includeFieldNames)); builder.append(", "); builder.append( - toString(fieldType.getMapValueType(), entry.getValue(), includeFieldNames)); + toString(fieldType.getMapValueType(), entry.getValue(), includeFieldNames)); builder.append("), "); } builder.append("}"); @@ -706,7 +684,7 @@ public FieldValueBuilder withFieldValue(Integer fieldId, Object value) { /** Set a field value using a FieldAccessDescriptor. */ public FieldValueBuilder withFieldValue( - FieldAccessDescriptor fieldAccessDescriptor, Object value) { + FieldAccessDescriptor fieldAccessDescriptor, Object value) { FieldAccessDescriptor fieldAccess = fieldAccessDescriptor.resolve(getSchema()); checkArgument(fieldAccess.referencesSingleField(), ""); fieldOverrides.addOverride(fieldAccess, new FieldOverride(value)); @@ -719,11 +697,11 @@ public FieldValueBuilder withFieldValue( */ public FieldValueBuilder withFieldValues(Map values) { values.entrySet().stream() - .forEach( - e -> - fieldOverrides.addOverride( - FieldAccessDescriptor.withFieldNames(e.getKey()).resolve(getSchema()), - new FieldOverride(e.getValue()))); + .forEach( + e -> + fieldOverrides.addOverride( + FieldAccessDescriptor.withFieldNames(e.getKey()).resolve(getSchema()), + new FieldOverride(e.getValue()))); return this; } @@ -733,19 +711,19 @@ public FieldValueBuilder withFieldValues(Map values) { */ public FieldValueBuilder withFieldAccessDescriptors(Map values) { values.entrySet().stream() - .forEach(e -> fieldOverrides.addOverride(e.getKey(), new FieldOverride(e.getValue()))); + .forEach(e -> fieldOverrides.addOverride(e.getKey(), new FieldOverride(e.getValue()))); return this; } public Row build() { Row row = - (Row) - new RowFieldMatcher() - .match( - new CapturingRowCases(getSchema(), this.fieldOverrides), - FieldType.row(getSchema()), - new RowPosition(FieldAccessDescriptor.create()), - sourceRow); + (Row) + new RowFieldMatcher() + .match( + new CapturingRowCases(getSchema(), this.fieldOverrides), + FieldType.row(getSchema()), + new RowPosition(FieldAccessDescriptor.create()), + sourceRow); return row; } } @@ -781,7 +759,7 @@ public FieldValueBuilder withFieldValue(Integer fieldId, Object value) { /** Set a field value using a FieldAccessDescriptor. */ public FieldValueBuilder withFieldValue( - FieldAccessDescriptor fieldAccessDescriptor, Object value) { + FieldAccessDescriptor fieldAccessDescriptor, Object value) { checkState(values.isEmpty()); return new FieldValueBuilder(schema, null).withFieldValue(fieldAccessDescriptor, value); } @@ -851,7 +829,7 @@ public int nextFieldId() { @Internal public Row withFieldValueGetters( - Factory> fieldValueGetterFactory, Object getterTarget) { + Factory> fieldValueGetterFactory, Object getterTarget) { checkState(getterTarget != null, "getters require withGetterTarget."); return new RowWithGetters(schema, fieldValueGetterFactory, getterTarget); } @@ -861,23 +839,23 @@ public Row build() { if (!values.isEmpty() && values.size() != schema.getFieldCount()) { throw new IllegalArgumentException( - "Row expected " - + schema.getFieldCount() - + " fields. initialized with " - + values.size() - + " fields."); + "Row expected " + + schema.getFieldCount() + + " fields. initialized with " + + values.size() + + " fields."); } if (!values.isEmpty()) { FieldOverrides fieldOverrides = new FieldOverrides(schema, this.values); if (!fieldOverrides.isEmpty()) { return (Row) - new RowFieldMatcher() - .match( - new CapturingRowCases(schema, fieldOverrides), - FieldType.row(schema), - new RowPosition(FieldAccessDescriptor.create()), - null); + new RowFieldMatcher() + .match( + new CapturingRowCases(schema, fieldOverrides), + FieldType.row(schema), + new RowPosition(FieldAccessDescriptor.create()), + null); } } return new RowWithStorage(schema, Collections.emptyList()); @@ -887,19 +865,19 @@ public Row build() { /** Creates a {@link Row} from the list of values and {@link #getSchema()}. */ public static Collector, Row> toRow(Schema schema) { return Collector.of( - () -> new ArrayList<>(schema.getFieldCount()), - List::add, - (left, right) -> { - left.addAll(right); - return left; - }, - values -> Row.withSchema(schema).addValues(values).build()); + () -> new ArrayList<>(schema.getFieldCount()), + List::add, + (left, right) -> { + left.addAll(right); + return left; + }, + values -> Row.withSchema(schema).addValues(values).build()); } /** Creates a new record filled with nulls. */ public static Row nullRow(Schema schema) { return Row.withSchema(schema) - .addValues(Collections.nCopies(schema.getFieldCount(), null)) - .build(); + .addValues(Collections.nCopies(schema.getFieldCount(), null)) + .build(); } } diff --git a/sdks/java/extensions/sql/build.gradle b/sdks/java/extensions/sql/build.gradle index 4fda5e0b0ea8..e6f021865225 100644 --- a/sdks/java/extensions/sql/build.gradle +++ b/sdks/java/extensions/sql/build.gradle @@ -206,14 +206,6 @@ task runBasicExample(type: JavaExec) { args = ["--runner=DirectRunner"] } -// Run basic SQL example -task runNestedRowInArrayExample(type: JavaExec) { - description = "Run basic SQL example" - mainClass = "org.apache.beam.sdk.extensions.sql.example.BeamSqlUnnestExample" - classpath = sourceSets.main.runtimeClasspath - args = ["--runner=DirectRunner"] -} - // Run SQL example on POJO inputs task runPojoExample(type: JavaExec) { description = "Run SQL example for PCollections of POJOs" diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/example/BeamSqlUnnestExample.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/example/BeamSqlUnnestExample.java deleted file mode 100644 index b5cb42e07a91..000000000000 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/example/BeamSqlUnnestExample.java +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.extensions.sql.example; - -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.extensions.sql.SqlTransform; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.options.PipelineOptionsFactory; -import org.apache.beam.sdk.schemas.Schema; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.MapElements; -import org.apache.beam.sdk.transforms.SimpleFunction; -import org.apache.beam.sdk.values.*; - -import java.util.Arrays; - -/** - * This is a quick example, which uses Beam SQL DSL to create a data pipeline. - * - *

Run the example from the Beam source root with - * - *

- *   ./gradlew :sdks:java:extensions:sql:runNestedRowInArrayExample
- * 
- * - *

The above command executes the example locally using direct runner. Running the pipeline in - * other runners require additional setup and are out of scope of the SQL examples. Please consult - * Beam documentation on how to run pipelines. - */ -class BeamSqlUnnestExample { - - public static void main(String[] args) { - PipelineOptions options = PipelineOptionsFactory.fromArgs(args).create(); - Pipeline p = Pipeline.create(options); - - // define the input row format level3 - Schema level3Type = - Schema.builder().addInt32Field("c1").addStringField("c2").addDoubleField("c3").build(); - - Row level3Row1 = Row.withSchema(level3Type).addValues(1, "row", 1.0).build(); - Row level3Row2 = Row.withSchema(level3Type).addValues(2, "row", 2.0).build(); - Row level3Row3 = Row.withSchema(level3Type).addValues(3, "row", 3.0).build(); - - // define the input row format level3 - Schema level2Type = - Schema.builder().addInt32Field("b1") - .addStringField("b2") - .addRowField("b3", level3Type) - .addDoubleField("b4").build(); - - - Row level2Row1 = Row.withSchema(level2Type).addValues(1, "row", level3Row1, 1.0).build(); - Row level2Row2 = Row.withSchema(level2Type).addValues(2, "row", level3Row2, 2.0).build(); - Row level2Row3 = Row.withSchema(level2Type).addValues(3, "row", level3Row3, 3.0).build(); - - // define the input row format level3 - Schema level1Type = - Schema.builder().addInt32Field("a1") - .addStringField("a2") - .addDoubleField("a3") - .addArrayField("a4", Schema.FieldType.row(level2Type)) - .build(); - Row level1Row1 = Row.withSchema(level1Type).addValues(1, "row", 1.0, - Arrays.asList(level2Row1, level2Row2, level2Row3)).build(); - Row level1Row2 = Row.withSchema(level1Type).addValues(2, "row", 2.0, - Arrays.asList(level2Row1, level2Row2, level2Row3)).build(); - Row level1Row3 = Row.withSchema(level1Type).addValues(3, "row", 3.0, - Arrays.asList(level2Row1, level2Row2, level2Row3)).build(); - - - // create a source PCollection with Create.of(); - PCollection inputTable = - PBegin.in(p).apply(Create.of(level1Row1, level1Row2, level1Row3).withRowSchema(level1Type)); - - String sql = "select t.a1, t.a2, t.a3, d.b1, d.b2, d.b4, d.b3.c1, d.b3.c2, d.b3.c3 from test t cross join unnest(t.a4) d"; - // Case 1. run a simple SQL query over input PCollection with BeamSql.simpleQuery; - PCollection dfTemp = - PCollectionTuple.of(new TupleTag<>("test"), inputTable).apply(SqlTransform.query(sql)); - - - - // print the output record of case 1; - Schema dfTempSchema = dfTemp.getSchema(); - // with out the fix it will throw following exception - // Caused by: java.lang.IllegalArgumentException: Row expected 10 fields. initialized with 8 fields. - - - // with the changes in the Row.Java - dfTemp - .apply( - "log_result", - MapElements.via( - new SimpleFunction() { - @Override - public Row apply(Row input) { - // expect output: - // PCOLLECTION: [1, row, 1.0, 1, row, 1.0, 1, row, 1.0] - // PCOLLECTION: [1, row, 1.0, 2, row, 2.0, 2, row, 2.0] - // PCOLLECTION: [1, row, 1.0, 3, row, 3.0, 3, row, 3.0] - // PCOLLECTION: [3, row, 3.0, 1, row, 1.0, 1, row, 1.0] - // PCOLLECTION: [3, row, 3.0, 2, row, 2.0, 2, row, 2.0] - // PCOLLECTION: [3, row, 3.0, 3, row, 3.0, 3, row, 3.0] - // PCOLLECTION: [2, row, 2.0, 1, row, 1.0, 1, row, 1.0] - // PCOLLECTION: [2, row, 2.0, 2, row, 2.0, 2, row, 2.0] - // PCOLLECTION: [2, row, 2.0, 3, row, 3.0, 3, row, 3.0] - - System.out.println("PCOLLECTION: " + input.getValues()); - return input; - } - })) - .setRowSchema(dfTempSchema); - - p.run().waitUntilFinish(); - } -} diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnnestRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnnestRel.java index 9e6145eb93c3..d454dfe1aa00 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnnestRel.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamUnnestRel.java @@ -17,8 +17,13 @@ */ package org.apache.beam.sdk.extensions.sql.impl.rel; +import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; import org.apache.beam.sdk.extensions.sql.impl.planner.BeamCostModel; import org.apache.beam.sdk.extensions.sql.impl.planner.BeamRelMetadataQuery; import org.apache.beam.sdk.extensions.sql.impl.planner.NodeStats; @@ -129,6 +134,34 @@ private UnnestFn(Schema outputSchema, List unnestIndices) { this.outputSchema = outputSchema; this.unnestIndices = unnestIndices; } + /** + * This is recursive call to get all the values of the nested rows. The recusion is bounded by + * the amount of nesting with in the data. This mirrors the unnest behavior of calcite towards + * schema. * + */ + private List getNestedRowBaseValues(Row nestedRow) { + return IntStream.range(0, nestedRow.getFieldCount()) + .mapToObj( + (i) -> { + List values = new ArrayList<>(); + Schema.FieldType fieldType = nestedRow.getSchema().getField(i).getType(); + if (fieldType.getTypeName().equals(Schema.TypeName.ROW)) { + @Nullable Row row = nestedRow.getBaseValue(i, Row.class); + if (row == null) { + return Stream.builder().build(); + } + List rowValues = getNestedRowBaseValues(row); + if (null != rowValues) { + values.addAll(rowValues); + } + } else { + values.add(nestedRow.getBaseValue(i)); + } + return values.stream(); + }) + .flatMap(Function.identity()) + .collect(Collectors.toList()); + } @ProcessElement public void process(@Element Row row, OutputReceiver out) { @@ -157,7 +190,7 @@ public void process(@Element Row row, OutputReceiver out) { out.output( Row.withSchema(outputSchema) .addValues(row.getBaseValues()) - .addValues(nestedRow.getNestedRowBaseValues()) + .addValues(getNestedRowBaseValues(nestedRow)) .build()); } else { out.output( diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUnnestRowsTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUnnestRowsTest.java index cd3e10db64ae..d72bc249154a 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUnnestRowsTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslUnnestRowsTest.java @@ -17,35 +17,33 @@ */ package org.apache.beam.sdk.extensions.sql; +import java.util.Arrays; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.values.*; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.TupleTag; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; -import java.util.Arrays; -/** - * Tests for nested rows handling. - */ +/** Tests for nested rows handling. */ public class BeamSqlDslUnnestRowsTest { - @Rule - public final TestPipeline pipeline = TestPipeline.create(); - @Rule - public ExpectedException exceptions = ExpectedException.none(); + @Rule public final TestPipeline pipeline = TestPipeline.create(); /** - * TODO([BEAM-14026]): This is a test of the incorrect behavior unnest - * because calcite flattens the row. + * TODO([BEAM-14026]): This is a test of the incorrect behavior unnest because calcite flattens + * the row. */ @Test public void testUnnestArrayWithNestedRows() { Schema level3Type = - Schema.builder().addInt32Field("c1").addStringField("c2").addDoubleField("c3").build(); + Schema.builder().addInt32Field("c1").addStringField("c2").addDoubleField("c3").build(); Row level3Row1 = Row.withSchema(level3Type).addValues(1, "row", 1.0).build(); Row level3Row2 = Row.withSchema(level3Type).addValues(2, "row", 2.0).build(); @@ -53,11 +51,12 @@ public void testUnnestArrayWithNestedRows() { // define the input row format level3 Schema level2Type = - Schema.builder().addInt32Field("b1") - .addStringField("b2") - .addRowField("b3", level3Type) - .addDoubleField("b4").build(); - + Schema.builder() + .addInt32Field("b1") + .addStringField("b2") + .addRowField("b3", level3Type) + .addDoubleField("b4") + .build(); Row level2Row1 = Row.withSchema(level2Type).addValues(1, "row", level3Row1, 1.0).build(); Row level2Row2 = Row.withSchema(level2Type).addValues(2, "row", level3Row2, 2.0).build(); @@ -65,55 +64,79 @@ public void testUnnestArrayWithNestedRows() { // define the input row format level3 Schema level1Type = - Schema.builder().addInt32Field("a1") - .addStringField("a2") - .addDoubleField("a3") - .addArrayField("a4", Schema.FieldType.row(level2Type)) - .build(); - Row level1Row1 = Row.withSchema(level1Type).addValues(1, "row", 1.0, - Arrays.asList(level2Row1, level2Row2, level2Row3)).build(); - Row level1Row2 = Row.withSchema(level1Type).addValues(2, "row", 2.0, - Arrays.asList(level2Row1, level2Row2, level2Row3)).build(); - Row level1Row3 = Row.withSchema(level1Type).addValues(3, "row", 3.0, - Arrays.asList(level2Row1, level2Row2, level2Row3)).build(); - + Schema.builder() + .addInt32Field("a1") + .addStringField("a2") + .addDoubleField("a3") + .addArrayField("a4", Schema.FieldType.row(level2Type)) + .build(); + Row level1Row1 = + Row.withSchema(level1Type) + .addValues(1, "row", 1.0, Arrays.asList(level2Row1, level2Row2, level2Row3)) + .build(); + Row level1Row2 = + Row.withSchema(level1Type) + .addValues(2, "row", 2.0, Arrays.asList(level2Row1, level2Row2, level2Row3)) + .build(); + Row level1Row3 = + Row.withSchema(level1Type) + .addValues(3, "row", 3.0, Arrays.asList(level2Row1, level2Row2, level2Row3)) + .build(); // create a source PCollection with Create.of(); PCollection inputTable = - PBegin.in(pipeline).apply(Create.of(level1Row1, level1Row2, level1Row3).withRowSchema(level1Type)); + PBegin.in(pipeline) + .apply(Create.of(level1Row1, level1Row2, level1Row3).withRowSchema(level1Type)); - String sql = "select t.a1, t.a2, t.a3, d.b1, d.b2, d.b4, " + - "d.b3.c1, d.b3.c2, d.b3.c3 from test t cross join unnest(t.a4) d"; + String sql = + "select t.a1, t.a2, t.a3, d.b1, d.b2, d.b4, " + + "d.b3.c1, d.b3.c2, d.b3.c3 from test t cross join unnest(t.a4) d"; // Case 1. run a simple SQL query over input PCollection with BeamSql.simpleQuery; PCollection result = - PCollectionTuple.of(new TupleTag<>("test"), inputTable).apply(SqlTransform.query(sql)); - + PCollectionTuple.of(new TupleTag<>("test"), inputTable).apply(SqlTransform.query(sql)); Schema resultSchema = - Schema.builder().addInt32Field("a1") - .addStringField("a2") - .addDoubleField("a3") - .addInt32Field("b1") - .addStringField("b2") - .addDoubleField("b4") - .addInt32Field("c1") - .addStringField("c2") - .addDoubleField("c3") - .build(); + Schema.builder() + .addInt32Field("a1") + .addStringField("a2") + .addDoubleField("a3") + .addInt32Field("b1") + .addStringField("b2") + .addDoubleField("b4") + .addInt32Field("c1") + .addStringField("c2") + .addDoubleField("c3") + .build(); PAssert.that(result) - .containsInAnyOrder( - Row.withSchema(resultSchema).addValues(1, "row", 1.0, 1, "row", 1.0, 1, "row", 1.0).build(), - Row.withSchema(resultSchema).addValues(1, "row", 1.0, 2, "row", 2.0, 2, "row", 2.0).build(), - Row.withSchema(resultSchema).addValues(1, "row", 1.0, 3, "row", 3.0, 3, "row", 3.0).build(), - Row.withSchema(resultSchema).addValues(3, "row", 3.0, 1, "row", 1.0, 1, "row", 1.0).build(), - Row.withSchema(resultSchema).addValues(3, "row", 3.0, 2, "row", 2.0, 2, "row", 2.0).build(), - Row.withSchema(resultSchema).addValues(3, "row", 3.0, 3, "row", 3.0, 3, "row", 3.0).build(), - Row.withSchema(resultSchema).addValues(2, "row", 2.0, 1, "row", 1.0, 1, "row", 1.0).build(), - Row.withSchema(resultSchema).addValues(2, "row", 2.0, 2, "row", 2.0, 2, "row", 2.0).build(), - Row.withSchema(resultSchema).addValues(2, "row", 2.0, 3, "row", 3.0, 3, "row", 3.0).build() - ); + .containsInAnyOrder( + Row.withSchema(resultSchema) + .addValues(1, "row", 1.0, 1, "row", 1.0, 1, "row", 1.0) + .build(), + Row.withSchema(resultSchema) + .addValues(1, "row", 1.0, 2, "row", 2.0, 2, "row", 2.0) + .build(), + Row.withSchema(resultSchema) + .addValues(1, "row", 1.0, 3, "row", 3.0, 3, "row", 3.0) + .build(), + Row.withSchema(resultSchema) + .addValues(3, "row", 3.0, 1, "row", 1.0, 1, "row", 1.0) + .build(), + Row.withSchema(resultSchema) + .addValues(3, "row", 3.0, 2, "row", 2.0, 2, "row", 2.0) + .build(), + Row.withSchema(resultSchema) + .addValues(3, "row", 3.0, 3, "row", 3.0, 3, "row", 3.0) + .build(), + Row.withSchema(resultSchema) + .addValues(2, "row", 2.0, 1, "row", 1.0, 1, "row", 1.0) + .build(), + Row.withSchema(resultSchema) + .addValues(2, "row", 2.0, 2, "row", 2.0, 2, "row", 2.0) + .build(), + Row.withSchema(resultSchema) + .addValues(2, "row", 2.0, 3, "row", 3.0, 3, "row", 3.0) + .build()); pipeline.run(); } - }