diff --git a/tsl/profiler/lib/BUILD b/tsl/profiler/lib/BUILD index 214b3bbf4..56e245ea2 100644 --- a/tsl/profiler/lib/BUILD +++ b/tsl/profiler/lib/BUILD @@ -358,6 +358,7 @@ cc_library( ":traceme_encode", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", + "@xla//xla/tsl/platform:env", "@xla//xla/tsl/platform:types", ], ) diff --git a/tsl/profiler/lib/connected_traceme.h b/tsl/profiler/lib/connected_traceme.h index c9e4e520c..34d5d8957 100644 --- a/tsl/profiler/lib/connected_traceme.h +++ b/tsl/profiler/lib/connected_traceme.h @@ -15,12 +15,14 @@ limitations under the License. #ifndef TENSORFLOW_TSL_PROFILER_LIB_CONNECTED_TRACEME_H_ #define TENSORFLOW_TSL_PROFILER_LIB_CONNECTED_TRACEME_H_ +#include #include #include #include #include "absl/strings/string_view.h" #include "absl/types/optional.h" +#include "xla/tsl/platform/env.h" #include "xla/tsl/platform/types.h" #include "tsl/profiler/lib/context_types.h" #include "tsl/profiler/lib/traceme.h" @@ -82,12 +84,16 @@ class TraceMeProducer : public TraceMe { explicit TraceMeProducer(NameT&& name, ContextType context_type = ContextType::kGeneric, std::optional context_id = std::nullopt, - int level = tsl::profiler::TraceMeLevel::kCritical) + int level = tsl::profiler::TraceMeLevel::kCritical, + std::optional pid = std::nullopt) : TraceMe(std::forward(name), level), context_id_(context_id.has_value() ? context_id.value() : TraceMe::NewActivityId()) { AppendMetadata([&] { - return TraceMeEncode({{"_pt", context_type}, {"_p", context_id_}}); + return TraceMeEncode( + {{"_pt", context_type}, + {"_p", context_id_}, + {"_ppid", pid.value_or(tsl::Env::Default()->GetProcessId())}}); }); } @@ -101,10 +107,14 @@ class TraceMeConsumer : public TraceMe { public: template TraceMeConsumer(NameT&& name, ContextType context_type, uint64 context_id, - int level = tsl::profiler::TraceMeLevel::kCritical) + int level = tsl::profiler::TraceMeLevel::kCritical, + std::optional pid = std::nullopt) : TraceMe(std::forward(name), level) { AppendMetadata([&] { - return TraceMeEncode({{"_ct", context_type}, {"_c", context_id}}); + return TraceMeEncode( + {{"_ct", context_type}, + {"_c", context_id}, + {"_cpid", pid.value_or(tsl::Env::Default()->GetProcessId())}}); }); }