Coverage for mindsdb / integrations / handlers / vertex_handler / vertex_client.py: 0%
52 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from mindsdb.utilities import log
2from google.cloud.aiplatform import init, TabularDataset, Model, Endpoint
3import pandas as pd
5from mindsdb.integrations.utilities.handlers.auth_utilities.google import GoogleServiceAccountOAuth2Manager
7logger = log.getLogger(__name__)
10class VertexClient:
11 """A class to interact with Vertex AI"""
13 def __init__(self, args_json, credentials_url=None, credentials_file=None, credentials_json=None):
14 google_sa_oauth2_manager = GoogleServiceAccountOAuth2Manager(
15 credentials_url=credentials_url,
16 credentials_file=credentials_file,
17 credentials_json=credentials_json,
18 )
19 credentials = google_sa_oauth2_manager.get_oauth2_credentials()
21 init(
22 credentials=credentials,
23 project=args_json["project_id"],
24 location=args_json["location"],
25 staging_bucket=args_json["staging_bucket"],
26 # the name of the experiment to use to track
27 # logged metrics and parameters
28 experiment=args_json["experiment"],
29 # description of the experiment above
30 experiment_description=args_json["experiment_description"],
31 )
33 def print_datasets(self):
34 """Print all datasets and dataset ids in the project"""
35 for dataset in TabularDataset.list():
36 logger.info(f"Dataset display name: {dataset.display_name}, ID: {dataset.name}")
38 def print_models(self):
39 """Print all model names and model ids in the project"""
40 for model in Model.list():
41 logger.info(f"Model display name: {model.display_name}, ID: {model.name}")
43 def print_endpoints(self):
44 """Print all endpoints and endpoint ids in the project"""
45 for endpoint in Endpoint.list():
46 logger.info(f"Endpoint display name: {endpoint.display_name}, ID: {endpoint.name}")
48 def get_model_by_display_name(self, display_name):
49 """Get a model by its display name"""
50 try:
51 return Model.list(filter=f"display_name={display_name}")[0]
52 except IndexError:
53 logger.info(f"Model with display name {display_name} not found")
55 def get_endpoint_by_display_name(self, display_name):
56 """Get an endpoint by its display name"""
57 try:
58 return Endpoint.list(filter=f"display_name={display_name}")[0]
59 except IndexError:
60 logger.info(f"Endpoint with display name {display_name} not found")
62 def get_model_by_id(self, model_id):
63 """Get a model by its ID"""
64 try:
65 return Model(model_name=model_id)
66 except IndexError:
67 logger.info(f"Model with ID {model_id} not found")
69 def deploy_model(self, model):
70 """Deploy a model to an endpoint - long runtime"""
71 endpoint = model.deploy()
72 return endpoint
74 def predict_from_df(self, endpoint_display_name, df, custom_model=False):
75 """Make a prediction from a Pandas dataframe"""
76 endpoint = self.get_endpoint_by_display_name(endpoint_display_name)
77 if custom_model:
78 records = df.values.tolist()
79 else:
80 records = df.astype(str).to_dict(orient="records") # list of dictionaries
81 prediction = endpoint.predict(instances=records)
82 return prediction
84 def predict_from_csv(self, endpoint_display_name, csv_to_predict):
85 """Make a prediction from a CSV file"""
86 df = pd.read_csv(csv_to_predict)
87 return self.predict_from_df(endpoint_display_name, df)
89 def predict_from_dict(self, endpoint_display_name, data):
91 # convert to list of dictionaries
92 instances = [dict(zip(data.keys(), values)) for values in zip(*data.values())]
93 endpoint = self.get_endpoint_by_display_name(endpoint_display_name)
94 prediction = endpoint.predict(instances=instances)
95 return prediction