torch.bmm
函数是 PyTorch 中用于执行批量矩阵乘法(Batch Matrix-Matrix Multiplication)的函数。它的名字 "bmm" 表示 "batch matrix multiplication"。
注意:torch.bmm()是不带广播机制的,也就说需按照矩阵运算机制。
比如:[B,3,4]*[B,4,5]是可以的,而[B,3,2]*[B,8,5]是不可以的。
torch.bmm(mat1, mat2)
mat1
和 mat2
是两个三维张量(或者可以被广播为三维张量)。
mat1
的形状为 (batch, n, m)
。mat2
的形状为 (batch, m, p)
。(batch, n, p)
import torch
# 创建两个三维张量
mat1 = torch.rand(3, 2, 4)
mat2 = torch.rand(3, 4, 3)
# 执行批量矩阵乘法
result = torch.bmm(mat1, mat2)
print(result.shape) # 输出: torch.Size([3, 2, 3])
在上面的示例中,mat1
的形状是 (3, 2, 4)
,mat2
的形状是 (3, 4, 3)
。执行 torch.bmm(mat1, mat2)
将得到一个形状为 (3, 2, 3)
的张量,其中 3
是 batch
的大小,2
是 mat1
的行数,3
是 mat2
的列数。
因篇幅问题不能全部显示,请点此查看更多更全内容