Update
This commit is contained in:
parent
c9389f7b1e
commit
ae564c0237
@ -7,6 +7,7 @@
|
|||||||
|
|
||||||
First thing to do is to extract TMDB_tv_dataset_v3.zip in dataset folder so that it contains TMDB_tv_dataset_v3.csv.
|
First thing to do is to extract TMDB_tv_dataset_v3.zip in dataset folder so that it contains TMDB_tv_dataset_v3.csv.
|
||||||
|
|
||||||
|
|
||||||
**Running program**
|
**Running program**
|
||||||
|
|
||||||
Start main.py and it will load dataset and ask for a title to get recommendations from, also how many recommendations wanted. Then enter and you will have those recommendations presented on screen.
|
Start main.py and it will load dataset and ask for a title to get recommendations from, also how many recommendations wanted. Then enter and you will have those recommendations presented on screen.
|
||||||
|
|||||||
216911
dataset/TMDB_tv_dataset_v3.csv
Normal file
216911
dataset/TMDB_tv_dataset_v3.csv
Normal file
File diff suppressed because it is too large
Load Diff
@ -3,7 +3,9 @@ import os
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
############################## Import data ##############################
|
###############################################################
|
||||||
|
#### Class: ImportData
|
||||||
|
###############################################################
|
||||||
class ImportData:
|
class ImportData:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -11,7 +13,9 @@ class ImportData:
|
|||||||
self.loaded_datasets = []
|
self.loaded_datasets = []
|
||||||
|
|
||||||
|
|
||||||
# ---------------------- Function: load_dataset ----------------------
|
###########################################################
|
||||||
|
#### Function: load_dataset
|
||||||
|
###########################################################
|
||||||
def load_dataset(self, dataset_path):
|
def load_dataset(self, dataset_path):
|
||||||
# Load data from dataset CSV file
|
# Load data from dataset CSV file
|
||||||
try:
|
try:
|
||||||
@ -22,7 +26,9 @@ class ImportData:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
# ---------------------- Function: create_data ----------------------
|
###########################################################
|
||||||
|
#### Function: create_data
|
||||||
|
###########################################################
|
||||||
def create_data(self, filename):
|
def create_data(self, filename):
|
||||||
try:
|
try:
|
||||||
self.data = self.load_dataset(filename)
|
self.data = self.load_dataset(filename)
|
||||||
@ -32,7 +38,9 @@ class ImportData:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
# ---------------------- Function: clean_data ----------------------
|
###########################################################
|
||||||
|
#### Function: clean_data
|
||||||
|
###########################################################
|
||||||
def clean_data(self):
|
def clean_data(self):
|
||||||
if self.data is not None:
|
if self.data is not None:
|
||||||
# Drop unnecessary columns
|
# Drop unnecessary columns
|
||||||
@ -53,7 +61,9 @@ class ImportData:
|
|||||||
print("No data to clean. Please load the dataset first.")
|
print("No data to clean. Please load the dataset first.")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------- Function: save_data ----------------------
|
###########################################################
|
||||||
|
#### Function: save_data
|
||||||
|
###########################################################
|
||||||
def save_data(self):
|
def save_data(self):
|
||||||
if self.data is not None:
|
if self.data is not None:
|
||||||
try:
|
try:
|
||||||
|
|||||||
4
main.py
4
main.py
@ -3,7 +3,9 @@ from trainmodel import TrainModel
|
|||||||
from recommendations import RecommendationLoader
|
from recommendations import RecommendationLoader
|
||||||
|
|
||||||
|
|
||||||
############################## Main ############################################
|
#########################################################################
|
||||||
|
#### function: main
|
||||||
|
#########################################################################
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
||||||
|
|||||||
16
read_data.py
16
read_data.py
@ -2,14 +2,18 @@ import pandas as pd
|
|||||||
from import_data import ImportData
|
from import_data import ImportData
|
||||||
|
|
||||||
|
|
||||||
############################## Load data ##############################
|
#########################################################################
|
||||||
|
#### Class: LoadData
|
||||||
|
#########################################################################
|
||||||
class LoadData:
|
class LoadData:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.data = None
|
self.data = None
|
||||||
self.filename = 'TMDB_tv_dataset_v3.csv'
|
self.filename = 'TMDB_tv_dataset_v3.csv'
|
||||||
|
|
||||||
|
|
||||||
# ---------------------- Function: load_data ----------------------
|
###########################################################
|
||||||
|
#### Function: load_data
|
||||||
|
###########################################################
|
||||||
def load_data(self):
|
def load_data(self):
|
||||||
self.read_data()
|
self.read_data()
|
||||||
self.clean_data()
|
self.clean_data()
|
||||||
@ -17,7 +21,9 @@ class LoadData:
|
|||||||
return self.data
|
return self.data
|
||||||
|
|
||||||
|
|
||||||
# ---------------------- Function: read_data ----------------------
|
###########################################################
|
||||||
|
#### Function: read_data
|
||||||
|
###########################################################
|
||||||
def read_data(self):
|
def read_data(self):
|
||||||
print("Starting to read data ...")
|
print("Starting to read data ...")
|
||||||
try:
|
try:
|
||||||
@ -38,7 +44,9 @@ class LoadData:
|
|||||||
print(f"Error during data import process: {e}")
|
print(f"Error during data import process: {e}")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------- Function: clean_data ----------------------
|
###########################################################
|
||||||
|
#### Function: clean_data
|
||||||
|
###########################################################
|
||||||
def clean_data(self):
|
def clean_data(self):
|
||||||
# Function to split a string into a list, or use an empty list if no valid data
|
# Function to split a string into a list, or use an empty list if no valid data
|
||||||
def split_to_list(value):
|
def split_to_list(value):
|
||||||
|
|||||||
@ -3,14 +3,18 @@ import pandas as pd
|
|||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
|
|
||||||
############################## Recommendation loader ##############################
|
###############################################################
|
||||||
|
#### Class: RecommendationLoader
|
||||||
|
###############################################################
|
||||||
class RecommendationLoader:
|
class RecommendationLoader:
|
||||||
def __init__(self, model, title_data):
|
def __init__(self, model, title_data):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.title_data = title_data
|
self.title_data = title_data
|
||||||
|
|
||||||
|
|
||||||
# ------------------------ Function: run ------------------------
|
###########################################################
|
||||||
|
#### Function: run
|
||||||
|
###########################################################
|
||||||
def run(self):
|
def run(self):
|
||||||
while True:
|
while True:
|
||||||
user_data = UserData()
|
user_data = UserData()
|
||||||
@ -36,14 +40,13 @@ class RecommendationLoader:
|
|||||||
print("#" * 100)
|
print("#" * 100)
|
||||||
print("\nWrite 'exit' or 'quit' to end the program.")
|
print("\nWrite 'exit' or 'quit' to end the program.")
|
||||||
|
|
||||||
|
|
||||||
# ------------------------ Function: get_recommendations ------------------------
|
###########################################################
|
||||||
|
#### Function: get_recommendations
|
||||||
|
###########################################################
|
||||||
def get_recommendations(self, target_row, user_data):
|
def get_recommendations(self, target_row, user_data):
|
||||||
|
recommendations = pd.DataFrame()
|
||||||
n_recommendations = user_data['n_rec']
|
n_recommendations = user_data['n_rec']
|
||||||
recommendations = self.model.recommend(target_row, user_data['n_rec'])
|
|
||||||
|
|
||||||
# I dont want to recommend a title with Reality in it if the reference doesnt have that genre and so on
|
|
||||||
recommendations = self.filter_genres(recommendations, target_row)
|
|
||||||
|
|
||||||
# Get more recommendations and filter untill n_recommendations is reached
|
# Get more recommendations and filter untill n_recommendations is reached
|
||||||
while len(recommendations) < n_recommendations:
|
while len(recommendations) < n_recommendations:
|
||||||
@ -58,7 +61,9 @@ class RecommendationLoader:
|
|||||||
self.display_recommendations(user_data, recommendations, n_recommendations, target_row)
|
self.display_recommendations(user_data, recommendations, n_recommendations, target_row)
|
||||||
|
|
||||||
|
|
||||||
# ------------------------ Function: display_recommendations ------------------------
|
###########################################################
|
||||||
|
#### Function: display_recommendations
|
||||||
|
###########################################################
|
||||||
def display_recommendations(self, user_data, recommendations, n_recommendations, target_row):
|
def display_recommendations(self, user_data, recommendations, n_recommendations, target_row):
|
||||||
print(f'\n{n_recommendations} recommendations based on "{user_data["title"]}":\n')
|
print(f'\n{n_recommendations} recommendations based on "{user_data["title"]}":\n')
|
||||||
|
|
||||||
@ -112,7 +117,9 @@ class RecommendationLoader:
|
|||||||
print("No recommendations found.")
|
print("No recommendations found.")
|
||||||
|
|
||||||
|
|
||||||
# ------------------------ Function: get_explanation ------------------------
|
###########################################################
|
||||||
|
#### Function: get_explanation
|
||||||
|
###########################################################
|
||||||
def get_explanation(self, row, target_row):
|
def get_explanation(self, row, target_row):
|
||||||
explanation = []
|
explanation = []
|
||||||
title = row['name']
|
title = row['name']
|
||||||
@ -136,7 +143,9 @@ class RecommendationLoader:
|
|||||||
return ' '.join(explanation)
|
return ' '.join(explanation)
|
||||||
|
|
||||||
|
|
||||||
# ------------------------ Function: check_genre_overlap ------------------------
|
###########################################################
|
||||||
|
#### Function: check_genre_overlap
|
||||||
|
###########################################################
|
||||||
def check_genre_overlap(self, target_row, row):
|
def check_genre_overlap(self, target_row, row):
|
||||||
# Get genres from the target row
|
# Get genres from the target row
|
||||||
target_genres = set(genre.lower() for genre in target_row['genres'])
|
target_genres = set(genre.lower() for genre in target_row['genres'])
|
||||||
@ -149,7 +158,9 @@ class RecommendationLoader:
|
|||||||
return overlap
|
return overlap
|
||||||
|
|
||||||
|
|
||||||
# ------------------------ Function: check_created_by_overlap ------------------------
|
###########################################################
|
||||||
|
#### Function: check_created_by_overlap
|
||||||
|
###########################################################
|
||||||
def check_created_by_overlap(self, target_row, row):
|
def check_created_by_overlap(self, target_row, row):
|
||||||
# Get created_by from the target row
|
# Get created_by from the target row
|
||||||
target_creators = set(creator.lower() for creator in target_row['created_by'])
|
target_creators = set(creator.lower() for creator in target_row['created_by'])
|
||||||
@ -162,7 +173,9 @@ class RecommendationLoader:
|
|||||||
return overlap
|
return overlap
|
||||||
|
|
||||||
|
|
||||||
# ------------------------ Function: extract_years ------------------------
|
###########################################################
|
||||||
|
#### Function: extract_years
|
||||||
|
###########################################################
|
||||||
def extract_years(self, air_date):
|
def extract_years(self, air_date):
|
||||||
# Make sure air_date is not null
|
# Make sure air_date is not null
|
||||||
if pd.isna(air_date):
|
if pd.isna(air_date):
|
||||||
@ -173,7 +186,9 @@ class RecommendationLoader:
|
|||||||
return air_date.split('-')[0]
|
return air_date.split('-')[0]
|
||||||
|
|
||||||
|
|
||||||
# ------------------------ Function: get_recommendations ------------------------
|
###########################################################
|
||||||
|
#### Function: filter_genres
|
||||||
|
###########################################################
|
||||||
def filter_genres(self, recommendations, target_row):
|
def filter_genres(self, recommendations, target_row):
|
||||||
# Get genres from the target row
|
# Get genres from the target row
|
||||||
reference_genres = [genre.lower() for genre in target_row['genres']]
|
reference_genres = [genre.lower() for genre in target_row['genres']]
|
||||||
|
|||||||
115
trainmodel.py
115
trainmodel.py
@ -1,58 +1,67 @@
|
|||||||
|
from read_data import LoadData
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
from sklearn.neighbors import NearestNeighbors
|
from sklearn.neighbors import NearestNeighbors
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
from sklearn.decomposition import TruncatedSVD
|
from sklearn.decomposition import TruncatedSVD
|
||||||
from scipy.sparse import hstack, csr_matrix
|
from scipy.sparse import hstack, csr_matrix
|
||||||
|
import numpy as np
|
||||||
|
import pickle
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings("ignore", category=UserWarning, module='sklearn')
|
warnings.filterwarnings("ignore", category=UserWarning, module='sklearn')
|
||||||
|
|
||||||
|
|
||||||
############################## Train model ##############################
|
#########################################################################
|
||||||
|
#### Class: TrainModel
|
||||||
|
#########################################################################
|
||||||
class TrainModel:
|
class TrainModel:
|
||||||
def __init__(self, title_data):
|
def __init__(self, title_data):
|
||||||
self.title_data = title_data
|
self.title_data = title_data
|
||||||
|
|
||||||
# Settings for vectorization
|
# Initialize Sentence-BERT model for embeddings
|
||||||
|
self.bert_model = SentenceTransformer('all-MiniLM-L12-v2')
|
||||||
|
|
||||||
|
# Settings for TF-IDF Vectorization
|
||||||
self.vectorizer = TfidfVectorizer(max_features=5000, ngram_range=(1, 2), min_df=0.01, max_df=0.5)
|
self.vectorizer = TfidfVectorizer(max_features=5000, ngram_range=(1, 2), min_df=0.01, max_df=0.5)
|
||||||
|
|
||||||
# Settings for nearest neighbors
|
# Settings for Nearest Neighbors
|
||||||
self.model = NearestNeighbors(metric='cosine')
|
self.nearest_neighbors = NearestNeighbors(metric='cosine')
|
||||||
self.scaler = StandardScaler()
|
self.scaler = StandardScaler()
|
||||||
|
|
||||||
# Settings for SVD
|
# Settings for SVD
|
||||||
self.svd = TruncatedSVD(n_components=300)
|
self.svd = TruncatedSVD(n_components=300)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------- Function: train ----------------------
|
###########################################################
|
||||||
|
#### Function: Train
|
||||||
|
###########################################################
|
||||||
def train(self):
|
def train(self):
|
||||||
print("Starting to train model ...")
|
print("Starting to train model ...")
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|
||||||
# Preprocess title data
|
# Preprocess title data with advanced embeddings included
|
||||||
preproccessed_data = self.preprocess_title_data()
|
preprocessed_data = self.preprocess_title_data()
|
||||||
|
|
||||||
# Train the NearestNeighbors model
|
# Train Nearest Neighbors on the enhanced feature set
|
||||||
self.model.fit(preproccessed_data)
|
self.nearest_neighbors.fit(preprocessed_data)
|
||||||
|
|
||||||
stop = time.time()
|
print(f'Trained model successfully in {time.time() - start:.2f} seconds.')
|
||||||
|
|
||||||
# Count time for training
|
|
||||||
elapsed_time = stop - start
|
|
||||||
|
|
||||||
print(f'Trained model successfully in {elapsed_time:.2f} seconds.')
|
|
||||||
|
|
||||||
|
|
||||||
# ------------------------ Function: recommend ------------------------
|
###########################################################
|
||||||
|
#### Function: get_recommendations
|
||||||
|
###########################################################
|
||||||
def recommend(self, target_row, num_recommendations=40):
|
def recommend(self, target_row, num_recommendations=40):
|
||||||
|
|
||||||
# Preprocess target data
|
# Preprocess target data
|
||||||
target_vector = self.preprocess_target_data(target_row)
|
target_vector = self.preprocess_target_data(target_row)
|
||||||
|
|
||||||
# Use NearestNeighbors model as input to K-nearest neighbors
|
# Use Nearest Neighbors to get recommendations
|
||||||
distances, indices = self.model.kneighbors(target_vector, n_neighbors=num_recommendations)
|
distances, indices = self.nearest_neighbors.kneighbors(target_vector, n_neighbors=num_recommendations)
|
||||||
recommendations = self.title_data.iloc[indices[0]].copy()
|
recommendations = self.title_data.iloc[indices[0]].copy()
|
||||||
recommendations['distance'] = distances[0]
|
recommendations['distance'] = distances[0]
|
||||||
|
|
||||||
@ -64,44 +73,76 @@ class TrainModel:
|
|||||||
return recommendations.head(num_recommendations)
|
return recommendations.head(num_recommendations)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------- Function: preprocess_data ----------------------
|
###########################################################
|
||||||
|
#### Function: preprocess_title_data
|
||||||
|
###########################################################
|
||||||
def preprocess_title_data(self):
|
def preprocess_title_data(self):
|
||||||
# Combine text fields in a new column for vectorization
|
|
||||||
self.title_data['combined_text'] = (
|
self.title_data['combined_text'] = (
|
||||||
self.title_data['overview'].fillna('').apply(str) + ' ' +
|
self.title_data['overview'].fillna('').apply(str) + ' ' +
|
||||||
self.title_data['genres'].fillna('').apply(str) + ' ' +
|
self.title_data['genres'].fillna('').apply(str) + ' ' +
|
||||||
self.title_data['created_by'].fillna('').apply(str)
|
self.title_data['created_by'].fillna('').apply(str)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process combined_text column with vectorizer
|
# Process text data for TF-IDF + SVD
|
||||||
text_features = self.vectorizer.fit_transform(self.title_data['combined_text'])
|
text_features = self.vectorizer.fit_transform(self.title_data['combined_text'])
|
||||||
text_features = self.svd.fit_transform(text_features)
|
text_features = self.svd.fit_transform(text_features)
|
||||||
|
|
||||||
# Scale numerical features in the DataFrame using a scaler
|
|
||||||
self.numerical_data = self.title_data.select_dtypes(include=['number'])
|
|
||||||
|
|
||||||
# Include ratings in numerical features
|
# Generate Sentence-BERT embeddings
|
||||||
|
bert_embeddings = self.load_pickle('bert_embeddings.pkl', self.title_data['combined_text'])
|
||||||
|
|
||||||
|
# Process numerical features
|
||||||
|
self.numerical_data = self.title_data.select_dtypes(include=['number'])
|
||||||
if 'vote_average' in self.numerical_data.columns:
|
if 'vote_average' in self.numerical_data.columns:
|
||||||
self.numerical_data = self.numerical_data[['vote_average']]
|
self.numerical_data = self.numerical_data[['vote_average']]
|
||||||
|
|
||||||
# Scale numerical features
|
|
||||||
numerical_features = self.scaler.fit_transform(self.numerical_data)
|
numerical_features = self.scaler.fit_transform(self.numerical_data)
|
||||||
numerical_features_sparse = csr_matrix(numerical_features)
|
numerical_features_sparse = csr_matrix(numerical_features)
|
||||||
|
|
||||||
# Combine text and numerical features
|
# Combine all features
|
||||||
combined_features = hstack([csr_matrix(text_features), numerical_features_sparse])
|
combined_features = hstack([csr_matrix(text_features), csr_matrix(bert_embeddings), numerical_features_sparse])
|
||||||
|
|
||||||
return combined_features
|
return combined_features
|
||||||
|
|
||||||
|
|
||||||
# ---------------------- Function: preprocess_target_data ----------------------
|
###########################################################
|
||||||
|
#### Function: preprocess_target_data
|
||||||
|
###########################################################
|
||||||
def preprocess_target_data(self, target_row):
|
def preprocess_target_data(self, target_row):
|
||||||
# Create feature vector for target row
|
# Process target text data for TF-IDF + SVD
|
||||||
target_text_vector = self.vectorizer.transform([target_row['combined_text']])
|
target_text_vector = self.vectorizer.transform([target_row['combined_text']])
|
||||||
target_text_vector = self.svd.transform(target_text_vector)
|
target_text_vector = self.svd.transform(target_text_vector)
|
||||||
|
|
||||||
# Process numerical features of the referens target
|
# Generate Sentence-BERT embedding for target
|
||||||
target_numerical = target_row[self.numerical_data.columns].values.reshape(1, -1)
|
target_bert_embedding = self.embed_text(target_row['combined_text']).reshape(1, -1)
|
||||||
target_vector = hstack([csr_matrix(target_text_vector), csr_matrix(self.scaler.transform(target_numerical))])
|
|
||||||
|
|
||||||
return target_vector
|
# Process numerical features
|
||||||
|
target_numerical = target_row[self.numerical_data.columns].values.reshape(1, -1)
|
||||||
|
target_numerical_scaled = self.scaler.transform(target_numerical)
|
||||||
|
|
||||||
|
# Combine all target features
|
||||||
|
target_vector = hstack([csr_matrix(target_text_vector), csr_matrix(target_bert_embedding), csr_matrix(target_numerical_scaled)])
|
||||||
|
|
||||||
|
return target_vector
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
#### Function: load_pickle
|
||||||
|
###########################################################
|
||||||
|
def load_pickle(self, filename, title_data):
|
||||||
|
try:
|
||||||
|
with open(filename, 'rb') as f:
|
||||||
|
bert_embeddings = pickle.load(f)
|
||||||
|
except FileNotFoundError:
|
||||||
|
bert_embeddings = np.vstack(title_data.apply(self.embed_text).values)
|
||||||
|
with open(filename, 'wb') as f:
|
||||||
|
pickle.dump(bert_embeddings, f)
|
||||||
|
return bert_embeddings
|
||||||
|
|
||||||
|
|
||||||
|
###########################################################
|
||||||
|
#### Function: embed_text
|
||||||
|
###########################################################
|
||||||
|
def embed_text(self, text):
|
||||||
|
# Use Sentence-BERT to create embeddings
|
||||||
|
return self.bert_model.encode(text, convert_to_numpy=True)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
12
user_data.py
12
user_data.py
@ -1,10 +1,14 @@
|
|||||||
############################## User input ##############################
|
###############################################################
|
||||||
|
#### Class: UserData
|
||||||
|
###############################################################
|
||||||
class UserData:
|
class UserData:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.user_data = {}
|
self.user_data = {}
|
||||||
self.n_rec = 10
|
self.n_rec = 10
|
||||||
|
|
||||||
# ---------------------- Function: title ----------------------
|
###########################################################
|
||||||
|
#### Function: title
|
||||||
|
###########################################################
|
||||||
def title(self):
|
def title(self):
|
||||||
# Ask for user input
|
# Ask for user input
|
||||||
print("#" * 100)
|
print("#" * 100)
|
||||||
@ -12,7 +16,9 @@ class UserData:
|
|||||||
self.user_data['title'] = title.strip().lower()
|
self.user_data['title'] = title.strip().lower()
|
||||||
return self.user_data
|
return self.user_data
|
||||||
|
|
||||||
# ---------------------- Function: n_recommendations ----------------------
|
###########################################################
|
||||||
|
#### Function: n_recommendations
|
||||||
|
###########################################################
|
||||||
def n_recommendations(self):
|
def n_recommendations(self):
|
||||||
# Ask for number of recommendations
|
# Ask for number of recommendations
|
||||||
while True:
|
while True:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user