我想在某种条件下生成所有可能的组合。鉴于我有一个包含所需条件的数据框。
Variable Cluster_no sub_group
GDP_M3 1 GDP
HPI_M3_lg2 1 HPI
FDI_C_lg5 1 FDI
FDI_M6 2 FDI
Export_M9 2 Export
GDP_M9 2 GDP
GDP_M12_lg7 3 GDP
Export_M12 3 Export
我发现itertools.combinations
给了我3的所有可能组合。但是,我想考虑使用cluster_no
和sub_group
的更多条件。
已经说过,如果我以GDP_M3
中的cluster 1
开头,它将不会与HPI_M3_lg2
或FDI_C_lg5
匹配,因为它来自同一集群。给定群集条件,它将查找其他群集,这些群集是cluster 2
或cluster 3
。
在cluster 2
中,有2个可能的变量,分别是FDI_M6
或Export_M9
,因为我也想考虑sub_group
的条件。如果选择FDI_M6
,它将移至下一个群集,因为再次允许每个群集中只有1个变量。
现在,我的名单是[GDP_M3, FDI_M6]
。组合的下一个变量是Export_M12
,因为它来自cluster 3
和sub_group Export
。
我想设置3种可能的组合(1到3)。对此有任何建议。
谢谢。
编辑以添加我的代码。
N=3
combination=[]
for i in range(1, N+1):
for j in itertools.combinations(a, i):
combination.append(list(j))
答案 0 :(得分:1)
我不认为内置的组合方法可以处理这种情况。您必须编写自己的回溯组合算法。我已经尝试实现一个:
all_possible_combinations = []
def get_combinations(N, data, cur_index=0, generated_el=[], cluster_tracker=set(), sub_group_tracker=set()):
if N == 0:
if generated_el:
all_possible_combinations.append(tuple(generated_el))
return
if cur_index >= len(data):
return
get_combinations(N, data, cur_index+1, generated_el, cluster_tracker, sub_group_tracker)
if data[cur_index][1] in cluster_tracker:
# I have already taken this cluster
return
if data[cur_index][2] in sub_group_tracker:
# I have already taken this sub group
return
generated_el.append(data[cur_index][0])
cluster_tracker.add(data[cur_index][1])
sub_group_tracker.add(data[cur_index][2])
get_combinations(N-1, data, cur_index+1, generated_el, cluster_tracker, sub_group_tracker)
generated_el.pop()
cluster_tracker.remove(data[cur_index][1])
sub_group_tracker.remove(data[cur_index][2])
return
if __name__ == "__main__":
data = [
("GDP_M3", "1", "GDP"),
("HPI_M3_lg2", "1", "HPI" ),
("FDI_C_lg5", "1", "FDI"),
("FDI_M6", "2", "FDI"),
("Export_M9", "2", "Export"),
("GDP_M9", "2", "GDP"),
("GDP_M12_lg7", "3", "GDP"),
("Export_M12", "3", "Export")
]
get_combinations(3, data)
print(all_possible_combinations)
您可以在此处查看输出:https://ideone.com/HwruJ7
答案 1 :(得分:0)
我的方法类似于@Ahmad Faiyaz
from collections import defaultdict
x= [[1,1,'gdp'],[2,1,'hpi'],[3,1,'fdi'],[4,2,'fdi'],[5,2,'export'],[6,2,'gdp'],[7,3,'gdp'],[8,3,'export']]
c=defaultdict(list)
for i in x:
c[i[1]]+=[i]
def rec_cal(i,clus,lis):
if i in c.keys():
for j in c[i]:
if j[2] not in clus:
clus.append(j[2])
lis.append(j[0])
rec_cal(i+1,clus,lis)
clus.pop()
lis.pop()
else:
continue
else:
print(lis)
rec_cal(1,[],[])
您将得到的输出为
[1, 4, 8]
[2, 4, 7]
[2, 4, 8]
[2, 5, 7]
[2, 6, 8]
[3, 5, 7]
[3, 6, 8]
此方法首先借助词典构建聚类集合,然后在考虑子组的情况下递归遍历这些聚类以创建最终输出。现在,我只是打印它,但是您可以轻松捕获它