cvsa/filter/old.py
2025-01-24 20:36:13 +08:00

148 lines
5.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"]="1"
import torch
import torch.nn as nn
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
import json
from torch.utils.data import Dataset, DataLoader
import numpy as np
class VideoDataset(Dataset):
def __init__(self, data_path, sentence_transformer):
self.data = []
self.sentence_transformer = sentence_transformer
with open(data_path, "r", encoding="utf-8") as f:
for line in f:
self.data.append(json.loads(line))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
title = item["title"]
description = item["description"]
tags = item["tags"]
label = item["label"]
# 获取每个特征的嵌入
title_embedding = self.get_embedding(title)
description_embedding = self.get_embedding(description)
tags_embedding = self.get_embedding(" ".join(tags))
# 将嵌入连接起来
combined_embedding = torch.cat([title_embedding, description_embedding, tags_embedding], dim=0)
return combined_embedding, label
def get_embedding(self, text):
# 使用SentenceTransformer生成嵌入
embedding = self.sentence_transformer.encode(text)
return torch.tensor(embedding)
class VideoClassifier(nn.Module):
def __init__(self, embedding_dim=768, hidden_dim=256, output_dim=3):
super(VideoClassifier, self).__init__()
# 每个特征的嵌入维度是embedding_dim总共有3个特征
total_embedding_dim = embedding_dim * 3
# 全连接层
self.fc1 = nn.Linear(total_embedding_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
self.log_softmax = nn.LogSoftmax(dim=1)
def forward(self, embedding_features):
# 全连接层
x = torch.relu(self.fc1(embedding_features))
output = self.fc2(x)
output = self.log_softmax(output)
return output
def train(model, dataloader, criterion, optimizer, device):
model.train()
total_loss = 0
correct = 0
total = 0
for embedding_features, labels in dataloader:
embedding_features = embedding_features.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(embedding_features)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
_, predicted = torch.max(outputs, 1)
correct += (predicted == labels).sum().item()
total += labels.size(0)
avg_loss = total_loss / len(dataloader)
accuracy = correct / total
return avg_loss, accuracy
def validate(model, dataloader, criterion, device):
model.eval()
total_loss = 0
correct = 0
total = 0
with torch.no_grad():
for embedding_features, labels in dataloader:
embedding_features = embedding_features.to(device)
labels = labels.to(device)
outputs = model(embedding_features)
loss = criterion(outputs, labels)
total_loss += loss.item()
_, predicted = torch.max(outputs, 1)
correct += (predicted == labels).sum().item()
total += labels.size(0)
avg_loss = total_loss / len(dataloader)
accuracy = correct / total
return avg_loss, accuracy
# 超参数
hidden_dim = 256
output_dim = 3
batch_size = 32
num_epochs = 10
learning_rate = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载数据集
tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3")
sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
dataset = VideoDataset("labeled_data.jsonl", sentence_transformer=sentence_transformer)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 初始化模型
model = VideoClassifier(embedding_dim=768, hidden_dim=256, output_dim=3).to(device)
# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
num_epochs = 5
# 训练和验证
for epoch in range(num_epochs):
train_loss, train_acc = train(model, dataloader, criterion, optimizer, device)
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}")
# 保存模型
torch.save(model.state_dict(), "video_classifier.pth")
model.eval() # 设置为评估模式
# 2. 定义推理函数
def predict(model, sentence_transformer, title, description, tags, device):
# 将输入数据转换为嵌入
title_embedding = torch.tensor(sentence_transformer.encode(title)).to(device)
description_embedding = torch.tensor(sentence_transformer.encode(description)).to(device)
tags_embedding = torch.tensor(sentence_transformer.encode(" ".join(tags))).to(device)
# 将嵌入连接起来
combined_embedding = torch.cat([title_embedding, description_embedding, tags_embedding], dim=0).unsqueeze(0)
# 推理
with torch.no_grad():
output = model(combined_embedding)
_, predicted = torch.max(output, 1)
return predicted.item()