diff --git a/pyproject.toml b/pyproject.toml index a3efbc55..706b22fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ test = [ "faker", "pytest-asyncio", "pytest-cov", + "pytest-env", "pytest-timeout", "pytest-watch", "pytest", diff --git a/pytest.ini b/pytest.ini index 68e2f208..1fecc333 100644 --- a/pytest.ini +++ b/pytest.ini @@ -2,3 +2,6 @@ addopts = --durations=10 --cov-config=.coveragerc --timeout=120 --timeout_method=thread --cov=runpod --cov-report=xml --cov-report=term-missing --cov-fail-under=90 -W error -p no:cacheprovider -p no:unraisableexception python_files = tests.py test_*.py *_test.py norecursedirs = venv *.egg-info .git build +env = + D:ENV=test + D:RUNPOD_LOG_LEVEL=ERROR \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index b770f89a..53a0055f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,3 +17,9 @@ tomlkit >= 0.12.2 tqdm-loggable >= 0.1.4 urllib3 >= 1.26.6 watchdog >= 3.0.0 + +setuptools==65.6.3 +opentelemetry-sdk +opentelemetry-exporter-otlp +opentelemetry-instrumentation-aiohttp-client +opentelemetry-instrumentation-requests diff --git a/runpod/http_client.py b/runpod/http_client.py index 0621fccd..268d548b 100644 --- a/runpod/http_client.py +++ b/runpod/http_client.py @@ -1,15 +1,19 @@ """ -HTTP Client abstractions +HTTP Client abstractions with OpenTelemetry tracing support. """ import os - import requests from aiohttp import ClientSession, ClientTimeout, TCPConnector, ClientResponseError +from opentelemetry import trace +from opentelemetry.instrumentation.aiohttp_client import create_trace_config +from opentelemetry.instrumentation.requests import RequestsInstrumentor from .cli.groups.config.functions import get_credentials from .user_agent import USER_AGENT +tracer = trace.get_tracer(__name__) + class TooManyRequests(ClientResponseError): pass @@ -32,22 +36,23 @@ def get_auth_header(): } -def AsyncClientSession(*args, **kwargs): # pylint: disable=invalid-name +def AsyncClientSession(*args, **kwargs): """ - Deprecation from aiohttp.ClientSession forbids inheritance. - This is now a factory method + Factory method for an async client session with OpenTelemetry tracing. """ return ClientSession( connector=TCPConnector(limit=0), headers=get_auth_header(), timeout=ClientTimeout(600, ceil_threshold=400), + trace_configs=[create_trace_config()], *args, **kwargs, ) class SyncClientSession(requests.Session): - """ - Inherits requests.Session to override `request()` method for tracing - """ - pass \ No newline at end of file + def __init__(self): + super().__init__() + self.headers.update(get_auth_header()) + +RequestsInstrumentor().instrument() diff --git a/runpod/serverless/modules/rp_http.py b/runpod/serverless/modules/rp_http.py index 3d82d35b..3d050f66 100644 --- a/runpod/serverless/modules/rp_http.py +++ b/runpod/serverless/modules/rp_http.py @@ -7,6 +7,7 @@ from aiohttp import ClientError from aiohttp_retry import FibonacciRetry, RetryClient +from opentelemetry import trace from runpod.http_client import ClientSession from runpod.serverless.modules.rp_logger import RunPodLogger @@ -24,12 +25,17 @@ JOB_STREAM_URL = JOB_STREAM_URL_TEMPLATE.replace("$RUNPOD_POD_ID", WORKER_ID) log = RunPodLogger() +tracer = trace.get_tracer(__name__) +@tracer.start_as_current_span("transmit", kind=trace.SpanKind.CLIENT) async def _transmit(client_session: ClientSession, url, job_data): """ Wrapper for transmitting results via POST. """ + span = trace.get_current_span() + span.set_attribute("job_data", job_data) + retry_options = FibonacciRetry(attempts=3) retry_client = RetryClient( client_session=client_session, retry_options=retry_options @@ -48,15 +54,18 @@ async def _transmit(client_session: ClientSession, url, job_data): await client_response.text() +@tracer.start_as_current_span("handle_result", kind=trace.SpanKind.CLIENT) async def _handle_result( session: ClientSession, job_data, job, url_template, log_message, is_stream=False ): """ A helper function to handle the result, either for sending or streaming. """ - try: - session.headers["X-Request-ID"] = job["id"] + span = trace.get_current_span() + span.set_attribute("request_id", job.get("id")) + span.set_attribute("is_stream", is_stream) + try: serialized_job_data = json.dumps(job_data, ensure_ascii=False) is_stream = "true" if is_stream else "false" @@ -66,9 +75,11 @@ async def _handle_result( log.debug(f"{log_message}", job["id"]) except ClientError as err: + span.record_exception(err) log.error(f"Failed to return job results. | {err}", job["id"]) except (TypeError, RuntimeError) as err: + span.record_exception(err) log.error(f"Error while returning job result. | {err}", job["id"]) finally: @@ -80,6 +91,7 @@ async def _handle_result( log.info("Finished.", job["id"]) +@tracer.start_as_current_span("send_result") async def send_result(session, job_data, job, is_stream=False): """ Return the job results. @@ -89,6 +101,7 @@ async def send_result(session, job_data, job, is_stream=False): ) +@tracer.start_as_current_span("stream_result") async def stream_result(session, job_data, job): """ Return the stream job results. diff --git a/runpod/serverless/modules/rp_job.py b/runpod/serverless/modules/rp_job.py index f7f8feba..a4349c50 100644 --- a/runpod/serverless/modules/rp_job.py +++ b/runpod/serverless/modules/rp_job.py @@ -6,6 +6,7 @@ import json import os import traceback +from opentelemetry import trace from typing import Any, AsyncGenerator, Callable, Dict, Optional, Union, List import aiohttp @@ -24,6 +25,7 @@ log = RunPodLogger() job_progress = JobsProgress() +tracer = trace.get_tracer(__name__) def _job_get_url(batch_size: int = 1): @@ -117,14 +119,26 @@ async def get_job( return jobs -async def handle_job(session: ClientSession, config: Dict[str, Any], job) -> dict: +@tracer.start_as_current_span("handle_error") +def _handle_error(err_output: any, job: dict) -> bool: + span = trace.get_current_span() + + span.set_status(trace.Status(trace.StatusCode.ERROR, str(err_output))) + log.debug(f"Handled error: {err_output}", job["id"]) + + +@tracer.start_as_current_span("handle_job") +async def handle_job(session: ClientSession, config: Dict[str, Any], job: dict) -> dict: + span = trace.get_current_span() + span.set_attribute("request_id", job.get("id")) + if is_generator(config["handler"]): is_stream = True generator_output = run_job_generator(config["handler"], job) - log.debug("Handler is a generator, streaming results.", job["id"]) job_result = {"output": []} async for stream_output in generator_output: + # temp log.debug(f"Stream output: {stream_output}", job["id"]) if type(stream_output.get("output")) == dict: @@ -164,6 +178,7 @@ async def handle_job(session: ClientSession, config: Dict[str, Any], job) -> dic await send_result(session, job_result, job, is_stream=is_stream) +@tracer.start_as_current_span("run_job") async def run_job(handler: Callable, job: Dict[str, Any]) -> Dict[str, Any]: """ Run the job using the handler. @@ -175,6 +190,9 @@ async def run_job(handler: Callable, job: Dict[str, Any]) -> Dict[str, Any]: Returns: Dict[str, Any]: The result of running the job. """ + span = trace.get_current_span() + span.set_attribute("request_id", job.get("id")) + log.info("Started.", job["id"]) run_result = {} @@ -210,6 +228,7 @@ async def run_job(handler: Callable, job: Dict[str, Any]) -> Dict[str, Any]: check_return_size(run_result) # Checks the size of the return body. except Exception as err: + span.record_exception(err) error_info = { "error_type": str(type(err)), "error_message": str(err), @@ -229,6 +248,7 @@ async def run_job(handler: Callable, job: Dict[str, Any]) -> Dict[str, Any]: return run_result +@tracer.start_as_current_span("run_job_generator") async def run_job_generator( handler: Callable, job: Dict[str, Any] ) -> AsyncGenerator[Dict[str, Union[str, Any]], None]: @@ -236,6 +256,9 @@ async def run_job_generator( Run generator job used to stream output. Yields output partials from the generator. """ + span = trace.get_current_span() + span.set_attribute("request_id", job.get("id")) + is_async_gen = inspect.isasyncgenfunction(handler) log.debug( "Using Async Generator" if is_async_gen else "Using Standard Generator", @@ -255,6 +278,7 @@ async def run_job_generator( yield {"output": output_partial} except Exception as err: + span.record_exception(err) log.error(err, job["id"]) yield {"error": f"handler: {str(err)} \ntraceback: {traceback.format_exc()}"} finally: diff --git a/runpod/serverless/modules/rp_ping.py b/runpod/serverless/modules/rp_ping.py index 88fa1049..3c268468 100644 --- a/runpod/serverless/modules/rp_ping.py +++ b/runpod/serverless/modules/rp_ping.py @@ -8,6 +8,7 @@ import time import requests +from opentelemetry import trace from urllib3.util.retry import Retry from runpod.http_client import SyncClientSession @@ -16,7 +17,8 @@ from runpod.version import __version__ as runpod_version log = RunPodLogger() -jobs = JobsProgress() # Contains the list of jobs that are currently running. +job_progress = JobsProgress() # Contains the list of jobs that are currently running. +tracer = trace.get_tracer(__name__) class Heartbeat: @@ -83,12 +85,18 @@ def ping_loop(self, test=False): if test: return + @tracer.start_as_current_span("send_ping", kind=trace.SpanKind.CLIENT) def _send_ping(self): """ Sends a heartbeat to the Runpod server. """ - job_ids = jobs.get_job_list() - ping_params = {"job_id": job_ids, "runpod_version": runpod_version} + span = trace.get_current_span() + job_ids = [] + for job in job_progress: + span.add_event("ping", {"request_id": job.id}) + job_ids.append(job.id) + + ping_params = {"job_id": ",".join(job_ids), "runpod_version": runpod_version} try: result = self._session.get( @@ -100,4 +108,5 @@ def _send_ping(self): ) except requests.RequestException as err: + span.record_exception(err) log.error(f"Ping Request Error: {err}, attempting to restart ping.") diff --git a/runpod/serverless/modules/rp_scale.py b/runpod/serverless/modules/rp_scale.py index 7c05ef9c..0b647828 100644 --- a/runpod/serverless/modules/rp_scale.py +++ b/runpod/serverless/modules/rp_scale.py @@ -8,6 +8,13 @@ import sys import traceback from typing import Any, Dict +from uuid import uuid1 # traceable to machine's MAC address + timestamp +from opentelemetry.trace import ( + get_tracer, + set_span_in_context, + SpanKind, + NonRecordingSpan, +) from ...http_client import AsyncClientSession, ClientSession, TooManyRequests from .rp_job import get_job, handle_job @@ -16,6 +23,7 @@ log = RunPodLogger() job_progress = JobsProgress() +tracer = get_tracer(__name__) def _handle_uncaught_exception(exc_type, exc_value, exc_traceback): @@ -169,49 +177,65 @@ async def get_jobs(self, session: ClientSession): jobs_needed = self.current_concurrency - self.current_occupancy() if jobs_needed <= 0: - log.debug("JobScaler.get_jobs | Queue is full. Retrying soon.") + log.debug("Queue is full. Retrying soon.") await asyncio.sleep(1) # don't go rapidly continue - try: - log.debug("JobScaler.get_jobs | Starting job acquisition.") - - # Keep the connection to the blocking call with timeout - acquired_jobs = await asyncio.wait_for( - self.jobs_fetcher(session, jobs_needed), - timeout=self.jobs_fetcher_timeout, - ) - - if not acquired_jobs: - log.debug("JobScaler.get_jobs | No jobs acquired.") - continue - - for job in acquired_jobs: - await self.jobs_queue.put(job) - job_progress.add(job) - log.debug("Job Queued", job["id"]) - - log.info(f"Jobs in queue: {self.jobs_queue.qsize()}") - - except TooManyRequests: - log.debug( - f"JobScaler.get_jobs | Too many requests. Debounce for 5 seconds." - ) - await asyncio.sleep(5) # debounce for 5 seconds - except asyncio.CancelledError: - log.debug("JobScaler.get_jobs | Request was cancelled.") - raise # CancelledError is a BaseException - except asyncio.TimeoutError: - log.debug("JobScaler.get_jobs | Job acquisition timed out. Retrying.") - except TypeError as error: - log.debug(f"JobScaler.get_jobs | Unexpected error: {error}.") - except Exception as error: - log.error( - f"Failed to get job. | Error Type: {type(error).__name__} | Error Message: {str(error)}" - ) - finally: - # Yield control back to the event loop - await asyncio.sleep(0) + with tracer.start_as_current_span( + "get_jobs", kind=SpanKind.CLIENT + ) as span: + span.set_attribute("batch_id", uuid1().hex) + + try: + log.debug("JobScaler.get_jobs | Starting job acquisition.") + + # Keep the connection to the blocking call with timeout + acquired_jobs = await asyncio.wait_for( + self.jobs_fetcher(session, jobs_needed), + timeout=self.jobs_fetcher_timeout, + ) + + if not acquired_jobs: + span.add_event("No jobs acquired") + log.debug("No jobs acquired") + continue + + span.set_attribute("jobs_acquired_count", len(acquired_jobs)) + + for job in acquired_jobs: + with tracer.start_as_current_span("queue_job", kind=SpanKind.PRODUCER) as job_span: + job_span.set_attribute("request_id", job.get("id")) + job["context"] = job_span.get_span_context() + + await self.jobs_queue.put(job) + job_progress.add(job) + log.debug("Job Queued", job["id"]) + + log.info(f"Jobs in queue: {self.jobs_queue.qsize()}") + + except TooManyRequests: + span.add_event("Too many requests. Debounce for 5 seconds.") + log.debug("JobScaler.get_jobs | Too many requests. Debounce for 5 seconds.") + await asyncio.sleep(5) # debounce for 5 seconds + except asyncio.CancelledError: + span.add_event("Request was cancelled") + log.debug("JobScaler.get_jobs | Request was cancelled.") + raise # CancelledError is a BaseException + except TimeoutError as error: + span.add_event("Job acquisition timed out") + log.debug("JobScaler.get_jobs | Job acquisition timed out. Retrying.") + except TypeError as error: + # worker waking up produces a JSON error here + span.record_exception(error) + log.debug(f"JobScaler.get_jobs | Unexpected error: {error}.") + except Exception as error: + span.record_exception(error) + log.error( + f"Failed to get job. | Error Type: {type(error).__name__} | Error Message: {str(error)}" + ) + finally: + # Yield control back to the event loop + await asyncio.sleep(0) async def run_jobs(self, session: ClientSession): """ @@ -227,11 +251,12 @@ async def run_jobs(self, session: ClientSession): job = await self.jobs_queue.get() # Create a new task for each job and add it to the task list - task = asyncio.create_task(self.handle_job(session, job)) + task = asyncio.create_task(self.perform_job(session, job)) tasks.append(task) # Wait for any job to finish if tasks: + # TODO: metrics {"jobs.in_progress", len(tasks)} log.info(f"Jobs in progress: {len(tasks)}") done, pending = await asyncio.wait( @@ -247,27 +272,36 @@ async def run_jobs(self, session: ClientSession): # Ensure all remaining tasks finish before stopping await asyncio.gather(*tasks) - async def handle_job(self, session: ClientSession, job: dict): + async def perform_job(self, session: ClientSession, job: dict): """ Process an individual job. This function is run concurrently for multiple jobs. """ - try: - log.debug("Handling Job", job["id"]) + context = set_span_in_context(NonRecordingSpan(job["context"])) - await self.jobs_handler(session, self.config, job) + with tracer.start_as_current_span( + "perform_job", context=context, kind=SpanKind.CONSUMER + ) as span: - if self.config.get("refresh_worker", False): - self.kill_worker() + try: + span.set_attribute("request_id", job.get("id")) + log.debug("Handling Job", job["id"]) - except Exception as err: - log.error(f"Error handling job: {err}", job["id"]) - raise err + await self.jobs_handler(session, self.config, job) - finally: - # Inform Queue of a task completion - self.jobs_queue.task_done() + if self.config.get("refresh_worker", False): + span.add_event("refresh_worker") + self.kill_worker() + + except Exception as err: + span.record_exception(err) + log.error(f"Error handling job: {err}", job["id"]) + raise err + + finally: + # Inform Queue of a task completion + self.jobs_queue.task_done() - # Job is no longer in progress - job_progress.remove(job) + # Job is no longer in progress + job_progress.remove(job) - log.debug("Finished Job", job["id"]) + log.debug("Finished Job", job["id"]) diff --git a/runpod/serverless/modules/rp_tracer.py b/runpod/serverless/modules/rp_tracer.py new file mode 100644 index 00000000..da8a0925 --- /dev/null +++ b/runpod/serverless/modules/rp_tracer.py @@ -0,0 +1,117 @@ +import os + +from opentelemetry import trace +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry.sdk.trace import TracerProvider, sampling +from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter +from opentelemetry.sdk.resources import ( + Resource, + DEPLOYMENT_ENVIRONMENT, + SERVICE_NAME, + SERVICE_VERSION, +) +from runpod.version import __version__ as runpod_version +from .rp_logger import RunPodLogger + + +log = RunPodLogger() + +# https://opentelemetry.io/docs/languages/sdk-configuration/otlp-exporter/ +OTEL_EXPORTER_OTLP_ENDPOINT = os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT", "") + +# https://opentelemetry.io/docs/languages/sdk-configuration/general/#otel_service_name +OTEL_SERVICE_NAME = os.getenv("OTEL_SERVICE_NAME", "serverless-worker") + +OTEL_SAMPLING_RATE = float(os.getenv("OTEL_SAMPLING_RATE", "0.01")) + + +def start( + service_name: str = OTEL_SERVICE_NAME, + collector: str = OTEL_EXPORTER_OTLP_ENDPOINT, + rate: float = OTEL_SAMPLING_RATE, +): + """ + Initializes the OpenTelemetry global tracer provider. + + Args: + service_name: The service name to associate with the OTEL spans. + collector: The URL of the OTEL collector to report to. Defaults to + the `OTEL_EXPORTER_OTLP_ENDPOINT` environment variable. + rate: The sampling rate between 0.0 and 1.0. Defaults to the + `OTEL_SAMPLING_RATE` env var or 0.01 (1%) + + Notes: + The env var `RUNPOD_LOG_LEVEL=trace` can be set to force mandatory tracing. + Otherwise, the sampling rate is used to control the amount of tracing. + + If a collector is provided, the traces are exported to it. + Else if the environment is "local", the traces are printed to the console. + + If neither of the above conditions are met, then tracing is disabled. + """ + RUNPOD_ENV = get_deployment_env() + RUNPOD_LOG_LEVEL = os.getenv("RUNPOD_LOG_LEVEL", "").lower() + + if RUNPOD_LOG_LEVEL == "trace": + sampler = sampling.ALWAYS_ON + else: + sampler = sampling.TraceIdRatioBased(rate) + + tracer = TracerProvider( + sampler=sampler, + resource=get_resource(service_name, RUNPOD_ENV), + ) + + if collector: + tracer.add_span_processor(BatchSpanProcessor(OTLPSpanExporter())) + trace.set_tracer_provider(tracer) + log.info(f"OpenTelemetry is on: {sampler.get_description()}") + + elif RUNPOD_ENV == "local": + tracer.add_span_processor(BatchSpanProcessor(ConsoleSpanExporter())) + trace.set_tracer_provider(tracer) + log.info(f"Tracing prints to console: {sampler.get_description()}") + + else: + # Use NoOpTracerProvider to disable OTEL + trace.set_tracer_provider(trace.NoOpTracerProvider()) + + +def get_resource(service_name: str, environment: str) -> Resource: + """ + Constructs and returns a Resource object for OpenTelemetry. + + The Resource object includes essential metadata such as deployment + environment, service name, service version, and unique identifiers + for the RunPod endpoint and pod. + + Args: + service_name: The name of the service to associate with the resource. + environment: The deployment environment (e.g., dev, prod, local). + + Returns: + A Resource object containing metadata for tracing and monitoring. + """ + RUNPOD_ENDPOINT_ID = "runpod.endpoint_id" + RUNPOD_ENDPOINT_ID_VALUE = os.getenv("RUNPOD_ENDPOINT_ID", "") + RUNPOD_POD_ID = "runpod.pod_id" + RUNPOD_POD_ID_VALUE = os.getenv("RUNPOD_POD_ID", "") + + return Resource.create( + { + DEPLOYMENT_ENVIRONMENT: environment, + RUNPOD_ENDPOINT_ID: RUNPOD_ENDPOINT_ID_VALUE, + RUNPOD_POD_ID: RUNPOD_POD_ID_VALUE, + SERVICE_NAME: service_name, + SERVICE_VERSION: runpod_version, + } + ) + + +def get_deployment_env() -> str: + RUNPOD_API_URL = os.getenv("RUNPOD_WEBHOOK_PING", "") + if "runpod.dev" in RUNPOD_API_URL: + return "dev" + if "runpod.ai" in RUNPOD_API_URL: + return "prod" + return "local" diff --git a/runpod/serverless/modules/worker_state.py b/runpod/serverless/modules/worker_state.py index 5e1a2f98..e7d84c14 100644 --- a/runpod/serverless/modules/worker_state.py +++ b/runpod/serverless/modules/worker_state.py @@ -5,7 +5,7 @@ import os import time import uuid -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Set from .rp_logger import RunPodLogger @@ -61,7 +61,7 @@ def __str__(self) -> str: # ---------------------------------------------------------------------------- # # Tracker # # ---------------------------------------------------------------------------- # -class JobsProgress(set): +class JobsProgress(Set[Job]): """Track the state of current jobs in progress.""" _instance = None diff --git a/runpod/serverless/worker.py b/runpod/serverless/worker.py index ec98347d..ed72e76f 100644 --- a/runpod/serverless/worker.py +++ b/runpod/serverless/worker.py @@ -7,7 +7,7 @@ import os from typing import Any, Dict -from runpod.serverless.modules import rp_logger, rp_local, rp_ping, rp_scale +from runpod.serverless.modules import rp_logger, rp_local, rp_ping, rp_scale, rp_tracer log = rp_logger.RunPodLogger() heartbeat = rp_ping.Heartbeat() @@ -35,6 +35,8 @@ def run_worker(config: Dict[str, Any]) -> None: Args: config (Dict[str, Any]): Configuration parameters for the worker. """ + rp_tracer.start() + # Start pinging RunPod to show that the worker is alive. heartbeat.start_ping() diff --git a/setup.py b/setup.py index 11fe7ce5..d9583e72 100644 --- a/setup.py +++ b/setup.py @@ -20,6 +20,7 @@ "nest_asyncio", "pytest", "pytest-cov", + "pytest-env", "pytest-timeout", "pytest-asyncio", ] diff --git a/tests/test_serverless/test_modules/run_scale.py b/tests/test_serverless/test_modules/run_scale.py index 5983c7a6..1730505b 100644 --- a/tests/test_serverless/test_modules/run_scale.py +++ b/tests/test_serverless/test_modules/run_scale.py @@ -3,6 +3,7 @@ from faker import Faker from typing import Any, Dict, Optional, List +from runpod.serverless.modules import rp_tracer from runpod.serverless.modules.rp_scale import JobScaler, RunPodLogger, JobsProgress fake = Faker() @@ -60,4 +61,5 @@ async def fake_handle_job(session, config, job) -> dict: "jobs_handler": fake_handle_job, } ) +rp_tracer.start() job_scaler.start() diff --git a/tests/test_serverless/test_modules/test_state.py b/tests/test_serverless/test_modules/test_state.py index f3bb3372..faa99028 100644 --- a/tests/test_serverless/test_modules/test_state.py +++ b/tests/test_serverless/test_modules/test_state.py @@ -158,4 +158,4 @@ async def test_get_job_list(self): async def test_get_job_count(self): # test job count contention when adding and removing jobs in parallel - pass \ No newline at end of file + pass