最近邻文本分类

我有两个文本文件,(1) 包含不良词汇的样本(2) 包含良好词汇的样本。现在我想进行最近邻分类,将新发现的词汇分类为好或坏。我希望了解如何在现有代码的基础上进行处理。谢谢

class Words_Works():def __init__(self):    self.all_texts = {}    self.categories = {}    self.knn_results = {}    self.stop_words = ['and','the','i','am','he','she','his',                        'me','my','a','at','in','you','your',                        'of','to','this','that','him','her',                        'they','is','it','can','for','into',                        'as','with','we','us','them','a',                         'it', 'on', 'so', 'too','k','the',                        'but', 'are','though'                        'very', 'here', 'even', 'from',                        'then', 'than']    self.leaf_words = ['s', 'es', 'ed', 'er', 'ly', 'ing']def add_category(self,f,cat_name):    f_in = open(f)    self.text = f_in.read().lower()    f_in.close()    self.wordify()    self.unstopify()    self.unleafify()    self.categories[cat_name] = {}    for item  in self.unleaf:        if self.categories[cat_name].has_key(item):            self.categories[cat_name][item] += 1        else:            self.categories[cat_name][item] = 1def load_categories(self):    try:        cat_db = open('tweetCategory.txt','rb')        self.categories = cPickle.load(cat_db)        cat_db.close()        print 'File successfully loaded from categories db'    except:        print 'File not loaded from categories_db'        # Finds the levenshtein's distance def levenshtein_distance(first, second):"""Find the Levenshtein distance between two strings."""if len(first) > len(second):    first, second = second, first    if len(second) == 0:        return len(first)        first_length = len(first) + 1        second_length = len(second) + 1        distance_matrix = [[0] * second_length for x in range(first_length)]        for i in range(first_length):            distance_matrix[i][0] = i            for j in range(second_length):               distance_matrix[0][j]=j               for i in xrange(1, first_length):                   for j in range(1, second_length):                       deletion = distance_matrix[i-1][j] + 1                       insertion = distance_matrix[i][j-1] + 1                       substitution = distance_matrix[i-1][j-1]                       if first[i-1] != second[j-1]:                           substitution += 1                           distance_matrix[i][j] = min(insertion, deletion, substitution)    return distance_matrix[first_length-1][second_length-1]  def add_text(self,f):    f_in = open(f)    self.text = f_in.read().lower()    f_in.close()    self.wordify()    self.unstopify()    self.unleafify()    self.all_texts[f] = {}    for item in self.unleaf:        if self.all_texts[f].has_key(item):            self.all_texts[f][item] += 1        else:            self.all_texts[f][item] = 1def save_categories(self):    cat_db = open('tweetCategory.txt','wb')    cPickle.dump(cat_db,self.categories,-1)    cat_db.close()def unstopify(self):    self.unstop = [item for item in self.words if item not in self.stop_words]def unleafify(self):    self.unleaf = self.unstop[:]    for leaf in self.leaf_words:        leaf_len = len(leaf)        leaf_pattern = re.compile('%s$' % leaf)        for i in range(len(self.unleaf)):            if leaf_pattern.findall(self.unleaf[i]):                self.unleaf[i] = self.unleaf[i][:-leaf_len]def wordify(self):    words_pattern = re.compile('//w+')    self.words = words_pattern.findall(self.text)def knn_calc(self):    for text in self.all_texts.keys():        self.knn_results[text] = {}        for category in self.categories.keys():            self.knn_results[text][category] = {}            iterations = 0            distance = 0            for word in self.all_texts[text].keys():                if word in self.categories[text].keys():                    distance = levenshtein_distance(text,category)                    self.knn_results[text][category]['Knn Distance'] = distance                    self.knn_results [text][category]['Knn Iterations'] = iterationsdef knn(self):    for text in self.all_texts.keys():        Result = None        for category in self.categories.keys():            if not result or self.knn_results[text][category]['Knn Distance'] < result:                knn = category                distance = self.knn_results[text][category]['Knn Distance']                iterations = self.knn_results[text][category]['Knn Iterations']                print 'File:',text                print 'Knn:',category                print 'Distance :', distance                print 'Iterations :', iterations                print 'End of nearest neighbour search'

测试用例如下:

mywork = Words_Works()positive = 'positive.txt'mywork.add_category(positive, 'Positive Tweets')               # 添加为类别negative = 'negative.txt'mywork.add_category(negative, 'Negative Tweets')neutral = 'neutral.txt'mywork.add_category(neutral, 'Neutral Tweets')for category in mywork.categories.keys():              # 打印类别    print category    print mywork.categories[category]    printprinttxts = ('samplegood.txt', 'samplebad.txt')                  # 创建文件列表for text in txts:                                      # 添加文件    mywork.add_text(text)for text in mywork.all_texts.keys():                   # 打印文本    print text    print mywork.all_texts[text]    print    printmywork.knn_calc()                                         # 计算knnfor files in mywork.knn_results.keys():                   # 打印详细结果    print files    for category in mywork.knn_results[files].keys():        print category        print mywork.knn_results[files][category]    printprint    mywork.knn()                                              # 显示结果

回答:

两条建议:首先,正如@[隐藏人名]所指出的,你应该使用编辑距离,也称为Levenshtein距离。你可以在python-Levenshtein包中找到它。

其次,使用标准库中的unittestdoctest库来测试你的代码。使用保存在外部文件中的示例来测试代码是个坏主意,因为没有访问这些文件的第三方(例如我们)无法知道输入是什么;同样,避免打印输出并手动检查,因为这很慢,容易出错,而且其他人无法审查。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注