深入理解Tensorflow中的masking和padding

TensorFlow是一个采用数据流图(data flow graphs),用于数值计算的开源软件库。节点(Nodes)在图中表示数学操作,图中的线(edges)则表示在节点间相互联系的多维数据数组,即张量(tensor)。它灵活的架构让你可以在多种平台上展开计算,例如台式计算机中的一个或多个CPU(或GPU),服务器,移动设备等等。TensorFlow 最初由Google大脑小组(隶属于Google机器智能研究机构)的研究员和工程师们开发出来,用于机器学习和深度神经网络方面的研究,但这个系统的通用性使其也可广泛用于其他计算领域。

声明:

需要读者对tensorflow和深度学习有一定了解

tf.boolean_mask实现类似numpy数组的mask操作

Python的numpy array可以使用boolean类型的数组作为索引,获得numpy array中对应boolean值为True的项。示例如下:

# numpy array中的boolean mask
import numpy as np
target_arr = np.arange(5)
print "numpy array before being masked:"
print target_arr
mask_arr = [True, False, True, False, False]
masked_arr = target_arr[mask_arr]
print "numpy array after being masked:"
print masked_arr

运行结果如下:

numpy array before being masked: [0 1 2 3 4] numpy array after being masked: [0 2]

tf.boolean_maks对目标tensor实现同上述numpy array一样的mask操作,该函数的参数也比较简单,如下所示:

tf.boolean_mask(
 tensor, # target tensor
 mask, # mask tensor
 axis=None,
 name='boolean_mask'
)

下面,我们来尝试一下tf.boolean_mask函数,示例如下:

import tensorflow as tf
# tensorflow中的boolean mask
target_tensor = tf.constant([[1, 2], [3, 4], [5, 6]])
mask_tensor = tf.constant([True, False, True])
masked_tensor = tf.boolean_mask(target_tensor, mask_tensor, axis=0)
sess = tf.InteractiveSession()
print masked_tensor.eval()

mask tensor中的第0和第2个元素是True,mask axis是第0维,也就是我们只选择了target tensor的第0行和第1行。

[[1 2] [5 6]]

如果把mask tensor也换成2维的tensor会怎样呢?

mask_tensor2 = tf.constant([[True, False], [False, False], [True, False]])
masked_tensor2 = tf.boolean_mask(target_tensor, mask_tensor, axis=0)
print masked_tensor2.eval()

[[1 2] [5 6]]

我们发现,结果不是[[1], [5]]。tf.boolean_mask不做元素维度的mask,tersorflow中有tf.ragged.boolean_mask实现元素维度的mask。

tf.ragged.boolean_mask
tf.ragged.boolean_mask(
 data,
 mask,
 name=None
)

tensorflow中的sparse向量和sparse mask tensorflow中的sparse tensor由三部分组成,分别是indices、values、dense_shape。对于稀疏张量SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]),转化成dense tensor的值为:

[[1, 0, 0, 0] [0, 0, 2, 0] [0, 0, 0, 0]]

使用tf.sparse.mask可以对sparse tensor执行mask操作。

tf.sparse.mask(
 a,
 mask_indices,
 name=None
)

上文定义的sparse tensor有1和2两个值,对应的indices为[[0, 0], [1, 2]],执行tf.sparsse.mask(a, [[1, 2]])后,稀疏向量转化成dense的值为:

[[1, 0, 0, 0] [0, 0, 0, 0] [0, 0, 0, 0]]

由于tf.sparse中的大多数函数都只在tensorflow2.0版本中有,所以没有实例演示。

padded_batch

tf.Dataset中的padded_batch函数,根据输入序列中的最大长度,自动的pad一个batch的序列。

padded_batch(
 batch_size,
 padded_shapes,
 padding_values=None,
 drop_remainder=False
)

这个函数与tf.Dataset中的batch函数对应,都是基于dataset构造batch,但是batch函数需要dataset中的所有样本形状相同,而padded_batch可以将不同形状的样本在构造batch时padding成一样的形状。

elements = [[1, 2], 
  [3, 4, 5], 
  [6, 7], 
  [8]] 
A = tf.data.Dataset.from_generator(lambda: iter(elements), tf.int32) 
B = A.padded_batch(2, padded_shapes=[None]) 
B_iter = B.make_one_shot_iterator()
print B_iter.get_next().eval()

[[1 2 0] [3 4 5]]

总结

到此这篇关于深入理解Tensorflow中的masking和padding的文章就介绍到这了,更多相关Tensorflow中的masking和padding内容请搜索ZaLou.Cn以前的文章或继续浏览下面的相关文章希望大家以后多多支持ZaLou.Cn!

原文地址:https://cloud.tencent.com/developer/article/1743485

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


学习编程是顺着互联网的发展潮流,是一件好事。新手如何学习编程?其实不难,不过在学习编程之前你得先了解你的目的是什么?这个很重要,因为目的决定你的发展方向、决定你的发展速度。
IT行业是什么工作做什么?IT行业的工作有:产品策划类、页面设计类、前端与移动、开发与测试、营销推广类、数据运营类、运营维护类、游戏相关类等,根据不同的分类下面有细分了不同的岗位。
女生学Java好就业吗?女生适合学Java编程吗?目前有不少女生学习Java开发,但要结合自身的情况,先了解自己适不适合去学习Java,不要盲目的选择不适合自己的Java培训班进行学习。只要肯下功夫钻研,多看、多想、多练
Can’t connect to local MySQL server through socket \'/var/lib/mysql/mysql.sock问题 1.进入mysql路径
oracle基本命令 一、登录操作 1.管理员登录 # 管理员登录 sqlplus / as sysdba 2.普通用户登录
一、背景 因为项目中需要通北京网络,所以需要连vpn,但是服务器有时候会断掉,所以写个shell脚本每五分钟去判断是否连接,于是就有下面的shell脚本。
BETWEEN 操作符选取介于两个值之间的数据范围内的值。这些值可以是数值、文本或者日期。
假如你已经使用过苹果开发者中心上架app,你肯定知道在苹果开发者中心的web界面,无法直接提交ipa文件,而是需要使用第三方工具,将ipa文件上传到构建版本,开...
下面的 SQL 语句指定了两个别名,一个是 name 列的别名,一个是 country 列的别名。**提示:**如果列名称包含空格,要求使用双引号或方括号:
在使用H5混合开发的app打包后,需要将ipa文件上传到appstore进行发布,就需要去苹果开发者中心进行发布。​
+----+--------------+---------------------------+-------+---------+
数组的声明并不是声明一个个单独的变量,比如 number0、number1、...、number99,而是声明一个数组变量,比如 numbers,然后使用 nu...
第一步:到appuploader官网下载辅助工具和iCloud驱动,使用前面创建的AppID登录。
如需删除表中的列,请使用下面的语法(请注意,某些数据库系统不允许这种在数据库表中删除列的方式):
前不久在制作win11pe,制作了一版,1.26GB,太大了,不满意,想再裁剪下,发现这次dism mount正常,commit或discard巨慢,以前都很快...
赛门铁克各个版本概览:https://knowledge.broadcom.com/external/article?legacyId=tech163829
实测Python 3.6.6用pip 21.3.1,再高就报错了,Python 3.10.7用pip 22.3.1是可以的
Broadcom Corporation (博通公司,股票代号AVGO)是全球领先的有线和无线通信半导体公司。其产品实现向家庭、 办公室和移动环境以及在这些环境...
发现个问题,server2016上安装了c4d这些版本,低版本的正常显示窗格,但红色圈出的高版本c4d打开后不显示窗格,
TAT:https://cloud.tencent.com/document/product/1340