Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functions to serialize and deserialize results #123

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
34 changes: 34 additions & 0 deletions make_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import json

with open("./tasks_config.json", "r") as f:
tasks_config = json.load(f)

scriptstring = """
from datetime import datetime
import shifthappens.benchmark
import shifthappens.utils
"""

print(tasks_config["tasks"])
for task in tasks_config["tasks"]:
scriptstring += tasks_config["import_lines"][task]
scriptstring += "\n"

scriptstring += tasks_config["import_lines"][tasks_config["model"]]

out_file_location = tasks_config["out_file_location"]
relative_data_folder = tasks_config["relative_data_folder"]

scriptstring += f"""
tasks = shifthappens.benchmark.get_task_registrations()
model = {tasks_config['model']}()
results = shifthappens.benchmark.evaluate_model(
model, "{relative_data_folder}"
)
results_string = shifthappens.utils.serialize_model_results(results)
out_file_location = "{out_file_location}"
with open(out_file_location, 'w') as outfile:
outfile.write(results_string)
"""
with open("./run_tasks.py", "w") as run_script_file:
run_script_file.write(scriptstring)
2 changes: 2 additions & 0 deletions run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
python3 make_script.py
python3 run_tasks.py
25 changes: 25 additions & 0 deletions shifthappens/task_data/task_metadata.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Class for storing a task's metadata."""

from dataclasses import dataclass
import json


@dataclass(frozen=True, eq=True)
Expand All @@ -22,5 +23,29 @@ class TaskMetadata:
relative_data_folder: str
standalone: bool = True

def serialize_task_metadata(self) -> str:
"""
Serialize TaskMetadata object into json string.
"""
metadata_dict = {
"name": self.name,
"relative_data_folder": self.relative_data_folder,
"standalone": self.standalone,
}
return json.dumps(metadata_dict)

@staticmethod
def deserialize_task_metadata(metadata_str: str):
"""
Deserialize valid json string into TaskMetadata object.
"""
metadata_dict = json.loads(metadata_str)
metadata = TaskMetadata(
name=metadata_dict["name"],
relative_data_folder=metadata_dict["relative_data_folder"],
standalone=metadata_dict["standalone"],
)
return metadata


_TASK_METADATA_FIELD = "__task_metadata__"
41 changes: 41 additions & 0 deletions shifthappens/tasks/task_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,44 @@ def __getattr__(self, item) -> float:
return self[item]
else:
return super().__getattribute__(item)

def serialize_summary_metrics(self) -> str:
"""
Serializes summary metrics of the objects into a string.
"""
return str({key.name: value for (key, value) in self.summary_metrics.items()})

def serialize_task_result(self) -> str:
"""
Serializes TaskResult object into a string.
"""
result_dict = {
"summary_metrics": self.serialize_summary_metrics(),
"metrics": str(self._metrics),
}
return str(result_dict)

@staticmethod
def deserialize_summary_metrics(
summary_metrics_str: str,
) -> Dict[Metric, Union[str, Tuple[str, ...]]]:
"""
Deserializes valid string into summary_metrics.
"""
summary_metrics = eval(summary_metrics_str)
result = {}
for key, value in summary_metrics.items():
result[Metric.__members__.get(key)] = value
return result

@staticmethod
def deserialize_task_result(task_result_str: str):
"""
Deserializes valid string into a TaskResult object.
"""
result_dict = eval(task_result_str)
metrics = eval(result_dict["metrics"])
summary_metrics = TaskResult.deserialize_summary_metrics(
result_dict["summary_metrics"]
)
return TaskResult(summary_metrics=summary_metrics, **metrics)
36 changes: 35 additions & 1 deletion shifthappens/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
"""Utility functions that are needed for the entire package."""

import errno
import json
import os
import sys
import time
import urllib.error
from itertools import product
from typing import Optional
from typing import Dict, Optional, Union

from shifthappens.task_data import task_metadata
from shifthappens.tasks.task_result import TaskResult


def dict_product(d):
Expand Down Expand Up @@ -135,3 +139,33 @@ def download_and_extract_archive(
archive = os.path.join(data_folder, filename)
print(f"Extracting {archive} to {data_folder}")
tv_utils.extract_archive(archive, data_folder, remove_finished)


def serialize_model_results(
results: Dict[task_metadata.TaskMetadata, Union[TaskResult, None]]
) -> str:
"""
Converts evaluation results of a model into json objects.
"""
return json.dumps(
{
key.serialize_task_metadata(): value.serialize_task_result()
for (key, value) in results.items()
if value is not None
}
)


def deserialize_model_results(
results_str,
) -> Dict[task_metadata.TaskMetadata, TaskResult]:
"""
Converts json objects to a dictionary with (TaskMetadata, TaskResult) as (key, value)
"""
results_json_dict = json.loads(results_str)
results = {}
for key, value in results_json_dict.items():
results[
task_metadata.TaskMetadata.deserialize_task_metadata(key)
] = TaskResult.deserialize_task_result(value)
return results
40 changes: 40 additions & 0 deletions tasks_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
{
"out_file_location": "/mnt/qb/work/bethge/<<username>>/<<path>>",
"relative_data_folder": "/mnt/qb/work/bethge/<<username>>/<<path>>",
"tasks": [
"imagenet_c",
"ccc",
"imagenet_3dcc",
"imagenet_cartoon",
"imagenet_d",
"imagenet_drawing",
"imagenet_m",
"imagenet_metashift",
"imagenet_patch",
"imagenet_r",
"objectnet",
"raccoons_ood",
"siscore",
"ssb",
"worst_case"
],
"model": "ResNet18",
"import_lines": {
"imagenet_c": "from shifthappens.tasks.imagenet_c.imagenet_c import ImageNetCSeparateCorruptions",
"ccc": "from shifthappens.tasks.ccc.ccc import CCC",
"imagenet_3dcc": "from shifthappens.tasks.imagenet_3dcc.imagenet_3dcc import ImageNet3DCCSeparateCorruptions",
"imagenet_cartoon": "from shifthappens.tasks.imagenet_cartoon.imagenet_cartoon import ImageNetCartoon",
"imagenet_d": "from shifthappens.tasks.imagenet_d.imagenet_d import *",
"imagenet_drawing": "from shifthappens.tasks.imagenet_drawing.imagenet_drawing import ImageNetDrawing",
"imagenet_m": "from shifthappens.tasks.imagenet_m.imagenet_m import ImageNetM",
"imagenet_metashift": "from shifthappens.tasks.imagenet_metashift.imagenet_metashift import ImageNetMetaShift",
"imagenet_patch": "from shifthappens.tasks.imagenet_patch.imagenet_patch import ImageNetPatchCorruptions",
"imagenet_r": "from shifthappens.tasks.imagenet_r.imagenet_r import ImageNetR",
"objectnet": "from shifthappens.tasks.objectnet.objectnet import ObjectNet",
"raccoons_ood": "from shifthappens.tasks.raccoons_ood.raccoons_ood import RaccOOD",
"siscore": "from shifthappens.tasks.siscore.siscore import *",
"ssb": "from shifthappens.tasks.ssb.semantic_shift_benchmark import *",
"worst_case": "from shifthappens.tasks.worst_case.worst_case import WorstCase",
"ResNet18": "from shifthappens.models.torchvision import ResNet18"
}
}