Coverage for mindsdb / integrations / handlers / byom_handler / proc_wrapper.py: 16%
128 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1"""
2Utility functions used in the 'Bring Your Own Model' (BYOM) engine.
4These functions interact with interfaces (stdin, stdout), python files, and the actual BYOM engine.
6In particular, they:
7 - Wrap and run python code in separate python proceess.
8 - Communicate with parent process throughout stdin/out using pickle to serialize objects.
11The flow is as follows:
13 1. Receive module code, method with parameters and stored attributes from parent process
14 2. A python class object is created from the code
15 3. Class is instanced and filled with stored attributes
16 4. A calls to the chosen method of the class is performed with any relevant parameters that were passed
17 5. Response is generated, appropriately packaged and sent to stdout
18 6. Exit
19"""
21import io
22import sys
23import pickle
24import inspect
25from enum import Enum
27import pandas as pd
30class BYOM_METHOD(Enum):
31 CHECK = 1
32 TRAIN = 2
33 PREDICT = 3
34 FINETUNE = 4
35 DESCRIBE = 5
36 FUNC_CALL = 6
39def pd_encode(df):
40 return df.to_parquet(engine='pyarrow')
43def pd_decode(encoded):
44 fd = io.BytesIO()
45 fd.write(encoded)
46 fd.seek(0)
47 return pd.read_parquet(fd, engine='pyarrow')
50def encode(obj):
51 return pickle.dumps(obj, protocol=5)
54def decode(encoded):
55 return pickle.loads(encoded)
58def return_output(obj):
59 # read stdin
60 encoded = encode(obj)
61 with open(1, 'wb') as fd:
62 fd.write(encoded)
63 sys.exit(0)
66def get_input():
67 # write to stdout
68 with open(0, 'rb') as fd:
69 encoded = fd.read()
70 obj = decode(encoded)
71 return obj
74def import_string(code, module_name='model'):
75 # import string as python module
77 import types
78 module = types.ModuleType(module_name)
80 exec(code, module.__dict__)
81 # sys.modules['my_module'] = module
82 return module
85def find_model_class(module):
86 # find the first class that contains predict and train methods
87 for _, cls in inspect.getmembers(module, inspect.isclass):
88 if inspect.getmodule(cls) is not None:
89 # is imported class
90 continue
92 funcs = [
93 name
94 for name, _ in inspect.getmembers(cls, inspect.isfunction)
95 ]
96 if 'predict' in funcs and 'train' in funcs:
97 # found
98 return cls
101def get_methods_info(module):
102 # get all methods and their types
103 methods = {}
104 for method_name, method in inspect.getmembers(module, inspect.isfunction):
106 sig = inspect.signature(method)
107 input_params = [
108 {'name': name, 'type': param.annotation.__name__}
109 for name, param in sig.parameters.items()
110 ]
111 methods[method_name] = {
112 'input_params': input_params,
113 'output_type': sig.return_annotation.__name__
114 }
115 return methods
118def check_module(module, mode):
119 # checks module and returns info
121 methods = {}
122 if mode == 'custom_function':
123 methods = get_methods_info(module)
125 else:
126 # is BYOM, check it.
127 model_class = find_model_class(module)
128 if model_class is None:
129 raise RuntimeError('Unable to find model class (it has to have `train` and `predict` methods)')
131 # try to initialize
132 model_class()
133 return {'methods': methods}
136def main():
137 # replace print output to stderr
138 sys.stdout = sys.stderr
140 params = get_input()
142 method = BYOM_METHOD(params['method'])
143 code = params['code']
145 module = import_string(code)
147 if method == BYOM_METHOD.FUNC_CALL:
148 func_name = params['func_name']
149 args = params['args']
151 func = getattr(module, func_name)
152 return return_output(func(*args))
154 if method == BYOM_METHOD.CHECK:
156 mode = params['mode']
157 info = check_module(module, mode)
159 return return_output(info)
161 model_class = find_model_class(module)
163 if method == BYOM_METHOD.TRAIN:
164 df = params['df']
165 if df is not None:
166 df = pd_decode(df)
167 to_predict = params['to_predict']
168 args = params['args']
169 model = model_class()
171 call_args = [df, to_predict]
172 if args:
173 call_args.append(args)
174 model.train(*call_args)
176 # return model
177 data = model.__dict__
179 model_state = encode(data)
180 return_output(model_state)
182 elif method == BYOM_METHOD.PREDICT:
183 model_state = params['model_state']
184 df = pd_decode(params['df'])
185 args = params['args']
187 model = model_class()
188 model.__dict__ = decode(model_state)
190 call_args = [df]
191 if args:
192 call_args.append(args)
193 res = model.predict(*call_args)
194 return_output(pd_encode(res))
196 elif method == BYOM_METHOD.FINETUNE:
197 model_state = params['model_state']
198 df = pd_decode(params['df'])
199 args = params['args']
201 model = model_class()
202 model.__dict__ = decode(model_state)
204 call_args = [df]
205 if args:
206 call_args.append(args)
208 model.finetune(*call_args)
210 # return model
211 data = model.__dict__
212 model_state = encode(data)
213 return_output(model_state)
215 elif method == BYOM_METHOD.DESCRIBE:
216 model_state = params['model_state']
217 model = model_class()
218 model.__dict__ = decode(model_state)
219 try:
220 df = model.describe(params.get('attribute'))
221 except Exception:
222 return_output(pd_encode(pd.DataFrame()))
223 return_output(pd_encode(df))
225 raise NotImplementedError(method)
228if __name__ == '__main__': 228 ↛ 229line 228 didn't jump to line 229 because the condition on line 228 was never true
229 main()