diff --git a/tsl/profiler/lib/profiler_factory.cc b/tsl/profiler/lib/profiler_factory.cc index a11a11b50..acb625b59 100644 --- a/tsl/profiler/lib/profiler_factory.cc +++ b/tsl/profiler/lib/profiler_factory.cc @@ -44,10 +44,27 @@ void RegisterProfilerFactory(ProfilerFactory factory) { std::vector> CreateProfilers( const tensorflow::ProfileOptions& options) { + // Create a copy of options to modify if circular buffer is enabled. + tensorflow::ProfileOptions options_to_use = options; + bool circular_buffer_enabled = false; + + // Check if tpu_circular_buffer_tracing is enabled in advanced configuration. + auto it = + options.advanced_configuration().find("tpu_circular_buffer_tracing"); + if (it != options.advanced_configuration().end()) { + circular_buffer_enabled = it->second.bool_value(); + } + + if (circular_buffer_enabled) { + // Disable other tracers by zeroing their levels in the local options copy. + options_to_use.set_host_tracer_level(0); + options_to_use.set_python_tracer_level(0); + } + std::vector> result; absl::MutexLock lock(mu); for (const auto& factory : *GetFactories()) { - auto profiler = factory(options); + auto profiler = factory(options_to_use); // A factory might return nullptr based on options. if (profiler == nullptr) continue; result.emplace_back( diff --git a/tsl/profiler/protobuf/profiler_service.proto b/tsl/profiler/protobuf/profiler_service.proto index 2419d8290..0382e763b 100644 --- a/tsl/profiler/protobuf/profiler_service.proto +++ b/tsl/profiler/protobuf/profiler_service.proto @@ -17,6 +17,11 @@ service ProfilerService { rpc Terminate(TerminateRequest) returns (TerminateResponse) {} // Collects profiling data and returns user-friendly metrics. rpc Monitor(MonitorRequest) returns (MonitorResponse) {} + // Starts a continuous profiling session. + rpc StartContinuousProfiling(ProfileRequest) + returns (ContinuousProfilingResponse) {} + // Gets a snapshot of an ongoing profiling session. + rpc GetSnapshot(GetSnapshotRequest) returns (ProfileResponse) {} } message ToolRequestOptions { @@ -99,6 +104,10 @@ message TerminateRequest { message TerminateResponse {} +message GetSnapshotRequest {} + +message ContinuousProfilingResponse {} + // Next-ID: 4 message MonitorRequest { // Duration for which to profile between each update.