生猪市场数据分析(五):迁移学习应用于咨询分类初探

标签: 数据分析  深度学习  神经网络

  • 0、项目背景
    上篇博客说到,农业领域标准的数据集较为稀少,只能寄希望与其他领域已经开放的标准数据集或者预训练模型。本项目借助于标注好的微博情绪文本构建text_CNN模型,进而结合少量标注的生猪市场新闻数据构建迁移学习模型。

  • 1、数据源
    新浪微博标注好的数据来自于
    https://github.com/murufeng/ChineseNlpCorpus,共计36 万多条,带情感标注 新浪微博,包含 4 种情感,其中喜悦约 20 万条,愤怒、厌恶、低落各约 5 万条。
    导入数据如下:
    在这里插入图片描述

  • 2、数据预处理
    数据预处理参见上一篇文章:生猪市场数据分析(四):基于无监督学习构建LSTM模型用于咨询分类。
    主要流程为:分词、引入词向量、构建词向量矩阵。

  • 3、构建text_CNN

  • 借助keras构建CNN模型:主要模块有Conv1D、MaxPool1D、Flatten、Dense;卷积层**函数使用relu,全连接层采用tanh(经本人测试文本分类使用tanh函数效果最佳,好于softmax以及relu等)。

model=Sequential()
model.add(Embedding(num_words,embedding_dim,weights=[embedding_matrix],input_length=max_tokens,trainable=False))
model.add(Conv1D(64,3,padding='same',strides=2,activation='relu'))
model.add(MaxPool1D(8))
model.add(Conv1D(128,4,padding='same',strides=1,activation='relu'))
model.add(MaxPool1D(4))
model.add(Conv1D(256,5,padding='same',strides=1,activation='relu'))
model.add(MaxPool1D(2))
model.add(Conv1D(512,6,padding='same',strides=1,activation='relu'))
model.add(MaxPool1D(1))
model.add(Flatten())
model.add(Dense(4,activation='tanh'))
#optimizer = Adam(lr=1e-3)
model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])
model.summary()
  • text_CNN模型结构如下:

在这里插入图片描述

  • 训练及测试模型
#训练模型
history=model.fit(X_train,Y_train,validation_split=0.1,epochs=20,batch_size=516)
#测试模型
model.evaluate(X_test,Y_test)
#保存模型
model.save('text_cnn0425.h5')

在这里插入图片描述
模型综合表现尚可,正确率在80%以上(此处仅为探讨,可以继续调参优化模型)。

  • 5、使用LSTM以及上述构建的text_CNN模型预测生猪市场咨询;
plt.plot(history0.history['loss'])
plt.plot(history0.history['acc'])
plt.plot(history2.history['loss'])
plt.plot(history2.history['acc'])
plt.legend(['LSTM_loss','lstm_acc','text_CNN_loss','text_CNN_acc'])

结果如下(训练及测试过程省略)
在这里插入图片描述
可以看到text_CNN表现很差,LSTM模型表现尚可,主要原因是用于训练的数据集太小,只有1500+条数据。

  • 5、构建迁移学习模型

迁移学习模型主要思路为冻结训练好的text_CNN模型的卷积层训练好的参数,重新训练全连接分类层。

  • top模型层
top_model_ =Sequential()
top_model_.add(Dense(128,activation='tanh'))
top_model_.add(Dense(4,activation='tanh'))
top_model_.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])
  • 冻结text_CNN卷积层
base_model = load_model('text_cnn0425.h5')
for layer in base_model.layers[:9]:
    layer.trainable = False
  • 构建新模型并训练
new_model = Model(inputs=base_model.input,outputs=top_model_(base_model.output))
new_model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])
new_model.fit(sz_train,szt_train,validation_split=0.1,epochs=20,batch_size=128)

在这里插入图片描述
从模型的表现来看,提升较大,最高准确率可达85%,甚至比LSTM模型还要好一点。

  • 6、验证

  • 使用聚类分析将剩余6800+文本分类成4类,并用PCA降维可视化;

from sklearn.decomposition import PCA
model_pca = PCA(n_components=2)
model_pca.fit(sz_pad)
x=model_pca.transform(sz_pad)
x2 = []
x1=[]
for i in x:
    x1.append(i[0])
    x2.append(i[1])
# In[13]:
import matplotlib.pyplot as plt
import seaborn as sns
color = {0:'red',1:'blue',2:'green',3:'black'}
plt.scatter(x1,x2,color=[color[i] for i in doc_label.labels_],alpha=0.6)

在这里插入图片描述

  • 用迁移学习模型验证结果
new_model.evaluate(sz_test,szt_test)

得出平均loss:0.599,平均正确率83.26%;应该说是一个可用级别的模型。

版权声明:本文为weixin_43990004原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_43990004/article/details/105821530

智能推荐

26_Python基础_继承

面向对象三大特性: 封装 根据 职责 将 属性 和 方法 封装 到一个抽象的 类 中 继承 实现代码的重用, 相同的代码不需要重复的编写 多态 不同的对象调用相同的方法,  产生不同的执行结果,  增加代码的灵活度 1.  单继承 1.1 概念 继承的概念:&...

循环

与任何程序设计语言一样Java利用条件语句与循环结构确定流程控制,一下总结一下Java中的循环语句: while do while for switch 对于golang来说: switch非常灵活。从第一个expr为true的case开始执行,如果case带有fallthrough,程序会继续执行下一条case,不会再判断下一条case的expr,如果之后的case都有fallthrough,d...

1638 统计只差一个字符的子串数目(动态规划)

1. 问题描述: 给你两个字符串 s 和 t ,请你找出 s 中的非空子串的数目,这些子串满足替换一个不同字符以后,是 t 串的子串。换言之,请你找到 s 和 t 串中恰好只有一个字符不同的子字符串对的数目。比方说, "computer" 和 "computation"...

websocket基本原理

HTTP中一个request只能有一个response。而且这个response也是被动的,不能主动发起 因此过去的服务端推送信息是通过客户端不停的轮询实现的 websocket是双向通信协议,提供了服务端主动推送信息的能力 需要客户端(浏览器)和服务端同时支持 如果经过代理的话,还需要代理支持,否则有些代理在长时间无通信时会自动切断连接 因此WS为了保证连接不被断掉,会发心跳 WebSocket...

mybatis+ehcache二级缓存

导入jar包 mapper.xml文件开启二级缓存 pojo类实现序列化接口 配置ehcache.xml 测试...

猜你喜欢

python+opencv实现图像拼接

任务 拍摄两张图片去除相同部分,拼接在一起 原图 结果 步骤 读取两张图片 使用sift检测关键点及描述因子 匹配关键点 处理并保存关键点 得到变换矩阵 图像变换并拼接 代码实现 扩展 这里对右边图像进行变换,右边变得模糊,可以修改代码对左边图像变换 这里只有两张图片拼接,可以封装实现多张图片拼接 可以修改代码实现上下图片的拼接...

python_sklearn机器学习算法系列之AdaBoost------人脸识别(PCA,决策树)

          注:在读本文之前建议读一下之前的一片文章python_sklearn机器学习算法系列之PCA(主成分分析)------人脸识别(k-NearestNeighbor,KNN)         本文主要目的是通过一个简单的小...

memmove函数与memcpy函数的模拟实现

memmove函数和memcpy函数都是在内存复制任意类型的,但是它俩也有区别。当源区域和目标区域有重复的,memmove函数会复制缓冲区重叠的部分,而memcpy相反,会报出未知错误。 下面给出两个函数的实现 首先,memmove函数。 实现的基本原理如下图。 具体代码如下: memcpy函数的实现很简单,就直接给出源代码了...

SpringFramework核心 - IOC容器的实现 - 总结

1. 概述 把Spring技术内幕第一章和第二章过了一遍,也做了一些笔记, 对IOC容器的实现有了一定皮毛理解,现在跟着源码再过一遍总结一下IOC容器的初始化,Bean的初始化的过程,做一下总结 ① IOC容器和简单工厂模式 在开始之前,先想想我们平时是怎么使用IOC容器为我们管理Bean的,假设我们要把下面的User类交给IOC容器管理 我们不想关心如何创建一个User对象实例的,仅仅在需要他的...