在类方法中将实例变量作为参数传递?

时间:2018-07-13 21:45:30

标签: python

我正在尝试重构代码中非常重复的部分。

我有一个具有两个实例变量的类,这些实例变量得到更新:

class Alerter(object):

    'Sends email regarding information about unmapped positions and trades'

    def __init__(self, job):
        self.job = job
        self.unmappedPositions = None
        self.unmappedTrades = None

在我的代码经过某些方法之后,它会创建一个表并更新self.unmappedPositionsself.unmappedTrades

def load_positions(self, filename):
    unmapped_positions_table = etl.fromcsv(filename)
    if 'positions' in filename:
        return self.add_to_unmapped_positions(unmapped_positions_table)
    else:
        return self.add_to_unmapped_trades(unmapped_positions_table)

因此,我有两个基本上执行相同操作的函数:

def add_to_unmapped_trades(self, table):
    if self.unmappedTrades:
        Logger.info("Adding to unmapped")
        self.unmappedTrades = self.unmappedTrades.cat(
            table).cache()
    else:
        Logger.info("Making new unmapped")
        self.unmappedTrades = table
    Logger.info("Data added to unmapped")
    return self.unmappedTrades

并且:

def add_to_unmapped_positions(self, table):
    if self.unmappedPositions:
        Logger.info("Adding to unmapped")
        self.unmappedPositions = self.unmappedPositions.cat(
            table).cache()
    else:
        Logger.info("Making new unmapped")
        self.unmappedPositions = table
    Logger.info("Data added to unmapped")
    return self.unmappedPositions

我尝试将其设为一种方法,以便仅传递第三个参数,然后找出要更新的内容。第三个参数是初始化变量self.unmappedPositionsself.unmappedTrades。但是,这似乎不起作用。还有其他建议吗?

1 个答案:

答案 0 :(得分:0)

您似乎具有关键的洞察力,可以独立于任何特定的存储编写此函数:

def add_to_unmapped(unmapped, table):
    if unmapped:
        Logger.info("Adding to unmapped")
        unmapped = unmapped.cat(table).cache()
    else:
        Logger.info("Making new unmapped")
        unmapped = table
    Logger.info("Data added to unmapped")
    return unmapped

这实际上是一种很好的做法。例如,您可以为其编写单元测试,或者,如果有两个表(如您所愿),则只需为其编写一次实现。

如果您抽象地考虑两个add_to_unmapped_*函数的作用,它们将:

  1. 计算新表;
  2. 将新表保存在对象中;和
  3. 返回新表。

我们现在将步骤1分离出来,您可以重构包装器:

class Alerter:
    def add_to_unmapped_trades(self, table):
        self.unmappedTrades = add_to_unmapped(self.unmappedTrades, table)
        return self.unmappedTrades