PyTorch学习笔记

模型保存与加载

Posted by JackyCJ on October 22, 2019

模型保存与加载

方法一(推荐):

# 保存
torch.save(model.state_dict(), path)

# 加载
model = ModelClass()
model.load_state_dict(torch.load(path))

方法二:

# 保存,实际使用时容易报错
torch.save(model, path)

# 加载
model = torch.load(path)

参考材料:https://pytorch.org/docs/stable/notes/serialization.html