Coverage for mindsdb / integrations / handlers / vertex_handler / vertex_handler.py: 0%
51 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import pandas as pd
2from mindsdb.integrations.libs.base import BaseMLEngine
3from mindsdb.integrations.handlers.vertex_handler.vertex_client import VertexClient
4from mindsdb.utilities import log
6logger = log.getLogger(__name__)
9class VertexHandler(BaseMLEngine):
10 """Handler for the Vertex Google AI cloud API"""
12 name = "Vertex"
14 def create(self, target, args=None, **kwargs):
15 """Logs in to Vertex and deploy a pre-trained model to an endpoint.
17 If the endpoint already exists for the model, we do nothing.
19 If the endpoint does not exist, we create it and deploy the model to it.
20 The runtime for this is long, it took 15 minutes for a small model.
21 """
22 assert "using" in args, "Must provide USING arguments for this handler"
23 args = args["using"]
25 model_name = args.pop("model_name")
26 custom_model = args.pop("custom_model", False)
28 # get credentials from engine
29 credentials_url, credentials_file, credentials_json = self._get_credentials_from_engine()
31 # get vertex args from handler then update args from model
32 vertex_args = self.engine_storage.json_get('args')
33 vertex_args.update(args)
35 vertex = VertexClient(vertex_args, credentials_url, credentials_file, credentials_json)
37 model = vertex.get_model_by_display_name(model_name)
38 if not model:
39 raise Exception(f"Vertex model {model_name} not found")
40 endpoint_name = model_name + "_endpoint"
41 if vertex.get_endpoint_by_display_name(endpoint_name):
42 logger.info(f"Endpoint {endpoint_name} already exists, skipping deployment")
43 else:
44 logger.info(f"Starting deployment at {endpoint_name}")
45 endpoint = vertex.deploy_model(model)
46 endpoint.display_name = endpoint_name
47 endpoint.update()
48 logger.info(f"Endpoint {endpoint_name} deployed")
50 predict_args = {}
51 predict_args["target"] = target
52 predict_args["endpoint_name"] = endpoint_name
53 predict_args["custom_model"] = custom_model
54 self.model_storage.json_set("predict_args", predict_args)
55 self.model_storage.json_set("vertex_args", vertex_args)
57 def predict(self, df, args=None):
58 """Predict using the deployed model by calling the endpoint."""
60 if "__mindsdb_row_id" in df.columns:
61 df.drop("__mindsdb_row_id", axis=1, inplace=True) # TODO is this required?
63 predict_args = self.model_storage.json_get("predict_args")
64 vertex_args = self.model_storage.json_get("vertex_args")
66 # get credentials from engine
67 credentials_url, credentials_file, credentials_json = self._get_credentials_from_engine()
69 vertex = VertexClient(vertex_args, credentials_url, credentials_file, credentials_json)
70 results = vertex.predict_from_df(predict_args["endpoint_name"], df, custom_model=predict_args["custom_model"])
72 if predict_args["custom_model"]:
73 return pd.DataFrame(results.predictions, columns=[predict_args["target"]])
74 else:
75 return pd.DataFrame(results.predictions)
77 def create_engine(self, connection_args):
78 # check if one of credentials_url, credentials_file, or credentials_json is provided
79 if 'service_account_key_url' not in connection_args and 'service_account_key_file' not in connection_args and 'service_account_key_json' not in connection_args:
80 raise KeyError('Either service_account_key_url, service_account_key_file, or service_account_key_json must be provided')
82 self.engine_storage.json_set('args', connection_args)
84 def _get_credentials_from_engine(self):
85 engine_args = self.engine_storage.json_get('args')
87 return engine_args.get('service_account_key_url'), engine_args.get('service_account_key_file'), engine_args.get('service_account_key_json')