当前位置 博文首页 > 浅谈pytorch中的nn.Sequential(*net[3: 5])是啥意思

    浅谈pytorch中的nn.Sequential(*net[3: 5])是啥意思

    作者:alittlebai1 时间:2021-06-19 18:45

    看到代码里面有这个

    在这里插入图片描述

    1 class ResNeXt101(nn.Module):
        2 def __init__(self):
            3 super(ResNeXt101, self).__init__()
            4 net = resnext101()
            # print(os.getcwd(), net)
    
            5 net = list(net.children())  # net.children()得到resneXt 的表层网络
            # for i, value in enumerate(net):
            #     print(i, value)
            6 self.layer0 = nn.Sequential(net[:3])  # 将前三层打包0, 1, 2两层
            print(self.layer0)
            7 self.layer1 = nn.Sequential(*net[3: 5])  # 将3, 4两层打包
            8 self.layer2 = net[5]
            9 self.layer3 = net[6]
    

    可以看到代码中的第六行(序号自己去掉,我打上去的) self.layer0 = nn.Sequential(net[:3])
    第七行self.layer1 = nn.Sequential(*net[3: 5])
    有一个nn.Sequential(net[:3])
    nn.Sequential(*net[3: 5])
    今天不讲nn.Sequential()用法,意义,作用因为我也不咋明白。惊天就说*net[3: 5]这个东西为啥要带“ * ”
    当代码中不带*的时候,运行会出现以下问题

    在这里插入图片描述

    意思就是列表不是子类,就是说参数不对

    net = list(net.children())

    这一行代码是将模型的每一层取出来构建一个列表,自己试着打印就可以。大概的输出就是[conv(),BatchNorm2d(), ReLU,MaxPool2d]等等

    在这里插入图片描述

    总共是是个元素,和一般的列表不太一样。

    当我们取net[:3]的时候,传进去的参数是一个列表,但是我们用*net[:3]的时候传进去的是单个元素

    list1 = ["conv", ("relu", "maxing"), ("relu", "maxing", 3), 3]
    list2 = [list1[:1]]
    list3 = [*list1[:1]]
    print("list2:{}, *list1[:2]:{}".format(list1[:1], *list1[:1]))
    

    在这里插入图片描述

    结果不带✳的是列表,带✳的是元素,所以nn.Sequential(*net[3: 5])中的*net[3: 5]就是给nn.Sequential()这个容器中传入多个层。

    js
    下一篇:没有了