cvsa/filter/checkpoint_conversion.py

26 lines
766 B
Python

import torch
from modelV3_10 import VideoClassifierV3_10
from modelV3_9 import VideoClassifierV3_9
def convert_checkpoint(original_model, new_model):
"""转换原始checkpoint到新结构"""
state_dict = original_model.state_dict()
# 直接复制所有参数(因为结构保持兼容)
new_model.load_state_dict(state_dict)
return new_model
# 使用示例
original_model = VideoClassifierV3_9()
new_model = VideoClassifierV3_10()
# 加载原始checkpoint
original_model.load_state_dict(torch.load('./filter/checkpoints/best_model_V3.9.pt'))
# 转换参数
converted_model = convert_checkpoint(original_model, new_model)
# 保存转换后的模型
torch.save(converted_model.state_dict(), './filter/checkpoints/best_model_V3.10.pt')