update: llm translate

This commit is contained in:
alikia2x (寒寒) 2024-09-10 21:35:25 +08:00
parent dcf53ca002
commit ebd1113a6e
Signed by: alikia2x
GPG Key ID: 56209E0CCD8420C6

View File

@ -6,13 +6,11 @@ from openai import OpenAI
load_dotenv() load_dotenv()
# 初始化OpenAI客户端
client = OpenAI( client = OpenAI(
api_key=os.getenv("API_KEY"), api_key=os.getenv("API_KEY"),
base_url=os.getenv("BASE_URL"), base_url=os.getenv("BASE_URL"),
) )
# 系统提示词
system_prompt = """ system_prompt = """
The user will provide some text. Please parse the text into segments, each segment contains 1 to 5 sentences. Translate each sentence into the corresponding language. If the input is in Chinese, return the English translation, and vice versa. The user will provide some text. Please parse the text into segments, each segment contains 1 to 5 sentences. Translate each sentence into the corresponding language. If the input is in Chinese, return the English translation, and vice versa.
@ -28,7 +26,6 @@ EXAMPLE JSON OUTPUT:
} }
""" """
# 翻译函数
def translate_text(text): def translate_text(text):
messages = [ messages = [
{"role": "system", "content": system_prompt}, {"role": "system", "content": system_prompt},
@ -43,7 +40,6 @@ def translate_text(text):
return json.loads(response.choices[0].message.content) return json.loads(response.choices[0].message.content)
# 处理单个文件的函数
def process_file(input_file, output_dir): def process_file(input_file, output_dir):
try: try:
with open(input_file, 'r', encoding='utf-8') as f: with open(input_file, 'r', encoding='utf-8') as f:
@ -60,7 +56,6 @@ def process_file(input_file, output_dir):
except Exception as e: except Exception as e:
print(f"Error processing {input_file}: {e}") print(f"Error processing {input_file}: {e}")
# 批量处理目录下的文件
def batch_process(input_dir, output_dir, num_threads=4): def batch_process(input_dir, output_dir, num_threads=4):
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
@ -73,17 +68,14 @@ def batch_process(input_dir, output_dir, num_threads=4):
threads.append(thread) threads.append(thread)
thread.start() thread.start()
# 控制线程数量
if len(threads) >= num_threads: if len(threads) >= num_threads:
for t in threads: for t in threads:
t.join() t.join()
threads = [] threads = []
# 等待剩余线程完成
for t in threads: for t in threads:
t.join() t.join()
# 主函数
if __name__ == "__main__": if __name__ == "__main__":
input_dir = "./source" input_dir = "./source"
output_dir = "./output" output_dir = "./output"