空间转录组 STAGATE
时间:2023-12-22 00:37:01
最近,在阅读和复制各大老板的空转论文、记录、交流和学习下,如有错误,欢迎指出。
前言
首先是STAGATE,这是中国科学院提出的提出的方法中NC在上面,主要思维类似于空转的一般思维,提取基因表达、空间信息和图像特征,然后聚类以识别每一个spot类型。STAGATE,没有图像信息,已经是发表论文的最佳结果了。
总体架构
总体结构如下。
一般来说,模型是四层AutoEncode,两层编码器两层解码器,但每层都被取代GAT。输入基因表达数据X进行重构X,损失函数自然是X和X’的MSE。值得注意的是,第二层和第三层,第一层和第四层分别共享一组权重W,图中已经表明了转移关系。如果是spot模型已经完成了所有的等级数据。如果是细胞级数据,将构建SNN,也就是重建新的GAT每层的结果是由新的邻接矩阵和旧的邻接矩阵组成GAT为下一层输入加权求和。
代码
作者最初发布的是tensorflow1代码,今年3月又公布了torch但是torch未构建版本SNN,在细节上与tensorflow也略有不同,如损失函数,tensorflow中除了MSE,还增加了权重损失防止过拟合,具体代码中会提到。下面我试着根据torch让我谈谈我对这篇论文的理解。(最好在linux在系统上运行windows总会有各种奇怪的错误。
首先是数据预处理。包括数据读取,只需根据论文下载数据即可。Normalization,选择高表达基因,正则化,取对数。然后读取最终评估和可视化的真实标签。
input_dir = os.path.join('Data', section_id) adata = sc.read_visium(path=input_dir, count_file=section_id '_filtered_feature_bc_matrix.h5') adata.var_names_make_unique() #Normalization sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000) sc.pp.normalize_total(adata, target_sum=1e4) sc.pp.log1p(adata) Ann_df = pd.read_csv(os.path.join('Data', section_id, "cluster_labels_" section_id '.csv'), sep=',', header=0, index_col=0) adata.obs['ground_truth'] = Ann_df.loc[adata.obs_names, 'ground_truth'] plt.rcParams["figure.figsize"] = (3, 3) sc.pl.spatial(adata, img_key="hires", color=["ground_truth"])
然后是spot和spot距离。距离大于0小于150spot构建相邻矩阵,认为在此范围内有连接,相邻矩阵为1,否则为0。以下是符合距离范围的计算spot并保存距离adata.uns['Spatial_Net']中。
def Cal_Spatial_Net(adata, rad_cutoff=None, k_cutoff=None, model='Radius', verbose=True): """\ Construct the spatial neighbor networks. Parameters ---------- adata AnnData object of scanpy package. rad_cutoff radius cutoff when model='Radius' k_cutoff The number of nearest neighbors when model='KNN' model The network construction model. When model=='Radius', the spot is connected to spots whose distance is less than rad_cutoff. When model=='KNN', the spot is connected to its first k_cutoff nearest neighbors. Returns ------- The spatial networks are saved in adata.uns['Spatial_Net'] """ assert(model in ['Radius', 'KNN']) if verbose: print('------Calculating spatial graph...') coor = pd.DataFrame(adata.obsm['spatial']) coor.index = adata.obs.index coor.columns = ['imagerow', 'imagecol'] if model == 'Radius': nbrs = sklearn.neighbors.NearestNeighbors(radius=rad_cutoff).fit(coor) distances, indices = nbrs.radius_neighbors(coor, return_distance=True) KNN_list = [] for it in range(indices.shape[0]): KNN_list.append(pd.DataFrame(zip([it]*indices[it].shape[0], indices[it], distances[it]))) if model == 'KNN': nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=k_cutoff 1).fit(coor) distances, indices = nbrs.kneighbors(coor) KNN_list = [] for it in range(indices.shape[0]): KNN_list.append(pd.DataFrame(zip([it]*indices.shape[1],indices[it,:], distances[it,:]))) KNN_df = pd.concat(KNN_list) KNN_df.columns = ['Cell1', 'Cell2', 'Distance'] Spatial_Net = KNN_df.copy() Spatial_Net = Spatial_Net.loc[Spatial_Net['Distance']>0,] id_cell_trans = dict(zip(range(coor.shape[0]), np.array(coor.index), )) Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans) Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans) if verbose: print('The graph contains %d edges, %d cells.' %(Spatial_Net.shape[0], adata.n_obs)) print('%.4f neighbors per cell on average.' %(Spatial_Net.shape[0]/adata.n_obs)) adata.uns['Spatial_Net'] = Spatial_Net
然后是可视化,平均每个spot有多少邻居。
def Stats_Spatial_Net(adata): import matplotlib.pyplot as plt Num_edge = adata.uns['Spatial_Net']['Cell1'].shape[0] Mean_edge = Num_edge/adata.shape[0] plot_df = pd.value_counts(pd.value_counts(adata.uns['Spatial_Net']['Cell1'])) plot_df = plot_df/adata.shape[0] fig, ax = plt.subplots(figsize=[3,2]) plt.ylabel('Percentage') plt.xlabel('') plt.title('Number of Neighbors (Mean=%.2f)'%Mean_edge) ax.bar(plot_df.index, plot_df)
正式进入下面STAGATE训练阶段到了。
首先是数据准备,包括两部分:根据选定的邻居构建邻接矩阵和基因表达数据。
def Transfer_pytorch_Data(adata): G_df = adata.uns['Spatial_Net'].copy() cells = np.array(adata.obs_names) cells_id_tran = dict(zip(cells, range(cells.shap[0])))
G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran)
G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran)
G = sp.coo_matrix((np.ones(G_df.shape[0]), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs))
G = G + sp.eye(G.shape[0])
edgeList = np.nonzero(G)
if type(adata.X) == np.ndarray:
data = Data(edge_index=torch.LongTensor(np.array(
[edgeList[0], edgeList[1]])), x=torch.FloatTensor(adata.X)) # .todense()
else:
data = Data(edge_index=torch.LongTensor(np.array(
[edgeList[0], edgeList[1]])), x=torch.FloatTensor(adata.X.todense())) # .todense()
return data
然后构建STAGATE模型 正如前边所说四层GAT,其中h2是最后的特征向量,h4是重建的基因表达数据。
class STAGATE(torch.nn.Module):
def __init__(self, hidden_dims):
super(STAGATE, self).__init__()
[in_dim, num_hidden, out_dim] = hidden_dims
self.conv1 = GATConv(in_dim, num_hidden, heads=1, concat=False,
dropout=0, add_self_loops=False, bias=False)
self.conv2 = GATConv(num_hidden, out_dim, heads=1, concat=False,
dropout=0, add_self_loops=False, bias=False)
self.conv3 = GATConv(out_dim, num_hidden, heads=1, concat=False,
dropout=0, add_self_loops=False, bias=False)
self.conv4 = GATConv(num_hidden, in_dim, heads=1, concat=False,
dropout=0, add_self_loops=False, bias=False)
def forward(self, features, edge_index):
h1 = F.elu(self.conv1(features, edge_index))
h2 = self.conv2(h1, edge_index, attention=False)
self.conv3.lin_src.data = self.conv2.lin_src.transpose(0, 1)
self.conv3.lin_dst.data = self.conv2.lin_dst.transpose(0, 1)
self.conv4.lin_src.data = self.conv1.lin_src.transpose(0, 1)
self.conv4.lin_dst.data = self.conv1.lin_dst.transpose(0, 1)
h3 = F.elu(self.conv3(h2, edge_index, attention=True,
tied_attention=self.conv1.attentions))
h4 = self.conv4(h3, edge_index, attention=False)
return h2, h4 # F.log_softmax(x, dim=-1)
具体的GAT代码不放了,详见`"Graph Attention Networks"
具体训练代码如下,不同点是加了梯度截断,最后返回h2,或者说是z,也就是特征向量用于下一步聚类分析,保存到adata中。
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
loss_list = []
for epoch in tqdm(range(1, n_epochs+1)):
model.train()
optimizer.zero_grad()
z, out = model(data.x, data.edge_index)
loss = F.mse_loss(data.x, out) #F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss_list.append(loss)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
optimizer.step()
model.eval()
z, out = model(data.x, data.edge_index)
STAGATE_rep = z.to('cpu').detach().numpy()
adata.obsm[key_added] = STAGATE_rep
if save_loss:
adata.uns['STAGATE_loss'] = loss
if save_reconstrction:
ReX = out.to('cpu').detach().numpy()
ReX[ReX<0] = 0
adata.layers['STAGATE_ReX'] = ReX
最后调用了R中的mclust包进行聚类。
def mclust_R(adata, num_cluster, modelNames='EEE', used_obsm='STAGATE', random_seed=2020):
"""\
Clustering using the mclust algorithm.
The parameters are the same as those in the R package mclust.
"""
np.random.seed(random_seed)
import rpy2.robjects as robjects
robjects.r.library("mclust")
import rpy2.robjects.numpy2ri
rpy2.robjects.numpy2ri.activate()
r_random_seed = robjects.r['set.seed']
r_random_seed(random_seed)
rmclust = robjects.r['Mclust']
res = rmclust(rpy2.robjects.numpy2ri.numpy2rpy(adata.obsm[used_obsm]), num_cluster, modelNames)
mclust_res = np.array(res[-2])
adata.obs['mclust'] = mclust_res
adata.obs['mclust'] = adata.obs['mclust'].astype('int')
adata.obs['mclust'] = adata.obs['mclust'].astype('category')
return adata
去掉缺失值并计算ARI。tensorflow版本和后续的数据分析解析等我看明白再来记录,最后附上测试DFPFC数据库的主函数。所有代码、数据和论文可以再github上下载,欢迎交流。
import warnings
warnings.filterwarnings("ignore")
import pandas as pd
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
import os
import sys
from sklearn.metrics.cluster import adjusted_rand_score
# import sklearn
import STAGATE_pyG as STAGATE
os.environ['R_HOME'] = '/home/admin/anaconda3/envs/lib/R'
# os.environ['R_USER'] = '/home/admin/Anaconda3\Lib\site-packages/rpy2'
dataset = ["151507", "151508", "151509", "151510", "151669", "151670", "151671", "151672", "151673", "151674", "151675",
"151676"]
knn = [7, 7, 7, 7, 5, 5, 5, 5, 7, 7, 7, 7]
ARIlist = []
for section_id, k in zip(dataset, knn):
print(section_id,k)
input_dir = os.path.join('Data', section_id)
adata = sc.read_visium(path=input_dir, count_file=section_id+'_filtered_feature_bc_matrix.h5')
adata.var_names_make_unique()
#Normalization
sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000)
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
Ann_df = pd.read_csv(os.path.join('Data',
section_id, "cluster_labels_"+section_id+'.csv'), sep=',', header=0, index_col=0)
adata.obs['ground_truth'] = Ann_df.loc[adata.obs_names, 'ground_truth']
plt.rcParams["figure.figsize"] = (3, 3)
sc.pl.spatial(adata, img_key="hires", color=["ground_truth"])
STAGATE.Cal_Spatial_Net(adata, rad_cutoff=150)
STAGATE.Stats_Spatial_Net(adata)
adata = STAGATE.train_STAGATE(adata)
sc.pp.neighbors(adata, use_rep='STAGATE')
sc.tl.umap(adata)
adata = STAGATE.mclust_R(adata, used_obsm='STAGATE', num_cluster=k)
obs_df = adata.obs.dropna()
ARI = adjusted_rand_score(obs_df['mclust'], obs_df['ground_truth'])
ARIlist.append(ARI)
print('Adjusted rand index = %.2f' %ARI)
print("ari mean", np.mean(ARIlist))
print("ari median", np.median(ARIlist))