Coverage for mindsdb / integrations / utilities / rag / loaders / vector_store_loader / pgvector.py: 22%

66 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from typing import Any, List, Union, Optional, Dict 

2 

3from langchain_community.vectorstores import PGVector 

4from langchain_community.vectorstores.pgvector import Base 

5 

6from pgvector.sqlalchemy import SPARSEVEC, Vector 

7import sqlalchemy as sa 

8from sqlalchemy.dialects.postgresql import JSON 

9 

10from sqlalchemy.orm import Session 

11 

12 

13_generated_sa_tables = {} 

14 

15 

16class PGVectorMDB(PGVector): 

17 """ 

18 langchain_community.vectorstores.PGVector adapted for mindsdb vector store table structure 

19 """ 

20 

21 def __init__(self, *args, is_sparse: bool = False, vector_size: Optional[int] = None, **kwargs): 

22 # todo get is_sparse and vector_size from kb vector table 

23 self.is_sparse = is_sparse 

24 if is_sparse and vector_size is None: 

25 raise ValueError("vector_size is required when is_sparse=True") 

26 self.vector_size = vector_size 

27 super().__init__(*args, **kwargs) 

28 

29 def __post_init__( 

30 self, 

31 ) -> None: 

32 

33 collection_name = self.collection_name 

34 

35 if collection_name not in _generated_sa_tables: 

36 

37 class EmbeddingStore(Base): 

38 """Embedding store.""" 

39 

40 __tablename__ = collection_name 

41 

42 id = sa.Column(sa.Integer, primary_key=True) 

43 embedding = sa.Column( 

44 "embeddings", 

45 SPARSEVEC() if self.is_sparse else Vector() if self.vector_size is None else 

46 SPARSEVEC(self.vector_size) if self.is_sparse else Vector(self.vector_size) 

47 ) 

48 document = sa.Column("content", sa.String, nullable=True) 

49 cmetadata = sa.Column("metadata", JSON, nullable=True) 

50 

51 _generated_sa_tables[collection_name] = EmbeddingStore 

52 

53 self.EmbeddingStore = _generated_sa_tables[collection_name] 

54 

55 def __query_collection( 

56 self, 

57 embedding: Union[List[float], Dict[int, float], str], 

58 k: int = 4, 

59 filter: Optional[Dict[str, str]] = None, 

60 ) -> List[Any]: 

61 """Query the collection.""" 

62 with Session(self._bind) as session: 

63 if self.is_sparse: 

64 # Sparse vectors: expect string in format "{key:value,...}/size" or dictionary 

65 if isinstance(embedding, dict): 

66 from pgvector.utils import SparseVector 

67 embedding = SparseVector(embedding, self.vector_size) 

68 embedding_str = embedding.to_text() 

69 elif isinstance(embedding, str): 

70 # Use string as is - it should already be in the correct format 

71 embedding_str = embedding 

72 # Use inner product for sparse vectors 

73 distance_op = "<#>" 

74 # For inner product, larger values are better matches 

75 order_direction = "ASC" 

76 else: 

77 # Dense vectors: expect string in JSON array format or list of floats 

78 if isinstance(embedding, list): 

79 embedding_str = f"[{','.join(str(x) for x in embedding)}]" 

80 elif isinstance(embedding, str): 

81 embedding_str = embedding 

82 # Use cosine similarity for dense vectors 

83 distance_op = "<=>" 

84 # For cosine similarity, smaller values are better matches 

85 order_direction = "ASC" 

86 

87 # Use SQL directly for vector comparison 

88 query = sa.text( 

89 f""" 

90 SELECT t.*, t.embeddings {distance_op} '{embedding_str}' as distance 

91 FROM {self.collection_name} t 

92 ORDER BY distance {order_direction} 

93 LIMIT {k} 

94 """ 

95 ) 

96 results = session.execute(query).all() 

97 

98 # Convert results to the expected format 

99 formatted_results = [] 

100 for rec in results: 

101 metadata = rec.metadata if bool(rec.metadata) else {0: 0} 

102 embedding_store = self.EmbeddingStore() 

103 embedding_store.document = rec.content 

104 embedding_store.cmetadata = metadata 

105 result = type( 

106 'Result', (), { 

107 'EmbeddingStore': embedding_store, 

108 'distance': rec.distance 

109 } 

110 ) 

111 formatted_results.append(result) 

112 

113 return formatted_results 

114 

115 # aliases for different langchain versions 

116 def _PGVector__query_collection(self, *args, **kwargs): 

117 

118 return self.__query_collection(*args, **kwargs) 

119 

120 def _query_collection(self, *args, **kwargs): 

121 return self.__query_collection(*args, **kwargs) 

122 

123 def create_collection(self): 

124 raise RuntimeError("Forbidden") 

125 

126 def delete_collection(self): 

127 raise RuntimeError("Forbidden") 

128 

129 def delete(self, *args, **kwargs): 

130 raise RuntimeError("Forbidden") 

131 

132 def add_embeddings(self, *args, **kwargs): 

133 raise RuntimeError("Forbidden")