Coverage for mindsdb / integrations / handlers / twitter_handler / twitter_handler.py: 0%
281 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import re
2import os
3import datetime as dt
4import time
5from collections import defaultdict
6import io
7import requests
9import pandas as pd
10import tweepy
12from mindsdb.utilities import log
13from mindsdb.utilities.config import Config
15from mindsdb_sql_parser import ast
17from mindsdb.integrations.libs.api_handler import APIHandler, APITable, FuncParser
18from mindsdb.integrations.utilities.sql_utils import extract_comparison_conditions
19from mindsdb.integrations.utilities.date_utils import parse_utc_date
21from mindsdb.integrations.libs.response import (
22 HandlerStatusResponse as StatusResponse,
23 HandlerResponse as Response,
24 RESPONSE_TYPE
25)
27logger = log.getLogger(__name__)
30class TweetsTable(APITable):
32 def select(self, query: ast.Select) -> Response:
34 conditions = extract_comparison_conditions(query.where)
36 params = {}
37 filters = []
38 for op, arg1, arg2 in conditions:
40 if op == 'or':
41 raise NotImplementedError('OR is not supported')
42 if arg1 == 'created_at':
43 date = parse_utc_date(arg2)
44 if op == '>':
45 # "tweets/search/recent" doesn't accept dates earlier than 7 days
46 if (dt.datetime.now(dt.timezone.utc) - date).days > 7:
47 # skip this condition
48 continue
49 params['start_time'] = date
50 elif op == '<':
51 params['end_time'] = date
52 else:
53 raise NotImplementedError
55 elif arg1 == 'query':
56 if op == '=':
57 params[arg1] = arg2
58 else:
59 NotImplementedError(f'Unknown op: {op}')
61 elif arg1 == 'id':
62 if op == '>':
63 params['since_id'] = arg2
64 elif op == '>=':
65 raise NotImplementedError("Please use 'id > value'")
66 elif op == '<':
67 params['until_id'] = arg2
68 elif op == '<=':
69 raise NotImplementedError("Please use 'id < value'")
70 else:
71 NotImplementedError('Search with "id=" is not implemented')
73 else:
74 filters.append([op, arg1, arg2])
76 if query.limit is not None:
77 params['max_results'] = query.limit.value
79 params['expansions'] = ['author_id', 'in_reply_to_user_id']
80 params['tweet_fields'] = ['created_at', 'conversation_id', 'referenced_tweets']
81 params['user_fields'] = ['name', 'username']
83 if 'query' not in params:
84 # search not works without query, use 'mindsdb'
85 params['query'] = 'mindsdb'
87 result = self.handler.call_twitter_api(
88 method_name='search_recent_tweets',
89 params=params,
90 filters=filters
91 )
93 # filter targets
94 columns = []
95 for target in query.targets:
96 if isinstance(target, ast.Star):
97 columns = []
98 break
99 elif isinstance(target, ast.Identifier):
100 columns.append(target.parts[-1])
101 else:
102 raise NotImplementedError
104 if len(columns) == 0:
105 columns = self.get_columns()
107 # columns to lower case
108 columns = [name.lower() for name in columns]
110 if len(result) == 0:
111 result = pd.DataFrame([], columns=columns)
112 else:
113 # add absent columns
114 for col in set(columns) & set(result.columns) ^ set(columns):
115 result[col] = None
117 # filter by columns
118 result = result[columns]
119 return result
121 def get_columns(self):
122 return [
123 'id',
124 'created_at',
125 'text',
126 'edit_history_tweet_ids',
127 'author_id',
128 'author_name',
129 'author_username',
130 'conversation_id',
131 'in_reply_to_tweet_id',
132 'in_retweeted_to_tweet_id',
133 'in_quote_to_tweet_id',
134 'in_reply_to_user_id',
135 'in_reply_to_user_name',
136 'in_reply_to_user_username',
137 ]
139 def insert(self, query: ast.Insert):
140 # https://docs.tweepy.org/en/stable/client.html#tweepy.Client.create_tweet
141 columns = [col.name for col in query.columns]
143 insert_params = ('consumer_key', 'consumer_secret', 'access_token', 'access_token_secret')
144 for p in insert_params:
145 if p not in self.handler.connection_args:
146 raise Exception(f'To insert data into Twitter, you need to provide the following parameters when connecting it to MindsDB: {insert_params}') # noqa
148 for row in query.values:
149 params = dict(zip(columns, row))
151 # split long text over 280 symbols
152 max_text_len = 280
153 text = params['text']
155 # Post image if column media_url is provided, only do this on last tweet
156 media_ids = None
157 if 'media_url' in params:
158 media_url = params.pop('media_url')
160 # create an in memory file
161 resp = requests.get(media_url)
162 img = io.BytesIO(resp.content)
164 # upload media to twitter
165 api_v1 = self.handler.create_connection(api_version=1)
166 content_type = resp.headers['Content-Type']
167 file_type = content_type.split('/')[-1]
168 media = api_v1.media_upload(filename="img.{file_type}".format(file_type=file_type), file=img)
170 media_ids = [media.media_id]
172 words = re.split('( )', text)
174 messages = []
176 text2 = ''
177 pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
178 for word in words:
179 # replace the links in word to string with the length as twitter short url (23)
180 word2 = re.sub(pattern, '-' * 23, word)
181 if len(text2) + len(word2) > max_text_len - 3 - 7: # 3 is for ..., 7 is for (10/11)
182 messages.append(text2.strip())
184 text2 = ''
185 text2 += word
187 # the last message
188 if text2.strip() != '':
189 messages.append(text2.strip())
191 len_messages = len(messages)
192 for i, text in enumerate(messages):
193 if i < len_messages - 1:
194 text += '...'
195 else:
196 text += ' '
197 # publish media with the last message
198 if media_ids is not None:
199 params['media_ids'] = media_ids
201 text += f'({i + 1}/{len_messages})'
203 params['text'] = text
204 ret = self.handler.call_twitter_api('create_tweet', params)
205 inserted_id = ret.id[0]
206 params['in_reply_to_tweet_id'] = inserted_id
209class TwitterHandler(APIHandler):
210 """A class for handling connections and interactions with the Twitter API.
212 Attributes:
213 bearer_token (str): The consumer key for the Twitter app.
214 api (tweepy.API): The `tweepy.API` object for interacting with the Twitter API.
216 """
218 def __init__(self, name=None, **kwargs):
219 super().__init__(name)
221 args = kwargs.get('connection_data', {})
223 self.connection_args = {}
224 handler_config = Config().get('twitter_handler', {})
225 for k in ['bearer_token', 'consumer_key', 'consumer_secret',
226 'access_token', 'access_token_secret', 'wait_on_rate_limit']:
227 if k in args:
228 self.connection_args[k] = args[k]
229 elif f'TWITTER_{k.upper()}' in os.environ:
230 self.connection_args[k] = os.environ[f'TWITTER_{k.upper()}']
231 elif k in handler_config:
232 self.connection_args[k] = handler_config[k]
234 self.api = None
235 self.is_connected = False
237 tweets = TweetsTable(self)
238 self._register_table('tweets', tweets)
240 def create_connection(self, api_version=2):
241 if api_version == 1:
242 auth = tweepy.OAuthHandler(
243 self.connection_args['consumer_key'],
244 self.connection_args['consumer_secret']
245 )
246 auth.set_access_token(
247 self.connection_args['access_token'],
248 self.connection_args['access_token_secret']
249 )
250 return tweepy.API(auth)
252 return tweepy.Client(**self.connection_args)
254 def connect(self, api_version=2):
255 """Authenticate with the Twitter API using the API keys and secrets stored in the `consumer_key`, `consumer_secret`, `access_token`, and `access_token_secret` attributes.""" # noqa
257 if self.is_connected is True:
258 return self.api
260 self.api = self.create_connection()
262 self.is_connected = True
263 return self.api
265 def check_connection(self) -> StatusResponse:
267 response = StatusResponse(False)
269 try:
270 api = self.connect()
272 # call get_user with unknown id.
273 # it raises an error in case if auth is not success and returns not-found otherwise
274 # api.get_me() is not exposed for OAuth 2.0 App-only authorisation
275 api.get_user(id=1)
276 response.success = True
278 except tweepy.Unauthorized as e:
279 response.error_message = f'Error connecting to Twitter api: {e}. Check bearer_token'
280 logger.error(response.error_message)
282 if response.success is True and len(self.connection_args) > 1:
283 # not only bearer_token, check read-write mode (OAuth 2.0 Authorization Code with PKCE)
284 try:
285 api = self.connect()
287 api.get_me()
289 except tweepy.Unauthorized as e:
290 keys = 'consumer_key', 'consumer_secret', 'access_token', 'access_token_secret'
291 response.error_message = f'Error connecting to Twitter api: {e}. Check' + ', '.join(keys)
292 logger.error(response.error_message)
294 response.success = False
296 if response.success is False and self.is_connected is True:
297 self.is_connected = False
299 return response
301 def native_query(self, query_string: str = None):
302 method_name, params = FuncParser().from_string(query_string)
304 df = self.call_twitter_api(method_name, params)
306 return Response(
307 RESPONSE_TYPE.TABLE,
308 data_frame=df
309 )
311 def _apply_filters(self, data, filters):
312 if not filters:
313 return data
315 data2 = []
316 for row in data:
317 add = False
318 for op, key, value in filters:
319 value2 = row.get(key)
320 if isinstance(value, int):
321 # twitter returns ids as string
322 value = str(value)
324 if op in ('!=', '<>'):
325 if value == value2:
326 break
327 elif op in ('==', '='):
328 if value != value2:
329 break
330 elif op == 'in':
331 if not isinstance(value, list):
332 value = [value]
333 if value2 not in value:
334 break
335 elif op == 'not in':
336 if not isinstance(value, list):
337 value = [value]
338 if value2 in value:
339 break
340 else:
341 raise NotImplementedError(f'Unknown filter: {op}')
342 # only if there wasn't breaks
343 add = True
344 if add:
345 data2.append(row)
346 return data2
348 def call_twitter_api(self, method_name: str = None, params: dict = None, filters: list = None):
350 # method > table > columns
351 expansions_map = {
352 'search_recent_tweets': {
353 'users': ['author_id', 'in_reply_to_user_id'],
354 },
355 'search_all_tweets': {
356 'users': ['author_id'],
357 },
358 }
360 api = self.connect()
361 method = getattr(api, method_name)
363 # pagination handle
365 count_results = None
366 if 'max_results' in params:
367 count_results = params['max_results']
369 data = []
370 includes = defaultdict(list)
372 max_page_size = 100
373 min_page_size = 10
374 left = None
376 limit_exec_time = time.time() + 60
378 if filters:
379 # if we have filters: do big page requests
380 params['max_results'] = max_page_size
382 while True:
383 if time.time() > limit_exec_time:
384 raise RuntimeError('Handler request timeout error')
386 if count_results is not None:
387 left = count_results - len(data)
388 if left == 0:
389 break
390 elif left < 0:
391 # got more results that we need
392 data = data[:left]
393 break
395 if left > max_page_size:
396 params['max_results'] = max_page_size
397 elif left < min_page_size:
398 params['max_results'] = min_page_size
399 else:
400 params['max_results'] = left
402 logger.debug(f'>>>twitter in: {method_name}({params})')
403 resp = method(**params)
405 if hasattr(resp, 'includes'):
406 for table, records in resp.includes.items():
407 includes[table].extend([r.data for r in records])
409 if isinstance(resp.data, list):
410 chunk = [r.data for r in resp.data]
411 else:
412 if isinstance(resp.data, dict):
413 data.append(resp.data)
414 if hasattr(resp.data, 'data') and isinstance(resp.data.data, dict):
415 data.append(resp.data.data)
416 break
418 # unwind columns
419 for row in chunk:
420 if 'referenced_tweets' in row:
421 refs = row['referenced_tweets']
422 if isinstance(refs, list) and len(refs) > 0:
423 if refs[0]['type'] == 'replied_to':
424 row['in_reply_to_tweet_id'] = refs[0]['id']
425 if refs[0]['type'] == 'retweeted':
426 row['in_retweeted_to_tweet_id'] = refs[0]['id']
427 if refs[0]['type'] == 'quoted':
428 row['in_quote_to_tweet_id'] = refs[0]['id']
430 if filters:
431 chunk = self._apply_filters(chunk, filters)
433 # limit output
434 if left is not None:
435 chunk = chunk[:left]
437 data.extend(chunk)
438 # next page ?
439 if count_results is not None and hasattr(resp, 'meta') and 'next_token' in resp.meta:
440 params['next_token'] = resp.meta['next_token']
441 else:
442 break
444 df = pd.DataFrame(data)
446 # enrich
447 expansions = expansions_map.get(method_name)
448 if expansions is not None:
449 for table, records in includes.items():
450 df_ref = pd.DataFrame(records)
452 if table not in expansions:
453 continue
455 for col_id in expansions[table]:
456 col = col_id[:-3] # cut _id
457 if col_id not in df.columns:
458 continue
460 col_map = {
461 col_ref: f'{col}_{col_ref}'
462 for col_ref in df_ref.columns
463 }
464 df_ref2 = df_ref.rename(columns=col_map)
465 df_ref2 = df_ref2.drop_duplicates(col_id)
467 df = df.merge(df_ref2, on=col_id, how='left')
469 return df