Summary of function in Pytorch
torch.stack(a, b, dim)
# 假设是时间步T1的输出
T1 = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# 假设是时间步T2的输出
T2 = torch.tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]])
print(torch.stack((T1,T2)).shape) # torch.Size([2, 3, 3])
print(torch.stack((T1,T2),dim=0))
# tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],
[[10, 20, 30],
[40, 50, 60],
[70, 80, 90]]])
print(torch.stack((T1,T2),dim=1)) # [3, 2, 3]
# tensor([[[ 1, 2, 3],
[10, 20, 30]],
[[ 4, 5, 6],
[40, 50, 60]],
[[ 7, 8, 9],
[70, 80, 90]]])
print(torch.stack((T1,T2),dim=2)) # [3, 3, 2]
# tensor([[[ 1, 10],
[ 2, 20],
[ 3, 30]],
[[ 4, 40],
[ 5, 50],
[ 6, 60]],
[[ 7, 70],
[ 8, 80],
[ 9, 90]]])
torch.cat(tensors, dim=0)
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497]])
>>> torch.cat((x, x, x), 0)
tensor([[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497],
[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497],
[ 0.6580, -1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497]])
>>> torch.cat((x, x, x), 1)
tensor([[ 0.6580, -1.0969, -0.4614, 0.6580, -1.0969, -0.4614, 0.6580,
-1.0969, -0.4614],
[-0.1034, -0.5790, 0.1497, -0.1034, -0.5790, 0.1497, -0.1034,
-0.5790, 0.1497]])
torch.unsqueeze()
data1=torch.randn((2,3))
print(data1,data1.dim(),data1.size())
# tensor([[ 1.1373, 0.1755, -0.3572],
[ 0.3606, -0.4550, -1.0797]])
# torch.Size([2, 3])
data2=torch.unsqueeze(data1,0)
print(data2)
print(data2.dim(),data2.size())
# tensor([[[ 1.1373, 0.1755, -0.3572],
[ 0.3606, -0.4550, -1.0797]]])
# torch.Size([1, 2, 3])
data3=torch.unsqueeze(data1,1)
print(data3)
print(data3.dim(),data3.size())
# tensor([[[1.1373, 0.1755, -0.3572]],
[[0.3606, -0.4550, -1.0797]]])
# torch.Size([2, 1, 3])
data4=torch.unsqueeze(data1,-1)
print(data4)
print(data4.dim(),data4.size())
# tensor([[[ 1.1373],
[ 0.1755],
[ -0.3572]],
[[ 0.3606],
[ -0.4550],
[ -1.0797]]])
# torch.Size([2, 3, 1])
torch.repeat_interleave(input, repeats, dim)
input (类型:torch.Tensor):输入张量
repeats(类型:int或torch.Tensor):每个元素的重复次数。repeats参数会被广播来适应输入张量的维度
dim(类型:int)需要重复的维度。默认情况下,将把输入张量展平(flatten)为向量,然后将每个元素重复repeats次,并返回重复后的张量。默认为None。
# pos = pos.repeat_interleave(repeats, dim=None)
# pos = torch. repeat_interleave(input, repeats, dim)
>>> x = torch.tensor([1, 2, 3])
>>> x.repeat_interleave(2)
tensor([1, 1, 2, 2, 3, 3])
# 传入多维张量,默认`flatten`
>>> y = torch.tensor([[1, 2], [3, 4]])
>>> torch.repeat_interleave(y, 2)
tensor([1, 1, 2, 2, 3, 3, 4, 4])
# 指定维度
>>> torch.repeat_interleave(y,3,0)
tensor([[1, 2],
[1, 2],
[1, 2],
[3, 4],
[3, 4],
[3, 4]])
>>> torch.repeat_interleave(y, 3, dim=1)
tensor([[1, 1, 1, 2, 2, 2],
[3, 3, 3, 4, 4, 4]])
# 指定不同元素重复不同次数
>>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0)
tensor([[1, 2],
[3, 4],
[3, 4]])
>>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=1)
tensor([[1, 2, 2],
[3, 4, 4]])
torch.nn.Flatten(start_dim=1, end_dim=-1)
注意默认开始的维度是第一维!
input = torch.randn(32, 1, 5, 5)
# With default parameters
m = nn.Flatten()
output = m(input)
output.size()
# torch.Size([32, 25])
# With non-default parameters
m = nn.Flatten(0, 2)
output = m(input)
output.size()
# torch.Size([160, 5])