Coverage for mindsdb / interfaces / agents / agents_controller.py: 74%

263 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import datetime 

2from typing import Dict, Iterator, List, Union, Tuple, Optional, Any 

3import copy 

4 

5from sqlalchemy.orm.attributes import flag_modified 

6from sqlalchemy import null 

7import pandas as pd 

8 

9from mindsdb.interfaces.storage import db 

10from mindsdb.interfaces.storage.db import Predictor 

11from mindsdb.utilities.context import context as ctx 

12from mindsdb.interfaces.database.projects import ProjectController 

13from mindsdb.interfaces.model.functions import PredictorRecordNotFound 

14from mindsdb.interfaces.model.model_controller import ModelController 

15from mindsdb.interfaces.skills.skills_controller import SkillsController 

16from mindsdb.utilities.config import config 

17from mindsdb.utilities import log 

18 

19from mindsdb.utilities.exception import EntityExistsError, EntityNotExistsError 

20 

21from .constants import ASSISTANT_COLUMN, SUPPORTED_PROVIDERS, PROVIDER_TO_MODELS 

22from .provider_utils import get_llm_provider 

23 

24logger = log.getLogger(__name__) 

25 

26default_project = config.get("default_project") 

27 

28 

29class AgentsController: 

30 """Handles CRUD operations at the database level for Agents""" 

31 

32 assistant_column = ASSISTANT_COLUMN 

33 autogenerated_skill_prefix = "Auto-generated SQL skill for agent" 

34 

35 def __init__( 

36 self, 

37 project_controller: ProjectController = None, 

38 skills_controller: SkillsController = None, 

39 model_controller: ModelController = None, 

40 ): 

41 if project_controller is None: 41 ↛ 43line 41 didn't jump to line 43 because the condition on line 41 was always true

42 project_controller = ProjectController() 

43 if skills_controller is None: 43 ↛ 45line 43 didn't jump to line 45 because the condition on line 43 was always true

44 skills_controller = SkillsController() 

45 if model_controller is None: 45 ↛ 47line 45 didn't jump to line 47 because the condition on line 45 was always true

46 model_controller = ModelController() 

47 self.project_controller = project_controller 

48 self.skills_controller = skills_controller 

49 self.model_controller = model_controller 

50 

51 def check_model_provider(self, model_name: str, provider: str = None) -> Tuple[dict, str]: 

52 """ 

53 Checks if a model exists, and gets the provider of the model. 

54 

55 The provider is either the provider of the model or the provider given as an argument. 

56 

57 Parameters: 

58 model_name (str): The name of the model 

59 provider (str): The provider to check 

60 

61 Returns: 

62 model (dict): The model object 

63 provider (str): The provider of the model 

64 """ 

65 model = None 

66 

67 # Handle the case when model_name is None (using default LLM) 

68 if model_name is None: 68 ↛ 69line 68 didn't jump to line 69 because the condition on line 68 was never true

69 return model, provider 

70 

71 try: 

72 model_name_no_version, model_version = Predictor.get_name_and_version(model_name) 

73 model = self.model_controller.get_model(model_name_no_version, version=model_version) 

74 provider = "mindsdb" if model.get("provider") is None else model.get("provider") 

75 except PredictorRecordNotFound: 

76 if not provider: 76 ↛ 78line 76 didn't jump to line 78 because the condition on line 76 was never true

77 # If provider is not given, get it from the model name 

78 provider = get_llm_provider({"model_name": model_name}) 

79 

80 elif provider not in SUPPORTED_PROVIDERS and model_name not in PROVIDER_TO_MODELS.get(provider, []): 80 ↛ 83line 80 didn't jump to line 83 because the condition on line 80 was always true

81 raise ValueError(f"Model with name does not exist for provider {provider}: {model_name}") 

82 

83 return model, provider 

84 

85 def get_agent(self, agent_name: str, project_name: str = default_project) -> Optional[db.Agents]: 

86 """ 

87 Gets an agent by name. 

88 

89 Parameters: 

90 agent_name (str): The name of the agent 

91 project_name (str): The name of the containing project - must exist 

92 

93 Returns: 

94 agent (Optional[db.Agents]): The database agent object 

95 """ 

96 

97 project = self.project_controller.get(name=project_name) 

98 agent = db.Agents.query.filter( 

99 db.Agents.name == agent_name, 

100 db.Agents.project_id == project.id, 

101 db.Agents.company_id == ctx.company_id, 

102 db.Agents.deleted_at == null(), 

103 ).first() 

104 return agent 

105 

106 def get_agent_by_id(self, id: int, project_name: str = default_project) -> db.Agents: 

107 """ 

108 Gets an agent by id. 

109 

110 Parameters: 

111 id (int): The id of the agent 

112 project_name (str): The name of the containing project - must exist 

113 

114 Returns: 

115 agent (db.Agents): The database agent object 

116 """ 

117 

118 project = self.project_controller.get(name=project_name) 

119 agent = db.Agents.query.filter( 

120 db.Agents.id == id, 

121 db.Agents.project_id == project.id, 

122 db.Agents.company_id == ctx.company_id, 

123 db.Agents.deleted_at == null(), 

124 ).first() 

125 return agent 

126 

127 def get_agents(self, project_name: str) -> List[dict]: 

128 """ 

129 Gets all agents in a project. 

130 

131 Parameters: 

132 project_name (str): The name of the containing project - must exist 

133 

134 Returns: 

135 all-agents (List[db.Agents]): List of database agent object 

136 """ 

137 

138 all_agents = db.Agents.query.filter(db.Agents.company_id == ctx.company_id, db.Agents.deleted_at == null()) 

139 

140 if project_name is not None: 140 ↛ 145line 140 didn't jump to line 145 because the condition on line 140 was always true

141 project = self.project_controller.get(name=project_name) 

142 

143 all_agents = all_agents.filter(db.Agents.project_id == project.id) 

144 

145 return all_agents.all() 

146 

147 def _create_default_sql_skill( 

148 self, 

149 name, 

150 project_name, 

151 include_tables: List[str] = None, 

152 include_knowledge_bases: List[str] = None, 

153 ): 

154 # Create a default SQL skill 

155 skill_name = f"{name}_sql_skill" 

156 skill_params = { 

157 "type": "sql", 

158 "description": f"{self.autogenerated_skill_prefix} {name}", 

159 } 

160 

161 # Add restrictions provided 

162 if include_tables: 162 ↛ 164line 162 didn't jump to line 164 because the condition on line 162 was always true

163 skill_params["include_tables"] = include_tables 

164 if include_knowledge_bases: 

165 skill_params["include_knowledge_bases"] = include_knowledge_bases 

166 

167 try: 

168 # Check if skill already exists 

169 existing_skill = self.skills_controller.get_skill(skill_name, project_name) 

170 if existing_skill is None: 

171 # Create the skill 

172 skill_type = skill_params.pop("type") 

173 self.skills_controller.add_skill( 

174 name=skill_name, project_name=project_name, type=skill_type, params=skill_params 

175 ) 

176 else: 

177 # Update the skill if parameters have changed 

178 params_changed = False 

179 

180 # Check if skill parameters need to be updated 

181 keys = set(skill_params.keys()) | set(existing_skill.params.keys()) 

182 for param_key in keys: 

183 if param_key in skill_params: 183 ↛ 188line 183 didn't jump to line 188 because the condition on line 183 was always true

184 if existing_skill.params.get(param_key) != skill_params[param_key]: 

185 existing_skill.params[param_key] = skill_params[param_key] 

186 params_changed = True 

187 else: 

188 existing_skill.params.pop(param_key) 

189 params_changed = True 

190 

191 # Update the skill if needed 

192 if params_changed: 192 ↛ 200line 192 didn't jump to line 200 because the condition on line 192 was always true

193 flag_modified(existing_skill, "params") 

194 db.session.commit() 

195 

196 except Exception as e: 

197 logger.exception("Failed to auto-create or update SQL skill:") 

198 raise ValueError(f"Failed to auto-create or update SQL skill: {e}") from e 

199 

200 return skill_name 

201 

202 def add_agent( 

203 self, 

204 name: str, 

205 project_name: str = None, 

206 model_name: Union[str, dict] = None, 

207 skills: List[Union[str, dict]] = None, 

208 provider: str = None, 

209 params: Dict[str, Any] = None, 

210 ) -> db.Agents: 

211 """ 

212 Adds an agent to the database. 

213 

214 Parameters: 

215 name (str): The name of the new agent 

216 project_name (str): The containing project 

217 model_name (str | dict): The name of the existing ML model the agent will use 

218 skills (List[Union[str, dict]]): List of existing skill names to add to the new agent, or list of dicts 

219 with one of keys is "name", and other is additional parameters for relationship agent<>skill 

220 provider (str): The provider of the model 

221 params (Dict[str, str]): Parameters to use when running the agent 

222 data: Dict, data sources for an agent, keys: 

223 - knowledge_bases: List of KBs to use 

224 - tables: list of tables to use 

225 model: Dict, parameters for the model to use 

226 - provider: The provider of the model (e.g., 'openai', 'google') 

227 - Other model-specific parameters like 'api_key', 'model_name', etc. 

228 <provider>_api_key: API key for the provider (e.g., openai_api_key) 

229 

230 # Deprecated parameters: 

231 database: The database to use for text2sql skills (default is 'mindsdb') 

232 knowledge_base_database: The database to use for knowledge base queries (default is 'mindsdb') 

233 include_tables: List of tables to include for text2sql skills 

234 ignore_tables: List of tables to ignore for text2sql skills 

235 include_knowledge_bases: List of knowledge bases to include for text2sql skills 

236 ignore_knowledge_bases: List of knowledge bases to ignore for text2sql skills 

237 

238 Returns: 

239 agent (db.Agents): The created agent 

240 

241 Raises: 

242 EntityExistsError: Agent with given name already exists, or skill/model with given name does not exist. 

243 """ 

244 if project_name is None: 244 ↛ 245line 244 didn't jump to line 245 because the condition on line 244 was never true

245 project_name = default_project 

246 project = self.project_controller.get(name=project_name) 

247 

248 agent = self.get_agent(name, project_name) 

249 

250 if agent is not None: 250 ↛ 251line 250 didn't jump to line 251 because the condition on line 250 was never true

251 raise EntityExistsError("Agent already exists", name) 

252 

253 # No need to copy params since we're not preserving the original reference 

254 params = params or {} 

255 

256 if isinstance(model_name, dict): 256 ↛ 258line 256 didn't jump to line 258 because the condition on line 256 was never true

257 # move into params 

258 params["model"] = model_name 

259 model_name = None 

260 

261 if model_name is not None: 

262 _, provider = self.check_model_provider(model_name, provider) 

263 

264 if model_name is None: 

265 logger.warning("'model_name' param is not provided. Using default global llm model at runtime.") 

266 

267 # If model_name is not provided, we use default global llm model at runtime 

268 # Default parameters will be applied at runtime via get_agent_llm_params 

269 # This allows global default updates to apply to all agents immediately 

270 

271 # Extract API key if provided in the format <provider>_api_key 

272 if provider is not None: 

273 provider_api_key_param = f"{provider.lower()}_api_key" 

274 if provider_api_key_param in params: 

275 # Keep the API key in params for the agent to use 

276 # It will be picked up by get_api_key() in handler_utils.py 

277 pass 

278 

279 # Handle generic api_key parameter if provided 

280 if "api_key" in params: 

281 # Keep the generic API key in params for the agent to use 

282 # It will be picked up by get_api_key() in handler_utils.py 

283 pass 

284 

285 depreciated_params = [ 

286 "database", 

287 "knowledge_base_database", 

288 "include_tables", 

289 "ignore_tables", 

290 "include_knowledge_bases", 

291 "ignore_knowledge_bases", 

292 ] 

293 if any(param in params for param in depreciated_params): 293 ↛ 294line 293 didn't jump to line 294 because the condition on line 293 was never true

294 raise ValueError( 

295 f"Parameters {', '.join(depreciated_params)} are deprecated. " 

296 "Use 'data' parameter with 'tables' and 'knowledge_bases' keys instead." 

297 ) 

298 

299 include_tables = None 

300 include_knowledge_bases = None 

301 if "data" in params: 

302 include_knowledge_bases = params["data"].get("knowledge_bases") 

303 include_tables = params["data"].get("tables") 

304 

305 # Convert string parameters to lists if needed 

306 if isinstance(include_tables, str): 306 ↛ 307line 306 didn't jump to line 307 because the condition on line 306 was never true

307 include_tables = [t.strip() for t in include_tables.split(",")] 

308 if isinstance(include_knowledge_bases, str): 308 ↛ 309line 308 didn't jump to line 309 because the condition on line 308 was never true

309 include_knowledge_bases = [kb.strip() for kb in include_knowledge_bases.split(",")] 

310 

311 # Auto-create SQL skill if no skills are provided but include_tables or include_knowledge_bases params are provided 

312 if not skills and (include_tables or include_knowledge_bases): 

313 skill = self._create_default_sql_skill( 

314 name, 

315 project_name, 

316 include_tables=include_tables, 

317 include_knowledge_bases=include_knowledge_bases, 

318 ) 

319 skills = [skill] 

320 

321 agent = db.Agents( 

322 name=name, 

323 project_id=project.id, 

324 company_id=ctx.company_id, 

325 user_class=ctx.user_class, 

326 model_name=model_name, 

327 provider=provider, 

328 params=params, 

329 ) 

330 

331 for skill in skills: 

332 if isinstance(skill, str): 332 ↛ 336line 332 didn't jump to line 336 because the condition on line 332 was always true

333 skill_name = skill 

334 parameters = {} 

335 else: 

336 parameters = skill.copy() 

337 skill_name = parameters.pop("name") 

338 

339 existing_skill = self.skills_controller.get_skill(skill_name, project_name) 

340 if existing_skill is None: 

341 db.session.rollback() 

342 raise ValueError(f"Skill with name does not exist: {skill_name}") 

343 

344 if existing_skill.type == "sql": 

345 # Add table restrictions if this is a text2sql skill 

346 if include_tables: 346 ↛ 350line 346 didn't jump to line 350 because the condition on line 346 was always true

347 parameters["tables"] = include_tables 

348 

349 # Add knowledge base parameters to both the skill and the association parameters 

350 if include_knowledge_bases: 350 ↛ 351line 350 didn't jump to line 351 because the condition on line 350 was never true

351 parameters["include_knowledge_bases"] = include_knowledge_bases 

352 if "include_knowledge_bases" not in existing_skill.params: 

353 existing_skill.params["include_knowledge_bases"] = include_knowledge_bases 

354 flag_modified(existing_skill, "params") 

355 

356 association = db.AgentSkillsAssociation(parameters=parameters, agent=agent, skill=existing_skill) 

357 db.session.add(association) 

358 

359 db.session.add(agent) 

360 db.session.commit() 

361 

362 return agent 

363 

364 def update_agent( 

365 self, 

366 agent_name: str, 

367 project_name: str = default_project, 

368 name: str = None, 

369 model_name: Union[str, dict] = None, 

370 skills_to_add: List[Union[str, dict]] = None, 

371 skills_to_remove: List[str] = None, 

372 skills_to_rewrite: List[Union[str, dict]] = None, 

373 provider: str = None, 

374 params: Dict[str, str] = None, 

375 ): 

376 """ 

377 Updates an agent in the database. 

378 

379 Parameters: 

380 agent_name (str): The name of the new agent, or existing agent to update 

381 project_name (str): The containing project 

382 name (str): The updated name of the agent 

383 model_name (str | dict): The name of the existing ML model the agent will use 

384 skills_to_add (List[Union[str, dict]]): List of skill names to add to the agent, or list of dicts 

385 with one of keys is "name", and other is additional parameters for relationship agent<>skill 

386 skills_to_remove (List[str]): List of skill names to remove from the agent 

387 skills_to_rewrite (List[Union[str, dict]]): new list of skills for the agent 

388 provider (str): The provider of the model 

389 params: (Dict[str, str]): Parameters to use when running the agent 

390 

391 Returns: 

392 agent (db.Agents): The created or updated agent 

393 

394 Raises: 

395 EntityExistsError: if agent with new name already exists 

396 EntityNotExistsError: if agent with name or skill not found 

397 ValueError: if conflict in skills list 

398 """ 

399 

400 skills_to_add = skills_to_add or [] 

401 skills_to_remove = skills_to_remove or [] 

402 skills_to_rewrite = skills_to_rewrite or [] 

403 

404 if len(skills_to_rewrite) > 0 and (len(skills_to_remove) > 0 or len(skills_to_add) > 0): 404 ↛ 405line 404 didn't jump to line 405 because the condition on line 404 was never true

405 raise ValueError( 

406 "'skills_to_rewrite' and 'skills_to_add' (or 'skills_to_remove') cannot be used at the same time" 

407 ) 

408 

409 existing_agent = self.get_agent(agent_name, project_name=project_name) 

410 if existing_agent is None: 410 ↛ 411line 410 didn't jump to line 411 because the condition on line 410 was never true

411 raise EntityNotExistsError(f"Agent with name not found: {agent_name}") 

412 existing_params = existing_agent.params or {} 

413 

414 is_demo = (existing_agent.params or {}).get("is_demo", False) 

415 if is_demo and ( 415 ↛ 421line 415 didn't jump to line 421 because the condition on line 415 was never true

416 (name is not None and name != agent_name) 

417 or (model_name is not None and existing_agent.model_name != model_name) 

418 or (provider is not None and existing_agent.provider != provider) 

419 or (isinstance(params, dict) and len(params) > 0 and "prompt_template" not in params) 

420 ): 

421 raise ValueError("It is forbidden to change properties of the demo object") 

422 

423 if name is not None and name != agent_name: 423 ↛ 425line 423 didn't jump to line 425 because the condition on line 423 was never true

424 # Check to see if updated name already exists 

425 agent_with_new_name = self.get_agent(name, project_name=project_name) 

426 if agent_with_new_name is not None: 

427 raise EntityExistsError(f"Agent with updated name already exists: {name}") 

428 existing_agent.name = name 

429 

430 if model_name or provider: 430 ↛ 431line 430 didn't jump to line 431 because the condition on line 430 was never true

431 if isinstance(model_name, dict): 

432 # move into params 

433 existing_params["model"] = model_name 

434 model_name = None 

435 

436 # check model and provider 

437 model, provider = self.check_model_provider(model_name, provider) 

438 # Update model and provider 

439 existing_agent.model_name = model_name 

440 existing_agent.provider = provider 

441 

442 if "data" in params: 

443 if len(skills_to_add) > 0 or len(skills_to_remove) > 0: 443 ↛ 444line 443 didn't jump to line 444 because the condition on line 443 was never true

444 raise ValueError( 

445 "'data' parameter cannot be used with 'skills_to_remove' or 'skills_to_add' parameters" 

446 ) 

447 

448 include_knowledge_bases = params["data"].get("knowledge_bases") 

449 include_tables = params["data"].get("tables") 

450 

451 skill = self._create_default_sql_skill( 

452 agent_name, 

453 project_name, 

454 include_tables=include_tables, 

455 include_knowledge_bases=include_knowledge_bases, 

456 ) 

457 skills_to_rewrite = [{"name": skill}] 

458 

459 # check that all skills exist 

460 skill_name_to_record_map = {} 

461 for skill_meta in skills_to_add + skills_to_remove + skills_to_rewrite: 

462 skill_name = skill_meta["name"] if isinstance(skill_meta, dict) else skill_meta 

463 if skill_name not in skill_name_to_record_map: 463 ↛ 461line 463 didn't jump to line 461 because the condition on line 463 was always true

464 skill_record = self.skills_controller.get_skill(skill_name, project_name) 

465 if skill_record is None: 465 ↛ 466line 465 didn't jump to line 466 because the condition on line 465 was never true

466 raise EntityNotExistsError(f"Skill with name does not exist: {skill_name}") 

467 skill_name_to_record_map[skill_name] = skill_record 

468 

469 if len(skills_to_add) > 0 or len(skills_to_remove) > 0: 

470 skills_to_add = [{"name": x} if isinstance(x, str) else x for x in skills_to_add] 

471 skills_to_add_names = [x["name"] for x in skills_to_add] 

472 

473 # there are no intersection between lists 

474 if not set(skills_to_add_names).isdisjoint(set(skills_to_remove)): 474 ↛ 475line 474 didn't jump to line 475 because the condition on line 474 was never true

475 raise ValueError("Conflict between skills to add and skills to remove.") 

476 

477 existing_agent_skills_names = [rel.skill.name for rel in existing_agent.skills_relationships] 

478 

479 # remove skills 

480 for skill_name in skills_to_remove: 

481 for rel in existing_agent.skills_relationships: 

482 if rel.skill.name == skill_name: 482 ↛ 481line 482 didn't jump to line 481 because the condition on line 482 was always true

483 db.session.delete(rel) 

484 

485 # add skills 

486 for skill_name in set(skills_to_add_names) - set(existing_agent_skills_names): 

487 skill_parameters = next(x for x in skills_to_add if x["name"] == skill_name).copy() 

488 del skill_parameters["name"] 

489 association = db.AgentSkillsAssociation( 

490 parameters=skill_parameters, agent=existing_agent, skill=skill_name_to_record_map[skill_name] 

491 ) 

492 db.session.add(association) 

493 

494 elif len(skills_to_rewrite) > 0: 494 ↛ 515line 494 didn't jump to line 515 because the condition on line 494 was always true

495 skill_name_to_parameters = { 

496 x["name"]: {k: v for k, v in x.items() if k != "name"} for x in skills_to_rewrite 

497 } 

498 existing_skill_names = set() 

499 for rel in existing_agent.skills_relationships: 

500 if rel.skill.name not in skill_name_to_parameters: 500 ↛ 501line 500 didn't jump to line 501 because the condition on line 500 was never true

501 db.session.delete(rel) 

502 else: 

503 existing_skill_names.add(rel.skill.name) 

504 skill_parameters = skill_name_to_parameters[rel.skill.name] 

505 rel.parameters = skill_parameters 

506 flag_modified(rel, "parameters") 

507 for new_skill_name in set(skill_name_to_parameters) - existing_skill_names: 507 ↛ 508line 507 didn't jump to line 508 because the loop on line 507 never started

508 association = db.AgentSkillsAssociation( 

509 parameters=skill_name_to_parameters[new_skill_name], 

510 agent=existing_agent, 

511 skill=skill_name_to_record_map[new_skill_name], 

512 ) 

513 db.session.add(association) 

514 

515 if params is not None: 515 ↛ 524line 515 didn't jump to line 524 because the condition on line 515 was always true

516 # Merge params on update 

517 existing_params.update(params) 

518 # Remove None values entirely. 

519 params = {k: v for k, v in existing_params.items() if v is not None} 

520 existing_agent.params = params 

521 # Some versions of SQL Alchemy won't handle JSON updates correctly without this. 

522 # See: https://docs.sqlalchemy.org/en/20/orm/session_api.html#sqlalchemy.orm.attributes.flag_modified 

523 flag_modified(existing_agent, "params") 

524 db.session.commit() 

525 

526 return existing_agent 

527 

528 def delete_agent(self, agent_name: str, project_name: str = default_project): 

529 """ 

530 Deletes an agent by name. 

531 

532 Parameters: 

533 agent_name (str): The name of the agent to delete 

534 project_name (str): The name of the containing project 

535 

536 Raises: 

537 ValueError: Agent does not exist. 

538 """ 

539 

540 agent = self.get_agent(agent_name, project_name) 

541 if agent is None: 541 ↛ 542line 541 didn't jump to line 542 because the condition on line 541 was never true

542 raise ValueError(f"Agent with name does not exist: {agent_name}") 

543 if isinstance(agent.params, dict) and agent.params.get("is_demo") is True: 543 ↛ 544line 543 didn't jump to line 544 because the condition on line 543 was never true

544 raise ValueError("Unable to delete demo object") 

545 

546 # delete autogenerated skill 

547 for rel in agent.skills_relationships: 

548 if rel.skill.params is not None and self.autogenerated_skill_prefix in rel.skill.params.get("description"): 548 ↛ 547line 548 didn't jump to line 547 because the condition on line 548 was always true

549 self.skills_controller.delete_skill(rel.skill.name, project_name, strict_case=True) 

550 

551 agent.deleted_at = datetime.datetime.now() 

552 db.session.commit() 

553 

554 def get_agent_llm_params(self, agent_params: dict): 

555 """ 

556 Get agent LLM parameters by combining default config with user provided parameters. 

557 Similar to how knowledge bases handle default parameters. 

558 """ 

559 combined_model_params = copy.deepcopy(config.get("default_llm", {})) 

560 

561 if "model" in agent_params: 561 ↛ 562line 561 didn't jump to line 562 because the condition on line 561 was never true

562 model_params = agent_params["model"] 

563 else: 

564 # params for LLM can be arbitrary 

565 model_params = agent_params 

566 

567 if model_params: 567 ↛ 570line 567 didn't jump to line 570 because the condition on line 567 was always true

568 combined_model_params.update(model_params) 

569 

570 return combined_model_params 

571 

572 def get_completion( 

573 self, 

574 agent: db.Agents, 

575 messages: list[Dict[str, str]], 

576 project_name: str = default_project, 

577 tools: list = None, 

578 stream: bool = False, 

579 params: dict | None = None, 

580 ) -> Union[Iterator[object], pd.DataFrame]: 

581 """ 

582 Queries an agent to get a completion. 

583 

584 Parameters: 

585 agent (db.Agents): Existing agent to get completion from 

586 messages (list[Dict[str, str]]): Chat history to send to the agent 

587 project_name (str): Project the agent belongs to (default mindsdb) 

588 tools (list[BaseTool]): Tools to use while getting the completion 

589 stream (bool): Whether to stream the response 

590 params (dict | None): params to redefine agent params 

591 

592 Returns: 

593 response (Union[Iterator[object], pd.DataFrame]): Completion as a DataFrame or iterator of completion chunks 

594 

595 Raises: 

596 ValueError: Agent's model does not exist. 

597 """ 

598 if stream: 598 ↛ 599line 598 didn't jump to line 599 because the condition on line 598 was never true

599 return self._get_completion_stream(agent, messages, project_name=project_name, tools=tools, params=params) 

600 from .langchain_agent import LangchainAgent 

601 

602 model, provider = self.check_model_provider(agent.model_name, agent.provider) 

603 # update old agents 

604 if agent.provider is None and provider is not None: 604 ↛ 605line 604 didn't jump to line 605 because the condition on line 604 was never true

605 agent.provider = provider 

606 db.session.commit() 

607 

608 # Get agent parameters and combine with default LLM parameters at runtime 

609 llm_params = self.get_agent_llm_params(agent.params) 

610 

611 lang_agent = LangchainAgent(agent, model, llm_params=llm_params) 

612 return lang_agent.get_completion(messages, params=params) 

613 

614 def _get_completion_stream( 

615 self, 

616 agent: db.Agents, 

617 messages: list[Dict[str, str]], 

618 project_name: str = default_project, 

619 tools: list = None, 

620 params: dict | None = None, 

621 ) -> Iterator[object]: 

622 """ 

623 Queries an agent to get a stream of completion chunks. 

624 

625 Parameters: 

626 agent (db.Agents): Existing agent to get completion from 

627 messages (list[Dict[str, str]]): Chat history to send to the agent 

628 trace_id (str): ID of Langfuse trace to use 

629 observation_id (str): ID of parent Langfuse observation to use 

630 project_name (str): Project the agent belongs to (default mindsdb) 

631 tools (list[BaseTool]): Tools to use while getting the completion 

632 params (dict | None): params to redefine agent params 

633 

634 Returns: 

635 chunks (Iterator[object]): Completion chunks as an iterator 

636 

637 Raises: 

638 ValueError: Agent's model does not exist. 

639 """ 

640 # For circular dependency. 

641 from .langchain_agent import LangchainAgent 

642 

643 model, provider = self.check_model_provider(agent.model_name, agent.provider) 

644 

645 # update old agents 

646 if agent.provider is None and provider is not None: 

647 agent.provider = provider 

648 db.session.commit() 

649 

650 # Get agent parameters and combine with default LLM parameters at runtime 

651 llm_params = self.get_agent_llm_params(agent.params) 

652 

653 lang_agent = LangchainAgent(agent, model=model, llm_params=llm_params) 

654 return lang_agent.get_completion(messages, stream=True, params=params)