Keras中的自定义损失函数应该返回该批次的单个损失值,还是返回该训练批次中每个样本的大量损失?

如何解决Keras中的自定义损失函数应该返回该批次的单个损失值,还是返回该训练批次中每个样本的大量损失?

我正在使用tensorflow(2.3)学习keras API。在tensorflow网站上的guide中,我找到了一个自定义损失函数的示例:

    def custom_mean_squared_error(y_true,y_pred):
        return tf.math.reduce_mean(tf.square(y_true - y_pred))

此自定义损失函数中的reduce_mean函数将返回标量。

像这样定义损失函数是否正确?据我所知,y_truey_pred的形状的第一维是批量大小。我认为损失函数应返回批次中每个样品的损失值。因此,损失函数应该给出形状为(batch_size,)的数组。但是上面的函数为整个批次提供了一个单一的值。

也许上面的例子是错误的?有人可以帮我解决这个问题吗?


p.s。 为什么我认为损失函数应该返回数组而不是单个值?

我阅读了Model类的源代码。当您为Model.compile()方法提供损失函数(请注意,它是一个函数,而不是损失)时,该损失函数将用于构造{{ 1}}对象,该对象存储在LossesContainer中。传递给Model.compiled_loss类的构造函数的损失函数再次用于构造LossesContainer对象,该对象存储在LossFunctionWrapper中。

根据LossFunctionWrapper类的源代码,通过LossesContainer._losses方法(从LossFunctionWrapper.__call__()类继承)计算训练批次的总损失值,即返回整个批次的单个损失值。。但是Loss首先调用LossFunctionWrapper.__call__()方法,以获取训练批次中每个样本的损失数组。然后对这些损失进行平均,以得到整个批次的单个损失值。调用LossFunctionWrapper.call()方法提供的损失函数就是在LossFunctionWrapper.call()方法中。

这就是为什么我认为自定义损失函数应该返回一系列损失,并具有单个标量值。此外,如果我们为Model.compile()方法编写了一个自定义Loss类,那么我们的自定义Model.compile()类的call()方法也应该返回一个数组,而不是一个信号值。


我在github上打开了issue。已确认需要自定义损失函数才能为每个样本返回一个损失值。该示例将需要更新以反映这一点。

解决方法

我在github上打开了issue。已确认需要自定义损失函数才能为每个样本返回一个损失值。该示例将需要更新以反映这一点。

,

实际上,据我所知,损失函数的返回值的形状并不重要,即它可以是标量张量或每个样本一个或多个值的张量。重要的是如何将其减小为标量值,以便可以将其用于优化过程或显示给用户。为此,您可以在Reduction documentation中检查缩小类型。

此外,这是compile方法documentation关于loss参数的内容,部分解决了这一点:

损失:字符串(目标函数的名称),目标函数或tf.keras.losses.Loss实例。参见tf.keras.losses。目标函数可以是带有签名loss = fn(y_true,y_pred)的任何可调用对象,其中y_true =形状为[batch_size,d0,.. dN]的地面真值,稀疏损失函数(例如,形状= {{1}的稀疏分类交叉熵)除外}。 [batch_size,.. dN-1] =形状为y_pred的预测值。它返回一个加权损失浮点张量。如果使用自定义[batch_size,.. dN]实例,并且将reduce设置为Loss,则返回值的形状为NONE,即。每个样本或每个时间步的损耗值;否则,它是一个标量。如果模型有多个输出,则可以通过传递字典或损失列表来在每个输出上使用不同的损失。该模型将使损失值最小化,将是所有单个损失的总和。

此外,值得注意的是,TF / Keras中的大多数内置损耗函数通常会在最后一个维度(即[batch_size,.. dN-1])上减小。


对于那些怀疑返回标量值的自定义损失函数是否会起作用的人:您可以运行以下代码段,然后您会看到模型可以正确训练和收敛。

axis=-1
,

<div class="container"> <table border='1' id='theTable'> <thead> <tr> <th>Name</th> <th>Role</th> </tr> </thead> <tbody> <tr> <td>Adam</td> <td>AAA</td> </tr> <tr> <td>Adam</td> <td>BBB</td> </tr> <tr> <td>Adam</td> <td>CCC</td> </tr> <tr> <td>Bert</td> <td>AAA</td> </tr> <tr> <td>Bert</td> <td>CCC</td> </tr> <tr> <td>Cesar</td> <td>BBB</td> </tr> </tbody> </table> <br> <table id='newTable' border='1'> <thead></thead> <tbody></tbody> </table> </div> <script src="http://code.jquery.com/jquery-1.11.0.min.js"></script> <script> $(document).ready(function () { var role_arr = []; $("#theTable td:nth-child(2)").each(function() { if ($.inArray($(this).text(),role_arr) == -1) role_arr.push($(this).text()); }); role_arr.sort() console.log(role_arr); // create thead row and put Roles in it var trow = "<tr>"; trow += '<th>Name</th>'; for (var i=0; i<role_arr.length; i++) { trow +='<th>'+ role_arr[i] +'</th>'; } trow += '</tr>'; $("#newTable").find("thead").append(trow); // create all names array var name_arr = []; $("#theTable td:nth-child(1)").each(function() { if ($.inArray($(this).text(),name_arr) == -1) name_arr.push($(this).text()); }); console.log(name_arr); for (var i=0; i<name_arr.length; i++) { // create an array for each name's roles var row_arr = []; $("#theTable tr:has(td:contains('"+name_arr[i]+"'))").each(function () { //console.log($(this).find('td:nth-child(2)').text()); row_arr.push($(this).find('td:nth-child(2)').text()); }); // create the table body row row var trow = "<tr>"; trow += '<td>'+name_arr[i]+'</td>'; for(var j=0; j<role_arr.length; j++) { if(row_arr.includes(role_arr[j])) { trow += '<td> X </td>'; } else { trow += '<td> - </td>'; } } trow += '</tr>'; $("#newTable").find("tbody").append(trow); } }); </script>取批次的平均值并返回。这就是为什么它是一个标量。

,

Tensorflow 网站上给出的损失函数是绝对正确的。

def custom_mean_squared_error(y_true,y_pred):
    return tf.math.reduce_mean(tf.square(y_true - y_pred))

在机器学习中,我们使用的损失是各个训练示例损失的总和,因此它应该是一个标量值。 (由于所有示例,我们使用的是单个网络,因此我们需要使用单个损耗值来更新参数。)

关于使集装箱蒙受损失:

在使用并行计算时,制作容器是一种更简单,可行的方法,因为我们使用批次而不是整个训练集来跟踪计算的损失指数。

,

我认为@Gödel发表的问题完全合法,而且是正确的。自定义损失函数应返回每个样本的损失值。并且,@ today提供的解释也是正确的。最后,这完全取决于所使用的 减少量

因此,如果使用类API创建损失函数,则减少参数会自动在自定义类中继承。使用其默认值“ sum_over_batch_size ”(这是给定批次中所有损失值的平均)。其他选项是“ 求和”,它计算总和而不是取平均值,最后一个选项是“ ”,其中返回损失值数组。

Keras文档中还提到,当人们使用model.fit()时,减少的这些差异是无可争辩的,因为减少是由TF / Keras自动处理的。

最后,还要提到的是,在创建自定义损失函数时,应返回一系列损失(单个样本损失)。它们的减少由框架处理。

链接:

,

由于有多个通道,因此可以增加维数。但是,每个通道的损耗都应只有一个标量值。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


依赖报错 idea导入项目后依赖报错,解决方案:https://blog.csdn.net/weixin_42420249/article/details/81191861 依赖版本报错:更换其他版本 无法下载依赖可参考:https://blog.csdn.net/weixin_42628809/a
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下 2021-12-03 13:33:33.927 ERROR 7228 [ main] o.s.b.d.LoggingFailureAnalysisReporter : *************************** APPL
错误1:gradle项目控制台输出为乱码 # 解决方案:https://blog.csdn.net/weixin_43501566/article/details/112482302 # 在gradle-wrapper.properties 添加以下内容 org.gradle.jvmargs=-Df
错误还原:在查询的过程中,传入的workType为0时,该条件不起作用 &lt;select id=&quot;xxx&quot;&gt; SELECT di.id, di.name, di.work_type, di.updated... &lt;where&gt; &lt;if test=&qu
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct redisServer’没有名为‘server_cpulist’的成员 redisSetCpuAffinity(server.server_cpulist); ^ server.c: 在函数‘hasActiveC
解决方案1 1、改项目中.idea/workspace.xml配置文件,增加dynamic.classpath参数 2、搜索PropertiesComponent,添加如下 &lt;property name=&quot;dynamic.classpath&quot; value=&quot;tru
删除根组件app.vue中的默认代码后报错:Module Error (from ./node_modules/eslint-loader/index.js): 解决方案:关闭ESlint代码检测,在项目根目录创建vue.config.js,在文件中添加 module.exports = { lin
查看spark默认的python版本 [root@master day27]# pyspark /home/software/spark-2.3.4-bin-hadoop2.7/conf/spark-env.sh: line 2: /usr/local/hadoop/bin/hadoop: No s
使用本地python环境可以成功执行 import pandas as pd import matplotlib.pyplot as plt # 设置字体 plt.rcParams[&#39;font.sans-serif&#39;] = [&#39;SimHei&#39;] # 能正确显示负号 p
错误1:Request method ‘DELETE‘ not supported 错误还原:controller层有一个接口,访问该接口时报错:Request method ‘DELETE‘ not supported 错误原因:没有接收到前端传入的参数,修改为如下 参考 错误2:cannot r
错误1:启动docker镜像时报错:Error response from daemon: driver failed programming external connectivity on endpoint quirky_allen 解决方法:重启docker -&gt; systemctl r
错误1:private field ‘xxx‘ is never assigned 按Altʾnter快捷键,选择第2项 参考:https://blog.csdn.net/shi_hong_fei_hei/article/details/88814070 错误2:启动时报错,不能找到主启动类 #
报错如下,通过源不能下载,最后警告pip需升级版本 Requirement already satisfied: pip in c:\users\ychen\appdata\local\programs\python\python310\lib\site-packages (22.0.4) Coll
错误1:maven打包报错 错误还原:使用maven打包项目时报错如下 [ERROR] Failed to execute goal org.apache.maven.plugins:maven-resources-plugin:3.2.0:resources (default-resources)
错误1:服务调用时报错 服务消费者模块assess通过openFeign调用服务提供者模块hires 如下为服务提供者模块hires的控制层接口 @RestController @RequestMapping(&quot;/hires&quot;) public class FeignControl
错误1:运行项目后报如下错误 解决方案 报错2:Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.1:compile (default-compile) on project sb 解决方案:在pom.
参考 错误原因 过滤器或拦截器在生效时,redisTemplate还没有注入 解决方案:在注入容器时就生效 @Component //项目运行时就注入Spring容器 public class RedisBean { @Resource private RedisTemplate&lt;String
使用vite构建项目报错 C:\Users\ychen\work&gt;npm init @vitejs/app @vitejs/create-app is deprecated, use npm init vite instead C:\Users\ychen\AppData\Local\npm-