Skip to content

Commit

Permalink
use flat_hash_map and small_vector in kernel factory
Browse files Browse the repository at this point in the history
  • Loading branch information
chenwhql committed Oct 15, 2021
1 parent 3f5f789 commit 2309149
Showing 1 changed file with 21 additions and 15 deletions.
36 changes: 21 additions & 15 deletions paddle/tcmpt/core/kernel_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@

#include <ostream>
#include <string>
#include <unordered_map>
#include <utility>

#include "paddle/tcmpt/core/backend.h"
#include "paddle/tcmpt/core/dtype.h"
#include "paddle/tcmpt/core/kernel_def.h"
#include "paddle/tcmpt/core/layout.h"
#include "paddle/utils/flat_hash_map.h"
#include "paddle/utils/small_vector.h"

// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/enforce.h"
Expand Down Expand Up @@ -209,25 +210,30 @@ class KernelArgsDef {
attribute_defs_.emplace_back(AttributeArgDef(type_index));
}

const std::vector<TensorArgDef>& input_defs() const { return input_defs_; }
const paddle::SmallVector<TensorArgDef>& input_defs() const {
return input_defs_;
}

const std::vector<TensorArgDef>& output_defs() const { return output_defs_; }
const paddle::SmallVector<TensorArgDef>& output_defs() const {
return output_defs_;
}

const std::vector<AttributeArgDef>& attribute_defs() const {
const paddle::SmallVector<AttributeArgDef>& attribute_defs() const {
return attribute_defs_;
}

std::vector<TensorArgDef>& input_defs() { return input_defs_; }
paddle::SmallVector<TensorArgDef>& input_defs() { return input_defs_; }

std::vector<TensorArgDef>& output_defs() { return output_defs_; }
paddle::SmallVector<TensorArgDef>& output_defs() { return output_defs_; }

std::vector<AttributeArgDef>& attribute_defs() { return attribute_defs_; }
paddle::SmallVector<AttributeArgDef>& attribute_defs() {
return attribute_defs_;
}

private:
// TODO(chenweihang): replaced by paddle::small_vector
std::vector<TensorArgDef> input_defs_{{}};
std::vector<TensorArgDef> output_defs_{{}};
std::vector<AttributeArgDef> attribute_defs_{{}};
paddle::SmallVector<TensorArgDef> input_defs_{{}};
paddle::SmallVector<TensorArgDef> output_defs_{{}};
paddle::SmallVector<AttributeArgDef> attribute_defs_{{}};
};

class Kernel {
Expand Down Expand Up @@ -263,10 +269,10 @@ class Kernel {
class KernelFactory {
public:
// replaced by paddle::flat_hash_map later
using KernelMap =
std::unordered_map<KernelName,
std::unordered_map<KernelKey, Kernel, KernelKey::Hash>,
KernelName::Hash>;
using KernelMap = paddle::flat_hash_map<
KernelName,
paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash>,
KernelName::Hash>;

static KernelFactory& Instance();

Expand Down

0 comments on commit 2309149

Please sign in to comment.