From bf2c9a393a096f9c9913e070fdd68d10779a3e84 Mon Sep 17 00:00:00 2001 From: alikia2x Date: Thu, 26 Sep 2024 22:57:27 +0800 Subject: [PATCH] update: add metadata export of intention classify --- intention-classify/training/train.py | 8 +++++++- translate/validation/LLMtrans.py | 3 ++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/intention-classify/training/train.py b/intention-classify/training/train.py index 422e16c..dd7d946 100644 --- a/intention-classify/training/train.py +++ b/intention-classify/training/train.py @@ -88,7 +88,7 @@ def main(): tokenizer = AutoTokenizer.from_pretrained(model_name) data = load_data("data.json") - class_to_idx, _ = create_class_mappings(data) + class_to_idx, idx_to_class = create_class_mappings(data) embedding_map = torch.load("token_id_to_reduced_embedding.pt") dataset = preprocess_data(data, embedding_map, tokenizer, class_to_idx) train_data, _ = train_test_split(dataset, test_size=0.2) @@ -143,6 +143,12 @@ def main(): }, opset_version=11, ) + meta = { + "idx_to_class": idx_to_class, + "threshold": 0 + } + with open('NLU_meta.json', 'w') as f: + json.dump(meta, f) if __name__ == "__main__": diff --git a/translate/validation/LLMtrans.py b/translate/validation/LLMtrans.py index 9d06a1a..03c979b 100644 --- a/translate/validation/LLMtrans.py +++ b/translate/validation/LLMtrans.py @@ -2,6 +2,7 @@ from openai import OpenAI import argparse import os from dotenv import load_dotenv +from tqdm import tqdm def translate_text(text, client, model_name, temp): messages = [ @@ -37,7 +38,7 @@ with open(input_file, "r") as f: src_lines = f.readlines() -for line in src_lines: +for line in tqdm(src_lines): result = translate_text(line, client, model, temp) with open(output_file, 'a') as f: f.write(result + '\n')