Coverage for mindsdb / interfaces / chatbot / memory.py: 23%

113 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1 

2from typing import Union 

3 

4from mindsdb_sql_parser.ast import Identifier, Select, BinaryOperation, Constant, OrderBy 

5 

6from mindsdb.interfaces.storage import db 

7from .types import ChatBotMessage 

8 

9 

10class BaseMemory: 

11 ''' 

12 base class to work with chatbot memory 

13 ''' 

14 MAX_DEPTH = 100 

15 

16 def __init__(self, chat_task, chat_params): 

17 # in memory yet 

18 self._modes = {} 

19 self._hide_history_before = {} 

20 self._cache = {} 

21 self.chat_params = chat_params 

22 self.chat_task = chat_task 

23 

24 def get_chat(self, chat_id, table_name=None): 

25 return ChatMemory(self, chat_id, table_name=table_name) 

26 

27 def hide_history(self, chat_id, left_count, table_name=None): 

28 ''' 

29 set date to start hiding messages 

30 ''' 

31 history = self.get_chat_history(chat_id, table_name=table_name) 

32 if left_count > len(history) - 1: 

33 left_count = len(history) - 1 

34 sent_at = history[-left_count].sent_at 

35 

36 self._hide_history_before[chat_id] = sent_at 

37 

38 def _apply_hiding(self, chat_id, history): 

39 ''' 

40 hide messages from history 

41 ''' 

42 before = self._hide_history_before.get(chat_id) 

43 

44 if before is None: 

45 return history 

46 

47 return [ 

48 msg 

49 for msg in history 

50 if msg.sent_at >= before 

51 ] 

52 

53 def get_mode(self, chat_id): 

54 return self._modes.get(chat_id) 

55 

56 def set_mode(self, chat_id, mode): 

57 self._modes[chat_id] = mode 

58 

59 def add_to_history(self, chat_id, chat_message, table_name=None): 

60 

61 # If the chat_id is a tuple, convert it to a string when storing the message in the database. 

62 self._add_to_history( 

63 chat_id, 

64 chat_message, 

65 table_name=table_name 

66 ) 

67 if chat_id in self._cache: 

68 del self._cache[chat_id] 

69 

70 def get_chat_history(self, chat_id, table_name=None, cached=True): 

71 key = (chat_id, table_name) if table_name else chat_id 

72 if cached and key in self._cache: 

73 history = self._cache[key] 

74 

75 else: 

76 history = self._get_chat_history( 

77 chat_id, 

78 table_name 

79 ) 

80 self._cache[key] = history 

81 

82 history = self._apply_hiding(chat_id, history) 

83 return history 

84 

85 def _add_to_history(self, chat_id, chat_message, table_name=None): 

86 raise NotImplementedError 

87 

88 def _get_chat_history(self, chat_id, table_name=None): 

89 raise NotImplementedError 

90 

91 

92class HandlerMemory(BaseMemory): 

93 ''' 

94 Uses handler's database to store and retrieve messages 

95 ''' 

96 

97 def _add_to_history(self, chat_id, chat_message, table_name=None): 

98 # do nothing. sent message will be stored by handler db 

99 pass 

100 

101 def _get_chat_history(self, chat_id, table_name): 

102 t_params = next( 

103 chat_params['chat_table'] for chat_params in self.chat_params if chat_params['chat_table']['name'] == table_name 

104 ) 

105 

106 text_col = t_params['text_col'] 

107 username_col = t_params['username_col'] 

108 time_col = t_params['time_col'] 

109 chat_id_cols = t_params['chat_id_col'] if isinstance(t_params['chat_id_col'], list) else [t_params['chat_id_col']] 

110 

111 chat_id = chat_id if isinstance(chat_id, tuple) else (chat_id,) 

112 # Add a WHERE clause for each chat_id column. 

113 where_conditions = [ 

114 BinaryOperation( 

115 op='=', 

116 args=[ 

117 Identifier(chat_id_col), 

118 Constant(chat_id[idx]) 

119 ] 

120 ) for idx, chat_id_col in enumerate(chat_id_cols) 

121 ] 

122 # Add a WHERE clause to ignore holding messages from the bot. 

123 from .chatbot_task import HOLDING_MESSAGE 

124 

125 where_conditions.append( 

126 BinaryOperation( 

127 op='!=', 

128 args=[ 

129 Identifier(text_col), 

130 Constant(HOLDING_MESSAGE) 

131 ] 

132 ) 

133 ) 

134 

135 # Convert the WHERE conditions to a BinaryOperation object. 

136 where_conditions_binary_operation = None 

137 for condition in where_conditions: 

138 if where_conditions_binary_operation is None: 

139 where_conditions_binary_operation = condition 

140 else: 

141 where_conditions_binary_operation = BinaryOperation('and', args=[where_conditions_binary_operation, condition]) 

142 

143 ast_query = Select( 

144 targets=[Identifier(text_col), 

145 Identifier(username_col), 

146 Identifier(time_col)], 

147 from_table=Identifier(t_params['name']), 

148 where=where_conditions_binary_operation, 

149 order_by=[OrderBy(Identifier(time_col))], 

150 limit=Constant(self.MAX_DEPTH), 

151 ) 

152 

153 resp = self.chat_task.chat_handler.query(ast_query) 

154 if resp.data_frame is None: 

155 return 

156 

157 df = resp.data_frame 

158 

159 # get last messages 

160 df = df.iloc[-self.MAX_DEPTH:] 

161 

162 result = [] 

163 for _, rec in df.iterrows(): 

164 chatbot_message = ChatBotMessage( 

165 ChatBotMessage.Type.DIRECT, 

166 rec[text_col], 

167 user=rec[username_col], 

168 sent_at=rec[time_col] 

169 ) 

170 result.append(chatbot_message) 

171 

172 return result 

173 

174 

175class DBMemory(BaseMemory): 

176 ''' 

177 uses mindsdb database to store messages 

178 ''' 

179 

180 def _generate_chat_id_for_db(self, chat_id: Union[str, tuple], table_name: str = None) -> str: 

181 """ 

182 Generate an ID for the chat to store in the database. 

183 The ID is a string that includes the components of the chat ID and the table name (if provided) separated by underscores. 

184 

185 Args: 

186 chat_id (str | tuple): The ID of the chat. 

187 table_name (str): The name of the table the chat belongs to. 

188 """ 

189 if isinstance(chat_id, tuple): 

190 char_id_str = "_".join(str(val) for val in chat_id) 

191 else: 

192 char_id_str = str(chat_id) 

193 

194 if table_name: 

195 chat_id_str = f"{table_name}_{char_id_str}" 

196 

197 return chat_id_str 

198 

199 def _add_to_history(self, chat_id, message, table_name=None): 

200 chat_bot_id = self.chat_task.bot_id 

201 destination = self._generate_chat_id_for_db(chat_id, table_name) 

202 

203 message = db.ChatBotsHistory( 

204 chat_bot_id=chat_bot_id, 

205 type=message.type.name, 

206 text=message.text, 

207 user=message.user, 

208 destination=destination, 

209 ) 

210 db.session.add(message) 

211 db.session.commit() 

212 

213 def _get_chat_history(self, chat_id, table_name=None): 

214 chat_bot_id = self.chat_task.bot_id 

215 destination = self._generate_chat_id_for_db(chat_id, table_name) 

216 

217 query = db.ChatBotsHistory.query\ 

218 .filter( 

219 db.ChatBotsHistory.chat_bot_id == chat_bot_id, 

220 db.ChatBotsHistory.destination == destination 

221 )\ 

222 .order_by(db.ChatBotsHistory.sent_at.desc())\ 

223 .limit(self.MAX_DEPTH) 

224 

225 result = [ 

226 ChatBotMessage( 

227 rec.type, 

228 rec.text, 

229 rec.user, 

230 sent_at=rec.sent_at, 

231 ) 

232 for rec in query 

233 ] 

234 result.reverse() 

235 return result 

236 

237 

238class ChatMemory: 

239 ''' 

240 interface to work with individual chat 

241 ''' 

242 def __init__(self, memory, chat_id, table_name=None): 

243 self.memory = memory 

244 self.chat_id = chat_id 

245 self.table_name = table_name 

246 

247 self.cached = False 

248 

249 def get_history(self, cached=True): 

250 result = self.memory.get_chat_history(self.chat_id, self.table_name, cached=cached and self.cached) 

251 

252 self.cached = True 

253 return result 

254 

255 def add_to_history(self, message): 

256 self.memory.add_to_history(self.chat_id, message, table_name=self.table_name) 

257 

258 def get_mode(self): 

259 return self.memory.get_mode(self.chat_id) 

260 

261 def set_mode(self, mode): 

262 self.memory.set_mode(self.chat_id, mode) 

263 

264 def hide_history(self, left_count): 

265 ''' 

266 set date to start hiding messages 

267 ''' 

268 self.memory.hide_history(self.chat_id, left_count, table_name=self.table_name)