From 8f346201e221a136adf35fd164b1b17106f72b94 Mon Sep 17 00:00:00 2001 From: KCarretto Date: Tue, 23 Jan 2024 03:00:08 +0000 Subject: [PATCH 1/5] tavern app coverage tests --- tavern/app.go | 14 ++++++- tavern/config.go | 9 ++++- tavern/env_test.go | 92 +++++++++++++++++++++++++++++++++++++++++++ tavern/main.go | 4 +- tavern/main_test.go | 19 +++++++++ tavern/status_test.go | 46 ++++++++++++++++++++++ 6 files changed, 180 insertions(+), 4 deletions(-) create mode 100644 tavern/env_test.go create mode 100644 tavern/main_test.go create mode 100644 tavern/status_test.go diff --git a/tavern/app.go b/tavern/app.go index 6d01e77ea..67df729f2 100644 --- a/tavern/app.go +++ b/tavern/app.go @@ -177,10 +177,20 @@ func NewServer(ctx context.Context, options ...func(*Config)) (*Server, error) { return nil, fmt.Errorf("failed to configure http/2: %w", err) } - return &Server{ + // Initialize Server + tSrv := &Server{ HTTP: cfg.srv, client: client, - }, nil + } + + // Shutdown for Test Run & Exit + if cfg.IsTestRunAndExitEnabled() { + go func() { + tSrv.HTTP.Shutdown(ctx) + }() + } + + return tSrv, nil } func newGraphQLHandler(client *ent.Client) http.Handler { diff --git a/tavern/config.go b/tavern/config.go index 242f7fe98..e9863813f 100644 --- a/tavern/config.go +++ b/tavern/config.go @@ -17,7 +17,9 @@ import ( var ( // EnvEnableTestData if set will populate the database with test data. - EnvEnableTestData = EnvString{"ENABLE_TEST_DATA", ""} + // EnvEnableTestRunAndExit will start the application, but exit immediately after. + EnvEnableTestData = EnvString{"ENABLE_TEST_DATA", ""} + EnvEnableTestRunAndExit = EnvString{"ENABLE_TEST_RUN_AND_EXIT", ""} // EnvOAuthClientID set to configure OAuth Client ID. // EnvOAuthClientSecret set to configure OAuth Client Secret. @@ -112,6 +114,11 @@ func (cfg *Config) IsTestDataEnabled() bool { return EnvEnableTestData.String() != "" } +// IsTestRunAndExitEnabled returns true if a value for the "ENABLE_TEST_RUN_AND_EXIT" environment variable is set. +func (cfg *Config) IsTestRunAndExitEnabled() bool { + return EnvEnableTestRunAndExit.String() != "" +} + // ConfigureHTTPServer enables the configuration of the Tavern HTTP server. The endpoint field will be // overwritten with Tavern's HTTP handler when Tavern is run. func ConfigureHTTPServer(address string, options ...func(*http.Server)) func(*Config) { diff --git a/tavern/env_test.go b/tavern/env_test.go new file mode 100644 index 000000000..53aa29c38 --- /dev/null +++ b/tavern/env_test.go @@ -0,0 +1,92 @@ +package main + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEnvString(t *testing.T) { + // Test Cases + tests := []struct { + name string + + env EnvString + osValue string + wantValue string + }{ + { + name: "Set", + env: EnvString{"TEST_ENV_STRING", ""}, + osValue: "VALUE_SET", + wantValue: "VALUE_SET", + }, + { + name: "Unset", + env: EnvString{"TEST_ENV_STRING", ""}, + osValue: "", + wantValue: "", + }, + { + name: "Default", + env: EnvString{"TEST_ENV_STRING", "BLAH_BLAH"}, + osValue: "", + wantValue: "BLAH_BLAH", + }, + } + + // Run Tests + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.osValue != "" { + os.Setenv(tc.env.Key, tc.osValue) + defer os.Unsetenv(tc.env.Key) + } + + assert.Equal(t, tc.wantValue, tc.env.String()) + }) + } +} + +func TestEnvInteger(t *testing.T) { + // Test Cases + tests := []struct { + name string + + env EnvInteger + osValue string + wantValue int + }{ + { + name: "Set", + env: EnvInteger{"TEST_ENV_INT", 0}, + osValue: "123", + wantValue: 123, + }, + { + name: "Unset", + env: EnvInteger{"TEST_ENV_INT", 0}, + osValue: "", + wantValue: 0, + }, + { + name: "Default", + env: EnvInteger{"TEST_ENV_INT", 456}, + osValue: "", + wantValue: 456, + }, + } + + // Run Tests + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if tc.osValue != "" { + os.Setenv(tc.env.Key, tc.osValue) + defer os.Unsetenv(tc.env.Key) + } + + assert.Equal(t, tc.wantValue, tc.env.Int()) + }) + } +} diff --git a/tavern/main.go b/tavern/main.go index 6ea1fad7f..79bf7d01e 100644 --- a/tavern/main.go +++ b/tavern/main.go @@ -2,7 +2,9 @@ package main import ( "context" + "errors" "log" + "net/http" "os" _ "realm.pub/tavern/internal/ent/runtime" @@ -17,7 +19,7 @@ func main() { ConfigureMySQLFromEnv(), ConfigureOAuthFromEnv("/oauth/authorize"), ) - if err := app.Run(os.Args); err != nil { + if err := app.Run(os.Args); err != nil && !errors.Is(err, http.ErrServerClosed) { log.Fatalf("fatal error: %v", err) } } diff --git a/tavern/main_test.go b/tavern/main_test.go new file mode 100644 index 000000000..866af30a3 --- /dev/null +++ b/tavern/main_test.go @@ -0,0 +1,19 @@ +package main + +import ( + "os" + "testing" +) + +// TestMainFunc runs main after configuring the application to immediately exit. +// This validates our default configurations are successful. +func TestMainFunc(t *testing.T) { + os.Setenv(EnvEnableTestRunAndExit.Key, "1") + defer func() { + if err := os.Unsetenv(EnvEnableTestRunAndExit.Key); err != nil { + t.Fatalf("failed to unset env var %s: %v", EnvEnableTestRunAndExit.Key, err) + } + }() + os.Args = []string{"tavern"} + main() +} diff --git a/tavern/status_test.go b/tavern/status_test.go new file mode 100644 index 000000000..0ec44ec68 --- /dev/null +++ b/tavern/status_test.go @@ -0,0 +1,46 @@ +package main + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStatusHandler(t *testing.T) { + // Setup Dependencies + handler := newStatusHandler() + + // Test Cases + tests := []struct { + name string + w *httptest.ResponseRecorder + r *http.Request + + wantCode int + wantBody string + }{ + { + name: "Successful", + w: httptest.NewRecorder(), + r: httptest.NewRequest(http.MethodGet, "/status", nil), + + wantCode: http.StatusOK, + wantBody: OKStatusText, + }, + } + + // Run Tests + for _, tc := range tests { + handler(tc.w, tc.r) + + body, err := io.ReadAll(tc.w.Body) + require.NoError(t, err) + + assert.Equal(t, tc.wantCode, tc.w.Code) + assert.Equal(t, tc.wantBody, string(body)) + } +} From 52df4e0d6556fe0cc839c3c4ec3fe0504eec5bb7 Mon Sep 17 00:00:00 2001 From: KCarretto Date: Tue, 23 Jan 2024 03:12:13 +0000 Subject: [PATCH 2/5] Add HTTP_LISTEN_ADDR env, use for test --- tavern/config.go | 7 +++++-- tavern/main.go | 2 +- tavern/main_test.go | 4 ++++ 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/tavern/config.go b/tavern/config.go index e9863813f..88f150d03 100644 --- a/tavern/config.go +++ b/tavern/config.go @@ -21,6 +21,9 @@ var ( EnvEnableTestData = EnvString{"ENABLE_TEST_DATA", ""} EnvEnableTestRunAndExit = EnvString{"ENABLE_TEST_RUN_AND_EXIT", ""} + // EnvHTTPListenAddr sets the address (ip:port) for tavern's HTTP server to bind to. + EnvHTTPListenAddr = EnvString{"HTTP_LISTEN_ADDR", "0.0.0.0:80"} + // EnvOAuthClientID set to configure OAuth Client ID. // EnvOAuthClientSecret set to configure OAuth Client Secret. // EnvOAuthDomain set to configure OAuth domain for consent flow redirect. @@ -121,9 +124,9 @@ func (cfg *Config) IsTestRunAndExitEnabled() bool { // ConfigureHTTPServer enables the configuration of the Tavern HTTP server. The endpoint field will be // overwritten with Tavern's HTTP handler when Tavern is run. -func ConfigureHTTPServer(address string, options ...func(*http.Server)) func(*Config) { +func ConfigureHTTPServerFromEnv(options ...func(*http.Server)) func(*Config) { srv := &http.Server{ - Addr: address, + Addr: EnvHTTPListenAddr.String(), } for _, opt := range options { opt(srv) diff --git a/tavern/main.go b/tavern/main.go index 79bf7d01e..a74e2dbed 100644 --- a/tavern/main.go +++ b/tavern/main.go @@ -15,7 +15,7 @@ import ( func main() { ctx := context.Background() app := newApp(ctx, - ConfigureHTTPServer("0.0.0.0:80"), + ConfigureHTTPServerFromEnv(), ConfigureMySQLFromEnv(), ConfigureOAuthFromEnv("/oauth/authorize"), ) diff --git a/tavern/main_test.go b/tavern/main_test.go index 866af30a3..8b5984114 100644 --- a/tavern/main_test.go +++ b/tavern/main_test.go @@ -9,10 +9,14 @@ import ( // This validates our default configurations are successful. func TestMainFunc(t *testing.T) { os.Setenv(EnvEnableTestRunAndExit.Key, "1") + os.Setenv(EnvHTTPListenAddr.Key, "127.0.0.1:8080") defer func() { if err := os.Unsetenv(EnvEnableTestRunAndExit.Key); err != nil { t.Fatalf("failed to unset env var %s: %v", EnvEnableTestRunAndExit.Key, err) } + if err := os.Unsetenv(EnvHTTPListenAddr.Key); err != nil { + t.Fatalf("failed to unset env var %s: %v", EnvHTTPListenAddr.Key, err) + } }() os.Args = []string{"tavern"} main() From f5a921420e897829f694839aa84191a126cfbe87 Mon Sep 17 00:00:00 2001 From: KCarretto Date: Tue, 23 Jan 2024 03:22:02 +0000 Subject: [PATCH 3/5] improve coverage --- tavern/app.go | 3 ++- tavern/main_test.go | 13 +++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/tavern/app.go b/tavern/app.go index 67df729f2..4a680e86d 100644 --- a/tavern/app.go +++ b/tavern/app.go @@ -69,6 +69,7 @@ type Server struct { // Close should always be called to clean up a Tavern server. func (srv *Server) Close() error { + srv.HTTP.Shutdown(context.Background()) return srv.client.Close() } @@ -186,7 +187,7 @@ func NewServer(ctx context.Context, options ...func(*Config)) (*Server, error) { // Shutdown for Test Run & Exit if cfg.IsTestRunAndExitEnabled() { go func() { - tSrv.HTTP.Shutdown(ctx) + tSrv.Close() }() } diff --git a/tavern/main_test.go b/tavern/main_test.go index 8b5984114..bc8e6d37d 100644 --- a/tavern/main_test.go +++ b/tavern/main_test.go @@ -10,12 +10,17 @@ import ( func TestMainFunc(t *testing.T) { os.Setenv(EnvEnableTestRunAndExit.Key, "1") os.Setenv(EnvHTTPListenAddr.Key, "127.0.0.1:8080") + os.Setenv(EnvEnablePProf.Key, "1") defer func() { - if err := os.Unsetenv(EnvEnableTestRunAndExit.Key); err != nil { - t.Fatalf("failed to unset env var %s: %v", EnvEnableTestRunAndExit.Key, err) + unsetList := []string{ + EnvEnableTestRunAndExit.Key, + EnvHTTPListenAddr.Key, + EnvEnablePProf.Key, } - if err := os.Unsetenv(EnvHTTPListenAddr.Key); err != nil { - t.Fatalf("failed to unset env var %s: %v", EnvHTTPListenAddr.Key, err) + for _, unset := range unsetList { + if err := os.Unsetenv(unset); err != nil { + t.Fatalf("failed to unset env var %s: %v", unset, err) + } } }() os.Args = []string{"tavern"} From 9e936646e2f68f3560c1577e891b1b1a4f0c7d67 Mon Sep 17 00:00:00 2001 From: KCarretto Date: Tue, 23 Jan 2024 03:48:26 +0000 Subject: [PATCH 4/5] Improved testing for report_task_output --- tavern/internal/c2/api_report_task_output.go | 59 ++++++++ .../c2/api_report_task_output_test.go | 134 ++++++++++++++++++ tavern/internal/c2/server.go | 43 ------ 3 files changed, 193 insertions(+), 43 deletions(-) create mode 100644 tavern/internal/c2/api_report_task_output.go create mode 100644 tavern/internal/c2/api_report_task_output_test.go diff --git a/tavern/internal/c2/api_report_task_output.go b/tavern/internal/c2/api_report_task_output.go new file mode 100644 index 000000000..45243011c --- /dev/null +++ b/tavern/internal/c2/api_report_task_output.go @@ -0,0 +1,59 @@ +package c2 + +import ( + "context" + "fmt" + "time" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "realm.pub/tavern/internal/c2/c2pb" + "realm.pub/tavern/internal/ent" +) + +func (srv *Server) ReportTaskOutput(ctx context.Context, req *c2pb.ReportTaskOutputRequest) (*c2pb.ReportTaskOutputResponse, error) { + // Validate Input + if req.Output == nil || req.Output.Id == 0 { + return nil, status.Errorf(codes.InvalidArgument, "must provide task id") + } + + // Parse Input + var ( + execStartedAt *time.Time + execFinishedAt *time.Time + taskErr *string + ) + if req.Output.ExecStartedAt != nil { + timestamp := req.Output.ExecStartedAt.AsTime() + execStartedAt = ×tamp + } + if req.Output.ExecFinishedAt != nil { + timestamp := req.Output.ExecFinishedAt.AsTime() + execFinishedAt = ×tamp + } + if req.Output.Error != nil { + taskErr = &req.Output.Error.Msg + } + + // Load Task + t, err := srv.graph.Task.Get(ctx, int(req.Output.Id)) + if ent.IsNotFound(err) { + return nil, status.Errorf(codes.NotFound, "no task found (id=%d): %v", req.Output.Id, err) + } + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to submit task result (id=%d): %v", req.Output.Id, err) + } + + // Update Task + _, err = t.Update(). + SetNillableExecStartedAt(execStartedAt). + SetOutput(fmt.Sprintf("%s%s", t.Output, req.Output.Output)). + SetNillableExecFinishedAt(execFinishedAt). + SetNillableError(taskErr). + Save(ctx) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to save submitted task result (id=%d): %v", t.ID, err) + } + + return &c2pb.ReportTaskOutputResponse{}, nil +} diff --git a/tavern/internal/c2/api_report_task_output_test.go b/tavern/internal/c2/api_report_task_output_test.go new file mode 100644 index 000000000..4f2138d01 --- /dev/null +++ b/tavern/internal/c2/api_report_task_output_test.go @@ -0,0 +1,134 @@ +package c2_test + +import ( + "context" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/testing/protocmp" + "google.golang.org/protobuf/types/known/timestamppb" + "realm.pub/tavern/internal/c2/c2pb" + "realm.pub/tavern/internal/c2/c2test" + "realm.pub/tavern/internal/ent" +) + +func TestReportTaskOutput(t *testing.T) { + // Setup Dependencies + ctx := context.Background() + client, graph, close := c2test.New(t) + defer close() + + // Test Data + now := timestamppb.Now() + finishedAt := timestamppb.New(time.Now().UTC().Add(10 * time.Minute)) + existingBeacon := c2test.NewRandomBeacon(ctx, graph) + existingTasks := []*ent.Task{ + c2test.NewRandomAssignedTask(ctx, graph, existingBeacon.Identifier), + c2test.NewRandomAssignedTask(ctx, graph, existingBeacon.Identifier), + } + + // Test Cases + tests := []struct { + name string + req *c2pb.ReportTaskOutputRequest + wantResp *c2pb.ReportTaskOutputResponse + wantCode codes.Code + wantOutput string + wantExecStartedAt *timestamppb.Timestamp + wantExecFinishedAt *timestamppb.Timestamp + }{ + { + name: "First_Output", + req: &c2pb.ReportTaskOutputRequest{ + Output: &c2pb.TaskOutput{ + Id: int64(existingTasks[0].ID), + Output: "TestOutput", + ExecStartedAt: now, + }, + }, + wantResp: &c2pb.ReportTaskOutputResponse{}, + wantCode: codes.OK, + wantOutput: "TestOutput", + wantExecStartedAt: now, + }, + { + name: "Append_Output", + req: &c2pb.ReportTaskOutputRequest{ + Output: &c2pb.TaskOutput{ + Id: int64(existingTasks[0].ID), + Output: "_AppendedOutput", + }, + }, + wantResp: &c2pb.ReportTaskOutputResponse{}, + wantCode: codes.OK, + wantOutput: "TestOutput_AppendedOutput", + wantExecStartedAt: now, + }, + { + name: "Exec_Finished", + req: &c2pb.ReportTaskOutputRequest{ + Output: &c2pb.TaskOutput{ + Id: int64(existingTasks[0].ID), + ExecFinishedAt: finishedAt, + }, + }, + wantResp: &c2pb.ReportTaskOutputResponse{}, + wantCode: codes.OK, + wantOutput: "TestOutput_AppendedOutput", + wantExecStartedAt: now, + wantExecFinishedAt: finishedAt, + }, + { + name: "Not_Found", + req: &c2pb.ReportTaskOutputRequest{ + Output: &c2pb.TaskOutput{ + Id: 999888777666, + }, + }, + wantResp: nil, + wantCode: codes.NotFound, + }, + { + name: "Invalid_Argument", + req: &c2pb.ReportTaskOutputRequest{ + Output: &c2pb.TaskOutput{}, + }, + wantResp: nil, + wantCode: codes.InvalidArgument, + }, + } + + // Run Tests + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Callback + resp, err := client.ReportTaskOutput(ctx, tc.req) + + // Assert Response Code + require.Equal(t, tc.wantCode.String(), status.Code(err).String(), err) + if status.Code(err) != codes.OK { + // Do not continue if we expected error code + return + } + + // Assert Response + if diff := cmp.Diff(tc.wantResp, resp, protocmp.Transform()); diff != "" { + t.Errorf("invalid response (-want +got): %v", diff) + } + + // Load Task + + testTask, err := graph.Task.Get(ctx, int(tc.req.Output.Id)) + require.NoError(t, err) + + // Task Assertions + assert.Equal(t, tc.wantOutput, testTask.Output) + }) + } + +} diff --git a/tavern/internal/c2/server.go b/tavern/internal/c2/server.go index adb9dffee..e414c1bcd 100644 --- a/tavern/internal/c2/server.go +++ b/tavern/internal/c2/server.go @@ -1,10 +1,6 @@ package c2 import ( - "context" - "fmt" - "time" - "realm.pub/tavern/internal/c2/c2pb" "realm.pub/tavern/internal/ent" ) @@ -21,42 +17,3 @@ func New(graph *ent.Client) *Server { graph: graph, } } - -func (srv *Server) ReportTaskOutput(ctx context.Context, req *c2pb.ReportTaskOutputRequest) (*c2pb.ReportTaskOutputResponse, error) { - // 1. Parse Input - var ( - execStartedAt *time.Time - execFinishedAt *time.Time - taskErr *string - ) - if req.Output.ExecStartedAt != nil { - timestamp := req.Output.ExecStartedAt.AsTime() - execStartedAt = ×tamp - } - if req.Output.ExecFinishedAt != nil { - timestamp := req.Output.ExecFinishedAt.AsTime() - execFinishedAt = ×tamp - } - if req.Output.Error != nil { - taskErr = &req.Output.Error.Msg - } - - // 2. Load the task - t, err := srv.graph.Task.Get(ctx, int(req.Output.Id)) - if err != nil { - return nil, fmt.Errorf("failed to submit task result (id=%d): %w", req.Output.Id, err) - } - - // 3. Update task info - _, err = t.Update(). - SetNillableExecStartedAt(execStartedAt). - SetOutput(fmt.Sprintf("%s%s", t.Output, req.Output.Output)). - SetNillableExecFinishedAt(execFinishedAt). - SetNillableError(taskErr). - Save(ctx) - if err != nil { - return nil, fmt.Errorf("failed to save submitted task result (id=%d): %w", t.ID, err) - } - - return &c2pb.ReportTaskOutputResponse{}, nil -} From 63097aa514ec92bbab4b6fcc00c2e74952504657 Mon Sep 17 00:00:00 2001 From: KCarretto Date: Tue, 23 Jan 2024 03:56:13 +0000 Subject: [PATCH 5/5] Added additional test case for report_process_list --- tavern/internal/c2/api_report_process_list_test.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tavern/internal/c2/api_report_process_list_test.go b/tavern/internal/c2/api_report_process_list_test.go index 716c75a25..5f15e2053 100644 --- a/tavern/internal/c2/api_report_process_list_test.go +++ b/tavern/internal/c2/api_report_process_list_test.go @@ -95,6 +95,17 @@ func TestReportProcessList(t *testing.T) { wantResp: nil, wantCode: codes.InvalidArgument, }, + { + name: "Not_Found", + req: &c2pb.ReportProcessListRequest{ + TaskId: 99888777776666, + List: []*c2pb.Process{ + {Pid: 1, Name: "systemd", Principal: "root"}, + }, + }, + wantResp: nil, + wantCode: codes.NotFound, + }, } // Run Tests