diff --git a/paddle/tcmpt/core/kernel_factory.h b/paddle/tcmpt/core/kernel_factory.h index 180f0ce2c6b87..db1f0df76e6ba 100644 --- a/paddle/tcmpt/core/kernel_factory.h +++ b/paddle/tcmpt/core/kernel_factory.h @@ -16,13 +16,14 @@ #include #include -#include #include #include "paddle/tcmpt/core/backend.h" #include "paddle/tcmpt/core/dtype.h" #include "paddle/tcmpt/core/kernel_def.h" #include "paddle/tcmpt/core/layout.h" +#include "paddle/utils/flat_hash_map.h" +#include "paddle/utils/small_vector.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/enforce.h" @@ -209,25 +210,30 @@ class KernelArgsDef { attribute_defs_.emplace_back(AttributeArgDef(type_index)); } - const std::vector& input_defs() const { return input_defs_; } + const paddle::SmallVector& input_defs() const { + return input_defs_; + } - const std::vector& output_defs() const { return output_defs_; } + const paddle::SmallVector& output_defs() const { + return output_defs_; + } - const std::vector& attribute_defs() const { + const paddle::SmallVector& attribute_defs() const { return attribute_defs_; } - std::vector& input_defs() { return input_defs_; } + paddle::SmallVector& input_defs() { return input_defs_; } - std::vector& output_defs() { return output_defs_; } + paddle::SmallVector& output_defs() { return output_defs_; } - std::vector& attribute_defs() { return attribute_defs_; } + paddle::SmallVector& attribute_defs() { + return attribute_defs_; + } private: - // TODO(chenweihang): replaced by paddle::small_vector - std::vector input_defs_{{}}; - std::vector output_defs_{{}}; - std::vector attribute_defs_{{}}; + paddle::SmallVector input_defs_{{}}; + paddle::SmallVector output_defs_{{}}; + paddle::SmallVector attribute_defs_{{}}; }; class Kernel { @@ -263,10 +269,10 @@ class Kernel { class KernelFactory { public: // replaced by paddle::flat_hash_map later - using KernelMap = - std::unordered_map, - KernelName::Hash>; + using KernelMap = paddle::flat_hash_map< + KernelName, + paddle::flat_hash_map, + KernelName::Hash>; static KernelFactory& Instance();