Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 9 additions & 18 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,21 @@ repos:
hooks:
- id: add-trailing-comma

- repo: local
- repo: https://github.com/psf/black
rev: 25.1.0
hooks:
- id: black
name: Format with Black
entry: poetry run black
language: system
types: [python]

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.5
hooks:
- id: ruff
name: Run ruff lints
entry: poetry run ruff
language: system
pass_filenames: false
types: [python]
args:
- "check"
- "--fix"
- "otlp_psqlpy"

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.15.0
hooks:
- id: mypy
name: Validate types with MyPy
entry: poetry run mypy
language: system
pass_filenames: false
types: [python]
args:
- ./otlp_psqlpy
additional_dependencies: ["."]
31 changes: 22 additions & 9 deletions otlp_psqlpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
import typing as t
from collections.abc import Awaitable, Mapping, Sequence

import psqlpy
import wrapt # type: ignore[import-untyped]
Expand Down Expand Up @@ -76,7 +77,7 @@ def _construct_span(
if hosts:
span_attributes[SpanAttributes.SERVER_ADDRESS] = ", ".join(hosts)
span_attributes[SpanAttributes.SERVER_PORT] = ", ".join(
[str(port) for port in ports]
[str(port) for port in ports],
)
span_attributes[SpanAttributes.NETWORK_TRANSPORT] = (
NetTransportValues.IP_TCP.value
Expand All @@ -85,7 +86,7 @@ def _construct_span(
elif host_addrs:
span_attributes[SpanAttributes.SERVER_ADDRESS] = ", ".join(host_addrs)
span_attributes[SpanAttributes.SERVER_PORT] = ", ".join(
[str(port) for port in ports]
[str(port) for port in ports],
)
span_attributes[SpanAttributes.NETWORK_TRANSPORT] = (
NetTransportValues.IP_TCP.value
Expand Down Expand Up @@ -171,7 +172,13 @@ def _uninstrument(self, **__: t.Any) -> None:
for method_name in methods:
unwrap(cls, method_name)

async def _do_execute(self, func, instance, args, kwargs):
async def _do_execute(
self,
func: t.Callable[..., Awaitable[t.Any]],
instance: t.Union[psqlpy.Connection, psqlpy.Transaction, psqlpy.Cursor],
args: Sequence[t.Any],
kwargs: Mapping[str, t.Any],
) -> t.Any:
exception = None
params = getattr(instance, "_params", {})
name = args[0] if args else params.get("database", "postgresql")
Expand All @@ -182,20 +189,20 @@ async def _do_execute(self, func, instance, args, kwargs):
except IndexError:
name = ""

with self._tracer.start_as_current_span(
with self._tracer.start_as_current_span( # type: ignore[union-attr]
name,
kind=SpanKind.CLIENT,
) as span:
if span.is_recording():
span_attributes = _construct_span(
instance,
_retrieve_parameter_from_args_or_kwargs(
_retrieve_parameter_from_args_or_kwargs( # type: ignore[arg-type]
parameter_name="querystring",
parameter_index=0,
args=args,
kwargs=kwargs,
),
_retrieve_parameter_from_args_or_kwargs(
_retrieve_parameter_from_args_or_kwargs( # type: ignore[arg-type]
parameter_name="parameters",
parameter_index=1,
args=args,
Expand All @@ -222,13 +229,19 @@ async def _do_execute(self, func, instance, args, kwargs):

return result

async def _do_cursor_execute(self, func, instance, args, kwargs):
async def _do_cursor_execute(
self,
func: t.Callable[..., Awaitable[t.Any]],
instance: psqlpy.Cursor,
args: Sequence[t.Any],
kwargs: Mapping[str, t.Any],
) -> t.Any:
"""Wrap cursor based functions. For every call this will generate a new span."""
exception = None

stop = False
with self._tracer.start_as_current_span(
f"CURSOR",
with self._tracer.start_as_current_span( # type: ignore[union-attr]
"CURSOR",
kind=SpanKind.CLIENT,
) as span:
if span.is_recording():
Expand Down
Loading