Compare commits
No commits in common. "main" and "v2-model" have entirely different histories.
9
.gitignore
vendored
9
.gitignore
vendored
@ -9,14 +9,9 @@ token_to_id.json
|
||||
__pycache__
|
||||
.env
|
||||
.env*
|
||||
translate/output*
|
||||
translate/source*
|
||||
translate/output
|
||||
translate/source
|
||||
translate/result
|
||||
*.db
|
||||
dataset/raw
|
||||
translate/special-spiders
|
||||
ugNMT/BPE/output*
|
||||
ugNMT/BPE/codes
|
||||
forced-alignment/segments
|
||||
forced-alignment/data
|
||||
forced-alignment/output.ttml
|
@ -1,7 +0,0 @@
|
||||
# 强制对齐在歌词逐字对齐上的应用
|
||||
|
||||
这个子项目是为了给[AquaVox](https://github.com/alikia2x/aquavox)提供AI加持的逐字歌词功能所诞生的。
|
||||
|
||||
## 规划
|
||||
|
||||
对于给定歌词和
|
File diff suppressed because one or more lines are too long
@ -1,228 +0,0 @@
|
||||
import torch
|
||||
import torchaudio
|
||||
from typing import List
|
||||
from pypinyin import lazy_pinyin
|
||||
from pypinyin_dict.phrase_pinyin_data import cc_cedict
|
||||
from torchaudio.transforms import Resample
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def compute_alignments(waveform: torch.Tensor, transcript: List[str]):
|
||||
with torch.inference_mode():
|
||||
emission, _ = model(waveform.to(device))
|
||||
token_spans = aligner(emission[0], tokenizer(transcript))
|
||||
return emission, token_spans
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
from torchaudio.pipelines import MMS_FA as bundle
|
||||
|
||||
model = bundle.get_model()
|
||||
model.to(device)
|
||||
|
||||
tokenizer = bundle.get_tokenizer()
|
||||
aligner = bundle.get_aligner()
|
||||
|
||||
cc_cedict.load()
|
||||
|
||||
add_spaces = lambda s: ' '.join(s)
|
||||
|
||||
from pydub import AudioSegment
|
||||
|
||||
def get_audio_duration(file_path):
|
||||
"""
|
||||
读取音频文件并获取其时长(秒数)。
|
||||
|
||||
:param file_path: 音频文件的路径
|
||||
:return: 音频文件的时长(秒数)
|
||||
"""
|
||||
try:
|
||||
audio = AudioSegment.from_file(file_path)
|
||||
duration_in_seconds = len(audio) / 1000.0
|
||||
return duration_in_seconds
|
||||
except Exception as e:
|
||||
print(f"Error reading audio file: {e}")
|
||||
return None
|
||||
|
||||
def timestamp(seconds):
|
||||
"""
|
||||
将浮点数的秒钟转换为TTML的时间戳格式(HH:MM:SS.sss)。
|
||||
|
||||
:param seconds: 浮点数的秒钟
|
||||
:return: TTML时间戳格式字符串
|
||||
"""
|
||||
hours = int(seconds // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
seconds = seconds % 60
|
||||
milliseconds = int((seconds % 1) * 1000)
|
||||
seconds = int(seconds)
|
||||
|
||||
return f"{hours:02}:{minutes:02}:{seconds:02}.{milliseconds:03}"
|
||||
|
||||
def timestamp_inverse(ttml_timestamp):
|
||||
"""
|
||||
将TTML的时间戳格式字符串(HH:MM:SS.sss)转换为浮点数的秒钟。
|
||||
|
||||
:param ttml_timestamp: TTML时间戳格式字符串
|
||||
:return: 浮点数的秒钟
|
||||
"""
|
||||
parts = ttml_timestamp.split(':')
|
||||
hours = int(parts[0])
|
||||
minutes = int(parts[1])
|
||||
seconds_and_milliseconds = parts[2].split('.')
|
||||
seconds = int(seconds_and_milliseconds[0])
|
||||
milliseconds = int(seconds_and_milliseconds[1])
|
||||
|
||||
total_seconds = hours * 3600 + minutes * 60 + seconds + milliseconds / 1000
|
||||
|
||||
return total_seconds
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
def extract_numbers_from_files(directory):
|
||||
"""
|
||||
读取给定目录,提取文件名中的数字部分,并返回一个包含这些数字的列表。
|
||||
|
||||
:param directory: 目录路径
|
||||
:return: 包含数字的列表
|
||||
"""
|
||||
numbers = []
|
||||
pattern = re.compile(r'line-(\d+)\.wav')
|
||||
|
||||
try:
|
||||
for filename in os.listdir(directory):
|
||||
match = pattern.match(filename)
|
||||
if match:
|
||||
number = int(match.group(1))
|
||||
numbers.append(number)
|
||||
except Exception as e:
|
||||
print(f"Error reading directory: {e}")
|
||||
return None
|
||||
|
||||
return numbers
|
||||
|
||||
class TTMLGenerator:
|
||||
def __init__(self, duration, xmlns="http://www.w3.org/ns/ttml", xmlns_ttm="http://www.w3.org/ns/ttml#metadata", xmlns_amll="http://www.example.com/ns/amll", xmlns_itunes="http://music.apple.com/lyric-ttml-internal"):
|
||||
self.tt = ET.Element("tt", attrib={
|
||||
"xmlns": xmlns,
|
||||
"xmlns:ttm": xmlns_ttm,
|
||||
"xmlns:amll": xmlns_amll,
|
||||
"xmlns:itunes": xmlns_itunes
|
||||
})
|
||||
self.head = ET.SubElement(self.tt, "head")
|
||||
self.metadata = ET.SubElement(self.head, "metadata")
|
||||
self.body = ET.SubElement(self.tt, "body", attrib={"dur": duration})
|
||||
self.div = ET.SubElement(self.body, "div")
|
||||
|
||||
def add_lyrics(self, begin, end, agent, itunes_key, words):
|
||||
p = ET.SubElement(self.div, "p", attrib={
|
||||
"begin": begin,
|
||||
"end": end,
|
||||
"ttm:agent": agent,
|
||||
"itunes:key": itunes_key
|
||||
})
|
||||
for word, start, stop in words:
|
||||
span = ET.SubElement(p, "span", attrib={"begin": start, "end": stop})
|
||||
span.text = word
|
||||
|
||||
def save(self, filename):
|
||||
tree = ET.ElementTree(self.tt)
|
||||
tree.write(filename, encoding="utf-8", xml_declaration=True)
|
||||
|
||||
duration = get_audio_duration("./data/谷雨.mp3")
|
||||
|
||||
# 示例使用
|
||||
ttml_generator = TTMLGenerator(duration=timestamp(duration))
|
||||
|
||||
|
||||
def process_line(line_idx, start_time, total_lines):
|
||||
with open(f"./segments/line-{line_idx}.txt", "r") as f:
|
||||
text = f.read()
|
||||
|
||||
waveform, sample_rate = torchaudio.load(f"./segments/line-{line_idx}.wav")
|
||||
|
||||
waveform = waveform[0:1]
|
||||
resampler = Resample(orig_freq=sample_rate, new_freq=16000)
|
||||
waveform = resampler(waveform)
|
||||
|
||||
text_pinyin = lazy_pinyin(text)
|
||||
text_normalized = " ".join(text_pinyin)
|
||||
|
||||
transcript = text_normalized.split()
|
||||
emission, token_spans = compute_alignments(waveform, transcript)
|
||||
num_frames = emission.size(1)
|
||||
ratio = waveform.size(1) / num_frames
|
||||
|
||||
words = []
|
||||
for i in range(len(token_spans)):
|
||||
spans = token_spans[i]
|
||||
x0 = start_time + int(ratio * spans[0].start) / 16000
|
||||
x1 = start_time + int(ratio * spans[-1].end) / 16000
|
||||
words.append({
|
||||
"word": text[i],
|
||||
"start": x0,
|
||||
"end": x1
|
||||
})
|
||||
idx=0
|
||||
for item in words:
|
||||
if idx == len(words) - 1:
|
||||
break
|
||||
item["end"] = words[idx + 1]["start"]
|
||||
idx+=1
|
||||
result = []
|
||||
for word in words:
|
||||
result.append((word["word"], timestamp(word["start"]), timestamp(word["end"])))
|
||||
return result
|
||||
|
||||
|
||||
lines_to_process = sorted(extract_numbers_from_files("segments"))
|
||||
|
||||
def parse_lrc(lrc_file, audio_len):
|
||||
"""解析LRC文件,返回一个包含时间戳和歌词的列表"""
|
||||
with open(lrc_file, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
lrc_data = []
|
||||
for line in lines:
|
||||
# 使用正则表达式匹配时间戳和歌词
|
||||
match = re.match(r'\[(\d+):(\d+\.\d+)\](.*)', line)
|
||||
if match:
|
||||
minutes = int(match.group(1))
|
||||
seconds = float(match.group(2))
|
||||
lyric = match.group(3).strip()
|
||||
lyric = lyric.replace(" ", "")
|
||||
timestamp = minutes * 60 + seconds
|
||||
lrc_data.append((lyric, timestamp))
|
||||
|
||||
for i, (lyric, start_time) in enumerate(lrc_data):
|
||||
# Skip empty line
|
||||
if lyric.strip() == "":
|
||||
continue
|
||||
if i < len(lrc_data) - 1:
|
||||
end_time = lrc_data[i + 1][1]
|
||||
else:
|
||||
end_time = audio_len
|
||||
lrc_data[i] = (lyric, start_time, end_time)
|
||||
|
||||
# Filter empty lines again
|
||||
lrc_data = [line for line in lrc_data if line[0].strip() != ""]
|
||||
|
||||
return lrc_data
|
||||
|
||||
lrc_data = parse_lrc("./data/谷雨.lrc", duration)
|
||||
|
||||
i=0
|
||||
for line_num in tqdm(lines_to_process):
|
||||
start_time = lrc_data[i][1]
|
||||
end_time = lrc_data[i][2]
|
||||
result = process_line(line_num, start_time, len(lines_to_process))
|
||||
ttml_generator.add_lyrics(
|
||||
begin=timestamp(start_time), end=timestamp(end_time), agent="v1", itunes_key=f"L{i+1}",
|
||||
words=result
|
||||
)
|
||||
i+=1
|
||||
|
||||
# 保存文件
|
||||
ttml_generator.save("output.ttml")
|
@ -1,57 +0,0 @@
|
||||
from pydub import AudioSegment
|
||||
import re
|
||||
|
||||
def parse_lrc(lrc_file):
|
||||
"""解析LRC文件,返回一个包含时间戳和歌词的列表"""
|
||||
with open(lrc_file, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
lrc_data = []
|
||||
for line in lines:
|
||||
# 使用正则表达式匹配时间戳和歌词
|
||||
match = re.match(r'\[(\d+):(\d+\.\d+)\](.*)', line)
|
||||
if match:
|
||||
minutes = int(match.group(1))
|
||||
seconds = float(match.group(2))
|
||||
lyric = match.group(3).strip()
|
||||
lyric = lyric.replace(" ", "")
|
||||
timestamp = minutes * 60 + seconds
|
||||
lrc_data.append((timestamp, lyric))
|
||||
|
||||
return lrc_data
|
||||
|
||||
def split_audio_by_lrc(audio_file, lrc_data, output_prefix):
|
||||
"""根据LRC数据分割音频文件,并保存为单独的WAV文件"""
|
||||
audio = AudioSegment.from_file(audio_file)
|
||||
|
||||
for i, (start_time, lyric) in enumerate(lrc_data):
|
||||
# Skip empty line
|
||||
if lyric.strip() == "":
|
||||
continue
|
||||
if i < len(lrc_data) - 1:
|
||||
end_time = lrc_data[i + 1][0]
|
||||
else:
|
||||
end_time = len(audio) / 1000 # 最后一行歌词到音频结束
|
||||
start_time = max(0, start_time - 0.1) # 前后各扩0.1秒
|
||||
end_time = min(len(audio) / 1000, end_time + 0.1)
|
||||
start_time_ms = start_time * 1000
|
||||
end_time_ms = end_time * 1000
|
||||
|
||||
segment = audio[start_time_ms:end_time_ms]
|
||||
output_file = f"{output_prefix}-{i+1}.wav"
|
||||
output_script = f"{output_prefix}-{i+1}.txt"
|
||||
output_time = f"{output_prefix}-{i+1}.time"
|
||||
segment.export(output_file, format="wav")
|
||||
with open(output_script, "w") as f:
|
||||
f.write(lyric)
|
||||
with open(output_time, "w") as f:
|
||||
f.write(str(start_time)+","+str(end_time))
|
||||
print(f"Saved {output_file}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
lrc_file = "./data/谷雨.lrc" # LRC文件路径
|
||||
audio_file = "./data/谷雨.mp3" # 音频文件路径
|
||||
output_prefix = "segments/line" # 输出文件名的前缀
|
||||
|
||||
lrc_data = parse_lrc(lrc_file)
|
||||
split_audio_by_lrc(audio_file, lrc_data, output_prefix)
|
@ -1,88 +0,0 @@
|
||||
import torch
|
||||
import torchaudio
|
||||
from typing import List
|
||||
from pypinyin import lazy_pinyin
|
||||
from pypinyin_dict.phrase_pinyin_data import cc_cedict
|
||||
from torchaudio.transforms import Resample
|
||||
import xml.etree.ElementTree as ET
|
||||
import argparse
|
||||
|
||||
from pydub import AudioSegment
|
||||
|
||||
def get_audio_duration(file_path):
|
||||
"""
|
||||
读取音频文件并获取其时长(秒数)。
|
||||
|
||||
:param file_path: 音频文件的路径
|
||||
:return: 音频文件的时长(秒数)
|
||||
"""
|
||||
try:
|
||||
audio = AudioSegment.from_file(file_path)
|
||||
duration_in_seconds = len(audio) / 1000.0
|
||||
return duration_in_seconds
|
||||
except Exception as e:
|
||||
print(f"Error reading audio file: {e}")
|
||||
return None
|
||||
|
||||
def timestamp(seconds):
|
||||
"""
|
||||
将浮点数的秒钟转换为SRT的时间戳格式(HH:MM:SS,sss)。
|
||||
|
||||
:param seconds: 浮点数的秒钟
|
||||
:return: SRT时间戳格式字符串
|
||||
"""
|
||||
hours = int(seconds // 3600)
|
||||
minutes = int((seconds % 3600) // 60)
|
||||
seconds = seconds % 60
|
||||
milliseconds = int((seconds % 1) * 1000)
|
||||
seconds = int(seconds)
|
||||
|
||||
return f"{hours:02}:{minutes:02}:{seconds:02},{milliseconds:03}"
|
||||
|
||||
|
||||
def compute_alignments(waveform: torch.Tensor, transcript: List[str]):
|
||||
with torch.inference_mode():
|
||||
emission, _ = model(waveform.to(device))
|
||||
token_spans = aligner(emission[0], tokenizer(transcript))
|
||||
return emission, token_spans
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
from torchaudio.pipelines import MMS_FA as bundle
|
||||
|
||||
model = bundle.get_model()
|
||||
model.to(device)
|
||||
|
||||
tokenizer = bundle.get_tokenizer()
|
||||
aligner = bundle.get_aligner()
|
||||
|
||||
cc_cedict.load()
|
||||
|
||||
with open("./data/扬旗鸣鼓.txt", "r") as f:
|
||||
text = f.read()
|
||||
|
||||
text_pinyin = lazy_pinyin(text)
|
||||
text_normalized = " ".join(text_pinyin)
|
||||
|
||||
waveform, sample_rate = torchaudio.load("./data/扬旗鸣鼓.wav")
|
||||
|
||||
waveform = waveform[0:1]
|
||||
resampler = Resample(orig_freq=sample_rate, new_freq=16000)
|
||||
waveform = resampler(waveform)
|
||||
|
||||
transcript = text_normalized.split()
|
||||
emission, token_spans = compute_alignments(waveform, transcript)
|
||||
num_frames = emission.size(1)
|
||||
|
||||
ratio = waveform.size(1) / num_frames
|
||||
|
||||
duration = get_audio_duration("./data/扬旗鸣鼓.wav")
|
||||
|
||||
for i in range(len(token_spans)):
|
||||
spans = token_spans[i]
|
||||
x0 = round(int(ratio * spans[0].start) / 16000, 3)
|
||||
x1 = round(int(ratio * spans[-1].end) / 16000, 3)
|
||||
with open("1.srt", "a") as f:
|
||||
f.write(f"{i+1}\n")
|
||||
f.write(f"{timestamp(x0)} --> {timestamp(x1)}\n")
|
||||
f.write(f"{transcript[i]}\n\n")
|
||||
f.flush()
|
@ -1,84 +0,0 @@
|
||||
import re
|
||||
|
||||
def srt_to_lrc(srt_text):
|
||||
# 使用正则表达式匹配时间戳和内容
|
||||
# who the fuck knows this
|
||||
srt_text+='\n\n'
|
||||
pattern = re.compile(r'(\d{2}:\d{2}:\d{2},\d{3}) --> (\d{2}:\d{2}:\d{2},\d{3})\n(.+?)\n\n', re.DOTALL)
|
||||
matches = pattern.findall(srt_text)
|
||||
lrc_lines = []
|
||||
|
||||
for start_time, end_time, content in matches:
|
||||
# 提取开始时间的高亮字符
|
||||
highlight_char = re.search(r'<font color="#00ff00">(.+?)</font>', content)
|
||||
if highlight_char:
|
||||
highlight_char = highlight_char.group(1)
|
||||
else:
|
||||
continue
|
||||
|
||||
# 将时间戳转换为LRC格式
|
||||
f,start_minutes, start_seconds, start_milliseconds = map(int, start_time.replace(',', ':').split(':'))
|
||||
f,end_minutes, end_seconds, end_milliseconds = map(int, end_time.replace(',', ':').split(':'))
|
||||
|
||||
start_time_lrc = f"{start_minutes:02d}:{start_seconds:02d}.{start_milliseconds:02d}"
|
||||
end_time_lrc = f"{end_minutes:02d}:{end_seconds:02d}.{end_milliseconds:02d}"
|
||||
|
||||
# 构建LRC行
|
||||
lrc_line = f"{highlight_char}|{start_time_lrc},{end_time_lrc}"
|
||||
lrc_lines.append(lrc_line)
|
||||
|
||||
# 如果内容中有换行符,将其替换为空格
|
||||
lrc_line = lrc_line.replace('\n', ' ')
|
||||
|
||||
return '\n'.join(lrc_lines)
|
||||
|
||||
with open('./data/谷雨.srt', 'r') as f:
|
||||
srt_text = f.read()
|
||||
|
||||
whole = srt_text.splitlines()[2].replace('<font color="#00ff00">','').replace('</font>','')
|
||||
whole = whole.replace(' ','\n')
|
||||
lines = whole.splitlines()
|
||||
|
||||
lyric_text = ""
|
||||
raw_text = srt_to_lrc(srt_text)
|
||||
raw_lines = raw_text.splitlines()
|
||||
for line in raw_lines:
|
||||
lyric_text += line.split('|')[0]
|
||||
|
||||
raw_idx=0
|
||||
lines_start_chr_idx=[]
|
||||
for line in lines:
|
||||
start = line[0]
|
||||
end = line[-1]
|
||||
while raw_idx < len(raw_lines) and not line.startswith(raw_lines[raw_idx].split("|")[0]):
|
||||
raw_idx += 1
|
||||
lines_start_chr_idx.append(raw_idx)
|
||||
lines_start_chr_idx.append(len(raw_lines)-1)
|
||||
|
||||
raw_idx=0
|
||||
lines_end_chr_idx=[]
|
||||
for line in lines:
|
||||
start = line[0]
|
||||
end = line[-1]
|
||||
while raw_idx < len(raw_lines) and not line.endswith(raw_lines[raw_idx].split("|")[0]):
|
||||
raw_idx += 1
|
||||
lines_end_chr_idx.append(raw_idx)
|
||||
lines_end_chr_idx.append(len(raw_lines)-1)
|
||||
|
||||
lrc_text = ""
|
||||
for i in range(len(lines_start_chr_idx)-1):
|
||||
start = lines_start_chr_idx[i]
|
||||
end = lines_end_chr_idx[i]
|
||||
time_start = raw_lines[start].split("|")[1].split(',')[0]
|
||||
time_end = raw_lines[end].split("|")[1].split(',')[0]
|
||||
lrc_text += f"[{time_start}]{lyric_text[start:end+1]}\n[{time_end}]\n"
|
||||
print(lrc_text)
|
||||
|
||||
lyric_len = len(lyric_text)
|
||||
for i in range(len(lines_start_chr_idx)-1):
|
||||
start = max(0,lines_start_chr_idx[i]-1)
|
||||
end = min(lyric_len-1, lines_end_chr_idx[i]+1)
|
||||
time_start = raw_lines[start].split("|")[1].split(',')[0]
|
||||
time_end = raw_lines[end].split("|")[1].split(',')[0]
|
||||
lrc_text += f"[{time_start}]{lyric_text[start:end+1]}\n[{time_end}]\n"
|
||||
print(lrc_text)
|
File diff suppressed because one or more lines are too long
@ -1,60 +0,0 @@
|
||||
import torch
|
||||
import torchaudio
|
||||
from typing import List
|
||||
from pypinyin import lazy_pinyin
|
||||
from pypinyin_dict.phrase_pinyin_data import cc_cedict
|
||||
from torchaudio.transforms import Resample
|
||||
|
||||
|
||||
def compute_alignments(waveform: torch.Tensor, transcript: List[str]):
|
||||
with torch.inference_mode():
|
||||
emission, _ = model(waveform.to(device))
|
||||
token_spans = aligner(emission[0], tokenizer(transcript))
|
||||
return emission, token_spans
|
||||
|
||||
# Compute average score weighted by the span length
|
||||
def _score(spans):
|
||||
return sum(s.score * len(s) for s in spans) / sum(len(s) for s in spans)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
from torchaudio.pipelines import MMS_FA as bundle
|
||||
|
||||
model = bundle.get_model()
|
||||
model.to(device)
|
||||
|
||||
tokenizer = bundle.get_tokenizer()
|
||||
aligner = bundle.get_aligner()
|
||||
|
||||
cc_cedict.load()
|
||||
|
||||
add_spaces = lambda s: ' '.join(s)
|
||||
|
||||
with open("./segments/line-1.txt", "r") as f:
|
||||
text = f.read()
|
||||
|
||||
text_raw = add_spaces(text)
|
||||
text_list = list(text)
|
||||
text_pinyin = lazy_pinyin(text)
|
||||
text_normalized = " ".join(text_pinyin)
|
||||
|
||||
waveform, sample_rate = torchaudio.load("./segments/line-1.wav")
|
||||
|
||||
waveform = waveform[0:1]
|
||||
resampler = Resample(orig_freq=sample_rate, new_freq=16000)
|
||||
waveform = resampler(waveform)
|
||||
|
||||
transcript = text_normalized.split()
|
||||
emission, token_spans = compute_alignments(waveform, transcript)
|
||||
num_frames = emission.size(1)
|
||||
|
||||
|
||||
print("Raw Transcript: ", text_raw)
|
||||
print("Normalized Transcript: ", text_normalized)
|
||||
|
||||
ratio = waveform.size(1) / num_frames
|
||||
|
||||
for i in range(len(token_spans)):
|
||||
spans = token_spans[i]
|
||||
x0 = round(int(ratio * spans[0].start) / 16000, 3)
|
||||
x1 = round(int(ratio * spans[-1].end) / 16000, 3)
|
||||
print(f"{text[i]}: {x0}-{x1}")
|
@ -1,30 +0,0 @@
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
class TTMLGenerator:
|
||||
def __init__(self, duration, xmlns="http://www.w3.org/ns/ttml", xmlns_ttm="http://www.w3.org/ns/ttml#metadata", xmlns_amll="http://www.example.com/ns/amll", xmlns_itunes="http://music.apple.com/lyric-ttml-internal"):
|
||||
self.tt = ET.Element("tt", attrib={
|
||||
"xmlns": xmlns,
|
||||
"xmlns:ttm": xmlns_ttm,
|
||||
"xmlns:amll": xmlns_amll,
|
||||
"xmlns:itunes": xmlns_itunes
|
||||
})
|
||||
self.head = ET.SubElement(self.tt, "head")
|
||||
self.metadata = ET.SubElement(self.head, "metadata")
|
||||
self.body = ET.SubElement(self.tt, "body", attrib={"dur": duration})
|
||||
self.div = ET.SubElement(self.body, "div")
|
||||
|
||||
def add_lyrics(self, begin, end, agent, itunes_key, words):
|
||||
p = ET.SubElement(self.div, "p", attrib={
|
||||
"begin": begin,
|
||||
"end": end,
|
||||
"ttm:agent": agent,
|
||||
"itunes:key": itunes_key
|
||||
})
|
||||
for word, start, stop in words:
|
||||
span = ET.SubElement(p, "span", attrib={"begin": start, "end": stop})
|
||||
span.text = word
|
||||
|
||||
def save(self, filename):
|
||||
tree = ET.ElementTree(self.tt)
|
||||
tree.write(filename, encoding="utf-8", xml_declaration=True)
|
@ -1 +0,0 @@
|
||||
{"idx_to_class": {"0": "weather", "1": "base64", "2": "url-encode", "3": "html-encode", "4": "ai.command", "5": "knowledge", "6": "ai.question", "7": "datetime"}, "threshold": 1.7}
|
@ -36,7 +36,6 @@
|
||||
"室外的温度是多少",
|
||||
"达拉斯今天热不热",
|
||||
"苏州现在天气怎么样",
|
||||
"明天悉尼会下雨吗?",
|
||||
"how's the weather",
|
||||
"What's going on with the weather?",
|
||||
"Can you give me an update on the weather?",
|
||||
@ -49,21 +48,21 @@
|
||||
"What's the weather like right now?",
|
||||
"Tell me the current weather conditions.",
|
||||
"How about the weather today?",
|
||||
"What's the weather looking like for the next few hours",
|
||||
"Is it going to stay this way all day",
|
||||
"Could you give me a brief overview of the weather",
|
||||
"What's the general weather situation in our area",
|
||||
"Is it cloudy or clear outside",
|
||||
"What's the forecast saying for today's weather",
|
||||
"Is it going to be a warm day",
|
||||
"Are we expecting any storms today",
|
||||
"What's the weather condition outside my window",
|
||||
"Is it a typical day for this season in terms of weather",
|
||||
"how's the weather now",
|
||||
"What's the temperature like right now",
|
||||
"Can you tell me the current temperature",
|
||||
"How hot is it outside",
|
||||
"What's the temperature supposed to be today",
|
||||
"What's the weather looking like for the next few hours?",
|
||||
"Is it going to stay this way all day?",
|
||||
"Could you give me a brief overview of the weather?",
|
||||
"What's the general weather situation in our area?",
|
||||
"Is it cloudy or clear outside?",
|
||||
"What's the forecast saying for today's weather?",
|
||||
"Is it going to be a warm day?",
|
||||
"Are we expecting any storms today?",
|
||||
"What's the weather condition outside my window?",
|
||||
"Is it a typical day for this season in terms of weather?",
|
||||
"how's the weather now?",
|
||||
"What's the temperature like right now?",
|
||||
"Can you tell me the current temperature?",
|
||||
"How hot is it outside?",
|
||||
"What's the temperature supposed to be today?",
|
||||
"What is the current temp outside?",
|
||||
"Could you tell me the outdoor temperature?",
|
||||
"Is it cold or warm outside?",
|
||||
@ -82,8 +81,8 @@
|
||||
"Can you tell me the temp in the nearby area?",
|
||||
"Is it below freezing outside?",
|
||||
"What's the average temperature for today?",
|
||||
"Is the temperature dropping or rising",
|
||||
"What should I wear considering the temperature"
|
||||
"Is the temperature dropping or rising?",
|
||||
"What should I wear considering the temperature?"
|
||||
],
|
||||
"base64": [
|
||||
"请将数据使用base64编码",
|
||||
@ -111,16 +110,17 @@
|
||||
"解码 base64",
|
||||
"Please encode this data with base64:",
|
||||
"I need to encode the following data in base64",
|
||||
"Could you encode this string using base64",
|
||||
"Could you encode this string using base64?",
|
||||
"Convert this data to b64 encoding",
|
||||
"I want to encode this information with base64",
|
||||
"Help me encode this in base32",
|
||||
"Can you encode this data to base64 format",
|
||||
"Can you encode this data to base64 format?",
|
||||
"b64 encode",
|
||||
"base64 encode",
|
||||
"encode base64",
|
||||
"base 64 encode online"
|
||||
],
|
||||
|
||||
"url-encode": [
|
||||
"编码 url",
|
||||
"URL部分需要编码",
|
||||
@ -145,6 +145,7 @@
|
||||
"url decoder",
|
||||
"URL encoder"
|
||||
],
|
||||
|
||||
"html-encode": [
|
||||
"请编码HTML实体",
|
||||
"文本转为HTML实体",
|
||||
@ -185,6 +186,7 @@
|
||||
"html   conversion",
|
||||
"html nbsp meaning"
|
||||
],
|
||||
|
||||
"ai.command": [
|
||||
"写一个TypeScript的HelloWorld代码",
|
||||
"检查以下内容的语法和清晰度",
|
||||
@ -235,11 +237,11 @@
|
||||
"help me learn chinese",
|
||||
"how to let the screen reader automatically focused to an newly poped up element in the web development",
|
||||
"summarize following text:",
|
||||
"Is there anything wrong with this code or can it be simplified",
|
||||
"Is there anything wrong with this code or can it be simplified?",
|
||||
"generate a Python script that prints 'Hello, World!'",
|
||||
"Can you proofread this essay for grammar and punctuation errors",
|
||||
"Can you proofread this essay for grammar and punctuation errors?",
|
||||
"Create a list of ten example sentences for the word 'serendipity.'",
|
||||
"Can you reformat this JSON to be more readable",
|
||||
"Can you reformat this JSON to be more readable?",
|
||||
"Suggest a creative title for my blog post about healthy eating.",
|
||||
"Refactor this JavaScript function to make it more efficient.",
|
||||
"Help me practice French: provide a sentence with a missing word that I can guess.",
|
||||
@ -247,15 +249,15 @@
|
||||
"Summarize this news article for me.",
|
||||
"Can you review this code snippet for potential security vulnerabilities?",
|
||||
"Generate a SQL query to find all users who signed up in the last 30 days.",
|
||||
"Can you translate this paragraph into Spanish",
|
||||
"Can you translate this paragraph into Spanish?",
|
||||
"Create a flowchart based on the following process description.",
|
||||
"Write a Python function to calculate the factorial of a number.",
|
||||
"Provide a detailed explanation of how to implement OAuth2 in a web application.",
|
||||
"Can you optimize this image for faster loading on a website",
|
||||
"Can you optimize this image for faster loading on a website?",
|
||||
"Suggest some catchy taglines for a new mobile app focused on fitness.",
|
||||
"Write a Bash script to back up my documents folder daily.",
|
||||
"Help me draft an email to request a meeting with a potential client.",
|
||||
"Can you convert this Markdown document into HTML",
|
||||
"Can you convert this Markdown document into HTML?",
|
||||
"Generate a Python script that scrapes data from a specified website.",
|
||||
"Can you find the synonyms of the word 'meticulous'?",
|
||||
"Write a SQL query to join two tables based on a common column.",
|
||||
@ -265,57 +267,31 @@
|
||||
"Can you assist me in learning Japanese?",
|
||||
"How can I make an alert box appear when a user clicks a button on a webpage?",
|
||||
"Summarize this research paper into bullet points.",
|
||||
"Can you check if there are any logical errors in this algorithm?",
|
||||
"请一步一步计算找到函数f(x)=U^2*x/(R+x)^2的顶点坐标。",
|
||||
"如何理解transformer自注意力机制中的Q,K,V?它们分别代表什么?",
|
||||
"帮我写一封求职信。先询问我的教育背景、技能和经验。",
|
||||
"总结这篇论文",
|
||||
"写一份10人晚宴的菜单",
|
||||
"写一篇博客",
|
||||
"写一段演讲稿"
|
||||
"Can you check if there are any logical errors in this algorithm?"
|
||||
],
|
||||
"knowledge": [
|
||||
|
||||
"ai.question": [
|
||||
"你认为哪个框架最适合性能敏感的项目?",
|
||||
"什么是后量子密码学?",
|
||||
"什么是密钥派生函数",
|
||||
"什么是线性代数?",
|
||||
"量子计算的特点是什么",
|
||||
"哈希函数的作用?",
|
||||
"什么是微积分?",
|
||||
"什么是区块链技术",
|
||||
"What is post-quantum cryptography",
|
||||
"What is a key derivation function?",
|
||||
"What is Linear Algebra?",
|
||||
"What is the main use of linear algebra in computer science",
|
||||
"What is quantum computing",
|
||||
"What is a hash function",
|
||||
"What is calculus",
|
||||
"什么是站点隔离?",
|
||||
"What is blockchain technology?",
|
||||
"BLEU 是什么",
|
||||
"黎巴嫩在哪",
|
||||
"什么是转义字符",
|
||||
"MixAlpha售价多少",
|
||||
"什么是神经机器翻译",
|
||||
"什么是月食",
|
||||
"什么是人工智能",
|
||||
"什么是F1-score"
|
||||
],
|
||||
"ai.question": [
|
||||
"人工智能真的有智力吗",
|
||||
"你认为哪个框架最适合性能敏感的项目?",
|
||||
"线性代数在计算机科学中的主要用途是什么?",
|
||||
"我应该使用哪个IDE来编写Go语言?",
|
||||
"Go vs Java vs Kotlin,哪个适合后端",
|
||||
"哪种编程语言最适合数据分析",
|
||||
"什么是量子计算",
|
||||
"什么是哈希函数?",
|
||||
"什么是微积分?",
|
||||
"机器学习在金融中的主要应用有哪些?",
|
||||
"写Python代码最好的文本编辑器是哪个?",
|
||||
"Python vs R vs Julia,哪个更适合数据科学?",
|
||||
"监督学习和无监督学习的关键区别是什么?",
|
||||
"数据库在Web应用程序中的作用是什么",
|
||||
"什么是区块链技术",
|
||||
"使用Docker进行应用程序部署的优势是什么?",
|
||||
"哪个云服务提供商提供最好的AI工具?",
|
||||
"加密是如何工作的",
|
||||
"负载均衡器在网络架构中的目的是什么",
|
||||
"加密是如何工作的?",
|
||||
"负载均衡器在网络架构中的目的是什么?",
|
||||
"机器学习和深度学习有什么区别",
|
||||
"软件工程中最常见的设计模式有哪些",
|
||||
"神经网络是如何学习的",
|
||||
@ -324,22 +300,31 @@
|
||||
"Rust编程语言的关键特性是什么?",
|
||||
"HTTP和HTTPS有什么区别",
|
||||
"使用像Git这样的版本控制系统有什么优势?",
|
||||
"什么是'边缘计算'的概念",
|
||||
"哪种编程语言最适合构建移动应用?",
|
||||
"关系数据库和NoSQL数据库有什么不同?",
|
||||
"算法在计算机科学中的重要性是什么",
|
||||
"算法在计算机科学中的重要性是什么?",
|
||||
"API在软件开发中的作用是什么",
|
||||
"保护Web应用程序的最佳实践是什么",
|
||||
"虚拟现实和增强现实有什么区别?",
|
||||
"机器翻译是如何工作的?",
|
||||
"Which framework do you think is the most suitable for performance sensitive projects?",
|
||||
"What is post-quantum cryptography",
|
||||
"What is a key derivation function?",
|
||||
"What is Linear Algebra?",
|
||||
"What is the main use of linear algebra in computer science",
|
||||
"which IDE should I use for Go",
|
||||
"Go vs Java vs Koltin, which for a backend",
|
||||
"Which programming language is best suited for data analysis?",
|
||||
"What are the main applications of machine learning in finance",
|
||||
"Which text editor is best for writing Python code",
|
||||
"Python vs R vs Julia, which is better for data science",
|
||||
"What are the key differences between supervised and unsupervised learning",
|
||||
"What is quantum computing?",
|
||||
"What is a hash function?",
|
||||
"What is calculus?",
|
||||
"What are the main applications of machine learning in finance?",
|
||||
"Which text editor is best for writing Python code?",
|
||||
"Python vs R vs Julia, which is better for data science?",
|
||||
"What are the key differences between supervised and unsupervised learning?",
|
||||
"What is the role of a database in a web application?",
|
||||
"What is blockchain technology?",
|
||||
"What are the advantages of using Docker for application deployment?",
|
||||
"Which cloud service provider offers the best AI tools?",
|
||||
"How does encryption work?",
|
||||
@ -347,20 +332,19 @@
|
||||
"What is the difference between machine learning and deep learning?",
|
||||
"What are the most common design patterns in software engineering?",
|
||||
"How does a neural network learn?",
|
||||
"What is the main benefit of using a microservices architecture",
|
||||
"What is the difference between a compiler and an interpreter",
|
||||
"What are the key features of the Rust programming language",
|
||||
"What is the difference between HTTP and HTTPS",
|
||||
"What are the advantages of using a version control system like Git",
|
||||
"What is the concept of 'edge computing'",
|
||||
"Which programming language is best for building mobile apps",
|
||||
"How does a relational database differ from a NoSQL database",
|
||||
"What is the importance of algorithms in computer science",
|
||||
"What is the role of an API in software development",
|
||||
"What is the main benefit of using a microservices architecture?",
|
||||
"What is the difference between a compiler and an interpreter?",
|
||||
"What are the key features of the Rust programming language?",
|
||||
"What is the difference between HTTP and HTTPS?",
|
||||
"What are the advantages of using a version control system like Git?",
|
||||
"What is the concept of 'edge computing'?",
|
||||
"Which programming language is best for building mobile apps?",
|
||||
"How does a relational database differ from a NoSQL database?",
|
||||
"What is the importance of algorithms in computer science?",
|
||||
"What is the role of an API in software development?",
|
||||
"What are the best practices for securing a web application?",
|
||||
"What is the difference between virtual reality and augmented reality?",
|
||||
"How does machine translation work?",
|
||||
"MBTI有科学依据吗?"
|
||||
"How does machine translation work?"
|
||||
],
|
||||
"datetime": ["明天周几", "16天后是几号", "一年前的今天是星期几"]
|
||||
}
|
||||
|
@ -28,7 +28,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_name=\"Qwen/Qwen2.5-3B\""
|
||||
"model_name=\"microsoft/Phi-3-mini-4k-instruct\""
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -37,10 +37,17 @@
|
||||
"id": "c1de25fc-e90a-425b-8520-3a57fa534b94",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "38137fc55ad24a9785ecbe1978bbc605",
|
||||
"model_id": "1aeb02c7c8084b1eb1b8e3178882fd60",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
@ -69,122 +76,6 @@
|
||||
"vocab = tokenizer.get_vocab()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "21214ff4-018d-4230-81b9-331ebb42773b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def bytes_to_unicode():\n",
|
||||
" \"\"\"\n",
|
||||
" Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control\n",
|
||||
" characters the bpe code barfs on.\n",
|
||||
"\n",
|
||||
" The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab\n",
|
||||
" if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for\n",
|
||||
" decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup\n",
|
||||
" tables between utf-8 bytes and unicode strings.\n",
|
||||
" \"\"\"\n",
|
||||
" bs = (\n",
|
||||
" list(range(ord(\"!\"), ord(\"~\") + 1)) + list(range(ord(\"¡\"), ord(\"¬\") + 1)) + list(range(ord(\"®\"), ord(\"ÿ\") + 1))\n",
|
||||
" )\n",
|
||||
" cs = bs[:]\n",
|
||||
" n = 0\n",
|
||||
" for b in range(2**8):\n",
|
||||
" if b not in bs:\n",
|
||||
" bs.append(b)\n",
|
||||
" cs.append(2**8 + n)\n",
|
||||
" n += 1\n",
|
||||
" cs = [chr(n) for n in cs]\n",
|
||||
" return dict(zip(bs, cs))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "cbc23d2d-985b-443a-83ee-c2286046ad5e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"btu=bytes_to_unicode()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"id": "4a99fa07-4922-4d8d-9c28-2275bf9cb8df",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"utb = reversed_dict = {value: key for key, value in btu.items()}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"id": "cb218ea7-50c7-4bb8-aa7f-0ee85da76147",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"result = tokenizer.convert_ids_to_tokens([104307])[0]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"id": "2dcb332a-cba9-4a14-9486-4e1ff6bd3dba",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"å\n",
|
||||
"229\n",
|
||||
"¤\n",
|
||||
"164\n",
|
||||
"©\n",
|
||||
"169\n",
|
||||
"æ\n",
|
||||
"230\n",
|
||||
"°\n",
|
||||
"176\n",
|
||||
"Ķ\n",
|
||||
"148\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"decoded=b\"\"\n",
|
||||
"for chr in result:\n",
|
||||
" print(chr)\n",
|
||||
" if chr in utb:\n",
|
||||
" print(utb[chr])\n",
|
||||
" decoded+=bytes([utb[chr]])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"id": "b1bf1289-2cab-4a97-ad21-b2d24de6d688",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'天气'"
|
||||
]
|
||||
},
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"decoded.decode(\"utf-8\", errors='replace')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
@ -204,7 +95,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"DIMENSIONS = 96"
|
||||
"DIMENSIONS = 128"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -277,17 +168,11 @@
|
||||
"import struct\n",
|
||||
"with open(\"token_embeddings.bin\", \"wb\") as f:\n",
|
||||
" for token_id in range(len(vocab)):\n",
|
||||
" # 将向量转换为半精度浮点数并保存\n",
|
||||
" f.write(struct.pack('96e', *reduced_embeddings[token_id].astype(np.float16)))\n"
|
||||
" # Write token id (2 bytes)\n",
|
||||
" f.write(struct.pack('H', token_id))\n",
|
||||
" # Write embedding vector (128 float numbers)\n",
|
||||
" f.write(struct.pack('128f', *reduced_embeddings[token_id]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "511a7cc4-1b8c-468c-b2a0-16dc6d74ab44",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@ -306,7 +191,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.14"
|
||||
"version": "3.9.19"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
@ -4,350 +4,5 @@
|
||||
"我爱你",
|
||||
"嘿嘿嘿诶嘿",
|
||||
"为什么",
|
||||
"拼多多",
|
||||
"machine translation",
|
||||
"trustrank",
|
||||
"中文词典",
|
||||
"bin screen linux",
|
||||
"\"TinyBERT",
|
||||
"iconify",
|
||||
"反义词 英文",
|
||||
"referer",
|
||||
"watchos uiscreen",
|
||||
"张鑫旭",
|
||||
"google first result",
|
||||
"flutter text align center",
|
||||
"ASR model",
|
||||
"real time whisper",
|
||||
"千樱凛",
|
||||
"马嘉祺",
|
||||
"flutter widget catalog",
|
||||
"flutter BottomNavigationBar left",
|
||||
"flutter tab indent vscode",
|
||||
"react native 用 expo 吗",
|
||||
"latest monorepo tool",
|
||||
"\"vite\" \"abortController\" is not defined",
|
||||
"vim comment lines",
|
||||
"Error: unable to get issuer certificate",
|
||||
"uuidv4",
|
||||
"npm semver",
|
||||
"react polyfill vite",
|
||||
"vibrance",
|
||||
"I can eat glass, it doesn't hurt me \"japanese\"",
|
||||
"I can swallow glass without any harm to myself",
|
||||
"copilot pricing",
|
||||
"vim close window",
|
||||
"sensors macos command",
|
||||
"智乃",
|
||||
"pypi wikipedia",
|
||||
"tesseract macos m1",
|
||||
"rag prompt template",
|
||||
"英国 破产",
|
||||
"bewlybewly",
|
||||
"safari-web-extension-converter",
|
||||
"starcoder",
|
||||
"open source web search for ai",
|
||||
"gpt4o mini tokenizer",
|
||||
"gpt4o tokenizer",
|
||||
"reverse dns lookup linux",
|
||||
"online ping",
|
||||
"termux",
|
||||
"802.11 table",
|
||||
"optimize",
|
||||
"集群",
|
||||
"chrome us",
|
||||
"transflective",
|
||||
"ielts toefl",
|
||||
"react router",
|
||||
"摇曳露营 萌娘百科",
|
||||
"isrc",
|
||||
"apple-system",
|
||||
"-apple-system",
|
||||
"css clip path animation",
|
||||
"can i use relative path in og image",
|
||||
"GitSora",
|
||||
"matrix im",
|
||||
"test your vocabulary",
|
||||
"boarding pass",
|
||||
"函数签名",
|
||||
"类型谓词",
|
||||
"barcode",
|
||||
"智能",
|
||||
"threejs 入门",
|
||||
"南亚语系",
|
||||
"linux user's computer be like",
|
||||
"apple a16 显微图",
|
||||
"dallas",
|
||||
"恶魔 英文",
|
||||
"Rime meaning",
|
||||
"adobe media encoder macos download",
|
||||
"mp4 transparency",
|
||||
"webkit",
|
||||
"chromium",
|
||||
"献血",
|
||||
"软件强制更新",
|
||||
"If you don’t agree with its politics views, Notepad+ + will add random characters in your source code.",
|
||||
"Unmerged paths",
|
||||
"字数统计",
|
||||
"Use build.rollupOptions.output.manualChunks to improve chunking: https://rollupjs.org/configuration-options/#output-manualchunks",
|
||||
"世界人权宣言",
|
||||
"latex percent",
|
||||
"chord in keyboard",
|
||||
"Google is trying to kill the Open Web.",
|
||||
"silo'd",
|
||||
"swiftui 数组倒数访问",
|
||||
"swiftui link to another view",
|
||||
"fizzbuzz",
|
||||
"AppDelegate watchos",
|
||||
"Cannot find type 'UIApplicationDelegate' in scope",
|
||||
"swiftui web image",
|
||||
"spammer",
|
||||
"swiftui text",
|
||||
"钢琴",
|
||||
"disable webgl chrome",
|
||||
"online uuid",
|
||||
"cp show progress",
|
||||
"易容术",
|
||||
"fulilian",
|
||||
"cargo",
|
||||
"wordle",
|
||||
"mismatch",
|
||||
"btc",
|
||||
"squelch",
|
||||
"psql show table structure",
|
||||
"let padding don't effect when empty",
|
||||
"take over the world meaning",
|
||||
"brain teasers",
|
||||
"Google flight API",
|
||||
"square symbol",
|
||||
"sill",
|
||||
"nextjs layout per page",
|
||||
"UA 550 umol/L",
|
||||
"react production promotion page",
|
||||
"jupyter notebook",
|
||||
"wth meaning",
|
||||
"glove词向量",
|
||||
"google suggestion relevance",
|
||||
"YouTube advertising income",
|
||||
"PKI",
|
||||
"next client only component",
|
||||
"nextjs use client",
|
||||
"nextjs docker tailwind not working",
|
||||
"k8s",
|
||||
"Logistic Regression",
|
||||
"氯化钾注射死刑",
|
||||
"icloud photo loss",
|
||||
"芙宁娜 水上行走",
|
||||
"vector design tool",
|
||||
"netizen",
|
||||
"framework or next js documentation",
|
||||
"csync",
|
||||
"next js",
|
||||
"后量子正向保密",
|
||||
"nip05",
|
||||
"Sora技术原理",
|
||||
"wasm效率",
|
||||
"switch code",
|
||||
"online IPA pronunciation",
|
||||
"pnpm global adir",
|
||||
"如何搜索",
|
||||
"1999 抽卡期望",
|
||||
"swiftui background blur",
|
||||
"chrome macos fullscreen hide",
|
||||
"中英文空格自动",
|
||||
"ios 旁白 屏幕识别",
|
||||
"ios 旁白 转子",
|
||||
"http 404",
|
||||
"yaml缩进",
|
||||
"counter generator github",
|
||||
"git 服务器提供远端仓库",
|
||||
"ipfs companion",
|
||||
"supervisor config",
|
||||
"SSO",
|
||||
"slot embedding",
|
||||
"sql show tables",
|
||||
"The request signature we calculated does not match the signature you provided. Check your Secret Access Key and signing method.",
|
||||
"icloud.com,cn",
|
||||
"VuePress",
|
||||
"parser",
|
||||
"stackoverflow statistics",
|
||||
"sd xl",
|
||||
"Rollup failed to resolve import \"workbox-precaching\" from",
|
||||
"dep",
|
||||
"Cannot find module estree-walker.js docker",
|
||||
"nuxt run",
|
||||
"base58解码",
|
||||
"cga",
|
||||
"vscode",
|
||||
"vscode",
|
||||
"silicon",
|
||||
"macos m1 linux",
|
||||
"预处理 后处理",
|
||||
"is vp9 opensource",
|
||||
"Alice Blu",
|
||||
"失控玩家",
|
||||
"kv数据库",
|
||||
"redis 持久化",
|
||||
"firefox disable outline",
|
||||
"cd -2",
|
||||
"IM application",
|
||||
"2021国产电影",
|
||||
"youtube chat overlay obs",
|
||||
"obs add clock",
|
||||
"Z is not defined nuxt",
|
||||
"safari ios debug",
|
||||
"safari debug",
|
||||
"chat",
|
||||
"nuxt plugin inject",
|
||||
"twitch",
|
||||
"obs 绿幕",
|
||||
"gnupg",
|
||||
"kde plasma wallpaper engine",
|
||||
"Plasma",
|
||||
"dns over https",
|
||||
"localforage缺点",
|
||||
"watchOS 10",
|
||||
"noun of repeat",
|
||||
"微信输入法",
|
||||
"行业报告",
|
||||
"keepass",
|
||||
"platform",
|
||||
"steam",
|
||||
"java proxy",
|
||||
"0 design",
|
||||
"cefr word level list",
|
||||
"precipitation meaning",
|
||||
"international school of lausanne",
|
||||
"Vim Uganda",
|
||||
"抖音 推荐算法",
|
||||
"Meta NNLO",
|
||||
"windbg dump分析",
|
||||
"web image fft",
|
||||
"GPT-4 Pricing",
|
||||
"GPT-4",
|
||||
"Scala",
|
||||
"tauri教程",
|
||||
"asyncio.create_task用法",
|
||||
"H5 滚动到底部",
|
||||
"microsoft copilot",
|
||||
"枫丹文字",
|
||||
"brew pip",
|
||||
"TS7016: Could not find a declaration file for module react .",
|
||||
"fastapi websocket",
|
||||
"kazv",
|
||||
"The Type 孔雀计划",
|
||||
"第一个图形操作系统",
|
||||
"娱乐 诞生",
|
||||
"ffmpeg 音频封面",
|
||||
"Jean-Loup Gailly",
|
||||
"Linux用户软件位置",
|
||||
"\"ubuntu\" 平滑滚动",
|
||||
"python range函数",
|
||||
"KMP",
|
||||
"sd 8gen2 GPU GFLOPS",
|
||||
"mac语音输入法",
|
||||
"openai translate",
|
||||
"蔚蓝档案 初始抽卡",
|
||||
"free custom domain email",
|
||||
"洛天依",
|
||||
"b站 频道页Tab 跳转",
|
||||
"URL 重定向预览",
|
||||
"计算机",
|
||||
"sololearn",
|
||||
"PoS机制 通俗解释",
|
||||
"google search cost",
|
||||
"bos s3",
|
||||
"react 打包",
|
||||
"useeffect 用法",
|
||||
"ts 字典类型",
|
||||
"vscode 字典单词自动补全插件",
|
||||
"componentwillupdate",
|
||||
"iPad Mini 2",
|
||||
"use-immer",
|
||||
"reducer 和 context",
|
||||
"mint",
|
||||
"Elementary OS",
|
||||
"google科技新闻",
|
||||
"iCloud mail \"\"-9002\"\"",
|
||||
"氢氧化铁胶体制备",
|
||||
"react native 视频处理",
|
||||
"四川 2023 高考 复旦大学 分数线",
|
||||
"哑铃弯举",
|
||||
"m2 ultra",
|
||||
"电池循环计数 site:apple.com",
|
||||
"相机发明时间",
|
||||
"冯诺依曼结构",
|
||||
"哈佛架构",
|
||||
"nodejs 后端",
|
||||
"34.5M€ to CN¥",
|
||||
"NLP 实体关注",
|
||||
"monkey",
|
||||
"react 快捷键监听",
|
||||
"mac 好看的电子书阅读器",
|
||||
"新闻",
|
||||
"在线字体编辑器",
|
||||
"ars technica",
|
||||
"genshin 4.1 release time",
|
||||
"swift device activity report",
|
||||
"swiftui tabview background",
|
||||
"swiftui text space",
|
||||
"apple inc. wikipedia",
|
||||
"how long does it take Google to return the results",
|
||||
"云原神 web",
|
||||
"支持homekit的空调",
|
||||
"内核隔离",
|
||||
"海祇岛解密",
|
||||
"swiftui Textfield",
|
||||
"xcode",
|
||||
"qq 链接",
|
||||
"M1 推出时间",
|
||||
"USB-IF",
|
||||
"nvchat",
|
||||
"P1% FPS",
|
||||
"react i18next 当前语言",
|
||||
"js 获取语言",
|
||||
"MulType",
|
||||
"b站平均使用时间",
|
||||
"pip 阿里源",
|
||||
"ip info",
|
||||
"graphjet",
|
||||
"金融思维",
|
||||
"C#写入文件",
|
||||
"Last Day Sinista M",
|
||||
"在 系统 位置 xcode select 找 不 到 SDK",
|
||||
"Error: Could not find a valid Xcode app bundle at '/Library/Developer/CommandLineTools'. Please update your Apple SDK location in Visual Studio's preferences (Projects > SDK Locations > Apple > Apple SDK). (UniBattery)",
|
||||
".NET能做什么",
|
||||
"could i give no tip ",
|
||||
"miami university of ohio",
|
||||
"方正颜宋",
|
||||
"中文 标题字体",
|
||||
"聚典平台",
|
||||
"62 basic words for a language",
|
||||
"procrastination meaning",
|
||||
"Lingbe",
|
||||
"娱乐至死",
|
||||
"macOS 外接显示器渲染",
|
||||
"白玉袖",
|
||||
"SwiftUI入门",
|
||||
"html插入其它网页",
|
||||
"捆绑 小说",
|
||||
"apple music 无损下载",
|
||||
"一miumiu 赐予",
|
||||
"macos markdown",
|
||||
"safari 开发者工具",
|
||||
"\"百合\" \"武侠\" \"国漫\"",
|
||||
"epub 格式详解",
|
||||
"chrome 隐藏滚动条",
|
||||
"发宽空格",
|
||||
"U+200A",
|
||||
"无性人",
|
||||
"Spotify",
|
||||
"禾念",
|
||||
"how to pronounce Lorem ipsum",
|
||||
"言和为什么不是男孩子",
|
||||
"浏览器主页",
|
||||
"react",
|
||||
"Tailwindcss react 扩展怎么用",
|
||||
"Prettier 扩展怎么用",
|
||||
"linter\""
|
||||
"拼多多"
|
||||
]
|
575
intention-classify/train.ipynb
Normal file
575
intention-classify/train.ipynb
Normal file
@ -0,0 +1,575 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a6a3195f-d099-4bf4-846f-51f403954818",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# sparkastML: Training the Intention Classification Model\n",
|
||||
"\n",
|
||||
"This is the model we use for intent recognition, using a **CNN architectur** and using an **Energy-based Model** to implement OSR (Open-set Recognition).\n",
|
||||
"\n",
|
||||
"In this case, **positive samples** refer to data that can be classified into existing class, while **negative samples** are those does not belong to any of the existing class."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "bddcdbb2-ccbc-4027-a38f-09c61ac94984",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"import torch\n",
|
||||
"from torch.utils.data import Dataset, DataLoader\n",
|
||||
"from torch.nn.utils.rnn import pad_sequence\n",
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"from transformers import AutoTokenizer, AutoModel\n",
|
||||
"import torch\n",
|
||||
"import numpy as np\n",
|
||||
"from scipy.spatial.distance import euclidean\n",
|
||||
"from scipy.stats import weibull_min\n",
|
||||
"from sklearn.preprocessing import normalize\n",
|
||||
"import torch.nn.functional as F\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "d3a0e10f-9bc3-44c7-a109-786dd5cd25ea",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"model_name=\"microsoft/Phi-3-mini-4k-instruct\"\n",
|
||||
"DIMENSIONS = 128\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(model_name)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "1ae14906-338d-4c99-87ed-bb1acd22b295",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Load Data\n",
|
||||
"\n",
|
||||
"We load the data from `data.json`, and also get the negative sample from the `noise.json`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "a206071c-ce4e-4de4-b936-bfc70d13708a",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/var/folders/25/gdz0c30x3mg1dj9qkwz0ch4w0000gq/T/ipykernel_6446/1697839999.py:18: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /private/var/folders/nz/j6p8yfhx1mv_0grj5xl4650h0000gp/T/abs_69nk78ncaj/croot/pytorch_1669252638507/work/torch/csrc/utils/tensor_new.cpp:204.)\n",
|
||||
" embeddings = torch.tensor(embeddings)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Load data\n",
|
||||
"with open('data.json', 'r') as f:\n",
|
||||
" data = json.load(f)\n",
|
||||
"\n",
|
||||
"# Create map: class to index\n",
|
||||
"class_to_idx = {cls: idx for idx, cls in enumerate(data.keys())}\n",
|
||||
"idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}\n",
|
||||
"\n",
|
||||
"# Preprocess data, convert sentences to the format of (class idx, embedding)\n",
|
||||
"def preprocess_data(data, embedding_map, tokenizer, max_length=64):\n",
|
||||
" dataset = []\n",
|
||||
" for label, sentences in data.items():\n",
|
||||
" for sentence in sentences:\n",
|
||||
" # Tokenize the sentence and convert tokens to embedding vectors\n",
|
||||
" tokens = tokenizer.tokenize(sentence)\n",
|
||||
" token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
|
||||
" embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]\n",
|
||||
" embeddings = torch.tensor(embeddings)\n",
|
||||
" dataset.append((class_to_idx[label], embeddings))\n",
|
||||
" return dataset\n",
|
||||
"\n",
|
||||
"# Load embedding map\n",
|
||||
"embedding_map = torch.load('token_id_to_reduced_embedding.pt')\n",
|
||||
"\n",
|
||||
"# Get preprocessed dataset\n",
|
||||
"dataset = preprocess_data(data, embedding_map, tokenizer)\n",
|
||||
"\n",
|
||||
"# Train-test split\n",
|
||||
"train_data, val_data = train_test_split(dataset, test_size=0.2, random_state=42)\n",
|
||||
"\n",
|
||||
"class TextDataset(Dataset):\n",
|
||||
" def __init__(self, data):\n",
|
||||
" self.data = data\n",
|
||||
"\n",
|
||||
" def __len__(self):\n",
|
||||
" return len(self.data)\n",
|
||||
"\n",
|
||||
" def __getitem__(self, idx):\n",
|
||||
" return self.data[idx]\n",
|
||||
"\n",
|
||||
" def collate_fn(self, batch):\n",
|
||||
" labels, embeddings = zip(*batch)\n",
|
||||
" labels = torch.tensor(labels)\n",
|
||||
" embeddings = pad_sequence(embeddings, batch_first=True)\n",
|
||||
" return labels, embeddings\n",
|
||||
"\n",
|
||||
"train_dataset = TextDataset(train_data)\n",
|
||||
"val_dataset = TextDataset(val_data)\n",
|
||||
"\n",
|
||||
"train_loader = DataLoader(train_dataset, batch_size=24, shuffle=True, collate_fn=train_dataset.collate_fn)\n",
|
||||
"val_loader = DataLoader(val_dataset, batch_size=24, shuffle=False, collate_fn=val_dataset.collate_fn)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "9adbe9b8-a2d2-4e1d-8620-457ed0e02fe6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"from torch.utils.data import Dataset, DataLoader\n",
|
||||
"\n",
|
||||
"class NegativeSampleDataset(Dataset):\n",
|
||||
" def __init__(self, negative_samples):\n",
|
||||
" \"\"\"\n",
|
||||
" negative_samples: List or array of negative sample embeddings or raw text\n",
|
||||
" \"\"\"\n",
|
||||
" self.samples = negative_samples\n",
|
||||
" \n",
|
||||
" def __len__(self):\n",
|
||||
" return len(self.samples)\n",
|
||||
" \n",
|
||||
" def __getitem__(self, idx):\n",
|
||||
" return self.samples[idx]\n",
|
||||
"\n",
|
||||
" def collate_fn(self, batch):\n",
|
||||
" embeddings = pad_sequence(batch, batch_first=True)\n",
|
||||
" return embeddings\n",
|
||||
"\n",
|
||||
"with open('noise.json', 'r') as f:\n",
|
||||
" negative_samples_list = json.load(f)\n",
|
||||
"\n",
|
||||
"negative_embedding_list = []\n",
|
||||
"\n",
|
||||
"for sentence in negative_samples_list:\n",
|
||||
" tokens = tokenizer.tokenize(sentence)\n",
|
||||
" token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
|
||||
" embeddings = [embedding_map[token_id] for token_id in token_ids[:64]]\n",
|
||||
" embeddings = torch.tensor(embeddings)\n",
|
||||
" negative_embedding_list.append(embeddings)\n",
|
||||
"\n",
|
||||
"negative_dataset = NegativeSampleDataset(negative_embedding_list)\n",
|
||||
"negative_loader = DataLoader(negative_dataset, batch_size=24, shuffle=True, collate_fn=negative_dataset.collate_fn)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "600febe4-2484-4aad-90a1-2bc821fdce1a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Implementating the Model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "adf624ac-ad63-437b-95f6-b02b7253b91e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch.nn as nn\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"\n",
|
||||
"class TextCNN(nn.Module):\n",
|
||||
" def __init__(self, input_dim, num_classes):\n",
|
||||
" super(TextCNN, self).__init__()\n",
|
||||
" self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=DIMENSIONS, kernel_size=3, padding=1)\n",
|
||||
" self.conv2 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=4, padding=1)\n",
|
||||
" self.conv3 = nn.Conv1d(in_channels=DIMENSIONS, out_channels=DIMENSIONS, kernel_size=5, padding=2)\n",
|
||||
" \n",
|
||||
" self.bn1 = nn.BatchNorm1d(DIMENSIONS)\n",
|
||||
" self.bn2 = nn.BatchNorm1d(DIMENSIONS)\n",
|
||||
" self.bn3 = nn.BatchNorm1d(DIMENSIONS)\n",
|
||||
" \n",
|
||||
" self.dropout = nn.Dropout(0.5)\n",
|
||||
" self.fc = nn.Linear(DIMENSIONS * 3, num_classes)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" x = x.permute(0, 2, 1) # Change the input shape to (batch_size, embedding_dim, seq_length)\n",
|
||||
" \n",
|
||||
" x1 = F.relu(self.bn1(self.conv1(x)))\n",
|
||||
" x1 = F.adaptive_max_pool1d(x1, output_size=1).squeeze(2)\n",
|
||||
" \n",
|
||||
" x2 = F.relu(self.bn2(self.conv2(x)))\n",
|
||||
" x2 = F.adaptive_max_pool1d(x2, output_size=1).squeeze(2)\n",
|
||||
" \n",
|
||||
" x3 = F.relu(self.bn3(self.conv3(x)))\n",
|
||||
" x3 = F.adaptive_max_pool1d(x3, output_size=1).squeeze(2)\n",
|
||||
" \n",
|
||||
" x = torch.cat((x1, x2, x3), dim=1)\n",
|
||||
" x = self.dropout(x)\n",
|
||||
" x = self.fc(x)\n",
|
||||
" return x\n",
|
||||
"\n",
|
||||
"# Initialize model\n",
|
||||
"input_dim = DIMENSIONS\n",
|
||||
"num_classes = len(class_to_idx)\n",
|
||||
"model = TextCNN(input_dim, num_classes)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "2750e17d-8a60-40c7-851b-1a567d0ee82b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Energy-based Models"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "a2d7c920-07d2-4d14-9cef-e2101b7a2ceb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def energy_score(logits):\n",
|
||||
" # Energy score is minus logsumexp\n",
|
||||
" return -torch.logsumexp(logits, dim=1)\n",
|
||||
"\n",
|
||||
"def generate_noise(batch_size, seq_length ,input_dim, device):\n",
|
||||
" # Generate a Gaussian noise\n",
|
||||
" return torch.randn(batch_size, seq_length, input_dim).to(device)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "904a60e4-95a0-4f7b-ad45-a7d8d0ac887d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Training"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "19acb5bf-00b1-47d4-ad25-a13c6be09f65",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch [1/50], Loss: 12.5108\n",
|
||||
"Epoch [2/50], Loss: 10.7305\n",
|
||||
"Epoch [3/50], Loss: 10.2943\n",
|
||||
"Epoch [4/50], Loss: 9.9350\n",
|
||||
"Epoch [5/50], Loss: 9.7991\n",
|
||||
"Epoch [6/50], Loss: 9.6443\n",
|
||||
"Epoch [7/50], Loss: 9.4762\n",
|
||||
"Epoch [8/50], Loss: 9.4637\n",
|
||||
"Epoch [9/50], Loss: 9.3025\n",
|
||||
"Epoch [10/50], Loss: 9.1719\n",
|
||||
"Epoch [11/50], Loss: 9.0632\n",
|
||||
"Epoch [12/50], Loss: 8.9741\n",
|
||||
"Epoch [13/50], Loss: 8.8487\n",
|
||||
"Epoch [14/50], Loss: 8.6565\n",
|
||||
"Epoch [15/50], Loss: 8.5830\n",
|
||||
"Epoch [16/50], Loss: 8.4196\n",
|
||||
"Epoch [17/50], Loss: 8.2319\n",
|
||||
"Epoch [18/50], Loss: 8.0655\n",
|
||||
"Epoch [19/50], Loss: 7.7140\n",
|
||||
"Epoch [20/50], Loss: 7.6921\n",
|
||||
"Epoch [21/50], Loss: 7.3375\n",
|
||||
"Epoch [22/50], Loss: 7.2297\n",
|
||||
"Epoch [23/50], Loss: 6.8833\n",
|
||||
"Epoch [24/50], Loss: 6.8534\n",
|
||||
"Epoch [25/50], Loss: 6.4557\n",
|
||||
"Epoch [26/50], Loss: 6.1365\n",
|
||||
"Epoch [27/50], Loss: 5.8558\n",
|
||||
"Epoch [28/50], Loss: 5.5030\n",
|
||||
"Epoch [29/50], Loss: 5.1604\n",
|
||||
"Epoch [30/50], Loss: 4.7742\n",
|
||||
"Epoch [31/50], Loss: 4.5958\n",
|
||||
"Epoch [32/50], Loss: 4.0713\n",
|
||||
"Epoch [33/50], Loss: 3.8872\n",
|
||||
"Epoch [34/50], Loss: 3.5240\n",
|
||||
"Epoch [35/50], Loss: 3.3115\n",
|
||||
"Epoch [36/50], Loss: 2.5667\n",
|
||||
"Epoch [37/50], Loss: 2.6709\n",
|
||||
"Epoch [38/50], Loss: 1.8075\n",
|
||||
"Epoch [39/50], Loss: 1.6654\n",
|
||||
"Epoch [40/50], Loss: 0.4622\n",
|
||||
"Epoch [41/50], Loss: 0.4719\n",
|
||||
"Epoch [42/50], Loss: -0.4037\n",
|
||||
"Epoch [43/50], Loss: -0.9405\n",
|
||||
"Epoch [44/50], Loss: -1.7204\n",
|
||||
"Epoch [45/50], Loss: -2.4124\n",
|
||||
"Epoch [46/50], Loss: -3.0032\n",
|
||||
"Epoch [47/50], Loss: -2.7123\n",
|
||||
"Epoch [48/50], Loss: -3.6953\n",
|
||||
"Epoch [49/50], Loss: -3.7212\n",
|
||||
"Epoch [50/50], Loss: -3.7558\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import torch.optim as optim\n",
|
||||
"\n",
|
||||
"criterion = nn.CrossEntropyLoss()\n",
|
||||
"optimizer = optim.Adam(model.parameters(), lr=8e-4)\n",
|
||||
"\n",
|
||||
"from torch.utils.tensorboard import SummaryWriter\n",
|
||||
"import tensorboard\n",
|
||||
"writer = SummaryWriter()\n",
|
||||
"\n",
|
||||
"def train_energy_model(model, train_loader, negative_loader, criterion, optimizer, num_epochs=10):\n",
|
||||
" model.train()\n",
|
||||
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
||||
" model.to(device)\n",
|
||||
" \n",
|
||||
" negative_iter = iter(negative_loader)\n",
|
||||
" \n",
|
||||
" for epoch in range(num_epochs):\n",
|
||||
" total_loss = 0\n",
|
||||
" for batch_idx, (labels, embeddings) in enumerate(train_loader):\n",
|
||||
" embeddings = embeddings.to(device)\n",
|
||||
" labels = labels.to(device)\n",
|
||||
" \n",
|
||||
" batch_size = embeddings.size(0)\n",
|
||||
" \n",
|
||||
" # ---------------------\n",
|
||||
" # 1. Positive sample\n",
|
||||
" # ---------------------\n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" outputs = model(embeddings) # logits from the model\n",
|
||||
" \n",
|
||||
" class_loss = criterion(outputs, labels)\n",
|
||||
" \n",
|
||||
" # Energy of positive sample\n",
|
||||
" known_energy = energy_score(outputs)\n",
|
||||
" energy_loss_known = known_energy.mean()\n",
|
||||
" \n",
|
||||
" # ------------------------------------\n",
|
||||
" # 2. Negative sample - Random Noise\n",
|
||||
" # ------------------------------------\n",
|
||||
" noise_embeddings = torch.randn_like(embeddings).to(device)\n",
|
||||
" noise_outputs = model(noise_embeddings)\n",
|
||||
" noise_energy = energy_score(noise_outputs)\n",
|
||||
" energy_loss_noise = F.relu(1 - noise_energy).mean() # For the energy of noise, bigger is better \n",
|
||||
" \n",
|
||||
" # ------------------------------------\n",
|
||||
" # 3. Negative sample - custom corpus\n",
|
||||
" # ------------------------------------\n",
|
||||
" \n",
|
||||
" try:\n",
|
||||
" negative_samples = next(negative_iter)\n",
|
||||
" except StopIteration:\n",
|
||||
" negative_iter = iter(negative_loader)\n",
|
||||
" negative_samples = next(negative_iter)\n",
|
||||
" negative_samples = negative_samples.to(device)\n",
|
||||
" negative_outputs = model(negative_samples)\n",
|
||||
" negative_energy = energy_score(negative_outputs)\n",
|
||||
" energy_loss_negative = F.relu(1 - negative_energy).mean() # For the energy of noise, bigger is better \n",
|
||||
" \n",
|
||||
" # -----------------------------\n",
|
||||
" # 4. Overall Loss calculation\n",
|
||||
" # -----------------------------\n",
|
||||
" total_energy_loss = energy_loss_known + energy_loss_noise + energy_loss_negative\n",
|
||||
" total_loss_batch = class_loss + total_energy_loss * 0.1 + 10\n",
|
||||
"\n",
|
||||
" writer.add_scalar(\"Engergy Loss\", total_energy_loss, epoch)\n",
|
||||
" writer.add_scalar(\"Loss\", total_loss_batch, epoch)\n",
|
||||
" writer.add_scalar(\"Norm Loss\", torch.exp(total_loss_batch * 0.003) * 10 , epoch)\n",
|
||||
" \n",
|
||||
" total_loss_batch.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
" \n",
|
||||
" total_loss += total_loss_batch.item()\n",
|
||||
" \n",
|
||||
" avg_loss = total_loss / len(train_loader)\n",
|
||||
" print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')\n",
|
||||
"\n",
|
||||
"train_energy_model(model, train_loader, negative_loader, criterion, optimizer, num_epochs=50)\n",
|
||||
"writer.flush()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "e6d29558-f497-4033-8488-169bd25ce881",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Evalutation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "472702e6-db4a-4faa-9e92-7510e6eacbb1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ENERGY_THRESHOLD = -3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "d3a4fef8-37ab-45c8-b2b1-9fc8bdffcffd",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Accuracy: 0.9315\n",
|
||||
"Precision: 1.0000\n",
|
||||
"Recall: 0.9254\n",
|
||||
"F1 Score: 0.9612\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from sklearn.metrics import f1_score, accuracy_score, precision_recall_fscore_support\n",
|
||||
"\n",
|
||||
"def evaluate_energy_model(model, known_loader, unknown_loader, energy_threshold):\n",
|
||||
" model.eval()\n",
|
||||
" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
||||
" \n",
|
||||
" all_preds = []\n",
|
||||
" all_labels = []\n",
|
||||
" \n",
|
||||
" # Evaluate positive sample\n",
|
||||
" with torch.no_grad():\n",
|
||||
" for labels, embeddings in known_loader:\n",
|
||||
" embeddings = embeddings.to(device)\n",
|
||||
" logits = model(embeddings)\n",
|
||||
" energy = energy_score(logits)\n",
|
||||
" \n",
|
||||
" preds = (energy <= energy_threshold).long()\n",
|
||||
" all_preds.extend(preds.cpu().numpy())\n",
|
||||
" all_labels.extend([1] * len(preds)) # Positive sample labeled as 1\n",
|
||||
" \n",
|
||||
" # Evaluate negative sample\n",
|
||||
" with torch.no_grad():\n",
|
||||
" for embeddings in unknown_loader:\n",
|
||||
" embeddings = embeddings.to(device)\n",
|
||||
" logits = model(embeddings)\n",
|
||||
" energy = energy_score(logits)\n",
|
||||
" \n",
|
||||
" preds = (energy <= energy_threshold).long()\n",
|
||||
" all_preds.extend(preds.cpu().numpy())\n",
|
||||
" all_labels.extend([0] * len(preds)) # Negative sample labeled as 1\n",
|
||||
" \n",
|
||||
" precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')\n",
|
||||
" accuracy = accuracy_score(all_labels, all_preds)\n",
|
||||
"\n",
|
||||
" print(f'Accuracy: {accuracy:.4f}')\n",
|
||||
" print(f'Precision: {precision:.4f}')\n",
|
||||
" print(f'Recall: {recall:.4f}')\n",
|
||||
" print(f'F1 Score: {f1:.4f}')\n",
|
||||
"\n",
|
||||
"evaluate_energy_model(model, val_loader, negative_loader, ENERGY_THRESHOLD)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"id": "ba614054-75e1-4f61-ace5-aeb11e29a222",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Save the model\n",
|
||||
"torch.save(model, \"model.pt\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "fdfa0c5e-e6d3-4db0-a142-96645c92719c",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Inference"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"id": "03928d75-81c8-4298-ab8a-d7f8a758b561",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Predicted: ['weather', 0.9989822506904602, -8.016249656677246]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def predict_with_energy(model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold, max_length=64):\n",
|
||||
" model.eval()\n",
|
||||
" tokens = tokenizer.tokenize(sentence)\n",
|
||||
" token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
|
||||
" embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]\n",
|
||||
" embeddings = torch.tensor(embeddings).unsqueeze(0) # Add batch dimension\n",
|
||||
" \n",
|
||||
" with torch.no_grad():\n",
|
||||
" logits = model(embeddings)\n",
|
||||
" probabilities = F.softmax(logits, dim=1)\n",
|
||||
" max_prob, predicted = torch.max(probabilities, 1)\n",
|
||||
" \n",
|
||||
" # Calculate energy score\n",
|
||||
" energy = energy_score(logits)\n",
|
||||
"\n",
|
||||
" # If energy > threshold, consider the input as unknown class\n",
|
||||
" if energy.item() > energy_threshold:\n",
|
||||
" return [\"Unknown\", max_prob.item(), energy.item()]\n",
|
||||
" else:\n",
|
||||
" return [idx_to_class[predicted.item()], max_prob.item(), energy.item()]\n",
|
||||
"\n",
|
||||
"# Example usage:\n",
|
||||
"sentence = \"weather today\"\n",
|
||||
"energy_threshold = ENERGY_THRESHOLD\n",
|
||||
"predicted = predict_with_energy(model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold)\n",
|
||||
"print(f'Predicted: {predicted}')\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.19"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -1,4 +0,0 @@
|
||||
# config.py
|
||||
|
||||
model_name = "Qwen/Qwen2.5-3B"
|
||||
DIMENSIONS = 96
|
@ -1,71 +0,0 @@
|
||||
# data_utils.py
|
||||
|
||||
import json
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def load_data(file_path):
|
||||
with open(file_path, "r") as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
|
||||
|
||||
def create_class_mappings(data):
|
||||
class_to_idx = {cls: idx for idx, cls in enumerate(data.keys())}
|
||||
idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}
|
||||
return class_to_idx, idx_to_class
|
||||
|
||||
|
||||
def preprocess_data(data, embedding_map, tokenizer, class_to_idx, max_length=64):
|
||||
dataset = []
|
||||
for label, sentences in data.items():
|
||||
for sentence in sentences:
|
||||
tokens = tokenizer.tokenize(sentence)
|
||||
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
embeddings = [
|
||||
embedding_map[token_id] for token_id in token_ids[:max_length]
|
||||
]
|
||||
embeddings = torch.tensor(embeddings)
|
||||
dataset.append((class_to_idx[label], embeddings))
|
||||
return dataset
|
||||
|
||||
def get_sentences(data):
|
||||
result = []
|
||||
for _, sentences in data.items():
|
||||
for sentence in sentences:
|
||||
result.append(sentence)
|
||||
return result
|
||||
|
||||
|
||||
class TextDataset(Dataset):
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.data[idx]
|
||||
|
||||
def collate_fn(self, batch):
|
||||
labels, embeddings = zip(*batch)
|
||||
labels = torch.tensor(labels)
|
||||
embeddings = pad_sequence(embeddings, batch_first=True)
|
||||
return labels, embeddings
|
||||
|
||||
|
||||
class NegativeSampleDataset(Dataset):
|
||||
def __init__(self, negative_samples):
|
||||
self.samples = negative_samples
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.samples[idx]
|
||||
|
||||
def collate_fn(self, batch):
|
||||
embeddings = pad_sequence(batch, batch_first=True)
|
||||
return embeddings
|
@ -1,53 +0,0 @@
|
||||
# model.py
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from training.config import DIMENSIONS
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, input_dim, heads):
|
||||
super(SelfAttention, self).__init__()
|
||||
self.heads = heads
|
||||
self.scale = (input_dim // heads) ** -0.5
|
||||
self.qkv = nn.Linear(input_dim, input_dim * 3)
|
||||
self.fc = nn.Linear(input_dim, input_dim)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, seq_length, embedding_dim = x.shape
|
||||
qkv = self.qkv(x).view(
|
||||
batch_size, seq_length, self.heads, 3, embedding_dim // self.heads
|
||||
)
|
||||
q, k, v = qkv[..., 0, :], qkv[..., 1, :], qkv[..., 2, :]
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 1, 3)
|
||||
v = v.permute(0, 2, 1, 3)
|
||||
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||
attention_output = torch.matmul(attn_weights, v)
|
||||
attention_output = attention_output.permute(0, 2, 1, 3).contiguous()
|
||||
attention_output = attention_output.view(batch_size, seq_length, embedding_dim)
|
||||
return self.fc(attention_output)
|
||||
|
||||
|
||||
class AttentionBasedModel(nn.Module):
|
||||
def __init__(self, input_dim, num_classes, heads=8, dim_feedforward=512, num_layers=3):
|
||||
super(AttentionBasedModel, self).__init__()
|
||||
self.self_attention_layers = nn.ModuleList([
|
||||
SelfAttention(input_dim, heads) for _ in range(num_layers)
|
||||
])
|
||||
self.fc1 = nn.Linear(input_dim, dim_feedforward)
|
||||
self.fc2 = nn.Linear(dim_feedforward, num_classes)
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
self.norm = nn.LayerNorm(input_dim)
|
||||
|
||||
def forward(self, x):
|
||||
for attn_layer in self.self_attention_layers:
|
||||
attn_output = attn_layer(x)
|
||||
x = self.norm(attn_output + x)
|
||||
pooled_output = torch.mean(x, dim=1)
|
||||
x = F.relu(self.fc1(pooled_output))
|
||||
x = self.dropout(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
@ -1,155 +0,0 @@
|
||||
# train.py
|
||||
|
||||
from sklearn.model_selection import train_test_split
|
||||
import torch
|
||||
import json
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.nn.functional as F
|
||||
from training.data_utils import (
|
||||
load_data,
|
||||
create_class_mappings,
|
||||
preprocess_data,
|
||||
TextDataset,
|
||||
NegativeSampleDataset,
|
||||
)
|
||||
from training.model import AttentionBasedModel
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def energy_score(logits, temperature=1.0):
|
||||
return -torch.logsumexp(logits / temperature, dim=1)
|
||||
|
||||
|
||||
def generate_noise(batch_size, seq_length, input_dim, device):
|
||||
return torch.randn(batch_size, seq_length, input_dim).to(device)
|
||||
|
||||
|
||||
def train_energy_model(
|
||||
model,
|
||||
train_loader,
|
||||
negative_loader,
|
||||
criterion,
|
||||
optimizer,
|
||||
num_epochs=10,
|
||||
margin=1.0,
|
||||
temperature=0.4,
|
||||
):
|
||||
model.train()
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model.to(device)
|
||||
negative_iter = iter(negative_loader)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
total_loss = 0
|
||||
for _, (labels, embeddings) in enumerate(train_loader):
|
||||
embeddings = embeddings.to(device)
|
||||
labels = labels.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
outputs = model(embeddings)
|
||||
class_loss = criterion(outputs, labels)
|
||||
known_energy = energy_score(outputs, temperature)
|
||||
positive_margin = 0.0
|
||||
energy_loss_known = F.relu(known_energy - positive_margin).mean()
|
||||
|
||||
noise_embeddings = torch.randn_like(embeddings).to(device)
|
||||
noise_outputs = model(noise_embeddings)
|
||||
noise_energy = energy_score(noise_outputs, temperature)
|
||||
energy_loss_noise = F.relu(margin - noise_energy).mean()
|
||||
|
||||
try:
|
||||
negative_samples = next(negative_iter)
|
||||
except StopIteration:
|
||||
negative_iter = iter(negative_loader)
|
||||
negative_samples = next(negative_iter)
|
||||
negative_samples = negative_samples.to(device)
|
||||
negative_outputs = model(negative_samples)
|
||||
negative_energy = energy_score(negative_outputs, temperature)
|
||||
energy_loss_negative = F.relu(margin - negative_energy).mean()
|
||||
|
||||
total_energy_loss = (
|
||||
energy_loss_known + energy_loss_noise + energy_loss_negative
|
||||
)
|
||||
total_loss_batch = class_loss + total_energy_loss
|
||||
|
||||
total_loss_batch.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += total_loss_batch.item()
|
||||
|
||||
avg_loss = total_loss / len(train_loader)
|
||||
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")
|
||||
|
||||
|
||||
def main():
|
||||
from config import model_name, DIMENSIONS
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
data = load_data("data.json")
|
||||
class_to_idx, idx_to_class = create_class_mappings(data)
|
||||
embedding_map = torch.load("token_id_to_reduced_embedding.pt")
|
||||
dataset = preprocess_data(data, embedding_map, tokenizer, class_to_idx)
|
||||
train_data, _ = train_test_split(dataset, test_size=0.2)
|
||||
|
||||
train_dataset = TextDataset(train_data)
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset, batch_size=24, shuffle=True, collate_fn=train_dataset.collate_fn
|
||||
)
|
||||
|
||||
with open("noise.json", "r") as f:
|
||||
negative_samples_list = json.load(f)
|
||||
|
||||
negative_embedding_list = []
|
||||
for sentence in negative_samples_list:
|
||||
tokens = tokenizer.tokenize(sentence)
|
||||
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
embeddings = [embedding_map[token_id] for token_id in token_ids[:64]]
|
||||
embeddings = torch.tensor(embeddings)
|
||||
negative_embedding_list.append(embeddings)
|
||||
|
||||
negative_dataset = NegativeSampleDataset(negative_embedding_list)
|
||||
negative_loader = DataLoader(
|
||||
negative_dataset,
|
||||
batch_size=24,
|
||||
shuffle=True,
|
||||
collate_fn=negative_dataset.collate_fn,
|
||||
)
|
||||
|
||||
input_dim = DIMENSIONS
|
||||
num_classes = len(class_to_idx)
|
||||
model = AttentionBasedModel(input_dim, num_classes)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=7e-4)
|
||||
|
||||
train_energy_model(
|
||||
model, train_loader, negative_loader, criterion, optimizer, num_epochs=120
|
||||
)
|
||||
|
||||
torch.save(model.state_dict(), "model.pt")
|
||||
|
||||
dummy_input = torch.randn(1, 64, DIMENSIONS)
|
||||
torch.onnx.export(
|
||||
model,
|
||||
dummy_input,
|
||||
"model.onnx",
|
||||
input_names=["input"],
|
||||
output_names=["output"],
|
||||
dynamic_axes={
|
||||
"input": {0: "batch_size", 1: "seq_length"},
|
||||
"output": {0: "batch_size"},
|
||||
},
|
||||
opset_version=11,
|
||||
)
|
||||
meta = {
|
||||
"idx_to_class": idx_to_class,
|
||||
"threshold": 0
|
||||
}
|
||||
with open('NLU_meta.json', 'w') as f:
|
||||
json.dump(meta, f)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,75 +0,0 @@
|
||||
from training.model import AttentionBasedModel
|
||||
from training.config import model_name
|
||||
import json
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from training.config import DIMENSIONS
|
||||
from training.model import AttentionBasedModel
|
||||
|
||||
|
||||
def energy_score(logits):
|
||||
# Energy score is minus logsumexp
|
||||
return -torch.logsumexp(logits, dim=1)
|
||||
|
||||
|
||||
def predict_with_energy(
|
||||
model,
|
||||
sentence,
|
||||
embedding_map,
|
||||
tokenizer,
|
||||
idx_to_class,
|
||||
energy_threshold,
|
||||
max_length=64,
|
||||
):
|
||||
model.eval()
|
||||
tokens = tokenizer.tokenize(sentence)
|
||||
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
print(token_ids)
|
||||
embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]
|
||||
embeddings = torch.tensor(embeddings).unsqueeze(0) # Add batch dimension
|
||||
current_shape = embeddings.shape
|
||||
|
||||
if current_shape[1] < 2:
|
||||
pad_size = 2 - current_shape[1]
|
||||
embeddings = F.pad(
|
||||
embeddings, (0, 0, 0, pad_size, 0, 0), mode="constant", value=0
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(embeddings)
|
||||
print(logits)
|
||||
probabilities = F.softmax(logits, dim=1)
|
||||
max_prob, predicted = torch.max(probabilities, 1)
|
||||
|
||||
# Calculate energy score
|
||||
energy = energy_score(logits)
|
||||
|
||||
# If energy > threshold, consider the input as unknown class
|
||||
if energy.item() > energy_threshold:
|
||||
return ["Unknown", max_prob.item(), energy.item()]
|
||||
else:
|
||||
return [idx_to_class[predicted.item()], max_prob.item(), energy.item()]
|
||||
|
||||
|
||||
with open("data.json", "r") as f:
|
||||
data = json.load(f)
|
||||
class_to_idx = {cls: idx for idx, cls in enumerate(data.keys())}
|
||||
idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}
|
||||
num_classes = len(class_to_idx)
|
||||
|
||||
input_dim = DIMENSIONS
|
||||
model = AttentionBasedModel(input_dim, num_classes)
|
||||
model.load_state_dict(torch.load("./model.pt"))
|
||||
embedding_map = torch.load("token_id_to_reduced_embedding.pt")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
# Example usage:
|
||||
ENERGY_THRESHOLD = 2
|
||||
sentence = "what on earth is the cross entropy loss"
|
||||
energy_threshold = ENERGY_THRESHOLD
|
||||
predicted = predict_with_energy(
|
||||
model, sentence, embedding_map, tokenizer, idx_to_class, energy_threshold
|
||||
)
|
||||
print(f"Predicted: {predicted}")
|
@ -1,80 +0,0 @@
|
||||
from training.model import AttentionBasedModel
|
||||
from training.config import model_name
|
||||
from training.config import DIMENSIONS
|
||||
from training.data_utils import get_sentences
|
||||
import json
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import AutoTokenizer
|
||||
from tqdm import tqdm
|
||||
from sklearn.metrics import f1_score, accuracy_score, precision_recall_fscore_support
|
||||
|
||||
def energy_score(logits):
|
||||
# Energy score is minus logsumexp
|
||||
return -torch.logsumexp(logits, dim=1)
|
||||
|
||||
|
||||
def get_energy(
|
||||
model,
|
||||
sentence,
|
||||
embedding_map,
|
||||
tokenizer,
|
||||
max_length=64,
|
||||
):
|
||||
model.eval()
|
||||
tokens = tokenizer.tokenize(sentence)
|
||||
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
embeddings = [embedding_map[token_id] for token_id in token_ids[:max_length]]
|
||||
embeddings = torch.tensor(embeddings).unsqueeze(0) # Add batch dimension
|
||||
current_shape = embeddings.shape
|
||||
|
||||
if current_shape[1] < 2:
|
||||
pad_size = 2 - current_shape[1]
|
||||
embeddings = F.pad(
|
||||
embeddings, (0, 0, 0, pad_size, 0, 0), mode="constant", value=0
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(embeddings)
|
||||
# Calculate energy score
|
||||
energy = energy_score(logits)
|
||||
|
||||
return energy
|
||||
|
||||
|
||||
with open("data.json", "r") as f:
|
||||
positive_data = json.load(f)
|
||||
class_to_idx = {cls: idx for idx, cls in enumerate(positive_data.keys())}
|
||||
idx_to_class = {idx: cls for cls, idx in class_to_idx.items()}
|
||||
num_classes = len(class_to_idx)
|
||||
|
||||
with open("noise.json", "r") as f:
|
||||
negative_data = json.load(f)
|
||||
|
||||
input_dim = DIMENSIONS
|
||||
model = AttentionBasedModel(input_dim, num_classes)
|
||||
model.load_state_dict(torch.load("./model.pt"))
|
||||
embedding_map = torch.load("token_id_to_reduced_embedding.pt")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
|
||||
all_preds = []
|
||||
all_labels = []
|
||||
ENERGY_THRESHOLD = 2
|
||||
for item in tqdm(get_sentences(positive_data)):
|
||||
result = get_energy(model, item, embedding_map, tokenizer) < ENERGY_THRESHOLD
|
||||
all_preds.append(result)
|
||||
all_labels.append(1)
|
||||
|
||||
for item in tqdm(negative_data):
|
||||
result = get_energy(model, item, embedding_map, tokenizer) < ENERGY_THRESHOLD
|
||||
all_preds.append(result)
|
||||
all_labels.append(0)
|
||||
|
||||
precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')
|
||||
accuracy = accuracy_score(all_labels, all_preds)
|
||||
|
||||
print(f'Accuracy: {accuracy:.4f}')
|
||||
print(f'Precision: {precision:.4f}')
|
||||
print(f'Recall: {recall:.4f}')
|
||||
print(f'F1 Score: {f1:.4f}')
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,103 +0,0 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import nltk
|
||||
from nltk.tokenize import sent_tokenize, word_tokenize
|
||||
from training.model import AttentionBasedModel
|
||||
|
||||
# Ensure required NLTK resources are available
|
||||
nltk.download('punkt')
|
||||
nltk.download('averaged_perceptron_tagger')
|
||||
|
||||
# Load pre-saved mappings
|
||||
pos2idx = torch.load('pos2idx.pt')
|
||||
class_mapping = torch.load('class_mapping.pt')
|
||||
|
||||
|
||||
# Load the pre-trained model and state
|
||||
model = AttentionBasedModel(40, 32, 6, 8, 6, 256)
|
||||
model.load_state_dict(torch.load("./model.pt", weights_only=True))
|
||||
|
||||
# Define helper functions
|
||||
def pad_sequence(seq, max_len):
|
||||
return seq + [-1] * (max_len - len(seq))
|
||||
|
||||
def encode_pos_tags(tagged_sentence):
|
||||
return [pos2idx[tag] if tag in pos2idx else -1 for _, tag in tagged_sentence]
|
||||
|
||||
# Split sentence into smaller chunks based on punctuation and length constraints
|
||||
def split_long_sentence(sentence, max_len=128):
|
||||
tokens = word_tokenize(sentence)
|
||||
|
||||
if len(tokens) <= max_len:
|
||||
return [sentence]
|
||||
|
||||
# Attempt to split based on punctuation marks
|
||||
punctuation_marks = [',', ';', ':', '!', '?', '.', '-']
|
||||
split_chunks = []
|
||||
current_chunk = []
|
||||
|
||||
for token in tokens:
|
||||
current_chunk.append(token)
|
||||
if token in punctuation_marks and len(current_chunk) >= max_len // 2:
|
||||
split_chunks.append(' '.join(current_chunk))
|
||||
current_chunk = []
|
||||
|
||||
if current_chunk:
|
||||
split_chunks.append(' '.join(current_chunk))
|
||||
|
||||
# If chunks are still too long, truncate them
|
||||
final_chunks = []
|
||||
for chunk in split_chunks:
|
||||
chunk_tokens = word_tokenize(chunk)
|
||||
if len(chunk_tokens) > max_len:
|
||||
final_chunks.extend([' '.join(chunk_tokens[i:i + max_len]) for i in range(0, len(chunk_tokens), max_len)])
|
||||
else:
|
||||
final_chunks.append(chunk)
|
||||
|
||||
return final_chunks
|
||||
|
||||
# Main function to process and score a chunk
|
||||
def score_sentence(sentence, model, max_length=128):
|
||||
# Tokenize and POS-tag the sentence
|
||||
tagged_sentence = nltk.pos_tag(nltk.word_tokenize(sentence))
|
||||
|
||||
# Encode the POS tags and pad the sequence
|
||||
encoded_sentences = encode_pos_tags(tagged_sentence)
|
||||
padded_sentence = torch.tensor(pad_sequence(encoded_sentences, max_length))
|
||||
|
||||
# Set the device
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
|
||||
|
||||
# Prepare the model
|
||||
model.to(device)
|
||||
model.eval() # Ensure the model is in evaluation mode
|
||||
|
||||
# Define weights and CEFR levels
|
||||
w_list = [1.04, 1.64, 2.35, 3.44, 4.92, 6.13]
|
||||
|
||||
# Inference without gradient calculation
|
||||
with torch.no_grad():
|
||||
sentence_tensor = padded_sentence.to(device)
|
||||
sentence_tensor = torch.unsqueeze(sentence_tensor, 0) # Add batch dimension
|
||||
|
||||
# Forward pass through the model
|
||||
outputs = model(sentence_tensor)
|
||||
|
||||
# Softmax and weighted scoring
|
||||
probabilities = torch.softmax(outputs[0], dim=0)
|
||||
score = sum(probabilities[i] * w_list[i] for i in range(6)).cpu().numpy()
|
||||
|
||||
return score
|
||||
|
||||
# Function to process a long article and return score list for each chunk
|
||||
def score_article(article, max_length=128, chunk_max_len=128):
|
||||
sentences = sent_tokenize(article) # Split the article into sentences
|
||||
score_list = []
|
||||
|
||||
for sentence in sentences:
|
||||
chunks = split_long_sentence(sentence, max_len=chunk_max_len)
|
||||
for chunk in chunks:
|
||||
score = score_sentence(chunk, model, max_length=max_length)
|
||||
score_list.append(float(score))
|
||||
|
||||
return score_list
|
File diff suppressed because it is too large
Load Diff
@ -1,82 +0,0 @@
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
import pandas as pd
|
||||
from openai import OpenAI
|
||||
from tqdm import tqdm
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
load_dotenv()
|
||||
|
||||
client = OpenAI(
|
||||
api_key=os.getenv("API_KEY"),
|
||||
base_url=os.getenv("BASE_URL"),
|
||||
)
|
||||
|
||||
def get_AI_response(text, client, model_name, temp):
|
||||
messages = [
|
||||
{"role": "user", "content": text},
|
||||
]
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model_name,
|
||||
messages=messages,
|
||||
temperature=temp,
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
def get_Examples(df, row, client, model_name, temp):
|
||||
exp = df["Example"][row]
|
||||
cds = df["Can-do statement"][row]
|
||||
gdw = df["guideword"][row]
|
||||
lvl = df["Level"][row]
|
||||
cat = df["SuperCategory"][row] + '/' + df["SubCategory"][row]
|
||||
prompt = \
|
||||
f'''Generate 10 example sentences based on the following instructions.
|
||||
Pay close attention to the 'Can-do Statement' and ensure all generated sentences adhere strictly to it.
|
||||
Provide only the sentences without any additional formatting or markdown.
|
||||
Output the sentences in plain text, one sentence per line, and do not contain empty line.
|
||||
INSTRUCTION
|
||||
Level: {lvl}
|
||||
Guideword: {gdw}
|
||||
Can-do Statement: {cds}
|
||||
Category: {cat}
|
||||
Example Sentences:
|
||||
{exp}
|
||||
'''
|
||||
return get_AI_response(prompt, client, model_name, temp)
|
||||
|
||||
def process_chunk(df, chunk, client, model, temp):
|
||||
results = []
|
||||
for row in chunk:
|
||||
exps = get_Examples(df, row, client, model, temp)
|
||||
results.append(exps)
|
||||
return results
|
||||
|
||||
input_file = './EGP.csv'
|
||||
df = pd.read_csv(input_file)
|
||||
newdf = df.copy()
|
||||
model = os.getenv("TRANSLATION_MODEL")
|
||||
temp = float(os.getenv("TRANSLATION_TEMP"))
|
||||
|
||||
chunk_size = 64
|
||||
total_rows = len(df.index)
|
||||
num_chunks = (total_rows + chunk_size - 1) // chunk_size # Ceiling division
|
||||
|
||||
with tqdm(total=total_rows) as pbar:
|
||||
for chunk_idx in range(num_chunks):
|
||||
start = chunk_idx * chunk_size
|
||||
end = min(start + chunk_size, total_rows)
|
||||
chunk = range(start, end)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=len(chunk)) as executor:
|
||||
futures = {executor.submit(get_Examples, df, row, client, model, temp): row for row in chunk} # 将 row 与 future 绑定
|
||||
for future in as_completed(futures):
|
||||
row = futures[future] # 获取对应的行号
|
||||
result = future.result() # 获取 AI 返回的结果
|
||||
newdf.at[row, "Example"] = result # 更新到正确的行
|
||||
|
||||
pbar.update(len(chunk))
|
||||
newdf.to_csv("output.csv", index=False)
|
||||
|
||||
newdf.to_csv("EGP_Derivied.csv", index=False)
|
@ -1,23 +0,0 @@
|
||||
import pandas as pd
|
||||
|
||||
df = pd.read_csv("EGP_Derivied.csv")
|
||||
newdf = pd.DataFrame()
|
||||
|
||||
levels_list=[]
|
||||
sentences_list=[]
|
||||
category_list=[]
|
||||
for line in range(len(df.index)):
|
||||
examples = list(filter(None, df["Example"][line].split("\n")))
|
||||
lvl = df["Level"][line]
|
||||
cat = df["SuperCategory"][line] + '/' + df["SubCategory"][line]
|
||||
for sentence in examples:
|
||||
sentences_list.append(sentence)
|
||||
levels_list.append(lvl)
|
||||
category_list.append(cat)
|
||||
|
||||
|
||||
newdf["Level"] = levels_list
|
||||
newdf["Category"] = category_list
|
||||
newdf["Sentence"] = sentences_list
|
||||
|
||||
newdf.to_csv("data.csv", index=False)
|
@ -1,111 +0,0 @@
|
||||
# model.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
def __init__(self, embedding_dim, max_len=5000):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
|
||||
# Create a positional encoding matrix of shape (max_len, embedding_dim)
|
||||
pe = torch.zeros(max_len, embedding_dim)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
|
||||
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
|
||||
# Add a batch dimension, so the shape becomes (1, max_len, embedding_dim)
|
||||
pe = pe.unsqueeze(0)
|
||||
|
||||
# Register the positional encoding as a buffer so it won't be updated by the optimizer
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
def forward(self, x):
|
||||
# x is expected to have shape (batch_size, seq_length, embedding_dim)
|
||||
seq_length = x.size(1)
|
||||
# Add positional encoding to input
|
||||
x = x + self.pe[:, :seq_length]
|
||||
return x
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, input_dim, heads):
|
||||
super(SelfAttention, self).__init__()
|
||||
self.heads = heads
|
||||
self.scale = (input_dim // heads) ** -0.5
|
||||
self.qkv = nn.Linear(input_dim, input_dim * 3)
|
||||
self.fc = nn.Linear(input_dim, input_dim)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, seq_length, embedding_dim = x.shape
|
||||
qkv = self.qkv(x).view(
|
||||
batch_size, seq_length, self.heads, 3, embedding_dim // self.heads
|
||||
)
|
||||
q, k, v = qkv[..., 0, :], qkv[..., 1, :], qkv[..., 2, :]
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 1, 3)
|
||||
v = v.permute(0, 2, 1, 3)
|
||||
|
||||
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
||||
attn_weights = F.softmax(attn_weights, dim=-1)
|
||||
|
||||
attention_output = torch.matmul(attn_weights, v)
|
||||
attention_output = attention_output.permute(0, 2, 1, 3).contiguous()
|
||||
attention_output = attention_output.view(batch_size, seq_length, embedding_dim)
|
||||
|
||||
return self.fc(attention_output)
|
||||
|
||||
class AttentionBasedModel(nn.Module):
|
||||
def __init__(self, pos_vocab_size, embedding_dim=128, num_classes=6, heads=8, num_attention_layers=3, dim_feedforward=512, max_len=128):
|
||||
super(AttentionBasedModel, self).__init__()
|
||||
self.embedding = nn.Embedding(pos_vocab_size, embedding_dim) # Embedding for POS tags
|
||||
self.positional_encoding = PositionalEncoding(embedding_dim, max_len) # Positional Encoding
|
||||
self.self_attention_layers = nn.ModuleList([
|
||||
SelfAttention(embedding_dim, heads) for _ in range(num_attention_layers)
|
||||
])
|
||||
self.fc1 = nn.Linear(embedding_dim, dim_feedforward)
|
||||
self.fc2 = nn.Linear(dim_feedforward, num_classes)
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
self.norm = nn.LayerNorm(embedding_dim)
|
||||
|
||||
def forward(self, x):
|
||||
# Input x is a matrix of one-hot encoded POS tags, shape: (batch_size, seq_length, pos_vocab_size)
|
||||
x = self.embedding(x) # Convert POS tags to embeddings, shape: (batch_size, seq_length, embedding_dim)
|
||||
|
||||
# Add positional encoding to embeddings
|
||||
x = self.positional_encoding(x)
|
||||
|
||||
for attn_layer in self.self_attention_layers:
|
||||
attn_output = attn_layer(x)
|
||||
x = self.norm(attn_output + x)
|
||||
|
||||
# Pool the output by taking the mean of the sequence (reduce along sequence length)
|
||||
pooled_output = torch.mean(x, dim=1)
|
||||
|
||||
# Fully connected layers for classification
|
||||
x = F.relu(self.fc1(pooled_output))
|
||||
x = self.dropout(x)
|
||||
x = self.fc2(x) # Output logits for the 6 classes
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# Example Usage
|
||||
# # Hyperparameters
|
||||
# pos_vocab_size = 50 # Size of the POS tag vocabulary
|
||||
# max_context_length = 128 # Maximum context length
|
||||
# embedding_dim = 128 # Embedding size
|
||||
# num_classes = 6 # Output classes
|
||||
# batch_size = 32 # Example batch size
|
||||
|
||||
# # Model initialization
|
||||
# model = AttentionBasedModel(pos_vocab_size, embedding_dim, num_classes)
|
||||
|
||||
# # Example input: batch of one-hot encoded POS tags (variable length sequences)
|
||||
# input_data = torch.randint(0, pos_vocab_size, (batch_size, max_context_length)) # Random input for testing
|
||||
|
||||
# # Forward pass
|
||||
# output = model(input_data) # Output shape will be (batch_size, num_classes)
|
||||
|
||||
# print(output.shape) # Should print torch.Size([batch_size, num_classes])
|
@ -1,150 +0,0 @@
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import pandas as pd
|
||||
import nltk
|
||||
from sklearn.preprocessing import LabelEncoder
|
||||
from sklearn.model_selection import train_test_split
|
||||
import numpy as np
|
||||
|
||||
#nltk.download('punkt')
|
||||
#nltk.download('averaged_perceptron_tagger')
|
||||
|
||||
# Load the model classes
|
||||
from model import AttentionBasedModel
|
||||
|
||||
# Load the data
|
||||
df = pd.read_csv('data.csv')
|
||||
|
||||
# Step 1: Extract sentences and corresponding levels
|
||||
sentences = df['Sentence'].values
|
||||
levels = df['Level'].values
|
||||
|
||||
# Step 2: Tokenize and POS tag each sentence
|
||||
pos_tags = [nltk.pos_tag(nltk.word_tokenize(sentence)) for sentence in sentences]
|
||||
|
||||
# Step 3: Build POS tag vocabulary
|
||||
# Extract unique POS tags from the dataset
|
||||
pos_vocab = set()
|
||||
for tagged_sentence in pos_tags:
|
||||
for _, tag in tagged_sentence:
|
||||
pos_vocab.add(tag)
|
||||
|
||||
# Create a mapping from POS tag to index
|
||||
pos2idx = {pos: idx for idx, pos in enumerate(pos_vocab)}
|
||||
pos_vocab_size = len(pos2idx)
|
||||
|
||||
# Step 4: Encode sentences into POS tag indices
|
||||
def encode_pos_tags(tagged_sentence):
|
||||
return [pos2idx[tag] for _, tag in tagged_sentence]
|
||||
|
||||
encoded_sentences = [encode_pos_tags(tagged_sentence) for tagged_sentence in pos_tags]
|
||||
|
||||
# Step 5: Encode levels (classes) into integers
|
||||
le = LabelEncoder()
|
||||
encoded_levels = le.fit_transform(levels)
|
||||
num_classes = len(le.classes_)
|
||||
|
||||
# Save class encoding mapping
|
||||
class_mapping = dict(zip(le.transform(le.classes_), le.classes_))
|
||||
torch.save(class_mapping, 'class_mapping.pt')
|
||||
|
||||
# Save POS tag encoding mapping
|
||||
torch.save(pos2idx, 'pos2idx.pt')
|
||||
|
||||
# Step 6: Pad sentences to a fixed length
|
||||
max_length = 64
|
||||
|
||||
def pad_sequence(seq, max_len):
|
||||
return seq + [-1] * (max_len - len(seq))
|
||||
|
||||
padded_sentences = [pad_sequence(seq, max_length) for seq in encoded_sentences]
|
||||
|
||||
# Step 7: Create a PyTorch Dataset and DataLoader
|
||||
class POSDataset(Dataset):
|
||||
def __init__(self, sentences, labels):
|
||||
self.sentences = sentences
|
||||
self.labels = labels
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sentences)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sentence = torch.tensor(self.sentences[idx], dtype=torch.long)
|
||||
label = torch.tensor(self.labels[idx], dtype=torch.long)
|
||||
return sentence, label
|
||||
|
||||
# Split data into training and validation sets
|
||||
X_train, X_val, y_train, y_val = train_test_split(padded_sentences, encoded_levels, test_size=0.2)
|
||||
|
||||
train_dataset = POSDataset(X_train, y_train)
|
||||
val_dataset = POSDataset(X_val, y_val)
|
||||
|
||||
batch_size = 128
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
||||
|
||||
# Step 8: Initialize the model, loss function, and optimizer
|
||||
embedding_dim = 32
|
||||
heads = 8
|
||||
num_attention_layers = 6
|
||||
dim_feedforward = 256
|
||||
learning_rate = 0.003
|
||||
|
||||
model = AttentionBasedModel(pos_vocab_size, embedding_dim, num_classes, heads, num_attention_layers, dim_feedforward)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
||||
|
||||
# Step 9: Training loop
|
||||
num_epochs = 100
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
|
||||
model.to(device)
|
||||
|
||||
step = 0
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
model.train()
|
||||
running_loss = 0.0
|
||||
|
||||
for sentences, labels in train_loader:
|
||||
sentences, labels = sentences.to(device), labels.to(device)
|
||||
|
||||
# Forward pass
|
||||
outputs = model(sentences)
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
step += 1
|
||||
|
||||
running_loss += loss.item()
|
||||
|
||||
# Validation phase
|
||||
model.eval()
|
||||
val_loss = 0.0
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for sentences, labels in val_loader:
|
||||
sentences, labels = sentences.to(device), labels.to(device)
|
||||
|
||||
outputs = model(sentences)
|
||||
loss = criterion(outputs, labels)
|
||||
val_loss += loss.item()
|
||||
|
||||
# Calculate accuracy
|
||||
_, predicted = torch.max(outputs, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
# Print training and validation stats
|
||||
print(f'Epoch [{epoch+1}/{num_epochs}], Step {step}, Loss: {running_loss/len(train_loader):.4f}, '
|
||||
f'Validation Loss: {val_loss/len(val_loader):.4f}, Accuracy: {100 * correct / total:.2f}%')
|
||||
torch.save(model.state_dict(), f'checkpoints/step_{step}.pt')
|
||||
|
||||
# Step 10: Save the trained model
|
||||
torch.save(model.state_dict(), 'model.pt')
|
@ -1,46 +0,0 @@
|
||||
from training.model import AttentionBasedModel
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import nltk
|
||||
|
||||
|
||||
sentence = '''Smartphones have worked their way deep into our lives and have become indispensable for work and socialising.'''
|
||||
|
||||
pos2idx = torch.load('pos2idx.pt')
|
||||
class_mapping = torch.load('class_mapping.pt')
|
||||
|
||||
def pad_sequence(seq, max_len):
|
||||
return seq + [-1] * (max_len - len(seq))
|
||||
|
||||
def encode_pos_tags(tagged_sentence):
|
||||
return [pos2idx[tag] if tag in pos2idx else -1 for _, tag in tagged_sentence]
|
||||
|
||||
|
||||
max_length = 64
|
||||
|
||||
tagged_sentence = nltk.pos_tag(nltk.word_tokenize(sentence))
|
||||
encoded_sentences = encode_pos_tags(tagged_sentence)
|
||||
padded_sentence = torch.tensor(pad_sequence(encoded_sentences, max_length))
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
|
||||
model = AttentionBasedModel(40, 32, 6, 8, 6, 256)
|
||||
model.load_state_dict(torch.load("./model.pt", weights_only=True))
|
||||
model.to(device)
|
||||
model.eval() # 确保模型处于评估模式
|
||||
|
||||
|
||||
w_list=[1.35, 1.63, 2.75, 3.64, 5.38, 6.32]
|
||||
cefr_dict = [None, "A1", "A2", "B1", "B2", "C1", "C2"]
|
||||
with torch.no_grad():
|
||||
sentence = padded_sentence.to(device)
|
||||
sentences = torch.unsqueeze(sentence, 0)
|
||||
|
||||
outputs = model(sentences)
|
||||
print(torch.max(outputs, 1))
|
||||
print(outputs[0])
|
||||
print(torch.softmax(outputs[0],0))
|
||||
s=0
|
||||
for i in range(6):
|
||||
s+=torch.softmax(outputs[0],0)[i] * w_list[i]
|
||||
s=s.cpu().numpy()
|
||||
# the s is the final output.
|
@ -105,8 +105,8 @@ def batch_process(input_dir, output_dir, num_threads=4):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
input_dir = "./source-new"
|
||||
output_dir = "./output-new"
|
||||
input_dir = "./source"
|
||||
output_dir = "./output"
|
||||
batch_process(
|
||||
input_dir, output_dir, num_threads=int(os.getenv("TRANSLATE_THREADS"))
|
||||
)
|
||||
|
@ -1,50 +1,19 @@
|
||||
# sparkastML NMT
|
||||
|
||||
A set of models that aims to offer best open-source machine translation, based on the [OpenNMT](https://opennmt.net/).
|
||||
|
||||
## News
|
||||
|
||||
sparkastML's translation model is now updated!
|
||||
sparkastML's first translation model is now available!
|
||||
|
||||
### Details
|
||||
|
||||
- **Source Language:** Chinese (Simplified)
|
||||
- **Target Language:** English
|
||||
- **Training Time:** Totally 11.3 hours, 46,500 steps (~1×10¹⁸ FLOPs)
|
||||
- **Training Device:**
|
||||
- RTX 3080 (20GB): 0-20,000 steps
|
||||
- RTX 4070: 20,000-46,500 steps
|
||||
- **Corpus Size:** Over 10 million sentences
|
||||
- **Validation BLEU Score:** 21.28
|
||||
- **Validation Loss (Cross Entropy):** 3.152
|
||||
|
||||
### Model Download
|
||||
|
||||
Avaliable soon.
|
||||
|
||||
### Special thanks
|
||||
|
||||
[yumechi](https://github.com/eternal-flame-AD/) for sponsoring an RTX 4070 for training.
|
||||
|
||||
## History
|
||||
|
||||
### Sep 19, 2024
|
||||
|
||||
sparkastML's translation model is now updated!
|
||||
|
||||
#### Details
|
||||
|
||||
- **Source Language:** Chinese (Simplified)
|
||||
- **Target Language:** English
|
||||
- **Training Time:** 5 hours, 20,000 steps
|
||||
- **Training Device:** RTX 3080 (20GB)
|
||||
- **Corpus Size:** Over 10 million sentences
|
||||
- **Validation BLEU Score:** 17
|
||||
- **Version:** 1.0
|
||||
|
||||
#### Model Download
|
||||
### Model Download
|
||||
|
||||
- **Google Drive:** [Download from Google Drive](https://drive.google.com/drive/folders/1-q_AKfQENW-pV6uAleUHPE9ghddfNWKF)
|
||||
- **IPFS:** [Download from IPFS](http://ipfs.a2x.pub/ipfs/QmUMadzkBwvH5KTpoxfv7TgqzaPpqBzkXtkecV9TXPfZ3F/)
|
||||
- **CID:** `QmUMadzkBwvH5KTpoxfv7TgqzaPpqBzkXtkecV9TXPfZ3F`
|
||||
- **Google Drive:** [Download from Google Drive](https://drive.google.com/file/d/1bJkkqQJLdwTgXFXVeP7fjPawfwelzeIB/view)
|
||||
- **IPFS:** [Download from IPFS](http://ipfs.a2x.pub/ipfs/QmNw3Mo3N31wwTQPXzNeGD8jPpkGp5VFQcC9gk44bfqW1u/)
|
||||
- **CID:** `QmNw3Mo3N31wwTQPXzNeGD8jPpkGp5VFQcC9gk44bfqW1u`
|
||||
- **GitHub Release:** [Go to Release Page](https://github.com/alikia2x/sparkastML/releases/tag/v2-model)
|
||||
|
File diff suppressed because one or more lines are too long
@ -1,79 +0,0 @@
|
||||
import torch
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
from numpy.linalg import norm
|
||||
import sys
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
|
||||
# Define the cosine similarity function
|
||||
cos_sim = lambda a, b: (a @ b.T) / (norm(a) * norm(b))
|
||||
|
||||
# Load the model and tokenizer
|
||||
model_name = 'jinaai/jina-embeddings-v2-base-zh'
|
||||
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
|
||||
|
||||
# Check if the correct number of command-line arguments are provided
|
||||
if len(sys.argv) < 4 or len(sys.argv) > 5:
|
||||
print("Usage: python script.py <file_a_path> <file_b_path> <output_file_path> [num_samples]")
|
||||
sys.exit(1)
|
||||
|
||||
# Define file paths from command-line arguments
|
||||
file_a_path = sys.argv[1]
|
||||
file_b_path = sys.argv[2]
|
||||
output_file_path = sys.argv[3]
|
||||
|
||||
# Define the number of samples to randomly select
|
||||
num_samples = int(sys.argv[4]) if len(sys.argv) == 5 else 100
|
||||
|
||||
# Get the total number of lines in the files without loading them fully
|
||||
def count_lines(file_path):
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return sum(1 for _ in f)
|
||||
|
||||
total_lines_a = count_lines(file_a_path)
|
||||
total_lines_b = count_lines(file_b_path)
|
||||
|
||||
# Ensure both files have the same number of lines
|
||||
if total_lines_a != total_lines_b:
|
||||
print("Files must have the same number of lines.")
|
||||
sys.exit(1)
|
||||
|
||||
# Select random sample indices without loading entire files
|
||||
selected_indices = sorted(random.sample(range(total_lines_a), num_samples))
|
||||
|
||||
# Function to get all sampled lines from the file
|
||||
def get_lines(file_path, line_numbers):
|
||||
result = []
|
||||
max_i = max(line_numbers)
|
||||
j=0
|
||||
next_i = line_numbers[j]
|
||||
len_line_numbers = len(line_numbers)
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for current_line, line in tqdm(enumerate(f)):
|
||||
if current_line < next_i:
|
||||
continue
|
||||
result.append(line.strip())
|
||||
j+=1
|
||||
if current_line >= max_i or j >= len_line_numbers:
|
||||
return result
|
||||
next_i = line_numbers[j]
|
||||
|
||||
return result
|
||||
|
||||
lines_a = get_lines(file_a_path, selected_indices)
|
||||
lines_b = get_lines(file_b_path, selected_indices)
|
||||
|
||||
# Open output file for writing
|
||||
with open(output_file_path, 'w', encoding='utf-8') as output_file:
|
||||
for i, idx in tqdm(enumerate(selected_indices)):
|
||||
# Get the corresponding lines from both files
|
||||
line_a = lines_a[i]
|
||||
line_b = lines_b[i]
|
||||
|
||||
embeddings = model.encode([line_a, line_b])
|
||||
similarity = cos_sim(embeddings[0], embeddings[1])
|
||||
|
||||
# Write the similarity to the output file
|
||||
output_file.write(f"{similarity}\n")
|
||||
|
||||
print(f"Similarity calculation completed. Results saved to {output_file_path}")
|
@ -1,60 +0,0 @@
|
||||
from transformers import AutoModel
|
||||
from numpy.linalg import norm
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Usage: python filter.py <file_a_path> <file_b_path> <output_file_path>"
|
||||
)
|
||||
|
||||
parser.add_argument("file_a", type=str, help="File No.1")
|
||||
parser.add_argument("file_b", type=str, help="File No.2")
|
||||
parser.add_argument("output", type=str, help="Output file")
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Resume from specified line",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Define the cosine similarity function
|
||||
cos_sim = lambda a, b: (a @ b.T) / (norm(a) * norm(b))
|
||||
|
||||
# Load the model and tokenizer
|
||||
model_name = 'jinaai/jina-embeddings-v2-base-zh'
|
||||
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
|
||||
model.to('cuda')
|
||||
|
||||
# Define file paths from command-line arguments
|
||||
file_a_path = args.file_a
|
||||
file_b_path = args.file_b
|
||||
output_file_path = args.output
|
||||
|
||||
resume_from = args.resume
|
||||
resume = resume_from >= 0
|
||||
output_file_mode = 'a' if resume else 'w'
|
||||
|
||||
# Open files
|
||||
with open(file_a_path, 'r', encoding='utf-8') as file_a, \
|
||||
open(file_b_path, 'r', encoding='utf-8') as file_b, \
|
||||
open(output_file_path, output_file_mode, encoding='utf-8') as output_file:
|
||||
i=1
|
||||
# Read file A and file B line by line
|
||||
for line_a, line_b in tqdm(zip(file_a, file_b)):
|
||||
if resume and i < resume_from:
|
||||
i+=1
|
||||
continue
|
||||
# Remove trailing newline characters
|
||||
line_a = line_a.strip()
|
||||
line_b = line_b.strip()
|
||||
|
||||
embeddings = model.encode([line_a, line_b])
|
||||
similarity = cos_sim(embeddings[0], embeddings[1])
|
||||
|
||||
# Write the similarity to the output file
|
||||
output_file.write(f"{similarity}\n")
|
||||
|
||||
i+=1
|
||||
|
||||
print(f"Similarity calculation completed. Results saved to {output_file_path}")
|
@ -1,74 +0,0 @@
|
||||
from transformers import AutoModel
|
||||
from numpy.linalg import norm
|
||||
import sys
|
||||
import random
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
|
||||
# Define the cosine similarity function
|
||||
cos_sim = lambda a, b: (a @ b.T) / (norm(a) * norm(b))
|
||||
|
||||
# Load the model and tokenizer
|
||||
model_name = 'jinaai/jina-embeddings-v2-base-zh'
|
||||
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
|
||||
|
||||
# Check if the correct number of command-line arguments are provided
|
||||
if len(sys.argv) < 4 or len(sys.argv) > 5:
|
||||
print("Usage: python script.py <file_path> <output_file_path> [num_samples]")
|
||||
sys.exit(1)
|
||||
|
||||
# Define file paths from command-line arguments
|
||||
file_path = sys.argv[1]
|
||||
output_file_path = sys.argv[2]
|
||||
|
||||
# Define the number of samples to randomly select
|
||||
num_samples = int(sys.argv[3]) if len(sys.argv) == 4 else 100
|
||||
|
||||
# Get the total number of lines in the files without loading them fully
|
||||
def count_lines(file_path):
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
return sum(1 for _ in f)
|
||||
|
||||
total_lines = count_lines(file_path)
|
||||
|
||||
# Select random sample indices without loading entire files
|
||||
selected_indices = sorted(random.sample(range(total_lines), num_samples))
|
||||
|
||||
# Function to get all sampled lines from the file
|
||||
def get_lines(file_path, line_numbers):
|
||||
result = []
|
||||
max_i = max(line_numbers)
|
||||
j=0
|
||||
next_i = line_numbers[j]
|
||||
len_line_numbers = len(line_numbers)
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for current_line, line in tqdm(enumerate(f)):
|
||||
if current_line < next_i:
|
||||
continue
|
||||
result.append(line.strip())
|
||||
j+=1
|
||||
if current_line >= max_i or j >= len_line_numbers:
|
||||
return result
|
||||
next_i = line_numbers[j]
|
||||
|
||||
return result
|
||||
|
||||
lines = get_lines(file_path, selected_indices)
|
||||
|
||||
# Open output file for writing
|
||||
with open(output_file_path, 'w', encoding='utf-8') as output_file, open("1.txt", 'w', encoding='utf-8') as lf:
|
||||
for i, idx in tqdm(enumerate(selected_indices)):
|
||||
# Get the corresponding lines from both files
|
||||
line = lines[i]
|
||||
data = json.loads(line)
|
||||
chn = data["chinese"]
|
||||
eng = data["english"]
|
||||
lf.write(str(idx)+'\n')
|
||||
|
||||
embeddings = model.encode([chn, eng])
|
||||
similarity = cos_sim(embeddings[0], embeddings[1])
|
||||
|
||||
# Write the similarity to the output file
|
||||
output_file.write(f"{similarity}\n")
|
||||
|
||||
print(f"Similarity calculation completed. Results saved to {output_file_path}")
|
@ -1,30 +0,0 @@
|
||||
import pandas as pd
|
||||
|
||||
# 定义文件路径
|
||||
source_files = ['./result/source.txt', './result/source-new.txt']
|
||||
target_files = ['./result/target.txt', './result/target-new.txt']
|
||||
|
||||
# 读取source和target文件内容
|
||||
source_data = []
|
||||
target_data = []
|
||||
|
||||
for file in source_files:
|
||||
with open(file, 'r', encoding='utf-8') as f:
|
||||
source_data.extend(f.readlines())
|
||||
|
||||
for file in target_files:
|
||||
with open(file, 'r', encoding='utf-8') as f:
|
||||
target_data.extend(f.readlines())
|
||||
|
||||
# 确保source和target行数一致
|
||||
if len(source_data) != len(target_data):
|
||||
print("Warning: The number of lines in source and target files do not match.")
|
||||
|
||||
# 创建DataFrame
|
||||
df = pd.DataFrame({
|
||||
'zh': [line.strip() for line in source_data], # 去掉每行的换行符
|
||||
'en': [line.strip() for line in target_data] # 去掉每行的换行符
|
||||
})
|
||||
|
||||
|
||||
df.to_csv('./result/data.csv', index=False, encoding='utf-8')
|
@ -27,8 +27,8 @@ def process_json_files(directory, converted_filename):
|
||||
data = json.load(json_file)
|
||||
segments = data.get('segments', [])
|
||||
|
||||
with open('./result/source-new.txt', 'a', encoding='utf-8') as source_file, \
|
||||
open('./result/target-new.txt', 'a', encoding='utf-8') as target_file:
|
||||
with open('./result/source.txt', 'a', encoding='utf-8') as source_file, \
|
||||
open('./result/target.txt', 'a', encoding='utf-8') as target_file:
|
||||
for segment in segments:
|
||||
chinese_text = segment.get('chinese', '').replace('\n', ' ')
|
||||
english_text = segment.get('english', '').replace('\n', ' ')
|
||||
@ -42,7 +42,7 @@ def process_json_files(directory, converted_filename):
|
||||
write_converted_file(converted_filename, filename)
|
||||
|
||||
if __name__ == "__main__":
|
||||
json_directory = './output-new' # 替换为你的JSON文件目录路径
|
||||
json_directory = './output' # 替换为你的JSON文件目录路径
|
||||
converted_filename = './result/converted.txt'
|
||||
|
||||
process_json_files(json_directory, converted_filename)
|
@ -1,52 +0,0 @@
|
||||
import os
|
||||
import re
|
||||
|
||||
def split_content(content):
|
||||
sentences = re.split(r'[。!?;.!?;]', content)
|
||||
segments = []
|
||||
current_segment = []
|
||||
current_length = 0
|
||||
|
||||
for sentence in sentences:
|
||||
sentence_length = len(sentence)
|
||||
if (len(current_segment) >= 25 or current_length + sentence_length > 1200):
|
||||
segments.append(''.join(current_segment))
|
||||
current_segment = []
|
||||
current_length = 0
|
||||
|
||||
current_segment.append(sentence)
|
||||
current_length += sentence_length
|
||||
|
||||
if current_segment:
|
||||
segments.append(''.join(current_segment))
|
||||
|
||||
return segments
|
||||
|
||||
def process_files_in_directory(directory):
|
||||
for filename in os.listdir(directory):
|
||||
file_path = os.path.join(directory, filename)
|
||||
|
||||
# 只处理文件,跳过目录
|
||||
if os.path.isfile(file_path):
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
content = file.read()
|
||||
|
||||
segments = split_content(content)
|
||||
|
||||
if len(segments) > 1:
|
||||
# 删除原始文件
|
||||
os.remove(file_path)
|
||||
|
||||
# 保存分割后的文件
|
||||
for i, segment in enumerate(segments):
|
||||
new_filename = f"{filename}_{i+1}"
|
||||
new_file_path = os.path.join(directory, new_filename)
|
||||
|
||||
with open(new_file_path, 'w', encoding='utf-8') as new_file:
|
||||
new_file.write(segment)
|
||||
else:
|
||||
print(f"文件 {filename} 不需要分割")
|
||||
|
||||
# 指定目录
|
||||
directory = './source-new'
|
||||
process_files_in_directory(directory)
|
@ -2,7 +2,6 @@ from openai import OpenAI
|
||||
import argparse
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from tqdm import tqdm
|
||||
|
||||
def translate_text(text, client, model_name, temp):
|
||||
messages = [
|
||||
@ -38,7 +37,7 @@ with open(input_file, "r") as f:
|
||||
src_lines = f.readlines()
|
||||
|
||||
|
||||
for line in tqdm(src_lines):
|
||||
for line in src_lines:
|
||||
result = translate_text(line, client, model, temp)
|
||||
with open(output_file, 'a') as f:
|
||||
f.write(result + '\n')
|
||||
|
@ -1,15 +1,14 @@
|
||||
import subprocess
|
||||
from tqdm import tqdm
|
||||
|
||||
def translate_text(text):
|
||||
command = f'argos-translate --from zh --to en "{text}"'
|
||||
result = subprocess.run(command, shell=True, capture_output=True, text=True)
|
||||
return result.stdout.strip()
|
||||
|
||||
with open("./data/src.txt", "r") as f:
|
||||
with open("src.txt", "r") as f:
|
||||
src_lines = f.readlines()
|
||||
|
||||
for line in tqdm(src_lines):
|
||||
for line in src_lines:
|
||||
result = translate_text(line)
|
||||
with open("./data/hyp-sk-1.2.txt", 'a') as f:
|
||||
with open("hyp-ag.txt", 'a') as f:
|
||||
f.write(result + '\n')
|
@ -1,42 +0,0 @@
|
||||
import json
|
||||
import subprocess
|
||||
import evaluate
|
||||
from nltk.tokenize import word_tokenize
|
||||
from tqdm import tqdm
|
||||
|
||||
bleu_cal = evaluate.load("chrf")
|
||||
|
||||
def translate_text(text):
|
||||
command = f'argos-translate --from zh --to en "{text}"'
|
||||
result = subprocess.run(command, shell=True, capture_output=True, text=True)
|
||||
return result.stdout.strip()
|
||||
|
||||
def main():
|
||||
# 读取数据集
|
||||
with open('./data/1.jsonl', 'r', encoding='utf-8') as f:
|
||||
data = [json.loads(line) for line in f]
|
||||
|
||||
translations = []
|
||||
references = []
|
||||
|
||||
# for entry in tqdm(data):
|
||||
# chinese_sentence = entry['zh']
|
||||
# translated_sentence = translate_text(chinese_sentence)
|
||||
# with open("./data/1-inf.txt", "a") as f:
|
||||
# f.write(translated_sentence + "\n")
|
||||
# translations.append(translated_sentence)
|
||||
|
||||
with open("./data/1-inf.txt", 'r') as f:
|
||||
translations = f.readlines()
|
||||
|
||||
for entry in data:
|
||||
english_sentence = entry['en']
|
||||
references.append([english_sentence])
|
||||
|
||||
|
||||
# 计算 BLEU 分数
|
||||
bleu = bleu_cal.compute(predictions=translations, references=references)
|
||||
print(bleu)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,10 +1,10 @@
|
||||
from googletrans import Translator
|
||||
translator = Translator()
|
||||
|
||||
with open("./data/src.txt", "r") as f:
|
||||
with open("src.txt", "r") as f:
|
||||
src_lines = f.readlines()
|
||||
|
||||
for line in src_lines:
|
||||
result = translator.translate(line, dest='en')
|
||||
with open("./data/hyp-gg-py.txt", 'a') as f:
|
||||
with open("hyp-gg-py.txt", 'a') as f:
|
||||
f.write(result.text + '\n')
|
@ -1,19 +0,0 @@
|
||||
from tqdm import tqdm
|
||||
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
|
||||
model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M")
|
||||
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
|
||||
|
||||
def translate_text(text):
|
||||
tokenizer.src_lang = "zh"
|
||||
encoded_zh = tokenizer(text, return_tensors="pt")
|
||||
generated_tokens = model.generate(**encoded_zh, forced_bos_token_id=tokenizer.get_lang_id("en"))
|
||||
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
||||
return result[0]
|
||||
|
||||
with open("./data/src.txt", "r") as f:
|
||||
src_lines = f.readlines()
|
||||
|
||||
for line in tqdm(src_lines):
|
||||
result = translate_text(line)
|
||||
with open("./data/hyp-m2m.txt", 'a') as f:
|
||||
f.write(result + '\n')
|
@ -28,7 +28,7 @@ def main(input_file, sample_size):
|
||||
chinese_text = item["chinese"]
|
||||
english_text = item["english"]
|
||||
|
||||
with open("./data/src.txt", 'a') as srcf, open("./data/ref.txt", 'a') as reff:
|
||||
with open("src.txt", 'a') as srcf, open("ref.txt", 'a') as reff:
|
||||
srcf.write(chinese_text + '\n')
|
||||
reff.write(english_text + '\n')
|
||||
|
||||
|
@ -1,16 +0,0 @@
|
||||
import re
|
||||
|
||||
# 读取文件内容
|
||||
with open('ug.txt', 'r', encoding='utf-8') as file:
|
||||
data = file.read()
|
||||
|
||||
# 定义正则表达式,保留维吾尔语字母、阿拉伯数字及常见标点符号
|
||||
# 维吾尔语字母的Unicode范围是U+0600-U+06FF
|
||||
# 阿拉伯数字 0-9,以及标点符号(。!?,,;:)可以根据需要调整
|
||||
filtered_data = re.sub(r'[^\u0600-\u06FF0-9.,!?؛:\s]', '', data)
|
||||
|
||||
# 将过滤后的数据输出或保存到新的文件中
|
||||
with open('filtered_ug.txt', 'w', encoding='utf-8') as file:
|
||||
file.write(filtered_data)
|
||||
|
||||
print("过滤完成,结果已保存到 filtered_ug.txt")
|
@ -1,13 +0,0 @@
|
||||
import re
|
||||
|
||||
def replace_spaces_in_file(input_file_path, output_file_path):
|
||||
with open(input_file_path, 'r', encoding='utf-8') as file:
|
||||
text = file.read()
|
||||
|
||||
new_text = re.sub(r' +', ' ', text)
|
||||
|
||||
with open(output_file_path, 'w', encoding='utf-8') as file:
|
||||
file.write(new_text)
|
||||
|
||||
# 调用函数,替换文件中的空格
|
||||
replace_spaces_in_file('./data/ug_texts1.txt', './data/2.txt')
|
Loading…
Reference in New Issue
Block a user