diff --git a/go/arrow/array/binary.go b/go/arrow/array/binary.go index 850fb09b4a8..474a35e17fc 100644 --- a/go/arrow/array/binary.go +++ b/go/arrow/array/binary.go @@ -17,6 +17,8 @@ package array import ( + "fmt" + "strings" "unsafe" "github.com/apache/arrow/go/arrow" @@ -80,6 +82,24 @@ func (a *Binary) ValueBytes() []byte { return a.valueBytes[a.valueOffsets[beg]:a.valueOffsets[end]] } +func (a *Binary) String() string { + o := new(strings.Builder) + o.WriteString("[") + for i := 0; i < a.Len(); i++ { + if i > 0 { + o.WriteString(" ") + } + switch { + case a.IsNull(i): + o.WriteString("(null)") + default: + fmt.Fprintf(o, "%q", a.ValueString(i)) + } + } + o.WriteString("]") + return o.String() +} + func (a *Binary) setData(data *Data) { if len(data.buffers) != 3 { panic("len(data.buffers) != 3") diff --git a/go/arrow/array/binary_test.go b/go/arrow/array/binary_test.go index 2af45dee60f..a7bbd47568b 100644 --- a/go/arrow/array/binary_test.go +++ b/go/arrow/array/binary_test.go @@ -405,3 +405,26 @@ func TestBinaryValueBytes(t *testing.T) { assert.Equal(t, []byte{'h', 'i', 'j', 'k', 'l', 'm', 'o', 'p', 'q'}, slice.ValueBytes()) } + +func TestBinaryStringer(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + values := []string{"a", "bc", "", "é", "", "hijk", "lm", "", "opq", "", "tu"} + valids := []bool{true, true, false, true, false, true, true, true, true, false, true} + + b := NewBinaryBuilder(mem, arrow.BinaryTypes.Binary) + defer b.Release() + + b.AppendStringValues(values, valids) + + arr := b.NewArray().(*Binary) + defer arr.Release() + + got := arr.String() + want := `["a" "bc" (null) "é" (null) "hijk" "lm" "" "opq" (null) "tu"]` + + if got != want { + t.Fatalf("invalid stringer:\ngot= %s\nwant=%s\n", got, want) + } +} diff --git a/go/arrow/array/list.go b/go/arrow/array/list.go index b571c50c74c..9b7464171fc 100644 --- a/go/arrow/array/list.go +++ b/go/arrow/array/list.go @@ -79,6 +79,11 @@ func (a *List) Len() int { return a.array.Len() } func (a *List) Offsets() []int32 { return a.offsets } +func (a *List) Retain() { + a.array.Retain() + a.values.Retain() +} + func (a *List) Release() { a.array.Release() a.values.Release() diff --git a/go/arrow/array/list_test.go b/go/arrow/array/list_test.go index 68a50ccca92..0853776e4de 100644 --- a/go/arrow/array/list_test.go +++ b/go/arrow/array/list_test.go @@ -55,6 +55,9 @@ func TestListArray(t *testing.T) { arr := lb.NewArray().(*array.List) defer arr.Release() + arr.Retain() + arr.Release() + if got, want := arr.DataType().ID(), arrow.LIST; got != want { t.Fatalf("got=%v, want=%v", got, want) } diff --git a/go/arrow/array/string.go b/go/arrow/array/string.go index cd0fd434a90..dbc340dc98e 100644 --- a/go/arrow/array/string.go +++ b/go/arrow/array/string.go @@ -46,7 +46,8 @@ func NewStringData(data *Data) *String { } // Value returns the slice at index i. This value should not be mutated. -func (a *String) Value(i int) string { return a.values[a.offsets[i]:a.offsets[i+1]] } +func (a *String) Value(i int) string { return a.values[a.offsets[i]:a.offsets[i+1]] } +func (a *String) ValueOffset(i int) int { return int(a.offsets[i]) } func (a *String) String() string { o := new(strings.Builder) diff --git a/go/arrow/array/string_test.go b/go/arrow/array/string_test.go index 016fc74c225..828b32a191a 100644 --- a/go/arrow/array/string_test.go +++ b/go/arrow/array/string_test.go @@ -30,8 +30,9 @@ func TestStringArray(t *testing.T) { defer mem.AssertSize(t, 0) var ( - want = []string{"hello", "世界", "", "bye"} - valids = []bool{true, true, false, true} + want = []string{"hello", "世界", "", "bye"} + valids = []bool{true, true, false, true} + offsets = []int{0, 5, 11, 11, 14} ) sb := array.NewStringBuilder(mem) @@ -79,6 +80,13 @@ func TestStringArray(t *testing.T) { t.Fatalf("arr[%d]: got=%q, want=%q", i, got, want[i]) } } + + if got, want := arr.ValueOffset(i), offsets[i]; got != want { + t.Fatalf("arr-offset-beg[%d]: got=%d, want=%d", i, got, want) + } + if got, want := arr.ValueOffset(i+1), offsets[i+1]; got != want { + t.Fatalf("arr-offset-end[%d]: got=%d, want=%d", i+1, got, want) + } } sub := array.MakeFromData(arr.Data()) diff --git a/go/arrow/array/struct.go b/go/arrow/array/struct.go index 55fd9135329..3e5200942e2 100644 --- a/go/arrow/array/struct.go +++ b/go/arrow/array/struct.go @@ -65,6 +65,13 @@ func (a *Struct) setData(data *Data) { } } +func (a *Struct) Retain() { + a.array.Retain() + for _, f := range a.fields { + f.Retain() + } +} + func (a *Struct) Release() { a.array.Release() for _, f := range a.fields { diff --git a/go/arrow/array/struct_test.go b/go/arrow/array/struct_test.go index d9701cee1e2..3b2b6b6b8e0 100644 --- a/go/arrow/array/struct_test.go +++ b/go/arrow/array/struct_test.go @@ -79,6 +79,9 @@ func TestStructArray(t *testing.T) { arr := sb.NewArray().(*array.Struct) defer arr.Release() + arr.Retain() + arr.Release() + if got, want := arr.DataType().ID(), arrow.STRUCT; got != want { t.Fatalf("got=%v, want=%v", got, want) } diff --git a/go/arrow/internal/arrdata/arrdata.go b/go/arrow/internal/arrdata/arrdata.go new file mode 100644 index 00000000000..bbdb3c08fb0 --- /dev/null +++ b/go/arrow/internal/arrdata/arrdata.go @@ -0,0 +1,550 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package arrdata exports arrays and records data ready to be used for tests. +package arrdata // import "github.com/apache/arrow/go/arrow/internal/arrdata" + +import ( + "fmt" + "sort" + + "github.com/apache/arrow/go/arrow" + "github.com/apache/arrow/go/arrow/array" + "github.com/apache/arrow/go/arrow/memory" +) + +var ( + Records = make(map[string][]array.Record) + RecordNames []string +) + +func init() { + Records["primitives"] = makePrimitiveRecords() + Records["structs"] = makeStructsRecords() + Records["lists"] = makeListsRecords() + Records["strings"] = makeStringsRecords() + + for k := range Records { + RecordNames = append(RecordNames, k) + } + sort.Strings(RecordNames) +} + +func makePrimitiveRecords() []array.Record { + mem := memory.NewGoAllocator() + + meta := arrow.NewMetadata( + []string{"k1", "k2", "k3"}, + []string{"v1", "v2", "v3"}, + ) + + schema := arrow.NewSchema( + []arrow.Field{ + arrow.Field{Name: "bools", Type: arrow.FixedWidthTypes.Boolean, Nullable: true}, + arrow.Field{Name: "int8s", Type: arrow.PrimitiveTypes.Int8, Nullable: true}, + arrow.Field{Name: "int16s", Type: arrow.PrimitiveTypes.Int16, Nullable: true}, + arrow.Field{Name: "int32s", Type: arrow.PrimitiveTypes.Int32, Nullable: true}, + arrow.Field{Name: "int64s", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, + arrow.Field{Name: "uint8s", Type: arrow.PrimitiveTypes.Uint8, Nullable: true}, + arrow.Field{Name: "uint16s", Type: arrow.PrimitiveTypes.Uint16, Nullable: true}, + arrow.Field{Name: "uint32s", Type: arrow.PrimitiveTypes.Uint32, Nullable: true}, + arrow.Field{Name: "uint64s", Type: arrow.PrimitiveTypes.Uint64, Nullable: true}, + arrow.Field{Name: "float32s", Type: arrow.PrimitiveTypes.Float32, Nullable: true}, + arrow.Field{Name: "float64s", Type: arrow.PrimitiveTypes.Float64, Nullable: true}, + }, &meta, + ) + + mask := []bool{true, false, false, true, true} + chunks := [][]array.Interface{ + []array.Interface{ + arrayOf(mem, []bool{true, false, true, false, true}, mask), + arrayOf(mem, []int8{-1, -2, -3, -4, -5}, mask), + arrayOf(mem, []int16{-1, -2, -3, -4, -5}, mask), + arrayOf(mem, []int32{-1, -2, -3, -4, -5}, mask), + arrayOf(mem, []int64{-1, -2, -3, -4, -5}, mask), + arrayOf(mem, []uint8{+1, +2, +3, +4, +5}, mask), + arrayOf(mem, []uint16{+1, +2, +3, +4, +5}, mask), + arrayOf(mem, []uint32{+1, +2, +3, +4, +5}, mask), + arrayOf(mem, []uint64{+1, +2, +3, +4, +5}, mask), + arrayOf(mem, []float32{+1, +2, +3, +4, +5}, mask), + arrayOf(mem, []float64{+1, +2, +3, +4, +5}, mask), + }, + []array.Interface{ + arrayOf(mem, []bool{true, false, true, false, true}, mask), + arrayOf(mem, []int8{-11, -12, -13, -14, -15}, mask), + arrayOf(mem, []int16{-11, -12, -13, -14, -15}, mask), + arrayOf(mem, []int32{-11, -12, -13, -14, -15}, mask), + arrayOf(mem, []int64{-11, -12, -13, -14, -15}, mask), + arrayOf(mem, []uint8{+11, +12, +13, +14, +15}, mask), + arrayOf(mem, []uint16{+11, +12, +13, +14, +15}, mask), + arrayOf(mem, []uint32{+11, +12, +13, +14, +15}, mask), + arrayOf(mem, []uint64{+11, +12, +13, +14, +15}, mask), + arrayOf(mem, []float32{+11, +12, +13, +14, +15}, mask), + arrayOf(mem, []float64{+11, +12, +13, +14, +15}, mask), + }, + []array.Interface{ + arrayOf(mem, []bool{true, false, true, false, true}, mask), + arrayOf(mem, []int8{-21, -22, -23, -24, -25}, mask), + arrayOf(mem, []int16{-21, -22, -23, -24, -25}, mask), + arrayOf(mem, []int32{-21, -22, -23, -24, -25}, mask), + arrayOf(mem, []int64{-21, -22, -23, -24, -25}, mask), + arrayOf(mem, []uint8{+21, +22, +23, +24, +25}, mask), + arrayOf(mem, []uint16{+21, +22, +23, +24, +25}, mask), + arrayOf(mem, []uint32{+21, +22, +23, +24, +25}, mask), + arrayOf(mem, []uint64{+21, +22, +23, +24, +25}, mask), + arrayOf(mem, []float32{+21, +22, +23, +24, +25}, mask), + arrayOf(mem, []float64{+21, +22, +23, +24, +25}, mask), + }, + } + + defer func() { + for _, chunk := range chunks { + for _, col := range chunk { + col.Release() + } + } + }() + + recs := make([]array.Record, len(chunks)) + for i, chunk := range chunks { + recs[i] = array.NewRecord(schema, chunk, -1) + } + + return recs +} + +func makeStructsRecords() []array.Record { + mem := memory.NewGoAllocator() + + fields := []arrow.Field{ + {Name: "f1", Type: arrow.PrimitiveTypes.Int32}, + {Name: "f2", Type: arrow.BinaryTypes.String}, + } + dtype := arrow.StructOf(fields...) + schema := arrow.NewSchema([]arrow.Field{{Name: "struct_nullable", Type: dtype, Nullable: true}}, nil) + + bldr := array.NewStructBuilder(mem, dtype) + defer bldr.Release() + + mask := []bool{true, false, false, true, true, true, false, true} + chunks := [][]array.Interface{ + []array.Interface{ + structOf(mem, dtype, []array.Interface{ + arrayOf(mem, []int32{-1, -2, -3, -4, -5}, mask[:5]), + arrayOf(mem, []string{"111", "222", "333", "444", "555"}, mask[:5]), + }, []bool{true}), + }, + []array.Interface{ + structOf(mem, dtype, []array.Interface{ + arrayOf(mem, []int32{-11, -12, -13, -14, -15, -16, -17, -18}, mask), + arrayOf(mem, []string{"1", "2", "3", "4", "5", "6", "7", "8"}, mask), + }, []bool{true}), + }, + } + + defer func() { + for _, chunk := range chunks { + for _, col := range chunk { + col.Release() + } + } + }() + + recs := make([]array.Record, len(chunks)) + for i, chunk := range chunks { + recs[i] = array.NewRecord(schema, chunk, -1) + } + + return recs +} + +func makeListsRecords() []array.Record { + mem := memory.NewGoAllocator() + dtype := arrow.ListOf(arrow.PrimitiveTypes.Int32) + schema := arrow.NewSchema([]arrow.Field{ + {Name: "list_nullable", Type: dtype, Nullable: true}, + }, nil) + + mask := []bool{true, false, false, true, true} + + chunks := [][]array.Interface{ + []array.Interface{ + listOf(mem, []array.Interface{ + arrayOf(mem, []int32{1, 2, 3, 4, 5}, mask), + arrayOf(mem, []int32{11, 12, 13, 14, 15}, mask), + arrayOf(mem, []int32{21, 22, 23, 24, 25}, mask), + }, nil), + }, + []array.Interface{ + listOf(mem, []array.Interface{ + arrayOf(mem, []int32{-1, -2, -3, -4, -5}, mask), + arrayOf(mem, []int32{-11, -12, -13, -14, -15}, mask), + arrayOf(mem, []int32{-21, -22, -23, -24, -25}, mask), + }, nil), + }, + []array.Interface{ + listOf(mem, []array.Interface{ + arrayOf(mem, []int32{-1, -2, -3, -4, -5}, mask), + arrayOf(mem, []int32{-11, -12, -13, -14, -15}, mask), + arrayOf(mem, []int32{-21, -22, -23, -24, -25}, mask), + }, []bool{true, false, true}), + }, + } + + defer func() { + for _, chunk := range chunks { + for _, col := range chunk { + col.Release() + } + } + }() + + recs := make([]array.Record, len(chunks)) + for i, chunk := range chunks { + recs[i] = array.NewRecord(schema, chunk, -1) + } + + return recs +} + +func makeStringsRecords() []array.Record { + mem := memory.NewGoAllocator() + schema := arrow.NewSchema([]arrow.Field{ + {Name: "strings", Type: arrow.BinaryTypes.String}, + {Name: "bytes", Type: arrow.BinaryTypes.Binary}, + }, nil) + + mask := []bool{true, false, false, true, true} + chunks := [][]array.Interface{ + []array.Interface{ + arrayOf(mem, []string{"1é", "2", "3", "4", "5"}, mask), + arrayOf(mem, [][]byte{[]byte("1é"), []byte("2"), []byte("3"), []byte("4"), []byte("5")}, mask), + }, + []array.Interface{ + arrayOf(mem, []string{"11", "22", "33", "44", "55"}, mask), + arrayOf(mem, [][]byte{[]byte("11"), []byte("22"), []byte("33"), []byte("44"), []byte("55")}, mask), + }, + []array.Interface{ + arrayOf(mem, []string{"111", "222", "333", "444", "555"}, mask), + arrayOf(mem, [][]byte{[]byte("111"), []byte("222"), []byte("333"), []byte("444"), []byte("555")}, mask), + }, + } + + defer func() { + for _, chunk := range chunks { + for _, col := range chunk { + col.Release() + } + } + }() + + recs := make([]array.Record, len(chunks)) + for i, chunk := range chunks { + recs[i] = array.NewRecord(schema, chunk, -1) + } + + return recs +} + +func arrayOf(mem memory.Allocator, a interface{}, valids []bool) array.Interface { + if mem == nil { + mem = memory.NewGoAllocator() + } + + switch a := a.(type) { + case []bool: + bldr := array.NewBooleanBuilder(mem) + defer bldr.Release() + + bldr.AppendValues(a, valids) + return bldr.NewBooleanArray() + + case []int8: + bldr := array.NewInt8Builder(mem) + defer bldr.Release() + + bldr.AppendValues(a, valids) + return bldr.NewInt8Array() + + case []int16: + bldr := array.NewInt16Builder(mem) + defer bldr.Release() + + bldr.AppendValues(a, valids) + return bldr.NewInt16Array() + + case []int32: + bldr := array.NewInt32Builder(mem) + defer bldr.Release() + + bldr.AppendValues(a, valids) + return bldr.NewInt32Array() + + case []int64: + bldr := array.NewInt64Builder(mem) + defer bldr.Release() + + bldr.AppendValues(a, valids) + return bldr.NewInt64Array() + + case []uint8: + bldr := array.NewUint8Builder(mem) + defer bldr.Release() + + bldr.AppendValues(a, valids) + return bldr.NewUint8Array() + + case []uint16: + bldr := array.NewUint16Builder(mem) + defer bldr.Release() + + bldr.AppendValues(a, valids) + return bldr.NewUint16Array() + + case []uint32: + bldr := array.NewUint32Builder(mem) + defer bldr.Release() + + bldr.AppendValues(a, valids) + return bldr.NewUint32Array() + + case []uint64: + bldr := array.NewUint64Builder(mem) + defer bldr.Release() + + bldr.AppendValues(a, valids) + return bldr.NewUint64Array() + + case []float32: + bldr := array.NewFloat32Builder(mem) + defer bldr.Release() + + bldr.AppendValues(a, valids) + return bldr.NewFloat32Array() + + case []float64: + bldr := array.NewFloat64Builder(mem) + defer bldr.Release() + + bldr.AppendValues(a, valids) + return bldr.NewFloat64Array() + + case []string: + bldr := array.NewStringBuilder(mem) + defer bldr.Release() + + bldr.AppendValues(a, valids) + return bldr.NewStringArray() + + case [][]byte: + bldr := array.NewBinaryBuilder(mem, arrow.BinaryTypes.Binary) + defer bldr.Release() + + bldr.AppendValues(a, valids) + return bldr.NewBinaryArray() + + default: + panic(fmt.Errorf("arrdata: invalid data slice type %T", a)) + } +} + +func listOf(mem memory.Allocator, values []array.Interface, valids []bool) *array.List { + if mem == nil { + mem = memory.NewGoAllocator() + } + + bldr := array.NewListBuilder(mem, values[0].DataType()) + defer bldr.Release() + + valid := func(i int) bool { + return valids[i] + } + + if valids == nil { + valid = func(i int) bool { return true } + } + + for i, value := range values { + bldr.Append(valid(i)) + buildArray(bldr.ValueBuilder(), value) + } + + return bldr.NewListArray() +} + +func structOf(mem memory.Allocator, dtype *arrow.StructType, fields []array.Interface, valids []bool) *array.Struct { + if mem == nil { + mem = memory.NewGoAllocator() + } + + bldr := array.NewStructBuilder(mem, dtype) + defer bldr.Release() + + if valids == nil { + valids = make([]bool, fields[0].Len()) + for i := range valids { + valids[i] = true + } + } + + for _, valid := range valids { + bldr.Append(valid) + for j := range dtype.Fields() { + fbldr := bldr.FieldBuilder(j) + buildArray(fbldr, fields[j]) + } + } + + return bldr.NewStructArray() +} + +func buildArray(bldr array.Builder, data array.Interface) { + defer data.Release() + + switch bldr := bldr.(type) { + case *array.BooleanBuilder: + data := data.(*array.Boolean) + for i := 0; i < data.Len(); i++ { + switch { + case data.IsValid(i): + bldr.Append(data.Value(i)) + default: + bldr.AppendNull() + } + } + + case *array.Int8Builder: + data := data.(*array.Int8) + for i := 0; i < data.Len(); i++ { + switch { + case data.IsValid(i): + bldr.Append(data.Value(i)) + default: + bldr.AppendNull() + } + } + + case *array.Int16Builder: + data := data.(*array.Int16) + for i := 0; i < data.Len(); i++ { + switch { + case data.IsValid(i): + bldr.Append(data.Value(i)) + default: + bldr.AppendNull() + } + } + + case *array.Int32Builder: + data := data.(*array.Int32) + for i := 0; i < data.Len(); i++ { + switch { + case data.IsValid(i): + bldr.Append(data.Value(i)) + default: + bldr.AppendNull() + } + } + + case *array.Int64Builder: + data := data.(*array.Int64) + for i := 0; i < data.Len(); i++ { + switch { + case data.IsValid(i): + bldr.Append(data.Value(i)) + default: + bldr.AppendNull() + } + } + + case *array.Uint8Builder: + data := data.(*array.Uint8) + for i := 0; i < data.Len(); i++ { + switch { + case data.IsValid(i): + bldr.Append(data.Value(i)) + default: + bldr.AppendNull() + } + } + + case *array.Uint16Builder: + data := data.(*array.Uint16) + for i := 0; i < data.Len(); i++ { + switch { + case data.IsValid(i): + bldr.Append(data.Value(i)) + default: + bldr.AppendNull() + } + } + + case *array.Uint32Builder: + data := data.(*array.Uint32) + for i := 0; i < data.Len(); i++ { + switch { + case data.IsValid(i): + bldr.Append(data.Value(i)) + default: + bldr.AppendNull() + } + } + + case *array.Uint64Builder: + data := data.(*array.Uint64) + for i := 0; i < data.Len(); i++ { + switch { + case data.IsValid(i): + bldr.Append(data.Value(i)) + default: + bldr.AppendNull() + } + } + + case *array.Float32Builder: + data := data.(*array.Float32) + for i := 0; i < data.Len(); i++ { + switch { + case data.IsValid(i): + bldr.Append(data.Value(i)) + default: + bldr.AppendNull() + } + } + + case *array.Float64Builder: + data := data.(*array.Float64) + for i := 0; i < data.Len(); i++ { + switch { + case data.IsValid(i): + bldr.Append(data.Value(i)) + default: + bldr.AppendNull() + } + } + + case *array.StringBuilder: + data := data.(*array.String) + for i := 0; i < data.Len(); i++ { + switch { + case data.IsValid(i): + bldr.Append(data.Value(i)) + default: + bldr.AppendNull() + } + } + } +} diff --git a/go/arrow/internal/bitutil/bitutil.go b/go/arrow/internal/bitutil/bitutil.go index 06a2a950331..b547d618a5a 100644 --- a/go/arrow/internal/bitutil/bitutil.go +++ b/go/arrow/internal/bitutil/bitutil.go @@ -30,12 +30,17 @@ var ( // IsMultipleOf8 returns whether v is a multiple of 8. func IsMultipleOf8(v int64) bool { return v&7 == 0 } +func BytesForBits(bits int64) int64 { return (bits + 7) >> 3 } + // NextPowerOf2 rounds x to the next power of two. func NextPowerOf2(x int) int { return 1 << uint(bits.Len(uint(x))) } // CeilByte rounds size to the next multiple of 8. func CeilByte(size int) int { return (size + 7) &^ 7 } +// CeilByte64 rounds size to the next multiple of 8. +func CeilByte64(size int64) int64 { return (size + 7) &^ 7 } + // BitIsSet returns true if the bit at index i in buf is set (1). func BitIsSet(buf []byte, i int) bool { return (buf[uint(i)/8] & BitMask[byte(i)%8]) != 0 } diff --git a/go/arrow/ipc/cmd/arrow-cat/main_test.go b/go/arrow/ipc/cmd/arrow-cat/main_test.go new file mode 100644 index 00000000000..0da306f09ce --- /dev/null +++ b/go/arrow/ipc/cmd/arrow-cat/main_test.go @@ -0,0 +1,380 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "os" + "testing" + + "github.com/apache/arrow/go/arrow/array" + "github.com/apache/arrow/go/arrow/internal/arrdata" + "github.com/apache/arrow/go/arrow/ipc" + "github.com/apache/arrow/go/arrow/memory" +) + +func TestCatStream(t *testing.T) { + for _, tc := range []struct { + name string + want string + }{ + { + name: "primitives", + want: `record 1... + col[0] "bools": [true (null) (null) false true] + col[1] "int8s": [-1 (null) (null) -4 -5] + col[2] "int16s": [-1 (null) (null) -4 -5] + col[3] "int32s": [-1 (null) (null) -4 -5] + col[4] "int64s": [-1 (null) (null) -4 -5] + col[5] "uint8s": [1 (null) (null) 4 5] + col[6] "uint16s": [1 (null) (null) 4 5] + col[7] "uint32s": [1 (null) (null) 4 5] + col[8] "uint64s": [1 (null) (null) 4 5] + col[9] "float32s": [1 (null) (null) 4 5] + col[10] "float64s": [1 (null) (null) 4 5] +record 2... + col[0] "bools": [true (null) (null) false true] + col[1] "int8s": [-11 (null) (null) -14 -15] + col[2] "int16s": [-11 (null) (null) -14 -15] + col[3] "int32s": [-11 (null) (null) -14 -15] + col[4] "int64s": [-11 (null) (null) -14 -15] + col[5] "uint8s": [11 (null) (null) 14 15] + col[6] "uint16s": [11 (null) (null) 14 15] + col[7] "uint32s": [11 (null) (null) 14 15] + col[8] "uint64s": [11 (null) (null) 14 15] + col[9] "float32s": [11 (null) (null) 14 15] + col[10] "float64s": [11 (null) (null) 14 15] +record 3... + col[0] "bools": [true (null) (null) false true] + col[1] "int8s": [-21 (null) (null) -24 -25] + col[2] "int16s": [-21 (null) (null) -24 -25] + col[3] "int32s": [-21 (null) (null) -24 -25] + col[4] "int64s": [-21 (null) (null) -24 -25] + col[5] "uint8s": [21 (null) (null) 24 25] + col[6] "uint16s": [21 (null) (null) 24 25] + col[7] "uint32s": [21 (null) (null) 24 25] + col[8] "uint64s": [21 (null) (null) 24 25] + col[9] "float32s": [21 (null) (null) 24 25] + col[10] "float64s": [21 (null) (null) 24 25] +`, + }, + { + name: "structs", + want: `record 1... + col[0] "struct_nullable": {[-1 (null) (null) -4 -5] ["111" (null) (null) "444" "555"]} +record 2... + col[0] "struct_nullable": {[-11 (null) (null) -14 -15 -16 (null) -18] ["1" (null) (null) "4" "5" "6" (null) "8"]} +`, + }, + { + name: "lists", + want: `record 1... + col[0] "list_nullable": [[1 (null) (null) 4 5] [11 (null) (null) 14 15] [21 (null) (null) 24 25]] +record 2... + col[0] "list_nullable": [[-1 (null) (null) -4 -5] [-11 (null) (null) -14 -15] [-21 (null) (null) -24 -25]] +record 3... + col[0] "list_nullable": [[-1 (null) (null) -4 -5] (null) [-21 (null) (null) -24 -25]] +`, + }, + { + name: "strings", + want: `record 1... + col[0] "strings": ["1é" (null) (null) "4" "5"] + col[1] "bytes": ["1é" (null) (null) "4" "5"] +record 2... + col[0] "strings": ["11" (null) (null) "44" "55"] + col[1] "bytes": ["11" (null) (null) "44" "55"] +record 3... + col[0] "strings": ["111" (null) (null) "444" "555"] + col[1] "bytes": ["111" (null) (null) "444" "555"] +`, + }, + } { + t.Run(tc.name, func(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + fname := func() string { + f, err := ioutil.TempFile("", "go-arrow-") + if err != nil { + t.Fatal(err) + } + defer f.Close() + + w := ipc.NewWriter(f, ipc.WithSchema(arrdata.Records[tc.name][0].Schema()), ipc.WithAllocator(mem)) + defer w.Close() + + for _, rec := range arrdata.Records[tc.name] { + err = w.Write(rec) + if err != nil { + t.Fatal(err) + } + } + + err = w.Close() + if err != nil { + t.Fatal(err) + } + + err = f.Close() + if err != nil { + t.Fatal(err) + } + + return f.Name() + }() + defer os.Remove(fname) + + f, err := os.Open(fname) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + w := new(bytes.Buffer) + err = processStream(w, f) + if err != nil { + t.Fatal(err) + } + + if got, want := w.String(), tc.want; got != want { + t.Fatalf("invalid output:\ngot:\n%s\nwant:\n%s\n", got, want) + } + }) + } +} + +func TestCatFile(t *testing.T) { + for _, tc := range []struct { + name string + want string + stream bool + }{ + { + stream: true, + name: "primitives", + want: `record 1... + col[0] "bools": [true (null) (null) false true] + col[1] "int8s": [-1 (null) (null) -4 -5] + col[2] "int16s": [-1 (null) (null) -4 -5] + col[3] "int32s": [-1 (null) (null) -4 -5] + col[4] "int64s": [-1 (null) (null) -4 -5] + col[5] "uint8s": [1 (null) (null) 4 5] + col[6] "uint16s": [1 (null) (null) 4 5] + col[7] "uint32s": [1 (null) (null) 4 5] + col[8] "uint64s": [1 (null) (null) 4 5] + col[9] "float32s": [1 (null) (null) 4 5] + col[10] "float64s": [1 (null) (null) 4 5] +record 2... + col[0] "bools": [true (null) (null) false true] + col[1] "int8s": [-11 (null) (null) -14 -15] + col[2] "int16s": [-11 (null) (null) -14 -15] + col[3] "int32s": [-11 (null) (null) -14 -15] + col[4] "int64s": [-11 (null) (null) -14 -15] + col[5] "uint8s": [11 (null) (null) 14 15] + col[6] "uint16s": [11 (null) (null) 14 15] + col[7] "uint32s": [11 (null) (null) 14 15] + col[8] "uint64s": [11 (null) (null) 14 15] + col[9] "float32s": [11 (null) (null) 14 15] + col[10] "float64s": [11 (null) (null) 14 15] +record 3... + col[0] "bools": [true (null) (null) false true] + col[1] "int8s": [-21 (null) (null) -24 -25] + col[2] "int16s": [-21 (null) (null) -24 -25] + col[3] "int32s": [-21 (null) (null) -24 -25] + col[4] "int64s": [-21 (null) (null) -24 -25] + col[5] "uint8s": [21 (null) (null) 24 25] + col[6] "uint16s": [21 (null) (null) 24 25] + col[7] "uint32s": [21 (null) (null) 24 25] + col[8] "uint64s": [21 (null) (null) 24 25] + col[9] "float32s": [21 (null) (null) 24 25] + col[10] "float64s": [21 (null) (null) 24 25] +`, + }, + { + name: "primitives", + want: `version: V4 +record 1/3... + col[0] "bools": [true (null) (null) false true] + col[1] "int8s": [-1 (null) (null) -4 -5] + col[2] "int16s": [-1 (null) (null) -4 -5] + col[3] "int32s": [-1 (null) (null) -4 -5] + col[4] "int64s": [-1 (null) (null) -4 -5] + col[5] "uint8s": [1 (null) (null) 4 5] + col[6] "uint16s": [1 (null) (null) 4 5] + col[7] "uint32s": [1 (null) (null) 4 5] + col[8] "uint64s": [1 (null) (null) 4 5] + col[9] "float32s": [1 (null) (null) 4 5] + col[10] "float64s": [1 (null) (null) 4 5] +record 2/3... + col[0] "bools": [true (null) (null) false true] + col[1] "int8s": [-11 (null) (null) -14 -15] + col[2] "int16s": [-11 (null) (null) -14 -15] + col[3] "int32s": [-11 (null) (null) -14 -15] + col[4] "int64s": [-11 (null) (null) -14 -15] + col[5] "uint8s": [11 (null) (null) 14 15] + col[6] "uint16s": [11 (null) (null) 14 15] + col[7] "uint32s": [11 (null) (null) 14 15] + col[8] "uint64s": [11 (null) (null) 14 15] + col[9] "float32s": [11 (null) (null) 14 15] + col[10] "float64s": [11 (null) (null) 14 15] +record 3/3... + col[0] "bools": [true (null) (null) false true] + col[1] "int8s": [-21 (null) (null) -24 -25] + col[2] "int16s": [-21 (null) (null) -24 -25] + col[3] "int32s": [-21 (null) (null) -24 -25] + col[4] "int64s": [-21 (null) (null) -24 -25] + col[5] "uint8s": [21 (null) (null) 24 25] + col[6] "uint16s": [21 (null) (null) 24 25] + col[7] "uint32s": [21 (null) (null) 24 25] + col[8] "uint64s": [21 (null) (null) 24 25] + col[9] "float32s": [21 (null) (null) 24 25] + col[10] "float64s": [21 (null) (null) 24 25] +`, + }, + { + stream: true, + name: "structs", + want: `record 1... + col[0] "struct_nullable": {[-1 (null) (null) -4 -5] ["111" (null) (null) "444" "555"]} +record 2... + col[0] "struct_nullable": {[-11 (null) (null) -14 -15 -16 (null) -18] ["1" (null) (null) "4" "5" "6" (null) "8"]} +`, + }, + { + name: "structs", + want: `version: V4 +record 1/2... + col[0] "struct_nullable": {[-1 (null) (null) -4 -5] ["111" (null) (null) "444" "555"]} +record 2/2... + col[0] "struct_nullable": {[-11 (null) (null) -14 -15 -16 (null) -18] ["1" (null) (null) "4" "5" "6" (null) "8"]} +`, + }, + { + stream: true, + name: "lists", + want: `record 1... + col[0] "list_nullable": [[1 (null) (null) 4 5] [11 (null) (null) 14 15] [21 (null) (null) 24 25]] +record 2... + col[0] "list_nullable": [[-1 (null) (null) -4 -5] [-11 (null) (null) -14 -15] [-21 (null) (null) -24 -25]] +record 3... + col[0] "list_nullable": [[-1 (null) (null) -4 -5] (null) [-21 (null) (null) -24 -25]] +`, + }, + { + name: "lists", + want: `version: V4 +record 1/3... + col[0] "list_nullable": [[1 (null) (null) 4 5] [11 (null) (null) 14 15] [21 (null) (null) 24 25]] +record 2/3... + col[0] "list_nullable": [[-1 (null) (null) -4 -5] [-11 (null) (null) -14 -15] [-21 (null) (null) -24 -25]] +record 3/3... + col[0] "list_nullable": [[-1 (null) (null) -4 -5] (null) [-21 (null) (null) -24 -25]] +`, + }, + { + stream: true, + name: "strings", + want: `record 1... + col[0] "strings": ["1é" (null) (null) "4" "5"] + col[1] "bytes": ["1é" (null) (null) "4" "5"] +record 2... + col[0] "strings": ["11" (null) (null) "44" "55"] + col[1] "bytes": ["11" (null) (null) "44" "55"] +record 3... + col[0] "strings": ["111" (null) (null) "444" "555"] + col[1] "bytes": ["111" (null) (null) "444" "555"] +`, + }, + { + name: "strings", + want: `version: V4 +record 1/3... + col[0] "strings": ["1é" (null) (null) "4" "5"] + col[1] "bytes": ["1é" (null) (null) "4" "5"] +record 2/3... + col[0] "strings": ["11" (null) (null) "44" "55"] + col[1] "bytes": ["11" (null) (null) "44" "55"] +record 3/3... + col[0] "strings": ["111" (null) (null) "444" "555"] + col[1] "bytes": ["111" (null) (null) "444" "555"] +`, + }, + } { + t.Run(fmt.Sprintf("%s-stream=%v", tc.name, tc.stream), func(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + fname := func() string { + f, err := ioutil.TempFile("", "go-arrow-") + if err != nil { + t.Fatal(err) + } + defer f.Close() + + var w interface { + io.Closer + Write(array.Record) error + } + + switch { + case tc.stream: + w = ipc.NewWriter(f, ipc.WithSchema(arrdata.Records[tc.name][0].Schema()), ipc.WithAllocator(mem)) + default: + w, err = ipc.NewFileWriter(f, ipc.WithSchema(arrdata.Records[tc.name][0].Schema()), ipc.WithAllocator(mem)) + if err != nil { + t.Fatal(err) + } + } + defer w.Close() + + for _, rec := range arrdata.Records[tc.name] { + err = w.Write(rec) + if err != nil { + t.Fatal(err) + } + } + + err = w.Close() + if err != nil { + t.Fatal(err) + } + + err = f.Close() + if err != nil { + t.Fatal(err) + } + + return f.Name() + }() + defer os.Remove(fname) + + w := new(bytes.Buffer) + err := processFile(w, fname) + if err != nil { + t.Fatal(err) + } + + if got, want := w.String(), tc.want; got != want { + t.Fatalf("invalid output:\ngot:\n%s\nwant:\n%s\n", got, want) + } + }) + } +} diff --git a/go/arrow/ipc/cmd/arrow-ls/main_test.go b/go/arrow/ipc/cmd/arrow-ls/main_test.go new file mode 100644 index 00000000000..7f77304f790 --- /dev/null +++ b/go/arrow/ipc/cmd/arrow-ls/main_test.go @@ -0,0 +1,270 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "os" + "testing" + + "github.com/apache/arrow/go/arrow/array" + "github.com/apache/arrow/go/arrow/internal/arrdata" + "github.com/apache/arrow/go/arrow/ipc" + "github.com/apache/arrow/go/arrow/memory" +) + +func TestLsStream(t *testing.T) { + for _, tc := range []struct { + name string + want string + }{ + { + name: "primitives", + want: `schema: + fields: 11 + - bools: type=bool, nullable + - int8s: type=int8, nullable + - int16s: type=int16, nullable + - int32s: type=int32, nullable + - int64s: type=int64, nullable + - uint8s: type=uint8, nullable + - uint16s: type=uint16, nullable + - uint32s: type=uint32, nullable + - uint64s: type=uint64, nullable + - float32s: type=float32, nullable + - float64s: type=float64, nullable +metadata: ["k1": "v1", "k2": "v2", "k3": "v3"] +records: 3 +`, + }, + { + name: "structs", + want: `schema: + fields: 1 + - struct_nullable: type=struct, nullable +records: 2 +`, + }, + { + name: "lists", + want: `schema: + fields: 1 + - list_nullable: type=list, nullable +records: 3 +`, + }, + } { + t.Run(tc.name, func(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + fname := func() string { + f, err := ioutil.TempFile("", "go-arrow-") + if err != nil { + t.Fatal(err) + } + defer f.Close() + + w := ipc.NewWriter(f, ipc.WithSchema(arrdata.Records[tc.name][0].Schema()), ipc.WithAllocator(mem)) + defer w.Close() + + for _, rec := range arrdata.Records[tc.name] { + err = w.Write(rec) + if err != nil { + t.Fatal(err) + } + } + + err = w.Close() + if err != nil { + t.Fatal(err) + } + + err = f.Close() + if err != nil { + t.Fatal(err) + } + + return f.Name() + }() + defer os.Remove(fname) + + f, err := os.Open(fname) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + w := new(bytes.Buffer) + err = processStream(w, f) + if err != nil { + t.Fatal(err) + } + + if got, want := w.String(), tc.want; got != want { + t.Fatalf("invalid output:\ngot:\n%s\nwant:\n%s\n", got, want) + } + }) + } +} + +func TestLsFile(t *testing.T) { + for _, tc := range []struct { + stream bool + name string + want string + }{ + { + stream: true, + name: "primitives", + want: `schema: + fields: 11 + - bools: type=bool, nullable + - int8s: type=int8, nullable + - int16s: type=int16, nullable + - int32s: type=int32, nullable + - int64s: type=int64, nullable + - uint8s: type=uint8, nullable + - uint16s: type=uint16, nullable + - uint32s: type=uint32, nullable + - uint64s: type=uint64, nullable + - float32s: type=float32, nullable + - float64s: type=float64, nullable +metadata: ["k1": "v1", "k2": "v2", "k3": "v3"] +records: 3 +`, + }, + { + name: "primitives", + want: `version: V4 +schema: + fields: 11 + - bools: type=bool, nullable + - int8s: type=int8, nullable + - int16s: type=int16, nullable + - int32s: type=int32, nullable + - int64s: type=int64, nullable + - uint8s: type=uint8, nullable + - uint16s: type=uint16, nullable + - uint32s: type=uint32, nullable + - uint64s: type=uint64, nullable + - float32s: type=float32, nullable + - float64s: type=float64, nullable +metadata: ["k1": "v1", "k2": "v2", "k3": "v3"] +records: 3 +`, + }, + { + stream: true, + name: "structs", + want: `schema: + fields: 1 + - struct_nullable: type=struct, nullable +records: 2 +`, + }, + { + name: "structs", + want: `version: V4 +schema: + fields: 1 + - struct_nullable: type=struct, nullable +records: 2 +`, + }, + { + stream: true, + name: "lists", + want: `schema: + fields: 1 + - list_nullable: type=list, nullable +records: 3 +`, + }, + { + name: "lists", + want: `version: V4 +schema: + fields: 1 + - list_nullable: type=list, nullable +records: 3 +`, + }, + } { + t.Run(fmt.Sprintf("%s-stream=%v", tc.name, tc.stream), func(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + fname := func() string { + f, err := ioutil.TempFile("", "go-arrow-") + if err != nil { + t.Fatal(err) + } + defer f.Close() + + var w interface { + io.Closer + Write(array.Record) error + } + + switch { + case tc.stream: + w = ipc.NewWriter(f, ipc.WithSchema(arrdata.Records[tc.name][0].Schema()), ipc.WithAllocator(mem)) + default: + w, err = ipc.NewFileWriter(f, ipc.WithSchema(arrdata.Records[tc.name][0].Schema()), ipc.WithAllocator(mem)) + if err != nil { + t.Fatal(err) + } + } + defer w.Close() + + for _, rec := range arrdata.Records[tc.name] { + err = w.Write(rec) + if err != nil { + t.Fatal(err) + } + } + + err = w.Close() + if err != nil { + t.Fatal(err) + } + + err = f.Close() + if err != nil { + t.Fatal(err) + } + + return f.Name() + }() + defer os.Remove(fname) + + w := new(bytes.Buffer) + err := processFile(w, fname) + if err != nil { + t.Fatal(err) + } + + if got, want := w.String(), tc.want; got != want { + t.Fatalf("invalid output:\ngot:\n%s\nwant:\n%s\n", got, want) + } + }) + } +} diff --git a/go/arrow/ipc/file_reader.go b/go/arrow/ipc/file_reader.go index 8fa3009d11e..81462ba0883 100644 --- a/go/arrow/ipc/file_reader.go +++ b/go/arrow/ipc/file_reader.go @@ -50,19 +50,16 @@ type FileReader struct { // NewFileReader opens an Arrow file using the provided reader r. func NewFileReader(r ReadAtSeeker, opts ...Option) (*FileReader, error) { var ( + cfg = newConfig(opts...) + err error + f = FileReader{ r: r, fields: make(dictTypeMap), memo: newMemo(), } - cfg = newConfig() - err error ) - for _, opt := range opts { - opt(cfg) - } - if cfg.footer.offset <= 0 { cfg.footer.offset, err = f.r.Seek(0, io.SeekEnd) if err != nil { @@ -81,6 +78,10 @@ func NewFileReader(r ReadAtSeeker, opts ...Option) (*FileReader, error) { return nil, errors.Wrap(err, "arrow/ipc: could not decode schema") } + if cfg.schema != nil && !cfg.schema.Equal(f.schema) { + return nil, errors.Errorf("arrow/ipc: inconsitent schema for reading (got: %v, want: %v)", f.schema, cfg.schema) + } + return &f, err } @@ -88,7 +89,7 @@ func (f *FileReader) readFooter() error { var err error if f.footer.offset <= int64(len(Magic)*2+4) { - return fmt.Errorf("arrow/ipc: file too small (%d)", f.footer.offset) + return fmt.Errorf("arrow/ipc: file too small (size=%d)", f.footer.offset) } eof := int64(len(Magic) + 4) diff --git a/go/arrow/ipc/file_test.go b/go/arrow/ipc/file_test.go new file mode 100644 index 00000000000..1f0d8214094 --- /dev/null +++ b/go/arrow/ipc/file_test.go @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipc_test + +import ( + "io" + "io/ioutil" + "os" + "testing" + + "github.com/apache/arrow/go/arrow/internal/arrdata" + "github.com/apache/arrow/go/arrow/ipc" + "github.com/apache/arrow/go/arrow/memory" +) + +func TestFile(t *testing.T) { + for name, recs := range arrdata.Records { + t.Run(name, func(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + f, err := ioutil.TempFile("", "arrow-ipc-") + if err != nil { + t.Fatal(err) + } + defer f.Close() + defer os.Remove(f.Name()) + + { + w, err := ipc.NewFileWriter(f, ipc.WithSchema(recs[0].Schema()), ipc.WithAllocator(mem)) + if err != nil { + t.Fatal(err) + } + defer w.Close() + + for i, rec := range recs { + err = w.Write(rec) + if err != nil { + t.Fatalf("could not write record[%d]: %v", i, err) + } + } + + err = w.Close() + if err != nil { + t.Fatal(err) + } + + err = f.Sync() + if err != nil { + t.Fatalf("could not sync data to disk: %v", err) + } + + _, err = f.Seek(0, io.SeekStart) + if err != nil { + t.Fatalf("could not seek to start: %v", err) + } + } + + { + r, err := ipc.NewFileReader(f, ipc.WithSchema(recs[0].Schema()), ipc.WithAllocator(mem)) + if err != nil { + t.Fatal(err) + } + defer r.Close() + + if got, want := r.NumRecords(), len(recs); got != want { + t.Fatalf("invalid number of records. got=%d, want=%d", got, want) + } + + for i := 0; i < r.NumRecords(); i++ { + rec, err := r.Record(i) + if err != nil { + t.Fatalf("could not read record %d: %v", i, err) + } + if !cmpRecs(rec, recs[i]) { + t.Fatalf("records[%d] differ", i) + } + } + + err = r.Close() + if err != nil { + t.Fatal(err) + } + } + }) + } + +} diff --git a/go/arrow/ipc/file_writer.go b/go/arrow/ipc/file_writer.go new file mode 100644 index 00000000000..1fb6f529302 --- /dev/null +++ b/go/arrow/ipc/file_writer.go @@ -0,0 +1,333 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipc // import "github.com/apache/arrow/go/arrow/ipc" + +import ( + "encoding/binary" + "io" + + "github.com/apache/arrow/go/arrow" + "github.com/apache/arrow/go/arrow/array" + "github.com/apache/arrow/go/arrow/internal/bitutil" + "github.com/apache/arrow/go/arrow/internal/flatbuf" + "github.com/apache/arrow/go/arrow/memory" + "github.com/pkg/errors" +) + +type payloadWriter interface { + start() error + write(payload) error + Close() error +} + +type pwriter struct { + w io.WriteSeeker + pos int64 + + schema *arrow.Schema + dicts []fileBlock + recs []fileBlock +} + +func (w *pwriter) start() error { + var err error + + err = w.updatePos() + if err != nil { + return errors.Wrap(err, "arrow/ipc: could not update position while in start") + } + + // only necessary to align to 8-byte boundary at the start of the file + _, err = w.Write(Magic) + if err != nil { + return errors.Wrap(err, "arrow/ipc: could not write magic Arrow bytes") + } + + err = w.align(kArrowIPCAlignment) + if err != nil { + return errors.Wrap(err, "arrow/ipc: could not align start block") + } + + return err +} + +func (w *pwriter) write(p payload) error { + blk := fileBlock{Offset: w.pos, Meta: 0, Body: p.size} + n, err := writeIPCPayload(w, p) + if err != nil { + return err + } + + blk.Meta = int32(n) + + err = w.updatePos() + if err != nil { + return errors.Wrap(err, "arrow/ipc: could not update position while in write-payload") + } + + switch byte(p.msg) { + case flatbuf.MessageHeaderDictionaryBatch: + w.dicts = append(w.dicts, blk) + case flatbuf.MessageHeaderRecordBatch: + w.recs = append(w.recs, blk) + } + + return nil +} + +func (w *pwriter) Close() error { + var err error + + // write file footer + err = w.updatePos() + if err != nil { + return errors.Wrap(err, "arrow/ipc: could not update position while in close") + } + + pos := w.pos + err = writeFileFooter(w.schema, w.dicts, w.recs, w) + if err != nil { + return errors.Wrap(err, "arrow/ipc: could not write file footer") + } + + // write file footer length + err = w.updatePos() // not strictly needed as we passed w to writeFileFooter... + if err != nil { + return errors.Wrap(err, "arrow/ipc: could not compute file footer length") + } + + size := w.pos - pos + if size <= 0 { + return errors.Errorf("arrow/ipc: invalid file footer size (size=%d)", size) + } + + buf := make([]byte, 4) + binary.LittleEndian.PutUint32(buf, uint32(size)) + _, err = w.Write(buf) + if err != nil { + return errors.Wrap(err, "arrow/ipc: could not write file footer size") + } + + _, err = w.Write(Magic) + if err != nil { + return errors.Wrap(err, "arrow/ipc: could not write Arrow magic bytes") + } + + return nil +} + +func (w *pwriter) updatePos() error { + var err error + w.pos, err = w.w.Seek(0, io.SeekCurrent) + return err +} + +func (w *pwriter) align(align int32) error { + remainder := paddedLength(w.pos, align) - w.pos + if remainder == 0 { + return nil + } + + _, err := w.Write(paddingBytes[:int(remainder)]) + return err +} + +func (w *pwriter) Write(p []byte) (int, error) { + n, err := w.w.Write(p) + w.pos += int64(n) + return n, err +} + +func writeIPCPayload(w io.Writer, p payload) (int, error) { + n, err := writeMessage(p.meta, kArrowIPCAlignment, w) + if err != nil { + return n, err + } + + // now write the buffers + for _, buf := range p.body { + var ( + size int64 + padding int64 + ) + + // the buffer might be null if we are handling zero row lengths. + if buf != nil { + size = int64(buf.Len()) + padding = bitutil.CeilByte64(size) - size + } + + if size > 0 { + _, err = w.Write(buf.Bytes()) + if err != nil { + return n, errors.Wrap(err, "arrow/ipc: could not write payload message body") + } + } + + if padding > 0 { + _, err = w.Write(paddingBytes[:padding]) + if err != nil { + return n, errors.Wrap(err, "arrow/ipc: could not write payload message padding") + } + } + } + + return n, err +} + +type payload struct { + msg MessageType + meta *memory.Buffer + body []*memory.Buffer + size int64 // length of body +} + +func (p *payload) Release() { + if p.meta != nil { + p.meta.Release() + p.meta = nil + } + for i, b := range p.body { + if b == nil { + continue + } + b.Release() + p.body[i] = nil + } +} + +type payloads []payload + +func (ps payloads) Release() { + for i := range ps { + ps[i].Release() + } +} + +// FileWriter is an Arrow file writer. +type FileWriter struct { + w io.WriteSeeker + + mem memory.Allocator + + header struct { + started bool + offset int64 + } + + footer struct { + written bool + } + + pw payloadWriter + + schema *arrow.Schema +} + +// NewFileWriter opens an Arrow file using the provided writer w. +func NewFileWriter(w io.WriteSeeker, opts ...Option) (*FileWriter, error) { + var ( + cfg = newConfig(opts...) + err error + ) + + f := FileWriter{ + w: w, + pw: &pwriter{w: w, schema: cfg.schema, pos: -1}, + mem: cfg.alloc, + schema: cfg.schema, + } + + pos, err := f.w.Seek(0, io.SeekCurrent) + if err != nil { + return nil, errors.Errorf("arrow/ipc: could not seek current position: %v", err) + } + f.header.offset = pos + + return &f, err +} + +func (f *FileWriter) Close() error { + err := f.checkStarted() + if err != nil { + return errors.Wrap(err, "arrow/ipc: could not write empty file") + } + + if f.footer.written { + return nil + } + + err = f.pw.Close() + if err != nil { + return errors.Wrap(err, "arrow/ipc: could not close payload writer") + } + f.footer.written = true + + return nil +} + +func (f *FileWriter) Write(rec array.Record) error { + schema := rec.Schema() + if schema == nil || !schema.Equal(f.schema) { + return errInconsistentSchema + } + + if err := f.checkStarted(); err != nil { + return errors.Wrap(err, "arrow/ipc: could not write header") + } + + const allow64b = true + var ( + data = payload{msg: MessageRecordBatch} + enc = newRecordEncoder(f.mem, 0, kMaxNestingDepth, allow64b) + ) + defer data.Release() + + if err := enc.Encode(&data, rec); err != nil { + return errors.Wrap(err, "arrow/ipc: could not encode record to payload") + } + + return f.pw.write(data) +} + +func (f *FileWriter) checkStarted() error { + if !f.header.started { + return f.start() + } + return nil +} + +func (f *FileWriter) start() error { + f.header.started = true + err := f.pw.start() + if err != nil { + return err + } + + // write out schema payloads + ps := payloadsFromSchema(f.schema, f.mem, nil) + defer ps.Release() + + for _, data := range ps { + err = f.pw.write(data) + if err != nil { + return err + } + } + + return nil +} diff --git a/go/arrow/ipc/ipc.go b/go/arrow/ipc/ipc.go index fc83e94609a..470baf2fba1 100644 --- a/go/arrow/ipc/ipc.go +++ b/go/arrow/ipc/ipc.go @@ -19,14 +19,32 @@ package ipc // import "github.com/apache/arrow/go/arrow/ipc" import ( "io" + "github.com/apache/arrow/go/arrow" "github.com/apache/arrow/go/arrow/memory" ) const ( errNotArrowFile = errString("arrow/ipc: not an Arrow file") errInconsistentFileMetadata = errString("arrow/ipc: file is smaller than indicated metadata size") + errInconsistentSchema = errString("arrow/ipc: tried to write record batch with different schema") + errMaxRecursion = errString("arrow/ipc: max recursion depth reached") + errBigArray = errString("arrow/ipc: array larger than 2^31-1 in length") + + kArrowAlignment = 64 // buffers are padded to 64b boundaries (for SIMD) + kTensorAlignment = 64 // tensors are padded to 64b boundaries + kArrowIPCAlignment = 8 // align on 8b boundaries in IPC +) + +var ( + paddingBytes [kArrowAlignment]byte + kEOS = [4]byte{0, 0, 0, 0} // end of stream message ) +func paddedLength(nbytes int64, alignment int32) int64 { + align := int64(alignment) + return ((nbytes + align - 1) / align) * align +} + type errString string func (s errString) Error() string { @@ -41,15 +59,22 @@ type ReadAtSeeker interface { type config struct { alloc memory.Allocator + schema *arrow.Schema footer struct { offset int64 } } -func newConfig() *config { - return &config{ +func newConfig(opts ...Option) *config { + cfg := &config{ alloc: memory.NewGoAllocator(), } + + for _, opt := range opts { + opt(cfg) + } + + return cfg } // Option is a functional option to configure opening or creating Arrow files @@ -69,3 +94,10 @@ func WithAllocator(mem memory.Allocator) Option { cfg.alloc = mem } } + +// WithSchema specifies the Arrow schema to be used for reading or writing. +func WithSchema(schema *arrow.Schema) Option { + return func(cfg *config) { + cfg.schema = schema + } +} diff --git a/go/arrow/ipc/ipc_test.go b/go/arrow/ipc/ipc_test.go new file mode 100644 index 00000000000..01be713cd30 --- /dev/null +++ b/go/arrow/ipc/ipc_test.go @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipc_test + +import ( + "fmt" + "io" + "strings" + + "github.com/apache/arrow/go/arrow/array" +) + +func cmpRecs(r1, r2 array.Record) bool { + // FIXME(sbinet): impl+use arrow.Record.Equal ? + + if !r1.Schema().Equal(r2.Schema()) { + return false + } + if r1.NumCols() != r2.NumCols() { + return false + } + if r1.NumRows() != r2.NumRows() { + return false + } + + var ( + txt1 = new(strings.Builder) + txt2 = new(strings.Builder) + ) + + printRec(txt1, r1) + printRec(txt2, r2) + + return txt1.String() == txt2.String() +} + +func printRec(w io.Writer, rec array.Record) { + for i, col := range rec.Columns() { + fmt.Fprintf(w, " col[%d] %q: %v\n", i, rec.ColumnName(i), col) + } +} diff --git a/go/arrow/ipc/metadata.go b/go/arrow/ipc/metadata.go index 93b05315370..b9718bd1073 100644 --- a/go/arrow/ipc/metadata.go +++ b/go/arrow/ipc/metadata.go @@ -17,8 +17,10 @@ package ipc // import "github.com/apache/arrow/go/arrow/ipc" import ( + "encoding/binary" "fmt" "io" + "sort" "github.com/apache/arrow/go/arrow" "github.com/apache/arrow/go/arrow/internal/flatbuf" @@ -43,6 +45,8 @@ const ( kMaxNestingDepth = 64 ) +type startVecFunc func(b *flatbuffers.Builder, n int) flatbuffers.UOffsetT + type fieldMetadata struct { Len int64 Nulls int64 @@ -62,6 +66,16 @@ type fileBlock struct { r io.ReaderAt } +func fileBlocksToFB(b *flatbuffers.Builder, blocks []fileBlock, start startVecFunc) flatbuffers.UOffsetT { + start(b, len(blocks)) + for i := len(blocks) - 1; i >= 0; i-- { + blk := blocks[i] + flatbuf.CreateBlock(b, blk.Offset, blk.Meta, blk.Body) + } + + return b.EndVector(len(blocks)) +} + func (blk fileBlock) NewMessage() (*Message, error) { var ( err error @@ -90,6 +104,36 @@ func (blk fileBlock) section() io.Reader { return io.NewSectionReader(blk.r, blk.Offset, int64(blk.Meta)+blk.Body) } +func unitFromFB(unit flatbuf.TimeUnit) arrow.TimeUnit { + switch unit { + case flatbuf.TimeUnitSECOND: + return arrow.Second + case flatbuf.TimeUnitMILLISECOND: + return arrow.Millisecond + case flatbuf.TimeUnitMICROSECOND: + return arrow.Microsecond + case flatbuf.TimeUnitNANOSECOND: + return arrow.Nanosecond + default: + panic(errors.Errorf("arrow/ipc: invalid flatbuf.TimeUnit(%d) value", unit)) + } +} + +func unitToFB(unit arrow.TimeUnit) flatbuf.TimeUnit { + switch unit { + case arrow.Second: + return flatbuf.TimeUnitSECOND + case arrow.Millisecond: + return flatbuf.TimeUnitMILLISECOND + case arrow.Microsecond: + return flatbuf.TimeUnitMICROSECOND + case arrow.Nanosecond: + return flatbuf.TimeUnitNANOSECOND + default: + panic(errors.Errorf("arrow/ipc: invalid arrow.TimeUnit(%d) value", unit)) + } +} + // initFB is a helper function to handle flatbuffers' polymorphism. func initFB(t interface { Table() flatbuffers.Table @@ -110,7 +154,7 @@ func fieldFromFB(field *flatbuf.Field, memo *dictMemo) (arrow.Field, error) { o.Name = string(field.Name()) o.Nullable = field.Nullable() - o.Metadata, err = metadataFrom(field) + o.Metadata, err = metadataFromFB(field) if err != nil { return o, err } @@ -137,15 +181,221 @@ func fieldFromFB(field *flatbuf.Field, memo *dictMemo) (arrow.Field, error) { return o, errors.Wrapf(err, "arrow/ipc: could not convert field type") } default: - // log.Printf("encoding: %v", encoding.Id()) - // n := field.ChildrenLength() - // log.Printf("children: %v", n) panic("not implemented") // FIXME(sbinet) } return o, nil } +func fieldToFB(b *flatbuffers.Builder, field arrow.Field, memo *dictMemo) flatbuffers.UOffsetT { + var visitor = fieldVisitor{b: b, memo: memo, meta: make(map[string]string)} + return visitor.result(field) +} + +type fieldVisitor struct { + b *flatbuffers.Builder + memo *dictMemo + dtype flatbuf.Type + offset flatbuffers.UOffsetT + kids []flatbuffers.UOffsetT + meta map[string]string +} + +func (fv *fieldVisitor) visit(dt arrow.DataType) { + switch dt := dt.(type) { + case *arrow.NullType: + fv.dtype = flatbuf.TypeNull + flatbuf.NullStart(fv.b) + fv.offset = flatbuf.NullEnd(fv.b) + + case *arrow.BooleanType: + fv.dtype = flatbuf.TypeBool + flatbuf.BoolStart(fv.b) + fv.offset = flatbuf.BoolEnd(fv.b) + + case *arrow.Uint8Type: + fv.dtype = flatbuf.TypeInt + fv.offset = intToFB(fv.b, int32(dt.BitWidth()), false) + + case *arrow.Uint16Type: + fv.dtype = flatbuf.TypeInt + fv.offset = intToFB(fv.b, int32(dt.BitWidth()), false) + + case *arrow.Uint32Type: + fv.dtype = flatbuf.TypeInt + fv.offset = intToFB(fv.b, int32(dt.BitWidth()), false) + + case *arrow.Uint64Type: + fv.dtype = flatbuf.TypeInt + fv.offset = intToFB(fv.b, int32(dt.BitWidth()), false) + + case *arrow.Int8Type: + fv.dtype = flatbuf.TypeInt + fv.offset = intToFB(fv.b, int32(dt.BitWidth()), true) + + case *arrow.Int16Type: + fv.dtype = flatbuf.TypeInt + fv.offset = intToFB(fv.b, int32(dt.BitWidth()), true) + + case *arrow.Int32Type: + fv.dtype = flatbuf.TypeInt + fv.offset = intToFB(fv.b, int32(dt.BitWidth()), true) + + case *arrow.Int64Type: + fv.dtype = flatbuf.TypeInt + fv.offset = intToFB(fv.b, int32(dt.BitWidth()), true) + + case *arrow.Float32Type: + fv.dtype = flatbuf.TypeFloatingPoint + fv.offset = floatToFB(fv.b, int32(dt.BitWidth())) + + case *arrow.Float64Type: + fv.dtype = flatbuf.TypeFloatingPoint + fv.offset = floatToFB(fv.b, int32(dt.BitWidth())) + + case *arrow.FixedSizeBinaryType: + fv.dtype = flatbuf.TypeFixedSizeBinary + flatbuf.FixedSizeBinaryStart(fv.b) + flatbuf.FixedSizeBinaryAddByteWidth(fv.b, int32(dt.ByteWidth)) + fv.offset = flatbuf.FixedSizeBinaryEnd(fv.b) + + case *arrow.BinaryType: + fv.dtype = flatbuf.TypeBinary + flatbuf.BinaryStart(fv.b) + fv.offset = flatbuf.BinaryEnd(fv.b) + + case *arrow.StringType: + fv.dtype = flatbuf.TypeUtf8 + flatbuf.Utf8Start(fv.b) + fv.offset = flatbuf.Utf8End(fv.b) + + case *arrow.Date32Type: + fv.dtype = flatbuf.TypeDate + flatbuf.DateStart(fv.b) + flatbuf.DateAddUnit(fv.b, flatbuf.DateUnitDAY) + fv.offset = flatbuf.DateEnd(fv.b) + + case *arrow.Date64Type: + fv.dtype = flatbuf.TypeDate + flatbuf.DateStart(fv.b) + flatbuf.DateAddUnit(fv.b, flatbuf.DateUnitMILLISECOND) + fv.offset = flatbuf.DateEnd(fv.b) + + case *arrow.Time32Type: + fv.dtype = flatbuf.TypeTime + flatbuf.TimeStart(fv.b) + flatbuf.TimeAddUnit(fv.b, unitToFB(dt.Unit)) + flatbuf.TimeAddBitWidth(fv.b, 32) + fv.offset = flatbuf.TimeEnd(fv.b) + + case *arrow.Time64Type: + fv.dtype = flatbuf.TypeTime + flatbuf.TimeStart(fv.b) + flatbuf.TimeAddUnit(fv.b, unitToFB(dt.Unit)) + flatbuf.TimeAddBitWidth(fv.b, 64) + fv.offset = flatbuf.TimeEnd(fv.b) + + case *arrow.TimestampType: + fv.dtype = flatbuf.TypeTimestamp + unit := unitToFB(dt.Unit) + tz := fv.b.CreateString(dt.TimeZone) + flatbuf.TimestampStart(fv.b) + flatbuf.TimestampAddUnit(fv.b, unit) + flatbuf.TimestampAddTimezone(fv.b, tz) + fv.offset = flatbuf.TimestampEnd(fv.b) + + case *arrow.StructType: + fv.dtype = flatbuf.TypeStruct_ + offsets := make([]flatbuffers.UOffsetT, len(dt.Fields())) + for i, field := range dt.Fields() { + offsets[i] = fieldToFB(fv.b, field, fv.memo) + } + flatbuf.Struct_Start(fv.b) + for i := len(offsets) - 1; i >= 0; i-- { + fv.b.PrependUOffsetT(offsets[i]) + } + fv.offset = flatbuf.Struct_End(fv.b) + fv.kids = append(fv.kids, offsets...) + + case *arrow.ListType: + fv.dtype = flatbuf.TypeList + fv.kids = append(fv.kids, fieldToFB(fv.b, arrow.Field{Name: "item", Type: dt.Elem()}, fv.memo)) + flatbuf.ListStart(fv.b) + fv.offset = flatbuf.ListEnd(fv.b) + + default: + err := errors.Errorf("arrow/ipc: invalid data type %v", dt) + panic(err) // FIXME(sbinet): implement all data-types. + } +} + +func (fv *fieldVisitor) result(field arrow.Field) flatbuffers.UOffsetT { + nameFB := fv.b.CreateString(field.Name) + + fv.visit(field.Type) + + flatbuf.FieldStartChildrenVector(fv.b, len(fv.kids)) + for i := len(fv.kids) - 1; i >= 0; i-- { + fv.b.PrependUOffsetT(fv.kids[i]) + } + kidsFB := fv.b.EndVector(len(fv.kids)) + + var dictFB flatbuffers.UOffsetT + if field.Type.ID() == arrow.DICTIONARY { + panic("not implemented") // FIXME(sbinet) + } + + var ( + metaFB flatbuffers.UOffsetT + kvs []flatbuffers.UOffsetT + ) + for i, k := range field.Metadata.Keys() { + v := field.Metadata.Values()[i] + kk := fv.b.CreateString(k) + vv := fv.b.CreateString(v) + flatbuf.KeyValueStart(fv.b) + flatbuf.KeyValueAddKey(fv.b, kk) + flatbuf.KeyValueAddValue(fv.b, vv) + kvs = append(kvs, flatbuf.KeyValueEnd(fv.b)) + } + { + keys := make([]string, 0, len(fv.meta)) + for k := range fv.meta { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + v := fv.meta[k] + kk := fv.b.CreateString(k) + vv := fv.b.CreateString(v) + flatbuf.KeyValueStart(fv.b) + flatbuf.KeyValueAddKey(fv.b, kk) + flatbuf.KeyValueAddValue(fv.b, vv) + kvs = append(kvs, flatbuf.KeyValueEnd(fv.b)) + } + } + if len(kvs) > 0 { + flatbuf.FieldStartCustomMetadataVector(fv.b, len(kvs)) + for i := len(kvs) - 1; i >= 0; i-- { + fv.b.PrependUOffsetT(kvs[i]) + } + metaFB = fv.b.EndVector(len(kvs)) + } + + flatbuf.FieldStart(fv.b) + flatbuf.FieldAddName(fv.b, nameFB) + flatbuf.FieldAddNullable(fv.b, field.Nullable) + flatbuf.FieldAddTypeType(fv.b, fv.dtype) + flatbuf.FieldAddType(fv.b, fv.offset) + flatbuf.FieldAddDictionary(fv.b, dictFB) + flatbuf.FieldAddChildren(fv.b, kidsFB) + flatbuf.FieldAddCustomMetadata(fv.b, metaFB) + + offset := flatbuf.FieldEnd(fv.b) + + return offset +} + func fieldFromFBDict(field *flatbuf.Field) (arrow.Field, error) { var ( o = arrow.Field{ @@ -170,7 +420,7 @@ func fieldFromFBDict(field *flatbuf.Field) (arrow.Field, error) { } } - meta, err := metadataFrom(field) + meta, err := metadataFromFB(field) if err != nil { return o, errors.Wrap(err, "arrow/ipc: metadata for field from dict") } @@ -299,6 +549,13 @@ func intFromFB(data flatbuf.Int) (arrow.DataType, error) { } } +func intToFB(b *flatbuffers.Builder, bw int32, isSigned bool) flatbuffers.UOffsetT { + flatbuf.IntStart(b) + flatbuf.IntAddBitWidth(b, bw) + flatbuf.IntAddIsSigned(b, isSigned) + return flatbuf.IntEnd(b) +} + func floatFromFB(data flatbuf.FloatingPoint) (arrow.DataType, error) { switch p := data.Precision(); p { case flatbuf.PrecisionHALF: @@ -312,12 +569,31 @@ func floatFromFB(data flatbuf.FloatingPoint) (arrow.DataType, error) { } } +func floatToFB(b *flatbuffers.Builder, bw int32) flatbuffers.UOffsetT { + switch bw { + case 16: + flatbuf.FloatingPointStart(b) + flatbuf.FloatingPointAddPrecision(b, flatbuf.PrecisionHALF) + return flatbuf.FloatingPointEnd(b) + case 32: + flatbuf.FloatingPointStart(b) + flatbuf.FloatingPointAddPrecision(b, flatbuf.PrecisionSINGLE) + return flatbuf.FloatingPointEnd(b) + case 64: + flatbuf.FloatingPointStart(b) + flatbuf.FloatingPointAddPrecision(b, flatbuf.PrecisionDOUBLE) + return flatbuf.FloatingPointEnd(b) + default: + panic(errors.Errorf("arrow/ipc: invalid floating point precision %d-bits", bw)) + } +} + type customMetadataer interface { CustomMetadataLength() int CustomMetadata(*flatbuf.KeyValue, int) bool } -func metadataFrom(md customMetadataer) (arrow.Metadata, error) { +func metadataFromFB(md customMetadataer) (arrow.Metadata, error) { var ( keys = make([]string, md.CustomMetadataLength()) vals = make([]string, md.CustomMetadataLength()) @@ -335,6 +611,29 @@ func metadataFrom(md customMetadataer) (arrow.Metadata, error) { return arrow.NewMetadata(keys, vals), nil } +func metadataToFB(b *flatbuffers.Builder, meta arrow.Metadata, start startVecFunc) flatbuffers.UOffsetT { + if meta.Len() == 0 { + return 0 + } + + n := meta.Len() + kvs := make([]flatbuffers.UOffsetT, n) + for i := range kvs { + k := b.CreateString(meta.Keys()[i]) + v := b.CreateString(meta.Values()[i]) + flatbuf.KeyValueStart(b) + flatbuf.KeyValueAddKey(b, k) + flatbuf.KeyValueAddValue(b, v) + kvs[i] = flatbuf.KeyValueEnd(b) + } + + start(b, n) + for i := n - 1; i >= 0; i-- { + b.PrependUOffsetT(kvs[i]) + } + return b.EndVector(n) +} + func schemaFromFB(schema *flatbuf.Schema, memo *dictMemo) (*arrow.Schema, error) { var ( err error @@ -353,7 +652,7 @@ func schemaFromFB(schema *flatbuf.Schema, memo *dictMemo) (*arrow.Schema, error) } } - md, err := metadataFrom(schema) + md, err := metadataFromFB(schema) if err != nil { return nil, errors.Wrapf(err, "arrow/ipc: could not convert schema metadata from flatbuf") } @@ -361,6 +660,29 @@ func schemaFromFB(schema *flatbuf.Schema, memo *dictMemo) (*arrow.Schema, error) return arrow.NewSchema(fields, &md), nil } +func schemaToFB(b *flatbuffers.Builder, schema *arrow.Schema, memo *dictMemo) flatbuffers.UOffsetT { + fields := make([]flatbuffers.UOffsetT, len(schema.Fields())) + for i, field := range schema.Fields() { + fields[i] = fieldToFB(b, field, memo) + } + + flatbuf.SchemaStartFieldsVector(b, len(fields)) + for i := len(fields) - 1; i >= 0; i-- { + b.PrependUOffsetT(fields[i]) + } + fieldsFB := b.EndVector(len(fields)) + + metaFB := metadataToFB(b, schema.Metadata(), flatbuf.SchemaStartCustomMetadataVector) + + flatbuf.SchemaStart(b) + flatbuf.SchemaAddEndianness(b, flatbuf.EndiannessLittle) + flatbuf.SchemaAddFields(b, fieldsFB) + flatbuf.SchemaAddCustomMetadata(b, metaFB) + offset := flatbuf.SchemaEnd(b) + + return offset +} + func dictTypesFromFB(schema *flatbuf.Schema) (dictTypeMap, error) { var ( err error @@ -407,3 +729,161 @@ func visitField(field *flatbuf.Field, dict dictTypeMap) (dictTypeMap, error) { } return dict, err } + +// payloadsFromSchema returns a slice of payloads corresponding to the given schema. +// Callers of payloadsFromSchema will need to call Release after use. +func payloadsFromSchema(schema *arrow.Schema, mem memory.Allocator, memo *dictMemo) payloads { + dict := newMemo() + + ps := make(payloads, 1, dict.Len()+1) + ps[0].msg = MessageSchema + ps[0].meta = writeSchemaMessage(schema, mem, &dict) + + // append dictionaries. + if dict.Len() > 0 { + panic("payloads-from-schema: not-implemented") + // for id, arr := range dict.id2dict { + // // GetSchemaPayloads: writer.cc:535 + // } + } + + if memo != nil { + *memo = dict + } + + return ps +} + +func writeFBBuilder(b *flatbuffers.Builder, mem memory.Allocator) *memory.Buffer { + raw := b.FinishedBytes() + buf := memory.NewResizableBuffer(mem) + buf.Resize(len(raw)) + copy(buf.Bytes(), raw) + return buf +} + +func writeMessageFB(b *flatbuffers.Builder, mem memory.Allocator, hdrType flatbuf.MessageHeader, hdr flatbuffers.UOffsetT, bodyLen int64) *memory.Buffer { + + flatbuf.MessageStart(b) + flatbuf.MessageAddVersion(b, int16(currentMetadataVersion)) + flatbuf.MessageAddHeaderType(b, hdrType) + flatbuf.MessageAddHeader(b, hdr) + flatbuf.MessageAddBodyLength(b, bodyLen) + msg := flatbuf.MessageEnd(b) + b.Finish(msg) + + return writeFBBuilder(b, mem) +} + +func writeSchemaMessage(schema *arrow.Schema, mem memory.Allocator, dict *dictMemo) *memory.Buffer { + b := flatbuffers.NewBuilder(1024) + schemaFB := schemaToFB(b, schema, dict) + return writeMessageFB(b, mem, flatbuf.MessageHeaderSchema, schemaFB, 0) +} + +func writeFileFooter(schema *arrow.Schema, dicts, recs []fileBlock, w io.Writer) error { + var ( + b = flatbuffers.NewBuilder(1024) + memo = newMemo() + ) + + schemaFB := schemaToFB(b, schema, &memo) + dictsFB := fileBlocksToFB(b, dicts, flatbuf.FooterStartDictionariesVector) + recsFB := fileBlocksToFB(b, recs, flatbuf.FooterStartRecordBatchesVector) + + flatbuf.FooterStart(b) + flatbuf.FooterAddVersion(b, int16(currentMetadataVersion)) + flatbuf.FooterAddSchema(b, schemaFB) + flatbuf.FooterAddDictionaries(b, dictsFB) + flatbuf.FooterAddRecordBatches(b, recsFB) + footer := flatbuf.FooterEnd(b) + + b.Finish(footer) + + _, err := w.Write(b.FinishedBytes()) + return err +} + +func writeRecordMessage(mem memory.Allocator, size, bodyLength int64, fields []fieldMetadata, meta []bufferMetadata) *memory.Buffer { + b := flatbuffers.NewBuilder(0) + recFB := recordToFB(b, size, bodyLength, fields, meta) + return writeMessageFB(b, mem, flatbuf.MessageHeaderRecordBatch, recFB, bodyLength) +} + +func recordToFB(b *flatbuffers.Builder, size, bodyLength int64, fields []fieldMetadata, meta []bufferMetadata) flatbuffers.UOffsetT { + fieldsFB := writeFieldNodes(b, fields, flatbuf.RecordBatchStartNodesVector) + metaFB := writeBuffers(b, meta, flatbuf.RecordBatchStartBuffersVector) + + flatbuf.RecordBatchStart(b) + flatbuf.RecordBatchAddLength(b, size) + flatbuf.RecordBatchAddNodes(b, fieldsFB) + flatbuf.RecordBatchAddBuffers(b, metaFB) + return flatbuf.RecordBatchEnd(b) +} + +func writeFieldNodes(b *flatbuffers.Builder, fields []fieldMetadata, start startVecFunc) flatbuffers.UOffsetT { + + start(b, len(fields)) + for i := len(fields) - 1; i >= 0; i-- { + field := fields[i] + if field.Offset != 0 { + panic(errors.Errorf("arrow/ipc: field metadata for IPC must have offset 0")) + } + flatbuf.CreateFieldNode(b, field.Len, field.Nulls) + } + + return b.EndVector(len(fields)) +} + +func writeBuffers(b *flatbuffers.Builder, buffers []bufferMetadata, start startVecFunc) flatbuffers.UOffsetT { + start(b, len(buffers)) + for i := len(buffers) - 1; i >= 0; i-- { + buffer := buffers[i] + flatbuf.CreateBuffer(b, buffer.Offset, buffer.Len) + } + return b.EndVector(len(buffers)) +} + +func writeMessage(msg *memory.Buffer, alignment int32, w io.Writer) (int, error) { + var ( + n int + err error + ) + + // ARROW-3212: we do not make any assumption on whether the output stream is aligned or not. + paddedMsgLen := int32(msg.Len()) + 4 + remainder := paddedMsgLen % alignment + if remainder != 0 { + paddedMsgLen += alignment - remainder + } + + // the returned message size includes the length prefix, the flatbuffer, + padding + n = int(paddedMsgLen) + + tmp := make([]byte, 4) + + // write the flatbuffer size prefix, including padding + sizeFB := paddedMsgLen - 4 + binary.LittleEndian.PutUint32(tmp, uint32(sizeFB)) + _, err = w.Write(tmp) + if err != nil { + return n, errors.Wrap(err, "arrow/ipc: could not write message flatbuffer size prefix") + } + + // write the flatbuffer + _, err = w.Write(msg.Bytes()) + if err != nil { + return n, errors.Wrap(err, "arrow/ipc: could not write message flatbuffer") + } + + // write any padding + padding := paddedMsgLen - int32(msg.Len()) - 4 + if padding > 0 { + _, err = w.Write(paddingBytes[:padding]) + if err != nil { + return n, errors.Wrap(err, "arrow/ipc: could not write message padding bytes") + } + } + + return n, err +} diff --git a/go/arrow/ipc/metadata_test.go b/go/arrow/ipc/metadata_test.go new file mode 100644 index 00000000000..974267239e4 --- /dev/null +++ b/go/arrow/ipc/metadata_test.go @@ -0,0 +1,159 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipc // import "github.com/apache/arrow/go/arrow/ipc" + +import ( + "bytes" + "reflect" + "testing" + + "github.com/apache/arrow/go/arrow" + "github.com/apache/arrow/go/arrow/internal/flatbuf" + flatbuffers "github.com/google/flatbuffers/go" +) + +func TestRWSchema(t *testing.T) { + meta := arrow.NewMetadata([]string{"k1", "k2", "k3"}, []string{"v1", "v2", "v3"}) + for _, tc := range []struct { + schema *arrow.Schema + memo dictMemo + }{ + { + schema: arrow.NewSchema([]arrow.Field{ + {Name: "f1", Type: arrow.PrimitiveTypes.Int64}, + {Name: "f2", Type: arrow.PrimitiveTypes.Uint16}, + {Name: "f3", Type: arrow.PrimitiveTypes.Float64}, + }, &meta), + memo: newMemo(), + }, + } { + t.Run("", func(t *testing.T) { + b := flatbuffers.NewBuilder(0) + + offset := schemaToFB(b, tc.schema, &tc.memo) + b.Finish(offset) + + buf := b.FinishedBytes() + + fb := flatbuf.GetRootAsSchema(buf, 0) + got, err := schemaFromFB(fb, &tc.memo) + if err != nil { + t.Fatal(err) + } + + if !got.Equal(tc.schema) { + t.Fatalf("r/w schema failed:\ngot = %#v\nwant= %#v\n", got, tc.schema) + } + + { + got := got.Metadata() + want := tc.schema.Metadata() + if got.Len() != want.Len() { + t.Fatalf("invalid metadata len: got=%d, want=%d", got.Len(), want.Len()) + } + if got, want := got.Keys(), want.Keys(); !reflect.DeepEqual(got, want) { + t.Fatalf("invalid metadata keys:\ngot =%v\nwant=%v\n", got, want) + } + if got, want := got.Values(), want.Values(); !reflect.DeepEqual(got, want) { + t.Fatalf("invalid metadata values:\ngot =%v\nwant=%v\n", got, want) + } + } + }) + } +} + +func TestRWFooter(t *testing.T) { + for _, tc := range []struct { + schema *arrow.Schema + dicts []fileBlock + recs []fileBlock + }{ + { + schema: arrow.NewSchema([]arrow.Field{ + {Name: "f1", Type: arrow.PrimitiveTypes.Int64}, + {Name: "f2", Type: arrow.PrimitiveTypes.Uint16}, + {Name: "f3", Type: arrow.PrimitiveTypes.Float64}, + }, nil), + dicts: []fileBlock{ + {Offset: 1, Meta: 2, Body: 3}, + {Offset: 4, Meta: 5, Body: 6}, + {Offset: 7, Meta: 8, Body: 9}, + }, + recs: []fileBlock{ + {Offset: 0, Meta: 10, Body: 30}, + {Offset: 10, Meta: 30, Body: 60}, + {Offset: 20, Meta: 30, Body: 40}, + }, + }, + } { + t.Run("", func(t *testing.T) { + o := new(bytes.Buffer) + + err := writeFileFooter(tc.schema, tc.dicts, tc.recs, o) + if err != nil { + t.Fatal(err) + } + + footer := flatbuf.GetRootAsFooter(o.Bytes(), 0) + + if got, want := MetadataVersion(footer.Version()), currentMetadataVersion; got != want { + t.Errorf("invalid metadata version: got=%[1]d %#[1]x, want=%[2]d %#[2]x", int16(got), int16(want)) + } + + schema, err := schemaFromFB(footer.Schema(nil), nil) + if err != nil { + t.Fatal(err) + } + + if !schema.Equal(tc.schema) { + t.Fatalf("schema r/w error:\ngot= %v\nwant=%v", schema, tc.schema) + } + + if got, want := footer.DictionariesLength(), len(tc.dicts); got != want { + t.Fatalf("dicts len differ: got=%d, want=%d", got, want) + } + + for i, dict := range tc.dicts { + var blk flatbuf.Block + if !footer.Dictionaries(&blk, i) { + t.Fatalf("could not get dictionary %d", i) + } + got := fileBlock{Offset: blk.Offset(), Meta: blk.MetaDataLength(), Body: blk.BodyLength()} + want := dict + if got != want { + t.Errorf("dict[%d] differ:\ngot= %v\nwant=%v", i, got, want) + } + } + + if got, want := footer.RecordBatchesLength(), len(tc.recs); got != want { + t.Fatalf("recs len differ: got=%d, want=%d", got, want) + } + + for i, rec := range tc.recs { + var blk flatbuf.Block + if !footer.RecordBatches(&blk, i) { + t.Fatalf("could not get record %d", i) + } + got := fileBlock{Offset: blk.Offset(), Meta: blk.MetaDataLength(), Body: blk.BodyLength()} + want := rec + if got != want { + t.Errorf("record[%d] differ:\ngot= %v\nwant=%v", i, got, want) + } + } + }) + } +} diff --git a/go/arrow/ipc/reader.go b/go/arrow/ipc/reader.go index dd98f7c2a71..919c7a5d2b0 100644 --- a/go/arrow/ipc/reader.go +++ b/go/arrow/ipc/reader.go @@ -62,7 +62,7 @@ func NewReader(r io.Reader, opts ...Option) (*Reader, error) { mem: cfg.alloc, } - err := rr.readSchema() + err := rr.readSchema(cfg.schema) if err != nil { return nil, errors.Wrap(err, "arrow/ipc: could not read schema from stream") } @@ -76,7 +76,7 @@ func (r *Reader) Err() error { return r.err } func (r *Reader) Schema() *arrow.Schema { return r.schema } -func (r *Reader) readSchema() error { +func (r *Reader) readSchema(schema *arrow.Schema) error { msg, err := r.r.Message() if err != nil { return errors.Wrap(err, "arrow/ipc: could not read message schema") @@ -106,6 +106,11 @@ func (r *Reader) readSchema() error { return errors.Wrap(err, "arrow/ipc: could not decode schema from message schema") } + // check the provided schema match the one read from stream. + if schema != nil && !schema.Equal(r.schema) { + return errInconsistentSchema + } + return nil } diff --git a/go/arrow/ipc/stream_test.go b/go/arrow/ipc/stream_test.go new file mode 100644 index 00000000000..64ea6cca296 --- /dev/null +++ b/go/arrow/ipc/stream_test.go @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipc_test + +import ( + "io" + "io/ioutil" + "os" + "testing" + + "github.com/apache/arrow/go/arrow/internal/arrdata" + "github.com/apache/arrow/go/arrow/ipc" + "github.com/apache/arrow/go/arrow/memory" +) + +func TestStream(t *testing.T) { + for name, recs := range arrdata.Records { + t.Run(name, func(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + f, err := ioutil.TempFile("", "arrow-ipc-") + if err != nil { + t.Fatal(err) + } + defer f.Close() + defer os.Remove(f.Name()) + + { + w := ipc.NewWriter(f, ipc.WithSchema(recs[0].Schema()), ipc.WithAllocator(mem)) + defer w.Close() + + for i, rec := range recs { + err = w.Write(rec) + if err != nil { + t.Fatalf("could not write record[%d]: %v", i, err) + } + } + + err = w.Close() + if err != nil { + t.Fatal(err) + } + } + + err = f.Sync() + if err != nil { + t.Fatalf("could not sync data to disk: %v", err) + } + + _, err = f.Seek(0, io.SeekStart) + if err != nil { + t.Fatalf("could not seek to start: %v", err) + } + + { + r, err := ipc.NewReader(f, ipc.WithSchema(recs[0].Schema()), ipc.WithAllocator(mem)) + if err != nil { + t.Fatal(err) + } + defer r.Release() + + n := 0 + for r.Next() { + rec := r.Record() + if !cmpRecs(rec, recs[n]) { + t.Fatalf("records[%d] differ", n) + } + n++ + } + + if len(recs) != n { + t.Fatalf("invalid number of records. got=%d, want=%d", n, len(recs)) + } + } + }) + } +} diff --git a/go/arrow/ipc/writer.go b/go/arrow/ipc/writer.go new file mode 100644 index 00000000000..e1aff5796d7 --- /dev/null +++ b/go/arrow/ipc/writer.go @@ -0,0 +1,385 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipc // import "github.com/apache/arrow/go/arrow/ipc" + +import ( + "fmt" + "io" + "math" + + "github.com/apache/arrow/go/arrow" + "github.com/apache/arrow/go/arrow/array" + "github.com/apache/arrow/go/arrow/internal/bitutil" + "github.com/apache/arrow/go/arrow/memory" + "github.com/pkg/errors" +) + +type swriter struct { + w io.Writer + pos int64 +} + +func (w *swriter) start() error { return nil } +func (w *swriter) Close() error { + _, err := w.Write(kEOS[:]) + return err +} + +func (w *swriter) write(p payload) error { + _, err := writeIPCPayload(w, p) + if err != nil { + return err + } + return nil +} + +func (w *swriter) Write(p []byte) (int, error) { + n, err := w.w.Write(p) + w.pos += int64(n) + return n, err +} + +// Writer is an Arrow stream writer. +type Writer struct { + w io.Writer + + mem memory.Allocator + pw payloadWriter + + started bool + schema *arrow.Schema +} + +// NewWriter returns a writer that writes records to the provided output stream. +func NewWriter(w io.Writer, opts ...Option) *Writer { + cfg := newConfig(opts...) + return &Writer{ + w: w, + mem: cfg.alloc, + pw: &swriter{w: w}, + schema: cfg.schema, + } +} + +func (w *Writer) Close() error { + if w.pw == nil { + return nil + } + + err := w.pw.Close() + if err != nil { + return errors.Wrap(err, "arrow/ipc: could not close payload writer") + } + w.pw = nil + + return nil +} + +func (w *Writer) Write(rec array.Record) error { + if !w.started { + err := w.start() + if err != nil { + return err + } + } + + schema := rec.Schema() + if schema == nil || !schema.Equal(w.schema) { + return errInconsistentSchema + } + + const allow64b = true + var ( + data = payload{msg: MessageRecordBatch} + enc = newRecordEncoder(w.mem, 0, kMaxNestingDepth, allow64b) + ) + defer data.Release() + + if err := enc.Encode(&data, rec); err != nil { + return errors.Wrap(err, "arrow/ipc: could not encode record to payload") + } + + return w.pw.write(data) +} + +func (w *Writer) start() error { + w.started = true + + // write out schema payloads + ps := payloadsFromSchema(w.schema, w.mem, nil) + defer ps.Release() + + for _, data := range ps { + err := w.pw.write(data) + if err != nil { + return err + } + } + + return nil +} + +type recordEncoder struct { + mem memory.Allocator + + fields []fieldMetadata + meta []bufferMetadata + + depth int64 + start int64 + allow64b bool +} + +func newRecordEncoder(mem memory.Allocator, startOffset, maxDepth int64, allow64b bool) *recordEncoder { + return &recordEncoder{ + mem: mem, + start: startOffset, + depth: maxDepth, + allow64b: allow64b, + } +} + +func (w *recordEncoder) Encode(p *payload, rec array.Record) error { + + // perform depth-first traversal of the row-batch + for i, col := range rec.Columns() { + err := w.visit(p, col) + if err != nil { + return errors.Wrapf(err, "arrow/ipc: could not encode column %d (%q)", i, rec.ColumnName(i)) + } + } + + // position for the start of a buffer relative to the passed frame of reference. + // may be 0 or some other position in an address space. + offset := w.start + w.meta = make([]bufferMetadata, len(p.body)) + + // construct the metadata for the record batch header + for i, buf := range p.body { + var ( + size int64 + padding int64 + ) + // the buffer might be null if we are handling zero row lengths. + if buf != nil { + size = int64(buf.Len()) + padding = bitutil.CeilByte64(size) - size + } + w.meta[i] = bufferMetadata{ + Offset: offset, + Len: size + padding, + } + offset += size + padding + } + + p.size = offset - w.start + if !bitutil.IsMultipleOf8(p.size) { + panic("not aligned") + } + + return w.encodeMetadata(p, rec.NumRows()) +} + +func (w *recordEncoder) visit(p *payload, arr array.Interface) error { + if w.depth <= 0 { + return errMaxRecursion + } + + if !w.allow64b && arr.Len() > math.MaxInt32 { + return errBigArray + } + + // add all common elements + w.fields = append(w.fields, fieldMetadata{ + Len: int64(arr.Len()), + Nulls: int64(arr.NullN()), + Offset: 0, + }) + + switch arr.NullN() { + case 0: + p.body = append(p.body, nil) + default: + data := arr.Data() + bitmap := newTruncatedBitmap(w.mem, int64(data.Offset()), int64(data.Len()), data.Buffers()[0]) + p.body = append(p.body, bitmap) + } + + switch dtype := arr.DataType().(type) { + case *arrow.NullType: + p.body = append(p.body, nil) + + case *arrow.BooleanType: + data := arr.Data() + p.body = append(p.body, newTruncatedBitmap(w.mem, int64(data.Offset()), int64(data.Len()), data.Buffers()[1])) + + case arrow.FixedWidthDataType: + data := arr.Data() + values := data.Buffers()[1] + typeWidth := dtype.BitWidth() / 8 + minLength := paddedLength(int64(arr.Len())*int64(typeWidth), kArrowAlignment) + + switch { + case needTruncate(int64(data.Offset()), values, minLength): + panic("not implemented") // FIXME(sbinet) writer.cc:212 + default: + values.Retain() + } + p.body = append(p.body, values) + + case *arrow.BinaryType: + arr := arr.(*array.Binary) + voffsets, err := w.getZeroBasedValueOffsets(arr) + if err != nil { + return errors.Wrapf(err, "could not retrieve zero-based value offsets from %T", arr) + } + data := arr.Data() + values := data.Buffers()[2] + + var totalDataBytes int64 + if voffsets != nil { + totalDataBytes = int64(len(arr.ValueBytes())) + } + + switch { + case needTruncate(int64(data.Offset()), values, totalDataBytes): + panic("not implemented") // FIXME(sbinet) writer.cc:264 + default: + values.Retain() + } + p.body = append(p.body, voffsets) + p.body = append(p.body, values) + + case *arrow.StringType: + arr := arr.(*array.String) + voffsets, err := w.getZeroBasedValueOffsets(arr) + if err != nil { + return errors.Wrapf(err, "could not retrieve zero-based value offsets from %T", arr) + } + data := arr.Data() + values := data.Buffers()[2] + + var totalDataBytes int64 + if voffsets != nil { + totalDataBytes = int64(arr.ValueOffset(arr.Len()) - arr.ValueOffset(0)) + } + + switch { + case needTruncate(int64(data.Offset()), values, totalDataBytes): + panic("not implemented") // FIXME(sbinet) writer.cc:264 + default: + values.Retain() + } + p.body = append(p.body, voffsets) + p.body = append(p.body, values) + + case *arrow.StructType: + w.depth-- + arr := arr.(*array.Struct) + for i := 0; i < arr.NumField(); i++ { + err := w.visit(p, arr.Field(i)) + if err != nil { + return errors.Wrapf(err, "could not visit field %d of struct-array", i) + } + } + w.depth++ + + case *arrow.ListType: + arr := arr.(*array.List) + voffsets, err := w.getZeroBasedValueOffsets(arr) + if err != nil { + return errors.Wrapf(err, "could not retrieve zero-based value offsets for array %T", arr) + } + p.body = append(p.body, voffsets) + + w.depth-- + var ( + values = arr.ListValues() + mustRelease = false + values_offset int64 + values_length int64 + ) + defer func() { + if mustRelease { + values.Release() + } + }() + + if voffsets != nil { + values_offset = int64(arr.Offsets()[0]) + values_length = int64(arr.Offsets()[arr.Len()]) - values_offset + } + + if len(arr.Offsets()) != 0 || values_length < int64(values.Len()) { + // must also slice the values + values = array.NewSlice(values, values_offset, values_length) + mustRelease = true + } + err = w.visit(p, values) + + if err != nil { + return errors.Wrapf(err, "could not visit list element for array %T", arr) + } + w.depth++ + + default: + panic(errors.Errorf("arrow/ipc: unknown array %T (dtype=%T)", arr, dtype)) + } + + return nil +} + +func (w *recordEncoder) getZeroBasedValueOffsets(arr array.Interface) (*memory.Buffer, error) { + data := arr.Data() + voffsets := data.Buffers()[1] + if data.Offset() != 0 { + // FIXME(sbinet): writer.cc:231 + panic(fmt.Errorf("not implemented offset=%d", data.Offset())) + } + + voffsets.Retain() + return voffsets, nil +} + +func (w *recordEncoder) encodeMetadata(p *payload, nrows int64) error { + p.meta = writeRecordMessage(w.mem, nrows, p.size, w.fields, w.meta) + return nil +} + +func newTruncatedBitmap(mem memory.Allocator, offset, length int64, input *memory.Buffer) *memory.Buffer { + if input != nil { + input.Retain() + return input + } + + minLength := paddedLength(bitutil.BytesForBits(length), kArrowAlignment) + switch { + case offset != 0 || minLength < int64(input.Len()): + // with a sliced array / non-zero offset, we must copy the bitmap + panic("not implemented") // FIXME(sbinet): writer.cc:75 + default: + input.Retain() + return input + } +} + +func needTruncate(offset int64, buf *memory.Buffer, minLength int64) bool { + if buf == nil { + return false + } + return offset != 0 || minLength < int64(buf.Len()) +}