masked_fill【将mask中值为True的位置对应的待填充的张量设置为value值】

(31) 2024-05-06 20:01:01

masked_fill方法有两个参数,maske和value,mask是一个pytorch张量(Tensor),元素是布尔值,value是要填充的值,填充规则是mask中取值为True位置对应于待填充的相应位置用value填充。

import torch
a=torch.tensor([[[5,5,5,5], [6,6,6,6], [7,7,7,7]], [[1,1,1,1],[2,2,2,2],[3,3,3,3]]])
print(a)
print(a.size())
print("#############################################3")
mask = torch.ByteTensor([[[1],[1],[0]],[[0],[1],[1]]])
print(mask.size())
b = a.masked_fill(mask, value=torch.tensor(-1e9))
print(b)
print(b.size())

其实就是只要mask中的布尔值为True的话,就将待填充的对应的位置(此程序中的a)设置为value值 

masked_fill【将mask中值为True的位置对应的待填充的张量设置为value值】 (https://mushiming.com/)  第1张

THE END

发表回复