598 lines
16 KiB
Plaintext
598 lines
16 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "5f7c9658-c285-4854-96c0-e899fc55421b",
|
|
"metadata": {},
|
|
"source": [
|
|
"# DM project: cheese"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "7f4f2b89-8257-468c-9f5e-a77e11b8b8ff",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"import time\n",
|
|
"import json\n",
|
|
"import random\n",
|
|
"import pandas as pd\n",
|
|
"import plotly.express as px\n",
|
|
"import tqdm.notebook as tqdm\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"from geopy.geocoders import Nominatim\n",
|
|
"from IPython.display import display, HTML\n",
|
|
"from mlxtend.preprocessing import TransactionEncoder\n",
|
|
"from mlxtend.frequent_patterns import apriori, association_rules\n",
|
|
"from sklearn import tree\n",
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
"from sklearn.preprocessing import LabelEncoder\n",
|
|
"from sklearn.linear_model import LinearRegression"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ceb71784-b0bf-4015-b8e6-78007c368e49",
|
|
"metadata": {},
|
|
"source": [
|
|
"We use the following dataset from Kaggle: [Cheese: 248 different types of cheese with various characteristics](https://www.kaggle.com/datasets/joebeachcapital/cheese). "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "1a0afba8-692b-4377-a2ce-5114983e3bbb",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"data = pd.read_csv(\"cheeses.csv\")\n",
|
|
"data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "bf3b548c-5ac4-4126-9ae9-5578ad158015",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Cleaning and pre-processing"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "2018aac2-6f3d-489a-b5d0-90b7c7793076",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(set(data[\"color\"]))\n",
|
|
"data[pd.isnull(data[\"color\"])]\n",
|
|
"data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a0a77563-518e-4808-b744-9fc0c76763fe",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(len(data[pd.isnull(data[\"calcium_content\"])]))\n",
|
|
"print(len(data[pd.isnull(data[\"fat_content\"])]))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "4590cffd-d4a9-4e15-8fd5-cbb22f048300",
|
|
"metadata": {},
|
|
"source": [
|
|
"Since those two columns have too much null data, we choose to remove them. \n",
|
|
"Similarly, we removed other columns we are not interested in: "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "c8489ffa-1067-4eb7-b65a-2fa18fdb4b04",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"unused_columns = [\"alt_spellings\", \"producers\", \"calcium_content\", \"url\", \"fat_content\", \"synonyms\"]\n",
|
|
"for col in unused_columns:\n",
|
|
" if col in data.columns:\n",
|
|
" del data[col]\n",
|
|
"data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "74044e9b-6ce4-420f-b1ad-492a4362ffb4",
|
|
"metadata": {},
|
|
"source": [
|
|
"Now, we are interested in having only one column representing the location for each cheese. "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "633ed80e-e416-41f6-ae58-b86ce4c132af",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"data=data.dropna(subset=[\"country\",\"region\"], how=\"all\")\n",
|
|
"data=data.fillna(value={\"country\":\"\"})\n",
|
|
"data=data.fillna(value={\"region\":\"\"})\n",
|
|
"print(f\"{len(data)} rows remaining\")\n",
|
|
"data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "fd66568f-78d4-4e1a-a91c-8ec483b4b03c",
|
|
"metadata": {},
|
|
"source": [
|
|
"We removed 6 rows for which we could not find a suitable location. "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "7ef7494b-ff08-40a5-890f-e0f718cf2842",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"data.loc[data.country.str.contains(\"England, Great Britain, United Kingdom\")|data.country.str.contains(\"England, United Kingdom\"),\"country\"]=\"England\"\n",
|
|
"data.loc[data.country.str.contains(\"Scotland\"),\"country\"]=\"Scotland\"\n",
|
|
"data.loc[data.country.str.contains(\"Great Britain, United Kingdom, Wales\")|data.country.str.contains(\"United Kingdom, Wales\"),\"country\"]=\"Wales\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "c479661d-4019-4557-8c53-d4223f0f246c",
|
|
"metadata": {},
|
|
"source": [
|
|
"We change some countries to get more easily the location. "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "fb044984-c33c-492c-91a2-4e9fff29ceb3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"data=data.drop(index=data[data[\"country\"].str.contains(\",\")].index)\n",
|
|
"data=data.drop(index=data[data[\"country\"].str.contains(\" and \")].index)\n",
|
|
"data.reset_index()\n",
|
|
"data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2f42c973-247a-4f51-947e-fbd76f8f12fc",
|
|
"metadata": {},
|
|
"source": [
|
|
"We removed 41 cheeses because they can come froms several countries. "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "59c4e6e7-d624-45a5-a9ea-eb375102b771",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"data[\"location\"]=data[\"region\"]+\", \"+data[\"country\"]\n",
|
|
"data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d42869b5-a4ea-4cd6-bd0e-1532af90f2da",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Converting the locations to GPS coordinates\n",
|
|
"\n",
|
|
"In order to have more numeric data to apply a classification algorithm, we transform the location to GPS coordinates and the color to RGB. "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "debb780e-ec13-4502-ac44-6001335e507d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def str_to_gps(loc):\n",
|
|
" l=loc.split(\",\")\n",
|
|
" loc=\",\".join([l[0],l[-1]])# removing details gives less errors while fetching the GPS coordinates\n",
|
|
" try:\n",
|
|
" res=Nominatim(user_agent=\"dmProject\").geocode(loc) \n",
|
|
" return (res.latitude, res.longitude)\n",
|
|
" except AttributeError:\n",
|
|
" loc=l[-1]\n",
|
|
" res=Nominatim(user_agent=\"dmProject\").geocode(loc) \n",
|
|
" return (res.latitude, res.longitude)\n",
|
|
"def get_locations(backup_file):\n",
|
|
" errors=set()\n",
|
|
" if os.path.isfile(backup_file):\n",
|
|
" with open(backup_file) as f:\n",
|
|
" return json.load(f)\n",
|
|
" locations_to_gps = {}\n",
|
|
" for loc in tqdm.tqdm(locs):\n",
|
|
" time.sleep(1) # We don't want to overload the Nominatim server which will stop responding\n",
|
|
" try:\n",
|
|
" locations_to_gps[loc] = str_to_gps(loc)\n",
|
|
" print(loc, locations_to_gps[loc])\n",
|
|
" except AttributeError:\n",
|
|
" errors.add(loc)\n",
|
|
" print(loc, file=sys.stderr)\n",
|
|
" with open(backup_file, \"w\") as f:\n",
|
|
" json.dump(locations_to_gps, f)\n",
|
|
" return locations_to_gps"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "204d1446-e58f-4585-8ac0-7466930e4291",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"locs=set(data[\"location\"])\n",
|
|
"locations_to_gps = get_locations(\"locations_to_gps.json\")\n",
|
|
"latitudes, longitudes = [], []\n",
|
|
"for i, value in enumerate(data.location):\n",
|
|
" latitudes.append(locations_to_gps[value][0])\n",
|
|
" longitudes.append(locations_to_gps[value][1])\n",
|
|
"data[\"latitude\"] = latitudes\n",
|
|
"data[\"longitude\"] = longitudes"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d41b1dc8-90df-44b8-9d83-d218f82a3637",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"fig = px.scatter_map(data, \n",
|
|
" lat=\"latitude\", \n",
|
|
" lon=\"longitude\", \n",
|
|
" hover_name=\"cheese\", \n",
|
|
" hover_data=[\"cheese\"],\n",
|
|
" color=\"milk\",\n",
|
|
" zoom=1.5,\n",
|
|
" height=800,\n",
|
|
" width=1400)\n",
|
|
"\n",
|
|
"fig.update_layout(mapbox_style=\"open-street-map\")\n",
|
|
"fig.update_layout(margin={\"r\":0,\"t\":0,\"l\":0,\"b\":0})\n",
|
|
"fig.show();"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "66ce4e4a-7006-411f-abd0-ee94d7cf99b3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def filter_df(df, cols=None):\n",
|
|
" if cols is None:\n",
|
|
" cols = [\"milk\", \n",
|
|
" \"color\",\n",
|
|
" \"type\", \"texture\", \"flavor\", \"aroma\", \"family\", \"rind\"]\n",
|
|
"\n",
|
|
" df = df.copy()\n",
|
|
" attributes = set() # Get all the possible attributes (some are mixed in different columns)\n",
|
|
" for col in cols:\n",
|
|
" values = set()\n",
|
|
" for val in set(df[col]):\n",
|
|
" if type(val) == float: # skip NaN values\n",
|
|
" continue\n",
|
|
" values = values.union([x.strip() for x in set(val.split(\",\"))])\n",
|
|
" attributes = attributes.union(values)\n",
|
|
" row_attrs = [set() for _ in range(len(df))] # get the attributes specific to each row\n",
|
|
" for col in cols:\n",
|
|
" for i, row in enumerate(df[col]):\n",
|
|
" if type(row) != float:\n",
|
|
" row_attrs[i] = row_attrs[i].union([x.strip() for x in row.split(\",\")])\n",
|
|
" for attr in attributes: # Add attributes rows\n",
|
|
" df[attr] = list(attr in row_attrs[i] for i in range(len(df[col])))\n",
|
|
" df=df.copy()\n",
|
|
" for col in cols:\n",
|
|
" del df[col]\n",
|
|
"\n",
|
|
" return df.copy()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "fda6aaad-7b1e-4daa-8d28-cd049df9cec2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"data_features=filter_df(data)\n",
|
|
"data_features"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "a1b022a3-a2f9-4e39-9e79-48ae9f6adca5",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Classification"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "100a7c2e-2d24-4814-bd68-4b9f6433ce4d",
|
|
"metadata": {},
|
|
"source": [
|
|
"Transformer: la couleur en RGB; la localisation en GPS\n",
|
|
"1ère question: est-ce que la couleur suffit à savoir d'où ça vient ? \n",
|
|
"2ème question: est-ce que si on ajoute le type ça marche ? \n",
|
|
"3ème question: et les caractéristiques gustatives ?"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "24e7ff6e-c308-4cc8-aeac-eeb372f4c479",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"data_features.drop(columns=[\"region\"])\n",
|
|
"data_features"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a8d0848f-b844-4a08-976d-4d1370070f73",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"Y=LabelEncoder().fit_transform(data_features[\"country\"])\n",
|
|
"X=data_features.drop(columns=[\"cheese\",\"country\",\"region\",\"vegetarian\",\"location\",\"latitude\",\"longitude\"])\n",
|
|
"data_train, data_test, target_train, target_test = train_test_split(\n",
|
|
" X, Y)\n",
|
|
"c=tree.DecisionTreeClassifier(max_depth=4)\n",
|
|
"c=c.fit(data_train,target_train)\n",
|
|
"plt.figure(figsize=(100,150))\n",
|
|
"ax=plt.subplot()\n",
|
|
"\n",
|
|
"tree.plot_tree(c,ax=ax,filled=True,feature_names=X.columns);\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f6f6934d-c551-4235-920a-e7fe02a75270",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "73488360-5ba3-4361-aa1a-7c8764b14acd",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "05a6fae7-7dae-41f2-add0-86017116ea11",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "3a0bb4d6-dd0b-451a-b698-3cb6d0b4241d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "27b78f7c-acc2-4667-a661-c65caf467a88",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f623dccc-e98d-4902-8c0b-01d3a35ac349",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "7cb55e0d-99af-447f-bc93-284eb0e46306",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f0c6c58d-6259-4c5f-836f-5342423ece8f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0fd63346-fc3a-4074-aae7-c481c14fc009",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "038cd38e-3890-4f73-91a7-c30294b3bc5b",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Pattern Mining"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "2e6b0dc1-030c-4239-803f-52736a41bcb5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"unused_columns = {\"vegetarian\", \"vegan\", \"cheese\", \"region\", \"color\", \"location\", \"latitude\", \"longitude\", \"country\"}\n",
|
|
"data_features_only=data_features.drop(columns=list(unused_columns.intersection(data_features.columns)))\n",
|
|
"data_features_only.shape[1]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "b76e8b2f-2efc-43f7-9aa7-fffb960313ad",
|
|
"metadata": {},
|
|
"source": [
|
|
"We have $164$ features in our data, that is very big compared to the number of rows of our data. So, we choose a min_support of $0.1$ during the apriori algorithm for pattern mining. "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "e7113235-7546-4c71-9b34-181472466d20",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"frequent_itemsets=apriori(data_features_only,min_support=.1, use_colnames=True)\n",
|
|
"display(HTML(frequent_itemsets.to_html()))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "e13aa0af-ee35-4d39-881f-44f4db54df4b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "61959c04-61bf-464a-89ca-72ec4782f927",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"assoc_rules = association_rules(frequent_itemsets, min_threshold=.5)\n",
|
|
"\n",
|
|
"display(HTML(assoc_rules.to_html()))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a3a2a838-bc56-4de8-ac5d-f1c3327f5447",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "361abbac-54b9-4a7d-ae71-46c17ca9b570",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "78ef08e7-1436-440f-b035-8b480af1cc7b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d640609a-1779-47ea-8a5c-17f2daea8700",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "fc73b5c6-1c6a-4dd7-9e94-94ca8a01627f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "fbac311f-c5c5-4951-a7ee-98d5e1f26707",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "c0365269-ff4b-4bdf-96e8-4141bf7f639e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "1d58dedb-88ca-4b5a-9fce-299c0d591887",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "04dd0063-1de0-403c-b61e-fa422a568727",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|