276 lines
9.6 KiB
Python
276 lines
9.6 KiB
Python
"""
|
|
Database connection and operations for ML training service
|
|
"""
|
|
|
|
from collections import defaultdict
|
|
import os
|
|
from typing import List, Dict, Optional, Any
|
|
from datetime import datetime
|
|
import asyncpg
|
|
from ml_new.config.config_loader import config_loader
|
|
from ml_new.config.logger_config import get_logger
|
|
from dotenv import load_dotenv
|
|
|
|
|
|
load_dotenv()
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
# Database configuration
|
|
DATABASE_URL = os.getenv("DATABASE_URL")
|
|
|
|
|
|
class DatabaseManager:
|
|
def __init__(self):
|
|
self.pool: Optional[asyncpg.Pool] = None
|
|
|
|
async def connect(self):
|
|
"""Initialize database connection pool"""
|
|
try:
|
|
self.pool = await asyncpg.create_pool(DATABASE_URL, min_size=5, max_size=20)
|
|
|
|
logger.info("Database connection pool initialized")
|
|
except Exception as e:
|
|
logger.error(f"Failed to connect to database: {e}")
|
|
raise
|
|
|
|
async def close(self):
|
|
"""Close database connection pool"""
|
|
if self.pool:
|
|
await self.pool.close()
|
|
logger.info("Database connection pool closed")
|
|
|
|
@property
|
|
def is_connected(self) -> bool:
|
|
"""Check if database connection pool is initialized"""
|
|
return self.pool is not None
|
|
|
|
async def get_embedding_models(self):
|
|
"""Get available embedding models from config"""
|
|
return config_loader.get_embedding_models()
|
|
|
|
async def get_video_metadata(
|
|
self, aid_list: List[int]
|
|
) -> Dict[int, Dict[str, Any]]:
|
|
"""Get video metadata for given AIDs"""
|
|
if not aid_list:
|
|
return {}
|
|
|
|
async with self.pool.acquire() as conn:
|
|
query = """
|
|
SELECT aid, title, description, tags
|
|
FROM bilibili_metadata
|
|
WHERE aid = ANY($1::bigint[])
|
|
"""
|
|
rows = await conn.fetch(query, aid_list)
|
|
|
|
result = {}
|
|
for row in rows:
|
|
result[int(row["aid"])] = {
|
|
"aid": int(row["aid"]),
|
|
"title": row["title"] or "",
|
|
"description": row["description"] or "",
|
|
"tags": row["tags"] or "",
|
|
}
|
|
|
|
return result
|
|
|
|
async def get_user_labels(
|
|
self, aid_list: List[int]
|
|
) -> Dict[int, List[Dict[str, Any]]]:
|
|
"""Get user labels for given AIDs, only the latest label per user"""
|
|
if not aid_list:
|
|
return {}
|
|
|
|
async with self.pool.acquire() as conn:
|
|
query = """
|
|
WITH latest_labels AS (
|
|
SELECT DISTINCT ON (aid, "user")
|
|
aid, "user", label, created_at
|
|
FROM internal.video_type_label
|
|
WHERE aid = ANY($1::bigint[])
|
|
ORDER BY aid, "user", created_at DESC
|
|
)
|
|
SELECT aid, "user", label, created_at
|
|
FROM latest_labels
|
|
ORDER BY aid, "user"
|
|
"""
|
|
rows = await conn.fetch(query, aid_list)
|
|
|
|
result = {}
|
|
for row in rows:
|
|
aid = int(row["aid"])
|
|
if aid not in result:
|
|
result[aid] = []
|
|
|
|
result[aid].append(
|
|
{
|
|
"user": row["user"],
|
|
"label": bool(row["label"]),
|
|
"created_at": row["created_at"].isoformat(),
|
|
}
|
|
)
|
|
|
|
return result
|
|
|
|
async def get_existing_embeddings(
|
|
self, checksums: List[str], model_name: str
|
|
) -> Dict[str, Dict[str, Any]]:
|
|
"""Get existing embeddings for given checksums and model"""
|
|
if not checksums:
|
|
return {}
|
|
|
|
async with self.pool.acquire() as conn:
|
|
query = """
|
|
SELECT data_checksum, dimensions, vec_2048, vec_1536, vec_1024, created_at
|
|
FROM internal.embeddings
|
|
WHERE model_name = $1 AND data_checksum = ANY($2::text[])
|
|
"""
|
|
rows = await conn.fetch(query, model_name, checksums)
|
|
|
|
result = {}
|
|
for row in rows:
|
|
checksum = row["data_checksum"]
|
|
|
|
# Convert vector strings to lists if they exist
|
|
vec_2048 = (
|
|
self._parse_vector_string(row["vec_2048"])
|
|
if row["vec_2048"]
|
|
else None
|
|
)
|
|
vec_1536 = (
|
|
self._parse_vector_string(row["vec_1536"])
|
|
if row["vec_1536"]
|
|
else None
|
|
)
|
|
vec_1024 = (
|
|
self._parse_vector_string(row["vec_1024"])
|
|
if row["vec_1024"]
|
|
else None
|
|
)
|
|
|
|
result[checksum] = {
|
|
"checksum": checksum,
|
|
"vec_2048": vec_2048,
|
|
"vec_1536": vec_1536,
|
|
"vec_1024": vec_1024,
|
|
"created_at": row["created_at"].isoformat(),
|
|
}
|
|
|
|
return result
|
|
|
|
def _parse_vector_string(self, vector_str: str) -> List[float]:
|
|
"""Parse vector string format '[1.0,2.0,3.0]' back to list"""
|
|
if not vector_str:
|
|
return []
|
|
|
|
try:
|
|
# Remove brackets and split by comma
|
|
vector_str = vector_str.strip()
|
|
if vector_str.startswith("[") and vector_str.endswith("]"):
|
|
vector_str = vector_str[1:-1]
|
|
|
|
return [float(x.strip()) for x in vector_str.split(",") if x.strip()]
|
|
except Exception as e:
|
|
logger.warning(f"Failed to parse vector string '{vector_str}': {e}")
|
|
return []
|
|
|
|
async def insert_embeddings(self, embeddings_data: List[Dict[str, Any]]) -> None:
|
|
"""Batch insert embeddings into database (Optimized)"""
|
|
if not embeddings_data:
|
|
return
|
|
|
|
batches = defaultdict(list)
|
|
now = datetime.now()
|
|
|
|
for data in embeddings_data:
|
|
vector_str = str(data["vector"])
|
|
# "[" + ",".join(map(str, data["vector"])) + "]"
|
|
|
|
dim = data["dimensions"]
|
|
|
|
batches[dim].append(
|
|
(data["model_name"], dim, data["checksum"], vector_str, now)
|
|
)
|
|
|
|
async with self.pool.acquire() as conn:
|
|
async with conn.transaction():
|
|
for dim, values in batches.items():
|
|
vec_column = f"vec_{dim}"
|
|
|
|
query = f"""
|
|
INSERT INTO internal.embeddings
|
|
(model_name, dimensions, data_checksum, {vec_column}, created_at)
|
|
VALUES ($1, $2, $3, $4, $5)
|
|
ON CONFLICT (model_name, dimensions, data_checksum) DO NOTHING
|
|
"""
|
|
|
|
await conn.executemany(query, values)
|
|
|
|
# Sampling Methods
|
|
async def get_all_aids(self) -> List[int]:
|
|
"""Get all available AIDs from labeled data (internal.video_type_label)"""
|
|
async with self.pool.acquire() as conn:
|
|
query = "SELECT DISTINCT aid FROM internal.video_type_label WHERE aid IS NOT NULL"
|
|
rows = await conn.fetch(query)
|
|
return [int(row["aid"]) for row in rows]
|
|
|
|
async def get_all_aids_count(self) -> List[int]:
|
|
"""Get all available AIDs from labeled data (internal.video_type_label)"""
|
|
async with self.pool.acquire() as conn:
|
|
query = "SELECT COUNT(DISTINCT aid) FROM internal.video_type_label WHERE aid IS NOT NULL"
|
|
rows = await conn.fetch(query)
|
|
return rows[0]["count"]
|
|
|
|
async def get_aids_by_strategy(
|
|
self, strategy: str, limit: Optional[int] = None
|
|
) -> List[int]:
|
|
"""Get AIDs based on sampling strategy"""
|
|
if strategy == "all":
|
|
return await self.get_all_aids()
|
|
elif strategy == "random":
|
|
return await self.get_random_aids(limit or 1000)
|
|
else:
|
|
raise ValueError(f"Unknown sampling strategy: {strategy}")
|
|
|
|
async def get_random_aids(
|
|
self, limit: int
|
|
) -> List[int]:
|
|
"""Get random AIDs from labeled data only"""
|
|
async with self.pool.acquire() as conn:
|
|
query = "SELECT aid FROM internal.video_type_label ORDER BY RANDOM() LIMIT $1"
|
|
rows = await conn.fetch(query, limit)
|
|
aids = [int(row["aid"]) for row in rows]
|
|
# deduplication
|
|
return list(set(aids))
|
|
|
|
async def get_sampling_stats(self) -> Dict[str, Any]:
|
|
"""Get statistics about available labeled data for sampling"""
|
|
async with self.pool.acquire() as conn:
|
|
# Total labeled videos
|
|
total_labeled_query = (
|
|
"SELECT COUNT(DISTINCT aid) as count FROM internal.video_type_label"
|
|
)
|
|
total_labeled_result = await conn.fetchrow(total_labeled_query)
|
|
total_labeled_videos = total_labeled_result["count"]
|
|
|
|
# Positive and negative labels
|
|
positive_query = "SELECT COUNT(DISTINCT aid) as count FROM internal.video_type_label WHERE label = true"
|
|
negative_query = "SELECT COUNT(DISTINCT aid) as count FROM internal.video_type_label WHERE label = false"
|
|
|
|
positive_result = await conn.fetchrow(positive_query)
|
|
negative_result = await conn.fetchrow(negative_query)
|
|
|
|
positive_labels = positive_result["count"]
|
|
negative_labels = negative_result["count"]
|
|
|
|
return {
|
|
"total_labeled_videos": total_labeled_videos,
|
|
"positive_labels": positive_labels,
|
|
"negative_labels": negative_labels,
|
|
}
|
|
|
|
|
|
# Global database manager instance
|
|
db_manager = DatabaseManager()
|