标签:idt fun cpp ack type 阅读 get stride auto
const vec_t& back_propagation(const vec_t& curr_delta, size_t index) override {
auto& ws = this->get_worker_storage(index);
conv_layer_worker_specific_storage& cws = conv_layer_worker_storage_[index];
const vec_t& prev_out = *(cws.prev_out_padded_);
const activation::function& prev_h = prev_->activation_function();
vec_t* prev_delta = (pad_type_ == padding::same) ? &cws.prev_delta_padded_ : &ws.prev_delta_;
vec_t& dW = ws.dW_;
vec_t& db = ws.db_;
std::fill(prev_delta->begin(), prev_delta->end(), float_t(0));
// propagate delta to previous layer
for_i(in_.depth_, [&](int inc) {
for (cnn_size_t outc = 0; outc < out_.depth_; outc++) {
if (!tbl_.is_connected(outc, inc)) continue;
const float_t *pw = &this->W_[weight_.get_index(0, 0, in_.depth_ * outc + inc)];
const float_t *pdelta_src = &curr_delta[out_.get_index(0, 0, outc)];
float_t *pdelta_dst = &(*prev_delta)[in_padded_.get_index(0, 0, inc)];
for (cnn_size_t y = 0; y < out_.height_; y++) {
for (cnn_size_t x = 0; x < out_.width_; x++) {
const float_t * ppw = pw;
const float_t ppdelta_src = pdelta_src[y * out_.width_ + x];
float_t * ppdelta_dst = pdelta_dst + y * h_stride_ * in_padded_.width_ + x * w_stride_;
for (cnn_size_t wy = 0; wy < weight_.height_; wy++) {
for (cnn_size_t wx = 0; wx < weight_.width_; wx++) {
ppdelta_dst[wy * in_padded_.width_ + wx] += *ppw++ * ppdelta_src;
}
}
}
}
}
});
for_i(parallelize_, in_padded_.size(), [&](int i) {
(*prev_delta)[i] *= prev_h.df(prev_out[i]);
});
// accumulate dw
for_i(in_.depth_, [&](int inc) {
for (cnn_size_t outc = 0; outc < out_.depth_; outc++) {
if (!tbl_.is_connected(outc, inc)) continue;
for (cnn_size_t wy = 0; wy < weight_.height_; wy++) {
for (cnn_size_t wx = 0; wx < weight_.width_; wx++) {
float_t dst = float_t(0);
const float_t * prevo = &prev_out[in_padded_.get_index(wx, wy, inc)];
const float_t * delta = &curr_delta[out_.get_index(0, 0, outc)];
for (cnn_size_t y = 0; y < out_.height_; y++) {
dst += vectorize::dot(prevo + y * in_padded_.width_, delta + y * out_.width_, out_.width_);
}
dW[weight_.get_index(wx, wy, in_.depth_ * outc + inc)] += dst;
}
}
}
});
// accumulate db
if (!db.empty()) {
for (cnn_size_t outc = 0; outc < out_.depth_; outc++) {
const float_t *delta = &curr_delta[out_.get_index(0, 0, outc)];
db[outc] += std::accumulate(delta, delta + out_.width_ * out_.height_, float_t(0));
}
}
if (pad_type_ == padding::same)
copy_and_unpad_delta(cws.prev_delta_padded_, ws.prev_delta_);
CNN_LOG_VECTOR(curr_delta, "[pc]curr_delta");
CNN_LOG_VECTOR(prev_delta_[index], "[pc]prev_delta");
CNN_LOG_VECTOR(dW, "[pc]dW");
CNN_LOG_VECTOR(db, "[pc]db");
return prev_->back_propagation(ws.prev_delta_, index);
}
标签:idt fun cpp ack type 阅读 get stride auto
原文地址:http://www.cnblogs.com/zmlvyou/p/7821686.html