目录

目录

PyTorch Tensor 工具函数技术文档:创建、计算、拼接与索引

目录
章节主题内容概要
一、总览整体架构与知识结构张量运算分类、各部分作用与关联、cat vs stack / view vs reshape 等对比、适用场景
二、张量创建与属性创建与 dtype/device/shapetensor/zeros/ones/empty/full、arange/linspace/logspace、rand/randn、from_numpy、dtype/device/ndim/shape
三、形状与视图变形与视图view/reshape、squeeze/unsqueeze、flatten、transpose/permute、contiguous、unfold/movedim/t()、view 与 copy 区别
四、拼接与堆叠cat 与 stackcat(沿已有维拼接)、stack(新维堆叠)、参数与约束、典型用法
五、拆分沿维拆分chunk、split、unbind、vsplit/hsplit/dsplit、与 cat 的互逆关系
六、逐元与广播运算四则与广播及更多逐元add/sub/mul/div、广播、log2/log10/exp2/expm1、floor_divide/div(rounding_mode)、pow/sqrt/rsqrt、minimum/maximum、lerp/frac、atan2/hypot、inplace
七、矩阵与线性代数矩阵乘、外积、叉积与 linalgmatmul/mm/bmm、dot/vdot、mv、ger/outer、cross、trace、linalg.inv/solve/det/svd/cholesky/qr/eig、linalg.norm/pinv/matrix_rank、tensordot、meshgrid
八、归约聚合、累积、排序与去重sum/mean/max/min/prod、argmax/argmin、norm/std/var、any/all/count_nonzero、cumsum/cumprod/cummax/cummin/logcumsumexp、median/nanmedian、nanmean/nansum、quantile/logsumexp/dist、sort/topk/kthvalue、unique
九、索引与切片基础与高级索引下标/切片、高级索引、masked_select/index_select/take、put_、nonzero、index_fill_、masked_fill_、masked_scatter_、index_put_
十、高级索引写回gather/scattergather、scatter/scatter_、index_add_/index_copy_、scatter_reduce_/index_reduce_、公式与示例
十一、比较与逻辑比较、where 与数值判断eq/ne/gt/ge/lt/le、where、logical_and/or/not、isfinite/isnan/isinf、isclose、equal/allclose、is_tensor/is_floating_point/is_complex
十二、其它工具clone/expand/einsum/设备转换等clone、expand/repeat、narrow、einsum、roll、clamp、addcdiv/addcmul、tril/triu、renorm、设备与 NumPy 转换、复数与 torch.special、diff/broadcast_tensors/atleast_1d/bucketize/searchsorted、resize_/resize_as_、result_type/promote_types、uniform_/normal_/exponential_/bernoulli_
十三、速查与小结速查与延伸分类速查表、常见遗漏清单、与 training_solver_guide/dataloader_guide 的衔接

阅读建议:先读总览建立「创建 → 形状 → 拼接/拆分 → 运算 → 归约 → 索引」的全局图景,再按需跳转具体 API;实现时按「先创建/再变形/再运算/再索引」顺序对照各章。


目标:在 PyTorch 中,张量(Tensor)是模型输入输出与中间计算的基本单位;工具函数负责对张量进行创建、变形、拼接/拆分、逐元/矩阵运算、归约、索引等,不涉及自动求导或优化器,但构成前向计算与数据准备的基础。

分类概览

张量工具函数
├── 创建与属性     tensor(), zeros(), arange(), dtype, device, shape
├── 形状与视图     view(), reshape(), squeeze(), unsqueeze(), transpose(), permute()
├── 拼接与堆叠     cat(), stack()
├── 拆分           chunk(), split(), unbind()
├── 逐元与广播     add(), mul(), pow(), sqrt(), minimum(), maximum(), lerp(), atan2(), 四则与广播
├── 矩阵运算       matmul(), mm(), bmm(), dot(), ger(), cross(), linalg.inv(), trace()
├── 归约           sum(), mean(), max(), min(), cumsum(), cumprod(), cummax(), cummin(), logcumsumexp(), sort(), topk(), unique(), any(), all()
├── 索引与切片     下标/切片, index_select(), masked_select(), take(), put_()
├── 高级索引       gather(), scatter(), index_add_(), index_copy_()
├── 比较与逻辑     eq(), gt(), where(), logical_and()
└── 其它           clone(), expand(), repeat(), einsum(), narrow(), clamp(), addcdiv(), tril(), triu(), renorm()
部分作用/主题与其它部分的关系
创建与属性从数值/序列/NumPy 得到张量,并控制 dtype/device/shape所有后续运算的前提;属性被 reshape/cat 等依赖
形状与视图在不改数据的前提下改变逻辑形状,或压缩/增加维度为 cat/stack、矩阵乘、归约提供正确维度
拼接与堆叠沿已有维拼接(cat)或沿新维堆叠(stack)与拆分互逆;常与 DataLoader 的 batch 维度配合
拆分沿某维切成多块(chunk/split)或拆成多个张量(unbind)与 cat 互逆;用于序列/多尺度处理
逐元与广播标量/同形/广播下的四则与函数(exp/log 等)损失、归一化、掩码运算的基础
矩阵运算向量/矩阵/批量矩阵乘与点积线性层、注意力、全连接的核心
归约沿维或全局聚合成标量或更小形状损失标量、全局池化、统计量
索引与切片按下标/切片/布尔/索引张量取值为 gather/scatter、掩码、条件选数提供基础
高级索引按索引张量聚集(gather)或写回(scatter)灵活重排、one-hot、稀疏更新
比较与逻辑比较与逻辑运算,常得到布尔张量与 where、掩码、条件分支配合
其它clone/expand/repeat/einsum/narrow/clamp复制、扩维、爱因斯坦求和、边界裁剪等
维度torch.cattorch.stack
含义沿已有维度拼接,维数不变维度上堆叠,维数 +1
形状dim 外同形,dim 上相加所有输入同形,输出多一维
典型多个 batch 拼成一个大 batch多个同形张量做成「批量」
示例(2,3) cat (2,3) → (4,3)(2,3) stack (2,3) → (2,2,3)
维度view()reshape()
连续性要求内存连续,否则先 .contiguous()不连续时可返回 copy,更宽松
返回值恒为视图(共享存储)可能视图或 copy
建议确定连续时用 view(略快)不确定或跨设备时用 reshape
维度expand()repeat()
实现仅改步长,不拷贝数据真实复制数据
约束只能把 1 扩成 n,不能 2→3任意倍数
内存不增显存随倍数增加

创建类 API 分为:从数值/列表填充常数序列随机从 NumPy/共享内存类同形状。属性包括 dtype、device、ndim、shape、stride、is_contiguous 等。

  • torch.tensor(data, dtype=None, device=None)
    是什么:从 Python 列表/标量构造新张量,数据会被拷贝。
    为什么需要:训练中把标量损失、列表标签等转为张量以便参与计算与 backward。
    注意:大列表用 tensor 会慢,可考虑 as_tensor 或 from_numpy 减少拷贝。
import torch
x = torch.tensor([1, 2, 3])
y = torch.tensor([[1., 2], [3, 4]], dtype=torch.float32)
  • torch.as_tensor(data, dtype=None, device=None)
    尽量共享 data 内存(若已是 tensor/ndarray 且 dtype/device 兼容则不拷贝),否则等价于 tensor。
  • torch.from_numpy(ndarray)
    与 NumPy 共享内存,仅支持 CPU;改一边另一边会变。
  • torch.zeros(size, dtype=None, device=None)
    全 0;常用于初始化或占位。
  • torch.ones(size, …)
    全 1;如初始化偏置、mask。
  • torch.empty(size, …)
    未初始化,值未定义;适合紧接着就地写入(如 uniform_)。
  • torch.full(size, fill_value, …)
    用标量 fill_value 填满。
a = torch.zeros(2, 3)
b = torch.ones(2, 3)
c = torch.full((2, 3), 3.14)
  • torch.arange(start=0, end, step=1, …)
    一维序列,不包含 end;整数或浮点。
  • torch.linspace(start, end, steps, …)
    等间隔 steps 个点,包含两端。
  • torch.logspace(start, end, steps, base=10.0, …)
    对数值等间隔。
torch.arange(5)           # [0,1,2,3,4]
torch.linspace(0, 1, 5)   # [0, 0.25, 0.5, 0.75, 1]
  • torch.rand(size) [0,1) 均匀分布。
  • torch.randn(size)
    标准正态。
  • torch.randint(low, high, size)
    离散均匀整数,区间 [low, high)。
  • torch.randperm(n)
    0 到 n-1 的随机排列。
  • torch.eye(n, m=None, …)
    单位矩阵;m 省略时方阵。
  • torch.diag(input, diagonal=0)
    提取对角线或构造对角阵:input 1D 则构造 2D 对角阵,input 2D 则取对角线为 1D。
torch.eye(3)           # (3,3) 单位阵
torch.diag(torch.tensor([1.,2.,3.]))  # (3,3) 对角阵
  • torch.zeros_like(input), torch.ones_like(input), torch.empty_like(input)
    与 input 同 shape(默认同 dtype/device)。
  • torch.full_like(input, fill_value)
    与 input 同形并填常量 fill_value;与 full 对应。
  • torch.complex(real, imag)
    用实部张量 real 与虚部张量 imag(同形)构造复数张量,dtype 为 complex。
  • tensor.dtype:元素类型(如 torch.float32, torch.int64)。
  • tensor.device:所在设备(cpu 或 cuda:0 等)。
  • tensor.shape / tensor.size():形状。
  • tensor.ndim:维数。
  • tensor.numel():元素总数。
  • tensor.stride():各维步长。
  • tensor.is_contiguous():是否内存连续。
x = torch.randn(2, 3)
print(x.shape, x.dtype, x.device, x.numel())

变形分为:不增删维度(view/reshape)、删维(squeeze/flatten)、增维(unsqueeze)、换维顺序(transpose/permute)。视图与存储布局、contiguous 紧密相关。

  • tensor.view(shape)
    是什么:在满足元素总数不变的前提下,把张量解释为新形状,返回视图*(共享存储)。
    为什么需要:线性层输入要二维 (N, C)、卷积要四维 (N,C,H,W),需把原始维度「压平」或重组。
    约束:当前张量必须内存连续;否则先 .contiguous() 再 view。
  • tensor.reshape(shape)
    语义同 view,但若形状兼容且不连续,PyTorch 会返回 copy 而非视图,因此更安全、略慢。
x = torch.arange(6)
y = x.view(2, 3)
z = x.reshape(2, 3)
assert y.data_ptr() == x.data_ptr()
  • tensor.squeeze(dim=None)
    是什么:去掉大小为 1 的维度;若 dim 为 None 则去掉所有 1 维。
    为什么需要:如 (1, C, 1, 1) 经池化后要去掉 1 维以便后续全连接。
  • tensor.unsqueeze(dim)
    在 dim 处插入大小为 1 的维度;常用于「加 batch 维」或「加通道维」以符合广播或 API 要求。
x = torch.randn(1, 3, 1, 4)
y = x.squeeze()      # (3, 4)
z = y.unsqueeze(0)   # (1, 3, 4)
  • tensor.flatten(start_dim=0, end_dim=-1)
    从 start_dim 到 end_dim(含)拉成一维,等价于该区间内 shape 连乘;常用于全连接前把空间维压平。
x = torch.randn(2, 3, 4, 5)
y = x.flatten(2)     # (2, 3, 20)
z = x.flatten()      # (120,)
  • tensor.transpose(dim0, dim1)
    交换两维;常用于「通道最后」与「通道在前」格式互转(如 HWC ↔ CHW)。
  • tensor.permute(dims)
    按 dims 重排所有维度,可一次完成多维重排。
x = torch.randn(2, 3, 4)  # (N, C, L)
y = x.permute(0, 2, 1)    # (N, L, C)
  • tensor.contiguous()
    若已连续则返回自身,否则返回一份连续拷贝;在 view/reshape 前若报错「non-contiguous」可先调用。
  • tensor.unfold(dim, size, step)
    沿 dim 做滑动窗口,窗口大小为 size、步长为 step;输出多一维(该维长度为窗口个数)。常用于 patch 提取、局部卷积。
  • torch.movedim(input, source, destination)
    把 source 维移到 destination 位置;单维移动时比 permute 更直观。
  • tensor.t()
    仅对 2D 有效,转置的简写。
  • torch.swapdims(input, dim0, dim1)
    交换两维,与 transpose(dim0, dim1) 等价;语义更直观。

  • torch.cat(tensors, dim=0, , out=None)
    是什么:在已有*维度 dim 上把多个张量首尾相接,输出维数与输入相同。
    为什么需要:把多个小 batch 拼成一个大 batch、把多段序列拼成一条等。
    约束:除 dim 外,其它维度形状必须相同。
  • torch.stack(tensors, dim=0, , out=None)
    是什么:在*维度 dim 上堆叠,所有输入形状相同,输出多一维。
    为什么需要:把多个同形张量做成「批量」以便批量运算(如多张图 stack 成 (N,C,H,W))。
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])
torch.cat([a, b], dim=0)   # (4, 2)
torch.cat([a, b], dim=1)   # (2, 4)
torch.stack([a, b], dim=0) # (2, 2, 2)

公式(概念):

  • cat:沿 dim,(,Li,)(\ldots, L_i, \ldots)(,Lj,)(\ldots, L_j, \ldots)(,Li+Lj,)(\ldots, L_i+L_j, \ldots)
  • stack:在 dim 插入新维,()(\ldots)()(\ldots)(2,)(2, \ldots)(2 为列表长度)。

  • torch.chunk(input, chunks, dim=0)
    把 dim 维均分成 chunks 块(若不能整除则前几块多 1);返回 tuple of tensors。
  • torch.split(tensor, split_size_or_sections, dim=0)
    split_size_or_sections 为整数时,按该大小切块(最后一块可更小);为 list 时,按各段长度切分。
    与 cat 互逆:torch.cat(torch.split(x, sizes, dim), dim) == x
x = torch.arange(10).view(2, 5)
list(torch.chunk(x, 2, dim=0))   # 两个 (1,5)
list(torch.split(x, [2, 3], dim=1))  # (2,2) 与 (2,3)
  • torch.unbind(input, dim=0)
    沿 dim 拆成多个张量,每个在 dim 上长度为 1 并去掉该维;返回 tuple。
x = torch.randn(3, 4)
list(torch.unbind(x, dim=0))  # 3 个 (4,)
  • torch.vsplit(input, indices_or_sections)
    沿 dim=0 切分,等价于 split(..., dim=0);2D 时按行切块。
  • torch.hsplit(input, indices_or_sections)
    沿 dim=1 切分;2D 时按列切块。
  • torch.dsplit(input, indices_or_sections)
    沿 dim=2 切分;用于 3D 及以上(如多通道特征图按通道切)。
x = torch.arange(12).view(3, 4)
torch.vsplit(x, 3)   # 3 个 (1,4)
torch.hsplit(x, [2, 3])  # (3,2), (3,1), (3,1)

  • torch.add / sub / mul / div(或 + - * /
    逐元素运算;若一个为标量则广播到另一张量形状。

为什么需要广播:向量与标量运算、通道级缩放(如 (C,) 与 (N,C,H,W))无需显式扩展,代码简洁且高效。

广播规则:从右对齐 shape,维数不足的左侧补 1;每维要么相等要么其一为 1。

a = torch.randn(2, 3)
b = torch.randn(2, 3)
c = a + b
d = a * 2
e = torch.randn(3)
f = a + e  # (2,3) + (3,) -> (2,3)
  • torch.exp / log / log1p / sqrt / abs
    指数、自然对数、log(1+x)、平方根、绝对值。
  • torch.log2(input) / torch.log10(input)
    以 2、10 为底的对数;信息论、量纲常用。
  • torch.exp2(input) / torch.expm1(input)
    2x2^xex1e^x - 1;expm1 在 x 接近 0 时数值更稳。
  • torch.sin / cos / tan / sinh / cosh / tanh
    三角函数与双曲函数。
  • torch.sigmoid / tanh
    常用激活形状。
  • torch.round / floor / ceil / trunc
    取整。

用途:激活函数、损失中的 log、归一化中的 sqrt、角度计算等。

x = torch.randn(2, 3)
y = torch.exp(x)
z = torch.sigmoid(x)
  • torch.pow(input, exponent)
    逐元素幂;exponent 可为标量或同形张量。
  • torch.sqrt(input) / torch.rsqrt(input)
    平方根、平方根倒数(1/x1/\sqrt{x}),常用于归一化。
  • torch.neg(input) / torch.reciprocal(input)
    取负、倒数;reciprocal 在 0 处为 inf。
  • torch.sign(input)
    符号:正为 1、负为 -1、零为 0。
  • torch.abs(input)
    绝对值(已有);torch.positive / torch.negative 恒等与取负。
x = torch.tensor([1., 4., 9.])
torch.sqrt(x)      # [1, 2, 3]
torch.pow(x, 0.5)  # 同 sqrt
torch.sign(torch.tensor([-0.1, 0., 0.1]))  # [-1, 0, 1]
  • torch.minimum(input, other) / torch.maximum(input, other)
    两张量逐元素取小/取大;与归约 min/max(单输入)不同。
  • torch.fmod(input, divisor) / torch.remainder(input, divisor)
    逐元素取模;fmod 符号与 input 一致,remainder 与 divisor 一致。
  • torch.lerp(start, end, weight)
    线性插值:start + weight * (end - start);weight 可为标量或与 start 同形。
  • torch.frac(input)
    逐元素取小数部分(等价于 x - trunc(x))。
  • torch.heaviside(input, values)
    阶跃:input < 0 为 0,input > 0 为 1,input == 0 取 values 对应位置。
  • torch.floor_divide(input, other) / a // b
    向下取整除法;与 trunc 向零取整不同。
  • torch.div(input, other, rounding_mode=None)
    rounding_mode 为 'floor''trunc' 时可控制舍入方式,替代 / 的默认行为。
a = torch.tensor([1., 3., 2.])
b = torch.tensor([2., 1., 4.])
torch.minimum(a, b)   # [1, 1, 2]
torch.lerp(a, b, 0.5) # [1.5, 2, 3]
  • torch.atan2(input, other)
    逐元素 atan2(y, x),返回值域 (π,π](-\pi,\pi],常用于角度/方向。
  • torch.hypot(input, other)
    逐元素 input2+other2\sqrt{\texttt{input}^2 + \texttt{other}^2}
  • torch.atan / asin / acos
    反三角函数,定义域需在 [-1,1](asin/acos)内。
  • a.add_(b)a.mul_(b) 等带下划线版本为 inplace;a += b 等价于 a.add_(b)
  • 多数逐元函数有 out 参数可写回指定张量,如 torch.add(a, b, out=c)

  • torch.matmul(input, other)
    通用矩阵乘:1d@1d 标量,2d@2d 矩阵乘,2d@1d 矩阵乘向量,(batch,2d)@(batch,2d) 批量乘等。
    为什么需要:线性层、注意力、全连接都依赖矩阵乘;matmul 统一处理多种形状。
  • torch.mm(a, b)
    严格 2D×2D,无广播。
  • torch.bmm(a, b)
    批量 3D×3D,(B,m,n) @ (B,n,p) → (B,m,p)。
  • torch.dot(a, b)
    仅 1D,内积标量。
  • torch.vdot(a, b)
    1D,先展平再共轭内积(实数为普通内积)。
  • torch.mv(matrix, vec)
    2D×1D → 1D。
  • torch.ger(vec1, vec2)
    两 1D 张量外积 → 2D:out[i,j] = vec1[i] * vec2[j]
  • torch.cross(input, other, dim=-1)
    叉积;默认最后一维为 3 维向量,沿该维做叉积。
  • torch.outer(input, other)
    与 ger 等价,两 1D 的外积。
u = torch.tensor([1., 2., 3.])
v = torch.tensor([4., 5., 6.])
torch.outer(u, v)   # (3,3) 外积
torch.cross(u, v)   # (3,) 叉积
  • torch.trace(input)
    方阵迹,标量。
  • torch.linalg.inv(A)
    方阵求逆;方程组求解用 torch.linalg.solve(A, B)(AX=B 求 X)。
  • torch.linalg.det(A)
    行列式。
  • torch.linalg.svd(A)
    奇异值分解,返回 U, S, Vh。
  • torch.linalg.cholesky(A)
    Cholesky 分解(A 需对称正定)。
  • torch.linalg.qr(A)
    QR 分解。
  • torch.linalg.eig(A) / torch.linalg.eigh(A)
    特征值与特征向量;eigh 用于厄米矩阵。
  • torch.linalg.norm(input, ord=None, dim=None)
    矩阵/向量范数:ord 为 'fro' 为 Frobenius 范数,'nuc' 为核范数,整数为对应 Lp。
  • torch.linalg.pinv(input)
    摩尔-彭罗斯伪逆;当矩阵非方或奇异时求「广义逆」。
  • torch.linalg.matrix_rank(input)
    矩阵秩(基于 SVD 的数值秩)。
  • torch.tensordot(a, b, dims=2)
    指定 dims 维做内积收缩:dims 为 int 表示两张量各取最后 dims 维收缩;为 (dims_a, dims_b) 表示 a 的 dims_a 维与 b 的 dims_b 维对应收缩。用于高阶张量乘法、注意力中的收缩等。
  • torch.meshgrid(tensors, indexing=‘ij’)
    由一维坐标张量生成网格;返回 N 个张量,形状由各维长度决定;indexing=‘xy’ 时前两维按笛卡尔(绘图常用),‘ij’ 为矩阵下标。用于生成 2D/3D 坐标网格、插值、物理场等。
A = torch.randn(3, 3)
A = A @ A.T + torch.eye(3)  # 对称正定
L = torch.linalg.cholesky(A)
torch.trace(A)

  • sum / mean / prod:和、平均、乘积。
  • max / min:最大值、最小值;可返回 (values, indices)。
  • argmax / argmin:最大/最小索引。
  • norm(p=2):范数,默认 L2。
  • std / var:标准差、方差。

dim:沿哪一维归约,该维消失(或保留见下)。
keepdim=True:保留该维为 1,便于后续广播(如对每通道求 mean 得 (1,C,1,1) 再减)。

  • torch.any(input, dim=None) / torch.all(input, dim=None)
    沿 dim 做逻辑或/与归约;dim 为 None 时对全部元素,返回 bool。
  • torch.count_nonzero(input, dim=None)
    沿 dim(或全局)统计非零个数。
  • torch.cumsum(input, dim) / torch.cumprod(input, dim)
    沿 dim 的累积和/累积积,输出形状与 input 相同;递进相加、递进相乘。
  • torch.cummax(input, dim) / torch.cummin(input, dim)
    沿 dim 的累积最大值/最小值,返回 (values, indices);递进取大、递进取小,常用于单调序列或分段最值。
  • torch.logcumsumexp(input, dim)
    log(kiexp(xk))\log(\sum_{k\le i} \exp(x_k)),数值稳定的递进 log-sum-exp;用于 CRF、序列 log 域累积。
  • torch.median(input, dim=None) / torch.nanmedian(input, dim=None)
    中位数;nanmedian 忽略 NaN。
  • torch.nanmean(input, dim=None) / torch.nansum(input, dim=None)
    忽略 NaN 的均值与和;含缺失值时常用。
  • torch.quantile(input, q, dim=None) / torch.nanquantile(input, q, dim=None)
    分位数;q 为标量或 1D,输出对应分位值;nanquantile 忽略 NaN。
  • torch.logsumexp(input, dim, keepdim=False)
    logexp(x)\log\sum\exp(x),数值稳定;常用于 log-softmax、对数域求和。
  • torch.mode(input, dim=-1)
    沿 dim 取众数,返回 (values, indices)。
  • torch.sort(input, dim=-1, descending=False)
    沿 dim 排序,返回 (values, indices)。
  • torch.topk(input, k, dim=-1)
    沿 dim 取最大的 k 个,返回 (values, indices)。
  • torch.kthvalue(input, k, dim=-1)
    沿 dim 取第 k 小值及索引。
x = torch.tensor([1., 2., 3., 4.])
torch.cumsum(x, dim=0)   # [1, 3, 6, 10] 递进相加
v, i = torch.cummax(x, dim=0)  # values [1,2,3,4], indices [0,1,2,3]
torch.logcumsumexp(torch.randn(5), dim=0)  # (5,) 递进 log-sum-exp
v, i = torch.topk(x, 2)  # values [4,3], indices [3,2]
  • torch.unique(input, sorted=True, return_inverse=False, return_counts=False)
    展平后去重;可返回逆映射、计数。
  • torch.unique_consecutive(input)
    仅去掉连续重复,保持顺序。
  • torch.dist(input, other, p=2)
    两张量间的 p-范数距离(标量);p=2 为欧氏距离。
x = torch.tensor([1, 2, 2, 3, 1])
torch.unique(x)  # [1, 2, 3]
torch.logsumexp(torch.randn(2, 3), dim=1)  # (2,) 数值稳定

与 NumPy 类似:整数、切片 start:stop:step、省略号 ...
注意:高级索引(张量索引、布尔张量)会触发复制而非视图。

x = torch.randn(3, 4, 5)
x[0]        # (4, 5)
x[:, 1:3]   # (3, 2, 5)
x[..., 0]   # (3, 4)
  • torch.index_select(input, dim, index)
    在 dim 维上按 1D 长整型 index 取子张量;输出在 dim 上长度为 len(index)。
  • torch.masked_select(input, mask)
    按与 input 同形的 bool mask 选出 True 位置元素,返回 1D(复制)。
  • torch.take(input, index)
    把 input 视为 1D,按 index 中线性下标取值,输出形状同 index。
  • tensor.put_(index, source, accumulate=False)
    与 take 互逆:把 input 视为 1D,在 index 指定的线性下标位置写入 source 中对应元素;accumulate=True 时累加而非覆盖。用于按「展平下标」批量写回。
  • torch.nonzero(input)
    返回非零元素的坐标,形状为 (N, ndim);与布尔掩码配合取下标、或做稀疏索引。
  • tensor.index_fill_(dim, index, value)
    在 dim 维的 index 指定位置填 value。
  • tensor.masked_fill_(mask, value)
    在 mask 为 True 的位置填 value;常用于把 padding 位置设为 0 或 -inf。
  • tensor.masked_scatter_(mask, source)
    在 mask 为 True 的位置依次从 source 取元素写入 self;source 中参与写入的元素数需 ≥ mask 中 True 的个数。
  • tensor.index_put_(indices, value) / torch.index_put(input, indices, value)
    按多维索引元组 indices(每维一个 1D 索引张量)写入 value;适合多维高级索引写回。
x = torch.randn(3, 4)
idx = torch.tensor([0, 2])
torch.index_select(x, 0, idx)   # (2, 4)
mask = x > 0
torch.masked_select(x, mask)    # 1D
idx = torch.nonzero(mask)       # (N, 2) 坐标
x.masked_fill_(~mask, -1e9)     # 非正位置填 -1e9
flat_idx = torch.tensor([0, 5, 11])
y = torch.zeros(12)
y.put_(flat_idx, torch.tensor([1., 2., 3.]))  # 按线性下标写入

  • torch.gather(input, dim, index, , sparse_grad=False)
    是什么:沿 dim 按 index 张量「聚拢」元素;index 与 input 维数相同,输出形状与 index 相同。
    公式(以 dim=0 为例):out[i][j][k] = input[index[i][j][k]][j][k]
    典型用途:从 logits 中按 target 取对应类别的 logit(如 CE 的 -log(p_true))、one-hot 式取值。
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
idx = torch.tensor([[0, 1, 0], [2, 0, 1]])
torch.gather(x, 0, idx)  # [[1,5,3], [7,2,6]]
  • tensor.scatter_(dim, index, src)
    是什么:gather 的逆——把 src 按 index 写到 tensor 的 dim 维对应位置。
    公式(dim=0):self[index[i][j][k]][j][k] = src[i][j][k]
    用途:one-hot 构造、稀疏更新、将类别下标转为 one-hot 嵌入。
x = torch.zeros(3, 4)
idx = torch.tensor([[0, 1, 2, 0], [1, 2, 0, 2], [2, 0, 1, 1]])
src = torch.ones(3, 4)
x.scatter_(0, idx, src)
  • index_add_(dim, index, source):按 index 在 dim 上把 source 加到 self。
  • index_copy_(dim, index, source):按 index 在 dim 上把 source 拷贝到 self。
  • tensor.scatter_reduce_(dim, index, src, reduce=‘sum’|‘mean’|‘amax’|‘amin’)
    与 scatter_ 类似,但同一下标被多次写入时按 reduce 归约(求和、平均、取大、取小);适合稀疏梯度聚合等。
  • tensor.index_reduce_(dim, index, source, reduce, include_self=True)
    沿 dim 按 index 将 source 归约写回 self;reduce 同上;include_self 控制是否先保留 self 再与 source 归约。

  • eq / ne / gt / ge / lt / le(或 == != > >= < <=
    逐元素比较,返回 bool 张量;可与 where 结合做条件赋值。
  • torch.where(condition, x, y)
    condition 为 True 取 x,否则取 y;x/y 可广播到同一形状。
    用途:裁剪、阈值、条件分支(如梯度裁剪、ReLU 式选择)。
a = torch.randn(2, 3)
b = torch.where(a > 0, a, torch.zeros_like(a))
  • torch.logical_and / logical_or / logical_not / logical_xor
    逐元素逻辑运算,常用于组合多个条件掩码。
  • torch.isfinite(input) / torch.isinf(input) / torch.isnan(input)
    逐元素判断是否有限、无穷、NaN;调试与掩码常用。
  • torch.isposinf(input) / torch.isneginf(input)
    是否为正无穷、负无穷。
  • torch.isclose(a, b, rtol=1e-5, atol=1e-8)
    逐元素判断是否「近似相等」:|a - b| ≤ atol + rtol * |b|
  • torch.isreal(input)
    逐元素是否实数(复数时虚部为 0 则为 True)。
  • torch.equal(a, b)
    若 a、b 同形且逐元相等则返回 True(单布尔);与 (a == b).all() 等价但更高效。
  • torch.allclose(a, b, rtol=1e-5, atol=1e-8)
    整体是否「近似相等」:所有元素满足 isclose 则 True;常用于数值测试。
  • torch.is_tensor(obj)
    是否为 PyTorch 张量。
  • tensor.is_floating_point() / tensor.is_complex()
    元素类型是否为浮点或复数;用于分支选择(如是否做归一化)。
x = torch.tensor([1.0, float('nan'), float('inf')])
torch.isfinite(x)  # [True, False, False]
torch.isnan(x)     # [False, True, False]
torch.isclose(torch.tensor(1.0), torch.tensor(1.0 + 1e-6))  # True
torch.equal(torch.ones(2, 3), torch.ones(2, 3))  # True
torch.allclose(torch.randn(3), torch.randn(3) + 1e-7)  # 可能 True

  • tensor.clone():拷贝一份,新存储,梯度可独立回传。
  • tensor.expand(*sizes):仅把大小为 1 的维扩展,不拷贝数据。
  • tensor.repeat(*repeats):按 repeats 各维重复,真实复制。
x = torch.randn(1, 3)
x.expand(4, 3)   # (4,3) 视图式
x.repeat(4, 2)   # (4, 6) 复制
  • tensor.narrow(dim, start, length)
    在 dim 维从 start 取 length 个,返回视图。
  • torch.einsum(equation, operands)
    是什么:用爱因斯坦求和约定写张量运算,一次表达乘法与求和维度。
    用途:矩阵乘、外积、迹、对角、batch 乘等,公式简洁。
A = torch.randn(3, 4)
B = torch.randn(4, 5)
C = torch.einsum("ij,jk->ik", A, B)  # 矩阵乘
x = torch.randn(4)
y = torch.einsum("i,i->", x, x)     # 内积
  • torch.roll(input, shifts, dims)
    沿 dims 循环移位。
  • torch.clamp(input, min=None, max=None)
    将元素限制在 [min, max];可用于梯度裁剪、ReLU 式截断。
  • torch.clamp_min(input, min) / torch.clamp_max(input, max)
    仅设下限或上限。
  • torch.addcdiv(input, tensor1, tensor2, value=1)
    input + value * (tensor1 / tensor2),逐元素;常用于带除法的更新。
  • torch.addcmul(input, tensor1, tensor2, value=1)
    input + value * (tensor1 * tensor2),逐元素。
  • torch.tril(input, diagonal=0) / torch.triu(input, diagonal=0)
    取下三角/上三角(含对角线);diagonal 控制偏移。
  • torch.renorm(input, p, dim, maxnorm)
    沿 dim 将子向量重新归一化,使其 p-范数不超过 maxnorm;常用于梯度稳定。
x = torch.randn(3, 3)
torch.tril(x)   # 下三角,其余为 0
torch.triu(x, diagonal=1)  # 严格上三角
  • tensor.to(device=None, dtype=None)
    设备或 dtype 迁移;.cpu(), .cuda(device=None) 为常用简写。
  • tensor.numpy()
    转 NumPy 数组(与 CPU 张量共享内存时注意 inplace 修改)。
  • tensor.tolist()
    转 Python 嵌套 list。
  • tensor.item()
    单元素张量取 Python 标量(如 loss.item())。
  • tensor.fill_(value) / tensor.zero_()
    原地填常量或置零。
  • tensor.copy_(src)
    从 src 拷贝到 self(形状需可广播或一致)。
  • torch.real(input) / torch.imag(input)
    取复数的实部/虚部。
  • torch.angle(input)
    幅角。
  • torch.conj(input)
    共轭。
  • torch.special 命名空间下有 erf, erfc, gammaln, multigammaln 等;科学计算与分布时常用。
  • torch.diff(input, n=1, dim=-1)
    沿 dim 做 n 阶差分;一阶即相邻相减,输出长度减 n。
  • torch.broadcast_tensors(tensors)
    将多个张量广播到相同形状后返回元组;在手动广播前做形状对齐时有用。
  • torch.atleast_1d(input) / atleast_2d / atleast_3d
    保证至少 1/2/3 维;标量变 (1,),1D 保持,2D 保持。
  • torch.bucketize(input, boundaries, right=False)
    将 input 按 boundaries 分桶,返回每个元素所属桶的下标(整数);right=True 时右闭。
  • torch.searchsorted(sorted_sequence, value, right=False)
    在有序 1D 张量 sorted_sequence 中找 value 的插入位置,保持有序;用于分位数、插值。
x = torch.tensor([1., 2., 4., 7.])
torch.diff(x)  # [1., 2., 3.]
a, b = torch.broadcast_tensors(torch.randn(3, 1), torch.randn(1, 4))  # 均为 (3, 4)
torch.bucketize(torch.tensor([0.5, 2.5, 5.]), torch.tensor([1., 2., 3., 4.]))  # 桶下标
  • tensor.resize_(sizes) / tensor.resize_as_(other)
    原地修改形状;可改变元素总数(多则未初始化、少则截断),与 view 不同。多用于与 C++/Legacy 接口兼容,新代码优先用 view/reshape。
  • torch.result_type(tensors_and_dtypes) / torch.promote_types(type1, type2)
    根据类型提升规则得到运算结果 dtype;混合 int/float 或不同精度时,用于事先确定输出类型或做类型检查。
  • tensor.uniform_(from=0, to=1) / tensor.normal_(mean=0, std=1)
    原地用均匀分布或正态分布填充;初始化权重、生成噪声时常用。
  • tensor.exponential_(lambd=1) / tensor.bernoulli_(p=0.5)
    原地用指数分布或伯努利分布填充;采样、Dropout 掩码等。
x = torch.empty(2, 3)
x.uniform_(-0.1, 0.1)   # 原地均匀初始化
torch.result_type(torch.tensor(1), torch.tensor(1.0))  # torch.float32

类别常用 API
创建tensor, zeros, ones, arange, linspace, randn, from_numpy
形状view, reshape, squeeze, unsqueeze, flatten, transpose, permute, contiguous
拼接/拆分cat, stack, chunk, split, unbind
逐元/广播add, sub, mul, div;pow, sqrt, rsqrt, neg, reciprocal, sign;minimum, maximum;fmod, remainder;lerp, frac;atan2, hypot;exp, log, sin, cos, round, floor, ceil
矩阵matmul, mm, bmm, mv, dot, vdot;ger, outer, cross;trace, linalg.inv, solve, det, svd, cholesky, qr, eig;tensordot, meshgrid
归约sum, mean, max, min, prod, argmax, argmin, norm, std, var;any, all, count_nonzero;cumsum, cumprod, cummax, cummin, logcumsumexp;median, mode;nanmean, nansum, quantile, logsumexp, dist;sort, topk, kthvalue;unique
索引[], index_select, masked_select, take, put_;nonzero;index_fill_, masked_fill_, masked_scatter_, index_put_
高级gather, scatter_, index_add_, index_copy_, scatter_reduce_, index_reduce_
比较逻辑eq, ne, gt, where, logical_and/or/not;isfinite, isnan, isinf, isclose;equal, allclose;is_tensor, is_floating_point, is_complex
其它clone, expand, repeat, narrow, einsum, roll, clamp, addcdiv, addcmul, tril, triu, renorm;real, imag, angle, conj;to/cpu/cuda, numpy/tolist/item, fill_/zero_/copy_;diff, broadcast_tensors, atleast_1d/2d/3d, bucketize, searchsorted;resize_/resize_as_, result_type/promote_types;uniform_/normal_/exponential_/bernoulli_;torch.special

以下 API 在写代码时常用,若在正文未单独展开,可在此按类查缺补漏。

类别易遗漏的 API用途简述
创建full_like, empty_like与输入同形填常量或未初始化;torch.complex(real, imag) 构造复数张量。
矩阵/张量torch.tensordot(a, b, dims)指定维做内积收缩;高阶张量乘、注意力收缩。
torch.meshgrid(*tensors, indexing='ij')由 1D 坐标生成网格;绘图、插值、物理场。
形状/视图tensor.t()2D 转置 shorthand(仅 2D)。
tensor.unfold(dim, size, step)沿 dim 滑动窗口,得到局部块;做 patch、局部卷积时常用。
torch.movedim(input, source, destination)把 source 维移到 destination;比 permute 更直观地「挪一维」。
torch.swapdims(input, dim0, dim1)交换两维,transpose 的别名。
索引/选数tensor.put_(index, source, accumulate=False)按线性下标写回,与 take 互逆;accumulate 为 True 时累加。
torch.nonzero(input) / input.nonzero()返回非零元素的坐标张量 (N, ndim);与 mask 配合取下标。
tensor.index_fill_(dim, index, value)在 dim 维的 index 位置填 value。
tensor.masked_fill_(mask, value)在 mask 为 True 处填 value;常用于 padding 掩码。
tensor.masked_scatter_(mask, source)按 mask 从 source 依次取数填到 self。
Scatter`tensor.scatter_reduce_(dim, index, src, reduce=‘sum’‘mean’
逐元log2, log10, exp2, expm1以 2/10 为底对数、2^x、exp(x)-1。
floor_divide / //向下取整除;`div(…, rounding_mode=‘floor’
归约cummax, cummin递进取大/取小,返回 (values, indices)。
logcumsumexp递进 log-sum-exp,数值稳定。
nanmean, nansum忽略 NaN 的均值/和。
quantile, nanquantile分位数。
比较/判断torch.equal(a, b)两张量逐元相等且同形则 True(单布尔)。
torch.allclose(a, b, rtol, atol)整体近似相等,常用于数值测试。
is_tensor, is_floating_point, is_complex类型/数值类型判断。
写回tensor.index_put_(indices, value)按多维索引元组写 value。
tensor.index_reduce_(dim, index, source, reduce)沿 dim 按 index 将 source 归约写回。
设备与转换`tensor.to(devicedtype), .cpu(), .cuda()`
tensor.numpy(), tensor.tolist(), tensor.item()转 NumPy、转 Python list、标量取 Python 标量。
其它tensor.resize_(*sizes) / resize_as_(other)原地改形状(可改元素总数);慎用,优先 view/reshape。
torch.result_type(...) / torch.promote_types(t1, t2)类型提升,得到运算结果 dtype。
tensor.uniform_(from, to) / normal_(mean, std)原地均匀/正态填充;权重初始化、噪声。
tensor.exponential_(lambd) / bernoulli_(p)原地指数/伯努利填充;采样、Dropout 等。
torch.diff(input, n=1, dim=-1)沿 dim 做 n 阶差分。
torch.broadcast_tensors(*tensors)把多个张量广播到相同形状后返回。
vsplit, hsplit, dsplit沿第 0/1/2 维 split 的便捷写法(2D/3D)。
atleast_1d, atleast_2d, atleast_3d保证至少 1/2/3 维。
tensor.fill_(value), zero_(), copy_(src)原地填值、置零、从 src 拷贝。
torch.bucketize(input, boundaries)将 input 按 boundaries 分桶,返回桶下标。
torch.searchsorted(sorted, value)在有序序列中找 value 的插入位置。

正文已覆盖绝大部分「创建、形状、拼接、逐元、矩阵、归约、索引、gather/scatter、比较、einsum」;上表用于查缺补漏与速查,需要时查阅官方 torchtorch.Tensor 文档。

PyTorch 张量工具函数覆盖「创建 → 变形 → 拼接/拆分 → 运算 → 归约 → 索引」全流程;掌握 cat/stack、view/reshape、广播、matmul、归约 dim/keepdim、gather/scatter 和 where,即可覆盖绝大多数模型与数据脚本中的张量操作。需要时以本文索引与速查表按类查找具体 API 与示例。

相关内容