Coverage for mindsdb / integrations / handlers / ray_serve_handler / ray_serve_handler.py: 0%

73 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import io 

2import json 

3 

4import requests 

5from typing import Dict, Optional 

6 

7import pandas as pd 

8import pyarrow.parquet as pq 

9 

10from mindsdb.integrations.libs.base import BaseMLEngine 

11 

12 

13class RayServeException(Exception): 

14 pass 

15 

16 

17class RayServeHandler(BaseMLEngine): 

18 """ 

19 The Ray Serve integration engine needs to have a working connection to Ray Serve. For this: 

20 - A Ray Serve server should be running 

21 

22 Example: 

23 

24 """ # noqa 

25 name = 'ray_serve' 

26 

27 @staticmethod 

28 def create_validation(target, args=None, **kwargs): 

29 if not args.get('using'): 

30 raise Exception("Error: This engine requires some parameters via the 'using' clause. Please refer to the documentation of the Ray Serve handler and try again.") # noqa 

31 if not args['using'].get('train_url'): 

32 raise Exception("Error: Please provide a URL for the training endpoint.") 

33 if not args['using'].get('predict_url'): 

34 raise Exception("Error: Please provide a URL for the prediction endpoint.") 

35 

36 def create(self, target: str, df: Optional[pd.DataFrame] = None, args: Optional[Dict] = None) -> None: 

37 # TODO: use join_learn_process to notify users when ray has finished the training process 

38 args = args['using'] # ignore the rest of the problem definition 

39 args['target'] = target 

40 self.model_storage.json_set('args', args) 

41 try: 

42 if args.get('is_parquet', False): 

43 buffer = io.BytesIO() 

44 df.to_parquet(buffer) 

45 resp = requests.post(args['train_url'], 

46 files={"df": ("df", buffer.getvalue(), "application/octet-stream")}, 

47 data={"args": json.dumps(args), "target": target}, 

48 ) 

49 else: 

50 resp = requests.post(args['train_url'], 

51 json={'df': df.to_json(orient='records'), 'target': target, 'args': args}, 

52 headers={'content-type': 'application/json; format=pandas-records'}) 

53 except requests.exceptions.InvalidSchema: 

54 raise Exception("Error: The URL provided for the training endpoint is invalid.") 

55 

56 error = None 

57 try: 

58 resp = resp.json() 

59 except json.JSONDecodeError: 

60 error = resp.text 

61 else: 

62 if resp.get('status') != 'ok': 

63 error = resp['status'] 

64 

65 if error: 

66 raise RayServeException(f"Error: {error}") 

67 

68 def predict(self, df, args=None): 

69 args = {**(self.model_storage.json_get('args')), **args} # merge incoming args 

70 pred_args = args.get('predict_params', {}) 

71 args = {**args, **pred_args} # merge pred_args 

72 if args.get('is_parquet', False): 

73 buffer = io.BytesIO() 

74 df.attrs['pred_args'] = pred_args 

75 df.to_parquet(buffer) 

76 resp = requests.post(args['predict_url'], 

77 files={"df": ("df", buffer.getvalue(), "application/octet-stream")}, 

78 data={"pred_args": json.dumps(pred_args)}, 

79 ) 

80 else: 

81 resp = requests.post(args['predict_url'], 

82 json={'df': df.to_json(orient='records'), 'pred_args': pred_args}, 

83 headers={'content-type': 'application/json; format=pandas-records'}) 

84 content_type = resp.headers.get("Content-Type", "") 

85 if "application/octet-stream" in content_type: 

86 try: 

87 buffer = io.BytesIO(resp.content) 

88 table = pq.read_table(buffer) 

89 response = table.to_pandas() 

90 except Exception: 

91 error = 'Could not decode parquet.' 

92 else: 

93 try: 

94 response = resp.json() 

95 except json.JSONDecodeError: 

96 error = resp.text 

97 

98 if 'prediction' in response: 

99 target = args['target'] 

100 if target != 'prediction': 

101 # rename prediction to target 

102 response[target] = response.pop('prediction') 

103 return pd.DataFrame(response) 

104 else: 

105 # something wrong 

106 error = response 

107 

108 raise RayServeException(f"Error: {error}") 

109 

110 def describe(self, key: Optional[str] = None) -> pd.DataFrame: 

111 args = self.model_storage.json_get('args') 

112 description = { 

113 'TRAIN_URL': [args['train_url']], 

114 'PREDICT_URL': [args['predict_url']], 

115 'TARGET': [args['target']], 

116 } 

117 return pd.DataFrame.from_dict(description)