码迷,mamicode.com
首页 > 其他好文 > 详细

GraphSAGE 代码解析 - minibatch.py

时间:2018-11-04 00:26:14      阅读:651      评论:0      收藏:0      [点我收藏+]

标签:改变   batch   ace   nbsp   new   根据   rgs   ati   tin   

class EdgeMinibatchIterator

    """ This minibatch iterator iterates over batches of sampled edges or
    random pairs of co-occuring edges.

    G -- networkx graph
    id2idx -- dict mapping node ids to index in feature tensor
    placeholders -- tensorflow placeholders object
    context_pairs -- if not none, then a list of co-occuring node pairs (from random walks)
    batch_size -- size of the minibatches
    max_degree -- maximum size of the downsampled adjacency lists
    n2v_retrain -- signals that the iterator is being used to add new embeddings to a n2v model
    fixed_n2v -- signals that the iterator is being used to retrain n2v with only existing nodes as context
    """

def __init__(self, G, id2idx, placeholders, context_pairs=None, batch_size=100, max_degree=25,

n2v_retrain=False, fixed_n2v=False, **kwargs) 中具体介绍以下:

1 self.nodes = np.random.permutation(G.nodes())
2 # 函数shuffle与permutation都是对原来的数组进行重新洗牌,即随机打乱原来的元素顺序
3 # shuffle直接在原来的数组上进行操作,改变原来数组的顺序,无返回值
4 # permutation不直接在原来的数组上进行操作,而是返回一个新的打乱顺序的数组,并不改变原来的数组。
1 self.adj, self.deg = self.construct_adj()

这里重点看construct_adj()函数。

 1 def construct_adj(self):
 2         adj = len(self.id2idx) *  3             np.ones((len(self.id2idx) + 1, self.max_degree))
 4         # 该矩阵记录训练数据中各节点的邻居节点的编号
 5         # 采样只取max_degree个邻居节点,采样方法见下
 6         # 同样进行了行数加一操作
 7 
 8         deg = np.zeros((len(self.id2idx),))
 9         # 该矩阵记录了每个节点的度数
10 
11         for nodeid in self.G.nodes():
12             if self.G.node[nodeid][test] or self.G.node[nodeid][val]:
13                 continue
14             neighbors = np.array([self.id2idx[neighbor]
15                                   for neighbor in self.G.neighbors(nodeid)                   
16                                   if (not self.G[nodeid][neighbor][train_removed])])
17             # Graph.neighbors() Return a list of the nodes connected to the node n.
18             # 在选取邻居节点时进行了筛选,对于G.neighbors(nodeid) 点node的邻居,
19             # 只取该node与neighbor相连的边的train_removed = False的neighbor
20             # 也就是只取不是val, test的节点。
21             # neighbors得到了邻居节点编号数列。
22 
23             deg[self.id2idx[nodeid]] = len(neighbors)
24             # deg各位取值为该位对应nodeid的节点的度数,
25             # 也即经过上面筛选后得到的邻居数
26 
27             if len(neighbors) == 0:
28                 continue
29             if len(neighbors) > self.max_degree:
30                 neighbors = np.random.choice(
31                     neighbors, self.max_degree, replace=False)
32             # range: neighbors; size = max_degree; replace: replace the origin matrix or not
33             # np.random.choice为选取size大小的数列
34 
35             elif len(neighbors) < self.max_degree:
36                 neighbors = np.random.choice(
37                     neighbors, self.max_degree, replace=True)
38             # 经过choice随机选取,得到了固定大小max_degree = 25的直接相连的邻居数列
39 
40             adj[self.id2idx[nodeid], :] = neighbors
41            # 把该node的邻居数列,赋值给adj矩阵中对应nodeid位的向量。
42         return adj, deg

 

construct_test_adj()  函数中,与上不同之处在于,可以直接得到邻居而无需根据val/test/train_removed筛选.

1 neighbors = np.array([self.id2idx[neighbor]
2                           for neighbor in self.G.neighbors(nodeid)])

 

GraphSAGE 代码解析 - minibatch.py

标签:改变   batch   ace   nbsp   new   根据   rgs   ati   tin   

原文地址:https://www.cnblogs.com/shiyublog/p/9902423.html

(1)
(1)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!