Coverage for mindsdb / interfaces / chatbot / polling.py: 18%

115 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import secrets 

2import threading 

3import time 

4 

5from mindsdb_sql_parser.ast import Identifier, Select, Insert 

6 

7from mindsdb.utilities import log 

8from mindsdb.utilities.context import context as ctx 

9 

10from .types import ChatBotMessage, BotException 

11 

12logger = log.getLogger(__name__) 

13 

14 

15class BasePolling: 

16 def __init__(self, chat_task, chat_params): 

17 self.params = chat_params 

18 self.chat_task = chat_task 

19 

20 def start(self, stop_event): 

21 raise NotImplementedError 

22 

23 def send_message(self, message: ChatBotMessage, table_name=None): 

24 chat_id = message.destination if isinstance(message.destination, tuple) else (message.destination,) 

25 text = message.text 

26 

27 t_params = ( 

28 self.params["chat_table"] 

29 if table_name is None 

30 else next((t["chat_table"] for t in self.params if t["chat_table"]["name"] == table_name)) 

31 ) 

32 chat_id_cols = ( 

33 t_params["chat_id_col"] if isinstance(t_params["chat_id_col"], list) else [t_params["chat_id_col"]] 

34 ) 

35 

36 ast_query = Insert( 

37 table=Identifier(t_params["name"]), 

38 columns=[*chat_id_cols, t_params["text_col"]], 

39 values=[ 

40 [*chat_id, text], 

41 ], 

42 ) 

43 

44 self.chat_task.chat_handler.query(ast_query) 

45 

46 

47class MessageCountPolling(BasePolling): 

48 def __init__(self, *args, **kwargs): 

49 super().__init__(*args, **kwargs) 

50 

51 self._to_stop = False 

52 self.chats_prev = None 

53 

54 def run(self, stop_event): 

55 while True: 

56 try: 

57 for chat_params in self.params: 

58 chat_ids = self.check_message_count(chat_params) 

59 logger.debug(f"number of chat ids found: {len(chat_ids)}") 

60 

61 for chat_id in chat_ids: 

62 try: 

63 chat_memory = self.chat_task.memory.get_chat( 

64 chat_id, 

65 table_name=chat_params["chat_table"]["name"], 

66 ) 

67 except Exception: 

68 logger.exception("Problem retrieving chat memory:") 

69 

70 try: 

71 message = self.get_last_message(chat_memory) 

72 except Exception: 

73 logger.exception("Problem getting last message:") 

74 message = None 

75 

76 if message: 

77 self.chat_task.on_message( 

78 message, chat_memory=chat_memory, table_name=chat_params["chat_table"]["name"] 

79 ) 

80 

81 except Exception: 

82 logger.exception("Unexpected error") 

83 

84 if stop_event.is_set(): 

85 return 

86 logger.debug(f"running {self.chat_task.bot_id}") 

87 time.sleep(7) 

88 

89 def get_last_message(self, chat_memory): 

90 # retrive from history 

91 try: 

92 history = chat_memory.get_history() 

93 except Exception: 

94 logger.exception("Problem retrieving history:") 

95 history = [] 

96 last_message = history[-1] 

97 if last_message.user == self.chat_task.bot_params["bot_username"]: 

98 # the last message is from bot 

99 return 

100 return last_message 

101 

102 def check_message_count(self, chat_params): 

103 p_params = chat_params["polling"] 

104 

105 chat_ids = [] 

106 

107 id_cols = p_params["chat_id_col"] if isinstance(p_params["chat_id_col"], list) else [p_params["chat_id_col"]] 

108 msgs_col = p_params["count_col"] 

109 # get chats status info 

110 ast_query = Select( 

111 targets=[*[Identifier(id_col) for id_col in id_cols], Identifier(msgs_col)], 

112 from_table=Identifier(p_params["table"]), 

113 ) 

114 

115 resp = self.chat_task.chat_handler.query(query=ast_query) 

116 if resp.data_frame is None: 

117 raise BotException("Error to get count of messages") 

118 

119 chats = {} 

120 for row in resp.data_frame.to_dict("records"): 

121 chat_id = tuple(row[id_col] for id_col in id_cols) 

122 msgs = row[msgs_col] 

123 

124 chats[chat_id] = msgs 

125 

126 if self.chats_prev is None: 

127 # first run 

128 self.chats_prev = chats 

129 else: 

130 # compare 

131 # for new keys 

132 for chat_id, count_msgs in chats.items(): 

133 if self.chats_prev.get(chat_id) != count_msgs: 

134 chat_ids.append(chat_id) 

135 

136 self.chats_prev = chats 

137 return chat_ids 

138 

139 def stop(self): 

140 self._to_stop = True 

141 

142 

143class RealtimePolling(BasePolling): 

144 def __init__(self, *args, **kwargs): 

145 super().__init__(*args, **kwargs) 

146 

147 # call back can be without context 

148 self._ctx_dump = ctx.dump() 

149 

150 def _callback(self, row, key): 

151 ctx.load(self._ctx_dump) 

152 

153 row.update(key) 

154 

155 # If more than one set of parameters is present, multiple tables are supported. 

156 if len(self.params) > 1: 

157 # Identify the table relevant to this event based on the key. 

158 event_keys = list(key.keys()) 

159 for param in self.params: 

160 table_keys = ( 

161 [param["chat_table"]["chat_id_col"]] 

162 if isinstance(param["chat_table"]["chat_id_col"], str) 

163 else param["chat_table"]["chat_id_col"] 

164 ) 

165 

166 if sorted(event_keys) == sorted(table_keys): 

167 t_params = param["chat_table"] 

168 break 

169 

170 # Otherwise, only a single table is supported. Use the first set of parameters. 

171 else: 

172 t_params = self.params[0] 

173 

174 # Get the chat ID from the row based on the chat ID column(s). 

175 chat_id = ( 

176 tuple(row[key] for key in t_params["chat_id_col"]) 

177 if isinstance(t_params["chat_id_col"], list) 

178 else row[t_params["chat_id_col"]] 

179 ) 

180 

181 message = ChatBotMessage( 

182 ChatBotMessage.Type.DIRECT, 

183 row[t_params["text_col"]], 

184 # In Slack direct messages are treated as channels themselves. 

185 row[t_params["username_col"]], 

186 chat_id, 

187 ) 

188 

189 self.chat_task.on_message( 

190 message, 

191 chat_id=chat_id, 

192 table_name=t_params["name"], 

193 ) 

194 

195 def run(self, stop_event): 

196 self.chat_task.chat_handler.subscribe(stop_event, self._callback) 

197 

198 # def send_message(self, message: ChatBotMessage): 

199 # 

200 # self.chat_task.chat_handler.realtime_send(message) 

201 

202 

203class WebhookPolling(BasePolling): 

204 """ 

205 Polling class for handling webhooks. 

206 """ 

207 

208 def __init__(self, *args, **kwargs): 

209 super().__init__(*args, **kwargs) 

210 

211 def run(self, stop_event: threading.Event) -> None: 

212 """ 

213 Run the webhook polling. 

214 Check if a webhook token is set for the chatbot. If not, generate a new one. 

215 Then, do nothing, as the webhook is handled by a task instantiated for each request. 

216 

217 Args: 

218 stop_event (threading.Event): Event to stop the polling. 

219 """ 

220 # If a webhook token is not set for the chatbot, generate a new one. 

221 from mindsdb.interfaces.chatbot.chatbot_controller import ChatBotController 

222 

223 chat_bot_controller = ChatBotController() 

224 chat_bot = chat_bot_controller.get_chatbot_by_id(self.chat_task.object_id) 

225 

226 if not chat_bot["webhook_token"]: 

227 chat_bot_controller.update_chatbot( 

228 chatbot_name=chat_bot["name"], 

229 project_name=chat_bot["project"], 

230 webhook_token=secrets.token_urlsafe(32), 

231 ) 

232 

233 # Do nothing, as the webhook is handled by a task instantiated for each request. 

234 stop_event.wait() 

235 

236 def send_message(self, message: ChatBotMessage, table_name: str = None) -> None: 

237 """ 

238 Send a message (response) to the chatbot. 

239 Pass the message to the chatbot handler to respond. 

240 

241 Args: 

242 message (ChatBotMessage): The message to send. 

243 table_name (str): The name of the table to send the message to. Defaults to None. 

244 """ 

245 self.chat_task.chat_handler.respond(message)