使用pytorch 筛选出一定范围的值
我就废话不多说了,大家还是直接看代码吧~ import torch input_tensor = torch.tensor([1,2,3,4,5]) print(input_tensor>3) mask = (input_tensor>3).nonzero() print(mask) print(input_tensor.index_select(0,mask)) tensor([0, 0, 0, 1, 1], dtype=torch.uint8) tensor([3, 4]) tensor([4, 5]) 补充知识:pytorch tensor筛选满足条件的行或列(使用与或) 我就废话不
用户评论