1. 首页
  2. 数据库
  3. 其它
  4. Pytorch中torch.gather函数

Pytorch中torch.gather函数

上传者: 2020-12-23 00:30:36上传 PDF文件 35.14KB 热度 19次
在学习 CS231n中的NetworkVisualization-PyTorch任务,讲解了使用torch.gather函数,gather函数是用来根据你输入的位置索引 index,来对张量位置的数据进行合并,然后再输出。 其中 gather有两种使用方式,一种为 torch.gather 另一种为 对象.gather。 首先介绍 对象.gather import torch torch.manual_seed(2) #为CPU设置种子用于生成随机数,以使得结果是确定的 def gather_example(): N, C = 4, 5 s = torch.randn(N,
用户评论