This commit is contained in:
jwradhe 2024-11-12 21:24:09 +01:00
parent c9389f7b1e
commit ae564c0237
9 changed files with 267022 additions and 64 deletions

View File

@ -7,6 +7,7 @@
First thing to do is to extract TMDB_tv_dataset_v3.zip in dataset folder so that it contains TMDB_tv_dataset_v3.csv.
**Running program**
Start main.py and it will load dataset and ask for a title to get recommendations from, also how many recommendations wanted. Then enter and you will have those recommendations presented on screen.

49964
data.csv Normal file

File diff suppressed because it is too large Load Diff

216911
dataset/TMDB_tv_dataset_v3.csv Normal file

File diff suppressed because it is too large Load Diff

View File

@ -3,7 +3,9 @@ import os
import pandas as pd
############################## Import data ##############################
###############################################################
#### Class: ImportData
###############################################################
class ImportData:
def __init__(self):
@ -11,7 +13,9 @@ class ImportData:
self.loaded_datasets = []
# ---------------------- Function: load_dataset ----------------------
###########################################################
#### Function: load_dataset
###########################################################
def load_dataset(self, dataset_path):
# Load data from dataset CSV file
try:
@ -22,7 +26,9 @@ class ImportData:
return None
# ---------------------- Function: create_data ----------------------
###########################################################
#### Function: create_data
###########################################################
def create_data(self, filename):
try:
self.data = self.load_dataset(filename)
@ -32,7 +38,9 @@ class ImportData:
return None
# ---------------------- Function: clean_data ----------------------
###########################################################
#### Function: clean_data
###########################################################
def clean_data(self):
if self.data is not None:
# Drop unnecessary columns
@ -53,7 +61,9 @@ class ImportData:
print("No data to clean. Please load the dataset first.")
# ---------------------- Function: save_data ----------------------
###########################################################
#### Function: save_data
###########################################################
def save_data(self):
if self.data is not None:
try:

View File

@ -3,7 +3,9 @@ from trainmodel import TrainModel
from recommendations import RecommendationLoader
############################## Main ############################################
#########################################################################
#### function: main
#########################################################################
def main():

View File

@ -2,14 +2,18 @@ import pandas as pd
from import_data import ImportData
############################## Load data ##############################
#########################################################################
#### Class: LoadData
#########################################################################
class LoadData:
def __init__(self):
self.data = None
self.filename = 'TMDB_tv_dataset_v3.csv'
# ---------------------- Function: load_data ----------------------
###########################################################
#### Function: load_data
###########################################################
def load_data(self):
self.read_data()
self.clean_data()
@ -17,7 +21,9 @@ class LoadData:
return self.data
# ---------------------- Function: read_data ----------------------
###########################################################
#### Function: read_data
###########################################################
def read_data(self):
print("Starting to read data ...")
try:
@ -38,7 +44,9 @@ class LoadData:
print(f"Error during data import process: {e}")
# ---------------------- Function: clean_data ----------------------
###########################################################
#### Function: clean_data
###########################################################
def clean_data(self):
# Function to split a string into a list, or use an empty list if no valid data
def split_to_list(value):

View File

@ -3,14 +3,18 @@ import pandas as pd
import textwrap
############################## Recommendation loader ##############################
###############################################################
#### Class: RecommendationLoader
###############################################################
class RecommendationLoader:
def __init__(self, model, title_data):
self.model = model
self.title_data = title_data
# ------------------------ Function: run ------------------------
###########################################################
#### Function: run
###########################################################
def run(self):
while True:
user_data = UserData()
@ -37,13 +41,12 @@ class RecommendationLoader:
print("\nWrite 'exit' or 'quit' to end the program.")
# ------------------------ Function: get_recommendations ------------------------
###########################################################
#### Function: get_recommendations
###########################################################
def get_recommendations(self, target_row, user_data):
recommendations = pd.DataFrame()
n_recommendations = user_data['n_rec']
recommendations = self.model.recommend(target_row, user_data['n_rec'])
# I dont want to recommend a title with Reality in it if the reference doesnt have that genre and so on
recommendations = self.filter_genres(recommendations, target_row)
# Get more recommendations and filter untill n_recommendations is reached
while len(recommendations) < n_recommendations:
@ -58,7 +61,9 @@ class RecommendationLoader:
self.display_recommendations(user_data, recommendations, n_recommendations, target_row)
# ------------------------ Function: display_recommendations ------------------------
###########################################################
#### Function: display_recommendations
###########################################################
def display_recommendations(self, user_data, recommendations, n_recommendations, target_row):
print(f'\n{n_recommendations} recommendations based on "{user_data["title"]}":\n')
@ -112,7 +117,9 @@ class RecommendationLoader:
print("No recommendations found.")
# ------------------------ Function: get_explanation ------------------------
###########################################################
#### Function: get_explanation
###########################################################
def get_explanation(self, row, target_row):
explanation = []
title = row['name']
@ -136,7 +143,9 @@ class RecommendationLoader:
return ' '.join(explanation)
# ------------------------ Function: check_genre_overlap ------------------------
###########################################################
#### Function: check_genre_overlap
###########################################################
def check_genre_overlap(self, target_row, row):
# Get genres from the target row
target_genres = set(genre.lower() for genre in target_row['genres'])
@ -149,7 +158,9 @@ class RecommendationLoader:
return overlap
# ------------------------ Function: check_created_by_overlap ------------------------
###########################################################
#### Function: check_created_by_overlap
###########################################################
def check_created_by_overlap(self, target_row, row):
# Get created_by from the target row
target_creators = set(creator.lower() for creator in target_row['created_by'])
@ -162,7 +173,9 @@ class RecommendationLoader:
return overlap
# ------------------------ Function: extract_years ------------------------
###########################################################
#### Function: extract_years
###########################################################
def extract_years(self, air_date):
# Make sure air_date is not null
if pd.isna(air_date):
@ -173,7 +186,9 @@ class RecommendationLoader:
return air_date.split('-')[0]
# ------------------------ Function: get_recommendations ------------------------
###########################################################
#### Function: filter_genres
###########################################################
def filter_genres(self, recommendations, target_row):
# Get genres from the target row
reference_genres = [genre.lower() for genre in target_row['genres']]

View File

@ -1,58 +1,67 @@
from read_data import LoadData
from sentence_transformers import SentenceTransformer
from sklearn.neighbors import NearestNeighbors
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import TruncatedSVD
from scipy.sparse import hstack, csr_matrix
import numpy as np
import pickle
import time
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module='sklearn')
############################## Train model ##############################
#########################################################################
#### Class: TrainModel
#########################################################################
class TrainModel:
def __init__(self, title_data):
self.title_data = title_data
# Settings for vectorization
# Initialize Sentence-BERT model for embeddings
self.bert_model = SentenceTransformer('all-MiniLM-L12-v2')
# Settings for TF-IDF Vectorization
self.vectorizer = TfidfVectorizer(max_features=5000, ngram_range=(1, 2), min_df=0.01, max_df=0.5)
# Settings for nearest neighbors
self.model = NearestNeighbors(metric='cosine')
# Settings for Nearest Neighbors
self.nearest_neighbors = NearestNeighbors(metric='cosine')
self.scaler = StandardScaler()
# Settings for SVD
self.svd = TruncatedSVD(n_components=300)
# ---------------------- Function: train ----------------------
###########################################################
#### Function: Train
###########################################################
def train(self):
print("Starting to train model ...")
start = time.time()
# Preprocess title data
preproccessed_data = self.preprocess_title_data()
# Preprocess title data with advanced embeddings included
preprocessed_data = self.preprocess_title_data()
# Train the NearestNeighbors model
self.model.fit(preproccessed_data)
# Train Nearest Neighbors on the enhanced feature set
self.nearest_neighbors.fit(preprocessed_data)
stop = time.time()
# Count time for training
elapsed_time = stop - start
print(f'Trained model successfully in {elapsed_time:.2f} seconds.')
print(f'Trained model successfully in {time.time() - start:.2f} seconds.')
# ------------------------ Function: recommend ------------------------
###########################################################
#### Function: get_recommendations
###########################################################
def recommend(self, target_row, num_recommendations=40):
# Preprocess target data
target_vector = self.preprocess_target_data(target_row)
# Use NearestNeighbors model as input to K-nearest neighbors
distances, indices = self.model.kneighbors(target_vector, n_neighbors=num_recommendations)
# Use Nearest Neighbors to get recommendations
distances, indices = self.nearest_neighbors.kneighbors(target_vector, n_neighbors=num_recommendations)
recommendations = self.title_data.iloc[indices[0]].copy()
recommendations['distance'] = distances[0]
@ -64,44 +73,76 @@ class TrainModel:
return recommendations.head(num_recommendations)
# ---------------------- Function: preprocess_data ----------------------
###########################################################
#### Function: preprocess_title_data
###########################################################
def preprocess_title_data(self):
# Combine text fields in a new column for vectorization
self.title_data['combined_text'] = (
self.title_data['overview'].fillna('').apply(str) + ' ' +
self.title_data['genres'].fillna('').apply(str) + ' ' +
self.title_data['created_by'].fillna('').apply(str)
)
# Process combined_text column with vectorizer
# Process text data for TF-IDF + SVD
text_features = self.vectorizer.fit_transform(self.title_data['combined_text'])
text_features = self.svd.fit_transform(text_features)
# Scale numerical features in the DataFrame using a scaler
self.numerical_data = self.title_data.select_dtypes(include=['number'])
# Generate Sentence-BERT embeddings
bert_embeddings = self.load_pickle('bert_embeddings.pkl', self.title_data['combined_text'])
# Include ratings in numerical features
# Process numerical features
self.numerical_data = self.title_data.select_dtypes(include=['number'])
if 'vote_average' in self.numerical_data.columns:
self.numerical_data = self.numerical_data[['vote_average']]
# Scale numerical features
numerical_features = self.scaler.fit_transform(self.numerical_data)
numerical_features_sparse = csr_matrix(numerical_features)
# Combine text and numerical features
combined_features = hstack([csr_matrix(text_features), numerical_features_sparse])
# Combine all features
combined_features = hstack([csr_matrix(text_features), csr_matrix(bert_embeddings), numerical_features_sparse])
return combined_features
# ---------------------- Function: preprocess_target_data ----------------------
###########################################################
#### Function: preprocess_target_data
###########################################################
def preprocess_target_data(self, target_row):
# Create feature vector for target row
# Process target text data for TF-IDF + SVD
target_text_vector = self.vectorizer.transform([target_row['combined_text']])
target_text_vector = self.svd.transform(target_text_vector)
# Process numerical features of the referens target
# Generate Sentence-BERT embedding for target
target_bert_embedding = self.embed_text(target_row['combined_text']).reshape(1, -1)
# Process numerical features
target_numerical = target_row[self.numerical_data.columns].values.reshape(1, -1)
target_vector = hstack([csr_matrix(target_text_vector), csr_matrix(self.scaler.transform(target_numerical))])
target_numerical_scaled = self.scaler.transform(target_numerical)
# Combine all target features
target_vector = hstack([csr_matrix(target_text_vector), csr_matrix(target_bert_embedding), csr_matrix(target_numerical_scaled)])
return target_vector
###########################################################
#### Function: load_pickle
###########################################################
def load_pickle(self, filename, title_data):
try:
with open(filename, 'rb') as f:
bert_embeddings = pickle.load(f)
except FileNotFoundError:
bert_embeddings = np.vstack(title_data.apply(self.embed_text).values)
with open(filename, 'wb') as f:
pickle.dump(bert_embeddings, f)
return bert_embeddings
###########################################################
#### Function: embed_text
###########################################################
def embed_text(self, text):
# Use Sentence-BERT to create embeddings
return self.bert_model.encode(text, convert_to_numpy=True)

View File

@ -1,10 +1,14 @@
############################## User input ##############################
###############################################################
#### Class: UserData
###############################################################
class UserData:
def __init__(self):
self.user_data = {}
self.n_rec = 10
# ---------------------- Function: title ----------------------
###########################################################
#### Function: title
###########################################################
def title(self):
# Ask for user input
print("#" * 100)
@ -12,7 +16,9 @@ class UserData:
self.user_data['title'] = title.strip().lower()
return self.user_data
# ---------------------- Function: n_recommendations ----------------------
###########################################################
#### Function: n_recommendations
###########################################################
def n_recommendations(self):
# Ask for number of recommendations
while True: