-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathAiManager.cs
More file actions
363 lines (303 loc) · 12.6 KB
/
AiManager.cs
File metadata and controls
363 lines (303 loc) · 12.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
using System.Security.Cryptography;
using System.Text;
using System.Text.Json;
using LLama;
using LLama.Common;
using LLama.Sampling;
namespace NoteUI;
public class AiManager
{
// ── Data models ──
public record ModelInfo(string Id, string Name);
public record ChatMessage(string Role, string Content, DateTime Timestamp);
public record LocalModel(string Name, string Repo, string FileName, string Size, bool IsPredefined = true)
{
public string DownloadUrl => $"https://huggingface.co/{Repo}/resolve/main/{FileName}";
public string LocalPath => Path.Combine(ModelsDir, FileName);
public bool IsInstalled => File.Exists(LocalPath);
}
public class AiSettings
{
public bool IsEnabled { get; set; } = true;
public float Temperature { get; set; } = 0.7f;
public int MaxTokens { get; set; } = 2048;
public int ContextSize { get; set; } = 2048;
public int GpuLayers { get; set; } = 20;
public string SystemPrompt { get; set; } = "Tu es un assistant utile et concis.";
public string LastProviderId { get; set; } = "";
public string LastModelId { get; set; } = "";
public string LastLocalModelFileName { get; set; } = "";
public Dictionary<string, string> Prompts { get; set; } = new();
public List<CustomPrompt> CustomPrompts { get; set; } = new();
}
public class CustomPrompt
{
public string Id { get; set; } = Guid.NewGuid().ToString("N")[..8];
public string Title { get; set; } = "";
public string Instruction { get; set; } = "";
}
public static readonly (string Key, string DefaultPrompt)[] PromptDefinitions =
[
("ai_improve_writing", "Improve writing quality, clarity, and flow while preserving meaning."),
("ai_fix_grammar_spelling", "Correct grammar, spelling, punctuation, and agreement mistakes."),
("ai_tone_professional", "Rewrite in a professional and formal tone."),
("ai_tone_friendly", "Rewrite in a friendly and approachable tone."),
("ai_tone_concise", "Rewrite in a concise tone with shorter, clearer sentences."),
];
public string GetPrompt(string key)
{
if (Settings.Prompts.TryGetValue(key, out var custom) && !string.IsNullOrWhiteSpace(custom))
return custom;
foreach (var (k, def) in PromptDefinitions)
if (k == key) return def;
return key;
}
public static string GetDefaultPrompt(string key)
{
foreach (var (k, def) in PromptDefinitions)
if (k == key) return def;
return key;
}
// ── Paths ──
private static readonly string BaseDir =
Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.LocalApplicationData), "NoteUI");
public static readonly string ModelsDir = Path.Combine(BaseDir, "ai_models");
private static readonly string SettingsPath = Path.Combine(BaseDir, "ai_settings.json");
private static readonly string KeysPath = Path.Combine(BaseDir, "ai_keys.dat");
// ── State ──
public AiSettings Settings { get; private set; } = new();
private Dictionary<string, string> _apiKeys = new();
// Local model state
private LLamaWeights? _model;
private ModelParams? _modelParams;
private string? _loadedModelPath;
private CancellationTokenSource? _inferenceCts;
// ── Providers ──
public static readonly ICloudAiProvider[] Providers =
[
new OpenAiProvider(),
new ClaudeProvider(),
new GeminiProvider(),
];
// ── Predefined local models ──
public static readonly LocalModel[] PredefinedModels =
[
new("Gemma 2 2B", "bartowski/gemma-2-2b-it-GGUF", "gemma-2-2b-it-Q4_K_M.gguf", "~1.6 GB"),
new("Llama 3.2 3B", "bartowski/Llama-3.2-3B-Instruct-GGUF", "Llama-3.2-3B-Instruct-Q4_K_M.gguf", "~2 GB"),
new("Llama 3.1 8B", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", "Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf", "~4.9 GB"),
];
// ── Init ──
public void Load()
{
Directory.CreateDirectory(ModelsDir);
LoadSettings();
LoadApiKeys();
}
// ── Settings persistence ──
private void LoadSettings()
{
try
{
if (File.Exists(SettingsPath))
{
var json = File.ReadAllText(SettingsPath);
Settings = JsonSerializer.Deserialize<AiSettings>(json) ?? new();
}
}
catch { Settings = new(); }
}
public void SaveSettings()
{
try
{
Directory.CreateDirectory(BaseDir);
var json = JsonSerializer.Serialize(Settings, new JsonSerializerOptions { WriteIndented = true });
File.WriteAllText(SettingsPath, json);
}
catch { }
}
// ── API Keys (DPAPI) ──
public string GetApiKey(string providerId) =>
_apiKeys.TryGetValue(providerId, out var key) ? key : "";
public void SetApiKey(string providerId, string key)
{
if (string.IsNullOrWhiteSpace(key))
_apiKeys.Remove(providerId);
else
_apiKeys[providerId] = key;
SaveApiKeys();
}
public bool HasApiKey(string providerId) =>
_apiKeys.TryGetValue(providerId, out var k) && !string.IsNullOrWhiteSpace(k);
private void LoadApiKeys()
{
try
{
if (!File.Exists(KeysPath)) return;
var encrypted = File.ReadAllBytes(KeysPath);
var decrypted = ProtectedData.Unprotect(encrypted, null, DataProtectionScope.CurrentUser);
var json = Encoding.UTF8.GetString(decrypted);
_apiKeys = JsonSerializer.Deserialize<Dictionary<string, string>>(json) ?? new();
}
catch { _apiKeys = new(); }
}
private void SaveApiKeys()
{
try
{
Directory.CreateDirectory(BaseDir);
var json = JsonSerializer.Serialize(_apiKeys);
var bytes = Encoding.UTF8.GetBytes(json);
var encrypted = ProtectedData.Protect(bytes, null, DataProtectionScope.CurrentUser);
File.WriteAllBytes(KeysPath, encrypted);
}
catch { }
}
// ── Provider helpers ──
public ICloudAiProvider? GetProvider(string id) =>
Providers.FirstOrDefault(p => p.Id == id);
// ── Local model management ──
public List<LocalModel> GetInstalledModels()
{
if (!Directory.Exists(ModelsDir)) return [];
var files = Directory.GetFiles(ModelsDir, "*.gguf");
var installed = new List<LocalModel>();
foreach (var file in files)
{
var fileName = Path.GetFileName(file);
var predefined = PredefinedModels.FirstOrDefault(m => m.FileName == fileName);
if (predefined != null)
installed.Add(predefined);
else
{
var sizeMb = new FileInfo(file).Length / (1024.0 * 1024.0);
var sizeStr = sizeMb > 1024 ? $"~{sizeMb / 1024:F1} GB" : $"~{sizeMb:F0} MB";
installed.Add(new LocalModel(fileName, "", fileName, sizeStr, false));
}
}
return installed;
}
public Task DownloadModelAsync(LocalModel model, IProgress<(long downloaded, long? total)> progress, CancellationToken ct)
=> DownloadFileAsync(model.DownloadUrl, model.LocalPath, progress, ct);
public Task DownloadFromUrlAsync(string url, string fileName, IProgress<(long downloaded, long? total)> progress, CancellationToken ct)
=> DownloadFileAsync(url, Path.Combine(ModelsDir, fileName), progress, ct);
private static async Task DownloadFileAsync(string url, string destPath, IProgress<(long downloaded, long? total)> progress, CancellationToken ct)
{
Directory.CreateDirectory(Path.GetDirectoryName(destPath)!);
var tempPath = destPath + ".tmp";
using var http = new HttpClient();
http.Timeout = TimeSpan.FromHours(2);
using var response = await http.GetAsync(url, HttpCompletionOption.ResponseHeadersRead, ct).ConfigureAwait(false);
response.EnsureSuccessStatusCode();
var totalBytes = response.Content.Headers.ContentLength;
await using var stream = await response.Content.ReadAsStreamAsync(ct).ConfigureAwait(false);
await using var fileStream = new FileStream(tempPath, FileMode.Create, FileAccess.Write, FileShare.None, 81920);
var buffer = new byte[81920];
long totalRead = 0;
int bytesRead;
var lastReport = DateTime.UtcNow;
while ((bytesRead = await stream.ReadAsync(buffer, ct).ConfigureAwait(false)) > 0)
{
await fileStream.WriteAsync(buffer.AsMemory(0, bytesRead), ct).ConfigureAwait(false);
totalRead += bytesRead;
var now = DateTime.UtcNow;
if ((now - lastReport).TotalMilliseconds >= 50)
{
lastReport = now;
progress.Report((totalRead, totalBytes));
}
}
// Final report to ensure 100%
progress.Report((totalRead, totalBytes));
fileStream.Close();
if (File.Exists(destPath)) File.Delete(destPath);
File.Move(tempPath, destPath);
}
public void DeleteModel(string fileName)
{
var path = Path.Combine(ModelsDir, fileName);
if (_loadedModelPath == path) UnloadModel();
if (File.Exists(path)) File.Delete(path);
}
// ── Local inference ──
public async Task LoadModelAsync(string fileName)
{
var path = Path.Combine(ModelsDir, fileName);
if (_loadedModelPath == path && _model != null) return;
UnloadModel();
await Task.Run(() =>
{
var gpuLayers = Settings.GpuLayers;
for (var attempt = 0; attempt < 2; attempt++)
{
try
{
var parameters = new ModelParams(path)
{
ContextSize = (uint)Settings.ContextSize,
GpuLayerCount = attempt == 0 ? gpuLayers : 0,
};
_model = LLamaWeights.LoadFromFile(parameters);
_modelParams = parameters;
_loadedModelPath = path;
return;
}
catch when (attempt == 0 && gpuLayers > 0)
{
_model?.Dispose();
_model = null;
}
}
});
}
public async IAsyncEnumerable<string> ChatLocalAsync(string userMessage, List<ChatMessage> history,
[System.Runtime.CompilerServices.EnumeratorCancellation] CancellationToken ct = default)
{
if (_model == null || _modelParams == null) yield break;
var template = new LLamaTemplate(_model);
template.Add("system", Settings.SystemPrompt);
foreach (var msg in history)
template.Add(msg.Role, msg.Content);
template.Add("user", userMessage);
var prompt = Encoding.UTF8.GetString(template.Apply().ToArray());
var inferenceParams = new InferenceParams
{
MaxTokens = Settings.MaxTokens,
AntiPrompts = [
"<|eot_id|>", "<|start_header_id|>",
"<|end|>", "<|assistant|>", "<|user|>", "<|system|>",
"[/INST]", "</s>",
"<|im_end|>", "<|endoftext|>",
"<end_of_turn>",
],
SamplingPipeline = new DefaultSamplingPipeline { Temperature = Settings.Temperature },
};
_inferenceCts = CancellationTokenSource.CreateLinkedTokenSource(ct);
using var context = _model.CreateContext(_modelParams);
var executor = new StatelessExecutor(_model, _modelParams);
await foreach (var token in executor.InferAsync(prompt, inferenceParams, _inferenceCts.Token))
{
yield return token;
}
}
public void StopInference() => _inferenceCts?.Cancel();
public void UnloadModel()
{
_model?.Dispose();
_model = null;
_modelParams = null;
_loadedModelPath = null;
}
public bool IsEnabled => Settings.IsEnabled;
public void DisableAll()
{
UnloadModel();
Settings.IsEnabled = false;
Settings.LastProviderId = "";
Settings.LastModelId = "";
Settings.LastLocalModelFileName = "";
SaveSettings();
}
public bool IsModelLoaded => _model != null;
public string? LoadedModelFileName => _loadedModelPath != null ? Path.GetFileName(_loadedModelPath) : null;
}