Coverage for mindsdb / interfaces / agents / langfuse_callback_handler.py: 11%

155 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from typing import Any, Dict, Union, Optional, List 

2from uuid import uuid4 

3import datetime 

4import json 

5 

6from langchain_core.callbacks.base import BaseCallbackHandler 

7 

8from mindsdb.utilities import log 

9from mindsdb.interfaces.storage import db 

10 

11logger = log.getLogger(__name__) 

12logger.setLevel('DEBUG') 

13 

14 

15class LangfuseCallbackHandler(BaseCallbackHandler): 

16 """Langchain callback handler that traces tool & chain executions using Langfuse.""" 

17 

18 def __init__(self, langfuse, trace_id: Optional[str] = None, observation_id: Optional[str] = None): 

19 self.langfuse = langfuse 

20 self.chain_uuid_to_span = {} 

21 self.action_uuid_to_span = {} 

22 # if these are not available, we generate some UUIDs 

23 self.trace_id = trace_id or uuid4().hex 

24 self.observation_id = observation_id or uuid4().hex 

25 # Track metrics about tools and chains 

26 self.tool_metrics = {} 

27 self.chain_metrics = {} 

28 self.current_chain = None 

29 

30 def on_tool_start( 

31 self, serialized: Dict[str, Any], input_str: str, **kwargs: Any 

32 ) -> Any: 

33 """Run when tool starts running.""" 

34 parent_run_uuid = kwargs.get('parent_run_id', uuid4()).hex 

35 action_span = self.action_uuid_to_span.get(parent_run_uuid) 

36 if action_span is None: 

37 return 

38 

39 tool_name = serialized.get("name", "tool") 

40 start_time = datetime.datetime.now() 

41 

42 # Initialize or update tool metrics 

43 if tool_name not in self.tool_metrics: 

44 self.tool_metrics[tool_name] = { 

45 'count': 0, 

46 'total_time': 0, 

47 'errors': 0, 

48 'last_error': None, 

49 'inputs': [] 

50 } 

51 

52 self.tool_metrics[tool_name]['count'] += 1 

53 self.tool_metrics[tool_name]['inputs'].append(input_str) 

54 

55 metadata = { 

56 'tool_name': tool_name, 

57 'started': start_time.isoformat(), 

58 'start_timestamp': start_time.timestamp(), 

59 'input_length': len(input_str) if input_str else 0 

60 } 

61 action_span.update(metadata=metadata) 

62 

63 def on_tool_end(self, output: str, **kwargs: Any) -> Any: 

64 """Run when tool ends running.""" 

65 parent_run_uuid = kwargs.get('parent_run_id', uuid4()).hex 

66 action_span = self.action_uuid_to_span.get(parent_run_uuid) 

67 if action_span is None: 

68 return 

69 

70 end_time = datetime.datetime.now() 

71 tool_name = action_span.metadata.get('tool_name', 'unknown') 

72 start_timestamp = action_span.metadata.get('start_timestamp') 

73 

74 if start_timestamp: 

75 duration = end_time.timestamp() - start_timestamp 

76 if tool_name in self.tool_metrics: 

77 self.tool_metrics[tool_name]['total_time'] += duration 

78 

79 metadata = { 

80 'finished': end_time.isoformat(), 

81 'duration_seconds': duration if start_timestamp else None, 

82 'output_length': len(output) if output else 0 

83 } 

84 

85 action_span.update( 

86 output=output, # tool output is action output (unless superseded by a global action output) 

87 metadata=metadata 

88 ) 

89 

90 def on_tool_error( 

91 self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any 

92 ) -> Any: 

93 """Run when tool errors.""" 

94 parent_run_uuid = kwargs.get('parent_run_id', uuid4()).hex 

95 action_span = self.action_uuid_to_span.get(parent_run_uuid) 

96 if action_span is None: 

97 return 

98 

99 try: 

100 error_str = str(error) 

101 except Exception: 

102 error_str = "Couldn't get error string." 

103 

104 tool_name = action_span.metadata.get('tool_name', 'unknown') 

105 if tool_name in self.tool_metrics: 

106 self.tool_metrics[tool_name]['errors'] += 1 

107 self.tool_metrics[tool_name]['last_error'] = error_str 

108 

109 metadata = { 

110 'error_description': error_str, 

111 'error_type': error.__class__.__name__, 

112 'error_time': datetime.datetime.now().isoformat() 

113 } 

114 action_span.update(metadata=metadata) 

115 

116 def on_chain_start( 

117 self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any 

118 ) -> Any: 

119 """Run when chain starts running.""" 

120 if self.langfuse is None: 

121 return 

122 

123 run_uuid = kwargs.get('run_id', uuid4()).hex 

124 

125 if serialized is None: 

126 serialized = {} 

127 

128 chain_name = serialized.get("name", "chain") 

129 start_time = datetime.datetime.now() 

130 

131 # Initialize or update chain metrics 

132 if chain_name not in self.chain_metrics: 

133 self.chain_metrics[chain_name] = { 

134 'count': 0, 

135 'total_time': 0, 

136 'errors': 0, 

137 'last_error': None 

138 } 

139 

140 self.chain_metrics[chain_name]['count'] += 1 

141 self.current_chain = chain_name 

142 

143 try: 

144 chain_span = self.langfuse.span( 

145 name=f'{chain_name}-{run_uuid}', 

146 trace_id=self.trace_id, 

147 parent_observation_id=self.observation_id, 

148 input=json.dumps(inputs, indent=2) 

149 ) 

150 

151 metadata = { 

152 'chain_name': chain_name, 

153 'started': start_time.isoformat(), 

154 'start_timestamp': start_time.timestamp(), 

155 'input_keys': list(inputs.keys()) if isinstance(inputs, dict) else None, 

156 'input_size': len(inputs) if isinstance(inputs, dict) else len(str(inputs)) 

157 } 

158 chain_span.update(metadata=metadata) 

159 self.chain_uuid_to_span[run_uuid] = chain_span 

160 except Exception as e: 

161 logger.warning(f"Error creating Langfuse span: {str(e)}") 

162 

163 def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any: 

164 """Run when chain ends running.""" 

165 if self.langfuse is None: 

166 return 

167 

168 chain_uuid = kwargs.get('run_id', uuid4()).hex 

169 if chain_uuid not in self.chain_uuid_to_span: 

170 return 

171 chain_span = self.chain_uuid_to_span.pop(chain_uuid) 

172 if chain_span is None: 

173 return 

174 

175 try: 

176 end_time = datetime.datetime.now() 

177 chain_name = chain_span.metadata.get('chain_name', 'unknown') 

178 start_timestamp = chain_span.metadata.get('start_timestamp') 

179 

180 if start_timestamp and chain_name in self.chain_metrics: 

181 duration = end_time.timestamp() - start_timestamp 

182 self.chain_metrics[chain_name]['total_time'] += duration 

183 

184 metadata = { 

185 'finished': end_time.isoformat(), 

186 'duration_seconds': duration if start_timestamp else None, 

187 'output_keys': list(outputs.keys()) if isinstance(outputs, dict) else None, 

188 'output_size': len(outputs) if isinstance(outputs, dict) else len(str(outputs)) 

189 } 

190 chain_span.update(output=json.dumps(outputs, indent=2), metadata=metadata) 

191 chain_span.end() 

192 except Exception as e: 

193 logger.warning(f"Error updating Langfuse span: {str(e)}") 

194 

195 def on_chain_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any) -> Any: 

196 """Run when chain errors.""" 

197 chain_uuid = kwargs.get('run_id', uuid4()).hex 

198 if chain_uuid not in self.chain_uuid_to_span: 

199 return 

200 chain_span = self.chain_uuid_to_span.get(chain_uuid) 

201 if chain_span is None: 

202 return 

203 

204 try: 

205 error_str = str(error) 

206 except Exception: 

207 error_str = "Couldn't get error string." 

208 

209 chain_name = chain_span.metadata.get('chain_name', 'unknown') 

210 if chain_name in self.chain_metrics: 

211 self.chain_metrics[chain_name]['errors'] += 1 

212 self.chain_metrics[chain_name]['last_error'] = error_str 

213 

214 metadata = { 

215 'error_description': error_str, 

216 'error_type': error.__class__.__name__, 

217 'error_time': datetime.datetime.now().isoformat() 

218 } 

219 chain_span.update(metadata=metadata) 

220 

221 def on_agent_action(self, action, **kwargs: Any) -> Any: 

222 """Run on agent action.""" 

223 if self.langfuse is None: 

224 return 

225 

226 run_uuid = kwargs.get('run_id', uuid4()).hex 

227 try: 

228 action_span = self.langfuse.span( 

229 name=f'{getattr(action, "type", "action")}-{getattr(action, "tool", "")}-{run_uuid}', 

230 trace_id=self.trace_id, 

231 parent_observation_id=self.observation_id, 

232 input=str(action) 

233 ) 

234 self.action_uuid_to_span[run_uuid] = action_span 

235 except Exception as e: 

236 logger.warning(f"Error creating Langfuse span for agent action: {str(e)}") 

237 

238 def on_agent_finish(self, finish, **kwargs: Any) -> Any: 

239 """Run on agent end.""" 

240 if self.langfuse is None: 

241 return 

242 

243 run_uuid = kwargs.get('run_id', uuid4()).hex 

244 if run_uuid not in self.action_uuid_to_span: 

245 return 

246 action_span = self.action_uuid_to_span.pop(run_uuid) 

247 if action_span is None: 

248 return 

249 

250 try: 

251 if finish is not None: 

252 action_span.update(output=finish) # supersedes tool output 

253 action_span.end() 

254 except Exception as e: 

255 logger.warning(f"Error updating Langfuse span: {str(e)}") 

256 

257 def auth_check(self): 

258 if self.langfuse is not None: 

259 return self.langfuse.auth_check() 

260 return False 

261 

262 def get_metrics(self) -> Dict[str, Any]: 

263 """Get collected metrics about tools and chains. 

264 

265 Returns: 

266 Dict containing: 

267 - tool_metrics: Statistics about tool usage, errors, and timing 

268 - chain_metrics: Statistics about chain execution, errors, and timing 

269 For each tool/chain, includes: 

270 - count: Number of times used 

271 - total_time: Total execution time 

272 - errors: Number of errors 

273 - last_error: Most recent error message 

274 - avg_duration: Average execution time 

275 """ 

276 metrics = { 

277 'tool_metrics': {}, 

278 'chain_metrics': {} 

279 } 

280 

281 # Process tool metrics 

282 for tool_name, data in self.tool_metrics.items(): 

283 metrics['tool_metrics'][tool_name] = { 

284 'count': data['count'], 

285 'total_time': data['total_time'], 

286 'avg_duration': data['total_time'] / data['count'] if data['count'] > 0 else 0, 

287 'errors': data['errors'], 

288 'last_error': data['last_error'], 

289 'error_rate': data['errors'] / data['count'] if data['count'] > 0 else 0 

290 } 

291 

292 # Process chain metrics 

293 for chain_name, data in self.chain_metrics.items(): 

294 metrics['chain_metrics'][chain_name] = { 

295 'count': data['count'], 

296 'total_time': data['total_time'], 

297 'avg_duration': data['total_time'] / data['count'] if data['count'] > 0 else 0, 

298 'errors': data['errors'], 

299 'last_error': data['last_error'], 

300 'error_rate': data['errors'] / data['count'] if data['count'] > 0 else 0 

301 } 

302 

303 return metrics 

304 

305 

306def get_skills(agent: db.Agents) -> List: 

307 """ Retrieve skills from agent `skills` attribute. Specific to agent endpoints. """ 

308 return [rel.skill.type for rel in agent.skills_relationships]