Skip to content

Commit

Permalink
Add request body test file, update topic match logic, add a dummy test
Browse files Browse the repository at this point in the history
  • Loading branch information
sophiasun0515 committed Jun 6, 2024
1 parent c6658e2 commit 9d4b28e
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/poprox_recommender/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ def select_articles(
model_device,
token_mapping,
num_slots,
todays_article_matched_topics
) -> Dict[UUID, List[Article]]:

# Transform news to model features
Expand Down
52 changes: 52 additions & 0 deletions src/poprox_recommender/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from poprox_recommender.default import select_articles
from poprox_recommender.paths import project_root

from transformers import BertTokenizer, BertForSequenceClassification
from torch.nn.functional import sigmoid


def load_model(device_name=None):
checkpoint = None
Expand All @@ -30,6 +33,8 @@ def generate_recs(event, context):
todays_articles = [
Article.model_validate(attrs) for attrs in req_body["todays_articles"]
]
topic_matched_dict, todays_article_matched_topics = match_news_topics_to_general(todays_articles)

past_articles = [
Article.model_validate(attrs) for attrs in req_body["past_articles"]
]
Expand All @@ -46,6 +51,7 @@ def generate_recs(event, context):
DEVICE,
TOKEN_MAPPING,
num_recs,
todays_article_matched_topics
)

body = {
Expand All @@ -60,3 +66,49 @@ def generate_recs(event, context):
response = {"statusCode": 200, "body": json.dumps(body)}

return response

def classify_news_topic(model, tokenizer, general_topics, topic):
inputs = tokenizer(topic, return_tensors='pt', truncation=True, padding=True, max_length=512)
outputs = model(**inputs)
logits = outputs.logits
probabilities = sigmoid(logits).squeeze().detach().numpy()

# Threshold for classification.
threshold = 0.5
classified_topics = [general_topics[i] for i, prob in enumerate(probabilities) if prob > threshold]

return classified_topics

def match_news_topics_to_general(articles):
general_topics = [
"US News",
"World News",
"Politics",
"Business",
"Entertainment",
"Sports",
"Health",
"Science",
"Tech ",
"Lifestyle",
"Religion",
"Climate",
"Education",
"Oddities",
]
# Load the pre-trained tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(general_topics))
topic_matched_dict = {} # we might be able to connect this to DB to read previous matching?
article_to_new_topic = {}
for article in articles:
news_topics = [mention.entity.name for mention in article.mentions]
article_topic = set()
for topic in news_topics:
if topic not in topic_matched_dict:
matched_general_topics = classify_news_topic(model, tokenizer, general_topics, topic)
topic_matched_dict[topic] = matched_general_topics # again, we can store this into db
for t in matched_general_topics:
article_topic.add(t)
article_to_new_topic[article.article_id] = article_topic
return topic_matched_dict, article_to_new_topic
2 changes: 1 addition & 1 deletion src/poprox_recommender/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@


def project_root() -> Path:
return Path(__file__).parent.parent
return Path(__file__).parent.parent.parent
1 change: 1 addition & 0 deletions tests/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
/request_body.json
21 changes: 21 additions & 0 deletions tests/pfar_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import json
import sys
from poprox_concepts import Article
from poprox_recommender.handler import match_news_topics_to_general

def generate_recs(event_path):
with open(event_path, 'r') as j:
req_body = json.loads(j.read())
todays_articles = [
Article.model_validate(attrs) for attrs in req_body["todays_articles"]
]

topic_matched_dict, todays_article_matched_topics = match_news_topics_to_general(todays_articles)
print(topic_matched_dict)
for article_topic in todays_article_matched_topics:
print(todays_article_matched_topics[article_topic])
break

if __name__ == '__main__':
event_path = 'request_body.json'
generate_recs(event_path)
5 changes: 5 additions & 0 deletions tests/request_body.json.dvc
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
outs:
- md5: 599081a4fb85de53d5fb695dbb0bc92e
size: 4717739
hash: md5
path: request_body.json

0 comments on commit 9d4b28e

Please sign in to comment.