Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ datafusion-substrait = { version = "52", optional = true }
datafusion-proto = { version = "52" }
datafusion-ffi = { version = "52" }
prost = "0.14.1" # keep in line with `datafusion-substrait`
serde_json = "1"
uuid = { version = "1.18", features = ["v4"] }
mimalloc = { version = "0.1", optional = true, default-features = false, features = [
"local_dynamic_tls",
Expand Down
6 changes: 6 additions & 0 deletions python/datafusion/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ def to_proto(self) -> bytes:
"""
return self._raw_plan.to_proto()

def __eq__(self, other: LogicalPlan) -> bool:
"""Test equality."""
if not isinstance(other, LogicalPlan):
return False
return self._raw_plan.__eq__(other._raw_plan)


class ExecutionPlan:
"""Represent nodes in the DataFusion Physical Plan."""
Expand Down
20 changes: 20 additions & 0 deletions python/datafusion/substrait.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,26 @@ def encode(self) -> bytes:
"""
return self.plan_internal.encode()

def to_json(self) -> str:
"""Get the JSON representation of the Substrait plan.

Returns:
A JSON representation of the Substrait plan.
"""
return self.plan_internal.to_json()

@staticmethod
def from_json(json: str) -> Plan:
"""Parse a plan from a JSON string representation.

Args:
json: JSON representation of a Substrait plan.

Returns:
Plan object representing the Substrait plan.
"""
return Plan(substrait_internal.Plan.from_json(json))


@deprecated("Use `Plan` instead.")
class plan(Plan): # noqa: N801
Expand Down
73 changes: 73 additions & 0 deletions python/tests/test_substrait.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,76 @@ def test_substrait_file_serialization(ctx, tmp_path, path_to_str):
expected_actual_plan = ss.Consumer.from_substrait_plan(ctx, actual_plan)

assert str(expected_logical_plan) == str(expected_actual_plan)


def test_json_processing_round_trip(ctx: SessionContext):
ctx.register_record_batches("t", [[pa.record_batch({"a": [1]})]])
original_logical_plan = ctx.sql("SELECT * FROM t").logical_plan()

substrait_plan = ss.Producer.to_substrait_plan(original_logical_plan, ctx)
json_plan = substrait_plan.to_json()

expected = """\
"relations": [
{
"root": {
"input": {
"project": {
"common": {
"emit": {
"outputMapping": [
1
]
}
},
"input": {
"read": {
"baseSchema": {
"names": [
"a"
],
"struct": {
"types": [
{
"i64": {
"nullability": "NULLABILITY_NULLABLE"
}
}
],
"nullability": "NULLABILITY_REQUIRED"
}
},
"namedTable": {
"names": [
"t"
]
}
}
},
"expressions": [
{
"selection": {
"directReference": {
"structField": {}
},
"rootReference": {}
}
}
]
}
},
"names": [
"a"
]
}
}
]"""

assert expected in json_plan

round_trip_substrait_plan = ss.Plan.from_json(json_plan)
round_trip_logical_plan = ss.Consumer.from_substrait_plan(
ctx, round_trip_substrait_plan
)

assert round_trip_logical_plan == original_logical_plan
4 changes: 2 additions & 2 deletions src/sql/logical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ use crate::expr::unnest::PyUnnest;
use crate::expr::values::PyValues;
use crate::expr::window::PyWindowExpr;

#[pyclass(frozen, name = "LogicalPlan", module = "datafusion", subclass)]
#[derive(Debug, Clone)]
#[pyclass(frozen, name = "LogicalPlan", module = "datafusion", subclass, eq)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PyLogicalPlan {
pub(crate) plan: Arc<LogicalPlan>,
}
Expand Down
15 changes: 14 additions & 1 deletion src/substrait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use pyo3::prelude::*;
use pyo3::types::PyBytes;

use crate::context::PySessionContext;
use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err};
use crate::errors::{PyDataFusionError, PyDataFusionResult, py_datafusion_err, to_datafusion_err};
use crate::sql::logical::PyLogicalPlan;
use crate::utils::wait_for_future;

Expand All @@ -42,6 +42,19 @@ impl PyPlan {
.map_err(PyDataFusionError::EncodeError)?;
Ok(PyBytes::new(py, &proto_bytes).into())
}

/// Get the JSON representation of the substrait plan
fn to_json(&self) -> PyDataFusionResult<String> {
let json = serde_json::to_string_pretty(&self.plan).map_err(to_datafusion_err)?;
Ok(json)
}

/// Parse a Substrait Plan from its JSON representation
#[staticmethod]
fn from_json(json: &str) -> PyDataFusionResult<PyPlan> {
let plan: Plan = serde_json::from_str(json).map_err(to_datafusion_err)?;
Ok(PyPlan { plan })
}
}

impl From<PyPlan> for Plan {
Expand Down