pytorch中squeeze()和unsqueeze()函数

squeeze:  挤压

import torch

data = torch.range(1, 6).view(2, 3)
print(data)
# tensor([[1., 2., 3.],
#         [4., 5., 6.]])

print(data.shape)
# torch.Size([2, 3])

# 新增维度
data1 = data.unsqueeze(1)
print(data1.shape)
# torch.Size([2, 1, 3])

# 删除维度,只能维度为1的才能被删除
data2 = data1.squeeze(1)
print(data2.shape)
# torch.Size([2, 3])


标签: 、面试
  • 回复
隐藏