Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 70 additions & 11 deletions LLama/Extensions/IModelParamsExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.Text;
using LLama.Abstractions;
using LLama.Native;
using System.Collections.Generic;

namespace LLama.Extensions;

Expand Down Expand Up @@ -45,20 +46,13 @@ public static IDisposable ToLlamaModelParams(this IModelParams @params, out LLam
result.tensor_split = (float*)disposer.Add(@params.TensorSplits.Pin()).Pointer;
}

// Add tensor buffer overrides, if any
if (@params.TensorBufferOverrides.Count > 0)
// Add tensor buffer overrides
unsafe
{
var bufferOverrideHelper = new LLamaTensorBufferOverrideHelper();
disposer.Add(bufferOverrideHelper);

foreach (var tensorOverride in @params.TensorBufferOverrides)
{
bufferOverrideHelper.AddOverride(tensorOverride.Pattern, tensorOverride.BufferType);
}

bufferOverrideHelper.ApplyToModelParams(ref result);
result.tensor_buft_overrides = ConvertOverrides(@params.TensorBufferOverrides, disposer);
}

// Add metadata overrides
if (@params.MetadataOverrides.Count == 0)
{
unsafe
Expand Down Expand Up @@ -106,4 +100,69 @@ public static IDisposable ToLlamaModelParams(this IModelParams @params, out LLam

return disposer;
}

/// <summary>
/// Get a map from name of device (`ggml_backend_buft_name`) to the device type (`ggml_backend_dev_buffer_type`)
/// </summary>
/// <returns>Dictionary mapping buffer type names to their handles</returns>
private static IReadOnlyDictionary<string, IntPtr> GetAvailableBufferTypes()
{
var result = new Dictionary<string, IntPtr>();

var count = NativeApi.ggml_backend_dev_count();
for (nuint i = 0; i < count; i++)
{
var dev = NativeApi.ggml_backend_dev_get(i);
var buft = NativeApi.ggml_backend_dev_buffer_type(dev);

var name = Marshal.PtrToStringAnsi(NativeApi.ggml_backend_buft_name(buft));
if (string.IsNullOrEmpty(name))
continue;

result[name] = buft;
}

return result;
}

private static unsafe LLamaModelTensorBufferOverride* ConvertOverrides(List<TensorBufferOverride> overrides, GroupDisposable disposer)
{
// Early out if there are no overrides
if (overrides.Count == 0)
return null;

var bufferTypes = GetAvailableBufferTypes();

var overridesCount = 0;
var overridesArray = new LLamaModelTensorBufferOverride[overrides.Count + 1];

foreach (var @override in overrides)
{
// Check if we have this buffer type
if (!bufferTypes.TryGetValue(@override.BufferType, out var bufferType))
continue;

// Create null terminated string and pin this memory so it can be passed to native code
var patternBytes = Encoding.UTF8.GetBytes(@override.Pattern + "\0");
var patternPin = patternBytes.AsMemory().Pin();
disposer.Add(patternPin);

// Add the item to the overridesArray
overridesArray[overridesCount++] = new()
{
Pattern = (byte*)patternPin.Pointer,
BufferType = bufferType
};
}

// Early out if there were no valid overrides
if (overridesCount == 0)
return null;

// Pin it so it can be safely passed across to native code
var overrideArrayPin = overridesArray.AsMemory().Pin();
disposer.Add(overrideArrayPin);

return (LLamaModelTensorBufferOverride*)overrideArrayPin.Pointer;
}
}
4 changes: 2 additions & 2 deletions LLama/Native/LLamaModelTensorBufferOverride.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ namespace LLama.Native
/// Original type: llama_model_tensor_buft_override
/// </summary>
[StructLayout(LayoutKind.Sequential)]
public struct LLamaModelTensorBufferOverride
public unsafe struct LLamaModelTensorBufferOverride
{
/// <summary>
/// Tensor name pattern to match
/// </summary>
public IntPtr Pattern;
public byte* Pattern;

/// <summary>
/// Backend buffer type to use for matching tensors, as obtained via ggml_backend_dev_buffer_type
Expand Down
135 changes: 0 additions & 135 deletions LLama/Native/LLamaTensorBufferOverrideHelper.cs

This file was deleted.

Loading