Coverage for mindsdb / interfaces / agents / langchain_agent.py: 11%

451 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import json 

2from concurrent.futures import as_completed, TimeoutError 

3from typing import Dict, Iterable, List, Optional 

4from uuid import uuid4 

5import queue 

6import re 

7import threading 

8import numpy as np 

9import pandas as pd 

10import logging 

11 

12from langchain.agents import AgentExecutor 

13from langchain.agents.initialize import initialize_agent 

14from langchain.chains.conversation.memory import ConversationSummaryBufferMemory 

15from langchain_community.chat_models import ChatLiteLLM, ChatOllama 

16from langchain_writer import ChatWriter 

17from langchain_google_genai import ChatGoogleGenerativeAI 

18from langchain_core.agents import AgentAction, AgentStep 

19from langchain_core.callbacks.base import BaseCallbackHandler 

20from langchain_core.messages import HumanMessage, AIMessage, SystemMessage 

21 

22from langchain_nvidia_ai_endpoints import ChatNVIDIA 

23from langchain_core.messages.base import BaseMessage 

24from langchain_core.prompts import PromptTemplate 

25from langchain_core.tools import Tool 

26 

27from mindsdb.integrations.libs.llm.utils import get_llm_config 

28from mindsdb.integrations.utilities.handler_utils import get_api_key 

29from mindsdb.integrations.utilities.rag.settings import DEFAULT_RAG_PROMPT_TEMPLATE 

30from mindsdb.interfaces.agents.event_dispatch_callback_handler import ( 

31 EventDispatchCallbackHandler, 

32) 

33from mindsdb.interfaces.agents.constants import AGENT_CHUNK_POLLING_INTERVAL_SECONDS 

34from mindsdb.utilities import log 

35from mindsdb.utilities.context_executor import ContextThreadPoolExecutor 

36from mindsdb.interfaces.storage import db 

37from mindsdb.utilities.context import context as ctx 

38 

39from .mindsdb_chat_model import ChatMindsdb 

40from .callback_handlers import LogCallbackHandler, ContextCaptureCallback 

41from .langfuse_callback_handler import LangfuseCallbackHandler, get_skills 

42from .safe_output_parser import SafeOutputParser 

43from .providers import get_bedrock_chat_model 

44 

45from mindsdb.interfaces.agents.constants import ( 

46 OPEN_AI_CHAT_MODELS, 

47 DEFAULT_AGENT_TIMEOUT_SECONDS, 

48 get_default_agent_type, 

49 DEFAULT_EMBEDDINGS_MODEL_PROVIDER, 

50 DEFAULT_MAX_ITERATIONS, 

51 DEFAULT_MAX_TOKENS, 

52 DEFAULT_TIKTOKEN_MODEL_NAME, 

53 SUPPORTED_PROVIDERS, 

54 ANTHROPIC_CHAT_MODELS, 

55 GOOGLE_GEMINI_CHAT_MODELS, 

56 OLLAMA_CHAT_MODELS, 

57 NVIDIA_NIM_CHAT_MODELS, 

58 USER_COLUMN, 

59 ASSISTANT_COLUMN, 

60 CONTEXT_COLUMN, 

61 TRACE_ID_COLUMN, 

62 DEFAULT_AGENT_SYSTEM_PROMPT, 

63 WRITER_CHAT_MODELS, 

64 MINDSDB_PREFIX, 

65 EXPLICIT_FORMAT_INSTRUCTIONS, 

66) 

67from mindsdb.interfaces.skills.skill_tool import skill_tool, SkillData 

68from langchain_anthropic import ChatAnthropic 

69from langchain_openai import ChatOpenAI 

70 

71from mindsdb.utilities.langfuse import LangfuseClientWrapper 

72 

73_PARSING_ERROR_PREFIXES = [ 

74 "An output parsing error occurred", 

75 "Could not parse LLM output", 

76] 

77 

78logger = log.getLogger(__name__) 

79 

80 

81def get_llm_provider(args: Dict) -> str: 

82 # If provider is explicitly specified, use that 

83 if "provider" in args: 

84 return args["provider"] 

85 

86 # Check for known model names from other providers first 

87 if args["model_name"] in ANTHROPIC_CHAT_MODELS: 

88 return "anthropic" 

89 if args["model_name"] in OPEN_AI_CHAT_MODELS: 

90 return "openai" 

91 if args["model_name"] in OLLAMA_CHAT_MODELS: 

92 return "ollama" 

93 if args["model_name"] in NVIDIA_NIM_CHAT_MODELS: 

94 return "nvidia_nim" 

95 if args["model_name"] in GOOGLE_GEMINI_CHAT_MODELS: 

96 return "google" 

97 # Check for writer models 

98 if args["model_name"] in WRITER_CHAT_MODELS: 

99 return "writer" 

100 

101 # For vLLM, require explicit provider specification 

102 raise ValueError("Invalid model name. Please define a supported llm provider") 

103 

104 

105def get_embedding_model_provider(args: Dict) -> str: 

106 """Get the embedding model provider from args. 

107 

108 For VLLM, this will use our custom VLLMEmbeddings class from langchain_embedding_handler. 

109 """ 

110 # Check for explicit embedding model provider 

111 if "embedding_model_provider" in args: 

112 provider = args["embedding_model_provider"] 

113 if provider == "vllm": 

114 if not (args.get("openai_api_base") and args.get("model")): 

115 raise ValueError( 

116 "VLLM embeddings configuration error:\n" 

117 "- Missing required parameters: 'openai_api_base' and/or 'model'\n" 

118 "- Example: openai_api_base='http://localhost:8003/v1', model='your-model-name'" 

119 ) 

120 logger.info("Using custom VLLMEmbeddings class") 

121 return "vllm" 

122 return provider 

123 

124 # Check if LLM provider is vLLM 

125 llm_provider = args.get("provider", DEFAULT_EMBEDDINGS_MODEL_PROVIDER) 

126 if llm_provider == "vllm": 

127 if not (args.get("openai_api_base") and args.get("model")): 

128 raise ValueError( 

129 "VLLM embeddings configuration error:\n" 

130 "- Missing required parameters: 'openai_api_base' and/or 'model'\n" 

131 "- When using VLLM as LLM provider, you must specify the embeddings server location and model\n" 

132 "- Example: openai_api_base='http://localhost:8003/v1', model='your-model-name'" 

133 ) 

134 logger.info("Using custom VLLMEmbeddings class") 

135 return "vllm" 

136 

137 # Default to LLM provider 

138 return llm_provider 

139 

140 

141def get_chat_model_params(args: Dict) -> Dict: 

142 model_config = args.copy() 

143 # Include API keys. 

144 model_config["api_keys"] = {p: get_api_key(p, model_config, None, strict=False) for p in SUPPORTED_PROVIDERS} 

145 llm_config = get_llm_config(args.get("provider", get_llm_provider(args)), model_config) 

146 config_dict = llm_config.model_dump(by_alias=True) 

147 config_dict = {k: v for k, v in config_dict.items() if v is not None} 

148 

149 # If provider is writer, ensure the API key is passed as 'api_key' 

150 if args.get("provider") == "writer" and "writer_api_key" in config_dict: 

151 config_dict["api_key"] = config_dict.pop("writer_api_key") 

152 

153 return config_dict 

154 

155 

156def create_chat_model(args: Dict): 

157 model_kwargs = get_chat_model_params(args) 

158 

159 if args["provider"] == "anthropic": 

160 return ChatAnthropic(**model_kwargs) 

161 if args["provider"] == "openai" or args["provider"] == "vllm": 

162 chat_open_ai = ChatOpenAI(**model_kwargs) 

163 # Some newer GPT models (e.g. gpt-4o when released) don't have token counting support yet. 

164 # By setting this manually in ChatOpenAI, we count tokens like compatible GPT models. 

165 try: 

166 chat_open_ai.get_num_tokens_from_messages([]) 

167 except NotImplementedError: 

168 chat_open_ai.tiktoken_model_name = DEFAULT_TIKTOKEN_MODEL_NAME 

169 return chat_open_ai 

170 if args["provider"] == "litellm": 

171 return ChatLiteLLM(**model_kwargs) 

172 if args["provider"] == "ollama": 

173 return ChatOllama(**model_kwargs) 

174 if args["provider"] == "nvidia_nim": 

175 return ChatNVIDIA(**model_kwargs) 

176 if args["provider"] == "google": 

177 return ChatGoogleGenerativeAI(**model_kwargs) 

178 if args["provider"] == "writer": 

179 return ChatWriter(**model_kwargs) 

180 if args["provider"] == "bedrock": 

181 ChatBedrock = get_bedrock_chat_model() 

182 return ChatBedrock(**model_kwargs) 

183 if args["provider"] == "mindsdb": 

184 return ChatMindsdb(**model_kwargs) 

185 raise ValueError(f"Unknown provider: {args['provider']}") 

186 

187 

188def prepare_prompts(df, base_template, input_variables, user_column=USER_COLUMN): 

189 empty_prompt_ids = np.where(df[input_variables].isna().all(axis=1).values)[0] 

190 

191 # Combine system prompt with user-provided template 

192 base_template = f"{DEFAULT_AGENT_SYSTEM_PROMPT}\n\n{base_template}" 

193 

194 base_template = base_template.replace("{{", "{").replace("}}", "}") 

195 prompts = [] 

196 

197 for i, row in df.iterrows(): 

198 if i not in empty_prompt_ids: 

199 prompt = PromptTemplate(input_variables=input_variables, template=base_template) 

200 kwargs = {col: row[col] if row[col] is not None else "" for col in input_variables} 

201 prompts.append(prompt.format(**kwargs)) 

202 elif row.get(user_column): 

203 prompts.append(row[user_column]) 

204 

205 return prompts, empty_prompt_ids 

206 

207 

208def prepare_callbacks(self, args): 

209 context_callback = ContextCaptureCallback() 

210 callbacks = self._get_agent_callbacks(args) 

211 callbacks.append(context_callback) 

212 return callbacks, context_callback 

213 

214 

215def handle_agent_error(e, error_message=None): 

216 if error_message is None: 

217 error_message = f"An error occurred during agent execution: {str(e)}" 

218 logger.error(error_message, exc_info=True) 

219 return error_message 

220 

221 

222def process_chunk(chunk): 

223 if isinstance(chunk, dict): 

224 return {k: process_chunk(v) for k, v in chunk.items()} 

225 elif isinstance(chunk, list): 

226 return [process_chunk(item) for item in chunk] 

227 elif isinstance(chunk, (str, int, float, bool, type(None))): 

228 return chunk 

229 else: 

230 return str(chunk) 

231 

232 

233class LangchainAgent: 

234 def __init__(self, agent: db.Agents, model: dict = None, llm_params: dict = None): 

235 self.agent = agent 

236 self.model = model 

237 

238 self.run_completion_span: Optional[object] = None 

239 self.llm: Optional[object] = None 

240 self.embedding_model: Optional[object] = None 

241 

242 self.log_callback_handler: Optional[object] = None 

243 self.langfuse_callback_handler: Optional[object] = None # native langfuse callback handler 

244 self.mdb_langfuse_callback_handler: Optional[object] = None # custom (see langfuse_callback_handler.py) 

245 

246 self.langfuse_client_wrapper = LangfuseClientWrapper() 

247 self.args = self._initialize_args(llm_params) 

248 

249 # Back compatibility for old models 

250 self.provider = self.args.get("provider", get_llm_provider(self.args)) 

251 

252 def _initialize_args(self, llm_params: dict = None) -> dict: 

253 """ 

254 Initialize the arguments for agent execution. 

255 

256 Takes the parameters passed during execution and sets necessary defaults. 

257 The params are already merged with defaults by AgentsController.get_agent_llm_params. 

258 

259 Args: 

260 llm_params: Parameters for agent execution (already merged with defaults) 

261 

262 Returns: 

263 dict: Final parameters for agent execution 

264 """ 

265 # Use the parameters passed to the method (already merged with defaults by AgentsController) 

266 # No fallback needed as AgentsController.get_agent_llm_params already handles this 

267 args = self.agent.params.copy() 

268 if llm_params: 

269 args.update(llm_params) 

270 

271 # Set model name and provider if given in create agent otherwise use global llm defaults 

272 # AgentsController.get_agent_llm_params 

273 if self.agent.model_name is not None: 

274 args["model_name"] = self.agent.model_name 

275 if self.agent.provider is not None: 

276 args["provider"] = self.agent.provider 

277 

278 args["embedding_model_provider"] = args.get("embedding_model", get_embedding_model_provider(args)) 

279 

280 # agent is using current langchain model 

281 if self.agent.provider == "mindsdb": 

282 args["model_name"] = self.agent.model_name 

283 

284 # get prompt 

285 prompt_template = self.model["problem_definition"].get("using", {}).get("prompt_template") 

286 if prompt_template is not None: 

287 # only update prompt_template if it is set on the model 

288 args["prompt_template"] = prompt_template 

289 

290 # Set default prompt template if not provided 

291 if args.get("prompt_template") is None: 

292 # Default prompt template depends on agent mode 

293 if args.get("mode") == "retrieval": 

294 args["prompt_template"] = DEFAULT_RAG_PROMPT_TEMPLATE 

295 logger.info(f"Using default retrieval prompt template: {DEFAULT_RAG_PROMPT_TEMPLATE[:50]}...") 

296 else: 

297 # Set a default prompt template for non-retrieval mode 

298 default_prompt = "you are an assistant, answer using the tables connected" 

299 args["prompt_template"] = default_prompt 

300 logger.info(f"Using default prompt template: {default_prompt}") 

301 

302 if "prompt_template" in args: 

303 logger.info(f"Using prompt template: {args['prompt_template'][:50]}...") 

304 

305 if "model_name" not in args: 

306 raise ValueError( 

307 "No model name provided for agent. Provide it in the model parameter or in the default model setup." 

308 ) 

309 

310 return args 

311 

312 def get_metadata(self) -> Dict: 

313 return { 

314 "provider": self.provider, 

315 "model_name": self.args["model_name"], 

316 "embedding_model_provider": self.args.get( 

317 "embedding_model_provider", get_embedding_model_provider(self.args) 

318 ), 

319 "skills": get_skills(self.agent), 

320 "user_id": ctx.user_id, 

321 "session_id": ctx.session_id, 

322 "company_id": ctx.company_id, 

323 "user_class": ctx.user_class, 

324 "email_confirmed": ctx.email_confirmed, 

325 } 

326 

327 def get_tags(self) -> List: 

328 return [ 

329 self.provider, 

330 ] 

331 

332 def get_completion(self, messages, stream: bool = False, params: dict | None = None): 

333 # Get metadata and tags to be used in the trace 

334 metadata = self.get_metadata() 

335 tags = self.get_tags() 

336 

337 # Set up trace for the API completion in Langfuse 

338 self.langfuse_client_wrapper.setup_trace( 

339 name="api-completion", 

340 input=messages, 

341 tags=tags, 

342 metadata=metadata, 

343 user_id=ctx.user_id, 

344 session_id=ctx.session_id, 

345 ) 

346 

347 # Set up trace for the run completion in Langfuse 

348 self.run_completion_span = self.langfuse_client_wrapper.start_span(name="run-completion", input=messages) 

349 

350 if stream: 

351 return self._get_completion_stream(messages) 

352 

353 args = {} 

354 args.update(self.args) 

355 args.update(params or {}) 

356 

357 df = pd.DataFrame(messages) 

358 logger.info(f"LangchainAgent.get_completion: Received {len(messages)} messages") 

359 if logger.isEnabledFor(logging.DEBUG): 

360 logger.debug(f"Messages DataFrame shape: {df.shape}") 

361 logger.debug(f"Messages DataFrame columns: {df.columns.tolist()}") 

362 logger.debug(f"Messages DataFrame content: {df.to_dict('records')}") 

363 

364 # Back compatibility for old models 

365 self.provider = args.get("provider", get_llm_provider(args)) 

366 

367 df = df.reset_index(drop=True) 

368 agent = self.create_agent(df) 

369 # Keep conversation history for context - don't nullify previous messages 

370 

371 # Only use the last message as the current prompt, but preserve history for agent memory 

372 response = self.run_agent(df, agent, args) 

373 

374 # End the run completion span and update the metadata with tool usage 

375 self.langfuse_client_wrapper.end_span(span=self.run_completion_span, output=response) 

376 

377 return response 

378 

379 def _get_completion_stream(self, messages: List[dict]) -> Iterable[Dict]: 

380 """Gets a completion as a stream of chunks from given messages. 

381 

382 Args: 

383 messages (List[dict]): Messages to get completion chunks for 

384 

385 Returns: 

386 chunks (Iterable[object]): Completion chunks 

387 """ 

388 

389 args = self.args 

390 

391 df = pd.DataFrame(messages) 

392 logger.info(f"LangchainAgent._get_completion_stream: Received {len(messages)} messages") 

393 # Check if we have the expected columns for conversation history 

394 if "question" in df.columns and "answer" in df.columns: 

395 logger.debug("DataFrame has question/answer columns for conversation history") 

396 else: 

397 logger.warning("DataFrame missing question/answer columns! Available columns: {df.columns.tolist()}") 

398 

399 self.embedding_model_provider = args.get("embedding_model_provider", get_embedding_model_provider(args)) 

400 # Back compatibility for old models 

401 self.provider = args.get("provider", get_llm_provider(args)) 

402 

403 df = df.reset_index(drop=True) 

404 agent = self.create_agent(df) 

405 # Keep conversation history for context - don't nullify previous messages 

406 # Only use the last message as the current prompt, but preserve history for agent memory 

407 return self.stream_agent(df, agent, args) 

408 

409 def create_agent(self, df: pd.DataFrame) -> AgentExecutor: 

410 # Set up tools. 

411 

412 args = self.args 

413 

414 llm = create_chat_model(args) 

415 self.llm = llm 

416 

417 # Don't set embedding model for retrieval mode - let the knowledge base handle it 

418 if args.get("mode") == "retrieval": 

419 self.args.pop("mode") 

420 

421 tools = self._langchain_tools_from_skills(llm) 

422 

423 # Prefer prediction prompt template over original if provided. 

424 prompt_template = args["prompt_template"] 

425 

426 # Modern LangChain approach: Use memory but populate it correctly 

427 # Create memory and populate with conversation history 

428 memory = ConversationSummaryBufferMemory( 

429 llm=llm, 

430 input_key="input", 

431 output_key="output", 

432 max_token_limit=args.get("max_tokens", DEFAULT_MAX_TOKENS), 

433 memory_key="chat_history", 

434 ) 

435 

436 # Add system message first 

437 memory.chat_memory.messages.insert(0, SystemMessage(content=prompt_template)) 

438 

439 user_column = args.get("user_column", USER_COLUMN) 

440 assistant_column = args.get("assistant_column", ASSISTANT_COLUMN) 

441 

442 logger.info(f"Processing conversation history: {len(df)} total messages, {len(df[:-1])} history messages") 

443 logger.debug(f"User column: {user_column}, Assistant column: {assistant_column}") 

444 

445 # Process history messages (all except the last one which is current message) 

446 history_df = df[:-1] 

447 if len(history_df) == 0: 

448 logger.debug("No history rows to process - this is normal for first message") 

449 

450 history_count = 0 

451 for i, row in enumerate(history_df.to_dict("records")): 

452 question = row.get(user_column) 

453 answer = row.get(assistant_column) 

454 logger.debug(f"Converting history row {i}: question='{question}', answer='{answer}'") 

455 

456 # Add messages directly to memory's chat_memory.messages list (modern approach) 

457 if isinstance(question, str) and len(question) > 0: 

458 memory.chat_memory.messages.append(HumanMessage(content=question)) 

459 history_count += 1 

460 logger.debug(f"Added HumanMessage to memory: {question}") 

461 if isinstance(answer, str) and len(answer) > 0: 

462 memory.chat_memory.messages.append(AIMessage(content=answer)) 

463 history_count += 1 

464 logger.debug(f"Added AIMessage to memory: {answer}") 

465 

466 logger.info(f"Built conversation history with {history_count} history messages + system message") 

467 logger.debug(f"Final memory messages count: {len(memory.chat_memory.messages)}") 

468 

469 # Store memory for agent use 

470 self._conversation_memory = memory 

471 default_agent = get_default_agent_type() 

472 agent_type = args.get("agent_type", default_agent) 

473 agent_executor = initialize_agent( 

474 tools, 

475 llm, 

476 agent=agent_type, 

477 # Use custom output parser to handle flaky LLMs that don't ALWAYS conform to output format. 

478 agent_kwargs={ 

479 "output_parser": SafeOutputParser(), 

480 "prefix": MINDSDB_PREFIX, # Override default "Assistant is a large language model..." text 

481 "format_instructions": EXPLICIT_FORMAT_INSTRUCTIONS, # More explicit tool calling instructions 

482 "ai_prefix": "AI", 

483 }, 

484 # Calls the agent's LLM Chain one final time to generate a final answer based on the previous steps 

485 early_stopping_method="generate", 

486 handle_parsing_errors=self._handle_parsing_errors, 

487 # Timeout per agent invocation. 

488 max_execution_time=args.get( 

489 "timeout_seconds", 

490 args.get("timeout_seconds", DEFAULT_AGENT_TIMEOUT_SECONDS), 

491 ), 

492 max_iterations=args.get("max_iterations", args.get("max_iterations", DEFAULT_MAX_ITERATIONS)), 

493 memory=memory, 

494 verbose=args.get("verbose", args.get("verbose", False)), 

495 ) 

496 return agent_executor 

497 

498 def _langchain_tools_from_skills(self, llm): 

499 # Makes Langchain compatible tools from a skill 

500 skills_data = [ 

501 SkillData( 

502 name=rel.skill.name, 

503 type=rel.skill.type, 

504 params=rel.skill.params, 

505 project_id=rel.skill.project_id, 

506 agent_tables_list=(rel.parameters or {}).get("tables"), 

507 ) 

508 for rel in self.agent.skills_relationships 

509 ] 

510 

511 tools_groups = skill_tool.get_tools_from_skills(skills_data, llm, self.embedding_model) 

512 

513 all_tools = [] 

514 for skill_type, tools in tools_groups.items(): 

515 for tool in tools: 

516 if isinstance(tool, dict): 

517 tool = Tool( 

518 name=tool["name"], 

519 func=tool["func"], 

520 description=tool["description"], 

521 ) 

522 all_tools.append(tool) 

523 return all_tools 

524 

525 def _get_agent_callbacks(self, args: Dict) -> List: 

526 all_callbacks = [] 

527 

528 if self.log_callback_handler is None: 

529 self.log_callback_handler = LogCallbackHandler(logger, verbose=args.get("verbose", True)) 

530 

531 all_callbacks.append(self.log_callback_handler) 

532 

533 if self.langfuse_client_wrapper.trace is None: 

534 # Get metadata and tags to be used in the trace 

535 metadata = self.get_metadata() 

536 tags = self.get_tags() 

537 

538 trace_name = "NativeTrace-MindsDB-AgentExecutor" 

539 

540 # Set up trace for the API completion in Langfuse 

541 self.langfuse_client_wrapper.setup_trace( 

542 name=trace_name, 

543 tags=tags, 

544 metadata=metadata, 

545 user_id=ctx.user_id, 

546 session_id=ctx.session_id, 

547 ) 

548 

549 if self.langfuse_callback_handler is None: 

550 self.langfuse_callback_handler = self.langfuse_client_wrapper.get_langchain_handler() 

551 

552 # custom tracer 

553 if self.mdb_langfuse_callback_handler is None: 

554 trace_id = self.langfuse_client_wrapper.get_trace_id() 

555 

556 span_id = None 

557 if self.run_completion_span is not None: 

558 span_id = self.run_completion_span.id 

559 

560 observation_id = args.get("observation_id", span_id or uuid4().hex) 

561 

562 self.mdb_langfuse_callback_handler = LangfuseCallbackHandler( 

563 langfuse=self.langfuse_client_wrapper.client, 

564 trace_id=trace_id, 

565 observation_id=observation_id, 

566 ) 

567 

568 # obs: we may want to unify these; native langfuse handler provides details as a tree on a sub-step of the overarching custom one # noqa 

569 if self.langfuse_callback_handler is not None: 

570 all_callbacks.append(self.langfuse_callback_handler) 

571 

572 if self.mdb_langfuse_callback_handler: 

573 all_callbacks.append(self.mdb_langfuse_callback_handler) 

574 

575 return all_callbacks 

576 

577 def _handle_parsing_errors(self, error: Exception) -> str: 

578 response = str(error) 

579 for p in _PARSING_ERROR_PREFIXES: 

580 if response.startswith(p): 

581 # As a somewhat dirty workaround, we accept the output formatted incorrectly and use it as a response. 

582 # 

583 # Ideally, in the future, we would write a parser that is more robust and flexible than the one Langchain uses. 

584 # Response is wrapped in `` 

585 logger.info("Handling parsing error, salvaging response...") 

586 response_output = response.split("`") 

587 if len(response_output) >= 2: 

588 response = response_output[-2] 

589 

590 # Wrap response in Langchain conversational react format. 

591 langchain_react_formatted_response = f"""Thought: Do I need to use a tool? No 

592AI: {response}""" 

593 return langchain_react_formatted_response 

594 return f"Agent failed with error:\n{str(error)}..." 

595 

596 def run_agent(self, df: pd.DataFrame, agent: AgentExecutor, args: Dict) -> pd.DataFrame: 

597 base_template = args.get("prompt_template", args["prompt_template"]) 

598 return_context = args.get("return_context", True) 

599 input_variables = re.findall(r"{{(.*?)}}", base_template) 

600 

601 prompts, empty_prompt_ids = prepare_prompts( 

602 df, base_template, input_variables, args.get("user_column", USER_COLUMN) 

603 ) 

604 

605 def _invoke_agent_executor_with_prompt(agent_executor, prompt): 

606 if not prompt: 

607 return {CONTEXT_COLUMN: [], ASSISTANT_COLUMN: ""} 

608 try: 

609 callbacks, context_callback = prepare_callbacks(self, args) 

610 

611 # Modern LangChain approach: Include conversation history + current message 

612 if hasattr(self, "_conversation_messages") and self._conversation_messages: 

613 # Add current user message to conversation history 

614 full_messages = self._conversation_messages + [HumanMessage(content=prompt)] 

615 logger.critical(f"🔍 INVOKING AGENT with {len(full_messages)} messages (including history)") 

616 logger.debug( 

617 f"Full conversation messages: {[type(msg).__name__ + ': ' + msg.content[:100] + '...' for msg in full_messages]}" 

618 ) 

619 

620 # For agents, we need to pass the input in the expected format 

621 # The agent expects 'input' key with the current question, but conversation history should be in memory 

622 result = agent_executor.invoke({"input": prompt}, config={"callbacks": callbacks}) 

623 else: 

624 logger.warning("No conversation messages found - using simple prompt") 

625 result = agent_executor.invoke({"input": prompt}, config={"callbacks": callbacks}) 

626 captured_context = context_callback.get_contexts() 

627 output = result["output"] if isinstance(result, dict) and "output" in result else str(result) 

628 return {CONTEXT_COLUMN: captured_context, ASSISTANT_COLUMN: output} 

629 except Exception as e: 

630 error_message = str(e) 

631 # Special handling for API key errors 

632 if "API key" in error_message and ("not found" in error_message or "missing" in error_message): 

633 # Format API key error more clearly 

634 logger.error(f"API Key Error: {error_message}") 

635 error_message = f"API Key Error: {error_message}" 

636 return { 

637 CONTEXT_COLUMN: [], 

638 ASSISTANT_COLUMN: handle_agent_error(e, error_message), 

639 } 

640 

641 completions = [] 

642 contexts = [] 

643 

644 max_workers = args.get("max_workers", None) 

645 agent_timeout_seconds = args.get("timeout", DEFAULT_AGENT_TIMEOUT_SECONDS) 

646 

647 with ContextThreadPoolExecutor(max_workers=max_workers) as executor: 

648 # Only process the last prompt (current question), not all prompts 

649 # The previous prompts are conversation history and should only be used for context 

650 if prompts: 

651 current_prompt = prompts[-1] # Last prompt is the current question 

652 futures = [executor.submit(_invoke_agent_executor_with_prompt, agent, current_prompt)] 

653 else: 

654 logger.error("No prompts found to process") 

655 futures = [] 

656 try: 

657 for future in as_completed(futures, timeout=agent_timeout_seconds): 

658 result = future.result() 

659 if result is None: 

660 result = { 

661 CONTEXT_COLUMN: [], 

662 ASSISTANT_COLUMN: "No response generated", 

663 } 

664 

665 completions.append(result[ASSISTANT_COLUMN]) 

666 contexts.append(result[CONTEXT_COLUMN]) 

667 except TimeoutError: 

668 timeout_message = ( 

669 f"I'm sorry! I couldn't generate a response within the allotted time ({agent_timeout_seconds} seconds). " 

670 "If you need more time for processing, you can adjust the timeout settings. " 

671 "Please refer to the documentation for instructions on how to change the timeout value. " 

672 "Feel free to try your request again." 

673 ) 

674 logger.warning(f"Agent execution timed out after {agent_timeout_seconds} seconds") 

675 for _ in range(len(futures) - len(completions)): 

676 completions.append(timeout_message) 

677 contexts.append([]) 

678 

679 # Add null completion for empty prompts 

680 for i in sorted(empty_prompt_ids)[:-1]: 

681 completions.insert(i, None) 

682 contexts.insert(i, []) 

683 

684 # Create DataFrame with completions and context if required 

685 pred_df = pd.DataFrame( 

686 { 

687 ASSISTANT_COLUMN: completions, 

688 CONTEXT_COLUMN: [json.dumps(ctx) for ctx in contexts], # Serialize context to JSON string 

689 TRACE_ID_COLUMN: self.langfuse_client_wrapper.get_trace_id(), 

690 } 

691 ) 

692 

693 if not return_context: 

694 pred_df = pred_df.drop(columns=[CONTEXT_COLUMN]) 

695 

696 return pred_df 

697 

698 def add_chunk_metadata(self, chunk: Dict) -> Dict: 

699 logger.debug(f"Adding metadata to chunk: {chunk}") 

700 logger.debug(f"Trace ID: {self.langfuse_client_wrapper.get_trace_id()}") 

701 chunk["trace_id"] = self.langfuse_client_wrapper.get_trace_id() 

702 return chunk 

703 

704 def _stream_agent_executor( 

705 self, 

706 agent_executor: AgentExecutor, 

707 prompt: str, 

708 callbacks: List[BaseCallbackHandler], 

709 ): 

710 chunk_queue = queue.Queue() 

711 # Add event dispatch callback handler only to streaming completions. 

712 event_dispatch_callback_handler = EventDispatchCallbackHandler(chunk_queue) 

713 callbacks.append(event_dispatch_callback_handler) 

714 stream_iterator = agent_executor.stream(prompt, config={"callbacks": callbacks}) 

715 

716 agent_executor_finished_event = threading.Event() 

717 

718 def stream_worker(context: dict): 

719 try: 

720 ctx.load(context) 

721 for chunk in stream_iterator: 

722 chunk_queue.put(chunk) 

723 except TimeoutError as e: 

724 error_message = f"Request timed out: The agent took too long to respond. {str(e)}" 

725 logger.error(f"Timeout error during streaming: {error_message}", exc_info=True) 

726 error_chunk = { 

727 "type": "error", 

728 "content": handle_agent_error(e, error_message), 

729 "error_type": "timeout", 

730 } 

731 chunk_queue.put(error_chunk) 

732 except ConnectionError as e: 

733 error_message = f"Connection error: Failed to connect to the service. {str(e)}" 

734 logger.error(f"Connection error during streaming: {error_message}", exc_info=True) 

735 error_chunk = { 

736 "type": "error", 

737 "content": handle_agent_error(e, error_message), 

738 "error_type": "connection", 

739 } 

740 chunk_queue.put(error_chunk) 

741 except Exception as e: 

742 error_message = str(e) 

743 

744 # Special handling for specific error types 

745 # Note: TimeoutError and ConnectionError are already handled by specific exception handlers above 

746 if "API key" in error_message and ("not found" in error_message or "missing" in error_message): 

747 logger.error(f"API Key Error: {error_message}") 

748 error_message = f"API Key Error: {error_message}" 

749 error_type = "authentication" 

750 elif "404" in error_message or "not found" in error_message.lower(): 

751 logger.error(f"Model Error: {error_message}") 

752 error_type = "not_found" 

753 elif "rate limit" in error_message.lower() or "429" in error_message: 

754 logger.error(f"Rate Limit Error: {error_message}") 

755 error_message = f"Rate limit exceeded: {error_message}" 

756 error_type = "rate_limit" 

757 else: 

758 logger.error(f"LLM chain encountered an error during streaming: {error_message}", exc_info=True) 

759 error_type = "general" 

760 

761 if not error_message or not error_message.strip(): 

762 error_message = f"An unknown error occurred during streaming: {type(e).__name__}" 

763 

764 error_chunk = { 

765 "type": "error", 

766 "content": handle_agent_error(e, error_message), 

767 "error_type": error_type, 

768 } 

769 chunk_queue.put(error_chunk) 

770 finally: 

771 # Wrap in try/finally to always set the thread event even if there's an exception. 

772 agent_executor_finished_event.set() 

773 

774 # Enqueue Langchain agent streaming chunks in a separate thread to not block event chunks. 

775 executor_stream_thread = threading.Thread( 

776 target=stream_worker, 

777 daemon=True, 

778 args=(ctx.dump(),), 

779 name="LangchainAgent.stream_worker", 

780 ) 

781 executor_stream_thread.start() 

782 

783 while not agent_executor_finished_event.is_set(): 

784 try: 

785 chunk = chunk_queue.get(block=True, timeout=AGENT_CHUNK_POLLING_INTERVAL_SECONDS) 

786 except queue.Empty: 

787 continue 

788 logger.debug(f"Processing streaming chunk {chunk}") 

789 processed_chunk = self.process_chunk(chunk) 

790 logger.info(f"Processed chunk: {processed_chunk}") 

791 yield self.add_chunk_metadata(processed_chunk) 

792 chunk_queue.task_done() 

793 

794 def stream_agent(self, df: pd.DataFrame, agent_executor: AgentExecutor, args: Dict) -> Iterable[Dict]: 

795 base_template = args.get("prompt_template", args["prompt_template"]) 

796 input_variables = re.findall(r"{{(.*?)}}", base_template) 

797 return_context = args.get("return_context", True) 

798 

799 prompts, _ = prepare_prompts(df, base_template, input_variables, args.get("user_column", USER_COLUMN)) 

800 

801 callbacks, context_callback = prepare_callbacks(self, args) 

802 

803 # Use last prompt (current question) instead of first prompt (history) 

804 current_prompt = prompts[-1] if prompts else "" 

805 yield self.add_chunk_metadata({"type": "start", "prompt": current_prompt}) 

806 

807 if not hasattr(agent_executor, "stream") or not callable(agent_executor.stream): 

808 raise AttributeError("The agent_executor does not have a 'stream' method") 

809 

810 stream_iterator = self._stream_agent_executor(agent_executor, current_prompt, callbacks) 

811 for chunk in stream_iterator: 

812 yield chunk 

813 

814 if return_context: 

815 # Yield context if required 

816 captured_context = context_callback.get_contexts() 

817 if captured_context: 

818 yield {"type": "context", "content": captured_context} 

819 

820 if self.log_callback_handler.generated_sql: 

821 # Yield generated SQL if available 

822 yield self.add_chunk_metadata({"type": "sql", "content": self.log_callback_handler.generated_sql}) 

823 

824 # End the run completion span and update the metadata with tool usage 

825 self.langfuse_client_wrapper.end_span_stream(span=self.run_completion_span) 

826 

827 @staticmethod 

828 def process_chunk(chunk): 

829 if isinstance(chunk, dict): 

830 return {k: LangchainAgent.process_chunk(v) for k, v in chunk.items()} 

831 if isinstance(chunk, list): 

832 return [LangchainAgent.process_chunk(item) for item in chunk] 

833 if isinstance(chunk, AgentAction): 

834 # Format agent actions properly for streaming. 

835 return { 

836 "tool": LangchainAgent.process_chunk(chunk.tool), 

837 "tool_input": LangchainAgent.process_chunk(chunk.tool_input), 

838 "log": LangchainAgent.process_chunk(chunk.log), 

839 } 

840 if isinstance(chunk, AgentStep): 

841 # Format agent steps properly for streaming. 

842 return { 

843 "action": LangchainAgent.process_chunk(chunk.action), 

844 "observation": LangchainAgent.process_chunk(chunk.observation) if chunk.observation else "", 

845 } 

846 if issubclass(chunk.__class__, BaseMessage): 

847 # Extract content from message subclasses properly for streaming. 

848 return {"content": chunk.content} 

849 if isinstance(chunk, (str, int, float, bool, type(None))): 

850 return chunk 

851 return str(chunk)