import math import numpy as np import os import sys import torch import torch.nn.functional as F import torch.nn as nn from data import SpeechDataset, SpeechDataLoader, featurelen, cer_wer, cer, wer from uyghur import uyghur_latin from tqdm import tqdm from GCGCResM import GCGCResM from GCGCRes import GCGCRes from GCGCRes1 import GCGCRes1 from GCGCRes2 import GCGCRes2 from QuartzNet import QuartzNet15x5, QuartzNet10x5, QuartzNet5x5 from UDS2W2L import UDS2W2L from UDS2W2L3 import UDS2W2L3 from UDS2W2L5 import UDS2W2L5 from UDS2W2L50 import UDS2W2L50 from UDS2W2L8 import UDS2W2L8 from UDS2W2L80 import UDS2W2L80 #from FuncNet1 import FuncNet1 from UArilash0 import UArilash0 from UArilash1 import UArilash1 from UFormerCTC1 import UFormerCTC1 from UFormerCTC2 import UFormerCTC2 from UFormerCTC3 import UFormerCTC3 from UFormerCTC5 import UFormerCTC5 from UFormerCTC3N import UFormerCTC3N from uformer1dgru import UFormer1DGRU from UFormerCTC1N import UFormerCTC1N from ConfModelN import ConfModelN from ConfModelM import ConfModelM from ConfModelM2D import ConfModelM2D from tiny_wav2letter import TinyWav2Letter from UDS2W2L050 import UDS2W2L050 from UDeepSpeech import UDeepSpeech from Conv1D3InDS2 import Conv1D3InDS2 from UDS2W2LGLU0 import UDS2W2LGLU0 from UDS2W2LGLU import UDS2W2LGLU from UDS2W2LGLU8 import UDS2W2LGLU8 from torch.optim.lr_scheduler import CosineAnnealingLR, CyclicLR, StepLR import random from torch.cuda.amp import GradScaler # Fix seed # seed = 17 # np.random.seed(seed) # torch.manual_seed(seed) # random.seed(seed) class CustOpt: def __init__(self, params, datalen, lr, min_lr = None): if min_lr is None: min_lr = lr self.optimizer = torch.optim.Adam(params, lr=lr) #, weight_decay=0.00001 #self.optimizer = torch.optim.Adamax(params, lr=lr, weight_decay=0.00001) #self.optimizer = torch.optim.AdamW(params, lr=lr, weight_decay = 0.00001) #self.optimizer = torch.optim.SGD(params, lr=lr, momentum=0.9, weight_decay=0.00001) self._step = 0 self.scheduler = CosineAnnealingLR(self.optimizer,T_max=datalen, eta_min = min_lr) #self.scheduler = StepLR(optimizer, step_size=10, gamma=0.1) #self.scheduler = CyclicLR(self.optimizer, T_max=datalen, eta_min = min_lr) def step(self): self.optimizer.step() self.scheduler.step() rate = self.scheduler.get_last_lr()[0] return rate def zero_grad(self): self.optimizer.zero_grad() #outputs format = B x F x T def calctc_loss(outputs, targets, output_lengths, target_lengths): loss = F.ctc_loss(outputs.permute(2,0,1).contiguous(), targets, output_lengths, target_lengths, blank = uyghur_latin.pad_idx, reduction='mean',zero_infinity=True) return loss def cal_loss(pred, gold): """ Calculate metrics args: pred: B x T x C gold: B x T input_lengths: B (for CTC) target_lengths: B (for CTC) """ gold = gold.contiguous().view(-1) # (B*T) pred = pred.contiguous().view(-1, pred.size(2)) # (B*T) x C loss = F.cross_entropy(pred, gold, ignore_index=uyghur_latin.pad_idx, reduction="mean") return loss def validate(model, valid_loader): chars = 0 words = 0 e_chars = 0 e_words = 0 avg_loss = 0 iter_cnt = 0 msg = "" cer_val = 0.0 model.eval() with torch.no_grad(): tlen = len(valid_loader) vbar = tqdm(iter(valid_loader), leave=True, total=tlen) for inputs, targets, input_lengths, target_lengths, _ in vbar: inputs = inputs.to(device) targets = targets.to(device) input_lengths = input_lengths.to(device) target_lengths = target_lengths.to(device) if model_type == 'CTC': outputs, output_lengths = model(inputs, input_lengths) loss = calctc_loss(outputs, targets, output_lengths, target_lengths) elif model_type =='S2S': output_lengths = 0 outputs, tgt = model(inputs, input_lengths, targets) loss = cal_loss(outputs, tgt) elif model_type == 'JOINT': output_lengths = 0 outputs, tgt = model(inputs, input_lengths, targets) loss1 = cal_loss(outputs, tgt) loss_ctc= calctc_loss(model.ctcOut, targets, model.ctcLen, target_lengths) #loss = loss1*0.6 + loss_ctc*0.4 loss = loss1*0.78 + loss_ctc*0.22 #loss = loss1*0.22 + loss_ctc*0.78 preds = model.greedydecode(outputs, output_lengths) targets = [uyghur_latin.decode(target) for target in targets] for pred, src in zip(preds, targets): e_char_cnt, char_cnt = cer(pred,src) e_word_cnt, word_cnt = wer(pred, src) e_chars += e_char_cnt e_words += e_word_cnt chars += char_cnt words += word_cnt iter_cnt += 1 avg_loss +=loss.item() msg = f" VALIDATION: [CER:{e_chars/chars:.2%} ({e_chars}/{chars} letters) WER:{e_words/words:.2%} ({e_words}/{words} words), Avg loss:{avg_loss/iter_cnt:4f}]" vbar.set_description(msg) vbar.close() cer_val = e_chars/chars with open(log_name,'a', encoding='utf-8') as fp: fp.write(msg+"\n") #Print Last 3 validation results result ="" result_cnt = 0 chars = 0 words = 0 e_chars = 0 e_words = 0 for pred, src in zip(preds, targets): e_char_cnt, char_cnt = cer(pred,src) e_word_cnt, word_cnt = wer(pred, src) e_chars += e_char_cnt e_words += e_word_cnt chars += char_cnt words += word_cnt result += f" O:{src}\n" result += f" P:{pred}\n" result += f" CER: {e_char_cnt/char_cnt:.2%} ({e_char_cnt}/{char_cnt} letters), WER: {e_word_cnt/word_cnt:.2%} ({e_word_cnt}/{word_cnt} words)\n" result_cnt += 1 if result_cnt >= 3: break print(result) return cer_val def train(model, train_loader): total_loss = 0 iter_cnt = 0 msg ='' model.train() pbar = tqdm(iter(train_loader), leave=True, total=mini_epoch_length) for data in pbar: optimizer.zero_grad() inputs, targets, input_lengths, target_lengths, _ = data inputs = inputs.to(device) targets = targets.to(device) input_lengths = input_lengths.to(device) target_lengths = target_lengths.to(device) if model_type == 'CTC': outputs, output_lengths = model(inputs, input_lengths) loss = calctc_loss(outputs, targets, output_lengths, target_lengths) elif model_type =='S2S': output_lengths = 0 outputs, tgt = model(inputs, input_lengths, targets) loss = cal_loss(outputs, tgt) elif model_type == 'JOINT': output_lengths = 0 outputs, tgt = model(inputs, input_lengths, targets) loss1 = cal_loss(outputs, tgt) loss_ctc = calctc_loss(model.ctcOut, targets, model.ctcLen, target_lengths) #loss = loss1*0.6 + loss_ctc*0.4 loss = loss1*0.78 + loss_ctc*0.22 #loss = loss1*0.22 + loss_ctc*0.78 loss.backward() lr = optimizer.step() total_loss += loss.item() iter_cnt += 1 msg = f'[LR: {lr: .6f} Loss: {loss.item(): .5f}, Avg loss: {(total_loss/iter_cnt): .5f}]' pbar.set_description(msg) #torch.cuda.empty_cache() if iter_cnt > mini_epoch_length: break pbar.close() with open(log_name,'a', encoding='utf-8') as fp: msg = f'Epoch[{(epoch+1):d}]:\t{msg}\n' fp.write(msg) def GetModel(): if model_type == 'CTC': #model = GCGCResM(num_features_input = featurelen) #model = UDS2W2L(num_features_input = featurelen) #model = GCGCRes2(num_features_input = featurelen) #model = GCGCRes(num_features_input = featurelen) # Bashqa yerde mengiwatidu #model = GCGCRes1(num_features_input = featurelen) # Bashqa yerde mengiwatidu #model = UDS2W2L50(num_features_input = featurelen) #model = UDS2W2L80(num_features_input = featurelen) #model = ConfModel(num_features_input = featurelen) #model = QuartzNet15x5(num_features_input = featurelen) #model = QuartzNet10x5(num_features_input = featurelen) #model = QuartzNet5x5(num_features_input = featurelen) #model = UArilash1(num_features_input = featurelen) #model = UDeepSpeech(num_features_input = featurelen) #model = UDS2W2L3(num_features_input = featurelen) #model = TinyWav2Letter(num_features_input = featurelen) #model = ConfModelM(num_features_input = featurelen) #model = UDS2W2L050(num_features_input = featurelen) #model = Conv1D3InDS2(num_features_input = featurelen) #model = UDS2W2LGLU(num_features_input = featurelen) model = UDS2W2LGLU8(num_features_input = featurelen) elif model_type == 'S2S': #model = UFormer(num_features_input = featurelen) #model = UFormer1DGRU(num_features_input = featurelen) #model = UFormerCTC(num_features_input = featurelen) #model = UFormerCTC3(num_features_input = featurelen) model = UFormerCTC3N(num_features_input = featurelen) #model = UFormerCTC1N(num_features_input = featurelen) elif model_type =='JOINT': #model = UFormer(num_features_input = featurelen) #model = UFormer1DGRU(num_features_input = featurelen) #model = UFormerCTC(num_features_input = featurelen) #model = UFormerCTC3(num_features_input = featurelen) #model = UFormerCTC3N(num_features_input = featurelen) model = UFormerCTC1N(num_features_input = featurelen) return model #Sinaydighan modellar #UFormerCTC3N #UDS2W2L5 #GCGCRes1 if __name__ == "__main__": device = "cuda" os.makedirs('./results',exist_ok=True) model_type = 'CTC' # S2S, 'JOINT', 'CTC' #train_file = 'uyghur_train.csv' train_file = 'uyghur_thuyg20_train_small.csv' test_file = 'uyghur_thuyg20_test_small.csv' train_set = SpeechDataset(train_file, augumentation=False) train_loader = SpeechDataLoader(train_set,num_workers=5, pin_memory = True, shuffle=True, batch_size=24) validation_set = SpeechDataset(test_file, augumentation=False) validation_loader = SpeechDataLoader(validation_set,num_workers=5, pin_memory = True, shuffle=True, batch_size=24) print("="*50) msg = f" Training Set: {train_file}, {len(train_set)} samples" + "\n" msg += f" Validation Set: {test_file}, {len(validation_set)} samples" + "\n" msg += f" Vocab Size : {uyghur_latin.vocab_size}" print(msg) model = GetModel() print("="*50) log_name = model.checkpoint + '.log' with open(log_name,'a', encoding='utf-8') as fp: fp.write(msg+'\n') train_set.Raw = model.Raw #If it using RAW wave form data validation_set.Raw = model.Raw #If it using RAW wave form data model = model.to(device) #Star train and validation testfile=["test1.wav","test2.wav", "test3.wav","test4.wav","test5.wav","test6.wav"] start_epoch = model.trained_epochs mini_epoch_length = len(train_loader) if mini_epoch_length > 1000: mini_epoch_length = mini_epoch_length//2 #pass optimizer = CustOpt(model.parameters(), mini_epoch_length//2, lr = 0.0001, min_lr=0.00001) for epoch in range(start_epoch,1000): torch.cuda.empty_cache() model.eval() msg = "" for afile in testfile: text = model.predict(afile,device) text = f"{afile}-->{text}\n" print(text,end="") msg += text with open(log_name,'a', encoding='utf-8') as fp: fp.write(msg+'\n') print("="*50) print(f"Training Epoch[{(epoch+1):d}]:") train(model, train_loader) if (epoch+1) % 1 == 0: print("Validating:") model.save((epoch+1)) curcer = validate(model,validation_loader) if curcer < model.best_cer: model.best_cer = curcer model.save((epoch+1),best=True) model.save((epoch+1))