diff --git a/go/arrow/compute/internal/kernels/vector_sort.go b/go/arrow/compute/internal/kernels/vector_sort.go new file mode 100644 index 00000000000..e04b01a7840 --- /dev/null +++ b/go/arrow/compute/internal/kernels/vector_sort.go @@ -0,0 +1,87 @@ +package kernels + +import ( + "bytes" + "sort" + + "github.com/apache/arrow/go/v12/arrow" + "github.com/apache/arrow/go/v12/arrow/array" + "github.com/apache/arrow/go/v12/arrow/compute/internal/exec" + "github.com/apache/arrow/go/v12/arrow/memory" +) + +type SortIndicesOptions struct { + NullPlacement NullPlacement `compute:"null_placement"` +} + +func (s *SortIndicesOptions) TypeName() string { + return "SortIndicesOptions" +} + +type Order int + +const ( + Ascending Order = iota + Descending +) + +type NullPlacement int + +const ( + AtStart NullPlacement = iota + AtEnd +) + +func GetVectorSortingKernels() []exec.VectorKernel { + var base exec.VectorKernel + base.CanExecuteChunkWise = true + base.OutputChunked = false + outType := exec.NewOutputType(arrow.ListOf(arrow.PrimitiveTypes.Int64)) + kernels := make([]exec.VectorKernel, 0) + for _, ty := range primitiveTypes { + base.Signature = &exec.KernelSignature{ + InputTypes: []exec.InputType{exec.NewExactInput(ty)}, + OutType: outType, + } + base.ExecFn = sortExec + kernels = append(kernels, base) + } + return kernels +} + +func sortExec(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error { + // Get the input array from the batch + inExecVal := batch.Values[0] + inArr := inExecVal.Array + + // Create a slice of indices, initialized to [0, 1, 2, ..., n-1] + indices := make([]int, inArr.Len) + for i := range indices { + indices[i] = i + } + + sz := inArr.Type.(arrow.FixedWidthDataType).Bytes() + + // Sort the indices slice based on the values in the input array + sort.Slice(indices, func(i, j int) bool { + // TODO: not sure what to do here? + // compare using scalar comparison? + a := inArr.Buffers[1].Buf[indices[i]*sz : (indices[i]+1)*sz] + b := inArr.Buffers[1].Buf[indices[j]*sz : (indices[j]+1)*sz] + return bytes.Compare(a, b) < 0 + }) + + // Create a new array builder to build the output array + builder := array.NewInt64Builder(memory.DefaultAllocator) + + // Add the sorted indices to the output array builder + for _, index := range indices { + builder.Append(int64(index)) + } + + // Build the output array and set it in the output ExecResult + outArr := builder.NewArray() + out.SetMembers(outArr.Data()) + + return nil +} diff --git a/go/arrow/compute/registry.go b/go/arrow/compute/registry.go index 3fbb12d65e4..f929abfd577 100644 --- a/go/arrow/compute/registry.go +++ b/go/arrow/compute/registry.go @@ -53,6 +53,7 @@ func GetFunctionRegistry() FunctionRegistry { RegisterScalarComparisons(registry) RegisterVectorHash(registry) RegisterVectorRunEndFuncs(registry) + RegisterVectorSorting(registry) }) return registry } diff --git a/go/arrow/compute/vector_sort_indices.go b/go/arrow/compute/vector_sort_indices.go new file mode 100644 index 00000000000..cc681d4425c --- /dev/null +++ b/go/arrow/compute/vector_sort_indices.go @@ -0,0 +1,34 @@ +package compute + +import ( + "context" + + "github.com/apache/arrow/go/v12/arrow/compute/internal/kernels" +) + +var ( + sortIndicesDoc = FunctionDoc{ + Summary: "Return the indices that would sort an array", + Description: "This function computes an array of indices that define a stable sort of the input array, record batch or table. By default, nNull values are considered greater than any other value and are therefore sorted at the end of the input. For floating-point types, NaNs are considered greater than any other non-null value, but smaller than null values.", + ArgNames: []string{"array"}, + } +) + +type SortIndicesOptions = kernels.SortIndicesOptions + +// RegisterVectorSorting registers functions related to vector sorting, such as sort_indices. +func RegisterVectorSorting(registry FunctionRegistry) { + vf := NewVectorFunction("sort_indices", Unary(), sortIndicesDoc) + vf.defaultOpts = &kernels.SortIndicesOptions{} + ks := kernels.GetVectorSortingKernels() + for i := range ks { + if err := vf.AddKernel(ks[i]); err != nil { + panic(err) + } + } + registry.AddFunction(vf, false) +} + +func SortIndices(ctx context.Context, opts kernels.SortIndicesOptions, input Datum) (Datum, error) { + return CallFunction(ctx, "sort_indices", &opts, input) +} diff --git a/go/arrow/compute/vector_sort_indices_test.go b/go/arrow/compute/vector_sort_indices_test.go new file mode 100644 index 00000000000..e7be25e205b --- /dev/null +++ b/go/arrow/compute/vector_sort_indices_test.go @@ -0,0 +1,111 @@ +package compute_test + +import ( + "context" + "strings" + "testing" + + "github.com/apache/arrow/go/v12/arrow" + "github.com/apache/arrow/go/v12/arrow/array" + "github.com/apache/arrow/go/v12/arrow/compute" + "github.com/apache/arrow/go/v12/arrow/memory" + "github.com/stretchr/testify/suite" +) + +type SortIndicesSuite struct { + suite.Suite + mem *memory.CheckedAllocator + + valueType arrow.DataType + jsonData []string + expectIndices []int64 + + expected compute.Datum + input compute.Datum + + ctx context.Context +} + +func (suite *SortIndicesSuite) SetupTest() { + suite.mem = memory.NewCheckedAllocator(memory.DefaultAllocator) + suite.ctx = compute.WithAllocator(context.Background(), suite.mem) + + var err error + inputChunks := make([]arrow.Array, len(suite.jsonData)) + for i, data := range suite.jsonData { + inputChunks[i], _, err = array.FromJSON(suite.mem, + suite.valueType, strings.NewReader(data)) + suite.Require().NoError(err) + } + + exp := array.NewInt64Builder(suite.mem) + exp.AppendValues(suite.expectIndices, nil) + arr := exp.NewArray().Data() + suite.expected = &compute.ArrayDatum{Value: arr} + chunked := arrow.NewChunked(inputChunks[0].DataType(), inputChunks) + suite.input = &compute.ChunkedDatum{Value: chunked} + + for i := range inputChunks { + inputChunks[i].Release() + } + exp.Release() +} + +func (suite *SortIndicesSuite) TearDownTest() { + suite.expected.Release() + suite.input.Release() + suite.mem.AssertSize(suite.T(), 0) +} + +func (suite *SortIndicesSuite) TestSortIndices() { + result, err := compute.SortIndices(suite.ctx, + compute.SortIndicesOptions{ + NullPlacement: 0, + }, suite.input) + suite.Require().NoError(err) + defer result.Release() + + assertDatumsEqual(suite.T(), suite.expected, result, nil, nil) +} + +func TestSortIndicesFunctions(t *testing.T) { + // base64 encoded for testing fixed size binary + const ( + valAba = `YWJh` + valAbc = `YWJj` + valAbd = `YWJk` + ) + + tests := []struct { + name string + data []string + expect []int64 + valueType arrow.DataType + }{ + {"simple int32", []string{`[1, 1, 0, -5, -5, -5, 255, 255]`}, []int64{3, 4, 5, 2, 0, 1, 6, 7}, arrow.PrimitiveTypes.Int32}, + //{"uint32 with nulls", []string{`[null, 1, 1, null, null, 5]`}, arrow.PrimitiveTypes.Uint32}, + //{"boolean", []string{`[true, true, true, false, false]`}, arrow.FixedWidthTypes.Boolean}, + //{"boolean no runs", []string{`[true, false, true, false, true, false, true, false, true]`}, arrow.FixedWidthTypes.Boolean}, + //{"float64 len=1", []string{`[1.0]`}, arrow.PrimitiveTypes.Float64}, + //{"bool chunks", []string{`[true, true]`, `[true, false, null, null, false]`, `[null, null]`}, arrow.FixedWidthTypes.Boolean}, + //{"float32 chunked", []string{`[1, 1, 0, -5, -5]`, `[-5, 255, 255]`}, arrow.PrimitiveTypes.Float32}, + //{"str", []string{`["foo", "foo", "foo", "bar", "bar", "baz", "bar", "bar", "foo", "foo"]`}, arrow.BinaryTypes.String}, + //{"large str", []string{`["foo", "foo", "foo", "bar", "bar", "baz", "bar", "bar", "foo", "foo"]`}, arrow.BinaryTypes.LargeString}, + //{"str chunked", []string{`["foo", "foo", null]`, `["foo", "bar", "bar"]`, `[null, null, "baz"]`, `[null]`}, arrow.BinaryTypes.String}, + //{"empty arrs", []string{`[]`}, arrow.PrimitiveTypes.Float32}, + //{"empty str array", []string{`[]`}, arrow.BinaryTypes.String}, + //{"empty chunked", []string{`[]`, `[]`, `[]`}, arrow.FixedWidthTypes.Boolean}, + //{"fsb", []string{`["` + valAba + `", "` + valAba + `", null, "` + valAbc + `", "` + valAbd + `", "` + valAbd + `", "` + valAbd + `"]`}, &arrow.FixedSizeBinaryType{ByteWidth: 3}}, + //{"fsb chunked", []string{`["` + valAba + `", "` + valAba + `", null]`, `["` + valAbc + `", "` + valAbd + `", "` + valAbd + `", "` + valAbd + `"]`, `[]`}, &arrow.FixedSizeBinaryType{ByteWidth: 3}} + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + suite.Run(t, &SortIndicesSuite{ + valueType: tt.valueType, + jsonData: tt.data, + expectIndices: tt.expect, + }) + }) + } +}