This commit is contained in:
jwradhe 2024-11-12 21:24:09 +01:00
parent c9389f7b1e
commit ae564c0237
9 changed files with 267022 additions and 64 deletions

View File

@ -7,6 +7,7 @@
First thing to do is to extract TMDB_tv_dataset_v3.zip in dataset folder so that it contains TMDB_tv_dataset_v3.csv. First thing to do is to extract TMDB_tv_dataset_v3.zip in dataset folder so that it contains TMDB_tv_dataset_v3.csv.
**Running program** **Running program**
Start main.py and it will load dataset and ask for a title to get recommendations from, also how many recommendations wanted. Then enter and you will have those recommendations presented on screen. Start main.py and it will load dataset and ask for a title to get recommendations from, also how many recommendations wanted. Then enter and you will have those recommendations presented on screen.

49964
data.csv Normal file

File diff suppressed because it is too large Load Diff

216911
dataset/TMDB_tv_dataset_v3.csv Normal file

File diff suppressed because it is too large Load Diff

View File

@ -3,7 +3,9 @@ import os
import pandas as pd import pandas as pd
############################## Import data ############################## ###############################################################
#### Class: ImportData
###############################################################
class ImportData: class ImportData:
def __init__(self): def __init__(self):
@ -11,7 +13,9 @@ class ImportData:
self.loaded_datasets = [] self.loaded_datasets = []
# ---------------------- Function: load_dataset ---------------------- ###########################################################
#### Function: load_dataset
###########################################################
def load_dataset(self, dataset_path): def load_dataset(self, dataset_path):
# Load data from dataset CSV file # Load data from dataset CSV file
try: try:
@ -22,7 +26,9 @@ class ImportData:
return None return None
# ---------------------- Function: create_data ---------------------- ###########################################################
#### Function: create_data
###########################################################
def create_data(self, filename): def create_data(self, filename):
try: try:
self.data = self.load_dataset(filename) self.data = self.load_dataset(filename)
@ -32,7 +38,9 @@ class ImportData:
return None return None
# ---------------------- Function: clean_data ---------------------- ###########################################################
#### Function: clean_data
###########################################################
def clean_data(self): def clean_data(self):
if self.data is not None: if self.data is not None:
# Drop unnecessary columns # Drop unnecessary columns
@ -53,7 +61,9 @@ class ImportData:
print("No data to clean. Please load the dataset first.") print("No data to clean. Please load the dataset first.")
# ---------------------- Function: save_data ---------------------- ###########################################################
#### Function: save_data
###########################################################
def save_data(self): def save_data(self):
if self.data is not None: if self.data is not None:
try: try:

View File

@ -3,7 +3,9 @@ from trainmodel import TrainModel
from recommendations import RecommendationLoader from recommendations import RecommendationLoader
############################## Main ############################################ #########################################################################
#### function: main
#########################################################################
def main(): def main():

View File

@ -2,14 +2,18 @@ import pandas as pd
from import_data import ImportData from import_data import ImportData
############################## Load data ############################## #########################################################################
#### Class: LoadData
#########################################################################
class LoadData: class LoadData:
def __init__(self): def __init__(self):
self.data = None self.data = None
self.filename = 'TMDB_tv_dataset_v3.csv' self.filename = 'TMDB_tv_dataset_v3.csv'
# ---------------------- Function: load_data ---------------------- ###########################################################
#### Function: load_data
###########################################################
def load_data(self): def load_data(self):
self.read_data() self.read_data()
self.clean_data() self.clean_data()
@ -17,7 +21,9 @@ class LoadData:
return self.data return self.data
# ---------------------- Function: read_data ---------------------- ###########################################################
#### Function: read_data
###########################################################
def read_data(self): def read_data(self):
print("Starting to read data ...") print("Starting to read data ...")
try: try:
@ -38,7 +44,9 @@ class LoadData:
print(f"Error during data import process: {e}") print(f"Error during data import process: {e}")
# ---------------------- Function: clean_data ---------------------- ###########################################################
#### Function: clean_data
###########################################################
def clean_data(self): def clean_data(self):
# Function to split a string into a list, or use an empty list if no valid data # Function to split a string into a list, or use an empty list if no valid data
def split_to_list(value): def split_to_list(value):

View File

@ -3,14 +3,18 @@ import pandas as pd
import textwrap import textwrap
############################## Recommendation loader ############################## ###############################################################
#### Class: RecommendationLoader
###############################################################
class RecommendationLoader: class RecommendationLoader:
def __init__(self, model, title_data): def __init__(self, model, title_data):
self.model = model self.model = model
self.title_data = title_data self.title_data = title_data
# ------------------------ Function: run ------------------------ ###########################################################
#### Function: run
###########################################################
def run(self): def run(self):
while True: while True:
user_data = UserData() user_data = UserData()
@ -37,13 +41,12 @@ class RecommendationLoader:
print("\nWrite 'exit' or 'quit' to end the program.") print("\nWrite 'exit' or 'quit' to end the program.")
# ------------------------ Function: get_recommendations ------------------------ ###########################################################
#### Function: get_recommendations
###########################################################
def get_recommendations(self, target_row, user_data): def get_recommendations(self, target_row, user_data):
recommendations = pd.DataFrame()
n_recommendations = user_data['n_rec'] n_recommendations = user_data['n_rec']
recommendations = self.model.recommend(target_row, user_data['n_rec'])
# I dont want to recommend a title with Reality in it if the reference doesnt have that genre and so on
recommendations = self.filter_genres(recommendations, target_row)
# Get more recommendations and filter untill n_recommendations is reached # Get more recommendations and filter untill n_recommendations is reached
while len(recommendations) < n_recommendations: while len(recommendations) < n_recommendations:
@ -58,7 +61,9 @@ class RecommendationLoader:
self.display_recommendations(user_data, recommendations, n_recommendations, target_row) self.display_recommendations(user_data, recommendations, n_recommendations, target_row)
# ------------------------ Function: display_recommendations ------------------------ ###########################################################
#### Function: display_recommendations
###########################################################
def display_recommendations(self, user_data, recommendations, n_recommendations, target_row): def display_recommendations(self, user_data, recommendations, n_recommendations, target_row):
print(f'\n{n_recommendations} recommendations based on "{user_data["title"]}":\n') print(f'\n{n_recommendations} recommendations based on "{user_data["title"]}":\n')
@ -112,7 +117,9 @@ class RecommendationLoader:
print("No recommendations found.") print("No recommendations found.")
# ------------------------ Function: get_explanation ------------------------ ###########################################################
#### Function: get_explanation
###########################################################
def get_explanation(self, row, target_row): def get_explanation(self, row, target_row):
explanation = [] explanation = []
title = row['name'] title = row['name']
@ -136,7 +143,9 @@ class RecommendationLoader:
return ' '.join(explanation) return ' '.join(explanation)
# ------------------------ Function: check_genre_overlap ------------------------ ###########################################################
#### Function: check_genre_overlap
###########################################################
def check_genre_overlap(self, target_row, row): def check_genre_overlap(self, target_row, row):
# Get genres from the target row # Get genres from the target row
target_genres = set(genre.lower() for genre in target_row['genres']) target_genres = set(genre.lower() for genre in target_row['genres'])
@ -149,7 +158,9 @@ class RecommendationLoader:
return overlap return overlap
# ------------------------ Function: check_created_by_overlap ------------------------ ###########################################################
#### Function: check_created_by_overlap
###########################################################
def check_created_by_overlap(self, target_row, row): def check_created_by_overlap(self, target_row, row):
# Get created_by from the target row # Get created_by from the target row
target_creators = set(creator.lower() for creator in target_row['created_by']) target_creators = set(creator.lower() for creator in target_row['created_by'])
@ -162,7 +173,9 @@ class RecommendationLoader:
return overlap return overlap
# ------------------------ Function: extract_years ------------------------ ###########################################################
#### Function: extract_years
###########################################################
def extract_years(self, air_date): def extract_years(self, air_date):
# Make sure air_date is not null # Make sure air_date is not null
if pd.isna(air_date): if pd.isna(air_date):
@ -173,7 +186,9 @@ class RecommendationLoader:
return air_date.split('-')[0] return air_date.split('-')[0]
# ------------------------ Function: get_recommendations ------------------------ ###########################################################
#### Function: filter_genres
###########################################################
def filter_genres(self, recommendations, target_row): def filter_genres(self, recommendations, target_row):
# Get genres from the target row # Get genres from the target row
reference_genres = [genre.lower() for genre in target_row['genres']] reference_genres = [genre.lower() for genre in target_row['genres']]

View File

@ -1,58 +1,67 @@
from read_data import LoadData
from sentence_transformers import SentenceTransformer
from sklearn.neighbors import NearestNeighbors from sklearn.neighbors import NearestNeighbors
from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import TruncatedSVD from sklearn.decomposition import TruncatedSVD
from scipy.sparse import hstack, csr_matrix from scipy.sparse import hstack, csr_matrix
import numpy as np
import pickle
import time import time
import warnings import warnings
warnings.filterwarnings("ignore", category=UserWarning, module='sklearn') warnings.filterwarnings("ignore", category=UserWarning, module='sklearn')
############################## Train model ############################## #########################################################################
#### Class: TrainModel
#########################################################################
class TrainModel: class TrainModel:
def __init__(self, title_data): def __init__(self, title_data):
self.title_data = title_data self.title_data = title_data
# Settings for vectorization # Initialize Sentence-BERT model for embeddings
self.bert_model = SentenceTransformer('all-MiniLM-L12-v2')
# Settings for TF-IDF Vectorization
self.vectorizer = TfidfVectorizer(max_features=5000, ngram_range=(1, 2), min_df=0.01, max_df=0.5) self.vectorizer = TfidfVectorizer(max_features=5000, ngram_range=(1, 2), min_df=0.01, max_df=0.5)
# Settings for nearest neighbors # Settings for Nearest Neighbors
self.model = NearestNeighbors(metric='cosine') self.nearest_neighbors = NearestNeighbors(metric='cosine')
self.scaler = StandardScaler() self.scaler = StandardScaler()
# Settings for SVD # Settings for SVD
self.svd = TruncatedSVD(n_components=300) self.svd = TruncatedSVD(n_components=300)
# ---------------------- Function: train ---------------------- ###########################################################
#### Function: Train
###########################################################
def train(self): def train(self):
print("Starting to train model ...") print("Starting to train model ...")
start = time.time() start = time.time()
# Preprocess title data # Preprocess title data with advanced embeddings included
preproccessed_data = self.preprocess_title_data() preprocessed_data = self.preprocess_title_data()
# Train the NearestNeighbors model # Train Nearest Neighbors on the enhanced feature set
self.model.fit(preproccessed_data) self.nearest_neighbors.fit(preprocessed_data)
stop = time.time() print(f'Trained model successfully in {time.time() - start:.2f} seconds.')
# Count time for training
elapsed_time = stop - start
print(f'Trained model successfully in {elapsed_time:.2f} seconds.')
# ------------------------ Function: recommend ------------------------ ###########################################################
#### Function: get_recommendations
###########################################################
def recommend(self, target_row, num_recommendations=40): def recommend(self, target_row, num_recommendations=40):
# Preprocess target data # Preprocess target data
target_vector = self.preprocess_target_data(target_row) target_vector = self.preprocess_target_data(target_row)
# Use NearestNeighbors model as input to K-nearest neighbors # Use Nearest Neighbors to get recommendations
distances, indices = self.model.kneighbors(target_vector, n_neighbors=num_recommendations) distances, indices = self.nearest_neighbors.kneighbors(target_vector, n_neighbors=num_recommendations)
recommendations = self.title_data.iloc[indices[0]].copy() recommendations = self.title_data.iloc[indices[0]].copy()
recommendations['distance'] = distances[0] recommendations['distance'] = distances[0]
@ -64,44 +73,76 @@ class TrainModel:
return recommendations.head(num_recommendations) return recommendations.head(num_recommendations)
# ---------------------- Function: preprocess_data ---------------------- ###########################################################
#### Function: preprocess_title_data
###########################################################
def preprocess_title_data(self): def preprocess_title_data(self):
# Combine text fields in a new column for vectorization
self.title_data['combined_text'] = ( self.title_data['combined_text'] = (
self.title_data['overview'].fillna('').apply(str) + ' ' + self.title_data['overview'].fillna('').apply(str) + ' ' +
self.title_data['genres'].fillna('').apply(str) + ' ' + self.title_data['genres'].fillna('').apply(str) + ' ' +
self.title_data['created_by'].fillna('').apply(str) self.title_data['created_by'].fillna('').apply(str)
) )
# Process combined_text column with vectorizer # Process text data for TF-IDF + SVD
text_features = self.vectorizer.fit_transform(self.title_data['combined_text']) text_features = self.vectorizer.fit_transform(self.title_data['combined_text'])
text_features = self.svd.fit_transform(text_features) text_features = self.svd.fit_transform(text_features)
# Scale numerical features in the DataFrame using a scaler # Generate Sentence-BERT embeddings
self.numerical_data = self.title_data.select_dtypes(include=['number']) bert_embeddings = self.load_pickle('bert_embeddings.pkl', self.title_data['combined_text'])
# Include ratings in numerical features # Process numerical features
self.numerical_data = self.title_data.select_dtypes(include=['number'])
if 'vote_average' in self.numerical_data.columns: if 'vote_average' in self.numerical_data.columns:
self.numerical_data = self.numerical_data[['vote_average']] self.numerical_data = self.numerical_data[['vote_average']]
# Scale numerical features
numerical_features = self.scaler.fit_transform(self.numerical_data) numerical_features = self.scaler.fit_transform(self.numerical_data)
numerical_features_sparse = csr_matrix(numerical_features) numerical_features_sparse = csr_matrix(numerical_features)
# Combine text and numerical features # Combine all features
combined_features = hstack([csr_matrix(text_features), numerical_features_sparse]) combined_features = hstack([csr_matrix(text_features), csr_matrix(bert_embeddings), numerical_features_sparse])
return combined_features return combined_features
# ---------------------- Function: preprocess_target_data ---------------------- ###########################################################
#### Function: preprocess_target_data
###########################################################
def preprocess_target_data(self, target_row): def preprocess_target_data(self, target_row):
# Create feature vector for target row # Process target text data for TF-IDF + SVD
target_text_vector = self.vectorizer.transform([target_row['combined_text']]) target_text_vector = self.vectorizer.transform([target_row['combined_text']])
target_text_vector = self.svd.transform(target_text_vector) target_text_vector = self.svd.transform(target_text_vector)
# Process numerical features of the referens target # Generate Sentence-BERT embedding for target
target_bert_embedding = self.embed_text(target_row['combined_text']).reshape(1, -1)
# Process numerical features
target_numerical = target_row[self.numerical_data.columns].values.reshape(1, -1) target_numerical = target_row[self.numerical_data.columns].values.reshape(1, -1)
target_vector = hstack([csr_matrix(target_text_vector), csr_matrix(self.scaler.transform(target_numerical))]) target_numerical_scaled = self.scaler.transform(target_numerical)
# Combine all target features
target_vector = hstack([csr_matrix(target_text_vector), csr_matrix(target_bert_embedding), csr_matrix(target_numerical_scaled)])
return target_vector return target_vector
###########################################################
#### Function: load_pickle
###########################################################
def load_pickle(self, filename, title_data):
try:
with open(filename, 'rb') as f:
bert_embeddings = pickle.load(f)
except FileNotFoundError:
bert_embeddings = np.vstack(title_data.apply(self.embed_text).values)
with open(filename, 'wb') as f:
pickle.dump(bert_embeddings, f)
return bert_embeddings
###########################################################
#### Function: embed_text
###########################################################
def embed_text(self, text):
# Use Sentence-BERT to create embeddings
return self.bert_model.encode(text, convert_to_numpy=True)

View File

@ -1,10 +1,14 @@
############################## User input ############################## ###############################################################
#### Class: UserData
###############################################################
class UserData: class UserData:
def __init__(self): def __init__(self):
self.user_data = {} self.user_data = {}
self.n_rec = 10 self.n_rec = 10
# ---------------------- Function: title ---------------------- ###########################################################
#### Function: title
###########################################################
def title(self): def title(self):
# Ask for user input # Ask for user input
print("#" * 100) print("#" * 100)
@ -12,7 +16,9 @@ class UserData:
self.user_data['title'] = title.strip().lower() self.user_data['title'] = title.strip().lower()
return self.user_data return self.user_data
# ---------------------- Function: n_recommendations ---------------------- ###########################################################
#### Function: n_recommendations
###########################################################
def n_recommendations(self): def n_recommendations(self):
# Ask for number of recommendations # Ask for number of recommendations
while True: while True: