pytorch 图像中的数据预处理和批标准化实例

目前数据预处理最常见的方法就是中心化和标准化。

中心化相当于修正数据的中心位置,实现方法非常简单,就是在每个特征维度上减去对应的均值,最后得到 0 均值的特征。

标准化也非常简单,在数据变成 0 均值之后,为了使得不同的特征维度有着相同的规模,可以除以标准差近似为一个标准正态分布,也可以依据最大值和最小值将其转化为 -1 ~ 1 之间

批标准化:BN

在数据预处理的时候,我们尽量输入特征不相关且满足一个标准的正态分布,这样模型的表现一般也较好。但是对于很深的网路结构,网路的非线性层会使得输出的结果变得相关,且不再满足一个标准的 N(0,1) 的分布,甚至输出的中心已经发生了偏移,这对于模型的训练,特别是深层的模型训练非常的困难。

所以在 2015 年一篇论文提出了这个方法,批标准化,简而言之,就是对于每一层网络的输出,对其做一个归一化,使其服从标准的正态分布,这样后一层网络的输入也是一个标准的正态分布,所以能够比较好的进行训练,加快收敛速度。

batch normalization 的实现非常简单,接下来写一下对应的python代码:

import sys
sys.path.append('..')

import torch

def simple_batch_norm_1d(x,gamma,beta):
  eps = 1e-5
  x_mean = torch.mean(x,dim=0,keepdim=True) # 保留维度进行 broadcast
  x_var = torch.mean((x - x_mean) ** 2,keepdim=True)
  x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)

x = torch.arange(15).view(5,3)
gamma = torch.ones(x.shape[1])
beta = torch.zeros(x.shape[1])
print('before bn: ')
print(x)
y = simple_batch_norm_1d(x,beta)
print('after bn: ')
print(y)

测试的时候该使用批标准化吗?

答案是肯定的,因为训练的时候使用了,而测试的时候不使用肯定会导致结果出现偏差,但是测试的时候如果只有一个数据集,那么均值不就是这个值,方差为 0 吗?这显然是随机的,所以测试的时候不能用测试的数据集去算均值和方差,而是用训练的时候算出的移动平均均值和方差去代替

下面我们实现以下能够区分训练状态和测试状态的批标准化方法

def batch_norm_1d(x,beta,is_training,moving_mean,moving_var,moving_momentum=0.1):
  eps = 1e-5
  x_mean = torch.mean(x,keepdim=True)
  if is_training:
    x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
    moving_mean[:] = moving_momentum * moving_mean + (1. - moving_momentum) * x_mean
    moving_var[:] = moving_momentum * moving_var + (1. - moving_momentum) * x_var
  else:
    x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)
  return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)

下面我们在卷积网络下试用一下批标准化看看效果

def data_tf(x):
  x = np.array(x,dtype='float32') / 255
  x = (x - 0.5) / 0.5 # 数据预处理,标准化
  x = torch.from_numpy(x)
  x = x.unsqueeze(0)
  return x

train_set = mnist.MNIST('./data',train=True,transform=data_tf,download=True) # 重新载入数据集,申明定义的数据变换
test_set = mnist.MNIST('./data',train=False,download=True)
train_data = DataLoader(train_set,batch_size=64,shuffle=True)
test_data = DataLoader(test_set,batch_size=128,shuffle=False)
# 使用批标准化
class conv_bn_net(nn.Module):
  def __init__(self):
    super(conv_bn_net,self).__init__()
    self.stage1 = nn.Sequential(
      nn.Conv2d(1,6,3,padding=1),nn.BatchNorm2d(6),nn.ReLU(True),nn.MaxPool2d(2,2),nn.Conv2d(6,16,5),nn.BatchNorm2d(16),2)
    )

    self.classfy = nn.Linear(400,10)
  def forward(self,x):
    x = self.stage1(x)
    x = x.view(x.shape[0],-1)
    x = self.classfy(x)
    return x

net = conv_bn_net()
optimizer = torch.optim.SGD(net.parameters(),1e-1) # 使用随机梯度下降,学习率 0.1

train(net,train_data,test_data,5,optimizer,criterion)

以上这篇pytorch 图像中的数据预处理和批标准化实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


使用OpenCV实现视频去抖 整体步骤: 设置输入输出视频 寻找帧之间的移动:使用opencv的特征检测器,检测前一帧的特征,并使用Lucas-Kanade光流算法在下一帧跟踪这些特征,根据两组点,将前一个坐标系映射到当前坐标系完成刚性(欧几里得)变换,最后使用数组纪录帧之间的运动。 计算帧之间的平
前言 对中文标题使用余弦相似度算法和编辑距离相似度分析进行相似度分析。 准备数据集part1 本次使用的数据集来源于前几年的硕士学位论文,可根据实际需要更换。结构如下所示: 学位论文题名 基于卷积神经网络的人脸识别研究 P2P流媒体视频点播系统设计和研究 校园网安全体系的设计与实现 无线传感器网络中
前言 之前尝试写过一个爬虫,那时对网页请求还不够熟练,用的原理是:爬取整个html文件,然后根据标签页筛选有效信息。 现在看来这种方式无疑是吃力不讨好,因此现在重新写了一个爬取天气的程序。 准备工作 网上能轻松找到的是 101010100 北京这种编号,而查看中国气象局URL,他们使用的是北京545
前言 本文使用Python实现了PCA算法,并使用ORL人脸数据集进行了测试并输出特征脸,简单实现了人脸识别的功能。 1. 准备 ORL人脸数据集共包含40个不同人的400张图像,是在1992年4月至1994年4月期间由英国剑桥的Olivetti研究实验室创建。此数据集包含40个类,每个类含10张图
前言 使用opencv对图像进行操作,要求:(1)定位银行票据的四条边,然后旋正。(2)根据版面分析,分割出小写金额区域。 图像校正 首先是对图像的校正 读取图片 对图片二值化 进行边缘检测 对边缘的进行霍夫曼变换 将变换结果从极坐标空间投影到笛卡尔坐标得到倾斜角 根据倾斜角对主体校正 import
天气预报API 功能 从中国天气网抓取数据返回1-7天的天气数据,包括: 日期 天气 温度 风力 风向 def get_weather(city): 入参: 城市名,type为字符串,如西安、北京,因为数据引用中国气象网,因此只支持中国城市 返回: 1、列表,包括1-7的天气数据,每一天的分别为一个
数据来源:House Prices - Advanced Regression Techniques 参考文献: Comprehensive data exploration with Python 1. 导入数据 import pandas as pd import warnings warnin
同步和异步 同步和异步是指程序的执行方式。在同步执行中,程序会按顺序一个接一个地执行任务,直到当前任务完成。而在异步执行中,程序会在等待当前任务完成的同时,执行其他任务。 同步执行意味着程序会阻塞,等待任务完成,而异步执行则意味着程序不会阻塞,可以同时执行多个任务。 同步和异步的选择取决于你的程序需
实现代码 import time import pydirectinput import keyboard if __name__ == '__main__': revolve = False while True: time.sleep(0.1) if keyboard.is_pr
本文从多个角度分析了vi编辑器保存退出命令。我们介绍了保存和退出vi编辑器的命令,以及如何撤销更改、移动光标、查找和替换文本等实用命令。希望这些技巧能帮助你更好地使用vi编辑器。
Python中的回车和换行是计算机中文本处理中的两个重要概念,它们在代码编写中扮演着非常重要的角色。本文从多个角度分析了Python中的回车和换行,包括回车和换行的概念、使用方法、使用场景和注意事项。通过本文的介绍,读者可以更好地理解和掌握Python中的回车和换行,从而编写出更加高效和规范的Python代码。
SQL Server启动不了错误1067是一种比较常见的故障,主要原因是数据库服务启动失败、权限不足和数据库文件损坏等。要解决这个问题,我们需要检查服务日志、重启服务器、检查文件权限和恢复数据库文件等。在日常的数据库运维工作中,我们应该时刻关注数据库的运行状况,及时发现并解决问题,以确保数据库的正常运行。
信息模块是一种可重复使用的、可编程的、可扩展的、可维护的、可测试的、可重构的软件组件。信息模块的端接需要从接口设计、数据格式、消息传递、函数调用等方面进行考虑。信息模块的端接需要满足高内聚、低耦合的原则,以保证系统的可扩展性和可维护性。
本文从电脑配置、PyCharm版本、Java版本、配置文件以及程序冲突等多个角度分析了Win10启动不了PyCharm的可能原因,并提供了解决方法。
本文主要从多个角度分析了安装SQL Server 2012时可能出现的错误,并提供了解决方法。
Pycharm是一款非常优秀的Python集成开发环境,它可以让Python开发者更加高效地进行代码编写、调试和测试。在Pycharm中设置解释器非常简单,我们可以通过创建新项目、修改项目解释器、设置全局解释器等多种方式进行设置。
Python中有多种方法可以将字符串转换为整数,包括使用int()函数、try-except语句、正则表达式、map()函数、ord()函数和reduce()函数。在实际应用中,应根据具体情况选择最合适的方法。
本文介绍了导入CSV文件的多种方法,包括使用Excel、Python和R等工具。同时,还介绍了导入CSV文件时需要注意的一些细节和问题。CSV文件是数据处理和分析中不可或缺的一部分,希望本文能够对读者有所帮助。
mongodb是一种新型的数据库,它采用了面向文档的数据模型,具有灵活性、高性能和高可用性等优势。但是,mongodb也存在数据结构混乱、安全性和学习成本高等问题。
当Python运行不了时,我们应该从代码、Python环境、操作系统和硬件设备等多个角度来排查问题,并采取相应的解决措施。