29 lines
1.2 KiB
Python
29 lines
1.2 KiB
Python
import torch
|
|
import torch.onnx
|
|
from model import CompactPredictor
|
|
|
|
def export_model(input_size, checkpoint_path, onnx_path):
|
|
model = CompactPredictor(input_size)
|
|
model.load_state_dict(torch.load(checkpoint_path))
|
|
|
|
dummy_input = torch.randn(1, input_size)
|
|
|
|
model.eval()
|
|
|
|
torch.onnx.export(model, # Model to be exported
|
|
dummy_input, # Model input
|
|
onnx_path, # Save path
|
|
export_params=True, # Whether to export model parameters
|
|
opset_version=11, # ONNX opset version
|
|
do_constant_folding=True, # Whether to perform constant folding optimization
|
|
input_names=['input'], # Input node name
|
|
output_names=['output'], # Output node name
|
|
dynamic_axes={'input': {0: 'batch_size'}, # Dynamic batch size
|
|
'output': {0: 'batch_size'}})
|
|
|
|
print(f"ONNX model has been exported to: {onnx_path}")
|
|
|
|
if __name__ == '__main__':
|
|
export_model(10, './pred/checkpoints/long_term.pt', 'long_term.onnx')
|
|
export_model(12, './pred/checkpoints/short_term.pt', 'short_term.onnx')
|