[机器学习-回归算法]Sklearn之线性回归实战

标签: Sklearn  机器学习

一,前言

一元线性回归的理论片请看我这个链接

二,热身例子

预测直线 y=1x1+2x2+3y = 1x_1 + 2x_2 +3

导入LinearRegression 从Sklearn.liear_model 包里

from sklearn.linear_model import LinearRegression

拟合数据也可以说是训练

reg = LinearRegression().fit(X, y)

检验正确率

print(reg.score(X, y))

训练的系数,也就是X前面的那个系数,这里打印出 [1. 2.]

print(reg.coef_)

直线的b的系数(其实就是偏置系数), 打印出3

print(reg.intercept_)

完整代码

import numpy as np
from sklearn.linear_model import LinearRegression
X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
# y = 1 * x_0 + 2 * x_1 + 3
y = np.dot(X, np.array([1, 2])) + 3

print(X)
print(y)
reg = LinearRegression().fit(X, y)
print(reg.score(X, y))
print(reg.coef_)
print(reg.intercept_)
print(reg.predict(np.array([[3, 5]])))
[[1 1]
 [1 2]
 [2 2]
 [2 3]]
[ 6  8  9 11]
1.0
[1. 2.]
3.0000000000000018
[16.]

三,贸易公司的简单例子

在这里插入图片描述
可见随着广告费用的增加,公司的销售额也在增加,但是它们并非绝对的 线性关系,而是趋向于平均
我们用线性拟合一下这个数据吧。

import numpy as np
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, r2_score

data = np.array([[10, 19],[13,60],[22,71],[37,74],[45,69],[48,89],[59,146]
                 ,[65,130],[66,153],[68,144],[68,128],[71,123],[84,127]
                 ,[88,125],[89,154],[89,150]])
X = data[:,np.newaxis, 0]
y = data[:,1]

print(X)
print(y)
reg = LinearRegression().fit(X, y)
print(reg.score(X, y))
print(reg.coef_)
print(reg.intercept_)
diabetes_y_pred = reg.predict(X)

print('Mean squared error: %.2f'  % mean_squared_error(y, diabetes_y_pred))
# The coefficient of determination: 1 is perfect prediction
print('Coefficient of determination: %.2f'  % r2_score(y, diabetes_y_pred))

plt.scatter(data[:,0], data[:,1],  color='black')
print('y='+str(reg.coef_[0]) +'*x + ' + str(reg.intercept_) )
plt.plot(data[:,0], reg.coef_*data[:,0] + reg.intercept_, color='blue', linewidth=3)
plt.show()
[[10]
 [13]
 [22]
 [37]
 [45]
 [48]
 [59]
 [65]
 [66]
 [68]
 [68]
 [71]
 [84]
 [88]
 [89]
 [89]]
[ 19  60  71  74  69  89 146 130 153 144 128 123 127 125 154 150]
0.7861129941287246
[1.37939644]
30.637280329657003
Mean squared error: 333.71
Coefficient of determination: 0.79
y=1.37939643679554*x + 30.637280329657003

拟合直线的表达是y=1.37939643679554*x + 30.637280329657003,,其中x表示广告费用,y表示销 售额,通过线性回归的公式就可以预测企业的销售额了。结果还可以 R2 = 0.79
在这里插入图片描述

四,Sklearn 官网里的一个例子

diabetes这个数据集一共有442个样本,每个样本有10 个特征。
我们选两个特征用线性来拟合它。


import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets, linear_model
from sklearn.metrics import mean_squared_error, r2_score

# Load the diabetes dataset
diabetes_X, diabetes_y = datasets.load_diabetes(return_X_y=True)

print('diabetes_X', diabetes_X.shape)
print('diabetes_y', diabetes_y.shape)

# Use only one feature
diabetes_X = diabetes_X[:, np.newaxis, 2]

print('diabetes_X', diabetes_X.shape)

# Split the data into training/testing sets
diabetes_X_train = diabetes_X[:-20]
diabetes_X_test = diabetes_X[-20:]

# Split the targets into training/testing sets
diabetes_y_train = diabetes_y[:-20]
diabetes_y_test = diabetes_y[-20:]

# Create linear regression object
regr = linear_model.LinearRegression()

# Train the model using the training sets
regr.fit(diabetes_X_train, diabetes_y_train)

# Make predictions using the testing set
diabetes_y_pred = regr.predict(diabetes_X_test)

# The coefficients
print('Coefficients: \n', regr.coef_)
# The mean squared error
print('Mean squared error: %.2f' % mean_squared_error(diabetes_y_test, diabetes_y_pred))
# The coefficient of determination: 1 is perfect prediction
print('Coefficient of determination: %.2f'% r2_score(diabetes_y_test, diabetes_y_pred))

# Plot outputs
plt.scatter(diabetes_X_test, diabetes_y_test,  color='black')
plt.plot(diabetes_X_test, diabetes_y_pred, color='blue', linewidth=3)

plt.xticks(())
plt.yticks(())

plt.show()
diabetes_X (442, 10)
diabetes_y (442,)
diabetes_X (442, 1)
Coefficients: 
 [938.23786125]
Mean squared error: 2548.07
Coefficient of determination: 0.47

R2 验证结果是0.47,结果不是很好,不过无所谓,我们就当学习而已,毕竟我们只是选择了10个里面的两个特征。
在这里插入图片描述

参考资料

[1]https://scikit-learn.org/stable/auto_examples/linear_model/plot_ols.html#sphx-glr-auto-examples-linear-model-plot-ols-py
[2] https://www.cnblogs.com/wuliytTaotao/p/10837533.html

版权声明:本文为keeppractice原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/keeppractice/article/details/106529867

智能推荐

Intellij IDEA 搭建Spring Boot项目(一)

Intellij IDEA 搭建Spring Boot项目 标签(空格分隔): SpringBoot JAVA后台 第一步 选择File –> New –> Project –>Spring Initialer –> 点击Next  第二步 自己修改 Group 和 Artif...

CentOS学习之路1-wget下载安装配置

参考1: https://blog.csdn.net/zhaoyanjun6/article/details/79108129 参考2: http://www.souvc.com/?p=1569 CentOS学习之路1-wget下载安装配置 1.wget的安装与基本使用 安装wget yum 安装软件 默认安装保存在/var/cache/yum ,用于所有用户使用。 帮助命令 基本用法 例子:下载...

深入浅出Spring的IOC容器,对Spring的IOC容器源码进行深入理解

文章目录 DispatcherServlet整体继承图 入口:DispatcherServlet.init() HttpServletBean.init() FrameworkServlet.initServletBean() 首先大家,去看Spring的源码入口,第一个就是DispatcherServlet DispatcherServlet整体继承图 入口:DispatcherServlet....

laravel框架的课堂知识点概总

1. MVC 1.1 概念理解 MVC全名是Model View Controller,是模型(model)-视图(view)-控制器(controller)的缩写,一种软件设计典范,用一种业务逻辑、数据、界面显示分离的方法组织代码,将业务逻辑聚集到一个部件里面,在改进和个性化定制界面及用户交互的同时,不需要重新编写业务逻辑 MVC 是一种使用 MVC(Model View Controller ...

Unity人物角色动画系统学习总结

使用动画系统控制人物行走、转向、翻墙、滑行、拾取木头 混合树用来混合多个动画 MatchTarget用来匹配翻墙贴合墙上的某一点,人物以此为支点翻墙跳跃 IK动画类似于MatchTarget,控制两只手上的两个点来指定手的旋转和位置,使得拾取木头时更逼真 创建AnimatorController: 首先创建一个混合树,然后双击 可以看到该混合树有五种状态机,分别是Idle、WalkForward、...

猜你喜欢

Composer 安装 ThinkPHP6 问题

Composer 安装 ThinkPHP6 问题 先说说问题 一.运行环境要求 二.配置 参考: ThinkPHP6.0完全开发手册 先说说问题 执行ThinkPHP6的安装命令 遇到问题汇总如下: 看提示是要更新版本,执行命令更新。 更新之后,再次安装ThinkPHP,之后遇到如下问题。 尝试了很多方法,依然不能解决。其中包括使用https://packagist.phpcomposer.com...

Spring Boot 整合JDBC

今天主要讲解一下SpringBoot如何整合JDBC,没啥理论好说的,直接上代码,看项目整体结构 看一下对应的pom.xml 定义User.java 定义数据源配置,这里使用druid,所以需要写一个配置类 上面指定druid的属性配置,和用户登录的账号信息以及对应的过滤规则: 下面定义数据访问接口和对应的实现: 数据访问层很简单,直接注入JdbcTemplate模板即可,下面再看对应的servi...

html鼠标悬停显示样式

1.显示小手:     在style中添加cursor:pointer 实现鼠标悬停变成小手样式     实例:         其他参数: cursor语法: cursor : auto | crosshair | default | hand | move | help | wait | tex...

Yupoo(又拍网)的系统架构

Yupoo!(又拍网) 是目前国内最大的图片服务提供商,整个网站构建于大量的开源软件之上。以下为其使用到的开源软件信息: 操作系统:CentOS、MacOSX、Ubuntu 服务器:Apache、Nginx、Squid 数据库:MySQLmochiweb、MySQLdb 服务器监控:Cacti、Nagios、 开发语言:PHP、Python、Erlang、Java、Lua 分布式计算:Hadoop...

创建一个Servlet项目流程(入门)

版本 IDEA 2020.2 JDK1.8 apache-tomcat-9.0.36 项目流程 一、IDEA中新建JaveEE项目 项目起名,选择项目存放地址,点击finish创建成功 进入项目后,右键选择项目,选择add Framework Support 选择Web Application,点击OK 此时项目文件夹 在WEB-INF下创建两个目录classes和lib 按ctrl+alt+sh...