如何解决决策树很小,无法阅读
我试图用我的信用卡流失数据集绘制决策树。我得到的图表太小,难以阅读。我该如何解决这个问题?我是否犯了任何错误或以其他方式如何改进此图表的可视化。
我已经附加了图像和代码的输出
我使用的代码:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report,confusion_matrix
from sklearn.tree import plot_tree
df=pd.read_csv(r'G:\\Edu\\My academics\\MSc in CS\\3rd sem\\Research\\Python files\\BC.csv')
##df = sns.load_dataset(r'G:\\Edu\\My academics\\MSc in CS\\3rd sem\\Research\\Python files\\BC.csv')
df.head()
df.head()
df.drop(['CLIENTNUM','Naive_Bayes_Classifier_Attrition_Flag_Card_Category_Contacts_Count_12_mon_Dependent_count_Education_Level_Months_Inactive_12_mon_1','Naive_Bayes_Classifier_Attrition_Flag_Card_Category_Contacts_Count_12_mon_Dependent_count_Education_Level_Months_Inactive_12_mon_2'],axis = 1,inplace = True)
print(df.columns)
sns.heatmap(df.corr())
plt.show()
df = pd.get_dummies(df,columns = ['Gender','Education_Level','Marital_Status','Income_Category','Card_Category'])
target = df['Attrition_Flag']
df1 = df.copy()
df1 = df1.drop('Attrition_Flag',axis =1)
X = df1
le = LabelEncoder()
target = le.fit_transform(target)
y = target
print(y)
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size = 0.2,random_state = 42)
print("Training split input- ",X_train.shape)
print("Testing split input- ",X_test.shape)
from sklearn import tree
dtree = tree.DecisionTreeClassifier()
dtree.fit(X_train,y_train)
# Predicting the values of test data
y_pred = dtree.predict(X_test)
print("Classification report - \n",classification_report(y_test,y_pred))
cm = confusion_matrix(y_test,y_pred)
plt.figure(figsize=(5,5))
sns.heatmap(data=cm,linewidths=.5,annot=True,square = True,cmap = 'Blues')
plt.ylabel('Actual label')
plt.xlabel('Predicted label')
all_sample_title = 'Accuracy Score: {0}'.format(dtree.score(X_test,y_test))
plt.title(all_sample_title,size = 15)
plt.show()
# Visualising the graph without the use of graphplt.figure(figsize = (25,25))
plt.figure(figsize=(12,12))
tree.plot_tree(dtree,fontsize=6)
plt.savefig('tree_high_dpi',dpi=100)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。