Coverage for mindsdb / integrations / handlers / ollama_handler / ollama_handler.py: 0%

89 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import json 

2import requests 

3from typing import Dict, Optional 

4 

5import pandas as pd 

6 

7from mindsdb.integrations.libs.base import BaseMLEngine 

8from mindsdb.integrations.libs.llm.utils import get_completed_prompts 

9 

10 

11class OllamaHandler(BaseMLEngine): 

12 name = "ollama" 

13 DEFAULT_SERVE_URL = "http://localhost:11434" 

14 

15 @staticmethod 

16 def create_validation(target, args=None, **kwargs): 

17 if 'using' not in args: 

18 raise Exception("Ollama engine requires a USING clause! Refer to its documentation for more details.") 

19 else: 

20 args = args['using'] 

21 

22 if 'model_name' not in args: 

23 raise Exception('`model_name` must be provided in the USING clause.') 

24 

25 # check ollama service health 

26 connection = args.get('ollama_serve_url', OllamaHandler.DEFAULT_SERVE_URL) 

27 status = requests.get(connection + '/api/tags').status_code 

28 if status != 200: 

29 raise Exception(f"Ollama service is not working (status `{status}`). Please double check it is running and try again.") # noqa 

30 

31 def create(self, target: str, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None: 

32 """ Pull LLM artifacts with Ollama API. """ 

33 # arg setter 

34 args = args['using'] 

35 args['target'] = target 

36 connection = args.get('ollama_serve_url', OllamaHandler.DEFAULT_SERVE_URL) 

37 

38 def _model_check(): 

39 """ Checks model has been pulled and that it works correctly. """ 

40 responses = {} 

41 for endpoint in ['generate', 'embeddings']: 

42 try: 

43 code = requests.post( 

44 connection + f'/api/{endpoint}', 

45 json={ 

46 'model': args['model_name'], 

47 'prompt': 'Hello.', 

48 } 

49 ).status_code 

50 responses[endpoint] = code 

51 except Exception: 

52 responses[endpoint] = 500 

53 return responses 

54 

55 # check model for all supported endpoints 

56 responses = _model_check() 

57 if 200 not in responses.values(): 

58 # pull model (blocking operation) and serve 

59 # TODO: point to the engine storage folder instead of default location 

60 connection = args.get('ollama_serve_url', OllamaHandler.DEFAULT_SERVE_URL) 

61 requests.post(connection + '/api/pull', json={'name': args['model_name']}) 

62 # try one last time 

63 responses = _model_check() 

64 if 200 not in responses.values(): 

65 raise Exception(f"Ollama model `{args['model_name']}` is not working correctly. Please try pulling this model manually, check it works correctly and try again.") # noqa 

66 

67 supported_modes = {k: True if v == 200 else False for k, v in responses.items()} 

68 

69 # check if a mode has been provided and if it is valid 

70 runnable_modes = [mode for mode, supported in supported_modes.items() if supported] 

71 if 'mode' in args: 

72 if args['mode'] not in runnable_modes: 

73 raise Exception(f"Mode `{args['mode']}` is not supported by the model `{args['model_name']}`.") 

74 

75 # if a mode has not been provided, check if the model supports only one mode 

76 # if it does, set it as the default mode 

77 # if it supports multiple modes, set the default mode to 'generate' 

78 else: 

79 if len(runnable_modes) == 1: 

80 args['mode'] = runnable_modes[0] 

81 else: 

82 args['mode'] = 'generate' 

83 

84 self.model_storage.json_set('args', args) 

85 

86 def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame: 

87 """ 

88 Generate text completions with the local LLM. 

89 Args: 

90 df (pd.DataFrame): The input DataFrame containing data to predict. 

91 args (Optional[Dict]): Additional arguments for prediction parameters. 

92 Returns: 

93 pd.DataFrame: The DataFrame containing row-wise text completions. 

94 """ 

95 # setup 

96 pred_args = args.get('predict_params', {}) 

97 args = self.model_storage.json_get('args') 

98 model_name, target_col = args['model_name'], args['target'] 

99 prompt_template = pred_args.get('prompt_template', 

100 args.get('prompt_template', 'Answer the following question: {{{{text}}}}')) 

101 

102 # prepare prompts 

103 prompts, empty_prompt_ids = get_completed_prompts(prompt_template, df) 

104 df['__mdb_prompt'] = prompts 

105 

106 # setup endpoint 

107 endpoint = args.get('mode', 'generate') 

108 

109 # call llm 

110 completions = [] 

111 for i, row in df.iterrows(): 

112 if i not in empty_prompt_ids: 

113 connection = args.get('ollama_serve_url', OllamaHandler.DEFAULT_SERVE_URL) 

114 raw_output = requests.post( 

115 connection + f'/api/{endpoint}', 

116 json={ 

117 'model': model_name, 

118 'prompt': row['__mdb_prompt'], 

119 } 

120 ) 

121 lines = raw_output.content.decode().split('\n') # stream of output tokens 

122 

123 values = [] 

124 for line in lines: 

125 if line != '': 

126 info = json.loads(line) 

127 if 'response' in info: 

128 token = info['response'] 

129 values.append(token) 

130 elif 'embedding' in info: 

131 embedding = info['embedding'] 

132 values.append(embedding) 

133 

134 if endpoint == 'embeddings': 

135 completions.append(values) 

136 else: 

137 completions.append(''.join(values)) 

138 else: 

139 completions.append('') 

140 

141 # consolidate output 

142 data = pd.DataFrame(completions) 

143 data.columns = [target_col] 

144 return data 

145 

146 def describe(self, attribute: Optional[str] = None) -> pd.DataFrame: 

147 args = self.model_storage.json_get('args') 

148 model_name, target_col = args['model_name'], args['target'] 

149 prompt_template = args.get('prompt_template', 'Answer the following question: {{{{text}}}}') 

150 

151 if attribute == "features": 

152 return pd.DataFrame([[target_col, prompt_template]], columns=['target_column', 'mindsdb_prompt_template']) 

153 

154 # get model info 

155 else: 

156 connection = args.get('ollama_serve_url', OllamaHandler.DEFAULT_SERVE_URL) 

157 model_info = requests.post(connection + '/api/show', json={'name': model_name}).json() 

158 return pd.DataFrame([[ 

159 model_name, 

160 model_info.get('license', 'N/A'), 

161 model_info.get('modelfile', 'N/A'), 

162 model_info.get('parameters', 'N/A'), 

163 model_info.get('template', 'N/A'), 

164 ]], 

165 columns=[ 

166 'model_type', 

167 'license', 

168 'modelfile', 

169 'parameters', 

170 'ollama_base_template', 

171 ])