如何解决使用jax时,XLA的Jit编译速度非常慢
我正在使用Jax做一些机器学习工作。 Jax使用XLA进行即时编译来加速,但是编译本身在CPU上太慢。我的情况是,CPU将仅使用单个内核来进行编译,这根本没有效率。
我找到了一些答案,如果我可以使用GPU进行编译,那将很快。谁能告诉我如何使用GPU进行编译?由于我没有对编译进行任何配置。谢谢!
该问题的一些补充:我正在使用Jax计算grad和hessian,这会使编译非常慢。代码如下:
## get results from model ##
def get_model_value(images):
return jnp.sum(model(images))
def get_model_grad(images):
images = jnp.expand_dims(images,axis=0)
image_grad = jacfwd(get_model_value)(images)
return image_grad
def get_model_hessian(images):
images = jnp.expand_dims(images,axis=0)
image_hess = jacfwd(jacrev(get_model_value))(images)
return image_hess
# get value
model_value = model(dis_img)
FR_value = jnp.expand_dims(FR_value,axis=1)
value_loss = crit_mse(model_value,FR_value)
# get grad
vmap_model_grad = jax.vmap(get_model_grad)
model_grad = vmap_model_grad(dis_img)
# get hessian
vmap_model_hessian = vmap(get_model_hessian)
model_hessian = vmap_model_hessian(dis_img)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。