Coverage for mindsdb / integrations / handlers / reddit_handler / reddit_tables.py: 0%

92 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import pandas as pd 

2from mindsdb.integrations.libs.api_handler import APITable 

3from mindsdb_sql_parser import ast 

4from mindsdb.integrations.utilities.sql_utils import extract_comparison_conditions 

5 

6 

7class CommentTable(APITable): 

8 def select(self, query: ast.Select) -> pd.DataFrame: 

9 '''Select data from the comment table and return it as a pandas DataFrame. 

10 

11 Args: 

12 query (ast.Select): The SQL query to be executed. 

13 

14 Returns: 

15 pandas.DataFrame: A pandas DataFrame containing the selected data. 

16 ''' 

17 

18 reddit = self.handler.connect() 

19 

20 submission_id = None 

21 conditions = extract_comparison_conditions(query.where) 

22 for condition in conditions: 

23 if condition[0] == '=' and condition[1] == 'submission_id': 

24 submission_id = condition[2] 

25 break 

26 

27 if submission_id is None: 

28 raise ValueError('Submission ID is missing in the SQL query') 

29 

30 submission = reddit.submission(id=submission_id) 

31 submission.comments.replace_more(limit=None) 

32 

33 result = [] 

34 for comment in submission.comments.list(): 

35 data = { 

36 'id': comment.id, 

37 'body': comment.body, 

38 'author': comment.author.name if comment.author else None, 

39 'created_utc': comment.created_utc, 

40 'score': comment.score, 

41 'permalink': comment.permalink, 

42 'ups': comment.ups, 

43 'downs': comment.downs, 

44 'subreddit': comment.subreddit.display_name, 

45 } 

46 result.append(data) 

47 

48 result = pd.DataFrame(result) 

49 self.filter_columns(result, query) 

50 return result 

51 

52 def get_columns(self): 

53 '''Get the list of column names for the comment table. 

54 

55 Returns: 

56 list: A list of column names for the comment table. 

57 ''' 

58 return [ 

59 'id', 

60 'body', 

61 'author', 

62 'created_utc', 

63 'permalink', 

64 'score', 

65 'ups', 

66 'downs', 

67 'subreddit', 

68 ] 

69 

70 def filter_columns(self, result: pd.DataFrame, query: ast.Select = None): 

71 columns = [] 

72 if query is not None: 

73 for target in query.targets: 

74 if isinstance(target, ast.Star): 

75 columns = self.get_columns() 

76 break 

77 elif isinstance(target, ast.Identifier): 

78 columns.append(target.value) 

79 if len(columns) > 0: 

80 result = result[columns] 

81 

82 

83class SubmissionTable(APITable): 

84 def select(self, query: ast.Select) -> pd.DataFrame: 

85 '''Select data from the submission table and return it as a pandas DataFrame. 

86 

87 Args: 

88 query (ast.Select): The SQL query to be executed. 

89 

90 Returns: 

91 pandas.DataFrame: A pandas DataFrame containing the selected data. 

92 ''' 

93 

94 reddit = self.handler.connect() 

95 

96 subreddit_name = None 

97 sort_type = None 

98 conditions = extract_comparison_conditions(query.where) 

99 for condition in conditions: 

100 if condition[0] == '=' and condition[1] == 'subreddit': 

101 subreddit_name = condition[2] 

102 elif condition[0] == '=' and condition[1] == 'sort_type': 

103 sort_type = condition[2] 

104 elif condition[0] == '=' and condition[1] == 'items': 

105 items = int(condition[2]) 

106 

107 if not sort_type: 

108 sort_type = 'hot' 

109 if not subreddit_name: 

110 return pd.DataFrame() 

111 

112 if sort_type == 'new': 

113 submissions = reddit.subreddit(subreddit_name).new(limit=items) 

114 elif sort_type == 'rising': 

115 submissions = reddit.subreddit(subreddit_name).rising(limit=items) 

116 elif sort_type == 'controversial': 

117 submissions = reddit.subreddit(subreddit_name).controversial(limit=items) 

118 elif sort_type == 'top': 

119 submissions = reddit.subreddit(subreddit_name).top(limit=items) 

120 else: 

121 submissions = reddit.subreddit(subreddit_name).hot(limit=items) 

122 

123 result = [] 

124 for submission in submissions: 

125 data = { 

126 'id': submission.id, 

127 'title': submission.title, 

128 'author': submission.author.name if submission.author else None, 

129 'created_utc': submission.created_utc, 

130 'score': submission.score, 

131 'num_comments': submission.num_comments, 

132 'permalink': submission.permalink, 

133 'url': submission.url, 

134 'ups': submission.ups, 

135 'downs': submission.downs, 

136 'num_crossposts': submission.num_crossposts, 

137 'subreddit': submission.subreddit.display_name, 

138 'selftext': submission.selftext, 

139 } 

140 result.append(data) 

141 

142 result = pd.DataFrame(result) 

143 self.filter_columns(result, query) 

144 return result 

145 

146 def get_columns(self): 

147 '''Get the list of column names for the submission table. 

148 

149 Returns: 

150 list: A list of column names for the submission table. 

151 ''' 

152 return [ 

153 'id', 

154 'title', 

155 'author', 

156 'created_utc', 

157 'permalink', 

158 'num_comments', 

159 'score', 

160 'ups', 

161 'downs', 

162 'num_crossposts', 

163 'subreddit', 

164 'selftext' 

165 ] 

166 

167 def filter_columns(self, result: pd.DataFrame, query: ast.Select = None): 

168 columns = [] 

169 if query is not None: 

170 for target in query.targets: 

171 if isinstance(target, ast.Star): 

172 columns = self.get_columns() 

173 break 

174 elif isinstance(target, ast.Identifier): 

175 columns.append(target.parts[-1]) 

176 else: 

177 raise NotImplementedError 

178 else: 

179 columns = self.get_columns() 

180 

181 columns = [name.lower() for name in columns] 

182 

183 if len(result) == 0: 

184 result = pd.DataFrame([], columns=columns) 

185 else: 

186 for col in set(columns) & set(result.columns) ^ set(columns): 

187 result[col] = None 

188 

189 result = result[columns] 

190 

191 if query is not None and query.limit is not None: 

192 return result.head(query.limit.value) 

193 

194 return result