码迷,mamicode.com
首页 > Web开发 > 详细

TensorFlow实战-VGGNet

时间:2017-06-29 23:49:49      阅读:378      评论:0      收藏:0      [点我收藏+]

标签:softmax   name   log   val   ...   return   weight   import   nim   

  1 from ... import input_data
  2 input_data=data_read()
  3 import tensorflow as tf
  4 
  5 def conv(name,x,w,b):
  6     return tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(x,w,strides=[1,1,1,1],padding=SAME),b),name=name)
  7 
  8 def max_pool(name,x,k):
  9     return tf.nn.max_pool(x,ksize=[1,k,k,1],strides=[1,k,k,1],padding=SAME,name=name)
 10 
 11 def fc(name,x,w,b):
 12     return tf.nn.relu(tf.matmul(x,w)+b,name=name)
 13 
 14 def vgg_net(_X,_weights,_biases,keep_prob):
 15         x_shape=_X.get_shape()
 16     _X=tf.reshape(_X,shape=[-1,X_shape[1].value,x_shape[2].value,x_shape[3].value])
 17 
 18     conv1_1=conv(conv1_1,_X,_weights[wc1_1],_biases[bc1_1])
 19     conv1_2=conv(conv1_2,conv1_1,_weights[wc1_2],_biases[bc1_2])
 20     pool1=max_pool(pool1,conv1_2,k=2)
 21 
 22     conv2_1=conv(conv2_1,pool1,_weights[wc2_1],_biases[bc2_1])
 23     conv2_2=conv(conv2_2,conv2_1,_weights[wc2_2],_biases[bc2_2])
 24     pool2=max_pool(pool2,conv2_2,k=2)
 25 
 26     conv3_1=conv(conv3_1,pool2,_weights[wc3_1],_biases[bc3_1])
 27     conv3_2=conv(conv3_2,conv3_1,_weights[wc3_2],_biases[bc3_2])
 28     conv3_3=conv(conv3_3,conv3_2,_weights[wc3_3],_biases[bc3_3])
 29     pool3=max_pool(pool3,conv3_3,k=2)
 30 
 31     conv4_1=conv(conv4_1,pool3,_weights[wc4_1],_biases[bc4_1])
 32     conv4_2=conv(conv4_2,conv4_1,_weights[wc4_2],_biases[bc4_2])
 33     conv4_3=conv(conv4_3,conv4_2,_weights[wc4_3],_biases[bc4_3])
 34     pool4=max_pool(pool4,conv4_3,k=2)
 35 
 36     conv5_1=conv(conv5_1,pool4,_weights[wc5_1],_biases[bc5_1])
 37     conv5_2=conv(conv5_2,conv5_1,_weights[wc5_2],_biases[bc5_2])
 38     conv5_3=conv(conv5_3,conv5_2,_weights[wc5_3],_biases[bc5_3])    
 39     pool5=max_pool(pool5,conv5_3,k=2)
 40 
 41     _shape=pool5.get_shape()
 42     flatten=_shape[1].value*_shape[2].value*_shape[3].value
 43     pool5=tf.reshape(pool5,shape=[-1,flatten])    
 44     fc1=fc(fc1,pool5,_weights[fc1],_biases[fb1])
 45     fc1=tf.nn.dropout(fc1,keep_prob)
 46 
 47     fc2=fc(fc2,fc1,_weights[fc2],_biases[fb2])
 48     fc2=tf.nn.dropout(fc2,keep_prob)
 49 
 50     fc3=fc(fc3,fc2,_weights[fc3],_biases[fb3])
 51     fc3=tf.nn.dropout(fc3,keep_prob)
 52 
 53     out=tf.argmax(tf.nn.softmax(fc3),1)
 54     
 55     return out
 56 
 57 learning_rate=0.001
 58 max_iters=200000
 59 batch_size=100
 60 display_step=20
 61 
 62 n_input=224*224*3
 63 n_classes=1000
 64 dropout=0.8
 65 
 66 x=tf.placeholder(tf.float32,[None,n_input])
 67 y=tf.placeholder(tf.float32,[None,n_classes])
 68 keep_prob=tf.placeholder(tf.float32)
 69 
 70 weights={
 71     wc1_1:tf.Variable(tf.random_normal([3,3,3,64])),
 72     wc1_2:tf.Variable(tf.random_normal([3,3,64,64])),
 73     wc2_1:tf.Variable(tf.random_normal([3,3,64,128])),
 74     wc2_2:tf.Variable(tf.random_normal([3,3,128,128])),
 75     wc3_1:tf.Variable(tf.random_normal([3,3,128,256])),
 76     wc3_2:tf.Variable(tf.random_normal([3,3,256,256])),
 77     wc3_3:tf.Variable(tf.random_normal([3,3,256,256])),
 78     wc4_1:tf.Variable(tf.random_normal([3,3,256,512])),
 79     wc4_2:tf.Variable(tf.random_normal([3,3,512,512])),
 80     wc4_3:tf.Variable(tf.random_normal([3,3,512,512])),
 81     wc5_1:tf.Variable(tf.random_normal([3,3,512,512])),
 82     wc5_2:tf.Variable(tf.random_normal([3,3,512,512])),
 83     wc5_3:tf.Variable(tf.random_normal([3,3,512,512])),
 84     fc1:tf.Variable(tf.random_normal([7*7*512,4096])),
 85     fc2:tf.Variable(tf.random_normal([4096,4096])),
 86     fc3:tf.Variable(tf.random_normal([4096,n_classes]))
 87 }
 88 
 89 biases={
 90     bc1_1:tf.Variable(tf.random_normal([64])),
 91     bc1_2:tf.Variable(tf.random_normal([64])),
 92     bc2_1:tf.Variable(tf.random_normal([128])),
 93     bc2_2:tf.Variable(tf.random_normal([128])),
 94     bc3_1:tf.Variable(tf.random_normal([256])),
 95     bc3_2:tf.Variable(tf.random_normal([256])),
 96     bc3_3:tf.Variable(tf.random_normal([256])),
 97     bc4_1:tf.Variable(tf.random_normal([512])),
 98     bc4_2:tf.Variable(tf.random_normal([512])),
 99     bc4_3:tf.Variable(tf.random_normal([512])),
100     bc5_1:tf.Variable(tf.random_normal([512])),
101     bc5_2:tf.Variable(tf.random_normal([512])),
102     bc5_3:tf.Variable(tf.random_normal([512]))
103 }
104 
105 pred=vgg_net(x,weights,biases,keep_prob)
106 
107 cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred,y))
108 optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
109 
110 correct=tf.equal(tf.argmax(pred,1),tf.argmax(y,1))
111 accuracy=tf.reduce_mean(tf.cast(correct,float32))
112 
113 init=tf.initialize_all_variables()
114 
115 with tf.Session() as sess:
116     sess.run(init)
117     step=1
118 
119     while step*batch_size<max_iters:
120         batch_xs,batch_ys=mnist.train.next_batch(batch_size)
121         sess.run(optimizer,feed_dict{x:batch_xs,y:batch_ys,keep_prob:dropout})
122 
123     step+=1        

 

TensorFlow实战-VGGNet

标签:softmax   name   log   val   ...   return   weight   import   nim   

原文地址:http://www.cnblogs.com/fighting-lady/p/7096547.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!