diff --git a/instrumentation/opencensus/database/hypersql/sql_test.go b/instrumentation/opencensus/database/hypersql/sql_test.go index e3217d14..368dd1b0 100644 --- a/instrumentation/opencensus/database/hypersql/sql_test.go +++ b/instrumentation/opencensus/database/hypersql/sql_test.go @@ -1,6 +1,6 @@ package hypersql -// highly inspired by https://github.com/openzipkin-contrib/zipkin-go-sql/blob/master/driver_test.go +// highly inspired in https://github.com/openzipkin-contrib/zipkin-go-sql/blob/master/driver_test.go import ( "context" diff --git a/instrumentation/opentelemetry/database/hypersql/sql.go b/instrumentation/opentelemetry/database/hypersql/sql.go index 3b6ece86..45b987bf 100644 --- a/instrumentation/opentelemetry/database/hypersql/sql.go +++ b/instrumentation/opentelemetry/database/hypersql/sql.go @@ -111,6 +111,7 @@ func (in *interceptor) ConnPrepareContext(ctx context.Context, conn driver.ConnP span.SetAttribute(key, value) } + span.SetAttribute("db.statement", query) defer span.End() tx, err := conn.PrepareContext(ctx, query) diff --git a/sdk/google.golang.org/grpc/attributes_test.go b/sdk/google.golang.org/grpc/attributes_test.go index cd9331cf..b535fb02 100644 --- a/sdk/google.golang.org/grpc/attributes_test.go +++ b/sdk/google.golang.org/grpc/attributes_test.go @@ -13,7 +13,10 @@ func TestSetScalarAttributeSuccess(t *testing.T) { span := mock.NewSpan() setAttributesFromMetadata("request", md, span) - assert.Equal(t, "value_1", span.Attributes["rpc.request.metadata.key_1"].(string)) + assert.Equal(t, "value_1", span.ReadAttribute("rpc.request.metadata.key_1").(string)) + + _ = span.ReadAttribute("container_id") // needed in containarized envs + assert.Zero(t, span.RemainingAttributes(), "unexpected remaining attribute: %v", span.Attributes) } func TestSetMultivalueAttributeSuccess(t *testing.T) { @@ -21,6 +24,9 @@ func TestSetMultivalueAttributeSuccess(t *testing.T) { span := mock.NewSpan() setAttributesFromMetadata("request", md, span) - assert.Equal(t, "value_1", span.Attributes["rpc.request.metadata.key_1[0]"].(string)) - assert.Equal(t, "value_2", span.Attributes["rpc.request.metadata.key_1[1]"].(string)) + assert.Equal(t, "value_1", span.ReadAttribute("rpc.request.metadata.key_1[0]").(string)) + assert.Equal(t, "value_2", span.ReadAttribute("rpc.request.metadata.key_1[1]").(string)) + + _ = span.ReadAttribute("container_id") // needed in containarized envs + assert.Zero(t, span.RemainingAttributes(), "unexpected remaining attribute: %v", span.Attributes) } diff --git a/sdk/google.golang.org/grpc/client.go b/sdk/google.golang.org/grpc/client.go index baff9922..c04c0727 100644 --- a/sdk/google.golang.org/grpc/client.go +++ b/sdk/google.golang.org/grpc/client.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/hypertrace/goagent/sdk" + internalconfig "github.com/hypertrace/goagent/sdk/internal/config" "github.com/hypertrace/goagent/sdk/internal/container" "google.golang.org/grpc" "google.golang.org/grpc/metadata" @@ -20,6 +21,8 @@ func WrapUnaryClientInterceptor(delegateInterceptor grpc.UnaryClientInterceptor, defaultAttributes["container_id"] = containerID } + dataCaptureConfig := internalconfig.GetConfig().GetDataCapture() + return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { var header metadata.MD var trailer metadata.MD @@ -45,22 +48,26 @@ func WrapUnaryClientInterceptor(delegateInterceptor grpc.UnaryClientInterceptor, span.SetAttribute("rpc.method", pieces[1]) reqBody, err := marshalMessageableJSON(req) - if len(reqBody) > 0 && err == nil { + if dataCaptureConfig.RpcBody.Request.Value && len(reqBody) > 0 && err == nil { span.SetAttribute("rpc.request.body", string(reqBody)) } - setAttributesFromRequestOutgoingMetadata(ctx, span) + if dataCaptureConfig.RpcMetadata.Request.Value { + setAttributesFromRequestOutgoingMetadata(ctx, span) + } err = invoker(ctx, method, req, reply, cc, opts...) if err != nil { return err } - setAttributesFromMetadata("response", header, span) - setAttributesFromMetadata("response", trailer, span) + if dataCaptureConfig.RpcMetadata.Response.Value { + setAttributesFromMetadata("response", header, span) + setAttributesFromMetadata("response", trailer, span) + } resBody, err := marshalMessageableJSON(reply) - if len(resBody) > 0 && err == nil { + if dataCaptureConfig.RpcBody.Response.Value && len(resBody) > 0 && err == nil { span.SetAttribute("rpc.response.body", string(resBody)) } diff --git a/sdk/google.golang.org/grpc/client_test.go b/sdk/google.golang.org/grpc/client_test.go index 898aa26e..be6d4447 100644 --- a/sdk/google.golang.org/grpc/client_test.go +++ b/sdk/google.golang.org/grpc/client_test.go @@ -68,15 +68,16 @@ func TestUnaryClientHelloWorldSuccess(t *testing.T) { span := spans[0] - assert.Equal(t, "grpc", span.Attributes["rpc.system"].(string)) - assert.Equal(t, "helloworld.Greeter", span.Attributes["rpc.service"].(string)) - assert.Equal(t, "SayHello", span.Attributes["rpc.method"].(string)) - assert.Equal(t, "test_value_1", span.Attributes["rpc.request.metadata.test_key_1"].(string)) - assert.Equal(t, "test_header_value", span.Attributes["rpc.response.metadata.test_header_key"].(string)) - assert.Equal(t, "test_trailer_value", span.Attributes["rpc.response.metadata.test_trailer_key"].(string)) + assert.Equal(t, "grpc", span.ReadAttribute("rpc.system").(string)) + assert.Equal(t, "helloworld.Greeter", span.ReadAttribute("rpc.service").(string)) + assert.Equal(t, "SayHello", span.ReadAttribute("rpc.method").(string)) + assert.Equal(t, "test_value_1", span.ReadAttribute("rpc.request.metadata.test_key_1").(string)) + assert.Equal(t, "test_header_value", span.ReadAttribute("rpc.response.metadata.test_header_key").(string)) + assert.Equal(t, "test_trailer_value", span.ReadAttribute("rpc.response.metadata.test_trailer_key").(string)) + assert.Equal(t, "application/grpc", span.ReadAttribute("rpc.response.metadata.content-type").(string)) expectedBody := "{\"name\":\"Cuchi\"}" - actualBody := span.Attributes["rpc.request.body"].(string) + actualBody := span.ReadAttribute("rpc.request.body").(string) if ok, err := jsonEqual(expectedBody, actualBody); err == nil { assert.True(t, ok, "incorrect request body:\nwant %s,\nhave %s", expectedBody, actualBody) } else { @@ -84,12 +85,15 @@ func TestUnaryClientHelloWorldSuccess(t *testing.T) { } expectedBody = "{\"message\":\"Hello Cuchi\"}" - actualBody = span.Attributes["rpc.response.body"].(string) + actualBody = span.ReadAttribute("rpc.response.body").(string) if ok, err := jsonEqual(expectedBody, actualBody); err == nil { assert.True(t, ok, "incorrect response body:\nwant %s,\nhave %s", expectedBody, actualBody) } else { t.Fatalf("unexpected error: %v", err) } + + _ = span.ReadAttribute("container_id") // needed in containarized envs + assert.Zero(t, span.RemainingAttributes(), "unexpected remaining attribute: %v", span.Attributes) } func TestClientHandlerHelloWorldSuccess(t *testing.T) { @@ -135,15 +139,15 @@ func TestClientHandlerHelloWorldSuccess(t *testing.T) { span := mockHandler.Spans[0] - assert.Equal(t, "grpc", span.Attributes["rpc.system"].(string)) - assert.Equal(t, "helloworld.Greeter", span.Attributes["rpc.service"].(string)) - assert.Equal(t, "SayHello", span.Attributes["rpc.method"].(string)) - assert.Equal(t, "test_value_1", span.Attributes["rpc.request.metadata.test_key_1"].(string)) - assert.Equal(t, "test_header_value", span.Attributes["rpc.response.metadata.test_header_key"].(string)) - assert.Equal(t, "test_trailer_value", span.Attributes["rpc.response.metadata.test_trailer_key"].(string)) + assert.Equal(t, "grpc", span.ReadAttribute("rpc.system").(string)) + assert.Equal(t, "helloworld.Greeter", span.ReadAttribute("rpc.service").(string)) + assert.Equal(t, "SayHello", span.ReadAttribute("rpc.method").(string)) + assert.Equal(t, "test_value_1", span.ReadAttribute("rpc.request.metadata.test_key_1").(string)) + assert.Equal(t, "test_header_value", span.ReadAttribute("rpc.response.metadata.test_header_key").(string)) + assert.Equal(t, "test_trailer_value", span.ReadAttribute("rpc.response.metadata.test_trailer_key").(string)) expectedBody := "{\"name\":\"Cuchi\"}" - actualBody := span.Attributes["rpc.request.body"].(string) + actualBody := span.ReadAttribute("rpc.request.body").(string) if ok, err := jsonEqual(expectedBody, actualBody); err == nil { assert.True(t, ok, "incorrect request body:\nwant %s,\nhave %s", expectedBody, actualBody) } else { @@ -151,7 +155,7 @@ func TestClientHandlerHelloWorldSuccess(t *testing.T) { } expectedBody = "{\"message\":\"Hello Cuchi\"}" - actualBody = span.Attributes["rpc.response.body"].(string) + actualBody = span.ReadAttribute("rpc.response.body").(string) if ok, err := jsonEqual(expectedBody, actualBody); err == nil { assert.True(t, ok, "incorrect response body:\nwant %s,\nhave %s", expectedBody, actualBody) } else { diff --git a/sdk/google.golang.org/grpc/server.go b/sdk/google.golang.org/grpc/server.go index 0b51fc09..46d4b156 100644 --- a/sdk/google.golang.org/grpc/server.go +++ b/sdk/google.golang.org/grpc/server.go @@ -4,7 +4,9 @@ import ( "context" "strings" + "github.com/hypertrace/goagent/config" "github.com/hypertrace/goagent/sdk" + internalconfig "github.com/hypertrace/goagent/sdk/internal/config" "github.com/hypertrace/goagent/sdk/internal/container" "google.golang.org/grpc" "google.golang.org/grpc/stats" @@ -36,7 +38,7 @@ func WrapUnaryServerInterceptor( ctx, req, info, - wrapHandler(info.FullMethod, handler, spanFromContext, defaultAttributes), + wrapHandler(info.FullMethod, handler, spanFromContext, defaultAttributes, internalconfig.GetConfig().GetDataCapture()), ) } } @@ -46,6 +48,7 @@ func wrapHandler( delegateHandler grpc.UnaryHandler, spanFromContext sdk.SpanFromContext, defaultAttributes map[string]string, + dataCaptureConfig *config.DataCapture, ) grpc.UnaryHandler { return func(ctx context.Context, req interface{}) (interface{}, error) { span := spanFromContext(ctx) @@ -65,11 +68,14 @@ func wrapHandler( span.SetAttribute("rpc.method", pieces[1]) reqBody, err := marshalMessageableJSON(req) - if len(reqBody) > 0 && err == nil { + if dataCaptureConfig.RpcBody.Request.Value && + len(reqBody) > 0 && err == nil { span.SetAttribute("rpc.request.body", string(reqBody)) } - setAttributesFromRequestIncomingMetadata(ctx, span) + if dataCaptureConfig.RpcMetadata.Request.Value { + setAttributesFromRequestIncomingMetadata(ctx, span) + } res, err := delegateHandler(ctx, req) if err != nil { @@ -77,7 +83,8 @@ func wrapHandler( } resBody, err := marshalMessageableJSON(res) - if len(resBody) > 0 && err == nil { + if dataCaptureConfig.RpcBody.Response.Value && + len(resBody) > 0 && err == nil { span.SetAttribute("rpc.response.body", string(resBody)) } @@ -91,6 +98,7 @@ type handler struct { stats.Handler spanFromContext sdk.SpanFromContext defaultAttributes map[string]string + dataCaptureConfig *config.DataCapture } // HandleRPC implements per-RPC tracing and stats instrumentation. @@ -102,7 +110,7 @@ func (s *handler) HandleRPC(ctx context.Context, rs stats.RPCStats) { // isNoop means either the span is not sampled or there was no span // in the request context which means this Handler is not used // inside an instrumented Handler, hence we just invoke the delegate - // round tripper. + // handler. return } @@ -117,21 +125,21 @@ func (s *handler) HandleRPC(ctx context.Context, rs stats.RPCStats) { return } - if rs.IsClient() { + if rs.IsClient() && s.dataCaptureConfig.RpcBody.Response.Value { span.SetAttribute("rpc.response.body", string(body)) - } else { + } else if !rs.IsClient() && s.dataCaptureConfig.RpcBody.Request.Value { span.SetAttribute("rpc.request.body", string(body)) } case *stats.InHeader: - if rs.IsClient() { + if rs.IsClient() && s.dataCaptureConfig.RpcMetadata.Response.Value { setAttributesFromMetadata("response", rs.Header, span) - } else { + } else if !rs.IsClient() && s.dataCaptureConfig.RpcMetadata.Request.Value { setAttributesFromMetadata("request", rs.Header, span) } case *stats.InTrailer: - if rs.IsClient() { + if rs.IsClient() && s.dataCaptureConfig.RpcMetadata.Response.Value { setAttributesFromMetadata("response", rs.Trailer, span) - } else { + } else if !rs.IsClient() && s.dataCaptureConfig.RpcMetadata.Request.Value { setAttributesFromMetadata("request", rs.Trailer, span) } case *stats.OutPayload: @@ -140,21 +148,21 @@ func (s *handler) HandleRPC(ctx context.Context, rs stats.RPCStats) { return } - if rs.IsClient() { + if rs.IsClient() && s.dataCaptureConfig.RpcBody.Request.Value { span.SetAttribute("rpc.request.body", string(body)) - } else { + } else if !rs.IsClient() && s.dataCaptureConfig.RpcBody.Response.Value { span.SetAttribute("rpc.response.body", string(body)) } case *stats.OutHeader: - if rs.IsClient() { + if rs.IsClient() && s.dataCaptureConfig.RpcMetadata.Request.Value { setAttributesFromMetadata("request", rs.Header, span) - } else { + } else if !rs.IsClient() && s.dataCaptureConfig.RpcMetadata.Response.Value { setAttributesFromMetadata("response", rs.Header, span) } case *stats.OutTrailer: - if rs.IsClient() { + if rs.IsClient() && s.dataCaptureConfig.RpcMetadata.Request.Value { setAttributesFromMetadata("request", rs.Trailer, span) - } else { + } else if !rs.IsClient() && s.dataCaptureConfig.RpcMetadata.Response.Value { setAttributesFromMetadata("response", rs.Trailer, span) } } @@ -188,5 +196,10 @@ func WrapStatsHandler(delegate stats.Handler, spanFromContext sdk.SpanFromContex defaultAttributes["container_id"] = containerID } - return &handler{Handler: delegate, spanFromContext: spanFromContext, defaultAttributes: defaultAttributes} + return &handler{ + Handler: delegate, + spanFromContext: spanFromContext, + defaultAttributes: defaultAttributes, + dataCaptureConfig: internalconfig.GetConfig().GetDataCapture(), + } } diff --git a/sdk/google.golang.org/grpc/server_test.go b/sdk/google.golang.org/grpc/server_test.go index 5a2d470f..d6720df0 100644 --- a/sdk/google.golang.org/grpc/server_test.go +++ b/sdk/google.golang.org/grpc/server_test.go @@ -61,13 +61,13 @@ func TestServerInterceptorHelloWorldSuccess(t *testing.T) { span := spans[0] - assert.Equal(t, "grpc", span.Attributes["rpc.system"].(string)) - assert.Equal(t, "helloworld.Greeter", span.Attributes["rpc.service"].(string)) - assert.Equal(t, "SayHello", span.Attributes["rpc.method"].(string)) - assert.Equal(t, "test_value", span.Attributes["rpc.request.metadata.test_key"].(string)) + assert.Equal(t, "grpc", span.ReadAttribute("rpc.system").(string)) + assert.Equal(t, "helloworld.Greeter", span.ReadAttribute("rpc.service").(string)) + assert.Equal(t, "SayHello", span.ReadAttribute("rpc.method").(string)) + assert.Equal(t, "test_value", span.ReadAttribute("rpc.request.metadata.test_key").(string)) expectedBody := "{\"name\":\"Pupo\"}" - actualBody := span.Attributes["rpc.request.body"].(string) + actualBody := span.ReadAttribute("rpc.request.body").(string) if ok, err := jsonEqual(expectedBody, actualBody); err == nil { assert.True(t, ok, "incorrect request body:\nwant %s,\nhave %s", expectedBody, actualBody) } else { @@ -75,7 +75,7 @@ func TestServerInterceptorHelloWorldSuccess(t *testing.T) { } expectedBody = "{\"message\":\"Hello Pupo\"}" - actualBody = span.Attributes["rpc.response.body"].(string) + actualBody = span.ReadAttribute("rpc.response.body").(string) if ok, err := jsonEqual(expectedBody, actualBody); err == nil { assert.True(t, ok, "incorrect response body:\nwant %s,\nhave %s", expectedBody, actualBody) } else { @@ -101,6 +101,7 @@ func TestServerHandlerHelloWorldSuccess(t *testing.T) { "bufnet", grpc.WithContextDialer(dialer), grpc.WithInsecure(), + grpc.WithUserAgent("test_agent"), ) if err != nil { t.Fatalf("failed to dial bufnet: %v", err) @@ -110,6 +111,7 @@ func TestServerHandlerHelloWorldSuccess(t *testing.T) { client := helloworld.NewGreeterClient(conn) ctx = metadata.NewOutgoingContext(ctx, metadata.Pairs("test_key", "test_value")) + _, err = client.SayHello(ctx, &helloworld.HelloRequest{ Name: "Pupo", }) @@ -121,14 +123,17 @@ func TestServerHandlerHelloWorldSuccess(t *testing.T) { span := mockHandler.Spans[0] - assert.Equal(t, "grpc", span.Attributes["rpc.system"].(string)) - assert.Equal(t, "helloworld.Greeter", span.Attributes["rpc.service"].(string)) - assert.Equal(t, "SayHello", span.Attributes["rpc.method"].(string)) - assert.Equal(t, "test_value", span.Attributes["rpc.request.metadata.test_key"].(string)) - assert.Equal(t, "test_value", span.Attributes["rpc.request.metadata.test_key"].(string)) + assert.Equal(t, "grpc", span.ReadAttribute("rpc.system").(string)) + assert.Equal(t, "helloworld.Greeter", span.ReadAttribute("rpc.service").(string)) + assert.Equal(t, "SayHello", span.ReadAttribute("rpc.method").(string)) + assert.Equal(t, "test_value", span.ReadAttribute("rpc.request.metadata.test_key").(string)) + + assert.Equal(t, "bufnet", span.ReadAttribute("rpc.request.metadata.:authority").(string)) + assert.Equal(t, "application/grpc", span.ReadAttribute("rpc.request.metadata.content-type").(string)) + assert.Contains(t, span.ReadAttribute("rpc.request.metadata.user-agent").(string), "test_agent") expectedBody := "{\"name\":\"Pupo\"}" - actualBody := span.Attributes["rpc.request.body"].(string) + actualBody := span.ReadAttribute("rpc.request.body").(string) if ok, err := jsonEqual(expectedBody, actualBody); err == nil { assert.True(t, ok, "incorrect request body:\nwant %s,\nhave %s", expectedBody, actualBody) } else { @@ -136,10 +141,13 @@ func TestServerHandlerHelloWorldSuccess(t *testing.T) { } expectedBody = "{\"message\":\"Hello Pupo\"}" - actualBody = span.Attributes["rpc.response.body"].(string) + actualBody = span.ReadAttribute("rpc.response.body").(string) if ok, err := jsonEqual(expectedBody, actualBody); err == nil { assert.True(t, ok, "incorrect response body:\nwant %s,\nhave %s", expectedBody, actualBody) } else { t.Fatalf("unexpected error: %v", err) } + + _ = span.ReadAttribute("container_id") // needed in containarized envs + assert.Zero(t, span.RemainingAttributes(), "unexpected remaining attribute: %v", span.Attributes) } diff --git a/sdk/internal/config/config.go b/sdk/internal/config/config.go index 623617c2..2920a481 100644 --- a/sdk/internal/config/config.go +++ b/sdk/internal/config/config.go @@ -13,14 +13,14 @@ var cfgMux = &sync.Mutex{} // InitConfig initializes the config with default values func InitConfig(c *config.AgentConfig) { + cfgMux.Lock() + defer cfgMux.Unlock() + if cfg != nil { log.Println("config already initialized, ignoring new config.") return } - cfgMux.Lock() - defer cfgMux.Unlock() - // The reason why we clone the message instead of reusing the one passed by the user // is because user might decide to change values in runtime and that is undesirable // without a proper API. diff --git a/sdk/internal/mock/span.go b/sdk/internal/mock/span.go index 5e749179..d942a8a3 100644 --- a/sdk/internal/mock/span.go +++ b/sdk/internal/mock/span.go @@ -29,6 +29,23 @@ func (s *Span) SetAttribute(key string, value interface{}) { s.Attributes[key] = value } +func (s *Span) ReadAttribute(key string) interface{} { + s.mux.Lock() // avoids race conditions + defer s.mux.Unlock() + + val, ok := s.Attributes[key] + if ok { + delete(s.Attributes, key) + return val + } + + return nil +} + +func (s *Span) RemainingAttributes() int { + return len(s.Attributes) +} + func (s *Span) IsNoop() bool { return s.Noop } diff --git a/sdk/net/http/attributes_test.go b/sdk/net/http/attributes_test.go index 54eb62b1..a4a4da4f 100644 --- a/sdk/net/http/attributes_test.go +++ b/sdk/net/http/attributes_test.go @@ -13,7 +13,10 @@ func TestSetScalarAttributeSuccess(t *testing.T) { h.Set("key_1", "value_1") span := mock.NewSpan() setAttributesFromHeaders("request", h, span) - assert.Equal(t, "value_1", span.Attributes["http.request.header.Key_1"].(string)) + assert.Equal(t, "value_1", span.ReadAttribute("http.request.header.Key_1").(string)) + + _ = span.ReadAttribute("container_id") // needed in containarized envs + assert.Zero(t, span.RemainingAttributes(), "unexpected remaining attribute: %v", span.Attributes) } func TestSetMultivalueAttributeSuccess(t *testing.T) { @@ -24,6 +27,9 @@ func TestSetMultivalueAttributeSuccess(t *testing.T) { span := mock.NewSpan() setAttributesFromHeaders("response", h, span) - assert.Equal(t, "value_1", span.Attributes["http.response.header.Key_1[0]"].(string)) - assert.Equal(t, "value_2", span.Attributes["http.response.header.Key_1[1]"].(string)) + assert.Equal(t, "value_1", span.ReadAttribute("http.response.header.Key_1[0]").(string)) + assert.Equal(t, "value_2", span.ReadAttribute("http.response.header.Key_1[1]").(string)) + + _ = span.ReadAttribute("container_id") // needed in containarized envs + assert.Zero(t, span.RemainingAttributes(), "unexpected remaining attribute: %v", span.Attributes) } diff --git a/sdk/net/http/handler.go b/sdk/net/http/handler.go index 12727130..9b639376 100644 --- a/sdk/net/http/handler.go +++ b/sdk/net/http/handler.go @@ -48,11 +48,11 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { span.SetAttribute("http.url", r.URL.String()) // Sets an attribute per each request header. - if h.dataCaptureConfig.GetHttpHeaders().GetRequest().GetValue() { + if h.dataCaptureConfig.GetHttpHeaders().Request.Value { setAttributesFromHeaders("request", r.Header, span) } - if h.dataCaptureConfig.GetHttpBody().GetRequest().GetValue() && shouldRecordBodyOfContentType(r.Header) { + if h.dataCaptureConfig.HttpBody.Request.Value && shouldRecordBodyOfContentType(r.Header) { body, err := ioutil.ReadAll(r.Body) if err != nil { return @@ -73,13 +73,13 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // tag found status code on exit defer func() { - if h.dataCaptureConfig.GetHttpBody().GetResponse().GetValue() && + if h.dataCaptureConfig.HttpBody.Response.Value && len(wi.body) > 0 && shouldRecordBodyOfContentType(wi.Header()) { span.SetAttribute("http.response.body", string(wi.body)) } - if h.dataCaptureConfig.GetHttpHeaders().GetResponse().GetValue() { + if h.dataCaptureConfig.HttpHeaders.Response.Value { // Sets an attribute per each response header. setAttributesFromHeaders("response", wi.Header(), span) } diff --git a/sdk/net/http/handler_test.go b/sdk/net/http/handler_test.go index 3fb2745d..5e5f286c 100644 --- a/sdk/net/http/handler_test.go +++ b/sdk/net/http/handler_test.go @@ -8,7 +8,6 @@ import ( "testing" "github.com/hypertrace/goagent/config" - sdkconfig "github.com/hypertrace/goagent/sdk/config" "github.com/hypertrace/goagent/sdk/internal/mock" "github.com/stretchr/testify/assert" ) @@ -27,21 +26,6 @@ func (h *mockHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { h.baseHandler.ServeHTTP(rw, r.WithContext(ctx)) } -func TestMain(m *testing.M) { - sdkconfig.InitConfig(&config.AgentConfig{ - DataCapture: &config.DataCapture{ - HttpHeaders: &config.Message{ - Request: config.Bool(true), - Response: config.Bool(true), - }, - HttpBody: &config.Message{ - Request: config.Bool(true), - Response: config.Bool(true), - }, - }, - }) -} - func TestServerRequestIsSuccessfullyTraced(t *testing.T) { h := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { rw.Header().Add("request_id", "abc123xyz") @@ -53,10 +37,15 @@ func TestServerRequestIsSuccessfullyTraced(t *testing.T) { wh, _ := WrapHandler(h, mock.SpanFromContext).(*handler) wh.dataCaptureConfig = &config.DataCapture{ HttpHeaders: &config.Message{ - Request: config.Bool(true), - Response: config.Bool(true), + Request: config.Bool(false), + Response: config.Bool(false), + }, + HttpBody: &config.Message{ + Request: config.Bool(false), + Response: config.Bool(false), }, } + ih := &mockHandler{baseHandler: wh} r, _ := http.NewRequest("GET", "http://traceable.ai/foo?user_id=1", strings.NewReader("test_request_body")) @@ -66,12 +55,66 @@ func TestServerRequestIsSuccessfullyTraced(t *testing.T) { ih.ServeHTTP(w, r) assert.Equal(t, "test_response_body", w.Body.String()) - spans := ih.spans - assert.Equal(t, 1, len(spans)) + assert.Equal(t, 1, len(ih.spans)) + + span := ih.spans[0] + assert.Equal(t, "http://traceable.ai/foo?user_id=1", span.ReadAttribute("http.url").(string)) + + _ = span.ReadAttribute("container_id") // needed in containarized envs + assert.Zero(t, span.RemainingAttributes(), "unexpected remaining attribute: %v", span.Attributes) +} + +func TestServerRequestHeadersAreSuccessfullyRecorded(t *testing.T) { + tCases := []struct { + captureHTTPHeadersRequestConfig bool + captureHTTPHeadersResponseConfig bool + }{ + {true, true}, + {true, false}, + {false, true}, + {false, false}, + } + for _, tCase := range tCases { + h := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + rw.Header().Add("request_id", "abc123xyz") + rw.WriteHeader(202) + }) + + wh, _ := WrapHandler(h, mock.SpanFromContext).(*handler) + ih := &mockHandler{baseHandler: wh} + wh.dataCaptureConfig = &config.DataCapture{ + HttpHeaders: &config.Message{ + Request: config.Bool(tCase.captureHTTPHeadersRequestConfig), + Response: config.Bool(tCase.captureHTTPHeadersResponseConfig), + }, + HttpBody: &config.Message{ + Request: config.Bool(false), + Response: config.Bool(false), + }, + } + + r, _ := http.NewRequest("GET", "http://traceable.ai/foo?user_id=1", strings.NewReader("test_request_body")) + r.Header.Add("api_key", "xyz123abc") + w := httptest.NewRecorder() + + ih.ServeHTTP(w, r) + + spans := ih.spans + assert.Equal(t, 1, len(spans)) - assert.Equal(t, "http://traceable.ai/foo?user_id=1", spans[0].Attributes["http.url"].(string)) - assert.Equal(t, "xyz123abc", spans[0].Attributes["http.request.header.Api_key"].(string)) - assert.Equal(t, "abc123xyz", spans[0].Attributes["http.response.header.Request_id"].(string)) + span := spans[0] + if tCase.captureHTTPHeadersRequestConfig { + assert.Equal(t, "xyz123abc", span.ReadAttribute("http.request.header.Api_key").(string)) + } else { + assert.Nil(t, span.ReadAttribute("http.request.header.Api_key")) + } + + if tCase.captureHTTPHeadersResponseConfig { + assert.Equal(t, "abc123xyz", span.ReadAttribute("http.response.header.Request_id").(string)) + } else { + assert.Nil(t, span.ReadAttribute("http.response.header.Request_id")) + } + } } func TestServerRecordsRequestAndResponseBodyAccordingly(t *testing.T) { @@ -137,6 +180,10 @@ func TestServerRecordsRequestAndResponseBodyAccordingly(t *testing.T) { Request: config.Bool(tCase.captureHTTPBodyConfig), Response: config.Bool(tCase.captureHTTPBodyConfig), }, + HttpHeaders: &config.Message{ + Request: config.Bool(false), + Response: config.Bool(false), + }, } ih := &mockHandler{baseHandler: wh} @@ -150,11 +197,15 @@ func TestServerRecordsRequestAndResponseBodyAccordingly(t *testing.T) { span := ih.spans[0] if tCase.shouldHaveRecordedRequestBody { - assert.Equal(t, tCase.requestBody, span.Attributes["http.request.body"].(string)) + assert.Equal(t, tCase.requestBody, span.ReadAttribute("http.request.body").(string)) + } else { + assert.Nil(t, span.ReadAttribute("http.request.body")) } if tCase.shouldHaveRecordedResponseBody { - assert.Equal(t, tCase.responseBody, span.Attributes["http.response.body"].(string)) + assert.Equal(t, tCase.responseBody, span.ReadAttribute("http.response.body").(string)) + } else { + assert.Nil(t, span.ReadAttribute("http.response.body")) } }) } diff --git a/sdk/net/http/transport.go b/sdk/net/http/transport.go index abc5028f..e92a72c3 100644 --- a/sdk/net/http/transport.go +++ b/sdk/net/http/transport.go @@ -34,7 +34,7 @@ func (rt *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { span.SetAttribute(key, value) } - if rt.dataCaptureConfig.GetHttpHeaders().GetRequest().GetValue() { + if rt.dataCaptureConfig.HttpHeaders.Request.Value { setAttributesFromHeaders("request", req.Header, span) } @@ -42,7 +42,7 @@ func (rt *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { // is in the recording accept list. Notice in here we rely on the fact that // the content type is not streamable, otherwise we could end up in a very // expensive parsing of a big body in memory. - if rt.dataCaptureConfig.GetHttpBody().GetRequest().GetValue() && shouldRecordBodyOfContentType(req.Header) { + if rt.dataCaptureConfig.HttpBody.Request.Value && shouldRecordBodyOfContentType(req.Header) { body, err := ioutil.ReadAll(req.Body) if err != nil { return rt.delegate.RoundTrip(req) @@ -62,7 +62,7 @@ func (rt *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { } // Notice, parsing a streamed content in memory can be expensive. - if rt.dataCaptureConfig.GetHttpBody().GetResponse().GetValue() && shouldRecordBodyOfContentType(res.Header) { + if rt.dataCaptureConfig.HttpBody.Response.Value && shouldRecordBodyOfContentType(res.Header) { body, err := ioutil.ReadAll(res.Body) if err != nil { return res, nil @@ -76,7 +76,7 @@ func (rt *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { res.Body = ioutil.NopCloser(bytes.NewBuffer(body)) } - if rt.dataCaptureConfig.GetHttpHeaders().GetResponse().GetValue() { + if rt.dataCaptureConfig.HttpHeaders.Response.Value { // Sets an attribute per each response header. setAttributesFromHeaders("response", res.Header, span) } diff --git a/sdk/net/http/transport_test.go b/sdk/net/http/transport_test.go index 9ee3d399..0f1270fd 100644 --- a/sdk/net/http/transport_test.go +++ b/sdk/net/http/transport_test.go @@ -10,6 +10,7 @@ import ( "net/http/httptest" "testing" + "github.com/hypertrace/goagent/config" "github.com/hypertrace/goagent/sdk/internal/mock" "github.com/stretchr/testify/assert" ) @@ -28,23 +29,31 @@ func (t *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) { func TestClientRequestIsSuccessfullyTraced(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { - rw.Header().Set("content-type", "application/json") - rw.Header().Set("request_id", "xyz123abc") rw.WriteHeader(202) rw.Write([]byte(`{"id":123}`)) })) defer srv.Close() + rt := WrapTransport(http.DefaultTransport, mock.SpanFromContext).(*roundTripper) + rt.dataCaptureConfig = &config.DataCapture{ + HttpHeaders: &config.Message{ + Request: config.Bool(false), + Response: config.Bool(false), + }, + HttpBody: &config.Message{ + Request: config.Bool(false), + Response: config.Bool(false), + }, + } + tr := &mockTransport{ - baseRoundTripper: WrapTransport(http.DefaultTransport, mock.SpanFromContext), + baseRoundTripper: rt, } client := &http.Client{ Transport: tr, } req, _ := http.NewRequest("POST", srv.URL, bytes.NewBufferString(`{"name":"Jacinto"}`)) - req.Header.Set("api_key", "abc123xyz") - req.Header.Set("content-type", "application/json") res, err := client.Do(req) if err != nil { t.Errorf("unexpected error: %v", err) @@ -60,10 +69,81 @@ func TestClientRequestIsSuccessfullyTraced(t *testing.T) { assert.Equal(t, 1, len(spans), "unexpected number of spans") span := spans[0] - assert.Equal(t, "abc123xyz", span.Attributes["http.request.header.Api_key"].(string)) - assert.Equal(t, `{"name":"Jacinto"}`, span.Attributes["http.request.body"].(string)) - assert.Equal(t, "xyz123abc", span.Attributes["http.response.header.Request_id"].(string)) - assert.Equal(t, `{"id":123}`, span.Attributes["http.response.body"].(string)) + + _ = span.ReadAttribute("container_id") // needed in containarized envs + // We make sure we read all attributes and covered them with tests + assert.Zero(t, span.RemainingAttributes(), "unexpected remaining attribute: %v", span.Attributes) +} + +func TestClientRequestHeadersAreCapturedAccordingly(t *testing.T) { + tCases := []struct { + captureHTTPHeadersRequestConfig bool + captureHTTPHeadersResponseConfig bool + }{ + {true, true}, + {true, false}, + {false, true}, + {false, false}, + } + for _, tCase := range tCases { + + srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Header().Set("content-type", "application/json") + rw.Header().Set("request_id", "xyz123abc") + rw.WriteHeader(202) + rw.Write([]byte(`{"id":123}`)) + })) + defer srv.Close() + + rt := WrapTransport(http.DefaultTransport, mock.SpanFromContext).(*roundTripper) + rt.dataCaptureConfig = &config.DataCapture{ + HttpHeaders: &config.Message{ + Request: config.Bool(tCase.captureHTTPHeadersRequestConfig), + Response: config.Bool(tCase.captureHTTPHeadersResponseConfig), + }, + HttpBody: &config.Message{ + Request: config.Bool(false), + Response: config.Bool(false), + }, + } + + tr := &mockTransport{ + baseRoundTripper: rt, + } + client := &http.Client{ + Transport: tr, + } + + req, _ := http.NewRequest("POST", srv.URL, bytes.NewBufferString(`{"name":"Jacinto"}`)) + req.Header.Set("api_key", "abc123xyz") + req.Header.Set("content-type", "application/json") + res, err := client.Do(req) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + assert.Equal(t, 202, res.StatusCode) + + resBody, err := ioutil.ReadAll(res.Body) + assert.Nil(t, err) + assert.Equal(t, `{"id":123}`, string(resBody)) + + spans := tr.spans + assert.Equal(t, 1, len(spans), "unexpected number of spans") + + span := spans[0] + if tCase.captureHTTPHeadersRequestConfig { + assert.Equal(t, "abc123xyz", span.ReadAttribute("http.request.header.Api_key").(string)) + } else { + assert.Nil(t, span.ReadAttribute("http.request.header.Api_key")) + } + + if tCase.captureHTTPHeadersResponseConfig { + assert.Equal(t, "xyz123abc", span.ReadAttribute("http.response.header.Request_id").(string)) + } else { + assert.Nil(t, span.ReadAttribute("http.response.header.Request_id")) + } + } } type failingTransport struct { @@ -94,6 +174,7 @@ func TestClientFailureRequestIsSuccessfullyTraced(t *testing.T) { func TestClientRecordsRequestAndResponseBodyAccordingly(t *testing.T) { tCases := map[string]struct { + captureHTTPBodyConfig bool requestBody string requestContentType string shouldHaveRecordedRequestBody bool @@ -102,22 +183,29 @@ func TestClientRecordsRequestAndResponseBodyAccordingly(t *testing.T) { shouldHaveRecordedResponseBody bool }{ "no content type headers and empty body": { + captureHTTPBodyConfig: true, + shouldHaveRecordedRequestBody: false, shouldHaveRecordedResponseBody: false, }, "no content type headers and non empty body": { + captureHTTPBodyConfig: true, + requestBody: "{}", responseBody: "{}", shouldHaveRecordedRequestBody: false, shouldHaveRecordedResponseBody: false, }, "content type headers but empty body": { + captureHTTPBodyConfig: true, + requestContentType: "application/json", responseContentType: "application/x-www-form-urlencoded", shouldHaveRecordedRequestBody: false, shouldHaveRecordedResponseBody: false, }, - "content type and body": { + "content type and body with config enabled": { + captureHTTPBodyConfig: true, requestBody: "test_request_body", responseBody: "test_response_body", requestContentType: "application/x-www-form-urlencoded", @@ -125,6 +213,15 @@ func TestClientRecordsRequestAndResponseBodyAccordingly(t *testing.T) { shouldHaveRecordedRequestBody: true, shouldHaveRecordedResponseBody: true, }, + "content type and body but config disabled": { + captureHTTPBodyConfig: false, + requestBody: "test_request_body", + responseBody: "test_response_body", + requestContentType: "application/x-www-form-urlencoded", + responseContentType: "Application/JSON", + shouldHaveRecordedRequestBody: false, + shouldHaveRecordedResponseBody: false, + }, } for name, tCase := range tCases { @@ -137,9 +234,22 @@ func TestClientRecordsRequestAndResponseBodyAccordingly(t *testing.T) { })) defer srv.Close() + rt := WrapTransport(http.DefaultTransport, mock.SpanFromContext).(*roundTripper) + rt.dataCaptureConfig = &config.DataCapture{ + HttpBody: &config.Message{ + Request: config.Bool(tCase.captureHTTPBodyConfig), + Response: config.Bool(tCase.captureHTTPBodyConfig), + }, + HttpHeaders: &config.Message{ + Request: config.Bool(false), + Response: config.Bool(false), + }, + } + tr := &mockTransport{ - baseRoundTripper: WrapTransport(http.DefaultTransport, mock.SpanFromContext), + baseRoundTripper: rt, } + client := &http.Client{ Transport: tr, } @@ -158,19 +268,16 @@ func TestClientRecordsRequestAndResponseBodyAccordingly(t *testing.T) { assert.Nil(t, err) span := tr.spans[0] - if tCase.shouldHaveRecordedRequestBody { - assert.Equal(t, tCase.requestBody, span.Attributes["http.request.body"].(string)) + assert.Equal(t, tCase.requestBody, span.ReadAttribute("http.request.body").(string)) } else { - _, ok := span.Attributes["http.request.body"] - assert.False(t, ok) + assert.Nil(t, span.ReadAttribute("http.request.body")) } if tCase.shouldHaveRecordedResponseBody { - assert.Equal(t, tCase.responseBody, span.Attributes["http.response.body"].(string)) + assert.Equal(t, tCase.responseBody, span.ReadAttribute("http.response.body").(string)) } else { - _, ok := span.Attributes["http.response.body"] - assert.False(t, ok) + assert.Nil(t, span.ReadAttribute("http.response.body")) } }) }