diff --git a/sentry_sdk/integrations/langchain.py b/sentry_sdk/integrations/langchain.py index 60057c2064..3b262f9346 100644 --- a/sentry_sdk/integrations/langchain.py +++ b/sentry_sdk/integrations/langchain.py @@ -34,27 +34,20 @@ DATA_FIELDS = { - "temperature": SPANDATA.GEN_AI_REQUEST_TEMPERATURE, - "top_p": SPANDATA.GEN_AI_REQUEST_TOP_P, - "top_k": SPANDATA.GEN_AI_REQUEST_TOP_K, + "frequency_penalty": SPANDATA.GEN_AI_REQUEST_FREQUENCY_PENALTY, "function_call": SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS, - "tool_calls": SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS, - "tools": SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS, - "response_format": SPANDATA.GEN_AI_RESPONSE_FORMAT, "logit_bias": SPANDATA.GEN_AI_REQUEST_LOGIT_BIAS, + "max_tokens": SPANDATA.GEN_AI_REQUEST_MAX_TOKENS, + "presence_penalty": SPANDATA.GEN_AI_REQUEST_PRESENCE_PENALTY, + "response_format": SPANDATA.GEN_AI_RESPONSE_FORMAT, "tags": SPANDATA.GEN_AI_REQUEST_TAGS, + "temperature": SPANDATA.GEN_AI_REQUEST_TEMPERATURE, + "tool_calls": SPANDATA.GEN_AI_RESPONSE_TOOL_CALLS, + "tools": SPANDATA.GEN_AI_REQUEST_AVAILABLE_TOOLS, + "top_k": SPANDATA.GEN_AI_REQUEST_TOP_K, + "top_p": SPANDATA.GEN_AI_REQUEST_TOP_P, } -# TODO(shellmayr): is this still the case? -# To avoid double collecting tokens, we do *not* measure -# token counts for models for which we have an explicit integration -NO_COLLECT_TOKEN_MODELS = [ - # "openai-chat", - # "anthropic-chat", - "cohere-chat", - "huggingface_endpoint", -] - class LangchainIntegration(Integration): identifier = "langchain" @@ -80,7 +73,6 @@ def setup_once(): class WatchedSpan: span = None # type: Span - no_collect_tokens = False # type: bool children = [] # type: List[WatchedSpan] is_pipeline = False # type: bool @@ -270,7 +262,7 @@ def on_llm_start( all_params.update(serialized.get("kwargs", {})) watched_span = self._create_span( - run_id, + run_id=run_id, parent_id=parent_run_id, op=OP.GEN_AI_PIPELINE, name=kwargs.get("name") or "Langchain LLM call", @@ -297,25 +289,31 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs): return all_params = kwargs.get("invocation_params", {}) all_params.update(serialized.get("kwargs", {})) + + model = ( + all_params.get("model") + or all_params.get("model_name") + or all_params.get("model_id") + or "" + ) + watched_span = self._create_span( - run_id, + run_id=run_id, parent_id=kwargs.get("parent_run_id"), op=OP.GEN_AI_CHAT, - name=kwargs.get("name") or "Langchain Chat Model", + name=f"chat {model}".strip(), origin=LangchainIntegration.origin, ) span = watched_span.span - model = all_params.get( - "model", all_params.get("model_name", all_params.get("model_id")) - ) - watched_span.no_collect_tokens = any( - x in all_params.get("_type", "") for x in NO_COLLECT_TOKEN_MODELS - ) - if not model and "anthropic" in all_params.get("_type"): - model = "claude-2" + span.set_data(SPANDATA.GEN_AI_OPERATION_NAME, "chat") if model: span.set_data(SPANDATA.GEN_AI_REQUEST_MODEL, model) + + for key, attribute in DATA_FIELDS.items(): + if key in all_params: + set_data_normalized(span, attribute, all_params[key], unpack=False) + if should_send_default_pii() and self.include_prompts: set_data_normalized( span, @@ -325,10 +323,6 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs): for list_ in messages ], ) - for k, v in DATA_FIELDS.items(): - if k in all_params: - set_data_normalized(span, v, all_params[k]) - # no manual token counting def on_chat_model_end(self, response, *, run_id, **kwargs): # type: (SentryLangchainCallback, LLMResult, UUID, Any) -> Any @@ -336,7 +330,6 @@ def on_chat_model_end(self, response, *, run_id, **kwargs): with capture_internal_exceptions(): if not run_id: return - token_usage = None # Try multiple paths to extract token usage, prioritizing streaming-aware approaches @@ -367,27 +360,26 @@ def on_chat_model_end(self, response, *, run_id, **kwargs): [[x.text for x in list_] for list_ in response.generations], ) - if not span_data.no_collect_tokens: - if token_usage: - input_tokens, output_tokens, total_tokens = ( - self._extract_token_usage(token_usage) - ) - else: - input_tokens, output_tokens, total_tokens = ( - self._extract_token_usage_from_generations(response.generations) - ) + if token_usage: + input_tokens, output_tokens, total_tokens = self._extract_token_usage( + token_usage + ) + else: + input_tokens, output_tokens, total_tokens = ( + self._extract_token_usage_from_generations(response.generations) + ) - if ( - input_tokens is not None - or output_tokens is not None - or total_tokens is not None - ): - record_token_usage( - span_data.span, - input_tokens=input_tokens, - output_tokens=output_tokens, - total_tokens=total_tokens, - ) + if ( + input_tokens is not None + or output_tokens is not None + or total_tokens is not None + ): + record_token_usage( + span_data.span, + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=total_tokens, + ) self._exit_span(span_data, run_id) @@ -429,27 +421,26 @@ def on_llm_end(self, response, *, run_id, **kwargs): [[x.text for x in list_] for list_ in response.generations], ) - if not span_data.no_collect_tokens: - if token_usage: - input_tokens, output_tokens, total_tokens = ( - self._extract_token_usage(token_usage) - ) - else: - input_tokens, output_tokens, total_tokens = ( - self._extract_token_usage_from_generations(response.generations) - ) + if token_usage: + input_tokens, output_tokens, total_tokens = self._extract_token_usage( + token_usage + ) + else: + input_tokens, output_tokens, total_tokens = ( + self._extract_token_usage_from_generations(response.generations) + ) - if ( - input_tokens is not None - or output_tokens is not None - or total_tokens is not None - ): - record_token_usage( - span_data.span, - input_tokens=input_tokens, - output_tokens=output_tokens, - total_tokens=total_tokens, - ) + if ( + input_tokens is not None + or output_tokens is not None + or total_tokens is not None + ): + record_token_usage( + span_data.span, + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=total_tokens, + ) self._exit_span(span_data, run_id) @@ -515,13 +506,13 @@ def on_tool_start(self, serialized, input_str, *, run_id, **kwargs): if not run_id: return - tool_name = serialized.get("name") or kwargs.get("name") + tool_name = serialized.get("name") or kwargs.get("name") or "" watched_span = self._create_span( - run_id, + run_id=run_id, parent_id=kwargs.get("parent_run_id"), op=OP.GEN_AI_EXECUTE_TOOL, - name=f"execute_tool {tool_name}", + name=f"execute_tool {tool_name}".strip(), origin=LangchainIntegration.origin, ) span = watched_span.span