Pytorch & Numpy 对照表

Pytorch 专栏收录该内容
17 篇文章 1 订阅

PyTorch 里面处理的最基本的操作对象就是 Tensor(张量),表示的是一个多维的矩阵,比如零维就是一个点,一维就是向量,二维就是一般的矩阵,多维就相当于一个多维的数组,这和 numpy 是对应的,而且 PyTorch 的 tensor 可以和 numpy 的 ndarray 相互转换,

唯一不同的是 PyTorch 可以在 GPU 上运行
而 numpy 的 ndarray 只能在 CPU 上运行

前言

我们先介绍一下一些常用的不同数据类型的Tensor( torch.Tensor默认的是 torch.FloatTensor数据类型):

  • 16位整型 torch.ShortTensor
  • 32位整型 torch.IntTensor
  • 64位整型 torch.LongTensor
  • 16位浮点型 torch.HalfTensor
  • 32位浮点型 torch.FloatTensor
  • 64位浮点型 torch.DoubleTensor
  • torch.Tensor 默认 torch.FloatTensor

类型(Types)

NumpyPyTorch
np.ndarraytorch.Tensor
np.float32torch.float32; torch.float
np.float64torch.float64; torch.double
np.float16torch.float16; torch.half
np.int8torch.int8
np.uint8torch.uint8
np.int16torch.int16; torch.short
np.int32torch.int32; torch.int
np.int64torch.int64; torch.long

索引

NumpyPyTorch
x[0]x[0]
x[:, 0]x[:, 0]
x[indices]x[indices]
np.take(x, indices)torch.take(x, torch.LongTensor(indices))
x[x != 0]x[x != 0]

ones and zeros

NumpyPyTorch
np.empty((2, 3))torch.empty(2, 3)
np.empty_like(x)torch.empty_like(x)
np.eyetorch.eye
np.identitytorch.eye
np.onestorch.ones
np.ones_liketorch.ones_like
np.zerostorch.zeros
np.zeros_liketorch.zeros_like

从已知数据构造

NumpyPyTorch
np.array([[1, 2], [3, 4]])torch.tensor([[1, 2], [3, 4]])
np.array([3.2, 4.3], dtype=np.float16) , np.float16([3.2, 4.3])torch.tensor([3.2, 4.3], dtype=torch.float16)
x.copy()x.clone()
np.fromfile(file)torch.tensor(torch.Storage(file))
np.frombuffer
np.fromfunction
np.fromiter
np.fromstring
np.loadtorch.load
np.loadtxt
np.concatenatetorch.cat

数值范围

NumpyPyTorch
np.arange(10)torch.arange(10)
np.arange(2, 3, 0.1)torch.arange(2, 3, 0.1)
np.linspacetorch.linspace
np.logspacetorch.logspace

构造矩阵

NumpyPyTorch
np.diagtorch.diag
np.triltorch.tril
np.triutorch.triu

参数

NumpyPyTorch
x.shapex.shape
x.stridesx.stride()
x.ndimx.dim()
x.datax.data
x.sizex.nelement()
x.dtypex.dtype

形状(Shape)变换

NumpyPyTorch
x.reshapex.reshape; x.view
x.resize()x.resize_
x.resize_as_
x.transposex.transpose or x.permute
x.flattenx.view(-1)
x.squeeze()x.squeeze()
x[:, np.newaxis]; np.expand_dims(x, 1)x.unsqueeze(1)

数据选择

NumpyPyTorch
np.put
x.putx.put_
x = np.array([1, 2, 3]) x.repeat(2) # [1, 1, 2, 2, 3, 3]x = torch.tensor([1, 2, 3]) x.repeat(2) # [1, 2, 3, 1, 2, 3] x.repeat(2).reshape(2, -1).transpose(1, 0).reshape(-1) # [1, 1, 2, 2, 3, 3]
np.tile(x, (3, 2))x.repeat(3, 2)
np.choose
np.sort sorted, indices = torch.sort(x, [dim])
np.argsort sorted, indices = torch.sort(x, [dim])
np.nonzerotorch.nonzero
np.wheretorch.where
x[::-1]

数值计算

NumpyPyTorch
x.minx.min
x.argminx.argmin
x.maxx.max
x.argmaxx.argmax
x.clipx.clamp
x.roundx.round
np.floor(x)torch.floor(x); x.floor()
np.ceil(x)torch.ceil(x); x.ceil()
x.tracex.trace
x.sumx.sum
x.cumsumx.cumsum
x.meanx.mean
x.stdx.std
x.prodx.prod
x.cumprodx.cumprod
x.all(x == 1).sum() == x.nelement()
x.any(x == 1).sum() > 0

数值比较

NumpyPyTorch
np.lessx.lt
np.less_equalx.le
np.greaterx.gt
np.greater_equalx.ge
np.equalx.eq
np.not_equalx.ne
  • 2
    点赞
  • 0
    评论
  • 18
    收藏
  • 一键三连
    一键三连
  • 扫一扫,分享海报

相关推荐
©️2020 CSDN 皮肤主题: 岁月 设计师:pinMode 返回首页
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值