D_T_C/decision tree classificatio...

1167 lines
406 KiB
Plaintext
Raw Normal View History

2022-12-22 10:10:56 +00:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Import tools"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
2022-12-22 10:49:40 +00:00
"import pandas as pd\n",
"import matplotlib.pyplot as plt"
2022-12-22 10:10:56 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-12-22 10:49:40 +00:00
"# First data analysis \n",
"---"
2022-12-22 10:10:56 +00:00
]
},
2022-12-22 10:31:40 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-12-22 10:49:40 +00:00
"## Get the data"
2022-12-22 10:31:40 +00:00
]
},
2022-12-22 10:10:56 +00:00
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>sepal_length</th>\n",
" <th>sepal_width</th>\n",
" <th>petal_length</th>\n",
" <th>petal_width</th>\n",
" <th>type</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>5.1</td>\n",
" <td>3.5</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
2022-12-22 10:49:40 +00:00
" <td>Setosa</td>\n",
2022-12-22 10:10:56 +00:00
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>4.9</td>\n",
" <td>3.0</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
2022-12-22 10:49:40 +00:00
" <td>Setosa</td>\n",
2022-12-22 10:10:56 +00:00
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>4.7</td>\n",
" <td>3.2</td>\n",
" <td>1.3</td>\n",
" <td>0.2</td>\n",
2022-12-22 10:49:40 +00:00
" <td>Setosa</td>\n",
2022-12-22 10:10:56 +00:00
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4.6</td>\n",
" <td>3.1</td>\n",
" <td>1.5</td>\n",
" <td>0.2</td>\n",
2022-12-22 10:49:40 +00:00
" <td>Setosa</td>\n",
2022-12-22 10:10:56 +00:00
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5.0</td>\n",
" <td>3.6</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
2022-12-22 10:49:40 +00:00
" <td>Setosa</td>\n",
2022-12-22 10:10:56 +00:00
" </tr>\n",
" <tr>\n",
2022-12-22 10:49:40 +00:00
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
2022-12-22 10:10:56 +00:00
" </tr>\n",
" <tr>\n",
2022-12-22 10:49:40 +00:00
" <th>145</th>\n",
" <td>6.7</td>\n",
" <td>3.0</td>\n",
" <td>5.2</td>\n",
" <td>2.3</td>\n",
" <td>Virginica</td>\n",
2022-12-22 10:10:56 +00:00
" </tr>\n",
" <tr>\n",
2022-12-22 10:49:40 +00:00
" <th>146</th>\n",
" <td>6.3</td>\n",
" <td>2.5</td>\n",
2022-12-22 10:10:56 +00:00
" <td>5.0</td>\n",
2022-12-22 10:49:40 +00:00
" <td>1.9</td>\n",
" <td>Virginica</td>\n",
2022-12-22 10:10:56 +00:00
" </tr>\n",
" <tr>\n",
2022-12-22 10:49:40 +00:00
" <th>147</th>\n",
" <td>6.5</td>\n",
" <td>3.0</td>\n",
" <td>5.2</td>\n",
" <td>2.0</td>\n",
" <td>Virginica</td>\n",
2022-12-22 10:10:56 +00:00
" </tr>\n",
" <tr>\n",
2022-12-22 10:49:40 +00:00
" <th>148</th>\n",
" <td>6.2</td>\n",
" <td>3.4</td>\n",
" <td>5.4</td>\n",
" <td>2.3</td>\n",
" <td>Virginica</td>\n",
" </tr>\n",
" <tr>\n",
" <th>149</th>\n",
" <td>5.9</td>\n",
" <td>3.0</td>\n",
" <td>5.1</td>\n",
" <td>1.8</td>\n",
" <td>Virginica</td>\n",
2022-12-22 10:10:56 +00:00
" </tr>\n",
" </tbody>\n",
"</table>\n",
2022-12-22 10:49:40 +00:00
"<p>150 rows × 5 columns</p>\n",
2022-12-22 10:10:56 +00:00
"</div>"
],
"text/plain": [
2022-12-22 10:49:40 +00:00
" sepal_length sepal_width petal_length petal_width type\n",
"0 5.1 3.5 1.4 0.2 Setosa\n",
"1 4.9 3.0 1.4 0.2 Setosa\n",
"2 4.7 3.2 1.3 0.2 Setosa\n",
"3 4.6 3.1 1.5 0.2 Setosa\n",
"4 5.0 3.6 1.4 0.2 Setosa\n",
".. ... ... ... ... ...\n",
"145 6.7 3.0 5.2 2.3 Virginica\n",
"146 6.3 2.5 5.0 1.9 Virginica\n",
"147 6.5 3.0 5.2 2.0 Virginica\n",
"148 6.2 3.4 5.4 2.3 Virginica\n",
"149 5.9 3.0 5.1 1.8 Virginica\n",
"\n",
"[150 rows x 5 columns]"
2022-12-22 10:10:56 +00:00
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"col_names = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'type']\n",
"data = pd.read_csv(\"iris.csv\", skiprows=1, header=None, names=col_names)\n",
2022-12-22 10:49:40 +00:00
"data"
]
},
{
"attachments": {
"iris-machinelearning.png": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABPsAAAHdCAMAAABc94yJAAAC61BMVEX///8THRAmHxMXFA0QGQ4WIhIgGxEBAQEaJxQbGA8oJBUPEQ0yKxxRbjMkLxctJxghKRUxQBshIRQpNRk9UyZMaTEtOxo6SR8zRh5NYix8Z61vWqZFZTBDWCgJCgk3TiNHXisbHxJUdjZiTqBSZi4eLxcsLRdATSI5QRyCbbGUgr1bezcxNBpMWShzY6s1MCJTcEiFdbY2Ohs6Wis+YS6Ev0SiksaNfrpHUSRacjNsUqJkgDmYicGLeLd4XqhXdFOejcMjOBptiTo8NycoQh5DPy+nl8hfRZsyVChNaT8uSyOKqUxaai99sF2EvFKrncxVPZdLdDdXgz5EbTRmV6W1q9NDRSBRfTp3qVZ4kDyVsEpjdTRBPCFci0DLxOHEvN1jlECSqndVSZxonEKDXabX1uiNpm1+r0R5uUNIOZUuLHCQxUuEumO6tdg/Lxx5cbKtptGchb5XXis+MYozL361oM1MSire3eyPbrArJF3RzeXBsNVufzdTVTSgm8oXETlvrEOMY6lDL3ojIyRPMYtNRjYgGkvz8fWWd7XHxsJDQkVAKGh0okJFWj8VEyR9UaBvnVdlekiXk8WjvUr5+fuMiL9vRZp+l0/r5PAwMzvh4tyHioecsX6oib+Fmz7Oz8ihfLdTU0pNYFJzh01QRYlTPSPSttdmklJmYGhgYU9jVZGyk8NgTyaWrmy9usJac2SFnl5kNpCssqOSxmdhYzw2HlBsgFzfyuLR09wnEg/Hpc17dZktGTWXla1tb1jExtSmn7Tt7ep6i2NWMnM7QFxmUT2nnJpMTm2Sh5uvrMBHJDh+ck2Cm26ltI14enlpbYVqRneEg6pcRVm8tK5zYUqThm9aX324zkTb3MyWoJNyXy15W4ZoaZ5+bmtRKlLY1FeilYA5Gh2djEbCzaeOb43i2KKJdjWOfFbGv2Wfb6n8+/C+tIi1oWPFtkXXw4OvnUDk3Xr68H3177z06V368J/FlcP699WCP5K0gLXbsn+PAARPfklEQVR42uy9aWzc55XuebmqWNzX4r6zikuRFBebm0mpuEQSRbm5DReTnOHSoCxDQ0F0INgAGeuarYzjod2kPW7CCaFAYgxFgdkmYoEDIXcEXI4NAzEVI4DbnQbSwGQCJMZt3O4PA8zneZ5z3vdfVZSTeOweYC5Qb5HFYrFYrCpW/eo55zznvP/hP0RWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZEVWZD25fvz3f//3P5APfP5A1o/D1g/CFs/48Q+/2cJv/liu3H79musHJ25D6PrN/831mx/+8Dc//M//+eQVn7wwzvrNb37zQ3zyK46+6kJfa/1YD6E3MXiwH/gjcr9/c/LuBG8Hb4O5HT8OuZV2BX/41XdG7kbwAZYr/WHo9/Z3fxPyH/+f/j9b/+XPfvz/Yf0Xe/j3vz3/S4Ql/62t10+FrOKYhHi315uZk56bnp6bi6OczMw0V1lCTExiYkxMWVpKZmZKSuaDhd3lfS7neAGH5WV8hKyFhQX7dWFhF4eZnEz8vqycnPQcXDkWjuVrTr9+pssfzi3hX8/hRVNS0nALsGKwEnE7uGJiEnhWRkbC6r/+4z++/fZzzz13+/b+7z+I28j0p1Sne3weT7U/MyWjLCoqISGBvxkTG1XmyszNKk2u296711WXPDaWPVbYUpKTUhabyJ/LFeIaXVw4xd+KjTV/LzGWF8jISEvBcu5DvxxyApkBPSfF3t60KnxkphfNTO/sHBwfbfsyc3JLsrJKzF3KTPdVdu3t7O1t11Xm5WXnFeaVZnly0vj38QjLo4HLpXuK6vOys/NKi3z4R+B6ccUZduFyuC+Vldkd28cHWHicD3amb9++ffEiv5H/y0FvqT6KKVX3g//w36yurtZjtRdhteuql4UfrK7zUM/j1Qe61tdv3LgxO7uSlJSa6nZ7/c3VEz5cuBK3PK++SH41r7JyZmamq2t7Gx/bPO4KftTx7K6Zjsq8do/f685wpyY1Nvb1zeLgrAEeNSbh+qO53Lrw91JTU5Nw8cZG/vXgckdnZMTLSjAr9HQMv492JyWtrDQmuaPlmrC8E6trvWtrXV1dvb29070dnoyY0MXryMjAhb1eP1ezH3/Z7/cmYXm9PLO5urp6YmLixvqd+48e3V/v8+KyuGV4nnn7Pn70f4S+qo51HR0d42DWIdbR8v6HsnY/5H/o+PDR3Ttcr9t15+6ZwyP5Vbn4IX7v+OTCVZ05c/++/uadO3fv3OXJu7LOHB4v61/4cH8Z13F0zL+4fHx4Ri9kfuWuPeVcx527zlWdeby7u4+X7GO8aHf3dz/c3X38GMf4Fj/Y3cctX1jYf4zzcC5/sL//+PUn1x3elUNzz8+cuRtyJ83PeTtwkXu8p7g/Z87obbQ38a75jl/v3r3Px8XeNbuufnj16sLV5d//O7CvuDgxJiED/3xBkC68ctKEAzGxMWUuvPYz01Jytnf15SVs4/G+BZ28CrmOl49x1rEDPzyMy3k5mekGd9XVitYcIV9AGNKfk6nwc7BLjJB8UVFRMTFRuAUxCkA+UxOEVO6Y1X/98h/fJvouXtx5vLAB+GXmpDRj4XmbkYELAX9y66OiQBYAJbuud2/v3nb5SHZh6/mS3My0KCCO116Gq4wTApa55BfNqyLWkNGVliZ0y8nUW82bHwClcuQNITPFos8lt5p0KgX8DpaPj07jjgN+Jbk5yr4cT33HHtd2cl42IVJa5MlJcfFPpOhDT/zl+oqUfVkeIRjpB+xFG/allxTlVSZv7+HFwEf9YGcH6Lu9c/HgovnnHMxkEX05gTD2/ZjkA+98E74JD5ZvgvC7gUM9PsE9fgB+DwC/jx98/PH6+zduvDe7spKUqlzyNlf72g386gnPImXf9gyZd+JQJwecmukYqqz3VQN+QB/IJ+gj/VZIwQHQr5HsM+iLdqe6LecEfWBf0kZSqh6ISGEfjyzwEh7KISaUfY241UpRd+rGhrd6dQbw65rpWiP7ZnxpCfpcCmdfKjiX5LcHOe0F/PzNzQSfr319HeR79Pn9W308FzcNl/Hfuh/OvpGRyUl8joyNjWCN8cuF1pbz51vL5y8/9dzFi5e5NjfrStOromL53lpcXMwXYKKr//xY+eTk5Fhha0tr6xh/c3JydHRudJJrFCdHypNbBwOBqipXXFycK47v1DzhcqXhrKr+1tGlm5dvPo9rXypvyT0/Njm/efPm0ujYoFw+Li4qzqyo0MUrMD9zxaUs35tfWro3dWnq0tbU1NTS5uZVnMSpeR5tbt68zJOXrl66NLWEz6353uWoUydXcSxuDO4w7kFry2B/lSs2sTgUNJATUXG4REvrWGvr+f5AVaA/N+CKc5ZIENxIPcHrwl1ZunnzeV03ZS3dXFpamp/7+bdmXzEfewo//Pf9fAVWp8tLJxMvyoQEAiDOxRc/+Jd+tLuwvK9KTwEI4WfgZ8in0FtecNgH+O205+CtE2+fuO5qD17Q6en96ar4VPQFoKFycqzwIflEgkUp9BRhir4ogR+e+xvFE19C+d1+my/86d1/ej0uM83rTvPKcrsziD8KOLIvzkVilGaXd/Xu3dvrKsfz63wumBOEXBRRCUiaJX9Gv3UJ9zNVqnL129uZQ9kXSEkJpAidKFJdLoWfJ69r5+Dg+PjwQi7Z58nVOwZdnVW5JuzrAEIKlX0Q2Hx8c9L1TSGQkw72kTBkX65ln9wlXnlKjicrLzl5+wi674jSD7IPwu8i/uCyvC0d7CSX6P+vKkz3AXx4EeP/gLeHZtUyNyZu2AWV9/4N0g/go+p78P66gz6KK7w1NoOW4B0OpSIccXqoY8agb3Rb5J9BXxfQhw+oPwi/+vaJar+Ivr6BWQM++WJlX6qj
}
},
"cell_type": "markdown",
"metadata": {},
"source": [
"<div>\n",
"<img src=\"attachment:iris-machinelearning.png\" width=\"500\" style=\"float: left\"/> \n",
"</div>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Constructing a model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"DTC (Decision Tree Classfier) is built in such a way that contains two kinds of nodes:\n",
"1. decision nodes which contain:\n",
" 1. condition which is defined by feature_index, threshold value for the particular feature\n",
" \n",
" feature_index is the index of a feature\n",
" threshold - certain value of a feature that we use to compare other feature values it. \n",
"\n",
" 2. left, right are for accessing left and right child.\n",
" 3. inforamtion gain - variable that stores the information gained by the split \n",
"\n",
"2. leaf nodes which contain: value which is a majority class of the leaf node. - it helps us to determine the class of a data point if it ends up in this particular leaf node. \n",
"\n",
"First we will define a Node class, then we are defining a Tree class which will have all the methods that we can perform on our tree. Tree class will basically allow us to build our tree based on the splits of our data that we will perform. Those splits and it's results (left and right child) will be in form of a Node that we have defined. "
2022-12-22 10:10:56 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Node class"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class Node():\n",
2022-12-22 10:49:40 +00:00
" def __init__(self, feature_index=None, threshold=None, left=None, right=None, info_gain=None, value=None, curr_depth = 0):\n",
2022-12-22 10:10:56 +00:00
" ''' constructor ''' \n",
" \n",
2022-12-22 10:49:40 +00:00
" #for decision node\n",
2022-12-22 10:10:56 +00:00
" self.feature_index = feature_index\n",
" self.threshold = threshold\n",
" self.left = left\n",
" self.right = right\n",
2022-12-22 10:49:40 +00:00
" self.info_gain = info_gain #this variable stores the information gained by the split denoted \n",
" #by this particulart decision node\n",
2022-12-22 10:10:56 +00:00
" \n",
2022-12-22 10:49:40 +00:00
" self.curr_depth = curr_depth #this variable is for current depth of a node. It's for both type of nodes.\n",
" \n",
" #for leaf node\n",
" self.value = value\n",
" #majority class of the leaf node...\n",
" \n",
" #it will help us to determine the class of a new data point \n",
" #if the data point ends up in this particular leaf node\n",
" "
2022-12-22 10:10:56 +00:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tree class"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"class DecisionTreeClassifier():\n",
" def __init__(self, min_samples_split=2, max_depth=2):\n",
" ''' constructor '''\n",
" \n",
" # initialize the root of the tree \n",
" self.root = None\n",
" \n",
" # stopping conditions\n",
" self.min_samples_split = min_samples_split\n",
" self.max_depth = max_depth\n",
" \n",
2022-12-22 10:49:40 +00:00
" #if in a particular node the number of samples becomes less than \n",
" #min_samples_split we won't split that node any further, we will\n",
" #treat that node as a leaf node. Same goes for max_depth.\n",
"\n",
" \n",
" #MOST IMPORTANT FUNCTION - recursive function for building a binary tree using a recursive function. \n",
" #This function takes dataset as an input, performs a best split of the dataset - creating left and right child,\n",
" #which either can be pure leaf node (node with only data points with one class) \n",
" #or a node with the remaining data and the condition that performs further splits of the data in that node.\n",
"\n",
2022-12-22 10:10:56 +00:00
" def build_tree(self, dataset, curr_depth=0):\n",
" ''' recursive function to build the tree ''' \n",
" \n",
2022-12-22 10:49:40 +00:00
" #splitting the dataset into two seperate variables, one containing feature and other containing the classes.\n",
2022-12-22 10:10:56 +00:00
" X, Y = dataset[:,:-1], dataset[:,-1]\n",
2022-12-22 10:49:40 +00:00
" \n",
" #extracting the number of samples and the number of features \n",
2022-12-22 10:10:56 +00:00
" num_samples, num_features = np.shape(X)\n",
" \n",
" # split until stopping conditions are met\n",
" if num_samples>=self.min_samples_split and curr_depth<=self.max_depth:\n",
" # find the best split\n",
" best_split = self.get_best_split(dataset, num_samples, num_features)\n",
" # check if information gain is positive\n",
" if best_split[\"info_gain\"]>0:\n",
" # recur left\n",
" left_subtree = self.build_tree(best_split[\"dataset_left\"], curr_depth+1)\n",
" # recur right\n",
" right_subtree = self.build_tree(best_split[\"dataset_right\"], curr_depth+1)\n",
" # return decision node\n",
" return Node(best_split[\"feature_index\"], best_split[\"threshold\"], \n",
2022-12-22 10:49:40 +00:00
" left_subtree, right_subtree, best_split[\"info_gain\"], curr_depth=curr_depth)\n",
2022-12-22 10:10:56 +00:00
" \n",
2022-12-22 10:49:40 +00:00
" # compute leaf nod\n",
2022-12-22 10:10:56 +00:00
" leaf_value = self.calculate_leaf_value(Y)\n",
" # return leaf node\n",
2022-12-22 10:49:40 +00:00
" return Node(value=leaf_value, curr_depth=curr_depth)\n",
" \n",
2022-12-22 10:10:56 +00:00
" \n",
2022-12-22 10:49:40 +00:00
" #SECOND MOST IMPORTANT FUNCTION\n",
2022-12-22 10:10:56 +00:00
" def get_best_split(self, dataset, num_samples, num_features):\n",
" ''' function to find the best split '''\n",
" \n",
" # dictionary to store the best split\n",
" best_split = {}\n",
" max_info_gain = -float(\"inf\")\n",
" \n",
" # loop over all the features\n",
" for feature_index in range(num_features):\n",
" feature_values = dataset[:, feature_index]\n",
2022-12-22 10:49:40 +00:00
" \n",
2022-12-22 10:10:56 +00:00
" possible_thresholds = np.unique(feature_values)\n",
2022-12-22 10:49:40 +00:00
" # loop over all the unique feature values present in the data\n",
2022-12-22 10:10:56 +00:00
" for threshold in possible_thresholds:\n",
" # get current split\n",
" dataset_left, dataset_right = self.split(dataset, feature_index, threshold)\n",
" # check if childs are not null\n",
" if len(dataset_left)>0 and len(dataset_right)>0:\n",
2022-12-22 10:49:40 +00:00
" \n",
" #extracing the classes of the dataset before split, as well as the classes of the\n",
" #right and left child after the split. (we use these array for computing information gain)\n",
2022-12-22 10:10:56 +00:00
" y, left_y, right_y = dataset[:, -1], dataset_left[:, -1], dataset_right[:, -1]\n",
2022-12-22 10:49:40 +00:00
" \n",
" #compute information gain\n",
2022-12-22 10:10:56 +00:00
" curr_info_gain = self.information_gain(y, left_y, right_y, \"gini\")\n",
2022-12-22 10:49:40 +00:00
" \n",
" #update the best split (dictionary) if needed (we will update it if the current information gain \n",
" #is greater than the previous one.)\n",
2022-12-22 10:10:56 +00:00
" if curr_info_gain>max_info_gain:\n",
" best_split[\"feature_index\"] = feature_index\n",
" best_split[\"threshold\"] = threshold\n",
" best_split[\"dataset_left\"] = dataset_left\n",
" best_split[\"dataset_right\"] = dataset_right\n",
" best_split[\"info_gain\"] = curr_info_gain\n",
" max_info_gain = curr_info_gain\n",
" \n",
" # return best split\n",
" return best_split\n",
" \n",
" def split(self, dataset, feature_index, threshold):\n",
" ''' function to split the data '''\n",
" \n",
" dataset_left = np.array([row for row in dataset if row[feature_index]<=threshold])\n",
2022-12-22 10:49:40 +00:00
" #left side contains data points that meet our threshold condition, passing all the rows \n",
" #for which the feature value is less or equal to threshold.\n",
" \n",
2022-12-22 10:10:56 +00:00
" dataset_right = np.array([row for row in dataset if row[feature_index]>threshold])\n",
2022-12-22 10:49:40 +00:00
" #right side contains those rows for which the particular value is greater than threshold.\n",
"\n",
2022-12-22 10:10:56 +00:00
" return dataset_left, dataset_right\n",
" \n",
" def information_gain(self, parent, l_child, r_child, mode=\"entropy\"):\n",
" ''' function to compute information gain '''\n",
" \n",
" weight_l = len(l_child) / len(parent)\n",
" weight_r = len(r_child) / len(parent)\n",
2022-12-22 10:49:40 +00:00
" \n",
2022-12-22 10:10:56 +00:00
" if mode==\"gini\":\n",
" gain = self.gini_index(parent) - (weight_l*self.gini_index(l_child) + weight_r*self.gini_index(r_child))\n",
" else:\n",
" gain = self.entropy(parent) - (weight_l*self.entropy(l_child) + weight_r*self.entropy(r_child))\n",
2022-12-22 10:49:40 +00:00
" \n",
" #here we can see two types of measuring the information contained in a system, gini and entropy.\n",
" #entropy = ∑-p_i*log(p_i)\n",
" #gini_index = 1 - ∑p_i**2, \n",
" #where p_i = probability of class i \n",
" #Why would we use gini function? Unlike entropy function, gini doesn't have logarithmic part,\n",
" #so by choosing gini function we have actually done a favor to us which is saving computation time - \n",
" #(it is easier to find square of a quantity than to find the logarithm.)\n",
" \n",
2022-12-22 10:10:56 +00:00
" return gain\n",
" \n",
" def entropy(self, y):\n",
" ''' function to compute entropy '''\n",
" \n",
" class_labels = np.unique(y)\n",
" entropy = 0\n",
" for cls in class_labels:\n",
" p_cls = len(y[y == cls]) / len(y)\n",
" entropy += -p_cls * np.log2(p_cls)\n",
" return entropy\n",
" \n",
" def gini_index(self, y):\n",
" ''' function to compute gini index '''\n",
" \n",
" class_labels = np.unique(y)\n",
" gini = 0\n",
" for cls in class_labels:\n",
" p_cls = len(y[y == cls]) / len(y)\n",
" gini += p_cls**2\n",
" return 1 - gini\n",
2022-12-22 10:49:40 +00:00
" \n",
" \n",
2022-12-22 10:10:56 +00:00
" def calculate_leaf_value(self, Y):\n",
" ''' function to compute leaf node '''\n",
2022-12-22 10:49:40 +00:00
" #the value of a leaf node is the majority class present in the node\n",
" #so...we just need to find the most occuring element in y!\n",
2022-12-22 10:10:56 +00:00
" \n",
" Y = list(Y)\n",
" return max(Y, key=Y.count)\n",
" \n",
2022-12-22 10:49:40 +00:00
" \n",
" \n",
" #THIS FUNCTION HELPS US TO VISUALIZE DECISION TREE\n",
2022-12-22 10:10:56 +00:00
" def print_tree(self, tree=None, indent=\" \"):\n",
" ''' function to print the tree '''\n",
" \n",
" if not tree:\n",
" tree = self.root\n",
"\n",
" if tree.value is not None:\n",
" print(tree.value)\n",
"\n",
" else:\n",
2022-12-22 10:49:40 +00:00
" print(\"X_\"+str(tree.feature_index), \"≤\", tree.threshold, \"?\", np.round(tree.info_gain,3))\n",
" print(tree.curr_depth + 1,\":\",\"%sleft: \" % (indent), end=\"\")\n",
" self.print_tree(tree.left, indent + \" \")\n",
" print(tree.curr_depth + 1 ,\":\",\"%sright: \" % (indent), end=\"\")\n",
" self.print_tree(tree.right, indent + \" \")\n",
2022-12-22 10:10:56 +00:00
" \n",
" def fit(self, X, Y):\n",
" ''' function to train the tree '''\n",
" \n",
" dataset = np.concatenate((X, Y), axis=1)\n",
" self.root = self.build_tree(dataset)\n",
" \n",
" def predict(self, X):\n",
" ''' function to predict new dataset '''\n",
" \n",
" preditions = [self.make_prediction(x, self.root) for x in X]\n",
" return preditions\n",
" \n",
" def make_prediction(self, x, tree):\n",
" ''' function to predict a single data point '''\n",
" \n",
" if tree.value!=None: return tree.value\n",
" feature_val = x[tree.feature_index]\n",
" if feature_val<=tree.threshold:\n",
" return self.make_prediction(x, tree.left)\n",
" else:\n",
" return self.make_prediction(x, tree.right)"
]
},
2022-12-22 10:49:40 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's now apply our model on the data."
]
},
2022-12-22 10:10:56 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train-Test split"
]
},
2022-12-22 10:49:40 +00:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here I am creating train and test dataset. We will train our model on Train dataset and test it with Test dataset."
]
},
2022-12-22 10:10:56 +00:00
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"X = data.iloc[:, :-1].values\n",
"Y = data.iloc[:, -1].values.reshape(-1,1)\n",
"from sklearn.model_selection import train_test_split\n",
"X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=.2, random_state=41)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Fit the model"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
2022-12-22 10:49:40 +00:00
"outputs": [],
"source": [
"classifier = DecisionTreeClassifier(min_samples_split=3, max_depth=3)\n",
"classifier.fit(X_train,Y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model visualization"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
2022-12-22 10:10:56 +00:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2022-12-22 10:49:40 +00:00
"X_2 ≤ 1.9 ? 0.337\n",
"1 : left: Setosa\n",
"1 : right: X_3 ≤ 1.5 ? 0.427\n",
"2 : left: X_2 ≤ 4.9 ? 0.051\n",
"3 : left: Versicolor\n",
"3 : right: Virginica\n",
"2 : right: X_2 ≤ 5.0 ? 0.02\n",
"3 : left: X_1 ≤ 2.8 ? 0.208\n",
"4 : left: Virginica\n",
"4 : right: Versicolor\n",
"3 : right: Virginica\n"
2022-12-22 10:10:56 +00:00
]
}
],
"source": [
"classifier.print_tree()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-12-22 10:49:40 +00:00
"## Testing the model"
2022-12-22 10:10:56 +00:00
]
},
{
"cell_type": "code",
2022-12-22 10:49:40 +00:00
"execution_count": 8,
2022-12-22 10:10:56 +00:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.9333333333333333"
]
},
2022-12-22 10:49:40 +00:00
"execution_count": 8,
2022-12-22 10:10:56 +00:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2022-12-22 10:49:40 +00:00
"Y_pred = classifier.predict(X_test)\n",
"\n",
2022-12-22 10:10:56 +00:00
"from sklearn.metrics import accuracy_score\n",
"accuracy_score(Y_test, Y_pred)"
]
2022-12-22 10:49:40 +00:00
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Second data analysis \n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Our objective here is to predict if the customer will purchase the iPhone or not given their gender, age and salary."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get the data"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"dataset = pd.read_csv(\"iphone_purchase_records.csv\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Gender</th>\n",
" <th>Age</th>\n",
" <th>Salary</th>\n",
" <th>Purchase Iphone</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Male</td>\n",
" <td>19</td>\n",
" <td>19000</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Male</td>\n",
" <td>35</td>\n",
" <td>20000</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Female</td>\n",
" <td>26</td>\n",
" <td>43000</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>Female</td>\n",
" <td>27</td>\n",
" <td>57000</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>Male</td>\n",
" <td>19</td>\n",
" <td>76000</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>395</th>\n",
" <td>Female</td>\n",
" <td>46</td>\n",
" <td>41000</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>396</th>\n",
" <td>Male</td>\n",
" <td>51</td>\n",
" <td>23000</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>397</th>\n",
" <td>Female</td>\n",
" <td>50</td>\n",
" <td>20000</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>398</th>\n",
" <td>Male</td>\n",
" <td>36</td>\n",
" <td>33000</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>399</th>\n",
" <td>Female</td>\n",
" <td>49</td>\n",
" <td>36000</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>400 rows × 4 columns</p>\n",
"</div>"
],
"text/plain": [
" Gender Age Salary Purchase Iphone\n",
"0 Male 19 19000 0\n",
"1 Male 35 20000 0\n",
"2 Female 26 43000 0\n",
"3 Female 27 57000 0\n",
"4 Male 19 76000 0\n",
".. ... ... ... ...\n",
"395 Female 46 41000 1\n",
"396 Male 51 23000 1\n",
"397 Female 50 20000 1\n",
"398 Male 36 33000 0\n",
"399 Female 49 36000 1\n",
"\n",
"[400 rows x 4 columns]"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dataset"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Converting gender to number"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Gender_Female</th>\n",
" <th>Gender_Male</th>\n",
" <th>Age</th>\n",
" <th>Salary</th>\n",
" <th>Purchase Iphone</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>19</td>\n",
" <td>19000</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>35</td>\n",
" <td>20000</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>26</td>\n",
" <td>43000</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>27</td>\n",
" <td>57000</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>19</td>\n",
" <td>76000</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>395</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>46</td>\n",
" <td>41000</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>396</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>51</td>\n",
" <td>23000</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>397</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>50</td>\n",
" <td>20000</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>398</th>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>36</td>\n",
" <td>33000</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>399</th>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>49</td>\n",
" <td>36000</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>400 rows × 5 columns</p>\n",
"</div>"
],
"text/plain": [
" Gender_Female Gender_Male Age Salary Purchase Iphone\n",
"0 0 1 19 19000 0\n",
"1 0 1 35 20000 0\n",
"2 1 0 26 43000 0\n",
"3 1 0 27 57000 0\n",
"4 0 1 19 76000 0\n",
".. ... ... ... ... ...\n",
"395 1 0 46 41000 1\n",
"396 0 1 51 23000 1\n",
"397 1 0 50 20000 1\n",
"398 0 1 36 33000 0\n",
"399 1 0 49 36000 1\n",
"\n",
"[400 rows x 5 columns]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#Convert gender variable into dummy/indicator variables or (binary vairbles) essentialy 1's and 0's.\n",
"#I chose the variable name one_hot_data bescause in ML one-hot is a group of bits among which the \n",
"#legal combinations of values are only those with a single high (1) bit and all the others low (0)\n",
"\n",
"one_hot_data = pd.get_dummies(dataset)\n",
"new_cols = [\"Gender_Female\", \"Gender_Male\", \"Age\", \"Salary\",\"Purchase Iphone\"]\n",
"data2 = one_hot_data[new_cols]\n",
"data2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train-Test split"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"X2 = data2.iloc[:, :-1].values\n",
"Y2 = data2.iloc[:, -1].values.reshape(-1,1)\n",
"\n",
"from sklearn.model_selection import train_test_split\n",
"X_train2, X_test2, Y_train2, Y_test2 = train_test_split(X2, Y2, test_size=.2, random_state=41)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(320, 80)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(X_train2), len(X_test2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Fit the model "
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"classifier2 = DecisionTreeClassifier(min_samples_split=3, max_depth=10)\n",
"classifier2.fit(X_train2,Y_train2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualizing the model "
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"X_2 ≤ 42 ? 0.176\n",
"1 : left: X_3 ≤ 90000 ? 0.185\n",
"2 : left: X_2 ≤ 36 ? 0.005\n",
"3 : left: 0\n",
"3 : right: X_3 ≤ 80000 ? 0.049\n",
"4 : left: X_3 ≤ 74000 ? 0.018\n",
"5 : left: X_2 ≤ 37 ? 0.003\n",
"6 : left: X_3 ≤ 62000 ? 0.03\n",
"7 : left: 0\n",
"7 : right: X_3 ≤ 70000 ? 0.375\n",
"8 : left: 1\n",
"8 : right: 0\n",
"6 : right: 0\n",
"5 : right: X_2 ≤ 39 ? 0.022\n",
"6 : left: X_3 ≤ 79000 ? 0.111\n",
"7 : left: X_0 ≤ 0 ? 0.056\n",
"8 : left: X_3 ≤ 77000 ? 0.444\n",
"9 : left: 0\n",
"9 : right: 1\n",
"8 : right: X_3 ≤ 78000 ? 0.444\n",
"9 : left: 1\n",
"9 : right: 0\n",
"7 : right: 0\n",
"6 : right: X_2 ≤ 41 ? 0.052\n",
"7 : left: 0\n",
"7 : right: X_0 ≤ 0 ? 0.111\n",
"8 : left: 0\n",
"8 : right: 1\n",
"4 : right: 1\n",
"2 : right: X_2 ≤ 34 ? 0.043\n",
"3 : left: X_3 ≤ 118000 ? 0.121\n",
"4 : left: X_3 ≤ 112000 ? 0.214\n",
"5 : left: 1\n",
"5 : right: X_3 ≤ 116000 ? 0.122\n",
"6 : left: 0\n",
"6 : right: X_2 ≤ 26 ? 0.167\n",
"7 : left: 0\n",
"7 : right: X_2 ≤ 31 ? 0.111\n",
"8 : left: 1\n",
"8 : right: 0\n",
"4 : right: 1\n",
"3 : right: 1\n",
"1 : right: X_3 ≤ 38000 ? 0.014\n",
"2 : left: X_3 ≤ 22000 ? 0.016\n",
"3 : left: X_2 ≤ 46 ? 0.125\n",
"4 : left: 0\n",
"4 : right: 1\n",
"3 : right: 1\n",
"2 : right: X_2 ≤ 50 ? 0.027\n",
"3 : left: X_3 ≤ 82000 ? 0.052\n",
"4 : left: X_2 ≤ 48 ? 0.042\n",
"5 : left: X_2 ≤ 47 ? 0.069\n",
"6 : left: X_3 ≤ 43000 ? 0.125\n",
"7 : left: 0\n",
"7 : right: X_3 ≤ 79000 ? 0.112\n",
"8 : left: X_3 ≤ 47000 ? 0.037\n",
"9 : left: 1\n",
"9 : right: X_2 ≤ 46 ? 0.12\n",
"10 : left: 0\n",
"10 : right: 1\n",
"8 : right: 0\n",
"6 : right: 1\n",
"5 : right: X_3 ≤ 39000 ? 0.375\n",
"6 : left: 1\n",
"6 : right: 0\n",
"4 : right: X_3 ≤ 139000 ? 0.028\n",
"5 : left: X_2 ≤ 46 ? 0.041\n",
"6 : left: X_0 ≤ 0 ? 0.122\n",
"7 : left: 1\n",
"7 : right: X_2 ≤ 44 ? 0.167\n",
"8 : left: X_2 ≤ 43 ? 0.111\n",
"9 : left: 0\n",
"9 : right: 1\n",
"8 : right: 0\n",
"6 : right: 1\n",
"5 : right: 1\n",
"3 : right: X_0 ≤ 0 ? 0.015\n",
"4 : left: 1\n",
"4 : right: X_3 ≤ 42000 ? 0.085\n",
"5 : left: 0\n",
"5 : right: X_2 ≤ 58 ? 0.02\n",
"6 : left: X_2 ≤ 52 ? 0.027\n",
"7 : left: X_3 ≤ 114000 ? 0.125\n",
"8 : left: 1\n",
"8 : right: 1\n",
"7 : right: 1\n",
"6 : right: X_3 ≤ 76000 ? 0.444\n",
"7 : left: 1\n",
"7 : right: 0\n"
]
}
],
"source": [
"classifier2.print_tree()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Testing the model"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.8625"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Y_pred2 = classifier2.predict(X_test2)\n",
"\n",
"from sklearn.metrics import accuracy_score\n",
"accuracy_score(Y_test2, Y_pred2)"
]
2022-12-22 10:10:56 +00:00
}
],
"metadata": {
"kernelspec": {
2022-12-22 10:49:40 +00:00
"display_name": "Python [conda env:.conda-firstEnv]",
2022-12-22 10:10:56 +00:00
"language": "python",
2022-12-22 10:49:40 +00:00
"name": "conda-env-.conda-firstEnv-py"
2022-12-22 10:10:56 +00:00
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2022-12-22 10:49:40 +00:00
"version": "3.10.8"
2022-12-22 10:10:56 +00:00
}
},
"nbformat": 4,
"nbformat_minor": 4
}