码迷,mamicode.com
首页 > 其他好文 > 详细

tensorflow源码解析之common_runtime-executor-下

时间:2018-09-02 02:11:10      阅读:250      评论:0      收藏:0      [点我收藏+]

标签:bool   cer   gops   dex   简单   键值   state   ted   other   

目录

  1. 核心概念
  2. executor.h
    1. Executor
    2. NewLocalExecutor
    3. ExecutorBarrier
  3. executor.cc
    1. structs
    2. GraphView
    3. ExecutorImpl
    4. ExecutorState
    5. details

3.4 ExecutorState

在执行器的执行图计算的时候,需要一个结构来保存当前计算的即时信息,TF为此设计了类ExecutorState,它被用来保存每一个对ExecutorImpl::Run调用的状态信息。它会在一个节点已经准备好之后调度这个节点,并且保存每个节点尚未完成的输入信息。
下面让我们先来看一下这个类的结构:

class ExecutorState {
  public:
    ExecutorState(const Executor::Args& args, ExecutorImpl* impl);
    void RunAsync(Executor::DoneCallback done);
  private:
    DeviceContextMap device_context_map_;
    
    typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq;
    typedef gtl::InlinedVector<Entry, 4> EntryVector;
    
    const bool vlog_;
    const bool log_memory_;
    int64 step_id_;
    
    //未拥有
    Rendezvous* rendezvous;
    SessionState* session_state_;
    TensorStore* tensor_store_;
    //每个执行步级别的容器
    ScopedStepContainer* step_container_;
    StepStatesCollector* stats_collector_;
    
    checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_;
    FunctionCallFrame* call_frame;
    const ExecutorImpl* impl_;
    CancellationManager* cancellation_manager_;
    Executor::Args::Runner runner_;
    bool sync_on_finish_;
    
    //拥有
    bool dumped_on_error_ = false;
    //当前执行步骤开始的帧
    FrameState* root_frame_;
    //执行器结束时需要调用的回调函数
    Executor::DoneCallback done_cb_;
    std::atomic_int_fast32_t num_outstanding_ops_;
    mutex mu_;
    Status status_ GUARDED_BY(mu_);
    
    //从帧名称到实际帧的映射。在当前帧的某个迭代周期内,可能会产生一个新的帧。新的子帧的唯一键值必须由父帧的名称、迭代编号、以及由nodedef推断出来的新帧的名称组合而成
    gtl::FlatMap<string, FrameState*> outstanding_frames_ GUARDED_BY(mu_);
    
    //一个帧的名称
    inline string MakeFrameName(FrameState* frame, int64 iter_id, const string& name);
    
    //找到一个现存的帧,或者创建一个新帧,在帧frame的iter迭代周期
    void FindOrCreateChildFrame(FrameState* frame, int64 iter, const Node* node, FrameState** child);
    
    //删除一个帧,当帧调用结束时使用
    void DeleteFrame(FrameState* frame, TaggedNodeSeq* ready);
    
    //清除那些起源于帧frame和迭代iter的帧,当一个子帧结束时调用
    void CleanupFramesIterations(FrameState* frame, int64 iter, TaggedNodeSeq* ready);
    
    //在当前的线程中处理一个已准备好的节点
    void Process(TaggedNode node, int64 scheduled_usec);
    
    //在调用item->kernel之前,先填入其输入
    Status PrepareInputs(const NodeItem& item, Entry* first_input, TensorValueVec* inputs, DeviceContextVec* input_device_contexts, AllocatorAttributeVec* input_alloc_attrs, bool* is_input_dead);
    
    //在item->kernel计算结束之后,处理输出
    Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, EntryVector* outputs, NodeExecStats* stats);
    
    //在处理完输出之后,将输入传递给下一个输入
    void PropagateOutputs(const TaggedNode& tagged_node, const NodeItem* item, EntryVector* outputs, TaggedNodeSeq* ready);
    
    //节点计算结束后,接管stats,如果执行完成则返回true
    bool NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready, NodeExecStats* stats, TaggedNodeReadyQueue* inline_ready);
    
    //调度ready中的所有复杂节点,然后将ready中的非复杂节点放入inline_ready
    void ScheduleReady(const TaggedNodeSeq& ready, TaggedNodeReadyQueue* inline_ready);
    
    //仅用作调试或记录
    inline void MaybeMarkCompleted(FrameState* frame, int64 iter, int64 id);
    
    //输出一个未完成或者活跃节点的信息
    void DumpPendingNodeState(const int node_id, const Entry* input_vector, bool show_nodes_with_no_ready_inputs);
    void DumpActiveNodeState(const FrameState* frame, IterationState* iteration);
    
    //提供执行器的状态信息
    void DumpState();
    const Tensor* GetTensorValueForDump(const Entry& input);
    
    //当执行器结束执行时,清理
    void Finish();
};

从API上来看,ExecutorState几乎担当了执行器的职责,从后面的介绍也可以看出,实际上确实如此。执行器内部实际调用的就是ExecutorState内部的API。从类的结构中,我们还是看到了许多未曾相识的结构,下面我们先一一分析这些类的意义和结构。

首先来看Entry,Entry要么是一个张量指针,要么是一个张量值,为计算图中的节点的输入或输出提供了一种统一的类型。

struct Entry {
    Entry(const Entry& other);
    Entry& operator=(const Entry& other);
    
    void ClearVal();//清除val字段
    ManualConstructor<Tensor> val;//一个张量的值,如果val_filed_is_set是true的话
    Tensor* ref = nullptr;//一个张量引用
    mutext* ref_mu = nullptr;//为上述张量引用的互斥量
    bool has_value = false;//值是否存在,不论是val或者ref
    bool val_filed_is_set = false;//val字段是否被设置
    
    AllocatorAttributes alloc_attr;//为当前的张量分配内存的内存分配器的属性
    
    DeviceContext* device_context = nullptr;//包含了关于这个张量如何创建的设备相关的信息
};

接下来看看IterationState,它代表了一轮迭代的状态。

struct IterationState {
  public:
    //一轮迭代的状态,每个迭代轮次都由一个单独的拷贝。对于第k轮迭代,第i个节点的第j个输入在input_tensors[k][impl_->nodes[i].input_start+j]。注意,没有必要对input_tensors做互斥锁,其中的内容只会被边的前一个节点写入,被边的后一个节点擦除,而每条边的前后两个节点是不可能同时运行的
    Entry* input_tensors;
    
    //每一轮迭代中未完成的op数量
    size_t outstanding_ops;
    
    //每一轮迭代中未完成的帧数量
    int outstanding_frame_count;
    int pending(PendingCounts::Handle h);
    int decrement_pending(PendingCounts::Handle int v);
    
    //标记一个merge节点为live
    void mark_live(PendingCounts::Handle h);
    //标记一个节点为处理开始
    void mark_started(PendingCounts::Handle h);
    //标记一个节点为处理结束
    void mark_completed(PendingCounts::Handle h);
    //获取节点状态
    PendingCounts::NodeState node_state(PendingCounts::Handle h);
    int dead_count(PendingCounts::Handle h);
    void increment_dead_count(PendingCounts::Handle h);
    void adjust_for_activation(PendingCounts::Handle h, bool increment_dead, int* pending_result, int* dead_result);
  
  private:
    PendingCounts counts_;
};

接下来是FrameState,代表了一个帧的状态。对于帧和迭代轮次,有以下几点需要说明:

  • 对于计算图中的循环来说,每个循环都需要创建一个新的帧。执行从第0个迭代开始。当第0个迭代的某个数值通过了一个NextIteration节点时,第1轮迭代就被创建并开始运行了。注意这时第0轮迭代可能仍在进行,所以多轮迭代可能会同时在运行。帧保持了多种数据结构来保存每轮迭代的状态。当第0轮迭代结束后,我们对其对应的状态进行垃圾回收。
  • 一个帧,当它的所有输入都已经被传入,所有的迭代都被计算完成时,这个帧就被认为是完成了,可以被进行垃圾回收了。
  • 一个帧保存了其中每一轮迭代的状态。如果以下三个条件都被满足,那么第i轮迭代就会被认为是已经完成了,第一,第i轮迭代已经没有未完成的节点了,第二,所有该轮的接收操作都已经完成了,第三,第i-1轮已完成。对于第0轮迭代,当帧的所有输入都已完成,我们就认为它已经结束了。
  • 帧和迭代轮次在结束后,都会进行垃圾回收。我们需要保存的状态量,跟调度器允许的并行度高度相关。我们希望调度器能够动态的控制未完成的并行帧和迭代的数量。为了减少内存消耗,调度器可能需要优先调度内层的帧和较低的迭代轮次。
  • 帧的状态一般总是在需要的时候才会被初始化,因此我们没有引入额外的损耗。

下面我们来具体看下FrameState的结构:

struct FrameState {
    const ExecutorImpl* executor = nullptr;//帧所在的执行器
    string frame_name;//当前帧的名称,是父帧,迭代轮次,和frame_name字段拼合起来得到的
    uint64 frame_id;//当前帧的唯一标识
    int64 parent_iter = -1;//父帧的迭代轮次,frame_name和parent_iter共同标识了当前的FrameState
    FrameState* parent_frame = nullptr;//父帧的FrameState
    const int 
    
    max_parallel_iterations;//最大允许的并行迭代数量
    int num_pending_inputs = 0;//当前帧仍然在等待的输入数量
    int64 iteration_count GUARDED_BY(mu) = 0;//当前帧中到达过的最大的迭代数量
    int num_outstanding_iterations GUARDED_BY(mu) = 1;//未完成的迭代数量
    
    gtl::InlinedVecotr<IterationState*,12> iterations;//当前帧活跃的迭代状态
    std::vector<std::pair<const Node*, Entry>> next_iter_roots GUARDED_BY(mu);
    std::vector<std::pair<const Node*, Entry>> inv_values GUARDED_BY(mu);
    std::vector<const Node*> dead_exits GUARDED_BY(mu);
    
    //属于当前帧的静态信息
    PendingCounts* pending_counts = nullptr;
    int total_input_tensors = 0;
    std::vector<const Node*>* nodes = nullptr;
    
    void InitializeFrameInfo(const string& enter_name);
    inline IterationState* GetInteration(int64 iter);
    inline void SetIteration(int64 iter, IterationState* state);
    
    //减少未完成的操作数量,清理帧中的迭代信息。如果帧执行结束则返回true
    inline bool DecrementOutputstandingOps(const GraphView* gview, int64 iter, TaggedNodeSeq* ready);
    inline bool DecrementOutstandingOpsLocked(const GraphView* gview, int64 iter, TaggedNodeSeq* ready);
    
    //如果帧中的计算都已经完成则返回true
    inline bool IsFrameDone();
    //如果迭代的计算已经结束则返回true
    bool IsIterationDone(int64 iter);
    //增加迭代的编号,如果是一个新迭代,就初始化它
    void IncrementIteration(const GraphView* gview, TaggedNodeSeq* ready);
    //激活一个新的迭代轮次中所有的NextIteration节点
    void ActivateNexts(const GraphView* gview, int64 iter, TaggedNodeSeq* ready);
    void ActivateLoopInvs(const GraphView* gview, int64 iter, TaggedNodeSeq* ready);
    void AddLoopInv(const NodeItem* item, const Entry& value, TaggedNodeSeq* ready);
    void ActivateNodes(const NodeItem* item, const bool is_dead, int64 iter, EntryVector* outputs, TaggedNodeSeq* ready);
    bool CleanupIterations(const GraphView* gview, int64 iter, TaggedNodeSeq* ready);
};

最后让我们来看下最后的两个结构体,TaggedNode和TaggedNodeReadyQueue。其中TaggedNode非常简单,就是一个<frame, iter, node>的结构体,而后者就是前者的一个Queue,用来表示已经准备好的节点的队列。

struct TaggedNode {
    const Node* node = nullptr;
    FrameState* input_frame = nullptr;
    int64 input_iter = -1;
    bool is_dead = false;
    
    TaggedNode(const Node* t_node, FrameState* in_frame, int64 in_iter, bool dead);
};
class TaggedNodeReadyQueue {
  public:
    void push_back(TaggedNode node);
    void pop_front();
    bool empty();
    const TaggedNode* begin();
    const TaggedNode* end();
  private:
    gtl::InlinedVector<TaggedNode, 16> ready_;
    int front_index_;
};

关于TaggedNodeReadyQueue,我们要说明一下,本来这里很自然的可以使用std::deque

3.5 details

终于快要接近终点了。在前文中我们讲了那么多结构,最终计算图的执行过程究竟是怎样的,我们仍然不得而知。因为具体的实现细节都隐藏在函数的实现中,而我们上文中全部都在探讨接口。现在我们就来看下,具体的实现方法。
首先,执行器的入口是Run函数,先来看下ExecutorImpl中的Run函数是如何实现的吧。

void ExecutorImpl::RunAsync(const Args& args, DoneCallback done){
    (new ExecutorState(args,this))->RunAsync(std::move(done));
}

这验证了我们上文中提到的,ExecutorImpl仍然只是一个接口,真正的执行是被推到ExecutorState类中完成的。在上述函数中,我们首先定义了一个ExecutorState对象,然后调用了它的RunAsync函数。在构造函数中,首先初始化了root_frame和iteration 0,我们具体看看RunAsync是如何实现的:

void ExecutorState::RunAsync(Executor::DoneCallback done){
    const Graph* graph = impl_->graph_;//获取计算图指针
    TaggedNodeSeq ready;//构建ready节点序列
    
    //让设备填充设备上下文映射
    Device* device = impl_->params_.device;
    Status fill_status = device->FillContextMap(graph, &device_context_map_);
    if(!fill_status.ok()){
        done(fill_status);
        return;
    }
    
    //初始化ready队列
    for(const Node* n : impl_->root_nodes){
        DCHECK_EQ(n->in_edges().size(),0);
        ready.push_back(TaggedNode{n,root_frame_,0,false});
    }
    if(ready.empty()){
        done(Status::OK());
    } else {
        num_outstanding_ops = ready.size();
        root_frame_->iterations[0]->outstanding_ops = ready.size();
        done_cb_ = std::move(done);
        ScheduleReady(ready,nullptr);
    }
}

可见,主要做了两件事,第一是初始化了ready queue,第二是启动了ScheduleReady函数。
下面我们再来看一下SheduleReady函数的运行机制:

tensorflow源码解析之common_runtime-executor-下

标签:bool   cer   gops   dex   简单   键值   state   ted   other   

原文地址:https://www.cnblogs.com/jicanghai/p/9572217.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!