Coverage for mindsdb / integrations / handlers / cohere_handler / cohere_handler.py: 0%

43 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from typing import Optional, Dict 

2 

3import cohere 

4import pandas as pd 

5 

6from mindsdb.integrations.libs.base import BaseMLEngine 

7 

8from mindsdb.utilities import log 

9 

10from mindsdb.integrations.utilities.handler_utils import get_api_key 

11 

12 

13logger = log.getLogger(__name__) 

14 

15 

16class CohereHandler(BaseMLEngine): 

17 """ 

18 Integration with the Cohere Python Library 

19 """ 

20 name = 'cohere' 

21 

22 def create(self, target: str, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None: 

23 if 'using' not in args: 

24 raise Exception("Cohere engine requires a USING clause! Refer to its documentation for more details.") 

25 

26 self.generative = True 

27 self.model_storage.json_set('args', args) 

28 

29 def predict(self, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None: 

30 

31 args = self.model_storage.json_get('args') 

32 

33 input_keys = list(args.keys()) 

34 

35 logger.info(f"Input keys: {input_keys}!") 

36 

37 input_column = args['using']['column'] 

38 

39 if input_column not in df.columns: 

40 raise RuntimeError(f'Column "{input_column}" not found in input data') 

41 

42 result_df = pd.DataFrame() 

43 

44 if args['using']['task'] == 'text-summarization': 

45 result_df['predictions'] = df[input_column].apply(self.predict_text_summary) 

46 

47 elif args['using']['task'] == 'text-generation': 

48 result_df['predictions'] = df[input_column].apply(self.predict_text_generation) 

49 

50 else: 

51 raise Exception(f"Task {args['using']['task']} is not supported!") 

52 

53 result_df = result_df.rename(columns={'predictions': args['target']}) 

54 

55 return result_df 

56 

57 def predict_text_summary(self, text): 

58 """ 

59 connects with cohere api to predict the summary of the input text 

60 

61 """ 

62 

63 args = self.model_storage.json_get('args') 

64 

65 api_key = get_api_key('cohere', args["using"], self.engine_storage, strict=False) 

66 co = cohere.Client(api_key) 

67 

68 response = co.summarize(text) 

69 text_summary = response.summary 

70 

71 return text_summary 

72 

73 def predict_text_generation(self, text): 

74 """ 

75 connects with cohere api to predict the next prompt of the input text 

76 

77 """ 

78 args = self.model_storage.json_get('args') 

79 

80 api_key = get_api_key('cohere', args["using"], self.engine_storage, strict=False) 

81 co = cohere.Client(api_key) 

82 

83 response = co.generate(text) 

84 text_generated = response.generations[0].text 

85 

86 return text_generated