Coverage for mindsdb / interfaces / chatbot / memory.py: 23%
113 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
2from typing import Union
4from mindsdb_sql_parser.ast import Identifier, Select, BinaryOperation, Constant, OrderBy
6from mindsdb.interfaces.storage import db
7from .types import ChatBotMessage
10class BaseMemory:
11 '''
12 base class to work with chatbot memory
13 '''
14 MAX_DEPTH = 100
16 def __init__(self, chat_task, chat_params):
17 # in memory yet
18 self._modes = {}
19 self._hide_history_before = {}
20 self._cache = {}
21 self.chat_params = chat_params
22 self.chat_task = chat_task
24 def get_chat(self, chat_id, table_name=None):
25 return ChatMemory(self, chat_id, table_name=table_name)
27 def hide_history(self, chat_id, left_count, table_name=None):
28 '''
29 set date to start hiding messages
30 '''
31 history = self.get_chat_history(chat_id, table_name=table_name)
32 if left_count > len(history) - 1:
33 left_count = len(history) - 1
34 sent_at = history[-left_count].sent_at
36 self._hide_history_before[chat_id] = sent_at
38 def _apply_hiding(self, chat_id, history):
39 '''
40 hide messages from history
41 '''
42 before = self._hide_history_before.get(chat_id)
44 if before is None:
45 return history
47 return [
48 msg
49 for msg in history
50 if msg.sent_at >= before
51 ]
53 def get_mode(self, chat_id):
54 return self._modes.get(chat_id)
56 def set_mode(self, chat_id, mode):
57 self._modes[chat_id] = mode
59 def add_to_history(self, chat_id, chat_message, table_name=None):
61 # If the chat_id is a tuple, convert it to a string when storing the message in the database.
62 self._add_to_history(
63 chat_id,
64 chat_message,
65 table_name=table_name
66 )
67 if chat_id in self._cache:
68 del self._cache[chat_id]
70 def get_chat_history(self, chat_id, table_name=None, cached=True):
71 key = (chat_id, table_name) if table_name else chat_id
72 if cached and key in self._cache:
73 history = self._cache[key]
75 else:
76 history = self._get_chat_history(
77 chat_id,
78 table_name
79 )
80 self._cache[key] = history
82 history = self._apply_hiding(chat_id, history)
83 return history
85 def _add_to_history(self, chat_id, chat_message, table_name=None):
86 raise NotImplementedError
88 def _get_chat_history(self, chat_id, table_name=None):
89 raise NotImplementedError
92class HandlerMemory(BaseMemory):
93 '''
94 Uses handler's database to store and retrieve messages
95 '''
97 def _add_to_history(self, chat_id, chat_message, table_name=None):
98 # do nothing. sent message will be stored by handler db
99 pass
101 def _get_chat_history(self, chat_id, table_name):
102 t_params = next(
103 chat_params['chat_table'] for chat_params in self.chat_params if chat_params['chat_table']['name'] == table_name
104 )
106 text_col = t_params['text_col']
107 username_col = t_params['username_col']
108 time_col = t_params['time_col']
109 chat_id_cols = t_params['chat_id_col'] if isinstance(t_params['chat_id_col'], list) else [t_params['chat_id_col']]
111 chat_id = chat_id if isinstance(chat_id, tuple) else (chat_id,)
112 # Add a WHERE clause for each chat_id column.
113 where_conditions = [
114 BinaryOperation(
115 op='=',
116 args=[
117 Identifier(chat_id_col),
118 Constant(chat_id[idx])
119 ]
120 ) for idx, chat_id_col in enumerate(chat_id_cols)
121 ]
122 # Add a WHERE clause to ignore holding messages from the bot.
123 from .chatbot_task import HOLDING_MESSAGE
125 where_conditions.append(
126 BinaryOperation(
127 op='!=',
128 args=[
129 Identifier(text_col),
130 Constant(HOLDING_MESSAGE)
131 ]
132 )
133 )
135 # Convert the WHERE conditions to a BinaryOperation object.
136 where_conditions_binary_operation = None
137 for condition in where_conditions:
138 if where_conditions_binary_operation is None:
139 where_conditions_binary_operation = condition
140 else:
141 where_conditions_binary_operation = BinaryOperation('and', args=[where_conditions_binary_operation, condition])
143 ast_query = Select(
144 targets=[Identifier(text_col),
145 Identifier(username_col),
146 Identifier(time_col)],
147 from_table=Identifier(t_params['name']),
148 where=where_conditions_binary_operation,
149 order_by=[OrderBy(Identifier(time_col))],
150 limit=Constant(self.MAX_DEPTH),
151 )
153 resp = self.chat_task.chat_handler.query(ast_query)
154 if resp.data_frame is None:
155 return
157 df = resp.data_frame
159 # get last messages
160 df = df.iloc[-self.MAX_DEPTH:]
162 result = []
163 for _, rec in df.iterrows():
164 chatbot_message = ChatBotMessage(
165 ChatBotMessage.Type.DIRECT,
166 rec[text_col],
167 user=rec[username_col],
168 sent_at=rec[time_col]
169 )
170 result.append(chatbot_message)
172 return result
175class DBMemory(BaseMemory):
176 '''
177 uses mindsdb database to store messages
178 '''
180 def _generate_chat_id_for_db(self, chat_id: Union[str, tuple], table_name: str = None) -> str:
181 """
182 Generate an ID for the chat to store in the database.
183 The ID is a string that includes the components of the chat ID and the table name (if provided) separated by underscores.
185 Args:
186 chat_id (str | tuple): The ID of the chat.
187 table_name (str): The name of the table the chat belongs to.
188 """
189 if isinstance(chat_id, tuple):
190 char_id_str = "_".join(str(val) for val in chat_id)
191 else:
192 char_id_str = str(chat_id)
194 if table_name:
195 chat_id_str = f"{table_name}_{char_id_str}"
197 return chat_id_str
199 def _add_to_history(self, chat_id, message, table_name=None):
200 chat_bot_id = self.chat_task.bot_id
201 destination = self._generate_chat_id_for_db(chat_id, table_name)
203 message = db.ChatBotsHistory(
204 chat_bot_id=chat_bot_id,
205 type=message.type.name,
206 text=message.text,
207 user=message.user,
208 destination=destination,
209 )
210 db.session.add(message)
211 db.session.commit()
213 def _get_chat_history(self, chat_id, table_name=None):
214 chat_bot_id = self.chat_task.bot_id
215 destination = self._generate_chat_id_for_db(chat_id, table_name)
217 query = db.ChatBotsHistory.query\
218 .filter(
219 db.ChatBotsHistory.chat_bot_id == chat_bot_id,
220 db.ChatBotsHistory.destination == destination
221 )\
222 .order_by(db.ChatBotsHistory.sent_at.desc())\
223 .limit(self.MAX_DEPTH)
225 result = [
226 ChatBotMessage(
227 rec.type,
228 rec.text,
229 rec.user,
230 sent_at=rec.sent_at,
231 )
232 for rec in query
233 ]
234 result.reverse()
235 return result
238class ChatMemory:
239 '''
240 interface to work with individual chat
241 '''
242 def __init__(self, memory, chat_id, table_name=None):
243 self.memory = memory
244 self.chat_id = chat_id
245 self.table_name = table_name
247 self.cached = False
249 def get_history(self, cached=True):
250 result = self.memory.get_chat_history(self.chat_id, self.table_name, cached=cached and self.cached)
252 self.cached = True
253 return result
255 def add_to_history(self, message):
256 self.memory.add_to_history(self.chat_id, message, table_name=self.table_name)
258 def get_mode(self):
259 return self.memory.get_mode(self.chat_id)
261 def set_mode(self, mode):
262 self.memory.set_mode(self.chat_id, mode)
264 def hide_history(self, left_count):
265 '''
266 set date to start hiding messages
267 '''
268 self.memory.hide_history(self.chat_id, left_count, table_name=self.table_name)