From 1c7de596f28d636338877108759aa31908037bcf Mon Sep 17 00:00:00 2001 From: yangzhong Date: Fri, 16 Dec 2022 18:27:02 +0800 Subject: [PATCH 1/5] Remove Ballista related things in the datafusion.proto --- datafusion/proto/proto/datafusion.proto | 75 +------------------------ 1 file changed, 1 insertion(+), 74 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 97ba57a7e0da..4cf7151ea997 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1330,77 +1330,4 @@ message ColumnStats { ScalarValue max_value = 2; uint32 null_count = 3; uint32 distinct_count = 4; -} - -message PartitionLocation { - // partition_id of the map stage who produces the shuffle. - uint32 map_partition_id = 1; - // partition_id of the shuffle, a composition of(job_id + map_stage_id + partition_id). - PartitionId partition_id = 2; - ExecutorMetadata executor_meta = 3; - PartitionStats partition_stats = 4; - string path = 5; -} - -// Unique identifier for a materialized partition of data -message PartitionId { - string job_id = 1; - uint32 stage_id = 2; - uint32 partition_id = 4; -} - -// Used by scheduler -message ExecutorMetadata { - string id = 1; - string host = 2; - uint32 port = 3; - uint32 grpc_port = 4; - ExecutorSpecification specification = 5; -} - -// Used by grpc -message ExecutorRegistration { - string id = 1; - // "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see https://github.com/tokio-rs/prost/issues/430 and https://github.com/tokio-rs/prost/pull/455) - // this syntax is ugly but is binary compatible with the "optional" keyword (see https://stackoverflow.com/questions/42622015/how-to-define-an-optional-field-in-protobuf-3) - oneof optional_host { - string host = 2; - } - uint32 port = 3; - uint32 grpc_port = 4; - ExecutorSpecification specification = 5; -} - -message ExecutorHeartbeat { - string executor_id = 1; - // Unix epoch-based timestamp in seconds - uint64 timestamp = 2; - repeated ExecutorMetric metrics = 3; - ExecutorStatus status = 4; -} - -message ExecutorSpecification { - repeated ExecutorResource resources = 1; -} - -message ExecutorResource { - // TODO add more resources - oneof resource { - uint32 task_slots = 1; - } -} - -message ExecutorMetric { - // TODO add more metrics - oneof metric { - uint64 available_memory = 1; - } -} - -message ExecutorStatus { - oneof status { - string active = 1; - string dead = 2; - string unknown = 3; - } -} +} \ No newline at end of file From c772d6a020d68cf52b4b2163e223ab002fe7e1f7 Mon Sep 17 00:00:00 2001 From: yangzhong Date: Fri, 16 Dec 2022 18:35:57 +0800 Subject: [PATCH 2/5] Add missing fields for FileScanExecConf, like output_ordering and config_options --- datafusion/proto/proto/datafusion.proto | 7 + datafusion/proto/src/generated/pbjson.rs | 1371 ++--------------- datafusion/proto/src/generated/prost.rs | 140 +- .../proto/src/physical_plan/from_proto.rs | 129 +- datafusion/proto/src/physical_plan/mod.rs | 71 +- .../proto/src/physical_plan/to_proto.rs | 32 + 6 files changed, 330 insertions(+), 1420 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 4cf7151ea997..47f457622d14 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1139,6 +1139,13 @@ message FileScanExecConf { Statistics statistics = 6; repeated string table_partition_cols = 7; string object_store_url = 8; + repeated PhysicalSortExprNode output_ordering = 9; + repeated ConfigOption options = 10; +} + +message ConfigOption { + string key = 1; + ScalarValue value = 2; } message ParquetScanExecNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 13236a935839..ad482de8284a 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -3057,6 +3057,114 @@ impl<'de> serde::Deserialize<'de> for ColumnStats { deserializer.deserialize_struct("datafusion.ColumnStats", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for ConfigOption { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.key.is_empty() { + len += 1; + } + if self.value.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.ConfigOption", len)?; + if !self.key.is_empty() { + struct_ser.serialize_field("key", &self.key)?; + } + if let Some(v) = self.value.as_ref() { + struct_ser.serialize_field("value", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for ConfigOption { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "key", + "value", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Key, + Value, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "key" => Ok(GeneratedField::Key), + "value" => Ok(GeneratedField::Value), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = ConfigOption; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.ConfigOption") + } + + fn visit_map(self, mut map: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut key__ = None; + let mut value__ = None; + while let Some(k) = map.next_key()? { + match k { + GeneratedField::Key => { + if key__.is_some() { + return Err(serde::de::Error::duplicate_field("key")); + } + key__ = Some(map.next_value()?); + } + GeneratedField::Value => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("value")); + } + value__ = map.next_value()?; + } + } + } + Ok(ConfigOption { + key: key__.unwrap_or_default(), + value: value__, + }) + } + } + deserializer.deserialize_struct("datafusion.ConfigOption", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for CreateCatalogNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -5327,873 +5435,21 @@ impl<'de> serde::Deserialize<'de> for EmptyMessage { formatter.write_str("struct datafusion.EmptyMessage") } - fn visit_map(self, mut map: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - while map.next_key::()?.is_some() { - let _ = map.next_value::()?; - } - Ok(EmptyMessage { - }) - } - } - deserializer.deserialize_struct("datafusion.EmptyMessage", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for EmptyRelationNode { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.produce_one_row { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.EmptyRelationNode", len)?; - if self.produce_one_row { - struct_ser.serialize_field("produceOneRow", &self.produce_one_row)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for EmptyRelationNode { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "produce_one_row", - "produceOneRow", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - ProduceOneRow, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "produceOneRow" | "produce_one_row" => Ok(GeneratedField::ProduceOneRow), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = EmptyRelationNode; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.EmptyRelationNode") - } - - fn visit_map(self, mut map: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut produce_one_row__ = None; - while let Some(k) = map.next_key()? { - match k { - GeneratedField::ProduceOneRow => { - if produce_one_row__.is_some() { - return Err(serde::de::Error::duplicate_field("produceOneRow")); - } - produce_one_row__ = Some(map.next_value()?); - } - } - } - Ok(EmptyRelationNode { - produce_one_row: produce_one_row__.unwrap_or_default(), - }) - } - } - deserializer.deserialize_struct("datafusion.EmptyRelationNode", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for ExecutorHeartbeat { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if !self.executor_id.is_empty() { - len += 1; - } - if self.timestamp != 0 { - len += 1; - } - if !self.metrics.is_empty() { - len += 1; - } - if self.status.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.ExecutorHeartbeat", len)?; - if !self.executor_id.is_empty() { - struct_ser.serialize_field("executorId", &self.executor_id)?; - } - if self.timestamp != 0 { - struct_ser.serialize_field("timestamp", ToString::to_string(&self.timestamp).as_str())?; - } - if !self.metrics.is_empty() { - struct_ser.serialize_field("metrics", &self.metrics)?; - } - if let Some(v) = self.status.as_ref() { - struct_ser.serialize_field("status", v)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for ExecutorHeartbeat { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "executor_id", - "executorId", - "timestamp", - "metrics", - "status", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - ExecutorId, - Timestamp, - Metrics, - Status, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "executorId" | "executor_id" => Ok(GeneratedField::ExecutorId), - "timestamp" => Ok(GeneratedField::Timestamp), - "metrics" => Ok(GeneratedField::Metrics), - "status" => Ok(GeneratedField::Status), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ExecutorHeartbeat; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ExecutorHeartbeat") - } - - fn visit_map(self, mut map: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut executor_id__ = None; - let mut timestamp__ = None; - let mut metrics__ = None; - let mut status__ = None; - while let Some(k) = map.next_key()? { - match k { - GeneratedField::ExecutorId => { - if executor_id__.is_some() { - return Err(serde::de::Error::duplicate_field("executorId")); - } - executor_id__ = Some(map.next_value()?); - } - GeneratedField::Timestamp => { - if timestamp__.is_some() { - return Err(serde::de::Error::duplicate_field("timestamp")); - } - timestamp__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::Metrics => { - if metrics__.is_some() { - return Err(serde::de::Error::duplicate_field("metrics")); - } - metrics__ = Some(map.next_value()?); - } - GeneratedField::Status => { - if status__.is_some() { - return Err(serde::de::Error::duplicate_field("status")); - } - status__ = map.next_value()?; - } - } - } - Ok(ExecutorHeartbeat { - executor_id: executor_id__.unwrap_or_default(), - timestamp: timestamp__.unwrap_or_default(), - metrics: metrics__.unwrap_or_default(), - status: status__, - }) - } - } - deserializer.deserialize_struct("datafusion.ExecutorHeartbeat", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for ExecutorMetadata { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if !self.id.is_empty() { - len += 1; - } - if !self.host.is_empty() { - len += 1; - } - if self.port != 0 { - len += 1; - } - if self.grpc_port != 0 { - len += 1; - } - if self.specification.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.ExecutorMetadata", len)?; - if !self.id.is_empty() { - struct_ser.serialize_field("id", &self.id)?; - } - if !self.host.is_empty() { - struct_ser.serialize_field("host", &self.host)?; - } - if self.port != 0 { - struct_ser.serialize_field("port", &self.port)?; - } - if self.grpc_port != 0 { - struct_ser.serialize_field("grpcPort", &self.grpc_port)?; - } - if let Some(v) = self.specification.as_ref() { - struct_ser.serialize_field("specification", v)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for ExecutorMetadata { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "id", - "host", - "port", - "grpc_port", - "grpcPort", - "specification", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Id, - Host, - Port, - GrpcPort, - Specification, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "id" => Ok(GeneratedField::Id), - "host" => Ok(GeneratedField::Host), - "port" => Ok(GeneratedField::Port), - "grpcPort" | "grpc_port" => Ok(GeneratedField::GrpcPort), - "specification" => Ok(GeneratedField::Specification), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ExecutorMetadata; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ExecutorMetadata") - } - - fn visit_map(self, mut map: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut id__ = None; - let mut host__ = None; - let mut port__ = None; - let mut grpc_port__ = None; - let mut specification__ = None; - while let Some(k) = map.next_key()? { - match k { - GeneratedField::Id => { - if id__.is_some() { - return Err(serde::de::Error::duplicate_field("id")); - } - id__ = Some(map.next_value()?); - } - GeneratedField::Host => { - if host__.is_some() { - return Err(serde::de::Error::duplicate_field("host")); - } - host__ = Some(map.next_value()?); - } - GeneratedField::Port => { - if port__.is_some() { - return Err(serde::de::Error::duplicate_field("port")); - } - port__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::GrpcPort => { - if grpc_port__.is_some() { - return Err(serde::de::Error::duplicate_field("grpcPort")); - } - grpc_port__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::Specification => { - if specification__.is_some() { - return Err(serde::de::Error::duplicate_field("specification")); - } - specification__ = map.next_value()?; - } - } - } - Ok(ExecutorMetadata { - id: id__.unwrap_or_default(), - host: host__.unwrap_or_default(), - port: port__.unwrap_or_default(), - grpc_port: grpc_port__.unwrap_or_default(), - specification: specification__, - }) - } - } - deserializer.deserialize_struct("datafusion.ExecutorMetadata", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for ExecutorMetric { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.metric.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.ExecutorMetric", len)?; - if let Some(v) = self.metric.as_ref() { - match v { - executor_metric::Metric::AvailableMemory(v) => { - struct_ser.serialize_field("availableMemory", ToString::to_string(&v).as_str())?; - } - } - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for ExecutorMetric { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "available_memory", - "availableMemory", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - AvailableMemory, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "availableMemory" | "available_memory" => Ok(GeneratedField::AvailableMemory), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ExecutorMetric; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ExecutorMetric") - } - - fn visit_map(self, mut map: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut metric__ = None; - while let Some(k) = map.next_key()? { - match k { - GeneratedField::AvailableMemory => { - if metric__.is_some() { - return Err(serde::de::Error::duplicate_field("availableMemory")); - } - metric__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| executor_metric::Metric::AvailableMemory(x.0)); - } - } - } - Ok(ExecutorMetric { - metric: metric__, - }) - } - } - deserializer.deserialize_struct("datafusion.ExecutorMetric", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for ExecutorRegistration { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if !self.id.is_empty() { - len += 1; - } - if self.port != 0 { - len += 1; - } - if self.grpc_port != 0 { - len += 1; - } - if self.specification.is_some() { - len += 1; - } - if self.optional_host.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.ExecutorRegistration", len)?; - if !self.id.is_empty() { - struct_ser.serialize_field("id", &self.id)?; - } - if self.port != 0 { - struct_ser.serialize_field("port", &self.port)?; - } - if self.grpc_port != 0 { - struct_ser.serialize_field("grpcPort", &self.grpc_port)?; - } - if let Some(v) = self.specification.as_ref() { - struct_ser.serialize_field("specification", v)?; - } - if let Some(v) = self.optional_host.as_ref() { - match v { - executor_registration::OptionalHost::Host(v) => { - struct_ser.serialize_field("host", v)?; - } - } - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for ExecutorRegistration { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "id", - "port", - "grpc_port", - "grpcPort", - "specification", - "host", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Id, - Port, - GrpcPort, - Specification, - Host, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "id" => Ok(GeneratedField::Id), - "port" => Ok(GeneratedField::Port), - "grpcPort" | "grpc_port" => Ok(GeneratedField::GrpcPort), - "specification" => Ok(GeneratedField::Specification), - "host" => Ok(GeneratedField::Host), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ExecutorRegistration; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ExecutorRegistration") - } - - fn visit_map(self, mut map: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut id__ = None; - let mut port__ = None; - let mut grpc_port__ = None; - let mut specification__ = None; - let mut optional_host__ = None; - while let Some(k) = map.next_key()? { - match k { - GeneratedField::Id => { - if id__.is_some() { - return Err(serde::de::Error::duplicate_field("id")); - } - id__ = Some(map.next_value()?); - } - GeneratedField::Port => { - if port__.is_some() { - return Err(serde::de::Error::duplicate_field("port")); - } - port__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::GrpcPort => { - if grpc_port__.is_some() { - return Err(serde::de::Error::duplicate_field("grpcPort")); - } - grpc_port__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::Specification => { - if specification__.is_some() { - return Err(serde::de::Error::duplicate_field("specification")); - } - specification__ = map.next_value()?; - } - GeneratedField::Host => { - if optional_host__.is_some() { - return Err(serde::de::Error::duplicate_field("host")); - } - optional_host__ = map.next_value::<::std::option::Option<_>>()?.map(executor_registration::OptionalHost::Host); - } - } - } - Ok(ExecutorRegistration { - id: id__.unwrap_or_default(), - port: port__.unwrap_or_default(), - grpc_port: grpc_port__.unwrap_or_default(), - specification: specification__, - optional_host: optional_host__, - }) - } - } - deserializer.deserialize_struct("datafusion.ExecutorRegistration", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for ExecutorResource { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.resource.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.ExecutorResource", len)?; - if let Some(v) = self.resource.as_ref() { - match v { - executor_resource::Resource::TaskSlots(v) => { - struct_ser.serialize_field("taskSlots", v)?; - } - } - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for ExecutorResource { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "task_slots", - "taskSlots", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - TaskSlots, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "taskSlots" | "task_slots" => Ok(GeneratedField::TaskSlots), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ExecutorResource; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ExecutorResource") - } - - fn visit_map(self, mut map: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut resource__ = None; - while let Some(k) = map.next_key()? { - match k { - GeneratedField::TaskSlots => { - if resource__.is_some() { - return Err(serde::de::Error::duplicate_field("taskSlots")); - } - resource__ = map.next_value::<::std::option::Option<::pbjson::private::NumberDeserialize<_>>>()?.map(|x| executor_resource::Resource::TaskSlots(x.0)); - } - } - } - Ok(ExecutorResource { - resource: resource__, - }) - } - } - deserializer.deserialize_struct("datafusion.ExecutorResource", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for ExecutorSpecification { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if !self.resources.is_empty() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.ExecutorSpecification", len)?; - if !self.resources.is_empty() { - struct_ser.serialize_field("resources", &self.resources)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for ExecutorSpecification { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "resources", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Resources, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "resources" => Ok(GeneratedField::Resources), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ExecutorSpecification; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ExecutorSpecification") - } - - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut resources__ = None; - while let Some(k) = map.next_key()? { - match k { - GeneratedField::Resources => { - if resources__.is_some() { - return Err(serde::de::Error::duplicate_field("resources")); - } - resources__ = Some(map.next_value()?); - } - } + while map.next_key::()?.is_some() { + let _ = map.next_value::()?; } - Ok(ExecutorSpecification { - resources: resources__.unwrap_or_default(), + Ok(EmptyMessage { }) } } - deserializer.deserialize_struct("datafusion.ExecutorSpecification", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.EmptyMessage", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for ExecutorStatus { +impl serde::Serialize for EmptyRelationNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result where @@ -6201,43 +5457,30 @@ impl serde::Serialize for ExecutorStatus { { use serde::ser::SerializeStruct; let mut len = 0; - if self.status.is_some() { + if self.produce_one_row { len += 1; } - let mut struct_ser = serializer.serialize_struct("datafusion.ExecutorStatus", len)?; - if let Some(v) = self.status.as_ref() { - match v { - executor_status::Status::Active(v) => { - struct_ser.serialize_field("active", v)?; - } - executor_status::Status::Dead(v) => { - struct_ser.serialize_field("dead", v)?; - } - executor_status::Status::Unknown(v) => { - struct_ser.serialize_field("unknown", v)?; - } - } + let mut struct_ser = serializer.serialize_struct("datafusion.EmptyRelationNode", len)?; + if self.produce_one_row { + struct_ser.serialize_field("produceOneRow", &self.produce_one_row)?; } struct_ser.end() } } -impl<'de> serde::Deserialize<'de> for ExecutorStatus { +impl<'de> serde::Deserialize<'de> for EmptyRelationNode { #[allow(deprecated)] fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { const FIELDS: &[&str] = &[ - "active", - "dead", - "unknown", + "produce_one_row", + "produceOneRow", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { - Active, - Dead, - Unknown, + ProduceOneRow, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -6259,9 +5502,7 @@ impl<'de> serde::Deserialize<'de> for ExecutorStatus { E: serde::de::Error, { match value { - "active" => Ok(GeneratedField::Active), - "dead" => Ok(GeneratedField::Dead), - "unknown" => Ok(GeneratedField::Unknown), + "produceOneRow" | "produce_one_row" => Ok(GeneratedField::ProduceOneRow), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -6271,45 +5512,33 @@ impl<'de> serde::Deserialize<'de> for ExecutorStatus { } struct GeneratedVisitor; impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = ExecutorStatus; + type Value = EmptyRelationNode; fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.ExecutorStatus") + formatter.write_str("struct datafusion.EmptyRelationNode") } - fn visit_map(self, mut map: V) -> std::result::Result + fn visit_map(self, mut map: V) -> std::result::Result where V: serde::de::MapAccess<'de>, { - let mut status__ = None; + let mut produce_one_row__ = None; while let Some(k) = map.next_key()? { match k { - GeneratedField::Active => { - if status__.is_some() { - return Err(serde::de::Error::duplicate_field("active")); - } - status__ = map.next_value::<::std::option::Option<_>>()?.map(executor_status::Status::Active); - } - GeneratedField::Dead => { - if status__.is_some() { - return Err(serde::de::Error::duplicate_field("dead")); - } - status__ = map.next_value::<::std::option::Option<_>>()?.map(executor_status::Status::Dead); - } - GeneratedField::Unknown => { - if status__.is_some() { - return Err(serde::de::Error::duplicate_field("unknown")); + GeneratedField::ProduceOneRow => { + if produce_one_row__.is_some() { + return Err(serde::de::Error::duplicate_field("produceOneRow")); } - status__ = map.next_value::<::std::option::Option<_>>()?.map(executor_status::Status::Unknown); + produce_one_row__ = Some(map.next_value()?); } } } - Ok(ExecutorStatus { - status: status__, + Ok(EmptyRelationNode { + produce_one_row: produce_one_row__.unwrap_or_default(), }) } } - deserializer.deserialize_struct("datafusion.ExecutorStatus", FIELDS, GeneratedVisitor) + deserializer.deserialize_struct("datafusion.EmptyRelationNode", FIELDS, GeneratedVisitor) } } impl serde::Serialize for ExplainExecNode { @@ -6921,6 +6150,12 @@ impl serde::Serialize for FileScanExecConf { if !self.object_store_url.is_empty() { len += 1; } + if !self.output_ordering.is_empty() { + len += 1; + } + if !self.options.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.FileScanExecConf", len)?; if !self.file_groups.is_empty() { struct_ser.serialize_field("fileGroups", &self.file_groups)?; @@ -6943,6 +6178,12 @@ impl serde::Serialize for FileScanExecConf { if !self.object_store_url.is_empty() { struct_ser.serialize_field("objectStoreUrl", &self.object_store_url)?; } + if !self.output_ordering.is_empty() { + struct_ser.serialize_field("outputOrdering", &self.output_ordering)?; + } + if !self.options.is_empty() { + struct_ser.serialize_field("options", &self.options)?; + } struct_ser.end() } } @@ -6963,6 +6204,9 @@ impl<'de> serde::Deserialize<'de> for FileScanExecConf { "tablePartitionCols", "object_store_url", "objectStoreUrl", + "output_ordering", + "outputOrdering", + "options", ]; #[allow(clippy::enum_variant_names)] @@ -6974,6 +6218,8 @@ impl<'de> serde::Deserialize<'de> for FileScanExecConf { Statistics, TablePartitionCols, ObjectStoreUrl, + OutputOrdering, + Options, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -7002,6 +6248,8 @@ impl<'de> serde::Deserialize<'de> for FileScanExecConf { "statistics" => Ok(GeneratedField::Statistics), "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), "objectStoreUrl" | "object_store_url" => Ok(GeneratedField::ObjectStoreUrl), + "outputOrdering" | "output_ordering" => Ok(GeneratedField::OutputOrdering), + "options" => Ok(GeneratedField::Options), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -7028,6 +6276,8 @@ impl<'de> serde::Deserialize<'de> for FileScanExecConf { let mut statistics__ = None; let mut table_partition_cols__ = None; let mut object_store_url__ = None; + let mut output_ordering__ = None; + let mut options__ = None; while let Some(k) = map.next_key()? { match k { GeneratedField::FileGroups => { @@ -7075,6 +6325,18 @@ impl<'de> serde::Deserialize<'de> for FileScanExecConf { } object_store_url__ = Some(map.next_value()?); } + GeneratedField::OutputOrdering => { + if output_ordering__.is_some() { + return Err(serde::de::Error::duplicate_field("outputOrdering")); + } + output_ordering__ = Some(map.next_value()?); + } + GeneratedField::Options => { + if options__.is_some() { + return Err(serde::de::Error::duplicate_field("options")); + } + options__ = Some(map.next_value()?); + } } } Ok(FileScanExecConf { @@ -7085,6 +6347,8 @@ impl<'de> serde::Deserialize<'de> for FileScanExecConf { statistics: statistics__, table_partition_cols: table_partition_cols__.unwrap_or_default(), object_store_url: object_store_url__.unwrap_or_default(), + output_ordering: output_ordering__.unwrap_or_default(), + options: options__.unwrap_or_default(), }) } } @@ -12798,303 +12062,6 @@ impl<'de> serde::Deserialize<'de> for PartialTableReference { deserializer.deserialize_struct("datafusion.PartialTableReference", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PartitionId { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if !self.job_id.is_empty() { - len += 1; - } - if self.stage_id != 0 { - len += 1; - } - if self.partition_id != 0 { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.PartitionId", len)?; - if !self.job_id.is_empty() { - struct_ser.serialize_field("jobId", &self.job_id)?; - } - if self.stage_id != 0 { - struct_ser.serialize_field("stageId", &self.stage_id)?; - } - if self.partition_id != 0 { - struct_ser.serialize_field("partitionId", &self.partition_id)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for PartitionId { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "job_id", - "jobId", - "stage_id", - "stageId", - "partition_id", - "partitionId", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - JobId, - StageId, - PartitionId, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "jobId" | "job_id" => Ok(GeneratedField::JobId), - "stageId" | "stage_id" => Ok(GeneratedField::StageId), - "partitionId" | "partition_id" => Ok(GeneratedField::PartitionId), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PartitionId; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PartitionId") - } - - fn visit_map(self, mut map: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut job_id__ = None; - let mut stage_id__ = None; - let mut partition_id__ = None; - while let Some(k) = map.next_key()? { - match k { - GeneratedField::JobId => { - if job_id__.is_some() { - return Err(serde::de::Error::duplicate_field("jobId")); - } - job_id__ = Some(map.next_value()?); - } - GeneratedField::StageId => { - if stage_id__.is_some() { - return Err(serde::de::Error::duplicate_field("stageId")); - } - stage_id__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::PartitionId => { - if partition_id__.is_some() { - return Err(serde::de::Error::duplicate_field("partitionId")); - } - partition_id__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - } - } - Ok(PartitionId { - job_id: job_id__.unwrap_or_default(), - stage_id: stage_id__.unwrap_or_default(), - partition_id: partition_id__.unwrap_or_default(), - }) - } - } - deserializer.deserialize_struct("datafusion.PartitionId", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for PartitionLocation { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.map_partition_id != 0 { - len += 1; - } - if self.partition_id.is_some() { - len += 1; - } - if self.executor_meta.is_some() { - len += 1; - } - if self.partition_stats.is_some() { - len += 1; - } - if !self.path.is_empty() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.PartitionLocation", len)?; - if self.map_partition_id != 0 { - struct_ser.serialize_field("mapPartitionId", &self.map_partition_id)?; - } - if let Some(v) = self.partition_id.as_ref() { - struct_ser.serialize_field("partitionId", v)?; - } - if let Some(v) = self.executor_meta.as_ref() { - struct_ser.serialize_field("executorMeta", v)?; - } - if let Some(v) = self.partition_stats.as_ref() { - struct_ser.serialize_field("partitionStats", v)?; - } - if !self.path.is_empty() { - struct_ser.serialize_field("path", &self.path)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for PartitionLocation { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "map_partition_id", - "mapPartitionId", - "partition_id", - "partitionId", - "executor_meta", - "executorMeta", - "partition_stats", - "partitionStats", - "path", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - MapPartitionId, - PartitionId, - ExecutorMeta, - PartitionStats, - Path, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "mapPartitionId" | "map_partition_id" => Ok(GeneratedField::MapPartitionId), - "partitionId" | "partition_id" => Ok(GeneratedField::PartitionId), - "executorMeta" | "executor_meta" => Ok(GeneratedField::ExecutorMeta), - "partitionStats" | "partition_stats" => Ok(GeneratedField::PartitionStats), - "path" => Ok(GeneratedField::Path), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PartitionLocation; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PartitionLocation") - } - - fn visit_map(self, mut map: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut map_partition_id__ = None; - let mut partition_id__ = None; - let mut executor_meta__ = None; - let mut partition_stats__ = None; - let mut path__ = None; - while let Some(k) = map.next_key()? { - match k { - GeneratedField::MapPartitionId => { - if map_partition_id__.is_some() { - return Err(serde::de::Error::duplicate_field("mapPartitionId")); - } - map_partition_id__ = - Some(map.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::PartitionId => { - if partition_id__.is_some() { - return Err(serde::de::Error::duplicate_field("partitionId")); - } - partition_id__ = map.next_value()?; - } - GeneratedField::ExecutorMeta => { - if executor_meta__.is_some() { - return Err(serde::de::Error::duplicate_field("executorMeta")); - } - executor_meta__ = map.next_value()?; - } - GeneratedField::PartitionStats => { - if partition_stats__.is_some() { - return Err(serde::de::Error::duplicate_field("partitionStats")); - } - partition_stats__ = map.next_value()?; - } - GeneratedField::Path => { - if path__.is_some() { - return Err(serde::de::Error::duplicate_field("path")); - } - path__ = Some(map.next_value()?); - } - } - } - Ok(PartitionLocation { - map_partition_id: map_partition_id__.unwrap_or_default(), - partition_id: partition_id__, - executor_meta: executor_meta__, - partition_stats: partition_stats__, - path: path__.unwrap_or_default(), - }) - } - } - deserializer.deserialize_struct("datafusion.PartitionLocation", FIELDS, GeneratedVisitor) - } -} impl serde::Serialize for PartitionMode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 1405e1eba638..a377f4b55a50 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1504,6 +1504,17 @@ pub struct FileScanExecConf { pub table_partition_cols: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, #[prost(string, tag = "8")] pub object_store_url: ::prost::alloc::string::String, + #[prost(message, repeated, tag = "9")] + pub output_ordering: ::prost::alloc::vec::Vec, + #[prost(message, repeated, tag = "10")] + pub options: ::prost::alloc::vec::Vec, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ConfigOption { + #[prost(string, tag = "1")] + pub key: ::prost::alloc::string::String, + #[prost(message, optional, tag = "2")] + pub value: ::core::option::Option, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct ParquetScanExecNode { @@ -1769,135 +1780,6 @@ pub struct ColumnStats { #[prost(uint32, tag = "4")] pub distinct_count: u32, } -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PartitionLocation { - /// partition_id of the map stage who produces the shuffle. - #[prost(uint32, tag = "1")] - pub map_partition_id: u32, - /// partition_id of the shuffle, a composition of(job_id + map_stage_id + partition_id). - #[prost(message, optional, tag = "2")] - pub partition_id: ::core::option::Option, - #[prost(message, optional, tag = "3")] - pub executor_meta: ::core::option::Option, - #[prost(message, optional, tag = "4")] - pub partition_stats: ::core::option::Option, - #[prost(string, tag = "5")] - pub path: ::prost::alloc::string::String, -} -/// Unique identifier for a materialized partition of data -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PartitionId { - #[prost(string, tag = "1")] - pub job_id: ::prost::alloc::string::String, - #[prost(uint32, tag = "2")] - pub stage_id: u32, - #[prost(uint32, tag = "4")] - pub partition_id: u32, -} -/// Used by scheduler -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ExecutorMetadata { - #[prost(string, tag = "1")] - pub id: ::prost::alloc::string::String, - #[prost(string, tag = "2")] - pub host: ::prost::alloc::string::String, - #[prost(uint32, tag = "3")] - pub port: u32, - #[prost(uint32, tag = "4")] - pub grpc_port: u32, - #[prost(message, optional, tag = "5")] - pub specification: ::core::option::Option, -} -/// Used by grpc -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ExecutorRegistration { - #[prost(string, tag = "1")] - pub id: ::prost::alloc::string::String, - #[prost(uint32, tag = "3")] - pub port: u32, - #[prost(uint32, tag = "4")] - pub grpc_port: u32, - #[prost(message, optional, tag = "5")] - pub specification: ::core::option::Option, - /// "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see and ) - /// this syntax is ugly but is binary compatible with the "optional" keyword (see ) - #[prost(oneof = "executor_registration::OptionalHost", tags = "2")] - pub optional_host: ::core::option::Option, -} -/// Nested message and enum types in `ExecutorRegistration`. -pub mod executor_registration { - /// "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see and ) - /// this syntax is ugly but is binary compatible with the "optional" keyword (see ) - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum OptionalHost { - #[prost(string, tag = "2")] - Host(::prost::alloc::string::String), - } -} -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ExecutorHeartbeat { - #[prost(string, tag = "1")] - pub executor_id: ::prost::alloc::string::String, - /// Unix epoch-based timestamp in seconds - #[prost(uint64, tag = "2")] - pub timestamp: u64, - #[prost(message, repeated, tag = "3")] - pub metrics: ::prost::alloc::vec::Vec, - #[prost(message, optional, tag = "4")] - pub status: ::core::option::Option, -} -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ExecutorSpecification { - #[prost(message, repeated, tag = "1")] - pub resources: ::prost::alloc::vec::Vec, -} -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ExecutorResource { - /// TODO add more resources - #[prost(oneof = "executor_resource::Resource", tags = "1")] - pub resource: ::core::option::Option, -} -/// Nested message and enum types in `ExecutorResource`. -pub mod executor_resource { - /// TODO add more resources - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum Resource { - #[prost(uint32, tag = "1")] - TaskSlots(u32), - } -} -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ExecutorMetric { - /// TODO add more metrics - #[prost(oneof = "executor_metric::Metric", tags = "1")] - pub metric: ::core::option::Option, -} -/// Nested message and enum types in `ExecutorMetric`. -pub mod executor_metric { - /// TODO add more metrics - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum Metric { - #[prost(uint64, tag = "1")] - AvailableMemory(u64), - } -} -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct ExecutorStatus { - #[prost(oneof = "executor_status::Status", tags = "1, 2, 3")] - pub status: ::core::option::Option, -} -/// Nested message and enum types in `ExecutorStatus`. -pub mod executor_status { - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum Status { - #[prost(string, tag = "1")] - Active(::prost::alloc::string::String), - #[prost(string, tag = "2")] - Dead(::prost::alloc::string::String), - #[prost(string, tag = "3")] - Unknown(::prost::alloc::string::String), - } -} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum JoinType { diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 825399825ff0..54012eabd5e5 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -18,6 +18,7 @@ //! Serde code to convert from protocol buffers to Rust data structures. use crate::protobuf; +use arrow::datatypes::DataType; use chrono::TimeZone; use chrono::Utc; use datafusion::arrow::datatypes::Schema; @@ -28,7 +29,7 @@ use datafusion::execution::context::ExecutionProps; use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::window_function::WindowFunction; use datafusion::physical_expr::expressions::DateTimeIntervalExpr; -use datafusion::physical_expr::ScalarFunctionExpr; +use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::file_format::FileScanConfig; use datafusion::physical_plan::{ expressions::{ @@ -50,6 +51,7 @@ use crate::convert_required; use crate::from_proto::from_proto_binary_op; use crate::protobuf::physical_expr_node::ExprType; use crate::protobuf::JoinSide; +use datafusion::physical_plan::sorts::sort::SortOptions; use parking_lot::RwLock; impl From<&protobuf::PhysicalColumn> for Column { @@ -304,6 +306,96 @@ pub fn parse_protobuf_hash_partitioning( } } +pub fn parse_protobuf_file_scan_config( + proto: &protobuf::FileScanExecConf, + registry: &dyn FunctionRegistry, +) -> Result { + let schema: Arc = Arc::new(convert_required!(proto.schema)?); + let projection = proto + .projection + .iter() + .map(|i| *i as usize) + .collect::>(); + let projection = if projection.is_empty() { + None + } else { + Some(projection) + }; + let statistics = convert_required!(proto.statistics)?; + + let file_groups: Vec> = proto + .file_groups + .iter() + .map(|f| f.try_into()) + .collect::, _>>()?; + + let object_store_url = match proto.object_store_url.is_empty() { + false => ObjectStoreUrl::parse(&proto.object_store_url)?, + true => ObjectStoreUrl::local_filesystem(), + }; + + // extract types of partition columns + let table_partition_cols = proto + .table_partition_cols + .iter() + .map(|col| { + Ok(( + col.to_owned(), + schema.field_with_name(col)?.data_type().clone(), + )) + }) + .collect::, DataFusionError>>()?; + + let output_ordering = proto + .output_ordering + .iter() + .map(|o| { + let expr = o + .expr + .as_ref() + .map(|e| parse_physical_expr(e.as_ref(), registry, &schema)) + .unwrap()?; + Ok(PhysicalSortExpr { + expr, + options: SortOptions { + descending: !o.asc, + nulls_first: o.nulls_first, + }, + }) + }) + .collect::, DataFusionError>>()?; + let output_ordering = if output_ordering.is_empty() { + None + } else { + Some(output_ordering) + }; + + let mut config_options = ConfigOptions::new(); + for option in proto.options.iter() { + config_options.set( + &option.key, + option + .value + .as_ref() + .map(|value| value.try_into()) + .transpose()? + .unwrap(), + ); + } + + Ok(FileScanConfig { + object_store_url, + file_schema: schema, + file_groups, + statistics, + projection, + limit: proto.limit.as_ref().map(|sl| sl.limit as usize), + table_partition_cols, + output_ordering, + config_options: Arc::new(RwLock::new(config_options)), + }) +} + impl TryFrom<&protobuf::PartitionedFile> for PartitionedFile { type Error = DataFusionError; @@ -381,41 +473,6 @@ impl TryInto for &protobuf::Statistics { } } -impl TryInto for &protobuf::FileScanExecConf { - type Error = DataFusionError; - - fn try_into(self) -> Result { - let schema = Arc::new(convert_required!(self.schema)?); - let projection = self - .projection - .iter() - .map(|i| *i as usize) - .collect::>(); - let projection = if projection.is_empty() { - None - } else { - Some(projection) - }; - let statistics = convert_required!(self.statistics)?; - - Ok(FileScanConfig { - config_options: Arc::new(RwLock::new(ConfigOptions::new())), // TODO add serde - object_store_url: ObjectStoreUrl::parse(&self.object_store_url)?, - file_schema: schema, - file_groups: self - .file_groups - .iter() - .map(|f| f.try_into()) - .collect::, _>>()?, - statistics, - projection, - limit: self.limit.as_ref().map(|sl| sl.limit as usize), - table_partition_cols: vec![], - output_ordering: None, - }) - } -} - impl From for datafusion::physical_plan::joins::utils::JoinSide { fn from(t: JoinSide) -> Self { match t { diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index f6b6a0bfe84b..dedafd94bdb0 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -21,10 +21,7 @@ use std::sync::Arc; use datafusion::arrow::compute::SortOptions; use datafusion::arrow::datatypes::SchemaRef; -use datafusion::config::ConfigOptions; use datafusion::datasource::file_format::file_type::FileCompressionType; -use datafusion::datasource::listing::PartitionedFile; -use datafusion::datasource::object_store::ObjectStoreUrl; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::WindowFrame; @@ -35,9 +32,7 @@ use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::explain::ExplainExec; use datafusion::physical_plan::expressions::{Column, PhysicalSortExpr}; -use datafusion::physical_plan::file_format::{ - AvroExec, CsvExec, FileScanConfig, ParquetExec, -}; +use datafusion::physical_plan::file_format::{AvroExec, CsvExec, ParquetExec}; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::joins::utils::{ColumnIndex, JoinFilter}; use datafusion::physical_plan::joins::CrossJoinExec; @@ -53,14 +48,15 @@ use datafusion::physical_plan::{ AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, WindowExpr, }; use datafusion_common::DataFusionError; -use parking_lot::RwLock; use prost::bytes::BufMut; use prost::Message; use crate::common::proto_error; use crate::common::{csv_delimiter_to_string, str_to_byte}; use crate::from_proto::parse_expr; -use crate::physical_plan::from_proto::parse_physical_expr; +use crate::physical_plan::from_proto::{ + parse_physical_expr, parse_protobuf_file_scan_config, +}; use crate::protobuf::physical_expr_node::ExprType; use crate::protobuf::physical_plan_node::PhysicalPlanType; use crate::protobuf::repartition_exec_node::PartitionMethod; @@ -152,7 +148,10 @@ impl AsExecutionPlan for PhysicalPlanNode { Ok(Arc::new(FilterExec::try_new(predicate, input)?)) } PhysicalPlanType::CsvScan(scan) => Ok(Arc::new(CsvExec::new( - decode_scan_config(scan.base_conf.as_ref().unwrap())?, + parse_protobuf_file_scan_config( + scan.base_conf.as_ref().unwrap(), + registry, + )?, scan.has_header, str_to_byte(&scan.delimiter)?, FileCompressionType::UNCOMPRESSED, @@ -164,14 +163,20 @@ impl AsExecutionPlan for PhysicalPlanNode { .map(|expr| parse_expr(expr, registry)) .transpose()?; Ok(Arc::new(ParquetExec::new( - decode_scan_config(scan.base_conf.as_ref().unwrap())?, + parse_protobuf_file_scan_config( + scan.base_conf.as_ref().unwrap(), + registry, + )?, predicate, None, ))) } - PhysicalPlanType::AvroScan(scan) => Ok(Arc::new(AvroExec::new( - decode_scan_config(scan.base_conf.as_ref().unwrap())?, - ))), + PhysicalPlanType::AvroScan(scan) => { + Ok(Arc::new(AvroExec::new(parse_protobuf_file_scan_config( + scan.base_conf.as_ref().unwrap(), + registry, + )?))) + } PhysicalPlanType::CoalesceBatches(coalesce_batches) => { let input: Arc = into_physical_plan!( coalesce_batches.input, @@ -1120,46 +1125,6 @@ impl AsExecutionPlan for PhysicalPlanNode { } } -fn decode_scan_config( - proto: &protobuf::FileScanExecConf, -) -> Result { - let schema = Arc::new(convert_required!(proto.schema)?); - let projection = proto - .projection - .iter() - .map(|i| *i as usize) - .collect::>(); - let projection = if projection.is_empty() { - None - } else { - Some(projection) - }; - let statistics = convert_required!(proto.statistics)?; - - let file_groups: Vec> = proto - .file_groups - .iter() - .map(|f| f.try_into()) - .collect::, _>>()?; - - let object_store_url = match proto.object_store_url.is_empty() { - false => ObjectStoreUrl::parse(&proto.object_store_url)?, - true => ObjectStoreUrl::local_filesystem(), - }; - - Ok(FileScanConfig { - config_options: Arc::new(RwLock::new(ConfigOptions::new())), // TODO add serde - object_store_url, - file_schema: schema, - file_groups, - statistics, - projection, - limit: proto.limit.as_ref().map(|sl| sl.limit as usize), - table_partition_cols: vec![], - output_ordering: None, - }) -} - pub trait AsExecutionPlan: Debug + Send + Sync + Clone { fn try_decode(buf: &[u8]) -> Result where diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index cce992192c80..20efbbf2bbcd 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -40,6 +40,7 @@ use datafusion::physical_plan::expressions::{Avg, BinaryExpr, Column, Max, Min, use datafusion::physical_plan::{AggregateExpr, PhysicalExpr}; use crate::protobuf; +use crate::protobuf::{ConfigOption, PhysicalSortExprNode}; use datafusion::logical_expr::BuiltinScalarFunction; use datafusion::physical_expr::expressions::DateTimeIntervalExpr; use datafusion::physical_expr::ScalarFunctionExpr; @@ -440,6 +441,35 @@ impl TryFrom<&FileScanConfig> for protobuf::FileScanExecConf { .map(|p| p.as_slice().try_into()) .collect::, _>>()?; + let output_ordering = if let Some(output_ordering) = &conf.output_ordering { + output_ordering + .iter() + .map(|o| { + let expr = o.expr.clone().try_into()?; + Ok(PhysicalSortExprNode { + expr: Some(Box::new(expr)), + asc: !o.options.descending, + nulls_first: o.options.nulls_first, + }) + }) + .collect::, DataFusionError>>()? + } else { + vec![] + }; + let options = { + let config_options = conf.config_options.read().options().clone(); + config_options + .into_iter() + .map(|(key, value)| { + let value = (&value).try_into()?; + Ok(ConfigOption { + key, + value: Some(value), + }) + }) + .collect::, DataFusionError>>()? + }; + Ok(protobuf::FileScanExecConf { file_groups, statistics: Some((&conf.statistics).into()), @@ -458,6 +488,8 @@ impl TryFrom<&FileScanConfig> for protobuf::FileScanExecConf { .map(|x| x.0.clone()) .collect::>(), object_store_url: conf.object_store_url.to_string(), + output_ordering, + options, }) } } From 05b5bc2a35680275f55e296076e827bba3b7b8f5 Mon Sep 17 00:00:00 2001 From: yangzhong Date: Fri, 16 Dec 2022 23:41:50 +0800 Subject: [PATCH 3/5] Reorganize the logical plan related code in proto to be consistent with the physical plan code --- datafusion/proto/src/bytes/mod.rs | 87 +- datafusion/proto/src/common.rs | 42 +- datafusion/proto/src/lib.rs | 1374 +-------------- .../src/{ => logical_plan}/from_proto.rs | 245 +-- .../{logical_plan.rs => logical_plan/mod.rs} | 1537 +++++++++++++++-- .../proto/src/{ => logical_plan}/to_proto.rs | 28 +- .../proto/src/physical_plan/from_proto.rs | 6 +- datafusion/proto/src/physical_plan/mod.rs | 33 +- 8 files changed, 1642 insertions(+), 1710 deletions(-) rename datafusion/proto/src/{ => logical_plan}/from_proto.rs (95%) rename datafusion/proto/src/{logical_plan.rs => logical_plan/mod.rs} (51%) rename datafusion/proto/src/{ => logical_plan}/to_proto.rs (98%) diff --git a/datafusion/proto/src/bytes/mod.rs b/datafusion/proto/src/bytes/mod.rs index 1eb946bd5099..6163efb1f109 100644 --- a/datafusion/proto/src/bytes/mod.rs +++ b/datafusion/proto/src/bytes/mod.rs @@ -16,16 +16,16 @@ // under the License. //! Serialization / Deserialization to Bytes -use crate::logical_plan::{AsLogicalPlan, LogicalExtensionCodec}; -use crate::physical_plan::{AsExecutionPlan, PhysicalExtensionCodec}; -use crate::{from_proto::parse_expr, protobuf}; -use arrow::datatypes::SchemaRef; -use datafusion::datasource::TableProvider; +use crate::logical_plan::{ + self, AsLogicalPlan, DefaultLogicalExtensionCodec, LogicalExtensionCodec, +}; +use crate::physical_plan::{ + AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, +}; +use crate::protobuf; use datafusion::physical_plan::functions::make_scalar_function; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::{ - create_udaf, create_udf, Expr, Extension, LogicalPlan, Volatility, -}; +use datafusion_expr::{create_udaf, create_udf, Expr, LogicalPlan, Volatility}; use prost::{ bytes::{Bytes, BytesMut}, Message, @@ -137,7 +137,7 @@ impl Serializeable for Expr { DataFusionError::Plan(format!("Error decoding expr as protobuf: {}", e)) })?; - parse_expr(&protobuf, registry).map_err(|e| { + logical_plan::from_proto::parse_expr(&protobuf, registry).map_err(|e| { DataFusionError::Plan(format!("Error parsing protobuf into Expr: {}", e)) }) } @@ -272,75 +272,6 @@ pub fn physical_plan_from_bytes_with_extension_codec( protobuf.try_into_physical_plan(ctx, &ctx.runtime_env(), extension_codec) } -#[derive(Debug)] -struct DefaultLogicalExtensionCodec {} - -impl LogicalExtensionCodec for DefaultLogicalExtensionCodec { - fn try_decode( - &self, - _buf: &[u8], - _inputs: &[LogicalPlan], - _ctx: &SessionContext, - ) -> Result { - Err(DataFusionError::NotImplemented( - "No extension codec provided".to_string(), - )) - } - - fn try_encode(&self, _node: &Extension, _buf: &mut Vec) -> Result<()> { - Err(DataFusionError::NotImplemented( - "No extension codec provided".to_string(), - )) - } - - fn try_decode_table_provider( - &self, - _buf: &[u8], - _schema: SchemaRef, - _ctx: &SessionContext, - ) -> std::result::Result, DataFusionError> { - Err(DataFusionError::NotImplemented( - "No codec provided to for TableProviders".to_string(), - )) - } - - fn try_encode_table_provider( - &self, - _node: Arc, - _buf: &mut Vec, - ) -> std::result::Result<(), DataFusionError> { - Err(DataFusionError::NotImplemented( - "No codec provided to for TableProviders".to_string(), - )) - } -} - -#[derive(Debug)] -pub struct DefaultPhysicalExtensionCodec {} - -impl PhysicalExtensionCodec for DefaultPhysicalExtensionCodec { - fn try_decode( - &self, - _buf: &[u8], - _inputs: &[Arc], - _registry: &dyn FunctionRegistry, - ) -> Result> { - Err(DataFusionError::NotImplemented( - "PhysicalExtensionCodec is not provided".to_string(), - )) - } - - fn try_encode( - &self, - _node: Arc, - _buf: &mut Vec, - ) -> Result<()> { - Err(DataFusionError::NotImplemented( - "PhysicalExtensionCodec is not provided".to_string(), - )) - } -} - #[cfg(test)] mod test { use super::*; diff --git a/datafusion/proto/src/common.rs b/datafusion/proto/src/common.rs index 9388f2e51d36..d74083850abb 100644 --- a/datafusion/proto/src/common.rs +++ b/datafusion/proto/src/common.rs @@ -33,6 +33,46 @@ pub fn str_to_byte(s: &String) -> Result { Ok(s.as_bytes()[0]) } -pub(crate) fn proto_error>(message: S) -> DataFusionError { +pub fn byte_to_string(b: u8) -> Result { + let b = &[b]; + let b = std::str::from_utf8(b) + .map_err(|_| DataFusionError::Internal("Invalid CSV delimiter".to_owned()))?; + Ok(b.to_owned()) +} + +#[macro_export] +macro_rules! convert_required { + ($PB:expr) => {{ + if let Some(field) = $PB.as_ref() { + Ok(field.try_into()?) + } else { + Err(proto_error("Missing required field in protobuf")) + } + }}; +} + +#[macro_export] +macro_rules! into_required { + ($PB:expr) => {{ + if let Some(field) = $PB.as_ref() { + Ok(field.into()) + } else { + Err(proto_error("Missing required field in protobuf")) + } + }}; +} + +#[macro_export] +macro_rules! convert_box_required { + ($PB:expr) => {{ + if let Some(field) = $PB.as_ref() { + field.as_ref().try_into() + } else { + Err(proto_error("Missing required field in protobuf")) + } + }}; +} + +pub fn proto_error>(message: S) -> DataFusionError { DataFusionError::Internal(message.into()) } diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index bfd978f2e327..872c26408ebe 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -17,1385 +17,13 @@ //! Serde code for logical plans and expressions. -use datafusion_common::DataFusionError; - pub mod bytes; -mod common; -pub mod from_proto; +pub mod common; pub mod generated; pub mod logical_plan; pub mod physical_plan; -pub mod to_proto; pub use generated::datafusion as protobuf; #[cfg(doctest)] doc_comment::doctest!("../README.md", readme_example_test); - -impl From for DataFusionError { - fn from(e: from_proto::Error) -> Self { - DataFusionError::Plan(e.to_string()) - } -} - -impl From for DataFusionError { - fn from(e: to_proto::Error) -> Self { - DataFusionError::Plan(e.to_string()) - } -} - -#[cfg(test)] -mod roundtrip_tests { - use super::from_proto::parse_expr; - use super::protobuf; - use crate::bytes::{ - logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, - logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec, - }; - use crate::logical_plan::LogicalExtensionCodec; - use arrow::datatypes::{Schema, SchemaRef}; - use arrow::{ - array::ArrayRef, - datatypes::{ - DataType, Field, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, - TimeUnit, UnionMode, - }, - }; - use datafusion::datasource::datasource::TableProviderFactory; - use datafusion::datasource::TableProvider; - use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; - use datafusion::physical_plan::functions::make_scalar_function; - use datafusion::prelude::{ - create_udf, CsvReadOptions, SessionConfig, SessionContext, - }; - use datafusion::test_util::{TestTableFactory, TestTableProvider}; - use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue}; - use datafusion_expr::expr::{Between, BinaryExpr, Case, Cast, GroupingSet, Like}; - use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNode}; - use datafusion_expr::{ - col, lit, Accumulator, AggregateFunction, - BuiltinScalarFunction::{Sqrt, Substr}, - Expr, LogicalPlan, Operator, Volatility, - }; - use datafusion_expr::{ - create_udaf, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, - }; - use prost::Message; - use std::any::Any; - use std::collections::HashMap; - use std::fmt; - use std::fmt::Debug; - use std::fmt::Formatter; - use std::sync::Arc; - - #[cfg(feature = "json")] - fn roundtrip_json_test(proto: &protobuf::LogicalExprNode) { - let string = serde_json::to_string(proto).unwrap(); - let back: protobuf::LogicalExprNode = serde_json::from_str(&string).unwrap(); - assert_eq!(proto, &back); - } - - #[cfg(not(feature = "json"))] - fn roundtrip_json_test(_proto: &protobuf::LogicalExprNode) {} - - // Given a DataFusion logical Expr, convert it to protobuf and back, using debug formatting to test - // equality. - fn roundtrip_expr_test(initial_struct: T, ctx: SessionContext) - where - for<'a> &'a T: TryInto + Debug, - E: Debug, - { - let proto: protobuf::LogicalExprNode = (&initial_struct).try_into().unwrap(); - let round_trip: Expr = parse_expr(&proto, &ctx).unwrap(); - - assert_eq!( - format!("{:?}", &initial_struct), - format!("{:?}", round_trip) - ); - - roundtrip_json_test(&proto); - } - - fn new_box_field(name: &str, dt: DataType, nullable: bool) -> Box { - Box::new(Field::new(name, dt, nullable)) - } - - #[tokio::test] - async fn roundtrip_logical_plan() -> Result<(), DataFusionError> { - let ctx = SessionContext::new(); - ctx.register_csv("t1", "testdata/test.csv", CsvReadOptions::default()) - .await?; - let scan = ctx.table("t1")?.to_logical_plan()?; - let topk_plan = LogicalPlan::Extension(Extension { - node: Arc::new(TopKPlanNode::new(3, scan, col("revenue"))), - }); - let extension_codec = TopKExtensionCodec {}; - let bytes = - logical_plan_to_bytes_with_extension_codec(&topk_plan, &extension_codec)?; - let logical_round_trip = - logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &extension_codec)?; - assert_eq!( - format!("{:?}", topk_plan), - format!("{:?}", logical_round_trip) - ); - Ok(()) - } - - #[derive(Clone, PartialEq, Eq, ::prost::Message)] - pub struct TestTableProto { - /// URL of the table root - #[prost(string, tag = "1")] - pub url: String, - } - - #[derive(Debug)] - pub struct TestTableProviderCodec {} - - impl LogicalExtensionCodec for TestTableProviderCodec { - fn try_decode( - &self, - _buf: &[u8], - _inputs: &[LogicalPlan], - _ctx: &SessionContext, - ) -> Result { - Err(DataFusionError::NotImplemented( - "No extension codec provided".to_string(), - )) - } - - fn try_encode( - &self, - _node: &Extension, - _buf: &mut Vec, - ) -> Result<(), DataFusionError> { - Err(DataFusionError::NotImplemented( - "No extension codec provided".to_string(), - )) - } - - fn try_decode_table_provider( - &self, - buf: &[u8], - schema: SchemaRef, - _ctx: &SessionContext, - ) -> Result, DataFusionError> { - let msg = TestTableProto::decode(buf).map_err(|_| { - DataFusionError::Internal("Error decoding test table".to_string()) - })?; - let provider = TestTableProvider { - url: msg.url, - schema, - }; - Ok(Arc::new(provider)) - } - - fn try_encode_table_provider( - &self, - node: Arc, - buf: &mut Vec, - ) -> Result<(), DataFusionError> { - let table = node - .as_ref() - .as_any() - .downcast_ref::() - .expect("Can't encode non-test tables"); - let msg = TestTableProto { - url: table.url.clone(), - }; - msg.encode(buf).map_err(|_| { - DataFusionError::Internal("Error encoding test table".to_string()) - }) - } - } - - #[tokio::test] - async fn roundtrip_custom_tables() -> Result<(), DataFusionError> { - let mut table_factories: HashMap> = - HashMap::new(); - table_factories.insert("TESTTABLE".to_string(), Arc::new(TestTableFactory {})); - let cfg = RuntimeConfig::new().with_table_factories(table_factories); - let env = RuntimeEnv::new(cfg).unwrap(); - let ses = SessionConfig::new(); - let ctx = SessionContext::with_config_rt(ses, Arc::new(env)); - - let sql = "CREATE EXTERNAL TABLE t STORED AS testtable LOCATION 's3://bucket/schema/table';"; - ctx.sql(sql).await.unwrap(); - - let codec = TestTableProviderCodec {}; - let scan = ctx.table("t")?.to_logical_plan()?; - let bytes = logical_plan_to_bytes_with_extension_codec(&scan, &codec)?; - let logical_round_trip = - logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?; - assert_eq!(format!("{:?}", scan), format!("{:?}", logical_round_trip)); - Ok(()) - } - - #[tokio::test] - async fn roundtrip_logical_plan_aggregation() -> Result<(), DataFusionError> { - let ctx = SessionContext::new(); - - let schema = Schema::new(vec![ - Field::new("a", DataType::Int64, true), - Field::new("b", DataType::Decimal128(15, 2), true), - ]); - - ctx.register_csv( - "t1", - "testdata/test.csv", - CsvReadOptions::default().schema(&schema), - ) - .await?; - - let query = - "SELECT a, SUM(b + 1) as b_sum FROM t1 GROUP BY a ORDER BY b_sum DESC"; - let plan = ctx.sql(query).await?.to_logical_plan()?; - - let bytes = logical_plan_to_bytes(&plan)?; - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; - assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip)); - - Ok(()) - } - - #[tokio::test] - async fn roundtrip_single_count_distinct() -> Result<(), DataFusionError> { - let ctx = SessionContext::new(); - - let schema = Schema::new(vec![ - Field::new("a", DataType::Int64, true), - Field::new("b", DataType::Decimal128(15, 2), true), - ]); - - ctx.register_csv( - "t1", - "testdata/test.csv", - CsvReadOptions::default().schema(&schema), - ) - .await?; - - let query = "SELECT a, COUNT(DISTINCT b) as b_cd FROM t1 GROUP BY a"; - let plan = ctx.sql(query).await?.to_logical_plan()?; - - let bytes = logical_plan_to_bytes(&plan)?; - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; - assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip)); - - Ok(()) - } - - #[tokio::test] - async fn roundtrip_logical_plan_with_extension() -> Result<(), DataFusionError> { - let ctx = SessionContext::new(); - ctx.register_csv("t1", "testdata/test.csv", CsvReadOptions::default()) - .await?; - let plan = ctx.table("t1")?.to_logical_plan()?; - let bytes = logical_plan_to_bytes(&plan)?; - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; - assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip)); - Ok(()) - } - - #[tokio::test] - async fn roundtrip_logical_plan_with_view_scan() -> Result<(), DataFusionError> { - let ctx = SessionContext::new(); - ctx.register_csv("t1", "testdata/test.csv", CsvReadOptions::default()) - .await?; - ctx.sql("CREATE VIEW view_t1(a, b) AS SELECT a, b FROM t1") - .await?; - let plan = ctx.sql("SELECT * FROM view_t1").await?.to_logical_plan()?; - let bytes = logical_plan_to_bytes(&plan)?; - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; - assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip)); - Ok(()) - } - - pub mod proto { - #[derive(Clone, PartialEq, ::prost::Message)] - pub struct TopKPlanProto { - #[prost(uint64, tag = "1")] - pub k: u64, - - #[prost(message, optional, tag = "2")] - pub expr: ::core::option::Option, - } - - #[derive(Clone, PartialEq, Eq, ::prost::Message)] - pub struct TopKExecProto { - #[prost(uint64, tag = "1")] - pub k: u64, - } - } - - struct TopKPlanNode { - k: usize, - input: LogicalPlan, - /// The sort expression (this example only supports a single sort - /// expr) - expr: Expr, - } - - impl TopKPlanNode { - pub fn new(k: usize, input: LogicalPlan, expr: Expr) -> Self { - Self { k, input, expr } - } - } - - impl Debug for TopKPlanNode { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - self.fmt_for_explain(f) - } - } - - impl UserDefinedLogicalNode for TopKPlanNode { - fn as_any(&self) -> &dyn Any { - self - } - - fn inputs(&self) -> Vec<&LogicalPlan> { - vec![&self.input] - } - - /// Schema for TopK is the same as the input - fn schema(&self) -> &DFSchemaRef { - self.input.schema() - } - - fn expressions(&self) -> Vec { - vec![self.expr.clone()] - } - - /// For example: `TopK: k=10` - fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "TopK: k={}", self.k) - } - - fn from_template( - &self, - exprs: &[Expr], - inputs: &[LogicalPlan], - ) -> Arc { - assert_eq!(inputs.len(), 1, "input size inconsistent"); - assert_eq!(exprs.len(), 1, "expression size inconsistent"); - Arc::new(TopKPlanNode { - k: self.k, - input: inputs[0].clone(), - expr: exprs[0].clone(), - }) - } - } - - #[derive(Debug)] - pub struct TopKExtensionCodec {} - - impl LogicalExtensionCodec for TopKExtensionCodec { - fn try_decode( - &self, - buf: &[u8], - inputs: &[LogicalPlan], - ctx: &SessionContext, - ) -> Result { - if let Some((input, _)) = inputs.split_first() { - let proto = proto::TopKPlanProto::decode(buf).map_err(|e| { - DataFusionError::Internal(format!( - "failed to decode logical plan: {:?}", - e - )) - })?; - - if let Some(expr) = proto.expr.as_ref() { - let node = TopKPlanNode::new( - proto.k as usize, - input.clone(), - parse_expr(expr, ctx)?, - ); - - Ok(Extension { - node: Arc::new(node), - }) - } else { - Err(DataFusionError::Internal( - "invalid plan, no expr".to_string(), - )) - } - } else { - Err(DataFusionError::Internal( - "invalid plan, no input".to_string(), - )) - } - } - - fn try_encode( - &self, - node: &Extension, - buf: &mut Vec, - ) -> Result<(), DataFusionError> { - if let Some(exec) = node.node.as_any().downcast_ref::() { - let proto = proto::TopKPlanProto { - k: exec.k as u64, - expr: Some((&exec.expr).try_into()?), - }; - - proto.encode(buf).map_err(|e| { - DataFusionError::Internal(format!( - "failed to encode logical plan: {:?}", - e - )) - })?; - - Ok(()) - } else { - Err(DataFusionError::Internal( - "unsupported plan type".to_string(), - )) - } - } - - fn try_decode_table_provider( - &self, - _buf: &[u8], - _schema: SchemaRef, - _ctx: &SessionContext, - ) -> Result, DataFusionError> { - Err(DataFusionError::Internal( - "unsupported plan type".to_string(), - )) - } - - fn try_encode_table_provider( - &self, - _node: Arc, - _buf: &mut Vec, - ) -> Result<(), DataFusionError> { - Err(DataFusionError::Internal( - "unsupported plan type".to_string(), - )) - } - } - - #[test] - fn scalar_values_error_serialization() { - let should_fail_on_seralize: Vec = vec![ - // Should fail due to empty values - ScalarValue::Struct( - Some(vec![]), - Box::new(vec![Field::new("item", DataType::Int16, true)]), - ), - // Should fail due to inconsistent types in the list - ScalarValue::new_list( - Some(vec![ - ScalarValue::Int16(None), - ScalarValue::Float32(Some(32.0)), - ]), - DataType::List(new_box_field("item", DataType::Int16, true)), - ), - ScalarValue::new_list( - Some(vec![ - ScalarValue::Float32(None), - ScalarValue::Float32(Some(32.0)), - ]), - DataType::List(new_box_field("item", DataType::Int16, true)), - ), - ScalarValue::new_list( - Some(vec![ - ScalarValue::Float32(None), - ScalarValue::Float32(Some(32.0)), - ]), - DataType::Int16, - ), - ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list( - None, - DataType::List(new_box_field("level2", DataType::Float32, true)), - ), - ScalarValue::new_list( - Some(vec![ - ScalarValue::Float32(Some(-213.1)), - ScalarValue::Float32(None), - ScalarValue::Float32(Some(5.5)), - ScalarValue::Float32(Some(2.0)), - ScalarValue::Float32(Some(1.0)), - ]), - DataType::List(new_box_field("level2", DataType::Float32, true)), - ), - ScalarValue::new_list( - None, - DataType::List(new_box_field( - "lists are typed inconsistently", - DataType::Int16, - true, - )), - ), - ]), - DataType::List(new_box_field( - "level1", - DataType::List(new_box_field("level2", DataType::Float32, true)), - true, - )), - ), - ]; - - for test_case in should_fail_on_seralize.into_iter() { - let proto: Result = - (&test_case).try_into(); - - // Validation is also done on read, so if serialization passed - // also try to convert back to ScalarValue - if let Ok(proto) = proto { - let res: Result = (&proto).try_into(); - assert!( - res.is_err(), - "The value {:?} unexpectedly serialized without error:{:?}", - test_case, - res - ); - } - } - } - - #[test] - fn round_trip_scalar_values() { - let should_pass: Vec = vec![ - ScalarValue::Boolean(None), - ScalarValue::Float32(None), - ScalarValue::Float64(None), - ScalarValue::Int8(None), - ScalarValue::Int16(None), - ScalarValue::Int32(None), - ScalarValue::Int64(None), - ScalarValue::UInt8(None), - ScalarValue::UInt16(None), - ScalarValue::UInt32(None), - ScalarValue::UInt64(None), - ScalarValue::Utf8(None), - ScalarValue::LargeUtf8(None), - ScalarValue::new_list(None, DataType::Boolean), - ScalarValue::Date32(None), - ScalarValue::Boolean(Some(true)), - ScalarValue::Boolean(Some(false)), - ScalarValue::Float32(Some(1.0)), - ScalarValue::Float32(Some(f32::MAX)), - ScalarValue::Float32(Some(f32::MIN)), - ScalarValue::Float32(Some(-2000.0)), - ScalarValue::Float64(Some(1.0)), - ScalarValue::Float64(Some(f64::MAX)), - ScalarValue::Float64(Some(f64::MIN)), - ScalarValue::Float64(Some(-2000.0)), - ScalarValue::Int8(Some(i8::MIN)), - ScalarValue::Int8(Some(i8::MAX)), - ScalarValue::Int8(Some(0)), - ScalarValue::Int8(Some(-15)), - ScalarValue::Int16(Some(i16::MIN)), - ScalarValue::Int16(Some(i16::MAX)), - ScalarValue::Int16(Some(0)), - ScalarValue::Int16(Some(-15)), - ScalarValue::Int32(Some(i32::MIN)), - ScalarValue::Int32(Some(i32::MAX)), - ScalarValue::Int32(Some(0)), - ScalarValue::Int32(Some(-15)), - ScalarValue::Int64(Some(i64::MIN)), - ScalarValue::Int64(Some(i64::MAX)), - ScalarValue::Int64(Some(0)), - ScalarValue::Int64(Some(-15)), - ScalarValue::UInt8(Some(u8::MAX)), - ScalarValue::UInt8(Some(0)), - ScalarValue::UInt16(Some(u16::MAX)), - ScalarValue::UInt16(Some(0)), - ScalarValue::UInt32(Some(u32::MAX)), - ScalarValue::UInt32(Some(0)), - ScalarValue::UInt64(Some(u64::MAX)), - ScalarValue::UInt64(Some(0)), - ScalarValue::Utf8(Some(String::from("Test string "))), - ScalarValue::LargeUtf8(Some(String::from("Test Large utf8"))), - ScalarValue::Date32(Some(0)), - ScalarValue::Date32(Some(i32::MAX)), - ScalarValue::Date32(None), - ScalarValue::Date64(Some(0)), - ScalarValue::Date64(Some(i64::MAX)), - ScalarValue::Date64(None), - ScalarValue::Time32Second(Some(0)), - ScalarValue::Time32Second(Some(i32::MAX)), - ScalarValue::Time32Second(None), - ScalarValue::Time32Millisecond(Some(0)), - ScalarValue::Time32Millisecond(Some(i32::MAX)), - ScalarValue::Time32Millisecond(None), - ScalarValue::Time64Microsecond(Some(0)), - ScalarValue::Time64Microsecond(Some(i64::MAX)), - ScalarValue::Time64Microsecond(None), - ScalarValue::Time64Nanosecond(Some(0)), - ScalarValue::Time64Nanosecond(Some(i64::MAX)), - ScalarValue::Time64Nanosecond(None), - ScalarValue::TimestampNanosecond(Some(0), None), - ScalarValue::TimestampNanosecond(Some(i64::MAX), None), - ScalarValue::TimestampNanosecond(Some(0), Some("UTC".to_string())), - ScalarValue::TimestampNanosecond(None, None), - ScalarValue::TimestampMicrosecond(Some(0), None), - ScalarValue::TimestampMicrosecond(Some(i64::MAX), None), - ScalarValue::TimestampMicrosecond(Some(0), Some("UTC".to_string())), - ScalarValue::TimestampMicrosecond(None, None), - ScalarValue::TimestampMillisecond(Some(0), None), - ScalarValue::TimestampMillisecond(Some(i64::MAX), None), - ScalarValue::TimestampMillisecond(Some(0), Some("UTC".to_string())), - ScalarValue::TimestampMillisecond(None, None), - ScalarValue::TimestampSecond(Some(0), None), - ScalarValue::TimestampSecond(Some(i64::MAX), None), - ScalarValue::TimestampSecond(Some(0), Some("UTC".to_string())), - ScalarValue::TimestampSecond(None, None), - ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(0, 0))), - ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(1, 2))), - ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value( - i32::MAX, - i32::MAX, - ))), - ScalarValue::IntervalDayTime(None), - ScalarValue::IntervalMonthDayNano(Some( - IntervalMonthDayNanoType::make_value(0, 0, 0), - )), - ScalarValue::IntervalMonthDayNano(Some( - IntervalMonthDayNanoType::make_value(1, 2, 3), - )), - ScalarValue::IntervalMonthDayNano(Some( - IntervalMonthDayNanoType::make_value(i32::MAX, i32::MAX, i64::MAX), - )), - ScalarValue::IntervalMonthDayNano(None), - ScalarValue::new_list( - Some(vec![ - ScalarValue::Float32(Some(-213.1)), - ScalarValue::Float32(None), - ScalarValue::Float32(Some(5.5)), - ScalarValue::Float32(Some(2.0)), - ScalarValue::Float32(Some(1.0)), - ]), - DataType::Float32, - ), - ScalarValue::new_list( - Some(vec![ - ScalarValue::new_list(None, DataType::Float32), - ScalarValue::new_list( - Some(vec![ - ScalarValue::Float32(Some(-213.1)), - ScalarValue::Float32(None), - ScalarValue::Float32(Some(5.5)), - ScalarValue::Float32(Some(2.0)), - ScalarValue::Float32(Some(1.0)), - ]), - DataType::Float32, - ), - ]), - DataType::List(new_box_field("item", DataType::Float32, true)), - ), - ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::Utf8(Some("foo".into()))), - ), - ScalarValue::Dictionary( - Box::new(DataType::Int32), - Box::new(ScalarValue::Utf8(None)), - ), - ScalarValue::Binary(Some(b"bar".to_vec())), - ScalarValue::Binary(None), - ScalarValue::LargeBinary(Some(b"bar".to_vec())), - ScalarValue::LargeBinary(None), - ScalarValue::Struct( - Some(vec![ - ScalarValue::Int32(Some(23)), - ScalarValue::Boolean(Some(false)), - ]), - Box::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Boolean, false), - ]), - ), - ScalarValue::Struct( - None, - Box::new(vec![ - Field::new("a", DataType::Int32, true), - Field::new("a", DataType::Boolean, false), - ]), - ), - ScalarValue::FixedSizeBinary( - b"bar".to_vec().len() as i32, - Some(b"bar".to_vec()), - ), - ScalarValue::FixedSizeBinary(0, None), - ScalarValue::FixedSizeBinary(5, None), - ]; - - for test_case in should_pass.into_iter() { - let proto: super::protobuf::ScalarValue = (&test_case) - .try_into() - .expect("failed conversion to protobuf"); - - let roundtrip: ScalarValue = (&proto) - .try_into() - .expect("failed conversion from protobuf"); - - assert_eq!( - test_case, roundtrip, - "ScalarValue was not the same after round trip!\n\n\ - Input: {:?}\n\nRoundtrip: {:?}", - test_case, roundtrip - ); - } - } - - #[test] - fn round_trip_scalar_types() { - let should_pass: Vec = vec![ - DataType::Boolean, - DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, - DataType::UInt8, - DataType::UInt16, - DataType::UInt32, - DataType::UInt64, - DataType::Float32, - DataType::Float64, - DataType::Date32, - DataType::Time64(TimeUnit::Microsecond), - DataType::Time64(TimeUnit::Nanosecond), - DataType::Utf8, - DataType::LargeUtf8, - // Recursive list tests - DataType::List(new_box_field("level1", DataType::Boolean, true)), - DataType::List(new_box_field( - "Level1", - DataType::List(new_box_field("level2", DataType::Date32, true)), - true, - )), - ]; - - for test_case in should_pass.into_iter() { - let field = Field::new("item", test_case, true); - let proto: super::protobuf::Field = (&field).try_into().unwrap(); - let roundtrip: Field = (&proto).try_into().unwrap(); - assert_eq!(format!("{:?}", field), format!("{:?}", roundtrip)); - } - } - - #[test] - fn round_trip_datatype() { - let test_cases: Vec = vec![ - DataType::Null, - DataType::Boolean, - DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, - DataType::UInt8, - DataType::UInt16, - DataType::UInt32, - DataType::UInt64, - DataType::Float16, - DataType::Float32, - DataType::Float64, - // Add more timestamp tests - DataType::Timestamp(TimeUnit::Millisecond, None), - DataType::Date32, - DataType::Date64, - DataType::Time32(TimeUnit::Second), - DataType::Time32(TimeUnit::Millisecond), - DataType::Time32(TimeUnit::Microsecond), - DataType::Time32(TimeUnit::Nanosecond), - DataType::Time64(TimeUnit::Second), - DataType::Time64(TimeUnit::Millisecond), - DataType::Time64(TimeUnit::Microsecond), - DataType::Time64(TimeUnit::Nanosecond), - DataType::Duration(TimeUnit::Second), - DataType::Duration(TimeUnit::Millisecond), - DataType::Duration(TimeUnit::Microsecond), - DataType::Duration(TimeUnit::Nanosecond), - DataType::Interval(IntervalUnit::YearMonth), - DataType::Interval(IntervalUnit::DayTime), - DataType::Binary, - DataType::FixedSizeBinary(0), - DataType::FixedSizeBinary(1234), - DataType::FixedSizeBinary(-432), - DataType::LargeBinary, - DataType::Utf8, - DataType::LargeUtf8, - DataType::Decimal128(7, 12), - // Recursive list tests - DataType::List(new_box_field("Level1", DataType::Binary, true)), - DataType::List(new_box_field( - "Level1", - DataType::List(new_box_field( - "Level2", - DataType::FixedSizeBinary(53), - false, - )), - true, - )), - // Fixed size lists - DataType::FixedSizeList(new_box_field("Level1", DataType::Binary, true), 4), - DataType::FixedSizeList( - new_box_field( - "Level1", - DataType::List(new_box_field( - "Level2", - DataType::FixedSizeBinary(53), - false, - )), - true, - ), - 41, - ), - // Struct Testing - DataType::Struct(vec![ - Field::new("nullable", DataType::Boolean, false), - Field::new("name", DataType::Utf8, false), - Field::new("datatype", DataType::Binary, false), - ]), - DataType::Struct(vec![ - Field::new("nullable", DataType::Boolean, false), - Field::new("name", DataType::Utf8, false), - Field::new("datatype", DataType::Binary, false), - Field::new( - "nested_struct", - DataType::Struct(vec![ - Field::new("nullable", DataType::Boolean, false), - Field::new("name", DataType::Utf8, false), - Field::new("datatype", DataType::Binary, false), - ]), - true, - ), - ]), - DataType::Union( - vec![ - Field::new("nullable", DataType::Boolean, false), - Field::new("name", DataType::Utf8, false), - Field::new("datatype", DataType::Binary, false), - ], - vec![7, 5, 3], - UnionMode::Sparse, - ), - DataType::Union( - vec![ - Field::new("nullable", DataType::Boolean, false), - Field::new("name", DataType::Utf8, false), - Field::new("datatype", DataType::Binary, false), - Field::new( - "nested_struct", - DataType::Struct(vec![ - Field::new("nullable", DataType::Boolean, false), - Field::new("name", DataType::Utf8, false), - Field::new("datatype", DataType::Binary, false), - ]), - true, - ), - ], - vec![5, 8, 1], - UnionMode::Dense, - ), - DataType::Dictionary( - Box::new(DataType::Utf8), - Box::new(DataType::Struct(vec![ - Field::new("nullable", DataType::Boolean, false), - Field::new("name", DataType::Utf8, false), - Field::new("datatype", DataType::Binary, false), - ])), - ), - DataType::Dictionary( - Box::new(DataType::Decimal128(10, 50)), - Box::new(DataType::FixedSizeList( - new_box_field("Level1", DataType::Binary, true), - 4, - )), - ), - ]; - - for test_case in test_cases.into_iter() { - let proto: super::protobuf::ArrowType = (&test_case).try_into().unwrap(); - let roundtrip: DataType = (&proto).try_into().unwrap(); - assert_eq!(format!("{:?}", test_case), format!("{:?}", roundtrip)); - } - } - - #[test] - fn roundtrip_null_scalar_values() { - let test_types = vec![ - ScalarValue::Boolean(None), - ScalarValue::Float32(None), - ScalarValue::Float64(None), - ScalarValue::Int8(None), - ScalarValue::Int16(None), - ScalarValue::Int32(None), - ScalarValue::Int64(None), - ScalarValue::UInt8(None), - ScalarValue::UInt16(None), - ScalarValue::UInt32(None), - ScalarValue::UInt64(None), - ScalarValue::Utf8(None), - ScalarValue::LargeUtf8(None), - ScalarValue::Date32(None), - ScalarValue::TimestampMicrosecond(None, None), - ScalarValue::TimestampNanosecond(None, None), - ScalarValue::List( - None, - Box::new(Field::new("item", DataType::Boolean, false)), - ), - ]; - - for test_case in test_types.into_iter() { - let proto_scalar: super::protobuf::ScalarValue = - (&test_case).try_into().unwrap(); - let returned_scalar: datafusion::scalar::ScalarValue = - (&proto_scalar).try_into().unwrap(); - assert_eq!( - format!("{:?}", &test_case), - format!("{:?}", returned_scalar) - ); - } - } - - #[test] - fn roundtrip_not() { - let test_expr = Expr::Not(Box::new(lit(1.0_f32))); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_is_null() { - let test_expr = Expr::IsNull(Box::new(col("id"))); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_is_not_null() { - let test_expr = Expr::IsNotNull(Box::new(col("id"))); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_between() { - let test_expr = Expr::Between(Between::new( - Box::new(lit(1.0_f32)), - true, - Box::new(lit(2.0_f32)), - Box::new(lit(3.0_f32)), - )); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_binary_op() { - fn test(op: Operator) { - let test_expr = Expr::BinaryExpr(BinaryExpr::new( - Box::new(lit(1.0_f32)), - op, - Box::new(lit(2.0_f32)), - )); - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - test(Operator::StringConcat); - test(Operator::RegexNotIMatch); - test(Operator::RegexNotMatch); - test(Operator::RegexIMatch); - test(Operator::RegexMatch); - test(Operator::Like); - test(Operator::NotLike); - test(Operator::ILike); - test(Operator::NotILike); - test(Operator::BitwiseShiftRight); - test(Operator::BitwiseShiftLeft); - test(Operator::BitwiseAnd); - test(Operator::BitwiseOr); - test(Operator::BitwiseXor); - test(Operator::IsDistinctFrom); - test(Operator::IsNotDistinctFrom); - test(Operator::And); - test(Operator::Or); - test(Operator::Eq); - test(Operator::NotEq); - test(Operator::Lt); - test(Operator::LtEq); - test(Operator::Gt); - test(Operator::GtEq); - } - - #[test] - fn roundtrip_case() { - let test_expr = Expr::Case(Case::new( - Some(Box::new(lit(1.0_f32))), - vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))], - Some(Box::new(lit(4.0_f32))), - )); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_case_with_null() { - let test_expr = Expr::Case(Case::new( - Some(Box::new(lit(1.0_f32))), - vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))], - Some(Box::new(Expr::Literal(ScalarValue::Null))), - )); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_null_literal() { - let test_expr = Expr::Literal(ScalarValue::Null); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_cast() { - let test_expr = Expr::Cast(Cast::new(Box::new(lit(1.0_f32)), DataType::Boolean)); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_sort_expr() { - let test_expr = Expr::Sort { - expr: Box::new(lit(1.0_f32)), - asc: true, - nulls_first: true, - }; - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_negative() { - let test_expr = Expr::Negative(Box::new(lit(1.0_f32))); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_inlist() { - let test_expr = Expr::InList { - expr: Box::new(lit(1.0_f32)), - list: vec![lit(2.0_f32)], - negated: true, - }; - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_wildcard() { - let test_expr = Expr::Wildcard; - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_sqrt() { - let test_expr = Expr::ScalarFunction { - fun: Sqrt, - args: vec![col("col")], - }; - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_like() { - fn like(negated: bool, escape_char: Option) { - let test_expr = Expr::Like(Like::new( - negated, - Box::new(col("col")), - Box::new(lit("[0-9]+")), - escape_char, - )); - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - like(true, Some('X')); - like(false, Some('\\')); - like(true, None); - like(false, None); - } - - #[test] - fn roundtrip_ilike() { - fn ilike(negated: bool, escape_char: Option) { - let test_expr = Expr::ILike(Like::new( - negated, - Box::new(col("col")), - Box::new(lit("[0-9]+")), - escape_char, - )); - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - ilike(true, Some('X')); - ilike(false, Some('\\')); - ilike(true, None); - ilike(false, None); - } - - #[test] - fn roundtrip_similar_to() { - fn similar_to(negated: bool, escape_char: Option) { - let test_expr = Expr::SimilarTo(Like::new( - negated, - Box::new(col("col")), - Box::new(lit("[0-9]+")), - escape_char, - )); - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - similar_to(true, Some('X')); - similar_to(false, Some('\\')); - similar_to(true, None); - similar_to(false, None); - } - - #[test] - fn roundtrip_count() { - let test_expr = Expr::AggregateFunction { - fun: AggregateFunction::Count, - args: vec![col("bananas")], - distinct: false, - filter: None, - }; - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_count_distinct() { - let test_expr = Expr::AggregateFunction { - fun: AggregateFunction::Count, - args: vec![col("bananas")], - distinct: true, - filter: None, - }; - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_approx_percentile_cont() { - let test_expr = Expr::AggregateFunction { - fun: AggregateFunction::ApproxPercentileCont, - args: vec![col("bananas"), lit(0.42_f32)], - distinct: false, - filter: None, - }; - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_aggregate_udf() { - #[derive(Debug)] - struct Dummy {} - - impl Accumulator for Dummy { - fn state(&self) -> datafusion::error::Result> { - Ok(vec![]) - } - - fn update_batch( - &mut self, - _values: &[ArrayRef], - ) -> datafusion::error::Result<()> { - Ok(()) - } - - fn merge_batch( - &mut self, - _states: &[ArrayRef], - ) -> datafusion::error::Result<()> { - Ok(()) - } - - fn evaluate(&self) -> datafusion::error::Result { - Ok(ScalarValue::Float64(None)) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) - } - } - - let dummy_agg = create_udaf( - // the name; used to represent it in plan descriptions and in the registry, to use in SQL. - "dummy_agg", - // the input type; DataFusion guarantees that the first entry of `values` in `update` has this type. - DataType::Float64, - // the return type; DataFusion expects this to match the type returned by `evaluate`. - Arc::new(DataType::Float64), - Volatility::Immutable, - // This is the accumulator factory; DataFusion uses it to create new accumulators. - Arc::new(|_| Ok(Box::new(Dummy {}))), - // This is the description of the state. `state()` must match the types here. - Arc::new(vec![DataType::Float64, DataType::UInt32]), - ); - - let test_expr = Expr::AggregateUDF { - fun: Arc::new(dummy_agg.clone()), - args: vec![lit(1.0_f64)], - filter: Some(Box::new(lit(true))), - }; - - let ctx = SessionContext::new(); - ctx.register_udaf(dummy_agg); - - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_scalar_udf() { - let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef); - - let scalar_fn = make_scalar_function(fn_impl); - - let udf = create_udf( - "dummy", - vec![DataType::Utf8], - Arc::new(DataType::Utf8), - Volatility::Immutable, - scalar_fn, - ); - - let test_expr = Expr::ScalarUDF { - fun: Arc::new(udf.clone()), - args: vec![lit("")], - }; - - let ctx = SessionContext::new(); - ctx.register_udf(udf); - - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_grouping_sets() { - let test_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ - vec![col("a")], - vec![col("b")], - vec![col("a"), col("b")], - ])); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_rollup() { - let test_expr = Expr::GroupingSet(GroupingSet::Rollup(vec![col("a"), col("b")])); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_cube() { - let test_expr = Expr::GroupingSet(GroupingSet::Cube(vec![col("a"), col("b")])); - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx); - } - - #[test] - fn roundtrip_substr() { - // substr(string, position) - let test_expr = Expr::ScalarFunction { - fun: Substr, - args: vec![col("col"), lit(1_i64)], - }; - - // substr(string, position, count) - let test_expr_with_count = Expr::ScalarFunction { - fun: Substr, - args: vec![col("col"), lit(1_i64), lit(1_i64)], - }; - - let ctx = SessionContext::new(); - roundtrip_expr_test(test_expr, ctx.clone()); - roundtrip_expr_test(test_expr_with_count, ctx); - } - #[test] - fn roundtrip_window() { - let ctx = SessionContext::new(); - - // 1. without window_frame - let test_expr1 = Expr::WindowFunction { - fun: WindowFunction::BuiltInWindowFunction( - datafusion_expr::window_function::BuiltInWindowFunction::Rank, - ), - args: vec![], - partition_by: vec![col("col1")], - order_by: vec![col("col2")], - window_frame: WindowFrame::new(true), - }; - - // 2. with default window_frame - let test_expr2 = Expr::WindowFunction { - fun: WindowFunction::BuiltInWindowFunction( - datafusion_expr::window_function::BuiltInWindowFunction::Rank, - ), - args: vec![], - partition_by: vec![col("col1")], - order_by: vec![col("col2")], - window_frame: WindowFrame::new(true), - }; - - // 3. with window_frame with row numbers - let range_number_frame = WindowFrame { - units: WindowFrameUnits::Range, - start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))), - end_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), - }; - - let test_expr3 = Expr::WindowFunction { - fun: WindowFunction::BuiltInWindowFunction( - datafusion_expr::window_function::BuiltInWindowFunction::Rank, - ), - args: vec![], - partition_by: vec![col("col1")], - order_by: vec![col("col2")], - window_frame: range_number_frame, - }; - - // 4. test with AggregateFunction - let row_number_frame = WindowFrame { - units: WindowFrameUnits::Rows, - start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))), - end_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), - }; - - let test_expr4 = Expr::WindowFunction { - fun: WindowFunction::AggregateFunction(AggregateFunction::Max), - args: vec![col("col1")], - partition_by: vec![col("col1")], - order_by: vec![col("col2")], - window_frame: row_number_frame, - }; - - roundtrip_expr_test(test_expr1, ctx.clone()); - roundtrip_expr_test(test_expr2, ctx.clone()); - roundtrip_expr_test(test_expr3, ctx.clone()); - roundtrip_expr_test(test_expr4, ctx); - } -} diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs similarity index 95% rename from datafusion/proto/src/from_proto.rs rename to datafusion/proto/src/logical_plan/from_proto.rs index ebe1870d3311..503a777d4675 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -15,14 +15,14 @@ // specific language governing permissions and limitations // under the License. -use crate::protobuf::plan_type::PlanTypeEnum::{ - FinalLogicalPlan, FinalPhysicalPlan, InitialLogicalPlan, InitialPhysicalPlan, - OptimizedLogicalPlan, OptimizedPhysicalPlan, -}; -use crate::protobuf::{self, PlaceholderNode}; use crate::protobuf::{ + self, + plan_type::PlanTypeEnum::{ + FinalLogicalPlan, FinalPhysicalPlan, InitialLogicalPlan, InitialPhysicalPlan, + OptimizedLogicalPlan, OptimizedPhysicalPlan, + }, CubeNode, GroupingSetNode, OptimizedLogicalPlanType, OptimizedPhysicalPlanType, - RollupNode, + PlaceholderNode, RollupNode, }; use arrow::datatypes::{ DataType, Field, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit, UnionMode, @@ -32,7 +32,6 @@ use datafusion_common::{ Column, DFField, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, ScalarValue, }; -use datafusion_expr::expr::{BinaryExpr, Cast, TryCast}; use datafusion_expr::{ abs, acos, array, ascii, asin, atan, atan2, bit_length, btrim, ceil, character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, date_bin, @@ -43,10 +42,11 @@ use datafusion_expr::{ sha384, sha512, signum, sin, split_part, sqrt, starts_with, strpos, substr, substring, tan, to_hex, to_timestamp_micros, to_timestamp_millis, to_timestamp_seconds, translate, trim, trunc, upper, uuid, AggregateFunction, - Between, BuiltInWindowFunction, BuiltinScalarFunction, Case, Expr, GetIndexedField, - GroupingSet, + Between, BinaryExpr, BuiltInWindowFunction, BuiltinScalarFunction, Case, Cast, Expr, + GetIndexedField, GroupingSet, GroupingSet::GroupingSets, - Like, Operator, WindowFrame, WindowFrameBound, WindowFrameUnits, + JoinConstraint, JoinType, Like, Operator, TryCast, WindowFrame, WindowFrameBound, + WindowFrameUnits, }; use std::sync::Arc; @@ -274,27 +274,27 @@ impl TryFrom<&protobuf::arrow_type::ArrowTypeEnum> for DataType { arrow_type::ArrowTypeEnum::Date32(_) => DataType::Date32, arrow_type::ArrowTypeEnum::Date64(_) => DataType::Date64, arrow_type::ArrowTypeEnum::Duration(time_unit) => { - DataType::Duration(protobuf::TimeUnit::try_from(time_unit)?.into()) + DataType::Duration(parse_i32_to_time_unit(time_unit)?) } arrow_type::ArrowTypeEnum::Timestamp(protobuf::Timestamp { time_unit, timezone, }) => DataType::Timestamp( - protobuf::TimeUnit::try_from(time_unit)?.into(), + parse_i32_to_time_unit(time_unit)?, match timezone.len() { 0 => None, _ => Some(timezone.to_owned()), }, ), arrow_type::ArrowTypeEnum::Time32(time_unit) => { - DataType::Time32(protobuf::TimeUnit::try_from(time_unit)?.into()) + DataType::Time32(parse_i32_to_time_unit(time_unit)?) } arrow_type::ArrowTypeEnum::Time64(time_unit) => { - DataType::Time64(protobuf::TimeUnit::try_from(time_unit)?.into()) + DataType::Time64(parse_i32_to_time_unit(time_unit)?) + } + arrow_type::ArrowTypeEnum::Interval(interval_unit) => { + DataType::Interval(parse_i32_to_interval_unit(interval_unit)?) } - arrow_type::ArrowTypeEnum::Interval(interval_unit) => DataType::Interval( - protobuf::IntervalUnit::try_from(interval_unit)?.into(), - ), arrow_type::ArrowTypeEnum::Decimal(protobuf::Decimal { precision, scale, @@ -523,15 +523,6 @@ impl From for BuiltInWindowFunction { } } -impl TryFrom<&i32> for protobuf::AggregateFunction { - type Error = Error; - - fn try_from(value: &i32) -> Result { - protobuf::AggregateFunction::from_i32(*value) - .ok_or_else(|| Error::unknown("AggregateFunction", *value)) - } -} - impl TryFrom<&protobuf::Schema> for Schema { type Error = Error; @@ -715,6 +706,117 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { } } +impl TryFrom for WindowFrame { + type Error = Error; + + fn try_from(window: protobuf::WindowFrame) -> Result { + let units = protobuf::WindowFrameUnits::from_i32(window.window_frame_units) + .ok_or_else(|| Error::unknown("WindowFrameUnits", window.window_frame_units))? + .into(); + let start_bound = window.start_bound.required("start_bound")?; + let end_bound = window + .end_bound + .map(|end_bound| match end_bound { + protobuf::window_frame::EndBound::Bound(end_bound) => { + end_bound.try_into() + } + }) + .transpose()? + .unwrap_or(WindowFrameBound::CurrentRow); + Ok(Self { + units, + start_bound, + end_bound, + }) + } +} + +impl TryFrom for WindowFrameBound { + type Error = Error; + + fn try_from(bound: protobuf::WindowFrameBound) -> Result { + let bound_type = + protobuf::WindowFrameBoundType::from_i32(bound.window_frame_bound_type) + .ok_or_else(|| { + Error::unknown("WindowFrameBoundType", bound.window_frame_bound_type) + })?; + match bound_type { + protobuf::WindowFrameBoundType::CurrentRow => Ok(Self::CurrentRow), + protobuf::WindowFrameBoundType::Preceding => match bound.bound_value { + Some(x) => Ok(Self::Preceding(ScalarValue::try_from(&x)?)), + None => Ok(Self::Preceding(ScalarValue::UInt64(None))), + }, + protobuf::WindowFrameBoundType::Following => match bound.bound_value { + Some(x) => Ok(Self::Following(ScalarValue::try_from(&x)?)), + None => Ok(Self::Following(ScalarValue::UInt64(None))), + }, + } + } +} + +impl From for TimeUnit { + fn from(time_unit: protobuf::TimeUnit) -> Self { + match time_unit { + protobuf::TimeUnit::Second => TimeUnit::Second, + protobuf::TimeUnit::Millisecond => TimeUnit::Millisecond, + protobuf::TimeUnit::Microsecond => TimeUnit::Microsecond, + protobuf::TimeUnit::Nanosecond => TimeUnit::Nanosecond, + } + } +} + +impl From for IntervalUnit { + fn from(interval_unit: protobuf::IntervalUnit) -> Self { + match interval_unit { + protobuf::IntervalUnit::YearMonth => IntervalUnit::YearMonth, + protobuf::IntervalUnit::DayTime => IntervalUnit::DayTime, + protobuf::IntervalUnit::MonthDayNano => IntervalUnit::MonthDayNano, + } + } +} + +impl From for JoinType { + fn from(t: protobuf::JoinType) -> Self { + match t { + protobuf::JoinType::Inner => JoinType::Inner, + protobuf::JoinType::Left => JoinType::Left, + protobuf::JoinType::Right => JoinType::Right, + protobuf::JoinType::Full => JoinType::Full, + protobuf::JoinType::Leftsemi => JoinType::LeftSemi, + protobuf::JoinType::Rightsemi => JoinType::RightSemi, + protobuf::JoinType::Leftanti => JoinType::LeftAnti, + protobuf::JoinType::Rightanti => JoinType::RightAnti, + } + } +} + +impl From for JoinConstraint { + fn from(t: protobuf::JoinConstraint) -> Self { + match t { + protobuf::JoinConstraint::On => JoinConstraint::On, + protobuf::JoinConstraint::Using => JoinConstraint::Using, + } + } +} + +pub fn parse_i32_to_time_unit(value: &i32) -> Result { + protobuf::TimeUnit::from_i32(*value) + .map(|t| t.into()) + .ok_or_else(|| Error::unknown("TimeUnit", *value)) +} + +pub fn parse_i32_to_interval_unit(value: &i32) -> Result { + protobuf::IntervalUnit::from_i32(*value) + .map(|t| t.into()) + .ok_or_else(|| Error::unknown("IntervalUnit", *value)) +} + +pub fn parse_i32_to_aggregate_function(value: &i32) -> Result { + protobuf::AggregateFunction::from_i32(*value) + .map(|a| a.into()) + .ok_or_else(|| Error::unknown("AggregateFunction", *value)) +} + /// Ensures that all `values` are of type DataType::List and have the /// same type as field fn validate_list_values(field: &Field, values: &[ScalarValue]) -> Result<(), Error> { @@ -818,7 +920,7 @@ pub fn parse_expr( match window_function { window_expr_node::WindowFunction::AggrFunction(i) => { - let aggr_function = protobuf::AggregateFunction::try_from(i)?.into(); + let aggr_function = parse_i32_to_aggregate_function(i)?; Ok(Expr::WindowFunction { fun: datafusion_expr::window_function::WindowFunction::AggregateFunction( @@ -852,7 +954,7 @@ pub fn parse_expr( } } ExprType::AggregateExpr(expr) => { - let fun = protobuf::AggregateFunction::try_from(&expr.aggr_function)?.into(); + let fun = parse_i32_to_aggregate_function(&expr.aggr_function)?; Ok(Expr::AggregateFunction { fun, @@ -1247,93 +1349,6 @@ fn parse_escape_char(s: &str) -> Result, DataFusionError> { } } -impl TryFrom for WindowFrame { - type Error = Error; - - fn try_from(window: protobuf::WindowFrame) -> Result { - let units = protobuf::WindowFrameUnits::from_i32(window.window_frame_units) - .ok_or_else(|| Error::unknown("WindowFrameUnits", window.window_frame_units))? - .into(); - let start_bound = window.start_bound.required("start_bound")?; - let end_bound = window - .end_bound - .map(|end_bound| match end_bound { - protobuf::window_frame::EndBound::Bound(end_bound) => { - end_bound.try_into() - } - }) - .transpose()? - .unwrap_or(WindowFrameBound::CurrentRow); - Ok(Self { - units, - start_bound, - end_bound, - }) - } -} - -impl TryFrom for WindowFrameBound { - type Error = Error; - - fn try_from(bound: protobuf::WindowFrameBound) -> Result { - let bound_type = - protobuf::WindowFrameBoundType::from_i32(bound.window_frame_bound_type) - .ok_or_else(|| { - Error::unknown("WindowFrameBoundType", bound.window_frame_bound_type) - })?; - match bound_type { - protobuf::WindowFrameBoundType::CurrentRow => Ok(Self::CurrentRow), - protobuf::WindowFrameBoundType::Preceding => match bound.bound_value { - Some(x) => Ok(Self::Preceding(ScalarValue::try_from(&x)?)), - None => Ok(Self::Preceding(ScalarValue::UInt64(None))), - }, - protobuf::WindowFrameBoundType::Following => match bound.bound_value { - Some(x) => Ok(Self::Following(ScalarValue::try_from(&x)?)), - None => Ok(Self::Following(ScalarValue::UInt64(None))), - }, - } - } -} - -impl TryFrom<&i32> for protobuf::TimeUnit { - type Error = Error; - - fn try_from(value: &i32) -> Result { - protobuf::TimeUnit::from_i32(*value) - .ok_or_else(|| Error::unknown("TimeUnit", *value)) - } -} - -impl From for TimeUnit { - fn from(time_unit: protobuf::TimeUnit) -> Self { - match time_unit { - protobuf::TimeUnit::Second => TimeUnit::Second, - protobuf::TimeUnit::Millisecond => TimeUnit::Millisecond, - protobuf::TimeUnit::Microsecond => TimeUnit::Microsecond, - protobuf::TimeUnit::Nanosecond => TimeUnit::Nanosecond, - } - } -} - -impl TryFrom<&i32> for protobuf::IntervalUnit { - type Error = Error; - - fn try_from(value: &i32) -> Result { - protobuf::IntervalUnit::from_i32(*value) - .ok_or_else(|| Error::unknown("IntervalUnit", *value)) - } -} - -impl From for IntervalUnit { - fn from(interval_unit: protobuf::IntervalUnit) -> Self { - match interval_unit { - protobuf::IntervalUnit::YearMonth => IntervalUnit::YearMonth, - protobuf::IntervalUnit::DayTime => IntervalUnit::DayTime, - protobuf::IntervalUnit::MonthDayNano => IntervalUnit::MonthDayNano, - } - } -} - // panic here because no better way to convert from Vec to Array fn vec_to_array(v: Vec) -> [T; N] { v.try_into().unwrap_or_else(|v: Vec| { diff --git a/datafusion/proto/src/logical_plan.rs b/datafusion/proto/src/logical_plan/mod.rs similarity index 51% rename from datafusion/proto/src/logical_plan.rs rename to datafusion/proto/src/logical_plan/mod.rs index 09ae9e41e7ec..b5a4945df9a9 100644 --- a/datafusion/proto/src/logical_plan.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -15,18 +15,17 @@ // specific language governing permissions and limitations // under the License. +use crate::common::{byte_to_string, proto_error, str_to_byte}; use crate::protobuf::logical_plan_node::LogicalPlanType::CustomScan; use crate::protobuf::CustomTableScanNode; use crate::{ - from_proto::{self, parse_expr}, + convert_required, protobuf::{ self, listing_table_scan_node::FileFormatType, logical_plan_node::LogicalPlanType, LogicalExtensionNode, LogicalPlanNode, }, - to_proto, }; use arrow::datatypes::{DataType, Schema, SchemaRef}; -use datafusion::datasource::TableProvider; use datafusion::{ datasource::{ file_format::{ @@ -34,17 +33,18 @@ use datafusion::{ }, listing::{ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl}, view::ViewTable, + TableProvider, }, datasource::{provider_as_source, source_as_provider}, prelude::SessionContext, }; use datafusion_common::{context, Column, DataFusionError, OwnedTableReference}; -use datafusion_expr::logical_plan::{builder::project, Prepare}; use datafusion_expr::{ logical_plan::{ - Aggregate, CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateView, - CrossJoin, Distinct, EmptyRelation, Extension, Join, JoinConstraint, JoinType, - Limit, Projection, Repartition, Sort, SubqueryAlias, TableScan, Values, Window, + builder::project, Aggregate, CreateCatalog, CreateCatalogSchema, + CreateExternalTable, CreateView, CrossJoin, Distinct, EmptyRelation, Extension, + Join, JoinConstraint, Limit, Prepare, Projection, Repartition, Sort, + SubqueryAlias, TableScan, Values, Window, }, Expr, LogicalPlan, LogicalPlanBuilder, }; @@ -53,24 +53,19 @@ use prost::Message; use std::fmt::Debug; use std::sync::Arc; -fn byte_to_string(b: u8) -> Result { - let b = &[b]; - let b = std::str::from_utf8(b) - .map_err(|_| DataFusionError::Internal("Invalid CSV delimiter".to_owned()))?; - Ok(b.to_owned()) -} +pub mod from_proto; +pub mod to_proto; -fn str_to_byte(s: &str) -> Result { - if s.len() != 1 { - return Err(DataFusionError::Internal( - "Invalid CSV delimiter".to_owned(), - )); +impl From for DataFusionError { + fn from(e: from_proto::Error) -> Self { + DataFusionError::Plan(e.to_string()) } - Ok(s.as_bytes()[0]) } -pub(crate) fn proto_error>(message: S) -> DataFusionError { - DataFusionError::Internal(message.into()) +impl From for DataFusionError { + fn from(e: to_proto::Error) -> Self { + DataFusionError::Plan(e.to_string()) + } } pub trait AsLogicalPlan: Debug + Send + Sync + Clone { @@ -183,87 +178,6 @@ macro_rules! into_logical_plan { }}; } -#[macro_export] -macro_rules! convert_required { - ($PB:expr) => {{ - if let Some(field) = $PB.as_ref() { - Ok(field.try_into()?) - } else { - Err(proto_error("Missing required field in protobuf")) - } - }}; -} - -#[macro_export] -macro_rules! into_required { - ($PB:expr) => {{ - if let Some(field) = $PB.as_ref() { - Ok(field.into()) - } else { - Err(proto_error("Missing required field in protobuf")) - } - }}; -} - -#[macro_export] -macro_rules! convert_box_required { - ($PB:expr) => {{ - if let Some(field) = $PB.as_ref() { - field.as_ref().try_into() - } else { - Err(proto_error("Missing required field in protobuf")) - } - }}; -} - -impl From for JoinType { - fn from(t: protobuf::JoinType) -> Self { - match t { - protobuf::JoinType::Inner => JoinType::Inner, - protobuf::JoinType::Left => JoinType::Left, - protobuf::JoinType::Right => JoinType::Right, - protobuf::JoinType::Full => JoinType::Full, - protobuf::JoinType::Leftsemi => JoinType::LeftSemi, - protobuf::JoinType::Rightsemi => JoinType::RightSemi, - protobuf::JoinType::Leftanti => JoinType::LeftAnti, - protobuf::JoinType::Rightanti => JoinType::RightAnti, - } - } -} - -impl From for protobuf::JoinType { - fn from(t: JoinType) -> Self { - match t { - JoinType::Inner => protobuf::JoinType::Inner, - JoinType::Left => protobuf::JoinType::Left, - JoinType::Right => protobuf::JoinType::Right, - JoinType::Full => protobuf::JoinType::Full, - JoinType::LeftSemi => protobuf::JoinType::Leftsemi, - JoinType::RightSemi => protobuf::JoinType::Rightsemi, - JoinType::LeftAnti => protobuf::JoinType::Leftanti, - JoinType::RightAnti => protobuf::JoinType::Rightanti, - } - } -} - -impl From for JoinConstraint { - fn from(t: protobuf::JoinConstraint) -> Self { - match t { - protobuf::JoinConstraint::On => JoinConstraint::On, - protobuf::JoinConstraint::Using => JoinConstraint::Using, - } - } -} - -impl From for protobuf::JoinConstraint { - fn from(t: JoinConstraint) -> Self { - match t { - JoinConstraint::On => protobuf::JoinConstraint::On, - JoinConstraint::Using => protobuf::JoinConstraint::Using, - } - } -} - fn from_owned_table_reference( table_ref: Option<&protobuf::OwnedTableReference>, error_context: &str, @@ -312,27 +226,26 @@ impl AsLogicalPlan for LogicalPlanNode { match plan { LogicalPlanType::Values(values) => { let n_cols = values.n_cols as usize; - let values: Vec> = - if values.values_list.is_empty() { - Ok(Vec::new()) - } else if values.values_list.len() % n_cols != 0 { - Err(DataFusionError::Internal(format!( - "Invalid values list length, expect {} to be divisible by {}", - values.values_list.len(), - n_cols - ))) - } else { - values - .values_list - .chunks_exact(n_cols) - .map(|r| { - r.iter() - .map(|expr| parse_expr(expr, ctx)) - .collect::, from_proto::Error>>() - }) - .collect::, _>>() - .map_err(|e| e.into()) - }?; + let values: Vec> = if values.values_list.is_empty() { + Ok(Vec::new()) + } else if values.values_list.len() % n_cols != 0 { + Err(DataFusionError::Internal(format!( + "Invalid values list length, expect {} to be divisible by {}", + values.values_list.len(), + n_cols + ))) + } else { + values + .values_list + .chunks_exact(n_cols) + .map(|r| { + r.iter() + .map(|expr| from_proto::parse_expr(expr, ctx)) + .collect::, from_proto::Error>>() + }) + .collect::, _>>() + .map_err(|e| e.into()) + }?; LogicalPlanBuilder::values(values)?.build() } LogicalPlanType::Projection(projection) => { @@ -341,7 +254,7 @@ impl AsLogicalPlan for LogicalPlanNode { let expr: Vec = projection .expr .iter() - .map(|expr| parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx)) .collect::, _>>()?; let new_proj = project(input, expr)?; @@ -362,7 +275,7 @@ impl AsLogicalPlan for LogicalPlanNode { let expr: Expr = selection .expr .as_ref() - .map(|expr| parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx)) .transpose()? .ok_or_else(|| { DataFusionError::Internal("expression required".to_string()) @@ -376,7 +289,7 @@ impl AsLogicalPlan for LogicalPlanNode { let window_expr = window .window_expr .iter() - .map(|expr| parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx)) .collect::, _>>()?; LogicalPlanBuilder::from(input).window(window_expr)?.build() } @@ -386,12 +299,12 @@ impl AsLogicalPlan for LogicalPlanNode { let group_expr = aggregate .group_expr .iter() - .map(|expr| parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx)) .collect::, _>>()?; let aggr_expr = aggregate .aggr_expr .iter() - .map(|expr| parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx)) .collect::, _>>()?; LogicalPlanBuilder::from(input) .aggregate(group_expr, aggr_expr)? @@ -413,13 +326,13 @@ impl AsLogicalPlan for LogicalPlanNode { let filters = scan .filters .iter() - .map(|expr| parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx)) .collect::, _>>()?; let file_sort_order = scan .file_sort_order .iter() - .map(|expr| parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx)) .collect::, _>>()?; // Protobuf doesn't distinguish between "not present" @@ -512,7 +425,7 @@ impl AsLogicalPlan for LogicalPlanNode { let filters = scan .filters .iter() - .map(|expr| parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx)) .collect::, _>>()?; let provider = extension_codec.try_decode_table_provider( &scan.custom_table_data, @@ -534,7 +447,7 @@ impl AsLogicalPlan for LogicalPlanNode { let sort_expr: Vec = sort .expr .iter() - .map(|expr| parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx)) .collect::, _>>()?; LogicalPlanBuilder::from(input).sort(sort_expr)?.build() } @@ -556,7 +469,7 @@ impl AsLogicalPlan for LogicalPlanNode { }) => Partitioning::Hash( pb_hash_expr .iter() - .map(|expr| parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx)) .collect::, _>>()?, partition_count as usize, ), @@ -718,7 +631,7 @@ impl AsLogicalPlan for LogicalPlanNode { let filter: Option = join .filter .as_ref() - .map(|expr| parse_expr(expr, ctx)) + .map(|expr| from_proto::parse_expr(expr, ctx)) .map_or(Ok(None), |v| v.map(Some))?; let builder = LogicalPlanBuilder::from(into_logical_plan!( @@ -1426,3 +1339,1359 @@ impl AsLogicalPlan for LogicalPlanNode { } } } + +#[cfg(test)] +mod roundtrip_tests { + use super::from_proto::parse_expr; + use super::protobuf; + use crate::bytes::{ + logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, + logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec, + }; + use crate::logical_plan::LogicalExtensionCodec; + use arrow::datatypes::{Schema, SchemaRef}; + use arrow::{ + array::ArrayRef, + datatypes::{ + DataType, Field, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, + TimeUnit, UnionMode, + }, + }; + use datafusion::datasource::datasource::TableProviderFactory; + use datafusion::datasource::TableProvider; + use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use datafusion::physical_plan::functions::make_scalar_function; + use datafusion::prelude::{ + create_udf, CsvReadOptions, SessionConfig, SessionContext, + }; + use datafusion::test_util::{TestTableFactory, TestTableProvider}; + use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue}; + use datafusion_expr::expr::{Between, BinaryExpr, Case, Cast, GroupingSet, Like}; + use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNode}; + use datafusion_expr::{ + col, lit, Accumulator, AggregateFunction, + BuiltinScalarFunction::{Sqrt, Substr}, + Expr, LogicalPlan, Operator, Volatility, + }; + use datafusion_expr::{ + create_udaf, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, + }; + use prost::Message; + use std::any::Any; + use std::collections::HashMap; + use std::fmt; + use std::fmt::Debug; + use std::fmt::Formatter; + use std::sync::Arc; + + #[cfg(feature = "json")] + fn roundtrip_json_test(proto: &protobuf::LogicalExprNode) { + let string = serde_json::to_string(proto).unwrap(); + let back: protobuf::LogicalExprNode = serde_json::from_str(&string).unwrap(); + assert_eq!(proto, &back); + } + + #[cfg(not(feature = "json"))] + fn roundtrip_json_test(_proto: &protobuf::LogicalExprNode) {} + + // Given a DataFusion logical Expr, convert it to protobuf and back, using debug formatting to test + // equality. + fn roundtrip_expr_test(initial_struct: T, ctx: SessionContext) + where + for<'a> &'a T: TryInto + Debug, + E: Debug, + { + let proto: protobuf::LogicalExprNode = (&initial_struct).try_into().unwrap(); + let round_trip: Expr = parse_expr(&proto, &ctx).unwrap(); + + assert_eq!( + format!("{:?}", &initial_struct), + format!("{:?}", round_trip) + ); + + roundtrip_json_test(&proto); + } + + fn new_box_field(name: &str, dt: DataType, nullable: bool) -> Box { + Box::new(Field::new(name, dt, nullable)) + } + + #[tokio::test] + async fn roundtrip_logical_plan() -> Result<(), DataFusionError> { + let ctx = SessionContext::new(); + ctx.register_csv("t1", "testdata/test.csv", CsvReadOptions::default()) + .await?; + let scan = ctx.table("t1")?.to_logical_plan()?; + let topk_plan = LogicalPlan::Extension(Extension { + node: Arc::new(TopKPlanNode::new(3, scan, col("revenue"))), + }); + let extension_codec = TopKExtensionCodec {}; + let bytes = + logical_plan_to_bytes_with_extension_codec(&topk_plan, &extension_codec)?; + let logical_round_trip = + logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &extension_codec)?; + assert_eq!( + format!("{:?}", topk_plan), + format!("{:?}", logical_round_trip) + ); + Ok(()) + } + + #[derive(Clone, PartialEq, Eq, ::prost::Message)] + pub struct TestTableProto { + /// URL of the table root + #[prost(string, tag = "1")] + pub url: String, + } + + #[derive(Debug)] + pub struct TestTableProviderCodec {} + + impl LogicalExtensionCodec for TestTableProviderCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[LogicalPlan], + _ctx: &SessionContext, + ) -> Result { + Err(DataFusionError::NotImplemented( + "No extension codec provided".to_string(), + )) + } + + fn try_encode( + &self, + _node: &Extension, + _buf: &mut Vec, + ) -> Result<(), DataFusionError> { + Err(DataFusionError::NotImplemented( + "No extension codec provided".to_string(), + )) + } + + fn try_decode_table_provider( + &self, + buf: &[u8], + schema: SchemaRef, + _ctx: &SessionContext, + ) -> Result, DataFusionError> { + let msg = TestTableProto::decode(buf).map_err(|_| { + DataFusionError::Internal("Error decoding test table".to_string()) + })?; + let provider = TestTableProvider { + url: msg.url, + schema, + }; + Ok(Arc::new(provider)) + } + + fn try_encode_table_provider( + &self, + node: Arc, + buf: &mut Vec, + ) -> Result<(), DataFusionError> { + let table = node + .as_ref() + .as_any() + .downcast_ref::() + .expect("Can't encode non-test tables"); + let msg = TestTableProto { + url: table.url.clone(), + }; + msg.encode(buf).map_err(|_| { + DataFusionError::Internal("Error encoding test table".to_string()) + }) + } + } + + #[tokio::test] + async fn roundtrip_custom_tables() -> Result<(), DataFusionError> { + let mut table_factories: HashMap> = + HashMap::new(); + table_factories.insert("TESTTABLE".to_string(), Arc::new(TestTableFactory {})); + let cfg = RuntimeConfig::new().with_table_factories(table_factories); + let env = RuntimeEnv::new(cfg).unwrap(); + let ses = SessionConfig::new(); + let ctx = SessionContext::with_config_rt(ses, Arc::new(env)); + + let sql = "CREATE EXTERNAL TABLE t STORED AS testtable LOCATION 's3://bucket/schema/table';"; + ctx.sql(sql).await.unwrap(); + + let codec = TestTableProviderCodec {}; + let scan = ctx.table("t")?.to_logical_plan()?; + let bytes = logical_plan_to_bytes_with_extension_codec(&scan, &codec)?; + let logical_round_trip = + logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?; + assert_eq!(format!("{:?}", scan), format!("{:?}", logical_round_trip)); + Ok(()) + } + + #[tokio::test] + async fn roundtrip_logical_plan_aggregation() -> Result<(), DataFusionError> { + let ctx = SessionContext::new(); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Decimal128(15, 2), true), + ]); + + ctx.register_csv( + "t1", + "testdata/test.csv", + CsvReadOptions::default().schema(&schema), + ) + .await?; + + let query = + "SELECT a, SUM(b + 1) as b_sum FROM t1 GROUP BY a ORDER BY b_sum DESC"; + let plan = ctx.sql(query).await?.to_logical_plan()?; + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip)); + + Ok(()) + } + + #[tokio::test] + async fn roundtrip_single_count_distinct() -> Result<(), DataFusionError> { + let ctx = SessionContext::new(); + + let schema = Schema::new(vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Decimal128(15, 2), true), + ]); + + ctx.register_csv( + "t1", + "testdata/test.csv", + CsvReadOptions::default().schema(&schema), + ) + .await?; + + let query = "SELECT a, COUNT(DISTINCT b) as b_cd FROM t1 GROUP BY a"; + let plan = ctx.sql(query).await?.to_logical_plan()?; + + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip)); + + Ok(()) + } + + #[tokio::test] + async fn roundtrip_logical_plan_with_extension() -> Result<(), DataFusionError> { + let ctx = SessionContext::new(); + ctx.register_csv("t1", "testdata/test.csv", CsvReadOptions::default()) + .await?; + let plan = ctx.table("t1")?.to_logical_plan()?; + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip)); + Ok(()) + } + + #[tokio::test] + async fn roundtrip_logical_plan_with_view_scan() -> Result<(), DataFusionError> { + let ctx = SessionContext::new(); + ctx.register_csv("t1", "testdata/test.csv", CsvReadOptions::default()) + .await?; + ctx.sql("CREATE VIEW view_t1(a, b) AS SELECT a, b FROM t1") + .await?; + let plan = ctx.sql("SELECT * FROM view_t1").await?.to_logical_plan()?; + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip)); + Ok(()) + } + + pub mod proto { + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct TopKPlanProto { + #[prost(uint64, tag = "1")] + pub k: u64, + + #[prost(message, optional, tag = "2")] + pub expr: ::core::option::Option, + } + + #[derive(Clone, PartialEq, Eq, ::prost::Message)] + pub struct TopKExecProto { + #[prost(uint64, tag = "1")] + pub k: u64, + } + } + + struct TopKPlanNode { + k: usize, + input: LogicalPlan, + /// The sort expression (this example only supports a single sort + /// expr) + expr: Expr, + } + + impl TopKPlanNode { + pub fn new(k: usize, input: LogicalPlan, expr: Expr) -> Self { + Self { k, input, expr } + } + } + + impl Debug for TopKPlanNode { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + self.fmt_for_explain(f) + } + } + + impl UserDefinedLogicalNode for TopKPlanNode { + fn as_any(&self) -> &dyn Any { + self + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + /// Schema for TopK is the same as the input + fn schema(&self) -> &DFSchemaRef { + self.input.schema() + } + + fn expressions(&self) -> Vec { + vec![self.expr.clone()] + } + + /// For example: `TopK: k=10` + fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "TopK: k={}", self.k) + } + + fn from_template( + &self, + exprs: &[Expr], + inputs: &[LogicalPlan], + ) -> Arc { + assert_eq!(inputs.len(), 1, "input size inconsistent"); + assert_eq!(exprs.len(), 1, "expression size inconsistent"); + Arc::new(TopKPlanNode { + k: self.k, + input: inputs[0].clone(), + expr: exprs[0].clone(), + }) + } + } + + #[derive(Debug)] + pub struct TopKExtensionCodec {} + + impl LogicalExtensionCodec for TopKExtensionCodec { + fn try_decode( + &self, + buf: &[u8], + inputs: &[LogicalPlan], + ctx: &SessionContext, + ) -> Result { + if let Some((input, _)) = inputs.split_first() { + let proto = proto::TopKPlanProto::decode(buf).map_err(|e| { + DataFusionError::Internal(format!( + "failed to decode logical plan: {:?}", + e + )) + })?; + + if let Some(expr) = proto.expr.as_ref() { + let node = TopKPlanNode::new( + proto.k as usize, + input.clone(), + parse_expr(expr, ctx)?, + ); + + Ok(Extension { + node: Arc::new(node), + }) + } else { + Err(DataFusionError::Internal( + "invalid plan, no expr".to_string(), + )) + } + } else { + Err(DataFusionError::Internal( + "invalid plan, no input".to_string(), + )) + } + } + + fn try_encode( + &self, + node: &Extension, + buf: &mut Vec, + ) -> Result<(), DataFusionError> { + if let Some(exec) = node.node.as_any().downcast_ref::() { + let proto = proto::TopKPlanProto { + k: exec.k as u64, + expr: Some((&exec.expr).try_into()?), + }; + + proto.encode(buf).map_err(|e| { + DataFusionError::Internal(format!( + "failed to encode logical plan: {:?}", + e + )) + })?; + + Ok(()) + } else { + Err(DataFusionError::Internal( + "unsupported plan type".to_string(), + )) + } + } + + fn try_decode_table_provider( + &self, + _buf: &[u8], + _schema: SchemaRef, + _ctx: &SessionContext, + ) -> Result, DataFusionError> { + Err(DataFusionError::Internal( + "unsupported plan type".to_string(), + )) + } + + fn try_encode_table_provider( + &self, + _node: Arc, + _buf: &mut Vec, + ) -> Result<(), DataFusionError> { + Err(DataFusionError::Internal( + "unsupported plan type".to_string(), + )) + } + } + + #[test] + fn scalar_values_error_serialization() { + let should_fail_on_seralize: Vec = vec![ + // Should fail due to empty values + ScalarValue::Struct( + Some(vec![]), + Box::new(vec![Field::new("item", DataType::Int16, true)]), + ), + // Should fail due to inconsistent types in the list + ScalarValue::new_list( + Some(vec![ + ScalarValue::Int16(None), + ScalarValue::Float32(Some(32.0)), + ]), + DataType::List(new_box_field("item", DataType::Int16, true)), + ), + ScalarValue::new_list( + Some(vec![ + ScalarValue::Float32(None), + ScalarValue::Float32(Some(32.0)), + ]), + DataType::List(new_box_field("item", DataType::Int16, true)), + ), + ScalarValue::new_list( + Some(vec![ + ScalarValue::Float32(None), + ScalarValue::Float32(Some(32.0)), + ]), + DataType::Int16, + ), + ScalarValue::new_list( + Some(vec![ + ScalarValue::new_list( + None, + DataType::List(new_box_field("level2", DataType::Float32, true)), + ), + ScalarValue::new_list( + Some(vec![ + ScalarValue::Float32(Some(-213.1)), + ScalarValue::Float32(None), + ScalarValue::Float32(Some(5.5)), + ScalarValue::Float32(Some(2.0)), + ScalarValue::Float32(Some(1.0)), + ]), + DataType::List(new_box_field("level2", DataType::Float32, true)), + ), + ScalarValue::new_list( + None, + DataType::List(new_box_field( + "lists are typed inconsistently", + DataType::Int16, + true, + )), + ), + ]), + DataType::List(new_box_field( + "level1", + DataType::List(new_box_field("level2", DataType::Float32, true)), + true, + )), + ), + ]; + + for test_case in should_fail_on_seralize.into_iter() { + let proto: Result = + (&test_case).try_into(); + + // Validation is also done on read, so if serialization passed + // also try to convert back to ScalarValue + if let Ok(proto) = proto { + let res: Result = (&proto).try_into(); + assert!( + res.is_err(), + "The value {:?} unexpectedly serialized without error:{:?}", + test_case, + res + ); + } + } + } + + #[test] + fn round_trip_scalar_values() { + let should_pass: Vec = vec![ + ScalarValue::Boolean(None), + ScalarValue::Float32(None), + ScalarValue::Float64(None), + ScalarValue::Int8(None), + ScalarValue::Int16(None), + ScalarValue::Int32(None), + ScalarValue::Int64(None), + ScalarValue::UInt8(None), + ScalarValue::UInt16(None), + ScalarValue::UInt32(None), + ScalarValue::UInt64(None), + ScalarValue::Utf8(None), + ScalarValue::LargeUtf8(None), + ScalarValue::new_list(None, DataType::Boolean), + ScalarValue::Date32(None), + ScalarValue::Boolean(Some(true)), + ScalarValue::Boolean(Some(false)), + ScalarValue::Float32(Some(1.0)), + ScalarValue::Float32(Some(f32::MAX)), + ScalarValue::Float32(Some(f32::MIN)), + ScalarValue::Float32(Some(-2000.0)), + ScalarValue::Float64(Some(1.0)), + ScalarValue::Float64(Some(f64::MAX)), + ScalarValue::Float64(Some(f64::MIN)), + ScalarValue::Float64(Some(-2000.0)), + ScalarValue::Int8(Some(i8::MIN)), + ScalarValue::Int8(Some(i8::MAX)), + ScalarValue::Int8(Some(0)), + ScalarValue::Int8(Some(-15)), + ScalarValue::Int16(Some(i16::MIN)), + ScalarValue::Int16(Some(i16::MAX)), + ScalarValue::Int16(Some(0)), + ScalarValue::Int16(Some(-15)), + ScalarValue::Int32(Some(i32::MIN)), + ScalarValue::Int32(Some(i32::MAX)), + ScalarValue::Int32(Some(0)), + ScalarValue::Int32(Some(-15)), + ScalarValue::Int64(Some(i64::MIN)), + ScalarValue::Int64(Some(i64::MAX)), + ScalarValue::Int64(Some(0)), + ScalarValue::Int64(Some(-15)), + ScalarValue::UInt8(Some(u8::MAX)), + ScalarValue::UInt8(Some(0)), + ScalarValue::UInt16(Some(u16::MAX)), + ScalarValue::UInt16(Some(0)), + ScalarValue::UInt32(Some(u32::MAX)), + ScalarValue::UInt32(Some(0)), + ScalarValue::UInt64(Some(u64::MAX)), + ScalarValue::UInt64(Some(0)), + ScalarValue::Utf8(Some(String::from("Test string "))), + ScalarValue::LargeUtf8(Some(String::from("Test Large utf8"))), + ScalarValue::Date32(Some(0)), + ScalarValue::Date32(Some(i32::MAX)), + ScalarValue::Date32(None), + ScalarValue::Date64(Some(0)), + ScalarValue::Date64(Some(i64::MAX)), + ScalarValue::Date64(None), + ScalarValue::Time32Second(Some(0)), + ScalarValue::Time32Second(Some(i32::MAX)), + ScalarValue::Time32Second(None), + ScalarValue::Time32Millisecond(Some(0)), + ScalarValue::Time32Millisecond(Some(i32::MAX)), + ScalarValue::Time32Millisecond(None), + ScalarValue::Time64Microsecond(Some(0)), + ScalarValue::Time64Microsecond(Some(i64::MAX)), + ScalarValue::Time64Microsecond(None), + ScalarValue::Time64Nanosecond(Some(0)), + ScalarValue::Time64Nanosecond(Some(i64::MAX)), + ScalarValue::Time64Nanosecond(None), + ScalarValue::TimestampNanosecond(Some(0), None), + ScalarValue::TimestampNanosecond(Some(i64::MAX), None), + ScalarValue::TimestampNanosecond(Some(0), Some("UTC".to_string())), + ScalarValue::TimestampNanosecond(None, None), + ScalarValue::TimestampMicrosecond(Some(0), None), + ScalarValue::TimestampMicrosecond(Some(i64::MAX), None), + ScalarValue::TimestampMicrosecond(Some(0), Some("UTC".to_string())), + ScalarValue::TimestampMicrosecond(None, None), + ScalarValue::TimestampMillisecond(Some(0), None), + ScalarValue::TimestampMillisecond(Some(i64::MAX), None), + ScalarValue::TimestampMillisecond(Some(0), Some("UTC".to_string())), + ScalarValue::TimestampMillisecond(None, None), + ScalarValue::TimestampSecond(Some(0), None), + ScalarValue::TimestampSecond(Some(i64::MAX), None), + ScalarValue::TimestampSecond(Some(0), Some("UTC".to_string())), + ScalarValue::TimestampSecond(None, None), + ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(0, 0))), + ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value(1, 2))), + ScalarValue::IntervalDayTime(Some(IntervalDayTimeType::make_value( + i32::MAX, + i32::MAX, + ))), + ScalarValue::IntervalDayTime(None), + ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNanoType::make_value(0, 0, 0), + )), + ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNanoType::make_value(1, 2, 3), + )), + ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNanoType::make_value(i32::MAX, i32::MAX, i64::MAX), + )), + ScalarValue::IntervalMonthDayNano(None), + ScalarValue::new_list( + Some(vec![ + ScalarValue::Float32(Some(-213.1)), + ScalarValue::Float32(None), + ScalarValue::Float32(Some(5.5)), + ScalarValue::Float32(Some(2.0)), + ScalarValue::Float32(Some(1.0)), + ]), + DataType::Float32, + ), + ScalarValue::new_list( + Some(vec![ + ScalarValue::new_list(None, DataType::Float32), + ScalarValue::new_list( + Some(vec![ + ScalarValue::Float32(Some(-213.1)), + ScalarValue::Float32(None), + ScalarValue::Float32(Some(5.5)), + ScalarValue::Float32(Some(2.0)), + ScalarValue::Float32(Some(1.0)), + ]), + DataType::Float32, + ), + ]), + DataType::List(new_box_field("item", DataType::Float32, true)), + ), + ScalarValue::Dictionary( + Box::new(DataType::Int32), + Box::new(ScalarValue::Utf8(Some("foo".into()))), + ), + ScalarValue::Dictionary( + Box::new(DataType::Int32), + Box::new(ScalarValue::Utf8(None)), + ), + ScalarValue::Binary(Some(b"bar".to_vec())), + ScalarValue::Binary(None), + ScalarValue::LargeBinary(Some(b"bar".to_vec())), + ScalarValue::LargeBinary(None), + ScalarValue::Struct( + Some(vec![ + ScalarValue::Int32(Some(23)), + ScalarValue::Boolean(Some(false)), + ]), + Box::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Boolean, false), + ]), + ), + ScalarValue::Struct( + None, + Box::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("a", DataType::Boolean, false), + ]), + ), + ScalarValue::FixedSizeBinary( + b"bar".to_vec().len() as i32, + Some(b"bar".to_vec()), + ), + ScalarValue::FixedSizeBinary(0, None), + ScalarValue::FixedSizeBinary(5, None), + ]; + + for test_case in should_pass.into_iter() { + let proto: super::protobuf::ScalarValue = (&test_case) + .try_into() + .expect("failed conversion to protobuf"); + + let roundtrip: ScalarValue = (&proto) + .try_into() + .expect("failed conversion from protobuf"); + + assert_eq!( + test_case, roundtrip, + "ScalarValue was not the same after round trip!\n\n\ + Input: {:?}\n\nRoundtrip: {:?}", + test_case, roundtrip + ); + } + } + + #[test] + fn round_trip_scalar_types() { + let should_pass: Vec = vec![ + DataType::Boolean, + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + DataType::Float32, + DataType::Float64, + DataType::Date32, + DataType::Time64(TimeUnit::Microsecond), + DataType::Time64(TimeUnit::Nanosecond), + DataType::Utf8, + DataType::LargeUtf8, + // Recursive list tests + DataType::List(new_box_field("level1", DataType::Boolean, true)), + DataType::List(new_box_field( + "Level1", + DataType::List(new_box_field("level2", DataType::Date32, true)), + true, + )), + ]; + + for test_case in should_pass.into_iter() { + let field = Field::new("item", test_case, true); + let proto: super::protobuf::Field = (&field).try_into().unwrap(); + let roundtrip: Field = (&proto).try_into().unwrap(); + assert_eq!(format!("{:?}", field), format!("{:?}", roundtrip)); + } + } + + #[test] + fn round_trip_datatype() { + let test_cases: Vec = vec![ + DataType::Null, + DataType::Boolean, + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + DataType::Float16, + DataType::Float32, + DataType::Float64, + // Add more timestamp tests + DataType::Timestamp(TimeUnit::Millisecond, None), + DataType::Date32, + DataType::Date64, + DataType::Time32(TimeUnit::Second), + DataType::Time32(TimeUnit::Millisecond), + DataType::Time32(TimeUnit::Microsecond), + DataType::Time32(TimeUnit::Nanosecond), + DataType::Time64(TimeUnit::Second), + DataType::Time64(TimeUnit::Millisecond), + DataType::Time64(TimeUnit::Microsecond), + DataType::Time64(TimeUnit::Nanosecond), + DataType::Duration(TimeUnit::Second), + DataType::Duration(TimeUnit::Millisecond), + DataType::Duration(TimeUnit::Microsecond), + DataType::Duration(TimeUnit::Nanosecond), + DataType::Interval(IntervalUnit::YearMonth), + DataType::Interval(IntervalUnit::DayTime), + DataType::Binary, + DataType::FixedSizeBinary(0), + DataType::FixedSizeBinary(1234), + DataType::FixedSizeBinary(-432), + DataType::LargeBinary, + DataType::Utf8, + DataType::LargeUtf8, + DataType::Decimal128(7, 12), + // Recursive list tests + DataType::List(new_box_field("Level1", DataType::Binary, true)), + DataType::List(new_box_field( + "Level1", + DataType::List(new_box_field( + "Level2", + DataType::FixedSizeBinary(53), + false, + )), + true, + )), + // Fixed size lists + DataType::FixedSizeList(new_box_field("Level1", DataType::Binary, true), 4), + DataType::FixedSizeList( + new_box_field( + "Level1", + DataType::List(new_box_field( + "Level2", + DataType::FixedSizeBinary(53), + false, + )), + true, + ), + 41, + ), + // Struct Testing + DataType::Struct(vec![ + Field::new("nullable", DataType::Boolean, false), + Field::new("name", DataType::Utf8, false), + Field::new("datatype", DataType::Binary, false), + ]), + DataType::Struct(vec![ + Field::new("nullable", DataType::Boolean, false), + Field::new("name", DataType::Utf8, false), + Field::new("datatype", DataType::Binary, false), + Field::new( + "nested_struct", + DataType::Struct(vec![ + Field::new("nullable", DataType::Boolean, false), + Field::new("name", DataType::Utf8, false), + Field::new("datatype", DataType::Binary, false), + ]), + true, + ), + ]), + DataType::Union( + vec![ + Field::new("nullable", DataType::Boolean, false), + Field::new("name", DataType::Utf8, false), + Field::new("datatype", DataType::Binary, false), + ], + vec![7, 5, 3], + UnionMode::Sparse, + ), + DataType::Union( + vec![ + Field::new("nullable", DataType::Boolean, false), + Field::new("name", DataType::Utf8, false), + Field::new("datatype", DataType::Binary, false), + Field::new( + "nested_struct", + DataType::Struct(vec![ + Field::new("nullable", DataType::Boolean, false), + Field::new("name", DataType::Utf8, false), + Field::new("datatype", DataType::Binary, false), + ]), + true, + ), + ], + vec![5, 8, 1], + UnionMode::Dense, + ), + DataType::Dictionary( + Box::new(DataType::Utf8), + Box::new(DataType::Struct(vec![ + Field::new("nullable", DataType::Boolean, false), + Field::new("name", DataType::Utf8, false), + Field::new("datatype", DataType::Binary, false), + ])), + ), + DataType::Dictionary( + Box::new(DataType::Decimal128(10, 50)), + Box::new(DataType::FixedSizeList( + new_box_field("Level1", DataType::Binary, true), + 4, + )), + ), + ]; + + for test_case in test_cases.into_iter() { + let proto: super::protobuf::ArrowType = (&test_case).try_into().unwrap(); + let roundtrip: DataType = (&proto).try_into().unwrap(); + assert_eq!(format!("{:?}", test_case), format!("{:?}", roundtrip)); + } + } + + #[test] + fn roundtrip_null_scalar_values() { + let test_types = vec![ + ScalarValue::Boolean(None), + ScalarValue::Float32(None), + ScalarValue::Float64(None), + ScalarValue::Int8(None), + ScalarValue::Int16(None), + ScalarValue::Int32(None), + ScalarValue::Int64(None), + ScalarValue::UInt8(None), + ScalarValue::UInt16(None), + ScalarValue::UInt32(None), + ScalarValue::UInt64(None), + ScalarValue::Utf8(None), + ScalarValue::LargeUtf8(None), + ScalarValue::Date32(None), + ScalarValue::TimestampMicrosecond(None, None), + ScalarValue::TimestampNanosecond(None, None), + ScalarValue::List( + None, + Box::new(Field::new("item", DataType::Boolean, false)), + ), + ]; + + for test_case in test_types.into_iter() { + let proto_scalar: super::protobuf::ScalarValue = + (&test_case).try_into().unwrap(); + let returned_scalar: datafusion::scalar::ScalarValue = + (&proto_scalar).try_into().unwrap(); + assert_eq!( + format!("{:?}", &test_case), + format!("{:?}", returned_scalar) + ); + } + } + + #[test] + fn roundtrip_not() { + let test_expr = Expr::Not(Box::new(lit(1.0_f32))); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + + #[test] + fn roundtrip_is_null() { + let test_expr = Expr::IsNull(Box::new(col("id"))); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + + #[test] + fn roundtrip_is_not_null() { + let test_expr = Expr::IsNotNull(Box::new(col("id"))); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + + #[test] + fn roundtrip_between() { + let test_expr = Expr::Between(Between::new( + Box::new(lit(1.0_f32)), + true, + Box::new(lit(2.0_f32)), + Box::new(lit(3.0_f32)), + )); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + + #[test] + fn roundtrip_binary_op() { + fn test(op: Operator) { + let test_expr = Expr::BinaryExpr(BinaryExpr::new( + Box::new(lit(1.0_f32)), + op, + Box::new(lit(2.0_f32)), + )); + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + test(Operator::StringConcat); + test(Operator::RegexNotIMatch); + test(Operator::RegexNotMatch); + test(Operator::RegexIMatch); + test(Operator::RegexMatch); + test(Operator::Like); + test(Operator::NotLike); + test(Operator::ILike); + test(Operator::NotILike); + test(Operator::BitwiseShiftRight); + test(Operator::BitwiseShiftLeft); + test(Operator::BitwiseAnd); + test(Operator::BitwiseOr); + test(Operator::BitwiseXor); + test(Operator::IsDistinctFrom); + test(Operator::IsNotDistinctFrom); + test(Operator::And); + test(Operator::Or); + test(Operator::Eq); + test(Operator::NotEq); + test(Operator::Lt); + test(Operator::LtEq); + test(Operator::Gt); + test(Operator::GtEq); + } + + #[test] + fn roundtrip_case() { + let test_expr = Expr::Case(Case::new( + Some(Box::new(lit(1.0_f32))), + vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))], + Some(Box::new(lit(4.0_f32))), + )); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + + #[test] + fn roundtrip_case_with_null() { + let test_expr = Expr::Case(Case::new( + Some(Box::new(lit(1.0_f32))), + vec![(Box::new(lit(2.0_f32)), Box::new(lit(3.0_f32)))], + Some(Box::new(Expr::Literal(ScalarValue::Null))), + )); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + + #[test] + fn roundtrip_null_literal() { + let test_expr = Expr::Literal(ScalarValue::Null); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + + #[test] + fn roundtrip_cast() { + let test_expr = Expr::Cast(Cast::new(Box::new(lit(1.0_f32)), DataType::Boolean)); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + + #[test] + fn roundtrip_sort_expr() { + let test_expr = Expr::Sort { + expr: Box::new(lit(1.0_f32)), + asc: true, + nulls_first: true, + }; + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + + #[test] + fn roundtrip_negative() { + let test_expr = Expr::Negative(Box::new(lit(1.0_f32))); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + + #[test] + fn roundtrip_inlist() { + let test_expr = Expr::InList { + expr: Box::new(lit(1.0_f32)), + list: vec![lit(2.0_f32)], + negated: true, + }; + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + + #[test] + fn roundtrip_wildcard() { + let test_expr = Expr::Wildcard; + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + + #[test] + fn roundtrip_sqrt() { + let test_expr = Expr::ScalarFunction { + fun: Sqrt, + args: vec![col("col")], + }; + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + + #[test] + fn roundtrip_like() { + fn like(negated: bool, escape_char: Option) { + let test_expr = Expr::Like(Like::new( + negated, + Box::new(col("col")), + Box::new(lit("[0-9]+")), + escape_char, + )); + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + like(true, Some('X')); + like(false, Some('\\')); + like(true, None); + like(false, None); + } + + #[test] + fn roundtrip_ilike() { + fn ilike(negated: bool, escape_char: Option) { + let test_expr = Expr::ILike(Like::new( + negated, + Box::new(col("col")), + Box::new(lit("[0-9]+")), + escape_char, + )); + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + ilike(true, Some('X')); + ilike(false, Some('\\')); + ilike(true, None); + ilike(false, None); + } + + #[test] + fn roundtrip_similar_to() { + fn similar_to(negated: bool, escape_char: Option) { + let test_expr = Expr::SimilarTo(Like::new( + negated, + Box::new(col("col")), + Box::new(lit("[0-9]+")), + escape_char, + )); + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + similar_to(true, Some('X')); + similar_to(false, Some('\\')); + similar_to(true, None); + similar_to(false, None); + } + + #[test] + fn roundtrip_count() { + let test_expr = Expr::AggregateFunction { + fun: AggregateFunction::Count, + args: vec![col("bananas")], + distinct: false, + filter: None, + }; + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + + #[test] + fn roundtrip_count_distinct() { + let test_expr = Expr::AggregateFunction { + fun: AggregateFunction::Count, + args: vec![col("bananas")], + distinct: true, + filter: None, + }; + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + + #[test] + fn roundtrip_approx_percentile_cont() { + let test_expr = Expr::AggregateFunction { + fun: AggregateFunction::ApproxPercentileCont, + args: vec![col("bananas"), lit(0.42_f32)], + distinct: false, + filter: None, + }; + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + + #[test] + fn roundtrip_aggregate_udf() { + #[derive(Debug)] + struct Dummy {} + + impl Accumulator for Dummy { + fn state(&self) -> datafusion::error::Result> { + Ok(vec![]) + } + + fn update_batch( + &mut self, + _values: &[ArrayRef], + ) -> datafusion::error::Result<()> { + Ok(()) + } + + fn merge_batch( + &mut self, + _states: &[ArrayRef], + ) -> datafusion::error::Result<()> { + Ok(()) + } + + fn evaluate(&self) -> datafusion::error::Result { + Ok(ScalarValue::Float64(None)) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + } + + let dummy_agg = create_udaf( + // the name; used to represent it in plan descriptions and in the registry, to use in SQL. + "dummy_agg", + // the input type; DataFusion guarantees that the first entry of `values` in `update` has this type. + DataType::Float64, + // the return type; DataFusion expects this to match the type returned by `evaluate`. + Arc::new(DataType::Float64), + Volatility::Immutable, + // This is the accumulator factory; DataFusion uses it to create new accumulators. + Arc::new(|_| Ok(Box::new(Dummy {}))), + // This is the description of the state. `state()` must match the types here. + Arc::new(vec![DataType::Float64, DataType::UInt32]), + ); + + let test_expr = Expr::AggregateUDF { + fun: Arc::new(dummy_agg.clone()), + args: vec![lit(1.0_f64)], + filter: Some(Box::new(lit(true))), + }; + + let ctx = SessionContext::new(); + ctx.register_udaf(dummy_agg); + + roundtrip_expr_test(test_expr, ctx); + } + + #[test] + fn roundtrip_scalar_udf() { + let fn_impl = |args: &[ArrayRef]| Ok(Arc::new(args[0].clone()) as ArrayRef); + + let scalar_fn = make_scalar_function(fn_impl); + + let udf = create_udf( + "dummy", + vec![DataType::Utf8], + Arc::new(DataType::Utf8), + Volatility::Immutable, + scalar_fn, + ); + + let test_expr = Expr::ScalarUDF { + fun: Arc::new(udf.clone()), + args: vec![lit("")], + }; + + let ctx = SessionContext::new(); + ctx.register_udf(udf); + + roundtrip_expr_test(test_expr, ctx); + } + + #[test] + fn roundtrip_grouping_sets() { + let test_expr = Expr::GroupingSet(GroupingSet::GroupingSets(vec![ + vec![col("a")], + vec![col("b")], + vec![col("a"), col("b")], + ])); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + + #[test] + fn roundtrip_rollup() { + let test_expr = Expr::GroupingSet(GroupingSet::Rollup(vec![col("a"), col("b")])); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + + #[test] + fn roundtrip_cube() { + let test_expr = Expr::GroupingSet(GroupingSet::Cube(vec![col("a"), col("b")])); + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx); + } + + #[test] + fn roundtrip_substr() { + // substr(string, position) + let test_expr = Expr::ScalarFunction { + fun: Substr, + args: vec![col("col"), lit(1_i64)], + }; + + // substr(string, position, count) + let test_expr_with_count = Expr::ScalarFunction { + fun: Substr, + args: vec![col("col"), lit(1_i64), lit(1_i64)], + }; + + let ctx = SessionContext::new(); + roundtrip_expr_test(test_expr, ctx.clone()); + roundtrip_expr_test(test_expr_with_count, ctx); + } + #[test] + fn roundtrip_window() { + let ctx = SessionContext::new(); + + // 1. without window_frame + let test_expr1 = Expr::WindowFunction { + fun: WindowFunction::BuiltInWindowFunction( + datafusion_expr::window_function::BuiltInWindowFunction::Rank, + ), + args: vec![], + partition_by: vec![col("col1")], + order_by: vec![col("col2")], + window_frame: WindowFrame::new(true), + }; + + // 2. with default window_frame + let test_expr2 = Expr::WindowFunction { + fun: WindowFunction::BuiltInWindowFunction( + datafusion_expr::window_function::BuiltInWindowFunction::Rank, + ), + args: vec![], + partition_by: vec![col("col1")], + order_by: vec![col("col2")], + window_frame: WindowFrame::new(true), + }; + + // 3. with window_frame with row numbers + let range_number_frame = WindowFrame { + units: WindowFrameUnits::Range, + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))), + end_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), + }; + + let test_expr3 = Expr::WindowFunction { + fun: WindowFunction::BuiltInWindowFunction( + datafusion_expr::window_function::BuiltInWindowFunction::Rank, + ), + args: vec![], + partition_by: vec![col("col1")], + order_by: vec![col("col2")], + window_frame: range_number_frame, + }; + + // 4. test with AggregateFunction + let row_number_frame = WindowFrame { + units: WindowFrameUnits::Rows, + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt64(Some(2))), + end_bound: WindowFrameBound::Following(ScalarValue::UInt64(Some(2))), + }; + + let test_expr4 = Expr::WindowFunction { + fun: WindowFunction::AggregateFunction(AggregateFunction::Max), + args: vec![col("col1")], + partition_by: vec![col("col1")], + order_by: vec![col("col2")], + window_frame: row_number_frame, + }; + + roundtrip_expr_test(test_expr1, ctx.clone()); + roundtrip_expr_test(test_expr2, ctx.clone()); + roundtrip_expr_test(test_expr3, ctx.clone()); + roundtrip_expr_test(test_expr4, ctx); + } +} diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs similarity index 98% rename from datafusion/proto/src/to_proto.rs rename to datafusion/proto/src/logical_plan/to_proto.rs index fdbcd060e731..b9350fdc1b39 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -39,8 +39,8 @@ use datafusion_expr::expr::{ }; use datafusion_expr::{ logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction, - BuiltInWindowFunction, BuiltinScalarFunction, Expr, WindowFrame, WindowFrameBound, - WindowFrameUnits, WindowFunction, + BuiltInWindowFunction, BuiltinScalarFunction, Expr, JoinConstraint, JoinType, + WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, }; #[derive(Debug)] @@ -1332,6 +1332,30 @@ impl From for protobuf::OwnedTableReference { } } +impl From for protobuf::JoinType { + fn from(t: JoinType) -> Self { + match t { + JoinType::Inner => protobuf::JoinType::Inner, + JoinType::Left => protobuf::JoinType::Left, + JoinType::Right => protobuf::JoinType::Right, + JoinType::Full => protobuf::JoinType::Full, + JoinType::LeftSemi => protobuf::JoinType::Leftsemi, + JoinType::RightSemi => protobuf::JoinType::Rightsemi, + JoinType::LeftAnti => protobuf::JoinType::Leftanti, + JoinType::RightAnti => protobuf::JoinType::Rightanti, + } + } +} + +impl From for protobuf::JoinConstraint { + fn from(t: JoinConstraint) -> Self { + match t { + JoinConstraint::On => protobuf::JoinConstraint::On, + JoinConstraint::Using => protobuf::JoinConstraint::Using, + } + } +} + /// Creates a scalar protobuf value from an optional value (T), and /// encoding None as the appropriate datatype fn create_proto_scalar protobuf::scalar_value::Value>( diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 54012eabd5e5..f4b4b6631f71 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -48,7 +48,7 @@ use std::sync::Arc; use crate::common::proto_error; use crate::convert_required; -use crate::from_proto::from_proto_binary_op; +use crate::logical_plan; use crate::protobuf::physical_expr_node::ExprType; use crate::protobuf::JoinSide; use datafusion::physical_plan::sorts::sort::SortOptions; @@ -83,7 +83,7 @@ pub(crate) fn parse_physical_expr( "left", input_schema, )?, - from_proto_binary_op(&binary_expr.op)?, + logical_plan::from_proto::from_proto_binary_op(&binary_expr.op)?, parse_required_physical_box_expr( &binary_expr.r, registry, @@ -93,7 +93,7 @@ pub(crate) fn parse_physical_expr( )), ExprType::DateTimeIntervalExpr(expr) => Arc::new(DateTimeIntervalExpr::try_new( parse_required_physical_box_expr(&expr.l, registry, "left", input_schema)?, - from_proto_binary_op(&expr.op)?, + logical_plan::from_proto::from_proto_binary_op(&expr.op)?, parse_required_physical_box_expr(&expr.r, registry, "right", input_schema)?, input_schema, )?), diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index dedafd94bdb0..a5ea75282777 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -53,7 +53,7 @@ use prost::Message; use crate::common::proto_error; use crate::common::{csv_delimiter_to_string, str_to_byte}; -use crate::from_proto::parse_expr; +use crate::logical_plan; use crate::physical_plan::from_proto::{ parse_physical_expr, parse_protobuf_file_scan_config, }; @@ -160,7 +160,7 @@ impl AsExecutionPlan for PhysicalPlanNode { let predicate = scan .pruning_predicate .as_ref() - .map(|expr| parse_expr(expr, registry)) + .map(|expr| logical_plan::from_proto::parse_expr(expr, registry)) .transpose()?; Ok(Arc::new(ParquetExec::new( parse_protobuf_file_scan_config( @@ -1165,6 +1165,32 @@ pub trait PhysicalExtensionCodec: Debug + Send + Sync { ) -> Result<(), DataFusionError>; } +#[derive(Debug)] +pub struct DefaultPhysicalExtensionCodec {} + +impl PhysicalExtensionCodec for DefaultPhysicalExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[Arc], + _registry: &dyn FunctionRegistry, + ) -> Result, DataFusionError> { + Err(DataFusionError::NotImplemented( + "PhysicalExtensionCodec is not provided".to_string(), + )) + } + + fn try_encode( + &self, + _node: Arc, + _buf: &mut Vec, + ) -> Result<(), DataFusionError> { + Err(DataFusionError::NotImplemented( + "PhysicalExtensionCodec is not provided".to_string(), + )) + } +} + #[macro_export] macro_rules! into_physical_plan { ($PB:expr, $REG:expr, $RUNTIME:expr, $CODEC:expr) => {{ @@ -1184,8 +1210,7 @@ mod roundtrip_tests { use std::sync::Arc; use super::super::protobuf; - use crate::bytes::DefaultPhysicalExtensionCodec; - use crate::physical_plan::AsExecutionPlan; + use crate::physical_plan::{AsExecutionPlan, DefaultPhysicalExtensionCodec}; use datafusion::arrow::array::ArrayRef; use datafusion::arrow::datatypes::IntervalUnit; use datafusion::config::ConfigOptions; From fb9ce98c1415c656022ece39bf21da3c73fa20a0 Mon Sep 17 00:00:00 2001 From: yangzhong Date: Sat, 17 Dec 2022 00:02:25 +0800 Subject: [PATCH 4/5] Reorganize the physical plan code in proto --- .../proto/src/physical_plan/from_proto.rs | 39 +++++------ .../proto/src/physical_plan/to_proto.rs | 69 ++++++++----------- 2 files changed, 44 insertions(+), 64 deletions(-) diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index f4b4b6631f71..5ff408ebb974 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -50,7 +50,7 @@ use crate::common::proto_error; use crate::convert_required; use crate::logical_plan; use crate::protobuf::physical_expr_node::ExprType; -use crate::protobuf::JoinSide; +use datafusion::physical_plan::joins::utils::JoinSide; use datafusion::physical_plan::sorts::sort::SortOptions; use parking_lot::RwLock; @@ -450,36 +450,31 @@ impl From<&protobuf::ColumnStats> for ColumnStatistics { } } -impl TryInto for &protobuf::Statistics { +impl From for JoinSide { + fn from(t: protobuf::JoinSide) -> Self { + match t { + protobuf::JoinSide::LeftSide => JoinSide::Left, + protobuf::JoinSide::RightSide => JoinSide::Right, + } + } +} + +impl TryFrom<&protobuf::Statistics> for Statistics { type Error = DataFusionError; - fn try_into(self) -> Result { - let column_statistics = self - .column_stats - .iter() - .map(|s| s.into()) - .collect::>(); + fn try_from(s: &protobuf::Statistics) -> Result { + let column_statistics = + s.column_stats.iter().map(|s| s.into()).collect::>(); Ok(Statistics { - num_rows: Some(self.num_rows as usize), - total_byte_size: Some(self.total_byte_size as usize), + num_rows: Some(s.num_rows as usize), + total_byte_size: Some(s.total_byte_size as usize), // No column statistic (None) is encoded with empty array column_statistics: if column_statistics.is_empty() { None } else { Some(column_statistics) }, - is_exact: self.is_exact, + is_exact: s.is_exact, }) } } - -impl From for datafusion::physical_plan::joins::utils::JoinSide { - fn from(t: JoinSide) -> Self { - match t { - JoinSide::LeftSide => datafusion::physical_plan::joins::utils::JoinSide::Left, - JoinSide::RightSide => { - datafusion::physical_plan::joins::utils::JoinSide::Right - } - } - } -} diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 20efbbf2bbcd..943f4f70da6c 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -44,96 +44,85 @@ use crate::protobuf::{ConfigOption, PhysicalSortExprNode}; use datafusion::logical_expr::BuiltinScalarFunction; use datafusion::physical_expr::expressions::DateTimeIntervalExpr; use datafusion::physical_expr::ScalarFunctionExpr; +use datafusion::physical_plan::joins::utils::JoinSide; use datafusion_common::DataFusionError; -impl TryInto for Arc { +impl TryFrom> for protobuf::PhysicalExprNode { type Error = DataFusionError; - fn try_into(self) -> Result { + fn try_from(a: Arc) -> Result { use datafusion::physical_plan::expressions; use protobuf::AggregateFunction; let mut distinct = false; - let aggr_function = if self.as_any().downcast_ref::().is_some() { + let aggr_function = if a.as_any().downcast_ref::().is_some() { Ok(AggregateFunction::Avg.into()) - } else if self.as_any().downcast_ref::().is_some() { + } else if a.as_any().downcast_ref::().is_some() { Ok(AggregateFunction::Sum.into()) - } else if self.as_any().downcast_ref::().is_some() { + } else if a.as_any().downcast_ref::().is_some() { Ok(AggregateFunction::Count.into()) - } else if self.as_any().downcast_ref::().is_some() { + } else if a.as_any().downcast_ref::().is_some() { distinct = true; Ok(AggregateFunction::Count.into()) - } else if self.as_any().downcast_ref::().is_some() { + } else if a.as_any().downcast_ref::().is_some() { Ok(AggregateFunction::Min.into()) - } else if self.as_any().downcast_ref::().is_some() { + } else if a.as_any().downcast_ref::().is_some() { Ok(AggregateFunction::Max.into()) - } else if self + } else if a .as_any() .downcast_ref::() .is_some() { Ok(AggregateFunction::ApproxDistinct.into()) - } else if self - .as_any() - .downcast_ref::() - .is_some() - { + } else if a.as_any().downcast_ref::().is_some() { Ok(AggregateFunction::ArrayAgg.into()) - } else if self - .as_any() - .downcast_ref::() - .is_some() - { + } else if a.as_any().downcast_ref::().is_some() { Ok(AggregateFunction::Variance.into()) - } else if self + } else if a .as_any() .downcast_ref::() .is_some() { Ok(AggregateFunction::VariancePop.into()) - } else if self + } else if a .as_any() .downcast_ref::() .is_some() { Ok(AggregateFunction::Covariance.into()) - } else if self + } else if a .as_any() .downcast_ref::() .is_some() { Ok(AggregateFunction::CovariancePop.into()) - } else if self - .as_any() - .downcast_ref::() - .is_some() - { + } else if a.as_any().downcast_ref::().is_some() { Ok(AggregateFunction::Stddev.into()) - } else if self + } else if a .as_any() .downcast_ref::() .is_some() { Ok(AggregateFunction::StddevPop.into()) - } else if self + } else if a .as_any() .downcast_ref::() .is_some() { Ok(AggregateFunction::Correlation.into()) - } else if self + } else if a .as_any() .downcast_ref::() .is_some() { Ok(AggregateFunction::ApproxPercentileCont.into()) - } else if self + } else if a .as_any() .downcast_ref::() .is_some() { Ok(AggregateFunction::ApproxPercentileContWithWeight.into()) - } else if self + } else if a .as_any() .downcast_ref::() .is_some() @@ -142,10 +131,10 @@ impl TryInto for Arc { } else { Err(DataFusionError::NotImplemented(format!( "Aggregate function not supported: {:?}", - self + a ))) }?; - let expressions: Vec = self + let expressions: Vec = a .expressions() .iter() .map(|e| e.clone().try_into()) @@ -494,15 +483,11 @@ impl TryFrom<&FileScanConfig> for protobuf::FileScanExecConf { } } -impl From for protobuf::JoinSide { - fn from(t: datafusion::physical_plan::joins::utils::JoinSide) -> Self { +impl From for protobuf::JoinSide { + fn from(t: JoinSide) -> Self { match t { - datafusion::physical_plan::joins::utils::JoinSide::Left => { - protobuf::JoinSide::LeftSide - } - datafusion::physical_plan::joins::utils::JoinSide::Right => { - protobuf::JoinSide::RightSide - } + JoinSide::Left => protobuf::JoinSide::LeftSide, + JoinSide::Right => protobuf::JoinSide::RightSide, } } } From 1d2271ee247db9fa4788784279fdf4025093741c Mon Sep 17 00:00:00 2001 From: yangzhong Date: Sat, 17 Dec 2022 00:50:16 +0800 Subject: [PATCH 5/5] Remove datafusion prefix in datafusion.proto --- datafusion/proto/proto/datafusion.proto | 78 ++++++++++++------------- 1 file changed, 39 insertions(+), 39 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 47f457622d14..2bc67dcbafb8 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -100,8 +100,8 @@ message ListingTableScanNode { repeated string paths = 2; string file_extension = 3; ProjectionColumns projection = 4; - datafusion.Schema schema = 5; - repeated datafusion.LogicalExprNode filters = 6; + Schema schema = 5; + repeated LogicalExprNode filters = 6; repeated string table_partition_cols = 7; bool collect_stat = 8; uint32 target_partitions = 9; @@ -110,13 +110,13 @@ message ListingTableScanNode { ParquetFormat parquet = 11; AvroFormat avro = 12; } - repeated datafusion.LogicalExprNode file_sort_order = 13; + repeated LogicalExprNode file_sort_order = 13; } message ViewTableScanNode { string table_name = 1; LogicalPlanNode input = 2; - datafusion.Schema schema = 3; + Schema schema = 3; ProjectionColumns projection = 4; string definition = 5; } @@ -125,14 +125,14 @@ message ViewTableScanNode { message CustomTableScanNode { string table_name = 1; ProjectionColumns projection = 2; - datafusion.Schema schema = 3; - repeated datafusion.LogicalExprNode filters = 4; + Schema schema = 3; + repeated LogicalExprNode filters = 4; bytes custom_table_data = 5; } message ProjectionNode { LogicalPlanNode input = 1; - repeated datafusion.LogicalExprNode expr = 2; + repeated LogicalExprNode expr = 2; oneof optional_alias { string alias = 3; } @@ -140,12 +140,12 @@ message ProjectionNode { message SelectionNode { LogicalPlanNode input = 1; - datafusion.LogicalExprNode expr = 2; + LogicalExprNode expr = 2; } message SortNode { LogicalPlanNode input = 1; - repeated datafusion.LogicalExprNode expr = 2; + repeated LogicalExprNode expr = 2; // Maximum number of highest/lowest rows to fetch; negative means no limit int64 fetch = 3; } @@ -159,7 +159,7 @@ message RepartitionNode { } message HashRepartition { - repeated datafusion.LogicalExprNode hash_expr = 1; + repeated LogicalExprNode hash_expr = 1; uint64 partition_count = 2; } @@ -173,7 +173,7 @@ message CreateExternalTableNode { string location = 2; string file_type = 3; bool has_header = 4; - datafusion.DfSchema schema = 5; + DfSchema schema = 5; repeated string table_partition_cols = 6; bool if_not_exists = 7; string delimiter = 8; @@ -191,13 +191,13 @@ message PrepareNode { message CreateCatalogSchemaNode { string schema_name = 1; bool if_not_exists = 2; - datafusion.DfSchema schema = 3; + DfSchema schema = 3; } message CreateCatalogNode { string catalog_name = 1; bool if_not_exists = 2; - datafusion.DfSchema schema = 3; + DfSchema schema = 3; } message CreateViewNode { @@ -212,7 +212,7 @@ message CreateViewNode { // the list is flattened, and with the field n_cols it can be parsed and partitioned into rows message ValuesNode { uint64 n_cols = 1; - repeated datafusion.LogicalExprNode values_list = 2; + repeated LogicalExprNode values_list = 2; } message AnalyzeNode { @@ -227,13 +227,13 @@ message ExplainNode { message AggregateNode { LogicalPlanNode input = 1; - repeated datafusion.LogicalExprNode group_expr = 2; - repeated datafusion.LogicalExprNode aggr_expr = 3; + repeated LogicalExprNode group_expr = 2; + repeated LogicalExprNode aggr_expr = 3; } message WindowNode { LogicalPlanNode input = 1; - repeated datafusion.LogicalExprNode window_expr = 2; + repeated LogicalExprNode window_expr = 2; } enum JoinType { @@ -257,8 +257,8 @@ message JoinNode { LogicalPlanNode right = 2; JoinType join_type = 3; JoinConstraint join_constraint = 4; - repeated datafusion.Column left_join_column = 5; - repeated datafusion.Column right_join_column = 6; + repeated Column left_join_column = 5; + repeated Column right_join_column = 6; bool null_equals_null = 7; LogicalExprNode filter = 8; } @@ -285,7 +285,7 @@ message LimitNode { } message SelectionExecNode { - datafusion.LogicalExprNode expr = 1; + LogicalExprNode expr = 1; } message SubqueryAliasNode { @@ -679,7 +679,7 @@ message WindowFrameBound { /////////////////////////////////////////////////////////////////////////////////////////////////// message Schema { - repeated datafusion.Field columns = 1; + repeated Field columns = 1; } message Field { @@ -993,7 +993,7 @@ message PhysicalExprNode { // column references PhysicalColumn column = 1; - datafusion.ScalarValue literal = 2; + ScalarValue literal = 2; // binary expressions PhysicalBinaryExprNode binary_expr = 3; @@ -1026,19 +1026,19 @@ message PhysicalExprNode { message PhysicalScalarUdfNode { string name = 1; repeated PhysicalExprNode args = 2; - datafusion.ArrowType return_type = 4; + ArrowType return_type = 4; } message PhysicalAggregateExprNode { - datafusion.AggregateFunction aggr_function = 1; + AggregateFunction aggr_function = 1; repeated PhysicalExprNode expr = 2; bool distinct = 3; } message PhysicalWindowExprNode { oneof window_function { - datafusion.AggregateFunction aggr_function = 1; - datafusion.BuiltInWindowFunction built_in_function = 2; + AggregateFunction aggr_function = 1; + BuiltInWindowFunction built_in_function = 2; // udaf = 3 } PhysicalExprNode expr = 4; @@ -1098,19 +1098,19 @@ message PhysicalCaseNode { message PhysicalScalarFunctionNode { string name = 1; - datafusion.ScalarFunction fun = 2; + ScalarFunction fun = 2; repeated PhysicalExprNode args = 3; - datafusion.ArrowType return_type = 4; + ArrowType return_type = 4; } message PhysicalTryCastNode { PhysicalExprNode expr = 1; - datafusion.ArrowType arrow_type = 2; + ArrowType arrow_type = 2; } message PhysicalCastNode { PhysicalExprNode expr = 1; - datafusion.ArrowType arrow_type = 2; + ArrowType arrow_type = 2; } message PhysicalNegativeNode { @@ -1133,7 +1133,7 @@ message ScanLimit { message FileScanExecConf { repeated FileGroup file_groups = 1; - datafusion.Schema schema = 2; + Schema schema = 2; repeated uint32 projection = 4; ScanLimit limit = 5; Statistics statistics = 6; @@ -1150,7 +1150,7 @@ message ConfigOption { message ParquetScanExecNode { FileScanExecConf base_conf = 1; - datafusion.LogicalExprNode pruning_predicate = 2; + LogicalExprNode pruning_predicate = 2; } message CsvScanExecNode { @@ -1173,7 +1173,7 @@ message HashJoinExecNode { PhysicalPlanNode left = 1; PhysicalPlanNode right = 2; repeated JoinOn on = 3; - datafusion.JoinType join_type = 4; + JoinType join_type = 4; PartitionMode partition_mode = 6; bool null_equals_null = 7; JoinFilter filter = 8; @@ -1184,8 +1184,8 @@ message UnionExecNode { } message ExplainExecNode { - datafusion.Schema schema = 1; - repeated datafusion.StringifiedPlan stringified_plans = 2; + Schema schema = 1; + repeated StringifiedPlan stringified_plans = 2; bool verbose = 3; } @@ -1206,7 +1206,7 @@ message JoinOn { message EmptyExecNode { bool produce_one_row = 1; - datafusion.Schema schema = 2; + Schema schema = 2; } message ProjectionExecNode { @@ -1225,7 +1225,7 @@ message WindowAggExecNode { PhysicalPlanNode input = 1; repeated PhysicalExprNode window_expr = 2; repeated string window_expr_name = 3; - datafusion.Schema input_schema = 4; + Schema input_schema = 4; } message AggregateExecNode { @@ -1236,7 +1236,7 @@ message AggregateExecNode { repeated string group_expr_name = 5; repeated string aggr_expr_name = 6; // we need the input schema to the partial aggregate to pass to the final aggregate - datafusion.Schema input_schema = 7; + Schema input_schema = 7; repeated PhysicalExprNode null_expr = 8; repeated bool groups = 9; } @@ -1292,7 +1292,7 @@ message RepartitionExecNode{ message JoinFilter{ PhysicalExprNode expression = 1; repeated ColumnIndex column_indices = 2; - datafusion.Schema schema = 3; + Schema schema = 3; } message ColumnIndex{