我需要一些帮助来设计一个(高效的)Spark 中的马尔可夫链(通过 Python)。我尽我所能编写了代码,但我想出的代码无法扩展…基本上,对于不同的映射阶段,我编写了自定义函数,这些函数对于几千个序列的工作良好,但当我们达到 20,000+(我有一些高达 800k)的序列时,速度就会变得非常慢。
对于那些不熟悉马尔可夫模型的人来说,这是它的主要内容…
这是我的数据…此时我已经将实际数据(无标题)放入了一个 RDD 中。
ID, SEQ500, HNL, LNH, MLH, HML
我们以元组的形式查看序列,因此
(HNL, LNH), (LNH,MLH), 等等..
我需要达到这个点…我返回一个字典(对于每一行数据),然后我将它序列化并存储在一个内存数据库中。
{500: {HNLLNH : 0.333}, {LNHMLH : 0.333}, {MLHHML : 0.333}, {LNHHNL : 0.000}, 等等..}
本质上,每个序列与下一个序列结合(HNL,LNH 变成 ‘HNLLNH’),然后对于所有可能的转换(序列的组合),我们计算它们的出现次数,然后除以总转换次数(在这种情况下为3)并得到它们的出现频率。
上面有3个转换,其中一个是 HNLLNH…所以对于 HNLLNH,1/3 = 0.333
作为一个旁注,我不确定这是否相关,但序列中每个位置的值是有限的…第一个位置(H/M/L),第二个位置(M/L),第三个位置(H,M,L)。
我之前的代码所做的是收集 RDD,然后使用我编写的函数映射几次。这些函数首先将字符串转换为列表,然后将 list[1] 与 list[2] 合并,然后将 list[2] 与 list[3] 合并,然后将 list[3] 与 list[4] 合并,等等…所以我最终得到了一些像这样的东西…
[HNLLNH],[LNHMLH],[MHLHML], 等等..
然后下一个函数从该列表中创建一个字典,使用列表项作为键,然后计算该键在整个列表中的总出现次数,除以 len(list) 以获得频率。然后我将该字典包装在另一个字典中,连同它的 ID 号(结果是上面的第二个代码块)。
正如我所说,这对于小型序列工作得很好,但对于长度为 100k+ 的列表则不太好用。
另外,请记住,这只是一行数据。我必须对 10-20k 行数据执行此操作,每行数据的序列长度在 500-800,000 之间变化。
关于如何编写 pyspark 代码(使用 API map/reduce/agg 等函数)来高效地完成这项工作,有什么建议吗?
编辑代码如下…可能从底部开始更有意义。请记住,我在学习这个(Python 和 Spark),我不是以此为生的,所以我的编码标准不是很好…
def f(x): # 自定义 RDD 映射函数 # 将两个独立的事务合并成单一的转换状态 cust_id = x[0] trans = ','.join(x[1]) y = trans.split(",") s = '' for i in range(len(y)-1): s= s + str(y[i] + str(y[i+1]))+"," return str(cust_id+','+s[:-1])def g(x): # 自定义 RDD 映射函数 # 通过累加状态转换的发生次数来计算转换状态的概率 # 并除以总转换次数 cust_id=str(x.split(",")[0]) trans = x.split(",")[1:] temp_list=[] middle = int((len(trans[0])+1)/2) for i in trans: temp_list.append( (''.join(i)[:middle], ''.join(i)[middle:]) ) state_trans = {} for i in temp_list: state_trans[i] = temp_list.count(i)/(len(temp_list)) my_dict = {} my_dict[cust_id]=state_trans return my_dictdef gen_tsm_dict_spark(lines): # 接受格式为 CUST_ID(or)PROFILE_ID,SEQ,SEQ,SEQ.... 的 RDD/字符串输入 # 返回带有 CUST_ID 和每个客户的 tsm 的字典 RDD # 即 {cust_id : { ('NLN', 'LNN') : 0.33, ('HPN', 'NPN') : 0.66} # 创建一个元组 ([cust/profile_id], [SEQ,SEQ,SEQ]) cust_trans = lines.map(lambda s: (s.split(",")[0],s.split(",")[1:])) with_seq = cust_trans.map(f) full_tsm_dict = with_seq.map(g) return full_tsm_dictdef main():result = gen_tsm_spark(my_rdd)# 插入到数据库中for x in result.collect(): for k,v in x.iteritems(): db_insert(k,v)
回答:
你可以尝试以下内容。这在很大程度上依赖于 tooolz
,但如果你更喜欢避免外部依赖,你可以轻松地用一些标准的 Python 库替换它。
from __future__ import divisionfrom collections import Counterfrom itertools import productfrom toolz.curried import sliding_window, map, pipe, concatfrom toolz.dicttoolz import merge# 生成所有可能的转换 defaults = sc.broadcast(dict(map( lambda x: ("".join(concat(x)), 0.0), product(product("HNL", "NL", "HNL"), repeat=2))))rdd = sc.parallelize(["500, HNL, LNH, NLH, HNL", "600, HNN, NNN, NNN, HNN, LNH"])def process(line): """ >>> process("000, HHH, LLL, NNN") ('000', {'LLLNNN': 0.5, 'HHHLLL': 0.5}) """ bits = line.split(", ") transactions = bits[1:] n = len(transactions) - 1 frequencies = pipe( sliding_window(2, transactions), # 获取所有转换 map(lambda p: "".join(p)), # 连接字符串 Counter, # 计数 lambda cnt: {k: v / n for (k, v) in cnt.items()} # 获取频率 ) return bits[0], frequenciesdef store_partition(iter): for (k, v) in iter: db_insert(k, merge([defaults.value, v]))rdd.map(process).foreachPartition(store_partition)
由于你知道所有可能的转换,我建议使用稀疏表示并忽略零值。此外,你可以用稀疏向量替换字典以减少内存占用。