aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--app.py90
-rw-r--r--database.py10
-rw-r--r--db2pc.py29
-rw-r--r--main.py154
-rw-r--r--pc2rec.py10
5 files changed, 163 insertions, 130 deletions
diff --git a/app.py b/app.py
index fbd5c7d..a843c44 100644
--- a/app.py
+++ b/app.py
@@ -19,47 +19,58 @@ index = pinecone.Index("movies")
app = Flask(__name__, template_folder="./templates")
+
def title2trakt_id(title: str, df=df):
- #Matches Exact Title, Otherwise Returns None
+ # Matches Exact Title, Otherwise Returns None
records = df[df["title"].str.lower() == title.lower()]
if len(records) == 0:
return 0, None
elif len(records) == 1:
return 1, records.trakt_id.tolist()[0]
else:
- return 2, records.trakt_id.tolist()
+ return 2, records.trakt_id.tolist()
+
def get_vector_value(trakt_id: int):
- fetch_response = index.fetch(ids=[str(trakt_id)])
- return fetch_response["vectors"][str(trakt_id)]["values"]
-
-def query_vectors(vector: list, top_k: int = 20, include_values: bool = False, include_metada: bool = True):
- query_response = index.query(
- queries=[
- (vector),
- ],
- top_k=top_k,
- include_values=include_values,
- include_metadata=include_metada
- )
- return query_response
+ fetch_response = index.fetch(ids=[str(trakt_id)])
+ return fetch_response["vectors"][str(trakt_id)]["values"]
+
+
+def query_vectors(
+ vector: list,
+ top_k: int = 20,
+ include_values: bool = False,
+ include_metada: bool = True,
+):
+ query_response = index.query(
+ queries=[
+ (vector),
+ ],
+ top_k=top_k,
+ include_values=include_values,
+ include_metadata=include_metada,
+ )
+ return query_response
+
def query2ids(query_response):
- trakt_ids = []
- for match in query_response["results"][0]["matches"]:
- trakt_ids.append(int(match["id"]))
- return trakt_ids
+ trakt_ids = []
+ for match in query_response["results"][0]["matches"]:
+ trakt_ids.append(int(match["id"]))
+ return trakt_ids
+
def get_deets_by_trakt_id(df, trakt_id: int):
- df = df[df["trakt_id"]==trakt_id]
- return {
- "title": df.title.values[0],
- "overview": df.overview.values[0],
- "runtime": int(df.runtime.values[0]),
- "year": int(df.year.values[0]),
- "trakt_id": trakt_id,
- "tagline": df.tagline.values[0]
- }
+ df = df[df["trakt_id"] == trakt_id]
+ return {
+ "title": df.title.values[0],
+ "overview": df.overview.values[0],
+ "runtime": int(df.runtime.values[0]),
+ "year": int(df.year.values[0]),
+ "trakt_id": trakt_id,
+ "tagline": df.tagline.values[0],
+ }
+
@app.route("/similar")
def get_similar_titles():
@@ -99,10 +110,10 @@ def get_similar_titles():
except TypeError:
maxRuntime = 220
vector = get_vector_value(trakt_id)
- movie_queries = query_vectors(vector, top_k = 69)
+ movie_queries = query_vectors(vector, top_k=69)
movie_ids = query2ids(movie_queries)
results = []
- #for trakt_id in movie_ids:
+ # for trakt_id in movie_ids:
# deets = get_deets_by_trakt_id(df, trakt_id)
# results.append(deets)
max_res = 30
@@ -111,12 +122,15 @@ def get_similar_titles():
if cur_res >= max_res:
break
deets = get_deets_by_trakt_id(df, trakt_id)
- if ((deets["year"]>=min_year) and (deets["year"]<=max_year)) and ((deets["runtime"]>=minRuntime) and (deets["runtime"]<=maxRuntime)):
+ if ((deets["year"] >= min_year) and (deets["year"] <= max_year)) and (
+ (deets["runtime"] >= minRuntime) and (deets["runtime"] <= maxRuntime)
+ ):
results.append(deets)
cur_res += 1
- return render_template("show_results.html",deets=results)
+ return render_template("show_results.html", deets=results)
+
-@app.route("/",methods=("GET","POST"))
+@app.route("/", methods=("GET", "POST"))
def find_similar_title():
if request.method == "GET":
return render_template("index.html")
@@ -125,7 +139,9 @@ def find_similar_title():
code, values = title2trakt_id(to_search_title)
print(f"Code {code} for {to_search_title}")
if code == 0:
- search_results = process.extract(to_search_title, movie_titles, scorer=fuzz.token_sort_ratio)
+ search_results = process.extract(
+ to_search_title, movie_titles, scorer=fuzz.token_sort_ratio
+ )
to_search_titles = []
to_search_ids = []
results = []
@@ -143,7 +159,7 @@ def find_similar_title():
deets = get_deets_by_trakt_id(df, int(trakt_id))
deets["trakt_id"] = trakt_id
results.append(deets)
- return render_template("same_titles.html",deets=results)
+ return render_template("same_titles.html", deets=results)
elif code == 1:
vector = get_vector_value(values)
@@ -153,11 +169,11 @@ def find_similar_title():
for trakt_id in movie_ids:
deets = get_deets_by_trakt_id(df, trakt_id)
results.append(deets)
- return render_template("show_results.html",deets=results)
+ return render_template("show_results.html", deets=results)
else:
results = []
for trakt_id in values:
deets = get_deets_by_trakt_id(df, int(trakt_id))
deets["trakt_id"] = trakt_id
results.append(deets)
- return render_template("same_titles.html",deets=results)
+ return render_template("same_titles.html", deets=results)
diff --git a/database.py b/database.py
index 369cf97..1a4bda0 100644
--- a/database.py
+++ b/database.py
@@ -5,7 +5,7 @@ from sqlalchemy import insert
from sqlalchemy.orm import sessionmaker
from sqlalchemy.exc import IntegrityError
-#database_url = "sqlite:///jlm.db"
+# database_url = "sqlite:///jlm.db"
meta = MetaData()
@@ -25,15 +25,17 @@ movies_table = Table(
Column("votes", Integer),
Column("comment_count", Integer),
Column("tagline", String),
- Column("embeddings", PickleType)
-
+ Column("embeddings", PickleType),
)
+
def init_db_stuff(database_url: str):
engine = create_engine(database_url)
meta.create_all(engine)
Session = sessionmaker(bind=engine)
return engine, Session
+
+
"""
movie = {
"title": movie["movie"]["title"],
@@ -46,4 +48,4 @@ def init_db_stuff(database_url: str):
"runtime": movie["movie"]["runtime"],
"country": movie["movie"]["country"]
}
-""" \ No newline at end of file
+"""
diff --git a/db2pc.py b/db2pc.py
index 4acddcb..48d3c4a 100644
--- a/db2pc.py
+++ b/db2pc.py
@@ -20,18 +20,23 @@ model = SentenceTransformer("paraphrase-multilingual-MiniLM-L12-v2")
batch_size = 32
df = pd.read_sql("Select * from movies", engine)
-df["combined_text"] = df["title"] + ": " + df["overview"].fillna('') + " - " + df["tagline"].fillna('') + " Genres:- " + df["genres"].fillna('')
+df["combined_text"] = (
+ df["title"]
+ + ": "
+ + df["overview"].fillna("")
+ + " - "
+ + df["tagline"].fillna("")
+ + " Genres:- "
+ + df["genres"].fillna("")
+)
print(f'Length of Combined Text: {len(df["combined_text"].tolist())}')
-for x in tqdm(range(0,len(df),batch_size)):
- to_send = []
- trakt_ids = df["trakt_id"][x:x+batch_size].tolist()
- sentences = df["combined_text"][x:x+batch_size].tolist()
- embeddings = model.encode(sentences)
- for idx, value in enumerate(trakt_ids):
- to_send.append(
- (
- str(value), embeddings[idx].tolist()
- ))
- index.upsert(to_send) \ No newline at end of file
+for x in tqdm(range(0, len(df), batch_size)):
+ to_send = []
+ trakt_ids = df["trakt_id"][x : x + batch_size].tolist()
+ sentences = df["combined_text"][x : x + batch_size].tolist()
+ embeddings = model.encode(sentences)
+ for idx, value in enumerate(trakt_ids):
+ to_send.append((str(value), embeddings[idx].tolist()))
+ index.upsert(to_send)
diff --git a/main.py b/main.py
index fbb7b4c..e8e8564 100644
--- a/main.py
+++ b/main.py
@@ -10,57 +10,56 @@ import time
trakt_id = os.getenv("TRAKT_ID")
trakt_se = os.getenv("TRAKT_SE")
-max_requests = 5000 # How many requests do you want to make
+max_requests = 5000 # How many requests do you want to make
req_count = 0
years = "1900-2021"
page = 1
-extended = "full" # Required to get additional information
-limit = "10" # No of entires per request
-languages = "en" # Limit to particular language
+extended = "full" # Required to get additional information
+limit = "10" # No of entires per request
+languages = "en" # Limit to particular language
api_base = "https://api.trakt.tv"
database_url = "sqlite:///jlm.db"
headers = {
- "Content-Type": "application/json",
- "trakt-api-version": "2",
- "trakt-api-key": trakt_id
+ "Content-Type": "application/json",
+ "trakt-api-version": "2",
+ "trakt-api-key": trakt_id,
}
params = {
- "query": "",
- "years": years,
- "page": page,
- "extended": extended,
- "limit": limit,
- "languages": languages
+ "query": "",
+ "years": years,
+ "page": page,
+ "extended": extended,
+ "limit": limit,
+ "languages": languages,
}
def create_movie_dict(movie: dict):
- m = movie["movie"]
- movie_dict = {
- "title": m["title"],
- "overview": m["overview"],
- "genres": m["genres"],
- "language": m["language"],
- "year": int(m["year"]),
- "trakt_id": m["ids"]["trakt"],
- "released": m["released"],
- "runtime": int(m["runtime"]),
- "country": m["country"],
- "rating": int(m["rating"]),
- "votes": int(m["votes"]),
- "comment_count": int(m["comment_count"]),
- "tagline": m["tagline"]
- }
- return movie_dict
-
+ m = movie["movie"]
+ movie_dict = {
+ "title": m["title"],
+ "overview": m["overview"],
+ "genres": m["genres"],
+ "language": m["language"],
+ "year": int(m["year"]),
+ "trakt_id": m["ids"]["trakt"],
+ "released": m["released"],
+ "runtime": int(m["runtime"]),
+ "country": m["country"],
+ "rating": int(m["rating"]),
+ "votes": int(m["votes"]),
+ "comment_count": int(m["comment_count"]),
+ "tagline": m["tagline"],
+ }
+ return movie_dict
params["limit"] = 1
-res = requests.get(f"{api_base}/search/movie",headers=headers,params=params)
+res = requests.get(f"{api_base}/search/movie", headers=headers, params=params)
total_items = res.headers["x-pagination-item-count"]
print(f"There are {total_items} movies")
@@ -80,45 +79,54 @@ engine, Session = init_db_stuff(database_url)
start_time = datetime.now()
-for page in tqdm(range(1,max_requests+10)):
- if req_count == 999:
- seconds_to_sleep = 300 - (datetime.now() - start_time).seconds
- if seconds_to_sleep < 1:
- seconds_to_sleep = 60
- print(f"Sleeping {seconds_to_sleep}s")
- # Need to respect their rate limitting
+for page in tqdm(range(1, max_requests + 10)):
+ if req_count == 999:
+ seconds_to_sleep = 300 - (datetime.now() - start_time).seconds
+ if seconds_to_sleep < 1:
+ seconds_to_sleep = 60
+ print(f"Sleeping {seconds_to_sleep}s")
+ # Need to respect their rate limitting
# Better to use x-ratelimit header
- time.sleep(seconds_to_sleep)
- start_time = datetime.now()
- req_count = 0
-
- params["page"] = page
- params["limit"] = int(int(total_items)/max_requests)
- movies = []
- res = requests.get(f"{api_base}/search/movie",headers=headers,params=params)
-
- if res.status_code == 500:
- break
- elif res.status_code == 200:
- None
- else:
- print(f"OwO Code {res.status_code}")
-
- for movie in res.json():
- movies.append(create_movie_dict(movie))
-
- with engine.connect() as conn:
- for movie in movies:
- with conn.begin() as trans:
- stmt = insert(movies_table).values(
- trakt_id=movie["trakt_id"], title=movie["title"], genres=" ".join(movie["genres"]),
- language=movie["language"], year=movie["year"], released=movie["released"],
- runtime=movie["runtime"], country=movie["country"], overview=movie["overview"],
- rating=movie["rating"], votes=movie["votes"], comment_count=movie["comment_count"],
- tagline=movie["tagline"])
- try:
- result = conn.execute(stmt)
- trans.commit()
- except IntegrityError:
- trans.rollback()
- req_count += 1
+ time.sleep(seconds_to_sleep)
+ start_time = datetime.now()
+ req_count = 0
+
+ params["page"] = page
+ params["limit"] = int(int(total_items) / max_requests)
+ movies = []
+ res = requests.get(f"{api_base}/search/movie", headers=headers, params=params)
+
+ if res.status_code == 500:
+ break
+ elif res.status_code == 200:
+ None
+ else:
+ print(f"OwO Code {res.status_code}")
+
+ for movie in res.json():
+ movies.append(create_movie_dict(movie))
+
+ with engine.connect() as conn:
+ for movie in movies:
+ with conn.begin() as trans:
+ stmt = insert(movies_table).values(
+ trakt_id=movie["trakt_id"],
+ title=movie["title"],
+ genres=" ".join(movie["genres"]),
+ language=movie["language"],
+ year=movie["year"],
+ released=movie["released"],
+ runtime=movie["runtime"],
+ country=movie["country"],
+ overview=movie["overview"],
+ rating=movie["rating"],
+ votes=movie["votes"],
+ comment_count=movie["comment_count"],
+ tagline=movie["tagline"],
+ )
+ try:
+ result = conn.execute(stmt)
+ trans.commit()
+ except IntegrityError:
+ trans.rollback()
+ req_count += 1
diff --git a/pc2rec.py b/pc2rec.py
index d427dad..b802631 100644
--- a/pc2rec.py
+++ b/pc2rec.py
@@ -10,7 +10,9 @@ from sqlalchemy import func
movie_name = "Forrest Gump"
with engine.connect() as conn:
- movie_deets = select(movies_table).filter(func.lower(movies_table.columns.title)==func.lower(movie_name))
- result = conn.execute(movie_deets)
- for row in result:
- print(row) \ No newline at end of file
+ movie_deets = select(movies_table).filter(
+ func.lower(movies_table.columns.title) == func.lower(movie_name)
+ )
+ result = conn.execute(movie_deets)
+ for row in result:
+ print(row)