diff --git a/test/sql/measures.test b/test/sql/measures.test index 8e503cb..547e24b 100644 --- a/test/sql/measures.test +++ b/test/sql/measures.test @@ -761,6 +761,34 @@ FROM daily_orders_v 2 320.0 840.0 3 270.0 840.0 +# ============================================================================= +# Test: GROUP BY alias for expression dimension +# ============================================================================= + +statement ok +CREATE TABLE monthly_sales (order_date DATE, region TEXT, amount DOUBLE); + +statement ok +INSERT INTO monthly_sales VALUES + ('2023-01-05', 'US', 100), ('2023-01-12', 'EU', 50), + ('2023-02-03', 'US', 200), ('2023-02-20', 'EU', 20); + +statement ok +CREATE VIEW monthly_sales_v AS +SELECT DATE_TRUNC('month', order_date) AS month, region, SUM(amount) AS MEASURE revenue +FROM monthly_sales +GROUP BY DATE_TRUNC('month', order_date), region; + +query IIRR rowsort +SEMANTIC SELECT month, region, AGGREGATE(revenue), AGGREGATE(revenue) AT (ALL region) AS month_total +FROM monthly_sales_v +; +---- +2023-01-01 EU 50.0 150.0 +2023-01-01 US 100.0 150.0 +2023-02-01 EU 20.0 220.0 +2023-02-01 US 200.0 220.0 + # ============================================================================= # Test: Multi-fact JOINs (wide tables) # ============================================================================= @@ -818,6 +846,40 @@ FROM fact_orders_v o JOIN fact_returns_v r ON o.year = r.year AND o.region = r.r 2023 EU 75.0 225.0 2023 US 150.0 225.0 +# ============================================================================= +# Test: JOIN with extra dimension from second table +# ============================================================================= + +statement ok +CREATE TABLE salesdetails (year INT, region TEXT, product TEXT, amount DOUBLE); + +statement ok +INSERT INTO salesdetails VALUES + (2022, 'US', 'Shoes', 2), (2022, 'US', 'Cars', 1), + (2022, 'EU', 'Shoes', 3), + (2023, 'US', 'Shoes', 4), (2023, 'US', 'Cars', 2), + (2023, 'EU', 'Cars', 5); + +statement ok +CREATE VIEW salesdetails_v AS +SELECT year, region, product, SUM(amount) AS MEASURE quantity +FROM salesdetails; + +query IIIRRR rowsort +SEMANTIC SELECT s.year, s.region, sd.product, + AGGREGATE(revenue) AS year_sales_revenue, + AGGREGATE(revenue) AT (ALL year) AS region_total, + AGGREGATE(quantity) AS product_qty +FROM sales_v s JOIN salesdetails_v sd ON s.year = sd.year AND s.region = sd.region +; +---- +2022 EU Shoes 50.0 125.0 3.0 +2022 US Cars 100.0 250.0 1.0 +2022 US Shoes 100.0 250.0 2.0 +2023 EU Cars 75.0 125.0 5.0 +2023 US Cars 150.0 250.0 2.0 +2023 US Shoes 150.0 250.0 4.0 + # ============================================================================= # Test: SET reaches beyond WHERE clause (paper semantics) # Per paper: SET should evaluate over data removed by outer WHERE clause diff --git a/yardstick-rs/src/sql/measures.rs b/yardstick-rs/src/sql/measures.rs index 2e01a52..485811c 100644 --- a/yardstick-rs/src/sql/measures.rs +++ b/yardstick-rs/src/sql/measures.rs @@ -7,7 +7,7 @@ //! //! Reference: https://arxiv.org/abs/2406.00251 -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::Mutex; use nom::{ @@ -1051,6 +1051,52 @@ fn group_by_matches_view(outer_cols: &[String], view_cols: &[String]) -> bool { !outer_set.is_empty() && outer_set == view_set } +fn filter_group_by_cols_for_measure( + outer_cols: &[String], + view_cols: &[String], + dimension_exprs: &HashMap, +) -> Vec { + if view_cols.is_empty() { + return outer_cols.to_vec(); + } + + let view_set: HashSet = view_cols + .iter() + .map(|col| normalize_group_by_col(col)) + .collect(); + + let mut alias_expr_norms: Vec<(String, String)> = Vec::new(); + alias_expr_norms.reserve(dimension_exprs.len()); + for (alias, expr) in dimension_exprs.iter() { + alias_expr_norms.push(( + normalize_group_by_col(alias), + normalize_group_by_col(expr), + )); + } + + outer_cols + .iter() + .filter(|col| { + let normalized_outer = normalize_group_by_col(col); + if view_set.contains(&normalized_outer) { + return true; + } + + if let Some(expr) = dimension_exprs.get(&normalized_outer) { + let normalized_expr = normalize_group_by_col(expr); + if view_set.contains(&normalized_expr) { + return true; + } + } + + alias_expr_norms.iter().any(|(alias_norm, expr_norm)| { + expr_norm == &normalized_outer && view_set.contains(alias_norm) + }) + }) + .cloned() + .collect() +} + fn can_use_view_measure_directly(resolved: &ResolvedMeasure, outer_group_by: &[String]) -> bool { group_by_matches_view(outer_group_by, &resolved.view_group_by_cols) } @@ -3416,6 +3462,11 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult { for (measure_name, modifiers, start, end) in patterns { // Look up which view contains this measure (for JOIN support) let resolved = resolve_measure_source(&measure_name, &primary_table_name); + let measure_group_by_cols = filter_group_by_cols_for_measure( + &group_by_cols, + &resolved.view_group_by_cols, + &resolved.dimension_exprs, + ); // Non-decomposable measures are recomputed from base rows (including AT modifiers) @@ -3452,7 +3503,7 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult { &resolved.source_view, outer_alias_ref, outer_where_ref, - &group_by_cols, + &measure_group_by_cols, ) } else if !resolved.is_decomposable { let outer_ref_for_non_decomp = @@ -3480,7 +3531,7 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult { &base_relation_sql, outer_ref_for_non_decomp, outer_where_ref, - &group_by_cols, + &measure_group_by_cols, &modifiers, &resolved.dimension_exprs, &format!("_nd_{join_counter}"), @@ -3496,7 +3547,7 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult { &base_relation_sql, outer_ref_for_non_decomp, outer_where_ref, - &group_by_cols, + &measure_group_by_cols, &modifiers, &resolved.dimension_exprs, ), @@ -3517,7 +3568,7 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult { &resolved.source_view, outer_alias_ref, outer_where_ref, - &group_by_cols, + &measure_group_by_cols, ) }; result_sql = format!("{}{}{}", &result_sql[..start], expanded, &result_sql[end..]); @@ -3529,6 +3580,11 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult { for (measure_name, start, end) in plain_calls { let resolved = resolve_measure_source(&measure_name, &primary_table_name); + let measure_group_by_cols = filter_group_by_cols_for_measure( + &group_by_cols, + &resolved.view_group_by_cols, + &resolved.dimension_exprs, + ); // For derived measures, use the expanded expression; otherwise use AGG(measure_name) @@ -3558,7 +3614,7 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult { &base_relation_sql, outer_ref_for_non_decomp, outer_where_ref, - &group_by_cols, + &measure_group_by_cols, &[], // No modifiers for plain AGGREGATE() &resolved.dimension_exprs, &format!("_nd_{join_counter}"), @@ -3574,7 +3630,7 @@ pub fn expand_aggregate_with_at(sql: &str) -> AggregateExpandResult { &base_relation_sql, outer_ref_for_non_decomp, outer_where_ref, - &group_by_cols, + &measure_group_by_cols, &[], // No modifiers for plain AGGREGATE() &resolved.dimension_exprs, ),