Coverage for mindsdb / integrations / handlers / byom_handler / proc_wrapper.py: 16%

128 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1""" 

2Utility functions used in the 'Bring Your Own Model' (BYOM) engine. 

3 

4These functions interact with interfaces (stdin, stdout), python files, and the actual BYOM engine. 

5 

6In particular, they: 

7 - Wrap and run python code in separate python proceess. 

8 - Communicate with parent process throughout stdin/out using pickle to serialize objects. 

9 

10 

11The flow is as follows: 

12 

13 1. Receive module code, method with parameters and stored attributes from parent process 

14 2. A python class object is created from the code 

15 3. Class is instanced and filled with stored attributes 

16 4. A calls to the chosen method of the class is performed with any relevant parameters that were passed 

17 5. Response is generated, appropriately packaged and sent to stdout 

18 6. Exit 

19""" 

20 

21import io 

22import sys 

23import pickle 

24import inspect 

25from enum import Enum 

26 

27import pandas as pd 

28 

29 

30class BYOM_METHOD(Enum): 

31 CHECK = 1 

32 TRAIN = 2 

33 PREDICT = 3 

34 FINETUNE = 4 

35 DESCRIBE = 5 

36 FUNC_CALL = 6 

37 

38 

39def pd_encode(df): 

40 return df.to_parquet(engine='pyarrow') 

41 

42 

43def pd_decode(encoded): 

44 fd = io.BytesIO() 

45 fd.write(encoded) 

46 fd.seek(0) 

47 return pd.read_parquet(fd, engine='pyarrow') 

48 

49 

50def encode(obj): 

51 return pickle.dumps(obj, protocol=5) 

52 

53 

54def decode(encoded): 

55 return pickle.loads(encoded) 

56 

57 

58def return_output(obj): 

59 # read stdin 

60 encoded = encode(obj) 

61 with open(1, 'wb') as fd: 

62 fd.write(encoded) 

63 sys.exit(0) 

64 

65 

66def get_input(): 

67 # write to stdout 

68 with open(0, 'rb') as fd: 

69 encoded = fd.read() 

70 obj = decode(encoded) 

71 return obj 

72 

73 

74def import_string(code, module_name='model'): 

75 # import string as python module 

76 

77 import types 

78 module = types.ModuleType(module_name) 

79 

80 exec(code, module.__dict__) 

81 # sys.modules['my_module'] = module 

82 return module 

83 

84 

85def find_model_class(module): 

86 # find the first class that contains predict and train methods 

87 for _, cls in inspect.getmembers(module, inspect.isclass): 

88 if inspect.getmodule(cls) is not None: 

89 # is imported class 

90 continue 

91 

92 funcs = [ 

93 name 

94 for name, _ in inspect.getmembers(cls, inspect.isfunction) 

95 ] 

96 if 'predict' in funcs and 'train' in funcs: 

97 # found 

98 return cls 

99 

100 

101def get_methods_info(module): 

102 # get all methods and their types 

103 methods = {} 

104 for method_name, method in inspect.getmembers(module, inspect.isfunction): 

105 

106 sig = inspect.signature(method) 

107 input_params = [ 

108 {'name': name, 'type': param.annotation.__name__} 

109 for name, param in sig.parameters.items() 

110 ] 

111 methods[method_name] = { 

112 'input_params': input_params, 

113 'output_type': sig.return_annotation.__name__ 

114 } 

115 return methods 

116 

117 

118def check_module(module, mode): 

119 # checks module and returns info 

120 

121 methods = {} 

122 if mode == 'custom_function': 

123 methods = get_methods_info(module) 

124 

125 else: 

126 # is BYOM, check it. 

127 model_class = find_model_class(module) 

128 if model_class is None: 

129 raise RuntimeError('Unable to find model class (it has to have `train` and `predict` methods)') 

130 

131 # try to initialize 

132 model_class() 

133 return {'methods': methods} 

134 

135 

136def main(): 

137 # replace print output to stderr 

138 sys.stdout = sys.stderr 

139 

140 params = get_input() 

141 

142 method = BYOM_METHOD(params['method']) 

143 code = params['code'] 

144 

145 module = import_string(code) 

146 

147 if method == BYOM_METHOD.FUNC_CALL: 

148 func_name = params['func_name'] 

149 args = params['args'] 

150 

151 func = getattr(module, func_name) 

152 return return_output(func(*args)) 

153 

154 if method == BYOM_METHOD.CHECK: 

155 

156 mode = params['mode'] 

157 info = check_module(module, mode) 

158 

159 return return_output(info) 

160 

161 model_class = find_model_class(module) 

162 

163 if method == BYOM_METHOD.TRAIN: 

164 df = params['df'] 

165 if df is not None: 

166 df = pd_decode(df) 

167 to_predict = params['to_predict'] 

168 args = params['args'] 

169 model = model_class() 

170 

171 call_args = [df, to_predict] 

172 if args: 

173 call_args.append(args) 

174 model.train(*call_args) 

175 

176 # return model 

177 data = model.__dict__ 

178 

179 model_state = encode(data) 

180 return_output(model_state) 

181 

182 elif method == BYOM_METHOD.PREDICT: 

183 model_state = params['model_state'] 

184 df = pd_decode(params['df']) 

185 args = params['args'] 

186 

187 model = model_class() 

188 model.__dict__ = decode(model_state) 

189 

190 call_args = [df] 

191 if args: 

192 call_args.append(args) 

193 res = model.predict(*call_args) 

194 return_output(pd_encode(res)) 

195 

196 elif method == BYOM_METHOD.FINETUNE: 

197 model_state = params['model_state'] 

198 df = pd_decode(params['df']) 

199 args = params['args'] 

200 

201 model = model_class() 

202 model.__dict__ = decode(model_state) 

203 

204 call_args = [df] 

205 if args: 

206 call_args.append(args) 

207 

208 model.finetune(*call_args) 

209 

210 # return model 

211 data = model.__dict__ 

212 model_state = encode(data) 

213 return_output(model_state) 

214 

215 elif method == BYOM_METHOD.DESCRIBE: 

216 model_state = params['model_state'] 

217 model = model_class() 

218 model.__dict__ = decode(model_state) 

219 try: 

220 df = model.describe(params.get('attribute')) 

221 except Exception: 

222 return_output(pd_encode(pd.DataFrame())) 

223 return_output(pd_encode(df)) 

224 

225 raise NotImplementedError(method) 

226 

227 

228if __name__ == '__main__': 228 ↛ 229line 228 didn't jump to line 229 because the condition on line 228 was never true

229 main()