码迷,mamicode.com
首页 > 其他好文 > 详细

Debug --> Variable,Tensor,Numpy的转换

时间:2021-04-21 12:37:02      阅读:0      评论:0      收藏:0      [点我收藏+]

标签:on()   ble   类型   device   sof   contain   eva   convert   glob   

尝试输出keras模型参数的时候,需要解决的问题:

 1 import tensorflow.compat.v1 as tf
 2 tf.disable_v2_behavior()
 3 import numpy as np
 4 weight = tf.get_variable(name=weights,initializer=tf.random_normal([5,2], stddev=0.01))
 5 with tf.Session() as sess:
 6     sess.run(tf.global_variables_initializer())
 7     print(------------------打印出已经初始化之后的Variable的值------------------------------)
 8     print(sess.run(weight))
 9     print(----------weight的类型------------)
10     print(type(weight))
11     # Variable转换为Tensor
12     # Variable类型转换为tensor类型(无论是numpy转换为Tensor还是Variable转换为Tensor都可以使用tf.convert_to_tensor)
13     data_tensor = tf.convert_to_tensor(weight) 
14     # 打印出Tensor的值(由Variable转化而来)
15     print(------------------Variable转化为Tensor,打印出Tensor的值--------------------------)
16     print(sess.run(data_tensor))
17     # tensor转化为numpy
18     print(-------------------tensor转换为numpy,打印出numpy的值-----------------)
19     data_numpy = data_tensor.eval()
20     print(data_numpy)
21     print(------------------numpy转换为Tensor---------------------------)
22     ten = tf.convert_to_tensor(data_numpy)
23     print(ten)
24     print(sess.run(ten))
25     # tensor转化为Variable(其实是Variable继承Tensor的结构,但是没有值
26     print(---------------------tensor转换为Variable(需要重新进行初始化)----------------------)
27     v = tf.Variable(data_tensor) # 此时Variable继承的是Tensor的结构,至于Variable的值,需要重新进行initialize
28     sess.run(tf.global_variables_initializer())
29     print(sess.run(weight)) # 此时输出的weight和v的结构是相同的,但是值是不同的。
30     print(sess.run(v))
31     
32 #     tf.enable_eager_execution(
33 #     config=None,
34 #     device_policy=None,
35 #     execution_mode=None
36 #     )
37     # Variable转换为numpy(也是使用eval)
38     print(---------------Variable转换为numpy(也是使用eval)--------------------)
39     data_numpy2 = weight.eval()
40     print(data_numpy2)

 

 

1.模型保存

model.save_model()可以保存网络结构权重以及优化器的参数
model.save_weights() 仅仅保存权重

2.模型加载

from keras.models import load_model
load_model():只能load 由save_model保存的,将模型和weight全load进来

model.load_weights(self, filepath, by_name=False):在加载权重之前,model必须编译好

3.sequential 和functional

序列式模型只能有单输入单输出,函数式模型可以有多个输入输出

4.model类

因为是继承, model对象有 container和layer的所有方法,可以用model对象访问下面三个类的所有方法

Model(Container)containerlayer
fit summary get_input_at(node_index)
evaluate get_layer get_config()
predict get_weights compute_mask(x, mask)
train on batch set_weights get_input_mask_at(node_index)
test_on_batch get_config get_output_at(node_index)
predict_on_batch compute_output_shape  
evaluate_generator    
predict_generator    
 

5.打印各层权重

layer.get_weights返回的是没有名字的权重array,Model.get_weights() 是他们的拼接,也没有名字,利用layer.weights 可以访问到后台的变量

1 #打印各层名字,权重的形状
2 for layer in model.layers:
3         for weight in layer.weights:
4             print weight.name,weight.shape

上面输出的weight是Var类型,下面给出另一种方法,输出的weight是np.Array类型:

1 names = [weight.name for layer in model.layers for weight in layer.weights]
2 weights = model.get_weights()
3 for name, weight in zip(names, weights):
4     print(name, weight.shape) 

 



 

Debug --> Variable,Tensor,Numpy的转换

标签:on()   ble   类型   device   sof   contain   eva   convert   glob   

原文地址:https://www.cnblogs.com/aluomengmengda/p/14679858.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!