Skip to content

Commit

Permalink
fix get var data type problem
Browse files Browse the repository at this point in the history
  • Loading branch information
thisjiang committed Sep 1, 2021
1 parent c1158b2 commit 60e126a
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 68 deletions.
21 changes: 20 additions & 1 deletion paddle/fluid/compiler/paddle2piano/piano_graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,27 @@ namespace paddle {
namespace piano {

using framework::ir::Node;
using framework::proto::VarType;
using GraphNodeVec = PianoGraphExecutor::GraphNodeVec;

VarType::Type GetVarDataType(const framework::VarDesc* var) {
// non-pod type list
static std::unordered_set<VarType::Type> non_pod_types = {
VarType::LOD_TENSOR, VarType::SELECTED_ROWS, VarType::LOD_TENSOR_ARRAY};

const auto& var_type = var->GetType();
if (non_pod_types.count(var_type) != 0) {
// if the value type is non-pod type
return var->GetDataType();
} else if (GetVarType2NoteTypeMap().count(var_type) != 0) {
// if value is supported type
return var_type;
}
PADDLE_THROW(
platform::errors::Unimplemented("Unsupported value data type (%s)",
framework::DataTypeToString(var_type)));
}

void CreateInputOperand(const GraphNodeVec& cluster_inputs, PianoScope* scope,
symbolization::NoteBuilder* builder) {
for (int64_t id = 0; id < cluster_inputs.size(); ++id) {
Expand All @@ -41,7 +60,7 @@ void CreateInputOperand(const GraphNodeVec& cluster_inputs, PianoScope* scope,

// create operand shape
const auto& var_shape = node->Var()->GetShape();
const auto& var_type = node->Var()->GetDataType();
const auto& var_type = GetVarDataType(node->Var());

// convert framework vartype to piano note type
note::ElementTypeProto element_type = VarType2NoteType(var_type);
Expand Down
79 changes: 12 additions & 67 deletions paddle/fluid/compiler/paddle2piano/vartype2notetype.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@ limitations under the License. */
#include <unordered_map>

#include "paddle/fluid/compiler/piano/note/note.pb.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace piano {

static note::ElementTypeProto VarType2NoteType(
framework::proto::VarType::Type type) {
static const std::unordered_map<framework::proto::VarType::Type,
note::ElementTypeProto>&
GetVarType2NoteTypeMap() {
static std::unordered_map<framework::proto::VarType::Type,
note::ElementTypeProto>
vartype2notetype = {{framework::proto::VarType::BOOL, note::B1},
Expand All @@ -37,73 +39,16 @@ static note::ElementTypeProto VarType2NoteType(
{framework::proto::VarType::FP64, note::F64},
{framework::proto::VarType::UINT8, note::U8},
{framework::proto::VarType::SIZE_T, note::U64}};
PADDLE_ENFORCE_NE(
vartype2notetype.find(type), vartype2notetype.end(),
platform::errors::NotFound("Cannot found VarType %d.", type));
return vartype2notetype.at(type);
return vartype2notetype;
}

template <framework::proto::VarType::Type T>
constexpr note::ElementTypeProto VarType2NoteType();

template <>
constexpr note::ElementTypeProto
VarType2NoteType<framework::proto::VarType::BOOL>() {
return note::B1;
}

template <>
constexpr note::ElementTypeProto
VarType2NoteType<framework::proto::VarType::INT8>() {
return note::S8;
}

template <>
constexpr note::ElementTypeProto
VarType2NoteType<framework::proto::VarType::INT16>() {
return note::S16;
}

template <>
constexpr note::ElementTypeProto
VarType2NoteType<framework::proto::VarType::INT32>() {
return note::S32;
}

template <>
constexpr note::ElementTypeProto
VarType2NoteType<framework::proto::VarType::INT64>() {
return note::S64;
}

template <>
constexpr note::ElementTypeProto
VarType2NoteType<framework::proto::VarType::FP16>() {
return note::F16;
}

template <>
constexpr note::ElementTypeProto
VarType2NoteType<framework::proto::VarType::FP32>() {
return note::F32;
}

template <>
constexpr note::ElementTypeProto
VarType2NoteType<framework::proto::VarType::FP64>() {
return note::F64;
}

template <>
constexpr note::ElementTypeProto
VarType2NoteType<framework::proto::VarType::UINT8>() {
return note::U8;
}

template <>
constexpr note::ElementTypeProto
VarType2NoteType<framework::proto::VarType::SIZE_T>() {
return note::U64;
static note::ElementTypeProto VarType2NoteType(
framework::proto::VarType::Type type) {
const auto& vartype2notetype = GetVarType2NoteTypeMap();
PADDLE_ENFORCE_NE(vartype2notetype.find(type), vartype2notetype.end(),
"Unsupported value data type (%s)",
framework::DataTypeToString(type));
return vartype2notetype.at(type);
}

} // namespace piano
Expand Down

0 comments on commit 60e126a

Please sign in to comment.