标签:fast cnn final values 使用 www. out fas sha
tf.nn.conv2d_backprop_input and tf.nn.conv2d_backprop_filter in an example.In tf.nn, there are 4 closely related 2d conv functions:
tf.nn.conv2dtf.nn.conv2d_backprop_filtertf.nn.conv2d_backprop_inputtf.nn.conv2d_transposeGiven out = conv2d(x, w) and the output gradient d_out:
tf.nn.conv2d_backprop_filter to compute the filter gradient d_wtf.nn.conv2d_backprop_input to compute the filter gradient d_xtf.nn.conv2d_backprop_input can be implemented by tf.nn.conv2d_transposetf.nn.conv2dNow, let‘s give an actual working code example of how to use the 4 functions above to compute d_x and d_w given d_out. This shows how conv2d, conv2d_backprop_filter, conv2d_backprop_input, and conv2d_transpose are related to each other. Please find the full scripts here.
Computing d_x in 4 different ways:
# Method 1: TF‘s autodiff
d_x = tf.gradients(f, x)[0]
# Method 2: manually using conv2d
d_x_manual = tf.nn.conv2d(input=tf_pad_to_full_conv2d(d_out, w_size),
filter=tf_rot180(w),
strides=strides,
padding=‘VALID‘)
# Method 3: conv2d_backprop_input
d_x_backprop_input = tf.nn.conv2d_backprop_input(input_sizes=x_shape,
filter=w,
out_backprop=d_out,
strides=strides,
padding=‘VALID‘)
# Method 4: conv2d_transpose
d_x_transpose = tf.nn.conv2d_transpose(value=d_out,
filter=w,
output_shape=x_shape,
strides=strides,
padding=‘VALID‘)
Computing d_w in 3 different ways:
# Method 1: TF‘s autodiff
d_w = tf.gradients(f, w)[0]
# Method 2: manually using conv2d
d_w_manual = tf_NHWC_to_HWIO(tf.nn.conv2d(input=x,
filter=tf_NHWC_to_HWIO(d_out),
strides=strides,
padding=‘VALID‘))
# Method 3: conv2d_backprop_filter
d_w_backprop_filter = tf.nn.conv2d_backprop_filter(input=x,
filter_sizes=w_shape,
out_backprop=d_out,
strides=strides,
padding=‘VALID‘)
Please see the full scripts for the implementation of tf_rot180, tf_pad_to_full_conv2d, tf_NHWC_to_HWIO. In the scripts, we check that the final output values of different methods are the same; a numpy implementation is also available.
标签:fast cnn final values 使用 www. out fas sha
原文地址:https://www.cnblogs.com/ranjiewen/p/9368359.html