diff --git a/framework/python/src/api/api.py b/framework/python/src/api/api.py index 719710694..5084555ae 100644 --- a/framework/python/src/api/api.py +++ b/framework/python/src/api/api.py @@ -168,7 +168,19 @@ async def post_sys_config(self, request: Request, response: Response): try: config = (await request.body()).decode("UTF-8") config_json = json.loads(config) + + # Validate req fields + if ("network" not in config_json or + "device_intf" not in config_json.get("network") or + "internet_intf" not in config_json.get("network") or + "log_level" not in config_json): + response.status_code = status.HTTP_400_BAD_REQUEST + return self._generate_msg( + False, + "Configuration is missing required fields") + self._session.set_config(config_json) + # Catch JSON Decode error etc except JSONDecodeError: response.status_code = status.HTTP_400_BAD_REQUEST @@ -690,6 +702,11 @@ async def update_profile(self, request: Request, response: Response): response.status_code = status.HTTP_400_BAD_REQUEST return self._generate_msg(False, "Invalid request received") + # Validate json profile + if not self.get_session().validate_profile_json(req_json): + response.status_code = status.HTTP_400_BAD_REQUEST + return self._generate_msg(False, "Invalid request received") + profile_name = req_json.get("name") # Check if profile exists diff --git a/framework/python/src/common/session.py b/framework/python/src/common/session.py index 20820c506..940fbe8f0 100644 --- a/framework/python/src/common/session.py +++ b/framework/python/src/common/session.py @@ -75,7 +75,7 @@ def apply_session_tracker(cls): @apply_session_tracker class TestrunSession(): - """Represents the current session of Test Run.""" + """Represents the current session of Testrun.""" def __init__(self, root_dir): self._root_dir = root_dir @@ -450,6 +450,11 @@ def _load_profiles(self): # Parse risk profile json json_data = json.load(f) + # Validate profile JSON + if not self.validate_profile_json(json_data): + LOGGER.error('Profile failed validation') + continue + # Instantiate a new risk profile risk_profile = RiskProfile() @@ -478,25 +483,6 @@ def get_profile(self, name): return profile return None - def validate_profile(self, profile_json): - - # Check name field is present - if 'name' not in profile_json: - return False - - # Check questions field is present - if 'questions' not in profile_json: - return False - - # Check all questions are present - for format_q in self.get_profiles_format(): - if self._get_profile_question(profile_json, - format_q.get('question')) is None: - LOGGER.error('Missing question: ' + format_q.get('question')) - return False - - return True - def _get_profile_question(self, profile_json, question): for q in profile_json.get('questions'): @@ -505,7 +491,14 @@ def _get_profile_question(self, profile_json, question): return None + def get_profile_format_question(self, question): + for q in self.get_profiles_format(): + if q.get('question') == question: + return q + def update_profile(self, profile_json): + """Update the risk profile with the provided JSON. + The content has already been validated in the API""" profile_name = profile_json['name'] @@ -513,39 +506,8 @@ def update_profile(self, profile_json): profile_json['version'] = self.get_version() profile_json['created'] = datetime.datetime.now().strftime('%Y-%m-%d') - if 'status' in profile_json and profile_json.get('status') == 'Valid': - # Attempting to submit a risk profile, we need to check it - - # Check all questions have been answered - all_questions_answered = True - - for question in self.get_profiles_format(): - - # Check question is present - profile_question = self._get_profile_question(profile_json, - question.get('question')) - - if profile_question is not None: - - # Check answer is present - if 'answer' not in profile_question: - LOGGER.error('Missing answer for question: ' + - question.get('question')) - all_questions_answered = False - - else: - LOGGER.error('Missing question: ' + question.get('question')) - all_questions_answered = False - - if not all_questions_answered: - LOGGER.error('Not all questions answered') - return None - - else: - profile_json['status'] = 'Draft' - + # Check if profile already exists risk_profile = self.get_profile(profile_name) - if risk_profile is None: # Create a new risk profile @@ -574,6 +536,106 @@ def update_profile(self, profile_json): return risk_profile + def validate_profile_json(self, profile_json): + """Validate properties in profile update requests""" + + # Get the status field + valid = False + if 'status' in profile_json and profile_json.get('status') == 'Valid': + valid = True + + # Check if 'name' exists in profile + if 'name' not in profile_json: + LOGGER.error('Missing "name" in profile') + return False + + # Check if 'name' field not empty + elif len(profile_json.get('name').strip()) == 0: + LOGGER.error('Name field left empty') + return False + + # Error handling if 'questions' not in request + if 'questions' not in profile_json and valid: + LOGGER.error('Missing "questions" field in profile') + return False + + # Validating the questions section + for question in profile_json.get('questions'): + + # Check if the question field is present + if 'question' not in question: + LOGGER.error('The "question" field is missing') + return False + + # Check if 'question' field not empty + elif len(question.get('question').strip()) == 0: + LOGGER.error('A question is missing from "question" field') + return False + + # Check if question is a recognized question + format_q = self.get_profile_format_question( + question.get('question')) + + if format_q is None: + LOGGER.error(f'Unrecognized question: {question.get("question")}') + return False + + # Error handling if 'answer' is missing + if 'answer' not in question and valid: + LOGGER.error('The answer field is missing') + return False + + # If answer is present, check the validation rules + else: + + # Extract the answer out of the profile + answer = question.get('answer') + + # Get the validation rules + field_type = format_q.get('type') + + # Check if type is string or single select, answer should be a string + if ((field_type in ['string', 'select']) + and not isinstance(answer, str)): + LOGGER.error(f'''Answer for question \ +{question.get('question')} is incorrect data type''') + return False + + # Check if type is select, answer must be from list + if field_type == 'select' and valid: + possible_answers = format_q.get('options') + if answer not in possible_answers: + LOGGER.error(f'''Answer for question \ +{question.get('question')} is not valid''') + return False + + # Validate select multiple field types + if field_type == 'select-multiple': + + if not isinstance(answer, list): + LOGGER.error(f'''Answer for question \ +{question.get('question')} is incorrect data type''') + return False + + question_options_len = len(format_q.get('options')) + + # We know it is a list, now check the indexes + for index in answer: + + # Check if the index is an integer + if not isinstance(index, int): + LOGGER.error(f'''Answer for question \ +{question.get('question')} is incorrect data type''') + return False + + # Check if index is 0 or above and less than the num of options + if index < 0 or index >= question_options_len: + LOGGER.error(f'''Invalid index provided as answer for \ +question {question.get('question')}''') + return False + + return True + def delete_profile(self, profile): try: diff --git a/framework/requirements.txt b/framework/requirements.txt index 402009ef9..0484905ee 100644 --- a/framework/requirements.txt +++ b/framework/requirements.txt @@ -21,6 +21,8 @@ pydantic==2.7.1 # Requirements for testing pytest==7.4.4 pytest-timeout==2.2.0 +responses==0.25.3 + # Requirements for the report markdown==3.5.2 diff --git a/modules/test/ntp/python/src/ntp_module.py b/modules/test/ntp/python/src/ntp_module.py index 033e98974..be27abbad 100644 --- a/modules/test/ntp/python/src/ntp_module.py +++ b/modules/test/ntp/python/src/ntp_module.py @@ -14,8 +14,8 @@ """NTP test module""" from test_module import TestModule from scapy.all import rdpcap, IP, IPv6, NTP, UDP, Ether -from datetime import datetime import os +from collections import defaultdict LOG_NAME = 'test_ntp' MODULE_REPORT_FILE_NAME = 'ntp_report.html' @@ -69,6 +69,33 @@ def generate_module_report(self): total_responses = sum(1 for row in ntp_table_data if row['Type'] == 'Server') + # Initialize a dictionary to store timestamps for each unique combination + timestamps = defaultdict(list) + + # Collect timestamps for each unique combination + for row in ntp_table_data: + # Add the timestamp to the corresponding combination + key = (row['Source'], row['Destination'], row['Type'], row['Version']) + timestamps[key].append(row['Timestamp']) + + # Calculate the average time between requests for each unique combination + average_time_between_requests = {} + + for key, times in timestamps.items(): + # Sort the timestamps + times.sort() + + # Calculate the time differences between consecutive timestamps + time_diffs = [t2 - t1 for t1, t2 in zip(times[:-1], times[1:])] + + # Calculate the average of the time differences + if time_diffs: + avg_diff = sum(time_diffs) / len(time_diffs) + else: + avg_diff = 0 # one timestamp, the average difference is 0 + + average_time_between_requests[key] = avg_diff + # Add summary table html_content += (f'''
| Destination | Type | Version | -Timestamp | +Count | +Sync Request Average | ''' - for row in ntp_table_data: - - # Timestamp of the NTP packet - dt_object = datetime.utcfromtimestamp(row['Timestamp']) + # Generate the HTML table with the count column + for (src, dst, typ, + version), avg_diff in average_time_between_requests.items(): + cnt = len(timestamps[(src, dst, typ, version)]) - # Extract milliseconds from the fractional part of the timestamp - milliseconds = int((row['Timestamp'] % 1) * 1000) - - # Format the datetime object with milliseconds - formatted_time = dt_object.strftime( - '%b %d, %Y %H:%M:%S.') + f'{milliseconds:03d}' + # Sync Average only applies to client requests + if 'Client' in typ: + # Convert avg_diff to seconds and format it + avg_diff_seconds = avg_diff + avg_formatted_time = f'{avg_diff_seconds:.3f} seconds' + else: + avg_formatted_time = 'N/A' - table_content += (f''' + table_content += f'''
|---|---|---|---|---|
| {row['Source']} | -{row['Destination']} | -{row['Type']} | -{row['Version']} | -{formatted_time} | -{src} | +{dst} | +{typ} | +{version} | +{cnt} | +{avg_formatted_time} | + ''' table_content += '''