diff --git a/go.mod b/go.mod index 9d5b522ed..888a2dedf 100644 --- a/go.mod +++ b/go.mod @@ -58,6 +58,7 @@ require ( github.com/mmcloughlin/avo v0.6.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pjbgf/sha1cd v0.3.1 // indirect + github.com/robfig/cron/v3 v3.0.1 // indirect github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect github.com/skeema/knownhosts v1.3.0 // indirect github.com/xanzy/ssh-agent v0.3.3 // indirect diff --git a/go.sum b/go.sum index e272a178d..1187ff05d 100644 --- a/go.sum +++ b/go.sum @@ -232,6 +232,8 @@ github.com/prometheus/common v0.61.0 h1:3gv/GThfX0cV2lpO7gkTUwZru38mxevy90Bj8YFS github.com/prometheus/common v0.61.0/go.mod h1:zr29OCN/2BsJRaFwG8QOBr41D6kkchKbpeNH7pAjb/s= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= +github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= +github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= diff --git a/tavern/internal/c2/api_claim_tasks.go b/tavern/internal/c2/api_claim_tasks.go index 16326ba65..7fcea3e57 100644 --- a/tavern/internal/c2/api_claim_tasks.go +++ b/tavern/internal/c2/api_claim_tasks.go @@ -9,14 +9,17 @@ import ( "time" "github.com/prometheus/client_golang/prometheus" + "github.com/robfig/cron/v3" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "realm.pub/tavern/internal/c2/c2pb" "realm.pub/tavern/internal/c2/epb" + "realm.pub/tavern/internal/ent" "realm.pub/tavern/internal/ent/beacon" "realm.pub/tavern/internal/ent/host" "realm.pub/tavern/internal/ent/tag" "realm.pub/tavern/internal/ent/task" + "realm.pub/tavern/internal/ent/tome" "realm.pub/tavern/internal/namegen" ) @@ -28,10 +31,117 @@ var ( }, []string{"host_identifier", "host_groups", "host_services"}, ) + metricTomeAutomationErrors = prometheus.NewCounter( + prometheus.CounterOpts{ + Name: "tavern_tome_automation_errors_total", + Help: "The total number of errors encountered during tome automation", + }, + ) ) func init() { prometheus.MustRegister(metricHostCallbacksTotal) + prometheus.MustRegister(metricTomeAutomationErrors) +} + +func (srv *Server) handleTomeAutomation(ctx context.Context, beaconID int, hostID int, isNewBeacon bool, isNewHost bool, now time.Time) { + // Tome Automation Logic + candidateTomes, err := srv.graph.Tome.Query(). + Where(tome.Or( + tome.RunOnNewBeaconCallback(true), + tome.RunOnFirstHostCallback(true), + tome.RunOnScheduleNEQ(""), + )). + All(ctx) + + if err != nil { + slog.ErrorContext(ctx, "failed to query candidate tomes for automation", "err", err) + metricTomeAutomationErrors.Inc() + return + } + + selectedTomes := make(map[int]*ent.Tome) + parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow) + currentMinute := now.Truncate(time.Minute) + + for _, t := range candidateTomes { + shouldRun := false + + // Check RunOnNewBeaconCallback + if isNewBeacon && t.RunOnNewBeaconCallback { + shouldRun = true + } + + // Check RunOnFirstHostCallback + if !shouldRun && isNewHost && t.RunOnFirstHostCallback { + shouldRun = true + } + + // Check RunOnSchedule + if !shouldRun && t.RunOnSchedule != "" { + sched, err := parser.Parse(t.RunOnSchedule) + if err == nil { + // Check if schedule matches current time + // Next(now-1sec) == now? + next := sched.Next(currentMinute.Add(-1 * time.Second)) + if next.Equal(currentMinute) { + // Check scheduled_hosts constraint + hostCount, err := t.QueryScheduledHosts().Count(ctx) + if err != nil { + slog.ErrorContext(ctx, "failed to count scheduled hosts for automation", "err", err, "tome_id", t.ID) + metricTomeAutomationErrors.Inc() + continue + } + if hostCount == 0 { + shouldRun = true + } else { + hostExists, err := t.QueryScheduledHosts(). + Where(host.ID(hostID)). + Exist(ctx) + if err != nil { + slog.ErrorContext(ctx, "failed to check host existence for automation", "err", err, "tome_id", t.ID) + metricTomeAutomationErrors.Inc() + continue + } + if hostExists { + shouldRun = true + } + } + } + } else { + // Don't log cron parse errors for now, as it might be spammy if stored in DB + // metricTomeAutomationErrors.Inc() + } + } + + if shouldRun { + selectedTomes[t.ID] = t + } + } + + // Create Quest and Task for each selected Tome + for _, t := range selectedTomes { + q, err := srv.graph.Quest.Create(). + SetName(fmt.Sprintf("Automated: %s", t.Name)). + SetTome(t). + SetParamDefsAtCreation(t.ParamDefs). + SetEldritchAtCreation(t.Eldritch). + Save(ctx) + if err != nil { + slog.ErrorContext(ctx, "failed to create automated quest", "err", err, "tome_id", t.ID) + metricTomeAutomationErrors.Inc() + continue + } + + _, err = srv.graph.Task.Create(). + SetQuest(q). + SetBeaconID(beaconID). + Save(ctx) + if err != nil { + slog.ErrorContext(ctx, "failed to create automated task", "err", err, "quest_id", q.ID) + metricTomeAutomationErrors.Inc() + } + } } func (srv *Server) ClaimTasks(ctx context.Context, req *c2pb.ClaimTasksRequest) (*c2pb.ClaimTasksResponse, error) { @@ -61,6 +171,15 @@ func (srv *Server) ClaimTasks(ctx context.Context, req *c2pb.ClaimTasksRequest) return nil, status.Errorf(codes.InvalidArgument, "must provide agent identifier") } + // Check if host is new (before upsert) + hostExists, err := srv.graph.Host.Query(). + Where(host.IdentifierEQ(req.Beacon.Host.Identifier)). + Exist(ctx) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to query host existence: %v", err) + } + isNewHost := !hostExists + // Upsert the host hostID, err := srv.graph.Host.Create(). SetIdentifier(req.Beacon.Host.Identifier). @@ -118,6 +237,8 @@ func (srv *Server) ClaimTasks(ctx context.Context, req *c2pb.ClaimTasksRequest) if err != nil { return nil, status.Errorf(codes.Internal, "failed to query beacon entity: %v", err) } + isNewBeacon := !beaconExists + var beaconNameAddr *string = nil if !beaconExists { candidateNames := []string{ @@ -172,6 +293,9 @@ func (srv *Server) ClaimTasks(ctx context.Context, req *c2pb.ClaimTasksRequest) return nil, status.Errorf(codes.Internal, "failed to upsert beacon entity: %v", err) } + // Run Tome Automation (non-blocking, best effort) + srv.handleTomeAutomation(ctx, beaconID, hostID, isNewBeacon, isNewHost, now) + // Load Tasks tasks, err := srv.graph.Task.Query(). Where(task.And( diff --git a/tavern/internal/c2/tome_automation_test.go b/tavern/internal/c2/tome_automation_test.go new file mode 100644 index 000000000..1f7860fc7 --- /dev/null +++ b/tavern/internal/c2/tome_automation_test.go @@ -0,0 +1,193 @@ +package c2 + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "realm.pub/tavern/internal/c2/c2pb" + "realm.pub/tavern/internal/ent" + "realm.pub/tavern/internal/ent/enttest" +) + +func TestHandleTomeAutomation(t *testing.T) { + ctx := context.Background() + client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") + defer client.Close() + + srv := &Server{graph: client} + now := time.Date(2023, 10, 27, 10, 0, 0, 0, time.UTC) + + // Create a dummy host and beacon for testing + h := client.Host.Create(). + SetIdentifier("test-host"). + SetName("Test Host"). + SetPlatform(c2pb.Host_PLATFORM_LINUX). + SaveX(ctx) + b := client.Beacon.Create(). + SetIdentifier("test-beacon"). + SetHost(h). + SetTransport(c2pb.ActiveTransport_TRANSPORT_HTTP1). + SaveX(ctx) + + // 1. Setup Tomes + // T1: New Beacon Only + client.Tome.Create(). + SetName("Tome New Beacon"). + SetDescription("Test"). + SetAuthor("Test Author"). + SetEldritch("print('new beacon')"). + SetRunOnNewBeaconCallback(true). + SaveX(ctx) + + // T2: New Host Only + client.Tome.Create(). + SetName("Tome New Host"). + SetDescription("Test"). + SetAuthor("Test Author"). + SetEldritch("print('new host')"). + SetRunOnFirstHostCallback(true). + SaveX(ctx) + + // T3: Schedule Matching (Every minute) + client.Tome.Create(). + SetName("Tome Schedule Match"). + SetDescription("Test"). + SetAuthor("Test Author"). + SetEldritch("print('schedule')"). + SetRunOnSchedule("* * * * *"). + SaveX(ctx) + + // T4: Schedule Matching with Host Restriction (Allowed) + client.Tome.Create(). + SetName("Tome Schedule Restricted Allowed"). + SetDescription("Test"). + SetAuthor("Test Author"). + SetEldritch("print('schedule restricted')"). + SetRunOnSchedule("* * * * *"). + AddScheduledHosts(h). + SaveX(ctx) + + // T5: Schedule Matching with Host Restriction (Denied - different host) + otherHost := client.Host.Create(). + SetIdentifier("other"). + SetPlatform(c2pb.Host_PLATFORM_LINUX). + SaveX(ctx) + + client.Tome.Create(). + SetName("Tome Schedule Restricted Denied"). + SetDescription("Test"). + SetAuthor("Test Author"). + SetEldritch("print('schedule denied')"). + SetRunOnSchedule("* * * * *"). + AddScheduledHosts(otherHost). + SaveX(ctx) + + tests := []struct { + name string + isNewBeacon bool + isNewHost bool + expectedTomes []string + }{ + { + name: "New Beacon Only", + isNewBeacon: true, + isNewHost: false, + expectedTomes: []string{ + "Tome New Beacon", + "Tome Schedule Match", + "Tome Schedule Restricted Allowed", + }, + }, + { + name: "New Host Only", + isNewBeacon: false, + isNewHost: true, + expectedTomes: []string{ + "Tome New Host", + "Tome Schedule Match", + "Tome Schedule Restricted Allowed", + }, + }, + { + name: "Both New", + isNewBeacon: true, + isNewHost: true, + expectedTomes: []string{ + "Tome New Beacon", + "Tome New Host", + "Tome Schedule Match", + "Tome Schedule Restricted Allowed", + }, + }, + { + name: "Neither New", + isNewBeacon: false, + isNewHost: false, + expectedTomes: []string{ + "Tome Schedule Match", + "Tome Schedule Restricted Allowed", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Clear existing quests/tasks to ensure clean slate + client.Task.Delete().ExecX(ctx) + client.Quest.Delete().ExecX(ctx) + + srv.handleTomeAutomation(ctx, b.ID, h.ID, tt.isNewBeacon, tt.isNewHost, now) + + // Verify Tasks + tasks := client.Task.Query().WithQuest(func(q *ent.QuestQuery) { + q.WithTome() + }).AllX(ctx) + + var createdTomes []string + for _, t := range tasks { + createdTomes = append(createdTomes, t.Edges.Quest.Edges.Tome.Name) + } + + assert.ElementsMatch(t, tt.expectedTomes, createdTomes) + }) + } +} + +func TestHandleTomeAutomation_Deduplication(t *testing.T) { + ctx := context.Background() + client := enttest.Open(t, "sqlite3", "file:ent?mode=memory&cache=shared&_fk=1") + defer client.Close() + + srv := &Server{graph: client} + now := time.Now() + + h := client.Host.Create(). + SetIdentifier("test"). + SetPlatform(c2pb.Host_PLATFORM_LINUX). + SaveX(ctx) + b := client.Beacon.Create(). + SetIdentifier("test"). + SetHost(h). + SetTransport(c2pb.ActiveTransport_TRANSPORT_HTTP1). + SaveX(ctx) + + // Tome with ALL triggers enabled + client.Tome.Create(). + SetName("Super Tome"). + SetDescription("Test"). + SetAuthor("Test Author"). + SetEldritch("print('super')"). + SetRunOnNewBeaconCallback(true). + SetRunOnFirstHostCallback(true). + SetRunOnSchedule("* * * * *"). + SaveX(ctx) + + // Trigger all conditions + srv.handleTomeAutomation(ctx, b.ID, h.ID, true, true, now) + + // Should only have 1 task + count := client.Task.Query().CountX(ctx) + assert.Equal(t, 1, count, "Should only create one task despite multiple triggers matching") +}