Coverage for mindsdb / integrations / handlers / sentence_transformers_handler / sentence_transformers_handler.py: 0%

40 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from typing import Optional 

2 

3import pandas as pd 

4 

5from mindsdb.integrations.handlers.sentence_transformers_handler.settings import Parameters 

6 

7from mindsdb.integrations.handlers.rag_handler.settings import load_embeddings_model, df_to_documents 

8 

9 

10from mindsdb.integrations.libs.base import BaseMLEngine 

11from mindsdb.utilities import log 

12 

13logger = log.getLogger(__name__) 

14 

15 

16class SentenceTransformersHandler(BaseMLEngine): 

17 name = "sentence transformers" 

18 

19 def __init__(self, model_storage, engine_storage, **kwargs) -> None: 

20 super().__init__(model_storage, engine_storage, **kwargs) 

21 self.generative = True 

22 

23 def create(self, target, df=None, args=None, **kwargs): 

24 """creates embeddings model and persists""" 

25 

26 args = args["using"] 

27 

28 valid_args = Parameters(**args) 

29 self.model_storage.json_set("args", valid_args.model_dump()) 

30 

31 def predict(self, df, args=None): 

32 """loads persisted embeddings model and gets embeddings on input text column(s)""" 

33 

34 args = self.model_storage.json_get("args") 

35 

36 if isinstance(df['content'].iloc[0], list) and len(df['content']) == 1: 

37 # allow user to pass in a list of strings in where clause 

38 # i.e where content = ['hello', 'world'] or where content = (select content from some_db.some_table) 

39 input_df = df.copy() 

40 df = pd.DataFrame(data={"content": input_df['content'].iloc[0]}) 

41 

42 # get text columns if specified 

43 if isinstance(args['text_columns'], str): 

44 columns = [args['text_columns']] 

45 

46 elif isinstance(args['text_columns'], list): 

47 columns = args['text_columns'] 

48 

49 elif args['text_columns'] is None: 

50 # assume all columns are text columns 

51 logger.info("No text columns specified, assuming all columns are text columns") 

52 columns = df.columns.tolist() 

53 

54 else: 

55 raise ValueError(f"Invalid value for text_columns: {args['text_columns']}") 

56 

57 documents = df_to_documents(df=df, page_content_columns=columns) 

58 

59 content = [doc.page_content for doc in documents] 

60 metadata = [doc.metadata for doc in documents] 

61 

62 model = load_embeddings_model(args['embeddings_model_name']) 

63 

64 embeddings = model.embed_documents(texts=content) 

65 

66 embeddings_df = pd.DataFrame(data={"content": content, "embeddings": embeddings, "metadata": metadata}) 

67 

68 return embeddings_df 

69 

70 def describe(self, attribute: Optional[str] = None) -> pd.DataFrame: 

71 

72 args = self.model_storage.json_get("args") 

73 

74 if attribute == "args": 

75 return pd.DataFrame(args.items(), columns=["key", "value"])