1. 首页
  2. 数据库
  3. 其它
  4. Pytorch的gather()和scatter()

Pytorch的gather()和scatter()

上传者: 2021-01-31 20:19:41上传 PDF文件 61.65KB 热度 28次
Pytorch的gather()和scatter() 1.gather() gather是取的意思,意为把某一tensor矩阵按照一个索引序列index取出,组成一个新的矩阵。 gather(input,dim,index) 参数: input是要取值的矩阵 dim指操作的维度,0为竖向操作即按行操作,1为横向操作即按列操作 index为索引序列 下面这个例子是按行取出第一行的’0号元素’,’0行元素’组成新的第一行; 再取出第二行的‘1号元素’,‘0号元素’组成新的第二行 a = torch.Tensor([[1,2],[3,4]]) b = torch.gather(a, 1, torch
用户评论