PyTorch-模型保存与加载
保存:
model = LinearRegression()
# ......各种操作
model.eval()
#训练完成,保存状态字典到linear.pkl
torch.save(model.state_dict(), \'./linear.pkl\')
加载:
model = LinearRegression()
model.load_state_dict(torch.load(\'linear.pth\'))
#...各种使用,比如预测...
x_test=np.arrar([..............])
x_test = torch.from_numpy(x_test)
predict_y = model(Variable(x_test))
版权声明:本文为onenoteone原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。