cvsa/ml/pred/export_onnx.py
alikia2x 636c5e25cb
ref: move ML stuff
add: .idea to VCS, the refactor guide
2025-03-29 14:13:15 +08:00

29 lines
1.2 KiB
Python

import torch
import torch.onnx
from model import CompactPredictor
def export_model(input_size, checkpoint_path, onnx_path):
model = CompactPredictor(input_size)
model.load_state_dict(torch.load(checkpoint_path))
dummy_input = torch.randn(1, input_size)
model.eval()
torch.onnx.export(model, # Model to be exported
dummy_input, # Model input
onnx_path, # Save path
export_params=True, # Whether to export model parameters
opset_version=11, # ONNX opset version
do_constant_folding=True, # Whether to perform constant folding optimization
input_names=['input'], # Input node name
output_names=['output'], # Output node name
dynamic_axes={'input': {0: 'batch_size'}, # Dynamic batch size
'output': {0: 'batch_size'}})
print(f"ONNX model has been exported to: {onnx_path}")
if __name__ == '__main__':
export_model(10, './pred/checkpoints/long_term.pt', 'long_term.onnx')
export_model(12, './pred/checkpoints/short_term.pt', 'short_term.onnx')