transformer_in_transformer_flax 源码
JAX / Flax中的Transformer中的Transformer 此存储库实现 ,像素级注意与补丁级注意配对以进行图像分类。 Pytorch实施和Vision Transformer回购在很大程度上受到了启发。 安装 $ pip install transformer-in-transformer-flax 用法 from jax import random from jax import numpy as jnp from transformer_in_transformer_flax import TransformerInTransformer , TNTConfig #example configuration for TNT-B config = TNTConfig ( num_classes = 1000 , depth = 12 , imag
下载地址
用户评论