机器学习:线性判别分析(LDA)代码实现

标签: 机器学习  线性判别分析  LDA  代码实现  python

机器学习:线性判别分析(LDA)代码实现

相关知识

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

代码实现

使用的数据集 周志华老师<机器学习>书上数据集3a(第一列序号,第二列密度,第三列含糖量,第四列是否为好瓜)
在这里插入图片描述
通过三种方法:第一种方法调库,二三种手写

#!/usr/bin/env python
# coding: utf-8

# In[8]:


import numpy as np

'''
get the projective point(2D) of a point to a line

@param point: the coordinate of the point form as [a,b]
@param line: the line parameters form as [k, t] which means y = k*x + t
@return: the coordinate of the projective point  
'''
def GetProjectivePoint_2D(point, line):#点x,y坐标。线斜率,截距。返回点投影到线上的点的坐标
    a = point[0]
    b = point[1]
    k = line[0]
    t = line[1]

    if   k == 0:    
        return [a, t] 
    elif k == np.inf: 
        return [0, b]
    x = (a+k*b-k*t) / (k*k+1)
    y = k*x + t
    return [x, y]





from _operator import inv 
#求逆矩阵





import numpy as np





import matplotlib.pyplot as plt 





data_file = open('watermelon_3a.csv')
dataset = np.loadtxt(data_file, delimiter=",")





X = dataset[:,1:3]
y = dataset[:,3]

# draw scatter diagram to show the raw data
f1 = plt.figure(1)       
plt.title('watermelon_3a')  
plt.xlabel('density')  
plt.ylabel('ratio_sugar')  
plt.scatter(X[y == 0,0], X[y == 0,1], marker = 'o', color = 'k', s=100, label = 'bad')
plt.scatter(X[y == 1,0], X[y == 1,1], marker = 'o', color = 'g', s=100, label = 'good')
plt.legend(loc = 'upper right')  
plt.show()





from sklearn import model_selection
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn import metrics
import matplotlib.pyplot as plt

# generalization of train and test set
X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.5, random_state=0)

# model fitting
lda_model = LinearDiscriminantAnalysis(solver='lsqr', shrinkage=None).fit(X_train, y_train)

# model validation
y_pred = lda_model.predict(X_test)

# summarize the fit of the model
print(metrics.confusion_matrix(y_test, y_pred))
print(metrics.classification_report(y_test, y_pred))

# draw the classfier decision boundary
f2 = plt.figure(2) 
h = 0.001
# x0_min, x0_max = X[:, 0].min()-0.1, X[:, 0].max()+0.1
# x1_min, x1_max = X[:, 1].min()-0.1, X[:, 1].max()+0.1

x0, x1 = np.meshgrid(np.arange(-1, 1, h),
                     np.arange(-1, 1, h))

# x0, x1 = np.meshgrid(np.arange(x0_min, x0_max, h),
#                      np.arange(x1_min, x1_max, h))

z = lda_model.predict(np.c_[x0.ravel(), x1.ravel()]) 

# Put the result into a color plot
z = z.reshape(x0.shape)
plt.contourf(x0, x1, z)

# Plot also the training pointsplt.title('watermelon_3a')  
plt.title('watermelon_3a')  
plt.xlabel('density')  
plt.ylabel('ratio_sugar')  
plt.scatter(X[y == 0,0], X[y == 0,1], marker = 'o', color = 'k', s=100, label = 'bad')
plt.scatter(X[y == 1,0], X[y == 1,1], marker = 'o', color = 'g', s=100, label = 'good')
plt.show()





u = []  
for i in range(2): # two class
    u.append(np.mean(X[y==i], axis=0))  # column mean

# 2-nd. computing the within-class scatter matrix, refer on book (3.33)
m,n = np.shape(X)
Sw = np.zeros((n,n))
for i in range(m):
    x_tmp = X[i].reshape(n,1)  # row -> cloumn vector
    if y[i] == 0: u_tmp = u[0].reshape(n,1)
    if y[i] == 1: u_tmp = u[1].reshape(n,1)
    Sw += np.dot( x_tmp - u_tmp, (x_tmp - u_tmp).T )

Sw = np.mat(Sw)
U, sigma, V= np.linalg.svd(Sw) 

Sw_inv = V.T * np.linalg.inv(np.diag(sigma)) * U.T
# 3-th. computing the parameter w, refer on book (3.39)
w = np.dot( Sw_inv, (u[0] - u[1]).reshape(n,1) )  # here we use a**-1 to get the inverse of a ndarray

print(w)





f3 = plt.figure(3)
plt.xlim( -0.2, 1 )
plt.ylim( -0.5, 0.7 )

p0_x0 = -X[:, 0].max()
p0_x1 = ( w[1,0] / w[0,0] ) * p0_x0
p1_x0 = X[:, 0].max()
p1_x1 = ( w[1,0] / w[0,0] ) * p1_x0

plt.title('watermelon_3a - LDA')  
plt.xlabel('density')  
plt.ylabel('ratio_sugar')  
plt.scatter(X[y == 0,0], X[y == 0,1], marker = 'o', color = 'k', s=10, label = 'bad')
plt.scatter(X[y == 1,0], X[y == 1,1], marker = 'o', color = 'g', s=10, label = 'good')
plt.legend(loc = 'upper right')  

plt.plot([p0_x0, p1_x0], [p0_x1, p1_x1])

# draw projective point on the line


m,n = np.shape(X)
for i in range(m):
    x_p = GetProjectivePoint_2D( [X[i,0], X[i,1]], [w[1,0] / w[0,0] , 0] ) 
    if y[i] == 0: 
        plt.plot(x_p[0], x_p[1], 'ko', markersize = 5)
    if y[i] == 1: 
        plt.plot(x_p[0], x_p[1], 'go', markersize = 5)   
    plt.plot([ x_p[0], X[i,0]], [x_p[1], X[i,1] ], 'c--', linewidth = 0.3)

plt.show()





X = np.delete(X, 14, 0)
y = np.delete(y, 14, 0)

u = []  
for i in range(2): # two class
    u.append(np.mean(X[y==i], axis=0))  # column mean

# 2-nd. computing the within-class scatter matrix, refer on book (3.33)
m,n = np.shape(X)
Sw = np.zeros((n,n))
for i in range(m):
    x_tmp = X[i].reshape(n,1)  # row -> cloumn vector
    if y[i] == 0: u_tmp = u[0].reshape(n,1)
    if y[i] == 1: u_tmp = u[1].reshape(n,1)
    Sw += np.dot( x_tmp - u_tmp, (x_tmp - u_tmp).T )

Sw = np.mat(Sw)
U, sigma, V= np.linalg.svd(Sw) 

Sw_inv = V.T * np.linalg.inv(np.diag(sigma)) * U.T
# 3-th. computing the parameter w, refer on book (3.39)
w = np.dot( Sw_inv, (u[0] - u[1]).reshape(n,1) )  # here we use a**-1 to get the inverse of a ndarray

print(w)

# 4-th draw the LDA line in scatter figure

# f2 = plt.figure(2)
f4 = plt.figure(4)
plt.xlim( -0.2, 1 )
plt.ylim( -0.5, 0.7 )

p0_x0 = -X[:, 0].max()
p0_x1 = ( w[1,0] / w[0,0] ) * p0_x0
p1_x0 = X[:, 0].max()
p1_x1 = ( w[1,0] / w[0,0] ) * p1_x0

plt.title('watermelon_3a - LDA')  
plt.xlabel('density')  
plt.ylabel('ratio_sugar')  
plt.scatter(X[y == 0,0], X[y == 0,1], marker = 'o', color = 'k', s=10, label = 'bad')
plt.scatter(X[y == 1,0], X[y == 1,1], marker = 'o', color = 'g', s=10, label = 'good')
plt.legend(loc = 'upper right')  

plt.plot([p0_x0, p1_x0], [p0_x1, p1_x1])

# draw projective point on the line


m,n = np.shape(X)
for i in range(m):
    x_p = GetProjectivePoint_2D( [X[i,0], X[i,1]], [w[1,0] / w[0,0] , 0] ) 
    if y[i] == 0: 
        plt.plot(x_p[0], x_p[1], 'ko', markersize = 5)
    if y[i] == 1: 
        plt.plot(x_p[0], x_p[1], 'go', markersize = 5)   
    plt.plot([ x_p[0], X[i,0]], [x_p[1], X[i,1] ], 'c--', linewidth = 0.3)

plt.show()


运行结果

数据集散点图
在这里插入图片描述
三种方法的训练结果
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
如有侵权,联系删除: [email protected]

版权声明:本文为weixin_43846172原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_43846172/article/details/102075801

智能推荐

说说 Python Django 应用的基础目录结构

通过以下 django-admin 指令创建应用之后,就会生成应用的基础目录结构。 比如,我们建立了一个叫 ‘first’ 的应用,它的目录结构是这样的: 目录或文件 说明 最外层的 first/ 这是新应用的根目录,所有与该应用相关的内容都放在这里。 manage.py 用于管理 Django 项目的命令行工具。 里面一层的 first/ 目录 是一个...

Springboot整合rabbitMQ

依赖: 配置文件application.yml RabbitConfig 消息生产者RabbitProducer 消息消费者RabbitCustomer 通过Controller进行调用 启动项目后调用接口: 结果:...

Thread.join()方法的使用

如果一个线程A执行了thread.join()语句,代表当前线程A等待thread线程终止后才从thread.join()方法返回 并且这个方法具有超时特性,可以添加参数设置 输出结果: jdk中Thread.join()方法的源码(进行了部门调整)   每个线程终止的条件是前驱线程的终止,每个线程等待前驱线程终止后,才从join()方法返回,  当线程终止时,会调用自身的no...

linux服务器部署jenkins笔记

安装jenkins参考文档:https://blog.csdn.net/tomatocc/article/details/83930714 1. 打开jenkins官网:https://jenkins.io/download/ 将war包下载到本地 **ps:**这里要注意的是要下载左边下方的war包,不要下载右边下面的war包。左边是稳定版本,右边是最新版本,建议大家使用稳定版本(我刚开始下载的...

k8s部署elasticsearch集群

百度营销大学     环境准备 我们使用的k8s和ceph环境见: https://blog.51cto.com/leejia/2495558 https://blog.51cto.com/leejia/2499684 ECK简介 Elastic Cloud on Kubernetes,这是一款基于 Kubernetes Operator 模式的新型编排产品,用户可使用该产品在...

猜你喜欢

saas-export项目-AdminLTE介绍与入门

AdminLTE介绍 (1)AdminLTE是什么? AdminLTE是一款建立在bootstrap和jquery之上的开源的模板主题工具 (2)AdminLTE有什么特点? 提供一系列响应的、可重复使用的组件, 并内置了多个模板页面 自适应多种屏幕分辨率,兼容PC和移动端 快速的创建一个响应式的Html5网站 AdminLTE 不但美观, 而且可以免去写很大CSS与JS的工作量 AdminLTE...

MyBatis中ResultMap结果集映射

用于解决属性名和字段名不一致的情况: resultMap 元素是 MyBatis 中最重要最强大的元素。...

编写一个shell

编写shell的过程: 1.从标准输入中读入一个字符串。 2.解析字符串 3.创建一个子进程的执行程序。 4.子进程程序替换。 5.父进程等待子进程退出。...

WEB自动化测试中Xpath定位方法

前言: Xpath是在XML文档中查找信息的一种语言,使用路径表达式来选取XML文档中的节点或节点集,由于XML与HTML结构类似(前者用于传输数据,后者用于显示数据),所以Xpath也常用于查找HTML文档中的节点或节点集。 一  路径表达式: 路径以“/”开始     表示找到满足该绝对路径的元素; 路径以//”开始  ...

力扣困难难度 第4题 寻找两个正序数组的中位数

先看一眼题 我的思路: 设置下标i,j分别用于遍历两个数组,初始值均为0,直到找到两个数组中从小到大的第第length/2个数为止结束循环,length为两个数组长度之和。 ·每次比较nums[i]nums[j],如果前者小则i++,否则j++ ·循环结束时,如果count已经达到length/2,则说明已经找到了中位数,[注意:此时有可能正好其中一个数组遍历完了!所以...