Delete tekshur.py
This commit is contained in:
parent
8f239fdb84
commit
31257e0ad9
75
tekshur.py
75
tekshur.py
@ -1,75 +0,0 @@
|
||||
import torch
|
||||
from data import SpeechDataset, SpeechDataLoader, featurelen, uyghur_latin, cer
|
||||
from GCGCResM import GCGCResM
|
||||
from uformer import UFormer
|
||||
from UDS2W2L50 import UDS2W2L50
|
||||
from UFormerCTC2 import UFormerCTC2
|
||||
|
||||
import sys
|
||||
import os
|
||||
import glob
|
||||
from tqdm import tqdm
|
||||
|
||||
def tekshurctc(model, hojjet, device):
|
||||
training_set = SpeechDataset(hojjet, augumentation=False)
|
||||
loader = SpeechDataLoader(training_set,num_workers=4, shuffle=False, batch_size=32)
|
||||
|
||||
line = []
|
||||
with torch.no_grad():
|
||||
pbar = tqdm(iter(loader), leave=True, total=len(loader))
|
||||
for inputs, targets, input_lengths, _ , paths in pbar:
|
||||
|
||||
inputs = inputs.to(device,non_blocking=True)
|
||||
outputs, output_lengths = model(inputs, input_lengths)
|
||||
preds = model.greedydecode(outputs, output_lengths)
|
||||
targets = [uyghur_latin.decode(target) for target in targets]
|
||||
|
||||
for pred, src, wavename in zip(preds, targets, paths):
|
||||
xatasani , _ = cer(pred, src)
|
||||
if xatasani >= 1:
|
||||
xata = f"{wavename}\t{src}\t{xatasani}\n"
|
||||
#xata = f"{src}\n{pred}\n\n"
|
||||
line.append(xata)
|
||||
return line
|
||||
|
||||
|
||||
def tekshurs2s(model, hojjet, device):
|
||||
training_set = SpeechDataset(hojjet, augumentation=False)
|
||||
loader = SpeechDataLoader(training_set,num_workers=4, shuffle=False, batch_size=20)
|
||||
|
||||
line = []
|
||||
with torch.no_grad():
|
||||
pbar = tqdm(iter(loader), leave=True, total=len(loader))
|
||||
for inputs, targets, input_lengths, _ , paths in pbar:
|
||||
|
||||
inputs = inputs.to(device,non_blocking=True)
|
||||
targets = targets.to(device,non_blocking=True)
|
||||
input_lengths = input_lengths.to(device,non_blocking=True)
|
||||
|
||||
outputs, _ = model(inputs, input_lengths, targets)
|
||||
preds = model.greedydecode(outputs, 0)
|
||||
targets = [uyghur_latin.decode(target) for target in targets]
|
||||
|
||||
for pred, src, wavename in zip(preds, targets, paths):
|
||||
xatasani , _ = cer(pred, src)
|
||||
if xatasani >= 5:
|
||||
xata = f"{wavename}\t{src}\t{xatasani}\n"
|
||||
#xata = f"{src}\n{pred}\n\n"
|
||||
line.append(xata)
|
||||
return line
|
||||
|
||||
if __name__ == '__main__':
|
||||
device = 'cuda'
|
||||
#model = GCGCResM(featurelen, load_best=False)
|
||||
#model = UFormer(featurelen, load_best=False)
|
||||
|
||||
model = UDS2W2L50(featurelen, load_best=False)
|
||||
#model = UFormerCTC2(featurelen, load_best=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
#'uyghur_train.csv' 'uyghur_thuyg20_train_small.csv', ''
|
||||
#netije = tekshurs2s(model, 'uyghur_train.csv', device)
|
||||
netije = tekshurctc(model, 'uyghur_thuyg20_test_small.csv', device)
|
||||
with open('tek_test.csv','w',encoding='utf_8_sig') as f:
|
||||
f.writelines(netije)
|
Loading…
Reference in New Issue
Block a user