Big update and changing to flask webgui instead
This commit is contained in:
parent
9765677235
commit
46715bff45
185
README.md
185
README.md
@ -1,102 +1,129 @@
|
|||||||
# Supervised Learning - TV-Show recommender
|
# Supervised Learning - TV-Show Recommender
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
1. [How to Run the Program](#how-to-run-the-program)
|
||||||
|
2. [Project Overview](#project-overview)
|
||||||
|
3. [Dataset](#dataset)
|
||||||
|
4. [Model and Algorithm](#model-and-algorithm)
|
||||||
|
5. [Features](#features)
|
||||||
|
6. [Requirements](#requirements)
|
||||||
|
7. [Libraries](#libraries)
|
||||||
|
8. [Classes](#classes)
|
||||||
|
9. [References](#references)
|
||||||
|
|
||||||
## How to run program
|
## How to Run the Program
|
||||||
|
|
||||||
**Before running program**
|
### Prerequisites
|
||||||
|
|
||||||
First thing to do is to extract TMDB_tv_dataset_v3.zip in dataset folder so that it contains TMDB_tv_dataset_v3.csv.
|
1. **Download and Extract the Dataset:**
|
||||||
|
- Download the dataset from [TMDB TV Dataset](https://www.kaggle.com/datasets/asaniczka/full-tmdb-tv-shows-dataset-2023-150k-shows).
|
||||||
|
- Extract `TMDB_tv_dataset_v3.zip` into the `dataset/` folder, so it contains the file `TMDB_tv_dataset_v3.csv`.
|
||||||
|
|
||||||
**Running program**
|
2. **Install Dependencies:**
|
||||||
|
- Install the necessary libraries listed in `requirements.txt` (see below).
|
||||||
|
|
||||||
Start main.py and it will load dataset and ask for a title to get recommendations from, also how many recommendations wanted. Then enter and you will have those recommendations presented on screen.
|
3. **Run the Program:**
|
||||||
|
- Start the program by running the following command:
|
||||||
|
|
||||||
> [!NOTE]
|
```bash
|
||||||
> First time loading program it will generate Sentence-BERT embeddings that will help program get better recommendations, this can take up to 5min due to big datafile.
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
- The program will load the dataset, ask for a TV show title to base recommendations on, and prompt for the number of recommendations.
|
||||||
|
|
||||||
## Specification
|
- **Note:** The first time the program is run, it will generate **Sentence-BERT embeddings**. This can take up to 5 minutes due to the large size of the dataset.
|
||||||
|
|
||||||
**TV-Show recommender**
|
---
|
||||||
|
|
||||||
This program will recommend you what tv-show to view based on what you like.
|
## Project Overview
|
||||||
You will tell what tv-show you like and how many recommendations wanted, then you will get that
|
|
||||||
amount of recommendations of tv-shows in order of rank from your search.
|
|
||||||
|
|
||||||
### Data Source:
|
The **TV-Show Recommender** is a machine learning-based program that suggests TV shows to users based on their preferences. The system uses **Nearest Neighbors (NN)** and **K-Nearest Neighbors (K-NN)** algorithms with **cosine distance** to recommend TV shows. Users provide a title of a TV show they like, and the system returns personalized recommendations based on similarity to other TV shows in the dataset.
|
||||||
I will use a dataset from TMBD
|
|
||||||
|
|
||||||
https://www.kaggle.com/datasets/asaniczka/full-tmdb-tv-shows-dataset-2023-150k-shows
|
---
|
||||||
|
|
||||||
### Model:
|
## Dataset
|
||||||
I must first preprocess data with vectorization so that i can train it in NearestNeighbors (NN) alhorithm with cosine distance. Later use NearestNeighbors (NN) in combination with K-NearestNeighbors (K-NN) alhorithm.
|
|
||||||
|
|
||||||
### Features:
|
The dataset used in this project is sourced from **TMDB** (The Movie Database). It contains over 150,000 TV shows and includes information such as:
|
||||||
1. Load data from dataset and preprocessing.
|
|
||||||
2. Model training with NN & k-NN algorithm.
|
|
||||||
3. User input
|
|
||||||
4. Recommendations
|
|
||||||
|
|
||||||
### Requirements:
|
- Title of TV shows
|
||||||
1. Title data:
|
- Genres
|
||||||
* Title
|
- First/Last air date
|
||||||
* Genres
|
- Vote count and average rating
|
||||||
* First/last air date
|
- Director/Creator information
|
||||||
* Vote count/average
|
- Overview/Description
|
||||||
* Director
|
- Networks
|
||||||
* Description
|
- Spoken languages
|
||||||
* Networks
|
- Number of seasons/episodes
|
||||||
* Spoken languages
|
|
||||||
* Number of seasons/episodes
|
|
||||||
2. User data:
|
|
||||||
* What Movie / TV-Show prefers
|
|
||||||
* Number of recommendations wanted
|
|
||||||
|
|
||||||
### Libraries
|
Download the dataset from [here](https://www.kaggle.com/datasets/asaniczka/full-tmdb-tv-shows-dataset-2023-150k-shows).
|
||||||
* pandas: Data manipulation and analysis
|
|
||||||
* scikit-learn: machine learning algorithms and preprocessing
|
|
||||||
* scipy: A scientific computing package for Python
|
|
||||||
* time: provides various functions for working with time
|
|
||||||
* os: functions for interacting with the operating system
|
|
||||||
* re: provides regular expression support
|
|
||||||
* textwrap: Text wrapping and filling
|
|
||||||
|
|
||||||
### Classes
|
---
|
||||||
1. LoadData
|
|
||||||
* load_data
|
|
||||||
* read_data
|
|
||||||
* clean_data
|
|
||||||
2. ImportData
|
|
||||||
* load_dataset
|
|
||||||
* create_data
|
|
||||||
* clean_data
|
|
||||||
* save_data
|
|
||||||
3. TrainModel
|
|
||||||
* train
|
|
||||||
* recommend
|
|
||||||
* preprocess_title_data
|
|
||||||
* preprocess_target_data
|
|
||||||
4. UserData
|
|
||||||
* input
|
|
||||||
* n_recommendations
|
|
||||||
5. RecommendationLoader
|
|
||||||
* run
|
|
||||||
* get_recommendations
|
|
||||||
* display_recommendations
|
|
||||||
* get_explanation
|
|
||||||
* check_genre_overlap
|
|
||||||
* check_created_by_overlap
|
|
||||||
* extract_years
|
|
||||||
* filter_genres
|
|
||||||
|
|
||||||
### References
|
## Model and Algorithm
|
||||||
* https://scikit-learn.org/dev/modules/generated/sklearn.neighbors.NearestNeighbors.html
|
|
||||||
* https://scikit-learn.org/1.5/modules/generated/sklearn.feature_extraction.text.TfidfVectorizer.html
|
|
||||||
* https://scikit-learn.org/dev/modules/generated/sklearn.preprocessing.StandardScaler.html
|
|
||||||
* https://scikit-learn.org/0.16/modules/generated/sklearn.decomposition.TruncatedSVD.html
|
|
||||||
* https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.hstack.html
|
|
||||||
* https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.html
|
|
||||||
|
|
||||||
|
The recommender system is based on **Supervised Learning** using the **NearestNeighbors** and **K-NearestNeighbors** algorithms. Here's a breakdown of the process:
|
||||||
|
|
||||||
|
1. **Data Preprocessing:**
|
||||||
|
- The TV show descriptions are vectorized using **Sentence-BERT embeddings** to create dense vector representations of each show's description.
|
||||||
|
|
||||||
|
2. **Model Training:**
|
||||||
|
- The **NearestNeighbors (NN)** algorithm is used with **cosine distance** to compute similarity between TV shows. The algorithm finds the most similar shows to a user-provided title.
|
||||||
|
|
||||||
|
3. **Recommendation Generation:**
|
||||||
|
- The model generates a list of recommended TV shows by finding the nearest neighbors of the input title using cosine similarity.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
1. **Data Loading & Preprocessing:**
|
||||||
|
- Loads the TV show data from a CSV file and preprocesses it for model training.
|
||||||
|
|
||||||
|
2. **Model Training with K-NN:**
|
||||||
|
- Trains a K-NN model using the **NearestNeighbors** algorithm for generating recommendations.
|
||||||
|
|
||||||
|
3. **User Input for Recommendations:**
|
||||||
|
- Accepts user input for the TV show title and the number of recommendations.
|
||||||
|
|
||||||
|
4. **TV Show Recommendations:**
|
||||||
|
- Returns a list of recommended TV shows based on similarity to the input TV show.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
### Data Requirements:
|
||||||
|
The dataset should contain the following columns for each TV show:
|
||||||
|
- **Title**
|
||||||
|
- **Genres**
|
||||||
|
- **First/Last air date**
|
||||||
|
- **Vote count/average**
|
||||||
|
- **Director**
|
||||||
|
- **Overview**
|
||||||
|
- **Networks**
|
||||||
|
- **Spoken languages**
|
||||||
|
- **Number of seasons/episodes**
|
||||||
|
|
||||||
|
### User Input Requirements:
|
||||||
|
- **TV Show Title**: The name of the TV show you like.
|
||||||
|
- **Number of Recommendations**: The number of recommendations you want to receive (default is 10).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Libraries
|
||||||
|
|
||||||
|
The following libraries are required to run the program:
|
||||||
|
|
||||||
|
- **pandas**: For data manipulation and analysis.
|
||||||
|
- **scikit-learn**: For machine learning algorithms and preprocessing.
|
||||||
|
- **scipy**: For scientific computing (e.g., sparse matrices).
|
||||||
|
- **time**: For working with time-related functions.
|
||||||
|
- **os**: For interacting with the operating system.
|
||||||
|
- **re**: For regular expression support.
|
||||||
|
- **textwrap**: For text wrapping and formatting.
|
||||||
|
- **flask**: For creating the web interface.
|
||||||
|
|
||||||
|
To install the dependencies, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|||||||
84
app.py
Normal file
84
app.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
from flask import Flask, render_template, request
|
||||||
|
from readdata import LoadData
|
||||||
|
from recommendations import RecommendationLoader
|
||||||
|
from training import TrainModel
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
data_loader = LoadData()
|
||||||
|
title_data = data_loader.load_data()
|
||||||
|
|
||||||
|
model = TrainModel(title_data)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
recommender = RecommendationLoader(model, title_data)
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/')
|
||||||
|
def home():
|
||||||
|
return render_template('index.html')
|
||||||
|
|
||||||
|
|
||||||
|
@app.route('/recommend', methods=['POST'])
|
||||||
|
def recommend():
|
||||||
|
|
||||||
|
# Get user input
|
||||||
|
title = request.form.get('title').strip()
|
||||||
|
n_recommendations = int(request.form.get('n_recommendations', 10))
|
||||||
|
|
||||||
|
# Validate user input
|
||||||
|
if not title:
|
||||||
|
return render_template('index.html', message="Please enter a valid TV show title.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
n_recommendations = int(n_recommendations)
|
||||||
|
if n_recommendations < 1 or n_recommendations > 50:
|
||||||
|
raise ValueError("Number of recommendations must be between 1 and 50.")
|
||||||
|
except ValueError as e:
|
||||||
|
return render_template('index.html', message=str(e))
|
||||||
|
|
||||||
|
# Get recommendations from the model
|
||||||
|
target_row = title_data[title_data['name'].str.lower() == title.lower()]
|
||||||
|
|
||||||
|
# Check if a match was found
|
||||||
|
if target_row.empty:
|
||||||
|
return render_template('index.html', message=f"No match found for '{title}'. Try again.")
|
||||||
|
|
||||||
|
# Get recommendations
|
||||||
|
target_row = target_row.iloc[0]
|
||||||
|
user_data = {'title': title, 'n_rec': n_recommendations}
|
||||||
|
recommendations = recommender.get_recommendations("flask", target_row, user_data)
|
||||||
|
|
||||||
|
# Check if recommendations were found
|
||||||
|
if recommendations is None or recommendations.empty:
|
||||||
|
return render_template('index.html', message=f"Sorry, no recommendations available for {title}.")
|
||||||
|
|
||||||
|
# Prepare data for display on the webpage
|
||||||
|
recommendations_data = []
|
||||||
|
|
||||||
|
for _, row in recommendations.iterrows():
|
||||||
|
|
||||||
|
# Extract the first and last air dates
|
||||||
|
first_air_date = recommender.extract_years(row['first_air_date'])
|
||||||
|
last_air_date = recommender.extract_years(row['last_air_date'])
|
||||||
|
if last_air_date != "Ongoing" and last_air_date:
|
||||||
|
years = f"{first_air_date} - {last_air_date}"
|
||||||
|
else:
|
||||||
|
years = f"{first_air_date}"
|
||||||
|
|
||||||
|
recommendations_data.append({
|
||||||
|
'title': row['name'],
|
||||||
|
'genres': ', '.join(row['genres']) if isinstance(row['genres'], list) else row['genres'],
|
||||||
|
'overview': row['overview'],
|
||||||
|
'rating': row['vote_average'],
|
||||||
|
'seasons': row['number_of_seasons'],
|
||||||
|
'episodes': row['number_of_episodes'],
|
||||||
|
'networks': ', '.join(row['networks']) if isinstance(row['networks'], list) and row['networks'] else 'N/A',
|
||||||
|
'years': years,
|
||||||
|
})
|
||||||
|
|
||||||
|
return render_template('index.html', recommendations=recommendations_data, original_title=title)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app.run(debug=True)
|
||||||
File diff suppressed because it is too large
Load Diff
4
main.py
4
main.py
@ -1,5 +1,5 @@
|
|||||||
from read_data import LoadData
|
from readdata import LoadData
|
||||||
from trainmodel import TrainModel
|
from training import TrainModel
|
||||||
from recommendations import RecommendationLoader
|
from recommendations import RecommendationLoader
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from user_data import UserData
|
from user import UserData
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
@ -44,7 +44,7 @@ class RecommendationLoader:
|
|||||||
###########################################################
|
###########################################################
|
||||||
#### Function: get_recommendations
|
#### Function: get_recommendations
|
||||||
###########################################################
|
###########################################################
|
||||||
def get_recommendations(self, target_row, user_data):
|
def get_recommendations(self, type, target_row, user_data):
|
||||||
recommendations = pd.DataFrame()
|
recommendations = pd.DataFrame()
|
||||||
n_recommendations = user_data['n_rec']
|
n_recommendations = user_data['n_rec']
|
||||||
|
|
||||||
@ -58,6 +58,9 @@ class RecommendationLoader:
|
|||||||
# Make sure we give n_recommendations recommendations
|
# Make sure we give n_recommendations recommendations
|
||||||
recommendations = recommendations.head(n_recommendations)
|
recommendations = recommendations.head(n_recommendations)
|
||||||
|
|
||||||
|
if type == 'flask':
|
||||||
|
return recommendations
|
||||||
|
else:
|
||||||
self.display_recommendations(user_data, recommendations, n_recommendations, target_row)
|
self.display_recommendations(user_data, recommendations, n_recommendations, target_row)
|
||||||
|
|
||||||
|
|
||||||
@ -106,73 +109,12 @@ class RecommendationLoader:
|
|||||||
print(f"Seasons: {seasons} ({episodes} episodes)")
|
print(f"Seasons: {seasons} ({episodes} episodes)")
|
||||||
print(f'\n{overview}\n')
|
print(f'\n{overview}\n')
|
||||||
|
|
||||||
# Get explanation for recommendation
|
|
||||||
explanation = self.get_explanation(row, target_row)
|
|
||||||
print(f"{explanation}\n")
|
|
||||||
|
|
||||||
print("-" * width)
|
print("-" * width)
|
||||||
|
|
||||||
print("\nEnd of recommendations.")
|
print("\nEnd of recommendations.")
|
||||||
else:
|
else:
|
||||||
print("No recommendations found.")
|
print("No recommendations found.")
|
||||||
|
|
||||||
|
|
||||||
###########################################################
|
|
||||||
#### Function: get_explanation
|
|
||||||
###########################################################
|
|
||||||
def get_explanation(self, row, target_row):
|
|
||||||
explanation = []
|
|
||||||
title = row['name']
|
|
||||||
|
|
||||||
explanation.append(f"The title '{title}' was recommended because: \n")
|
|
||||||
|
|
||||||
# Explain genre overlap
|
|
||||||
genre_overlap = self.check_genre_overlap(target_row, row)
|
|
||||||
if genre_overlap:
|
|
||||||
overlapping_genres = ', '.join(genre_overlap)
|
|
||||||
explanation.append(f"It shares the following genres with your preferences: {overlapping_genres}.\n")
|
|
||||||
|
|
||||||
# Explain created_by overlap
|
|
||||||
created_by_overlap = self.check_created_by_overlap(target_row, row)
|
|
||||||
if created_by_overlap:
|
|
||||||
overlapping_created_by = ', '.join(created_by_overlap)
|
|
||||||
explanation.append(f"It shares the following director with your preferences: {overlapping_created_by}.\n")
|
|
||||||
|
|
||||||
# Explain the distance metric
|
|
||||||
explanation.append(f"The distance metric of {round(row['distance'], 2)} indicates that it is quite similar to your preferences.")
|
|
||||||
return ' '.join(explanation)
|
|
||||||
|
|
||||||
|
|
||||||
###########################################################
|
|
||||||
#### Function: check_genre_overlap
|
|
||||||
###########################################################
|
|
||||||
def check_genre_overlap(self, target_row, row):
|
|
||||||
# Get genres from the target row
|
|
||||||
target_genres = set(genre.lower() for genre in target_row['genres'])
|
|
||||||
# Get genres from the recommended row
|
|
||||||
recommended_genres = set(genre.lower() for genre in row['genres'])
|
|
||||||
|
|
||||||
# Find the intersection of the target genres and recommended genres
|
|
||||||
overlap = target_genres.intersection(recommended_genres)
|
|
||||||
|
|
||||||
return overlap
|
|
||||||
|
|
||||||
|
|
||||||
###########################################################
|
|
||||||
#### Function: check_created_by_overlap
|
|
||||||
###########################################################
|
|
||||||
def check_created_by_overlap(self, target_row, row):
|
|
||||||
# Get created_by from the target row
|
|
||||||
target_creators = set(creator.lower() for creator in target_row['created_by'])
|
|
||||||
# Get created_by from the recommended row
|
|
||||||
recommended_creators = set(creator.lower() for creator in row['created_by'])
|
|
||||||
|
|
||||||
# Find the intersection of the target creators and recommended creators
|
|
||||||
overlap = target_creators.intersection(recommended_creators)
|
|
||||||
|
|
||||||
return overlap
|
|
||||||
|
|
||||||
|
|
||||||
###########################################################
|
###########################################################
|
||||||
#### Function: extract_years
|
#### Function: extract_years
|
||||||
###########################################################
|
###########################################################
|
||||||
@ -190,6 +132,7 @@ class RecommendationLoader:
|
|||||||
#### Function: filter_genres
|
#### Function: filter_genres
|
||||||
###########################################################
|
###########################################################
|
||||||
def filter_genres(self, recommendations, target_row):
|
def filter_genres(self, recommendations, target_row):
|
||||||
|
|
||||||
# Get genres from the target row
|
# Get genres from the target row
|
||||||
reference_genres = [genre.lower() for genre in target_row['genres']]
|
reference_genres = [genre.lower() for genre in target_row['genres']]
|
||||||
|
|
||||||
|
|||||||
92
templates/index.html
Normal file
92
templates/index.html
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Recommendation System</title>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: Arial, sans-serif;
|
||||||
|
margin: 0;
|
||||||
|
padding: 20px;
|
||||||
|
background-color: #f4f4f9;
|
||||||
|
}
|
||||||
|
.container {
|
||||||
|
max-width: 1200px;
|
||||||
|
margin: 0 auto;
|
||||||
|
padding: 20px;
|
||||||
|
}
|
||||||
|
h1 {
|
||||||
|
text-align: center;
|
||||||
|
color: #333;
|
||||||
|
}
|
||||||
|
form {
|
||||||
|
text-align: center;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
}
|
||||||
|
.recommendation {
|
||||||
|
background: #fff;
|
||||||
|
padding: 15px;
|
||||||
|
border-radius: 8px;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
|
||||||
|
}
|
||||||
|
.recommendation h3 {
|
||||||
|
margin-top: 0;
|
||||||
|
}
|
||||||
|
.recommendation p {
|
||||||
|
margin: 5px 0;
|
||||||
|
}
|
||||||
|
.recommendation .overview {
|
||||||
|
font-size: 14px;
|
||||||
|
color: #555;
|
||||||
|
}
|
||||||
|
.error-message {
|
||||||
|
color: red;
|
||||||
|
text-align: center;
|
||||||
|
font-weight: bold;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
|
||||||
|
<div class="container">
|
||||||
|
<h1>TV-Show Recommendations</h1>
|
||||||
|
|
||||||
|
<!-- Recommendation Form -->
|
||||||
|
<form method="POST" action="/recommend">
|
||||||
|
<label for="title">Enter a Title (TV Show):</label><br><br>
|
||||||
|
<input type="text" id="title" name="title" required><br><br>
|
||||||
|
|
||||||
|
<label for="n_recommendations">Number of Recommendations:</label><br><br>
|
||||||
|
<input type="number" id="n_recommendations" name="n_recommendations" value="10" min="1" max="50"><br><br>
|
||||||
|
|
||||||
|
<input type="submit" value="Get Recommendations">
|
||||||
|
</form>
|
||||||
|
|
||||||
|
<!-- Display Error Message if any -->
|
||||||
|
{% if message %}
|
||||||
|
<div class="error-message">{{ message }}</div>
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
<!-- Display Recommendations -->
|
||||||
|
{% if recommendations %}
|
||||||
|
<h2>Recommendations based on "{{ original_title }}":</h2>
|
||||||
|
<div class="recommendations">
|
||||||
|
{% for rec in recommendations %}
|
||||||
|
<div class="recommendation">
|
||||||
|
<h3>{{ rec.title }} ({{ rec.years }})</h3>
|
||||||
|
<p><strong>Genres:</strong> {{ rec.genres }}</p>
|
||||||
|
<p><strong>Networks:</strong> {{ rec.networks }}</p>
|
||||||
|
<p><strong>Rating:</strong> {{ rec.rating }}</p>
|
||||||
|
<p><strong>Seasons:</strong> {{ rec.seasons }}({{ rec.episodes }} episodes)</p>
|
||||||
|
<p class="overview"><strong>Overview:</strong> {{ rec.overview }}</p>
|
||||||
|
</div>
|
||||||
|
{% endfor %}
|
||||||
|
</div>
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
@ -1,15 +1,12 @@
|
|||||||
from read_data import LoadData
|
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
from sklearn.neighbors import NearestNeighbors
|
from sklearn.neighbors import NearestNeighbors
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
from sklearn.decomposition import TruncatedSVD
|
from sklearn.decomposition import TruncatedSVD
|
||||||
from scipy.sparse import hstack, csr_matrix
|
from scipy.sparse import hstack, csr_matrix
|
||||||
import numpy as np
|
|
||||||
import pickle
|
import pickle
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings("ignore", category=UserWarning, module='sklearn')
|
warnings.filterwarnings("ignore", category=UserWarning, module='sklearn')
|
||||||
|
|
||||||
@ -24,14 +21,16 @@ class TrainModel:
|
|||||||
# Initialize Sentence-BERT model for embeddings
|
# Initialize Sentence-BERT model for embeddings
|
||||||
self.bert_model = SentenceTransformer('all-MiniLM-L12-v2')
|
self.bert_model = SentenceTransformer('all-MiniLM-L12-v2')
|
||||||
|
|
||||||
# Settings for TF-IDF Vectorization
|
# TF-IDF Vectorization settings
|
||||||
self.vectorizer = TfidfVectorizer(max_features=5000, ngram_range=(1, 2), min_df=0.01, max_df=0.5)
|
self.vectorizer = TfidfVectorizer(max_features=5000, ngram_range=(1, 2), min_df=0.01, max_df=0.5)
|
||||||
|
|
||||||
# Settings for Nearest Neighbors
|
# Nearest Neighbors settings
|
||||||
self.nearest_neighbors = NearestNeighbors(metric='cosine')
|
self.nearest_neighbors = NearestNeighbors(metric='cosine')
|
||||||
|
|
||||||
|
# Scaler for numerical features
|
||||||
self.scaler = StandardScaler()
|
self.scaler = StandardScaler()
|
||||||
|
|
||||||
# Settings for SVD
|
# SVD for dimensionality reduction
|
||||||
self.svd = TruncatedSVD(n_components=300)
|
self.svd = TruncatedSVD(n_components=300)
|
||||||
|
|
||||||
|
|
||||||
@ -39,7 +38,7 @@ class TrainModel:
|
|||||||
#### Function: Train
|
#### Function: Train
|
||||||
###########################################################
|
###########################################################
|
||||||
def train(self):
|
def train(self):
|
||||||
print("Starting to train model ...")
|
print("Starting to train model...")
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|
||||||
@ -53,10 +52,9 @@ class TrainModel:
|
|||||||
|
|
||||||
|
|
||||||
###########################################################
|
###########################################################
|
||||||
#### Function: get_recommendations
|
#### Function: Recommend
|
||||||
###########################################################
|
###########################################################
|
||||||
def recommend(self, target_row, num_recommendations=40):
|
def recommend(self, target_row, num_recommendations=40):
|
||||||
|
|
||||||
# Preprocess target data
|
# Preprocess target data
|
||||||
target_vector = self.preprocess_target_data(target_row)
|
target_vector = self.preprocess_target_data(target_row)
|
||||||
|
|
||||||
@ -77,29 +75,28 @@ class TrainModel:
|
|||||||
#### Function: preprocess_title_data
|
#### Function: preprocess_title_data
|
||||||
###########################################################
|
###########################################################
|
||||||
def preprocess_title_data(self):
|
def preprocess_title_data(self):
|
||||||
|
# Combine text fields for TF-IDF and BERT
|
||||||
self.title_data['combined_text'] = (
|
self.title_data['combined_text'] = (
|
||||||
self.title_data['overview'].fillna('').apply(str) + ' ' +
|
self.title_data['overview'].fillna('').apply(str) + ' ' +
|
||||||
self.title_data['genres'].fillna('').apply(str) + ' ' +
|
self.title_data['genres'].fillna('').apply(str) + ' ' +
|
||||||
self.title_data['created_by'].fillna('').apply(str)
|
self.title_data['created_by'].fillna('').apply(str)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process text data for TF-IDF + SVD
|
# TF-IDF + SVD
|
||||||
text_features = self.vectorizer.fit_transform(self.title_data['combined_text'])
|
text_features = self.vectorizer.fit_transform(self.title_data['combined_text'])
|
||||||
text_features = self.svd.fit_transform(text_features)
|
text_features = self.svd.fit_transform(text_features)
|
||||||
|
|
||||||
# Generate Sentence-BERT embeddings
|
# Sentence-BERT embeddings
|
||||||
bert_embeddings = self.load_pickle('bert_embeddings.pkl', self.title_data['combined_text'])
|
bert_embeddings = self.load_pickle('bert_embeddings.pkl', self.title_data['combined_text'])
|
||||||
|
|
||||||
# Process numerical features
|
# Numerical features
|
||||||
self.numerical_data = self.title_data.select_dtypes(include=['number'])
|
self.numerical_data = self.title_data.select_dtypes(include=['number'])
|
||||||
if 'vote_average' in self.numerical_data.columns:
|
|
||||||
self.numerical_data = self.numerical_data[['vote_average']]
|
|
||||||
numerical_features = self.scaler.fit_transform(self.numerical_data)
|
numerical_features = self.scaler.fit_transform(self.numerical_data)
|
||||||
numerical_features_sparse = csr_matrix(numerical_features)
|
numerical_features_sparse = csr_matrix(numerical_features)
|
||||||
|
|
||||||
# Combine all features
|
# Combine all features
|
||||||
combined_features = hstack([csr_matrix(text_features), csr_matrix(bert_embeddings), numerical_features_sparse])
|
combined_features = hstack([csr_matrix(text_features), csr_matrix(bert_embeddings),
|
||||||
|
numerical_features_sparse])
|
||||||
return combined_features
|
return combined_features
|
||||||
|
|
||||||
|
|
||||||
@ -107,37 +104,23 @@ class TrainModel:
|
|||||||
#### Function: preprocess_target_data
|
#### Function: preprocess_target_data
|
||||||
###########################################################
|
###########################################################
|
||||||
def preprocess_target_data(self, target_row):
|
def preprocess_target_data(self, target_row):
|
||||||
# Process target text data for TF-IDF + SVD
|
# TF-IDF + SVD
|
||||||
target_text_vector = self.vectorizer.transform([target_row['combined_text']])
|
target_text_vector = self.vectorizer.transform([target_row['combined_text']])
|
||||||
target_text_vector = self.svd.transform(target_text_vector)
|
target_text_vector = self.svd.transform(target_text_vector)
|
||||||
|
|
||||||
# Generate Sentence-BERT embedding for target
|
# Sentence-BERT embedding
|
||||||
target_bert_embedding = self.embed_text(target_row['combined_text']).reshape(1, -1)
|
target_bert_embedding = self.embed_text(target_row['combined_text']).reshape(1, -1)
|
||||||
|
|
||||||
# Process numerical features
|
# Numerical features
|
||||||
target_numerical = target_row[self.numerical_data.columns].values.reshape(1, -1)
|
target_numerical = target_row[self.numerical_data.columns].values.reshape(1, -1)
|
||||||
target_numerical_scaled = self.scaler.transform(target_numerical)
|
target_numerical_scaled = self.scaler.transform(target_numerical)
|
||||||
|
|
||||||
# Combine all target features
|
# Combine all features
|
||||||
target_vector = hstack([csr_matrix(target_text_vector), csr_matrix(target_bert_embedding), csr_matrix(target_numerical_scaled)])
|
target_vector = hstack([csr_matrix(target_text_vector), csr_matrix(target_bert_embedding),
|
||||||
|
csr_matrix(target_numerical_scaled)])
|
||||||
return target_vector
|
return target_vector
|
||||||
|
|
||||||
|
|
||||||
###########################################################
|
|
||||||
#### Function: load_pickle
|
|
||||||
###########################################################
|
|
||||||
def load_pickle(self, filename, title_data):
|
|
||||||
try:
|
|
||||||
with open(filename, 'rb') as f:
|
|
||||||
bert_embeddings = pickle.load(f)
|
|
||||||
except FileNotFoundError:
|
|
||||||
bert_embeddings = np.vstack(title_data.apply(self.embed_text).values)
|
|
||||||
with open(filename, 'wb') as f:
|
|
||||||
pickle.dump(bert_embeddings, f)
|
|
||||||
return bert_embeddings
|
|
||||||
|
|
||||||
|
|
||||||
###########################################################
|
###########################################################
|
||||||
#### Function: embed_text
|
#### Function: embed_text
|
||||||
###########################################################
|
###########################################################
|
||||||
@ -146,3 +129,18 @@ class TrainModel:
|
|||||||
return self.bert_model.encode(text, convert_to_numpy=True)
|
return self.bert_model.encode(text, convert_to_numpy=True)
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
#### Function: load_pickle
|
||||||
|
###########################################################
|
||||||
|
def load_pickle(self, filename, title_data):
|
||||||
|
try:
|
||||||
|
with open(filename, 'rb') as f:
|
||||||
|
bert_embeddings = pickle.load(f)
|
||||||
|
except FileNotFoundError:
|
||||||
|
print("Generating Sentence-BERT embeddings...")
|
||||||
|
bert_embeddings = self.bert_model.encode(title_data.tolist(), batch_size=64, convert_to_numpy=True)
|
||||||
|
with open(filename, 'wb') as f:
|
||||||
|
pickle.dump(bert_embeddings, f)
|
||||||
|
return bert_embeddings
|
||||||
|
|
||||||
|
|
||||||
Loading…
Reference in New Issue
Block a user