From 549f69f0f5c103d2700f974404a3e6dc290c70bd Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 28 May 2025 07:38:58 -0700 Subject: [PATCH] GH-46627: [Swift] Support Decimal128 --- swift/Arrow/Sources/Arrow/ArrowArray.swift | 20 +++++++++++ .../Sources/Arrow/ArrowArrayBuilder.swift | 28 +++++++++++++-- .../Sources/Arrow/ArrowBufferBuilder.swift | 4 ++- swift/Arrow/Sources/Arrow/ArrowDecoder.swift | 15 ++++++-- .../Sources/Arrow/ArrowReaderHelper.swift | 24 ++++++++++++- swift/Arrow/Sources/Arrow/ArrowType.swift | 34 +++++++++++++++++-- swift/Arrow/Sources/Arrow/ProtoUtil.swift | 9 +++++ 7 files changed, 124 insertions(+), 10 deletions(-) diff --git a/swift/Arrow/Sources/Arrow/ArrowArray.swift b/swift/Arrow/Sources/Arrow/ArrowArray.swift index 4fc1b8b9fc7..9f7e15e5b60 100644 --- a/swift/Arrow/Sources/Arrow/ArrowArray.swift +++ b/swift/Arrow/Sources/Arrow/ArrowArray.swift @@ -97,6 +97,8 @@ public class ArrowArrayHolderImpl: ArrowArrayHolder { return try ArrowArrayHolderImpl(FixedArray(with)) case .float: return try ArrowArrayHolderImpl(FixedArray(with)) + case .decimal128: + return try ArrowArrayHolderImpl(FixedArray(with)) case .date32: return try ArrowArrayHolderImpl(Date32Array(with)) case .date64: @@ -233,6 +235,24 @@ public class Date64Array: ArrowArray { public class Time32Array: FixedArray {} public class Time64Array: FixedArray {} +public class Decimal128Array: FixedArray { + public override subscript(_ index: UInt) -> Decimal? { + if self.arrowData.isNull(index) { + return nil + } + let scale: Int32 = switch self.arrowData.type.id { + case .decimal128(_, let scale): + scale + default: + 18 + } + let byteOffset = self.arrowData.stride * Int(index) + let value = self.arrowData.buffers[1].rawPointer.advanced(by: byteOffset).load( + as: UInt64.self) + return Decimal(value) / pow(10, Int(scale)) + } +} + public class BinaryArray: ArrowArray { public struct Options { public var printAsHex = false diff --git a/swift/Arrow/Sources/Arrow/ArrowArrayBuilder.swift b/swift/Arrow/Sources/Arrow/ArrowArrayBuilder.swift index 005cad79dae..8d22a499375 100644 --- a/swift/Arrow/Sources/Arrow/ArrowArrayBuilder.swift +++ b/swift/Arrow/Sources/Arrow/ArrowArrayBuilder.swift @@ -119,6 +119,12 @@ public class Time64ArrayBuilder: ArrowArrayBuilder, T } } +public class Decimal128ArrayBuilder: ArrowArrayBuilder, Decimal128Array> { + fileprivate convenience init(precision: Int32, scale: Int32) throws { + try self.init(ArrowTypeDecimal128(precision: precision, scale: scale)) + } +} + public class StructArrayBuilder: ArrowArrayBuilder { let builders: [any ArrowArrayHolderBuilder] let fields: [ArrowField] @@ -197,6 +203,8 @@ public class ArrowArrayBuilders { return try ArrowArrayBuilders.loadBoolArrayBuilder() } else if builderType == Date.self || builderType == Date?.self { return try ArrowArrayBuilders.loadDate64ArrayBuilder() + } else if builderType == Decimal.self || builderType == Decimal?.self { + return try ArrowArrayBuilders.loadDecimal128ArrayBuilder(38, 18) } else { throw ArrowError.invalid("Invalid type for builder: \(builderType)") } @@ -215,7 +223,8 @@ public class ArrowArrayBuilders { type == UInt8.self || type == UInt16.self || type == UInt32.self || type == UInt64.self || type == String.self || type == Double.self || - type == Float.self || type == Date.self + type == Float.self || type == Date.self || + type == Decimal.self || type == Decimal?.self } public static func loadStructArrayBuilderForType(_ obj: T) throws -> StructArrayBuilder { @@ -279,12 +288,18 @@ public class ArrowArrayBuilders { throw ArrowError.invalid("Expected arrow type for \(arrowType.id) not found") } return try Time64ArrayBuilder(timeType.unit) + case .decimal128: + guard let decimalType = arrowType as? ArrowTypeDecimal128 else { + throw ArrowError.invalid("Expected ArrowTypeDecimal128 for decimal128 type") + } + return try Decimal128ArrayBuilder(precision: decimalType.precision, scale: decimalType.scale) default: throw ArrowError.unknownType("Builder not found for arrow type: \(arrowType.id)") } } - public static func loadNumberArrayBuilder() throws -> NumberArrayBuilder { + public static func loadNumberArrayBuilder( // swiftlint:disable:this cyclomatic_complexity + ) throws -> NumberArrayBuilder { let type = T.self if type == Int8.self { return try NumberArrayBuilder() @@ -306,6 +321,8 @@ public class ArrowArrayBuilders { return try NumberArrayBuilder() } else if type == Double.self { return try NumberArrayBuilder() + } else if type == Decimal.self { + return try NumberArrayBuilder() } else { throw ArrowError.unknownType("Type is invalid for NumberArrayBuilder") } @@ -338,4 +355,11 @@ public class ArrowArrayBuilders { public static func loadTime64ArrayBuilder(_ unit: ArrowTime64Unit) throws -> Time64ArrayBuilder { return try Time64ArrayBuilder(unit) } + + public static func loadDecimal128ArrayBuilder( + _ precision: Int32 = 38, + _ scale: Int32 = 18 + ) throws -> Decimal128ArrayBuilder { + return try Decimal128ArrayBuilder(precision: precision, scale: scale) + } } diff --git a/swift/Arrow/Sources/Arrow/ArrowBufferBuilder.swift b/swift/Arrow/Sources/Arrow/ArrowBufferBuilder.swift index 47f9c40354b..6ad74ba0a65 100644 --- a/swift/Arrow/Sources/Arrow/ArrowBufferBuilder.swift +++ b/swift/Arrow/Sources/Arrow/ArrowBufferBuilder.swift @@ -118,7 +118,7 @@ public class FixedBufferBuilder: ValuesBufferBuilder, ArrowBufferBuilder { return [nulls, values] } - fileprivate static func defaultValueForType() throws -> T { + fileprivate static func defaultValueForType() throws -> T { // swiftlint:disable:this cyclomatic_complexity let type = T.self if type == Int8.self { return Int8(0) as! T // swiftlint:disable:this force_cast @@ -140,6 +140,8 @@ public class FixedBufferBuilder: ValuesBufferBuilder, ArrowBufferBuilder { return Float(0) as! T // swiftlint:disable:this force_cast } else if type == Double.self { return Double(0) as! T // swiftlint:disable:this force_cast + } else if type == Decimal.self { + return Decimal(0) as! T // swiftlint:disable:this force_cast } throw ArrowError.unknownType("Unable to determine default value") diff --git a/swift/Arrow/Sources/Arrow/ArrowDecoder.swift b/swift/Arrow/Sources/Arrow/ArrowDecoder.swift index 35dd4dcd1e8..78e563ee056 100644 --- a/swift/Arrow/Sources/Arrow/ArrowDecoder.swift +++ b/swift/Arrow/Sources/Arrow/ArrowDecoder.swift @@ -164,7 +164,8 @@ private struct ArrowUnkeyedDecoding: UnkeyedDecodingContainer { type == UInt8.self || type == UInt16.self || type == UInt32.self || type == UInt64.self || type == String.self || type == Double.self || - type == Float.self || type == Date.self { + type == Float.self || type == Date.self || + type == Decimal.self || type == Decimal?.self { defer {increment()} return try self.decoder.doDecode(self.currentIndex)! } else { @@ -263,8 +264,12 @@ private struct ArrowKeyedDecoding: KeyedDecodingContainerProtoco return try self.decoder.doDecode(key)! } + func decode(_ type: Decimal.Type, forKey key: Key) throws -> Decimal { + return try self.decoder.doDecode(key)! + } + func decode(_ type: T.Type, forKey key: Key) throws -> T where T: Decodable { - if ArrowArrayBuilders.isValidBuilderType(type) || type == Date.self { + if ArrowArrayBuilders.isValidBuilderType(type) || type == Date.self || type == Decimal.self { return try self.decoder.doDecode(key)! } else { throw ArrowError.invalid("Type \(type) is currently not supported") @@ -366,8 +371,12 @@ private struct ArrowSingleValueDecoding: SingleValueDecodingContainer { return try self.decoder.doDecode(self.decoder.singleRBCol)! } + func decode(_ type: Decimal.Type) throws -> Decimal { + return try self.decoder.doDecode(self.decoder.singleRBCol)! + } + func decode(_ type: T.Type) throws -> T where T: Decodable { - if ArrowArrayBuilders.isValidBuilderType(type) || type == Date.self { + if ArrowArrayBuilders.isValidBuilderType(type) || type == Date.self || type == Decimal.self { return try self.decoder.doDecode(self.decoder.singleRBCol)! } else { throw ArrowError.invalid("Type \(type) is currently not supported") diff --git a/swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift b/swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift index 18cf41ad25a..35a72070a90 100644 --- a/swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift +++ b/swift/Arrow/Sources/Arrow/ArrowReaderHelper.swift @@ -44,6 +44,20 @@ private func makeStringHolder(_ buffers: [ArrowBuffer], } } +private func makeDecimalHolder(_ field: ArrowField, + buffers: [ArrowBuffer], + nullCount: UInt +) -> Result { + do { + let arrowData = try ArrowData(field.type, buffers: buffers, nullCount: nullCount) + return .success(ArrowArrayHolderImpl(try Decimal128Array(arrowData))) + } catch let error as ArrowError { + return .failure(error) + } catch { + return .failure(.unknownError("\(error)")) + } +} + private func makeDateHolder(_ field: ArrowField, buffers: [ArrowBuffer], nullCount: UInt @@ -178,6 +192,8 @@ func makeArrayHolder( // swiftlint:disable:this cyclomatic_complexity return makeFixedHolder(Float.self, field: field, buffers: buffers, nullCount: nullCount) case .double: return makeFixedHolder(Double.self, field: field, buffers: buffers, nullCount: nullCount) + case .decimal128: + return makeDecimalHolder(field, buffers: buffers, nullCount: nullCount) case .string: return makeStringHolder(buffers, nullCount: nullCount) case .binary: @@ -203,7 +219,7 @@ func makeBuffer(_ buffer: org_apache_arrow_flatbuf_Buffer, fileData: Data, func isFixedPrimitive(_ type: org_apache_arrow_flatbuf_Type_) -> Bool { switch type { - case .int, .bool, .floatingpoint, .date, .time: + case .int, .bool, .floatingpoint, .date, .time, .decimal: return true default: return false @@ -243,6 +259,12 @@ func findArrowType( // swiftlint:disable:this cyclomatic_complexity function_bod default: return ArrowType(ArrowType.ArrowUnknown) } + case .decimal: + let dataType = field.type(type: org_apache_arrow_flatbuf_Decimal.self)! + if dataType.bitWidth == 128 { + return ArrowType(ArrowType.ArrowDecimal128) + } + return ArrowType(ArrowType.ArrowUnknown) case .utf8: return ArrowType(ArrowType.ArrowString) case .binary: diff --git a/swift/Arrow/Sources/Arrow/ArrowType.swift b/swift/Arrow/Sources/Arrow/ArrowType.swift index b44f8591859..14a300d2451 100644 --- a/swift/Arrow/Sources/Arrow/ArrowType.swift +++ b/swift/Arrow/Sources/Arrow/ArrowType.swift @@ -37,13 +37,13 @@ public enum ArrowError: Error { case invalid(String) } -public enum ArrowTypeId { +public enum ArrowTypeId: Sendable, Equatable { case binary case boolean case date32 case date64 case dateType - case decimal128 + case decimal128(_ precision: Int32, _ scale: Int32) case decimal256 case dictionary case double @@ -122,6 +122,23 @@ public class ArrowTypeTime64: ArrowType { } } +public class ArrowTypeDecimal128: ArrowType { + let precision: Int32 + let scale: Int32 + + public init(precision: Int32, scale: Int32) { + self.precision = precision + self.scale = scale + super.init(ArrowType.ArrowDecimal128) + } + + public override var cDataFormatId: String { + get throws { + return "d:\(precision),\(scale)" + } + } +} + public class ArrowNestedType: ArrowType { let fields: [ArrowField] public init(_ info: ArrowType.Info, fields: [ArrowField]) { @@ -142,6 +159,7 @@ public class ArrowType { public static let ArrowUInt64 = Info.primitiveInfo(ArrowTypeId.uint64) public static let ArrowFloat = Info.primitiveInfo(ArrowTypeId.float) public static let ArrowDouble = Info.primitiveInfo(ArrowTypeId.double) + public static let ArrowDecimal128 = Info.primitiveInfo(ArrowTypeId.decimal128(38, 18)) public static let ArrowUnknown = Info.primitiveInfo(ArrowTypeId.unknown) public static let ArrowString = Info.variableInfo(ArrowTypeId.string) public static let ArrowBool = Info.primitiveInfo(ArrowTypeId.boolean) @@ -206,12 +224,16 @@ public class ArrowType { return ArrowType.ArrowFloat } else if type == Double.self { return ArrowType.ArrowDouble + } else if type == Decimal.self { + return ArrowType.ArrowDecimal128 } else { return ArrowType.ArrowUnknown } } - public static func infoForNumericType(_ type: T.Type) -> ArrowType.Info { + public static func infoForNumericType( // swiftlint:disable:this cyclomatic_complexity + _ type: T.Type + ) -> ArrowType.Info { if type == Int8.self { return ArrowType.ArrowInt8 } else if type == Int16.self { @@ -232,6 +254,8 @@ public class ArrowType { return ArrowType.ArrowFloat } else if type == Double.self { return ArrowType.ArrowDouble + } else if type == Decimal.self { + return ArrowType.ArrowDecimal128 } else { return ArrowType.ArrowUnknown } @@ -260,6 +284,8 @@ public class ArrowType { return MemoryLayout.stride case .double: return MemoryLayout.stride + case .decimal128: + return 16 // Decimal 128 (= 16 * 8) bits case .boolean: return MemoryLayout.stride case .date32: @@ -304,6 +330,8 @@ public class ArrowType { return "f" case ArrowTypeId.double: return "g" + case ArrowTypeId.decimal128(let precision, let scale): + return "d:\(precision),\(scale)" case ArrowTypeId.boolean: return "b" case ArrowTypeId.date32: diff --git a/swift/Arrow/Sources/Arrow/ProtoUtil.swift b/swift/Arrow/Sources/Arrow/ProtoUtil.swift index 88cfb0bfcde..43253218e16 100644 --- a/swift/Arrow/Sources/Arrow/ProtoUtil.swift +++ b/swift/Arrow/Sources/Arrow/ProtoUtil.swift @@ -44,6 +44,15 @@ func fromProto( // swiftlint:disable:this cyclomatic_complexity function_body_le } else if floatType.precision == .double { arrowType = ArrowType(ArrowType.ArrowDouble) } + case .decimal: + let decimalType = field.type(type: org_apache_arrow_flatbuf_Decimal.self)! + if decimalType.bitWidth == 128 && decimalType.precision <= 38 { + let arrowDecimal128 = ArrowTypeId.decimal128(decimalType.precision, decimalType.scale) + arrowType = ArrowType(ArrowType.Info.primitiveInfo(arrowDecimal128)) + } else { + // Unsupport yet + arrowType = ArrowType(ArrowType.ArrowUnknown) + } case .utf8: arrowType = ArrowType(ArrowType.ArrowString) case .binary: