A lil service that creates embeddings of posts, profiles, and avatars to store them in Qdrant
1import logging
2import sys
3import uuid
4from dataclasses import dataclass
5from datetime import datetime, timezone
6from time import time
7from typing import List, Optional
8
9from qdrant_client import QdrantClient
10from qdrant_client.grpc import OptimizersConfigDiff
11from qdrant_client.http.models import BinaryQuantizationConfig
12from qdrant_client.models import (
13 BinaryQuantization,
14 Distance,
15 FieldCondition,
16 Filter,
17 HnswConfigDiff,
18 MatchValue,
19 Payload,
20 PayloadSchemaType,
21 PointStruct,
22 ScalarQuantization,
23 ScalarQuantizationConfig,
24 ScalarType,
25 VectorParams,
26)
27
28from config import CONFIG
29from metrics import prom_metrics
30
31logger = logging.getLogger(__name__)
32
33
34@dataclass
35class Result:
36 did: str
37 payload: Optional[Payload]
38 score: Optional[float]
39
40
41@dataclass
42class ResultWithVector(Result):
43 vector: PointStruct
44
45
46class QdrantService:
47 def __init__(self) -> None:
48 self._client = None
49
50 def initialized(self):
51 return self._client is not None
52
53 def get_client(self):
54 return self._client
55
56 def initialize(self) -> None:
57 logger.info(f"Connecting to Qdrant at {CONFIG.qdrant_url}")
58
59 self._client = QdrantClient(
60 url=CONFIG.qdrant_url,
61 )
62
63 self.profile_collection_name = CONFIG.qdrant_profile_collection_name
64 self.avatar_collection_name = CONFIG.qdrant_avatar_collection_name
65 self.post_collection_name = CONFIG.qdrant_post_collection_name
66 self._ensure_collections_exist()
67
68 def _ensure_collections_exist(self):
69 profile_coll_exists = self._client.collection_exists(
70 self.profile_collection_name
71 )
72 avatar_coll_exists = self._client.collection_exists(self.avatar_collection_name)
73 post_coll_exists = self._client.collection_exists(self.post_collection_name)
74
75 if not profile_coll_exists:
76 logger.info(f"Creating profile collection: {self.profile_collection_name}")
77 try:
78 self._client.create_collection(
79 collection_name=self.profile_collection_name,
80 vectors_config=VectorParams(size=1024, distance=Distance.COSINE),
81 hnsw_config=HnswConfigDiff(m=32, ef_construct=200),
82 quantization_config=ScalarQuantization(
83 scalar=ScalarQuantizationConfig(
84 type=ScalarType.INT8, quantile=0.99, always_ram=True
85 )
86 ),
87 )
88 except Exception as e:
89 logger.error(f"Failed to create profiles collection: {e}")
90 sys.exit(1)
91
92 try:
93 self._client.create_payload_index(
94 collection_name=self.profile_collection_name,
95 field_name="did",
96 field_schema=PayloadSchemaType.KEYWORD,
97 )
98 self._client.create_payload_index(
99 collection_name=self.profile_collection_name,
100 field_name="timestamp",
101 field_schema=PayloadSchemaType.DATETIME,
102 )
103 except Exception as e:
104 logger.error(f"Failed to create profiles indexes: {e}")
105 sys.exit(1)
106
107 logger.info("Collection created successfully")
108
109 if not avatar_coll_exists:
110 logger.info(f"Creating avatar collection: {self.avatar_collection_name}")
111
112 try:
113 self._client.create_collection(
114 collection_name=self.avatar_collection_name,
115 vectors_config=VectorParams(
116 # PDQ vectors have a size of 256
117 size=256,
118 # Qdrant doesn't support hamming distance, so we'll use euclidian distance and
119 # use the square root of the selected max distance for lookups
120 distance=Distance.EUCLID,
121 ),
122 hnsw_config=HnswConfigDiff(
123 m=16, # lower m for binary-like data
124 ef_construct=100,
125 ),
126 quantization_config=BinaryQuantization(
127 binary=BinaryQuantizationConfig(always_ram=True)
128 ),
129 )
130 except Exception as e:
131 logger.error(f"Failed to create avatar collection: {e}")
132 sys.exit(1)
133
134 try:
135 self._client.create_payload_index(
136 collection_name=self.avatar_collection_name,
137 field_name="did",
138 field_schema=PayloadSchemaType.KEYWORD,
139 )
140 self._client.create_payload_index(
141 collection_name=self.avatar_collection_name,
142 field_name="timestamp",
143 field_schema=PayloadSchemaType.DATETIME,
144 )
145 except Exception as e:
146 logger.error(f"Failed to create avatar indexes: {e}")
147 sys.exit(1)
148
149 if not post_coll_exists:
150 logger.info(f"Creating post collection: {self.post_collection_name}")
151 try:
152 self._client.create_collection(
153 collection_name=self.post_collection_name,
154 vectors_config=VectorParams(
155 size=CONFIG.embedding_size,
156 distance=Distance.COSINE,
157 ),
158 hnsw_config=HnswConfigDiff(
159 m=48,
160 ef_construct=256,
161 ),
162 quantization_config=ScalarQuantization(
163 scalar=ScalarQuantizationConfig(
164 type=ScalarType.INT8,
165 quantile=0.99,
166 always_ram=True,
167 ),
168 ),
169 optimizers_config=OptimizersConfigDiff(
170 indexing_threshold=50_000,
171 ),
172 )
173 except Exception as e:
174 logger.error(f"Failed to create posts collection: {e}")
175 sys.exit(1)
176
177 try:
178 self._client.create_payload_index(
179 collection_name=self.post_collection_name,
180 field_name="uri",
181 field_schema=PayloadSchemaType.KEYWORD,
182 )
183 self._client.create_payload_index(
184 collection_name=self.post_collection_name,
185 field_name="timestamp",
186 field_schema=PayloadSchemaType.DATETIME,
187 )
188 except Exception as e:
189 logger.error(f"Failed to create post indexes: {e}")
190 sys.exit(1)
191
192 logger.info("Collection created successfully")
193
194 def upsert_profile(self, did: str, description: str, vector: List[float]):
195 status = "error"
196 start_time = time()
197
198 try:
199 payload = {
200 "did": did,
201 "description": description,
202 "timestamp": create_now_timestamp(),
203 }
204
205 existing = self._client.scroll(
206 collection_name=self.profile_collection_name,
207 scroll_filter=Filter(
208 must=[FieldCondition(key="did", match=MatchValue(value=did))]
209 ),
210 )
211
212 if existing and existing[0] and len(existing[0]) > 0:
213 point_id = existing[0][0].id
214 else:
215 point_id = str(uuid.uuid4())
216
217 point = PointStruct(
218 id=point_id,
219 vector=vector,
220 payload=payload,
221 )
222
223 self._client.upsert(
224 collection_name=self.profile_collection_name,
225 points=[point],
226 )
227
228 status = "ok"
229
230 return True
231 except Exception as e:
232 logger.error(f"Error upserting profile: {e}")
233 return False
234 finally:
235 prom_metrics.upserts.labels(kind="profile", status=status).inc()
236 prom_metrics.upsert_duration.labels(kind="profile", status=status).observe(
237 time() - start_time
238 )
239
240 def upsert_avatar(self, did: str, cid: str, vector: List[float]):
241 status = "error"
242 start_time = time()
243
244 try:
245 payload = {
246 "did": did,
247 "cid": cid,
248 "timestamp": create_now_timestamp(),
249 }
250
251 existing = self._client.scroll(
252 collection_name=self.avatar_collection_name,
253 scroll_filter=Filter(
254 must=[FieldCondition(key="did", match=MatchValue(value=did))]
255 ),
256 )
257
258 if existing and existing[0] and len(existing[0]) > 0:
259 point_id = existing[0][0].id
260 else:
261 point_id = str(uuid.uuid4())
262
263 point = PointStruct(
264 id=point_id,
265 vector=vector,
266 payload=payload,
267 )
268
269 self._client.upsert(
270 collection_name=self.avatar_collection_name,
271 points=[point],
272 )
273
274 status = "ok"
275
276 return True
277 except Exception as e:
278 logger.error(f"Error upserting avatar: {e}")
279 return False
280 finally:
281 prom_metrics.upserts.labels(kind="avatar", status=status).inc()
282 prom_metrics.upsert_duration.labels(kind="avatar", status=status).observe(
283 time() - start_time
284 )
285
286 def upsert_post(self, did: str, uri: str, text: str, vector: List[float]):
287 status = "error"
288 start_time = time()
289
290 word_ct = len(text.split())
291
292 try:
293 payload = {
294 "did": did,
295 "uri": uri,
296 "text": text,
297 "word_count": word_ct,
298 "timestamp": create_now_timestamp(),
299 }
300
301 # we don't care about upserting these
302 point_id = str(uuid.uuid4())
303
304 point = PointStruct(
305 id=point_id,
306 vector=vector,
307 payload=payload,
308 )
309
310 self._client.upsert(
311 collection_name=self.post_collection_name,
312 points=[point],
313 )
314
315 status = "ok"
316
317 return True
318 except Exception as e:
319 logger.error(f"Error upserting post: {e}")
320 return False
321 finally:
322 prom_metrics.upserts.labels(kind="post", status=status).inc()
323 prom_metrics.upsert_duration.labels(kind="post", status=status).observe(
324 time() - start_time
325 )
326
327 def search_similar(
328 self,
329 collection_name: str,
330 query_vector: List[float],
331 limit: int = 10,
332 score_threshold: Optional[float] = None,
333 filter_conditions: Optional[Filter] = None,
334 ) -> Optional[List[Result]]:
335 try:
336 results = self._client.query_points(
337 collection_name=collection_name,
338 query=query_vector,
339 query_filter=filter_conditions,
340 limit=limit,
341 score_threshold=score_threshold,
342 with_payload=True,
343 ).points
344
345 return [
346 Result(
347 did=hit.payload.get("did"),
348 payload=hit.payload,
349 score=hit.score,
350 )
351 for hit in results
352 ]
353 except Exception as e:
354 logger.error(f"Error searching for similar vectors: {e}")
355
356 def get_profile_by_did(self, did: str) -> Optional[ResultWithVector]:
357 result = self._client.scroll(
358 collection_name=self.profile_collection_name,
359 scroll_filter=Filter(
360 must=[FieldCondition(key="did", match=MatchValue(value=did))]
361 ),
362 with_vectors=True,
363 with_payload=True,
364 )
365
366 if result and result[0] and len(result[0]) > 0:
367 point = result[0][0]
368 return ResultWithVector(
369 did=point.payload["did"],
370 payload=point.payload,
371 vector=point.vector,
372 score=1.0,
373 )
374
375 def get_avatar_by_did(self, did: str) -> Optional[ResultWithVector]:
376 result = self._client.scroll(
377 collection_name=self.avatar_collection_name,
378 scroll_filter=Filter(
379 must=[FieldCondition(key="did", match=MatchValue(value=did))]
380 ),
381 with_vectors=True,
382 with_payload=True,
383 )
384
385 if result and result[0] and len(result[0]) > 0:
386 point = result[0][0]
387 return ResultWithVector(
388 did=point.payload["did"],
389 payload=point.payload,
390 vector=point.vector,
391 score=1.0,
392 )
393
394
395QDRANT_SERVICE = QdrantService()
396
397
398def create_now_timestamp():
399 return datetime.now(timezone.utc).isoformat()