Coverage for mindsdb / interfaces / storage / db.py: 92%

395 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import json 

2import orjson 

3import datetime 

4import os 

5from typing import Dict, List, Optional 

6 

7import numpy as np 

8from sqlalchemy import ( 

9 JSON, 

10 Boolean, 

11 Column, 

12 DateTime, 

13 Index, 

14 Integer, 

15 LargeBinary, 

16 Numeric, 

17 String, 

18 UniqueConstraint, 

19 create_engine, 

20 text, 

21 types, 

22) 

23from sqlalchemy.exc import OperationalError 

24from sqlalchemy.orm import ( 

25 Mapped, 

26 mapped_column, 

27 declarative_base, 

28 relationship, 

29 scoped_session, 

30 sessionmaker, 

31) 

32from sqlalchemy.sql.schema import ForeignKey 

33from mind_castle.sqlalchemy_type import SecretData 

34 

35from mindsdb.utilities.json_encoder import CustomJSONEncoder 

36from mindsdb.utilities.config import config 

37 

38 

39class Base: 

40 __allow_unmapped__ = True 

41 

42 

43Base = declarative_base(cls=Base) 

44 

45session, engine = None, None 

46 

47 

48def init(connection_str: str = None): 

49 global Base, session, engine 

50 if connection_str is None: 50 ↛ 53line 50 didn't jump to line 53 because the condition on line 50 was always true

51 connection_str = config["storage_db"] 

52 # Use orjson with our CustomJSONEncoder.default for JSON serialization 

53 _default_json = CustomJSONEncoder().default 

54 

55 def _json_serializer(value): 

56 return orjson.dumps( 

57 value, 

58 default=_default_json, 

59 option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_PASSTHROUGH_DATETIME, 

60 ).decode("utf-8") 

61 

62 base_args = { 

63 "pool_size": 30, 

64 "max_overflow": 200, 

65 "json_serializer": _json_serializer, 

66 } 

67 engine = create_engine(connection_str, echo=False, **base_args) 

68 session = scoped_session(sessionmaker(bind=engine, autoflush=True)) 

69 Base.query = session.query_property() 

70 

71 

72def serializable_insert(record: Base, try_count: int = 100): 

73 """Do serializeble insert. If fail - repeat it {try_count} times. 

74 

75 Args: 

76 record (Base): sqlalchey record to insert 

77 try_count (int): count of tryes to insert record 

78 """ 

79 commited = False 

80 while not commited: 

81 session.connection(execution_options={"isolation_level": "SERIALIZABLE"}) 

82 if engine.name == "postgresql": 82 ↛ 83line 82 didn't jump to line 83 because the condition on line 82 was never true

83 session.execute(text("LOCK TABLE PREDICTOR IN EXCLUSIVE MODE")) 

84 session.add(record) 

85 try: 

86 session.commit() 

87 except OperationalError: 

88 # catch 'SerializationFailure' (it should be in str(e), but it may depend on engine) 

89 session.rollback() 

90 try_count += -1 

91 if try_count == 0: 

92 raise 

93 else: 

94 commited = True 

95 

96 

97# Source: https://stackoverflow.com/questions/26646362/numpy-array-is-not-json-serializable 

98class NumpyEncoder(json.JSONEncoder): 

99 """Special json encoder for numpy types""" 

100 

101 def default(self, obj): 

102 if isinstance(obj, np.integer): 

103 return int(obj) 

104 elif isinstance(obj, np.floating): 

105 return float(obj) 

106 elif isinstance(obj, np.ndarray): 

107 return obj.tolist() 

108 return json.JSONEncoder.default(self, obj) 

109 

110 

111class Array(types.TypeDecorator): 

112 """Float Type that replaces commas with dots on input""" 

113 

114 impl = types.String 

115 

116 def process_bind_param(self, value, dialect): # insert 

117 if isinstance(value, str): 117 ↛ 119line 117 didn't jump to line 119 because the condition on line 117 was always true

118 return value 

119 elif value is None: 

120 return value 

121 else: 

122 return ",|,|,".join(value) 

123 

124 def process_result_value(self, value, dialect): # select 

125 return value.split(",|,|,") if value is not None else None 

126 

127 

128class Json(types.TypeDecorator): 

129 """Float Type that replaces commas with dots on input""" 

130 

131 impl = types.String 

132 

133 def process_bind_param(self, value, dialect): # insert 

134 return json.dumps(value, cls=NumpyEncoder) if value is not None else None 

135 

136 def process_result_value(self, value, dialect): # select 

137 if isinstance(value, dict): 137 ↛ 138line 137 didn't jump to line 138 because the condition on line 137 was never true

138 return value 

139 return json.loads(value) if value is not None else None 

140 

141 

142# Use MindsDB's "Json" column type as a backend for mind-castle 

143class SecretDataJson(SecretData): 

144 impl = Json 

145 

146 

147class PREDICTOR_STATUS: 

148 __slots__ = () 

149 COMPLETE = "complete" 

150 TRAINING = "training" 

151 FINETUNING = "finetuning" 

152 GENERATING = "generating" 

153 ERROR = "error" 

154 VALIDATION = "validation" 

155 DELETED = "deleted" # TODO remove it? 

156 

157 

158PREDICTOR_STATUS = PREDICTOR_STATUS() 

159 

160 

161class Predictor(Base): 

162 __tablename__ = "predictor" 

163 

164 id = Column(Integer, primary_key=True) 

165 updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) 

166 created_at = Column(DateTime, default=datetime.datetime.now) 

167 deleted_at = Column(DateTime) 

168 name = Column(String) 

169 data = Column(Json) # A JSON -- should be everything returned by `get_model_data`, I think 

170 to_predict = Column(Array) 

171 company_id = Column(String) 

172 mindsdb_version = Column(String) 

173 native_version = Column(String) 

174 integration_id = Column(ForeignKey("integration.id", name="fk_integration_id")) 

175 data_integration_ref = Column(Json) 

176 fetch_data_query = Column(String) 

177 learn_args = Column(Json) 

178 update_status = Column(String, default="up_to_date") 

179 status = Column(String) 

180 active = Column(Boolean, default=True) 

181 training_data_columns_count = Column(Integer) 

182 training_data_rows_count = Column(Integer) 

183 training_start_at = Column(DateTime) 

184 training_stop_at = Column(DateTime) 

185 label = Column(String, nullable=True) 

186 version = Column(Integer, default=1) 

187 code = Column(String, nullable=True) 

188 lightwood_version = Column(String, nullable=True) 

189 dtype_dict = Column(Json, nullable=True) 

190 project_id = Column(Integer, ForeignKey("project.id", name="fk_project_id"), nullable=False) 

191 training_phase_current = Column(Integer) 

192 training_phase_total = Column(Integer) 

193 training_phase_name = Column(String) 

194 training_metadata = Column(JSON, default={}, nullable=False) 

195 

196 @staticmethod 

197 def get_name_and_version(full_name): 

198 name_no_version = full_name 

199 version = None 

200 parts = full_name.split(".") 

201 if len(parts) > 1 and parts[-1].isdigit(): 

202 version = int(parts[-1]) 

203 name_no_version = ".".join(parts[:-1]) 

204 return name_no_version, version 

205 

206 

207Index( 

208 "predictor_index", 

209 Predictor.company_id, 

210 Predictor.name, 

211 Predictor.version, 

212 Predictor.active, 

213 Predictor.deleted_at, # would be good to have here nullsfirst(Predictor.deleted_at) 

214 unique=True, 

215) 

216 

217 

218class Project(Base): 

219 __tablename__ = "project" 

220 

221 id = Column(Integer, primary_key=True) 

222 created_at = Column(DateTime, default=datetime.datetime.now) 

223 updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) 

224 deleted_at = Column(DateTime) 

225 name = Column(String, nullable=False) 

226 company_id = Column(String, default="0") 

227 metadata_: dict = Column("metadata", JSON, nullable=True) 

228 __table_args__ = (UniqueConstraint("name", "company_id", name="unique_project_name_company_id"),) 

229 

230 

231class Integration(Base): 

232 __tablename__ = "integration" 

233 id = Column(Integer, primary_key=True) 

234 updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) 

235 created_at = Column(DateTime, default=datetime.datetime.now) 

236 name = Column(String, nullable=False) 

237 engine = Column(String, nullable=False) 

238 data = Column(SecretDataJson(os.environ.get("MINDSDB_DATA_ENCRYPTION_TYPE", "none"))) 

239 company_id = Column(String) 

240 

241 __table_args__ = (UniqueConstraint("name", "company_id", name="unique_integration_name_company_id"),) 

242 

243 

244class File(Base): 

245 __tablename__ = "file" 

246 id = Column(Integer, primary_key=True) 

247 name = Column(String, nullable=False) 

248 company_id = Column(String) 

249 source_file_path = Column(String, nullable=False) 

250 file_path = Column(String, nullable=False) 

251 row_count = Column(Integer, nullable=False) 

252 columns = Column(Json, nullable=False) 

253 created_at = Column(DateTime, default=datetime.datetime.now) 

254 metadata_: dict = Column("metadata", JSON, nullable=True) 

255 updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) 

256 __table_args__ = (UniqueConstraint("name", "company_id", name="unique_file_name_company_id"),) 

257 

258 

259class View(Base): 

260 __tablename__ = "view" 

261 id = Column(Integer, primary_key=True) 

262 name = Column(String, nullable=False) 

263 company_id = Column(String) 

264 query = Column(String, nullable=False) 

265 project_id = Column(Integer, ForeignKey("project.id", name="fk_project_id"), nullable=False) 

266 __table_args__ = (UniqueConstraint("name", "company_id", name="unique_view_name_company_id"),) 

267 

268 

269class JsonStorage(Base): 

270 __tablename__ = "json_storage" 

271 id = Column(Integer, primary_key=True) 

272 resource_group = Column(String) 

273 resource_id = Column(Integer) 

274 name = Column(String) 

275 content = Column(JSON) 

276 encrypted_content = Column(LargeBinary, nullable=True) 

277 company_id = Column(String) 

278 

279 def to_dict(self) -> Dict: 

280 return { 

281 "id": self.id, 

282 "resource_group": self.resource_group, 

283 "resource_id": self.resource_id, 

284 "name": self.name, 

285 "content": self.content, 

286 "encrypted_content": self.encrypted_content, 

287 "company_id": self.company_id, 

288 } 

289 

290 

291class Jobs(Base): 

292 __tablename__ = "jobs" 

293 id = Column(Integer, primary_key=True) 

294 company_id = Column(String) 

295 user_class = Column(Integer, nullable=True) 

296 active = Column(Boolean, default=True) 

297 

298 name = Column(String, nullable=False) 

299 project_id = Column(Integer, nullable=False) 

300 query_str = Column(String, nullable=False) 

301 if_query_str = Column(String, nullable=True) 

302 start_at = Column(DateTime, default=datetime.datetime.now) 

303 end_at = Column(DateTime) 

304 next_run_at = Column(DateTime) 

305 schedule_str = Column(String) 

306 

307 deleted_at = Column(DateTime) 

308 updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) 

309 created_at = Column(DateTime, default=datetime.datetime.now) 

310 

311 

312class JobsHistory(Base): 

313 __tablename__ = "jobs_history" 

314 id = Column(Integer, primary_key=True) 

315 company_id = Column(String) 

316 

317 job_id = Column(Integer) 

318 

319 query_str = Column(String) 

320 start_at = Column(DateTime) 

321 end_at = Column(DateTime) 

322 

323 error = Column(String) 

324 created_at = Column(DateTime, default=datetime.datetime.now) 

325 updated_at = Column(DateTime, default=datetime.datetime.now) 

326 

327 __table_args__ = (UniqueConstraint("job_id", "start_at", name="uniq_job_history_job_id_start"),) 

328 

329 

330class ChatBots(Base): 

331 __tablename__ = "chat_bots" 

332 id = Column(Integer, primary_key=True) 

333 

334 name = Column(String, nullable=False) 

335 project_id = Column(Integer, nullable=False) 

336 agent_id = Column(ForeignKey("agents.id", name="fk_agent_id")) 

337 

338 # To be removed when existing chatbots are backfilled with newly created Agents. 

339 model_name = Column(String) 

340 database_id = Column(Integer) 

341 params = Column(JSON) 

342 

343 updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) 

344 created_at = Column(DateTime, default=datetime.datetime.now) 

345 webhook_token = Column(String) 

346 

347 def as_dict(self) -> Dict: 

348 return { 

349 "id": self.id, 

350 "name": self.name, 

351 "project_id": self.project_id, 

352 "agent_id": self.agent_id, 

353 "model_name": self.model_name, 

354 "params": self.params, 

355 "webhook_token": self.webhook_token, 

356 "created_at": self.created_at, 

357 "database_id": self.database_id, 

358 } 

359 

360 

361class ChatBotsHistory(Base): 

362 __tablename__ = "chat_bots_history" 

363 id = Column(Integer, primary_key=True) 

364 chat_bot_id = Column(Integer, nullable=False) 

365 type = Column(String) # TODO replace to enum 

366 text = Column(String) 

367 user = Column(String) 

368 destination = Column(String) 

369 sent_at = Column(DateTime, default=datetime.datetime.now) 

370 error = Column(String) 

371 

372 

373class Triggers(Base): 

374 __tablename__ = "triggers" 

375 id = Column(Integer, primary_key=True) 

376 

377 name = Column(String, nullable=False) 

378 project_id = Column(Integer, nullable=False) 

379 

380 database_id = Column(Integer, nullable=False) 

381 table_name = Column(String, nullable=False) 

382 query_str = Column(String, nullable=False) 

383 columns = Column(String) # list of columns separated by delimiter 

384 

385 updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) 

386 created_at = Column(DateTime, default=datetime.datetime.now) 

387 

388 

389class Tasks(Base): 

390 __tablename__ = "tasks" 

391 id = Column(Integer, primary_key=True) 

392 company_id = Column(String) 

393 user_class = Column(Integer, nullable=True) 

394 

395 # trigger, chatbot 

396 object_type = Column(String, nullable=False) 

397 object_id = Column(Integer, nullable=False) 

398 

399 last_error = Column(String) 

400 active = Column(Boolean, default=True) 

401 reload = Column(Boolean, default=False) 

402 

403 # for running in concurrent processes 

404 run_by = Column(String) 

405 alive_time = Column(DateTime(timezone=True)) 

406 

407 updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) 

408 created_at = Column(DateTime, default=datetime.datetime.now) 

409 

410 

411class AgentSkillsAssociation(Base): 

412 __tablename__ = "agent_skills" 

413 

414 agent_id: Mapped[int] = mapped_column(ForeignKey("agents.id"), primary_key=True) 

415 skill_id: Mapped[int] = mapped_column(ForeignKey("skills.id"), primary_key=True) 

416 parameters: Mapped[dict] = mapped_column(JSON, default={}, nullable=True) 

417 

418 agent = relationship("Agents", back_populates="skills_relationships") 

419 skill = relationship("Skills", back_populates="agents_relationships") 

420 

421 

422class Skills(Base): 

423 __tablename__ = "skills" 

424 id = Column(Integer, primary_key=True) 

425 agents_relationships: Mapped[List["Agents"]] = relationship(AgentSkillsAssociation, back_populates="skill") 

426 name = Column(String, nullable=False) 

427 project_id = Column(Integer, nullable=False) 

428 type = Column(String, nullable=False) 

429 params = Column(JSON) 

430 

431 created_at = Column(DateTime, default=datetime.datetime.now) 

432 updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) 

433 deleted_at = Column(DateTime) 

434 

435 def as_dict(self) -> Dict: 

436 return { 

437 "id": self.id, 

438 "name": self.name, 

439 "project_id": self.project_id, 

440 "agent_ids": [rel.agent.id for rel in self.agents_relationships], 

441 "type": self.type, 

442 "params": self.params, 

443 "created_at": self.created_at, 

444 } 

445 

446 

447class Agents(Base): 

448 __tablename__ = "agents" 

449 id = Column(Integer, primary_key=True) 

450 skills_relationships: Mapped[List["Skills"]] = relationship(AgentSkillsAssociation, back_populates="agent") 

451 company_id = Column(String, nullable=True) 

452 user_class = Column(Integer, nullable=True) 

453 

454 name = Column(String, nullable=False) 

455 project_id = Column(Integer, nullable=False) 

456 

457 model_name = Column(String, nullable=True) 

458 provider = Column(String, nullable=True) 

459 params = Column(JSON) 

460 

461 updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) 

462 created_at = Column(DateTime, default=datetime.datetime.now) 

463 deleted_at = Column(DateTime) 

464 

465 def as_dict(self) -> Dict: 

466 skills = [] 

467 skills_extra_parameters = {} 

468 for rel in self.skills_relationships: 

469 skill = rel.skill 

470 # Skip auto-generated SQL skills 

471 if skill.params.get("description", "").startswith("Auto-generated SQL skill for agent"): 

472 continue 

473 skills.append(skill.as_dict()) 

474 skills_extra_parameters[skill.name] = rel.parameters or {} 

475 

476 params = self.params.copy() 

477 

478 agent_dict = { 

479 "id": self.id, 

480 "name": self.name, 

481 "project_id": self.project_id, 

482 "updated_at": self.updated_at, 

483 "created_at": self.created_at, 

484 } 

485 

486 if self.model_name: 

487 agent_dict["model_name"] = self.model_name 

488 

489 if self.provider: 

490 agent_dict["provider"] = self.provider 

491 

492 # Since skills were depreciated, they are only used with Minds 

493 # Minds expects the parameters to be provided as is without breaking them down 

494 if skills: 

495 agent_dict["skills"] = skills 

496 agent_dict["skills_extra_parameters"] = skills_extra_parameters 

497 agent_dict["params"] = params 

498 else: 

499 data = params.pop("data", {}) 

500 model = params.pop("model", {}) 

501 prompt_template = params.pop("prompt_template", None) 

502 if data: 

503 agent_dict["data"] = data 

504 if model: 

505 agent_dict["model"] = model 

506 if prompt_template: 506 ↛ 508line 506 didn't jump to line 508 because the condition on line 506 was always true

507 agent_dict["prompt_template"] = prompt_template 

508 if params: 

509 agent_dict["params"] = params 

510 

511 return agent_dict 

512 

513 

514class KnowledgeBase(Base): 

515 __tablename__ = "knowledge_base" 

516 id = Column(Integer, primary_key=True) 

517 name = Column(String, nullable=False) 

518 project_id = Column(Integer, nullable=False) 

519 params = Column(JSON) 

520 

521 vector_database_id = Column( 

522 ForeignKey("integration.id", name="fk_knowledge_base_vector_database_id"), 

523 doc="fk to the vector database integration", 

524 ) 

525 vector_database = relationship( 

526 "Integration", 

527 foreign_keys=[vector_database_id], 

528 doc="vector database integration", 

529 ) 

530 

531 vector_database_table = Column(String, doc="table name in the vector database") 

532 

533 embedding_model_id = Column( 

534 ForeignKey("predictor.id", name="fk_knowledge_base_embedding_model_id"), 

535 doc="fk to the embedding model", 

536 ) 

537 

538 embedding_model = relationship("Predictor", foreign_keys=[embedding_model_id], doc="embedding model") 

539 query_id = Column(Integer, nullable=True) 

540 

541 created_at = Column(DateTime, default=datetime.datetime.now) 

542 updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) 

543 

544 __table_args__ = (UniqueConstraint("name", "project_id", name="unique_knowledge_base_name_project_id"),) 

545 

546 def as_dict(self, with_secrets: Optional[bool] = True) -> Dict: 

547 params = self.params.copy() 

548 embedding_model = params.pop("embedding_model", None) 

549 reranking_model = params.pop("reranking_model", None) 

550 

551 if not with_secrets: 551 ↛ 557line 551 didn't jump to line 557 because the condition on line 551 was always true

552 for key in ("api_key", "private_key"): 

553 for el in (embedding_model, reranking_model): 

554 if el and key in el: 

555 el[key] = "******" 

556 

557 return { 

558 "id": self.id, 

559 "name": self.name, 

560 "project_id": self.project_id, 

561 "vector_database": None if self.vector_database is None else self.vector_database.name, 

562 "vector_database_table": self.vector_database_table, 

563 "updated_at": self.updated_at, 

564 "created_at": self.created_at, 

565 "query_id": self.query_id, 

566 "embedding_model": embedding_model, 

567 "reranking_model": reranking_model, 

568 "metadata_columns": params.pop("metadata_columns", None), 

569 "content_columns": params.pop("content_columns", None), 

570 "id_column": params.pop("id_column", None), 

571 "params": params, 

572 } 

573 

574 

575class QueryContext(Base): 

576 __tablename__ = "query_context" 

577 id: int = Column(Integer, primary_key=True) 

578 company_id: int = Column(String, nullable=True) 

579 

580 query: str = Column(String, nullable=False) 

581 context_name: str = Column(String, nullable=False) 

582 values: dict = Column(JSON) 

583 

584 updated_at: datetime.datetime = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) 

585 created_at: datetime.datetime = Column(DateTime, default=datetime.datetime.now) 

586 

587 

588class Queries(Base): 

589 __tablename__ = "queries" 

590 id: int = Column(Integer, primary_key=True) 

591 company_id: int = Column(String, nullable=True) 

592 

593 sql: str = Column(String, nullable=False) 

594 database: str = Column(String, nullable=True) 

595 

596 started_at: datetime.datetime = Column(DateTime) 

597 finished_at: datetime.datetime = Column(DateTime) 

598 

599 parameters = Column(JSON, default={}) 

600 context = Column(JSON, default={}) 

601 processed_rows = Column(Integer, default=0) 

602 error: str = Column(String, nullable=True) 

603 

604 updated_at: datetime.datetime = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now) 

605 created_at: datetime.datetime = Column(DateTime, default=datetime.datetime.now) 

606 

607 

608class LLMLog(Base): 

609 __tablename__ = "llm_log" 

610 id: int = Column(Integer, primary_key=True) 

611 company_id: int = Column(String, nullable=False) 

612 api_key: str = Column(String, nullable=True) 

613 model_id: int = Column(Integer, nullable=True) 

614 model_group: str = Column(String, nullable=True) 

615 input: str = Column(JSON, nullable=True) 

616 output: str = Column(JSON, nullable=True) 

617 start_time: datetime = Column(DateTime, nullable=False) 

618 end_time: datetime = Column(DateTime, nullable=True) 

619 cost: float = Column(Numeric(5, 2), nullable=True) 

620 prompt_tokens: int = Column(Integer, nullable=True) 

621 completion_tokens: int = Column(Integer, nullable=True) 

622 total_tokens: int = Column(Integer, nullable=True) 

623 success: bool = Column(Boolean, nullable=False, default=True) 

624 exception: str = Column(String, nullable=True) 

625 traceback: str = Column(String, nullable=True) 

626 stream: bool = Column(Boolean, default=False, comment="Is this completion done in 'streaming' mode") 

627 metadata_: dict = Column("metadata", JSON, nullable=True) 

628 

629 

630class LLMData(Base): 

631 """ 

632 Stores the question/answer pairs of an LLM call so examples can be used 

633 for self improvement with DSPy 

634 """ 

635 

636 __tablename__ = "llm_data" 

637 id: int = Column(Integer, primary_key=True) 

638 input: str = Column(String, nullable=False) 

639 output: str = Column(String, nullable=False) 

640 model_id: int = Column(Integer, nullable=False) 

641 created_at: datetime = Column(DateTime, default=datetime.datetime.now) 

642 updated_at: datetime = Column(DateTime, onupdate=datetime.datetime.now)