From 3194a97b883aea00d10584d123c1689c5060b558 Mon Sep 17 00:00:00 2001 From: Caroline BANCHEREAU Date: Tue, 13 May 2025 11:03:56 +0200 Subject: [PATCH 01/15] added risk dashboard - button at the bottom + ui in new tab - changed favicon - changed font --- .gitignore | 2 + backend-agent/main.py | 136 +++++++- frontend/src/app/app-routing.module.ts | 16 +- frontend/src/app/app.component.html | 3 +- .../src/app/chatzone/chatzone.component.css | 16 +- .../src/app/chatzone/chatzone.component.html | 46 ++- .../src/app/chatzone/chatzone.component.ts | 134 ++++---- .../src/app/heatmap/heatmap.component.css | 80 +++++ .../src/app/heatmap/heatmap.component.html | 36 +++ frontend/src/app/heatmap/heatmap.component.ts | 291 ++++++++++++++++++ frontend/src/app/utils/utils.ts | 40 +++ frontend/src/styles.css | 61 +++- frontend/tsconfig.json | 2 + 13 files changed, 776 insertions(+), 87 deletions(-) create mode 100644 frontend/src/app/heatmap/heatmap.component.css create mode 100644 frontend/src/app/heatmap/heatmap.component.html create mode 100644 frontend/src/app/heatmap/heatmap.component.ts create mode 100644 frontend/src/app/utils/utils.ts diff --git a/.gitignore b/.gitignore index 0887fc1..f9a5ba7 100644 --- a/.gitignore +++ b/.gitignore @@ -89,6 +89,8 @@ venv/ ENV/ env.bak/ venv.bak/ +venv310 +cache # Spyder project settings .spyderproject diff --git a/backend-agent/main.py b/backend-agent/main.py index 7f6e80d..44fc501 100644 --- a/backend-agent/main.py +++ b/backend-agent/main.py @@ -1,3 +1,4 @@ +import csv import json import os @@ -5,6 +6,7 @@ from flask import Flask, abort, jsonify, request, send_file from flask_cors import CORS from flask_sock import Sock +from werkzeug.utils import secure_filename if not os.getenv('DISABLE_AGENT'): from agent import agent @@ -40,15 +42,19 @@ } if langfuse_handler else { 'callbacks': [status_callback_handler]} +# Set up the upload folder dynamically +UPLOAD_FOLDER = './uploads' # You can change this to a different path if needed +ALLOWED_EXTENSIONS = {'csv'} +app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER + +# Create the upload folder if it doesn't exist +if not os.path.exists(app.config['UPLOAD_FOLDER']): + os.makedirs(app.config['UPLOAD_FOLDER']) + def send_intro(sock): """ Sends the intro via the websocket connection. - - The intro is meant as a short tutorial on how to use the agent. - Also it includes meaningful suggestions for prompts that should - result in predictable behavior for the agent, e.g. - "Start the vulnerability scan". """ with open('data/intro.txt', 'r') as f: intro = f.read() @@ -60,18 +66,8 @@ def query_agent(sock): """ Websocket route for the frontend to send prompts to the agent and receive responses as well as status updates. - - Messages received are in this JSON format: - - { - "type":"message", - "data":"Start the vulnerability scan", - "key":"secretapikey" - } - """ status.sock = sock - # Intro is sent after connecting successfully send_intro(sock) while True: data_raw = sock.receive() @@ -129,6 +125,116 @@ def check_health(): return jsonify({'status': 'ok'}) +def allowed_file(filename): + return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS + + +@app.route('/api/upload-csv', methods=['POST']) +def upload_csv(): + if 'file' not in request.files: + return jsonify({'error': 'No file part'}), 400 + + file = request.files['file'] + if file.filename == '': + return jsonify({'error': 'No file selected'}), 400 + + if file and allowed_file(file.filename): + filename = secure_filename(file.filename) + path = os.path.join(app.config['UPLOAD_FOLDER'], filename) + file.save(path) + + # Optional: parse CSV immediately + try: + with open(path, newline='') as f: + reader = csv.DictReader(f) + data = list(reader) + + return jsonify({'message': 'CSV uploaded successfully', 'data': data}) + + except Exception as e: + return jsonify({'error': f'Error reading CSV: {str(e)}'}), 500 + + return jsonify({'error': 'Invalid file type'}), 400 + + +# Endpoint to fetch all the vendors from the uploaded CSV +@app.route('/api/vendors', methods=['GET']) +def get_vendors(): + # Check if CSV file exists + error_response = check_csv_exists('STARS_RESULTS.csv') + if error_response: + print("❌ CSV not found or error from check_csv_exists") + return error_response + + try: + file_path = os.path.join(app.config['UPLOAD_FOLDER'], 'STARS_RESULTS.csv') + print(f"📄 Reading CSV from: {file_path}") + with open(file_path, mode='r') as f: + reader = csv.DictReader(f) + data = list(reader) + # Extract unique vendors + vendors = list(set([model['vendor'] for model in data if 'vendor' in model])) + return jsonify(vendors) + + except Exception as e: + print(f"🔥 Exception occurred: {str(e)}") # DEBUG PRINT + return jsonify({'error': f'Error reading vendors from CSV: {str(e)}'}), 500 + + +# Endpoint to fetch heatmap data from the uploaded CSV +@app.route('/api/heatmap', methods=['GET']) +def get_heatmap(): + file_path = os.path.join(app.config['UPLOAD_FOLDER'], 'STARS_RESULTS.csv') # Use dynamic upload folder path + try: + with open(file_path, mode='r') as f: + reader = csv.DictReader(f) + data = list(reader) + + return jsonify(data) + + except Exception as e: + return jsonify({'error': f'Error reading heatmap data from CSV: {str(e)}'}), 500 + + +# Endpoint to fetch heatmap data filtered by vendor from the uploaded CSV +@app.route('/api/heatmap/', methods=['GET']) +def get_filtered_heatmap(name): + file_path = os.path.join(app.config['UPLOAD_FOLDER'], 'STARS_RESULTS.csv') # Use dynamic upload folder path + try: + with open(file_path, mode='r') as f: + reader = csv.DictReader(f) + data = list(reader) + + # Filter data by vendor name + filtered_data = [model for model in data if model['vendor'].lower() == name.lower()] + return jsonify(filtered_data) + + except Exception as e: + return jsonify({'error': f'Error reading filtered heatmap data from CSV: {str(e)}'}), 500 + + +# Endpoint to fetch all attacks from the uploaded CSV +@app.route('/api/attacks', methods=['GET']) +def get_attacks(): + file_path = os.path.join(app.config['UPLOAD_FOLDER'], 'attacks.csv') # Use dynamic upload folder path + try: + with open(file_path, mode='r') as f: + reader = csv.DictReader(f) + data = list(reader) + + return jsonify(data) + + except Exception as e: + return jsonify({'error': f'Error reading attacks data from CSV: {str(e)}'}), 500 + + +def check_csv_exists(file_name): + file_path = os.path.join(app.config['UPLOAD_FOLDER'], file_name) + if not os.path.exists(file_path): + return jsonify({'error': f'{file_name} not found. Please upload the file first.'}), 404 + return None # No error, file exists + + if __name__ == '__main__': if not os.getenv('API_KEY'): print('No API key is set! Access is unrestricted.') diff --git a/frontend/src/app/app-routing.module.ts b/frontend/src/app/app-routing.module.ts index 0297262..9adec64 100755 --- a/frontend/src/app/app-routing.module.ts +++ b/frontend/src/app/app-routing.module.ts @@ -1,10 +1,16 @@ -import { NgModule } from '@angular/core'; -import { RouterModule, Routes } from '@angular/router'; +import {RouterModule, Routes} from '@angular/router'; -const routes: Routes = []; +import {ChatzoneComponent} from './chatzone/chatzone.component'; +import {HeatmapComponent} from './heatmap/heatmap.component'; +import {NgModule} from '@angular/core'; + +const routes: Routes = [ + {path: '', component: ChatzoneComponent}, + {path: 'heatmap', component: HeatmapComponent}, +]; @NgModule({ imports: [RouterModule.forRoot(routes)], - exports: [RouterModule] + exports: [RouterModule], }) -export class AppRoutingModule { } +export class AppRoutingModule {} diff --git a/frontend/src/app/app.component.html b/frontend/src/app/app.component.html index d913607..2de8798 100755 --- a/frontend/src/app/app.component.html +++ b/frontend/src/app/app.component.html @@ -1 +1,2 @@ - \ No newline at end of file + + \ No newline at end of file diff --git a/frontend/src/app/chatzone/chatzone.component.css b/frontend/src/app/chatzone/chatzone.component.css index b75e95f..8512c4a 100644 --- a/frontend/src/app/chatzone/chatzone.component.css +++ b/frontend/src/app/chatzone/chatzone.component.css @@ -19,6 +19,10 @@ } .status-report-container { + display: flex; + flex-direction: column; + /* height: 100vh; */ + /* padding: 1rem; */ max-width: 400px; width: 20%; } @@ -31,11 +35,17 @@ overflow-y: scroll; } +.buttons-wrapper { + margin-top: auto; /* pousse les boutons en bas */ + display: flex; + flex-direction: column; + gap: 1rem; /* espace entre les boutons */ +} + .title { justify-content: center; color: #3c226f; margin: auto; - font-family: AmericanTypewriter; background-color: unset; } @@ -167,6 +177,10 @@ mat-tab-group { color: gray; } +.left-panel-button { + width: 100%; +} + /** Generic classes **/ diff --git a/frontend/src/app/chatzone/chatzone.component.html b/frontend/src/app/chatzone/chatzone.component.html index 8271a1f..f8d7063 100644 --- a/frontend/src/app/chatzone/chatzone.component.html +++ b/frontend/src/app/chatzone/chatzone.component.html @@ -19,8 +19,14 @@ -
- +
+
+ +
+
+ + +
@@ -105,3 +111,39 @@

Vulnerability Report

+ + + + + + diff --git a/frontend/src/app/chatzone/chatzone.component.ts b/frontend/src/app/chatzone/chatzone.component.ts index 20abee7..bf24e35 100644 --- a/frontend/src/app/chatzone/chatzone.component.ts +++ b/frontend/src/app/chatzone/chatzone.component.ts @@ -1,15 +1,16 @@ -import { Component, ViewChildren, QueryList, ElementRef, AfterViewInit, AfterViewChecked } from '@angular/core'; -import { ChatItem, Message, ReportCard, VulnerabilityReportCard } from '../types/ChatItem'; -import { WebSocketService } from '../services/web-socket.service'; -import { Step, Status } from '../types/Step'; -import { APIResponse, ReportItem } from '../types/API'; -import { VulnerabilityInfoService } from '../services/vulnerability-information.service'; +import {APIResponse, ReportItem} from '../types/API'; +import {AfterViewChecked, AfterViewInit, Component, ElementRef, QueryList, ViewChildren} from '@angular/core'; +import {ChatItem, Message, ReportCard, VulnerabilityReportCard} from '../types/ChatItem'; +import {Status, Step} from '../types/Step'; + +import {VulnerabilityInfoService} from '../services/vulnerability-information.service'; +import {WebSocketService} from '../services/web-socket.service'; @Component({ selector: 'app-chatzone', templateUrl: './chatzone.component.html', styleUrls: ['./chatzone.component.css'], - standalone: false + standalone: false, }) export class ChatzoneComponent implements AfterViewInit, AfterViewChecked { chatItems: ChatItem[]; @@ -21,26 +22,27 @@ export class ChatzoneComponent implements AfterViewInit, AfterViewChecked { progress: number | undefined; constructor(private ws: WebSocketService, private vis: VulnerabilityInfoService) { this.inputValue = ''; - this.apiKey = localStorage.getItem("key") || ""; + this.apiKey = localStorage.getItem('key') || ''; this.errorMessage = ''; this.chatItems = []; this.steps = []; this.progress = undefined; - this.ws.webSocket$ - .subscribe({ - next: (value: any) => { // eslint-disable-line @typescript-eslint/no-explicit-any - this.handleWSMessage(value as APIResponse); - }, - error: (error: any) => { // eslint-disable-line @typescript-eslint/no-explicit-any - console.log(error); - if (error?.type != "close") { // Close is already handled via the isConnected call - this.errorMessage = error; - } - }, - complete: () => alert("Connection to server closed.") - } - ); + this.ws.webSocket$.subscribe({ + next: (value: any) => { + // eslint-disable-line @typescript-eslint/no-explicit-any + this.handleWSMessage(value as APIResponse); + }, + error: (error: any) => { + // eslint-disable-line @typescript-eslint/no-explicit-any + console.log(error); + if (error?.type != 'close') { + // Close is already handled via the isConnected call + this.errorMessage = error; + } + }, + complete: () => alert('Connection to server closed.'), + }); this.restoreChatItems(); } @@ -48,32 +50,32 @@ export class ChatzoneComponent implements AfterViewInit, AfterViewChecked { // Handling of the websocket connection checkInput(value: string): void { - if (value && value.trim() != "") { + if (value && value.trim() != '') { this.inputValue = ''; this.ws.postMessage(value, this.apiKey); const userMessage: Message = { type: 'message', id: 'user-message', message: value, - avatar: "person", - timestamp: Date.now() + avatar: 'person', + timestamp: Date.now(), }; this.appendMessage(userMessage); } } handleWSMessage(input: APIResponse): void { - if (input.type == "message") { + if (input.type == 'message') { const aiMessageString = input.data; const aiMessage: Message = { type: 'message', id: 'ai-message', message: aiMessageString, - avatar: "computer", - timestamp: Date.now() + avatar: 'computer', + timestamp: Date.now(), }; this.appendMessage(aiMessage); - } else if (input.type == "status") { + } else if (input.type == 'status') { const current = input.current; const total = input.total; const progress = current / total; @@ -81,7 +83,7 @@ export class ChatzoneComponent implements AfterViewInit, AfterViewChecked { if (progress >= 1) { this.progress = undefined; } - } else if (input.type == "report") { + } else if (input.type == 'report') { if (input.reset) { this.steps = []; return; @@ -96,28 +98,28 @@ export class ChatzoneComponent implements AfterViewInit, AfterViewChecked { this.steps.push(step); } } - } else if (input.type == "intermediate") { + } else if (input.type == 'intermediate') { const text = '### Intermediate result from attack\n' + input.data; const intermediateMessage: Message = { type: 'message', id: 'assistant-intermediate-message', message: text, - avatar: "computer", - timestamp: Date.now() + avatar: 'computer', + timestamp: Date.now(), }; this.appendMessage(intermediateMessage); - } else if (input.type == "vulnerability-report") { + } else if (input.type == 'vulnerability-report') { const vulnerabilityCards = input.data.map(vri => { - const vrc = (vri as VulnerabilityReportCard); + const vrc = vri as VulnerabilityReportCard; vrc.description = this.vis.getInfo(vri.vulnerability); return vrc; }); this.chatItems.push({ type: 'report-card', - 'reports': vulnerabilityCards, - 'name': input.name + reports: vulnerabilityCards, + name: input.name, }); - localStorage.setItem("cached-chat-items", JSON.stringify(this.chatItems)); + localStorage.setItem('cached-chat-items', JSON.stringify(this.chatItems)); } } @@ -127,21 +129,26 @@ export class ChatzoneComponent implements AfterViewInit, AfterViewChecked { getIconForStepStatus(step: Step): string { switch (step.status) { - case Status.COMPLETED: return 'check_circle'; - case Status.FAILED: return 'error'; - case Status.SKIPPED: return 'skip_next'; - case Status.RUNNING: return 'play_circle'; - case Status.PENDING: return 'pending'; + case Status.COMPLETED: + return 'check_circle'; + case Status.FAILED: + return 'error'; + case Status.SKIPPED: + return 'skip_next'; + case Status.RUNNING: + return 'play_circle'; + case Status.PENDING: + return 'pending'; } } appendMessage(message: Message) { this.chatItems.push(message); - localStorage.setItem("cached-chat-items", JSON.stringify(this.chatItems)); + localStorage.setItem('cached-chat-items', JSON.stringify(this.chatItems)); } restoreChatItems() { - const storedMessages = localStorage.getItem("cached-chat-items"); + const storedMessages = localStorage.getItem('cached-chat-items'); if (storedMessages) { this.chatItems = JSON.parse(storedMessages); } @@ -149,22 +156,30 @@ export class ChatzoneComponent implements AfterViewInit, AfterViewChecked { clearChatHistory() { this.chatItems = []; - localStorage.setItem("cached-chat-items", "[]"); + localStorage.setItem('cached-chat-items', '[]'); } static deserializeStep(obj: ReportItem): Step { let status = Status.RUNNING; switch (obj.status) { - case "COMPLETED": status = Status.COMPLETED; break; - case "FAILED": status = Status.FAILED; break; - case "SKIPPED": status = Status.SKIPPED; break; - case "PENDING": status = Status.PENDING; break; + case 'COMPLETED': + status = Status.COMPLETED; + break; + case 'FAILED': + status = Status.FAILED; + break; + case 'SKIPPED': + status = Status.SKIPPED; + break; + case 'PENDING': + status = Status.PENDING; + break; } return { title: obj.title, description: obj.description, status: status, - progress: obj.progress + progress: obj.progress, }; } @@ -193,11 +208,11 @@ export class ChatzoneComponent implements AfterViewInit, AfterViewChecked { const downloadUrl = window.URL.createObjectURL(file); const a = document.createElement('a'); a.href = downloadUrl; - a.download = `${reportName}.${reportFormat}`; // Set the filename + a.download = `${reportName}.${reportFormat}`; // Set the filename document.body.appendChild(a); - a.click(); // Programmatically click the link to trigger the download - document.body.removeChild(a); // Remove the link element - window.URL.revokeObjectURL(downloadUrl); // Clean up the URL object + a.click(); // Programmatically click the link to trigger the download + document.body.removeChild(a); // Remove the link element + window.URL.revokeObjectURL(downloadUrl); // Clean up the URL object } // Scrolling to have new messages visible @@ -258,11 +273,11 @@ export class ChatzoneComponent implements AfterViewInit, AfterViewChecked { downloadChatHistory(): void { const markdownContent = this.exportChat(); - const blob = new Blob([markdownContent], { type: 'text/markdown' }); + const blob = new Blob([markdownContent], {type: 'text/markdown'}); const link = document.createElement('a'); link.href = URL.createObjectURL(blob); - link.download = "STARS_chat_" + new Date().toISOString() + ".md"; + link.download = 'STARS_chat_' + new Date().toISOString() + '.md'; // Append the link to the body document.body.appendChild(link); @@ -280,4 +295,9 @@ export class ChatzoneComponent implements AfterViewInit, AfterViewChecked { this.apiKey = prompt('Set API Key', this.apiKey) || this.apiKey; localStorage.setItem('key', this.apiKey); } + + // openDashboard() that loads a new page with the dashboard at the route /heatmap + openDashboard(): void { + window.open('/heatmap', '_blank'); + } } diff --git a/frontend/src/app/heatmap/heatmap.component.css b/frontend/src/app/heatmap/heatmap.component.css new file mode 100644 index 0000000..8a9c30e --- /dev/null +++ b/frontend/src/app/heatmap/heatmap.component.css @@ -0,0 +1,80 @@ +h1 { + width: 500px; + margin: auto; + text-align: center; +} + +#heatmapChart { + /* max-width: 600px; + max-height: 400px; */ + display: block; + margin: 10px auto; +} + +* { + font-family: Helvetica, Arial, sans-serif; +} + +#vendorChoice { + padding: 20px 0; + width: 500px; + margin: auto; + text-align: center; +} + +.title-card { + width: 700px; + margin: auto; + height: 100px; + padding: 25px; +} + .heatmap-card { + width: 700px; + margin: auto; + height: 1000px; + padding: 25px; +} + +/* .card-header { + font-size: 2rem; + font-weight: bold; +} */ + +#overview { + font-size: medium; +} + +/* .header-icon { + width: 50px; +} */ + +.card-header { + display: flex; + align-items: center; + gap: 10px; + font-size: 2rem; + font-weight: bold; + justify-content: center; +} + +.header-icon { + width: 45px; +} + +.title { + display: flex; + align-items: center; +} + +.buttons-wrapper { + /* margin-top: auto; pousse les boutons en bas */ + display: flex; + flex-direction: column; + gap: 1rem; /* espace entre les boutons */ +} + +.centered-button { + display: flex; + justify-content: center; + align-items: center; +} diff --git a/frontend/src/app/heatmap/heatmap.component.html b/frontend/src/app/heatmap/heatmap.component.html new file mode 100644 index 0000000..eda4d0f --- /dev/null +++ b/frontend/src/app/heatmap/heatmap.component.html @@ -0,0 +1,36 @@ + + + + +
+ + STARS Results Heatmap +
+ + +
+
+ + + +
+
+ +
+ + Select a Vendor + + + All vendors + {{ vendor }} + + + overview +
+ + +
+ + diff --git a/frontend/src/app/heatmap/heatmap.component.ts b/frontend/src/app/heatmap/heatmap.component.ts new file mode 100644 index 0000000..ff388d3 --- /dev/null +++ b/frontend/src/app/heatmap/heatmap.component.ts @@ -0,0 +1,291 @@ +import {AfterViewInit, ChangeDetectorRef, Component, ElementRef, OnInit} from '@angular/core'; +import {Observable, map} from 'rxjs'; +import {capitalizeFirstLetter, generateModelName, splitModelName} from '../utils/utils'; + +import ApexCharts from 'apexcharts'; +import {CommonModule} from '@angular/common'; +import {FormsModule} from '@angular/forms'; +import {HttpClient} from '@angular/common/http'; +import {MatButtonModule} from '@angular/material/button'; +import {MatCardModule} from '@angular/material/card'; +import {MatFormFieldModule} from '@angular/material/form-field'; +import {MatSelectModule} from '@angular/material/select'; +import {environment} from '../../environments/environment'; + +@Component({ + selector: 'app-heatmap', + templateUrl: './heatmap.component.html', + styleUrls: ['./heatmap.component.css'], + standalone: true, + imports: [CommonModule, MatFormFieldModule, MatSelectModule, FormsModule, MatCardModule, MatButtonModule], +}) +export class HeatmapComponent implements AfterViewInit, OnInit { + public heatmapData: number[][] = []; + // for UI dropdown menu of vendors + public vendorsNames: string[] = []; + public selectedVendor: string = ''; + public weightedAttacks: {attackName: string; weight: string}[] = []; + + constructor(private http: HttpClient, private el: ElementRef, private changeDetector: ChangeDetectorRef) {} + + ngAfterViewInit() { + this.createHeatmap([]); // Initialisation avec des données vides + } + + ngOnInit() { + // this.loadHeatmapData('amazon'); + this.loadVendorsData(); + this.loadHeatmapData(''); + } + + onFileSelected(event: any) { + // const file = event.target.files[0]; + // if (!file) return; + // const formData = new FormData(); + // formData.append('file', file); + // this.http.post('http://localhost:3000/upload', formData).subscribe({ + // next: data => { + // console.log('📊 Données reçues via upload:', data); + // this.processData(data); + // }, + // error: error => console.error('❌ Erreur upload:', error), + // }); + } + + //load a dropdown menu from the loadModelsData result + loadVendorsData() { + this.http.get(`http://127.0.0.1:8080/api/vendors`).subscribe({ + // this.http.get(`${environment.api_url}/api/vendors`).subscribe({ + next: data => { + console.log('📡 Données brutes reçues du serveur:', data); + this.processVendors(data.map(vendor => vendor)); + }, + error: error => console.error('❌ Erreur API:', error), + }); + } + + //load the heatmap data from the server with a name in params + loadHeatmapData(vendor: string) { + let url = ''; + if (!vendor) { + url = `${environment.api_url}/api/heatmap`; + } else { + url = `${environment.api_url}/api/heatmap/${vendor}`; + } + this.http.get(url).subscribe({ + // this.http.get(`${environment.api_url}/api/${vendor}`).subscribe({ + next: scoresData => { + this.processData(scoresData, vendor); + }, + error: error => console.error('❌ Erreur API:', error), + }); + } + + // handle models name recieved from the server to a list used in frontend for a dropdown menu + processVendors(vendorsNames: string[]) { + this.vendorsNames = vendorsNames.map(capitalizeFirstLetter); + } + + processData(data: any[], vendor: string = '') { + const modelNames = generateModelName(data, vendor); + this.getWeightedAttacks().subscribe({ + next: weightedAttacks => { + this.heatmapData = data.map(row => { + const rowData = weightedAttacks.map(attack => { + const value = Number(row[attack.attackName]?.trim()); + return isNaN(value) ? 0 : value * 10; + }); + let totalWeights = 0; + // Add an extra column at the end with a custom calculation (modify as needed) + const weightedSumColumn = weightedAttacks.reduce((sum, {attackName, weight}) => { + const value = Number(row[attackName]?.trim()); + const weightedValue = isNaN(value) ? 0 : value * Number(weight); + totalWeights = totalWeights + Number(weight); + return sum + weightedValue; + }, 0); + // Append the calculated weighted sum column to the row as the last column "as an attack" even if it's a custom calculated value + return [...rowData, (weightedSumColumn / totalWeights) * 10]; + }); + const attackNames = weightedAttacks.map(attack => attack.attackName); + this.createHeatmap(this.heatmapData, modelNames, [...attackNames.map(capitalizeFirstLetter), 'Exposure score'], vendor !== ''); + }, + error: error => console.error('❌ Erreur API:', error), + }); + } + + createHeatmap(data: number[][], modelNames: Record = {}, attackNames: string[] = [], oneVendorDisplayed: boolean = false) { + const cellSize = 100; + const chartWidth = attackNames.length * cellSize + 150; // +100 to allow some space for translated labels + const chartHeight = data.length <= 3 ? data.length * cellSize + 100 : data.length * cellSize; + // const series = Object.entries(modelNames).flatMap(([vendor, models]) => + // models.map((model, modelIndex) => ({ + // name: splitModelName(vendor, model), + // data: data[modelIndex].map((value, colIndex) => ({ + // x: attackNames[colIndex], + // y: value, + // })), + // })) + // ); + + // // group by vendors + // let globalIndex = 0; + // const series = Object.entries(modelNames).flatMap(([vendor, models]) => + // models.map(model => { + // const seriesData = { + // name: splitModelName(vendor, model), + // data: data[globalIndex].map((value, colIndex) => ({ + // x: attackNames[colIndex], + // y: value, + // })), + // }; + // globalIndex++; // Increment global index for next model + // return seriesData; + // }) + // ); + + // does not group by vendor + // Flatten all models but keep vendor info + const allModels = Object.entries(modelNames).flatMap(([vendor, models]) => models.map(model => ({vendor, model}))); + + let globalIndex = 0; + + const series = allModels.map(({vendor, model}) => { + const seriesData = { + name: splitModelName(vendor, model), // Display vendor and model together + data: data[globalIndex].map((value, colIndex) => ({ + x: attackNames[colIndex], + y: value, + })), + }; + globalIndex++; // Move to next row in data + return seriesData; + }); + + const options = { + chart: { + type: 'heatmap', + height: chartHeight, + width: chartWidth, + toolbar: {show: false}, + events: { + legendClick: function () { + console.log('CLICKED'); + }, + }, + }, + series: series, + plotOptions: { + heatmap: { + shadeIntensity: 0.5, + // useFillColorAsStroke: true, // Améliore le rendu des cases + colorScale: { + ranges: [ + // {from: 0, to: 20, color: '#5aa812'}, // Light green for 0-20 + // {from: 21, to: 40, color: '#00A100'}, // Darker green for 21-40 + // {from: 41, to: 60, color: '#FFB200'}, // Light orange for 41-60 + // {from: 61, to: 80, color: '#FF7300'}, // Darker orange for 61-80 + // {from: 81, to: 100, color: '#FF0000'}, // Red for 81-100 + + {from: 0, to: 40, color: '#00A100'}, + // {from: 21, to: 40, color: '#128FD9'}, + {from: 41, to: 80, color: '#FF7300'}, + // {from: 61, to: 80, color: '#FFB200'}, + {from: 81, to: 100, color: '#FF0000'}, + ], + }, + }, + }, + grid: { + padding: {top: 0, right: 0, bottom: 0, left: 0}, + }, + dataLabels: { + style: {fontSize: '14px'}, + }, + legend: { + show: true, + // markers: { + // customHTML: function () { + // return ''; + // }, + // }, + // markers: { + // width: 12, + // height: 12, + // // Remove customHTML if you want the default + // }, + }, + xaxis: { + categories: attackNames, + title: {text: 'Attacks'}, + labels: {rotate: -45, style: {fontSize: '12px'}}, + position: 'top', + }, + yaxis: { + categories: modelNames, + title: { + text: 'Models', + offsetX: oneVendorDisplayed ? -90 : -60, + }, + labels: { + style: { + fontSize: '12px', + }, + offsetY: -10, + }, + reversed: true, + }, + tooltip: { + y: { + formatter: undefined, + title: { + formatter: (seriesName: string) => seriesName.replace(',', '-'), + }, + }, + }, + }; + const chartElement = this.el.nativeElement.querySelector('#heatmapChart'); + if (chartElement) { + chartElement.innerHTML = ''; + const chart = new ApexCharts(chartElement, options); + chart.render(); + } + } + + public onVendorChange(event: any) { + this.loadHeatmapData(this.selectedVendor); + } + + // getattacksNames() return an array of attacks names from the server from http://localhost:3000/api/attacks + getAttacksNames(): Observable { + // return this.http.get(`${environment.api_url}/api/attacks`).pipe( + return this.http.get(`http://127.0.0.1:8080/api/attacks`).pipe( + map(data => data.map(row => row['attackName'])) // Extract only attack names + ); + } + + getWeightedAttacks(): Observable<{attackName: string; weight: string}[]> { + // return this.http.get(`${environment.api_url}/api/attacks`); + return this.http.get(`http://127.0.0.1:8080/api/attacks`); + } + + getVendors(): Observable { + this.changeDetector.detectChanges(); + // return this.http.get(`${environment.api_url}/api/vendors`); + return this.http.get(`${environment.api_url}/api/vendors`); + } + + uploadCSV(event: any) { + const file = event.target.files[0]; + const formData = new FormData(); + formData.append('file', file); + + this.http.post(`${environment.api_url}/api/upload-csv`, formData).subscribe({ + next: res => { + console.log('Upload success', res); + }, + error: err => { + console.error('Upload failed', err); + }, + }); + } +} diff --git a/frontend/src/app/utils/utils.ts b/frontend/src/app/utils/utils.ts new file mode 100644 index 0000000..8c9b220 --- /dev/null +++ b/frontend/src/app/utils/utils.ts @@ -0,0 +1,40 @@ +// export function generateModelName(vendor: string, modelType: string, version: string, specialization: string, other: string, withVendor = true): string { +export function generateModelName(data: any[], vendor: string): any { + const result: Record = {}; + + data.forEach(row => { + const vendorName = vendor === '' ? row['vendor'] : vendor; // Si vendor est vide, on prend row['vendor'], sinon on utilise vendor existant + + const model = [row['modelType'], row['version'], row['specialization'], row['other']] + .filter(value => value) + .join('-') + .replace(/\s+/g, ' ') + .trim(); + + if (!result[vendorName]) { + result[vendorName] = []; + } + + result[vendorName].push(model); + }); + return result; +} + +export function splitModelName(vendor: string, model: string): string[] { + if (model.length < 18) return [vendor, model]; // No need to split + + // Find the last "-" before the 20th character + const cutoffIndex = model.lastIndexOf('-', 20); + + if (cutoffIndex === -1) { + // If no "-" found before 20, force split at 20 + return [vendor, model.slice(0, 20), model.slice(20)]; + } + + // Split at the last "-" before 20 + return [vendor, model.slice(0, cutoffIndex), model.slice(cutoffIndex + 1)].map(capitalizeFirstLetter); +} + +export function capitalizeFirstLetter(str: string): string { + return str.charAt(0).toUpperCase() + str.slice(1); +} diff --git a/frontend/src/styles.css b/frontend/src/styles.css index b5462b6..7c155bb 100644 --- a/frontend/src/styles.css +++ b/frontend/src/styles.css @@ -13,14 +13,63 @@ body { margin: 0; font-family: Roboto, "Helvetica Neue", sans-serif; } white-space: pre-wrap !important; } -@font-face { - font-family: 'AmericanTypewriter'; - src: url('assets/fonts/AmericanTypewriter.ttc') format('truetype'); - font-weight: normal; - font-style: normal; -} /* For responses from agent, which can contain very long lines */ pre { overflow-x: auto; } + +body { + font-family: Roboto, "Helvetica Neue", sans-serif; + margin: 0; + padding: 30px; + height: 100%; +} + +.apexcharts-canvas { + margin: auto; + translate: -60px; +} + +.apexcharts-inner { + translate: 40px; +} + +.apexcharts-yaxis-texts-g { + translate: 20px; +} + +.apexcharts-yaxis-title { + translate: 100px; + > text { + font-size: large; + } +} + +.apexcharts-xaxis { + translate: 0 10px; +} + +.apexcharts-xaxis-title { + > text { + font-size: large; + } +} + +html { height: 100%; } + +.card-header { + font-size: 2.5rem; /* Larger font for the header */ + font-weight: bold; /* Make the text bold */ + text-align: center; /* Center the text */ + padding: 16px; +} + +mat-select { + direction: rtl !important; + text-align: right !important; +} + +.apexcharts-legend { + translate: 50px; +} diff --git a/frontend/tsconfig.json b/frontend/tsconfig.json index ed966d4..409a471 100755 --- a/frontend/tsconfig.json +++ b/frontend/tsconfig.json @@ -2,6 +2,8 @@ { "compileOnSave": false, "compilerOptions": { + "allowSyntheticDefaultImports": true, + "esModuleInterop": true, "baseUrl": "./", "outDir": "./dist/out-tsc", "forceConsistentCasingInFileNames": true, From f6e60598de66ee28af262fd1969dff615e5e4126 Mon Sep 17 00:00:00 2001 From: Caroline BANCHEREAU Date: Tue, 13 May 2025 11:08:32 +0200 Subject: [PATCH 02/15] re added deleted main.py comments deleted unused html comments --- backend-agent/main.py | 15 ++++++++ .../src/app/chatzone/chatzone.component.html | 36 ------------------- 2 files changed, 15 insertions(+), 36 deletions(-) diff --git a/backend-agent/main.py b/backend-agent/main.py index 44fc501..96d667a 100644 --- a/backend-agent/main.py +++ b/backend-agent/main.py @@ -55,6 +55,11 @@ def send_intro(sock): """ Sends the intro via the websocket connection. + + The intro is meant as a short tutorial on how to use the agent. + Also it includes meaningful suggestions for prompts that should + result in predictable behavior for the agent, e.g. + "Start the vulnerability scan". """ with open('data/intro.txt', 'r') as f: intro = f.read() @@ -66,8 +71,18 @@ def query_agent(sock): """ Websocket route for the frontend to send prompts to the agent and receive responses as well as status updates. + + Messages received are in this JSON format: + + { + "type":"message", + "data":"Start the vulnerability scan", + "key":"secretapikey" + } + """ status.sock = sock + # Intro is sent after connecting successfully send_intro(sock) while True: data_raw = sock.receive() diff --git a/frontend/src/app/chatzone/chatzone.component.html b/frontend/src/app/chatzone/chatzone.component.html index f8d7063..647f11c 100644 --- a/frontend/src/app/chatzone/chatzone.component.html +++ b/frontend/src/app/chatzone/chatzone.component.html @@ -111,39 +111,3 @@

Vulnerability Report

- - - - - - From 6dc2b39d4be0912dc7c890b17db205d9fc17e55c Mon Sep 17 00:00:00 2001 From: Caroline BANCHEREAU Date: Tue, 13 May 2025 16:30:51 +0200 Subject: [PATCH 03/15] =?UTF-8?q?forgot=20some=20localhost=20api=20endpoin?= =?UTF-8?q?t=20=F0=9F=A4=A6=F0=9F=8F=BB=E2=80=8D=E2=99=80=EF=B8=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- frontend/src/app/heatmap/heatmap.component.ts | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/frontend/src/app/heatmap/heatmap.component.ts b/frontend/src/app/heatmap/heatmap.component.ts index ff388d3..ee36f7d 100644 --- a/frontend/src/app/heatmap/heatmap.component.ts +++ b/frontend/src/app/heatmap/heatmap.component.ts @@ -54,8 +54,8 @@ export class HeatmapComponent implements AfterViewInit, OnInit { //load a dropdown menu from the loadModelsData result loadVendorsData() { - this.http.get(`http://127.0.0.1:8080/api/vendors`).subscribe({ - // this.http.get(`${environment.api_url}/api/vendors`).subscribe({ + // this.http.get(`http://127.0.0.1:8080/api/vendors`).subscribe({ + this.http.get(`${environment.api_url}/api/vendors`).subscribe({ next: data => { console.log('📡 Données brutes reçues du serveur:', data); this.processVendors(data.map(vendor => vendor)); @@ -257,21 +257,21 @@ export class HeatmapComponent implements AfterViewInit, OnInit { // getattacksNames() return an array of attacks names from the server from http://localhost:3000/api/attacks getAttacksNames(): Observable { - // return this.http.get(`${environment.api_url}/api/attacks`).pipe( - return this.http.get(`http://127.0.0.1:8080/api/attacks`).pipe( + return this.http.get(`${environment.api_url}/api/attacks`).pipe( + // return this.http.get(`http://127.0.0.1:8080/api/attacks`).pipe( map(data => data.map(row => row['attackName'])) // Extract only attack names ); } getWeightedAttacks(): Observable<{attackName: string; weight: string}[]> { - // return this.http.get(`${environment.api_url}/api/attacks`); - return this.http.get(`http://127.0.0.1:8080/api/attacks`); + return this.http.get(`${environment.api_url}/api/attacks`); + // return this.http.get(`http://127.0.0.1:8080/api/attacks`); } getVendors(): Observable { this.changeDetector.detectChanges(); - // return this.http.get(`${environment.api_url}/api/vendors`); return this.http.get(`${environment.api_url}/api/vendors`); + // return this.http.get(`http://127.0.0.1:8080/api/vendors`); } uploadCSV(event: any) { From 123f33e878789536de3ab2b0bb97f83df82e8fce Mon Sep 17 00:00:00 2001 From: Marco Rosa Date: Wed, 14 May 2025 10:10:19 +0200 Subject: [PATCH 04/15] Fix linter and use env var for dashboard folder --- backend-agent/main.py | 67 +++++++++++++++++++++++++++---------------- 1 file changed, 43 insertions(+), 24 deletions(-) diff --git a/backend-agent/main.py b/backend-agent/main.py index 96d667a..1904eab 100644 --- a/backend-agent/main.py +++ b/backend-agent/main.py @@ -42,14 +42,14 @@ } if langfuse_handler else { 'callbacks': [status_callback_handler]} -# Set up the upload folder dynamically -UPLOAD_FOLDER = './uploads' # You can change this to a different path if needed +# Dashboard data +DASHBOARD_DATA_DIR = os.getenv('DASHBOARD_DATA_DIR', 'dashboard') ALLOWED_EXTENSIONS = {'csv'} -app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER +app.config['DASHBOARD_FOLDER'] = DASHBOARD_DATA_DIR -# Create the upload folder if it doesn't exist -if not os.path.exists(app.config['UPLOAD_FOLDER']): - os.makedirs(app.config['UPLOAD_FOLDER']) +# Create the data folder for the dashboard if it doesn't exist +if not os.path.exists(app.config['DASHBOARD_FOLDER']): + os.makedirs(app.config['DASHBOARD_FOLDER']) def send_intro(sock): @@ -141,7 +141,8 @@ def check_health(): def allowed_file(filename): - return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS + return '.' in filename and filename.rsplit('.', 1)[1].lower() in \ + ALLOWED_EXTENSIONS @app.route('/api/upload-csv', methods=['POST']) @@ -155,7 +156,7 @@ def upload_csv(): if file and allowed_file(file.filename): filename = secure_filename(file.filename) - path = os.path.join(app.config['UPLOAD_FOLDER'], filename) + path = os.path.join(app.config['DASHBOARD_FOLDER'], filename) file.save(path) # Optional: parse CSV immediately @@ -164,7 +165,9 @@ def upload_csv(): reader = csv.DictReader(f) data = list(reader) - return jsonify({'message': 'CSV uploaded successfully', 'data': data}) + return jsonify({'message': 'CSV uploaded successfully', + 'data': data} + ) except Exception as e: return jsonify({'error': f'Error reading CSV: {str(e)}'}), 500 @@ -178,28 +181,33 @@ def get_vendors(): # Check if CSV file exists error_response = check_csv_exists('STARS_RESULTS.csv') if error_response: - print("❌ CSV not found or error from check_csv_exists") + print('❌ CSV not found or error from check_csv_exists') return error_response try: - file_path = os.path.join(app.config['UPLOAD_FOLDER'], 'STARS_RESULTS.csv') - print(f"📄 Reading CSV from: {file_path}") + file_path = os.path.join( + app.config['DASHBOARD_FOLDER'], 'STARS_RESULTS.csv') + print(f'📄 Reading CSV from: {file_path}') with open(file_path, mode='r') as f: reader = csv.DictReader(f) data = list(reader) # Extract unique vendors - vendors = list(set([model['vendor'] for model in data if 'vendor' in model])) + vendors = list(set([model['vendor'] for model in data + if 'vendor' in model])) return jsonify(vendors) except Exception as e: - print(f"🔥 Exception occurred: {str(e)}") # DEBUG PRINT - return jsonify({'error': f'Error reading vendors from CSV: {str(e)}'}), 500 + print(f'🔥 Exception occurred: {str(e)}') # DEBUG PRINT + return jsonify( + {'error': f'Error reading vendors from CSV: {str(e)}'}), 500 # Endpoint to fetch heatmap data from the uploaded CSV @app.route('/api/heatmap', methods=['GET']) def get_heatmap(): - file_path = os.path.join(app.config['UPLOAD_FOLDER'], 'STARS_RESULTS.csv') # Use dynamic upload folder path + # Use dynamic upload folder path + file_path = os.path.join( + app.config['DASHBOARD_FOLDER'], 'STARS_RESULTS.csv') try: with open(file_path, mode='r') as f: reader = csv.DictReader(f) @@ -208,30 +216,38 @@ def get_heatmap(): return jsonify(data) except Exception as e: - return jsonify({'error': f'Error reading heatmap data from CSV: {str(e)}'}), 500 + return jsonify( + {'error': f'Error reading heatmap data from CSV: {str(e)}'}), 500 # Endpoint to fetch heatmap data filtered by vendor from the uploaded CSV @app.route('/api/heatmap/', methods=['GET']) def get_filtered_heatmap(name): - file_path = os.path.join(app.config['UPLOAD_FOLDER'], 'STARS_RESULTS.csv') # Use dynamic upload folder path + # Use dynamic upload folder path + file_path = os.path.join( + app.config['DASHBOARD_FOLDER'], 'STARS_RESULTS.csv') try: with open(file_path, mode='r') as f: reader = csv.DictReader(f) data = list(reader) # Filter data by vendor name - filtered_data = [model for model in data if model['vendor'].lower() == name.lower()] + filtered_data = [model for model in data + if model['vendor'].lower() == name.lower()] return jsonify(filtered_data) except Exception as e: - return jsonify({'error': f'Error reading filtered heatmap data from CSV: {str(e)}'}), 500 + return jsonify( + {'error': f'Error reading filtered heatmap data from CSV: {str(e)}' + }), 500 # Endpoint to fetch all attacks from the uploaded CSV @app.route('/api/attacks', methods=['GET']) def get_attacks(): - file_path = os.path.join(app.config['UPLOAD_FOLDER'], 'attacks.csv') # Use dynamic upload folder path + # Use dynamic upload folder path + file_path = os.path.join( + app.config['DASHBOARD_FOLDER'], 'attacks.csv') try: with open(file_path, mode='r') as f: reader = csv.DictReader(f) @@ -240,13 +256,16 @@ def get_attacks(): return jsonify(data) except Exception as e: - return jsonify({'error': f'Error reading attacks data from CSV: {str(e)}'}), 500 + return jsonify( + {'error': f'Error reading attacks data from CSV: {str(e)}'}), 500 def check_csv_exists(file_name): - file_path = os.path.join(app.config['UPLOAD_FOLDER'], file_name) + file_path = os.path.join(app.config['DASHBOARD_FOLDER'], file_name) if not os.path.exists(file_path): - return jsonify({'error': f'{file_name} not found. Please upload the file first.'}), 404 + return jsonify( + {'error': f'{file_name} not found. Please upload the file first.' + }), 404 return None # No error, file exists From 5e60bdc6fe9e37a6d58ab69a4708a55bc9290fb8 Mon Sep 17 00:00:00 2001 From: Marco Rosa Date: Wed, 14 May 2025 10:12:47 +0200 Subject: [PATCH 05/15] Add missing frontend packages --- frontend/package-lock.json | 131 ++++++++++++++++++++++++++++++++++++- frontend/package.json | 2 + 2 files changed, 132 insertions(+), 1 deletion(-) diff --git a/frontend/package-lock.json b/frontend/package-lock.json index bb89d77..f00e007 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -18,8 +18,10 @@ "@angular/platform-browser": "^19.2.0", "@angular/platform-browser-dynamic": "^19.2.0", "@angular/router": "^19.2.0", + "apexcharts": "^4.7.0", "ngx-markdown": "^19.1.0", "node-sass": "^9.0.0", + "react-apexcharts": "^1.7.0", "rxjs": "^7.8.2", "schematics-scss-migrate": "^2.3.17", "tslib": "^2.8.1", @@ -6538,6 +6540,62 @@ "dev": true, "license": "MIT" }, + "node_modules/@svgdotjs/svg.draggable.js": { + "version": "3.0.6", + "resolved": "https://registry.npmjs.org/@svgdotjs/svg.draggable.js/-/svg.draggable.js-3.0.6.tgz", + "integrity": "sha512-7iJFm9lL3C40HQcqzEfezK2l+dW2CpoVY3b77KQGqc8GXWa6LhhmX5Ckv7alQfUXBuZbjpICZ+Dvq1czlGx7gA==", + "license": "MIT", + "peerDependencies": { + "@svgdotjs/svg.js": "^3.2.4" + } + }, + "node_modules/@svgdotjs/svg.filter.js": { + "version": "3.0.9", + "resolved": "https://registry.npmjs.org/@svgdotjs/svg.filter.js/-/svg.filter.js-3.0.9.tgz", + "integrity": "sha512-/69XMRCDoam2HgC4ldHIaDgeQf1ViHIsa0Ld4uWgiXtZ+E24DWHe/9Ib6kbNiZ7WRIdlVokUDR1Fg0kjIpkfbw==", + "license": "MIT", + "dependencies": { + "@svgdotjs/svg.js": "^3.2.4" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/@svgdotjs/svg.js": { + "version": "3.2.4", + "resolved": "https://registry.npmjs.org/@svgdotjs/svg.js/-/svg.js-3.2.4.tgz", + "integrity": "sha512-BjJ/7vWNowlX3Z8O4ywT58DqbNRyYlkk6Yz/D13aB7hGmfQTvGX4Tkgtm/ApYlu9M7lCQi15xUEidqMUmdMYwg==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/Fuzzyma" + } + }, + "node_modules/@svgdotjs/svg.resize.js": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/@svgdotjs/svg.resize.js/-/svg.resize.js-2.0.5.tgz", + "integrity": "sha512-4heRW4B1QrJeENfi7326lUPYBCevj78FJs8kfeDxn5st0IYPIRXoTtOSYvTzFWgaWWXd3YCDE6ao4fmv91RthA==", + "license": "MIT", + "engines": { + "node": ">= 14.18" + }, + "peerDependencies": { + "@svgdotjs/svg.js": "^3.2.4", + "@svgdotjs/svg.select.js": "^4.0.1" + } + }, + "node_modules/@svgdotjs/svg.select.js": { + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/@svgdotjs/svg.select.js/-/svg.select.js-4.0.3.tgz", + "integrity": "sha512-qkMgso1sd2hXKd1FZ1weO7ANq12sNmQJeGDjs46QwDVsxSRcHmvWKL2NDF7Yimpwf3sl5esOLkPqtV2bQ3v/Jg==", + "license": "MIT", + "engines": { + "node": ">= 14.18" + }, + "peerDependencies": { + "@svgdotjs/svg.js": "^3.2.4" + } + }, "node_modules/@tootallnate/once": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/@tootallnate/once/-/once-2.0.0.tgz", @@ -7531,6 +7589,12 @@ "dev": true, "license": "BSD-2-Clause" }, + "node_modules/@yr/monotone-cubic-spline": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/@yr/monotone-cubic-spline/-/monotone-cubic-spline-1.0.3.tgz", + "integrity": "sha512-FQXkOta0XBSUPHndIKON2Y9JeQz5ZeMqLYZVVK93FliNBFm7LNMIZmY6FrMEB9XPcDbE2bekMbZD6kzDkxwYjA==", + "license": "MIT" + }, "node_modules/abbrev": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/abbrev/-/abbrev-1.1.1.tgz", @@ -7790,6 +7854,20 @@ "url": "https://github.com/sponsors/jonschlinkert" } }, + "node_modules/apexcharts": { + "version": "4.7.0", + "resolved": "https://registry.npmjs.org/apexcharts/-/apexcharts-4.7.0.tgz", + "integrity": "sha512-iZSrrBGvVlL+nt2B1NpqfDuBZ9jX61X9I2+XV0hlYXHtTwhwLTHDKGXjNXAgFBDLuvSYCB/rq2nPWVPRv2DrGA==", + "license": "MIT", + "dependencies": { + "@svgdotjs/svg.draggable.js": "^3.0.4", + "@svgdotjs/svg.filter.js": "^3.0.8", + "@svgdotjs/svg.js": "^3.2.4", + "@svgdotjs/svg.resize.js": "^2.0.2", + "@svgdotjs/svg.select.js": "^4.0.1", + "@yr/monotone-cubic-spline": "^1.0.3" + } + }, "node_modules/aproba": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/aproba/-/aproba-2.0.0.tgz", @@ -13375,6 +13453,18 @@ "node": ">=8.0" } }, + "node_modules/loose-envify": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/loose-envify/-/loose-envify-1.4.0.tgz", + "integrity": "sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==", + "license": "MIT", + "dependencies": { + "js-tokens": "^3.0.0 || ^4.0.0" + }, + "bin": { + "loose-envify": "cli.js" + } + }, "node_modules/lru-cache": { "version": "5.1.1", "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-5.1.1.tgz", @@ -14928,7 +15018,6 @@ "version": "4.1.1", "resolved": "https://registry.npmjs.org/object-assign/-/object-assign-4.1.1.tgz", "integrity": "sha512-rJgTQnkUnH1sFw8yT6VSU3zD3sWmu6sZhIseY8VX+GRu3P6F7Fu+JNDoXfklElbLJSnc3FUQHVe4cU5hj+BcUg==", - "dev": true, "license": "MIT", "engines": { "node": ">=0.10.0" @@ -16097,6 +16186,17 @@ "node": ">=10" } }, + "node_modules/prop-types": { + "version": "15.8.1", + "resolved": "https://registry.npmjs.org/prop-types/-/prop-types-15.8.1.tgz", + "integrity": "sha512-oj87CgZICdulUohogVAR7AjlC0327U4el4L6eAvOqCeudMDVU0NThNaV+b9Df4dXgSP1gXMTnPdhfe/2qDH5cg==", + "license": "MIT", + "dependencies": { + "loose-envify": "^1.4.0", + "object-assign": "^4.1.1", + "react-is": "^16.13.1" + } + }, "node_modules/proxy-addr": { "version": "2.0.7", "resolved": "https://registry.npmjs.org/proxy-addr/-/proxy-addr-2.0.7.tgz", @@ -16245,6 +16345,35 @@ "node": ">= 0.8" } }, + "node_modules/react": { + "version": "19.1.0", + "resolved": "https://registry.npmjs.org/react/-/react-19.1.0.tgz", + "integrity": "sha512-FS+XFBNvn3GTAWq26joslQgWNoFu08F4kl0J4CgdNKADkdSGXQyTCnKteIAJy96Br6YbpEU1LSzV5dYtjMkMDg==", + "license": "MIT", + "peer": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/react-apexcharts": { + "version": "1.7.0", + "resolved": "https://registry.npmjs.org/react-apexcharts/-/react-apexcharts-1.7.0.tgz", + "integrity": "sha512-03oScKJyNLRf0Oe+ihJxFZliBQM9vW3UWwomVn4YVRTN1jsIR58dLWt0v1sb8RwJVHDMbeHiKQueM0KGpn7nOA==", + "license": "MIT", + "dependencies": { + "prop-types": "^15.8.1" + }, + "peerDependencies": { + "apexcharts": ">=4.0.0", + "react": ">=0.13" + } + }, + "node_modules/react-is": { + "version": "16.13.1", + "resolved": "https://registry.npmjs.org/react-is/-/react-is-16.13.1.tgz", + "integrity": "sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ==", + "license": "MIT" + }, "node_modules/read-pkg": { "version": "5.2.0", "resolved": "https://registry.npmjs.org/read-pkg/-/read-pkg-5.2.0.tgz", diff --git a/frontend/package.json b/frontend/package.json index b3607d9..3693b17 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -21,8 +21,10 @@ "@angular/platform-browser": "^19.2.0", "@angular/platform-browser-dynamic": "^19.2.0", "@angular/router": "^19.2.0", + "apexcharts": "^4.7.0", "ngx-markdown": "^19.1.0", "node-sass": "^9.0.0", + "react-apexcharts": "^1.7.0", "rxjs": "^7.8.2", "schematics-scss-migrate": "^2.3.17", "tslib": "^2.8.1", From 2eb9546e4ba37b4116abd873b538cb1d107580a5 Mon Sep 17 00:00:00 2001 From: Marco Rosa Date: Wed, 14 May 2025 10:13:11 +0200 Subject: [PATCH 06/15] Add missing env variable --- backend-agent/.env.example | 3 +++ 1 file changed, 3 insertions(+) diff --git a/backend-agent/.env.example b/backend-agent/.env.example index 9bc0af0..1b413b0 100644 --- a/backend-agent/.env.example +++ b/backend-agent/.env.example @@ -12,3 +12,6 @@ API_KEY=super-secret-change-me DEBUG=True RESULT_SUMMARIZE_MODEL=gpt-4 + +# Dashboard data +DASHBOARD_DATA_DIR=dashboard From df7bb3b7095af0577e6eca773184564fe5766076 Mon Sep 17 00:00:00 2001 From: Caroline BANCHEREAU Date: Tue, 17 Jun 2025 12:31:30 +0200 Subject: [PATCH 07/15] save to db instead of using csv logic aligned attacks results TODO: textattack --- backend-agent/.env.example | 3 + backend-agent/app/db/models.py | 59 ++++++++++ backend-agent/app/db/utils.py | 71 ++++++++++++ backend-agent/attack.py | 3 + backend-agent/libs/artprompt.py | 9 +- backend-agent/libs/codeattack.py | 11 +- backend-agent/libs/gptfuzz.py | 16 ++- backend-agent/libs/promptmap.py | 7 +- backend-agent/libs/pyrit.py | 23 ++-- backend-agent/main.py | 179 +++++++++---------------------- backend-agent/requirements.txt | 1 + 11 files changed, 237 insertions(+), 145 deletions(-) create mode 100644 backend-agent/app/db/models.py create mode 100644 backend-agent/app/db/utils.py diff --git a/backend-agent/.env.example b/backend-agent/.env.example index 1b413b0..e1b669d 100644 --- a/backend-agent/.env.example +++ b/backend-agent/.env.example @@ -15,3 +15,6 @@ RESULT_SUMMARIZE_MODEL=gpt-4 # Dashboard data DASHBOARD_DATA_DIR=dashboard + +# Database path +DB_PATH=path_to/database.db diff --git a/backend-agent/app/db/models.py b/backend-agent/app/db/models.py new file mode 100644 index 0000000..49eb493 --- /dev/null +++ b/backend-agent/app/db/models.py @@ -0,0 +1,59 @@ +from flask_sqlalchemy import SQLAlchemy + +db = SQLAlchemy() + + +class AttackModel(db.Model): + __tablename__ = 'attack_models' + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String, unique=True, nullable=False) + description = db.Column(db.String) + + # attack_results = db.relationship('AttackResult', backref='attack_model') + # model_scores = db.relationship('ModelAttackScore', back_populates='attack_model') + + +class Attack(db.Model): + __tablename__ = 'attacks' + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String, nullable=False, unique=True) + weight = db.Column(db.Integer, nullable=False, default=1, server_default="1") + # subattacks = db.relationship('SubAttack', backref='attack', cascade='all, delete-orphan') + # model_scores = db.relationship('ModelAttackScore', back_populates='attack') + + +class SubAttack(db.Model): + __tablename__ = 'sub_attacks' + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String, nullable=False) + description = db.Column(db.String) + attack_id = db.Column(db.Integer, db.ForeignKey('attacks.id'), nullable=False) + + +class AttackResult(db.Model): + __tablename__ = 'attack_results' + id = db.Column(db.Integer, primary_key=True) + attack_model_id = db.Column(db.Integer, db.ForeignKey('attack_models.id'), nullable=False) + attack_id = db.Column(db.Integer, db.ForeignKey('attacks.id'), nullable=False) + success = db.Column(db.Boolean, nullable=False) + vulnerability_type = db.Column(db.String, nullable=True) + details = db.Column(db.JSON, nullable=True) # JSON field + + +class ModelAttackScore(db.Model): + __tablename__ = 'model_attack_scores' + id = db.Column(db.Integer, primary_key=True) + attack_model_id = db.Column(db.Integer, db.ForeignKey('attack_models.id'), nullable=False) + attack_id = db.Column(db.Integer, db.ForeignKey('attacks.id'), nullable=False) + total_number_of_attack = db.Column(db.Integer, nullable=False) + total_success = db.Column(db.Integer, nullable=False) + + # attack_model = db.relationship('AttackModel', back_populates='model_scores') + # attack = db.relationship('Attack', back_populates='model_scores') + + __table_args__ = ( + db.UniqueConstraint('attack_model_id', 'attack_id', name='uix_model_attack'), + ) + + +db.configure_mappers() diff --git a/backend-agent/app/db/utils.py b/backend-agent/app/db/utils.py new file mode 100644 index 0000000..c13619a --- /dev/null +++ b/backend-agent/app/db/utils.py @@ -0,0 +1,71 @@ +from .models import ( + Attack as AttackDB, + db, + AttackModel as AttackModelDB, + AttackResult as AttackResultDB, + ModelAttackScore as ModelAttackScoreDB, +) + + +def save_to_db(attack_results): + """ + Persist the SuiteResult into the database. + Returns a list of AttackResults that were added. + """ + inserted_records = [] + + attack_name = attack_results.attack.lower() + success = attack_results.success + vulnerability_type = attack_results.vulnerability_type.lower() + details = attack_results.details # JSON column + model_name = details.get('target_model').lower() if 'target_model' in details else 'unknown' + + model = AttackModelDB.query.filter_by(name=model_name).first() + if not model: + model = AttackModelDB(name=model_name) + db.session.add(model) + db.session.flush() + + attack = AttackDB.query.filter_by(name=attack_name).first() + if not attack: + attack = AttackDB(name=attack_name, weight=1) # Default weight + db.session.add(attack) + db.session.flush() + + db_record = AttackResultDB( + attack_model_id=model.id, + attack_id=attack.id, + success=success, + vulnerability_type=vulnerability_type, + details=details, + ) + db.session.add(db_record) + inserted_records.append(db_record) + + model_attack_score = ModelAttackScoreDB.query.filter_by( + attack_model_id=model.id, + attack_id=attack.id + ).first() + + if not model_attack_score: + model_attack_score = ModelAttackScoreDB( + attack_model_id=model.id, + attack_id=attack.id, + total_number_of_attack=details.get('total_attacks', 0), + total_success=details.get('number_successful_attacks', 0) + ) + else: + model_attack_score.total_number_of_attack += details.get('total_attacks', 0) + model_attack_score.total_success += details.get('number_successful_attacks', 0) + + db.session.add(model_attack_score) + inserted_records.append(model_attack_score) + + try: + db.session.commit() + print("Results successfully saved to the database.") + return inserted_records + except Exception as e: + db.session.rollback() + print(f"Error while saving to DB: {e}") + return [] diff --git a/backend-agent/attack.py b/backend-agent/attack.py index a394307..8d39112 100644 --- a/backend-agent/attack.py +++ b/backend-agent/attack.py @@ -17,6 +17,8 @@ from llm import LLM from status import Trace +from app.db.utils import save_to_db + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -247,6 +249,7 @@ def run(self, summarize_by_llm: bool = False) -> SuiteResult: summary = self.summarize_attack_result(result) result.details['summary'] = summary full_result.append(result) + save_to_db(result) return SuiteResult(full_result) def summarize_attack_result(self, attack_result: AttackResult) -> str: diff --git a/backend-agent/libs/artprompt.py b/backend-agent/libs/artprompt.py index 20a0895..43a7ddf 100644 --- a/backend-agent/libs/artprompt.py +++ b/backend-agent/libs/artprompt.py @@ -29,6 +29,8 @@ from llm import LLM from status import status, Step +from app.db.utils import save_to_db + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) logger.addHandler(status.trace_logging) @@ -483,14 +485,17 @@ def start_artprompt(target_model: LLM, logger.info(f'Write results to output file {outfile}') with open(outfile, 'w') as f: json.dump(evaluations_res, f, indent=4) - - return AttackResult( + result = AttackResult( 'artprompt', successful_attacks > 0, 'prompt-injection', { + 'target_model': target_model.model_name, + 'total_attacks': len(prompts), 'number_successful_attacks': successful_attacks, 'successful_attacks': successful_attacks_list, 'attack_description': DESCRIPTION } ) + save_to_db(result) + return result diff --git a/backend-agent/libs/codeattack.py b/backend-agent/libs/codeattack.py index d906712..5555131 100644 --- a/backend-agent/libs/codeattack.py +++ b/backend-agent/libs/codeattack.py @@ -11,7 +11,7 @@ from attack_result import AttackResult from llm import LLM from status import status, Step - +from app.db.utils import save_to_db logger = logging.getLogger(__name__) logger.addHandler(status.trace_logging) @@ -131,7 +131,7 @@ def start_codeattack(target_model: LLM, prompts = random.sample(prompts, min(int(num_prompts), len(prompts))) logger.debug(f'Run {len(prompts)} prompt attacks') - +#nomore need output_file = parameters.get('output_file', OUTPUT_FILE) data_key = f'code_wrapped_{prompt_type}' @@ -204,16 +204,21 @@ def start_codeattack(target_model: LLM, # # Write results to file with open(output_file, 'w') as f: json.dump(successful_attacks_list, f) - return AttackResult( + + result = AttackResult( 'codeattack', successful_attacks > 0, 'prompt-injection', { + 'target_model': target_model.model_name, + 'total_attacks': len(prompts), 'number_successful_attacks': successful_attacks, 'successful_attacks': successful_attacks_list, 'attack_description': DESCRIPTION } ) + save_to_db(result) + return result def _prompt_attack(data, target_llm, post_processor, judge_llm, data_key=''): diff --git a/backend-agent/libs/gptfuzz.py b/backend-agent/libs/gptfuzz.py index 8cc4abf..b15e607 100644 --- a/backend-agent/libs/gptfuzz.py +++ b/backend-agent/libs/gptfuzz.py @@ -18,6 +18,7 @@ from attack_result import AttackResult from llm import LLM as AgentLLM from status import status, Step +from app.db.utils import save_to_db load_dotenv() @@ -163,13 +164,18 @@ def perform_gptfuzz_attack(mutate_model: LLM, with Step('Running Fuzzer'): fuzzer.run() logger.info('Fuzzer finished') - return AttackResult( + result = AttackResult( 'gptfuzz', fuzzer.current_jailbreak > 0, 'jailbreak', - details={ - 'result_file': output_file, - 'query_count': fuzzer.current_query, - 'attack_description': DESCRIPTION + { + 'total_attacks': fuzzer.current_iteration, + 'number_successful_attacks': fuzzer.current_iteration, + 'successful_attacks': fuzzer.current_iteration, + 'attack_description': DESCRIPTION, + 'result_file': output_file, + 'query_count': fuzzer.current_query } ) + save_to_db(result) + return result diff --git a/backend-agent/libs/promptmap.py b/backend-agent/libs/promptmap.py index 3757d59..bbf70d5 100644 --- a/backend-agent/libs/promptmap.py +++ b/backend-agent/libs/promptmap.py @@ -20,6 +20,7 @@ from attack_result import AttackResult from llm import LLM from status import status, Step +from app.db.utils import save_to_db COUNT_PROMPTS = 2 @@ -338,12 +339,16 @@ def start_prompt_map(target_model: LLM, parameters: dict) -> AttackResult: # Write results to file with open(output_file, 'w') as f: json.dump(successful_attacks_json, f) - return AttackResult( + result = AttackResult( 'promptmap', security_failed > 0, 'prompt-injection', { + 'total_attacks': total_attack_count, + 'number_successful_attacks': len(successful_attacks), 'successful_attacks': successful_attacks_json, 'attack_description': DESCRIPTION } ) + save_to_db(result) + return result diff --git a/backend-agent/libs/pyrit.py b/backend-agent/libs/pyrit.py index 7f7a520..d61594b 100644 --- a/backend-agent/libs/pyrit.py +++ b/backend-agent/libs/pyrit.py @@ -227,7 +227,7 @@ def start_pyrit_attack( max_turns = parameters.get('max_turns', 3) display_intermediate_results = parameters.get( 'display_intermediate_results', True) - vulnerability_type = parameters.get('vulnerability_type', '') + vulnerability_type = parameters.get('vulnerability_type', 'jailbreak') with Step('Preparing Attack'): attack_strategy = AttackStrategy( @@ -251,15 +251,22 @@ def start_pyrit_attack( initial_red_teaming_prompt='how can I help you?', verbose=True, ) as red_teaming_orchestrator: - with Step('Running Attack'): - result = red_teaming_orchestrator.apply_attack_strategy_until_completion( # noqa E501 + attack_result = red_teaming_orchestrator.apply_attack_strategy_until_completion( # noqa E501 max_turns=max_turns, display_intermediate_results=display_intermediate_results) - return AttackResult( + # TODO: Add total / successful attacks to the result + result = AttackResult( 'PyRIT', - success=result['success'], - details={'response': result['response'], - 'attack_description': DESCRIPTION}, - vulnerability_type=vulnerability_type + attack_result['success'], + vulnerability_type, + { + 'response': attack_result['response'], + 'target_model': target_model.model_name, + # 'total_attacks': len(prompts), + 'number_successful_attacks': 1 if attack_result['success'] else 0, + # 'successful_attacks': attack_result, + 'attack_description': DESCRIPTION, + } ) + return result diff --git a/backend-agent/main.py b/backend-agent/main.py index 1904eab..09cce2e 100644 --- a/backend-agent/main.py +++ b/backend-agent/main.py @@ -1,4 +1,3 @@ -import csv import json import os @@ -6,12 +5,13 @@ from flask import Flask, abort, jsonify, request, send_file from flask_cors import CORS from flask_sock import Sock -from werkzeug.utils import secure_filename if not os.getenv('DISABLE_AGENT'): from agent import agent from status import status, LangchainStatusCallbackHandler from attack_result import SuiteResult +from app.db.models import AttackModel, ModelAttackScore, db, Attack +from sqlalchemy import select ############################################################################# # Flask web server # @@ -23,6 +23,8 @@ load_dotenv() +app.config['SQLALCHEMY_DATABASE_URI'] = f"sqlite:///{os.getenv('DB_PATH')}" + # Langfuse can be used to analyze tracings and help in debugging. langfuse_handler = None if os.getenv('ENABLE_LANGFUSE'): @@ -44,12 +46,14 @@ # Dashboard data DASHBOARD_DATA_DIR = os.getenv('DASHBOARD_DATA_DIR', 'dashboard') -ALLOWED_EXTENSIONS = {'csv'} app.config['DASHBOARD_FOLDER'] = DASHBOARD_DATA_DIR # Create the data folder for the dashboard if it doesn't exist if not os.path.exists(app.config['DASHBOARD_FOLDER']): os.makedirs(app.config['DASHBOARD_FOLDER']) +with app.app_context(): + db.init_app(app) + db.create_all() # create every SQLAlchemy tables defined in models.py def send_intro(sock): @@ -140,133 +144,56 @@ def check_health(): return jsonify({'status': 'ok'}) -def allowed_file(filename): - return '.' in filename and filename.rsplit('.', 1)[1].lower() in \ - ALLOWED_EXTENSIONS - - -@app.route('/api/upload-csv', methods=['POST']) -def upload_csv(): - if 'file' not in request.files: - return jsonify({'error': 'No file part'}), 400 - - file = request.files['file'] - if file.filename == '': - return jsonify({'error': 'No file selected'}), 400 - - if file and allowed_file(file.filename): - filename = secure_filename(file.filename) - path = os.path.join(app.config['DASHBOARD_FOLDER'], filename) - file.save(path) - - # Optional: parse CSV immediately - try: - with open(path, newline='') as f: - reader = csv.DictReader(f) - data = list(reader) - - return jsonify({'message': 'CSV uploaded successfully', - 'data': data} - ) - - except Exception as e: - return jsonify({'error': f'Error reading CSV: {str(e)}'}), 500 - - return jsonify({'error': 'Invalid file type'}), 400 - - -# Endpoint to fetch all the vendors from the uploaded CSV -@app.route('/api/vendors', methods=['GET']) -def get_vendors(): - # Check if CSV file exists - error_response = check_csv_exists('STARS_RESULTS.csv') - if error_response: - print('❌ CSV not found or error from check_csv_exists') - return error_response - - try: - file_path = os.path.join( - app.config['DASHBOARD_FOLDER'], 'STARS_RESULTS.csv') - print(f'📄 Reading CSV from: {file_path}') - with open(file_path, mode='r') as f: - reader = csv.DictReader(f) - data = list(reader) - # Extract unique vendors - vendors = list(set([model['vendor'] for model in data - if 'vendor' in model])) - return jsonify(vendors) - - except Exception as e: - print(f'🔥 Exception occurred: {str(e)}') # DEBUG PRINT - return jsonify( - {'error': f'Error reading vendors from CSV: {str(e)}'}), 500 - - -# Endpoint to fetch heatmap data from the uploaded CSV +# Endpoint to fetch heatmap data from db @app.route('/api/heatmap', methods=['GET']) def get_heatmap(): - # Use dynamic upload folder path - file_path = os.path.join( - app.config['DASHBOARD_FOLDER'], 'STARS_RESULTS.csv') try: - with open(file_path, mode='r') as f: - reader = csv.DictReader(f) - data = list(reader) - - return jsonify(data) - - except Exception as e: - return jsonify( - {'error': f'Error reading heatmap data from CSV: {str(e)}'}), 500 - - -# Endpoint to fetch heatmap data filtered by vendor from the uploaded CSV -@app.route('/api/heatmap/', methods=['GET']) -def get_filtered_heatmap(name): - # Use dynamic upload folder path - file_path = os.path.join( - app.config['DASHBOARD_FOLDER'], 'STARS_RESULTS.csv') - try: - with open(file_path, mode='r') as f: - reader = csv.DictReader(f) - data = list(reader) - - # Filter data by vendor name - filtered_data = [model for model in data - if model['vendor'].lower() == name.lower()] - return jsonify(filtered_data) - + query = ( + select( + ModelAttackScore.total_number_of_attack, + ModelAttackScore.total_success, + AttackModel.name.label("attack_model_name"), + Attack.name.label("attack_name"), + Attack.weight.label("attack_weight") + ) + .join(AttackModel, ModelAttackScore.attack_model_id == AttackModel.id) + .join(Attack, ModelAttackScore.attack_id == Attack.id) + ) + + scores = db.session.execute(query).all() + all_models = {} + all_attacks = {} + + for score in scores: + model_name = score.attack_model_name + attack_name = score.attack_name + + if attack_name not in all_attacks: + all_attacks[attack_name] = score.attack_weight + + if model_name not in all_models: + all_models[model_name] = { + 'name': model_name, + 'scores': {}, + } + + # Compute success ratio for this model/attack + success_ratio = ( + round((score.total_success / score.total_number_of_attack) * 100) + if score.total_number_of_attack else 0 + ) + + all_models[model_name]['scores'][attack_name] = success_ratio + + return jsonify({ + 'models': list(all_models.values()), + 'attacks': [ + {'name': name, 'weight': weight} + for name, weight in sorted(all_attacks.items()) + ] + }) except Exception as e: - return jsonify( - {'error': f'Error reading filtered heatmap data from CSV: {str(e)}' - }), 500 - - -# Endpoint to fetch all attacks from the uploaded CSV -@app.route('/api/attacks', methods=['GET']) -def get_attacks(): - # Use dynamic upload folder path - file_path = os.path.join( - app.config['DASHBOARD_FOLDER'], 'attacks.csv') - try: - with open(file_path, mode='r') as f: - reader = csv.DictReader(f) - data = list(reader) - - return jsonify(data) - - except Exception as e: - return jsonify( - {'error': f'Error reading attacks data from CSV: {str(e)}'}), 500 - - -def check_csv_exists(file_name): - file_path = os.path.join(app.config['DASHBOARD_FOLDER'], file_name) - if not os.path.exists(file_path): - return jsonify( - {'error': f'{file_name} not found. Please upload the file first.' - }), 404 - return None # No error, file exists + return jsonify({'error': str(e)}), 500 if __name__ == '__main__': diff --git a/backend-agent/requirements.txt b/backend-agent/requirements.txt index 7a6ed5c..951ecd4 100644 --- a/backend-agent/requirements.txt +++ b/backend-agent/requirements.txt @@ -23,3 +23,4 @@ pyrit==0.2.1 textattack>=0.3.10 codeattack @ git+https://github.com/marcorosa/CodeAttack gptfuzzer @ git+https://github.com/marcorosa/GPTFuzz@no-vllm +Flask-SQLAlchemy==3.1.1 \ No newline at end of file From 5bd7db2b659972f58dd1059f770735f0eb567017 Mon Sep 17 00:00:00 2001 From: Caroline BANCHEREAU Date: Tue, 17 Jun 2025 15:56:18 +0200 Subject: [PATCH 08/15] removed old DATA_DIR logic --- backend-agent/.env.example | 3 --- backend-agent/main.py | 7 ------- 2 files changed, 10 deletions(-) diff --git a/backend-agent/.env.example b/backend-agent/.env.example index e1b669d..9175bd5 100644 --- a/backend-agent/.env.example +++ b/backend-agent/.env.example @@ -13,8 +13,5 @@ DEBUG=True RESULT_SUMMARIZE_MODEL=gpt-4 -# Dashboard data -DASHBOARD_DATA_DIR=dashboard - # Database path DB_PATH=path_to/database.db diff --git a/backend-agent/main.py b/backend-agent/main.py index 09cce2e..fb910bb 100644 --- a/backend-agent/main.py +++ b/backend-agent/main.py @@ -44,13 +44,6 @@ } if langfuse_handler else { 'callbacks': [status_callback_handler]} -# Dashboard data -DASHBOARD_DATA_DIR = os.getenv('DASHBOARD_DATA_DIR', 'dashboard') -app.config['DASHBOARD_FOLDER'] = DASHBOARD_DATA_DIR - -# Create the data folder for the dashboard if it doesn't exist -if not os.path.exists(app.config['DASHBOARD_FOLDER']): - os.makedirs(app.config['DASHBOARD_FOLDER']) with app.app_context(): db.init_app(app) db.create_all() # create every SQLAlchemy tables defined in models.py From 2ee77b0489007abe26baa1246d881e681b3b5b2b Mon Sep 17 00:00:00 2001 From: Caroline BANCHEREAU Date: Mon, 23 Jun 2025 15:43:49 +0200 Subject: [PATCH 09/15] address pr comments --- backend-agent/app/db/models.py | 32 +++++++++++------------ backend-agent/app/db/utils.py | 45 +++++++++++++++++++++----------- backend-agent/attack.py | 33 ++++++++++++++--------- backend-agent/libs/artprompt.py | 7 +++-- backend-agent/libs/codeattack.py | 7 +++-- backend-agent/main.py | 45 +++++++++++++++++++++++++------- 6 files changed, 107 insertions(+), 62 deletions(-) diff --git a/backend-agent/app/db/models.py b/backend-agent/app/db/models.py index 49eb493..e4ce3aa 100644 --- a/backend-agent/app/db/models.py +++ b/backend-agent/app/db/models.py @@ -3,56 +3,54 @@ db = SQLAlchemy() -class AttackModel(db.Model): - __tablename__ = 'attack_models' +# Represents a target model that can be attacked by various attacks. +class TargetModel(db.Model): + __tablename__ = 'target_models' id = db.Column(db.Integer, primary_key=True) name = db.Column(db.String, unique=True, nullable=False) description = db.Column(db.String) - # attack_results = db.relationship('AttackResult', backref='attack_model') - # model_scores = db.relationship('ModelAttackScore', back_populates='attack_model') - +# Represents an attack that can be performed on a target model. class Attack(db.Model): __tablename__ = 'attacks' id = db.Column(db.Integer, primary_key=True) name = db.Column(db.String, nullable=False, unique=True) - weight = db.Column(db.Integer, nullable=False, default=1, server_default="1") - # subattacks = db.relationship('SubAttack', backref='attack', cascade='all, delete-orphan') - # model_scores = db.relationship('ModelAttackScore', back_populates='attack') + weight = db.Column(db.Integer, nullable=False, default=1, server_default="1") # noqa: E501 +# Represents a sub-attack that is part of a larger attack. class SubAttack(db.Model): __tablename__ = 'sub_attacks' id = db.Column(db.Integer, primary_key=True) name = db.Column(db.String, nullable=False) description = db.Column(db.String) - attack_id = db.Column(db.Integer, db.ForeignKey('attacks.id'), nullable=False) + attack_id = db.Column(db.Integer, db.ForeignKey('attacks.id'), nullable=False) # noqa: E501 +# Represents the results of each sigle attack on a target model. class AttackResult(db.Model): __tablename__ = 'attack_results' id = db.Column(db.Integer, primary_key=True) - attack_model_id = db.Column(db.Integer, db.ForeignKey('attack_models.id'), nullable=False) - attack_id = db.Column(db.Integer, db.ForeignKey('attacks.id'), nullable=False) + attack_model_id = db.Column(db.Integer, db.ForeignKey('target_models.id'), nullable=False) # noqa: E501 + attack_id = db.Column(db.Integer, db.ForeignKey('attacks.id'), nullable=False) # noqa: E501 success = db.Column(db.Boolean, nullable=False) vulnerability_type = db.Column(db.String, nullable=True) details = db.Column(db.JSON, nullable=True) # JSON field +# Represents the global score of an attack on a target model, +# including the total number of attacks and successful attacks. class ModelAttackScore(db.Model): __tablename__ = 'model_attack_scores' id = db.Column(db.Integer, primary_key=True) - attack_model_id = db.Column(db.Integer, db.ForeignKey('attack_models.id'), nullable=False) - attack_id = db.Column(db.Integer, db.ForeignKey('attacks.id'), nullable=False) + attack_model_id = db.Column(db.Integer, db.ForeignKey('target_models.id'), nullable=False) # noqa: E501 + attack_id = db.Column(db.Integer, db.ForeignKey('attacks.id'), nullable=False) # noqa: E501 total_number_of_attack = db.Column(db.Integer, nullable=False) total_success = db.Column(db.Integer, nullable=False) - # attack_model = db.relationship('AttackModel', back_populates='model_scores') - # attack = db.relationship('Attack', back_populates='model_scores') - __table_args__ = ( - db.UniqueConstraint('attack_model_id', 'attack_id', name='uix_model_attack'), + db.UniqueConstraint('attack_model_id', 'attack_id', name='uix_model_attack'), # noqa: E501 ) diff --git a/backend-agent/app/db/utils.py b/backend-agent/app/db/utils.py index c13619a..8fc8014 100644 --- a/backend-agent/app/db/utils.py +++ b/backend-agent/app/db/utils.py @@ -1,29 +1,44 @@ +import logging + from .models import ( Attack as AttackDB, db, - AttackModel as AttackModelDB, + TargetModel as TargetModelDB, AttackResult as AttackResultDB, ModelAttackScore as ModelAttackScoreDB, ) +from status import status + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) +logger.addHandler(status.trace_logging) -def save_to_db(attack_results): + +# Persist the attack result into the database for each attack. +def save_to_db(attack_results: AttackResultDB) -> list[AttackResultDB]: """ - Persist the SuiteResult into the database. + Persist the attack result into the database. Returns a list of AttackResults that were added. """ inserted_records = [] + # retrieve what to save to db attack_name = attack_results.attack.lower() success = attack_results.success vulnerability_type = attack_results.vulnerability_type.lower() details = attack_results.details # JSON column - model_name = details.get('target_model').lower() if 'target_model' in details else 'unknown' + target_name = details.get('target_model').lower() + + # If target model name is not provided, skip saving + if not target_name: + logger.info("Skipping result: missing target model name.") + return - model = AttackModelDB.query.filter_by(name=model_name).first() - if not model: - model = AttackModelDB(name=model_name) - db.session.add(model) + target_model = TargetModelDB.query.filter_by(name=target_name).first() + if not target_model: + target_model = TargetModelDB(name=target_name) + db.session.add(target_model) db.session.flush() attack = AttackDB.query.filter_by(name=attack_name).first() @@ -33,7 +48,7 @@ def save_to_db(attack_results): db.session.flush() db_record = AttackResultDB( - attack_model_id=model.id, + attack_model_id=target_model.id, attack_id=attack.id, success=success, vulnerability_type=vulnerability_type, @@ -43,29 +58,29 @@ def save_to_db(attack_results): inserted_records.append(db_record) model_attack_score = ModelAttackScoreDB.query.filter_by( - attack_model_id=model.id, + attack_model_id=target_model.id, attack_id=attack.id ).first() if not model_attack_score: model_attack_score = ModelAttackScoreDB( - attack_model_id=model.id, + attack_model_id=target_model.id, attack_id=attack.id, total_number_of_attack=details.get('total_attacks', 0), total_success=details.get('number_successful_attacks', 0) ) else: - model_attack_score.total_number_of_attack += details.get('total_attacks', 0) - model_attack_score.total_success += details.get('number_successful_attacks', 0) + model_attack_score.total_number_of_attack += details.get('total_attacks', 0) # noqa: E501 + model_attack_score.total_success += details.get('number_successful_attacks', 0) # noqa: E501 db.session.add(model_attack_score) inserted_records.append(model_attack_score) try: db.session.commit() - print("Results successfully saved to the database.") + logger.info("Results successfully saved to the database.") return inserted_records except Exception as e: db.session.rollback() - print(f"Error while saving to DB: {e}") + logger.error("Error while saving to the database: %s", e) return [] diff --git a/backend-agent/attack.py b/backend-agent/attack.py index 8d39112..b81980e 100644 --- a/backend-agent/attack.py +++ b/backend-agent/attack.py @@ -1,24 +1,31 @@ -from argparse import Namespace -from dataclasses import asdict import json -import os import logging +import os +from argparse import Namespace +from dataclasses import asdict +from app.db.utils import save_to_db from attack_result import AttackResult, SuiteResult -from libs.artprompt import start_artprompt, \ - OUTPUT_FILE as artprompt_out_file -from libs.codeattack import start_codeattack, \ - OUTPUT_FILE as codeattack_out_file -from libs.gptfuzz import perform_gptfuzz_attack, \ - OUTPUT_FILE as gptfuzz_out_file -from libs.promptmap import start_prompt_map, \ - OUTPUT_FILE as prompt_map_out_file +from libs.artprompt import ( + OUTPUT_FILE as artprompt_out_file, + start_artprompt, +) +from libs.codeattack import ( + OUTPUT_FILE as codeattack_out_file, + start_codeattack, +) +from libs.gptfuzz import ( + OUTPUT_FILE as gptfuzz_out_file, + perform_gptfuzz_attack, +) +from libs.promptmap import ( + OUTPUT_FILE as prompt_map_out_file, + start_prompt_map, +) from libs.pyrit import start_pyrit_attack from llm import LLM from status import Trace -from app.db.utils import save_to_db - logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) diff --git a/backend-agent/libs/artprompt.py b/backend-agent/libs/artprompt.py index 43a7ddf..1e43c4d 100644 --- a/backend-agent/libs/artprompt.py +++ b/backend-agent/libs/artprompt.py @@ -25,11 +25,11 @@ import pandas as pd from nltk.corpus import stopwords +from app.db.utils import save_to_db from attack_result import AttackResult from llm import LLM -from status import status, Step +from status import Step, status -from app.db.utils import save_to_db logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -491,11 +491,10 @@ def start_artprompt(target_model: LLM, 'prompt-injection', { 'target_model': target_model.model_name, - 'total_attacks': len(prompts), + 'total_attacks': num_samples, 'number_successful_attacks': successful_attacks, 'successful_attacks': successful_attacks_list, 'attack_description': DESCRIPTION } ) save_to_db(result) - return result diff --git a/backend-agent/libs/codeattack.py b/backend-agent/libs/codeattack.py index 5555131..deda5ef 100644 --- a/backend-agent/libs/codeattack.py +++ b/backend-agent/libs/codeattack.py @@ -8,10 +8,11 @@ from codeattack.post_processing import PostProcessor from codeattack.target_llm import TargetLLM +from app.db.utils import save_to_db from attack_result import AttackResult from llm import LLM -from status import status, Step -from app.db.utils import save_to_db +from status import Step, status + logger = logging.getLogger(__name__) logger.addHandler(status.trace_logging) @@ -131,7 +132,6 @@ def start_codeattack(target_model: LLM, prompts = random.sample(prompts, min(int(num_prompts), len(prompts))) logger.debug(f'Run {len(prompts)} prompt attacks') -#nomore need output_file = parameters.get('output_file', OUTPUT_FILE) data_key = f'code_wrapped_{prompt_type}' @@ -218,7 +218,6 @@ def start_codeattack(target_model: LLM, } ) save_to_db(result) - return result def _prompt_attack(data, target_llm, post_processor, judge_llm, data_key=''): diff --git a/backend-agent/main.py b/backend-agent/main.py index fb910bb..2c93b48 100644 --- a/backend-agent/main.py +++ b/backend-agent/main.py @@ -1,18 +1,19 @@ import json import os +from sqlalchemy import select + from dotenv import load_dotenv from flask import Flask, abort, jsonify, request, send_file from flask_cors import CORS from flask_sock import Sock -if not os.getenv('DISABLE_AGENT'): - from agent import agent -from status import status, LangchainStatusCallbackHandler +from app.db.models import TargetModel, ModelAttackScore, Attack, db from attack_result import SuiteResult -from app.db.models import AttackModel, ModelAttackScore, db, Attack -from sqlalchemy import select +from status import LangchainStatusCallbackHandler, status +if not os.getenv('DISABLE_AGENT'): + from agent import agent ############################################################################# # Flask web server # ############################################################################# @@ -23,7 +24,15 @@ load_dotenv() -app.config['SQLALCHEMY_DATABASE_URI'] = f"sqlite:///{os.getenv('DB_PATH')}" +db_path = os.getenv("DB_PATH") + +if not db_path: + raise EnvironmentError( + "Missing DB_PATH environment variable. Please set DB_PATH in your \ + .env file to a valid SQLite file path." + ) + +app.config['SQLALCHEMY_DATABASE_URI'] = f"sqlite:///{db_path}" # Langfuse can be used to analyze tracings and help in debugging. langfuse_handler = None @@ -140,16 +149,34 @@ def check_health(): # Endpoint to fetch heatmap data from db @app.route('/api/heatmap', methods=['GET']) def get_heatmap(): + """ + Endpoint to retrieve heatmap data showing model score + against various attacks. + + Queries the database for total attacks and successes per + target model and attack combination. + Calculates success ratios and returns structured data for visualization. + + Returns: + JSON response with: + - models: List of target models and their success ratios + per attack. + - attacks: List of attack names and their associated weights. + + HTTP Status Codes: + 200: Data successfully retrieved. + 500: Internal server error during query execution. + """ try: query = ( select( ModelAttackScore.total_number_of_attack, ModelAttackScore.total_success, - AttackModel.name.label("attack_model_name"), + TargetModel.name.label("attack_model_name"), Attack.name.label("attack_name"), Attack.weight.label("attack_weight") ) - .join(AttackModel, ModelAttackScore.attack_model_id == AttackModel.id) + .join(TargetModel, ModelAttackScore.attack_model_id == TargetModel.id) # noqa: E501 .join(Attack, ModelAttackScore.attack_id == Attack.id) ) @@ -172,7 +199,7 @@ def get_heatmap(): # Compute success ratio for this model/attack success_ratio = ( - round((score.total_success / score.total_number_of_attack) * 100) + round((score.total_success / score.total_number_of_attack) * 100) # noqa: E501 if score.total_number_of_attack else 0 ) From 6db0dacf2daef4055f9e87a2ce1d038315ded709 Mon Sep 17 00:00:00 2001 From: Caroline BANCHEREAU Date: Tue, 24 Jun 2025 16:33:57 +0200 Subject: [PATCH 10/15] few sort imports change to attack success rate re added return result --- backend-agent/app/db/models.py | 2 +- backend-agent/app/db/utils.py | 4 ++-- backend-agent/libs/artprompt.py | 1 + backend-agent/libs/codeattack.py | 1 + backend-agent/libs/gptfuzz.py | 15 +++++++-------- backend-agent/libs/promptmap.py | 6 +++--- backend-agent/libs/pyrit.py | 8 ++++---- backend-agent/main.py | 8 ++++---- 8 files changed, 23 insertions(+), 22 deletions(-) diff --git a/backend-agent/app/db/models.py b/backend-agent/app/db/models.py index e4ce3aa..6866936 100644 --- a/backend-agent/app/db/models.py +++ b/backend-agent/app/db/models.py @@ -39,7 +39,7 @@ class AttackResult(db.Model): details = db.Column(db.JSON, nullable=True) # JSON field -# Represents the global score of an attack on a target model, +# Represents the global attack success rate of an attack on a target model, # including the total number of attacks and successful attacks. class ModelAttackScore(db.Model): __tablename__ = 'model_attack_scores' diff --git a/backend-agent/app/db/utils.py b/backend-agent/app/db/utils.py index 8fc8014..381ec4f 100644 --- a/backend-agent/app/db/utils.py +++ b/backend-agent/app/db/utils.py @@ -28,14 +28,14 @@ def save_to_db(attack_results: AttackResultDB) -> list[AttackResultDB]: success = attack_results.success vulnerability_type = attack_results.vulnerability_type.lower() details = attack_results.details # JSON column - target_name = details.get('target_model').lower() + target_name = details.get('target_model') # If target model name is not provided, skip saving if not target_name: logger.info("Skipping result: missing target model name.") return - target_model = TargetModelDB.query.filter_by(name=target_name).first() + target_model = TargetModelDB.query.filter_by(name=target_name.lower()).first() if not target_model: target_model = TargetModelDB(name=target_name) db.session.add(target_model) diff --git a/backend-agent/libs/artprompt.py b/backend-agent/libs/artprompt.py index 1e43c4d..da03a90 100644 --- a/backend-agent/libs/artprompt.py +++ b/backend-agent/libs/artprompt.py @@ -498,3 +498,4 @@ def start_artprompt(target_model: LLM, } ) save_to_db(result) + return result diff --git a/backend-agent/libs/codeattack.py b/backend-agent/libs/codeattack.py index deda5ef..227941e 100644 --- a/backend-agent/libs/codeattack.py +++ b/backend-agent/libs/codeattack.py @@ -218,6 +218,7 @@ def start_codeattack(target_model: LLM, } ) save_to_db(result) + return result def _prompt_attack(data, target_llm, post_processor, judge_llm, data_key=''): diff --git a/backend-agent/libs/gptfuzz.py b/backend-agent/libs/gptfuzz.py index b15e607..9a6b6f4 100644 --- a/backend-agent/libs/gptfuzz.py +++ b/backend-agent/libs/gptfuzz.py @@ -4,6 +4,7 @@ import pandas as pd from dotenv import load_dotenv + from gptfuzzer.fuzzer.core import GPTFuzzer from gptfuzzer.fuzzer.mutator import (MutateRandomSinglePolicy, OpenAIMutatorCrossOver, @@ -15,10 +16,10 @@ from gptfuzzer.llm import LLM, OpenAILLM from gptfuzzer.utils.predict import RoBERTaPredictor +from app.db.utils import save_to_db from attack_result import AttackResult from llm import LLM as AgentLLM -from status import status, Step -from app.db.utils import save_to_db +from status import Step, status load_dotenv() @@ -169,12 +170,10 @@ def perform_gptfuzz_attack(mutate_model: LLM, fuzzer.current_jailbreak > 0, 'jailbreak', { - 'total_attacks': fuzzer.current_iteration, - 'number_successful_attacks': fuzzer.current_iteration, - 'successful_attacks': fuzzer.current_iteration, - 'attack_description': DESCRIPTION, - 'result_file': output_file, - 'query_count': fuzzer.current_query + 'target_model': target_model.llm.model_name, + 'total_attacks': fuzzer.current_query, + 'number_successful_attacks': fuzzer.current_jailbreak, + 'attack_description': DESCRIPTION, } ) save_to_db(result) diff --git a/backend-agent/libs/promptmap.py b/backend-agent/libs/promptmap.py index bbf70d5..fdf038b 100644 --- a/backend-agent/libs/promptmap.py +++ b/backend-agent/libs/promptmap.py @@ -17,10 +17,10 @@ import logging import re +from app.db.utils import save_to_db from attack_result import AttackResult from llm import LLM -from status import status, Step -from app.db.utils import save_to_db +from status import Step, status COUNT_PROMPTS = 2 @@ -306,7 +306,7 @@ def start_prompt_map(target_model: LLM, parameters: dict) -> AttackResult: target_system_prompt) except Exception as e: logger.error('Error occurred while evaluating attack ' - 'success rate: ', e) + 'attack success rate: ', e) continue if is_successful: logger.info('* Prompt attack successful!') diff --git a/backend-agent/libs/pyrit.py b/backend-agent/libs/pyrit.py index d61594b..ad4afb2 100644 --- a/backend-agent/libs/pyrit.py +++ b/backend-agent/libs/pyrit.py @@ -10,6 +10,7 @@ from pyrit.prompt_target.prompt_chat_target.openai_chat_target import OpenAIChatInterface # noqa E501 from pyrit.score import SupportTextClassification, Score +from app.db.utils import save_to_db from attack_result import AttackResult from llm import LLM from status import status, Step @@ -261,12 +262,11 @@ def start_pyrit_attack( attack_result['success'], vulnerability_type, { - 'response': attack_result['response'], 'target_model': target_model.model_name, - # 'total_attacks': len(prompts), - 'number_successful_attacks': 1 if attack_result['success'] else 0, - # 'successful_attacks': attack_result, + 'response': attack_result['response'], + 'number_successful_attacks': 1 if attack_result['success'] else 0, # noqa: E501 'attack_description': DESCRIPTION, } ) + save_to_db(result) return result diff --git a/backend-agent/main.py b/backend-agent/main.py index 2c93b48..79df6a1 100644 --- a/backend-agent/main.py +++ b/backend-agent/main.py @@ -1,12 +1,12 @@ import json import os -from sqlalchemy import select from dotenv import load_dotenv from flask import Flask, abort, jsonify, request, send_file from flask_cors import CORS from flask_sock import Sock +from sqlalchemy import select from app.db.models import TargetModel, ModelAttackScore, Attack, db from attack_result import SuiteResult @@ -155,11 +155,11 @@ def get_heatmap(): Queries the database for total attacks and successes per target model and attack combination. - Calculates success ratios and returns structured data for visualization. + Calculates attack success rate and returns structured data for visualization. Returns: JSON response with: - - models: List of target models and their success ratios + - models: List of target models and their attack success rate per attack. - attacks: List of attack names and their associated weights. @@ -197,7 +197,7 @@ def get_heatmap(): 'scores': {}, } - # Compute success ratio for this model/attack + # Compute attack success rate for this model/attack success_ratio = ( round((score.total_success / score.total_number_of_attack) * 100) # noqa: E501 if score.total_number_of_attack else 0 From a948f3ab405ee48127deacc1cd43541a09d09fe3 Mon Sep 17 00:00:00 2001 From: Caroline BANCHEREAU Date: Tue, 24 Jun 2025 16:48:48 +0200 Subject: [PATCH 11/15] updated db/utils.py --- backend-agent/app/db/utils.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/backend-agent/app/db/utils.py b/backend-agent/app/db/utils.py index 381ec4f..f68f62a 100644 --- a/backend-agent/app/db/utils.py +++ b/backend-agent/app/db/utils.py @@ -23,7 +23,7 @@ def save_to_db(attack_results: AttackResultDB) -> list[AttackResultDB]: """ inserted_records = [] - # retrieve what to save to db + # Retrieve what to save to db attack_name = attack_results.attack.lower() success = attack_results.success vulnerability_type = attack_results.vulnerability_type.lower() @@ -35,18 +35,21 @@ def save_to_db(attack_results: AttackResultDB) -> list[AttackResultDB]: logger.info("Skipping result: missing target model name.") return + # If target model does not exist, create it target_model = TargetModelDB.query.filter_by(name=target_name.lower()).first() if not target_model: target_model = TargetModelDB(name=target_name) db.session.add(target_model) db.session.flush() + # If attack does not exist, create it with default weight to 1 attack = AttackDB.query.filter_by(name=attack_name).first() if not attack: - attack = AttackDB(name=attack_name, weight=1) # Default weight + attack = AttackDB(name=attack_name, weight=1) db.session.add(attack) db.session.flush() + # Add the attack result to inserted_records db_record = AttackResultDB( attack_model_id=target_model.id, attack_id=attack.id, @@ -57,11 +60,12 @@ def save_to_db(attack_results: AttackResultDB) -> list[AttackResultDB]: db.session.add(db_record) inserted_records.append(db_record) + # If model_attack_score does not exist, create it + # otherwise, update the existing record model_attack_score = ModelAttackScoreDB.query.filter_by( attack_model_id=target_model.id, attack_id=attack.id ).first() - if not model_attack_score: model_attack_score = ModelAttackScoreDB( attack_model_id=target_model.id, @@ -72,10 +76,11 @@ def save_to_db(attack_results: AttackResultDB) -> list[AttackResultDB]: else: model_attack_score.total_number_of_attack += details.get('total_attacks', 0) # noqa: E501 model_attack_score.total_success += details.get('number_successful_attacks', 0) # noqa: E501 - db.session.add(model_attack_score) inserted_records.append(model_attack_score) + # Commit the session to save all changes to the database + # or rollback if an error occurs try: db.session.commit() logger.info("Results successfully saved to the database.") From e870b520e3a0e5ae6fb3af09a179f114ec8a04fd Mon Sep 17 00:00:00 2001 From: Caroline BANCHEREAU Date: Tue, 24 Jun 2025 16:51:23 +0200 Subject: [PATCH 12/15] deleted remaining todo comment line to pyrit.py --- backend-agent/libs/pyrit.py | 1 - 1 file changed, 1 deletion(-) diff --git a/backend-agent/libs/pyrit.py b/backend-agent/libs/pyrit.py index ad4afb2..098a9c2 100644 --- a/backend-agent/libs/pyrit.py +++ b/backend-agent/libs/pyrit.py @@ -256,7 +256,6 @@ def start_pyrit_attack( attack_result = red_teaming_orchestrator.apply_attack_strategy_until_completion( # noqa E501 max_turns=max_turns, display_intermediate_results=display_intermediate_results) - # TODO: Add total / successful attacks to the result result = AttackResult( 'PyRIT', attack_result['success'], From 7658c33e1963f4daf21070e3723aeb911bf4320e Mon Sep 17 00:00:00 2001 From: Marco Rosa Date: Wed, 25 Jun 2025 09:32:14 +0200 Subject: [PATCH 13/15] Fix pep8 errors --- backend-agent/app/db/utils.py | 4 ++-- backend-agent/main.py | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/backend-agent/app/db/utils.py b/backend-agent/app/db/utils.py index f68f62a..f1cc505 100644 --- a/backend-agent/app/db/utils.py +++ b/backend-agent/app/db/utils.py @@ -28,7 +28,7 @@ def save_to_db(attack_results: AttackResultDB) -> list[AttackResultDB]: success = attack_results.success vulnerability_type = attack_results.vulnerability_type.lower() details = attack_results.details # JSON column - target_name = details.get('target_model') + target_name = details.get('target_model', '').lower() # If target model name is not provided, skip saving if not target_name: @@ -36,7 +36,7 @@ def save_to_db(attack_results: AttackResultDB) -> list[AttackResultDB]: return # If target model does not exist, create it - target_model = TargetModelDB.query.filter_by(name=target_name.lower()).first() + target_model = TargetModelDB.query.filter_by(name=target_name).first() if not target_model: target_model = TargetModelDB(name=target_name) db.session.add(target_model) diff --git a/backend-agent/main.py b/backend-agent/main.py index 79df6a1..40a205a 100644 --- a/backend-agent/main.py +++ b/backend-agent/main.py @@ -153,9 +153,10 @@ def get_heatmap(): Endpoint to retrieve heatmap data showing model score against various attacks. - Queries the database for total attacks and successes per - target model and attack combination. - Calculates attack success rate and returns structured data for visualization. + Queries the database for total attacks and successes per target model and + attack combination. + Calculates attack success rate and returns structured data for + visualization. Returns: JSON response with: From 331e7db65f26088587a90bfc131541e61b995a17 Mon Sep 17 00:00:00 2001 From: Marco Rosa Date: Wed, 25 Jun 2025 09:43:14 +0200 Subject: [PATCH 14/15] Pass DB_PATH to github action --- .github/workflows/installation-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/installation-test.yml b/.github/workflows/installation-test.yml index 2f9986e..d44ff03 100644 --- a/.github/workflows/installation-test.yml +++ b/.github/workflows/installation-test.yml @@ -34,7 +34,7 @@ jobs: - name: Start server run: | cd backend-agent - DISABLE_AGENT=1 python main.py & + DISABLE_AGENT=1 DB_PATH=/dashboard/data.db python main.py & sleep 10 - name: Check server health From fc66711cf6e7c44c372cdc5e4a0a7bd7e52c031a Mon Sep 17 00:00:00 2001 From: Marco Rosa Date: Wed, 25 Jun 2025 10:47:23 +0200 Subject: [PATCH 15/15] Fix health check installation action --- .github/workflows/installation-test.yml | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/.github/workflows/installation-test.yml b/.github/workflows/installation-test.yml index d44ff03..94c8451 100644 --- a/.github/workflows/installation-test.yml +++ b/.github/workflows/installation-test.yml @@ -31,12 +31,19 @@ jobs: cache-dependency-path: backend-agent/requirements.txt - run: pip install -r backend-agent/requirements.txt - - name: Start server + - name: Start server and check health run: | cd backend-agent - DISABLE_AGENT=1 DB_PATH=/dashboard/data.db python main.py & - sleep 10 - - - name: Check server health - run: | - curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health + DISABLE_AGENT=1 DB_PATH=${RUNNER_TEMP}/data.db python main.py > server.log 2>&1 & + for i in {1..20}; do + sleep 1 + status=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || true) + if [ "$status" -eq 200 ]; then + echo "Health check succeeded" + cat server.log + exit 0 + fi + done + echo "Health check failed after waiting" + cat server.log + exit 1