diff --git a/bin/socks5/proxy_test.go b/bin/socks5/proxy_test.go index a4ca5f588..ca4ef071f 100644 --- a/bin/socks5/proxy_test.go +++ b/bin/socks5/proxy_test.go @@ -18,7 +18,7 @@ import ( // MockPortalServer implements the Portal service. type MockPortalServer struct { portalpb.UnimplementedPortalServer - mu sync.Mutex + mu sync.Mutex // For testing, we can hook into received motes onMote func(mote *portalpb.Mote, send func(*portalpb.Mote)) @@ -81,7 +81,9 @@ func BenchmarkProxy(b *testing.B) { sent := 0 data := make([]byte, chunkSize) // Fill with some data - for i := range data { data[i] = byte(i) } + for i := range data { + data[i] = byte(i) + } currentSeq := uint64(0) diff --git a/tavern/internal/c2/api_claim_tasks_test.go b/tavern/internal/c2/api_claim_tasks_test.go index f188fc7b2..e0091b342 100644 --- a/tavern/internal/c2/api_claim_tasks_test.go +++ b/tavern/internal/c2/api_claim_tasks_test.go @@ -20,7 +20,7 @@ import ( func TestClaimTasks(t *testing.T) { // Setup Dependencies ctx := context.Background() - client, graph, close := c2test.New(t) + client, graph, close, _ := c2test.New(t) defer close() // Test Data diff --git a/tavern/internal/c2/api_fetch_asset_test.go b/tavern/internal/c2/api_fetch_asset_test.go index 5522fd2c7..ea0bb24c2 100644 --- a/tavern/internal/c2/api_fetch_asset_test.go +++ b/tavern/internal/c2/api_fetch_asset_test.go @@ -21,7 +21,7 @@ import ( func TestFetchAsset(t *testing.T) { // Setup Dependencies ctx := context.Background() - client, graph, close := c2test.New(t) + client, graph, close, token := c2test.New(t) defer close() // Test Cases @@ -67,6 +67,13 @@ func TestFetchAsset(t *testing.T) { SetContent(data). SaveX(ctx) + // Ensure request contains JWT + if tc.req.Context == nil { + tc.req.Context = &c2pb.TaskContext{Jwt: token} + } else { + tc.req.Context.Jwt = token + } + // Send Request fileClient, err := client.FetchAsset(ctx, tc.req) require.NoError(t, err) diff --git a/tavern/internal/c2/api_report_credential_test.go b/tavern/internal/c2/api_report_credential_test.go index aca16ef92..9501f0525 100644 --- a/tavern/internal/c2/api_report_credential_test.go +++ b/tavern/internal/c2/api_report_credential_test.go @@ -20,7 +20,7 @@ import ( func TestReportCredential(t *testing.T) { // Setup Dependencies ctx := context.Background() - client, graph, close := c2test.New(t) + client, graph, close, token := c2test.New(t) defer close() // Test Data @@ -51,7 +51,7 @@ func TestReportCredential(t *testing.T) { host: existingHost, task: existingTask, req: &c2pb.ReportCredentialRequest{ - Context: &c2pb.TaskContext{TaskId: int64(existingTask.ID)}, + Context: &c2pb.TaskContext{TaskId: int64(existingTask.ID), Jwt: token}, Credential: &epb.Credential{ Principal: existingCredential.Principal, Secret: existingCredential.Secret, @@ -78,7 +78,7 @@ func TestReportCredential(t *testing.T) { host: existingHost, task: existingTask, req: &c2pb.ReportCredentialRequest{ - Context: &c2pb.TaskContext{TaskId: int64(existingTask.ID)}, + Context: &c2pb.TaskContext{TaskId: int64(existingTask.ID), Jwt: token}, Credential: &epb.Credential{ Principal: "root", Secret: "changeme123", @@ -125,7 +125,7 @@ func TestReportCredential(t *testing.T) { host: existingHost, task: existingTask, req: &c2pb.ReportCredentialRequest{ - Context: &c2pb.TaskContext{TaskId: int64(existingTask.ID)}, + Context: &c2pb.TaskContext{TaskId: int64(existingTask.ID), Jwt: token}, }, wantResp: nil, wantCode: codes.InvalidArgument, @@ -133,7 +133,7 @@ func TestReportCredential(t *testing.T) { { name: "NotFound", req: &c2pb.ReportCredentialRequest{ - Context: &c2pb.TaskContext{TaskId: 99888777776666}, + Context: &c2pb.TaskContext{TaskId: 99888777776666, Jwt: token}, Credential: &epb.Credential{ Principal: "root", Secret: "oopsies", diff --git a/tavern/internal/c2/api_report_file_test.go b/tavern/internal/c2/api_report_file_test.go index 8055ddfa0..34b43ed4d 100644 --- a/tavern/internal/c2/api_report_file_test.go +++ b/tavern/internal/c2/api_report_file_test.go @@ -20,7 +20,7 @@ import ( func TestReportFile(t *testing.T) { // Setup Dependencies ctx := context.Background() - client, graph, close := c2test.New(t) + client, graph, close, token := c2test.New(t) defer close() // Test Data @@ -99,7 +99,7 @@ func TestReportFile(t *testing.T) { name: "MissingPath", reqs: []*c2pb.ReportFileRequest{ { - Context: &c2pb.TaskContext{TaskId: 1234}, + Context: &c2pb.TaskContext{TaskId: 1234, Jwt: token}, }, }, wantCode: codes.InvalidArgument, @@ -108,7 +108,7 @@ func TestReportFile(t *testing.T) { name: "NewFile_Single", reqs: []*c2pb.ReportFileRequest{ { - Context: &c2pb.TaskContext{TaskId: int64(existingTasks[2].ID)}, + Context: &c2pb.TaskContext{TaskId: int64(existingTasks[2].ID), Jwt: token}, Chunk: &epb.File{ Metadata: &epb.FileMetadata{ Path: "/new/file", @@ -142,7 +142,7 @@ func TestReportFile(t *testing.T) { name: "NewFile_MultiChunk", reqs: []*c2pb.ReportFileRequest{ { - Context: &c2pb.TaskContext{TaskId: int64(existingTasks[2].ID)}, + Context: &c2pb.TaskContext{TaskId: int64(existingTasks[2].ID), Jwt: token}, Chunk: &epb.File{ Metadata: &epb.FileMetadata{ Path: "/another/new/file", @@ -174,7 +174,7 @@ func TestReportFile(t *testing.T) { name: "Replace_File", reqs: []*c2pb.ReportFileRequest{ { - Context: &c2pb.TaskContext{TaskId: int64(existingTasks[2].ID)}, + Context: &c2pb.TaskContext{TaskId: int64(existingTasks[2].ID), Jwt: token}, Chunk: &epb.File{ Metadata: &epb.FileMetadata{ Path: "/another/new/file", @@ -201,7 +201,7 @@ func TestReportFile(t *testing.T) { name: "No_Prexisting_Files", reqs: []*c2pb.ReportFileRequest{ { - Context: &c2pb.TaskContext{TaskId: int64(existingTasks[3].ID)}, + Context: &c2pb.TaskContext{TaskId: int64(existingTasks[3].ID), Jwt: token}, Chunk: &epb.File{ Metadata: &epb.FileMetadata{ Path: "/no/other/files", diff --git a/tavern/internal/c2/api_report_process_list_test.go b/tavern/internal/c2/api_report_process_list_test.go index 2df5e5bb7..dbc667df9 100644 --- a/tavern/internal/c2/api_report_process_list_test.go +++ b/tavern/internal/c2/api_report_process_list_test.go @@ -19,7 +19,7 @@ import ( func TestReportProcessList(t *testing.T) { // Setup Dependencies ctx := context.Background() - client, graph, close := c2test.New(t) + client, graph, close, token := c2test.New(t) defer close() // Test Data @@ -44,7 +44,7 @@ func TestReportProcessList(t *testing.T) { host: existingHost, task: existingTask, req: &c2pb.ReportProcessListRequest{ - Context: &c2pb.TaskContext{TaskId: int64(existingTask.ID)}, + Context: &c2pb.TaskContext{TaskId: int64(existingTask.ID), Jwt: token}, List: &epb.ProcessList{ List: []*epb.Process{ {Pid: 1, Name: "systemd", Principal: "root", Status: epb.Process_STATUS_RUN}, @@ -63,7 +63,7 @@ func TestReportProcessList(t *testing.T) { host: existingHost, task: existingTask, req: &c2pb.ReportProcessListRequest{ - Context: &c2pb.TaskContext{TaskId: int64(existingTask.ID)}, + Context: &c2pb.TaskContext{TaskId: int64(existingTask.ID), Jwt: token}, List: &epb.ProcessList{ List: []*epb.Process{ {Pid: 1, Name: "systemd", Principal: "root"}, @@ -96,7 +96,7 @@ func TestReportProcessList(t *testing.T) { host: existingHost, task: existingTask, req: &c2pb.ReportProcessListRequest{ - Context: &c2pb.TaskContext{TaskId: int64(existingTask.ID)}, + Context: &c2pb.TaskContext{TaskId: int64(existingTask.ID), Jwt: token}, List: &epb.ProcessList{ List: []*epb.Process{}, }, @@ -107,7 +107,7 @@ func TestReportProcessList(t *testing.T) { { name: "Not_Found", req: &c2pb.ReportProcessListRequest{ - Context: &c2pb.TaskContext{TaskId: 99888777776666}, + Context: &c2pb.TaskContext{TaskId: 99888777776666, Jwt: token}, List: &epb.ProcessList{ List: []*epb.Process{ {Pid: 1, Name: "systemd", Principal: "root"}, diff --git a/tavern/internal/c2/api_report_task_output_test.go b/tavern/internal/c2/api_report_task_output_test.go index ac0326602..43a1e98a8 100644 --- a/tavern/internal/c2/api_report_task_output_test.go +++ b/tavern/internal/c2/api_report_task_output_test.go @@ -20,7 +20,7 @@ import ( func TestReportTaskOutput(t *testing.T) { // Setup Dependencies ctx := context.Background() - client, graph, close := c2test.New(t) + client, graph, close, token := c2test.New(t) defer close() // Test Data @@ -131,6 +131,13 @@ func TestReportTaskOutput(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { // Callback + // Ensure JWT present in request context + if tc.req.Context == nil { + tc.req.Context = &c2pb.TaskContext{Jwt: token} + } else { + tc.req.Context.Jwt = token + } + resp, err := client.ReportTaskOutput(ctx, tc.req) // Assert Response Code diff --git a/tavern/internal/c2/c2test/grpc.go b/tavern/internal/c2/c2test/grpc.go index ebf31c3c1..2fbdad64d 100644 --- a/tavern/internal/c2/c2test/grpc.go +++ b/tavern/internal/c2/c2test/grpc.go @@ -7,7 +7,9 @@ import ( "errors" "net" "testing" + "time" + "github.com/golang-jwt/jwt/v5" _ "github.com/mattn/go-sqlite3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -24,7 +26,7 @@ import ( "realm.pub/tavern/internal/portals/mux" ) -func New(t *testing.T) (c2pb.C2Client, *ent.Client, func()) { +func New(t *testing.T) (c2pb.C2Client, *ent.Client, func(), string) { t.Helper() ctx := context.Background() @@ -55,6 +57,19 @@ func New(t *testing.T) (c2pb.C2Client, *ent.Client, func()) { testPubKey, testPrivKey, err := ed25519.GenerateKey(rand.Reader) require.NoError(t, err) + // Generate a signed JWT string for tests + claims := jwt.MapClaims{ + "iat": time.Now().Unix(), + "exp": time.Now().Add(1 * time.Hour).Unix(), + } + testToken := "" + { + token := jwt.NewWithClaims(jwt.SigningMethodEdDSA, claims) + s, err := token.SignedString(testPrivKey) + require.NoError(t, err) + testToken = s + } + // gRPC Server lis := bufconn.Listen(1024 * 1024 * 10) baseSrv := grpc.NewServer() @@ -82,5 +97,5 @@ func New(t *testing.T) (c2pb.C2Client, *ent.Client, func()) { if err := <-grpcErrCh; err != nil && !errors.Is(err, grpc.ErrServerStopped) { t.Fatalf("failed to serve grpc: %v", err) } - } + }, testToken } diff --git a/tavern/internal/c2/server.go b/tavern/internal/c2/server.go index bf4e10538..624a8a225 100644 --- a/tavern/internal/c2/server.go +++ b/tavern/internal/c2/server.go @@ -10,8 +10,10 @@ import ( "time" "github.com/golang-jwt/jwt/v5" + "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" "realm.pub/tavern/internal/c2/c2pb" "realm.pub/tavern/internal/ent" "realm.pub/tavern/internal/http/stream" @@ -111,19 +113,14 @@ func (srv *Server) ValidateJWT(jwttoken string) error { token, err := jwt.Parse(jwttoken, func(token *jwt.Token) (any, error) { // 1. Verify the signing method is EdDSA if _, ok := token.Method.(*jwt.SigningMethodEd25519); !ok { - // TODO: Uncomment with imixv1 delete - // return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) - slog.Warn(fmt.Sprintf("unexpected signing method: %v", token.Header["alg"])) + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } // 2. Return the PUBLIC key for verification return srv.jwtPublicKey, nil }) if err != nil || !token.Valid { - // TODO: Uncomment with imixv1 delete - // return status.Errorf(codes.PermissionDenied, "invalid token: %v", err) - slog.Warn(fmt.Sprintf("invalid token: %v", err)) - return nil + return status.Errorf(codes.PermissionDenied, "invalid token: %v", err) } slog.Info(fmt.Sprintf("received valid JWT: %s", jwttoken)) diff --git a/tavern/internal/cryptocodec/cryptocodec.go b/tavern/internal/cryptocodec/cryptocodec.go index a24e1dcff..c80eb87b1 100644 --- a/tavern/internal/cryptocodec/cryptocodec.go +++ b/tavern/internal/cryptocodec/cryptocodec.go @@ -152,7 +152,6 @@ func (csvc *CryptoSvc) Decrypt(in_arr []byte) ([]byte, []byte) { client_pub_key_bytes := make([]byte, x25519.Size) copy(client_pub_key_bytes, in_arr[:x25519.Size]) - ids, err := goAllIds() if err != nil { slog.Error("failed to get goid") diff --git a/tavern/internal/ent/schema/link_test.go b/tavern/internal/ent/schema/link_test.go index 0acf6f6cc..dc554a2bc 100644 --- a/tavern/internal/ent/schema/link_test.go +++ b/tavern/internal/ent/schema/link_test.go @@ -41,4 +41,4 @@ func TestCreateLinkWithExplicitExpiresAt(t *testing.T) { // Verify the explicit expiresAt was used assert.WithinDuration(t, futureTime, link.ExpiresAt, time.Second) assert.Equal(t, 5, link.DownloadsRemaining) -} \ No newline at end of file +}