diff --git a/LICENSE b/LICENSE
index 79de57d6670..1cb5bd54bdb 100644
--- a/LICENSE
+++ b/LICENSE
@@ -226,3 +226,212 @@ under the MIT license:
SOFTWARE.
https://github.com/pola-rs/polars/blob/main/LICENSE
+
+--------------------------------------------------------------------------------
+
+This project includes code from apache spark project, which is licensed
+under the Apache license:
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+https://github.com/apache/spark/blob/master/LICENSE
\ No newline at end of file
diff --git a/java/spark/pom.xml b/java/spark/pom.xml
index 4c6f183f5e4..c34eb78b320 100644
--- a/java/spark/pom.xml
+++ b/java/spark/pom.xml
@@ -23,6 +23,36 @@
2.12
+
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+ 3.2.1
+
+
+ scala-compile-first
+ process-resources
+
+ compile
+
+
+
+ scala-test-compile
+ process-test-resources
+
+ testCompile
+
+
+
+
+
+ -feature
+
+
+
+
+
scala-2.13
@@ -88,11 +118,18 @@
org.apache.spark
spark-sql_${scala.compat.version}
${spark.version}
+ provided
org.junit.jupiter
junit-jupiter
test
+
+ org.scalatest
+ scalatest_2.12
+ 3.2.10
+ test
+
diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/LanceConstant.java b/java/spark/src/main/java/com/lancedb/lance/spark/LanceConstant.java
new file mode 100644
index 00000000000..ad634ec92a4
--- /dev/null
+++ b/java/spark/src/main/java/com/lancedb/lance/spark/LanceConstant.java
@@ -0,0 +1,18 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.lancedb.lance.spark;
+
+public class LanceConstant {
+ public static final String ROW_ID = "_rowid";
+}
diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/LanceDataset.java b/java/spark/src/main/java/com/lancedb/lance/spark/LanceDataset.java
index 71adfab123f..bd10a527672 100644
--- a/java/spark/src/main/java/com/lancedb/lance/spark/LanceDataset.java
+++ b/java/spark/src/main/java/com/lancedb/lance/spark/LanceDataset.java
@@ -16,22 +16,41 @@
import com.lancedb.lance.spark.write.SparkWrite;
import com.google.common.collect.ImmutableSet;
+import org.apache.spark.sql.connector.catalog.MetadataColumn;
+import org.apache.spark.sql.connector.catalog.SupportsMetadataColumns;
import org.apache.spark.sql.connector.catalog.SupportsRead;
import org.apache.spark.sql.connector.catalog.SupportsWrite;
import org.apache.spark.sql.connector.catalog.TableCapability;
import org.apache.spark.sql.connector.read.ScanBuilder;
import org.apache.spark.sql.connector.write.LogicalWriteInfo;
import org.apache.spark.sql.connector.write.WriteBuilder;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
import java.util.Set;
/** Lance Spark Dataset. */
-public class LanceDataset implements SupportsRead, SupportsWrite {
+public class LanceDataset implements SupportsRead, SupportsWrite, SupportsMetadataColumns {
private static final Set CAPABILITIES =
ImmutableSet.of(TableCapability.BATCH_READ, TableCapability.BATCH_WRITE);
+ public static final MetadataColumn[] METADATA_COLUMNS =
+ new MetadataColumn[] {
+ new MetadataColumn() {
+ @Override
+ public String name() {
+ return LanceConstant.ROW_ID;
+ }
+
+ @Override
+ public DataType dataType() {
+ return DataTypes.LongType;
+ }
+ }
+ };
+
LanceConfig options;
private final StructType sparkSchema;
@@ -70,4 +89,9 @@ public Set capabilities() {
public WriteBuilder newWriteBuilder(LogicalWriteInfo logicalWriteInfo) {
return new SparkWrite.SparkWriteBuilder(sparkSchema, options);
}
+
+ @Override
+ public MetadataColumn[] metadataColumns() {
+ return METADATA_COLUMNS;
+ }
}
diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/SparkOptions.java b/java/spark/src/main/java/com/lancedb/lance/spark/SparkOptions.java
index efe39c068f5..a9edf57108d 100644
--- a/java/spark/src/main/java/com/lancedb/lance/spark/SparkOptions.java
+++ b/java/spark/src/main/java/com/lancedb/lance/spark/SparkOptions.java
@@ -34,6 +34,7 @@ public class SparkOptions {
private static final String max_row_per_file = "max_row_per_file";
private static final String max_rows_per_group = "max_rows_per_group";
private static final String max_bytes_per_file = "max_bytes_per_file";
+ private static final String batch_size = "batch_size";
public static ReadOptions genReadOptionFromConfig(LanceConfig config) {
ReadOptions.Builder builder = new ReadOptions.Builder();
@@ -85,4 +86,12 @@ private static Map genStorageOptions(LanceConfig config) {
}
return storageOptions;
}
+
+ public static int getBatchSize(LanceConfig config) {
+ Map options = config.getOptions();
+ if (options.containsKey(batch_size)) {
+ return Integer.parseInt(options.get(batch_size));
+ }
+ return 512;
+ }
}
diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java
index d3239107e4f..6225967f443 100644
--- a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java
+++ b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java
@@ -27,7 +27,7 @@
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.spark.sql.types.StructType;
-import org.apache.spark.sql.util.ArrowUtils;
+import org.apache.spark.sql.util.LanceArrowUtils;
import java.time.ZoneId;
import java.util.List;
@@ -40,7 +40,7 @@ public static Optional getSchema(LanceConfig config) {
String uri = config.getDatasetUri();
ReadOptions options = SparkOptions.genReadOptionFromConfig(config);
try (Dataset dataset = Dataset.open(allocator, uri, options)) {
- return Optional.of(ArrowUtils.fromArrowSchema(dataset.getSchema()));
+ return Optional.of(LanceArrowUtils.fromArrowSchema(dataset.getSchema()));
} catch (IllegalArgumentException e) {
// dataset not found
return Optional.empty();
@@ -49,7 +49,7 @@ public static Optional getSchema(LanceConfig config) {
public static Optional getSchema(String datasetUri) {
try (Dataset dataset = Dataset.open(datasetUri, allocator)) {
- return Optional.of(ArrowUtils.fromArrowSchema(dataset.getSchema()));
+ return Optional.of(LanceArrowUtils.fromArrowSchema(dataset.getSchema()));
} catch (IllegalArgumentException e) {
// dataset not found
return Optional.empty();
@@ -89,7 +89,7 @@ public static void appendFragments(LanceConfig config, List fr
public static LanceArrowWriter getArrowWriter(StructType sparkSchema, int batchSize) {
return new LanceArrowWriter(
- allocator, ArrowUtils.toArrowSchema(sparkSchema, "UTC", false, false), batchSize);
+ allocator, LanceArrowUtils.toArrowSchema(sparkSchema, "UTC", false, false), batchSize);
}
public static List createFragment(
@@ -104,7 +104,7 @@ public static void createDataset(String datasetUri, StructType sparkSchema, Writ
Dataset.create(
allocator,
datasetUri,
- ArrowUtils.toArrowSchema(sparkSchema, ZoneId.systemDefault().getId(), true, false),
+ LanceArrowUtils.toArrowSchema(sparkSchema, ZoneId.systemDefault().getId(), true, false),
params)
.close();
}
diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceFragmentColumnarBatchScanner.java b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceFragmentColumnarBatchScanner.java
index 1cac598f7e0..d9406b0ac7e 100644
--- a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceFragmentColumnarBatchScanner.java
+++ b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceFragmentColumnarBatchScanner.java
@@ -18,8 +18,8 @@
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowReader;
-import org.apache.spark.sql.vectorized.ArrowColumnVector;
import org.apache.spark.sql.vectorized.ColumnarBatch;
+import org.apache.spark.sql.vectorized.LanceArrowColumnVector;
import java.io.IOException;
@@ -51,8 +51,8 @@ public boolean loadNextBatch() throws IOException {
currentColumnarBatch =
new ColumnarBatch(
root.getFieldVectors().stream()
- .map(ArrowColumnVector::new)
- .toArray(ArrowColumnVector[]::new),
+ .map(LanceArrowColumnVector::new)
+ .toArray(LanceArrowColumnVector[]::new),
root.getRowCount());
return true;
}
diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceFragmentScanner.java b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceFragmentScanner.java
index a1004acf260..e60d95994ce 100644
--- a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceFragmentScanner.java
+++ b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceFragmentScanner.java
@@ -20,6 +20,7 @@
import com.lancedb.lance.ipc.LanceScanner;
import com.lancedb.lance.ipc.ScanOptions;
import com.lancedb.lance.spark.LanceConfig;
+import com.lancedb.lance.spark.LanceConstant;
import com.lancedb.lance.spark.SparkOptions;
import com.lancedb.lance.spark.read.LanceInputPartition;
@@ -59,6 +60,8 @@ public static LanceFragmentScanner create(
if (inputPartition.getWhereCondition().isPresent()) {
scanOptions.filter(inputPartition.getWhereCondition().get());
}
+ scanOptions.batchSize(SparkOptions.getBatchSize(config));
+ scanOptions.withRowId(getWithRowId(inputPartition.getSchema()));
scanner = fragment.newScan(scanOptions.build());
} catch (Throwable t) {
if (scanner != null) {
@@ -100,6 +103,15 @@ public void close() throws IOException {
}
private static List getColumnNames(StructType schema) {
- return Arrays.stream(schema.fields()).map(StructField::name).collect(Collectors.toList());
+ return Arrays.stream(schema.fields())
+ .map(StructField::name)
+ .filter(name -> !name.equals(LanceConstant.ROW_ID))
+ .collect(Collectors.toList());
+ }
+
+ private static boolean getWithRowId(StructType schema) {
+ return Arrays.stream(schema.fields())
+ .map(StructField::name)
+ .anyMatch(name -> name.equals(LanceConstant.ROW_ID));
}
}
diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceDataWriter.java b/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceDataWriter.java
index 1b7a78736dc..4e735996768 100644
--- a/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceDataWriter.java
+++ b/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceDataWriter.java
@@ -94,7 +94,8 @@ protected WriterFactory(StructType schema, LanceConfig config) {
@Override
public DataWriter createWriter(int partitionId, long taskId) {
- LanceArrowWriter arrowWriter = LanceDatasetAdapter.getArrowWriter(schema, 1024);
+ int batch_size = SparkOptions.getBatchSize(config);
+ LanceArrowWriter arrowWriter = LanceDatasetAdapter.getArrowWriter(schema, batch_size);
WriteParams params = SparkOptions.genWriteParamsFromConfig(config);
Callable> fragmentCreator =
() -> LanceDatasetAdapter.createFragment(config.getDatasetUri(), arrowWriter, params);
diff --git a/java/spark/src/main/java/org/apache/spark/sql/vectorized/LanceArrowColumnVector.java b/java/spark/src/main/java/org/apache/spark/sql/vectorized/LanceArrowColumnVector.java
new file mode 100644
index 00000000000..9b43a7a3bd5
--- /dev/null
+++ b/java/spark/src/main/java/org/apache/spark/sql/vectorized/LanceArrowColumnVector.java
@@ -0,0 +1,185 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.vectorized;
+
+import org.apache.arrow.vector.UInt8Vector;
+import org.apache.arrow.vector.ValueVector;
+import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.sql.util.LanceArrowUtils;
+import org.apache.spark.unsafe.types.UTF8String;
+
+public class LanceArrowColumnVector extends ColumnVector {
+ private UInt8Accessor uInt8Accessor;
+ private ArrowColumnVector arrowColumnVector;
+
+ public LanceArrowColumnVector(ValueVector vector) {
+ super(LanceArrowUtils.fromArrowField(vector.getField()));
+ if (vector instanceof UInt8Vector) {
+ uInt8Accessor = new UInt8Accessor((UInt8Vector) vector);
+ } else {
+ arrowColumnVector = new ArrowColumnVector(vector);
+ }
+ }
+
+ @Override
+ public void close() {
+ if (uInt8Accessor != null) {
+ uInt8Accessor.close();
+ }
+ if (arrowColumnVector != null) {
+ arrowColumnVector.close();
+ }
+ }
+
+ @Override
+ public boolean hasNull() {
+ if (uInt8Accessor != null) {
+ return uInt8Accessor.getNullCount() > 0;
+ }
+ if (arrowColumnVector != null) {
+ return arrowColumnVector.hasNull();
+ }
+ return false;
+ }
+
+ @Override
+ public int numNulls() {
+ if (uInt8Accessor != null) {
+ return uInt8Accessor.getNullCount();
+ }
+ if (arrowColumnVector != null) {
+ return arrowColumnVector.numNulls();
+ }
+ return 0;
+ }
+
+ @Override
+ public boolean isNullAt(int rowId) {
+ if (uInt8Accessor != null) {
+ return uInt8Accessor.isNullAt(rowId);
+ }
+ if (arrowColumnVector != null) {
+ return arrowColumnVector.isNullAt(rowId);
+ }
+ return false;
+ }
+
+ @Override
+ public boolean getBoolean(int rowId) {
+ if (arrowColumnVector != null) {
+ return arrowColumnVector.getBoolean(rowId);
+ }
+ return false;
+ }
+
+ @Override
+ public byte getByte(int rowId) {
+ if (arrowColumnVector != null) {
+ return arrowColumnVector.getByte(rowId);
+ }
+ return 0;
+ }
+
+ @Override
+ public short getShort(int rowId) {
+ if (arrowColumnVector != null) {
+ return arrowColumnVector.getShort(rowId);
+ }
+ return 0;
+ }
+
+ @Override
+ public int getInt(int rowId) {
+ if (arrowColumnVector != null) {
+ return arrowColumnVector.getInt(rowId);
+ }
+ return 0;
+ }
+
+ @Override
+ public long getLong(int rowId) {
+ if (uInt8Accessor != null) {
+ return uInt8Accessor.getLong(rowId);
+ }
+ if (arrowColumnVector != null) {
+ return arrowColumnVector.getLong(rowId);
+ }
+ return 0L;
+ }
+
+ @Override
+ public float getFloat(int rowId) {
+ if (arrowColumnVector != null) {
+ return arrowColumnVector.getFloat(rowId);
+ }
+ return 0;
+ }
+
+ @Override
+ public double getDouble(int rowId) {
+ if (arrowColumnVector != null) {
+ return arrowColumnVector.getDouble(rowId);
+ }
+ return 0;
+ }
+
+ @Override
+ public ColumnarArray getArray(int rowId) {
+ if (arrowColumnVector != null) {
+ return arrowColumnVector.getArray(rowId);
+ }
+ return null;
+ }
+
+ @Override
+ public ColumnarMap getMap(int ordinal) {
+ if (arrowColumnVector != null) {
+ return arrowColumnVector.getMap(ordinal);
+ }
+ return null;
+ }
+
+ @Override
+ public Decimal getDecimal(int rowId, int precision, int scale) {
+ if (arrowColumnVector != null) {
+ return arrowColumnVector.getDecimal(rowId, precision, scale);
+ }
+ return null;
+ }
+
+ @Override
+ public UTF8String getUTF8String(int rowId) {
+ if (arrowColumnVector != null) {
+ return arrowColumnVector.getUTF8String(rowId);
+ }
+ return null;
+ }
+
+ @Override
+ public byte[] getBinary(int rowId) {
+ if (arrowColumnVector != null) {
+ return arrowColumnVector.getBinary(rowId);
+ }
+ return new byte[0];
+ }
+
+ @Override
+ public ColumnVector getChild(int ordinal) {
+ if (arrowColumnVector != null) {
+ return arrowColumnVector.getChild(ordinal);
+ }
+ return null;
+ }
+}
diff --git a/java/spark/src/main/java/org/apache/spark/sql/vectorized/UInt8Accessor.java b/java/spark/src/main/java/org/apache/spark/sql/vectorized/UInt8Accessor.java
new file mode 100644
index 00000000000..bbefd355e77
--- /dev/null
+++ b/java/spark/src/main/java/org/apache/spark/sql/vectorized/UInt8Accessor.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.vectorized;
+
+import org.apache.arrow.vector.UInt8Vector;
+
+// UInt8Accessor can't extend the ArrowVectorAccessor since it's package private.
+public class UInt8Accessor {
+ private final UInt8Vector accessor;
+
+ UInt8Accessor(UInt8Vector vector) {
+ this.accessor = vector;
+ }
+
+ final long getLong(int rowId) {
+ return accessor.getObjectNoOverflow(rowId).longValueExact();
+ }
+
+ final boolean isNullAt(int rowId) {
+ return accessor.isNull(rowId);
+ }
+
+ final int getNullCount() {
+ return accessor.getNullCount();
+ }
+
+ final void close() {
+ accessor.close();
+ }
+}
diff --git a/java/spark/src/main/scala/org/apache/spark/sql/util/LanceArrowUtils.scala b/java/spark/src/main/scala/org/apache/spark/sql/util/LanceArrowUtils.scala
new file mode 100644
index 00000000000..d1e67f1fee6
--- /dev/null
+++ b/java/spark/src/main/scala/org/apache/spark/sql/util/LanceArrowUtils.scala
@@ -0,0 +1,143 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * The following code is originally from https://github.com/apache/spark/blob/master/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
+ * and is licensed under the Apache license:
+ *
+ * License: Apache License 2.0, Copyright 2014 and onwards The Apache Software Foundation.
+ * https://github.com/apache/spark/blob/master/LICENSE
+ *
+ * It has been modified by the Lance developers to fit the needs of the Lance project.
+ */
+
+package org.apache.spark.sql.util
+
+import com.lancedb.lance.spark.LanceConstant
+import org.apache.arrow.vector.complex.MapVector
+import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
+import org.apache.spark.sql.errors.ExecutionErrors
+import org.apache.spark.sql.types._
+
+import java.util.concurrent.atomic.AtomicInteger
+import scala.collection.JavaConverters._
+
+object LanceArrowUtils {
+ def fromArrowField(field: Field): DataType = {
+ field.getType match {
+ case int: ArrowType.Int if !int.getIsSigned && int.getBitWidth == 8 * 8 => LongType
+ case _ => ArrowUtils.fromArrowField(field)
+ }
+ }
+
+ def fromArrowSchema(schema: Schema): StructType = {
+ StructType(schema.getFields.asScala.map { field =>
+ val dt = fromArrowField(field)
+ StructField(field.getName, dt, field.isNullable)
+ }.toArray)
+ }
+
+ def toArrowSchema(
+ schema: StructType,
+ timeZoneId: String,
+ errorOnDuplicatedFieldNames: Boolean,
+ largeVarTypes: Boolean = false): Schema = {
+ new Schema(schema.map { field =>
+ toArrowField(
+ field.name,
+ deduplicateFieldNames(field.dataType, errorOnDuplicatedFieldNames),
+ field.nullable,
+ timeZoneId,
+ largeVarTypes)
+ }.asJava)
+ }
+
+ def toArrowField(
+ name: String,
+ dt: DataType,
+ nullable: Boolean,
+ timeZoneId: String,
+ largeVarTypes: Boolean = false): Field = {
+ dt match {
+ case ArrayType(elementType, containsNull) =>
+ val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null)
+ new Field(name, fieldType,
+ Seq(toArrowField("element", elementType, containsNull, timeZoneId,
+ largeVarTypes)).asJava)
+ case StructType(fields) =>
+ val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null)
+ new Field(name, fieldType,
+ fields.map { field =>
+ toArrowField(field.name, field.dataType, field.nullable, timeZoneId, largeVarTypes)
+ }.toSeq.asJava)
+ case MapType(keyType, valueType, valueContainsNull) =>
+ val mapType = new FieldType(nullable, new ArrowType.Map(false), null)
+ // Note: Map Type struct can not be null, Struct Type key field can not be null
+ new Field(name, mapType,
+ Seq(toArrowField(MapVector.DATA_VECTOR_NAME,
+ new StructType()
+ .add(MapVector.KEY_NAME, keyType, nullable = false)
+ .add(MapVector.VALUE_NAME, valueType, nullable = valueContainsNull),
+ nullable = false,
+ timeZoneId,
+ largeVarTypes)).asJava)
+ case udt: UserDefinedType[_] =>
+ toArrowField(name, udt.sqlType, nullable, timeZoneId, largeVarTypes)
+ case dataType =>
+ val fieldType = new FieldType(nullable, toArrowType(dataType, timeZoneId,
+ largeVarTypes, name), null)
+ new Field(name, fieldType, Seq.empty[Field].asJava)
+ }
+ }
+
+ private def toArrowType(
+ dt: DataType,
+ timeZoneId: String,
+ largeVarTypes: Boolean = false,
+ name: String): ArrowType = dt match {
+ case LongType if name.equals(LanceConstant.ROW_ID) => new ArrowType.Int(8 * 8, false)
+ case _ => ArrowUtils.toArrowType(dt, timeZoneId, largeVarTypes)
+ }
+
+ private def deduplicateFieldNames(
+ dt: DataType, errorOnDuplicatedFieldNames: Boolean): DataType = dt match {
+ case udt: UserDefinedType[_] => deduplicateFieldNames(udt.sqlType, errorOnDuplicatedFieldNames)
+ case st @ StructType(fields) =>
+ val newNames = if (st.names.toSet.size == st.names.length) {
+ st.names
+ } else {
+ if (errorOnDuplicatedFieldNames) {
+ throw ExecutionErrors.duplicatedFieldNameInArrowStructError(st.names)
+ }
+ val genNawName = st.names.groupBy(identity).map {
+ case (name, names) if names.length > 1 =>
+ val i = new AtomicInteger()
+ name -> { () => s"${name}_${i.getAndIncrement()}" }
+ case (name, _) => name -> { () => name }
+ }
+ st.names.map(genNawName(_)())
+ }
+ val newFields =
+ fields.zip(newNames).map { case (StructField(_, dataType, nullable, metadata), name) =>
+ StructField(
+ name, deduplicateFieldNames(dataType, errorOnDuplicatedFieldNames), nullable, metadata)
+ }
+ StructType(newFields)
+ case ArrayType(elementType, containsNull) =>
+ ArrayType(deduplicateFieldNames(elementType, errorOnDuplicatedFieldNames), containsNull)
+ case MapType(keyType, valueType, valueContainsNull) =>
+ MapType(
+ deduplicateFieldNames(keyType, errorOnDuplicatedFieldNames),
+ deduplicateFieldNames(valueType, errorOnDuplicatedFieldNames),
+ valueContainsNull)
+ case _ => dt
+ }
+}
diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/TestUtils.java b/java/spark/src/test/java/com/lancedb/lance/spark/TestUtils.java
index 0dfde5f471c..e9f3581ef17 100644
--- a/java/spark/src/test/java/com/lancedb/lance/spark/TestUtils.java
+++ b/java/spark/src/test/java/com/lancedb/lance/spark/TestUtils.java
@@ -37,6 +37,12 @@ public static class TestTable1Config {
Arrays.asList(1L, 2L, 3L, -1L),
Arrays.asList(2L, 4L, 6L, -2L),
Arrays.asList(3L, 6L, 9L, -3L));
+ public static final List> expectedValuesWithRowId =
+ Arrays.asList(
+ Arrays.asList(0L, 0L, 0L, 0L, 0L),
+ Arrays.asList(1L, 2L, 3L, -1L, 1L),
+ Arrays.asList(2L, 4L, 6L, -2L, (1L << 32) + 0L),
+ Arrays.asList(3L, 6L, 9L, -3L, (1L << 32) + 1L));
public static final LanceConfig lanceConfig;
public static final StructType schema =
diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorReadWithRowId.java b/java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorReadWithRowId.java
new file mode 100644
index 00000000000..9cf02bb6220
--- /dev/null
+++ b/java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorReadWithRowId.java
@@ -0,0 +1,132 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.lancedb.lance.spark.read;
+
+import com.lancedb.lance.spark.LanceConfig;
+import com.lancedb.lance.spark.LanceDataSource;
+import com.lancedb.lance.spark.TestUtils;
+
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class SparkConnectorReadWithRowId {
+ private static SparkSession spark;
+ private static String dbPath;
+ private static Dataset data;
+
+ @BeforeAll
+ static void setup() {
+ spark =
+ SparkSession.builder()
+ .appName("spark-lance-connector-test")
+ .master("local")
+ .config("spark.sql.catalog.lance", "com.lancedb.lance.spark.LanceCatalog")
+ .getOrCreate();
+ dbPath = TestUtils.TestTable1Config.dbPath;
+ data =
+ spark
+ .read()
+ .format(LanceDataSource.name)
+ .option(
+ LanceConfig.CONFIG_DATASET_URI,
+ LanceConfig.getDatasetUri(dbPath, TestUtils.TestTable1Config.datasetName))
+ .load();
+ }
+
+ @AfterAll
+ static void tearDown() {
+ if (spark != null) {
+ spark.stop();
+ }
+ }
+
+ private void validateData(Dataset data, List> expectedValues) {
+ List rows = data.collectAsList();
+ assertEquals(expectedValues.size(), rows.size());
+
+ for (int i = 0; i < rows.size(); i++) {
+ Row row = rows.get(i);
+ List expectedRow = expectedValues.get(i);
+ assertEquals(expectedRow.size(), row.size());
+
+ for (int j = 0; j < expectedRow.size(); j++) {
+ long expectedValue = expectedRow.get(j);
+ long actualValue = row.getLong(j);
+ assertEquals(expectedValue, actualValue, "Mismatch at row " + i + " column " + j);
+ }
+ }
+ }
+
+ @Test
+ public void readAllWithoutRowId() {
+ validateData(data, TestUtils.TestTable1Config.expectedValues);
+ }
+
+ @Test
+ public void readAllWithRowId() {
+ validateData(
+ data.select("x", "y", "b", "c", "_rowid"),
+ TestUtils.TestTable1Config.expectedValuesWithRowId);
+ }
+
+ @Test
+ public void select() {
+ validateData(
+ data.select("y", "b", "_rowid"),
+ TestUtils.TestTable1Config.expectedValuesWithRowId.stream()
+ .map(row -> Arrays.asList(row.get(1), row.get(2), row.get(4)))
+ .collect(Collectors.toList()));
+ }
+
+ @Test
+ public void filterSelect() {
+ validateData(
+ data.select("y", "b", "_rowid").filter("y > 3"),
+ TestUtils.TestTable1Config.expectedValuesWithRowId.stream()
+ .map(
+ row ->
+ Arrays.asList(
+ row.get(1),
+ row.get(2),
+ row.get(4))) // "y" is at index 1, "b" is at index 2, "_rowid" is at index 4
+ .filter(row -> row.get(0) > 3)
+ .collect(Collectors.toList()));
+ }
+
+ @Test
+ public void filterSelectByRowId() {
+ validateData(
+ data.select("y", "b", "_rowid").filter("_rowid > 3"),
+ TestUtils.TestTable1Config.expectedValuesWithRowId.stream()
+ .map(
+ row ->
+ Arrays.asList(
+ row.get(1),
+ row.get(2),
+ row.get(4))) // "y" is at index 1, "b" is at index 2, "_rowid" is at index 4
+ .filter(row -> row.get(2) > 3)
+ .collect(Collectors.toList()));
+ }
+}
diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/write/BatchAppendTest.java b/java/spark/src/test/java/com/lancedb/lance/spark/write/BatchAppendTest.java
index 1e51609f5ef..229fd7ba778 100644
--- a/java/spark/src/test/java/com/lancedb/lance/spark/write/BatchAppendTest.java
+++ b/java/spark/src/test/java/com/lancedb/lance/spark/write/BatchAppendTest.java
@@ -33,7 +33,7 @@
import org.apache.spark.sql.connector.write.DataWriterFactory;
import org.apache.spark.sql.connector.write.WriterCommitMessage;
import org.apache.spark.sql.types.StructType;
-import org.apache.spark.sql.util.ArrowUtils;
+import org.apache.spark.sql.util.LanceArrowUtils;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInfo;
import org.junit.jupiter.api.io.TempDir;
@@ -58,7 +58,7 @@ public void testLanceDataWriter(TestInfo testInfo) throws Exception {
// Append data to lance dataset
LanceConfig config = LanceConfig.from(datasetUri);
- StructType sparkSchema = ArrowUtils.fromArrowSchema(schema);
+ StructType sparkSchema = LanceArrowUtils.fromArrowSchema(schema);
BatchAppend batchAppend = new BatchAppend(sparkSchema, config);
DataWriterFactory factor = batchAppend.createBatchWriterFactory(() -> 1);
diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/write/LanceDataWriterTest.java b/java/spark/src/test/java/com/lancedb/lance/spark/write/LanceDataWriterTest.java
index bb5293b4e87..d94cdb13269 100644
--- a/java/spark/src/test/java/com/lancedb/lance/spark/write/LanceDataWriterTest.java
+++ b/java/spark/src/test/java/com/lancedb/lance/spark/write/LanceDataWriterTest.java
@@ -26,7 +26,7 @@
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
import org.apache.spark.sql.types.StructType;
-import org.apache.spark.sql.util.ArrowUtils;
+import org.apache.spark.sql.util.LanceArrowUtils;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInfo;
import org.junit.jupiter.api.io.TempDir;
@@ -49,7 +49,7 @@ public void testLanceDataWriter(TestInfo testInfo) throws IOException {
Schema schema = new Schema(Collections.singletonList(field));
LanceConfig config =
LanceConfig.from(tempDir.resolve(datasetName + LanceConfig.LANCE_FILE_SUFFIX).toString());
- StructType sparkSchema = ArrowUtils.fromArrowSchema(schema);
+ StructType sparkSchema = LanceArrowUtils.fromArrowSchema(schema);
LanceDataWriter.WriterFactory writerFactory =
new LanceDataWriter.WriterFactory(sparkSchema, config);
LanceDataWriter dataWriter = (LanceDataWriter) writerFactory.createWriter(0, 0);
diff --git a/java/spark/src/test/scala/org/apache/spark/sql/util/LanceArrowUtilsSuite.scala b/java/spark/src/test/scala/org/apache/spark/sql/util/LanceArrowUtilsSuite.scala
new file mode 100644
index 00000000000..0636f7664a8
--- /dev/null
+++ b/java/spark/src/test/scala/org/apache/spark/sql/util/LanceArrowUtilsSuite.scala
@@ -0,0 +1,120 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * The following code is originally from https://github.com/apache/spark/blob/master/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala
+ * and is licensed under the Apache license:
+ *
+ * License: Apache License 2.0, Copyright 2014 and onwards The Apache Software Foundation.
+ * https://github.com/apache/spark/blob/master/LICENSE
+ *
+ * It has been modified by the Lance developers to fit the needs of the Lance project.
+ */
+
+package org.apache.spark.sql.util
+
+import com.lancedb.lance.spark.LanceConstant
+import org.apache.arrow.vector.types.pojo.ArrowType
+import org.apache.spark.SparkUnsupportedOperationException
+import org.apache.spark.sql.types._
+import org.scalatest.funsuite.AnyFunSuite
+
+import java.time.ZoneId
+
+class LanceArrowUtilsSuite extends AnyFunSuite {
+ def roundtrip(dt: DataType, fieldName: String = "value"): Unit = {
+ dt match {
+ case schema: StructType =>
+ assert(LanceArrowUtils.fromArrowSchema(LanceArrowUtils.toArrowSchema(schema, null, true)) === schema)
+ case _ =>
+ roundtrip(new StructType().add(fieldName, dt))
+ }
+ }
+
+ test("unsigned long") {
+ roundtrip(BooleanType, LanceConstant.ROW_ID)
+ val arrowType = LanceArrowUtils.toArrowField(LanceConstant.ROW_ID, LongType, true, "Beijing")
+ assert(arrowType.getType.asInstanceOf[ArrowType.Int].getBitWidth === 64)
+ assert(!arrowType.getType.asInstanceOf[ArrowType.Int].getIsSigned)
+ }
+
+ test("simple") {
+ roundtrip(BooleanType)
+ roundtrip(ByteType)
+ roundtrip(ShortType)
+ roundtrip(IntegerType)
+ roundtrip(LongType)
+ roundtrip(FloatType)
+ roundtrip(DoubleType)
+ roundtrip(StringType)
+ roundtrip(BinaryType)
+ roundtrip(DecimalType.SYSTEM_DEFAULT)
+ roundtrip(DateType)
+ roundtrip(YearMonthIntervalType())
+ roundtrip(DayTimeIntervalType())
+ }
+
+ test("timestamp") {
+
+ def roundtripWithTz(timeZoneId: String): Unit = {
+ val schema = new StructType().add("value", TimestampType)
+ val arrowSchema = LanceArrowUtils.toArrowSchema(schema, timeZoneId, true)
+ val fieldType = arrowSchema.findField("value").getType.asInstanceOf[ArrowType.Timestamp]
+ assert(fieldType.getTimezone() === timeZoneId)
+ assert(LanceArrowUtils.fromArrowSchema(arrowSchema) === schema)
+ }
+
+ roundtripWithTz(ZoneId.systemDefault().getId)
+ roundtripWithTz("Asia/Tokyo")
+ roundtripWithTz("UTC")
+ }
+
+ test("array") {
+ roundtrip(ArrayType(IntegerType, containsNull = true))
+ roundtrip(ArrayType(IntegerType, containsNull = false))
+ roundtrip(ArrayType(ArrayType(IntegerType, containsNull = true), containsNull = true))
+ roundtrip(ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true))
+ roundtrip(ArrayType(ArrayType(IntegerType, containsNull = true), containsNull = false))
+ roundtrip(ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = false))
+ }
+
+ test("struct") {
+ roundtrip(new StructType())
+ roundtrip(new StructType().add("i", IntegerType))
+ roundtrip(new StructType().add("arr", ArrayType(IntegerType)))
+ roundtrip(new StructType().add("i", IntegerType).add("arr", ArrayType(IntegerType)))
+ roundtrip(new StructType().add(
+ "struct",
+ new StructType().add("i", IntegerType).add("arr", ArrayType(IntegerType))))
+ }
+
+ test("struct with duplicated field names") {
+
+ def check(dt: DataType, expected: DataType): Unit = {
+ val schema = new StructType().add("value", dt)
+ intercept[SparkUnsupportedOperationException] {
+ LanceArrowUtils.toArrowSchema(schema, null, true)
+ }
+ assert(LanceArrowUtils.fromArrowSchema(LanceArrowUtils.toArrowSchema(schema, null, false))
+ === new StructType().add("value", expected))
+ }
+
+ roundtrip(new StructType().add("i", IntegerType).add("i", StringType))
+
+ check(new StructType().add("i", IntegerType).add("i", StringType),
+ new StructType().add("i_0", IntegerType).add("i_1", StringType))
+ check(ArrayType(new StructType().add("i", IntegerType).add("i", StringType)),
+ ArrayType(new StructType().add("i_0", IntegerType).add("i_1", StringType)))
+ check(MapType(StringType, new StructType().add("i", IntegerType).add("i", StringType)),
+ MapType(StringType, new StructType().add("i_0", IntegerType).add("i_1", StringType)))
+ }
+
+}
diff --git a/java/spark/src/test/scala/org/apache/spark/sql/vectorized/LanceArrowColumnVectorSuite.scala b/java/spark/src/test/scala/org/apache/spark/sql/vectorized/LanceArrowColumnVectorSuite.scala
new file mode 100644
index 00000000000..18bf378136f
--- /dev/null
+++ b/java/spark/src/test/scala/org/apache/spark/sql/vectorized/LanceArrowColumnVectorSuite.scala
@@ -0,0 +1,519 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * The following code is originally from https://github.com/apache/spark/blob/master/sql/core/src/test/scala/org/apache/spark/sql/vectorized/ArrowColumnVectorSuite.scala
+ * and is licensed under the Apache license:
+ *
+ * License: Apache License 2.0, Copyright 2014 and onwards The Apache Software Foundation.
+ * https://github.com/apache/spark/blob/master/LICENSE
+ *
+ * It has been modified by the Lance developers to fit the needs of the Lance project.
+ */
+
+package org.apache.spark.sql.vectorized
+
+import com.lancedb.lance.spark.LanceConstant
+import org.apache.spark.sql.util.{ArrowUtils, LanceArrowUtils}
+import org.apache.spark.sql.types._
+import org.apache.arrow.vector._
+import org.apache.arrow.vector.complex._
+import org.scalatest.funsuite.AnyFunSuite
+import org.apache.spark.unsafe.types.UTF8String
+
+class LanceArrowColumnVectorSuite extends AnyFunSuite {
+ test("boolean") {
+ val allocator = ArrowUtils.rootAllocator.newChildAllocator("boolean", 0, Long.MaxValue)
+ val vector = LanceArrowUtils.toArrowField("boolean", BooleanType, nullable = true, null)
+ .createVector(allocator).asInstanceOf[BitVector]
+ vector.allocateNew()
+
+ (0 until 10).foreach { i =>
+ vector.setSafe(i, if (i % 2 == 0) 1 else 0)
+ }
+ vector.setNull(10)
+ vector.setValueCount(11)
+
+ val columnVector = new LanceArrowColumnVector(vector)
+ assert(columnVector.dataType === BooleanType)
+ assert(columnVector.hasNull)
+ assert(columnVector.numNulls === 1)
+
+ (0 until 10).foreach { i =>
+ assert(columnVector.getBoolean(i) === (i % 2 == 0))
+ }
+ assert(columnVector.isNullAt(10))
+
+ assert(columnVector.getBooleans(0, 10) === (0 until 10).map(i => (i % 2 == 0)))
+
+ columnVector.close()
+ allocator.close()
+ }
+
+
+ test("byte") {
+ val allocator = ArrowUtils.rootAllocator.newChildAllocator("byte", 0, Long.MaxValue)
+ val vector = LanceArrowUtils.toArrowField("byte", ByteType, nullable = true, null)
+ .createVector(allocator).asInstanceOf[TinyIntVector]
+ vector.allocateNew()
+
+ (0 until 10).foreach { i =>
+ vector.setSafe(i, i.toByte)
+ }
+ vector.setNull(10)
+ vector.setValueCount(11)
+
+ val columnVector = new LanceArrowColumnVector(vector)
+ assert(columnVector.dataType === ByteType)
+ assert(columnVector.hasNull)
+ assert(columnVector.numNulls === 1)
+
+ (0 until 10).foreach { i =>
+ assert(columnVector.getByte(i) === i.toByte)
+ }
+ assert(columnVector.isNullAt(10))
+
+ assert(columnVector.getBytes(0, 10) === (0 until 10).map(i => i.toByte))
+
+ columnVector.close()
+ allocator.close()
+ }
+
+ test("short") {
+ val allocator = ArrowUtils.rootAllocator.newChildAllocator("short", 0, Long.MaxValue)
+ val vector = LanceArrowUtils.toArrowField("short", ShortType, nullable = true, null)
+ .createVector(allocator).asInstanceOf[SmallIntVector]
+ vector.allocateNew()
+
+ (0 until 10).foreach { i =>
+ vector.setSafe(i, i.toShort)
+ }
+ vector.setNull(10)
+ vector.setValueCount(11)
+
+ val columnVector = new LanceArrowColumnVector(vector)
+ assert(columnVector.dataType === ShortType)
+ assert(columnVector.hasNull)
+ assert(columnVector.numNulls === 1)
+
+ (0 until 10).foreach { i =>
+ assert(columnVector.getShort(i) === i.toShort)
+ }
+ assert(columnVector.isNullAt(10))
+
+ assert(columnVector.getShorts(0, 10) === (0 until 10).map(i => i.toShort))
+
+ columnVector.close()
+ allocator.close()
+ }
+
+ test("int") {
+ val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue)
+ val vector = LanceArrowUtils.toArrowField("int", IntegerType, nullable = true, null)
+ .createVector(allocator).asInstanceOf[IntVector]
+ vector.allocateNew()
+
+ (0 until 10).foreach { i =>
+ vector.setSafe(i, i)
+ }
+ vector.setNull(10)
+ vector.setValueCount(11)
+
+ val columnVector = new LanceArrowColumnVector(vector)
+ assert(columnVector.dataType === IntegerType)
+ assert(columnVector.hasNull)
+ assert(columnVector.numNulls === 1)
+
+ (0 until 10).foreach { i =>
+ assert(columnVector.getInt(i) === i)
+ }
+ assert(columnVector.isNullAt(10))
+
+ assert(columnVector.getInts(0, 10) === (0 until 10))
+
+ columnVector.close()
+ allocator.close()
+ }
+
+ test("long") {
+ val allocator = ArrowUtils.rootAllocator.newChildAllocator("long", 0, Long.MaxValue)
+ val vector = LanceArrowUtils.toArrowField("long", LongType, nullable = true, null)
+ .createVector(allocator).asInstanceOf[BigIntVector]
+ vector.allocateNew()
+
+ (0 until 10).foreach { i =>
+ vector.setSafe(i, i.toLong)
+ }
+ vector.setNull(10)
+ vector.setValueCount(11)
+
+ val columnVector = new LanceArrowColumnVector(vector)
+ assert(columnVector.dataType === LongType)
+ assert(columnVector.hasNull)
+ assert(columnVector.numNulls === 1)
+
+ (0 until 10).foreach { i =>
+ assert(columnVector.getLong(i) === i.toLong)
+ }
+ assert(columnVector.isNullAt(10))
+
+ assert(columnVector.getLongs(0, 10) === (0 until 10).map(i => i.toLong))
+
+ columnVector.close()
+ allocator.close()
+ }
+
+ test("unsigned long") {
+ val allocator = ArrowUtils.rootAllocator.newChildAllocator("unsigned long", 0, Long.MaxValue)
+ val vector = LanceArrowUtils.toArrowField(LanceConstant.ROW_ID, LongType, nullable = true, null)
+ .createVector(allocator).asInstanceOf[UInt8Vector]
+ vector.allocateNew()
+
+ (0 until 10).foreach { i =>
+ vector.setSafe(i, i.toLong)
+ }
+ vector.setNull(10)
+ vector.setValueCount(11)
+
+ val columnVector = new LanceArrowColumnVector(vector)
+ assert(columnVector.dataType === LongType)
+ assert(columnVector.hasNull)
+ assert(columnVector.numNulls === 1)
+
+ (0 until 10).foreach { i =>
+ assert(columnVector.getLong(i) === i.toLong)
+ }
+ assert(columnVector.isNullAt(10))
+
+ assert(columnVector.getLongs(0, 10) === (0 until 10).map(i => i.toLong))
+
+ columnVector.close()
+ allocator.close()
+ }
+
+ test("float") {
+ val allocator = ArrowUtils.rootAllocator.newChildAllocator("float", 0, Long.MaxValue)
+ val vector = LanceArrowUtils.toArrowField("float", FloatType, nullable = true, null)
+ .createVector(allocator).asInstanceOf[Float4Vector]
+ vector.allocateNew()
+
+ (0 until 10).foreach { i =>
+ vector.setSafe(i, i.toFloat)
+ }
+ vector.setNull(10)
+ vector.setValueCount(11)
+
+ val columnVector = new LanceArrowColumnVector(vector)
+ assert(columnVector.dataType === FloatType)
+ assert(columnVector.hasNull)
+ assert(columnVector.numNulls === 1)
+
+ (0 until 10).foreach { i =>
+ assert(columnVector.getFloat(i) === i.toFloat)
+ }
+ assert(columnVector.isNullAt(10))
+
+ assert(columnVector.getFloats(0, 10) === (0 until 10).map(i => i.toFloat))
+
+ columnVector.close()
+ allocator.close()
+ }
+
+ test("double") {
+ val allocator = ArrowUtils.rootAllocator.newChildAllocator("double", 0, Long.MaxValue)
+ val vector = LanceArrowUtils.toArrowField("double", DoubleType, nullable = true, null)
+ .createVector(allocator).asInstanceOf[Float8Vector]
+ vector.allocateNew()
+
+ (0 until 10).foreach { i =>
+ vector.setSafe(i, i.toDouble)
+ }
+ vector.setNull(10)
+ vector.setValueCount(11)
+
+ val columnVector = new LanceArrowColumnVector(vector)
+ assert(columnVector.dataType === DoubleType)
+ assert(columnVector.hasNull)
+ assert(columnVector.numNulls === 1)
+
+ (0 until 10).foreach { i =>
+ assert(columnVector.getDouble(i) === i.toDouble)
+ }
+ assert(columnVector.isNullAt(10))
+
+ assert(columnVector.getDoubles(0, 10) === (0 until 10).map(i => i.toDouble))
+
+ columnVector.close()
+ allocator.close()
+ }
+
+ test("string") {
+ val allocator = ArrowUtils.rootAllocator.newChildAllocator("string", 0, Long.MaxValue)
+ val vector = LanceArrowUtils.toArrowField("string", StringType, nullable = true, null)
+ .createVector(allocator).asInstanceOf[VarCharVector]
+ vector.allocateNew()
+
+ (0 until 10).foreach { i =>
+ val utf8 = s"str$i".getBytes("utf8")
+ vector.setSafe(i, utf8, 0, utf8.length)
+ }
+ vector.setNull(10)
+ vector.setValueCount(11)
+
+ val columnVector = new LanceArrowColumnVector(vector)
+ assert(columnVector.dataType === StringType)
+ assert(columnVector.hasNull)
+ assert(columnVector.numNulls === 1)
+
+ (0 until 10).foreach { i =>
+ assert(columnVector.getUTF8String(i) === UTF8String.fromString(s"str$i"))
+ }
+ assert(columnVector.isNullAt(10))
+
+ columnVector.close()
+ allocator.close()
+ }
+
+ test("large_string") {
+ val allocator = ArrowUtils.rootAllocator.newChildAllocator("string", 0, Long.MaxValue)
+ val vector = LanceArrowUtils.toArrowField("string", StringType, nullable = true, null, true)
+ .createVector(allocator).asInstanceOf[LargeVarCharVector]
+ vector.allocateNew()
+
+ (0 until 10).foreach { i =>
+ val utf8 = s"str$i".getBytes("utf8")
+ vector.setSafe(i, utf8, 0, utf8.length)
+ }
+ vector.setNull(10)
+ vector.setValueCount(11)
+
+ val columnVector = new LanceArrowColumnVector(vector)
+ assert(columnVector.dataType === StringType)
+ assert(columnVector.hasNull)
+ assert(columnVector.numNulls === 1)
+
+ (0 until 10).foreach { i =>
+ assert(columnVector.getUTF8String(i) === UTF8String.fromString(s"str$i"))
+ }
+ assert(columnVector.isNullAt(10))
+
+ columnVector.close()
+ allocator.close()
+ }
+
+ test("binary") {
+ val allocator = ArrowUtils.rootAllocator.newChildAllocator("binary", 0, Long.MaxValue)
+ val vector = LanceArrowUtils.toArrowField("binary", BinaryType, nullable = true, null, false)
+ .createVector(allocator).asInstanceOf[VarBinaryVector]
+ vector.allocateNew()
+
+ (0 until 10).foreach { i =>
+ val utf8 = s"str$i".getBytes("utf8")
+ vector.setSafe(i, utf8, 0, utf8.length)
+ }
+ vector.setNull(10)
+ vector.setValueCount(11)
+
+ val columnVector = new LanceArrowColumnVector(vector)
+ assert(columnVector.dataType === BinaryType)
+ assert(columnVector.hasNull)
+ assert(columnVector.numNulls === 1)
+
+ (0 until 10).foreach { i =>
+ assert(columnVector.getBinary(i) === s"str$i".getBytes("utf8"))
+ }
+ assert(columnVector.isNullAt(10))
+
+ columnVector.close()
+ allocator.close()
+ }
+
+ test("large_binary") {
+ val allocator = ArrowUtils.rootAllocator.newChildAllocator("binary", 0, Long.MaxValue)
+ val vector = LanceArrowUtils.toArrowField("binary", BinaryType, nullable = true, null, true)
+ .createVector(allocator).asInstanceOf[LargeVarBinaryVector]
+ vector.allocateNew()
+
+ (0 until 10).foreach { i =>
+ val utf8 = s"str$i".getBytes("utf8")
+ vector.setSafe(i, utf8, 0, utf8.length)
+ }
+ vector.setNull(10)
+ vector.setValueCount(11)
+
+ val columnVector = new LanceArrowColumnVector(vector)
+ assert(columnVector.dataType === BinaryType)
+ assert(columnVector.hasNull)
+ assert(columnVector.numNulls === 1)
+
+ (0 until 10).foreach { i =>
+ assert(columnVector.getBinary(i) === s"str$i".getBytes("utf8"))
+ }
+ assert(columnVector.isNullAt(10))
+
+ columnVector.close()
+ allocator.close()
+ }
+
+ test("array") {
+ val allocator = ArrowUtils.rootAllocator.newChildAllocator("array", 0, Long.MaxValue)
+ val vector = LanceArrowUtils.toArrowField("array", ArrayType(IntegerType), nullable = true, null)
+ .createVector(allocator).asInstanceOf[ListVector]
+ vector.allocateNew()
+ val elementVector = vector.getDataVector().asInstanceOf[IntVector]
+
+ // [1, 2]
+ vector.startNewValue(0)
+ elementVector.setSafe(0, 1)
+ elementVector.setSafe(1, 2)
+ vector.endValue(0, 2)
+
+ // [3, null, 5]
+ vector.startNewValue(1)
+ elementVector.setSafe(2, 3)
+ elementVector.setNull(3)
+ elementVector.setSafe(4, 5)
+ vector.endValue(1, 3)
+
+ // null
+
+ // []
+ vector.startNewValue(3)
+ vector.endValue(3, 0)
+
+ elementVector.setValueCount(5)
+ vector.setValueCount(4)
+
+ val columnVector = new LanceArrowColumnVector(vector)
+ assert(columnVector.dataType === ArrayType(IntegerType))
+ assert(columnVector.hasNull)
+ assert(columnVector.numNulls === 1)
+
+ val array0 = columnVector.getArray(0)
+ assert(array0.numElements() === 2)
+ assert(array0.getInt(0) === 1)
+ assert(array0.getInt(1) === 2)
+
+ val array1 = columnVector.getArray(1)
+ assert(array1.numElements() === 3)
+ assert(array1.getInt(0) === 3)
+ assert(array1.isNullAt(1))
+ assert(array1.getInt(2) === 5)
+
+ assert(columnVector.isNullAt(2))
+
+ val array3 = columnVector.getArray(3)
+ assert(array3.numElements() === 0)
+
+ columnVector.close()
+ allocator.close()
+ }
+
+ test("non nullable struct") {
+ val allocator = ArrowUtils.rootAllocator.newChildAllocator("struct", 0, Long.MaxValue)
+ val schema = new StructType().add("int", IntegerType).add("long", LongType)
+ val vector = LanceArrowUtils.toArrowField("struct", schema, nullable = false, null)
+ .createVector(allocator).asInstanceOf[StructVector]
+
+ vector.allocateNew()
+ val intVector = vector.getChildByOrdinal(0).asInstanceOf[IntVector]
+ val longVector = vector.getChildByOrdinal(1).asInstanceOf[BigIntVector]
+
+ vector.setIndexDefined(0)
+ intVector.setSafe(0, 1)
+ longVector.setSafe(0, 1L)
+
+ vector.setIndexDefined(1)
+ intVector.setSafe(1, 2)
+ longVector.setNull(1)
+
+ vector.setValueCount(2)
+
+ val columnVector = new LanceArrowColumnVector(vector)
+ assert(columnVector.dataType === schema)
+ assert(!columnVector.hasNull)
+ assert(columnVector.numNulls === 0)
+
+ val row0 = columnVector.getStruct(0)
+ assert(row0.getInt(0) === 1)
+ assert(row0.getLong(1) === 1L)
+
+ val row1 = columnVector.getStruct(1)
+ assert(row1.getInt(0) === 2)
+ assert(row1.isNullAt(1))
+
+ columnVector.close()
+ allocator.close()
+ }
+
+ test("struct") {
+ val allocator = ArrowUtils.rootAllocator.newChildAllocator("struct", 0, Long.MaxValue)
+ val schema = new StructType().add("int", IntegerType).add("long", LongType)
+ val vector = LanceArrowUtils.toArrowField("struct", schema, nullable = true, null)
+ .createVector(allocator).asInstanceOf[StructVector]
+ vector.allocateNew()
+ val intVector = vector.getChildByOrdinal(0).asInstanceOf[IntVector]
+ val longVector = vector.getChildByOrdinal(1).asInstanceOf[BigIntVector]
+
+ // (1, 1L)
+ vector.setIndexDefined(0)
+ intVector.setSafe(0, 1)
+ longVector.setSafe(0, 1L)
+
+ // (2, null)
+ vector.setIndexDefined(1)
+ intVector.setSafe(1, 2)
+ longVector.setNull(1)
+
+ // (null, 3L)
+ vector.setIndexDefined(2)
+ intVector.setNull(2)
+ longVector.setSafe(2, 3L)
+
+ // null
+ vector.setNull(3)
+
+ // (5, 5L)
+ vector.setIndexDefined(4)
+ intVector.setSafe(4, 5)
+ longVector.setSafe(4, 5L)
+
+ intVector.setValueCount(5)
+ longVector.setValueCount(5)
+ vector.setValueCount(5)
+
+ val columnVector = new LanceArrowColumnVector(vector)
+ assert(columnVector.dataType === schema)
+ assert(columnVector.hasNull)
+ assert(columnVector.numNulls === 1)
+
+ val row0 = columnVector.getStruct(0)
+ assert(row0.getInt(0) === 1)
+ assert(row0.getLong(1) === 1L)
+
+ val row1 = columnVector.getStruct(1)
+ assert(row1.getInt(0) === 2)
+ assert(row1.isNullAt(1))
+
+ val row2 = columnVector.getStruct(2)
+ assert(row2.isNullAt(0))
+ assert(row2.getLong(1) === 3L)
+
+ assert(columnVector.isNullAt(3))
+
+ val row4 = columnVector.getStruct(4)
+ assert(row4.getInt(0) === 5)
+ assert(row4.getLong(1) === 5L)
+
+ columnVector.close()
+ allocator.close()
+ }
+}