标签:shared 有用 类构造 enqueue queue public mpi blank 访问
参考: http://www.tensorfly.cn/tfdoc/how_tos/adding_an_op.html
添加新的OP需要3步(下述所有代码在here):
1. 定义 Op 的接口
// 1. 定义 Op 的接口
// REGISTER_OP()向 TensorFlow 系统注册来定义 Op 的接口,该OP就是HorovodAllreduceOp.
// 在注册时, 指定 Op 的名称: REGISTER_OP("HorovodAllreduce")
// 输入(类型和名称): Input("tensor: T")
// 输出(类型和名称): Output("sum: T")
// 和所需要任何 属性的文档说明Doc(R"doc(...)doc");
//
// 该 Op 接受一个 T 类型 tensor 作为输入, T 类型可以是{int32, int64, float32, float64}
// 输出一个 T 类型 tensor sum,sum是在所有的MPI进程中求和
REGISTER_OP("HorovodAllreduce")
.Attr("T: {int32, int64, float32, float64}")
.Input("tensor: T")
.Output("sum: T")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
})
.Doc(R"doc(
Perform an MPI Allreduce on a tensor. All other processes that do a reduction
on a tensor with the same name must have the same dimension for that tensor.
Tensors are reduced with other tensors that have the same node name for the
allreduce.
Arguments
tensor: A tensor to reduce.
Output
sum: A tensor with the same shape as `tensor`, summed across all MPI processes.
)doc");
2. 为 Op 实现 kernel
// 2. 为 Op 实现 kernel。
// 在定义接口之后, 每一个实现称之为一个 "kernel",提供一个或多个 Op 的实现,即可以存在多个 kernel。
// 为这些 kernel 的每一个创建一个对应的类, 继承 AsyncOpKernel, 覆盖 ComputeAsync 方法。
// ComputeAsync 方法提供一个类型为 OpKernelContext* 的参数 context, 用于访问一些有用的信息, 例如输入和输出的 tensor。
class HorovodAllreduceOp : public AsyncOpKernel {
public:
// 防止类构造函数的隐式自动转换,只能显示调用该构造函数
explicit HorovodAllreduceOp(OpKernelConstruction* context)
: AsyncOpKernel(context) {}
// 重写ComputeAsync()方法
void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
OP_REQUIRES_OK_ASYNC(context, ConvertStatus(common::CheckInitialized()),
done);
auto node_name = name();
auto device = GetDeviceID(context);
auto tensor = context->input(0);
Tensor* output;
OP_REQUIRES_OK_ASYNC(
context, context->allocate_output(0, tensor.shape(), &output), done);
// ReadyEvent makes sure input tensor is ready, and output is allocated.
// shared_ptr 是一个标准的共享所有权的智能指针, 允许多个指针指向同一个对象
auto ready_event = std::shared_ptr<common::ReadyEvent>(RecordReadyEvent(context));
// 模板函数 std::make_shared 可以返回一个指定类型的 std::shared_ptr
auto hvd_context = std::make_shared<TFOpContext>(context);
auto hvd_tensor = std::make_shared<TFTensor>(tensor);
auto hvd_output = std::make_shared<TFTensor>(*output);
// 将张量的Allreduce操作OP加入队列,加入谁的队列??
auto enqueue_result = EnqueueTensorAllreduce(
hvd_context, hvd_tensor, hvd_output, ready_event, node_name, device,
[context, done](const common::Status& status) {
context->SetStatus(ConvertStatus(status));
done();
});
OP_REQUIRES_OK_ASYNC(context, ConvertStatus(enqueue_result), done);
}
};
3. 注册OP到 TensorFlow 系统
// 3. 注册OP到 TensorFlow 系统
// 注册时可以指定该 kernel 运行时的多个约束条件. 例如可以指定一个 kernel 在 CPU 上运行, 另一个在 GPU 上运行
REGISTER_KERNEL_BUILDER(Name("HorovodAllreduce").Device(DEVICE_CPU),
HorovodAllreduceOp);
// 如果执行了GPU
#if HOROVOD_GPU_ALLREDUCE
REGISTER_KERNEL_BUILDER(Name("HorovodAllreduce").Device(DEVICE_GPU),
HorovodAllreduceOp);
#endif
以horovd的HorovodAllreduceOp为例,学习如何在tensorflow上添加一个新的操作OP
标签:shared 有用 类构造 enqueue queue public mpi blank 访问
原文地址:https://www.cnblogs.com/lixiaolun/p/9163431.html