Spark:更高效的聚合以连接来自不同行的字符串

时间:2015-12-19 20:57:59

标签: python apache-spark pyspark

我目前正在处理DNA序列数据,但我遇到了一些性能障碍。

我有两个查找词典/哈希(作为RDD)与DNA"单词" (短序列)作为键和索引位置列表作为值。一个用于较短的查询序列,另一个用于数据库序列。即使是非常非常大的序列,创建表也非常快。

下一步,我需要将它们配对并找到" hits" (每个常用词的索引位置对)。

我首先加入查找字典,速度相当快。但是,我现在需要这些对,所以我必须进行两次flatmap,一次是从查询中扩展索引列表,第二次是从数据库中展开索引列表。这不是理想的,但我没有看到另一种方法。至少它表现不错。

此时的输出为:(query_index, (word_length, diagonal_offset)),其中对角线偏移量为database_sequence_index减去查询序列索引。

然而,我现在需要找到具有相同对角线偏移量(db_index - query_index)的索引对,并合理地靠近并加入它们(因此我增加了单词的长度),但仅作为对(即一次我加入一个索引与另一个索引,我不想要任何其他东西与它合并)。

我使用名为Seed()的特殊对象使用aggregateByKey操作。

PARALELLISM = 16 # I have 4 cores with hyperthreading
def generateHsps(query_lookup_table_rdd, database_lookup_table_rdd):
    global broadcastSequences

    def mergeValueOp(seedlist, (query_index, seed_length)):
        return seedlist.addSeed((query_index, seed_length))

    def mergeSeedListsOp(seedlist1, seedlist2):
        return seedlist1.mergeSeedListIntoSelf(seedlist2)

    hits_rdd = (query_lookup_table_rdd.join(database_lookup_table_rdd)
                .flatMap(lambda (word, (query_indices, db_indices)): [(query_index, db_indices) for query_index in query_indices], preservesPartitioning=True)
                .flatMap(lambda (query_index, db_indices): [(db_index - query_index, (query_index, WORD_SIZE)) for db_index in db_indices], preservesPartitioning=True)
                .aggregateByKey(SeedList(), mergeValueOp, mergeSeedListsOp, PARALLELISM)
                .map(lambda (diagonal, seedlist): (diagonal, seedlist.mergedSeedList))
                .flatMap(lambda (diagonal, seedlist): [(query_index, seed_length, diagonal) for query_index, seed_length in seedlist])
                )

    return hits_rdd

种子():

class SeedList():
    def __init__(self):
        self.unmergedSeedList = []
        self.mergedSeedList = []


    #Try to find a more efficient way to do this
    def addSeed(self, (query_index1, seed_length1)):
        for i in range(0, len(self.unmergedSeedList)):
            (query_index2, seed_length2) = self.unmergedSeedList[i]
            #print "Checking ({0}, {1})".format(query_index2, seed_length2)
            if min(abs(query_index2 + seed_length2 - query_index1), abs(query_index1 + seed_length1 - query_index2)) <= WINDOW_SIZE:
                self.mergedSeedList.append((min(query_index1, query_index2), max(query_index1+seed_length1, query_index2+seed_length2)-min(query_index1, query_index2)))
                self.unmergedSeedList.pop(i)
                return self
        self.unmergedSeedList.append((query_index1, seed_length1))
        return self

    def mergeSeedListIntoSelf(self, seedlist2):
        print "merging seed"
        for (query_index2, seed_length2) in seedlist2.unmergedSeedList:
            wasmerged = False
            for i in range(0, len(self.unmergedSeedList)):
                (query_index1, seed_length1) = self.unmergedSeedList[i]
                if min(abs(query_index2 + seed_length2 - query_index1), abs(query_index1 + seed_length1 - query_index2)) <= WINDOW_SIZE:
                    self.mergedSeedList.append((min(query_index1, query_index2), max(query_index1+seed_length1, query_index2+seed_length2)-min(query_index1, query_index2)))
                    self.unmergedSeedList.pop(i)
                    wasmerged = True
                    break
            if not wasmerged:
                self.unmergedSeedList.append((query_index2, seed_length2))
        return self

即使是中等长度的序列,这也是性能真正崩溃的地方。

有没有更好的方法来进行这种聚合?我的直觉是肯定的,但我无法想出来。

我知道这是一个非常漫长的技术问题,即使没有简单的解决方案,我也非常感谢任何见解。

编辑:以下是我如何制作查找表:

def createLookupTable(sequence_rdd, sequence_name, word_length):
    global broadcastSequences
    blank_list = []

    def addItemToList(lst, val):
        lst.append(val)
        return lst

    def mergeLists(lst1, lst2):
        #print "Merging"
        return lst1+lst2

    return (sequence_rdd
            .flatMap(lambda seq_len: range(0, seq_len - word_length + 1))
            .repartition(PARALLELISM)
            #.partitionBy(PARALLELISM)
            .map(lambda index: (str(broadcastSequences.value[sequence_name][index:index + word_length]), index), preservesPartitioning=True)
            .aggregateByKey(blank_list, addItemToList, mergeLists, PARALLELISM))
            #.map(lambda (word, indices): (word, sorted(indices))))

这是运行整个操作的函数:

def run(query_seq, database_sequence, translate_query=False):
    global broadcastSequences
    scoring_matrix = 'nucleotide' if isinstance(query_seq.alphabet, DNAAlphabet) else 'blosum62'
    sequences = {'query': query_seq,
                 'database': database_sequence}

    broadcastSequences = sc.broadcast(sequences)
    query_rdd = sc.parallelize([len(query_seq)])
    query_rdd.cache()
    database_rdd = sc.parallelize([len(database_sequence)])
    database_rdd.cache()
    query_lookup_table_rdd = createLookupTable(query_rdd, 'query', WORD_SIZE)
    query_lookup_table_rdd.cache()
    database_lookup_table_rdd = createLookupTable(database_rdd, 'database', WORD_SIZE)
    seeds_rdd = generateHsps(query_lookup_table_rdd, database_lookup_table_rdd)
    return seeds_rdd

编辑2:我通过替换:

稍微调整了一些事情并略微提高了性能
                .flatMap(lambda (word, (query_indices, db_indices)): [(query_index, db_indices) for query_index in query_indices], preservesPartitioning=True)
                .flatMap(lambda (query_index, db_indices): [(db_index - query_index, (query_index, WORD_SIZE)) for db_index in db_indices], preservesPartitioning=True)

在hits_rdd中:

.flatMap(lambda (word, (query_indices, db_indices)): itertools.product(query_indices, db_indices))
                .map(lambda (query_index, db_index): (db_index - query_index, (query_index, WORD_SIZE) ))

至少现在我没有用中间数据结构烧掉存储空间。

1 个答案:

答案 0 :(得分:1)

让我们忘记你在做什么和思考的技术细节&#34;功能性&#34;关于所涉及的步骤,忘记了实施的细节。像这样的功能性思维是并行数据分析的重要组成部分;理想情况下,如果我们能够像这样解决问题,我们可以更清楚地说明所涉及的步骤,并最终更清晰,更简洁。根据表格数据模型进行思考,我认为您的问题包括以下步骤:

  1. 在序列列上加入两个数据集。
  2. 创建一个新列delta,其中包含索引之间的差异。
  3. 按(或)索引排序,以确保子序列的顺序正确。
  4. delta分组并连接序列列中的字符串,以获取数据集之间的完整匹配。
  5. 对于前3个步骤,我认为使用DataFrames是有意义的,因为这个数据模型对我们正在进行的那种处理有所了解。 (实际上我也可以在步骤4中使用DataFrames,除了pyspark当前不支持DataFrames的自定义聚合函数,尽管Scala确实如此)。

    第四步(如果我正确理解你在问题中真正提出的问题),考虑如何在功能上做这件事有点棘手,不过我认为这是一个优雅的有效的解决方案是使用减少(也称为右折);这个模式可以应用于你可以在迭代地应用关联二元函数方面表达的任何问题,这是一个&#34;分组&#34;任何3个参数都不重要(虽然顺序当然可能很重要),符号上,这是一个函数x,y -> f(x,y),其中f(x, f(y, z)) = f(f(x, y), z)。字符串(或更一般地说是列表)连接就是这样一种功能。

    以下是pyspark中这种情况的示例;希望您能够根据数据的细节进行调整:

    #setup some sample data
    query = [['abcd', 30] ,['adab', 34] ,['dbab',38]]
    reference = [['dbab', 20], ['ccdd', 24], ['abcd', 50], ['adab',54], ['dbab',58], ['dbab', 62]]
    
    #create data frames
    query_df = sqlContext.createDataFrame(query, schema = ['sequence1', 'index1'])
    reference_df = sqlContext.createDataFrame(reference, schema = ['sequence2', 'index2'])
    
    #step 1: join
    matches = query_df.join(reference_df, query_df.sequence1 == reference_df.sequence2)
    
    #step 2: calculate delta column
    matches_delta = matches.withColumn('delta', matches.index2 - matches.index1)
    
    #step 3: sort by index
    matches_sorted = matches_delta.sort('delta').sort('index2')
    
    #step 4: convert back to rdd and reduce
    #note that + is just string concatenation for strings
    r = matches_sorted['delta', 'sequence1'].rdd
    r.reduceByKey(lambda x, y : x + y).collect()
    
    #expected output:
    #[(24, u'dbab'), (-18, u'dbab'), (20, u'abcdadabdbab')]
    
相关问题