Coverage for mindsdb / integrations / handlers / vertex_handler / vertex_client.py: 0%

52 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from mindsdb.utilities import log 

2from google.cloud.aiplatform import init, TabularDataset, Model, Endpoint 

3import pandas as pd 

4 

5from mindsdb.integrations.utilities.handlers.auth_utilities.google import GoogleServiceAccountOAuth2Manager 

6 

7logger = log.getLogger(__name__) 

8 

9 

10class VertexClient: 

11 """A class to interact with Vertex AI""" 

12 

13 def __init__(self, args_json, credentials_url=None, credentials_file=None, credentials_json=None): 

14 google_sa_oauth2_manager = GoogleServiceAccountOAuth2Manager( 

15 credentials_url=credentials_url, 

16 credentials_file=credentials_file, 

17 credentials_json=credentials_json, 

18 ) 

19 credentials = google_sa_oauth2_manager.get_oauth2_credentials() 

20 

21 init( 

22 credentials=credentials, 

23 project=args_json["project_id"], 

24 location=args_json["location"], 

25 staging_bucket=args_json["staging_bucket"], 

26 # the name of the experiment to use to track 

27 # logged metrics and parameters 

28 experiment=args_json["experiment"], 

29 # description of the experiment above 

30 experiment_description=args_json["experiment_description"], 

31 ) 

32 

33 def print_datasets(self): 

34 """Print all datasets and dataset ids in the project""" 

35 for dataset in TabularDataset.list(): 

36 logger.info(f"Dataset display name: {dataset.display_name}, ID: {dataset.name}") 

37 

38 def print_models(self): 

39 """Print all model names and model ids in the project""" 

40 for model in Model.list(): 

41 logger.info(f"Model display name: {model.display_name}, ID: {model.name}") 

42 

43 def print_endpoints(self): 

44 """Print all endpoints and endpoint ids in the project""" 

45 for endpoint in Endpoint.list(): 

46 logger.info(f"Endpoint display name: {endpoint.display_name}, ID: {endpoint.name}") 

47 

48 def get_model_by_display_name(self, display_name): 

49 """Get a model by its display name""" 

50 try: 

51 return Model.list(filter=f"display_name={display_name}")[0] 

52 except IndexError: 

53 logger.info(f"Model with display name {display_name} not found") 

54 

55 def get_endpoint_by_display_name(self, display_name): 

56 """Get an endpoint by its display name""" 

57 try: 

58 return Endpoint.list(filter=f"display_name={display_name}")[0] 

59 except IndexError: 

60 logger.info(f"Endpoint with display name {display_name} not found") 

61 

62 def get_model_by_id(self, model_id): 

63 """Get a model by its ID""" 

64 try: 

65 return Model(model_name=model_id) 

66 except IndexError: 

67 logger.info(f"Model with ID {model_id} not found") 

68 

69 def deploy_model(self, model): 

70 """Deploy a model to an endpoint - long runtime""" 

71 endpoint = model.deploy() 

72 return endpoint 

73 

74 def predict_from_df(self, endpoint_display_name, df, custom_model=False): 

75 """Make a prediction from a Pandas dataframe""" 

76 endpoint = self.get_endpoint_by_display_name(endpoint_display_name) 

77 if custom_model: 

78 records = df.values.tolist() 

79 else: 

80 records = df.astype(str).to_dict(orient="records") # list of dictionaries 

81 prediction = endpoint.predict(instances=records) 

82 return prediction 

83 

84 def predict_from_csv(self, endpoint_display_name, csv_to_predict): 

85 """Make a prediction from a CSV file""" 

86 df = pd.read_csv(csv_to_predict) 

87 return self.predict_from_df(endpoint_display_name, df) 

88 

89 def predict_from_dict(self, endpoint_display_name, data): 

90 

91 # convert to list of dictionaries 

92 instances = [dict(zip(data.keys(), values)) for values in zip(*data.values())] 

93 endpoint = self.get_endpoint_by_display_name(endpoint_display_name) 

94 prediction = endpoint.predict(instances=instances) 

95 return prediction