diff --git a/LLama/Extensions/IModelParamsExtensions.cs b/LLama/Extensions/IModelParamsExtensions.cs index ad3a3591c..2939318da 100644 --- a/LLama/Extensions/IModelParamsExtensions.cs +++ b/LLama/Extensions/IModelParamsExtensions.cs @@ -3,6 +3,7 @@ using System.Text; using LLama.Abstractions; using LLama.Native; +using System.Collections.Generic; namespace LLama.Extensions; @@ -45,20 +46,13 @@ public static IDisposable ToLlamaModelParams(this IModelParams @params, out LLam result.tensor_split = (float*)disposer.Add(@params.TensorSplits.Pin()).Pointer; } - // Add tensor buffer overrides, if any - if (@params.TensorBufferOverrides.Count > 0) + // Add tensor buffer overrides + unsafe { - var bufferOverrideHelper = new LLamaTensorBufferOverrideHelper(); - disposer.Add(bufferOverrideHelper); - - foreach (var tensorOverride in @params.TensorBufferOverrides) - { - bufferOverrideHelper.AddOverride(tensorOverride.Pattern, tensorOverride.BufferType); - } - - bufferOverrideHelper.ApplyToModelParams(ref result); + result.tensor_buft_overrides = ConvertOverrides(@params.TensorBufferOverrides, disposer); } + // Add metadata overrides if (@params.MetadataOverrides.Count == 0) { unsafe @@ -106,4 +100,69 @@ public static IDisposable ToLlamaModelParams(this IModelParams @params, out LLam return disposer; } + + /// + /// Get a map from name of device (`ggml_backend_buft_name`) to the device type (`ggml_backend_dev_buffer_type`) + /// + /// Dictionary mapping buffer type names to their handles + private static IReadOnlyDictionary GetAvailableBufferTypes() + { + var result = new Dictionary(); + + var count = NativeApi.ggml_backend_dev_count(); + for (nuint i = 0; i < count; i++) + { + var dev = NativeApi.ggml_backend_dev_get(i); + var buft = NativeApi.ggml_backend_dev_buffer_type(dev); + + var name = Marshal.PtrToStringAnsi(NativeApi.ggml_backend_buft_name(buft)); + if (string.IsNullOrEmpty(name)) + continue; + + result[name] = buft; + } + + return result; + } + + private static unsafe LLamaModelTensorBufferOverride* ConvertOverrides(List overrides, GroupDisposable disposer) + { + // Early out if there are no overrides + if (overrides.Count == 0) + return null; + + var bufferTypes = GetAvailableBufferTypes(); + + var overridesCount = 0; + var overridesArray = new LLamaModelTensorBufferOverride[overrides.Count + 1]; + + foreach (var @override in overrides) + { + // Check if we have this buffer type + if (!bufferTypes.TryGetValue(@override.BufferType, out var bufferType)) + continue; + + // Create null terminated string and pin this memory so it can be passed to native code + var patternBytes = Encoding.UTF8.GetBytes(@override.Pattern + "\0"); + var patternPin = patternBytes.AsMemory().Pin(); + disposer.Add(patternPin); + + // Add the item to the overridesArray + overridesArray[overridesCount++] = new() + { + Pattern = (byte*)patternPin.Pointer, + BufferType = bufferType + }; + } + + // Early out if there were no valid overrides + if (overridesCount == 0) + return null; + + // Pin it so it can be safely passed across to native code + var overrideArrayPin = overridesArray.AsMemory().Pin(); + disposer.Add(overrideArrayPin); + + return (LLamaModelTensorBufferOverride*)overrideArrayPin.Pointer; + } } \ No newline at end of file diff --git a/LLama/Native/LLamaModelTensorBufferOverride.cs b/LLama/Native/LLamaModelTensorBufferOverride.cs index c3a6eac6f..3b7d3fa99 100644 --- a/LLama/Native/LLamaModelTensorBufferOverride.cs +++ b/LLama/Native/LLamaModelTensorBufferOverride.cs @@ -7,12 +7,12 @@ namespace LLama.Native /// Original type: llama_model_tensor_buft_override /// [StructLayout(LayoutKind.Sequential)] - public struct LLamaModelTensorBufferOverride + public unsafe struct LLamaModelTensorBufferOverride { /// /// Tensor name pattern to match /// - public IntPtr Pattern; + public byte* Pattern; /// /// Backend buffer type to use for matching tensors, as obtained via ggml_backend_dev_buffer_type diff --git a/LLama/Native/LLamaTensorBufferOverrideHelper.cs b/LLama/Native/LLamaTensorBufferOverrideHelper.cs deleted file mode 100644 index daab52b62..000000000 --- a/LLama/Native/LLamaTensorBufferOverrideHelper.cs +++ /dev/null @@ -1,135 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; - -namespace LLama.Native -{ - /// - /// Helper for creating and managing tensor buffer overrides - /// - internal class LLamaTensorBufferOverrideHelper : IDisposable - { - private readonly List _allocatedMemory = new(); - private readonly List _overrides = new(); - private IntPtr _overrideArray = IntPtr.Zero; - private readonly Dictionary _bufferTypeCache = new(); - - /// - /// Get all available buffer types - /// - /// Dictionary mapping buffer type names to their handles - public Dictionary GetAvailableBufferTypes() - { - var result = new Dictionary(); - - nuint count = NativeApi.ggml_backend_dev_count(); - for (nuint i = 0; i < count; i++) - { - IntPtr dev = NativeApi.ggml_backend_dev_get(i); - IntPtr buft = NativeApi.ggml_backend_dev_buffer_type(dev); - - if (buft != IntPtr.Zero) - { - IntPtr namePtr = NativeApi.ggml_backend_buft_name(buft); - string name = Marshal.PtrToStringAnsi(namePtr) ?? string.Empty; - - if (!string.IsNullOrEmpty(name)) - { - result[name] = buft; - _bufferTypeCache[name] = buft; - } - } - } - - return result; - } - - /// - /// Add a tensor buffer override - /// - /// Tensor name pattern to match - /// Name of the buffer type to use - /// True if the override was added successfully - public bool AddOverride(string pattern, string bufferTypeName) - { - if (string.IsNullOrEmpty(pattern) || string.IsNullOrEmpty(bufferTypeName)) - return false; - - // Get all buffer types if cache is empty - if (_bufferTypeCache.Count == 0) - { - GetAvailableBufferTypes(); - } - - // Check if we have this buffer type - if (!_bufferTypeCache.TryGetValue(bufferTypeName, out IntPtr bufferType)) - return false; - - // Allocate memory for the pattern string and keep track of it - byte[] patternBytes = Encoding.UTF8.GetBytes(pattern + "\0"); - IntPtr patternPtr = Marshal.AllocHGlobal(patternBytes.Length); - Marshal.Copy(patternBytes, 0, patternPtr, patternBytes.Length); - _allocatedMemory.Add(patternPtr); - - // Create the override - var @override = new LLamaModelTensorBufferOverride - { - Pattern = patternPtr, - BufferType = bufferType - }; - - _overrides.Add(@override); - return true; - } - - /// - /// Apply the overrides to model parameters - /// - /// Model parameters to update - public unsafe void ApplyToModelParams(ref LLamaModelParams modelParams) - { - if (_overrides.Count == 0) - { - modelParams.tensor_buft_overrides = null; - return; - } - - // Free previous array if it exists - if (_overrideArray != IntPtr.Zero) - { - Marshal.FreeHGlobal(_overrideArray); - } - - // Allocate memory for the array + null terminator - int size = Marshal.SizeOf() * (_overrides.Count + 1); - _overrideArray = Marshal.AllocHGlobal(size); - _allocatedMemory.Add(_overrideArray); - - // Copy overrides to array - for (int i = 0; i < _overrides.Count; i++) - { - IntPtr elemPtr = IntPtr.Add(_overrideArray, i * Marshal.SizeOf()); - Marshal.StructureToPtr(_overrides[i], elemPtr, false); - } - - // Add null terminator - IntPtr nullTermPtr = IntPtr.Add(_overrideArray, _overrides.Count * Marshal.SizeOf()); - Marshal.StructureToPtr(new LLamaModelTensorBufferOverride { Pattern = IntPtr.Zero, BufferType = IntPtr.Zero }, nullTermPtr, false); - - // Update model params - modelParams.tensor_buft_overrides = (LLamaModelTensorBufferOverride*)_overrideArray.ToPointer(); - } - - /// - public void Dispose() - { - foreach (IntPtr ptr in _allocatedMemory) - { - Marshal.FreeHGlobal(ptr); - } - _allocatedMemory.Clear(); - _overrides.Clear(); - _overrideArray = IntPtr.Zero; - } - } -}