diff --git a/cheese.ipynb b/cheese.ipynb index df8e66c..8e8de64 100644 --- a/cheese.ipynb +++ b/cheese.ipynb @@ -36,7 +36,10 @@ "from sklearn import tree\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import LabelEncoder\n", - "from sklearn.linear_model import LinearRegression, Ridge, Lasso, ElasticNet" + "from sklearn.linear_model import LinearRegression, Ridge, Lasso, ElasticNet\n", + "from sklearn.neighbors import KNeighborsClassifier\n", + "\n", + "from sklearn.metrics import accuracy_score\n" ] }, { @@ -44,7 +47,7 @@ "id": "ceb71784-b0bf-4015-b8e6-78007c368e49", "metadata": {}, "source": [ - "For this project, we chose to study cheeses. We retrieved a [dataset from Kaggle](https://www.kaggle.com/datasets/joebeachcapital/cheese) that gives several characteristics for more than $1000$ cheeses. We have information about the origin, the milk, types, texture, rind, flavor, etc. of these cheeses. " + "For this project, we chose to study cheeses. We retrieved this [dataset from Kaggle](https://www.kaggle.com/datasets/joebeachcapital/cheese) that gives several characteristics for more than $1000$ cheeses. We have information about the origin, the milk, types, texture, rind, flavor, etc. of these cheeses. " ] }, { @@ -378,14 +381,6 @@ " return list(c[0] for c in data_colors), list(c[1] for c in data_colors), list(c[2] for c in data_colors)" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "471728e0-5543-4afd-bf54-d21bd49dda75", - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": null, @@ -415,48 +410,12 @@ "\n" ] }, - { - "cell_type": "markdown", - "id": "da7e65cd-5324-496b-affd-246ae4cf9813", - "metadata": {}, - "source": [ - "### II.A Decision tree" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a8d0848f-b844-4a08-976d-4d1370070f73", - "metadata": {}, - "outputs": [], - "source": [ - "Y=LabelEncoder().fit_transform(data_features[\"country\"])\n", - "X=data_features.drop(columns=[\"cheese\",\"country\",\"region\",\"vegetarian\",\"location\",\"latitude\",\"longitude\"])\n", - "data_train, data_test, target_train, target_test = train_test_split(\n", - " X, Y)\n", - "c=tree.DecisionTreeClassifier(max_depth=4,random_state=0)\n", - "c=c.fit(data_train,target_train)\n", - "plt.figure(figsize=(150,100))\n", - "ax=plt.subplot()\n", - "\n", - "tree.plot_tree(c,ax=ax,filled=True,feature_names=X.columns,);" - ] - }, - { - "cell_type": "markdown", - "id": "fca7080e-cb7b-4030-bafd-9036ecdb15ab", - "metadata": {}, - "source": [ - "We built a decision tree for our cheese database. \n", - "We noticed that the most relevant features, those used by the decision tree, focus on the texture of the cheese and the taste on the cheeses (rindless, bloomy, soft, tangy), rather than on the animal milk used. \n" - ] - }, { "cell_type": "markdown", "id": "30bf1cd5-9b95-4300-a172-f36d870c49f6", "metadata": {}, "source": [ - "### Linear regression: find location depending on the cheese characteristics\n", + "### II.A Linear regression: find location depending on the cheese characteristics\n", "\n", "We try to do a linear regression over the data to see whether, given a cheese, we can guess where it originates from. We are going to see that it does not work very well, each regression model has a $R^2$ coefficient of less than $0.3$, which is very bad. \n" ] @@ -468,6 +427,7 @@ "metadata": {}, "outputs": [], "source": [ + "old_data_features=data_features.copy()\n", "for col in [\"cheese\",\"country\",\"region\",\"location\",\"vegetarian\",\"vegan\"]:\n", " try: \n", " del data_features[col]\n", @@ -507,11 +467,97 @@ "metadata": {}, "source": [ "Not good, even quite bad. \n", - "We cannot find the region a cheese originates from given its characteristic. \n", - "\n", + "We cannot find the precise place a cheese originates from given its characteristic. \n", + "Can we get the country, at least? We are going to try to achieve this using a `DecisionTree`. \n", "\n" ] }, + { + "cell_type": "markdown", + "id": "da7e65cd-5324-496b-affd-246ae4cf9813", + "metadata": {}, + "source": [ + "### II.B Decision tree" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a8d0848f-b844-4a08-976d-4d1370070f73", + "metadata": {}, + "outputs": [], + "source": [ + "data_features=old_data_features.copy()\n", + "Y=LabelEncoder().fit_transform(data_features[\"country\"])\n", + "X=data_features.drop(columns=[\"cheese\",\"country\",\"region\",\"vegetarian\",\"location\",\"vegan\", \n", + " \"latitude\",\n", + " \"longitude\"\n", + " ])\n", + "print(f\"There are {len(set(data[\"country\"]))} different countries in the data\")\n", + "X" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "252e81e1-014c-4f73-86ae-8b149dd433e8", + "metadata": {}, + "outputs": [], + "source": [ + "X_train, X_test, Y_train, Y_test = train_test_split(\n", + " X, Y,random_state=0, test_size=1/10)\n", + "c=KNeighborsClassifier()\n", + "\n", + "c=c.fit(X_train,Y_train)\n", + "\n", + "predY_train=c.predict(X_train)\n", + "predY_test=c.predict(X_test)\n", + "ac_train=accuracy_score(Y_train, predY_train)\n", + "ac_test=accuracy_score(Y_test, predY_test)\n", + "print(f\"{ac_train=},{ac_test=}\")\n" + ] + }, + { + "cell_type": "markdown", + "id": "6414bf4d-d360-4264-a0cb-c3c25178e025", + "metadata": {}, + "source": [ + "Thus, an approach using a K-nearest-neighbor classifier does not seem to work. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65923e7a-1724-48cb-8a1f-eca64a834cdd", + "metadata": {}, + "outputs": [], + "source": [ + "X_train, X_test, Y_train, Y_test = train_test_split(\n", + " X, Y,random_state=42, test_size=1/100)\n", + "\n", + "c=tree.DecisionTreeClassifier(max_depth=5,random_state=42, criterion=\"log_loss\")\n", + "c=c.fit(X_train,Y_train)\n", + "\n", + "predY_train=c.predict(X_train)\n", + "predY_test=c.predict(X_test)\n", + "ac_train=accuracy_score(Y_train, predY_train)\n", + "ac_test=accuracy_score(Y_test, predY_test)\n", + "print(f\"{ac_train=},{ac_test=}\")\n", + "plt.figure(figsize=(150,100))\n", + "ax=plt.subplot()\n", + "\n", + "tree.plot_tree(c,ax=ax,filled=True,feature_names=X.columns);" + ] + }, + { + "cell_type": "markdown", + "id": "178cdfec-dd07-4b6b-983c-d9a6e6960f43", + "metadata": {}, + "source": [ + "We built a decision tree for our cheese database. \n", + "We noticed that the most relevant features, those used by the decision tree, focus on the texture of the cheese and the taste on the cheeses (rindless, bloomy, soft, tangy), rather than on the animal milk used. \n" + ] + }, { "cell_type": "markdown", "id": "038cd38e-3890-4f73-91a7-c30294b3bc5b",