Coverage for mindsdb / interfaces / agents / langfuse_callback_handler.py: 11%
155 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from typing import Any, Dict, Union, Optional, List
2from uuid import uuid4
3import datetime
4import json
6from langchain_core.callbacks.base import BaseCallbackHandler
8from mindsdb.utilities import log
9from mindsdb.interfaces.storage import db
11logger = log.getLogger(__name__)
12logger.setLevel('DEBUG')
15class LangfuseCallbackHandler(BaseCallbackHandler):
16 """Langchain callback handler that traces tool & chain executions using Langfuse."""
18 def __init__(self, langfuse, trace_id: Optional[str] = None, observation_id: Optional[str] = None):
19 self.langfuse = langfuse
20 self.chain_uuid_to_span = {}
21 self.action_uuid_to_span = {}
22 # if these are not available, we generate some UUIDs
23 self.trace_id = trace_id or uuid4().hex
24 self.observation_id = observation_id or uuid4().hex
25 # Track metrics about tools and chains
26 self.tool_metrics = {}
27 self.chain_metrics = {}
28 self.current_chain = None
30 def on_tool_start(
31 self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
32 ) -> Any:
33 """Run when tool starts running."""
34 parent_run_uuid = kwargs.get('parent_run_id', uuid4()).hex
35 action_span = self.action_uuid_to_span.get(parent_run_uuid)
36 if action_span is None:
37 return
39 tool_name = serialized.get("name", "tool")
40 start_time = datetime.datetime.now()
42 # Initialize or update tool metrics
43 if tool_name not in self.tool_metrics:
44 self.tool_metrics[tool_name] = {
45 'count': 0,
46 'total_time': 0,
47 'errors': 0,
48 'last_error': None,
49 'inputs': []
50 }
52 self.tool_metrics[tool_name]['count'] += 1
53 self.tool_metrics[tool_name]['inputs'].append(input_str)
55 metadata = {
56 'tool_name': tool_name,
57 'started': start_time.isoformat(),
58 'start_timestamp': start_time.timestamp(),
59 'input_length': len(input_str) if input_str else 0
60 }
61 action_span.update(metadata=metadata)
63 def on_tool_end(self, output: str, **kwargs: Any) -> Any:
64 """Run when tool ends running."""
65 parent_run_uuid = kwargs.get('parent_run_id', uuid4()).hex
66 action_span = self.action_uuid_to_span.get(parent_run_uuid)
67 if action_span is None:
68 return
70 end_time = datetime.datetime.now()
71 tool_name = action_span.metadata.get('tool_name', 'unknown')
72 start_timestamp = action_span.metadata.get('start_timestamp')
74 if start_timestamp:
75 duration = end_time.timestamp() - start_timestamp
76 if tool_name in self.tool_metrics:
77 self.tool_metrics[tool_name]['total_time'] += duration
79 metadata = {
80 'finished': end_time.isoformat(),
81 'duration_seconds': duration if start_timestamp else None,
82 'output_length': len(output) if output else 0
83 }
85 action_span.update(
86 output=output, # tool output is action output (unless superseded by a global action output)
87 metadata=metadata
88 )
90 def on_tool_error(
91 self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
92 ) -> Any:
93 """Run when tool errors."""
94 parent_run_uuid = kwargs.get('parent_run_id', uuid4()).hex
95 action_span = self.action_uuid_to_span.get(parent_run_uuid)
96 if action_span is None:
97 return
99 try:
100 error_str = str(error)
101 except Exception:
102 error_str = "Couldn't get error string."
104 tool_name = action_span.metadata.get('tool_name', 'unknown')
105 if tool_name in self.tool_metrics:
106 self.tool_metrics[tool_name]['errors'] += 1
107 self.tool_metrics[tool_name]['last_error'] = error_str
109 metadata = {
110 'error_description': error_str,
111 'error_type': error.__class__.__name__,
112 'error_time': datetime.datetime.now().isoformat()
113 }
114 action_span.update(metadata=metadata)
116 def on_chain_start(
117 self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
118 ) -> Any:
119 """Run when chain starts running."""
120 if self.langfuse is None:
121 return
123 run_uuid = kwargs.get('run_id', uuid4()).hex
125 if serialized is None:
126 serialized = {}
128 chain_name = serialized.get("name", "chain")
129 start_time = datetime.datetime.now()
131 # Initialize or update chain metrics
132 if chain_name not in self.chain_metrics:
133 self.chain_metrics[chain_name] = {
134 'count': 0,
135 'total_time': 0,
136 'errors': 0,
137 'last_error': None
138 }
140 self.chain_metrics[chain_name]['count'] += 1
141 self.current_chain = chain_name
143 try:
144 chain_span = self.langfuse.span(
145 name=f'{chain_name}-{run_uuid}',
146 trace_id=self.trace_id,
147 parent_observation_id=self.observation_id,
148 input=json.dumps(inputs, indent=2)
149 )
151 metadata = {
152 'chain_name': chain_name,
153 'started': start_time.isoformat(),
154 'start_timestamp': start_time.timestamp(),
155 'input_keys': list(inputs.keys()) if isinstance(inputs, dict) else None,
156 'input_size': len(inputs) if isinstance(inputs, dict) else len(str(inputs))
157 }
158 chain_span.update(metadata=metadata)
159 self.chain_uuid_to_span[run_uuid] = chain_span
160 except Exception as e:
161 logger.warning(f"Error creating Langfuse span: {str(e)}")
163 def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
164 """Run when chain ends running."""
165 if self.langfuse is None:
166 return
168 chain_uuid = kwargs.get('run_id', uuid4()).hex
169 if chain_uuid not in self.chain_uuid_to_span:
170 return
171 chain_span = self.chain_uuid_to_span.pop(chain_uuid)
172 if chain_span is None:
173 return
175 try:
176 end_time = datetime.datetime.now()
177 chain_name = chain_span.metadata.get('chain_name', 'unknown')
178 start_timestamp = chain_span.metadata.get('start_timestamp')
180 if start_timestamp and chain_name in self.chain_metrics:
181 duration = end_time.timestamp() - start_timestamp
182 self.chain_metrics[chain_name]['total_time'] += duration
184 metadata = {
185 'finished': end_time.isoformat(),
186 'duration_seconds': duration if start_timestamp else None,
187 'output_keys': list(outputs.keys()) if isinstance(outputs, dict) else None,
188 'output_size': len(outputs) if isinstance(outputs, dict) else len(str(outputs))
189 }
190 chain_span.update(output=json.dumps(outputs, indent=2), metadata=metadata)
191 chain_span.end()
192 except Exception as e:
193 logger.warning(f"Error updating Langfuse span: {str(e)}")
195 def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any:
196 """Run when chain errors."""
197 chain_uuid = kwargs.get('run_id', uuid4()).hex
198 if chain_uuid not in self.chain_uuid_to_span:
199 return
200 chain_span = self.chain_uuid_to_span.get(chain_uuid)
201 if chain_span is None:
202 return
204 try:
205 error_str = str(error)
206 except Exception:
207 error_str = "Couldn't get error string."
209 chain_name = chain_span.metadata.get('chain_name', 'unknown')
210 if chain_name in self.chain_metrics:
211 self.chain_metrics[chain_name]['errors'] += 1
212 self.chain_metrics[chain_name]['last_error'] = error_str
214 metadata = {
215 'error_description': error_str,
216 'error_type': error.__class__.__name__,
217 'error_time': datetime.datetime.now().isoformat()
218 }
219 chain_span.update(metadata=metadata)
221 def on_agent_action(self, action, **kwargs: Any) -> Any:
222 """Run on agent action."""
223 if self.langfuse is None:
224 return
226 run_uuid = kwargs.get('run_id', uuid4()).hex
227 try:
228 action_span = self.langfuse.span(
229 name=f'{getattr(action, "type", "action")}-{getattr(action, "tool", "")}-{run_uuid}',
230 trace_id=self.trace_id,
231 parent_observation_id=self.observation_id,
232 input=str(action)
233 )
234 self.action_uuid_to_span[run_uuid] = action_span
235 except Exception as e:
236 logger.warning(f"Error creating Langfuse span for agent action: {str(e)}")
238 def on_agent_finish(self, finish, **kwargs: Any) -> Any:
239 """Run on agent end."""
240 if self.langfuse is None:
241 return
243 run_uuid = kwargs.get('run_id', uuid4()).hex
244 if run_uuid not in self.action_uuid_to_span:
245 return
246 action_span = self.action_uuid_to_span.pop(run_uuid)
247 if action_span is None:
248 return
250 try:
251 if finish is not None:
252 action_span.update(output=finish) # supersedes tool output
253 action_span.end()
254 except Exception as e:
255 logger.warning(f"Error updating Langfuse span: {str(e)}")
257 def auth_check(self):
258 if self.langfuse is not None:
259 return self.langfuse.auth_check()
260 return False
262 def get_metrics(self) -> Dict[str, Any]:
263 """Get collected metrics about tools and chains.
265 Returns:
266 Dict containing:
267 - tool_metrics: Statistics about tool usage, errors, and timing
268 - chain_metrics: Statistics about chain execution, errors, and timing
269 For each tool/chain, includes:
270 - count: Number of times used
271 - total_time: Total execution time
272 - errors: Number of errors
273 - last_error: Most recent error message
274 - avg_duration: Average execution time
275 """
276 metrics = {
277 'tool_metrics': {},
278 'chain_metrics': {}
279 }
281 # Process tool metrics
282 for tool_name, data in self.tool_metrics.items():
283 metrics['tool_metrics'][tool_name] = {
284 'count': data['count'],
285 'total_time': data['total_time'],
286 'avg_duration': data['total_time'] / data['count'] if data['count'] > 0 else 0,
287 'errors': data['errors'],
288 'last_error': data['last_error'],
289 'error_rate': data['errors'] / data['count'] if data['count'] > 0 else 0
290 }
292 # Process chain metrics
293 for chain_name, data in self.chain_metrics.items():
294 metrics['chain_metrics'][chain_name] = {
295 'count': data['count'],
296 'total_time': data['total_time'],
297 'avg_duration': data['total_time'] / data['count'] if data['count'] > 0 else 0,
298 'errors': data['errors'],
299 'last_error': data['last_error'],
300 'error_rate': data['errors'] / data['count'] if data['count'] > 0 else 0
301 }
303 return metrics
306def get_skills(agent: db.Agents) -> List:
307 """ Retrieve skills from agent `skills` attribute. Specific to agent endpoints. """
308 return [rel.skill.type for rel in agent.skills_relationships]