diff --git a/go/arrow/array/dictionary.go b/go/arrow/array/dictionary.go index d0a1c4dc97e..89b002862b2 100644 --- a/go/arrow/array/dictionary.go +++ b/go/arrow/array/dictionary.go @@ -421,6 +421,9 @@ type dictionaryBuilder struct { deltaOffset int memoTable hashing.MemoTable idxBuilder IndexBuilder + + initialDictOffset int + initialDict arrow.Array } // NewDictionaryBuilderWithDict initializes a dictionary builder and inserts the values from `init` as the first @@ -447,6 +450,11 @@ func NewDictionaryBuilderWithDict(mem memory.Allocator, dt *arrow.DictionaryType dt: dt, } + if init != nil { + bldr.initialDictOffset = init.Len() + bldr.initialDict = init + } + switch dt.ValueType.ID() { case arrow.NULL: ret := &NullDictionaryBuilder{bldr} @@ -454,189 +462,74 @@ func NewDictionaryBuilderWithDict(mem memory.Allocator, dt *arrow.DictionaryType return ret case arrow.UINT8: ret := &Uint8DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Uint8)); err != nil { - panic(err) - } - } return ret case arrow.INT8: ret := &Int8DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Int8)); err != nil { - panic(err) - } - } return ret case arrow.UINT16: ret := &Uint16DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Uint16)); err != nil { - panic(err) - } - } return ret case arrow.INT16: ret := &Int16DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Int16)); err != nil { - panic(err) - } - } return ret case arrow.UINT32: ret := &Uint32DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Uint32)); err != nil { - panic(err) - } - } return ret case arrow.INT32: ret := &Int32DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Int32)); err != nil { - panic(err) - } - } return ret case arrow.UINT64: ret := &Uint64DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Uint64)); err != nil { - panic(err) - } - } return ret case arrow.INT64: ret := &Int64DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Int64)); err != nil { - panic(err) - } - } return ret case arrow.FLOAT16: ret := &Float16DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Float16)); err != nil { - panic(err) - } - } return ret case arrow.FLOAT32: ret := &Float32DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Float32)); err != nil { - panic(err) - } - } return ret case arrow.FLOAT64: ret := &Float64DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Float64)); err != nil { - panic(err) - } - } return ret case arrow.STRING: ret := &BinaryDictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertStringDictValues(init.(*String)); err != nil { - panic(err) - } - } return ret case arrow.BINARY: ret := &BinaryDictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Binary)); err != nil { - panic(err) - } - } return ret case arrow.FIXED_SIZE_BINARY: ret := &FixedSizeBinaryDictionaryBuilder{ bldr, dt.ValueType.(*arrow.FixedSizeBinaryType).ByteWidth, } - if init != nil { - if err = ret.InsertDictValues(init.(*FixedSizeBinary)); err != nil { - panic(err) - } - } return ret case arrow.DATE32: ret := &Date32DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Date32)); err != nil { - panic(err) - } - } return ret case arrow.DATE64: ret := &Date64DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Date64)); err != nil { - panic(err) - } - } return ret case arrow.TIMESTAMP: ret := &TimestampDictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Timestamp)); err != nil { - panic(err) - } - } return ret case arrow.TIME32: ret := &Time32DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Time32)); err != nil { - panic(err) - } - } return ret case arrow.TIME64: ret := &Time64DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Time64)); err != nil { - panic(err) - } - } return ret case arrow.INTERVAL_MONTHS: ret := &MonthIntervalDictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*MonthInterval)); err != nil { - panic(err) - } - } return ret case arrow.INTERVAL_DAY_TIME: ret := &DayTimeDictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*DayTimeInterval)); err != nil { - panic(err) - } - } return ret case arrow.DECIMAL128: ret := &Decimal128DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Decimal128)); err != nil { - panic(err) - } - } return ret case arrow.DECIMAL256: ret := &Decimal256DictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Decimal256)); err != nil { - panic(err) - } - } return ret case arrow.LIST: case arrow.STRUCT: @@ -648,22 +541,12 @@ func NewDictionaryBuilderWithDict(mem memory.Allocator, dt *arrow.DictionaryType case arrow.FIXED_SIZE_LIST: case arrow.DURATION: ret := &DurationDictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*Duration)); err != nil { - panic(err) - } - } return ret case arrow.LARGE_STRING: case arrow.LARGE_BINARY: case arrow.LARGE_LIST: case arrow.INTERVAL_MONTH_DAY_NANO: ret := &MonthDayNanoDictionaryBuilder{bldr} - if init != nil { - if err = ret.InsertDictValues(init.(*MonthDayNanoInterval)); err != nil { - panic(err) - } - } return ret } @@ -804,6 +687,16 @@ func (b *dictionaryBuilder) NewDictionaryArray() *Dictionary { a.refCount = 1 indices := b.newData() + + if b.initialDict != nil { + dict := MakeFromData(indices.dictionary) + finalDict, err := Concatenate([]arrow.Array{b.initialDict, dict}, b.mem) + if err != nil { + panic(err) + } + indices.dictionary = finalDict.Data().(*Data) + } + a.setData(indices) indices.Release() return a @@ -849,14 +742,14 @@ func (b *dictionaryBuilder) insertDictBytes(val []byte) error { func (b *dictionaryBuilder) appendValue(val interface{}) error { idx, _, err := b.memoTable.GetOrInsert(val) - b.idxBuilder.Append(idx) + b.idxBuilder.Append(idx + b.initialDictOffset) b.length += 1 return err } func (b *dictionaryBuilder) appendBytes(val []byte) error { idx, _, err := b.memoTable.GetOrInsertBytes(val) - b.idxBuilder.Append(idx) + b.idxBuilder.Append(idx + b.initialDictOffset) b.length += 1 return err } diff --git a/go/arrow/array/dictionary_test.go b/go/arrow/array/dictionary_test.go index d0878fa3b03..75a1d20124c 100644 --- a/go/arrow/array/dictionary_test.go +++ b/go/arrow/array/dictionary_test.go @@ -1914,3 +1914,47 @@ func BenchmarkBinaryDictionaryBuilder(b *testing.B) { assert.NoError(b, builder.AppendString(randString())) } } + +func TestBinaryDictionaryBuilderWithInitDict(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + dictType := &arrow.DictionaryType{IndexType: &arrow.Int32Type{}, ValueType: arrow.BinaryTypes.String} + bldr := array.NewDictionaryBuilder(mem, dictType).(*array.BinaryDictionaryBuilder) + defer bldr.Release() + + bldr.AppendString("drop-1") + bldr.AppendString("keep-2") + bldr.AppendString("keep-3") + + arr := bldr.NewArray().(*array.Dictionary) + defer arr.Release() + indices := arr.Indices().(*array.Int32) + dict := arr.Dictionary().(*array.String) + + keepIndices := map[int]struct{}{} + for i := 0; i < dict.Len(); i++ { + if strings.HasPrefix(dict.Value(i), "keep-") { + keepIndices[i] = struct{}{} + } + } + + bldr2 := array.NewDictionaryBuilderWithDict(mem, dictType, dict).(*array.BinaryDictionaryBuilder) + indices2 := bldr2.IndexBuilder().Builder.(*array.Int32Builder) + defer bldr2.Release() + + for i := 0; i < arr.Len(); i++ { + if _, ok := keepIndices[i]; ok { + indices2.Append(indices.Value(i)) + } + } + + arr2 := bldr2.NewArray().(*array.Dictionary) + defer arr2.Release() + dict2 := arr2.Dictionary().(*array.String) + + assert.Equal(t, 2, arr2.Len()) + assert.Equal(t, 3, dict2.Len()) + assert.Equal(t, "keep-2", dict2.Value(arr2.GetValueIndex(0))) + assert.Equal(t, "keep-3", dict2.Value(arr2.GetValueIndex(1))) +}