diff --git a/.github/_typos.toml b/.github/_typos.toml index fb576b499..d1f4f98b1 100644 --- a/.github/_typos.toml +++ b/.github/_typos.toml @@ -17,4 +17,6 @@ extend-exclude = [ [default.extend-words] # Used in a comment in SafeLLamaSamplerHandle.cs, as a prefix of "hello" -teh = "hel" \ No newline at end of file +teh = "hel" +# ot is the shorthand version of llama.cpp's override-tensor parameter +ot = "ot" diff --git a/LLama.Unittest/ModelsParamsTests.cs b/LLama.Unittest/ModelsParamsTests.cs index 3fab9ed3e..59cf70bf5 100644 --- a/LLama.Unittest/ModelsParamsTests.cs +++ b/LLama.Unittest/ModelsParamsTests.cs @@ -41,6 +41,11 @@ public void SerializeRoundTripSystemTextJson() actual.MetadataOverrides = null!; expected.MetadataOverrides = null!; + // Same deal + Assert.True(expected.TensorBufferOverrides.SequenceEqual(actual.TensorBufferOverrides)); + actual.TensorBufferOverrides = null!; + expected.TensorBufferOverrides = null!; + // Check encoding is the same var b1 = expected.Encoding.GetBytes("Hello"); var b2 = actual.Encoding.GetBytes("Hello"); diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs index a67a11a96..9824c0922 100644 --- a/LLama.Web/Common/ModelOptions.cs +++ b/LLama.Web/Common/ModelOptions.cs @@ -26,6 +26,9 @@ public class ModelOptions /// public GPUSplitMode? SplitMode { get; set; } + /// + public List TensorBufferOverrides { get; set; } = new(); + /// public int GpuLayerCount { get; set; } = 20; diff --git a/LLama/Abstractions/IModelParams.cs b/LLama/Abstractions/IModelParams.cs index cbbacafe5..8a752e190 100644 --- a/LLama/Abstractions/IModelParams.cs +++ b/LLama/Abstractions/IModelParams.cs @@ -38,6 +38,12 @@ public interface IModelParams /// GPUSplitMode? SplitMode { get; } + /// + /// Buffer type overrides for specific tensor patterns, allowing you to specify hardware devices to use for individual tensors or sets of tensors. + /// Equivalent to --override-tensor or -ot on the llama.cpp command line or tensor_buft_overrides internally. + /// + List TensorBufferOverrides { get; } + /// /// Number of layers to run in VRAM / GPU memory (n_gpu_layers) /// diff --git a/LLama/Abstractions/TensorBufferOverride.cs b/LLama/Abstractions/TensorBufferOverride.cs new file mode 100644 index 000000000..e8ec3f136 --- /dev/null +++ b/LLama/Abstractions/TensorBufferOverride.cs @@ -0,0 +1,36 @@ +using System; + +namespace LLama.Abstractions +{ + /// + /// Represents a mapping between a tensor name pattern and a specific buffer type + /// + public class TensorBufferOverride + { + /// + /// Pattern to match tensor names. This is a regular expression. You can check the tensor names via the model.Metadata. + /// + public string Pattern { get; set; } + + /// + /// Buffer type to use for matching tensors. Examples: CPU, GPU0, GPU1 + /// + public string BufferType { get; set; } + + /// + /// Creates a new tensor buffer override + /// + /// Pattern to match tensor names + /// Buffer type to use for matching tensors + public TensorBufferOverride(string pattern, string bufferType) + { + if (string.IsNullOrEmpty(pattern)) + throw new ArgumentException("Pattern cannot be null or empty", nameof(pattern)); + if (string.IsNullOrEmpty(bufferType)) + throw new ArgumentException("Buffer type cannot be null or empty", nameof(bufferType)); + + Pattern = pattern; + BufferType = bufferType; + } + } +} diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs index 7e4b1a967..23f5681be 100644 --- a/LLama/Common/ModelParams.cs +++ b/LLama/Common/ModelParams.cs @@ -21,6 +21,9 @@ public record ModelParams /// public GPUSplitMode? SplitMode { get; set; } + /// + public List TensorBufferOverrides { get; set; } = new(); + /// public int GpuLayerCount { get; set; } = 20; diff --git a/LLama/Extensions/IModelParamsExtensions.cs b/LLama/Extensions/IModelParamsExtensions.cs index 588564e33..ad3a3591c 100644 --- a/LLama/Extensions/IModelParamsExtensions.cs +++ b/LLama/Extensions/IModelParamsExtensions.cs @@ -45,6 +45,20 @@ public static IDisposable ToLlamaModelParams(this IModelParams @params, out LLam result.tensor_split = (float*)disposer.Add(@params.TensorSplits.Pin()).Pointer; } + // Add tensor buffer overrides, if any + if (@params.TensorBufferOverrides.Count > 0) + { + var bufferOverrideHelper = new LLamaTensorBufferOverrideHelper(); + disposer.Add(bufferOverrideHelper); + + foreach (var tensorOverride in @params.TensorBufferOverrides) + { + bufferOverrideHelper.AddOverride(tensorOverride.Pattern, tensorOverride.BufferType); + } + + bufferOverrideHelper.ApplyToModelParams(ref result); + } + if (@params.MetadataOverrides.Count == 0) { unsafe diff --git a/LLama/Native/LLamaModelParams.cs b/LLama/Native/LLamaModelParams.cs index 562896d7b..acb024852 100644 --- a/LLama/Native/LLamaModelParams.cs +++ b/LLama/Native/LLamaModelParams.cs @@ -13,11 +13,11 @@ public unsafe struct LLamaModelParams /// todo: add support for llama_model_params.devices /// private IntPtr devices; - - // NULL-terminated list of buffer types to use for tensors that match a pattern - // actual type: llama_model_tensor_buft_override* - // todo: add support for tensor_buft_overrides - private IntPtr tensor_buft_overrides; + + /// + /// NULL-terminated list of buffer types to use for tensors that match a pattern + /// + public LLamaModelTensorBufferOverride* tensor_buft_overrides; /// /// // number of layers to store in VRAM diff --git a/LLama/Native/LLamaModelTensorBufferOverride.cs b/LLama/Native/LLamaModelTensorBufferOverride.cs new file mode 100644 index 000000000..c3a6eac6f --- /dev/null +++ b/LLama/Native/LLamaModelTensorBufferOverride.cs @@ -0,0 +1,22 @@ +using System; + +namespace LLama.Native +{ + /// + /// Represents a mapping between a tensor name pattern and a backend buffer type
+ /// Original type: llama_model_tensor_buft_override + ///
+ [StructLayout(LayoutKind.Sequential)] + public struct LLamaModelTensorBufferOverride + { + /// + /// Tensor name pattern to match + /// + public IntPtr Pattern; + + /// + /// Backend buffer type to use for matching tensors, as obtained via ggml_backend_dev_buffer_type + /// + public IntPtr BufferType; + } +} diff --git a/LLama/Native/LLamaTensorBufferOverrideHelper.cs b/LLama/Native/LLamaTensorBufferOverrideHelper.cs new file mode 100644 index 000000000..daab52b62 --- /dev/null +++ b/LLama/Native/LLamaTensorBufferOverrideHelper.cs @@ -0,0 +1,135 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace LLama.Native +{ + /// + /// Helper for creating and managing tensor buffer overrides + /// + internal class LLamaTensorBufferOverrideHelper : IDisposable + { + private readonly List _allocatedMemory = new(); + private readonly List _overrides = new(); + private IntPtr _overrideArray = IntPtr.Zero; + private readonly Dictionary _bufferTypeCache = new(); + + /// + /// Get all available buffer types + /// + /// Dictionary mapping buffer type names to their handles + public Dictionary GetAvailableBufferTypes() + { + var result = new Dictionary(); + + nuint count = NativeApi.ggml_backend_dev_count(); + for (nuint i = 0; i < count; i++) + { + IntPtr dev = NativeApi.ggml_backend_dev_get(i); + IntPtr buft = NativeApi.ggml_backend_dev_buffer_type(dev); + + if (buft != IntPtr.Zero) + { + IntPtr namePtr = NativeApi.ggml_backend_buft_name(buft); + string name = Marshal.PtrToStringAnsi(namePtr) ?? string.Empty; + + if (!string.IsNullOrEmpty(name)) + { + result[name] = buft; + _bufferTypeCache[name] = buft; + } + } + } + + return result; + } + + /// + /// Add a tensor buffer override + /// + /// Tensor name pattern to match + /// Name of the buffer type to use + /// True if the override was added successfully + public bool AddOverride(string pattern, string bufferTypeName) + { + if (string.IsNullOrEmpty(pattern) || string.IsNullOrEmpty(bufferTypeName)) + return false; + + // Get all buffer types if cache is empty + if (_bufferTypeCache.Count == 0) + { + GetAvailableBufferTypes(); + } + + // Check if we have this buffer type + if (!_bufferTypeCache.TryGetValue(bufferTypeName, out IntPtr bufferType)) + return false; + + // Allocate memory for the pattern string and keep track of it + byte[] patternBytes = Encoding.UTF8.GetBytes(pattern + "\0"); + IntPtr patternPtr = Marshal.AllocHGlobal(patternBytes.Length); + Marshal.Copy(patternBytes, 0, patternPtr, patternBytes.Length); + _allocatedMemory.Add(patternPtr); + + // Create the override + var @override = new LLamaModelTensorBufferOverride + { + Pattern = patternPtr, + BufferType = bufferType + }; + + _overrides.Add(@override); + return true; + } + + /// + /// Apply the overrides to model parameters + /// + /// Model parameters to update + public unsafe void ApplyToModelParams(ref LLamaModelParams modelParams) + { + if (_overrides.Count == 0) + { + modelParams.tensor_buft_overrides = null; + return; + } + + // Free previous array if it exists + if (_overrideArray != IntPtr.Zero) + { + Marshal.FreeHGlobal(_overrideArray); + } + + // Allocate memory for the array + null terminator + int size = Marshal.SizeOf() * (_overrides.Count + 1); + _overrideArray = Marshal.AllocHGlobal(size); + _allocatedMemory.Add(_overrideArray); + + // Copy overrides to array + for (int i = 0; i < _overrides.Count; i++) + { + IntPtr elemPtr = IntPtr.Add(_overrideArray, i * Marshal.SizeOf()); + Marshal.StructureToPtr(_overrides[i], elemPtr, false); + } + + // Add null terminator + IntPtr nullTermPtr = IntPtr.Add(_overrideArray, _overrides.Count * Marshal.SizeOf()); + Marshal.StructureToPtr(new LLamaModelTensorBufferOverride { Pattern = IntPtr.Zero, BufferType = IntPtr.Zero }, nullTermPtr, false); + + // Update model params + modelParams.tensor_buft_overrides = (LLamaModelTensorBufferOverride*)_overrideArray.ToPointer(); + } + + /// + public void Dispose() + { + foreach (IntPtr ptr in _allocatedMemory) + { + Marshal.FreeHGlobal(ptr); + } + _allocatedMemory.Clear(); + _overrides.Clear(); + _overrideArray = IntPtr.Zero; + } + } +} diff --git a/LLama/Native/NativeApi.Load.cs b/LLama/Native/NativeApi.Load.cs index 2d5be063f..4555ed0d2 100644 --- a/LLama/Native/NativeApi.Load.cs +++ b/LLama/Native/NativeApi.Load.cs @@ -107,6 +107,8 @@ private static void SetDllImportResolver() internal const string libraryName = "llama"; internal const string llavaLibraryName = "llava_shared"; + internal const string ggmlLibraryName = "ggml"; + internal const string ggmlBaseLibraryName = "ggml-base"; private static INativeLibrary? _loadedLLamaLibrary = null; private static INativeLibrary? _loadedLLavaLibrary = null; diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index d238753fe..87cf02c78 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -447,5 +447,36 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback) // it would expose the raw pointer to the model, without properly wrapping it in a SafeLLamaModelHandle. //[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] //public static void llama_model* llama_get_model(SafeLLamaContextHandle ctx); + + /// + /// Get the number of available backend devices + /// + /// Count of available backend devices + [DllImport(ggmlLibraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern nuint ggml_backend_dev_count(); + + /// + /// Get a backend device by index + /// + /// Device index + /// Pointer to the backend device + [DllImport(ggmlLibraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern IntPtr ggml_backend_dev_get(nuint i); + + /// + /// Get the buffer type for a backend device + /// + /// Backend device pointer + /// Pointer to the buffer type + [DllImport(ggmlBaseLibraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern IntPtr ggml_backend_dev_buffer_type(IntPtr dev); + + /// + /// Get the name of a buffer type + /// + /// Buffer type pointer + /// Name of the buffer type + [DllImport(ggmlBaseLibraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern IntPtr ggml_backend_buft_name(IntPtr buft); } }