diff --git a/native/core/src/execution/operators/scan.rs b/native/core/src/execution/operators/scan.rs index c94c2be37b..430c9b1464 100644 --- a/native/core/src/execution/operators/scan.rs +++ b/native/core/src/execution/operators/scan.rs @@ -80,6 +80,8 @@ pub struct ScanExec { jvm_fetch_time: Time, /// Time spent in FFI arrow_ffi_time: Time, + /// Does the source of this scan re-use buffers? + pub has_buffer_reuse: bool, } impl ScanExec { @@ -88,6 +90,7 @@ impl ScanExec { input_source: Option>, input_source_description: &str, data_types: Vec, + has_buffer_reuse: bool, ) -> Result { let metrics_set = ExecutionPlanMetricsSet::default(); let baseline_metrics = BaselineMetrics::new(&metrics_set, 0); @@ -138,6 +141,7 @@ impl ScanExec { jvm_fetch_time, arrow_ffi_time, schema, + has_buffer_reuse, }) } diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index fd4ea69b45..f71a6e7204 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1216,8 +1216,13 @@ impl PhysicalPlanner { }; // The `ScanExec` operator will take actual arrays from Spark during execution - let scan = - ScanExec::new(self.exec_context_id, input_source, &scan.source, data_types)?; + let scan = ScanExec::new( + self.exec_context_id, + input_source, + &scan.source, + data_types, + scan.has_buffer_reuse, + )?; Ok(( vec![scan.clone()], Arc::new(SparkPlan::new(spark_plan.plan_id, Arc::new(scan), vec![])), @@ -2341,8 +2346,13 @@ impl From for DataFusionError { fn can_reuse_input_batch(op: &Arc) -> bool { if op.as_any().is::() || op.as_any().is::() { can_reuse_input_batch(op.children()[0]) + } else if op.as_any().is::() { + op.as_any() + .downcast_ref::() + .unwrap() + .has_buffer_reuse } else { - op.as_any().is::() + false } } @@ -2594,6 +2604,7 @@ mod tests { type_info: None, }], source: "".to_string(), + has_buffer_reuse: false, })), }; @@ -2667,6 +2678,7 @@ mod tests { type_info: None, }], source: "".to_string(), + has_buffer_reuse: false, })), }; @@ -2877,6 +2889,7 @@ mod tests { op_struct: Some(OpStruct::Scan(spark_operator::Scan { fields: vec![create_proto_datatype()], source: "".to_string(), + has_buffer_reuse: false, })), } } @@ -2919,6 +2932,7 @@ mod tests { }, ], source: "".to_string(), + has_buffer_reuse: false, })), }; @@ -3039,6 +3053,7 @@ mod tests { }, ], source: "".to_string(), + has_buffer_reuse: false, })), }; diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 9a41f977a7..02830629c2 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -76,6 +76,8 @@ message Scan { // is purely for informational purposes when viewing native query plans in // debug mode. string source = 2; + // Specifies whether the source of the scan reuses buffers + bool has_buffer_reuse = 3; } message NativeScan { diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 7127315b6b..42ec85e9be 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2830,6 +2830,12 @@ object QueryPlanSerde extends Logging with CometExprShim { scanBuilder.setSource(source) } + op match { + case scan: CometScanExec => + scanBuilder.setHasBufferReuse(scan.scanImpl == CometConf.SCAN_NATIVE_COMET) + case _ => + } + val scanTypes = op.output.flatten { attr => serializeDataType(attr.dataType) } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala index 66d2fac89c..1fac876e4d 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala @@ -145,6 +145,9 @@ class CometNativeShuffleWriter[K, V]( serializeDataType(attr.dataType) } + // TODO only set this if input is native_comet scan + scanBuilder.setHasBufferReuse(true) + if (scanTypes.length == outputAttributes.length) { scanBuilder.addAllFields(scanTypes.asJava)