diff --git a/flamesdk/flame_core.py b/flamesdk/flame_core.py index 3df35b3..90807d5 100644 --- a/flamesdk/flame_core.py +++ b/flamesdk/flame_core.py @@ -10,75 +10,92 @@ from flamesdk.resources.client_apis.data_api import DataAPI from flamesdk.resources.client_apis.message_broker_api import MessageBrokerAPI, Message from flamesdk.resources.client_apis.storage_api import StorageAPI, LocalDifferentialPrivacyParams +from flamesdk.resources.client_apis.po_api import POAPI from flamesdk.resources.node_config import NodeConfig from flamesdk.resources.rest_api import FlameAPI from flamesdk.resources.utils.fhir import fhir_to_csv from flamesdk.resources.utils.utils import wait_until_nginx_online -from flamesdk.resources.utils.logging import flame_log, declare_log_types - +from flamesdk.resources.utils.logging import FlameLogger class FlameCoreSDK: def __init__(self, aggregator_requires_data: bool = False, silent: bool = False): - self.silent = silent - self.flame_log("Starting FlameCoreSDK", self.silent) + self._flame_logger = FlameLogger(silent=silent) + self.flame_log("Starting FlameCoreSDK") - self.flame_log("\tExtracting node config", self.silent) + self.flame_log("\tExtracting node config") # Extract node config self.config = NodeConfig() # Wait until nginx is online try: - wait_until_nginx_online(self.config.nginx_name, self.silent) + wait_until_nginx_online(self.config.nginx_name, self._flame_logger) except Exception as e: - self.flame_log(f"Nginx connection failure (error_msg='{e}')", False, 'error') + self.flame_log(f"Nginx connection failure (error_msg='{repr(e)}')", log_type='error') # Set up the connection to all the services needed ## Connect to message broker - self.flame_log("\tConnecting to MessageBroker...", self.silent, end='', suppress_tail=True) + self.flame_log("\tConnecting to MessageBroker...", end='', suppress_tail=True) try: - self._message_broker_api = MessageBrokerAPI(self.config) - self.flame_log("success", self.silent, suppress_head=True) + self._message_broker_api = MessageBrokerAPI(self.config, self._flame_logger) + self.flame_log("success", suppress_head=True) except Exception as e: self._message_broker_api = None - self.flame_log(f"failed (error_msg='{e}')", False, 'error', suppress_head=True) + self.flame_log(f"failed (error_msg='{repr(e)}')", log_type='error', suppress_head=True) try: ### Update config with self_config from Messagebroker self.config = self._message_broker_api.config except Exception as e: - self.flame_log(f"Unable to retrieve node config from message broker (error_msg='{e}')", - False, - 'error') + self.flame_log(f"Unable to retrieve node config from message broker (error_msg='{repr(e)}')", + log_type='error') + + ## Connect to po service + self.flame_log("\tConnecting to PO service...", end='', suppress_tail=True) + try: + self._po_api = POAPI(self.config, self._flame_logger) + self._flame_logger.add_po_api(self._po_api) + self.flame_log("success", suppress_head=True) + except Exception as e: + self._po_api = None + self.flame_log(f"failed (error_msg='{repr(e)}')", log_type='error', suppress_head=True) ## Connect to result service - self.flame_log("\tConnecting to ResultService...", self.silent, end='', suppress_tail=True) + self.flame_log("\tConnecting to ResultService...", end='', suppress_tail=True) try: - self._storage_api = StorageAPI(self.config) - self.flame_log("success", self.silent, suppress_head=True) + self._storage_api = StorageAPI(self.config, self._flame_logger) + self.flame_log("success", suppress_head=True) except Exception as e: self._storage_api = None - self.flame_log(f"failed (error_msg='{e}')", False, 'error', suppress_head=True) + self.flame_log(f"failed (error_msg='{repr(e)}')", log_type='error', suppress_head=True) if (self.config.node_role == 'default') or aggregator_requires_data: ## Connection to data service - self.flame_log("\tConnecting to DataApi...", self.silent, end='', suppress_tail=True) + self.flame_log("\tConnecting to DataApi...", end='', suppress_tail=True) try: - self._data_api = DataAPI(self.config) - self.flame_log("success", self.silent, suppress_head=True) + self._data_api = DataAPI(self.config, self._flame_logger) + self.flame_log("success", suppress_head=True) except Exception as e: self._data_api = None - self.flame_log(f"failed (error_msg='{e}')", False, 'error') + self.flame_log(f"failed (error_msg='{repr(e)}')", log_type='error', suppress_head=True) + else: + self._data_api = True # Start the flame api thread used for incoming messages and health checks - self.flame_log("\tStarting FlameApi thread...", self.silent, end='', suppress_tail=True) + self.flame_log("\tStarting FlameApi thread...", end='', suppress_tail=True) try: self._flame_api_thread = Thread(target=self._start_flame_api) self._flame_api_thread.start() - self.flame_log("success", self.silent, suppress_head=True) - except: - raise Exception("Analysis hard crashed when attempting to start api.") + self.flame_log("success", suppress_head=True) + except Exception as e: + self._flame_api_thread = None + self.flame_log(f"failed (error_msg='{repr(e)}')", log_type='error', suppress_head=True) + + if all([self._message_broker_api, self._po_api, self._storage_api, self._data_api, self._flame_api_thread]): + self._flame_logger.set_runstatus('running') + self.flame_log("FlameCoreSDK ready") + else: + self.flame_log("FlameCoreSDK startup failed", log_type='error') - self.flame_log("FlameCoreSDK ready", self.silent) ########################################General################################################## def get_aggregator_id(self) -> Optional[str]: @@ -153,8 +170,7 @@ def analysis_finished(self) -> bool: "analysis_finished", {}, max_attempts=5, - attempt_timeout=30, - silent=self.silent) + attempt_timeout=30) return self._node_finished() @@ -205,8 +221,7 @@ def ready_check(self, acknowledged_list, _ = self.send_message(receivers=nodes, message_category='ready_check', message={}, - timeout=attempt_interval, - silent=self.silent) + timeout=attempt_interval) for node in acknowledged_list: received[node] = True nodes.remove(node) @@ -218,50 +233,42 @@ def ready_check(self, def flame_log(self, msg: Union[str, bytes], - silent: Optional[bool] = None, sep: str = ' ', end: str = '\n', file: object = None, - flush: bool = False, log_type: str = 'normal', suppress_head: bool = False, suppress_tail: bool = False) -> None: """ Print logs to console. :param msg: - :param silent: :param sep: :param end: :param file: - :param flush: :param log_type: :param suppress_head: :param suppress_tail: :return: """ - if silent is None: - silent = self.silent - flame_log(msg=msg, - silent=silent, - sep=sep, - end=end, - file=file, - flush=flush, - log_type=log_type, - suppress_head=suppress_head, - suppress_tail=suppress_tail) + if log_type != 'error': + self._flame_logger.new_log(msg=msg, + sep=sep, + end=end, + file=file, + log_type=log_type, + suppress_head=suppress_head, + suppress_tail=suppress_tail) + else: + self._flame_logger.raise_error(msg) - def declare_log_types(self, new_log_types: dict[str, str], silent: Optional[bool] = None) -> None: + def declare_log_types(self, new_log_types: dict[str, str]) -> None: """ Declare new log_types to be added to log_type literals, and how/as what they should be interpreted by Flame (the latter have to be known values from HUB_LOG_LITERALS for existing log status fields). :param new_log_types: - :param silent: :return: """ - if silent is None: - silent = self.silent - declare_log_types(new_log_types, silent) + self._flame_logger.declare_log_types(new_log_types) def fhir_to_csv(self, fhir_data: dict[str, Any], @@ -273,8 +280,8 @@ def fhir_to_csv(self, col_id_filters: Optional[list[str]] = None, row_col_name: str = '', separator: str = ',', - output_type: Literal["file", "dict"] = "file", - silent: Optional[bool] = None) -> Union[StringIO, dict[Any, dict[Any, Any]]]: + output_type: Literal["file", "dict"] = "file" + ) -> Optional[Union[StringIO, dict[Any, dict[Any, Any]]]]: """ Convert a FHIR Bundle (or other FHIR-formatted dict) to CSV, pivoting on specified keys. @@ -282,7 +289,6 @@ def fhir_to_csv(self, applies optional filters, and produces either a CSV‐formatted file-like object or a nested dictionary representation - :param fhir_data: FHIR data to convert :param col_key_seq: :param value_key_seq: @@ -293,23 +299,25 @@ def fhir_to_csv(self, :param row_col_name: :param separator: :param output_type: - :param silent: :return: CSV formatted data as StringIO or dict """ - if silent is None: - silent = self.silent - return fhir_to_csv(fhir_data=fhir_data, - col_key_seq=col_key_seq, - value_key_seq=value_key_seq, - input_resource=input_resource, - row_key_seq=row_key_seq, - row_id_filters=row_id_filters, - col_id_filters=col_id_filters, - row_col_name=row_col_name, - separator=separator, - output_type=output_type, - data_client=self._data_api, - silent=silent) + if type(self._data_api) == DataAPI: + return fhir_to_csv(fhir_data=fhir_data, + col_key_seq=col_key_seq, + value_key_seq=value_key_seq, + input_resource=input_resource, + flame_logger=self._flame_logger, + row_key_seq=row_key_seq, + row_id_filters=row_id_filters, + col_id_filters=col_id_filters, + row_col_name=row_col_name, + separator=separator, + output_type=output_type, + data_client=self._data_api) + else: + self.flame_log("Data API is not available, cannot convert FHIR to CSV", + log_type='warning') + return None ########################################Message Broker Client#################################### def send_message(self, @@ -318,8 +326,7 @@ def send_message(self, message: dict, max_attempts: int = 1, timeout: Optional[int] = None, - attempt_timeout: int = 10, - silent: Optional[bool] = None) -> tuple[list[str], list[str]]: + attempt_timeout: int = 10) -> tuple[list[str], list[str]]: """ Send a message to the specified nodes :param receivers: list of node ids to send the message to @@ -328,18 +335,14 @@ def send_message(self, :param max_attempts: the maximum number of attempts to send the message :param timeout: time in seconds to wait for the message acknowledgement, if None waits indefinitely :param attempt_timeout: timeout of each attempt, if timeout is None (the last attempt will be indefinite though) - :param silent: if True, the response will not be logged :return: a tuple of nodes ids that acknowledged and not acknowledged the message """ - if silent is None: - silent = self.silent return asyncio.run(self._message_broker_api.send_message(receivers, message_category, message, max_attempts, timeout, - attempt_timeout, - silent)) + attempt_timeout)) def await_messages(self, senders: list[str], @@ -390,8 +393,7 @@ def send_message_and_wait_for_responses(self, message: dict, max_attempts: int = 1, timeout: Optional[int] = None, - attempt_timeout: int = 10, - silent: Optional[bool] = None) -> dict[str, Optional[list[Message]]]: + attempt_timeout: int = 10) -> dict[str, Optional[list[Message]]]: """ Sends a message to all specified nodes and waits for responses, (combines send_message and await_responses) :param receivers: list of node ids to send the message to @@ -402,59 +404,46 @@ def send_message_and_wait_for_responses(self, :param attempt_timeout: timeout of each attempt, if timeout is None (the last attempt will be indefinite though) :return: the responses """ - if silent is None: - silent = self.silent return self._message_broker_api.send_message_and_wait_for_responses(receivers, message_category, message, max_attempts, timeout, - attempt_timeout, - silent) + attempt_timeout) ########################################Storage Client########################################### def submit_final_result(self, result: Any, output_type: Literal['str', 'bytes', 'pickle'] = 'str', - local_dp: Optional[LocalDifferentialPrivacyParams] = None, #TODO:localdp - silent: Optional[bool] = None) -> dict[str, str]: + local_dp: Optional[LocalDifferentialPrivacyParams] = None) -> dict[str, str]: """ sends the final result to the hub. Making it available for analysts to download. This method is only available for nodes for which the method `get_role(self)` returns "aggregator”. :param result: the final result :param output_type: output type of final results (default: string) - :param local_dp: tba #TODO:localdp - :param silent: if True, the response will not be logged + :param local_dp: :return: the request status code """ - if silent is None: - silent = self.silent return self._storage_api.submit_final_result(result, output_type, - local_dp, #TODO:localdp - silent=silent) + local_dp) def save_intermediate_data(self, data: Any, location: Literal["local", "global"], remote_node_ids: Optional[list[str]] = None, - tag: Optional[str] = None, - silent: Optional[bool] = None) -> Union[dict[str, dict[str, str]], dict[str, str]]: + tag: Optional[str] = None) -> Union[dict[str, dict[str, str]], dict[str, str]]: """ saves intermediate results/data either on the hub (location="global"), or locally :param data: the result to save :param location: the location to save the result, local saves in the node, global saves in central instance of MinIO :param remote_node_ids: optional remote node ids (used for accessing remote node's public key for encryption) :param tag: optional storage tag - :param silent: if True, the response will not be logged :return: the request status code{"status": ,"url":, "id": }, or dict of said dicts if encrypted mode is used, i.e. remote_node_ids are set """ - if silent is None: - silent = self.silent return self._storage_api.save_intermediate_data(data, location=location, remote_node_ids=remote_node_ids, - tag=tag, - silent=silent) + tag=tag) def get_intermediate_data(self, location: Literal["local", "global"], @@ -484,8 +473,7 @@ def send_intermediate_data(self, max_attempts: int = 1, timeout: Optional[int] = None, attempt_timeout: int = 10, - encrypted: bool = False, - silent: Optional[bool] = None) -> tuple[list[str], list[str]]: + encrypted: bool = False) -> tuple[list[str], list[str]]: """ Sends intermediate data to specified receivers using the Result Service and Message Broker. @@ -502,7 +490,6 @@ def send_intermediate_data(self, timeout (int, optional): time in seconds to wait for the message acknowledgement, if None waits indefinitely attempt_timeout (int): timeout of each attempt, if timeout is None (the last attempt will be indefinite though) encrypted (bool): bool whether data should be encrypted or not - silent (bool): if True, the response will not be logged Returns: tuple[list[str], list[str]]: @@ -520,8 +507,6 @@ def send_intermediate_data(self, print("Failed nodes:", failed) # e.g., ["node3"] ``` """ - if silent is None: - silent = self.silent if encrypted: result_id_body = {k: v['id'] for k, v in self.save_intermediate_data(data, @@ -535,8 +520,7 @@ def send_intermediate_data(self, {"result_id": result_id_body}, max_attempts, timeout, - attempt_timeout, - silent) + attempt_timeout) def await_intermediate_data(self, senders: list[str], @@ -596,36 +580,57 @@ def get_local_tags(self, filter: Optional[str] = None) -> list[str]: return self._storage_api.get_local_tags(filter) ########################################Data Client####################################### - def get_data_client(self, data_id: str) -> AsyncClient: + def get_data_client(self, data_id: str) -> Optional[AsyncClient]: """ Returns the data client for a specific fhir or S3 store used for this project. :param data_id: the id of the data source :return: the data client """ - return self._data_api.get_data_client(data_id) + if type(self._data_api) == DataAPI: + return self._data_api.get_data_client(data_id) + else: + self.flame_log("Data API is not available, cannot retrieve data client", + log_type='warning') + return None - def get_data_sources(self) -> list[str]: + def get_data_sources(self) -> Optional[list[str]]: """ Returns a list of all data sources available for this project. :return: the list of data sources """ - return self._data_api.get_data_sources() + if type(self._data_api) == DataAPI: + return self._data_api.get_data_sources() + else: + self.flame_log("Data API is not available, cannot retrieve data sources", + log_type='warning') + return None - def get_fhir_data(self, fhir_queries: Optional[list[str]] = None) -> list[Union[dict[str, dict], dict]]: + def get_fhir_data(self, fhir_queries: Optional[list[str]] = None) -> Optional[list[Union[dict[str, dict], dict]]]: """ Returns the data from the FHIR store for each of the specified queries. :param fhir_queries: list of queries to get the data :return: """ - return self._data_api.get_fhir_data(fhir_queries) + if type(self._data_api) == DataAPI: + return self._data_api.get_fhir_data(fhir_queries) + else: + self.flame_log("Data API is not available, cannot retrieve FHIR data", + log_type='warning') + return None - def get_s3_data(self, s3_keys: Optional[list[str]] = None) -> list[Union[dict[str, str], str]]: + def get_s3_data(self, s3_keys: Optional[list[str]] = None) -> Optional[list[Union[dict[str, str], str]]]: """ Returns the data from the S3 store associated with the given key. :param s3_keys:f :return: """ - return self._data_api.get_s3_data(s3_keys) + if type(self._data_api) == DataAPI: + return self._data_api.get_s3_data(s3_keys) + else: + self.flame_log("Data API is not available, cannot retrieve S3 data", + log_type='warning') + return None + ########################################Internal############################################### def _start_flame_api(self) -> None: @@ -634,10 +639,11 @@ def _start_flame_api(self) -> None: :return: """ self.flame_api = FlameAPI(self._message_broker_api.message_broker_client, - self._data_api.data_client if hasattr(self, '_data_api') else 'ignore', + self._data_api.data_client if hasattr(self._data_api, 'data_client') else 'ignore', self._storage_api.result_client, + self._po_api.po_client, + self._flame_logger, self.config.keycloak_token, - self.silent, finished_check=self._has_finished, finishing_call=self._node_finished) diff --git a/flamesdk/resources/client_apis/clients/data_api_client.py b/flamesdk/resources/client_apis/clients/data_api_client.py index a801bee..3b5d7bd 100644 --- a/flamesdk/resources/client_apis/clients/data_api_client.py +++ b/flamesdk/resources/client_apis/clients/data_api_client.py @@ -1,11 +1,14 @@ -from typing import Any, Optional, Union +from typing import Optional, Union import asyncio -from httpx import AsyncClient +from httpx import AsyncClient, HTTPStatusError import re +from flamesdk.resources.utils.logging import FlameLogger + class DataApiClient: - def __init__(self, project_id: str, nginx_name: str, data_source_token: str, keycloak_token: str) -> None: + def __init__(self, project_id: str, nginx_name: str, data_source_token: str, keycloak_token: str, flame_logger: FlameLogger) -> None: self.nginx_name = nginx_name + self.flame_logger = flame_logger self.client = AsyncClient(base_url=f"http://{nginx_name}/kong", headers={"apikey": data_source_token, "Content-Type": "application/json"}, @@ -26,14 +29,20 @@ def refresh_token(self, keycloak_token: str): async def _retrieve_available_sources(self) -> list[dict[str, str]]: response = await self.hub_client.get(f"/kong/datastore/{self.project_id}") - response.raise_for_status() + try: + response.raise_for_status() + except HTTPStatusError as e: + self.flame_logger.raise_error(f"Failed to retrieve available data sources for project {self.project_id}:" + f" {repr(e)}") + return [{'name': source['name']} for source in response.json()['data']] def get_available_sources(self): return self.available_sources - def get_data(self, s3_keys: Optional[list[str]] = None, fhir_queries: Optional[list[str]] = None) \ - -> list[Union[dict[str, Union[dict, str]], str]]: + def get_data(self, + s3_keys: Optional[list[str]] = None, + fhir_queries: Optional[list[str]] = None) -> list[Union[dict[str, Union[dict, str]], str]]: dataset_sources = [] for source in self.available_sources: datasets = {} @@ -41,7 +50,12 @@ def get_data(self, s3_keys: Optional[list[str]] = None, fhir_queries: Optional[l for fhir_query in fhir_queries: # premise: retrieves data for each fhir_query from each data source response = asyncio.run(self.client.get(f"{source['name']}/fhir/{fhir_query}", headers=[('Connection', 'close')])) - response.raise_for_status() + try: + response.raise_for_status() + except HTTPStatusError as e: + self.flame_logger.new_log(f"Failed to retrieve fhir data for query {fhir_query} " + f"from source {source['name']}: {repr(e)}", log_type='warning') + continue datasets[fhir_query] = response.json() else: response_names = asyncio.run(self._get_s3_dataset_names(source['name'])) @@ -49,15 +63,21 @@ def get_data(self, s3_keys: Optional[list[str]] = None, fhir_queries: Optional[l if (s3_keys is None) or (res_name in s3_keys): response = asyncio.run(self.client.get(f"{source['name']}/s3/{res_name}", headers=[('Connection', 'close')])) - response.raise_for_status() + try: + response.raise_for_status() + except HTTPStatusError as e: + self.flame_logger.raise_error(f"Failed to retrieve s3 data for key {res_name} " + f"from source {source['name']}: {repr(e)}") datasets[res_name] = response.content dataset_sources.append(datasets) return dataset_sources async def _get_s3_dataset_names(self, source_name: str) -> list[str]: response = await self.client.get(f"{source_name}/s3", headers=[('Connection', 'close')]) - response.raise_for_status() - + try: + response.raise_for_status() + except HTTPStatusError as e: + self.flame_logger.raise_error(f"Failed to retrieve S3 dataset names from source {source_name}: {repr(e)}") responses = re.findall(r'(.*?)', str(response.text)) return responses @@ -72,8 +92,8 @@ def get_data_source_client(self, data_id: str) -> AsyncClient: if sources["id"] == data_id: path = sources["paths"][0] if path is None: - raise ValueError(f"Data source with id {data_id} not found") - client = AsyncClient(base_url=f"{path}",) + self.flame_logger.raise_error(f"Data source with id {data_id} not found") + client = AsyncClient(base_url=f"{path}") return client diff --git a/flamesdk/resources/client_apis/clients/message_broker_client.py b/flamesdk/resources/client_apis/clients/message_broker_client.py index a422e88..a344289 100644 --- a/flamesdk/resources/client_apis/clients/message_broker_client.py +++ b/flamesdk/resources/client_apis/clients/message_broker_client.py @@ -3,10 +3,10 @@ import asyncio import datetime from typing import Optional, Literal -from httpx import AsyncClient, HTTPError +from httpx import AsyncClient, HTTPStatusError from flamesdk.resources.node_config import NodeConfig -from flamesdk.resources.utils.logging import flame_log +from flamesdk.resources.utils.logging import FlameLogger class Message: @@ -14,6 +14,7 @@ def __init__(self, message: dict, config: NodeConfig, outgoing: bool, + flame_logger: FlameLogger, message_number: Optional[int] = None, category: Optional[str] = None, recipients: Optional[list[str]] = None) -> None: @@ -26,25 +27,25 @@ def __init__(self, :param category: the message category :param recipients: the list of recipients """ + self.flame_logger = flame_logger if outgoing: if "meta" in message.keys(): - raise ValueError("Cannot use field 'meta' in message body. " - "This field is reserved for meta data used by the message broker.") + self.flame_logger.raise_error("Cannot use field 'meta' in message body. " + "This field is reserved for meta data used by the message broker.") elif type(message_number) != int: - raise ValueError(f"Specified outgoing message, but did not specify integer value for message_number " - f"(received: {type(message_number)}).") + self.flame_logger.raise_error(f"Specified outgoing message, but did not specify integer value for " + f"message_number (received: {type(message_number)}).") elif type(category) != str: - raise ValueError("Specified outgoing message, but did not specify string value for category " - f"(received: {type(category)}).") - + self.flame_logger.raise_error(f"Specified outgoing message, but did not specify string value for " + f"category (received: {type(category)}).") elif (type(recipients) != list) or (any([type(recipient) != str for recipient in recipients])): if hasattr(recipients, '__iter__'): - raise ValueError(f"Specified outgoing message, but did not specify list of strings value for " - f"recipients (received: {type(recipients)} containing " - f"{set([type(recipient) for recipient in recipients])}).") + self.flame_logger.raise_error(f"Specified outgoing message, but did not specify list of strings " + f"value for recipients (received: {type(recipients)} containing " + f"{set([type(recipient) for recipient in recipients])}).") else: - raise ValueError(f"Specified outgoing message, but did not specify list of strings value for " - f"recipients (received: {type(recipients)}).") + self.flame_logger.raise_error(f"Specified outgoing message, but did not specify list of strings " + f"value for recipients (received: {type(recipients)}).") self.recipients = recipients self.body = message @@ -92,14 +93,15 @@ def _update_meta_data(self, class MessageBrokerClient: - def __init__(self, config: NodeConfig, silent: bool = False) -> None: + def __init__(self, config: NodeConfig, flame_logger: FlameLogger) -> None: self.nodeConfig = config + self.flame_logger = flame_logger self._message_broker = AsyncClient( base_url=f"http://{self.nodeConfig.nginx_name}/message-broker", headers={"Authorization": f"Bearer {config.keycloak_token}", "Accept": "application/json"}, follow_redirects=True ) - asyncio.run(self._connect(silent=silent)) + asyncio.run(self._connect()) self.list_of_incoming_messages: list[Message] = [] self.list_of_outgoing_messages: list[Message] = [] self.message_number = 0 @@ -117,15 +119,20 @@ def refresh_token(self, keycloak_token: str): async def get_self_config(self, analysis_id: str) -> dict[str, str]: response = await self._message_broker.get(f'/analyses/{analysis_id}/participants/self', headers=[('Connection', 'close')]) - response.raise_for_status() + try: + response.raise_for_status() + except HTTPStatusError as e: + self.flame_logger.raise_error(f"Failed to retrieve self configuration for analysis {analysis_id}: " + f"{repr(e)}") return response.json() async def get_partner_nodes(self, self_node_id: str, analysis_id: str) -> list[dict[str, str]]: response = await self._message_broker.get(f'/analyses/{analysis_id}/participants', headers=[('Connection', 'close')]) - - response.raise_for_status() - + try: + response.raise_for_status() + except HTTPStatusError as e: + self.flame_logger.raise_error(f"Failed to retrieve partner nodes for analysis {analysis_id} : {repr(e)}") response = [node_conf for node_conf in response.json() if node_conf['nodeId'] != self_node_id] return response @@ -135,28 +142,28 @@ async def test_connection(self) -> bool: try: response.raise_for_status() return True - except HTTPError: + except HTTPStatusError as e: + self.flame_logger.raise_error(f"Failed to connect to message broker: {repr(e)}") return False - async def _connect(self, silent: bool = False) -> None: + async def _connect(self) -> None: response = await self._message_broker.post( f'/analyses/{os.getenv("ANALYSIS_ID")}/messages/subscriptions', json={'webhookUrl': f'http://{self.nodeConfig.nginx_name}/analysis/webhook'} ) try: response.raise_for_status() - except HTTPError as e: - flame_log("Failed to subscribe to message broker", silent) - flame_log(repr(e), silent) + except HTTPStatusError as e: + self.flame_logger.raise_error(f"Failed to subscribe to message broker: {repr(e)}") response = await self._message_broker.get(f'/analyses/{os.getenv("ANALYSIS_ID")}/participants/self', headers=[('Connection', 'close')]) try: response.raise_for_status() - except HTTPError as e: - flame_log("Successfully subscribed to message broker, but failed to retrieve participants", silent) - flame_log(repr(e), silent) + except HTTPStatusError as e: + self.flame_logger.raise_error(f"Successfully subscribed to message broker, " + f"but failed to retrieve participants: {repr(e)}") - async def send_message(self, message: Message, silent: bool = False) -> None: + async def send_message(self, message: Message) -> None: self.message_number += 1 body = { "recipients": message.recipients, @@ -168,18 +175,18 @@ async def send_message(self, message: Message, silent: bool = False) -> None: headers=[('Connection', 'close'), ("Content-Type", "application/json")]) if message.body["meta"]["sender"] == self.nodeConfig.node_id: - flame_log(f"send message: {body}", silent) + self.flame_logger.new_log(f"send message: {body}", log_type='info') self.list_of_outgoing_messages.append(message) - def receive_message(self, body: dict, silent: bool = False) -> None: + def receive_message(self, body: dict) -> None: needs_acknowledgment = body["meta"]["akn_id"] is None - message = Message(message=body, config=self.nodeConfig, outgoing=False) + message = Message(message=body, config=self.nodeConfig, outgoing=False, flame_logger=self.flame_logger ) self.list_of_incoming_messages.append(message) if needs_acknowledgment: - flame_log("acknowledging ready check" if body["meta"]["category"] == "ready_check" else "incoming message", - silent) + self.flame_logger.new_log("acknowledging ready check" if body["meta"]["category"] == "ready_check" else "incoming message", + log_type='info') asyncio.run(self.acknowledge_message(message)) def delete_message_by_id(self, message_id: str, type: Literal["outgoing", "incoming"]) -> int: @@ -196,14 +203,18 @@ def delete_message_by_id(self, message_id: str, type: Literal["outgoing", "incom self.list_of_outgoing_messages.remove(message) number_of_deleted_messages += 1 if number_of_deleted_messages == 0: - raise ValueError(f"Could not find message with id={message_id} in outgoing messages.") + self.flame_logger.new_log(f"Could not find message with id={message_id} in outgoing messages.", + log_type='warning') + return 0 if type == "incoming": for message in self.list_of_outgoing_messages: if message.body["meta"]["id"] == message_id: self.list_of_outgoing_messages.remove(message) number_of_deleted_messages += 1 if number_of_deleted_messages == 0: - raise ValueError(f"Could not find message with id={message_id} in outgoing messages.") + self.flame_logger.new_log(f"Could not find message with id={message_id} in incoming messages.", + log_type='warning') + return 0 return number_of_deleted_messages async def await_message(self, diff --git a/flamesdk/resources/client_apis/clients/po_client.py b/flamesdk/resources/client_apis/clients/po_client.py new file mode 100644 index 0000000..d675996 --- /dev/null +++ b/flamesdk/resources/client_apis/clients/po_client.py @@ -0,0 +1,40 @@ +from typing import Optional, Union +import asyncio +from httpx import Client, HTTPError + +from flamesdk.resources.utils.logging import FlameLogger + +class POClient: + def __init__(self, nginx_name: str, keycloak_token: str, flame_logger: FlameLogger) -> None: + self.nginx_name = nginx_name + self.client = Client(base_url=f"http://{nginx_name}/po", + headers={"Authorization": f"Bearer {keycloak_token}", + "accept": "application/json"}, + follow_redirects=True) + self.flame_logger = flame_logger + + def refresh_token(self, keycloak_token: str): + self.client = Client(base_url=f"http://{self.nginx_name}/po", + headers={"Authorization": f"Bearer {keycloak_token}", + "accept": "application/json"}, + follow_redirects=True) + + def stream_logs(self, log: str, log_type: str, analysis_id: str, status: str) -> None: + log_dict = { + "log": log, + "log_type": log_type, + "analysis_id": analysis_id, + "status": status + } + print("Sending logs to PO:", log_dict) + response = self.client.post("/stream_logs", + json=log_dict, + headers={"Content-Type": "application/json"}) + try: + response.raise_for_status() + print("Successfully streamed logs to PO") + except HTTPError as e: + #self.flame_logger.new_log(f"Failed to stream logs to PO: {repr(e)}", log_type='error') + print("HTTP Error in po api:", repr(e)) + except Exception as e: + print("Unforeseen Error:", repr(e)) diff --git a/flamesdk/resources/client_apis/clients/result_client.py b/flamesdk/resources/client_apis/clients/result_client.py index 8b5b87b..7316a43 100644 --- a/flamesdk/resources/client_apis/clients/result_client.py +++ b/flamesdk/resources/client_apis/clients/result_client.py @@ -1,14 +1,13 @@ import math -from httpx import Client +from httpx import Client, HTTPStatusError import pickle -from _pickle import PicklingError import re import uuid from io import BytesIO from typing import Any, Literal, Optional from typing_extensions import TypedDict -from flamesdk.resources.utils.logging import flame_log +from flamesdk.resources.utils.logging import FlameLogger class LocalDifferentialPrivacyParams(TypedDict, total=True): @@ -18,12 +17,12 @@ class LocalDifferentialPrivacyParams(TypedDict, total=True): class ResultClient: - def __init__(self, nginx_name, keycloak_token) -> None: + def __init__(self, nginx_name, keycloak_token, flame_logger: FlameLogger) -> None: self.nginx_name = nginx_name self.client = Client(base_url=f"http://{nginx_name}/storage", headers={"Authorization": f"Bearer {keycloak_token}"}, follow_redirects=True) - + self.flame_logger = flame_logger def refresh_token(self, keycloak_token: str): self.client = Client(base_url=f"http://{self.nginx_name}/storage", headers={"Authorization": f"Bearer {keycloak_token}"}, @@ -36,7 +35,7 @@ def push_result(self, type: Literal["final", "global", "local"] = "final", output_type: Literal['str', 'bytes', 'pickle'] = 'pickle', local_dp: Optional[LocalDifferentialPrivacyParams] = None, #TODO:localdp - silent: bool = False) -> dict[str, str]: + ) -> dict[str, str]: """ Pushes the result to the hub. Making it available for analysts to download. @@ -46,18 +45,17 @@ def push_result(self, :param type: location to save the result, final saves in the hub to be downloaded, global saves in central instance of MinIO, local saves in the node :param output_type: the type of the result, str, bytes or pickle only for final results :param local_dp: parameters for local differential privacy, only for final floating-point type results #TODO:localdp - :param silent: if True, the response will not be logged :return: """ if tag and (type != "local"): - raise ValueError("Tag can only be used with local type, in current implementation") + self.flame_logger.raise_error("Tag can only be used with local type, in current implementation") elif remote_node_id and (type != "global"): - raise ValueError("Remote_node_id can only be used with global type, in current implementation") - + self.flame_logger.raise_error("Remote_node_id can only be used with global type, in current implementation") type = "intermediate" if type == "global" else type if tag and not re.match(r'^[a-z0-9]+(-[a-z0-9]+)*$', tag): - raise ValueError("Tag must consist only of lowercase letters, numbers, and hyphens") + self.flame_logger.raise_error(f"Invalid tag format: {tag}. " + f"Tag must consist only of lowercase letters, numbers, and hyphens") # TODO:localdp (start) # check if local dp parameters have been supplied @@ -67,23 +65,23 @@ def push_result(self, if use_local_dp: # check if result is a numeric value if not isinstance(result, (float, int)): - raise ValueError("Local differential privacy can only be applied on numeric values") + self.flame_logger.raise_error("Local differential privacy can only be applied on numeric values") # check if result is finite if not math.isfinite(result): - raise ValueError("Result is not finite") + self.flame_logger.raise_error("ValueError: Result is not finite") # check if final result submission is requested if type != "final": - raise ValueError("Local differential privacy is only supported for submission of final results") + self.flame_logger.raise_error("ValueError: Local differential privacy is only supported for " + "submission of final results") # print warning if output_type other than str is specified if output_type != "str": - flame_log( + self.flame_logger.new_log( f"Result submission with local differential privacy requested but output type is set to `{output_type}`." "`str` is enforced but this may change in a future version.", - silent - ) + log_type='warning') # write as string to request body file_body = str(result).encode("utf-8") @@ -93,16 +91,16 @@ def push_result(self, file_body = bytes(result) else: file_body = pickle.dumps(result) - except (TypeError, ValueError, UnicodeEncodeError, PicklingError) as e: + except (TypeError, ValueError, UnicodeEncodeError, pickle.PicklingError) as e: if output_type != 'pickle': - flame_log(f"Failed to translate result data to type={output_type}: {e}", silent) - flame_log("Attempting 'pickle' instead...", silent) + self.flame_logger.new_log(f"Failed to translate result data to type={output_type}: {repr(e)}", log_type='warning') + self.flame_logger.new_log("Attempting 'pickle' instead...", log_type='warning') try: file_body = pickle.dumps(result) - except PicklingError as e: - raise ValueError(f"Failed to pickle result data: {e}") + except pickle.PicklingError as e: + self.flame_logger.raise_error(f"Failed to pickle result data: {repr(e)}") else: - raise ValueError(f"Failed to pickle result data: {e}") + self.flame_logger.raise_error(f"Failed to pickle result data: {repr(e)}") if remote_node_id: data = {"remote_node_id": remote_node_id} @@ -118,18 +116,19 @@ def push_result(self, request_path += "localdp" # local_dp is guaranteed to not be None, so remap values to string and update request data mapping data.update({k: str(v) for k, v in local_dp.items()}) - #TODO:localdp (end) response = self.client.put(request_path, files={"file": (str(uuid.uuid4()), BytesIO(file_body))}, data=data, headers=[('Connection', 'close')]) - response.raise_for_status() + try: + response.raise_for_status() + except HTTPStatusError as e: + self.flame_logger.raise_error(f"Failed to push results: {repr(e)}") if type != "final": - flame_log(f"response push_results: {response.json()}", silent) + self.flame_logger.new_log(f"response push_results: {response.json()}", log_type='info') else: return {"status": "success"} - return {"status": "success", "url": response.json()["url"], "id": response.json()["url"].split("/")[-1]} @@ -150,13 +149,11 @@ def get_intermediate_data(self, :return: """ if (tag is not None) and (type != "local"): - raise ValueError("Tag can only be used with local type") + self.flame_logger.raise_error("Tag can only be used with local type") if (id is None) and (tag is None): - raise ValueError("Either id or tag should be provided") - + self.flame_logger.raise_error("Tag can only be used with local type") if tag and not re.match(r'^[a-z0-9]{1,2}|[a-z0-9][a-z0-9-]{,30}[a-z0-9]+$', tag): - raise ValueError("Tag must consist only of lowercase letters, numbers, and hyphens") - + self.flame_logger.raise_error(f"Tag must consist only of lowercase letters, numbers, and hyphens") type = "intermediate" if type == "global" else type if tag: @@ -182,7 +179,10 @@ def _get_location_url_for_tag(self, tag: str) -> str: :return: """ response = self.client.get(f"/local/tags/{tag}") - response.raise_for_status() + try: + response.raise_for_status() + except HTTPStatusError as e: + self.flame_logger.raise_error(f"Failed to Retrieves the URL associated with the specified tag.: {repr(e)}") urls = [] for item in response.json()["results"]: item["url"] = item["url"].split("/local/")[1] @@ -196,7 +196,10 @@ def _get_file(self, url: str) -> Any: :return: """ response = self.client.get(url) - response.raise_for_status() + try: + response.raise_for_status() + except HTTPStatusError as e: + self.flame_logger.raise_error(f"Failed to retrieve file from URL: {repr(e)}") return pickle.loads(BytesIO(response.content).read()) def get_local_tags(self, filter: Optional[str] = None) -> list[str]: @@ -232,7 +235,11 @@ def get_local_tags(self, filter: Optional[str] = None) -> list[str]: HTTPError: If the request to fetch tags fails. """ response = self.client.get("/local/tags") - response.raise_for_status() + try: + response.raise_for_status() + except HTTPStatusError as e: + self.flame_logger.raise_error(f"Failed to retrieve local tags: {repr(e)}") + tag_name_list = [tag["name"] for tag in response.json()["tags"]] if filter is not None: diff --git a/flamesdk/resources/client_apis/data_api.py b/flamesdk/resources/client_apis/data_api.py index 538deed..6b8161e 100644 --- a/flamesdk/resources/client_apis/data_api.py +++ b/flamesdk/resources/client_apis/data_api.py @@ -3,14 +3,15 @@ from flamesdk.resources.client_apis.clients.data_api_client import DataApiClient from flamesdk.resources.node_config import NodeConfig - +from flamesdk.resources.utils.logging import FlameLogger class DataAPI: - def __init__(self, config: NodeConfig): + def __init__(self, config: NodeConfig, flame_logger: FlameLogger) -> None: self.data_client = DataApiClient(config.project_id, config.nginx_name, config.data_source_token, - config.keycloak_token) + config.keycloak_token, + flame_logger= flame_logger) def get_data_client(self, data_id: str) -> AsyncClient: """ diff --git a/flamesdk/resources/client_apis/message_broker_api.py b/flamesdk/resources/client_apis/message_broker_api.py index 959f0d0..895adcd 100644 --- a/flamesdk/resources/client_apis/message_broker_api.py +++ b/flamesdk/resources/client_apis/message_broker_api.py @@ -4,11 +4,13 @@ from flamesdk.resources.node_config import NodeConfig from flamesdk.resources.client_apis.clients.message_broker_client import MessageBrokerClient, Message +from flamesdk.resources.utils.logging import FlameLogger class MessageBrokerAPI: - def __init__(self, config: NodeConfig, silent: bool = False): - self.message_broker_client = MessageBrokerClient(config, silent) + def __init__(self, config: NodeConfig, flame_logger: FlameLogger) -> None: + self.flame_logger = flame_logger + self.message_broker_client = MessageBrokerClient(config, flame_logger) self.config = self.message_broker_client.nodeConfig self.participants = asyncio.run(self.message_broker_client.get_partner_nodes(self.config.node_id, self.config.analysis_id)) @@ -19,8 +21,7 @@ async def send_message(self, message: dict, max_attempts: int = 1, timeout: Optional[int] = None, - attempt_timeout: int = 10, - silent: bool = False) -> tuple[list[str], list[str]]: + attempt_timeout: int = 10) -> tuple[list[str], list[str]]: """ Sends a message to specified nodes with support for multiple attempts and timeout handling. @@ -35,16 +36,16 @@ async def send_message(self, :param max_attempts: the maximum number of attempts to send the message :param timeout: time in seconds to wait for the message acknowledgement, if None waits indefinitely :param attempt_timeout: timeout of each attempt, if timeout is None (the last attempt will be indefinite though) - :param silent: if True, the response will not be logged + :raises TimeoutError: if the message is not acknowledged within the specified timeout :return: a tuple of nodes ids that acknowledged and not acknowledged the message """ - # Create a message object - message = Message(recipients=receivers, - message=message, - category=message_category, + message = Message(message=message, config=self.config, + outgoing=True, + flame_logger=self.flame_logger, message_number=self.message_broker_client.message_number, - outgoing=True) + category=message_category, + recipients=receivers) start_time = datetime.now() acknowledged = [] not_acknowledged = receivers @@ -58,7 +59,7 @@ async def send_message(self, message.recipients = not_acknowledged # Send the message - await self.message_broker_client.send_message(message, silent) + await self.message_broker_client.send_message(message) # await the message acknowledgement await_list = [] @@ -93,15 +94,13 @@ async def await_messages(self, node_ids: list[str], message_category: str, message_id: Optional[str] = None, - timeout: Optional[int] = None, - silent: bool = False) -> dict[str, Optional[list[Message]]]: + timeout: Optional[int] = None) -> dict[str, Optional[list[Message]]]: """ Wait for responses from the specified nodes :param node_ids: list of node ids to wait for :param message_category: the message category to wait for :param message_id: optional message id to wait for :param timeout: time in seconds to wait for the message, if None waits indefinitely - :param silent: if True, the response will not be logged :return: """ await_list = [] @@ -146,7 +145,8 @@ def delete_messages_by_id(self, message_ids: list[str]) -> int: number_of_deleted_messages += self.message_broker_client.delete_message_by_id(message_id, type="outgoing") return number_of_deleted_messages - def clear_messages(self, status: Literal["read", "unread", "all"] = "read", + def clear_messages(self, + status: Literal["read", "unread", "all"] = "read", min_age: Optional[int] = None) -> int: """ Deletes all messages by status (default: read messages) and if they are older than the specified min_age. It @@ -161,13 +161,13 @@ def clear_messages(self, status: Literal["read", "unread", "all"] = "read", number_of_deleted_messages += self.message_broker_client.clear_messages("outgoing", status, min_age) return number_of_deleted_messages - def send_message_and_wait_for_responses(self, receivers: list[str], + def send_message_and_wait_for_responses(self, + receivers: list[str], message_category: str, message: dict, max_attempts: int = 1, timeout: Optional[int] = None, - attempt_timeout: int = 10, - silent: bool = False) -> dict[str, Optional[list[Message]]]: + attempt_timeout: int = 10) -> dict[str, Optional[list[Message]]]: """ Sends a message to all specified nodes and waits for responses, (combines send_message and await_responses) :param receivers: list of node ids to send the message to @@ -176,7 +176,6 @@ def send_message_and_wait_for_responses(self, receivers: list[str], :param max_attempts: the maximum number of attempts to send the message :param timeout: time in seconds to wait for the message acknowledgement, if None waits indefinitely :param attempt_timeout: timeout of each attempt, if timeout is None (the last attempt will be indefinite though) - :param silent: if True, the response will not be logged :return: the responses """ time_start = datetime.now() @@ -187,7 +186,7 @@ def send_message_and_wait_for_responses(self, receivers: list[str], max_attempts, timeout, attempt_timeout, - silent)) + )) timeout = timeout - (datetime.now() - time_start).seconds if timeout < 0: timeout = 1 diff --git a/flamesdk/resources/client_apis/po_api.py b/flamesdk/resources/client_apis/po_api.py new file mode 100644 index 0000000..63664bf --- /dev/null +++ b/flamesdk/resources/client_apis/po_api.py @@ -0,0 +1,19 @@ +import asyncio + +from flamesdk.resources.client_apis.clients.po_client import POClient +from flamesdk.resources.node_config import NodeConfig +from flamesdk.resources.utils.logging import FlameLogger + +class POAPI: + def __init__(self, config: NodeConfig, flame_logger: FlameLogger) -> None: + self.po_client = POClient(config.nginx_name, config.keycloak_token, flame_logger) + self.analysis_id = config.analysis_id + + def stream_logs(self, log: str, log_type: str, status: str) -> None: + """ + Streams logs to the PO service. + :param log: the log message + :param log_type: type of the log (e.g., 'info', 'error') + :param status: status of the log + """ + self.po_client.stream_logs(log, log_type, self.analysis_id, status) diff --git a/flamesdk/resources/client_apis/storage_api.py b/flamesdk/resources/client_apis/storage_api.py index 8a36d9c..981d981 100644 --- a/flamesdk/resources/client_apis/storage_api.py +++ b/flamesdk/resources/client_apis/storage_api.py @@ -2,45 +2,40 @@ from flamesdk.resources.client_apis.clients.result_client import ResultClient, LocalDifferentialPrivacyParams from flamesdk.resources.node_config import NodeConfig - +from flamesdk.resources.utils.logging import FlameLogger class StorageAPI: - def __init__(self, config: NodeConfig): - self.result_client = ResultClient(config.nginx_name, config.keycloak_token) + def __init__(self, config: NodeConfig, flame_logger: FlameLogger) -> None: + self.result_client = ResultClient(config.nginx_name, config.keycloak_token, flame_logger) def submit_final_result(self, result: Any, output_type: Literal['str', 'bytes', 'pickle'] = 'str', - local_dp: Optional[LocalDifferentialPrivacyParams] = None, #TODO:localdp - silent: bool = False) -> dict[str, str]: + local_dp: Optional[LocalDifferentialPrivacyParams] = None) -> dict[str, str]: """ sends the final result to the hub. Making it available for analysts to download. This method is only available for nodes for which the method `get_role(self)` returns "aggregator”. :param result: the final result :param output_type: output type of final results (default: string) :param local_dp: tba - :param silent: if True, the response will not be logged :return: the request status code """ return self.result_client.push_result(result, type="final", output_type=output_type, - local_dp = local_dp, #TODO:localdp - silent=silent) + local_dp = local_dp) def save_intermediate_data(self, data: Any, location: Literal["global", "local"], remote_node_ids: Optional[list[str]] = None, - tag: Optional[str] = None, - silent: bool = False) -> Union[dict[str, dict[str, str]], dict[str, str]]: + tag: Optional[str] = None) -> Union[dict[str, dict[str, str]], dict[str, str]]: """ saves intermediate results/data either on the hub (location="global"), or locally :param data: the result to save :param location: the location to save the result, local saves in the node, global saves in central instance of MinIO :param remote_node_ids: optional remote node ids (used for accessing remote node's public key for encryption) :param tag: optional storage tag - :param silent: if True, the response will not be logged :return: list of the request status codes and url access and ids """ returns = {} @@ -48,11 +43,10 @@ def save_intermediate_data(self, for remote_node_id in remote_node_ids: returns[remote_node_id] = self.result_client.push_result(data, remote_node_id=remote_node_id, - type=location, - silent=silent) + type=location) return returns else: - return self.result_client.push_result(data, tag=tag, type=location, silent=silent) + return self.result_client.push_result(data, tag=tag, type=location) def get_intermediate_data(self, location: Literal["local", "global"], diff --git a/flamesdk/resources/rest_api.py b/flamesdk/resources/rest_api.py index 496396e..48031eb 100644 --- a/flamesdk/resources/rest_api.py +++ b/flamesdk/resources/rest_api.py @@ -10,17 +10,19 @@ from flamesdk.resources.client_apis.clients.message_broker_client import MessageBrokerClient from flamesdk.resources.client_apis.clients.data_api_client import DataApiClient from flamesdk.resources.client_apis.clients.result_client import ResultClient +from flamesdk.resources.client_apis.clients.po_client import POClient from flamesdk.resources.utils.utils import extract_remaining_time_from_token -from flamesdk.resources.utils.logging import flame_log +from flamesdk.resources.utils.logging import FlameLogger class FlameAPI: def __init__(self, message_broker: MessageBrokerClient, - data_client: DataApiClient | str, + data_client: Union[DataApiClient, str], result_client: ResultClient, + po_client: POClient, + flame_logger: FlameLogger, keycloak_token: str, - silent: bool, finished_check: Callable, finishing_call: Callable) -> None: app = FastAPI(title=f"FLAME node", @@ -40,6 +42,7 @@ def __init__(self, ) router = APIRouter() + self.flame_logger = flame_logger self.keycloak_token = keycloak_token self.finished = False self.finished_check = finished_check @@ -52,9 +55,13 @@ async def token_refresh(request: Request) -> JSONResponse: body = await request.json() new_token = body.get("token") if not new_token: - flame_log("No token, raising HTTPException", silent) - raise HTTPException(status_code=400, detail="Token is required") + try: + raise HTTPException(status_code=400, detail="Token is required") + except HTTPException as e: + self.flame_logger.raise_error(f"No token, raising HTTPException: {repr(e)}") + # refresh token in po client + po_client.refresh_token(new_token) # refresh token in message-broker message_broker.refresh_token(new_token) if type(data_client) is DataApiClient: @@ -66,13 +73,15 @@ async def token_refresh(request: Request) -> JSONResponse: self.keycloak_token = new_token return JSONResponse(content={"message": "Token refreshed successfully"}) except Exception as e: - flame_log(f"stack trace {e}", silent) - raise HTTPException(status_code=500, detail=str(e)) + try: + raise HTTPException(status_code=500, detail=str(e)) + except HTTPException as e: + self.flame_logger.raise_error(f"stack trace {repr(e)}") @router.get("/healthz", response_class=JSONResponse) def health() -> dict[str, Union[str, int]]: return {"status": self._finished([message_broker, data_client, result_client]), - "token_remaining_time": extract_remaining_time_from_token(self.keycloak_token)} + "token_remaining_time": extract_remaining_time_from_token(self.keycloak_token, self.flame_logger)} async def get_body(request: Request) -> dict[str, Any]: return await request.json() @@ -80,9 +89,9 @@ async def get_body(request: Request) -> dict[str, Any]: @router.post("/webhook", response_class=JSONResponse) def get_message(msg: dict = Depends(get_body)) -> None: if msg['meta']['sender'] != message_broker.nodeConfig.node_id: - flame_log(f"received message webhook: {msg}", silent) + self.flame_logger.new_log(f"received message webhook: {msg}", log_type='info') - message_broker.receive_message(msg, silent) + message_broker.receive_message(msg) # check message category for finished if msg['meta']['category'] == "analysis_finished": @@ -104,6 +113,9 @@ def _finished(self, clients: list[Any]) -> str: return "stuck" elif (not main_alive) and (not self.finished_check()): return "failed" + elif self.flame_logger.runstatus == "failed": + return "failed" + try: if self.finished: return "finished" diff --git a/flamesdk/resources/utils/fhir.py b/flamesdk/resources/utils/fhir.py index 6358918..80d851f 100644 --- a/flamesdk/resources/utils/fhir.py +++ b/flamesdk/resources/utils/fhir.py @@ -2,7 +2,7 @@ from typing import Optional, Any, Literal, Union from flamesdk.resources.client_apis.data_api import DataAPI -from flamesdk.resources.utils.logging import flame_log +from flamesdk.resources.utils.logging import FlameLogger _KNOWN_RESOURCES = ['Observation', 'QuestionnaireResponse'] @@ -12,29 +12,30 @@ def fhir_to_csv(fhir_data: dict[str, Any], col_key_seq: str, value_key_seq: str, input_resource: str, + flame_logger: FlameLogger, row_key_seq: Optional[str] = None, row_id_filters: Optional[list[str]] = None, col_id_filters: Optional[list[str]] = None, row_col_name: str = '', separator: str = ',', output_type: Literal["file", "dict"] = "file", - data_client: Optional[DataAPI] = None, - silent: bool = True) -> Union[StringIO, dict[Any, dict[Any, Any]]]: + data_client: Optional[DataAPI] = None) -> Union[StringIO, dict[Any, dict[Any, Any]]]: if input_resource not in _KNOWN_RESOURCES: - raise IOError(f"Unknown resource specified (given={input_resource}, known={_KNOWN_RESOURCES})") + flame_logger.raise_error(f"Unknown resource specified (given={input_resource}, known={_KNOWN_RESOURCES})") if input_resource == 'Observation' and not row_key_seq: - raise IOError(f"Resource 'Observation' specified, but no valid row key sequence was given (given={row_key_seq})") + flame_logger.raise_error(f"Resource 'Observation' specified, but no valid row key sequence was given " + f"(given={row_key_seq})") df_dict = {} - flame_log(f"Converting fhir data resource of type={input_resource} to csv", silent) + flame_logger.new_log(f"Converting fhir data resource of type={input_resource} to csv") while True: # extract from resource if input_resource == 'Observation': for i, entry in enumerate(fhir_data['entry']): - flame_log(f"Parsing fhir data entry no={i + 1} of {len(fhir_data['entry'])}", silent) - col_id = _search_fhir_resource(entry, key_sequence=col_key_seq) - row_id = _search_fhir_resource(entry, key_sequence=row_key_seq) - value = _search_fhir_resource(entry, key_sequence=value_key_seq) + flame_logger.new_log(f"Parsing fhir data entry no={i + 1} of {len(fhir_data['entry'])}") + col_id = _search_fhir_resource(entry, flame_logger, key_sequence=col_key_seq) + row_id = _search_fhir_resource(entry, flame_logger, key_sequence=row_key_seq) + value = _search_fhir_resource(entry, flame_logger, key_sequence=value_key_seq) if row_id_filters is not None: if (row_id is None) or (not any([row_id_filter in row_id for row_id_filter in row_id_filters])): continue @@ -48,10 +49,10 @@ def fhir_to_csv(fhir_data: dict[str, Any], df_dict[col_id][row_id] = value elif input_resource == 'QuestionnaireResponse': for i, entry in enumerate(fhir_data['entry']): - flame_log(f"Parsing fhir data entry no={i + 1} of {len(fhir_data['entry'])}", silent) + flame_logger.new_log(f"Parsing fhir data entry no={i + 1} of {len(fhir_data['entry'])}") for item in entry['resource']['item']: - col_id = _search_fhir_resource(item, key_sequence=col_key_seq, current=2) - value = _search_fhir_resource(item, key_sequence=value_key_seq, current=2) + col_id = _search_fhir_resource(item, flame_logger, key_sequence=col_key_seq, current=2) + value = _search_fhir_resource(item, flame_logger, key_sequence=value_key_seq, current=2) if col_id_filters is not None: if (col_id is None) or (not any([col_id_filter in col_id for col_id_filter in col_id_filters])): continue @@ -59,7 +60,10 @@ def fhir_to_csv(fhir_data: dict[str, Any], df_dict[col_id] = {} df_dict[col_id][str(i)] = value else: - raise IOError(f"Unknown resource specified (given={input_resource}, known={_KNOWN_RESOURCES})") + try: + raise IOError(f"Unknown resource specified (given={input_resource}, known={_KNOWN_RESOURCES})") + except IOError as e: + flame_logger.raise_error(f"Error while parsing fhir data: {repr(e)}") # get next data if data_client is None: @@ -70,33 +74,36 @@ def fhir_to_csv(fhir_data: dict[str, Any], link_relation, link_url = str(e['relation']), str(e['url']) if link_relation == 'next': next_query = link_url.split('/fhir/')[-1] - flame_log(f"Parsing next batch query={next_query}", silent) + flame_logger.new_log(f"Parsing next batch query={next_query}") if next_query: fhir_data = [r for r in data_client.get_fhir_data([next_query]) if r][0][next_query] else: - flame_log("Fhir data parsing finished", silent) + flame_logger.new_log("Fhir data parsing finished") break # set output format if output_type == "file": - output = _dict_to_csv(df_dict, row_col_name=row_col_name, separator=separator, silent=silent) + output = _dict_to_csv(df_dict, row_col_name=row_col_name, separator=separator, flame_logger=flame_logger) else: output = df_dict return output -def _dict_to_csv(data: dict[Any, dict[Any, Any]], row_col_name: str, separator: str, silent: bool = True) -> StringIO: +def _dict_to_csv(data: dict[Any, dict[Any, Any]], + row_col_name: str, + separator: str, + flame_logger: FlameLogger) -> StringIO: io = StringIO() headers = [f"{row_col_name}"] headers.extend(list(data.keys())) headers = [f"{header}" for header in headers] file_content = separator.join(headers) - flame_log("Writing fhir data dict to csv...", silent) + flame_logger.new_log("Writing fhir data dict to csv...") visited_rows = [] for i, rows in enumerate(data.values()): - flame_log(f"Writing row {i + 1} of {len(data.values())}", silent) + flame_logger.new_log(f"Writing row {i + 1} of {len(data.values())}") for row_id in rows.keys(): if row_id in visited_rows: continue @@ -112,11 +119,12 @@ def _dict_to_csv(data: dict[Any, dict[Any, Any]], row_col_name: str, separator: io.write(file_content) io.seek(0) - flame_log("Fhir data converted to csv", silent) + flame_logger.new_log("Fhir data converted to csv") return io def _search_fhir_resource(fhir_entry: Union[dict[str, Any], list[Any]], + flame_logger: FlameLogger, key_sequence: str, current: int = 0) -> Optional[Any]: keys = key_sequence.split('.') @@ -127,18 +135,21 @@ def _search_fhir_resource(fhir_entry: Union[dict[str, Any], list[Any]], try: if field == key: fhir_entry = fhir_entry[key] - next_value = _search_fhir_resource(fhir_entry, key_sequence, current + 1) + next_value = _search_fhir_resource(fhir_entry, flame_logger, key_sequence, current + 1) if next_value is not None: return next_value except KeyError: - print(f"KeyError: Unable to find field '{key}' in fhir data at level={current + 1} " - f"(keys found: {fhir_entry.keys()})") + flame_logger.new_log(f"Unable to find field '{key}' in fhir data at level={current + 1} " + f"(keys found: {fhir_entry.keys()})", + log_type='warning') return None elif type(fhir_entry) == list: for e in fhir_entry: - next_value = _search_fhir_resource(e, key_sequence, current) + next_value = _search_fhir_resource(e, flame_logger, key_sequence, current) if next_value is not None: return next_value + else: + return None else: if current == (len(keys) - 1): if type(fhir_entry) == dict: @@ -149,9 +160,14 @@ def _search_fhir_resource(fhir_entry: Union[dict[str, Any], list[Any]], if key: value = fhir_entry[key] else: - print(f"KeyError: Unable to find field '{key}' in fhir data at level={current + 1} " - f"(keys found: {fhir_entry.keys()})") + flame_logger.new_log(f"Unable to find field '{key}' in fhir data at level={current + 1} " + f"(keys found: fhir_entry.keys())", + log_type='warning') return None return value + else: + return None else: - print(f"Unexpected data type found (found type={type(fhir_entry)})") + flame_logger.new_log(f"Unexpected data type found (found type={type(fhir_entry)})", + log_type='warning') + return None diff --git a/flamesdk/resources/utils/logging.py b/flamesdk/resources/utils/logging.py index 5b2c9ab..f034b90 100644 --- a/flamesdk/resources/utils/logging.py +++ b/flamesdk/resources/utils/logging.py @@ -2,6 +2,7 @@ import time from enum import Enum from typing import Union +import queue class HUB_LOG_LITERALS(Enum): @@ -25,76 +26,160 @@ class HUB_LOG_LITERALS(Enum): 'error': HUB_LOG_LITERALS.error_code.value, 'critical-error': HUB_LOG_LITERALS.critical_error_code.value} +class FlameLogger: -def flame_log(msg: Union[str, bytes], - silent: bool, - sep: str = ' ', - end: str = '\n', - file = None, - flush: bool = False, - log_type: str = 'normal', - suppress_head: bool = False, - suppress_tail: bool = False) -> None: - """ - Print logs to console, if silent is set to False. May raise IOError, if suppress_head=False and log_type receives - an invalid value. - :param msg: - :param silent: - :param sep: - :param end: - :param file: - :param flush: - :param log_type: - :param suppress_head: - :param suppress_tail: - :return: - """ - if log_type not in _LOG_TYPE_LITERALS.keys(): - try: - raise IOError(f"Invalid log type given to logging function " - f"(known log_types={_LOG_TYPE_LITERALS.keys()}, received log_type={log_type}).") - except IOError as e: - flame_log(f"When attempting to use logging function, this error occurred: {e}", - False, - log_type='error') - - if not silent: - if isinstance(msg, bytes): - msg = msg.decode('utf-8', errors='replace') - msg_cleaned = ''.join(filter(lambda x: x in string.printable, msg)) - if suppress_head: # suppressing head (ignore log_type) - head = '' - elif log_type == 'normal': # if log_type=='normal', add nothing to head - head = f"[flame {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}] " - else: # else, add uppercase log_type - head = f"[flame -- {log_type.upper()} -- {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}] " - if suppress_tail: - tail = '' - else: - tail = f"!suff!{_LOG_TYPE_LITERALS[log_type]}" - print(f"{head}{msg_cleaned}{tail}", sep=sep, end=end, file=file, flush=flush) - - -def declare_log_types(new_log_types: dict[str, str], silent: bool) -> None: - """ - Declare new log_types to be added to log_type literals, and how/as what they should be interpreted by Flame - (the latter have to be known values from HUB_LOG_LITERALS for existing log status fields). - :param new_log_types: - :param silent: - :return: - """ - for k, v in new_log_types.items(): - if v in [e.value for e in HUB_LOG_LITERALS]: - if k not in _LOG_TYPE_LITERALS.keys(): - _LOG_TYPE_LITERALS[k] = v - flame_log(f"Successfully declared new log_type={k} with Hub literal '{v}'.", - silent, - log_type='info') + def __init__(self, silent: bool = False): + """ + Initialize the FlameLog class with a silent mode. + :param silent: If True, logs will not be printed to console. + """ + self.queue = queue.Queue() + self.po_api = None # Placeholder for PO_API instance + self.silent = silent + self.runstatus = 'starting' # Default status for logs + self.log_queue = "" + + def add_po_api(self, po_api) -> None: + """ + Add a POAPI instance to the FlameLogger. + :param po_api: An instance of POAPI. + """ + self.po_api = po_api + + def set_runstatus(self, status: str) -> None: + """ + Set the run status for the logger. + :param status: The status to set (e.g., 'running', 'completed', 'failed'). + """ + if status not in ['starting', 'running', 'finished', 'failed']: + status = 'failed' # Default to 'running' if an invalid status is provided + self.runstatus = status + + def send_logs_from_queue(self) -> None: + """ + Send all logs from the queue to the POAPI. + """ + if self.po_api is None: + try: + raise ValueError("POAPI instance is not set. Use add_po_api() to set it.") + except ValueError as e: + self.raise_error(repr(e)) + if not self.queue.empty(): + print("Sending queued logs to POAPI...") + while not self.queue.empty(): + print(self.queue.qsize(), "logs left in queue.") + print(self.queue.empty()) + log_dict = self.queue.get() + self.po_api.stream_logs(log_dict['msg'], log_dict['log_type'], log_dict['status']) + print(self.queue.empty()) + self.queue.task_done() + + print("All queued logs sent to POAPI.") + + def new_log(self, + msg: Union[str, bytes], + sep: str = ' ', + end: str = '\n', + file = None, + log_type: str = 'normal', + suppress_head: bool = False, + suppress_tail: bool = False) -> None: + """ + Print logs to console, if silent is set to False. May raise IOError, if suppress_head=False and log_type receives + an invalid value. + :param msg: + :param sep: + :param end: + :param file: + :param log_type: + :param suppress_head: + :param suppress_tail: + :return: + """ + if log_type not in _LOG_TYPE_LITERALS.keys(): + try: + raise IOError(f"Invalid log type given to logging function " + f"(known log_types={_LOG_TYPE_LITERALS.keys()}, received log_type={log_type}).") + except IOError as e: + self.raise_error(f"When attempting to use logging function, this error occurred: {repr(e)}") + + log = None + if not self.silent: + if isinstance(msg, bytes): + msg = msg.decode('utf-8', errors='replace') + msg_cleaned = ''.join(filter(lambda x: x in string.printable, msg)) + + if suppress_head: + head = '' else: - flame_log(f"Attempting to declare new log_type failed since log_type={k} " - f"already exists and cannot be overwritten.", silent, log_type='warning') + log_type_fill = "" if log_type == 'normal' else f"-- {log_type.upper()} -- " + head = f"[flame {log_type_fill}{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}] " + tail = "" if suppress_tail else f"!suff!{log_type}" + + log = f"{head}{msg_cleaned}{tail}" + print(log, sep=sep, end=end, file=file) + + if suppress_tail: + self.log_queue = log + else: + if suppress_head: + log = self.log_queue + log + self.log_queue = "" + self._submit_logs(log, _LOG_TYPE_LITERALS[log_type], self.runstatus) + + def waiting_for_health_check(self, seconds: int = 100) -> None: + time.sleep(seconds) + + def raise_error(self, message: str) -> None: + self.set_runstatus("failed") + self.new_log(message, log_type="error") + self.waiting_for_health_check() + + def declare_log_types(self, new_log_types: dict[str, str]) -> None: + """ + Declare new log_types to be added to log_type literals, and how/as what they should be interpreted by Flame + (the latter have to be known values from HUB_LOG_LITERALS for existing log status fields). + :param new_log_types: + :return: + """ + for k, v in new_log_types.items(): + if v in [e.value for e in HUB_LOG_LITERALS]: + if k not in _LOG_TYPE_LITERALS.keys(): + _LOG_TYPE_LITERALS[k] = v + self.new_log(f"Successfully declared new log_type={k} with Hub literal '{v}'.", + log_type='info') + else: + self.new_log(f"Attempting to declare new log_type failed since log_type={k} " + f"already exists and cannot be overwritten.", log_type='warning') + else: + self.raise_error(f"Attempting to declare new log_type failed. Attempted to declare new log_type for " + f"invalid Hub log field = {v} (known field values: " + f"{[e.value for e in HUB_LOG_LITERALS]}).") + + def _submit_logs(self, log: str, log_type: str, status: str): + if self.po_api is None: + log_dict = { + "msg": log, + "log_type": log_type, + "status": status + } + self.queue.put(log_dict) else: - flame_log(f"Attempting to declare new log_type failed. Attempted to declare new log_type for " - f"invalid Hub log field = {v} (known field values: {[e.value for e in HUB_LOG_LITERALS]}).", - False, - log_type='error') + try: + self.send_logs_from_queue() + self.po_api.stream_logs(log, log_type, status) + except Exception as e: + # If sending fails, we can still queue the log + log_dict = { + "msg": log, + "log_type": log_type, + "status": status + } + self.queue.put(log_dict) + # But also create new error log for queue + error_log_dict = { + "msg": f"Failed to send log to POAPI: {repr(e)}", + "log_type": 'warning', + "status": status + } + self.queue.put(error_log_dict) diff --git a/flamesdk/resources/utils/utils.py b/flamesdk/resources/utils/utils.py index 7b115b0..df76efa 100644 --- a/flamesdk/resources/utils/utils.py +++ b/flamesdk/resources/utils/utils.py @@ -1,30 +1,35 @@ -from httpx import AsyncClient, ConnectError +from httpx import AsyncClient, TransportError, HTTPStatusError import asyncio import time import base64 import json -from flamesdk.resources.utils.logging import flame_log +from flamesdk.resources.utils.logging import FlameLogger -def wait_until_nginx_online(nginx_name: str, silent: bool) -> None: - flame_log("\tConnecting to nginx...", silent, end='', suppress_tail=True) +def wait_until_nginx_online(nginx_name: str, flame_logger: FlameLogger) -> None: + flame_logger.new_log("\tConnecting to nginx...", end='', suppress_tail=True) nginx_is_online = False while not nginx_is_online: try: client = AsyncClient(base_url=f"http://{nginx_name}") response = asyncio.run(client.get("/healthz")) - response.raise_for_status() - nginx_is_online = True - except ConnectError: + try: + response.raise_for_status() + nginx_is_online = True + except HTTPStatusError as e: + flame_logger.new_log(f"{repr(e)}", log_type="warning") + #nginx_is_online = True + except TransportError: time.sleep(1) - flame_log("success", silent, suppress_head=True) + flame_logger.new_log("success", suppress_head=True) -def extract_remaining_time_from_token(token: str) -> int: +def extract_remaining_time_from_token(token: str, flame_logger: FlameLogger) -> int: """ Extracts the remaining time until the expiration of the token. :param token: + :param flame_logger: :return: int in seconds until the expiration of the token """ try: @@ -36,11 +41,14 @@ def extract_remaining_time_from_token(token: str) -> int: payload = json.loads(payload) exp_time = payload.get("exp") if exp_time is None: - raise ValueError("Token does not contain expiration ('exp') claim.") + try: + raise ValueError("Token does not contain expiration ('exp') claim.") + except ValueError as e: + flame_logger.raise_error(f"Error extracting expiration time from token: {repr(e)}") # Calculate the time remaining until the expiration current_time = int(time.time()) remaining_time = exp_time - current_time return remaining_time if remaining_time > 0 else 0 except Exception as e: - raise ValueError(f"Invalid token: {str(e)}") + flame_logger.raise_error(f"{repr(e)}") diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..9c46705 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +addopts = --ignore=tests/test_images + diff --git a/tests/test_images/test_query.py b/tests/test_images/test_query.py new file mode 100644 index 0000000..321392d --- /dev/null +++ b/tests/test_images/test_query.py @@ -0,0 +1,11 @@ +from flamesdk import FlameCoreSDK +import time + + +def main(): + flame = FlameCoreSDK() + flame.get_fhir_data("fhir/Observation?_count=500") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/unit_test/test_data_api_client.py b/tests/unit_test/test_data_api_client.py new file mode 100644 index 0000000..d7e2462 --- /dev/null +++ b/tests/unit_test/test_data_api_client.py @@ -0,0 +1,108 @@ +# Python +import pytest +from unittest.mock import AsyncMock, patch +from flamesdk.resources.utils.logging import FlameLogger +from httpx import AsyncClient + +from flamesdk.resources.client_apis.clients.data_api_client import DataApiClient + +# Dummy response mimicking httpx.Response +class DummyResponse: + def __init__(self, json_data=None, text_data="", content=b""): + self._json = json_data or {} + self.text = text_data + self.content = content + + def json(self): + return self._json + + def raise_for_status(self): + pass + +# Dummy async get function to return our DummyResponse instance. +async def dummy_get(url, **kwargs): + if url.startswith("/kong/datastore/"): + # Response for _retrieve_available_sources + return DummyResponse({"data": [{"name": "source1"}]}) + elif "/fhir/" in url: + # Response for fhir endpoint requests + return DummyResponse({"result": "fhir-data"}) + elif url.endswith("/s3"): + # Response for _get_s3_dataset_names returning a key wrapped in tags + return DummyResponse(text_data="key1") + elif "/s3/" in url: + # Response for individual S3 requests + return DummyResponse(content=b"s3-data") + return DummyResponse() + +def test_data_api_client_init(): + with patch("flamesdk.resources.client_apis.clients.data_api_client.AsyncClient.get", new=AsyncMock(side_effect=dummy_get)): + flame_logger = FlameLogger() + client = DataApiClient("proj_id", "nginx", "data_token", "key_token", flame_logger) + # Verify available sources were set as expected. + assert client.get_available_sources() == [{"name": "source1"}] + + +def test_refresh_token(): + # Patch the get method to prevent actual HTTP calls during initialization. + with patch("flamesdk.resources.client_apis.clients.data_api_client.AsyncClient.get", new=AsyncMock(side_effect=dummy_get)): + flame_logger = FlameLogger() + # Create DataApiClient with initial keycloak_token "key_token" + client = DataApiClient("proj_id", "nginx", "data_token", "key_token", flame_logger) + # Check initial Authorization header in hub_client + initial_auth = client.hub_client.headers.get("Authorization") + assert initial_auth == "Bearer key_token" + # Refresh token with a new keycloak token + new_token = "new_key_token" + client.refresh_token(new_token) + # Verify that the hub_client has been updated with the new token + updated_auth = client.hub_client.headers.get("Authorization") + assert updated_auth == f"Bearer {new_token}" + +def test_get_data_fhir(): + fhir_queries = ["query1", "query2"] + with patch("flamesdk.resources.client_apis.clients.data_api_client.AsyncClient.get", new=AsyncMock(side_effect=dummy_get)): + flame_logger = FlameLogger() + client = DataApiClient("proj_id", "nginx", "data_token", "key_token", flame_logger) + # Call get_data with fhir_queries provided + results = client.get_data(fhir_queries=fhir_queries) + # Expect one source with fhir data responses returned for each query. + expected = {"query1": {"result": "fhir-data"}, "query2": {"result": "fhir-data"}} + assert results == [expected] + +def test_get_data_s3(): + s3_keys = ["key1"] + with patch("flamesdk.resources.client_apis.clients.data_api_client.AsyncClient.get", new=AsyncMock(side_effect=dummy_get)): + flame_logger = FlameLogger() + client = DataApiClient("proj_id", "nginx", "data_token", "key_token", flame_logger) + # Call get_data with s3_keys provided and without fhir_queries. + results = client.get_data(s3_keys=s3_keys) + # Expected one source with S3 data for key1. + expected = {"key1": b"s3-data"} + assert results == [expected] + +@pytest.fixture +def dummy_sources(monkeypatch): + sources = [ + {"id": "test_id", "paths": ["http://test_path"]}, + {"id": "other_id", "paths": ["http://other_path"]} + ] + async def dummy_retrieve_available_sources(self): + return sources + monkeypatch.setattr(DataApiClient, "_retrieve_available_sources", dummy_retrieve_available_sources) + return sources + +@pytest.fixture +def client(dummy_sources): + flame_logger = FlameLogger() + return DataApiClient("proj_id", "nginx", "data_token", "key_token",flame_logger) + +def test_get_data_source_client_success(client): + client_obj = client.get_data_source_client("test_id") + assert isinstance(client_obj, AsyncClient) + assert str(client_obj.base_url) == "http://test_path" + +def test_get_data_source_client_not_found(client): + with pytest.raises(ValueError) as exc_info: + client.get_data_source_client("invalid_id") + assert "Data source with id invalid_id not found" in str(exc_info.value) \ No newline at end of file diff --git a/tests/unit_test/test_flame_logger.py b/tests/unit_test/test_flame_logger.py new file mode 100644 index 0000000..5480ad4 --- /dev/null +++ b/tests/unit_test/test_flame_logger.py @@ -0,0 +1,38 @@ +import pytest +from unittest.mock import MagicMock +from flamesdk.resources.utils.logging import FlameLogger, _LOG_TYPE_LITERALS + +@pytest.fixture +def mock_po_client(): + return MagicMock() + +@pytest.fixture +def flame_logger(): + return FlameLogger() + +def test_add_po_client(flame_logger, mock_po_client): + flame_logger.add_po_api(mock_po_client) + assert flame_logger.po_api == mock_po_client + +def test_set_runstatus(flame_logger): + flame_logger.set_runstatus("running") + assert flame_logger.runstatus == "running" + +def test_new_log_without_po_client(flame_logger): + flame_logger.new_log("Test log message", log_type="info") + assert flame_logger.queue.empty() == False + + +def test_send_logs_from_queue(flame_logger, mock_po_client): + + flame_logger.new_log("Test log message", log_type="info") + flame_logger.add_po_client(mock_po_client) + + flame_logger.send_logs_from_queue() + assert flame_logger.queue.empty() == True + +def test_declare_log_types(flame_logger): + new_log_types = {"custom": "info"} + flame_logger.declare_log_types(new_log_types) + assert "custom" in _LOG_TYPE_LITERALS + assert _LOG_TYPE_LITERALS["custom"] == "info" diff --git a/tests/unit_test/test_message.py b/tests/unit_test/test_message.py new file mode 100644 index 0000000..d491198 --- /dev/null +++ b/tests/unit_test/test_message.py @@ -0,0 +1,115 @@ +import pytest +import uuid +import datetime +from flamesdk.resources.client_apis.clients.message_broker_client import Message + +# Dummy stub for NodeConfig used by Message. +class DummyNodeConfig: + def __init__(self): + self.node_id = "node_1" + self.analysis_id = "analysis_dummy" + self.nginx_name = "localhost" + self.keycloak_token = "dummy_token" + + def set_role(self, role: str): + self.role = role + + def set_node_id(self, node_id: str): + self.node_id = node_id + +# Helper function to generate the current time string +def current_time_str(): + return str(datetime.datetime.now()) + +# Test for a valid outgoing message. +def test_outgoing_message_valid(): + node_config = DummyNodeConfig() + # Outgoing message must not include "meta" + body = {"data": "test outgoing message"} + message_number = 1 + category = "notification" + recipients = ["recipient1", "recipient2"] + msg = Message(message=body, + config=node_config, + outgoing=True, + message_number=message_number, + category=category, + recipients=recipients) + # Check if meta was created and contains expected fields. + meta = msg.body.get("meta") + assert meta is not None + assert meta["type"] == "outgoing" + assert meta["category"] == category + assert meta["number"] == message_number + assert meta["sender"] == node_config.node_id + # Recipients should be preserved + assert msg.recipients == recipients + +# Test for error when outgoing message includes a "meta" field. +def test_outgoing_message_with_meta_error(): + node_config = DummyNodeConfig() + body = {"meta": {"dummy": "field"}, "data": "test"} + with pytest.raises(ValueError, match=r"Cannot use field 'meta' in message body"): + Message(message=body, + config=node_config, + outgoing=True, + message_number=1, + category="notification", + recipients=["recipient1"]) + +# Test for error when message_number is not an integer. +def test_outgoing_message_invalid_message_number(): + node_config = DummyNodeConfig() + body = {"data": "test"} + with pytest.raises(ValueError, match=r"did not specify integer value for message_number"): + Message(message=body, + config=node_config, + outgoing=True, + message_number="not_an_int", + category="notification", + recipients=["recipient1"]) + +# Test for error when category is not a string. +def test_outgoing_message_invalid_category(): + node_config = DummyNodeConfig() + body = {"data": "test"} + with pytest.raises(ValueError, match=r"did not specify string value for category"): + Message(message=body, + config=node_config, + outgoing=True, + message_number=1, + category=123, + recipients=["recipient1"]) + +# Test for error when recipients is not a list of strings. +def test_outgoing_message_invalid_recipients(): + node_config = DummyNodeConfig() + body = {"data": "test"} + with pytest.raises(ValueError, match=r"did not specify list of strings"): + Message(message=body, + config=node_config, + outgoing=True, + message_number=1, + category="notification", + recipients="not_a_list") + +# Test for an incoming message where meta exists. +def test_incoming_message(): + node_config = DummyNodeConfig() + # Simulate an incoming message with pre-existing meta data. + meta = {"sender": "node_2", "status": "unread", "type": "incoming", "akn_id": None, "created_at": current_time_str()} + body = {"data": "incoming message", "meta": meta.copy()} + msg = Message(message=body, config=node_config, outgoing=False) + # Incoming messages set recipients to the sender. + assert msg.recipients == [meta["sender"]] + # The meta type should be updated to 'incoming' + assert msg.body["meta"]["type"] == "incoming" + +# Test set_read method. +def test_set_read(): + node_config = DummyNodeConfig() + body = {"data": "test", "meta" : {"sender": "node_2", "status": "unread", "type": "incoming", "akn_id": None, "created_at": current_time_str()}} + msg = Message(message=body, config=node_config, outgoing=False) + # concatenate the meta data, meta is not set in the body + msg.set_read() + assert msg.body["meta"]["status"] == "read" \ No newline at end of file diff --git a/tests/unit_test/test_message_broker_api.py b/tests/unit_test/test_message_broker_api.py new file mode 100644 index 0000000..a11cdde --- /dev/null +++ b/tests/unit_test/test_message_broker_api.py @@ -0,0 +1,59 @@ +import pytest +import asyncio +from unittest.mock import MagicMock, AsyncMock +from flamesdk.resources.client_apis.message_broker_api import MessageBrokerAPI +from flamesdk.resources.node_config import NodeConfig +from flamesdk.resources.utils.logging import FlameLogger + +class DummyMessageBrokerClient: + def __init__(self, config, flame_logger): + self.nodeConfig = config + self.message_number = 1 + async def get_partner_nodes(self, node_id, analysis_id): + return ["nodeA", "nodeB"] + async def send_message(self, *args, **kwargs): + return (["nodeA"], ["nodeB"]) + async def await_message_acknowledgement(self, *args, **kwargs): + # print *args + + if args[0] == ["nodeA"]: + return "nodeA" + else: + return None + + +@pytest.fixture +def dummy_config(): + config = NodeConfig() + config.node_id = "dummy_node" + config.analysis_id = "dummy_analysis" + config.nginx_name = "dummy_nginx" + config.keycloak_token = "dummy_token" + return config + +@pytest.fixture +def dummy_logger(): + return FlameLogger() + +@pytest.fixture +def patch_message_broker_client(monkeypatch): + monkeypatch.setattr( + "flamesdk.resources.client_apis.message_broker_api.MessageBrokerClient", + DummyMessageBrokerClient + ) + + +def test_message_broker_api_init(dummy_config, dummy_logger, patch_message_broker_client): + api = MessageBrokerAPI(dummy_config, dummy_logger) + assert api.config == dummy_config + assert api.participants == ["nodeA", "nodeB"] + + +def test_send_message(dummy_config, dummy_logger, patch_message_broker_client): + api = MessageBrokerAPI(dummy_config, dummy_logger) + receivers = ["nodeA", "nodeB"] + message_category = "test" + message = {"data": "hello"} + acknowledged, not_acknowledged = asyncio.run(api.send_message(receivers, message_category, message)) + + print(acknowledged, not_acknowledged) diff --git a/tests/unit_test/test_message_broker_client.py b/tests/unit_test/test_message_broker_client.py new file mode 100644 index 0000000..2e67618 --- /dev/null +++ b/tests/unit_test/test_message_broker_client.py @@ -0,0 +1,127 @@ +import os +import pytest +import asyncio +import datetime +from httpx import Response, Request +#from asyncmock import AsyncMock +from flamesdk.resources.client_apis.clients.message_broker_client import MessageBrokerClient, Message +from flamesdk.resources.utils.logging import FlameLogger + +# Dummy stub for NodeConfig used by MessageBrokerClient. +class DummyNodeConfig: + def __init__(self): + self.node_id = "dummy_node" + self.analysis_id = "dummy_analysis" + self.nginx_name = "dummy_nginx" + self.keycloak_token = "dummy_token" + + def set_role(self, role: str): + self.role = role + + def set_node_id(self, node_id: str): + self.node_id = node_id + +# Set environment variable needed by MessageBrokerClient. +@pytest.fixture(autouse=True) +def set_analysis_id(monkeypatch): + monkeypatch.setenv("ANALYSIS_ID", "test_analysis") + +# Patch get_self_config so that __init__ does not perform real network calls. +@pytest.fixture +def dummy_get_self_config(): + async def fake_get_self_config(self, analysis_id: str): + return {"nodeType": "test_role", "nodeId": "test_node"} + return fake_get_self_config + +# Create a test client with patched network methods. +@pytest.fixture +def client(monkeypatch, dummy_get_self_config): + monkeypatch.setattr(MessageBrokerClient, "get_self_config", dummy_get_self_config) + async def fake_connect(self): + pass + monkeypatch.setattr(MessageBrokerClient, "_connect", fake_connect) + flame_logger = FlameLogger() + return MessageBrokerClient(DummyNodeConfig(), flame_logger) + +def test_refresh_token(client): + new_token = "new_dummy_token" + client.refresh_token(new_token) + updated_auth = client._message_broker.headers.get("Authorization") + assert updated_auth == f"Bearer {new_token}" + +def test_delete_message_by_id(client): + # Create a dummy outgoing message without 'meta' field. + msg_body = { + "data": "test message" + } + dummy_message = Message(message=msg_body, config=client.nodeConfig, outgoing=True, + message_number=1, category="test", recipients=["rec1"]) + client.list_of_outgoing_messages.append(dummy_message) + deleted_count = client.delete_message_by_id(dummy_message.body["meta"]["id"], "outgoing") + assert deleted_count == 1 + # Verify the message is removed from the outgoing list + assert all(m.body["meta"]["id"] != dummy_message.body["meta"]["id"] for m in client.list_of_outgoing_messages) + +def test_clear_messages(client): + # Create dummy incoming messages with different status. + current_time = str(datetime.datetime.now()) + msg_body_read = { + "data": "message1", + "meta": { + "id": "msg-1", + "sender": "nodeA", + "status": "read", + "type": "incoming", + "category": "test", + "number": 1, + "created_at": current_time, + "arrived_at": None, + "akn_id": "nodeX", + } + } + msg_body_unread = { + "data": "message2", + "meta": { + "id": "msg-2", + "sender": "nodeA", + "status": "unread", + "type": "incoming", + "category": "test", + "number": 2, + "created_at": current_time, + "arrived_at": None, + "akn_id": "nodeX", + } + } + msg1 = Message(message=msg_body_read, config=client.nodeConfig, outgoing=False) + msg2 = Message(message=msg_body_unread, config=client.nodeConfig, outgoing=False) + client.list_of_incoming_messages.extend([msg1, msg2]) + deleted_count = client.clear_messages("incoming", status="read") + assert deleted_count == 1 + # Verify only the message with status "unread" remains. + assert client.list_of_incoming_messages[0].body["meta"]["status"] == "unread" + +def test_receive_message(client, monkeypatch): + # Create an incoming message with missing akn_id. + current_time = str(datetime.datetime.now()) + msg_body = { + "data": "incoming test", + "meta": { + "id": "incoming-1", + "sender": "nodeB", + "status": "unread", + "type": "incoming", + "category": "test", + "number": 1, + "created_at": current_time, + "arrived_at": None, + "akn_id": None, + } + } + async def dummy_ack(self, message): + return + monkeypatch.setattr(MessageBrokerClient, "acknowledge_message", dummy_ack) + client.receive_message(msg_body) + received_msg = client.list_of_incoming_messages[-1] + # Verify that recipients have been set to the sender. + assert received_msg.recipients == [msg_body["meta"]["sender"]] \ No newline at end of file diff --git a/tests/unit_test/test_result_client.py b/tests/unit_test/test_result_client.py new file mode 100644 index 0000000..a58c494 --- /dev/null +++ b/tests/unit_test/test_result_client.py @@ -0,0 +1,49 @@ +import pytest +from flamesdk.resources.client_apis.clients.result_client import ResultClient +from flamesdk.resources.utils.logging import FlameLogger + +class DummyClient: + def __init__(self, *args, **kwargs): + self.base_url = kwargs.get('base_url') + self.headers = kwargs.get('headers') + self.follow_redirects = kwargs.get('follow_redirects') + self.last_request = None + def post(self, *args, **kwargs): + self.last_request = (args, kwargs) + class DummyResponse: + def json(self): + return {'status': 'success'} + def raise_for_status(self): + pass + return DummyResponse() + def put(self, *args, **kwargs): + self.last_request = (args, kwargs) + class DummyResponse: + def json(self): + return {'status': 'success'} + def raise_for_status(self): + pass + return DummyResponse() + +def test_result_client_init(monkeypatch): + monkeypatch.setattr('flamesdk.resources.client_apis.clients.result_client.Client', DummyClient) + flame_logger = FlameLogger() + client = ResultClient('nginx', 'token', flame_logger) + assert client.nginx_name == 'nginx' + assert isinstance(client.client, DummyClient) + +def test_refresh_token(monkeypatch): + monkeypatch.setattr('flamesdk.resources.client_apis.clients.result_client.Client', DummyClient) + flame_logger = FlameLogger() + client = ResultClient('nginx', 'token', flame_logger) + client.refresh_token('newtoken') + assert client.client.headers['Authorization'] == 'Bearer newtoken' + +def test_push_result(monkeypatch): + monkeypatch.setattr('flamesdk.resources.client_apis.clients.result_client.Client', DummyClient) + flame_logger = FlameLogger() + client = ResultClient('nginx', 'token', flame_logger) + # Patch client.post to simulate a response + client.client.post = lambda *a, **k: type('R', (), {'json': lambda self: {'status': 'success'}})() + result = client.push_result(result={'foo': 'bar'}) + assert result['status'] == 'success' diff --git a/tests/unit_test/test_util.py b/tests/unit_test/test_util.py index 75bf512..b098184 100644 --- a/tests/unit_test/test_util.py +++ b/tests/unit_test/test_util.py @@ -1,50 +1,49 @@ -from flamesdk.resources.utils.fhir import fhir_to_csv, _search_fhir_resource -from flamesdk.resources.utils.utils import extract_remaining_time_from_token, flame_log +from flamesdk.resources.utils.fhir import fhir_to_csv +from flamesdk.resources.utils.utils import extract_remaining_time_from_token +from flamesdk.resources.utils.logging import FlameLogger import ast import time def test_extract_remaining_time_from_token(): + flame_logger = FlameLogger() token = "eyJhbGciOiJSUzI1NiIsInR5cCIgOiAiSldUIiwia2lkIiA6ICJBWWdqdWV5T09pRVZjM2pOYVdweWtZWWptQjhOaEVkTTRlbFhzUHN5SDhvIn0.eyJleHAiOjE3NDE2ODYxMzcsImlhdCI6MTc0MTY4NDMzNywianRpIjoiMzUwZWM2ZWUtMWYwMi00NDdhLTgyOWYtOTE5MzMxODNhNGY1IiwiaXNzIjoiaHR0cDovL2ZsYW1lLW5vZGUta2V5Y2xvYWsvcmVhbG1zL2ZsYW1lIiwiYXVkIjoiYWNjb3VudCIsInN1YiI6IjIzYzI5ODQxLThhNzUtNGFjNi1hMjM4LTAyN2QyMDJjN2FjYyIsInR5cCI6IkJlYXJlciIsImF6cCI6ImJiOWFhMTY1LTkwOTYtNGFhZS1iNTE5LTgxN2Y4NTdlNjNiYSIsImFjciI6IjEiLCJyZWFsbV9hY2Nlc3MiOnsicm9sZXMiOlsib2ZmbGluZV9hY2Nlc3MiLCJkZWZhdWx0LXJvbGVzLWZsYW1lIiwidW1hX2F1dGhvcml6YXRpb24iXX0sInJlc291cmNlX2FjY2VzcyI6eyJhY2NvdW50Ijp7InJvbGVzIjpbIm1hbmFnZS1hY2NvdW50IiwibWFuYWdlLWFjY291bnQtbGlua3MiLCJ2aWV3LXByb2ZpbGUiXX19LCJzY29wZSI6ImVtYWlsIHByb2ZpbGUiLCJlbWFpbF92ZXJpZmllZCI6ZmFsc2UsImNsaWVudEhvc3QiOiIxMC4yNDQuMTc0LjI0NiIsInByZWZlcnJlZF91c2VybmFtZSI6InNlcnZpY2UtYWNjb3VudC1iYjlhYTE2NS05MDk2LTRhYWUtYjUxOS04MTdmODU3ZTYzYmEiLCJjbGllbnRBZGRyZXNzIjoiMTAuMjQ0LjE3NC4yNDYiLCJjbGllbnRfaWQiOiJiYjlhYTE2NS05MDk2LTRhYWUtYjUxOS04MTdmODU3ZTYzYmEifQ.o-jq4eMASfwigw83k5XWpaGrl1_omUNP9onkGqa1LWhY_j8Ziv45A4c1IUjcCdSBBXMwylFoNxvA97lHKHsOFH5Bv3EeVDeIUA3YyCPFPyVAH8Woi26E0iGTmUoFyW8Vn6_Xk_jRfK280BHORL6SxjH5nvGQuVIkXHCgaTo2YTRN4ze4i1xpCnNwBcdC7y94y5MrVT9xDGalgB7qfho0lIGzdgXNJjwBwnDXRjrszShsvkW2TCphql0kS7pEMDptWd2WHavIHAQqQritFfe5VylEdhkH2u_FeNksESAZJlYHPxNSz1XWYtDLymnFw_oQbOF_kf0PI_d4-gJ96W8h2g" - remaining_time = extract_remaining_time_from_token(token) + remaining_time = extract_remaining_time_from_token(token, flame_logger) assert (remaining_time == 0) def test_flame_log(): - with open('stream.tar', 'rb') as file: + flame_logger = FlameLogger() + with open('tests/unit_test/stream.tar', 'rb') as file: file_content = file.read() - flame_log(file_content, False) - flame_log("file_content", True, suppress_head=True) + flame_logger.new_log(file_content) + flame_logger.new_log("file_content", suppress_head=True) -def test_multi_param(in_type): - if in_type == 'Observation': - filename = "stream_observation.json" - elif in_type == 'QuestionnaireResponse': - filename = "stream_questionnaire.json" - else: - raise IOError('Unable to recognize input resource type') + +def test_multi_param_observation(): + flame_logger = FlameLogger() + filename = "tests/unit_test/stream_observation.json" with open(filename, "r") as f: content = f.read() fhir_data = ast.literal_eval(content) + output = fhir_to_csv(fhir_data, + col_key_seq="resource.subject.reference", + row_key_seq="resource.component.valueCodeableConcept.coding.code", + value_key_seq="resource.component.valueQuantity.value", + row_id_filters=["ENSG"], + input_resource="Observation", + flame_logger=flame_logger) + assert output is not None - if in_type == 'Observation': - return fhir_to_csv(fhir_data, - col_key_seq="resource.subject.reference", - row_key_seq="resource.component.valueCodeableConcept.coding.code", - value_key_seq="resource.component.valueQuantity.value", - row_id_filters=["ENSG"], - input_resource=in_type) - elif in_type == 'QuestionnaireResponse': - return fhir_to_csv(fhir_data, - col_key_seq="resource.item.linkId", - value_key_seq="resource.item.answer.value", - input_resource=in_type) - -start_time = time.time() -#print(_search_in_fhir_entry({'fullUrl': 'http://nginx-analysis-671b3985-f901-48c5-9cba-ee62bc6f393d-1/fhir/Observation/gene-observation-C00039-ENSG00000005156', 'resource': {'category': [{'coding': [{'code': 'laboratory', 'display': 'Laboratory', 'system': 'http://terminology.hl7.org/CodeSystem/observation-category'}]}], 'code': {'coding': [{'code': '69548-6', 'display': 'Genetic variant assessment', 'system': 'http://loinc.org'}]}, 'component': [{'code': {'coding': [{'code': '48018-6', 'display': 'Gene studied ' '[ID]', 'system': 'http://loinc.org'}]}, 'valueCodeableConcept': {'coding': [{'code': 'ENSG00000005156', 'system': 'http://ensembl.org'}]}}, {'code': {'coding': [{'code': '48003-8', 'display': 'DNA sequence ' 'variation ' 'identifier ' '[Identifier]', 'system': 'http://loinc.org'}]}, 'valueQuantity': {'code': 'count', 'system': 'http://unitsofmeasure.org', 'unit': 'count', 'value': 2452}}], 'id': 'gene-observation-C00039-ENSG00000005156', 'meta': {'lastUpdated': '2025-06-12T11:24:08.247Z', 'versionId': '3'}, 'resourceType': 'Observation', 'status': 'final', 'subject': {'reference': 'Patient/C00039'}}, 'search': {'mode': 'match'}}, 'resource.component.valueQuantity.value' )) -print("output: " + test_multi_param('Observation').read()) -print(f"Elapsed time: {time.time() - start_time} secs") -print(f"Estimated time: {(time.time() - start_time) * ((41886 * 118) / 500) / 60} minutes\n\n") -start_time = time.time() -print("output: " + test_multi_param('QuestionnaireResponse').read()) -print(f"Elapsed time: {time.time() - start_time} secs") +def test_multi_param_questionnaire(): + flame_logger = FlameLogger() + filename = "tests/unit_test/stream_ques.json" + with open(filename, "r") as f: + content = f.read() + fhir_data = ast.literal_eval(content) + output = fhir_to_csv(fhir_data, + col_key_seq="resource.item.linkId", + value_key_seq="resource.item.answer.value", + input_resource="QuestionnaireResponse", + flame_logger=flame_logger) + assert output is not None