PyTorch学习(四)--用PyTorch实现线性回归

教程视频:https://www.bilibili.com/video/BV1tE411s7QT

废话不多说,代码如下:

import  torch

x_data = torch.Tensor([[1.0],[2.0],[3.0]])
y_data = torch.Tensor([[2.0],[4.0],[6.0]])

class LinearModel(torch.nn.Module):
    def __init__(self):#构造函数
        super(LinearModel,self).__init__()
        self.linear = torch.nn.Linear(1,1)#构造对象,并说明输入输出的维数,第三个参数默认为true,表示用到b
    def forward(self, x):
        y_pred = self.linear(x)#可调用对象,计算y=wx+b
        return  y_pred

model = LinearModel()#实例化模型

criterion = torch.nn.MSELoss(size_average=False)
#model.parameters()会扫描module中的所有成员,如果成员中有相应权重,那么都会将结果加到要训练的参数集合上
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)#lr为学习率

for epoch in range(1000):
    y_pred = model(x_data)
    loss = criterion(y_pred,y_data)
    print(epoch,loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print('w=',model.linear.weight.item())
print('b=',model.linear.bias.item())

x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

结果:

0 56.52023696899414
1 25.170454025268555
2 11.214292526245117
3 5.001270771026611
4 2.2352840900421143
5 1.0038176774978638
6 0.45547759532928467
7 0.21124869585037231
8 0.10240332782268524
9 0.05382827669382095
10 0.03208546340465546
……
90 0.004652736708521843
91 0.004585907328873873
92 0.004519954323768616
93 0.00445501459762454
94 0.004390999674797058
95 0.004327872302383184
96 0.004265678580850363
97 0.004204379860311747
98 0.004143938422203064
99 0.00408441387116909
w= 2.042545795440674
b= -0.09671643376350403
y_pred = tensor([[8.0735]])

不同优化器,他们的性能在使用上有什么区别?直接看图
以下包含了Adagrad Adam adamax ASGD RMSprop Rprop SGD七种优化器的loss下降图。其实还有一种优化器LBFGS,使用时需要传递闭包等等,我会在之后补上,暂时没有。
在这里插入图片描述

小知识点:可调用对象

如果要使用一个可调用对象,那么在类的声明的时候要定义一个 call()函数就OK了,就像这样

class Foobar:
	def __init__(self):
		pass
	def __call__(self,*args,**kwargs):
		pass

其中参数*args代表把前面n个参数变成n元组,**kwargsd会把参数变成一个词典,举个例子:

 def func(*args,**kwargs):
 	print(args)
 	print(kwargs)

#调用一下
func(1,2,3,4,x=3,y=5)

结果:
(1,2,3,4)
{‘x’:3,‘y’:5}

版权声明:本文为weixin_44841652原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_44841652/article/details/105068509

智能推荐

Rtthread学习笔记(十三)RT-Thread Studio开启硬件看门狗Watchdog

一、开启硬件看门狗Watchdog 1、配置RT-Thread Settings 2、开启stm32f1xx_hal_conf.h中的宏定义 3.使用RT接口函数初始化硬件看门狗...

TYVJ 4864 天天去哪吃 || 清北学堂金秋杯大奖赛

题目描述: 记录一下i这个值上次出现的位置在哪里,就是pre...

java反编译

jvm 把Boolean类型的值flag当做int类型处理。​​​ Foo.java: 由 class 文件生成 jasm 文件:java -jar asmtools.jar jdis Foo.class > Foo.jasm  修改jasm文件: 执行反编译: java -jar jd-gui-1.6.6.jar File 打开Foo.class文件:b修改为2 重新执行java...

【学习笔记】03-v-html的学习和示例

v-html的认识和使用 示例: 显示结果: 注意:v-html是有复制的...

Java实现在线考试系统(系统介绍)

1.和现在有的考试系统有以下几种优势: a.和现在有的系统比较起来,本系统有科目、章节、老师、学生、班级等信息的管理,还有批阅试卷查看已批阅试卷等。传统的考试系统划分并不细,业务功能简单。 b.和学校的考试系统还有外面的考试系统比较起来,本系统是B/S结构,学校的考试系统一般为C/S结构,性能方面不如B/S结构,并且C/S接口需要安装客户端,客户端压力很大,我的系统只需要电脑具有浏览器,在同一局域...

猜你喜欢

计算机视觉--多视几何初步尝试

基础矩阵的原理 K和K’分别是两个相机的参数矩阵。p和p’是X在平面π的坐标表示。所以可以得出 具体计算过程 代码: #!/usr/bin/env python coding: utf-8 from PIL import Image from numpy import * from pylab import * import numpy as np from imp ...

java初学者怎么学习才可以快速入门

java初学者怎么学习才可以快速入门 一、了解JAVA 我们要知道:Java是由Sun Microsystems公司于1995年5月推出的Java面向对象程序设计语言。 Java之父:詹姆斯·高斯林 1.1 java的三个体系 Java SE(Java Platform Standard Edition)。Java SE 以前称为 J2SE。它允许开发和部署在桌面、服务器、嵌入式环境...

字段属性之主键&增删改查&自增长&唯一键约束

字段属性之主键&自增长&唯一键约束 主键 主键:primary key 主要的键 一张表中只有一个字段可以使用对应的键,用来唯一的约束该字段里面的数据,不能重复,这种称之为主键 一张表只能最多一个主键 增加主键 SQL操作中有多种方式增加主键大体分为三种 1.在创建表的时候直接在字段之后跟primary key关键字(主键本身不允许为空) 优点:非常直接:缺点:只能使用一个字段作为...

linux下 基于libmad的socket多用户mp3音频在线播放服务器

在众多大神的帮助下,这个在线播放流媒体服务器终于完成啦。。。。 这个mp3流媒体服务器设计的思路是,服务器程序server用多线程实现和多个客户端的通信(这是必然的),然后发送给客户端当前的音频列表公客户端选择,之后根据k客户端的选择给多个客户端传输相应mp3文件的数据,同时,客户端进行实时地音频解码并播放。 关于libmad开源mp3音频解码库的使用,见上一篇博客吧。。。。 在服务器程序这一端,...

Nginx

Nginx Nginx简介: Nginx是一个高性能的http和反向代理服务器,特点是有内存少,并发能力强,事实上Nginx的并发能力确实在同类型网页服务器中表现较好, Nginx用作web服务器:Nginx可以作为静态页面的web服务器,同时还支持CGI语言,但不支持java,java程序只能通过Tomcat配合完成。Nginx专为性能优化而开发,性能是其最重要的考量,实现上非常注重效率,能经受...