diff --git a/LLama/Extensions/IModelParamsExtensions.cs b/LLama/Extensions/IModelParamsExtensions.cs
index ad3a3591c..2939318da 100644
--- a/LLama/Extensions/IModelParamsExtensions.cs
+++ b/LLama/Extensions/IModelParamsExtensions.cs
@@ -3,6 +3,7 @@
using System.Text;
using LLama.Abstractions;
using LLama.Native;
+using System.Collections.Generic;
namespace LLama.Extensions;
@@ -45,20 +46,13 @@ public static IDisposable ToLlamaModelParams(this IModelParams @params, out LLam
result.tensor_split = (float*)disposer.Add(@params.TensorSplits.Pin()).Pointer;
}
- // Add tensor buffer overrides, if any
- if (@params.TensorBufferOverrides.Count > 0)
+ // Add tensor buffer overrides
+ unsafe
{
- var bufferOverrideHelper = new LLamaTensorBufferOverrideHelper();
- disposer.Add(bufferOverrideHelper);
-
- foreach (var tensorOverride in @params.TensorBufferOverrides)
- {
- bufferOverrideHelper.AddOverride(tensorOverride.Pattern, tensorOverride.BufferType);
- }
-
- bufferOverrideHelper.ApplyToModelParams(ref result);
+ result.tensor_buft_overrides = ConvertOverrides(@params.TensorBufferOverrides, disposer);
}
+ // Add metadata overrides
if (@params.MetadataOverrides.Count == 0)
{
unsafe
@@ -106,4 +100,69 @@ public static IDisposable ToLlamaModelParams(this IModelParams @params, out LLam
return disposer;
}
+
+ ///
+ /// Get a map from name of device (`ggml_backend_buft_name`) to the device type (`ggml_backend_dev_buffer_type`)
+ ///
+ /// Dictionary mapping buffer type names to their handles
+ private static IReadOnlyDictionary GetAvailableBufferTypes()
+ {
+ var result = new Dictionary();
+
+ var count = NativeApi.ggml_backend_dev_count();
+ for (nuint i = 0; i < count; i++)
+ {
+ var dev = NativeApi.ggml_backend_dev_get(i);
+ var buft = NativeApi.ggml_backend_dev_buffer_type(dev);
+
+ var name = Marshal.PtrToStringAnsi(NativeApi.ggml_backend_buft_name(buft));
+ if (string.IsNullOrEmpty(name))
+ continue;
+
+ result[name] = buft;
+ }
+
+ return result;
+ }
+
+ private static unsafe LLamaModelTensorBufferOverride* ConvertOverrides(List overrides, GroupDisposable disposer)
+ {
+ // Early out if there are no overrides
+ if (overrides.Count == 0)
+ return null;
+
+ var bufferTypes = GetAvailableBufferTypes();
+
+ var overridesCount = 0;
+ var overridesArray = new LLamaModelTensorBufferOverride[overrides.Count + 1];
+
+ foreach (var @override in overrides)
+ {
+ // Check if we have this buffer type
+ if (!bufferTypes.TryGetValue(@override.BufferType, out var bufferType))
+ continue;
+
+ // Create null terminated string and pin this memory so it can be passed to native code
+ var patternBytes = Encoding.UTF8.GetBytes(@override.Pattern + "\0");
+ var patternPin = patternBytes.AsMemory().Pin();
+ disposer.Add(patternPin);
+
+ // Add the item to the overridesArray
+ overridesArray[overridesCount++] = new()
+ {
+ Pattern = (byte*)patternPin.Pointer,
+ BufferType = bufferType
+ };
+ }
+
+ // Early out if there were no valid overrides
+ if (overridesCount == 0)
+ return null;
+
+ // Pin it so it can be safely passed across to native code
+ var overrideArrayPin = overridesArray.AsMemory().Pin();
+ disposer.Add(overrideArrayPin);
+
+ return (LLamaModelTensorBufferOverride*)overrideArrayPin.Pointer;
+ }
}
\ No newline at end of file
diff --git a/LLama/Native/LLamaModelTensorBufferOverride.cs b/LLama/Native/LLamaModelTensorBufferOverride.cs
index c3a6eac6f..3b7d3fa99 100644
--- a/LLama/Native/LLamaModelTensorBufferOverride.cs
+++ b/LLama/Native/LLamaModelTensorBufferOverride.cs
@@ -7,12 +7,12 @@ namespace LLama.Native
/// Original type: llama_model_tensor_buft_override
///
[StructLayout(LayoutKind.Sequential)]
- public struct LLamaModelTensorBufferOverride
+ public unsafe struct LLamaModelTensorBufferOverride
{
///
/// Tensor name pattern to match
///
- public IntPtr Pattern;
+ public byte* Pattern;
///
/// Backend buffer type to use for matching tensors, as obtained via ggml_backend_dev_buffer_type
diff --git a/LLama/Native/LLamaTensorBufferOverrideHelper.cs b/LLama/Native/LLamaTensorBufferOverrideHelper.cs
deleted file mode 100644
index daab52b62..000000000
--- a/LLama/Native/LLamaTensorBufferOverrideHelper.cs
+++ /dev/null
@@ -1,135 +0,0 @@
-using System;
-using System.Collections.Generic;
-using System.Text;
-
-namespace LLama.Native
-{
- ///
- /// Helper for creating and managing tensor buffer overrides
- ///
- internal class LLamaTensorBufferOverrideHelper : IDisposable
- {
- private readonly List _allocatedMemory = new();
- private readonly List _overrides = new();
- private IntPtr _overrideArray = IntPtr.Zero;
- private readonly Dictionary _bufferTypeCache = new();
-
- ///
- /// Get all available buffer types
- ///
- /// Dictionary mapping buffer type names to their handles
- public Dictionary GetAvailableBufferTypes()
- {
- var result = new Dictionary();
-
- nuint count = NativeApi.ggml_backend_dev_count();
- for (nuint i = 0; i < count; i++)
- {
- IntPtr dev = NativeApi.ggml_backend_dev_get(i);
- IntPtr buft = NativeApi.ggml_backend_dev_buffer_type(dev);
-
- if (buft != IntPtr.Zero)
- {
- IntPtr namePtr = NativeApi.ggml_backend_buft_name(buft);
- string name = Marshal.PtrToStringAnsi(namePtr) ?? string.Empty;
-
- if (!string.IsNullOrEmpty(name))
- {
- result[name] = buft;
- _bufferTypeCache[name] = buft;
- }
- }
- }
-
- return result;
- }
-
- ///
- /// Add a tensor buffer override
- ///
- /// Tensor name pattern to match
- /// Name of the buffer type to use
- /// True if the override was added successfully
- public bool AddOverride(string pattern, string bufferTypeName)
- {
- if (string.IsNullOrEmpty(pattern) || string.IsNullOrEmpty(bufferTypeName))
- return false;
-
- // Get all buffer types if cache is empty
- if (_bufferTypeCache.Count == 0)
- {
- GetAvailableBufferTypes();
- }
-
- // Check if we have this buffer type
- if (!_bufferTypeCache.TryGetValue(bufferTypeName, out IntPtr bufferType))
- return false;
-
- // Allocate memory for the pattern string and keep track of it
- byte[] patternBytes = Encoding.UTF8.GetBytes(pattern + "\0");
- IntPtr patternPtr = Marshal.AllocHGlobal(patternBytes.Length);
- Marshal.Copy(patternBytes, 0, patternPtr, patternBytes.Length);
- _allocatedMemory.Add(patternPtr);
-
- // Create the override
- var @override = new LLamaModelTensorBufferOverride
- {
- Pattern = patternPtr,
- BufferType = bufferType
- };
-
- _overrides.Add(@override);
- return true;
- }
-
- ///
- /// Apply the overrides to model parameters
- ///
- /// Model parameters to update
- public unsafe void ApplyToModelParams(ref LLamaModelParams modelParams)
- {
- if (_overrides.Count == 0)
- {
- modelParams.tensor_buft_overrides = null;
- return;
- }
-
- // Free previous array if it exists
- if (_overrideArray != IntPtr.Zero)
- {
- Marshal.FreeHGlobal(_overrideArray);
- }
-
- // Allocate memory for the array + null terminator
- int size = Marshal.SizeOf() * (_overrides.Count + 1);
- _overrideArray = Marshal.AllocHGlobal(size);
- _allocatedMemory.Add(_overrideArray);
-
- // Copy overrides to array
- for (int i = 0; i < _overrides.Count; i++)
- {
- IntPtr elemPtr = IntPtr.Add(_overrideArray, i * Marshal.SizeOf());
- Marshal.StructureToPtr(_overrides[i], elemPtr, false);
- }
-
- // Add null terminator
- IntPtr nullTermPtr = IntPtr.Add(_overrideArray, _overrides.Count * Marshal.SizeOf());
- Marshal.StructureToPtr(new LLamaModelTensorBufferOverride { Pattern = IntPtr.Zero, BufferType = IntPtr.Zero }, nullTermPtr, false);
-
- // Update model params
- modelParams.tensor_buft_overrides = (LLamaModelTensorBufferOverride*)_overrideArray.ToPointer();
- }
-
- ///
- public void Dispose()
- {
- foreach (IntPtr ptr in _allocatedMemory)
- {
- Marshal.FreeHGlobal(ptr);
- }
- _allocatedMemory.Clear();
- _overrides.Clear();
- _overrideArray = IntPtr.Zero;
- }
- }
-}