diff --git a/changes/44391-osv-vuln-optimizations b/changes/44391-osv-vuln-optimizations new file mode 100644 index 00000000000..ed206fe0775 --- /dev/null +++ b/changes/44391-osv-vuln-optimizations @@ -0,0 +1 @@ +* Optimized OSV vulnerability scanning to query distinct software per OS version rather than per host, reducing redundant database queries for many hosts sharing the same packages. diff --git a/server/datastore/mysql/software.go b/server/datastore/mysql/software.go index b6a59b83a40..b10f3de7492 100644 --- a/server/datastore/mysql/software.go +++ b/server/datastore/mysql/software.go @@ -3309,6 +3309,88 @@ func (ds *Datastore) ListSoftwareForVulnDetection( return result, nil } +const softwareVulnDetectionBatchSize = 10000 + +func (ds *Datastore) ListSoftwareForVulnDetectionByOSVersion( + ctx context.Context, + osVer fleet.OSVersion, +) ([]fleet.Software, error) { + var softwareIDs []uint + err := sqlx.SelectContext(ctx, ds.reader(ctx), &softwareIDs, ` + SELECT DISTINCT hs.software_id + FROM host_software hs + JOIN hosts h ON hs.host_id = h.id + WHERE h.platform = ? AND h.os_version = ? + `, osVer.Platform, osVer.Name) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "listing distinct software IDs for OS version") + } + + if len(softwareIDs) == 0 { + return nil, nil + } + + var result []fleet.Software + if err := common_mysql.BatchProcessSimple(softwareIDs, softwareVulnDetectionBatchSize, func(batch []uint) error { + placeholders := strings.TrimSuffix(strings.Repeat("?,", len(batch)), ",") + query := fmt.Sprintf(` + SELECT s.id, s.name, s.version, s.release, s.arch, COALESCE(cpe.cpe, '') AS generated_cpe + FROM software s + LEFT JOIN software_cpe cpe ON s.id = cpe.software_id + WHERE s.id IN (%s) + `, placeholders) + args := make([]any, len(batch)) + for i, id := range batch { + args[i] = id + } + var batchResult []fleet.Software + if err := sqlx.SelectContext(ctx, ds.reader(ctx), &batchResult, query, args...); err != nil { + return ctxerr.Wrap(ctx, err, "fetching software details for vulnerability detection") + } + result = append(result, batchResult...) + return nil + }); err != nil { + return nil, err + } + + return result, nil +} + +func (ds *Datastore) ListSoftwareVulnerabilitiesBySoftwareIDs( + ctx context.Context, + softwareIDs []uint, + source fleet.VulnerabilitySource, +) ([]fleet.SoftwareVulnerability, error) { + if len(softwareIDs) == 0 { + return nil, nil + } + + var result []fleet.SoftwareVulnerability + if err := common_mysql.BatchProcessSimple(softwareIDs, softwareVulnDetectionBatchSize, func(batch []uint) error { + placeholders := strings.TrimSuffix(strings.Repeat("?,", len(batch)), ",") + query := fmt.Sprintf(` + SELECT software_id, cve, resolved_in_version + FROM software_cve + WHERE source = ? AND software_id IN (%s) + `, placeholders) + args := make([]any, 0, len(batch)+1) + args = append(args, source) + for _, id := range batch { + args = append(args, id) + } + var batchResult []fleet.SoftwareVulnerability + if err := sqlx.SelectContext(ctx, ds.reader(ctx), &batchResult, query, args...); err != nil { + return ctxerr.Wrap(ctx, err, "fetching software vulnerabilities by software IDs") + } + result = append(result, batchResult...) + return nil + }); err != nil { + return nil, err + } + + return result, nil +} + // ListCVEs returns all cve_meta rows published after 'maxAge' func (ds *Datastore) ListCVEs(ctx context.Context, maxAge time.Duration) ([]fleet.CVEMeta, error) { var result []fleet.CVEMeta diff --git a/server/datastore/mysql/software_test.go b/server/datastore/mysql/software_test.go index 6d73a2392e5..fd8ab85f2df 100644 --- a/server/datastore/mysql/software_test.go +++ b/server/datastore/mysql/software_test.go @@ -69,6 +69,8 @@ func TestSoftware(t *testing.T) { {"InsertSoftwareVulnerabilities", testInsertSoftwareVulnerabilities}, {"ListCVEs", testListCVEs}, {"ListSoftwareForVulnDetection", testListSoftwareForVulnDetection}, + {"ListSoftwareForVulnDetectionByOSVersion", testListSoftwareForVulnDetectionByOSVersion}, + {"ListSoftwareVulnerabilitiesBySoftwareIDs", testListSoftwareVulnerabilitiesBySoftwareIDs}, {"AllSoftwareIterator", testAllSoftwareIterator}, {"AllSoftwareIteratorForCustomLinuxImages", testSoftwareIteratorForLinuxKernelCustomImages}, {"UpsertSoftwareCPEs", testUpsertSoftwareCPEs}, @@ -12011,3 +12013,143 @@ func testSoftwareLiteByID(t *testing.T, ds *Datastore) { require.Error(t, err) require.True(t, fleet.IsNotFound(err)) } + +func testListSoftwareForVulnDetectionByOSVersion(t *testing.T, ds *Datastore) { + ctx := context.Background() + + // Create two hosts with the same OS version and overlapping software. + host1 := test.NewHost(t, ds, "osv-host1", "", "osv-host1key", "osv-host1uuid", time.Now()) + host1.Platform = "ubuntu" + host1.OSVersion = "Ubuntu 22.04.1 LTS" + require.NoError(t, ds.UpdateHost(ctx, host1)) + + host2 := test.NewHost(t, ds, "osv-host2", "", "osv-host2key", "osv-host2uuid", time.Now()) + host2.Platform = "ubuntu" + host2.OSVersion = "Ubuntu 22.04.1 LTS" + require.NoError(t, ds.UpdateHost(ctx, host2)) + + // Create a host with a different OS version. + host3 := test.NewHost(t, ds, "osv-host3", "", "osv-host3key", "osv-host3uuid", time.Now()) + host3.Platform = "ubuntu" + host3.OSVersion = "Ubuntu 20.04.1 LTS" + require.NoError(t, ds.UpdateHost(ctx, host3)) + + sharedSoftware := []fleet.Software{ + {Name: "libfoo", Version: "1.2.3", Source: "deb_packages"}, + {Name: "libbar", Version: "4.5.6", Source: "deb_packages"}, + } + _, err := ds.UpdateHostSoftware(ctx, host1.ID, sharedSoftware) + require.NoError(t, err) + + host2Software := []fleet.Software{ + {Name: "libfoo", Version: "1.2.3", Source: "deb_packages"}, // shared with host1 + {Name: "libbaz", Version: "7.8.9", Source: "deb_packages"}, // unique to host2 + } + _, err = ds.UpdateHostSoftware(ctx, host2.ID, host2Software) + require.NoError(t, err) + + host3Software := []fleet.Software{ + {Name: "libother", Version: "0.0.1", Source: "deb_packages"}, + } + _, err = ds.UpdateHostSoftware(ctx, host3.ID, host3Software) + require.NoError(t, err) + + // Query for Ubuntu 22.04.1 LTS — should return 3 distinct software items. + result, err := ds.ListSoftwareForVulnDetectionByOSVersion(ctx, fleet.OSVersion{ + Platform: "ubuntu", + Name: "Ubuntu 22.04.1 LTS", + }) + require.NoError(t, err) + + names := make([]string, len(result)) + for i, sw := range result { + names[i] = sw.Name + } + sort.Strings(names) + require.Equal(t, []string{"libbar", "libbaz", "libfoo"}, names) + + // Verify no duplicates (libfoo exists on both hosts but should appear once). + require.Len(t, result, 3) + + // Query for Ubuntu 20.04.1 LTS — should return only host3's software. + result, err = ds.ListSoftwareForVulnDetectionByOSVersion(ctx, fleet.OSVersion{ + Platform: "ubuntu", + Name: "Ubuntu 20.04.1 LTS", + }) + require.NoError(t, err) + require.Len(t, result, 1) + require.Equal(t, "libother", result[0].Name) + + // Query for nonexistent OS — should return nil. + result, err = ds.ListSoftwareForVulnDetectionByOSVersion(ctx, fleet.OSVersion{ + Platform: "ubuntu", + Name: "Ubuntu 99.99 LTS", + }) + require.NoError(t, err) + require.Nil(t, result) +} + +func testListSoftwareVulnerabilitiesBySoftwareIDs(t *testing.T, ds *Datastore) { + ctx := context.Background() + + // Create some software. + host := test.NewHost(t, ds, "vuln-sw-host", "", "vuln-sw-hostkey", "vuln-sw-hostuuid", time.Now()) + software := []fleet.Software{ + {Name: "pkg-a", Version: "1.0", Source: "deb_packages"}, + {Name: "pkg-b", Version: "2.0", Source: "deb_packages"}, + {Name: "pkg-c", Version: "3.0", Source: "deb_packages"}, + } + _, err := ds.UpdateHostSoftware(ctx, host.ID, software) + require.NoError(t, err) + require.NoError(t, ds.LoadHostSoftware(ctx, host, false)) + + // Look up software by name to avoid depending on unstable ordering. + swByName := make(map[string]fleet.HostSoftwareEntry, len(host.Software)) + for _, sw := range host.Software { + swByName[sw.Name] = sw + } + swA := swByName["pkg-a"] + swB := swByName["pkg-b"] + swC := swByName["pkg-c"] + + // Insert vulns with different sources. + _, err = ds.InsertSoftwareVulnerabilities(ctx, []fleet.SoftwareVulnerability{ + {SoftwareID: swA.ID, CVE: "CVE-2024-0001"}, + {SoftwareID: swA.ID, CVE: "CVE-2024-0002"}, + {SoftwareID: swB.ID, CVE: "CVE-2024-0003"}, + }, fleet.UbuntuOSVSource) + require.NoError(t, err) + + _, err = ds.InsertSoftwareVulnerabilities(ctx, []fleet.SoftwareVulnerability{ + {SoftwareID: swA.ID, CVE: "CVE-2024-9999"}, + {SoftwareID: swC.ID, CVE: "CVE-2024-8888"}, + }, fleet.NVDSource) + require.NoError(t, err) + + // Query OSV source for swA and swB — should return 3 vulns (2 for swA, 1 for swB). + result, err := ds.ListSoftwareVulnerabilitiesBySoftwareIDs(ctx, []uint{swA.ID, swB.ID}, fleet.UbuntuOSVSource) + require.NoError(t, err) + require.Len(t, result, 3) + + cves := make([]string, len(result)) + for i, v := range result { + cves[i] = v.CVE + } + sort.Strings(cves) + require.Equal(t, []string{"CVE-2024-0001", "CVE-2024-0002", "CVE-2024-0003"}, cves) + + // Query OSV source for swC — should return empty (swC's vulns are NVD source). + result, err = ds.ListSoftwareVulnerabilitiesBySoftwareIDs(ctx, []uint{swC.ID}, fleet.UbuntuOSVSource) + require.NoError(t, err) + require.Empty(t, result) + + // Query NVD source for swA and swC — should return 2 vulns. + result, err = ds.ListSoftwareVulnerabilitiesBySoftwareIDs(ctx, []uint{swA.ID, swC.ID}, fleet.NVDSource) + require.NoError(t, err) + require.Len(t, result, 2) + + // Empty software IDs — should return nil. + result, err = ds.ListSoftwareVulnerabilitiesBySoftwareIDs(ctx, []uint{}, fleet.UbuntuOSVSource) + require.NoError(t, err) + require.Nil(t, result) +} diff --git a/server/fleet/datastore.go b/server/fleet/datastore.go index 07693e4f4cd..3399be9c80a 100644 --- a/server/fleet/datastore.go +++ b/server/fleet/datastore.go @@ -676,7 +676,13 @@ type Datastore interface { // ListSoftwareForVulnDetection returns all software for the given hostID with only the fields // used for vulnerability detection populated (id, name, version, cpe_id, cpe) ListSoftwareForVulnDetection(ctx context.Context, filter VulnSoftwareFilter) ([]Software, error) + // ListSoftwareForVulnDetectionByOSVersion returns all distinct software installed on hosts + // matching the given OS version. + ListSoftwareForVulnDetectionByOSVersion(ctx context.Context, osVer OSVersion) ([]Software, error) ListSoftwareVulnerabilitiesByHostIDsSource(ctx context.Context, hostIDs []uint, source VulnerabilitySource) (map[uint][]SoftwareVulnerability, error) + // ListSoftwareVulnerabilitiesBySoftwareIDs returns vulnerabilities for the given software IDs + // filtered by source. Queries software_cve directly without joining through host_software. + ListSoftwareVulnerabilitiesBySoftwareIDs(ctx context.Context, softwareIDs []uint, source VulnerabilitySource) ([]SoftwareVulnerability, error) LoadHostSoftware(ctx context.Context, host *Host, includeCVEScores bool) error AllSoftwareIterator(ctx context.Context, query SoftwareIterQueryOptions) (SoftwareIterator, error) diff --git a/server/mock/datastore_mock.go b/server/mock/datastore_mock.go index 6ffce8fe8fa..e4c8dcd607c 100644 --- a/server/mock/datastore_mock.go +++ b/server/mock/datastore_mock.go @@ -511,8 +511,12 @@ type GetDetailsForUninstallFromExecutionIDFunc func(ctx context.Context, executi type ListSoftwareForVulnDetectionFunc func(ctx context.Context, filter fleet.VulnSoftwareFilter) ([]fleet.Software, error) +type ListSoftwareForVulnDetectionByOSVersionFunc func(ctx context.Context, osVer fleet.OSVersion) ([]fleet.Software, error) + type ListSoftwareVulnerabilitiesByHostIDsSourceFunc func(ctx context.Context, hostIDs []uint, source fleet.VulnerabilitySource) (map[uint][]fleet.SoftwareVulnerability, error) +type ListSoftwareVulnerabilitiesBySoftwareIDsFunc func(ctx context.Context, softwareIDs []uint, source fleet.VulnerabilitySource) ([]fleet.SoftwareVulnerability, error) + type LoadHostSoftwareFunc func(ctx context.Context, host *fleet.Host, includeCVEScores bool) error type AllSoftwareIteratorFunc func(ctx context.Context, query fleet.SoftwareIterQueryOptions) (fleet.SoftwareIterator, error) @@ -2640,9 +2644,15 @@ type DataStore struct { ListSoftwareForVulnDetectionFunc ListSoftwareForVulnDetectionFunc ListSoftwareForVulnDetectionFuncInvoked bool + ListSoftwareForVulnDetectionByOSVersionFunc ListSoftwareForVulnDetectionByOSVersionFunc + ListSoftwareForVulnDetectionByOSVersionFuncInvoked bool + ListSoftwareVulnerabilitiesByHostIDsSourceFunc ListSoftwareVulnerabilitiesByHostIDsSourceFunc ListSoftwareVulnerabilitiesByHostIDsSourceFuncInvoked bool + ListSoftwareVulnerabilitiesBySoftwareIDsFunc ListSoftwareVulnerabilitiesBySoftwareIDsFunc + ListSoftwareVulnerabilitiesBySoftwareIDsFuncInvoked bool + LoadHostSoftwareFunc LoadHostSoftwareFunc LoadHostSoftwareFuncInvoked bool @@ -6445,6 +6455,13 @@ func (s *DataStore) ListSoftwareForVulnDetection(ctx context.Context, filter fle return s.ListSoftwareForVulnDetectionFunc(ctx, filter) } +func (s *DataStore) ListSoftwareForVulnDetectionByOSVersion(ctx context.Context, osVer fleet.OSVersion) ([]fleet.Software, error) { + s.mu.Lock() + s.ListSoftwareForVulnDetectionByOSVersionFuncInvoked = true + s.mu.Unlock() + return s.ListSoftwareForVulnDetectionByOSVersionFunc(ctx, osVer) +} + func (s *DataStore) ListSoftwareVulnerabilitiesByHostIDsSource(ctx context.Context, hostIDs []uint, source fleet.VulnerabilitySource) (map[uint][]fleet.SoftwareVulnerability, error) { s.mu.Lock() s.ListSoftwareVulnerabilitiesByHostIDsSourceFuncInvoked = true @@ -6452,6 +6469,13 @@ func (s *DataStore) ListSoftwareVulnerabilitiesByHostIDsSource(ctx context.Conte return s.ListSoftwareVulnerabilitiesByHostIDsSourceFunc(ctx, hostIDs, source) } +func (s *DataStore) ListSoftwareVulnerabilitiesBySoftwareIDs(ctx context.Context, softwareIDs []uint, source fleet.VulnerabilitySource) ([]fleet.SoftwareVulnerability, error) { + s.mu.Lock() + s.ListSoftwareVulnerabilitiesBySoftwareIDsFuncInvoked = true + s.mu.Unlock() + return s.ListSoftwareVulnerabilitiesBySoftwareIDsFunc(ctx, softwareIDs, source) +} + func (s *DataStore) LoadHostSoftware(ctx context.Context, host *fleet.Host, includeCVEScores bool) error { s.mu.Lock() s.LoadHostSoftwareFuncInvoked = true diff --git a/server/vulnerabilities/osv/analyzer.go b/server/vulnerabilities/osv/analyzer.go index e5a1beb6719..1d9f2a80ff9 100644 --- a/server/vulnerabilities/osv/analyzer.go +++ b/server/vulnerabilities/osv/analyzer.go @@ -51,104 +51,150 @@ type OSVVulnerability struct { Versions []string `json:"versions,omitempty"` } -// Analyze scans all hosts for vulnerabilities based on the OSV artifacts for their platform -func Analyze( +type softwareMatcher func(software []fleet.Software) []fleet.SoftwareVulnerability + +const softwareBatchSize = 5000 + +func analyzeOSV( ctx context.Context, ds fleet.Datastore, ver fleet.OSVersion, - vulnPath string, + source fleet.VulnerabilitySource, + matcher softwareMatcher, collectVulns bool, logger *slog.Logger, - date time.Time, ) ([]fleet.SoftwareVulnerability, error) { - if strings.ToLower(ver.Platform) != "ubuntu" { - return nil, ErrUnsupportedPlatform + // Get distinct software for this OS version (replaces per-host ListSoftwareForVulnDetection). + softwareStart := time.Now().UTC() + software, err := ds.ListSoftwareForVulnDetectionByOSVersion(ctx, ver) + if err != nil { + return nil, fmt.Errorf("listing software for OS version: %w", err) } + softwareTime := time.Since(softwareStart) - artifact, err := loadOSVArtifact(ctx, ver, vulnPath, logger, date) - if err != nil { - return nil, fmt.Errorf("loading OSV artifact: %w", err) + if len(software) == 0 { + logger.DebugContext(ctx, "no software found for os version", + "platform", ver.Platform, "version", ver.Version) + return nil, nil } - source := fleet.UbuntuOSVSource - toInsertSet := make(map[string]fleet.SoftwareVulnerability) - toDeleteSet := make(map[string]fleet.SoftwareVulnerability) - totalHosts := 0 + var ( + totalFound int + totalExisting int + totalInsert int + totalDelete int + matchTime time.Duration + existingTime time.Duration + allNewVulns []fleet.SoftwareVulnerability + ) + + for i := 0; i < len(software); i += softwareBatchSize { + end := min(i+softwareBatchSize, len(software)) + chunk := software[i:end] + + // Match this chunk against the artifact. + matchStart := time.Now().UTC() + found := matcher(chunk) + matchTime += time.Since(matchStart) + + // Collect software IDs for this chunk. + chunkIDs := make([]uint, len(chunk)) + for j, sw := range chunk { + chunkIDs[j] = sw.ID + } - // Paginate through all hosts with this OS version - var offset int - for { - hostIDs, err := ds.HostIDsByOSVersion(ctx, ver, offset, hostsBatchSize) + // Get existing vulns scoped to this chunk's software IDs. + existingStart := time.Now().UTC() + existing, err := ds.ListSoftwareVulnerabilitiesBySoftwareIDs(ctx, chunkIDs, source) if err != nil { - return nil, fmt.Errorf("getting host IDs: %w", err) + return nil, fmt.Errorf("listing existing vulnerabilities: %w", err) } + existingTime += time.Since(existingStart) - if len(hostIDs) == 0 { - break - } + // Compute delta for this chunk. + toInsert, toDelete := utils.VulnsDelta(found, existing) - totalHosts += len(hostIDs) - offset += hostsBatchSize + totalFound += len(found) + totalExisting += len(existing) + totalInsert += len(toInsert) + totalDelete += len(toDelete) - foundInBatch := make(map[uint][]fleet.SoftwareVulnerability) - for _, hostID := range hostIDs { - software, err := ds.ListSoftwareForVulnDetection(ctx, fleet.VulnSoftwareFilter{ - HostID: &hostID, - }) - if err != nil { - return nil, fmt.Errorf("listing software for host %d: %w", hostID, err) + // Delete stale vulnerabilities for this chunk. + if len(toDelete) > 0 { + toDeleteMap := make(map[string]fleet.SoftwareVulnerability, len(toDelete)) + for _, v := range toDelete { + toDeleteMap[v.Key()] = v + } + if err := utils.BatchProcess(toDeleteMap, func(v []fleet.SoftwareVulnerability) error { + return ds.DeleteSoftwareVulnerabilities(ctx, v) + }, vulnBatchSize); err != nil { + return nil, fmt.Errorf("deleting stale vulnerabilities: %w", err) } - - foundInBatch[hostID] = matchSoftwareToOSV(software, artifact) } - existingInBatch, err := ds.ListSoftwareVulnerabilitiesByHostIDsSource(ctx, hostIDs, source) - if err != nil { - return nil, fmt.Errorf("listing existing vulnerabilities: %w", err) - } + // Deduplicate and insert new vulnerabilities for this chunk. + if len(toInsert) > 0 { + seen := make(map[string]struct{}, len(toInsert)) + dedupedInsert := make([]fleet.SoftwareVulnerability, 0, len(toInsert)) + for _, v := range toInsert { + if _, ok := seen[v.Key()]; !ok { + seen[v.Key()] = struct{}{} + dedupedInsert = append(dedupedInsert, v) + } + } - for _, hostID := range hostIDs { - insrt, del := utils.VulnsDelta(foundInBatch[hostID], existingInBatch[hostID]) - for _, i := range insrt { - toInsertSet[i.Key()] = i + newVulns, err := ds.InsertSoftwareVulnerabilities(ctx, dedupedInsert, source) + if err != nil { + return nil, fmt.Errorf("inserting software vulnerabilities: %w", err) } - for _, d := range del { - toDeleteSet[d.Key()] = d + if collectVulns { + allNewVulns = append(allNewVulns, newVulns...) } } } - if totalHosts == 0 { - logger.DebugContext(ctx, "no hosts found for os version", "platform", ver.Platform, "version", ver.Version) + logger.DebugContext(ctx, "osv analysis completed", + "platform", ver.Platform, + "version", ver.Version, + "distinct_software", len(software), + "software_query_time", softwareTime, + "match_time", matchTime, + "existing_query_time", existingTime, + "found_vulns", totalFound, + "existing_vulns", totalExisting, + "to_insert", totalInsert, + "to_delete", totalDelete, + ) + + if !collectVulns { return nil, nil } - logger.DebugContext(ctx, "processed hosts for osv analysis", "platform", ver.Platform, "version", ver.Version, "host_count", totalHosts) - - // Delete stale vulnerabilities - err = utils.BatchProcess(toDeleteSet, func(v []fleet.SoftwareVulnerability) error { - return ds.DeleteSoftwareVulnerabilities(ctx, v) - }, vulnBatchSize) - if err != nil { - return nil, fmt.Errorf("deleting stale vulnerabilities: %w", err) - } + return allNewVulns, nil +} - // Insert new vulnerabilities - allVulns := make([]fleet.SoftwareVulnerability, 0, len(toInsertSet)) - for _, v := range toInsertSet { - allVulns = append(allVulns, v) +// Analyze scans all hosts for vulnerabilities based on the OSV artifacts for their platform +func Analyze( + ctx context.Context, + ds fleet.Datastore, + ver fleet.OSVersion, + vulnPath string, + collectVulns bool, + logger *slog.Logger, + date time.Time, +) ([]fleet.SoftwareVulnerability, error) { + if strings.ToLower(ver.Platform) != "ubuntu" { + return nil, ErrUnsupportedPlatform } - newVulns, err := ds.InsertSoftwareVulnerabilities(ctx, allVulns, source) + artifact, err := loadOSVArtifact(ctx, ver, vulnPath, logger, date) if err != nil { - return nil, fmt.Errorf("inserting software vulnerabilities: %w", err) - } - - if !collectVulns { - return nil, nil + return nil, fmt.Errorf("loading OSV artifact: %w", err) } - return newVulns, nil + return analyzeOSV(ctx, ds, ver, fleet.UbuntuOSVSource, func(sw []fleet.Software) []fleet.SoftwareVulnerability { + return matchSoftwareToOSV(sw, artifact) + }, collectVulns, logger) } // findLatestOSVArtifactForVersion finds the most recent OSV artifact for a specific Ubuntu version @@ -520,82 +566,9 @@ func AnalyzeRHEL( return nil, fmt.Errorf("loading RHEL OSV artifact: %w", err) } - source := fleet.RHELOSVSource - toInsertSet := make(map[string]fleet.SoftwareVulnerability) - toDeleteSet := make(map[string]fleet.SoftwareVulnerability) - totalHosts := 0 - - var offset int - for { - hostIDs, err := ds.HostIDsByOSVersion(ctx, ver, offset, hostsBatchSize) - if err != nil { - return nil, fmt.Errorf("getting host IDs: %w", err) - } - - if len(hostIDs) == 0 { - break - } - - totalHosts += len(hostIDs) - offset += hostsBatchSize - - foundInBatch := make(map[uint][]fleet.SoftwareVulnerability) - for _, hostID := range hostIDs { - software, err := ds.ListSoftwareForVulnDetection(ctx, fleet.VulnSoftwareFilter{ - HostID: &hostID, - }) - if err != nil { - return nil, fmt.Errorf("listing software for host %d: %w", hostID, err) - } - - foundInBatch[hostID] = matchSoftwareToRHELOSV(software, artifact) - } - - existingInBatch, err := ds.ListSoftwareVulnerabilitiesByHostIDsSource(ctx, hostIDs, source) - if err != nil { - return nil, fmt.Errorf("listing existing vulnerabilities: %w", err) - } - - for _, hostID := range hostIDs { - insrt, del := utils.VulnsDelta(foundInBatch[hostID], existingInBatch[hostID]) - for _, i := range insrt { - toInsertSet[i.Key()] = i - } - for _, d := range del { - toDeleteSet[d.Key()] = d - } - } - } - - if totalHosts == 0 { - logger.DebugContext(ctx, "no hosts found for os version", "platform", ver.Platform, "version", ver.Version) - return nil, nil - } - - logger.DebugContext(ctx, "processed hosts for rhel osv analysis", "platform", ver.Platform, "version", ver.Version, "host_count", totalHosts) - - err = utils.BatchProcess(toDeleteSet, func(v []fleet.SoftwareVulnerability) error { - return ds.DeleteSoftwareVulnerabilities(ctx, v) - }, vulnBatchSize) - if err != nil { - return nil, fmt.Errorf("deleting stale vulnerabilities: %w", err) - } - - allVulns := make([]fleet.SoftwareVulnerability, 0, len(toInsertSet)) - for _, v := range toInsertSet { - allVulns = append(allVulns, v) - } - - newVulns, err := ds.InsertSoftwareVulnerabilities(ctx, allVulns, source) - if err != nil { - return nil, fmt.Errorf("inserting software vulnerabilities: %w", err) - } - - if !collectVulns { - return nil, nil - } - - return newVulns, nil + return analyzeOSV(ctx, ds, ver, fleet.RHELOSVSource, func(sw []fleet.Software) []fleet.SoftwareVulnerability { + return matchSoftwareToRHELOSV(sw, artifact) + }, collectVulns, logger) } // findLatestRHELOSVArtifactForVersion finds the most recent RHEL OSV artifact for a major version.