Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/44391-osv-vuln-optimizations
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
* Optimized OSV vulnerability scanning to query distinct software per OS version rather than per host, reducing redundant database queries for many hosts sharing the same packages.
82 changes: 82 additions & 0 deletions server/datastore/mysql/software.go
Original file line number Diff line number Diff line change
Expand Up @@ -3309,6 +3309,88 @@ func (ds *Datastore) ListSoftwareForVulnDetection(
return result, nil
}

const softwareVulnDetectionBatchSize = 10000

func (ds *Datastore) ListSoftwareForVulnDetectionByOSVersion(
ctx context.Context,
osVer fleet.OSVersion,
) ([]fleet.Software, error) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious if we'll run into memory issues here. I believe the old pattern paginated

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this method in particular, if we take 100k softwares then...

~100K uint (8 bytes each) = 800KB for the IDs
~100K fleet.Software structs - we only populate 6 fields in our query (id, name, version, release, arch, generated_cpe)

  • name (~20 chars)
  • version (~10)
  • release (~10)
  • arch (~5)
  • cpe (~60)

~105 bytes
Struct overhead: ~200 bytes (8 strings × 16 bytes header + uint)
Total per item: ~305 bytes
100K items: ~30MB

Seemed reasonable. Am I missing something?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, just caught my eye

var softwareIDs []uint
err := sqlx.SelectContext(ctx, ds.reader(ctx), &softwareIDs, `
SELECT DISTINCT hs.software_id
FROM host_software hs
JOIN hosts h ON hs.host_id = h.id
WHERE h.platform = ? AND h.os_version = ?
`, osVer.Platform, osVer.Name)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "listing distinct software IDs for OS version")
}

Comment on lines +3313 to +3328
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current implementation ran in 600ms locally with 14M rows — extrapolating linearly to 330M would be ~14 seconds. That seems acceptable. Especially given the cron job ran for 8+ hours.

We could potentially batch it as stated without the join:
SELECT id FROM hosts WHERE platform = ? AND os_version = ? ORDER BY id LIMIT ? OFFSET ? + SELECT DISTINCT software_id FROM host_software WHERE host_id IN (%s)
If we use 10k batches.
We would still get the same 330M reads total, but split across 15 round trips. Each batch processes 10K hosts × 2,300 = 23M entries, deduplicates locally, and merges into a Go map. It would give us smaller temp tables per query. But could actually be slower due to the round-trip overhead and larger IN clauses.

if len(softwareIDs) == 0 {
return nil, nil
}

var result []fleet.Software
if err := common_mysql.BatchProcessSimple(softwareIDs, softwareVulnDetectionBatchSize, func(batch []uint) error {
placeholders := strings.TrimSuffix(strings.Repeat("?,", len(batch)), ",")
query := fmt.Sprintf(`
SELECT s.id, s.name, s.version, s.release, s.arch, COALESCE(cpe.cpe, '') AS generated_cpe
FROM software s
LEFT JOIN software_cpe cpe ON s.id = cpe.software_id
WHERE s.id IN (%s)
`, placeholders)
args := make([]any, len(batch))
for i, id := range batch {
args[i] = id
}
var batchResult []fleet.Software
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &batchResult, query, args...); err != nil {
return ctxerr.Wrap(ctx, err, "fetching software details for vulnerability detection")
}
result = append(result, batchResult...)
return nil
}); err != nil {
return nil, err
}

return result, nil
}

func (ds *Datastore) ListSoftwareVulnerabilitiesBySoftwareIDs(
ctx context.Context,
softwareIDs []uint,
source fleet.VulnerabilitySource,
) ([]fleet.SoftwareVulnerability, error) {
if len(softwareIDs) == 0 {
return nil, nil
}

var result []fleet.SoftwareVulnerability
if err := common_mysql.BatchProcessSimple(softwareIDs, softwareVulnDetectionBatchSize, func(batch []uint) error {
placeholders := strings.TrimSuffix(strings.Repeat("?,", len(batch)), ",")
query := fmt.Sprintf(`
SELECT software_id, cve, resolved_in_version
FROM software_cve
WHERE source = ? AND software_id IN (%s)
`, placeholders)
args := make([]any, 0, len(batch)+1)
args = append(args, source)
for _, id := range batch {
args = append(args, id)
}
var batchResult []fleet.SoftwareVulnerability
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &batchResult, query, args...); err != nil {
return ctxerr.Wrap(ctx, err, "fetching software vulnerabilities by software IDs")
}
result = append(result, batchResult...)
return nil
}); err != nil {
return nil, err
}

return result, nil
}

// ListCVEs returns all cve_meta rows published after 'maxAge'
func (ds *Datastore) ListCVEs(ctx context.Context, maxAge time.Duration) ([]fleet.CVEMeta, error) {
var result []fleet.CVEMeta
Expand Down
142 changes: 142 additions & 0 deletions server/datastore/mysql/software_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ func TestSoftware(t *testing.T) {
{"InsertSoftwareVulnerabilities", testInsertSoftwareVulnerabilities},
{"ListCVEs", testListCVEs},
{"ListSoftwareForVulnDetection", testListSoftwareForVulnDetection},
{"ListSoftwareForVulnDetectionByOSVersion", testListSoftwareForVulnDetectionByOSVersion},
{"ListSoftwareVulnerabilitiesBySoftwareIDs", testListSoftwareVulnerabilitiesBySoftwareIDs},
{"AllSoftwareIterator", testAllSoftwareIterator},
{"AllSoftwareIteratorForCustomLinuxImages", testSoftwareIteratorForLinuxKernelCustomImages},
{"UpsertSoftwareCPEs", testUpsertSoftwareCPEs},
Expand Down Expand Up @@ -12011,3 +12013,143 @@ func testSoftwareLiteByID(t *testing.T, ds *Datastore) {
require.Error(t, err)
require.True(t, fleet.IsNotFound(err))
}

func testListSoftwareForVulnDetectionByOSVersion(t *testing.T, ds *Datastore) {
ctx := context.Background()

// Create two hosts with the same OS version and overlapping software.
host1 := test.NewHost(t, ds, "osv-host1", "", "osv-host1key", "osv-host1uuid", time.Now())
host1.Platform = "ubuntu"
host1.OSVersion = "Ubuntu 22.04.1 LTS"
require.NoError(t, ds.UpdateHost(ctx, host1))

host2 := test.NewHost(t, ds, "osv-host2", "", "osv-host2key", "osv-host2uuid", time.Now())
host2.Platform = "ubuntu"
host2.OSVersion = "Ubuntu 22.04.1 LTS"
require.NoError(t, ds.UpdateHost(ctx, host2))

// Create a host with a different OS version.
host3 := test.NewHost(t, ds, "osv-host3", "", "osv-host3key", "osv-host3uuid", time.Now())
host3.Platform = "ubuntu"
host3.OSVersion = "Ubuntu 20.04.1 LTS"
require.NoError(t, ds.UpdateHost(ctx, host3))

sharedSoftware := []fleet.Software{
{Name: "libfoo", Version: "1.2.3", Source: "deb_packages"},
{Name: "libbar", Version: "4.5.6", Source: "deb_packages"},
}
_, err := ds.UpdateHostSoftware(ctx, host1.ID, sharedSoftware)
require.NoError(t, err)

host2Software := []fleet.Software{
{Name: "libfoo", Version: "1.2.3", Source: "deb_packages"}, // shared with host1
{Name: "libbaz", Version: "7.8.9", Source: "deb_packages"}, // unique to host2
}
_, err = ds.UpdateHostSoftware(ctx, host2.ID, host2Software)
require.NoError(t, err)

host3Software := []fleet.Software{
{Name: "libother", Version: "0.0.1", Source: "deb_packages"},
}
_, err = ds.UpdateHostSoftware(ctx, host3.ID, host3Software)
require.NoError(t, err)

// Query for Ubuntu 22.04.1 LTS — should return 3 distinct software items.
result, err := ds.ListSoftwareForVulnDetectionByOSVersion(ctx, fleet.OSVersion{
Platform: "ubuntu",
Name: "Ubuntu 22.04.1 LTS",
})
require.NoError(t, err)

names := make([]string, len(result))
for i, sw := range result {
names[i] = sw.Name
}
sort.Strings(names)
require.Equal(t, []string{"libbar", "libbaz", "libfoo"}, names)

// Verify no duplicates (libfoo exists on both hosts but should appear once).
require.Len(t, result, 3)

// Query for Ubuntu 20.04.1 LTS — should return only host3's software.
result, err = ds.ListSoftwareForVulnDetectionByOSVersion(ctx, fleet.OSVersion{
Platform: "ubuntu",
Name: "Ubuntu 20.04.1 LTS",
})
require.NoError(t, err)
require.Len(t, result, 1)
require.Equal(t, "libother", result[0].Name)

// Query for nonexistent OS — should return nil.
result, err = ds.ListSoftwareForVulnDetectionByOSVersion(ctx, fleet.OSVersion{
Platform: "ubuntu",
Name: "Ubuntu 99.99 LTS",
})
require.NoError(t, err)
require.Nil(t, result)
}

func testListSoftwareVulnerabilitiesBySoftwareIDs(t *testing.T, ds *Datastore) {
ctx := context.Background()

// Create some software.
host := test.NewHost(t, ds, "vuln-sw-host", "", "vuln-sw-hostkey", "vuln-sw-hostuuid", time.Now())
software := []fleet.Software{
{Name: "pkg-a", Version: "1.0", Source: "deb_packages"},
{Name: "pkg-b", Version: "2.0", Source: "deb_packages"},
{Name: "pkg-c", Version: "3.0", Source: "deb_packages"},
}
_, err := ds.UpdateHostSoftware(ctx, host.ID, software)
require.NoError(t, err)
require.NoError(t, ds.LoadHostSoftware(ctx, host, false))

// Look up software by name to avoid depending on unstable ordering.
swByName := make(map[string]fleet.HostSoftwareEntry, len(host.Software))
for _, sw := range host.Software {
swByName[sw.Name] = sw
}
swA := swByName["pkg-a"]
swB := swByName["pkg-b"]
swC := swByName["pkg-c"]

// Insert vulns with different sources.
_, err = ds.InsertSoftwareVulnerabilities(ctx, []fleet.SoftwareVulnerability{
{SoftwareID: swA.ID, CVE: "CVE-2024-0001"},
{SoftwareID: swA.ID, CVE: "CVE-2024-0002"},
{SoftwareID: swB.ID, CVE: "CVE-2024-0003"},
}, fleet.UbuntuOSVSource)
require.NoError(t, err)

_, err = ds.InsertSoftwareVulnerabilities(ctx, []fleet.SoftwareVulnerability{
{SoftwareID: swA.ID, CVE: "CVE-2024-9999"},
{SoftwareID: swC.ID, CVE: "CVE-2024-8888"},
}, fleet.NVDSource)
require.NoError(t, err)

// Query OSV source for swA and swB — should return 3 vulns (2 for swA, 1 for swB).
result, err := ds.ListSoftwareVulnerabilitiesBySoftwareIDs(ctx, []uint{swA.ID, swB.ID}, fleet.UbuntuOSVSource)
require.NoError(t, err)
require.Len(t, result, 3)

cves := make([]string, len(result))
for i, v := range result {
cves[i] = v.CVE
}
sort.Strings(cves)
require.Equal(t, []string{"CVE-2024-0001", "CVE-2024-0002", "CVE-2024-0003"}, cves)

// Query OSV source for swC — should return empty (swC's vulns are NVD source).
result, err = ds.ListSoftwareVulnerabilitiesBySoftwareIDs(ctx, []uint{swC.ID}, fleet.UbuntuOSVSource)
require.NoError(t, err)
require.Empty(t, result)

// Query NVD source for swA and swC — should return 2 vulns.
result, err = ds.ListSoftwareVulnerabilitiesBySoftwareIDs(ctx, []uint{swA.ID, swC.ID}, fleet.NVDSource)
require.NoError(t, err)
require.Len(t, result, 2)

// Empty software IDs — should return nil.
result, err = ds.ListSoftwareVulnerabilitiesBySoftwareIDs(ctx, []uint{}, fleet.UbuntuOSVSource)
require.NoError(t, err)
require.Nil(t, result)
}
6 changes: 6 additions & 0 deletions server/fleet/datastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,13 @@ type Datastore interface {
// ListSoftwareForVulnDetection returns all software for the given hostID with only the fields
// used for vulnerability detection populated (id, name, version, cpe_id, cpe)
ListSoftwareForVulnDetection(ctx context.Context, filter VulnSoftwareFilter) ([]Software, error)
// ListSoftwareForVulnDetectionByOSVersion returns all distinct software installed on hosts
// matching the given OS version.
ListSoftwareForVulnDetectionByOSVersion(ctx context.Context, osVer OSVersion) ([]Software, error)
ListSoftwareVulnerabilitiesByHostIDsSource(ctx context.Context, hostIDs []uint, source VulnerabilitySource) (map[uint][]SoftwareVulnerability, error)
// ListSoftwareVulnerabilitiesBySoftwareIDs returns vulnerabilities for the given software IDs
// filtered by source. Queries software_cve directly without joining through host_software.
ListSoftwareVulnerabilitiesBySoftwareIDs(ctx context.Context, softwareIDs []uint, source VulnerabilitySource) ([]SoftwareVulnerability, error)
LoadHostSoftware(ctx context.Context, host *Host, includeCVEScores bool) error

AllSoftwareIterator(ctx context.Context, query SoftwareIterQueryOptions) (SoftwareIterator, error)
Expand Down
24 changes: 24 additions & 0 deletions server/mock/datastore_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -511,8 +511,12 @@ type GetDetailsForUninstallFromExecutionIDFunc func(ctx context.Context, executi

type ListSoftwareForVulnDetectionFunc func(ctx context.Context, filter fleet.VulnSoftwareFilter) ([]fleet.Software, error)

type ListSoftwareForVulnDetectionByOSVersionFunc func(ctx context.Context, osVer fleet.OSVersion) ([]fleet.Software, error)

type ListSoftwareVulnerabilitiesByHostIDsSourceFunc func(ctx context.Context, hostIDs []uint, source fleet.VulnerabilitySource) (map[uint][]fleet.SoftwareVulnerability, error)

type ListSoftwareVulnerabilitiesBySoftwareIDsFunc func(ctx context.Context, softwareIDs []uint, source fleet.VulnerabilitySource) ([]fleet.SoftwareVulnerability, error)

type LoadHostSoftwareFunc func(ctx context.Context, host *fleet.Host, includeCVEScores bool) error

type AllSoftwareIteratorFunc func(ctx context.Context, query fleet.SoftwareIterQueryOptions) (fleet.SoftwareIterator, error)
Expand Down Expand Up @@ -2640,9 +2644,15 @@ type DataStore struct {
ListSoftwareForVulnDetectionFunc ListSoftwareForVulnDetectionFunc
ListSoftwareForVulnDetectionFuncInvoked bool

ListSoftwareForVulnDetectionByOSVersionFunc ListSoftwareForVulnDetectionByOSVersionFunc
ListSoftwareForVulnDetectionByOSVersionFuncInvoked bool

ListSoftwareVulnerabilitiesByHostIDsSourceFunc ListSoftwareVulnerabilitiesByHostIDsSourceFunc
ListSoftwareVulnerabilitiesByHostIDsSourceFuncInvoked bool

ListSoftwareVulnerabilitiesBySoftwareIDsFunc ListSoftwareVulnerabilitiesBySoftwareIDsFunc
ListSoftwareVulnerabilitiesBySoftwareIDsFuncInvoked bool

LoadHostSoftwareFunc LoadHostSoftwareFunc
LoadHostSoftwareFuncInvoked bool

Expand Down Expand Up @@ -6445,13 +6455,27 @@ func (s *DataStore) ListSoftwareForVulnDetection(ctx context.Context, filter fle
return s.ListSoftwareForVulnDetectionFunc(ctx, filter)
}

func (s *DataStore) ListSoftwareForVulnDetectionByOSVersion(ctx context.Context, osVer fleet.OSVersion) ([]fleet.Software, error) {
s.mu.Lock()
s.ListSoftwareForVulnDetectionByOSVersionFuncInvoked = true
s.mu.Unlock()
return s.ListSoftwareForVulnDetectionByOSVersionFunc(ctx, osVer)
}

func (s *DataStore) ListSoftwareVulnerabilitiesByHostIDsSource(ctx context.Context, hostIDs []uint, source fleet.VulnerabilitySource) (map[uint][]fleet.SoftwareVulnerability, error) {
s.mu.Lock()
s.ListSoftwareVulnerabilitiesByHostIDsSourceFuncInvoked = true
s.mu.Unlock()
return s.ListSoftwareVulnerabilitiesByHostIDsSourceFunc(ctx, hostIDs, source)
}

func (s *DataStore) ListSoftwareVulnerabilitiesBySoftwareIDs(ctx context.Context, softwareIDs []uint, source fleet.VulnerabilitySource) ([]fleet.SoftwareVulnerability, error) {
s.mu.Lock()
s.ListSoftwareVulnerabilitiesBySoftwareIDsFuncInvoked = true
s.mu.Unlock()
return s.ListSoftwareVulnerabilitiesBySoftwareIDsFunc(ctx, softwareIDs, source)
}

func (s *DataStore) LoadHostSoftware(ctx context.Context, host *fleet.Host, includeCVEScores bool) error {
s.mu.Lock()
s.LoadHostSoftwareFuncInvoked = true
Expand Down
Loading
Loading