Coverage for mindsdb / interfaces / agents / langchain_agent.py: 11%
451 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import json
2from concurrent.futures import as_completed, TimeoutError
3from typing import Dict, Iterable, List, Optional
4from uuid import uuid4
5import queue
6import re
7import threading
8import numpy as np
9import pandas as pd
10import logging
12from langchain.agents import AgentExecutor
13from langchain.agents.initialize import initialize_agent
14from langchain.chains.conversation.memory import ConversationSummaryBufferMemory
15from langchain_community.chat_models import ChatLiteLLM, ChatOllama
16from langchain_writer import ChatWriter
17from langchain_google_genai import ChatGoogleGenerativeAI
18from langchain_core.agents import AgentAction, AgentStep
19from langchain_core.callbacks.base import BaseCallbackHandler
20from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
22from langchain_nvidia_ai_endpoints import ChatNVIDIA
23from langchain_core.messages.base import BaseMessage
24from langchain_core.prompts import PromptTemplate
25from langchain_core.tools import Tool
27from mindsdb.integrations.libs.llm.utils import get_llm_config
28from mindsdb.integrations.utilities.handler_utils import get_api_key
29from mindsdb.integrations.utilities.rag.settings import DEFAULT_RAG_PROMPT_TEMPLATE
30from mindsdb.interfaces.agents.event_dispatch_callback_handler import (
31 EventDispatchCallbackHandler,
32)
33from mindsdb.interfaces.agents.constants import AGENT_CHUNK_POLLING_INTERVAL_SECONDS
34from mindsdb.utilities import log
35from mindsdb.utilities.context_executor import ContextThreadPoolExecutor
36from mindsdb.interfaces.storage import db
37from mindsdb.utilities.context import context as ctx
39from .mindsdb_chat_model import ChatMindsdb
40from .callback_handlers import LogCallbackHandler, ContextCaptureCallback
41from .langfuse_callback_handler import LangfuseCallbackHandler, get_skills
42from .safe_output_parser import SafeOutputParser
43from .providers import get_bedrock_chat_model
45from mindsdb.interfaces.agents.constants import (
46 OPEN_AI_CHAT_MODELS,
47 DEFAULT_AGENT_TIMEOUT_SECONDS,
48 get_default_agent_type,
49 DEFAULT_EMBEDDINGS_MODEL_PROVIDER,
50 DEFAULT_MAX_ITERATIONS,
51 DEFAULT_MAX_TOKENS,
52 DEFAULT_TIKTOKEN_MODEL_NAME,
53 SUPPORTED_PROVIDERS,
54 ANTHROPIC_CHAT_MODELS,
55 GOOGLE_GEMINI_CHAT_MODELS,
56 OLLAMA_CHAT_MODELS,
57 NVIDIA_NIM_CHAT_MODELS,
58 USER_COLUMN,
59 ASSISTANT_COLUMN,
60 CONTEXT_COLUMN,
61 TRACE_ID_COLUMN,
62 DEFAULT_AGENT_SYSTEM_PROMPT,
63 WRITER_CHAT_MODELS,
64 MINDSDB_PREFIX,
65 EXPLICIT_FORMAT_INSTRUCTIONS,
66)
67from mindsdb.interfaces.skills.skill_tool import skill_tool, SkillData
68from langchain_anthropic import ChatAnthropic
69from langchain_openai import ChatOpenAI
71from mindsdb.utilities.langfuse import LangfuseClientWrapper
73_PARSING_ERROR_PREFIXES = [
74 "An output parsing error occurred",
75 "Could not parse LLM output",
76]
78logger = log.getLogger(__name__)
81def get_llm_provider(args: Dict) -> str:
82 # If provider is explicitly specified, use that
83 if "provider" in args:
84 return args["provider"]
86 # Check for known model names from other providers first
87 if args["model_name"] in ANTHROPIC_CHAT_MODELS:
88 return "anthropic"
89 if args["model_name"] in OPEN_AI_CHAT_MODELS:
90 return "openai"
91 if args["model_name"] in OLLAMA_CHAT_MODELS:
92 return "ollama"
93 if args["model_name"] in NVIDIA_NIM_CHAT_MODELS:
94 return "nvidia_nim"
95 if args["model_name"] in GOOGLE_GEMINI_CHAT_MODELS:
96 return "google"
97 # Check for writer models
98 if args["model_name"] in WRITER_CHAT_MODELS:
99 return "writer"
101 # For vLLM, require explicit provider specification
102 raise ValueError("Invalid model name. Please define a supported llm provider")
105def get_embedding_model_provider(args: Dict) -> str:
106 """Get the embedding model provider from args.
108 For VLLM, this will use our custom VLLMEmbeddings class from langchain_embedding_handler.
109 """
110 # Check for explicit embedding model provider
111 if "embedding_model_provider" in args:
112 provider = args["embedding_model_provider"]
113 if provider == "vllm":
114 if not (args.get("openai_api_base") and args.get("model")):
115 raise ValueError(
116 "VLLM embeddings configuration error:\n"
117 "- Missing required parameters: 'openai_api_base' and/or 'model'\n"
118 "- Example: openai_api_base='http://localhost:8003/v1', model='your-model-name'"
119 )
120 logger.info("Using custom VLLMEmbeddings class")
121 return "vllm"
122 return provider
124 # Check if LLM provider is vLLM
125 llm_provider = args.get("provider", DEFAULT_EMBEDDINGS_MODEL_PROVIDER)
126 if llm_provider == "vllm":
127 if not (args.get("openai_api_base") and args.get("model")):
128 raise ValueError(
129 "VLLM embeddings configuration error:\n"
130 "- Missing required parameters: 'openai_api_base' and/or 'model'\n"
131 "- When using VLLM as LLM provider, you must specify the embeddings server location and model\n"
132 "- Example: openai_api_base='http://localhost:8003/v1', model='your-model-name'"
133 )
134 logger.info("Using custom VLLMEmbeddings class")
135 return "vllm"
137 # Default to LLM provider
138 return llm_provider
141def get_chat_model_params(args: Dict) -> Dict:
142 model_config = args.copy()
143 # Include API keys.
144 model_config["api_keys"] = {p: get_api_key(p, model_config, None, strict=False) for p in SUPPORTED_PROVIDERS}
145 llm_config = get_llm_config(args.get("provider", get_llm_provider(args)), model_config)
146 config_dict = llm_config.model_dump(by_alias=True)
147 config_dict = {k: v for k, v in config_dict.items() if v is not None}
149 # If provider is writer, ensure the API key is passed as 'api_key'
150 if args.get("provider") == "writer" and "writer_api_key" in config_dict:
151 config_dict["api_key"] = config_dict.pop("writer_api_key")
153 return config_dict
156def create_chat_model(args: Dict):
157 model_kwargs = get_chat_model_params(args)
159 if args["provider"] == "anthropic":
160 return ChatAnthropic(**model_kwargs)
161 if args["provider"] == "openai" or args["provider"] == "vllm":
162 chat_open_ai = ChatOpenAI(**model_kwargs)
163 # Some newer GPT models (e.g. gpt-4o when released) don't have token counting support yet.
164 # By setting this manually in ChatOpenAI, we count tokens like compatible GPT models.
165 try:
166 chat_open_ai.get_num_tokens_from_messages([])
167 except NotImplementedError:
168 chat_open_ai.tiktoken_model_name = DEFAULT_TIKTOKEN_MODEL_NAME
169 return chat_open_ai
170 if args["provider"] == "litellm":
171 return ChatLiteLLM(**model_kwargs)
172 if args["provider"] == "ollama":
173 return ChatOllama(**model_kwargs)
174 if args["provider"] == "nvidia_nim":
175 return ChatNVIDIA(**model_kwargs)
176 if args["provider"] == "google":
177 return ChatGoogleGenerativeAI(**model_kwargs)
178 if args["provider"] == "writer":
179 return ChatWriter(**model_kwargs)
180 if args["provider"] == "bedrock":
181 ChatBedrock = get_bedrock_chat_model()
182 return ChatBedrock(**model_kwargs)
183 if args["provider"] == "mindsdb":
184 return ChatMindsdb(**model_kwargs)
185 raise ValueError(f"Unknown provider: {args['provider']}")
188def prepare_prompts(df, base_template, input_variables, user_column=USER_COLUMN):
189 empty_prompt_ids = np.where(df[input_variables].isna().all(axis=1).values)[0]
191 # Combine system prompt with user-provided template
192 base_template = f"{DEFAULT_AGENT_SYSTEM_PROMPT}\n\n{base_template}"
194 base_template = base_template.replace("{{", "{").replace("}}", "}")
195 prompts = []
197 for i, row in df.iterrows():
198 if i not in empty_prompt_ids:
199 prompt = PromptTemplate(input_variables=input_variables, template=base_template)
200 kwargs = {col: row[col] if row[col] is not None else "" for col in input_variables}
201 prompts.append(prompt.format(**kwargs))
202 elif row.get(user_column):
203 prompts.append(row[user_column])
205 return prompts, empty_prompt_ids
208def prepare_callbacks(self, args):
209 context_callback = ContextCaptureCallback()
210 callbacks = self._get_agent_callbacks(args)
211 callbacks.append(context_callback)
212 return callbacks, context_callback
215def handle_agent_error(e, error_message=None):
216 if error_message is None:
217 error_message = f"An error occurred during agent execution: {str(e)}"
218 logger.error(error_message, exc_info=True)
219 return error_message
222def process_chunk(chunk):
223 if isinstance(chunk, dict):
224 return {k: process_chunk(v) for k, v in chunk.items()}
225 elif isinstance(chunk, list):
226 return [process_chunk(item) for item in chunk]
227 elif isinstance(chunk, (str, int, float, bool, type(None))):
228 return chunk
229 else:
230 return str(chunk)
233class LangchainAgent:
234 def __init__(self, agent: db.Agents, model: dict = None, llm_params: dict = None):
235 self.agent = agent
236 self.model = model
238 self.run_completion_span: Optional[object] = None
239 self.llm: Optional[object] = None
240 self.embedding_model: Optional[object] = None
242 self.log_callback_handler: Optional[object] = None
243 self.langfuse_callback_handler: Optional[object] = None # native langfuse callback handler
244 self.mdb_langfuse_callback_handler: Optional[object] = None # custom (see langfuse_callback_handler.py)
246 self.langfuse_client_wrapper = LangfuseClientWrapper()
247 self.args = self._initialize_args(llm_params)
249 # Back compatibility for old models
250 self.provider = self.args.get("provider", get_llm_provider(self.args))
252 def _initialize_args(self, llm_params: dict = None) -> dict:
253 """
254 Initialize the arguments for agent execution.
256 Takes the parameters passed during execution and sets necessary defaults.
257 The params are already merged with defaults by AgentsController.get_agent_llm_params.
259 Args:
260 llm_params: Parameters for agent execution (already merged with defaults)
262 Returns:
263 dict: Final parameters for agent execution
264 """
265 # Use the parameters passed to the method (already merged with defaults by AgentsController)
266 # No fallback needed as AgentsController.get_agent_llm_params already handles this
267 args = self.agent.params.copy()
268 if llm_params:
269 args.update(llm_params)
271 # Set model name and provider if given in create agent otherwise use global llm defaults
272 # AgentsController.get_agent_llm_params
273 if self.agent.model_name is not None:
274 args["model_name"] = self.agent.model_name
275 if self.agent.provider is not None:
276 args["provider"] = self.agent.provider
278 args["embedding_model_provider"] = args.get("embedding_model", get_embedding_model_provider(args))
280 # agent is using current langchain model
281 if self.agent.provider == "mindsdb":
282 args["model_name"] = self.agent.model_name
284 # get prompt
285 prompt_template = self.model["problem_definition"].get("using", {}).get("prompt_template")
286 if prompt_template is not None:
287 # only update prompt_template if it is set on the model
288 args["prompt_template"] = prompt_template
290 # Set default prompt template if not provided
291 if args.get("prompt_template") is None:
292 # Default prompt template depends on agent mode
293 if args.get("mode") == "retrieval":
294 args["prompt_template"] = DEFAULT_RAG_PROMPT_TEMPLATE
295 logger.info(f"Using default retrieval prompt template: {DEFAULT_RAG_PROMPT_TEMPLATE[:50]}...")
296 else:
297 # Set a default prompt template for non-retrieval mode
298 default_prompt = "you are an assistant, answer using the tables connected"
299 args["prompt_template"] = default_prompt
300 logger.info(f"Using default prompt template: {default_prompt}")
302 if "prompt_template" in args:
303 logger.info(f"Using prompt template: {args['prompt_template'][:50]}...")
305 if "model_name" not in args:
306 raise ValueError(
307 "No model name provided for agent. Provide it in the model parameter or in the default model setup."
308 )
310 return args
312 def get_metadata(self) -> Dict:
313 return {
314 "provider": self.provider,
315 "model_name": self.args["model_name"],
316 "embedding_model_provider": self.args.get(
317 "embedding_model_provider", get_embedding_model_provider(self.args)
318 ),
319 "skills": get_skills(self.agent),
320 "user_id": ctx.user_id,
321 "session_id": ctx.session_id,
322 "company_id": ctx.company_id,
323 "user_class": ctx.user_class,
324 "email_confirmed": ctx.email_confirmed,
325 }
327 def get_tags(self) -> List:
328 return [
329 self.provider,
330 ]
332 def get_completion(self, messages, stream: bool = False, params: dict | None = None):
333 # Get metadata and tags to be used in the trace
334 metadata = self.get_metadata()
335 tags = self.get_tags()
337 # Set up trace for the API completion in Langfuse
338 self.langfuse_client_wrapper.setup_trace(
339 name="api-completion",
340 input=messages,
341 tags=tags,
342 metadata=metadata,
343 user_id=ctx.user_id,
344 session_id=ctx.session_id,
345 )
347 # Set up trace for the run completion in Langfuse
348 self.run_completion_span = self.langfuse_client_wrapper.start_span(name="run-completion", input=messages)
350 if stream:
351 return self._get_completion_stream(messages)
353 args = {}
354 args.update(self.args)
355 args.update(params or {})
357 df = pd.DataFrame(messages)
358 logger.info(f"LangchainAgent.get_completion: Received {len(messages)} messages")
359 if logger.isEnabledFor(logging.DEBUG):
360 logger.debug(f"Messages DataFrame shape: {df.shape}")
361 logger.debug(f"Messages DataFrame columns: {df.columns.tolist()}")
362 logger.debug(f"Messages DataFrame content: {df.to_dict('records')}")
364 # Back compatibility for old models
365 self.provider = args.get("provider", get_llm_provider(args))
367 df = df.reset_index(drop=True)
368 agent = self.create_agent(df)
369 # Keep conversation history for context - don't nullify previous messages
371 # Only use the last message as the current prompt, but preserve history for agent memory
372 response = self.run_agent(df, agent, args)
374 # End the run completion span and update the metadata with tool usage
375 self.langfuse_client_wrapper.end_span(span=self.run_completion_span, output=response)
377 return response
379 def _get_completion_stream(self, messages: List[dict]) -> Iterable[Dict]:
380 """Gets a completion as a stream of chunks from given messages.
382 Args:
383 messages (List[dict]): Messages to get completion chunks for
385 Returns:
386 chunks (Iterable[object]): Completion chunks
387 """
389 args = self.args
391 df = pd.DataFrame(messages)
392 logger.info(f"LangchainAgent._get_completion_stream: Received {len(messages)} messages")
393 # Check if we have the expected columns for conversation history
394 if "question" in df.columns and "answer" in df.columns:
395 logger.debug("DataFrame has question/answer columns for conversation history")
396 else:
397 logger.warning("DataFrame missing question/answer columns! Available columns: {df.columns.tolist()}")
399 self.embedding_model_provider = args.get("embedding_model_provider", get_embedding_model_provider(args))
400 # Back compatibility for old models
401 self.provider = args.get("provider", get_llm_provider(args))
403 df = df.reset_index(drop=True)
404 agent = self.create_agent(df)
405 # Keep conversation history for context - don't nullify previous messages
406 # Only use the last message as the current prompt, but preserve history for agent memory
407 return self.stream_agent(df, agent, args)
409 def create_agent(self, df: pd.DataFrame) -> AgentExecutor:
410 # Set up tools.
412 args = self.args
414 llm = create_chat_model(args)
415 self.llm = llm
417 # Don't set embedding model for retrieval mode - let the knowledge base handle it
418 if args.get("mode") == "retrieval":
419 self.args.pop("mode")
421 tools = self._langchain_tools_from_skills(llm)
423 # Prefer prediction prompt template over original if provided.
424 prompt_template = args["prompt_template"]
426 # Modern LangChain approach: Use memory but populate it correctly
427 # Create memory and populate with conversation history
428 memory = ConversationSummaryBufferMemory(
429 llm=llm,
430 input_key="input",
431 output_key="output",
432 max_token_limit=args.get("max_tokens", DEFAULT_MAX_TOKENS),
433 memory_key="chat_history",
434 )
436 # Add system message first
437 memory.chat_memory.messages.insert(0, SystemMessage(content=prompt_template))
439 user_column = args.get("user_column", USER_COLUMN)
440 assistant_column = args.get("assistant_column", ASSISTANT_COLUMN)
442 logger.info(f"Processing conversation history: {len(df)} total messages, {len(df[:-1])} history messages")
443 logger.debug(f"User column: {user_column}, Assistant column: {assistant_column}")
445 # Process history messages (all except the last one which is current message)
446 history_df = df[:-1]
447 if len(history_df) == 0:
448 logger.debug("No history rows to process - this is normal for first message")
450 history_count = 0
451 for i, row in enumerate(history_df.to_dict("records")):
452 question = row.get(user_column)
453 answer = row.get(assistant_column)
454 logger.debug(f"Converting history row {i}: question='{question}', answer='{answer}'")
456 # Add messages directly to memory's chat_memory.messages list (modern approach)
457 if isinstance(question, str) and len(question) > 0:
458 memory.chat_memory.messages.append(HumanMessage(content=question))
459 history_count += 1
460 logger.debug(f"Added HumanMessage to memory: {question}")
461 if isinstance(answer, str) and len(answer) > 0:
462 memory.chat_memory.messages.append(AIMessage(content=answer))
463 history_count += 1
464 logger.debug(f"Added AIMessage to memory: {answer}")
466 logger.info(f"Built conversation history with {history_count} history messages + system message")
467 logger.debug(f"Final memory messages count: {len(memory.chat_memory.messages)}")
469 # Store memory for agent use
470 self._conversation_memory = memory
471 default_agent = get_default_agent_type()
472 agent_type = args.get("agent_type", default_agent)
473 agent_executor = initialize_agent(
474 tools,
475 llm,
476 agent=agent_type,
477 # Use custom output parser to handle flaky LLMs that don't ALWAYS conform to output format.
478 agent_kwargs={
479 "output_parser": SafeOutputParser(),
480 "prefix": MINDSDB_PREFIX, # Override default "Assistant is a large language model..." text
481 "format_instructions": EXPLICIT_FORMAT_INSTRUCTIONS, # More explicit tool calling instructions
482 "ai_prefix": "AI",
483 },
484 # Calls the agent's LLM Chain one final time to generate a final answer based on the previous steps
485 early_stopping_method="generate",
486 handle_parsing_errors=self._handle_parsing_errors,
487 # Timeout per agent invocation.
488 max_execution_time=args.get(
489 "timeout_seconds",
490 args.get("timeout_seconds", DEFAULT_AGENT_TIMEOUT_SECONDS),
491 ),
492 max_iterations=args.get("max_iterations", args.get("max_iterations", DEFAULT_MAX_ITERATIONS)),
493 memory=memory,
494 verbose=args.get("verbose", args.get("verbose", False)),
495 )
496 return agent_executor
498 def _langchain_tools_from_skills(self, llm):
499 # Makes Langchain compatible tools from a skill
500 skills_data = [
501 SkillData(
502 name=rel.skill.name,
503 type=rel.skill.type,
504 params=rel.skill.params,
505 project_id=rel.skill.project_id,
506 agent_tables_list=(rel.parameters or {}).get("tables"),
507 )
508 for rel in self.agent.skills_relationships
509 ]
511 tools_groups = skill_tool.get_tools_from_skills(skills_data, llm, self.embedding_model)
513 all_tools = []
514 for skill_type, tools in tools_groups.items():
515 for tool in tools:
516 if isinstance(tool, dict):
517 tool = Tool(
518 name=tool["name"],
519 func=tool["func"],
520 description=tool["description"],
521 )
522 all_tools.append(tool)
523 return all_tools
525 def _get_agent_callbacks(self, args: Dict) -> List:
526 all_callbacks = []
528 if self.log_callback_handler is None:
529 self.log_callback_handler = LogCallbackHandler(logger, verbose=args.get("verbose", True))
531 all_callbacks.append(self.log_callback_handler)
533 if self.langfuse_client_wrapper.trace is None:
534 # Get metadata and tags to be used in the trace
535 metadata = self.get_metadata()
536 tags = self.get_tags()
538 trace_name = "NativeTrace-MindsDB-AgentExecutor"
540 # Set up trace for the API completion in Langfuse
541 self.langfuse_client_wrapper.setup_trace(
542 name=trace_name,
543 tags=tags,
544 metadata=metadata,
545 user_id=ctx.user_id,
546 session_id=ctx.session_id,
547 )
549 if self.langfuse_callback_handler is None:
550 self.langfuse_callback_handler = self.langfuse_client_wrapper.get_langchain_handler()
552 # custom tracer
553 if self.mdb_langfuse_callback_handler is None:
554 trace_id = self.langfuse_client_wrapper.get_trace_id()
556 span_id = None
557 if self.run_completion_span is not None:
558 span_id = self.run_completion_span.id
560 observation_id = args.get("observation_id", span_id or uuid4().hex)
562 self.mdb_langfuse_callback_handler = LangfuseCallbackHandler(
563 langfuse=self.langfuse_client_wrapper.client,
564 trace_id=trace_id,
565 observation_id=observation_id,
566 )
568 # obs: we may want to unify these; native langfuse handler provides details as a tree on a sub-step of the overarching custom one # noqa
569 if self.langfuse_callback_handler is not None:
570 all_callbacks.append(self.langfuse_callback_handler)
572 if self.mdb_langfuse_callback_handler:
573 all_callbacks.append(self.mdb_langfuse_callback_handler)
575 return all_callbacks
577 def _handle_parsing_errors(self, error: Exception) -> str:
578 response = str(error)
579 for p in _PARSING_ERROR_PREFIXES:
580 if response.startswith(p):
581 # As a somewhat dirty workaround, we accept the output formatted incorrectly and use it as a response.
582 #
583 # Ideally, in the future, we would write a parser that is more robust and flexible than the one Langchain uses.
584 # Response is wrapped in ``
585 logger.info("Handling parsing error, salvaging response...")
586 response_output = response.split("`")
587 if len(response_output) >= 2:
588 response = response_output[-2]
590 # Wrap response in Langchain conversational react format.
591 langchain_react_formatted_response = f"""Thought: Do I need to use a tool? No
592AI: {response}"""
593 return langchain_react_formatted_response
594 return f"Agent failed with error:\n{str(error)}..."
596 def run_agent(self, df: pd.DataFrame, agent: AgentExecutor, args: Dict) -> pd.DataFrame:
597 base_template = args.get("prompt_template", args["prompt_template"])
598 return_context = args.get("return_context", True)
599 input_variables = re.findall(r"{{(.*?)}}", base_template)
601 prompts, empty_prompt_ids = prepare_prompts(
602 df, base_template, input_variables, args.get("user_column", USER_COLUMN)
603 )
605 def _invoke_agent_executor_with_prompt(agent_executor, prompt):
606 if not prompt:
607 return {CONTEXT_COLUMN: [], ASSISTANT_COLUMN: ""}
608 try:
609 callbacks, context_callback = prepare_callbacks(self, args)
611 # Modern LangChain approach: Include conversation history + current message
612 if hasattr(self, "_conversation_messages") and self._conversation_messages:
613 # Add current user message to conversation history
614 full_messages = self._conversation_messages + [HumanMessage(content=prompt)]
615 logger.critical(f"🔍 INVOKING AGENT with {len(full_messages)} messages (including history)")
616 logger.debug(
617 f"Full conversation messages: {[type(msg).__name__ + ': ' + msg.content[:100] + '...' for msg in full_messages]}"
618 )
620 # For agents, we need to pass the input in the expected format
621 # The agent expects 'input' key with the current question, but conversation history should be in memory
622 result = agent_executor.invoke({"input": prompt}, config={"callbacks": callbacks})
623 else:
624 logger.warning("No conversation messages found - using simple prompt")
625 result = agent_executor.invoke({"input": prompt}, config={"callbacks": callbacks})
626 captured_context = context_callback.get_contexts()
627 output = result["output"] if isinstance(result, dict) and "output" in result else str(result)
628 return {CONTEXT_COLUMN: captured_context, ASSISTANT_COLUMN: output}
629 except Exception as e:
630 error_message = str(e)
631 # Special handling for API key errors
632 if "API key" in error_message and ("not found" in error_message or "missing" in error_message):
633 # Format API key error more clearly
634 logger.error(f"API Key Error: {error_message}")
635 error_message = f"API Key Error: {error_message}"
636 return {
637 CONTEXT_COLUMN: [],
638 ASSISTANT_COLUMN: handle_agent_error(e, error_message),
639 }
641 completions = []
642 contexts = []
644 max_workers = args.get("max_workers", None)
645 agent_timeout_seconds = args.get("timeout", DEFAULT_AGENT_TIMEOUT_SECONDS)
647 with ContextThreadPoolExecutor(max_workers=max_workers) as executor:
648 # Only process the last prompt (current question), not all prompts
649 # The previous prompts are conversation history and should only be used for context
650 if prompts:
651 current_prompt = prompts[-1] # Last prompt is the current question
652 futures = [executor.submit(_invoke_agent_executor_with_prompt, agent, current_prompt)]
653 else:
654 logger.error("No prompts found to process")
655 futures = []
656 try:
657 for future in as_completed(futures, timeout=agent_timeout_seconds):
658 result = future.result()
659 if result is None:
660 result = {
661 CONTEXT_COLUMN: [],
662 ASSISTANT_COLUMN: "No response generated",
663 }
665 completions.append(result[ASSISTANT_COLUMN])
666 contexts.append(result[CONTEXT_COLUMN])
667 except TimeoutError:
668 timeout_message = (
669 f"I'm sorry! I couldn't generate a response within the allotted time ({agent_timeout_seconds} seconds). "
670 "If you need more time for processing, you can adjust the timeout settings. "
671 "Please refer to the documentation for instructions on how to change the timeout value. "
672 "Feel free to try your request again."
673 )
674 logger.warning(f"Agent execution timed out after {agent_timeout_seconds} seconds")
675 for _ in range(len(futures) - len(completions)):
676 completions.append(timeout_message)
677 contexts.append([])
679 # Add null completion for empty prompts
680 for i in sorted(empty_prompt_ids)[:-1]:
681 completions.insert(i, None)
682 contexts.insert(i, [])
684 # Create DataFrame with completions and context if required
685 pred_df = pd.DataFrame(
686 {
687 ASSISTANT_COLUMN: completions,
688 CONTEXT_COLUMN: [json.dumps(ctx) for ctx in contexts], # Serialize context to JSON string
689 TRACE_ID_COLUMN: self.langfuse_client_wrapper.get_trace_id(),
690 }
691 )
693 if not return_context:
694 pred_df = pred_df.drop(columns=[CONTEXT_COLUMN])
696 return pred_df
698 def add_chunk_metadata(self, chunk: Dict) -> Dict:
699 logger.debug(f"Adding metadata to chunk: {chunk}")
700 logger.debug(f"Trace ID: {self.langfuse_client_wrapper.get_trace_id()}")
701 chunk["trace_id"] = self.langfuse_client_wrapper.get_trace_id()
702 return chunk
704 def _stream_agent_executor(
705 self,
706 agent_executor: AgentExecutor,
707 prompt: str,
708 callbacks: List[BaseCallbackHandler],
709 ):
710 chunk_queue = queue.Queue()
711 # Add event dispatch callback handler only to streaming completions.
712 event_dispatch_callback_handler = EventDispatchCallbackHandler(chunk_queue)
713 callbacks.append(event_dispatch_callback_handler)
714 stream_iterator = agent_executor.stream(prompt, config={"callbacks": callbacks})
716 agent_executor_finished_event = threading.Event()
718 def stream_worker(context: dict):
719 try:
720 ctx.load(context)
721 for chunk in stream_iterator:
722 chunk_queue.put(chunk)
723 except TimeoutError as e:
724 error_message = f"Request timed out: The agent took too long to respond. {str(e)}"
725 logger.error(f"Timeout error during streaming: {error_message}", exc_info=True)
726 error_chunk = {
727 "type": "error",
728 "content": handle_agent_error(e, error_message),
729 "error_type": "timeout",
730 }
731 chunk_queue.put(error_chunk)
732 except ConnectionError as e:
733 error_message = f"Connection error: Failed to connect to the service. {str(e)}"
734 logger.error(f"Connection error during streaming: {error_message}", exc_info=True)
735 error_chunk = {
736 "type": "error",
737 "content": handle_agent_error(e, error_message),
738 "error_type": "connection",
739 }
740 chunk_queue.put(error_chunk)
741 except Exception as e:
742 error_message = str(e)
744 # Special handling for specific error types
745 # Note: TimeoutError and ConnectionError are already handled by specific exception handlers above
746 if "API key" in error_message and ("not found" in error_message or "missing" in error_message):
747 logger.error(f"API Key Error: {error_message}")
748 error_message = f"API Key Error: {error_message}"
749 error_type = "authentication"
750 elif "404" in error_message or "not found" in error_message.lower():
751 logger.error(f"Model Error: {error_message}")
752 error_type = "not_found"
753 elif "rate limit" in error_message.lower() or "429" in error_message:
754 logger.error(f"Rate Limit Error: {error_message}")
755 error_message = f"Rate limit exceeded: {error_message}"
756 error_type = "rate_limit"
757 else:
758 logger.error(f"LLM chain encountered an error during streaming: {error_message}", exc_info=True)
759 error_type = "general"
761 if not error_message or not error_message.strip():
762 error_message = f"An unknown error occurred during streaming: {type(e).__name__}"
764 error_chunk = {
765 "type": "error",
766 "content": handle_agent_error(e, error_message),
767 "error_type": error_type,
768 }
769 chunk_queue.put(error_chunk)
770 finally:
771 # Wrap in try/finally to always set the thread event even if there's an exception.
772 agent_executor_finished_event.set()
774 # Enqueue Langchain agent streaming chunks in a separate thread to not block event chunks.
775 executor_stream_thread = threading.Thread(
776 target=stream_worker,
777 daemon=True,
778 args=(ctx.dump(),),
779 name="LangchainAgent.stream_worker",
780 )
781 executor_stream_thread.start()
783 while not agent_executor_finished_event.is_set():
784 try:
785 chunk = chunk_queue.get(block=True, timeout=AGENT_CHUNK_POLLING_INTERVAL_SECONDS)
786 except queue.Empty:
787 continue
788 logger.debug(f"Processing streaming chunk {chunk}")
789 processed_chunk = self.process_chunk(chunk)
790 logger.info(f"Processed chunk: {processed_chunk}")
791 yield self.add_chunk_metadata(processed_chunk)
792 chunk_queue.task_done()
794 def stream_agent(self, df: pd.DataFrame, agent_executor: AgentExecutor, args: Dict) -> Iterable[Dict]:
795 base_template = args.get("prompt_template", args["prompt_template"])
796 input_variables = re.findall(r"{{(.*?)}}", base_template)
797 return_context = args.get("return_context", True)
799 prompts, _ = prepare_prompts(df, base_template, input_variables, args.get("user_column", USER_COLUMN))
801 callbacks, context_callback = prepare_callbacks(self, args)
803 # Use last prompt (current question) instead of first prompt (history)
804 current_prompt = prompts[-1] if prompts else ""
805 yield self.add_chunk_metadata({"type": "start", "prompt": current_prompt})
807 if not hasattr(agent_executor, "stream") or not callable(agent_executor.stream):
808 raise AttributeError("The agent_executor does not have a 'stream' method")
810 stream_iterator = self._stream_agent_executor(agent_executor, current_prompt, callbacks)
811 for chunk in stream_iterator:
812 yield chunk
814 if return_context:
815 # Yield context if required
816 captured_context = context_callback.get_contexts()
817 if captured_context:
818 yield {"type": "context", "content": captured_context}
820 if self.log_callback_handler.generated_sql:
821 # Yield generated SQL if available
822 yield self.add_chunk_metadata({"type": "sql", "content": self.log_callback_handler.generated_sql})
824 # End the run completion span and update the metadata with tool usage
825 self.langfuse_client_wrapper.end_span_stream(span=self.run_completion_span)
827 @staticmethod
828 def process_chunk(chunk):
829 if isinstance(chunk, dict):
830 return {k: LangchainAgent.process_chunk(v) for k, v in chunk.items()}
831 if isinstance(chunk, list):
832 return [LangchainAgent.process_chunk(item) for item in chunk]
833 if isinstance(chunk, AgentAction):
834 # Format agent actions properly for streaming.
835 return {
836 "tool": LangchainAgent.process_chunk(chunk.tool),
837 "tool_input": LangchainAgent.process_chunk(chunk.tool_input),
838 "log": LangchainAgent.process_chunk(chunk.log),
839 }
840 if isinstance(chunk, AgentStep):
841 # Format agent steps properly for streaming.
842 return {
843 "action": LangchainAgent.process_chunk(chunk.action),
844 "observation": LangchainAgent.process_chunk(chunk.observation) if chunk.observation else "",
845 }
846 if issubclass(chunk.__class__, BaseMessage):
847 # Extract content from message subclasses properly for streaming.
848 return {"content": chunk.content}
849 if isinstance(chunk, (str, int, float, bool, type(None))):
850 return chunk
851 return str(chunk)