Skip to content
TopicTracker
来自 gilesthomas.com查看原文
译文语言译文语言

在 Flax 中使用 Safetensors

文章介绍了如何在 Flax(基于 JAX 的神经网络库)中使用 Safetensors 保存和加载模型检查点。作者发现 Safetensors 官方提供的 Flax/JAX API 仅支持平铺的字典结构(字符串→JAX 数组),而 Flax 的 nnx.State.to_pure_dict 会产生嵌套字典,直接传入会导致报错。解决方案是先通过 nnx.to_flat_state 将状态转换为扁平结构,再拼接成点分隔键名的简单字典,即可正常使用 safetensors.flax 的 save_file 和 load_file 函数。

相关报道