# -*- coding: utf-8 -*-
# 文件名:ml_visualization.py
"""
作者: 梁斌
版本: 1.0
日期: 2017/10
该模块用于机器学习中的分类边界、决策树等可视化
声明:小象学院拥有完全知识产权的权利;只限于善意学习者在本课程使用,
不得在课程范围外向任何第三方散播。任何其他人或机构不得盗版、复制、仿造其中的创意,
我们将保留一切通过法律手段追究违反者的权利
"""
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn import neighbors
import graphviz
from sklearn.tree import export_graphviz
import matplotlib.patches as mpatches
def plot_decision_tree(clf, feature_names, class_names):
"""
决策树结果可视化
需要安装
1. graphviz程序(已提供在代码目录下),并将安装目录下的bin目录添加到环境变量中,重启jupyter或系统生效
如:C:\Program Files (x86)\Graphviz2.38\bin 添加到系统PATH环境变量中
2. graphviz模块, pip install graphviz
"""
tmp_dot_file = 'decision_tree_tmp.dot'
export_graphviz(clf, out_file=tmp_dot_file, feature_names=feature_names, class_names=class_names,
filled=True, impurity=False)
with open(tmp_dot_file) as f:
dot_graph = f.read()
# Alternate method using pydotplus, if installed.
# graph = pydotplus.graphviz.graph_from_dot_data(dot_graph)
# return graph.create_png()
return graphviz.Source(dot_graph)
def plot_feature_importances(clf, feature_names):
"""
可视化分类器中特征的重要性
"""
c_features = len(feature_names)
plt.barh(range(c_features), clf.feature_importances_)
plt.xlabel('Feature importance')
plt.ylabel('Feature name')
plt.yticks(np.arange(c_features), feature_names)
def plot_class_regions_for_classifier(clf, X, y, X_test=None, y_test=None, title=None,
target_names=None, plot_decision_regions=True):
"""
根据分类器可视化数据分类的结果
只能用于二维特征的数据
"""
num_classes = np.amax(y) + 1
color_list_light = ['#FFFFAA', '#EFEFEF', '#AAFFAA', '#AAAAFF']
color_list_bold = ['#EEEE00', '#000000', '#00CC00', '#0000CC']
cmap_light = ListedColormap(color_list_light[0:num_classes])
cmap_bold = ListedColormap(color_list_bold[0:num_classes])
h = 0.03
k = 0.5
x_plot_adjust = 0.1
y_plot_adjust = 0.1
plot_symbol_size = 50
x_min = X[:, 0].min()
x_max = X[:, 0].max()
y_min = X[:, 1].min()
y_max = X[:, 1].max()
x2, y2 = np.meshgrid(np.arange(x_min-k, x_max+k, h), np.arange(y_min-k, y_max+k, h))
P = clf.predict(np.c_[x2.ravel(), y2.ravel()])
P = P.reshape(x2.shape)
plt.figure()
if plot_decision_regions:
plt.contourf(x2, y2, P, cmap=cmap_light, alpha=0.8)
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold, s=plot_symbol_size, edgecolor='black')
plt.xlim(x_min - x_plot_adjust, x_max + x_plot_adjust)
plt.ylim(y_min - y_plot_adjust, y_max + y_plot_adjust)
if X_test is not None:
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap=cmap_bold, s=plot_symbol_size,
marker='^', edgecolor='black')
train_score = clf.score(X, y)
test_score = clf.score(X_test, y_test)
title = title + "\nTrain score = {:.2f}, Test score = {:.2f}".format(train_score, test_score)
if target_names is not None:
legend_handles = []
for i in range(0, len(target_names)):
patch = mpatches.Patch(color=color_list_bold[i], label=target_names[i])
legend_handles.append(patch)
plt.legend(loc=0, handles=legend_handles)
if title is not None:
plt.title(title)
plt.show()
def plot_fruit_knn(X, y, n_neighbors):
"""
在“水果数据集”上对 height 和 width 二维数据进行kNN训练
并绘制出结果
"""
X_mat = X[['height', 'width']].as_matrix()
y_mat = y.as_matrix()
# Create color maps
cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF', '#AFAFAF'])
cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF', '#AFAFAF'])
clf = neighbors.KNeighborsClassifier(n_neighbors)
clf.fit(X_mat, y_mat)
# Plot the decision boundary by assigning a color in the color map
# to each mesh point.
mesh_step_size = .01 # step size in the mesh
plot_symbol_size = 50
x_min, x_max = X_mat[:, 0].min() - 1, X_mat[:, 0].max() + 1
y_min, y_max = X_mat[:, 1].min() - 1, X_mat[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, mesh_step_size),
np.arange(y_min, y_max, mesh_step_size))
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
# Put the result into a color plot
Z = Z.reshape(xx.shape)
plt.figure()
plt.pcolormesh(xx, yy, Z, cmap=cmap_light)
# Plot training points
plt.scatter(X_mat[:, 0], X_mat[:, 1], s=plot_symbol_size, c=y, cmap=cmap_bold,
edgecolor='black')
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
patch0 = mpatches.Patch(color='#FF0000', label='apple')
patch1 = mpatches.Patch(color='#00FF00', label='mandarin')
patch2 = mpatches.Patch(color='#0000FF', label='orange')
patch3 = mpatches.Patch(color='#AFAFAF', label='lemon')
plt.legend(handles=[patch0, patch1, patch2, patch3])
plt.xlabel('height (cm)')
plt.ylabel('width (cm)')
plt.show()