- VisualStudio2022插件的安装及使用-编程手把手系列文章
- pprof-在现网场景怎么用
- C#实现的下拉多选框,下拉多选树,多级节点
- 【学习笔记】基础数据结构:猫树
本篇是一次内部分享,给项目开发的同事分享什么是深度学习。用最简单的手写数字识别做例子,讲解了大概的原理.
展示首先数字识别项目的使用。项目实现过程
通俗解释 机器学习的关键内涵之一在于利用计算机的运算能力从大量的数据中发现一个规律,用这个规律实现预测或判断的功能.
以算法区分深度学习应用,算法类别可分成三大类:
卷积神经网络
递归神经网络
对抗神经网络
卷积神经网络(CNN)主要应用可分为图像分类、目标检测、语义分割 。
图片在计算机中以数字矩阵的形式存储。 https://h.markbuild.com/doc/binary-viewer-cn.html 。
图片的保存:
模型训练的思想:
损失函数:衡量训练结果和实际偏差的函数。数值越大代表差距越大 优化器:优化模型的算法,让损失函数减小的方法 。
Q&A 。
模型训练使用pytorch框架,同样可以实现的框架还由tensorflow、keras.
手写识别使用的是MNIST数据集,手写数字图片。MNIST数据集由像素是28 × 28 的0~9的手写数字图片组成,一共有7万张图片,其中6万张是训练集,1万张是测试集。每个图片是黑底白字的形式.
pytorch 中提供了torchvision 包,可以通过该包可以下载数据集 。
import torchvision
import matplotlib.pyplot as plt
# 训练数据集
train_data = torchvision.datasets.MNIST(
root="data", # 表示把MINST保存在data文件夹下
download=True, # 表示需要从网络上下载。下载过一次后,下一次就不会再重复下载了
train=True, # 表示这是训练数据集
transform=torchvision.transforms.ToTensor()
# 要把数据集中的数据转换为pytorch能够使用的Tensor类型
)
# 测试数据集
test_data = torchvision.datasets.MNIST(
root="data", # 表示把MINST保存在data文件夹下
download=True, # 表示需要从网络上下载。下载过一次后,下一次就不会再重复下载了
train=False, # 表示这是测试数据集
transform=torchvision.transforms.ToTensor()
# 要把数据集中的数据转换为pytorch能够使用的Tensor类型
)
演示 。
模型使用的是卷积神经网络模型。定义的神经网络模型如下:
import torch.nn as nn
# 定义卷积神经网络类
class RLS_CNN(nn.Module):
def __init__(self):
super(RLS_CNN, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=16, # 输入、输出通道数,输出通道数可以理解为提取了几种特征
kernel_size=(3, 3), # 卷积核尺寸
stride=(1, 1), # 卷积核每次移动多少个像素
padding=1), # 原图片边缘加几个空白像素
# 输入图片尺寸为 1×28×28
# 第一次卷积,尺寸为 16×28×28
nn.MaxPool2d(kernel_size=2), # 第一次池化,尺寸为 16×14×14
nn.Conv2d(16, 32, 3, 1, 1), # 第二次卷积,尺寸为 32×14×14
nn.MaxPool2d(2), # 第二次池化,尺寸为 32×7 ×7
nn.Flatten(), # 将三维数组变成一维数组
nn.Linear(32*7*7, 16), # 变成16个卷积核,每一个卷积核是1*1,最后输出16个数字
nn.ReLU(), # 激活函数 x<0 y=0 x>0 y=x,用在反向反向传导
nn.Linear(16, 10) # 将16变成10,预测0-9之间概率值
)
def forward(self, x):
return self.net(x)
卷积神经网络通常由3个部分构成:卷积层,池化层,全连接层。各部分的功能:
美颜相机的原理就是提取图片的特征,如下图片第二张模糊轮廓,第三张是突出轮廓.
卷积的功能:提取图片的多种特征信息 卷积的原理:用一个卷积核和图片的矩阵相乘,得到一个新的矩阵。新矩阵就是一个新的特征。 卷积核 卷积核也是一个矩阵,通常是33的矩阵,或者是55的矩阵。卷积运算的过程如下:
图像边缘提取 使用如下的卷积核就可以提取图像的边缘轮廓特征 。
调参: 卷积核矩阵由3*3一共9个参数组成,这些参数都是模型自动生成的,所谓的调参,其中一部分就是指调整卷积核矩阵的参数,让其提取的特征能够使预测更加准确 。
池化的功能:池化就是缩小矩阵的尺寸,从而减少后续操作的参数数量。通常会在相邻的卷积层之间加入一个池化层。 池化的原理:池化的运算过程:将一个44的矩阵最大池化成22的矩阵,就是取4*4矩阵中对应区域中最大的一个数值.
池化通常有两种:
全连接功能: 全连接的作用是组合特征和分类。 在前面两个步骤中从一张图片提取多种特征,并将特征矩阵进行了压缩。当数据到达全连接层时得到是一张图片的多种特征。 某一个特征并不能说整个图片是什么,否则就是盲人摸象。那么全连接层就是将多种特征组合起来形成一个完整的特征,并根据特征计算出图片是某一个类型的概率。 全连接层最终输出就是概率。比如手写数字识别,最终全连接层输出就是某一个手写数字在0~9上的概率.
tensor([[ 0.949, 3.032, 0.771, -2.173, -0.038, -0.236, 0.013, 0.614, -1.125, -2.6991]])
全连接的原理 全连接层实现的是特征组合,原理和卷积类似,也就是用一个卷积核对矩阵做运算,最后得到一个一维的数组,也就是0-9的概率.
调参:全连接的实现也需要卷积核的参与,所以卷积核矩阵也是参数的一部分,调参就包括该部分的参数.
手写数字识别的卷积神经网络,下面分析卷积+池化+全连接的过程:
Q&A 。
损失函数功能:衡量训练结果和实际偏差的函数。数值越大代表差距越大 优化器功能:让模型不断优化,让损失函数减小的方法 。
手写数字识别中使用的损失函数和优化器如下:
# 交叉熵损失函数,选择一种方法计算误差值
loss_func = torch.nn.CrossEntropyLoss()
# 优化器,随机梯度下降算法
optimizer = torch.optim.SGD(model.parameters(), lr=0.2)
手写识别中选择了交叉熵损失函数,pytorch一共有19中损失函数可以使用,比较好理解的是平方差损失函数 。
手写识别中选了随机梯度下降算法,用来实现反向传播参数的修改。pytorch中一共有11中优化器可以使用.
模型训练的流程:
# 定义训练次数
cnt_epochs = 5 # 训练5个循环
# 循环训练
for cnt in range(cnt_epochs):
# 把训练集中的数据训练一遍
for imgs, labels in train_dataloader:
outputs = model(imgs) # 输出0~9预测的结果概率
loss = loss_func(outputs, labels) # 和输入做一个比较,得到一个误差
optimizer.zero_grad() # 初始化梯度,清空梯度。注意清空优化器的梯度,防止累计
loss.backward() # 方向传播计算
optimizer.step() # 累加1,执行一次
# 保存训练的结果(包括模型和参数)
torch.save(model, "my_cnn.nn")
需要注意的点:
Q&A
卷积
+ 池化
+ 全连接
+ 损失函数
+ 优化器
最后此篇关于Pytorch手写数字识别深度学习基础分享的文章就讲到这里了,如果你想了解更多关于Pytorch手写数字识别深度学习基础分享的内容请搜索CFSDN的文章或继续浏览相关文章,希望大家以后支持我的博客! 。
直接上代码,可以写在公共文件common和继承的基础类中,方便调用 ?
1、php服务端环境搭建 1.php 服务端环境 安装套件 xampp(apach+mysql+php解释器) f:\mydoc文件(重要)\dl_学习\download重要资源\apache
如下所示: Eclipse快捷键 Ctrl+1 快速修复 Ctrl+D: 删除当前行 Ctrl+Alt+↓ 复制当前行到下一行(复制增加) Ctrl+Alt+↑ 复制当前行到上一行(复制增加)
第一步:conn.PHP文件,用于连接数据库并定义接口格式,代码如下: php" id="highlighter_808731">
本篇文章整理了几道Linux下C语言的经典面试题,相信对大家更好的理解Linux下的C语言会有很大的帮助,欢迎大家探讨指正。 1、如果在Linux下使用GCC编译器执行下列程序,输出结果是什么?
安装完最新的Boost库 官方说明中有一句话: Finally, $ ./b2 install will leave Boost binaries in the lib/ subdirecto
为了梳理前面学习的《spring整合mybatis(maven+mysql)一》与《spring整合mybatis(maven+mysql)二》中的内容,准备做一个完整的示例完成一个简单的图书管理功
网站内容质量仅仅是页面综合得分里面的一项.不管算法如何改变调整,搜索引擎都不会丢弃网站页面的综合得分。 一般情况下我们把页面的综合得分为8个点: 1、标题的设置 (标题的设置要有独特性)
最近事情很忙,一个新项目赶着出来,但是很多功能都要重新做,一直在编写代码、debug。今天因为一个新程序要使用fragment来做,虽然以前也使用过fragment,不过没有仔细研究,今天顺道写篇文
Android资源命名规范 最近几个月,大量涉及android资源的相关工作。对于复杂的应用而言,资源命名的规范很有必要。除了开发人员之外,UI设计人员(或者切图相关人员)也需要对资源使用的位置非常
以前一直使用Hibernate,基本上没用过Mybatis,工作中需要做映射关系,简单的了解下Mybatis的映射。 两者相差不多都支持一对一,一对多,多对多,本章简单介绍一对一的使用以及注意点。
如下所示: ? 1
如果想在自定义的View上面显示Button 等View组件需要完成如下任务 1.在自定义View的类中覆盖父类的构造(注意是2个参数的) 复制代码 代码如下: publ
实现功能:实现表格tr拖动,并保存因为拖动改变的等级. jsp代码 ?
代码:测试类 java" id="highlighter_819000"> ?
红黑树是一种二叉平衡查找树,每个结点上有一个存储位来表示结点的颜色,可以是red或black。 红黑树具有以下性质: (1) 每个结点是红色或是黑色 (2) 根结点是黑色的 (3) 如果一个
废话不多说,直接上代码 ? 1
码代码时,有时候需要根据比较大小分别赋值: ? 1
实际项目开发中,我们经常会用一些版本控制器来托管自己的代码,今天就来总结下Git的相关用法,废话不多说,直接开写。 目的:通过Git管理github托管项目代码 1、下载安装Git 1、下载
直接上代码: 复制代码 代码如下: //验证码类 class ValidateCode { private $charset = 'abcdefghkmnprstuvwxyzABC
我是一名优秀的程序员,十分优秀!