您当前位置:资讯中心 >大数据 >浏览文章

时空图神经网络原理及Pytorch实现

来源:CTO 日期:2024/4/30 14:54:10 阅读量:(0)

在我们生活的这个充满联系的世界中,从微观的分子结构到宏观的社交网络,再到复杂的城市设计结构,都隐藏着一张张相互关联的图数据。这些图数据仿佛一张张神秘的网,将世界万物紧密相连。而图神经网络(GNN)作为一种革命性的技术,正以其强大的能力,逐渐揭开这些图数据的面纱,让我们能够更深入地理解和利用它们。

图神经网络的出现,为我们提供了一种全新的建模和学习方式。它不仅能够捕捉数据的空间结构,还能够揭示图结构中的复杂关系。无论是在生物学领域,如蛋白质结构分析和药物发现,还是在社会学领域,如社交网络模拟和舆情分析,图神经网络都展现出了惊人的应用潜力。

更令人兴奋的是,图神经网络还可以与其他机器学习模型进行融合,形成更加强大的模型。例如,将图神经网络与序列模型结合,形成时空图神经网络(Spatail-Temporal Graph),不仅能够捕捉数据的时间和空间依赖性,还能够更全面地揭示数据的内在规律和趋势。这种融合模型的出现,为各个领域的研究和应用带来了更多的可能性。

在时空图神经网络中,时间维度被巧妙地引入到了图结构中。这意味着,原本静止的节点特征现在会随着时间的推移而发生变化。这种变化不仅反映了节点之间的动态关系,还为我们提供了更丰富的信息,使我们能够更准确地预测和分析各种复杂现象。

不过,GNN模型和序列模型(如简单RNN、LSTM或GRU)本身就复杂。结合这些模型以处理空间和时间依赖性是强大的,但也很复杂:难以理解,也难以实现。

所以在这篇文章中,我们将深入探讨这些模型的原理,并实现一个相对简单的示例,以更深入地理解它们的能力和应用。

图神经网络(GNN)

我们先介绍一些入门的知识简要讨论GNN。

图G可以定义为G = (V, E),其中V是节点集,E是它们之间的边。

一个包含n个节点的图的特征矩阵,每个节点具有f个特征,是所有特征的连接:

GNN的关键问题是所有连接节点之间的消息传递,这种邻居特征转换和聚合可以写成:

A是图的邻接矩阵,I是允许自连接的单位矩阵。虽然这不是完整的方程,但这已经可以说明可以学习不同节点之间空间依赖性的图卷积网络的基础。一个经典的图神经网络如下图所示:

时空图神经网络 (ST-GNN)

ST-GNN中每个时间步都是一个图,并通过GCN/GAT网络传递,以获得嵌入数据空间相互依赖性的结果编码图。然后这些编码图可以像时间序列数据一样进行建模,只要保留每个时间步骤的数据的图结构的完整性。下图演示了这两个步骤,时间模型可以是从ARIMA或简单的循环神经网络或者是transformers的任何序列模型。

我们下面使用简单的循环神经网络来绘制ST-GNN的组件

上面就是ST-GNN的基本原理,将GNN和序列模型(如RNN、LSTM、GRU、Transformers 等)结合。如果你已经熟悉这些序列和GNN模型,那么理论来说是非常简单的,但是实际操作的时候就会有一些复杂,所以我们下面将直接使用Pytorch实现一个简单的ST-GNN。

ST-GNN的Pytorch实现

首先要说明:为了用于演示我将使用大型科技公司的股市数据。虽然这些数据本质上不是图数据,但这种网络可能会捕捉到这些公司之间的相互依赖性,例如一个公司的表现(好或坏)可能反过来影响市场中其他公司的价值。但这只是一个演示,我们并不建议在股市预测中使用ST-GNN。

加载数据,直接使用yfinance里面什么都有

import yfinance as yf
 import datetime as dt
 import pandas as pd
 from sklearn.preprocessing import StandardScaler
 
 import plotly.graph_objs as go
 from plotly.offline import iplot
 import matplotlib.pyplot as plt
 
 ############ Dataset download #################
 start_date = dt.datetime(2013,1,1)
 end_date = dt.datetime(2024,3,7)
 #loading from yahoo finance
 google = yf.download("GOOGL",start_date, end_date)
 apple = yf.download("AAPL",start_date, end_date)
 Microsoft = yf.download("MSFT", start_date, end_date)
 Amazon = yf.download("AMZN", start_date, end_date)
 meta = yf.download("META", start_date, end_date)
 Nvidia = yf.download("NVDA", start_date, end_date)
 data = pd.DataFrame({'google': google['Open'],'microsoft': Microsoft['Open'],'amazon': Amazon['Open'],
                      'Nvidia': Nvidia['Open'],'meta': meta['Open'], 'apple': apple['Open']})
 ############## Scaling data ######################
 scaler = StandardScaler()
 data_scaled = pd.DataFrame(scaler.fit_transform(data), columns=data.columns)
关键字:
声明:我公司网站部分信息和资讯来自于网络,若涉及版权相关问题请致电(63937922)或在线提交留言告知,我们会第一时间屏蔽删除。
有价值
0% (0)
无价值
0% (10)

分享转发:

发表评论请先登录后发表评论。愿您的每句评论,都能给大家的生活添色彩,带来共鸣,带来思索,带来快乐。