diff --git a/README.md b/README.md index d123385..5f0f87f 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # TableNet -Unofficial implementation of ICDAR 2019 paper : _TableNet: Deep Learning model for end-to-end Table detection and Tabular data extraction from Scanned Document Images._ +Unofficial implementation of ICDAR 2019 paper : _TableNet: Deep Learning model for end-to-end Table detection and Tabular data extraction from Scanned Document Images._ [__Paper__](https://arxiv.org/abs/2001.01469) @@ -11,7 +11,7 @@ TableNet is a modern deep learning architecture that was proposed by a team from They proposed a solution that includes accurate detection of the tabular region within an image and subsequently detecting and extracting information from the rows and columns of the detected table. -**Architecture:** The architecture is based out of Long et al., an encoder-decoder model for semantic segmentation. The same encoder/decoder network is used as the FCN architecture for table extraction. The images are preprocessed and modified using the Tesseract OCR. +**Architecture:** The architecture is based out of Long et al., an encoder-decoder model for semantic segmentation. The same encoder/decoder network is used as the FCN architecture for table extraction. The images are preprocessed and modified using the Tesseract OCR. Source: [Nanonets](https://nanonets.com/blog/table-extraction-deep-learning/#tablenet?&utm_source=nanonets.com/blog/&utm_medium=blog&utm_content=Table%20Detection,%20Information%20Extraction%20and%20Structuring%20using%20Deep%20Learning) @@ -23,13 +23,13 @@ Source: [Nanonets](https://nanonets.com/blog/table-extraction-deep-learning/#tab pip install -r requirements.txt ``` -1. Download the Marmot Dataset from the link given in readme. -1. Run `data_preprocess/generate_mask.py` to generate Table and Column Mask of corresponding images. +1. Download the Marmot Dataset from the link given in readme and unarchive. +1. Run `data_preprocess/generate_mask.py` to generate Table and Column Mask of corresponding images. Check out `--help` for more info. 1. Follow the `TableNet.ipynb` notebook to train and test the model. ## Challenges -* Require a very decent System with a good GPU for accurate result on High pixel images. +* Require a very decent System with a good GPU for accurate result on High pixel images. ## Dataset -Download the dataset provided in paper : [Marmot Dataset](https://drive.google.com/drive/folders/1QZiv5RKe3xlOBdTzuTVuYRxixemVIODp). +Download the dataset provided in paper : [Marmot Dataset](https://drive.google.com/drive/folders/1QZiv5RKe3xlOBdTzuTVuYRxixemVIODp). diff --git a/data_preprocess/generate_mask.py b/data_preprocess/generate_mask.py old mode 100644 new mode 100755 index 4dbedde..a347a1a --- a/data_preprocess/generate_mask.py +++ b/data_preprocess/generate_mask.py @@ -1,101 +1,164 @@ +#!/usr/bin/env python + ''' -Generate Comumn and Table mask from Marmot Data +Generate Column and Table masks from Marmot Data ''' import xml.etree.ElementTree as ET import os +import click import numpy as np from PIL import Image -# Returns if columns belong to same table or not + def sameTable(ymin_1, ymin_2, ymax_1, ymax_2): + '''Check if columns belong to same table or not''' min_diff = abs(ymin_1 - ymin_2) max_diff = abs(ymax_1 - ymax_2) - if min_diff <= 5 and max_diff <=5: + if min_diff <= 5 and max_diff <= 5: return True - elif min_diff <= 4 and max_diff <=7: + elif min_diff <= 4 and max_diff <= 7: return True - elif min_diff <= 7 and max_diff <=4: + elif min_diff <= 7 and max_diff <= 4: return True + return False -if __name__ == "__main__": - directory = './dataset/Marmot_data/' - final_col_directory = './dataset/column_mask/' - final_table_directory = './dataset/table_mask/' - - for file in os.listdir(directory): - filename = os.fsdecode(file) - # Find all the xml files - if filename.endswith(".xml"): - filename = filename[:-4] - - # Parse xml file - tree = ET.parse('./dataset/Marmot/' + filename + '.xml') - root = tree.getroot() - size = root.find('size') - - # Parse width - width = int(size.find('width').text) - height = int(size.find('height').text) - - # Create grayscale image array - col_mask = np.zeros((height, width), dtype=np.int32) - table_mask = np.zeros((height, width), dtype = np.int32) - - got_first_column = False - i=0 - table_xmin = 10000 - table_xmax = 0 - - table_ymin = 10000 - table_ymax = 0 - - for column in root.findall('object'): - bndbox = column.find('bndbox') - xmin = int(bndbox.find('xmin').text) - ymin = int(bndbox.find('ymin').text) - xmax = int(bndbox.find('xmax').text) - ymax = int(bndbox.find('ymax').text) - - col_mask[ymin:ymax, xmin:xmax] = 255 - - if got_first_column: - if sameTable(prev_ymin, ymin, prev_ymax, ymax) == False: - i+=1 - got_first_column = False - table_mask[table_ymin:table_ymax, table_xmin:table_xmax] = 255 - - table_xmin = 10000 - table_xmax = 0 - - table_ymin = 10000 - table_ymax = 0 - - if got_first_column == False: - got_first_column = True - first_xmin = xmin - - prev_ymin = ymin - prev_ymax = ymax - - table_xmin = min(xmin, table_xmin) - table_xmax = max(xmax, table_xmax) - - table_ymin = min(ymin, table_ymin) - table_ymax = max(ymax, table_ymax) - - table_mask[table_ymin:table_ymax, table_xmin:table_xmax] = 255 - - im = Image.fromarray(col_mask.astype(np.uint8),'L') - im.save(final_col_directory + filename + ".jpeg") - - im = Image.fromarray(table_mask.astype(np.uint8),'L') - im.save(final_table_directory + filename + ".jpeg") - - +def generate_masks(xml_file, table_mask_file, column_mask_file): + # Parse xml file + tree = ET.parse(xml_file) + root = tree.getroot() + size = root.find('size') + + # Parse dimensions + width = int(size.find('width').text) + height = int(size.find('height').text) + + # Create grayscale image array + col_mask = np.zeros((height, width), dtype=np.int32) + table_mask = np.zeros((height, width), dtype=np.int32) + + got_first_column = False + i = 0 + table_xmin, table_xmax = 10000, 0 + table_ymin, table_ymax = 10000, 0 + + for column in root.findall('object'): + bndbox = column.find('bndbox') + xmin = int(bndbox.find('xmin').text) + ymin = int(bndbox.find('ymin').text) + xmax = int(bndbox.find('xmax').text) + ymax = int(bndbox.find('ymax').text) + + col_mask[ymin:ymax, xmin:xmax] = 255 + + if got_first_column: + if not sameTable(prev_ymin, ymin, prev_ymax, ymax): + i += 1 + got_first_column = False + table_mask[table_ymin:table_ymax, table_xmin:table_xmax] = 255 + + table_xmin = 10000 + table_xmax = 0 + + table_ymin = 10000 + table_ymax = 0 + + if not got_first_column: + got_first_column = True + first_xmin = xmin + + prev_ymin = ymin + prev_ymax = ymax + + table_xmin = min(xmin, table_xmin) + table_xmax = max(xmax, table_xmax) + + table_ymin = min(ymin, table_ymin) + table_ymax = max(ymax, table_ymax) + + table_mask[table_ymin:table_ymax, table_xmin:table_xmax] = 255 + im = Image.fromarray(col_mask.astype(np.uint8), 'L') + im.save(column_mask_file) + im = Image.fromarray(table_mask.astype(np.uint8), 'L') + im.save(table_mask_file) + +@click.command() +@click.argument('source_dir', type=click.Path()) +@click.argument('dest_dir', default='.') +@click.option('--quiet', '-q', is_flag=True, + help='do not perform dataset consistency check') +def main(source_dir, dest_dir, quiet): + '''Generate Table and Column masks from Marmot Dataset. + + Command line arguments:\n + \b + source_dir -- path to marmot dataset directory with *.bmp and *.xml files, + e.g. ./Marmot_data + dest_dir -- top output directory for saving generated files. Default is + current directory. Two subdirectories will be created under + dest_dir separately for column and table mask files: + * dest_dir/column_mask + * dest_dir/table_mask + + Additionally check dataset for consistency: + every *.xml is expected to have a corresponding *.bmp file. + ''' + + assert os.path.isdir(source_dir), \ + f"Source directory not found: {source_dir}" + + final_col_directory = os.path.join(dest_dir, 'column_mask') + final_table_directory = os.path.join(dest_dir, 'table_mask') + + os.makedirs(final_col_directory, exist_ok=True) + os.makedirs(final_table_directory, exist_ok=True) + + for fname in os.listdir(source_dir): + fname = os.fsdecode(fname) + if fname.endswith('.xml'): + basename = fname[:-4] + outfile = basename + '.jpeg' + + if not quiet: + check_file_pairs(source_dir, basename) + + xmlfile = os.path.join(source_dir, fname) + table_mask_file = os.path.join(final_table_directory, outfile) + column_mask_file = os.path.join(final_col_directory, outfile) + + generate_masks(xmlfile, table_mask_file, column_mask_file) + + +def check_file_pairs(dname, basename) -> bool: + '''In Marmot dataset, each .bmp file has a corresponding .xml file. + Check if this is the case for given file . + + Return + ------ + True if both .bmp and .xml exist + False otherwise + + >>> check_file_pairs('./dataset/Marmot_data', '10.1.1.8.2121_4') + >>> True + ''' + labels = {True: 'Found', False: 'Missing'} + bmp_file = os.path.join(dname, basename + '.bmp') + xml_file = os.path.join(dname, basename + '.xml') + bmp_ok = os.path.isfile(bmp_file) + xml_ok = os.path.isfile(xml_file) + ok = bmp_ok and xml_ok + if not ok: + msg = 'File pair is incomplete:\n{}\t{}\n{}\t{}\n'.format( + labels[bmp_ok], bmp_file, labels[xml_ok], xml_file) + print(msg) + return ok + + +if __name__ == "__main__": + exit(main()) diff --git a/requirements.txt b/requirements.txt index 3af28a2..f9fba2e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ tensorflow==2.2.0 Pillow==5.4.0 -matplotlib==3.3.2 \ No newline at end of file +matplotlib==3.3.2 +numpy +click