From d0f6390f84f0ea03e2c4e9e15a8f1a7d2622c741 Mon Sep 17 00:00:00 2001 From: yanghua Date: Wed, 25 Jun 2025 20:35:50 +0800 Subject: [PATCH 01/14] feat: support sql api for dataset --- rust/lance/src/dataset.rs | 61 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 58 insertions(+), 3 deletions(-) diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index f02ba0baf4c..f899d557e98 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -7,6 +7,7 @@ use arrow_array::{RecordBatch, RecordBatchReader}; use byteorder::{ByteOrder, LittleEndian}; use chrono::{prelude::*, Duration}; +use datafusion::prelude::SessionContext; use deepsize::DeepSizeOf; use futures::future::BoxFuture; use futures::stream::{self, BoxStream, StreamExt, TryStreamExt}; @@ -74,6 +75,7 @@ use self::refs::Tags; use self::scanner::{DatasetRecordBatchStream, Scanner}; use self::transaction::{Operation, Transaction}; use self::write::write_fragments_internal; +use crate::datafusion::LanceTableProvider; use crate::datatypes::Schema; use crate::error::box_error; use crate::io::commit::{ @@ -1449,6 +1451,28 @@ impl Dataset { *self = self.checkout_version(latest_version).await?; Ok(()) } + + /// Run a SQL query against the dataset. + pub async fn sql( + &self, + sql: &str, + table: &str, + with_row_id: bool, + with_row_addr: bool, + ) -> Result> { + let ctx = SessionContext::new(); + ctx.register_table( + table, + Arc::new(LanceTableProvider::new( + Arc::new(self.clone()), + with_row_id, + with_row_addr, + )), + )?; + let df = ctx.sql(sql).await?; + let result = df.collect().await?; + Ok(result) + } } pub(crate) struct NewTransactionResult<'a> { @@ -1893,7 +1917,9 @@ mod tests { use crate::dataset::transaction::DataReplacementGroup; use crate::dataset::WriteMode::Overwrite; use crate::index::vector::VectorIndexParams; - use crate::utils::test::copy_test_data_to_tmp; + use crate::utils::test::{ + copy_test_data_to_tmp, DatagenExt, FragmentCount, FragmentRowCount, TestDatasetGenerator, + }; use arrow::array::{as_struct_array, AsArray, GenericListBuilder, GenericStringBuilder}; use arrow::compute::concat_batches; @@ -1930,6 +1956,7 @@ mod tests { use lance_table::format::{DataFile, WriterVersion}; use all_asserts::assert_true; + use arrow_array::types::Int64Type; use lance_testing::datagen::generate_random_array; use pretty_assertions::assert_eq; use rand::seq::SliceRandom; @@ -6395,6 +6422,7 @@ mod tests { ); } +<<<<<<< HEAD #[rstest] #[tokio::test] async fn test_fragment_id_zero_not_reused() { @@ -6479,7 +6507,7 @@ mod tests { schema.clone(), vec![Arc::new(UInt32Array::from_iter_values(0..30))], ) - .unwrap(); + .unwrap(); let batches = RecordBatchIterator::new(vec![Ok(data)], schema.clone()); let write_params = WriteParams { max_rows_per_file: 10, // Force multiple fragments @@ -6509,7 +6537,7 @@ mod tests { schema.clone(), vec![Arc::new(UInt32Array::from_iter_values(100..120))], ) - .unwrap(); + .unwrap(); let batches = RecordBatchIterator::new(vec![Ok(data)], schema.clone()); let write_params = WriteParams { mode: WriteMode::Append, @@ -6526,4 +6554,31 @@ mod tests { assert_eq!(dataset.get_fragments()[1].id(), 4); assert_eq!(dataset.manifest.max_fragment_id(), Some(4)); } + + #[tokio::test] + async fn test_sql() { + let test_dir = tempdir().unwrap(); + let test_uri = test_dir.path().to_str().unwrap(); + let ds = gen() + .col("x", array::step::()) + .col("y", array::step_custom::(0, 2)) + .into_dataset( + test_uri, + FragmentCount::from(10), + FragmentRowCount::from(10), + ) + .await + .unwrap(); + + let results = ds + .sql("SELECT SUM(x) FROM foo WHERE y > 100", "foo", true, true) + .await + .unwrap(); + assert_eq!(results.len(), 1); + let results = results.into_iter().next().unwrap(); + assert_eq!(results.num_columns(), 1); + assert_eq!(results.num_rows(), 1); + // SUM(0..100) - SUM(0..50) = 3675 + assert_eq!(results.column(0).as_primitive::().value(0), 3675); + } } From bb40714b42d6513357caf584ecabdfd27327bbba Mon Sep 17 00:00:00 2001 From: yanghua Date: Wed, 25 Jun 2025 21:53:31 +0800 Subject: [PATCH 02/14] feat: support sql api for dataset --- rust/lance/src/dataset.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index f899d557e98..14dce1df9b2 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -6580,5 +6580,16 @@ mod tests { assert_eq!(results.num_rows(), 1); // SUM(0..100) - SUM(0..50) = 3675 assert_eq!(results.column(0).as_primitive::().value(0), 3675); + + // verify _rowid and _rowaddr + let results = ds + .sql("SELECT x, y, _rowid, _rowaddr FROM foo where y > 100", "foo", true, true) + .await + .unwrap(); + assert_eq!(results.len(), 1); + let results = results.into_iter().next().unwrap(); + assert_eq!(results.num_columns(), 4); + assert_true!(results.column(2).as_primitive::().value(0) > 100); + assert_true!(results.column(3).as_primitive::().value(0) > 100); } } From d0a35bf33ab90ff50d7a1aeaef8169c1f0084fd4 Mon Sep 17 00:00:00 2001 From: yanghua Date: Wed, 25 Jun 2025 21:54:34 +0800 Subject: [PATCH 03/14] feat: support sql api for dataset --- rust/lance/src/dataset.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index 14dce1df9b2..b30ccdce7c6 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -6583,7 +6583,12 @@ mod tests { // verify _rowid and _rowaddr let results = ds - .sql("SELECT x, y, _rowid, _rowaddr FROM foo where y > 100", "foo", true, true) + .sql( + "SELECT x, y, _rowid, _rowaddr FROM foo where y > 100", + "foo", + true, + true, + ) .await .unwrap(); assert_eq!(results.len(), 1); From 03d9bd3fd6f41b33c77865845ed0ae67b0e3bc49 Mon Sep 17 00:00:00 2001 From: yanghua Date: Fri, 27 Jun 2025 10:52:38 +0800 Subject: [PATCH 04/14] feat: support sql api for dataset --- rust/lance/src/dataset.rs | 107 ++++++++++++++++++++++++++++---------- 1 file changed, 79 insertions(+), 28 deletions(-) diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index b30ccdce7c6..ed5fb4e4706 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -325,6 +325,69 @@ impl From for ProjectionRequest { } } +/// Customize the params of dataset's sql API. +#[derive(Clone, Debug)] +pub struct SqlOptions { + /// the dataset to run the SQL query + dataset: Option, + + /// the SQL query to run + sql: String, + + /// the name of the table to register in the datafusion context + table_name: String, + + /// if true, the query result will include the internal row id + row_id: bool, + + /// if true, the query result will include the internal row address + row_addr: bool, +} + +impl SqlOptions { + pub fn table_name(mut self, table_name: &str) -> Self { + self.table_name = table_name.to_string(); + self + } + + pub fn with_row_id(mut self, row_id: bool) -> Self { + self.row_id = row_id; + self + } + + pub fn with_row_addr(mut self, row_addr: bool) -> Self { + self.row_addr = row_addr; + self + } + + pub async fn execute(self) -> Result> { + let ctx = SessionContext::new(); + ctx.register_table( + self.table_name, + Arc::new(LanceTableProvider::new( + Arc::new(self.dataset.unwrap()), + self.row_id, + self.row_addr, + )), + )?; + let df = ctx.sql(&self.sql).await?; + let result = df.collect().await?; + Ok(result) + } +} + +impl Default for SqlOptions { + fn default() -> Self { + Self { + dataset: None, + sql: "".to_string(), + table_name: "".to_string(), + row_id: false, + row_addr: false, + } + } +} + impl Dataset { /// Open an existing dataset. /// @@ -1453,25 +1516,13 @@ impl Dataset { } /// Run a SQL query against the dataset. - pub async fn sql( - &self, - sql: &str, - table: &str, - with_row_id: bool, - with_row_addr: bool, - ) -> Result> { - let ctx = SessionContext::new(); - ctx.register_table( - table, - Arc::new(LanceTableProvider::new( - Arc::new(self.clone()), - with_row_id, - with_row_addr, - )), - )?; - let df = ctx.sql(sql).await?; - let result = df.collect().await?; - Ok(result) + /// The underlying SQL engine is DataFusion. + pub fn sql(&mut self, sql: &str) -> SqlOptions { + SqlOptions { + dataset: Some(self.clone()), + sql: sql.to_string(), + ..Default::default() + } } } @@ -6559,7 +6610,7 @@ mod tests { async fn test_sql() { let test_dir = tempdir().unwrap(); let test_uri = test_dir.path().to_str().unwrap(); - let ds = gen() + let mut ds = gen() .col("x", array::step::()) .col("y", array::step_custom::(0, 2)) .into_dataset( @@ -6571,7 +6622,9 @@ mod tests { .unwrap(); let results = ds - .sql("SELECT SUM(x) FROM foo WHERE y > 100", "foo", true, true) + .sql("SELECT SUM(x) FROM foo WHERE y > 100") + .table_name("foo") + .execute() .await .unwrap(); assert_eq!(results.len(), 1); @@ -6581,14 +6634,12 @@ mod tests { // SUM(0..100) - SUM(0..50) = 3675 assert_eq!(results.column(0).as_primitive::().value(0), 3675); - // verify _rowid and _rowaddr let results = ds - .sql( - "SELECT x, y, _rowid, _rowaddr FROM foo where y > 100", - "foo", - true, - true, - ) + .sql("SELECT x, y, _rowid, _rowaddr FROM foo where y > 100") + .table_name("foo") + .with_row_id(true) + .with_row_addr(true) + .execute() .await .unwrap(); assert_eq!(results.len(), 1); From 13a9b904a8d4db0b267084a053238557673c3c6f Mon Sep 17 00:00:00 2001 From: yanghua Date: Thu, 10 Jul 2025 23:43:46 +0800 Subject: [PATCH 05/14] feat(rust): support sql api --- rust/lance-core/src/error.rs | 5 +++ rust/lance/src/dataset.rs | 54 ++++-------------------- rust/lance/src/dataset/sql.rs | 78 +++++++++++++++++++++++++++++++++++ 3 files changed, 91 insertions(+), 46 deletions(-) create mode 100644 rust/lance/src/dataset/sql.rs diff --git a/rust/lance-core/src/error.rs b/rust/lance-core/src/error.rs index 6eed53b010e..25448e01f64 100644 --- a/rust/lance-core/src/error.rs +++ b/rust/lance-core/src/error.rs @@ -111,6 +111,11 @@ pub enum Error { minor_version: u16, location: Location, }, + #[snafu(display("DataFusion inner error: {source}, {location}"))] + DataFusionInnerError { + source: BoxedError, + location: Location, + }, } impl Error { diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index ed5fb4e4706..af64dee8a74 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -7,7 +7,6 @@ use arrow_array::{RecordBatch, RecordBatchReader}; use byteorder::{ByteOrder, LittleEndian}; use chrono::{prelude::*, Duration}; -use datafusion::prelude::SessionContext; use deepsize::DeepSizeOf; use futures::future::BoxFuture; use futures::stream::{self, BoxStream, StreamExt, TryStreamExt}; @@ -61,6 +60,7 @@ pub mod refs; pub(crate) mod rowids; pub mod scanner; mod schema_evolution; +mod sql; pub mod statistics; mod take; pub mod transaction; @@ -75,7 +75,6 @@ use self::refs::Tags; use self::scanner::{DatasetRecordBatchStream, Scanner}; use self::transaction::{Operation, Transaction}; use self::write::write_fragments_internal; -use crate::datafusion::LanceTableProvider; use crate::datatypes::Schema; use crate::error::box_error; use crate::io::commit::{ @@ -344,50 +343,6 @@ pub struct SqlOptions { row_addr: bool, } -impl SqlOptions { - pub fn table_name(mut self, table_name: &str) -> Self { - self.table_name = table_name.to_string(); - self - } - - pub fn with_row_id(mut self, row_id: bool) -> Self { - self.row_id = row_id; - self - } - - pub fn with_row_addr(mut self, row_addr: bool) -> Self { - self.row_addr = row_addr; - self - } - - pub async fn execute(self) -> Result> { - let ctx = SessionContext::new(); - ctx.register_table( - self.table_name, - Arc::new(LanceTableProvider::new( - Arc::new(self.dataset.unwrap()), - self.row_id, - self.row_addr, - )), - )?; - let df = ctx.sql(&self.sql).await?; - let result = df.collect().await?; - Ok(result) - } -} - -impl Default for SqlOptions { - fn default() -> Self { - Self { - dataset: None, - sql: "".to_string(), - table_name: "".to_string(), - row_id: false, - row_addr: false, - } - } -} - impl Dataset { /// Open an existing dataset. /// @@ -1517,6 +1472,7 @@ impl Dataset { /// Run a SQL query against the dataset. /// The underlying SQL engine is DataFusion. + /// Please refer to the DataFusion documentation for supported SQL syntax. pub fn sql(&mut self, sql: &str) -> SqlOptions { SqlOptions { dataset: Some(self.clone()), @@ -6626,6 +6582,9 @@ mod tests { .table_name("foo") .execute() .await + .unwrap() + .collect() + .await .unwrap(); assert_eq!(results.len(), 1); let results = results.into_iter().next().unwrap(); @@ -6641,6 +6600,9 @@ mod tests { .with_row_addr(true) .execute() .await + .unwrap() + .collect() + .await .unwrap(); assert_eq!(results.len(), 1); let results = results.into_iter().next().unwrap(); diff --git a/rust/lance/src/dataset/sql.rs b/rust/lance/src/dataset/sql.rs new file mode 100644 index 00000000000..89a65ec3bb3 --- /dev/null +++ b/rust/lance/src/dataset/sql.rs @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use crate::datafusion::LanceTableProvider; +use crate::dataset::SqlOptions; +use arrow_array::RecordBatch; +use datafusion::execution::SendableRecordBatchStream; +use datafusion::prelude::SessionContext; +use lance_core::Error; +use snafu::location; +use std::sync::Arc; + +impl SqlOptions { + pub fn table_name(mut self, table_name: &str) -> Self { + self.table_name = table_name.to_string(); + self + } + + pub fn with_row_id(mut self, row_id: bool) -> Self { + self.row_id = row_id; + self + } + + pub fn with_row_addr(mut self, row_addr: bool) -> Self { + self.row_addr = row_addr; + self + } + + pub async fn execute(self) -> lance_core::Result { + let ctx = SessionContext::new(); + ctx.register_table( + self.table_name, + Arc::new(LanceTableProvider::new( + Arc::new(self.dataset.unwrap()), + self.row_id, + self.row_addr, + )), + )?; + let df = ctx.sql(&self.sql).await?; + let result_stream = df.execute_stream().await.unwrap(); + Ok(QueryResult { + stream: result_stream, + }) + } +} + +impl Default for SqlOptions { + fn default() -> Self { + Self { + dataset: None, + sql: "".to_string(), + table_name: "".to_string(), + row_id: false, + row_addr: false, + } + } +} + +pub struct QueryResult { + stream: SendableRecordBatchStream, +} + +impl QueryResult { + pub fn into_stream(self) -> SendableRecordBatchStream { + self.stream + } + + pub async fn collect(self) -> lance_core::Result> { + use futures::TryStreamExt; + self.stream + .try_collect::>() + .await + .map_err(|e| Error::DataFusionInnerError { + source: e.into(), + location: location!(), + }) + } +} From b611a9095146bb413dab4de503b32cf6d96efadd Mon Sep 17 00:00:00 2001 From: yanghua Date: Thu, 10 Jul 2025 23:46:34 +0800 Subject: [PATCH 06/14] feat(rust): support sql api --- rust/lance/src/dataset.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index af64dee8a74..415c57bb3f5 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -6429,7 +6429,6 @@ mod tests { ); } -<<<<<<< HEAD #[rstest] #[tokio::test] async fn test_fragment_id_zero_not_reused() { @@ -6514,7 +6513,7 @@ mod tests { schema.clone(), vec![Arc::new(UInt32Array::from_iter_values(0..30))], ) - .unwrap(); + .unwrap(); let batches = RecordBatchIterator::new(vec![Ok(data)], schema.clone()); let write_params = WriteParams { max_rows_per_file: 10, // Force multiple fragments @@ -6544,7 +6543,7 @@ mod tests { schema.clone(), vec![Arc::new(UInt32Array::from_iter_values(100..120))], ) - .unwrap(); + .unwrap(); let batches = RecordBatchIterator::new(vec![Ok(data)], schema.clone()); let write_params = WriteParams { mode: WriteMode::Append, From 9dcba816ade1195c1149d460cb9be330ef6e328e Mon Sep 17 00:00:00 2001 From: yanghua Date: Thu, 10 Jul 2025 23:49:09 +0800 Subject: [PATCH 07/14] feat(rust): support sql api --- rust/lance/src/dataset.rs | 19 ------------------- rust/lance/src/dataset/sql.rs | 20 ++++++++++++++++++++ 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index 415c57bb3f5..18781875164 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -324,25 +324,6 @@ impl From for ProjectionRequest { } } -/// Customize the params of dataset's sql API. -#[derive(Clone, Debug)] -pub struct SqlOptions { - /// the dataset to run the SQL query - dataset: Option, - - /// the SQL query to run - sql: String, - - /// the name of the table to register in the datafusion context - table_name: String, - - /// if true, the query result will include the internal row id - row_id: bool, - - /// if true, the query result will include the internal row address - row_addr: bool, -} - impl Dataset { /// Open an existing dataset. /// diff --git a/rust/lance/src/dataset/sql.rs b/rust/lance/src/dataset/sql.rs index 89a65ec3bb3..148878bb40e 100644 --- a/rust/lance/src/dataset/sql.rs +++ b/rust/lance/src/dataset/sql.rs @@ -3,6 +3,7 @@ use crate::datafusion::LanceTableProvider; use crate::dataset::SqlOptions; +use crate::Dataset; use arrow_array::RecordBatch; use datafusion::execution::SendableRecordBatchStream; use datafusion::prelude::SessionContext; @@ -10,6 +11,25 @@ use lance_core::Error; use snafu::location; use std::sync::Arc; +/// Customize the params of dataset's sql API. +#[derive(Clone, Debug)] +pub struct SqlOptions { + /// the dataset to run the SQL query + dataset: Option, + + /// the SQL query to run + sql: String, + + /// the name of the table to register in the datafusion context + table_name: String, + + /// if true, the query result will include the internal row id + row_id: bool, + + /// if true, the query result will include the internal row address + row_addr: bool, +} + impl SqlOptions { pub fn table_name(mut self, table_name: &str) -> Self { self.table_name = table_name.to_string(); From 6de037e413cb73a627c9d0675da5aeaa23021824 Mon Sep 17 00:00:00 2001 From: yanghua Date: Thu, 10 Jul 2025 23:59:23 +0800 Subject: [PATCH 08/14] feat(rust): support sql api --- rust/lance/src/dataset.rs | 1 + rust/lance/src/dataset/sql.rs | 11 +++++------ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index 18781875164..8415c440cc1 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -75,6 +75,7 @@ use self::refs::Tags; use self::scanner::{DatasetRecordBatchStream, Scanner}; use self::transaction::{Operation, Transaction}; use self::write::write_fragments_internal; +use crate::dataset::sql::SqlOptions; use crate::datatypes::Schema; use crate::error::box_error; use crate::io::commit::{ diff --git a/rust/lance/src/dataset/sql.rs b/rust/lance/src/dataset/sql.rs index 148878bb40e..f057b464440 100644 --- a/rust/lance/src/dataset/sql.rs +++ b/rust/lance/src/dataset/sql.rs @@ -2,7 +2,6 @@ // SPDX-FileCopyrightText: Copyright The Lance Authors use crate::datafusion::LanceTableProvider; -use crate::dataset::SqlOptions; use crate::Dataset; use arrow_array::RecordBatch; use datafusion::execution::SendableRecordBatchStream; @@ -15,19 +14,19 @@ use std::sync::Arc; #[derive(Clone, Debug)] pub struct SqlOptions { /// the dataset to run the SQL query - dataset: Option, + pub(crate) dataset: Option, /// the SQL query to run - sql: String, + pub(crate) sql: String, /// the name of the table to register in the datafusion context - table_name: String, + pub(crate) table_name: String, /// if true, the query result will include the internal row id - row_id: bool, + pub(crate) row_id: bool, /// if true, the query result will include the internal row address - row_addr: bool, + pub(crate) row_addr: bool, } impl SqlOptions { From b243c0269e34c2e4a60f40d9f6c4c13b387323b4 Mon Sep 17 00:00:00 2001 From: yanghua Date: Fri, 11 Jul 2025 17:53:17 +0800 Subject: [PATCH 09/14] feat(rust): support sql api --- rust/lance-core/src/error.rs | 5 -- rust/lance/src/dataset.rs | 60 ++---------------- rust/lance/src/dataset/sql.rs | 116 ++++++++++++++++++++++++++++------ 3 files changed, 99 insertions(+), 82 deletions(-) diff --git a/rust/lance-core/src/error.rs b/rust/lance-core/src/error.rs index 25448e01f64..6eed53b010e 100644 --- a/rust/lance-core/src/error.rs +++ b/rust/lance-core/src/error.rs @@ -111,11 +111,6 @@ pub enum Error { minor_version: u16, location: Location, }, - #[snafu(display("DataFusion inner error: {source}, {location}"))] - DataFusionInnerError { - source: BoxedError, - location: Location, - }, } impl Error { diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index 8415c440cc1..304b077fd3e 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -75,7 +75,7 @@ use self::refs::Tags; use self::scanner::{DatasetRecordBatchStream, Scanner}; use self::transaction::{Operation, Transaction}; use self::write::write_fragments_internal; -use crate::dataset::sql::SqlOptions; +use crate::dataset::sql::SqlBuilder; use crate::datatypes::Schema; use crate::error::box_error; use crate::io::commit::{ @@ -1455,8 +1455,8 @@ impl Dataset { /// Run a SQL query against the dataset. /// The underlying SQL engine is DataFusion. /// Please refer to the DataFusion documentation for supported SQL syntax. - pub fn sql(&mut self, sql: &str) -> SqlOptions { - SqlOptions { + pub fn sql(&mut self, sql: &str) -> SqlBuilder { + SqlBuilder { dataset: Some(self.clone()), sql: sql.to_string(), ..Default::default() @@ -1906,9 +1906,7 @@ mod tests { use crate::dataset::transaction::DataReplacementGroup; use crate::dataset::WriteMode::Overwrite; use crate::index::vector::VectorIndexParams; - use crate::utils::test::{ - copy_test_data_to_tmp, DatagenExt, FragmentCount, FragmentRowCount, TestDatasetGenerator, - }; + use crate::utils::test::copy_test_data_to_tmp; use arrow::array::{as_struct_array, AsArray, GenericListBuilder, GenericStringBuilder}; use arrow::compute::concat_batches; @@ -1945,7 +1943,6 @@ mod tests { use lance_table::format::{DataFile, WriterVersion}; use all_asserts::assert_true; - use arrow_array::types::Int64Type; use lance_testing::datagen::generate_random_array; use pretty_assertions::assert_eq; use rand::seq::SliceRandom; @@ -6542,53 +6539,4 @@ mod tests { assert_eq!(dataset.get_fragments()[1].id(), 4); assert_eq!(dataset.manifest.max_fragment_id(), Some(4)); } - - #[tokio::test] - async fn test_sql() { - let test_dir = tempdir().unwrap(); - let test_uri = test_dir.path().to_str().unwrap(); - let mut ds = gen() - .col("x", array::step::()) - .col("y", array::step_custom::(0, 2)) - .into_dataset( - test_uri, - FragmentCount::from(10), - FragmentRowCount::from(10), - ) - .await - .unwrap(); - - let results = ds - .sql("SELECT SUM(x) FROM foo WHERE y > 100") - .table_name("foo") - .execute() - .await - .unwrap() - .collect() - .await - .unwrap(); - assert_eq!(results.len(), 1); - let results = results.into_iter().next().unwrap(); - assert_eq!(results.num_columns(), 1); - assert_eq!(results.num_rows(), 1); - // SUM(0..100) - SUM(0..50) = 3675 - assert_eq!(results.column(0).as_primitive::().value(0), 3675); - - let results = ds - .sql("SELECT x, y, _rowid, _rowaddr FROM foo where y > 100") - .table_name("foo") - .with_row_id(true) - .with_row_addr(true) - .execute() - .await - .unwrap() - .collect() - .await - .unwrap(); - assert_eq!(results.len(), 1); - let results = results.into_iter().next().unwrap(); - assert_eq!(results.num_columns(), 4); - assert_true!(results.column(2).as_primitive::().value(0) > 100); - assert_true!(results.column(3).as_primitive::().value(0) > 100); - } } diff --git a/rust/lance/src/dataset/sql.rs b/rust/lance/src/dataset/sql.rs index f057b464440..46bad84489a 100644 --- a/rust/lance/src/dataset/sql.rs +++ b/rust/lance/src/dataset/sql.rs @@ -4,15 +4,14 @@ use crate::datafusion::LanceTableProvider; use crate::Dataset; use arrow_array::RecordBatch; +use datafusion::dataframe::DataFrame; use datafusion::execution::SendableRecordBatchStream; use datafusion::prelude::SessionContext; -use lance_core::Error; -use snafu::location; use std::sync::Arc; /// Customize the params of dataset's sql API. #[derive(Clone, Debug)] -pub struct SqlOptions { +pub struct SqlBuilder { /// the dataset to run the SQL query pub(crate) dataset: Option, @@ -20,7 +19,7 @@ pub struct SqlOptions { pub(crate) sql: String, /// the name of the table to register in the datafusion context - pub(crate) table_name: String, + pub(crate) table_name: Option, /// if true, the query result will include the internal row id pub(crate) row_id: bool, @@ -29,17 +28,25 @@ pub struct SqlOptions { pub(crate) row_addr: bool, } -impl SqlOptions { +impl SqlBuilder { + /// The table name to register in the datafusion context. + /// This is used to specify a "table name" for the dataset. + /// So that you can run SQL queries against it. + /// If not set, the default table name is "dataset". pub fn table_name(mut self, table_name: &str) -> Self { - self.table_name = table_name.to_string(); + self.table_name = Some(table_name.to_string()); self } + /// Specify if the query result should include the internal row id. + /// If true, the query result will include an additional column named "_rowid". pub fn with_row_id(mut self, row_id: bool) -> Self { self.row_id = row_id; self } + /// Specify if the query result should include the internal row address. + /// If true, the query result will include an additional column named "_rowaddr". pub fn with_row_addr(mut self, row_addr: bool) -> Self { self.row_addr = row_addr; self @@ -48,7 +55,7 @@ impl SqlOptions { pub async fn execute(self) -> lance_core::Result { let ctx = SessionContext::new(); ctx.register_table( - self.table_name, + self.table_name.unwrap(), Arc::new(LanceTableProvider::new( Arc::new(self.dataset.unwrap()), self.row_id, @@ -56,19 +63,16 @@ impl SqlOptions { )), )?; let df = ctx.sql(&self.sql).await?; - let result_stream = df.execute_stream().await.unwrap(); - Ok(QueryResult { - stream: result_stream, - }) + Ok(QueryResult::new(df)) } } -impl Default for SqlOptions { +impl Default for SqlBuilder { fn default() -> Self { Self { dataset: None, sql: "".to_string(), - table_name: "".to_string(), + table_name: Some("dataset".to_string()), row_id: false, row_addr: false, } @@ -76,22 +80,92 @@ impl Default for SqlOptions { } pub struct QueryResult { - stream: SendableRecordBatchStream, + dataframe: DataFrame, } impl QueryResult { - pub fn into_stream(self) -> SendableRecordBatchStream { - self.stream + pub fn new(dataframe: DataFrame) -> Self { + Self { dataframe } + } + + pub async fn into_stream(self) -> SendableRecordBatchStream { + self.dataframe.execute_stream().await.unwrap() } pub async fn collect(self) -> lance_core::Result> { use futures::TryStreamExt; - self.stream + Ok(self + .dataframe + .execute_stream() + .await + .unwrap() .try_collect::>() + .await?) + } + + pub fn dataframe(&self) -> &DataFrame { + &self.dataframe + } + + pub async fn explain(&self, verbose: bool, analyze: bool) -> DataFrame { + self.dataframe.clone().explain(verbose, analyze).unwrap() + } +} + +#[cfg(test)] +mod tests { + use crate::utils::test::{DatagenExt, FragmentCount, FragmentRowCount}; + use all_asserts::assert_true; + use arrow_array::cast::AsArray; + use arrow_array::types::{Int32Type, Int64Type, UInt64Type}; + use lance_datagen::{array, gen}; + + #[tokio::test] + async fn test_sql() { + let mut ds = gen() + .col("x", array::step::()) + .col("y", array::step_custom::(0, 2)) + .into_dataset( + "memory://test_sql_dataset", + FragmentCount::from(10), + FragmentRowCount::from(10), + ) + .await + .unwrap(); + + let results = ds + .sql("SELECT SUM(x) FROM foo WHERE y > 100") + .table_name("foo") + .execute() + .await + .unwrap() + .collect() + .await + .unwrap(); + pretty_assertions::assert_eq!(results.len(), 1); + let results = results.into_iter().next().unwrap(); + pretty_assertions::assert_eq!(results.num_columns(), 1); + pretty_assertions::assert_eq!(results.num_rows(), 1); + // SUM(0..100) - SUM(0..50) = 3675 + pretty_assertions::assert_eq!(results.column(0).as_primitive::().value(0), 3675); + + let results = ds + .sql("SELECT x, y, _rowid, _rowaddr FROM foo where y > 100") + .table_name("foo") + .with_row_id(true) + .with_row_addr(true) + .execute() + .await + .unwrap() + .collect() .await - .map_err(|e| Error::DataFusionInnerError { - source: e.into(), - location: location!(), - }) + .unwrap(); + let total_rows: usize = results.iter().map(|batch| batch.num_rows()).sum(); + let expect_rows = ds.count_rows(Some("y > 100".to_string())).await.unwrap(); + pretty_assertions::assert_eq!(total_rows, expect_rows); + let results = results.into_iter().next().unwrap(); + pretty_assertions::assert_eq!(results.num_columns(), 4); + assert_true!(results.column(2).as_primitive::().value(0) > 100); + assert_true!(results.column(3).as_primitive::().value(0) > 100); } } From 932eea8015f55791df1fc4729bb97220b5ae42a3 Mon Sep 17 00:00:00 2001 From: yanghua Date: Fri, 11 Jul 2025 18:56:59 +0800 Subject: [PATCH 10/14] feat(rust): support sql api --- rust/lance/src/dataset/sql.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/rust/lance/src/dataset/sql.rs b/rust/lance/src/dataset/sql.rs index 46bad84489a..0bbd1b9cf49 100644 --- a/rust/lance/src/dataset/sql.rs +++ b/rust/lance/src/dataset/sql.rs @@ -9,22 +9,22 @@ use datafusion::execution::SendableRecordBatchStream; use datafusion::prelude::SessionContext; use std::sync::Arc; -/// Customize the params of dataset's sql API. +/// A SQL builder to prepare options for running SQL queries against a Lance dataset. #[derive(Clone, Debug)] pub struct SqlBuilder { - /// the dataset to run the SQL query + /// The dataset to run the SQL query pub(crate) dataset: Option, - /// the SQL query to run + /// The SQL query to run pub(crate) sql: String, /// the name of the table to register in the datafusion context pub(crate) table_name: Option, - /// if true, the query result will include the internal row id + /// If true, the query result will include the internal row id pub(crate) row_id: bool, - /// if true, the query result will include the internal row address + /// If true, the query result will include the internal row address pub(crate) row_addr: bool, } From 8efa748cc108b12997a2e0be838522c7d54e378f Mon Sep 17 00:00:00 2001 From: yanghua Date: Tue, 15 Jul 2025 16:57:53 +0800 Subject: [PATCH 11/14] refactor code --- rust/lance/src/dataset.rs | 6 +- rust/lance/src/dataset/sql.rs | 114 +++++++++++++++++++++++----------- 2 files changed, 80 insertions(+), 40 deletions(-) diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index 304b077fd3e..0b58eb566d6 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -1456,11 +1456,7 @@ impl Dataset { /// The underlying SQL engine is DataFusion. /// Please refer to the DataFusion documentation for supported SQL syntax. pub fn sql(&mut self, sql: &str) -> SqlBuilder { - SqlBuilder { - dataset: Some(self.clone()), - sql: sql.to_string(), - ..Default::default() - } + SqlBuilder::new(self.clone(), sql) } } diff --git a/rust/lance/src/dataset/sql.rs b/rust/lance/src/dataset/sql.rs index 0bbd1b9cf49..d35fa7b1c0d 100644 --- a/rust/lance/src/dataset/sql.rs +++ b/rust/lance/src/dataset/sql.rs @@ -3,7 +3,7 @@ use crate::datafusion::LanceTableProvider; use crate::Dataset; -use arrow_array::RecordBatch; +use arrow_array::{Array, RecordBatch, StringArray}; use datafusion::dataframe::DataFrame; use datafusion::execution::SendableRecordBatchStream; use datafusion::prelude::SessionContext; @@ -13,7 +13,7 @@ use std::sync::Arc; #[derive(Clone, Debug)] pub struct SqlBuilder { /// The dataset to run the SQL query - pub(crate) dataset: Option, + pub(crate) dataset: Dataset, /// The SQL query to run pub(crate) sql: String, @@ -22,13 +22,23 @@ pub struct SqlBuilder { pub(crate) table_name: Option, /// If true, the query result will include the internal row id - pub(crate) row_id: bool, + pub(crate) row_id: Option, /// If true, the query result will include the internal row address - pub(crate) row_addr: bool, + pub(crate) row_addr: Option, } impl SqlBuilder { + pub fn new(dataset: Dataset, sql: &str) -> Self { + Self { + dataset, + sql: sql.to_string(), + table_name: Some("dataset".to_string()), + row_id: None, + row_addr: None, + } + } + /// The table name to register in the datafusion context. /// This is used to specify a "table name" for the dataset. /// So that you can run SQL queries against it. @@ -41,49 +51,39 @@ impl SqlBuilder { /// Specify if the query result should include the internal row id. /// If true, the query result will include an additional column named "_rowid". pub fn with_row_id(mut self, row_id: bool) -> Self { - self.row_id = row_id; + self.row_id = Some(row_id); self } /// Specify if the query result should include the internal row address. /// If true, the query result will include an additional column named "_rowaddr". pub fn with_row_addr(mut self, row_addr: bool) -> Self { - self.row_addr = row_addr; + self.row_addr = Some(row_addr); self } - pub async fn execute(self) -> lance_core::Result { + pub async fn build(self) -> lance_core::Result { let ctx = SessionContext::new(); + let row_id = self.row_id.unwrap_or(false); + let row_addr = self.row_addr.unwrap_or(false); ctx.register_table( self.table_name.unwrap(), Arc::new(LanceTableProvider::new( - Arc::new(self.dataset.unwrap()), - self.row_id, - self.row_addr, + Arc::new(self.dataset.clone()), + row_id, + row_addr, )), )?; let df = ctx.sql(&self.sql).await?; - Ok(QueryResult::new(df)) + Ok(SqlQueryBuilder::new(df)) } } -impl Default for SqlBuilder { - fn default() -> Self { - Self { - dataset: None, - sql: "".to_string(), - table_name: Some("dataset".to_string()), - row_id: false, - row_addr: false, - } - } -} - -pub struct QueryResult { +pub struct SqlQueryBuilder { dataframe: DataFrame, } -impl QueryResult { +impl SqlQueryBuilder { pub fn new(dataframe: DataFrame) -> Self { Self { dataframe } } @@ -92,7 +92,7 @@ impl QueryResult { self.dataframe.execute_stream().await.unwrap() } - pub async fn collect(self) -> lance_core::Result> { + pub async fn into_batch_records(self) -> lance_core::Result> { use futures::TryStreamExt; Ok(self .dataframe @@ -103,12 +103,30 @@ impl QueryResult { .await?) } - pub fn dataframe(&self) -> &DataFrame { - &self.dataframe + pub fn into_dataframe(self) -> DataFrame { + self.dataframe } - pub async fn explain(&self, verbose: bool, analyze: bool) -> DataFrame { - self.dataframe.clone().explain(verbose, analyze).unwrap() + pub async fn into_explain_plan( + self, + verbose: bool, + analyze: bool, + ) -> lance_core::Result { + let explained_df = self.dataframe.explain(verbose, analyze)?; + let batches = explained_df.collect().await?; + let mut lines = Vec::new(); + for batch in &batches { + let column = batch.column(0); + let array = column + .as_any() + .downcast_ref::() + .expect("Expected StringArray in 'plan' column for DataFrame.explain"); + for i in 0..array.len() { + lines.push(array.value(i).to_string()); + } + } + + Ok(lines.join("\n")) } } @@ -121,7 +139,7 @@ mod tests { use lance_datagen::{array, gen}; #[tokio::test] - async fn test_sql() { + async fn test_sql_execute() { let mut ds = gen() .col("x", array::step::()) .col("y", array::step_custom::(0, 2)) @@ -136,10 +154,10 @@ mod tests { let results = ds .sql("SELECT SUM(x) FROM foo WHERE y > 100") .table_name("foo") - .execute() + .build() .await .unwrap() - .collect() + .into_batch_records() .await .unwrap(); pretty_assertions::assert_eq!(results.len(), 1); @@ -154,10 +172,10 @@ mod tests { .table_name("foo") .with_row_id(true) .with_row_addr(true) - .execute() + .build() .await .unwrap() - .collect() + .into_batch_records() .await .unwrap(); let total_rows: usize = results.iter().map(|batch| batch.num_rows()).sum(); @@ -168,4 +186,30 @@ mod tests { assert_true!(results.column(2).as_primitive::().value(0) > 100); assert_true!(results.column(3).as_primitive::().value(0) > 100); } + + #[tokio::test] + async fn test_sql_explain_plan() { + let mut ds = gen() + .col("x", array::step::()) + .col("y", array::step_custom::(0, 2)) + .into_dataset( + "memory://test_sql_explain_plan", + FragmentCount::from(2), + FragmentRowCount::from(5), + ) + .await + .unwrap(); + + let builder = ds + .sql("SELECT SUM(x) FROM foo WHERE y > 2") + .table_name("foo") + .build() + .await + .unwrap(); + + let plan = builder.into_explain_plan(true, false).await.unwrap(); + + // 检查 explain plan 输出包含关键字 + assert!(plan.contains("Aggregate") || plan.contains("SUM")); + } } From b4a485b53f72af7b1b6e6a9ddb0dec2a793a11d4 Mon Sep 17 00:00:00 2001 From: yanghua Date: Tue, 15 Jul 2025 17:00:29 +0800 Subject: [PATCH 12/14] refactor code --- rust/lance/src/dataset/sql.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/rust/lance/src/dataset/sql.rs b/rust/lance/src/dataset/sql.rs index d35fa7b1c0d..f6ce0b5c06a 100644 --- a/rust/lance/src/dataset/sql.rs +++ b/rust/lance/src/dataset/sql.rs @@ -209,7 +209,6 @@ mod tests { let plan = builder.into_explain_plan(true, false).await.unwrap(); - // 检查 explain plan 输出包含关键字 assert!(plan.contains("Aggregate") || plan.contains("SUM")); } } From 75c84dad612eb67e3c3c2c16d2b16c44f6d2632f Mon Sep 17 00:00:00 2001 From: yanghua Date: Wed, 16 Jul 2025 11:43:40 +0800 Subject: [PATCH 13/14] refactor code --- rust/lance/src/dataset.rs | 6 +++--- rust/lance/src/dataset/sql.rs | 36 +++++++++++++++++------------------ 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index 0b58eb566d6..0d3eab904a6 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -75,7 +75,7 @@ use self::refs::Tags; use self::scanner::{DatasetRecordBatchStream, Scanner}; use self::transaction::{Operation, Transaction}; use self::write::write_fragments_internal; -use crate::dataset::sql::SqlBuilder; +use crate::dataset::sql::SqlQueryBuilder; use crate::datatypes::Schema; use crate::error::box_error; use crate::io::commit::{ @@ -1455,8 +1455,8 @@ impl Dataset { /// Run a SQL query against the dataset. /// The underlying SQL engine is DataFusion. /// Please refer to the DataFusion documentation for supported SQL syntax. - pub fn sql(&mut self, sql: &str) -> SqlBuilder { - SqlBuilder::new(self.clone(), sql) + pub fn sql(&mut self, sql: &str) -> SqlQueryBuilder { + SqlQueryBuilder::new(self.clone(), sql) } } diff --git a/rust/lance/src/dataset/sql.rs b/rust/lance/src/dataset/sql.rs index f6ce0b5c06a..9ca0e52f660 100644 --- a/rust/lance/src/dataset/sql.rs +++ b/rust/lance/src/dataset/sql.rs @@ -11,7 +11,7 @@ use std::sync::Arc; /// A SQL builder to prepare options for running SQL queries against a Lance dataset. #[derive(Clone, Debug)] -pub struct SqlBuilder { +pub struct SqlQueryBuilder { /// The dataset to run the SQL query pub(crate) dataset: Dataset, @@ -19,23 +19,23 @@ pub struct SqlBuilder { pub(crate) sql: String, /// the name of the table to register in the datafusion context - pub(crate) table_name: Option, + pub(crate) table_name: String, /// If true, the query result will include the internal row id - pub(crate) row_id: Option, + pub(crate) with_row_id: bool, /// If true, the query result will include the internal row address - pub(crate) row_addr: Option, + pub(crate) with_row_addr: bool, } -impl SqlBuilder { +impl SqlQueryBuilder { pub fn new(dataset: Dataset, sql: &str) -> Self { Self { dataset, sql: sql.to_string(), - table_name: Some("dataset".to_string()), - row_id: None, - row_addr: None, + table_name: "dataset".to_string(), + with_row_id: false, + with_row_addr: false, } } @@ -44,30 +44,30 @@ impl SqlBuilder { /// So that you can run SQL queries against it. /// If not set, the default table name is "dataset". pub fn table_name(mut self, table_name: &str) -> Self { - self.table_name = Some(table_name.to_string()); + self.table_name = table_name.to_string(); self } /// Specify if the query result should include the internal row id. /// If true, the query result will include an additional column named "_rowid". pub fn with_row_id(mut self, row_id: bool) -> Self { - self.row_id = Some(row_id); + self.with_row_id = row_id; self } /// Specify if the query result should include the internal row address. /// If true, the query result will include an additional column named "_rowaddr". pub fn with_row_addr(mut self, row_addr: bool) -> Self { - self.row_addr = Some(row_addr); + self.with_row_addr = row_addr; self } - pub async fn build(self) -> lance_core::Result { + pub async fn build(self) -> lance_core::Result { let ctx = SessionContext::new(); - let row_id = self.row_id.unwrap_or(false); - let row_addr = self.row_addr.unwrap_or(false); + let row_id = self.with_row_id; + let row_addr = self.with_row_addr; ctx.register_table( - self.table_name.unwrap(), + self.table_name, Arc::new(LanceTableProvider::new( Arc::new(self.dataset.clone()), row_id, @@ -75,15 +75,15 @@ impl SqlBuilder { )), )?; let df = ctx.sql(&self.sql).await?; - Ok(SqlQueryBuilder::new(df)) + Ok(SqlQuery::new(df)) } } -pub struct SqlQueryBuilder { +pub struct SqlQuery { dataframe: DataFrame, } -impl SqlQueryBuilder { +impl SqlQuery { pub fn new(dataframe: DataFrame) -> Self { Self { dataframe } } From f940194233736a4d8c616001b0ca64f50e422839 Mon Sep 17 00:00:00 2001 From: yanghua Date: Wed, 16 Jul 2025 13:17:01 +0800 Subject: [PATCH 14/14] refactor code --- rust/lance/src/dataset.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index 0d3eab904a6..81558a54c98 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -60,7 +60,7 @@ pub mod refs; pub(crate) mod rowids; pub mod scanner; mod schema_evolution; -mod sql; +pub mod sql; pub mod statistics; mod take; pub mod transaction;