Coverage for mindsdb / integrations / handlers / twitter_handler / twitter_handler.py: 0%

281 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import re 

2import os 

3import datetime as dt 

4import time 

5from collections import defaultdict 

6import io 

7import requests 

8 

9import pandas as pd 

10import tweepy 

11 

12from mindsdb.utilities import log 

13from mindsdb.utilities.config import Config 

14 

15from mindsdb_sql_parser import ast 

16 

17from mindsdb.integrations.libs.api_handler import APIHandler, APITable, FuncParser 

18from mindsdb.integrations.utilities.sql_utils import extract_comparison_conditions 

19from mindsdb.integrations.utilities.date_utils import parse_utc_date 

20 

21from mindsdb.integrations.libs.response import ( 

22 HandlerStatusResponse as StatusResponse, 

23 HandlerResponse as Response, 

24 RESPONSE_TYPE 

25) 

26 

27logger = log.getLogger(__name__) 

28 

29 

30class TweetsTable(APITable): 

31 

32 def select(self, query: ast.Select) -> Response: 

33 

34 conditions = extract_comparison_conditions(query.where) 

35 

36 params = {} 

37 filters = [] 

38 for op, arg1, arg2 in conditions: 

39 

40 if op == 'or': 

41 raise NotImplementedError('OR is not supported') 

42 if arg1 == 'created_at': 

43 date = parse_utc_date(arg2) 

44 if op == '>': 

45 # "tweets/search/recent" doesn't accept dates earlier than 7 days 

46 if (dt.datetime.now(dt.timezone.utc) - date).days > 7: 

47 # skip this condition 

48 continue 

49 params['start_time'] = date 

50 elif op == '<': 

51 params['end_time'] = date 

52 else: 

53 raise NotImplementedError 

54 

55 elif arg1 == 'query': 

56 if op == '=': 

57 params[arg1] = arg2 

58 else: 

59 NotImplementedError(f'Unknown op: {op}') 

60 

61 elif arg1 == 'id': 

62 if op == '>': 

63 params['since_id'] = arg2 

64 elif op == '>=': 

65 raise NotImplementedError("Please use 'id > value'") 

66 elif op == '<': 

67 params['until_id'] = arg2 

68 elif op == '<=': 

69 raise NotImplementedError("Please use 'id < value'") 

70 else: 

71 NotImplementedError('Search with "id=" is not implemented') 

72 

73 else: 

74 filters.append([op, arg1, arg2]) 

75 

76 if query.limit is not None: 

77 params['max_results'] = query.limit.value 

78 

79 params['expansions'] = ['author_id', 'in_reply_to_user_id'] 

80 params['tweet_fields'] = ['created_at', 'conversation_id', 'referenced_tweets'] 

81 params['user_fields'] = ['name', 'username'] 

82 

83 if 'query' not in params: 

84 # search not works without query, use 'mindsdb' 

85 params['query'] = 'mindsdb' 

86 

87 result = self.handler.call_twitter_api( 

88 method_name='search_recent_tweets', 

89 params=params, 

90 filters=filters 

91 ) 

92 

93 # filter targets 

94 columns = [] 

95 for target in query.targets: 

96 if isinstance(target, ast.Star): 

97 columns = [] 

98 break 

99 elif isinstance(target, ast.Identifier): 

100 columns.append(target.parts[-1]) 

101 else: 

102 raise NotImplementedError 

103 

104 if len(columns) == 0: 

105 columns = self.get_columns() 

106 

107 # columns to lower case 

108 columns = [name.lower() for name in columns] 

109 

110 if len(result) == 0: 

111 result = pd.DataFrame([], columns=columns) 

112 else: 

113 # add absent columns 

114 for col in set(columns) & set(result.columns) ^ set(columns): 

115 result[col] = None 

116 

117 # filter by columns 

118 result = result[columns] 

119 return result 

120 

121 def get_columns(self): 

122 return [ 

123 'id', 

124 'created_at', 

125 'text', 

126 'edit_history_tweet_ids', 

127 'author_id', 

128 'author_name', 

129 'author_username', 

130 'conversation_id', 

131 'in_reply_to_tweet_id', 

132 'in_retweeted_to_tweet_id', 

133 'in_quote_to_tweet_id', 

134 'in_reply_to_user_id', 

135 'in_reply_to_user_name', 

136 'in_reply_to_user_username', 

137 ] 

138 

139 def insert(self, query: ast.Insert): 

140 # https://docs.tweepy.org/en/stable/client.html#tweepy.Client.create_tweet 

141 columns = [col.name for col in query.columns] 

142 

143 insert_params = ('consumer_key', 'consumer_secret', 'access_token', 'access_token_secret') 

144 for p in insert_params: 

145 if p not in self.handler.connection_args: 

146 raise Exception(f'To insert data into Twitter, you need to provide the following parameters when connecting it to MindsDB: {insert_params}') # noqa 

147 

148 for row in query.values: 

149 params = dict(zip(columns, row)) 

150 

151 # split long text over 280 symbols 

152 max_text_len = 280 

153 text = params['text'] 

154 

155 # Post image if column media_url is provided, only do this on last tweet 

156 media_ids = None 

157 if 'media_url' in params: 

158 media_url = params.pop('media_url') 

159 

160 # create an in memory file 

161 resp = requests.get(media_url) 

162 img = io.BytesIO(resp.content) 

163 

164 # upload media to twitter 

165 api_v1 = self.handler.create_connection(api_version=1) 

166 content_type = resp.headers['Content-Type'] 

167 file_type = content_type.split('/')[-1] 

168 media = api_v1.media_upload(filename="img.{file_type}".format(file_type=file_type), file=img) 

169 

170 media_ids = [media.media_id] 

171 

172 words = re.split('( )', text) 

173 

174 messages = [] 

175 

176 text2 = '' 

177 pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+' 

178 for word in words: 

179 # replace the links in word to string with the length as twitter short url (23) 

180 word2 = re.sub(pattern, '-' * 23, word) 

181 if len(text2) + len(word2) > max_text_len - 3 - 7: # 3 is for ..., 7 is for (10/11) 

182 messages.append(text2.strip()) 

183 

184 text2 = '' 

185 text2 += word 

186 

187 # the last message 

188 if text2.strip() != '': 

189 messages.append(text2.strip()) 

190 

191 len_messages = len(messages) 

192 for i, text in enumerate(messages): 

193 if i < len_messages - 1: 

194 text += '...' 

195 else: 

196 text += ' ' 

197 # publish media with the last message 

198 if media_ids is not None: 

199 params['media_ids'] = media_ids 

200 

201 text += f'({i + 1}/{len_messages})' 

202 

203 params['text'] = text 

204 ret = self.handler.call_twitter_api('create_tweet', params) 

205 inserted_id = ret.id[0] 

206 params['in_reply_to_tweet_id'] = inserted_id 

207 

208 

209class TwitterHandler(APIHandler): 

210 """A class for handling connections and interactions with the Twitter API. 

211 

212 Attributes: 

213 bearer_token (str): The consumer key for the Twitter app. 

214 api (tweepy.API): The `tweepy.API` object for interacting with the Twitter API. 

215 

216 """ 

217 

218 def __init__(self, name=None, **kwargs): 

219 super().__init__(name) 

220 

221 args = kwargs.get('connection_data', {}) 

222 

223 self.connection_args = {} 

224 handler_config = Config().get('twitter_handler', {}) 

225 for k in ['bearer_token', 'consumer_key', 'consumer_secret', 

226 'access_token', 'access_token_secret', 'wait_on_rate_limit']: 

227 if k in args: 

228 self.connection_args[k] = args[k] 

229 elif f'TWITTER_{k.upper()}' in os.environ: 

230 self.connection_args[k] = os.environ[f'TWITTER_{k.upper()}'] 

231 elif k in handler_config: 

232 self.connection_args[k] = handler_config[k] 

233 

234 self.api = None 

235 self.is_connected = False 

236 

237 tweets = TweetsTable(self) 

238 self._register_table('tweets', tweets) 

239 

240 def create_connection(self, api_version=2): 

241 if api_version == 1: 

242 auth = tweepy.OAuthHandler( 

243 self.connection_args['consumer_key'], 

244 self.connection_args['consumer_secret'] 

245 ) 

246 auth.set_access_token( 

247 self.connection_args['access_token'], 

248 self.connection_args['access_token_secret'] 

249 ) 

250 return tweepy.API(auth) 

251 

252 return tweepy.Client(**self.connection_args) 

253 

254 def connect(self, api_version=2): 

255 """Authenticate with the Twitter API using the API keys and secrets stored in the `consumer_key`, `consumer_secret`, `access_token`, and `access_token_secret` attributes.""" # noqa 

256 

257 if self.is_connected is True: 

258 return self.api 

259 

260 self.api = self.create_connection() 

261 

262 self.is_connected = True 

263 return self.api 

264 

265 def check_connection(self) -> StatusResponse: 

266 

267 response = StatusResponse(False) 

268 

269 try: 

270 api = self.connect() 

271 

272 # call get_user with unknown id. 

273 # it raises an error in case if auth is not success and returns not-found otherwise 

274 # api.get_me() is not exposed for OAuth 2.0 App-only authorisation 

275 api.get_user(id=1) 

276 response.success = True 

277 

278 except tweepy.Unauthorized as e: 

279 response.error_message = f'Error connecting to Twitter api: {e}. Check bearer_token' 

280 logger.error(response.error_message) 

281 

282 if response.success is True and len(self.connection_args) > 1: 

283 # not only bearer_token, check read-write mode (OAuth 2.0 Authorization Code with PKCE) 

284 try: 

285 api = self.connect() 

286 

287 api.get_me() 

288 

289 except tweepy.Unauthorized as e: 

290 keys = 'consumer_key', 'consumer_secret', 'access_token', 'access_token_secret' 

291 response.error_message = f'Error connecting to Twitter api: {e}. Check' + ', '.join(keys) 

292 logger.error(response.error_message) 

293 

294 response.success = False 

295 

296 if response.success is False and self.is_connected is True: 

297 self.is_connected = False 

298 

299 return response 

300 

301 def native_query(self, query_string: str = None): 

302 method_name, params = FuncParser().from_string(query_string) 

303 

304 df = self.call_twitter_api(method_name, params) 

305 

306 return Response( 

307 RESPONSE_TYPE.TABLE, 

308 data_frame=df 

309 ) 

310 

311 def _apply_filters(self, data, filters): 

312 if not filters: 

313 return data 

314 

315 data2 = [] 

316 for row in data: 

317 add = False 

318 for op, key, value in filters: 

319 value2 = row.get(key) 

320 if isinstance(value, int): 

321 # twitter returns ids as string 

322 value = str(value) 

323 

324 if op in ('!=', '<>'): 

325 if value == value2: 

326 break 

327 elif op in ('==', '='): 

328 if value != value2: 

329 break 

330 elif op == 'in': 

331 if not isinstance(value, list): 

332 value = [value] 

333 if value2 not in value: 

334 break 

335 elif op == 'not in': 

336 if not isinstance(value, list): 

337 value = [value] 

338 if value2 in value: 

339 break 

340 else: 

341 raise NotImplementedError(f'Unknown filter: {op}') 

342 # only if there wasn't breaks 

343 add = True 

344 if add: 

345 data2.append(row) 

346 return data2 

347 

348 def call_twitter_api(self, method_name: str = None, params: dict = None, filters: list = None): 

349 

350 # method > table > columns 

351 expansions_map = { 

352 'search_recent_tweets': { 

353 'users': ['author_id', 'in_reply_to_user_id'], 

354 }, 

355 'search_all_tweets': { 

356 'users': ['author_id'], 

357 }, 

358 } 

359 

360 api = self.connect() 

361 method = getattr(api, method_name) 

362 

363 # pagination handle 

364 

365 count_results = None 

366 if 'max_results' in params: 

367 count_results = params['max_results'] 

368 

369 data = [] 

370 includes = defaultdict(list) 

371 

372 max_page_size = 100 

373 min_page_size = 10 

374 left = None 

375 

376 limit_exec_time = time.time() + 60 

377 

378 if filters: 

379 # if we have filters: do big page requests 

380 params['max_results'] = max_page_size 

381 

382 while True: 

383 if time.time() > limit_exec_time: 

384 raise RuntimeError('Handler request timeout error') 

385 

386 if count_results is not None: 

387 left = count_results - len(data) 

388 if left == 0: 

389 break 

390 elif left < 0: 

391 # got more results that we need 

392 data = data[:left] 

393 break 

394 

395 if left > max_page_size: 

396 params['max_results'] = max_page_size 

397 elif left < min_page_size: 

398 params['max_results'] = min_page_size 

399 else: 

400 params['max_results'] = left 

401 

402 logger.debug(f'>>>twitter in: {method_name}({params})') 

403 resp = method(**params) 

404 

405 if hasattr(resp, 'includes'): 

406 for table, records in resp.includes.items(): 

407 includes[table].extend([r.data for r in records]) 

408 

409 if isinstance(resp.data, list): 

410 chunk = [r.data for r in resp.data] 

411 else: 

412 if isinstance(resp.data, dict): 

413 data.append(resp.data) 

414 if hasattr(resp.data, 'data') and isinstance(resp.data.data, dict): 

415 data.append(resp.data.data) 

416 break 

417 

418 # unwind columns 

419 for row in chunk: 

420 if 'referenced_tweets' in row: 

421 refs = row['referenced_tweets'] 

422 if isinstance(refs, list) and len(refs) > 0: 

423 if refs[0]['type'] == 'replied_to': 

424 row['in_reply_to_tweet_id'] = refs[0]['id'] 

425 if refs[0]['type'] == 'retweeted': 

426 row['in_retweeted_to_tweet_id'] = refs[0]['id'] 

427 if refs[0]['type'] == 'quoted': 

428 row['in_quote_to_tweet_id'] = refs[0]['id'] 

429 

430 if filters: 

431 chunk = self._apply_filters(chunk, filters) 

432 

433 # limit output 

434 if left is not None: 

435 chunk = chunk[:left] 

436 

437 data.extend(chunk) 

438 # next page ? 

439 if count_results is not None and hasattr(resp, 'meta') and 'next_token' in resp.meta: 

440 params['next_token'] = resp.meta['next_token'] 

441 else: 

442 break 

443 

444 df = pd.DataFrame(data) 

445 

446 # enrich 

447 expansions = expansions_map.get(method_name) 

448 if expansions is not None: 

449 for table, records in includes.items(): 

450 df_ref = pd.DataFrame(records) 

451 

452 if table not in expansions: 

453 continue 

454 

455 for col_id in expansions[table]: 

456 col = col_id[:-3] # cut _id 

457 if col_id not in df.columns: 

458 continue 

459 

460 col_map = { 

461 col_ref: f'{col}_{col_ref}' 

462 for col_ref in df_ref.columns 

463 } 

464 df_ref2 = df_ref.rename(columns=col_map) 

465 df_ref2 = df_ref2.drop_duplicates(col_id) 

466 

467 df = df.merge(df_ref2, on=col_id, how='left') 

468 

469 return df