Coverage for mindsdb / integrations / handlers / vertex_handler / vertex_handler.py: 0%

51 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import pandas as pd 

2from mindsdb.integrations.libs.base import BaseMLEngine 

3from mindsdb.integrations.handlers.vertex_handler.vertex_client import VertexClient 

4from mindsdb.utilities import log 

5 

6logger = log.getLogger(__name__) 

7 

8 

9class VertexHandler(BaseMLEngine): 

10 """Handler for the Vertex Google AI cloud API""" 

11 

12 name = "Vertex" 

13 

14 def create(self, target, args=None, **kwargs): 

15 """Logs in to Vertex and deploy a pre-trained model to an endpoint. 

16 

17 If the endpoint already exists for the model, we do nothing. 

18 

19 If the endpoint does not exist, we create it and deploy the model to it. 

20 The runtime for this is long, it took 15 minutes for a small model. 

21 """ 

22 assert "using" in args, "Must provide USING arguments for this handler" 

23 args = args["using"] 

24 

25 model_name = args.pop("model_name") 

26 custom_model = args.pop("custom_model", False) 

27 

28 # get credentials from engine 

29 credentials_url, credentials_file, credentials_json = self._get_credentials_from_engine() 

30 

31 # get vertex args from handler then update args from model 

32 vertex_args = self.engine_storage.json_get('args') 

33 vertex_args.update(args) 

34 

35 vertex = VertexClient(vertex_args, credentials_url, credentials_file, credentials_json) 

36 

37 model = vertex.get_model_by_display_name(model_name) 

38 if not model: 

39 raise Exception(f"Vertex model {model_name} not found") 

40 endpoint_name = model_name + "_endpoint" 

41 if vertex.get_endpoint_by_display_name(endpoint_name): 

42 logger.info(f"Endpoint {endpoint_name} already exists, skipping deployment") 

43 else: 

44 logger.info(f"Starting deployment at {endpoint_name}") 

45 endpoint = vertex.deploy_model(model) 

46 endpoint.display_name = endpoint_name 

47 endpoint.update() 

48 logger.info(f"Endpoint {endpoint_name} deployed") 

49 

50 predict_args = {} 

51 predict_args["target"] = target 

52 predict_args["endpoint_name"] = endpoint_name 

53 predict_args["custom_model"] = custom_model 

54 self.model_storage.json_set("predict_args", predict_args) 

55 self.model_storage.json_set("vertex_args", vertex_args) 

56 

57 def predict(self, df, args=None): 

58 """Predict using the deployed model by calling the endpoint.""" 

59 

60 if "__mindsdb_row_id" in df.columns: 

61 df.drop("__mindsdb_row_id", axis=1, inplace=True) # TODO is this required? 

62 

63 predict_args = self.model_storage.json_get("predict_args") 

64 vertex_args = self.model_storage.json_get("vertex_args") 

65 

66 # get credentials from engine 

67 credentials_url, credentials_file, credentials_json = self._get_credentials_from_engine() 

68 

69 vertex = VertexClient(vertex_args, credentials_url, credentials_file, credentials_json) 

70 results = vertex.predict_from_df(predict_args["endpoint_name"], df, custom_model=predict_args["custom_model"]) 

71 

72 if predict_args["custom_model"]: 

73 return pd.DataFrame(results.predictions, columns=[predict_args["target"]]) 

74 else: 

75 return pd.DataFrame(results.predictions) 

76 

77 def create_engine(self, connection_args): 

78 # check if one of credentials_url, credentials_file, or credentials_json is provided 

79 if 'service_account_key_url' not in connection_args and 'service_account_key_file' not in connection_args and 'service_account_key_json' not in connection_args: 

80 raise KeyError('Either service_account_key_url, service_account_key_file, or service_account_key_json must be provided') 

81 

82 self.engine_storage.json_set('args', connection_args) 

83 

84 def _get_credentials_from_engine(self): 

85 engine_args = self.engine_storage.json_get('args') 

86 

87 return engine_args.get('service_account_key_url'), engine_args.get('service_account_key_file'), engine_args.get('service_account_key_json')