From a0e68a4ee0abafd5d9b8d1ef482d2e4dbd7e8ab6 Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Tue, 22 Aug 2023 21:27:12 +1200 Subject: [PATCH 1/4] Update Web example to use multi-context --- LLama.Web/Common/InferenceOptions.cs | 99 +++++++++ LLama.Web/Common/LLamaOptions.cs | 9 - LLama.Web/Common/ModelOptions.cs | 189 +++++++++--------- LLama.Web/Common/ParameterOptions.cs | 99 --------- .../{PromptOptions.cs => PromptConfig.cs} | 4 +- LLama.Web/Common/ServiceResult.cs | 41 ---- LLama.Web/Common/SessionConfig.cs | 13 ++ LLama.Web/Hubs/ISessionClient.cs | 2 - LLama.Web/Hubs/SessionConnectionHub.cs | 56 ++---- LLama.Web/Models/CancelModel.cs | 7 - LLama.Web/Models/InferTokenModel.cs | 12 ++ LLama.Web/Models/ModelSession.cs | 154 ++++++++++---- LLama.Web/Models/ResponseFragment.cs | 18 -- LLama.Web/Pages/Executor/Instruct.cshtml | 96 --------- LLama.Web/Pages/Executor/Instruct.cshtml.cs | 34 ---- LLama.Web/Pages/Executor/Instruct.cshtml.css | 4 - LLama.Web/Pages/Executor/Interactive.cshtml | 96 --------- .../Pages/Executor/Interactive.cshtml.cs | 34 ---- .../Pages/Executor/Interactive.cshtml.css | 4 - LLama.Web/Pages/Executor/Stateless.cshtml | 97 --------- LLama.Web/Pages/Executor/Stateless.cshtml.cs | 34 ---- LLama.Web/Pages/Executor/Stateless.cshtml.css | 4 - LLama.Web/Pages/Index.cshtml | 117 ++++++++++- LLama.Web/Pages/Index.cshtml.cs | 25 ++- LLama.Web/Pages/Shared/_ChatTemplates.cshtml | 24 +-- LLama.Web/Pages/Shared/_Layout.cshtml | 32 +-- LLama.Web/Pages/Shared/_Parameters.cshtml | 137 +++++++++++++ LLama.Web/Pages/Shared/_Parameters.cshtml.cs | 12 ++ LLama.Web/Program.cs | 3 +- LLama.Web/README.md | 62 ++++-- .../Services/ConnectionSessionService.cs | 94 --------- LLama.Web/Services/IModelService.cs | 14 ++ LLama.Web/Services/IModelSessionService.cs | 8 +- LLama.Web/Services/ModelService.cs | 170 ++++++++++++++++ LLama.Web/Services/ModelSessionService.cs | 138 +++++++++++++ LLama.Web/appsettings.json | 59 +++--- LLama.Web/wwwroot/css/site.css | 17 +- LLama.Web/wwwroot/js/sessionConnectionChat.js | 139 ++++++++----- LLama.Web/wwwroot/js/site.js | 8 +- LLamaSharp.sln | 2 +- 40 files changed, 1163 insertions(+), 1004 deletions(-) create mode 100644 LLama.Web/Common/InferenceOptions.cs delete mode 100644 LLama.Web/Common/ParameterOptions.cs rename LLama.Web/Common/{PromptOptions.cs => PromptConfig.cs} (63%) delete mode 100644 LLama.Web/Common/ServiceResult.cs create mode 100644 LLama.Web/Common/SessionConfig.cs delete mode 100644 LLama.Web/Models/CancelModel.cs create mode 100644 LLama.Web/Models/InferTokenModel.cs delete mode 100644 LLama.Web/Models/ResponseFragment.cs delete mode 100644 LLama.Web/Pages/Executor/Instruct.cshtml delete mode 100644 LLama.Web/Pages/Executor/Instruct.cshtml.cs delete mode 100644 LLama.Web/Pages/Executor/Instruct.cshtml.css delete mode 100644 LLama.Web/Pages/Executor/Interactive.cshtml delete mode 100644 LLama.Web/Pages/Executor/Interactive.cshtml.cs delete mode 100644 LLama.Web/Pages/Executor/Interactive.cshtml.css delete mode 100644 LLama.Web/Pages/Executor/Stateless.cshtml delete mode 100644 LLama.Web/Pages/Executor/Stateless.cshtml.cs delete mode 100644 LLama.Web/Pages/Executor/Stateless.cshtml.css create mode 100644 LLama.Web/Pages/Shared/_Parameters.cshtml create mode 100644 LLama.Web/Pages/Shared/_Parameters.cshtml.cs delete mode 100644 LLama.Web/Services/ConnectionSessionService.cs create mode 100644 LLama.Web/Services/IModelService.cs create mode 100644 LLama.Web/Services/ModelService.cs create mode 100644 LLama.Web/Services/ModelSessionService.cs diff --git a/LLama.Web/Common/InferenceOptions.cs b/LLama.Web/Common/InferenceOptions.cs new file mode 100644 index 000000000..c091243b8 --- /dev/null +++ b/LLama.Web/Common/InferenceOptions.cs @@ -0,0 +1,99 @@ +using LLama.Common; +using LLama.Abstractions; + +namespace LLama.Web.Common +{ + public class InferenceOptions : IInferenceParams + { + public string Name { get; set; } + + + + /// + /// number of tokens to keep from initial prompt + /// + public int TokensKeep { get; set; } = 0; + /// + /// how many new tokens to predict (n_predict), set to -1 to inifinitely generate response + /// until it complete. + /// + public int MaxTokens { get; set; } = -1; + /// + /// logit bias for specific tokens + /// + public Dictionary? LogitBias { get; set; } = null; + + /// + /// Sequences where the model will stop generating further tokens. + /// + public IEnumerable AntiPrompts { get; set; } = Array.Empty(); + /// + /// path to file for saving/loading model eval state + /// + public string PathSession { get; set; } = string.Empty; + /// + /// string to suffix user inputs with + /// + public string InputSuffix { get; set; } = string.Empty; + /// + /// string to prefix user inputs with + /// + public string InputPrefix { get; set; } = string.Empty; + /// + /// 0 or lower to use vocab size + /// + public int TopK { get; set; } = 40; + /// + /// 1.0 = disabled + /// + public float TopP { get; set; } = 0.95f; + /// + /// 1.0 = disabled + /// + public float TfsZ { get; set; } = 1.0f; + /// + /// 1.0 = disabled + /// + public float TypicalP { get; set; } = 1.0f; + /// + /// 1.0 = disabled + /// + public float Temperature { get; set; } = 0.8f; + /// + /// 1.0 = disabled + /// + public float RepeatPenalty { get; set; } = 1.1f; + /// + /// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n) + /// + public int RepeatLastTokensCount { get; set; } = 64; + /// + /// frequency penalty coefficient + /// 0.0 = disabled + /// + public float FrequencyPenalty { get; set; } = .0f; + /// + /// presence penalty coefficient + /// 0.0 = disabled + /// + public float PresencePenalty { get; set; } = .0f; + /// + /// Mirostat uses tokens instead of words. + /// algorithm described in the paper https://arxiv.org/abs/2007.14966. + /// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + /// + public MirostatType Mirostat { get; set; } = MirostatType.Disable; + /// + /// target entropy + /// + public float MirostatTau { get; set; } = 5.0f; + /// + /// learning rate + /// + public float MirostatEta { get; set; } = 0.1f; + /// + /// consider newlines as a repeatable token (penalize_nl) + /// + public bool PenalizeNL { get; set; } = true; + } +} diff --git a/LLama.Web/Common/LLamaOptions.cs b/LLama.Web/Common/LLamaOptions.cs index 1ac0d829f..2348dd133 100644 --- a/LLama.Web/Common/LLamaOptions.cs +++ b/LLama.Web/Common/LLamaOptions.cs @@ -3,18 +3,9 @@ public class LLamaOptions { public List Models { get; set; } - public List Prompts { get; set; } = new List(); - public List Parameters { get; set; } = new List(); public void Initialize() { - foreach (var prompt in Prompts) - { - if (File.Exists(prompt.Path)) - { - prompt.Prompt = File.ReadAllText(prompt.Path).Trim(); - } - } } } } diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs index e8b89dee8..32431d008 100644 --- a/LLama.Web/Common/ModelOptions.cs +++ b/LLama.Web/Common/ModelOptions.cs @@ -4,112 +4,111 @@ namespace LLama.Web.Common { public class ModelOptions : IModelParams { - public string Name { get; set; } public int MaxInstances { get; set; } - /// - /// Model context size (n_ctx) - /// - public int ContextSize { get; set; } = 512; - /// - /// the GPU that is used for scratch and small tensors - /// - public int MainGpu { get; set; } = 0; - /// - /// if true, reduce VRAM usage at the cost of performance - /// - public bool LowVram { get; set; } = false; - /// - /// Number of layers to run in VRAM / GPU memory (n_gpu_layers) - /// - public int GpuLayerCount { get; set; } = 20; - /// - /// Seed for the random number generator (seed) - /// - public int Seed { get; set; } = 1686349486; - /// - /// Use f16 instead of f32 for memory kv (memory_f16) - /// - public bool UseFp16Memory { get; set; } = true; - /// - /// Use mmap for faster loads (use_mmap) - /// - public bool UseMemorymap { get; set; } = true; - /// - /// Use mlock to keep model in memory (use_mlock) - /// - public bool UseMemoryLock { get; set; } = false; - /// - /// Compute perplexity over the prompt (perplexity) - /// - public bool Perplexity { get; set; } = false; - /// - /// Model path (model) - /// - public string ModelPath { get; set; } - /// - /// model alias - /// - public string ModelAlias { get; set; } = "unknown"; - /// - /// lora adapter path (lora_adapter) - /// - public string LoraAdapter { get; set; } = string.Empty; - /// - /// base model path for the lora adapter (lora_base) - /// - public string LoraBase { get; set; } = string.Empty; - /// - /// Number of threads (-1 = autodetect) (n_threads) - /// - public int Threads { get; set; } = Math.Max(Environment.ProcessorCount / 2, 1); - /// - /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch) - /// - public int BatchSize { get; set; } = 512; + /// + /// Model context size (n_ctx) + /// + public int ContextSize { get; set; } = 512; + /// + /// the GPU that is used for scratch and small tensors + /// + public int MainGpu { get; set; } = 0; + /// + /// if true, reduce VRAM usage at the cost of performance + /// + public bool LowVram { get; set; } = false; + /// + /// Number of layers to run in VRAM / GPU memory (n_gpu_layers) + /// + public int GpuLayerCount { get; set; } = 20; + /// + /// Seed for the random number generator (seed) + /// + public int Seed { get; set; } = 1686349486; + /// + /// Use f16 instead of f32 for memory kv (memory_f16) + /// + public bool UseFp16Memory { get; set; } = true; + /// + /// Use mmap for faster loads (use_mmap) + /// + public bool UseMemorymap { get; set; } = true; + /// + /// Use mlock to keep model in memory (use_mlock) + /// + public bool UseMemoryLock { get; set; } = false; + /// + /// Compute perplexity over the prompt (perplexity) + /// + public bool Perplexity { get; set; } = false; + /// + /// Model path (model) + /// + public string ModelPath { get; set; } + /// + /// model alias + /// + public string ModelAlias { get; set; } = "unknown"; + /// + /// lora adapter path (lora_adapter) + /// + public string LoraAdapter { get; set; } = string.Empty; + /// + /// base model path for the lora adapter (lora_base) + /// + public string LoraBase { get; set; } = string.Empty; + /// + /// Number of threads (-1 = autodetect) (n_threads) + /// + public int Threads { get; set; } = Math.Max(Environment.ProcessorCount / 2, 1); + /// + /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch) + /// + public int BatchSize { get; set; } = 512; - /// - /// Whether to convert eos to newline during the inference. - /// - public bool ConvertEosToNewLine { get; set; } = false; + /// + /// Whether to convert eos to newline during the inference. + /// + public bool ConvertEosToNewLine { get; set; } = false; - /// - /// Whether to use embedding mode. (embedding) Note that if this is set to true, - /// The LLamaModel won't produce text response anymore. - /// - public bool EmbeddingMode { get; set; } = false; + /// + /// Whether to use embedding mode. (embedding) Note that if this is set to true, + /// The LLamaModel won't produce text response anymore. + /// + public bool EmbeddingMode { get; set; } = false; - /// - /// how split tensors should be distributed across GPUs - /// - public float[] TensorSplits { get; set; } + /// + /// how split tensors should be distributed across GPUs + /// + public float[] TensorSplits { get; set; } - /// - /// Grouped-Query Attention - /// - public int GroupedQueryAttention { get; set; } = 1; + /// + /// Grouped-Query Attention + /// + public int GroupedQueryAttention { get; set; } = 1; - /// - /// RMS Norm Epsilon - /// - public float RmsNormEpsilon { get; set; } = 5e-6f; + /// + /// RMS Norm Epsilon + /// + public float RmsNormEpsilon { get; set; } = 5e-6f; - /// - /// RoPE base frequency - /// - public float RopeFrequencyBase { get; set; } = 10000.0f; + /// + /// RoPE base frequency + /// + public float RopeFrequencyBase { get; set; } = 10000.0f; - /// - /// RoPE frequency scaling factor - /// - public float RopeFrequencyScale { get; set; } = 1.0f; + /// + /// RoPE frequency scaling factor + /// + public float RopeFrequencyScale { get; set; } = 1.0f; - /// - /// Use experimental mul_mat_q kernels - /// - public bool MulMatQ { get; set; } + /// + /// Use experimental mul_mat_q kernels + /// + public bool MulMatQ { get; set; } - } + } } diff --git a/LLama.Web/Common/ParameterOptions.cs b/LLama.Web/Common/ParameterOptions.cs deleted file mode 100644 index 7677f04ae..000000000 --- a/LLama.Web/Common/ParameterOptions.cs +++ /dev/null @@ -1,99 +0,0 @@ -using LLama.Common; -using LLama.Abstractions; - -namespace LLama.Web.Common -{ - public class ParameterOptions : IInferenceParams - { - public string Name { get; set; } - - - - /// - /// number of tokens to keep from initial prompt - /// - public int TokensKeep { get; set; } = 0; - /// - /// how many new tokens to predict (n_predict), set to -1 to inifinitely generate response - /// until it complete. - /// - public int MaxTokens { get; set; } = -1; - /// - /// logit bias for specific tokens - /// - public Dictionary? LogitBias { get; set; } = null; - - /// - /// Sequences where the model will stop generating further tokens. - /// - public IEnumerable AntiPrompts { get; set; } = Array.Empty(); - /// - /// path to file for saving/loading model eval state - /// - public string PathSession { get; set; } = string.Empty; - /// - /// string to suffix user inputs with - /// - public string InputSuffix { get; set; } = string.Empty; - /// - /// string to prefix user inputs with - /// - public string InputPrefix { get; set; } = string.Empty; - /// - /// 0 or lower to use vocab size - /// - public int TopK { get; set; } = 40; - /// - /// 1.0 = disabled - /// - public float TopP { get; set; } = 0.95f; - /// - /// 1.0 = disabled - /// - public float TfsZ { get; set; } = 1.0f; - /// - /// 1.0 = disabled - /// - public float TypicalP { get; set; } = 1.0f; - /// - /// 1.0 = disabled - /// - public float Temperature { get; set; } = 0.8f; - /// - /// 1.0 = disabled - /// - public float RepeatPenalty { get; set; } = 1.1f; - /// - /// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n) - /// - public int RepeatLastTokensCount { get; set; } = 64; - /// - /// frequency penalty coefficient - /// 0.0 = disabled - /// - public float FrequencyPenalty { get; set; } = .0f; - /// - /// presence penalty coefficient - /// 0.0 = disabled - /// - public float PresencePenalty { get; set; } = .0f; - /// - /// Mirostat uses tokens instead of words. - /// algorithm described in the paper https://arxiv.org/abs/2007.14966. - /// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - /// - public MirostatType Mirostat { get; set; } = MirostatType.Disable; - /// - /// target entropy - /// - public float MirostatTau { get; set; } = 5.0f; - /// - /// learning rate - /// - public float MirostatEta { get; set; } = 0.1f; - /// - /// consider newlines as a repeatable token (penalize_nl) - /// - public bool PenalizeNL { get; set; } = true; - } -} diff --git a/LLama.Web/Common/PromptOptions.cs b/LLama.Web/Common/PromptConfig.cs similarity index 63% rename from LLama.Web/Common/PromptOptions.cs rename to LLama.Web/Common/PromptConfig.cs index 4e44a5d12..38cef2908 100644 --- a/LLama.Web/Common/PromptOptions.cs +++ b/LLama.Web/Common/PromptConfig.cs @@ -1,9 +1,7 @@ namespace LLama.Web.Common { - public class PromptOptions + public class PromptConfig { - public string Name { get; set; } - public string Path { get; set; } public string Prompt { get; set; } public List AntiPrompt { get; set; } public List OutputFilter { get; set; } diff --git a/LLama.Web/Common/ServiceResult.cs b/LLama.Web/Common/ServiceResult.cs deleted file mode 100644 index 709a6d3aa..000000000 --- a/LLama.Web/Common/ServiceResult.cs +++ /dev/null @@ -1,41 +0,0 @@ -namespace LLama.Web.Common -{ - public class ServiceResult : ServiceResult, IServiceResult - { - public T Value { get; set; } - } - - - public class ServiceResult - { - public string Error { get; set; } - - public bool HasError - { - get { return !string.IsNullOrEmpty(Error); } - } - - public static IServiceResult FromValue(T value) - { - return new ServiceResult - { - Value = value, - }; - } - - public static IServiceResult FromError(string error) - { - return new ServiceResult - { - Error = error, - }; - } - } - - public interface IServiceResult - { - T Value { get; set; } - string Error { get; set; } - bool HasError { get; } - } -} diff --git a/LLama.Web/Common/SessionConfig.cs b/LLama.Web/Common/SessionConfig.cs new file mode 100644 index 000000000..3781c7d11 --- /dev/null +++ b/LLama.Web/Common/SessionConfig.cs @@ -0,0 +1,13 @@ +namespace LLama.Web.Common +{ + public class SessionConfig + { + public string Model { get; set; } + public string Prompt { get; set; } + public LLamaExecutorType ExecutorType { get; set; } = LLamaExecutorType.Instruct; + public string AntiPrompt { get; set; } = string.Empty; + public string OutputFilter { get; set; } = string.Empty; + public string InputSuffix { get; set; } = string.Empty; + public string InputPrefix { get; set; } = string.Empty; + } +} diff --git a/LLama.Web/Hubs/ISessionClient.cs b/LLama.Web/Hubs/ISessionClient.cs index 9e9dc0f19..051fa7a9d 100644 --- a/LLama.Web/Hubs/ISessionClient.cs +++ b/LLama.Web/Hubs/ISessionClient.cs @@ -1,12 +1,10 @@ using LLama.Web.Common; -using LLama.Web.Models; namespace LLama.Web.Hubs { public interface ISessionClient { Task OnStatus(string connectionId, SessionConnectionStatus status); - Task OnResponse(ResponseFragment fragment); Task OnError(string error); } } diff --git a/LLama.Web/Hubs/SessionConnectionHub.cs b/LLama.Web/Hubs/SessionConnectionHub.cs index 080866c6b..a6ee6df66 100644 --- a/LLama.Web/Hubs/SessionConnectionHub.cs +++ b/LLama.Web/Hubs/SessionConnectionHub.cs @@ -2,16 +2,15 @@ using LLama.Web.Models; using LLama.Web.Services; using Microsoft.AspNetCore.SignalR; -using System.Diagnostics; namespace LLama.Web.Hubs { public class SessionConnectionHub : Hub { private readonly ILogger _logger; - private readonly ConnectionSessionService _modelSessionService; + private readonly IModelSessionService _modelSessionService; - public SessionConnectionHub(ILogger logger, ConnectionSessionService modelSessionService) + public SessionConnectionHub(ILogger logger, IModelSessionService modelSessionService) { _logger = logger; _modelSessionService = modelSessionService; @@ -27,29 +26,26 @@ public override async Task OnConnectedAsync() } - public override async Task OnDisconnectedAsync(Exception? exception) + public override async Task OnDisconnectedAsync(Exception exception) { _logger.Log(LogLevel.Information, "[OnDisconnectedAsync], Id: {0}", Context.ConnectionId); // Remove connections session on dissconnect - await _modelSessionService.RemoveAsync(Context.ConnectionId); + await _modelSessionService.CloseAsync(Context.ConnectionId); await base.OnDisconnectedAsync(exception); } [HubMethodName("LoadModel")] - public async Task OnLoadModel(LLamaExecutorType executorType, string modelName, string promptName, string parameterName) + public async Task OnLoadModel(SessionConfig sessionConfig, InferenceOptions inferenceConfig) { - _logger.Log(LogLevel.Information, "[OnLoadModel] - Load new model, Connection: {0}, Model: {1}, Prompt: {2}, Parameter: {3}", Context.ConnectionId, modelName, promptName, parameterName); - - // Remove existing connections session - await _modelSessionService.RemoveAsync(Context.ConnectionId); + _logger.Log(LogLevel.Information, "[OnLoadModel] - Load new model, Connection: {0}", Context.ConnectionId); // Create model session - var modelSessionResult = await _modelSessionService.CreateAsync(executorType, Context.ConnectionId, modelName, promptName, parameterName); - if (modelSessionResult.HasError) + var modelSession = await _modelSessionService.CreateAsync(Context.ConnectionId, sessionConfig, inferenceConfig); + if (modelSession is null) { - await Clients.Caller.OnError(modelSessionResult.Error); + await Clients.Caller.OnError("Failed to create model session"); return; } @@ -59,40 +55,12 @@ public async Task OnLoadModel(LLamaExecutorType executorType, string modelName, [HubMethodName("SendPrompt")] - public async Task OnSendPrompt(string prompt) + public IAsyncEnumerable OnSendPrompt(string prompt, InferenceOptions inferConfig, CancellationToken cancellationToken) { _logger.Log(LogLevel.Information, "[OnSendPrompt] - New prompt received, Connection: {0}", Context.ConnectionId); - // Get connections session - var modelSession = await _modelSessionService.GetAsync(Context.ConnectionId); - if (modelSession is null) - { - await Clients.Caller.OnError("No model has been loaded"); - return; - } - - - // Create unique response id - var responseId = Guid.NewGuid().ToString(); - - // Send begin of response - await Clients.Caller.OnResponse(new ResponseFragment(responseId, isFirst: true)); - - // Send content of response - var stopwatch = Stopwatch.GetTimestamp(); - await foreach (var fragment in modelSession.InferAsync(prompt, CancellationTokenSource.CreateLinkedTokenSource(Context.ConnectionAborted))) - { - await Clients.Caller.OnResponse(new ResponseFragment(responseId, fragment)); - } - - // Send end of response - var elapsedTime = Stopwatch.GetElapsedTime(stopwatch); - var signature = modelSession.IsInferCanceled() - ? $"Inference cancelled after {elapsedTime.TotalSeconds:F0} seconds" - : $"Inference completed in {elapsedTime.TotalSeconds:F0} seconds"; - await Clients.Caller.OnResponse(new ResponseFragment(responseId, signature, isLast: true)); - _logger.Log(LogLevel.Information, "[OnSendPrompt] - Inference complete, Connection: {0}, Elapsed: {1}, Canceled: {2}", Context.ConnectionId, elapsedTime, modelSession.IsInferCanceled()); + var linkedCancelationToken = CancellationTokenSource.CreateLinkedTokenSource(Context.ConnectionAborted, cancellationToken); + return _modelSessionService.InferAsync(Context.ConnectionId, prompt, inferConfig, linkedCancelationToken.Token); } - } } diff --git a/LLama.Web/Models/CancelModel.cs b/LLama.Web/Models/CancelModel.cs deleted file mode 100644 index de7e3f2f7..000000000 --- a/LLama.Web/Models/CancelModel.cs +++ /dev/null @@ -1,7 +0,0 @@ -namespace LLama.Web.Models -{ - public class CancelModel - { - public string ConnectionId { get; set; } - } -} diff --git a/LLama.Web/Models/InferTokenModel.cs b/LLama.Web/Models/InferTokenModel.cs new file mode 100644 index 000000000..f25c12f6a --- /dev/null +++ b/LLama.Web/Models/InferTokenModel.cs @@ -0,0 +1,12 @@ +namespace LLama.Web.Models +{ + public record InferTokenModel(int TokenId, float Probability, string Content, InferTokenType Type, int Elapsed); + + public enum InferTokenType + { + Begin = 0, + Content = 2, + End = 4, + Cancel = 10 + } +} diff --git a/LLama.Web/Models/ModelSession.cs b/LLama.Web/Models/ModelSession.cs index c53676f24..8b7b77b0f 100644 --- a/LLama.Web/Models/ModelSession.cs +++ b/LLama.Web/Models/ModelSession.cs @@ -1,68 +1,154 @@ using LLama.Abstractions; +using LLama.Common; using LLama.Web.Common; namespace LLama.Web.Models { - public class ModelSession : IDisposable + public class ModelSession { - private bool _isFirstInteraction = true; - private ModelOptions _modelOptions; - private PromptOptions _promptOptions; - private ParameterOptions _inferenceOptions; - private ITextStreamTransform _outputTransform; - private ILLamaExecutor _executor; + private readonly LLamaContext _context; + private readonly ILLamaExecutor _executor; + private readonly SessionConfig _sessionParams; + private readonly PromptConfig _promptParams; + private readonly ITextStreamTransform _outputTransform; + + private IInferenceParams _inferenceParams; private CancellationTokenSource _cancellationTokenSource; - public ModelSession(ILLamaExecutor executor, ModelOptions modelOptions, PromptOptions promptOptions, ParameterOptions parameterOptions) - { - _executor = executor; - _modelOptions = modelOptions; - _promptOptions = promptOptions; - _inferenceOptions = parameterOptions; - - _inferenceOptions.AntiPrompts = _promptOptions.AntiPrompt?.Concat(_inferenceOptions.AntiPrompts ?? Enumerable.Empty()).Distinct() ?? _inferenceOptions.AntiPrompts; - if (_promptOptions.OutputFilter?.Count > 0) - _outputTransform = new LLamaTransforms.KeywordTextOutputStreamTransform(_promptOptions.OutputFilter, redundancyLength: 5); - } - public string ModelName + /// + /// Initializes a new instance of the class. + /// + /// The context. + /// The session identifier. + /// The session configuration. + /// The inference parameters. + public ModelSession(LLamaContext context, SessionConfig sessionConfig, IInferenceParams inferenceParams = null) { - get { return _modelOptions.Name; } + _context = context; + _sessionParams = sessionConfig; + _inferenceParams = inferenceParams; + + // Executor + _executor = sessionConfig.ExecutorType switch + { + LLamaExecutorType.Interactive => new InteractiveExecutor(_context), + LLamaExecutorType.Instruct => new InstructExecutor(_context), + LLamaExecutorType.Stateless => new StatelessExecutor(_context), + _ => default + }; + + // Initial Prompt + _promptParams = new PromptConfig + { + Prompt = _sessionParams.Prompt, + AntiPrompt = CommaSeperatedToList(_sessionParams.AntiPrompt), + OutputFilter = CommaSeperatedToList(_sessionParams.OutputFilter), + }; + + //Output Filter + if (_promptParams.OutputFilter?.Count > 0) + _outputTransform = new LLamaTransforms.KeywordTextOutputStreamTransform(_promptParams.OutputFilter, redundancyLength: 8); } - public IAsyncEnumerable InferAsync(string message, CancellationTokenSource cancellationTokenSource) + + /// + /// Gets the name of the model. + /// + public string ModelName => _sessionParams.Model; + + + /// + /// Initializes the prompt. + /// + /// The inference parameters. + /// The cancellation token. + public async Task InitializePrompt(IInferenceParams inferenceParams = null, CancellationToken cancellationToken = default) { - _cancellationTokenSource = cancellationTokenSource; - if (_isFirstInteraction) + ConfigureInferenceParams(inferenceParams); + + if (_executor is StatelessExecutor) + return; + + // Run Initial prompt + _cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + await foreach (var _ in _executor.InferAsync(_sessionParams.Prompt, inferenceParams, _cancellationTokenSource.Token)) { - _isFirstInteraction = false; - message = _promptOptions.Prompt + message; - } + // We dont really need the response of the initial prompt, so exit on first token + break; + }; + } + + /// + /// Runs inference on the model context + /// + /// The message. + /// The inference parameters. + /// The cancellation token. + /// + public IAsyncEnumerable InferAsync(string message, IInferenceParams inferenceParams = null, CancellationToken cancellationToken = default) + { + ConfigureInferenceParams(inferenceParams); + _cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); if (_outputTransform is not null) - return _outputTransform.TransformAsync(_executor.InferAsync(message, _inferenceOptions, _cancellationTokenSource.Token)); + return _outputTransform.TransformAsync(_executor.InferAsync(message, inferenceParams, _cancellationTokenSource.Token)); - return _executor.InferAsync(message, _inferenceOptions, _cancellationTokenSource.Token); + return _executor.InferAsync(message, inferenceParams, _cancellationTokenSource.Token); } + /// + /// Cancels the current inference. + /// public void CancelInfer() { _cancellationTokenSource?.Cancel(); } + + /// + /// Determines whether inference is canceled. + /// + /// + /// true if inference is canceled; otherwise, false. + /// public bool IsInferCanceled() { - return _cancellationTokenSource.IsCancellationRequested; + return _cancellationTokenSource?.IsCancellationRequested ?? false; } - public void Dispose() + + /// + /// Configures the inference parameters. + /// + /// The inference parameters. + private void ConfigureInferenceParams(IInferenceParams inferenceParams) + { + // If not null set as default + if (inferenceParams is not null) + _inferenceParams = inferenceParams; + + // If null set to new + if (_inferenceParams is null) + _inferenceParams = new InferenceParams(); + + // Merge Antiprompts + var antiPrompts = new List(); + antiPrompts.AddRange(_promptParams.AntiPrompt ?? Enumerable.Empty()); + antiPrompts.AddRange(_inferenceParams.AntiPrompts ?? Enumerable.Empty()); + _inferenceParams.AntiPrompts = antiPrompts.Distinct(); + } + + + private static List CommaSeperatedToList(string value) { - _inferenceOptions = null; - _outputTransform = null; + if (string.IsNullOrEmpty(value)) + return null; - _executor?.Context.Dispose(); - _executor = null; + return value.Split(",", StringSplitOptions.RemoveEmptyEntries) + .Select(x => x.Trim()) + .ToList(); } } } diff --git a/LLama.Web/Models/ResponseFragment.cs b/LLama.Web/Models/ResponseFragment.cs deleted file mode 100644 index 02f27f13e..000000000 --- a/LLama.Web/Models/ResponseFragment.cs +++ /dev/null @@ -1,18 +0,0 @@ -namespace LLama.Web.Models -{ - public class ResponseFragment - { - public ResponseFragment(string id, string content = null, bool isFirst = false, bool isLast = false) - { - Id = id; - IsLast = isLast; - IsFirst = isFirst; - Content = content; - } - - public string Id { get; set; } - public string Content { get; set; } - public bool IsLast { get; set; } - public bool IsFirst { get; set; } - } -} diff --git a/LLama.Web/Pages/Executor/Instruct.cshtml b/LLama.Web/Pages/Executor/Instruct.cshtml deleted file mode 100644 index 9f8cb2d89..000000000 --- a/LLama.Web/Pages/Executor/Instruct.cshtml +++ /dev/null @@ -1,96 +0,0 @@ -@page -@model InstructModel -@{ - -} -@Html.AntiForgeryToken() -
- -
-
-

Instruct

-
- Hub: Disconnected -
-
- -
- Model - -
- -
- Parameters - -
- -
- Prompt - - -
- -
-
-
- -
-
- -
-
-
- -
-
-
-
- -
-
- -
-
- -
-
- - -
-
-
-
- -
-
- -@{ await Html.RenderPartialAsync("_ChatTemplates"); } - -@section Scripts { - - -} \ No newline at end of file diff --git a/LLama.Web/Pages/Executor/Instruct.cshtml.cs b/LLama.Web/Pages/Executor/Instruct.cshtml.cs deleted file mode 100644 index 18a58253b..000000000 --- a/LLama.Web/Pages/Executor/Instruct.cshtml.cs +++ /dev/null @@ -1,34 +0,0 @@ -using LLama.Web.Common; -using LLama.Web.Models; -using LLama.Web.Services; -using Microsoft.AspNetCore.Mvc; -using Microsoft.AspNetCore.Mvc.RazorPages; -using Microsoft.Extensions.Options; - -namespace LLama.Web.Pages -{ - public class InstructModel : PageModel - { - private readonly ILogger _logger; - private readonly ConnectionSessionService _modelSessionService; - - public InstructModel(ILogger logger, IOptions options, ConnectionSessionService modelSessionService) - { - _logger = logger; - Options = options.Value; - _modelSessionService = modelSessionService; - } - - public LLamaOptions Options { get; set; } - - public void OnGet() - { - } - - public async Task OnPostCancel(CancelModel model) - { - await _modelSessionService.CancelAsync(model.ConnectionId); - return new JsonResult(default); - } - } -} \ No newline at end of file diff --git a/LLama.Web/Pages/Executor/Instruct.cshtml.css b/LLama.Web/Pages/Executor/Instruct.cshtml.css deleted file mode 100644 index ed9a1d59f..000000000 --- a/LLama.Web/Pages/Executor/Instruct.cshtml.css +++ /dev/null @@ -1,4 +0,0 @@ -.section-content { - flex: 1; - overflow-y: scroll; -} diff --git a/LLama.Web/Pages/Executor/Interactive.cshtml b/LLama.Web/Pages/Executor/Interactive.cshtml deleted file mode 100644 index 916b59ca8..000000000 --- a/LLama.Web/Pages/Executor/Interactive.cshtml +++ /dev/null @@ -1,96 +0,0 @@ -@page -@model InteractiveModel -@{ - -} -@Html.AntiForgeryToken() -
- -
-
-

Interactive

-
- Hub: Disconnected -
-
- -
- Model - -
- -
- Parameters - -
- -
- Prompt - - -
- -
-
-
- -
-
- -
-
-
- -
-
-
-
- -
-
- -
-
- -
-
- - -
-
-
-
- -
-
- -@{ await Html.RenderPartialAsync("_ChatTemplates");} - -@section Scripts { - - -} \ No newline at end of file diff --git a/LLama.Web/Pages/Executor/Interactive.cshtml.cs b/LLama.Web/Pages/Executor/Interactive.cshtml.cs deleted file mode 100644 index 7179a4405..000000000 --- a/LLama.Web/Pages/Executor/Interactive.cshtml.cs +++ /dev/null @@ -1,34 +0,0 @@ -using LLama.Web.Common; -using LLama.Web.Models; -using LLama.Web.Services; -using Microsoft.AspNetCore.Mvc; -using Microsoft.AspNetCore.Mvc.RazorPages; -using Microsoft.Extensions.Options; - -namespace LLama.Web.Pages -{ - public class InteractiveModel : PageModel - { - private readonly ILogger _logger; - private readonly ConnectionSessionService _modelSessionService; - - public InteractiveModel(ILogger logger, IOptions options, ConnectionSessionService modelSessionService) - { - _logger = logger; - Options = options.Value; - _modelSessionService = modelSessionService; - } - - public LLamaOptions Options { get; set; } - - public void OnGet() - { - } - - public async Task OnPostCancel(CancelModel model) - { - await _modelSessionService.CancelAsync(model.ConnectionId); - return new JsonResult(default); - } - } -} \ No newline at end of file diff --git a/LLama.Web/Pages/Executor/Interactive.cshtml.css b/LLama.Web/Pages/Executor/Interactive.cshtml.css deleted file mode 100644 index ed9a1d59f..000000000 --- a/LLama.Web/Pages/Executor/Interactive.cshtml.css +++ /dev/null @@ -1,4 +0,0 @@ -.section-content { - flex: 1; - overflow-y: scroll; -} diff --git a/LLama.Web/Pages/Executor/Stateless.cshtml b/LLama.Web/Pages/Executor/Stateless.cshtml deleted file mode 100644 index b5d8eea37..000000000 --- a/LLama.Web/Pages/Executor/Stateless.cshtml +++ /dev/null @@ -1,97 +0,0 @@ -@page -@model StatelessModel -@{ - -} -@Html.AntiForgeryToken() -
- -
-
-

Stateless

-
- Hub: Disconnected -
-
- -
- Model - -
- -
- Parameters - -
- -
- Prompt - - -
- -
-
-
- -
-
- -
-
-
- -
-
-
-
- -
-
- -
-
- -
-
- - -
-
-
-
- -
-
- -@{ await Html.RenderPartialAsync("_ChatTemplates"); } - - -@section Scripts { - - -} \ No newline at end of file diff --git a/LLama.Web/Pages/Executor/Stateless.cshtml.cs b/LLama.Web/Pages/Executor/Stateless.cshtml.cs deleted file mode 100644 index f88c4b832..000000000 --- a/LLama.Web/Pages/Executor/Stateless.cshtml.cs +++ /dev/null @@ -1,34 +0,0 @@ -using LLama.Web.Common; -using LLama.Web.Models; -using LLama.Web.Services; -using Microsoft.AspNetCore.Mvc; -using Microsoft.AspNetCore.Mvc.RazorPages; -using Microsoft.Extensions.Options; - -namespace LLama.Web.Pages -{ - public class StatelessModel : PageModel - { - private readonly ILogger _logger; - private readonly ConnectionSessionService _modelSessionService; - - public StatelessModel(ILogger logger, IOptions options, ConnectionSessionService modelSessionService) - { - _logger = logger; - Options = options.Value; - _modelSessionService = modelSessionService; - } - - public LLamaOptions Options { get; set; } - - public void OnGet() - { - } - - public async Task OnPostCancel(CancelModel model) - { - await _modelSessionService.CancelAsync(model.ConnectionId); - return new JsonResult(default); - } - } -} \ No newline at end of file diff --git a/LLama.Web/Pages/Executor/Stateless.cshtml.css b/LLama.Web/Pages/Executor/Stateless.cshtml.css deleted file mode 100644 index ed9a1d59f..000000000 --- a/LLama.Web/Pages/Executor/Stateless.cshtml.css +++ /dev/null @@ -1,4 +0,0 @@ -.section-content { - flex: 1; - overflow-y: scroll; -} diff --git a/LLama.Web/Pages/Index.cshtml b/LLama.Web/Pages/Index.cshtml index b5f0c15fc..dcb088371 100644 --- a/LLama.Web/Pages/Index.cshtml +++ b/LLama.Web/Pages/Index.cshtml @@ -1,10 +1,119 @@ @page @model IndexModel @{ - ViewData["Title"] = "Home page"; + ViewData["Title"] = "Inference Demo"; } -
-

Welcome

-

Learn about building Web apps with ASP.NET Core.

+@Html.AntiForgeryToken() +
+ +
+
+
+ @ViewData["Title"] +
+
+ Socket: Disconnected +
+
+ +
+
+
+
+ Model + @Html.DropDownListFor(m => m.SessionOptions.Model, new SelectList(Model.Options.Models, "Name", "Name"), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"}) +
+
+ Executor + @Html.DropDownListFor(m => m.SessionOptions.ExecutorType, Html.GetEnumSelectList(), new { @class = "form-control prompt-control" ,required="required", autocomplete="off"}) +
+ + +
+
+
+ +
+
+
+ + +
+
+ +
+
+
+ +
+
+
+
+ +
+
+ +
+
+ +
+
+ + +
+
+
+
+ +
+ +@{ + await Html.RenderPartialAsync("_ChatTemplates"); +} + +@section Scripts { + + +} \ No newline at end of file diff --git a/LLama.Web/Pages/Index.cshtml.cs b/LLama.Web/Pages/Index.cshtml.cs index 477c9bfbe..576867859 100644 --- a/LLama.Web/Pages/Index.cshtml.cs +++ b/LLama.Web/Pages/Index.cshtml.cs @@ -1,5 +1,7 @@ -using Microsoft.AspNetCore.Mvc; +using LLama.Web.Common; +using Microsoft.AspNetCore.Mvc; using Microsoft.AspNetCore.Mvc.RazorPages; +using Microsoft.Extensions.Options; namespace LLama.Web.Pages { @@ -7,14 +9,33 @@ public class IndexModel : PageModel { private readonly ILogger _logger; - public IndexModel(ILogger logger) + public IndexModel(ILogger logger, IOptions options) { _logger = logger; + Options = options.Value; } + public LLamaOptions Options { get; set; } + + [BindProperty] + public SessionConfig SessionOptions { get; set; } + + [BindProperty] + public InferenceOptions InferenceOptions { get; set; } + public void OnGet() { + SessionOptions = new SessionConfig + { + Prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.", + AntiPrompt = "User:", + // OutputFilter = "User:, Response:" + }; + InferenceOptions = new InferenceOptions + { + Temperature = 0.8f + }; } } } \ No newline at end of file diff --git a/LLama.Web/Pages/Shared/_ChatTemplates.cshtml b/LLama.Web/Pages/Shared/_ChatTemplates.cshtml index 156440124..cd768f1f5 100644 --- a/LLama.Web/Pages/Shared/_ChatTemplates.cshtml +++ b/LLama.Web/Pages/Shared/_ChatTemplates.cshtml @@ -12,7 +12,7 @@
- {{text}} + {{text}}
{{date}}
@@ -26,9 +26,7 @@
- - - +
@@ -41,20 +39,6 @@
- \ No newline at end of file diff --git a/LLama.Web/Pages/Shared/_Layout.cshtml b/LLama.Web/Pages/Shared/_Layout.cshtml index 23132bfa4..16d6ad527 100644 --- a/LLama.Web/Pages/Shared/_Layout.cshtml +++ b/LLama.Web/Pages/Shared/_Layout.cshtml @@ -3,7 +3,7 @@ - @ViewData["Title"] - LLama.Web + @ViewData["Title"] - LLamaSharp.Web @@ -13,24 +13,26 @@
-
- @RenderBody() -
+
+ @RenderBody() +
- © 2023 - LLama.Web + © 2023 - LLamaSharp.Web
diff --git a/LLama.Web/Pages/Shared/_Parameters.cshtml b/LLama.Web/Pages/Shared/_Parameters.cshtml new file mode 100644 index 000000000..165b65a87 --- /dev/null +++ b/LLama.Web/Pages/Shared/_Parameters.cshtml @@ -0,0 +1,137 @@ +@page +@model LLama.Abstractions.IInferenceParams +@{ +} + +
+
+ MaxTokens +
+ @Html.TextBoxFor(m => m.MaxTokens, new { @type="range", @class = "slider", min="-1", max="2048", step="1" }) + +
+
+ +
+ TokensKeep +
+ @Html.TextBoxFor(m => m.TokensKeep, new { @type="range", @class = "slider", min="0", max="2048", step="1" }) + +
+
+
+ +
+
+ TopK +
+ @Html.TextBoxFor(m => m.TopK, new { @type="range", @class = "slider", min="-1", max="100", step="1" }) + +
+
+ +
+ TopP +
+ @Html.TextBoxFor(m => m.TopP, new { @type="range", @class = "slider", min="0.0", max="1.0", step="0.01" }) + +
+
+
+ + + +
+
+ TypicalP +
+ @Html.TextBoxFor(m => m.TypicalP, new { @type="range", @class = "slider", min="0.0", max="1.0", step="0.01" }) + +
+
+ +
+ Temperature +
+ @Html.TextBoxFor(m => m.Temperature, new { @type="range", @class = "slider", min="0.0", max="1.5", step="0.01" }) + +
+
+
+ +
+
+ RepeatPenalty +
+ @Html.TextBoxFor(m => m.RepeatPenalty, new { @type="range", @class = "slider", min="0.0", max="2.0", step="0.01" }) + +
+
+ +
+ RepeatLastTokensCount +
+ @Html.TextBoxFor(m => m.RepeatLastTokensCount, new { @type="range", @class = "slider", min="0", max="2048", step="1" }) + +
+
+
+ +
+
+ FrequencyPenalty +
+ @Html.TextBoxFor(m => m.FrequencyPenalty, new { @type="range", @class = "slider", min="0.0", max="1.0", step="0.01" }) + +
+
+ +
+ PresencePenalty +
+ @Html.TextBoxFor(m => m.PresencePenalty, new { @type="range", @class = "slider", min="0.0", max="1.0", step="0.01" }) + +
+
+
+ +
+
+ TfsZ +
+ @Html.TextBoxFor(m => m.TfsZ, new { @type="range", @class = "slider",min="0.0", max="1.0", step="0.01" }) + +
+
+
+ - +
+ + +
+
+
+ + +
+ Mirostat + @Html.DropDownListFor(m => m.Mirostat, Html.GetEnumSelectList(), new { @class = "form-control form-select" }) +
+ +
+
+ MirostatTau +
+ @Html.TextBoxFor(m => m.MirostatTau, new { @type="range", @class = "slider", min="0.0", max="10.0", step="0.01" }) + +
+
+ +
+ MirostatEta +
+ @Html.TextBoxFor(m => m.MirostatEta, new { @type="range", @class = "slider", min="0.0", max="1.0", step="0.01" }) + +
+
+
\ No newline at end of file diff --git a/LLama.Web/Pages/Shared/_Parameters.cshtml.cs b/LLama.Web/Pages/Shared/_Parameters.cshtml.cs new file mode 100644 index 000000000..60c3e3348 --- /dev/null +++ b/LLama.Web/Pages/Shared/_Parameters.cshtml.cs @@ -0,0 +1,12 @@ +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.RazorPages; + +namespace LLama.Web.Pages.Shared +{ + public class _ParametersModel : PageModel + { + public void OnGet() + { + } + } +} diff --git a/LLama.Web/Program.cs b/LLama.Web/Program.cs index 6db653a14..ee63590e0 100644 --- a/LLama.Web/Program.cs +++ b/LLama.Web/Program.cs @@ -20,7 +20,8 @@ public static void Main(string[] args) .BindConfiguration(nameof(LLamaOptions)); // Services DI - builder.Services.AddSingleton(); + builder.Services.AddSingleton(); + builder.Services.AddSingleton(); var app = builder.Build(); diff --git a/LLama.Web/README.md b/LLama.Web/README.md index 9b6786e6b..611182b20 100644 --- a/LLama.Web/README.md +++ b/LLama.Web/README.md @@ -1,37 +1,55 @@ ## LLama.Web - Basic ASP.NET Core examples of LLamaSharp in action -LLama.Web has no heavy dependencies and no extra frameworks ove bootstrap and jquery to keep the examples clean and easy to copy over to your own project +LLama.Web has no heavy dependencies and no extra frameworks over bootstrap and jquery to keep the examples clean and easy to copy over to your own project ## Websockets Using signalr websockets simplifys the streaming of responses and model per connection management - - ## Setup -You can setup Models, Prompts and Inference parameters in the appsettings.json +You can setup Models and parameters in the appsettings.json **Models** You can add multiple models to the options for quick selection in the UI, options are based on ModelParams so its fully configurable -**Parameters** -You can add multiple sets of inference parameters to the options for quick selection in the UI, options are based on InferenceParams so its fully configurable - -**Prompts** -You can add multiple sets of prompts to the options for quick selection in the UI - Example: ```json - { - "Name": "Alpaca", - "Path": "D:\\Repositories\\AI\\Prompts\\alpaca.txt", - "Prompt": "Alternativly to can set a prompt text directly and omit the Path" - "AntiPrompt": [ - "User:" - ], - "OutputFilter": [ - "Response:", - "User:" - ] - } +{ + "Logging": { + "LogLevel": { + "Default": "Information", + "Microsoft.AspNetCore": "Warning" + } + }, + "AllowedHosts": "*", + "LLamaConfig": { + "Models": [{ + "Name": "WizardLM-7B", + "MaxInstances": 2, + "ModelPath": "D:\\Repositories\\Models\\wizardLM-7B.ggmlv3.q4_0.bin", + "ContextSize": 512, + "BatchSize": 512, + "Threads": -1, + "GpuLayerCount": 20, + "UseMemorymap": true, + "UseMemoryLock": false, + "MainGpu": 0, + "LowVram": false, + "Seed": 1686349486, + "UseFp16Memory": true, + "Perplexity": false, + "ModelAlias": "unknown", + "LoraAdapter": "", + "LoraBase": "", + "ConvertEosToNewLine": false, + "EmbeddingMode": false, + "TensorSplits": null, + "GroupedQueryAttention": 1, + "RmsNormEpsilon": 0.000005, + "RopeFrequencyBase": 10000.0, + "RopeFrequencyScale": 1.0, + "MulMatQ": false + }] + } +} ``` diff --git a/LLama.Web/Services/ConnectionSessionService.cs b/LLama.Web/Services/ConnectionSessionService.cs deleted file mode 100644 index 7dfcde397..000000000 --- a/LLama.Web/Services/ConnectionSessionService.cs +++ /dev/null @@ -1,94 +0,0 @@ -using LLama.Abstractions; -using LLama.Web.Common; -using LLama.Web.Models; -using Microsoft.Extensions.Options; -using System.Collections.Concurrent; -using System.Drawing; - -namespace LLama.Web.Services -{ - /// - /// Example Service for handling a model session for a websockets connection lifetime - /// Each websocket connection will create its own unique session and context allowing you to use multiple tabs to compare prompts etc - /// - public class ConnectionSessionService : IModelSessionService - { - private readonly LLamaOptions _options; - private readonly ILogger _logger; - private readonly ConcurrentDictionary _modelSessions; - - public ConnectionSessionService(ILogger logger, IOptions options) - { - _logger = logger; - _options = options.Value; - _modelSessions = new ConcurrentDictionary(); - } - - public Task GetAsync(string connectionId) - { - _modelSessions.TryGetValue(connectionId, out var modelSession); - return Task.FromResult(modelSession); - } - - public Task> CreateAsync(LLamaExecutorType executorType, string connectionId, string modelName, string promptName, string parameterName) - { - var modelOption = _options.Models.FirstOrDefault(x => x.Name == modelName); - if (modelOption is null) - return Task.FromResult(ServiceResult.FromError($"Model option '{modelName}' not found")); - - var promptOption = _options.Prompts.FirstOrDefault(x => x.Name == promptName); - if (promptOption is null) - return Task.FromResult(ServiceResult.FromError($"Prompt option '{promptName}' not found")); - - var parameterOption = _options.Parameters.FirstOrDefault(x => x.Name == parameterName); - if (parameterOption is null) - return Task.FromResult(ServiceResult.FromError($"Parameter option '{parameterName}' not found")); - - - //Max instance - var currentInstances = _modelSessions.Count(x => x.Value.ModelName == modelOption.Name); - if (modelOption.MaxInstances > -1 && currentInstances >= modelOption.MaxInstances) - return Task.FromResult(ServiceResult.FromError("Maximum model instances reached")); - - // Create model - var llamaModel = new LLamaContext(modelOption); - - // Create executor - ILLamaExecutor executor = executorType switch - { - LLamaExecutorType.Interactive => new InteractiveExecutor(llamaModel), - LLamaExecutorType.Instruct => new InstructExecutor(llamaModel), - LLamaExecutorType.Stateless => new StatelessExecutor(llamaModel), - _ => default - }; - - // Create session - var modelSession = new ModelSession(executor, modelOption, promptOption, parameterOption); - if (!_modelSessions.TryAdd(connectionId, modelSession)) - return Task.FromResult(ServiceResult.FromError("Failed to create model session")); - - return Task.FromResult(ServiceResult.FromValue(modelSession)); - } - - public Task RemoveAsync(string connectionId) - { - if (_modelSessions.TryRemove(connectionId, out var modelSession)) - { - modelSession.CancelInfer(); - modelSession.Dispose(); - return Task.FromResult(true); - } - return Task.FromResult(false); - } - - public Task CancelAsync(string connectionId) - { - if (_modelSessions.TryGetValue(connectionId, out var modelSession)) - { - modelSession.CancelInfer(); - return Task.FromResult(true); - } - return Task.FromResult(false); - } - } -} diff --git a/LLama.Web/Services/IModelService.cs b/LLama.Web/Services/IModelService.cs new file mode 100644 index 000000000..9085e0ba5 --- /dev/null +++ b/LLama.Web/Services/IModelService.cs @@ -0,0 +1,14 @@ +using LLama.Web.Common; + +namespace LLama.Web.Services +{ + public interface IModelService + { + Task CreateContext(string modelName, string key); + Task GetContext(string modelName, string key); + Task GetModel(string modelName); + Task LoadModel(ModelOptions modelParams); + Task RemoveContext(string modelName, string key); + Task GetOrCreateModelAndContext(string modelName, string key); + } +} diff --git a/LLama.Web/Services/IModelSessionService.cs b/LLama.Web/Services/IModelSessionService.cs index 4ee0d483f..2b944b711 100644 --- a/LLama.Web/Services/IModelSessionService.cs +++ b/LLama.Web/Services/IModelSessionService.cs @@ -6,11 +6,9 @@ namespace LLama.Web.Services { public interface IModelSessionService { - Task GetAsync(string sessionId); - Task> CreateAsync(LLamaExecutorType executorType, string sessionId, string modelName, string promptName, string parameterName); - Task RemoveAsync(string sessionId); + Task CloseAsync(string sessionId); Task CancelAsync(string sessionId); + Task CreateAsync(string sessionId, SessionConfig sessionConfig, IInferenceParams inferenceParams = null, CancellationToken cancellationToken = default); + IAsyncEnumerable InferAsync(string sessionId, string prompt, IInferenceParams inferenceParams = null, CancellationToken cancellationToken = default); } - - } diff --git a/LLama.Web/Services/ModelService.cs b/LLama.Web/Services/ModelService.cs new file mode 100644 index 000000000..98da04282 --- /dev/null +++ b/LLama.Web/Services/ModelService.cs @@ -0,0 +1,170 @@ +using LLama.Web.Common; +using Microsoft.Extensions.Options; +using System.Collections.Concurrent; +using System.Text; + +namespace LLama.Web.Services +{ + public class ModelService : IModelService + { + private readonly LLamaOptions _configuration; + private readonly ILogger _logger; + private readonly SemaphoreSlim _modelLock = new SemaphoreSlim(1, 1); + private readonly ConcurrentDictionary _modelInstances; + private readonly ConcurrentDictionary _contextInstances; + + + /// + /// Initializes a new instance of the class. + /// + /// The logger. + /// The options. + public ModelService(ILogger logger, IOptions options) + { + _configuration = options.Value; + _modelInstances = new ConcurrentDictionary(); + _contextInstances = new ConcurrentDictionary(); + } + + + /// + /// Loads a model with the provided configuration. + /// + /// The model configuration. + /// + public async Task LoadModel(ModelOptions modelConfig) + { + if (_modelInstances.TryGetValue(modelConfig.Name, out LLamaWeights existingModel)) + return existingModel; + + // Model oading can take some toke so take a lock here + await _modelLock.WaitAsync(); + + try + { + // Catch anyone waiting behind the lock + if (_modelInstances.TryGetValue(modelConfig.Name, out LLamaWeights model)) + return existingModel; + + model = LLamaWeights.LoadFromFile(modelConfig); + if (!_modelInstances.TryAdd(modelConfig.Name, model)) + throw new Exception("Failed to add model"); + + return model; + } + finally + { + _modelLock.Release(); + } + } + + + /// + /// Gets a model ny name. + /// + /// Name of the model. + /// + public Task GetModel(string modelName) + { + _modelInstances.TryGetValue(modelName, out LLamaWeights model); + return Task.FromResult(model); + } + + + /// + /// Gets a context from the specified model. + /// + /// Name of the model. + /// The key. + /// + /// Model not found + public Task GetContext(string modelName, string key) + { + if (!_modelInstances.TryGetValue(modelName, out LLamaWeights model)) + return Task.FromException(new Exception("Model not found")); + + _contextInstances.TryGetValue(ContextKey(modelName, key), out LLamaContext context); + return Task.FromResult(context); + } + + + /// + /// Creates a context on the specified model. + /// + /// Name of the model. + /// The key. + /// + /// Model not found + public Task CreateContext(string modelName, string key) + { + if (!_modelInstances.TryGetValue(modelName, out LLamaWeights model)) + return Task.FromException(new Exception("Model not found")); + + var modelConfig = _configuration.Models.FirstOrDefault(x => x.Name == modelName); + var context = model.CreateContext(modelConfig, Encoding.UTF8); + if (!_contextInstances.TryAdd(ContextKey(modelName, key), context)) + return Task.FromException(new Exception("Failed to add context")); + + return Task.FromResult(context); + } + + + /// + /// Removes a context from the specified model. + /// + /// Name of the model. + /// The key. + /// + /// Model not found + public Task RemoveContext(string modelName, string key) + { + if (!_modelInstances.TryGetValue(modelName, out LLamaWeights model)) + return Task.FromException(new Exception("Model not found")); + + if (_contextInstances.TryRemove(ContextKey(modelName, key), out LLamaContext context)) + { + context.Dispose(); + return Task.FromResult(true); + } + return Task.FromResult(false); + } + + + /// + /// Loads, Gets,Creates a Model and a Context + /// + /// Name of the model. + /// The key. + /// + /// Model option '{modelName}' not found + public async Task GetOrCreateModelAndContext(string modelName, string key) + { + if (_modelInstances.TryGetValue(modelName, out LLamaWeights model)) + { + // Get or Create Context + return await GetContext(modelName, key) + ?? await CreateContext(modelName, key); + } + + // Get model configuration + var modelConfig = _configuration.Models.FirstOrDefault(x => x.Name == modelName); + if (modelConfig is null) + throw new Exception($"Model option '{modelName}' not found"); + + // Load Model + model = await LoadModel(modelConfig); + + // Get or Create Context + return await GetContext(modelName, key) + ?? await CreateContext(modelName, key); + } + + + /// + /// Create a key for the context collection using the model and key provided. + /// + /// Name of the model. + /// The context key. + private static string ContextKey(string modelName, string contextKey) => $"{modelName}-{contextKey}"; + } +} diff --git a/LLama.Web/Services/ModelSessionService.cs b/LLama.Web/Services/ModelSessionService.cs new file mode 100644 index 000000000..9cf56a4de --- /dev/null +++ b/LLama.Web/Services/ModelSessionService.cs @@ -0,0 +1,138 @@ +using LLama.Abstractions; +using LLama.Web.Common; +using LLama.Web.Models; +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Runtime.CompilerServices; + +namespace LLama.Web.Services +{ + /// + /// Example Service for handling a model session for a websockets connection lifetime + /// Each websocket connection will create its own unique session and context allowing you to use multiple tabs to compare prompts etc + public class ModelSessionService : IModelSessionService + { + private readonly IModelService _modelService; + private readonly ConcurrentDictionary _modelSessions; + + + /// + /// Initializes a new instance of the class. + /// + /// The model service. + /// The model session state service. + public ModelSessionService(IModelService modelService) + { + _modelService = modelService; + _modelSessions = new ConcurrentDictionary(); + } + + + /// + /// Creates a new ModelSession + /// + /// The session identifier. + /// The session configuration. + /// The inference parameters. + /// The cancellation token. + /// + /// + /// Session with id {sessionId} already exists + /// or + /// Failed to create model session + /// + public async Task CreateAsync(string sessionId, SessionConfig sessionConfig, IInferenceParams inferenceParams = null, CancellationToken cancellationToken = default) + { + if (_modelSessions.TryGetValue(sessionId, out _)) + throw new Exception($"Session with id {sessionId} already exists"); + + // Create context + var context = await _modelService.GetOrCreateModelAndContext(sessionConfig.Model, sessionId.ToString()); + + // Create session + var modelSession = new ModelSession(context, sessionConfig, inferenceParams); + if (!_modelSessions.TryAdd(sessionId, modelSession)) + throw new Exception($"Failed to create model session"); + + // Run initial Prompt + await modelSession.InitializePrompt(inferenceParams, cancellationToken); + return modelSession; + + } + + + /// + /// Runs inference on the current ModelSession + /// + /// The session identifier. + /// The prompt. + /// The inference parameters. + /// The cancellation token. + /// + /// Inference is already running for this session + public async IAsyncEnumerable InferAsync(string sessionId, string prompt, IInferenceParams inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + if (!_modelSessions.TryGetValue(sessionId, out var modelSession)) + yield break; + + + // Send begin of response + var stopwatch = Stopwatch.GetTimestamp(); + yield return new InferTokenModel(default, default, default, InferTokenType.Begin, GetElapsed(stopwatch)); + + // Send content of response + await foreach (var token in modelSession.InferAsync(prompt, inferenceParams, cancellationToken)) + yield return new InferTokenModel(default, default, token, InferTokenType.Content, GetElapsed(stopwatch)); + + // Send end of response + var elapsedTime = GetElapsed(stopwatch); + var endTokenType = modelSession.IsInferCanceled() ? InferTokenType.Cancel : InferTokenType.End; + var signature = endTokenType == InferTokenType.Cancel + ? $"Inference cancelled after {elapsedTime / 1000:F0} seconds" + : $"Inference completed in {elapsedTime / 1000:F0} seconds"; + yield return new InferTokenModel(default, default, signature, endTokenType, elapsedTime); + } + + + /// + /// Closes the session + /// + /// The session identifier. + /// + public async Task CloseAsync(string sessionId) + { + if (_modelSessions.TryRemove(sessionId, out var modelSession)) + { + modelSession.CancelInfer(); + return await _modelService.RemoveContext(modelSession.ModelName, sessionId.ToString()); + } + return false; + } + + + /// + /// Cancels the current action. + /// + /// The session identifier. + /// + public Task CancelAsync(string sessionId) + { + if (_modelSessions.TryGetValue(sessionId, out var modelSession)) + { + modelSession.CancelInfer(); + return Task.FromResult(true); + } + return Task.FromResult(false); + } + + /// + /// Gets the elapsed time in milliseconds. + /// + /// The timestamp. + /// + private static int GetElapsed(long timestamp) + { + return (int)Stopwatch.GetElapsedTime(timestamp).TotalMilliseconds; + } + } +} diff --git a/LLama.Web/appsettings.json b/LLama.Web/appsettings.json index 9f340a9c7..34340a87e 100644 --- a/LLama.Web/appsettings.json +++ b/LLama.Web/appsettings.json @@ -11,44 +11,31 @@ { "Name": "WizardLM-7B", "MaxInstances": 2, + "ModelPath": "D:\\Repositories\\AI\\Models\\wizardLM-7B.ggmlv3.q4_0.bin", - "ContextSize": 2048 - } - ], - "Parameters": [ - { - "Name": "Default", - "Temperature": 0.6 - } - ], - "Prompts": [ - { - "Name": "None", - "Prompt": "" - }, - { - "Name": "Alpaca", - "Path": "D:\\Repositories\\AI\\Prompts\\alpaca.txt", - "AntiPrompt": [ - "User:" - ], - "OutputFilter": [ - "Response:", - "User:" - ] - }, - { - "Name": "ChatWithBob", - "Path": "D:\\Repositories\\AI\\Prompts\\chat-with-bob.txt", - "AntiPrompt": [ - "User:" - ], - "OutputFilter": [ - "Bob:", - "User:" - ] + "ContextSize": 512, + "BatchSize": 512, + "Threads": -1, + "GpuLayerCount": 20, + "UseMemorymap": true, + "UseMemoryLock": false, + "MainGpu": 0, + "LowVram": false, + "Seed": 1686349486, + "UseFp16Memory": true, + "Perplexity": false, + "ModelAlias": "unknown", + "LoraAdapter": "", + "LoraBase": "", + "ConvertEosToNewLine": false, + "EmbeddingMode": false, + "TensorSplits": null, + "GroupedQueryAttention": 1, + "RmsNormEpsilon": 0.000005, + "RopeFrequencyBase": 10000.0, + "RopeFrequencyScale": 1.0, + "MulMatQ": false } ] - } } diff --git a/LLama.Web/wwwroot/css/site.css b/LLama.Web/wwwroot/css/site.css index d10ef9757..1603f8cb3 100644 --- a/LLama.Web/wwwroot/css/site.css +++ b/LLama.Web/wwwroot/css/site.css @@ -27,8 +27,21 @@ footer { } } -.btn:focus, .btn:active:focus, .btn-link.nav-link:focus, .form-control:focus, .form-check-input:focus { - box-shadow: 0 0 0 0.1rem white, 0 0 0 0.25rem #258cfb; +#scroll-container { + flex: 1; + overflow-y: scroll; +} + +#output-container .content { + white-space: break-spaces; } +.slider-container > .slider { + width: 100%; +} + +.slider-container > label { + width: 50px; + text-align: center; +} \ No newline at end of file diff --git a/LLama.Web/wwwroot/js/sessionConnectionChat.js b/LLama.Web/wwwroot/js/sessionConnectionChat.js index 472b59718..d677892ab 100644 --- a/LLama.Web/wwwroot/js/sessionConnectionChat.js +++ b/LLama.Web/wwwroot/js/sessionConnectionChat.js @@ -1,26 +1,26 @@ -const createConnectionSessionChat = (LLamaExecutorType) => { +const createConnectionSessionChat = () => { const outputErrorTemplate = $("#outputErrorTemplate").html(); const outputInfoTemplate = $("#outputInfoTemplate").html(); const outputUserTemplate = $("#outputUserTemplate").html(); const outputBotTemplate = $("#outputBotTemplate").html(); - const sessionDetailsTemplate = $("#sessionDetailsTemplate").html(); + const signatureTemplate = $("#signatureTemplate").html(); - let connectionId; + let inferenceSession; const connection = new signalR.HubConnectionBuilder().withUrl("/SessionConnectionHub").build(); const scrollContainer = $("#scroll-container"); const outputContainer = $("#output-container"); const chatInput = $("#input"); - const onStatus = (connection, status) => { - connectionId = connection; if (status == Enums.SessionConnectionStatus.Connected) { $("#socket").text("Connected").addClass("text-success"); } else if (status == Enums.SessionConnectionStatus.Loaded) { + loaderHide(); enableControls(); - $("#session-details").html(Mustache.render(sessionDetailsTemplate, { model: getSelectedModel(), prompt: getSelectedPrompt(), parameter: getSelectedParameter() })); + $("#load").hide(); + $("#unload").show(); onInfo(`New model session successfully started`) } } @@ -36,30 +36,31 @@ const createConnectionSessionChat = (LLamaExecutorType) => { let responseContent; let responseContainer; - let responseFirstFragment; + let responseFirstToken; const onResponse = (response) => { if (!response) return; - if (response.isFirst) { - outputContainer.append(Mustache.render(outputBotTemplate, response)); - responseContainer = $(`#${response.id}`); + if (response.type == Enums.InferTokenType.Begin) { + const uniqueId = randomString(); + outputContainer.append(Mustache.render(outputBotTemplate, { id: uniqueId, ...response })); + responseContainer = $(`#${uniqueId}`); responseContent = responseContainer.find(".content"); - responseFirstFragment = true; + responseFirstToken = true; scrollToBottom(true); return; } - if (response.isLast) { + if (response.type == Enums.InferTokenType.End || response.type == Enums.InferTokenType.Cancel) { enableControls(); - responseContainer.find(".signature").append(response.content); + responseContainer.find(".signature").append(Mustache.render(signatureTemplate, response)); scrollToBottom(); } else { - if (responseFirstFragment) { + if (responseFirstToken) { responseContent.empty(); - responseFirstFragment = false; + responseFirstToken = false; responseContainer.find(".date").append(getDateTime()); } responseContent.append(response.content); @@ -67,45 +68,88 @@ const createConnectionSessionChat = (LLamaExecutorType) => { } } - const sendPrompt = async () => { const text = chatInput.val(); if (text) { + chatInput.val(null); disableControls(); outputContainer.append(Mustache.render(outputUserTemplate, { text: text, date: getDateTime() })); - await connection.invoke('SendPrompt', text); - chatInput.val(null); + inferenceSession = await connection + .stream("SendPrompt", text, serializeFormToJson('SessionParameters')) + .subscribe({ + next: onResponse, + complete: onResponse, + error: onError, + }); scrollToBottom(true); } } const cancelPrompt = async () => { - await ajaxPostJsonAsync('?handler=Cancel', { connectionId: connectionId }); + if (inferenceSession) + inferenceSession.dispose(); } const loadModel = async () => { - const modelName = getSelectedModel(); - const promptName = getSelectedPrompt(); - const parameterName = getSelectedParameter(); - if (!modelName || !promptName || !parameterName) { - onError("Please select a valid Model, Parameter and Prompt"); - return; - } + const sessionParams = serializeFormToJson('SessionParameters'); + loaderShow(); + disableControls(); + disablePromptControls(); + $("#load").attr("disabled", "disabled"); + // TODO: Split parameters sets + await connection.invoke('LoadModel', sessionParams, sessionParams); + } + + const unloadModel = async () => { disableControls(); - await connection.invoke('LoadModel', LLamaExecutorType, modelName, promptName, parameterName); + enablePromptControls(); + $("#load").removeAttr("disabled"); } + const serializeFormToJson = (form) => { + const formDataJson = {}; + const formData = new FormData(document.getElementById(form)); + formData.forEach((value, key) => { + + if (key.includes(".")) + key = key.split(".")[1]; + + // Convert number strings to numbers + if (!isNaN(value) && value.trim() !== "") { + formDataJson[key] = parseFloat(value); + } + // Convert boolean strings to booleans + else if (value === "true" || value === "false") { + formDataJson[key] = (value === "true"); + } + else { + formDataJson[key] = value; + } + }); + return formDataJson; + } const enableControls = () => { $(".input-control").removeAttr("disabled"); } - const disableControls = () => { $(".input-control").attr("disabled", "disabled"); } + const enablePromptControls = () => { + $("#load").show(); + $("#unload").hide(); + $(".prompt-control").removeAttr("disabled"); + activatePromptTab(); + } + + const disablePromptControls = () => { + $(".prompt-control").attr("disabled", "disabled"); + activateParamsTab(); + } + const clearOutput = () => { outputContainer.empty(); } @@ -117,27 +161,14 @@ const createConnectionSessionChat = (LLamaExecutorType) => { customPrompt.text(selectedValue); } - - const getSelectedModel = () => { - return $("option:selected", "#Model").val(); - } - - - const getSelectedParameter = () => { - return $("option:selected", "#Parameter").val(); - } - - - const getSelectedPrompt = () => { - return $("option:selected", "#Prompt").val(); - } - - const getDateTime = () => { const dateTime = new Date(); return dateTime.toLocaleString(); } + const randomString = () => { + return Math.random().toString(36).slice(2); + } const scrollToBottom = (force) => { const scrollTop = scrollContainer.scrollTop(); @@ -151,10 +182,25 @@ const createConnectionSessionChat = (LLamaExecutorType) => { } } + const activatePromptTab = () => { + $("#nav-prompt-tab").trigger("click"); + } + const activateParamsTab = () => { + $("#nav-params-tab").trigger("click"); + } + + const loaderShow = () => { + $(".spinner").show(); + } + + const loaderHide = () => { + $(".spinner").hide(); + } // Map UI functions $("#load").on("click", loadModel); + $("#unload").on("click", unloadModel); $("#send").on("click", sendPrompt); $("#clear").on("click", clearOutput); $("#cancel").on("click", cancelPrompt); @@ -165,7 +211,10 @@ const createConnectionSessionChat = (LLamaExecutorType) => { sendPrompt(); } }); - + $(".slider").on("input", function (e) { + const slider = $(this); + slider.next().text(slider.val()); + }).trigger("input"); // Map signalr functions diff --git a/LLama.Web/wwwroot/js/site.js b/LLama.Web/wwwroot/js/site.js index 2f679669b..9f896f83f 100644 --- a/LLama.Web/wwwroot/js/site.js +++ b/LLama.Web/wwwroot/js/site.js @@ -32,8 +32,6 @@ const ajaxGetJsonAsync = (url, data) => { } - - const Enums = { SessionConnectionStatus: Object.freeze({ Disconnected: 0, @@ -45,6 +43,12 @@ const Enums = { Instruct: 1, Stateless: 2 }), + InferTokenType: Object.freeze({ + Begin: 0, + Content: 2, + End: 4, + Cancel: 10 + }), GetName: (enumType, enumKey) => { return Object.keys(enumType)[enumKey] }, diff --git a/LLamaSharp.sln b/LLamaSharp.sln index 2e00196c2..f1b4b18d5 100644 --- a/LLamaSharp.sln +++ b/LLamaSharp.sln @@ -11,7 +11,7 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LLamaSharp", "LLama\LLamaSh EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LLama.WebAPI", "LLama.WebAPI\LLama.WebAPI.csproj", "{D3CEC57A-9027-4DA4-AAAC-612A1EB50ADF}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "LLama.Web", "LLama.Web\LLama.Web.csproj", "{C3531DB2-1B2B-433C-8DE6-3541E3620DB1}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "LLama.Web", "LLama.Web\LLama.Web.csproj", "{C3531DB2-1B2B-433C-8DE6-3541E3620DB1}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution From a0bcd982697aeb6353d6a8ec5fbe01ae50fc0a3f Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Tue, 22 Aug 2023 23:05:24 +1200 Subject: [PATCH 2/4] Make multi model caching configurable --- LLama.Web/Common/LLamaOptions.cs | 5 +- LLama.Web/Hubs/SessionConnectionHub.cs | 1 + LLama.Web/Services/ModelService.cs | 62 ++++++++++++++++++- LLama.Web/Services/ModelSessionService.cs | 4 +- LLama.Web/appsettings.json | 1 + LLama.Web/wwwroot/js/sessionConnectionChat.js | 1 + 6 files changed, 68 insertions(+), 6 deletions(-) diff --git a/LLama.Web/Common/LLamaOptions.cs b/LLama.Web/Common/LLamaOptions.cs index 2348dd133..ad58676f7 100644 --- a/LLama.Web/Common/LLamaOptions.cs +++ b/LLama.Web/Common/LLamaOptions.cs @@ -1,7 +1,10 @@ -namespace LLama.Web.Common +using LLama.Web.Services; + +namespace LLama.Web.Common { public class LLamaOptions { + public ModelCacheType ModelCacheType { get; set; } public List Models { get; set; } public void Initialize() diff --git a/LLama.Web/Hubs/SessionConnectionHub.cs b/LLama.Web/Hubs/SessionConnectionHub.cs index a6ee6df66..28046d4f9 100644 --- a/LLama.Web/Hubs/SessionConnectionHub.cs +++ b/LLama.Web/Hubs/SessionConnectionHub.cs @@ -42,6 +42,7 @@ public async Task OnLoadModel(SessionConfig sessionConfig, InferenceOptions infe _logger.Log(LogLevel.Information, "[OnLoadModel] - Load new model, Connection: {0}", Context.ConnectionId); // Create model session + await _modelSessionService.CloseAsync(Context.ConnectionId); var modelSession = await _modelSessionService.CreateAsync(Context.ConnectionId, sessionConfig, inferenceConfig); if (modelSession is null) { diff --git a/LLama.Web/Services/ModelService.cs b/LLama.Web/Services/ModelService.cs index 98da04282..150167f1f 100644 --- a/LLama.Web/Services/ModelService.cs +++ b/LLama.Web/Services/ModelService.cs @@ -1,10 +1,18 @@ using LLama.Web.Common; +using Microsoft.AspNetCore.DataProtection.KeyManagement; +using Microsoft.AspNetCore.Mvc.ModelBinding; using Microsoft.Extensions.Options; using System.Collections.Concurrent; using System.Text; namespace LLama.Web.Services { + public enum ModelCacheType + { + Single = 0, + Multiple = 1 + } + public class ModelService : IModelService { private readonly LLamaOptions _configuration; @@ -46,9 +54,12 @@ public async Task LoadModel(ModelOptions modelConfig) if (_modelInstances.TryGetValue(modelConfig.Name, out LLamaWeights model)) return existingModel; + if (_configuration.ModelCacheType == ModelCacheType.Single) + await UnloadModels(); + model = LLamaWeights.LoadFromFile(modelConfig); if (!_modelInstances.TryAdd(modelConfig.Name, model)) - throw new Exception("Failed to add model"); + throw new Exception("Failed to add model"); return model; } @@ -59,6 +70,41 @@ public async Task LoadModel(ModelOptions modelConfig) } + /// + /// Unloads the model. + /// + /// Name of the model. + /// + public Task UnloadModel(string modelName) + { + if (_modelInstances.TryRemove(modelName, out LLamaWeights model)) + { + foreach (var contextKey in ContextKeys(modelName)) + { + if (!_contextInstances.TryRemove(contextKey, out var context)) + continue; + + context?.Dispose(); + } + model?.Dispose(); + return Task.FromResult(true); + } + return Task.FromResult(false); + } + + + /// + /// Unloads all models. + /// + public async Task UnloadModels() + { + foreach (var modelName in _modelInstances.Keys) + { + await UnloadModel(modelName); + } + } + + /// /// Gets a model ny name. /// @@ -119,7 +165,7 @@ public Task CreateContext(string modelName, string key) public Task RemoveContext(string modelName, string key) { if (!_modelInstances.TryGetValue(modelName, out LLamaWeights model)) - return Task.FromException(new Exception("Model not found")); + return Task.FromResult(false); if (_contextInstances.TryRemove(ContextKey(modelName, key), out LLamaContext context)) { @@ -160,11 +206,21 @@ public async Task GetOrCreateModelAndContext(string modelName, str } + /// + /// Gets a list of context keys for the model name provided. + /// + /// Name of the model. + /// + private IEnumerable ContextKeys(string modelName) + { + return _contextInstances.Keys.Where(x => x.StartsWith($"{modelName}:")); + } + /// /// Create a key for the context collection using the model and key provided. /// /// Name of the model. /// The context key. - private static string ContextKey(string modelName, string contextKey) => $"{modelName}-{contextKey}"; + private string ContextKey(string modelName, string contextKey) => $"{modelName}:{contextKey}"; } } diff --git a/LLama.Web/Services/ModelSessionService.cs b/LLama.Web/Services/ModelSessionService.cs index 9cf56a4de..522515a6c 100644 --- a/LLama.Web/Services/ModelSessionService.cs +++ b/LLama.Web/Services/ModelSessionService.cs @@ -43,8 +43,8 @@ public ModelSessionService(IModelService modelService) /// public async Task CreateAsync(string sessionId, SessionConfig sessionConfig, IInferenceParams inferenceParams = null, CancellationToken cancellationToken = default) { - if (_modelSessions.TryGetValue(sessionId, out _)) - throw new Exception($"Session with id {sessionId} already exists"); + if (_modelSessions.TryGetValue(sessionId, out var existingSession)) + return existingSession; // Create context var context = await _modelService.GetOrCreateModelAndContext(sessionConfig.Model, sessionId.ToString()); diff --git a/LLama.Web/appsettings.json b/LLama.Web/appsettings.json index 34340a87e..00a5056c0 100644 --- a/LLama.Web/appsettings.json +++ b/LLama.Web/appsettings.json @@ -7,6 +7,7 @@ }, "AllowedHosts": "*", "LLamaOptions": { + "ModelCacheType": "Multiple", "Models": [ { "Name": "WizardLM-7B", diff --git a/LLama.Web/wwwroot/js/sessionConnectionChat.js b/LLama.Web/wwwroot/js/sessionConnectionChat.js index d677892ab..85adde2cb 100644 --- a/LLama.Web/wwwroot/js/sessionConnectionChat.js +++ b/LLama.Web/wwwroot/js/sessionConnectionChat.js @@ -102,6 +102,7 @@ const createConnectionSessionChat = () => { } const unloadModel = async () => { + await cancelPrompt(); disableControls(); enablePromptControls(); $("#load").removeAttr("disabled"); From 0bf4fc0dd4763ba87aea115d79465a93cbb037f4 Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Tue, 22 Aug 2023 23:56:26 +1200 Subject: [PATCH 3/4] Preload model support --- LLama.Web/Common/LLamaOptions.cs | 4 +--- LLama.Web/Common/ModelCacheType.cs | 10 ++++++++++ LLama.Web/Program.cs | 1 + LLama.Web/Services/IModelService.cs | 3 +++ LLama.Web/Services/LoaderService.cs | 22 +++++++++++++++++++++ LLama.Web/Services/ModelService.cs | 30 +++++++++++++++++++++-------- LLama.Web/appsettings.json | 2 +- 7 files changed, 60 insertions(+), 12 deletions(-) create mode 100644 LLama.Web/Common/ModelCacheType.cs create mode 100644 LLama.Web/Services/LoaderService.cs diff --git a/LLama.Web/Common/LLamaOptions.cs b/LLama.Web/Common/LLamaOptions.cs index ad58676f7..5db4ea959 100644 --- a/LLama.Web/Common/LLamaOptions.cs +++ b/LLama.Web/Common/LLamaOptions.cs @@ -1,6 +1,4 @@ -using LLama.Web.Services; - -namespace LLama.Web.Common +namespace LLama.Web.Common { public class LLamaOptions { diff --git a/LLama.Web/Common/ModelCacheType.cs b/LLama.Web/Common/ModelCacheType.cs new file mode 100644 index 000000000..71f88b356 --- /dev/null +++ b/LLama.Web/Common/ModelCacheType.cs @@ -0,0 +1,10 @@ +namespace LLama.Web.Common +{ + public enum ModelCacheType + { + Single = 0, + Multiple = 1, + PreloadSingle = 2, + PreloadMultiple = 3, + } +} diff --git a/LLama.Web/Program.cs b/LLama.Web/Program.cs index ee63590e0..52eb80d6f 100644 --- a/LLama.Web/Program.cs +++ b/LLama.Web/Program.cs @@ -20,6 +20,7 @@ public static void Main(string[] args) .BindConfiguration(nameof(LLamaOptions)); // Services DI + builder.Services.AddHostedService(); builder.Services.AddSingleton(); builder.Services.AddSingleton(); diff --git a/LLama.Web/Services/IModelService.cs b/LLama.Web/Services/IModelService.cs index 9085e0ba5..9f2285109 100644 --- a/LLama.Web/Services/IModelService.cs +++ b/LLama.Web/Services/IModelService.cs @@ -7,7 +7,10 @@ public interface IModelService Task CreateContext(string modelName, string key); Task GetContext(string modelName, string key); Task GetModel(string modelName); + Task LoadModels(); Task LoadModel(ModelOptions modelParams); + Task UnloadModel(string modelName); + Task UnloadModels(); Task RemoveContext(string modelName, string key); Task GetOrCreateModelAndContext(string modelName, string key); } diff --git a/LLama.Web/Services/LoaderService.cs b/LLama.Web/Services/LoaderService.cs new file mode 100644 index 000000000..2eb3a5a66 --- /dev/null +++ b/LLama.Web/Services/LoaderService.cs @@ -0,0 +1,22 @@ +namespace LLama.Web.Services +{ + public class LoaderService : IHostedService + { + private readonly IModelService _modelService; + + public LoaderService(IModelService modelService) + { + _modelService = modelService; + } + + public async Task StartAsync(CancellationToken cancellationToken) + { + await _modelService.LoadModels(); + } + + public async Task StopAsync(CancellationToken cancellationToken) + { + await _modelService.UnloadModels(); + } + } +} diff --git a/LLama.Web/Services/ModelService.cs b/LLama.Web/Services/ModelService.cs index 150167f1f..c08956747 100644 --- a/LLama.Web/Services/ModelService.cs +++ b/LLama.Web/Services/ModelService.cs @@ -1,17 +1,10 @@ using LLama.Web.Common; -using Microsoft.AspNetCore.DataProtection.KeyManagement; -using Microsoft.AspNetCore.Mvc.ModelBinding; using Microsoft.Extensions.Options; using System.Collections.Concurrent; using System.Text; namespace LLama.Web.Services { - public enum ModelCacheType - { - Single = 0, - Multiple = 1 - } public class ModelService : IModelService { @@ -54,7 +47,8 @@ public async Task LoadModel(ModelOptions modelConfig) if (_modelInstances.TryGetValue(modelConfig.Name, out LLamaWeights model)) return existingModel; - if (_configuration.ModelCacheType == ModelCacheType.Single) + // If in single mode unload any other models + if (_configuration.ModelCacheType == ModelCacheType.Single || _configuration.ModelCacheType == ModelCacheType.PreloadSingle) await UnloadModels(); model = LLamaWeights.LoadFromFile(modelConfig); @@ -70,6 +64,26 @@ public async Task LoadModel(ModelOptions modelConfig) } + /// + /// Loads the models. + /// + public async Task LoadModels() + { + if (_configuration.ModelCacheType == ModelCacheType.Single + || _configuration.ModelCacheType == ModelCacheType.Multiple) + return; + + foreach (var modelConfig in _configuration.Models) + { + await LoadModel(modelConfig); + + //Only preload first model if in SinglePreload mode + if (_configuration.ModelCacheType == ModelCacheType.PreloadSingle) + break; + } + } + + /// /// Unloads the model. /// diff --git a/LLama.Web/appsettings.json b/LLama.Web/appsettings.json index 00a5056c0..9148ce0cf 100644 --- a/LLama.Web/appsettings.json +++ b/LLama.Web/appsettings.json @@ -7,7 +7,7 @@ }, "AllowedHosts": "*", "LLamaOptions": { - "ModelCacheType": "Multiple", + "ModelCacheType": "PreloadSingle", "Models": [ { "Name": "WizardLM-7B", From 8193dfc2d4afc11229930d28ba641f7bfd9be394 Mon Sep 17 00:00:00 2001 From: sa_ddam213 Date: Thu, 24 Aug 2023 08:02:32 +1200 Subject: [PATCH 4/4] Add app switch --- LLama.Web/Common/AppType.cs | 8 ++++ LLama.Web/Common/LLamaOptions.cs | 1 + LLama.Web/Controllers/ModelController.cs | 27 +++++++++++ LLama.Web/LLama.Web.csproj | 12 ++++- LLama.Web/Program.cs | 61 +++++++++++++++++------- LLama.Web/appsettings.json | 3 +- 6 files changed, 93 insertions(+), 19 deletions(-) create mode 100644 LLama.Web/Common/AppType.cs create mode 100644 LLama.Web/Controllers/ModelController.cs diff --git a/LLama.Web/Common/AppType.cs b/LLama.Web/Common/AppType.cs new file mode 100644 index 000000000..9697cd895 --- /dev/null +++ b/LLama.Web/Common/AppType.cs @@ -0,0 +1,8 @@ +namespace LLama.Web.Common +{ + public enum AppType + { + Web = 0, + WebApi = 1 + } +} diff --git a/LLama.Web/Common/LLamaOptions.cs b/LLama.Web/Common/LLamaOptions.cs index 5db4ea959..95ce1566a 100644 --- a/LLama.Web/Common/LLamaOptions.cs +++ b/LLama.Web/Common/LLamaOptions.cs @@ -2,6 +2,7 @@ { public class LLamaOptions { + public AppType AppType { get; set; } public ModelCacheType ModelCacheType { get; set; } public List Models { get; set; } diff --git a/LLama.Web/Controllers/ModelController.cs b/LLama.Web/Controllers/ModelController.cs new file mode 100644 index 000000000..39b697eb5 --- /dev/null +++ b/LLama.Web/Controllers/ModelController.cs @@ -0,0 +1,27 @@ +using LLama.Web.Common; +using LLama.Web.Services; +using Microsoft.AspNetCore.Mvc; +using Microsoft.Extensions.Options; + +namespace LLama.Web.Controllers +{ + [ApiController] + [Route("[controller]")] + public class ModelController : ControllerBase + { + private readonly LLamaOptions _configuration; + private readonly IModelService _modelService; + + public ModelController(IOptions options, IModelService modelService) + { + _modelService = modelService; + _configuration = options.Value; + } + + [HttpGet("GetAll")] + public async Task GetModels() + { + return Ok(_configuration.Models); + } + } +} diff --git a/LLama.Web/LLama.Web.csproj b/LLama.Web/LLama.Web.csproj index d0e15a62d..e0b6f860c 100644 --- a/LLama.Web/LLama.Web.csproj +++ b/LLama.Web/LLama.Web.csproj @@ -7,11 +7,21 @@ - + + + + + + + + + + + diff --git a/LLama.Web/Program.cs b/LLama.Web/Program.cs index 52eb80d6f..b74c027fb 100644 --- a/LLama.Web/Program.cs +++ b/LLama.Web/Program.cs @@ -1,6 +1,7 @@ using LLama.Web.Common; using LLama.Web.Hubs; using LLama.Web.Services; +using System.Text.Json.Serialization; namespace LLama.Web { @@ -14,38 +15,64 @@ public static void Main(string[] args) builder.Services.AddRazorPages(); builder.Services.AddSignalR(); - // Load InteractiveOptions + // Load LLamaOptions builder.Services.AddOptions() - .PostConfigure(x => x.Initialize()) - .BindConfiguration(nameof(LLamaOptions)); + .PostConfigure(options => options.Initialize()) + .BindConfiguration(nameof(LLamaOptions)); // Services DI builder.Services.AddHostedService(); builder.Services.AddSingleton(); builder.Services.AddSingleton(); - var app = builder.Build(); - // Configure the HTTP request pipeline. - if (!app.Environment.IsDevelopment()) + var configuration = builder.Configuration.GetSection(nameof(LLamaOptions)).Get(); + if (configuration.AppType == AppType.Web) { - app.UseExceptionHandler("/Error"); - // The default HSTS value is 30 days. You may want to change this for production scenarios, see https://aka.ms/aspnetcore-hsts. - app.UseHsts(); - } + var app = builder.Build(); - app.UseHttpsRedirection(); - app.UseStaticFiles(); + // Configure the HTTP request pipeline. + if (!app.Environment.IsDevelopment()) + { + app.UseExceptionHandler("/Error"); + // The default HSTS value is 30 days. You may want to change this for production scenarios, see https://aka.ms/aspnetcore-hsts. + app.UseHsts(); + } - app.UseRouting(); + app.UseHttpsRedirection(); + app.UseStaticFiles(); + app.UseRouting(); + app.UseAuthorization(); + app.MapRazorPages(); + app.MapHub(nameof(SessionConnectionHub)); + app.Run(); + } + else if (configuration.AppType == AppType.WebApi) + { - app.UseAuthorization(); + // Add Controllers + builder.Services.AddControllers().AddJsonOptions(options => + { + options.JsonSerializerOptions.WriteIndented = true; + options.JsonSerializerOptions.Converters.Add(new JsonStringEnumConverter()); + }); - app.MapRazorPages(); + // Add Swagger/OpenAPI https://aka.ms/aspnetcore/swashbuckle + builder.Services.AddEndpointsApiExplorer(); + builder.Services.AddSwaggerGen(options => options.UseInlineDefinitionsForEnums()); - app.MapHub(nameof(SessionConnectionHub)); + var app = builder.Build(); - app.Run(); + app.UseSwagger(); + app.UseSwaggerUI(options => { + options.SwaggerEndpoint("/swagger/v1/swagger.json", "v1"); + options.RoutePrefix = string.Empty; + }); + + app.UseAuthorization(); + app.MapControllers(); + app.Run(); + } } } } \ No newline at end of file diff --git a/LLama.Web/appsettings.json b/LLama.Web/appsettings.json index 6d2217a43..36f2018aa 100644 --- a/LLama.Web/appsettings.json +++ b/LLama.Web/appsettings.json @@ -7,7 +7,8 @@ }, "AllowedHosts": "*", "LLamaOptions": { - "ModelCacheType": "PreloadSingle", + "AppType": "WebApi", + "ModelCacheType": "Single", "Models": [ { "Name": "WizardLM-7B",