Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion api_admin/settings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import logging

from django.contrib.auth import get_user_model
from django.contrib.auth.models import Group
from django.db import transaction
Expand All @@ -11,13 +13,18 @@
validate_go_inception_payload,
)
from common.config import SysConfig
from sql.inventory import (
INVENTORY_REFRESH_INTERVAL_CHOICES,
ensure_inventory_refresh_schedule,
)
from sql.engines.mysql_ddl import validate_binary_path
from sql.models import InstanceTag, ResourceGroup

from api_core.permissions import IsStaffOrSuperuser
from api_core.response import success_response

User = get_user_model()
logger = logging.getLogger("default")

DEFAULT_CHAT_MODEL = "gpt-3.5-turbo"
DEFAULT_QUERY_TEMPLATE = (
Expand All @@ -38,6 +45,7 @@
STORAGE_TYPE_OPTIONS = ("local", "sftp", "s3c", "azure")
SMS_PROVIDER_OPTIONS = ("disabled", "aliyun", "tencent")
TASK_BACKEND_OPTIONS = ("django_q", "celery")
INVENTORY_REFRESH_INTERVAL_OPTIONS = INVENTORY_REFRESH_INTERVAL_CHOICES

SYSTEM_SETTINGS_SCHEMA = (
{"name": "go_inception_host", "kind": "string", "default": ""},
Expand Down Expand Up @@ -77,6 +85,12 @@
"choices": TASK_BACKEND_OPTIONS,
"default": "django_q",
},
{
"name": "inventory_refresh_interval",
"kind": "choice",
"choices": INVENTORY_REFRESH_INTERVAL_OPTIONS,
"default": "24h",
},
{"name": "celery_broker_url", "kind": "string", "default": ""},
{"name": "celery_result_backend", "kind": "string", "default": ""},
{
Expand Down Expand Up @@ -281,9 +295,22 @@ def build_system_settings_options():
}
for backend in TASK_BACKEND_OPTIONS
],
"inventory_refresh_intervals": [
{"value": interval, "label": interval}
for interval in INVENTORY_REFRESH_INTERVAL_OPTIONS
],
}


def sync_inventory_refresh_schedule(force=False):
try:
ensure_inventory_refresh_schedule(force=force)
return True
except Exception as exc:
logger.exception("Failed to synchronize the inventory refresh schedule.")
return False


class SystemSettingsSerializer(serializers.Serializer):
def get_fields(self):
fields = {}
Expand Down Expand Up @@ -480,6 +507,7 @@ class SystemSettingsView(views.APIView):
)
def get(self, request):
serializer = SystemSettingsSerializer(instance=load_system_settings())
sync_inventory_refresh_schedule()
return success_response(
data={
"settings": serializer.data,
Expand All @@ -497,12 +525,20 @@ def put(self, request):
serializer = SystemSettingsSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
save_system_settings(serializer.validated_data)
schedule_synced = sync_inventory_refresh_schedule(force=True)
response_serializer = SystemSettingsSerializer(instance=load_system_settings())
detail = "System settings updated successfully."
if not schedule_synced:
detail = (
"System settings updated, but the inventory refresh schedule "
"could not be synchronized. Check the task backend and try again."
)
return success_response(
detail="System settings updated successfully.",
detail=detail,
data={
"settings": response_serializer.data,
"options": build_system_settings_options(),
"inventory_refresh_schedule_synced": schedule_synced,
},
)

Expand Down
100 changes: 90 additions & 10 deletions api_core/legacy_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
from sql.engines import ReviewSet
from sql.engines.mysql_ddl import MysqlDDLExecutorError
from sql.engines.models import ReviewResult, ResultSet
from api_admin.settings import DEFAULT_CHAT_MODEL, NOTIFY_PHASE_OPTIONS
from api_admin.settings import (
DEFAULT_CHAT_MODEL,
INVENTORY_REFRESH_INTERVAL_OPTIONS,
NOTIFY_PHASE_OPTIONS,
)
from sql.models import (
ResourceGroup,
Instance,
Expand Down Expand Up @@ -112,6 +116,48 @@ def test_get_system_settings_includes_task_backend_options(self):
{"value": "celery", "label": "Celery"},
payload["data"]["options"]["task_backends"],
)
self.assertEqual(
payload["data"]["settings"]["inventory_refresh_interval"], "24h"
)
self.assertEqual(
payload["data"]["options"]["inventory_refresh_intervals"],
[
{"value": interval, "label": interval}
for interval in INVENTORY_REFRESH_INTERVAL_OPTIONS
],
)

def test_put_system_settings_saves_inventory_refresh_interval(self):
response = self.client.put(
"/api/v1/system-settings/",
data=json.dumps({"inventory_refresh_interval": "6h"}),
content_type="application/json",
)

self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(
response.json()["data"]["settings"]["inventory_refresh_interval"], "6h"
)
self.assertTrue(response.json()["data"]["inventory_refresh_schedule_synced"])
self.assertEqual(self.sys_config.get("inventory_refresh_interval"), "6h")

@patch("api_admin.settings.sync_inventory_refresh_schedule", return_value=False)
def test_put_system_settings_surfaces_inventory_schedule_sync_warning(
self, _mock_sync
):
response = self.client.put(
"/api/v1/system-settings/",
data=json.dumps({"inventory_refresh_interval": "12h"}),
content_type="application/json",
)

self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(
response.json()["detail"],
"System settings updated, but the inventory refresh schedule could not be synchronized. Check the task backend and try again.",
)
self.assertFalse(response.json()["data"]["inventory_refresh_schedule_synced"])
self.assertEqual(self.sys_config.get("inventory_refresh_interval"), "12h")

def test_put_system_settings_requires_broker_url_for_celery(self):
response = self.client.put(
Expand Down Expand Up @@ -1483,6 +1529,29 @@ def test_get_instance_list(self):
self.assertEqual(r.status_code, status.HTTP_200_OK)
self.assertEqual(response_data(r)["count"], 1)

def test_get_instance_list_includes_inventory_snapshot_fields(self):
self.ins.inventory_status = Instance.INVENTORY_STATUS_OK
self.ins.inventory_detected_hostname = "detected-host"
self.ins.inventory_detected_version = "8.0.36"
self.ins.inventory_last_success_at = datetime.now()
self.ins.save(
update_fields=[
"inventory_status",
"inventory_detected_hostname",
"inventory_detected_version",
"inventory_last_success_at",
]
)

r = self.client.get("/api/v1/instance/", format="json")

self.assertEqual(r.status_code, status.HTTP_200_OK)
payload = response_data(r)["results"][0]
self.assertEqual(payload["inventory_status"], "ok")
self.assertEqual(payload["inventory_detected_hostname"], "detected-host")
self.assertEqual(payload["inventory_detected_version"], "8.0.36")
self.assertIsNotNone(payload["inventory_last_refresh_at"])

def test_get_instance_list_with_search_and_filters(self):
"""Search and filters should match legacy inventory behavior."""
read_tag = InstanceTag.objects.create(
Expand Down Expand Up @@ -2013,7 +2082,7 @@ def test_delete_instance(self):
self.assertEqual(Instance.objects.filter(instance_name="some_ins").count(), 0)


class TestPermissionRequestAPI(CacheIsolatedAPITestCase):
class TestPermissionRequestAPI_Legacy(CacheIsolatedAPITestCase):
def setUp(self):
self.review_group = Group.objects.create(name="Permission Approvers")
self.resource_group = ResourceGroup.objects.create(group_name="permission-rg")
Expand Down Expand Up @@ -2241,38 +2310,49 @@ def test_query_instance_list_includes_temporary_instance_grant(self):

def test_test_instance_connection_requires_superuser(self):
"""Connection testing stays restricted to superusers."""
self._login(self.requester)
r = self.client.post(
f"/api/v1/instance/{self.ins.id}/test-connection/",
f"/api/v1/instance/{self.instance.id}/test-connection/",
format="json",
)
self.assertEqual(r.status_code, status.HTTP_403_FORBIDDEN)

@patch("api_instances.views.get_engine")
@patch("sql.inventory.get_engine")
def test_test_instance_connection(self, mock_get_engine):
"""Superusers can run the SPA connection test action."""
self.user.is_superuser = True
self.user.save(update_fields=["is_superuser"])
self.requester.is_superuser = True
self.requester.save(update_fields=["is_superuser"])
self._login(self.requester)

mock_engine = Mock()
mock_result = Mock(error="")
mock_engine.test_connection.return_value = mock_result
mock_engine.get_inventory_details.return_value = {
"hostname": "detected-host",
"version": "8.0.36",
}
mock_get_engine.return_value = mock_engine

r = self.client.post(
f"/api/v1/instance/{self.ins.id}/test-connection/",
f"/api/v1/instance/{self.instance.id}/test-connection/",
format="json",
)
self.assertEqual(r.status_code, status.HTTP_200_OK)
payload = response_data(r)
self.assertEqual(payload["success"], True)
self.assertEqual(payload["message"], "Connection successful.")
self.instance.refresh_from_db()
self.assertEqual(self.instance.inventory_status, "ok")
self.assertEqual(self.instance.inventory_detected_hostname, "detected-host")
self.assertEqual(self.instance.inventory_detected_version, "8.0.36")

@patch("api_instances.views.get_engine")
def test_get_instance_resource(self, mock_get_engine):
"""Test querying instance resources."""
group = ResourceGroup.objects.create(group_name="instance_resource_test")
self.user.resource_group.add(group)
self.ins.resource_group.add(group)
self.query_user.resource_group.add(group)
self.instance.resource_group.add(group)
self._login(self.query_user)

mock_engine = Mock()
mock_engine.escape_string.side_effect = lambda x: x
Expand All @@ -2285,7 +2365,7 @@ def test_get_instance_resource(self, mock_get_engine):

r = self.client.get(
"/api/v1/instance/resource/",
{"instance_id": self.ins.id, "resource_type": "database"},
{"instance_id": self.instance.id, "resource_type": "database"},
format="json",
)
self.assertEqual(r.status_code, status.HTTP_200_OK)
Expand Down
7 changes: 7 additions & 0 deletions api_instances/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,9 @@ class InstanceDiagnosticKillResultSerializer(serializers.Serializer):
class InstanceListSerializer(serializers.ModelSerializer):
resource_group_ids = serializers.SerializerMethodField()
instance_tag_ids = serializers.SerializerMethodField()
inventory_last_refresh_at = serializers.DateTimeField(
source="inventory_last_success_at", read_only=True
)

def get_resource_group_ids(self, obj):
return list(
Expand Down Expand Up @@ -305,6 +308,10 @@ class Meta:
"sid",
"resource_group_ids",
"instance_tag_ids",
"inventory_status",
"inventory_detected_hostname",
"inventory_detected_version",
"inventory_last_refresh_at",
)


Expand Down
25 changes: 21 additions & 4 deletions api_instances/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
from rest_framework.response import Response

from sql.engines import ResultSet, engine_map, get_engine
from sql.inventory import (
ensure_inventory_refresh_schedule,
refresh_instance_inventory_snapshot,
)
from sql.models import (
Instance,
InstanceAccount,
Expand Down Expand Up @@ -506,6 +510,8 @@ def get_queryset(self):
Q(instance_name__icontains=search)
| Q(host__icontains=search)
| Q(user__icontains=search)
| Q(inventory_detected_hostname__icontains=search)
| Q(inventory_detected_version__icontains=search)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
)
if search.isdigit():
search_filter |= Q(id=int(search))
Expand Down Expand Up @@ -537,6 +543,14 @@ def get_queryset(self):
"-user",
"type",
"-type",
"inventory_status",
"-inventory_status",
"inventory_detected_hostname",
"-inventory_detected_hostname",
"inventory_detected_version",
"-inventory_detected_version",
"inventory_last_success_at",
"-inventory_last_success_at",
}
if ordering in allowed_ordering:
queryset = queryset.order_by(ordering, "id")
Expand All @@ -551,7 +565,7 @@ def get_queryset(self):
name="search",
type=OpenApiTypes.STR,
location=OpenApiParameter.QUERY,
description="Match instance ID, name, host, or user.",
description="Match instance ID, name, host, user, detected hostname, or detected version.",
),
OpenApiParameter(
name="type",
Expand Down Expand Up @@ -584,6 +598,10 @@ def get_queryset(self):
permission_required("sql.menu_instance_list", raise_exception=True)
)
def get(self, request):
try:
ensure_inventory_refresh_schedule()
except Exception:
logger.exception("Failed to ensure the inventory refresh schedule.")
instances = self.filter_queryset(self.get_queryset())
page_ins = self.paginate_queryset(queryset=instances)
serializer_obj = self.get_serializer(page_ins, many=True)
Expand Down Expand Up @@ -2362,8 +2380,7 @@ def post(self, request, pk):
raise Http404

try:
query_engine = get_engine(instance=instance)
test_result = query_engine.test_connection()
snapshot_result = refresh_instance_inventory_snapshot(instance=instance)
except serializers.ValidationError:
raise
except Exception:
Expand All @@ -2372,7 +2389,7 @@ def post(self, request, pk):
{"errors": "Unable to connect to instance. Check configuration."}
)

if test_result.error:
if not snapshot_result["success"]:
raise serializers.ValidationError(
{"errors": "Unable to connect to instance. Check configuration."}
)
Expand Down
Loading
Loading