Coverage for mindsdb / integrations / handlers / ray_serve_handler / ray_serve_handler.py: 0%
73 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import io
2import json
4import requests
5from typing import Dict, Optional
7import pandas as pd
8import pyarrow.parquet as pq
10from mindsdb.integrations.libs.base import BaseMLEngine
13class RayServeException(Exception):
14 pass
17class RayServeHandler(BaseMLEngine):
18 """
19 The Ray Serve integration engine needs to have a working connection to Ray Serve. For this:
20 - A Ray Serve server should be running
22 Example:
24 """ # noqa
25 name = 'ray_serve'
27 @staticmethod
28 def create_validation(target, args=None, **kwargs):
29 if not args.get('using'):
30 raise Exception("Error: This engine requires some parameters via the 'using' clause. Please refer to the documentation of the Ray Serve handler and try again.") # noqa
31 if not args['using'].get('train_url'):
32 raise Exception("Error: Please provide a URL for the training endpoint.")
33 if not args['using'].get('predict_url'):
34 raise Exception("Error: Please provide a URL for the prediction endpoint.")
36 def create(self, target: str, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None:
37 # TODO: use join_learn_process to notify users when ray has finished the training process
38 args = args['using'] # ignore the rest of the problem definition
39 args['target'] = target
40 self.model_storage.json_set('args', args)
41 try:
42 if args.get('is_parquet', False):
43 buffer = io.BytesIO()
44 df.to_parquet(buffer)
45 resp = requests.post(args['train_url'],
46 files={"df": ("df", buffer.getvalue(), "application/octet-stream")},
47 data={"args": json.dumps(args), "target": target},
48 )
49 else:
50 resp = requests.post(args['train_url'],
51 json={'df': df.to_json(orient='records'), 'target': target, 'args': args},
52 headers={'content-type': 'application/json; format=pandas-records'})
53 except requests.exceptions.InvalidSchema:
54 raise Exception("Error: The URL provided for the training endpoint is invalid.")
56 error = None
57 try:
58 resp = resp.json()
59 except json.JSONDecodeError:
60 error = resp.text
61 else:
62 if resp.get('status') != 'ok':
63 error = resp['status']
65 if error:
66 raise RayServeException(f"Error: {error}")
68 def predict(self, df, args=None):
69 args = {**(self.model_storage.json_get('args')), **args} # merge incoming args
70 pred_args = args.get('predict_params', {})
71 args = {**args, **pred_args} # merge pred_args
72 if args.get('is_parquet', False):
73 buffer = io.BytesIO()
74 df.attrs['pred_args'] = pred_args
75 df.to_parquet(buffer)
76 resp = requests.post(args['predict_url'],
77 files={"df": ("df", buffer.getvalue(), "application/octet-stream")},
78 data={"pred_args": json.dumps(pred_args)},
79 )
80 else:
81 resp = requests.post(args['predict_url'],
82 json={'df': df.to_json(orient='records'), 'pred_args': pred_args},
83 headers={'content-type': 'application/json; format=pandas-records'})
84 content_type = resp.headers.get("Content-Type", "")
85 if "application/octet-stream" in content_type:
86 try:
87 buffer = io.BytesIO(resp.content)
88 table = pq.read_table(buffer)
89 response = table.to_pandas()
90 except Exception:
91 error = 'Could not decode parquet.'
92 else:
93 try:
94 response = resp.json()
95 except json.JSONDecodeError:
96 error = resp.text
98 if 'prediction' in response:
99 target = args['target']
100 if target != 'prediction':
101 # rename prediction to target
102 response[target] = response.pop('prediction')
103 return pd.DataFrame(response)
104 else:
105 # something wrong
106 error = response
108 raise RayServeException(f"Error: {error}")
110 def describe(self, key: Optional[str] = None) -> pd.DataFrame:
111 args = self.model_storage.json_get('args')
112 description = {
113 'TRAIN_URL': [args['train_url']],
114 'PREDICT_URL': [args['predict_url']],
115 'TARGET': [args['target']],
116 }
117 return pd.DataFrame.from_dict(description)