Big update and changing to flask webgui instead
This commit is contained in:
parent
9765677235
commit
46715bff45
185
README.md
185
README.md
@ -1,102 +1,129 @@
|
||||
# Supervised Learning - TV-Show recommender
|
||||
# Supervised Learning - TV-Show Recommender
|
||||
|
||||
## Table of Contents
|
||||
1. [How to Run the Program](#how-to-run-the-program)
|
||||
2. [Project Overview](#project-overview)
|
||||
3. [Dataset](#dataset)
|
||||
4. [Model and Algorithm](#model-and-algorithm)
|
||||
5. [Features](#features)
|
||||
6. [Requirements](#requirements)
|
||||
7. [Libraries](#libraries)
|
||||
8. [Classes](#classes)
|
||||
9. [References](#references)
|
||||
|
||||
## How to run program
|
||||
## How to Run the Program
|
||||
|
||||
**Before running program**
|
||||
### Prerequisites
|
||||
|
||||
First thing to do is to extract TMDB_tv_dataset_v3.zip in dataset folder so that it contains TMDB_tv_dataset_v3.csv.
|
||||
1. **Download and Extract the Dataset:**
|
||||
- Download the dataset from [TMDB TV Dataset](https://www.kaggle.com/datasets/asaniczka/full-tmdb-tv-shows-dataset-2023-150k-shows).
|
||||
- Extract `TMDB_tv_dataset_v3.zip` into the `dataset/` folder, so it contains the file `TMDB_tv_dataset_v3.csv`.
|
||||
|
||||
**Running program**
|
||||
2. **Install Dependencies:**
|
||||
- Install the necessary libraries listed in `requirements.txt` (see below).
|
||||
|
||||
Start main.py and it will load dataset and ask for a title to get recommendations from, also how many recommendations wanted. Then enter and you will have those recommendations presented on screen.
|
||||
3. **Run the Program:**
|
||||
- Start the program by running the following command:
|
||||
|
||||
> [!NOTE]
|
||||
> First time loading program it will generate Sentence-BERT embeddings that will help program get better recommendations, this can take up to 5min due to big datafile.
|
||||
```bash
|
||||
python main.py
|
||||
```
|
||||
|
||||
- The program will load the dataset, ask for a TV show title to base recommendations on, and prompt for the number of recommendations.
|
||||
|
||||
## Specification
|
||||
- **Note:** The first time the program is run, it will generate **Sentence-BERT embeddings**. This can take up to 5 minutes due to the large size of the dataset.
|
||||
|
||||
**TV-Show recommender**
|
||||
---
|
||||
|
||||
This program will recommend you what tv-show to view based on what you like.
|
||||
You will tell what tv-show you like and how many recommendations wanted, then you will get that
|
||||
amount of recommendations of tv-shows in order of rank from your search.
|
||||
## Project Overview
|
||||
|
||||
### Data Source:
|
||||
I will use a dataset from TMBD
|
||||
The **TV-Show Recommender** is a machine learning-based program that suggests TV shows to users based on their preferences. The system uses **Nearest Neighbors (NN)** and **K-Nearest Neighbors (K-NN)** algorithms with **cosine distance** to recommend TV shows. Users provide a title of a TV show they like, and the system returns personalized recommendations based on similarity to other TV shows in the dataset.
|
||||
|
||||
https://www.kaggle.com/datasets/asaniczka/full-tmdb-tv-shows-dataset-2023-150k-shows
|
||||
---
|
||||
|
||||
### Model:
|
||||
I must first preprocess data with vectorization so that i can train it in NearestNeighbors (NN) alhorithm with cosine distance. Later use NearestNeighbors (NN) in combination with K-NearestNeighbors (K-NN) alhorithm.
|
||||
## Dataset
|
||||
|
||||
### Features:
|
||||
1. Load data from dataset and preprocessing.
|
||||
2. Model training with NN & k-NN algorithm.
|
||||
3. User input
|
||||
4. Recommendations
|
||||
The dataset used in this project is sourced from **TMDB** (The Movie Database). It contains over 150,000 TV shows and includes information such as:
|
||||
|
||||
### Requirements:
|
||||
1. Title data:
|
||||
* Title
|
||||
* Genres
|
||||
* First/last air date
|
||||
* Vote count/average
|
||||
* Director
|
||||
* Description
|
||||
* Networks
|
||||
* Spoken languages
|
||||
* Number of seasons/episodes
|
||||
2. User data:
|
||||
* What Movie / TV-Show prefers
|
||||
* Number of recommendations wanted
|
||||
- Title of TV shows
|
||||
- Genres
|
||||
- First/Last air date
|
||||
- Vote count and average rating
|
||||
- Director/Creator information
|
||||
- Overview/Description
|
||||
- Networks
|
||||
- Spoken languages
|
||||
- Number of seasons/episodes
|
||||
|
||||
### Libraries
|
||||
* pandas: Data manipulation and analysis
|
||||
* scikit-learn: machine learning algorithms and preprocessing
|
||||
* scipy: A scientific computing package for Python
|
||||
* time: provides various functions for working with time
|
||||
* os: functions for interacting with the operating system
|
||||
* re: provides regular expression support
|
||||
* textwrap: Text wrapping and filling
|
||||
Download the dataset from [here](https://www.kaggle.com/datasets/asaniczka/full-tmdb-tv-shows-dataset-2023-150k-shows).
|
||||
|
||||
### Classes
|
||||
1. LoadData
|
||||
* load_data
|
||||
* read_data
|
||||
* clean_data
|
||||
2. ImportData
|
||||
* load_dataset
|
||||
* create_data
|
||||
* clean_data
|
||||
* save_data
|
||||
3. TrainModel
|
||||
* train
|
||||
* recommend
|
||||
* preprocess_title_data
|
||||
* preprocess_target_data
|
||||
4. UserData
|
||||
* input
|
||||
* n_recommendations
|
||||
5. RecommendationLoader
|
||||
* run
|
||||
* get_recommendations
|
||||
* display_recommendations
|
||||
* get_explanation
|
||||
* check_genre_overlap
|
||||
* check_created_by_overlap
|
||||
* extract_years
|
||||
* filter_genres
|
||||
---
|
||||
|
||||
### References
|
||||
* https://scikit-learn.org/dev/modules/generated/sklearn.neighbors.NearestNeighbors.html
|
||||
* https://scikit-learn.org/1.5/modules/generated/sklearn.feature_extraction.text.TfidfVectorizer.html
|
||||
* https://scikit-learn.org/dev/modules/generated/sklearn.preprocessing.StandardScaler.html
|
||||
* https://scikit-learn.org/0.16/modules/generated/sklearn.decomposition.TruncatedSVD.html
|
||||
* https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.hstack.html
|
||||
* https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.html
|
||||
## Model and Algorithm
|
||||
|
||||
The recommender system is based on **Supervised Learning** using the **NearestNeighbors** and **K-NearestNeighbors** algorithms. Here's a breakdown of the process:
|
||||
|
||||
1. **Data Preprocessing:**
|
||||
- The TV show descriptions are vectorized using **Sentence-BERT embeddings** to create dense vector representations of each show's description.
|
||||
|
||||
2. **Model Training:**
|
||||
- The **NearestNeighbors (NN)** algorithm is used with **cosine distance** to compute similarity between TV shows. The algorithm finds the most similar shows to a user-provided title.
|
||||
|
||||
3. **Recommendation Generation:**
|
||||
- The model generates a list of recommended TV shows by finding the nearest neighbors of the input title using cosine similarity.
|
||||
|
||||
---
|
||||
|
||||
## Features
|
||||
|
||||
1. **Data Loading & Preprocessing:**
|
||||
- Loads the TV show data from a CSV file and preprocesses it for model training.
|
||||
|
||||
2. **Model Training with K-NN:**
|
||||
- Trains a K-NN model using the **NearestNeighbors** algorithm for generating recommendations.
|
||||
|
||||
3. **User Input for Recommendations:**
|
||||
- Accepts user input for the TV show title and the number of recommendations.
|
||||
|
||||
4. **TV Show Recommendations:**
|
||||
- Returns a list of recommended TV shows based on similarity to the input TV show.
|
||||
|
||||
---
|
||||
|
||||
## Requirements
|
||||
|
||||
### Data Requirements:
|
||||
The dataset should contain the following columns for each TV show:
|
||||
- **Title**
|
||||
- **Genres**
|
||||
- **First/Last air date**
|
||||
- **Vote count/average**
|
||||
- **Director**
|
||||
- **Overview**
|
||||
- **Networks**
|
||||
- **Spoken languages**
|
||||
- **Number of seasons/episodes**
|
||||
|
||||
### User Input Requirements:
|
||||
- **TV Show Title**: The name of the TV show you like.
|
||||
- **Number of Recommendations**: The number of recommendations you want to receive (default is 10).
|
||||
|
||||
---
|
||||
|
||||
## Libraries
|
||||
|
||||
The following libraries are required to run the program:
|
||||
|
||||
- **pandas**: For data manipulation and analysis.
|
||||
- **scikit-learn**: For machine learning algorithms and preprocessing.
|
||||
- **scipy**: For scientific computing (e.g., sparse matrices).
|
||||
- **time**: For working with time-related functions.
|
||||
- **os**: For interacting with the operating system.
|
||||
- **re**: For regular expression support.
|
||||
- **textwrap**: For text wrapping and formatting.
|
||||
- **flask**: For creating the web interface.
|
||||
|
||||
To install the dependencies, run:
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
|
||||
84
app.py
Normal file
84
app.py
Normal file
@ -0,0 +1,84 @@
|
||||
from flask import Flask, render_template, request
|
||||
from readdata import LoadData
|
||||
from recommendations import RecommendationLoader
|
||||
from training import TrainModel
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
data_loader = LoadData()
|
||||
title_data = data_loader.load_data()
|
||||
|
||||
model = TrainModel(title_data)
|
||||
model.train()
|
||||
|
||||
recommender = RecommendationLoader(model, title_data)
|
||||
|
||||
|
||||
@app.route('/')
|
||||
def home():
|
||||
return render_template('index.html')
|
||||
|
||||
|
||||
@app.route('/recommend', methods=['POST'])
|
||||
def recommend():
|
||||
|
||||
# Get user input
|
||||
title = request.form.get('title').strip()
|
||||
n_recommendations = int(request.form.get('n_recommendations', 10))
|
||||
|
||||
# Validate user input
|
||||
if not title:
|
||||
return render_template('index.html', message="Please enter a valid TV show title.")
|
||||
|
||||
try:
|
||||
n_recommendations = int(n_recommendations)
|
||||
if n_recommendations < 1 or n_recommendations > 50:
|
||||
raise ValueError("Number of recommendations must be between 1 and 50.")
|
||||
except ValueError as e:
|
||||
return render_template('index.html', message=str(e))
|
||||
|
||||
# Get recommendations from the model
|
||||
target_row = title_data[title_data['name'].str.lower() == title.lower()]
|
||||
|
||||
# Check if a match was found
|
||||
if target_row.empty:
|
||||
return render_template('index.html', message=f"No match found for '{title}'. Try again.")
|
||||
|
||||
# Get recommendations
|
||||
target_row = target_row.iloc[0]
|
||||
user_data = {'title': title, 'n_rec': n_recommendations}
|
||||
recommendations = recommender.get_recommendations("flask", target_row, user_data)
|
||||
|
||||
# Check if recommendations were found
|
||||
if recommendations is None or recommendations.empty:
|
||||
return render_template('index.html', message=f"Sorry, no recommendations available for {title}.")
|
||||
|
||||
# Prepare data for display on the webpage
|
||||
recommendations_data = []
|
||||
|
||||
for _, row in recommendations.iterrows():
|
||||
|
||||
# Extract the first and last air dates
|
||||
first_air_date = recommender.extract_years(row['first_air_date'])
|
||||
last_air_date = recommender.extract_years(row['last_air_date'])
|
||||
if last_air_date != "Ongoing" and last_air_date:
|
||||
years = f"{first_air_date} - {last_air_date}"
|
||||
else:
|
||||
years = f"{first_air_date}"
|
||||
|
||||
recommendations_data.append({
|
||||
'title': row['name'],
|
||||
'genres': ', '.join(row['genres']) if isinstance(row['genres'], list) else row['genres'],
|
||||
'overview': row['overview'],
|
||||
'rating': row['vote_average'],
|
||||
'seasons': row['number_of_seasons'],
|
||||
'episodes': row['number_of_episodes'],
|
||||
'networks': ', '.join(row['networks']) if isinstance(row['networks'], list) and row['networks'] else 'N/A',
|
||||
'years': years,
|
||||
})
|
||||
|
||||
return render_template('index.html', recommendations=recommendations_data, original_title=title)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(debug=True)
|
||||
File diff suppressed because it is too large
Load Diff
4
main.py
4
main.py
@ -1,5 +1,5 @@
|
||||
from read_data import LoadData
|
||||
from trainmodel import TrainModel
|
||||
from readdata import LoadData
|
||||
from training import TrainModel
|
||||
from recommendations import RecommendationLoader
|
||||
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from user_data import UserData
|
||||
from user import UserData
|
||||
import pandas as pd
|
||||
import textwrap
|
||||
|
||||
@ -44,7 +44,7 @@ class RecommendationLoader:
|
||||
###########################################################
|
||||
#### Function: get_recommendations
|
||||
###########################################################
|
||||
def get_recommendations(self, target_row, user_data):
|
||||
def get_recommendations(self, type, target_row, user_data):
|
||||
recommendations = pd.DataFrame()
|
||||
n_recommendations = user_data['n_rec']
|
||||
|
||||
@ -58,6 +58,9 @@ class RecommendationLoader:
|
||||
# Make sure we give n_recommendations recommendations
|
||||
recommendations = recommendations.head(n_recommendations)
|
||||
|
||||
if type == 'flask':
|
||||
return recommendations
|
||||
else:
|
||||
self.display_recommendations(user_data, recommendations, n_recommendations, target_row)
|
||||
|
||||
|
||||
@ -106,73 +109,12 @@ class RecommendationLoader:
|
||||
print(f"Seasons: {seasons} ({episodes} episodes)")
|
||||
print(f'\n{overview}\n')
|
||||
|
||||
# Get explanation for recommendation
|
||||
explanation = self.get_explanation(row, target_row)
|
||||
print(f"{explanation}\n")
|
||||
|
||||
print("-" * width)
|
||||
|
||||
print("\nEnd of recommendations.")
|
||||
else:
|
||||
print("No recommendations found.")
|
||||
|
||||
|
||||
###########################################################
|
||||
#### Function: get_explanation
|
||||
###########################################################
|
||||
def get_explanation(self, row, target_row):
|
||||
explanation = []
|
||||
title = row['name']
|
||||
|
||||
explanation.append(f"The title '{title}' was recommended because: \n")
|
||||
|
||||
# Explain genre overlap
|
||||
genre_overlap = self.check_genre_overlap(target_row, row)
|
||||
if genre_overlap:
|
||||
overlapping_genres = ', '.join(genre_overlap)
|
||||
explanation.append(f"It shares the following genres with your preferences: {overlapping_genres}.\n")
|
||||
|
||||
# Explain created_by overlap
|
||||
created_by_overlap = self.check_created_by_overlap(target_row, row)
|
||||
if created_by_overlap:
|
||||
overlapping_created_by = ', '.join(created_by_overlap)
|
||||
explanation.append(f"It shares the following director with your preferences: {overlapping_created_by}.\n")
|
||||
|
||||
# Explain the distance metric
|
||||
explanation.append(f"The distance metric of {round(row['distance'], 2)} indicates that it is quite similar to your preferences.")
|
||||
return ' '.join(explanation)
|
||||
|
||||
|
||||
###########################################################
|
||||
#### Function: check_genre_overlap
|
||||
###########################################################
|
||||
def check_genre_overlap(self, target_row, row):
|
||||
# Get genres from the target row
|
||||
target_genres = set(genre.lower() for genre in target_row['genres'])
|
||||
# Get genres from the recommended row
|
||||
recommended_genres = set(genre.lower() for genre in row['genres'])
|
||||
|
||||
# Find the intersection of the target genres and recommended genres
|
||||
overlap = target_genres.intersection(recommended_genres)
|
||||
|
||||
return overlap
|
||||
|
||||
|
||||
###########################################################
|
||||
#### Function: check_created_by_overlap
|
||||
###########################################################
|
||||
def check_created_by_overlap(self, target_row, row):
|
||||
# Get created_by from the target row
|
||||
target_creators = set(creator.lower() for creator in target_row['created_by'])
|
||||
# Get created_by from the recommended row
|
||||
recommended_creators = set(creator.lower() for creator in row['created_by'])
|
||||
|
||||
# Find the intersection of the target creators and recommended creators
|
||||
overlap = target_creators.intersection(recommended_creators)
|
||||
|
||||
return overlap
|
||||
|
||||
|
||||
###########################################################
|
||||
#### Function: extract_years
|
||||
###########################################################
|
||||
@ -190,6 +132,7 @@ class RecommendationLoader:
|
||||
#### Function: filter_genres
|
||||
###########################################################
|
||||
def filter_genres(self, recommendations, target_row):
|
||||
|
||||
# Get genres from the target row
|
||||
reference_genres = [genre.lower() for genre in target_row['genres']]
|
||||
|
||||
|
||||
92
templates/index.html
Normal file
92
templates/index.html
Normal file
@ -0,0 +1,92 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Recommendation System</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: Arial, sans-serif;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
background-color: #f4f4f9;
|
||||
}
|
||||
.container {
|
||||
max-width: 1200px;
|
||||
margin: 0 auto;
|
||||
padding: 20px;
|
||||
}
|
||||
h1 {
|
||||
text-align: center;
|
||||
color: #333;
|
||||
}
|
||||
form {
|
||||
text-align: center;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
.recommendation {
|
||||
background: #fff;
|
||||
padding: 15px;
|
||||
border-radius: 8px;
|
||||
margin-bottom: 20px;
|
||||
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
.recommendation h3 {
|
||||
margin-top: 0;
|
||||
}
|
||||
.recommendation p {
|
||||
margin: 5px 0;
|
||||
}
|
||||
.recommendation .overview {
|
||||
font-size: 14px;
|
||||
color: #555;
|
||||
}
|
||||
.error-message {
|
||||
color: red;
|
||||
text-align: center;
|
||||
font-weight: bold;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
|
||||
<div class="container">
|
||||
<h1>TV-Show Recommendations</h1>
|
||||
|
||||
<!-- Recommendation Form -->
|
||||
<form method="POST" action="/recommend">
|
||||
<label for="title">Enter a Title (TV Show):</label><br><br>
|
||||
<input type="text" id="title" name="title" required><br><br>
|
||||
|
||||
<label for="n_recommendations">Number of Recommendations:</label><br><br>
|
||||
<input type="number" id="n_recommendations" name="n_recommendations" value="10" min="1" max="50"><br><br>
|
||||
|
||||
<input type="submit" value="Get Recommendations">
|
||||
</form>
|
||||
|
||||
<!-- Display Error Message if any -->
|
||||
{% if message %}
|
||||
<div class="error-message">{{ message }}</div>
|
||||
{% endif %}
|
||||
|
||||
<!-- Display Recommendations -->
|
||||
{% if recommendations %}
|
||||
<h2>Recommendations based on "{{ original_title }}":</h2>
|
||||
<div class="recommendations">
|
||||
{% for rec in recommendations %}
|
||||
<div class="recommendation">
|
||||
<h3>{{ rec.title }} ({{ rec.years }})</h3>
|
||||
<p><strong>Genres:</strong> {{ rec.genres }}</p>
|
||||
<p><strong>Networks:</strong> {{ rec.networks }}</p>
|
||||
<p><strong>Rating:</strong> {{ rec.rating }}</p>
|
||||
<p><strong>Seasons:</strong> {{ rec.seasons }}({{ rec.episodes }} episodes)</p>
|
||||
<p class="overview"><strong>Overview:</strong> {{ rec.overview }}</p>
|
||||
</div>
|
||||
{% endfor %}
|
||||
</div>
|
||||
{% endif %}
|
||||
|
||||
</div>
|
||||
|
||||
</body>
|
||||
</html>
|
||||
@ -1,15 +1,12 @@
|
||||
from read_data import LoadData
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from sklearn.neighbors import NearestNeighbors
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from sklearn.decomposition import TruncatedSVD
|
||||
from scipy.sparse import hstack, csr_matrix
|
||||
import numpy as np
|
||||
import pickle
|
||||
import time
|
||||
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module='sklearn')
|
||||
|
||||
@ -24,14 +21,16 @@ class TrainModel:
|
||||
# Initialize Sentence-BERT model for embeddings
|
||||
self.bert_model = SentenceTransformer('all-MiniLM-L12-v2')
|
||||
|
||||
# Settings for TF-IDF Vectorization
|
||||
# TF-IDF Vectorization settings
|
||||
self.vectorizer = TfidfVectorizer(max_features=5000, ngram_range=(1, 2), min_df=0.01, max_df=0.5)
|
||||
|
||||
# Settings for Nearest Neighbors
|
||||
# Nearest Neighbors settings
|
||||
self.nearest_neighbors = NearestNeighbors(metric='cosine')
|
||||
|
||||
# Scaler for numerical features
|
||||
self.scaler = StandardScaler()
|
||||
|
||||
# Settings for SVD
|
||||
# SVD for dimensionality reduction
|
||||
self.svd = TruncatedSVD(n_components=300)
|
||||
|
||||
|
||||
@ -53,10 +52,9 @@ class TrainModel:
|
||||
|
||||
|
||||
###########################################################
|
||||
#### Function: get_recommendations
|
||||
#### Function: Recommend
|
||||
###########################################################
|
||||
def recommend(self, target_row, num_recommendations=40):
|
||||
|
||||
# Preprocess target data
|
||||
target_vector = self.preprocess_target_data(target_row)
|
||||
|
||||
@ -77,29 +75,28 @@ class TrainModel:
|
||||
#### Function: preprocess_title_data
|
||||
###########################################################
|
||||
def preprocess_title_data(self):
|
||||
# Combine text fields for TF-IDF and BERT
|
||||
self.title_data['combined_text'] = (
|
||||
self.title_data['overview'].fillna('').apply(str) + ' ' +
|
||||
self.title_data['genres'].fillna('').apply(str) + ' ' +
|
||||
self.title_data['created_by'].fillna('').apply(str)
|
||||
)
|
||||
|
||||
# Process text data for TF-IDF + SVD
|
||||
# TF-IDF + SVD
|
||||
text_features = self.vectorizer.fit_transform(self.title_data['combined_text'])
|
||||
text_features = self.svd.fit_transform(text_features)
|
||||
|
||||
# Generate Sentence-BERT embeddings
|
||||
# Sentence-BERT embeddings
|
||||
bert_embeddings = self.load_pickle('bert_embeddings.pkl', self.title_data['combined_text'])
|
||||
|
||||
# Process numerical features
|
||||
# Numerical features
|
||||
self.numerical_data = self.title_data.select_dtypes(include=['number'])
|
||||
if 'vote_average' in self.numerical_data.columns:
|
||||
self.numerical_data = self.numerical_data[['vote_average']]
|
||||
numerical_features = self.scaler.fit_transform(self.numerical_data)
|
||||
numerical_features_sparse = csr_matrix(numerical_features)
|
||||
|
||||
# Combine all features
|
||||
combined_features = hstack([csr_matrix(text_features), csr_matrix(bert_embeddings), numerical_features_sparse])
|
||||
|
||||
combined_features = hstack([csr_matrix(text_features), csr_matrix(bert_embeddings),
|
||||
numerical_features_sparse])
|
||||
return combined_features
|
||||
|
||||
|
||||
@ -107,37 +104,23 @@ class TrainModel:
|
||||
#### Function: preprocess_target_data
|
||||
###########################################################
|
||||
def preprocess_target_data(self, target_row):
|
||||
# Process target text data for TF-IDF + SVD
|
||||
# TF-IDF + SVD
|
||||
target_text_vector = self.vectorizer.transform([target_row['combined_text']])
|
||||
target_text_vector = self.svd.transform(target_text_vector)
|
||||
|
||||
# Generate Sentence-BERT embedding for target
|
||||
# Sentence-BERT embedding
|
||||
target_bert_embedding = self.embed_text(target_row['combined_text']).reshape(1, -1)
|
||||
|
||||
# Process numerical features
|
||||
# Numerical features
|
||||
target_numerical = target_row[self.numerical_data.columns].values.reshape(1, -1)
|
||||
target_numerical_scaled = self.scaler.transform(target_numerical)
|
||||
|
||||
# Combine all target features
|
||||
target_vector = hstack([csr_matrix(target_text_vector), csr_matrix(target_bert_embedding), csr_matrix(target_numerical_scaled)])
|
||||
|
||||
# Combine all features
|
||||
target_vector = hstack([csr_matrix(target_text_vector), csr_matrix(target_bert_embedding),
|
||||
csr_matrix(target_numerical_scaled)])
|
||||
return target_vector
|
||||
|
||||
|
||||
###########################################################
|
||||
#### Function: load_pickle
|
||||
###########################################################
|
||||
def load_pickle(self, filename, title_data):
|
||||
try:
|
||||
with open(filename, 'rb') as f:
|
||||
bert_embeddings = pickle.load(f)
|
||||
except FileNotFoundError:
|
||||
bert_embeddings = np.vstack(title_data.apply(self.embed_text).values)
|
||||
with open(filename, 'wb') as f:
|
||||
pickle.dump(bert_embeddings, f)
|
||||
return bert_embeddings
|
||||
|
||||
|
||||
###########################################################
|
||||
#### Function: embed_text
|
||||
###########################################################
|
||||
@ -146,3 +129,18 @@ class TrainModel:
|
||||
return self.bert_model.encode(text, convert_to_numpy=True)
|
||||
|
||||
|
||||
###########################################################
|
||||
#### Function: load_pickle
|
||||
###########################################################
|
||||
def load_pickle(self, filename, title_data):
|
||||
try:
|
||||
with open(filename, 'rb') as f:
|
||||
bert_embeddings = pickle.load(f)
|
||||
except FileNotFoundError:
|
||||
print("Generating Sentence-BERT embeddings...")
|
||||
bert_embeddings = self.bert_model.encode(title_data.tolist(), batch_size=64, convert_to_numpy=True)
|
||||
with open(filename, 'wb') as f:
|
||||
pickle.dump(bert_embeddings, f)
|
||||
return bert_embeddings
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user