Skip to content
TopicTracker
来自 gilesthomas.com查看原文
译文语言译文语言

JAX 后端与设备

作者在将 PyTorch 的 LLM 代码移植到 JAX 时,遇到加载 19GiB 数据集时 CUDA 显存不足的问题。通过探究发现,JAX 默认将数据分配到最快的可用后端(GPU),而不会自动利用 CPU 内存。文章详细介绍了 JAX 的 backend 与 device 概念,并展示了如何使用 `jax.default_device` 上下文管理器临时切换到 CPU 设备加载大数组,以绕过 GPU 显存限制。

相关报道