Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions python/python/tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,16 @@ def test_json_path_queries():
result = dataset.to_table(
filter="json_extract(data, '$.user.name') = '\"Alice\"'"
)
sql = (
dataset.sql(
"SELECT * FROM dataset WHERE "
"json_extract(data, '$.user.name') = '\"Alice\"'"
)
.build()
.to_batch_records()
)
sql_result = pa.Table.from_batches(sql)
assert result == sql_result
assert result.num_rows == 1
assert result["id"][0].as_py() == 1

Expand Down Expand Up @@ -255,19 +265,53 @@ def test_json_get_functions():

# Test json_get_string
result = dataset.to_table(filter="json_get_string(data, 'name') = 'Alice'")
sql = (
dataset.sql(
"SELECT * FROM dataset WHERE json_get_string(data, 'name') = 'Alice'"
)
.build()
.to_batch_records()
)
sql_result = pa.Table.from_batches(sql)
assert result == sql_result
assert result.num_rows == 1
assert result["id"][0].as_py() == 1

# Test json_get_int with type coercion
result = dataset.to_table(filter="json_get_int(data, 'age') > 28")
sql = (
dataset.sql("SELECT * FROM dataset WHERE json_get_int(data, 'age') > 28")
.build()
.to_batch_records()
)
sql_result = pa.Table.from_batches(sql)
assert result == sql_result
assert result.num_rows == 2 # Alice (30) and Charlie ("35" -> 35)

# Test json_get_bool with type coercion
result = dataset.to_table(filter="json_get_bool(data, 'active') = true")
sql = (
dataset.sql(
"SELECT * FROM dataset WHERE json_get_bool(data, 'active') = true"
)
.build()
.to_batch_records()
)
sql_result = pa.Table.from_batches(sql)
assert result == sql_result
assert result.num_rows == 2 # Alice (true) and Charlie ("true" -> true)

# Test json_get_float
result = dataset.to_table(filter="json_get_float(data, 'score') > 90")
sql = (
dataset.sql(
"SELECT * FROM dataset WHERE json_get_float(data, 'score') > 90"
)
.build()
.to_batch_records()
)
sql_result = pa.Table.from_batches(sql)
assert result == sql_result
assert result.num_rows == 2 # Alice (95.5) and Charlie ("92" -> 92.0)


Expand Down Expand Up @@ -304,13 +348,35 @@ def test_nested_json_access():
'name')
= 'Alice'"""
)
sql = (
dataset.sql(
"SELECT * FROM dataset WHERE "
"json_get_string("
"json_get(json_get(data, 'user'), 'profile'), "
"'name') = 'Alice'"
)
.build()
.to_batch_records()
)
sql_result = pa.Table.from_batches(sql)
assert result == sql_result
assert result.num_rows == 1
assert result["id"][0].as_py() == 1

# Or use JSONPath for deep access
result = dataset.to_table(
filter="json_extract(data, '$.user.profile.settings.theme') = '\"dark\"'"
)
sql = (
dataset.sql(
"SELECT * FROM dataset WHERE "
"json_extract(data, '$.user.profile.settings.theme') = '\"dark\"'"
)
.build()
.to_batch_records()
)
sql_result = pa.Table.from_batches(sql)
assert result == sql_result
assert result.num_rows == 1
assert result["id"][0].as_py() == 1

Expand Down Expand Up @@ -342,16 +408,44 @@ def test_json_array_operations():
result = dataset.to_table(
filter="json_array_contains(data, '$.items', 'apple')"
)
sql = (
dataset.sql(
"SELECT * FROM dataset WHERE "
"json_array_contains(data, '$.items', 'apple')"
)
.build()
.to_batch_records()
)
sql_result = pa.Table.from_batches(sql)
assert result == sql_result
assert result.num_rows == 1
assert result["id"][0].as_py() == 1

# Test array length
result = dataset.to_table(filter="json_array_length(data, '$.counts') > 3")
sql = (
dataset.sql(
"SELECT * FROM dataset WHERE json_array_length(data, '$.counts') > 3"
)
.build()
.to_batch_records()
)
sql_result = pa.Table.from_batches(sql)
assert result == sql_result
assert result.num_rows == 1
assert result["id"][0].as_py() == 1

# Test empty array
result = dataset.to_table(filter="json_array_length(data, '$.items') = 0")
sql = (
dataset.sql(
"SELECT * FROM dataset WHERE json_array_length(data, '$.items') = 0"
)
.build()
.to_batch_records()
)
sql_result = pa.Table.from_batches(sql)
assert result == sql_result
assert result.num_rows == 1
assert result["id"][0].as_py() == 3

Expand Down Expand Up @@ -400,7 +494,17 @@ def test_json_filter_append_missing_json_cast(tmp_path: Path):
result = dataset.to_table(
filter="json_get(article_metadata, 'article_journal') IS NOT NULL"
)
sql = (
dataset.sql(
"SELECT * FROM dataset WHERE "
"json_get(article_metadata, 'article_journal') IS NOT NULL"
)
.build()
.to_batch_records()
)
sql_result = pa.Table.from_batches(sql)

assert result == sql_result
assert result.num_rows == 3
assert result.column("article_journal").to_pylist() == [
"Cell",
Expand Down
89 changes: 87 additions & 2 deletions rust/lance/src/dataset/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
// SPDX-FileCopyrightText: Copyright The Lance Authors

use crate::datafusion::LanceTableProvider;
use crate::dataset::utils::SchemaAdapter;
use crate::Dataset;
use arrow_array::RecordBatch;
use datafusion::dataframe::DataFrame;
use datafusion::execution::SendableRecordBatchStream;
use datafusion::prelude::SessionContext;
use futures::TryStreamExt;
use lance_datafusion::udf::register_functions;
use std::sync::Arc;

/// A SQL builder to prepare options for running SQL queries against a Lance dataset.
Expand Down Expand Up @@ -75,6 +77,7 @@ impl SqlQueryBuilder {
row_addr,
)),
)?;
register_functions(&ctx);
let df = ctx.sql(&self.sql).await?;
Ok(SqlQuery::new(df))
}
Expand All @@ -90,7 +93,18 @@ impl SqlQuery {
}

pub async fn into_stream(self) -> lance_core::Result<SendableRecordBatchStream> {
self.dataframe.execute_stream().await.map_err(|e| e.into())
let exec_node = self
.dataframe
.execute_stream()
.await
.map_err(lance_core::Error::from)?;
let schema = exec_node.schema();
if SchemaAdapter::requires_logical_conversion(&schema) {
let adapter = SchemaAdapter::new(schema);
Ok(adapter.to_logical_stream(exec_node))
} else {
Ok(exec_node)
}
}

pub async fn into_batch_records(self) -> lance_core::Result<Vec<RecordBatch>> {
Expand All @@ -109,11 +123,18 @@ impl SqlQuery {
#[cfg(test)]
mod tests {
use crate::utils::test::{assert_string_matches, DatagenExt, FragmentCount, FragmentRowCount};
use std::collections::HashMap;
use std::sync::Arc;

use crate::Dataset;
use all_asserts::assert_true;
use arrow_array::cast::AsArray;
use arrow_array::types::{Int32Type, Int64Type, UInt64Type};

use arrow_array::{Int32Array, RecordBatch, RecordBatchIterator, StringArray};
use arrow_schema::Schema as ArrowSchema;
use arrow_schema::{DataType, Field};
use lance_arrow::json::ARROW_JSON_EXT_NAME;
use lance_arrow::ARROW_EXT_NAME_KEY;
use lance_datagen::{array, gen_batch};

#[tokio::test]
Expand Down Expand Up @@ -280,4 +301,68 @@ mod tests {
]], row_count: 1 }"#;
assert_string_matches(&plan, expected_pattern).unwrap();
}

#[tokio::test]
async fn test_nested_json_access() {
let json_rows = vec![
Some(r#"{"user": {"profile": {"name": "Alice", "settings": {"theme": "dark"}}}}"#),
Some(r#"{"user": {"profile": {"name": "Bob", "settings": {"theme": "light"}}}}"#),
];
let json_array = StringArray::from(json_rows);
let id_array = Int32Array::from(vec![1, 2]);

let mut metadata = HashMap::new();
metadata.insert(
ARROW_EXT_NAME_KEY.to_string(),
ARROW_JSON_EXT_NAME.to_string(),
);

let schema = Arc::new(ArrowSchema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("data", DataType::Utf8, true).with_metadata(metadata),
]));

let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(id_array), Arc::new(json_array)],
)
.unwrap();

let reader = RecordBatchIterator::new(vec![Ok(batch.clone())], schema.clone());
let ds = Dataset::write(reader, "memory://test_nested_json_access", None)
.await
.unwrap();

let results = ds
.sql(
"SELECT id FROM dataset WHERE \
json_get_string(json_get(json_get(data, 'user'), 'profile'), 'name') = 'Alice'",
)
.build()
.await
.unwrap()
.into_batch_records()
.await
.unwrap();
let batch = results.into_iter().next().unwrap();
pretty_assertions::assert_eq!(batch.num_rows(), 1);
pretty_assertions::assert_eq!(batch.num_columns(), 1);
pretty_assertions::assert_eq!(batch.column(0).as_primitive::<Int32Type>().value(0), 1);

let results = ds
.sql(
"SELECT id FROM dataset WHERE \
json_extract(data, '$.user.profile.settings.theme') = '\"dark\"'",
)
.build()
.await
.unwrap()
.into_batch_records()
.await
.unwrap();
let batch = results.into_iter().next().unwrap();
pretty_assertions::assert_eq!(batch.num_rows(), 1);
pretty_assertions::assert_eq!(batch.num_columns(), 1);
pretty_assertions::assert_eq!(batch.column(0).as_primitive::<Int32Type>().value(0), 1);
}
}
Loading