diff --git a/Cargo.lock b/Cargo.lock index 592a797bf..6af46fa1a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "abi_stable" @@ -1632,6 +1632,7 @@ dependencies = [ "pyo3-async-runtimes", "pyo3-build-config", "pyo3-log", + "serde_json", "tokio", "url", "uuid", diff --git a/Cargo.toml b/Cargo.toml index 44bb88186..3e632bafc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -62,6 +62,7 @@ datafusion-substrait = { version = "52", optional = true } datafusion-proto = { version = "52" } datafusion-ffi = { version = "52" } prost = "0.14.1" # keep in line with `datafusion-substrait` +serde_json = "1" uuid = { version = "1.18", features = ["v4"] } mimalloc = { version = "0.1", optional = true, default-features = false, features = [ "local_dynamic_tls", diff --git a/python/datafusion/plan.py b/python/datafusion/plan.py index 0b7bebcb3..fb54fd624 100644 --- a/python/datafusion/plan.py +++ b/python/datafusion/plan.py @@ -98,6 +98,12 @@ def to_proto(self) -> bytes: """ return self._raw_plan.to_proto() + def __eq__(self, other: LogicalPlan) -> bool: + """Test equality.""" + if not isinstance(other, LogicalPlan): + return False + return self._raw_plan.__eq__(other._raw_plan) + class ExecutionPlan: """Represent nodes in the DataFusion Physical Plan.""" diff --git a/python/datafusion/substrait.py b/python/datafusion/substrait.py index f10adfb0c..3115238fa 100644 --- a/python/datafusion/substrait.py +++ b/python/datafusion/substrait.py @@ -67,6 +67,26 @@ def encode(self) -> bytes: """ return self.plan_internal.encode() + def to_json(self) -> str: + """Get the JSON representation of the Substrait plan. + + Returns: + A JSON representation of the Substrait plan. + """ + return self.plan_internal.to_json() + + @staticmethod + def from_json(json: str) -> Plan: + """Parse a plan from a JSON string representation. + + Args: + json: JSON representation of a Substrait plan. + + Returns: + Plan object representing the Substrait plan. + """ + return Plan(substrait_internal.Plan.from_json(json)) + @deprecated("Use `Plan` instead.") class plan(Plan): # noqa: N801 diff --git a/python/tests/test_substrait.py b/python/tests/test_substrait.py index 43aa327d4..a5f59ba7e 100644 --- a/python/tests/test_substrait.py +++ b/python/tests/test_substrait.py @@ -74,3 +74,76 @@ def test_substrait_file_serialization(ctx, tmp_path, path_to_str): expected_actual_plan = ss.Consumer.from_substrait_plan(ctx, actual_plan) assert str(expected_logical_plan) == str(expected_actual_plan) + + +def test_json_processing_round_trip(ctx: SessionContext): + ctx.register_record_batches("t", [[pa.record_batch({"a": [1]})]]) + original_logical_plan = ctx.sql("SELECT * FROM t").logical_plan() + + substrait_plan = ss.Producer.to_substrait_plan(original_logical_plan, ctx) + json_plan = substrait_plan.to_json() + + expected = """\ + "relations": [ + { + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [ + 1 + ] + } + }, + "input": { + "read": { + "baseSchema": { + "names": [ + "a" + ], + "struct": { + "types": [ + { + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + }, + "namedTable": { + "names": [ + "t" + ] + } + } + }, + "expressions": [ + { + "selection": { + "directReference": { + "structField": {} + }, + "rootReference": {} + } + } + ] + } + }, + "names": [ + "a" + ] + } + } + ]""" + + assert expected in json_plan + + round_trip_substrait_plan = ss.Plan.from_json(json_plan) + round_trip_logical_plan = ss.Consumer.from_substrait_plan( + ctx, round_trip_substrait_plan + ) + + assert round_trip_logical_plan == original_logical_plan diff --git a/src/sql/logical.rs b/src/sql/logical.rs index 786118199..cd2ed73d3 100644 --- a/src/sql/logical.rs +++ b/src/sql/logical.rs @@ -66,8 +66,8 @@ use crate::expr::unnest::PyUnnest; use crate::expr::values::PyValues; use crate::expr::window::PyWindowExpr; -#[pyclass(frozen, name = "LogicalPlan", module = "datafusion", subclass)] -#[derive(Debug, Clone)] +#[pyclass(frozen, name = "LogicalPlan", module = "datafusion", subclass, eq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct PyLogicalPlan { pub(crate) plan: Arc, } diff --git a/src/substrait.rs b/src/substrait.rs index ea8eaf506..1cbf3256c 100644 --- a/src/substrait.rs +++ b/src/substrait.rs @@ -23,7 +23,7 @@ use pyo3::prelude::*; use pyo3::types::PyBytes; use crate::context::PySessionContext; -use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err}; +use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err, to_datafusion_err}; use crate::sql::logical::PyLogicalPlan; use crate::utils::wait_for_future; @@ -42,6 +42,19 @@ impl PyPlan { .map_err(PyDataFusionError::EncodeError)?; Ok(PyBytes::new(py, &proto_bytes).into()) } + + /// Get the JSON representation of the substrait plan + fn to_json(&self) -> PyDataFusionResult { + let json = serde_json::to_string_pretty(&self.plan).map_err(to_datafusion_err)?; + Ok(json) + } + + /// Parse a Substrait Plan from its JSON representation + #[staticmethod] + fn from_json(json: &str) -> PyDataFusionResult { + let plan: Plan = serde_json::from_str(json).map_err(to_datafusion_err)?; + Ok(PyPlan { plan }) + } } impl From for Plan {