如何解决在TensorFlow中如何避免重复训练和预测中的OOM错误?
我在TensorFlow中有一些代码,该代码采用基本模型,使用一些数据对其进行微调(训练),然后使用该模型将其他数据用于predict()
。所有这些都封装在模块的main()
方法中,并且可以正常工作。
但是,当我在不同的基本模型上循环运行此代码时,例如在7个基本模型之后,我最终得到一个OOM。这是预期的吗?我希望Python在每次main()
调用后都会清理。 TensorFlow不这样做吗?我该如何强制呢?
编辑:这是一条MWE,显示的不是OOM崩溃,而是增加了内存消耗:
import gc
import os
import numpy as np
import psutil
import tensorflow as tf
tf.get_logger().setLevel("ERROR") # Suppress "tf.function retracing" warnings
process = psutil.Process(os.getpid())
for i in range(100):
(model := tf.keras.applications.mobilenet.MobileNet()).compile(loss="mse")
history = model.fit(
x=(x := tf.zeros((1,*model.input.shape[1:]))),y=(y := tf.zeros((1,*model.output.shape[1:]))),verbose=0,)
prediction = model.predict(x)
_ = gc.collect()
# tf.keras.backend.clear_session()
print(f"rss {i}: {process.memory_info().rss >> 20} MB")
在我的计算机(CPU)上打印
rss 0: 374 MB
rss 1: 438 MB
rss 2: 478 MB
rss 3: 517 MB
rss 4: 554 MB
rss 5: 588 MB
rss 6: 634 MB
rss 7: 669 MB
rss 8: 686 MB
rss 9: 726 MB
...
rss 30: 1386 MB
rss 31: 1413 MB
rss 32: 1445 MB
rss 33: 1476 MB
rss 34: 1506 MB
rss 35: 1536 MB
rss 36: 1568 MB
rss 37: 1597 MB
rss 38: 1630 MB
rss 39: 1662 MB
...
如果没有评论tf.keras.backend.clear_session()
,那就更好了,但还不完善:
rss 0: 374 MB
rss 1: 420 MB
rss 2: 418 MB
rss 3: 450 MB
rss 4: 447 MB
rss 5: 469 MB
rss 6: 469 MB
rss 7: 475 MB
rss 8: 487 MB
rss 9: 494 MB
...
rss 40: 519 MB
rss 41: 516 MB
rss 42: 517 MB
rss 43: 520 MB
rss 44: 519 MB
rss 45: 519 MB
rss 46: 521 MB
rss 47: 517 MB
rss 48: 521 MB
rss 49: 521 MB
...
rss 90: 531 MB
rss 91: 531 MB
rss 92: 531 MB
rss 93: 531 MB
rss 94: 532 MB
rss 95: 532 MB
rss 96: 533 MB
rss 97: 534 MB
rss 98: 533 MB
rss 99: 533 MB
切换gc.collect()
和tf.keras.backend.clear_session()
的顺序也无济于事。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。