diff --git a/api/include/opentelemetry/context/context.h b/api/include/opentelemetry/context/context.h index dfc5c87bcc..a1379cdc64 100644 --- a/api/include/opentelemetry/context/context.h +++ b/api/include/opentelemetry/context/context.h @@ -16,6 +16,7 @@ class Context { public: + Context() = default; // Creates a context object from a map of keys and identifiers, this will // hold a shared_ptr to the head of the DataList linked list template @@ -89,9 +90,9 @@ class Context return false; } -private: - Context() = default; + bool operator==(const Context &other) { return (head_ == other.head_); } +private: // A linked list to contain the keys and values of this context node class DataList { diff --git a/api/include/opentelemetry/context/runtime_context.h b/api/include/opentelemetry/context/runtime_context.h new file mode 100644 index 0000000000..7ad3b73bdc --- /dev/null +++ b/api/include/opentelemetry/context/runtime_context.h @@ -0,0 +1,61 @@ +#pragma once + +#include "opentelemetry/context/context.h" + +OPENTELEMETRY_BEGIN_NAMESPACE +namespace context +{ +// Provides a wrapper for propagating the context object globally. In order +// to use either the threadlocal_context.h file must be included or another +// implementation which must be derived from the RuntimeContext can be +// provided. +class RuntimeContext +{ +public: + class Token + { + public: + bool operator==(const Context &other) noexcept { return context_ == other; } + + ~Token() noexcept { Detach(*this); } + + private: + friend class RuntimeContext; + + // A constructor that sets the token's Context object to the + // one that was passed in. + Token(Context context) noexcept : context_(context){}; + + Token() noexcept = default; + + Context context_; + }; + + // Return the current context. + static Context GetCurrent() noexcept { return context_handler_->InternalGetCurrent(); } + + // Sets the current 'Context' object. Returns a token + // that can be used to reset to the previous Context. + static Token Attach(Context context) noexcept + { + return context_handler_->InternalAttach(context); + } + + // Resets the context to a previous value stored in the + // passed in token. Returns true if successful, false otherwise + static bool Detach(Token &token) noexcept { return context_handler_->InternalDetach(token); } + + static RuntimeContext *context_handler_; + +protected: + // Provides a token with the passed in context + Token CreateToken(Context context) noexcept { return Token(context); } + + virtual Context InternalGetCurrent() noexcept = 0; + + virtual Token InternalAttach(Context context) noexcept = 0; + + virtual bool InternalDetach(Token &token) noexcept = 0; +}; +} // namespace context +OPENTELEMETRY_END_NAMESPACE diff --git a/api/include/opentelemetry/context/threadlocal_context.h b/api/include/opentelemetry/context/threadlocal_context.h new file mode 100644 index 0000000000..febd89e63a --- /dev/null +++ b/api/include/opentelemetry/context/threadlocal_context.h @@ -0,0 +1,116 @@ +#pragma once + +#include "opentelemetry/context/context.h" +#include "opentelemetry/context/runtime_context.h" + +OPENTELEMETRY_BEGIN_NAMESPACE +namespace context +{ + +// The ThreadLocalContext class is a derived class from RuntimeContext and +// provides a wrapper for propogating context through cpp thread locally. +// This file must be included to use the RuntimeContext class if another +// implementation has not been registered. +class ThreadLocalContext : public RuntimeContext +{ +public: + ThreadLocalContext() noexcept = default; + + // Return the current context. + Context InternalGetCurrent() noexcept override { return stack_.Top(); } + + // Resets the context to a previous value stored in the + // passed in token. Returns true if successful, false otherwise + bool InternalDetach(Token &token) noexcept override + { + if (!(token == stack_.Top())) + { + return false; + } + stack_.Pop(); + return true; + } + + // Sets the current 'Context' object. Returns a token + // that can be used to reset to the previous Context. + Token InternalAttach(Context context) noexcept override + { + stack_.Push(context); + Token old_context = CreateToken(context); + return old_context; + } + +private: + // A nested class to store the attached contexts in a stack. + class Stack + { + friend class ThreadLocalContext; + + Stack() noexcept : size_(0), capacity_(0), base_(nullptr){}; + + // Pops the top Context off the stack and returns it. + Context Pop() noexcept + { + if (size_ <= 0) + { + return Context(); + } + int index = size_ - 1; + size_--; + return base_[index]; + } + + // Returns the Context at the top of the stack. + Context Top() const noexcept + { + if (size_ <= 0) + { + return Context(); + } + return base_[size_ - 1]; + } + + // Pushes the passed in context pointer to the top of the stack + // and resizes if necessary. + void Push(Context context) noexcept + { + size_++; + if (size_ > capacity_) + { + Resize(size_ * 2); + } + base_[size_ - 1] = context; + } + + // Reallocates the storage array to the pass in new capacity size. + void Resize(int new_capacity) noexcept + { + int old_size = size_ - 1; + if (new_capacity == 0) + { + new_capacity = 2; + } + Context *temp = new Context[new_capacity]; + if (base_ != nullptr) + { + std::copy(base_, base_ + old_size, temp); + delete[] base_; + } + base_ = temp; + } + + ~Stack() noexcept { delete[] base_; } + + size_t size_; + size_t capacity_; + Context *base_; + }; + + static thread_local Stack stack_; +}; +thread_local ThreadLocalContext::Stack ThreadLocalContext::stack_ = ThreadLocalContext::Stack(); + +// Registers the ThreadLocalContext as the context handler for the RuntimeContext +RuntimeContext *RuntimeContext::context_handler_ = new ThreadLocalContext(); +} // namespace context +OPENTELEMETRY_END_NAMESPACE diff --git a/api/test/context/BUILD b/api/test/context/BUILD index b213710944..87452f4588 100644 --- a/api/test/context/BUILD +++ b/api/test/context/BUILD @@ -10,3 +10,14 @@ cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_test( + name = "runtime_context_test", + srcs = [ + "runtime_context_test.cc", + ], + deps = [ + "//api", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/api/test/context/context_test.cc b/api/test/context/context_test.cc index 26fc5110bf..f9a6dfbfd4 100644 --- a/api/test/context/context_test.cc +++ b/api/test/context/context_test.cc @@ -118,3 +118,22 @@ TEST(ContextTest, ContextHasKey) EXPECT_TRUE(context_test.HasKey("test_key")); EXPECT_FALSE(context_test.HasKey("foo_key")); } + +// Tests that a copied context returns true when compared +TEST(ContextTest, ContextCopyCompare) +{ + std::map map_test = {{"test_key", (int64_t)123}}; + context::Context context_test = context::Context(map_test); + context::Context copied_test = context_test; + EXPECT_TRUE(context_test == copied_test); +} + +// Tests that two differently constructed contexts return false when compared +TEST(ContextTest, ContextDiffCompare) +{ + std::map map_test = {{"test_key", (int64_t)123}}; + std::map map_foo = {{"foo_key", (int64_t)123}}; + context::Context context_test = context::Context(map_test); + context::Context foo_test = context::Context(map_foo); + EXPECT_FALSE(context_test == foo_test); +} diff --git a/api/test/context/runtime_context_test.cc b/api/test/context/runtime_context_test.cc new file mode 100644 index 0000000000..b70d69dce8 --- /dev/null +++ b/api/test/context/runtime_context_test.cc @@ -0,0 +1,61 @@ +#include "opentelemetry/context/context.h" +#include "opentelemetry/context/threadlocal_context.h" + +#include + +using namespace opentelemetry; + +// Tests that GetCurrent returns the current context +TEST(RuntimeContextTest, GetCurrent) +{ + std::map map_test = {{"test_key", (int64_t)123}}; + context::Context test_context = context::Context(map_test); + context::RuntimeContext::Token old_context = context::RuntimeContext::Attach(test_context); + EXPECT_TRUE(context::RuntimeContext::GetCurrent() == test_context); + context::RuntimeContext::Detach(old_context); +} + +// Tests that detach resets the context to the previous context +TEST(RuntimeContextTest, Detach) +{ + std::map map_test = {{"test_key", (int64_t)123}}; + context::Context test_context = context::Context(map_test); + context::Context foo_context = context::Context(map_test); + + context::RuntimeContext::Token test_context_token = context::RuntimeContext::Attach(test_context); + context::RuntimeContext::Token foo_context_token = context::RuntimeContext::Attach(foo_context); + + context::RuntimeContext::Detach(foo_context_token); + EXPECT_TRUE(context::RuntimeContext::GetCurrent() == test_context); + context::RuntimeContext::Detach(test_context_token); +} + +// Tests that detach returns false when the wrong context is provided +TEST(RuntimeContextTest, DetachWrongContext) +{ + std::map map_test = {{"test_key", (int64_t)123}}; + context::Context test_context = context::Context(map_test); + context::Context foo_context = context::Context(map_test); + context::RuntimeContext::Token test_context_token = context::RuntimeContext::Attach(test_context); + context::RuntimeContext::Token foo_context_token = context::RuntimeContext::Attach(foo_context); + EXPECT_FALSE(context::RuntimeContext::Detach(test_context_token)); + context::RuntimeContext::Detach(foo_context_token); + context::RuntimeContext::Detach(test_context_token); +} + +// Tests that the ThreadLocalContext can handle three attached contexts +TEST(RuntimeContextTest, ThreeAttachDetach) +{ + std::map map_test = {{"test_key", (int64_t)123}}; + context::Context test_context = context::Context(map_test); + context::Context foo_context = context::Context(map_test); + context::Context other_context = context::Context(map_test); + context::RuntimeContext::Token test_context_token = context::RuntimeContext::Attach(test_context); + context::RuntimeContext::Token foo_context_token = context::RuntimeContext::Attach(foo_context); + context::RuntimeContext::Token other_context_token = + context::RuntimeContext::Attach(other_context); + + EXPECT_TRUE(context::RuntimeContext::Detach(other_context_token)); + EXPECT_TRUE(context::RuntimeContext::Detach(foo_context_token)); + EXPECT_TRUE(context::RuntimeContext::Detach(test_context_token)); +}