Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 48 additions & 9 deletions tavern/internal/c2/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ import (
"encoding/json"
"fmt"
"time"

"realm.pub/tavern/internal/c2/c2pb"
"realm.pub/tavern/internal/ent"
"realm.pub/tavern/internal/ent/beacon"
"realm.pub/tavern/internal/ent/host"
"realm.pub/tavern/internal/ent/task"
"realm.pub/tavern/internal/namegen"

)

type Server struct {
Expand Down Expand Up @@ -66,12 +67,50 @@ func (srv *Server) ClaimTasks(ctx context.Context, req *c2pb.ClaimTasksRequest)
if err != nil {
return nil, fmt.Errorf("failed to upsert host entity: %w", err)
}
// 2. check if beacon is new
beaconExists, err := srv.graph.Beacon.Query().Where(beacon.IdentifierEQ(req.Beacon.Identifier)).Exist(ctx)
if err != nil {
return nil, fmt.Errorf("failed to query beacon entity: %w", err)
}
var beaconnameaddr *string = nil
//3. if the beacon is new lets pick a name for it
if !beaconExists {
candidateNames := []string{
namegen.GetRandomNameSimple(),
namegen.GetRandomNameModerate(),
namegen.GetRandomNameComplex(),
}

collisions, err := srv.graph.Beacon.Query().Where(beacon.NameIn(candidateNames...)).All(ctx)
if err != nil {
return nil, fmt.Errorf("failed to query beacon entity: %w", err)
}
if len(collisions) == 3 {
candidateNames := []string{
namegen.GetRandomNameSimple(),
namegen.GetRandomNameModerate(),
namegen.GetRandomNameComplex(),
}

collisions, err = srv.graph.Beacon.Query().Where(beacon.NameIn(candidateNames...)).All(ctx)
if err != nil {
return nil, fmt.Errorf("failed to query beacon entity: %w", err)
}
}
for _, canidate := range candidateNames {
if !namegen.IsCollision(collisions, canidate) {
beaconnameaddr = &canidate
break
}
}
}

// 2. Upsert the beacon
// 4. Upsert the beacon
beaconID, err := srv.graph.Beacon.Create().
SetPrincipal(req.Beacon.Principal).
SetIdentifier(req.Beacon.Identifier).
SetAgentIdentifier(req.Beacon.Agent.Identifier).
SetNillableName(beaconnameaddr).
SetHostID(hostID).
SetLastSeenAt(now).
SetInterval(req.Beacon.Interval).
Expand All @@ -82,7 +121,7 @@ func (srv *Server) ClaimTasks(ctx context.Context, req *c2pb.ClaimTasksRequest)
return nil, fmt.Errorf("failed to upsert beacon entity: %w", err)
}

// 3. Load Tasks
// 5. Load Tasks
tasks, err := srv.graph.Task.Query().
Where(task.And(
task.HasBeaconWith(beacon.ID(beaconID)),
Expand All @@ -93,22 +132,22 @@ func (srv *Server) ClaimTasks(ctx context.Context, req *c2pb.ClaimTasksRequest)
return nil, fmt.Errorf("failed to query tasks: %w", err)
}

// 4. Prepare Transaction for Claiming Tasks
// 6. Prepare Transaction for Claiming Tasks
tx, err := srv.graph.Tx(ctx)
if err != nil {
return nil, fmt.Errorf("failed to initialize transaction: %w", err)
}
client := tx.Client()

// 5. Rollback transaction if we panic
// 7. Rollback transaction if we panic
defer func() {
if v := recover(); v != nil {
tx.Rollback()
panic(v)
}
}()

// 6. Update all ClaimedAt timestamps to claim tasks
// 8. Update all ClaimedAt timestamps to claim tasks
// ** Note: If one fails to update, we roll back the transaction and return the error
taskIDs := make([]int, 0, len(tasks))
for _, t := range tasks {
Expand All @@ -121,12 +160,12 @@ func (srv *Server) ClaimTasks(ctx context.Context, req *c2pb.ClaimTasksRequest)
taskIDs = append(taskIDs, t.ID)
}

// 7. Commit the transaction
// 9. Commit the transaction
if err := tx.Commit(); err != nil {
return nil, rollback(tx, fmt.Errorf("failed to commit transaction: %w", err))
}

// 8. Load the tasks with our non transactional client (cannot use transaction after commit)
// 10. Load the tasks with our non transactional client (cannot use transaction after commit)
resp := c2pb.ClaimTasksResponse{}
resp.Tasks = make([]*c2pb.Task, 0, len(taskIDs))
for _, taskID := range taskIDs {
Expand Down Expand Up @@ -155,7 +194,7 @@ func (srv *Server) ClaimTasks(ctx context.Context, req *c2pb.ClaimTasksRequest)
})
}

// 9. Return claimed tasks
// 11. Return claimed tasks
return &resp, nil
}

Expand Down
2 changes: 1 addition & 1 deletion tavern/internal/ent/schema/beacon.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (Beacon) Fields() []ent.Field {
NotEmpty().
Unique().
Immutable().
DefaultFunc(namegen.GetRandomName).
DefaultFunc(namegen.GetRandomNameComplex).
Comment("A human readable identifier for the beacon."),
field.String("principal").
Optional().
Expand Down
67 changes: 53 additions & 14 deletions tavern/internal/namegen/namegen.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"log"
"math/big"
"time"

"realm.pub/tavern/internal/ent"
)

var (
Expand Down Expand Up @@ -886,21 +888,49 @@ var (
}
)

// GetRandomName generates a random name from the list of adjectives and surnames in this package
// formatted as "adjective-surname". For example 'focused-turing'.
func GetRandomName() string {
if time.Now().Month() == time.October && time.Now().Day() == 31 {
adj1IndexHalloween := newRandInt(int64(len(adjectives_halloween)))
adj2IndexHalloween := newRandInt(int64(len(adjectives_halloween)))
nounIndex := newRandInt(int64(len(noun_halloween)))
randNum := newRandInt(10000000)
return fmt.Sprintf("%s-%s-%s-%d", adjectives_halloween[adj1IndexHalloween], adjectives_halloween[adj2IndexHalloween], noun_halloween[nounIndex], randNum)
// getRandomNameSimple generates a random name with one adjective and one noun.
func GetRandomNameSimple() string {
adj, noun := getRandomAdjNoun()
return fmt.Sprintf("%s-%s", adj, noun)
}

// getRandomNameModerate generates a random name with two adjectives and one noun.
func GetRandomNameModerate() string {
adj1, adj2, noun := getRandomAdjAdjNoun()
return fmt.Sprintf("%s-%s-%s", adj1, adj2, noun)
}

// getRandomNameComplex generates a random name with two adjectives, one noun, and a number.
func GetRandomNameComplex() string {
adj1, adj2, noun := getRandomAdjAdjNoun()
num := newRandInt(10000000)
return fmt.Sprintf("%s-%s-%s-%d", adj1, adj2, noun, num)
}

// Helper function to get a random adjective and noun.
func getRandomAdjNoun() (string, string) {
var adj, noun string

if time.Now().Month() == time.October {
adj = adjectives_halloween[newRandInt(int64(len(adjectives_halloween)))]
noun = noun_halloween[newRandInt(int64(len(noun_halloween)))]
} else {
adj = adjectives[newRandInt(int64(len(adjectives)))]
noun = nouns[newRandInt(int64(len(nouns)))]
}
return adj, noun
}

// Helper function to get two random adjectives and a noun.
func getRandomAdjAdjNoun() (string, string, string) {
adj1, noun := getRandomAdjNoun()
var adj2 string
if time.Now().Month() == time.October {
adj2 = adjectives_halloween[newRandInt(int64(len(adjectives_halloween)))]
} else {
adj2 = adjectives[newRandInt(int64(len(adjectives)))]
}
adj1Index := newRandInt(int64(len(adjectives)))
adj2Index := newRandInt(int64(len(adjectives)))
nounIndex := newRandInt(int64(len(nouns)))
randNum := newRandInt(10000000)
return fmt.Sprintf("%s-%s-%s-%d", adjectives[adj1Index], adjectives[adj2Index], nouns[nounIndex], randNum)
return adj1, adj2, noun
}

// cryptoRandSecure is not always secure, if it errors we return 1337 % max
Expand All @@ -912,3 +942,12 @@ func newRandInt(max int64) int64 {
}
return nBig.Int64()
}

func IsCollision(beacons []*ent.Beacon, str string) bool {
for _, v := range beacons {
if v.Name == str {
return true
}
}
return false
}
50 changes: 46 additions & 4 deletions tavern/internal/namegen/namegen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,28 @@ package namegen_test

import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"realm.pub/tavern/internal/namegen"
"realm.pub/tavern/internal/ent"
)

func TestGetRandomName(t *testing.T) {
t.Run("BasicName", func(t *testing.T) {
name := namegen.GetRandomName()
assert.NotEmpty(t, name)
name1 := namegen.GetRandomNameSimple()
assert.NotEmpty(t, name1)
name2 := namegen.GetRandomNameModerate()
assert.NotEmpty(t, name2)
name3 := namegen.GetRandomNameComplex()
assert.NotEmpty(t, name3)
})

t.Run("NoDuplicates", func(t *testing.T) {
// Ensure we don't duplicate names over the course of many trials
names := make(map[string]bool, 1000000)
count := 0
for i := 0; i < 1000000; i++ {
name := namegen.GetRandomName()
name := namegen.GetRandomNameComplex()
exists, ok := names[name]
require.False(t, ok, "Name %s already exists - after %d attempts", name, count)
assert.False(t, exists)
Expand All @@ -29,3 +33,41 @@ func TestGetRandomName(t *testing.T) {
})

}

// TestBeaconnameinstring tests the Beaconnameinstring function
func TestBeaconnameinstring(t *testing.T) {
testCases := []struct {
name string
beacons []*ent.Beacon
str string
expected bool
}{
{
name: "String matches a beacon name",
beacons: []*ent.Beacon{
{Name: "Alpha"},
{Name: "Beta"},
},
str: "Beta",
expected: true,
},
{
name: "String does not match any beacon name",
beacons: []*ent.Beacon{
{Name: "Alpha"},
{Name: "Beta"},
},
str: "Gamma",
expected: false,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := namegen.IsCollision(tc.beacons, tc.str)
if result != tc.expected {
t.Errorf("Beaconnameinstring(%v, %s) = %v; expected %v", tc.beacons, tc.str, result, tc.expected)
}
})
}
}
4 changes: 2 additions & 2 deletions tavern/test_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,14 +351,14 @@ None
func createQuest(ctx context.Context, client *ent.Client, beacons ...*ent.Beacon) {
// Mid-Execution
testTome := client.Tome.Create().
SetName(namegen.GetRandomName()).
SetName(namegen.GetRandomNameComplex()).
SetDescription("Print a message for fun!").
SetEldritch(`print(input_params['msg'])`).
SetParamDefs(`[{"name":"msg","label":"Message","type":"string","placeholder":"something to print"}]`).
SaveX(ctx)

q := client.Quest.Create().
SetName(namegen.GetRandomName()).
SetName(namegen.GetRandomNameComplex()).
SetParameters(`{"msg":"Hello World!"}`).
SetTome(testTome).
SaveX(ctx)
Expand Down