diff --git a/LLama/Exceptions/RuntimeError.cs b/LLama/Exceptions/RuntimeError.cs index bf20543ca..9b3dbd242 100644 --- a/LLama/Exceptions/RuntimeError.cs +++ b/LLama/Exceptions/RuntimeError.cs @@ -77,6 +77,21 @@ public MissingTemplateException(string message) } } +/// +/// `llama_decode` return a non-zero status code +/// +public class TemplateNotFoundException + : RuntimeError +{ + /// + public TemplateNotFoundException(string name) + : base($"llama_model_chat_template failed: Tried to retrieve template '{name}' but it couldn't be found.\n" + + $"This might mean that the model was exported incorrectly, or that this is a base model that contains no templates.\n" + + $"This exception can be disabled by passing 'strict=false' as a parameter when retrieving the template.") + { + } +} + /// /// `llama_get_logits_ith` returned null, indicating that the index was invalid /// diff --git a/LLama/LLamaTemplate.cs b/LLama/LLamaTemplate.cs index 73abdb919..ed18778d7 100644 --- a/LLama/LLamaTemplate.cs +++ b/LLama/LLamaTemplate.cs @@ -105,19 +105,21 @@ public bool AddAssistant /// /// Construct a new template, using the default model template /// - /// - /// - public LLamaTemplate(SafeLlamaModelHandle model, string? name = null) - : this(model.GetTemplate(name)) + /// The native handle of the loaded model. + /// The name of the template, in case there are many or differently named. Set to 'null' for the default behaviour of finding an appropriate match. + /// Setting this to true will cause the call to throw if no valid templates are found. + public LLamaTemplate(SafeLlamaModelHandle model, string? name = null, bool strict = true) + : this(model.GetTemplate(name, strict)) { } /// /// Construct a new template, using the default model template /// - /// - public LLamaTemplate(LLamaWeights weights) - : this(weights.NativeHandle) + /// The handle of the loaded model's weights. + /// Setting this to true will cause the call to throw if no valid templates are found. + public LLamaTemplate(LLamaWeights weights, bool strict = true) + : this(weights.NativeHandle, strict: strict) { } diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index 9439c2bb3..0fd39176b 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -603,15 +603,22 @@ internal IReadOnlyDictionary ReadMetadata() /// Get the default chat template. Returns nullptr if not available /// If name is NULL, returns the default chat template /// - /// + /// The name of the template, in case there are many or differently named. Set to 'null' for the default behaviour of finding an appropriate match. + /// Setting this to true will cause the call to throw if no valid templates are found. /// - public string? GetTemplate(string? name = null) + public string? GetTemplate(string? name = null, bool strict = true) { unsafe { var bytesPtr = llama_model_chat_template(this, name); if (bytesPtr == null) - return null; + { + if (strict) + throw new TemplateNotFoundException(name ?? "default template"); + else + return null; + + } // Find null terminator var spanBytes = new Span(bytesPtr, int.MaxValue);