Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions examples/run_star_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,12 @@ def aggregation_method(self, analysis_results):
total_patient_count = sum(analysis_results)
return total_patient_count

def has_converged(self, result, last_result, num_iterations):
def has_converged(self, result, last_result):
"""
Determines if the aggregation process has converged.

:param result: The current aggregated result.
:param last_result: The aggregated result from the previous iteration.
:param num_iterations: The number of iterations completed so far.
:return: True if the aggregation has converged; False to continue iterations.
"""
# TODO (optional): if the parameter 'simple_analysis' in 'StarModel' is set to False,
Expand Down
89 changes: 89 additions & 0 deletions examples/run_star_model_dp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from flame.star import StarLocalDPModel, StarAnalyzer, StarAggregator


class MyAnalyzer(StarAnalyzer):
def __init__(self, flame):
"""
Initializes the custom Analyzer node.

:param flame: Instance of FlameCoreSDK to interact with the FLAME components.
"""
super().__init__(flame) # Connects this analyzer to the FLAME components

def analysis_method(self, data, aggregator_results):
"""
Performs analysis on the retrieved data from data sources.

:param data: A list of dictionaries containing the data from each data source.
- Each dictionary corresponds to a data source.
- Keys are the queries executed, and values are the results (dict for FHIR, str for S3).
:param aggregator_results: Results from the aggregator in previous iterations.
- None in the first iteration.
- Contains the result from the aggregator's aggregation_method in subsequent iterations.
:return: Any result of your analysis on one node (ex. patient count).
"""
# TODO: Implement your analysis method
# in this example we retrieving first fhir dataset, extract patient counts,
# take total number of patients
patient_count = float(data[0]['Patient?_summary=count']['total'])
return patient_count


class MyAggregator(StarAggregator):
def __init__(self, flame):
"""
Initializes the custom Aggregator node.

:param flame: Instance of FlameCoreSDK to interact with the FLAME components.
"""
super().__init__(flame) # Connects this aggregator to the FLAME components

def aggregation_method(self, analysis_results):
"""
Aggregates the results received from all analyzer nodes.

:param analysis_results: A list of analysis results from each analyzer node.
:return: The aggregated result (e.g., total patient count across all analyzers).
"""
# TODO: Implement your aggregation method
# in this example we retrieving sum up total patient counts across all nodes
total_patient_count = sum(analysis_results)
return total_patient_count

def has_converged(self, result, last_result):
"""
Determines if the aggregation process has converged.

:param result: The current aggregated result.
:param last_result: The aggregated result from the previous iteration.
:return: True if the aggregation has converged; False to continue iterations.
"""
# TODO (optional): if the parameter 'simple_analysis' in 'StarModel' is set to False,
# this function defines the exit criteria in a multi-iterative analysis (otherwise ignored)
return True # Return True to indicate convergence in this simple analysis


def main():
"""
Sets up and initiates the distributed analysis using the FLAME components.

- Defines the custom analyzer and aggregator classes.
- Specifies the type of data and queries to execute.
- Configures analysis parameters like iteration behavior and output format.
"""
StarLocalDPModel(
analyzer=MyAnalyzer, # Custom analyzer class (must inherit from StarAnalyzer)
aggregator=MyAggregator, # Custom aggregator class (must inherit from StarAggregator)
data_type='fhir', # Type of data source ('fhir' or 's3')
query='Patient?_summary=count', # Query or list of queries to retrieve data
simple_analysis=True, # True for single-iteration; False for multi-iterative analysis
output_type='str', # Output format for the final result ('str', 'bytes', or 'pickle')
epsilon=1.0, # Privacy budget for differential privacy
sensitivity=1.0, # Sensitivity parameter for differential privacy
analyzer_kwargs=None, # Additional keyword arguments for the custom analyzer constructor (i.e. MyAnalyzer)
aggregator_kwargs=None # Additional keyword arguments for the custom aggregator constructor (i.e. MyAggregator)
)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions flame/star/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from flame.star.star_model import StarModel
from flame.star.star_localdp.star_localdp_model import StarLocalDPModel
from flame.star.analyzer_client import Analyzer as StarAnalyzer
from flame.star.aggregator_client import Aggregator as StarAggregator
from flame.star.star_model_tester import StarModelTester
10 changes: 4 additions & 6 deletions flame/star/aggregator_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,19 @@ def __init__(self, flame: Union[FlameCoreSDK, MockFlameCoreSDK]) -> None:
raise ValueError(f'Attempted to initialize aggregator node with mismatching configuration '
f'(expected: node_role="aggregator", received="{self.role}").')

def aggregate(self, node_results: list[Any], simple_analysis: bool = True) -> tuple[Any, bool]:
def aggregate(self, node_results: list[Any], simple_analysis: bool = True) -> tuple[Any, bool, bool]:
result = self.aggregation_method(node_results)

delta_criteria = self.has_converged(result, self.latest_result)
if not simple_analysis:
if self.num_iterations != 0:
converged = self.has_converged(result, self.latest_result)
else:
converged = False
converged = delta_criteria if self.num_iterations != 0 else False
else:
converged = True

self.latest_result = result
self.num_iterations += 1

return self.latest_result, converged
return self.latest_result, converged, delta_criteria

@abstractmethod
def aggregation_method(self, analysis_results: list[Any]) -> Any:
Expand Down
97 changes: 57 additions & 40 deletions flame/star/star_localdp/star_localdp_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from typing import Optional, Type, Literal, Union
from typing import Optional, Type, Literal, Union, Any

from flamesdk import FlameCoreSDK
from flame.star.aggregator_client import Aggregator
from flame.star.analyzer_client import Analyzer
from flame.star.star_model import StarModel, _ERROR_MESSAGES
from flame.utils.mock_flame_core import MockFlameCoreSDK


class StarLocalDPModel(StarModel):
flame: FlameCoreSDK
flame: Union[FlameCoreSDK, MockFlameCoreSDK]

data: Optional[list[dict[str, Any]]] = None
test_mode: bool = False

epsilon: Optional[float]
sensitivity: Optional[float]
Expand All @@ -22,60 +26,73 @@ def __init__(self,
analyzer_kwargs: Optional[dict] = None,
aggregator_kwargs: Optional[dict] = None,
epsilon: Optional[float] = None,
sensitivity: Optional[float] = None) -> None:
sensitivity: Optional[float] = None,
test_mode: bool = False,
test_kwargs: Optional[dict] = None) -> None:
self.epsilon = epsilon
self.sensitivity = sensitivity
super().__init__(analyzer=analyzer,
aggregator=aggregator,
data_type=data_type,
query=query,
simple_analysis=simple_analysis,
output_type=output_type,
analyzer_kwargs=analyzer_kwargs,
aggregator_kwargs=aggregator_kwargs)
self.epsilon = epsilon
self.sensitivity = sensitivity
aggregator_kwargs=aggregator_kwargs,
test_mode=test_mode,
test_kwargs=test_kwargs)

def _start_aggregator(self,
aggregator: Type[Aggregator],
simple_analysis: bool = True,
output_type: Literal['str', 'bytes', 'pickle'] = 'str',
aggregator_kwargs: Optional[dict] = None) -> None:
if self._is_aggregator():
if issubclass(aggregator, Aggregator):
# init custom aggregator subclass
if aggregator_kwargs is None:
aggregator = aggregator(flame=self.flame)
else:
aggregator = aggregator(flame=self.flame, **aggregator_kwargs)
aggregator_kwargs: Optional[dict] = None,
test_node_kwargs: Optional[dict[str, Any]] = None) -> None:
if issubclass(aggregator, Aggregator):
# init custom aggregator subclass
if aggregator_kwargs is None:
aggregator = aggregator(flame=self.flame)
else:
aggregator = aggregator(flame=self.flame, **aggregator_kwargs)

# Ready Check
self._wait_until_partners_ready()
if test_node_kwargs is not None:
aggregator.set_num_iterations(test_node_kwargs['num_iterations'])
aggregator.set_latest_result(test_node_kwargs['latest_result'])

# Get analyzer ids
analyzers = aggregator.partner_node_ids
# Ready Check
self._wait_until_partners_ready()

while not self._converged(): # (**)
# Await intermediate results
result_dict = self.flame.await_intermediate_data(analyzers)
# Get analyzer ids
analyzers = aggregator.partner_node_ids

# Aggregate results
agg_res, converged = aggregator.aggregate(list(result_dict.values()), simple_analysis)
self.flame.flame_log(f"Aggregated results: {str(agg_res)[:100]}")
while not aggregator.finished: # (**)
# Await intermediate results
result_dict = self.flame.await_intermediate_data(analyzers)

if converged:
self.flame.flame_log("Submitting final results using differential privacy...", end='')
if self.epsilon and self.sensitivity:
localdp = {"epsilon": self.epsilon, "sensitivity": self.sensitivity}
else:
localdp = None
response = self.flame.submit_final_result(agg_res, output_type, localdp=localdp)
self.flame.flame_log(f"success (response={response})")
self.flame.analysis_finished() # LOOP BREAK
else:
# Send aggregated result to analyzers
self.flame.send_intermediate_data(analyzers, agg_res)
# Aggregate results
agg_res, converged, delta_crit = aggregator.aggregate(list(result_dict.values()), simple_analysis)
self.flame.flame_log(f"Aggregated results: {str(agg_res)[:100]}")

aggregator.node_finished()
else:
raise BrokenPipeError(_ERROR_MESSAGES.IS_INCORRECT_CLASS.value)
if converged:
if not self.test_mode:
self.flame.flame_log("Submitting final results using differential privacy...",
log_type='info',
end='')
if delta_crit and (self.epsilon is not None) and (self.sensitivity is not None):
local_dp = {"epsilon": self.epsilon, "sensitivity": self.sensitivity}
else:
local_dp = None
if self.test_mode and (local_dp is not None):
self.flame.flame_log(f"\tTest mode: Would apply local DP with epsilon={local_dp['epsilon']} "
f"and sensitivity={local_dp['sensitivity']}",
log_type='info')
response = self.flame.submit_final_result(agg_res, output_type, local_dp=local_dp)
if not self.test_mode:
self.flame.flame_log(f"success (response={response})", log_type='info')
self.flame.analysis_finished()
aggregator.node_finished() # LOOP BREAK
else:
# Send aggregated result to analyzers
self.flame.send_intermediate_data(analyzers, agg_res)
else:
raise BrokenPipeError(_ERROR_MESSAGES.IS_ANALYZER.value)
raise BrokenPipeError(_ERROR_MESSAGES.IS_INCORRECT_CLASS.value)
14 changes: 8 additions & 6 deletions flame/star/star_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class _ERROR_MESSAGES(Enum):

class StarModel:
flame: Union[FlameCoreSDK, MockFlameCoreSDK]

data: Optional[list[dict[str, Any]]] = None
test_mode: bool = False

Expand All @@ -31,12 +32,13 @@ def __init__(self,
test_mode: bool = False,
test_kwargs: Optional[dict] = None) -> None:
self.test_mode = test_mode
if not self.test_mode:
self.flame = FlameCoreSDK()
else:
if self.test_mode:
self.flame = MockFlameCoreSDK(test_kwargs=test_kwargs)
test_node_kwargs = {'num_iterations': test_kwargs['num_iterations'],
'latest_result': test_kwargs['latest_result']} if self.test_mode else None
test_node_kwargs = {'num_iterations': test_kwargs['num_iterations'],
'latest_result': test_kwargs['latest_result']}
else:
self.flame = FlameCoreSDK()
test_node_kwargs = None

if self._is_analyzer():
self.flame.flame_log(f"Analyzer {test_kwargs['node_id'] + ' ' if self.test_mode else ''}started",
Expand Down Expand Up @@ -95,7 +97,7 @@ def _start_aggregator(self,
result_dict = self.flame.await_intermediate_data(analyzers)

# Aggregate results
agg_res, converged = aggregator.aggregate(list(result_dict.values()), simple_analysis)
agg_res, converged, _ = aggregator.aggregate(list(result_dict.values()), simple_analysis)

if converged:
if not self.test_mode:
Expand Down
Loading