From 691db4772896ae7c77e5737eb28a3bb2c16c6cbe Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Tue, 27 Sep 2022 17:34:47 -0400 Subject: [PATCH 1/9] ARROW-17871: [Go] Initial binary arithmetic --- go/arrow/compute/arithmetic.go | 141 + go/arrow/compute/arithmetic_test.go | 229 + go/arrow/compute/internal/exec/utils.go | 12 + go/arrow/compute/internal/kernels/Makefile | 18 +- .../internal/kernels/_lib/base_arithmetic.cc | 243 + .../kernels/_lib/base_arithmetic_avx2_amd64.s | 12671 ++++++++++++++ .../kernels/_lib/base_arithmetic_sse4_amd64.s | 13530 +++++++++++++++ .../internal/kernels/_lib/cast_numeric.cc | 18 +- .../compute/internal/kernels/_lib/safe-math.h | 1072 ++ .../compute/internal/kernels/_lib/types.h | 477 + .../internal/kernels/base_arithmetic.go | 141 + .../internal/kernels/base_arithmetic_amd64.go | 83 + .../kernels/base_arithmetic_avx2_amd64.go | 46 + .../kernels/base_arithmetic_avx2_amd64.s | 12857 ++++++++++++++ .../kernels/base_arithmetic_sse4_amd64.go | 46 + .../kernels/base_arithmetic_sse4_amd64.s | 13806 ++++++++++++++++ .../kernels/basic_arithmetic_noasm.go | 32 + go/arrow/compute/internal/kernels/helpers.go | 68 + .../internal/kernels/scalar_arithmetic.go | 45 + go/arrow/compute/internal/kernels/types.go | 1 - go/arrow/compute/registry.go | 1 + go/arrow/compute/utils.go | 159 + go/arrow/datatype.go | 10 + 23 files changed, 55687 insertions(+), 19 deletions(-) create mode 100644 go/arrow/compute/arithmetic.go create mode 100644 go/arrow/compute/arithmetic_test.go create mode 100644 go/arrow/compute/internal/kernels/_lib/base_arithmetic.cc create mode 100644 go/arrow/compute/internal/kernels/_lib/base_arithmetic_avx2_amd64.s create mode 100644 go/arrow/compute/internal/kernels/_lib/base_arithmetic_sse4_amd64.s create mode 100644 go/arrow/compute/internal/kernels/_lib/safe-math.h create mode 100644 go/arrow/compute/internal/kernels/_lib/types.h create mode 100644 go/arrow/compute/internal/kernels/base_arithmetic.go create mode 100644 go/arrow/compute/internal/kernels/base_arithmetic_amd64.go create mode 100644 go/arrow/compute/internal/kernels/base_arithmetic_avx2_amd64.go create mode 100644 go/arrow/compute/internal/kernels/base_arithmetic_avx2_amd64.s create mode 100644 go/arrow/compute/internal/kernels/base_arithmetic_sse4_amd64.go create mode 100644 go/arrow/compute/internal/kernels/base_arithmetic_sse4_amd64.s create mode 100644 go/arrow/compute/internal/kernels/basic_arithmetic_noasm.go create mode 100644 go/arrow/compute/internal/kernels/scalar_arithmetic.go diff --git a/go/arrow/compute/arithmetic.go b/go/arrow/compute/arithmetic.go new file mode 100644 index 00000000000..113b70b391a --- /dev/null +++ b/go/arrow/compute/arithmetic.go @@ -0,0 +1,141 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compute + +import ( + "context" + "fmt" + "strings" + + "github.com/apache/arrow/go/v10/arrow" + "github.com/apache/arrow/go/v10/arrow/compute/internal/exec" + "github.com/apache/arrow/go/v10/arrow/compute/internal/kernels" +) + +type arithmeticFunction struct { + ScalarFunction +} + +func (fn *arithmeticFunction) checkDecimals(vals ...arrow.DataType) error { + if !hasDecimal(vals...) { + return nil + } + + if len(vals) != 2 { + return nil + } + + op := fn.name[:strings.Index(fn.name, "_")] + switch op { + case "add", "subtract": + return castBinaryDecimalArgs(decPromoteAdd, vals...) + case "multiply": + return castBinaryDecimalArgs(decPromoteMultiply, vals...) + case "divide": + return castBinaryDecimalArgs(decPromoteDivide, vals...) + default: + return fmt.Errorf("%w: invalid decimal function: %s", arrow.ErrInvalid, fn.name) + } +} + +func (fn *arithmeticFunction) DispatchBest(vals ...arrow.DataType) (exec.Kernel, error) { + if err := fn.checkArity(len(vals)); err != nil { + return nil, err + } + + if err := fn.checkDecimals(vals...); err != nil { + return nil, err + } + + if kn, err := fn.DispatchExact(vals...); err == nil { + return kn, nil + } + + ensureDictionaryDecoded(vals...) + + // only promote types for binary funcs + if len(vals) == 2 { + replaceNullWithOtherType(vals...) + if unit, istime := commonTemporalResolution(vals...); istime { + replaceTemporalTypes(unit, vals...) + } else { + if dt := commonNumeric(vals...); dt != nil { + replaceTypes(dt, vals...) + } + } + } + + return fn.DispatchExact(vals...) +} + +var ( + addDoc FunctionDoc +) + +func RegisterScalarArithmetic(reg FunctionRegistry) { + addFn := &arithmeticFunction{*NewScalarFunction("add", Binary(), addDoc)} + for _, k := range kernels.GetArithmeticKernels(kernels.OpAdd) { + if err := addFn.AddKernel(k); err != nil { + panic(err) + } + } + + reg.AddFunction(addFn, false) + + addCheckedFn := &arithmeticFunction{*NewScalarFunction("add_checked", Binary(), addDoc)} + for _, k := range kernels.GetArithmeticKernels(kernels.OpAddChecked) { + if err := addCheckedFn.AddKernel(k); err != nil { + panic(err) + } + } + + reg.AddFunction(addCheckedFn, false) + + subFn := &arithmeticFunction{*NewScalarFunction("sub", Binary(), addDoc)} + for _, k := range kernels.GetArithmeticKernels(kernels.OpSub) { + if err := subFn.AddKernel(k); err != nil { + panic(err) + } + } + + reg.AddFunction(subFn, false) + + subCheckedFn := &arithmeticFunction{*NewScalarFunction("sub_checked", Binary(), addDoc)} + for _, k := range kernels.GetArithmeticKernels(kernels.OpSubChecked) { + if err := subCheckedFn.AddKernel(k); err != nil { + panic(err) + } + } + + reg.AddFunction(subCheckedFn, false) +} + +func Add(ctx context.Context, opts ArithmeticOptions, left, right Datum) (Datum, error) { + fn := "add" + if opts.CheckOverflow { + fn = "add_checked" + } + return CallFunction(ctx, fn, nil, left, right) +} + +func Subtract(ctx context.Context, opts ArithmeticOptions, left, right Datum) (Datum, error) { + fn := "sub" + if opts.CheckOverflow { + fn = "sub_checked" + } + return CallFunction(ctx, fn, nil, left, right) +} diff --git a/go/arrow/compute/arithmetic_test.go b/go/arrow/compute/arithmetic_test.go new file mode 100644 index 00000000000..527e63bc7af --- /dev/null +++ b/go/arrow/compute/arithmetic_test.go @@ -0,0 +1,229 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compute_test + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/apache/arrow/go/v10/arrow" + "github.com/apache/arrow/go/v10/arrow/array" + "github.com/apache/arrow/go/v10/arrow/compute" + "github.com/apache/arrow/go/v10/arrow/compute/internal/exec" + "github.com/apache/arrow/go/v10/arrow/memory" + "github.com/apache/arrow/go/v10/arrow/scalar" + "github.com/stretchr/testify/suite" +) + +type binaryFunc = func(context.Context, compute.ArithmeticOptions, compute.Datum, compute.Datum) (compute.Datum, error) + +type BinaryArithmeticSuite[T exec.NumericTypes] struct { + suite.Suite + + mem *memory.CheckedAllocator + opts compute.ArithmeticOptions + ctx context.Context +} + +func (BinaryArithmeticSuite[T]) DataType() arrow.DataType { + return exec.GetDataType[T]() +} + +func (b *BinaryArithmeticSuite[T]) SetupTest() { + b.mem = memory.NewCheckedAllocator(memory.DefaultAllocator) + b.opts.CheckOverflow = false + b.ctx = compute.WithAllocator(context.TODO(), b.mem) +} + +func (b *BinaryArithmeticSuite[T]) TearDownTest() { + b.mem.AssertSize(b.T(), 0) +} + +func (b *BinaryArithmeticSuite[T]) makeNullScalar() scalar.Scalar { + return scalar.MakeNullScalar(b.DataType()) +} + +func (b *BinaryArithmeticSuite[T]) makeScalar(val T) scalar.Scalar { + return scalar.MakeScalar(val) +} + +func (b *BinaryArithmeticSuite[T]) assertBinopScalars(fn binaryFunc, lhs, rhs T, expected T) { + left, right := b.makeScalar(lhs), b.makeScalar(rhs) + exp := b.makeScalar(expected) + + actual, err := fn(b.ctx, b.opts, &compute.ScalarDatum{Value: left}, &compute.ScalarDatum{Value: right}) + b.NoError(err) + sc := actual.(*compute.ScalarDatum).Value + + b.Truef(scalar.Equals(exp, sc), "expected: %s\ngot: %s", exp, sc) +} + +func (b *BinaryArithmeticSuite[T]) assertBinopScArr(fn binaryFunc, lhs T, rhs, expected string) { + left := b.makeScalar(lhs) + b.assertBinopScalarArr(fn, left, rhs, expected) +} + +func (b *BinaryArithmeticSuite[T]) assertBinopScalarArr(fn binaryFunc, lhs scalar.Scalar, rhs, expected string) { + right, _, _ := array.FromJSON(b.mem, b.DataType(), strings.NewReader(rhs)) + defer right.Release() + exp, _, _ := array.FromJSON(b.mem, b.DataType(), strings.NewReader(expected)) + defer exp.Release() + + actual, err := fn(b.ctx, b.opts, &compute.ScalarDatum{Value: lhs}, &compute.ArrayDatum{Value: right.Data()}) + b.NoError(err) + defer actual.Release() + assertDatumsEqual(b.T(), &compute.ArrayDatum{Value: exp.Data()}, actual) +} + +func (b *BinaryArithmeticSuite[T]) assertBinopArrSc(fn binaryFunc, lhs string, rhs T, expected string) { + right := b.makeScalar(rhs) + b.assertBinopArrScalar(fn, lhs, right, expected) +} + +func (b *BinaryArithmeticSuite[T]) assertBinopArrScalar(fn binaryFunc, lhs string, rhs scalar.Scalar, expected string) { + left, _, _ := array.FromJSON(b.mem, b.DataType(), strings.NewReader(lhs)) + defer left.Release() + exp, _, _ := array.FromJSON(b.mem, b.DataType(), strings.NewReader(expected)) + defer exp.Release() + + actual, err := fn(b.ctx, b.opts, &compute.ArrayDatum{Value: left.Data()}, &compute.ScalarDatum{Value: rhs}) + b.NoError(err) + defer actual.Release() + assertDatumsEqual(b.T(), &compute.ArrayDatum{Value: exp.Data()}, actual) +} + +func (b *BinaryArithmeticSuite[T]) assertBinopArrays(fn binaryFunc, lhs, rhs, expected string) { + left, _, _ := array.FromJSON(b.mem, b.DataType(), strings.NewReader(lhs)) + defer left.Release() + right, _, _ := array.FromJSON(b.mem, b.DataType(), strings.NewReader(rhs)) + defer right.Release() + exp, _, _ := array.FromJSON(b.mem, b.DataType(), strings.NewReader(expected)) + defer exp.Release() + + b.assertBinop(fn, left, right, exp) +} + +func (b *BinaryArithmeticSuite[T]) assertBinop(fn binaryFunc, left, right, expected arrow.Array) { + actual, err := fn(b.ctx, b.opts, &compute.ArrayDatum{Value: left.Data()}, &compute.ArrayDatum{Value: right.Data()}) + b.Require().NoError(err) + defer actual.Release() + assertDatumsEqual(b.T(), &compute.ArrayDatum{Value: expected.Data()}, actual) + + // also check (Scalar, Scalar) operations + for i := 0; i < expected.Len(); i++ { + s, err := scalar.GetScalar(expected, i) + b.Require().NoError(err) + lhs, _ := scalar.GetScalar(left, i) + rhs, _ := scalar.GetScalar(right, i) + + actual, err := fn(b.ctx, b.opts, &compute.ScalarDatum{Value: lhs}, &compute.ScalarDatum{Value: rhs}) + b.NoError(err) + b.Truef(scalar.Equals(s, actual.(*compute.ScalarDatum).Value), "expected: %s\ngot: %s", s, actual) + } +} + +func (b *BinaryArithmeticSuite[T]) setOverflowCheck(value bool) { + b.opts.CheckOverflow = value +} + +func (b *BinaryArithmeticSuite[T]) assertBinopErr(fn binaryFunc, lhs, rhs, expectedMsg string) { + left, _, _ := array.FromJSON(b.mem, b.DataType(), strings.NewReader(lhs)) + defer left.Release() + right, _, _ := array.FromJSON(b.mem, b.DataType(), strings.NewReader(rhs)) + defer right.Release() + + _, err := fn(b.ctx, b.opts, &compute.ArrayDatum{left.Data()}, &compute.ArrayDatum{Value: right.Data()}) + b.ErrorIs(err, arrow.ErrInvalid) + b.ErrorContains(err, expectedMsg) +} + +func (b *BinaryArithmeticSuite[T]) TestAdd() { + b.Run(b.DataType().String(), func() { + for _, overflow := range []bool{false, true} { + b.Run(fmt.Sprintf("overflow=%t", overflow), func() { + b.setOverflowCheck(overflow) + + b.assertBinopArrays(compute.Add, `[]`, `[]`, `[]`) + b.assertBinopArrays(compute.Add, `[3, 2, 6]`, `[1, 0, 2]`, `[4, 2, 8]`) + // nulls on one side + b.assertBinopArrays(compute.Add, `[null, 1, null]`, `[3, 4, 5]`, `[null, 5, null]`) + b.assertBinopArrays(compute.Add, `[3, 4, 5]`, `[null, 1, null]`, `[null, 5, null]`) + // nulls on both sides + b.assertBinopArrays(compute.Add, `[null, 1, 2]`, `[3, 4, null]`, `[null, 5, null]`) + // all nulls + b.assertBinopArrays(compute.Add, `[null]`, `[null]`, `[null]`) + + // scalar on the left + b.assertBinopScArr(compute.Add, 3, `[1, 2]`, `[4, 5]`) + b.assertBinopScArr(compute.Add, 3, `[null, 2]`, `[null, 5]`) + b.assertBinopScalarArr(compute.Add, b.makeNullScalar(), `[1, 2]`, `[null, null]`) + b.assertBinopScalarArr(compute.Add, b.makeNullScalar(), `[null, 2]`, `[null, null]`) + // scalar on the right + b.assertBinopArrSc(compute.Add, `[1, 2]`, 3, `[4, 5]`) + b.assertBinopArrSc(compute.Add, `[null, 2]`, 3, `[null, 5]`) + b.assertBinopArrScalar(compute.Add, `[1, 2]`, b.makeNullScalar(), `[null, null]`) + b.assertBinopArrScalar(compute.Add, `[null, 2]`, b.makeNullScalar(), `[null, null]`) + }) + } + }) +} + +func (b *BinaryArithmeticSuite[T]) TestSub() { + b.Run(b.DataType().String(), func() { + for _, overflow := range []bool{false, true} { + b.Run(fmt.Sprintf("overflow=%t", overflow), func() { + b.setOverflowCheck(overflow) + + b.assertBinopArrays(compute.Subtract, `[]`, `[]`, `[]`) + b.assertBinopArrays(compute.Subtract, `[3, 2, 6]`, `[1, 0, 2]`, `[2, 2, 4]`) + // nulls on one side + b.assertBinopArrays(compute.Subtract, `[null, 4, null]`, `[2, 1, 0]`, `[null, 3, null]`) + b.assertBinopArrays(compute.Subtract, `[3, 4, 5]`, `[null, 1, null]`, `[null, 3, null]`) + // nulls on both sides + b.assertBinopArrays(compute.Subtract, `[null, 4, 3]`, `[2, 1, null]`, `[null, 3, null]`) + // all nulls + b.assertBinopArrays(compute.Subtract, `[null]`, `[null]`, `[null]`) + + // scalar on the left + b.assertBinopScArr(compute.Subtract, 3, `[1, 2]`, `[2, 1]`) + b.assertBinopScArr(compute.Subtract, 3, `[null, 2]`, `[null, 1]`) + b.assertBinopScalarArr(compute.Subtract, b.makeNullScalar(), `[1, 2]`, `[null, null]`) + b.assertBinopScalarArr(compute.Subtract, b.makeNullScalar(), `[null, 2]`, `[null, null]`) + // scalar on the right + b.assertBinopArrSc(compute.Subtract, `[4, 5]`, 3, `[1, 2]`) + b.assertBinopArrSc(compute.Subtract, `[null, 5]`, 3, `[null, 2]`) + b.assertBinopArrScalar(compute.Subtract, `[1, 2]`, b.makeNullScalar(), `[null, null]`) + b.assertBinopArrScalar(compute.Subtract, `[null, 2]`, b.makeNullScalar(), `[null, null]`) + }) + } + }) +} + +func TestBinaryArithmetic(t *testing.T) { + suite.Run(t, new(BinaryArithmeticSuite[int8])) + suite.Run(t, new(BinaryArithmeticSuite[uint8])) + suite.Run(t, new(BinaryArithmeticSuite[int16])) + suite.Run(t, new(BinaryArithmeticSuite[uint16])) + suite.Run(t, new(BinaryArithmeticSuite[int32])) + suite.Run(t, new(BinaryArithmeticSuite[uint32])) + suite.Run(t, new(BinaryArithmeticSuite[int64])) + suite.Run(t, new(BinaryArithmeticSuite[uint64])) + suite.Run(t, new(BinaryArithmeticSuite[float32])) + suite.Run(t, new(BinaryArithmeticSuite[float64])) +} diff --git a/go/arrow/compute/internal/exec/utils.go b/go/arrow/compute/internal/exec/utils.go index 876e3f38ece..903748a1176 100644 --- a/go/arrow/compute/internal/exec/utils.go +++ b/go/arrow/compute/internal/exec/utils.go @@ -135,6 +135,13 @@ func Min[T constraints.Ordered](a, b T) T { return b } +func Max[T constraints.Ordered](a, b T) T { + if a > b { + return a + } + return b +} + // OptionsInit should be used in the case where a KernelState is simply // represented with a specific type by value (instead of pointer). // This will initialize the KernelState as a value-copied instance of @@ -172,6 +179,11 @@ func GetDataType[T NumericTypes | bool | string]() arrow.DataType { return typMap[reflect.TypeOf(z)] } +func GetType[T NumericTypes | bool | string]() arrow.Type { + var z T + return typMap[reflect.TypeOf(z)].ID() +} + type arrayBuilder[T NumericTypes] interface { array.Builder Append(T) diff --git a/go/arrow/compute/internal/kernels/Makefile b/go/arrow/compute/internal/kernels/Makefile index 752c38d412d..96238cc9a12 100644 --- a/go/arrow/compute/internal/kernels/Makefile +++ b/go/arrow/compute/internal/kernels/Makefile @@ -36,7 +36,8 @@ ALL_SOURCES := $(shell find . -path ./_lib -prune -o -name '*.go' -name '*.s' -n .PHONEY: assembly INTEL_SOURCES := \ - cast_numeric_avx2_amd64.s cast_numeric_sse4_amd64.s constant_factor_avx2_amd64.s constant_factor_sse4_amd64.s + cast_numeric_avx2_amd64.s cast_numeric_sse4_amd64.s constant_factor_avx2_amd64.s \ + constant_factor_sse4_amd64.s base_arithmetic_avx2_amd64.s base_arithmetic_sse4_amd64.s # # ARROW-15336: DO NOT add the assembly target for Arm64 (ARM_SOURCES) until c2goasm added the Arm64 support. @@ -55,6 +56,15 @@ _lib/cast_numeric_sse4_amd64.s: _lib/cast_numeric.cc _lib/cast_numeric_neon.s: _lib/cast_numeric.cc $(CXX) -std=c++17 -S $(C_FLAGS_NEON) $^ -o $@ ; $(PERL_FIXUP_ROTATE) $@ +_lib/base_arithmetic_avx2_amd64.s: _lib/base_arithmetic.cc + $(CXX) -std=c++17 -S $(C_FLAGS) $(ASM_FLAGS_AVX2) $^ -o $@ ; $(PERL_FIXUP_ROTATE) $@ + +_lib/base_arithmetic_sse4_amd64.s: _lib/base_arithmetic.cc + $(CXX) -std=c++17 -S $(C_FLAGS) $(ASM_FLAGS_SSE4) $^ -o $@ ; $(PERL_FIXUP_ROTATE) $@ + +_lib/base_arithmetic_neon.s: _lib/base_arithmetic.cc + $(CXX) -std=c++17 -S $(C_FLAGS_NEON) $^ -o $@ ; $(PERL_FIXUP_ROTATE) $@ + _lib/constant_factor_avx2_amd64.s: _lib/constant_factor.c $(CC) -S $(C_FLAGS) $(ASM_FLAGS_AVX2) $^ -o $@ ; $(PERL_FIXUP_ROTATE) $@ @@ -76,6 +86,12 @@ constant_factor_avx2_amd64.s: _lib/constant_factor_avx2_amd64.s constant_factor_sse4_amd64.s: _lib/constant_factor_sse4_amd64.s $(C2GOASM) -a -f $^ $@ +base_arithmetic_avx2_amd64.s: _lib/base_arithmetic_avx2_amd64.s + $(C2GOASM) -a -f $^ $@ + +base_arithmetic_sse4_amd64.s: _lib/base_arithmetic_sse4_amd64.s + $(C2GOASM) -a -f $^ $@ + clean: rm -f $(INTEL_SOURCES) rm -f $(addprefix _lib/,$(INTEL_SOURCES)) diff --git a/go/arrow/compute/internal/kernels/_lib/base_arithmetic.cc b/go/arrow/compute/internal/kernels/_lib/base_arithmetic.cc new file mode 100644 index 00000000000..335434702c5 --- /dev/null +++ b/go/arrow/compute/internal/kernels/_lib/base_arithmetic.cc @@ -0,0 +1,243 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "types.h" +#include "safe-math.h" + + // Define functions AddWithOverflow, SubtractWithOverflow, MultiplyWithOverflow +// with the signature `bool(T u, T v, T* out)` where T is an integer type. +// On overflow, these functions return true. Otherwise, false is returned +// and `out` is updated with the result of the operation. + +#define OP_WITH_OVERFLOW(_func_name, _psnip_op, _type, _psnip_type) \ + static inline bool _func_name(_type u, _type v, _type* out) { \ + return !psnip_safe_##_psnip_type##_##_psnip_op(out, u, v); \ + } + +#define OPS_WITH_OVERFLOW(_func_name, _psnip_op) \ + OP_WITH_OVERFLOW(_func_name, _psnip_op, int8_t, int8) \ + OP_WITH_OVERFLOW(_func_name, _psnip_op, int16_t, int16) \ + OP_WITH_OVERFLOW(_func_name, _psnip_op, int32_t, int32) \ + OP_WITH_OVERFLOW(_func_name, _psnip_op, int64_t, int64) \ + OP_WITH_OVERFLOW(_func_name, _psnip_op, uint8_t, uint8) \ + OP_WITH_OVERFLOW(_func_name, _psnip_op, uint16_t, uint16) \ + OP_WITH_OVERFLOW(_func_name, _psnip_op, uint32_t, uint32) \ + OP_WITH_OVERFLOW(_func_name, _psnip_op, uint64_t, uint64) + +OPS_WITH_OVERFLOW(AddWithOverflow, add) +OPS_WITH_OVERFLOW(SubtractWithOverflow, sub) +OPS_WITH_OVERFLOW(MultiplyWithOverflow, mul) +OPS_WITH_OVERFLOW(DivideWithOverflow, div) + +enum class optype : int8_t { + ADD, + ADD_CHECKED, + SUB, + SUB_CHECKED, +}; + +template +using is_unsigned_integer_value = bool_constant && is_unsigned_v>; + +template +using is_signed_integer_value = bool_constant && is_signed_v>; + +template +using enable_if_signed_integer_t = enable_if_t::value, R>; + +template +using enable_if_unsigned_integer_t = enable_if_t::value, R>; + +template +using enable_if_integer_t = enable_if_t< + is_signed_integer_value::value || is_unsigned_integer_value::value, R>; + +template +using enable_if_floating_t = enable_if_t, R>; + +struct Add { + template + static constexpr enable_if_floating_t Call(Arg0 left, Arg1 right, bool*) { + return left + right; + } + + template + static constexpr enable_if_integer_t Call(Arg0 left, Arg1 right, bool*) { + return left + right; + } +}; + +struct Sub { + template + static constexpr enable_if_floating_t Call(Arg0 left, Arg1 right, bool*) { + return left - right; + } + + template + static constexpr enable_if_integer_t Call(Arg0 left, Arg1 right, bool*) { + return left - right; + } +}; + +struct AddChecked { + template + static constexpr enable_if_floating_t Call(Arg0 left, Arg1 right, bool*) { + return left + right; + } + + template + static constexpr enable_if_integer_t Call(Arg0 left, Arg1 right, bool* failure) { + static_assert(is_same::value && is_same::value, ""); + T result = 0; + if (AddWithOverflow(left, right, &result)) { + *failure = true; + } + return result; + } +}; + + +struct SubChecked { + template + static constexpr enable_if_floating_t Call(Arg0 left, Arg1 right, bool*) { + return left - right; + } + + template + static constexpr enable_if_integer_t Call(Arg0 left, Arg1 right, bool* failure) { + static_assert(is_same::value && is_same::value, ""); + T result = 0; + if (SubtractWithOverflow(left, right, &result)) { + *failure = true; + } + return result; + } +}; + +template +struct arithmetic_op_arr_arr_impl { + static inline void exec(const void* in_left, const void* in_right, void* out, const int len) { + const T* left = reinterpret_cast(in_left); + const T* right = reinterpret_cast(in_right); + T* output = reinterpret_cast(out); + + bool failure = false; + for (int i = 0; i < len; ++i) { + output[i] = Op::template Call(left[i], right[i], &failure); + } + } +}; + +template +struct arithmetic_op_arr_scalar_impl { + static inline void exec(const void* in_left, const void* scalar_right, void* out, const int len) { + const T* left = reinterpret_cast(in_left); + const T right = *reinterpret_cast(scalar_right); + T* output = reinterpret_cast(out); + + bool failure = false; + for (int i = 0; i < len; ++i) { + output[i] = Op::template Call(left[i], right, &failure); + } + } +}; + +template +struct arithmetic_op_scalar_arr_impl { + static inline void exec(const void* scalar_left, const void* in_right, void* out, const int len) { + const T left = *reinterpret_cast(scalar_left); + const T* right = reinterpret_cast(in_right); + T* output = reinterpret_cast(out); + + bool failure = false; + for (int i = 0; i < len; ++i) { + output[i] = Op::template Call(left, right[i], &failure); + } + } +}; + + +template typename Impl> +static inline void arithmetic_op(const int type, const void* in_left, const void* in_right, void* output, const int len) { + const auto intype = static_cast(type); + + switch (intype) { + case arrtype::UINT8: + Impl::exec(in_left, in_right, output, len); + break; + case arrtype::INT8: + Impl::exec(in_left, in_right, output, len); + break; + case arrtype::UINT16: + Impl::exec(in_left, in_right, output, len); + break; + case arrtype::INT16: + Impl::exec(in_left, in_right, output, len); + break; + case arrtype::UINT32: + Impl::exec(in_left, in_right, output, len); + break; + case arrtype::INT32: + Impl::exec(in_left, in_right, output, len); + break; + case arrtype::UINT64: + Impl::exec(in_left, in_right, output, len); + break; + case arrtype::INT64: + Impl::exec(in_left, in_right, output, len); + break; + case arrtype::FLOAT32: + Impl::exec(in_left, in_right, output, len); + break; + case arrtype::FLOAT64: + Impl::exec(in_left, in_right, output, len); + break; + default: + break; + } +} + +template