diff --git a/framework/python/src/api/api.py b/framework/python/src/api/api.py index 6931df36c..8a5de24ec 100644 --- a/framework/python/src/api/api.py +++ b/framework/python/src/api/api.py @@ -119,6 +119,12 @@ def __init__(self, test_run): self.get_profiles_format) self._router.add_api_route("/profiles", self.get_profiles) + self._router.add_api_route("/profiles", + self.update_profile, + methods=["POST"]) + self._router.add_api_route("/profiles", + self.delete_profile, + methods=["DELETE"]) # Allow all origins to access the API origins = ["*"] @@ -378,7 +384,12 @@ async def delete_report(self, request: Request, response: Response): response.status_code = 400 return self._generate_msg(False, "Invalid request received") - body_json = json.loads(body_raw) + try: + body_json = json.loads(body_raw) + except JSONDecodeError as e: + response.status_code = status.HTTP_400_BAD_REQUEST + return self._generate_msg(False, + "Invalid request received") if "mac_addr" not in body_json or "timestamp" not in body_json: response.status_code = 400 @@ -626,6 +637,91 @@ def get_profiles_format(self, response: Response): def get_profiles(self): return self.get_session().get_profiles() + async def update_profile(self, request: Request, response: Response): + + LOGGER.debug("Received profile update request") + + try: + req_raw = (await request.body()).decode("UTF-8") + req_json = json.loads(req_raw) + except JSONDecodeError as e: + response.status_code = status.HTTP_400_BAD_REQUEST + return self._generate_msg(False, + "Invalid request received") + + # Check that profile is valid + valid_profile = self.get_session().validate_profile(req_json) + if not valid_profile: + response.status_code = status.HTTP_400_BAD_REQUEST + return self._generate_msg(False, "Invalid profile request received") + + profile_name = req_json.get("name") + + # Check if profile exists + profile = self.get_session().get_profile(profile_name) + + if profile is None: + # Create new profile + profile = self.get_session().update_profile(req_json) + + if profile is not None: + response.status_code = status.HTTP_201_CREATED + return self._generate_msg(True, "Successfully created a new profile") + LOGGER.error("An error occurred whilst creating a new profile") + + else: + # Update existing profile + profile = self.get_session().update_profile(req_json) + + if profile is not None: + response.status_code = status.HTTP_200_OK + return self._generate_msg(True, "Successfully updated that profile") + LOGGER.error("An error occurred whilst updating a profile") + + response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR + return self._generate_msg(False, "An error occurred whilst creating or updating a profile") + + async def delete_profile(self, request: Request, response: Response): + + LOGGER.debug("Received profile delete request") + + try: + req_raw = (await request.body()).decode("UTF-8") + req_json = json.loads(req_raw) + except JSONDecodeError as e: + response.status_code = status.HTTP_400_BAD_REQUEST + return self._generate_msg(False, + "Invalid request received") + + # Check name included in request + if 'name' not in req_json: + response.status_code = status.HTTP_400_BAD_REQUEST + return self._generate_msg(False, + "Invalid request received") + + # Get profile name + profile_name = req_json.get("name") + + # Fetch profile + profile = self.get_session().get_profile(profile_name) + + # Check if profile exists + if profile is None: + response.status_code = status.HTTP_404_NOT_FOUND + return self._generate_msg(False, + "A profile with that name could not be found") + + # Attempt to delete the profile + success = self.get_session().delete_profile(profile) + + if not success: + response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR + return self._generate_msg(False, + "An error occurred whilst deleting that profile") + + return self._generate_msg(True, + "Successfully deleted that profile") + # Certificates def get_certs(self): LOGGER.debug("Received certs list request") diff --git a/framework/python/src/common/risk_profile.py b/framework/python/src/common/risk_profile.py index bd82c076c..06ea41a43 100644 --- a/framework/python/src/common/risk_profile.py +++ b/framework/python/src/common/risk_profile.py @@ -15,30 +15,27 @@ from datetime import datetime -SECONDS_IN_YEAR = 31536000 - class RiskProfile(): def __init__(self, json_data): self.name = json_data['name'] - self.status = json_data['status'] - self.created = json_data['created'] - self.version = json_data['version'] - self.questions = json_data['questions'] - - # Check the profile has not expired - self.check_status() - def check_status(self): - if self.status == 'Valid': + if 'status' in json_data: + self.status = json_data['status'] + else: + self.status = 'Draft' - # Check expiry - created_date = datetime.strptime( - self.created, "%Y-%m-%d %H:%M:%S").timestamp() - - today = datetime.now().timestamp() - - if created_date < (today - SECONDS_IN_YEAR): - self.status = 'Expired' + self.created = datetime.strptime(json_data['created'], + '%Y-%m-%d') + self.version = json_data['version'] + self.questions = json_data['questions'] - return self.status + def to_json(self): + json = { + 'name': self.name, + 'version': self.version, + 'created': self.created.strftime('%Y-%m-%d'), + 'status': self.status, + 'questions': self.questions + } + return json \ No newline at end of file diff --git a/framework/python/src/common/session.py b/framework/python/src/common/session.py index ca5015c52..cd138bca0 100644 --- a/framework/python/src/common/session.py +++ b/framework/python/src/common/session.py @@ -37,6 +37,7 @@ MAX_DEVICE_REPORTS_KEY = 'max_device_reports' CERTS_PATH = 'local/root_certs' CONFIG_FILE_PATH = 'local/system.json' +SECONDS_IN_YEAR = 31536000 PROFILE_FORMAT_PATH = 'resources/risk_assessment.json' PROFILES_DIR = 'local/risk_profiles' @@ -372,7 +373,9 @@ def _load_profiles(self): ), encoding='utf-8') as f: json_data = json.load(f) risk_profile = RiskProfile(json_data) + risk_profile.status = self.check_profile_status(risk_profile) self._profiles.append(risk_profile) + except Exception as e: LOGGER.error('An error occurred whilst loading risk profiles') @@ -383,6 +386,137 @@ def get_profiles_format(self): def get_profiles(self): return self._profiles + + def get_profile(self, name): + for profile in self._profiles: + if profile.name.lower() == name.lower(): + return profile + return None + + def validate_profile(self, profile_json): + + # Check name field is present + if 'name' not in profile_json: + return False + + # Check questions field is present + if 'questions' not in profile_json: + return False + + # Check all questions are present + for format_q in self.get_profiles_format(): + if self._get_profile_question(profile_json, format_q.get('question')) is None: + LOGGER.error('Missing question: ' + format_q.get('question')) + return False + + return True + + def _get_profile_question(self, profile_json, question): + + for q in profile_json.get('questions'): + if question.lower() == q.get('question').lower(): + return q + + return None + + def update_profile(self, profile_json): + + profile_name = profile_json['name'] + + # Add version, timestamp and status + profile_json['version'] = self.get_version() + profile_json['created'] = datetime.datetime.now().strftime('%Y-%m-%d') + + if 'status' in profile_json and profile_json.get('status') == 'Valid': + # Attempting to submit a risk profile, we need to check it + + # Check all questions have been answered + all_questions_answered = True + + for question in self.get_profiles_format(): + + # Check question is present + profile_question = self._get_profile_question( + profile_json, question.get('question') + ) + + if profile_question is not None: + + # Check answer is present + if 'answer' not in profile_question: + LOGGER.error("Missing answer for question: " + question.get('question')) + all_questions_answered = False + + else: + LOGGER.error("Missing question: " + question.get('question')) + all_questions_answered = False + + if not all_questions_answered: + LOGGER.error('Not all questions answered') + return None + + else: + profile_json['status'] = 'Draft' + + risk_profile = self.get_profile(profile_name) + + if risk_profile is None: + + # Create a new risk profile + risk_profile = RiskProfile(profile_json) + self._profiles.append(risk_profile) + + else: + + # Check if name has changed + if 'rename' in profile_json: + new_name = profile_json.get('rename') + + # Delete the original file + os.remove(os.path.join(PROFILES_DIR, risk_profile.name + '.json')) + + risk_profile.name = new_name + + # Update questions and answers + risk_profile.questions = profile_json.get('questions') + + # Write file to disk + with open(os.path.join(PROFILES_DIR, risk_profile.name + '.json'), 'w') as f: + f.write(json.dumps(risk_profile.to_json())) + + return risk_profile + + def check_profile_status(self, profile): + + if profile.status == 'Valid': + + # Check expiry + created_date = profile.created.timestamp() + + today = datetime.datetime.now().timestamp() + + if created_date < (today - SECONDS_IN_YEAR): + profile.status = 'Expired' + + return profile.status + + def delete_profile(self, profile): + + try: + profile_name = profile.name + file_name = profile_name + '.json' + + profile_path = os.path.join(PROFILES_DIR, file_name) + + os.remove(profile_path) + self._profiles.remove(profile) + + return True + + except Exception as e: + LOGGER.error('An error occurred whilst deleting a profile') + LOGGER.debug(e) + return False def reset(self): self.set_status('Idle')