Compare commits
8 Commits
99c7a34833
...
b4205049cb
Author | SHA1 | Date | |
---|---|---|---|
b4205049cb | |||
509c10ded0 | |||
41f8b42f1c | |||
fba56106cc | |||
81847cc090 | |||
0d18c921cb | |||
ba6b8bd5b3 | |||
2ed909268e |
@ -20,7 +20,7 @@ class VideoPlayDataset(Dataset):
|
|||||||
self.valid_series = [s for s in self.series_dict.values() if len(s['abs_time']) > 1]
|
self.valid_series = [s for s in self.series_dict.values() if len(s['abs_time']) > 1]
|
||||||
self.term = term
|
self.term = term
|
||||||
# Set time window based on term
|
# Set time window based on term
|
||||||
self.time_window = 1000 * 24 * 3600 if term == 'long' else 7 * 24 * 3600
|
self.time_window = 1000 * 24 * 3600 if term == 'long' else 3 * 24 * 3600
|
||||||
MINUTE = 60
|
MINUTE = 60
|
||||||
HOUR = 3600
|
HOUR = 3600
|
||||||
DAY = 24 * HOUR
|
DAY = 24 * HOUR
|
||||||
@ -37,6 +37,7 @@ class VideoPlayDataset(Dataset):
|
|||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
self.feature_windows = [
|
self.feature_windows = [
|
||||||
|
#( 5 * MINUTE, 0 * MINUTE),
|
||||||
( 15 * MINUTE, 0 * MINUTE),
|
( 15 * MINUTE, 0 * MINUTE),
|
||||||
( 40 * MINUTE, 0 * MINUTE),
|
( 40 * MINUTE, 0 * MINUTE),
|
||||||
( 1 * HOUR, 0 * HOUR),
|
( 1 * HOUR, 0 * HOUR),
|
||||||
@ -45,7 +46,7 @@ class VideoPlayDataset(Dataset):
|
|||||||
( 3 * HOUR, 0 * HOUR),
|
( 3 * HOUR, 0 * HOUR),
|
||||||
#( 6 * HOUR, 3 * HOUR),
|
#( 6 * HOUR, 3 * HOUR),
|
||||||
( 6 * HOUR, 0 * HOUR),
|
( 6 * HOUR, 0 * HOUR),
|
||||||
(18 * HOUR, 12 * HOUR),
|
#(18 * HOUR, 12 * HOUR),
|
||||||
#( 1 * DAY, 6 * HOUR),
|
#( 1 * DAY, 6 * HOUR),
|
||||||
( 1 * DAY, 0 * DAY),
|
( 1 * DAY, 0 * DAY),
|
||||||
#( 2 * DAY, 1 * DAY),
|
#( 2 * DAY, 1 * DAY),
|
||||||
|
@ -4,20 +4,20 @@ from model import CompactPredictor
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
model = CompactPredictor(10).to('cpu', dtype=torch.float32)
|
model = CompactPredictor(15).to('cpu', dtype=torch.float32)
|
||||||
model.load_state_dict(torch.load('./pred/checkpoints/long_term.pt'))
|
model.load_state_dict(torch.load('./pred/checkpoints/model_20250320_0045.pt'))
|
||||||
model.eval()
|
model.eval()
|
||||||
# inference
|
# inference
|
||||||
initial = 997029
|
initial = 999704
|
||||||
last = initial
|
last = initial
|
||||||
start_time = '2025-03-17 00:13:17'
|
start_time = '2025-03-19 22:00:42'
|
||||||
for i in range(1, 120):
|
for i in range(1, 48):
|
||||||
hour = i / 0.5
|
hour = i / 6
|
||||||
sec = hour * 3600
|
sec = hour * 3600
|
||||||
time_d = np.log2(sec)
|
time_d = np.log2(sec)
|
||||||
data = [time_d, np.log2(initial+1), # time_delta, current_views
|
data = [time_d, np.log2(initial+1), # time_delta, current_views
|
||||||
6.111542, 8.404707, 10.071566, 11.55888, 12.457823,# grows_feat
|
4.857981, 6.29067, 6.869476, 6.58392, 6.523051, 8.242355, 8.841574, 10.203909, 11.449314, 12.659556, # grows_feat
|
||||||
0.009225, 0.001318, 28.001814# time_feat
|
0.916956, 0.416708, 28.003162 # time_feat
|
||||||
]
|
]
|
||||||
np_arr = np.array([data])
|
np_arr = np.array([data])
|
||||||
tensor = torch.from_numpy(np_arr).to('cpu', dtype=torch.float32)
|
tensor = torch.from_numpy(np_arr).to('cpu', dtype=torch.float32)
|
||||||
@ -25,7 +25,7 @@ def main():
|
|||||||
num = output.detach().numpy()[0][0]
|
num = output.detach().numpy()[0][0]
|
||||||
views_pred = int(np.exp2(num)) + initial
|
views_pred = int(np.exp2(num)) + initial
|
||||||
current_time = datetime.datetime.strptime(start_time, '%Y-%m-%d %H:%M:%S') + datetime.timedelta(hours=hour)
|
current_time = datetime.datetime.strptime(start_time, '%Y-%m-%d %H:%M:%S') + datetime.timedelta(hours=hour)
|
||||||
print(current_time.strftime('%m-%d %H:%M:%S'), views_pred, views_pred - last)
|
print(current_time.strftime('%m-%d %H:%M'), views_pred, views_pred - last)
|
||||||
last = views_pred
|
last = views_pred
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -38,7 +38,7 @@ def train(model, dataloader, device, epochs=100):
|
|||||||
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3,
|
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3,
|
||||||
total_steps=len(dataloader)*30)
|
total_steps=len(dataloader)*30)
|
||||||
# Huber loss
|
# Huber loss
|
||||||
criterion = asymmetricHuberLoss(delta=1.0, beta=2.1)
|
criterion = asymmetricHuberLoss(delta=1.0, beta=2.2)
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
global_step = 0
|
global_step = 0
|
||||||
@ -100,7 +100,7 @@ if __name__ == "__main__":
|
|||||||
device = 'mps'
|
device = 'mps'
|
||||||
|
|
||||||
# Initialize dataset and model
|
# Initialize dataset and model
|
||||||
dataset = VideoPlayDataset('./data/pred', './data/pred/publish_time.csv', 'short')
|
dataset = VideoPlayDataset('./data/pred', './data/pred/publish_time.csv', 'short', 712)
|
||||||
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, collate_fn=collate_fn)
|
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, collate_fn=collate_fn)
|
||||||
|
|
||||||
# Get feature dimension
|
# Get feature dimension
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
import { type Client, Pool } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
|
import { type Client, Pool } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
|
||||||
import { postgresConfig } from "@core/db/pgConfig.ts";
|
import { postgresConfig, postgresConfigCred } from "@core/db/pgConfig.ts";
|
||||||
import { createMiddleware } from "hono/factory";
|
import { createMiddleware } from "hono/factory";
|
||||||
|
|
||||||
const pool = new Pool(postgresConfig, 4);
|
const pool = new Pool(postgresConfig, 4);
|
||||||
|
const poolCred = new Pool(postgresConfigCred, 2);
|
||||||
|
|
||||||
export const db = pool;
|
export const db = pool;
|
||||||
|
export const dbCred = poolCred;
|
||||||
|
|
||||||
export const dbMiddleware = createMiddleware(async (c, next) => {
|
export const dbMiddleware = createMiddleware(async (c, next) => {
|
||||||
const connection = await pool.connect();
|
const connection = await pool.connect();
|
||||||
@ -13,8 +15,16 @@ export const dbMiddleware = createMiddleware(async (c, next) => {
|
|||||||
connection.release();
|
connection.release();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
export const dbCredMiddleware = createMiddleware(async (c, next) => {
|
||||||
|
const connection = await poolCred.connect();
|
||||||
|
c.set("dbCred", connection);
|
||||||
|
await next();
|
||||||
|
connection.release();
|
||||||
|
})
|
||||||
|
|
||||||
declare module "hono" {
|
declare module "hono" {
|
||||||
interface ContextVariableMap {
|
interface ContextVariableMap {
|
||||||
db: Client;
|
db: Client;
|
||||||
|
dbCred: Client;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
{
|
{
|
||||||
"name": "@cvsa/backend",
|
"name": "@cvsa/backend",
|
||||||
"imports": {
|
"imports": {
|
||||||
|
"@rabbit-company/argon2id": "jsr:@rabbit-company/argon2id@^2.1.0",
|
||||||
"hono": "jsr:@hono/hono@^4.7.5",
|
"hono": "jsr:@hono/hono@^4.7.5",
|
||||||
"zod": "npm:zod",
|
"zod": "npm:zod",
|
||||||
"yup": "npm:yup"
|
"yup": "npm:yup"
|
||||||
|
@ -1,15 +1,18 @@
|
|||||||
import { Hono } from "hono";
|
import { Hono } from "hono";
|
||||||
import { dbMiddleware } from "./database.ts";
|
import { dbCredMiddleware, dbMiddleware } from "./database.ts";
|
||||||
import { rootHandler } from "./root.ts";
|
import { rootHandler } from "./root.ts";
|
||||||
import { getSnapshotsHanlder } from "./snapshots.ts";
|
import { getSnapshotsHanlder } from "./snapshots.ts";
|
||||||
|
import { registerHandler } from "./register.ts";
|
||||||
|
|
||||||
export const app = new Hono();
|
export const app = new Hono();
|
||||||
|
|
||||||
app.use('/video/*', dbMiddleware);
|
app.use('/video/*', dbMiddleware);
|
||||||
|
app.use('/user', dbCredMiddleware);
|
||||||
|
|
||||||
app.get("/", ...rootHandler);
|
app.get("/", ...rootHandler);
|
||||||
|
|
||||||
app.get('/video/:id/snapshots', ...getSnapshotsHanlder);
|
app.get('/video/:id/snapshots', ...getSnapshotsHanlder);
|
||||||
|
app.post('/user', ...registerHandler);
|
||||||
|
|
||||||
const fetch = app.fetch;
|
const fetch = app.fetch;
|
||||||
|
|
||||||
@ -17,4 +20,4 @@ export default {
|
|||||||
fetch,
|
fetch,
|
||||||
} satisfies Deno.ServeDefaultExport;
|
} satisfies Deno.ServeDefaultExport;
|
||||||
|
|
||||||
export const VERSION = "0.2.4";
|
export const VERSION = "0.3.0";
|
65
packages/backend/register.ts
Normal file
65
packages/backend/register.ts
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import { createHandlers } from "./utils.ts";
|
||||||
|
import Argon2id from "@rabbit-company/argon2id";
|
||||||
|
import { object, string, ValidationError } from "yup";
|
||||||
|
import type { Context } from "hono";
|
||||||
|
import type { Bindings, BlankEnv, BlankInput } from "hono/types";
|
||||||
|
import type { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
|
||||||
|
|
||||||
|
const RegistrationBodySchema = object({
|
||||||
|
username: string().trim().required("Username is required").max(50, "Username cannot exceed 50 characters"),
|
||||||
|
password: string().required("Password is required"),
|
||||||
|
nickname: string().optional(),
|
||||||
|
});
|
||||||
|
|
||||||
|
type ContextType = Context<BlankEnv & { Bindings: Bindings }, "/user", BlankInput>;
|
||||||
|
|
||||||
|
export const userExists = async (username: string, client: Client) => {
|
||||||
|
const query = `
|
||||||
|
SELECT * FROM users WHERE username = $1
|
||||||
|
`;
|
||||||
|
const result = await client.queryObject(query, [username]);
|
||||||
|
return result.rows.length > 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const registerHandler = createHandlers(async (c: ContextType) => {
|
||||||
|
const client = c.get("dbCred");
|
||||||
|
|
||||||
|
try {
|
||||||
|
const body = await RegistrationBodySchema.validate(await c.req.json());
|
||||||
|
const { username, password, nickname } = body;
|
||||||
|
|
||||||
|
if (await userExists(username, client)) {
|
||||||
|
return c.json({
|
||||||
|
message: `User "${username}" already exists.`,
|
||||||
|
}, 400);
|
||||||
|
}
|
||||||
|
|
||||||
|
const hash = await Argon2id.hashEncoded(password);
|
||||||
|
|
||||||
|
const query = `
|
||||||
|
INSERT INTO users (username, password, nickname) VALUES ($1, $2, $3)
|
||||||
|
`;
|
||||||
|
await client.queryObject(query, [username, hash, nickname || null]);
|
||||||
|
|
||||||
|
return c.json({
|
||||||
|
message: `User "${username}" registered successfully.`,
|
||||||
|
}, 201);
|
||||||
|
} catch (e) {
|
||||||
|
if (e instanceof ValidationError) {
|
||||||
|
return c.json({
|
||||||
|
message: "Invalid registration data.",
|
||||||
|
errors: e.errors,
|
||||||
|
}, 400);
|
||||||
|
} else if (e instanceof SyntaxError) {
|
||||||
|
return c.json({
|
||||||
|
message: "Invalid JSON in request body.",
|
||||||
|
}, 400);
|
||||||
|
} else {
|
||||||
|
console.error("Registration error:", e);
|
||||||
|
return c.json({
|
||||||
|
message: "An unexpected error occurred during registration.",
|
||||||
|
error: (e as Error).message,
|
||||||
|
}, 500);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
Loading…
Reference in New Issue
Block a user