diff --git a/tavern/internal/c2/server.go b/tavern/internal/c2/server.go index d856f8a8b..a7c9dafb7 100644 --- a/tavern/internal/c2/server.go +++ b/tavern/internal/c2/server.go @@ -5,12 +5,13 @@ import ( "encoding/json" "fmt" "time" - "realm.pub/tavern/internal/c2/c2pb" "realm.pub/tavern/internal/ent" "realm.pub/tavern/internal/ent/beacon" "realm.pub/tavern/internal/ent/host" "realm.pub/tavern/internal/ent/task" + "realm.pub/tavern/internal/namegen" + ) type Server struct { @@ -66,12 +67,50 @@ func (srv *Server) ClaimTasks(ctx context.Context, req *c2pb.ClaimTasksRequest) if err != nil { return nil, fmt.Errorf("failed to upsert host entity: %w", err) } + // 2. check if beacon is new + beaconExists, err := srv.graph.Beacon.Query().Where(beacon.IdentifierEQ(req.Beacon.Identifier)).Exist(ctx) + if err != nil { + return nil, fmt.Errorf("failed to query beacon entity: %w", err) + } + var beaconnameaddr *string = nil + //3. if the beacon is new lets pick a name for it + if !beaconExists { + candidateNames := []string{ + namegen.GetRandomNameSimple(), + namegen.GetRandomNameModerate(), + namegen.GetRandomNameComplex(), + } + + collisions, err := srv.graph.Beacon.Query().Where(beacon.NameIn(candidateNames...)).All(ctx) + if err != nil { + return nil, fmt.Errorf("failed to query beacon entity: %w", err) + } + if len(collisions) == 3 { + candidateNames := []string{ + namegen.GetRandomNameSimple(), + namegen.GetRandomNameModerate(), + namegen.GetRandomNameComplex(), + } + + collisions, err = srv.graph.Beacon.Query().Where(beacon.NameIn(candidateNames...)).All(ctx) + if err != nil { + return nil, fmt.Errorf("failed to query beacon entity: %w", err) + } + } + for _, canidate := range candidateNames { + if !namegen.IsCollision(collisions, canidate) { + beaconnameaddr = &canidate + break + } + } + } - // 2. Upsert the beacon + // 4. Upsert the beacon beaconID, err := srv.graph.Beacon.Create(). SetPrincipal(req.Beacon.Principal). SetIdentifier(req.Beacon.Identifier). SetAgentIdentifier(req.Beacon.Agent.Identifier). + SetNillableName(beaconnameaddr). SetHostID(hostID). SetLastSeenAt(now). SetInterval(req.Beacon.Interval). @@ -82,7 +121,7 @@ func (srv *Server) ClaimTasks(ctx context.Context, req *c2pb.ClaimTasksRequest) return nil, fmt.Errorf("failed to upsert beacon entity: %w", err) } - // 3. Load Tasks + // 5. Load Tasks tasks, err := srv.graph.Task.Query(). Where(task.And( task.HasBeaconWith(beacon.ID(beaconID)), @@ -93,14 +132,14 @@ func (srv *Server) ClaimTasks(ctx context.Context, req *c2pb.ClaimTasksRequest) return nil, fmt.Errorf("failed to query tasks: %w", err) } - // 4. Prepare Transaction for Claiming Tasks + // 6. Prepare Transaction for Claiming Tasks tx, err := srv.graph.Tx(ctx) if err != nil { return nil, fmt.Errorf("failed to initialize transaction: %w", err) } client := tx.Client() - // 5. Rollback transaction if we panic + // 7. Rollback transaction if we panic defer func() { if v := recover(); v != nil { tx.Rollback() @@ -108,7 +147,7 @@ func (srv *Server) ClaimTasks(ctx context.Context, req *c2pb.ClaimTasksRequest) } }() - // 6. Update all ClaimedAt timestamps to claim tasks + // 8. Update all ClaimedAt timestamps to claim tasks // ** Note: If one fails to update, we roll back the transaction and return the error taskIDs := make([]int, 0, len(tasks)) for _, t := range tasks { @@ -121,12 +160,12 @@ func (srv *Server) ClaimTasks(ctx context.Context, req *c2pb.ClaimTasksRequest) taskIDs = append(taskIDs, t.ID) } - // 7. Commit the transaction + // 9. Commit the transaction if err := tx.Commit(); err != nil { return nil, rollback(tx, fmt.Errorf("failed to commit transaction: %w", err)) } - // 8. Load the tasks with our non transactional client (cannot use transaction after commit) + // 10. Load the tasks with our non transactional client (cannot use transaction after commit) resp := c2pb.ClaimTasksResponse{} resp.Tasks = make([]*c2pb.Task, 0, len(taskIDs)) for _, taskID := range taskIDs { @@ -155,7 +194,7 @@ func (srv *Server) ClaimTasks(ctx context.Context, req *c2pb.ClaimTasksRequest) }) } - // 9. Return claimed tasks + // 11. Return claimed tasks return &resp, nil } diff --git a/tavern/internal/ent/schema/beacon.go b/tavern/internal/ent/schema/beacon.go index 543ae287b..14ccdb5c1 100644 --- a/tavern/internal/ent/schema/beacon.go +++ b/tavern/internal/ent/schema/beacon.go @@ -27,7 +27,7 @@ func (Beacon) Fields() []ent.Field { NotEmpty(). Unique(). Immutable(). - DefaultFunc(namegen.GetRandomName). + DefaultFunc(namegen.GetRandomNameComplex). Comment("A human readable identifier for the beacon."), field.String("principal"). Optional(). diff --git a/tavern/internal/namegen/namegen.go b/tavern/internal/namegen/namegen.go index 2d1411aa2..101f88dd8 100644 --- a/tavern/internal/namegen/namegen.go +++ b/tavern/internal/namegen/namegen.go @@ -6,6 +6,8 @@ import ( "log" "math/big" "time" + + "realm.pub/tavern/internal/ent" ) var ( @@ -886,21 +888,49 @@ var ( } ) -// GetRandomName generates a random name from the list of adjectives and surnames in this package -// formatted as "adjective-surname". For example 'focused-turing'. -func GetRandomName() string { - if time.Now().Month() == time.October && time.Now().Day() == 31 { - adj1IndexHalloween := newRandInt(int64(len(adjectives_halloween))) - adj2IndexHalloween := newRandInt(int64(len(adjectives_halloween))) - nounIndex := newRandInt(int64(len(noun_halloween))) - randNum := newRandInt(10000000) - return fmt.Sprintf("%s-%s-%s-%d", adjectives_halloween[adj1IndexHalloween], adjectives_halloween[adj2IndexHalloween], noun_halloween[nounIndex], randNum) +// getRandomNameSimple generates a random name with one adjective and one noun. +func GetRandomNameSimple() string { + adj, noun := getRandomAdjNoun() + return fmt.Sprintf("%s-%s", adj, noun) +} + +// getRandomNameModerate generates a random name with two adjectives and one noun. +func GetRandomNameModerate() string { + adj1, adj2, noun := getRandomAdjAdjNoun() + return fmt.Sprintf("%s-%s-%s", adj1, adj2, noun) +} + +// getRandomNameComplex generates a random name with two adjectives, one noun, and a number. +func GetRandomNameComplex() string { + adj1, adj2, noun := getRandomAdjAdjNoun() + num := newRandInt(10000000) + return fmt.Sprintf("%s-%s-%s-%d", adj1, adj2, noun, num) +} + +// Helper function to get a random adjective and noun. +func getRandomAdjNoun() (string, string) { + var adj, noun string + + if time.Now().Month() == time.October { + adj = adjectives_halloween[newRandInt(int64(len(adjectives_halloween)))] + noun = noun_halloween[newRandInt(int64(len(noun_halloween)))] + } else { + adj = adjectives[newRandInt(int64(len(adjectives)))] + noun = nouns[newRandInt(int64(len(nouns)))] + } + return adj, noun +} + +// Helper function to get two random adjectives and a noun. +func getRandomAdjAdjNoun() (string, string, string) { + adj1, noun := getRandomAdjNoun() + var adj2 string + if time.Now().Month() == time.October { + adj2 = adjectives_halloween[newRandInt(int64(len(adjectives_halloween)))] + } else { + adj2 = adjectives[newRandInt(int64(len(adjectives)))] } - adj1Index := newRandInt(int64(len(adjectives))) - adj2Index := newRandInt(int64(len(adjectives))) - nounIndex := newRandInt(int64(len(nouns))) - randNum := newRandInt(10000000) - return fmt.Sprintf("%s-%s-%s-%d", adjectives[adj1Index], adjectives[adj2Index], nouns[nounIndex], randNum) + return adj1, adj2, noun } // cryptoRandSecure is not always secure, if it errors we return 1337 % max @@ -912,3 +942,12 @@ func newRandInt(max int64) int64 { } return nBig.Int64() } + +func IsCollision(beacons []*ent.Beacon, str string) bool { + for _, v := range beacons { + if v.Name == str { + return true + } + } + return false +} diff --git a/tavern/internal/namegen/namegen_test.go b/tavern/internal/namegen/namegen_test.go index 9a22bb4dc..8987bba76 100644 --- a/tavern/internal/namegen/namegen_test.go +++ b/tavern/internal/namegen/namegen_test.go @@ -2,16 +2,20 @@ package namegen_test import ( "testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "realm.pub/tavern/internal/namegen" + "realm.pub/tavern/internal/ent" ) func TestGetRandomName(t *testing.T) { t.Run("BasicName", func(t *testing.T) { - name := namegen.GetRandomName() - assert.NotEmpty(t, name) + name1 := namegen.GetRandomNameSimple() + assert.NotEmpty(t, name1) + name2 := namegen.GetRandomNameModerate() + assert.NotEmpty(t, name2) + name3 := namegen.GetRandomNameComplex() + assert.NotEmpty(t, name3) }) t.Run("NoDuplicates", func(t *testing.T) { @@ -19,7 +23,7 @@ func TestGetRandomName(t *testing.T) { names := make(map[string]bool, 1000000) count := 0 for i := 0; i < 1000000; i++ { - name := namegen.GetRandomName() + name := namegen.GetRandomNameComplex() exists, ok := names[name] require.False(t, ok, "Name %s already exists - after %d attempts", name, count) assert.False(t, exists) @@ -29,3 +33,41 @@ func TestGetRandomName(t *testing.T) { }) } + +// TestBeaconnameinstring tests the Beaconnameinstring function +func TestBeaconnameinstring(t *testing.T) { + testCases := []struct { + name string + beacons []*ent.Beacon + str string + expected bool + }{ + { + name: "String matches a beacon name", + beacons: []*ent.Beacon{ + {Name: "Alpha"}, + {Name: "Beta"}, + }, + str: "Beta", + expected: true, + }, + { + name: "String does not match any beacon name", + beacons: []*ent.Beacon{ + {Name: "Alpha"}, + {Name: "Beta"}, + }, + str: "Gamma", + expected: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := namegen.IsCollision(tc.beacons, tc.str) + if result != tc.expected { + t.Errorf("Beaconnameinstring(%v, %s) = %v; expected %v", tc.beacons, tc.str, result, tc.expected) + } + }) + } +} diff --git a/tavern/test_data.go b/tavern/test_data.go index 689c79d52..9eb88453f 100644 --- a/tavern/test_data.go +++ b/tavern/test_data.go @@ -351,14 +351,14 @@ None func createQuest(ctx context.Context, client *ent.Client, beacons ...*ent.Beacon) { // Mid-Execution testTome := client.Tome.Create(). - SetName(namegen.GetRandomName()). + SetName(namegen.GetRandomNameComplex()). SetDescription("Print a message for fun!"). SetEldritch(`print(input_params['msg'])`). SetParamDefs(`[{"name":"msg","label":"Message","type":"string","placeholder":"something to print"}]`). SaveX(ctx) q := client.Quest.Create(). - SetName(namegen.GetRandomName()). + SetName(namegen.GetRandomNameComplex()). SetParameters(`{"msg":"Hello World!"}`). SetTome(testTome). SaveX(ctx)