merge: branch 'main' into gitbook

This commit is contained in:
alikia2x (寒寒) 2025-03-31 05:34:24 +08:00
commit 4d2b002264
Signed by: alikia2x
GPG Key ID: 56209E0CCD8420C6
206 changed files with 5205 additions and 697621 deletions

14
.gitignore vendored
View File

@ -51,7 +51,6 @@ internal/
!tests/cases/projects/projectOption/**/node_modules
!tests/cases/projects/NodeModulesSearch/**/*
!tests/baselines/reference/project/nodeModules*/**/*
.idea
yarn.lock
yarn-error.log
.parallelperf.*
@ -76,14 +75,13 @@ node_modules/
# project specific
data/main.db
.env
logs/
__pycache__
filter/runs
data/filter/eval*
data/filter/train*
filter/checkpoints
data/filter/model_predicted*
ml/filter/runs
ml/pred/runs
ml/pred/checkpoints
ml/pred/observed
ml/data/
ml/filter/checkpoints
scripts
model/

9
.idea/.gitignore vendored Normal file
View File

@ -0,0 +1,9 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml
dataSources.xml

21
.idea/cvsa.iml Normal file
View File

@ -0,0 +1,21 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="WEB_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<excludeFolder url="file://$MODULE_DIR$/.tmp" />
<excludeFolder url="file://$MODULE_DIR$/temp" />
<excludeFolder url="file://$MODULE_DIR$/tmp" />
<excludeFolder url="file://$MODULE_DIR$/ml/data" />
<excludeFolder url="file://$MODULE_DIR$/doc" />
<excludeFolder url="file://$MODULE_DIR$/ml/filter/checkpoints" />
<excludeFolder url="file://$MODULE_DIR$/ml/filter/runs" />
<excludeFolder url="file://$MODULE_DIR$/ml/lab/data" />
<excludeFolder url="file://$MODULE_DIR$/ml/lab/temp" />
<excludeFolder url="file://$MODULE_DIR$/logs" />
<excludeFolder url="file://$MODULE_DIR$/model" />
<excludeFolder url="file://$MODULE_DIR$/src/db" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

View File

@ -0,0 +1,12 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="GrazieInspection" enabled="false" level="GRAMMAR_ERROR" enabled_by_default="false" />
<inspection_tool class="LanguageDetectionInspection" enabled="false" level="WARNING" enabled_by_default="false" />
<inspection_tool class="SpellCheckingInspection" enabled="false" level="TYPO" enabled_by_default="false">
<option name="processCode" value="true" />
<option name="processLiterals" value="true" />
<option name="processComments" value="true" />
</inspection_tool>
</profile>
</component>

8
.idea/modules.xml Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/cvsa.iml" filepath="$PROJECT_DIR$/.idea/cvsa.iml" />
</modules>
</component>
</project>

6
.idea/sqldialects.xml Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="SqlDialectMappings">
<file url="PROJECT" dialect="PostgreSQL" />
</component>
</project>

6
.idea/vcs.xml Normal file
View File

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>

View File

@ -3,3 +3,4 @@ data
*.svg
*.txt
*.md
*config*

35
.zed/settings.json Normal file
View File

@ -0,0 +1,35 @@
// Folder-specific settings
//
// For a full list of overridable settings, and general information on folder-specific settings,
// see the documentation: https://zed.dev/docs/configuring-zed#settings-files
{
"lsp": {
"deno": {
"settings": {
"deno": {
"enable": true
}
}
}
},
"languages": {
"TypeScript": {
"language_servers": [
"deno",
"!typescript-language-server",
"!vtsls",
"!eslint"
],
"formatter": "language_server"
},
"TSX": {
"language_servers": [
"deno",
"!typescript-language-server",
"!vtsls",
"!eslint"
],
"formatter": "language_server"
}
}
}

View File

@ -6,9 +6,12 @@
纵观整个互联网对于「中文歌声合成」或「中文虚拟歌手」常简称为中V或VC相关信息进行较为系统、全面地整理收集的主要有以下几个网站
- [萌娘百科](https://zh.moegirl.org.cn/): 收录了大量中V歌曲及歌姬的信息呈现形式为传统维基基于[MediaWiki](https://www.mediawiki.org/))。
- [VCPedia](https://vcpedia.cn/): 由原萌娘百科中文歌声合成编辑团队的部分成员搭建,专属于中文歌声合成相关内容的信息集成站点[^1],呈现形式为传统维基(基于[MediaWiki](https://www.mediawiki.org/))。
- [VocaDB](https://vocadb.net/): 一个围绕 Vocaloid、UTAU 和其他歌声合成器的协作数据库其中包含艺术家、唱片、PV 等[^2],其中包含大量中文歌声合成作品。
- [萌娘百科](https://zh.moegirl.org.cn/):
收录了大量中V歌曲及歌姬的信息呈现形式为传统维基基于[MediaWiki](https://www.mediawiki.org/))。
- [VCPedia](https://vcpedia.cn/):
由原萌娘百科中文歌声合成编辑团队的部分成员搭建,专属于中文歌声合成相关内容的信息集成站点[^1],呈现形式为传统维基(基于[MediaWiki](https://www.mediawiki.org/))。
- [VocaDB](https://vocadb.net/): 一个围绕 Vocaloid、UTAU 和其他歌声合成器的协作数据库其中包含艺术家、唱片、PV
等[^2],其中包含大量中文歌声合成作品。
- [天钿Daily](https://tdd.bunnyxt.com/)一个VC相关数据交流与分享的网站。致力于VC相关数据交流定期抓取VC相关数据选取有意义的纬度展示。[^3]
上述网站中,或多或少存在一些不足,例如:
@ -36,11 +39,13 @@
### 数据库
中V档案馆使用[PostgreSQL](https://postgresql.org)作为数据库,我们承诺定期导出数据库转储 (dump) 文件并公开,其内容遵从以下协议或条款:
中V档案馆使用[PostgreSQL](https://postgresql.org)作为数据库,我们承诺定期导出数据库转储 (dump)
文件并公开,其内容遵从以下协议或条款:
- 数据库中的事实性数据根据适用法律不构成受版权保护的内容。中V档案馆放弃一切可能的权利[CC0 1.0 Universal](https://creativecommons.org/publicdomain/zero/1.0/))。
- 对于数据库中有原创性的内容(如贡献者编辑的描述性内容),如无例外,以[CC BY 4.0协议](https://creativecommons.org/licenses/by/4.0/)提供。
- 对于引用、摘编或改编自萌娘百科、VCPedia的内容以与原始协议(CC BY-NC-SA 3.0 CN)兼容的协议[CC BY-NC-SA 4.0协议](https://creativecommons.org/licenses/by-nc-sa/4.0/)提供,并注明原始协议 。
- 对于引用、摘编或改编自萌娘百科、VCPedia的内容以与原始协议(CC BY-NC-SA 3.0
CN)兼容的协议[CC BY-NC-SA 4.0协议](https://creativecommons.org/licenses/by-nc-sa/4.0/)提供,并注明原始协议 。
> 根据原始协议第四条第2项内容CC BY-NC-SA 4.0协议为与原始协议具有相同授权要素的后续版本(“可适用的协议”)。
- 中V档案馆文档使用[CC BY 4.0协议](https://creativecommons.org/licenses/by/4.0/)。
@ -48,7 +53,8 @@
用于构建中V档案馆的软件代码在[AGPL 3.0](https://www.gnu.org/licenses/agpl-3.0.html)许可证下公开,参见[LICENSE](./LICENSE)
[^1]: 引用自[VCPedia](https://vcpedia.cn/%E9%A6%96%E9%A1%B5),于[知识共享 署名-非商业性使用-相同方式共享 3.0中国大陆 (CC BY-NC-SA 3.0 CN) 许可协议](https://creativecommons.org/licenses/by-nc-sa/3.0/cn/)下提供。
[^2]: 翻译自[VocaDB](https://vocadb.net/),于[CC BY 4.0协议](https://creativecommons.org/licenses/by/4.0/)下提供。
[^3]: 引用自[关于 - 天钿Daily](https://tdd.bunnyxt.com/about)

View File

@ -1,12 +0,0 @@
import { JSX } from "preact";
import { IS_BROWSER } from "$fresh/runtime.ts";
export function Button(props: JSX.HTMLAttributes<HTMLButtonElement>) {
return (
<button
{...props}
disabled={!IS_BROWSER || props.disabled}
class="px-2 py-1 border-gray-500 border-2 rounded bg-white hover:bg-gray-200 transition-colors"
/>
);
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +0,0 @@
# The data
感谢[天钿Daily](https://tdd.bunnyxt.com/)提供的数据。

View File

@ -1,55 +0,0 @@
import json
import random
def process_data(input_file, output_file):
"""
从输入文件中读取数据找出model和human不一致的行
删除"model""human"键重命名为"label"
然后将处理后的数据添加到输出文件中
在写入之前它会加载output_file中的所有样本
并使用aid键进行去重过滤
Args:
input_file (str): 输入文件的路径
output_file (str): 输出文件的路径
"""
# 加载output_file中已有的数据用于去重
existing_data = set()
try:
with open(output_file, 'r', encoding='utf-8') as f_out:
for line in f_out:
try:
data = json.loads(line)
existing_data.add(data['aid'])
except json.JSONDecodeError:
pass # 忽略JSON解码错误继续读取下一行
except FileNotFoundError:
pass # 如果文件不存在,则忽略
with open(input_file, 'r', encoding='utf-8') as f_in, open(output_file, 'a', encoding='utf-8') as f_out:
for line in f_in:
try:
data = json.loads(line)
if data['model'] != data['human'] or random.random() < 0.2:
if data['aid'] not in existing_data: # 检查aid是否已存在
del data['model']
data['label'] = data['human']
del data['human']
f_out.write(json.dumps(data, ensure_ascii=False) + '\n')
existing_data.add(data['aid']) # 将新的aid添加到集合中
except json.JSONDecodeError as e:
print(f"JSON解码错误: {e}")
print(f"错误行内容: {line.strip()}")
except KeyError as e:
print(f"KeyError: 键 '{e}' 不存在")
print(f"错误行内容: {line.strip()}")
# 调用函数处理数据
input_file = 'real_test.jsonl'
output_file = 'labeled_data.jsonl'
process_data(input_file, output_file)
print(f"处理完成,结果已写入 {output_file}")

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1,57 +1,22 @@
{
"lock": false,
"tasks": {
"crawl-raw-bili": "deno --allow-env --allow-ffi --allow-read --allow-net --allow-write --allow-run src/db/raw/insertAidsToDB.ts",
"crawl-bili-aids": "deno --allow-env --allow-ffi --allow-read --allow-net --allow-write --allow-run src/db/raw/fetchAids.ts",
"check": "deno fmt --check && deno lint && deno check **/*.ts && deno check **/*.tsx",
"cli": "echo \"import '\\$fresh/src/dev/cli.ts'\" | deno run --unstable -A -",
"manifest": "deno task cli manifest $(pwd)",
"start": "deno run -A --watch=static/,routes/ dev.ts",
"build": "deno run -A dev.ts build",
"preview": "deno run -A main.ts",
"update": "deno run -A -r https://fresh.deno.dev/update .",
"worker": "deno run --env-file=.env --allow-env --allow-read --allow-ffi --allow-net --allow-write ./src/worker.ts",
"adder": "deno run --allow-env --allow-read --allow-ffi --allow-net ./src/jobAdder.ts",
"bullui": "deno run --allow-read --allow-env --allow-ffi --allow-net ./src/bullui.ts",
"all": "concurrently 'deno task start' 'deno task worker' 'deno task adder' 'deno task bullui'",
"test": "deno test ./test/ --allow-env --allow-ffi --allow-read --allow-net --allow-write --allow-run"
},
"lint": {
"rules": {
"tags": ["fresh", "recommended"]
}
},
"exclude": ["**/_fresh/*"],
"imports": {
"@std/assert": "jsr:@std/assert@1",
"$fresh/": "https://deno.land/x/fresh@1.7.3/",
"preact": "https://esm.sh/preact@10.22.0",
"preact/": "https://esm.sh/preact@10.22.0/",
"@preact/signals": "https://esm.sh/*@preact/signals@1.2.2",
"@preact/signals-core": "https://esm.sh/*@preact/signals-core@1.5.1",
"tailwindcss": "npm:tailwindcss@3.4.1",
"tailwindcss/": "npm:/tailwindcss@3.4.1/",
"tailwindcss/plugin": "npm:/tailwindcss@3.4.1/plugin.js",
"$std/": "https://deno.land/std@0.216.0/",
"@huggingface/transformers": "npm:@huggingface/transformers@3.0.0",
"bullmq": "npm:bullmq",
"lib/": "./lib/",
"ioredis": "npm:ioredis",
"@bull-board/api": "npm:@bull-board/api",
"@bull-board/express": "npm:@bull-board/express",
"express": "npm:express",
"src/": "./src/"
},
"compilerOptions": {
"jsx": "react-jsx",
"jsxImportSource": "preact"
},
"workspace": ["./packages/crawler", "./packages/frontend", "./packages/backend", "./packages/core"],
"nodeModulesDir": "auto",
"tasks": {
"crawler": "deno task --filter 'crawler' all",
"backend": "deno task --filter 'backend' start"
},
"fmt": {
"useTabs": true,
"lineWidth": 120,
"indentWidth": 4,
"semiColons": true,
"proseWrap": "always"
},
"imports": {
"@astrojs/node": "npm:@astrojs/node@^9.1.3",
"@astrojs/svelte": "npm:@astrojs/svelte@^7.0.8",
"@core/db/": "./packages/core/db/",
"date-fns": "npm:date-fns@^4.1.0"
}
}

7
dev.ts
View File

@ -1,7 +0,0 @@
#!/usr/bin/env -S deno run -A --watch=static/,routes/
import dev from "$fresh/dev.ts";
import config from "./fresh.config.ts";
import "$std/dotenv/load.ts";
await dev(import.meta.url, "./main.ts", config);

View File

@ -17,7 +17,8 @@ layout:
Welcome to the CVSA Documentation!
This doc contains various information about the CVSA project, including technical architecture, tutorials for visitors, etc.
This doc contains various information about the CVSA project, including technical architecture, tutorials for visitors,
etc.
### Jump right in

View File

@ -1,11 +1,11 @@
# Table of contents
* [Welcome](README.md)
- [Welcome](README.md)
## About
* [About CVSA Project](about/this-project.md)
* [Scope of Inclusion](about/scope-of-inclusion.md)
- [About CVSA Project](about/this-project.md)
- [Scope of Inclusion](about/scope-of-inclusion.md)
## Architecure
@ -17,5 +17,5 @@
## API Doc
* [Catalog](api-doc/catalog.md)
* [Songs](api-doc/songs.md)
- [Catalog](api-doc/catalog.md)
- [Songs](api-doc/songs.md)

View File

@ -1,6 +1,7 @@
# Scope of Inclusion
CVSA contains many aspects of Chinese Vocal Synthesis, including songs, albums, artists (publisher, manipulators, arranger, etc), singers and voice engines / voicebanks.&#x20;
CVSA contains many aspects of Chinese Vocal Synthesis, including songs, albums, artists (publisher, manipulators,
arranger, etc), singers and voice engines / voicebanks.&#x20;
For a **song**, it must meet the following conditions to be included in CVSA:
@ -26,6 +27,11 @@ We define a **Chinese virtual singer** as follows:
### Using Vocal Synthesizer
To be included in CVSA, at least one line of the song must be produced by a Vocal Synthesizer (including harmony vocals).
To be included in CVSA, at least one line of the song must be produced by a Vocal Synthesizer (including harmony
vocals).
We define a vocal synthesizer as a software or system that generates synthesized singing voices by algorithmically modeling vocal characteristics and producing audio from input parameters such as lyrics, pitch, and dynamics, encompassing both waveform-concatenation-based (e.g., VOCALOID, UTAU) and AI-based (e.g., Synthesizer V, ACE Studio) approaches, **but excluding voice conversion tools that solely alter the timbre of pre-existing recordings** (e.g., [so-vits svc](https://github.com/svc-develop-team/so-vits-svc)).
We define a vocal synthesizer as a software or system that generates synthesized singing voices by algorithmically
modeling vocal characteristics and producing audio from input parameters such as lyrics, pitch, and dynamics,
encompassing both waveform-concatenation-based (e.g., VOCALOID, UTAU) and AI-based (e.g., Synthesizer V, ACE Studio)
approaches, **but excluding voice conversion tools that solely alter the timbre of pre-existing recordings** (e.g.,
[so-vits svc](https://github.com/svc-develop-team/so-vits-svc)).

View File

@ -1,11 +1,13 @@
# About CVSA Project
CVSA (Chinese Vocal Synthesis Archive) aims to collect as much content as possible about the Chinese Vocal Synthesis community in a highly automation-assisted way.&#x20;
CVSA (Chinese Vocal Synthesis Archive) aims to collect as much content as possible about the Chinese Vocal Synthesis
community in a highly automation-assisted way.&#x20;
Unlike existing projects such as [VocaDB](https://vocadb.net), CVSA collects and displays the following content in an automated and manually edited way:
* Metadata of songs (name, duration, publisher, singer, etc.)
* Descriptive information of songs (content introduction, creation background, lyrics, etc.)
* Engagement data snapshots of songs, i.e. historical snapshots of their engagement data (including views, favorites, likes, etc.) on the [Bilibili](https://en.wikipedia.org/wiki/Bilibili) website.
* Information about artists, albums, vocal synthesizers, and voicebanks.
Unlike existing projects such as [VocaDB](https://vocadb.net), CVSA collects and displays the following content in an
automated and manually edited way:
- Metadata of songs (name, duration, publisher, singer, etc.)
- Descriptive information of songs (content introduction, creation background, lyrics, etc.)
- Engagement data snapshots of songs, i.e. historical snapshots of their engagement data (including views, favorites,
likes, etc.) on the [Bilibili](https://en.wikipedia.org/wiki/Bilibili) website.
- Information about artists, albums, vocal synthesizers, and voicebanks.

View File

@ -1,4 +1,3 @@
# Catalog
* [**Songs**](songs.md)
- [**Songs**](songs.md)

View File

@ -6,7 +6,8 @@ The AI systems we currently use are:
### The Filter
Located at `/filter/` under project root dir, it classifies a video in the [category 30](../about/scope-of-inclusion.md#category-30) into the following categories:
Located at `/filter/` under project root dir, it classifies a video in the
[category 30](../about/scope-of-inclusion.md#category-30) into the following categories:
* 0: Not related to Chinese vocal synthesis
* 1: A original song with Chinese vocal synthesis

View File

@ -2,7 +2,8 @@
CVSA uses [PostgreSQL](https://www.postgresql.org/) as our database.
All public data of CVSA (excluding users' personal data) is stored in a database named `cvsa_main`, which contains the following tables:
All public data of CVSA (excluding users' personal data) is stored in a database named `cvsa_main`, which contains the
following tables:
* songs: stores the main information of songs
* bili\_user: stores snapshots of Bilibili user information

View File

@ -1,6 +1,7 @@
# Type of Song
The **Unrelated type** refers specifically to videos that are not in our [Scope of Inclusion](../../about/scope-of-inclusion.md).
The **Unrelated type** refers specifically to videos that are not in our
[Scope of Inclusion](../../about/scope-of-inclusion.md).
### Table: `songs`

View File

@ -1,22 +1,22 @@
# Table of contents
* [欢迎](README.md)
- [欢迎](README.md)
## 关于 <a href="#about" id="about"></a>
* [关于本项目](about/this-project.md)
* [收录范围](about/scope-of-inclusion.md)
- [关于本项目](about/this-project.md)
- [收录范围](about/scope-of-inclusion.md)
## 技术架构 <a href="#architecture" id="architecture"></a>
* [概览](architecture/overview.md)
* [数据库结构](architecture/database-structure/README.md)
* [歌曲类型](architecture/database-structure/type-of-song.md)
* [人工智能](architecture/artificial-intelligence.md)
* [消息队列](architecture/message-queue/README.md)
* [VideoTagsQueue队列](architecture/message-queue/video-tags-queue.md)
- [概览](architecture/overview.md)
- [数据库结构](architecture/database-structure/README.md)
- [歌曲类型](architecture/database-structure/type-of-song.md)
- [人工智能](architecture/artificial-intelligence.md)
- [消息队列](architecture/message-queue/README.md)
- [VideoTagsQueue队列](architecture/message-queue/video-tags-queue.md)
## API 文档 <a href="#api-doc" id="api-doc"></a>
* [目录](api-doc/catalog.md)
* [歌曲](api-doc/songs.md)
- [目录](api-doc/catalog.md)
- [歌曲](api-doc/songs.md)

View File

@ -6,7 +6,8 @@
#### VOCALOID·UATU 分区
原则上中V档案馆中收录的歌曲必须包含在哔哩哔哩 VOCALOID·UTAU 分区分区ID为30下的视频中。在某些特殊情况下此规则可能不是强制的。
原则上中V档案馆中收录的歌曲必须包含在哔哩哔哩 VOCALOID·UTAU
分区分区ID为30下的视频中。在某些特殊情况下此规则可能不是强制的。
#### 至少一行中文
@ -16,4 +17,6 @@
歌曲的至少一行必须由歌声合成器生成包括和声部分才能被收录到中V档案馆中。
我们将歌声合成器定义为通过算法建模声音特征并根据输入的歌词、音高等参数生成音频的软件或系统,包括基于波形拼接的(如 VOCALOID、UTAU和基于 AI 的(如 Synthesizer V、ACE Studio方法**但不包括仅改变现有歌声音色的AI声音转换器**(例如 [so-vits svc](https://github.com/svc-develop-team/so-vits-svc))。
我们将歌声合成器定义为通过算法建模声音特征并根据输入的歌词、音高等参数生成音频的软件或系统,包括基于波形拼接的(如
VOCALOID、UTAU和基于 AI 的(如 Synthesizer V、ACE Studio方法**但不包括仅改变现有歌声音色的AI声音转换器**(例如
[so-vits svc](https://github.com/svc-develop-team/so-vits-svc))。

View File

@ -6,34 +6,33 @@
纵观整个互联网对于「中文歌声合成」或「中文虚拟歌手」常简称为中V或VC相关信息进行较为系统、全面地整理收集的主要有以下几个网站
* [萌娘百科](https://zh.moegirl.org.cn/): 收录了大量中V歌曲及歌姬的信息呈现形式为传统维基基于[MediaWiki](https://www.mediawiki.org/))。
* [VCPedia](https://vcpedia.cn/): 由原萌娘百科中文歌声合成编辑团队的部分成员搭建,专属于中文歌声合成相关内容的信息集成站点[^1],呈现形式为传统维基(基于[MediaWiki](https://www.mediawiki.org/))。
* [VocaDB](https://vocadb.net/): [一个围绕 Vocaloid、UTAU 和其他歌声合成器的协作数据库其中包含艺术家、唱片、PV 等](#user-content-fn-2)[^2],其中包含大量中文歌声合成作品。
* [天钿Daily](https://tdd.bunnyxt.com/)一个VC相关数据交流与分享的网站。致力于VC相关数据交流定期抓取VC相关数据选取有意义的纬度展示。
- [萌娘百科](https://zh.moegirl.org.cn/):
收录了大量中V歌曲及歌姬的信息呈现形式为传统维基基于[MediaWiki](https://www.mediawiki.org/))。
- [VCPedia](https://vcpedia.cn/):
由原萌娘百科中文歌声合成编辑团队的部分成员搭建,专属于中文歌声合成相关内容的信息集成站点[^1],呈现形式为传统维基(基于[MediaWiki](https://www.mediawiki.org/))。
- [VocaDB](https://vocadb.net/):
[一个围绕 Vocaloid、UTAU 和其他歌声合成器的协作数据库其中包含艺术家、唱片、PV 等](#user-content-fn-2)[^2],其中包含大量中文歌声合成作品。
- [天钿Daily](https://tdd.bunnyxt.com/)一个VC相关数据交流与分享的网站。致力于VC相关数据交流定期抓取VC相关数据选取有意义的纬度展示。
上述网站中,或多或少存在一些不足,例如:
* 萌娘百科、VCPedia受限于传统维基绝大多数内容依赖人工编辑。
* VocaDB基于结构化数据库构建由此可以依赖程序生成一些信息但**条目收录**仍然完全依赖人工完成。
* VocaDB主要专注于元数据展示少有关于歌曲、作者等的描述性的文字也缺乏描述性的背景信息。
* 天钿Daily只展示歌曲的统计数据及历史趋势没有关于歌曲其它信息的收集。
- 萌娘百科、VCPedia受限于传统维基绝大多数内容依赖人工编辑。
- VocaDB基于结构化数据库构建由此可以依赖程序生成一些信息但**条目收录**仍然完全依赖人工完成。
- VocaDB主要专注于元数据展示少有关于歌曲、作者等的描述性的文字也缺乏描述性的背景信息。
- 天钿Daily只展示歌曲的统计数据及历史趋势没有关于歌曲其它信息的收集。
因此,**中V档案馆**吸取前人经验,克服上述网站的不足,希望做到:
* 歌曲收录(指发现歌曲并创建条目)的完全自动化
* 歌曲元信息提取的高度自动化
* 歌曲统计数据收集的完全自动化
* 在程序辅助的同时欢迎并鼓励贡献者参与编辑(主要为描述性内容)或纠错
* 在适当的许可声明下,引用来自上述源的数据,使内容更加全面、丰富。
- 歌曲收录(指发现歌曲并创建条目)的完全自动化
- 歌曲元信息提取的高度自动化
- 歌曲统计数据收集的完全自动化
- 在程序辅助的同时欢迎并鼓励贡献者参与编辑(主要为描述性内容)或纠错
- 在适当的许可声明下,引用来自上述源的数据,使内容更加全面、丰富。
***
---
本文在[CC BY-NC-SA 4.0协议](https://creativecommons.org/licenses/by-nc-sa/4.0/)提供。
[^1]: 引用自[VCPedia](https://vcpedia.cn/%E9%A6%96%E9%A1%B5),于[知识共享 署名-非商业性使用-相同方式共享 3.0中国大陆 (CC BY-NC-SA 3.0 CN) 许可协议](https://creativecommons.org/licenses/by-nc-sa/3.0/cn/)下提供。
[^2]: 翻译自[VocaDB](https://vocadb.net/),于[CC BY 4.0协议](https://creativecommons.org/licenses/by/4.0/)下提供。

View File

@ -1,3 +1,3 @@
# 目录
* [歌曲](songs.md)
- [歌曲](songs.md)

View File

@ -8,6 +8,6 @@ CVSA 的自动化工作流高度依赖人工智能进行信息提取和分类。
位于项目根目录下的 `/filter/`,它将 [30 分区](../about/scope-of-inclusion.md#vocaloiduatu-fen-qu) 中的视频分为以下类别:
* 0与中文人声合成无关
* 1中文人声合成原创曲
* 2中文人声合成的翻唱/混音歌曲
- 0与中文人声合成无关
- 1中文人声合成原创曲
- 2中文人声合成的翻唱/混音歌曲

View File

@ -4,7 +4,7 @@ CVSA 使用 [PostgreSQL](https://www.postgresql.org/) 作为数据库。
CVSA 的所有公开数据(不包括用户的个人数据)都存储在名为 `cvsa_main` 的数据库中,该数据库包含以下表:
* songs存储歌曲的主要信息
* bili\_user存储 Bilibili 用户信息快照
* all\_data[分区 30](../../about/scope-of-inclusion.md#vocaloiduatu-fen-qu) 中所有视频的元数据。
* labelling\_result包含由我们的 AI 系统 标记的 `all_data` 中视频的标签。
- songs存储歌曲的主要信息
- bili\_user存储 Bilibili 用户信息快照
- all\_data[分区 30](../../about/scope-of-inclusion.md#vocaloiduatu-fen-qu) 中所有视频的元数据。
- labelling\_result包含由我们的 AI 系统 标记的 `all_data` 中视频的标签。

View File

@ -7,7 +7,7 @@
`songs` 表格中使用的 `type` 列。
| 类型 | 说明 |
| -- | ---------- |
| ---- | ------------ |
| 0 | 不相关 |
| 1 | 原创 |
| 2 | 翻唱 (Cover) |
@ -18,7 +18,7 @@
#### 表格:`labelling_result`
| 标签 | 说明 |
| -- | ----------- |
| ---- | ------------------ |
| 0 | AI 标记:不相关 |
| 1 | AI 标记:原创 |
| 2 | AI 标记:翻唱/混音 |

View File

@ -1,2 +1 @@
# 消息队列

View File

@ -1,31 +0,0 @@
import torch
from model2vec import StaticModel
def prepare_batch(batch_data, device="cpu"):
"""
将输入的 batch_data 转换为模型所需的输入格式 [batch_size, num_channels, embedding_dim]
参数:
batch_data (dict): 输入的 batch 数据格式为 {
"title": [text1, text2, ...],
"description": [text1, text2, ...],
"tags": [text1, text2, ...],
"author_info": [text1, text2, ...]
}
device (str): 模型运行的设备 "cpu" "cuda"
返回:
torch.Tensor: 形状为 [batch_size, num_channels, embedding_dim] 的张量
"""
# 1. 对每个通道的文本分别编码
channel_embeddings = []
model = StaticModel.from_pretrained("./model/embedding/")
for channel in ["title", "description", "tags", "author_info"]:
texts = batch_data[channel] # 获取当前通道的文本列表
embeddings = torch.from_numpy(model.encode(texts)).to(torch.float32).to(device) # 编码为 [batch_size, embedding_dim]
channel_embeddings.append(embeddings)
# 2. 将编码结果堆叠为 [batch_size, num_channels, embedding_dim]
batch_tensor = torch.stack(channel_embeddings, dim=1) # 在 dim=1 上堆叠
return batch_tensor

View File

@ -1,58 +0,0 @@
import torch
import torch.nn as nn
class VideoClassifierV3(nn.Module):
def __init__(self, embedding_dim=1024, hidden_dim=256, output_dim=3):
super().__init__()
self.num_channels = 4
self.channel_names = ['title', 'description', 'tags', 'author_info']
# 改进1带温度系数的通道权重比原始固定权重更灵活
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))
self.temperature = 2.0 # 可调节的平滑系数
# 改进2更稳健的全连接结构
self.fc = nn.Sequential(
nn.Linear(embedding_dim * self.num_channels, hidden_dim*2),
nn.BatchNorm1d(hidden_dim*2),
nn.Dropout(0.2),
nn.ReLU(),
nn.Linear(hidden_dim*2, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, output_dim)
)
# 改进3输出层初始化
nn.init.xavier_uniform_(self.fc[-1].weight)
nn.init.zeros_(self.fc[-1].bias)
def forward(self, input_texts, sentence_transformer):
# 合并所有通道文本进行批量编码
all_texts = [text for channel in self.channel_names for text in input_texts[channel]]
# 使用SentenceTransformer生成嵌入保持冻结
with torch.no_grad():
task = "classification"
embeddings = torch.tensor(
sentence_transformer.encode(all_texts, task=task),
device=next(self.parameters()).device
)
# 分割嵌入并加权
split_sizes = [len(input_texts[name]) for name in self.channel_names]
channel_features = torch.split(embeddings, split_sizes, dim=0)
channel_features = torch.stack(channel_features, dim=1) # [batch, 4, 1024]
# 改进4带温度系数的softmax加权
weights = torch.softmax(self.channel_weights / self.temperature, dim=0)
weighted_features = channel_features * weights.unsqueeze(0).unsqueeze(-1)
# 拼接特征
combined = weighted_features.view(weighted_features.size(0), -1)
# 全连接层
return self.fc(combined)
def get_channel_weights(self):
"""获取各通道权重(带温度调节)"""
return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy()

View File

@ -1,111 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class VideoClassifierV3_4(nn.Module):
def __init__(self, embedding_dim=1024, hidden_dim=512, output_dim=3):
super().__init__()
self.num_channels = 4
self.channel_names = ['title', 'description', 'tags', 'author_info']
# 可学习温度系数
self.temperature = nn.Parameter(torch.tensor(1.7))
# 带约束的通道权重使用Sigmoid替代Softmax
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))
# 增强的非线性层
self.fc = nn.Sequential(
nn.Linear(embedding_dim * self.num_channels, hidden_dim*2),
nn.BatchNorm1d(hidden_dim*2),
nn.Dropout(0.3),
nn.GELU(),
nn.Linear(hidden_dim*2, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.Dropout(0.2),
nn.GELU(),
nn.Linear(hidden_dim, output_dim)
)
# 权重初始化
self._init_weights()
def _init_weights(self):
for layer in self.fc:
if isinstance(layer, nn.Linear):
# 使用ReLU的初始化参数GELU的近似
nn.init.kaiming_normal_(layer.weight, nonlinearity='relu') # 修改这里
# 或者使用Xavier初始化更适合通用场景
# nn.init.xavier_normal_(layer.weight, gain=nn.init.calculate_gain('relu'))
nn.init.zeros_(layer.bias)
def forward(self, input_texts, sentence_transformer):
# 合并文本进行批量编码
all_texts = [text for channel in self.channel_names for text in input_texts[channel]]
# 冻结的文本编码
with torch.no_grad():
embeddings = torch.tensor(
sentence_transformer.encode(all_texts),
device=next(self.parameters()).device
)
# 分割并加权通道特征
split_sizes = [len(input_texts[name]) for name in self.channel_names]
channel_features = torch.split(embeddings, split_sizes, dim=0)
channel_features = torch.stack(channel_features, dim=1)
# 自适应通道权重Sigmoid约束
weights = torch.sigmoid(self.channel_weights) # [0,1]范围
weighted_features = channel_features * weights.unsqueeze(0).unsqueeze(-1)
# 特征拼接
combined = weighted_features.view(weighted_features.size(0), -1)
return self.fc(combined)
def get_channel_weights(self):
"""获取各通道权重(带温度调节)"""
return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy()
class AdaptiveRecallLoss(nn.Module):
def __init__(self, class_weights, alpha=0.8, gamma=2.0, fp_penalty=0.5):
"""
Args:
class_weights (torch.Tensor): 类别权重
alpha (float): 召回率调节因子0-1
gamma (float): Focal Loss参数
fp_penalty (float): 类别0假阳性惩罚强度
"""
super().__init__()
self.class_weights = class_weights
self.alpha = alpha
self.gamma = gamma
self.fp_penalty = fp_penalty
def forward(self, logits, targets):
# 基础交叉熵损失
ce_loss = F.cross_entropy(logits, targets, weight=self.class_weights, reduction='none')
# Focal Loss组件
pt = torch.exp(-ce_loss)
focal_loss = ((1 - pt) ** self.gamma) * ce_loss
# 召回率增强(对困难样本加权)
class_mask = F.one_hot(targets, num_classes=len(self.class_weights))
class_weights = (self.alpha + (1 - self.alpha) * pt.unsqueeze(-1)) * class_mask
recall_loss = (class_weights * focal_loss.unsqueeze(-1)).sum(dim=1)
# 类别0假阳性惩罚
probs = F.softmax(logits, dim=1)
fp_mask = (targets != 0) & (torch.argmax(logits, dim=1) == 0)
fp_loss = self.fp_penalty * probs[:, 0][fp_mask].pow(2).sum()
# 总损失
total_loss = recall_loss.mean() + fp_loss / len(targets)
return total_loss

View File

@ -1,148 +0,0 @@
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"]="1"
import torch
import torch.nn as nn
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
import json
from torch.utils.data import Dataset, DataLoader
import numpy as np
class VideoDataset(Dataset):
def __init__(self, data_path, sentence_transformer):
self.data = []
self.sentence_transformer = sentence_transformer
with open(data_path, "r", encoding="utf-8") as f:
for line in f:
self.data.append(json.loads(line))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
title = item["title"]
description = item["description"]
tags = item["tags"]
label = item["label"]
# 获取每个特征的嵌入
title_embedding = self.get_embedding(title)
description_embedding = self.get_embedding(description)
tags_embedding = self.get_embedding(" ".join(tags))
# 将嵌入连接起来
combined_embedding = torch.cat([title_embedding, description_embedding, tags_embedding], dim=0)
return combined_embedding, label
def get_embedding(self, text):
# 使用SentenceTransformer生成嵌入
embedding = self.sentence_transformer.encode(text)
return torch.tensor(embedding)
class VideoClassifier(nn.Module):
def __init__(self, embedding_dim=768, hidden_dim=256, output_dim=3):
super(VideoClassifier, self).__init__()
# 每个特征的嵌入维度是embedding_dim总共有3个特征
total_embedding_dim = embedding_dim * 3
# 全连接层
self.fc1 = nn.Linear(total_embedding_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
self.log_softmax = nn.LogSoftmax(dim=1)
def forward(self, embedding_features):
# 全连接层
x = torch.relu(self.fc1(embedding_features))
output = self.fc2(x)
output = self.log_softmax(output)
return output
def train(model, dataloader, criterion, optimizer, device):
model.train()
total_loss = 0
correct = 0
total = 0
for embedding_features, labels in dataloader:
embedding_features = embedding_features.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(embedding_features)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
_, predicted = torch.max(outputs, 1)
correct += (predicted == labels).sum().item()
total += labels.size(0)
avg_loss = total_loss / len(dataloader)
accuracy = correct / total
return avg_loss, accuracy
def validate(model, dataloader, criterion, device):
model.eval()
total_loss = 0
correct = 0
total = 0
with torch.no_grad():
for embedding_features, labels in dataloader:
embedding_features = embedding_features.to(device)
labels = labels.to(device)
outputs = model(embedding_features)
loss = criterion(outputs, labels)
total_loss += loss.item()
_, predicted = torch.max(outputs, 1)
correct += (predicted == labels).sum().item()
total += labels.size(0)
avg_loss = total_loss / len(dataloader)
accuracy = correct / total
return avg_loss, accuracy
# 超参数
hidden_dim = 256
output_dim = 3
batch_size = 32
num_epochs = 10
learning_rate = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载数据集
tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3")
sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
dataset = VideoDataset("labeled_data.jsonl", sentence_transformer=sentence_transformer)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 初始化模型
model = VideoClassifier(embedding_dim=768, hidden_dim=256, output_dim=3).to(device)
# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
num_epochs = 5
# 训练和验证
for epoch in range(num_epochs):
train_loss, train_acc = train(model, dataloader, criterion, optimizer, device)
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}")
# 保存模型
torch.save(model.state_dict(), "video_classifier.pth")
model.eval() # 设置为评估模式
# 2. 定义推理函数
def predict(model, sentence_transformer, title, description, tags, device):
# 将输入数据转换为嵌入
title_embedding = torch.tensor(sentence_transformer.encode(title)).to(device)
description_embedding = torch.tensor(sentence_transformer.encode(description)).to(device)
tags_embedding = torch.tensor(sentence_transformer.encode(" ".join(tags))).to(device)
# 将嵌入连接起来
combined_embedding = torch.cat([title_embedding, description_embedding, tags_embedding], dim=0).unsqueeze(0)
# 推理
with torch.no_grad():
output = model(combined_embedding)
_, predicted = torch.max(output, 1)
return predicted.item()

View File

@ -1,6 +0,0 @@
import { defineConfig } from "$fresh/server.ts";
import tailwind from "$fresh/plugins/tailwind.ts";
export default defineConfig({
plugins: [tailwind()],
});

View File

@ -1,27 +0,0 @@
// DO NOT EDIT. This file is generated by Fresh.
// This file SHOULD be checked into source version control.
// This file is automatically updated during development when running `dev.ts`.
import * as $_404 from "./routes/_404.tsx";
import * as $_app from "./routes/_app.tsx";
import * as $api_joke from "./routes/api/joke.ts";
import * as $greet_name_ from "./routes/greet/[name].tsx";
import * as $index from "./routes/index.tsx";
import * as $Counter from "./islands/Counter.tsx";
import type { Manifest } from "$fresh/server.ts";
const manifest = {
routes: {
"./routes/_404.tsx": $_404,
"./routes/_app.tsx": $_app,
"./routes/api/joke.ts": $api_joke,
"./routes/greet/[name].tsx": $greet_name_,
"./routes/index.tsx": $index,
},
islands: {
"./islands/Counter.tsx": $Counter,
},
baseUrl: import.meta.url,
} satisfies Manifest;
export default manifest;

View File

@ -1,16 +0,0 @@
import type { Signal } from "@preact/signals";
import { Button } from "../components/Button.tsx";
interface CounterProps {
count: Signal<number>;
}
export default function Counter(props: CounterProps) {
return (
<div class="flex gap-8 py-6">
<Button onClick={() => props.count.value -= 1}>-1</Button>
<p class="text-3xl tabular-nums">{props.count}</p>
<Button onClick={() => props.count.value += 1}>+1</Button>
</div>
);
}

View File

@ -1,61 +0,0 @@
import { Client, Transaction } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
import { AllDataType } from "lib/db/schema.d.ts";
import logger from "lib/log/logger.ts";
import { parseTimestampFromPsql } from "lib/utils/formatTimestampToPostgre.ts";
export async function videoExistsInAllData(client: Client, aid: number) {
return await client.queryObject<{ exists: boolean }>(`SELECT EXISTS(SELECT 1 FROM all_data WHERE aid = $1)`, [aid])
.then((result) => result.rows[0].exists);
}
export async function insertIntoAllData(client: Client, data: AllDataType) {
logger.log(`inserted ${data.aid}`, "db-all_data");
return await client.queryObject(
`INSERT INTO all_data (aid, bvid, description, uid, tags, title, published_at)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (aid) DO NOTHING`,
[data.aid, data.bvid, data.description, data.uid, data.tags, data.title, data.published_at],
);
}
export async function getLatestVideoTimestampFromAllData(client: Client) {
return await client.queryObject<{ published_at: string }>(
`SELECT published_at FROM all_data ORDER BY published_at DESC LIMIT 1`,
)
.then((result) => {
const date = new Date(result.rows[0].published_at);
if (isNaN(date.getTime())) {
return null;
}
return date.getTime();
});
}
export async function videoTagsIsNull(client: Client | Transaction, aid: number) {
return await client.queryObject<{ exists: boolean }>(
`SELECT EXISTS(SELECT 1 FROM all_data WHERE aid = $1 AND tags IS NULL)`,
[aid],
).then((result) => result.rows[0].exists);
}
export async function updateVideoTags(client: Client | Transaction, aid: number, tags: string[]) {
return await client.queryObject(
`UPDATE all_data SET tags = $1 WHERE aid = $2`,
[tags.join(","), aid],
);
}
export async function getNullVideoTagsList(client: Client) {
const queryResult = await client.queryObject<{ aid: number; published_at: string }>(
`SELECT aid, published_at FROM all_data WHERE tags IS NULL`,
);
const rows = queryResult.rows;
return rows.map(
(row) => {
return {
aid: Number(row.aid),
published_at: parseTimestampFromPsql(row.published_at),
};
},
);
}

View File

@ -1,6 +0,0 @@
import { Pool } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
import {postgresConfig} from "lib/db/pgConfig.ts";
const pool = new Pool(postgresConfig, 32);
export const db = pool;

View File

@ -1,3 +0,0 @@
import { Redis } from "ioredis";
export const redis = new Redis({ maxRetriesPerRequest: null });

9
lib/db/schema.d.ts vendored
View File

@ -1,9 +0,0 @@
export interface AllDataType {
aid: number;
bvid: string | null;
description: string | null;
uid: number | null;
tags: string | null;
title: string | null;
published_at: string | null;
}

View File

@ -1,19 +0,0 @@
import { SentenceTransformer } from "./model.ts"; // Changed import path
async function main() {
const sentenceTransformer = await SentenceTransformer.from_pretrained(
"mixedbread-ai/mxbai-embed-large-v1",
);
const outputs = await sentenceTransformer.encode([
"Hello world",
"How are you guys doing?",
"Today is Friday!",
]);
// @ts-ignore
console.log(outputs["last_hidden_state"]);
return outputs;
}
main(); // Keep main function call if you want this file to be runnable directly for testing.

View File

@ -1,40 +0,0 @@
// lib/ml/sentence_transformer_model.ts
import { AutoModel, AutoTokenizer, PretrainedOptions } from "@huggingface/transformers";
export class SentenceTransformer {
constructor(
private readonly tokenizer: AutoTokenizer,
private readonly model: AutoModel,
) {}
static async from_pretrained(
modelName: string,
options?: PretrainedOptions,
): Promise<SentenceTransformer> {
if (!options) {
options = {
progress_callback: undefined,
cache_dir: undefined,
local_files_only: false,
revision: "main",
};
}
const tokenizer = await AutoTokenizer.from_pretrained(modelName, options);
const model = await AutoModel.from_pretrained(modelName, options);
return new SentenceTransformer(tokenizer, model);
}
async encode(sentences: string[]): Promise<any> { // Changed return type to 'any' for now to match console.log output
//@ts-ignore
const modelInputs = await this.tokenizer(sentences, {
padding: true,
truncation: true,
});
//@ts-ignore
const outputs = await this.model(modelInputs);
return outputs;
}
}

View File

@ -1,34 +0,0 @@
import { Tensor } from "@huggingface/transformers";
//@ts-ignore
import { Callable } from "@huggingface/transformers/src/utils/core.js"; // Keep as is for now, might need adjustment
export interface PoolingConfig {
word_embedding_dimension: number;
pooling_mode_cls_token: boolean;
pooling_mode_mean_tokens: boolean;
pooling_mode_max_tokens: boolean;
pooling_mode_mean_sqrt_len_tokens: boolean;
}
export interface PoolingInput {
token_embeddings: Tensor;
attention_mask: Tensor;
}
export interface PoolingOutput {
sentence_embedding: Tensor;
}
export class Pooling extends Callable {
constructor(private readonly config: PoolingConfig) {
super();
}
// async _call(inputs: any) { // Keep if pooling functionality is needed
// return this.forward(inputs);
// }
// async forward(inputs: PoolingInput): PoolingOutput { // Keep if pooling functionality is needed
// }
}

View File

@ -1,32 +0,0 @@
import { AutoModel, AutoTokenizer, Tensor } from '@huggingface/transformers';
const modelName = "alikia2x/jina-embedding-v3-m2v-1024";
const modelConfig = {
config: { model_type: 'model2vec' },
dtype: 'fp32',
revision: 'refs/pr/1',
cache_dir: undefined,
local_files_only: true,
};
const tokenizerConfig = {
revision: 'refs/pr/2'
};
const model = await AutoModel.from_pretrained(modelName, modelConfig);
const tokenizer = await AutoTokenizer.from_pretrained(modelName, tokenizerConfig);
const texts = ['hello', 'hello world'];
const { input_ids } = await tokenizer(texts, { add_special_tokens: false, return_tensor: false });
const cumsum = arr => arr.reduce((acc, num, i) => [...acc, num + (acc[i - 1] || 0)], []);
const offsets = [0, ...cumsum(input_ids.slice(0, -1).map(x => x.length))];
const flattened_input_ids = input_ids.flat();
const modelInputs = {
input_ids: new Tensor('int64', flattened_input_ids, [flattened_input_ids.length]),
offsets: new Tensor('int64', offsets, [offsets.length])
};
const { embeddings } = await model(modelInputs);
console.log(embeddings.tolist()); // output matches python version

View File

@ -1,52 +0,0 @@
import { Job } from "bullmq";
import { insertLatestVideos } from "lib/task/insertLatestVideo.ts";
import { LatestVideosQueue } from "lib/mq/index.ts";
import { MINUTE } from "$std/datetime/constants.ts";
import { db } from "lib/db/init.ts";
import { truncate } from "lib/utils/truncate.ts";
import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
import logger from "lib/log/logger.ts";
import { lockManager } from "lib/mq/lockManager.ts";
const delayMap = [5, 10, 15, 30, 60, 60];
const updateQueueInterval = async (failedCount: number, delay: number) => {
logger.log(`job:getLatestVideos added to queue, delay: ${(delay / MINUTE).toFixed(2)} minutes.`, "mq");
await LatestVideosQueue.upsertJobScheduler("getLatestVideos", {
every: delay,
}, {
data: {
failedCount: failedCount,
},
});
return;
};
const executeTask = async (client: Client, failedCount: number) => {
const result = await insertLatestVideos(client);
failedCount = result !== 0 ? truncate(failedCount + 1, 0, 5) : 0;
if (failedCount !== 0) {
await updateQueueInterval(failedCount, delayMap[failedCount] * MINUTE);
}
return;
};
export const getLatestVideosWorker = async (job: Job) => {
if (await lockManager.isLocked("getLatestVideos")) {
logger.log("job:getLatestVideos is locked, skipping.", "mq");
return;
}
lockManager.acquireLock("getLatestVideos");
const failedCount = (job.data.failedCount ?? 0) as number;
const client = await db.connect();
try {
await executeTask(client, failedCount);
} finally {
client.release();
lockManager.releaseLock("getLatestVideos");
}
return;
};

View File

@ -1,99 +0,0 @@
import { Job } from "bullmq";
import { VideoTagsQueue } from "lib/mq/index.ts";
import { DAY, HOUR, MINUTE, SECOND } from "$std/datetime/constants.ts";
import { db } from "lib/db/init.ts";
import { truncate } from "lib/utils/truncate.ts";
import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
import logger from "lib/log/logger.ts";
import { getNullVideoTagsList, updateVideoTags } from "lib/db/allData.ts";
import { getVideoTags } from "lib/net/getVideoTags.ts";
import { NetSchedulerError } from "lib/mq/scheduler.ts";
import { WorkerError } from "src/worker.ts";
const delayMap = [0.5, 3, 5, 15, 30, 60];
const getJobPriority = (diff: number) => {
let priority;
if (diff > 14 * DAY) {
priority = 10;
} else if (diff > 7 * DAY) {
priority = 7;
} else if (diff > DAY) {
priority = 5;
} else if (diff > 6 * HOUR) {
priority = 3;
} else if (diff > HOUR) {
priority = 2;
} else {
priority = 1;
}
return priority;
};
const executeTask = async (client: Client, aid: number, failedCount: number, job: Job) => {
try {
const result = await getVideoTags(aid);
if (!result) {
failedCount = truncate(failedCount + 1, 0, 5);
const delay = delayMap[failedCount] * MINUTE;
logger.log(
`job:getVideoTags added to queue, delay: ${delayMap[failedCount]} minutes.`,
"mq",
);
await VideoTagsQueue.add("getVideoTags", { aid, failedCount }, { delay, priority: 6 - failedCount });
return 1;
}
await updateVideoTags(client, aid, result);
logger.log(`Fetched tags for aid: ${aid}`, "task");
return 0;
} catch (e) {
if (!(e instanceof NetSchedulerError)) {
throw new WorkerError(<Error> e, "task", "getVideoTags/fn:executeTask");
}
const err = e as NetSchedulerError;
if (err.code === "NO_AVAILABLE_PROXY" || err.code === "PROXY_RATE_LIMITED") {
logger.warn(`No available proxy for fetching tags, delayed. aid: ${aid}`, "task");
await VideoTagsQueue.add("getVideoTags", { aid, failedCount }, {
delay: 25 * SECOND * Math.random() + 5 * SECOND,
priority: job.priority,
});
return 2;
}
throw new WorkerError(err, "task", "getVideoTags/fn:executeTask");
}
};
export const getVideoTagsWorker = async (job: Job) => {
const failedCount = (job.data.failedCount ?? 0) as number;
const client = await db.connect();
const aid = job.data.aid;
if (!aid) {
return 3;
}
const v = await executeTask(client, aid, failedCount, job);
client.release();
return v;
};
export const getVideoTagsInitializer = async () => {
const client = await db.connect();
const videos = await getNullVideoTagsList(client);
if (videos.length == 0) {
return 4;
}
const count = await VideoTagsQueue.getJobCounts("wait", "delayed", "active");
const total = count.delayed + count.active + count.wait;
const max = 15;
const rest = truncate(max - total, 0, max);
let i = 0;
for (const video of videos) {
if (i > rest) return 100 + i;
const aid = video.aid;
const timestamp = video.published_at;
const diff = Date.now() - timestamp;
await VideoTagsQueue.add("getVideoTags", { aid }, { priority: getJobPriority(diff) });
i++;
}
return 0;
};

View File

@ -1 +0,0 @@
export * from "lib/mq/exec/getLatestVideos.ts";

View File

@ -1,5 +0,0 @@
import { Queue } from "bullmq";
export const LatestVideosQueue = new Queue("latestVideos");
export const VideoTagsQueue = new Queue("videoTags");

View File

@ -1,22 +0,0 @@
import { MINUTE, SECOND } from "$std/datetime/constants.ts";
import { LatestVideosQueue, VideoTagsQueue } from "lib/mq/index.ts";
import logger from "lib/log/logger.ts";
async function configGetLatestVideos() {
await LatestVideosQueue.upsertJobScheduler("getLatestVideos", {
every: 1 * MINUTE,
});
}
async function configGetVideosTags() {
await VideoTagsQueue.upsertJobScheduler("getVideosTags", {
every: 30 * SECOND,
immediately: true,
});
}
export async function initMQ() {
await configGetLatestVideos();
await configGetVideosTags();
logger.log("Message queue initialized.");
}

View File

@ -1,164 +0,0 @@
import logger from "lib/log/logger.ts";
import {RateLimiter} from "lib/mq/rateLimiter.ts";
import {SlidingWindow} from "lib/mq/slidingWindow.ts";
import {redis} from "lib/db/redis.ts";
import Redis from "ioredis";
interface Proxy {
type: string;
task: string;
limiter?: RateLimiter;
}
interface ProxiesMap {
[name: string]: Proxy;
}
type NetSchedulerErrorCode =
| "NO_AVAILABLE_PROXY"
| "PROXY_RATE_LIMITED"
| "PROXY_NOT_FOUND"
| "FETCH_ERROR"
| "NOT_IMPLEMENTED";
export class NetSchedulerError extends Error {
public code: NetSchedulerErrorCode;
public rawError: unknown | undefined;
constructor(message: string, errorCode: NetSchedulerErrorCode, rawError?: unknown) {
super(message);
this.name = "NetSchedulerError";
this.code = errorCode;
this.rawError = rawError;
}
}
class NetScheduler {
private proxies: ProxiesMap = {};
addProxy(name: string, type: string, task: string): void {
this.proxies[name] = { type, task };
}
removeProxy(name: string): void {
delete this.proxies[name];
}
setProxyLimiter(name: string, limiter: RateLimiter): void {
this.proxies[name].limiter = limiter;
}
/*
* Make a request to the specified URL with any available proxy
* @param {string} url - The URL to request.
* @param {string} method - The HTTP method to use for the request. Default is "GET".
* @returns {Promise<any>} - A promise that resolves to the response body.
* @throws {NetSchedulerError} - The error will be thrown in following cases:
* - No available proxy currently: with error code NO_AVAILABLE_PROXY
* - Proxy is under rate limit: with error code PROXY_RATE_LIMITED
* - The native `fetch` function threw an error: with error code FETCH_ERROR
* - The proxy type is not supported: with error code NOT_IMPLEMENTED
*/
async request<R>(url: string, task: string, method: string = "GET"): Promise<R> {
// find a available proxy
const proxiesNames = Object.keys(this.proxies);
for (const proxyName of proxiesNames) {
const proxy = this.proxies[proxyName];
if (proxy.task !== task) continue;
if (await this.getProxyAvailability(proxyName)) {
return await this.proxyRequest<R>(url, proxyName, method);
}
}
throw new NetSchedulerError("No available proxy currently.", "NO_AVAILABLE_PROXY");
}
/*
* Make a request to the specified URL with the specified proxy
* @param {string} url - The URL to request.
* @param {string} proxyName - The name of the proxy to use.
* @param {string} method - The HTTP method to use for the request. Default is "GET".
* @param {boolean} force - If true, the request will be made even if the proxy is rate limited. Default is false.
* @returns {Promise<any>} - A promise that resolves to the response body.
* @throws {NetSchedulerError} - The error will be thrown in following cases:
* - Proxy not found: with error code PROXY_NOT_FOUND
* - Proxy is under rate limit: with error code PROXY_RATE_LIMITED
* - The native `fetch` function threw an error: with error code FETCH_ERROR
* - The proxy type is not supported: with error code NOT_IMPLEMENTED
*/
async proxyRequest<R>(url: string, proxyName: string, method: string = "GET", force: boolean = false): Promise<R> {
const proxy = this.proxies[proxyName];
if (!proxy) {
throw new NetSchedulerError(`Proxy "${proxyName}" not found`, "PROXY_NOT_FOUND");
}
if (!force && await this.getProxyAvailability(proxyName) === false) {
throw new NetSchedulerError(`Proxy "${proxyName}" is rate limited`, "PROXY_RATE_LIMITED");
}
if (proxy.limiter) {
try {
await proxy.limiter!.trigger();
} catch (e) {
const error = e as Error;
if (e instanceof Redis.ReplyError) {
logger.error(error, "redis");
}
logger.warn(`Unhandled error: ${error.message}`, "mq", "proxyRequest");
}
}
switch (proxy.type) {
case "native":
return await this.nativeRequest<R>(url, method);
default:
throw new NetSchedulerError(`Proxy type ${proxy.type} not supported.`, "NOT_IMPLEMENTED");
}
}
private async getProxyAvailability(name: string): Promise<boolean> {
try {
const proxyConfig = this.proxies[name];
if (!proxyConfig || !proxyConfig.limiter) {
return true;
}
return await proxyConfig.limiter.getAvailability();
} catch (e) {
const error = e as Error;
if (e instanceof Redis.ReplyError) {
logger.error(error, "redis");
return false;
}
logger.warn(`Unhandled error: ${error.message}`, "mq", "getProxyAvailability");
return false;
}
}
private async nativeRequest<R>(url: string, method: string): Promise<R> {
try {
const response = await fetch(url, { method });
return await response.json() as R;
} catch (e) {
throw new NetSchedulerError("Fetch error", "FETCH_ERROR", e);
}
}
}
const netScheduler = new NetScheduler();
netScheduler.addProxy("default", "native", "default");
netScheduler.addProxy("tags-native", "native", "getVideoTags");
const tagsRateLimiter = new RateLimiter("getVideoTags", [
{
window: new SlidingWindow(redis, 1),
max: 3,
},
{
window: new SlidingWindow(redis, 30),
max: 30,
},
{
window: new SlidingWindow(redis, 2 * 60),
max: 50,
},
]);
netScheduler.setProxyLimiter("tags-native", tagsRateLimiter);
export default netScheduler;

117
lib/net/bilibili.d.ts vendored
View File

@ -1,117 +0,0 @@
interface BaseResponse<T> {
code: number;
message: string;
ttl: number;
data: T;
}
export type VideoListResponse = BaseResponse<VideoListData>;
export type VideoTagsResponse = BaseResponse<VideoTagsData>;
type VideoTagsData = VideoTags[];
interface VideoTags {
tag_id: number;
tag_name: string;
cover: string;
head_cover: string;
content: string;
short_content: string;
type: number;
state: number;
ctime: number;
count: {
view: number;
use: number;
atten: number;
}
is_atten: number;
likes: number;
hates: number;
attribute: number;
liked: number;
hated: number;
extra_attr: number;
}
interface VideoListData {
archives: VideoListVideo[];
page: {
num: number;
size: number;
count: number;
};
}
interface VideoListVideo {
aid: number;
videos: number;
tid: number;
tname: string;
copyright: number;
pic: string;
title: string;
pubdate: number;
ctime: number;
desc: string;
state: number;
duration: number;
mission_id?: number;
rights: {
bp: number;
elec: number;
download: number;
movie: number;
pay: number;
hd5: number;
no_reprint: number;
autoplay: number;
ugc_pay: number;
is_cooperation: number;
ugc_pay_preview: number;
no_background: number;
arc_pay: number;
pay_free_watch: number;
},
owner: {
mid: number;
name: string;
face: string;
},
stat: {
aid: number;
view: number;
danmaku: number;
reply: number;
favorite: number;
coin: number;
share: number;
now_rank: number;
his_rank: number;
like: number;
dislike: number;
vt: number;
vv: number;
},
dynamic: string;
cid: number;
dimension: {
width: number;
height: number;
rotate: number;
},
season_id?: number;
short_link_v2: string;
first_frame: string;
pub_location: string;
cover43: string;
tidv2: number;
tname_v2: string;
bvid: string;
season_type: number;
is_ogv: number;
ovg_info: string | null;
rcmd_season: string;
enable_vt: number;
ai_rcmd: null | string;
}

View File

@ -1,88 +0,0 @@
import { getLatestVideos } from "lib/net/getLatestVideos.ts";
import { AllDataType } from "lib/db/schema.d.ts";
export async function getVideoPositionInNewList(timestamp: number): Promise<number | null | AllDataType[]> {
const virtualPageSize = 50;
let lowPage = 1;
let highPage = 1;
let foundUpper = false;
while (true) {
const ps = highPage < 2 ? 50 : 1
const pn = highPage < 2 ? 1 : highPage * virtualPageSize;
const fetchTags = highPage < 2 ? true : false;
const videos = await getLatestVideos(pn, ps, 250, fetchTags);
if (!videos || videos.length === 0) {
break;
}
const lastVideo = videos[videos.length - 1];
if (!lastVideo || !lastVideo.published_at) {
break;
}
const lastTime = Date.parse(lastVideo.published_at);
if (lastTime <= timestamp && highPage == 1) {
return videos;
}
else if (lastTime <= timestamp) {
foundUpper = true;
break;
} else {
lowPage = highPage;
highPage *= 2;
}
}
if (!foundUpper) {
return null;
}
let boundaryPage = highPage;
let lo = lowPage;
let hi = highPage;
while (lo <= hi) {
const mid = Math.floor((lo + hi) / 2);
const videos = await getLatestVideos(mid * virtualPageSize, 1, 250, false);
if (!videos) {
return null;
}
if (videos.length === 0) {
hi = mid - 1;
continue;
}
const lastVideo = videos[videos.length - 1];
if (!lastVideo || !lastVideo.published_at) {
hi = mid - 1;
continue;
}
const lastTime = Date.parse(lastVideo.published_at);
if (lastTime > timestamp) {
lo = mid + 1;
} else {
boundaryPage = mid;
hi = mid - 1;
}
}
const boundaryVideos = await getLatestVideos(boundaryPage, virtualPageSize, 250, false);
let indexInPage = 0;
if (boundaryVideos && boundaryVideos.length > 0) {
for (let i = 0; i < boundaryVideos.length; i++) {
const video = boundaryVideos[i];
if (!video.published_at) {
continue;
}
const videoTime = Date.parse(video.published_at);
if (videoTime > timestamp) {
indexInPage++;
} else {
break;
}
}
}
const count = (boundaryPage - 1) * virtualPageSize + indexInPage;
const safetyMargin = 5;
return count + safetyMargin;
}

View File

@ -1,45 +0,0 @@
import { VideoListResponse } from "lib/net/bilibili.d.ts";
import { formatTimestampToPsql as formatPublishedAt } from "lib/utils/formatTimestampToPostgre.ts";
import { AllDataType } from "lib/db/schema.d.ts";
import logger from "lib/log/logger.ts";
import { HOUR, SECOND } from "$std/datetime/constants.ts";
export async function getLatestVideos(
page: number = 1,
pageSize: number = 10,
sleepRate: number = 250,
fetchTags: boolean = true,
): Promise<AllDataType[] | null> {
try {
const response = await fetch(
`https://api.bilibili.com/x/web-interface/newlist?rid=30&ps=${pageSize}&pn=${page}`,
);
const data: VideoListResponse = await response.json();
if (data.code !== 0) {
logger.error(`Error fetching videos: ${data.message}`, "net", "getLatestVideos");
return null;
}
if (data.data.archives.length === 0) {
logger.verbose("No more videos found", "net", "getLatestVideos");
return [];
}
return data.data.archives.map((video) => {
const published_at = formatPublishedAt(video.pubdate * SECOND + 8 * HOUR);
return {
aid: video.aid,
bvid: video.bvid,
description: video.desc,
uid: video.owner.mid,
tags: null,
title: video.title,
published_at: published_at,
} as AllDataType;
});
} catch (error) {
logger.error(error as Error, "net", "getLatestVideos");
return null;
}
}

View File

@ -1,35 +0,0 @@
import { VideoTagsResponse } from "lib/net/bilibili.d.ts";
import netScheduler, {NetSchedulerError} from "lib/mq/scheduler.ts";
import logger from "lib/log/logger.ts";
/*
* Fetch the tags for a video
* @param {number} aid The video's aid
* @return {Promise<string[] | null>} A promise, which resolves to an array of tags,
* or null if an `fetch` error occurred
* @throws {NetSchedulerError} If the request failed.
*/
export async function getVideoTags(aid: number): Promise<string[] | null> {
try {
const url = `https://api.bilibili.com/x/tag/archive/tags?aid=${aid}`;
const data = await netScheduler.request<VideoTagsResponse>(url, 'getVideoTags');
if (data.code != 0) {
logger.error(`Error fetching tags for video ${aid}: ${data.message}`, 'net', 'getVideoTags');
return [];
}
return data.data.map((tag) => tag.tag_name);
}
catch (e) {
const error = e as NetSchedulerError;
if (error.code == "FETCH_ERROR") {
const rawError = error.rawError! as Error;
rawError.message = `Error fetching tags for video ${aid}: ` + rawError.message;
logger.error(rawError, 'net', 'getVideoTags');
return null;
}
else {
// Re-throw the error
throw e;
}
}
}

View File

@ -1,77 +0,0 @@
import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
import { getLatestVideos } from "lib/net/getLatestVideos.ts";
import { getLatestVideoTimestampFromAllData, insertIntoAllData, videoExistsInAllData } from "lib/db/allData.ts";
import { sleep } from "lib/utils/sleep.ts";
import { getVideoPositionInNewList } from "lib/net/bisectVideoStartFrom.ts";
import { SECOND } from "$std/datetime/constants.ts";
import logger from "lib/log/logger.ts";
export async function insertLatestVideos(
client: Client,
pageSize: number = 10,
sleepRate: number = 250,
intervalRate: number = 4000,
): Promise<number | null> {
const latestVideoTimestamp = await getLatestVideoTimestampFromAllData(client);
if (latestVideoTimestamp == null) {
logger.error("Cannot get latest video timestamp from current database.", "net", "fn:insertLatestVideos()");
return null
}
logger.log(`Latest video in the database: ${new Date(latestVideoTimestamp).toISOString()}`, "net", "fn:insertLatestVideos()")
const videoIndex = await getVideoPositionInNewList(latestVideoTimestamp);
if (videoIndex == null) {
logger.error("Cannot locate the video through bisect.", "net", "fn:insertLatestVideos()");
return null
}
if (typeof videoIndex == "object") {
for (const video of videoIndex) {
const videoExists = await videoExistsInAllData(client, video.aid);
if (!videoExists) {
insertIntoAllData(client, video);
}
}
return 0;
}
let page = Math.floor(videoIndex / pageSize) + 1;
let failCount = 0;
const insertedVideos = new Set();
while (true) {
try {
const videos = await getLatestVideos(page, pageSize, sleepRate);
if (videos == null) {
failCount++;
if (failCount > 5) {
return null;
}
continue;
}
failCount = 0;
if (videos.length == 0) {
logger.verbose("No more videos found", "net", "fn:insertLatestVideos()");
break;
}
for (const video of videos) {
const videoExists = await videoExistsInAllData(client, video.aid);
if (!videoExists) {
insertIntoAllData(client, video);
insertedVideos.add(video.aid);
}
}
logger.log(`Page ${page} crawled, total: ${insertedVideos.size} videos.`, "net", "fn:insertLatestVideos()");
page--;
if (page < 1) {
return 0;
}
} catch (error) {
logger.error(error as Error, "net", "fn:insertLatestVideos()");
failCount++;
if (failCount > 5) {
return null;
}
continue;
} finally {
await sleep(Math.random() * intervalRate + failCount * 3 * SECOND + SECOND);
}
}
return 0;
}

13
main.ts
View File

@ -1,13 +0,0 @@
/// <reference no-default-lib="true" />
/// <reference lib="dom" />
/// <reference lib="dom.iterable" />
/// <reference lib="dom.asynciterable" />
/// <reference lib="deno.ns" />
import "$std/dotenv/load.ts";
import { start } from "$fresh/server.ts";
import manifest from "./fresh.gen.ts";
import config from "./fresh.config.ts";
await start(manifest, config);

View File

@ -19,3 +19,12 @@ Note
0331: V3.6-test3 # 3.5不太行,我试着调下超参
0335: V3.7-test3 # 3.6还行,再调超参试试看
0414: V3.8-test3 # 3.7不行从3.6的基础重新调
1918: V3.9
2308: V3.11
2243: V3.11 # 256维嵌入
2253: V3.11 # 1024维度嵌入对比
2337: V3.12 # 级联分类
2350: V3.13 # V3.12, 换用普通交叉熵损失
0012: V3.11 # 换用普通交叉熵损失
0039: V3.11 # 级联分类,但使用两个独立模型
0122: V3.15 # 删除author_info通道

View File

@ -103,8 +103,7 @@ class MultiChannelDataset(Dataset):
texts = {
'title': example['title'],
'description': example['description'],
'tags': tags_text,
'author_info': example['author_info']
'tags': tags_text
}
return {

110
ml/filter/embedding.py Normal file
View File

@ -0,0 +1,110 @@
import numpy as np
import torch
from model2vec import StaticModel
def prepare_batch(batch_data, device="cpu"):
"""
将输入的 batch_data 转换为模型所需的输入格式 [batch_size, num_channels, embedding_dim]
参数:
batch_data (dict): 输入的 batch 数据格式为 {
"title": [text1, text2, ...],
"description": [text1, text2, ...],
"tags": [text1, text2, ...]
}
device (str): 模型运行的设备 "cpu" "cuda"
返回:
torch.Tensor: 形状为 [batch_size, num_channels, embedding_dim] 的张量
"""
# 1. 对每个通道的文本分别编码
channel_embeddings = []
model = StaticModel.from_pretrained("./model/embedding_1024/")
for channel in ["title", "description", "tags"]:
texts = batch_data[channel] # 获取当前通道的文本列表
embeddings = torch.from_numpy(model.encode(texts)).to(torch.float32).to(device) # 编码为 [batch_size, embedding_dim]
channel_embeddings.append(embeddings)
# 2. 将编码结果堆叠为 [batch_size, num_channels, embedding_dim]
batch_tensor = torch.stack(channel_embeddings, dim=1) # 在 dim=1 上堆叠
return batch_tensor
import onnxruntime as ort
from transformers import AutoTokenizer
from itertools import accumulate
def prepare_batch_per_token(batch_data, max_length=1024):
"""
将输入的 batch_data 转换为模型所需的输入格式 [batch_size, num_channels, seq_length, embedding_dim]
参数:
batch_data (dict): 输入的 batch 数据格式为 {
"title": [text1, text2, ...],
"description": [text1, text2, ...],
"tags": [text1, text2, ...],
"author_info": [text1, text2, ...]
}
max_length (int): 最大序列长度
返回:
torch.Tensor: 形状为 [batch_size, num_channels, seq_length, embedding_dim] 的张量
"""
# 初始化 tokenizer 和 ONNX 模型
tokenizer = AutoTokenizer.from_pretrained("alikia2x/jina-embedding-v3-m2v-1024")
session = ort.InferenceSession("./model/embedding_256/onnx/model.onnx")
# 1. 对每个通道的文本分别编码
channel_embeddings = []
for channel in ["title", "description", "tags", "author_info"]:
texts = batch_data[channel] # 获取当前通道的文本列表
# Step 1: 生成 input_ids 和 offsets
# 对每个文本单独编码,保留原始 token 长度
encoded_inputs = [tokenizer(text, truncation=True, max_length=max_length, return_tensors='np') for text in texts]
# 提取每个文本的 input_ids 长度(考虑实际的 token 数量)
input_ids_lengths = [len(enc["input_ids"][0]) for enc in encoded_inputs]
# 生成 offsets: [0, len1, len1+len2, ...]
offsets = list(accumulate([0] + input_ids_lengths[:-1])) # 累积和,排除最后一个长度
# 将所有 input_ids 展平为一维数组
flattened_input_ids = np.concatenate([enc["input_ids"][0] for enc in encoded_inputs], axis=0).astype(np.int64)
# Step 2: 构建 ONNX 输入
inputs = {
"input_ids": ort.OrtValue.ortvalue_from_numpy(flattened_input_ids),
"offsets": ort.OrtValue.ortvalue_from_numpy(np.array(offsets, dtype=np.int64))
}
# Step 3: 运行 ONNX 模型
embeddings = session.run(None, inputs)[0] # 假设输出名为 "embeddings"
# Step 4: 将输出重塑为 [batch_size, seq_length, embedding_dim]
# 注意:这里假设 ONNX 输出的形状是 [total_tokens, embedding_dim]
# 需要根据实际序列长度重新分组
batch_size = len(texts)
embeddings_split = np.split(embeddings, np.cumsum(input_ids_lengths[:-1]))
padded_embeddings = []
for emb, seq_len in zip(embeddings_split, input_ids_lengths):
# 对每个序列填充到 max_length
if seq_len > max_length:
# 如果序列长度超过 max_length截断
emb = emb[:max_length]
pad_length = 0
else:
# 否则填充到 max_length
pad_length = max_length - seq_len
# 填充到 [max_length, embedding_dim]
padded = np.pad(emb, ((0, pad_length), (0, 0)), mode='constant')
padded_embeddings.append(padded)
# 确保所有填充后的序列形状一致
embeddings_tensor = torch.tensor(np.stack(padded_embeddings), dtype=torch.float32)
channel_embeddings.append(embeddings_tensor)
# 2. 将编码结果堆叠为 [batch_size, num_channels, seq_length, embedding_dim]
batch_tensor = torch.stack(channel_embeddings, dim=1)
return batch_tensor

View File

@ -0,0 +1,54 @@
import json
import torch
import random
from embedding import prepare_batch
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
file_path = './data/filter/model_predicted.jsonl'
class Dataset:
def __init__(self, file_path):
all_examples = self.load_data(file_path)
self.examples = all_examples
def load_data(self, file_path):
with open(file_path, 'r', encoding='utf-8') as f:
return [json.loads(line) for line in f]
def __getitem__(self, idx):
end_idx = min((idx + 1) * self.batch_size, len(self.examples))
texts = {
'title': [ex['title'] for ex in self.examples[idx * self.batch_size:end_idx]],
'description': [ex['description'] for ex in self.examples[idx * self.batch_size:end_idx]],
'tags': [",".join(ex['tags']) for ex in self.examples[idx * self.batch_size:end_idx]],
'author_info': [ex['author_info'] for ex in self.examples[idx * self.batch_size:end_idx]]
}
return texts
def __len__(self):
return len(self.examples)
def get_batch(self, idx, batch_size):
self.batch_size = batch_size
return self.__getitem__(idx)
total = 600000
batch_size = 512
batch_num = total // batch_size
dataset = Dataset(file_path)
arr_len = batch_size * 4 * 1024
sample_rate = 0.1
sample_num = int(arr_len * sample_rate)
data = np.array([])
for i in tqdm(range(batch_num)):
batch = dataset.get_batch(i, batch_size)
batch = prepare_batch(batch, device="cpu")
arr = batch.flatten().numpy()
sampled = np.random.choice(arr.shape[0], size=sample_num, replace=False)
data = np.concatenate((data, arr[sampled]), axis=0) if data.size else arr[sampled]
if i % 10 == 0:
np.save('embedding_range.npy', data)
np.save('embedding_range.npy', data)

View File

@ -0,0 +1,43 @@
import numpy as np
import matplotlib.pyplot as plt
# 加载数据
data = np.load("1.npy")
# 绘制直方图,获取频数
n, bins, patches = plt.hist(data, bins=32, density=False, alpha=0.7, color='skyblue')
# 计算数据总数
total_data = len(data)
# 将频数转换为频率
frequencies = n / total_data
# 计算统计信息
max_val = np.max(data)
min_val = np.min(data)
std_dev = np.std(data)
# 设置图形属性
plt.title('Frequency Distribution Histogram')
plt.xlabel('Value')
plt.ylabel('Frequency')
# 重新绘制直方图,使用频率作为高度
plt.cla() # 清除当前坐标轴上的内容
plt.bar([(bins[i] + bins[i+1])/2 for i in range(len(bins)-1)], frequencies, width=[bins[i+1]-bins[i] for i in range(len(bins)-1)], alpha=0.7, color='skyblue')
# 在柱子上注明频率值
for i in range(len(patches)):
plt.text(bins[i]+(bins[i+1]-bins[i])/2, frequencies[i], f'{frequencies[i]:.2e}', ha='center', va='bottom', fontsize=6)
# 在图表一角显示统计信息
stats_text = f"Max: {max_val:.6f}\nMin: {min_val:.6f}\nStd: {std_dev:.4e}"
plt.text(0.95, 0.95, stats_text, transform=plt.gca().transAxes,
ha='right', va='top', bbox=dict(facecolor='white', edgecolor='black', alpha=0.8))
# 设置 x 轴刻度对齐柱子边界
plt.xticks(bins, fontsize = 6)
# 显示图形
plt.show()

View File

@ -10,7 +10,7 @@ import tty
import termios
from sentence_transformers import SentenceTransformer
from db_utils import fetch_entry_data, parse_entry_data
from modelV3_9 import VideoClassifierV3_9
from modelV3_10 import VideoClassifierV3_10
class LabelingSystem:
def __init__(self, mode='model_testing', database_path="./data/main.db",
@ -27,7 +27,7 @@ class LabelingSystem:
self.model = None
self.sentence_transformer = None
if self.mode == 'model_testing':
self.model = VideoClassifierV3_9()
self.model = VideoClassifierV3_10()
self.model.load_state_dict(torch.load(model_path))
self.model.eval()
self.sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")

View File

@ -3,13 +3,13 @@ import torch.nn as nn
import torch.nn.functional as F
class VideoClassifierV3_10(nn.Module):
def __init__(self, embedding_dim=1024, hidden_dim=648, output_dim=3):
def __init__(self, embedding_dim=1024, hidden_dim=648, output_dim=3, temperature=1.7):
super().__init__()
self.num_channels = 4
self.channel_names = ['title', 'description', 'tags', 'author_info']
# 可学习温度系数
self.temperature = nn.Parameter(torch.tensor(1.7))
self.temperature = nn.Parameter(torch.tensor(temperature))
# 带约束的通道权重使用Sigmoid替代Softmax
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))

79
ml/filter/modelV3_12.py Normal file
View File

@ -0,0 +1,79 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class VideoClassifierV3_12(nn.Module):
def __init__(self, embedding_dim=1024, hidden_dim=648):
super().__init__()
self.num_channels = 4
self.channel_names = ['title', 'description', 'tags', 'author_info']
# 可学习温度系数
self.temperature = nn.Parameter(torch.tensor(1.7))
# 带约束的通道权重使用Sigmoid替代Softmax
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))
# 第一个二分类器0 vs 1/2
self.first_classifier = nn.Sequential(
nn.Linear(embedding_dim * self.num_channels, hidden_dim*2),
nn.BatchNorm1d(hidden_dim*2),
nn.Dropout(0.2),
nn.GELU(),
nn.Linear(hidden_dim*2, 2) # 输出为2类0 vs 1/2
)
# 第二个二分类器1 vs 2
self.second_classifier = nn.Sequential(
nn.Linear(embedding_dim * self.num_channels, hidden_dim*2),
nn.BatchNorm1d(hidden_dim*2),
nn.Dropout(0.2),
nn.GELU(),
nn.Linear(hidden_dim*2, 2) # 输出为2类1 vs 2
)
# 权重初始化
self._init_weights()
def _init_weights(self):
for layer in self.first_classifier:
if isinstance(layer, nn.Linear):
nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')
nn.init.zeros_(layer.bias)
for layer in self.second_classifier:
if isinstance(layer, nn.Linear):
nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')
nn.init.zeros_(layer.bias)
def forward(self, channel_features: torch.Tensor):
"""
输入格式: [batch_size, num_channels, embedding_dim]
输出格式: [batch_size, output_dim]
"""
# 自适应通道权重Sigmoid约束
weights = torch.sigmoid(self.channel_weights) # [0,1]范围
weighted_features = channel_features * weights.unsqueeze(0).unsqueeze(-1)
# 特征拼接
combined = weighted_features.view(weighted_features.size(0), -1)
# 第一个二分类器0 vs 1/2
first_output = self.first_classifier(combined)
first_probs = F.softmax(first_output, dim=1)
# 第二个二分类器1 vs 2
second_output = self.second_classifier(combined)
second_probs = F.softmax(second_output, dim=1)
# 合并结果
final_probs = torch.zeros(channel_features.size(0), 3).to(channel_features.device)
final_probs[:, 0] = first_probs[:, 0] # 类别0的概率
final_probs[:, 1] = first_probs[:, 1] * second_probs[:, 0] # 类别1的概率
final_probs[:, 2] = first_probs[:, 1] * second_probs[:, 1] # 类别2的概率
return final_probs
def get_channel_weights(self):
"""获取各通道权重(带温度调节)"""
return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy()

View File

@ -2,14 +2,14 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
class VideoClassifierV3_9(nn.Module):
def __init__(self, embedding_dim=1024, hidden_dim=648, output_dim=3):
class VideoClassifierV3_15(nn.Module):
def __init__(self, embedding_dim=1024, hidden_dim=648, output_dim=3, temperature=1.7):
super().__init__()
self.num_channels = 4
self.channel_names = ['title', 'description', 'tags', 'author_info']
self.num_channels = 3
self.channel_names = ['title', 'description', 'tags']
# 可学习温度系数
self.temperature = nn.Parameter(torch.tensor(1.7))
self.temperature = nn.Parameter(torch.tensor(temperature))
# 带约束的通道权重使用Sigmoid替代Softmax
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))
@ -38,21 +38,11 @@ class VideoClassifierV3_9(nn.Module):
nn.init.zeros_(layer.bias)
def forward(self, input_texts, sentence_transformer):
# 合并文本进行批量编码
all_texts = [text for channel in self.channel_names for text in input_texts[channel]]
# 冻结的文本编码
with torch.no_grad():
embeddings = torch.tensor(
sentence_transformer.encode(all_texts),
device=next(self.parameters()).device
)
# 分割并加权通道特征
split_sizes = [len(input_texts[name]) for name in self.channel_names]
channel_features = torch.split(embeddings, split_sizes, dim=0)
channel_features = torch.stack(channel_features, dim=1)
def forward(self, channel_features: torch.Tensor):
"""
输入格式: [batch_size, num_channels, embedding_dim]
输出格式: [batch_size, output_dim]
"""
# 自适应通道权重Sigmoid约束
weights = torch.sigmoid(self.channel_weights) # [0,1]范围

93
ml/filter/modelV6_0.py Normal file
View File

@ -0,0 +1,93 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class VideoClassifierV6_0(nn.Module):
def __init__(self, embedding_dim=256, seq_length=1024, hidden_dim=512, output_dim=3):
super().__init__()
self.num_channels = 4
self.channel_names = ['title', 'description', 'tags', 'author_info']
# CNN特征提取层
self.conv_layers = nn.Sequential(
# 第一层卷积
nn.Conv2d(self.num_channels, 64, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(64),
nn.GELU(),
nn.MaxPool2d(kernel_size=(2, 2)),
# 第二层卷积
nn.Conv2d(64, 128, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(128),
nn.GELU(),
nn.MaxPool2d(kernel_size=(2, 2)),
# 第三层卷积
nn.Conv2d(128, 256, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(256),
nn.GELU(),
# 全局平均池化层
# 输出形状为 [batch_size, 256, 1, 1]
nn.AdaptiveAvgPool2d((1, 1))
)
# 全局池化后的特征维度固定为 256
self.feature_dim = 256
# 全连接层
self.fc = nn.Sequential(
nn.Linear(self.feature_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.Dropout(0.2),
nn.GELU(),
nn.Linear(hidden_dim, output_dim)
)
self._init_weights()
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(self, channel_features: torch.Tensor):
"""
输入格式: [batch_size, num_channels, seq_length, embedding_dim]
输出格式: [batch_size, output_dim]
"""
# CNN特征提取
conv_features = self.conv_layers(channel_features)
# 展平特征(全局池化后形状为 [batch_size, 256, 1, 1]
flat_features = conv_features.view(conv_features.size(0), -1) # [batch_size, 256]
# 全连接层分类
return self.fc(flat_features)
# 损失函数保持不变
class AdaptiveRecallLoss(nn.Module):
def __init__(self, class_weights, alpha=0.8, gamma=2.0, fp_penalty=0.5):
super().__init__()
self.class_weights = class_weights
self.alpha = alpha
self.gamma = gamma
self.fp_penalty = fp_penalty
def forward(self, logits, targets):
ce_loss = F.cross_entropy(logits, targets, weight=self.class_weights, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = ((1 - pt) ** self.gamma) * ce_loss
class_mask = F.one_hot(targets, num_classes=len(self.class_weights))
class_weights = (self.alpha + (1 - self.alpha) * pt.unsqueeze(-1)) * class_mask
recall_loss = (class_weights * focal_loss.unsqueeze(-1)).sum(dim=1)
probs = F.softmax(logits, dim=1)
fp_mask = (targets != 0) & (torch.argmax(logits, dim=1) == 0)
fp_loss = self.fp_penalty * probs[:, 0][fp_mask].pow(2).sum()
total_loss = recall_loss.mean() + fp_loss / len(targets)
return total_loss

View File

@ -1,16 +1,16 @@
import torch
from modelV3_10 import VideoClassifierV3_10
from modelV3_15 import VideoClassifierV3_15
def export_onnx(model_path="./filter/checkpoints/best_model_V3.10.pt",
onnx_path="./model/video_classifier_v3_10.onnx"):
def export_onnx(model_path="./filter/checkpoints/best_model_V3.17.pt",
onnx_path="./model/video_classifier_v3_17.onnx"):
# 初始化模型
model = VideoClassifierV3_10()
model = VideoClassifierV3_15()
model.load_state_dict(torch.load(model_path))
model.eval()
# 创建符合输入规范的虚拟输入
dummy_input = torch.randn(1, 4, 1024) # [batch=1, channels=4, embedding_dim=1024]
dummy_input = torch.randn(1, 3, 1024) # [batch=1, channels=4, embedding_dim=1024]
# 导出ONNX
torch.onnx.export(

36
ml/filter/quantize.py Normal file
View File

@ -0,0 +1,36 @@
from safetensors import safe_open
from safetensors.torch import save_file
import torch
# 配置路径
model_path = "./model/embedding/model.safetensors"
save_path = "./model/embedding/int8_model.safetensors"
# 加载原始嵌入层
with safe_open(model_path, framework="pt") as f:
embeddings_tensor = f.get_tensor("embeddings")
# 计算极值
min_val = torch.min(embeddings_tensor)
max_val = torch.max(embeddings_tensor)
# 计算量化参数
scale = (max_val - min_val) / 255 # int8 的范围是 256 个值(-128 到 127
# 将浮点数映射到 int8 范围
int8_tensor = torch.round((embeddings_tensor - min_val) / scale).to(torch.int8) - 128
# 确保与原张量形状一致
assert int8_tensor.shape == embeddings_tensor.shape
# 保存映射后的 int8 张量
save_file({"embeddings": int8_tensor}, save_path)
# 输出反映射公式
print("int8 反映射公式:")
m = min_val.item()
am = abs(min_val.item())
sign = "-" if m < 0 else "+"
print(f"int8_tensor = (int8_value + 128) × {scale.item()} {sign} {am}")
print("int8 映射完成!")

View File

@ -1,7 +1,7 @@
from labeling_system import LabelingSystem
DATABASE_PATH = "./data/main.db"
MODEL_PATH = "./filter/checkpoints/best_model_V3.9.pt"
MODEL_PATH = "./filter/checkpoints/best_model_V3.11.pt"
OUTPUT_FILE = "./data/filter/real_test.jsonl"
BATCH_SIZE = 50

View File

@ -4,17 +4,16 @@ import numpy as np
from torch.utils.data import DataLoader
import torch.optim as optim
from dataset import MultiChannelDataset
from filter.modelV3_10 import VideoClassifierV3_10, AdaptiveRecallLoss
from sentence_transformers import SentenceTransformer
from filter.modelV3_15 import AdaptiveRecallLoss, VideoClassifierV3_15
from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score, classification_report
import os
import torch
from torch.utils.tensorboard import SummaryWriter # 引入 TensorBoard
from torch.utils.tensorboard import SummaryWriter
import time
from embedding import prepare_batch
import torch.nn as nn
# 动态生成子目录名称
run_name = f"run_{time.strftime('%Y%m%d_%H%M')}"
log_dir = os.path.join('./filter/runs', run_name)
@ -52,9 +51,8 @@ class_weights = torch.tensor(
)
# 初始化模型和SentenceTransformer
sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
model = VideoClassifierV3_10()
checkpoint_name = './filter/checkpoints/best_model_V3.11.pt'
model = VideoClassifierV3_15()
checkpoint_name = './filter/checkpoints/best_model_V3.17.pt'
# 模型保存路径
os.makedirs('./filter/checkpoints', exist_ok=True)
@ -78,7 +76,7 @@ def evaluate(model, dataloader):
with torch.no_grad():
for batch in dataloader:
batch_tensor = prepare_batch(batch['texts'], device="cpu")
batch_tensor = prepare_batch(batch['texts'])
logits = model(batch_tensor)
preds = torch.argmax(logits, dim=1)
all_preds.extend(preds.cpu().numpy())
@ -111,9 +109,8 @@ for epoch in range(num_epochs):
for batch_idx, batch in enumerate(train_loader):
optimizer.zero_grad()
batch_tensor = prepare_batch(batch['texts'], device="cpu")
batch_tensor = prepare_batch(batch['texts'])
# 传入文本字典和sentence_transformer
logits = model(batch_tensor)
loss = criterion(logits, batch['label'])

View File

12
ml/pred/count.py Normal file
View File

@ -0,0 +1,12 @@
# iterate all json files in ./data/pred
import os
import json
count = 0
for filename in os.listdir('./data/pred'):
if filename.endswith('.json'):
with open('./data/pred/' + filename, 'r') as f:
data = json.load(f)
count += len(data)
print(count)

19
ml/pred/crawler.py Normal file
View File

@ -0,0 +1,19 @@
import os
import requests
import json
import time
with open("./pred/2", "r") as fp:
raw = fp.readlines()
aids = [ int(x.strip()) for x in raw ]
for aid in aids:
if os.path.exists(f"./data/pred/{aid}.json"):
continue
url = f"https://api.bunnyxt.com/tdd/v2/video/{aid}/record?last_count=5000"
r = requests.get(url)
data = r.json()
with open (f"./data/pred/{aid}.json", "w") as fp:
json.dump(data, fp, ensure_ascii=False, indent=4)
time.sleep(5)
print(aid)

178
ml/pred/dataset.py Normal file
View File

@ -0,0 +1,178 @@
import os
import json
import random
import bisect
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
import datetime
class VideoPlayDataset(Dataset):
def __init__(self, data_dir, publish_time_path, term='long', seed=42):
if seed is not None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
self.data_dir = data_dir
self.series_dict = self._load_and_process_data(publish_time_path)
self.valid_series = [s for s in self.series_dict.values() if len(s['abs_time']) > 1]
self.term = term
# Set time window based on term
self.time_window = 1000 * 24 * 3600 if term == 'long' else 7 * 24 * 3600
MINUTE = 60
HOUR = 3600
DAY = 24 * HOUR
if term == 'long':
self.feature_windows = [
1 * HOUR,
6 * HOUR,
1 *DAY,
3 * DAY,
7 * DAY,
30 * DAY,
100 * DAY
]
else:
self.feature_windows = [
( 15 * MINUTE, 0 * MINUTE),
( 40 * MINUTE, 0 * MINUTE),
( 1 * HOUR, 0 * HOUR),
( 2 * HOUR, 1 * HOUR),
( 3 * HOUR, 2 * HOUR),
( 3 * HOUR, 0 * HOUR),
#( 6 * HOUR, 3 * HOUR),
( 6 * HOUR, 0 * HOUR),
(18 * HOUR, 12 * HOUR),
#( 1 * DAY, 6 * HOUR),
( 1 * DAY, 0 * DAY),
#( 2 * DAY, 1 * DAY),
( 3 * DAY, 0 * DAY),
#( 4 * DAY, 1 * DAY),
( 7 * DAY, 0 * DAY)
]
def _extract_features(self, series, current_idx, target_idx):
current_time = series['abs_time'][current_idx]
current_play = series['play_count'][current_idx]
dt = datetime.datetime.fromtimestamp(current_time)
if self.term == 'long':
time_features = [
np.log2(max(current_time - series['create_time'], 1))
]
else:
time_features = [
(dt.hour * 3600 + dt.minute * 60 + dt.second) / 86400,
(dt.weekday() * 24 + dt.hour) / 168,
np.log2(max(current_time - series['create_time'], 1))
]
growth_features = []
if self.term == 'long':
for window in self.feature_windows:
prev_time = current_time - window
prev_idx = self._get_nearest_value(series, prev_time, current_idx)
if prev_idx is not None:
time_diff = current_time - series['abs_time'][prev_idx]
play_diff = current_play - series['play_count'][prev_idx]
scaled_diff = play_diff / (time_diff / window) if time_diff > 0 else 0.0
else:
scaled_diff = 0.0
growth_features.append(np.log2(max(scaled_diff, 1)))
else:
for window_start, window_end in self.feature_windows:
prev_time_start = current_time - window_start
prev_time_end = current_time - window_end # window_end is typically 0
prev_idx_start = self._get_nearest_value(series, prev_time_start, current_idx)
prev_idx_end = self._get_nearest_value(series, prev_time_end, current_idx)
if prev_idx_start is not None and prev_idx_end is not None:
time_diff = series['abs_time'][prev_idx_end] - series['abs_time'][prev_idx_start]
play_diff = series['play_count'][prev_idx_end] - series['play_count'][prev_idx_start]
scaled_diff = play_diff / (time_diff / (window_start - window_end)) if time_diff > 0 else 0.0
else:
scaled_diff = 0.0
growth_features.append(np.log2(max(scaled_diff, 1)))
time_diff = series['abs_time'][target_idx] - current_time
return [np.log2(max(time_diff, 1))] + [np.log2(current_play + 1)] + growth_features + time_features
def _load_and_process_data(self, publish_time_path):
publish_df = pd.read_csv(publish_time_path)
publish_df['published_at'] = pd.to_datetime(publish_df['published_at'])
publish_dict = dict(zip(publish_df['aid'], publish_df['published_at']))
series_dict = {}
for filename in os.listdir(self.data_dir):
if not filename.endswith('.json'):
continue
with open(os.path.join(self.data_dir, filename), 'r') as f:
data = json.load(f)
if 'code' in data:
continue
for item in data:
aid = item['aid']
published_time = pd.to_datetime(publish_dict[aid]).timestamp()
if aid not in series_dict:
series_dict[aid] = {
'abs_time': [],
'play_count': [],
'create_time': published_time
}
series_dict[aid]['abs_time'].append(item['added'])
series_dict[aid]['play_count'].append(item['view'])
# Sort each series by absolute time
for aid in series_dict:
sorted_indices = sorted(range(len(series_dict[aid]['abs_time'])),
key=lambda k: series_dict[aid]['abs_time'][k])
series_dict[aid]['abs_time'] = [series_dict[aid]['abs_time'][i] for i in sorted_indices]
series_dict[aid]['play_count'] = [series_dict[aid]['play_count'][i] for i in sorted_indices]
return series_dict
def __len__(self):
return 100000 # Virtual length for sampling
def _get_nearest_value(self, series, target_time, current_idx):
times = series['abs_time']
pos = bisect.bisect_right(times, target_time, 0, current_idx + 1)
candidates = []
if pos > 0:
candidates.append(pos - 1)
if pos <= current_idx:
candidates.append(pos)
if not candidates:
return None
closest_idx = min(candidates, key=lambda i: abs(times[i] - target_time))
return closest_idx
def __getitem__(self, _idx):
while True:
series = random.choice(self.valid_series)
if len(series['abs_time']) < 2:
continue
current_idx = random.randint(0, len(series['abs_time']) - 2)
current_time = series['abs_time'][current_idx]
max_target_time = current_time + self.time_window
candidate_indices = []
for j in range(current_idx + 1, len(series['abs_time'])):
if series['abs_time'][j] > max_target_time:
break
candidate_indices.append(j)
if not candidate_indices:
continue
target_idx = random.choice(candidate_indices)
break
current_play = series['play_count'][current_idx]
target_play = series['play_count'][target_idx]
target_delta = max(target_play - current_play, 0)
return {
'features': torch.FloatTensor(self._extract_features(series, current_idx, target_idx)),
'target': torch.log2(torch.FloatTensor([target_delta]) + 1)
}
def collate_fn(batch):
return {
'features': torch.stack([x['features'] for x in batch]),
'targets': torch.stack([x['target'] for x in batch])
}

28
ml/pred/export_onnx.py Normal file
View File

@ -0,0 +1,28 @@
import torch
import torch.onnx
from model import CompactPredictor
def export_model(input_size, checkpoint_path, onnx_path):
model = CompactPredictor(input_size)
model.load_state_dict(torch.load(checkpoint_path))
dummy_input = torch.randn(1, input_size)
model.eval()
torch.onnx.export(model, # Model to be exported
dummy_input, # Model input
onnx_path, # Save path
export_params=True, # Whether to export model parameters
opset_version=11, # ONNX opset version
do_constant_folding=True, # Whether to perform constant folding optimization
input_names=['input'], # Input node name
output_names=['output'], # Output node name
dynamic_axes={'input': {0: 'batch_size'}, # Dynamic batch size
'output': {0: 'batch_size'}})
print(f"ONNX model has been exported to: {onnx_path}")
if __name__ == '__main__':
export_model(10, './pred/checkpoints/long_term.pt', 'long_term.onnx')
export_model(12, './pred/checkpoints/short_term.pt', 'short_term.onnx')

32
ml/pred/inference.py Normal file
View File

@ -0,0 +1,32 @@
import datetime
import numpy as np
from model import CompactPredictor
import torch
def main():
model = CompactPredictor(10).to('cpu', dtype=torch.float32)
model.load_state_dict(torch.load('./pred/checkpoints/long_term.pt'))
model.eval()
# inference
initial = 997029
last = initial
start_time = '2025-03-17 00:13:17'
for i in range(1, 120):
hour = i / 0.5
sec = hour * 3600
time_d = np.log2(sec)
data = [time_d, np.log2(initial+1), # time_delta, current_views
6.111542, 8.404707, 10.071566, 11.55888, 12.457823,# grows_feat
0.009225, 0.001318, 28.001814# time_feat
]
np_arr = np.array([data])
tensor = torch.from_numpy(np_arr).to('cpu', dtype=torch.float32)
output = model(tensor)
num = output.detach().numpy()[0][0]
views_pred = int(np.exp2(num)) + initial
current_time = datetime.datetime.strptime(start_time, '%Y-%m-%d %H:%M:%S') + datetime.timedelta(hours=hour)
print(current_time.strftime('%m-%d %H:%M:%S'), views_pred, views_pred - last)
last = views_pred
if __name__ == '__main__':
main()

Some files were not shown because too many files have changed in this diff Show More