Coverage for mindsdb / integrations / handlers / ollama_handler / ollama_handler.py: 0%
89 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import json
2import requests
3from typing import Dict, Optional
5import pandas as pd
7from mindsdb.integrations.libs.base import BaseMLEngine
8from mindsdb.integrations.libs.llm.utils import get_completed_prompts
11class OllamaHandler(BaseMLEngine):
12 name = "ollama"
13 DEFAULT_SERVE_URL = "http://localhost:11434"
15 @staticmethod
16 def create_validation(target, args=None, **kwargs):
17 if 'using' not in args:
18 raise Exception("Ollama engine requires a USING clause! Refer to its documentation for more details.")
19 else:
20 args = args['using']
22 if 'model_name' not in args:
23 raise Exception('`model_name` must be provided in the USING clause.')
25 # check ollama service health
26 connection = args.get('ollama_serve_url', OllamaHandler.DEFAULT_SERVE_URL)
27 status = requests.get(connection + '/api/tags').status_code
28 if status != 200:
29 raise Exception(f"Ollama service is not working (status `{status}`). Please double check it is running and try again.") # noqa
31 def create(self, target: str, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None:
32 """ Pull LLM artifacts with Ollama API. """
33 # arg setter
34 args = args['using']
35 args['target'] = target
36 connection = args.get('ollama_serve_url', OllamaHandler.DEFAULT_SERVE_URL)
38 def _model_check():
39 """ Checks model has been pulled and that it works correctly. """
40 responses = {}
41 for endpoint in ['generate', 'embeddings']:
42 try:
43 code = requests.post(
44 connection + f'/api/{endpoint}',
45 json={
46 'model': args['model_name'],
47 'prompt': 'Hello.',
48 }
49 ).status_code
50 responses[endpoint] = code
51 except Exception:
52 responses[endpoint] = 500
53 return responses
55 # check model for all supported endpoints
56 responses = _model_check()
57 if 200 not in responses.values():
58 # pull model (blocking operation) and serve
59 # TODO: point to the engine storage folder instead of default location
60 connection = args.get('ollama_serve_url', OllamaHandler.DEFAULT_SERVE_URL)
61 requests.post(connection + '/api/pull', json={'name': args['model_name']})
62 # try one last time
63 responses = _model_check()
64 if 200 not in responses.values():
65 raise Exception(f"Ollama model `{args['model_name']}` is not working correctly. Please try pulling this model manually, check it works correctly and try again.") # noqa
67 supported_modes = {k: True if v == 200 else False for k, v in responses.items()}
69 # check if a mode has been provided and if it is valid
70 runnable_modes = [mode for mode, supported in supported_modes.items() if supported]
71 if 'mode' in args:
72 if args['mode'] not in runnable_modes:
73 raise Exception(f"Mode `{args['mode']}` is not supported by the model `{args['model_name']}`.")
75 # if a mode has not been provided, check if the model supports only one mode
76 # if it does, set it as the default mode
77 # if it supports multiple modes, set the default mode to 'generate'
78 else:
79 if len(runnable_modes) == 1:
80 args['mode'] = runnable_modes[0]
81 else:
82 args['mode'] = 'generate'
84 self.model_storage.json_set('args', args)
86 def predict(self, df: pd.DataFrame, args: Optional[Dict] = None) -> pd.DataFrame:
87 """
88 Generate text completions with the local LLM.
89 Args:
90 df (pd.DataFrame): The input DataFrame containing data to predict.
91 args (Optional[Dict]): Additional arguments for prediction parameters.
92 Returns:
93 pd.DataFrame: The DataFrame containing row-wise text completions.
94 """
95 # setup
96 pred_args = args.get('predict_params', {})
97 args = self.model_storage.json_get('args')
98 model_name, target_col = args['model_name'], args['target']
99 prompt_template = pred_args.get('prompt_template',
100 args.get('prompt_template', 'Answer the following question: {{{{text}}}}'))
102 # prepare prompts
103 prompts, empty_prompt_ids = get_completed_prompts(prompt_template, df)
104 df['__mdb_prompt'] = prompts
106 # setup endpoint
107 endpoint = args.get('mode', 'generate')
109 # call llm
110 completions = []
111 for i, row in df.iterrows():
112 if i not in empty_prompt_ids:
113 connection = args.get('ollama_serve_url', OllamaHandler.DEFAULT_SERVE_URL)
114 raw_output = requests.post(
115 connection + f'/api/{endpoint}',
116 json={
117 'model': model_name,
118 'prompt': row['__mdb_prompt'],
119 }
120 )
121 lines = raw_output.content.decode().split('\n') # stream of output tokens
123 values = []
124 for line in lines:
125 if line != '':
126 info = json.loads(line)
127 if 'response' in info:
128 token = info['response']
129 values.append(token)
130 elif 'embedding' in info:
131 embedding = info['embedding']
132 values.append(embedding)
134 if endpoint == 'embeddings':
135 completions.append(values)
136 else:
137 completions.append(''.join(values))
138 else:
139 completions.append('')
141 # consolidate output
142 data = pd.DataFrame(completions)
143 data.columns = [target_col]
144 return data
146 def describe(self, attribute: Optional[str] = None) -> pd.DataFrame:
147 args = self.model_storage.json_get('args')
148 model_name, target_col = args['model_name'], args['target']
149 prompt_template = args.get('prompt_template', 'Answer the following question: {{{{text}}}}')
151 if attribute == "features":
152 return pd.DataFrame([[target_col, prompt_template]], columns=['target_column', 'mindsdb_prompt_template'])
154 # get model info
155 else:
156 connection = args.get('ollama_serve_url', OllamaHandler.DEFAULT_SERVE_URL)
157 model_info = requests.post(connection + '/api/show', json={'name': model_name}).json()
158 return pd.DataFrame([[
159 model_name,
160 model_info.get('license', 'N/A'),
161 model_info.get('modelfile', 'N/A'),
162 model_info.get('parameters', 'N/A'),
163 model_info.get('template', 'N/A'),
164 ]],
165 columns=[
166 'model_type',
167 'license',
168 'modelfile',
169 'parameters',
170 'ollama_base_template',
171 ])