码迷,mamicode.com
首页 > 其他好文 > 详细

pytorch 的register_hook和register_backward_hook的介绍和实验

时间:2020-03-25 23:13:57      阅读:242      评论:0      收藏:0      [点我收藏+]

标签:pytorch   nbsp   instance   col   +=   tput   manual   identity   pen   

class Classifier(nn.Module):
    def __init__(self, in_size, in_ch):
        super(Classifier, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_ch, 3, 3, 1, 1),
            nn.ReLU(),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(3, 6, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(6, 3, 3, 1, 1),
            nn.ReLU(),
        )
        self.fc = nn.Linear(3 * in_size * in_size, 1)

    def forward(self, x):
        x = self.layer1(x)
        identity = x
        x = self.layer2(x)
        x += identity
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


def print_grad(grad):
    print(========= register_hook output:======== )
    print(grad.size())
    print(grad)


def grad_hook(md, grad_in, grad_out):
    print(========= register_backward_hook output:======== )
    print(grad_out[0].size())
    print(grad_out[0])


torch.random.manual_seed(1000)

if __name__ == __main__:
    in_size, in_ch = 4, 1
    x = torch.randn(1, 1, 4, 4)
    model = Classifier(in_size, in_ch)
    y_hat = model(x)
    y_gt = torch.Tensor([[1.5]])
    crt = nn.MSELoss()
    print(y_hat)
    print(=======================)
    identity = []
    for idx, (name, md) in enumerate(model._modules.items()):
        md.register_backward_hook(grad_hook)
        if isinstance(md, nn.Linear):
            x += identity[0]
            x = torch.flatten(x, 1)
        x = md(x)
        x.register_hook(print_grad)
        if idx == 0:
            identity.append(x)

    loss = crt(x, y_gt)
    loss.backward()

    print(x)

 

pytorch 的register_hook和register_backward_hook的介绍和实验

标签:pytorch   nbsp   instance   col   +=   tput   manual   identity   pen   

原文地址:https://www.cnblogs.com/dxscode/p/12571148.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!