方法python类

时间:2019-03-05 17:51:02

标签: python-3.x class methods

class NN(object):

    def __init__(...):
        [...] #some intialization of the class

    #define a recursive function to return  a vector which has atleast one non-zero element
    @staticmethod
    def generate_random_nodes(dropout_prob, size):
        temp = np.random.binomial(1, dropout_prob, size)
        return temp if not sum(temp) else generate_random_nodes(dropout_prob, size)

    def compute_dropout(self, activations, dropout_prob = 0.5):
        [...]
        mult = np.copy(activations)          
        temp = generate_random_nodes(dropout_prob, size = activations.shape[0])
        mult[:,i] = temp            
        activations*=mult
        return activations

    def fit(self, ...):
        compute_dropout(...)
  

我想在我的类中创建一个函数,该函数由   类方法。此函数是递归的,旨在返回一个   仅当向量具有至少一个非零的向量时才为0和1s   元素

     

我得到的错误是“ Nameerror:名称'generate_random_nodes'是   未定义

1 个答案:

答案 0 :(得分:2)

在类中定义的任何内容都必须使用限定名称引用,可以直接在类上或其实例上查询。因此,这里最简单的解决方法是显式调用NN.generate_random_nodes进行递归调用,并在对其的初始调用中self.generate_random_nodes(仅显示具有更改的方法):

@staticmethod
def generate_random_nodes(dropout_prob, size):
    temp = np.random.binomial(1, dropout_prob, size)
    # Must explicitly qualify recursive call
    return temp if not sum(temp) else NN.generate_random_nodes(dropout_prob, size)

def compute_dropout(self, activations, dropout_prob = 0.5):
    [...]
    mult = np.copy(activations)          
    # Can call static on self just fine, and avoids hard-coding class name
    temp = self.generate_random_nodes(dropout_prob, size=activations.shape[0])
    mult[:,i] = temp            
    activations*=mult
    return activations

请注意,作为Python 3.x上的CPython实现细节,在类中定义的方法内引用__class__会创建一个闭合作用域,该作用域可让您访问其定义的类,从而避免重复自己通过显式指定类,因此generate_random_nodes可以是:

@staticmethod
def generate_random_nodes(dropout_prob, size):
    temp = np.random.binomial(1, dropout_prob, size)
    # Must qualify recursive call
    return temp if not sum(temp) else __class__.generate_random_nodes(dropout_prob, size)

有两个优点:

  1. __class__的嵌套范围查找比NN的全局范围查找快一点,并且
  2. 如果您的NN类的名称在开发过程中发生了更改,则根本不需要更改generate_random_nodes(因为它隐式地获取了对其定义的类的引用)。
  3. li>

您也可以(不依赖CPython实现细节)将其更改为classmethod以获得相同的基本好处:

@classmethod
def generate_random_nodes(cls, dropout_prob, size):
    temp = np.random.binomial(1, dropout_prob, size)
    # Must qualify recursive call
    return temp if not sum(temp) else cls.generate_random_nodes(dropout_prob, size)

因为classmethod收到对其被调用的类的引用(如果在实例上被调用,则被调用的实例的类)。这是对classmethod的轻微滥用(classmethod的唯一用途是用于类层次结构中的备用构造函数,在该类层次结构中,需要能够使用备用构造函数构造子类而不在子类中重载子类);这是完全合法的,只是有点不合常规。

如以下注释中所述:

  1. Python不擅长递归
  2. 您的递归条件是向后的(只有当其中的tempsum时返回0,这意味着temp是全零的数组)递归的机会,并且对于足够高的dropout_prob / size参数,几乎可以确定递归错误。

因此,您想将temp if not sum(temp) else <recursive call>更改为temp if sum(temp) else <recursive call>,或者由于它是numpy数组temp if temp.any() else <recursive call>而获得更好的性能/明显性。尽管这样做可能会使递归错误的几率很小,但如果您要格外小心,只需更改为基于while循环的方法,这样就可以避免无限期递归:

@staticmethod
def generate_random_nodes(dropout_prob, size):
    while True:
        temp = np.random.binomial(1, dropout_prob, size)
        if temp.any():
            return temp
相关问题