Coverage for mindsdb / integrations / handlers / newsapi_handler / newsapi_handler.py: 0%
138 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import os
2import urllib
3from typing import Any
5import pandas as pd
6from mindsdb_sql_parser import parse_sql
7from mindsdb_sql_parser import ast
8from newsapi import NewsApiClient
10from mindsdb.api.executor.data_types.response_type import RESPONSE_TYPE
11from mindsdb.integrations.libs.api_handler import APIHandler, APITable
12from mindsdb.integrations.libs.response import HandlerResponse, HandlerStatusResponse
13from mindsdb.integrations.utilities.sql_utils import extract_comparison_conditions
14from mindsdb.utilities.config import Config
17class NewsAPIArticleTable(APITable):
18 def __init__(self, handler):
19 super().__init__(handler)
21 def select(self, query: ast.Select) -> pd.DataFrame:
22 conditions = extract_comparison_conditions(query.where)
24 params = {}
26 for op, arg1, arg2 in conditions:
27 if arg1 == "query":
28 params["q"] = urllib.parse.quote_plus(arg2)
29 elif arg1 == "sources":
30 if len(arg2.split(",")) > 20:
31 raise ValueError(
32 "The number of items it sources should be 20 or less"
33 )
34 else:
35 params[arg1] = arg2
36 elif arg1 == "publishedAt":
37 if op == "Gt" or op == "GtE":
38 params["from"] = arg2
39 if op == "Lt" or op == "LtE":
40 params["to"] = arg2
41 elif op == "Eq":
42 params["from"] = arg2
43 params["to"] = arg2
44 else:
45 params[arg1] = arg2
47 if query.limit:
48 if query.limit.value > 100:
49 params["page"], params["page_size"] = divmod(query.limit.value, 100)
50 if params["page_size"] == 0:
51 params["page_size"] = 100
52 else:
53 params["page_size"] = query.limit.value
54 params["page"] = 1
55 else:
56 params["page_size"] = 100
57 params["page"] = 1
59 if query.order_by:
60 if len(query.order_by) == 1:
61 order_column = str(query.order_by[0]).split('.')[-1]
62 if order_column not in ["relevancy", "publishedAt"]:
63 raise NotImplementedError("Not supported ordering by this field")
64 params["sort_by"] = order_column
65 else:
66 raise ValueError(
67 "Multiple order by condition is not supported by the API"
68 )
70 selected_columns = []
72 result = self.handler.call_application_api(params=params)
74 if not result.empty:
75 for target in query.targets:
76 if isinstance(target, ast.Star):
77 selected_columns = self.get_columns()
78 break
79 elif isinstance(target, ast.Identifier):
80 selected_columns.append(target.parts[-1])
81 else:
82 raise ValueError(f"Unknown query target {type(target)}")
84 return result[selected_columns]
86 def get_columns(self) -> list:
87 return [
88 "author",
89 "title",
90 "description",
91 "url",
92 "urlToImage",
93 "publishedAt",
94 "content",
95 "source_id",
96 "source_name",
97 "query",
98 "searchIn",
99 "domains",
100 "excludedDomains",
101 ]
104class NewsAPIHandler(APIHandler):
105 def __init__(self, name: str, **kwargs):
106 super().__init__(name)
107 self.api = None
108 self._tables = {}
110 args = kwargs.get("connection_data", {})
111 self.connection_args = {}
112 handler_config = Config().get("newsAPI_handler", {})
114 for k in ["api_key"]:
115 if k in args:
116 self.connection_args[k] = args[k]
117 elif f"NEWSAPI_{k.upper()}" in os.environ:
118 self.connection_args[k] = os.environ[f"NEWSAPI_{k.upper()}"]
119 elif k in handler_config:
120 self.connection_args[k] = handler_config[k]
122 self.is_connected = False
123 self.api = self.create_connection()
125 article = NewsAPIArticleTable(self)
126 self._register_table("article", article)
128 def __del__(self):
129 if self.is_connected is True:
130 self.disconnect()
132 def disconnect(self):
133 """
134 Close any existing connections.
135 """
137 if self.is_connected is False:
138 return
140 self.is_connected = False
141 return self.is_connected
143 def create_connection(self):
144 return NewsApiClient(**self.connection_args)
146 def _register_table(self, table_name: str, table_class: Any):
147 self._tables[table_name] = table_class
149 def get_table(self, table_name: str):
150 return self._tables.get(table_name)
152 def connect(self) -> HandlerStatusResponse:
153 if self.is_connected is True:
154 return self.api
156 self.api = self.create_connection()
158 self.is_connected = True
159 return HandlerStatusResponse(success=True)
161 def check_connection(self) -> HandlerStatusResponse:
162 response = HandlerStatusResponse(False)
164 try:
165 self.connect()
167 self.api.get_top_headlines(page_size=1, page=1)
168 response.success = True
170 except Exception as e:
171 response.error_message = e.message
173 return response
175 def native_query(self, query: Any):
176 ast = parse_sql(query)
177 table = self.get_table("article")
178 data = table.select(ast)
179 return HandlerResponse(RESPONSE_TYPE.TABLE, data_frame=data)
181 def call_application_api(
182 self, method_name: str = None, params: dict = None
183 ) -> pd.DataFrame:
184 # This will implement api base on the native query
185 # By processing native query to convert it to api callable parameters
186 if self.is_connected is False:
187 self.connect()
189 pages = params.get("page", 1)
190 data = []
192 for page in range(1, pages + 1):
193 params["page"] = page
194 try:
195 result = self.api.get_everything(**params)
196 except Exception as e:
197 raise RuntimeError(f"API call failed: {e}")
198 articles = result["articles"]
199 for article in articles:
200 article["source_id"] = article["source"]["id"]
201 article["source_name"] = article["source"]["name"]
202 del article["source"]
203 article["query"] = params.get("q")
204 article["searchIn"] = params.get("searchIn")
205 article["domains"] = params.get("domains")
206 article["excludedDomains"] = params.get("exclude_domains")
207 data.append(article)
209 return pd.DataFrame(data=data)