diff --git a/python/python/tests/test_json.py b/python/python/tests/test_json.py index d6a911c6aec..0a9e328b256 100644 --- a/python/python/tests/test_json.py +++ b/python/python/tests/test_json.py @@ -213,6 +213,16 @@ def test_json_path_queries(): result = dataset.to_table( filter="json_extract(data, '$.user.name') = '\"Alice\"'" ) + sql = ( + dataset.sql( + "SELECT * FROM dataset WHERE " + "json_extract(data, '$.user.name') = '\"Alice\"'" + ) + .build() + .to_batch_records() + ) + sql_result = pa.Table.from_batches(sql) + assert result == sql_result assert result.num_rows == 1 assert result["id"][0].as_py() == 1 @@ -255,19 +265,53 @@ def test_json_get_functions(): # Test json_get_string result = dataset.to_table(filter="json_get_string(data, 'name') = 'Alice'") + sql = ( + dataset.sql( + "SELECT * FROM dataset WHERE json_get_string(data, 'name') = 'Alice'" + ) + .build() + .to_batch_records() + ) + sql_result = pa.Table.from_batches(sql) + assert result == sql_result assert result.num_rows == 1 assert result["id"][0].as_py() == 1 # Test json_get_int with type coercion result = dataset.to_table(filter="json_get_int(data, 'age') > 28") + sql = ( + dataset.sql("SELECT * FROM dataset WHERE json_get_int(data, 'age') > 28") + .build() + .to_batch_records() + ) + sql_result = pa.Table.from_batches(sql) + assert result == sql_result assert result.num_rows == 2 # Alice (30) and Charlie ("35" -> 35) # Test json_get_bool with type coercion result = dataset.to_table(filter="json_get_bool(data, 'active') = true") + sql = ( + dataset.sql( + "SELECT * FROM dataset WHERE json_get_bool(data, 'active') = true" + ) + .build() + .to_batch_records() + ) + sql_result = pa.Table.from_batches(sql) + assert result == sql_result assert result.num_rows == 2 # Alice (true) and Charlie ("true" -> true) # Test json_get_float result = dataset.to_table(filter="json_get_float(data, 'score') > 90") + sql = ( + dataset.sql( + "SELECT * FROM dataset WHERE json_get_float(data, 'score') > 90" + ) + .build() + .to_batch_records() + ) + sql_result = pa.Table.from_batches(sql) + assert result == sql_result assert result.num_rows == 2 # Alice (95.5) and Charlie ("92" -> 92.0) @@ -304,6 +348,18 @@ def test_nested_json_access(): 'name') = 'Alice'""" ) + sql = ( + dataset.sql( + "SELECT * FROM dataset WHERE " + "json_get_string(" + "json_get(json_get(data, 'user'), 'profile'), " + "'name') = 'Alice'" + ) + .build() + .to_batch_records() + ) + sql_result = pa.Table.from_batches(sql) + assert result == sql_result assert result.num_rows == 1 assert result["id"][0].as_py() == 1 @@ -311,6 +367,16 @@ def test_nested_json_access(): result = dataset.to_table( filter="json_extract(data, '$.user.profile.settings.theme') = '\"dark\"'" ) + sql = ( + dataset.sql( + "SELECT * FROM dataset WHERE " + "json_extract(data, '$.user.profile.settings.theme') = '\"dark\"'" + ) + .build() + .to_batch_records() + ) + sql_result = pa.Table.from_batches(sql) + assert result == sql_result assert result.num_rows == 1 assert result["id"][0].as_py() == 1 @@ -342,16 +408,44 @@ def test_json_array_operations(): result = dataset.to_table( filter="json_array_contains(data, '$.items', 'apple')" ) + sql = ( + dataset.sql( + "SELECT * FROM dataset WHERE " + "json_array_contains(data, '$.items', 'apple')" + ) + .build() + .to_batch_records() + ) + sql_result = pa.Table.from_batches(sql) + assert result == sql_result assert result.num_rows == 1 assert result["id"][0].as_py() == 1 # Test array length result = dataset.to_table(filter="json_array_length(data, '$.counts') > 3") + sql = ( + dataset.sql( + "SELECT * FROM dataset WHERE json_array_length(data, '$.counts') > 3" + ) + .build() + .to_batch_records() + ) + sql_result = pa.Table.from_batches(sql) + assert result == sql_result assert result.num_rows == 1 assert result["id"][0].as_py() == 1 # Test empty array result = dataset.to_table(filter="json_array_length(data, '$.items') = 0") + sql = ( + dataset.sql( + "SELECT * FROM dataset WHERE json_array_length(data, '$.items') = 0" + ) + .build() + .to_batch_records() + ) + sql_result = pa.Table.from_batches(sql) + assert result == sql_result assert result.num_rows == 1 assert result["id"][0].as_py() == 3 @@ -400,7 +494,17 @@ def test_json_filter_append_missing_json_cast(tmp_path: Path): result = dataset.to_table( filter="json_get(article_metadata, 'article_journal') IS NOT NULL" ) + sql = ( + dataset.sql( + "SELECT * FROM dataset WHERE " + "json_get(article_metadata, 'article_journal') IS NOT NULL" + ) + .build() + .to_batch_records() + ) + sql_result = pa.Table.from_batches(sql) + assert result == sql_result assert result.num_rows == 3 assert result.column("article_journal").to_pylist() == [ "Cell", diff --git a/rust/lance/src/dataset/sql.rs b/rust/lance/src/dataset/sql.rs index bf05a080aa3..2e5ec9f42bb 100644 --- a/rust/lance/src/dataset/sql.rs +++ b/rust/lance/src/dataset/sql.rs @@ -2,12 +2,14 @@ // SPDX-FileCopyrightText: Copyright The Lance Authors use crate::datafusion::LanceTableProvider; +use crate::dataset::utils::SchemaAdapter; use crate::Dataset; use arrow_array::RecordBatch; use datafusion::dataframe::DataFrame; use datafusion::execution::SendableRecordBatchStream; use datafusion::prelude::SessionContext; use futures::TryStreamExt; +use lance_datafusion::udf::register_functions; use std::sync::Arc; /// A SQL builder to prepare options for running SQL queries against a Lance dataset. @@ -75,6 +77,7 @@ impl SqlQueryBuilder { row_addr, )), )?; + register_functions(&ctx); let df = ctx.sql(&self.sql).await?; Ok(SqlQuery::new(df)) } @@ -90,7 +93,18 @@ impl SqlQuery { } pub async fn into_stream(self) -> lance_core::Result { - self.dataframe.execute_stream().await.map_err(|e| e.into()) + let exec_node = self + .dataframe + .execute_stream() + .await + .map_err(lance_core::Error::from)?; + let schema = exec_node.schema(); + if SchemaAdapter::requires_logical_conversion(&schema) { + let adapter = SchemaAdapter::new(schema); + Ok(adapter.to_logical_stream(exec_node)) + } else { + Ok(exec_node) + } } pub async fn into_batch_records(self) -> lance_core::Result> { @@ -109,11 +123,18 @@ impl SqlQuery { #[cfg(test)] mod tests { use crate::utils::test::{assert_string_matches, DatagenExt, FragmentCount, FragmentRowCount}; + use std::collections::HashMap; + use std::sync::Arc; + use crate::Dataset; use all_asserts::assert_true; use arrow_array::cast::AsArray; use arrow_array::types::{Int32Type, Int64Type, UInt64Type}; - + use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, StringArray}; + use arrow_schema::Schema as ArrowSchema; + use arrow_schema::{DataType, Field}; + use lance_arrow::json::ARROW_JSON_EXT_NAME; + use lance_arrow::ARROW_EXT_NAME_KEY; use lance_datagen::{array, gen_batch}; #[tokio::test] @@ -280,4 +301,68 @@ mod tests { ]], row_count: 1 }"#; assert_string_matches(&plan, expected_pattern).unwrap(); } + + #[tokio::test] + async fn test_nested_json_access() { + let json_rows = vec![ + Some(r#"{"user": {"profile": {"name": "Alice", "settings": {"theme": "dark"}}}}"#), + Some(r#"{"user": {"profile": {"name": "Bob", "settings": {"theme": "light"}}}}"#), + ]; + let json_array = StringArray::from(json_rows); + let id_array = Int32Array::from(vec![1, 2]); + + let mut metadata = HashMap::new(); + metadata.insert( + ARROW_EXT_NAME_KEY.to_string(), + ARROW_JSON_EXT_NAME.to_string(), + ); + + let schema = Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("data", DataType::Utf8, true).with_metadata(metadata), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(id_array), Arc::new(json_array)], + ) + .unwrap(); + + let reader = RecordBatchIterator::new(vec![Ok(batch.clone())], schema.clone()); + let ds = Dataset::write(reader, "memory://test_nested_json_access", None) + .await + .unwrap(); + + let results = ds + .sql( + "SELECT id FROM dataset WHERE \ + json_get_string(json_get(json_get(data, 'user'), 'profile'), 'name') = 'Alice'", + ) + .build() + .await + .unwrap() + .into_batch_records() + .await + .unwrap(); + let batch = results.into_iter().next().unwrap(); + pretty_assertions::assert_eq!(batch.num_rows(), 1); + pretty_assertions::assert_eq!(batch.num_columns(), 1); + pretty_assertions::assert_eq!(batch.column(0).as_primitive::().value(0), 1); + + let results = ds + .sql( + "SELECT id FROM dataset WHERE \ + json_extract(data, '$.user.profile.settings.theme') = '\"dark\"'", + ) + .build() + .await + .unwrap() + .into_batch_records() + .await + .unwrap(); + let batch = results.into_iter().next().unwrap(); + pretty_assertions::assert_eq!(batch.num_rows(), 1); + pretty_assertions::assert_eq!(batch.num_columns(), 1); + pretty_assertions::assert_eq!(batch.column(0).as_primitive::().value(0), 1); + } }