diff --git a/README.md b/README.md index 52dedf9..51f4799 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,7 @@ Initial development is done in **Go** (`opti-sql-go`), which serves as the prima - `/operators` - SQL operator implementations (filter, join, aggregation, project) - `/physical-optimizer` - Query plan parsing and optimization - `/substrait` - Substrait plan integration +- `/operators/OPERATORS.md` - concise reference for operator constructors, behavior and examples ## Branching Model diff --git a/src/Backend/opti-sql-go/Expr/expr.go b/src/Backend/opti-sql-go/Expr/expr.go index 665990d..4899a15 100644 --- a/src/Backend/opti-sql-go/Expr/expr.go +++ b/src/Backend/opti-sql-go/Expr/expr.go @@ -260,10 +260,8 @@ func NewLiteralResolve(Type arrow.DataType, Value any) *LiteralResolve { castVal = float64(v) } default: - fmt.Printf("%v did not match any case, of type %T\n", v, v) castVal = Value } - fmt.Printf("sotred as -> %v\t%v\n", Type, castVal) return &LiteralResolve{Type: Type, Value: castVal} } func EvalLiteral(l *LiteralResolve, batch *operators.RecordBatch) (arrow.Array, error) { @@ -448,37 +446,36 @@ func EvalBinary(b *BinaryExpr, batch *operators.RecordBatch) (arrow.Array, error if err != nil { return nil, err } + ctx := context.Background() opt := compute.ArithmeticOptions{} switch b.Op { // arithmetic case Addition: - datum, err := compute.Add(context.TODO(), opt, compute.NewDatum(leftArr), compute.NewDatum(rightArr)) + datum, err := compute.Add(ctx, opt, compute.NewDatum(leftArr), compute.NewDatum(rightArr)) if err != nil { return nil, err } return unpackDatum(datum) case Subtraction: - datum, err := compute.Subtract(context.TODO(), opt, compute.NewDatum(leftArr), compute.NewDatum(rightArr)) + datum, err := compute.Subtract(ctx, opt, compute.NewDatum(leftArr), compute.NewDatum(rightArr)) if err != nil { return nil, err } return unpackDatum(datum) case Multiplication: - datum, err := compute.Multiply(context.TODO(), opt, compute.NewDatum(leftArr), compute.NewDatum(rightArr)) + datum, err := compute.Multiply(ctx, opt, compute.NewDatum(leftArr), compute.NewDatum(rightArr)) if err != nil { return nil, err } return unpackDatum(datum) case Division: - datum, err := compute.Divide(context.TODO(), opt, compute.NewDatum(leftArr), compute.NewDatum(rightArr)) + datum, err := compute.Divide(ctx, opt, compute.NewDatum(leftArr), compute.NewDatum(rightArr)) if err != nil { return nil, err } return unpackDatum(datum) - // comparisions TODO: - // These return a boolean array case Equal: if leftArr.DataType() != rightArr.DataType() { return nil, ErrCantCompareDifferentTypes(leftArr.DataType(), rightArr.DataType()) @@ -593,6 +590,7 @@ func NewScalarFunction(function supportedFunctions, Argument Expression) *Scalar } func EvalScalarFunction(s *ScalarFunction, batch *operators.RecordBatch) (arrow.Array, error) { + ctx := context.Background() switch s.Function { case Upper: arr, err := EvalExpression(s.Arguments, batch) @@ -612,7 +610,7 @@ func EvalScalarFunction(s *ScalarFunction, batch *operators.RecordBatch) (arrow. if err != nil { return nil, err } - datum, err := compute.AbsoluteValue(context.TODO(), compute.ArithmeticOptions{}, compute.NewDatum(arr)) + datum, err := compute.AbsoluteValue(ctx, compute.ArithmeticOptions{}, compute.NewDatum(arr)) if err != nil { return nil, err } @@ -622,7 +620,7 @@ func EvalScalarFunction(s *ScalarFunction, batch *operators.RecordBatch) (arrow. if err != nil { return nil, err } - datum, err := compute.Round(context.TODO(), compute.DefaultRoundOptions, compute.NewDatum(arr)) + datum, err := compute.Round(ctx, compute.DefaultRoundOptions, compute.NewDatum(arr)) if err != nil { return nil, err } @@ -657,7 +655,7 @@ func EvalCast(c *CastExpr, batch *operators.RecordBatch) (arrow.Array, error) { // Use Arrow compute kernel to cast castOpts := compute.SafeCastOptions(c.TargetType) - out, err := compute.CastArray(context.TODO(), arr, castOpts) + out, err := compute.CastArray(context.Background(), arr, castOpts) if err != nil { return nil, fmt.Errorf("cast error: cannot cast %s to %s: %w", arr.DataType(), c.TargetType, err) diff --git a/src/Backend/opti-sql-go/Expr/expr_test.go b/src/Backend/opti-sql-go/Expr/expr_test.go index 7f839bc..f0d2f43 100644 --- a/src/Backend/opti-sql-go/Expr/expr_test.go +++ b/src/Backend/opti-sql-go/Expr/expr_test.go @@ -1,7 +1,6 @@ package Expr import ( - "fmt" "log" "opti-sql-go/operators" "testing" @@ -1550,7 +1549,7 @@ func TestLikeOperatorSQL(t *testing.T) { t.Run("name starts with a", func(t *testing.T) { rc := generateTestColumns() sqlStatment := "A%" - whereStatment := NewBinaryExpr(NewColumnResolve("name"), Like, NewLiteralResolve(arrow.BinaryTypes.String, string(sqlStatment))) + whereStatment := NewBinaryExpr(NewColumnResolve("name"), Like, NewLiteralResolve(arrow.BinaryTypes.String, sqlStatment)) boolMask, err := EvalExpression(whereStatment, rc) if err != nil { t.Fatalf("unexpected error from EvalExpression") @@ -1572,7 +1571,7 @@ func TestLikeOperatorSQL(t *testing.T) { t.Run("name contains li", func(t *testing.T) { rc := generateTestColumns() sqlStatment := "%li%" - whereStatment := NewBinaryExpr(NewColumnResolve("name"), Like, NewLiteralResolve(arrow.BinaryTypes.String, string(sqlStatment))) + whereStatment := NewBinaryExpr(NewColumnResolve("name"), Like, NewLiteralResolve(arrow.BinaryTypes.String, sqlStatment)) boolMask, err := EvalExpression(whereStatment, rc) if err != nil { @@ -1624,7 +1623,7 @@ func TestLikeOperatorSQL(t *testing.T) { t.Run("name is exactly 5 letters", func(t *testing.T) { rc := generateTestColumns() sqlStatment := "_____" - whereStatment := NewBinaryExpr(NewColumnResolve("name"), Like, NewLiteralResolve(arrow.BinaryTypes.String, string(sqlStatment))) + whereStatment := NewBinaryExpr(NewColumnResolve("name"), Like, NewLiteralResolve(arrow.BinaryTypes.String, sqlStatment)) boolMask, err := EvalExpression(whereStatment, rc) if err != nil { @@ -1650,7 +1649,7 @@ func TestLikeOperatorSQL(t *testing.T) { t.Run("name starts with Ch", func(t *testing.T) { rc := generateTestColumns() sqlStatment := "Ch%" - whereStatment := NewBinaryExpr(NewColumnResolve("name"), Like, NewLiteralResolve(arrow.BinaryTypes.String, string(sqlStatment))) + whereStatment := NewBinaryExpr(NewColumnResolve("name"), Like, NewLiteralResolve(arrow.BinaryTypes.String, sqlStatment)) boolMask, err := EvalExpression(whereStatment, rc) if err != nil { @@ -1727,7 +1726,6 @@ func TestNullCheckExpr(t *testing.T) { defer maskArr.Release() boolMask := maskArr.(*array.Boolean) - fmt.Printf("boolean mask:\t%v\n", boolMask) if boolMask.Len() != 5 { t.Fatalf("expected length 5 mask, got %d", boolMask.Len()) } diff --git a/src/Backend/opti-sql-go/config/config.go b/src/Backend/opti-sql-go/config/config.go index 627136b..17154fe 100644 --- a/src/Backend/opti-sql-go/config/config.go +++ b/src/Backend/opti-sql-go/config/config.go @@ -92,7 +92,6 @@ var configInstance *Config = &Config{ EnableQueryStats: true, EnableMemoryStats: true, }, - // TODO: remove hardcoded secretes before production. we are just testing for now Secretes: secretesConfig{ AccessKey: "DO8013ZT6VDHJ2EM94RN", SecretKey: "kPvQSMt6naiwe/FhDnzXpYmVE5yzJUsIR0/OJpsUNzo", diff --git a/src/Backend/opti-sql-go/main.go b/src/Backend/opti-sql-go/main.go index f277de6..82e1eb8 100644 --- a/src/Backend/opti-sql-go/main.go +++ b/src/Backend/opti-sql-go/main.go @@ -6,8 +6,6 @@ import ( "os" ) -// TODO: in the project operators make sure the record batches account for the RowCount field properly. - func main() { if len(os.Args) > 1 { if err := config.Decode(os.Args[1]); err != nil { diff --git a/src/Backend/opti-sql-go/operators/Join/hashJoin.go b/src/Backend/opti-sql-go/operators/Join/hashJoin.go index e630a70..13a6969 100644 --- a/src/Backend/opti-sql-go/operators/Join/hashJoin.go +++ b/src/Backend/opti-sql-go/operators/Join/hashJoin.go @@ -17,10 +17,7 @@ import ( "github.com/apache/arrow/go/v17/arrow/memory" ) -// TODO: clean up PR and push again -// TODO: write intergration test for operators to work together // TODO: see ticket #27 -// TODO: take small break from this project to work on inverted index search for a couple days var ( ErrInvalidJoinClauseCount = func(l, r int) error { @@ -395,7 +392,7 @@ func (hj *HashJoinExec) buildOutputArrays( leftIdxArr arrow.Array, rightIdxArr arrow.Array, ) ([]arrow.Array, error) { - ctx := context.TODO() + ctx := context.Background() output := make([]arrow.Array, hj.schema.NumFields()) for i := range len(leftCols) { diff --git a/src/Backend/opti-sql-go/operators/OPERATORS.md b/src/Backend/opti-sql-go/operators/OPERATORS.md new file mode 100644 index 0000000..1302f6b --- /dev/null +++ b/src/Backend/opti-sql-go/operators/OPERATORS.md @@ -0,0 +1,158 @@ +# Operators — quick reference + +This document gives a concise overview of the operator model used in this repository, how to construct the most common operators, and what each operator's constructor expects and why. Placeholders like `Expr.Expression` and `RecordBatch` refer to the repository types found under `Expr` and `operators/record.go`. + +## What is an Operator? + +An operator implements the `operators.Operator` interface: + +- `Next(n uint16) (*operators.RecordBatch, error)` — return up to `n` rows (many operators ignore the exact n and read/produce what they need). Returns `io.EOF` when finished. +- `Schema() *arrow.Schema` — the operator's output schema. +- `Close() error` — release resources (files, network handles, etc.). + +The basic data unit is `operators.RecordBatch` (schema + Arrow arrays + rowcount). Operators compose: the output of one operator becomes the input (child) of the next. + +## Leaf (source) operators + +Leaf operators are the pipeline entry points. They read data from some storage and produce `RecordBatch` values. + +- CSV source + - Constructor: `project.NewProjectCSVLeaf(io.Reader)` + - Inputs: an `io.Reader` (file, buffer). Produces typed Arrow arrays from CSV columns. + - Notes: simple, fast for local CSVs. Use when you want a streaming CSV source. + +- Parquet source + - Constructor: (parquet reader; see project package) + - Inputs: parquet file handle. Produces Arrow arrays preserving parquet types. + +- In-memory source + - Constructor: `project.NewInMemoryProjectExec(names []string, columns []any)` + - Inputs: column names and Go slices (used heavily in unit tests). + - Notes: useful for deterministic test inputs and small-memory datasets. + +- S3 / NetworkResource + - use `project.NewStreamReader` to create a network file reader. this just means it allows chunk reading of files not on local disk. + - Notes: the repository supports reading remote files; a configuration option lets you download the full remote file first to avoid per-request network latency when the operator needs repeated random access (e.g., for Parquet or when sorting). This is exposed as a NetworkResource / download option in the project/source constructors. + - the result of `project.NewStreamReader(fileName)` can be passed directly to `project.NewProjectCSVLeaf(io.Reader)` and `project.NewParquetSource(readSeeker)`. This was intentional so its seemless to work with s3 files as possible + +## How to construct operators — summary of common operators + +The pattern is consistent: each operator has a `NewXxx...` constructor that takes one or more child operators, expression descriptors, or configuration params. + +### Project (Select) +- Constructor: `project.NewProjectExec(child operators.Operator, exprs []Expr.Expression)` +- Purpose: evaluate a list of projection expressions (column refs, scalar functions, aliases) and return a batch with only the requested columns. +- What to pass in: + - `child` — the input operator to project from (leaf or intermediate op). + - `exprs` — expressions created with `Expr.NewColumnResolve`, `Expr.NewLiteralResolve`, `Expr.NewAlias`, `Expr.NewScalarFunction`, etc. +- Why: keeps expression evaluation centralized and lets downstream operators work with a narrow schema. + +### Filter +- Constructor: `filter.NewFilterExec(child operators.Operator, predicate Expr.Expression)` +- Purpose: apply boolean predicates to input rows and emit only matching rows. +- What to pass in: + - `child` — operator producing input rows. + - `predicate` — an `Expr.Expression` that evaluates to boolean (can combine binary operators, scalar functions, null checks). +- Why: decouples predicate evaluation from projection and other operators; filter may buffer results across batches to serve limit-like requests. + +### Limit +- Constructor: `filter.NewLimitExec(child operators.Operator, limit uint64)` +- Purpose: stop the pipeline after `limit` rows are emitted. +- What to pass in: the `child` operator and the numeric `limit`. +- Why: simple consumer-side cap; implemented as a thin operator above any child. + +### Distinct +- Constructor: `filter.NewDistinctExec(child operators.Operator, colExprs []Expr.Expression)` +- Purpose: remove duplicate rows on the selected key columns. +- What to pass in: `child` and the list of key column expressions. +- Why: used to produce unique values for a given set of columns; often followed by `Sort` for deterministic order. + +### Sort / TopK +- Constructors: + - `aggr.NewSortExec(child operators.Operator, sortKeys []aggr.SortKey)` — fully sorts input + - `aggr.NewTopKSortExec(child operators.Operator, sortKeys []aggr.SortKey, k uint16)` — keep top-k +- Purpose: order rows by one or more columns. +- What to pass in: + - `child` — input operator + - `sortKeys` — built with `aggr.NewSortKey(expr Expr.Expression, asc bool)`; multiple keys are combined with `aggr.CombineSortKeys(...)`. +- Why: some consumers require sorted input (ORDER BY) or only the top-k entries (TopK). +- Notes: current implementations read data into memory and sort; care must be taken for large datasets. + +### GroupBy / Aggregation +- Constructors: + - `aggr.NewGroupByExec(child operators.Operator, groupExpr []aggr.AggregateFunctions, groupBy []Expr.Expression)` — group-by with aggregates + - `aggr.NewGlobalAggrExec(child operators.Operator, aggExprs []aggr.AggregateFunctions)` — global aggregation (no GROUP BY) +- Purpose: compute aggregates (SUM, AVG, COUNT, MIN, MAX) grouped by one or more columns. +- What to pass in: + - `child` — input operator + - `groupExpr` / `aggExprs` — list of `aggr.AggregateFunctions` (built with `aggr.NewAggregateFunctions(aggr.AggrFunc, Expr.Expression)`) describing the aggregate function and its child expression (usually a column). + - `groupBy` — expressions for the group-by keys (column resolves). +- Why: central place for aggregator logic; constructors validate types (numeric types for SUM/AVG) and construct the output schema. + +### Join (HashJoin) +- Constructor: `join.NewHashJoinExec(left, right operators.Operator, clause join.JoinClause, joinType join.JoinType, filters []Expr.Expression)` +- Purpose: perform hash-based joins (Inner, Left, Right). +- What to pass in: + - `left`, `right` — child operators for the two sides of the join (usually scans or projections) + - `clause` — `join.NewJoinClause(leftExprs []Expr.Expression, rightExprs []Expr.Expression)` describing which columns pair together (supports multiple equality clauses) + - `joinType` — `join.InnerJoin`, `join.LeftJoin`, etc. + - `filters` — optional post-join filters (not always used) | still need to implement this but no time soon, as these can just be treated as Filter Opererations +- Why: joins combine rows from two inputs. The constructor validates schema compatibility and builds the combined output schema (prefixing duplicate column names with `left_`/`right_`). +- Implementation notes: the HashJoin reads the entirety of both children (current implementation) into memory and builds a hash table on the right side for probing. + +## Common constructor patterns & rationale + +- Child operator(s) always come first: most operators are constructed around one input (`child`) or two (`left`, `right`). This makes pipelines composable. +- Expressions are passed as `Expr.Expression` objects. Use the `Expr` package helpers to build column resolves, literals, scalar functions, binary operators and aliases. +- Constructors perform validation: type checking for aggregates, matching # of join expressions, or validity of projection expressions — this fails fast at construction time instead of at runtime. +- Many blocking operators (Sort, GroupBy, Join) read the full input before producing output. Be careful with large inputs — these operators are not yet externalized (spill-to-disk) and may require configuration or chunking for large datasets. + +## Practical examples (pseudocode) + +- Project + Filter + Limit pipeline: + +```go +src := project.NewProjectCSVLeaf(fileReader) +pred := Expr.NewBinaryExpr(Expr.NewColumnResolve("age"), Expr.GreaterThan, Expr.NewLiteralResolve(arrow.PrimitiveTypes.Int64, 30)) +filt, _ := filter.NewFilterExec(src, pred) +projExprs := Expr.NewExpressions(Expr.NewColumnResolve("id"), Expr.NewColumnResolve("name")) +proj, _ := project.NewProjectExec(filt, projExprs) +lim, _ := filter.NewLimitExec(proj, 10) +batch, _ := lim.Next(10) +``` + +- GroupBy example: + +```go +col := func(n string) Expr.Expression { return Expr.NewColumnResolve(n) } +aggs := []aggr.AggregateFunctions{{AggrFunc: aggr.Sum, Child: col("salary")}} +gb, _ := aggr.NewGroupByExec(src, aggs, []Expr.Expression{col("department")}) +result, _ := gb.Next(1000) +``` + +- HashJoin example (equality on `id`): + +```go +clause := join.NewJoinClause([]Expr.Expression{Expr.NewColumnResolve("id")}, []Expr.Expression{Expr.NewColumnResolve("id")}) +j, _ := join.NewHashJoinExec(leftSrc, rightSrc, clause, join.InnerJoin, nil) +batch, _ := j.Next(100) +``` + +## Notes & best practices + +- Always call `Close()` on the root operator when done (after `Next` returns `io.EOF`) to release files and network handles. +- Use `project.NewInMemoryProjectExec` for tests — it builds reproducible `RecordBatch` inputs quickly. +- When writing pipelines that may read remote files, prefer to configure the source to download the whole file if the operator will need random access or many read passes (sorting, joining, grouping). This avoids repeated network calls and unpredictable latency. +- Watch out for duplicate column names after joins: the join constructor prefixes with `left_`/`right_` when needed. + +## Where to look next in the codebase +- `operators/record.go` — `Operator` interface and `RecordBatch` helpers (builder, PrettyPrint). +- `operators/project/` — project implementations and CSV/parquet readers. +- `operators/filter/` — Filter, Limit, Distinct operator implementations. +- `operators/aggr/` — Sort, TopK, GroupBy and aggregate implementations. +- `operators/Join/` — HashJoin implementation. + +Reading the tests +----------------- + +For concrete examples of how SQL statements map to operator pipelines, read the integration/unit tests in `operators/test/` (and other test files under `operators/`). The tests build real pipelines (CSV/InMemory -> Filter/Project/Join/GroupBy/Sort/etc.) and show the exact constructor calls and expressions used to represent SQL queries. They are the best source of truth for small end-to-end examples. \ No newline at end of file diff --git a/src/Backend/opti-sql-go/operators/aggr/groupBy.go b/src/Backend/opti-sql-go/operators/aggr/groupBy.go index 962a450..7ca86ea 100644 --- a/src/Backend/opti-sql-go/operators/aggr/groupBy.go +++ b/src/Backend/opti-sql-go/operators/aggr/groupBy.go @@ -430,6 +430,16 @@ func buildDynamicArray(mem memory.Allocator, dt arrow.DataType, values []any) ar // =========================== // UNSUPPORTED TYPE // =========================== + case arrow.BOOL: + b := array.NewBooleanBuilder(mem) + for _, v := range values { + if v == nil { + b.AppendNull() + } else { + b.Append(castToBool(v)) + } + } + return b.NewArray() default: panic(fmt.Sprintf("unsupported dynamic array type: %v", dt)) } @@ -440,3 +450,9 @@ func buildFloatArray(mem memory.Allocator, values []float64) arrow.Array { b.AppendValues(values, nil) return b.NewArray() } +func castToBool(v any) bool { + if v == "true" || v == true { + return true + } + return false +} diff --git a/src/Backend/opti-sql-go/operators/aggr/groupBy_test.go b/src/Backend/opti-sql-go/operators/aggr/groupBy_test.go index 10756f0..41434ac 100644 --- a/src/Backend/opti-sql-go/operators/aggr/groupBy_test.go +++ b/src/Backend/opti-sql-go/operators/aggr/groupBy_test.go @@ -150,7 +150,6 @@ func TestNewGroupByExecAndSchema(t *testing.T) { if schema == nil { t.Fatalf("schema should not be nil") } - fmt.Println(schema) // group-by + 1 agg = 2 fields if got, want := schema.NumFields(), 2; got != want { @@ -198,7 +197,6 @@ func TestNewGroupByExecAndSchema(t *testing.T) { } schema := gb.Schema() - fmt.Printf("schema: %v\n", schema) wantFields := len(groupBy) + len(aggs) if schema.NumFields() != wantFields { t.Fatalf("expected %d fields, got %d", wantFields, schema.NumFields()) diff --git a/src/Backend/opti-sql-go/operators/aggr/singleAggr.go b/src/Backend/opti-sql-go/operators/aggr/singleAggr.go index 1fcccdd..0f7c3b5 100644 --- a/src/Backend/opti-sql-go/operators/aggr/singleAggr.go +++ b/src/Backend/opti-sql-go/operators/aggr/singleAggr.go @@ -266,7 +266,7 @@ func validAggrType(dt arrow.DataType) bool { } func castArrayToFloat64(arr arrow.Array) (arrow.Array, error) { - outDatum, err := compute.CastArray(context.TODO(), arr, compute.NewCastOptions(&arrow.Float64Type{}, true)) + outDatum, err := compute.CastArray(context.Background(), arr, compute.NewCastOptions(&arrow.Float64Type{}, true)) if err != nil { return nil, err } diff --git a/src/Backend/opti-sql-go/operators/aggr/sort.go b/src/Backend/opti-sql-go/operators/aggr/sort.go index 18ca64a..1b731f8 100644 --- a/src/Backend/opti-sql-go/operators/aggr/sort.go +++ b/src/Backend/opti-sql-go/operators/aggr/sort.go @@ -119,7 +119,7 @@ func (s *SortExec) Next(n uint16) (*operators.RecordBatch, error) { idxArray := idxToArrowArray(idx, mem) defer idxArray.Release() for i := range len(allColumns) { - arr, err := compute.TakeArray(context.TODO(), allColumns[i], idxArray) + arr, err := compute.TakeArray(context.Background(), allColumns[i], idxArray) if err != nil { return nil, err } @@ -160,7 +160,7 @@ func (s *SortExec) Close() error { return s.input.Close() } func (s *SortExec) consumeSortedBatch(readsize uint64, mem memory.Allocator) ([]arrow.Array, error) { - ctx := context.TODO() + ctx := context.Background() resultColumns := make([]arrow.Array, len(s.schema.Fields())) offsetArray := genoffsetTakeIdx(s.consumedOffset, readsize, mem) defer offsetArray.Release() @@ -347,7 +347,7 @@ func joinArrays(existing, newarrs []arrow.Array, mem memory.Allocator) ([]arrow. } func (t *TopKSortExec) consumeSortedBatch(readsize uint64, mem memory.Allocator) ([]arrow.Array, error) { - ctx := context.TODO() + ctx := context.Background() resultColumns := make([]arrow.Array, len(t.schema.Fields())) offsetArray := genoffsetTakeIdx(t.consumedOffset, readsize, mem) defer offsetArray.Release() diff --git a/src/Backend/opti-sql-go/operators/filter/filter.go b/src/Backend/opti-sql-go/operators/filter/filter.go index a476ac8..d09f4a2 100644 --- a/src/Backend/opti-sql-go/operators/filter/filter.go +++ b/src/Backend/opti-sql-go/operators/filter/filter.go @@ -122,7 +122,7 @@ func (f *FilterExec) Close() error { func ApplyBooleanMask(col arrow.Array, mask *array.Boolean) (arrow.Array, error) { datum, err := compute.Filter( - context.TODO(), + context.Background(), compute.NewDatum(col), compute.NewDatum(mask), *compute.DefaultFilterOptions(), @@ -163,12 +163,11 @@ func validPredicates(pred Expr.Expression, schema *arrow.Schema) bool { if err != nil { return false } - //TODO: allow for nulls to be comparable fmt.Printf("dt1:\t%v\ndt2:\t%v\n", dt1, dt2) if !arrow.TypeEqual(dt1, dt2) { return false } - // recursively validate children + fmt.Printf("left:\t%v\nright:\t%v\n", p.Left, p.Right) return validPredicates(p.Left, schema) && validPredicates(p.Right, schema) @@ -177,6 +176,8 @@ func validPredicates(pred Expr.Expression, schema *arrow.Schema) bool { case *Expr.NullCheckExpr: return validPredicates(p.Expr, schema) + case *Expr.ScalarFunction: + return true default: return false } @@ -215,16 +216,17 @@ func (f *FilterExec) sliceFilterCols(n int64, mem memory.Allocator) ([]arrow.Arr defer keepArr.Release() // For each column: materialize output slice + update buffer + ctx := context.Background() for i, col := range f.bufferedCols { // emit slice - sliceOut, err := compute.TakeArray(context.TODO(), col, emitArr) + sliceOut, err := compute.TakeArray(ctx, col, emitArr) if err != nil { return nil, err } out[i] = sliceOut // keep remaining slice - keepSlice, err := compute.TakeArray(context.TODO(), col, keepArr) + keepSlice, err := compute.TakeArray(ctx, col, keepArr) if err != nil { return nil, err } diff --git a/src/Backend/opti-sql-go/operators/filter/filter_test.go b/src/Backend/opti-sql-go/operators/filter/filter_test.go index 9facb8c..8e90489 100644 --- a/src/Backend/opti-sql-go/operators/filter/filter_test.go +++ b/src/Backend/opti-sql-go/operators/filter/filter_test.go @@ -2,7 +2,6 @@ package filter import ( "errors" - "fmt" "io" "opti-sql-go/Expr" "testing" @@ -333,16 +332,14 @@ func TestFilterBuffer(t *testing.T) { if err != nil { t.Fatalf("failed to create filter exec: %v", err) } - rc, err := f.Next(5) + _, err = f.Next(5) if err != nil { t.Fatalf("unexpected error: %v", err) } - fmt.Printf("First Batch:\t%v\n", rc.PrettyPrint()) - rc, err = f.Next(5) + _, err = f.Next(5) if err != nil { t.Fatalf("unexpected error: %v", err) } - fmt.Printf("second Batch:\t%v\n", rc.PrettyPrint()) }) } diff --git a/src/Backend/opti-sql-go/operators/filter/limit.go b/src/Backend/opti-sql-go/operators/filter/limit.go index d25b848..6a5aa86 100644 --- a/src/Backend/opti-sql-go/operators/filter/limit.go +++ b/src/Backend/opti-sql-go/operators/filter/limit.go @@ -110,6 +110,7 @@ func (d *DistinctExec) Next(n uint16) (*operators.RecordBatch, error) { return nil, io.EOF } mem := memory.NewGoAllocator() + ctx := context.Background() if !d.consumedInput { for { childBatch, err := d.input.Next(math.MaxUint16) @@ -161,7 +162,7 @@ func (d *DistinctExec) Next(n uint16) (*operators.RecordBatch, error) { takeArray := idxToArrowArray(idxTracker, mem) for i := range len(childBatch.Columns) { largeArray := childBatch.Columns[i] - uniqueElements, err := compute.TakeArray(context.TODO(), largeArray, takeArray) + uniqueElements, err := compute.TakeArray(ctx, largeArray, takeArray) if err != nil { return nil, err } @@ -209,7 +210,7 @@ func (d *DistinctExec) Close() error { return d.input.Close() } func (d *DistinctExec) consumeDistinctArrays(readSize uint64, mem memory.Allocator) ([]arrow.Array, error) { - ctx := context.TODO() + ctx := context.Background() resultColumns := make([]arrow.Array, len(d.schema.Fields())) offsetArray := genoffsetTakeIdx(d.consumedOffset, readSize, mem) defer offsetArray.Release() diff --git a/src/Backend/opti-sql-go/operators/project/parquet.go b/src/Backend/opti-sql-go/operators/project/parquet.go index 42d5c14..50aa856 100644 --- a/src/Backend/opti-sql-go/operators/project/parquet.go +++ b/src/Backend/opti-sql-go/operators/project/parquet.go @@ -49,7 +49,7 @@ func NewParquetSource(r parquet.ReaderAtSeeker) (*ParquetSource, error) { if err != nil { return nil, err } - rdr, err := arrowReader.GetRecordReader(context.TODO(), nil, nil) + rdr, err := arrowReader.GetRecordReader(context.Background(), nil, nil) if err != nil { return nil, err } @@ -98,7 +98,7 @@ func NewParquetSourcePushDown(r parquet.ReaderAtSeeker, columns []string) (*Parq wantedColumnsIDX = append(wantedColumnsIDX, idx_array...) } - rdr, err := arrowReader.GetRecordReader(context.TODO(), wantedColumnsIDX, nil) + rdr, err := arrowReader.GetRecordReader(context.Background(), wantedColumnsIDX, nil) if err != nil { return nil, err } diff --git a/src/Backend/opti-sql-go/operators/project/parquet_test.go b/src/Backend/opti-sql-go/operators/project/parquet_test.go index ff28535..c051d9f 100644 --- a/src/Backend/opti-sql-go/operators/project/parquet_test.go +++ b/src/Backend/opti-sql-go/operators/project/parquet_test.go @@ -34,7 +34,6 @@ schema: metadata: ["PARQUET:field_id": "-1"] - lon: type=float64, nullable */ -// TODO: more to their own files later down the line func existIn(str string, arr []string) bool { for _, a := range arr { if a == str { diff --git a/src/Backend/opti-sql-go/operators/project/projectExecExpr_test.go b/src/Backend/opti-sql-go/operators/project/projectExecExpr_test.go index 47435b7..3832a39 100644 --- a/src/Backend/opti-sql-go/operators/project/projectExecExpr_test.go +++ b/src/Backend/opti-sql-go/operators/project/projectExecExpr_test.go @@ -834,10 +834,3 @@ func TestProjectExec_FunctionExpr(t *testing.T) { } }) } - -/* -complex expr -ex: alias(function(column |operator| literal) |operator| literal) -TODO: not the most important thing right now since we know basic expression are fine -*/ -func TestProjectExec_ComplexExpr(t *testing.T) {} diff --git a/src/Backend/opti-sql-go/operators/test/intergration_test.go b/src/Backend/opti-sql-go/operators/test/intergration_test.go index af51020..15786a9 100644 --- a/src/Backend/opti-sql-go/operators/test/intergration_test.go +++ b/src/Backend/opti-sql-go/operators/test/intergration_test.go @@ -6,6 +6,8 @@ import ( "io" "opti-sql-go/Expr" "opti-sql-go/operators" + join "opti-sql-go/operators/Join" + aggr "opti-sql-go/operators/aggr" "opti-sql-go/operators/filter" "opti-sql-go/operators/project" "os" @@ -114,11 +116,11 @@ func TestSelectFilterLimit(t *testing.T) { } if batch == nil { - t.Logf("(1.A) got nil batch (possibly EOF)") + t.Logf("(1A) got nil batch (possibly EOF)") return } - t.Logf("(1.A) batch:\n%v\n", batch.PrettyPrint()) + t.Logf("(1A) batch:\n%v\n", batch.PrettyPrint()) }) // (1.B) SELECT username, age_years FROM source1 WHERE is_active = true AND age_years < 25 LIMIT 3; @@ -162,14 +164,14 @@ func TestSelectFilterLimit(t *testing.T) { } if batch == nil { - t.Logf("(1.B) got nil batch (possibly EOF)") + t.Logf("(1B) got nil batch (possibly EOF)") return } - t.Logf("(1.B) batch:\n%v\n", batch.PrettyPrint()) + t.Logf("(1B) batch:\n%v\n", batch.PrettyPrint()) }) // (1.C) SELECT id, favorite_color FROM source1 WHERE favorite_color = 'Red' LIMIT 7; - t.Run("(1.C)", func(t *testing.T) { + t.Run("1C", func(t *testing.T) { src := source1Project() pred := Expr.NewBinaryExpr( @@ -203,101 +205,503 @@ func TestSelectFilterLimit(t *testing.T) { } if batch == nil { - t.Logf("(1.C) got nil batch (possibly EOF)") + t.Logf("(1C) got nil batch (possibly EOF)") return } - t.Logf("(1.C) batch:\n%v\n", batch.PrettyPrint()) + t.Logf("(1C) batch:\n%v\n", batch.PrettyPrint()) }) } -/* -(2) -Operators: Filter, Scalar functions -sql query: -(2.A)SELECT id, username, LOWER(favorite_color) as fav_color_lower FROM source1 WHERE UPPER(favorite_color) = 'BLUE'; -(2.B)SELECT username, LOWER(email_address) AS email_lower -FROM source1 -WHERE UPPER(username) = 'ALICE'; -*/ +// ------------------------------------------------------------------------- +// (2) Operators: Filter, Scalar functions +// (2.A) SELECT id, username, LOWER(favorite_color) as fav_color_lower FROM source1 WHERE UPPER(favorite_color) = 'BLUE'; +// (2.B) SELECT username, LOWER(email_address) AS email_lower FROM source1 WHERE UPPER(username) = 'ALICE'; +func TestFilterScalarFunctions(t *testing.T) { + // (2.A) SELECT id, username, LOWER(favorite_color) as fav_color_lower FROM source1 WHERE UPPER(favorite_color) = 'BLUE'; + t.Run("2A", func(t *testing.T) { + src := source1Project() -/* -(3) -Operators: select, Sort -sql query: -(3.A)SELECT id, account_balance_usd, username -FROM source1 -ORDER BY account_balance_usd ASC -(3.B)SELECT id, favorite_color -FROM source1 -ORDER BY favorite_color ASC; -*/ + pred := Expr.NewBinaryExpr( + Expr.NewScalarFunction(Expr.Upper, Expr.NewColumnResolve("favorite_color")), + Expr.Equal, + Expr.NewLiteralResolve(arrow.BinaryTypes.String, "BLUE"), + ) -/* -(4) -Operators: Join(INNER), Select -SQL: -(4.A)SELECT s1.id, s1.username, s2.department_name -FROM source1 AS s1 -INNER JOIN source2 AS s2 -ON s1.favorite_color = s2.manager_name; -(4.B)SELECT s1.id, s1.email_address, s2.department_name -FROM source1 AS s1 -INNER JOIN source2 AS s2 -ON s1.favorite_color = s2.manager_name; -*/ + filt, err := filter.NewFilterExec(src, pred) + if err != nil { + t.Fatalf("filter init failed: %v", err) + } -/* -(5) -Operators: GroupBy, Aggregation(SUM, AVG), Select -SQL: -(5.A)SELECT favorite_color, AVG(age_years) AS avg_age, SUM(account_balance_usd) AS total_balance -FROM source1 -GROUP BY favorite_color; -(5.B)SELECT is_active, COUNT(*) AS active_count, AVG(age_years) AS avg_age -FROM source1 -GROUP BY is_active; + exprs := Expr.NewExpressions( + Expr.NewColumnResolve("id"), + Expr.NewColumnResolve("username"), + Expr.NewAlias(Expr.NewScalarFunction(Expr.Lower, Expr.NewColumnResolve("favorite_color")), "fav_color_lower"), + ) + proj, err := project.NewProjectExec(filt, exprs) + if err != nil { + t.Fatalf("project init failed: %v", err) + } -*/ + batch, err := proj.Next(100) + if err != nil && !errors.Is(err, io.EOF) { + t.Fatalf("unexpected error: %v", err) + } + if batch == nil { + t.Logf("(2A) got nil batch (possibly EOF)") + return + } + t.Logf("(2A) batch:\n%v\n", batch.PrettyPrint()) + }) -/* -(6) -Operators: Distinct, Sort(DESC) -SQL: -(6.A)SELECT DISTINCT favorite_color -FROM source1 -ORDER BY favorite_color DESC; -(6.B)SELECT DISTINCT is_active -FROM source1 -ORDER BY is_active DESC; + // (2.B) SELECT username, LOWER(email_address) AS email_lower FROM source1 WHERE UPPER(username) = 'ALICE'; + t.Run("2B", func(t *testing.T) { + src := source1Project() -*/ + pred := Expr.NewBinaryExpr( + Expr.NewScalarFunction(Expr.Upper, Expr.NewColumnResolve("username")), + Expr.Equal, + Expr.NewLiteralResolve(arrow.BinaryTypes.String, "ALICE"), + ) -/* -(7) -Operators: Join(INNER), Filter, Projection, Limit + filt, err := filter.NewFilterExec(src, pred) + if err != nil { + t.Fatalf("filter init failed: %v", err) + } -SQL: -(7.A)SELECT s1.id, s1.username, s2.department_name -FROM source1 AS s1 -INNER JOIN source2 AS s2 -ON s1.favorite_color = s2.manager_name -WHERE s1.age_years > 30 -LIMIT 5; -(7.B)SELECT s1.username, s2.manager_email -FROM source1 AS s1 -JOIN source2 AS s2 -ON s1.favorite_color = s2.manager_name -WHERE s2.department_name = 'Engineering' -LIMIT 3; -(7.C)SELECT s1.id, s2.manager_name -FROM source1 s1 -JOIN source2 s2 -ON s1.favorite_color = s2.manager_name -WHERE s1.account_balance_usd > 10000 -LIMIT 2; -*/ + exprs := Expr.NewExpressions( + Expr.NewColumnResolve("username"), + Expr.NewAlias(Expr.NewScalarFunction(Expr.Lower, Expr.NewColumnResolve("email_address")), "email_lower"), + ) + proj, err := project.NewProjectExec(filt, exprs) + if err != nil { + t.Fatalf("project init failed: %v", err) + } + batch, err := proj.Next(100) + if err != nil && !errors.Is(err, io.EOF) { + t.Fatalf("unexpected error: %v", err) + } + if batch != nil { + t.Fatalf("was expecting an empty batch but recieved %s\n", batch.PrettyPrint()) + return + } + }) +} + +// ------------------------------------------------------------------------- +// (3) Operators: select, Sort +// (3.A) SELECT id, account_balance_usd, username FROM source1 ORDER BY account_balance_usd ASC +// (3.B) SELECT id, favorite_color FROM source1 ORDER BY favorite_color ASC; +func TestSelectSort(t *testing.T) { + // (3.A) SELECT id, account_balance_usd, username FROM source1 ORDER BY account_balance_usd ASC + t.Run("3A", func(t *testing.T) { + src := source1Project() + exprs := Expr.NewExpressions( + Expr.NewColumnResolve("id"), + Expr.NewColumnResolve("account_balance_usd"), + Expr.NewColumnResolve("username"), + ) + proj, err := project.NewProjectExec(src, exprs) + if err != nil { + t.Fatalf("project init failed: %v", err) + } + + sk := aggr.NewSortKey(Expr.NewColumnResolve("account_balance_usd"), true) + sortExec, err := aggr.NewSortExec(proj, aggr.CombineSortKeys(sk)) + if err != nil { + t.Fatalf("sort init failed: %v", err) + } + batch, err := sortExec.Next(100) + if err != nil && !errors.Is(err, io.EOF) { + t.Fatalf("unexpected error: %v", err) + } + if batch == nil { + t.Logf("(3A) got nil batch (possibly EOF)") + return + } + t.Logf("(3A) batch:\n%v\n", batch.PrettyPrint()) + }) + + // (3.B) SELECT id, favorite_color FROM source1 ORDER BY favorite_color ASC; + t.Run("3B", func(t *testing.T) { + src := source1Project() + exprs := Expr.NewExpressions( + Expr.NewColumnResolve("id"), + Expr.NewColumnResolve("favorite_color"), + ) + proj, err := project.NewProjectExec(src, exprs) + if err != nil { + t.Fatalf("project init failed: %v", err) + } + sk := aggr.NewSortKey(Expr.NewColumnResolve("favorite_color"), true) + sortExec, err := aggr.NewSortExec(proj, aggr.CombineSortKeys(sk)) + if err != nil { + t.Fatalf("sort init failed: %v", err) + } + batch, err := sortExec.Next(100) + if err != nil && !errors.Is(err, io.EOF) { + t.Fatalf("unexpected error: %v", err) + } + if batch == nil { + t.Logf("(3B) got nil batch (possibly EOF)") + return + } + t.Logf("(3B) batch:\n%v\n", batch.PrettyPrint()) + }) +} + +// ------------------------------------------------------------------------- +// (4) Operators: Join(INNER), Select +// (4.A) SELECT s1.id, s1.username, s2.department_name FROM source1 AS s1 INNER JOIN source2 AS s2 ON s1.id = s2.id; +// (4.B) SELECT s1.id, s1.email_address, s2.department_name FROM source1 AS s1 INNER JOIN source2 AS s2 ON s1.id = s2.id; +func TestJoinSelect(t *testing.T) { + // (4.A) SELECT s1.id, s1.username, s2.department_name FROM source1 AS s1 INNER JOIN source2 AS s2 ON s1.favorite_color = s2.manager_name; + t.Run("4A", func(t *testing.T) { + src1 := source1Project() + src2 := source2Project() + clause := join.NewJoinClause( + []Expr.Expression{Expr.NewColumnResolve("id")}, + []Expr.Expression{Expr.NewColumnResolve("id")}, + ) + j, err := join.NewHashJoinExec(src1, src2, clause, join.InnerJoin, nil) + if err != nil { + t.Fatalf("join init failed: %v", err) + } + exprs := Expr.NewExpressions( + Expr.NewAlias(Expr.NewColumnResolve("left_id"), "id"), + Expr.NewColumnResolve("username"), + Expr.NewColumnResolve("department_name"), + ) + t.Logf("\t%v\n", j.Schema()) + proj, err := project.NewProjectExec(j, exprs) + if err != nil { + t.Fatalf("project init failed: %v", err) + } + batch, err := proj.Next(100) + if err != nil && !errors.Is(err, io.EOF) { + t.Fatalf("unexpected error: %v", err) + } + if batch == nil { + t.Logf("(4A) got nil batch (possibly EOF)") + return + } + t.Logf("(4A) batch:\n%v\n", batch.PrettyPrint()) + }) + + // (4.B) SELECT s1.id, s1.email_address, s2.department_name FROM source1 AS s1 INNER JOIN source2 AS s2 ON s1.id = s2.id; + t.Run("4B", func(t *testing.T) { + src1 := source1Project() + src2 := source2Project() + clause := join.NewJoinClause( + []Expr.Expression{Expr.NewColumnResolve("id")}, + []Expr.Expression{Expr.NewColumnResolve("id")}, + ) + j, err := join.NewHashJoinExec(src1, src2, clause, join.InnerJoin, nil) + if err != nil { + t.Fatalf("join init failed: %v", err) + } + exprs := Expr.NewExpressions( + Expr.NewAlias(Expr.NewColumnResolve("left_id"), "cool_guy_id"), + Expr.NewColumnResolve("email_address"), + Expr.NewColumnResolve("department_name"), + ) + proj, err := project.NewProjectExec(j, exprs) + if err != nil { + t.Fatalf("project init failed: %v", err) + } + batch, err := proj.Next(100) + if err != nil && !errors.Is(err, io.EOF) { + t.Fatalf("unexpected error: %v", err) + } + if batch == nil { + t.Logf("(4B) got nil batch (possibly EOF)") + return + } + t.Logf("(4B) batch:\n%v\n", batch.PrettyPrint()) + }) +} + +func TestGroupByAggregation(t *testing.T) { + // (5.A) SELECT favorite_color, AVG(age_years) AS avg_age, SUM(account_balance_usd) AS total_balance FROM source1 GROUP BY favorite_color order by avg_age; + t.Run("5A", func(t *testing.T) { + src := source1Project() + + groupBy := []Expr.Expression{Expr.NewColumnResolve("favorite_color")} + aggs := []aggr.AggregateFunctions{ + aggr.NewAggregateFunctions(aggr.Avg, Expr.NewColumnResolve("age_years")), + aggr.NewAggregateFunctions(aggr.Sum, Expr.NewColumnResolve("account_balance_usd")), + } + + gb, err := aggr.NewGroupByExec(src, aggs, groupBy) + if err != nil { + t.Fatalf("groupby init failed: %v", err) + } + sortExec, err := aggr.NewSortExec(gb, aggr.CombineSortKeys(aggr.NewSortKey(Expr.NewColumnResolve("avg_Column(age_years)"), true))) + if err != nil { + t.Fatalf("sort init failed: %v", err) + } + + batch, err := sortExec.Next(1000) + if err != nil && !errors.Is(err, io.EOF) { + t.Fatalf("unexpected error: %v", err) + } + if batch == nil { + t.Logf("(5A) got nil batch (possibly EOF)") + return + } + t.Logf("(5A) batch:\n%v\n", batch.PrettyPrint()) + }) + + // (5.B) SELECT is_active, COUNT(*) AS active_count, AVG(age_years) AS avg_age FROM source1 GROUP BY is_active; + t.Run("5B", func(t *testing.T) { + src := source1Project() + groupBy := []Expr.Expression{Expr.NewColumnResolve("is_active")} + aggs := []aggr.AggregateFunctions{ + aggr.NewAggregateFunctions(aggr.Count, Expr.NewColumnResolve("id")), + aggr.NewAggregateFunctions(aggr.Avg, Expr.NewColumnResolve("age_years")), + } + + gb, err := aggr.NewGroupByExec(src, aggs, groupBy) + if err != nil { + t.Fatalf("groupby init failed: %v", err) + } + + batch, err := gb.Next(1000) + if err != nil && !errors.Is(err, io.EOF) { + t.Fatalf("unexpected error: %v", err) + } + if batch == nil { + t.Logf("(5B) got nil batch (possibly EOF)") + return + } + t.Logf("(5B) batch:\n%v\n", batch.PrettyPrint()) + }) +} + +// TestDistinctSort runs DISTINCT + Sort pipelines for source1 +// (6.A)SELECT DISTINCT favorite_color +// FROM source1 +// ORDER BY favorite_color DESC; +// (6.B)SELECT DISTINCT is_active +// FROM source1 +// ORDER BY is_active DESC; +func TestDistinctSort(t *testing.T) { + // (6.A) SELECT DISTINCT favorite_color FROM source1 ORDER BY favorite_color DESC; + t.Run("6A", func(t *testing.T) { + src := source1Project() + + cols := []Expr.Expression{Expr.NewColumnResolve("favorite_color")} + distinct, err := filter.NewDistinctExec(src, cols) + if err != nil { + t.Fatalf("distinct init failed: %v", err) + } + + sk := aggr.NewSortKey(Expr.NewColumnResolve("favorite_color"), false) // DESC + sortExec, err := aggr.NewSortExec(distinct, aggr.CombineSortKeys(sk)) + if err != nil { + t.Fatalf("sort init failed: %v", err) + } + proj, err := project.NewProjectExec(sortExec, Expr.NewExpressions(Expr.NewColumnResolve("favorite_color"))) + if err != nil { + t.Fatalf("project init failed: %v", err) + } + + batch, err := proj.Next(100) + if err != nil && !errors.Is(err, io.EOF) { + t.Fatalf("unexpected error: %v", err) + } + if batch == nil { + t.Logf("(6A) got nil batch (possibly EOF)") + return + } + t.Logf("(6A) batch:\n%v\n", batch.PrettyPrint()) + }) + + // (6.B) SELECT DISTINCT is_active FROM source1 ORDER BY is_active DESC; + t.Run("6B", func(t *testing.T) { + src := source1Project() + + cols := []Expr.Expression{Expr.NewColumnResolve("is_active")} + distinct, err := filter.NewDistinctExec(src, cols) + if err != nil { + t.Fatalf("distinct init failed: %v", err) + } + + sk := aggr.NewSortKey(Expr.NewColumnResolve("is_active"), false) // DESC + sortExec, err := aggr.NewSortExec(distinct, aggr.CombineSortKeys(sk)) + if err != nil { + t.Fatalf("sort init failed: %v", err) + } + proj, err := project.NewProjectExec(sortExec, Expr.NewExpressions(Expr.NewColumnResolve("is_active"))) + if err != nil { + t.Fatalf("project init failed: %v", err) + } + + batch, err := proj.Next(100) + if err != nil && !errors.Is(err, io.EOF) { + t.Fatalf("unexpected error: %v", err) + } + if batch == nil { + t.Logf("(6B) got nil batch (possibly EOF)") + return + } + t.Logf("(6B) batch:\n%v\n", batch.PrettyPrint()) + }) +} + +// TestJoinFilterProjLimit runs join + filter + project + limit pipelines +// (7.A)SELECT s1.id, s1.username, s2.department_name FROM source1 AS s1 INNER JOIN source2 AS s2 ON s1.id = s2.id WHERE s1.age_years > 30 LIMIT 5; +// (7.B)SELECT s1.username, s2.manager_email FROM source1 AS s1 JOIN source2 AS s2 ON s1.id = s2.id WHERE s2.department_name = 'Engineering' LIMIT 3; +// (7.C)SELECT s1.id, s2.manager_name FROM source1 s1 JOIN source2 s2 ON s1.id = s2.id WHERE s1.account_balance_usd > 10000 LIMIT 2; +func TestJoinFilterProjLimit(t *testing.T) { + // (7.A)SELECT s1.id, s1.username, s2.department_name FROM source1 AS s1 INNER JOIN source2 AS s2 ON s1.id = s2.id WHERE s1.age_years > 30 LIMIT 5; + t.Run("7A", func(t *testing.T) { + src1 := source1Project() + src2 := source2Project() + clause := join.NewJoinClause( + []Expr.Expression{Expr.NewColumnResolve("id")}, + []Expr.Expression{Expr.NewColumnResolve("id")}, + ) + j, err := join.NewHashJoinExec(src1, src2, clause, join.InnerJoin, nil) + if err != nil { + t.Fatalf("join init failed: %v", err) + } + pred := Expr.NewBinaryExpr( + Expr.NewColumnResolve("age_years"), + Expr.GreaterThan, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Int64, 30), + ) + + filt, err := filter.NewFilterExec(j, pred) + if err != nil { + t.Fatalf("filter init failed: %v", err) + } + + exprs := Expr.NewExpressions( + Expr.NewAlias(Expr.NewColumnResolve("left_id"), "id"), + Expr.NewColumnResolve("username"), + Expr.NewAlias(Expr.NewColumnResolve("department_name"), "deptartment"), + ) + proj, err := project.NewProjectExec(filt, exprs) + if err != nil { + t.Fatalf("project init failed: %v", err) + } + + lim, err := filter.NewLimitExec(proj, 5) + if err != nil { + t.Fatalf("limit init failed: %v", err) + } + + batch, err := lim.Next(100) + if err != nil && !errors.Is(err, io.EOF) { + t.Fatalf("unexpected error: %v", err) + } + if batch == nil { + t.Logf("(7A) got nil batch (possibly EOF)") + return + } + t.Logf("(7A) batch:\n%v\n", batch.PrettyPrint()) + }) + + // (7.B)SELECT s1.username, s2.manager_email FROM source1 AS s1 JOIN source2 AS s2 ON s1.id = s2.id WHERE s2.department_name = 'Engineering' LIMIT 3; + t.Run("7B", func(t *testing.T) { + src1 := source1Project() + src2 := source2Project() + clause := join.NewJoinClause( + []Expr.Expression{Expr.NewColumnResolve("id")}, + []Expr.Expression{Expr.NewColumnResolve("id")}, + ) + j, err := join.NewHashJoinExec(src1, src2, clause, join.InnerJoin, nil) + if err != nil { + t.Fatalf("join init failed: %v", err) + } + + pred := Expr.NewBinaryExpr( + Expr.NewColumnResolve("department_name"), + Expr.Equal, + Expr.NewLiteralResolve(arrow.BinaryTypes.String, "Engineering"), + ) + + filt, err := filter.NewFilterExec(j, pred) + if err != nil { + t.Fatalf("filter init failed: %v", err) + } + + exprs := Expr.NewExpressions( + Expr.NewColumnResolve("username"), + Expr.NewColumnResolve("manager_email"), + ) + proj, err := project.NewProjectExec(filt, exprs) + if err != nil { + t.Fatalf("project init failed: %v", err) + } + + lim, err := filter.NewLimitExec(proj, 3) + if err != nil { + t.Fatalf("limit init failed: %v", err) + } + + batch, err := lim.Next(100) + if err != nil && !errors.Is(err, io.EOF) { + t.Fatalf("unexpected error: %v", err) + } + if batch == nil { + t.Logf("(7B) got nil batch (possibly EOF)") + return + } + t.Logf("(7B) batch:\n%v\n", batch.PrettyPrint()) + }) + + // (7.C)SELECT s1.id, s2.manager_name FROM source1 s1 JOIN source2 s2 ON s1.id = s2.id WHERE s1.account_balance_usd > 10000 LIMIT 2; + t.Run("7C", func(t *testing.T) { + src1 := source1Project() + src2 := source2Project() + clause := join.NewJoinClause( + []Expr.Expression{Expr.NewColumnResolve("id")}, + []Expr.Expression{Expr.NewColumnResolve("id")}, + ) + j, err := join.NewHashJoinExec(src1, src2, clause, join.InnerJoin, nil) + if err != nil { + t.Fatalf("join init failed: %v", err) + } + + pred := Expr.NewBinaryExpr( + Expr.NewColumnResolve("account_balance_usd"), + Expr.GreaterThan, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Float64, 10000.0), + ) + + filt, err := filter.NewFilterExec(j, pred) + if err != nil { + t.Fatalf("filter init failed: %v", err) + } + + exprs := Expr.NewExpressions( + Expr.NewColumnResolve("left_id"), + Expr.NewColumnResolve("manager_name"), + ) + proj, err := project.NewProjectExec(filt, exprs) + if err != nil { + t.Fatalf("project init failed: %v", err) + } + + lim, err := filter.NewLimitExec(proj, 2) + if err != nil { + t.Fatalf("limit init failed: %v", err) + } + + batch, err := lim.Next(100) + if err != nil && !errors.Is(err, io.EOF) { + t.Fatalf("unexpected error: %v", err) + } + if batch == nil { + t.Logf("(7C) got nil batch (possibly EOF)") + return + } + t.Logf("(7C) batch:\n%v\n", batch.PrettyPrint()) + }) +} /* (8) @@ -312,29 +716,149 @@ FROM source1 WHERE ABS(account_balance_usd) > 5000; */ -/* -(9) -Operators: Sort (multiple columns), Select +// TestScalarAbsRound runs scalar ABS/ROUND with Filter + Projection +// (8.A)SELECT id, ROUND(ABS(average_session_minutes)) AS rounded_session FROM source1 WHERE ABS(average_session_minutes) > 5; +// (8.B)SELECT username, ROUND(account_balance_usd) AS rounded_balance FROM source1 WHERE ABS(account_balance_usd) > 5000; +func TestScalarAbsRound(t *testing.T) { + // (8.A)SELECT id, ROUND(ABS(average_session_minutes)) AS rounded_session FROM source1 WHERE ABS(average_session_minutes) > 5; + t.Run("8A", func(t *testing.T) { + src := source1Project() -SQL: -(9.A)SELECT id, username, age_years -FROM source1 -ORDER BY age_years DESC, username ASC; -(9.B)SELECT id, email_address, age_years -FROM source1 -ORDER BY age_years ASC, email_address DESC; + // predicate: ABS(average_session_minutes) > 5 + pred := Expr.NewBinaryExpr( + Expr.NewScalarFunction(Expr.Abs, Expr.NewColumnResolve("average_session_minutes")), + Expr.GreaterThan, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Float64, 5.0), + ) -*/ + filt, err := filter.NewFilterExec(src, pred) + if err != nil { + t.Fatalf("filter init failed: %v", err) + } -/* -(10) -Operators: Join (INNER, multiple conditions), Select, Sort (multiple columns) + // projection: id, ROUND(ABS(average_session_minutes)) as rounded_session + roundExpr := Expr.NewScalarFunction(Expr.Round, Expr.NewScalarFunction(Expr.Abs, Expr.NewColumnResolve("average_session_minutes"))) + exprs := Expr.NewExpressions( + Expr.NewColumnResolve("id"), + Expr.NewAlias(roundExpr, "rounded_session"), + ) + proj, err := project.NewProjectExec(filt, exprs) + if err != nil { + t.Fatalf("project init failed: %v", err) + } -(10.A)SELECT s1.id, s1.username, s2.manager_name, s2.budget -FROM source1 AS s1 -INNER JOIN source2 AS s2 - ON s1.favorite_color = s2.manager_name - AND s1.region = s2.region -ORDER BY s2.budget DESC, s1.username ASC; + batch, err := proj.Next(100) + if err != nil && !errors.Is(err, io.EOF) { + t.Fatalf("unexpected error: %v", err) + } + if batch == nil { + t.Logf("(8A) got nil batch (possibly EOF)") + return + } + t.Logf("(8A) batch:\n%v\n", batch.PrettyPrint()) + }) -*/ + // (8.B)SELECT username, ROUND(account_balance_usd) AS rounded_balance FROM source1 WHERE ABS(account_balance_usd) > 5000; + t.Run("8B", func(t *testing.T) { + src := source1Project() + + // predicate: ABS(account_balance_usd) > 5000 + pred := Expr.NewBinaryExpr( + Expr.NewScalarFunction(Expr.Abs, Expr.NewColumnResolve("account_balance_usd")), + Expr.GreaterThan, + Expr.NewLiteralResolve(arrow.PrimitiveTypes.Float64, 5000.0), + ) + + filt, err := filter.NewFilterExec(src, pred) + if err != nil { + t.Fatalf("filter init failed: %v", err) + } + + roundExpr := Expr.NewScalarFunction(Expr.Round, Expr.NewColumnResolve("account_balance_usd")) + exprs := Expr.NewExpressions( + Expr.NewColumnResolve("username"), + Expr.NewAlias(roundExpr, "rounded_balance"), + ) + proj, err := project.NewProjectExec(filt, exprs) + if err != nil { + t.Fatalf("project init failed: %v", err) + } + + batch, err := proj.Next(100) + if err != nil && !errors.Is(err, io.EOF) { + t.Fatalf("unexpected error: %v", err) + } + if batch == nil { + t.Logf("(8B) got nil batch (possibly EOF)") + return + } + t.Logf("(8B) batch:\n%v\n", batch.PrettyPrint()) + }) +} + +// TestSelectMultiSort runs multi-column ORDER BY tests +// (9.A)SELECT id, username, age_years FROM source1 ORDER BY age_years DESC, username ASC; +// (9.B)SELECT id, email_address, age_years FROM source1 ORDER BY age_years ASC, email_address DESC; +func TestSelectMultiSort(t *testing.T) { + // (9.A) + t.Run("9A", func(t *testing.T) { + src := source1Project() + exprs := Expr.NewExpressions( + Expr.NewColumnResolve("id"), + Expr.NewColumnResolve("username"), + Expr.NewColumnResolve("age_years"), + ) + proj, err := project.NewProjectExec(src, exprs) + if err != nil { + t.Fatalf("project init failed: %v", err) + } + + sk1 := aggr.NewSortKey(Expr.NewColumnResolve("age_years"), false) // DESC + sk2 := aggr.NewSortKey(Expr.NewColumnResolve("username"), true) // ASC + sortExec, err := aggr.NewSortExec(proj, aggr.CombineSortKeys(sk1, sk2)) + if err != nil { + t.Fatalf("sort init failed: %v", err) + } + + batch, err := sortExec.Next(100) + if err != nil && !errors.Is(err, io.EOF) { + t.Fatalf("unexpected error: %v", err) + } + if batch == nil { + t.Logf("(9A) got nil batch (possibly EOF)") + return + } + t.Logf("(9A) batch:\n%v\n", batch.PrettyPrint()) + }) + + // (9.B) + t.Run("9B", func(t *testing.T) { + src := source1Project() + exprs := Expr.NewExpressions( + Expr.NewColumnResolve("id"), + Expr.NewColumnResolve("email_address"), + Expr.NewColumnResolve("age_years"), + ) + proj, err := project.NewProjectExec(src, exprs) + if err != nil { + t.Fatalf("project init failed: %v", err) + } + + sk1 := aggr.NewSortKey(Expr.NewColumnResolve("age_years"), true) // ASC + sk2 := aggr.NewSortKey(Expr.NewColumnResolve("email_address"), false) // DESC + sortExec, err := aggr.NewSortExec(proj, aggr.CombineSortKeys(sk1, sk2)) + if err != nil { + t.Fatalf("sort init failed: %v", err) + } + + batch, err := sortExec.Next(100) + if err != nil && !errors.Is(err, io.EOF) { + t.Fatalf("unexpected error: %v", err) + } + if batch == nil { + t.Logf("(9B) got nil batch (possibly EOF)") + return + } + t.Logf("(9B) batch:\n%v\n", batch.PrettyPrint()) + }) +} diff --git a/src/Backend/opti-sql-go/operators/test/t1_test.go b/src/Backend/opti-sql-go/operators/test/t1_test.go index 71bf2b1..dd728fb 100644 --- a/src/Backend/opti-sql-go/operators/test/t1_test.go +++ b/src/Backend/opti-sql-go/operators/test/t1_test.go @@ -2,7 +2,6 @@ package test import ( "errors" - "fmt" "io" "math" "opti-sql-go/Expr" @@ -991,7 +990,6 @@ func TestHavingExec(t *testing.T) { ) hv, _ := aggr.NewHavingExec(gb, having) - fmt.Printf("\t%v\n", hv.Schema()) batch, err := hv.Next(500) if err != nil { t.Fatalf("having next failed: %v", err) diff --git a/src/Backend/opti-sql-go/substrait/substrait_test.go b/src/Backend/opti-sql-go/substrait/substrait_test.go index 122ad0b..fe23790 100644 --- a/src/Backend/opti-sql-go/substrait/substrait_test.go +++ b/src/Backend/opti-sql-go/substrait/substrait_test.go @@ -31,13 +31,13 @@ func TestDummyInput(t *testing.T) { dummyRequest := &QueryExecutionRequest{ SqlStatement: "SELECT * FROM table", SubstraitLogical: []byte("CgJTUxIMCgpTZWxlY3QgKiBGUk9NIHRhYmxl"), - Id: "GenerateDTODOHaasdavdasvasdvada", + Id: "GenerateDTMoneyOHaasdavdasvasdvada", Source: &SourceType{ S3Source: "s3://my-bucket/data/table.parquet", Mime: "application/vnd.apache.parquet", }, } - resp, err := ss.ExecuteQuery(context.TODO(), dummyRequest) + resp, err := ss.ExecuteQuery(context.Background(), dummyRequest) if err != nil { t.Errorf("Expected no error, got %v", err) }