286 lines
9.9 KiB
Python
286 lines
9.9 KiB
Python
"""
|
|
Database connection and operations for ML training service
|
|
"""
|
|
|
|
import os
|
|
import hashlib
|
|
from typing import List, Dict, Optional, Any
|
|
from datetime import datetime
|
|
import asyncpg
|
|
from config_loader import config_loader
|
|
from dotenv import load_dotenv
|
|
from logger_config import get_logger
|
|
|
|
load_dotenv()
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
# Database configuration
|
|
DATABASE_URL = os.getenv("DATABASE_URL")
|
|
|
|
class DatabaseManager:
|
|
def __init__(self):
|
|
self.pool: Optional[asyncpg.Pool] = None
|
|
|
|
async def connect(self):
|
|
"""Initialize database connection pool"""
|
|
try:
|
|
self.pool = await asyncpg.create_pool(DATABASE_URL, min_size=5, max_size=20)
|
|
|
|
logger.info("Database connection pool initialized")
|
|
except Exception as e:
|
|
logger.error(f"Failed to connect to database: {e}")
|
|
raise
|
|
|
|
async def close(self):
|
|
"""Close database connection pool"""
|
|
if self.pool:
|
|
await self.pool.close()
|
|
logger.info("Database connection pool closed")
|
|
|
|
@property
|
|
def is_connected(self) -> bool:
|
|
"""Check if database connection pool is initialized"""
|
|
return self.pool is not None
|
|
|
|
async def get_embedding_models(self):
|
|
"""Get available embedding models from config"""
|
|
return config_loader.get_embedding_models()
|
|
|
|
async def get_video_metadata(
|
|
self, aid_list: List[int]
|
|
) -> Dict[int, Dict[str, Any]]:
|
|
"""Get video metadata for given AIDs"""
|
|
if not aid_list:
|
|
return {}
|
|
|
|
async with self.pool.acquire() as conn:
|
|
query = """
|
|
SELECT aid, title, description, tags
|
|
FROM bilibili_metadata
|
|
WHERE aid = ANY($1::bigint[])
|
|
"""
|
|
rows = await conn.fetch(query, aid_list)
|
|
|
|
result = {}
|
|
for row in rows:
|
|
result[int(row["aid"])] = {
|
|
"aid": int(row["aid"]),
|
|
"title": row["title"] or "",
|
|
"description": row["description"] or "",
|
|
"tags": row["tags"] or "",
|
|
}
|
|
|
|
return result
|
|
|
|
async def get_user_labels(
|
|
self, aid_list: List[int]
|
|
) -> Dict[int, List[Dict[str, Any]]]:
|
|
"""Get user labels for given AIDs, only the latest label per user"""
|
|
if not aid_list:
|
|
return {}
|
|
|
|
async with self.pool.acquire() as conn:
|
|
query = """
|
|
WITH latest_labels AS (
|
|
SELECT DISTINCT ON (aid, "user")
|
|
aid, "user", label, created_at
|
|
FROM internal.video_type_label
|
|
WHERE aid = ANY($1::bigint[])
|
|
ORDER BY aid, "user", created_at DESC
|
|
)
|
|
SELECT aid, "user", label, created_at
|
|
FROM latest_labels
|
|
ORDER BY aid, "user"
|
|
"""
|
|
rows = await conn.fetch(query, aid_list)
|
|
|
|
result = {}
|
|
for row in rows:
|
|
aid = int(row["aid"])
|
|
if aid not in result:
|
|
result[aid] = []
|
|
|
|
result[aid].append(
|
|
{
|
|
"user": row["user"],
|
|
"label": bool(row["label"]),
|
|
"created_at": row["created_at"].isoformat(),
|
|
}
|
|
)
|
|
|
|
return result
|
|
|
|
async def get_existing_embeddings(
|
|
self, checksums: List[str], model_name: str
|
|
) -> Dict[str, Dict[str, Any]]:
|
|
"""Get existing embeddings for given checksums and model"""
|
|
if not checksums:
|
|
return {}
|
|
|
|
async with self.pool.acquire() as conn:
|
|
query = """
|
|
SELECT data_checksum, dimensions, vec_2048, vec_1536, vec_1024, created_at
|
|
FROM internal.embeddings
|
|
WHERE model_name = $1 AND data_checksum = ANY($2::text[])
|
|
"""
|
|
rows = await conn.fetch(query, model_name, checksums)
|
|
|
|
result = {}
|
|
for row in rows:
|
|
checksum = row["data_checksum"]
|
|
|
|
# Convert vector strings to lists if they exist
|
|
vec_2048 = self._parse_vector_string(row["vec_2048"]) if row["vec_2048"] else None
|
|
vec_1536 = self._parse_vector_string(row["vec_1536"]) if row["vec_1536"] else None
|
|
vec_1024 = self._parse_vector_string(row["vec_1024"]) if row["vec_1024"] else None
|
|
|
|
result[checksum] = {
|
|
"checksum": checksum,
|
|
"vec_2048": vec_2048,
|
|
"vec_1536": vec_1536,
|
|
"vec_1024": vec_1024,
|
|
"created_at": row["created_at"].isoformat(),
|
|
}
|
|
|
|
return result
|
|
|
|
def _parse_vector_string(self, vector_str: str) -> List[float]:
|
|
"""Parse vector string format '[1.0,2.0,3.0]' back to list"""
|
|
if not vector_str:
|
|
return []
|
|
|
|
try:
|
|
# Remove brackets and split by comma
|
|
vector_str = vector_str.strip()
|
|
if vector_str.startswith('[') and vector_str.endswith(']'):
|
|
vector_str = vector_str[1:-1]
|
|
|
|
return [float(x.strip()) for x in vector_str.split(',') if x.strip()]
|
|
except Exception as e:
|
|
logger.warning(f"Failed to parse vector string '{vector_str}': {e}")
|
|
return []
|
|
|
|
async def insert_embeddings(self, embeddings_data: List[Dict[str, Any]]) -> None:
|
|
"""Batch insert embeddings into database"""
|
|
if not embeddings_data:
|
|
return
|
|
|
|
async with self.pool.acquire() as conn:
|
|
async with conn.transaction():
|
|
for data in embeddings_data:
|
|
# Determine which vector column to use based on dimensions
|
|
vec_column = f"vec_{data['dimensions']}"
|
|
|
|
# Convert vector list to string format for PostgreSQL
|
|
vector_str = "[" + ",".join(map(str, data["vector"])) + "]"
|
|
|
|
query = f"""
|
|
INSERT INTO internal.embeddings
|
|
(model_name, dimensions, data_checksum, {vec_column}, created_at)
|
|
VALUES ($1, $2, $3, $4, $5)
|
|
ON CONFLICT (model_name, dimensions, data_checksum) DO NOTHING
|
|
"""
|
|
|
|
await conn.execute(
|
|
query,
|
|
data["model_name"],
|
|
data["dimensions"],
|
|
data["checksum"],
|
|
vector_str,
|
|
datetime.now(),
|
|
)
|
|
|
|
async def get_final_dataset(
|
|
self, aid_list: List[int], model_name: str
|
|
) -> List[Dict[str, Any]]:
|
|
"""Get final dataset with embeddings and labels"""
|
|
if not aid_list:
|
|
return []
|
|
|
|
# Get video metadata
|
|
metadata = await self.get_video_metadata(aid_list)
|
|
|
|
# Get user labels (latest per user)
|
|
labels = await self.get_user_labels(aid_list)
|
|
|
|
# Prepare text data for embedding
|
|
text_data = []
|
|
aid_to_text = {}
|
|
|
|
for aid in aid_list:
|
|
if aid in metadata:
|
|
# Combine title, description, and tags for embedding
|
|
text_parts = [
|
|
metadata[aid]["title"],
|
|
metadata[aid]["description"],
|
|
metadata[aid]["tags"],
|
|
]
|
|
combined_text = " ".join(filter(None, text_parts))
|
|
|
|
# Create checksum for deduplication
|
|
checksum = hashlib.md5(combined_text.encode("utf-8")).hexdigest()
|
|
|
|
text_data.append(
|
|
{"aid": aid, "text": combined_text, "checksum": checksum}
|
|
)
|
|
aid_to_text[checksum] = aid
|
|
|
|
# Get checksums = [ existing embeddings
|
|
checks = [item["checksum"] for item in text_data]
|
|
existing_embeddings = await self.get_existing_embeddings(checks)
|
|
|
|
# ums, model_name Prepare final dataset
|
|
dataset = []
|
|
|
|
for item in text_data:
|
|
aid = item["aid"]
|
|
checksum = item["checksum"]
|
|
|
|
# Get embedding vector
|
|
embedding_vector = None
|
|
if checksum in existing_embeddings:
|
|
# Use existing embedding
|
|
emb_data = existing_embeddings[checksum]
|
|
if emb_data["vec_1536"]:
|
|
embedding_vector = emb_data["vec_1536"]
|
|
elif emb_data["vec_2048"]:
|
|
embedding_vector = emb_data["vec_2048"]
|
|
elif emb_data["vec_1024"]:
|
|
embedding_vector = emb_data["vec_1024"]
|
|
|
|
# Get labels for this aid
|
|
aid_labels = labels.get(aid, [])
|
|
|
|
# Determine final label using consensus (majority vote)
|
|
if aid_labels:
|
|
positive_votes = sum(1 for lbl in aid_labels if lbl["label"])
|
|
final_label = positive_votes > len(aid_labels) / 2
|
|
else:
|
|
final_label = None # No labels available
|
|
|
|
# Check for inconsistent labels
|
|
inconsistent = len(aid_labels) > 1 and (
|
|
sum(1 for lbl in aid_labels if lbl["label"]) != 0
|
|
and sum(1 for lbl in aid_labels if lbl["label"]) != len(aid_labels)
|
|
)
|
|
|
|
if embedding_vector and final_label is not None:
|
|
dataset.append(
|
|
{
|
|
"aid": aid,
|
|
"embedding": embedding_vector,
|
|
"label": final_label,
|
|
"metadata": metadata.get(aid, {}),
|
|
"user_labels": aid_labels,
|
|
"inconsistent": inconsistent,
|
|
"text_checksum": checksum,
|
|
}
|
|
)
|
|
|
|
return dataset
|
|
|
|
|
|
# Global database manager instance
|
|
db_manager = DatabaseManager()
|