Compare commits

...

32 Commits

Author SHA1 Message Date
fdff155673
update: idk 2024-11-16 21:51:19 +08:00
c4ca9c7d4f
ref: clean forced-alignment 2024-11-07 03:14:47 +08:00
65123d1b39
add: full example of forced alignment for music 2024-11-03 01:11:45 +08:00
aeab34f84b
add: forced alignment example 2024-11-02 15:46:12 +08:00
37d2507f10
update: latest synthetic data script 2024-10-07 23:15:25 +08:00
33754146c8
add: text-difficulty/grammar 2024-10-02 21:11:23 +08:00
ae6f10a6f0
add: open set validation 2024-09-28 21:53:55 +08:00
bf2c9a393a
update: add metadata export of intention classify 2024-09-26 22:57:27 +08:00
853d158c41
update: README for translate 2024-09-23 21:30:44 +08:00
9f071ee0a0
ref: the intention-classification model 2024-09-22 03:58:56 +08:00
66cf093177
add: dataset quality check 2024-09-20 00:53:51 +08:00
237d2f5c96
ref: remove unnecessary file 2024-09-19 22:05:05 +08:00
01597c298d
update: evaluation 2024-09-19 22:03:54 +08:00
435faa4b92
update: README 2024-09-17 20:44:55 +08:00
580753bb6f
update: README 2024-09-17 20:20:47 +08:00
6500e378be
add: translation evaluation 2024-09-17 20:07:47 +08:00
3bb222bda1
update: README 2024-09-16 17:40:11 +08:00
3ebeaf4655
update: readme 2024-09-16 17:34:13 +08:00
932cbd4336
add: dataset 2024-09-16 17:29:12 +08:00
a9a7430a58
update: fetching with cooldown
fix: post-process unmatch
improve: LLM-translate now request with temprature
2024-09-16 04:08:33 +08:00
6f25183654
update: fetcher and post-process
move the max threads and fetch limit in fetcher into env
update the postprocess flow to remove duplicates
2024-09-16 00:59:58 +08:00
9eeb3de828
update: fetcher, translator
increase threshold of split in fetcher
improve prompt for LLM-translator
2024-09-16 00:48:07 +08:00
7021687e10
add: postprocess 2024-09-15 23:54:37 +08:00
4c9f411f67
add: content fetcher for translate 2024-09-15 23:43:01 +08:00
ebd1113a6e
update: llm translate 2024-09-10 21:35:25 +08:00
dcf53ca002
add: spider 2024-09-10 21:35:00 +08:00
1acc1ce703
add: LLM-based batch translation
used for improve translation dataset quality
2024-09-10 00:52:55 +08:00
dc1722ca3d
ref: use argos-translate instead 2024-09-07 23:02:50 +08:00
bb0aa5b79b
update: translate
improve speed
2024-09-07 23:00:15 +08:00
12b9b910f4
add: translation 2024-09-07 15:53:21 +08:00
86394c7f87
update: readme 2024-09-01 22:57:52 +08:00
2c88faf9c0
add: readme 2024-09-01 22:31:52 +08:00
58 changed files with 32193 additions and 656 deletions

17
.gitignore vendored
View File

@ -4,4 +4,19 @@ runs
*.pt
*.bin
token_to_id.json
.ipynb_checkpoints
.ipynb_checkpoints
**/data/**
__pycache__
.env
.env*
translate/output*
translate/source*
translate/result
*.db
dataset/raw
translate/special-spiders
ugNMT/BPE/output*
ugNMT/BPE/codes
forced-alignment/segments
forced-alignment/data
forced-alignment/output.ttml

29
README.md Normal file
View File

@ -0,0 +1,29 @@
# sparkastML
This repository contains the machine learning components for the [sparkast](https://github.com/alikia2x/sparkast) project.
The main goal of this project is to improve the search functionality of sparkast, enabling users to receive real-time answers as they type their queries.
## Intention Classification
The model in the `/intention-classify` directory is designed to categorize user queries into predefined classes.
We use a Convolutional Neural Network (CNN) architecture combined with an Energy-based Model for open-set recognition.
This model is optimized to be lightweight, ensuring it can run on a wide range of devices, including within the browser environment.
For a detailed explanation of how it works, refer to [this blog post](https://blog.alikia2x.com/en/posts/sparkastml-intention/).
## Translation
Language barriers are one of the biggest obstacles to communication between civilizations. In modern times, with the development of computer science and artificial intelligence, machine translation is bridging this gap and building a modern Tower of Babel.
Unfortunately, many machine translation systems are owned by commercial companies, which seriously hinders the development of freedom and innovation.
Therefore, sparkastML is on a mission to challenge commercial machine translation. We decided to tackle the translation between Chinese and English first. These are two languages with a long history and a large number of users. Their writing methods and expression habits are very different, which brings challenges to the project.
For more details, visit [this page](./translate/README.md).
## Dataset
To support the development of Libre Intelligence, we have made a series of datasets publicly available. You can access them [here](./dataset/public/README.md).

22
dataset/public/README.md Normal file
View File

@ -0,0 +1,22 @@
# sparkastML Datasets
This repository contains datasets published by the sparkastML project.
## Translation ZH-EN
This dataset features high-quality, fresh synthetic data comprising over 100,000 sentences of Chinese-English parallel corpora.
### Details
- **Source Language:** Chinese (Simplified)
- **Target Language:** English
- **Version:** 1
- **Last Update:** 2024/09/16
- **LICENSE:** [CC-BY 4.0](https://creativecommons.org/licenses/by/4.0/)
### Download
- **Google Drive:** [Download from Google Drive](https://drive.google.com/drive/folders/1_ADblZcB5p9BUvawkYDmp1qIUDZgkkoe)
- **IPFS:** [Download from IPFS](https://ipfs.a2x.pub/ipfs/QmYz4ew4nSzPc6TZvoWk6jXpGN82qt3J46nwfb75N2YKc4/)
- CID: `QmYz4ew4nSzPc6TZvoWk6jXpGN82qt3J46nwfb75N2YKc4`
- **GitHub Release:** [Go to Release Page](https://github.com/alikia2x/sparkastML/releases/tag/v1-dataset)

View File

@ -0,0 +1,7 @@
# 强制对齐在歌词逐字对齐上的应用
这个子项目是为了给[AquaVox](https://github.com/alikia2x/aquavox)提供AI加持的逐字歌词功能所诞生的。
## 规划
对于给定歌词和

File diff suppressed because one or more lines are too long

228
forced-alignment/split.py Normal file
View File

@ -0,0 +1,228 @@
import torch
import torchaudio
from typing import List
from pypinyin import lazy_pinyin
from pypinyin_dict.phrase_pinyin_data import cc_cedict
from torchaudio.transforms import Resample
from tqdm import tqdm
def compute_alignments(waveform: torch.Tensor, transcript: List[str]):
with torch.inference_mode():
emission, _ = model(waveform.to(device))
token_spans = aligner(emission[0], tokenizer(transcript))
return emission, token_spans
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from torchaudio.pipelines import MMS_FA as bundle
model = bundle.get_model()
model.to(device)
tokenizer = bundle.get_tokenizer()
aligner = bundle.get_aligner()
cc_cedict.load()
add_spaces = lambda s: ' '.join(s)
from pydub import AudioSegment
def get_audio_duration(file_path):
"""
读取音频文件并获取其时长秒数
:param file_path: 音频文件的路径
:return: 音频文件的时长秒数
"""
try:
audio = AudioSegment.from_file(file_path)
duration_in_seconds = len(audio) / 1000.0
return duration_in_seconds
except Exception as e:
print(f"Error reading audio file: {e}")
return None
def timestamp(seconds):
"""
将浮点数的秒钟转换为TTML的时间戳格式HH:MM:SS.sss
:param seconds: 浮点数的秒钟
:return: TTML时间戳格式字符串
"""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
seconds = seconds % 60
milliseconds = int((seconds % 1) * 1000)
seconds = int(seconds)
return f"{hours:02}:{minutes:02}:{seconds:02}.{milliseconds:03}"
def timestamp_inverse(ttml_timestamp):
"""
将TTML的时间戳格式字符串HH:MM:SS.sss转换为浮点数的秒钟
:param ttml_timestamp: TTML时间戳格式字符串
:return: 浮点数的秒钟
"""
parts = ttml_timestamp.split(':')
hours = int(parts[0])
minutes = int(parts[1])
seconds_and_milliseconds = parts[2].split('.')
seconds = int(seconds_and_milliseconds[0])
milliseconds = int(seconds_and_milliseconds[1])
total_seconds = hours * 3600 + minutes * 60 + seconds + milliseconds / 1000
return total_seconds
import xml.etree.ElementTree as ET
import os
import re
def extract_numbers_from_files(directory):
"""
读取给定目录提取文件名中的数字部分并返回一个包含这些数字的列表
:param directory: 目录路径
:return: 包含数字的列表
"""
numbers = []
pattern = re.compile(r'line-(\d+)\.wav')
try:
for filename in os.listdir(directory):
match = pattern.match(filename)
if match:
number = int(match.group(1))
numbers.append(number)
except Exception as e:
print(f"Error reading directory: {e}")
return None
return numbers
class TTMLGenerator:
def __init__(self, duration, xmlns="http://www.w3.org/ns/ttml", xmlns_ttm="http://www.w3.org/ns/ttml#metadata", xmlns_amll="http://www.example.com/ns/amll", xmlns_itunes="http://music.apple.com/lyric-ttml-internal"):
self.tt = ET.Element("tt", attrib={
"xmlns": xmlns,
"xmlns:ttm": xmlns_ttm,
"xmlns:amll": xmlns_amll,
"xmlns:itunes": xmlns_itunes
})
self.head = ET.SubElement(self.tt, "head")
self.metadata = ET.SubElement(self.head, "metadata")
self.body = ET.SubElement(self.tt, "body", attrib={"dur": duration})
self.div = ET.SubElement(self.body, "div")
def add_lyrics(self, begin, end, agent, itunes_key, words):
p = ET.SubElement(self.div, "p", attrib={
"begin": begin,
"end": end,
"ttm:agent": agent,
"itunes:key": itunes_key
})
for word, start, stop in words:
span = ET.SubElement(p, "span", attrib={"begin": start, "end": stop})
span.text = word
def save(self, filename):
tree = ET.ElementTree(self.tt)
tree.write(filename, encoding="utf-8", xml_declaration=True)
duration = get_audio_duration("./data/谷雨.mp3")
# 示例使用
ttml_generator = TTMLGenerator(duration=timestamp(duration))
def process_line(line_idx, start_time, total_lines):
with open(f"./segments/line-{line_idx}.txt", "r") as f:
text = f.read()
waveform, sample_rate = torchaudio.load(f"./segments/line-{line_idx}.wav")
waveform = waveform[0:1]
resampler = Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform)
text_pinyin = lazy_pinyin(text)
text_normalized = " ".join(text_pinyin)
transcript = text_normalized.split()
emission, token_spans = compute_alignments(waveform, transcript)
num_frames = emission.size(1)
ratio = waveform.size(1) / num_frames
words = []
for i in range(len(token_spans)):
spans = token_spans[i]
x0 = start_time + int(ratio * spans[0].start) / 16000
x1 = start_time + int(ratio * spans[-1].end) / 16000
words.append({
"word": text[i],
"start": x0,
"end": x1
})
idx=0
for item in words:
if idx == len(words) - 1:
break
item["end"] = words[idx + 1]["start"]
idx+=1
result = []
for word in words:
result.append((word["word"], timestamp(word["start"]), timestamp(word["end"])))
return result
lines_to_process = sorted(extract_numbers_from_files("segments"))
def parse_lrc(lrc_file, audio_len):
"""解析LRC文件返回一个包含时间戳和歌词的列表"""
with open(lrc_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
lrc_data = []
for line in lines:
# 使用正则表达式匹配时间戳和歌词
match = re.match(r'\[(\d+):(\d+\.\d+)\](.*)', line)
if match:
minutes = int(match.group(1))
seconds = float(match.group(2))
lyric = match.group(3).strip()
lyric = lyric.replace(" ", "")
timestamp = minutes * 60 + seconds
lrc_data.append((lyric, timestamp))
for i, (lyric, start_time) in enumerate(lrc_data):
# Skip empty line
if lyric.strip() == "":
continue
if i < len(lrc_data) - 1:
end_time = lrc_data[i + 1][1]
else:
end_time = audio_len
lrc_data[i] = (lyric, start_time, end_time)
# Filter empty lines again
lrc_data = [line for line in lrc_data if line[0].strip() != ""]
return lrc_data
lrc_data = parse_lrc("./data/谷雨.lrc", duration)
i=0
for line_num in tqdm(lines_to_process):
start_time = lrc_data[i][1]
end_time = lrc_data[i][2]
result = process_line(line_num, start_time, len(lines_to_process))
ttml_generator.add_lyrics(
begin=timestamp(start_time), end=timestamp(end_time), agent="v1", itunes_key=f"L{i+1}",
words=result
)
i+=1
# 保存文件
ttml_generator.save("output.ttml")

View File

@ -0,0 +1,57 @@
from pydub import AudioSegment
import re
def parse_lrc(lrc_file):
"""解析LRC文件返回一个包含时间戳和歌词的列表"""
with open(lrc_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
lrc_data = []
for line in lines:
# 使用正则表达式匹配时间戳和歌词
match = re.match(r'\[(\d+):(\d+\.\d+)\](.*)', line)
if match:
minutes = int(match.group(1))
seconds = float(match.group(2))
lyric = match.group(3).strip()
lyric = lyric.replace(" ", "")
timestamp = minutes * 60 + seconds
lrc_data.append((timestamp, lyric))
return lrc_data
def split_audio_by_lrc(audio_file, lrc_data, output_prefix):
"""根据LRC数据分割音频文件并保存为单独的WAV文件"""
audio = AudioSegment.from_file(audio_file)
for i, (start_time, lyric) in enumerate(lrc_data):
# Skip empty line
if lyric.strip() == "":
continue
if i < len(lrc_data) - 1:
end_time = lrc_data[i + 1][0]
else:
end_time = len(audio) / 1000 # 最后一行歌词到音频结束
start_time = max(0, start_time - 0.1) # 前后各扩0.1秒
end_time = min(len(audio) / 1000, end_time + 0.1)
start_time_ms = start_time * 1000
end_time_ms = end_time * 1000
segment = audio[start_time_ms:end_time_ms]
output_file = f"{output_prefix}-{i+1}.wav"
output_script = f"{output_prefix}-{i+1}.txt"
output_time = f"{output_prefix}-{i+1}.time"
segment.export(output_file, format="wav")
with open(output_script, "w") as f:
f.write(lyric)
with open(output_time, "w") as f:
f.write(str(start_time)+","+str(end_time))
print(f"Saved {output_file}")
if __name__ == "__main__":
lrc_file = "./data/谷雨.lrc" # LRC文件路径
audio_file = "./data/谷雨.mp3" # 音频文件路径
output_prefix = "segments/line" # 输出文件名的前缀
lrc_data = parse_lrc(lrc_file)
split_audio_by_lrc(audio_file, lrc_data, output_prefix)

View File

@ -0,0 +1,88 @@
import torch
import torchaudio
from typing import List
from pypinyin import lazy_pinyin
from pypinyin_dict.phrase_pinyin_data import cc_cedict
from torchaudio.transforms import Resample
import xml.etree.ElementTree as ET
import argparse
from pydub import AudioSegment
def get_audio_duration(file_path):
"""
读取音频文件并获取其时长秒数
:param file_path: 音频文件的路径
:return: 音频文件的时长秒数
"""
try:
audio = AudioSegment.from_file(file_path)
duration_in_seconds = len(audio) / 1000.0
return duration_in_seconds
except Exception as e:
print(f"Error reading audio file: {e}")
return None
def timestamp(seconds):
"""
将浮点数的秒钟转换为SRT的时间戳格式HH:MM:SS,sss
:param seconds: 浮点数的秒钟
:return: SRT时间戳格式字符串
"""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
seconds = seconds % 60
milliseconds = int((seconds % 1) * 1000)
seconds = int(seconds)
return f"{hours:02}:{minutes:02}:{seconds:02},{milliseconds:03}"
def compute_alignments(waveform: torch.Tensor, transcript: List[str]):
with torch.inference_mode():
emission, _ = model(waveform.to(device))
token_spans = aligner(emission[0], tokenizer(transcript))
return emission, token_spans
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from torchaudio.pipelines import MMS_FA as bundle
model = bundle.get_model()
model.to(device)
tokenizer = bundle.get_tokenizer()
aligner = bundle.get_aligner()
cc_cedict.load()
with open("./data/扬旗鸣鼓.txt", "r") as f:
text = f.read()
text_pinyin = lazy_pinyin(text)
text_normalized = " ".join(text_pinyin)
waveform, sample_rate = torchaudio.load("./data/扬旗鸣鼓.wav")
waveform = waveform[0:1]
resampler = Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform)
transcript = text_normalized.split()
emission, token_spans = compute_alignments(waveform, transcript)
num_frames = emission.size(1)
ratio = waveform.size(1) / num_frames
duration = get_audio_duration("./data/扬旗鸣鼓.wav")
for i in range(len(token_spans)):
spans = token_spans[i]
x0 = round(int(ratio * spans[0].start) / 16000, 3)
x1 = round(int(ratio * spans[-1].end) / 16000, 3)
with open("1.srt", "a") as f:
f.write(f"{i+1}\n")
f.write(f"{timestamp(x0)} --> {timestamp(x1)}\n")
f.write(f"{transcript[i]}\n\n")
f.flush()

View File

@ -0,0 +1,84 @@
import re
def srt_to_lrc(srt_text):
# 使用正则表达式匹配时间戳和内容
# who the fuck knows this
srt_text+='\n\n'
pattern = re.compile(r'(\d{2}:\d{2}:\d{2},\d{3}) --> (\d{2}:\d{2}:\d{2},\d{3})\n(.+?)\n\n', re.DOTALL)
matches = pattern.findall(srt_text)
lrc_lines = []
for start_time, end_time, content in matches:
# 提取开始时间的高亮字符
highlight_char = re.search(r'<font color="#00ff00">(.+?)</font>', content)
if highlight_char:
highlight_char = highlight_char.group(1)
else:
continue
# 将时间戳转换为LRC格式
f,start_minutes, start_seconds, start_milliseconds = map(int, start_time.replace(',', ':').split(':'))
f,end_minutes, end_seconds, end_milliseconds = map(int, end_time.replace(',', ':').split(':'))
start_time_lrc = f"{start_minutes:02d}:{start_seconds:02d}.{start_milliseconds:02d}"
end_time_lrc = f"{end_minutes:02d}:{end_seconds:02d}.{end_milliseconds:02d}"
# 构建LRC行
lrc_line = f"{highlight_char}|{start_time_lrc},{end_time_lrc}"
lrc_lines.append(lrc_line)
# 如果内容中有换行符,将其替换为空格
lrc_line = lrc_line.replace('\n', ' ')
return '\n'.join(lrc_lines)
with open('./data/谷雨.srt', 'r') as f:
srt_text = f.read()
whole = srt_text.splitlines()[2].replace('<font color="#00ff00">','').replace('</font>','')
whole = whole.replace(' ','\n')
lines = whole.splitlines()
lyric_text = ""
raw_text = srt_to_lrc(srt_text)
raw_lines = raw_text.splitlines()
for line in raw_lines:
lyric_text += line.split('|')[0]
raw_idx=0
lines_start_chr_idx=[]
for line in lines:
start = line[0]
end = line[-1]
while raw_idx < len(raw_lines) and not line.startswith(raw_lines[raw_idx].split("|")[0]):
raw_idx += 1
lines_start_chr_idx.append(raw_idx)
lines_start_chr_idx.append(len(raw_lines)-1)
raw_idx=0
lines_end_chr_idx=[]
for line in lines:
start = line[0]
end = line[-1]
while raw_idx < len(raw_lines) and not line.endswith(raw_lines[raw_idx].split("|")[0]):
raw_idx += 1
lines_end_chr_idx.append(raw_idx)
lines_end_chr_idx.append(len(raw_lines)-1)
lrc_text = ""
for i in range(len(lines_start_chr_idx)-1):
start = lines_start_chr_idx[i]
end = lines_end_chr_idx[i]
time_start = raw_lines[start].split("|")[1].split(',')[0]
time_end = raw_lines[end].split("|")[1].split(',')[0]
lrc_text += f"[{time_start}]{lyric_text[start:end+1]}\n[{time_end}]\n"
print(lrc_text)
lyric_len = len(lyric_text)
for i in range(len(lines_start_chr_idx)-1):
start = max(0,lines_start_chr_idx[i]-1)
end = min(lyric_len-1, lines_end_chr_idx[i]+1)
time_start = raw_lines[start].split("|")[1].split(',')[0]
time_end = raw_lines[end].split("|")[1].split(',')[0]
lrc_text += f"[{time_start}]{lyric_text[start:end+1]}\n[{time_end}]\n"
print(lrc_text)

328
forced-alignment/test.ipynb Normal file

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,60 @@
import torch
import torchaudio
from typing import List
from pypinyin import lazy_pinyin
from pypinyin_dict.phrase_pinyin_data import cc_cedict
from torchaudio.transforms import Resample
def compute_alignments(waveform: torch.Tensor, transcript: List[str]):
with torch.inference_mode():
emission, _ = model(waveform.to(device))
token_spans = aligner(emission[0], tokenizer(transcript))
return emission, token_spans
# Compute average score weighted by the span length
def _score(spans):
return sum(s.score * len(s) for s in spans) / sum(len(s) for s in spans)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from torchaudio.pipelines import MMS_FA as bundle
model = bundle.get_model()
model.to(device)
tokenizer = bundle.get_tokenizer()
aligner = bundle.get_aligner()
cc_cedict.load()
add_spaces = lambda s: ' '.join(s)
with open("./segments/line-1.txt", "r") as f:
text = f.read()
text_raw = add_spaces(text)
text_list = list(text)
text_pinyin = lazy_pinyin(text)
text_normalized = " ".join(text_pinyin)
waveform, sample_rate = torchaudio.load("./segments/line-1.wav")
waveform = waveform[0:1]
resampler = Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform)
transcript = text_normalized.split()
emission, token_spans = compute_alignments(waveform, transcript)
num_frames = emission.size(1)
print("Raw Transcript: ", text_raw)
print("Normalized Transcript: ", text_normalized)
ratio = waveform.size(1) / num_frames
for i in range(len(token_spans)):
spans = token_spans[i]
x0 = round(int(ratio * spans[0].start) / 16000, 3)
x1 = round(int(ratio * spans[-1].end) / 16000, 3)
print(f"{text[i]}: {x0}-{x1}")

30
forced-alignment/ttml.py Normal file
View File

@ -0,0 +1,30 @@
import xml.etree.ElementTree as ET
class TTMLGenerator:
def __init__(self, duration, xmlns="http://www.w3.org/ns/ttml", xmlns_ttm="http://www.w3.org/ns/ttml#metadata", xmlns_amll="http://www.example.com/ns/amll", xmlns_itunes="http://music.apple.com/lyric-ttml-internal"):
self.tt = ET.Element("tt", attrib={
"xmlns": xmlns,
"xmlns:ttm": xmlns_ttm,
"xmlns:amll": xmlns_amll,
"xmlns:itunes": xmlns_itunes
})
self.head = ET.SubElement(self.tt, "head")
self.metadata = ET.SubElement(self.head, "metadata")
self.body = ET.SubElement(self.tt, "body", attrib={"dur": duration})
self.div = ET.SubElement(self.body, "div")
def add_lyrics(self, begin, end, agent, itunes_key, words):
p = ET.SubElement(self.div, "p", attrib={
"begin": begin,
"end": end,
"ttm:agent": agent,
"itunes:key": itunes_key
})
for word, start, stop in words:
span = ET.SubElement(p, "span", attrib={"begin": start, "end": stop})
span.text = word
def save(self, filename):
tree = ET.ElementTree(self.tt)
tree.write(filename, encoding="utf-8", xml_declaration=True)

View File

@ -0,0 +1 @@
{"idx_to_class": {"0": "weather", "1": "base64", "2": "url-encode", "3": "html-encode", "4": "ai.command", "5": "knowledge", "6": "ai.question", "7": "datetime"}, "threshold": 1.7}

View File

@ -36,6 +36,7 @@
"室外的温度是多少",
"达拉斯今天热不热",
"苏州现在天气怎么样",
"明天悉尼会下雨吗?",
"how's the weather",
"What's going on with the weather?",
"Can you give me an update on the weather?",
@ -48,21 +49,21 @@
"What's the weather like right now?",
"Tell me the current weather conditions.",
"How about the weather today?",
"What's the weather looking like for the next few hours?",
"Is it going to stay this way all day?",
"Could you give me a brief overview of the weather?",
"What's the general weather situation in our area?",
"Is it cloudy or clear outside?",
"What's the forecast saying for today's weather?",
"Is it going to be a warm day?",
"Are we expecting any storms today?",
"What's the weather condition outside my window?",
"Is it a typical day for this season in terms of weather?",
"how's the weather now?",
"What's the temperature like right now?",
"Can you tell me the current temperature?",
"How hot is it outside?",
"What's the temperature supposed to be today?",
"What's the weather looking like for the next few hours",
"Is it going to stay this way all day",
"Could you give me a brief overview of the weather",
"What's the general weather situation in our area",
"Is it cloudy or clear outside",
"What's the forecast saying for today's weather",
"Is it going to be a warm day",
"Are we expecting any storms today",
"What's the weather condition outside my window",
"Is it a typical day for this season in terms of weather",
"how's the weather now",
"What's the temperature like right now",
"Can you tell me the current temperature",
"How hot is it outside",
"What's the temperature supposed to be today",
"What is the current temp outside?",
"Could you tell me the outdoor temperature?",
"Is it cold or warm outside?",
@ -81,8 +82,8 @@
"Can you tell me the temp in the nearby area?",
"Is it below freezing outside?",
"What's the average temperature for today?",
"Is the temperature dropping or rising?",
"What should I wear considering the temperature?"
"Is the temperature dropping or rising",
"What should I wear considering the temperature"
],
"base64": [
"请将数据使用base64编码",
@ -110,17 +111,16 @@
"解码 base64",
"Please encode this data with base64:",
"I need to encode the following data in base64",
"Could you encode this string using base64?",
"Could you encode this string using base64",
"Convert this data to b64 encoding",
"I want to encode this information with base64",
"Help me encode this in base32",
"Can you encode this data to base64 format?",
"Can you encode this data to base64 format",
"b64 encode",
"base64 encode",
"encode base64",
"base 64 encode online"
],
"url-encode": [
"编码 url",
"URL部分需要编码",
@ -145,7 +145,6 @@
"url decoder",
"URL encoder"
],
"html-encode": [
"请编码HTML实体",
"文本转为HTML实体",
@ -186,7 +185,6 @@
"html &nbsp conversion",
"html nbsp meaning"
],
"ai.command": [
"写一个TypeScript的HelloWorld代码",
"检查以下内容的语法和清晰度",
@ -237,11 +235,11 @@
"help me learn chinese",
"how to let the screen reader automatically focused to an newly poped up element in the web development",
"summarize following text:",
"Is there anything wrong with this code or can it be simplified?",
"Is there anything wrong with this code or can it be simplified",
"generate a Python script that prints 'Hello, World!'",
"Can you proofread this essay for grammar and punctuation errors?",
"Can you proofread this essay for grammar and punctuation errors",
"Create a list of ten example sentences for the word 'serendipity.'",
"Can you reformat this JSON to be more readable?",
"Can you reformat this JSON to be more readable",
"Suggest a creative title for my blog post about healthy eating.",
"Refactor this JavaScript function to make it more efficient.",
"Help me practice French: provide a sentence with a missing word that I can guess.",
@ -249,15 +247,15 @@
"Summarize this news article for me.",
"Can you review this code snippet for potential security vulnerabilities?",
"Generate a SQL query to find all users who signed up in the last 30 days.",
"Can you translate this paragraph into Spanish?",
"Can you translate this paragraph into Spanish",
"Create a flowchart based on the following process description.",
"Write a Python function to calculate the factorial of a number.",
"Provide a detailed explanation of how to implement OAuth2 in a web application.",
"Can you optimize this image for faster loading on a website?",
"Can you optimize this image for faster loading on a website",
"Suggest some catchy taglines for a new mobile app focused on fitness.",
"Write a Bash script to back up my documents folder daily.",
"Help me draft an email to request a meeting with a potential client.",
"Can you convert this Markdown document into HTML?",
"Can you convert this Markdown document into HTML",
"Generate a Python script that scrapes data from a specified website.",
"Can you find the synonyms of the word 'meticulous'?",
"Write a SQL query to join two tables based on a common column.",
@ -267,31 +265,57 @@
"Can you assist me in learning Japanese?",
"How can I make an alert box appear when a user clicks a button on a webpage?",
"Summarize this research paper into bullet points.",
"Can you check if there are any logical errors in this algorithm?"
"Can you check if there are any logical errors in this algorithm?",
"请一步一步计算找到函数f(x)=U^2*x/(R+x)^2的顶点坐标。",
"如何理解transformer自注意力机制中的Q,K,V它们分别代表什么",
"帮我写一封求职信。先询问我的教育背景、技能和经验。",
"总结这篇论文",
"写一份10人晚宴的菜单",
"写一篇博客",
"写一段演讲稿"
],
"ai.question": [
"你认为哪个框架最适合性能敏感的项目?",
"knowledge": [
"什么是后量子密码学?",
"什么是密钥派生函数",
"什么是线性代数?",
"量子计算的特点是什么",
"哈希函数的作用?",
"什么是微积分?",
"什么是区块链技术",
"What is post-quantum cryptography",
"What is a key derivation function?",
"What is Linear Algebra?",
"What is the main use of linear algebra in computer science",
"What is quantum computing",
"What is a hash function",
"What is calculus",
"什么是站点隔离?",
"What is blockchain technology?",
"BLEU 是什么",
"黎巴嫩在哪",
"什么是转义字符",
"MixAlpha售价多少",
"什么是神经机器翻译",
"什么是月食",
"什么是人工智能",
"什么是F1-score"
],
"ai.question": [
"人工智能真的有智力吗",
"你认为哪个框架最适合性能敏感的项目?",
"线性代数在计算机科学中的主要用途是什么?",
"我应该使用哪个IDE来编写Go语言",
"Go vs Java vs Kotlin哪个适合后端",
"哪种编程语言最适合数据分析",
"什么是量子计算",
"什么是哈希函数?",
"什么是微积分?",
"机器学习在金融中的主要应用有哪些?",
"写Python代码最好的文本编辑器是哪个",
"Python vs R vs Julia哪个更适合数据科学",
"监督学习和无监督学习的关键区别是什么?",
"数据库在Web应用程序中的作用是什么",
"什么是区块链技术",
"使用Docker进行应用程序部署的优势是什么",
"哪个云服务提供商提供最好的AI工具",
"加密是如何工作的?",
"负载均衡器在网络架构中的目的是什么?",
"加密是如何工作的",
"负载均衡器在网络架构中的目的是什么",
"机器学习和深度学习有什么区别",
"软件工程中最常见的设计模式有哪些",
"神经网络是如何学习的",
@ -300,31 +324,22 @@
"Rust编程语言的关键特性是什么",
"HTTP和HTTPS有什么区别",
"使用像Git这样的版本控制系统有什么优势",
"什么是'边缘计算'的概念",
"哪种编程语言最适合构建移动应用?",
"关系数据库和NoSQL数据库有什么不同",
"算法在计算机科学中的重要性是什么",
"算法在计算机科学中的重要性是什么",
"API在软件开发中的作用是什么",
"保护Web应用程序的最佳实践是什么",
"虚拟现实和增强现实有什么区别?",
"机器翻译是如何工作的?",
"Which framework do you think is the most suitable for performance sensitive projects?",
"What is post-quantum cryptography",
"What is a key derivation function?",
"What is Linear Algebra?",
"What is the main use of linear algebra in computer science",
"which IDE should I use for Go",
"Go vs Java vs Koltin, which for a backend",
"Which programming language is best suited for data analysis?",
"What is quantum computing?",
"What is a hash function?",
"What is calculus?",
"What are the main applications of machine learning in finance?",
"Which text editor is best for writing Python code?",
"Python vs R vs Julia, which is better for data science?",
"What are the key differences between supervised and unsupervised learning?",
"What are the main applications of machine learning in finance",
"Which text editor is best for writing Python code",
"Python vs R vs Julia, which is better for data science",
"What are the key differences between supervised and unsupervised learning",
"What is the role of a database in a web application?",
"What is blockchain technology?",
"What are the advantages of using Docker for application deployment?",
"Which cloud service provider offers the best AI tools?",
"How does encryption work?",
@ -332,19 +347,20 @@
"What is the difference between machine learning and deep learning?",
"What are the most common design patterns in software engineering?",
"How does a neural network learn?",
"What is the main benefit of using a microservices architecture?",
"What is the difference between a compiler and an interpreter?",
"What are the key features of the Rust programming language?",
"What is the difference between HTTP and HTTPS?",
"What are the advantages of using a version control system like Git?",
"What is the concept of 'edge computing'?",
"Which programming language is best for building mobile apps?",
"How does a relational database differ from a NoSQL database?",
"What is the importance of algorithms in computer science?",
"What is the role of an API in software development?",
"What is the main benefit of using a microservices architecture",
"What is the difference between a compiler and an interpreter",
"What are the key features of the Rust programming language",
"What is the difference between HTTP and HTTPS",
"What are the advantages of using a version control system like Git",
"What is the concept of 'edge computing'",
"Which programming language is best for building mobile apps",
"How does a relational database differ from a NoSQL database",
"What is the importance of algorithms in computer science",
"What is the role of an API in software development",
"What are the best practices for securing a web application?",
"What is the difference between virtual reality and augmented reality?",
"How does machine translation work?"
"How does machine translation work?",
"MBTI有科学依据吗"
],
"datetime": ["明天周几", "16天后是几号", "一年前的今天是星期几"]
}

View File

@ -28,7 +28,7 @@
"metadata": {},
"outputs": [],
"source": [
"model_name=\"microsoft/Phi-3-mini-4k-instruct\""
"model_name=\"Qwen/Qwen2.5-3B\""
]
},
{
@ -37,17 +37,10 @@
"id": "c1de25fc-e90a-425b-8520-3a57fa534b94",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "1aeb02c7c8084b1eb1b8e3178882fd60",
"model_id": "38137fc55ad24a9785ecbe1978bbc605",
"version_major": 2,
"version_minor": 0
},
@ -76,6 +69,122 @@
"vocab = tokenizer.get_vocab()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "21214ff4-018d-4230-81b9-331ebb42773b",
"metadata": {},
"outputs": [],
"source": [
"def bytes_to_unicode():\n",
" \"\"\"\n",
" Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control\n",
" characters the bpe code barfs on.\n",
"\n",
" The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab\n",
" if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for\n",
" decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup\n",
" tables between utf-8 bytes and unicode strings.\n",
" \"\"\"\n",
" bs = (\n",
" list(range(ord(\"!\"), ord(\"~\") + 1)) + list(range(ord(\"¡\"), ord(\"¬\") + 1)) + list(range(ord(\"®\"), ord(\"ÿ\") + 1))\n",
" )\n",
" cs = bs[:]\n",
" n = 0\n",
" for b in range(2**8):\n",
" if b not in bs:\n",
" bs.append(b)\n",
" cs.append(2**8 + n)\n",
" n += 1\n",
" cs = [chr(n) for n in cs]\n",
" return dict(zip(bs, cs))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "cbc23d2d-985b-443a-83ee-c2286046ad5e",
"metadata": {},
"outputs": [],
"source": [
"btu=bytes_to_unicode()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "4a99fa07-4922-4d8d-9c28-2275bf9cb8df",
"metadata": {},
"outputs": [],
"source": [
"utb = reversed_dict = {value: key for key, value in btu.items()}"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "cb218ea7-50c7-4bb8-aa7f-0ee85da76147",
"metadata": {},
"outputs": [],
"source": [
"result = tokenizer.convert_ids_to_tokens([104307])[0]"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "2dcb332a-cba9-4a14-9486-4e1ff6bd3dba",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"å\n",
"229\n",
"¤\n",
"164\n",
"©\n",
"169\n",
"æ\n",
"230\n",
"°\n",
"176\n",
"Ķ\n",
"148\n"
]
}
],
"source": [
"decoded=b\"\"\n",
"for chr in result:\n",
" print(chr)\n",
" if chr in utb:\n",
" print(utb[chr])\n",
" decoded+=bytes([utb[chr]])"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "b1bf1289-2cab-4a97-ad21-b2d24de6d688",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'天气'"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"decoded.decode(\"utf-8\", errors='replace')"
]
},
{
"cell_type": "code",
"execution_count": 5,
@ -95,7 +204,7 @@
"metadata": {},
"outputs": [],
"source": [
"DIMENSIONS = 128"
"DIMENSIONS = 96"
]
},
{
@ -168,11 +277,17 @@
"import struct\n",
"with open(\"token_embeddings.bin\", \"wb\") as f:\n",
" for token_id in range(len(vocab)):\n",
" # Write token id (2 bytes)\n",
" f.write(struct.pack('H', token_id))\n",
" # Write embedding vector (128 float numbers)\n",
" f.write(struct.pack('128f', *reduced_embeddings[token_id]))"
" # 将向量转换为半精度浮点数并保存\n",
" f.write(struct.pack('96e', *reduced_embeddings[token_id].astype(np.float16)))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "511a7cc4-1b8c-468c-b2a0-16dc6d74ab44",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
@ -191,7 +306,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
"version": "3.10.14"
}
},
"nbformat": 4,

View File

@ -4,5 +4,350 @@
"我爱你",
"嘿嘿嘿诶嘿",
"为什么",
"拼多多"
]
"拼多多",
"machine translation",
"trustrank",
"中文词典",
"bin screen linux",
"\"TinyBERT",
"iconify",
"反义词 英文",
"referer",
"watchos uiscreen",
"张鑫旭",
"google first result",
"flutter text align center",
"ASR model",
"real time whisper",
"千樱凛",
"马嘉祺",
"flutter widget catalog",
"flutter BottomNavigationBar left",
"flutter tab indent vscode",
"react native 用 expo 吗",
"latest monorepo tool",
"\"vite\" \"abortController\" is not defined",
"vim comment lines",
"Error: unable to get issuer certificate",
"uuidv4",
"npm semver",
"react polyfill vite",
"vibrance",
"I can eat glass, it doesn't hurt me \"japanese\"",
"I can swallow glass without any harm to myself",
"copilot pricing",
"vim close window",
"sensors macos command",
"智乃",
"pypi wikipedia",
"tesseract macos m1",
"rag prompt template",
"英国 破产",
"bewlybewly",
"safari-web-extension-converter",
"starcoder",
"open source web search for ai",
"gpt4o mini tokenizer",
"gpt4o tokenizer",
"reverse dns lookup linux",
"online ping",
"termux",
"802.11 table",
"optimize",
"集群",
"chrome us",
"transflective",
"ielts toefl",
"react router",
"摇曳露营 萌娘百科",
"isrc",
"apple-system",
"-apple-system",
"css clip path animation",
"can i use relative path in og image",
"GitSora",
"matrix im",
"test your vocabulary",
"boarding pass",
"函数签名",
"类型谓词",
"barcode",
"智能",
"threejs 入门",
"南亚语系",
"linux user's computer be like",
"apple a16 显微图",
"dallas",
"恶魔 英文",
"Rime meaning",
"adobe media encoder macos download",
"mp4 transparency",
"webkit",
"chromium",
"献血",
"软件强制更新",
"If you dont agree with its politics views, Notepad+ + will add random characters in your source code.",
"Unmerged paths",
"字数统计",
"Use build.rollupOptions.output.manualChunks to improve chunking: https://rollupjs.org/configuration-options/#output-manualchunks",
"世界人权宣言",
"latex percent",
"chord in keyboard",
"Google is trying to kill the Open Web.",
"silo'd",
"swiftui 数组倒数访问",
"swiftui link to another view",
"fizzbuzz",
"AppDelegate watchos",
"Cannot find type 'UIApplicationDelegate' in scope",
"swiftui web image",
"spammer",
"swiftui text",
"钢琴",
"disable webgl chrome",
"online uuid",
"cp show progress",
"易容术",
"fulilian",
"cargo",
"wordle",
"mismatch",
"btc",
"squelch",
"psql show table structure",
"let padding don't effect when empty",
"take over the world meaning",
"brain teasers",
"Google flight API",
"square symbol",
"sill",
"nextjs layout per page",
"UA 550 umol/L",
"react production promotion page",
"jupyter notebook",
"wth meaning",
"glove词向量",
"google suggestion relevance",
"YouTube advertising income",
"PKI",
"next client only component",
"nextjs use client",
"nextjs docker tailwind not working",
"k8s",
"Logistic Regression",
"氯化钾注射死刑",
"icloud photo loss",
"芙宁娜 水上行走",
"vector design tool",
"netizen",
"framework or next js documentation",
"csync",
"next js",
"后量子正向保密",
"nip05",
"Sora技术原理",
"wasm效率",
"switch code",
"online IPA pronunciation",
"pnpm global adir",
"如何搜索",
"1999 抽卡期望",
"swiftui background blur",
"chrome macos fullscreen hide",
"中英文空格自动",
"ios 旁白 屏幕识别",
"ios 旁白 转子",
"http 404",
"yaml缩进",
"counter generator github",
"git 服务器提供远端仓库",
"ipfs companion",
"supervisor config",
"SSO",
"slot embedding",
"sql show tables",
"The request signature we calculated does not match the signature you provided. Check your Secret Access Key and signing method.",
"icloud.com,cn",
"VuePress",
"parser",
"stackoverflow statistics",
"sd xl",
"Rollup failed to resolve import \"workbox-precaching\" from",
"dep",
"Cannot find module estree-walker.js docker",
"nuxt run",
"base58解码",
"cga",
"vscode",
"vscode",
"silicon",
"macos m1 linux",
"预处理 后处理",
"is vp9 opensource",
"Alice Blu",
"失控玩家",
"kv数据库",
"redis 持久化",
"firefox disable outline",
"cd -2",
"IM application",
"2021国产电影",
"youtube chat overlay obs",
"obs add clock",
"Z is not defined nuxt",
"safari ios debug",
"safari debug",
"chat",
"nuxt plugin inject",
"twitch",
"obs 绿幕",
"gnupg",
"kde plasma wallpaper engine",
"Plasma",
"dns over https",
"localforage缺点",
"watchOS 10",
"noun of repeat",
"微信输入法",
"行业报告",
"keepass",
"platform",
"steam",
"java proxy",
"0 design",
"cefr word level list",
"precipitation meaning",
"international school of lausanne",
"Vim Uganda",
"抖音 推荐算法",
"Meta NNLO",
"windbg dump分析",
"web image fft",
"GPT-4 Pricing",
"GPT-4",
"Scala",
"tauri教程",
"asyncio.create_task用法",
"H5 滚动到底部",
"microsoft copilot",
"枫丹文字",
"brew pip",
"TS7016: Could not find a declaration file for module react .",
"fastapi websocket",
"kazv",
"The Type 孔雀计划",
"第一个图形操作系统",
"娱乐 诞生",
"ffmpeg 音频封面",
"Jean-Loup Gailly",
"Linux用户软件位置",
"\"ubuntu\" 平滑滚动",
"python range函数",
"KMP",
"sd 8gen2 GPU GFLOPS",
"mac语音输入法",
"openai translate",
"蔚蓝档案 初始抽卡",
"free custom domain email",
"洛天依",
"b站 频道页Tab 跳转",
"URL 重定向预览",
"计算机",
"sololearn",
"PoS机制 通俗解释",
"google search cost",
"bos s3",
"react 打包",
"useeffect 用法",
"ts 字典类型",
"vscode 字典单词自动补全插件",
"componentwillupdate",
"iPad Mini 2",
"use-immer",
"reducer 和 context",
"mint",
"Elementary OS",
"google科技新闻",
"iCloud mail \"\"-9002\"\"",
"氢氧化铁胶体制备",
"react native 视频处理",
"四川 2023 高考 复旦大学 分数线",
"哑铃弯举",
"m2 ultra",
"电池循环计数 site:apple.com",
"相机发明时间",
"冯诺依曼结构",
"哈佛架构",
"nodejs 后端",
"34.5M€ to CN¥",
"NLP 实体关注",
"monkey",
"react 快捷键监听",
"mac 好看的电子书阅读器",
"新闻",
"在线字体编辑器",
"ars technica",
"genshin 4.1 release time",
"swift device activity report",
"swiftui tabview background",
"swiftui text space",
"apple inc. wikipedia",
"how long does it take Google to return the results",
"云原神 web",
"支持homekit的空调",
"内核隔离",
"海祇岛解密",
"swiftui Textfield",
"xcode",
"qq 链接",
"M1 推出时间",
"USB-IF",
"nvchat",
"P1% FPS",
"react i18next 当前语言",
"js 获取语言",
"MulType",
"b站平均使用时间",
"pip 阿里源",
"ip info",
"graphjet",
"金融思维",
"C#写入文件",
"Last Day Sinista M",
"在 系统 位置 xcode select 找 不 到 SDK",
"Error: Could not find a valid Xcode app bundle at '/Library/Developer/CommandLineTools'. Please update your Apple SDK location in Visual Studio's preferences (Projects > SDK Locations > Apple > Apple SDK). (UniBattery)",
".NET能做什么",
"could i give no tip ",
"miami university of ohio",
"方正颜宋",
"中文 标题字体",
"聚典平台",
"62 basic words for a language",
"procrastination meaning",
"Lingbe",
"娱乐至死",
"macOS 外接显示器渲染",
"白玉袖",
"SwiftUI入门",
"html插入其它网页",
"捆绑 小说",
"apple music 无损下载",
"一miumiu 赐予",
"macos markdown",
"safari 开发者工具",
"\"百合\" \"武侠\" \"国漫\"",
"epub 格式详解",
"chrome 隐藏滚动条",
"发宽空格",
"U+200A",
"无性人",
"Spotify",
"禾念",
"how to pronounce Lorem ipsum",
"言和为什么不是男孩子",
"浏览器主页",
"react",
"Tailwindcss react 扩展怎么用",
"Prettier 扩展怎么用",
"linter\""
]

View File

@ -1,575 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "a6a3195f-d099-4bf4-846f-51f403954818",
"metadata": {},
"source": [
"# sparkastML: Training the Intention Classification Model\n",
"\n",
"This is the model we use for intent recognition, using a **CNN architectur** and using an **Energy-based Model** to implement OSR (Open-set Recognition).\n",
"\n",
"In this case, **positive samples** refer to data that can be classified into existing class, while **negative samples** are those does not belong to any of the existing class."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "bddcdbb2-ccbc-4027-a38f-09c61ac94984",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import torch\n",
"from torch.utils.data import Dataset, DataLoader\n",
"from torch.nn.utils.rnn import pad_sequence\n",
"from sklearn.model_selection import train_test_split\n",
"from transformers import AutoTokenizer, AutoModel\n",
"import torch\n",
"import numpy as np\n",
"from scipy.spatial.distance import euclidean\n",
"from scipy.stats import weibull_min\n",
"from sklearn.preprocessing import normalize\n",
"import torch.nn.functional as F\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "d3a0e10f-9bc3-44c7-a109-786dd5cd25ea",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
]
}
],
"source": [
"model_name=\"microsoft/Phi-3-mini-4k-instruct\"\n",
"DIMENSIONS = 128\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)"
]
},
{
"cell_type": "markdown",
"id": "1ae14906-338d-4c99-87ed-bb1acd22b295",
"metadata": {},
"source": [
"## Load Data\n",
"\n",
"We load the data from `data.json`, and also get the negative sample from the `noise.json`."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a206071c-ce4e-4de4-b936-bfc70d13708a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/25/gdz0c30x3mg1dj9qkwz0ch4w0000gq/T/ipykernel_6446/1697839999.py:18: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_69nk78ncaj/croot/pytorch_1669252638507/work/torch/csrc/utils/tensor_new.cpp:204.)\n",
" embeddings = torch.tensor(embeddings)\n"
]
}
],
"source": [
"# Load data\n",
"with open('data.json', 'r') as f:\n",
" data = json.load(f)\n",
"\n",
"# Create map: class to index\n",
"class_to_idx = {cls: idx for idx, cls in enumerate(data.keys())}\n",
"idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}\n",
"\n",
"# Preprocess data, convert sentences to the format of (class idx, embedding)\n",
"def preprocess_data(data, embedding_map, tokenizer, max_length=64):\n",
" dataset = []\n",
" for label, sentences in data.items():\n",
" for sentence in sentences:\n",
" # Tokenize the sentence and convert tokens to embedding vectors\n",
" tokens = tokenizer.tokenize(sentence)\n",
" token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
" embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]\n",
" embeddings = torch.tensor(embeddings)\n",
" dataset.append((class_to_idx[label], embeddings))\n",
" return dataset\n",
"\n",
"# Load embedding map\n",
"embedding_map = torch.load('token_id_to_reduced_embedding.pt')\n",
"\n",
"# Get preprocessed dataset\n",
"dataset = preprocess_data(data, embedding_map, tokenizer)\n",
"\n",
"# Train-test split\n",
"train_data, val_data = train_test_split(dataset, test_size=0.2, random_state=42)\n",
"\n",
"class TextDataset(Dataset):\n",
" def __init__(self, data):\n",
" self.data = data\n",
"\n",
" def __len__(self):\n",
" return len(self.data)\n",
"\n",
" def __getitem__(self, idx):\n",
" return self.data[idx]\n",
"\n",
" def collate_fn(self, batch):\n",
" labels, embeddings = zip(*batch)\n",
" labels = torch.tensor(labels)\n",
" embeddings = pad_sequence(embeddings, batch_first=True)\n",
" return labels, embeddings\n",
"\n",
"train_dataset = TextDataset(train_data)\n",
"val_dataset = TextDataset(val_data)\n",
"\n",
"train_loader = DataLoader(train_dataset, batch_size=24, shuffle=True, collate_fn=train_dataset.collate_fn)\n",
"val_loader = DataLoader(val_dataset, batch_size=24, shuffle=False, collate_fn=val_dataset.collate_fn)\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "9adbe9b8-a2d2-4e1d-8620-457ed0e02fe6",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch.utils.data import Dataset, DataLoader\n",
"\n",
"class NegativeSampleDataset(Dataset):\n",
" def __init__(self, negative_samples):\n",
" \"\"\"\n",
" negative_samples: List or array of negative sample embeddings or raw text\n",
" \"\"\"\n",
" self.samples = negative_samples\n",
" \n",
" def __len__(self):\n",
" return len(self.samples)\n",
" \n",
" def __getitem__(self, idx):\n",
" return self.samples[idx]\n",
"\n",
" def collate_fn(self, batch):\n",
" embeddings = pad_sequence(batch, batch_first=True)\n",
" return embeddings\n",
"\n",
"with open('noise.json', 'r') as f:\n",
" negative_samples_list = json.load(f)\n",
"\n",
"negative_embedding_list = []\n",
"\n",
"for sentence in negative_samples_list:\n",
" tokens = tokenizer.tokenize(sentence)\n",
" token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
" embeddings = [embedding_map[token_id] for token_id in token_ids[:64]]\n",
" embeddings = torch.tensor(embeddings)\n",
" negative_embedding_list.append(embeddings)\n",
"\n",
"negative_dataset = NegativeSampleDataset(negative_embedding_list)\n",
"negative_loader = DataLoader(negative_dataset, batch_size=24, shuffle=True, collate_fn=negative_dataset.collate_fn)\n"
]
},
{
"cell_type": "markdown",
"id": "600febe4-2484-4aad-90a1-2bc821fdce1a",
"metadata": {},
"source": [
"## Implementating the Model"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "adf624ac-ad63-437b-95f6-b02b7253b91e",
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"class TextCNN(nn.Module):\n",
" def __init__(self, input_dim, num_classes):\n",
" super(TextCNN, self).__init__()\n",
" self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=DIMENSIONS, kernel_size=3, padding=1)\n",
" self.conv2 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=4, padding=1)\n",
" self.conv3 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=5, padding=2)\n",
" \n",
" self.bn1 = nn.BatchNorm1d(DIMENSIONS)\n",
" self.bn2 = nn.BatchNorm1d(DIMENSIONS)\n",
" self.bn3 = nn.BatchNorm1d(DIMENSIONS)\n",
" \n",
" self.dropout = nn.Dropout(0.5)\n",
" self.fc = nn.Linear(DIMENSIONS * 3, num_classes)\n",
"\n",
" def forward(self, x):\n",
" x = x.permute(0, 2, 1) # Change the input shape to (batch_size, embedding_dim, seq_length)\n",
" \n",
" x1 = F.relu(self.bn1(self.conv1(x)))\n",
" x1 = F.adaptive_max_pool1d(x1, output_size=1).squeeze(2)\n",
" \n",
" x2 = F.relu(self.bn2(self.conv2(x)))\n",
" x2 = F.adaptive_max_pool1d(x2, output_size=1).squeeze(2)\n",
" \n",
" x3 = F.relu(self.bn3(self.conv3(x)))\n",
" x3 = F.adaptive_max_pool1d(x3, output_size=1).squeeze(2)\n",
" \n",
" x = torch.cat((x1, x2, x3), dim=1)\n",
" x = self.dropout(x)\n",
" x = self.fc(x)\n",
" return x\n",
"\n",
"# Initialize model\n",
"input_dim = DIMENSIONS\n",
"num_classes = len(class_to_idx)\n",
"model = TextCNN(input_dim, num_classes)\n"
]
},
{
"cell_type": "markdown",
"id": "2750e17d-8a60-40c7-851b-1a567d0ee82b",
"metadata": {},
"source": [
"## Energy-based Models"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "a2d7c920-07d2-4d14-9cef-e2101b7a2ceb",
"metadata": {},
"outputs": [],
"source": [
"def energy_score(logits):\n",
" # Energy score is minus logsumexp\n",
" return -torch.logsumexp(logits, dim=1)\n",
"\n",
"def generate_noise(batch_size, seq_length ,input_dim, device):\n",
" # Generate a Gaussian noise\n",
" return torch.randn(batch_size, seq_length, input_dim).to(device)\n"
]
},
{
"cell_type": "markdown",
"id": "904a60e4-95a0-4f7b-ad45-a7d8d0ac887d",
"metadata": {},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "19acb5bf-00b1-47d4-ad25-a13c6be09f65",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch [1/50], Loss: 12.5108\n",
"Epoch [2/50], Loss: 10.7305\n",
"Epoch [3/50], Loss: 10.2943\n",
"Epoch [4/50], Loss: 9.9350\n",
"Epoch [5/50], Loss: 9.7991\n",
"Epoch [6/50], Loss: 9.6443\n",
"Epoch [7/50], Loss: 9.4762\n",
"Epoch [8/50], Loss: 9.4637\n",
"Epoch [9/50], Loss: 9.3025\n",
"Epoch [10/50], Loss: 9.1719\n",
"Epoch [11/50], Loss: 9.0632\n",
"Epoch [12/50], Loss: 8.9741\n",
"Epoch [13/50], Loss: 8.8487\n",
"Epoch [14/50], Loss: 8.6565\n",
"Epoch [15/50], Loss: 8.5830\n",
"Epoch [16/50], Loss: 8.4196\n",
"Epoch [17/50], Loss: 8.2319\n",
"Epoch [18/50], Loss: 8.0655\n",
"Epoch [19/50], Loss: 7.7140\n",
"Epoch [20/50], Loss: 7.6921\n",
"Epoch [21/50], Loss: 7.3375\n",
"Epoch [22/50], Loss: 7.2297\n",
"Epoch [23/50], Loss: 6.8833\n",
"Epoch [24/50], Loss: 6.8534\n",
"Epoch [25/50], Loss: 6.4557\n",
"Epoch [26/50], Loss: 6.1365\n",
"Epoch [27/50], Loss: 5.8558\n",
"Epoch [28/50], Loss: 5.5030\n",
"Epoch [29/50], Loss: 5.1604\n",
"Epoch [30/50], Loss: 4.7742\n",
"Epoch [31/50], Loss: 4.5958\n",
"Epoch [32/50], Loss: 4.0713\n",
"Epoch [33/50], Loss: 3.8872\n",
"Epoch [34/50], Loss: 3.5240\n",
"Epoch [35/50], Loss: 3.3115\n",
"Epoch [36/50], Loss: 2.5667\n",
"Epoch [37/50], Loss: 2.6709\n",
"Epoch [38/50], Loss: 1.8075\n",
"Epoch [39/50], Loss: 1.6654\n",
"Epoch [40/50], Loss: 0.4622\n",
"Epoch [41/50], Loss: 0.4719\n",
"Epoch [42/50], Loss: -0.4037\n",
"Epoch [43/50], Loss: -0.9405\n",
"Epoch [44/50], Loss: -1.7204\n",
"Epoch [45/50], Loss: -2.4124\n",
"Epoch [46/50], Loss: -3.0032\n",
"Epoch [47/50], Loss: -2.7123\n",
"Epoch [48/50], Loss: -3.6953\n",
"Epoch [49/50], Loss: -3.7212\n",
"Epoch [50/50], Loss: -3.7558\n"
]
}
],
"source": [
"import torch.optim as optim\n",
"\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=8e-4)\n",
"\n",
"from torch.utils.tensorboard import SummaryWriter\n",
"import tensorboard\n",
"writer = SummaryWriter()\n",
"\n",
"def train_energy_model(model, train_loader, negative_loader, criterion, optimizer, num_epochs=10):\n",
" model.train()\n",
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
" model.to(device)\n",
" \n",
" negative_iter = iter(negative_loader)\n",
" \n",
" for epoch in range(num_epochs):\n",
" total_loss = 0\n",
" for batch_idx, (labels, embeddings) in enumerate(train_loader):\n",
" embeddings = embeddings.to(device)\n",
" labels = labels.to(device)\n",
" \n",
" batch_size = embeddings.size(0)\n",
" \n",
" # ---------------------\n",
" # 1. Positive sample\n",
" # ---------------------\n",
" optimizer.zero_grad()\n",
" outputs = model(embeddings) # logits from the model\n",
" \n",
" class_loss = criterion(outputs, labels)\n",
" \n",
" # Energy of positive sample\n",
" known_energy = energy_score(outputs)\n",
" energy_loss_known = known_energy.mean()\n",
" \n",
" # ------------------------------------\n",
" # 2. Negative sample - Random Noise\n",
" # ------------------------------------\n",
" noise_embeddings = torch.randn_like(embeddings).to(device)\n",
" noise_outputs = model(noise_embeddings)\n",
" noise_energy = energy_score(noise_outputs)\n",
" energy_loss_noise = F.relu(1 - noise_energy).mean() # For the energy of noise, bigger is better \n",
" \n",
" # ------------------------------------\n",
" # 3. Negative sample - custom corpus\n",
" # ------------------------------------\n",
" \n",
" try:\n",
" negative_samples = next(negative_iter)\n",
" except StopIteration:\n",
" negative_iter = iter(negative_loader)\n",
" negative_samples = next(negative_iter)\n",
" negative_samples = negative_samples.to(device)\n",
" negative_outputs = model(negative_samples)\n",
" negative_energy = energy_score(negative_outputs)\n",
" energy_loss_negative = F.relu(1 - negative_energy).mean() # For the energy of noise, bigger is better \n",
" \n",
" # -----------------------------\n",
" # 4. Overall Loss calculation\n",
" # -----------------------------\n",
" total_energy_loss = energy_loss_known + energy_loss_noise + energy_loss_negative\n",
" total_loss_batch = class_loss + total_energy_loss * 0.1 + 10\n",
"\n",
" writer.add_scalar(\"Engergy Loss\", total_energy_loss, epoch)\n",
" writer.add_scalar(\"Loss\", total_loss_batch, epoch)\n",
" writer.add_scalar(\"Norm Loss\", torch.exp(total_loss_batch * 0.003) * 10 , epoch)\n",
" \n",
" total_loss_batch.backward()\n",
" optimizer.step()\n",
" \n",
" total_loss += total_loss_batch.item()\n",
" \n",
" avg_loss = total_loss / len(train_loader)\n",
" print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')\n",
"\n",
"train_energy_model(model, train_loader, negative_loader, criterion, optimizer, num_epochs=50)\n",
"writer.flush()\n"
]
},
{
"cell_type": "markdown",
"id": "e6d29558-f497-4033-8488-169bd25ce881",
"metadata": {},
"source": [
"## Evalutation"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "472702e6-db4a-4faa-9e92-7510e6eacbb1",
"metadata": {},
"outputs": [],
"source": [
"ENERGY_THRESHOLD = -3"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "d3a4fef8-37ab-45c8-b2b1-9fc8bdffcffd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy: 0.9315\n",
"Precision: 1.0000\n",
"Recall: 0.9254\n",
"F1 Score: 0.9612\n"
]
}
],
"source": [
"from sklearn.metrics import f1_score, accuracy_score, precision_recall_fscore_support\n",
"\n",
"def evaluate_energy_model(model, known_loader, unknown_loader, energy_threshold):\n",
" model.eval()\n",
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
" \n",
" all_preds = []\n",
" all_labels = []\n",
" \n",
" # Evaluate positive sample\n",
" with torch.no_grad():\n",
" for labels, embeddings in known_loader:\n",
" embeddings = embeddings.to(device)\n",
" logits = model(embeddings)\n",
" energy = energy_score(logits)\n",
" \n",
" preds = (energy <= energy_threshold).long()\n",
" all_preds.extend(preds.cpu().numpy())\n",
" all_labels.extend([1] * len(preds)) # Positive sample labeled as 1\n",
" \n",
" # Evaluate negative sample\n",
" with torch.no_grad():\n",
" for embeddings in unknown_loader:\n",
" embeddings = embeddings.to(device)\n",
" logits = model(embeddings)\n",
" energy = energy_score(logits)\n",
" \n",
" preds = (energy <= energy_threshold).long()\n",
" all_preds.extend(preds.cpu().numpy())\n",
" all_labels.extend([0] * len(preds)) # Negative sample labeled as 1\n",
" \n",
" precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')\n",
" accuracy = accuracy_score(all_labels, all_preds)\n",
"\n",
" print(f'Accuracy: {accuracy:.4f}')\n",
" print(f'Precision: {precision:.4f}')\n",
" print(f'Recall: {recall:.4f}')\n",
" print(f'F1 Score: {f1:.4f}')\n",
"\n",
"evaluate_energy_model(model, val_loader, negative_loader, ENERGY_THRESHOLD)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "ba614054-75e1-4f61-ace5-aeb11e29a222",
"metadata": {},
"outputs": [],
"source": [
"# Save the model\n",
"torch.save(model, \"model.pt\")"
]
},
{
"cell_type": "markdown",
"id": "fdfa0c5e-e6d3-4db0-a142-96645c92719c",
"metadata": {},
"source": [
"## Inference"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "03928d75-81c8-4298-ab8a-d7f8a758b561",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Predicted: ['weather', 0.9989822506904602, -8.016249656677246]\n"
]
}
],
"source": [
"def predict_with_energy(model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold, max_length=64):\n",
" model.eval()\n",
" tokens = tokenizer.tokenize(sentence)\n",
" token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
" embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]\n",
" embeddings = torch.tensor(embeddings).unsqueeze(0) # Add batch dimension\n",
" \n",
" with torch.no_grad():\n",
" logits = model(embeddings)\n",
" probabilities = F.softmax(logits, dim=1)\n",
" max_prob, predicted = torch.max(probabilities, 1)\n",
" \n",
" # Calculate energy score\n",
" energy = energy_score(logits)\n",
"\n",
" # If energy > threshold, consider the input as unknown class\n",
" if energy.item() > energy_threshold:\n",
" return [\"Unknown\", max_prob.item(), energy.item()]\n",
" else:\n",
" return [idx_to_class[predicted.item()], max_prob.item(), energy.item()]\n",
"\n",
"# Example usage:\n",
"sentence = \"weather today\"\n",
"energy_threshold = ENERGY_THRESHOLD\n",
"predicted = predict_with_energy(model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold)\n",
"print(f'Predicted: {predicted}')\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,4 @@
# config.py
model_name = "Qwen/Qwen2.5-3B"
DIMENSIONS = 96

View File

@ -0,0 +1,71 @@
# data_utils.py
import json
import torch
from torch.utils.data import Dataset
from torch.nn.utils.rnn import pad_sequence
def load_data(file_path):
with open(file_path, "r") as f:
data = json.load(f)
return data
def create_class_mappings(data):
class_to_idx = {cls: idx for idx, cls in enumerate(data.keys())}
idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}
return class_to_idx, idx_to_class
def preprocess_data(data, embedding_map, tokenizer, class_to_idx, max_length=64):
dataset = []
for label, sentences in data.items():
for sentence in sentences:
tokens = tokenizer.tokenize(sentence)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
embeddings = [
embedding_map[token_id] for token_id in token_ids[:max_length]
]
embeddings = torch.tensor(embeddings)
dataset.append((class_to_idx[label], embeddings))
return dataset
def get_sentences(data):
result = []
for _, sentences in data.items():
for sentence in sentences:
result.append(sentence)
return result
class TextDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def collate_fn(self, batch):
labels, embeddings = zip(*batch)
labels = torch.tensor(labels)
embeddings = pad_sequence(embeddings, batch_first=True)
return labels, embeddings
class NegativeSampleDataset(Dataset):
def __init__(self, negative_samples):
self.samples = negative_samples
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.samples[idx]
def collate_fn(self, batch):
embeddings = pad_sequence(batch, batch_first=True)
return embeddings

View File

@ -0,0 +1,53 @@
# model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from training.config import DIMENSIONS
class SelfAttention(nn.Module):
def __init__(self, input_dim, heads):
super(SelfAttention, self).__init__()
self.heads = heads
self.scale = (input_dim // heads) ** -0.5
self.qkv = nn.Linear(input_dim, input_dim * 3)
self.fc = nn.Linear(input_dim, input_dim)
def forward(self, x):
batch_size, seq_length, embedding_dim = x.shape
qkv = self.qkv(x).view(
batch_size, seq_length, self.heads, 3, embedding_dim // self.heads
)
q, k, v = qkv[..., 0, :], qkv[..., 1, :], qkv[..., 2, :]
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale
attn_weights = F.softmax(attn_weights, dim=-1)
attention_output = torch.matmul(attn_weights, v)
attention_output = attention_output.permute(0, 2, 1, 3).contiguous()
attention_output = attention_output.view(batch_size, seq_length, embedding_dim)
return self.fc(attention_output)
class AttentionBasedModel(nn.Module):
def __init__(self, input_dim, num_classes, heads=8, dim_feedforward=512, num_layers=3):
super(AttentionBasedModel, self).__init__()
self.self_attention_layers = nn.ModuleList([
SelfAttention(input_dim, heads) for _ in range(num_layers)
])
self.fc1 = nn.Linear(input_dim, dim_feedforward)
self.fc2 = nn.Linear(dim_feedforward, num_classes)
self.dropout = nn.Dropout(0.5)
self.norm = nn.LayerNorm(input_dim)
def forward(self, x):
for attn_layer in self.self_attention_layers:
attn_output = attn_layer(x)
x = self.norm(attn_output + x)
pooled_output = torch.mean(x, dim=1)
x = F.relu(self.fc1(pooled_output))
x = self.dropout(x)
x = self.fc2(x)
return x

View File

@ -0,0 +1,155 @@
# train.py
from sklearn.model_selection import train_test_split
import torch
import json
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from training.data_utils import (
load_data,
create_class_mappings,
preprocess_data,
TextDataset,
NegativeSampleDataset,
)
from training.model import AttentionBasedModel
import torch.nn as nn
def energy_score(logits, temperature=1.0):
return -torch.logsumexp(logits / temperature, dim=1)
def generate_noise(batch_size, seq_length, input_dim, device):
return torch.randn(batch_size, seq_length, input_dim).to(device)
def train_energy_model(
model,
train_loader,
negative_loader,
criterion,
optimizer,
num_epochs=10,
margin=1.0,
temperature=0.4,
):
model.train()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
negative_iter = iter(negative_loader)
for epoch in range(num_epochs):
total_loss = 0
for _, (labels, embeddings) in enumerate(train_loader):
embeddings = embeddings.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(embeddings)
class_loss = criterion(outputs, labels)
known_energy = energy_score(outputs, temperature)
positive_margin = 0.0
energy_loss_known = F.relu(known_energy - positive_margin).mean()
noise_embeddings = torch.randn_like(embeddings).to(device)
noise_outputs = model(noise_embeddings)
noise_energy = energy_score(noise_outputs, temperature)
energy_loss_noise = F.relu(margin - noise_energy).mean()
try:
negative_samples = next(negative_iter)
except StopIteration:
negative_iter = iter(negative_loader)
negative_samples = next(negative_iter)
negative_samples = negative_samples.to(device)
negative_outputs = model(negative_samples)
negative_energy = energy_score(negative_outputs, temperature)
energy_loss_negative = F.relu(margin - negative_energy).mean()
total_energy_loss = (
energy_loss_known + energy_loss_noise + energy_loss_negative
)
total_loss_batch = class_loss + total_energy_loss
total_loss_batch.backward()
optimizer.step()
total_loss += total_loss_batch.item()
avg_loss = total_loss / len(train_loader)
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")
def main():
from config import model_name, DIMENSIONS
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
data = load_data("data.json")
class_to_idx, idx_to_class = create_class_mappings(data)
embedding_map = torch.load("token_id_to_reduced_embedding.pt")
dataset = preprocess_data(data, embedding_map, tokenizer, class_to_idx)
train_data, _ = train_test_split(dataset, test_size=0.2)
train_dataset = TextDataset(train_data)
train_loader = DataLoader(
train_dataset, batch_size=24, shuffle=True, collate_fn=train_dataset.collate_fn
)
with open("noise.json", "r") as f:
negative_samples_list = json.load(f)
negative_embedding_list = []
for sentence in negative_samples_list:
tokens = tokenizer.tokenize(sentence)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
embeddings = [embedding_map[token_id] for token_id in token_ids[:64]]
embeddings = torch.tensor(embeddings)
negative_embedding_list.append(embeddings)
negative_dataset = NegativeSampleDataset(negative_embedding_list)
negative_loader = DataLoader(
negative_dataset,
batch_size=24,
shuffle=True,
collate_fn=negative_dataset.collate_fn,
)
input_dim = DIMENSIONS
num_classes = len(class_to_idx)
model = AttentionBasedModel(input_dim, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=7e-4)
train_energy_model(
model, train_loader, negative_loader, criterion, optimizer, num_epochs=120
)
torch.save(model.state_dict(), "model.pt")
dummy_input = torch.randn(1, 64, DIMENSIONS)
torch.onnx.export(
model,
dummy_input,
"model.onnx",
input_names=["input"],
output_names=["output"],
dynamic_axes={
"input": {0: "batch_size", 1: "seq_length"},
"output": {0: "batch_size"},
},
opset_version=11,
)
meta = {
"idx_to_class": idx_to_class,
"threshold": 0
}
with open('NLU_meta.json', 'w') as f:
json.dump(meta, f)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,75 @@
from training.model import AttentionBasedModel
from training.config import model_name
import json
import torch
from transformers import AutoTokenizer
import torch
import torch.nn.functional as F
from training.config import DIMENSIONS
from training.model import AttentionBasedModel
def energy_score(logits):
# Energy score is minus logsumexp
return -torch.logsumexp(logits, dim=1)
def predict_with_energy(
model,
sentence,
embedding_map,
tokenizer,
idx_to_class,
energy_threshold,
max_length=64,
):
model.eval()
tokens = tokenizer.tokenize(sentence)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
print(token_ids)
embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]
embeddings = torch.tensor(embeddings).unsqueeze(0) # Add batch dimension
current_shape = embeddings.shape
if current_shape[1] < 2:
pad_size = 2 - current_shape[1]
embeddings = F.pad(
embeddings, (0, 0, 0, pad_size, 0, 0), mode="constant", value=0
)
with torch.no_grad():
logits = model(embeddings)
print(logits)
probabilities = F.softmax(logits, dim=1)
max_prob, predicted = torch.max(probabilities, 1)
# Calculate energy score
energy = energy_score(logits)
# If energy > threshold, consider the input as unknown class
if energy.item() > energy_threshold:
return ["Unknown", max_prob.item(), energy.item()]
else:
return [idx_to_class[predicted.item()], max_prob.item(), energy.item()]
with open("data.json", "r") as f:
data = json.load(f)
class_to_idx = {cls: idx for idx, cls in enumerate(data.keys())}
idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}
num_classes = len(class_to_idx)
input_dim = DIMENSIONS
model = AttentionBasedModel(input_dim, num_classes)
model.load_state_dict(torch.load("./model.pt"))
embedding_map = torch.load("token_id_to_reduced_embedding.pt")
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Example usage:
ENERGY_THRESHOLD = 2
sentence = "what on earth is the cross entropy loss"
energy_threshold = ENERGY_THRESHOLD
predicted = predict_with_energy(
model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold
)
print(f"Predicted: {predicted}")

View File

@ -0,0 +1,80 @@
from training.model import AttentionBasedModel
from training.config import model_name
from training.config import DIMENSIONS
from training.data_utils import get_sentences
import json
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer
from tqdm import tqdm
from sklearn.metrics import f1_score, accuracy_score, precision_recall_fscore_support
def energy_score(logits):
# Energy score is minus logsumexp
return -torch.logsumexp(logits, dim=1)
def get_energy(
model,
sentence,
embedding_map,
tokenizer,
max_length=64,
):
model.eval()
tokens = tokenizer.tokenize(sentence)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]
embeddings = torch.tensor(embeddings).unsqueeze(0) # Add batch dimension
current_shape = embeddings.shape
if current_shape[1] < 2:
pad_size = 2 - current_shape[1]
embeddings = F.pad(
embeddings, (0, 0, 0, pad_size, 0, 0), mode="constant", value=0
)
with torch.no_grad():
logits = model(embeddings)
# Calculate energy score
energy = energy_score(logits)
return energy
with open("data.json", "r") as f:
positive_data = json.load(f)
class_to_idx = {cls: idx for idx, cls in enumerate(positive_data.keys())}
idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}
num_classes = len(class_to_idx)
with open("noise.json", "r") as f:
negative_data = json.load(f)
input_dim = DIMENSIONS
model = AttentionBasedModel(input_dim, num_classes)
model.load_state_dict(torch.load("./model.pt"))
embedding_map = torch.load("token_id_to_reduced_embedding.pt")
tokenizer = AutoTokenizer.from_pretrained(model_name)
all_preds = []
all_labels = []
ENERGY_THRESHOLD = 2
for item in tqdm(get_sentences(positive_data)):
result = get_energy(model, item, embedding_map, tokenizer) < ENERGY_THRESHOLD
all_preds.append(result)
all_labels.append(1)
for item in tqdm(negative_data):
result = get_energy(model, item, embedding_map, tokenizer) < ENERGY_THRESHOLD
all_preds.append(result)
all_labels.append(0)
precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')
accuracy = accuracy_score(all_labels, all_preds)
print(f'Accuracy: {accuracy:.4f}')
print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1 Score: {f1:.4f}')

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,103 @@
import torch
import torch.nn.functional as F
import nltk
from nltk.tokenize import sent_tokenize, word_tokenize
from training.model import AttentionBasedModel
# Ensure required NLTK resources are available
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
# Load pre-saved mappings
pos2idx = torch.load('pos2idx.pt')
class_mapping = torch.load('class_mapping.pt')
# Load the pre-trained model and state
model = AttentionBasedModel(40, 32, 6, 8, 6, 256)
model.load_state_dict(torch.load("./model.pt", weights_only=True))
# Define helper functions
def pad_sequence(seq, max_len):
return seq + [-1] * (max_len - len(seq))
def encode_pos_tags(tagged_sentence):
return [pos2idx[tag] if tag in pos2idx else -1 for _, tag in tagged_sentence]
# Split sentence into smaller chunks based on punctuation and length constraints
def split_long_sentence(sentence, max_len=128):
tokens = word_tokenize(sentence)
if len(tokens) <= max_len:
return [sentence]
# Attempt to split based on punctuation marks
punctuation_marks = [',', ';', ':', '!', '?', '.', '-']
split_chunks = []
current_chunk = []
for token in tokens:
current_chunk.append(token)
if token in punctuation_marks and len(current_chunk) >= max_len // 2:
split_chunks.append(' '.join(current_chunk))
current_chunk = []
if current_chunk:
split_chunks.append(' '.join(current_chunk))
# If chunks are still too long, truncate them
final_chunks = []
for chunk in split_chunks:
chunk_tokens = word_tokenize(chunk)
if len(chunk_tokens) > max_len:
final_chunks.extend([' '.join(chunk_tokens[i:i + max_len]) for i in range(0, len(chunk_tokens), max_len)])
else:
final_chunks.append(chunk)
return final_chunks
# Main function to process and score a chunk
def score_sentence(sentence, model, max_length=128):
# Tokenize and POS-tag the sentence
tagged_sentence = nltk.pos_tag(nltk.word_tokenize(sentence))
# Encode the POS tags and pad the sequence
encoded_sentences = encode_pos_tags(tagged_sentence)
padded_sentence = torch.tensor(pad_sequence(encoded_sentences, max_length))
# Set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
# Prepare the model
model.to(device)
model.eval() # Ensure the model is in evaluation mode
# Define weights and CEFR levels
w_list = [1.04, 1.64, 2.35, 3.44, 4.92, 6.13]
# Inference without gradient calculation
with torch.no_grad():
sentence_tensor = padded_sentence.to(device)
sentence_tensor = torch.unsqueeze(sentence_tensor, 0) # Add batch dimension
# Forward pass through the model
outputs = model(sentence_tensor)
# Softmax and weighted scoring
probabilities = torch.softmax(outputs[0], dim=0)
score = sum(probabilities[i] * w_list[i] for i in range(6)).cpu().numpy()
return score
# Function to process a long article and return score list for each chunk
def score_article(article, max_length=128, chunk_max_len=128):
sentences = sent_tokenize(article) # Split the article into sentences
score_list = []
for sentence in sentences:
chunks = split_long_sentence(sentence, max_len=chunk_max_len)
for chunk in chunks:
score = score_sentence(chunk, model, max_length=max_length)
score_list.append(float(score))
return score_list

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,82 @@
import os
from dotenv import load_dotenv
import pandas as pd
from openai import OpenAI
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
load_dotenv()
client = OpenAI(
api_key=os.getenv("API_KEY"),
base_url=os.getenv("BASE_URL"),
)
def get_AI_response(text, client, model_name, temp):
messages = [
{"role": "user", "content": text},
]
response = client.chat.completions.create(
model=model_name,
messages=messages,
temperature=temp,
)
return response.choices[0].message.content
def get_Examples(df, row, client, model_name, temp):
exp = df["Example"][row]
cds = df["Can-do statement"][row]
gdw = df["guideword"][row]
lvl = df["Level"][row]
cat = df["SuperCategory"][row] + '/' + df["SubCategory"][row]
prompt = \
f'''Generate 10 example sentences based on the following instructions.
Pay close attention to the 'Can-do Statement' and ensure all generated sentences adhere strictly to it.
Provide only the sentences without any additional formatting or markdown.
Output the sentences in plain text, one sentence per line, and do not contain empty line.
INSTRUCTION
Level: {lvl}
Guideword: {gdw}
Can-do Statement: {cds}
Category: {cat}
Example Sentences:
{exp}
'''
return get_AI_response(prompt, client, model_name, temp)
def process_chunk(df, chunk, client, model, temp):
results = []
for row in chunk:
exps = get_Examples(df, row, client, model, temp)
results.append(exps)
return results
input_file = './EGP.csv'
df = pd.read_csv(input_file)
newdf = df.copy()
model = os.getenv("TRANSLATION_MODEL")
temp = float(os.getenv("TRANSLATION_TEMP"))
chunk_size = 64
total_rows = len(df.index)
num_chunks = (total_rows + chunk_size - 1) // chunk_size # Ceiling division
with tqdm(total=total_rows) as pbar:
for chunk_idx in range(num_chunks):
start = chunk_idx * chunk_size
end = min(start + chunk_size, total_rows)
chunk = range(start, end)
with ThreadPoolExecutor(max_workers=len(chunk)) as executor:
futures = {executor.submit(get_Examples, df, row, client, model, temp): row for row in chunk} # 将 row 与 future 绑定
for future in as_completed(futures):
row = futures[future] # 获取对应的行号
result = future.result() # 获取 AI 返回的结果
newdf.at[row, "Example"] = result # 更新到正确的行
pbar.update(len(chunk))
newdf.to_csv("output.csv", index=False)
newdf.to_csv("EGP_Derivied.csv", index=False)

View File

@ -0,0 +1,23 @@
import pandas as pd
df = pd.read_csv("EGP_Derivied.csv")
newdf = pd.DataFrame()
levels_list=[]
sentences_list=[]
category_list=[]
for line in range(len(df.index)):
examples = list(filter(None, df["Example"][line].split("\n")))
lvl = df["Level"][line]
cat = df["SuperCategory"][line] + '/' + df["SubCategory"][line]
for sentence in examples:
sentences_list.append(sentence)
levels_list.append(lvl)
category_list.append(cat)
newdf["Level"] = levels_list
newdf["Category"] = category_list
newdf["Sentence"] = sentences_list
newdf.to_csv("data.csv", index=False)

View File

@ -0,0 +1,111 @@
# model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class PositionalEncoding(nn.Module):
def __init__(self, embedding_dim, max_len=5000):
super(PositionalEncoding, self).__init__()
# Create a positional encoding matrix of shape (max_len, embedding_dim)
pe = torch.zeros(max_len, embedding_dim)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# Add a batch dimension, so the shape becomes (1, max_len, embedding_dim)
pe = pe.unsqueeze(0)
# Register the positional encoding as a buffer so it won't be updated by the optimizer
self.register_buffer('pe', pe)
def forward(self, x):
# x is expected to have shape (batch_size, seq_length, embedding_dim)
seq_length = x.size(1)
# Add positional encoding to input
x = x + self.pe[:, :seq_length]
return x
class SelfAttention(nn.Module):
def __init__(self, input_dim, heads):
super(SelfAttention, self).__init__()
self.heads = heads
self.scale = (input_dim // heads) ** -0.5
self.qkv = nn.Linear(input_dim, input_dim * 3)
self.fc = nn.Linear(input_dim, input_dim)
def forward(self, x):
batch_size, seq_length, embedding_dim = x.shape
qkv = self.qkv(x).view(
batch_size, seq_length, self.heads, 3, embedding_dim // self.heads
)
q, k, v = qkv[..., 0, :], qkv[..., 1, :], qkv[..., 2, :]
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale
attn_weights = F.softmax(attn_weights, dim=-1)
attention_output = torch.matmul(attn_weights, v)
attention_output = attention_output.permute(0, 2, 1, 3).contiguous()
attention_output = attention_output.view(batch_size, seq_length, embedding_dim)
return self.fc(attention_output)
class AttentionBasedModel(nn.Module):
def __init__(self, pos_vocab_size, embedding_dim=128, num_classes=6, heads=8, num_attention_layers=3, dim_feedforward=512, max_len=128):
super(AttentionBasedModel, self).__init__()
self.embedding = nn.Embedding(pos_vocab_size, embedding_dim) # Embedding for POS tags
self.positional_encoding = PositionalEncoding(embedding_dim, max_len) # Positional Encoding
self.self_attention_layers = nn.ModuleList([
SelfAttention(embedding_dim, heads) for _ in range(num_attention_layers)
])
self.fc1 = nn.Linear(embedding_dim, dim_feedforward)
self.fc2 = nn.Linear(dim_feedforward, num_classes)
self.dropout = nn.Dropout(0.5)
self.norm = nn.LayerNorm(embedding_dim)
def forward(self, x):
# Input x is a matrix of one-hot encoded POS tags, shape: (batch_size, seq_length, pos_vocab_size)
x = self.embedding(x) # Convert POS tags to embeddings, shape: (batch_size, seq_length, embedding_dim)
# Add positional encoding to embeddings
x = self.positional_encoding(x)
for attn_layer in self.self_attention_layers:
attn_output = attn_layer(x)
x = self.norm(attn_output + x)
# Pool the output by taking the mean of the sequence (reduce along sequence length)
pooled_output = torch.mean(x, dim=1)
# Fully connected layers for classification
x = F.relu(self.fc1(pooled_output))
x = self.dropout(x)
x = self.fc2(x) # Output logits for the 6 classes
return x
# Example Usage
# # Hyperparameters
# pos_vocab_size = 50 # Size of the POS tag vocabulary
# max_context_length = 128 # Maximum context length
# embedding_dim = 128 # Embedding size
# num_classes = 6 # Output classes
# batch_size = 32 # Example batch size
# # Model initialization
# model = AttentionBasedModel(pos_vocab_size, embedding_dim, num_classes)
# # Example input: batch of one-hot encoded POS tags (variable length sequences)
# input_data = torch.randint(0, pos_vocab_size, (batch_size, max_context_length)) # Random input for testing
# # Forward pass
# output = model(input_data) # Output shape will be (batch_size, num_classes)
# print(output.shape) # Should print torch.Size([batch_size, num_classes])

View File

@ -0,0 +1,150 @@
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import nltk
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import numpy as np
#nltk.download('punkt')
#nltk.download('averaged_perceptron_tagger')
# Load the model classes
from model import AttentionBasedModel
# Load the data
df = pd.read_csv('data.csv')
# Step 1: Extract sentences and corresponding levels
sentences = df['Sentence'].values
levels = df['Level'].values
# Step 2: Tokenize and POS tag each sentence
pos_tags = [nltk.pos_tag(nltk.word_tokenize(sentence)) for sentence in sentences]
# Step 3: Build POS tag vocabulary
# Extract unique POS tags from the dataset
pos_vocab = set()
for tagged_sentence in pos_tags:
for _, tag in tagged_sentence:
pos_vocab.add(tag)
# Create a mapping from POS tag to index
pos2idx = {pos: idx for idx, pos in enumerate(pos_vocab)}
pos_vocab_size = len(pos2idx)
# Step 4: Encode sentences into POS tag indices
def encode_pos_tags(tagged_sentence):
return [pos2idx[tag] for _, tag in tagged_sentence]
encoded_sentences = [encode_pos_tags(tagged_sentence) for tagged_sentence in pos_tags]
# Step 5: Encode levels (classes) into integers
le = LabelEncoder()
encoded_levels = le.fit_transform(levels)
num_classes = len(le.classes_)
# Save class encoding mapping
class_mapping = dict(zip(le.transform(le.classes_), le.classes_))
torch.save(class_mapping, 'class_mapping.pt')
# Save POS tag encoding mapping
torch.save(pos2idx, 'pos2idx.pt')
# Step 6: Pad sentences to a fixed length
max_length = 64
def pad_sequence(seq, max_len):
return seq + [-1] * (max_len - len(seq))
padded_sentences = [pad_sequence(seq, max_length) for seq in encoded_sentences]
# Step 7: Create a PyTorch Dataset and DataLoader
class POSDataset(Dataset):
def __init__(self, sentences, labels):
self.sentences = sentences
self.labels = labels
def __len__(self):
return len(self.sentences)
def __getitem__(self, idx):
sentence = torch.tensor(self.sentences[idx], dtype=torch.long)
label = torch.tensor(self.labels[idx], dtype=torch.long)
return sentence, label
# Split data into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(padded_sentences, encoded_levels, test_size=0.2)
train_dataset = POSDataset(X_train, y_train)
val_dataset = POSDataset(X_val, y_val)
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
# Step 8: Initialize the model, loss function, and optimizer
embedding_dim = 32
heads = 8
num_attention_layers = 6
dim_feedforward = 256
learning_rate = 0.003
model = AttentionBasedModel(pos_vocab_size, embedding_dim, num_classes, heads, num_attention_layers, dim_feedforward)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Step 9: Training loop
num_epochs = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
model.to(device)
step = 0
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for sentences, labels in train_loader:
sentences, labels = sentences.to(device), labels.to(device)
# Forward pass
outputs = model(sentences)
loss = criterion(outputs, labels)
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
step += 1
running_loss += loss.item()
# Validation phase
model.eval()
val_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for sentences, labels in val_loader:
sentences, labels = sentences.to(device), labels.to(device)
outputs = model(sentences)
loss = criterion(outputs, labels)
val_loss += loss.item()
# Calculate accuracy
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
# Print training and validation stats
print(f'Epoch [{epoch+1}/{num_epochs}], Step {step}, Loss: {running_loss/len(train_loader):.4f}, '
f'Validation Loss: {val_loss/len(val_loader):.4f}, Accuracy: {100 * correct / total:.2f}%')
torch.save(model.state_dict(), f'checkpoints/step_{step}.pt')
# Step 10: Save the trained model
torch.save(model.state_dict(), 'model.pt')

View File

@ -0,0 +1,46 @@
from training.model import AttentionBasedModel
import torch
import torch.nn.functional as F
import nltk
sentence = '''Smartphones have worked their way deep into our lives and have become indispensable for work and socialising.'''
pos2idx = torch.load('pos2idx.pt')
class_mapping = torch.load('class_mapping.pt')
def pad_sequence(seq, max_len):
return seq + [-1] * (max_len - len(seq))
def encode_pos_tags(tagged_sentence):
return [pos2idx[tag] if tag in pos2idx else -1 for _, tag in tagged_sentence]
max_length = 64
tagged_sentence = nltk.pos_tag(nltk.word_tokenize(sentence))
encoded_sentences = encode_pos_tags(tagged_sentence)
padded_sentence = torch.tensor(pad_sequence(encoded_sentences, max_length))
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
model = AttentionBasedModel(40, 32, 6, 8, 6, 256)
model.load_state_dict(torch.load("./model.pt", weights_only=True))
model.to(device)
model.eval() # 确保模型处于评估模式
w_list=[1.35, 1.63, 2.75, 3.64, 5.38, 6.32]
cefr_dict = [None, "A1", "A2", "B1", "B2", "C1", "C2"]
with torch.no_grad():
sentence = padded_sentence.to(device)
sentences = torch.unsqueeze(sentence, 0)
outputs = model(sentences)
print(torch.max(outputs, 1))
print(outputs[0])
print(torch.softmax(outputs[0],0))
s=0
for i in range(6):
s+=torch.softmax(outputs[0],0)[i] * w_list[i]
s=s.cpu().numpy()
# the s is the final output.

View File

@ -0,0 +1,48 @@
import json, random
from torch.utils.data import Dataset
max_dataset_size = 82000
class MultiTRANS19(Dataset):
def __init__(self, data_file):
self.data = self.load_data(data_file)
def load_data(self, data_file):
Data = []
file_lines = []
with open(data_file, "rt", encoding="utf-8") as f:
file_lines = f.readlines()
combine_number_list = []
for _ in range(max_dataset_size):
num = random.randint(2, 7)
combine_number_list.append(num)
file_lines = random.sample(file_lines, sum(combine_number_list))
total = 0
for combine_count in combine_number_list:
num_combination = combine_number_list[combine_count]
sample = {
"chinese": "",
"english": ""
}
for line in file_lines[total: total+num_combination]:
try:
line_sample = json.loads(line.strip())
sample["chinese"] += line_sample["chinese"]
sample["english"] += line_sample["english"]
except json.JSONDecodeError as e:
print(f"Error decoding line: {e}")
Data.append(sample)
total+=num_combination
return Data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]

View File

@ -0,0 +1,41 @@
import json, random
from torch.utils.data import Dataset
max_dataset_size = 220000
class TRANS19(Dataset):
def __init__(self, data_file):
self.data = self.load_data(data_file)
def load_data(self, data_file):
with open(data_file, "rt", encoding="utf-8") as f:
total_lines = sum(1 for _ in f)
# 生成不重复的随机行号列表
random_line_numbers = random.sample(
range(total_lines), min(max_dataset_size, total_lines)
)
random_line_numbers.sort() # 排序以便按顺序读取文件
Data = []
current_line_number = 0
with open(data_file, "rt", encoding="utf-8") as f:
for idx, line in enumerate(f):
if current_line_number >= len(random_line_numbers):
break
if idx == random_line_numbers[current_line_number]:
try:
sample = json.loads(line.strip())
Data.append(sample)
except json.JSONDecodeError:
print(f"Error decoding line {idx}")
current_line_number += 1
return Data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]

View File

@ -0,0 +1,42 @@
import random
from torch.utils.data import Dataset
max_dataset_size = 100000
class Wikititle(Dataset):
def __init__(self, data_file):
self.data = self.load_data(data_file)
def load_data(self, data_file):
with open(data_file, "rt", encoding="utf-8") as f:
total_lines = sum(1 for _ in f)
# 生成不重复的随机行号列表
random_line_numbers = random.sample(
range(total_lines), min(max_dataset_size, total_lines)
)
random_line_numbers.sort() # 排序以便按顺序读取文件
Data = []
current_line_number = 0
with open(data_file, "rt", encoding="utf-8") as f:
for idx, line in enumerate(f):
if current_line_number >= len(random_line_numbers):
break
if idx == random_line_numbers[current_line_number]:
zh, en = line.split("\t")
sample = {
"chinese": zh,
"english": en
}
Data.append(sample)
current_line_number += 1
return Data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]

View File

@ -0,0 +1,232 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 19,
"id": "07b697c8-5cc2-4021-9ab8-e7e3c90065ee",
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"import torch\n",
"from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n",
"\n",
"# 定义参数\n",
"model_checkpoint = \"Helsinki-NLP/opus-mt-zh-en\"\n",
"checkpoint_path = \"./saves/step_86500_bleu_29.87.bin\" # 假设使用训练中的checkpoint\n",
"\n",
"# 加载tokenizer和模型\n",
"tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n",
"model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)\n",
"\n",
"# 加载checkpoint\n",
"#model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')[\"model_state_dict\"])\n",
"model.eval()\n",
"\n",
"# 将模型转移到设备\n",
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"model = model.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "ccfb5004-2bdd-4d64-88a3-2af96b87092c",
"metadata": {},
"outputs": [],
"source": [
"def infer_translation_batch(input_texts, model, tokenizer, max_length=512, num_beams=1, length_penalty=1):\n",
" # 记录推理开始时间\n",
" start_time = time.time()\n",
"\n",
" # 预处理输入文本(批量处理)\n",
" inputs = tokenizer(\n",
" input_texts,\n",
" return_tensors=\"pt\",\n",
" padding=True, # 使用动态填充,对齐批量输入的长度\n",
" truncation=True,\n",
" max_length=max_length,\n",
" ).to(device)\n",
"\n",
" # 模型生成翻译\n",
" with torch.no_grad():\n",
" output_tokens = model.generate(\n",
" inputs[\"input_ids\"],\n",
" num_beams=num_beams,\n",
" length_penalty=length_penalty,\n",
" early_stopping=False,\n",
" #temperature=0.5,\n",
" #top_p=0.90,\n",
" do_sample=False\n",
" )\n",
"\n",
" # 解码生成的tokens为文本批量处理\n",
" translations = [\n",
" tokenizer.decode(output, skip_special_tokens=True) for output in output_tokens\n",
" ]\n",
"\n",
" # 记录推理结束时间\n",
" end_time = time.time()\n",
" inference_time = end_time - start_time\n",
"\n",
" return translations, inference_time\n",
"\n",
"def translate(input_text, model, tokenizer, batch_size=16):\n",
" lines = input_text.splitlines()\n",
" \n",
" # 存储每一行的翻译结果\n",
" translations = []\n",
" total_time = 0\n",
" \n",
" # 分批处理\n",
" for i in range(0, len(lines), batch_size):\n",
" batch_lines = [line for line in lines[i:i + batch_size] if line.strip()]\n",
" if not batch_lines:\n",
" translations.extend([\"\"] * len(batch_lines))\n",
" continue\n",
" batch_translations, time_cost = infer_translation_batch(batch_lines, model, tokenizer)\n",
" translations.extend(batch_translations)\n",
" total_time += time_cost\n",
" \n",
" final_translation = \"\\n\".join(translations)\n",
" \n",
" return final_translation, total_time\n"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "d5d35c96-3c4a-487c-ac26-d3d97f1208a6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Original Text: \n",
"\n",
"为了降低Transformer翻译模型如基于Helsinki-NLP的Opus模型的推理时间并提高性能以下是一些常见且有效的优化方法\n",
"\n",
"1. 模型量化\n",
"简介量化是通过使用低精度数值表示模型权重例如将32位浮点数转换为8位整数来减少模型的计算量和内存占用从而加快推理速度。\n",
"方法:\n",
"Post-training quantization (PTQ):模型训练后对权重进行量化。\n",
"Quantization-aware training (QAT)在训练时引入量化通常效果比PTQ更好。\n",
"2. 模型剪枝\n",
"简介:剪枝通过移除模型中对推理结果影响较小的权重和节点来减小模型规模,从而加速推理。\n",
"方法:\n",
"结构化剪枝:移除整个层、注意力头或神经元。\n",
"非结构化剪枝:移除个别的低权重参数。\n",
"3. 减少模型尺寸\n",
"简介:通过使用更小的模型架构(例如减少层数、隐藏层维度或注意力头的数量),可以减少计算量和推理时间。\n",
"方法使用较小版本的模型例如opus-mt-small或手动调整Transformer的超参数。\n",
"4. 启用混合精度推理\n",
"简介混合精度推理允许部分计算使用半精度浮点数FP16从而减少内存占用并提高推理速度。\n",
"工具:\n",
"NVIDIA的TensorRT和**AMP (Automatic Mixed Precision)**是常用的工具可以自动处理FP16的计算。\n",
"5. 使用高效的解码策略\n",
"简介解码策略的选择影响推理速度。常用的解码方式如Beam Search虽然精度较高但速度较慢。\n",
"方法:\n",
"降低beam size减小beam size可以显著加快解码速度虽然可能会略微牺牲翻译质量。\n",
"Top-k sampling和Nucleus Sampling (Top-p sampling):这些方法通过限制词汇选择的范围来加快推理速度。\n",
"\n",
"\n",
"\n",
"Translated Text: \n",
"To reduce the time of reasoning and improve performance of the Transformer translation model (e.g., the Opus model based on Helsinki-NLP), the following are common and effective methods of optimization:\n",
"Model quantification\n",
"Profile: Quantification reduces model computing and memory occupancy by using low precision values to indicate model weights (e.g., converting 32-digit float points to 8-digit integer values), thereby accelerating reasoning.\n",
"Methodology:\n",
"Post-training Quantisation (PTQ): Quantifying weights after model training.\n",
"Quantification-aware trading (QAT): Quantification is introduced in training, usually with better results than PTQ.\n",
"Model cutting\n",
"Profile: Cuts reduce the size of the model by removing weights and nodes in the model that influence the reasoning results less.\n",
"Methodology:\n",
"Structured cut-off: removes the whole layer, attention head or neuron.\n",
"Unstructured cut-off: removes individual low weight parameters.\n",
"3. Reduction of model size\n",
"Profile: The calculation and reasoning time can be reduced by using smaller model structures (e.g., reducing the number of layers, hidden layers or the number of attention points).\n",
"Method: Use smaller versions of models, such as opus-mt-small, or manually adjust Transformer's hyperparameters.\n",
"4. Enable mixed precision reasoning\n",
"Introduction: The mixed precision reasoning allows for partial calculation of semi-precision floats (FP16), thereby reducing memory occupancy and increasing the speed of reasoning.\n",
"Tools:\n",
"The NVIDIA TensorRT and **AMP (Automatic Mixed Precision)** are commonly used tools that can automatically process FP16 calculations.\n",
"Use of efficient decoding strategies\n",
"Profile: The selection of the decoding strategy affects the speed of reasoning. Common decoding methods such as BeamSearch are more precise but slow.\n",
"Methodology:\n",
"Lower beam size: Reduction of beam size can significantly accelerate decoding, although it may be at the expense of translation quality.\n",
"Top-k sampling and Nucleus Sampling (Top-p sampling): These methods accelerate reasoning by limiting the range of vocabulary selections.\n",
"\n",
"Inference Time: 2.8956 seconds\n"
]
}
],
"source": [
"# 用户输入\n",
"input_text = '''\n",
"为了降低Transformer翻译模型如基于Helsinki-NLP的Opus模型的推理时间并提高性能以下是一些常见且有效的优化方法\n",
"\n",
"1. 模型量化\n",
"简介量化是通过使用低精度数值表示模型权重例如将32位浮点数转换为8位整数来减少模型的计算量和内存占用从而加快推理速度。\n",
"方法:\n",
"Post-training quantization (PTQ):模型训练后对权重进行量化。\n",
"Quantization-aware training (QAT)在训练时引入量化通常效果比PTQ更好。\n",
"2. 模型剪枝\n",
"简介:剪枝通过移除模型中对推理结果影响较小的权重和节点来减小模型规模,从而加速推理。\n",
"方法:\n",
"结构化剪枝:移除整个层、注意力头或神经元。\n",
"非结构化剪枝:移除个别的低权重参数。\n",
"3. 减少模型尺寸\n",
"简介:通过使用更小的模型架构(例如减少层数、隐藏层维度或注意力头的数量),可以减少计算量和推理时间。\n",
"方法使用较小版本的模型例如opus-mt-small或手动调整Transformer的超参数。\n",
"4. 启用混合精度推理\n",
"简介混合精度推理允许部分计算使用半精度浮点数FP16从而减少内存占用并提高推理速度。\n",
"工具:\n",
"NVIDIA的TensorRT和**AMP (Automatic Mixed Precision)**是常用的工具可以自动处理FP16的计算。\n",
"5. 使用高效的解码策略\n",
"简介解码策略的选择影响推理速度。常用的解码方式如Beam Search虽然精度较高但速度较慢。\n",
"方法:\n",
"降低beam size减小beam size可以显著加快解码速度虽然可能会略微牺牲翻译质量。\n",
"Top-k sampling和Nucleus Sampling (Top-p sampling):这些方法通过限制词汇选择的范围来加快推理速度。\n",
"'''\n",
"\n",
"# 进行推理并测量时间\n",
"translated_text, time_taken = translate(input_text, model, tokenizer)\n",
"\n",
"# 输出结果\n",
"print(f\"Original Text: \\n{input_text}\\n\\n\")\n",
"print(f\"Translated Text: \\n{translated_text}\\n\")\n",
"print(f\"Inference Time: {time_taken:.4f} seconds\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e4a44b25-a8bb-4a82-964a-0811c34c256c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,338 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "e95d90ec-1f93-45d9-ab8a-ee3d0bae293d",
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"import os\n",
"import numpy as np\n",
"import torch\n",
"from torch.utils.data import Dataset, DataLoader, random_split\n",
"from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n",
"from transformers import AdamW, get_scheduler\n",
"from sacrebleu.metrics import BLEU\n",
"from tqdm.auto import tqdm\n",
"import json"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9b8e703a-a5b5-43bf-9b12-2220d869145a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using cpu device\n"
]
}
],
"source": [
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"print(f'Using {device} device')\n",
"\n",
"max_dataset_size = 22000\n",
"train_set_size = 20000\n",
"valid_set_size = 2000\n",
"\n",
"max_input_length = 128\n",
"max_target_length = 128\n",
"\n",
"batch_size = 16\n",
"learning_rate = 1e-5\n",
"epoch_num = 3"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "3db1484a-e923-44b9-a2e6-52178a8c09ee",
"metadata": {},
"outputs": [],
"source": [
"class TRANS(Dataset):\n",
" def __init__(self, data_file):\n",
" self.data = self.load_data(data_file)\n",
" \n",
" def load_data(self, data_file):\n",
" Data = {}\n",
" with open(data_file, 'rt', encoding='utf-8') as f:\n",
" for idx, line in enumerate(f):\n",
" if idx >= max_dataset_size:\n",
" break\n",
" sample = json.loads(line.strip())\n",
" Data[idx] = sample\n",
" return Data\n",
" \n",
" def __len__(self):\n",
" return len(self.data)\n",
"\n",
" def __getitem__(self, idx):\n",
" return self.data[idx]\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "0258cad4-f498-4952-ac29-e103ae8e9041",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/transformers/models/marian/tokenization_marian.py:175: UserWarning: Recommended: pip install sacremoses.\n",
" warnings.warn(\"Recommended: pip install sacremoses.\")\n",
"/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:1601: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
" warnings.warn(\n"
]
}
],
"source": [
"data = TRANS('./data/translation2019zh/translation2019zh_train.json')\n",
"train_data, valid_data = random_split(data, [train_set_size, valid_set_size])\n",
"test_data = TRANS('./data/translation2019zh/translation2019zh_valid.json')\n",
"\n",
"model_checkpoint = \"Helsinki-NLP/opus-mt-zh-en\"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n",
"model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)\n",
"model = model.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "437fb69c-59f6-48f0-9c74-330cf4862b22",
"metadata": {},
"outputs": [],
"source": [
"def collote_fn(batch_samples):\n",
" batch_inputs, batch_targets = [], []\n",
" for sample in batch_samples:\n",
" batch_inputs.append(sample['chinese'])\n",
" batch_targets.append(sample['english'])\n",
" batch_data = tokenizer(\n",
" batch_inputs, \n",
" padding=True, \n",
" max_length=max_input_length,\n",
" truncation=True, \n",
" return_tensors=\"pt\"\n",
" )\n",
" with tokenizer.as_target_tokenizer():\n",
" labels = tokenizer(\n",
" batch_targets, \n",
" padding=True, \n",
" max_length=max_target_length,\n",
" truncation=True, \n",
" return_tensors=\"pt\"\n",
" )[\"input_ids\"]\n",
" batch_data['decoder_input_ids'] = model.prepare_decoder_input_ids_from_labels(labels)\n",
" end_token_index = torch.where(labels == tokenizer.eos_token_id)[1]\n",
" for idx, end_idx in enumerate(end_token_index):\n",
" labels[idx][end_idx+1:] = -100\n",
" batch_data['labels'] = labels\n",
" return batch_data\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "b9f261d8-02ca-47fc-92d7-6d495ae9c6a1",
"metadata": {},
"outputs": [],
"source": [
"train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=collote_fn)\n",
"valid_dataloader = DataLoader(valid_data, batch_size=batch_size, shuffle=False, collate_fn=collote_fn)\n",
"test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False, collate_fn=collote_fn)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "6fcfa14a-a81b-4a3f-a459-cc0c06f4fa70",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/transformers/optimization.py:591: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
" warnings.warn(\n"
]
}
],
"source": [
"def train_loop(dataloader, model, optimizer, lr_scheduler, epoch, total_loss):\n",
" progress_bar = tqdm(range(len(dataloader)))\n",
" progress_bar.set_description(f'loss: {0:>7f}')\n",
" finish_batch_num = (epoch-1) * len(dataloader)\n",
" \n",
" model.train()\n",
" for batch, batch_data in enumerate(dataloader, start=1):\n",
" batch_data = batch_data.to(device)\n",
" outputs = model(**batch_data)\n",
" loss = outputs.loss\n",
"\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" lr_scheduler.step()\n",
"\n",
" total_loss += loss.item()\n",
" progress_bar.set_description(f'loss: {total_loss/(finish_batch_num + batch):>7f}')\n",
" progress_bar.update(1)\n",
" return total_loss\n",
"\n",
"bleu = BLEU()\n",
"\n",
"def test_loop(dataloader, model):\n",
" preds, labels = [], []\n",
" \n",
" model.eval()\n",
" for batch_data in tqdm(dataloader):\n",
" batch_data = batch_data.to(device)\n",
" with torch.no_grad():\n",
" generated_tokens = model.generate(\n",
" batch_data[\"input_ids\"],\n",
" attention_mask=batch_data[\"attention_mask\"],\n",
" max_length=max_target_length,\n",
" ).cpu().numpy()\n",
" label_tokens = batch_data[\"labels\"].cpu().numpy()\n",
" \n",
" decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)\n",
" label_tokens = np.where(label_tokens != -100, label_tokens, tokenizer.pad_token_id)\n",
" decoded_labels = tokenizer.batch_decode(label_tokens, skip_special_tokens=True)\n",
"\n",
" preds += [pred.strip() for pred in decoded_preds]\n",
" labels += [[label.strip()] for label in decoded_labels]\n",
" bleu_score = bleu.corpus_score(preds, labels).score\n",
" print(f\"BLEU: {bleu_score:>0.2f}\\n\")\n",
" return bleu_score\n",
"\n",
"optimizer = AdamW(model.parameters(), lr=learning_rate)\n",
"lr_scheduler = get_scheduler(\n",
" \"linear\",\n",
" optimizer=optimizer,\n",
" num_warmup_steps=0,\n",
" num_training_steps=epoch_num*len(train_dataloader),\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "12068522-df42-484f-97f1-13ce588bf47b",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0'"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "896ba74b-1a6a-402c-b94a-e9cf47bb0d65",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/3\n",
"-------------------------------\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0453b70899854c0191a93b53748ddaa0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/12500 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/transformers/tokenization_utils_base.py:4126: UserWarning: `as_target_tokenizer` is deprecated and will be removed in v5 of Transformers. You can tokenize your labels by using the argument `text_target` of the regular `__call__` method (either in the same call as your input texts if you use the same keyword arguments, or in a separate call.\n",
" warnings.warn(\n"
]
},
{
"ename": "RuntimeError",
"evalue": "MPS backend out of memory (MPS allocated: 9.37 GB, other allocations: 8.66 GB, max allowed: 18.13 GB). Tried to allocate 222.17 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[12], line 6\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m t \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(epoch_num):\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mt\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch_num\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m-------------------------------\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 6\u001b[0m total_loss \u001b[38;5;241m=\u001b[39m \u001b[43mtrain_loop\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlr_scheduler\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtotal_loss\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 7\u001b[0m valid_bleu \u001b[38;5;241m=\u001b[39m test_loop(valid_dataloader, model)\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m valid_bleu \u001b[38;5;241m>\u001b[39m best_bleu:\n",
"Cell \u001b[0;32mIn[10], line 13\u001b[0m, in \u001b[0;36mtrain_loop\u001b[0;34m(dataloader, model, optimizer, lr_scheduler, epoch, total_loss)\u001b[0m\n\u001b[1;32m 10\u001b[0m loss \u001b[38;5;241m=\u001b[39m outputs\u001b[38;5;241m.\u001b[39mloss\n\u001b[1;32m 12\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n\u001b[0;32m---> 13\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 14\u001b[0m optimizer\u001b[38;5;241m.\u001b[39mstep()\n\u001b[1;32m 15\u001b[0m lr_scheduler\u001b[38;5;241m.\u001b[39mstep()\n",
"File \u001b[0;32m/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/torch/_tensor.py:522\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 512\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 513\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 514\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 515\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 520\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m 521\u001b[0m )\n\u001b[0;32m--> 522\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 523\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m 524\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/torch/autograd/__init__.py:347\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 342\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 344\u001b[0m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[1;32m 345\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 346\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 347\u001b[0m \u001b[43m_engine_run_backward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 348\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 349\u001b[0m \u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 350\u001b[0m \u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 351\u001b[0m \u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 352\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 353\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 354\u001b[0m \u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 355\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m/opt/anaconda3/envs/sparkastML/lib/python3.10/site-packages/torch/autograd/graph.py:818\u001b[0m, in \u001b[0;36m_engine_run_backward\u001b[0;34m(t_outputs, *args, **kwargs)\u001b[0m\n\u001b[1;32m 816\u001b[0m unregister_hooks \u001b[38;5;241m=\u001b[39m _register_logging_hooks_on_whole_graph(t_outputs)\n\u001b[1;32m 817\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 818\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 819\u001b[0m \u001b[43m \u001b[49m\u001b[43mt_outputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m 820\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[1;32m 821\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 822\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m attach_logging_hooks:\n",
"\u001b[0;31mRuntimeError\u001b[0m: MPS backend out of memory (MPS allocated: 9.37 GB, other allocations: 8.66 GB, max allowed: 18.13 GB). Tried to allocate 222.17 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure)."
]
}
],
"source": [
"epoch_num = 3\n",
"total_loss = 0.\n",
"best_bleu = 0.\n",
"for t in range(epoch_num):\n",
" print(f\"Epoch {t+1}/{epoch_num}\\n-------------------------------\")\n",
" total_loss = train_loop(train_dataloader, model, optimizer, lr_scheduler, t+1, total_loss)\n",
" valid_bleu = test_loop(valid_dataloader, model)\n",
" if valid_bleu > best_bleu:\n",
" best_bleu = valid_bleu\n",
" print('saving new weights...\\n')\n",
" torch.save(\n",
" model.state_dict(), \n",
" f'epoch_{t+1}_valid_bleu_{valid_bleu:0.2f}_model_weights.bin'\n",
" )\n",
"print(\"Done!\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6fd3439a-058a-4220-9b65-b355b52f74b5",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

202
translate-old/zh-en/train.py Executable file
View File

@ -0,0 +1,202 @@
import numpy as np
import torch
from torch.utils.data import DataLoader, random_split
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import AdamW, get_scheduler
from sacrebleu.metrics import BLEU
from tqdm.auto import tqdm
from torch.utils.tensorboard import SummaryWriter
from dataloader.multiTrans19 import MultiTRANS19
writer = SummaryWriter()
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
train_set_size = 80000
valid_set_size = 2000
test_data_size = 0
last_1k_loss = []
kmean_loss = 0.0
total_loss = 0.0
best_bleu = 0.0
step = 0
max_input_length = 128
max_target_length = 128
batch_size = 8
learning_rate = 1e-5
epoch_num = 1
# 检查点文件路径默认为None
# checkpoint_path = None
checkpoint_path = "./saves/checkpoint_76500.bin" # 如果要从检查点继续训练,设置此路径
#data = Wikititle("./data/wikititles-v3.zh-en.tsv")
data = MultiTRANS19("./data/translation2019zh/translation2019zh_train.json")
print(len(data))
train_data, valid_data, test_data = random_split(data, [train_set_size, valid_set_size, test_data_size])
# data = TRANS("./data/translation2019zh/translation2019zh_train.json")
# train_data, valid_data = random_split(data, [train_set_size, valid_set_size])
# test_data = TRANS("./data/translation2019zh/translation2019zh_valid.json")
model_checkpoint = "Helsinki-NLP/opus-mt-zh-en"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
model = model.to(device)
# 如果指定了检查点路径,则从检查点加载模型状态
if checkpoint_path is not None:
print(f"Loading checkpoint from {checkpoint_path}")
checkpoint_data = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint_data["model_state_dict"])
total_loss = checkpoint_data.get("total_loss", 0.0)
step = checkpoint_data.get("step", 0)
kmean_loss = total_loss / step
last_1k_loss = [kmean_loss] * 1000
def collote_fn(batch_samples):
batch_inputs, batch_targets = [], []
for sample in batch_samples:
batch_inputs.append(sample["chinese"])
batch_targets.append(sample["english"])
batch_data = tokenizer(
batch_inputs,
padding=True,
max_length=max_input_length,
truncation=True,
return_tensors="pt",
)
with tokenizer.as_target_tokenizer():
labels = tokenizer(
batch_targets,
padding=True,
max_length=max_target_length,
truncation=True,
return_tensors="pt",
)["input_ids"]
batch_data["decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
labels
)
end_token_index = torch.where(labels == tokenizer.eos_token_id)[1]
for idx, end_idx in enumerate(end_token_index):
labels[idx][end_idx + 1 :] = -100
batch_data["labels"] = labels
batch_data = {k: v.to(device) for k, v in batch_data.items()}
return batch_data
train_dataloader = DataLoader(
train_data, batch_size=batch_size, shuffle=True, collate_fn=collote_fn
)
valid_dataloader = DataLoader(
valid_data, batch_size=batch_size, shuffle=False, collate_fn=collote_fn
)
test_dataloader = DataLoader(
test_data, batch_size=batch_size, shuffle=False, collate_fn=collote_fn
)
def train_loop(dataloader, model, optimizer, lr_scheduler, epoch, total_loss, step):
progress_bar = tqdm(range(len(dataloader)))
progress_bar.set_description(f"loss: {0:>7f}")
model.train()
for batch, batch_data in enumerate(dataloader, start=1):
outputs = model(**batch_data)
loss = outputs.loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
total_loss += loss.item()
del last_1k_loss[0]
last_1k_loss.append(loss.item())
kmean_loss = sum(last_1k_loss) / len(last_1k_loss)
progress_bar.set_description(
f"loss: {kmean_loss:>7f}"
)
progress_bar.update(1)
step += 1
writer.add_scalar("Loss", kmean_loss, step)
writer.add_scalar("Overall Loss", total_loss / step, step)
if step % 250 == 0:
checkpoint = {
"model_state_dict": model.state_dict(),
"total_loss": total_loss,
"kmean_loss": kmean_loss,
"step": step,
}
torch.save(checkpoint, f"./saves/checkpoint_{step}.bin")
return total_loss, step
bleu = BLEU()
def test_loop(dataloader, model):
preds, labels = [], []
model.eval()
for batch_data in tqdm(dataloader):
with torch.no_grad():
generated_tokens = (
model.generate(
batch_data["input_ids"],
attention_mask=batch_data["attention_mask"],
max_length=max_target_length,
no_repeat_ngram_size=3,
)
.cpu()
.numpy()
)
label_tokens = batch_data["labels"].cpu().numpy()
decoded_preds = tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)
label_tokens = np.where(
label_tokens != -100, label_tokens, tokenizer.pad_token_id
)
decoded_labels = tokenizer.batch_decode(label_tokens, skip_special_tokens=True)
preds += [pred.strip() for pred in decoded_preds]
labels += [[label.strip()] for label in decoded_labels]
bleu_score = bleu.corpus_score(preds, labels).score
print(f"BLEU: {bleu_score:>0.2f}\n")
return bleu_score
optimizer = AdamW(model.parameters(), lr=learning_rate)
lr_scheduler = get_scheduler(
"linear",
optimizer=optimizer,
num_warmup_steps=int(0.1 * epoch_num * len(train_dataloader)),
num_training_steps=epoch_num * len(train_dataloader),
)
for t in range(epoch_num):
print(f"Epoch {t+1}/{epoch_num}\n {'-'*20}")
total_loss, step = train_loop(
train_dataloader, model, optimizer, lr_scheduler, t + 1, total_loss, step
)
valid_bleu = test_loop(valid_dataloader, model)
print("saving new weights...\n")
checkpoint = {
"model_state_dict": model.state_dict(),
"total_loss": total_loss,
"kmean_loss": kmean_loss,
"step": step,
}
torch.save(checkpoint, f"./saves/step_{step}_bleu_{valid_bleu:0.2f}.bin")
print("Done!")

View File

@ -0,0 +1,243 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "2d0860b5-8e4e-4596-a81f-be259e188775",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch.utils.data import Dataset, DataLoader\n",
"from transformers import AutoTokenizer, AutoModelForSeq2SeqLM\n",
"import json\n",
"import numpy as np\n",
"from sacrebleu.metrics import BLEU\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "7c22baec-c987-46cd-af6f-5eb594a2b1e4",
"metadata": {},
"outputs": [],
"source": [
"# 定义参数\n",
"checkpoint_path = \"./step_137000_valid_bleu_25.55_model_weights.bin\" # 假设你要加载第2个epoch中的500步的checkpoint\n",
"data_file = \"./data/translation2019zh/translation2019zh_valid.json\" # 假设使用验证集来测试\n",
"model_checkpoint = \"Helsinki-NLP/opus-mt-zh-en\"\n",
"max_dataset_size = 100\n",
"max_input_length = 128\n",
"max_target_length = 128\n",
"batch_size = 8"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "fff9be9b-57d8-4203-b644-155c76baa1ff",
"metadata": {},
"outputs": [],
"source": [
"class TRANS(Dataset):\n",
" def __init__(self, data_file):\n",
" self.data = self.load_data(data_file)\n",
" \n",
" def load_data(self, data_file):\n",
" Data = {}\n",
" with open(data_file, 'rt', encoding='utf-8') as f:\n",
" for idx, line in enumerate(f):\n",
" if idx >= max_dataset_size:\n",
" break\n",
" sample = json.loads(line.strip())\n",
" Data[idx] = sample\n",
" return Data\n",
" \n",
" def __len__(self):\n",
" return len(self.data)\n",
"\n",
" def __getitem__(self, idx):\n",
" return self.data[idx]\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "3be38090-c1e5-4fb5-b90d-7f18fe8dc23f",
"metadata": {},
"outputs": [],
"source": [
"def collote_fn(batch_samples):\n",
" batch_inputs, batch_targets = [], []\n",
" for sample in batch_samples:\n",
" batch_inputs.append(sample['chinese'])\n",
" batch_targets.append(sample['english'])\n",
" batch_data = tokenizer(\n",
" batch_inputs, \n",
" padding=True, \n",
" max_length=max_input_length,\n",
" truncation=True, \n",
" return_tensors=\"pt\"\n",
" )\n",
" with tokenizer.as_target_tokenizer():\n",
" labels = tokenizer(\n",
" batch_targets, \n",
" padding=True, \n",
" max_length=max_target_length,\n",
" truncation=True, \n",
" return_tensors=\"pt\"\n",
" )[\"input_ids\"]\n",
" batch_data['decoder_input_ids'] = model.prepare_decoder_input_ids_from_labels(labels)\n",
" end_token_index = torch.where(labels == tokenizer.eos_token_id)[1]\n",
" for idx, end_idx in enumerate(end_token_index):\n",
" labels[idx][end_idx+1:] = -100\n",
" batch_data['labels'] = labels\n",
" return batch_data\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "3d5a697a-9b44-4ff3-96df-24a2cb608773",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/25/gdz0c30x3mg1dj9qkwz0ch4w0000gq/T/ipykernel_13528/1590730426.py:6: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" model.load_state_dict(torch.load(checkpoint_path, map_location=\"cpu\"))\n"
]
}
],
"source": [
"# 加载模型和tokenizer\n",
"tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n",
"model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)\n",
"\n",
"# 加载checkpoint\n",
"model.load_state_dict(torch.load(checkpoint_path, map_location=\"cpu\"))\n",
"model.eval()\n",
"\n",
"# 将模型转移到设备\n",
"device = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n",
"model = model.to(device)\n",
"\n",
"# 加载测试数据\n",
"test_data = TRANS(data_file)\n",
"test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True, collate_fn=collote_fn)\n",
"\n",
"# 定义BLEU评估函数\n",
"bleu = BLEU()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "6446e3f4-b6e2-4f8a-abc4-6fee7224d517",
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"def test_model(dataloader, model):\n",
" preds, labels = [], []\n",
"\n",
" model.eval()\n",
" for batch_data in tqdm(dataloader):\n",
" batch_data = batch_data.to(device)\n",
" with torch.no_grad():\n",
" generated_tokens = model.generate(\n",
" batch_data[\"input_ids\"],\n",
" attention_mask=batch_data[\"attention_mask\"],\n",
" max_length=max_target_length,\n",
" ).cpu().numpy()\n",
"\n",
" label_tokens = batch_data[\"labels\"].cpu().numpy()\n",
" \n",
"\n",
" decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)\n",
" label_tokens = np.where(label_tokens != -100, label_tokens, tokenizer.pad_token_id)\n",
" decoded_labels = tokenizer.batch_decode(label_tokens, skip_special_tokens=True)\n",
"\n",
" preds += [pred.strip() for pred in decoded_preds]\n",
" labels += [[label.strip()] for label in decoded_labels]\n",
" \n",
" bleu_score = bleu.corpus_score(preds, labels).score\n",
" print(f\"BLEU: {bleu_score:>0.2f}\")\n",
" return bleu_score"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "f196f320-7d8b-44d5-9903-5fdb0532e318",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Testing model...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:33<00:00, 2.61s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"BLEU: 12.95\n",
"Test BLEU score: 12.95\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"print(\"Testing model...\")\n",
"bleu_score = test_model(test_dataloader, model)\n",
"print(f\"Test BLEU score: {bleu_score:0.2f}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "96e1d097-93d6-482d-8836-167974de98bc",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

112
translate/LLMtranslator.py Normal file
View File

@ -0,0 +1,112 @@
import os
from dotenv import load_dotenv
import json
import threading
from openai import OpenAI
from pathlib import Path
load_dotenv()
client = OpenAI(
api_key=os.getenv("API_KEY"),
base_url=os.getenv("BASE_URL"),
)
system_prompt = """
The user will provide some text. Please parse the text into segments, each segment contains 1 to 5 sentences. Translate each sentence into the corresponding language. If the input is in Chinese, return the English translation, and vice versa.
IMPORTANT:
1. Segment should not be too long, each segment should be under 100 English words or 180 Chinese characters.
2. For segments or sentences that appear multiple times in the original text, they are only output **once** in the returned translation.
3. **For content with obvious semantic differences, such as different components on a web page, no matter how short it is, it should be divided into a separate segment.**
4. **Information such as web page headers, footers, and other fixed text, such as copyright notices, website or company names, and conventional link text (such as "About Us", "Privacy Policy", etc.) will be **ignored and not translated**
5. If the provided text lacks proper punctuation, please add proper punctuation to both the source text and the translated text in the output.
EXAMPLE INPUT:
法律之前人人平等并有权享受法律的平等保护不受任何歧视人人有权享受平等保护以免受违反本宣言的任何歧视行为以及煽动这种歧视的任何行为之害
EXAMPLE JSON OUTPUT:
{
"segments": [
{"chinese": "法律之前人人平等,并有权享受法律的平等保护,不受任何歧视。", "english": "All are equal before the law and are entitled without any discrimination to equal protection of the law."},
{"chinese": "人人有权享受平等保护,以免受违反本宣言的任何歧视行为以及煽动这种歧视的任何行为之害。", "english": "All are entitled to equal protection against any discrimination in violation of this Declaration and against any incitement to such discrimination."}
]
}
"""
def translate_text(text, client, model_name, temp):
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": text},
]
response = client.chat.completions.create(
model=model_name,
messages=messages,
response_format={"type": "json_object"},
temperature=temp,
)
return json.loads(response.choices[0].message.content)
def process_file(input_file, output_dir):
try:
with open(input_file, "r", encoding="utf-8") as f:
text = f.read()
model = os.getenv("TRANSLATION_MODEL")
temp = float(os.getenv("TRANSLATION_TEMP"))
translation = translate_text(text, client, model, temp)
output_path = os.path.join(output_dir, Path(input_file).stem + ".json")
with open(output_path, "w", encoding="utf-8") as f:
json.dump(translation, f, ensure_ascii=False, indent=4)
print(f"Successfully translated and saved to {output_path}")
except Exception as e:
print(f"Error processing {input_file}: {e}")
def batch_process(input_dir, output_dir, num_threads=4):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
input_files = [
f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))
]
output_files = [
f for f in os.listdir(output_dir) if os.path.isfile(os.path.join(output_dir, f))
]
output_stems = {Path(f).stem for f in output_files}
files = [
os.path.join(input_dir, f)
for f in input_files
if Path(f).stem not in output_stems
]
threads = []
for file in files:
thread = threading.Thread(target=process_file, args=(file, output_dir))
threads.append(thread)
thread.start()
if len(threads) >= num_threads:
for t in threads:
t.join()
threads = []
for t in threads:
t.join()
if __name__ == "__main__":
input_dir = "./source-new"
output_dir = "./output-new"
batch_process(
input_dir, output_dir, num_threads=int(os.getenv("TRANSLATE_THREADS"))
)

50
translate/README.md Normal file
View File

@ -0,0 +1,50 @@
# sparkastML NMT
A set of models that aims to offer best open-source machine translation, based on the [OpenNMT](https://opennmt.net/).
## News
sparkastML's translation model is now updated!
### Details
- **Source Language:** Chinese (Simplified)
- **Target Language:** English
- **Training Time:** Totally 11.3 hours, 46,500 steps (~1×10¹⁸ FLOPs)
- **Training Device:**
- RTX 3080 (20GB): 0-20,000 steps
- RTX 4070: 20,000-46,500 steps
- **Corpus Size:** Over 10 million sentences
- **Validation BLEU Score:** 21.28
- **Validation Loss (Cross Entropy):** 3.152
### Model Download
Avaliable soon.
### Special thanks
[yumechi](https://github.com/eternal-flame-AD/) for sponsoring an RTX 4070 for training.
## History
### Sep 19, 2024
sparkastML's translation model is now updated!
#### Details
- **Source Language:** Chinese (Simplified)
- **Target Language:** English
- **Training Time:** 5 hours, 20,000 steps
- **Training Device:** RTX 3080 (20GB)
- **Corpus Size:** Over 10 million sentences
- **Validation BLEU Score:** 17
- **Version:** 1.0
#### Model Download
- **Google Drive:** [Download from Google Drive](https://drive.google.com/drive/folders/1-q_AKfQENW-pV6uAleUHPE9ghddfNWKF)
- **IPFS:** [Download from IPFS](http://ipfs.a2x.pub/ipfs/QmUMadzkBwvH5KTpoxfv7TgqzaPpqBzkXtkecV9TXPfZ3F/)
- **CID:** `QmUMadzkBwvH5KTpoxfv7TgqzaPpqBzkXtkecV9TXPfZ3F`
- **GitHub Release:** [Go to Release Page](https://github.com/alikia2x/sparkastML/releases/tag/v2-model)

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,79 @@
import torch
from transformers import AutoModel, AutoTokenizer
from numpy.linalg import norm
import sys
import random
from tqdm import tqdm
# Define the cosine similarity function
cos_sim = lambda a, b: (a @ b.T) / (norm(a) * norm(b))
# Load the model and tokenizer
model_name = 'jinaai/jina-embeddings-v2-base-zh'
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
# Check if the correct number of command-line arguments are provided
if len(sys.argv) < 4 or len(sys.argv) > 5:
print("Usage: python script.py <file_a_path> <file_b_path> <output_file_path> [num_samples]")
sys.exit(1)
# Define file paths from command-line arguments
file_a_path = sys.argv[1]
file_b_path = sys.argv[2]
output_file_path = sys.argv[3]
# Define the number of samples to randomly select
num_samples = int(sys.argv[4]) if len(sys.argv) == 5 else 100
# Get the total number of lines in the files without loading them fully
def count_lines(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
return sum(1 for _ in f)
total_lines_a = count_lines(file_a_path)
total_lines_b = count_lines(file_b_path)
# Ensure both files have the same number of lines
if total_lines_a != total_lines_b:
print("Files must have the same number of lines.")
sys.exit(1)
# Select random sample indices without loading entire files
selected_indices = sorted(random.sample(range(total_lines_a), num_samples))
# Function to get all sampled lines from the file
def get_lines(file_path, line_numbers):
result = []
max_i = max(line_numbers)
j=0
next_i = line_numbers[j]
len_line_numbers = len(line_numbers)
with open(file_path, 'r', encoding='utf-8') as f:
for current_line, line in tqdm(enumerate(f)):
if current_line < next_i:
continue
result.append(line.strip())
j+=1
if current_line >= max_i or j >= len_line_numbers:
return result
next_i = line_numbers[j]
return result
lines_a = get_lines(file_a_path, selected_indices)
lines_b = get_lines(file_b_path, selected_indices)
# Open output file for writing
with open(output_file_path, 'w', encoding='utf-8') as output_file:
for i, idx in tqdm(enumerate(selected_indices)):
# Get the corresponding lines from both files
line_a = lines_a[i]
line_b = lines_b[i]
embeddings = model.encode([line_a, line_b])
similarity = cos_sim(embeddings[0], embeddings[1])
# Write the similarity to the output file
output_file.write(f"{similarity}\n")
print(f"Similarity calculation completed. Results saved to {output_file_path}")

View File

@ -0,0 +1,60 @@
from transformers import AutoModel
from numpy.linalg import norm
import argparse
from tqdm import tqdm
parser = argparse.ArgumentParser(
description="Usage: python filter.py <file_a_path> <file_b_path> <output_file_path>"
)
parser.add_argument("file_a", type=str, help="File No.1")
parser.add_argument("file_b", type=str, help="File No.2")
parser.add_argument("output", type=str, help="Output file")
parser.add_argument(
"--resume",
type=int,
default=-1,
help="Resume from specified line",
)
args = parser.parse_args()
# Define the cosine similarity function
cos_sim = lambda a, b: (a @ b.T) / (norm(a) * norm(b))
# Load the model and tokenizer
model_name = 'jinaai/jina-embeddings-v2-base-zh'
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
model.to('cuda')
# Define file paths from command-line arguments
file_a_path = args.file_a
file_b_path = args.file_b
output_file_path = args.output
resume_from = args.resume
resume = resume_from >= 0
output_file_mode = 'a' if resume else 'w'
# Open files
with open(file_a_path, 'r', encoding='utf-8') as file_a, \
open(file_b_path, 'r', encoding='utf-8') as file_b, \
open(output_file_path, output_file_mode, encoding='utf-8') as output_file:
i=1
# Read file A and file B line by line
for line_a, line_b in tqdm(zip(file_a, file_b)):
if resume and i < resume_from:
i+=1
continue
# Remove trailing newline characters
line_a = line_a.strip()
line_b = line_b.strip()
embeddings = model.encode([line_a, line_b])
similarity = cos_sim(embeddings[0], embeddings[1])
# Write the similarity to the output file
output_file.write(f"{similarity}\n")
i+=1
print(f"Similarity calculation completed. Results saved to {output_file_path}")

View File

@ -0,0 +1,74 @@
from transformers import AutoModel
from numpy.linalg import norm
import sys
import random
import json
from tqdm import tqdm
# Define the cosine similarity function
cos_sim = lambda a, b: (a @ b.T) / (norm(a) * norm(b))
# Load the model and tokenizer
model_name = 'jinaai/jina-embeddings-v2-base-zh'
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
# Check if the correct number of command-line arguments are provided
if len(sys.argv) < 4 or len(sys.argv) > 5:
print("Usage: python script.py <file_path> <output_file_path> [num_samples]")
sys.exit(1)
# Define file paths from command-line arguments
file_path = sys.argv[1]
output_file_path = sys.argv[2]
# Define the number of samples to randomly select
num_samples = int(sys.argv[3]) if len(sys.argv) == 4 else 100
# Get the total number of lines in the files without loading them fully
def count_lines(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
return sum(1 for _ in f)
total_lines = count_lines(file_path)
# Select random sample indices without loading entire files
selected_indices = sorted(random.sample(range(total_lines), num_samples))
# Function to get all sampled lines from the file
def get_lines(file_path, line_numbers):
result = []
max_i = max(line_numbers)
j=0
next_i = line_numbers[j]
len_line_numbers = len(line_numbers)
with open(file_path, 'r', encoding='utf-8') as f:
for current_line, line in tqdm(enumerate(f)):
if current_line < next_i:
continue
result.append(line.strip())
j+=1
if current_line >= max_i or j >= len_line_numbers:
return result
next_i = line_numbers[j]
return result
lines = get_lines(file_path, selected_indices)
# Open output file for writing
with open(output_file_path, 'w', encoding='utf-8') as output_file, open("1.txt", 'w', encoding='utf-8') as lf:
for i, idx in tqdm(enumerate(selected_indices)):
# Get the corresponding lines from both files
line = lines[i]
data = json.loads(line)
chn = data["chinese"]
eng = data["english"]
lf.write(str(idx)+'\n')
embeddings = model.encode([chn, eng])
similarity = cos_sim(embeddings[0], embeddings[1])
# Write the similarity to the output file
output_file.write(f"{similarity}\n")
print(f"Similarity calculation completed. Results saved to {output_file_path}")

135
translate/fetcher.py Normal file
View File

@ -0,0 +1,135 @@
import sqlite3
import trafilatura
import hashlib
import re
import os
from dotenv import load_dotenv
from trafilatura.readability_lxml import is_probably_readerable
from concurrent.futures import ThreadPoolExecutor, as_completed
load_dotenv()
# 常量定义
MAX_FETCH_LIMIT = int(os.getenv("FETCH_LIMIT")) # 每次运行时获取的最大任务数量
# 数据库连接
def connect_db(db_path):
return sqlite3.connect(db_path)
# 创建fetch_list表
def create_fetch_list_table(conn):
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS fetch_list (
url TEXT PRIMARY KEY,
fetched_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
conn.commit()
# 获取未爬取的URL列表
def get_unfetched_urls(conn, limit):
cursor = conn.cursor()
cursor.execute("""
SELECT url FROM url_list
WHERE url NOT IN (SELECT url FROM fetch_list)
LIMIT ?
""", (limit,))
return [row[0] for row in cursor.fetchall()]
# 下载并提取网页内容
def fetch_and_extract_content(conn, url):
downloaded = trafilatura.fetch_url(url)
if not downloaded:
return None
html_string = downloaded
if not is_probably_readerable(html_string) and os.getenv("FETCH_IGNORE_CHECK").capitalize() == "TRUE":
print(f"URL {url} is not readable.")
record_fetched_url(conn, url)
return None
content = trafilatura.extract(html_string, output_format="txt", url=url, favor_precision=True)
print(f"Successfully extracted text for URL: {url}")
return content
# 计算URL的MD5
def md5_hash(url):
return hashlib.md5(url.encode()).hexdigest()
# 分段规则
def split_content(content):
sentences = re.split(r'[。!?;.!?;]', content)
segments = []
current_segment = []
current_length = 0
for sentence in sentences:
sentence_length = len(sentence)
if (len(current_segment) >= 12 or current_length + sentence_length > 1800):
segments.append(''.join(current_segment))
current_segment = []
current_length = 0
current_segment.append(sentence)
current_length += sentence_length
if current_segment:
segments.append(''.join(current_segment))
return segments
# 保存分段内容到文件
def save_segments(url, segments, path):
url_hash = md5_hash(url)
for idx, segment in enumerate(segments):
save_path = os.path.join(path, f"{url_hash}_{idx}.txt")
with open(save_path, "w", encoding="utf-8") as f:
f.write(segment)
# 记录已爬取的URL
def record_fetched_url(conn, url):
cursor = conn.cursor()
cursor.execute("""
INSERT INTO fetch_list (url, fetched_time)
VALUES (?, CURRENT_TIMESTAMP)
""", (url,))
conn.commit()
# 处理单个URL的任务
def process_url(url, db_path, save_path):
import time, random
cooldown_base = float(os.getenv("FETCH_COOLDOWN"))
time.sleep(random.random() * cooldown_base)
conn = connect_db(db_path)
content = fetch_and_extract_content(conn, url)
if content:
segments = split_content(content)
save_segments(url, segments, save_path)
record_fetched_url(conn, url)
conn.close()
time.sleep(random.random() * cooldown_base)
# 主函数
def main():
db_path = "crawler.db"
save_path = "./source"
conn = connect_db(db_path)
# 创建fetch_list表
create_fetch_list_table(conn)
unfetched_urls = get_unfetched_urls(conn, MAX_FETCH_LIMIT)
conn.close()
with ThreadPoolExecutor(max_workers=int(os.getenv("FETCH_THREADS"))) as executor:
futures = [executor.submit(process_url, url, db_path, save_path) for url in unfetched_urls]
for future in as_completed(futures):
try:
future.result()
except Exception as e:
print(f"An error occurred: {e}")
if __name__ == "__main__":
main()

30
translate/hf-dataset.py Normal file
View File

@ -0,0 +1,30 @@
import pandas as pd
# 定义文件路径
source_files = ['./result/source.txt', './result/source-new.txt']
target_files = ['./result/target.txt', './result/target-new.txt']
# 读取source和target文件内容
source_data = []
target_data = []
for file in source_files:
with open(file, 'r', encoding='utf-8') as f:
source_data.extend(f.readlines())
for file in target_files:
with open(file, 'r', encoding='utf-8') as f:
target_data.extend(f.readlines())
# 确保source和target行数一致
if len(source_data) != len(target_data):
print("Warning: The number of lines in source and target files do not match.")
# 创建DataFrame
df = pd.DataFrame({
'zh': [line.strip() for line in source_data], # 去掉每行的换行符
'en': [line.strip() for line in target_data] # 去掉每行的换行符
})
df.to_csv('./result/data.csv', index=False, encoding='utf-8')

48
translate/postprocess.py Normal file
View File

@ -0,0 +1,48 @@
import os
import json
from pybloom_live import BloomFilter
def read_converted_files(filename):
"""读取converted.txt文件返回一个包含已处理文件名的集合"""
if os.path.exists(filename):
with open(filename, 'r', encoding='utf-8') as file:
return set(file.read().splitlines())
return set()
def write_converted_file(filename, file_name):
"""将处理过的文件名写入converted.txt"""
with open(filename, 'a', encoding='utf-8') as file:
file.write(file_name + '\n')
def process_json_files(directory, converted_filename):
"""处理指定目录下的所有json文件"""
converted_files = read_converted_files(converted_filename)
bloom_filter_chinese = BloomFilter(capacity=1000000, error_rate=0.001) # 初始化Bloom Filter
bloom_filter_english = BloomFilter(capacity=1000000, error_rate=0.001) # 初始化Bloom Filter
for filename in os.listdir(directory):
if filename.endswith('.json') and filename not in converted_files:
file_path = os.path.join(directory, filename)
with open(file_path, 'r', encoding='utf-8') as json_file:
data = json.load(json_file)
segments = data.get('segments', [])
with open('./result/source-new.txt', 'a', encoding='utf-8') as source_file, \
open('./result/target-new.txt', 'a', encoding='utf-8') as target_file:
for segment in segments:
chinese_text = segment.get('chinese', '').replace('\n', ' ')
english_text = segment.get('english', '').replace('\n', ' ')
if chinese_text not in bloom_filter_chinese and english_text not in bloom_filter_english:
bloom_filter_chinese.add(chinese_text)
source_file.write(chinese_text + '\n')
bloom_filter_english.add(english_text)
target_file.write(english_text + '\n')
write_converted_file(converted_filename, filename)
if __name__ == "__main__":
json_directory = './output-new' # 替换为你的JSON文件目录路径
converted_filename = './result/converted.txt'
process_json_files(json_directory, converted_filename)

139
translate/spider.py Normal file
View File

@ -0,0 +1,139 @@
import os
import requests
from bs4 import BeautifulSoup
import sqlite3
import urllib.robotparser as urobot
from urllib.parse import urljoin, urlparse
from dotenv import load_dotenv
MAX_RECURSION_DEPTH = 5
MAX_URLS = 1000
MAX_THREADS = 10
HEADERS = {
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36",
}
conn = sqlite3.connect("crawler.db")
cursor = conn.cursor()
# Initialization
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS url_list (
url TEXT PRIMARY KEY,
fetched_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
visited BOOLEAN,
parent_url TEXT,
child_url_count INTEGER
)
"""
)
conn.commit()
def fetch_url(url, headers=None):
try:
response = requests.get(url, headers=headers)
response.raise_for_status()
return response.text, response.url
except requests.RequestException as e:
print(f"Error fetching {url}: {e}")
return None, url
def extract_links(html, base_url):
soup = BeautifulSoup(html, "html.parser")
links = set()
for a_tag in soup.find_all("a", href=True):
href = a_tag["href"]
full_url = urljoin(base_url, href)
if urlparse(full_url).netloc == urlparse(base_url).netloc:
links.add(full_url)
return links
def fetch_sitemap(sitemap_url):
html, _ = fetch_url(sitemap_url)
if html:
soup = BeautifulSoup(html, "xml")
urls = {loc.text for loc in soup.find_all("loc")}
return urls
return set()
def save_url(url, parent_url=None):
cursor = conn.cursor()
cursor.execute(
"""
INSERT OR IGNORE INTO url_list (url, visited, parent_url, child_url_count)
VALUES (?, ?, ?, ?)
""",
(url, False, parent_url, 0),
)
conn.commit()
def update_url(url, child_url_count):
cursor.execute(
"""
UPDATE url_list SET child_url_count = ? WHERE url = ?
""",
(child_url_count, url),
)
conn.commit()
def crawl(url, rp=None, depth=0):
if depth > MAX_RECURSION_DEPTH:
return
if (
rp
and rp.can_fetch("*", url) == False
and rp.can_fetch("Googlebot", url) == False
and rp.can_fetch("Baiduspider", url) == False
):
return
save_url(url)
html, fetched_url = fetch_url(url, HEADERS)
if not html:
return
cursor.execute(
"""
UPDATE url_list SET visited = TRUE, fetched_time = CURRENT_TIMESTAMP WHERE url = ?
""",
(fetched_url,),
)
conn.commit()
links = extract_links(html, fetched_url)
for link in links:
save_url(link, fetched_url)
update_url(fetched_url, len(links))
for link in links:
crawl(link, depth=depth + 1)
def main(seed_url, rp, sitemap=None):
if sitemap:
sitemap_urls = fetch_sitemap(sitemap)
for sitemap_url in sitemap_urls:
save_url(sitemap_url)
crawl(seed_url, rp=rp)
def get_config(key):
return os.getenv(key)
# Example usage
if __name__ == "__main__":
load_dotenv()
seed_url = get_config("SEED_URL")
rp = urobot.RobotFileParser()
rp.set_url(get_config("ROBOTS_URL"))
rp.read()
main(seed_url, rp, get_config("SITEMAP_URL"))
conn.close()

52
translate/split_source.py Normal file
View File

@ -0,0 +1,52 @@
import os
import re
def split_content(content):
sentences = re.split(r'[。!?;.!?;]', content)
segments = []
current_segment = []
current_length = 0
for sentence in sentences:
sentence_length = len(sentence)
if (len(current_segment) >= 25 or current_length + sentence_length > 1200):
segments.append(''.join(current_segment))
current_segment = []
current_length = 0
current_segment.append(sentence)
current_length += sentence_length
if current_segment:
segments.append(''.join(current_segment))
return segments
def process_files_in_directory(directory):
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
# 只处理文件,跳过目录
if os.path.isfile(file_path):
with open(file_path, 'r', encoding='utf-8') as file:
content = file.read()
segments = split_content(content)
if len(segments) > 1:
# 删除原始文件
os.remove(file_path)
# 保存分割后的文件
for i, segment in enumerate(segments):
new_filename = f"{filename}_{i+1}"
new_file_path = os.path.join(directory, new_filename)
with open(new_file_path, 'w', encoding='utf-8') as new_file:
new_file.write(segment)
else:
print(f"文件 {filename} 不需要分割")
# 指定目录
directory = './source-new'
process_files_in_directory(directory)

View File

@ -0,0 +1,44 @@
from openai import OpenAI
import argparse
import os
from dotenv import load_dotenv
from tqdm import tqdm
def translate_text(text, client, model_name, temp):
messages = [
{"role": "system", "content": "User will provide some text. You need to translate the text into English and output it WITHOUT ANY ADDITIONAL INFORMATION OR EXPLANATION."},
{"role": "user", "content": text},
]
response = client.chat.completions.create(
model=model_name,
messages=messages,
temperature=temp,
)
return response.choices[0].message.content
load_dotenv()
parser = argparse.ArgumentParser()
parser.add_argument("input", type=str, help="Path to the input file")
parser.add_argument("output", type=str, help="Path to the output file")
args = parser.parse_args()
input_file = args.input
output_file = args.output
client = OpenAI(
api_key=os.getenv("API_KEY"),
base_url=os.getenv("BASE_URL"),
)
model = os.getenv("TRANSLATION_MODEL")
temp = float(os.getenv("TRANSLATION_TEMP"))
with open(input_file, "r") as f:
src_lines = f.readlines()
for line in tqdm(src_lines):
result = translate_text(line, client, model, temp)
with open(output_file, 'a') as f:
f.write(result + '\n')

View File

@ -0,0 +1,15 @@
import subprocess
from tqdm import tqdm
def translate_text(text):
command = f'argos-translate --from zh --to en "{text}"'
result = subprocess.run(command, shell=True, capture_output=True, text=True)
return result.stdout.strip()
with open("./data/src.txt", "r") as f:
src_lines = f.readlines()
for line in tqdm(src_lines):
result = translate_text(line)
with open("./data/hyp-sk-1.2.txt", 'a') as f:
f.write(result + '\n')

View File

@ -0,0 +1,42 @@
import json
import subprocess
import evaluate
from nltk.tokenize import word_tokenize
from tqdm import tqdm
bleu_cal = evaluate.load("chrf")
def translate_text(text):
command = f'argos-translate --from zh --to en "{text}"'
result = subprocess.run(command, shell=True, capture_output=True, text=True)
return result.stdout.strip()
def main():
# 读取数据集
with open('./data/1.jsonl', 'r', encoding='utf-8') as f:
data = [json.loads(line) for line in f]
translations = []
references = []
# for entry in tqdm(data):
# chinese_sentence = entry['zh']
# translated_sentence = translate_text(chinese_sentence)
# with open("./data/1-inf.txt", "a") as f:
# f.write(translated_sentence + "\n")
# translations.append(translated_sentence)
with open("./data/1-inf.txt", 'r') as f:
translations = f.readlines()
for entry in data:
english_sentence = entry['en']
references.append([english_sentence])
# 计算 BLEU 分数
bleu = bleu_cal.compute(predictions=translations, references=references)
print(bleu)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,10 @@
from googletrans import Translator
translator = Translator()
with open("./data/src.txt", "r") as f:
src_lines = f.readlines()
for line in src_lines:
result = translator.translate(line, dest='en')
with open("./data/hyp-gg-py.txt", 'a') as f:
f.write(result.text + '\n')

View File

@ -0,0 +1,19 @@
from tqdm import tqdm
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M")
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
def translate_text(text):
tokenizer.src_lang = "zh"
encoded_zh = tokenizer(text, return_tensors="pt")
generated_tokens = model.generate(**encoded_zh, forced_bos_token_id=tokenizer.get_lang_id("en"))
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
return result[0]
with open("./data/src.txt", "r") as f:
src_lines = f.readlines()
for line in tqdm(src_lines):
result = translate_text(line)
with open("./data/hyp-m2m.txt", 'a') as f:
f.write(result + '\n')

View File

@ -0,0 +1,56 @@
import json
import random
import argparse
from tqdm import tqdm
# 读取jsonl文件
def read_jsonl(file_path):
with open(file_path, "r", encoding="utf-8") as file:
for line in file:
yield json.loads(line)
# 随机抽取一定数量的行
def sample_lines(data, sample_size):
return random.sample(list(data), sample_size)
# 主函数
def main(input_file, sample_size):
# 读取jsonl文件
data = read_jsonl(input_file)
# 随机抽取一定数量的行
sampled_data = sample_lines(data, sample_size)
for item in tqdm(sampled_data):
chinese_text = item["chinese"]
english_text = item["english"]
with open("./data/src.txt", 'a') as srcf, open("./data/ref.txt", 'a') as reff:
srcf.write(chinese_text + '\n')
reff.write(english_text + '\n')
# 示例调用
if __name__ == "__main__":
# 创建命令行参数解析器
parser = argparse.ArgumentParser(
description="Process a JSONL file by sampling lines and translating text."
)
# 添加命令行参数
parser.add_argument("input", type=str, help="Path to the input JSONL file")
parser.add_argument(
"--sample_size",
type=int,
default=100,
help="Number of lines to sample (default: 100)",
)
# 解析命令行参数
args = parser.parse_args()
# 调用主函数
main(args.input, args.sample_size)

View File

@ -0,0 +1,16 @@
import re
# 读取文件内容
with open('ug.txt', 'r', encoding='utf-8') as file:
data = file.read()
# 定义正则表达式,保留维吾尔语字母、阿拉伯数字及常见标点符号
# 维吾尔语字母的Unicode范围是U+0600-U+06FF
# 阿拉伯数字 0-9以及标点符号,:)可以根据需要调整
filtered_data = re.sub(r'[^\u0600-\u06FF0-9.,!?؛:\s]', '', data)
# 将过滤后的数据输出或保存到新的文件中
with open('filtered_ug.txt', 'w', encoding='utf-8') as file:
file.write(filtered_data)
print("过滤完成,结果已保存到 filtered_ug.txt")

13
ugNMT/BPE/filter_space.py Normal file
View File

@ -0,0 +1,13 @@
import re
def replace_spaces_in_file(input_file_path, output_file_path):
with open(input_file_path, 'r', encoding='utf-8') as file:
text = file.read()
new_text = re.sub(r' +', ' ', text)
with open(output_file_path, 'w', encoding='utf-8') as file:
file.write(new_text)
# 调用函数,替换文件中的空格
replace_spaces_in_file('./data/ug_texts1.txt', './data/2.txt')