当前位置 博文首页 > pytorch 中forward 的用法与解释说明

    pytorch 中forward 的用法与解释说明

    作者:JY丫丫 时间:2021-07-20 18:58

    前言

    最近在使用pytorch的时候,模型训练时,不需要使用forward,只要在实例化一个对象中传入对应的参数就可以自动调用 forward 函数

    即:

    forward 的使用

    class Module(nn.Module):
     def __init__(self):
      super(Module, self).__init__()
      # ......
      
     def forward(self, x):
      # ......
      return x
    data = ..... #输入数据
    # 实例化一个对象
    module = Module()
    # 前向传播
    module(data) 
    # 而不是使用下面的
    # module.forward(data) 
    

    实际上

    module(data) 

    是等价于

    module.forward(data) 

    forward 使用的解释

    等价的原因是因为 python calss 中的__call__和__init__方法.

    class A():
     def __call__(self):
      print('i can be called like a function')
     
    a = A()
    a()

    out:

    i can be called like a function

    __call__里调用其他的函数

    class A():
     def __call__(self, param):
      
      print('i can called like a function')
      print('传入参数的类型是:{} 值为: {}'.format(type(param), param))
     
      res = self.forward(param)
      return res
     
     def forward(self, input_):
      print('forward 函数被调用了')
     
      print('in forward, 传入参数类型是:{} 值为: {}'.format( type(input_), input_))
      return input_ 
    a = A() 
    input_param = a('i')
    print("对象a传入的参数是:", input_param)
    

    out:

    i can called like a function

    传入参数的类型是:<class ‘str'> 值为: i

    forward 函数被调用了

    in forward, 传入参数类型是:<class ‘str'> 值为: i

    对象a传入的参数是: i

    补充:Pytorch 模型中nn.Model 中的forward() 前向传播不调用 解释

    在pytorch 中没有调用模型的forward()前向传播,只实列化后把参数传入。

    定义模型

    class Module(nn.Module):
     def __init__(self):
      super(Module, self).__init__()
      # ......
     
     def forward(self, x):
      # ......
      return x
    data = ..... #输入数据
    # 实例化一个对象
    module = Module()
    # 前向传播 直接把输入传入实列化
    module(data) 
    #没有使用module.forward(data) 
    

    实际上module(data) 等价于module.forward(data)

    等价的原因是因为 python calss 中的__call__ 可以让类像函数一样调用

    当执行model(x)的时候,底层自动调用forward方法计算结果

    class A():
     def __call__(self):
      print('i can be called like a function')
     
    a = A()
    a()
    >>>i can be called like a function

    在__call__ 里可调用其它的函数

    class A():
     def __call__(self, param):
      
      print('我在__call__中,传入参数',param)
     
      res = self.forward(param)
      return res
     
     def forward(self, x):
      print('我在forward函数中,传入参数类型是值为: ',x)
      return x
     
    a = A()
    y = a('i')
     >>> 我在__call__中,传入参数 i
     >>>我在forward函数中,传入参数类型是值为: i
    print("传入的参数是:", y)
     >>>传入的参数是: i

    以上为个人经验,希望能给大家一个参考,也希望大家多多支持站长博客。如有错误或未考虑完全的地方,望不吝赐教。

    jsjbwy
    下一篇:没有了