码迷,mamicode.com
首页 > Web开发 > 详细

Iris Classification Neural Network

时间:2018-05-19 15:39:46      阅读:343      评论:0      收藏:0      [点我收藏+]

标签:提高   算法   ase   print   feed   and   ica   ret   表示   

Iris Classification Neural Network

Neural Network

技术分享图片

formula derivation

\[ \begin{align} a & = x \cdot w_1 \y & = a \cdot w_2 \& = x \cdot w_1 \cdot w_2 \y & = softmax(y) \end{align} \]

code (training only)

\[ a = x \cdot w_1 \y = a \cdot w_2 \]

w1 = tf.Variable(tf.random_normal([4,5], stddev=1, seed=1))
w2 = tf.Variable(tf.random_normal([5,3], stddev=1, seed=1))

x = tf.placeholder(tf.float32, shape=(None, 4), name=‘x-input‘)

a = tf.matmul(x, w1)
y = tf.matmul(a, w2)

既然是有监督学习,那就在训练阶段必须要给出 label,以此来计算交叉熵

# 用来存储数据的标签
y_ = tf.placeholder(tf.float32, shape=(None, 3), name=‘y-input‘)

隐藏层的激活函数是 sigmoid

y = tf.sigmoid(y)

softmax 与 交叉熵(corss entropy) 的组合函数,损失函数是交叉熵的均值

# softmax & corss_entropy
cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_, logits=y)
# mean
cross_entropy_mean = tf.reduce_mean(cross_entropy)

为了防止神经网络过拟合,需加入正则化项,一般选取 “L2 正则化”

loss = cross_entropy_mean + \
    tf.contrib.layers.l2_regularizer(regulation_lamda)(w1) + \
    tf.contrib.layers.l2_regularizer(regulation_lamda)(w2)

为了加速神经网络的训练过程,需加入“指数衰减”技术

表示训练过程的计算图,优化方法选择了 Adam 算法,本质是反向传播算法。还可以选择“梯度下降法”(GradientDescentOptimizer)

train_step = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)

训练阶段

with tf.Session() as sess:  # Session 最好在“上下文机制”中开启,以防资源泄露
    init_op = tf.global_variables_initializer()  # 初始化网络中节点的参数,主要是 w1,w2
    sess.run(init_op)

    steps = 10000
    for i in range(steps):
        beg = (i * batch_size) % dataset_size    # 计算 batch
        end = min(beg+batch_size, dataset_size)  # 计算 batch
    
        sess.run(train_step, feed_dict={x:X[beg:end], y_:Y[beg:end]})  # 反向传播,训练网络
        if i % 1000 == 0:
            total_corss_entropy = sess.run(  # 计算交叉熵
                cross_entropy_mean,          # 计算交叉熵
                feed_dict={x:X, y_:Y}        # 计算交叉熵
            )
            print("After %d training steps, cross entropy on all data is %g" % (i, total_corss_entropy))

在训练阶段中,需要引入“滑动平均模型”来提高模型在测试数据上的健壮性(这是书上的说法,而我认为是泛化能力)

全部代码

# -*- encoding=utf8 -*-

from sklearn.datasets import load_iris
import tensorflow as tf


def label_convert(Y):
    l = list()
    for y in Y:
        if y == 0:
            l.append([1,0,0])
        elif y == 1:
            l.append([0, 1, 0])
        elif y == 2:
            l.append([0, 0, 1])
    return l


def load_data():
    iris = load_iris()
    X = iris.data
    Y = label_convert(iris.target)
    return (X,Y)

if __name__ == ‘__main__‘:
    X,Y = load_data()

    learning_rate = 0.001
    batch_size = 10
    dataset_size = 150
    regulation_lamda = 0.001

    w1 = tf.Variable(tf.random_normal([4,5], stddev=1, seed=1))
    w2 = tf.Variable(tf.random_normal([5,3], stddev=1, seed=1))

    x = tf.placeholder(tf.float32, shape=(None, 4), name=‘x-input‘)
    y_ = tf.placeholder(tf.float32, shape=(None, 3), name=‘y-input‘)

    a = tf.matmul(x, w1)
    y = tf.matmul(a, w2)

    y = tf.sigmoid(y)

    cross_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_, logits=y)
    cross_entropy_mean = tf.reduce_mean(cross_entropy)
    loss = cross_entropy_mean + \
           tf.contrib.layers.l2_regularizer(regulation_lamda)(w1) + \
           tf.contrib.layers.l2_regularizer(regulation_lamda)(w2)
    train_step = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)

    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)

        steps = 10000
        for i in range(steps):
            beg = (i * batch_size) % dataset_size
            end = min(beg+batch_size, dataset_size)

            sess.run(train_step, feed_dict={x:X[beg:end], y_:Y[beg:end]})
            if i % 1000 == 0:
                total_corss_entropy = sess.run(
                    cross_entropy_mean,
                    feed_dict={x:X, y_:Y}
                )
                print("After %d training steps, cross entropy on all data is %g" % (i, total_corss_entropy))

        print(sess.run(w1))
        print(sess.run(w2))

Experiment Result

random split cross validation

Iris Classification Neural Network

标签:提高   算法   ase   print   feed   and   ica   ret   表示   

原文地址:https://www.cnblogs.com/fengyubo/p/9060249.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!