CSET301 LabW8L2
CSET301 LabW8L2
Decision Tree
Decision tree is the most powerful and popular tool for classification and prediction. A Decision tree is a flowchart like tree structure,
where each node finds the best threshold on that feature to further classify/predict more accurately, each branch represents an
outcome of that threshold, and each leaf node holds a class label.
In [1]:
from matplotlib import pyplot as plt # For plotting
from sklearn import datasets # For loading standard datasets
from sklearn.tree import DecisionTreeClassifier # To run decision tree model
from sklearn import tree # to visualize decision trees
Quick Tip: sklearn.datasets has some toy datasets, the package also has helpers to fetch larger datasets commonly used by the machine
learning community
In [2]:
# Prepare the data data
iris = datasets.load_iris()
X = iris.data
y = iris.target
In [3]:
# Initialize the model
clf = tree.DecisionTreeClassifier()
# Fir the model
clf.fit(iris.data,iris.target)
Task
Train your own decision tree and play with the following hyper-parameters then state your observations on at least 15 different
hyper-parameter settings. Following are only some of the parameters:
Print accuracies for each hyper-parameter setting used. Print in following format:
1. PARAMS[random_state=1, max_depth=....] , Accuracy=0.97
2. PARAMS[random_state=42, min_samples_split=....] , Accuracy=0.94
..
.
Perform the same set of acitvites on different dataset: https://gist.github.com/kudaliar032/b8cf65d84b73903257ed603f6c1a2508
In [4]:
# initialise and then Fit the classifier
clf = tree.DecisionTreeClassifier()
clf.fit(X, y)
In [5]:
# Gives text representation to the decision tree trained
text_representation = tree.export_text(clf)
print(text_representation)
In [6]:
# To save the above info in a text file
with open("decistion_tree.log", "w") as fout:
fout.write(text_representation)
In [7]:
# Visualize the results in a beautiful manner using sklearn plot_tree
# Look documentation for modifying fonts: https://scikit-learn.org/stable/modules/generated/sklearn.tree.plo
fig = plt.figure(figsize=(25,20))
_ = tree.plot_tree(clf,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True)
In the above figure color of the nodes represent the majoritiy of the class
In [8]:
# TODO: Write accuracy function here
import sklearn.metrics as metrics
from sklearn.model_selection import train_test_split
X_s_train, X_s_test, y_s_train, y_s_test = train_test_split(X, y, test_size=0.25, random_state=6)
y_pred=clf.predict(X_s_test)
from sklearn.model_selection import train_test_split
print("Accuracy:",metrics.accuracy_score(y_s_test, y_pred))
Accuracy: 1.0
In [9]:
# TODO: Print 15 hyperparam settings along with accuracy
import sklearn.metrics as metrics
for i in range(5,13):
clf1=DecisionTreeClassifier(criterion = "gini", splitter = 'random', max_leaf_nodes = 10, min_samples_leaf
clf1.fit(X,y)
X_s_train, X_s_test, y_s_train, y_s_test = train_test_split(X, y, test_size=0.25, random_state=6)
y_pred=clf1.predict(X_s_test)
print('PARAMS[criterion = "gini", splitter = "random", max_leaf_nodes = 10, min_samples_leaf = 5, max_dept
for i in range(35,43):
clf1=DecisionTreeClassifier(criterion = "entropy", splitter = 'random',min_samples_split=4, max_leaf_nodes
clf1.fit(X,y)
X_s_train, X_s_test, y_s_train, y_s_test = train_test_split(X, y, test_size=0.25, random_state=6)
y_pred=clf1.predict(X_s_test)
print('PARAMS[criterion = "entropy", splitter = "random",min_samples_split=4, max_leaf_nodes = 5, min_samp
In [10]:
# Save the figure
fig.savefig("decistion_tree.png")
https://stackoverflow.com/questions/35064304/runtimeerror-make-sure-the-graphviz-executables-are-on-your-systems-path-aft
Graph visualization is a way of representing structural information as diagrams of abstract graphs and networks.
In [11]:
import graphviz
# DOT data - since graphviz accepts data in DOT we will convert our tree into a compatable format
dot_data = tree.export_graphviz(clf, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True)
# Draw graph
graph = graphviz.Source(dot_data, format="png")
graph
petal length (cm) <= 4.95 petal length (cm) <= 4.85
gini = 0.168 gini = 0.043
samples = 54 samples = 46
value = [0, 49, 5] value = [0, 1, 45]
class = versicolor class = virginica
petal width (cm) <= 1.65 petal width (cm) <= 1.55 sepal width (cm) <= 3.1
gini = 0.0
gini = 0.041 gini = 0.444 gini = 0.444
samples = 43
samples = 48 samples = 6 samples = 3
value = [0, 0, 43]
value = [0, 47, 1] value = [0, 2, 4] value = [0, 1, 2]
class = virginica
class = versicolor class = virginica class = virginica
In [12]:
graph.render("decision_tree_graphivz")
Out[12]: 'decision_tree_graphivz.png'
Resources
https://mljar.com/blog/visualize-decision-tree/ (source code)
https://towardsdatascience.com/visualizing-decision-trees-with-python-scikit-learn-graphviz-matplotlib-1c50b4aa68dc
https://explained.ai/decision-tree-viz/
https://scikit-learn.org/stable/modules/generated/sklearn.tree.export_graphviz.html
In [ ]: