26 lines
766 B
Python
26 lines
766 B
Python
import torch
|
|
|
|
from modelV3_10 import VideoClassifierV3_10
|
|
from modelV3_9 import VideoClassifierV3_9
|
|
|
|
|
|
def convert_checkpoint(original_model, new_model):
|
|
"""转换原始checkpoint到新结构"""
|
|
state_dict = original_model.state_dict()
|
|
|
|
# 直接复制所有参数(因为结构保持兼容)
|
|
new_model.load_state_dict(state_dict)
|
|
return new_model
|
|
|
|
# 使用示例
|
|
original_model = VideoClassifierV3_9()
|
|
new_model = VideoClassifierV3_10()
|
|
|
|
# 加载原始checkpoint
|
|
original_model.load_state_dict(torch.load('./filter/checkpoints/best_model_V3.9.pt'))
|
|
|
|
# 转换参数
|
|
converted_model = convert_checkpoint(original_model, new_model)
|
|
|
|
# 保存转换后的模型
|
|
torch.save(converted_model.state_dict(), './filter/checkpoints/best_model_V3.10.pt') |