我正在处理3D点云数据。我拥有点云图结构的稀疏矩阵表示(如scipy.sparse中的csr_matrix)。我想将那些在测地距离(由图中的路径长度近似)阈值内的点聚合在一起,并一起处理它们。为了找到这样的点,我需要运行一些最短路径查找算法,如Dijkstra算法。简而言之,我的想法是这样的
- 从N个点中抽取K个点(我可以使用最远点采样来实现)
- 为每个K个点找到最近的测地邻居(使用支持反向传播的算法)
- 使用某些神经网络处理每个点的邻居
这些将在我的前向函数中进行。有没有办法在我的功能中实现Dijkstra算法?
或者还有其他我可以实现的想法吗?
非常感谢!
回答:
我根据这里讨论的内容,创建了使用优先级队列的Dijkstra自定义实现。为此,我使用torch函数创建了一个自定义的PriorityQ
类,如下所示
class priorityQ_torch(object): """Priority Q implelmentation in PyTorch Args: object ([torch.Tensor]): [The Queue to work on] """ def __init__(self, val): self.q = torch.tensor([[val, 0]]) # self.top = self.q[0] # self.isEmpty = self.q.shape[0] == 0 def push(self, x): """Pushes x to q based on weightvalue in x. Maintains ascending order Args: q ([torch.Tensor]): [The tensor queue arranged in ascending order of weight value] x ([torch.Tensor]): [[index, weight] tensor to be inserted] Returns: [torch.Tensor]: [The queue tensor after correct insertion] """ if type(x) == np.ndarray: x = torch.tensor(x) if self.isEmpty(): self.q = x self.q = torch.unsqueeze(self.q, dim=0) return idx = torch.searchsorted(self.q.T[1], x[1]) print(idx) self.q = torch.vstack([self.q[0:idx], x, self.q[idx:]]).contiguous() def top(self): """Returns the top element from the queue Returns: [torch.Tensor]: [top element] """ return self.q[0] def pop(self): """pops(without return) the highest priority element with the minimum weight Args: q ([torch.Tensor]): [The tensor queue arranged in ascending order of weight value] Returns: [torch.Tensor]: [highest priority element] """ if self.isEmpty(): print("Can Not Pop") self.q = self.q[1:] def isEmpty(self): """Checks is the priority queue is empty Args: q ([torch.Tensor]): [The tensor queue arranged in ascending order of weight value] Returns: [Bool] : [Returns True is empty] """ return self.q.shape[0] == 0
现在是Dijkstra算法,使用邻接矩阵(以图的权重作为输入)
def dijkstra(adj): n = adj.shape[0] distance_matrix = torch.zeros([n, n]) for i in range(n): u = torch.zeros(n, dtype=torch.bool) d = np.inf * torch.ones(n) d[i] = 0 q = priorityQ_torch(i) while not q.isEmpty(): v, d_v = q.top() # point and distance v = v.int() q.pop() if d_v != d[v]: continue for j, py in enumerate(adj[v]): if py == 0 and j != v: continue else: to = j weight = py if d[v] + py < d[to]: d[to] = d[v] + py q.push(torch.Tensor([to, d[to]])) distance_matrix[i] = d return distance_matrix
返回图点的 shortest path distance 矩阵!