Skip to content

Commit

Permalink
VIT-OCR accuracy check
Browse files Browse the repository at this point in the history
  • Loading branch information
Silv3S committed Apr 13, 2022
1 parent d79515a commit acc5db2
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 2 deletions.
6 changes: 4 additions & 2 deletions paddle/fluid/inference/tests/api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -345,13 +345,15 @@ inference_analysis_test(test_analyzer_transformer_profile SRCS analyzer_transfor
--cpu_num_threads=${CPU_NUM_THREADS_ON_CI})

# VIT-OCR
set(VIT_OCR_URL "https://paddle-qa.bj.bcebos.com/inference_model/2.1.1/ocr")
set(VIT_OCR_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/vit_ocr")
if (NOT EXISTS ${VIT_OCR_INSTALL_DIR}/vit_ocr.tgz)
inference_download_and_uncompress_without_verify(${VIT_OCR_INSTALL_DIR} "https://paddle-qa.bj.bcebos.com" "inference_model/2.1.1/ocr/vit_ocr.tgz")
inference_download_and_uncompress_without_verify(${VIT_OCR_INSTALL_DIR} ${VIT_OCR_URL} vit_ocr.tgz)
inference_download(${VIT_OCR_INSTALL_DIR} ${VIT_OCR_URL} datavit.txt )
endif()
inference_analysis_test(test_analyzer_vit_ocr SRCS analyzer_vit_ocr_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${VIT_OCR_INSTALL_DIR}/vit_ocr)
ARGS --infer_model=${VIT_OCR_INSTALL_DIR}/vit_ocr --infer_data=${VIT_OCR_INSTALL_DIR}/datavit.txt)

# ocr
set(OCR_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/ocr")
Expand Down
63 changes: 63 additions & 0 deletions paddle/fluid/inference/tests/api/analyzer_vit_ocr_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,57 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/inference/tests/api/tester_helper.h"
#include <fstream>
#include <iostream>

namespace paddle {
namespace inference {
namespace analysis {

struct Record {
std::vector<float> data;
std::vector<int32_t> shape;
};

Record ProcessALine(const std::string &line) {
std::vector<std::string> columns;
split(line, '\t', &columns);
CHECK_EQ(columns.size(), 2UL)
<< "data format error, should be <data>\t<shape>";

Record record;
std::vector<std::string> data_strs;
split(columns[0], ' ', &data_strs);
for (auto &d : data_strs) {
record.data.push_back(std::stof(d));
}

std::vector<std::string> shape_strs;
split(columns[1], ' ', &shape_strs);
for (auto &s : shape_strs) {
record.shape.push_back(std::stoi(s));
}

return record;
}

void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
std::string line;
std::ifstream file(FLAGS_infer_data);
std::getline(file, line);
auto record = ProcessALine(line);

PaddleTensor input;
input.shape = record.shape;
input.dtype = PaddleDType::FLOAT32;
size_t input_size = record.data.size() * sizeof(float);
input.data.Resize(input_size);
memcpy(input.data.data(), record.data.data(), input_size);
std::vector<PaddleTensor> input_slots;
input_slots.assign({input});
(*inputs).emplace_back(input_slots);
}

void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) {
cfg->SetModel(FLAGS_infer_model + "/inference.pdmodel",
FLAGS_infer_model + "/inference.pdiparams");
Expand All @@ -33,6 +79,23 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) {
}
}

// Compare results of NativeConfig and AnalysisConfig
void compare(bool use_mkldnn = false) {
AnalysisConfig cfg;
SetConfig(&cfg, use_mkldnn);

std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all);
CompareNativeAndAnalysis(
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
}

TEST(Analyzer_vit_ocr, compare) { compare(); }

#ifdef PADDLE_WITH_MKLDNN
TEST(Analyzer_vit_ocr, compare_mkldnn) { compare(true /* use_mkldnn */); }
#endif

#ifdef PADDLE_WITH_MKLDNN
// Check the fuse status
TEST(Analyzer_vit_ocr, fuse_status) {
Expand Down

0 comments on commit acc5db2

Please sign in to comment.