cvsa/ml/lab/mmsAlignment/alignWithMMS.py
alikia2x 636c5e25cb
ref: move ML stuff
add: .idea to VCS, the refactor guide
2025-03-29 14:13:15 +08:00

167 lines
5.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import re
import torch
import torchaudio
from typing import List
from pypinyin import lazy_pinyin
from pypinyin_dict.phrase_pinyin_data import cc_cedict
from torchaudio.transforms import Resample
from tqdm import tqdm
from utils.ttml import TTMLGenerator
from utils.audio import get_audio_duration
# 初始化设备、模型、分词器、对齐器等
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bundle = torchaudio.pipelines.MMS_FA
model = bundle.get_model().to(device)
tokenizer = bundle.get_tokenizer()
aligner = bundle.get_aligner()
cc_cedict.load()
def timestamp(seconds: float) -> str:
"""将浮点数秒钟转换为TTML时间戳格式"""
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
seconds = seconds % 60
milliseconds = int((seconds % 1) * 1000)
seconds = int(seconds)
return f"{hours:02}:{minutes:02}:{seconds:02}.{milliseconds:03}"
def compute_alignments(waveform: torch.Tensor, transcript: List[str]):
with torch.inference_mode():
emission, _ = model(waveform.to(device))
token_spans = aligner(emission[0], tokenizer(transcript))
return emission, token_spans
def parse_lrc(lrc_file, audio_len):
"""解析LRC文件返回一个包含时间戳和歌词的列表"""
with open(lrc_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
lrc_data = []
for line in lines:
# 使用正则表达式匹配时间戳和歌词
match = re.match(r'\[(\d+):(\d+\.\d+)\](.*)', line)
if match:
minutes = int(match.group(1))
seconds = float(match.group(2))
lyric = match.group(3).strip()
lyric = lyric.replace(" ", "")
timestamp = minutes * 60 + seconds
lrc_data.append((lyric, timestamp))
for i, (lyric, start_time) in enumerate(lrc_data):
# Skip empty line
if lyric.strip() == "":
continue
if i < len(lrc_data) - 1:
end_time = lrc_data[i + 1][1]
else:
end_time = audio_len
lrc_data[i] = (lyric, start_time, end_time)
# Filter empty lines again
lrc_data = [line for line in lrc_data if line[0].strip() != ""]
return lrc_data
def extract_numbers_from_files(directory):
"""
读取给定目录,提取文件名中的数字部分,并返回一个包含这些数字的列表。
:param directory: 目录路径
:return: 包含数字的列表
"""
numbers = []
pattern = re.compile(r'line-(\d+)\.wav')
try:
for filename in os.listdir(directory):
match = pattern.match(filename)
if match:
number = int(match.group(1))
numbers.append(number)
except Exception as e:
print(f"Error reading directory: {e}")
return None
return numbers
def process_line(line_idx, start_time):
with open(f"./temp/lines/line-{line_idx}.txt", "r") as f:
text = f.read()
waveform, sample_rate = torchaudio.load(f"./temp/lines/line-{line_idx}.wav")
waveform = waveform[0:1]
resampler = Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform)
text_pinyin = lazy_pinyin(text)
text_normalized = " ".join(text_pinyin)
transcript = text_normalized.split()
emission, token_spans = compute_alignments(waveform, transcript)
num_frames = emission.size(1)
ratio = waveform.size(1) / num_frames
words = []
for i in range(len(token_spans)):
spans = token_spans[i]
x0 = start_time + int(ratio * spans[0].start) / 16000
x1 = start_time + int(ratio * spans[-1].end) / 16000
words.append({
"word": text[i],
"start": x0,
"end": x1
})
idx=0
for item in words:
if idx == len(words) - 1:
break
item["end"] = words[idx + 1]["start"]
idx+=1
result = []
for word in words:
result.append((word["word"], timestamp(word["start"]), timestamp(word["end"])))
return result
def align(audio_file: str, lrc_file: str, output_ttml: str, segments_dir: str = "./temp/lines"):
"""
对齐音频和歌词并输出TTML文件。
:param audio_file: 音频文件路径
:param lrc_file: LRC歌词文件路径
:param output_ttml: 输出TTML文件路径
:param segments_dir: 存放分割后音频片段的目录,默认为"./segments"
"""
# 获取音频时长
duration = get_audio_duration(audio_file)
# 解析LRC文件
lrc_data = parse_lrc(lrc_file, duration)
# 提取要处理的行号
lines_to_process = sorted(extract_numbers_from_files(segments_dir))
# 创建TTML生成器实例
ttml_generator = TTMLGenerator(duration=timestamp(duration))
i = 0
for line_num in tqdm(lines_to_process):
start_time = lrc_data[i][1]
end_time = lrc_data[i][2]
result = process_line(line_num, start_time)
ttml_generator.add_lyrics(
begin=timestamp(start_time), end=timestamp(end_time), agent="v1", itunes_key=f"L{i+1}",
words=result
)
i += 1
# 保存TTML文件
ttml_generator.save(output_ttml)
if __name__ == "__main__":
align("./data/1.flac", "./data/1.lrc", "./data/output.ttml", "./temp/lines")