ecsimsw

TF_source ) 2019-04-25 본문

TF_source ) 2019-04-25

JinHwan Kim 2019. 4. 25. 01:56

/// matmul_op.cc 1468 REGISTER_KERNEL_BUILDER

#if defined(INTEL_MKL)  // math kernel library
TF_CALL_float(REGISTER_CPU_EIGEN);
#else
TF_CALL_float(REGISTER_CPU);

#define REGISTER_CPU_EIGEN(T)                                                  \
  REGISTER_KERNEL_BUILDER(                                                     \
      Name("MatMul").Device(DEVICE_CPU).TypeConstraint("T").Label("eigen"), \
      MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>);

#define REGISTER_CPU(T)                                             \
  REGISTER_KERNEL_BUILDER(                                          \
      Name("MatMul").Device(DEVICE_CPU).TypeConstraint("T"),     \    // kernel_builder
      MatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>); \    
  REGISTER_CPU_EIGEN(T);

/// op_kernel.h REGISTER_KERNEL_BUILDER

#define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
  REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__)

#define REGISTER_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_builder, ...) \
  REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, __VA_ARGS__)

#define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...)        \
  constexpr bool should_register_##ctr##__flag =                      \
      SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__);                        \
  static ::tensorflow::kernel_factory::OpKernelRegistrar              \
      registrar__body__##ctr##__object(                               \
          should_register_##ctr##__flag                               \
              ? ::tensorflow::register_kernel::kernel_builder.Build()  // KernelDef* kernel_def
              : nullptr,                                              \
          #__VA_ARGS__,                                               \           // StringPiece kernel_class_name
          [](::tensorflow::OpKernelConstruction* context)             \  // Factory factory
              -> ::tensorflow::OpKernel* {                            \
            return new __VA_ARGS__(context);                          \
          });

/// op_kernel.h

namespace kernel_factory {

class OpKernelRegistrar {
 public:
  typedef OpKernel* (*Factory)(OpKernelConstruction*);

  OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name, Factory factory) {
    if (kernel_def != nullptr) {
      InitInternal(kernel_def, kernel_class_name, factory);
    }
  }

/// op_kernel.cc

 private:
  void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name,
                    Factory factory);
};}

namespace kernel_factory {
void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name, Factory factory) {
  if (kernel_def->op() != "_no_register") {
    const string key = Key(kernel_def->op(), DeviceType(kernel_def->device_type()), kernel_def->label());
    GlobalKernelRegistryTyped()->insert(std::make_pair (key, KernelRegistration(*kernel_def, kernel_class_name, factory)));
  }
  delete kernel_def;
}} 

struct KernelRegistration {
  KernelRegistration(const KernelDef& d, StringPiece c,
                     kernel_factory::OpKernelRegistrar::Factory f)
      : def(d), kernel_class_name(std::string(c)), factory(f) {}
  const KernelDef def;
  const string kernel_class_name;
  const kernel_factory::OpKernelRegistrar::Factory factory;
};

typedef std::unordered_multimap<string, KernelRegistration> KernelRegistry;

void* GlobalKernelRegistry() {
  static KernelRegistry* global_kernel_registry = new KernelRegistry;
  return global_kernel_registry;
}

static KernelRegistry* GlobalKernelRegistryTyped() {
  return reinterpret_cast<KernelRegistry*>(GlobalKernelRegistry());
}

static string Key(StringPiece op_type, const DeviceType& device_type, StringPiece label) {
  return strings::StrCat(op_type, ":", DeviceTypeString(device_type), ":", label);
}

/// CONCLUSION

               KernelRegistration = kernelDef, kernel_class_name, Factory

               Key = optype : DeviceType : label

               unordered_multimap<string, KernelRegistration> KernelRegistry   <- insert( pair(key, KernelRegistration )) 

              -> What is relation of key and KernelRegistration?  // Where is the KernelRegistry used point??

 Name("MatMul").Device(DEVICE_CPU).TypeConstraint("T").Build();  -> kernelDef

///op_kernel.h

namespace register_kernel {
class Name : public KernelDefBuilder {
 public:
  explicit Name(const char* op)
      : KernelDefBuilder(SHOULD_REGISTER_OP(op) ? op : "_no_register") {}
};

///kernel_def_builder.h

class KernelDefBuilder {
 public:
  explicit KernelDefBuilder(const char* op_name);
  ~KernelDefBuilder();

  KernelDefBuilder& Device(const char* device_type);
  
  KernelDefBuilder& TypeConstraint(const char* attr_name,
                                   gtl::ArraySlice allowed);
                                   
  KernelDefBuilder& TypeConstraint(const char* attr_name, DataType allowed);
  
  template<class T> 
  KernelDefBuilder& TypeConstraint(const char* attr_name);

  KernelDefBuilder& HostMemory(const char* arg_name);

  KernelDefBuilder& Label(const char* label);
  
  const KernelDef* Build();
  
 private:
  KernelDef* kernel_def_;
  TF_DISALLOW_COPY_AND_ASSIGN(KernelDefBuilder);
};

///kernel_def_builder.cc

KernelDefBuilder::KernelDefBuilder(const char* op_name) {
  kernel_def_ = new KernelDef;
  kernel_def_->set_op(op_name);
}

KernelDefBuilder::~KernelDefBuilder() {
  DCHECK(kernel_def_ == nullptr) << "Did not call Build()";
}

KernelDefBuilder& KernelDefBuilder::Device(const char* device_type) {
  kernel_def_->set_device_type(device_type);
  return *this;

  
KernelDefBuilder& KernelDefBuilder::TypeConstraint(
    const char* attr_name, gtl::ArraySlice allowed) {
  auto* constraint = kernel_def_->add_constraint();
  constraint->set_name(attr_name);
  auto* allowed_values = constraint->mutable_allowed_values()->mutable_list();
  for (DataType dt : allowed) {
    allowed_values->add_type(dt);
  }
  return *this;


KernelDefBuilder& KernelDefBuilder::TypeConstraint(const char* attr_name,
                                                   DataType allowed) {
  auto* constraint = kernel_def_->add_constraint();
  constraint->set_name(attr_name);
  constraint->mutable_allowed_values()->mutable_list()->add_type(allowed);
  return *this;
}

KernelDefBuilder& KernelDefBuilder::HostMemory(const char* arg_name) {
  kernel_def_->add_host_memory_arg(arg_name);
  return *this;
}

} KernelDefBuilder& KernelDefBuilder::Label(const char* label) {
  CHECK_EQ(kernel_def_->label(), "")
      << "Trying to set a kernel's label a second time: '" << label
      << "' in: " << ProtoShortDebugString(*kernel_def_);
  kernel_def_->set_label(label);
  return *this;
}

const KernelDef* KernelDefBuilder::Build() {
  KernelDef* r = kernel_def_;
  kernel_def_ = nullptr;
  return r;
}

///kernel_def_builder.h KernelDef

// Forward declare proto so that kernels don't need to depend on it
class KernelDef;                                                                                                                         

'Machine Learning > tf_source' 카테고리의 다른 글

TF_source ) matmul_op.cc  (0) 2019.04.30
TF_source ) kernel_builder  (0) 2019.04.23
What's the difference between user registers and kernel registers?  (0) 2019.04.20
TF_source ) KernelRegistry  (0) 2019.04.20
TF_source) Opkernel  (0) 2019.04.12
Comments