TypeError: can‘t convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory fi

TypeError: can‘t convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory fiRuntimeError:Expectedobjectoftypetorch.cuda.FloatTensorbutfoundtypetorch.FloatTensorforargument#4’mat1’意思是:如果想把CUDAtensor格式的数据改成numpy时,需要先将其转换成cpufloat-tensor随后再转到numpy格式。numpy不能读取CU…

大家好,又见面了,我是你们的朋友全栈君。如果您正在找激活码,请点击查看最新教程,关注关注公众号 “全栈程序员社区” 获取激活教程,可能之前旧版本教程已经失效.最新Idea2022.1教程亲测有效,一键激活。

Jetbrains全系列IDE稳定放心使用

运行程序如下:

import numpy as np
import torch
from torch import nn
from torch.autograd import Variable
import matplotlib.pyplot as plt


class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression,self).__init__()
        self.linear = nn.Linear(1,1)
    def forward(self, x):
        out = self.linear(x)
        return out
x_train = np.array([[3.3],[4.4],[5.5],[6.710],[6.93],[4.168],[9.779],[6.182],[7.59],[2.167],[7.042],[10.791],[5.313],[7.997],[3.1]],dtype=np.float32)
y_train = np.array([[1.7],[2.76],[2.09],[3.19],[1.694],[1.573],[3.366],[2.596],[2.53],[1.221],[2.827],[3.465],[1.65],[2.904],[1.3]],dtype=np.float32)
x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train)
num_epochs = 1000
if torch.cuda.is_available():
    print("GPU1")
    model = LinearRegression().cuda()
else:
    print("CPU1")
    model = LinearRegression()

criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(),lr=1e-3)


for epoch in  range(num_epochs):
    if torch.cuda.is_available():
        print('GPU2')
        inputs = Variable(x_train).cuda()
        target = Variable(y_train).cuda()
    else:
        print("CPU2")
        inputs = Variable(x_train)
        target = Variable(y_train)
    # forward
    out = model(inputs)
    loss = criterion(out,target)
    # backward
    optimizer.zero_grad()  # 梯度归零
    loss.backward()   # 反向传播
    optimizer.step()  # 更新参数
    # if (epoch+1) % 20 ==0:
    #     print('Epoch[{}/{}], loss:{:,6f}'.format(epoch+1,num_epochs,loss.data[0]))


model.eval()

predict = model(Variable(x_train.cuda()))  
predict = predict.data.numpy()
plt.plot(x_train.numpy(),y_train.numpy(),'ro',label='Original data')
plt.plot(x_train.numpy(),predict,label='Fitting Line')
plt.show()

这行报错:predict = predict.data.numpy()

TypeError: can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

意思是:如果想把CUDA tensor格式的数据改成numpy时,需要先将其转换成cpu float-tensor随后再转到numpy格式。 numpy不能读取CUDA tensor 需要将它转化为 CPU tensor
predict.data.numpy() 改为predict.data.cpu().numpy()即可

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/180596.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • yui3:widget

    yui3:widgetWidget类包含什么?widget类的结构和职能widget类基本的属性渲染方法渐进增强标签结构class名和CSS默认UI事件类的结构和职责  Widget类的结构和职责 Widget类是一个用于创建widgets的基础类。Widget类可以实例化,但是一般都是用它作为基础类,扩展创建widgets,这些通过扩展创…

  • VS code安装和使用技巧

    VS code安装和使用技巧VSCode是微软提供的一款轻量级但功能十分强大的编辑器,内置了对JavaScript,TypeScript和Node.js语言的支持,并且为其他语言如C++,C#,Python,PHP等提供了丰富的扩展库和运行时。一:VSCode的安装(去下载),1.1:VSCode的当前版本为1.18,支持Windows,Ubuntu,Mac1.2:安装VS

  • proxmox集群节点崩溃处理

    proxmox集群节点崩溃处理

  • JavaScript之正则表达式的使用方法详细介绍[通俗易懂]

    JavaScript之正则表达式的使用方法详细介绍[通俗易懂]首先必须说明的是,这类文章(js正则表达式)在c站或者整个it类论坛是很多人写过的,而我认为我这篇的不同之处在于更加“小白”化,这也与我一贯的风格有关吧。关于JavaScript正则表达式,其他的文章大多一上来就太过激进,不利于初学者学习(我当粗就是这么被劝退的),这也是我为什么要坚持写这篇文章,希望小白在看了这篇文章后,不管能不能完全掌握JavaScript正则表达式,但至少对JavaScript正则表达式能有一个比较深刻的印象吧。

    2022年10月24日
  • 必读,sql加索引调优案例和explain extended说明

    做一个积极的人编码、改bug、提升自己我有一个乐园,面向编程,春暖花开!昨天分享了Mysql中的 explain 命令,使用 explain 来分析 select 语句的运行效果,如 :explain可以获得select语句使用的索引情况、排序的情况等等。链接:顺便提到了explain extended,有小伙伴留言说想知道一些explain extended,那今天就在简单讲解一下。…

  • nested exception is java.lang.NoClassDefFoundError: org/codehaus/jettison/json/JSONObject异常的解决办法

    nested exception is java.lang.NoClassDefFoundError: org/codehaus/jettison/json/JSONObject异常的解决办法nestedexceptionisjava.lang.NoClassDefFoundError:org/codehaus/jettison/json/JSONObject异常的解决办法

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号