Coverage for mindsdb / integrations / utilities / rag / loaders / vector_store_loader / pgvector.py: 22%
66 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from typing import Any, List, Union, Optional, Dict
3from langchain_community.vectorstores import PGVector
4from langchain_community.vectorstores.pgvector import Base
6from pgvector.sqlalchemy import SPARSEVEC, Vector
7import sqlalchemy as sa
8from sqlalchemy.dialects.postgresql import JSON
10from sqlalchemy.orm import Session
13_generated_sa_tables = {}
16class PGVectorMDB(PGVector):
17 """
18 langchain_community.vectorstores.PGVector adapted for mindsdb vector store table structure
19 """
21 def __init__(self, *args, is_sparse: bool = False, vector_size: Optional[int] = None, **kwargs):
22 # todo get is_sparse and vector_size from kb vector table
23 self.is_sparse = is_sparse
24 if is_sparse and vector_size is None:
25 raise ValueError("vector_size is required when is_sparse=True")
26 self.vector_size = vector_size
27 super().__init__(*args, **kwargs)
29 def __post_init__(
30 self,
31 ) -> None:
33 collection_name = self.collection_name
35 if collection_name not in _generated_sa_tables:
37 class EmbeddingStore(Base):
38 """Embedding store."""
40 __tablename__ = collection_name
42 id = sa.Column(sa.Integer, primary_key=True)
43 embedding = sa.Column(
44 "embeddings",
45 SPARSEVEC() if self.is_sparse else Vector() if self.vector_size is None else
46 SPARSEVEC(self.vector_size) if self.is_sparse else Vector(self.vector_size)
47 )
48 document = sa.Column("content", sa.String, nullable=True)
49 cmetadata = sa.Column("metadata", JSON, nullable=True)
51 _generated_sa_tables[collection_name] = EmbeddingStore
53 self.EmbeddingStore = _generated_sa_tables[collection_name]
55 def __query_collection(
56 self,
57 embedding: Union[List[float], Dict[int, float], str],
58 k: int = 4,
59 filter: Optional[Dict[str, str]] = None,
60 ) -> List[Any]:
61 """Query the collection."""
62 with Session(self._bind) as session:
63 if self.is_sparse:
64 # Sparse vectors: expect string in format "{key:value,...}/size" or dictionary
65 if isinstance(embedding, dict):
66 from pgvector.utils import SparseVector
67 embedding = SparseVector(embedding, self.vector_size)
68 embedding_str = embedding.to_text()
69 elif isinstance(embedding, str):
70 # Use string as is - it should already be in the correct format
71 embedding_str = embedding
72 # Use inner product for sparse vectors
73 distance_op = "<#>"
74 # For inner product, larger values are better matches
75 order_direction = "ASC"
76 else:
77 # Dense vectors: expect string in JSON array format or list of floats
78 if isinstance(embedding, list):
79 embedding_str = f"[{','.join(str(x) for x in embedding)}]"
80 elif isinstance(embedding, str):
81 embedding_str = embedding
82 # Use cosine similarity for dense vectors
83 distance_op = "<=>"
84 # For cosine similarity, smaller values are better matches
85 order_direction = "ASC"
87 # Use SQL directly for vector comparison
88 query = sa.text(
89 f"""
90 SELECT t.*, t.embeddings {distance_op} '{embedding_str}' as distance
91 FROM {self.collection_name} t
92 ORDER BY distance {order_direction}
93 LIMIT {k}
94 """
95 )
96 results = session.execute(query).all()
98 # Convert results to the expected format
99 formatted_results = []
100 for rec in results:
101 metadata = rec.metadata if bool(rec.metadata) else {0: 0}
102 embedding_store = self.EmbeddingStore()
103 embedding_store.document = rec.content
104 embedding_store.cmetadata = metadata
105 result = type(
106 'Result', (), {
107 'EmbeddingStore': embedding_store,
108 'distance': rec.distance
109 }
110 )
111 formatted_results.append(result)
113 return formatted_results
115 # aliases for different langchain versions
116 def _PGVector__query_collection(self, *args, **kwargs):
118 return self.__query_collection(*args, **kwargs)
120 def _query_collection(self, *args, **kwargs):
121 return self.__query_collection(*args, **kwargs)
123 def create_collection(self):
124 raise RuntimeError("Forbidden")
126 def delete_collection(self):
127 raise RuntimeError("Forbidden")
129 def delete(self, *args, **kwargs):
130 raise RuntimeError("Forbidden")
132 def add_embeddings(self, *args, **kwargs):
133 raise RuntimeError("Forbidden")