diff --git a/.github/workflows/installation-test.yml b/.github/workflows/installation-test.yml index 2f9986e..94c8451 100644 --- a/.github/workflows/installation-test.yml +++ b/.github/workflows/installation-test.yml @@ -31,12 +31,19 @@ jobs: cache-dependency-path: backend-agent/requirements.txt - run: pip install -r backend-agent/requirements.txt - - name: Start server + - name: Start server and check health run: | cd backend-agent - DISABLE_AGENT=1 python main.py & - sleep 10 - - - name: Check server health - run: | - curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health + DISABLE_AGENT=1 DB_PATH=${RUNNER_TEMP}/data.db python main.py > server.log 2>&1 & + for i in {1..20}; do + sleep 1 + status=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || true) + if [ "$status" -eq 200 ]; then + echo "Health check succeeded" + cat server.log + exit 0 + fi + done + echo "Health check failed after waiting" + cat server.log + exit 1 diff --git a/.gitignore b/.gitignore index 0887fc1..f9a5ba7 100644 --- a/.gitignore +++ b/.gitignore @@ -89,6 +89,8 @@ venv/ ENV/ env.bak/ venv.bak/ +venv310 +cache # Spyder project settings .spyderproject diff --git a/backend-agent/.env.example b/backend-agent/.env.example index 9bc0af0..9175bd5 100644 --- a/backend-agent/.env.example +++ b/backend-agent/.env.example @@ -12,3 +12,6 @@ API_KEY=super-secret-change-me DEBUG=True RESULT_SUMMARIZE_MODEL=gpt-4 + +# Database path +DB_PATH=path_to/database.db diff --git a/backend-agent/app/db/models.py b/backend-agent/app/db/models.py new file mode 100644 index 0000000..6866936 --- /dev/null +++ b/backend-agent/app/db/models.py @@ -0,0 +1,57 @@ +from flask_sqlalchemy import SQLAlchemy + +db = SQLAlchemy() + + +# Represents a target model that can be attacked by various attacks. +class TargetModel(db.Model): + __tablename__ = 'target_models' + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String, unique=True, nullable=False) + description = db.Column(db.String) + + +# Represents an attack that can be performed on a target model. +class Attack(db.Model): + __tablename__ = 'attacks' + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String, nullable=False, unique=True) + weight = db.Column(db.Integer, nullable=False, default=1, server_default="1") # noqa: E501 + + +# Represents a sub-attack that is part of a larger attack. +class SubAttack(db.Model): + __tablename__ = 'sub_attacks' + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String, nullable=False) + description = db.Column(db.String) + attack_id = db.Column(db.Integer, db.ForeignKey('attacks.id'), nullable=False) # noqa: E501 + + +# Represents the results of each sigle attack on a target model. +class AttackResult(db.Model): + __tablename__ = 'attack_results' + id = db.Column(db.Integer, primary_key=True) + attack_model_id = db.Column(db.Integer, db.ForeignKey('target_models.id'), nullable=False) # noqa: E501 + attack_id = db.Column(db.Integer, db.ForeignKey('attacks.id'), nullable=False) # noqa: E501 + success = db.Column(db.Boolean, nullable=False) + vulnerability_type = db.Column(db.String, nullable=True) + details = db.Column(db.JSON, nullable=True) # JSON field + + +# Represents the global attack success rate of an attack on a target model, +# including the total number of attacks and successful attacks. +class ModelAttackScore(db.Model): + __tablename__ = 'model_attack_scores' + id = db.Column(db.Integer, primary_key=True) + attack_model_id = db.Column(db.Integer, db.ForeignKey('target_models.id'), nullable=False) # noqa: E501 + attack_id = db.Column(db.Integer, db.ForeignKey('attacks.id'), nullable=False) # noqa: E501 + total_number_of_attack = db.Column(db.Integer, nullable=False) + total_success = db.Column(db.Integer, nullable=False) + + __table_args__ = ( + db.UniqueConstraint('attack_model_id', 'attack_id', name='uix_model_attack'), # noqa: E501 + ) + + +db.configure_mappers() diff --git a/backend-agent/app/db/utils.py b/backend-agent/app/db/utils.py new file mode 100644 index 0000000..f1cc505 --- /dev/null +++ b/backend-agent/app/db/utils.py @@ -0,0 +1,91 @@ +import logging + +from .models import ( + Attack as AttackDB, + db, + TargetModel as TargetModelDB, + AttackResult as AttackResultDB, + ModelAttackScore as ModelAttackScoreDB, +) + +from status import status + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) +logger.addHandler(status.trace_logging) + + +# Persist the attack result into the database for each attack. +def save_to_db(attack_results: AttackResultDB) -> list[AttackResultDB]: + """ + Persist the attack result into the database. + Returns a list of AttackResults that were added. + """ + inserted_records = [] + + # Retrieve what to save to db + attack_name = attack_results.attack.lower() + success = attack_results.success + vulnerability_type = attack_results.vulnerability_type.lower() + details = attack_results.details # JSON column + target_name = details.get('target_model', '').lower() + + # If target model name is not provided, skip saving + if not target_name: + logger.info("Skipping result: missing target model name.") + return + + # If target model does not exist, create it + target_model = TargetModelDB.query.filter_by(name=target_name).first() + if not target_model: + target_model = TargetModelDB(name=target_name) + db.session.add(target_model) + db.session.flush() + + # If attack does not exist, create it with default weight to 1 + attack = AttackDB.query.filter_by(name=attack_name).first() + if not attack: + attack = AttackDB(name=attack_name, weight=1) + db.session.add(attack) + db.session.flush() + + # Add the attack result to inserted_records + db_record = AttackResultDB( + attack_model_id=target_model.id, + attack_id=attack.id, + success=success, + vulnerability_type=vulnerability_type, + details=details, + ) + db.session.add(db_record) + inserted_records.append(db_record) + + # If model_attack_score does not exist, create it + # otherwise, update the existing record + model_attack_score = ModelAttackScoreDB.query.filter_by( + attack_model_id=target_model.id, + attack_id=attack.id + ).first() + if not model_attack_score: + model_attack_score = ModelAttackScoreDB( + attack_model_id=target_model.id, + attack_id=attack.id, + total_number_of_attack=details.get('total_attacks', 0), + total_success=details.get('number_successful_attacks', 0) + ) + else: + model_attack_score.total_number_of_attack += details.get('total_attacks', 0) # noqa: E501 + model_attack_score.total_success += details.get('number_successful_attacks', 0) # noqa: E501 + db.session.add(model_attack_score) + inserted_records.append(model_attack_score) + + # Commit the session to save all changes to the database + # or rollback if an error occurs + try: + db.session.commit() + logger.info("Results successfully saved to the database.") + return inserted_records + except Exception as e: + db.session.rollback() + logger.error("Error while saving to the database: %s", e) + return [] diff --git a/backend-agent/attack.py b/backend-agent/attack.py index a394307..b81980e 100644 --- a/backend-agent/attack.py +++ b/backend-agent/attack.py @@ -1,18 +1,27 @@ -from argparse import Namespace -from dataclasses import asdict import json -import os import logging +import os +from argparse import Namespace +from dataclasses import asdict +from app.db.utils import save_to_db from attack_result import AttackResult, SuiteResult -from libs.artprompt import start_artprompt, \ - OUTPUT_FILE as artprompt_out_file -from libs.codeattack import start_codeattack, \ - OUTPUT_FILE as codeattack_out_file -from libs.gptfuzz import perform_gptfuzz_attack, \ - OUTPUT_FILE as gptfuzz_out_file -from libs.promptmap import start_prompt_map, \ - OUTPUT_FILE as prompt_map_out_file +from libs.artprompt import ( + OUTPUT_FILE as artprompt_out_file, + start_artprompt, +) +from libs.codeattack import ( + OUTPUT_FILE as codeattack_out_file, + start_codeattack, +) +from libs.gptfuzz import ( + OUTPUT_FILE as gptfuzz_out_file, + perform_gptfuzz_attack, +) +from libs.promptmap import ( + OUTPUT_FILE as prompt_map_out_file, + start_prompt_map, +) from libs.pyrit import start_pyrit_attack from llm import LLM from status import Trace @@ -247,6 +256,7 @@ def run(self, summarize_by_llm: bool = False) -> SuiteResult: summary = self.summarize_attack_result(result) result.details['summary'] = summary full_result.append(result) + save_to_db(result) return SuiteResult(full_result) def summarize_attack_result(self, attack_result: AttackResult) -> str: diff --git a/backend-agent/libs/artprompt.py b/backend-agent/libs/artprompt.py index 20a0895..da03a90 100644 --- a/backend-agent/libs/artprompt.py +++ b/backend-agent/libs/artprompt.py @@ -25,9 +25,11 @@ import pandas as pd from nltk.corpus import stopwords +from app.db.utils import save_to_db from attack_result import AttackResult from llm import LLM -from status import status, Step +from status import Step, status + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -483,14 +485,17 @@ def start_artprompt(target_model: LLM, logger.info(f'Write results to output file {outfile}') with open(outfile, 'w') as f: json.dump(evaluations_res, f, indent=4) - - return AttackResult( + result = AttackResult( 'artprompt', successful_attacks > 0, 'prompt-injection', { + 'target_model': target_model.model_name, + 'total_attacks': num_samples, 'number_successful_attacks': successful_attacks, 'successful_attacks': successful_attacks_list, 'attack_description': DESCRIPTION } ) + save_to_db(result) + return result diff --git a/backend-agent/libs/codeattack.py b/backend-agent/libs/codeattack.py index d906712..227941e 100644 --- a/backend-agent/libs/codeattack.py +++ b/backend-agent/libs/codeattack.py @@ -8,9 +8,10 @@ from codeattack.post_processing import PostProcessor from codeattack.target_llm import TargetLLM +from app.db.utils import save_to_db from attack_result import AttackResult from llm import LLM -from status import status, Step +from status import Step, status logger = logging.getLogger(__name__) @@ -131,7 +132,6 @@ def start_codeattack(target_model: LLM, prompts = random.sample(prompts, min(int(num_prompts), len(prompts))) logger.debug(f'Run {len(prompts)} prompt attacks') - output_file = parameters.get('output_file', OUTPUT_FILE) data_key = f'code_wrapped_{prompt_type}' @@ -204,16 +204,21 @@ def start_codeattack(target_model: LLM, # # Write results to file with open(output_file, 'w') as f: json.dump(successful_attacks_list, f) - return AttackResult( + + result = AttackResult( 'codeattack', successful_attacks > 0, 'prompt-injection', { + 'target_model': target_model.model_name, + 'total_attacks': len(prompts), 'number_successful_attacks': successful_attacks, 'successful_attacks': successful_attacks_list, 'attack_description': DESCRIPTION } ) + save_to_db(result) + return result def _prompt_attack(data, target_llm, post_processor, judge_llm, data_key=''): diff --git a/backend-agent/libs/gptfuzz.py b/backend-agent/libs/gptfuzz.py index 8cc4abf..9a6b6f4 100644 --- a/backend-agent/libs/gptfuzz.py +++ b/backend-agent/libs/gptfuzz.py @@ -4,6 +4,7 @@ import pandas as pd from dotenv import load_dotenv + from gptfuzzer.fuzzer.core import GPTFuzzer from gptfuzzer.fuzzer.mutator import (MutateRandomSinglePolicy, OpenAIMutatorCrossOver, @@ -15,9 +16,10 @@ from gptfuzzer.llm import LLM, OpenAILLM from gptfuzzer.utils.predict import RoBERTaPredictor +from app.db.utils import save_to_db from attack_result import AttackResult from llm import LLM as AgentLLM -from status import status, Step +from status import Step, status load_dotenv() @@ -163,13 +165,16 @@ def perform_gptfuzz_attack(mutate_model: LLM, with Step('Running Fuzzer'): fuzzer.run() logger.info('Fuzzer finished') - return AttackResult( + result = AttackResult( 'gptfuzz', fuzzer.current_jailbreak > 0, 'jailbreak', - details={ - 'result_file': output_file, - 'query_count': fuzzer.current_query, - 'attack_description': DESCRIPTION + { + 'target_model': target_model.llm.model_name, + 'total_attacks': fuzzer.current_query, + 'number_successful_attacks': fuzzer.current_jailbreak, + 'attack_description': DESCRIPTION, } ) + save_to_db(result) + return result diff --git a/backend-agent/libs/promptmap.py b/backend-agent/libs/promptmap.py index 3757d59..fdf038b 100644 --- a/backend-agent/libs/promptmap.py +++ b/backend-agent/libs/promptmap.py @@ -17,9 +17,10 @@ import logging import re +from app.db.utils import save_to_db from attack_result import AttackResult from llm import LLM -from status import status, Step +from status import Step, status COUNT_PROMPTS = 2 @@ -305,7 +306,7 @@ def start_prompt_map(target_model: LLM, parameters: dict) -> AttackResult: target_system_prompt) except Exception as e: logger.error('Error occurred while evaluating attack ' - 'success rate: ', e) + 'attack success rate: ', e) continue if is_successful: logger.info('* Prompt attack successful!') @@ -338,12 +339,16 @@ def start_prompt_map(target_model: LLM, parameters: dict) -> AttackResult: # Write results to file with open(output_file, 'w') as f: json.dump(successful_attacks_json, f) - return AttackResult( + result = AttackResult( 'promptmap', security_failed > 0, 'prompt-injection', { + 'total_attacks': total_attack_count, + 'number_successful_attacks': len(successful_attacks), 'successful_attacks': successful_attacks_json, 'attack_description': DESCRIPTION } ) + save_to_db(result) + return result diff --git a/backend-agent/libs/pyrit.py b/backend-agent/libs/pyrit.py index 7f7a520..098a9c2 100644 --- a/backend-agent/libs/pyrit.py +++ b/backend-agent/libs/pyrit.py @@ -10,6 +10,7 @@ from pyrit.prompt_target.prompt_chat_target.openai_chat_target import OpenAIChatInterface # noqa E501 from pyrit.score import SupportTextClassification, Score +from app.db.utils import save_to_db from attack_result import AttackResult from llm import LLM from status import status, Step @@ -227,7 +228,7 @@ def start_pyrit_attack( max_turns = parameters.get('max_turns', 3) display_intermediate_results = parameters.get( 'display_intermediate_results', True) - vulnerability_type = parameters.get('vulnerability_type', '') + vulnerability_type = parameters.get('vulnerability_type', 'jailbreak') with Step('Preparing Attack'): attack_strategy = AttackStrategy( @@ -251,15 +252,20 @@ def start_pyrit_attack( initial_red_teaming_prompt='how can I help you?', verbose=True, ) as red_teaming_orchestrator: - with Step('Running Attack'): - result = red_teaming_orchestrator.apply_attack_strategy_until_completion( # noqa E501 + attack_result = red_teaming_orchestrator.apply_attack_strategy_until_completion( # noqa E501 max_turns=max_turns, display_intermediate_results=display_intermediate_results) - return AttackResult( + result = AttackResult( 'PyRIT', - success=result['success'], - details={'response': result['response'], - 'attack_description': DESCRIPTION}, - vulnerability_type=vulnerability_type + attack_result['success'], + vulnerability_type, + { + 'target_model': target_model.model_name, + 'response': attack_result['response'], + 'number_successful_attacks': 1 if attack_result['success'] else 0, # noqa: E501 + 'attack_description': DESCRIPTION, + } ) + save_to_db(result) + return result diff --git a/backend-agent/main.py b/backend-agent/main.py index 7f6e80d..40a205a 100644 --- a/backend-agent/main.py +++ b/backend-agent/main.py @@ -1,16 +1,19 @@ import json import os + from dotenv import load_dotenv from flask import Flask, abort, jsonify, request, send_file from flask_cors import CORS from flask_sock import Sock +from sqlalchemy import select -if not os.getenv('DISABLE_AGENT'): - from agent import agent -from status import status, LangchainStatusCallbackHandler +from app.db.models import TargetModel, ModelAttackScore, Attack, db from attack_result import SuiteResult +from status import LangchainStatusCallbackHandler, status +if not os.getenv('DISABLE_AGENT'): + from agent import agent ############################################################################# # Flask web server # ############################################################################# @@ -21,6 +24,16 @@ load_dotenv() +db_path = os.getenv("DB_PATH") + +if not db_path: + raise EnvironmentError( + "Missing DB_PATH environment variable. Please set DB_PATH in your \ + .env file to a valid SQLite file path." + ) + +app.config['SQLALCHEMY_DATABASE_URI'] = f"sqlite:///{db_path}" + # Langfuse can be used to analyze tracings and help in debugging. langfuse_handler = None if os.getenv('ENABLE_LANGFUSE'): @@ -40,6 +53,10 @@ } if langfuse_handler else { 'callbacks': [status_callback_handler]} +with app.app_context(): + db.init_app(app) + db.create_all() # create every SQLAlchemy tables defined in models.py + def send_intro(sock): """ @@ -129,6 +146,77 @@ def check_health(): return jsonify({'status': 'ok'}) +# Endpoint to fetch heatmap data from db +@app.route('/api/heatmap', methods=['GET']) +def get_heatmap(): + """ + Endpoint to retrieve heatmap data showing model score + against various attacks. + + Queries the database for total attacks and successes per target model and + attack combination. + Calculates attack success rate and returns structured data for + visualization. + + Returns: + JSON response with: + - models: List of target models and their attack success rate + per attack. + - attacks: List of attack names and their associated weights. + + HTTP Status Codes: + 200: Data successfully retrieved. + 500: Internal server error during query execution. + """ + try: + query = ( + select( + ModelAttackScore.total_number_of_attack, + ModelAttackScore.total_success, + TargetModel.name.label("attack_model_name"), + Attack.name.label("attack_name"), + Attack.weight.label("attack_weight") + ) + .join(TargetModel, ModelAttackScore.attack_model_id == TargetModel.id) # noqa: E501 + .join(Attack, ModelAttackScore.attack_id == Attack.id) + ) + + scores = db.session.execute(query).all() + all_models = {} + all_attacks = {} + + for score in scores: + model_name = score.attack_model_name + attack_name = score.attack_name + + if attack_name not in all_attacks: + all_attacks[attack_name] = score.attack_weight + + if model_name not in all_models: + all_models[model_name] = { + 'name': model_name, + 'scores': {}, + } + + # Compute attack success rate for this model/attack + success_ratio = ( + round((score.total_success / score.total_number_of_attack) * 100) # noqa: E501 + if score.total_number_of_attack else 0 + ) + + all_models[model_name]['scores'][attack_name] = success_ratio + + return jsonify({ + 'models': list(all_models.values()), + 'attacks': [ + {'name': name, 'weight': weight} + for name, weight in sorted(all_attacks.items()) + ] + }) + except Exception as e: + return jsonify({'error': str(e)}), 500 + + if __name__ == '__main__': if not os.getenv('API_KEY'): print('No API key is set! Access is unrestricted.') diff --git a/backend-agent/requirements.txt b/backend-agent/requirements.txt index b73c3d3..b5e7df8 100644 --- a/backend-agent/requirements.txt +++ b/backend-agent/requirements.txt @@ -23,3 +23,4 @@ pyrit==0.2.1 textattack>=0.3.10 codeattack @ git+https://github.com/marcorosa/CodeAttack gptfuzzer @ git+https://github.com/marcorosa/GPTFuzz@no-vllm +Flask-SQLAlchemy==3.1.1 \ No newline at end of file diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 11f1335..625a78a 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -18,7 +18,10 @@ "@angular/platform-browser": "^19.2.0", "@angular/platform-browser-dynamic": "^19.2.0", "@angular/router": "^19.2.0", + "apexcharts": "^4.7.0", "ngx-markdown": "^19.1.0", + "node-sass": "^9.0.0", + "react-apexcharts": "^1.7.0", "rxjs": "^7.8.2", "sass": "^1.89.0", "schematics-scss-migrate": "^2.3.17", @@ -6645,6 +6648,62 @@ "dev": true, "license": "MIT" }, + "node_modules/@svgdotjs/svg.draggable.js": { + "version": "3.0.6", + "resolved": "https://registry.npmjs.org/@svgdotjs/svg.draggable.js/-/svg.draggable.js-3.0.6.tgz", + "integrity": "sha512-7iJFm9lL3C40HQcqzEfezK2l+dW2CpoVY3b77KQGqc8GXWa6LhhmX5Ckv7alQfUXBuZbjpICZ+Dvq1czlGx7gA==", + "license": "MIT", + "peerDependencies": { + "@svgdotjs/svg.js": "^3.2.4" + } + }, + "node_modules/@svgdotjs/svg.filter.js": { + "version": "3.0.9", + "resolved": "https://registry.npmjs.org/@svgdotjs/svg.filter.js/-/svg.filter.js-3.0.9.tgz", + "integrity": "sha512-/69XMRCDoam2HgC4ldHIaDgeQf1ViHIsa0Ld4uWgiXtZ+E24DWHe/9Ib6kbNiZ7WRIdlVokUDR1Fg0kjIpkfbw==", + "license": "MIT", + "dependencies": { + "@svgdotjs/svg.js": "^3.2.4" + }, + "engines": { + "node": ">= 0.8.0" + } + }, + "node_modules/@svgdotjs/svg.js": { + "version": "3.2.4", + "resolved": "https://registry.npmjs.org/@svgdotjs/svg.js/-/svg.js-3.2.4.tgz", + "integrity": "sha512-BjJ/7vWNowlX3Z8O4ywT58DqbNRyYlkk6Yz/D13aB7hGmfQTvGX4Tkgtm/ApYlu9M7lCQi15xUEidqMUmdMYwg==", + "license": "MIT", + "funding": { + "type": "github", + "url": "https://github.com/sponsors/Fuzzyma" + } + }, + "node_modules/@svgdotjs/svg.resize.js": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/@svgdotjs/svg.resize.js/-/svg.resize.js-2.0.5.tgz", + "integrity": "sha512-4heRW4B1QrJeENfi7326lUPYBCevj78FJs8kfeDxn5st0IYPIRXoTtOSYvTzFWgaWWXd3YCDE6ao4fmv91RthA==", + "license": "MIT", + "engines": { + "node": ">= 14.18" + }, + "peerDependencies": { + "@svgdotjs/svg.js": "^3.2.4", + "@svgdotjs/svg.select.js": "^4.0.1" + } + }, + "node_modules/@svgdotjs/svg.select.js": { + "version": "4.0.3", + "resolved": "https://registry.npmjs.org/@svgdotjs/svg.select.js/-/svg.select.js-4.0.3.tgz", + "integrity": "sha512-qkMgso1sd2hXKd1FZ1weO7ANq12sNmQJeGDjs46QwDVsxSRcHmvWKL2NDF7Yimpwf3sl5esOLkPqtV2bQ3v/Jg==", + "license": "MIT", + "engines": { + "node": ">= 14.18" + }, + "peerDependencies": { + "@svgdotjs/svg.js": "^3.2.4" + } + }, "node_modules/@tootallnate/once": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/@tootallnate/once/-/once-2.0.0.tgz", @@ -7637,6 +7696,12 @@ "dev": true, "license": "BSD-2-Clause" }, + "node_modules/@yr/monotone-cubic-spline": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/@yr/monotone-cubic-spline/-/monotone-cubic-spline-1.0.3.tgz", + "integrity": "sha512-FQXkOta0XBSUPHndIKON2Y9JeQz5ZeMqLYZVVK93FliNBFm7LNMIZmY6FrMEB9XPcDbE2bekMbZD6kzDkxwYjA==", + "license": "MIT" + }, "node_modules/abbrev": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/abbrev/-/abbrev-1.1.1.tgz", @@ -7905,6 +7970,20 @@ "url": "https://github.com/sponsors/jonschlinkert" } }, + "node_modules/apexcharts": { + "version": "4.7.0", + "resolved": "https://registry.npmjs.org/apexcharts/-/apexcharts-4.7.0.tgz", + "integrity": "sha512-iZSrrBGvVlL+nt2B1NpqfDuBZ9jX61X9I2+XV0hlYXHtTwhwLTHDKGXjNXAgFBDLuvSYCB/rq2nPWVPRv2DrGA==", + "license": "MIT", + "dependencies": { + "@svgdotjs/svg.draggable.js": "^3.0.4", + "@svgdotjs/svg.filter.js": "^3.0.8", + "@svgdotjs/svg.js": "^3.2.4", + "@svgdotjs/svg.resize.js": "^2.0.2", + "@svgdotjs/svg.select.js": "^4.0.1", + "@yr/monotone-cubic-spline": "^1.0.3" + } + }, "node_modules/aproba": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/aproba/-/aproba-2.0.0.tgz", @@ -13677,6 +13756,18 @@ "node": ">=8.0" } }, + "node_modules/loose-envify": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/loose-envify/-/loose-envify-1.4.0.tgz", + "integrity": "sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==", + "license": "MIT", + "dependencies": { + "js-tokens": "^3.0.0 || ^4.0.0" + }, + "bin": { + "loose-envify": "cli.js" + } + }, "node_modules/lru-cache": { "version": "5.1.1", "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-5.1.1.tgz", @@ -15338,7 +15429,6 @@ "version": "4.1.1", "resolved": "https://registry.npmjs.org/object-assign/-/object-assign-4.1.1.tgz", "integrity": "sha512-rJgTQnkUnH1sFw8yT6VSU3zD3sWmu6sZhIseY8VX+GRu3P6F7Fu+JNDoXfklElbLJSnc3FUQHVe4cU5hj+BcUg==", - "dev": true, "license": "MIT", "engines": { "node": ">=0.10.0" @@ -16524,6 +16614,17 @@ "node": ">=10" } }, + "node_modules/prop-types": { + "version": "15.8.1", + "resolved": "https://registry.npmjs.org/prop-types/-/prop-types-15.8.1.tgz", + "integrity": "sha512-oj87CgZICdulUohogVAR7AjlC0327U4el4L6eAvOqCeudMDVU0NThNaV+b9Df4dXgSP1gXMTnPdhfe/2qDH5cg==", + "license": "MIT", + "dependencies": { + "loose-envify": "^1.4.0", + "object-assign": "^4.1.1", + "react-is": "^16.13.1" + } + }, "node_modules/proxy-addr": { "version": "2.0.7", "resolved": "https://registry.npmjs.org/proxy-addr/-/proxy-addr-2.0.7.tgz", @@ -16675,6 +16776,35 @@ "node": ">= 0.8" } }, + "node_modules/react": { + "version": "19.1.0", + "resolved": "https://registry.npmjs.org/react/-/react-19.1.0.tgz", + "integrity": "sha512-FS+XFBNvn3GTAWq26joslQgWNoFu08F4kl0J4CgdNKADkdSGXQyTCnKteIAJy96Br6YbpEU1LSzV5dYtjMkMDg==", + "license": "MIT", + "peer": true, + "engines": { + "node": ">=0.10.0" + } + }, + "node_modules/react-apexcharts": { + "version": "1.7.0", + "resolved": "https://registry.npmjs.org/react-apexcharts/-/react-apexcharts-1.7.0.tgz", + "integrity": "sha512-03oScKJyNLRf0Oe+ihJxFZliBQM9vW3UWwomVn4YVRTN1jsIR58dLWt0v1sb8RwJVHDMbeHiKQueM0KGpn7nOA==", + "license": "MIT", + "dependencies": { + "prop-types": "^15.8.1" + }, + "peerDependencies": { + "apexcharts": ">=4.0.0", + "react": ">=0.13" + } + }, + "node_modules/react-is": { + "version": "16.13.1", + "resolved": "https://registry.npmjs.org/react-is/-/react-is-16.13.1.tgz", + "integrity": "sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ==", + "license": "MIT" + }, "node_modules/read-pkg": { "version": "5.2.0", "resolved": "https://registry.npmjs.org/read-pkg/-/read-pkg-5.2.0.tgz", diff --git a/frontend/package.json b/frontend/package.json index 7763103..3df512a 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -21,7 +21,10 @@ "@angular/platform-browser": "^19.2.0", "@angular/platform-browser-dynamic": "^19.2.0", "@angular/router": "^19.2.0", + "apexcharts": "^4.7.0", "ngx-markdown": "^19.1.0", + "node-sass": "^9.0.0", + "react-apexcharts": "^1.7.0", "rxjs": "^7.8.2", "sass": "^1.89.0", "schematics-scss-migrate": "^2.3.17", diff --git a/frontend/src/app/app-routing.module.ts b/frontend/src/app/app-routing.module.ts index 0297262..9adec64 100755 --- a/frontend/src/app/app-routing.module.ts +++ b/frontend/src/app/app-routing.module.ts @@ -1,10 +1,16 @@ -import { NgModule } from '@angular/core'; -import { RouterModule, Routes } from '@angular/router'; +import {RouterModule, Routes} from '@angular/router'; -const routes: Routes = []; +import {ChatzoneComponent} from './chatzone/chatzone.component'; +import {HeatmapComponent} from './heatmap/heatmap.component'; +import {NgModule} from '@angular/core'; + +const routes: Routes = [ + {path: '', component: ChatzoneComponent}, + {path: 'heatmap', component: HeatmapComponent}, +]; @NgModule({ imports: [RouterModule.forRoot(routes)], - exports: [RouterModule] + exports: [RouterModule], }) -export class AppRoutingModule { } +export class AppRoutingModule {} diff --git a/frontend/src/app/app.component.html b/frontend/src/app/app.component.html index d913607..2de8798 100755 --- a/frontend/src/app/app.component.html +++ b/frontend/src/app/app.component.html @@ -1 +1,2 @@ - \ No newline at end of file + + \ No newline at end of file diff --git a/frontend/src/app/chatzone/chatzone.component.css b/frontend/src/app/chatzone/chatzone.component.css index b75e95f..8512c4a 100644 --- a/frontend/src/app/chatzone/chatzone.component.css +++ b/frontend/src/app/chatzone/chatzone.component.css @@ -19,6 +19,10 @@ } .status-report-container { + display: flex; + flex-direction: column; + /* height: 100vh; */ + /* padding: 1rem; */ max-width: 400px; width: 20%; } @@ -31,11 +35,17 @@ overflow-y: scroll; } +.buttons-wrapper { + margin-top: auto; /* pousse les boutons en bas */ + display: flex; + flex-direction: column; + gap: 1rem; /* espace entre les boutons */ +} + .title { justify-content: center; color: #3c226f; margin: auto; - font-family: AmericanTypewriter; background-color: unset; } @@ -167,6 +177,10 @@ mat-tab-group { color: gray; } +.left-panel-button { + width: 100%; +} + /** Generic classes **/ diff --git a/frontend/src/app/chatzone/chatzone.component.html b/frontend/src/app/chatzone/chatzone.component.html index 8271a1f..647f11c 100644 --- a/frontend/src/app/chatzone/chatzone.component.html +++ b/frontend/src/app/chatzone/chatzone.component.html @@ -19,8 +19,14 @@ -
- +
+
+ +
+
+ + +
diff --git a/frontend/src/app/chatzone/chatzone.component.ts b/frontend/src/app/chatzone/chatzone.component.ts index 20abee7..bf24e35 100644 --- a/frontend/src/app/chatzone/chatzone.component.ts +++ b/frontend/src/app/chatzone/chatzone.component.ts @@ -1,15 +1,16 @@ -import { Component, ViewChildren, QueryList, ElementRef, AfterViewInit, AfterViewChecked } from '@angular/core'; -import { ChatItem, Message, ReportCard, VulnerabilityReportCard } from '../types/ChatItem'; -import { WebSocketService } from '../services/web-socket.service'; -import { Step, Status } from '../types/Step'; -import { APIResponse, ReportItem } from '../types/API'; -import { VulnerabilityInfoService } from '../services/vulnerability-information.service'; +import {APIResponse, ReportItem} from '../types/API'; +import {AfterViewChecked, AfterViewInit, Component, ElementRef, QueryList, ViewChildren} from '@angular/core'; +import {ChatItem, Message, ReportCard, VulnerabilityReportCard} from '../types/ChatItem'; +import {Status, Step} from '../types/Step'; + +import {VulnerabilityInfoService} from '../services/vulnerability-information.service'; +import {WebSocketService} from '../services/web-socket.service'; @Component({ selector: 'app-chatzone', templateUrl: './chatzone.component.html', styleUrls: ['./chatzone.component.css'], - standalone: false + standalone: false, }) export class ChatzoneComponent implements AfterViewInit, AfterViewChecked { chatItems: ChatItem[]; @@ -21,26 +22,27 @@ export class ChatzoneComponent implements AfterViewInit, AfterViewChecked { progress: number | undefined; constructor(private ws: WebSocketService, private vis: VulnerabilityInfoService) { this.inputValue = ''; - this.apiKey = localStorage.getItem("key") || ""; + this.apiKey = localStorage.getItem('key') || ''; this.errorMessage = ''; this.chatItems = []; this.steps = []; this.progress = undefined; - this.ws.webSocket$ - .subscribe({ - next: (value: any) => { // eslint-disable-line @typescript-eslint/no-explicit-any - this.handleWSMessage(value as APIResponse); - }, - error: (error: any) => { // eslint-disable-line @typescript-eslint/no-explicit-any - console.log(error); - if (error?.type != "close") { // Close is already handled via the isConnected call - this.errorMessage = error; - } - }, - complete: () => alert("Connection to server closed.") - } - ); + this.ws.webSocket$.subscribe({ + next: (value: any) => { + // eslint-disable-line @typescript-eslint/no-explicit-any + this.handleWSMessage(value as APIResponse); + }, + error: (error: any) => { + // eslint-disable-line @typescript-eslint/no-explicit-any + console.log(error); + if (error?.type != 'close') { + // Close is already handled via the isConnected call + this.errorMessage = error; + } + }, + complete: () => alert('Connection to server closed.'), + }); this.restoreChatItems(); } @@ -48,32 +50,32 @@ export class ChatzoneComponent implements AfterViewInit, AfterViewChecked { // Handling of the websocket connection checkInput(value: string): void { - if (value && value.trim() != "") { + if (value && value.trim() != '') { this.inputValue = ''; this.ws.postMessage(value, this.apiKey); const userMessage: Message = { type: 'message', id: 'user-message', message: value, - avatar: "person", - timestamp: Date.now() + avatar: 'person', + timestamp: Date.now(), }; this.appendMessage(userMessage); } } handleWSMessage(input: APIResponse): void { - if (input.type == "message") { + if (input.type == 'message') { const aiMessageString = input.data; const aiMessage: Message = { type: 'message', id: 'ai-message', message: aiMessageString, - avatar: "computer", - timestamp: Date.now() + avatar: 'computer', + timestamp: Date.now(), }; this.appendMessage(aiMessage); - } else if (input.type == "status") { + } else if (input.type == 'status') { const current = input.current; const total = input.total; const progress = current / total; @@ -81,7 +83,7 @@ export class ChatzoneComponent implements AfterViewInit, AfterViewChecked { if (progress >= 1) { this.progress = undefined; } - } else if (input.type == "report") { + } else if (input.type == 'report') { if (input.reset) { this.steps = []; return; @@ -96,28 +98,28 @@ export class ChatzoneComponent implements AfterViewInit, AfterViewChecked { this.steps.push(step); } } - } else if (input.type == "intermediate") { + } else if (input.type == 'intermediate') { const text = '### Intermediate result from attack\n' + input.data; const intermediateMessage: Message = { type: 'message', id: 'assistant-intermediate-message', message: text, - avatar: "computer", - timestamp: Date.now() + avatar: 'computer', + timestamp: Date.now(), }; this.appendMessage(intermediateMessage); - } else if (input.type == "vulnerability-report") { + } else if (input.type == 'vulnerability-report') { const vulnerabilityCards = input.data.map(vri => { - const vrc = (vri as VulnerabilityReportCard); + const vrc = vri as VulnerabilityReportCard; vrc.description = this.vis.getInfo(vri.vulnerability); return vrc; }); this.chatItems.push({ type: 'report-card', - 'reports': vulnerabilityCards, - 'name': input.name + reports: vulnerabilityCards, + name: input.name, }); - localStorage.setItem("cached-chat-items", JSON.stringify(this.chatItems)); + localStorage.setItem('cached-chat-items', JSON.stringify(this.chatItems)); } } @@ -127,21 +129,26 @@ export class ChatzoneComponent implements AfterViewInit, AfterViewChecked { getIconForStepStatus(step: Step): string { switch (step.status) { - case Status.COMPLETED: return 'check_circle'; - case Status.FAILED: return 'error'; - case Status.SKIPPED: return 'skip_next'; - case Status.RUNNING: return 'play_circle'; - case Status.PENDING: return 'pending'; + case Status.COMPLETED: + return 'check_circle'; + case Status.FAILED: + return 'error'; + case Status.SKIPPED: + return 'skip_next'; + case Status.RUNNING: + return 'play_circle'; + case Status.PENDING: + return 'pending'; } } appendMessage(message: Message) { this.chatItems.push(message); - localStorage.setItem("cached-chat-items", JSON.stringify(this.chatItems)); + localStorage.setItem('cached-chat-items', JSON.stringify(this.chatItems)); } restoreChatItems() { - const storedMessages = localStorage.getItem("cached-chat-items"); + const storedMessages = localStorage.getItem('cached-chat-items'); if (storedMessages) { this.chatItems = JSON.parse(storedMessages); } @@ -149,22 +156,30 @@ export class ChatzoneComponent implements AfterViewInit, AfterViewChecked { clearChatHistory() { this.chatItems = []; - localStorage.setItem("cached-chat-items", "[]"); + localStorage.setItem('cached-chat-items', '[]'); } static deserializeStep(obj: ReportItem): Step { let status = Status.RUNNING; switch (obj.status) { - case "COMPLETED": status = Status.COMPLETED; break; - case "FAILED": status = Status.FAILED; break; - case "SKIPPED": status = Status.SKIPPED; break; - case "PENDING": status = Status.PENDING; break; + case 'COMPLETED': + status = Status.COMPLETED; + break; + case 'FAILED': + status = Status.FAILED; + break; + case 'SKIPPED': + status = Status.SKIPPED; + break; + case 'PENDING': + status = Status.PENDING; + break; } return { title: obj.title, description: obj.description, status: status, - progress: obj.progress + progress: obj.progress, }; } @@ -193,11 +208,11 @@ export class ChatzoneComponent implements AfterViewInit, AfterViewChecked { const downloadUrl = window.URL.createObjectURL(file); const a = document.createElement('a'); a.href = downloadUrl; - a.download = `${reportName}.${reportFormat}`; // Set the filename + a.download = `${reportName}.${reportFormat}`; // Set the filename document.body.appendChild(a); - a.click(); // Programmatically click the link to trigger the download - document.body.removeChild(a); // Remove the link element - window.URL.revokeObjectURL(downloadUrl); // Clean up the URL object + a.click(); // Programmatically click the link to trigger the download + document.body.removeChild(a); // Remove the link element + window.URL.revokeObjectURL(downloadUrl); // Clean up the URL object } // Scrolling to have new messages visible @@ -258,11 +273,11 @@ export class ChatzoneComponent implements AfterViewInit, AfterViewChecked { downloadChatHistory(): void { const markdownContent = this.exportChat(); - const blob = new Blob([markdownContent], { type: 'text/markdown' }); + const blob = new Blob([markdownContent], {type: 'text/markdown'}); const link = document.createElement('a'); link.href = URL.createObjectURL(blob); - link.download = "STARS_chat_" + new Date().toISOString() + ".md"; + link.download = 'STARS_chat_' + new Date().toISOString() + '.md'; // Append the link to the body document.body.appendChild(link); @@ -280,4 +295,9 @@ export class ChatzoneComponent implements AfterViewInit, AfterViewChecked { this.apiKey = prompt('Set API Key', this.apiKey) || this.apiKey; localStorage.setItem('key', this.apiKey); } + + // openDashboard() that loads a new page with the dashboard at the route /heatmap + openDashboard(): void { + window.open('/heatmap', '_blank'); + } } diff --git a/frontend/src/app/heatmap/heatmap.component.css b/frontend/src/app/heatmap/heatmap.component.css new file mode 100644 index 0000000..8a9c30e --- /dev/null +++ b/frontend/src/app/heatmap/heatmap.component.css @@ -0,0 +1,80 @@ +h1 { + width: 500px; + margin: auto; + text-align: center; +} + +#heatmapChart { + /* max-width: 600px; + max-height: 400px; */ + display: block; + margin: 10px auto; +} + +* { + font-family: Helvetica, Arial, sans-serif; +} + +#vendorChoice { + padding: 20px 0; + width: 500px; + margin: auto; + text-align: center; +} + +.title-card { + width: 700px; + margin: auto; + height: 100px; + padding: 25px; +} + .heatmap-card { + width: 700px; + margin: auto; + height: 1000px; + padding: 25px; +} + +/* .card-header { + font-size: 2rem; + font-weight: bold; +} */ + +#overview { + font-size: medium; +} + +/* .header-icon { + width: 50px; +} */ + +.card-header { + display: flex; + align-items: center; + gap: 10px; + font-size: 2rem; + font-weight: bold; + justify-content: center; +} + +.header-icon { + width: 45px; +} + +.title { + display: flex; + align-items: center; +} + +.buttons-wrapper { + /* margin-top: auto; pousse les boutons en bas */ + display: flex; + flex-direction: column; + gap: 1rem; /* espace entre les boutons */ +} + +.centered-button { + display: flex; + justify-content: center; + align-items: center; +} diff --git a/frontend/src/app/heatmap/heatmap.component.html b/frontend/src/app/heatmap/heatmap.component.html new file mode 100644 index 0000000..eda4d0f --- /dev/null +++ b/frontend/src/app/heatmap/heatmap.component.html @@ -0,0 +1,36 @@ + + + + +
+ + STARS Results Heatmap +
+ + +
+
+ + + +
+
+ +
+ + Select a Vendor + + + All vendors + {{ vendor }} + + + overview +
+ + +
+ + diff --git a/frontend/src/app/heatmap/heatmap.component.ts b/frontend/src/app/heatmap/heatmap.component.ts new file mode 100644 index 0000000..ee36f7d --- /dev/null +++ b/frontend/src/app/heatmap/heatmap.component.ts @@ -0,0 +1,291 @@ +import {AfterViewInit, ChangeDetectorRef, Component, ElementRef, OnInit} from '@angular/core'; +import {Observable, map} from 'rxjs'; +import {capitalizeFirstLetter, generateModelName, splitModelName} from '../utils/utils'; + +import ApexCharts from 'apexcharts'; +import {CommonModule} from '@angular/common'; +import {FormsModule} from '@angular/forms'; +import {HttpClient} from '@angular/common/http'; +import {MatButtonModule} from '@angular/material/button'; +import {MatCardModule} from '@angular/material/card'; +import {MatFormFieldModule} from '@angular/material/form-field'; +import {MatSelectModule} from '@angular/material/select'; +import {environment} from '../../environments/environment'; + +@Component({ + selector: 'app-heatmap', + templateUrl: './heatmap.component.html', + styleUrls: ['./heatmap.component.css'], + standalone: true, + imports: [CommonModule, MatFormFieldModule, MatSelectModule, FormsModule, MatCardModule, MatButtonModule], +}) +export class HeatmapComponent implements AfterViewInit, OnInit { + public heatmapData: number[][] = []; + // for UI dropdown menu of vendors + public vendorsNames: string[] = []; + public selectedVendor: string = ''; + public weightedAttacks: {attackName: string; weight: string}[] = []; + + constructor(private http: HttpClient, private el: ElementRef, private changeDetector: ChangeDetectorRef) {} + + ngAfterViewInit() { + this.createHeatmap([]); // Initialisation avec des données vides + } + + ngOnInit() { + // this.loadHeatmapData('amazon'); + this.loadVendorsData(); + this.loadHeatmapData(''); + } + + onFileSelected(event: any) { + // const file = event.target.files[0]; + // if (!file) return; + // const formData = new FormData(); + // formData.append('file', file); + // this.http.post('http://localhost:3000/upload', formData).subscribe({ + // next: data => { + // console.log('📊 Données reçues via upload:', data); + // this.processData(data); + // }, + // error: error => console.error('❌ Erreur upload:', error), + // }); + } + + //load a dropdown menu from the loadModelsData result + loadVendorsData() { + // this.http.get(`http://127.0.0.1:8080/api/vendors`).subscribe({ + this.http.get(`${environment.api_url}/api/vendors`).subscribe({ + next: data => { + console.log('📡 Données brutes reçues du serveur:', data); + this.processVendors(data.map(vendor => vendor)); + }, + error: error => console.error('❌ Erreur API:', error), + }); + } + + //load the heatmap data from the server with a name in params + loadHeatmapData(vendor: string) { + let url = ''; + if (!vendor) { + url = `${environment.api_url}/api/heatmap`; + } else { + url = `${environment.api_url}/api/heatmap/${vendor}`; + } + this.http.get(url).subscribe({ + // this.http.get(`${environment.api_url}/api/${vendor}`).subscribe({ + next: scoresData => { + this.processData(scoresData, vendor); + }, + error: error => console.error('❌ Erreur API:', error), + }); + } + + // handle models name recieved from the server to a list used in frontend for a dropdown menu + processVendors(vendorsNames: string[]) { + this.vendorsNames = vendorsNames.map(capitalizeFirstLetter); + } + + processData(data: any[], vendor: string = '') { + const modelNames = generateModelName(data, vendor); + this.getWeightedAttacks().subscribe({ + next: weightedAttacks => { + this.heatmapData = data.map(row => { + const rowData = weightedAttacks.map(attack => { + const value = Number(row[attack.attackName]?.trim()); + return isNaN(value) ? 0 : value * 10; + }); + let totalWeights = 0; + // Add an extra column at the end with a custom calculation (modify as needed) + const weightedSumColumn = weightedAttacks.reduce((sum, {attackName, weight}) => { + const value = Number(row[attackName]?.trim()); + const weightedValue = isNaN(value) ? 0 : value * Number(weight); + totalWeights = totalWeights + Number(weight); + return sum + weightedValue; + }, 0); + // Append the calculated weighted sum column to the row as the last column "as an attack" even if it's a custom calculated value + return [...rowData, (weightedSumColumn / totalWeights) * 10]; + }); + const attackNames = weightedAttacks.map(attack => attack.attackName); + this.createHeatmap(this.heatmapData, modelNames, [...attackNames.map(capitalizeFirstLetter), 'Exposure score'], vendor !== ''); + }, + error: error => console.error('❌ Erreur API:', error), + }); + } + + createHeatmap(data: number[][], modelNames: Record = {}, attackNames: string[] = [], oneVendorDisplayed: boolean = false) { + const cellSize = 100; + const chartWidth = attackNames.length * cellSize + 150; // +100 to allow some space for translated labels + const chartHeight = data.length <= 3 ? data.length * cellSize + 100 : data.length * cellSize; + // const series = Object.entries(modelNames).flatMap(([vendor, models]) => + // models.map((model, modelIndex) => ({ + // name: splitModelName(vendor, model), + // data: data[modelIndex].map((value, colIndex) => ({ + // x: attackNames[colIndex], + // y: value, + // })), + // })) + // ); + + // // group by vendors + // let globalIndex = 0; + // const series = Object.entries(modelNames).flatMap(([vendor, models]) => + // models.map(model => { + // const seriesData = { + // name: splitModelName(vendor, model), + // data: data[globalIndex].map((value, colIndex) => ({ + // x: attackNames[colIndex], + // y: value, + // })), + // }; + // globalIndex++; // Increment global index for next model + // return seriesData; + // }) + // ); + + // does not group by vendor + // Flatten all models but keep vendor info + const allModels = Object.entries(modelNames).flatMap(([vendor, models]) => models.map(model => ({vendor, model}))); + + let globalIndex = 0; + + const series = allModels.map(({vendor, model}) => { + const seriesData = { + name: splitModelName(vendor, model), // Display vendor and model together + data: data[globalIndex].map((value, colIndex) => ({ + x: attackNames[colIndex], + y: value, + })), + }; + globalIndex++; // Move to next row in data + return seriesData; + }); + + const options = { + chart: { + type: 'heatmap', + height: chartHeight, + width: chartWidth, + toolbar: {show: false}, + events: { + legendClick: function () { + console.log('CLICKED'); + }, + }, + }, + series: series, + plotOptions: { + heatmap: { + shadeIntensity: 0.5, + // useFillColorAsStroke: true, // Améliore le rendu des cases + colorScale: { + ranges: [ + // {from: 0, to: 20, color: '#5aa812'}, // Light green for 0-20 + // {from: 21, to: 40, color: '#00A100'}, // Darker green for 21-40 + // {from: 41, to: 60, color: '#FFB200'}, // Light orange for 41-60 + // {from: 61, to: 80, color: '#FF7300'}, // Darker orange for 61-80 + // {from: 81, to: 100, color: '#FF0000'}, // Red for 81-100 + + {from: 0, to: 40, color: '#00A100'}, + // {from: 21, to: 40, color: '#128FD9'}, + {from: 41, to: 80, color: '#FF7300'}, + // {from: 61, to: 80, color: '#FFB200'}, + {from: 81, to: 100, color: '#FF0000'}, + ], + }, + }, + }, + grid: { + padding: {top: 0, right: 0, bottom: 0, left: 0}, + }, + dataLabels: { + style: {fontSize: '14px'}, + }, + legend: { + show: true, + // markers: { + // customHTML: function () { + // return ''; + // }, + // }, + // markers: { + // width: 12, + // height: 12, + // // Remove customHTML if you want the default + // }, + }, + xaxis: { + categories: attackNames, + title: {text: 'Attacks'}, + labels: {rotate: -45, style: {fontSize: '12px'}}, + position: 'top', + }, + yaxis: { + categories: modelNames, + title: { + text: 'Models', + offsetX: oneVendorDisplayed ? -90 : -60, + }, + labels: { + style: { + fontSize: '12px', + }, + offsetY: -10, + }, + reversed: true, + }, + tooltip: { + y: { + formatter: undefined, + title: { + formatter: (seriesName: string) => seriesName.replace(',', '-'), + }, + }, + }, + }; + const chartElement = this.el.nativeElement.querySelector('#heatmapChart'); + if (chartElement) { + chartElement.innerHTML = ''; + const chart = new ApexCharts(chartElement, options); + chart.render(); + } + } + + public onVendorChange(event: any) { + this.loadHeatmapData(this.selectedVendor); + } + + // getattacksNames() return an array of attacks names from the server from http://localhost:3000/api/attacks + getAttacksNames(): Observable { + return this.http.get(`${environment.api_url}/api/attacks`).pipe( + // return this.http.get(`http://127.0.0.1:8080/api/attacks`).pipe( + map(data => data.map(row => row['attackName'])) // Extract only attack names + ); + } + + getWeightedAttacks(): Observable<{attackName: string; weight: string}[]> { + return this.http.get(`${environment.api_url}/api/attacks`); + // return this.http.get(`http://127.0.0.1:8080/api/attacks`); + } + + getVendors(): Observable { + this.changeDetector.detectChanges(); + return this.http.get(`${environment.api_url}/api/vendors`); + // return this.http.get(`http://127.0.0.1:8080/api/vendors`); + } + + uploadCSV(event: any) { + const file = event.target.files[0]; + const formData = new FormData(); + formData.append('file', file); + + this.http.post(`${environment.api_url}/api/upload-csv`, formData).subscribe({ + next: res => { + console.log('Upload success', res); + }, + error: err => { + console.error('Upload failed', err); + }, + }); + } +} diff --git a/frontend/src/app/utils/utils.ts b/frontend/src/app/utils/utils.ts new file mode 100644 index 0000000..8c9b220 --- /dev/null +++ b/frontend/src/app/utils/utils.ts @@ -0,0 +1,40 @@ +// export function generateModelName(vendor: string, modelType: string, version: string, specialization: string, other: string, withVendor = true): string { +export function generateModelName(data: any[], vendor: string): any { + const result: Record = {}; + + data.forEach(row => { + const vendorName = vendor === '' ? row['vendor'] : vendor; // Si vendor est vide, on prend row['vendor'], sinon on utilise vendor existant + + const model = [row['modelType'], row['version'], row['specialization'], row['other']] + .filter(value => value) + .join('-') + .replace(/\s+/g, ' ') + .trim(); + + if (!result[vendorName]) { + result[vendorName] = []; + } + + result[vendorName].push(model); + }); + return result; +} + +export function splitModelName(vendor: string, model: string): string[] { + if (model.length < 18) return [vendor, model]; // No need to split + + // Find the last "-" before the 20th character + const cutoffIndex = model.lastIndexOf('-', 20); + + if (cutoffIndex === -1) { + // If no "-" found before 20, force split at 20 + return [vendor, model.slice(0, 20), model.slice(20)]; + } + + // Split at the last "-" before 20 + return [vendor, model.slice(0, cutoffIndex), model.slice(cutoffIndex + 1)].map(capitalizeFirstLetter); +} + +export function capitalizeFirstLetter(str: string): string { + return str.charAt(0).toUpperCase() + str.slice(1); +} diff --git a/frontend/src/styles.css b/frontend/src/styles.css index b5462b6..7c155bb 100644 --- a/frontend/src/styles.css +++ b/frontend/src/styles.css @@ -13,14 +13,63 @@ body { margin: 0; font-family: Roboto, "Helvetica Neue", sans-serif; } white-space: pre-wrap !important; } -@font-face { - font-family: 'AmericanTypewriter'; - src: url('assets/fonts/AmericanTypewriter.ttc') format('truetype'); - font-weight: normal; - font-style: normal; -} /* For responses from agent, which can contain very long lines */ pre { overflow-x: auto; } + +body { + font-family: Roboto, "Helvetica Neue", sans-serif; + margin: 0; + padding: 30px; + height: 100%; +} + +.apexcharts-canvas { + margin: auto; + translate: -60px; +} + +.apexcharts-inner { + translate: 40px; +} + +.apexcharts-yaxis-texts-g { + translate: 20px; +} + +.apexcharts-yaxis-title { + translate: 100px; + > text { + font-size: large; + } +} + +.apexcharts-xaxis { + translate: 0 10px; +} + +.apexcharts-xaxis-title { + > text { + font-size: large; + } +} + +html { height: 100%; } + +.card-header { + font-size: 2.5rem; /* Larger font for the header */ + font-weight: bold; /* Make the text bold */ + text-align: center; /* Center the text */ + padding: 16px; +} + +mat-select { + direction: rtl !important; + text-align: right !important; +} + +.apexcharts-legend { + translate: 50px; +} diff --git a/frontend/tsconfig.json b/frontend/tsconfig.json index ed966d4..409a471 100755 --- a/frontend/tsconfig.json +++ b/frontend/tsconfig.json @@ -2,6 +2,8 @@ { "compileOnSave": false, "compilerOptions": { + "allowSyntheticDefaultImports": true, + "esModuleInterop": true, "baseUrl": "./", "outDir": "./dist/out-tsc", "forceConsistentCasingInFileNames": true,