Delete train.py
This commit is contained in:
parent
31257e0ad9
commit
006896af46
361
train.py
361
train.py
@ -1,361 +0,0 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
|
||||
from data import SpeechDataset, SpeechDataLoader, featurelen, cer_wer, cer, wer
|
||||
from uyghur import uyghur_latin
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
from GCGCResM import GCGCResM
|
||||
from GCGCRes import GCGCRes
|
||||
from GCGCRes1 import GCGCRes1
|
||||
from GCGCRes2 import GCGCRes2
|
||||
from QuartzNet import QuartzNet15x5, QuartzNet10x5, QuartzNet5x5
|
||||
from UDS2W2L import UDS2W2L
|
||||
from UDS2W2L3 import UDS2W2L3
|
||||
from UDS2W2L5 import UDS2W2L5
|
||||
from UDS2W2L50 import UDS2W2L50
|
||||
from UDS2W2L8 import UDS2W2L8
|
||||
from UDS2W2L80 import UDS2W2L80
|
||||
#from FuncNet1 import FuncNet1
|
||||
from UArilash0 import UArilash0
|
||||
from UArilash1 import UArilash1
|
||||
|
||||
from UFormerCTC1 import UFormerCTC1
|
||||
from UFormerCTC2 import UFormerCTC2
|
||||
from UFormerCTC3 import UFormerCTC3
|
||||
from UFormerCTC5 import UFormerCTC5
|
||||
from UFormerCTC3N import UFormerCTC3N
|
||||
from uformer1dgru import UFormer1DGRU
|
||||
from UFormerCTC1N import UFormerCTC1N
|
||||
|
||||
from ConfModelN import ConfModelN
|
||||
from ConfModelM import ConfModelM
|
||||
from ConfModelM2D import ConfModelM2D
|
||||
from tiny_wav2letter import TinyWav2Letter
|
||||
from UDS2W2L050 import UDS2W2L050
|
||||
|
||||
from UDeepSpeech import UDeepSpeech
|
||||
from Conv1D3InDS2 import Conv1D3InDS2
|
||||
from UDS2W2LGLU0 import UDS2W2LGLU0
|
||||
from UDS2W2LGLU import UDS2W2LGLU
|
||||
from UDS2W2LGLU8 import UDS2W2LGLU8
|
||||
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR, CyclicLR, StepLR
|
||||
import random
|
||||
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
# Fix seed
|
||||
# seed = 17
|
||||
# np.random.seed(seed)
|
||||
# torch.manual_seed(seed)
|
||||
# random.seed(seed)
|
||||
|
||||
class CustOpt:
|
||||
def __init__(self, params, datalen, lr, min_lr = None):
|
||||
if min_lr is None:
|
||||
min_lr = lr
|
||||
|
||||
self.optimizer = torch.optim.Adam(params, lr=lr) #, weight_decay=0.00001
|
||||
#self.optimizer = torch.optim.Adamax(params, lr=lr, weight_decay=0.00001)
|
||||
#self.optimizer = torch.optim.AdamW(params, lr=lr, weight_decay = 0.00001)
|
||||
#self.optimizer = torch.optim.SGD(params, lr=lr, momentum=0.9, weight_decay=0.00001)
|
||||
self._step = 0
|
||||
self.scheduler = CosineAnnealingLR(self.optimizer,T_max=datalen, eta_min = min_lr)
|
||||
#self.scheduler = StepLR(optimizer, step_size=10, gamma=0.1)
|
||||
#self.scheduler = CyclicLR(self.optimizer, T_max=datalen, eta_min = min_lr)
|
||||
|
||||
def step(self):
|
||||
self.optimizer.step()
|
||||
self.scheduler.step()
|
||||
rate = self.scheduler.get_last_lr()[0]
|
||||
return rate
|
||||
|
||||
def zero_grad(self):
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
#outputs format = B x F x T
|
||||
def calctc_loss(outputs, targets, output_lengths, target_lengths):
|
||||
loss = F.ctc_loss(outputs.permute(2,0,1).contiguous(), targets, output_lengths, target_lengths, blank = uyghur_latin.pad_idx, reduction='mean',zero_infinity=True)
|
||||
return loss
|
||||
|
||||
def cal_loss(pred, gold):
|
||||
"""
|
||||
Calculate metrics
|
||||
args:
|
||||
pred: B x T x C
|
||||
gold: B x T
|
||||
input_lengths: B (for CTC)
|
||||
target_lengths: B (for CTC)
|
||||
"""
|
||||
gold = gold.contiguous().view(-1) # (B*T)
|
||||
pred = pred.contiguous().view(-1, pred.size(2)) # (B*T) x C
|
||||
loss = F.cross_entropy(pred, gold, ignore_index=uyghur_latin.pad_idx, reduction="mean")
|
||||
return loss
|
||||
|
||||
|
||||
def validate(model, valid_loader):
|
||||
chars = 0
|
||||
words = 0
|
||||
e_chars = 0
|
||||
e_words = 0
|
||||
avg_loss = 0
|
||||
iter_cnt = 0
|
||||
msg = ""
|
||||
|
||||
cer_val = 0.0
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
tlen = len(valid_loader)
|
||||
vbar = tqdm(iter(valid_loader), leave=True, total=tlen)
|
||||
for inputs, targets, input_lengths, target_lengths, _ in vbar:
|
||||
|
||||
inputs = inputs.to(device)
|
||||
targets = targets.to(device)
|
||||
input_lengths = input_lengths.to(device)
|
||||
target_lengths = target_lengths.to(device)
|
||||
|
||||
if model_type == 'CTC':
|
||||
outputs, output_lengths = model(inputs, input_lengths)
|
||||
loss = calctc_loss(outputs, targets, output_lengths, target_lengths)
|
||||
elif model_type =='S2S':
|
||||
output_lengths = 0
|
||||
outputs, tgt = model(inputs, input_lengths, targets)
|
||||
loss = cal_loss(outputs, tgt)
|
||||
elif model_type == 'JOINT':
|
||||
output_lengths = 0
|
||||
outputs, tgt = model(inputs, input_lengths, targets)
|
||||
loss1 = cal_loss(outputs, tgt)
|
||||
loss_ctc= calctc_loss(model.ctcOut, targets, model.ctcLen, target_lengths)
|
||||
#loss = loss1*0.6 + loss_ctc*0.4
|
||||
loss = loss1*0.78 + loss_ctc*0.22
|
||||
#loss = loss1*0.22 + loss_ctc*0.78
|
||||
|
||||
preds = model.greedydecode(outputs, output_lengths)
|
||||
targets = [uyghur_latin.decode(target) for target in targets]
|
||||
|
||||
for pred, src in zip(preds, targets):
|
||||
e_char_cnt, char_cnt = cer(pred,src)
|
||||
e_word_cnt, word_cnt = wer(pred, src)
|
||||
e_chars += e_char_cnt
|
||||
e_words += e_word_cnt
|
||||
|
||||
chars += char_cnt
|
||||
words += word_cnt
|
||||
|
||||
iter_cnt += 1
|
||||
avg_loss +=loss.item()
|
||||
|
||||
msg = f" VALIDATION: [CER:{e_chars/chars:.2%} ({e_chars}/{chars} letters) WER:{e_words/words:.2%} ({e_words}/{words} words), Avg loss:{avg_loss/iter_cnt:4f}]"
|
||||
vbar.set_description(msg)
|
||||
|
||||
vbar.close()
|
||||
|
||||
cer_val = e_chars/chars
|
||||
|
||||
with open(log_name,'a', encoding='utf-8') as fp:
|
||||
fp.write(msg+"\n")
|
||||
|
||||
#Print Last 3 validation results
|
||||
result =""
|
||||
result_cnt = 0
|
||||
chars = 0
|
||||
words = 0
|
||||
e_chars = 0
|
||||
e_words = 0
|
||||
for pred, src in zip(preds, targets):
|
||||
e_char_cnt, char_cnt = cer(pred,src)
|
||||
e_word_cnt, word_cnt = wer(pred, src)
|
||||
e_chars += e_char_cnt
|
||||
e_words += e_word_cnt
|
||||
chars += char_cnt
|
||||
words += word_cnt
|
||||
result += f" O:{src}\n"
|
||||
result += f" P:{pred}\n"
|
||||
result += f" CER: {e_char_cnt/char_cnt:.2%} ({e_char_cnt}/{char_cnt} letters), WER: {e_word_cnt/word_cnt:.2%} ({e_word_cnt}/{word_cnt} words)\n"
|
||||
result_cnt += 1
|
||||
if result_cnt >= 3:
|
||||
break
|
||||
|
||||
print(result)
|
||||
return cer_val
|
||||
|
||||
|
||||
def train(model, train_loader):
|
||||
total_loss = 0
|
||||
iter_cnt = 0
|
||||
msg =''
|
||||
model.train()
|
||||
pbar = tqdm(iter(train_loader), leave=True, total=mini_epoch_length)
|
||||
for data in pbar:
|
||||
optimizer.zero_grad()
|
||||
inputs, targets, input_lengths, target_lengths, _ = data
|
||||
inputs = inputs.to(device)
|
||||
targets = targets.to(device)
|
||||
input_lengths = input_lengths.to(device)
|
||||
target_lengths = target_lengths.to(device)
|
||||
|
||||
if model_type == 'CTC':
|
||||
outputs, output_lengths = model(inputs, input_lengths)
|
||||
loss = calctc_loss(outputs, targets, output_lengths, target_lengths)
|
||||
elif model_type =='S2S':
|
||||
output_lengths = 0
|
||||
outputs, tgt = model(inputs, input_lengths, targets)
|
||||
loss = cal_loss(outputs, tgt)
|
||||
elif model_type == 'JOINT':
|
||||
output_lengths = 0
|
||||
outputs, tgt = model(inputs, input_lengths, targets)
|
||||
loss1 = cal_loss(outputs, tgt)
|
||||
loss_ctc = calctc_loss(model.ctcOut, targets, model.ctcLen, target_lengths)
|
||||
#loss = loss1*0.6 + loss_ctc*0.4
|
||||
loss = loss1*0.78 + loss_ctc*0.22
|
||||
#loss = loss1*0.22 + loss_ctc*0.78
|
||||
|
||||
loss.backward()
|
||||
lr = optimizer.step()
|
||||
total_loss += loss.item()
|
||||
iter_cnt += 1
|
||||
|
||||
msg = f'[LR: {lr: .6f} Loss: {loss.item(): .5f}, Avg loss: {(total_loss/iter_cnt): .5f}]'
|
||||
pbar.set_description(msg)
|
||||
#torch.cuda.empty_cache()
|
||||
if iter_cnt > mini_epoch_length:
|
||||
break
|
||||
|
||||
pbar.close()
|
||||
with open(log_name,'a', encoding='utf-8') as fp:
|
||||
msg = f'Epoch[{(epoch+1):d}]:\t{msg}\n'
|
||||
fp.write(msg)
|
||||
|
||||
def GetModel():
|
||||
|
||||
if model_type == 'CTC':
|
||||
#model = GCGCResM(num_features_input = featurelen)
|
||||
#model = UDS2W2L(num_features_input = featurelen)
|
||||
#model = GCGCRes2(num_features_input = featurelen)
|
||||
#model = GCGCRes(num_features_input = featurelen) # Bashqa yerde mengiwatidu
|
||||
#model = GCGCRes1(num_features_input = featurelen) # Bashqa yerde mengiwatidu
|
||||
|
||||
#model = UDS2W2L50(num_features_input = featurelen)
|
||||
#model = UDS2W2L80(num_features_input = featurelen)
|
||||
#model = ConfModel(num_features_input = featurelen)
|
||||
|
||||
#model = QuartzNet15x5(num_features_input = featurelen)
|
||||
#model = QuartzNet10x5(num_features_input = featurelen)
|
||||
#model = QuartzNet5x5(num_features_input = featurelen)
|
||||
|
||||
#model = UArilash1(num_features_input = featurelen)
|
||||
#model = UDeepSpeech(num_features_input = featurelen)
|
||||
#model = UDS2W2L3(num_features_input = featurelen)
|
||||
|
||||
|
||||
#model = TinyWav2Letter(num_features_input = featurelen)
|
||||
#model = ConfModelM(num_features_input = featurelen)
|
||||
|
||||
#model = UDS2W2L050(num_features_input = featurelen)
|
||||
#model = Conv1D3InDS2(num_features_input = featurelen)
|
||||
#model = UDS2W2LGLU(num_features_input = featurelen)
|
||||
model = UDS2W2LGLU8(num_features_input = featurelen)
|
||||
|
||||
elif model_type == 'S2S':
|
||||
#model = UFormer(num_features_input = featurelen)
|
||||
#model = UFormer1DGRU(num_features_input = featurelen)
|
||||
|
||||
#model = UFormerCTC(num_features_input = featurelen)
|
||||
#model = UFormerCTC3(num_features_input = featurelen)
|
||||
model = UFormerCTC3N(num_features_input = featurelen)
|
||||
#model = UFormerCTC1N(num_features_input = featurelen)
|
||||
|
||||
elif model_type =='JOINT':
|
||||
#model = UFormer(num_features_input = featurelen)
|
||||
#model = UFormer1DGRU(num_features_input = featurelen)
|
||||
|
||||
#model = UFormerCTC(num_features_input = featurelen)
|
||||
#model = UFormerCTC3(num_features_input = featurelen)
|
||||
#model = UFormerCTC3N(num_features_input = featurelen)
|
||||
model = UFormerCTC1N(num_features_input = featurelen)
|
||||
|
||||
|
||||
return model
|
||||
|
||||
|
||||
#Sinaydighan modellar
|
||||
#UFormerCTC3N
|
||||
#UDS2W2L5
|
||||
#GCGCRes1
|
||||
|
||||
if __name__ == "__main__":
|
||||
device = "cuda"
|
||||
os.makedirs('./results',exist_ok=True)
|
||||
|
||||
model_type = 'CTC' # S2S, 'JOINT', 'CTC'
|
||||
|
||||
#train_file = 'uyghur_train.csv'
|
||||
train_file = 'uyghur_thuyg20_train_small.csv'
|
||||
test_file = 'uyghur_thuyg20_test_small.csv'
|
||||
|
||||
train_set = SpeechDataset(train_file, augumentation=False)
|
||||
train_loader = SpeechDataLoader(train_set,num_workers=5, pin_memory = True, shuffle=True, batch_size=24)
|
||||
|
||||
validation_set = SpeechDataset(test_file, augumentation=False)
|
||||
validation_loader = SpeechDataLoader(validation_set,num_workers=5, pin_memory = True, shuffle=True, batch_size=24)
|
||||
|
||||
print("="*50)
|
||||
msg = f" Training Set: {train_file}, {len(train_set)} samples" + "\n"
|
||||
msg += f" Validation Set: {test_file}, {len(validation_set)} samples" + "\n"
|
||||
msg += f" Vocab Size : {uyghur_latin.vocab_size}"
|
||||
|
||||
print(msg)
|
||||
model = GetModel()
|
||||
print("="*50)
|
||||
|
||||
log_name = model.checkpoint + '.log'
|
||||
with open(log_name,'a', encoding='utf-8') as fp:
|
||||
fp.write(msg+'\n')
|
||||
|
||||
train_set.Raw = model.Raw #If it using RAW wave form data
|
||||
validation_set.Raw = model.Raw #If it using RAW wave form data
|
||||
|
||||
model = model.to(device)
|
||||
|
||||
#Star train and validation
|
||||
testfile=["test1.wav","test2.wav", "test3.wav","test4.wav","test5.wav","test6.wav"]
|
||||
start_epoch = model.trained_epochs
|
||||
mini_epoch_length = len(train_loader)
|
||||
if mini_epoch_length > 1000:
|
||||
mini_epoch_length = mini_epoch_length//2
|
||||
#pass
|
||||
|
||||
optimizer = CustOpt(model.parameters(), mini_epoch_length//2, lr = 0.0001, min_lr=0.00001)
|
||||
for epoch in range(start_epoch,1000):
|
||||
torch.cuda.empty_cache()
|
||||
model.eval()
|
||||
msg = ""
|
||||
for afile in testfile:
|
||||
text = model.predict(afile,device)
|
||||
text = f"{afile}-->{text}\n"
|
||||
print(text,end="")
|
||||
msg += text
|
||||
|
||||
with open(log_name,'a', encoding='utf-8') as fp:
|
||||
fp.write(msg+'\n')
|
||||
|
||||
print("="*50)
|
||||
print(f"Training Epoch[{(epoch+1):d}]:")
|
||||
train(model, train_loader)
|
||||
if (epoch+1) % 1 == 0:
|
||||
print("Validating:")
|
||||
model.save((epoch+1))
|
||||
curcer = validate(model,validation_loader)
|
||||
if curcer < model.best_cer:
|
||||
model.best_cer = curcer
|
||||
model.save((epoch+1),best=True)
|
||||
|
||||
model.save((epoch+1))
|
Loading…
Reference in New Issue
Block a user