第1关:朴素贝叶斯——新闻分类
任务描述
本关任务:使用sklearn完成新闻主题分类任务。
相关知识
为了完成本关任务,你需要掌握如何使用sklearn提供的MultinomialNB类与文本向量化。
数据简介
本关使用的是20newsgroups数据集,20newsgroups数据集是用于文本分类、文本挖据和信息检索研究的国际标准数据集之一。数据集收集了18846篇新闻组文档,均匀分为20个不同主题(比如电脑硬件、中东等主题)的新闻组集合。
MultinomialNB
MultinomialNB类中的fit函数实现了朴素贝叶斯分类算法训练模型的功能,predict函数实现了法模型预测的功能。
其中fit函数的参数如下:
X:大小为[样本数量,特征数量]的ndarry,存放训练样本
Y:值为整型,大小为[样本数量]的ndarray,存放训练样本的分类标签
而predict函数有一个向量输入:
X:大小为[样本数量,特征数量]的ndarry,存放预测样本
MultinomialNB的使用代码如下:
clf = tree.MultinomialNB()
clf.fit(X_train, Y_train)
result = clf.predict(X_test)
编程要求
填写news_predict(train_sample, train_label, test_sample)函数完成鸢尾花分类任务,其中:
train_sample:原始训练样本
train_label:训练标签
test_samp