diff --git a/src/edge_proxy/environments.py b/src/edge_proxy/environments.py index 542f9fd..8127a55 100644 --- a/src/edge_proxy/environments.py +++ b/src/edge_proxy/environments.py @@ -26,6 +26,8 @@ logger = structlog.get_logger(__name__) +SERVER_API_KEY_PREFIX = "ser." + class EnvironmentService: def __init__( @@ -77,11 +79,12 @@ def get_flags_response_data( ) -> dict[str, typing.Any]: environment_document = self.get_environment(environment_key) environment = EnvironmentModel.model_validate(environment_document) + is_server_key = environment_key.startswith(SERVER_API_KEY_PREFIX) if feature: feature_state = get_environment_feature_state(environment, feature) - if not filter_out_server_key_only_feature_states( + if not is_server_key and not filter_out_server_key_only_feature_states( feature_states=[feature_state], environment=environment, ): @@ -90,10 +93,12 @@ def get_flags_response_data( data = map_feature_state_to_response_data(feature_state) else: - feature_states = filter_out_server_key_only_feature_states( - feature_states=get_environment_feature_states(environment), - environment=environment, - ) + feature_states = get_environment_feature_states(environment) + if not is_server_key: + feature_states = filter_out_server_key_only_feature_states( + feature_states=feature_states, + environment=environment, + ) data = map_feature_states_to_response_data(feature_states) return data @@ -103,6 +108,8 @@ def get_identity_response_data( ) -> dict[str, typing.Any]: environment_document = self.get_environment(environment_key) environment = EnvironmentModel.model_validate(environment_document) + is_server_key = environment_key.startswith(SERVER_API_KEY_PREFIX) + identity = IdentityModel.model_validate( self.cache.get_identity( environment_api_key=environment_key, @@ -110,14 +117,17 @@ def get_identity_response_data( ) ) trait_models = input_data.traits - flags = filter_out_server_key_only_feature_states( - feature_states=get_identity_feature_states( - environment, - identity, - override_traits=trait_models, - ), - environment=environment, + flags = get_identity_feature_states( + environment, + identity, + override_traits=trait_models, ) + + if not is_server_key: + flags = filter_out_server_key_only_feature_states( + feature_states=flags, + environment=environment, + ) data = { "traits": map_traits_to_response_data(trait_models), "flags": map_feature_states_to_response_data( diff --git a/tests/test_environments.py b/tests/test_environments.py index 21c03ad..e9bbc88 100644 --- a/tests/test_environments.py +++ b/tests/test_environments.py @@ -9,7 +9,10 @@ from pytest_mock import MockerFixture from edge_proxy.environments import EnvironmentService -from edge_proxy.exceptions import FlagsmithUnknownKeyError +from edge_proxy.exceptions import ( + FeatureNotFoundError, + FlagsmithUnknownKeyError, +) from edge_proxy.models import IdentityWithTraits from edge_proxy.settings import ( EndpointCacheSettings, @@ -230,3 +233,101 @@ async def test_get_identity_flags_response_skips_cache_for_different_identity( assert environment_service.get_identity_response_data.cache_info().currsize == 2 assert environment_service.get_identity_response_data.cache_info().misses == 2 assert environment_service.get_identity_response_data.cache_info().hits == 0 + + +@pytest.mark.asyncio +async def test_get_flags_response_data_skips_filter_for_server_key( + mocker: MockerFixture, +) -> None: + # Given + # We create a new settings object that contains a server key as a client_side_key + api_key = "ser." + environment_1_api_key + _settings = AppSettings( + environment_key_pairs=[ + {"client_side_key": api_key, "server_side_key": "ser.key"} + ] + ) + + mocked_client = mocker.AsyncMock() + mocked_client.get.return_value = mocker.MagicMock( + text=orjson.dumps(environment_1), raise_for_status=lambda: None + ) + + environment_service = EnvironmentService(settings=_settings, client=mocked_client) + await environment_service.refresh_environment_caches() + + # When + # We retrieve the flag response data + flags = environment_service.get_flags_response_data(api_key) + specific_flag = environment_service.get_flags_response_data(api_key, "feature_3") + + # Then + # we get the server-side only flag + assert len(flags) == 3 + assert flags[2].get("feature").get("name") == "feature_3" + assert specific_flag.get("feature").get("name") == "feature_3" + + +@pytest.mark.asyncio +async def test_get_flags_response_data_filters_server_side_features_for_client_key( + mocker: MockerFixture, +) -> None: + # Given + # We create a new settings object that contains a client side key + _settings = AppSettings( + environment_key_pairs=[ + {"client_side_key": environment_1_api_key, "server_side_key": "ser.key"} + ] + ) + + mocked_client = mocker.AsyncMock() + mocked_client.get.return_value = mocker.MagicMock( + text=orjson.dumps(environment_1), raise_for_status=lambda: None + ) + + environment_service = EnvironmentService(settings=_settings, client=mocked_client) + await environment_service.refresh_environment_caches() + + # When + # We retrieve the flag response data + flags = environment_service.get_flags_response_data(environment_1_api_key) + with pytest.raises(FeatureNotFoundError): + environment_service.get_flags_response_data(environment_1_api_key, "feature_3") + + # Then + # we only get the two client side flags + assert len(flags) == 2 + + +@pytest.mark.asyncio +async def test_get_identity_flags_response_skips_filter_for_server_key( + mocker: MockerFixture, +) -> None: + # Given + # We create a new settings object that contains a server key as a client_side_key + api_key = "ser." + environment_1_api_key + _settings = AppSettings( + environment_key_pairs=[ + {"client_side_key": api_key, "server_side_key": "ser.key"} + ] + ) + + mocked_client = mocker.AsyncMock() + mocked_client.get.return_value = mocker.MagicMock( + text=orjson.dumps(environment_1), raise_for_status=lambda: None + ) + + environment_service = EnvironmentService(settings=_settings, client=mocked_client) + await environment_service.refresh_environment_caches() + + # When + # We retrieve the flags for an identity + result = environment_service.get_identity_response_data( + IdentityWithTraits(identifier="foo"), api_key + ) + + # Then + # we get the server-side only flag + flags = result.get("flags") + assert len(flags) == 3 + assert flags[2].get("feature").get("name") == "feature_3"