diff --git a/vulnerabilities/api.py b/vulnerabilities/api.py index 7b9d383ce..8eb02ca77 100644 --- a/vulnerabilities/api.py +++ b/vulnerabilities/api.py @@ -22,17 +22,27 @@ # Visit https://github.com/nexB/vulnerablecode/ for support and download. from urllib.parse import unquote +from typing import List +from django.db.models import Q from django.urls import reverse from django_filters import rest_framework as filters from packageurl import PackageURL from rest_framework import serializers from rest_framework import viewsets +from rest_framework.decorators import action +from rest_framework.response import Response +from drf_spectacular.utils import extend_schema, inline_serializer +from drf_spectacular.types import OpenApiTypes from vulnerabilities.models import Package from vulnerabilities.models import Vulnerability from vulnerabilities.models import VulnerabilityReference +# This serializer is used for the bulk apis, to prevent wrong auto documentation +# TODO: Fix the swagger documentation for bulk apis +placeholder_serializer = inline_serializer(name="Placeholder", fields={}) + class VulnerabilityReferenceSerializer(serializers.ModelSerializer): class Meta: @@ -60,8 +70,8 @@ class Meta: fields = ["url", "vulnerability_id"] -class VulnerabilitySerializer(serializers.HyperlinkedModelSerializer): - references = VulnerabilityReferenceSerializer(many=True, source="vulnerabilityreference_set") +class MinimalVulnerabilitySerializer(serializers.HyperlinkedModelSerializer): + resolved_packages = HyperLinkedPackageSerializer( many=True, source="resolved_to", read_only=True ) @@ -69,36 +79,43 @@ class VulnerabilitySerializer(serializers.HyperlinkedModelSerializer): many=True, source="vulnerable_to", read_only=True ) + class Meta: + model = Vulnerability + fields = ["url", "unresolved_packages", "resolved_packages"] + + +class VulnerabilitySerializer(MinimalVulnerabilitySerializer): + references = VulnerabilityReferenceSerializer(many=True, source="vulnerabilityreference_set") + class Meta: model = Vulnerability fields = "__all__" -class PackageSerializer(serializers.HyperlinkedModelSerializer): +class MinimalPackageSerializer(serializers.HyperlinkedModelSerializer): unresolved_vulnerabilities = HyperLinkedVulnerabilitySerializer( many=True, source="vulnerable_to", read_only=True ) resolved_vulnerabilities = HyperLinkedVulnerabilitySerializer( many=True, source="resolved_to", read_only=True ) - purl = serializers.CharField(source="package_url") class Meta: model = Package fields = [ - "url", - "type", - "namespace", - "name", - "version", - "qualifiers", - "subpath", - "purl", "resolved_vulnerabilities", "unresolved_vulnerabilities", ] +class PackageSerializer(MinimalPackageSerializer): + purl = serializers.CharField(source="package_url") + + class Meta: + model = Package + exclude = ["vulnerabilities"] + + class PackageFilterSet(filters.FilterSet): purl = filters.CharFilter(method="filter_purl") @@ -126,6 +143,38 @@ class PackageViewSet(viewsets.ReadOnlyModelViewSet): filter_backends = (filters.DjangoFilterBackend,) filterset_class = PackageFilterSet + # TODO: Fix the swagger documentation for this endpoint + @extend_schema(request=placeholder_serializer, responses=placeholder_serializer) + @action(detail=False, methods=["post"]) + def bulk_search(self, request): + """ + See https://github.com/nexB/vulnerablecode/pull/303#issuecomment-761801639 for docs + """ + filter_list = Q() + response = {} + if not isinstance(request.data.get("packages"), list): + return Response( + status=400, + data={ + "Error": "Request needs to contain a key 'packages' which has the value of a list of package urls" # nopep8 + }, + ) + for purl in request.data["packages"]: + try: + filter_list |= Q( + **{k: v for k, v in PackageURL.from_string(purl).to_dict().items() if v} + ) + except ValueError as ve: + return Response(status=400, data={"Error": str(ve)}) + + # This handles the case when the said purl doesnt exist in db + response[purl] = {} + res = Package.objects.filter(filter_list) + for p in res: + response[p.package_url] = MinimalPackageSerializer(p, context={"request": request}).data + + return Response(response) + class VulnerabilityFilterSet(filters.FilterSet): vulnerability_id = filters.CharFilter(field_name="cve_id") @@ -141,3 +190,31 @@ class VulnerabilityViewSet(viewsets.ReadOnlyModelViewSet): paginate_by = 50 filter_backends = (filters.DjangoFilterBackend,) filterset_class = VulnerabilityFilterSet + + # TODO: Fix the swagger documentation for this endpoint + @extend_schema(request=placeholder_serializer, responses=placeholder_serializer) + @action(detail=False, methods=["post"]) + def bulk_search(self, request): + """ + See https://github.com/nexB/vulnerablecode/pull/303#issuecomment-761801619 for docs + """ + filter_list = [] + response = {} + if not isinstance(request.data.get("vulnerabilities"), list): + return Response( + status=400, + data={ + "Error": "Request needs to contain a key 'vulnerabilities' which has the value of a list of vulnerability ids" # nopep8 + }, + ) + + for cve_id in request.data["vulnerabilities"]: + filter_list.append(cve_id) + # This handles the case when the said cve doesnt exist in db + response[cve_id] = {} + res = Vulnerability.objects.filter(cve_id__in=filter_list) + for vuln in res: + response[vuln.cve_id] = MinimalVulnerabilitySerializer( + vuln, context={"request": request} + ).data + return Response(response) diff --git a/vulnerabilities/fixtures/debian.json b/vulnerabilities/fixtures/debian.json index 35a160128..27fdd0e72 100644 --- a/vulnerabilities/fixtures/debian.json +++ b/vulnerabilities/fixtures/debian.json @@ -73,7 +73,7 @@ }, { "model": "vulnerabilities.packagerelatedvulnerability", - "pk": 1, + "pk": 10, "fields": { "vulnerability": 2, "package": 2, diff --git a/vulnerabilities/tests/test_api.py b/vulnerabilities/tests/test_api.py index 967830718..092349e42 100644 --- a/vulnerabilities/tests/test_api.py +++ b/vulnerabilities/tests/test_api.py @@ -22,6 +22,7 @@ # Visit https://github.com/nexB/vulnerablecode/ for support and download. import os +from collections import OrderedDict from random import choices from unittest.mock import MagicMock from urllib.parse import quote @@ -31,6 +32,8 @@ from vulnerabilities.api import PackageSerializer from vulnerabilities.models import Package +from rest_framework.test import APIRequestFactory +from rest_framework.test import APIClient BASE_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -188,3 +191,148 @@ def test_package_serializer(self): purls = {r["purl"] for r in response} self.assertIn("pkg:deb/debian/mimetex@1.50-1.1?distro=jessie", purls) self.assertIn("pkg:deb/debian/mimetex@1.74-1?distro=jessie", purls) + + +class TestBulkAPIResponse(TestCase): + fixtures = ["debian.json"] + + def test_bulk_vulnerabilities_api(self): + request_body = {"vulnerabilities": ["CVE-2009-1382", "CVE-2014-8242", "RANDOM-CVE"]} + expected_response = { + "CVE-2009-1382": { + "resolved_packages": [ + OrderedDict( + [ + ("url", "http://testserver/api/packages/2/"), + ("purl", "pkg:deb/debian/mimetex@1.74-1?distro=jessie"), + ] + ), + OrderedDict( + [ + ("url", "http://testserver/api/packages/3/"), + ("purl", "pkg:deb/debian/mimetex@1.50-1.1?distro=jessie"), + ] + ), + ], + "unresolved_packages": [], + "url": "http://testserver/api/vulnerabilities/2/", + }, + "CVE-2014-8242": { + "resolved_packages": [], + "unresolved_packages": [ + OrderedDict( + [ + ("url", "http://testserver/api/packages/1/"), + ("purl", "pkg:deb/debian/librsync@0.9.7-10?distro=jessie"), + ] + ) + ], + "url": "http://testserver/api/vulnerabilities/1/", + }, + "RANDOM-CVE": {}, + } + + response = self.client.post( + "/api/vulnerabilities/bulk_search/", data=request_body, content_type="application/json" + ).data + assert response == expected_response + + def test_bulk_packages_api(self): + request_body = { + "packages": [ + "pkg:deb/debian/librsync@0.9.7-10?distro=jessie", + "pkg:deb/debian/mimetex@1.50-1.1?distro=jessie", + ] + } + response = self.client.post( + "/api/packages/bulk_search/", data=request_body, content_type="application/json" + ).data + expected_response = { + "pkg:deb/debian/librsync@0.9.7-10?distro=jessie": { + "resolved_vulnerabilities": [], + "unresolved_vulnerabilities": [ + OrderedDict( + [ + ("url", "http://testserver/api/vulnerabilities/1/"), + ("vulnerability_id", "CVE-2014-8242"), + ] + ) + ], + }, + "pkg:deb/debian/mimetex@1.50-1.1?distro=jessie": { + "resolved_vulnerabilities": [ + OrderedDict( + [ + ("url", "http://testserver/api/vulnerabilities/2/"), + ("vulnerability_id", "CVE-2009-1382"), + ] + ), + OrderedDict( + [ + ("url", "http://testserver/api/vulnerabilities/3/"), + ("vulnerability_id", "CVE-2009-2459"), + ] + ), + ], + "unresolved_vulnerabilities": [], + }, + } + + assert response == expected_response + + def test_invalid_request_bulk_packages(self): + error_response = { + "Error": "Request needs to contain a key 'packages' which has the value of a list of package urls" # nopep8 + } + invalid_key_request_data = {"pkg": []} + response = self.client.post( + "/api/packages/bulk_search/", + data=invalid_key_request_data, + content_type="application/json", + ).data + assert response == error_response + + valid_key_invalid_datatype_request_data = {"packages": {}} + response = self.client.post( + "/api/packages/bulk_search/", + data=valid_key_invalid_datatype_request_data, + content_type="application/json", + ).data + assert response == error_response + + invalid_purl_request_data = { + "packages": [ + "pkg:deb/debian/librsync@0.9.7-10?distro=jessie", + "pg:deb/debian/mimetex@1.50-1.1?distro=jessie", + ] + } + response = self.client.post( + "/api/packages/bulk_search/", + data=invalid_purl_request_data, + content_type="application/json", + ).data + purl_error_respones = { + "Error": "purl is missing the required \"pkg\" scheme component: 'pg:deb/debian/mimetex@1.50-1.1?distro=jessie'." # nopep8 + } + assert response == purl_error_respones + + def test_invalid_request_bulk_vulnerabilities(self): + error_response = { + "Error": "Request needs to contain a key 'vulnerabilities' which has the value of a list of vulnerability ids" # nopep8 + } + + wrong_key_data = {"xyz": []} + response = self.client.post( + "/api/vulnerabilities/bulk_search/", + data=wrong_key_data, + content_type="application/json", + ).data + assert response == error_response + + wrong_type_data = {"vulnerabilities": {}} + response = self.client.post( + "/api/vulnerabilities/bulk_search/", + data=wrong_key_data, + content_type="application/json", + ).data + assert response == error_response