ecsimsw
TF_source ) matmul_op.cc 본문
/// matmul_op.cc
#endif // GOOGLE_CUDA
template
class MatMulOp : public OpKernel {
public:
explicit MatMulOp(OpKernelConstruction* ctx)
: OpKernel(ctx), algorithms_set_already_(false) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_));
LaunchMatMul<Device, T, USE_CUBLAS>::GetBlasGemmAlgorithm(
ctx, &algorithms_, &algorithms_set_already_);
use_autotune_ = MatmulAutotuneEnable();
}
void Compute(OpKernelContext* ctx) override {
const Tensor& a = ctx->input(0);
const Tensor& b = ctx->input(1);
// Check that the dimensions of the two matrices are valid.
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a.shape()),
errors::InvalidArgument("In[0] is not a matrix"));
OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b.shape()),
errors::InvalidArgument("In[1] is not a matrix"));
Eigen::array<Eigen::IndexPair, 1> dim_pair;
dim_pair[0].first = transpose_a_ ? 0 : 1;
dim_pair[0].second = transpose_b_ ? 1 : 0;
OP_REQUIRES(
ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
errors::InvalidArgument(
"Matrix size-incompatible: In[0]: ", a.shape().DebugString(),
", In[1]: ", b.shape().DebugString()));
int a_dim_remaining = 1 - dim_pair[0].first;
int b_dim_remaining = 1 - dim_pair[0].second;
TensorShape out_shape(
{a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)});
Tensor* out = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
if (out->NumElements() == 0) {
// If a has shape [0, x] or b has shape [x, 0], the output shape
// is a 0-element matrix, so there is nothing to do.
return;
}
if (a.NumElements() == 0 || b.NumElements() == 0) {
// If a has shape [x, 0] and b has shape [0, y], the
// output shape is [x, y] where x and y are non-zero, so we fill
// the output with zeros.
functor::SetZeroFunctor<Device, T> f;
f(ctx->eigen_device(), out->flat());
return;
}
LaunchMatMul<Device, T, USE_CUBLAS>::launch(
ctx, a, b, dim_pair, &algorithms_, use_autotune_, out);
}
private:
std::vector algorithms_;
bool algorithms_set_already_;
bool use_autotune_;
bool transpose_a_;
bool transpose_b_;
};
///OpkernelConstruction ctx
private:
const DeviceType device_type_;
DeviceBase* const device_;
Allocator* allocator_;
const NodeDef* def_;
const OpDef* op_def_;
FunctionLibraryRuntime* flib_;
DataTypeSlice input_types_;
MemoryTypeSlice input_memory_types_;
DataTypeSlice output_types_;
MemoryTypeSlice output_memory_types_;
const int graph_def_version_;
Status* status_;
'Machine Learning > tf_source' 카테고리의 다른 글
TF_source ) 2019-04-25 (0) | 2019.04.25 |
---|---|
TF_source ) kernel_builder (0) | 2019.04.23 |
What's the difference between user registers and kernel registers? (0) | 2019.04.20 |
TF_source ) KernelRegistry (0) | 2019.04.20 |
TF_source) Opkernel (0) | 2019.04.12 |