362 lines
13 KiB
Python
362 lines
13 KiB
Python
import math
|
|
import numpy as np
|
|
import os
|
|
import sys
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torch.nn as nn
|
|
|
|
from data import SpeechDataset, SpeechDataLoader, featurelen, cer_wer, cer, wer
|
|
from uyghur import uyghur_latin
|
|
from tqdm import tqdm
|
|
|
|
|
|
from GCGCResM import GCGCResM
|
|
from GCGCRes import GCGCRes
|
|
from GCGCRes1 import GCGCRes1
|
|
from GCGCRes2 import GCGCRes2
|
|
from QuartzNet import QuartzNet15x5, QuartzNet10x5, QuartzNet5x5
|
|
from UDS2W2L import UDS2W2L
|
|
from UDS2W2L3 import UDS2W2L3
|
|
from UDS2W2L5 import UDS2W2L5
|
|
from UDS2W2L50 import UDS2W2L50
|
|
from UDS2W2L8 import UDS2W2L8
|
|
from UDS2W2L80 import UDS2W2L80
|
|
#from FuncNet1 import FuncNet1
|
|
from UArilash0 import UArilash0
|
|
from UArilash1 import UArilash1
|
|
|
|
from UFormerCTC1 import UFormerCTC1
|
|
from UFormerCTC2 import UFormerCTC2
|
|
from UFormerCTC3 import UFormerCTC3
|
|
from UFormerCTC5 import UFormerCTC5
|
|
from UFormerCTC3N import UFormerCTC3N
|
|
from uformer1dgru import UFormer1DGRU
|
|
from UFormerCTC1N import UFormerCTC1N
|
|
|
|
from ConfModelN import ConfModelN
|
|
from ConfModelM import ConfModelM
|
|
from ConfModelM2D import ConfModelM2D
|
|
from tiny_wav2letter import TinyWav2Letter
|
|
from UDS2W2L050 import UDS2W2L050
|
|
|
|
from UDeepSpeech import UDeepSpeech
|
|
from Conv1D3InDS2 import Conv1D3InDS2
|
|
from UDS2W2LGLU0 import UDS2W2LGLU0
|
|
from UDS2W2LGLU import UDS2W2LGLU
|
|
from UDS2W2LGLU8 import UDS2W2LGLU8
|
|
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR, CyclicLR, StepLR
|
|
import random
|
|
|
|
from torch.cuda.amp import GradScaler
|
|
|
|
# Fix seed
|
|
# seed = 17
|
|
# np.random.seed(seed)
|
|
# torch.manual_seed(seed)
|
|
# random.seed(seed)
|
|
|
|
class CustOpt:
|
|
def __init__(self, params, datalen, lr, min_lr = None):
|
|
if min_lr is None:
|
|
min_lr = lr
|
|
|
|
self.optimizer = torch.optim.Adam(params, lr=lr) #, weight_decay=0.00001
|
|
#self.optimizer = torch.optim.Adamax(params, lr=lr, weight_decay=0.00001)
|
|
#self.optimizer = torch.optim.AdamW(params, lr=lr, weight_decay = 0.00001)
|
|
#self.optimizer = torch.optim.SGD(params, lr=lr, momentum=0.9, weight_decay=0.00001)
|
|
self._step = 0
|
|
self.scheduler = CosineAnnealingLR(self.optimizer,T_max=datalen, eta_min = min_lr)
|
|
#self.scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
|
|
#self.scheduler = CyclicLR(self.optimizer, T_max=datalen, eta_min = min_lr)
|
|
|
|
def step(self):
|
|
self.optimizer.step()
|
|
self.scheduler.step()
|
|
rate = self.scheduler.get_last_lr()[0]
|
|
return rate
|
|
|
|
def zero_grad(self):
|
|
self.optimizer.zero_grad()
|
|
|
|
#outputs format = B x F x T
|
|
def calctc_loss(outputs, targets, output_lengths, target_lengths):
|
|
loss = F.ctc_loss(outputs.permute(2,0,1).contiguous(), targets, output_lengths, target_lengths, blank = uyghur_latin.pad_idx, reduction='mean',zero_infinity=True)
|
|
return loss
|
|
|
|
def cal_loss(pred, gold):
|
|
"""
|
|
Calculate metrics
|
|
args:
|
|
pred: B x T x C
|
|
gold: B x T
|
|
input_lengths: B (for CTC)
|
|
target_lengths: B (for CTC)
|
|
"""
|
|
gold = gold.contiguous().view(-1) # (B*T)
|
|
pred = pred.contiguous().view(-1, pred.size(2)) # (B*T) x C
|
|
loss = F.cross_entropy(pred, gold, ignore_index=uyghur_latin.pad_idx, reduction="mean")
|
|
return loss
|
|
|
|
|
|
def validate(model, valid_loader):
|
|
chars = 0
|
|
words = 0
|
|
e_chars = 0
|
|
e_words = 0
|
|
avg_loss = 0
|
|
iter_cnt = 0
|
|
msg = ""
|
|
|
|
cer_val = 0.0
|
|
|
|
model.eval()
|
|
with torch.no_grad():
|
|
tlen = len(valid_loader)
|
|
vbar = tqdm(iter(valid_loader), leave=True, total=tlen)
|
|
for inputs, targets, input_lengths, target_lengths, _ in vbar:
|
|
|
|
inputs = inputs.to(device)
|
|
targets = targets.to(device)
|
|
input_lengths = input_lengths.to(device)
|
|
target_lengths = target_lengths.to(device)
|
|
|
|
if model_type == 'CTC':
|
|
outputs, output_lengths = model(inputs, input_lengths)
|
|
loss = calctc_loss(outputs, targets, output_lengths, target_lengths)
|
|
elif model_type =='S2S':
|
|
output_lengths = 0
|
|
outputs, tgt = model(inputs, input_lengths, targets)
|
|
loss = cal_loss(outputs, tgt)
|
|
elif model_type == 'JOINT':
|
|
output_lengths = 0
|
|
outputs, tgt = model(inputs, input_lengths, targets)
|
|
loss1 = cal_loss(outputs, tgt)
|
|
loss_ctc= calctc_loss(model.ctcOut, targets, model.ctcLen, target_lengths)
|
|
#loss = loss1*0.6 + loss_ctc*0.4
|
|
loss = loss1*0.78 + loss_ctc*0.22
|
|
#loss = loss1*0.22 + loss_ctc*0.78
|
|
|
|
preds = model.greedydecode(outputs, output_lengths)
|
|
targets = [uyghur_latin.decode(target) for target in targets]
|
|
|
|
for pred, src in zip(preds, targets):
|
|
e_char_cnt, char_cnt = cer(pred,src)
|
|
e_word_cnt, word_cnt = wer(pred, src)
|
|
e_chars += e_char_cnt
|
|
e_words += e_word_cnt
|
|
|
|
chars += char_cnt
|
|
words += word_cnt
|
|
|
|
iter_cnt += 1
|
|
avg_loss +=loss.item()
|
|
|
|
msg = f" VALIDATION: [CER:{e_chars/chars:.2%} ({e_chars}/{chars} letters) WER:{e_words/words:.2%} ({e_words}/{words} words), Avg loss:{avg_loss/iter_cnt:4f}]"
|
|
vbar.set_description(msg)
|
|
|
|
vbar.close()
|
|
|
|
cer_val = e_chars/chars
|
|
|
|
with open(log_name,'a', encoding='utf-8') as fp:
|
|
fp.write(msg+"\n")
|
|
|
|
#Print Last 3 validation results
|
|
result =""
|
|
result_cnt = 0
|
|
chars = 0
|
|
words = 0
|
|
e_chars = 0
|
|
e_words = 0
|
|
for pred, src in zip(preds, targets):
|
|
e_char_cnt, char_cnt = cer(pred,src)
|
|
e_word_cnt, word_cnt = wer(pred, src)
|
|
e_chars += e_char_cnt
|
|
e_words += e_word_cnt
|
|
chars += char_cnt
|
|
words += word_cnt
|
|
result += f" O:{src}\n"
|
|
result += f" P:{pred}\n"
|
|
result += f" CER: {e_char_cnt/char_cnt:.2%} ({e_char_cnt}/{char_cnt} letters), WER: {e_word_cnt/word_cnt:.2%} ({e_word_cnt}/{word_cnt} words)\n"
|
|
result_cnt += 1
|
|
if result_cnt >= 3:
|
|
break
|
|
|
|
print(result)
|
|
return cer_val
|
|
|
|
|
|
def train(model, train_loader):
|
|
total_loss = 0
|
|
iter_cnt = 0
|
|
msg =''
|
|
model.train()
|
|
pbar = tqdm(iter(train_loader), leave=True, total=mini_epoch_length)
|
|
for data in pbar:
|
|
optimizer.zero_grad()
|
|
inputs, targets, input_lengths, target_lengths, _ = data
|
|
inputs = inputs.to(device)
|
|
targets = targets.to(device)
|
|
input_lengths = input_lengths.to(device)
|
|
target_lengths = target_lengths.to(device)
|
|
|
|
if model_type == 'CTC':
|
|
outputs, output_lengths = model(inputs, input_lengths)
|
|
loss = calctc_loss(outputs, targets, output_lengths, target_lengths)
|
|
elif model_type =='S2S':
|
|
output_lengths = 0
|
|
outputs, tgt = model(inputs, input_lengths, targets)
|
|
loss = cal_loss(outputs, tgt)
|
|
elif model_type == 'JOINT':
|
|
output_lengths = 0
|
|
outputs, tgt = model(inputs, input_lengths, targets)
|
|
loss1 = cal_loss(outputs, tgt)
|
|
loss_ctc = calctc_loss(model.ctcOut, targets, model.ctcLen, target_lengths)
|
|
#loss = loss1*0.6 + loss_ctc*0.4
|
|
loss = loss1*0.78 + loss_ctc*0.22
|
|
#loss = loss1*0.22 + loss_ctc*0.78
|
|
|
|
loss.backward()
|
|
lr = optimizer.step()
|
|
total_loss += loss.item()
|
|
iter_cnt += 1
|
|
|
|
msg = f'[LR: {lr: .6f} Loss: {loss.item(): .5f}, Avg loss: {(total_loss/iter_cnt): .5f}]'
|
|
pbar.set_description(msg)
|
|
#torch.cuda.empty_cache()
|
|
if iter_cnt > mini_epoch_length:
|
|
break
|
|
|
|
pbar.close()
|
|
with open(log_name,'a', encoding='utf-8') as fp:
|
|
msg = f'Epoch[{(epoch+1):d}]:\t{msg}\n'
|
|
fp.write(msg)
|
|
|
|
def GetModel():
|
|
|
|
if model_type == 'CTC':
|
|
#model = GCGCResM(num_features_input = featurelen)
|
|
#model = UDS2W2L(num_features_input = featurelen)
|
|
#model = GCGCRes2(num_features_input = featurelen)
|
|
#model = GCGCRes(num_features_input = featurelen) # Bashqa yerde mengiwatidu
|
|
#model = GCGCRes1(num_features_input = featurelen) # Bashqa yerde mengiwatidu
|
|
|
|
#model = UDS2W2L50(num_features_input = featurelen)
|
|
#model = UDS2W2L80(num_features_input = featurelen)
|
|
#model = ConfModel(num_features_input = featurelen)
|
|
|
|
#model = QuartzNet15x5(num_features_input = featurelen)
|
|
#model = QuartzNet10x5(num_features_input = featurelen)
|
|
#model = QuartzNet5x5(num_features_input = featurelen)
|
|
|
|
#model = UArilash1(num_features_input = featurelen)
|
|
#model = UDeepSpeech(num_features_input = featurelen)
|
|
#model = UDS2W2L3(num_features_input = featurelen)
|
|
|
|
|
|
#model = TinyWav2Letter(num_features_input = featurelen)
|
|
#model = ConfModelM(num_features_input = featurelen)
|
|
|
|
#model = UDS2W2L050(num_features_input = featurelen)
|
|
#model = Conv1D3InDS2(num_features_input = featurelen)
|
|
#model = UDS2W2LGLU(num_features_input = featurelen)
|
|
model = UDS2W2LGLU8(num_features_input = featurelen)
|
|
|
|
elif model_type == 'S2S':
|
|
#model = UFormer(num_features_input = featurelen)
|
|
#model = UFormer1DGRU(num_features_input = featurelen)
|
|
|
|
#model = UFormerCTC(num_features_input = featurelen)
|
|
#model = UFormerCTC3(num_features_input = featurelen)
|
|
model = UFormerCTC3N(num_features_input = featurelen)
|
|
#model = UFormerCTC1N(num_features_input = featurelen)
|
|
|
|
elif model_type =='JOINT':
|
|
#model = UFormer(num_features_input = featurelen)
|
|
#model = UFormer1DGRU(num_features_input = featurelen)
|
|
|
|
#model = UFormerCTC(num_features_input = featurelen)
|
|
#model = UFormerCTC3(num_features_input = featurelen)
|
|
#model = UFormerCTC3N(num_features_input = featurelen)
|
|
model = UFormerCTC1N(num_features_input = featurelen)
|
|
|
|
|
|
return model
|
|
|
|
|
|
#Sinaydighan modellar
|
|
#UFormerCTC3N
|
|
#UDS2W2L5
|
|
#GCGCRes1
|
|
|
|
if __name__ == "__main__":
|
|
device = "cuda"
|
|
os.makedirs('./results',exist_ok=True)
|
|
|
|
model_type = 'CTC' # S2S, 'JOINT', 'CTC'
|
|
|
|
#train_file = 'uyghur_train.csv'
|
|
train_file = 'uyghur_thuyg20_train_small.csv'
|
|
test_file = 'uyghur_thuyg20_test_small.csv'
|
|
|
|
train_set = SpeechDataset(train_file, augumentation=False)
|
|
train_loader = SpeechDataLoader(train_set,num_workers=5, pin_memory = True, shuffle=True, batch_size=24)
|
|
|
|
validation_set = SpeechDataset(test_file, augumentation=False)
|
|
validation_loader = SpeechDataLoader(validation_set,num_workers=5, pin_memory = True, shuffle=True, batch_size=24)
|
|
|
|
print("="*50)
|
|
msg = f" Training Set: {train_file}, {len(train_set)} samples" + "\n"
|
|
msg += f" Validation Set: {test_file}, {len(validation_set)} samples" + "\n"
|
|
msg += f" Vocab Size : {uyghur_latin.vocab_size}"
|
|
|
|
print(msg)
|
|
model = GetModel()
|
|
print("="*50)
|
|
|
|
log_name = model.checkpoint + '.log'
|
|
with open(log_name,'a', encoding='utf-8') as fp:
|
|
fp.write(msg+'\n')
|
|
|
|
train_set.Raw = model.Raw #If it using RAW wave form data
|
|
validation_set.Raw = model.Raw #If it using RAW wave form data
|
|
|
|
model = model.to(device)
|
|
|
|
#Star train and validation
|
|
testfile=["test1.wav","test2.wav", "test3.wav","test4.wav","test5.wav","test6.wav"]
|
|
start_epoch = model.trained_epochs
|
|
mini_epoch_length = len(train_loader)
|
|
if mini_epoch_length > 1000:
|
|
mini_epoch_length = mini_epoch_length//2
|
|
#pass
|
|
|
|
optimizer = CustOpt(model.parameters(), mini_epoch_length//2, lr = 0.0001, min_lr=0.00001)
|
|
for epoch in range(start_epoch,1000):
|
|
torch.cuda.empty_cache()
|
|
model.eval()
|
|
msg = ""
|
|
for afile in testfile:
|
|
text = model.predict(afile,device)
|
|
text = f"{afile}-->{text}\n"
|
|
print(text,end="")
|
|
msg += text
|
|
|
|
with open(log_name,'a', encoding='utf-8') as fp:
|
|
fp.write(msg+'\n')
|
|
|
|
print("="*50)
|
|
print(f"Training Epoch[{(epoch+1):d}]:")
|
|
train(model, train_loader)
|
|
if (epoch+1) % 1 == 0:
|
|
print("Validating:")
|
|
model.save((epoch+1))
|
|
curcer = validate(model,validation_loader)
|
|
if curcer < model.best_cer:
|
|
model.best_cer = curcer
|
|
model.save((epoch+1),best=True)
|
|
|
|
model.save((epoch+1))
|