【Scikit-Learn】朴素贝叶斯文档分类

标签: 朴素贝叶斯  混淆矩阵  模型总结

本文使用的数据集来自mlcomp.org上的20news-18828,下载地址为:mlcomp.org/datasets/379

1. 读入数据

datasets/mlcomp/379/train目录下放的是语料库,其中包含20个子目录,每个子目录的名字表示的是文档的类别,子目录下包含这种类别的所有文档。

load_files()函数会从这个目录里把所有的文档都读入内存,并且自动根据所在的子目录名称打上标签。

下面代码中的各参数
1. news_train.data:是一个数组,里面包含了所有文档的文本信息。
2. news_train.target:是一个数组,里面包含了所有文档所属的类别的编号。
3. news_train.target_names:类别的名称。

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

from time import time
from sklearn.datasets import load_files

print("loading train dataset ...")
t = time()
news_train = load_files('datasets/mlcomp/379/train')
print("summary: {0} documents in {1} categories.".format(
    len(news_train.data), len(news_train.target_names)))
print("done in {0} seconds".format(time() - t))
'''
loading train dataset ...
summary: 13180 documents in 20 categories.
done in 4.350337743759155 seconds
'''

news_train.target
'''
    array([18, 13,  1, ..., 14, 15,  4])
'''

news_train.target_names
'''
['alt.atheism',
 'comp.graphics',
 'comp.os.ms-windows.misc',
 'comp.sys.ibm.pc.hardware',
 'comp.sys.mac.hardware',
 'comp.windows.x',
 'misc.forsale',
 'rec.autos',
 'rec.motorcycles',
 'rec.sport.baseball',
 'rec.sport.hockey',
 'sci.crypt',
 'sci.electronics',
 'sci.med',
 'sci.space',
 'soc.religion.christian',
 'talk.politics.guns',
 'talk.politics.mideast',
 'talk.politics.misc',
 'talk.religion.misc']
'''

# 查询第一篇文档所属类别的名称
news_train.target_names[news_train.target[0]]
'''
    'talk.politics.misc'
'''

2. 将文章转换为一个TF-IDF向量

我们的词典中总共有130274个词语,每个文档转化为一个 1*103274的向量表示。

原数据集news_train.data共有13180个文档,向量化为X_train(是一个 13180*130274)的稀疏矩阵。

from sklearn.feature_extraction.text import TfidfVectorizer

print("vectorizing train dataset ...")
t = time()
vectorizer = TfidfVectorizer(encoding='latin-1')
X_train = vectorizer.fit_transform((d for d in news_train.data))
print("n_samples: %d, n_features: %d" % X_train.shape)
print("number of non-zero features in sample [{0}]: {1}".format(
    news_train.filenames[0], X_train[0].getnnz())) # 计算第0号文档的Tf-IDF向量中的非零元素个数
print("done in {0} seconds".format(time() - t))
'''
vectorizing train dataset ...
n_samples: 13180, n_features: 130274
number of non-zero features in sample [datasets/mlcomp/379/train\talk.politics.misc\17860-178992]: 108
done in 3.431797981262207 seconds
'''

3. 训练朴素贝叶斯模型

from sklearn.naive_bayes import MultinomialNB

print("traning models ...".format(time() - t))
t = time()
y_train = news_train.target
clf = MultinomialNB(alpha=0.0001)
clf.fit(X_train, y_train)
train_score = clf.score(X_train, y_train)
print("train score: {0}".format(train_score))
print("done in {0} seconds".format(time() - t))
'''
traning models ...
train score: 0.9978755690440061
done in 0.2513284683227539 seconds
'''

4. 测试训练好的模型

# 读入测试数据
print("loading test dataset ...")
t = time()
news_test = load_files('datasets/mlcomp/379/test')
print("summary: {0} documents in {1} categories.".format(
    len(news_test.data), len(news_test.target_names)))
print("done in {0} seconds".format(time() - t))
'''
loading test dataset ...
summary: 5648 documents in 20 categories.
done in 0.117918014526 seconds
'''

# 向量化测试数据
print("vectorizing test dataset ...")
t = time()
X_test = vectorizer.transform((d for d in news_test.data))
y_test = news_test.target
print("n_samples: %d, n_features: %d" % X_test.shape)
print("number of non-zero features in sample [{0}]: {1}".format(
    news_test.filenames[0], X_test[0].getnnz()))
print("done in %fs" % (time() - t))
'''
vectorizing test dataset ...
n_samples: 5648, n_features: 130274
number of non-zero features in sample [datasets/mlcomp/379/test/rec.autos/7429-103268]: 61
done in 2.915759s
'''

# 预测测试集的第0号文档,并输出预测类别及真实类别
pred = clf.predict(X_test[0])
print("predict: {0} is in category {1}".format(
    news_test.filenames[0], news_test.target_names[pred[0]]))
print("actually: {0} is in category {1}".format(
    news_test.filenames[0], news_test.target_names[news_test.target[0]]))
'''
predict: datasets/mlcomp/379/test/rec.autos/7429-103268 is in category rec.autos
actually: datasets/mlcomp/379/test/rec.autos/7429-103268 is in category rec.autos
'''

# 预测整个测试集
print("predicting test dataset ...")
t = time()
pred = clf.predict(X_test)
print("done in %fs" % (time() - t))
'''
predicting test dataset ...
done in 0.090978s
'''

# 总结朴素贝叶斯模型在测试集上的性能
from sklearn.metrics import classification_report

print("classification report on test set for classifier:")
print(clf)
print(classification_report(y_test, pred,
                            target_names=news_test.target_names))
'''
classification report on test set for classifier:
MultinomialNB(alpha=0.0001, class_prior=None, fit_prior=True)
                          precision    recall  f1-score   support

             alt.atheism       0.90      0.91      0.91       245
           comp.graphics       0.80      0.90      0.85       298
 comp.os.ms-windows.misc       0.82      0.79      0.80       292
comp.sys.ibm.pc.hardware       0.81      0.80      0.81       301
   comp.sys.mac.hardware       0.90      0.91      0.91       256
          comp.windows.x       0.88      0.88      0.88       297
            misc.forsale       0.87      0.81      0.84       290
               rec.autos       0.92      0.93      0.92       324
         rec.motorcycles       0.96      0.96      0.96       294
      rec.sport.baseball       0.97      0.94      0.96       315
        rec.sport.hockey       0.96      0.99      0.98       302
               sci.crypt       0.95      0.96      0.95       297
         sci.electronics       0.91      0.85      0.88       313
                 sci.med       0.96      0.96      0.96       277
               sci.space       0.94      0.97      0.96       305
  soc.religion.christian       0.93      0.96      0.94       293
      talk.politics.guns       0.91      0.96      0.93       246
   talk.politics.mideast       0.96      0.98      0.97       296
      talk.politics.misc       0.90      0.90      0.90       236
      talk.religion.misc       0.89      0.78      0.83       171

             avg / total       0.91      0.91      0.91      5648
'''

5. 画出测试集上的混淆矩阵

混淆矩阵的示例如下:

这里写图片描述

下面代码可视化出混淆矩阵,除对角线外,其他地方颜色越浅,说明此处错误越多。

from sklearn.metrics import confusion_matrix

cm = confusion_matrix(y_test, pred)
print("confusion matrix:")
print(cm)
'''
confusion matrix:
[[224   0   0   0   0   0   0   0   0   0   0   0   0   0   2   5   0   0   1  13]
 [  1 267   5   5   2   8   1   1   0   0   0   2   3   2   1   0   0   0   0   0]
 [  1  13 230  24   4  10   5   0   0   0   0   1   2   1   0   0   0   0   1   0]
 [  0   9  21 242   7   2  10   1   0   0   1   1   7   0   0   0   0   0   0   0]
 [  0   1   5   5 233   2   2   2   1   0   0   3   1   0   1   0   0   0   0   0]
 [  0  20   6   3   1 260   0   0   0   2   0   1   0   0   2   0   2   0   0   0]
 [  0   2   5  12   3   1 235  10   2   3   1   0   7   0   2   0   2   1   4   0]
 [  0   1   0   0   1   0   8 300   4   1   0   0   1   2   3   0   2   0   1   0]
 [  0   1   0   0   0   2   2   3 283   0   0   0   1   0   0   0   0   0   1   1]
 [  0   1   1   0   1   2   1   2   0 297   8   1   0   1   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   2   2 298   0   0   0   0   0   0   0   0   0]
 [  0   1   2   0   0   1   1   0   0   0   0 284   2   1   0   0   2   1   2   0]
 [  0  11   3   5   4   2   4   5   1   1   0   4 266   1   4   0   1   0   1   0]
 [  1   1   0   1   0   2   1   0   0   0   0   0   1 266   2   1   0   0   1   0]
 [  0   3   0   0   1   1   0   0   0   0   0   1   0   1 296   0   1   0   1   0]
 [  3   1   0   1   0   0   0   0   0   0   1   0   0   2   1 280   0   1   1   2]
 [  1   0   2   0   0   0   0   0   1   0   0   0   0   0   0   0 236   1   4   1]
 [  1   0   0   0   0   1   0   0   0   0   0   0   0   0   0   3   0 290   1   0]
 [  2   1   0   0   1   1   0   1   0   0   0   0   0   0   0   1  10   7   212 0]
 [ 16   0   0   0   0   0   0   0   0   0   0   0   0   0   0  12   4   1   4 134]]
'''

# Show confusion matrix
plt.figure(figsize=(8, 8), dpi=144)
plt.title('Confusion matrix of the classifier')
ax = plt.gca()                                  
ax.spines['right'].set_color('none')            
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_color('none')
ax.spines['left'].set_color('none')
ax.xaxis.set_ticks_position('none')
ax.yaxis.set_ticks_position('none')
ax.set_xticklabels([])
ax.set_yticklabels([])
plt.matshow(cm, fignum=1, cmap='gray')
plt.colorbar();

这里写图片描述

原文链接:加载失败,请重新获取