diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index bdb666fd658..14d48b1a615 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -141,6 +141,7 @@ go/arrow/unionmode_string.go go/arrow/compute/go.sum go/arrow/compute/datumkind_string.go go/arrow/compute/funckind_string.go +go/arrow/compute/internal/kernels/_lib/vendored/* go/*.tmpldata go/*.s go/parquet/internal/gen-go/parquet/GoUnusedProtection__.go diff --git a/go/arrow/compute/arithmetic.go b/go/arrow/compute/arithmetic.go new file mode 100644 index 00000000000..49c3c24160e --- /dev/null +++ b/go/arrow/compute/arithmetic.go @@ -0,0 +1,150 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compute + +import ( + "context" + "fmt" + + "github.com/apache/arrow/go/v10/arrow" + "github.com/apache/arrow/go/v10/arrow/compute/internal/exec" + "github.com/apache/arrow/go/v10/arrow/compute/internal/kernels" +) + +type arithmeticFunction struct { + ScalarFunction + + promote decimalPromotion +} + +func (fn *arithmeticFunction) checkDecimals(vals ...arrow.DataType) error { + if !hasDecimal(vals...) { + return nil + } + + if len(vals) != 2 { + return nil + } + + if fn.promote == decPromoteNone { + return fmt.Errorf("%w: invalid decimal function: %s", arrow.ErrInvalid, fn.name) + } + + return castBinaryDecimalArgs(fn.promote, vals...) +} + +func (fn *arithmeticFunction) DispatchBest(vals ...arrow.DataType) (exec.Kernel, error) { + if err := fn.checkArity(len(vals)); err != nil { + return nil, err + } + + if err := fn.checkDecimals(vals...); err != nil { + return nil, err + } + + if kn, err := fn.DispatchExact(vals...); err == nil { + return kn, nil + } + + ensureDictionaryDecoded(vals...) + + // only promote types for binary funcs + if len(vals) == 2 { + replaceNullWithOtherType(vals...) + if unit, istime := commonTemporalResolution(vals...); istime { + replaceTemporalTypes(unit, vals...) + } else { + if dt := commonNumeric(vals...); dt != nil { + replaceTypes(dt, vals...) + } + } + } + + return fn.DispatchExact(vals...) +} + +var ( + addDoc FunctionDoc +) + +func RegisterScalarArithmetic(reg FunctionRegistry) { + addFn := &arithmeticFunction{*NewScalarFunction("add_unchecked", Binary(), addDoc), decPromoteAdd} + for _, k := range kernels.GetArithmeticKernels(kernels.OpAdd) { + if err := addFn.AddKernel(k); err != nil { + panic(err) + } + } + + reg.AddFunction(addFn, false) + + addCheckedFn := &arithmeticFunction{*NewScalarFunction("add", Binary(), addDoc), decPromoteAdd} + for _, k := range kernels.GetArithmeticKernels(kernels.OpAddChecked) { + if err := addCheckedFn.AddKernel(k); err != nil { + panic(err) + } + } + + reg.AddFunction(addCheckedFn, false) + + subFn := &arithmeticFunction{*NewScalarFunction("sub_unchecked", Binary(), addDoc), decPromoteAdd} + for _, k := range kernels.GetArithmeticKernels(kernels.OpSub) { + if err := subFn.AddKernel(k); err != nil { + panic(err) + } + } + + reg.AddFunction(subFn, false) + + subCheckedFn := &arithmeticFunction{*NewScalarFunction("sub", Binary(), addDoc), decPromoteAdd} + for _, k := range kernels.GetArithmeticKernels(kernels.OpSubChecked) { + if err := subCheckedFn.AddKernel(k); err != nil { + panic(err) + } + } + + reg.AddFunction(subCheckedFn, false) +} + +// Add performs an addition between the passed in arguments (scalar or array) +// and returns the result. If one argument is a scalar and the other is an +// array, the scalar value is added to each value of the array. +// +// ArithmeticOptions specifies whether or not to check for overflows, +// performance is faster if not explicitly checking for overflows but +// will error on an overflow if CheckOverflow is true. +func Add(ctx context.Context, opts ArithmeticOptions, left, right Datum) (Datum, error) { + fn := "add" + if opts.NoCheckOverflow { + fn = "add_unchecked" + } + return CallFunction(ctx, fn, nil, left, right) +} + +// Sub performs a subtraction between the passed in arguments (scalar or array) +// and returns the result. If one argument is a scalar and the other is an +// array, the scalar value is subtracted from each value of the array. +// +// ArithmeticOptions specifies whether or not to check for overflows, +// performance is faster if not explicitly checking for overflows but +// will error on an overflow if CheckOverflow is true. +func Subtract(ctx context.Context, opts ArithmeticOptions, left, right Datum) (Datum, error) { + fn := "sub" + if opts.NoCheckOverflow { + fn = "sub_unchecked" + } + return CallFunction(ctx, fn, nil, left, right) +} diff --git a/go/arrow/compute/arithmetic_test.go b/go/arrow/compute/arithmetic_test.go new file mode 100644 index 00000000000..2da7a62fe86 --- /dev/null +++ b/go/arrow/compute/arithmetic_test.go @@ -0,0 +1,502 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compute_test + +import ( + "context" + "fmt" + "math" + "strings" + "testing" + + "github.com/apache/arrow/go/v10/arrow" + "github.com/apache/arrow/go/v10/arrow/array" + "github.com/apache/arrow/go/v10/arrow/compute" + "github.com/apache/arrow/go/v10/arrow/compute/internal/exec" + "github.com/apache/arrow/go/v10/arrow/internal/testing/gen" + "github.com/apache/arrow/go/v10/arrow/memory" + "github.com/apache/arrow/go/v10/arrow/scalar" + "github.com/klauspost/cpuid/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +var ( + CpuCacheSizes = [...]int{ // defaults + 32 * 1024, // level 1: 32K + 256 * 1024, // level 2: 256K + 3072 * 1024, // level 3: 3M + } +) + +func init() { + if cpuid.CPU.Cache.L1D != -1 { + CpuCacheSizes[0] = cpuid.CPU.Cache.L1D + } + if cpuid.CPU.Cache.L2 != -1 { + CpuCacheSizes[1] = cpuid.CPU.Cache.L2 + } + if cpuid.CPU.Cache.L3 != -1 { + CpuCacheSizes[2] = cpuid.CPU.Cache.L3 + } +} + +type binaryArithmeticFunc = func(context.Context, compute.ArithmeticOptions, compute.Datum, compute.Datum) (compute.Datum, error) + +type binaryFunc = func(left, right compute.Datum) (compute.Datum, error) + +func assertScalarEquals(t *testing.T, expected, actual scalar.Scalar) { + assert.Truef(t, scalar.Equals(expected, actual), "expected: %s\ngot: %s", expected, actual) +} + +func assertBinop(t *testing.T, fn binaryFunc, left, right, expected arrow.Array) { + actual, err := fn(&compute.ArrayDatum{Value: left.Data()}, &compute.ArrayDatum{Value: right.Data()}) + require.NoError(t, err) + defer actual.Release() + assertDatumsEqual(t, &compute.ArrayDatum{Value: expected.Data()}, actual) + + // also check (Scalar, Scalar) operations + for i := 0; i < expected.Len(); i++ { + s, err := scalar.GetScalar(expected, i) + require.NoError(t, err) + lhs, _ := scalar.GetScalar(left, i) + rhs, _ := scalar.GetScalar(right, i) + + actual, err := fn(&compute.ScalarDatum{Value: lhs}, &compute.ScalarDatum{Value: rhs}) + assert.NoError(t, err) + assertScalarEquals(t, s, actual.(*compute.ScalarDatum).Value) + } +} + +func assertBinopErr(t *testing.T, fn binaryFunc, left, right arrow.Array, expectedMsg string) { + _, err := fn(&compute.ArrayDatum{left.Data()}, &compute.ArrayDatum{Value: right.Data()}) + assert.ErrorIs(t, err, arrow.ErrInvalid) + assert.ErrorContains(t, err, expectedMsg) +} + +type BinaryFuncTestSuite struct { + suite.Suite + + mem *memory.CheckedAllocator + ctx context.Context +} + +func (b *BinaryFuncTestSuite) SetupTest() { + b.mem = memory.NewCheckedAllocator(memory.DefaultAllocator) + b.ctx = compute.WithAllocator(context.TODO(), b.mem) +} + +func (b *BinaryFuncTestSuite) TearDownTest() { + b.mem.AssertSize(b.T(), 0) +} + +type Float16BinaryFuncTestSuite struct { + BinaryFuncTestSuite +} + +func (b *Float16BinaryFuncTestSuite) assertBinopErr(fn binaryFunc, lhs, rhs string) { + left, _, _ := array.FromJSON(b.mem, arrow.FixedWidthTypes.Float16, strings.NewReader(lhs), array.WithUseNumber()) + defer left.Release() + right, _, _ := array.FromJSON(b.mem, arrow.FixedWidthTypes.Float16, strings.NewReader(rhs), array.WithUseNumber()) + defer right.Release() + + _, err := fn(&compute.ArrayDatum{left.Data()}, &compute.ArrayDatum{right.Data()}) + b.ErrorIs(err, arrow.ErrNotImplemented) +} + +func (b *Float16BinaryFuncTestSuite) TestAdd() { + for _, overflow := range []bool{false, true} { + b.Run(fmt.Sprintf("no_overflow_check=%t", overflow), func() { + opts := compute.ArithmeticOptions{NoCheckOverflow: overflow} + b.assertBinopErr(func(left, right compute.Datum) (compute.Datum, error) { + return compute.Add(b.ctx, opts, left, right) + }, `[1.5]`, `[1.5]`) + }) + } +} + +func (b *Float16BinaryFuncTestSuite) TestSub() { + for _, overflow := range []bool{false, true} { + b.Run(fmt.Sprintf("no_overflow_check=%t", overflow), func() { + opts := compute.ArithmeticOptions{NoCheckOverflow: overflow} + b.assertBinopErr(func(left, right compute.Datum) (compute.Datum, error) { + return compute.Subtract(b.ctx, opts, left, right) + }, `[1.5]`, `[1.5]`) + }) + } +} + +type BinaryArithmeticSuite[T exec.NumericTypes] struct { + BinaryFuncTestSuite + + opts compute.ArithmeticOptions + min, max T +} + +func (BinaryArithmeticSuite[T]) DataType() arrow.DataType { + return exec.GetDataType[T]() +} + +func (b *BinaryArithmeticSuite[T]) SetupTest() { + b.BinaryFuncTestSuite.SetupTest() + b.opts.NoCheckOverflow = false +} + +func (b *BinaryArithmeticSuite[T]) makeNullScalar() scalar.Scalar { + return scalar.MakeNullScalar(b.DataType()) +} + +func (b *BinaryArithmeticSuite[T]) makeScalar(val T) scalar.Scalar { + return scalar.MakeScalar(val) +} + +func (b *BinaryArithmeticSuite[T]) assertBinopScalars(fn binaryArithmeticFunc, lhs, rhs T, expected T) { + left, right := b.makeScalar(lhs), b.makeScalar(rhs) + exp := b.makeScalar(expected) + + actual, err := fn(b.ctx, b.opts, &compute.ScalarDatum{Value: left}, &compute.ScalarDatum{Value: right}) + b.NoError(err) + sc := actual.(*compute.ScalarDatum).Value + + assertScalarEquals(b.T(), exp, sc) +} + +func (b *BinaryArithmeticSuite[T]) assertBinopScalarValArr(fn binaryArithmeticFunc, lhs T, rhs, expected string) { + left := b.makeScalar(lhs) + b.assertBinopScalarArr(fn, left, rhs, expected) +} + +func (b *BinaryArithmeticSuite[T]) assertBinopScalarArr(fn binaryArithmeticFunc, lhs scalar.Scalar, rhs, expected string) { + right, _, _ := array.FromJSON(b.mem, b.DataType(), strings.NewReader(rhs)) + defer right.Release() + exp, _, _ := array.FromJSON(b.mem, b.DataType(), strings.NewReader(expected)) + defer exp.Release() + + actual, err := fn(b.ctx, b.opts, &compute.ScalarDatum{Value: lhs}, &compute.ArrayDatum{Value: right.Data()}) + b.NoError(err) + defer actual.Release() + assertDatumsEqual(b.T(), &compute.ArrayDatum{Value: exp.Data()}, actual) +} + +func (b *BinaryArithmeticSuite[T]) assertBinopArrScalarVal(fn binaryArithmeticFunc, lhs string, rhs T, expected string) { + right := b.makeScalar(rhs) + b.assertBinopArrScalar(fn, lhs, right, expected) +} + +func (b *BinaryArithmeticSuite[T]) assertBinopArrScalar(fn binaryArithmeticFunc, lhs string, rhs scalar.Scalar, expected string) { + left, _, _ := array.FromJSON(b.mem, b.DataType(), strings.NewReader(lhs)) + defer left.Release() + exp, _, _ := array.FromJSON(b.mem, b.DataType(), strings.NewReader(expected)) + defer exp.Release() + + actual, err := fn(b.ctx, b.opts, &compute.ArrayDatum{Value: left.Data()}, &compute.ScalarDatum{Value: rhs}) + b.NoError(err) + defer actual.Release() + assertDatumsEqual(b.T(), &compute.ArrayDatum{Value: exp.Data()}, actual) +} + +func (b *BinaryArithmeticSuite[T]) assertBinop(fn binaryArithmeticFunc, lhs, rhs, expected string) { + left, _, _ := array.FromJSON(b.mem, b.DataType(), strings.NewReader(lhs)) + defer left.Release() + right, _, _ := array.FromJSON(b.mem, b.DataType(), strings.NewReader(rhs)) + defer right.Release() + exp, _, _ := array.FromJSON(b.mem, b.DataType(), strings.NewReader(expected)) + defer exp.Release() + + assertBinop(b.T(), func(left, right compute.Datum) (compute.Datum, error) { + return fn(b.ctx, b.opts, left, right) + }, left, right, exp) +} + +func (b *BinaryArithmeticSuite[T]) setOverflowCheck(value bool) { + b.opts.NoCheckOverflow = value +} + +func (b *BinaryArithmeticSuite[T]) assertBinopErr(fn binaryArithmeticFunc, lhs, rhs, expectedMsg string) { + left, _, _ := array.FromJSON(b.mem, b.DataType(), strings.NewReader(lhs), array.WithUseNumber()) + defer left.Release() + right, _, _ := array.FromJSON(b.mem, b.DataType(), strings.NewReader(rhs), array.WithUseNumber()) + defer right.Release() + + assertBinopErr(b.T(), func(left, right compute.Datum) (compute.Datum, error) { + return fn(b.ctx, b.opts, left, right) + }, left, right, expectedMsg) +} + +func (b *BinaryArithmeticSuite[T]) TestAdd() { + b.Run(b.DataType().String(), func() { + for _, overflow := range []bool{false, true} { + b.Run(fmt.Sprintf("no_overflow_check=%t", overflow), func() { + b.setOverflowCheck(overflow) + + b.assertBinop(compute.Add, `[]`, `[]`, `[]`) + b.assertBinop(compute.Add, `[3, 2, 6]`, `[1, 0, 2]`, `[4, 2, 8]`) + // nulls on one side + b.assertBinop(compute.Add, `[null, 1, null]`, `[3, 4, 5]`, `[null, 5, null]`) + b.assertBinop(compute.Add, `[3, 4, 5]`, `[null, 1, null]`, `[null, 5, null]`) + // nulls on both sides + b.assertBinop(compute.Add, `[null, 1, 2]`, `[3, 4, null]`, `[null, 5, null]`) + // all nulls + b.assertBinop(compute.Add, `[null]`, `[null]`, `[null]`) + + // scalar on the left + b.assertBinopScalarValArr(compute.Add, 3, `[1, 2]`, `[4, 5]`) + b.assertBinopScalarValArr(compute.Add, 3, `[null, 2]`, `[null, 5]`) + b.assertBinopScalarArr(compute.Add, b.makeNullScalar(), `[1, 2]`, `[null, null]`) + b.assertBinopScalarArr(compute.Add, b.makeNullScalar(), `[null, 2]`, `[null, null]`) + // scalar on the right + b.assertBinopArrScalarVal(compute.Add, `[1, 2]`, 3, `[4, 5]`) + b.assertBinopArrScalarVal(compute.Add, `[null, 2]`, 3, `[null, 5]`) + b.assertBinopArrScalar(compute.Add, `[1, 2]`, b.makeNullScalar(), `[null, null]`) + b.assertBinopArrScalar(compute.Add, `[null, 2]`, b.makeNullScalar(), `[null, null]`) + + if !arrow.IsFloating(b.DataType().ID()) && !overflow { + val := fmt.Sprintf("[%v]", b.max) + b.assertBinopErr(compute.Add, val, val, "overflow") + } + }) + } + }) +} + +func (b *BinaryArithmeticSuite[T]) TestSub() { + b.Run(b.DataType().String(), func() { + for _, overflow := range []bool{false, true} { + b.Run(fmt.Sprintf("no_overflow_check=%t", overflow), func() { + b.setOverflowCheck(overflow) + + b.assertBinop(compute.Subtract, `[]`, `[]`, `[]`) + b.assertBinop(compute.Subtract, `[3, 2, 6]`, `[1, 0, 2]`, `[2, 2, 4]`) + // nulls on one side + b.assertBinop(compute.Subtract, `[null, 4, null]`, `[2, 1, 0]`, `[null, 3, null]`) + b.assertBinop(compute.Subtract, `[3, 4, 5]`, `[null, 1, null]`, `[null, 3, null]`) + // nulls on both sides + b.assertBinop(compute.Subtract, `[null, 4, 3]`, `[2, 1, null]`, `[null, 3, null]`) + // all nulls + b.assertBinop(compute.Subtract, `[null]`, `[null]`, `[null]`) + + // scalar on the left + b.assertBinopScalarValArr(compute.Subtract, 3, `[1, 2]`, `[2, 1]`) + b.assertBinopScalarValArr(compute.Subtract, 3, `[null, 2]`, `[null, 1]`) + b.assertBinopScalarArr(compute.Subtract, b.makeNullScalar(), `[1, 2]`, `[null, null]`) + b.assertBinopScalarArr(compute.Subtract, b.makeNullScalar(), `[null, 2]`, `[null, null]`) + // scalar on the right + b.assertBinopArrScalarVal(compute.Subtract, `[4, 5]`, 3, `[1, 2]`) + b.assertBinopArrScalarVal(compute.Subtract, `[null, 5]`, 3, `[null, 2]`) + b.assertBinopArrScalar(compute.Subtract, `[1, 2]`, b.makeNullScalar(), `[null, null]`) + b.assertBinopArrScalar(compute.Subtract, `[null, 2]`, b.makeNullScalar(), `[null, null]`) + + if !arrow.IsFloating(b.DataType().ID()) && !overflow { + b.assertBinopErr(compute.Subtract, fmt.Sprintf("[%v]", b.min), fmt.Sprintf("[%v]", b.max), "overflow") + } + }) + } + }) +} + +func TestBinaryArithmetic(t *testing.T) { + suite.Run(t, &BinaryArithmeticSuite[int8]{min: math.MinInt8, max: math.MaxInt8}) + suite.Run(t, &BinaryArithmeticSuite[uint8]{min: 0, max: math.MaxUint8}) + suite.Run(t, &BinaryArithmeticSuite[int16]{min: math.MinInt16, max: math.MaxInt16}) + suite.Run(t, &BinaryArithmeticSuite[uint16]{min: 0, max: math.MaxUint16}) + suite.Run(t, &BinaryArithmeticSuite[int32]{min: math.MinInt32, max: math.MaxInt32}) + suite.Run(t, &BinaryArithmeticSuite[uint32]{min: 0, max: math.MaxUint32}) + suite.Run(t, &BinaryArithmeticSuite[int64]{min: math.MinInt64, max: math.MaxInt64}) + suite.Run(t, &BinaryArithmeticSuite[uint64]{min: 0, max: math.MaxUint64}) + suite.Run(t, &BinaryArithmeticSuite[float32]{min: -math.MaxFloat32, max: math.MaxFloat32}) + suite.Run(t, &BinaryArithmeticSuite[float64]{min: -math.MaxFloat64, max: math.MaxFloat64}) + suite.Run(t, new(Float16BinaryFuncTestSuite)) +} + +func TestBinaryArithmeticDispatchBest(t *testing.T) { + for _, name := range []string{"add", "sub"} { + for _, suffix := range []string{"", "_unchecked"} { + name += suffix + t.Run(name, func(t *testing.T) { + + tests := []struct { + left, right arrow.DataType + expected arrow.DataType + }{ + {arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Int32}, + {arrow.PrimitiveTypes.Int32, arrow.Null, arrow.PrimitiveTypes.Int32}, + {arrow.Null, arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Int32}, + {arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Int8, arrow.PrimitiveTypes.Int32}, + {arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Int16, arrow.PrimitiveTypes.Int32}, + {arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Int32}, + {arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Int64, arrow.PrimitiveTypes.Int64}, + {arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Uint8, arrow.PrimitiveTypes.Int32}, + {arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Uint16, arrow.PrimitiveTypes.Int32}, + {arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Uint32, arrow.PrimitiveTypes.Int64}, + {arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Uint64, arrow.PrimitiveTypes.Int64}, + {arrow.PrimitiveTypes.Uint8, arrow.PrimitiveTypes.Uint8, arrow.PrimitiveTypes.Uint8}, + {arrow.PrimitiveTypes.Uint8, arrow.PrimitiveTypes.Uint16, arrow.PrimitiveTypes.Uint16}, + {arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Float32, arrow.PrimitiveTypes.Float32}, + {arrow.PrimitiveTypes.Float32, arrow.PrimitiveTypes.Int64, arrow.PrimitiveTypes.Float32}, + {arrow.PrimitiveTypes.Float64, arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Float64}, + {&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int8, ValueType: arrow.PrimitiveTypes.Float64}, + arrow.PrimitiveTypes.Float64, arrow.PrimitiveTypes.Float64}, + {&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int8, ValueType: arrow.PrimitiveTypes.Float64}, + arrow.PrimitiveTypes.Int16, arrow.PrimitiveTypes.Float64}, + } + + for _, tt := range tests { + CheckDispatchBest(t, name, []arrow.DataType{tt.left, tt.right}, []arrow.DataType{tt.expected, tt.expected}) + } + }) + } + } +} + +const seed = 0x94378165 + +type binaryOp = func(ctx context.Context, left, right compute.Datum) (compute.Datum, error) + +func Add(ctx context.Context, left, right compute.Datum) (compute.Datum, error) { + var opts compute.ArithmeticOptions + return compute.Add(ctx, opts, left, right) +} + +func Subtract(ctx context.Context, left, right compute.Datum) (compute.Datum, error) { + var opts compute.ArithmeticOptions + return compute.Subtract(ctx, opts, left, right) +} + +func AddUnchecked(ctx context.Context, left, right compute.Datum) (compute.Datum, error) { + opts := compute.ArithmeticOptions{NoCheckOverflow: true} + return compute.Add(ctx, opts, left, right) +} + +func SubtractUnchecked(ctx context.Context, left, right compute.Datum) (compute.Datum, error) { + opts := compute.ArithmeticOptions{NoCheckOverflow: true} + return compute.Subtract(ctx, opts, left, right) +} + +func arrayScalarKernel(b *testing.B, sz int, nullProp float64, op binaryOp, dt arrow.DataType) { + b.Run("array scalar", func(b *testing.B) { + var ( + mem = memory.NewCheckedAllocator(memory.DefaultAllocator) + arraySize = int64(sz / dt.(arrow.FixedWidthDataType).Bytes()) + min int64 = 6 + max = min + 15 + sc, _ = scalar.MakeScalarParam(6, dt) + rhs compute.Datum = &compute.ScalarDatum{Value: sc} + rng = gen.NewRandomArrayGenerator(seed, mem) + ) + + lhs := rng.Numeric(dt.ID(), arraySize, min, max, nullProp) + b.Cleanup(func() { + lhs.Release() + }) + + var ( + res compute.Datum + err error + ctx = context.Background() + left = &compute.ArrayDatum{Value: lhs.Data()} + ) + + b.SetBytes(arraySize) + b.ResetTimer() + for n := 0; n < b.N; n++ { + res, err = op(ctx, left, rhs) + b.StopTimer() + if err != nil { + b.Fatal(err) + } + res.Release() + b.StartTimer() + } + }) +} + +func arrayArrayKernel(b *testing.B, sz int, nullProp float64, op binaryOp, dt arrow.DataType) { + b.Run("array array", func(b *testing.B) { + var ( + mem = memory.NewCheckedAllocator(memory.DefaultAllocator) + arraySize = int64(sz / dt.(arrow.FixedWidthDataType).Bytes()) + rmin int64 = 1 + rmax = rmin + 6 // 7 + lmin = rmax + 1 // 8 + lmax = lmin + 6 // 14 + rng = gen.NewRandomArrayGenerator(seed, mem) + ) + + lhs := rng.Numeric(dt.ID(), arraySize, lmin, lmax, nullProp) + rhs := rng.Numeric(dt.ID(), arraySize, rmin, rmax, nullProp) + b.Cleanup(func() { + lhs.Release() + rhs.Release() + }) + var ( + res compute.Datum + err error + ctx = context.Background() + left = &compute.ArrayDatum{Value: lhs.Data()} + right = &compute.ArrayDatum{Value: rhs.Data()} + ) + + b.SetBytes(arraySize) + b.ResetTimer() + for n := 0; n < b.N; n++ { + res, err = op(ctx, left, right) + b.StopTimer() + if err != nil { + b.Fatal(err) + } + res.Release() + b.StartTimer() + } + }) +} + +func BenchmarkScalarArithmetic(b *testing.B) { + args := []struct { + sz int + nullProb float64 + }{ + {CpuCacheSizes[2], 0}, + {CpuCacheSizes[2], 0.5}, + {CpuCacheSizes[2], 1}, + } + + testfns := []struct { + name string + op binaryOp + }{ + {"Add", Add}, + {"AddUnchecked", AddUnchecked}, + {"Subtract", Subtract}, + {"SubtractUnchecked", SubtractUnchecked}, + } + + for _, dt := range numericTypes { + b.Run(dt.String(), func(b *testing.B) { + for _, benchArgs := range args { + b.Run(fmt.Sprintf("sz=%d/nullprob=%.2f", benchArgs.sz, benchArgs.nullProb), func(b *testing.B) { + for _, tfn := range testfns { + b.Run(tfn.name, func(b *testing.B) { + arrayArrayKernel(b, benchArgs.sz, benchArgs.nullProb, tfn.op, dt) + arrayScalarKernel(b, benchArgs.sz, benchArgs.nullProb, tfn.op, dt) + }) + } + }) + } + }) + } +} diff --git a/go/arrow/compute/cast_test.go b/go/arrow/compute/cast_test.go index c8f07e23aef..cb5c4f8a758 100644 --- a/go/arrow/compute/cast_test.go +++ b/go/arrow/compute/cast_test.go @@ -34,7 +34,6 @@ import ( "github.com/apache/arrow/go/v10/arrow/internal/testing/types" "github.com/apache/arrow/go/v10/arrow/memory" "github.com/apache/arrow/go/v10/arrow/scalar" - "github.com/klauspost/cpuid/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -2732,26 +2731,6 @@ func TestCasts(t *testing.T) { const rngseed = 0x94378165 -var ( - CpuCacheSizes = [...]int{ // defaults - 32 * 1024, // level 1: 32K - 256 * 1024, // level 2: 256K - 3072 * 1024, // level 3: 3M - } -) - -func init() { - if cpuid.CPU.Cache.L1D != -1 { - CpuCacheSizes[0] = cpuid.CPU.Cache.L1D - } - if cpuid.CPU.Cache.L2 != -1 { - CpuCacheSizes[1] = cpuid.CPU.Cache.L2 - } - if cpuid.CPU.Cache.L3 != -1 { - CpuCacheSizes[2] = cpuid.CPU.Cache.L3 - } -} - func benchmarkNumericCast(b *testing.B, fromType, toType arrow.DataType, opts compute.CastOptions, size, min, max int64, nullprob float64) { rng := gen.NewRandomArrayGenerator(rngseed, memory.DefaultAllocator) arr := rng.Numeric(fromType.ID(), size, min, max, nullprob) diff --git a/go/arrow/compute/exec.go b/go/arrow/compute/exec.go index 3709424b9e4..b7f4962806c 100644 --- a/go/arrow/compute/exec.go +++ b/go/arrow/compute/exec.go @@ -99,6 +99,17 @@ func execInternal(ctx context.Context, fn Function, opts FunctionOptions, passed return } + // cast arguments if necessary + for i, arg := range args { + if !arrow.TypeEqual(inTypes[i], arg.(ArrayLikeDatum).Type()) { + args[i], err = CastDatum(ctx, arg, SafeCastOptions(inTypes[i])) + if err != nil { + return nil, err + } + defer args[i].Release() + } + } + kctx := &exec.KernelCtx{Ctx: ctx, Kernel: k} init := k.GetInitFn() kinitArgs := exec.KernelInitArgs{Kernel: k, Inputs: inTypes, Options: opts} diff --git a/go/arrow/compute/executor.go b/go/arrow/compute/executor.go index 8098f2f8edd..f51c59deaf0 100644 --- a/go/arrow/compute/executor.go +++ b/go/arrow/compute/executor.go @@ -242,7 +242,7 @@ func propagateNulls(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ArraySp } var ( - arrsWithNulls = make([]*exec.ArraySpan, 0) + arrsWithNulls = make([]*exec.ArraySpan, 0, len(batch.Values)) isAllNull bool prealloc bool = out.Buffers[0].Buf != nil ) @@ -596,6 +596,7 @@ func (s *scalarExecutor) executeSpans(data chan<- Datum) (err error) { resultOffset = nextOffset } if err != nil { + prealloc.Release() return } diff --git a/go/arrow/compute/expression.go b/go/arrow/compute/expression.go index 644de5cf5c9..aa6e3661afa 100644 --- a/go/arrow/compute/expression.go +++ b/go/arrow/compute/expression.go @@ -485,7 +485,7 @@ const ( ) type ArithmeticOptions struct { - CheckOverflow bool `compute:"check_overflow"` + NoCheckOverflow bool `compute:"check_overflow"` } func (ArithmeticOptions) TypeName() string { return "ArithmeticOptions" } diff --git a/go/arrow/compute/functions_test.go b/go/arrow/compute/functions_test.go index 78dbd8be5e4..1f167f0232c 100644 --- a/go/arrow/compute/functions_test.go +++ b/go/arrow/compute/functions_test.go @@ -19,8 +19,10 @@ package compute_test import ( "testing" + "github.com/apache/arrow/go/v10/arrow" "github.com/apache/arrow/go/v10/arrow/compute" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestArityBasics(t *testing.T) { @@ -44,3 +46,22 @@ func TestArityBasics(t *testing.T) { assert.Equal(t, 2, varargs.NArgs) assert.True(t, varargs.IsVarArgs) } + +func CheckDispatchBest(t *testing.T, funcName string, originalTypes, expected []arrow.DataType) { + fn, exists := compute.GetFunctionRegistry().GetFunction(funcName) + require.True(t, exists) + + vals := make([]arrow.DataType, len(originalTypes)) + copy(vals, originalTypes) + + actualKernel, err := fn.DispatchBest(vals...) + require.NoError(t, err) + expKernel, err := fn.DispatchExact(expected...) + require.NoError(t, err) + + assert.Same(t, expKernel, actualKernel) + assert.Equal(t, len(expected), len(vals)) + for i, v := range vals { + assert.True(t, arrow.TypeEqual(v, expected[i]), v.String(), expected[i].String()) + } +} diff --git a/go/arrow/compute/internal/exec/span.go b/go/arrow/compute/internal/exec/span.go index ca6caf436b9..1e8a719d347 100644 --- a/go/arrow/compute/internal/exec/span.go +++ b/go/arrow/compute/internal/exec/span.go @@ -86,6 +86,21 @@ type ArraySpan struct { Children []ArraySpan } +// if an error is encountered, call Release on a preallocated span +// to ensure it releases any self-allocated buffers, it will +// not call release on buffers it doesn't own (SelfAlloc != true) +func (a *ArraySpan) Release() { + for _, c := range a.Children { + c.Release() + } + + for _, b := range a.Buffers { + if b.SelfAlloc { + b.Owner.Release() + } + } +} + func (a *ArraySpan) MayHaveNulls() bool { return atomic.LoadInt64(&a.Nulls) != 0 && a.Buffers[0].Buf != nil } @@ -114,7 +129,7 @@ func (a *ArraySpan) NumBuffers() int { return getNumBuffers(a.Type) } // MakeData generates an arrow.ArrayData object for this ArraySpan, // properly updating the buffer ref count if necessary. func (a *ArraySpan) MakeData() arrow.ArrayData { - bufs := make([]*memory.Buffer, a.NumBuffers()) + var bufs [3]*memory.Buffer for i := range bufs { b := a.GetBuffer(i) bufs[i] = b @@ -155,7 +170,7 @@ func (a *ArraySpan) MakeData() arrow.ArrayData { } if dt.ID() == arrow.DICTIONARY { - result := array.NewData(a.Type, length, bufs, nil, nulls, off) + result := array.NewData(a.Type, length, bufs[:a.NumBuffers()], nil, nulls, off) dict := a.Dictionary().MakeData() defer dict.Release() result.SetDictionary(dict) @@ -173,7 +188,7 @@ func (a *ArraySpan) MakeData() arrow.ArrayData { children[i] = d } } - return array.NewData(a.Type, length, bufs, children, nulls, off) + return array.NewData(a.Type, length, bufs[:a.NumBuffers()], children, nulls, off) } // MakeArray is a convenience function for calling array.MakeFromData(a.MakeData()) @@ -186,14 +201,24 @@ func (a *ArraySpan) MakeArray() arrow.Array { // SetSlice updates the offset and length of this ArraySpan to refer to // a specific slice of the underlying buffers. func (a *ArraySpan) SetSlice(off, length int64) { - a.Offset, a.Len = off, length + if off == a.Offset && length == a.Len { + // don't modify the nulls if the slice is the entire span + return + } + if a.Type.ID() != arrow.NULL { if a.Nulls != 0 { - a.Nulls = array.UnknownNullCount + if a.Nulls == a.Len { + a.Nulls = length + } else { + a.Nulls = array.UnknownNullCount + } } } else { - a.Nulls = a.Len + a.Nulls = length } + + a.Offset, a.Len = off, length } // GetBuffer returns the buffer for the requested index. If this buffer diff --git a/go/arrow/compute/internal/exec/utils.go b/go/arrow/compute/internal/exec/utils.go index 876e3f38ece..57fe3183c6e 100644 --- a/go/arrow/compute/internal/exec/utils.go +++ b/go/arrow/compute/internal/exec/utils.go @@ -135,6 +135,13 @@ func Min[T constraints.Ordered](a, b T) T { return b } +func Max[T constraints.Ordered](a, b T) T { + if a > b { + return a + } + return b +} + // OptionsInit should be used in the case where a KernelState is simply // represented with a specific type by value (instead of pointer). // This will initialize the KernelState as a value-copied instance of @@ -165,13 +172,26 @@ var typMap = map[reflect.Type]arrow.DataType{ reflect.TypeOf(arrow.Date32(0)): arrow.FixedWidthTypes.Date32, reflect.TypeOf(arrow.Date64(0)): arrow.FixedWidthTypes.Date64, reflect.TypeOf(true): arrow.FixedWidthTypes.Boolean, + reflect.TypeOf(float16.Num{}): arrow.FixedWidthTypes.Float16, } -func GetDataType[T NumericTypes | bool | string]() arrow.DataType { +// GetDataType returns the appropriate arrow.DataType for the given type T +// only for non-parametric types. This uses a map and reflection internally +// so don't call this in a tight loop, instead call this once and then use +// a closure with the result. +func GetDataType[T NumericTypes | bool | string | float16.Num]() arrow.DataType { var z T return typMap[reflect.TypeOf(z)] } +// GetType returns the appropriate arrow.Type type T, only for non-parameteric +// types. This uses a map and reflection internally so don't call this in +// a tight loop, instead call it once and then use a closure with the result. +func GetType[T NumericTypes | bool | string]() arrow.Type { + var z T + return typMap[reflect.TypeOf(z)].ID() +} + type arrayBuilder[T NumericTypes] interface { array.Builder Append(T) diff --git a/go/arrow/compute/internal/kernels/Makefile b/go/arrow/compute/internal/kernels/Makefile index 752c38d412d..96238cc9a12 100644 --- a/go/arrow/compute/internal/kernels/Makefile +++ b/go/arrow/compute/internal/kernels/Makefile @@ -36,7 +36,8 @@ ALL_SOURCES := $(shell find . -path ./_lib -prune -o -name '*.go' -name '*.s' -n .PHONEY: assembly INTEL_SOURCES := \ - cast_numeric_avx2_amd64.s cast_numeric_sse4_amd64.s constant_factor_avx2_amd64.s constant_factor_sse4_amd64.s + cast_numeric_avx2_amd64.s cast_numeric_sse4_amd64.s constant_factor_avx2_amd64.s \ + constant_factor_sse4_amd64.s base_arithmetic_avx2_amd64.s base_arithmetic_sse4_amd64.s # # ARROW-15336: DO NOT add the assembly target for Arm64 (ARM_SOURCES) until c2goasm added the Arm64 support. @@ -55,6 +56,15 @@ _lib/cast_numeric_sse4_amd64.s: _lib/cast_numeric.cc _lib/cast_numeric_neon.s: _lib/cast_numeric.cc $(CXX) -std=c++17 -S $(C_FLAGS_NEON) $^ -o $@ ; $(PERL_FIXUP_ROTATE) $@ +_lib/base_arithmetic_avx2_amd64.s: _lib/base_arithmetic.cc + $(CXX) -std=c++17 -S $(C_FLAGS) $(ASM_FLAGS_AVX2) $^ -o $@ ; $(PERL_FIXUP_ROTATE) $@ + +_lib/base_arithmetic_sse4_amd64.s: _lib/base_arithmetic.cc + $(CXX) -std=c++17 -S $(C_FLAGS) $(ASM_FLAGS_SSE4) $^ -o $@ ; $(PERL_FIXUP_ROTATE) $@ + +_lib/base_arithmetic_neon.s: _lib/base_arithmetic.cc + $(CXX) -std=c++17 -S $(C_FLAGS_NEON) $^ -o $@ ; $(PERL_FIXUP_ROTATE) $@ + _lib/constant_factor_avx2_amd64.s: _lib/constant_factor.c $(CC) -S $(C_FLAGS) $(ASM_FLAGS_AVX2) $^ -o $@ ; $(PERL_FIXUP_ROTATE) $@ @@ -76,6 +86,12 @@ constant_factor_avx2_amd64.s: _lib/constant_factor_avx2_amd64.s constant_factor_sse4_amd64.s: _lib/constant_factor_sse4_amd64.s $(C2GOASM) -a -f $^ $@ +base_arithmetic_avx2_amd64.s: _lib/base_arithmetic_avx2_amd64.s + $(C2GOASM) -a -f $^ $@ + +base_arithmetic_sse4_amd64.s: _lib/base_arithmetic_sse4_amd64.s + $(C2GOASM) -a -f $^ $@ + clean: rm -f $(INTEL_SOURCES) rm -f $(addprefix _lib/,$(INTEL_SOURCES)) diff --git a/go/arrow/compute/internal/kernels/_lib/base_arithmetic.cc b/go/arrow/compute/internal/kernels/_lib/base_arithmetic.cc new file mode 100644 index 00000000000..dc2234bfb35 --- /dev/null +++ b/go/arrow/compute/internal/kernels/_lib/base_arithmetic.cc @@ -0,0 +1,175 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "types.h" +#include "vendored/safe-math.h" + +// Corresponds to equivalent ArithmeticOp enum in base_arithmetic.go +// for passing across which operation to perform. This allows simpler +// implementation at the cost of having to pass the extra int8 and +// perform a switch. +// +// In cases of small arrays, this is completely negligible. In cases +// of large arrays, the time saved by using SIMD here is significantly +// worth the cost. +enum class optype : int8_t { + ADD, + SUB, + + // this impl doesn't actually perform any overflow checks as we need + // to only run overflow checks on non-null entries + ADD_CHECKED, + SUB_CHECKED, +}; + +struct Add { + template + static constexpr T Call(Arg0 left, Arg1 right) { + if constexpr (is_arithmetic_v) + return left + right; + } +}; + +struct Sub { + template + static constexpr T Call(Arg0 left, Arg1 right) { + if constexpr (is_arithmetic_v) + return left - right; + } +}; + +struct AddChecked { + template + static constexpr T Call(Arg0 left, Arg1 right) { + static_assert(is_same::value && is_same::value, ""); + if constexpr(is_arithmetic_v) { + return left + right; + } + } +}; + + +struct SubChecked { + template + static constexpr T Call(Arg0 left, Arg1 right) { + static_assert(is_same::value && is_same::value, ""); + if constexpr(is_arithmetic_v) { + return left - right; + } + } +}; + +template +struct arithmetic_op_arr_arr_impl { + static inline void exec(const void* in_left, const void* in_right, void* out, const int len) { + const T* left = reinterpret_cast(in_left); + const T* right = reinterpret_cast(in_right); + T* output = reinterpret_cast(out); + + for (int i = 0; i < len; ++i) { + output[i] = Op::template Call(left[i], right[i]); + } + } +}; + +template +struct arithmetic_op_arr_scalar_impl { + static inline void exec(const void* in_left, const void* scalar_right, void* out, const int len) { + const T* left = reinterpret_cast(in_left); + const T right = *reinterpret_cast(scalar_right); + T* output = reinterpret_cast(out); + + for (int i = 0; i < len; ++i) { + output[i] = Op::template Call(left[i], right); + } + } +}; + +template +struct arithmetic_op_scalar_arr_impl { + static inline void exec(const void* scalar_left, const void* in_right, void* out, const int len) { + const T left = *reinterpret_cast(scalar_left); + const T* right = reinterpret_cast(in_right); + T* output = reinterpret_cast(out); + + for (int i = 0; i < len; ++i) { + output[i] = Op::template Call(left, right[i]); + } + } +}; + + +template typename Impl> +static inline void arithmetic_op(const int type, const void* in_left, const void* in_right, void* output, const int len) { + const auto intype = static_cast(type); + + switch (intype) { + case arrtype::UINT8: + return Impl::exec(in_left, in_right, output, len); + case arrtype::INT8: + return Impl::exec(in_left, in_right, output, len); + case arrtype::UINT16: + return Impl::exec(in_left, in_right, output, len); + case arrtype::INT16: + return Impl::exec(in_left, in_right, output, len); + case arrtype::UINT32: + return Impl::exec(in_left, in_right, output, len); + case arrtype::INT32: + return Impl::exec(in_left, in_right, output, len); + case arrtype::UINT64: + return Impl::exec(in_left, in_right, output, len); + case arrtype::INT64: + return Impl::exec(in_left, in_right, output, len); + case arrtype::FLOAT32: + return Impl::exec(in_left, in_right, output, len); + case arrtype::FLOAT64: + return Impl::exec(in_left, in_right, output, len); + default: + break; + } +} + +template