网站首页 > 技术文章 正文
机器之心整理
参与:思源、一鸣
一行代码定义计算图,So Easy,妈妈再也不用担心我的机器学习。
项目地址:https://github.com/JuliusKunze/jaxnet
JAXnet 是一个基于 JAX 的深度学习库,它的 API 提供了便利的模型搭建体验。相比 TensorFlow 2.0 或 PyTorch 等主流框架,JAXnet 拥有独特的优势。举个栗子,不论是 Keras 还是 PyTorch,它们建模就像搭积木一样。
然而,还有一种比搭积木更简单的方法,这就是 JAXnet 的模块化:
from jaxnet import * net = Sequential(Dense(1024), relu, Dense(1024), relu, Dense(4), logsoftmax) creates a neural net model from predefined modules.
创建一个全连接网络可以直接用预定义的模块,可以说 JAXnet 定义计算图,只需一行代码就可以了。写一个神经网络,原来 So easy。
总体来说,JAXnet 主要关注的是模块化、可扩展性和易用性等几个方面:
- 采用了不可变权重,而不是全局计算图,从而获得更强的稳健性;
- 用于构建神经网络、训练循环、预处理、后处理等过程的 NumPy 代码经过 GPU 编译;
- 任意模块或整个网络的正则化、重参数化都只需要一行代码;
- 取消了全局随机状态,采用了更便捷的 Key 控制。
可扩展性
你可以使用 @parametrized 定义自己的模块,并复用其它的模块:
from jax import numpy as np @parametrizeddef loss(inputs, targets): return -np.mean(net(inputs) * targets)
所有的模块都是用这样的方法组合在一起的。jax.numpy (https://github.com/google/jax#whats-supported) 是 numpy 的镜像。只要你知道怎么使用 numpy,那么你就可以知道 JAXnet 大部分的用法了。
以下是 TensorFlow2/Keras 的代码,JAXnet 相比之下更为简洁:
import tensorflow as tf from tensorflow.keras import Sequential from tensorflow.keras.layers import Dense, Lambda net = Sequential([Dense(1024, 'relu'), Dense(1024, 'relu'), Dense(4), Lambda(tf.nn.log_softmax)]) def loss(inputs, targets): return -tf.reduce_mean(net(inputs) * targets)
需要注意的是,Lambda 函数在 JAXnet 中不是必要的。而 relu 和 logsoftmax 函数都是 Python 写的函数。
非可变权重
和 TensorFlow 或者 Keras 不同,JAXnet 没有全局计算图。net 和 loss 这样的模块不保存可变权重。权重则是保存在分开的不可变类中。这些权重由 init_parameters 函数初始化,用于提供随机的键和样本输入:
from jax.random import PRNGKey def next_batch(): return np.zeros((3, 784)), np.zeros((3, 4)) params = loss.init_parameters(PRNGKey(0), *next_batch()) print(params.sequential.dense2.bias) # [0.00376661 0.01038619 0.00920947 0.00792002]
目标函数不会在线变更权重,而是不断更新权重的下一个版本。它们会以新的优化状态返回,并由 get_parameters 取回。
opt = optimizers.Adam() state = opt.init_state(params)for _ in range(10): state = opt.optimize(loss.apply, state, *next_batch()) # accelerate with jit=True trained_params = opt.get_parameters(state)
当需要对网络进行评价时:
test_loss = loss.apply(trained_params, *test_batch) # accelerate with jit=True
JAXnet 的正则化也十分简单:
loss = L2egularized(loss, scale = .1)
其他特性
除了简洁的代码,JAXnet 还支持在 GPU 上进行计算。而且还可以用 jit 进行编译,摆脱 Python 运行缓慢的问题。同时,JAXnet 是单步调试的,和 Python 代码一样。
安装也十分简单,使用 pip 安装即可。如果需要使用 GPU,则需要先安装 jaxlib。
其他具体的 API 可参考:https://github.com/JuliusKunze/jaxnet/blob/master/API.md
猜你喜欢
- 2024-10-12 新手教程:在新应用中实践深度学习的最佳建议
- 2024-10-12 一个新的基于样本数量计算的的高斯 softmax 函数
- 2024-10-12 TensorFlow和PyTorch相继发布最新版,有何变化
- 2024-10-12 来自特斯拉AI总监的“秘方”:这些训练神经网络的小技巧不能忽略
- 2024-10-12 小型深度学习框架 | TinyGrad,不到1K行代码(附代码下载)
- 2024-10-12 Pytorch损失函数简明教程 pytorch loss function
- 2024-10-12 深度学习实战第四课 深度学习基础教程
- 2024-10-12 Python pytorch 深度学习神经网络 softmax线性回归分类学习笔记
- 2024-10-12 收藏!PyTorch常用代码段合集 pytorch 编程
- 2024-10-12 深度神经网络模型训练中的最新tricks总结【原理与代码汇总】
你 发表评论:
欢迎- 最近发表
-
- 在 Spring Boot 项目中使用 activiti
- 开箱即用-activiti流程引擎(active 流程引擎)
- 在springBoot项目中整合使用activiti
- activiti中的网关是干什么的?(activiti包含网关)
- SpringBoot集成工作流Activiti(完整源码和配套文档)
- Activiti工作流介绍及使用(activiti工作流会签)
- SpringBoot集成工作流Activiti(实际项目演示)
- activiti工作流引擎(activiti工作流引擎怎么用)
- 工作流Activiti初体验及在数据库中生成的表
- Activiti工作流浅析(activiti6.0工作流引擎深度解析)
- 标签列表
-
- oraclesql优化 (66)
- 类的加载机制 (75)
- feignclient (62)
- 一致性hash算法 (71)
- dockfile (66)
- 锁机制 (57)
- javaresponse (60)
- 查看hive版本 (59)
- phpworkerman (57)
- spark算子 (58)
- vue双向绑定的原理 (68)
- springbootget请求 (58)
- docker网络三种模式 (67)
- spring控制反转 (71)
- data:image/jpeg (69)
- base64 (69)
- java分页 (64)
- kibanadocker (60)
- qabstracttablemodel (62)
- java生成pdf文件 (69)
- deletelater (62)
- com.aspose.words (58)
- android.mk (62)
- qopengl (73)
- epoch_millis (61)
本文暂时没有评论,来添加一个吧(●'◡'●)