diff --git a/README.md b/README.md index c54846f..d6f53e9 100644 --- a/README.md +++ b/README.md @@ -488,16 +488,27 @@ let response = try await session.respond { } ``` -You can tune MLX KV-cache behavior per request with model-specific options: +You can tune MLX request behavior per call with model-specific options, +including KV-cache settings and optional media preprocessing: ```swift var options = GenerationOptions(temperature: 0.7) -options[custom: MLXLanguageModel.self] = .init( - maxKVSize: 4096, - kvBits: 4, - kvGroupSize: 64, - quantizedKVStart: 128 +var mlxOptions = MLXLanguageModel.CustomGenerationOptions.default +mlxOptions.kvCache = .init( + maxSize: 4096, + bits: 4, + groupSize: 64, + quantizedStart: 128 ) +// Apply a deterministic preprocessing step for image inputs. +mlxOptions.userInputProcessing = .resize(to: CGSize(width: 512, height: 512)) +// Inject extra template context consumed by model-specific chat templates. +mlxOptions.additionalContext = [ + "user_name": .string("Alice"), + "turn_count": .int(3), + "verbose": .bool(true), +] +options[custom: MLXLanguageModel.self] = mlxOptions let response = try await session.respond( to: "Summarize this transcript", @@ -505,6 +516,15 @@ let response = try await session.respond( ) ``` +You can specify `userInputProcessing` to enforce a consistent image +preprocessing step +(for example, fixed dimensions for predictable latency, memory usage, and vision behavior). +By default, images are passed through without an explicit resize override +(`resize: nil`), so MLX applies its default media processing behavior. + +You can also set `additionalContext` to provide extra JSON template variables +for model-specific chat templates. + GPU cache behavior can be configured when creating the model: ```swift diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index 8f23b2f..0ef37ef 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -196,38 +196,122 @@ import Foundation /// Set these values through ``GenerationOptions`` using /// `GenerationOptions[custom: MLXLanguageModel.self]`. public struct CustomGenerationOptions: AnyLanguageModel.CustomGenerationOptions, Codable { - /// Limits how many tokens the KV cache retains. - /// - /// Set this to `nil` to use the backend default. - public var maxKVSize: Int? - /// Sets the KV-cache quantization bit width. - /// - /// Set this to `nil` to disable KV quantization. - public var kvBits: Int? - /// Sets the token group size used for KV quantization. - public var kvGroupSize: Int - /// Sets the token offset where quantized KV storage starts. - public var quantizedKVStart: Int + /// Configures KV-cache behavior for MLX generation. + public struct KVCache: Codable, Equatable, Sendable { + /// Limits how many tokens the KV cache retains. + /// + /// Set this to `nil` to use the backend default. + public var maxSize: Int? + + /// Sets the KV-cache quantization bit width. + /// + /// Set this to `nil` to disable KV quantization. + public var bits: Int? + + /// Sets the token group size used for KV quantization. + public var groupSize: Int + + /// Sets the token offset where quantized KV storage starts. + public var quantizedStart: Int + + /// Default KV-cache options used when none are provided at runtime. + /// By default, the token group size is 64 and the quantized start is 0. + public static var `default`: Self { + .init( + maxSize: nil, + bits: nil, + groupSize: 64, + quantizedStart: 0 + ) + } + + /// Creates KV-cache configuration for MLX generation. + /// + /// - Parameters: + /// - maxSize: The maximum number of tokens to retain in KV cache storage. + /// Pass `nil` to use the backend default. + /// - bits: The KV-cache quantization bit width. + /// Pass `nil` to disable KV quantization. + /// - groupSize: The token group size used for KV quantization. + /// - quantizedStart: The token index where quantized KV storage begins. + public init( + maxSize: Int?, + bits: Int?, + groupSize: Int, + quantizedStart: Int + ) { + self.maxSize = maxSize + self.bits = bits + self.groupSize = groupSize + self.quantizedStart = quantizedStart + } + } + /// KV-cache configuration used for generation. + public var kvCache: KVCache + + /// Configures media preprocessing applied before model input. + public struct UserInputProcessing: Codable, Equatable, Sendable { + /// Optional resize target applied to media before tokenization. + public var resize: CGSize? + + /// Creates user-input processing configuration. + /// + /// - Parameter resize: Optional target size for media resizing. + init(resize: CGSize?) { + self.resize = resize + } + + /// Creates processing that resizes media to a fixed size. + /// + /// - Parameter size: Target size used for resizing media inputs. + public static func resize(to size: CGSize) -> Self { + .init(resize: size) + } + + var mlxValue: MLXLMCommon.UserInput.Processing { + .init(resize: resize) + } + } + /// Processing to apply to user media before input preparation. + public var userInputProcessing: UserInputProcessing? + + var processingForUserInput: MLXLMCommon.UserInput.Processing { + userInputProcessing?.mlxValue + ?? .init(resize: nil) + } + + /// Additional key-value pairs injected into the chat template rendering context. + public var additionalContext: [String: AnyLanguageModel.JSONValue]? + + var additionalContextForUserInput: [String: any Sendable]? { + additionalContext?.mapValues { $0.toSendable() } + } /// Creates MLX-specific generation options. /// /// - Parameters: - /// - maxKVSize: The maximum number of tokens to retain in KV cache storage. - /// Pass `nil` to use the backend default. - /// - kvBits: The KV-cache quantization bit width. - /// Pass `nil` to disable KV quantization. - /// - kvGroupSize: The token group size used for KV quantization. - /// - quantizedKVStart: The token index where quantized KV storage begins. + /// - kvCache: KV-cache configuration used for generation. + /// - additionalContext: Additional key-value pairs injected into the chat + /// template rendering context. + /// - userInputProcessing: Processing to apply to user media before input preparation. + /// Defaults to `nil`, which lets MLX use its default media handling. public init( - maxKVSize: Int? = nil, - kvBits: Int? = nil, - kvGroupSize: Int = 64, - quantizedKVStart: Int = 0 + kvCache: KVCache, + userInputProcessing: UserInputProcessing?, + additionalContext: [String: AnyLanguageModel.JSONValue]? ) { - self.maxKVSize = maxKVSize - self.kvBits = kvBits - self.kvGroupSize = kvGroupSize - self.quantizedKVStart = quantizedKVStart + self.kvCache = kvCache + self.additionalContext = additionalContext + self.userInputProcessing = userInputProcessing + } + + /// Default MLX generation options used when none are provided at runtime. + public static var `default`: Self { + .init( + kvCache: .default, + userInputProcessing: nil, + additionalContext: nil + ) } } @@ -761,15 +845,16 @@ import Foundation } private func makeUserInput( - session: LanguageModelSession, - fallbackPrompt: String, - tools: [ToolSpec]? + chat: [MLXLMCommon.Chat.Message], + tools: [ToolSpec]?, + processing: MLXLMCommon.UserInput.Processing = .init(resize: nil), + additionalContext: [String: any Sendable]? = nil ) -> MLXLMCommon.UserInput { - let chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: fallbackPrompt) return MLXLMCommon.UserInput( chat: chat, - processing: .init(resize: .init(width: 512, height: 512)), - tools: tools + processing: processing, + tools: tools, + additionalContext: additionalContext, ) } @@ -813,6 +898,12 @@ import Foundation // Map AnyLanguageModel GenerationOptions to MLX GenerateParameters let generateParameters = toGenerateParameters(options) + // Extract additional context from custom options + let additionalContext = options[custom: MLXLanguageModel.self]?.additionalContextForUserInput + let userInputProcessing = + options[custom: MLXLanguageModel.self]?.processingForUserInput + ?? .init(resize: nil) + // Build chat history from full transcript var chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description) @@ -825,10 +916,11 @@ import Foundation // Loop until no more tool calls while true { // Build user input with current chat history and tools - let userInput = MLXLMCommon.UserInput( + let userInput = makeUserInput( chat: chat, - processing: .init(resize: .init(width: 512, height: 512)), - tools: toolSpecs + tools: toolSpecs, + processing: userInputProcessing, + additionalContext: additionalContext ) let lmInput = try await context.processor.prepare(input: userInput) let resolved = resolveCache( @@ -991,10 +1083,20 @@ import Foundation // Build chat inside task to avoid Sendable issues let generateParameters = toGenerateParameters(options) - let userInput = makeUserInput( + let additionalContext = options[custom: MLXLanguageModel.self]?.additionalContextForUserInput + let userInputProcessing = + options[custom: MLXLanguageModel.self]?.processingForUserInput + ?? .init(resize: nil) + let chat = convertTranscriptToMLXChat( session: session, - fallbackPrompt: prompt.description, - tools: nil + fallbackPrompt: prompt.description + ) + + let userInput = makeUserInput( + chat: chat, + tools: nil, + processing: userInputProcessing, + additionalContext: additionalContext ) let lmInput = try await context.processor.prepare(input: userInput) let resolved = resolveCache( @@ -1092,7 +1194,7 @@ import Foundation let newCache = context.model.newCache(parameters: params) let userInput = MLXLMCommon.UserInput( chat: [.init(role: .system, content: instructions)], - processing: .init(resize: .init(width: 512, height: 512)), + processing: .init(resize: nil), tools: toolSpecs ) let lmInput = try await context.processor.prepare(input: userInput) @@ -1116,10 +1218,10 @@ import Foundation let custom = options[custom: MLXLanguageModel.self] return MLXLMCommon.GenerateParameters( maxTokens: options.maximumResponseTokens, - maxKVSize: custom?.maxKVSize, - kvBits: custom?.kvBits, - kvGroupSize: custom?.kvGroupSize ?? 64, - quantizedKVStart: custom?.quantizedKVStart ?? 0, + maxKVSize: custom?.kvCache.maxSize, + kvBits: custom?.kvCache.bits, + kvGroupSize: custom?.kvCache.groupSize ?? 64, + quantizedKVStart: custom?.kvCache.quantizedStart ?? 0, temperature: Float(options.temperature ?? 0.6), topP: 1.0, repetitionPenalty: nil, @@ -1132,10 +1234,10 @@ import Foundation let custom = options[custom: MLXLanguageModel.self] return MLXLMCommon.GenerateParameters( maxTokens: options.maximumResponseTokens, - maxKVSize: custom?.maxKVSize, - kvBits: custom?.kvBits, - kvGroupSize: custom?.kvGroupSize ?? 64, - quantizedKVStart: custom?.quantizedKVStart ?? 0, + maxKVSize: custom?.kvCache.maxSize, + kvBits: custom?.kvCache.bits, + kvGroupSize: custom?.kvCache.groupSize ?? 64, + quantizedKVStart: custom?.kvCache.quantizedStart ?? 0, temperature: Float(options.temperature ?? 0.2), topP: 0.95, repetitionPenalty: 1.1, @@ -1489,7 +1591,7 @@ import Foundation { header += ". Expected value: \(constString)" } else if let enumValues = jsonSchema.enum, !enumValues.isEmpty, - let data = try? encoder.encode(JSONValue.array(enumValues)), + let data = try? encoder.encode(enumValues), let enumString = String(data: data, encoding: .utf8) { header += ". Allowed values: \(enumString)" @@ -1529,10 +1631,17 @@ import Foundation let baseChat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description) let schemaPrompt = includeSchemaInPrompt ? schemaPrompt(for: schema) : nil let chat = normalizeChatForStructuredGeneration(baseChat, schemaPrompt: schemaPrompt) + + let additionalContext = options[custom: MLXLanguageModel.self]?.additionalContextForUserInput + let userInputProcessing = + options[custom: MLXLanguageModel.self]?.processingForUserInput + ?? .init(resize: nil) + let userInput = MLXLMCommon.UserInput( chat: chat, - processing: .init(resize: .init(width: 512, height: 512)), - tools: nil + processing: userInputProcessing, + tools: nil, + additionalContext: additionalContext, ) let lmInput = try await context.processor.prepare(input: userInput) @@ -1773,4 +1882,18 @@ import Foundation return sampledToken.item(Int.self) } } + extension AnyLanguageModel.JSONValue { + /// Recursively converts a `JSONValue` to its primitive Swift equivalent. + func toSendable() -> any Sendable { + switch self { + case .string(let s): return s + case .int(let i): return i + case .double(let d): return d + case .bool(let b): return b + case .null: return NSNull() + case .array(let arr): return arr.map { $0.toSendable() } + case .object(let obj): return obj.mapValues { $0.toSendable() } + } + } + } #endif // MLX diff --git a/Sources/AnyLanguageModel/Shared/JSONValue.swift b/Sources/AnyLanguageModel/Shared/JSONValue.swift new file mode 100644 index 0000000..5d3c315 --- /dev/null +++ b/Sources/AnyLanguageModel/Shared/JSONValue.swift @@ -0,0 +1,4 @@ +import enum JSONSchema.JSONValue + +/// A type-safe representation of JSON values used by AnyLanguageModel APIs. +public typealias JSONValue = JSONSchema.JSONValue diff --git a/Tests/AnyLanguageModelTests/CustomGenerationOptionsTests.swift b/Tests/AnyLanguageModelTests/CustomGenerationOptionsTests.swift index 89539a1..3a4a812 100644 --- a/Tests/AnyLanguageModelTests/CustomGenerationOptionsTests.swift +++ b/Tests/AnyLanguageModelTests/CustomGenerationOptionsTests.swift @@ -867,38 +867,50 @@ struct GeminiCustomOptionsTests { struct MLXCustomOptionsTests { @Test func initialization() { let options = MLXLanguageModel.CustomGenerationOptions( - maxKVSize: 4096, - kvBits: 4, - kvGroupSize: 64, - quantizedKVStart: 128 + kvCache: .init( + maxSize: 4096, + bits: 4, + groupSize: 64, + quantizedStart: 128 + ), + userInputProcessing: nil, + additionalContext: nil ) - #expect(options.maxKVSize == 4096) - #expect(options.kvBits == 4) - #expect(options.kvGroupSize == 64) - #expect(options.quantizedKVStart == 128) + #expect(options.kvCache.maxSize == 4096) + #expect(options.kvCache.bits == 4) + #expect(options.kvCache.groupSize == 64) + #expect(options.kvCache.quantizedStart == 128) } @Test func integrationWithGenerationOptions() { var options = GenerationOptions(temperature: 0.7) options[custom: MLXLanguageModel.self] = .init( - maxKVSize: 2048, - kvBits: 8, - kvGroupSize: 32, - quantizedKVStart: 256 + kvCache: .init( + maxSize: 2048, + bits: 8, + groupSize: 32, + quantizedStart: 256 + ), + userInputProcessing: nil, + additionalContext: nil ) let retrieved = options[custom: MLXLanguageModel.self] - #expect(retrieved?.maxKVSize == 2048) - #expect(retrieved?.kvBits == 8) - #expect(retrieved?.kvGroupSize == 32) - #expect(retrieved?.quantizedKVStart == 256) + #expect(retrieved?.kvCache.maxSize == 2048) + #expect(retrieved?.kvCache.bits == 8) + #expect(retrieved?.kvCache.groupSize == 32) + #expect(retrieved?.kvCache.quantizedStart == 256) } @Test func codable() throws { let options = MLXLanguageModel.CustomGenerationOptions( - maxKVSize: 8192, - kvBits: 4, - kvGroupSize: 64, - quantizedKVStart: 0 + kvCache: .init( + maxSize: 8192, + bits: 4, + groupSize: 64, + quantizedStart: 0 + ), + userInputProcessing: nil, + additionalContext: nil ) let data = try JSONEncoder().encode(options) let decoded = try JSONDecoder().decode(MLXLanguageModel.CustomGenerationOptions.self, from: data) diff --git a/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift b/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift index 937b32e..bb048d3 100644 --- a/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift @@ -154,7 +154,11 @@ import Testing ) ]) let session = LanguageModelSession(model: visionModel, transcript: transcript) - let response = try await session.respond(to: "") + var options = GenerationOptions() + var mlxOptions = MLXLanguageModel.CustomGenerationOptions.default + mlxOptions.userInputProcessing = .resize(to: CGSize(width: 512, height: 512)) + options[custom: MLXLanguageModel.self] = mlxOptions + let response = try await session.respond(to: "", options: options) #expect(!response.content.isEmpty) } @@ -168,7 +172,11 @@ import Testing ) ]) let session = LanguageModelSession(model: visionModel, transcript: transcript) - let response = try await session.respond(to: "") + var options = GenerationOptions() + var mlxOptions = MLXLanguageModel.CustomGenerationOptions.default + mlxOptions.userInputProcessing = .resize(to: CGSize(width: 512, height: 512)) + options[custom: MLXLanguageModel.self] = mlxOptions + let response = try await session.respond(to: "", options: options) #expect(!response.content.isEmpty) } @@ -255,6 +263,28 @@ import Testing #expect([Priority.low, Priority.medium, Priority.high].contains(response.content)) } + @Test func withAdditionalContext() async throws { + let session = LanguageModelSession(model: model) + + var options = GenerationOptions( + temperature: 0.7, + maximumResponseTokens: 32 + ) + var custom = MLXLanguageModel.CustomGenerationOptions.default + custom.additionalContext = [ + "user_name": JSONValue.string("Alice"), + "turn_count": JSONValue.int(3), + "verbose": JSONValue.bool(true), + ] + options[custom: MLXLanguageModel.self] = custom + + let response = try await session.respond( + to: "Say hello", + options: options + ) + #expect(!response.content.isEmpty) + } + @Test func unavailableForNonexistentModel() async { let model = MLXLanguageModel(modelId: "mlx-community/does-not-exist-anylanguagemodel-test") await model.removeFromCache()