Skip to main content

Pytorch Cheatsheet

PyTorch provides two high-level features :

  • Tensor computation (like numpy) with strong GPU acceleration
  • Deep Neural Networks built on an autodiff system

Basic

https://github.com/xbresson/CS5242_2025

训练相关

  • 一个MLP有两个compution graph,一个前向,一个后向
  • requires_grad_()会记住前向传播中所有的操作,然后在后向传播loss.backward()重播
  • deatch(),每个minibatch都要deatch(),用以从compution graph中deatch出来,也就是从内存中卸载,避免OOM

torch.Tensor() vs torch.tensor()

省流:为了使用浮点数,直接用大写Tensor。

  • torch.Tensor() creates a tensor with the default data type, as defined by torch.get_default_dtype().
  • torch.tensor() will infer data type from the data.

For example:

>>> torch.Tensor([1, 2, 3]).dtype
torch.float32

>>> torch.tensor([1, 2, 3]).dtype
torch.int64

Pytorch将张量的第一维视为batch_size

PyTorch 会自动将张量的第一维视为 batch_size,即批次大小。默认情况下,PyTorch 的大多数神经网络模块(如 nn.Linearnn.Conv2d 等)和操作(如 torch.meantorch.sum 等)会沿着第一维进行并行计算,通过批次维度同时传入模型。,从而让整个批次的样本同时通过模型。

  1. 批次张量的创建:在数据预处理或加载阶段,通常会将多个样本(图像、序列等)打包成一个批次张量。例如,一个形状为 (batch_size, 3, 224, 224) 的张量表示一个包含 batch_size 个样本的图像批次,每个图像的通道数为 3,大小为 224x224

  2. 并行计算:当批次张量输入到模型的 forward 方法中时,PyTorch 会自动沿批次维度进行广播,逐样本地对每个样本执行相同的计算。例如,对于 nn.Linear 层的前馈操作,它会对批次中的每个样本应用相同的权重矩阵。

  3. 自动微分和批次处理:通过批次操作,PyTorch 能高效地利用 GPU 加速。批次操作可以充分利用 GPU 的并行计算能力,并在反向传播过程中同时计算每个样本的梯度。这在训练过程中可以提高计算速度并稳定模型的更新。

以下是一个简单示例,展示了如何用批次张量输入并同时处理多个批次:

import torch
import torch.nn as nn

# 定义一个简单的全连接层
class SimpleModel(nn.Module):
def __init__(self, input_size, output_size):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(input_size, output_size)

def forward(self, x):
return self.fc(x)

# 实例化模型
model = SimpleModel(input_size=10, output_size=5)

# 创建一个包含 32 个样本的批次,每个样本的特征向量大小为 10
batch_input = torch.randn(32, 10)

# 通过模型进行前向传播,计算同时应用于 32 个样本
output = model(batch_input)

# 输出形状:torch.Size([32, 5])
print(output.shape)

在这个示例中,batch_input 是一个包含 32 个样本的张量(即 batch_size=32),每个样本具有 10 个特征。在前向传播时,模型的 fc 层会同时对这 32 个样本应用相同的权重,并生成一个形状为 (32, 5) 的输出,表示每个样本的 5 个输出特征。

这种批次处理方式让模型在 GPU 上高效运行,因为 GPU 能够一次性处理整个批次的数据,极大地提高了并行计算能力,降低了训练和推理的时间。

Reference

  1. https://pytorch.org/tutorials/beginner/pytorch_with_examples.html#warm-up-numpy
  2. https://github.com/xbresson/CS5242_2021