diff --git a/openwisp_controller/config/apps.py b/openwisp_controller/config/apps.py index 973b46c63..fd6dfb6d0 100644 --- a/openwisp_controller/config/apps.py +++ b/openwisp_controller/config/apps.py @@ -61,6 +61,9 @@ def __setmodels__(self): self.org_limits = load_model("config", "OrganizationLimits") self.cert_model = load_model("django_x509", "Cert") self.org_model = load_model("openwisp_users", "Organization") + self.organization_config_settings_model = load_model( + "config", "OrganizationConfigSettings" + ) def connect_signals(self): """ diff --git a/openwisp_controller/config/base/multitenancy.py b/openwisp_controller/config/base/multitenancy.py index 9cdf15736..da976be96 100644 --- a/openwisp_controller/config/base/multitenancy.py +++ b/openwisp_controller/config/base/multitenancy.py @@ -2,14 +2,14 @@ from copy import deepcopy import swapper -from django.db import models +from django.db import models, transaction from django.utils.translation import gettext_lazy as _ from jsonfield import JSONField from openwisp_utils.base import KeyField, UUIDModel +from .. import tasks from ..exceptions import OrganizationDeviceLimitExceeded -from ..tasks import bulk_invalidate_config_get_cached_checksum class AbstractOrganizationConfigSettings(UUIDModel): @@ -47,6 +47,13 @@ class Meta: verbose_name_plural = verbose_name abstract = True + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if "context" in self.get_deferred_fields(): + self._initial_context = models.DEFERRED + else: + self._initial_context = deepcopy(self.context) + def __str__(self): return self.organization.name @@ -57,14 +64,32 @@ def save( self, force_insert=False, force_update=False, using=None, update_fields=None ): context_changed = False - if not self._state.adding: - db_instance = self.__class__.objects.only("context").get(id=self.id) - context_changed = db_instance.context != self.context + context_in_update = update_fields is None or "context" in update_fields + if not self._state.adding and context_in_update: + initial_context = getattr(self, "_initial_context", None) + if initial_context is not None and initial_context != models.DEFERRED: + context_changed = initial_context != self.context + elif initial_context == models.DEFERRED and context_in_update: + # Conservative: if we don't know initial state and context is + # being updated, assume it changed to avoid stale cache + context_changed = True super().save(force_insert, force_update, using, update_fields) - if context_changed: - bulk_invalidate_config_get_cached_checksum.delay( - {"device__organization_id": str(self.organization_id)} + if context_changed and self.organization.is_active: + organization_id = str(self.organization_id) + transaction.on_commit( + lambda: ( + tasks.bulk_invalidate_config_get_cached_checksum.delay( + {"device__organization_id": organization_id} + ), + tasks.invalidate_organization_vpn_cache.delay(organization_id), + ), + using=using, ) + if context_in_update: + if "context" in self.get_deferred_fields(): + self._initial_context = models.DEFERRED + else: + self._initial_context = deepcopy(self.context) class AbstractOrganizationLimits(models.Model): diff --git a/openwisp_controller/config/handlers.py b/openwisp_controller/config/handlers.py index d838d1ad3..66c1f8595 100644 --- a/openwisp_controller/config/handlers.py +++ b/openwisp_controller/config/handlers.py @@ -141,6 +141,7 @@ def devicegroup_templates_change_handler(instance, **kwargs): def organization_disabled_handler(instance, **kwargs): """ Asynchronously invalidates device and VPN controller views cache + when organization becomes inactive """ if instance.is_active: return @@ -149,6 +150,5 @@ def organization_disabled_handler(instance, **kwargs): except Organization.DoesNotExist: return if instance.is_active == db_instance.is_active: - # No change in is_active return tasks.invalidate_controller_views_cache.delay(str(instance.id)) diff --git a/openwisp_controller/config/tasks.py b/openwisp_controller/config/tasks.py index 40aa0ff33..9a9363185 100644 --- a/openwisp_controller/config/tasks.py +++ b/openwisp_controller/config/tasks.py @@ -100,6 +100,28 @@ def invalidate_vpn_server_devices_cache_change(vpn_pk): VpnClient.invalidate_clients_cache(vpn) +@shared_task(soft_time_limit=7200) +def invalidate_organization_vpn_cache(organization_id): + """ + Invalidates VPN cache for all VPNs in an organization when + organization configuration variables change. + """ + Vpn = load_model("config", "Vpn") + from .controller.views import GetVpnView + + try: + for vpn in ( + Vpn.objects.filter(organization_id=organization_id).only("id").iterator() + ): + GetVpnView.invalidate_get_vpn_cache(vpn) + vpn.invalidate_checksum_cache() + except SoftTimeLimitExceeded: + logger.exception( + "soft time limit hit while executing " + f"invalidate_organization_vpn_cache for organization {organization_id}" + ) + + @shared_task(soft_time_limit=7200) def invalidate_devicegroup_cache_delete(instance_id, model_name, **kwargs): from .api.views import DeviceGroupCommonName diff --git a/openwisp_controller/config/tests/test_handlers.py b/openwisp_controller/config/tests/test_handlers.py index d55ad30bf..191ba2549 100644 --- a/openwisp_controller/config/tests/test_handlers.py +++ b/openwisp_controller/config/tests/test_handlers.py @@ -13,25 +13,83 @@ def test_organization_disabled_handler(self, mocked_task): with self.subTest("Test task not executed on creating active orgs"): org = self._create_org() mocked_task.assert_not_called() - with self.subTest("Test task executed on changing active to inactive org"): org.is_active = False org.save() mocked_task.assert_called_once() - - mocked_task.reset_mock() - with self.subTest("Test task not executed on saving inactive org"): - org.name = "Changed named" - org.save() - mocked_task.assert_not_called() - - with self.subTest("Test task not executed on creating inactive org"): - inactive_org = self._create_org( - is_active=False, name="inactive", slug="inactive" - ) - mocked_task.assert_not_called() - with self.subTest("Test task not executed on changing inactive to active org"): + mocked_task.reset_mock() + inactive_org = self._create_org(is_active=False) + mocked_task.assert_not_called() inactive_org.is_active = True inactive_org.save() mocked_task.assert_not_called() + + +class TestOrganizationConfigSettingsVpnCacheInvalidation( + TestOrganizationMixin, TestCase +): + def _get_org_config_settings(self, org=None): + if not org: + org = self._create_org() + # Import the model directly to avoid issues with related manager + from openwisp_controller.config.models import OrganizationConfigSettings + + config_settings, _ = OrganizationConfigSettings.objects.get_or_create( + organization=org, defaults={"context": {}} + ) + return config_settings + + @patch.object(tasks.invalidate_organization_vpn_cache, "delay") + @patch.object(tasks.bulk_invalidate_config_get_cached_checksum, "delay") + def test_vpn_cache_invalidated_on_context_change( + self, config_cache_mock, vpn_cache_mock + ): + """Test VPN cache invalidation when context changes""" + config_settings = self._get_org_config_settings() + config_settings.context = {"new": "context"} + with self.captureOnCommitCallbacks(execute=True): + config_settings.save() + vpn_cache_mock.assert_called_once_with(str(config_settings.organization_id)) + config_cache_mock.assert_called_once_with( + {"device__organization_id": str(config_settings.organization_id)} + ) + + @patch.object(tasks.invalidate_organization_vpn_cache, "delay") + @patch.object(tasks.bulk_invalidate_config_get_cached_checksum, "delay") + def test_no_cache_invalidation_on_create(self, config_cache_mock, vpn_cache_mock): + """Test no VPN cache invalidation on object creation""" + with self.captureOnCommitCallbacks(execute=True): + self._get_org_config_settings() + vpn_cache_mock.assert_not_called() + config_cache_mock.assert_not_called() + + @patch.object(tasks.invalidate_organization_vpn_cache, "delay") + @patch.object(tasks.bulk_invalidate_config_get_cached_checksum, "delay") + def test_no_cache_invalidation_for_inactive_org( + self, config_cache_mock, vpn_cache_mock + ): + """Test no VPN cache invalidation for inactive organizations""" + inactive_org = self._create_org(is_active=False) + config_settings = self._get_org_config_settings(inactive_org) + config_settings.context = {"new": "context"} + with self.captureOnCommitCallbacks(execute=True): + config_settings.save() + vpn_cache_mock.assert_not_called() + config_cache_mock.assert_not_called() + + @patch.object(tasks.invalidate_organization_vpn_cache, "delay") + @patch.object(tasks.bulk_invalidate_config_get_cached_checksum, "delay") + def test_no_cache_invalidation_if_context_unchanged( + self, config_cache_mock, vpn_cache_mock + ): + """Test no VPN cache invalidation when context is unchanged""" + config_settings = self._get_org_config_settings() + original_context = config_settings.context + config_settings.registration_enabled = False + with self.captureOnCommitCallbacks(execute=True): + config_settings.save() + vpn_cache_mock.assert_not_called() + config_cache_mock.assert_not_called() + # Verify context actually didn't change + self.assertEqual(config_settings.context, original_context)