Google Metrax 为 JAX 带来预定义模型评估指标

最近由 Google 开源Metrax 是一个 JAX 库,为分类、回归、NLP、视觉和音频模型提供标准化、高性能的度量实现。 梅特拉克斯 Google 解释说,解决了 JAX 生态系统中的一个缺口,该缺口迫使许多团队从 TensorFlow 迁移到 JAX,以实现他们自己版本的常见评估指标,例如准确性、F1、RMS 误差等: 虽然对某些人来说,创建指标似乎是一个相当简单明了的主题,但当考虑跨数据中心大小的分布式计算环境进行大规模训练和评估时,它就变得不那么简单了。 谷歌指出,Metrax 的目标之一是确保所有指标得到良好实施并遵守最佳实践。在指标定义的支持下,Metrax 使用 vmap 和 jit 等高级 JAX 功能来提高性能。例如,这些功能用于实施新的“at K”度量,以能够并行计算多个 K 值。这使得能够更全面、更快速地评估模型。 您可以使用 PrecisionAtK 确定多个 K 值(例如,K=1、K=8 和 K=20 时)的模型精度,所有这些都在一次前向传递模型中进行,而不需要对每个参数多次调用 PrecisionAtK。 开发运营工程师 在 Substack 上以名称书写 神经铸造厂 写道: 事实上,Metrax 支持在一次传递中计算多个 K 值,这对于排名系统来说是一个巨大的胜利。每次切换项目时,我都在重写指标实用程序,而这种标准化早就应该实现了。 API 看起来也很干净。很好奇他们是否针对特定用例(例如大规模推荐管道)的自定义实现对其进行了基准测试。 下面的代码片段展示了如何计算 精确度量 给定的预测和标签。可以指定一个可选阈值来将概率预测转换为二进制预测: import metrax # […]