Coverage for mindsdb / integrations / handlers / sentence_transformers_handler / sentence_transformers_handler.py: 0%
40 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from typing import Optional
3import pandas as pd
5from mindsdb.integrations.handlers.sentence_transformers_handler.settings import Parameters
7from mindsdb.integrations.handlers.rag_handler.settings import load_embeddings_model, df_to_documents
10from mindsdb.integrations.libs.base import BaseMLEngine
11from mindsdb.utilities import log
13logger = log.getLogger(__name__)
16class SentenceTransformersHandler(BaseMLEngine):
17 name = "sentence transformers"
19 def __init__(self, model_storage, engine_storage, **kwargs) -> None:
20 super().__init__(model_storage, engine_storage, **kwargs)
21 self.generative = True
23 def create(self, target, df=None, args=None, **kwargs):
24 """creates embeddings model and persists"""
26 args = args["using"]
28 valid_args = Parameters(**args)
29 self.model_storage.json_set("args", valid_args.model_dump())
31 def predict(self, df, args=None):
32 """loads persisted embeddings model and gets embeddings on input text column(s)"""
34 args = self.model_storage.json_get("args")
36 if isinstance(df['content'].iloc[0], list) and len(df['content']) == 1:
37 # allow user to pass in a list of strings in where clause
38 # i.e where content = ['hello', 'world'] or where content = (select content from some_db.some_table)
39 input_df = df.copy()
40 df = pd.DataFrame(data={"content": input_df['content'].iloc[0]})
42 # get text columns if specified
43 if isinstance(args['text_columns'], str):
44 columns = [args['text_columns']]
46 elif isinstance(args['text_columns'], list):
47 columns = args['text_columns']
49 elif args['text_columns'] is None:
50 # assume all columns are text columns
51 logger.info("No text columns specified, assuming all columns are text columns")
52 columns = df.columns.tolist()
54 else:
55 raise ValueError(f"Invalid value for text_columns: {args['text_columns']}")
57 documents = df_to_documents(df=df, page_content_columns=columns)
59 content = [doc.page_content for doc in documents]
60 metadata = [doc.metadata for doc in documents]
62 model = load_embeddings_model(args['embeddings_model_name'])
64 embeddings = model.embed_documents(texts=content)
66 embeddings_df = pd.DataFrame(data={"content": content, "embeddings": embeddings, "metadata": metadata})
68 return embeddings_df
70 def describe(self, attribute: Optional[str] = None) -> pd.DataFrame:
72 args = self.model_storage.json_get("args")
74 if attribute == "args":
75 return pd.DataFrame(args.items(), columns=["key", "value"])