模型保存与加载
方法一(推荐):
# 保存
torch.save(model.state_dict(), path)
# 加载
model = ModelClass()
model.load_state_dict(torch.load(path))
方法二:
# 保存,实际使用时容易报错
torch.save(model, path)
# 加载
model = torch.load(path)
参考材料:https://pytorch.org/docs/stable/notes/serialization.html