博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
TensorFlow保存和载入模型
阅读量:4575 次
发布时间:2019-06-08

本文共 454 字,大约阅读时间需要 1 分钟。

首先定义一个tf.train.Saver类:

saver = tf.train.Saver(max_to_keep=1)

其中,max_to_keep参数设定只保存最后一个参数,默认值是5,即保存最后5个模型,如果设置成0,训练过程中的所有模型都会被保存。

模型训练好以后,保存模型:

saver.save(sess, ckpt_dir + "/nn_model.ckpt", global_step=1)

其中,sess是Session,ckpt_dir + "/nn_model.ckpt"是保存的路径和名称,global_step是模型名称的后缀名,由于我们只保存最后一个模型,所以可以设置为1,如果每一个模型都想保存,可以设置成训练的epoch。

载入模型比较简单:

saver.restore(sess, model_file)

其中,sess是Session,model_file是模型的路径和名称。

转载于:https://www.cnblogs.com/mstk/p/9395589.html

你可能感兴趣的文章
android平台下使用点九PNG技术
查看>>
Python学习3,列表
查看>>
最长回文子串
查看>>
JAVA基础-JDBC(一)
查看>>
js中for和while运行速度比较
查看>>
算法第5章作业
查看>>
7.9 练习
查看>>
基于ArcGIS JS API的在线专题地图实现
查看>>
learnByWork
查看>>
lua 函数
查看>>
Git的基本命令
查看>>
四平方和
查看>>
第十八周 12.27-1.2
查看>>
C# IP地址字符串和数值转换
查看>>
TCHAR和CHAR类型的互转
查看>>
常用界面布局
查看>>
C语言—— for 循环
查看>>
IBM lotus9.0测试版即将公测
查看>>
xml常用方法
查看>>
Cube Stacking(并差集深度+结点个数)
查看>>