Compare commits
17 Commits
v1-dataset
...
main
Author | SHA1 | Date | |
---|---|---|---|
fdff155673 | |||
c4ca9c7d4f | |||
65123d1b39 | |||
aeab34f84b | |||
37d2507f10 | |||
33754146c8 | |||
ae6f10a6f0 | |||
bf2c9a393a | |||
853d158c41 | |||
9f071ee0a0 | |||
66cf093177 | |||
237d2f5c96 | |||
01597c298d | |||
435faa4b92 | |||
580753bb6f | |||
6500e378be | |||
3bb222bda1 |
11
.gitignore
vendored
11
.gitignore
vendored
@ -9,9 +9,14 @@ token_to_id.json
|
|||||||
__pycache__
|
__pycache__
|
||||||
.env
|
.env
|
||||||
.env*
|
.env*
|
||||||
translate/output
|
translate/output*
|
||||||
translate/source
|
translate/source*
|
||||||
translate/result
|
translate/result
|
||||||
*.db
|
*.db
|
||||||
dataset/raw
|
dataset/raw
|
||||||
translate/special-spiders
|
translate/special-spiders
|
||||||
|
ugNMT/BPE/output*
|
||||||
|
ugNMT/BPE/codes
|
||||||
|
forced-alignment/segments
|
||||||
|
forced-alignment/data
|
||||||
|
forced-alignment/output.ttml
|
20
README.md
20
README.md
@ -1,18 +1,28 @@
|
|||||||
# sparkastML
|
# sparkastML
|
||||||
|
|
||||||
This repository houses the machine learning components for the [sparkast](https://github.com/alikia2x/sparkast) project.
|
This repository contains the machine learning components for the [sparkast](https://github.com/alikia2x/sparkast) project.
|
||||||
|
|
||||||
The primary objective of this project is to enhance the search functionality of sparkast, allowing users to receive real-time answers as they type their queries.
|
The main goal of this project is to improve the search functionality of sparkast, enabling users to receive real-time answers as they type their queries.
|
||||||
|
|
||||||
## Intention Classification
|
## Intention Classification
|
||||||
|
|
||||||
The model located in the `/intention-classify` directory is designed to categorize user queries into predefined classes.
|
The model in the `/intention-classify` directory is designed to categorize user queries into predefined classes.
|
||||||
|
|
||||||
We utilize a Convolutional Neural Network (CNN) architecture in conjunction with an Energy-based Model for open-set recognition.
|
We use a Convolutional Neural Network (CNN) architecture combined with an Energy-based Model for open-set recognition.
|
||||||
|
|
||||||
This model is optimized to be lightweight, ensuring it can run on a wide range of devices, including within the browser environment.
|
This model is optimized to be lightweight, ensuring it can run on a wide range of devices, including within the browser environment.
|
||||||
|
|
||||||
For a detailed explanation of how it works, you can refer to [this blog post](https://blog.alikia2x.com/en/posts/sparkastml-intention/).
|
For a detailed explanation of how it works, refer to [this blog post](https://blog.alikia2x.com/en/posts/sparkastml-intention/).
|
||||||
|
|
||||||
|
## Translation
|
||||||
|
|
||||||
|
Language barriers are one of the biggest obstacles to communication between civilizations. In modern times, with the development of computer science and artificial intelligence, machine translation is bridging this gap and building a modern Tower of Babel.
|
||||||
|
|
||||||
|
Unfortunately, many machine translation systems are owned by commercial companies, which seriously hinders the development of freedom and innovation.
|
||||||
|
|
||||||
|
Therefore, sparkastML is on a mission to challenge commercial machine translation. We decided to tackle the translation between Chinese and English first. These are two languages with a long history and a large number of users. Their writing methods and expression habits are very different, which brings challenges to the project.
|
||||||
|
|
||||||
|
For more details, visit [this page](./translate/README.md).
|
||||||
|
|
||||||
## Dataset
|
## Dataset
|
||||||
|
|
||||||
|
@ -8,10 +8,15 @@ This dataset features high-quality, fresh synthetic data comprising over 100,000
|
|||||||
|
|
||||||
### Details
|
### Details
|
||||||
|
|
||||||
|
- **Source Language:** Chinese (Simplified)
|
||||||
|
- **Target Language:** English
|
||||||
- **Version:** 1
|
- **Version:** 1
|
||||||
- **Last Update:** 2024/09/16
|
- **Last Update:** 2024/09/16
|
||||||
|
- **LICENSE:** [CC-BY 4.0](https://creativecommons.org/licenses/by/4.0/)
|
||||||
|
|
||||||
### Download Links
|
### Download
|
||||||
|
|
||||||
- **Google Drive:** [Download from Google Drive](https://drive.google.com/drive/folders/1_ADblZcB5p9BUvawkYDmp1qIUDZgkkoe?usp=sharing)
|
- **Google Drive:** [Download from Google Drive](https://drive.google.com/drive/folders/1_ADblZcB5p9BUvawkYDmp1qIUDZgkkoe)
|
||||||
- **IPFS:** [Download from IPFS](https://ipfs.a2x.pub/ipfs/QmYz4ew4nSzPc6TZvoWk6jXpGN82qt3J46nwfb75N2YKc4/)
|
- **IPFS:** [Download from IPFS](https://ipfs.a2x.pub/ipfs/QmYz4ew4nSzPc6TZvoWk6jXpGN82qt3J46nwfb75N2YKc4/)
|
||||||
|
- CID: `QmYz4ew4nSzPc6TZvoWk6jXpGN82qt3J46nwfb75N2YKc4`
|
||||||
|
- **GitHub Release:** [Go to Release Page](https://github.com/alikia2x/sparkastML/releases/tag/v1-dataset)
|
||||||
|
7
forced-alignment/README.md
Normal file
7
forced-alignment/README.md
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
# 强制对齐在歌词逐字对齐上的应用
|
||||||
|
|
||||||
|
这个子项目是为了给[AquaVox](https://github.com/alikia2x/aquavox)提供AI加持的逐字歌词功能所诞生的。
|
||||||
|
|
||||||
|
## 规划
|
||||||
|
|
||||||
|
对于给定歌词和
|
191
forced-alignment/probs_distribution.ipynb
Normal file
191
forced-alignment/probs_distribution.ipynb
Normal file
File diff suppressed because one or more lines are too long
228
forced-alignment/split.py
Normal file
228
forced-alignment/split.py
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from typing import List
|
||||||
|
from pypinyin import lazy_pinyin
|
||||||
|
from pypinyin_dict.phrase_pinyin_data import cc_cedict
|
||||||
|
from torchaudio.transforms import Resample
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def compute_alignments(waveform: torch.Tensor, transcript: List[str]):
|
||||||
|
with torch.inference_mode():
|
||||||
|
emission, _ = model(waveform.to(device))
|
||||||
|
token_spans = aligner(emission[0], tokenizer(transcript))
|
||||||
|
return emission, token_spans
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
from torchaudio.pipelines import MMS_FA as bundle
|
||||||
|
|
||||||
|
model = bundle.get_model()
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
tokenizer = bundle.get_tokenizer()
|
||||||
|
aligner = bundle.get_aligner()
|
||||||
|
|
||||||
|
cc_cedict.load()
|
||||||
|
|
||||||
|
add_spaces = lambda s: ' '.join(s)
|
||||||
|
|
||||||
|
from pydub import AudioSegment
|
||||||
|
|
||||||
|
def get_audio_duration(file_path):
|
||||||
|
"""
|
||||||
|
读取音频文件并获取其时长(秒数)。
|
||||||
|
|
||||||
|
:param file_path: 音频文件的路径
|
||||||
|
:return: 音频文件的时长(秒数)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
audio = AudioSegment.from_file(file_path)
|
||||||
|
duration_in_seconds = len(audio) / 1000.0
|
||||||
|
return duration_in_seconds
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading audio file: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def timestamp(seconds):
|
||||||
|
"""
|
||||||
|
将浮点数的秒钟转换为TTML的时间戳格式(HH:MM:SS.sss)。
|
||||||
|
|
||||||
|
:param seconds: 浮点数的秒钟
|
||||||
|
:return: TTML时间戳格式字符串
|
||||||
|
"""
|
||||||
|
hours = int(seconds // 3600)
|
||||||
|
minutes = int((seconds % 3600) // 60)
|
||||||
|
seconds = seconds % 60
|
||||||
|
milliseconds = int((seconds % 1) * 1000)
|
||||||
|
seconds = int(seconds)
|
||||||
|
|
||||||
|
return f"{hours:02}:{minutes:02}:{seconds:02}.{milliseconds:03}"
|
||||||
|
|
||||||
|
def timestamp_inverse(ttml_timestamp):
|
||||||
|
"""
|
||||||
|
将TTML的时间戳格式字符串(HH:MM:SS.sss)转换为浮点数的秒钟。
|
||||||
|
|
||||||
|
:param ttml_timestamp: TTML时间戳格式字符串
|
||||||
|
:return: 浮点数的秒钟
|
||||||
|
"""
|
||||||
|
parts = ttml_timestamp.split(':')
|
||||||
|
hours = int(parts[0])
|
||||||
|
minutes = int(parts[1])
|
||||||
|
seconds_and_milliseconds = parts[2].split('.')
|
||||||
|
seconds = int(seconds_and_milliseconds[0])
|
||||||
|
milliseconds = int(seconds_and_milliseconds[1])
|
||||||
|
|
||||||
|
total_seconds = hours * 3600 + minutes * 60 + seconds + milliseconds / 1000
|
||||||
|
|
||||||
|
return total_seconds
|
||||||
|
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
def extract_numbers_from_files(directory):
|
||||||
|
"""
|
||||||
|
读取给定目录,提取文件名中的数字部分,并返回一个包含这些数字的列表。
|
||||||
|
|
||||||
|
:param directory: 目录路径
|
||||||
|
:return: 包含数字的列表
|
||||||
|
"""
|
||||||
|
numbers = []
|
||||||
|
pattern = re.compile(r'line-(\d+)\.wav')
|
||||||
|
|
||||||
|
try:
|
||||||
|
for filename in os.listdir(directory):
|
||||||
|
match = pattern.match(filename)
|
||||||
|
if match:
|
||||||
|
number = int(match.group(1))
|
||||||
|
numbers.append(number)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading directory: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return numbers
|
||||||
|
|
||||||
|
class TTMLGenerator:
|
||||||
|
def __init__(self, duration, xmlns="http://www.w3.org/ns/ttml", xmlns_ttm="http://www.w3.org/ns/ttml#metadata", xmlns_amll="http://www.example.com/ns/amll", xmlns_itunes="http://music.apple.com/lyric-ttml-internal"):
|
||||||
|
self.tt = ET.Element("tt", attrib={
|
||||||
|
"xmlns": xmlns,
|
||||||
|
"xmlns:ttm": xmlns_ttm,
|
||||||
|
"xmlns:amll": xmlns_amll,
|
||||||
|
"xmlns:itunes": xmlns_itunes
|
||||||
|
})
|
||||||
|
self.head = ET.SubElement(self.tt, "head")
|
||||||
|
self.metadata = ET.SubElement(self.head, "metadata")
|
||||||
|
self.body = ET.SubElement(self.tt, "body", attrib={"dur": duration})
|
||||||
|
self.div = ET.SubElement(self.body, "div")
|
||||||
|
|
||||||
|
def add_lyrics(self, begin, end, agent, itunes_key, words):
|
||||||
|
p = ET.SubElement(self.div, "p", attrib={
|
||||||
|
"begin": begin,
|
||||||
|
"end": end,
|
||||||
|
"ttm:agent": agent,
|
||||||
|
"itunes:key": itunes_key
|
||||||
|
})
|
||||||
|
for word, start, stop in words:
|
||||||
|
span = ET.SubElement(p, "span", attrib={"begin": start, "end": stop})
|
||||||
|
span.text = word
|
||||||
|
|
||||||
|
def save(self, filename):
|
||||||
|
tree = ET.ElementTree(self.tt)
|
||||||
|
tree.write(filename, encoding="utf-8", xml_declaration=True)
|
||||||
|
|
||||||
|
duration = get_audio_duration("./data/谷雨.mp3")
|
||||||
|
|
||||||
|
# 示例使用
|
||||||
|
ttml_generator = TTMLGenerator(duration=timestamp(duration))
|
||||||
|
|
||||||
|
|
||||||
|
def process_line(line_idx, start_time, total_lines):
|
||||||
|
with open(f"./segments/line-{line_idx}.txt", "r") as f:
|
||||||
|
text = f.read()
|
||||||
|
|
||||||
|
waveform, sample_rate = torchaudio.load(f"./segments/line-{line_idx}.wav")
|
||||||
|
|
||||||
|
waveform = waveform[0:1]
|
||||||
|
resampler = Resample(orig_freq=sample_rate, new_freq=16000)
|
||||||
|
waveform = resampler(waveform)
|
||||||
|
|
||||||
|
text_pinyin = lazy_pinyin(text)
|
||||||
|
text_normalized = " ".join(text_pinyin)
|
||||||
|
|
||||||
|
transcript = text_normalized.split()
|
||||||
|
emission, token_spans = compute_alignments(waveform, transcript)
|
||||||
|
num_frames = emission.size(1)
|
||||||
|
ratio = waveform.size(1) / num_frames
|
||||||
|
|
||||||
|
words = []
|
||||||
|
for i in range(len(token_spans)):
|
||||||
|
spans = token_spans[i]
|
||||||
|
x0 = start_time + int(ratio * spans[0].start) / 16000
|
||||||
|
x1 = start_time + int(ratio * spans[-1].end) / 16000
|
||||||
|
words.append({
|
||||||
|
"word": text[i],
|
||||||
|
"start": x0,
|
||||||
|
"end": x1
|
||||||
|
})
|
||||||
|
idx=0
|
||||||
|
for item in words:
|
||||||
|
if idx == len(words) - 1:
|
||||||
|
break
|
||||||
|
item["end"] = words[idx + 1]["start"]
|
||||||
|
idx+=1
|
||||||
|
result = []
|
||||||
|
for word in words:
|
||||||
|
result.append((word["word"], timestamp(word["start"]), timestamp(word["end"])))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
lines_to_process = sorted(extract_numbers_from_files("segments"))
|
||||||
|
|
||||||
|
def parse_lrc(lrc_file, audio_len):
|
||||||
|
"""解析LRC文件,返回一个包含时间戳和歌词的列表"""
|
||||||
|
with open(lrc_file, 'r', encoding='utf-8') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
lrc_data = []
|
||||||
|
for line in lines:
|
||||||
|
# 使用正则表达式匹配时间戳和歌词
|
||||||
|
match = re.match(r'\[(\d+):(\d+\.\d+)\](.*)', line)
|
||||||
|
if match:
|
||||||
|
minutes = int(match.group(1))
|
||||||
|
seconds = float(match.group(2))
|
||||||
|
lyric = match.group(3).strip()
|
||||||
|
lyric = lyric.replace(" ", "")
|
||||||
|
timestamp = minutes * 60 + seconds
|
||||||
|
lrc_data.append((lyric, timestamp))
|
||||||
|
|
||||||
|
for i, (lyric, start_time) in enumerate(lrc_data):
|
||||||
|
# Skip empty line
|
||||||
|
if lyric.strip() == "":
|
||||||
|
continue
|
||||||
|
if i < len(lrc_data) - 1:
|
||||||
|
end_time = lrc_data[i + 1][1]
|
||||||
|
else:
|
||||||
|
end_time = audio_len
|
||||||
|
lrc_data[i] = (lyric, start_time, end_time)
|
||||||
|
|
||||||
|
# Filter empty lines again
|
||||||
|
lrc_data = [line for line in lrc_data if line[0].strip() != ""]
|
||||||
|
|
||||||
|
return lrc_data
|
||||||
|
|
||||||
|
lrc_data = parse_lrc("./data/谷雨.lrc", duration)
|
||||||
|
|
||||||
|
i=0
|
||||||
|
for line_num in tqdm(lines_to_process):
|
||||||
|
start_time = lrc_data[i][1]
|
||||||
|
end_time = lrc_data[i][2]
|
||||||
|
result = process_line(line_num, start_time, len(lines_to_process))
|
||||||
|
ttml_generator.add_lyrics(
|
||||||
|
begin=timestamp(start_time), end=timestamp(end_time), agent="v1", itunes_key=f"L{i+1}",
|
||||||
|
words=result
|
||||||
|
)
|
||||||
|
i+=1
|
||||||
|
|
||||||
|
# 保存文件
|
||||||
|
ttml_generator.save("output.ttml")
|
57
forced-alignment/split_song.py
Normal file
57
forced-alignment/split_song.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
from pydub import AudioSegment
|
||||||
|
import re
|
||||||
|
|
||||||
|
def parse_lrc(lrc_file):
|
||||||
|
"""解析LRC文件,返回一个包含时间戳和歌词的列表"""
|
||||||
|
with open(lrc_file, 'r', encoding='utf-8') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
lrc_data = []
|
||||||
|
for line in lines:
|
||||||
|
# 使用正则表达式匹配时间戳和歌词
|
||||||
|
match = re.match(r'\[(\d+):(\d+\.\d+)\](.*)', line)
|
||||||
|
if match:
|
||||||
|
minutes = int(match.group(1))
|
||||||
|
seconds = float(match.group(2))
|
||||||
|
lyric = match.group(3).strip()
|
||||||
|
lyric = lyric.replace(" ", "")
|
||||||
|
timestamp = minutes * 60 + seconds
|
||||||
|
lrc_data.append((timestamp, lyric))
|
||||||
|
|
||||||
|
return lrc_data
|
||||||
|
|
||||||
|
def split_audio_by_lrc(audio_file, lrc_data, output_prefix):
|
||||||
|
"""根据LRC数据分割音频文件,并保存为单独的WAV文件"""
|
||||||
|
audio = AudioSegment.from_file(audio_file)
|
||||||
|
|
||||||
|
for i, (start_time, lyric) in enumerate(lrc_data):
|
||||||
|
# Skip empty line
|
||||||
|
if lyric.strip() == "":
|
||||||
|
continue
|
||||||
|
if i < len(lrc_data) - 1:
|
||||||
|
end_time = lrc_data[i + 1][0]
|
||||||
|
else:
|
||||||
|
end_time = len(audio) / 1000 # 最后一行歌词到音频结束
|
||||||
|
start_time = max(0, start_time - 0.1) # 前后各扩0.1秒
|
||||||
|
end_time = min(len(audio) / 1000, end_time + 0.1)
|
||||||
|
start_time_ms = start_time * 1000
|
||||||
|
end_time_ms = end_time * 1000
|
||||||
|
|
||||||
|
segment = audio[start_time_ms:end_time_ms]
|
||||||
|
output_file = f"{output_prefix}-{i+1}.wav"
|
||||||
|
output_script = f"{output_prefix}-{i+1}.txt"
|
||||||
|
output_time = f"{output_prefix}-{i+1}.time"
|
||||||
|
segment.export(output_file, format="wav")
|
||||||
|
with open(output_script, "w") as f:
|
||||||
|
f.write(lyric)
|
||||||
|
with open(output_time, "w") as f:
|
||||||
|
f.write(str(start_time)+","+str(end_time))
|
||||||
|
print(f"Saved {output_file}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
lrc_file = "./data/谷雨.lrc" # LRC文件路径
|
||||||
|
audio_file = "./data/谷雨.mp3" # 音频文件路径
|
||||||
|
output_prefix = "segments/line" # 输出文件名的前缀
|
||||||
|
|
||||||
|
lrc_data = parse_lrc(lrc_file)
|
||||||
|
split_audio_by_lrc(audio_file, lrc_data, output_prefix)
|
88
forced-alignment/split_whole.py
Normal file
88
forced-alignment/split_whole.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from typing import List
|
||||||
|
from pypinyin import lazy_pinyin
|
||||||
|
from pypinyin_dict.phrase_pinyin_data import cc_cedict
|
||||||
|
from torchaudio.transforms import Resample
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from pydub import AudioSegment
|
||||||
|
|
||||||
|
def get_audio_duration(file_path):
|
||||||
|
"""
|
||||||
|
读取音频文件并获取其时长(秒数)。
|
||||||
|
|
||||||
|
:param file_path: 音频文件的路径
|
||||||
|
:return: 音频文件的时长(秒数)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
audio = AudioSegment.from_file(file_path)
|
||||||
|
duration_in_seconds = len(audio) / 1000.0
|
||||||
|
return duration_in_seconds
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading audio file: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def timestamp(seconds):
|
||||||
|
"""
|
||||||
|
将浮点数的秒钟转换为SRT的时间戳格式(HH:MM:SS,sss)。
|
||||||
|
|
||||||
|
:param seconds: 浮点数的秒钟
|
||||||
|
:return: SRT时间戳格式字符串
|
||||||
|
"""
|
||||||
|
hours = int(seconds // 3600)
|
||||||
|
minutes = int((seconds % 3600) // 60)
|
||||||
|
seconds = seconds % 60
|
||||||
|
milliseconds = int((seconds % 1) * 1000)
|
||||||
|
seconds = int(seconds)
|
||||||
|
|
||||||
|
return f"{hours:02}:{minutes:02}:{seconds:02},{milliseconds:03}"
|
||||||
|
|
||||||
|
|
||||||
|
def compute_alignments(waveform: torch.Tensor, transcript: List[str]):
|
||||||
|
with torch.inference_mode():
|
||||||
|
emission, _ = model(waveform.to(device))
|
||||||
|
token_spans = aligner(emission[0], tokenizer(transcript))
|
||||||
|
return emission, token_spans
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
from torchaudio.pipelines import MMS_FA as bundle
|
||||||
|
|
||||||
|
model = bundle.get_model()
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
tokenizer = bundle.get_tokenizer()
|
||||||
|
aligner = bundle.get_aligner()
|
||||||
|
|
||||||
|
cc_cedict.load()
|
||||||
|
|
||||||
|
with open("./data/扬旗鸣鼓.txt", "r") as f:
|
||||||
|
text = f.read()
|
||||||
|
|
||||||
|
text_pinyin = lazy_pinyin(text)
|
||||||
|
text_normalized = " ".join(text_pinyin)
|
||||||
|
|
||||||
|
waveform, sample_rate = torchaudio.load("./data/扬旗鸣鼓.wav")
|
||||||
|
|
||||||
|
waveform = waveform[0:1]
|
||||||
|
resampler = Resample(orig_freq=sample_rate, new_freq=16000)
|
||||||
|
waveform = resampler(waveform)
|
||||||
|
|
||||||
|
transcript = text_normalized.split()
|
||||||
|
emission, token_spans = compute_alignments(waveform, transcript)
|
||||||
|
num_frames = emission.size(1)
|
||||||
|
|
||||||
|
ratio = waveform.size(1) / num_frames
|
||||||
|
|
||||||
|
duration = get_audio_duration("./data/扬旗鸣鼓.wav")
|
||||||
|
|
||||||
|
for i in range(len(token_spans)):
|
||||||
|
spans = token_spans[i]
|
||||||
|
x0 = round(int(ratio * spans[0].start) / 16000, 3)
|
||||||
|
x1 = round(int(ratio * spans[-1].end) / 16000, 3)
|
||||||
|
with open("1.srt", "a") as f:
|
||||||
|
f.write(f"{i+1}\n")
|
||||||
|
f.write(f"{timestamp(x0)} --> {timestamp(x1)}\n")
|
||||||
|
f.write(f"{transcript[i]}\n\n")
|
||||||
|
f.flush()
|
84
forced-alignment/srt2lrc.py
Normal file
84
forced-alignment/srt2lrc.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
def srt_to_lrc(srt_text):
|
||||||
|
# 使用正则表达式匹配时间戳和内容
|
||||||
|
# who the fuck knows this
|
||||||
|
srt_text+='\n\n'
|
||||||
|
pattern = re.compile(r'(\d{2}:\d{2}:\d{2},\d{3}) --> (\d{2}:\d{2}:\d{2},\d{3})\n(.+?)\n\n', re.DOTALL)
|
||||||
|
matches = pattern.findall(srt_text)
|
||||||
|
lrc_lines = []
|
||||||
|
|
||||||
|
for start_time, end_time, content in matches:
|
||||||
|
# 提取开始时间的高亮字符
|
||||||
|
highlight_char = re.search(r'<font color="#00ff00">(.+?)</font>', content)
|
||||||
|
if highlight_char:
|
||||||
|
highlight_char = highlight_char.group(1)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 将时间戳转换为LRC格式
|
||||||
|
f,start_minutes, start_seconds, start_milliseconds = map(int, start_time.replace(',', ':').split(':'))
|
||||||
|
f,end_minutes, end_seconds, end_milliseconds = map(int, end_time.replace(',', ':').split(':'))
|
||||||
|
|
||||||
|
start_time_lrc = f"{start_minutes:02d}:{start_seconds:02d}.{start_milliseconds:02d}"
|
||||||
|
end_time_lrc = f"{end_minutes:02d}:{end_seconds:02d}.{end_milliseconds:02d}"
|
||||||
|
|
||||||
|
# 构建LRC行
|
||||||
|
lrc_line = f"{highlight_char}|{start_time_lrc},{end_time_lrc}"
|
||||||
|
lrc_lines.append(lrc_line)
|
||||||
|
|
||||||
|
# 如果内容中有换行符,将其替换为空格
|
||||||
|
lrc_line = lrc_line.replace('\n', ' ')
|
||||||
|
|
||||||
|
return '\n'.join(lrc_lines)
|
||||||
|
|
||||||
|
with open('./data/谷雨.srt', 'r') as f:
|
||||||
|
srt_text = f.read()
|
||||||
|
|
||||||
|
whole = srt_text.splitlines()[2].replace('<font color="#00ff00">','').replace('</font>','')
|
||||||
|
whole = whole.replace(' ','\n')
|
||||||
|
lines = whole.splitlines()
|
||||||
|
|
||||||
|
lyric_text = ""
|
||||||
|
raw_text = srt_to_lrc(srt_text)
|
||||||
|
raw_lines = raw_text.splitlines()
|
||||||
|
for line in raw_lines:
|
||||||
|
lyric_text += line.split('|')[0]
|
||||||
|
|
||||||
|
raw_idx=0
|
||||||
|
lines_start_chr_idx=[]
|
||||||
|
for line in lines:
|
||||||
|
start = line[0]
|
||||||
|
end = line[-1]
|
||||||
|
while raw_idx < len(raw_lines) and not line.startswith(raw_lines[raw_idx].split("|")[0]):
|
||||||
|
raw_idx += 1
|
||||||
|
lines_start_chr_idx.append(raw_idx)
|
||||||
|
lines_start_chr_idx.append(len(raw_lines)-1)
|
||||||
|
|
||||||
|
raw_idx=0
|
||||||
|
lines_end_chr_idx=[]
|
||||||
|
for line in lines:
|
||||||
|
start = line[0]
|
||||||
|
end = line[-1]
|
||||||
|
while raw_idx < len(raw_lines) and not line.endswith(raw_lines[raw_idx].split("|")[0]):
|
||||||
|
raw_idx += 1
|
||||||
|
lines_end_chr_idx.append(raw_idx)
|
||||||
|
lines_end_chr_idx.append(len(raw_lines)-1)
|
||||||
|
|
||||||
|
lrc_text = ""
|
||||||
|
for i in range(len(lines_start_chr_idx)-1):
|
||||||
|
start = lines_start_chr_idx[i]
|
||||||
|
end = lines_end_chr_idx[i]
|
||||||
|
time_start = raw_lines[start].split("|")[1].split(',')[0]
|
||||||
|
time_end = raw_lines[end].split("|")[1].split(',')[0]
|
||||||
|
lrc_text += f"[{time_start}]{lyric_text[start:end+1]}\n[{time_end}]\n"
|
||||||
|
print(lrc_text)
|
||||||
|
|
||||||
|
lyric_len = len(lyric_text)
|
||||||
|
for i in range(len(lines_start_chr_idx)-1):
|
||||||
|
start = max(0,lines_start_chr_idx[i]-1)
|
||||||
|
end = min(lyric_len-1, lines_end_chr_idx[i]+1)
|
||||||
|
time_start = raw_lines[start].split("|")[1].split(',')[0]
|
||||||
|
time_end = raw_lines[end].split("|")[1].split(',')[0]
|
||||||
|
lrc_text += f"[{time_start}]{lyric_text[start:end+1]}\n[{time_end}]\n"
|
||||||
|
print(lrc_text)
|
328
forced-alignment/test.ipynb
Normal file
328
forced-alignment/test.ipynb
Normal file
File diff suppressed because one or more lines are too long
60
forced-alignment/test_split.py
Normal file
60
forced-alignment/test_split.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
import torch
|
||||||
|
import torchaudio
|
||||||
|
from typing import List
|
||||||
|
from pypinyin import lazy_pinyin
|
||||||
|
from pypinyin_dict.phrase_pinyin_data import cc_cedict
|
||||||
|
from torchaudio.transforms import Resample
|
||||||
|
|
||||||
|
|
||||||
|
def compute_alignments(waveform: torch.Tensor, transcript: List[str]):
|
||||||
|
with torch.inference_mode():
|
||||||
|
emission, _ = model(waveform.to(device))
|
||||||
|
token_spans = aligner(emission[0], tokenizer(transcript))
|
||||||
|
return emission, token_spans
|
||||||
|
|
||||||
|
# Compute average score weighted by the span length
|
||||||
|
def _score(spans):
|
||||||
|
return sum(s.score * len(s) for s in spans) / sum(len(s) for s in spans)
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
from torchaudio.pipelines import MMS_FA as bundle
|
||||||
|
|
||||||
|
model = bundle.get_model()
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
tokenizer = bundle.get_tokenizer()
|
||||||
|
aligner = bundle.get_aligner()
|
||||||
|
|
||||||
|
cc_cedict.load()
|
||||||
|
|
||||||
|
add_spaces = lambda s: ' '.join(s)
|
||||||
|
|
||||||
|
with open("./segments/line-1.txt", "r") as f:
|
||||||
|
text = f.read()
|
||||||
|
|
||||||
|
text_raw = add_spaces(text)
|
||||||
|
text_list = list(text)
|
||||||
|
text_pinyin = lazy_pinyin(text)
|
||||||
|
text_normalized = " ".join(text_pinyin)
|
||||||
|
|
||||||
|
waveform, sample_rate = torchaudio.load("./segments/line-1.wav")
|
||||||
|
|
||||||
|
waveform = waveform[0:1]
|
||||||
|
resampler = Resample(orig_freq=sample_rate, new_freq=16000)
|
||||||
|
waveform = resampler(waveform)
|
||||||
|
|
||||||
|
transcript = text_normalized.split()
|
||||||
|
emission, token_spans = compute_alignments(waveform, transcript)
|
||||||
|
num_frames = emission.size(1)
|
||||||
|
|
||||||
|
|
||||||
|
print("Raw Transcript: ", text_raw)
|
||||||
|
print("Normalized Transcript: ", text_normalized)
|
||||||
|
|
||||||
|
ratio = waveform.size(1) / num_frames
|
||||||
|
|
||||||
|
for i in range(len(token_spans)):
|
||||||
|
spans = token_spans[i]
|
||||||
|
x0 = round(int(ratio * spans[0].start) / 16000, 3)
|
||||||
|
x1 = round(int(ratio * spans[-1].end) / 16000, 3)
|
||||||
|
print(f"{text[i]}: {x0}-{x1}")
|
30
forced-alignment/ttml.py
Normal file
30
forced-alignment/ttml.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
|
||||||
|
import xml.etree.ElementTree as ET
|
||||||
|
|
||||||
|
class TTMLGenerator:
|
||||||
|
def __init__(self, duration, xmlns="http://www.w3.org/ns/ttml", xmlns_ttm="http://www.w3.org/ns/ttml#metadata", xmlns_amll="http://www.example.com/ns/amll", xmlns_itunes="http://music.apple.com/lyric-ttml-internal"):
|
||||||
|
self.tt = ET.Element("tt", attrib={
|
||||||
|
"xmlns": xmlns,
|
||||||
|
"xmlns:ttm": xmlns_ttm,
|
||||||
|
"xmlns:amll": xmlns_amll,
|
||||||
|
"xmlns:itunes": xmlns_itunes
|
||||||
|
})
|
||||||
|
self.head = ET.SubElement(self.tt, "head")
|
||||||
|
self.metadata = ET.SubElement(self.head, "metadata")
|
||||||
|
self.body = ET.SubElement(self.tt, "body", attrib={"dur": duration})
|
||||||
|
self.div = ET.SubElement(self.body, "div")
|
||||||
|
|
||||||
|
def add_lyrics(self, begin, end, agent, itunes_key, words):
|
||||||
|
p = ET.SubElement(self.div, "p", attrib={
|
||||||
|
"begin": begin,
|
||||||
|
"end": end,
|
||||||
|
"ttm:agent": agent,
|
||||||
|
"itunes:key": itunes_key
|
||||||
|
})
|
||||||
|
for word, start, stop in words:
|
||||||
|
span = ET.SubElement(p, "span", attrib={"begin": start, "end": stop})
|
||||||
|
span.text = word
|
||||||
|
|
||||||
|
def save(self, filename):
|
||||||
|
tree = ET.ElementTree(self.tt)
|
||||||
|
tree.write(filename, encoding="utf-8", xml_declaration=True)
|
1
intention-classify/NLU_meta.json
Normal file
1
intention-classify/NLU_meta.json
Normal file
@ -0,0 +1 @@
|
|||||||
|
{"idx_to_class": {"0": "weather", "1": "base64", "2": "url-encode", "3": "html-encode", "4": "ai.command", "5": "knowledge", "6": "ai.question", "7": "datetime"}, "threshold": 1.7}
|
@ -36,6 +36,7 @@
|
|||||||
"室外的温度是多少",
|
"室外的温度是多少",
|
||||||
"达拉斯今天热不热",
|
"达拉斯今天热不热",
|
||||||
"苏州现在天气怎么样",
|
"苏州现在天气怎么样",
|
||||||
|
"明天悉尼会下雨吗?",
|
||||||
"how's the weather",
|
"how's the weather",
|
||||||
"What's going on with the weather?",
|
"What's going on with the weather?",
|
||||||
"Can you give me an update on the weather?",
|
"Can you give me an update on the weather?",
|
||||||
@ -48,21 +49,21 @@
|
|||||||
"What's the weather like right now?",
|
"What's the weather like right now?",
|
||||||
"Tell me the current weather conditions.",
|
"Tell me the current weather conditions.",
|
||||||
"How about the weather today?",
|
"How about the weather today?",
|
||||||
"What's the weather looking like for the next few hours?",
|
"What's the weather looking like for the next few hours",
|
||||||
"Is it going to stay this way all day?",
|
"Is it going to stay this way all day",
|
||||||
"Could you give me a brief overview of the weather?",
|
"Could you give me a brief overview of the weather",
|
||||||
"What's the general weather situation in our area?",
|
"What's the general weather situation in our area",
|
||||||
"Is it cloudy or clear outside?",
|
"Is it cloudy or clear outside",
|
||||||
"What's the forecast saying for today's weather?",
|
"What's the forecast saying for today's weather",
|
||||||
"Is it going to be a warm day?",
|
"Is it going to be a warm day",
|
||||||
"Are we expecting any storms today?",
|
"Are we expecting any storms today",
|
||||||
"What's the weather condition outside my window?",
|
"What's the weather condition outside my window",
|
||||||
"Is it a typical day for this season in terms of weather?",
|
"Is it a typical day for this season in terms of weather",
|
||||||
"how's the weather now?",
|
"how's the weather now",
|
||||||
"What's the temperature like right now?",
|
"What's the temperature like right now",
|
||||||
"Can you tell me the current temperature?",
|
"Can you tell me the current temperature",
|
||||||
"How hot is it outside?",
|
"How hot is it outside",
|
||||||
"What's the temperature supposed to be today?",
|
"What's the temperature supposed to be today",
|
||||||
"What is the current temp outside?",
|
"What is the current temp outside?",
|
||||||
"Could you tell me the outdoor temperature?",
|
"Could you tell me the outdoor temperature?",
|
||||||
"Is it cold or warm outside?",
|
"Is it cold or warm outside?",
|
||||||
@ -81,8 +82,8 @@
|
|||||||
"Can you tell me the temp in the nearby area?",
|
"Can you tell me the temp in the nearby area?",
|
||||||
"Is it below freezing outside?",
|
"Is it below freezing outside?",
|
||||||
"What's the average temperature for today?",
|
"What's the average temperature for today?",
|
||||||
"Is the temperature dropping or rising?",
|
"Is the temperature dropping or rising",
|
||||||
"What should I wear considering the temperature?"
|
"What should I wear considering the temperature"
|
||||||
],
|
],
|
||||||
"base64": [
|
"base64": [
|
||||||
"请将数据使用base64编码",
|
"请将数据使用base64编码",
|
||||||
@ -110,17 +111,16 @@
|
|||||||
"解码 base64",
|
"解码 base64",
|
||||||
"Please encode this data with base64:",
|
"Please encode this data with base64:",
|
||||||
"I need to encode the following data in base64",
|
"I need to encode the following data in base64",
|
||||||
"Could you encode this string using base64?",
|
"Could you encode this string using base64",
|
||||||
"Convert this data to b64 encoding",
|
"Convert this data to b64 encoding",
|
||||||
"I want to encode this information with base64",
|
"I want to encode this information with base64",
|
||||||
"Help me encode this in base32",
|
"Help me encode this in base32",
|
||||||
"Can you encode this data to base64 format?",
|
"Can you encode this data to base64 format",
|
||||||
"b64 encode",
|
"b64 encode",
|
||||||
"base64 encode",
|
"base64 encode",
|
||||||
"encode base64",
|
"encode base64",
|
||||||
"base 64 encode online"
|
"base 64 encode online"
|
||||||
],
|
],
|
||||||
|
|
||||||
"url-encode": [
|
"url-encode": [
|
||||||
"编码 url",
|
"编码 url",
|
||||||
"URL部分需要编码",
|
"URL部分需要编码",
|
||||||
@ -145,7 +145,6 @@
|
|||||||
"url decoder",
|
"url decoder",
|
||||||
"URL encoder"
|
"URL encoder"
|
||||||
],
|
],
|
||||||
|
|
||||||
"html-encode": [
|
"html-encode": [
|
||||||
"请编码HTML实体",
|
"请编码HTML实体",
|
||||||
"文本转为HTML实体",
|
"文本转为HTML实体",
|
||||||
@ -186,7 +185,6 @@
|
|||||||
"html   conversion",
|
"html   conversion",
|
||||||
"html nbsp meaning"
|
"html nbsp meaning"
|
||||||
],
|
],
|
||||||
|
|
||||||
"ai.command": [
|
"ai.command": [
|
||||||
"写一个TypeScript的HelloWorld代码",
|
"写一个TypeScript的HelloWorld代码",
|
||||||
"检查以下内容的语法和清晰度",
|
"检查以下内容的语法和清晰度",
|
||||||
@ -237,11 +235,11 @@
|
|||||||
"help me learn chinese",
|
"help me learn chinese",
|
||||||
"how to let the screen reader automatically focused to an newly poped up element in the web development",
|
"how to let the screen reader automatically focused to an newly poped up element in the web development",
|
||||||
"summarize following text:",
|
"summarize following text:",
|
||||||
"Is there anything wrong with this code or can it be simplified?",
|
"Is there anything wrong with this code or can it be simplified",
|
||||||
"generate a Python script that prints 'Hello, World!'",
|
"generate a Python script that prints 'Hello, World!'",
|
||||||
"Can you proofread this essay for grammar and punctuation errors?",
|
"Can you proofread this essay for grammar and punctuation errors",
|
||||||
"Create a list of ten example sentences for the word 'serendipity.'",
|
"Create a list of ten example sentences for the word 'serendipity.'",
|
||||||
"Can you reformat this JSON to be more readable?",
|
"Can you reformat this JSON to be more readable",
|
||||||
"Suggest a creative title for my blog post about healthy eating.",
|
"Suggest a creative title for my blog post about healthy eating.",
|
||||||
"Refactor this JavaScript function to make it more efficient.",
|
"Refactor this JavaScript function to make it more efficient.",
|
||||||
"Help me practice French: provide a sentence with a missing word that I can guess.",
|
"Help me practice French: provide a sentence with a missing word that I can guess.",
|
||||||
@ -249,15 +247,15 @@
|
|||||||
"Summarize this news article for me.",
|
"Summarize this news article for me.",
|
||||||
"Can you review this code snippet for potential security vulnerabilities?",
|
"Can you review this code snippet for potential security vulnerabilities?",
|
||||||
"Generate a SQL query to find all users who signed up in the last 30 days.",
|
"Generate a SQL query to find all users who signed up in the last 30 days.",
|
||||||
"Can you translate this paragraph into Spanish?",
|
"Can you translate this paragraph into Spanish",
|
||||||
"Create a flowchart based on the following process description.",
|
"Create a flowchart based on the following process description.",
|
||||||
"Write a Python function to calculate the factorial of a number.",
|
"Write a Python function to calculate the factorial of a number.",
|
||||||
"Provide a detailed explanation of how to implement OAuth2 in a web application.",
|
"Provide a detailed explanation of how to implement OAuth2 in a web application.",
|
||||||
"Can you optimize this image for faster loading on a website?",
|
"Can you optimize this image for faster loading on a website",
|
||||||
"Suggest some catchy taglines for a new mobile app focused on fitness.",
|
"Suggest some catchy taglines for a new mobile app focused on fitness.",
|
||||||
"Write a Bash script to back up my documents folder daily.",
|
"Write a Bash script to back up my documents folder daily.",
|
||||||
"Help me draft an email to request a meeting with a potential client.",
|
"Help me draft an email to request a meeting with a potential client.",
|
||||||
"Can you convert this Markdown document into HTML?",
|
"Can you convert this Markdown document into HTML",
|
||||||
"Generate a Python script that scrapes data from a specified website.",
|
"Generate a Python script that scrapes data from a specified website.",
|
||||||
"Can you find the synonyms of the word 'meticulous'?",
|
"Can you find the synonyms of the word 'meticulous'?",
|
||||||
"Write a SQL query to join two tables based on a common column.",
|
"Write a SQL query to join two tables based on a common column.",
|
||||||
@ -267,31 +265,57 @@
|
|||||||
"Can you assist me in learning Japanese?",
|
"Can you assist me in learning Japanese?",
|
||||||
"How can I make an alert box appear when a user clicks a button on a webpage?",
|
"How can I make an alert box appear when a user clicks a button on a webpage?",
|
||||||
"Summarize this research paper into bullet points.",
|
"Summarize this research paper into bullet points.",
|
||||||
"Can you check if there are any logical errors in this algorithm?"
|
"Can you check if there are any logical errors in this algorithm?",
|
||||||
|
"请一步一步计算找到函数f(x)=U^2*x/(R+x)^2的顶点坐标。",
|
||||||
|
"如何理解transformer自注意力机制中的Q,K,V?它们分别代表什么?",
|
||||||
|
"帮我写一封求职信。先询问我的教育背景、技能和经验。",
|
||||||
|
"总结这篇论文",
|
||||||
|
"写一份10人晚宴的菜单",
|
||||||
|
"写一篇博客",
|
||||||
|
"写一段演讲稿"
|
||||||
],
|
],
|
||||||
|
"knowledge": [
|
||||||
"ai.question": [
|
|
||||||
"你认为哪个框架最适合性能敏感的项目?",
|
|
||||||
"什么是后量子密码学?",
|
"什么是后量子密码学?",
|
||||||
"什么是密钥派生函数",
|
"什么是密钥派生函数",
|
||||||
"什么是线性代数?",
|
"什么是线性代数?",
|
||||||
|
"量子计算的特点是什么",
|
||||||
|
"哈希函数的作用?",
|
||||||
|
"什么是微积分?",
|
||||||
|
"什么是区块链技术",
|
||||||
|
"What is post-quantum cryptography",
|
||||||
|
"What is a key derivation function?",
|
||||||
|
"What is Linear Algebra?",
|
||||||
|
"What is the main use of linear algebra in computer science",
|
||||||
|
"What is quantum computing",
|
||||||
|
"What is a hash function",
|
||||||
|
"What is calculus",
|
||||||
|
"什么是站点隔离?",
|
||||||
|
"What is blockchain technology?",
|
||||||
|
"BLEU 是什么",
|
||||||
|
"黎巴嫩在哪",
|
||||||
|
"什么是转义字符",
|
||||||
|
"MixAlpha售价多少",
|
||||||
|
"什么是神经机器翻译",
|
||||||
|
"什么是月食",
|
||||||
|
"什么是人工智能",
|
||||||
|
"什么是F1-score"
|
||||||
|
],
|
||||||
|
"ai.question": [
|
||||||
|
"人工智能真的有智力吗",
|
||||||
|
"你认为哪个框架最适合性能敏感的项目?",
|
||||||
"线性代数在计算机科学中的主要用途是什么?",
|
"线性代数在计算机科学中的主要用途是什么?",
|
||||||
"我应该使用哪个IDE来编写Go语言?",
|
"我应该使用哪个IDE来编写Go语言?",
|
||||||
"Go vs Java vs Kotlin,哪个适合后端",
|
"Go vs Java vs Kotlin,哪个适合后端",
|
||||||
"哪种编程语言最适合数据分析",
|
"哪种编程语言最适合数据分析",
|
||||||
"什么是量子计算",
|
|
||||||
"什么是哈希函数?",
|
|
||||||
"什么是微积分?",
|
|
||||||
"机器学习在金融中的主要应用有哪些?",
|
"机器学习在金融中的主要应用有哪些?",
|
||||||
"写Python代码最好的文本编辑器是哪个?",
|
"写Python代码最好的文本编辑器是哪个?",
|
||||||
"Python vs R vs Julia,哪个更适合数据科学?",
|
"Python vs R vs Julia,哪个更适合数据科学?",
|
||||||
"监督学习和无监督学习的关键区别是什么?",
|
"监督学习和无监督学习的关键区别是什么?",
|
||||||
"数据库在Web应用程序中的作用是什么",
|
"数据库在Web应用程序中的作用是什么",
|
||||||
"什么是区块链技术",
|
|
||||||
"使用Docker进行应用程序部署的优势是什么?",
|
"使用Docker进行应用程序部署的优势是什么?",
|
||||||
"哪个云服务提供商提供最好的AI工具?",
|
"哪个云服务提供商提供最好的AI工具?",
|
||||||
"加密是如何工作的?",
|
"加密是如何工作的",
|
||||||
"负载均衡器在网络架构中的目的是什么?",
|
"负载均衡器在网络架构中的目的是什么",
|
||||||
"机器学习和深度学习有什么区别",
|
"机器学习和深度学习有什么区别",
|
||||||
"软件工程中最常见的设计模式有哪些",
|
"软件工程中最常见的设计模式有哪些",
|
||||||
"神经网络是如何学习的",
|
"神经网络是如何学习的",
|
||||||
@ -300,31 +324,22 @@
|
|||||||
"Rust编程语言的关键特性是什么?",
|
"Rust编程语言的关键特性是什么?",
|
||||||
"HTTP和HTTPS有什么区别",
|
"HTTP和HTTPS有什么区别",
|
||||||
"使用像Git这样的版本控制系统有什么优势?",
|
"使用像Git这样的版本控制系统有什么优势?",
|
||||||
"什么是'边缘计算'的概念",
|
|
||||||
"哪种编程语言最适合构建移动应用?",
|
"哪种编程语言最适合构建移动应用?",
|
||||||
"关系数据库和NoSQL数据库有什么不同?",
|
"关系数据库和NoSQL数据库有什么不同?",
|
||||||
"算法在计算机科学中的重要性是什么?",
|
"算法在计算机科学中的重要性是什么",
|
||||||
"API在软件开发中的作用是什么",
|
"API在软件开发中的作用是什么",
|
||||||
"保护Web应用程序的最佳实践是什么",
|
"保护Web应用程序的最佳实践是什么",
|
||||||
"虚拟现实和增强现实有什么区别?",
|
"虚拟现实和增强现实有什么区别?",
|
||||||
"机器翻译是如何工作的?",
|
"机器翻译是如何工作的?",
|
||||||
"Which framework do you think is the most suitable for performance sensitive projects?",
|
"Which framework do you think is the most suitable for performance sensitive projects?",
|
||||||
"What is post-quantum cryptography",
|
|
||||||
"What is a key derivation function?",
|
|
||||||
"What is Linear Algebra?",
|
|
||||||
"What is the main use of linear algebra in computer science",
|
|
||||||
"which IDE should I use for Go",
|
"which IDE should I use for Go",
|
||||||
"Go vs Java vs Koltin, which for a backend",
|
"Go vs Java vs Koltin, which for a backend",
|
||||||
"Which programming language is best suited for data analysis?",
|
"Which programming language is best suited for data analysis?",
|
||||||
"What is quantum computing?",
|
"What are the main applications of machine learning in finance",
|
||||||
"What is a hash function?",
|
"Which text editor is best for writing Python code",
|
||||||
"What is calculus?",
|
"Python vs R vs Julia, which is better for data science",
|
||||||
"What are the main applications of machine learning in finance?",
|
"What are the key differences between supervised and unsupervised learning",
|
||||||
"Which text editor is best for writing Python code?",
|
|
||||||
"Python vs R vs Julia, which is better for data science?",
|
|
||||||
"What are the key differences between supervised and unsupervised learning?",
|
|
||||||
"What is the role of a database in a web application?",
|
"What is the role of a database in a web application?",
|
||||||
"What is blockchain technology?",
|
|
||||||
"What are the advantages of using Docker for application deployment?",
|
"What are the advantages of using Docker for application deployment?",
|
||||||
"Which cloud service provider offers the best AI tools?",
|
"Which cloud service provider offers the best AI tools?",
|
||||||
"How does encryption work?",
|
"How does encryption work?",
|
||||||
@ -332,19 +347,20 @@
|
|||||||
"What is the difference between machine learning and deep learning?",
|
"What is the difference between machine learning and deep learning?",
|
||||||
"What are the most common design patterns in software engineering?",
|
"What are the most common design patterns in software engineering?",
|
||||||
"How does a neural network learn?",
|
"How does a neural network learn?",
|
||||||
"What is the main benefit of using a microservices architecture?",
|
"What is the main benefit of using a microservices architecture",
|
||||||
"What is the difference between a compiler and an interpreter?",
|
"What is the difference between a compiler and an interpreter",
|
||||||
"What are the key features of the Rust programming language?",
|
"What are the key features of the Rust programming language",
|
||||||
"What is the difference between HTTP and HTTPS?",
|
"What is the difference between HTTP and HTTPS",
|
||||||
"What are the advantages of using a version control system like Git?",
|
"What are the advantages of using a version control system like Git",
|
||||||
"What is the concept of 'edge computing'?",
|
"What is the concept of 'edge computing'",
|
||||||
"Which programming language is best for building mobile apps?",
|
"Which programming language is best for building mobile apps",
|
||||||
"How does a relational database differ from a NoSQL database?",
|
"How does a relational database differ from a NoSQL database",
|
||||||
"What is the importance of algorithms in computer science?",
|
"What is the importance of algorithms in computer science",
|
||||||
"What is the role of an API in software development?",
|
"What is the role of an API in software development",
|
||||||
"What are the best practices for securing a web application?",
|
"What are the best practices for securing a web application?",
|
||||||
"What is the difference between virtual reality and augmented reality?",
|
"What is the difference between virtual reality and augmented reality?",
|
||||||
"How does machine translation work?"
|
"How does machine translation work?",
|
||||||
|
"MBTI有科学依据吗?"
|
||||||
],
|
],
|
||||||
"datetime": ["明天周几", "16天后是几号", "一年前的今天是星期几"]
|
"datetime": ["明天周几", "16天后是几号", "一年前的今天是星期几"]
|
||||||
}
|
}
|
||||||
|
@ -28,7 +28,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"model_name=\"microsoft/Phi-3-mini-4k-instruct\""
|
"model_name=\"Qwen/Qwen2.5-3B\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -37,17 +37,10 @@
|
|||||||
"id": "c1de25fc-e90a-425b-8520-3a57fa534b94",
|
"id": "c1de25fc-e90a-425b-8520-3a57fa534b94",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
"model_id": "1aeb02c7c8084b1eb1b8e3178882fd60",
|
"model_id": "38137fc55ad24a9785ecbe1978bbc605",
|
||||||
"version_major": 2,
|
"version_major": 2,
|
||||||
"version_minor": 0
|
"version_minor": 0
|
||||||
},
|
},
|
||||||
@ -76,6 +69,122 @@
|
|||||||
"vocab = tokenizer.get_vocab()"
|
"vocab = tokenizer.get_vocab()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "21214ff4-018d-4230-81b9-331ebb42773b",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def bytes_to_unicode():\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control\n",
|
||||||
|
" characters the bpe code barfs on.\n",
|
||||||
|
"\n",
|
||||||
|
" The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab\n",
|
||||||
|
" if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for\n",
|
||||||
|
" decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup\n",
|
||||||
|
" tables between utf-8 bytes and unicode strings.\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" bs = (\n",
|
||||||
|
" list(range(ord(\"!\"), ord(\"~\") + 1)) + list(range(ord(\"¡\"), ord(\"¬\") + 1)) + list(range(ord(\"®\"), ord(\"ÿ\") + 1))\n",
|
||||||
|
" )\n",
|
||||||
|
" cs = bs[:]\n",
|
||||||
|
" n = 0\n",
|
||||||
|
" for b in range(2**8):\n",
|
||||||
|
" if b not in bs:\n",
|
||||||
|
" bs.append(b)\n",
|
||||||
|
" cs.append(2**8 + n)\n",
|
||||||
|
" n += 1\n",
|
||||||
|
" cs = [chr(n) for n in cs]\n",
|
||||||
|
" return dict(zip(bs, cs))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "cbc23d2d-985b-443a-83ee-c2286046ad5e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"btu=bytes_to_unicode()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 14,
|
||||||
|
"id": "4a99fa07-4922-4d8d-9c28-2275bf9cb8df",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"utb = reversed_dict = {value: key for key, value in btu.items()}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 24,
|
||||||
|
"id": "cb218ea7-50c7-4bb8-aa7f-0ee85da76147",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"result = tokenizer.convert_ids_to_tokens([104307])[0]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 27,
|
||||||
|
"id": "2dcb332a-cba9-4a14-9486-4e1ff6bd3dba",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"å\n",
|
||||||
|
"229\n",
|
||||||
|
"¤\n",
|
||||||
|
"164\n",
|
||||||
|
"©\n",
|
||||||
|
"169\n",
|
||||||
|
"æ\n",
|
||||||
|
"230\n",
|
||||||
|
"°\n",
|
||||||
|
"176\n",
|
||||||
|
"Ķ\n",
|
||||||
|
"148\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"decoded=b\"\"\n",
|
||||||
|
"for chr in result:\n",
|
||||||
|
" print(chr)\n",
|
||||||
|
" if chr in utb:\n",
|
||||||
|
" print(utb[chr])\n",
|
||||||
|
" decoded+=bytes([utb[chr]])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 29,
|
||||||
|
"id": "b1bf1289-2cab-4a97-ad21-b2d24de6d688",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"'天气'"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 29,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"decoded.decode(\"utf-8\", errors='replace')"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 5,
|
||||||
@ -95,7 +204,7 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"DIMENSIONS = 128"
|
"DIMENSIONS = 96"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -168,11 +277,17 @@
|
|||||||
"import struct\n",
|
"import struct\n",
|
||||||
"with open(\"token_embeddings.bin\", \"wb\") as f:\n",
|
"with open(\"token_embeddings.bin\", \"wb\") as f:\n",
|
||||||
" for token_id in range(len(vocab)):\n",
|
" for token_id in range(len(vocab)):\n",
|
||||||
" # Write token id (2 bytes)\n",
|
" # 将向量转换为半精度浮点数并保存\n",
|
||||||
" f.write(struct.pack('H', token_id))\n",
|
" f.write(struct.pack('96e', *reduced_embeddings[token_id].astype(np.float16)))\n"
|
||||||
" # Write embedding vector (128 float numbers)\n",
|
|
||||||
" f.write(struct.pack('128f', *reduced_embeddings[token_id]))"
|
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "511a7cc4-1b8c-468c-b2a0-16dc6d74ab44",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
@ -191,7 +306,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.9.19"
|
"version": "3.10.14"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
@ -4,5 +4,350 @@
|
|||||||
"我爱你",
|
"我爱你",
|
||||||
"嘿嘿嘿诶嘿",
|
"嘿嘿嘿诶嘿",
|
||||||
"为什么",
|
"为什么",
|
||||||
"拼多多"
|
"拼多多",
|
||||||
]
|
"machine translation",
|
||||||
|
"trustrank",
|
||||||
|
"中文词典",
|
||||||
|
"bin screen linux",
|
||||||
|
"\"TinyBERT",
|
||||||
|
"iconify",
|
||||||
|
"反义词 英文",
|
||||||
|
"referer",
|
||||||
|
"watchos uiscreen",
|
||||||
|
"张鑫旭",
|
||||||
|
"google first result",
|
||||||
|
"flutter text align center",
|
||||||
|
"ASR model",
|
||||||
|
"real time whisper",
|
||||||
|
"千樱凛",
|
||||||
|
"马嘉祺",
|
||||||
|
"flutter widget catalog",
|
||||||
|
"flutter BottomNavigationBar left",
|
||||||
|
"flutter tab indent vscode",
|
||||||
|
"react native 用 expo 吗",
|
||||||
|
"latest monorepo tool",
|
||||||
|
"\"vite\" \"abortController\" is not defined",
|
||||||
|
"vim comment lines",
|
||||||
|
"Error: unable to get issuer certificate",
|
||||||
|
"uuidv4",
|
||||||
|
"npm semver",
|
||||||
|
"react polyfill vite",
|
||||||
|
"vibrance",
|
||||||
|
"I can eat glass, it doesn't hurt me \"japanese\"",
|
||||||
|
"I can swallow glass without any harm to myself",
|
||||||
|
"copilot pricing",
|
||||||
|
"vim close window",
|
||||||
|
"sensors macos command",
|
||||||
|
"智乃",
|
||||||
|
"pypi wikipedia",
|
||||||
|
"tesseract macos m1",
|
||||||
|
"rag prompt template",
|
||||||
|
"英国 破产",
|
||||||
|
"bewlybewly",
|
||||||
|
"safari-web-extension-converter",
|
||||||
|
"starcoder",
|
||||||
|
"open source web search for ai",
|
||||||
|
"gpt4o mini tokenizer",
|
||||||
|
"gpt4o tokenizer",
|
||||||
|
"reverse dns lookup linux",
|
||||||
|
"online ping",
|
||||||
|
"termux",
|
||||||
|
"802.11 table",
|
||||||
|
"optimize",
|
||||||
|
"集群",
|
||||||
|
"chrome us",
|
||||||
|
"transflective",
|
||||||
|
"ielts toefl",
|
||||||
|
"react router",
|
||||||
|
"摇曳露营 萌娘百科",
|
||||||
|
"isrc",
|
||||||
|
"apple-system",
|
||||||
|
"-apple-system",
|
||||||
|
"css clip path animation",
|
||||||
|
"can i use relative path in og image",
|
||||||
|
"GitSora",
|
||||||
|
"matrix im",
|
||||||
|
"test your vocabulary",
|
||||||
|
"boarding pass",
|
||||||
|
"函数签名",
|
||||||
|
"类型谓词",
|
||||||
|
"barcode",
|
||||||
|
"智能",
|
||||||
|
"threejs 入门",
|
||||||
|
"南亚语系",
|
||||||
|
"linux user's computer be like",
|
||||||
|
"apple a16 显微图",
|
||||||
|
"dallas",
|
||||||
|
"恶魔 英文",
|
||||||
|
"Rime meaning",
|
||||||
|
"adobe media encoder macos download",
|
||||||
|
"mp4 transparency",
|
||||||
|
"webkit",
|
||||||
|
"chromium",
|
||||||
|
"献血",
|
||||||
|
"软件强制更新",
|
||||||
|
"If you don’t agree with its politics views, Notepad+ + will add random characters in your source code.",
|
||||||
|
"Unmerged paths",
|
||||||
|
"字数统计",
|
||||||
|
"Use build.rollupOptions.output.manualChunks to improve chunking: https://rollupjs.org/configuration-options/#output-manualchunks",
|
||||||
|
"世界人权宣言",
|
||||||
|
"latex percent",
|
||||||
|
"chord in keyboard",
|
||||||
|
"Google is trying to kill the Open Web.",
|
||||||
|
"silo'd",
|
||||||
|
"swiftui 数组倒数访问",
|
||||||
|
"swiftui link to another view",
|
||||||
|
"fizzbuzz",
|
||||||
|
"AppDelegate watchos",
|
||||||
|
"Cannot find type 'UIApplicationDelegate' in scope",
|
||||||
|
"swiftui web image",
|
||||||
|
"spammer",
|
||||||
|
"swiftui text",
|
||||||
|
"钢琴",
|
||||||
|
"disable webgl chrome",
|
||||||
|
"online uuid",
|
||||||
|
"cp show progress",
|
||||||
|
"易容术",
|
||||||
|
"fulilian",
|
||||||
|
"cargo",
|
||||||
|
"wordle",
|
||||||
|
"mismatch",
|
||||||
|
"btc",
|
||||||
|
"squelch",
|
||||||
|
"psql show table structure",
|
||||||
|
"let padding don't effect when empty",
|
||||||
|
"take over the world meaning",
|
||||||
|
"brain teasers",
|
||||||
|
"Google flight API",
|
||||||
|
"square symbol",
|
||||||
|
"sill",
|
||||||
|
"nextjs layout per page",
|
||||||
|
"UA 550 umol/L",
|
||||||
|
"react production promotion page",
|
||||||
|
"jupyter notebook",
|
||||||
|
"wth meaning",
|
||||||
|
"glove词向量",
|
||||||
|
"google suggestion relevance",
|
||||||
|
"YouTube advertising income",
|
||||||
|
"PKI",
|
||||||
|
"next client only component",
|
||||||
|
"nextjs use client",
|
||||||
|
"nextjs docker tailwind not working",
|
||||||
|
"k8s",
|
||||||
|
"Logistic Regression",
|
||||||
|
"氯化钾注射死刑",
|
||||||
|
"icloud photo loss",
|
||||||
|
"芙宁娜 水上行走",
|
||||||
|
"vector design tool",
|
||||||
|
"netizen",
|
||||||
|
"framework or next js documentation",
|
||||||
|
"csync",
|
||||||
|
"next js",
|
||||||
|
"后量子正向保密",
|
||||||
|
"nip05",
|
||||||
|
"Sora技术原理",
|
||||||
|
"wasm效率",
|
||||||
|
"switch code",
|
||||||
|
"online IPA pronunciation",
|
||||||
|
"pnpm global adir",
|
||||||
|
"如何搜索",
|
||||||
|
"1999 抽卡期望",
|
||||||
|
"swiftui background blur",
|
||||||
|
"chrome macos fullscreen hide",
|
||||||
|
"中英文空格自动",
|
||||||
|
"ios 旁白 屏幕识别",
|
||||||
|
"ios 旁白 转子",
|
||||||
|
"http 404",
|
||||||
|
"yaml缩进",
|
||||||
|
"counter generator github",
|
||||||
|
"git 服务器提供远端仓库",
|
||||||
|
"ipfs companion",
|
||||||
|
"supervisor config",
|
||||||
|
"SSO",
|
||||||
|
"slot embedding",
|
||||||
|
"sql show tables",
|
||||||
|
"The request signature we calculated does not match the signature you provided. Check your Secret Access Key and signing method.",
|
||||||
|
"icloud.com,cn",
|
||||||
|
"VuePress",
|
||||||
|
"parser",
|
||||||
|
"stackoverflow statistics",
|
||||||
|
"sd xl",
|
||||||
|
"Rollup failed to resolve import \"workbox-precaching\" from",
|
||||||
|
"dep",
|
||||||
|
"Cannot find module estree-walker.js docker",
|
||||||
|
"nuxt run",
|
||||||
|
"base58解码",
|
||||||
|
"cga",
|
||||||
|
"vscode",
|
||||||
|
"vscode",
|
||||||
|
"silicon",
|
||||||
|
"macos m1 linux",
|
||||||
|
"预处理 后处理",
|
||||||
|
"is vp9 opensource",
|
||||||
|
"Alice Blu",
|
||||||
|
"失控玩家",
|
||||||
|
"kv数据库",
|
||||||
|
"redis 持久化",
|
||||||
|
"firefox disable outline",
|
||||||
|
"cd -2",
|
||||||
|
"IM application",
|
||||||
|
"2021国产电影",
|
||||||
|
"youtube chat overlay obs",
|
||||||
|
"obs add clock",
|
||||||
|
"Z is not defined nuxt",
|
||||||
|
"safari ios debug",
|
||||||
|
"safari debug",
|
||||||
|
"chat",
|
||||||
|
"nuxt plugin inject",
|
||||||
|
"twitch",
|
||||||
|
"obs 绿幕",
|
||||||
|
"gnupg",
|
||||||
|
"kde plasma wallpaper engine",
|
||||||
|
"Plasma",
|
||||||
|
"dns over https",
|
||||||
|
"localforage缺点",
|
||||||
|
"watchOS 10",
|
||||||
|
"noun of repeat",
|
||||||
|
"微信输入法",
|
||||||
|
"行业报告",
|
||||||
|
"keepass",
|
||||||
|
"platform",
|
||||||
|
"steam",
|
||||||
|
"java proxy",
|
||||||
|
"0 design",
|
||||||
|
"cefr word level list",
|
||||||
|
"precipitation meaning",
|
||||||
|
"international school of lausanne",
|
||||||
|
"Vim Uganda",
|
||||||
|
"抖音 推荐算法",
|
||||||
|
"Meta NNLO",
|
||||||
|
"windbg dump分析",
|
||||||
|
"web image fft",
|
||||||
|
"GPT-4 Pricing",
|
||||||
|
"GPT-4",
|
||||||
|
"Scala",
|
||||||
|
"tauri教程",
|
||||||
|
"asyncio.create_task用法",
|
||||||
|
"H5 滚动到底部",
|
||||||
|
"microsoft copilot",
|
||||||
|
"枫丹文字",
|
||||||
|
"brew pip",
|
||||||
|
"TS7016: Could not find a declaration file for module react .",
|
||||||
|
"fastapi websocket",
|
||||||
|
"kazv",
|
||||||
|
"The Type 孔雀计划",
|
||||||
|
"第一个图形操作系统",
|
||||||
|
"娱乐 诞生",
|
||||||
|
"ffmpeg 音频封面",
|
||||||
|
"Jean-Loup Gailly",
|
||||||
|
"Linux用户软件位置",
|
||||||
|
"\"ubuntu\" 平滑滚动",
|
||||||
|
"python range函数",
|
||||||
|
"KMP",
|
||||||
|
"sd 8gen2 GPU GFLOPS",
|
||||||
|
"mac语音输入法",
|
||||||
|
"openai translate",
|
||||||
|
"蔚蓝档案 初始抽卡",
|
||||||
|
"free custom domain email",
|
||||||
|
"洛天依",
|
||||||
|
"b站 频道页Tab 跳转",
|
||||||
|
"URL 重定向预览",
|
||||||
|
"计算机",
|
||||||
|
"sololearn",
|
||||||
|
"PoS机制 通俗解释",
|
||||||
|
"google search cost",
|
||||||
|
"bos s3",
|
||||||
|
"react 打包",
|
||||||
|
"useeffect 用法",
|
||||||
|
"ts 字典类型",
|
||||||
|
"vscode 字典单词自动补全插件",
|
||||||
|
"componentwillupdate",
|
||||||
|
"iPad Mini 2",
|
||||||
|
"use-immer",
|
||||||
|
"reducer 和 context",
|
||||||
|
"mint",
|
||||||
|
"Elementary OS",
|
||||||
|
"google科技新闻",
|
||||||
|
"iCloud mail \"\"-9002\"\"",
|
||||||
|
"氢氧化铁胶体制备",
|
||||||
|
"react native 视频处理",
|
||||||
|
"四川 2023 高考 复旦大学 分数线",
|
||||||
|
"哑铃弯举",
|
||||||
|
"m2 ultra",
|
||||||
|
"电池循环计数 site:apple.com",
|
||||||
|
"相机发明时间",
|
||||||
|
"冯诺依曼结构",
|
||||||
|
"哈佛架构",
|
||||||
|
"nodejs 后端",
|
||||||
|
"34.5M€ to CN¥",
|
||||||
|
"NLP 实体关注",
|
||||||
|
"monkey",
|
||||||
|
"react 快捷键监听",
|
||||||
|
"mac 好看的电子书阅读器",
|
||||||
|
"新闻",
|
||||||
|
"在线字体编辑器",
|
||||||
|
"ars technica",
|
||||||
|
"genshin 4.1 release time",
|
||||||
|
"swift device activity report",
|
||||||
|
"swiftui tabview background",
|
||||||
|
"swiftui text space",
|
||||||
|
"apple inc. wikipedia",
|
||||||
|
"how long does it take Google to return the results",
|
||||||
|
"云原神 web",
|
||||||
|
"支持homekit的空调",
|
||||||
|
"内核隔离",
|
||||||
|
"海祇岛解密",
|
||||||
|
"swiftui Textfield",
|
||||||
|
"xcode",
|
||||||
|
"qq 链接",
|
||||||
|
"M1 推出时间",
|
||||||
|
"USB-IF",
|
||||||
|
"nvchat",
|
||||||
|
"P1% FPS",
|
||||||
|
"react i18next 当前语言",
|
||||||
|
"js 获取语言",
|
||||||
|
"MulType",
|
||||||
|
"b站平均使用时间",
|
||||||
|
"pip 阿里源",
|
||||||
|
"ip info",
|
||||||
|
"graphjet",
|
||||||
|
"金融思维",
|
||||||
|
"C#写入文件",
|
||||||
|
"Last Day Sinista M",
|
||||||
|
"在 系统 位置 xcode select 找 不 到 SDK",
|
||||||
|
"Error: Could not find a valid Xcode app bundle at '/Library/Developer/CommandLineTools'. Please update your Apple SDK location in Visual Studio's preferences (Projects > SDK Locations > Apple > Apple SDK). (UniBattery)",
|
||||||
|
".NET能做什么",
|
||||||
|
"could i give no tip ",
|
||||||
|
"miami university of ohio",
|
||||||
|
"方正颜宋",
|
||||||
|
"中文 标题字体",
|
||||||
|
"聚典平台",
|
||||||
|
"62 basic words for a language",
|
||||||
|
"procrastination meaning",
|
||||||
|
"Lingbe",
|
||||||
|
"娱乐至死",
|
||||||
|
"macOS 外接显示器渲染",
|
||||||
|
"白玉袖",
|
||||||
|
"SwiftUI入门",
|
||||||
|
"html插入其它网页",
|
||||||
|
"捆绑 小说",
|
||||||
|
"apple music 无损下载",
|
||||||
|
"一miumiu 赐予",
|
||||||
|
"macos markdown",
|
||||||
|
"safari 开发者工具",
|
||||||
|
"\"百合\" \"武侠\" \"国漫\"",
|
||||||
|
"epub 格式详解",
|
||||||
|
"chrome 隐藏滚动条",
|
||||||
|
"发宽空格",
|
||||||
|
"U+200A",
|
||||||
|
"无性人",
|
||||||
|
"Spotify",
|
||||||
|
"禾念",
|
||||||
|
"how to pronounce Lorem ipsum",
|
||||||
|
"言和为什么不是男孩子",
|
||||||
|
"浏览器主页",
|
||||||
|
"react",
|
||||||
|
"Tailwindcss react 扩展怎么用",
|
||||||
|
"Prettier 扩展怎么用",
|
||||||
|
"linter\""
|
||||||
|
]
|
||||||
|
@ -1,575 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "a6a3195f-d099-4bf4-846f-51f403954818",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"# sparkastML: Training the Intention Classification Model\n",
|
|
||||||
"\n",
|
|
||||||
"This is the model we use for intent recognition, using a **CNN architectur** and using an **Energy-based Model** to implement OSR (Open-set Recognition).\n",
|
|
||||||
"\n",
|
|
||||||
"In this case, **positive samples** refer to data that can be classified into existing class, while **negative samples** are those does not belong to any of the existing class."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 1,
|
|
||||||
"id": "bddcdbb2-ccbc-4027-a38f-09c61ac94984",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import json\n",
|
|
||||||
"import torch\n",
|
|
||||||
"from torch.utils.data import Dataset, DataLoader\n",
|
|
||||||
"from torch.nn.utils.rnn import pad_sequence\n",
|
|
||||||
"from sklearn.model_selection import train_test_split\n",
|
|
||||||
"from transformers import AutoTokenizer, AutoModel\n",
|
|
||||||
"import torch\n",
|
|
||||||
"import numpy as np\n",
|
|
||||||
"from scipy.spatial.distance import euclidean\n",
|
|
||||||
"from scipy.stats import weibull_min\n",
|
|
||||||
"from sklearn.preprocessing import normalize\n",
|
|
||||||
"import torch.nn.functional as F\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 2,
|
|
||||||
"id": "d3a0e10f-9bc3-44c7-a109-786dd5cd25ea",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"model_name=\"microsoft/Phi-3-mini-4k-instruct\"\n",
|
|
||||||
"DIMENSIONS = 128\n",
|
|
||||||
"tokenizer = AutoTokenizer.from_pretrained(model_name)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "1ae14906-338d-4c99-87ed-bb1acd22b295",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Load Data\n",
|
|
||||||
"\n",
|
|
||||||
"We load the data from `data.json`, and also get the negative sample from the `noise.json`."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 3,
|
|
||||||
"id": "a206071c-ce4e-4de4-b936-bfc70d13708a",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"/var/folders/25/gdz0c30x3mg1dj9qkwz0ch4w0000gq/T/ipykernel_6446/1697839999.py:18: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_69nk78ncaj/croot/pytorch_1669252638507/work/torch/csrc/utils/tensor_new.cpp:204.)\n",
|
|
||||||
" embeddings = torch.tensor(embeddings)\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"# Load data\n",
|
|
||||||
"with open('data.json', 'r') as f:\n",
|
|
||||||
" data = json.load(f)\n",
|
|
||||||
"\n",
|
|
||||||
"# Create map: class to index\n",
|
|
||||||
"class_to_idx = {cls: idx for idx, cls in enumerate(data.keys())}\n",
|
|
||||||
"idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}\n",
|
|
||||||
"\n",
|
|
||||||
"# Preprocess data, convert sentences to the format of (class idx, embedding)\n",
|
|
||||||
"def preprocess_data(data, embedding_map, tokenizer, max_length=64):\n",
|
|
||||||
" dataset = []\n",
|
|
||||||
" for label, sentences in data.items():\n",
|
|
||||||
" for sentence in sentences:\n",
|
|
||||||
" # Tokenize the sentence and convert tokens to embedding vectors\n",
|
|
||||||
" tokens = tokenizer.tokenize(sentence)\n",
|
|
||||||
" token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
|
|
||||||
" embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]\n",
|
|
||||||
" embeddings = torch.tensor(embeddings)\n",
|
|
||||||
" dataset.append((class_to_idx[label], embeddings))\n",
|
|
||||||
" return dataset\n",
|
|
||||||
"\n",
|
|
||||||
"# Load embedding map\n",
|
|
||||||
"embedding_map = torch.load('token_id_to_reduced_embedding.pt')\n",
|
|
||||||
"\n",
|
|
||||||
"# Get preprocessed dataset\n",
|
|
||||||
"dataset = preprocess_data(data, embedding_map, tokenizer)\n",
|
|
||||||
"\n",
|
|
||||||
"# Train-test split\n",
|
|
||||||
"train_data, val_data = train_test_split(dataset, test_size=0.2, random_state=42)\n",
|
|
||||||
"\n",
|
|
||||||
"class TextDataset(Dataset):\n",
|
|
||||||
" def __init__(self, data):\n",
|
|
||||||
" self.data = data\n",
|
|
||||||
"\n",
|
|
||||||
" def __len__(self):\n",
|
|
||||||
" return len(self.data)\n",
|
|
||||||
"\n",
|
|
||||||
" def __getitem__(self, idx):\n",
|
|
||||||
" return self.data[idx]\n",
|
|
||||||
"\n",
|
|
||||||
" def collate_fn(self, batch):\n",
|
|
||||||
" labels, embeddings = zip(*batch)\n",
|
|
||||||
" labels = torch.tensor(labels)\n",
|
|
||||||
" embeddings = pad_sequence(embeddings, batch_first=True)\n",
|
|
||||||
" return labels, embeddings\n",
|
|
||||||
"\n",
|
|
||||||
"train_dataset = TextDataset(train_data)\n",
|
|
||||||
"val_dataset = TextDataset(val_data)\n",
|
|
||||||
"\n",
|
|
||||||
"train_loader = DataLoader(train_dataset, batch_size=24, shuffle=True, collate_fn=train_dataset.collate_fn)\n",
|
|
||||||
"val_loader = DataLoader(val_dataset, batch_size=24, shuffle=False, collate_fn=val_dataset.collate_fn)\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 4,
|
|
||||||
"id": "9adbe9b8-a2d2-4e1d-8620-457ed0e02fe6",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import torch\n",
|
|
||||||
"from torch.utils.data import Dataset, DataLoader\n",
|
|
||||||
"\n",
|
|
||||||
"class NegativeSampleDataset(Dataset):\n",
|
|
||||||
" def __init__(self, negative_samples):\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" negative_samples: List or array of negative sample embeddings or raw text\n",
|
|
||||||
" \"\"\"\n",
|
|
||||||
" self.samples = negative_samples\n",
|
|
||||||
" \n",
|
|
||||||
" def __len__(self):\n",
|
|
||||||
" return len(self.samples)\n",
|
|
||||||
" \n",
|
|
||||||
" def __getitem__(self, idx):\n",
|
|
||||||
" return self.samples[idx]\n",
|
|
||||||
"\n",
|
|
||||||
" def collate_fn(self, batch):\n",
|
|
||||||
" embeddings = pad_sequence(batch, batch_first=True)\n",
|
|
||||||
" return embeddings\n",
|
|
||||||
"\n",
|
|
||||||
"with open('noise.json', 'r') as f:\n",
|
|
||||||
" negative_samples_list = json.load(f)\n",
|
|
||||||
"\n",
|
|
||||||
"negative_embedding_list = []\n",
|
|
||||||
"\n",
|
|
||||||
"for sentence in negative_samples_list:\n",
|
|
||||||
" tokens = tokenizer.tokenize(sentence)\n",
|
|
||||||
" token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
|
|
||||||
" embeddings = [embedding_map[token_id] for token_id in token_ids[:64]]\n",
|
|
||||||
" embeddings = torch.tensor(embeddings)\n",
|
|
||||||
" negative_embedding_list.append(embeddings)\n",
|
|
||||||
"\n",
|
|
||||||
"negative_dataset = NegativeSampleDataset(negative_embedding_list)\n",
|
|
||||||
"negative_loader = DataLoader(negative_dataset, batch_size=24, shuffle=True, collate_fn=negative_dataset.collate_fn)\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "600febe4-2484-4aad-90a1-2bc821fdce1a",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Implementating the Model"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 5,
|
|
||||||
"id": "adf624ac-ad63-437b-95f6-b02b7253b91e",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import torch.nn as nn\n",
|
|
||||||
"import torch.nn.functional as F\n",
|
|
||||||
"\n",
|
|
||||||
"class TextCNN(nn.Module):\n",
|
|
||||||
" def __init__(self, input_dim, num_classes):\n",
|
|
||||||
" super(TextCNN, self).__init__()\n",
|
|
||||||
" self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=DIMENSIONS, kernel_size=3, padding=1)\n",
|
|
||||||
" self.conv2 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=4, padding=1)\n",
|
|
||||||
" self.conv3 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=5, padding=2)\n",
|
|
||||||
" \n",
|
|
||||||
" self.bn1 = nn.BatchNorm1d(DIMENSIONS)\n",
|
|
||||||
" self.bn2 = nn.BatchNorm1d(DIMENSIONS)\n",
|
|
||||||
" self.bn3 = nn.BatchNorm1d(DIMENSIONS)\n",
|
|
||||||
" \n",
|
|
||||||
" self.dropout = nn.Dropout(0.5)\n",
|
|
||||||
" self.fc = nn.Linear(DIMENSIONS * 3, num_classes)\n",
|
|
||||||
"\n",
|
|
||||||
" def forward(self, x):\n",
|
|
||||||
" x = x.permute(0, 2, 1) # Change the input shape to (batch_size, embedding_dim, seq_length)\n",
|
|
||||||
" \n",
|
|
||||||
" x1 = F.relu(self.bn1(self.conv1(x)))\n",
|
|
||||||
" x1 = F.adaptive_max_pool1d(x1, output_size=1).squeeze(2)\n",
|
|
||||||
" \n",
|
|
||||||
" x2 = F.relu(self.bn2(self.conv2(x)))\n",
|
|
||||||
" x2 = F.adaptive_max_pool1d(x2, output_size=1).squeeze(2)\n",
|
|
||||||
" \n",
|
|
||||||
" x3 = F.relu(self.bn3(self.conv3(x)))\n",
|
|
||||||
" x3 = F.adaptive_max_pool1d(x3, output_size=1).squeeze(2)\n",
|
|
||||||
" \n",
|
|
||||||
" x = torch.cat((x1, x2, x3), dim=1)\n",
|
|
||||||
" x = self.dropout(x)\n",
|
|
||||||
" x = self.fc(x)\n",
|
|
||||||
" return x\n",
|
|
||||||
"\n",
|
|
||||||
"# Initialize model\n",
|
|
||||||
"input_dim = DIMENSIONS\n",
|
|
||||||
"num_classes = len(class_to_idx)\n",
|
|
||||||
"model = TextCNN(input_dim, num_classes)\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "2750e17d-8a60-40c7-851b-1a567d0ee82b",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Energy-based Models"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 6,
|
|
||||||
"id": "a2d7c920-07d2-4d14-9cef-e2101b7a2ceb",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def energy_score(logits):\n",
|
|
||||||
" # Energy score is minus logsumexp\n",
|
|
||||||
" return -torch.logsumexp(logits, dim=1)\n",
|
|
||||||
"\n",
|
|
||||||
"def generate_noise(batch_size, seq_length ,input_dim, device):\n",
|
|
||||||
" # Generate a Gaussian noise\n",
|
|
||||||
" return torch.randn(batch_size, seq_length, input_dim).to(device)\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "904a60e4-95a0-4f7b-ad45-a7d8d0ac887d",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Training"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 7,
|
|
||||||
"id": "19acb5bf-00b1-47d4-ad25-a13c6be09f65",
|
|
||||||
"metadata": {
|
|
||||||
"scrolled": true
|
|
||||||
},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Epoch [1/50], Loss: 12.5108\n",
|
|
||||||
"Epoch [2/50], Loss: 10.7305\n",
|
|
||||||
"Epoch [3/50], Loss: 10.2943\n",
|
|
||||||
"Epoch [4/50], Loss: 9.9350\n",
|
|
||||||
"Epoch [5/50], Loss: 9.7991\n",
|
|
||||||
"Epoch [6/50], Loss: 9.6443\n",
|
|
||||||
"Epoch [7/50], Loss: 9.4762\n",
|
|
||||||
"Epoch [8/50], Loss: 9.4637\n",
|
|
||||||
"Epoch [9/50], Loss: 9.3025\n",
|
|
||||||
"Epoch [10/50], Loss: 9.1719\n",
|
|
||||||
"Epoch [11/50], Loss: 9.0632\n",
|
|
||||||
"Epoch [12/50], Loss: 8.9741\n",
|
|
||||||
"Epoch [13/50], Loss: 8.8487\n",
|
|
||||||
"Epoch [14/50], Loss: 8.6565\n",
|
|
||||||
"Epoch [15/50], Loss: 8.5830\n",
|
|
||||||
"Epoch [16/50], Loss: 8.4196\n",
|
|
||||||
"Epoch [17/50], Loss: 8.2319\n",
|
|
||||||
"Epoch [18/50], Loss: 8.0655\n",
|
|
||||||
"Epoch [19/50], Loss: 7.7140\n",
|
|
||||||
"Epoch [20/50], Loss: 7.6921\n",
|
|
||||||
"Epoch [21/50], Loss: 7.3375\n",
|
|
||||||
"Epoch [22/50], Loss: 7.2297\n",
|
|
||||||
"Epoch [23/50], Loss: 6.8833\n",
|
|
||||||
"Epoch [24/50], Loss: 6.8534\n",
|
|
||||||
"Epoch [25/50], Loss: 6.4557\n",
|
|
||||||
"Epoch [26/50], Loss: 6.1365\n",
|
|
||||||
"Epoch [27/50], Loss: 5.8558\n",
|
|
||||||
"Epoch [28/50], Loss: 5.5030\n",
|
|
||||||
"Epoch [29/50], Loss: 5.1604\n",
|
|
||||||
"Epoch [30/50], Loss: 4.7742\n",
|
|
||||||
"Epoch [31/50], Loss: 4.5958\n",
|
|
||||||
"Epoch [32/50], Loss: 4.0713\n",
|
|
||||||
"Epoch [33/50], Loss: 3.8872\n",
|
|
||||||
"Epoch [34/50], Loss: 3.5240\n",
|
|
||||||
"Epoch [35/50], Loss: 3.3115\n",
|
|
||||||
"Epoch [36/50], Loss: 2.5667\n",
|
|
||||||
"Epoch [37/50], Loss: 2.6709\n",
|
|
||||||
"Epoch [38/50], Loss: 1.8075\n",
|
|
||||||
"Epoch [39/50], Loss: 1.6654\n",
|
|
||||||
"Epoch [40/50], Loss: 0.4622\n",
|
|
||||||
"Epoch [41/50], Loss: 0.4719\n",
|
|
||||||
"Epoch [42/50], Loss: -0.4037\n",
|
|
||||||
"Epoch [43/50], Loss: -0.9405\n",
|
|
||||||
"Epoch [44/50], Loss: -1.7204\n",
|
|
||||||
"Epoch [45/50], Loss: -2.4124\n",
|
|
||||||
"Epoch [46/50], Loss: -3.0032\n",
|
|
||||||
"Epoch [47/50], Loss: -2.7123\n",
|
|
||||||
"Epoch [48/50], Loss: -3.6953\n",
|
|
||||||
"Epoch [49/50], Loss: -3.7212\n",
|
|
||||||
"Epoch [50/50], Loss: -3.7558\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"import torch.optim as optim\n",
|
|
||||||
"\n",
|
|
||||||
"criterion = nn.CrossEntropyLoss()\n",
|
|
||||||
"optimizer = optim.Adam(model.parameters(), lr=8e-4)\n",
|
|
||||||
"\n",
|
|
||||||
"from torch.utils.tensorboard import SummaryWriter\n",
|
|
||||||
"import tensorboard\n",
|
|
||||||
"writer = SummaryWriter()\n",
|
|
||||||
"\n",
|
|
||||||
"def train_energy_model(model, train_loader, negative_loader, criterion, optimizer, num_epochs=10):\n",
|
|
||||||
" model.train()\n",
|
|
||||||
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
|
||||||
" model.to(device)\n",
|
|
||||||
" \n",
|
|
||||||
" negative_iter = iter(negative_loader)\n",
|
|
||||||
" \n",
|
|
||||||
" for epoch in range(num_epochs):\n",
|
|
||||||
" total_loss = 0\n",
|
|
||||||
" for batch_idx, (labels, embeddings) in enumerate(train_loader):\n",
|
|
||||||
" embeddings = embeddings.to(device)\n",
|
|
||||||
" labels = labels.to(device)\n",
|
|
||||||
" \n",
|
|
||||||
" batch_size = embeddings.size(0)\n",
|
|
||||||
" \n",
|
|
||||||
" # ---------------------\n",
|
|
||||||
" # 1. Positive sample\n",
|
|
||||||
" # ---------------------\n",
|
|
||||||
" optimizer.zero_grad()\n",
|
|
||||||
" outputs = model(embeddings) # logits from the model\n",
|
|
||||||
" \n",
|
|
||||||
" class_loss = criterion(outputs, labels)\n",
|
|
||||||
" \n",
|
|
||||||
" # Energy of positive sample\n",
|
|
||||||
" known_energy = energy_score(outputs)\n",
|
|
||||||
" energy_loss_known = known_energy.mean()\n",
|
|
||||||
" \n",
|
|
||||||
" # ------------------------------------\n",
|
|
||||||
" # 2. Negative sample - Random Noise\n",
|
|
||||||
" # ------------------------------------\n",
|
|
||||||
" noise_embeddings = torch.randn_like(embeddings).to(device)\n",
|
|
||||||
" noise_outputs = model(noise_embeddings)\n",
|
|
||||||
" noise_energy = energy_score(noise_outputs)\n",
|
|
||||||
" energy_loss_noise = F.relu(1 - noise_energy).mean() # For the energy of noise, bigger is better \n",
|
|
||||||
" \n",
|
|
||||||
" # ------------------------------------\n",
|
|
||||||
" # 3. Negative sample - custom corpus\n",
|
|
||||||
" # ------------------------------------\n",
|
|
||||||
" \n",
|
|
||||||
" try:\n",
|
|
||||||
" negative_samples = next(negative_iter)\n",
|
|
||||||
" except StopIteration:\n",
|
|
||||||
" negative_iter = iter(negative_loader)\n",
|
|
||||||
" negative_samples = next(negative_iter)\n",
|
|
||||||
" negative_samples = negative_samples.to(device)\n",
|
|
||||||
" negative_outputs = model(negative_samples)\n",
|
|
||||||
" negative_energy = energy_score(negative_outputs)\n",
|
|
||||||
" energy_loss_negative = F.relu(1 - negative_energy).mean() # For the energy of noise, bigger is better \n",
|
|
||||||
" \n",
|
|
||||||
" # -----------------------------\n",
|
|
||||||
" # 4. Overall Loss calculation\n",
|
|
||||||
" # -----------------------------\n",
|
|
||||||
" total_energy_loss = energy_loss_known + energy_loss_noise + energy_loss_negative\n",
|
|
||||||
" total_loss_batch = class_loss + total_energy_loss * 0.1 + 10\n",
|
|
||||||
"\n",
|
|
||||||
" writer.add_scalar(\"Engergy Loss\", total_energy_loss, epoch)\n",
|
|
||||||
" writer.add_scalar(\"Loss\", total_loss_batch, epoch)\n",
|
|
||||||
" writer.add_scalar(\"Norm Loss\", torch.exp(total_loss_batch * 0.003) * 10 , epoch)\n",
|
|
||||||
" \n",
|
|
||||||
" total_loss_batch.backward()\n",
|
|
||||||
" optimizer.step()\n",
|
|
||||||
" \n",
|
|
||||||
" total_loss += total_loss_batch.item()\n",
|
|
||||||
" \n",
|
|
||||||
" avg_loss = total_loss / len(train_loader)\n",
|
|
||||||
" print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')\n",
|
|
||||||
"\n",
|
|
||||||
"train_energy_model(model, train_loader, negative_loader, criterion, optimizer, num_epochs=50)\n",
|
|
||||||
"writer.flush()\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "e6d29558-f497-4033-8488-169bd25ce881",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Evalutation"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 8,
|
|
||||||
"id": "472702e6-db4a-4faa-9e92-7510e6eacbb1",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"ENERGY_THRESHOLD = -3"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 9,
|
|
||||||
"id": "d3a4fef8-37ab-45c8-b2b1-9fc8bdffcffd",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Accuracy: 0.9315\n",
|
|
||||||
"Precision: 1.0000\n",
|
|
||||||
"Recall: 0.9254\n",
|
|
||||||
"F1 Score: 0.9612\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"from sklearn.metrics import f1_score, accuracy_score, precision_recall_fscore_support\n",
|
|
||||||
"\n",
|
|
||||||
"def evaluate_energy_model(model, known_loader, unknown_loader, energy_threshold):\n",
|
|
||||||
" model.eval()\n",
|
|
||||||
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
|
||||||
" \n",
|
|
||||||
" all_preds = []\n",
|
|
||||||
" all_labels = []\n",
|
|
||||||
" \n",
|
|
||||||
" # Evaluate positive sample\n",
|
|
||||||
" with torch.no_grad():\n",
|
|
||||||
" for labels, embeddings in known_loader:\n",
|
|
||||||
" embeddings = embeddings.to(device)\n",
|
|
||||||
" logits = model(embeddings)\n",
|
|
||||||
" energy = energy_score(logits)\n",
|
|
||||||
" \n",
|
|
||||||
" preds = (energy <= energy_threshold).long()\n",
|
|
||||||
" all_preds.extend(preds.cpu().numpy())\n",
|
|
||||||
" all_labels.extend([1] * len(preds)) # Positive sample labeled as 1\n",
|
|
||||||
" \n",
|
|
||||||
" # Evaluate negative sample\n",
|
|
||||||
" with torch.no_grad():\n",
|
|
||||||
" for embeddings in unknown_loader:\n",
|
|
||||||
" embeddings = embeddings.to(device)\n",
|
|
||||||
" logits = model(embeddings)\n",
|
|
||||||
" energy = energy_score(logits)\n",
|
|
||||||
" \n",
|
|
||||||
" preds = (energy <= energy_threshold).long()\n",
|
|
||||||
" all_preds.extend(preds.cpu().numpy())\n",
|
|
||||||
" all_labels.extend([0] * len(preds)) # Negative sample labeled as 1\n",
|
|
||||||
" \n",
|
|
||||||
" precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')\n",
|
|
||||||
" accuracy = accuracy_score(all_labels, all_preds)\n",
|
|
||||||
"\n",
|
|
||||||
" print(f'Accuracy: {accuracy:.4f}')\n",
|
|
||||||
" print(f'Precision: {precision:.4f}')\n",
|
|
||||||
" print(f'Recall: {recall:.4f}')\n",
|
|
||||||
" print(f'F1 Score: {f1:.4f}')\n",
|
|
||||||
"\n",
|
|
||||||
"evaluate_energy_model(model, val_loader, negative_loader, ENERGY_THRESHOLD)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 20,
|
|
||||||
"id": "ba614054-75e1-4f61-ace5-aeb11e29a222",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# Save the model\n",
|
|
||||||
"torch.save(model, \"model.pt\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "fdfa0c5e-e6d3-4db0-a142-96645c92719c",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"## Inference"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 19,
|
|
||||||
"id": "03928d75-81c8-4298-ab8a-d7f8a758b561",
|
|
||||||
"metadata": {
|
|
||||||
"scrolled": true
|
|
||||||
},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Predicted: ['weather', 0.9989822506904602, -8.016249656677246]\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"def predict_with_energy(model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold, max_length=64):\n",
|
|
||||||
" model.eval()\n",
|
|
||||||
" tokens = tokenizer.tokenize(sentence)\n",
|
|
||||||
" token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
|
|
||||||
" embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]\n",
|
|
||||||
" embeddings = torch.tensor(embeddings).unsqueeze(0) # Add batch dimension\n",
|
|
||||||
" \n",
|
|
||||||
" with torch.no_grad():\n",
|
|
||||||
" logits = model(embeddings)\n",
|
|
||||||
" probabilities = F.softmax(logits, dim=1)\n",
|
|
||||||
" max_prob, predicted = torch.max(probabilities, 1)\n",
|
|
||||||
" \n",
|
|
||||||
" # Calculate energy score\n",
|
|
||||||
" energy = energy_score(logits)\n",
|
|
||||||
"\n",
|
|
||||||
" # If energy > threshold, consider the input as unknown class\n",
|
|
||||||
" if energy.item() > energy_threshold:\n",
|
|
||||||
" return [\"Unknown\", max_prob.item(), energy.item()]\n",
|
|
||||||
" else:\n",
|
|
||||||
" return [idx_to_class[predicted.item()], max_prob.item(), energy.item()]\n",
|
|
||||||
"\n",
|
|
||||||
"# Example usage:\n",
|
|
||||||
"sentence = \"weather today\"\n",
|
|
||||||
"energy_threshold = ENERGY_THRESHOLD\n",
|
|
||||||
"predicted = predict_with_energy(model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold)\n",
|
|
||||||
"print(f'Predicted: {predicted}')\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "Python 3 (ipykernel)",
|
|
||||||
"language": "python",
|
|
||||||
"name": "python3"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3",
|
|
||||||
"version": "3.9.19"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 5
|
|
||||||
}
|
|
4
intention-classify/training/config.py
Normal file
4
intention-classify/training/config.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
# config.py
|
||||||
|
|
||||||
|
model_name = "Qwen/Qwen2.5-3B"
|
||||||
|
DIMENSIONS = 96
|
71
intention-classify/training/data_utils.py
Normal file
71
intention-classify/training/data_utils.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
# data_utils.py
|
||||||
|
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(file_path):
|
||||||
|
with open(file_path, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def create_class_mappings(data):
|
||||||
|
class_to_idx = {cls: idx for idx, cls in enumerate(data.keys())}
|
||||||
|
idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}
|
||||||
|
return class_to_idx, idx_to_class
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_data(data, embedding_map, tokenizer, class_to_idx, max_length=64):
|
||||||
|
dataset = []
|
||||||
|
for label, sentences in data.items():
|
||||||
|
for sentence in sentences:
|
||||||
|
tokens = tokenizer.tokenize(sentence)
|
||||||
|
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
embeddings = [
|
||||||
|
embedding_map[token_id] for token_id in token_ids[:max_length]
|
||||||
|
]
|
||||||
|
embeddings = torch.tensor(embeddings)
|
||||||
|
dataset.append((class_to_idx[label], embeddings))
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
def get_sentences(data):
|
||||||
|
result = []
|
||||||
|
for _, sentences in data.items():
|
||||||
|
for sentence in sentences:
|
||||||
|
result.append(sentence)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class TextDataset(Dataset):
|
||||||
|
def __init__(self, data):
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.data[idx]
|
||||||
|
|
||||||
|
def collate_fn(self, batch):
|
||||||
|
labels, embeddings = zip(*batch)
|
||||||
|
labels = torch.tensor(labels)
|
||||||
|
embeddings = pad_sequence(embeddings, batch_first=True)
|
||||||
|
return labels, embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class NegativeSampleDataset(Dataset):
|
||||||
|
def __init__(self, negative_samples):
|
||||||
|
self.samples = negative_samples
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.samples)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.samples[idx]
|
||||||
|
|
||||||
|
def collate_fn(self, batch):
|
||||||
|
embeddings = pad_sequence(batch, batch_first=True)
|
||||||
|
return embeddings
|
53
intention-classify/training/model.py
Normal file
53
intention-classify/training/model.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
# model.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from training.config import DIMENSIONS
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttention(nn.Module):
|
||||||
|
def __init__(self, input_dim, heads):
|
||||||
|
super(SelfAttention, self).__init__()
|
||||||
|
self.heads = heads
|
||||||
|
self.scale = (input_dim // heads) ** -0.5
|
||||||
|
self.qkv = nn.Linear(input_dim, input_dim * 3)
|
||||||
|
self.fc = nn.Linear(input_dim, input_dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
batch_size, seq_length, embedding_dim = x.shape
|
||||||
|
qkv = self.qkv(x).view(
|
||||||
|
batch_size, seq_length, self.heads, 3, embedding_dim // self.heads
|
||||||
|
)
|
||||||
|
q, k, v = qkv[..., 0, :], qkv[..., 1, :], qkv[..., 2, :]
|
||||||
|
q = q.permute(0, 2, 1, 3)
|
||||||
|
k = k.permute(0, 2, 1, 3)
|
||||||
|
v = v.permute(0, 2, 1, 3)
|
||||||
|
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
||||||
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
attention_output = torch.matmul(attn_weights, v)
|
||||||
|
attention_output = attention_output.permute(0, 2, 1, 3).contiguous()
|
||||||
|
attention_output = attention_output.view(batch_size, seq_length, embedding_dim)
|
||||||
|
return self.fc(attention_output)
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBasedModel(nn.Module):
|
||||||
|
def __init__(self, input_dim, num_classes, heads=8, dim_feedforward=512, num_layers=3):
|
||||||
|
super(AttentionBasedModel, self).__init__()
|
||||||
|
self.self_attention_layers = nn.ModuleList([
|
||||||
|
SelfAttention(input_dim, heads) for _ in range(num_layers)
|
||||||
|
])
|
||||||
|
self.fc1 = nn.Linear(input_dim, dim_feedforward)
|
||||||
|
self.fc2 = nn.Linear(dim_feedforward, num_classes)
|
||||||
|
self.dropout = nn.Dropout(0.5)
|
||||||
|
self.norm = nn.LayerNorm(input_dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for attn_layer in self.self_attention_layers:
|
||||||
|
attn_output = attn_layer(x)
|
||||||
|
x = self.norm(attn_output + x)
|
||||||
|
pooled_output = torch.mean(x, dim=1)
|
||||||
|
x = F.relu(self.fc1(pooled_output))
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
return x
|
155
intention-classify/training/train.py
Normal file
155
intention-classify/training/train.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
# train.py
|
||||||
|
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
import torch
|
||||||
|
import json
|
||||||
|
import torch.optim as optim
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from training.data_utils import (
|
||||||
|
load_data,
|
||||||
|
create_class_mappings,
|
||||||
|
preprocess_data,
|
||||||
|
TextDataset,
|
||||||
|
NegativeSampleDataset,
|
||||||
|
)
|
||||||
|
from training.model import AttentionBasedModel
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def energy_score(logits, temperature=1.0):
|
||||||
|
return -torch.logsumexp(logits / temperature, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_noise(batch_size, seq_length, input_dim, device):
|
||||||
|
return torch.randn(batch_size, seq_length, input_dim).to(device)
|
||||||
|
|
||||||
|
|
||||||
|
def train_energy_model(
|
||||||
|
model,
|
||||||
|
train_loader,
|
||||||
|
negative_loader,
|
||||||
|
criterion,
|
||||||
|
optimizer,
|
||||||
|
num_epochs=10,
|
||||||
|
margin=1.0,
|
||||||
|
temperature=0.4,
|
||||||
|
):
|
||||||
|
model.train()
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
model.to(device)
|
||||||
|
negative_iter = iter(negative_loader)
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
total_loss = 0
|
||||||
|
for _, (labels, embeddings) in enumerate(train_loader):
|
||||||
|
embeddings = embeddings.to(device)
|
||||||
|
labels = labels.to(device)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
outputs = model(embeddings)
|
||||||
|
class_loss = criterion(outputs, labels)
|
||||||
|
known_energy = energy_score(outputs, temperature)
|
||||||
|
positive_margin = 0.0
|
||||||
|
energy_loss_known = F.relu(known_energy - positive_margin).mean()
|
||||||
|
|
||||||
|
noise_embeddings = torch.randn_like(embeddings).to(device)
|
||||||
|
noise_outputs = model(noise_embeddings)
|
||||||
|
noise_energy = energy_score(noise_outputs, temperature)
|
||||||
|
energy_loss_noise = F.relu(margin - noise_energy).mean()
|
||||||
|
|
||||||
|
try:
|
||||||
|
negative_samples = next(negative_iter)
|
||||||
|
except StopIteration:
|
||||||
|
negative_iter = iter(negative_loader)
|
||||||
|
negative_samples = next(negative_iter)
|
||||||
|
negative_samples = negative_samples.to(device)
|
||||||
|
negative_outputs = model(negative_samples)
|
||||||
|
negative_energy = energy_score(negative_outputs, temperature)
|
||||||
|
energy_loss_negative = F.relu(margin - negative_energy).mean()
|
||||||
|
|
||||||
|
total_energy_loss = (
|
||||||
|
energy_loss_known + energy_loss_noise + energy_loss_negative
|
||||||
|
)
|
||||||
|
total_loss_batch = class_loss + total_energy_loss
|
||||||
|
|
||||||
|
total_loss_batch.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
total_loss += total_loss_batch.item()
|
||||||
|
|
||||||
|
avg_loss = total_loss / len(train_loader)
|
||||||
|
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
from config import model_name, DIMENSIONS
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
data = load_data("data.json")
|
||||||
|
class_to_idx, idx_to_class = create_class_mappings(data)
|
||||||
|
embedding_map = torch.load("token_id_to_reduced_embedding.pt")
|
||||||
|
dataset = preprocess_data(data, embedding_map, tokenizer, class_to_idx)
|
||||||
|
train_data, _ = train_test_split(dataset, test_size=0.2)
|
||||||
|
|
||||||
|
train_dataset = TextDataset(train_data)
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset, batch_size=24, shuffle=True, collate_fn=train_dataset.collate_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
with open("noise.json", "r") as f:
|
||||||
|
negative_samples_list = json.load(f)
|
||||||
|
|
||||||
|
negative_embedding_list = []
|
||||||
|
for sentence in negative_samples_list:
|
||||||
|
tokens = tokenizer.tokenize(sentence)
|
||||||
|
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
embeddings = [embedding_map[token_id] for token_id in token_ids[:64]]
|
||||||
|
embeddings = torch.tensor(embeddings)
|
||||||
|
negative_embedding_list.append(embeddings)
|
||||||
|
|
||||||
|
negative_dataset = NegativeSampleDataset(negative_embedding_list)
|
||||||
|
negative_loader = DataLoader(
|
||||||
|
negative_dataset,
|
||||||
|
batch_size=24,
|
||||||
|
shuffle=True,
|
||||||
|
collate_fn=negative_dataset.collate_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_dim = DIMENSIONS
|
||||||
|
num_classes = len(class_to_idx)
|
||||||
|
model = AttentionBasedModel(input_dim, num_classes)
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=7e-4)
|
||||||
|
|
||||||
|
train_energy_model(
|
||||||
|
model, train_loader, negative_loader, criterion, optimizer, num_epochs=120
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.save(model.state_dict(), "model.pt")
|
||||||
|
|
||||||
|
dummy_input = torch.randn(1, 64, DIMENSIONS)
|
||||||
|
torch.onnx.export(
|
||||||
|
model,
|
||||||
|
dummy_input,
|
||||||
|
"model.onnx",
|
||||||
|
input_names=["input"],
|
||||||
|
output_names=["output"],
|
||||||
|
dynamic_axes={
|
||||||
|
"input": {0: "batch_size", 1: "seq_length"},
|
||||||
|
"output": {0: "batch_size"},
|
||||||
|
},
|
||||||
|
opset_version=11,
|
||||||
|
)
|
||||||
|
meta = {
|
||||||
|
"idx_to_class": idx_to_class,
|
||||||
|
"threshold": 0
|
||||||
|
}
|
||||||
|
with open('NLU_meta.json', 'w') as f:
|
||||||
|
json.dump(meta, f)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
75
intention-classify/validation/inference.py
Normal file
75
intention-classify/validation/inference.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
from training.model import AttentionBasedModel
|
||||||
|
from training.config import model_name
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from training.config import DIMENSIONS
|
||||||
|
from training.model import AttentionBasedModel
|
||||||
|
|
||||||
|
|
||||||
|
def energy_score(logits):
|
||||||
|
# Energy score is minus logsumexp
|
||||||
|
return -torch.logsumexp(logits, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def predict_with_energy(
|
||||||
|
model,
|
||||||
|
sentence,
|
||||||
|
embedding_map,
|
||||||
|
tokenizer,
|
||||||
|
idx_to_class,
|
||||||
|
energy_threshold,
|
||||||
|
max_length=64,
|
||||||
|
):
|
||||||
|
model.eval()
|
||||||
|
tokens = tokenizer.tokenize(sentence)
|
||||||
|
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
print(token_ids)
|
||||||
|
embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]
|
||||||
|
embeddings = torch.tensor(embeddings).unsqueeze(0) # Add batch dimension
|
||||||
|
current_shape = embeddings.shape
|
||||||
|
|
||||||
|
if current_shape[1] < 2:
|
||||||
|
pad_size = 2 - current_shape[1]
|
||||||
|
embeddings = F.pad(
|
||||||
|
embeddings, (0, 0, 0, pad_size, 0, 0), mode="constant", value=0
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(embeddings)
|
||||||
|
print(logits)
|
||||||
|
probabilities = F.softmax(logits, dim=1)
|
||||||
|
max_prob, predicted = torch.max(probabilities, 1)
|
||||||
|
|
||||||
|
# Calculate energy score
|
||||||
|
energy = energy_score(logits)
|
||||||
|
|
||||||
|
# If energy > threshold, consider the input as unknown class
|
||||||
|
if energy.item() > energy_threshold:
|
||||||
|
return ["Unknown", max_prob.item(), energy.item()]
|
||||||
|
else:
|
||||||
|
return [idx_to_class[predicted.item()], max_prob.item(), energy.item()]
|
||||||
|
|
||||||
|
|
||||||
|
with open("data.json", "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
class_to_idx = {cls: idx for idx, cls in enumerate(data.keys())}
|
||||||
|
idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}
|
||||||
|
num_classes = len(class_to_idx)
|
||||||
|
|
||||||
|
input_dim = DIMENSIONS
|
||||||
|
model = AttentionBasedModel(input_dim, num_classes)
|
||||||
|
model.load_state_dict(torch.load("./model.pt"))
|
||||||
|
embedding_map = torch.load("token_id_to_reduced_embedding.pt")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
|
||||||
|
# Example usage:
|
||||||
|
ENERGY_THRESHOLD = 2
|
||||||
|
sentence = "what on earth is the cross entropy loss"
|
||||||
|
energy_threshold = ENERGY_THRESHOLD
|
||||||
|
predicted = predict_with_energy(
|
||||||
|
model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold
|
||||||
|
)
|
||||||
|
print(f"Predicted: {predicted}")
|
80
intention-classify/validation/openset_validation.py
Normal file
80
intention-classify/validation/openset_validation.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
from training.model import AttentionBasedModel
|
||||||
|
from training.config import model_name
|
||||||
|
from training.config import DIMENSIONS
|
||||||
|
from training.data_utils import get_sentences
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from tqdm import tqdm
|
||||||
|
from sklearn.metrics import f1_score, accuracy_score, precision_recall_fscore_support
|
||||||
|
|
||||||
|
def energy_score(logits):
|
||||||
|
# Energy score is minus logsumexp
|
||||||
|
return -torch.logsumexp(logits, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def get_energy(
|
||||||
|
model,
|
||||||
|
sentence,
|
||||||
|
embedding_map,
|
||||||
|
tokenizer,
|
||||||
|
max_length=64,
|
||||||
|
):
|
||||||
|
model.eval()
|
||||||
|
tokens = tokenizer.tokenize(sentence)
|
||||||
|
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]
|
||||||
|
embeddings = torch.tensor(embeddings).unsqueeze(0) # Add batch dimension
|
||||||
|
current_shape = embeddings.shape
|
||||||
|
|
||||||
|
if current_shape[1] < 2:
|
||||||
|
pad_size = 2 - current_shape[1]
|
||||||
|
embeddings = F.pad(
|
||||||
|
embeddings, (0, 0, 0, pad_size, 0, 0), mode="constant", value=0
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model(embeddings)
|
||||||
|
# Calculate energy score
|
||||||
|
energy = energy_score(logits)
|
||||||
|
|
||||||
|
return energy
|
||||||
|
|
||||||
|
|
||||||
|
with open("data.json", "r") as f:
|
||||||
|
positive_data = json.load(f)
|
||||||
|
class_to_idx = {cls: idx for idx, cls in enumerate(positive_data.keys())}
|
||||||
|
idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}
|
||||||
|
num_classes = len(class_to_idx)
|
||||||
|
|
||||||
|
with open("noise.json", "r") as f:
|
||||||
|
negative_data = json.load(f)
|
||||||
|
|
||||||
|
input_dim = DIMENSIONS
|
||||||
|
model = AttentionBasedModel(input_dim, num_classes)
|
||||||
|
model.load_state_dict(torch.load("./model.pt"))
|
||||||
|
embedding_map = torch.load("token_id_to_reduced_embedding.pt")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
|
||||||
|
|
||||||
|
all_preds = []
|
||||||
|
all_labels = []
|
||||||
|
ENERGY_THRESHOLD = 2
|
||||||
|
for item in tqdm(get_sentences(positive_data)):
|
||||||
|
result = get_energy(model, item, embedding_map, tokenizer) < ENERGY_THRESHOLD
|
||||||
|
all_preds.append(result)
|
||||||
|
all_labels.append(1)
|
||||||
|
|
||||||
|
for item in tqdm(negative_data):
|
||||||
|
result = get_energy(model, item, embedding_map, tokenizer) < ENERGY_THRESHOLD
|
||||||
|
all_preds.append(result)
|
||||||
|
all_labels.append(0)
|
||||||
|
|
||||||
|
precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')
|
||||||
|
accuracy = accuracy_score(all_labels, all_preds)
|
||||||
|
|
||||||
|
print(f'Accuracy: {accuracy:.4f}')
|
||||||
|
print(f'Precision: {precision:.4f}')
|
||||||
|
print(f'Recall: {recall:.4f}')
|
||||||
|
print(f'F1 Score: {f1:.4f}')
|
3629
text-difficulty/grammar/EGP.csv
Normal file
3629
text-difficulty/grammar/EGP.csv
Normal file
File diff suppressed because it is too large
Load Diff
11829
text-difficulty/grammar/EGP_Derivied.csv
Normal file
11829
text-difficulty/grammar/EGP_Derivied.csv
Normal file
File diff suppressed because it is too large
Load Diff
103
text-difficulty/grammar/article.py
Normal file
103
text-difficulty/grammar/article.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import nltk
|
||||||
|
from nltk.tokenize import sent_tokenize, word_tokenize
|
||||||
|
from training.model import AttentionBasedModel
|
||||||
|
|
||||||
|
# Ensure required NLTK resources are available
|
||||||
|
nltk.download('punkt')
|
||||||
|
nltk.download('averaged_perceptron_tagger')
|
||||||
|
|
||||||
|
# Load pre-saved mappings
|
||||||
|
pos2idx = torch.load('pos2idx.pt')
|
||||||
|
class_mapping = torch.load('class_mapping.pt')
|
||||||
|
|
||||||
|
|
||||||
|
# Load the pre-trained model and state
|
||||||
|
model = AttentionBasedModel(40, 32, 6, 8, 6, 256)
|
||||||
|
model.load_state_dict(torch.load("./model.pt", weights_only=True))
|
||||||
|
|
||||||
|
# Define helper functions
|
||||||
|
def pad_sequence(seq, max_len):
|
||||||
|
return seq + [-1] * (max_len - len(seq))
|
||||||
|
|
||||||
|
def encode_pos_tags(tagged_sentence):
|
||||||
|
return [pos2idx[tag] if tag in pos2idx else -1 for _, tag in tagged_sentence]
|
||||||
|
|
||||||
|
# Split sentence into smaller chunks based on punctuation and length constraints
|
||||||
|
def split_long_sentence(sentence, max_len=128):
|
||||||
|
tokens = word_tokenize(sentence)
|
||||||
|
|
||||||
|
if len(tokens) <= max_len:
|
||||||
|
return [sentence]
|
||||||
|
|
||||||
|
# Attempt to split based on punctuation marks
|
||||||
|
punctuation_marks = [',', ';', ':', '!', '?', '.', '-']
|
||||||
|
split_chunks = []
|
||||||
|
current_chunk = []
|
||||||
|
|
||||||
|
for token in tokens:
|
||||||
|
current_chunk.append(token)
|
||||||
|
if token in punctuation_marks and len(current_chunk) >= max_len // 2:
|
||||||
|
split_chunks.append(' '.join(current_chunk))
|
||||||
|
current_chunk = []
|
||||||
|
|
||||||
|
if current_chunk:
|
||||||
|
split_chunks.append(' '.join(current_chunk))
|
||||||
|
|
||||||
|
# If chunks are still too long, truncate them
|
||||||
|
final_chunks = []
|
||||||
|
for chunk in split_chunks:
|
||||||
|
chunk_tokens = word_tokenize(chunk)
|
||||||
|
if len(chunk_tokens) > max_len:
|
||||||
|
final_chunks.extend([' '.join(chunk_tokens[i:i + max_len]) for i in range(0, len(chunk_tokens), max_len)])
|
||||||
|
else:
|
||||||
|
final_chunks.append(chunk)
|
||||||
|
|
||||||
|
return final_chunks
|
||||||
|
|
||||||
|
# Main function to process and score a chunk
|
||||||
|
def score_sentence(sentence, model, max_length=128):
|
||||||
|
# Tokenize and POS-tag the sentence
|
||||||
|
tagged_sentence = nltk.pos_tag(nltk.word_tokenize(sentence))
|
||||||
|
|
||||||
|
# Encode the POS tags and pad the sequence
|
||||||
|
encoded_sentences = encode_pos_tags(tagged_sentence)
|
||||||
|
padded_sentence = torch.tensor(pad_sequence(encoded_sentences, max_length))
|
||||||
|
|
||||||
|
# Set the device
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
|
||||||
|
|
||||||
|
# Prepare the model
|
||||||
|
model.to(device)
|
||||||
|
model.eval() # Ensure the model is in evaluation mode
|
||||||
|
|
||||||
|
# Define weights and CEFR levels
|
||||||
|
w_list = [1.04, 1.64, 2.35, 3.44, 4.92, 6.13]
|
||||||
|
|
||||||
|
# Inference without gradient calculation
|
||||||
|
with torch.no_grad():
|
||||||
|
sentence_tensor = padded_sentence.to(device)
|
||||||
|
sentence_tensor = torch.unsqueeze(sentence_tensor, 0) # Add batch dimension
|
||||||
|
|
||||||
|
# Forward pass through the model
|
||||||
|
outputs = model(sentence_tensor)
|
||||||
|
|
||||||
|
# Softmax and weighted scoring
|
||||||
|
probabilities = torch.softmax(outputs[0], dim=0)
|
||||||
|
score = sum(probabilities[i] * w_list[i] for i in range(6)).cpu().numpy()
|
||||||
|
|
||||||
|
return score
|
||||||
|
|
||||||
|
# Function to process a long article and return score list for each chunk
|
||||||
|
def score_article(article, max_length=128, chunk_max_len=128):
|
||||||
|
sentences = sent_tokenize(article) # Split the article into sentences
|
||||||
|
score_list = []
|
||||||
|
|
||||||
|
for sentence in sentences:
|
||||||
|
chunks = split_long_sentence(sentence, max_len=chunk_max_len)
|
||||||
|
for chunk in chunks:
|
||||||
|
score = score_sentence(chunk, model, max_length=max_length)
|
||||||
|
score_list.append(float(score))
|
||||||
|
|
||||||
|
return score_list
|
11819
text-difficulty/grammar/data.csv
Normal file
11819
text-difficulty/grammar/data.csv
Normal file
File diff suppressed because it is too large
Load Diff
82
text-difficulty/grammar/data_deriving.py
Normal file
82
text-difficulty/grammar/data_deriving.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import pandas as pd
|
||||||
|
from openai import OpenAI
|
||||||
|
from tqdm import tqdm
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
api_key=os.getenv("API_KEY"),
|
||||||
|
base_url=os.getenv("BASE_URL"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_AI_response(text, client, model_name, temp):
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": text},
|
||||||
|
]
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
temperature=temp,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
def get_Examples(df, row, client, model_name, temp):
|
||||||
|
exp = df["Example"][row]
|
||||||
|
cds = df["Can-do statement"][row]
|
||||||
|
gdw = df["guideword"][row]
|
||||||
|
lvl = df["Level"][row]
|
||||||
|
cat = df["SuperCategory"][row] + '/' + df["SubCategory"][row]
|
||||||
|
prompt = \
|
||||||
|
f'''Generate 10 example sentences based on the following instructions.
|
||||||
|
Pay close attention to the 'Can-do Statement' and ensure all generated sentences adhere strictly to it.
|
||||||
|
Provide only the sentences without any additional formatting or markdown.
|
||||||
|
Output the sentences in plain text, one sentence per line, and do not contain empty line.
|
||||||
|
INSTRUCTION
|
||||||
|
Level: {lvl}
|
||||||
|
Guideword: {gdw}
|
||||||
|
Can-do Statement: {cds}
|
||||||
|
Category: {cat}
|
||||||
|
Example Sentences:
|
||||||
|
{exp}
|
||||||
|
'''
|
||||||
|
return get_AI_response(prompt, client, model_name, temp)
|
||||||
|
|
||||||
|
def process_chunk(df, chunk, client, model, temp):
|
||||||
|
results = []
|
||||||
|
for row in chunk:
|
||||||
|
exps = get_Examples(df, row, client, model, temp)
|
||||||
|
results.append(exps)
|
||||||
|
return results
|
||||||
|
|
||||||
|
input_file = './EGP.csv'
|
||||||
|
df = pd.read_csv(input_file)
|
||||||
|
newdf = df.copy()
|
||||||
|
model = os.getenv("TRANSLATION_MODEL")
|
||||||
|
temp = float(os.getenv("TRANSLATION_TEMP"))
|
||||||
|
|
||||||
|
chunk_size = 64
|
||||||
|
total_rows = len(df.index)
|
||||||
|
num_chunks = (total_rows + chunk_size - 1) // chunk_size # Ceiling division
|
||||||
|
|
||||||
|
with tqdm(total=total_rows) as pbar:
|
||||||
|
for chunk_idx in range(num_chunks):
|
||||||
|
start = chunk_idx * chunk_size
|
||||||
|
end = min(start + chunk_size, total_rows)
|
||||||
|
chunk = range(start, end)
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=len(chunk)) as executor:
|
||||||
|
futures = {executor.submit(get_Examples, df, row, client, model, temp): row for row in chunk} # 将 row 与 future 绑定
|
||||||
|
for future in as_completed(futures):
|
||||||
|
row = futures[future] # 获取对应的行号
|
||||||
|
result = future.result() # 获取 AI 返回的结果
|
||||||
|
newdf.at[row, "Example"] = result # 更新到正确的行
|
||||||
|
|
||||||
|
pbar.update(len(chunk))
|
||||||
|
newdf.to_csv("output.csv", index=False)
|
||||||
|
|
||||||
|
newdf.to_csv("EGP_Derivied.csv", index=False)
|
23
text-difficulty/grammar/data_postprocessing.py
Normal file
23
text-difficulty/grammar/data_postprocessing.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
df = pd.read_csv("EGP_Derivied.csv")
|
||||||
|
newdf = pd.DataFrame()
|
||||||
|
|
||||||
|
levels_list=[]
|
||||||
|
sentences_list=[]
|
||||||
|
category_list=[]
|
||||||
|
for line in range(len(df.index)):
|
||||||
|
examples = list(filter(None, df["Example"][line].split("\n")))
|
||||||
|
lvl = df["Level"][line]
|
||||||
|
cat = df["SuperCategory"][line] + '/' + df["SubCategory"][line]
|
||||||
|
for sentence in examples:
|
||||||
|
sentences_list.append(sentence)
|
||||||
|
levels_list.append(lvl)
|
||||||
|
category_list.append(cat)
|
||||||
|
|
||||||
|
|
||||||
|
newdf["Level"] = levels_list
|
||||||
|
newdf["Category"] = category_list
|
||||||
|
newdf["Sentence"] = sentences_list
|
||||||
|
|
||||||
|
newdf.to_csv("data.csv", index=False)
|
111
text-difficulty/grammar/training/model.py
Normal file
111
text-difficulty/grammar/training/model.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
# model.py
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
|
||||||
|
class PositionalEncoding(nn.Module):
|
||||||
|
def __init__(self, embedding_dim, max_len=5000):
|
||||||
|
super(PositionalEncoding, self).__init__()
|
||||||
|
|
||||||
|
# Create a positional encoding matrix of shape (max_len, embedding_dim)
|
||||||
|
pe = torch.zeros(max_len, embedding_dim)
|
||||||
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||||
|
div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
|
||||||
|
|
||||||
|
pe[:, 0::2] = torch.sin(position * div_term)
|
||||||
|
pe[:, 1::2] = torch.cos(position * div_term)
|
||||||
|
|
||||||
|
# Add a batch dimension, so the shape becomes (1, max_len, embedding_dim)
|
||||||
|
pe = pe.unsqueeze(0)
|
||||||
|
|
||||||
|
# Register the positional encoding as a buffer so it won't be updated by the optimizer
|
||||||
|
self.register_buffer('pe', pe)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x is expected to have shape (batch_size, seq_length, embedding_dim)
|
||||||
|
seq_length = x.size(1)
|
||||||
|
# Add positional encoding to input
|
||||||
|
x = x + self.pe[:, :seq_length]
|
||||||
|
return x
|
||||||
|
|
||||||
|
class SelfAttention(nn.Module):
|
||||||
|
def __init__(self, input_dim, heads):
|
||||||
|
super(SelfAttention, self).__init__()
|
||||||
|
self.heads = heads
|
||||||
|
self.scale = (input_dim // heads) ** -0.5
|
||||||
|
self.qkv = nn.Linear(input_dim, input_dim * 3)
|
||||||
|
self.fc = nn.Linear(input_dim, input_dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
batch_size, seq_length, embedding_dim = x.shape
|
||||||
|
qkv = self.qkv(x).view(
|
||||||
|
batch_size, seq_length, self.heads, 3, embedding_dim // self.heads
|
||||||
|
)
|
||||||
|
q, k, v = qkv[..., 0, :], qkv[..., 1, :], qkv[..., 2, :]
|
||||||
|
q = q.permute(0, 2, 1, 3)
|
||||||
|
k = k.permute(0, 2, 1, 3)
|
||||||
|
v = v.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
||||||
|
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
|
attention_output = torch.matmul(attn_weights, v)
|
||||||
|
attention_output = attention_output.permute(0, 2, 1, 3).contiguous()
|
||||||
|
attention_output = attention_output.view(batch_size, seq_length, embedding_dim)
|
||||||
|
|
||||||
|
return self.fc(attention_output)
|
||||||
|
|
||||||
|
class AttentionBasedModel(nn.Module):
|
||||||
|
def __init__(self, pos_vocab_size, embedding_dim=128, num_classes=6, heads=8, num_attention_layers=3, dim_feedforward=512, max_len=128):
|
||||||
|
super(AttentionBasedModel, self).__init__()
|
||||||
|
self.embedding = nn.Embedding(pos_vocab_size, embedding_dim) # Embedding for POS tags
|
||||||
|
self.positional_encoding = PositionalEncoding(embedding_dim, max_len) # Positional Encoding
|
||||||
|
self.self_attention_layers = nn.ModuleList([
|
||||||
|
SelfAttention(embedding_dim, heads) for _ in range(num_attention_layers)
|
||||||
|
])
|
||||||
|
self.fc1 = nn.Linear(embedding_dim, dim_feedforward)
|
||||||
|
self.fc2 = nn.Linear(dim_feedforward, num_classes)
|
||||||
|
self.dropout = nn.Dropout(0.5)
|
||||||
|
self.norm = nn.LayerNorm(embedding_dim)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# Input x is a matrix of one-hot encoded POS tags, shape: (batch_size, seq_length, pos_vocab_size)
|
||||||
|
x = self.embedding(x) # Convert POS tags to embeddings, shape: (batch_size, seq_length, embedding_dim)
|
||||||
|
|
||||||
|
# Add positional encoding to embeddings
|
||||||
|
x = self.positional_encoding(x)
|
||||||
|
|
||||||
|
for attn_layer in self.self_attention_layers:
|
||||||
|
attn_output = attn_layer(x)
|
||||||
|
x = self.norm(attn_output + x)
|
||||||
|
|
||||||
|
# Pool the output by taking the mean of the sequence (reduce along sequence length)
|
||||||
|
pooled_output = torch.mean(x, dim=1)
|
||||||
|
|
||||||
|
# Fully connected layers for classification
|
||||||
|
x = F.relu(self.fc1(pooled_output))
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.fc2(x) # Output logits for the 6 classes
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# Example Usage
|
||||||
|
# # Hyperparameters
|
||||||
|
# pos_vocab_size = 50 # Size of the POS tag vocabulary
|
||||||
|
# max_context_length = 128 # Maximum context length
|
||||||
|
# embedding_dim = 128 # Embedding size
|
||||||
|
# num_classes = 6 # Output classes
|
||||||
|
# batch_size = 32 # Example batch size
|
||||||
|
|
||||||
|
# # Model initialization
|
||||||
|
# model = AttentionBasedModel(pos_vocab_size, embedding_dim, num_classes)
|
||||||
|
|
||||||
|
# # Example input: batch of one-hot encoded POS tags (variable length sequences)
|
||||||
|
# input_data = torch.randint(0, pos_vocab_size, (batch_size, max_context_length)) # Random input for testing
|
||||||
|
|
||||||
|
# # Forward pass
|
||||||
|
# output = model(input_data) # Output shape will be (batch_size, num_classes)
|
||||||
|
|
||||||
|
# print(output.shape) # Should print torch.Size([batch_size, num_classes])
|
150
text-difficulty/grammar/training/train.py
Normal file
150
text-difficulty/grammar/training/train.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
import torch
|
||||||
|
import torch.optim as optim
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
import pandas as pd
|
||||||
|
import nltk
|
||||||
|
from sklearn.preprocessing import LabelEncoder
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
#nltk.download('punkt')
|
||||||
|
#nltk.download('averaged_perceptron_tagger')
|
||||||
|
|
||||||
|
# Load the model classes
|
||||||
|
from model import AttentionBasedModel
|
||||||
|
|
||||||
|
# Load the data
|
||||||
|
df = pd.read_csv('data.csv')
|
||||||
|
|
||||||
|
# Step 1: Extract sentences and corresponding levels
|
||||||
|
sentences = df['Sentence'].values
|
||||||
|
levels = df['Level'].values
|
||||||
|
|
||||||
|
# Step 2: Tokenize and POS tag each sentence
|
||||||
|
pos_tags = [nltk.pos_tag(nltk.word_tokenize(sentence)) for sentence in sentences]
|
||||||
|
|
||||||
|
# Step 3: Build POS tag vocabulary
|
||||||
|
# Extract unique POS tags from the dataset
|
||||||
|
pos_vocab = set()
|
||||||
|
for tagged_sentence in pos_tags:
|
||||||
|
for _, tag in tagged_sentence:
|
||||||
|
pos_vocab.add(tag)
|
||||||
|
|
||||||
|
# Create a mapping from POS tag to index
|
||||||
|
pos2idx = {pos: idx for idx, pos in enumerate(pos_vocab)}
|
||||||
|
pos_vocab_size = len(pos2idx)
|
||||||
|
|
||||||
|
# Step 4: Encode sentences into POS tag indices
|
||||||
|
def encode_pos_tags(tagged_sentence):
|
||||||
|
return [pos2idx[tag] for _, tag in tagged_sentence]
|
||||||
|
|
||||||
|
encoded_sentences = [encode_pos_tags(tagged_sentence) for tagged_sentence in pos_tags]
|
||||||
|
|
||||||
|
# Step 5: Encode levels (classes) into integers
|
||||||
|
le = LabelEncoder()
|
||||||
|
encoded_levels = le.fit_transform(levels)
|
||||||
|
num_classes = len(le.classes_)
|
||||||
|
|
||||||
|
# Save class encoding mapping
|
||||||
|
class_mapping = dict(zip(le.transform(le.classes_), le.classes_))
|
||||||
|
torch.save(class_mapping, 'class_mapping.pt')
|
||||||
|
|
||||||
|
# Save POS tag encoding mapping
|
||||||
|
torch.save(pos2idx, 'pos2idx.pt')
|
||||||
|
|
||||||
|
# Step 6: Pad sentences to a fixed length
|
||||||
|
max_length = 64
|
||||||
|
|
||||||
|
def pad_sequence(seq, max_len):
|
||||||
|
return seq + [-1] * (max_len - len(seq))
|
||||||
|
|
||||||
|
padded_sentences = [pad_sequence(seq, max_length) for seq in encoded_sentences]
|
||||||
|
|
||||||
|
# Step 7: Create a PyTorch Dataset and DataLoader
|
||||||
|
class POSDataset(Dataset):
|
||||||
|
def __init__(self, sentences, labels):
|
||||||
|
self.sentences = sentences
|
||||||
|
self.labels = labels
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.sentences)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
sentence = torch.tensor(self.sentences[idx], dtype=torch.long)
|
||||||
|
label = torch.tensor(self.labels[idx], dtype=torch.long)
|
||||||
|
return sentence, label
|
||||||
|
|
||||||
|
# Split data into training and validation sets
|
||||||
|
X_train, X_val, y_train, y_val = train_test_split(padded_sentences, encoded_levels, test_size=0.2)
|
||||||
|
|
||||||
|
train_dataset = POSDataset(X_train, y_train)
|
||||||
|
val_dataset = POSDataset(X_val, y_val)
|
||||||
|
|
||||||
|
batch_size = 128
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||||
|
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
||||||
|
|
||||||
|
# Step 8: Initialize the model, loss function, and optimizer
|
||||||
|
embedding_dim = 32
|
||||||
|
heads = 8
|
||||||
|
num_attention_layers = 6
|
||||||
|
dim_feedforward = 256
|
||||||
|
learning_rate = 0.003
|
||||||
|
|
||||||
|
model = AttentionBasedModel(pos_vocab_size, embedding_dim, num_classes, heads, num_attention_layers, dim_feedforward)
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
||||||
|
|
||||||
|
# Step 9: Training loop
|
||||||
|
num_epochs = 100
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
step = 0
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
model.train()
|
||||||
|
running_loss = 0.0
|
||||||
|
|
||||||
|
for sentences, labels in train_loader:
|
||||||
|
sentences, labels = sentences.to(device), labels.to(device)
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
outputs = model(sentences)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
|
||||||
|
# Backward pass and optimization
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
step += 1
|
||||||
|
|
||||||
|
running_loss += loss.item()
|
||||||
|
|
||||||
|
# Validation phase
|
||||||
|
model.eval()
|
||||||
|
val_loss = 0.0
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for sentences, labels in val_loader:
|
||||||
|
sentences, labels = sentences.to(device), labels.to(device)
|
||||||
|
|
||||||
|
outputs = model(sentences)
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
val_loss += loss.item()
|
||||||
|
|
||||||
|
# Calculate accuracy
|
||||||
|
_, predicted = torch.max(outputs, 1)
|
||||||
|
total += labels.size(0)
|
||||||
|
correct += (predicted == labels).sum().item()
|
||||||
|
|
||||||
|
# Print training and validation stats
|
||||||
|
print(f'Epoch [{epoch+1}/{num_epochs}], Step {step}, Loss: {running_loss/len(train_loader):.4f}, '
|
||||||
|
f'Validation Loss: {val_loss/len(val_loader):.4f}, Accuracy: {100 * correct / total:.2f}%')
|
||||||
|
torch.save(model.state_dict(), f'checkpoints/step_{step}.pt')
|
||||||
|
|
||||||
|
# Step 10: Save the trained model
|
||||||
|
torch.save(model.state_dict(), 'model.pt')
|
46
text-difficulty/grammar/validation/inference.py
Normal file
46
text-difficulty/grammar/validation/inference.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
from training.model import AttentionBasedModel
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import nltk
|
||||||
|
|
||||||
|
|
||||||
|
sentence = '''Smartphones have worked their way deep into our lives and have become indispensable for work and socialising.'''
|
||||||
|
|
||||||
|
pos2idx = torch.load('pos2idx.pt')
|
||||||
|
class_mapping = torch.load('class_mapping.pt')
|
||||||
|
|
||||||
|
def pad_sequence(seq, max_len):
|
||||||
|
return seq + [-1] * (max_len - len(seq))
|
||||||
|
|
||||||
|
def encode_pos_tags(tagged_sentence):
|
||||||
|
return [pos2idx[tag] if tag in pos2idx else -1 for _, tag in tagged_sentence]
|
||||||
|
|
||||||
|
|
||||||
|
max_length = 64
|
||||||
|
|
||||||
|
tagged_sentence = nltk.pos_tag(nltk.word_tokenize(sentence))
|
||||||
|
encoded_sentences = encode_pos_tags(tagged_sentence)
|
||||||
|
padded_sentence = torch.tensor(pad_sequence(encoded_sentences, max_length))
|
||||||
|
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
|
||||||
|
model = AttentionBasedModel(40, 32, 6, 8, 6, 256)
|
||||||
|
model.load_state_dict(torch.load("./model.pt", weights_only=True))
|
||||||
|
model.to(device)
|
||||||
|
model.eval() # 确保模型处于评估模式
|
||||||
|
|
||||||
|
|
||||||
|
w_list=[1.35, 1.63, 2.75, 3.64, 5.38, 6.32]
|
||||||
|
cefr_dict = [None, "A1", "A2", "B1", "B2", "C1", "C2"]
|
||||||
|
with torch.no_grad():
|
||||||
|
sentence = padded_sentence.to(device)
|
||||||
|
sentences = torch.unsqueeze(sentence, 0)
|
||||||
|
|
||||||
|
outputs = model(sentences)
|
||||||
|
print(torch.max(outputs, 1))
|
||||||
|
print(outputs[0])
|
||||||
|
print(torch.softmax(outputs[0],0))
|
||||||
|
s=0
|
||||||
|
for i in range(6):
|
||||||
|
s+=torch.softmax(outputs[0],0)[i] * w_list[i]
|
||||||
|
s=s.cpu().numpy()
|
||||||
|
# the s is the final output.
|
@ -34,63 +34,79 @@ EXAMPLE JSON OUTPUT:
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def translate_text(text):
|
|
||||||
|
def translate_text(text, client, model_name, temp):
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": system_prompt},
|
{"role": "system", "content": system_prompt},
|
||||||
{"role": "user", "content": text}
|
{"role": "user", "content": text},
|
||||||
]
|
]
|
||||||
|
|
||||||
response = client.chat.completions.create(
|
response = client.chat.completions.create(
|
||||||
model=os.getenv("TRANSLATION_MODEL"),
|
model=model_name,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
response_format={'type': 'json_object'},
|
response_format={"type": "json_object"},
|
||||||
temperature=float(os.getenv("TRANSLATION_TEMP"))
|
temperature=temp,
|
||||||
)
|
)
|
||||||
|
|
||||||
return json.loads(response.choices[0].message.content)
|
return json.loads(response.choices[0].message.content)
|
||||||
|
|
||||||
|
|
||||||
def process_file(input_file, output_dir):
|
def process_file(input_file, output_dir):
|
||||||
try:
|
try:
|
||||||
with open(input_file, 'r', encoding='utf-8') as f:
|
with open(input_file, "r", encoding="utf-8") as f:
|
||||||
text = f.read()
|
text = f.read()
|
||||||
|
|
||||||
translation = translate_text(text)
|
model = os.getenv("TRANSLATION_MODEL")
|
||||||
|
temp = float(os.getenv("TRANSLATION_TEMP"))
|
||||||
|
translation = translate_text(text, client, model, temp)
|
||||||
|
|
||||||
output_path = os.path.join(output_dir, Path(input_file).stem + ".json")
|
output_path = os.path.join(output_dir, Path(input_file).stem + ".json")
|
||||||
with open(output_path, 'w', encoding='utf-8') as f:
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(translation, f, ensure_ascii=False, indent=4)
|
json.dump(translation, f, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
print(f"Successfully translated and saved to {output_path}")
|
print(f"Successfully translated and saved to {output_path}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing {input_file}: {e}")
|
print(f"Error processing {input_file}: {e}")
|
||||||
|
|
||||||
|
|
||||||
def batch_process(input_dir, output_dir, num_threads=4):
|
def batch_process(input_dir, output_dir, num_threads=4):
|
||||||
if not os.path.exists(output_dir):
|
if not os.path.exists(output_dir):
|
||||||
os.makedirs(output_dir)
|
os.makedirs(output_dir)
|
||||||
|
|
||||||
input_files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
input_files = [
|
||||||
output_files = [f for f in os.listdir(output_dir) if os.path.isfile(os.path.join(output_dir, f))]
|
f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))
|
||||||
|
]
|
||||||
|
output_files = [
|
||||||
|
f for f in os.listdir(output_dir) if os.path.isfile(os.path.join(output_dir, f))
|
||||||
|
]
|
||||||
|
|
||||||
output_stems = {Path(f).stem for f in output_files}
|
output_stems = {Path(f).stem for f in output_files}
|
||||||
|
|
||||||
files = [os.path.join(input_dir, f) for f in input_files if Path(f).stem not in output_stems]
|
files = [
|
||||||
|
os.path.join(input_dir, f)
|
||||||
|
for f in input_files
|
||||||
|
if Path(f).stem not in output_stems
|
||||||
|
]
|
||||||
|
|
||||||
threads = []
|
threads = []
|
||||||
for file in files:
|
for file in files:
|
||||||
thread = threading.Thread(target=process_file, args=(file, output_dir))
|
thread = threading.Thread(target=process_file, args=(file, output_dir))
|
||||||
threads.append(thread)
|
threads.append(thread)
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
if len(threads) >= num_threads:
|
if len(threads) >= num_threads:
|
||||||
for t in threads:
|
for t in threads:
|
||||||
t.join()
|
t.join()
|
||||||
threads = []
|
threads = []
|
||||||
|
|
||||||
for t in threads:
|
for t in threads:
|
||||||
t.join()
|
t.join()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
input_dir = "./source"
|
input_dir = "./source-new"
|
||||||
output_dir = "./output"
|
output_dir = "./output-new"
|
||||||
batch_process(input_dir, output_dir, num_threads=int(os.getenv("TRANSLATE_THREADS")))
|
batch_process(
|
||||||
|
input_dir, output_dir, num_threads=int(os.getenv("TRANSLATE_THREADS"))
|
||||||
|
)
|
50
translate/README.md
Normal file
50
translate/README.md
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
# sparkastML NMT
|
||||||
|
|
||||||
|
A set of models that aims to offer best open-source machine translation, based on the [OpenNMT](https://opennmt.net/).
|
||||||
|
|
||||||
|
## News
|
||||||
|
|
||||||
|
sparkastML's translation model is now updated!
|
||||||
|
|
||||||
|
### Details
|
||||||
|
|
||||||
|
- **Source Language:** Chinese (Simplified)
|
||||||
|
- **Target Language:** English
|
||||||
|
- **Training Time:** Totally 11.3 hours, 46,500 steps (~1×10¹⁸ FLOPs)
|
||||||
|
- **Training Device:**
|
||||||
|
- RTX 3080 (20GB): 0-20,000 steps
|
||||||
|
- RTX 4070: 20,000-46,500 steps
|
||||||
|
- **Corpus Size:** Over 10 million sentences
|
||||||
|
- **Validation BLEU Score:** 21.28
|
||||||
|
- **Validation Loss (Cross Entropy):** 3.152
|
||||||
|
|
||||||
|
### Model Download
|
||||||
|
|
||||||
|
Avaliable soon.
|
||||||
|
|
||||||
|
### Special thanks
|
||||||
|
|
||||||
|
[yumechi](https://github.com/eternal-flame-AD/) for sponsoring an RTX 4070 for training.
|
||||||
|
|
||||||
|
## History
|
||||||
|
|
||||||
|
### Sep 19, 2024
|
||||||
|
|
||||||
|
sparkastML's translation model is now updated!
|
||||||
|
|
||||||
|
#### Details
|
||||||
|
|
||||||
|
- **Source Language:** Chinese (Simplified)
|
||||||
|
- **Target Language:** English
|
||||||
|
- **Training Time:** 5 hours, 20,000 steps
|
||||||
|
- **Training Device:** RTX 3080 (20GB)
|
||||||
|
- **Corpus Size:** Over 10 million sentences
|
||||||
|
- **Validation BLEU Score:** 17
|
||||||
|
- **Version:** 1.0
|
||||||
|
|
||||||
|
#### Model Download
|
||||||
|
|
||||||
|
- **Google Drive:** [Download from Google Drive](https://drive.google.com/drive/folders/1-q_AKfQENW-pV6uAleUHPE9ghddfNWKF)
|
||||||
|
- **IPFS:** [Download from IPFS](http://ipfs.a2x.pub/ipfs/QmUMadzkBwvH5KTpoxfv7TgqzaPpqBzkXtkecV9TXPfZ3F/)
|
||||||
|
- **CID:** `QmUMadzkBwvH5KTpoxfv7TgqzaPpqBzkXtkecV9TXPfZ3F`
|
||||||
|
- **GitHub Release:** [Go to Release Page](https://github.com/alikia2x/sparkastML/releases/tag/v2-model)
|
126
translate/analytics/Distribution.ipynb
Normal file
126
translate/analytics/Distribution.ipynb
Normal file
File diff suppressed because one or more lines are too long
79
translate/analytics/ccmatrix/check_sim.py
Normal file
79
translate/analytics/ccmatrix/check_sim.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
import torch
|
||||||
|
from transformers import AutoModel, AutoTokenizer
|
||||||
|
from numpy.linalg import norm
|
||||||
|
import sys
|
||||||
|
import random
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# Define the cosine similarity function
|
||||||
|
cos_sim = lambda a, b: (a @ b.T) / (norm(a) * norm(b))
|
||||||
|
|
||||||
|
# Load the model and tokenizer
|
||||||
|
model_name = 'jinaai/jina-embeddings-v2-base-zh'
|
||||||
|
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
|
||||||
|
# Check if the correct number of command-line arguments are provided
|
||||||
|
if len(sys.argv) < 4 or len(sys.argv) > 5:
|
||||||
|
print("Usage: python script.py <file_a_path> <file_b_path> <output_file_path> [num_samples]")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Define file paths from command-line arguments
|
||||||
|
file_a_path = sys.argv[1]
|
||||||
|
file_b_path = sys.argv[2]
|
||||||
|
output_file_path = sys.argv[3]
|
||||||
|
|
||||||
|
# Define the number of samples to randomly select
|
||||||
|
num_samples = int(sys.argv[4]) if len(sys.argv) == 5 else 100
|
||||||
|
|
||||||
|
# Get the total number of lines in the files without loading them fully
|
||||||
|
def count_lines(file_path):
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
return sum(1 for _ in f)
|
||||||
|
|
||||||
|
total_lines_a = count_lines(file_a_path)
|
||||||
|
total_lines_b = count_lines(file_b_path)
|
||||||
|
|
||||||
|
# Ensure both files have the same number of lines
|
||||||
|
if total_lines_a != total_lines_b:
|
||||||
|
print("Files must have the same number of lines.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Select random sample indices without loading entire files
|
||||||
|
selected_indices = sorted(random.sample(range(total_lines_a), num_samples))
|
||||||
|
|
||||||
|
# Function to get all sampled lines from the file
|
||||||
|
def get_lines(file_path, line_numbers):
|
||||||
|
result = []
|
||||||
|
max_i = max(line_numbers)
|
||||||
|
j=0
|
||||||
|
next_i = line_numbers[j]
|
||||||
|
len_line_numbers = len(line_numbers)
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
for current_line, line in tqdm(enumerate(f)):
|
||||||
|
if current_line < next_i:
|
||||||
|
continue
|
||||||
|
result.append(line.strip())
|
||||||
|
j+=1
|
||||||
|
if current_line >= max_i or j >= len_line_numbers:
|
||||||
|
return result
|
||||||
|
next_i = line_numbers[j]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
lines_a = get_lines(file_a_path, selected_indices)
|
||||||
|
lines_b = get_lines(file_b_path, selected_indices)
|
||||||
|
|
||||||
|
# Open output file for writing
|
||||||
|
with open(output_file_path, 'w', encoding='utf-8') as output_file:
|
||||||
|
for i, idx in tqdm(enumerate(selected_indices)):
|
||||||
|
# Get the corresponding lines from both files
|
||||||
|
line_a = lines_a[i]
|
||||||
|
line_b = lines_b[i]
|
||||||
|
|
||||||
|
embeddings = model.encode([line_a, line_b])
|
||||||
|
similarity = cos_sim(embeddings[0], embeddings[1])
|
||||||
|
|
||||||
|
# Write the similarity to the output file
|
||||||
|
output_file.write(f"{similarity}\n")
|
||||||
|
|
||||||
|
print(f"Similarity calculation completed. Results saved to {output_file_path}")
|
60
translate/analytics/filter.py
Normal file
60
translate/analytics/filter.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
from transformers import AutoModel
|
||||||
|
from numpy.linalg import norm
|
||||||
|
import argparse
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Usage: python filter.py <file_a_path> <file_b_path> <output_file_path>"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("file_a", type=str, help="File No.1")
|
||||||
|
parser.add_argument("file_b", type=str, help="File No.2")
|
||||||
|
parser.add_argument("output", type=str, help="Output file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--resume",
|
||||||
|
type=int,
|
||||||
|
default=-1,
|
||||||
|
help="Resume from specified line",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Define the cosine similarity function
|
||||||
|
cos_sim = lambda a, b: (a @ b.T) / (norm(a) * norm(b))
|
||||||
|
|
||||||
|
# Load the model and tokenizer
|
||||||
|
model_name = 'jinaai/jina-embeddings-v2-base-zh'
|
||||||
|
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
model.to('cuda')
|
||||||
|
|
||||||
|
# Define file paths from command-line arguments
|
||||||
|
file_a_path = args.file_a
|
||||||
|
file_b_path = args.file_b
|
||||||
|
output_file_path = args.output
|
||||||
|
|
||||||
|
resume_from = args.resume
|
||||||
|
resume = resume_from >= 0
|
||||||
|
output_file_mode = 'a' if resume else 'w'
|
||||||
|
|
||||||
|
# Open files
|
||||||
|
with open(file_a_path, 'r', encoding='utf-8') as file_a, \
|
||||||
|
open(file_b_path, 'r', encoding='utf-8') as file_b, \
|
||||||
|
open(output_file_path, output_file_mode, encoding='utf-8') as output_file:
|
||||||
|
i=1
|
||||||
|
# Read file A and file B line by line
|
||||||
|
for line_a, line_b in tqdm(zip(file_a, file_b)):
|
||||||
|
if resume and i < resume_from:
|
||||||
|
i+=1
|
||||||
|
continue
|
||||||
|
# Remove trailing newline characters
|
||||||
|
line_a = line_a.strip()
|
||||||
|
line_b = line_b.strip()
|
||||||
|
|
||||||
|
embeddings = model.encode([line_a, line_b])
|
||||||
|
similarity = cos_sim(embeddings[0], embeddings[1])
|
||||||
|
|
||||||
|
# Write the similarity to the output file
|
||||||
|
output_file.write(f"{similarity}\n")
|
||||||
|
|
||||||
|
i+=1
|
||||||
|
|
||||||
|
print(f"Similarity calculation completed. Results saved to {output_file_path}")
|
74
translate/analytics/translation2019/check_sim.py
Normal file
74
translate/analytics/translation2019/check_sim.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
from transformers import AutoModel
|
||||||
|
from numpy.linalg import norm
|
||||||
|
import sys
|
||||||
|
import random
|
||||||
|
import json
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# Define the cosine similarity function
|
||||||
|
cos_sim = lambda a, b: (a @ b.T) / (norm(a) * norm(b))
|
||||||
|
|
||||||
|
# Load the model and tokenizer
|
||||||
|
model_name = 'jinaai/jina-embeddings-v2-base-zh'
|
||||||
|
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
|
||||||
|
# Check if the correct number of command-line arguments are provided
|
||||||
|
if len(sys.argv) < 4 or len(sys.argv) > 5:
|
||||||
|
print("Usage: python script.py <file_path> <output_file_path> [num_samples]")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Define file paths from command-line arguments
|
||||||
|
file_path = sys.argv[1]
|
||||||
|
output_file_path = sys.argv[2]
|
||||||
|
|
||||||
|
# Define the number of samples to randomly select
|
||||||
|
num_samples = int(sys.argv[3]) if len(sys.argv) == 4 else 100
|
||||||
|
|
||||||
|
# Get the total number of lines in the files without loading them fully
|
||||||
|
def count_lines(file_path):
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
return sum(1 for _ in f)
|
||||||
|
|
||||||
|
total_lines = count_lines(file_path)
|
||||||
|
|
||||||
|
# Select random sample indices without loading entire files
|
||||||
|
selected_indices = sorted(random.sample(range(total_lines), num_samples))
|
||||||
|
|
||||||
|
# Function to get all sampled lines from the file
|
||||||
|
def get_lines(file_path, line_numbers):
|
||||||
|
result = []
|
||||||
|
max_i = max(line_numbers)
|
||||||
|
j=0
|
||||||
|
next_i = line_numbers[j]
|
||||||
|
len_line_numbers = len(line_numbers)
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as f:
|
||||||
|
for current_line, line in tqdm(enumerate(f)):
|
||||||
|
if current_line < next_i:
|
||||||
|
continue
|
||||||
|
result.append(line.strip())
|
||||||
|
j+=1
|
||||||
|
if current_line >= max_i or j >= len_line_numbers:
|
||||||
|
return result
|
||||||
|
next_i = line_numbers[j]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
lines = get_lines(file_path, selected_indices)
|
||||||
|
|
||||||
|
# Open output file for writing
|
||||||
|
with open(output_file_path, 'w', encoding='utf-8') as output_file, open("1.txt", 'w', encoding='utf-8') as lf:
|
||||||
|
for i, idx in tqdm(enumerate(selected_indices)):
|
||||||
|
# Get the corresponding lines from both files
|
||||||
|
line = lines[i]
|
||||||
|
data = json.loads(line)
|
||||||
|
chn = data["chinese"]
|
||||||
|
eng = data["english"]
|
||||||
|
lf.write(str(idx)+'\n')
|
||||||
|
|
||||||
|
embeddings = model.encode([chn, eng])
|
||||||
|
similarity = cos_sim(embeddings[0], embeddings[1])
|
||||||
|
|
||||||
|
# Write the similarity to the output file
|
||||||
|
output_file.write(f"{similarity}\n")
|
||||||
|
|
||||||
|
print(f"Similarity calculation completed. Results saved to {output_file_path}")
|
30
translate/hf-dataset.py
Normal file
30
translate/hf-dataset.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
# 定义文件路径
|
||||||
|
source_files = ['./result/source.txt', './result/source-new.txt']
|
||||||
|
target_files = ['./result/target.txt', './result/target-new.txt']
|
||||||
|
|
||||||
|
# 读取source和target文件内容
|
||||||
|
source_data = []
|
||||||
|
target_data = []
|
||||||
|
|
||||||
|
for file in source_files:
|
||||||
|
with open(file, 'r', encoding='utf-8') as f:
|
||||||
|
source_data.extend(f.readlines())
|
||||||
|
|
||||||
|
for file in target_files:
|
||||||
|
with open(file, 'r', encoding='utf-8') as f:
|
||||||
|
target_data.extend(f.readlines())
|
||||||
|
|
||||||
|
# 确保source和target行数一致
|
||||||
|
if len(source_data) != len(target_data):
|
||||||
|
print("Warning: The number of lines in source and target files do not match.")
|
||||||
|
|
||||||
|
# 创建DataFrame
|
||||||
|
df = pd.DataFrame({
|
||||||
|
'zh': [line.strip() for line in source_data], # 去掉每行的换行符
|
||||||
|
'en': [line.strip() for line in target_data] # 去掉每行的换行符
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
df.to_csv('./result/data.csv', index=False, encoding='utf-8')
|
@ -27,8 +27,8 @@ def process_json_files(directory, converted_filename):
|
|||||||
data = json.load(json_file)
|
data = json.load(json_file)
|
||||||
segments = data.get('segments', [])
|
segments = data.get('segments', [])
|
||||||
|
|
||||||
with open('./result/source.txt', 'a', encoding='utf-8') as source_file, \
|
with open('./result/source-new.txt', 'a', encoding='utf-8') as source_file, \
|
||||||
open('./result/target.txt', 'a', encoding='utf-8') as target_file:
|
open('./result/target-new.txt', 'a', encoding='utf-8') as target_file:
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
chinese_text = segment.get('chinese', '').replace('\n', ' ')
|
chinese_text = segment.get('chinese', '').replace('\n', ' ')
|
||||||
english_text = segment.get('english', '').replace('\n', ' ')
|
english_text = segment.get('english', '').replace('\n', ' ')
|
||||||
@ -42,7 +42,7 @@ def process_json_files(directory, converted_filename):
|
|||||||
write_converted_file(converted_filename, filename)
|
write_converted_file(converted_filename, filename)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
json_directory = './output' # 替换为你的JSON文件目录路径
|
json_directory = './output-new' # 替换为你的JSON文件目录路径
|
||||||
converted_filename = './result/converted.txt'
|
converted_filename = './result/converted.txt'
|
||||||
|
|
||||||
process_json_files(json_directory, converted_filename)
|
process_json_files(json_directory, converted_filename)
|
52
translate/split_source.py
Normal file
52
translate/split_source.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
def split_content(content):
|
||||||
|
sentences = re.split(r'[。!?;.!?;]', content)
|
||||||
|
segments = []
|
||||||
|
current_segment = []
|
||||||
|
current_length = 0
|
||||||
|
|
||||||
|
for sentence in sentences:
|
||||||
|
sentence_length = len(sentence)
|
||||||
|
if (len(current_segment) >= 25 or current_length + sentence_length > 1200):
|
||||||
|
segments.append(''.join(current_segment))
|
||||||
|
current_segment = []
|
||||||
|
current_length = 0
|
||||||
|
|
||||||
|
current_segment.append(sentence)
|
||||||
|
current_length += sentence_length
|
||||||
|
|
||||||
|
if current_segment:
|
||||||
|
segments.append(''.join(current_segment))
|
||||||
|
|
||||||
|
return segments
|
||||||
|
|
||||||
|
def process_files_in_directory(directory):
|
||||||
|
for filename in os.listdir(directory):
|
||||||
|
file_path = os.path.join(directory, filename)
|
||||||
|
|
||||||
|
# 只处理文件,跳过目录
|
||||||
|
if os.path.isfile(file_path):
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as file:
|
||||||
|
content = file.read()
|
||||||
|
|
||||||
|
segments = split_content(content)
|
||||||
|
|
||||||
|
if len(segments) > 1:
|
||||||
|
# 删除原始文件
|
||||||
|
os.remove(file_path)
|
||||||
|
|
||||||
|
# 保存分割后的文件
|
||||||
|
for i, segment in enumerate(segments):
|
||||||
|
new_filename = f"{filename}_{i+1}"
|
||||||
|
new_file_path = os.path.join(directory, new_filename)
|
||||||
|
|
||||||
|
with open(new_file_path, 'w', encoding='utf-8') as new_file:
|
||||||
|
new_file.write(segment)
|
||||||
|
else:
|
||||||
|
print(f"文件 {filename} 不需要分割")
|
||||||
|
|
||||||
|
# 指定目录
|
||||||
|
directory = './source-new'
|
||||||
|
process_files_in_directory(directory)
|
44
translate/validation/LLMtrans.py
Normal file
44
translate/validation/LLMtrans.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
from openai import OpenAI
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
def translate_text(text, client, model_name, temp):
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "User will provide some text. You need to translate the text into English and output it WITHOUT ANY ADDITIONAL INFORMATION OR EXPLANATION."},
|
||||||
|
{"role": "user", "content": text},
|
||||||
|
]
|
||||||
|
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
temperature=temp,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("input", type=str, help="Path to the input file")
|
||||||
|
parser.add_argument("output", type=str, help="Path to the output file")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
input_file = args.input
|
||||||
|
output_file = args.output
|
||||||
|
client = OpenAI(
|
||||||
|
api_key=os.getenv("API_KEY"),
|
||||||
|
base_url=os.getenv("BASE_URL"),
|
||||||
|
)
|
||||||
|
model = os.getenv("TRANSLATION_MODEL")
|
||||||
|
temp = float(os.getenv("TRANSLATION_TEMP"))
|
||||||
|
|
||||||
|
with open(input_file, "r") as f:
|
||||||
|
src_lines = f.readlines()
|
||||||
|
|
||||||
|
|
||||||
|
for line in tqdm(src_lines):
|
||||||
|
result = translate_text(line, client, model, temp)
|
||||||
|
with open(output_file, 'a') as f:
|
||||||
|
f.write(result + '\n')
|
15
translate/validation/argoTrans.py
Normal file
15
translate/validation/argoTrans.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import subprocess
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
def translate_text(text):
|
||||||
|
command = f'argos-translate --from zh --to en "{text}"'
|
||||||
|
result = subprocess.run(command, shell=True, capture_output=True, text=True)
|
||||||
|
return result.stdout.strip()
|
||||||
|
|
||||||
|
with open("./data/src.txt", "r") as f:
|
||||||
|
src_lines = f.readlines()
|
||||||
|
|
||||||
|
for line in tqdm(src_lines):
|
||||||
|
result = translate_text(line)
|
||||||
|
with open("./data/hyp-sk-1.2.txt", 'a') as f:
|
||||||
|
f.write(result + '\n')
|
42
translate/validation/bleu_full.py
Normal file
42
translate/validation/bleu_full.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import json
|
||||||
|
import subprocess
|
||||||
|
import evaluate
|
||||||
|
from nltk.tokenize import word_tokenize
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
bleu_cal = evaluate.load("chrf")
|
||||||
|
|
||||||
|
def translate_text(text):
|
||||||
|
command = f'argos-translate --from zh --to en "{text}"'
|
||||||
|
result = subprocess.run(command, shell=True, capture_output=True, text=True)
|
||||||
|
return result.stdout.strip()
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 读取数据集
|
||||||
|
with open('./data/1.jsonl', 'r', encoding='utf-8') as f:
|
||||||
|
data = [json.loads(line) for line in f]
|
||||||
|
|
||||||
|
translations = []
|
||||||
|
references = []
|
||||||
|
|
||||||
|
# for entry in tqdm(data):
|
||||||
|
# chinese_sentence = entry['zh']
|
||||||
|
# translated_sentence = translate_text(chinese_sentence)
|
||||||
|
# with open("./data/1-inf.txt", "a") as f:
|
||||||
|
# f.write(translated_sentence + "\n")
|
||||||
|
# translations.append(translated_sentence)
|
||||||
|
|
||||||
|
with open("./data/1-inf.txt", 'r') as f:
|
||||||
|
translations = f.readlines()
|
||||||
|
|
||||||
|
for entry in data:
|
||||||
|
english_sentence = entry['en']
|
||||||
|
references.append([english_sentence])
|
||||||
|
|
||||||
|
|
||||||
|
# 计算 BLEU 分数
|
||||||
|
bleu = bleu_cal.compute(predictions=translations, references=references)
|
||||||
|
print(bleu)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
10
translate/validation/googleTrans.py
Normal file
10
translate/validation/googleTrans.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
from googletrans import Translator
|
||||||
|
translator = Translator()
|
||||||
|
|
||||||
|
with open("./data/src.txt", "r") as f:
|
||||||
|
src_lines = f.readlines()
|
||||||
|
|
||||||
|
for line in src_lines:
|
||||||
|
result = translator.translate(line, dest='en')
|
||||||
|
with open("./data/hyp-gg-py.txt", 'a') as f:
|
||||||
|
f.write(result.text + '\n')
|
19
translate/validation/m2mTrans.py
Normal file
19
translate/validation/m2mTrans.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
|
||||||
|
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M")
|
||||||
|
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
|
||||||
|
|
||||||
|
def translate_text(text):
|
||||||
|
tokenizer.src_lang = "zh"
|
||||||
|
encoded_zh = tokenizer(text, return_tensors="pt")
|
||||||
|
generated_tokens = model.generate(**encoded_zh, forced_bos_token_id=tokenizer.get_lang_id("en"))
|
||||||
|
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
||||||
|
return result[0]
|
||||||
|
|
||||||
|
with open("./data/src.txt", "r") as f:
|
||||||
|
src_lines = f.readlines()
|
||||||
|
|
||||||
|
for line in tqdm(src_lines):
|
||||||
|
result = translate_text(line)
|
||||||
|
with open("./data/hyp-m2m.txt", 'a') as f:
|
||||||
|
f.write(result + '\n')
|
56
translate/validation/preprocess.py
Normal file
56
translate/validation/preprocess.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import json
|
||||||
|
import random
|
||||||
|
import argparse
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
# 读取jsonl文件
|
||||||
|
def read_jsonl(file_path):
|
||||||
|
with open(file_path, "r", encoding="utf-8") as file:
|
||||||
|
for line in file:
|
||||||
|
yield json.loads(line)
|
||||||
|
|
||||||
|
|
||||||
|
# 随机抽取一定数量的行
|
||||||
|
def sample_lines(data, sample_size):
|
||||||
|
return random.sample(list(data), sample_size)
|
||||||
|
|
||||||
|
|
||||||
|
# 主函数
|
||||||
|
def main(input_file, sample_size):
|
||||||
|
# 读取jsonl文件
|
||||||
|
data = read_jsonl(input_file)
|
||||||
|
|
||||||
|
# 随机抽取一定数量的行
|
||||||
|
sampled_data = sample_lines(data, sample_size)
|
||||||
|
|
||||||
|
for item in tqdm(sampled_data):
|
||||||
|
chinese_text = item["chinese"]
|
||||||
|
english_text = item["english"]
|
||||||
|
|
||||||
|
with open("./data/src.txt", 'a') as srcf, open("./data/ref.txt", 'a') as reff:
|
||||||
|
srcf.write(chinese_text + '\n')
|
||||||
|
reff.write(english_text + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
# 示例调用
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 创建命令行参数解析器
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Process a JSONL file by sampling lines and translating text."
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加命令行参数
|
||||||
|
parser.add_argument("input", type=str, help="Path to the input JSONL file")
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample_size",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="Number of lines to sample (default: 100)",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 解析命令行参数
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# 调用主函数
|
||||||
|
main(args.input, args.sample_size)
|
16
ugNMT/BPE/filter_non-ug_char.py
Normal file
16
ugNMT/BPE/filter_non-ug_char.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
# 读取文件内容
|
||||||
|
with open('ug.txt', 'r', encoding='utf-8') as file:
|
||||||
|
data = file.read()
|
||||||
|
|
||||||
|
# 定义正则表达式,保留维吾尔语字母、阿拉伯数字及常见标点符号
|
||||||
|
# 维吾尔语字母的Unicode范围是U+0600-U+06FF
|
||||||
|
# 阿拉伯数字 0-9,以及标点符号(。!?,,;:)可以根据需要调整
|
||||||
|
filtered_data = re.sub(r'[^\u0600-\u06FF0-9.,!?؛:\s]', '', data)
|
||||||
|
|
||||||
|
# 将过滤后的数据输出或保存到新的文件中
|
||||||
|
with open('filtered_ug.txt', 'w', encoding='utf-8') as file:
|
||||||
|
file.write(filtered_data)
|
||||||
|
|
||||||
|
print("过滤完成,结果已保存到 filtered_ug.txt")
|
13
ugNMT/BPE/filter_space.py
Normal file
13
ugNMT/BPE/filter_space.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
def replace_spaces_in_file(input_file_path, output_file_path):
|
||||||
|
with open(input_file_path, 'r', encoding='utf-8') as file:
|
||||||
|
text = file.read()
|
||||||
|
|
||||||
|
new_text = re.sub(r' +', ' ', text)
|
||||||
|
|
||||||
|
with open(output_file_path, 'w', encoding='utf-8') as file:
|
||||||
|
file.write(new_text)
|
||||||
|
|
||||||
|
# 调用函数,替换文件中的空格
|
||||||
|
replace_spaces_in_file('./data/ug_texts1.txt', './data/2.txt')
|
Loading…
Reference in New Issue
Block a user