pytorch基础到提高(2)-Tensor(2)

torch.Tensor 是torch.FloatTensor的别名。
tensor可用Python list或sequence 使用torch.tensor()进行构造。

import torch
x=torch.tensor([[10,20],[2,4]])
print(x)
tensor([[10, 20],
        [ 2,  4]])
import torch
a=[[10.2,20.6],[2,4]]
x=torch.DoubleTensor(a)
print(x)
y=torch.IntTensor(a)
print(y)
tensor([[10.2000, 20.6000],
        [ 2.0000,  4.0000]], dtype=torch.float64)
tensor([[10, 20],
        [ 2,  4]], dtype=torch.int32)
<ipython-input-14-f9db31b67daa>:5: DeprecationWarning: an integer is required (got type float).  Implicit conversion to integers using __int__ is deprecated, and may be removed in a future version of Python.
  y=torch.IntTensor(a)

torch.tensor()始终复制数据。如果您有一个 Tensor数据,并且只想更改它的requires_grad标志,可使用requires_grad_()或 detach() 防止拷贝。
如果您有一个numpy数组并希望避免复制 ,可使用torch.as_tensor()。

import torch
import numpy as np 
a = np.arange(8)
b = a.reshape(4,2)
print (b)
y=torch.torch.as_tensor(b)
print(y)
y[1][1]=55
print(y)
print(b)
[[0 1]
 [2 3]
 [4 5]
 [6 7]]
tensor([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7]])
tensor([[ 0,  1],
        [ 2, 55],
        [ 4,  5],
        [ 6,  7]])
[[ 0  1]
 [ 2 55]
 [ 4  5]
 [ 6  7]]
import torch
y=torch.zeros([2, 4], dtype=torch.int32)
print(y)
tensor([[0, 0, 0, 0],
        [0, 0, 0, 0]], dtype=torch.int32)
已标记关键词 清除标记
相关推荐
©️2020 CSDN 皮肤主题: 鲸 设计师:meimeiellie 返回首页