torch.tensorでnanを0に置き換える

torch.tensorのnanを数値に置き換えるには、torch.nan_to_num()を使用すると簡単に置き換えられます。

x = torch.tensor([float('nan'), 1.0, 1.0, 1.0])
zero_x = torch.nan_to_num(x)
print(zero_x)

# >>> tensor([0., 1., 1., 1.])

なお、torch.nan_to_num()はnanだけでなくinfや-infも置換することができます。

x = torch.tensor([float('nan'), float('inf'), -float('inf'), 1.0])
zero_x = torch.nan_to_num(x)
print(zero_x)

# >>> tensor([ 0.0000e+00,  3.4028e+38, -3.4028e+38,  1.0000e+00])