Add files via upload
This commit is contained in:
parent
006896af46
commit
8d6465b11c
197
UModel.py
Normal file
197
UModel.py
Normal file
@ -0,0 +1,197 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from uyghur import uyghur_latin
|
||||
from data import melfuture
|
||||
|
||||
class ResB(nn.Module):
|
||||
def __init__(self, num_filters, kernel, pad, d = 0.4):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv1d(num_filters, num_filters, kernel_size = kernel, stride = 1 , padding=pad, bias=False),
|
||||
nn.BatchNorm1d(num_filters)
|
||||
)
|
||||
|
||||
self.relu = nn.ReLU()
|
||||
self.bn = nn.BatchNorm1d(num_filters)
|
||||
self.drop =nn.Dropout(d)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
out = self.conv(x)
|
||||
out += identity
|
||||
out = self.bn(out)
|
||||
out = self.relu(out)
|
||||
out = self.drop(out)
|
||||
return out
|
||||
|
||||
class UModel(nn.Module):
|
||||
def __init__(self, num_features_input, load_best=False):
|
||||
super(UModel, self).__init__()
|
||||
|
||||
self.in1 = nn.Conv1d(128,256,11,2, 5*1, dilation = 1, bias=False)
|
||||
self.in2 = nn.Conv1d(128,256,15,2, 7*2, dilation = 2, bias=False)
|
||||
self.in3 = nn.Conv1d(128,256,19,2, 9*3, dilation = 3, bias=False)
|
||||
self.concat = nn.Conv1d(256*3,256,1,1,bias=True)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
self.cnn1 = nn.Sequential(
|
||||
nn.Conv1d(256, 256, 11, 1, 5, bias=False),
|
||||
nn.BatchNorm1d(256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
ResB(256,11,5,0.2),
|
||||
ResB(256,11,5,0.2),
|
||||
ResB(256,11,5,0.2),
|
||||
ResB(256,11,5,0.2)
|
||||
)
|
||||
self.rnn = nn.GRU(256, 384, num_layers=1 , batch_first=True, bidirectional=True)
|
||||
self.cnn2 = nn.Sequential(
|
||||
ResB(384,13,6,0.2),
|
||||
ResB(384,13,6,0.2),
|
||||
ResB(384,13,6,0.2),
|
||||
nn.Conv1d(384, 512, 17, 1,8, bias=False),
|
||||
nn.BatchNorm1d(512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
ResB(512,17,8,0.3),
|
||||
ResB(512,17,8,0.3),
|
||||
nn.Conv1d(512, 1024, 1, 1, bias=False),
|
||||
nn.BatchNorm1d(1024),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
ResB(1024,1,0,0.0),
|
||||
)
|
||||
self.outlayer = nn.Conv1d(1024, uyghur_latin.vocab_size, 1, 1)
|
||||
self.softMax = nn.LogSoftmax(dim=1)
|
||||
|
||||
self.checkpoint = 'results/UModel'
|
||||
self._load(load_best)
|
||||
print(f'The model has {self.parameters_count(self):,} trainable parameters')
|
||||
|
||||
# X : N x F x T
|
||||
def forward(self, x, input_lengths):
|
||||
|
||||
inp = torch.cat([self.in1(x), self.in2(x), self.in3(x)],dim = 1)
|
||||
inp = self.concat(inp)
|
||||
inp = self.relu(inp)
|
||||
out = self.cnn1(inp)
|
||||
|
||||
out_lens = input_lengths//2
|
||||
out = out.permute(0,2,1)
|
||||
|
||||
out,_ = self.rnn(out)
|
||||
out = (out[:, :, :self.rnn.hidden_size] + out[:, :, self.rnn.hidden_size:]).contiguous()
|
||||
|
||||
out = self.cnn2(out.permute(0,2,1))
|
||||
out = self.outlayer(out)
|
||||
out = self.softMax(out)
|
||||
return out, out_lens
|
||||
|
||||
|
||||
def parameters_count(self, model):
|
||||
sum_par = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
return sum_par
|
||||
|
||||
def _load(self, load_best=False):
|
||||
path = None
|
||||
if load_best == True and os.path.exists(self.checkpoint + '_best.pth'):
|
||||
path = path = self.checkpoint + '_best.pth'
|
||||
elif os.path.exists(self.checkpoint + '_last.pth'):
|
||||
path = self.checkpoint + '_last.pth'
|
||||
|
||||
if path is not None:
|
||||
pack = torch.load(path, map_location='cpu')
|
||||
self.load_state_dict(pack['st_dict'])
|
||||
self.trained_epochs = pack['epoch']
|
||||
self.best_cer = pack.get('BCER', 1.0)
|
||||
print(f' Model loaded: {path}')
|
||||
print(f' Best CER: {self.best_cer:.2%}')
|
||||
print(f' Trained: {self.trained_epochs} epochs')
|
||||
|
||||
def save(self, epoch, best = False):
|
||||
pack = {
|
||||
'st_dict':self.state_dict(),
|
||||
'epoch':epoch,
|
||||
'BCER':self.best_cer
|
||||
}
|
||||
|
||||
if best == True:
|
||||
path = path = self.checkpoint + '_best.pth'
|
||||
else:
|
||||
path = path = self.checkpoint + '_last.pth'
|
||||
torch.save(pack, path)
|
||||
|
||||
|
||||
def predict(self, path, device):
|
||||
self.eval()
|
||||
spect = melfuture(path).to(device)
|
||||
spect.unsqueeze_(0)
|
||||
xn = [spect.size(2)]
|
||||
xn = torch.IntTensor(xn)
|
||||
out, xn = self.forward(spect, xn)
|
||||
text = self.greedydecode(out, xn)
|
||||
self.train()
|
||||
return text[0]
|
||||
|
||||
#CTC greedy decode
|
||||
def greedydecode(self, yps, yps_lens):
|
||||
_, max_yps = torch.max(yps, 1)
|
||||
preds = []
|
||||
for x in range(len(max_yps)):
|
||||
pred = []
|
||||
last = None
|
||||
for i in range(yps_lens[x]):
|
||||
char = int(max_yps[x][i].item())
|
||||
if char != uyghur_latin.pad_idx:
|
||||
if char != last:
|
||||
pred.append(char)
|
||||
last = char
|
||||
preds.append(pred)
|
||||
|
||||
predstrs = [uyghur_latin.decode(pred) for pred in preds]
|
||||
return predstrs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from data import featurelen, melfuture
|
||||
device ="cpu"
|
||||
|
||||
net = UModel(featurelen).to(device)
|
||||
#net.save(0)
|
||||
|
||||
text = net.predict("test1.wav",device)
|
||||
print(text)
|
||||
text = net.predict("test2.wav",device)
|
||||
print(text)
|
||||
|
||||
melf = melfuture("test3.wav")
|
||||
melf.unsqueeze_(0)
|
||||
|
||||
conv0 = nn.Conv1d(featurelen,256,11,2, 5, 1)
|
||||
|
||||
conv1 = nn.Conv1d(256,256,11,1, 5, 1)
|
||||
conv3 = nn.Conv1d(256,256,11,1, 5*2, 2)
|
||||
conv5 = nn.Conv1d(256,256,11,1, 5*3, 3)
|
||||
|
||||
out0 = conv0(melf)
|
||||
|
||||
out1 = conv1(out0)
|
||||
out3 = conv3(out0)
|
||||
out5 = conv5(out0)
|
||||
|
||||
print(out1.size())
|
||||
print(out3.size())
|
||||
print(out5.size())
|
||||
|
||||
out = out1 * out3 * out5
|
||||
print(out.size())
|
||||
|
||||
|
||||
#net = GCGCRes(featurelen).to(device)
|
||||
#net.save(1)
|
||||
|
||||
#text = net.predict("test1.wav",device)
|
||||
#print(text)
|
||||
#text = net.predict("test2.wav",device)
|
||||
#print(text)
|
197
data.py
Normal file
197
data.py
Normal file
@ -0,0 +1,197 @@
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import librosa
|
||||
from sklearn import preprocessing
|
||||
import os
|
||||
import random
|
||||
from uyghur import uyghur_latin
|
||||
import numpy as np
|
||||
|
||||
|
||||
featurelen = 128 #melspec, 60 #mfcc
|
||||
sample_rate = 22050
|
||||
fft_len = 1024
|
||||
window_len = fft_len
|
||||
window = "hann"
|
||||
hop_len = 200
|
||||
|
||||
white_noise,_=librosa.load('white.wav',sr=sample_rate, duration=15.0)
|
||||
perlin_noise,_=librosa.load('perlin.wav',sr=sample_rate, duration=15.0)
|
||||
cafe_noise, _ = librosa.load('cafe.wav',sr=sample_rate, duration=15.0)
|
||||
radio_noise, _ = librosa.load('radionoise.wav',sr=sample_rate, duration=15.0)
|
||||
|
||||
def addnoise(audio):
|
||||
rnd = random.random()
|
||||
if len(audio) > len(white_noise):
|
||||
pass
|
||||
elif rnd <0.25:
|
||||
audio = audio + white_noise[:len(audio)]
|
||||
elif rnd <0.50:
|
||||
audio = audio + perlin_noise[:audio.shape[0]]
|
||||
elif rnd <0.75:
|
||||
audio = audio + radio_noise[:audio.shape[0]]
|
||||
else:
|
||||
audio = audio + cafe_noise[:audio.shape[0]]
|
||||
return audio
|
||||
|
||||
def randomstretch(audio):
|
||||
factor = random.uniform(0.8, 1.2)
|
||||
audio = librosa.core.resample(audio,sample_rate,sample_rate*factor)
|
||||
return audio
|
||||
|
||||
#def spec_augment(feat, T=70, F=15, time_mask_num=1, freq_mask_num=1):
|
||||
def spec_augment(feat, T=50, F=13, time_mask_num=1, freq_mask_num=1):
|
||||
rnd = random.random()
|
||||
|
||||
feat_size = feat.size(0)
|
||||
seq_len = feat.size(1)
|
||||
|
||||
if rnd< 0.33:
|
||||
# time mask
|
||||
for _ in range(time_mask_num):
|
||||
t = random.randint(0, T)
|
||||
t0 = random.randint(0, seq_len - t)
|
||||
feat[:, t0 : t0 + t] = 0
|
||||
|
||||
elif rnd <0.66:
|
||||
# freq mask
|
||||
for _ in range(freq_mask_num):
|
||||
f = random.randint(0, F)
|
||||
f0 = random.randint(0, feat_size - f)
|
||||
feat[f0 : f0 + f, :] = 0
|
||||
else:
|
||||
# time mask
|
||||
for _ in range(time_mask_num):
|
||||
t = random.randint(0, T)
|
||||
t0 = random.randint(0, seq_len - t)
|
||||
feat[:, t0 : t0 + t] = 0
|
||||
|
||||
# freq mask
|
||||
for _ in range(freq_mask_num):
|
||||
f = random.randint(0, F)
|
||||
f0 = random.randint(0, feat_size - f)
|
||||
feat[f0 : f0 + f, :] = 0
|
||||
|
||||
return feat
|
||||
|
||||
|
||||
def melfuture(wav_path, augument = False):
|
||||
audio, s_r = librosa.load(wav_path, sr=sample_rate, res_type='polyphase')
|
||||
|
||||
if augument:
|
||||
if random.random()<0.5:
|
||||
audio = randomstretch(audio)
|
||||
|
||||
if random.random()<0.5:
|
||||
audio = addnoise(audio)
|
||||
|
||||
audio = preprocessing.minmax_scale(audio, axis=0)
|
||||
audio = librosa.effects.preemphasis(audio)
|
||||
|
||||
spec = librosa.feature.melspectrogram(y=audio, sr=s_r, n_fft=fft_len, hop_length=hop_len, n_mels=featurelen, fmax=8000)
|
||||
spec = librosa.power_to_db(spec)
|
||||
#spec = librosa.amplitude_to_db(spec)
|
||||
|
||||
spec = (spec - spec.mean()) / spec.std()
|
||||
spec = torch.FloatTensor(spec)
|
||||
if augument and random.random()<0.5:
|
||||
spec = spec_augment(spec)
|
||||
|
||||
return spec
|
||||
|
||||
class SpeechDataset(Dataset):
|
||||
def __init__(self, index_path, augumentation = False):
|
||||
self.Raw = False
|
||||
with open(index_path,encoding='utf_8_sig') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
self.idx = []
|
||||
for x in lines:
|
||||
item = x.strip().split("\t")
|
||||
if os.path.exists(item[0]):
|
||||
line = []
|
||||
line.append(item[0])
|
||||
char_indx = uyghur_latin.encode(item[1])
|
||||
line.append(char_indx)
|
||||
self.idx.append(line)
|
||||
|
||||
self.augument = augumentation
|
||||
|
||||
def __getitem__(self, index):
|
||||
wav_path, char_index = self.idx[index]
|
||||
x = melfuture(wav_path, self.augument)
|
||||
return x, char_index, wav_path
|
||||
|
||||
def __len__(self):
|
||||
return len(self.idx)
|
||||
|
||||
def _collate_fn(batch):
|
||||
input_lens = [sample[0].size(1) for sample in batch]
|
||||
target_lens = [len(sample[1]) for sample in batch]
|
||||
|
||||
inputs = torch.zeros(len(batch), batch[0][0].size(0), max(input_lens) ,dtype=torch.float32)
|
||||
targets = torch.zeros(len(batch), max(target_lens),dtype=torch.long).fill_(uyghur_latin.pad_idx)
|
||||
|
||||
target_lens = torch.IntTensor(target_lens)
|
||||
input_lens = torch.IntTensor(input_lens)
|
||||
paths = []
|
||||
for x, sample in enumerate(batch):
|
||||
tensor = sample[0]
|
||||
target = sample[1]
|
||||
seq_length = tensor.size(1)
|
||||
inputs[x].narrow(1, 0, seq_length).copy_(tensor)
|
||||
targets[x][:len(target)] = torch.LongTensor(target)
|
||||
paths.append(sample[2])
|
||||
return inputs, targets, input_lens, target_lens, paths
|
||||
|
||||
|
||||
class SpeechDataLoader(DataLoader):
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
Creates a data loader for AudioDatasets.
|
||||
"""
|
||||
super(SpeechDataLoader, self).__init__(*args, **kwargs)
|
||||
self.collate_fn = _collate_fn
|
||||
|
||||
|
||||
# The following code is from: http://hetland.org/coding/python/levenshtein.py
|
||||
def levenshtein(a,b):
|
||||
"Calculates the Levenshtein distance between a and b."
|
||||
n, m = len(a), len(b)
|
||||
if n > m:
|
||||
# Make sure n <= m, to use O(min(n,m)) space
|
||||
a,b = b,a
|
||||
n,m = m,n
|
||||
|
||||
current = list(range(n+1))
|
||||
for i in range(1,m+1):
|
||||
previous, current = current, [i]+[0]*n
|
||||
for j in range(1,n+1):
|
||||
add, delete = previous[j]+1, current[j-1]+1
|
||||
change = previous[j-1]
|
||||
if a[j-1] != b[i-1]:
|
||||
change = change + 1
|
||||
current[j] = min(add, delete, change)
|
||||
|
||||
return current[n]
|
||||
|
||||
def wer(s1, src):
|
||||
sw = src.split()
|
||||
return levenshtein(s1.split(),sw), len(sw)
|
||||
|
||||
def cer(s1, src):
|
||||
return levenshtein(s1,src),len(src)
|
||||
|
||||
def cer_wer(preds, targets):
|
||||
err_c, lettercnt, err_w, wordcnt = 0,0,0,0
|
||||
for pred, target in zip(preds, targets):
|
||||
c_er, c_cnt = cer(pred, target)
|
||||
w_er, w_cnt = wer(pred, target)
|
||||
err_c += c_er
|
||||
lettercnt += c_cnt
|
||||
wordcnt += w_cnt
|
||||
err_w += w_er
|
||||
|
||||
return err_c, lettercnt, err_w, wordcnt
|
BIN
perlin.wav
Normal file
BIN
perlin.wav
Normal file
Binary file not shown.
BIN
radionoise.wav
Normal file
BIN
radionoise.wav
Normal file
Binary file not shown.
BIN
silence.wav
Normal file
BIN
silence.wav
Normal file
Binary file not shown.
2142
thuyg20_test.csv
Normal file
2142
thuyg20_test.csv
Normal file
File diff suppressed because it is too large
Load Diff
9923
thuyg20_train.csv
Normal file
9923
thuyg20_train.csv
Normal file
File diff suppressed because it is too large
Load Diff
190
train.py
Normal file
190
train.py
Normal file
@ -0,0 +1,190 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from data import SpeechDataset, SpeechDataLoader, featurelen, cer, wer
|
||||
from uyghur import uyghur_latin
|
||||
from tqdm import tqdm
|
||||
from UModel import UModel
|
||||
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
class CustOpt:
|
||||
def __init__(self, params, datalen, lr, min_lr = None):
|
||||
if min_lr is None:
|
||||
min_lr = lr
|
||||
|
||||
self.optimizer = torch.optim.Adam(params, lr=lr, weight_decay=0.000001) #, weight_decay=0.000001
|
||||
self._step = 0
|
||||
self.scheduler = CosineAnnealingLR(self.optimizer,T_max=datalen, eta_min = min_lr)
|
||||
|
||||
def step(self):
|
||||
self.optimizer.step()
|
||||
self.scheduler.step()
|
||||
rate = self.scheduler.get_last_lr()[0]
|
||||
return rate
|
||||
|
||||
def zero_grad(self):
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
#outputs format = B x F x T
|
||||
def calctc_loss(outputs, targets, output_lengths, target_lengths):
|
||||
loss = F.ctc_loss(outputs.permute(2,0,1).contiguous(), targets, output_lengths, target_lengths, blank = uyghur_latin.pad_idx, reduction='mean',zero_infinity=True)
|
||||
return loss
|
||||
|
||||
def validate(model, valid_loader):
|
||||
chars = 0
|
||||
words = 0
|
||||
e_chars = 0
|
||||
e_words = 0
|
||||
avg_loss = 0
|
||||
iter_cnt = 0
|
||||
msg = ""
|
||||
|
||||
cer_val = 0.0
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
tlen = len(valid_loader)
|
||||
vbar = tqdm(iter(valid_loader), leave=True, total=tlen)
|
||||
for inputs, targets, input_lengths, target_lengths, _ in vbar:
|
||||
|
||||
inputs = inputs.to(device)
|
||||
targets = targets.to(device)
|
||||
outputs, output_lengths = model(inputs, input_lengths)
|
||||
loss = calctc_loss(outputs, targets, output_lengths, target_lengths)
|
||||
preds = model.greedydecode(outputs, output_lengths)
|
||||
targets = [uyghur_latin.decode(target) for target in targets]
|
||||
|
||||
for pred, src in zip(preds, targets):
|
||||
e_char_cnt, char_cnt = cer(pred,src)
|
||||
e_word_cnt, word_cnt = wer(pred, src)
|
||||
e_chars += e_char_cnt
|
||||
e_words += e_word_cnt
|
||||
|
||||
chars += char_cnt
|
||||
words += word_cnt
|
||||
|
||||
iter_cnt += 1
|
||||
avg_loss +=loss.item()
|
||||
|
||||
msg = f" VALIDATION: [CER:{e_chars/chars:.2%} ({e_chars}/{chars} letters) WER:{e_words/words:.2%} ({e_words}/{words} words), Avg loss:{avg_loss/iter_cnt:4f}]"
|
||||
vbar.set_description(msg)
|
||||
|
||||
vbar.close()
|
||||
|
||||
cer_val = e_chars/chars
|
||||
|
||||
with open(log_name,'a', encoding='utf-8') as fp:
|
||||
fp.write(msg+"\n")
|
||||
|
||||
#Print Last 3 validation results
|
||||
result =""
|
||||
result_cnt = 0
|
||||
for pred, src in zip(preds, targets):
|
||||
e_char_cnt, char_cnt = cer(pred,src)
|
||||
e_word_cnt, word_cnt = wer(pred, src)
|
||||
result += f" O:{src}\n"
|
||||
result += f" P:{pred}\n"
|
||||
result += f" CER: {e_char_cnt/char_cnt:.2%} ({e_char_cnt}/{char_cnt} letters), WER: {e_word_cnt/word_cnt:.2%} ({e_word_cnt}/{word_cnt} words)\n"
|
||||
result_cnt += 1
|
||||
if result_cnt >= 3:
|
||||
break
|
||||
|
||||
print(result)
|
||||
return cer_val
|
||||
|
||||
|
||||
def train(model, train_loader):
|
||||
total_loss = 0
|
||||
iter_cnt = 0
|
||||
msg =''
|
||||
model.train()
|
||||
pbar = tqdm(iter(train_loader), leave=True, total=mini_epoch_length)
|
||||
for data in pbar:
|
||||
optimizer.zero_grad()
|
||||
inputs, targets, input_lengths, target_lengths, _ = data
|
||||
inputs = inputs.to(device)
|
||||
targets = targets.to(device)
|
||||
|
||||
outputs, output_lengths = model(inputs, input_lengths)
|
||||
loss = calctc_loss(outputs, targets, output_lengths, target_lengths)
|
||||
loss.backward()
|
||||
|
||||
lr = optimizer.step()
|
||||
total_loss += loss.item()
|
||||
iter_cnt += 1
|
||||
|
||||
msg = f'[LR: {lr: .7f} Loss: {loss.item(): .5f}, Avg loss: {(total_loss/iter_cnt): .5f}]'
|
||||
pbar.set_description(msg)
|
||||
if iter_cnt > mini_epoch_length:
|
||||
break
|
||||
|
||||
pbar.close()
|
||||
with open(log_name,'a', encoding='utf-8') as fp:
|
||||
msg = f'Epoch[{(epoch+1):d}]:\t{msg}\n'
|
||||
fp.write(msg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
device = "cuda"
|
||||
|
||||
os.makedirs('./results',exist_ok=True)
|
||||
|
||||
train_file = 'thuyg20_train.csv'
|
||||
test_file = 'thuyg20_test.csv'
|
||||
|
||||
train_set = SpeechDataset(train_file, augumentation=True)
|
||||
train_loader = SpeechDataLoader(train_set,num_workers=4, pin_memory = True, shuffle=True, batch_size=20)
|
||||
|
||||
validation_set = SpeechDataset(test_file, augumentation=False)
|
||||
validation_loader = SpeechDataLoader(validation_set,num_workers=4, pin_memory = True, shuffle=True, batch_size=20)
|
||||
|
||||
print("="*50)
|
||||
msg = f" Training Set: {train_file}, {len(train_set)} samples" + "\n"
|
||||
msg += f" Validation Set: {test_file}, {len(validation_set)} samples" + "\n"
|
||||
msg += f" Vocab Size : {uyghur_latin.vocab_size}"
|
||||
|
||||
print(msg)
|
||||
model = UModel(num_features_input = featurelen)
|
||||
|
||||
print("="*50)
|
||||
|
||||
log_name = model.checkpoint + '.log'
|
||||
with open(log_name,'a', encoding='utf-8') as fp:
|
||||
fp.write(msg+'\n')
|
||||
|
||||
model = model.to(device)
|
||||
|
||||
#Start train and validation
|
||||
testfile=["test1.wav","test2.wav", "test3.wav","test4.wav","test5.wav","test6.wav"]
|
||||
start_epoch = model.trained_epochs
|
||||
mini_epoch_length = len(train_loader)
|
||||
|
||||
optimizer = CustOpt(model.parameters(), mini_epoch_length, lr = 0.00002, min_lr=0.00002)
|
||||
for epoch in range(start_epoch,1000):
|
||||
torch.cuda.empty_cache()
|
||||
model.eval()
|
||||
msg = ""
|
||||
for afile in testfile:
|
||||
text = model.predict(afile,device)
|
||||
text = f"{afile}-->{text}\n"
|
||||
print(text,end="")
|
||||
msg += text
|
||||
|
||||
with open(log_name,'a', encoding='utf-8') as fp:
|
||||
fp.write(msg+'\n')
|
||||
|
||||
print("="*50)
|
||||
print(f"Training Epoch[{(epoch+1):d}]:")
|
||||
train(model, train_loader)
|
||||
if (epoch+1) % 2 == 0:
|
||||
print("Validating:")
|
||||
model.save((epoch+1))
|
||||
curcer = validate(model,validation_loader)
|
||||
if curcer < model.best_cer:
|
||||
model.best_cer = curcer
|
||||
model.save((epoch+1),best=True)
|
||||
|
||||
model.save((epoch+1))
|
66
uyghur.py
Normal file
66
uyghur.py
Normal file
@ -0,0 +1,66 @@
|
||||
import re
|
||||
|
||||
class Uyghur():
|
||||
def __init__(self, ):
|
||||
self.uyghur_latin = "abcdefghijklmnopqrstuvwxyz éöü’"
|
||||
self._vocab_list = [self.pad_char, self.sos_char,self.eos_char] + list(self.uyghur_latin) # $ for padding char. index must be 0
|
||||
self._vocab2idx = {v: idx for idx, v in enumerate(self._vocab_list)}
|
||||
|
||||
def encode(self, s):
|
||||
s = s.replace("-", ' ').replace(",", ' ').replace(".", ' ').replace("!", ' ').replace("?", ' ').replace("'","’")
|
||||
s = re.sub('\s+',' ',s).strip().lower()
|
||||
seq = [self.vocab_to_idx(v) for v in s if v in self.uyghur_latin]
|
||||
return seq
|
||||
|
||||
def decode(self, seq):
|
||||
vocabs = []
|
||||
for idx in seq:
|
||||
v = self.idx_to_vocab(idx)
|
||||
if idx == self.pad_idx or idx == self.eos_idx:
|
||||
break
|
||||
elif idx == self.sos_idx:
|
||||
pass
|
||||
else:
|
||||
vocabs.append(v)
|
||||
s = re.sub('\s+',' ',"".join(vocabs)).strip()
|
||||
return s
|
||||
|
||||
def vocab_to_idx(self, vocab):
|
||||
return self._vocab2idx[vocab]
|
||||
|
||||
def idx_to_vocab(self, idx):
|
||||
return self._vocab_list[idx]
|
||||
|
||||
def vocab_list(self):
|
||||
return self._vocab_list
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return len(self._vocab_list)
|
||||
|
||||
@property
|
||||
def pad_idx(self):
|
||||
return self.vocab_to_idx(self.pad_char)
|
||||
|
||||
@property
|
||||
def sos_idx(self):
|
||||
return self.vocab_to_idx(self.sos_char)
|
||||
|
||||
@property
|
||||
def eos_idx(self):
|
||||
return self.vocab_to_idx(self.eos_char)
|
||||
|
||||
@property
|
||||
def pad_char(self):
|
||||
return "<pad>"
|
||||
|
||||
@property
|
||||
def sos_char(self):
|
||||
return "<sos>"
|
||||
|
||||
@property
|
||||
def eos_char(self):
|
||||
return "<eos>"
|
||||
|
||||
|
||||
uyghur_latin = Uyghur()
|
Loading…
Reference in New Issue
Block a user