From 8f9ab3a0483258ab521b9972bed11cf49057c714 Mon Sep 17 00:00:00 2001 From: Nightknight3000 Date: Thu, 22 Jan 2026 12:30:36 +0100 Subject: [PATCH 1/6] feat: update local dp pattern Co-authored-by: antidodo --- flame/star/star_localdp/star_localdp_model.py | 87 ++++---- flame/star/star_model.py | 12 +- flame/utils/mock_flame_core.py | 17 +- poetry.lock | 197 ++++++++++++++++-- pyproject.toml | 1 + 5 files changed, 258 insertions(+), 56 deletions(-) diff --git a/flame/star/star_localdp/star_localdp_model.py b/flame/star/star_localdp/star_localdp_model.py index 0e05d50..4863024 100644 --- a/flame/star/star_localdp/star_localdp_model.py +++ b/flame/star/star_localdp/star_localdp_model.py @@ -1,13 +1,17 @@ -from typing import Optional, Type, Literal, Union +from typing import Optional, Type, Literal, Union, Any from flamesdk import FlameCoreSDK from flame.star.aggregator_client import Aggregator from flame.star.analyzer_client import Analyzer from flame.star.star_model import StarModel, _ERROR_MESSAGES +from flame.utils.mock_flame_core import MockFlameCoreSDK class StarLocalDPModel(StarModel): - flame: FlameCoreSDK + flame: Union[FlameCoreSDK, MockFlameCoreSDK] + + data: Optional[list[dict[str, Any]]] = None + test_mode: bool = False epsilon: Optional[float] sensitivity: Optional[float] @@ -21,6 +25,8 @@ def __init__(self, output_type: Literal['str', 'bytes', 'pickle'] = 'str', analyzer_kwargs: Optional[dict] = None, aggregator_kwargs: Optional[dict] = None, + test_mode: bool = False, + test_kwargs: Optional[dict] = None, epsilon: Optional[float] = None, sensitivity: Optional[float] = None) -> None: super().__init__(analyzer=analyzer, @@ -30,7 +36,9 @@ def __init__(self, simple_analysis=simple_analysis, output_type=output_type, analyzer_kwargs=analyzer_kwargs, - aggregator_kwargs=aggregator_kwargs) + aggregator_kwargs=aggregator_kwargs, + test_mode=test_mode, + test_kwargs=test_kwargs) self.epsilon = epsilon self.sensitivity = sensitivity @@ -38,44 +46,49 @@ def _start_aggregator(self, aggregator: Type[Aggregator], simple_analysis: bool = True, output_type: Literal['str', 'bytes', 'pickle'] = 'str', - aggregator_kwargs: Optional[dict] = None) -> None: - if self._is_aggregator(): - if issubclass(aggregator, Aggregator): - # init custom aggregator subclass - if aggregator_kwargs is None: - aggregator = aggregator(flame=self.flame) - else: - aggregator = aggregator(flame=self.flame, **aggregator_kwargs) + aggregator_kwargs: Optional[dict] = None, + test_node_kwargs: Optional[dict[str, Any]] = None) -> None: + if issubclass(aggregator, Aggregator): + # init custom aggregator subclass + if aggregator_kwargs is None: + aggregator = aggregator(flame=self.flame) + else: + aggregator = aggregator(flame=self.flame, **aggregator_kwargs) - # Ready Check - self._wait_until_partners_ready() + if test_node_kwargs is not None: + aggregator.set_num_iterations(test_node_kwargs['num_iterations']) + aggregator.set_latest_result(test_node_kwargs['latest_result']) - # Get analyzer ids - analyzers = aggregator.partner_node_ids + # Ready Check + self._wait_until_partners_ready() - while not self._converged(): # (**) - # Await intermediate results - result_dict = self.flame.await_intermediate_data(analyzers) + # Get analyzer ids + analyzers = aggregator.partner_node_ids - # Aggregate results - agg_res, converged = aggregator.aggregate(list(result_dict.values()), simple_analysis) - self.flame.flame_log(f"Aggregated results: {str(agg_res)[:100]}") + while not aggregator.finished: # (**) + # Await intermediate results + result_dict = self.flame.await_intermediate_data(analyzers) - if converged: - self.flame.flame_log("Submitting final results using differential privacy...", end='') - if self.epsilon and self.sensitivity: - localdp = {"epsilon": self.epsilon, "sensitivity": self.sensitivity} - else: - localdp = None - response = self.flame.submit_final_result(agg_res, output_type, localdp=localdp) - self.flame.flame_log(f"success (response={response})") - self.flame.analysis_finished() # LOOP BREAK - else: - # Send aggregated result to analyzers - self.flame.send_intermediate_data(analyzers, agg_res) + # Aggregate results + agg_res, converged = aggregator.aggregate(list(result_dict.values()), simple_analysis) + self.flame.flame_log(f"Aggregated results: {str(agg_res)[:100]}") - aggregator.node_finished() - else: - raise BrokenPipeError(_ERROR_MESSAGES.IS_INCORRECT_CLASS.value) + if converged: + if not self.test_mode: + self.flame.flame_log("Submitting final results using differential privacy...", + log_type='info', + end='') + if (self.epsilon is not None) and (self.sensitivity is not None): + local_dp = {"epsilon": self.epsilon, "sensitivity": self.sensitivity} + else: + local_dp = None + response = self.flame.submit_final_result(agg_res, output_type, local_dp=local_dp) + if not self.test_mode: + self.flame.flame_log(f"success (response={response})", log_type='info') + self.flame.analysis_finished() + aggregator.node_finished() # LOOP BREAK + else: + # Send aggregated result to analyzers + self.flame.send_intermediate_data(analyzers, agg_res) else: - raise BrokenPipeError(_ERROR_MESSAGES.IS_ANALYZER.value) \ No newline at end of file + raise BrokenPipeError(_ERROR_MESSAGES.IS_INCORRECT_CLASS.value) diff --git a/flame/star/star_model.py b/flame/star/star_model.py index 1660f66..37ea5cf 100644 --- a/flame/star/star_model.py +++ b/flame/star/star_model.py @@ -16,6 +16,7 @@ class _ERROR_MESSAGES(Enum): class StarModel: flame: Union[FlameCoreSDK, MockFlameCoreSDK] + data: Optional[list[dict[str, Any]]] = None test_mode: bool = False @@ -31,12 +32,13 @@ def __init__(self, test_mode: bool = False, test_kwargs: Optional[dict] = None) -> None: self.test_mode = test_mode - if not self.test_mode: - self.flame = FlameCoreSDK() - else: + if self.test_mode: self.flame = MockFlameCoreSDK(test_kwargs=test_kwargs) - test_node_kwargs = {'num_iterations': test_kwargs['num_iterations'], - 'latest_result': test_kwargs['latest_result']} if self.test_mode else None + test_node_kwargs = {'num_iterations': test_kwargs['num_iterations'], + 'latest_result': test_kwargs['latest_result']} + else: + self.flame = FlameCoreSDK() + test_node_kwargs = None if self._is_analyzer(): self.flame.flame_log(f"Analyzer {test_kwargs['node_id'] + ' ' if self.test_mode else ''}started", diff --git a/flame/utils/mock_flame_core.py b/flame/utils/mock_flame_core.py index 1fb62ef..8f4c865 100644 --- a/flame/utils/mock_flame_core.py +++ b/flame/utils/mock_flame_core.py @@ -3,6 +3,10 @@ from io import StringIO from typing import Any, Literal, Optional, Union +from opendp.domains import atom_domain +from opendp.measurements import make_laplace +from opendp.metrics import absolute_distance + _REQUIRED_KWARGS = ['node_id', 'participant_ids', 'role'] @@ -175,8 +179,19 @@ def send_message_and_wait_for_responses(self, ########################################Storage Client########################################### def submit_final_result(self, - result: Any, output_type: Literal['str', 'bytes', 'pickle'] = 'str', + result: Any, + output_type: Literal['str', 'bytes', 'pickle'] = 'str', local_dp: Optional[dict] = None) -> dict[str, str]: + if local_dp is not None: + if type(result) in [int, float]: + scale = local_dp['sensitivity'] / local_dp['epsilon'] # Laplace scale parameter + laplace_mech = make_laplace(input_domain=atom_domain(T=float), + input_metric=absolute_distance(T=float), + scale=scale) + result = laplace_mech(float(result)) + else: + self.flame_log("Given result type is not supported for local DP -> DP step will be skipped.", + log_type='warning') self.final_results_storage = result def save_intermediate_data(self, diff --git a/poetry.lock b/poetry.lock index a1c4b5b..3d1ab73 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand. [[package]] name = "annotated-types" @@ -51,7 +51,7 @@ description = "Validate configuration and produce human readable error messages. optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "python_version == \"3.9\"" +markers = "python_version < \"3.10\"" files = [ {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"}, {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"}, @@ -77,7 +77,7 @@ description = "Composable command line interface toolkit" optional = false python-versions = ">=3.7" groups = ["main"] -markers = "python_version == \"3.9\"" +markers = "python_version < \"3.10\"" files = [ {file = "click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2"}, {file = "click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a"}, @@ -115,6 +115,24 @@ files = [ ] markers = {main = "platform_system == \"Windows\"", dev = "sys_platform == \"win32\""} +[[package]] +name = "deprecated" +version = "1.3.1" +description = "Python @deprecated decorator to deprecate old python classes, functions or methods." +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" +groups = ["dev"] +files = [ + {file = "deprecated-1.3.1-py2.py3-none-any.whl", hash = "sha256:597bfef186b6f60181535a29fbe44865ce137a5079f295b479886c82729d5f3f"}, + {file = "deprecated-1.3.1.tar.gz", hash = "sha256:b1b50e0ff0c1fddaa5708a2c6b0a6588bb09b892825ab2b214ac9ea9d92a5223"}, +] + +[package.dependencies] +wrapt = ">=1.10,<3" + +[package.extras] +dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "setuptools ; python_version >= \"3.12\"", "tox"] + [[package]] name = "distlib" version = "0.4.0" @@ -173,7 +191,7 @@ description = "A platform independent file lock." optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "python_version == \"3.9\"" +markers = "python_version < \"3.10\"" files = [ {file = "filelock-3.19.1-py3-none-any.whl", hash = "sha256:d38e30481def20772f5baf097c122c3babc4fcdb7e14e57049eb9d88c6dc017d"}, {file = "filelock-3.19.1.tar.gz", hash = "sha256:66eda1888b0171c998b35be2bcc0f6d75c388a7ce20c3f3f37aa8e96c2dddf58"}, @@ -280,7 +298,7 @@ description = "File identification library for Python" optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "python_version == \"3.9\"" +markers = "python_version < \"3.10\"" files = [ {file = "identify-2.6.15-py2.py3-none-any.whl", hash = "sha256:1181ef7608e00704db228516541eb83a88a9f94433a8c80bb9b5bd54b1d81757"}, {file = "identify-2.6.15.tar.gz", hash = "sha256:e4f4864b96c6557ef2a1e1c951771838f4edc9df3a72ec7118b338801b11c7bf"}, @@ -327,7 +345,7 @@ description = "brain-dead simple config-ini parsing" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "python_version == \"3.9\"" +markers = "python_version < \"3.10\"" files = [ {file = "iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760"}, {file = "iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7"}, @@ -358,16 +376,42 @@ files = [ {file = "nodeenv-1.10.0.tar.gz", hash = "sha256:996c191ad80897d076bdfba80a41994c2b47c68e224c542b48feba42ba00f8bb"}, ] +[[package]] +name = "opendp" +version = "0.12.1" +description = "Python bindings for the OpenDP Library" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "opendp-0.12.1-cp39-abi3-macosx_10_13_x86_64.whl", hash = "sha256:72edcd516e606a983ceaf828663655e46ed7d2a712e6335845413672ce10b89a"}, + {file = "opendp-0.12.1-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:6315380316fada9fd051ac0d0e46d323da1c9a509f0a808908a9c980be4f448a"}, + {file = "opendp-0.12.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c95c3ab7b8e61f94a76f08aafb9dc099d761845ad85da4f217c153fb612ab545"}, + {file = "opendp-0.12.1-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b66cc7bcacbacd5a911a536b0fb463b2084964fed38b1466ca266e54d717ba67"}, + {file = "opendp-0.12.1-cp39-abi3-win32.whl", hash = "sha256:b714a4776dfe8af01f6d480d56b5d7595b645da9379acdaf98cd29111df88725"}, + {file = "opendp-0.12.1-cp39-abi3-win_amd64.whl", hash = "sha256:4590b5297c3572456ecaab451e5560d4bd5fec645f509b49df52625220b2a286"}, + {file = "opendp-0.12.1-py3-none-manylinux_2_24_aarch64.whl", hash = "sha256:2258ef20a959f5b6acd97a2b611a26a68a69a34647fbf72d85d72420d41c9e4e"}, + {file = "opendp-0.12.1.tar.gz", hash = "sha256:5b17a83733c903958a49ef2fd72e9620169bc0f7bab7c03a20aba66bfdc3fa2e"}, +] + +[package.dependencies] +deprecated = "*" + +[package.extras] +numpy = ["numpy", "randomgen (>=2.0.0)"] +polars = ["numpy", "polars (==1.12.0)", "pyarrow", "randomgen (>=2.0.0)", "scikit-learn"] +scikit-learn = ["numpy", "randomgen (>=2.0.0)", "scikit-learn"] + [[package]] name = "packaging" -version = "25.0" +version = "26.0" description = "Core utilities for Python packages" optional = false python-versions = ">=3.8" groups = ["dev"] files = [ - {file = "packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484"}, - {file = "packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f"}, + {file = "packaging-26.0-py3-none-any.whl", hash = "sha256:b36f1fef9334a5588b4166f8bcd26a14e521f2b55e6b9de3aaa80d3ff7a37529"}, + {file = "packaging-26.0.tar.gz", hash = "sha256:00243ae351a257117b6a241061796684b084ed1c516a08c48a3f7e147a9d80b4"}, ] [[package]] @@ -377,7 +421,7 @@ description = "A small Python package for determining appropriate platform-speci optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "python_version == \"3.9\"" +markers = "python_version < \"3.10\"" files = [ {file = "platformdirs-4.4.0-py3-none-any.whl", hash = "sha256:abd01743f24e5287cd7a5db3752faf1a2d65353f38ec26d98e25a6db65958c85"}, {file = "platformdirs-4.4.0.tar.gz", hash = "sha256:ca753cf4d81dc309bc67b0ea38fd15dc97bc30ce419a7f58d13eb3bf14c4febf"}, @@ -429,7 +473,7 @@ description = "A framework for managing and maintaining multi-language pre-commi optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "python_version == \"3.9\"" +markers = "python_version < \"3.10\"" files = [ {file = "pre_commit-4.3.0-py2.py3-none-any.whl", hash = "sha256:2b0747ad7e6e967169136edffee14c16e148a778a54e4f967921aa1ebf2308d8"}, {file = "pre_commit-4.3.0.tar.gz", hash = "sha256:499fe450cc9d42e9d58e606262795ecb64dd05438943c62b66f6a8673da30b16"}, @@ -640,7 +684,7 @@ description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "python_version == \"3.9\"" +markers = "python_version < \"3.10\"" files = [ {file = "pytest-8.4.2-py3-none-any.whl", hash = "sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79"}, {file = "pytest-8.4.2.tar.gz", hash = "sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01"}, @@ -691,6 +735,13 @@ optional = false python-versions = ">=3.8" groups = ["dev"] files = [ + {file = "PyYAML-6.0.3-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:c2514fceb77bc5e7a2f7adfaa1feb2fb311607c9cb518dbc378688ec73d8292f"}, + {file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9c57bb8c96f6d1808c030b1687b9b5fb476abaa47f0db9c0101f5e9f394e97f4"}, + {file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:efd7b85f94a6f21e4932043973a7ba2613b059c4a000551892ac9f1d11f5baf3"}, + {file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:22ba7cfcad58ef3ecddc7ed1db3409af68d023b7f940da23c6c2a1890976eda6"}, + {file = "PyYAML-6.0.3-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:6344df0d5755a2c9a276d4473ae6b90647e216ab4757f8426893b5dd2ac3f369"}, + {file = "PyYAML-6.0.3-cp38-cp38-win32.whl", hash = "sha256:3ff07ec89bae51176c0549bc4c63aa6202991da2d9a6129d7aef7f1407d3f295"}, + {file = "PyYAML-6.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:5cf4e27da7e3fbed4d6c3d8e797387aaad68102272f8f9752883bc32d61cb87b"}, {file = "pyyaml-6.0.3-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:214ed4befebe12df36bcc8bc2b64b396ca31be9304b8f59e25c11cf94a4c033b"}, {file = "pyyaml-6.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:02ea2dfa234451bbb8772601d7b8e426c2bfa197136796224e50e35a78777956"}, {file = "pyyaml-6.0.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b30236e45cf30d2b8e7b3e85881719e98507abed1011bf463a8fa23e9c3e98a8"}, @@ -950,7 +1001,127 @@ typing-extensions = {version = ">=4.13.2", markers = "python_version < \"3.11\"" docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8) ; platform_python_implementation == \"PyPy\" or platform_python_implementation == \"GraalVM\" or platform_python_implementation == \"CPython\" and sys_platform == \"win32\" and python_version >= \"3.13\"", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10) ; platform_python_implementation == \"CPython\""] +[[package]] +name = "wrapt" +version = "2.0.1" +description = "Module for decorators, wrappers and monkey patching." +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "wrapt-2.0.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:64b103acdaa53b7caf409e8d45d39a8442fe6dcfec6ba3f3d141e0cc2b5b4dbd"}, + {file = "wrapt-2.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:91bcc576260a274b169c3098e9a3519fb01f2989f6d3d386ef9cbf8653de1374"}, + {file = "wrapt-2.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ab594f346517010050126fcd822697b25a7031d815bb4fbc238ccbe568216489"}, + {file = "wrapt-2.0.1-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:36982b26f190f4d737f04a492a68accbfc6fa042c3f42326fdfbb6c5b7a20a31"}, + {file = "wrapt-2.0.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:23097ed8bc4c93b7bf36fa2113c6c733c976316ce0ee2c816f64ca06102034ef"}, + {file = "wrapt-2.0.1-cp310-cp310-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:8bacfe6e001749a3b64db47bcf0341da757c95959f592823a93931a422395013"}, + {file = "wrapt-2.0.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:8ec3303e8a81932171f455f792f8df500fc1a09f20069e5c16bd7049ab4e8e38"}, + {file = "wrapt-2.0.1-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:3f373a4ab5dbc528a94334f9fe444395b23c2f5332adab9ff4ea82f5a9e33bc1"}, + {file = "wrapt-2.0.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f49027b0b9503bf6c8cdc297ca55006b80c2f5dd36cecc72c6835ab6e10e8a25"}, + {file = "wrapt-2.0.1-cp310-cp310-win32.whl", hash = "sha256:8330b42d769965e96e01fa14034b28a2a7600fbf7e8f0cc90ebb36d492c993e4"}, + {file = "wrapt-2.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:1218573502a8235bb8a7ecaed12736213b22dcde9feab115fa2989d42b5ded45"}, + {file = "wrapt-2.0.1-cp310-cp310-win_arm64.whl", hash = "sha256:eda8e4ecd662d48c28bb86be9e837c13e45c58b8300e43ba3c9b4fa9900302f7"}, + {file = "wrapt-2.0.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:0e17283f533a0d24d6e5429a7d11f250a58d28b4ae5186f8f47853e3e70d2590"}, + {file = "wrapt-2.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:85df8d92158cb8f3965aecc27cf821461bb5f40b450b03facc5d9f0d4d6ddec6"}, + {file = "wrapt-2.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c1be685ac7700c966b8610ccc63c3187a72e33cab53526a27b2a285a662cd4f7"}, + {file = "wrapt-2.0.1-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:df0b6d3b95932809c5b3fecc18fda0f1e07452d05e2662a0b35548985f256e28"}, + {file = "wrapt-2.0.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4da7384b0e5d4cae05c97cd6f94faaf78cc8b0f791fc63af43436d98c4ab37bb"}, + {file = "wrapt-2.0.1-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ec65a78fbd9d6f083a15d7613b2800d5663dbb6bb96003899c834beaa68b242c"}, + {file = "wrapt-2.0.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7de3cc939be0e1174969f943f3b44e0d79b6f9a82198133a5b7fc6cc92882f16"}, + {file = "wrapt-2.0.1-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:fb1a5b72cbd751813adc02ef01ada0b0d05d3dcbc32976ce189a1279d80ad4a2"}, + {file = "wrapt-2.0.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:3fa272ca34332581e00bf7773e993d4f632594eb2d1b0b162a9038df0fd971dd"}, + {file = "wrapt-2.0.1-cp311-cp311-win32.whl", hash = "sha256:fc007fdf480c77301ab1afdbb6ab22a5deee8885f3b1ed7afcb7e5e84a0e27be"}, + {file = "wrapt-2.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:47434236c396d04875180171ee1f3815ca1eada05e24a1ee99546320d54d1d1b"}, + {file = "wrapt-2.0.1-cp311-cp311-win_arm64.whl", hash = "sha256:837e31620e06b16030b1d126ed78e9383815cbac914693f54926d816d35d8edf"}, + {file = "wrapt-2.0.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:1fdbb34da15450f2b1d735a0e969c24bdb8d8924892380126e2a293d9902078c"}, + {file = "wrapt-2.0.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:3d32794fe940b7000f0519904e247f902f0149edbe6316c710a8562fb6738841"}, + {file = "wrapt-2.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:386fb54d9cd903ee0012c09291336469eb7b244f7183d40dc3e86a16a4bace62"}, + {file = "wrapt-2.0.1-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:7b219cb2182f230676308cdcacd428fa837987b89e4b7c5c9025088b8a6c9faf"}, + {file = "wrapt-2.0.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:641e94e789b5f6b4822bb8d8ebbdfc10f4e4eae7756d648b717d980f657a9eb9"}, + {file = "wrapt-2.0.1-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:fe21b118b9f58859b5ebaa4b130dee18669df4bd111daad082b7beb8799ad16b"}, + {file = "wrapt-2.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:17fb85fa4abc26a5184d93b3efd2dcc14deb4b09edcdb3535a536ad34f0b4dba"}, + {file = "wrapt-2.0.1-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:b89ef9223d665ab255ae42cc282d27d69704d94be0deffc8b9d919179a609684"}, + {file = "wrapt-2.0.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a453257f19c31b31ba593c30d997d6e5be39e3b5ad9148c2af5a7314061c63eb"}, + {file = "wrapt-2.0.1-cp312-cp312-win32.whl", hash = "sha256:3e271346f01e9c8b1130a6a3b0e11908049fe5be2d365a5f402778049147e7e9"}, + {file = "wrapt-2.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:2da620b31a90cdefa9cd0c2b661882329e2e19d1d7b9b920189956b76c564d75"}, + {file = "wrapt-2.0.1-cp312-cp312-win_arm64.whl", hash = "sha256:aea9c7224c302bc8bfc892b908537f56c430802560e827b75ecbde81b604598b"}, + {file = "wrapt-2.0.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:47b0f8bafe90f7736151f61482c583c86b0693d80f075a58701dd1549b0010a9"}, + {file = "wrapt-2.0.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:cbeb0971e13b4bd81d34169ed57a6dda017328d1a22b62fda45e1d21dd06148f"}, + {file = "wrapt-2.0.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:eb7cffe572ad0a141a7886a1d2efa5bef0bf7fe021deeea76b3ab334d2c38218"}, + {file = "wrapt-2.0.1-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:c8d60527d1ecfc131426b10d93ab5d53e08a09c5fa0175f6b21b3252080c70a9"}, + {file = "wrapt-2.0.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c654eafb01afac55246053d67a4b9a984a3567c3808bb7df2f8de1c1caba2e1c"}, + {file = "wrapt-2.0.1-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:98d873ed6c8b4ee2418f7afce666751854d6d03e3c0ec2a399bb039cd2ae89db"}, + {file = "wrapt-2.0.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c9e850f5b7fc67af856ff054c71690d54fa940c3ef74209ad9f935b4f66a0233"}, + {file = "wrapt-2.0.1-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:e505629359cb5f751e16e30cf3f91a1d3ddb4552480c205947da415d597f7ac2"}, + {file = "wrapt-2.0.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:2879af909312d0baf35f08edeea918ee3af7ab57c37fe47cb6a373c9f2749c7b"}, + {file = "wrapt-2.0.1-cp313-cp313-win32.whl", hash = "sha256:d67956c676be5a24102c7407a71f4126d30de2a569a1c7871c9f3cabc94225d7"}, + {file = "wrapt-2.0.1-cp313-cp313-win_amd64.whl", hash = "sha256:9ca66b38dd642bf90c59b6738af8070747b610115a39af2498535f62b5cdc1c3"}, + {file = "wrapt-2.0.1-cp313-cp313-win_arm64.whl", hash = "sha256:5a4939eae35db6b6cec8e7aa0e833dcca0acad8231672c26c2a9ab7a0f8ac9c8"}, + {file = "wrapt-2.0.1-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:a52f93d95c8d38fed0669da2ebdb0b0376e895d84596a976c15a9eb45e3eccb3"}, + {file = "wrapt-2.0.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:4e54bbf554ee29fcceee24fa41c4d091398b911da6e7f5d7bffda963c9aed2e1"}, + {file = "wrapt-2.0.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:908f8c6c71557f4deaa280f55d0728c3bca0960e8c3dd5ceeeafb3c19942719d"}, + {file = "wrapt-2.0.1-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:e2f84e9af2060e3904a32cea9bb6db23ce3f91cfd90c6b426757cf7cc01c45c7"}, + {file = "wrapt-2.0.1-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e3612dc06b436968dfb9142c62e5dfa9eb5924f91120b3c8ff501ad878f90eb3"}, + {file = "wrapt-2.0.1-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:6d2d947d266d99a1477cd005b23cbd09465276e302515e122df56bb9511aca1b"}, + {file = "wrapt-2.0.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:7d539241e87b650cbc4c3ac9f32c8d1ac8a54e510f6dca3f6ab60dcfd48c9b10"}, + {file = "wrapt-2.0.1-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:4811e15d88ee62dbf5c77f2c3ff3932b1e3ac92323ba3912f51fc4016ce81ecf"}, + {file = "wrapt-2.0.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:c1c91405fcf1d501fa5d55df21e58ea49e6b879ae829f1039faaf7e5e509b41e"}, + {file = "wrapt-2.0.1-cp313-cp313t-win32.whl", hash = "sha256:e76e3f91f864e89db8b8d2a8311d57df93f01ad6bb1e9b9976d1f2e83e18315c"}, + {file = "wrapt-2.0.1-cp313-cp313t-win_amd64.whl", hash = "sha256:83ce30937f0ba0d28818807b303a412440c4b63e39d3d8fc036a94764b728c92"}, + {file = "wrapt-2.0.1-cp313-cp313t-win_arm64.whl", hash = "sha256:4b55cacc57e1dc2d0991dbe74c6419ffd415fb66474a02335cb10efd1aa3f84f"}, + {file = "wrapt-2.0.1-cp314-cp314-macosx_10_13_universal2.whl", hash = "sha256:5e53b428f65ece6d9dad23cb87e64506392b720a0b45076c05354d27a13351a1"}, + {file = "wrapt-2.0.1-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:ad3ee9d0f254851c71780966eb417ef8e72117155cff04821ab9b60549694a55"}, + {file = "wrapt-2.0.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:d7b822c61ed04ee6ad64bc90d13368ad6eb094db54883b5dde2182f67a7f22c0"}, + {file = "wrapt-2.0.1-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:7164a55f5e83a9a0b031d3ffab4d4e36bbec42e7025db560f225489fa929e509"}, + {file = "wrapt-2.0.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e60690ba71a57424c8d9ff28f8d006b7ad7772c22a4af432188572cd7fa004a1"}, + {file = "wrapt-2.0.1-cp314-cp314-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:3cd1a4bd9a7a619922a8557e1318232e7269b5fb69d4ba97b04d20450a6bf970"}, + {file = "wrapt-2.0.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:b4c2e3d777e38e913b8ce3a6257af72fb608f86a1df471cb1d4339755d0a807c"}, + {file = "wrapt-2.0.1-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:3d366aa598d69416b5afedf1faa539fac40c1d80a42f6b236c88c73a3c8f2d41"}, + {file = "wrapt-2.0.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:c235095d6d090aa903f1db61f892fffb779c1eaeb2a50e566b52001f7a0f66ed"}, + {file = "wrapt-2.0.1-cp314-cp314-win32.whl", hash = "sha256:bfb5539005259f8127ea9c885bdc231978c06b7a980e63a8a61c8c4c979719d0"}, + {file = "wrapt-2.0.1-cp314-cp314-win_amd64.whl", hash = "sha256:4ae879acc449caa9ed43fc36ba08392b9412ee67941748d31d94e3cedb36628c"}, + {file = "wrapt-2.0.1-cp314-cp314-win_arm64.whl", hash = "sha256:8639b843c9efd84675f1e100ed9e99538ebea7297b62c4b45a7042edb84db03e"}, + {file = "wrapt-2.0.1-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:9219a1d946a9b32bb23ccae66bdb61e35c62773ce7ca6509ceea70f344656b7b"}, + {file = "wrapt-2.0.1-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:fa4184e74197af3adad3c889a1af95b53bb0466bced92ea99a0c014e48323eec"}, + {file = "wrapt-2.0.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:c5ef2f2b8a53b7caee2f797ef166a390fef73979b15778a4a153e4b5fedce8fa"}, + {file = "wrapt-2.0.1-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:e042d653a4745be832d5aa190ff80ee4f02c34b21f4b785745eceacd0907b815"}, + {file = "wrapt-2.0.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2afa23318136709c4b23d87d543b425c399887b4057936cd20386d5b1422b6fa"}, + {file = "wrapt-2.0.1-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:6c72328f668cf4c503ffcf9434c2b71fdd624345ced7941bc6693e61bbe36bef"}, + {file = "wrapt-2.0.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:3793ac154afb0e5b45d1233cb94d354ef7a983708cc3bb12563853b1d8d53747"}, + {file = "wrapt-2.0.1-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:fec0d993ecba3991645b4857837277469c8cc4c554a7e24d064d1ca291cfb81f"}, + {file = "wrapt-2.0.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:949520bccc1fa227274da7d03bf238be15389cd94e32e4297b92337df9b7a349"}, + {file = "wrapt-2.0.1-cp314-cp314t-win32.whl", hash = "sha256:be9e84e91d6497ba62594158d3d31ec0486c60055c49179edc51ee43d095f79c"}, + {file = "wrapt-2.0.1-cp314-cp314t-win_amd64.whl", hash = "sha256:61c4956171c7434634401db448371277d07032a81cc21c599c22953374781395"}, + {file = "wrapt-2.0.1-cp314-cp314t-win_arm64.whl", hash = "sha256:35cdbd478607036fee40273be8ed54a451f5f23121bd9d4be515158f9498f7ad"}, + {file = "wrapt-2.0.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:90897ea1cf0679763b62e79657958cd54eae5659f6360fc7d2ccc6f906342183"}, + {file = "wrapt-2.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:50844efc8cdf63b2d90cd3d62d4947a28311e6266ce5235a219d21b195b4ec2c"}, + {file = "wrapt-2.0.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:49989061a9977a8cbd6d20f2efa813f24bf657c6990a42967019ce779a878dbf"}, + {file = "wrapt-2.0.1-cp38-cp38-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:09c7476ab884b74dce081ad9bfd07fe5822d8600abade571cb1f66d5fc915af6"}, + {file = "wrapt-2.0.1-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d1a8a09a004ef100e614beec82862d11fc17d601092c3599afd22b1f36e4137e"}, + {file = "wrapt-2.0.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:89a82053b193837bf93c0f8a57ded6e4b6d88033a499dadff5067e912c2a41e9"}, + {file = "wrapt-2.0.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:f26f8e2ca19564e2e1fdbb6a0e47f36e0efbab1acc31e15471fad88f828c75f6"}, + {file = "wrapt-2.0.1-cp38-cp38-win32.whl", hash = "sha256:115cae4beed3542e37866469a8a1f2b9ec549b4463572b000611e9946b86e6f6"}, + {file = "wrapt-2.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:c4012a2bd37059d04f8209916aa771dfb564cccb86079072bdcd48a308b6a5c5"}, + {file = "wrapt-2.0.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:68424221a2dc00d634b54f92441914929c5ffb1c30b3b837343978343a3512a3"}, + {file = "wrapt-2.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6bd1a18f5a797fe740cb3d7a0e853a8ce6461cc62023b630caec80171a6b8097"}, + {file = "wrapt-2.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fb3a86e703868561c5cad155a15c36c716e1ab513b7065bd2ac8ed353c503333"}, + {file = "wrapt-2.0.1-cp39-cp39-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:5dc1b852337c6792aa111ca8becff5bacf576bf4a0255b0f05eb749da6a1643e"}, + {file = "wrapt-2.0.1-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c046781d422f0830de6329fa4b16796096f28a92c8aef3850674442cdcb87b7f"}, + {file = "wrapt-2.0.1-cp39-cp39-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:f73f9f7a0ebd0db139253d27e5fc8d2866ceaeef19c30ab5d69dcbe35e1a6981"}, + {file = "wrapt-2.0.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:b667189cf8efe008f55bbda321890bef628a67ab4147ebf90d182f2dadc78790"}, + {file = "wrapt-2.0.1-cp39-cp39-musllinux_1_2_riscv64.whl", hash = "sha256:a9a83618c4f0757557c077ef71d708ddd9847ed66b7cc63416632af70d3e2308"}, + {file = "wrapt-2.0.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:1e9b121e9aeb15df416c2c960b8255a49d44b4038016ee17af03975992d03931"}, + {file = "wrapt-2.0.1-cp39-cp39-win32.whl", hash = "sha256:1f186e26ea0a55f809f232e92cc8556a0977e00183c3ebda039a807a42be1494"}, + {file = "wrapt-2.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:bf4cb76f36be5de950ce13e22e7fdf462b35b04665a12b64f3ac5c1bbbcf3728"}, + {file = "wrapt-2.0.1-cp39-cp39-win_arm64.whl", hash = "sha256:d6cc985b9c8b235bd933990cdbf0f891f8e010b65a3911f7a55179cd7b0fc57b"}, + {file = "wrapt-2.0.1-py3-none-any.whl", hash = "sha256:4d2ce1bf1a48c5277d7969259232b57645aae5686dba1eaeade39442277afbca"}, + {file = "wrapt-2.0.1.tar.gz", hash = "sha256:9c9c635e78497cacb81e84f8b11b23e0aacac7a136e73b8e5b2109a1d9fc468f"}, +] + +[package.extras] +dev = ["pytest", "setuptools"] + [metadata] lock-version = "2.1" python-versions = "^3.9" -content-hash = "e89c01282a61cc81b44df28aada7058145f4bc4afdc672770932db92928bfc99" +content-hash = "acd8c5581356e1dcc7ef41fd72238154838f97a24c53f845adf7c801f6eae20d" diff --git a/pyproject.toml b/pyproject.toml index 7be3a6a..da17d15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ flamesdk = {git = "https://github.com/PrivateAIM/python-sdk.git", tag = "0.3.1" pytest = ">=8.3.0" ruff = ">=0.9.0" pre-commit = ">=4.0.0" +opendp = ">=0.12.1,<0.13.0" [build-system] requires = ["poetry-core"] From 2914987cc74271138b31a94413f111d361266277 Mon Sep 17 00:00:00 2001 From: Nightknight3000 Date: Thu, 22 Jan 2026 13:42:37 +0100 Subject: [PATCH 2/6] fix: add local dp variables to mock tester Co-authored-by: antidodo --- flame/star/__init__.py | 1 + flame/star/star_localdp/star_localdp_model.py | 6 ++-- flame/star/star_model_tester.py | 33 +++++++++++-------- 3 files changed, 24 insertions(+), 16 deletions(-) diff --git a/flame/star/__init__.py b/flame/star/__init__.py index 6cd6e0a..38859ea 100644 --- a/flame/star/__init__.py +++ b/flame/star/__init__.py @@ -1,4 +1,5 @@ from flame.star.star_model import StarModel +from flame.star.star_localdp.star_localdp_model import StarLocalDPModel from flame.star.analyzer_client import Analyzer as StarAnalyzer from flame.star.aggregator_client import Aggregator as StarAggregator from flame.star.star_model_tester import StarModelTester diff --git a/flame/star/star_localdp/star_localdp_model.py b/flame/star/star_localdp/star_localdp_model.py index 4863024..f16fe1d 100644 --- a/flame/star/star_localdp/star_localdp_model.py +++ b/flame/star/star_localdp/star_localdp_model.py @@ -25,10 +25,10 @@ def __init__(self, output_type: Literal['str', 'bytes', 'pickle'] = 'str', analyzer_kwargs: Optional[dict] = None, aggregator_kwargs: Optional[dict] = None, - test_mode: bool = False, - test_kwargs: Optional[dict] = None, epsilon: Optional[float] = None, - sensitivity: Optional[float] = None) -> None: + sensitivity: Optional[float] = None, + test_mode: bool = False, + test_kwargs: Optional[dict] = None) -> None: super().__init__(analyzer=analyzer, aggregator=aggregator, data_type=data_type, diff --git a/flame/star/star_model_tester.py b/flame/star/star_model_tester.py index 163e620..fa462a1 100644 --- a/flame/star/star_model_tester.py +++ b/flame/star/star_model_tester.py @@ -1,7 +1,7 @@ import pickle from typing import Any, Type, Literal, Optional, Union -from flame.star import StarModel, StarAnalyzer, StarAggregator +from flame.star import StarModel, StarLocalDPModel, StarAnalyzer, StarAggregator from flame.utils.mock_flame_core import MockFlameCoreSDK @@ -22,6 +22,8 @@ def __init__(self, output_type: Literal['str', 'bytes', 'pickle'] = 'str', analyzer_kwargs: Optional[dict] = None, aggregator_kwargs: Optional[dict] = None, + epsilon: Optional[float] = None, + sensitivity: Optional[float] = None, result_filepath: str = None) -> None: self.agg_index = len(data_splits) while not self.converged: @@ -34,7 +36,9 @@ def __init__(self, query, output_type, analyzer_kwargs, - aggregator_kwargs) + aggregator_kwargs, + epsilon, + sensitivity) if simple_analysis: self.write_result(result, output_type, result_filepath) self.converged = True @@ -56,7 +60,9 @@ def simulate_nodes(self, query: Optional[Union[str, list[str]]], output_type: Literal['str', 'bytes', 'pickle'], analyzer_kwargs: Optional[dict] = None, - aggregator_kwargs: Optional[dict] = None) -> tuple[Any, dict[str, Any]]: + aggregator_kwargs: Optional[dict] = None, + epsilon: Optional[float] = None, + sensitivity: Optional[float] = None,) -> tuple[Any, dict[str, Any]]: sim_nodes = {} agg_kwargs = None for i in range(len(data_splits) + 1): @@ -73,16 +79,17 @@ def simulate_nodes(self, if i == self.agg_index: agg_kwargs = test_kwargs - sim_nodes[node_id] = StarModel(analyzer, - aggregator, - data_type, - query, - True, - output_type, - analyzer_kwargs, - aggregator_kwargs, - test_mode=True, - test_kwargs=test_kwargs) + pattern = StarModel if (epsilon is None) or (sensitivity is None) else StarLocalDPModel + sim_nodes[node_id] = pattern(analyzer, + aggregator, + data_type, + query, + True, + output_type, + analyzer_kwargs, + aggregator_kwargs, + test_mode=True, + test_kwargs=test_kwargs) return sim_nodes[f"node_{self.agg_index}"].flame.final_results_storage, agg_kwargs @staticmethod From 54ebad64df1ea7e4b723b96ecc0f853331e0d8f6 Mon Sep 17 00:00:00 2001 From: davidhieber Date: Thu, 22 Jan 2026 15:17:22 +0100 Subject: [PATCH 3/6] feat: enhance local differential privacy support for local testing Co-authored-by: Nightknight3000 --- flame/star/aggregator_client.py | 10 ++--- flame/star/star_localdp/star_localdp_model.py | 13 +++++-- flame/star/star_model.py | 2 +- flame/star/star_model_tester.py | 37 +++++++++++++------ flame/utils/mock_flame_core.py | 2 + poetry.lock | 25 +++++-------- pyproject.toml | 2 +- 7 files changed, 51 insertions(+), 40 deletions(-) diff --git a/flame/star/aggregator_client.py b/flame/star/aggregator_client.py index 4daf8c0..173891d 100644 --- a/flame/star/aggregator_client.py +++ b/flame/star/aggregator_client.py @@ -14,21 +14,19 @@ def __init__(self, flame: Union[FlameCoreSDK, MockFlameCoreSDK]) -> None: raise ValueError(f'Attempted to initialize aggregator node with mismatching configuration ' f'(expected: node_role="aggregator", received="{self.role}").') - def aggregate(self, node_results: list[Any], simple_analysis: bool = True) -> tuple[Any, bool]: + def aggregate(self, node_results: list[Any], simple_analysis: bool = True) -> tuple[Any, bool, bool]: result = self.aggregation_method(node_results) + delta_criteria = self.has_converged(result, self.latest_result) if self.num_iterations != 0 else False if not simple_analysis: - if self.num_iterations != 0: - converged = self.has_converged(result, self.latest_result) - else: - converged = False + converged = delta_criteria else: converged = True self.latest_result = result self.num_iterations += 1 - return self.latest_result, converged + return self.latest_result, converged, delta_criteria @abstractmethod def aggregation_method(self, analysis_results: list[Any]) -> Any: diff --git a/flame/star/star_localdp/star_localdp_model.py b/flame/star/star_localdp/star_localdp_model.py index f16fe1d..f8a6e67 100644 --- a/flame/star/star_localdp/star_localdp_model.py +++ b/flame/star/star_localdp/star_localdp_model.py @@ -29,6 +29,8 @@ def __init__(self, sensitivity: Optional[float] = None, test_mode: bool = False, test_kwargs: Optional[dict] = None) -> None: + self.epsilon = epsilon + self.sensitivity = sensitivity super().__init__(analyzer=analyzer, aggregator=aggregator, data_type=data_type, @@ -39,8 +41,6 @@ def __init__(self, aggregator_kwargs=aggregator_kwargs, test_mode=test_mode, test_kwargs=test_kwargs) - self.epsilon = epsilon - self.sensitivity = sensitivity def _start_aggregator(self, aggregator: Type[Aggregator], @@ -70,7 +70,7 @@ def _start_aggregator(self, result_dict = self.flame.await_intermediate_data(analyzers) # Aggregate results - agg_res, converged = aggregator.aggregate(list(result_dict.values()), simple_analysis) + agg_res, converged, delta_crit = aggregator.aggregate(list(result_dict.values()), simple_analysis) self.flame.flame_log(f"Aggregated results: {str(agg_res)[:100]}") if converged: @@ -78,12 +78,17 @@ def _start_aggregator(self, self.flame.flame_log("Submitting final results using differential privacy...", log_type='info', end='') - if (self.epsilon is not None) and (self.sensitivity is not None): + if delta_crit and (self.epsilon is not None) and (self.sensitivity is not None): local_dp = {"epsilon": self.epsilon, "sensitivity": self.sensitivity} else: local_dp = None + if self.test_mode and (local_dp is not None): + self.flame.flame_log(f"\tTest mode: Would apply local DP with epsilon={local_dp['epsilon']} " + f"and sensitivity={local_dp['sensitivity']}", + log_type='info') response = self.flame.submit_final_result(agg_res, output_type, local_dp=local_dp) if not self.test_mode: + self.has_converged(agg_res, aggregator.latest_result) self.flame.flame_log(f"success (response={response})", log_type='info') self.flame.analysis_finished() aggregator.node_finished() # LOOP BREAK diff --git a/flame/star/star_model.py b/flame/star/star_model.py index 37ea5cf..5fcb12c 100644 --- a/flame/star/star_model.py +++ b/flame/star/star_model.py @@ -97,7 +97,7 @@ def _start_aggregator(self, result_dict = self.flame.await_intermediate_data(analyzers) # Aggregate results - agg_res, converged = aggregator.aggregate(list(result_dict.values()), simple_analysis) + agg_res, converged, _ = aggregator.aggregate(list(result_dict.values()), simple_analysis) if converged: if not self.test_mode: diff --git a/flame/star/star_model_tester.py b/flame/star/star_model_tester.py index fa462a1..b74d8df 100644 --- a/flame/star/star_model_tester.py +++ b/flame/star/star_model_tester.py @@ -62,7 +62,7 @@ def simulate_nodes(self, analyzer_kwargs: Optional[dict] = None, aggregator_kwargs: Optional[dict] = None, epsilon: Optional[float] = None, - sensitivity: Optional[float] = None,) -> tuple[Any, dict[str, Any]]: + sensitivity: Optional[float] = None) -> tuple[Any, dict[str, Any]]: sim_nodes = {} agg_kwargs = None for i in range(len(data_splits) + 1): @@ -79,17 +79,30 @@ def simulate_nodes(self, if i == self.agg_index: agg_kwargs = test_kwargs - pattern = StarModel if (epsilon is None) or (sensitivity is None) else StarLocalDPModel - sim_nodes[node_id] = pattern(analyzer, - aggregator, - data_type, - query, - True, - output_type, - analyzer_kwargs, - aggregator_kwargs, - test_mode=True, - test_kwargs=test_kwargs) + if (epsilon is None) or (sensitivity is None): + sim_nodes[node_id] = StarModel(analyzer, + aggregator, + data_type, + query, + True, + output_type, + analyzer_kwargs, + aggregator_kwargs, + test_mode=True, + test_kwargs=test_kwargs) + else: + sim_nodes[node_id] = StarLocalDPModel(analyzer, + aggregator, + data_type, + query, + True, + output_type, + analyzer_kwargs, + aggregator_kwargs, + epsilon=epsilon, + sensitivity=sensitivity, + test_mode=True, + test_kwargs=test_kwargs) return sim_nodes[f"node_{self.agg_index}"].flame.final_results_storage, agg_kwargs @staticmethod diff --git a/flame/utils/mock_flame_core.py b/flame/utils/mock_flame_core.py index 8f4c865..d852ec9 100644 --- a/flame/utils/mock_flame_core.py +++ b/flame/utils/mock_flame_core.py @@ -3,6 +3,7 @@ from io import StringIO from typing import Any, Literal, Optional, Union +from opendp.mod import enable_features from opendp.domains import atom_domain from opendp.measurements import make_laplace from opendp.metrics import absolute_distance @@ -184,6 +185,7 @@ def submit_final_result(self, local_dp: Optional[dict] = None) -> dict[str, str]: if local_dp is not None: if type(result) in [int, float]: + enable_features("contrib") scale = local_dp['sensitivity'] / local_dp['epsilon'] # Laplace scale parameter laplace_mech = make_laplace(input_domain=atom_domain(T=float), input_metric=absolute_distance(T=float), diff --git a/poetry.lock b/poetry.lock index 3d1ab73..44b0f95 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] name = "annotated-types" @@ -51,7 +51,7 @@ description = "Validate configuration and produce human readable error messages. optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "python_version < \"3.10\"" +markers = "python_version == \"3.9\"" files = [ {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"}, {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"}, @@ -77,7 +77,7 @@ description = "Composable command line interface toolkit" optional = false python-versions = ">=3.7" groups = ["main"] -markers = "python_version < \"3.10\"" +markers = "python_version == \"3.9\"" files = [ {file = "click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2"}, {file = "click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a"}, @@ -191,7 +191,7 @@ description = "A platform independent file lock." optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "python_version < \"3.10\"" +markers = "python_version == \"3.9\"" files = [ {file = "filelock-3.19.1-py3-none-any.whl", hash = "sha256:d38e30481def20772f5baf097c122c3babc4fcdb7e14e57049eb9d88c6dc017d"}, {file = "filelock-3.19.1.tar.gz", hash = "sha256:66eda1888b0171c998b35be2bcc0f6d75c388a7ce20c3f3f37aa8e96c2dddf58"}, @@ -298,7 +298,7 @@ description = "File identification library for Python" optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "python_version < \"3.10\"" +markers = "python_version == \"3.9\"" files = [ {file = "identify-2.6.15-py2.py3-none-any.whl", hash = "sha256:1181ef7608e00704db228516541eb83a88a9f94433a8c80bb9b5bd54b1d81757"}, {file = "identify-2.6.15.tar.gz", hash = "sha256:e4f4864b96c6557ef2a1e1c951771838f4edc9df3a72ec7118b338801b11c7bf"}, @@ -345,7 +345,7 @@ description = "brain-dead simple config-ini parsing" optional = false python-versions = ">=3.8" groups = ["dev"] -markers = "python_version < \"3.10\"" +markers = "python_version == \"3.9\"" files = [ {file = "iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760"}, {file = "iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7"}, @@ -421,7 +421,7 @@ description = "A small Python package for determining appropriate platform-speci optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "python_version < \"3.10\"" +markers = "python_version == \"3.9\"" files = [ {file = "platformdirs-4.4.0-py3-none-any.whl", hash = "sha256:abd01743f24e5287cd7a5db3752faf1a2d65353f38ec26d98e25a6db65958c85"}, {file = "platformdirs-4.4.0.tar.gz", hash = "sha256:ca753cf4d81dc309bc67b0ea38fd15dc97bc30ce419a7f58d13eb3bf14c4febf"}, @@ -473,7 +473,7 @@ description = "A framework for managing and maintaining multi-language pre-commi optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "python_version < \"3.10\"" +markers = "python_version == \"3.9\"" files = [ {file = "pre_commit-4.3.0-py2.py3-none-any.whl", hash = "sha256:2b0747ad7e6e967169136edffee14c16e148a778a54e4f967921aa1ebf2308d8"}, {file = "pre_commit-4.3.0.tar.gz", hash = "sha256:499fe450cc9d42e9d58e606262795ecb64dd05438943c62b66f6a8673da30b16"}, @@ -684,7 +684,7 @@ description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.9" groups = ["dev"] -markers = "python_version < \"3.10\"" +markers = "python_version == \"3.9\"" files = [ {file = "pytest-8.4.2-py3-none-any.whl", hash = "sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79"}, {file = "pytest-8.4.2.tar.gz", hash = "sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01"}, @@ -735,13 +735,6 @@ optional = false python-versions = ">=3.8" groups = ["dev"] files = [ - {file = "PyYAML-6.0.3-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:c2514fceb77bc5e7a2f7adfaa1feb2fb311607c9cb518dbc378688ec73d8292f"}, - {file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9c57bb8c96f6d1808c030b1687b9b5fb476abaa47f0db9c0101f5e9f394e97f4"}, - {file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:efd7b85f94a6f21e4932043973a7ba2613b059c4a000551892ac9f1d11f5baf3"}, - {file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:22ba7cfcad58ef3ecddc7ed1db3409af68d023b7f940da23c6c2a1890976eda6"}, - {file = "PyYAML-6.0.3-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:6344df0d5755a2c9a276d4473ae6b90647e216ab4757f8426893b5dd2ac3f369"}, - {file = "PyYAML-6.0.3-cp38-cp38-win32.whl", hash = "sha256:3ff07ec89bae51176c0549bc4c63aa6202991da2d9a6129d7aef7f1407d3f295"}, - {file = "PyYAML-6.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:5cf4e27da7e3fbed4d6c3d8e797387aaad68102272f8f9752883bc32d61cb87b"}, {file = "pyyaml-6.0.3-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:214ed4befebe12df36bcc8bc2b64b396ca31be9304b8f59e25c11cf94a4c033b"}, {file = "pyyaml-6.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:02ea2dfa234451bbb8772601d7b8e426c2bfa197136796224e50e35a78777956"}, {file = "pyyaml-6.0.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b30236e45cf30d2b8e7b3e85881719e98507abed1011bf463a8fa23e9c3e98a8"}, diff --git a/pyproject.toml b/pyproject.toml index da17d15..75f89a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ packages = [{ include = "flame" }] [tool.poetry.dependencies] python = "^3.9" -flamesdk = {git = "https://github.com/PrivateAIM/python-sdk.git", tag = "0.3.1" } +flamesdk = {git = "https://github.com/PrivateAIM/python-sdk.git", tag = "0.3.1"} [tool.poetry.group.dev.dependencies] pytest = ">=8.3.0" From 8baaf809df90a244f0766ea3e484aff86a26b6c0 Mon Sep 17 00:00:00 2001 From: Nightknight3000 Date: Fri, 23 Jan 2026 10:52:17 +0100 Subject: [PATCH 4/6] feat: add custom kwargs support for aggregator instance in convergence test Co-authored-by: antidodo --- flame/star/star_model_tester.py | 39 +++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/flame/star/star_model_tester.py b/flame/star/star_model_tester.py index b74d8df..3ffd0fc 100644 --- a/flame/star/star_model_tester.py +++ b/flame/star/star_model_tester.py @@ -29,21 +29,21 @@ def __init__(self, while not self.converged: print(f"--- Starting Iteration {self.num_iterations} ---") - result, agg_kwargs = self.simulate_nodes(data_splits, - analyzer, - aggregator, - data_type, - query, - output_type, - analyzer_kwargs, - aggregator_kwargs, - epsilon, - sensitivity) + result, test_agg_kwargs = self.simulate_nodes(data_splits, + analyzer, + aggregator, + data_type, + query, + output_type, + analyzer_kwargs, + aggregator_kwargs, + epsilon, + sensitivity) if simple_analysis: self.write_result(result, output_type, result_filepath) self.converged = True else: - self.converged = self.check_convergence(aggregator, agg_kwargs, result) + self.converged = self.check_convergence(aggregator, test_agg_kwargs, result, aggregator_kwargs) if self.converged: self.write_result(result, output_type, result_filepath) else: @@ -107,13 +107,18 @@ def simulate_nodes(self, @staticmethod def check_convergence(aggregator: Type[StarAggregator], - agg_kwargs: dict[str, Any], - result: Any) -> bool: - if all(k in agg_kwargs.keys() for k in ('num_iterations', 'latest_result')): - agg = aggregator(MockFlameCoreSDK(test_kwargs=agg_kwargs)) - agg.set_num_iterations(agg_kwargs['num_iterations']) + test_agg_kwargs: dict[str, Any], + result: Any, + aggregator_kwargs: Optional[dict] = None) -> bool: + if all(k in test_agg_kwargs.keys() for k in ('num_iterations', 'latest_result')): + if aggregator_kwargs is None: + agg = aggregator(MockFlameCoreSDK(test_kwargs=test_agg_kwargs)) + else: + agg = aggregator(MockFlameCoreSDK(test_kwargs=test_agg_kwargs), **aggregator_kwargs) + agg.set_num_iterations(test_agg_kwargs['num_iterations']) + if agg.num_iterations != 0: - return agg.has_converged(result=result, last_result=agg_kwargs['latest_result']) + return agg.has_converged(result=result, last_result=test_agg_kwargs['latest_result']) else: return False else: From 5b1933a9f12de7bacb1e10c4c9bab02bbb1f030e Mon Sep 17 00:00:00 2001 From: davidhieber Date: Fri, 23 Jan 2026 11:42:53 +0100 Subject: [PATCH 5/6] fix: refine convergence criteria handling in aggregation process Co-authored-by: Nightknight3000 --- flame/star/aggregator_client.py | 4 ++-- flame/star/star_localdp/star_localdp_model.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/flame/star/aggregator_client.py b/flame/star/aggregator_client.py index 173891d..5177dae 100644 --- a/flame/star/aggregator_client.py +++ b/flame/star/aggregator_client.py @@ -17,9 +17,9 @@ def __init__(self, flame: Union[FlameCoreSDK, MockFlameCoreSDK]) -> None: def aggregate(self, node_results: list[Any], simple_analysis: bool = True) -> tuple[Any, bool, bool]: result = self.aggregation_method(node_results) - delta_criteria = self.has_converged(result, self.latest_result) if self.num_iterations != 0 else False + delta_criteria = self.has_converged(result, self.latest_result) if not simple_analysis: - converged = delta_criteria + converged = delta_criteria if self.num_iterations != 0 else False else: converged = True diff --git a/flame/star/star_localdp/star_localdp_model.py b/flame/star/star_localdp/star_localdp_model.py index f8a6e67..597175a 100644 --- a/flame/star/star_localdp/star_localdp_model.py +++ b/flame/star/star_localdp/star_localdp_model.py @@ -88,7 +88,6 @@ def _start_aggregator(self, log_type='info') response = self.flame.submit_final_result(agg_res, output_type, local_dp=local_dp) if not self.test_mode: - self.has_converged(agg_res, aggregator.latest_result) self.flame.flame_log(f"success (response={response})", log_type='info') self.flame.analysis_finished() aggregator.node_finished() # LOOP BREAK From 8fffeacecb5378b939b6ec9237458f2a252c6b8e Mon Sep 17 00:00:00 2001 From: davidhieber Date: Fri, 23 Jan 2026 11:58:54 +0100 Subject: [PATCH 6/6] feat: implement custom analyzer and aggregator for local differential privacy model Co-authored-by: Nightknight3000 --- examples/run_star_model.py | 3 +- examples/run_star_model_dp.py | 89 +++++++++++++++++++++++++++++++++++ test/test_star_pattern_dp.py | 46 ++++++++++++++++++ 3 files changed, 136 insertions(+), 2 deletions(-) create mode 100644 examples/run_star_model_dp.py create mode 100644 test/test_star_pattern_dp.py diff --git a/examples/run_star_model.py b/examples/run_star_model.py index 9e26073..1d724fd 100644 --- a/examples/run_star_model.py +++ b/examples/run_star_model.py @@ -50,13 +50,12 @@ def aggregation_method(self, analysis_results): total_patient_count = sum(analysis_results) return total_patient_count - def has_converged(self, result, last_result, num_iterations): + def has_converged(self, result, last_result): """ Determines if the aggregation process has converged. :param result: The current aggregated result. :param last_result: The aggregated result from the previous iteration. - :param num_iterations: The number of iterations completed so far. :return: True if the aggregation has converged; False to continue iterations. """ # TODO (optional): if the parameter 'simple_analysis' in 'StarModel' is set to False, diff --git a/examples/run_star_model_dp.py b/examples/run_star_model_dp.py new file mode 100644 index 0000000..9f318a3 --- /dev/null +++ b/examples/run_star_model_dp.py @@ -0,0 +1,89 @@ +from flame.star import StarLocalDPModel, StarAnalyzer, StarAggregator + + +class MyAnalyzer(StarAnalyzer): + def __init__(self, flame): + """ + Initializes the custom Analyzer node. + + :param flame: Instance of FlameCoreSDK to interact with the FLAME components. + """ + super().__init__(flame) # Connects this analyzer to the FLAME components + + def analysis_method(self, data, aggregator_results): + """ + Performs analysis on the retrieved data from data sources. + + :param data: A list of dictionaries containing the data from each data source. + - Each dictionary corresponds to a data source. + - Keys are the queries executed, and values are the results (dict for FHIR, str for S3). + :param aggregator_results: Results from the aggregator in previous iterations. + - None in the first iteration. + - Contains the result from the aggregator's aggregation_method in subsequent iterations. + :return: Any result of your analysis on one node (ex. patient count). + """ + # TODO: Implement your analysis method + # in this example we retrieving first fhir dataset, extract patient counts, + # take total number of patients + patient_count = float(data[0]['Patient?_summary=count']['total']) + return patient_count + + +class MyAggregator(StarAggregator): + def __init__(self, flame): + """ + Initializes the custom Aggregator node. + + :param flame: Instance of FlameCoreSDK to interact with the FLAME components. + """ + super().__init__(flame) # Connects this aggregator to the FLAME components + + def aggregation_method(self, analysis_results): + """ + Aggregates the results received from all analyzer nodes. + + :param analysis_results: A list of analysis results from each analyzer node. + :return: The aggregated result (e.g., total patient count across all analyzers). + """ + # TODO: Implement your aggregation method + # in this example we retrieving sum up total patient counts across all nodes + total_patient_count = sum(analysis_results) + return total_patient_count + + def has_converged(self, result, last_result): + """ + Determines if the aggregation process has converged. + + :param result: The current aggregated result. + :param last_result: The aggregated result from the previous iteration. + :return: True if the aggregation has converged; False to continue iterations. + """ + # TODO (optional): if the parameter 'simple_analysis' in 'StarModel' is set to False, + # this function defines the exit criteria in a multi-iterative analysis (otherwise ignored) + return True # Return True to indicate convergence in this simple analysis + + +def main(): + """ + Sets up and initiates the distributed analysis using the FLAME components. + + - Defines the custom analyzer and aggregator classes. + - Specifies the type of data and queries to execute. + - Configures analysis parameters like iteration behavior and output format. + """ + StarLocalDPModel( + analyzer=MyAnalyzer, # Custom analyzer class (must inherit from StarAnalyzer) + aggregator=MyAggregator, # Custom aggregator class (must inherit from StarAggregator) + data_type='fhir', # Type of data source ('fhir' or 's3') + query='Patient?_summary=count', # Query or list of queries to retrieve data + simple_analysis=True, # True for single-iteration; False for multi-iterative analysis + output_type='str', # Output format for the final result ('str', 'bytes', or 'pickle') + epsilon=1.0, # Privacy budget for differential privacy + sensitivity=1.0, # Sensitivity parameter for differential privacy + analyzer_kwargs=None, # Additional keyword arguments for the custom analyzer constructor (i.e. MyAnalyzer) + aggregator_kwargs=None # Additional keyword arguments for the custom aggregator constructor (i.e. MyAggregator) + ) + + +if __name__ == "__main__": + main() diff --git a/test/test_star_pattern_dp.py b/test/test_star_pattern_dp.py new file mode 100644 index 0000000..a2a5aa5 --- /dev/null +++ b/test/test_star_pattern_dp.py @@ -0,0 +1,46 @@ +from typing import Any, Optional +from flame.star import StarAnalyzer, StarAggregator +from flame.star.star_model_tester import StarModelTester + + +class MyAnalyzer(StarAnalyzer): + def __init__(self, flame): + super().__init__(flame) + + def analysis_method(self, data, aggregator_results): + self.flame.flame_log(f"\tAggregator results in MyAnalyzer: {aggregator_results}", log_type='debug') + analysis_result = sum(data) / len(data) \ + if aggregator_results is None \ + else (sum(data) / len(data) + aggregator_results) + 1 / 2 + self.flame.flame_log(f"MyAnalysis result ({self.id}): {analysis_result}", log_type='notice') + return analysis_result + + +class MyAggregator(StarAggregator): + def __init__(self, flame): + super().__init__(flame) + + def aggregation_method(self, analysis_results: list[Any]) -> Any: + self.flame.flame_log(f"\tAnalysis results in MyAggregator: {analysis_results}", log_type='notice') + result = sum(analysis_results) / len(analysis_results) + self.flame.flame_log(f"MyAggregator result ({self.id}): {result}", log_type='notice') + return result + + def has_converged(self, result: Any, last_result: Optional[Any]) -> bool: + self.flame.flame_log(f"\tLast result: {last_result}, Current result: {result}", log_type="notice") + self.flame.flame_log(f"\tChecking convergence at iteration {self.num_iterations}", log_type="notice") + return self.num_iterations >= 5 # Limit to 5 iterations for testing + + +if __name__ == "__main__": + data_1 = [1, 2, 3, 4] + data_2 = [5, 6, 7, 8] + data_splits = [data_1, data_2] + + StarModelTester(data_splits=data_splits, # TODO: Insert your data fragments in a list + analyzer=MyAnalyzer, # TODO: Replace with your custom Analyzer class + aggregator=MyAggregator, # TODO: Replace with your custom Aggregator class + data_type='s3', # TODO: Specify data type ('fhir' or 's3') + simple_analysis=False, + epsilon=1, + sensitivity=10**0)