1. 首页
  2. 数据库
  3. 其它
  4. 使用pytorch 筛选出一定范围的值

使用pytorch 筛选出一定范围的值

上传者: 2020-12-22 10:34:56上传 PDF文件 31KB 热度 19次
我就废话不多说了,大家还是直接看代码吧~ import torch input_tensor = torch.tensor([1,2,3,4,5]) print(input_tensor>3) mask = (input_tensor>3).nonzero() print(mask) print(input_tensor.index_select(0,mask)) tensor([0, 0, 0, 1, 1], dtype=torch.uint8) tensor([3, 4]) tensor([4, 5]) 补充知识:pytorch tensor筛选满足条件的行或列(使用与或) 我就废话不
用户评论