当前位置 主页 > 网站技术 > 代码类 >

    将tensorflow模型打包成PB文件及PB文件读取方式

    栏目:代码类 时间:2020-01-23 21:09

    1. tensorflow模型文件打包成PB文件

    import tensorflow as tf
    from tensorflow.python.tools import freeze_graph
     
    with tf.Graph().as_default():
      with tf.device("/cpu:0"):
        config = tf.ConfigProto(allow_soft_placement=True)
        with tf.Session(config=config).as_default() as sess:
          model = Your_Model_Name()
          model.build_graph()
          sess.run(tf.initialize_all_variables())
          
          saver = tf.train.Saver()
          ckpt_path = "/your/model/path"
          saver.restore(sess, ckpt_path)
     
          graphdef = tf.get_default_graph().as_graph_def()
          tf.train.write_graph(sess.graph_def,"/your/save/path/","save_name.pb",as_text=False)
          frozen_graph = tf.graph_util.convert_variables_to_constants(sess,graphdef,['output/node/name'])
          frozen_graph_trim = tf.graph_util.remove_training_nodes(frozen_graph)
          freeze_graph.freeze_graph('/your/save/path/save_name.pb','',True, ckpt_path,'output/node/name','save/restore_all','save/Const:0','frozen_name.pb',True,"")

    2. PB文件读取使用

    output_graph_def = tf.GraphDef()
    with open("your_name.pb","rb") as f:
      output_graph_def.ParseFromString(f.read())
      _ = tf.import_graph_def(output_graph_def, name="")
     
    node_in = sess.graph.get_tensor_by_name("input_node_name")
    model_out = sess.graph.get_tensor_by_name("out_node_name")
     
    feed_dict = {node_in:in_data}
    pred = sess.run(model_out, feed_dict)

    以上这篇将tensorflow模型打包成PB文件及PB文件读取方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持IIS7站长之家。