Update
This commit is contained in:
parent
c9389f7b1e
commit
ae564c0237
@ -7,6 +7,7 @@
|
||||
|
||||
First thing to do is to extract TMDB_tv_dataset_v3.zip in dataset folder so that it contains TMDB_tv_dataset_v3.csv.
|
||||
|
||||
|
||||
**Running program**
|
||||
|
||||
Start main.py and it will load dataset and ask for a title to get recommendations from, also how many recommendations wanted. Then enter and you will have those recommendations presented on screen.
|
||||
|
||||
216911
dataset/TMDB_tv_dataset_v3.csv
Normal file
216911
dataset/TMDB_tv_dataset_v3.csv
Normal file
File diff suppressed because it is too large
Load Diff
@ -3,7 +3,9 @@ import os
|
||||
import pandas as pd
|
||||
|
||||
|
||||
############################## Import data ##############################
|
||||
###############################################################
|
||||
#### Class: ImportData
|
||||
###############################################################
|
||||
class ImportData:
|
||||
|
||||
def __init__(self):
|
||||
@ -11,7 +13,9 @@ class ImportData:
|
||||
self.loaded_datasets = []
|
||||
|
||||
|
||||
# ---------------------- Function: load_dataset ----------------------
|
||||
###########################################################
|
||||
#### Function: load_dataset
|
||||
###########################################################
|
||||
def load_dataset(self, dataset_path):
|
||||
# Load data from dataset CSV file
|
||||
try:
|
||||
@ -22,7 +26,9 @@ class ImportData:
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------- Function: create_data ----------------------
|
||||
###########################################################
|
||||
#### Function: create_data
|
||||
###########################################################
|
||||
def create_data(self, filename):
|
||||
try:
|
||||
self.data = self.load_dataset(filename)
|
||||
@ -32,7 +38,9 @@ class ImportData:
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------- Function: clean_data ----------------------
|
||||
###########################################################
|
||||
#### Function: clean_data
|
||||
###########################################################
|
||||
def clean_data(self):
|
||||
if self.data is not None:
|
||||
# Drop unnecessary columns
|
||||
@ -53,7 +61,9 @@ class ImportData:
|
||||
print("No data to clean. Please load the dataset first.")
|
||||
|
||||
|
||||
# ---------------------- Function: save_data ----------------------
|
||||
###########################################################
|
||||
#### Function: save_data
|
||||
###########################################################
|
||||
def save_data(self):
|
||||
if self.data is not None:
|
||||
try:
|
||||
|
||||
4
main.py
4
main.py
@ -3,7 +3,9 @@ from trainmodel import TrainModel
|
||||
from recommendations import RecommendationLoader
|
||||
|
||||
|
||||
############################## Main ############################################
|
||||
#########################################################################
|
||||
#### function: main
|
||||
#########################################################################
|
||||
|
||||
def main():
|
||||
|
||||
|
||||
16
read_data.py
16
read_data.py
@ -2,14 +2,18 @@ import pandas as pd
|
||||
from import_data import ImportData
|
||||
|
||||
|
||||
############################## Load data ##############################
|
||||
#########################################################################
|
||||
#### Class: LoadData
|
||||
#########################################################################
|
||||
class LoadData:
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
self.filename = 'TMDB_tv_dataset_v3.csv'
|
||||
|
||||
|
||||
# ---------------------- Function: load_data ----------------------
|
||||
###########################################################
|
||||
#### Function: load_data
|
||||
###########################################################
|
||||
def load_data(self):
|
||||
self.read_data()
|
||||
self.clean_data()
|
||||
@ -17,7 +21,9 @@ class LoadData:
|
||||
return self.data
|
||||
|
||||
|
||||
# ---------------------- Function: read_data ----------------------
|
||||
###########################################################
|
||||
#### Function: read_data
|
||||
###########################################################
|
||||
def read_data(self):
|
||||
print("Starting to read data ...")
|
||||
try:
|
||||
@ -38,7 +44,9 @@ class LoadData:
|
||||
print(f"Error during data import process: {e}")
|
||||
|
||||
|
||||
# ---------------------- Function: clean_data ----------------------
|
||||
###########################################################
|
||||
#### Function: clean_data
|
||||
###########################################################
|
||||
def clean_data(self):
|
||||
# Function to split a string into a list, or use an empty list if no valid data
|
||||
def split_to_list(value):
|
||||
|
||||
@ -3,14 +3,18 @@ import pandas as pd
|
||||
import textwrap
|
||||
|
||||
|
||||
############################## Recommendation loader ##############################
|
||||
###############################################################
|
||||
#### Class: RecommendationLoader
|
||||
###############################################################
|
||||
class RecommendationLoader:
|
||||
def __init__(self, model, title_data):
|
||||
self.model = model
|
||||
self.title_data = title_data
|
||||
|
||||
|
||||
# ------------------------ Function: run ------------------------
|
||||
###########################################################
|
||||
#### Function: run
|
||||
###########################################################
|
||||
def run(self):
|
||||
while True:
|
||||
user_data = UserData()
|
||||
@ -37,13 +41,12 @@ class RecommendationLoader:
|
||||
print("\nWrite 'exit' or 'quit' to end the program.")
|
||||
|
||||
|
||||
# ------------------------ Function: get_recommendations ------------------------
|
||||
###########################################################
|
||||
#### Function: get_recommendations
|
||||
###########################################################
|
||||
def get_recommendations(self, target_row, user_data):
|
||||
recommendations = pd.DataFrame()
|
||||
n_recommendations = user_data['n_rec']
|
||||
recommendations = self.model.recommend(target_row, user_data['n_rec'])
|
||||
|
||||
# I dont want to recommend a title with Reality in it if the reference doesnt have that genre and so on
|
||||
recommendations = self.filter_genres(recommendations, target_row)
|
||||
|
||||
# Get more recommendations and filter untill n_recommendations is reached
|
||||
while len(recommendations) < n_recommendations:
|
||||
@ -58,7 +61,9 @@ class RecommendationLoader:
|
||||
self.display_recommendations(user_data, recommendations, n_recommendations, target_row)
|
||||
|
||||
|
||||
# ------------------------ Function: display_recommendations ------------------------
|
||||
###########################################################
|
||||
#### Function: display_recommendations
|
||||
###########################################################
|
||||
def display_recommendations(self, user_data, recommendations, n_recommendations, target_row):
|
||||
print(f'\n{n_recommendations} recommendations based on "{user_data["title"]}":\n')
|
||||
|
||||
@ -112,7 +117,9 @@ class RecommendationLoader:
|
||||
print("No recommendations found.")
|
||||
|
||||
|
||||
# ------------------------ Function: get_explanation ------------------------
|
||||
###########################################################
|
||||
#### Function: get_explanation
|
||||
###########################################################
|
||||
def get_explanation(self, row, target_row):
|
||||
explanation = []
|
||||
title = row['name']
|
||||
@ -136,7 +143,9 @@ class RecommendationLoader:
|
||||
return ' '.join(explanation)
|
||||
|
||||
|
||||
# ------------------------ Function: check_genre_overlap ------------------------
|
||||
###########################################################
|
||||
#### Function: check_genre_overlap
|
||||
###########################################################
|
||||
def check_genre_overlap(self, target_row, row):
|
||||
# Get genres from the target row
|
||||
target_genres = set(genre.lower() for genre in target_row['genres'])
|
||||
@ -149,7 +158,9 @@ class RecommendationLoader:
|
||||
return overlap
|
||||
|
||||
|
||||
# ------------------------ Function: check_created_by_overlap ------------------------
|
||||
###########################################################
|
||||
#### Function: check_created_by_overlap
|
||||
###########################################################
|
||||
def check_created_by_overlap(self, target_row, row):
|
||||
# Get created_by from the target row
|
||||
target_creators = set(creator.lower() for creator in target_row['created_by'])
|
||||
@ -162,7 +173,9 @@ class RecommendationLoader:
|
||||
return overlap
|
||||
|
||||
|
||||
# ------------------------ Function: extract_years ------------------------
|
||||
###########################################################
|
||||
#### Function: extract_years
|
||||
###########################################################
|
||||
def extract_years(self, air_date):
|
||||
# Make sure air_date is not null
|
||||
if pd.isna(air_date):
|
||||
@ -173,7 +186,9 @@ class RecommendationLoader:
|
||||
return air_date.split('-')[0]
|
||||
|
||||
|
||||
# ------------------------ Function: get_recommendations ------------------------
|
||||
###########################################################
|
||||
#### Function: filter_genres
|
||||
###########################################################
|
||||
def filter_genres(self, recommendations, target_row):
|
||||
# Get genres from the target row
|
||||
reference_genres = [genre.lower() for genre in target_row['genres']]
|
||||
|
||||
105
trainmodel.py
105
trainmodel.py
@ -1,58 +1,67 @@
|
||||
from read_data import LoadData
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sklearn.neighbors import NearestNeighbors
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.decomposition import TruncatedSVD
|
||||
from scipy.sparse import hstack, csr_matrix
|
||||
import numpy as np
|
||||
import pickle
|
||||
import time
|
||||
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module='sklearn')
|
||||
|
||||
|
||||
############################## Train model ##############################
|
||||
#########################################################################
|
||||
#### Class: TrainModel
|
||||
#########################################################################
|
||||
class TrainModel:
|
||||
def __init__(self, title_data):
|
||||
self.title_data = title_data
|
||||
|
||||
# Settings for vectorization
|
||||
# Initialize Sentence-BERT model for embeddings
|
||||
self.bert_model = SentenceTransformer('all-MiniLM-L12-v2')
|
||||
|
||||
# Settings for TF-IDF Vectorization
|
||||
self.vectorizer = TfidfVectorizer(max_features=5000, ngram_range=(1, 2), min_df=0.01, max_df=0.5)
|
||||
|
||||
# Settings for nearest neighbors
|
||||
self.model = NearestNeighbors(metric='cosine')
|
||||
# Settings for Nearest Neighbors
|
||||
self.nearest_neighbors = NearestNeighbors(metric='cosine')
|
||||
self.scaler = StandardScaler()
|
||||
|
||||
# Settings for SVD
|
||||
self.svd = TruncatedSVD(n_components=300)
|
||||
|
||||
|
||||
# ---------------------- Function: train ----------------------
|
||||
###########################################################
|
||||
#### Function: Train
|
||||
###########################################################
|
||||
def train(self):
|
||||
print("Starting to train model ...")
|
||||
|
||||
start = time.time()
|
||||
|
||||
# Preprocess title data
|
||||
preproccessed_data = self.preprocess_title_data()
|
||||
# Preprocess title data with advanced embeddings included
|
||||
preprocessed_data = self.preprocess_title_data()
|
||||
|
||||
# Train the NearestNeighbors model
|
||||
self.model.fit(preproccessed_data)
|
||||
# Train Nearest Neighbors on the enhanced feature set
|
||||
self.nearest_neighbors.fit(preprocessed_data)
|
||||
|
||||
stop = time.time()
|
||||
|
||||
# Count time for training
|
||||
elapsed_time = stop - start
|
||||
|
||||
print(f'Trained model successfully in {elapsed_time:.2f} seconds.')
|
||||
print(f'Trained model successfully in {time.time() - start:.2f} seconds.')
|
||||
|
||||
|
||||
# ------------------------ Function: recommend ------------------------
|
||||
###########################################################
|
||||
#### Function: get_recommendations
|
||||
###########################################################
|
||||
def recommend(self, target_row, num_recommendations=40):
|
||||
|
||||
# Preprocess target data
|
||||
target_vector = self.preprocess_target_data(target_row)
|
||||
|
||||
# Use NearestNeighbors model as input to K-nearest neighbors
|
||||
distances, indices = self.model.kneighbors(target_vector, n_neighbors=num_recommendations)
|
||||
# Use Nearest Neighbors to get recommendations
|
||||
distances, indices = self.nearest_neighbors.kneighbors(target_vector, n_neighbors=num_recommendations)
|
||||
recommendations = self.title_data.iloc[indices[0]].copy()
|
||||
recommendations['distance'] = distances[0]
|
||||
|
||||
@ -64,44 +73,76 @@ class TrainModel:
|
||||
return recommendations.head(num_recommendations)
|
||||
|
||||
|
||||
# ---------------------- Function: preprocess_data ----------------------
|
||||
###########################################################
|
||||
#### Function: preprocess_title_data
|
||||
###########################################################
|
||||
def preprocess_title_data(self):
|
||||
# Combine text fields in a new column for vectorization
|
||||
self.title_data['combined_text'] = (
|
||||
self.title_data['overview'].fillna('').apply(str) + ' ' +
|
||||
self.title_data['genres'].fillna('').apply(str) + ' ' +
|
||||
self.title_data['created_by'].fillna('').apply(str)
|
||||
)
|
||||
|
||||
# Process combined_text column with vectorizer
|
||||
# Process text data for TF-IDF + SVD
|
||||
text_features = self.vectorizer.fit_transform(self.title_data['combined_text'])
|
||||
text_features = self.svd.fit_transform(text_features)
|
||||
|
||||
# Scale numerical features in the DataFrame using a scaler
|
||||
self.numerical_data = self.title_data.select_dtypes(include=['number'])
|
||||
# Generate Sentence-BERT embeddings
|
||||
bert_embeddings = self.load_pickle('bert_embeddings.pkl', self.title_data['combined_text'])
|
||||
|
||||
# Include ratings in numerical features
|
||||
# Process numerical features
|
||||
self.numerical_data = self.title_data.select_dtypes(include=['number'])
|
||||
if 'vote_average' in self.numerical_data.columns:
|
||||
self.numerical_data = self.numerical_data[['vote_average']]
|
||||
|
||||
# Scale numerical features
|
||||
numerical_features = self.scaler.fit_transform(self.numerical_data)
|
||||
numerical_features_sparse = csr_matrix(numerical_features)
|
||||
|
||||
# Combine text and numerical features
|
||||
combined_features = hstack([csr_matrix(text_features), numerical_features_sparse])
|
||||
# Combine all features
|
||||
combined_features = hstack([csr_matrix(text_features), csr_matrix(bert_embeddings), numerical_features_sparse])
|
||||
|
||||
return combined_features
|
||||
|
||||
|
||||
# ---------------------- Function: preprocess_target_data ----------------------
|
||||
###########################################################
|
||||
#### Function: preprocess_target_data
|
||||
###########################################################
|
||||
def preprocess_target_data(self, target_row):
|
||||
# Create feature vector for target row
|
||||
# Process target text data for TF-IDF + SVD
|
||||
target_text_vector = self.vectorizer.transform([target_row['combined_text']])
|
||||
target_text_vector = self.svd.transform(target_text_vector)
|
||||
|
||||
# Process numerical features of the referens target
|
||||
# Generate Sentence-BERT embedding for target
|
||||
target_bert_embedding = self.embed_text(target_row['combined_text']).reshape(1, -1)
|
||||
|
||||
# Process numerical features
|
||||
target_numerical = target_row[self.numerical_data.columns].values.reshape(1, -1)
|
||||
target_vector = hstack([csr_matrix(target_text_vector), csr_matrix(self.scaler.transform(target_numerical))])
|
||||
target_numerical_scaled = self.scaler.transform(target_numerical)
|
||||
|
||||
# Combine all target features
|
||||
target_vector = hstack([csr_matrix(target_text_vector), csr_matrix(target_bert_embedding), csr_matrix(target_numerical_scaled)])
|
||||
|
||||
return target_vector
|
||||
|
||||
|
||||
###########################################################
|
||||
#### Function: load_pickle
|
||||
###########################################################
|
||||
def load_pickle(self, filename, title_data):
|
||||
try:
|
||||
with open(filename, 'rb') as f:
|
||||
bert_embeddings = pickle.load(f)
|
||||
except FileNotFoundError:
|
||||
bert_embeddings = np.vstack(title_data.apply(self.embed_text).values)
|
||||
with open(filename, 'wb') as f:
|
||||
pickle.dump(bert_embeddings, f)
|
||||
return bert_embeddings
|
||||
|
||||
|
||||
###########################################################
|
||||
#### Function: embed_text
|
||||
###########################################################
|
||||
def embed_text(self, text):
|
||||
# Use Sentence-BERT to create embeddings
|
||||
return self.bert_model.encode(text, convert_to_numpy=True)
|
||||
|
||||
|
||||
|
||||
12
user_data.py
12
user_data.py
@ -1,10 +1,14 @@
|
||||
############################## User input ##############################
|
||||
###############################################################
|
||||
#### Class: UserData
|
||||
###############################################################
|
||||
class UserData:
|
||||
def __init__(self):
|
||||
self.user_data = {}
|
||||
self.n_rec = 10
|
||||
|
||||
# ---------------------- Function: title ----------------------
|
||||
###########################################################
|
||||
#### Function: title
|
||||
###########################################################
|
||||
def title(self):
|
||||
# Ask for user input
|
||||
print("#" * 100)
|
||||
@ -12,7 +16,9 @@ class UserData:
|
||||
self.user_data['title'] = title.strip().lower()
|
||||
return self.user_data
|
||||
|
||||
# ---------------------- Function: n_recommendations ----------------------
|
||||
###########################################################
|
||||
#### Function: n_recommendations
|
||||
###########################################################
|
||||
def n_recommendations(self):
|
||||
# Ask for number of recommendations
|
||||
while True:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user