Skip to content

Commit

Permalink
odla torchscript
Browse files Browse the repository at this point in the history
  • Loading branch information
shuhand committed May 6, 2022
1 parent 20ea317 commit fc608c9
Show file tree
Hide file tree
Showing 3 changed files with 334 additions and 0 deletions.
5 changes: 5 additions & 0 deletions ODLA/platforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,8 @@ option(ODLA_BUILD_TF_Wrapper "Build ODLA Tensorflow Wrapper" OFF)
if (ODLA_BUILD_TF_Wrapper)
add_subdirectory(tensorflow)
endif()

option(ODLA_BUILD_TORCH_Wrapper "Build ODLA TORCHSCRIPT Wrapper" OFF)
if (ODLA_BUILD_TORCH_Wrapper)
add_subdirectory(torch)
endif()
65 changes: 65 additions & 0 deletions ODLA/platforms/torch/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# ==============================================================================
# Copyright (C) 2022 Alibaba Group Holding Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
# ==============================================================================
add_odla_library(odla_torch SHARED odla_torchscript.cc)

if (CMAKE_PREFIX_PATH)
find_package(Torch REQUIRED)
endif()
if (TORCH_INSTALL_PREFIX)
set(TORCH_PATH ${TORCH_INSTALL_PREFIX})
else()
execute_process(
COMMAND python3 -c "import torch; print(torch.__path__[0])"
OUTPUT_VARIABLE TORCH_PATH
OUTPUT_STRIP_TRAILING_WHITESPACE
RESULT_VARIABLE retcode)
endif()
if (NOT TORCH_PATH)
message(FATAL_ERROR "torch install is not found.")
else()
message(STATUS "torch install path: ${TORCH_PATH}")
endif()

set(PYTORCH_INC_DIR
${TORCH_PATH}/include
${TORCH_PATH}/torch/csrc/api/include
${TORCH_PATH}/include/TH
${TORCH_PATH}/include/THC
)
message(STATUS "torch include dirs: ${PYTORCH_INC_DIR}")
target_include_directories(odla_torch PRIVATE ${PYTORCH_INC_DIR})

set(PYTORCH_LIBS
-lc10
-ltorch
-ltorch_cpu
-ltorch_python)

if (NOT TORCH_CXX_FLAGS)
execute_process(
COMMAND python3 -c "import torch; print(torch._C._GLIBCXX_USE_CXX11_ABI)"
OUTPUT_VARIABLE TORCH_CXX11_ABI
OUTPUT_STRIP_TRAILING_WHITESPACE
RESULT_VARIABLE retcode)
message(STATUS "Torch CXX flags:-D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI}")
target_compile_definitions(odla_torch PRIVATE _GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI})
else()
message(STATUS "Torch CXX flags:${TORCH_CXX_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
endif()
target_link_options(odla_torch PUBLIC -Wl,-rpath,${TORCH_PATH}/lib -L${TORCH_PATH}/lib)
target_link_libraries(odla_torch PUBLIC ${PYTORCH_LIBS})
target_link_libraries(odla_torch PUBLIC ODLA)
264 changes: 264 additions & 0 deletions ODLA/platforms/torch/odla_torchscript.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
//===- odla_torchscript.cc ------------------------------------------------===//
//
// Copyright (C) 2022 Alibaba Group Holding Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <c10/util/ArrayRef.h>
#include <torch/csrc/jit/serialization/import.h>

#include <ODLA/odla.h>

#include <sstream>

const uint32_t MAX_OUTPUT_TENSORS = 10;
const uint32_t MAX_INPUT_TENSORS = 20;

struct _odla_device {
c10::DeviceType device_t_;
};

struct _odla_value {
_odla_value(uint32_t v):id_(v) {}
uint32_t id_;
};

struct _odla_executable {
torch::jit::Module module_;
std::vector<odla_value> odla_inputs_outputs_;
odla_uint32 num_inputs_;
};

struct _odla_context {
_odla_context();
std::vector<torch::jit::IValue> inputs_;
std::vector<odla_value_type> input_types_;
torch::jit::IValue output_;
std::vector<at::Tensor> output_tensors_;
odla_uint32 num_output_tensors_;
};

_odla_context::_odla_context() {
inputs_.resize(MAX_INPUT_TENSORS);
input_types_.resize(MAX_INPUT_TENSORS);
}

size_t static getElementCount(const odla_value_shape& dims) {
return dims.size == 0 ? 1
: std::accumulate(dims.dims, dims.dims + dims.size, 1,
std::multiplies<size_t>());
}

c10::IntArrayRef static toTensorDim(odla_value_shape& dims) {
return dims.size == 0 ? c10::IntArrayRef(1) :
c10::IntArrayRef(dims.dims, dims.size);
}

c10::ScalarType static toTensorDataType(odla_element_type dt) {
static const std::unordered_map<odla_element_type, c10::ScalarType> dt_map = {
{ODLA_FLOAT32, c10::ScalarType::Float},
{ODLA_INT32, c10::ScalarType::Int},
{ODLA_BOOL, c10::ScalarType::Bool}
};
auto it = dt_map.find(dt);
return it == dt_map.end() ? c10::ScalarType::Float : it->second;
}

odla_element_type static toODLADataType(const c10::ScalarType& st) {
static const std::unordered_map<c10::ScalarType, odla_element_type> dt_map = {
{c10::ScalarType::Float, ODLA_FLOAT32},
{c10::ScalarType::Int, ODLA_INT32},
{c10::ScalarType::Bool, ODLA_BOOL}
};
auto it = dt_map.find(st);
return it == dt_map.end() ? ODLA_FLOAT32 : it->second;
}

odla_value_type static toODLAValueType(const c10::ScalarType& dt, at::IntArrayRef dims) {
odla_value_type ty;
ty.element_type = toODLADataType(dt);
ty.shape.size = dims.size();
int i = 0;
for (auto d : dims) {
ty.shape.dims[i++] = d;
}
return ty;
}

static std::unordered_map<odla_context, std::unique_ptr<_odla_context>> g_ctxs;
static std::unordered_map<odla_executable, std::unique_ptr<_odla_executable>>
g_executables;

static _odla_device g_device{c10::kCUDA};

odla_status odla_AllocateDevice(const odla_vendor vendor,
const odla_device_name device_name,
odla_device* device,
const char* config) {
*device = &g_device;
return ODLA_SUCCESS;
}

odla_status odla_LoadExecutable(odla_resource_location location,
odla_device device,
odla_executable* computation) {
*computation = nullptr;
if (location.location_type != ODLA_LOCATION_MEMORY &&
location.location_type != ODLA_LOCATION_PATH) {
return ODLA_FAILURE;
}
auto comp = std::make_unique<_odla_executable>();
if (location.location_type == ODLA_LOCATION_MEMORY) {
std::istringstream s;
s.rdbuf()->pubsetbuf(const_cast<char*>(
reinterpret_cast<const char*>(location.location)),
location.size);
comp->module_ = torch::jit::load(s, c10::Device(g_device.device_t_));
} else {
comp->module_ = torch::jit::load(reinterpret_cast<const char*>(
location.location),
c10::Device(g_device.device_t_));
}
auto schema = comp->module_.get_method("forward").function().getSchema();
assert(!schema.is_vararg());
assert(!schema.is_varret());
auto num_inputs = comp->module_.get_method("forward").function().num_inputs();
comp->num_inputs_ = num_inputs - 1;
for (uint32_t idx = 0; idx < std::max(comp->num_inputs_, MAX_OUTPUT_TENSORS); ++idx) {
auto v = std::make_unique<_odla_value>(idx);
comp->odla_inputs_outputs_.push_back(v.get());
}
*computation = comp.get();
g_executables[*computation] = std::move(comp);
return ODLA_SUCCESS;
}

odla_status odla_GetArgFromExecutableByIdx(odla_executable comp,
odla_uint32 idx,
odla_value* value) {
if (idx > comp->num_inputs_) {
*value = nullptr;
return ODLA_FAILURE;
}
*value = comp->odla_inputs_outputs_[idx];
return ODLA_SUCCESS;
}

odla_status odla_GetOutputFromExecutableByIdx(const odla_executable comp,
const odla_uint32 output_idx,
odla_value* output_value) {
if (output_idx > comp->odla_inputs_outputs_.size()) {
*output_value = nullptr;
return ODLA_FAILURE;
}
*output_value = comp->odla_inputs_outputs_[output_idx];
return ODLA_SUCCESS;
}

odla_status odla_CreateContext(odla_context* context) {
*context = nullptr;
auto ctx = std::make_unique<_odla_context>();
*context = ctx.get();
g_ctxs[*context] = std::move(ctx);
return ODLA_SUCCESS;
}

odla_status odla_SetRuntimeValueType(odla_context context, odla_value v, odla_value_type ty) {
assert(v->id_ < MAX_INPUT_TENSORS);
context->input_types_[v->id_] = std::move(ty);
return ODLA_SUCCESS;
}

odla_status odla_GetRuntimeValueType(odla_context context, odla_value value, odla_value_type* ty) {
assert(value->id_ <= context->num_output_tensors_);
auto t = context->output_tensors_[value->id_];
*ty = toODLAValueType(t.scalar_type(), t.sizes());
return ODLA_SUCCESS;
}

odla_status odla_BindToArgument(odla_value value, const odla_void* data_ptr,
odla_context context) {
assert(value->id_ < MAX_INPUT_TENSORS);
auto ty = context->input_types_[value->id_];
auto options = c10::TensorOptions()
.dtype(toTensorDataType(ty.element_type))
.device(c10::kCPU);
auto t = at::from_blob(const_cast<void*>(data_ptr), toTensorDim(ty.shape), options);
if (g_device.device_t_ == c10::kCUDA) {
t = t.to(c10::device(c10::kCUDA));
}
context->inputs_[value->id_] = c10::IValue(t);
return ODLA_SUCCESS;
}

odla_status odla_BindToOutput(odla_value value, odla_void* data_ptr,
odla_context context) {
assert(value->id_ < context->num_output_tensors_);
auto t = context->output_tensors_[value->id_];
auto ty = toODLAValueType(t.scalar_type(), t.sizes());
void* raw_data = t.storage().data();
int len = at::elementSize(t.scalar_type()) * getElementCount(ty.shape);
if (g_device.device_t_ == c10::kCPU) {
memcpy(data_ptr, raw_data, len);
} else {
// cudaMemcpy(data_ptr, raw_data, len, cudaMemcpyDeviceToHost);
t = t.to(c10::Device(c10::kCPU));
memcpy(data_ptr, t.storage().data(), len);
}
return ODLA_SUCCESS;
}

odla_status odla_GetRuntimeNumOfOutputs(odla_context context,
odla_uint32 *num_output_ptr) {
*num_output_ptr = (odla_uint32)context->num_output_tensors_;
return ODLA_SUCCESS;
}

odla_status odla_LaunchExecutable(const odla_executable computation,
const odla_context context) {
context->inputs_.resize(computation->num_inputs_);
context->input_types_.resize(computation->num_inputs_);
context->output_ = computation->module_.forward(context->inputs_);

if (context->output_.isTensor()) {
context->output_tensors_.push_back(context->output_.toTensor());
} else {
assert(context->output_.isTuple());
for (const auto& item : context->output_.toTuple()->elements()) {
assert (item.isTensor());
context->output_tensors_.push_back(item.toTensor());
}
}
context->num_output_tensors_ = context->output_tensors_.size();
return ODLA_SUCCESS;
}

odla_status odla_DestroyContext(odla_context context) {
auto it = g_ctxs.find(context);
if (it == g_ctxs.end()) {
return ODLA_FAILURE;
}
g_ctxs.erase(it);
return ODLA_SUCCESS;
}

odla_status odla_DestroyExecutable(odla_executable computation) {
auto it = g_executables.find(computation);
if (it == g_executables.end()) {
return ODLA_FAILURE;
}
g_executables.erase(it);
return ODLA_SUCCESS;
}

odla_status odla_DestroyDevice(odla_device device) { return ODLA_SUCCESS; }

0 comments on commit fc608c9

Please sign in to comment.