锐单电子商城 , 一站式电子元器件采购平台!
  • 电话:400-990-0325

空间转录组 STAGATE

时间:2023-12-22 00:37:01 df37nc连接器

最近,在阅读和复制各大老板的空转论文、记录、交流和学习下,如有错误,欢迎指出。

前言

首先是STAGATE,这是中国科学院提出的提出的方法中NC在上面,主要思维类似于空转的一般思维,提取基因表达、空间信息和图像特征,然后聚类以识别每一个spot类型。STAGATE,没有图像信息,已经是发表论文的最佳结果了。

总体架构

总体结构如下。

一般来说,模型是四层AutoEncode,两层编码器两层解码器,但每层都被取代GAT。输入基因表达数据X进行重构X,损失函数自然是X和X’的MSE。值得注意的是,第二层和第三层,第一层和第四层分别共享一组权重W,图中已经表明了转移关系。如果是spot模型已经完成了所有的等级数据。如果是细胞级数据,将构建SNN,也就是重建新的GAT每层的结果是由新的邻接矩阵和旧的邻接矩阵组成GAT为下一层输入加权求和。

代码

作者最初发布的是tensorflow1代码,今年3月又公布了torch但是torch未构建版本SNN,在细节上与tensorflow也略有不同,如损失函数,tensorflow中除了MSE,还增加了权重损失防止过拟合,具体代码中会提到。下面我试着根据torch让我谈谈我对这篇论文的理解。(最好在linux在系统上运行windows总会有各种奇怪的错误。

首先是数据预处理。包括数据读取,只需根据论文下载数据即可。Normalization,选择高表达基因,正则化,取对数。然后读取最终评估和可视化的真实标签

    input_dir = os.path.join('Data', section_id)     adata = sc.read_visium(path=input_dir, count_file=section_id '_filtered_feature_bc_matrix.h5')     adata.var_names_make_unique()      #Normalization     sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000)     sc.pp.normalize_total(adata, target_sum=1e4)     sc.pp.log1p(adata)      Ann_df = pd.read_csv(os.path.join('Data',                                       section_id, "cluster_labels_" section_id '.csv'), sep=',', header=0, index_col=0)     adata.obs['ground_truth'] = Ann_df.loc[adata.obs_names, 'ground_truth']      plt.rcParams["figure.figsize"] = (3, 3)     sc.pl.spatial(adata, img_key="hires", color=["ground_truth"])

然后是spot和spot距离。距离大于0小于150spot构建相邻矩阵,认为在此范围内有连接,相邻矩阵为1,否则为0。以下是符合距离范围的计算spot并保存距离adata.uns['Spatial_Net']中。

def Cal_Spatial_Net(adata, rad_cutoff=None, k_cutoff=None, model='Radius', verbose=True):     """\     Construct the spatial neighbor networks.      Parameters     ----------     adata         AnnData object of scanpy package.     rad_cutoff         radius cutoff when model='Radius'     k_cutoff         The number of nearest neighbors when model='KNN'     model         The network construction model. When model=='Radius', the spot is connected to spots whose distance is less than rad_cutoff. When model=='KNN', the spot is connected to its first k_cutoff nearest neighbors.          Returns     -------     The spatial networks are saved in adata.uns['Spatial_Net']     """      assert(model in ['Radius', 'KNN'])     if verbose:         print('------Calculating spatial graph...')     coor = pd.DataFrame(adata.obsm['spatial'])     coor.index = adata.obs.index     coor.columns = ['imagerow', 'imagecol']      if model == 'Radius':         nbrs = sklearn.neighbors.NearestNeighbors(radius=rad_cutoff).fit(coor)         distances, indices = nbrs.radius_neighbors(coor, return_distance=True)         KNN_list = []         for it in range(indices.shape[0]):             KNN_list.append(pd.DataFrame(zip([it]*indices[it].shape[0], indices[it], distances[it])))          if model == 'KNN':         nbrs = sklearn.neighbors.NearestNeighbors(n_neighbors=k_cutoff 1).fit(coor)         distances, indices = nbrs.kneighbors(coor)         KNN_list = []         for it in range(indices.shape[0]):             KNN_list.append(pd.DataFrame(zip([it]*indices.shape[1],indices[it,:], distances[it,:])))      KNN_df = pd.concat(KNN_list)     KNN_df.columns = ['Cell1', 'Cell2', 'Distance']      Spatial_Net = KNN_df.copy()     Spatial_Net = Spatial_Net.loc[Spatial_Net['Distance']>0,]     id_cell_trans = dict(zip(range(coor.shape[0]), np.array(coor.index), ))     Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans)     Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans)     if verbose:         print('The graph contains %d edges, %d cells.' %(Spatial_Net.shape[0], adata.n_obs))         print('%.4f neighbors per cell on average.' %(Spatial_Net.shape[0]/adata.n_obs))      adata.uns['Spatial_Net'] = Spatial_Net

然后是可视化,平均每个spot有多少邻居。

def Stats_Spatial_Net(adata):     import matplotlib.pyplot as plt     Num_edge = adata.uns['Spatial_Net']['Cell1'].shape[0]     Mean_edge = Num_edge/adata.shape[0]     plot_df = pd.value_counts(pd.value_counts(adata.uns['Spatial_Net']['Cell1']))     plot_df = plot_df/adata.shape[0]     fig, ax = plt.subplots(figsize=[3,2])     plt.ylabel('Percentage')     plt.xlabel('')     plt.title('Number of Neighbors (Mean=%.2f)'%Mean_edge)     ax.bar(plot_df.index, plot_df)

正式进入下面STAGATE训练阶段到了。

首先是数据准备,包括两部分:根据选定的邻居构建邻接矩阵和基因表达数据。

def Transfer_pytorch_Data(adata):     G_df = adata.uns['Spatial_Net'].copy()     cells = np.array(adata.obs_names)     cells_id_tran = dict(zip(cells, range(cells.shap[0])))
    G_df['Cell1'] = G_df['Cell1'].map(cells_id_tran)
    G_df['Cell2'] = G_df['Cell2'].map(cells_id_tran)

    G = sp.coo_matrix((np.ones(G_df.shape[0]), (G_df['Cell1'], G_df['Cell2'])), shape=(adata.n_obs, adata.n_obs))
    G = G + sp.eye(G.shape[0])

    edgeList = np.nonzero(G)
    if type(adata.X) == np.ndarray:
        data = Data(edge_index=torch.LongTensor(np.array(
            [edgeList[0], edgeList[1]])), x=torch.FloatTensor(adata.X))  # .todense()
    else:
        data = Data(edge_index=torch.LongTensor(np.array(
            [edgeList[0], edgeList[1]])), x=torch.FloatTensor(adata.X.todense()))  # .todense()
    return data

然后构建STAGATE模型 正如前边所说四层GAT,其中h2是最后的特征向量,h4是重建的基因表达数据。

class STAGATE(torch.nn.Module):
    def __init__(self, hidden_dims):
        super(STAGATE, self).__init__()

        [in_dim, num_hidden, out_dim] = hidden_dims
        self.conv1 = GATConv(in_dim, num_hidden, heads=1, concat=False,
                             dropout=0, add_self_loops=False, bias=False)
        self.conv2 = GATConv(num_hidden, out_dim, heads=1, concat=False,
                             dropout=0, add_self_loops=False, bias=False)
        self.conv3 = GATConv(out_dim, num_hidden, heads=1, concat=False,
                             dropout=0, add_self_loops=False, bias=False)
        self.conv4 = GATConv(num_hidden, in_dim, heads=1, concat=False,
                             dropout=0, add_self_loops=False, bias=False)

    def forward(self, features, edge_index):

        h1 = F.elu(self.conv1(features, edge_index))
        h2 = self.conv2(h1, edge_index, attention=False)
        self.conv3.lin_src.data = self.conv2.lin_src.transpose(0, 1)
        self.conv3.lin_dst.data = self.conv2.lin_dst.transpose(0, 1)
        self.conv4.lin_src.data = self.conv1.lin_src.transpose(0, 1)
        self.conv4.lin_dst.data = self.conv1.lin_dst.transpose(0, 1)
        h3 = F.elu(self.conv3(h2, edge_index, attention=True,
                              tied_attention=self.conv1.attentions))
        h4 = self.conv4(h3, edge_index, attention=False)

        return h2, h4  # F.log_softmax(x, dim=-1)

具体的GAT代码不放了,详见`"Graph Attention Networks"

具体训练代码如下,不同点是加了梯度截断,最后返回h2,或者说是z,也就是特征向量用于下一步聚类分析,保存到adata中。

    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    loss_list = []
    for epoch in tqdm(range(1, n_epochs+1)):
        model.train()
        optimizer.zero_grad()
        z, out = model(data.x, data.edge_index)
        loss = F.mse_loss(data.x, out) #F.nll_loss(out[data.train_mask], data.y[data.train_mask])
        loss_list.append(loss)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
        optimizer.step()
    
    model.eval()
    z, out = model(data.x, data.edge_index)
    
    STAGATE_rep = z.to('cpu').detach().numpy()
    adata.obsm[key_added] = STAGATE_rep

    if save_loss:
        adata.uns['STAGATE_loss'] = loss
    if save_reconstrction:
        ReX = out.to('cpu').detach().numpy()
        ReX[ReX<0] = 0
        adata.layers['STAGATE_ReX'] = ReX

最后调用了R中的mclust包进行聚类。

def mclust_R(adata, num_cluster, modelNames='EEE', used_obsm='STAGATE', random_seed=2020):
    """\
    Clustering using the mclust algorithm.
    The parameters are the same as those in the R package mclust.
    """
    
    np.random.seed(random_seed)
    import rpy2.robjects as robjects
    robjects.r.library("mclust")

    import rpy2.robjects.numpy2ri
    rpy2.robjects.numpy2ri.activate()
    r_random_seed = robjects.r['set.seed']
    r_random_seed(random_seed)
    rmclust = robjects.r['Mclust']

    res = rmclust(rpy2.robjects.numpy2ri.numpy2rpy(adata.obsm[used_obsm]), num_cluster, modelNames)
    mclust_res = np.array(res[-2])

    adata.obs['mclust'] = mclust_res
    adata.obs['mclust'] = adata.obs['mclust'].astype('int')
    adata.obs['mclust'] = adata.obs['mclust'].astype('category')
    return adata

去掉缺失值并计算ARI。tensorflow版本和后续的数据分析解析等我看明白再来记录,最后附上测试DFPFC数据库的主函数。所有代码、数据和论文可以再github上下载,欢迎交流。

import warnings
warnings.filterwarnings("ignore")

import pandas as pd
import numpy as np
import scanpy as sc
import matplotlib.pyplot as plt
import os
import sys
from sklearn.metrics.cluster import adjusted_rand_score
# import sklearn
import STAGATE_pyG as STAGATE
os.environ['R_HOME'] = '/home/admin/anaconda3/envs/lib/R'
# os.environ['R_USER'] = '/home/admin/Anaconda3\Lib\site-packages/rpy2'



dataset = ["151507", "151508", "151509", "151510", "151669", "151670", "151671", "151672", "151673", "151674", "151675",
           "151676"]
knn = [7, 7, 7, 7, 5, 5, 5, 5, 7, 7, 7, 7]
ARIlist = []
for section_id, k in zip(dataset, knn):
    print(section_id,k)
    input_dir = os.path.join('Data', section_id)
    adata = sc.read_visium(path=input_dir, count_file=section_id+'_filtered_feature_bc_matrix.h5')
    adata.var_names_make_unique()

    #Normalization
    sc.pp.highly_variable_genes(adata, flavor="seurat_v3", n_top_genes=3000)
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)

    Ann_df = pd.read_csv(os.path.join('Data',
                                      section_id, "cluster_labels_"+section_id+'.csv'), sep=',', header=0, index_col=0)
    adata.obs['ground_truth'] = Ann_df.loc[adata.obs_names, 'ground_truth']

    plt.rcParams["figure.figsize"] = (3, 3)
    sc.pl.spatial(adata, img_key="hires", color=["ground_truth"])

    STAGATE.Cal_Spatial_Net(adata, rad_cutoff=150)
    STAGATE.Stats_Spatial_Net(adata)

    adata = STAGATE.train_STAGATE(adata)

    sc.pp.neighbors(adata, use_rep='STAGATE')
    sc.tl.umap(adata)
    adata = STAGATE.mclust_R(adata, used_obsm='STAGATE', num_cluster=k)

    obs_df = adata.obs.dropna()
    ARI = adjusted_rand_score(obs_df['mclust'], obs_df['ground_truth'])
    ARIlist.append(ARI)
    print('Adjusted rand index = %.2f' %ARI)
print("ari mean", np.mean(ARIlist))
print("ari median", np.median(ARIlist))

锐单商城拥有海量元器件数据手册IC替代型号,打造电子元器件IC百科大全!

相关文章