训练数据(包括训练集和验证集)大约有80
百万个样本,每个样本包含200
个密集浮点数。共有6
个标记类别,且类别分布不平衡。
在常用的机器学习库中(例如libsvm
、scikit-learn
、Spark MLlib
、random forest
、XGBoost
等),我应该使用哪一个?关于硬件配置,机器有24
个CPU核心和250
Gb的内存。
回答:
我建议使用scikit-learn的SGDClassifier,因为它支持在线学习,你可以将训练数据分成小批量(mini-batches)加载到内存中,并逐步训练分类器,这样就不需要将所有数据一次性加载到内存中。
它高度并行且易于使用。你可以将warm_start参数设置为True,并多次调用fit方法,每次加载一部分X, y数据到内存中,或者更好的选择是使用partial_fit方法。
clf = SGDClassifier(loss='hinge', alpha=1e-4, penalty='l2', l1_ratio=0.9, learning_rate='optimal', n_iter=10, shuffle=False, n_jobs=10, fit_intercept=True)# len(classes) = n_classesall_classes = np.array(set_of_all_classes)while True: #从磁盘加载一个小批量到内存中 X, y = load_next_chunk() clf.partial_fit(X, y, all_classes) X_test, y_test = load_test_data() y_pred = clf.predict(X_test)