update: llm translate

This commit is contained in:
alikia2x (寒寒) 2024-09-10 21:35:25 +08:00
parent dcf53ca002
commit ebd1113a6e
Signed by: alikia2x
GPG Key ID: 56209E0CCD8420C6

View File

@ -6,13 +6,11 @@ from openai import OpenAI
load_dotenv()
# 初始化OpenAI客户端
client = OpenAI(
api_key=os.getenv("API_KEY"),
base_url=os.getenv("BASE_URL"),
)
# 系统提示词
system_prompt = """
The user will provide some text. Please parse the text into segments, each segment contains 1 to 5 sentences. Translate each sentence into the corresponding language. If the input is in Chinese, return the English translation, and vice versa.
@ -28,7 +26,6 @@ EXAMPLE JSON OUTPUT:
}
"""
# 翻译函数
def translate_text(text):
messages = [
{"role": "system", "content": system_prompt},
@ -43,7 +40,6 @@ def translate_text(text):
return json.loads(response.choices[0].message.content)
# 处理单个文件的函数
def process_file(input_file, output_dir):
try:
with open(input_file, 'r', encoding='utf-8') as f:
@ -60,7 +56,6 @@ def process_file(input_file, output_dir):
except Exception as e:
print(f"Error processing {input_file}: {e}")
# 批量处理目录下的文件
def batch_process(input_dir, output_dir, num_threads=4):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
@ -73,17 +68,14 @@ def batch_process(input_dir, output_dir, num_threads=4):
threads.append(thread)
thread.start()
# 控制线程数量
if len(threads) >= num_threads:
for t in threads:
t.join()
threads = []
# 等待剩余线程完成
for t in threads:
t.join()
# 主函数
if __name__ == "__main__":
input_dir = "./source"
output_dir = "./output"