diff --git a/.github/_typos.toml b/.github/_typos.toml
index fb576b499..d1f4f98b1 100644
--- a/.github/_typos.toml
+++ b/.github/_typos.toml
@@ -17,4 +17,6 @@ extend-exclude = [
[default.extend-words]
# Used in a comment in SafeLLamaSamplerHandle.cs, as a prefix of "hello"
-teh = "hel"
\ No newline at end of file
+teh = "hel"
+# ot is the shorthand version of llama.cpp's override-tensor parameter
+ot = "ot"
diff --git a/LLama.Unittest/ModelsParamsTests.cs b/LLama.Unittest/ModelsParamsTests.cs
index 3fab9ed3e..59cf70bf5 100644
--- a/LLama.Unittest/ModelsParamsTests.cs
+++ b/LLama.Unittest/ModelsParamsTests.cs
@@ -41,6 +41,11 @@ public void SerializeRoundTripSystemTextJson()
actual.MetadataOverrides = null!;
expected.MetadataOverrides = null!;
+ // Same deal
+ Assert.True(expected.TensorBufferOverrides.SequenceEqual(actual.TensorBufferOverrides));
+ actual.TensorBufferOverrides = null!;
+ expected.TensorBufferOverrides = null!;
+
// Check encoding is the same
var b1 = expected.Encoding.GetBytes("Hello");
var b2 = actual.Encoding.GetBytes("Hello");
diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs
index a67a11a96..9824c0922 100644
--- a/LLama.Web/Common/ModelOptions.cs
+++ b/LLama.Web/Common/ModelOptions.cs
@@ -26,6 +26,9 @@ public class ModelOptions
///
public GPUSplitMode? SplitMode { get; set; }
+ ///
+ public List TensorBufferOverrides { get; set; } = new();
+
///
public int GpuLayerCount { get; set; } = 20;
diff --git a/LLama/Abstractions/IModelParams.cs b/LLama/Abstractions/IModelParams.cs
index cbbacafe5..8a752e190 100644
--- a/LLama/Abstractions/IModelParams.cs
+++ b/LLama/Abstractions/IModelParams.cs
@@ -38,6 +38,12 @@ public interface IModelParams
///
GPUSplitMode? SplitMode { get; }
+ ///
+ /// Buffer type overrides for specific tensor patterns, allowing you to specify hardware devices to use for individual tensors or sets of tensors.
+ /// Equivalent to --override-tensor or -ot on the llama.cpp command line or tensor_buft_overrides internally.
+ ///
+ List TensorBufferOverrides { get; }
+
///
/// Number of layers to run in VRAM / GPU memory (n_gpu_layers)
///
diff --git a/LLama/Abstractions/TensorBufferOverride.cs b/LLama/Abstractions/TensorBufferOverride.cs
new file mode 100644
index 000000000..e8ec3f136
--- /dev/null
+++ b/LLama/Abstractions/TensorBufferOverride.cs
@@ -0,0 +1,36 @@
+using System;
+
+namespace LLama.Abstractions
+{
+ ///
+ /// Represents a mapping between a tensor name pattern and a specific buffer type
+ ///
+ public class TensorBufferOverride
+ {
+ ///
+ /// Pattern to match tensor names. This is a regular expression. You can check the tensor names via the model.Metadata.
+ ///
+ public string Pattern { get; set; }
+
+ ///
+ /// Buffer type to use for matching tensors. Examples: CPU, GPU0, GPU1
+ ///
+ public string BufferType { get; set; }
+
+ ///
+ /// Creates a new tensor buffer override
+ ///
+ /// Pattern to match tensor names
+ /// Buffer type to use for matching tensors
+ public TensorBufferOverride(string pattern, string bufferType)
+ {
+ if (string.IsNullOrEmpty(pattern))
+ throw new ArgumentException("Pattern cannot be null or empty", nameof(pattern));
+ if (string.IsNullOrEmpty(bufferType))
+ throw new ArgumentException("Buffer type cannot be null or empty", nameof(bufferType));
+
+ Pattern = pattern;
+ BufferType = bufferType;
+ }
+ }
+}
diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs
index 7e4b1a967..23f5681be 100644
--- a/LLama/Common/ModelParams.cs
+++ b/LLama/Common/ModelParams.cs
@@ -21,6 +21,9 @@ public record ModelParams
///
public GPUSplitMode? SplitMode { get; set; }
+ ///
+ public List TensorBufferOverrides { get; set; } = new();
+
///
public int GpuLayerCount { get; set; } = 20;
diff --git a/LLama/Extensions/IModelParamsExtensions.cs b/LLama/Extensions/IModelParamsExtensions.cs
index 588564e33..ad3a3591c 100644
--- a/LLama/Extensions/IModelParamsExtensions.cs
+++ b/LLama/Extensions/IModelParamsExtensions.cs
@@ -45,6 +45,20 @@ public static IDisposable ToLlamaModelParams(this IModelParams @params, out LLam
result.tensor_split = (float*)disposer.Add(@params.TensorSplits.Pin()).Pointer;
}
+ // Add tensor buffer overrides, if any
+ if (@params.TensorBufferOverrides.Count > 0)
+ {
+ var bufferOverrideHelper = new LLamaTensorBufferOverrideHelper();
+ disposer.Add(bufferOverrideHelper);
+
+ foreach (var tensorOverride in @params.TensorBufferOverrides)
+ {
+ bufferOverrideHelper.AddOverride(tensorOverride.Pattern, tensorOverride.BufferType);
+ }
+
+ bufferOverrideHelper.ApplyToModelParams(ref result);
+ }
+
if (@params.MetadataOverrides.Count == 0)
{
unsafe
diff --git a/LLama/Native/LLamaModelParams.cs b/LLama/Native/LLamaModelParams.cs
index 562896d7b..acb024852 100644
--- a/LLama/Native/LLamaModelParams.cs
+++ b/LLama/Native/LLamaModelParams.cs
@@ -13,11 +13,11 @@ public unsafe struct LLamaModelParams
/// todo: add support for llama_model_params.devices
///
private IntPtr devices;
-
- // NULL-terminated list of buffer types to use for tensors that match a pattern
- // actual type: llama_model_tensor_buft_override*
- // todo: add support for tensor_buft_overrides
- private IntPtr tensor_buft_overrides;
+
+ ///
+ /// NULL-terminated list of buffer types to use for tensors that match a pattern
+ ///
+ public LLamaModelTensorBufferOverride* tensor_buft_overrides;
///
/// // number of layers to store in VRAM
diff --git a/LLama/Native/LLamaModelTensorBufferOverride.cs b/LLama/Native/LLamaModelTensorBufferOverride.cs
new file mode 100644
index 000000000..c3a6eac6f
--- /dev/null
+++ b/LLama/Native/LLamaModelTensorBufferOverride.cs
@@ -0,0 +1,22 @@
+using System;
+
+namespace LLama.Native
+{
+ ///
+ /// Represents a mapping between a tensor name pattern and a backend buffer type
+ /// Original type: llama_model_tensor_buft_override
+ ///
+ [StructLayout(LayoutKind.Sequential)]
+ public struct LLamaModelTensorBufferOverride
+ {
+ ///
+ /// Tensor name pattern to match
+ ///
+ public IntPtr Pattern;
+
+ ///
+ /// Backend buffer type to use for matching tensors, as obtained via ggml_backend_dev_buffer_type
+ ///
+ public IntPtr BufferType;
+ }
+}
diff --git a/LLama/Native/LLamaTensorBufferOverrideHelper.cs b/LLama/Native/LLamaTensorBufferOverrideHelper.cs
new file mode 100644
index 000000000..daab52b62
--- /dev/null
+++ b/LLama/Native/LLamaTensorBufferOverrideHelper.cs
@@ -0,0 +1,135 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace LLama.Native
+{
+ ///
+ /// Helper for creating and managing tensor buffer overrides
+ ///
+ internal class LLamaTensorBufferOverrideHelper : IDisposable
+ {
+ private readonly List _allocatedMemory = new();
+ private readonly List _overrides = new();
+ private IntPtr _overrideArray = IntPtr.Zero;
+ private readonly Dictionary _bufferTypeCache = new();
+
+ ///
+ /// Get all available buffer types
+ ///
+ /// Dictionary mapping buffer type names to their handles
+ public Dictionary GetAvailableBufferTypes()
+ {
+ var result = new Dictionary();
+
+ nuint count = NativeApi.ggml_backend_dev_count();
+ for (nuint i = 0; i < count; i++)
+ {
+ IntPtr dev = NativeApi.ggml_backend_dev_get(i);
+ IntPtr buft = NativeApi.ggml_backend_dev_buffer_type(dev);
+
+ if (buft != IntPtr.Zero)
+ {
+ IntPtr namePtr = NativeApi.ggml_backend_buft_name(buft);
+ string name = Marshal.PtrToStringAnsi(namePtr) ?? string.Empty;
+
+ if (!string.IsNullOrEmpty(name))
+ {
+ result[name] = buft;
+ _bufferTypeCache[name] = buft;
+ }
+ }
+ }
+
+ return result;
+ }
+
+ ///
+ /// Add a tensor buffer override
+ ///
+ /// Tensor name pattern to match
+ /// Name of the buffer type to use
+ /// True if the override was added successfully
+ public bool AddOverride(string pattern, string bufferTypeName)
+ {
+ if (string.IsNullOrEmpty(pattern) || string.IsNullOrEmpty(bufferTypeName))
+ return false;
+
+ // Get all buffer types if cache is empty
+ if (_bufferTypeCache.Count == 0)
+ {
+ GetAvailableBufferTypes();
+ }
+
+ // Check if we have this buffer type
+ if (!_bufferTypeCache.TryGetValue(bufferTypeName, out IntPtr bufferType))
+ return false;
+
+ // Allocate memory for the pattern string and keep track of it
+ byte[] patternBytes = Encoding.UTF8.GetBytes(pattern + "\0");
+ IntPtr patternPtr = Marshal.AllocHGlobal(patternBytes.Length);
+ Marshal.Copy(patternBytes, 0, patternPtr, patternBytes.Length);
+ _allocatedMemory.Add(patternPtr);
+
+ // Create the override
+ var @override = new LLamaModelTensorBufferOverride
+ {
+ Pattern = patternPtr,
+ BufferType = bufferType
+ };
+
+ _overrides.Add(@override);
+ return true;
+ }
+
+ ///
+ /// Apply the overrides to model parameters
+ ///
+ /// Model parameters to update
+ public unsafe void ApplyToModelParams(ref LLamaModelParams modelParams)
+ {
+ if (_overrides.Count == 0)
+ {
+ modelParams.tensor_buft_overrides = null;
+ return;
+ }
+
+ // Free previous array if it exists
+ if (_overrideArray != IntPtr.Zero)
+ {
+ Marshal.FreeHGlobal(_overrideArray);
+ }
+
+ // Allocate memory for the array + null terminator
+ int size = Marshal.SizeOf() * (_overrides.Count + 1);
+ _overrideArray = Marshal.AllocHGlobal(size);
+ _allocatedMemory.Add(_overrideArray);
+
+ // Copy overrides to array
+ for (int i = 0; i < _overrides.Count; i++)
+ {
+ IntPtr elemPtr = IntPtr.Add(_overrideArray, i * Marshal.SizeOf());
+ Marshal.StructureToPtr(_overrides[i], elemPtr, false);
+ }
+
+ // Add null terminator
+ IntPtr nullTermPtr = IntPtr.Add(_overrideArray, _overrides.Count * Marshal.SizeOf());
+ Marshal.StructureToPtr(new LLamaModelTensorBufferOverride { Pattern = IntPtr.Zero, BufferType = IntPtr.Zero }, nullTermPtr, false);
+
+ // Update model params
+ modelParams.tensor_buft_overrides = (LLamaModelTensorBufferOverride*)_overrideArray.ToPointer();
+ }
+
+ ///
+ public void Dispose()
+ {
+ foreach (IntPtr ptr in _allocatedMemory)
+ {
+ Marshal.FreeHGlobal(ptr);
+ }
+ _allocatedMemory.Clear();
+ _overrides.Clear();
+ _overrideArray = IntPtr.Zero;
+ }
+ }
+}
diff --git a/LLama/Native/NativeApi.Load.cs b/LLama/Native/NativeApi.Load.cs
index 2d5be063f..4555ed0d2 100644
--- a/LLama/Native/NativeApi.Load.cs
+++ b/LLama/Native/NativeApi.Load.cs
@@ -107,6 +107,8 @@ private static void SetDllImportResolver()
internal const string libraryName = "llama";
internal const string llavaLibraryName = "llava_shared";
+ internal const string ggmlLibraryName = "ggml";
+ internal const string ggmlBaseLibraryName = "ggml-base";
private static INativeLibrary? _loadedLLamaLibrary = null;
private static INativeLibrary? _loadedLLavaLibrary = null;
diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs
index d238753fe..87cf02c78 100644
--- a/LLama/Native/NativeApi.cs
+++ b/LLama/Native/NativeApi.cs
@@ -447,5 +447,36 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
// it would expose the raw pointer to the model, without properly wrapping it in a SafeLLamaModelHandle.
//[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
//public static void llama_model* llama_get_model(SafeLLamaContextHandle ctx);
+
+ ///
+ /// Get the number of available backend devices
+ ///
+ /// Count of available backend devices
+ [DllImport(ggmlLibraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern nuint ggml_backend_dev_count();
+
+ ///
+ /// Get a backend device by index
+ ///
+ /// Device index
+ /// Pointer to the backend device
+ [DllImport(ggmlLibraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern IntPtr ggml_backend_dev_get(nuint i);
+
+ ///
+ /// Get the buffer type for a backend device
+ ///
+ /// Backend device pointer
+ /// Pointer to the buffer type
+ [DllImport(ggmlBaseLibraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern IntPtr ggml_backend_dev_buffer_type(IntPtr dev);
+
+ ///
+ /// Get the name of a buffer type
+ ///
+ /// Buffer type pointer
+ /// Name of the buffer type
+ [DllImport(ggmlBaseLibraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern IntPtr ggml_backend_buft_name(IntPtr buft);
}
}