在本文中,我们将看看使用 PyTorch torch.max() 函数。
正如你所期望的,这是一个非常简单的功能,但有趣的是,它比你想象的要多。
让我们看看使用这个功能,使用一些简单的例子。
注:在写作时,使用的 PyTorch 版本为 ** PyTorch 1.5.0**
PyTorch torch.max() - 基本语法
要使用 PyTorch torch.max()
,先导入 torch
。
1import torch
现在,这个函数返回Tensor中的元素中的最大值。
PyTorch torch.max() 的默认行为
默认行为是返回一个单个元素和一个索引,相当于全球最大元素。
1max_element = torch.max(input_tensor)
这里有一个例子:
1p = torch.randn([2, 3])
2print(p)
3max_element = torch.max(p)
4print(max_element)
出发点( )
1tensor([[-0.0665, 2.7976, 0.9753],
2 [ 0.0688, -1.0376, 1.4443]])
3tensor(2.7976)
事实上,这给了我们Tensor中的全球最大元素!
使用 torch.max() 沿一个维度
但是,您可能希望在特定维度中获得最大值,例如 [Tensor]( / 社区 / 教程 / Pytorch-tensor),而不是单个元素。
要指定尺寸(轴 - 在numpy
中),还有另一个可选的关键字参数,称为dim
这代表了我们采取的最大方向。
这会返回一个 tuple,‘max_elements’和‘max_indices’。
max_elements
-> Tensor 的所有最大元素max_indices
-> 对应最大元素的索引
1max_elements, max_indices = torch.max(input_tensor, dim)
这将返回一个Tensor,该元素具有维度dim
沿的最大元素。
现在让我们看看一些例子。
1p = torch.randn([2, 3])
2print(p)
3
4# Get the maximum along dim = 0 (axis = 0)
5max_elements, max_idxs = torch.max(p, dim=0)
6print(max_elements)
7print(max_idxs)
出发点( )
1tensor([[-0.0665, 2.7976, 0.9753],
2 [ 0.0688, -1.0376, 1.4443]])
3tensor([0.0688, 2.7976, 1.4443])
4tensor([1, 0, 1])
正如你可以看到的那样,我们在维度0(沿列的最大)中找到最大值。
此外,我们可以获得对应元素的索引,例如,0.0688
在列 0 沿着索引 1
。
同样,如果要在行中找到最大值,请使用dim=1
。
1# Get the maximum along dim = 1 (axis = 1)
2max_elements, max_idxs = torch.max(p, dim=1)
3print(max_elements)
4print(max_idxs)
出发点( )
1tensor([2.7976, 1.4443])
2tensor([1, 2])
事实上,我们得到沿行的最大元素,以及相应的索引(沿行)。
使用 torch.max() 进行比较
我们还可以使用torch.max()来获取两个传感器之间的最大值。
1output_tensor = torch.max(a, b)
在这里,a
和b
必须具有相同的尺寸,或者必须是广播的
传感器。
下面是一个简单的例子来比较两个具有相同尺寸的传感器。
1p = torch.randn([2, 3])
2q = torch.randn([2, 3])
3
4print("p =", p)
5print("q =",q)
6
7# Compare elements of p and q and get the maximum
8max_elements = torch.max(p, q)
9
10print(max_elements)
出发点( )
1p = tensor([[-0.0665, 2.7976, 0.9753],
2 [ 0.0688, -1.0376, 1.4443]])
3q = tensor([[-0.0678, 0.2042, 0.8254],
4 [-0.1530, 0.0581, -0.3694]])
5tensor([[-0.0665, 2.7976, 0.9753],
6 [ 0.0688, 0.0581, 1.4443]])
事实上,我们得出输出 tensor 的最大元素在p
和q
之间。
结论
在本文中,我们了解了如何使用 torch.max() 函数来找出一个 Tensor 的最大元素。
我们还使用了这个函数来比较两个压缩器,并获得它们之间的最大值。
对于类似的文章,请浏览我们的内容在我们的 [PyTorch 教程]( / 社区 / 教程 / pytorch)! 留意更多!
参考
- PyTorch 官方文件 在 torch.max()