diff --git a/go/parquet/pqarrow/file_writer_test.go b/go/parquet/pqarrow/file_writer_test.go index 5b807389a3e..b2d111cb0f9 100644 --- a/go/parquet/pqarrow/file_writer_test.go +++ b/go/parquet/pqarrow/file_writer_test.go @@ -26,7 +26,10 @@ import ( "github.com/apache/arrow/go/v18/arrow/array" "github.com/apache/arrow/go/v18/arrow/memory" "github.com/apache/arrow/go/v18/parquet" + "github.com/apache/arrow/go/v18/parquet/file" + "github.com/apache/arrow/go/v18/parquet/internal/encoding" "github.com/apache/arrow/go/v18/parquet/pqarrow" + pqschema "github.com/apache/arrow/go/v18/parquet/schema" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -133,3 +136,56 @@ func TestFileWriterBuffered(t *testing.T) { require.NoError(t, writer.Close()) assert.Equal(t, 4, writer.NumRows()) } + +func TestFileWriterWithLogicalTypes(t *testing.T) { + schema := arrow.NewSchema([]arrow.Field{ + {Name: "string", Nullable: true, Type: arrow.BinaryTypes.String}, + {Name: "json", Nullable: true, Type: arrow.BinaryTypes.String}, + }, nil) + + data := `[ + { "string": "{\"key\":\"value\"}", "json": "{\"key\":\"value\"}" }, + { "string": null, "json": null } + ]` + + logicalTypes := []pqschema.LogicalType{ + nil, + pqschema.JSONLogicalType{}, + } + + alloc := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer alloc.AssertSize(t, 0) + + record, _, err := array.RecordFromJSON(alloc, schema, strings.NewReader(data)) + require.NoError(t, err) + defer record.Release() + + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + sink := encoding.NewBufferWriter(0, mem) + defer sink.Release() + + writer, err := pqarrow.NewFileWriter( + schema, + sink, + parquet.NewWriterProperties( + parquet.WithAllocator(alloc), + ), + pqarrow.NewArrowWriterProperties( + pqarrow.WithAllocator(alloc), + pqarrow.WithCustomLogicalTypes(logicalTypes), + ), + ) + require.NoError(t, err) + + require.NoError(t, writer.Write(record)) + require.NoError(t, writer.Close()) + + reader, err := file.NewParquetReader(bytes.NewReader(sink.Bytes())) + require.NoError(t, err) + assert.EqualValues(t, 2, reader.NumRows()) + + parquetSchema := reader.MetaData().Schema + assert.EqualValues(t, "String", parquetSchema.Column(0).LogicalType().String()) + assert.EqualValues(t, "JSON", parquetSchema.Column(1).LogicalType().String()) +} \ No newline at end of file diff --git a/go/parquet/pqarrow/properties.go b/go/parquet/pqarrow/properties.go index 25a299c86f5..5c105070a02 100755 --- a/go/parquet/pqarrow/properties.go +++ b/go/parquet/pqarrow/properties.go @@ -22,6 +22,7 @@ import ( "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/memory" "github.com/apache/arrow/go/v18/parquet/internal/encoding" + "github.com/apache/arrow/go/v18/parquet/schema" ) // ArrowWriterProperties are used to determine how to manipulate the arrow data @@ -33,7 +34,9 @@ type ArrowWriterProperties struct { coerceTimestampUnit arrow.TimeUnit allowTruncatedTimestamps bool storeSchema bool - noMapLogicalType bool + noMapLogicalType bool // if true, do not set Logical type for arrow.MAP + customLogicalTypes []schema.LogicalType // specify to customize the Logical types of the output parquet schema + // compliantNestedTypes bool } @@ -119,6 +122,12 @@ func WithNoMapLogicalType() WriterOption { } } +func WithCustomLogicalTypes(logicalTypes []schema.LogicalType) WriterOption { + return func(c *config) { + c.props.customLogicalTypes = logicalTypes + } +} + // func WithCompliantNestedTypes(enabled bool) WriterOption { // return func(c *config) { // c.props.compliantNestedTypes = enabled diff --git a/go/parquet/pqarrow/schema.go b/go/parquet/pqarrow/schema.go index ce5cc6f9050..b3aa5f0e2e2 100644 --- a/go/parquet/pqarrow/schema.go +++ b/go/parquet/pqarrow/schema.go @@ -239,7 +239,7 @@ func structToNode(typ *arrow.StructType, name string, nullable bool, props *parq children := make(schema.FieldList, 0, typ.NumFields()) for _, f := range typ.Fields() { - n, err := fieldToNode(f.Name, f, props, arrprops) + n, err := fieldToNode(f.Name, f, props, arrprops, nil) if err != nil { return nil, err } @@ -249,7 +249,7 @@ func structToNode(typ *arrow.StructType, name string, nullable bool, props *parq return schema.NewGroupNode(name, repFromNullable(nullable), children, -1) } -func fieldToNode(name string, field arrow.Field, props *parquet.WriterProperties, arrprops ArrowWriterProperties) (schema.Node, error) { +func fieldToNode(name string, field arrow.Field, props *parquet.WriterProperties, arrprops ArrowWriterProperties, customLogicalType schema.LogicalType) (schema.Node, error) { var ( logicalType schema.LogicalType = schema.NoLogicalType{} typ parquet.Type @@ -358,7 +358,7 @@ func fieldToNode(name string, field arrow.Field, props *parquet.WriterProperties elem = field.Type.(*arrow.FixedSizeListType).Elem() } - child, err := fieldToNode(name, arrow.Field{Name: name, Type: elem, Nullable: true}, props, arrprops) + child, err := fieldToNode(name, arrow.Field{Name: name, Type: elem, Nullable: true}, props, arrprops, nil) if err != nil { return nil, err } @@ -368,7 +368,7 @@ func fieldToNode(name string, field arrow.Field, props *parquet.WriterProperties // parquet has no dictionary type, dictionary is encoding, not schema level dictType := field.Type.(*arrow.DictionaryType) return fieldToNode(name, arrow.Field{Name: name, Type: dictType.ValueType, Nullable: field.Nullable, Metadata: field.Metadata}, - props, arrprops) + props, arrprops, customLogicalType) case arrow.EXTENSION: return fieldToNode(name, arrow.Field{ Name: name, @@ -378,15 +378,15 @@ func fieldToNode(name string, field arrow.Field, props *parquet.WriterProperties ipc.ExtensionTypeKeyName: field.Type.(arrow.ExtensionType).ExtensionName(), ipc.ExtensionMetadataKeyName: field.Type.(arrow.ExtensionType).Serialize(), }), - }, props, arrprops) + }, props, arrprops, customLogicalType) case arrow.MAP: mapType := field.Type.(*arrow.MapType) - keyNode, err := fieldToNode("key", mapType.KeyField(), props, arrprops) + keyNode, err := fieldToNode("key", mapType.KeyField(), props, arrprops, nil) if err != nil { return nil, err } - valueNode, err := fieldToNode("value", mapType.ItemField(), props, arrprops) + valueNode, err := fieldToNode("value", mapType.ItemField(), props, arrprops, nil) if err != nil { return nil, err } @@ -406,6 +406,10 @@ func fieldToNode(name string, field arrow.Field, props *parquet.WriterProperties return nil, fmt.Errorf("%w: support for %s", arrow.ErrNotImplemented, field.Type.ID()) } + if customLogicalType != nil { + logicalType = customLogicalType + } + return schema.NewPrimitiveNodeLogical(name, repType, logicalType, typ, length, fieldIDFromMeta(field.Metadata)) } @@ -441,8 +445,12 @@ func ToParquet(sc *arrow.Schema, props *parquet.WriterProperties, arrprops Arrow } nodes := make(schema.FieldList, 0, sc.NumFields()) - for _, f := range sc.Fields() { - n, err := fieldToNode(f.Name, f, props, arrprops) + for i, f := range sc.Fields() { + var logicalType schema.LogicalType + if arrprops.customLogicalTypes != nil && i < len(arrprops.customLogicalTypes) { + logicalType = arrprops.customLogicalTypes[i] + } + n, err := fieldToNode(f.Name, f, props, arrprops, logicalType) if err != nil { return nil, err }