Skip to content

Commit

Permalink
Add type annotations, some documentation, tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Saba-Sabato committed Jan 25, 2024
1 parent 39e083d commit ae52196
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 41 deletions.
35 changes: 21 additions & 14 deletions ocrmac/ocrmac.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
MATPLOTLIB_AVAILABLE = False


def pil2buf(pil_image):
def pil2buf(pil_image: Image.Image):
"""Convert PIL image to buffer"""
buffer = io.BytesIO()
pil_image.save(buffer, format="PNG")
Expand Down Expand Up @@ -47,17 +47,23 @@ def convert_coordinates_pil(bbox, im_width, im_height):
return x1, y1, x2, y2


def text_from_image(image, recognition_level="accurate", language_preference=None):
def text_from_image(
image, recognition_level="accurate", language_preference=None
) -> list[tuple[str, float, tuple[float, float, float, float]]]:
"""
Helper function to call VNRecognizeTextRequest from Apple's vision framework.
Args:
image (str or PIL image): Path to image or PIL image.
recognition_level (str, optional): Recognition level. Defaults to 'accurate'.
language_preference (list, optional): Language preference. Defaults to None.
Returns:
list: List of tuples containing the text, the confidence and the bounding box.
:param image: Path to image (str) or PIL Image.Image.
:param recognition_level: Recognition level. Defaults to 'accurate'.
:param language_preference: Language preference. Defaults to None.
:returns: List of tuples containing the text, the confidence and the bounding box.
Each tuple looks like (text, confidence, (x, y, width, height))
The bounding box (x, y, width, height) is composed of numbers between 0 and 1,
that represent a percentage from total image (width, height) accordingly.
You can use the `convert_coordinates_*` functions to convert them to pixels.
For more info, see https://developer.apple.com/documentation/vision/vndetectedobjectobservation/2867227-boundingbox?language=objc
and https://developer.apple.com/documentation/vision/vnrectangleobservation?language=objc
"""

if isinstance(image, str):
Expand Down Expand Up @@ -138,7 +144,9 @@ def __init__(self, image, recognition_level="accurate", language_preference=None
self.language_preference = language_preference
self.res = None

def recognize(self, px=False):
def recognize(
self, px=False
) -> list[tuple[str, float, tuple[float, float, float, float]]]:
res = text_from_image(
self.image, self.recognition_level, self.language_preference
)
Expand Down Expand Up @@ -188,7 +196,7 @@ def annotate_matplotlib(

return fig

def annotate_PIL(self, color="red", fontsize=12):
def annotate_PIL(self, color="red", fontsize=12) -> Image.Image:
"""_summary_
Args:
Expand All @@ -206,9 +214,8 @@ def annotate_PIL(self, color="red", fontsize=12):

draw = ImageDraw.Draw(annotated_image)
font = ImageFont.truetype("Arial Unicode.ttf", fontsize)

for _ in self.res:
text, conf, bbox = _

for text, conf, bbox in self.res:
x1, y1, x2, y2 = convert_coordinates_pil(
bbox, annotated_image.width, annotated_image.height
)
Expand Down
80 changes: 53 additions & 27 deletions tests/test_ocrmac.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#!/usr/bin/env python

"""Tests for `ocrmac` package."""
from tempfile import TemporaryFile
from unittest import TestCase

import pytest

Expand Down Expand Up @@ -41,30 +43,54 @@ def test_command_line_interface():
"""


def test_ocrmac():

samples = [
"GitHub: Let's build from here",
"github.com",
"Let's build from here",
"Harnessed for productivity. Designed for collaboration.",
"Celebrated for built-in security. Welcome to the",
"platform developers love.",
"Email address",
"Sign up for GitHub",
"Start a free enterprise trial",
"Trusted by the world's leading organizations",
"Mercedes-Benz",
]

annotations = [
_[0] for _ in ocrmac.OCR("test.png", recognition_level="accurate").recognize()
]

for sample in samples:
found = False
for annotation in annotations:
if sample in annotation:
found = True
break
assert found
class Test(TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
# Uncomment these to regenerate test output files
# with open("test_output_fast.png", "w+b") as fast, open(
# "test_output_accurate.png", "w+b"
# ) as accurate:
# ocrmac.OCR("test.png", recognition_level="fast").annotate_PIL().save(
# fast, format="png"
# )
# ocrmac.OCR("test.png", recognition_level="accurate").annotate_PIL().save(
# accurate, format="png"
# )

def test_ocrmac(self):
samples = {
"GitHub: Let's build from here • X",
"github.com",
"Let's build from here",
"Harnessed for productivity. Designed for collaboration.",
"Celebrated for built-in security. Welcome to the",
"platform developers love.",
"Email address",
"Sign up for GitHub",
"Start a free enterprise trial >",
"Trusted by the world's leading organizations y",
"Mercedes-Benz",
}

annotations = {
str(_[0])
for _ in ocrmac.OCR("test.png", recognition_level="accurate").recognize()
}
self.assertTrue(samples <= annotations)

def test_fast(self):
annotated = ocrmac.OCR("test.png", recognition_level="fast").annotate_PIL()
with TemporaryFile() as output2:
annotated.save(output2, format="png")
output2.seek(0)
with open("test_output_fast.png", "rb") as output:
self.assertEqual(output.read(), output2.read())

def test_accurate(self):
annotated = ocrmac.OCR("test.png", recognition_level="accurate").annotate_PIL()
with TemporaryFile() as output2:
annotated.save(output2, format="png")
output2.seek(0)
with open("test_output_accurate.png", "rb") as output:
self.assertEqual(output.read(), output2.read())

0 comments on commit ae52196

Please sign in to comment.