Coverage for mindsdb / integrations / handlers / cohere_handler / cohere_handler.py: 0%
43 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from typing import Optional, Dict
3import cohere
4import pandas as pd
6from mindsdb.integrations.libs.base import BaseMLEngine
8from mindsdb.utilities import log
10from mindsdb.integrations.utilities.handler_utils import get_api_key
13logger = log.getLogger(__name__)
16class CohereHandler(BaseMLEngine):
17 """
18 Integration with the Cohere Python Library
19 """
20 name = 'cohere'
22 def create(self, target: str, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None:
23 if 'using' not in args:
24 raise Exception("Cohere engine requires a USING clause! Refer to its documentation for more details.")
26 self.generative = True
27 self.model_storage.json_set('args', args)
29 def predict(self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None:
31 args = self.model_storage.json_get('args')
33 input_keys = list(args.keys())
35 logger.info(f"Input keys: {input_keys}!")
37 input_column = args['using']['column']
39 if input_column not in df.columns:
40 raise RuntimeError(f'Column "{input_column}" not found in input data')
42 result_df = pd.DataFrame()
44 if args['using']['task'] == 'text-summarization':
45 result_df['predictions'] = df[input_column].apply(self.predict_text_summary)
47 elif args['using']['task'] == 'text-generation':
48 result_df['predictions'] = df[input_column].apply(self.predict_text_generation)
50 else:
51 raise Exception(f"Task {args['using']['task']} is not supported!")
53 result_df = result_df.rename(columns={'predictions': args['target']})
55 return result_df
57 def predict_text_summary(self, text):
58 """
59 connects with cohere api to predict the summary of the input text
61 """
63 args = self.model_storage.json_get('args')
65 api_key = get_api_key('cohere', args["using"], self.engine_storage, strict=False)
66 co = cohere.Client(api_key)
68 response = co.summarize(text)
69 text_summary = response.summary
71 return text_summary
73 def predict_text_generation(self, text):
74 """
75 connects with cohere api to predict the next prompt of the input text
77 """
78 args = self.model_storage.json_get('args')
80 api_key = get_api_key('cohere', args["using"], self.engine_storage, strict=False)
81 co = cohere.Client(api_key)
83 response = co.generate(text)
84 text_generated = response.generations[0].text
86 return text_generated