Coverage for mindsdb / integrations / handlers / reddit_handler / reddit_tables.py: 0%
92 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import pandas as pd
2from mindsdb.integrations.libs.api_handler import APITable
3from mindsdb_sql_parser import ast
4from mindsdb.integrations.utilities.sql_utils import extract_comparison_conditions
7class CommentTable(APITable):
8 def select(self, query: ast.Select) -> pd.DataFrame:
9 '''Select data from the comment table and return it as a pandas DataFrame.
11 Args:
12 query (ast.Select): The SQL query to be executed.
14 Returns:
15 pandas.DataFrame: A pandas DataFrame containing the selected data.
16 '''
18 reddit = self.handler.connect()
20 submission_id = None
21 conditions = extract_comparison_conditions(query.where)
22 for condition in conditions:
23 if condition[0] == '=' and condition[1] == 'submission_id':
24 submission_id = condition[2]
25 break
27 if submission_id is None:
28 raise ValueError('Submission ID is missing in the SQL query')
30 submission = reddit.submission(id=submission_id)
31 submission.comments.replace_more(limit=None)
33 result = []
34 for comment in submission.comments.list():
35 data = {
36 'id': comment.id,
37 'body': comment.body,
38 'author': comment.author.name if comment.author else None,
39 'created_utc': comment.created_utc,
40 'score': comment.score,
41 'permalink': comment.permalink,
42 'ups': comment.ups,
43 'downs': comment.downs,
44 'subreddit': comment.subreddit.display_name,
45 }
46 result.append(data)
48 result = pd.DataFrame(result)
49 self.filter_columns(result, query)
50 return result
52 def get_columns(self):
53 '''Get the list of column names for the comment table.
55 Returns:
56 list: A list of column names for the comment table.
57 '''
58 return [
59 'id',
60 'body',
61 'author',
62 'created_utc',
63 'permalink',
64 'score',
65 'ups',
66 'downs',
67 'subreddit',
68 ]
70 def filter_columns(self, result: pd.DataFrame, query: ast.Select = None):
71 columns = []
72 if query is not None:
73 for target in query.targets:
74 if isinstance(target, ast.Star):
75 columns = self.get_columns()
76 break
77 elif isinstance(target, ast.Identifier):
78 columns.append(target.value)
79 if len(columns) > 0:
80 result = result[columns]
83class SubmissionTable(APITable):
84 def select(self, query: ast.Select) -> pd.DataFrame:
85 '''Select data from the submission table and return it as a pandas DataFrame.
87 Args:
88 query (ast.Select): The SQL query to be executed.
90 Returns:
91 pandas.DataFrame: A pandas DataFrame containing the selected data.
92 '''
94 reddit = self.handler.connect()
96 subreddit_name = None
97 sort_type = None
98 conditions = extract_comparison_conditions(query.where)
99 for condition in conditions:
100 if condition[0] == '=' and condition[1] == 'subreddit':
101 subreddit_name = condition[2]
102 elif condition[0] == '=' and condition[1] == 'sort_type':
103 sort_type = condition[2]
104 elif condition[0] == '=' and condition[1] == 'items':
105 items = int(condition[2])
107 if not sort_type:
108 sort_type = 'hot'
109 if not subreddit_name:
110 return pd.DataFrame()
112 if sort_type == 'new':
113 submissions = reddit.subreddit(subreddit_name).new(limit=items)
114 elif sort_type == 'rising':
115 submissions = reddit.subreddit(subreddit_name).rising(limit=items)
116 elif sort_type == 'controversial':
117 submissions = reddit.subreddit(subreddit_name).controversial(limit=items)
118 elif sort_type == 'top':
119 submissions = reddit.subreddit(subreddit_name).top(limit=items)
120 else:
121 submissions = reddit.subreddit(subreddit_name).hot(limit=items)
123 result = []
124 for submission in submissions:
125 data = {
126 'id': submission.id,
127 'title': submission.title,
128 'author': submission.author.name if submission.author else None,
129 'created_utc': submission.created_utc,
130 'score': submission.score,
131 'num_comments': submission.num_comments,
132 'permalink': submission.permalink,
133 'url': submission.url,
134 'ups': submission.ups,
135 'downs': submission.downs,
136 'num_crossposts': submission.num_crossposts,
137 'subreddit': submission.subreddit.display_name,
138 'selftext': submission.selftext,
139 }
140 result.append(data)
142 result = pd.DataFrame(result)
143 self.filter_columns(result, query)
144 return result
146 def get_columns(self):
147 '''Get the list of column names for the submission table.
149 Returns:
150 list: A list of column names for the submission table.
151 '''
152 return [
153 'id',
154 'title',
155 'author',
156 'created_utc',
157 'permalink',
158 'num_comments',
159 'score',
160 'ups',
161 'downs',
162 'num_crossposts',
163 'subreddit',
164 'selftext'
165 ]
167 def filter_columns(self, result: pd.DataFrame, query: ast.Select = None):
168 columns = []
169 if query is not None:
170 for target in query.targets:
171 if isinstance(target, ast.Star):
172 columns = self.get_columns()
173 break
174 elif isinstance(target, ast.Identifier):
175 columns.append(target.parts[-1])
176 else:
177 raise NotImplementedError
178 else:
179 columns = self.get_columns()
181 columns = [name.lower() for name in columns]
183 if len(result) == 0:
184 result = pd.DataFrame([], columns=columns)
185 else:
186 for col in set(columns) & set(result.columns) ^ set(columns):
187 result[col] = None
189 result = result[columns]
191 if query is not None and query.limit is not None:
192 return result.head(query.limit.value)
194 return result