Scalable and distributed custom feed generator, ott - on that topic

Actually insert the embeddings into the db

+85 -7
+1
crates/Cargo.lock
··· 2500 2500 name = "ott-embed" 2501 2501 version = "0.1.0" 2502 2502 dependencies = [ 2503 + "anyhow", 2503 2504 "fluvio", 2504 2505 "ott-types", 2505 2506 "pgvector",
+1
crates/ott-embed/Cargo.toml
··· 4 4 edition = "2024" 5 5 6 6 [dependencies] 7 + anyhow = "1.0.100" 7 8 fluvio = "0.50.1" 8 9 ott-types = { version = "0.1.0", path = "../ott-types" } 9 10 pgvector = { version = "0.4", features = ["sqlx"] }
+51 -6
crates/ott-embed/src/main.rs
··· 1 + use std::time::Duration; 2 + 3 + use ott_embed::pg_client::PgClient; 1 4 use ott_embed::tei_client::TextEmbedding; 2 - use serde::{Deserialize, Serialize}; 3 - use tokio::sync::mpsc::{self, Receiver, Sender}; 5 + use tokio::{ 6 + sync::mpsc::{Receiver, Sender}, 7 + time::interval, 8 + }; 4 9 5 10 use tokio_stream::StreamExt; 6 - use tracing::{error, info, warn}; 11 + use tracing::{debug, error, warn}; 7 12 use tracing_subscriber::EnvFilter; 8 13 9 14 use fluvio::{consumer::ConsumerConfigExtBuilder, Fluvio, Offset}; ··· 71 76 .expect("Failed to send embedding between tasks"); 72 77 } 73 78 Err(e) => { 74 - error!(e); 79 + error!("Failed to embed post! {} {}", post.uri, e); 75 80 } 76 81 }; 77 82 } ··· 79 84 80 85 async fn store_task(mut embeddings: Receiver<Embedding>) { 81 86 warn!("Ready to start storing embeddings"); 82 - while let Some(embedding) = embeddings.recv().await { 83 - warn!("Embedded {}", embedding.uri) 87 + let batch_size = 100; 88 + let mut flush_timer = interval(Duration::from_millis(500)); 89 + let pg_client = PgClient::new().await.expect("Failed to connect to db"); 90 + 91 + let mut batch = Vec::with_capacity(batch_size); 92 + loop { 93 + error!("Storing"); 94 + tokio::select! { 95 + Some(record) = embeddings.recv() => { 96 + batch.push(record); 97 + if batch.len() >= batch_size { 98 + if let Err(e) = pg_client.insert_embeddings(&batch).await { 99 + error!("Insert error: {}", e); 100 + } 101 + flush_timer.reset(); 102 + batch.clear(); 103 + debug!("Inserted normally"); 104 + } 105 + } 106 + 107 + // Flush periodically even if batch isn't full 108 + _ = flush_timer.tick() => { 109 + if !batch.is_empty() { 110 + if let Err(e) = pg_client.insert_embeddings(&batch).await { 111 + error!("Insert error: {}", e); 112 + } 113 + batch.clear(); 114 + debug!("Inserted flushed"); 115 + } 116 + } 117 + 118 + // Channel closed 119 + else => { 120 + // Final flush 121 + if !batch.is_empty() { 122 + if let Err(e) = pg_client.insert_embeddings(&batch).await { 123 + error!("Final insert error: {}", e); 124 + } 125 + } 126 + break; 127 + } 128 + } 84 129 } 85 130 }
+32 -1
crates/ott-embed/src/pg_client.rs
··· 1 + use anyhow::Result; 2 + use ott_types::Embedding; 3 + use pgvector::Vector; 1 4 use sqlx::PgPool; 2 5 3 - struct PgClient { 6 + pub struct PgClient { 4 7 pool: PgPool, 5 8 } 9 + 10 + impl PgClient { 11 + pub async fn new() -> Result<Self> { 12 + let database_url = std::env::var("DATABASE_URL")?; 13 + let pool = PgPool::connect(&database_url).await?; 14 + Ok(Self { pool: pool }) 15 + } 16 + 17 + pub async fn insert_embeddings( 18 + self: &Self, 19 + vectors: &Vec<Embedding>, 20 + ) -> Result<(), sqlx::Error> { 21 + let mut tx = self.pool.begin().await?; 22 + 23 + for embedding in vectors { 24 + let vector = Vector::from(embedding.vector.clone()); 25 + 26 + sqlx::query("INSERT INTO vectors (uri, vector) VALUES ($1, $2)") 27 + .bind(&embedding.uri) 28 + .bind(vector) 29 + .execute(&mut *tx) 30 + .await?; 31 + } 32 + 33 + tx.commit().await?; 34 + Ok(()) 35 + } 36 + }