K Means On IRIS Dataset
K Means On IRIS Dataset
#Suppress warnings
import warnings
warnings.filterwarnings('ignore')
#Importing Libraries
import numpy as np
import pandas as pd
#Creating a dataframe
data=pd.read_csv("IRIS.csv")
data.head()
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 150 entries, 0 to 149
Data columns (total 5 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 sepal_length 150 non-null float64
1 sepal_width 150 non-null float64
2 petal_length 150 non-null float64
3 petal_width 150 non-null float64
4 species 150 non-null object
dtypes: float64(4), object(1)
memory usage: 6.0+ KB
data.describe
Data Visualization
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
#Histogram
plt.hist(data["sepal_length"],bins=10,color="green")
plt.xlabel("sepla_length")
plt.ylabel("petal_length")
plt.show()
#Scatter Plot
sns.pairplot(data, hue="species")
plt.show()
sns.heatmap(df1.corr(),annot=True)
plt.show()
KMeans
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=3)
kmeans.fit(data[['petal_length','petal_width']])
▾ KMeans
KMeans(n_clusters=3)
kmeans.cluster_centers_
array([[5.59583333, 2.0375 ],
[1.464 , 0.244 ],
[4.26923077, 1.34230769]])
plt.scatter(data['petal_length'],data['petal_width'],c=data['species'], cmap='rainbow')
plt.scatter(1.464, 0.244, s=200, c='b', marker='s')
plt.scatter(5.59583333, 2.0375, s=200, c='r', marker='s')
plt.scatter(4.26923077, 1.34230769, s=200, c='g', marker='s')
plt.show()
# Compute cluster centers and predict cluster index for each sample.
pred = kmeans.predict(data[['petal_length','petal_width']])
pred
array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
sample_test=np.array([6,2])
second_test=sample_test.reshape(1, -1)
kmeans.predict(second_test)
array([0])
Loading [MathJax]/jax/output/CommonHTML/fonts/TeX/fontdata.js