From 9d4b28e21034da5c1fd64ea826eac2de8e497092 Mon Sep 17 00:00:00 2001 From: sophiasun0515 Date: Thu, 6 Jun 2024 01:29:08 +0000 Subject: [PATCH] Add request body test file, update topic match logic, add a dummy test --- src/poprox_recommender/default.py | 1 + src/poprox_recommender/handler.py | 52 +++++++++++++++++++++++++++++++ src/poprox_recommender/paths.py | 2 +- tests/.gitignore | 1 + tests/pfar_test.py | 21 +++++++++++++ tests/request_body.json.dvc | 5 +++ 6 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 tests/.gitignore create mode 100644 tests/pfar_test.py create mode 100644 tests/request_body.json.dvc diff --git a/src/poprox_recommender/default.py b/src/poprox_recommender/default.py index 89877466..c509d0a2 100644 --- a/src/poprox_recommender/default.py +++ b/src/poprox_recommender/default.py @@ -290,6 +290,7 @@ def select_articles( model_device, token_mapping, num_slots, + todays_article_matched_topics ) -> Dict[UUID, List[Article]]: # Transform news to model features diff --git a/src/poprox_recommender/handler.py b/src/poprox_recommender/handler.py index f517a808..7121c617 100644 --- a/src/poprox_recommender/handler.py +++ b/src/poprox_recommender/handler.py @@ -7,6 +7,9 @@ from poprox_recommender.default import select_articles from poprox_recommender.paths import project_root +from transformers import BertTokenizer, BertForSequenceClassification +from torch.nn.functional import sigmoid + def load_model(device_name=None): checkpoint = None @@ -30,6 +33,8 @@ def generate_recs(event, context): todays_articles = [ Article.model_validate(attrs) for attrs in req_body["todays_articles"] ] + topic_matched_dict, todays_article_matched_topics = match_news_topics_to_general(todays_articles) + past_articles = [ Article.model_validate(attrs) for attrs in req_body["past_articles"] ] @@ -46,6 +51,7 @@ def generate_recs(event, context): DEVICE, TOKEN_MAPPING, num_recs, + todays_article_matched_topics ) body = { @@ -60,3 +66,49 @@ def generate_recs(event, context): response = {"statusCode": 200, "body": json.dumps(body)} return response + +def classify_news_topic(model, tokenizer, general_topics, topic): + inputs = tokenizer(topic, return_tensors='pt', truncation=True, padding=True, max_length=512) + outputs = model(**inputs) + logits = outputs.logits + probabilities = sigmoid(logits).squeeze().detach().numpy() + + # Threshold for classification. + threshold = 0.5 + classified_topics = [general_topics[i] for i, prob in enumerate(probabilities) if prob > threshold] + + return classified_topics + +def match_news_topics_to_general(articles): + general_topics = [ + "US News", + "World News", + "Politics", + "Business", + "Entertainment", + "Sports", + "Health", + "Science", + "Tech ", + "Lifestyle", + "Religion", + "Climate", + "Education", + "Oddities", + ] + # Load the pre-trained tokenizer and model + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=len(general_topics)) + topic_matched_dict = {} # we might be able to connect this to DB to read previous matching? + article_to_new_topic = {} + for article in articles: + news_topics = [mention.entity.name for mention in article.mentions] + article_topic = set() + for topic in news_topics: + if topic not in topic_matched_dict: + matched_general_topics = classify_news_topic(model, tokenizer, general_topics, topic) + topic_matched_dict[topic] = matched_general_topics # again, we can store this into db + for t in matched_general_topics: + article_topic.add(t) + article_to_new_topic[article.article_id] = article_topic + return topic_matched_dict, article_to_new_topic \ No newline at end of file diff --git a/src/poprox_recommender/paths.py b/src/poprox_recommender/paths.py index 931047ae..bf0fb8ff 100644 --- a/src/poprox_recommender/paths.py +++ b/src/poprox_recommender/paths.py @@ -2,4 +2,4 @@ def project_root() -> Path: - return Path(__file__).parent.parent + return Path(__file__).parent.parent.parent diff --git a/tests/.gitignore b/tests/.gitignore new file mode 100644 index 00000000..f0f82a0b --- /dev/null +++ b/tests/.gitignore @@ -0,0 +1 @@ +/request_body.json diff --git a/tests/pfar_test.py b/tests/pfar_test.py new file mode 100644 index 00000000..60a9a8fc --- /dev/null +++ b/tests/pfar_test.py @@ -0,0 +1,21 @@ +import json +import sys +from poprox_concepts import Article +from poprox_recommender.handler import match_news_topics_to_general + +def generate_recs(event_path): + with open(event_path, 'r') as j: + req_body = json.loads(j.read()) + todays_articles = [ + Article.model_validate(attrs) for attrs in req_body["todays_articles"] + ] + + topic_matched_dict, todays_article_matched_topics = match_news_topics_to_general(todays_articles) + print(topic_matched_dict) + for article_topic in todays_article_matched_topics: + print(todays_article_matched_topics[article_topic]) + break + +if __name__ == '__main__': + event_path = 'request_body.json' + generate_recs(event_path) \ No newline at end of file diff --git a/tests/request_body.json.dvc b/tests/request_body.json.dvc new file mode 100644 index 00000000..cd101096 --- /dev/null +++ b/tests/request_body.json.dvc @@ -0,0 +1,5 @@ +outs: +- md5: 599081a4fb85de53d5fb695dbb0bc92e + size: 4717739 + hash: md5 + path: request_body.json