merge: branch 'main' into gitbook
This commit is contained in:
commit
4d2b002264
14
.gitignore
vendored
14
.gitignore
vendored
@ -51,7 +51,6 @@ internal/
|
|||||||
!tests/cases/projects/projectOption/**/node_modules
|
!tests/cases/projects/projectOption/**/node_modules
|
||||||
!tests/cases/projects/NodeModulesSearch/**/*
|
!tests/cases/projects/NodeModulesSearch/**/*
|
||||||
!tests/baselines/reference/project/nodeModules*/**/*
|
!tests/baselines/reference/project/nodeModules*/**/*
|
||||||
.idea
|
|
||||||
yarn.lock
|
yarn.lock
|
||||||
yarn-error.log
|
yarn-error.log
|
||||||
.parallelperf.*
|
.parallelperf.*
|
||||||
@ -76,14 +75,13 @@ node_modules/
|
|||||||
|
|
||||||
|
|
||||||
# project specific
|
# project specific
|
||||||
data/main.db
|
|
||||||
.env
|
|
||||||
logs/
|
logs/
|
||||||
__pycache__
|
__pycache__
|
||||||
filter/runs
|
ml/filter/runs
|
||||||
data/filter/eval*
|
ml/pred/runs
|
||||||
data/filter/train*
|
ml/pred/checkpoints
|
||||||
filter/checkpoints
|
ml/pred/observed
|
||||||
data/filter/model_predicted*
|
ml/data/
|
||||||
|
ml/filter/checkpoints
|
||||||
scripts
|
scripts
|
||||||
model/
|
model/
|
||||||
|
9
.idea/.gitignore
vendored
Normal file
9
.idea/.gitignore
vendored
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
# Default ignored files
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# Editor-based HTTP Client requests
|
||||||
|
/httpRequests/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
||||||
|
dataSources.xml
|
21
.idea/cvsa.iml
Normal file
21
.idea/cvsa.iml
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="WEB_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$">
|
||||||
|
<excludeFolder url="file://$MODULE_DIR$/.tmp" />
|
||||||
|
<excludeFolder url="file://$MODULE_DIR$/temp" />
|
||||||
|
<excludeFolder url="file://$MODULE_DIR$/tmp" />
|
||||||
|
<excludeFolder url="file://$MODULE_DIR$/ml/data" />
|
||||||
|
<excludeFolder url="file://$MODULE_DIR$/doc" />
|
||||||
|
<excludeFolder url="file://$MODULE_DIR$/ml/filter/checkpoints" />
|
||||||
|
<excludeFolder url="file://$MODULE_DIR$/ml/filter/runs" />
|
||||||
|
<excludeFolder url="file://$MODULE_DIR$/ml/lab/data" />
|
||||||
|
<excludeFolder url="file://$MODULE_DIR$/ml/lab/temp" />
|
||||||
|
<excludeFolder url="file://$MODULE_DIR$/logs" />
|
||||||
|
<excludeFolder url="file://$MODULE_DIR$/model" />
|
||||||
|
<excludeFolder url="file://$MODULE_DIR$/src/db" />
|
||||||
|
</content>
|
||||||
|
<orderEntry type="inheritedJdk" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
</module>
|
12
.idea/inspectionProfiles/Project_Default.xml
Normal file
12
.idea/inspectionProfiles/Project_Default.xml
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<profile version="1.0">
|
||||||
|
<option name="myName" value="Project Default" />
|
||||||
|
<inspection_tool class="GrazieInspection" enabled="false" level="GRAMMAR_ERROR" enabled_by_default="false" />
|
||||||
|
<inspection_tool class="LanguageDetectionInspection" enabled="false" level="WARNING" enabled_by_default="false" />
|
||||||
|
<inspection_tool class="SpellCheckingInspection" enabled="false" level="TYPO" enabled_by_default="false">
|
||||||
|
<option name="processCode" value="true" />
|
||||||
|
<option name="processLiterals" value="true" />
|
||||||
|
<option name="processComments" value="true" />
|
||||||
|
</inspection_tool>
|
||||||
|
</profile>
|
||||||
|
</component>
|
8
.idea/modules.xml
Normal file
8
.idea/modules.xml
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/cvsa.iml" filepath="$PROJECT_DIR$/.idea/cvsa.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
6
.idea/sqldialects.xml
Normal file
6
.idea/sqldialects.xml
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="SqlDialectMappings">
|
||||||
|
<file url="PROJECT" dialect="PostgreSQL" />
|
||||||
|
</component>
|
||||||
|
</project>
|
6
.idea/vcs.xml
Normal file
6
.idea/vcs.xml
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
@ -2,4 +2,5 @@ data
|
|||||||
*.json
|
*.json
|
||||||
*.svg
|
*.svg
|
||||||
*.txt
|
*.txt
|
||||||
*.md
|
*.md
|
||||||
|
*config*
|
35
.zed/settings.json
Normal file
35
.zed/settings.json
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
// Folder-specific settings
|
||||||
|
//
|
||||||
|
// For a full list of overridable settings, and general information on folder-specific settings,
|
||||||
|
// see the documentation: https://zed.dev/docs/configuring-zed#settings-files
|
||||||
|
{
|
||||||
|
"lsp": {
|
||||||
|
"deno": {
|
||||||
|
"settings": {
|
||||||
|
"deno": {
|
||||||
|
"enable": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"languages": {
|
||||||
|
"TypeScript": {
|
||||||
|
"language_servers": [
|
||||||
|
"deno",
|
||||||
|
"!typescript-language-server",
|
||||||
|
"!vtsls",
|
||||||
|
"!eslint"
|
||||||
|
],
|
||||||
|
"formatter": "language_server"
|
||||||
|
},
|
||||||
|
"TSX": {
|
||||||
|
"language_servers": [
|
||||||
|
"deno",
|
||||||
|
"!typescript-language-server",
|
||||||
|
"!vtsls",
|
||||||
|
"!eslint"
|
||||||
|
],
|
||||||
|
"formatter": "language_server"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
22
README.md
22
README.md
@ -6,9 +6,12 @@
|
|||||||
|
|
||||||
纵观整个互联网,对于「中文歌声合成」或「中文虚拟歌手」(常简称为中V或VC)相关信息进行较为系统、全面地整理收集的主要有以下几个网站:
|
纵观整个互联网,对于「中文歌声合成」或「中文虚拟歌手」(常简称为中V或VC)相关信息进行较为系统、全面地整理收集的主要有以下几个网站:
|
||||||
|
|
||||||
- [萌娘百科](https://zh.moegirl.org.cn/): 收录了大量中V歌曲及歌姬的信息,呈现形式为传统维基(基于[MediaWiki](https://www.mediawiki.org/))。
|
- [萌娘百科](https://zh.moegirl.org.cn/):
|
||||||
- [VCPedia](https://vcpedia.cn/): 由原萌娘百科中文歌声合成编辑团队的部分成员搭建,专属于中文歌声合成相关内容的信息集成站点[^1],呈现形式为传统维基(基于[MediaWiki](https://www.mediawiki.org/))。
|
收录了大量中V歌曲及歌姬的信息,呈现形式为传统维基(基于[MediaWiki](https://www.mediawiki.org/))。
|
||||||
- [VocaDB](https://vocadb.net/): 一个围绕 Vocaloid、UTAU 和其他歌声合成器的协作数据库,其中包含艺术家、唱片、PV 等[^2],其中包含大量中文歌声合成作品。
|
- [VCPedia](https://vcpedia.cn/):
|
||||||
|
由原萌娘百科中文歌声合成编辑团队的部分成员搭建,专属于中文歌声合成相关内容的信息集成站点[^1],呈现形式为传统维基(基于[MediaWiki](https://www.mediawiki.org/))。
|
||||||
|
- [VocaDB](https://vocadb.net/): 一个围绕 Vocaloid、UTAU 和其他歌声合成器的协作数据库,其中包含艺术家、唱片、PV
|
||||||
|
等[^2],其中包含大量中文歌声合成作品。
|
||||||
- [天钿Daily](https://tdd.bunnyxt.com/):一个VC相关数据交流与分享的网站。致力于VC相关数据交流,定期抓取VC相关数据,选取有意义的纬度展示。[^3]
|
- [天钿Daily](https://tdd.bunnyxt.com/):一个VC相关数据交流与分享的网站。致力于VC相关数据交流,定期抓取VC相关数据,选取有意义的纬度展示。[^3]
|
||||||
|
|
||||||
上述网站中,或多或少存在一些不足,例如:
|
上述网站中,或多或少存在一些不足,例如:
|
||||||
@ -36,19 +39,22 @@
|
|||||||
|
|
||||||
### 数据库
|
### 数据库
|
||||||
|
|
||||||
中V档案馆使用[PostgreSQL](https://postgresql.org)作为数据库,我们承诺定期导出数据库转储 (dump) 文件并公开,其内容遵从以下协议或条款:
|
中V档案馆使用[PostgreSQL](https://postgresql.org)作为数据库,我们承诺定期导出数据库转储 (dump)
|
||||||
|
文件并公开,其内容遵从以下协议或条款:
|
||||||
|
|
||||||
- 数据库中的事实性数据,根据适用法律,不构成受版权保护的内容。中V档案馆放弃一切可能的权利([CC0 1.0 Universal](https://creativecommons.org/publicdomain/zero/1.0/))。
|
- 数据库中的事实性数据,根据适用法律,不构成受版权保护的内容。中V档案馆放弃一切可能的权利([CC0 1.0 Universal](https://creativecommons.org/publicdomain/zero/1.0/))。
|
||||||
- 对于数据库中有原创性的内容(如贡献者编辑的描述性内容),如无例外,以[CC BY 4.0协议](https://creativecommons.org/licenses/by/4.0/)提供。
|
- 对于数据库中有原创性的内容(如贡献者编辑的描述性内容),如无例外,以[CC BY 4.0协议](https://creativecommons.org/licenses/by/4.0/)提供。
|
||||||
- 对于引用、摘编或改编自萌娘百科、VCPedia的内容,以与原始协议(CC BY-NC-SA 3.0 CN)兼容的协议[CC BY-NC-SA 4.0协议](https://creativecommons.org/licenses/by-nc-sa/4.0/)提供,并注明原始协议 。
|
- 对于引用、摘编或改编自萌娘百科、VCPedia的内容,以与原始协议(CC BY-NC-SA 3.0
|
||||||
> 根据原始协议第四条第2项内容,CC BY-NC-SA 4.0协议为与原始协议具有相同授权要素的后续版本(“可适用的协议”)。
|
CN)兼容的协议[CC BY-NC-SA 4.0协议](https://creativecommons.org/licenses/by-nc-sa/4.0/)提供,并注明原始协议 。
|
||||||
|
> 根据原始协议第四条第2项内容,CC BY-NC-SA 4.0协议为与原始协议具有相同授权要素的后续版本(“可适用的协议”)。
|
||||||
- 中V档案馆文档使用[CC BY 4.0协议](https://creativecommons.org/licenses/by/4.0/)。
|
- 中V档案馆文档使用[CC BY 4.0协议](https://creativecommons.org/licenses/by/4.0/)。
|
||||||
|
|
||||||
### 软件代码
|
### 软件代码
|
||||||
|
|
||||||
用于构建中V档案馆的软件代码在[AGPL 3.0](https://www.gnu.org/licenses/agpl-3.0.html)许可证下公开,参见[LICENSE](./LICENSE)
|
用于构建中V档案馆的软件代码在[AGPL 3.0](https://www.gnu.org/licenses/agpl-3.0.html)许可证下公开,参见[LICENSE](./LICENSE)
|
||||||
|
|
||||||
|
|
||||||
[^1]: 引用自[VCPedia](https://vcpedia.cn/%E9%A6%96%E9%A1%B5),于[知识共享 署名-非商业性使用-相同方式共享 3.0中国大陆 (CC BY-NC-SA 3.0 CN) 许可协议](https://creativecommons.org/licenses/by-nc-sa/3.0/cn/)下提供。
|
[^1]: 引用自[VCPedia](https://vcpedia.cn/%E9%A6%96%E9%A1%B5),于[知识共享 署名-非商业性使用-相同方式共享 3.0中国大陆 (CC BY-NC-SA 3.0 CN) 许可协议](https://creativecommons.org/licenses/by-nc-sa/3.0/cn/)下提供。
|
||||||
|
|
||||||
[^2]: 翻译自[VocaDB](https://vocadb.net/),于[CC BY 4.0协议](https://creativecommons.org/licenses/by/4.0/)下提供。
|
[^2]: 翻译自[VocaDB](https://vocadb.net/),于[CC BY 4.0协议](https://creativecommons.org/licenses/by/4.0/)下提供。
|
||||||
[^3]: 引用自[关于 - 天钿Daily](https://tdd.bunnyxt.com/about)
|
|
||||||
|
[^3]: 引用自[关于 - 天钿Daily](https://tdd.bunnyxt.com/about)
|
||||||
|
@ -1,12 +0,0 @@
|
|||||||
import { JSX } from "preact";
|
|
||||||
import { IS_BROWSER } from "$fresh/runtime.ts";
|
|
||||||
|
|
||||||
export function Button(props: JSX.HTMLAttributes<HTMLButtonElement>) {
|
|
||||||
return (
|
|
||||||
<button
|
|
||||||
{...props}
|
|
||||||
disabled={!IS_BROWSER || props.disabled}
|
|
||||||
class="px-2 py-1 border-gray-500 border-2 rounded bg-white hover:bg-gray-200 transition-colors"
|
|
||||||
/>
|
|
||||||
);
|
|
||||||
}
|
|
File diff suppressed because it is too large
Load Diff
649902
data/2025010104_c30_aids.txt
649902
data/2025010104_c30_aids.txt
File diff suppressed because it is too large
Load Diff
@ -1,3 +0,0 @@
|
|||||||
# The data
|
|
||||||
|
|
||||||
感谢[天钿Daily](https://tdd.bunnyxt.com/)提供的数据。
|
|
@ -1,55 +0,0 @@
|
|||||||
import json
|
|
||||||
import random
|
|
||||||
|
|
||||||
def process_data(input_file, output_file):
|
|
||||||
"""
|
|
||||||
从输入文件中读取数据,找出model和human不一致的行,
|
|
||||||
删除"model"键,将"human"键重命名为"label",
|
|
||||||
然后将处理后的数据添加到输出文件中。
|
|
||||||
在写入之前,它会加载output_file中的所有样本,
|
|
||||||
并使用aid键进行去重过滤。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_file (str): 输入文件的路径。
|
|
||||||
output_file (str): 输出文件的路径。
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 加载output_file中已有的数据,用于去重
|
|
||||||
existing_data = set()
|
|
||||||
try:
|
|
||||||
with open(output_file, 'r', encoding='utf-8') as f_out:
|
|
||||||
for line in f_out:
|
|
||||||
try:
|
|
||||||
data = json.loads(line)
|
|
||||||
existing_data.add(data['aid'])
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass # 忽略JSON解码错误,继续读取下一行
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass # 如果文件不存在,则忽略
|
|
||||||
|
|
||||||
with open(input_file, 'r', encoding='utf-8') as f_in, open(output_file, 'a', encoding='utf-8') as f_out:
|
|
||||||
for line in f_in:
|
|
||||||
try:
|
|
||||||
data = json.loads(line)
|
|
||||||
|
|
||||||
if data['model'] != data['human'] or random.random() < 0.2:
|
|
||||||
if data['aid'] not in existing_data: # 检查aid是否已存在
|
|
||||||
del data['model']
|
|
||||||
data['label'] = data['human']
|
|
||||||
del data['human']
|
|
||||||
f_out.write(json.dumps(data, ensure_ascii=False) + '\n')
|
|
||||||
existing_data.add(data['aid']) # 将新的aid添加到集合中
|
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
print(f"JSON解码错误: {e}")
|
|
||||||
print(f"错误行内容: {line.strip()}")
|
|
||||||
except KeyError as e:
|
|
||||||
print(f"KeyError: 键 '{e}' 不存在")
|
|
||||||
print(f"错误行内容: {line.strip()}")
|
|
||||||
|
|
||||||
# 调用函数处理数据
|
|
||||||
input_file = 'real_test.jsonl'
|
|
||||||
output_file = 'labeled_data.jsonl'
|
|
||||||
process_data(input_file, output_file)
|
|
||||||
print(f"处理完成,结果已写入 {output_file}")
|
|
||||||
|
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
75
deno.json
75
deno.json
@ -1,57 +1,22 @@
|
|||||||
{
|
{
|
||||||
"lock": false,
|
"lock": false,
|
||||||
"tasks": {
|
"workspace": ["./packages/crawler", "./packages/frontend", "./packages/backend", "./packages/core"],
|
||||||
"crawl-raw-bili": "deno --allow-env --allow-ffi --allow-read --allow-net --allow-write --allow-run src/db/raw/insertAidsToDB.ts",
|
"nodeModulesDir": "auto",
|
||||||
"crawl-bili-aids": "deno --allow-env --allow-ffi --allow-read --allow-net --allow-write --allow-run src/db/raw/fetchAids.ts",
|
"tasks": {
|
||||||
"check": "deno fmt --check && deno lint && deno check **/*.ts && deno check **/*.tsx",
|
"crawler": "deno task --filter 'crawler' all",
|
||||||
"cli": "echo \"import '\\$fresh/src/dev/cli.ts'\" | deno run --unstable -A -",
|
"backend": "deno task --filter 'backend' start"
|
||||||
"manifest": "deno task cli manifest $(pwd)",
|
},
|
||||||
"start": "deno run -A --watch=static/,routes/ dev.ts",
|
"fmt": {
|
||||||
"build": "deno run -A dev.ts build",
|
"useTabs": true,
|
||||||
"preview": "deno run -A main.ts",
|
"lineWidth": 120,
|
||||||
"update": "deno run -A -r https://fresh.deno.dev/update .",
|
"indentWidth": 4,
|
||||||
"worker": "deno run --env-file=.env --allow-env --allow-read --allow-ffi --allow-net --allow-write ./src/worker.ts",
|
"semiColons": true,
|
||||||
"adder": "deno run --allow-env --allow-read --allow-ffi --allow-net ./src/jobAdder.ts",
|
"proseWrap": "always"
|
||||||
"bullui": "deno run --allow-read --allow-env --allow-ffi --allow-net ./src/bullui.ts",
|
},
|
||||||
"all": "concurrently 'deno task start' 'deno task worker' 'deno task adder' 'deno task bullui'",
|
"imports": {
|
||||||
"test": "deno test ./test/ --allow-env --allow-ffi --allow-read --allow-net --allow-write --allow-run"
|
"@astrojs/node": "npm:@astrojs/node@^9.1.3",
|
||||||
},
|
"@astrojs/svelte": "npm:@astrojs/svelte@^7.0.8",
|
||||||
"lint": {
|
"@core/db/": "./packages/core/db/",
|
||||||
"rules": {
|
"date-fns": "npm:date-fns@^4.1.0"
|
||||||
"tags": ["fresh", "recommended"]
|
}
|
||||||
}
|
|
||||||
},
|
|
||||||
"exclude": ["**/_fresh/*"],
|
|
||||||
"imports": {
|
|
||||||
"@std/assert": "jsr:@std/assert@1",
|
|
||||||
"$fresh/": "https://deno.land/x/fresh@1.7.3/",
|
|
||||||
"preact": "https://esm.sh/preact@10.22.0",
|
|
||||||
"preact/": "https://esm.sh/preact@10.22.0/",
|
|
||||||
"@preact/signals": "https://esm.sh/*@preact/signals@1.2.2",
|
|
||||||
"@preact/signals-core": "https://esm.sh/*@preact/signals-core@1.5.1",
|
|
||||||
"tailwindcss": "npm:tailwindcss@3.4.1",
|
|
||||||
"tailwindcss/": "npm:/tailwindcss@3.4.1/",
|
|
||||||
"tailwindcss/plugin": "npm:/tailwindcss@3.4.1/plugin.js",
|
|
||||||
"$std/": "https://deno.land/std@0.216.0/",
|
|
||||||
"@huggingface/transformers": "npm:@huggingface/transformers@3.0.0",
|
|
||||||
"bullmq": "npm:bullmq",
|
|
||||||
"lib/": "./lib/",
|
|
||||||
"ioredis": "npm:ioredis",
|
|
||||||
"@bull-board/api": "npm:@bull-board/api",
|
|
||||||
"@bull-board/express": "npm:@bull-board/express",
|
|
||||||
"express": "npm:express",
|
|
||||||
"src/": "./src/"
|
|
||||||
},
|
|
||||||
"compilerOptions": {
|
|
||||||
"jsx": "react-jsx",
|
|
||||||
"jsxImportSource": "preact"
|
|
||||||
},
|
|
||||||
"nodeModulesDir": "auto",
|
|
||||||
"fmt": {
|
|
||||||
"useTabs": true,
|
|
||||||
"lineWidth": 120,
|
|
||||||
"indentWidth": 4,
|
|
||||||
"semiColons": true,
|
|
||||||
"proseWrap": "always"
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
7
dev.ts
7
dev.ts
@ -1,7 +0,0 @@
|
|||||||
#!/usr/bin/env -S deno run -A --watch=static/,routes/
|
|
||||||
|
|
||||||
import dev from "$fresh/dev.ts";
|
|
||||||
import config from "./fresh.config.ts";
|
|
||||||
|
|
||||||
import "$std/dotenv/load.ts";
|
|
||||||
await dev(import.meta.url, "./main.ts", config);
|
|
@ -17,7 +17,8 @@ layout:
|
|||||||
|
|
||||||
Welcome to the CVSA Documentation!
|
Welcome to the CVSA Documentation!
|
||||||
|
|
||||||
This doc contains various information about the CVSA project, including technical architecture, tutorials for visitors, etc.
|
This doc contains various information about the CVSA project, including technical architecture, tutorials for visitors,
|
||||||
|
etc.
|
||||||
|
|
||||||
### Jump right in
|
### Jump right in
|
||||||
|
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
# Table of contents
|
# Table of contents
|
||||||
|
|
||||||
* [Welcome](README.md)
|
- [Welcome](README.md)
|
||||||
|
|
||||||
## About
|
## About
|
||||||
|
|
||||||
* [About CVSA Project](about/this-project.md)
|
- [About CVSA Project](about/this-project.md)
|
||||||
* [Scope of Inclusion](about/scope-of-inclusion.md)
|
- [Scope of Inclusion](about/scope-of-inclusion.md)
|
||||||
|
|
||||||
## Architecure
|
## Architecure
|
||||||
|
|
||||||
@ -17,5 +17,5 @@
|
|||||||
|
|
||||||
## API Doc
|
## API Doc
|
||||||
|
|
||||||
* [Catalog](api-doc/catalog.md)
|
- [Catalog](api-doc/catalog.md)
|
||||||
* [Songs](api-doc/songs.md)
|
- [Songs](api-doc/songs.md)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# Scope of Inclusion
|
# Scope of Inclusion
|
||||||
|
|
||||||
CVSA contains many aspects of Chinese Vocal Synthesis, including songs, albums, artists (publisher, manipulators, arranger, etc), singers and voice engines / voicebanks. 
|
CVSA contains many aspects of Chinese Vocal Synthesis, including songs, albums, artists (publisher, manipulators,
|
||||||
|
arranger, etc), singers and voice engines / voicebanks. 
|
||||||
|
|
||||||
For a **song**, it must meet the following conditions to be included in CVSA:
|
For a **song**, it must meet the following conditions to be included in CVSA:
|
||||||
|
|
||||||
@ -26,6 +27,11 @@ We define a **Chinese virtual singer** as follows:
|
|||||||
|
|
||||||
### Using Vocal Synthesizer
|
### Using Vocal Synthesizer
|
||||||
|
|
||||||
To be included in CVSA, at least one line of the song must be produced by a Vocal Synthesizer (including harmony vocals).
|
To be included in CVSA, at least one line of the song must be produced by a Vocal Synthesizer (including harmony
|
||||||
|
vocals).
|
||||||
|
|
||||||
We define a vocal synthesizer as a software or system that generates synthesized singing voices by algorithmically modeling vocal characteristics and producing audio from input parameters such as lyrics, pitch, and dynamics, encompassing both waveform-concatenation-based (e.g., VOCALOID, UTAU) and AI-based (e.g., Synthesizer V, ACE Studio) approaches, **but excluding voice conversion tools that solely alter the timbre of pre-existing recordings** (e.g., [so-vits svc](https://github.com/svc-develop-team/so-vits-svc)).
|
We define a vocal synthesizer as a software or system that generates synthesized singing voices by algorithmically
|
||||||
|
modeling vocal characteristics and producing audio from input parameters such as lyrics, pitch, and dynamics,
|
||||||
|
encompassing both waveform-concatenation-based (e.g., VOCALOID, UTAU) and AI-based (e.g., Synthesizer V, ACE Studio)
|
||||||
|
approaches, **but excluding voice conversion tools that solely alter the timbre of pre-existing recordings** (e.g.,
|
||||||
|
[so-vits svc](https://github.com/svc-develop-team/so-vits-svc)).
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
# About CVSA Project
|
# About CVSA Project
|
||||||
|
|
||||||
CVSA (Chinese Vocal Synthesis Archive) aims to collect as much content as possible about the Chinese Vocal Synthesis community in a highly automation-assisted way. 
|
CVSA (Chinese Vocal Synthesis Archive) aims to collect as much content as possible about the Chinese Vocal Synthesis
|
||||||
|
community in a highly automation-assisted way. 
|
||||||
|
|
||||||
Unlike existing projects such as [VocaDB](https://vocadb.net), CVSA collects and displays the following content in an automated and manually edited way:
|
Unlike existing projects such as [VocaDB](https://vocadb.net), CVSA collects and displays the following content in an
|
||||||
|
automated and manually edited way:
|
||||||
* Metadata of songs (name, duration, publisher, singer, etc.)
|
|
||||||
* Descriptive information of songs (content introduction, creation background, lyrics, etc.)
|
|
||||||
* Engagement data snapshots of songs, i.e. historical snapshots of their engagement data (including views, favorites, likes, etc.) on the [Bilibili](https://en.wikipedia.org/wiki/Bilibili) website.
|
|
||||||
* Information about artists, albums, vocal synthesizers, and voicebanks.
|
|
||||||
|
|
||||||
|
- Metadata of songs (name, duration, publisher, singer, etc.)
|
||||||
|
- Descriptive information of songs (content introduction, creation background, lyrics, etc.)
|
||||||
|
- Engagement data snapshots of songs, i.e. historical snapshots of their engagement data (including views, favorites,
|
||||||
|
likes, etc.) on the [Bilibili](https://en.wikipedia.org/wiki/Bilibili) website.
|
||||||
|
- Information about artists, albums, vocal synthesizers, and voicebanks.
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
# Catalog
|
# Catalog
|
||||||
|
|
||||||
* [**Songs**](songs.md)
|
- [**Songs**](songs.md)
|
||||||
|
|
||||||
|
@ -6,7 +6,8 @@ The AI systems we currently use are:
|
|||||||
|
|
||||||
### The Filter
|
### The Filter
|
||||||
|
|
||||||
Located at `/filter/` under project root dir, it classifies a video in the [category 30](../about/scope-of-inclusion.md#category-30) into the following categories:
|
Located at `/filter/` under project root dir, it classifies a video in the
|
||||||
|
[category 30](../about/scope-of-inclusion.md#category-30) into the following categories:
|
||||||
|
|
||||||
* 0: Not related to Chinese vocal synthesis
|
* 0: Not related to Chinese vocal synthesis
|
||||||
* 1: A original song with Chinese vocal synthesis
|
* 1: A original song with Chinese vocal synthesis
|
||||||
|
@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
CVSA uses [PostgreSQL](https://www.postgresql.org/) as our database.
|
CVSA uses [PostgreSQL](https://www.postgresql.org/) as our database.
|
||||||
|
|
||||||
All public data of CVSA (excluding users' personal data) is stored in a database named `cvsa_main`, which contains the following tables:
|
All public data of CVSA (excluding users' personal data) is stored in a database named `cvsa_main`, which contains the
|
||||||
|
following tables:
|
||||||
|
|
||||||
* songs: stores the main information of songs
|
* songs: stores the main information of songs
|
||||||
* bili\_user: stores snapshots of Bilibili user information
|
* bili\_user: stores snapshots of Bilibili user information
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# Type of Song
|
# Type of Song
|
||||||
|
|
||||||
The **Unrelated type** refers specifically to videos that are not in our [Scope of Inclusion](../../about/scope-of-inclusion.md).
|
The **Unrelated type** refers specifically to videos that are not in our
|
||||||
|
[Scope of Inclusion](../../about/scope-of-inclusion.md).
|
||||||
|
|
||||||
### Table: `songs`
|
### Table: `songs`
|
||||||
|
|
||||||
|
@ -1,22 +1,22 @@
|
|||||||
# Table of contents
|
# Table of contents
|
||||||
|
|
||||||
* [欢迎](README.md)
|
- [欢迎](README.md)
|
||||||
|
|
||||||
## 关于 <a href="#about" id="about"></a>
|
## 关于 <a href="#about" id="about"></a>
|
||||||
|
|
||||||
* [关于本项目](about/this-project.md)
|
- [关于本项目](about/this-project.md)
|
||||||
* [收录范围](about/scope-of-inclusion.md)
|
- [收录范围](about/scope-of-inclusion.md)
|
||||||
|
|
||||||
## 技术架构 <a href="#architecture" id="architecture"></a>
|
## 技术架构 <a href="#architecture" id="architecture"></a>
|
||||||
|
|
||||||
* [概览](architecture/overview.md)
|
- [概览](architecture/overview.md)
|
||||||
* [数据库结构](architecture/database-structure/README.md)
|
- [数据库结构](architecture/database-structure/README.md)
|
||||||
* [歌曲类型](architecture/database-structure/type-of-song.md)
|
- [歌曲类型](architecture/database-structure/type-of-song.md)
|
||||||
* [人工智能](architecture/artificial-intelligence.md)
|
- [人工智能](architecture/artificial-intelligence.md)
|
||||||
* [消息队列](architecture/message-queue/README.md)
|
- [消息队列](architecture/message-queue/README.md)
|
||||||
* [VideoTagsQueue队列](architecture/message-queue/video-tags-queue.md)
|
- [VideoTagsQueue队列](architecture/message-queue/video-tags-queue.md)
|
||||||
|
|
||||||
## API 文档 <a href="#api-doc" id="api-doc"></a>
|
## API 文档 <a href="#api-doc" id="api-doc"></a>
|
||||||
|
|
||||||
* [目录](api-doc/catalog.md)
|
- [目录](api-doc/catalog.md)
|
||||||
* [歌曲](api-doc/songs.md)
|
- [歌曲](api-doc/songs.md)
|
||||||
|
@ -6,7 +6,8 @@
|
|||||||
|
|
||||||
#### VOCALOID·UATU 分区
|
#### VOCALOID·UATU 分区
|
||||||
|
|
||||||
原则上,中V档案馆中收录的歌曲必须包含在哔哩哔哩 VOCALOID·UTAU 分区(分区ID为30)下的视频中。在某些特殊情况下,此规则可能不是强制的。
|
原则上,中V档案馆中收录的歌曲必须包含在哔哩哔哩 VOCALOID·UTAU
|
||||||
|
分区(分区ID为30)下的视频中。在某些特殊情况下,此规则可能不是强制的。
|
||||||
|
|
||||||
#### 至少一行中文
|
#### 至少一行中文
|
||||||
|
|
||||||
@ -16,4 +17,6 @@
|
|||||||
|
|
||||||
歌曲的至少一行必须由歌声合成器生成(包括和声部分),才能被收录到中V档案馆中。
|
歌曲的至少一行必须由歌声合成器生成(包括和声部分),才能被收录到中V档案馆中。
|
||||||
|
|
||||||
我们将歌声合成器定义为通过算法建模声音特征并根据输入的歌词、音高等参数生成音频的软件或系统,包括基于波形拼接的(如 VOCALOID、UTAU)和基于 AI 的(如 Synthesizer V、ACE Studio)方法,**但不包括仅改变现有歌声音色的AI声音转换器**(例如 [so-vits svc](https://github.com/svc-develop-team/so-vits-svc))。
|
我们将歌声合成器定义为通过算法建模声音特征并根据输入的歌词、音高等参数生成音频的软件或系统,包括基于波形拼接的(如
|
||||||
|
VOCALOID、UTAU)和基于 AI 的(如 Synthesizer V、ACE Studio)方法,**但不包括仅改变现有歌声音色的AI声音转换器**(例如
|
||||||
|
[so-vits svc](https://github.com/svc-develop-team/so-vits-svc))。
|
||||||
|
@ -6,34 +6,33 @@
|
|||||||
|
|
||||||
纵观整个互联网,对于「中文歌声合成」或「中文虚拟歌手」(常简称为中V或VC)相关信息进行较为系统、全面地整理收集的主要有以下几个网站:
|
纵观整个互联网,对于「中文歌声合成」或「中文虚拟歌手」(常简称为中V或VC)相关信息进行较为系统、全面地整理收集的主要有以下几个网站:
|
||||||
|
|
||||||
* [萌娘百科](https://zh.moegirl.org.cn/): 收录了大量中V歌曲及歌姬的信息,呈现形式为传统维基(基于[MediaWiki](https://www.mediawiki.org/))。
|
- [萌娘百科](https://zh.moegirl.org.cn/):
|
||||||
* [VCPedia](https://vcpedia.cn/): 由原萌娘百科中文歌声合成编辑团队的部分成员搭建,专属于中文歌声合成相关内容的信息集成站点[^1],呈现形式为传统维基(基于[MediaWiki](https://www.mediawiki.org/))。
|
收录了大量中V歌曲及歌姬的信息,呈现形式为传统维基(基于[MediaWiki](https://www.mediawiki.org/))。
|
||||||
* [VocaDB](https://vocadb.net/): [一个围绕 Vocaloid、UTAU 和其他歌声合成器的协作数据库,其中包含艺术家、唱片、PV 等](#user-content-fn-2)[^2],其中包含大量中文歌声合成作品。
|
- [VCPedia](https://vcpedia.cn/):
|
||||||
* [天钿Daily](https://tdd.bunnyxt.com/):一个VC相关数据交流与分享的网站。致力于VC相关数据交流,定期抓取VC相关数据,选取有意义的纬度展示。
|
由原萌娘百科中文歌声合成编辑团队的部分成员搭建,专属于中文歌声合成相关内容的信息集成站点[^1],呈现形式为传统维基(基于[MediaWiki](https://www.mediawiki.org/))。
|
||||||
|
- [VocaDB](https://vocadb.net/):
|
||||||
|
[一个围绕 Vocaloid、UTAU 和其他歌声合成器的协作数据库,其中包含艺术家、唱片、PV 等](#user-content-fn-2)[^2],其中包含大量中文歌声合成作品。
|
||||||
|
- [天钿Daily](https://tdd.bunnyxt.com/):一个VC相关数据交流与分享的网站。致力于VC相关数据交流,定期抓取VC相关数据,选取有意义的纬度展示。
|
||||||
|
|
||||||
上述网站中,或多或少存在一些不足,例如:
|
上述网站中,或多或少存在一些不足,例如:
|
||||||
|
|
||||||
* 萌娘百科、VCPedia受限于传统维基,绝大多数内容依赖人工编辑。
|
- 萌娘百科、VCPedia受限于传统维基,绝大多数内容依赖人工编辑。
|
||||||
* VocaDB基于结构化数据库构建,由此可以依赖程序生成一些信息,但**条目收录**仍然完全依赖人工完成。
|
- VocaDB基于结构化数据库构建,由此可以依赖程序生成一些信息,但**条目收录**仍然完全依赖人工完成。
|
||||||
* VocaDB主要专注于元数据展示,少有关于歌曲、作者等的描述性的文字,也缺乏描述性的背景信息。
|
- VocaDB主要专注于元数据展示,少有关于歌曲、作者等的描述性的文字,也缺乏描述性的背景信息。
|
||||||
* 天钿Daily只展示歌曲的统计数据及历史趋势,没有关于歌曲其它信息的收集。
|
- 天钿Daily只展示歌曲的统计数据及历史趋势,没有关于歌曲其它信息的收集。
|
||||||
|
|
||||||
因此,**中V档案馆**吸取前人经验,克服上述网站的不足,希望做到:
|
因此,**中V档案馆**吸取前人经验,克服上述网站的不足,希望做到:
|
||||||
|
|
||||||
* 歌曲收录(指发现歌曲并创建条目)的完全自动化
|
- 歌曲收录(指发现歌曲并创建条目)的完全自动化
|
||||||
* 歌曲元信息提取的高度自动化
|
- 歌曲元信息提取的高度自动化
|
||||||
* 歌曲统计数据收集的完全自动化
|
- 歌曲统计数据收集的完全自动化
|
||||||
* 在程序辅助的同时欢迎并鼓励贡献者参与编辑(主要为描述性内容)或纠错
|
- 在程序辅助的同时欢迎并鼓励贡献者参与编辑(主要为描述性内容)或纠错
|
||||||
* 在适当的许可声明下,引用来自上述源的数据,使内容更加全面、丰富。
|
- 在适当的许可声明下,引用来自上述源的数据,使内容更加全面、丰富。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
***
|
|
||||||
|
|
||||||
本文在[CC BY-NC-SA 4.0协议](https://creativecommons.org/licenses/by-nc-sa/4.0/)提供。
|
本文在[CC BY-NC-SA 4.0协议](https://creativecommons.org/licenses/by-nc-sa/4.0/)提供。
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
[^1]: 引用自[VCPedia](https://vcpedia.cn/%E9%A6%96%E9%A1%B5),于[知识共享 署名-非商业性使用-相同方式共享 3.0中国大陆 (CC BY-NC-SA 3.0 CN) 许可协议](https://creativecommons.org/licenses/by-nc-sa/3.0/cn/)下提供。
|
[^1]: 引用自[VCPedia](https://vcpedia.cn/%E9%A6%96%E9%A1%B5),于[知识共享 署名-非商业性使用-相同方式共享 3.0中国大陆 (CC BY-NC-SA 3.0 CN) 许可协议](https://creativecommons.org/licenses/by-nc-sa/3.0/cn/)下提供。
|
||||||
|
|
||||||
[^2]: 翻译自[VocaDB](https://vocadb.net/),于[CC BY 4.0协议](https://creativecommons.org/licenses/by/4.0/)下提供。
|
[^2]: 翻译自[VocaDB](https://vocadb.net/),于[CC BY 4.0协议](https://creativecommons.org/licenses/by/4.0/)下提供。
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
# 目录
|
# 目录
|
||||||
|
|
||||||
* [歌曲](songs.md)
|
- [歌曲](songs.md)
|
||||||
|
@ -6,8 +6,8 @@ CVSA 的自动化工作流高度依赖人工智能进行信息提取和分类。
|
|||||||
|
|
||||||
#### Filter
|
#### Filter
|
||||||
|
|
||||||
位于项目根目录下的 `/filter/`,它将 [30 分区](../about/scope-of-inclusion.md#vocaloiduatu-fen-qu) 中的视频分为以下类别:
|
位于项目根目录下的 `/filter/`,它将 [30 分区](../about/scope-of-inclusion.md#vocaloiduatu-fen-qu) 中的视频分为以下类别:
|
||||||
|
|
||||||
* 0:与中文人声合成无关
|
- 0:与中文人声合成无关
|
||||||
* 1:中文人声合成原创曲
|
- 1:中文人声合成原创曲
|
||||||
* 2:中文人声合成的翻唱/混音歌曲
|
- 2:中文人声合成的翻唱/混音歌曲
|
||||||
|
@ -4,7 +4,7 @@ CVSA 使用 [PostgreSQL](https://www.postgresql.org/) 作为数据库。
|
|||||||
|
|
||||||
CVSA 的所有公开数据(不包括用户的个人数据)都存储在名为 `cvsa_main` 的数据库中,该数据库包含以下表:
|
CVSA 的所有公开数据(不包括用户的个人数据)都存储在名为 `cvsa_main` 的数据库中,该数据库包含以下表:
|
||||||
|
|
||||||
* songs:存储歌曲的主要信息
|
- songs:存储歌曲的主要信息
|
||||||
* bili\_user:存储 Bilibili 用户信息快照
|
- bili\_user:存储 Bilibili 用户信息快照
|
||||||
* all\_data:[分区 30](../../about/scope-of-inclusion.md#vocaloiduatu-fen-qu) 中所有视频的元数据。
|
- all\_data:[分区 30](../../about/scope-of-inclusion.md#vocaloiduatu-fen-qu) 中所有视频的元数据。
|
||||||
* labelling\_result:包含由我们的 AI 系统 标记的 `all_data` 中视频的标签。
|
- labelling\_result:包含由我们的 AI 系统 标记的 `all_data` 中视频的标签。
|
||||||
|
@ -7,18 +7,18 @@
|
|||||||
`songs` 表格中使用的 `type` 列。
|
`songs` 表格中使用的 `type` 列。
|
||||||
|
|
||||||
| 类型 | 说明 |
|
| 类型 | 说明 |
|
||||||
| -- | ---------- |
|
| ---- | ------------ |
|
||||||
| 0 | 不相关 |
|
| 0 | 不相关 |
|
||||||
| 1 | 原创 |
|
| 1 | 原创 |
|
||||||
| 2 | 翻唱 (Cover) |
|
| 2 | 翻唱 (Cover) |
|
||||||
| 3 | 混音 (Remix) |
|
| 3 | 混音 (Remix) |
|
||||||
| 4 | 纯音乐 |
|
| 4 | 纯音乐 |
|
||||||
| 10 | 其他 |
|
| 10 | 其他 |
|
||||||
|
|
||||||
#### 表格:`labelling_result`
|
#### 表格:`labelling_result`
|
||||||
|
|
||||||
| 标签 | 说明 |
|
| 标签 | 说明 |
|
||||||
| -- | ----------- |
|
| ---- | ------------------ |
|
||||||
| 0 | AI 标记:不相关 |
|
| 0 | AI 标记:不相关 |
|
||||||
| 1 | AI 标记:原创 |
|
| 1 | AI 标记:原创 |
|
||||||
| 2 | AI 标记:翻唱/混音 |
|
| 2 | AI 标记:翻唱/混音 |
|
||||||
|
@ -1,2 +1 @@
|
|||||||
# 消息队列
|
# 消息队列
|
||||||
|
|
||||||
|
@ -1,31 +0,0 @@
|
|||||||
import torch
|
|
||||||
from model2vec import StaticModel
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_batch(batch_data, device="cpu"):
|
|
||||||
"""
|
|
||||||
将输入的 batch_data 转换为模型所需的输入格式 [batch_size, num_channels, embedding_dim]。
|
|
||||||
|
|
||||||
参数:
|
|
||||||
batch_data (dict): 输入的 batch 数据,格式为 {
|
|
||||||
"title": [text1, text2, ...],
|
|
||||||
"description": [text1, text2, ...],
|
|
||||||
"tags": [text1, text2, ...],
|
|
||||||
"author_info": [text1, text2, ...]
|
|
||||||
}
|
|
||||||
device (str): 模型运行的设备(如 "cpu" 或 "cuda")。
|
|
||||||
|
|
||||||
返回:
|
|
||||||
torch.Tensor: 形状为 [batch_size, num_channels, embedding_dim] 的张量。
|
|
||||||
"""
|
|
||||||
# 1. 对每个通道的文本分别编码
|
|
||||||
channel_embeddings = []
|
|
||||||
model = StaticModel.from_pretrained("./model/embedding/")
|
|
||||||
for channel in ["title", "description", "tags", "author_info"]:
|
|
||||||
texts = batch_data[channel] # 获取当前通道的文本列表
|
|
||||||
embeddings = torch.from_numpy(model.encode(texts)).to(torch.float32).to(device) # 编码为 [batch_size, embedding_dim]
|
|
||||||
channel_embeddings.append(embeddings)
|
|
||||||
|
|
||||||
# 2. 将编码结果堆叠为 [batch_size, num_channels, embedding_dim]
|
|
||||||
batch_tensor = torch.stack(channel_embeddings, dim=1) # 在 dim=1 上堆叠
|
|
||||||
return batch_tensor
|
|
@ -1,58 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
class VideoClassifierV3(nn.Module):
|
|
||||||
def __init__(self, embedding_dim=1024, hidden_dim=256, output_dim=3):
|
|
||||||
super().__init__()
|
|
||||||
self.num_channels = 4
|
|
||||||
self.channel_names = ['title', 'description', 'tags', 'author_info']
|
|
||||||
|
|
||||||
# 改进1:带温度系数的通道权重(比原始固定权重更灵活)
|
|
||||||
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))
|
|
||||||
self.temperature = 2.0 # 可调节的平滑系数
|
|
||||||
|
|
||||||
# 改进2:更稳健的全连接结构
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
nn.Linear(embedding_dim * self.num_channels, hidden_dim*2),
|
|
||||||
nn.BatchNorm1d(hidden_dim*2),
|
|
||||||
nn.Dropout(0.2),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Linear(hidden_dim*2, hidden_dim),
|
|
||||||
nn.LayerNorm(hidden_dim),
|
|
||||||
nn.Linear(hidden_dim, output_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 改进3:输出层初始化
|
|
||||||
nn.init.xavier_uniform_(self.fc[-1].weight)
|
|
||||||
nn.init.zeros_(self.fc[-1].bias)
|
|
||||||
|
|
||||||
def forward(self, input_texts, sentence_transformer):
|
|
||||||
# 合并所有通道文本进行批量编码
|
|
||||||
all_texts = [text for channel in self.channel_names for text in input_texts[channel]]
|
|
||||||
|
|
||||||
# 使用SentenceTransformer生成嵌入(保持冻结)
|
|
||||||
with torch.no_grad():
|
|
||||||
task = "classification"
|
|
||||||
embeddings = torch.tensor(
|
|
||||||
sentence_transformer.encode(all_texts, task=task),
|
|
||||||
device=next(self.parameters()).device
|
|
||||||
)
|
|
||||||
|
|
||||||
# 分割嵌入并加权
|
|
||||||
split_sizes = [len(input_texts[name]) for name in self.channel_names]
|
|
||||||
channel_features = torch.split(embeddings, split_sizes, dim=0)
|
|
||||||
channel_features = torch.stack(channel_features, dim=1) # [batch, 4, 1024]
|
|
||||||
|
|
||||||
# 改进4:带温度系数的softmax加权
|
|
||||||
weights = torch.softmax(self.channel_weights / self.temperature, dim=0)
|
|
||||||
weighted_features = channel_features * weights.unsqueeze(0).unsqueeze(-1)
|
|
||||||
|
|
||||||
# 拼接特征
|
|
||||||
combined = weighted_features.view(weighted_features.size(0), -1)
|
|
||||||
|
|
||||||
# 全连接层
|
|
||||||
return self.fc(combined)
|
|
||||||
|
|
||||||
def get_channel_weights(self):
|
|
||||||
"""获取各通道权重(带温度调节)"""
|
|
||||||
return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy()
|
|
@ -1,111 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
class VideoClassifierV3_4(nn.Module):
|
|
||||||
def __init__(self, embedding_dim=1024, hidden_dim=512, output_dim=3):
|
|
||||||
super().__init__()
|
|
||||||
self.num_channels = 4
|
|
||||||
self.channel_names = ['title', 'description', 'tags', 'author_info']
|
|
||||||
|
|
||||||
# 可学习温度系数
|
|
||||||
self.temperature = nn.Parameter(torch.tensor(1.7))
|
|
||||||
|
|
||||||
# 带约束的通道权重(使用Sigmoid替代Softmax)
|
|
||||||
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))
|
|
||||||
|
|
||||||
# 增强的非线性层
|
|
||||||
self.fc = nn.Sequential(
|
|
||||||
nn.Linear(embedding_dim * self.num_channels, hidden_dim*2),
|
|
||||||
nn.BatchNorm1d(hidden_dim*2),
|
|
||||||
nn.Dropout(0.3),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Linear(hidden_dim*2, hidden_dim),
|
|
||||||
nn.BatchNorm1d(hidden_dim),
|
|
||||||
nn.Dropout(0.2),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Linear(hidden_dim, output_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 权重初始化
|
|
||||||
self._init_weights()
|
|
||||||
|
|
||||||
def _init_weights(self):
|
|
||||||
for layer in self.fc:
|
|
||||||
if isinstance(layer, nn.Linear):
|
|
||||||
# 使用ReLU的初始化参数(GELU的近似)
|
|
||||||
nn.init.kaiming_normal_(layer.weight, nonlinearity='relu') # 修改这里
|
|
||||||
|
|
||||||
# 或者使用Xavier初始化(更适合通用场景)
|
|
||||||
# nn.init.xavier_normal_(layer.weight, gain=nn.init.calculate_gain('relu'))
|
|
||||||
|
|
||||||
nn.init.zeros_(layer.bias)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, input_texts, sentence_transformer):
|
|
||||||
# 合并文本进行批量编码
|
|
||||||
all_texts = [text for channel in self.channel_names for text in input_texts[channel]]
|
|
||||||
|
|
||||||
# 冻结的文本编码
|
|
||||||
with torch.no_grad():
|
|
||||||
embeddings = torch.tensor(
|
|
||||||
sentence_transformer.encode(all_texts),
|
|
||||||
device=next(self.parameters()).device
|
|
||||||
)
|
|
||||||
|
|
||||||
# 分割并加权通道特征
|
|
||||||
split_sizes = [len(input_texts[name]) for name in self.channel_names]
|
|
||||||
channel_features = torch.split(embeddings, split_sizes, dim=0)
|
|
||||||
channel_features = torch.stack(channel_features, dim=1)
|
|
||||||
|
|
||||||
# 自适应通道权重(Sigmoid约束)
|
|
||||||
weights = torch.sigmoid(self.channel_weights) # [0,1]范围
|
|
||||||
weighted_features = channel_features * weights.unsqueeze(0).unsqueeze(-1)
|
|
||||||
|
|
||||||
# 特征拼接
|
|
||||||
combined = weighted_features.view(weighted_features.size(0), -1)
|
|
||||||
|
|
||||||
return self.fc(combined)
|
|
||||||
|
|
||||||
def get_channel_weights(self):
|
|
||||||
"""获取各通道权重(带温度调节)"""
|
|
||||||
return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy()
|
|
||||||
|
|
||||||
|
|
||||||
class AdaptiveRecallLoss(nn.Module):
|
|
||||||
def __init__(self, class_weights, alpha=0.8, gamma=2.0, fp_penalty=0.5):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
class_weights (torch.Tensor): 类别权重
|
|
||||||
alpha (float): 召回率调节因子(0-1)
|
|
||||||
gamma (float): Focal Loss参数
|
|
||||||
fp_penalty (float): 类别0假阳性惩罚强度
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.class_weights = class_weights
|
|
||||||
self.alpha = alpha
|
|
||||||
self.gamma = gamma
|
|
||||||
self.fp_penalty = fp_penalty
|
|
||||||
|
|
||||||
def forward(self, logits, targets):
|
|
||||||
# 基础交叉熵损失
|
|
||||||
ce_loss = F.cross_entropy(logits, targets, weight=self.class_weights, reduction='none')
|
|
||||||
|
|
||||||
# Focal Loss组件
|
|
||||||
pt = torch.exp(-ce_loss)
|
|
||||||
focal_loss = ((1 - pt) ** self.gamma) * ce_loss
|
|
||||||
|
|
||||||
# 召回率增强(对困难样本加权)
|
|
||||||
class_mask = F.one_hot(targets, num_classes=len(self.class_weights))
|
|
||||||
class_weights = (self.alpha + (1 - self.alpha) * pt.unsqueeze(-1)) * class_mask
|
|
||||||
recall_loss = (class_weights * focal_loss.unsqueeze(-1)).sum(dim=1)
|
|
||||||
|
|
||||||
# 类别0假阳性惩罚
|
|
||||||
probs = F.softmax(logits, dim=1)
|
|
||||||
fp_mask = (targets != 0) & (torch.argmax(logits, dim=1) == 0)
|
|
||||||
fp_loss = self.fp_penalty * probs[:, 0][fp_mask].pow(2).sum()
|
|
||||||
|
|
||||||
# 总损失
|
|
||||||
total_loss = recall_loss.mean() + fp_loss / len(targets)
|
|
||||||
|
|
||||||
return total_loss
|
|
148
filter/old.py
148
filter/old.py
@ -1,148 +0,0 @@
|
|||||||
import os
|
|
||||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"]="1"
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
import json
|
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
class VideoDataset(Dataset):
|
|
||||||
def __init__(self, data_path, sentence_transformer):
|
|
||||||
self.data = []
|
|
||||||
self.sentence_transformer = sentence_transformer
|
|
||||||
with open(data_path, "r", encoding="utf-8") as f:
|
|
||||||
for line in f:
|
|
||||||
self.data.append(json.loads(line))
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.data)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
item = self.data[idx]
|
|
||||||
title = item["title"]
|
|
||||||
description = item["description"]
|
|
||||||
tags = item["tags"]
|
|
||||||
label = item["label"]
|
|
||||||
|
|
||||||
# 获取每个特征的嵌入
|
|
||||||
title_embedding = self.get_embedding(title)
|
|
||||||
description_embedding = self.get_embedding(description)
|
|
||||||
tags_embedding = self.get_embedding(" ".join(tags))
|
|
||||||
|
|
||||||
# 将嵌入连接起来
|
|
||||||
combined_embedding = torch.cat([title_embedding, description_embedding, tags_embedding], dim=0)
|
|
||||||
|
|
||||||
return combined_embedding, label
|
|
||||||
|
|
||||||
def get_embedding(self, text):
|
|
||||||
# 使用SentenceTransformer生成嵌入
|
|
||||||
embedding = self.sentence_transformer.encode(text)
|
|
||||||
return torch.tensor(embedding)
|
|
||||||
|
|
||||||
class VideoClassifier(nn.Module):
|
|
||||||
def __init__(self, embedding_dim=768, hidden_dim=256, output_dim=3):
|
|
||||||
super(VideoClassifier, self).__init__()
|
|
||||||
# 每个特征的嵌入维度是embedding_dim,总共有3个特征
|
|
||||||
total_embedding_dim = embedding_dim * 3
|
|
||||||
|
|
||||||
# 全连接层
|
|
||||||
self.fc1 = nn.Linear(total_embedding_dim, hidden_dim)
|
|
||||||
self.fc2 = nn.Linear(hidden_dim, output_dim)
|
|
||||||
self.log_softmax = nn.LogSoftmax(dim=1)
|
|
||||||
|
|
||||||
def forward(self, embedding_features):
|
|
||||||
# 全连接层
|
|
||||||
x = torch.relu(self.fc1(embedding_features))
|
|
||||||
output = self.fc2(x)
|
|
||||||
output = self.log_softmax(output)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def train(model, dataloader, criterion, optimizer, device):
|
|
||||||
model.train()
|
|
||||||
total_loss = 0
|
|
||||||
correct = 0
|
|
||||||
total = 0
|
|
||||||
for embedding_features, labels in dataloader:
|
|
||||||
embedding_features = embedding_features.to(device)
|
|
||||||
labels = labels.to(device)
|
|
||||||
optimizer.zero_grad()
|
|
||||||
outputs = model(embedding_features)
|
|
||||||
loss = criterion(outputs, labels)
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
total_loss += loss.item()
|
|
||||||
_, predicted = torch.max(outputs, 1)
|
|
||||||
correct += (predicted == labels).sum().item()
|
|
||||||
total += labels.size(0)
|
|
||||||
avg_loss = total_loss / len(dataloader)
|
|
||||||
accuracy = correct / total
|
|
||||||
return avg_loss, accuracy
|
|
||||||
|
|
||||||
def validate(model, dataloader, criterion, device):
|
|
||||||
model.eval()
|
|
||||||
total_loss = 0
|
|
||||||
correct = 0
|
|
||||||
total = 0
|
|
||||||
with torch.no_grad():
|
|
||||||
for embedding_features, labels in dataloader:
|
|
||||||
embedding_features = embedding_features.to(device)
|
|
||||||
labels = labels.to(device)
|
|
||||||
outputs = model(embedding_features)
|
|
||||||
loss = criterion(outputs, labels)
|
|
||||||
total_loss += loss.item()
|
|
||||||
_, predicted = torch.max(outputs, 1)
|
|
||||||
correct += (predicted == labels).sum().item()
|
|
||||||
total += labels.size(0)
|
|
||||||
avg_loss = total_loss / len(dataloader)
|
|
||||||
accuracy = correct / total
|
|
||||||
return avg_loss, accuracy
|
|
||||||
|
|
||||||
# 超参数
|
|
||||||
hidden_dim = 256
|
|
||||||
output_dim = 3
|
|
||||||
batch_size = 32
|
|
||||||
num_epochs = 10
|
|
||||||
learning_rate = 0.001
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
|
|
||||||
|
|
||||||
# 加载数据集
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3")
|
|
||||||
sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
|
|
||||||
dataset = VideoDataset("labeled_data.jsonl", sentence_transformer=sentence_transformer)
|
|
||||||
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
|
||||||
# 初始化模型
|
|
||||||
model = VideoClassifier(embedding_dim=768, hidden_dim=256, output_dim=3).to(device)
|
|
||||||
|
|
||||||
# 损失函数和优化器
|
|
||||||
criterion = nn.CrossEntropyLoss()
|
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
|
||||||
num_epochs = 5
|
|
||||||
# 训练和验证
|
|
||||||
for epoch in range(num_epochs):
|
|
||||||
train_loss, train_acc = train(model, dataloader, criterion, optimizer, device)
|
|
||||||
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}")
|
|
||||||
|
|
||||||
# 保存模型
|
|
||||||
torch.save(model.state_dict(), "video_classifier.pth")
|
|
||||||
model.eval() # 设置为评估模式
|
|
||||||
|
|
||||||
# 2. 定义推理函数
|
|
||||||
def predict(model, sentence_transformer, title, description, tags, device):
|
|
||||||
# 将输入数据转换为嵌入
|
|
||||||
title_embedding = torch.tensor(sentence_transformer.encode(title)).to(device)
|
|
||||||
description_embedding = torch.tensor(sentence_transformer.encode(description)).to(device)
|
|
||||||
tags_embedding = torch.tensor(sentence_transformer.encode(" ".join(tags))).to(device)
|
|
||||||
|
|
||||||
# 将嵌入连接起来
|
|
||||||
combined_embedding = torch.cat([title_embedding, description_embedding, tags_embedding], dim=0).unsqueeze(0)
|
|
||||||
|
|
||||||
# 推理
|
|
||||||
with torch.no_grad():
|
|
||||||
output = model(combined_embedding)
|
|
||||||
_, predicted = torch.max(output, 1)
|
|
||||||
|
|
||||||
return predicted.item()
|
|
@ -1,6 +0,0 @@
|
|||||||
import { defineConfig } from "$fresh/server.ts";
|
|
||||||
import tailwind from "$fresh/plugins/tailwind.ts";
|
|
||||||
|
|
||||||
export default defineConfig({
|
|
||||||
plugins: [tailwind()],
|
|
||||||
});
|
|
27
fresh.gen.ts
27
fresh.gen.ts
@ -1,27 +0,0 @@
|
|||||||
// DO NOT EDIT. This file is generated by Fresh.
|
|
||||||
// This file SHOULD be checked into source version control.
|
|
||||||
// This file is automatically updated during development when running `dev.ts`.
|
|
||||||
|
|
||||||
import * as $_404 from "./routes/_404.tsx";
|
|
||||||
import * as $_app from "./routes/_app.tsx";
|
|
||||||
import * as $api_joke from "./routes/api/joke.ts";
|
|
||||||
import * as $greet_name_ from "./routes/greet/[name].tsx";
|
|
||||||
import * as $index from "./routes/index.tsx";
|
|
||||||
import * as $Counter from "./islands/Counter.tsx";
|
|
||||||
import type { Manifest } from "$fresh/server.ts";
|
|
||||||
|
|
||||||
const manifest = {
|
|
||||||
routes: {
|
|
||||||
"./routes/_404.tsx": $_404,
|
|
||||||
"./routes/_app.tsx": $_app,
|
|
||||||
"./routes/api/joke.ts": $api_joke,
|
|
||||||
"./routes/greet/[name].tsx": $greet_name_,
|
|
||||||
"./routes/index.tsx": $index,
|
|
||||||
},
|
|
||||||
islands: {
|
|
||||||
"./islands/Counter.tsx": $Counter,
|
|
||||||
},
|
|
||||||
baseUrl: import.meta.url,
|
|
||||||
} satisfies Manifest;
|
|
||||||
|
|
||||||
export default manifest;
|
|
@ -1,16 +0,0 @@
|
|||||||
import type { Signal } from "@preact/signals";
|
|
||||||
import { Button } from "../components/Button.tsx";
|
|
||||||
|
|
||||||
interface CounterProps {
|
|
||||||
count: Signal<number>;
|
|
||||||
}
|
|
||||||
|
|
||||||
export default function Counter(props: CounterProps) {
|
|
||||||
return (
|
|
||||||
<div class="flex gap-8 py-6">
|
|
||||||
<Button onClick={() => props.count.value -= 1}>-1</Button>
|
|
||||||
<p class="text-3xl tabular-nums">{props.count}</p>
|
|
||||||
<Button onClick={() => props.count.value += 1}>+1</Button>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
@ -1,61 +0,0 @@
|
|||||||
import { Client, Transaction } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
|
|
||||||
import { AllDataType } from "lib/db/schema.d.ts";
|
|
||||||
import logger from "lib/log/logger.ts";
|
|
||||||
import { parseTimestampFromPsql } from "lib/utils/formatTimestampToPostgre.ts";
|
|
||||||
|
|
||||||
export async function videoExistsInAllData(client: Client, aid: number) {
|
|
||||||
return await client.queryObject<{ exists: boolean }>(`SELECT EXISTS(SELECT 1 FROM all_data WHERE aid = $1)`, [aid])
|
|
||||||
.then((result) => result.rows[0].exists);
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function insertIntoAllData(client: Client, data: AllDataType) {
|
|
||||||
logger.log(`inserted ${data.aid}`, "db-all_data");
|
|
||||||
return await client.queryObject(
|
|
||||||
`INSERT INTO all_data (aid, bvid, description, uid, tags, title, published_at)
|
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
|
||||||
ON CONFLICT (aid) DO NOTHING`,
|
|
||||||
[data.aid, data.bvid, data.description, data.uid, data.tags, data.title, data.published_at],
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function getLatestVideoTimestampFromAllData(client: Client) {
|
|
||||||
return await client.queryObject<{ published_at: string }>(
|
|
||||||
`SELECT published_at FROM all_data ORDER BY published_at DESC LIMIT 1`,
|
|
||||||
)
|
|
||||||
.then((result) => {
|
|
||||||
const date = new Date(result.rows[0].published_at);
|
|
||||||
if (isNaN(date.getTime())) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
return date.getTime();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function videoTagsIsNull(client: Client | Transaction, aid: number) {
|
|
||||||
return await client.queryObject<{ exists: boolean }>(
|
|
||||||
`SELECT EXISTS(SELECT 1 FROM all_data WHERE aid = $1 AND tags IS NULL)`,
|
|
||||||
[aid],
|
|
||||||
).then((result) => result.rows[0].exists);
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function updateVideoTags(client: Client | Transaction, aid: number, tags: string[]) {
|
|
||||||
return await client.queryObject(
|
|
||||||
`UPDATE all_data SET tags = $1 WHERE aid = $2`,
|
|
||||||
[tags.join(","), aid],
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function getNullVideoTagsList(client: Client) {
|
|
||||||
const queryResult = await client.queryObject<{ aid: number; published_at: string }>(
|
|
||||||
`SELECT aid, published_at FROM all_data WHERE tags IS NULL`,
|
|
||||||
);
|
|
||||||
const rows = queryResult.rows;
|
|
||||||
return rows.map(
|
|
||||||
(row) => {
|
|
||||||
return {
|
|
||||||
aid: Number(row.aid),
|
|
||||||
published_at: parseTimestampFromPsql(row.published_at),
|
|
||||||
};
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
@ -1,6 +0,0 @@
|
|||||||
import { Pool } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
|
|
||||||
import {postgresConfig} from "lib/db/pgConfig.ts";
|
|
||||||
|
|
||||||
const pool = new Pool(postgresConfig, 32);
|
|
||||||
|
|
||||||
export const db = pool;
|
|
@ -1,3 +0,0 @@
|
|||||||
import { Redis } from "ioredis";
|
|
||||||
|
|
||||||
export const redis = new Redis({ maxRetriesPerRequest: null });
|
|
9
lib/db/schema.d.ts
vendored
9
lib/db/schema.d.ts
vendored
@ -1,9 +0,0 @@
|
|||||||
export interface AllDataType {
|
|
||||||
aid: number;
|
|
||||||
bvid: string | null;
|
|
||||||
description: string | null;
|
|
||||||
uid: number | null;
|
|
||||||
tags: string | null;
|
|
||||||
title: string | null;
|
|
||||||
published_at: string | null;
|
|
||||||
}
|
|
@ -1,19 +0,0 @@
|
|||||||
import { SentenceTransformer } from "./model.ts"; // Changed import path
|
|
||||||
|
|
||||||
async function main() {
|
|
||||||
const sentenceTransformer = await SentenceTransformer.from_pretrained(
|
|
||||||
"mixedbread-ai/mxbai-embed-large-v1",
|
|
||||||
);
|
|
||||||
const outputs = await sentenceTransformer.encode([
|
|
||||||
"Hello world",
|
|
||||||
"How are you guys doing?",
|
|
||||||
"Today is Friday!",
|
|
||||||
]);
|
|
||||||
|
|
||||||
// @ts-ignore
|
|
||||||
console.log(outputs["last_hidden_state"]);
|
|
||||||
|
|
||||||
return outputs;
|
|
||||||
}
|
|
||||||
|
|
||||||
main(); // Keep main function call if you want this file to be runnable directly for testing.
|
|
@ -1,40 +0,0 @@
|
|||||||
// lib/ml/sentence_transformer_model.ts
|
|
||||||
import { AutoModel, AutoTokenizer, PretrainedOptions } from "@huggingface/transformers";
|
|
||||||
|
|
||||||
export class SentenceTransformer {
|
|
||||||
constructor(
|
|
||||||
private readonly tokenizer: AutoTokenizer,
|
|
||||||
private readonly model: AutoModel,
|
|
||||||
) {}
|
|
||||||
|
|
||||||
static async from_pretrained(
|
|
||||||
modelName: string,
|
|
||||||
options?: PretrainedOptions,
|
|
||||||
): Promise<SentenceTransformer> {
|
|
||||||
if (!options) {
|
|
||||||
options = {
|
|
||||||
progress_callback: undefined,
|
|
||||||
cache_dir: undefined,
|
|
||||||
local_files_only: false,
|
|
||||||
revision: "main",
|
|
||||||
};
|
|
||||||
}
|
|
||||||
const tokenizer = await AutoTokenizer.from_pretrained(modelName, options);
|
|
||||||
const model = await AutoModel.from_pretrained(modelName, options);
|
|
||||||
|
|
||||||
return new SentenceTransformer(tokenizer, model);
|
|
||||||
}
|
|
||||||
|
|
||||||
async encode(sentences: string[]): Promise<any> { // Changed return type to 'any' for now to match console.log output
|
|
||||||
//@ts-ignore
|
|
||||||
const modelInputs = await this.tokenizer(sentences, {
|
|
||||||
padding: true,
|
|
||||||
truncation: true,
|
|
||||||
});
|
|
||||||
|
|
||||||
//@ts-ignore
|
|
||||||
const outputs = await this.model(modelInputs);
|
|
||||||
|
|
||||||
return outputs;
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,34 +0,0 @@
|
|||||||
import { Tensor } from "@huggingface/transformers";
|
|
||||||
//@ts-ignore
|
|
||||||
import { Callable } from "@huggingface/transformers/src/utils/core.js"; // Keep as is for now, might need adjustment
|
|
||||||
|
|
||||||
export interface PoolingConfig {
|
|
||||||
word_embedding_dimension: number;
|
|
||||||
pooling_mode_cls_token: boolean;
|
|
||||||
pooling_mode_mean_tokens: boolean;
|
|
||||||
pooling_mode_max_tokens: boolean;
|
|
||||||
pooling_mode_mean_sqrt_len_tokens: boolean;
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface PoolingInput {
|
|
||||||
token_embeddings: Tensor;
|
|
||||||
attention_mask: Tensor;
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface PoolingOutput {
|
|
||||||
sentence_embedding: Tensor;
|
|
||||||
}
|
|
||||||
|
|
||||||
export class Pooling extends Callable {
|
|
||||||
constructor(private readonly config: PoolingConfig) {
|
|
||||||
super();
|
|
||||||
}
|
|
||||||
|
|
||||||
// async _call(inputs: any) { // Keep if pooling functionality is needed
|
|
||||||
// return this.forward(inputs);
|
|
||||||
// }
|
|
||||||
|
|
||||||
// async forward(inputs: PoolingInput): PoolingOutput { // Keep if pooling functionality is needed
|
|
||||||
|
|
||||||
// }
|
|
||||||
}
|
|
@ -1,32 +0,0 @@
|
|||||||
import { AutoModel, AutoTokenizer, Tensor } from '@huggingface/transformers';
|
|
||||||
|
|
||||||
const modelName = "alikia2x/jina-embedding-v3-m2v-1024";
|
|
||||||
|
|
||||||
const modelConfig = {
|
|
||||||
config: { model_type: 'model2vec' },
|
|
||||||
dtype: 'fp32',
|
|
||||||
revision: 'refs/pr/1',
|
|
||||||
cache_dir: undefined,
|
|
||||||
local_files_only: true,
|
|
||||||
};
|
|
||||||
const tokenizerConfig = {
|
|
||||||
revision: 'refs/pr/2'
|
|
||||||
};
|
|
||||||
|
|
||||||
const model = await AutoModel.from_pretrained(modelName, modelConfig);
|
|
||||||
const tokenizer = await AutoTokenizer.from_pretrained(modelName, tokenizerConfig);
|
|
||||||
|
|
||||||
const texts = ['hello', 'hello world'];
|
|
||||||
const { input_ids } = await tokenizer(texts, { add_special_tokens: false, return_tensor: false });
|
|
||||||
|
|
||||||
const cumsum = arr => arr.reduce((acc, num, i) => [...acc, num + (acc[i - 1] || 0)], []);
|
|
||||||
const offsets = [0, ...cumsum(input_ids.slice(0, -1).map(x => x.length))];
|
|
||||||
|
|
||||||
const flattened_input_ids = input_ids.flat();
|
|
||||||
const modelInputs = {
|
|
||||||
input_ids: new Tensor('int64', flattened_input_ids, [flattened_input_ids.length]),
|
|
||||||
offsets: new Tensor('int64', offsets, [offsets.length])
|
|
||||||
};
|
|
||||||
|
|
||||||
const { embeddings } = await model(modelInputs);
|
|
||||||
console.log(embeddings.tolist()); // output matches python version
|
|
@ -1,52 +0,0 @@
|
|||||||
import { Job } from "bullmq";
|
|
||||||
import { insertLatestVideos } from "lib/task/insertLatestVideo.ts";
|
|
||||||
import { LatestVideosQueue } from "lib/mq/index.ts";
|
|
||||||
import { MINUTE } from "$std/datetime/constants.ts";
|
|
||||||
import { db } from "lib/db/init.ts";
|
|
||||||
import { truncate } from "lib/utils/truncate.ts";
|
|
||||||
import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
|
|
||||||
import logger from "lib/log/logger.ts";
|
|
||||||
import { lockManager } from "lib/mq/lockManager.ts";
|
|
||||||
|
|
||||||
const delayMap = [5, 10, 15, 30, 60, 60];
|
|
||||||
|
|
||||||
const updateQueueInterval = async (failedCount: number, delay: number) => {
|
|
||||||
logger.log(`job:getLatestVideos added to queue, delay: ${(delay / MINUTE).toFixed(2)} minutes.`, "mq");
|
|
||||||
await LatestVideosQueue.upsertJobScheduler("getLatestVideos", {
|
|
||||||
every: delay,
|
|
||||||
}, {
|
|
||||||
data: {
|
|
||||||
failedCount: failedCount,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
};
|
|
||||||
|
|
||||||
const executeTask = async (client: Client, failedCount: number) => {
|
|
||||||
const result = await insertLatestVideos(client);
|
|
||||||
failedCount = result !== 0 ? truncate(failedCount + 1, 0, 5) : 0;
|
|
||||||
if (failedCount !== 0) {
|
|
||||||
await updateQueueInterval(failedCount, delayMap[failedCount] * MINUTE);
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
};
|
|
||||||
|
|
||||||
export const getLatestVideosWorker = async (job: Job) => {
|
|
||||||
if (await lockManager.isLocked("getLatestVideos")) {
|
|
||||||
logger.log("job:getLatestVideos is locked, skipping.", "mq");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
lockManager.acquireLock("getLatestVideos");
|
|
||||||
|
|
||||||
const failedCount = (job.data.failedCount ?? 0) as number;
|
|
||||||
const client = await db.connect();
|
|
||||||
|
|
||||||
try {
|
|
||||||
await executeTask(client, failedCount);
|
|
||||||
} finally {
|
|
||||||
client.release();
|
|
||||||
lockManager.releaseLock("getLatestVideos");
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
};
|
|
@ -1,99 +0,0 @@
|
|||||||
import { Job } from "bullmq";
|
|
||||||
import { VideoTagsQueue } from "lib/mq/index.ts";
|
|
||||||
import { DAY, HOUR, MINUTE, SECOND } from "$std/datetime/constants.ts";
|
|
||||||
import { db } from "lib/db/init.ts";
|
|
||||||
import { truncate } from "lib/utils/truncate.ts";
|
|
||||||
import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
|
|
||||||
import logger from "lib/log/logger.ts";
|
|
||||||
import { getNullVideoTagsList, updateVideoTags } from "lib/db/allData.ts";
|
|
||||||
import { getVideoTags } from "lib/net/getVideoTags.ts";
|
|
||||||
import { NetSchedulerError } from "lib/mq/scheduler.ts";
|
|
||||||
import { WorkerError } from "src/worker.ts";
|
|
||||||
|
|
||||||
const delayMap = [0.5, 3, 5, 15, 30, 60];
|
|
||||||
const getJobPriority = (diff: number) => {
|
|
||||||
let priority;
|
|
||||||
if (diff > 14 * DAY) {
|
|
||||||
priority = 10;
|
|
||||||
} else if (diff > 7 * DAY) {
|
|
||||||
priority = 7;
|
|
||||||
} else if (diff > DAY) {
|
|
||||||
priority = 5;
|
|
||||||
} else if (diff > 6 * HOUR) {
|
|
||||||
priority = 3;
|
|
||||||
} else if (diff > HOUR) {
|
|
||||||
priority = 2;
|
|
||||||
} else {
|
|
||||||
priority = 1;
|
|
||||||
}
|
|
||||||
return priority;
|
|
||||||
};
|
|
||||||
|
|
||||||
const executeTask = async (client: Client, aid: number, failedCount: number, job: Job) => {
|
|
||||||
try {
|
|
||||||
const result = await getVideoTags(aid);
|
|
||||||
if (!result) {
|
|
||||||
failedCount = truncate(failedCount + 1, 0, 5);
|
|
||||||
const delay = delayMap[failedCount] * MINUTE;
|
|
||||||
logger.log(
|
|
||||||
`job:getVideoTags added to queue, delay: ${delayMap[failedCount]} minutes.`,
|
|
||||||
"mq",
|
|
||||||
);
|
|
||||||
await VideoTagsQueue.add("getVideoTags", { aid, failedCount }, { delay, priority: 6 - failedCount });
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
await updateVideoTags(client, aid, result);
|
|
||||||
logger.log(`Fetched tags for aid: ${aid}`, "task");
|
|
||||||
return 0;
|
|
||||||
} catch (e) {
|
|
||||||
if (!(e instanceof NetSchedulerError)) {
|
|
||||||
throw new WorkerError(<Error> e, "task", "getVideoTags/fn:executeTask");
|
|
||||||
}
|
|
||||||
const err = e as NetSchedulerError;
|
|
||||||
if (err.code === "NO_AVAILABLE_PROXY" || err.code === "PROXY_RATE_LIMITED") {
|
|
||||||
logger.warn(`No available proxy for fetching tags, delayed. aid: ${aid}`, "task");
|
|
||||||
await VideoTagsQueue.add("getVideoTags", { aid, failedCount }, {
|
|
||||||
delay: 25 * SECOND * Math.random() + 5 * SECOND,
|
|
||||||
priority: job.priority,
|
|
||||||
});
|
|
||||||
return 2;
|
|
||||||
}
|
|
||||||
throw new WorkerError(err, "task", "getVideoTags/fn:executeTask");
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
export const getVideoTagsWorker = async (job: Job) => {
|
|
||||||
const failedCount = (job.data.failedCount ?? 0) as number;
|
|
||||||
const client = await db.connect();
|
|
||||||
const aid = job.data.aid;
|
|
||||||
if (!aid) {
|
|
||||||
return 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
const v = await executeTask(client, aid, failedCount, job);
|
|
||||||
client.release();
|
|
||||||
return v;
|
|
||||||
};
|
|
||||||
|
|
||||||
export const getVideoTagsInitializer = async () => {
|
|
||||||
const client = await db.connect();
|
|
||||||
const videos = await getNullVideoTagsList(client);
|
|
||||||
if (videos.length == 0) {
|
|
||||||
return 4;
|
|
||||||
}
|
|
||||||
const count = await VideoTagsQueue.getJobCounts("wait", "delayed", "active");
|
|
||||||
const total = count.delayed + count.active + count.wait;
|
|
||||||
const max = 15;
|
|
||||||
const rest = truncate(max - total, 0, max);
|
|
||||||
|
|
||||||
let i = 0;
|
|
||||||
for (const video of videos) {
|
|
||||||
if (i > rest) return 100 + i;
|
|
||||||
const aid = video.aid;
|
|
||||||
const timestamp = video.published_at;
|
|
||||||
const diff = Date.now() - timestamp;
|
|
||||||
await VideoTagsQueue.add("getVideoTags", { aid }, { priority: getJobPriority(diff) });
|
|
||||||
i++;
|
|
||||||
}
|
|
||||||
return 0;
|
|
||||||
};
|
|
@ -1 +0,0 @@
|
|||||||
export * from "lib/mq/exec/getLatestVideos.ts";
|
|
@ -1,5 +0,0 @@
|
|||||||
import { Queue } from "bullmq";
|
|
||||||
|
|
||||||
export const LatestVideosQueue = new Queue("latestVideos");
|
|
||||||
|
|
||||||
export const VideoTagsQueue = new Queue("videoTags");
|
|
@ -1,22 +0,0 @@
|
|||||||
import { MINUTE, SECOND } from "$std/datetime/constants.ts";
|
|
||||||
import { LatestVideosQueue, VideoTagsQueue } from "lib/mq/index.ts";
|
|
||||||
import logger from "lib/log/logger.ts";
|
|
||||||
|
|
||||||
async function configGetLatestVideos() {
|
|
||||||
await LatestVideosQueue.upsertJobScheduler("getLatestVideos", {
|
|
||||||
every: 1 * MINUTE,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
async function configGetVideosTags() {
|
|
||||||
await VideoTagsQueue.upsertJobScheduler("getVideosTags", {
|
|
||||||
every: 30 * SECOND,
|
|
||||||
immediately: true,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
export async function initMQ() {
|
|
||||||
await configGetLatestVideos();
|
|
||||||
await configGetVideosTags();
|
|
||||||
logger.log("Message queue initialized.");
|
|
||||||
}
|
|
@ -1,164 +0,0 @@
|
|||||||
import logger from "lib/log/logger.ts";
|
|
||||||
import {RateLimiter} from "lib/mq/rateLimiter.ts";
|
|
||||||
import {SlidingWindow} from "lib/mq/slidingWindow.ts";
|
|
||||||
import {redis} from "lib/db/redis.ts";
|
|
||||||
import Redis from "ioredis";
|
|
||||||
|
|
||||||
interface Proxy {
|
|
||||||
type: string;
|
|
||||||
task: string;
|
|
||||||
limiter?: RateLimiter;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface ProxiesMap {
|
|
||||||
[name: string]: Proxy;
|
|
||||||
}
|
|
||||||
|
|
||||||
type NetSchedulerErrorCode =
|
|
||||||
| "NO_AVAILABLE_PROXY"
|
|
||||||
| "PROXY_RATE_LIMITED"
|
|
||||||
| "PROXY_NOT_FOUND"
|
|
||||||
| "FETCH_ERROR"
|
|
||||||
| "NOT_IMPLEMENTED";
|
|
||||||
|
|
||||||
export class NetSchedulerError extends Error {
|
|
||||||
public code: NetSchedulerErrorCode;
|
|
||||||
public rawError: unknown | undefined;
|
|
||||||
constructor(message: string, errorCode: NetSchedulerErrorCode, rawError?: unknown) {
|
|
||||||
super(message);
|
|
||||||
this.name = "NetSchedulerError";
|
|
||||||
this.code = errorCode;
|
|
||||||
this.rawError = rawError;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class NetScheduler {
|
|
||||||
private proxies: ProxiesMap = {};
|
|
||||||
|
|
||||||
addProxy(name: string, type: string, task: string): void {
|
|
||||||
this.proxies[name] = { type, task };
|
|
||||||
}
|
|
||||||
|
|
||||||
removeProxy(name: string): void {
|
|
||||||
delete this.proxies[name];
|
|
||||||
}
|
|
||||||
|
|
||||||
setProxyLimiter(name: string, limiter: RateLimiter): void {
|
|
||||||
this.proxies[name].limiter = limiter;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Make a request to the specified URL with any available proxy
|
|
||||||
* @param {string} url - The URL to request.
|
|
||||||
* @param {string} method - The HTTP method to use for the request. Default is "GET".
|
|
||||||
* @returns {Promise<any>} - A promise that resolves to the response body.
|
|
||||||
* @throws {NetSchedulerError} - The error will be thrown in following cases:
|
|
||||||
* - No available proxy currently: with error code NO_AVAILABLE_PROXY
|
|
||||||
* - Proxy is under rate limit: with error code PROXY_RATE_LIMITED
|
|
||||||
* - The native `fetch` function threw an error: with error code FETCH_ERROR
|
|
||||||
* - The proxy type is not supported: with error code NOT_IMPLEMENTED
|
|
||||||
*/
|
|
||||||
async request<R>(url: string, task: string, method: string = "GET"): Promise<R> {
|
|
||||||
// find a available proxy
|
|
||||||
const proxiesNames = Object.keys(this.proxies);
|
|
||||||
for (const proxyName of proxiesNames) {
|
|
||||||
const proxy = this.proxies[proxyName];
|
|
||||||
if (proxy.task !== task) continue;
|
|
||||||
if (await this.getProxyAvailability(proxyName)) {
|
|
||||||
return await this.proxyRequest<R>(url, proxyName, method);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
throw new NetSchedulerError("No available proxy currently.", "NO_AVAILABLE_PROXY");
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Make a request to the specified URL with the specified proxy
|
|
||||||
* @param {string} url - The URL to request.
|
|
||||||
* @param {string} proxyName - The name of the proxy to use.
|
|
||||||
* @param {string} method - The HTTP method to use for the request. Default is "GET".
|
|
||||||
* @param {boolean} force - If true, the request will be made even if the proxy is rate limited. Default is false.
|
|
||||||
* @returns {Promise<any>} - A promise that resolves to the response body.
|
|
||||||
* @throws {NetSchedulerError} - The error will be thrown in following cases:
|
|
||||||
* - Proxy not found: with error code PROXY_NOT_FOUND
|
|
||||||
* - Proxy is under rate limit: with error code PROXY_RATE_LIMITED
|
|
||||||
* - The native `fetch` function threw an error: with error code FETCH_ERROR
|
|
||||||
* - The proxy type is not supported: with error code NOT_IMPLEMENTED
|
|
||||||
*/
|
|
||||||
async proxyRequest<R>(url: string, proxyName: string, method: string = "GET", force: boolean = false): Promise<R> {
|
|
||||||
const proxy = this.proxies[proxyName];
|
|
||||||
if (!proxy) {
|
|
||||||
throw new NetSchedulerError(`Proxy "${proxyName}" not found`, "PROXY_NOT_FOUND");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!force && await this.getProxyAvailability(proxyName) === false) {
|
|
||||||
throw new NetSchedulerError(`Proxy "${proxyName}" is rate limited`, "PROXY_RATE_LIMITED");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (proxy.limiter) {
|
|
||||||
try {
|
|
||||||
await proxy.limiter!.trigger();
|
|
||||||
} catch (e) {
|
|
||||||
const error = e as Error;
|
|
||||||
if (e instanceof Redis.ReplyError) {
|
|
||||||
logger.error(error, "redis");
|
|
||||||
}
|
|
||||||
logger.warn(`Unhandled error: ${error.message}`, "mq", "proxyRequest");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (proxy.type) {
|
|
||||||
case "native":
|
|
||||||
return await this.nativeRequest<R>(url, method);
|
|
||||||
default:
|
|
||||||
throw new NetSchedulerError(`Proxy type ${proxy.type} not supported.`, "NOT_IMPLEMENTED");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private async getProxyAvailability(name: string): Promise<boolean> {
|
|
||||||
try {
|
|
||||||
const proxyConfig = this.proxies[name];
|
|
||||||
if (!proxyConfig || !proxyConfig.limiter) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return await proxyConfig.limiter.getAvailability();
|
|
||||||
} catch (e) {
|
|
||||||
const error = e as Error;
|
|
||||||
if (e instanceof Redis.ReplyError) {
|
|
||||||
logger.error(error, "redis");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
logger.warn(`Unhandled error: ${error.message}`, "mq", "getProxyAvailability");
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private async nativeRequest<R>(url: string, method: string): Promise<R> {
|
|
||||||
try {
|
|
||||||
const response = await fetch(url, { method });
|
|
||||||
return await response.json() as R;
|
|
||||||
} catch (e) {
|
|
||||||
throw new NetSchedulerError("Fetch error", "FETCH_ERROR", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const netScheduler = new NetScheduler();
|
|
||||||
netScheduler.addProxy("default", "native", "default");
|
|
||||||
netScheduler.addProxy("tags-native", "native", "getVideoTags");
|
|
||||||
const tagsRateLimiter = new RateLimiter("getVideoTags", [
|
|
||||||
{
|
|
||||||
window: new SlidingWindow(redis, 1),
|
|
||||||
max: 3,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
window: new SlidingWindow(redis, 30),
|
|
||||||
max: 30,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
window: new SlidingWindow(redis, 2 * 60),
|
|
||||||
max: 50,
|
|
||||||
},
|
|
||||||
]);
|
|
||||||
netScheduler.setProxyLimiter("tags-native", tagsRateLimiter);
|
|
||||||
|
|
||||||
export default netScheduler;
|
|
117
lib/net/bilibili.d.ts
vendored
117
lib/net/bilibili.d.ts
vendored
@ -1,117 +0,0 @@
|
|||||||
interface BaseResponse<T> {
|
|
||||||
code: number;
|
|
||||||
message: string;
|
|
||||||
ttl: number;
|
|
||||||
data: T;
|
|
||||||
}
|
|
||||||
|
|
||||||
export type VideoListResponse = BaseResponse<VideoListData>;
|
|
||||||
export type VideoTagsResponse = BaseResponse<VideoTagsData>;
|
|
||||||
|
|
||||||
type VideoTagsData = VideoTags[];
|
|
||||||
|
|
||||||
interface VideoTags {
|
|
||||||
tag_id: number;
|
|
||||||
tag_name: string;
|
|
||||||
cover: string;
|
|
||||||
head_cover: string;
|
|
||||||
content: string;
|
|
||||||
short_content: string;
|
|
||||||
type: number;
|
|
||||||
state: number;
|
|
||||||
ctime: number;
|
|
||||||
count: {
|
|
||||||
view: number;
|
|
||||||
use: number;
|
|
||||||
atten: number;
|
|
||||||
}
|
|
||||||
is_atten: number;
|
|
||||||
likes: number;
|
|
||||||
hates: number;
|
|
||||||
attribute: number;
|
|
||||||
liked: number;
|
|
||||||
hated: number;
|
|
||||||
extra_attr: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface VideoListData {
|
|
||||||
archives: VideoListVideo[];
|
|
||||||
page: {
|
|
||||||
num: number;
|
|
||||||
size: number;
|
|
||||||
count: number;
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
interface VideoListVideo {
|
|
||||||
aid: number;
|
|
||||||
videos: number;
|
|
||||||
tid: number;
|
|
||||||
tname: string;
|
|
||||||
copyright: number;
|
|
||||||
pic: string;
|
|
||||||
title: string;
|
|
||||||
pubdate: number;
|
|
||||||
ctime: number;
|
|
||||||
desc: string;
|
|
||||||
state: number;
|
|
||||||
duration: number;
|
|
||||||
mission_id?: number;
|
|
||||||
rights: {
|
|
||||||
bp: number;
|
|
||||||
elec: number;
|
|
||||||
download: number;
|
|
||||||
movie: number;
|
|
||||||
pay: number;
|
|
||||||
hd5: number;
|
|
||||||
no_reprint: number;
|
|
||||||
autoplay: number;
|
|
||||||
ugc_pay: number;
|
|
||||||
is_cooperation: number;
|
|
||||||
ugc_pay_preview: number;
|
|
||||||
no_background: number;
|
|
||||||
arc_pay: number;
|
|
||||||
pay_free_watch: number;
|
|
||||||
},
|
|
||||||
owner: {
|
|
||||||
mid: number;
|
|
||||||
name: string;
|
|
||||||
face: string;
|
|
||||||
},
|
|
||||||
stat: {
|
|
||||||
aid: number;
|
|
||||||
view: number;
|
|
||||||
danmaku: number;
|
|
||||||
reply: number;
|
|
||||||
favorite: number;
|
|
||||||
coin: number;
|
|
||||||
share: number;
|
|
||||||
now_rank: number;
|
|
||||||
his_rank: number;
|
|
||||||
like: number;
|
|
||||||
dislike: number;
|
|
||||||
vt: number;
|
|
||||||
vv: number;
|
|
||||||
},
|
|
||||||
dynamic: string;
|
|
||||||
cid: number;
|
|
||||||
dimension: {
|
|
||||||
width: number;
|
|
||||||
height: number;
|
|
||||||
rotate: number;
|
|
||||||
},
|
|
||||||
season_id?: number;
|
|
||||||
short_link_v2: string;
|
|
||||||
first_frame: string;
|
|
||||||
pub_location: string;
|
|
||||||
cover43: string;
|
|
||||||
tidv2: number;
|
|
||||||
tname_v2: string;
|
|
||||||
bvid: string;
|
|
||||||
season_type: number;
|
|
||||||
is_ogv: number;
|
|
||||||
ovg_info: string | null;
|
|
||||||
rcmd_season: string;
|
|
||||||
enable_vt: number;
|
|
||||||
ai_rcmd: null | string;
|
|
||||||
}
|
|
@ -1,88 +0,0 @@
|
|||||||
import { getLatestVideos } from "lib/net/getLatestVideos.ts";
|
|
||||||
import { AllDataType } from "lib/db/schema.d.ts";
|
|
||||||
|
|
||||||
export async function getVideoPositionInNewList(timestamp: number): Promise<number | null | AllDataType[]> {
|
|
||||||
const virtualPageSize = 50;
|
|
||||||
|
|
||||||
let lowPage = 1;
|
|
||||||
let highPage = 1;
|
|
||||||
let foundUpper = false;
|
|
||||||
while (true) {
|
|
||||||
const ps = highPage < 2 ? 50 : 1
|
|
||||||
const pn = highPage < 2 ? 1 : highPage * virtualPageSize;
|
|
||||||
const fetchTags = highPage < 2 ? true : false;
|
|
||||||
const videos = await getLatestVideos(pn, ps, 250, fetchTags);
|
|
||||||
if (!videos || videos.length === 0) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
const lastVideo = videos[videos.length - 1];
|
|
||||||
if (!lastVideo || !lastVideo.published_at) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
const lastTime = Date.parse(lastVideo.published_at);
|
|
||||||
if (lastTime <= timestamp && highPage == 1) {
|
|
||||||
return videos;
|
|
||||||
}
|
|
||||||
else if (lastTime <= timestamp) {
|
|
||||||
foundUpper = true;
|
|
||||||
break;
|
|
||||||
} else {
|
|
||||||
lowPage = highPage;
|
|
||||||
highPage *= 2;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!foundUpper) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
let boundaryPage = highPage;
|
|
||||||
let lo = lowPage;
|
|
||||||
let hi = highPage;
|
|
||||||
while (lo <= hi) {
|
|
||||||
const mid = Math.floor((lo + hi) / 2);
|
|
||||||
const videos = await getLatestVideos(mid * virtualPageSize, 1, 250, false);
|
|
||||||
if (!videos) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
if (videos.length === 0) {
|
|
||||||
hi = mid - 1;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const lastVideo = videos[videos.length - 1];
|
|
||||||
if (!lastVideo || !lastVideo.published_at) {
|
|
||||||
hi = mid - 1;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const lastTime = Date.parse(lastVideo.published_at);
|
|
||||||
if (lastTime > timestamp) {
|
|
||||||
lo = mid + 1;
|
|
||||||
} else {
|
|
||||||
boundaryPage = mid;
|
|
||||||
hi = mid - 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const boundaryVideos = await getLatestVideos(boundaryPage, virtualPageSize, 250, false);
|
|
||||||
let indexInPage = 0;
|
|
||||||
if (boundaryVideos && boundaryVideos.length > 0) {
|
|
||||||
for (let i = 0; i < boundaryVideos.length; i++) {
|
|
||||||
const video = boundaryVideos[i];
|
|
||||||
if (!video.published_at) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const videoTime = Date.parse(video.published_at);
|
|
||||||
if (videoTime > timestamp) {
|
|
||||||
indexInPage++;
|
|
||||||
} else {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const count = (boundaryPage - 1) * virtualPageSize + indexInPage;
|
|
||||||
|
|
||||||
const safetyMargin = 5;
|
|
||||||
|
|
||||||
return count + safetyMargin;
|
|
||||||
}
|
|
@ -1,45 +0,0 @@
|
|||||||
import { VideoListResponse } from "lib/net/bilibili.d.ts";
|
|
||||||
import { formatTimestampToPsql as formatPublishedAt } from "lib/utils/formatTimestampToPostgre.ts";
|
|
||||||
import { AllDataType } from "lib/db/schema.d.ts";
|
|
||||||
import logger from "lib/log/logger.ts";
|
|
||||||
import { HOUR, SECOND } from "$std/datetime/constants.ts";
|
|
||||||
|
|
||||||
export async function getLatestVideos(
|
|
||||||
page: number = 1,
|
|
||||||
pageSize: number = 10,
|
|
||||||
sleepRate: number = 250,
|
|
||||||
fetchTags: boolean = true,
|
|
||||||
): Promise<AllDataType[] | null> {
|
|
||||||
try {
|
|
||||||
const response = await fetch(
|
|
||||||
`https://api.bilibili.com/x/web-interface/newlist?rid=30&ps=${pageSize}&pn=${page}`,
|
|
||||||
);
|
|
||||||
const data: VideoListResponse = await response.json();
|
|
||||||
|
|
||||||
if (data.code !== 0) {
|
|
||||||
logger.error(`Error fetching videos: ${data.message}`, "net", "getLatestVideos");
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (data.data.archives.length === 0) {
|
|
||||||
logger.verbose("No more videos found", "net", "getLatestVideos");
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
|
|
||||||
return data.data.archives.map((video) => {
|
|
||||||
const published_at = formatPublishedAt(video.pubdate * SECOND + 8 * HOUR);
|
|
||||||
return {
|
|
||||||
aid: video.aid,
|
|
||||||
bvid: video.bvid,
|
|
||||||
description: video.desc,
|
|
||||||
uid: video.owner.mid,
|
|
||||||
tags: null,
|
|
||||||
title: video.title,
|
|
||||||
published_at: published_at,
|
|
||||||
} as AllDataType;
|
|
||||||
});
|
|
||||||
} catch (error) {
|
|
||||||
logger.error(error as Error, "net", "getLatestVideos");
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,35 +0,0 @@
|
|||||||
import { VideoTagsResponse } from "lib/net/bilibili.d.ts";
|
|
||||||
import netScheduler, {NetSchedulerError} from "lib/mq/scheduler.ts";
|
|
||||||
import logger from "lib/log/logger.ts";
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Fetch the tags for a video
|
|
||||||
* @param {number} aid The video's aid
|
|
||||||
* @return {Promise<string[] | null>} A promise, which resolves to an array of tags,
|
|
||||||
* or null if an `fetch` error occurred
|
|
||||||
* @throws {NetSchedulerError} If the request failed.
|
|
||||||
*/
|
|
||||||
export async function getVideoTags(aid: number): Promise<string[] | null> {
|
|
||||||
try {
|
|
||||||
const url = `https://api.bilibili.com/x/tag/archive/tags?aid=${aid}`;
|
|
||||||
const data = await netScheduler.request<VideoTagsResponse>(url, 'getVideoTags');
|
|
||||||
if (data.code != 0) {
|
|
||||||
logger.error(`Error fetching tags for video ${aid}: ${data.message}`, 'net', 'getVideoTags');
|
|
||||||
return [];
|
|
||||||
}
|
|
||||||
return data.data.map((tag) => tag.tag_name);
|
|
||||||
}
|
|
||||||
catch (e) {
|
|
||||||
const error = e as NetSchedulerError;
|
|
||||||
if (error.code == "FETCH_ERROR") {
|
|
||||||
const rawError = error.rawError! as Error;
|
|
||||||
rawError.message = `Error fetching tags for video ${aid}: ` + rawError.message;
|
|
||||||
logger.error(rawError, 'net', 'getVideoTags');
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
// Re-throw the error
|
|
||||||
throw e;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,77 +0,0 @@
|
|||||||
import { Client } from "https://deno.land/x/postgres@v0.19.3/mod.ts";
|
|
||||||
import { getLatestVideos } from "lib/net/getLatestVideos.ts";
|
|
||||||
import { getLatestVideoTimestampFromAllData, insertIntoAllData, videoExistsInAllData } from "lib/db/allData.ts";
|
|
||||||
import { sleep } from "lib/utils/sleep.ts";
|
|
||||||
import { getVideoPositionInNewList } from "lib/net/bisectVideoStartFrom.ts";
|
|
||||||
import { SECOND } from "$std/datetime/constants.ts";
|
|
||||||
import logger from "lib/log/logger.ts";
|
|
||||||
|
|
||||||
export async function insertLatestVideos(
|
|
||||||
client: Client,
|
|
||||||
pageSize: number = 10,
|
|
||||||
sleepRate: number = 250,
|
|
||||||
intervalRate: number = 4000,
|
|
||||||
): Promise<number | null> {
|
|
||||||
const latestVideoTimestamp = await getLatestVideoTimestampFromAllData(client);
|
|
||||||
if (latestVideoTimestamp == null) {
|
|
||||||
logger.error("Cannot get latest video timestamp from current database.", "net", "fn:insertLatestVideos()");
|
|
||||||
return null
|
|
||||||
}
|
|
||||||
logger.log(`Latest video in the database: ${new Date(latestVideoTimestamp).toISOString()}`, "net", "fn:insertLatestVideos()")
|
|
||||||
const videoIndex = await getVideoPositionInNewList(latestVideoTimestamp);
|
|
||||||
if (videoIndex == null) {
|
|
||||||
logger.error("Cannot locate the video through bisect.", "net", "fn:insertLatestVideos()");
|
|
||||||
return null
|
|
||||||
}
|
|
||||||
if (typeof videoIndex == "object") {
|
|
||||||
for (const video of videoIndex) {
|
|
||||||
const videoExists = await videoExistsInAllData(client, video.aid);
|
|
||||||
if (!videoExists) {
|
|
||||||
insertIntoAllData(client, video);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
let page = Math.floor(videoIndex / pageSize) + 1;
|
|
||||||
let failCount = 0;
|
|
||||||
const insertedVideos = new Set();
|
|
||||||
while (true) {
|
|
||||||
try {
|
|
||||||
const videos = await getLatestVideos(page, pageSize, sleepRate);
|
|
||||||
if (videos == null) {
|
|
||||||
failCount++;
|
|
||||||
if (failCount > 5) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
failCount = 0;
|
|
||||||
if (videos.length == 0) {
|
|
||||||
logger.verbose("No more videos found", "net", "fn:insertLatestVideos()");
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
for (const video of videos) {
|
|
||||||
const videoExists = await videoExistsInAllData(client, video.aid);
|
|
||||||
if (!videoExists) {
|
|
||||||
insertIntoAllData(client, video);
|
|
||||||
insertedVideos.add(video.aid);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
logger.log(`Page ${page} crawled, total: ${insertedVideos.size} videos.`, "net", "fn:insertLatestVideos()");
|
|
||||||
page--;
|
|
||||||
if (page < 1) {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
logger.error(error as Error, "net", "fn:insertLatestVideos()");
|
|
||||||
failCount++;
|
|
||||||
if (failCount > 5) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
} finally {
|
|
||||||
await sleep(Math.random() * intervalRate + failCount * 3 * SECOND + SECOND);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0;
|
|
||||||
}
|
|
13
main.ts
13
main.ts
@ -1,13 +0,0 @@
|
|||||||
/// <reference no-default-lib="true" />
|
|
||||||
/// <reference lib="dom" />
|
|
||||||
/// <reference lib="dom.iterable" />
|
|
||||||
/// <reference lib="dom.asynciterable" />
|
|
||||||
/// <reference lib="deno.ns" />
|
|
||||||
|
|
||||||
import "$std/dotenv/load.ts";
|
|
||||||
|
|
||||||
import { start } from "$fresh/server.ts";
|
|
||||||
import manifest from "./fresh.gen.ts";
|
|
||||||
import config from "./fresh.config.ts";
|
|
||||||
|
|
||||||
await start(manifest, config);
|
|
@ -18,4 +18,13 @@ Note
|
|||||||
0324: V3.5-test3 # 用回3.2的FC层试试
|
0324: V3.5-test3 # 用回3.2的FC层试试
|
||||||
0331: V3.6-test3 # 3.5不太行,我试着调下超参
|
0331: V3.6-test3 # 3.5不太行,我试着调下超参
|
||||||
0335: V3.7-test3 # 3.6还行,再调超参试试看
|
0335: V3.7-test3 # 3.6还行,再调超参试试看
|
||||||
0414: V3.8-test3 # 3.7不行,从3.6的基础重新调
|
0414: V3.8-test3 # 3.7不行,从3.6的基础重新调
|
||||||
|
1918: V3.9
|
||||||
|
2308: V3.11
|
||||||
|
2243: V3.11 # 256维嵌入
|
||||||
|
2253: V3.11 # 1024维度嵌入(对比)
|
||||||
|
2337: V3.12 # 级联分类
|
||||||
|
2350: V3.13 # V3.12, 换用普通交叉熵损失
|
||||||
|
0012: V3.11 # 换用普通交叉熵损失
|
||||||
|
0039: V3.11 # 级联分类,但使用两个独立模型
|
||||||
|
0122: V3.15 # 删除author_info通道
|
@ -103,8 +103,7 @@ class MultiChannelDataset(Dataset):
|
|||||||
texts = {
|
texts = {
|
||||||
'title': example['title'],
|
'title': example['title'],
|
||||||
'description': example['description'],
|
'description': example['description'],
|
||||||
'tags': tags_text,
|
'tags': tags_text
|
||||||
'author_info': example['author_info']
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
110
ml/filter/embedding.py
Normal file
110
ml/filter/embedding.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from model2vec import StaticModel
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_batch(batch_data, device="cpu"):
|
||||||
|
"""
|
||||||
|
将输入的 batch_data 转换为模型所需的输入格式 [batch_size, num_channels, embedding_dim]。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
batch_data (dict): 输入的 batch 数据,格式为 {
|
||||||
|
"title": [text1, text2, ...],
|
||||||
|
"description": [text1, text2, ...],
|
||||||
|
"tags": [text1, text2, ...]
|
||||||
|
}
|
||||||
|
device (str): 模型运行的设备(如 "cpu" 或 "cuda")。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
torch.Tensor: 形状为 [batch_size, num_channels, embedding_dim] 的张量。
|
||||||
|
"""
|
||||||
|
# 1. 对每个通道的文本分别编码
|
||||||
|
channel_embeddings = []
|
||||||
|
model = StaticModel.from_pretrained("./model/embedding_1024/")
|
||||||
|
for channel in ["title", "description", "tags"]:
|
||||||
|
texts = batch_data[channel] # 获取当前通道的文本列表
|
||||||
|
embeddings = torch.from_numpy(model.encode(texts)).to(torch.float32).to(device) # 编码为 [batch_size, embedding_dim]
|
||||||
|
channel_embeddings.append(embeddings)
|
||||||
|
|
||||||
|
# 2. 将编码结果堆叠为 [batch_size, num_channels, embedding_dim]
|
||||||
|
batch_tensor = torch.stack(channel_embeddings, dim=1) # 在 dim=1 上堆叠
|
||||||
|
return batch_tensor
|
||||||
|
|
||||||
|
import onnxruntime as ort
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from itertools import accumulate
|
||||||
|
|
||||||
|
def prepare_batch_per_token(batch_data, max_length=1024):
|
||||||
|
"""
|
||||||
|
将输入的 batch_data 转换为模型所需的输入格式 [batch_size, num_channels, seq_length, embedding_dim]。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
batch_data (dict): 输入的 batch 数据,格式为 {
|
||||||
|
"title": [text1, text2, ...],
|
||||||
|
"description": [text1, text2, ...],
|
||||||
|
"tags": [text1, text2, ...],
|
||||||
|
"author_info": [text1, text2, ...]
|
||||||
|
}
|
||||||
|
max_length (int): 最大序列长度。
|
||||||
|
|
||||||
|
返回:
|
||||||
|
torch.Tensor: 形状为 [batch_size, num_channels, seq_length, embedding_dim] 的张量。
|
||||||
|
"""
|
||||||
|
# 初始化 tokenizer 和 ONNX 模型
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("alikia2x/jina-embedding-v3-m2v-1024")
|
||||||
|
session = ort.InferenceSession("./model/embedding_256/onnx/model.onnx")
|
||||||
|
|
||||||
|
# 1. 对每个通道的文本分别编码
|
||||||
|
channel_embeddings = []
|
||||||
|
for channel in ["title", "description", "tags", "author_info"]:
|
||||||
|
texts = batch_data[channel] # 获取当前通道的文本列表
|
||||||
|
|
||||||
|
# Step 1: 生成 input_ids 和 offsets
|
||||||
|
# 对每个文本单独编码,保留原始 token 长度
|
||||||
|
encoded_inputs = [tokenizer(text, truncation=True, max_length=max_length, return_tensors='np') for text in texts]
|
||||||
|
|
||||||
|
# 提取每个文本的 input_ids 长度(考虑实际的 token 数量)
|
||||||
|
input_ids_lengths = [len(enc["input_ids"][0]) for enc in encoded_inputs]
|
||||||
|
|
||||||
|
# 生成 offsets: [0, len1, len1+len2, ...]
|
||||||
|
offsets = list(accumulate([0] + input_ids_lengths[:-1])) # 累积和,排除最后一个长度
|
||||||
|
|
||||||
|
# 将所有 input_ids 展平为一维数组
|
||||||
|
flattened_input_ids = np.concatenate([enc["input_ids"][0] for enc in encoded_inputs], axis=0).astype(np.int64)
|
||||||
|
|
||||||
|
# Step 2: 构建 ONNX 输入
|
||||||
|
inputs = {
|
||||||
|
"input_ids": ort.OrtValue.ortvalue_from_numpy(flattened_input_ids),
|
||||||
|
"offsets": ort.OrtValue.ortvalue_from_numpy(np.array(offsets, dtype=np.int64))
|
||||||
|
}
|
||||||
|
|
||||||
|
# Step 3: 运行 ONNX 模型
|
||||||
|
embeddings = session.run(None, inputs)[0] # 假设输出名为 "embeddings"
|
||||||
|
|
||||||
|
# Step 4: 将输出重塑为 [batch_size, seq_length, embedding_dim]
|
||||||
|
# 注意:这里假设 ONNX 输出的形状是 [total_tokens, embedding_dim]
|
||||||
|
# 需要根据实际序列长度重新分组
|
||||||
|
batch_size = len(texts)
|
||||||
|
embeddings_split = np.split(embeddings, np.cumsum(input_ids_lengths[:-1]))
|
||||||
|
padded_embeddings = []
|
||||||
|
for emb, seq_len in zip(embeddings_split, input_ids_lengths):
|
||||||
|
# 对每个序列填充到 max_length
|
||||||
|
if seq_len > max_length:
|
||||||
|
# 如果序列长度超过 max_length,截断
|
||||||
|
emb = emb[:max_length]
|
||||||
|
pad_length = 0
|
||||||
|
else:
|
||||||
|
# 否则填充到 max_length
|
||||||
|
pad_length = max_length - seq_len
|
||||||
|
|
||||||
|
# 填充到 [max_length, embedding_dim]
|
||||||
|
padded = np.pad(emb, ((0, pad_length), (0, 0)), mode='constant')
|
||||||
|
padded_embeddings.append(padded)
|
||||||
|
|
||||||
|
# 确保所有填充后的序列形状一致
|
||||||
|
embeddings_tensor = torch.tensor(np.stack(padded_embeddings), dtype=torch.float32)
|
||||||
|
channel_embeddings.append(embeddings_tensor)
|
||||||
|
|
||||||
|
# 2. 将编码结果堆叠为 [batch_size, num_channels, seq_length, embedding_dim]
|
||||||
|
batch_tensor = torch.stack(channel_embeddings, dim=1)
|
||||||
|
return batch_tensor
|
54
ml/filter/embedding_range.py
Normal file
54
ml/filter/embedding_range.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import random
|
||||||
|
from embedding import prepare_batch
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
file_path = './data/filter/model_predicted.jsonl'
|
||||||
|
|
||||||
|
class Dataset:
|
||||||
|
def __init__(self, file_path):
|
||||||
|
all_examples = self.load_data(file_path)
|
||||||
|
self.examples = all_examples
|
||||||
|
|
||||||
|
def load_data(self, file_path):
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
return [json.loads(line) for line in f]
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
end_idx = min((idx + 1) * self.batch_size, len(self.examples))
|
||||||
|
texts = {
|
||||||
|
'title': [ex['title'] for ex in self.examples[idx * self.batch_size:end_idx]],
|
||||||
|
'description': [ex['description'] for ex in self.examples[idx * self.batch_size:end_idx]],
|
||||||
|
'tags': [",".join(ex['tags']) for ex in self.examples[idx * self.batch_size:end_idx]],
|
||||||
|
'author_info': [ex['author_info'] for ex in self.examples[idx * self.batch_size:end_idx]]
|
||||||
|
}
|
||||||
|
return texts
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.examples)
|
||||||
|
|
||||||
|
def get_batch(self, idx, batch_size):
|
||||||
|
self.batch_size = batch_size
|
||||||
|
return self.__getitem__(idx)
|
||||||
|
|
||||||
|
total = 600000
|
||||||
|
batch_size = 512
|
||||||
|
batch_num = total // batch_size
|
||||||
|
dataset = Dataset(file_path)
|
||||||
|
arr_len = batch_size * 4 * 1024
|
||||||
|
sample_rate = 0.1
|
||||||
|
sample_num = int(arr_len * sample_rate)
|
||||||
|
|
||||||
|
data = np.array([])
|
||||||
|
for i in tqdm(range(batch_num)):
|
||||||
|
batch = dataset.get_batch(i, batch_size)
|
||||||
|
batch = prepare_batch(batch, device="cpu")
|
||||||
|
arr = batch.flatten().numpy()
|
||||||
|
sampled = np.random.choice(arr.shape[0], size=sample_num, replace=False)
|
||||||
|
data = np.concatenate((data, arr[sampled]), axis=0) if data.size else arr[sampled]
|
||||||
|
if i % 10 == 0:
|
||||||
|
np.save('embedding_range.npy', data)
|
||||||
|
np.save('embedding_range.npy', data)
|
43
ml/filter/embedding_visualization.py
Normal file
43
ml/filter/embedding_visualization.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
# 加载数据
|
||||||
|
data = np.load("1.npy")
|
||||||
|
|
||||||
|
# 绘制直方图,获取频数
|
||||||
|
n, bins, patches = plt.hist(data, bins=32, density=False, alpha=0.7, color='skyblue')
|
||||||
|
|
||||||
|
# 计算数据总数
|
||||||
|
total_data = len(data)
|
||||||
|
|
||||||
|
# 将频数转换为频率
|
||||||
|
frequencies = n / total_data
|
||||||
|
|
||||||
|
# 计算统计信息
|
||||||
|
max_val = np.max(data)
|
||||||
|
min_val = np.min(data)
|
||||||
|
std_dev = np.std(data)
|
||||||
|
|
||||||
|
# 设置图形属性
|
||||||
|
plt.title('Frequency Distribution Histogram')
|
||||||
|
plt.xlabel('Value')
|
||||||
|
plt.ylabel('Frequency')
|
||||||
|
|
||||||
|
# 重新绘制直方图,使用频率作为高度
|
||||||
|
plt.cla() # 清除当前坐标轴上的内容
|
||||||
|
plt.bar([(bins[i] + bins[i+1])/2 for i in range(len(bins)-1)], frequencies, width=[bins[i+1]-bins[i] for i in range(len(bins)-1)], alpha=0.7, color='skyblue')
|
||||||
|
|
||||||
|
# 在柱子上注明频率值
|
||||||
|
for i in range(len(patches)):
|
||||||
|
plt.text(bins[i]+(bins[i+1]-bins[i])/2, frequencies[i], f'{frequencies[i]:.2e}', ha='center', va='bottom', fontsize=6)
|
||||||
|
|
||||||
|
# 在图表一角显示统计信息
|
||||||
|
stats_text = f"Max: {max_val:.6f}\nMin: {min_val:.6f}\nStd: {std_dev:.4e}"
|
||||||
|
plt.text(0.95, 0.95, stats_text, transform=plt.gca().transAxes,
|
||||||
|
ha='right', va='top', bbox=dict(facecolor='white', edgecolor='black', alpha=0.8))
|
||||||
|
|
||||||
|
# 设置 x 轴刻度对齐柱子边界
|
||||||
|
plt.xticks(bins, fontsize = 6)
|
||||||
|
|
||||||
|
# 显示图形
|
||||||
|
plt.show()
|
@ -10,7 +10,7 @@ import tty
|
|||||||
import termios
|
import termios
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
from db_utils import fetch_entry_data, parse_entry_data
|
from db_utils import fetch_entry_data, parse_entry_data
|
||||||
from modelV3_9 import VideoClassifierV3_9
|
from modelV3_10 import VideoClassifierV3_10
|
||||||
|
|
||||||
class LabelingSystem:
|
class LabelingSystem:
|
||||||
def __init__(self, mode='model_testing', database_path="./data/main.db",
|
def __init__(self, mode='model_testing', database_path="./data/main.db",
|
||||||
@ -27,7 +27,7 @@ class LabelingSystem:
|
|||||||
self.model = None
|
self.model = None
|
||||||
self.sentence_transformer = None
|
self.sentence_transformer = None
|
||||||
if self.mode == 'model_testing':
|
if self.mode == 'model_testing':
|
||||||
self.model = VideoClassifierV3_9()
|
self.model = VideoClassifierV3_10()
|
||||||
self.model.load_state_dict(torch.load(model_path))
|
self.model.load_state_dict(torch.load(model_path))
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
self.sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
|
self.sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
|
@ -3,13 +3,13 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
class VideoClassifierV3_10(nn.Module):
|
class VideoClassifierV3_10(nn.Module):
|
||||||
def __init__(self, embedding_dim=1024, hidden_dim=648, output_dim=3):
|
def __init__(self, embedding_dim=1024, hidden_dim=648, output_dim=3, temperature=1.7):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_channels = 4
|
self.num_channels = 4
|
||||||
self.channel_names = ['title', 'description', 'tags', 'author_info']
|
self.channel_names = ['title', 'description', 'tags', 'author_info']
|
||||||
|
|
||||||
# 可学习温度系数
|
# 可学习温度系数
|
||||||
self.temperature = nn.Parameter(torch.tensor(1.7))
|
self.temperature = nn.Parameter(torch.tensor(temperature))
|
||||||
|
|
||||||
# 带约束的通道权重(使用Sigmoid替代Softmax)
|
# 带约束的通道权重(使用Sigmoid替代Softmax)
|
||||||
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))
|
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))
|
79
ml/filter/modelV3_12.py
Normal file
79
ml/filter/modelV3_12.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class VideoClassifierV3_12(nn.Module):
|
||||||
|
def __init__(self, embedding_dim=1024, hidden_dim=648):
|
||||||
|
super().__init__()
|
||||||
|
self.num_channels = 4
|
||||||
|
self.channel_names = ['title', 'description', 'tags', 'author_info']
|
||||||
|
|
||||||
|
# 可学习温度系数
|
||||||
|
self.temperature = nn.Parameter(torch.tensor(1.7))
|
||||||
|
|
||||||
|
# 带约束的通道权重(使用Sigmoid替代Softmax)
|
||||||
|
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))
|
||||||
|
|
||||||
|
# 第一个二分类器:0 vs 1/2
|
||||||
|
self.first_classifier = nn.Sequential(
|
||||||
|
nn.Linear(embedding_dim * self.num_channels, hidden_dim*2),
|
||||||
|
nn.BatchNorm1d(hidden_dim*2),
|
||||||
|
nn.Dropout(0.2),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(hidden_dim*2, 2) # 输出为2类:0 vs 1/2
|
||||||
|
)
|
||||||
|
|
||||||
|
# 第二个二分类器:1 vs 2
|
||||||
|
self.second_classifier = nn.Sequential(
|
||||||
|
nn.Linear(embedding_dim * self.num_channels, hidden_dim*2),
|
||||||
|
nn.BatchNorm1d(hidden_dim*2),
|
||||||
|
nn.Dropout(0.2),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(hidden_dim*2, 2) # 输出为2类:1 vs 2
|
||||||
|
)
|
||||||
|
|
||||||
|
# 权重初始化
|
||||||
|
self._init_weights()
|
||||||
|
|
||||||
|
def _init_weights(self):
|
||||||
|
for layer in self.first_classifier:
|
||||||
|
if isinstance(layer, nn.Linear):
|
||||||
|
nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')
|
||||||
|
nn.init.zeros_(layer.bias)
|
||||||
|
|
||||||
|
for layer in self.second_classifier:
|
||||||
|
if isinstance(layer, nn.Linear):
|
||||||
|
nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')
|
||||||
|
nn.init.zeros_(layer.bias)
|
||||||
|
|
||||||
|
def forward(self, channel_features: torch.Tensor):
|
||||||
|
"""
|
||||||
|
输入格式: [batch_size, num_channels, embedding_dim]
|
||||||
|
输出格式: [batch_size, output_dim]
|
||||||
|
"""
|
||||||
|
# 自适应通道权重(Sigmoid约束)
|
||||||
|
weights = torch.sigmoid(self.channel_weights) # [0,1]范围
|
||||||
|
weighted_features = channel_features * weights.unsqueeze(0).unsqueeze(-1)
|
||||||
|
|
||||||
|
# 特征拼接
|
||||||
|
combined = weighted_features.view(weighted_features.size(0), -1)
|
||||||
|
|
||||||
|
# 第一个二分类器:0 vs 1/2
|
||||||
|
first_output = self.first_classifier(combined)
|
||||||
|
first_probs = F.softmax(first_output, dim=1)
|
||||||
|
|
||||||
|
# 第二个二分类器:1 vs 2
|
||||||
|
second_output = self.second_classifier(combined)
|
||||||
|
second_probs = F.softmax(second_output, dim=1)
|
||||||
|
|
||||||
|
# 合并结果
|
||||||
|
final_probs = torch.zeros(channel_features.size(0), 3).to(channel_features.device)
|
||||||
|
final_probs[:, 0] = first_probs[:, 0] # 类别0的概率
|
||||||
|
final_probs[:, 1] = first_probs[:, 1] * second_probs[:, 0] # 类别1的概率
|
||||||
|
final_probs[:, 2] = first_probs[:, 1] * second_probs[:, 1] # 类别2的概率
|
||||||
|
|
||||||
|
return final_probs
|
||||||
|
|
||||||
|
def get_channel_weights(self):
|
||||||
|
"""获取各通道权重(带温度调节)"""
|
||||||
|
return torch.softmax(self.channel_weights / self.temperature, dim=0).detach().cpu().numpy()
|
@ -2,14 +2,14 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
class VideoClassifierV3_9(nn.Module):
|
class VideoClassifierV3_15(nn.Module):
|
||||||
def __init__(self, embedding_dim=1024, hidden_dim=648, output_dim=3):
|
def __init__(self, embedding_dim=1024, hidden_dim=648, output_dim=3, temperature=1.7):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_channels = 4
|
self.num_channels = 3
|
||||||
self.channel_names = ['title', 'description', 'tags', 'author_info']
|
self.channel_names = ['title', 'description', 'tags']
|
||||||
|
|
||||||
# 可学习温度系数
|
# 可学习温度系数
|
||||||
self.temperature = nn.Parameter(torch.tensor(1.7))
|
self.temperature = nn.Parameter(torch.tensor(temperature))
|
||||||
|
|
||||||
# 带约束的通道权重(使用Sigmoid替代Softmax)
|
# 带约束的通道权重(使用Sigmoid替代Softmax)
|
||||||
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))
|
self.channel_weights = nn.Parameter(torch.ones(self.num_channels))
|
||||||
@ -38,21 +38,11 @@ class VideoClassifierV3_9(nn.Module):
|
|||||||
nn.init.zeros_(layer.bias)
|
nn.init.zeros_(layer.bias)
|
||||||
|
|
||||||
|
|
||||||
def forward(self, input_texts, sentence_transformer):
|
def forward(self, channel_features: torch.Tensor):
|
||||||
# 合并文本进行批量编码
|
"""
|
||||||
all_texts = [text for channel in self.channel_names for text in input_texts[channel]]
|
输入格式: [batch_size, num_channels, embedding_dim]
|
||||||
|
输出格式: [batch_size, output_dim]
|
||||||
# 冻结的文本编码
|
"""
|
||||||
with torch.no_grad():
|
|
||||||
embeddings = torch.tensor(
|
|
||||||
sentence_transformer.encode(all_texts),
|
|
||||||
device=next(self.parameters()).device
|
|
||||||
)
|
|
||||||
|
|
||||||
# 分割并加权通道特征
|
|
||||||
split_sizes = [len(input_texts[name]) for name in self.channel_names]
|
|
||||||
channel_features = torch.split(embeddings, split_sizes, dim=0)
|
|
||||||
channel_features = torch.stack(channel_features, dim=1)
|
|
||||||
|
|
||||||
# 自适应通道权重(Sigmoid约束)
|
# 自适应通道权重(Sigmoid约束)
|
||||||
weights = torch.sigmoid(self.channel_weights) # [0,1]范围
|
weights = torch.sigmoid(self.channel_weights) # [0,1]范围
|
93
ml/filter/modelV6_0.py
Normal file
93
ml/filter/modelV6_0.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class VideoClassifierV6_0(nn.Module):
|
||||||
|
def __init__(self, embedding_dim=256, seq_length=1024, hidden_dim=512, output_dim=3):
|
||||||
|
super().__init__()
|
||||||
|
self.num_channels = 4
|
||||||
|
self.channel_names = ['title', 'description', 'tags', 'author_info']
|
||||||
|
|
||||||
|
# CNN特征提取层
|
||||||
|
self.conv_layers = nn.Sequential(
|
||||||
|
# 第一层卷积
|
||||||
|
nn.Conv2d(self.num_channels, 64, kernel_size=(3, 3), padding=1),
|
||||||
|
nn.BatchNorm2d(64),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.MaxPool2d(kernel_size=(2, 2)),
|
||||||
|
|
||||||
|
# 第二层卷积
|
||||||
|
nn.Conv2d(64, 128, kernel_size=(3, 3), padding=1),
|
||||||
|
nn.BatchNorm2d(128),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.MaxPool2d(kernel_size=(2, 2)),
|
||||||
|
|
||||||
|
# 第三层卷积
|
||||||
|
nn.Conv2d(128, 256, kernel_size=(3, 3), padding=1),
|
||||||
|
nn.BatchNorm2d(256),
|
||||||
|
nn.GELU(),
|
||||||
|
|
||||||
|
# 全局平均池化层
|
||||||
|
# 输出形状为 [batch_size, 256, 1, 1]
|
||||||
|
nn.AdaptiveAvgPool2d((1, 1))
|
||||||
|
)
|
||||||
|
|
||||||
|
# 全局池化后的特征维度固定为 256
|
||||||
|
self.feature_dim = 256
|
||||||
|
|
||||||
|
# 全连接层
|
||||||
|
self.fc = nn.Sequential(
|
||||||
|
nn.Linear(self.feature_dim, hidden_dim),
|
||||||
|
nn.BatchNorm1d(hidden_dim),
|
||||||
|
nn.Dropout(0.2),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(hidden_dim, output_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._init_weights()
|
||||||
|
|
||||||
|
def _init_weights(self):
|
||||||
|
for module in self.modules():
|
||||||
|
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
|
||||||
|
nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
|
||||||
|
if module.bias is not None:
|
||||||
|
nn.init.zeros_(module.bias)
|
||||||
|
|
||||||
|
def forward(self, channel_features: torch.Tensor):
|
||||||
|
"""
|
||||||
|
输入格式: [batch_size, num_channels, seq_length, embedding_dim]
|
||||||
|
输出格式: [batch_size, output_dim]
|
||||||
|
"""
|
||||||
|
# CNN特征提取
|
||||||
|
conv_features = self.conv_layers(channel_features)
|
||||||
|
|
||||||
|
# 展平特征(全局池化后形状为 [batch_size, 256, 1, 1])
|
||||||
|
flat_features = conv_features.view(conv_features.size(0), -1) # [batch_size, 256]
|
||||||
|
|
||||||
|
# 全连接层分类
|
||||||
|
return self.fc(flat_features)
|
||||||
|
|
||||||
|
# 损失函数保持不变
|
||||||
|
class AdaptiveRecallLoss(nn.Module):
|
||||||
|
def __init__(self, class_weights, alpha=0.8, gamma=2.0, fp_penalty=0.5):
|
||||||
|
super().__init__()
|
||||||
|
self.class_weights = class_weights
|
||||||
|
self.alpha = alpha
|
||||||
|
self.gamma = gamma
|
||||||
|
self.fp_penalty = fp_penalty
|
||||||
|
|
||||||
|
def forward(self, logits, targets):
|
||||||
|
ce_loss = F.cross_entropy(logits, targets, weight=self.class_weights, reduction='none')
|
||||||
|
pt = torch.exp(-ce_loss)
|
||||||
|
focal_loss = ((1 - pt) ** self.gamma) * ce_loss
|
||||||
|
|
||||||
|
class_mask = F.one_hot(targets, num_classes=len(self.class_weights))
|
||||||
|
class_weights = (self.alpha + (1 - self.alpha) * pt.unsqueeze(-1)) * class_mask
|
||||||
|
recall_loss = (class_weights * focal_loss.unsqueeze(-1)).sum(dim=1)
|
||||||
|
|
||||||
|
probs = F.softmax(logits, dim=1)
|
||||||
|
fp_mask = (targets != 0) & (torch.argmax(logits, dim=1) == 0)
|
||||||
|
fp_loss = self.fp_penalty * probs[:, 0][fp_mask].pow(2).sum()
|
||||||
|
|
||||||
|
total_loss = recall_loss.mean() + fp_loss / len(targets)
|
||||||
|
return total_loss
|
@ -1,16 +1,16 @@
|
|||||||
import torch
|
import torch
|
||||||
from modelV3_10 import VideoClassifierV3_10
|
from modelV3_15 import VideoClassifierV3_15
|
||||||
|
|
||||||
|
|
||||||
def export_onnx(model_path="./filter/checkpoints/best_model_V3.10.pt",
|
def export_onnx(model_path="./filter/checkpoints/best_model_V3.17.pt",
|
||||||
onnx_path="./model/video_classifier_v3_10.onnx"):
|
onnx_path="./model/video_classifier_v3_17.onnx"):
|
||||||
# 初始化模型
|
# 初始化模型
|
||||||
model = VideoClassifierV3_10()
|
model = VideoClassifierV3_15()
|
||||||
model.load_state_dict(torch.load(model_path))
|
model.load_state_dict(torch.load(model_path))
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
# 创建符合输入规范的虚拟输入
|
# 创建符合输入规范的虚拟输入
|
||||||
dummy_input = torch.randn(1, 4, 1024) # [batch=1, channels=4, embedding_dim=1024]
|
dummy_input = torch.randn(1, 3, 1024) # [batch=1, channels=4, embedding_dim=1024]
|
||||||
|
|
||||||
# 导出ONNX
|
# 导出ONNX
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
36
ml/filter/quantize.py
Normal file
36
ml/filter/quantize.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from safetensors import safe_open
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# 配置路径
|
||||||
|
model_path = "./model/embedding/model.safetensors"
|
||||||
|
save_path = "./model/embedding/int8_model.safetensors"
|
||||||
|
|
||||||
|
# 加载原始嵌入层
|
||||||
|
with safe_open(model_path, framework="pt") as f:
|
||||||
|
embeddings_tensor = f.get_tensor("embeddings")
|
||||||
|
|
||||||
|
# 计算极值
|
||||||
|
min_val = torch.min(embeddings_tensor)
|
||||||
|
max_val = torch.max(embeddings_tensor)
|
||||||
|
|
||||||
|
# 计算量化参数
|
||||||
|
scale = (max_val - min_val) / 255 # int8 的范围是 256 个值(-128 到 127)
|
||||||
|
|
||||||
|
# 将浮点数映射到 int8 范围
|
||||||
|
int8_tensor = torch.round((embeddings_tensor - min_val) / scale).to(torch.int8) - 128
|
||||||
|
|
||||||
|
# 确保与原张量形状一致
|
||||||
|
assert int8_tensor.shape == embeddings_tensor.shape
|
||||||
|
|
||||||
|
# 保存映射后的 int8 张量
|
||||||
|
save_file({"embeddings": int8_tensor}, save_path)
|
||||||
|
|
||||||
|
# 输出反映射公式
|
||||||
|
print("int8 反映射公式:")
|
||||||
|
m = min_val.item()
|
||||||
|
am = abs(min_val.item())
|
||||||
|
sign = "-" if m < 0 else "+"
|
||||||
|
print(f"int8_tensor = (int8_value + 128) × {scale.item()} {sign} {am}")
|
||||||
|
|
||||||
|
print("int8 映射完成!")
|
@ -1,7 +1,7 @@
|
|||||||
from labeling_system import LabelingSystem
|
from labeling_system import LabelingSystem
|
||||||
|
|
||||||
DATABASE_PATH = "./data/main.db"
|
DATABASE_PATH = "./data/main.db"
|
||||||
MODEL_PATH = "./filter/checkpoints/best_model_V3.9.pt"
|
MODEL_PATH = "./filter/checkpoints/best_model_V3.11.pt"
|
||||||
OUTPUT_FILE = "./data/filter/real_test.jsonl"
|
OUTPUT_FILE = "./data/filter/real_test.jsonl"
|
||||||
BATCH_SIZE = 50
|
BATCH_SIZE = 50
|
||||||
|
|
@ -4,17 +4,16 @@ import numpy as np
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from dataset import MultiChannelDataset
|
from dataset import MultiChannelDataset
|
||||||
from filter.modelV3_10 import VideoClassifierV3_10, AdaptiveRecallLoss
|
from filter.modelV3_15 import AdaptiveRecallLoss, VideoClassifierV3_15
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score, classification_report
|
from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score, classification_report
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.tensorboard import SummaryWriter # 引入 TensorBoard
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
import time
|
import time
|
||||||
from embedding import prepare_batch
|
from embedding import prepare_batch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
# 动态生成子目录名称
|
|
||||||
run_name = f"run_{time.strftime('%Y%m%d_%H%M')}"
|
run_name = f"run_{time.strftime('%Y%m%d_%H%M')}"
|
||||||
log_dir = os.path.join('./filter/runs', run_name)
|
log_dir = os.path.join('./filter/runs', run_name)
|
||||||
|
|
||||||
@ -52,9 +51,8 @@ class_weights = torch.tensor(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 初始化模型和SentenceTransformer
|
# 初始化模型和SentenceTransformer
|
||||||
sentence_transformer = SentenceTransformer("Thaweewat/jina-embedding-v3-m2v-1024")
|
model = VideoClassifierV3_15()
|
||||||
model = VideoClassifierV3_10()
|
checkpoint_name = './filter/checkpoints/best_model_V3.17.pt'
|
||||||
checkpoint_name = './filter/checkpoints/best_model_V3.11.pt'
|
|
||||||
|
|
||||||
# 模型保存路径
|
# 模型保存路径
|
||||||
os.makedirs('./filter/checkpoints', exist_ok=True)
|
os.makedirs('./filter/checkpoints', exist_ok=True)
|
||||||
@ -78,7 +76,7 @@ def evaluate(model, dataloader):
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
batch_tensor = prepare_batch(batch['texts'], device="cpu")
|
batch_tensor = prepare_batch(batch['texts'])
|
||||||
logits = model(batch_tensor)
|
logits = model(batch_tensor)
|
||||||
preds = torch.argmax(logits, dim=1)
|
preds = torch.argmax(logits, dim=1)
|
||||||
all_preds.extend(preds.cpu().numpy())
|
all_preds.extend(preds.cpu().numpy())
|
||||||
@ -111,9 +109,8 @@ for epoch in range(num_epochs):
|
|||||||
for batch_idx, batch in enumerate(train_loader):
|
for batch_idx, batch in enumerate(train_loader):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
batch_tensor = prepare_batch(batch['texts'], device="cpu")
|
batch_tensor = prepare_batch(batch['texts'])
|
||||||
|
|
||||||
# 传入文本字典和sentence_transformer
|
|
||||||
logits = model(batch_tensor)
|
logits = model(batch_tensor)
|
||||||
|
|
||||||
loss = criterion(logits, batch['label'])
|
loss = criterion(logits, batch['label'])
|
0
lab/.gitignore → ml/lab/.gitignore
vendored
0
lab/.gitignore → ml/lab/.gitignore
vendored
12
ml/pred/count.py
Normal file
12
ml/pred/count.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
# iterate all json files in ./data/pred
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for filename in os.listdir('./data/pred'):
|
||||||
|
if filename.endswith('.json'):
|
||||||
|
with open('./data/pred/' + filename, 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
count += len(data)
|
||||||
|
print(count)
|
19
ml/pred/crawler.py
Normal file
19
ml/pred/crawler.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import os
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
with open("./pred/2", "r") as fp:
|
||||||
|
raw = fp.readlines()
|
||||||
|
aids = [ int(x.strip()) for x in raw ]
|
||||||
|
|
||||||
|
for aid in aids:
|
||||||
|
if os.path.exists(f"./data/pred/{aid}.json"):
|
||||||
|
continue
|
||||||
|
url = f"https://api.bunnyxt.com/tdd/v2/video/{aid}/record?last_count=5000"
|
||||||
|
r = requests.get(url)
|
||||||
|
data = r.json()
|
||||||
|
with open (f"./data/pred/{aid}.json", "w") as fp:
|
||||||
|
json.dump(data, fp, ensure_ascii=False, indent=4)
|
||||||
|
time.sleep(5)
|
||||||
|
print(aid)
|
178
ml/pred/dataset.py
Normal file
178
ml/pred/dataset.py
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import bisect
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
class VideoPlayDataset(Dataset):
|
||||||
|
def __init__(self, data_dir, publish_time_path, term='long', seed=42):
|
||||||
|
if seed is not None:
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
self.data_dir = data_dir
|
||||||
|
self.series_dict = self._load_and_process_data(publish_time_path)
|
||||||
|
self.valid_series = [s for s in self.series_dict.values() if len(s['abs_time']) > 1]
|
||||||
|
self.term = term
|
||||||
|
# Set time window based on term
|
||||||
|
self.time_window = 1000 * 24 * 3600 if term == 'long' else 7 * 24 * 3600
|
||||||
|
MINUTE = 60
|
||||||
|
HOUR = 3600
|
||||||
|
DAY = 24 * HOUR
|
||||||
|
|
||||||
|
if term == 'long':
|
||||||
|
self.feature_windows = [
|
||||||
|
1 * HOUR,
|
||||||
|
6 * HOUR,
|
||||||
|
1 *DAY,
|
||||||
|
3 * DAY,
|
||||||
|
7 * DAY,
|
||||||
|
30 * DAY,
|
||||||
|
100 * DAY
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
self.feature_windows = [
|
||||||
|
( 15 * MINUTE, 0 * MINUTE),
|
||||||
|
( 40 * MINUTE, 0 * MINUTE),
|
||||||
|
( 1 * HOUR, 0 * HOUR),
|
||||||
|
( 2 * HOUR, 1 * HOUR),
|
||||||
|
( 3 * HOUR, 2 * HOUR),
|
||||||
|
( 3 * HOUR, 0 * HOUR),
|
||||||
|
#( 6 * HOUR, 3 * HOUR),
|
||||||
|
( 6 * HOUR, 0 * HOUR),
|
||||||
|
(18 * HOUR, 12 * HOUR),
|
||||||
|
#( 1 * DAY, 6 * HOUR),
|
||||||
|
( 1 * DAY, 0 * DAY),
|
||||||
|
#( 2 * DAY, 1 * DAY),
|
||||||
|
( 3 * DAY, 0 * DAY),
|
||||||
|
#( 4 * DAY, 1 * DAY),
|
||||||
|
( 7 * DAY, 0 * DAY)
|
||||||
|
]
|
||||||
|
|
||||||
|
def _extract_features(self, series, current_idx, target_idx):
|
||||||
|
current_time = series['abs_time'][current_idx]
|
||||||
|
current_play = series['play_count'][current_idx]
|
||||||
|
dt = datetime.datetime.fromtimestamp(current_time)
|
||||||
|
|
||||||
|
if self.term == 'long':
|
||||||
|
time_features = [
|
||||||
|
np.log2(max(current_time - series['create_time'], 1))
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
time_features = [
|
||||||
|
(dt.hour * 3600 + dt.minute * 60 + dt.second) / 86400,
|
||||||
|
(dt.weekday() * 24 + dt.hour) / 168,
|
||||||
|
np.log2(max(current_time - series['create_time'], 1))
|
||||||
|
]
|
||||||
|
|
||||||
|
growth_features = []
|
||||||
|
if self.term == 'long':
|
||||||
|
for window in self.feature_windows:
|
||||||
|
prev_time = current_time - window
|
||||||
|
prev_idx = self._get_nearest_value(series, prev_time, current_idx)
|
||||||
|
if prev_idx is not None:
|
||||||
|
time_diff = current_time - series['abs_time'][prev_idx]
|
||||||
|
play_diff = current_play - series['play_count'][prev_idx]
|
||||||
|
scaled_diff = play_diff / (time_diff / window) if time_diff > 0 else 0.0
|
||||||
|
else:
|
||||||
|
scaled_diff = 0.0
|
||||||
|
growth_features.append(np.log2(max(scaled_diff, 1)))
|
||||||
|
else:
|
||||||
|
for window_start, window_end in self.feature_windows:
|
||||||
|
prev_time_start = current_time - window_start
|
||||||
|
prev_time_end = current_time - window_end # window_end is typically 0
|
||||||
|
prev_idx_start = self._get_nearest_value(series, prev_time_start, current_idx)
|
||||||
|
prev_idx_end = self._get_nearest_value(series, prev_time_end, current_idx)
|
||||||
|
if prev_idx_start is not None and prev_idx_end is not None:
|
||||||
|
time_diff = series['abs_time'][prev_idx_end] - series['abs_time'][prev_idx_start]
|
||||||
|
play_diff = series['play_count'][prev_idx_end] - series['play_count'][prev_idx_start]
|
||||||
|
scaled_diff = play_diff / (time_diff / (window_start - window_end)) if time_diff > 0 else 0.0
|
||||||
|
else:
|
||||||
|
scaled_diff = 0.0
|
||||||
|
growth_features.append(np.log2(max(scaled_diff, 1)))
|
||||||
|
|
||||||
|
time_diff = series['abs_time'][target_idx] - current_time
|
||||||
|
return [np.log2(max(time_diff, 1))] + [np.log2(current_play + 1)] + growth_features + time_features
|
||||||
|
|
||||||
|
def _load_and_process_data(self, publish_time_path):
|
||||||
|
publish_df = pd.read_csv(publish_time_path)
|
||||||
|
publish_df['published_at'] = pd.to_datetime(publish_df['published_at'])
|
||||||
|
publish_dict = dict(zip(publish_df['aid'], publish_df['published_at']))
|
||||||
|
series_dict = {}
|
||||||
|
for filename in os.listdir(self.data_dir):
|
||||||
|
if not filename.endswith('.json'):
|
||||||
|
continue
|
||||||
|
with open(os.path.join(self.data_dir, filename), 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
if 'code' in data:
|
||||||
|
continue
|
||||||
|
for item in data:
|
||||||
|
aid = item['aid']
|
||||||
|
published_time = pd.to_datetime(publish_dict[aid]).timestamp()
|
||||||
|
if aid not in series_dict:
|
||||||
|
series_dict[aid] = {
|
||||||
|
'abs_time': [],
|
||||||
|
'play_count': [],
|
||||||
|
'create_time': published_time
|
||||||
|
}
|
||||||
|
series_dict[aid]['abs_time'].append(item['added'])
|
||||||
|
series_dict[aid]['play_count'].append(item['view'])
|
||||||
|
# Sort each series by absolute time
|
||||||
|
for aid in series_dict:
|
||||||
|
sorted_indices = sorted(range(len(series_dict[aid]['abs_time'])),
|
||||||
|
key=lambda k: series_dict[aid]['abs_time'][k])
|
||||||
|
series_dict[aid]['abs_time'] = [series_dict[aid]['abs_time'][i] for i in sorted_indices]
|
||||||
|
series_dict[aid]['play_count'] = [series_dict[aid]['play_count'][i] for i in sorted_indices]
|
||||||
|
return series_dict
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return 100000 # Virtual length for sampling
|
||||||
|
|
||||||
|
def _get_nearest_value(self, series, target_time, current_idx):
|
||||||
|
times = series['abs_time']
|
||||||
|
pos = bisect.bisect_right(times, target_time, 0, current_idx + 1)
|
||||||
|
candidates = []
|
||||||
|
if pos > 0:
|
||||||
|
candidates.append(pos - 1)
|
||||||
|
if pos <= current_idx:
|
||||||
|
candidates.append(pos)
|
||||||
|
if not candidates:
|
||||||
|
return None
|
||||||
|
closest_idx = min(candidates, key=lambda i: abs(times[i] - target_time))
|
||||||
|
return closest_idx
|
||||||
|
|
||||||
|
def __getitem__(self, _idx):
|
||||||
|
while True:
|
||||||
|
series = random.choice(self.valid_series)
|
||||||
|
if len(series['abs_time']) < 2:
|
||||||
|
continue
|
||||||
|
current_idx = random.randint(0, len(series['abs_time']) - 2)
|
||||||
|
current_time = series['abs_time'][current_idx]
|
||||||
|
max_target_time = current_time + self.time_window
|
||||||
|
candidate_indices = []
|
||||||
|
for j in range(current_idx + 1, len(series['abs_time'])):
|
||||||
|
if series['abs_time'][j] > max_target_time:
|
||||||
|
break
|
||||||
|
candidate_indices.append(j)
|
||||||
|
if not candidate_indices:
|
||||||
|
continue
|
||||||
|
target_idx = random.choice(candidate_indices)
|
||||||
|
break
|
||||||
|
current_play = series['play_count'][current_idx]
|
||||||
|
target_play = series['play_count'][target_idx]
|
||||||
|
target_delta = max(target_play - current_play, 0)
|
||||||
|
return {
|
||||||
|
'features': torch.FloatTensor(self._extract_features(series, current_idx, target_idx)),
|
||||||
|
'target': torch.log2(torch.FloatTensor([target_delta]) + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
def collate_fn(batch):
|
||||||
|
return {
|
||||||
|
'features': torch.stack([x['features'] for x in batch]),
|
||||||
|
'targets': torch.stack([x['target'] for x in batch])
|
||||||
|
}
|
28
ml/pred/export_onnx.py
Normal file
28
ml/pred/export_onnx.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
import torch
|
||||||
|
import torch.onnx
|
||||||
|
from model import CompactPredictor
|
||||||
|
|
||||||
|
def export_model(input_size, checkpoint_path, onnx_path):
|
||||||
|
model = CompactPredictor(input_size)
|
||||||
|
model.load_state_dict(torch.load(checkpoint_path))
|
||||||
|
|
||||||
|
dummy_input = torch.randn(1, input_size)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
torch.onnx.export(model, # Model to be exported
|
||||||
|
dummy_input, # Model input
|
||||||
|
onnx_path, # Save path
|
||||||
|
export_params=True, # Whether to export model parameters
|
||||||
|
opset_version=11, # ONNX opset version
|
||||||
|
do_constant_folding=True, # Whether to perform constant folding optimization
|
||||||
|
input_names=['input'], # Input node name
|
||||||
|
output_names=['output'], # Output node name
|
||||||
|
dynamic_axes={'input': {0: 'batch_size'}, # Dynamic batch size
|
||||||
|
'output': {0: 'batch_size'}})
|
||||||
|
|
||||||
|
print(f"ONNX model has been exported to: {onnx_path}")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
export_model(10, './pred/checkpoints/long_term.pt', 'long_term.onnx')
|
||||||
|
export_model(12, './pred/checkpoints/short_term.pt', 'short_term.onnx')
|
32
ml/pred/inference.py
Normal file
32
ml/pred/inference.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import datetime
|
||||||
|
import numpy as np
|
||||||
|
from model import CompactPredictor
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def main():
|
||||||
|
model = CompactPredictor(10).to('cpu', dtype=torch.float32)
|
||||||
|
model.load_state_dict(torch.load('./pred/checkpoints/long_term.pt'))
|
||||||
|
model.eval()
|
||||||
|
# inference
|
||||||
|
initial = 997029
|
||||||
|
last = initial
|
||||||
|
start_time = '2025-03-17 00:13:17'
|
||||||
|
for i in range(1, 120):
|
||||||
|
hour = i / 0.5
|
||||||
|
sec = hour * 3600
|
||||||
|
time_d = np.log2(sec)
|
||||||
|
data = [time_d, np.log2(initial+1), # time_delta, current_views
|
||||||
|
6.111542, 8.404707, 10.071566, 11.55888, 12.457823,# grows_feat
|
||||||
|
0.009225, 0.001318, 28.001814# time_feat
|
||||||
|
]
|
||||||
|
np_arr = np.array([data])
|
||||||
|
tensor = torch.from_numpy(np_arr).to('cpu', dtype=torch.float32)
|
||||||
|
output = model(tensor)
|
||||||
|
num = output.detach().numpy()[0][0]
|
||||||
|
views_pred = int(np.exp2(num)) + initial
|
||||||
|
current_time = datetime.datetime.strptime(start_time, '%Y-%m-%d %H:%M:%S') + datetime.timedelta(hours=hour)
|
||||||
|
print(current_time.strftime('%m-%d %H:%M:%S'), views_pred, views_pred - last)
|
||||||
|
last = views_pred
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user