DNN:基于Keras对手写数字的识别

标签: 神经网络  tensorflow  深度学习  python  机器学习

Keras和TensorFlow的安装及常见故障处理

Keras和TensorFlow的安装

1、在anaconda prompt中输入conda create -n keras创建keras环境
2、输入conda activate keras**创建的环境
3、输入conda install tensorflow 安装TensorFlow,询问处输入y即可安装
4、创建的keras环境中若无Spyder,需要在keras环境中输入 conda install spyder进行安装

常见故障处理

1、若conda install 中出现PackagesNotFoundError,可采用pip install 进行替代,不过pip不支持断点续传,如果网络不好,可能出现下载到一半突然中断的现象,如下图所示:

在这里插入图片描述找个好点的网络即可解决,也可以使用国内的豆瓣镜像源进行下载。

2、若在Spyder中运行程序时出现Keras need TensorFlow 2.2 or higher,但是明明已经下了2.2或更高的版本,说明可能电脑里装了两个或以上TensorFlow,可以将所有的TensorFlow都卸载干净再装需要的版本。

3、若在Spyder中运行程序时出现AttributeError: module ‘tensorflow.python.framework.ops’ has no attribute ‘_TensorLike’,是Keras和TensorFlow的版本不匹配,可参照下面的链接选择Keras、TensorFlow和Python三者均匹配的版本。
List of Avaliable Environments

编程实现

导入需要的库和数据。

from keras.datasets import mnist
from keras import models
from keras.layers import Dense
from keras.utils import np_utils

(X_train,y_train),(X_test,y_test) = mnist.load_data()

由于每一个手写数字是一个二维的矩阵,在建立神经网络之前需要将其转化为一维的向量,并将其归一化。

num_pix = X_train.shape[1] * X_train.shape[2]
X_train = X_train.reshape(X_train.shape[0], num_pix)
X_train = X_train /255

X_test = X_test.reshape(X_test.shape[0], num_pix)
X_test = X_test /255

将训练目标集和测试目标集离散化。

y_train = np_utils.to_categorical(y_train)
y_test = np_utils.to_categorical(y_test)

搭建三层神经网络进行预测,共进行20轮计算。

net = models.Sequential()
net.add(Dense(input_dim = num_pix, output_dim = 500, activation = 'relu'))
net.add(Dense(output_dim = 500, activation = 'relu'))
net.add(Dense(output_dim = 10, activation = 'softmax'))

net.compile(optimizer = 'adam', loss = 'categorical_crossentropy', metrics = ['accuracy'])

net.fit(X_train, y_train, batch_size = 16, epochs = 20)

输出预测的损失和准确度。

score = net.evaluate(X_test, y_test, )
print('loss is:\t ', score[0])
print('accuracy is:\t ', score[1])

运行结果

在这里插入图片描述

版权声明:本文为weixin_42315235原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_42315235/article/details/109503654

智能推荐

前端小练习:jQuery酷炫照片墙

jQuery酷炫照片墙 效果展示: HTML代码: css代码: jQuery代码: 方法 解释 transform transform 属性向元素应用 2D 或 3D 转换。该属性允许我们对元素进行旋转、缩放、移动或倾斜。W3scool Math.random() 产生随机数。编程狮 translate 绘图函数编程狮 attr attr() 方法设置或返回被选元素的属性和值。编程狮 anima...

springMVC拦截器

一、     SpringMVC拦截器实现原理 用户请求到DispatherServlet中,DispatherServlet调用HandlerMapping查找Handler,HandlerMapping返回一个拦截器链(HandlerExecutionChain),springmvc中的拦截器是通过HandlerMapping发起的。 &nbs...

Unity Json反序列化

Json反序列化 结果:...

[机器学习-回归算法]Sklearn之线性回归实战

Sklearn之线性回归实战 一,前言 二,热身例子 三,贸易公司的简单例子 四,Sklearn 官网里的一个例子 参考资料 一,前言 一元线性回归的理论片请看我这个链接 二,热身例子 预测直线 y=1x1+2x2+3y = 1x_1 + 2x_2 +3y=1x1​+2x2​+3 导入LinearRegression 从Sklearn.liear_model 包里 拟合数据也可以说是训练 检验正确...

Android 开发者,你真的懂 Context 吗?

Android Context 详解 前言 一、Context是什么 二、Context结构 1、ContextImpl类介绍 2、ContextWrapper类介绍 3、ContextThemeWrapper 三、Context的数量 四、Context注意事项 五、如何正确回复以上面试题? 前言 Context 相信所有的 Android 开发人员基本上每天都在接触,因为它太常见了。但是这并不...

猜你喜欢

SpringMVC ----Json的简单交互处理

SpringMVC--Json Json的介绍 什么是JSON? JSON 和 JavaScript 对象互转 Controller返回JSON数据 Jackson 乱码 乱码的解决方法一 代码优化 乱码统一解决方法 返回json字符串统一解决 测试多个对象的集合输出 输出时间对象 抽取为工具类 FastJson fastjson 三个主要的类: JSONObject JSONArray JSON...

微信小程序自定义组件简单实现

本文将教你如何实现一个自定义的toast提示框,实现后的基本效果图如下: 小程序中一个自定义组件由 json wxml wxss js 4个文件组成的。下面我们一步一步地来创建文件及完成其中的配置: step1:创建自定义组件 首先创建一个components文件夹,用于放置所有自定义的组件,创建之后的目录结构为 其中的toastedit是我们本次...

PyTorch学习(四)--用PyTorch实现线性回归

教程视频:https://www.bilibili.com/video/BV1tE411s7QT 废话不多说,代码如下: 结果: 0 56.52023696899414 1 25.170454025268555 2 11.214292526245117 3 5.001270771026611 4 2.2352840900421143 5 1.0038176774978638 6 0.4554775...

1、Qt 的窗口组件和窗口类型

1、窗口组件 图形用户界面由不同的窗口和窗口组件组成 组件的类型 — 容器类(父组件):用于包含其它的界面组件 — 功能类(子组件):用于实现特定的交互功能 Qt 中没有父组件的顶级组件叫做窗口 QWidget QWidget 继承于 QObject 和 QPaintDevice — QObject 是所有支持 Qt 对象模型的基类 — QPaint...

从APP跳转到微信指定联系人聊天页面功能的实现与采坑之旅

从APP跳转到微信指定联系人聊天页面功能的实现与采坑之旅 起因 实现逻辑 效果图 实现过程 跳转微信按钮点击事件 无障碍监听主要方法 一些必要的参数 监听主要方法 遇到的坑 1. 搜索内容无法赋值给搜索框 2. 如何停止监听? 3. 没查询到结果如何停止监听? 4. 如果在微信其他页面怎么办? 5. 页面改变UI加载太慢 6. 聊天界面和主页面是同一个活动 7. 搜索不到结果时,发现他在搜索结果页...