Coverage for mindsdb / interfaces / storage / db.py: 92%
395 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import json
2import orjson
3import datetime
4import os
5from typing import Dict, List, Optional
7import numpy as np
8from sqlalchemy import (
9 JSON,
10 Boolean,
11 Column,
12 DateTime,
13 Index,
14 Integer,
15 LargeBinary,
16 Numeric,
17 String,
18 UniqueConstraint,
19 create_engine,
20 text,
21 types,
22)
23from sqlalchemy.exc import OperationalError
24from sqlalchemy.orm import (
25 Mapped,
26 mapped_column,
27 declarative_base,
28 relationship,
29 scoped_session,
30 sessionmaker,
31)
32from sqlalchemy.sql.schema import ForeignKey
33from mind_castle.sqlalchemy_type import SecretData
35from mindsdb.utilities.json_encoder import CustomJSONEncoder
36from mindsdb.utilities.config import config
39class Base:
40 __allow_unmapped__ = True
43Base = declarative_base(cls=Base)
45session, engine = None, None
48def init(connection_str: str = None):
49 global Base, session, engine
50 if connection_str is None: 50 ↛ 53line 50 didn't jump to line 53 because the condition on line 50 was always true
51 connection_str = config["storage_db"]
52 # Use orjson with our CustomJSONEncoder.default for JSON serialization
53 _default_json = CustomJSONEncoder().default
55 def _json_serializer(value):
56 return orjson.dumps(
57 value,
58 default=_default_json,
59 option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_PASSTHROUGH_DATETIME,
60 ).decode("utf-8")
62 base_args = {
63 "pool_size": 30,
64 "max_overflow": 200,
65 "json_serializer": _json_serializer,
66 }
67 engine = create_engine(connection_str, echo=False, **base_args)
68 session = scoped_session(sessionmaker(bind=engine, autoflush=True))
69 Base.query = session.query_property()
72def serializable_insert(record: Base, try_count: int = 100):
73 """Do serializeble insert. If fail - repeat it {try_count} times.
75 Args:
76 record (Base): sqlalchey record to insert
77 try_count (int): count of tryes to insert record
78 """
79 commited = False
80 while not commited:
81 session.connection(execution_options={"isolation_level": "SERIALIZABLE"})
82 if engine.name == "postgresql": 82 ↛ 83line 82 didn't jump to line 83 because the condition on line 82 was never true
83 session.execute(text("LOCK TABLE PREDICTOR IN EXCLUSIVE MODE"))
84 session.add(record)
85 try:
86 session.commit()
87 except OperationalError:
88 # catch 'SerializationFailure' (it should be in str(e), but it may depend on engine)
89 session.rollback()
90 try_count += -1
91 if try_count == 0:
92 raise
93 else:
94 commited = True
97# Source: https://stackoverflow.com/questions/26646362/numpy-array-is-not-json-serializable
98class NumpyEncoder(json.JSONEncoder):
99 """Special json encoder for numpy types"""
101 def default(self, obj):
102 if isinstance(obj, np.integer):
103 return int(obj)
104 elif isinstance(obj, np.floating):
105 return float(obj)
106 elif isinstance(obj, np.ndarray):
107 return obj.tolist()
108 return json.JSONEncoder.default(self, obj)
111class Array(types.TypeDecorator):
112 """Float Type that replaces commas with dots on input"""
114 impl = types.String
116 def process_bind_param(self, value, dialect): # insert
117 if isinstance(value, str): 117 ↛ 119line 117 didn't jump to line 119 because the condition on line 117 was always true
118 return value
119 elif value is None:
120 return value
121 else:
122 return ",|,|,".join(value)
124 def process_result_value(self, value, dialect): # select
125 return value.split(",|,|,") if value is not None else None
128class Json(types.TypeDecorator):
129 """Float Type that replaces commas with dots on input"""
131 impl = types.String
133 def process_bind_param(self, value, dialect): # insert
134 return json.dumps(value, cls=NumpyEncoder) if value is not None else None
136 def process_result_value(self, value, dialect): # select
137 if isinstance(value, dict): 137 ↛ 138line 137 didn't jump to line 138 because the condition on line 137 was never true
138 return value
139 return json.loads(value) if value is not None else None
142# Use MindsDB's "Json" column type as a backend for mind-castle
143class SecretDataJson(SecretData):
144 impl = Json
147class PREDICTOR_STATUS:
148 __slots__ = ()
149 COMPLETE = "complete"
150 TRAINING = "training"
151 FINETUNING = "finetuning"
152 GENERATING = "generating"
153 ERROR = "error"
154 VALIDATION = "validation"
155 DELETED = "deleted" # TODO remove it?
158PREDICTOR_STATUS = PREDICTOR_STATUS()
161class Predictor(Base):
162 __tablename__ = "predictor"
164 id = Column(Integer, primary_key=True)
165 updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now)
166 created_at = Column(DateTime, default=datetime.datetime.now)
167 deleted_at = Column(DateTime)
168 name = Column(String)
169 data = Column(Json) # A JSON -- should be everything returned by `get_model_data`, I think
170 to_predict = Column(Array)
171 company_id = Column(String)
172 mindsdb_version = Column(String)
173 native_version = Column(String)
174 integration_id = Column(ForeignKey("integration.id", name="fk_integration_id"))
175 data_integration_ref = Column(Json)
176 fetch_data_query = Column(String)
177 learn_args = Column(Json)
178 update_status = Column(String, default="up_to_date")
179 status = Column(String)
180 active = Column(Boolean, default=True)
181 training_data_columns_count = Column(Integer)
182 training_data_rows_count = Column(Integer)
183 training_start_at = Column(DateTime)
184 training_stop_at = Column(DateTime)
185 label = Column(String, nullable=True)
186 version = Column(Integer, default=1)
187 code = Column(String, nullable=True)
188 lightwood_version = Column(String, nullable=True)
189 dtype_dict = Column(Json, nullable=True)
190 project_id = Column(Integer, ForeignKey("project.id", name="fk_project_id"), nullable=False)
191 training_phase_current = Column(Integer)
192 training_phase_total = Column(Integer)
193 training_phase_name = Column(String)
194 training_metadata = Column(JSON, default={}, nullable=False)
196 @staticmethod
197 def get_name_and_version(full_name):
198 name_no_version = full_name
199 version = None
200 parts = full_name.split(".")
201 if len(parts) > 1 and parts[-1].isdigit():
202 version = int(parts[-1])
203 name_no_version = ".".join(parts[:-1])
204 return name_no_version, version
207Index(
208 "predictor_index",
209 Predictor.company_id,
210 Predictor.name,
211 Predictor.version,
212 Predictor.active,
213 Predictor.deleted_at, # would be good to have here nullsfirst(Predictor.deleted_at)
214 unique=True,
215)
218class Project(Base):
219 __tablename__ = "project"
221 id = Column(Integer, primary_key=True)
222 created_at = Column(DateTime, default=datetime.datetime.now)
223 updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now)
224 deleted_at = Column(DateTime)
225 name = Column(String, nullable=False)
226 company_id = Column(String, default="0")
227 metadata_: dict = Column("metadata", JSON, nullable=True)
228 __table_args__ = (UniqueConstraint("name", "company_id", name="unique_project_name_company_id"),)
231class Integration(Base):
232 __tablename__ = "integration"
233 id = Column(Integer, primary_key=True)
234 updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now)
235 created_at = Column(DateTime, default=datetime.datetime.now)
236 name = Column(String, nullable=False)
237 engine = Column(String, nullable=False)
238 data = Column(SecretDataJson(os.environ.get("MINDSDB_DATA_ENCRYPTION_TYPE", "none")))
239 company_id = Column(String)
241 __table_args__ = (UniqueConstraint("name", "company_id", name="unique_integration_name_company_id"),)
244class File(Base):
245 __tablename__ = "file"
246 id = Column(Integer, primary_key=True)
247 name = Column(String, nullable=False)
248 company_id = Column(String)
249 source_file_path = Column(String, nullable=False)
250 file_path = Column(String, nullable=False)
251 row_count = Column(Integer, nullable=False)
252 columns = Column(Json, nullable=False)
253 created_at = Column(DateTime, default=datetime.datetime.now)
254 metadata_: dict = Column("metadata", JSON, nullable=True)
255 updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now)
256 __table_args__ = (UniqueConstraint("name", "company_id", name="unique_file_name_company_id"),)
259class View(Base):
260 __tablename__ = "view"
261 id = Column(Integer, primary_key=True)
262 name = Column(String, nullable=False)
263 company_id = Column(String)
264 query = Column(String, nullable=False)
265 project_id = Column(Integer, ForeignKey("project.id", name="fk_project_id"), nullable=False)
266 __table_args__ = (UniqueConstraint("name", "company_id", name="unique_view_name_company_id"),)
269class JsonStorage(Base):
270 __tablename__ = "json_storage"
271 id = Column(Integer, primary_key=True)
272 resource_group = Column(String)
273 resource_id = Column(Integer)
274 name = Column(String)
275 content = Column(JSON)
276 encrypted_content = Column(LargeBinary, nullable=True)
277 company_id = Column(String)
279 def to_dict(self) -> Dict:
280 return {
281 "id": self.id,
282 "resource_group": self.resource_group,
283 "resource_id": self.resource_id,
284 "name": self.name,
285 "content": self.content,
286 "encrypted_content": self.encrypted_content,
287 "company_id": self.company_id,
288 }
291class Jobs(Base):
292 __tablename__ = "jobs"
293 id = Column(Integer, primary_key=True)
294 company_id = Column(String)
295 user_class = Column(Integer, nullable=True)
296 active = Column(Boolean, default=True)
298 name = Column(String, nullable=False)
299 project_id = Column(Integer, nullable=False)
300 query_str = Column(String, nullable=False)
301 if_query_str = Column(String, nullable=True)
302 start_at = Column(DateTime, default=datetime.datetime.now)
303 end_at = Column(DateTime)
304 next_run_at = Column(DateTime)
305 schedule_str = Column(String)
307 deleted_at = Column(DateTime)
308 updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now)
309 created_at = Column(DateTime, default=datetime.datetime.now)
312class JobsHistory(Base):
313 __tablename__ = "jobs_history"
314 id = Column(Integer, primary_key=True)
315 company_id = Column(String)
317 job_id = Column(Integer)
319 query_str = Column(String)
320 start_at = Column(DateTime)
321 end_at = Column(DateTime)
323 error = Column(String)
324 created_at = Column(DateTime, default=datetime.datetime.now)
325 updated_at = Column(DateTime, default=datetime.datetime.now)
327 __table_args__ = (UniqueConstraint("job_id", "start_at", name="uniq_job_history_job_id_start"),)
330class ChatBots(Base):
331 __tablename__ = "chat_bots"
332 id = Column(Integer, primary_key=True)
334 name = Column(String, nullable=False)
335 project_id = Column(Integer, nullable=False)
336 agent_id = Column(ForeignKey("agents.id", name="fk_agent_id"))
338 # To be removed when existing chatbots are backfilled with newly created Agents.
339 model_name = Column(String)
340 database_id = Column(Integer)
341 params = Column(JSON)
343 updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now)
344 created_at = Column(DateTime, default=datetime.datetime.now)
345 webhook_token = Column(String)
347 def as_dict(self) -> Dict:
348 return {
349 "id": self.id,
350 "name": self.name,
351 "project_id": self.project_id,
352 "agent_id": self.agent_id,
353 "model_name": self.model_name,
354 "params": self.params,
355 "webhook_token": self.webhook_token,
356 "created_at": self.created_at,
357 "database_id": self.database_id,
358 }
361class ChatBotsHistory(Base):
362 __tablename__ = "chat_bots_history"
363 id = Column(Integer, primary_key=True)
364 chat_bot_id = Column(Integer, nullable=False)
365 type = Column(String) # TODO replace to enum
366 text = Column(String)
367 user = Column(String)
368 destination = Column(String)
369 sent_at = Column(DateTime, default=datetime.datetime.now)
370 error = Column(String)
373class Triggers(Base):
374 __tablename__ = "triggers"
375 id = Column(Integer, primary_key=True)
377 name = Column(String, nullable=False)
378 project_id = Column(Integer, nullable=False)
380 database_id = Column(Integer, nullable=False)
381 table_name = Column(String, nullable=False)
382 query_str = Column(String, nullable=False)
383 columns = Column(String) # list of columns separated by delimiter
385 updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now)
386 created_at = Column(DateTime, default=datetime.datetime.now)
389class Tasks(Base):
390 __tablename__ = "tasks"
391 id = Column(Integer, primary_key=True)
392 company_id = Column(String)
393 user_class = Column(Integer, nullable=True)
395 # trigger, chatbot
396 object_type = Column(String, nullable=False)
397 object_id = Column(Integer, nullable=False)
399 last_error = Column(String)
400 active = Column(Boolean, default=True)
401 reload = Column(Boolean, default=False)
403 # for running in concurrent processes
404 run_by = Column(String)
405 alive_time = Column(DateTime(timezone=True))
407 updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now)
408 created_at = Column(DateTime, default=datetime.datetime.now)
411class AgentSkillsAssociation(Base):
412 __tablename__ = "agent_skills"
414 agent_id: Mapped[int] = mapped_column(ForeignKey("agents.id"), primary_key=True)
415 skill_id: Mapped[int] = mapped_column(ForeignKey("skills.id"), primary_key=True)
416 parameters: Mapped[dict] = mapped_column(JSON, default={}, nullable=True)
418 agent = relationship("Agents", back_populates="skills_relationships")
419 skill = relationship("Skills", back_populates="agents_relationships")
422class Skills(Base):
423 __tablename__ = "skills"
424 id = Column(Integer, primary_key=True)
425 agents_relationships: Mapped[List["Agents"]] = relationship(AgentSkillsAssociation, back_populates="skill")
426 name = Column(String, nullable=False)
427 project_id = Column(Integer, nullable=False)
428 type = Column(String, nullable=False)
429 params = Column(JSON)
431 created_at = Column(DateTime, default=datetime.datetime.now)
432 updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now)
433 deleted_at = Column(DateTime)
435 def as_dict(self) -> Dict:
436 return {
437 "id": self.id,
438 "name": self.name,
439 "project_id": self.project_id,
440 "agent_ids": [rel.agent.id for rel in self.agents_relationships],
441 "type": self.type,
442 "params": self.params,
443 "created_at": self.created_at,
444 }
447class Agents(Base):
448 __tablename__ = "agents"
449 id = Column(Integer, primary_key=True)
450 skills_relationships: Mapped[List["Skills"]] = relationship(AgentSkillsAssociation, back_populates="agent")
451 company_id = Column(String, nullable=True)
452 user_class = Column(Integer, nullable=True)
454 name = Column(String, nullable=False)
455 project_id = Column(Integer, nullable=False)
457 model_name = Column(String, nullable=True)
458 provider = Column(String, nullable=True)
459 params = Column(JSON)
461 updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now)
462 created_at = Column(DateTime, default=datetime.datetime.now)
463 deleted_at = Column(DateTime)
465 def as_dict(self) -> Dict:
466 skills = []
467 skills_extra_parameters = {}
468 for rel in self.skills_relationships:
469 skill = rel.skill
470 # Skip auto-generated SQL skills
471 if skill.params.get("description", "").startswith("Auto-generated SQL skill for agent"):
472 continue
473 skills.append(skill.as_dict())
474 skills_extra_parameters[skill.name] = rel.parameters or {}
476 params = self.params.copy()
478 agent_dict = {
479 "id": self.id,
480 "name": self.name,
481 "project_id": self.project_id,
482 "updated_at": self.updated_at,
483 "created_at": self.created_at,
484 }
486 if self.model_name:
487 agent_dict["model_name"] = self.model_name
489 if self.provider:
490 agent_dict["provider"] = self.provider
492 # Since skills were depreciated, they are only used with Minds
493 # Minds expects the parameters to be provided as is without breaking them down
494 if skills:
495 agent_dict["skills"] = skills
496 agent_dict["skills_extra_parameters"] = skills_extra_parameters
497 agent_dict["params"] = params
498 else:
499 data = params.pop("data", {})
500 model = params.pop("model", {})
501 prompt_template = params.pop("prompt_template", None)
502 if data:
503 agent_dict["data"] = data
504 if model:
505 agent_dict["model"] = model
506 if prompt_template: 506 ↛ 508line 506 didn't jump to line 508 because the condition on line 506 was always true
507 agent_dict["prompt_template"] = prompt_template
508 if params:
509 agent_dict["params"] = params
511 return agent_dict
514class KnowledgeBase(Base):
515 __tablename__ = "knowledge_base"
516 id = Column(Integer, primary_key=True)
517 name = Column(String, nullable=False)
518 project_id = Column(Integer, nullable=False)
519 params = Column(JSON)
521 vector_database_id = Column(
522 ForeignKey("integration.id", name="fk_knowledge_base_vector_database_id"),
523 doc="fk to the vector database integration",
524 )
525 vector_database = relationship(
526 "Integration",
527 foreign_keys=[vector_database_id],
528 doc="vector database integration",
529 )
531 vector_database_table = Column(String, doc="table name in the vector database")
533 embedding_model_id = Column(
534 ForeignKey("predictor.id", name="fk_knowledge_base_embedding_model_id"),
535 doc="fk to the embedding model",
536 )
538 embedding_model = relationship("Predictor", foreign_keys=[embedding_model_id], doc="embedding model")
539 query_id = Column(Integer, nullable=True)
541 created_at = Column(DateTime, default=datetime.datetime.now)
542 updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now)
544 __table_args__ = (UniqueConstraint("name", "project_id", name="unique_knowledge_base_name_project_id"),)
546 def as_dict(self, with_secrets: Optional[bool] = True) -> Dict:
547 params = self.params.copy()
548 embedding_model = params.pop("embedding_model", None)
549 reranking_model = params.pop("reranking_model", None)
551 if not with_secrets: 551 ↛ 557line 551 didn't jump to line 557 because the condition on line 551 was always true
552 for key in ("api_key", "private_key"):
553 for el in (embedding_model, reranking_model):
554 if el and key in el:
555 el[key] = "******"
557 return {
558 "id": self.id,
559 "name": self.name,
560 "project_id": self.project_id,
561 "vector_database": None if self.vector_database is None else self.vector_database.name,
562 "vector_database_table": self.vector_database_table,
563 "updated_at": self.updated_at,
564 "created_at": self.created_at,
565 "query_id": self.query_id,
566 "embedding_model": embedding_model,
567 "reranking_model": reranking_model,
568 "metadata_columns": params.pop("metadata_columns", None),
569 "content_columns": params.pop("content_columns", None),
570 "id_column": params.pop("id_column", None),
571 "params": params,
572 }
575class QueryContext(Base):
576 __tablename__ = "query_context"
577 id: int = Column(Integer, primary_key=True)
578 company_id: int = Column(String, nullable=True)
580 query: str = Column(String, nullable=False)
581 context_name: str = Column(String, nullable=False)
582 values: dict = Column(JSON)
584 updated_at: datetime.datetime = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now)
585 created_at: datetime.datetime = Column(DateTime, default=datetime.datetime.now)
588class Queries(Base):
589 __tablename__ = "queries"
590 id: int = Column(Integer, primary_key=True)
591 company_id: int = Column(String, nullable=True)
593 sql: str = Column(String, nullable=False)
594 database: str = Column(String, nullable=True)
596 started_at: datetime.datetime = Column(DateTime)
597 finished_at: datetime.datetime = Column(DateTime)
599 parameters = Column(JSON, default={})
600 context = Column(JSON, default={})
601 processed_rows = Column(Integer, default=0)
602 error: str = Column(String, nullable=True)
604 updated_at: datetime.datetime = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now)
605 created_at: datetime.datetime = Column(DateTime, default=datetime.datetime.now)
608class LLMLog(Base):
609 __tablename__ = "llm_log"
610 id: int = Column(Integer, primary_key=True)
611 company_id: int = Column(String, nullable=False)
612 api_key: str = Column(String, nullable=True)
613 model_id: int = Column(Integer, nullable=True)
614 model_group: str = Column(String, nullable=True)
615 input: str = Column(JSON, nullable=True)
616 output: str = Column(JSON, nullable=True)
617 start_time: datetime = Column(DateTime, nullable=False)
618 end_time: datetime = Column(DateTime, nullable=True)
619 cost: float = Column(Numeric(5, 2), nullable=True)
620 prompt_tokens: int = Column(Integer, nullable=True)
621 completion_tokens: int = Column(Integer, nullable=True)
622 total_tokens: int = Column(Integer, nullable=True)
623 success: bool = Column(Boolean, nullable=False, default=True)
624 exception: str = Column(String, nullable=True)
625 traceback: str = Column(String, nullable=True)
626 stream: bool = Column(Boolean, default=False, comment="Is this completion done in 'streaming' mode")
627 metadata_: dict = Column("metadata", JSON, nullable=True)
630class LLMData(Base):
631 """
632 Stores the question/answer pairs of an LLM call so examples can be used
633 for self improvement with DSPy
634 """
636 __tablename__ = "llm_data"
637 id: int = Column(Integer, primary_key=True)
638 input: str = Column(String, nullable=False)
639 output: str = Column(String, nullable=False)
640 model_id: int = Column(Integer, nullable=False)
641 created_at: datetime = Column(DateTime, default=datetime.datetime.now)
642 updated_at: datetime = Column(DateTime, onupdate=datetime.datetime.now)