Investigating Global Attention and Prognosis on Colorectal CODEX data

[11]:
import os
import sys
sys.path.append("../")
device = "cuda"
[12]:
import scanpy as sc
import squidpy as sq
import pandas as pd
from tqdm.notebook import tqdm
import scipy as sp
import numpy as np
from sksurv.nonparametric import kaplan_meier_estimator
[13]:
import torch
[14]:
import steamboat as sf
import steamboat.tools
[15]:
import pickle as pkl
[16]:
import matplotlib.pyplot as plt
import seaborn as sns
[17]:
import matplotlib
plt.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['mathtext.fontset'] = 'dejavuserif'
matplotlib.rcParams['font.family'] = 'arial'
[18]:
import importlib
importlib.reload(steamboat.tools)
[18]:
<module 'steamboat.tools' from 'c:\\Files\\projects\\Steamboat\\examples\\..\\steamboat\\tools.py'>
[19]:
data_path = "E:/codex/"
# fig_path = "C:/Users/lshh/OneDrive/Publications/Steamboat/pub/fig-codex-elements/"
# pltkw = dict(bbox_inches='tight', transparent=True)

Process dataset

You can download the dataset here

[20]:
patient_info = pd.read_excel(data_path + "mmc2.xlsx", sheet_name=0, skipfooter=5)
---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
Cell In[20], line 1
----> 1 patient_info = pd.read_excel(data_path + "mmc2.xlsx", sheet_name=0, skipfooter=5)

File c:\Users\lshh\miniconda3\envs\py313_torch291_cuda130\Lib\site-packages\pandas\io\excel\_base.py:495, in read_excel(io, sheet_name, header, names, index_col, usecols, dtype, engine, converters, true_values, false_values, skiprows, nrows, na_values, keep_default_na, na_filter, verbose, parse_dates, date_parser, date_format, thousands, decimal, comment, skipfooter, storage_options, dtype_backend, engine_kwargs)
    493 if not isinstance(io, ExcelFile):
    494     should_close = True
--> 495     io = ExcelFile(
    496         io,
    497         storage_options=storage_options,
    498         engine=engine,
    499         engine_kwargs=engine_kwargs,
    500     )
    501 elif engine and engine != io.engine:
    502     raise ValueError(
    503         "Engine should not be specified when passing "
    504         "an ExcelFile - ExcelFile already has the engine set"
    505     )

File c:\Users\lshh\miniconda3\envs\py313_torch291_cuda130\Lib\site-packages\pandas\io\excel\_base.py:1550, in ExcelFile.__init__(self, path_or_buffer, engine, storage_options, engine_kwargs)
   1548     ext = "xls"
   1549 else:
-> 1550     ext = inspect_excel_format(
   1551         content_or_path=path_or_buffer, storage_options=storage_options
   1552     )
   1553     if ext is None:
   1554         raise ValueError(
   1555             "Excel file format cannot be determined, you must specify "
   1556             "an engine manually."
   1557         )

File c:\Users\lshh\miniconda3\envs\py313_torch291_cuda130\Lib\site-packages\pandas\io\excel\_base.py:1402, in inspect_excel_format(content_or_path, storage_options)
   1399 if isinstance(content_or_path, bytes):
   1400     content_or_path = BytesIO(content_or_path)
-> 1402 with get_handle(
   1403     content_or_path, "rb", storage_options=storage_options, is_text=False
   1404 ) as handle:
   1405     stream = handle.handle
   1406     stream.seek(0)

File c:\Users\lshh\miniconda3\envs\py313_torch291_cuda130\Lib\site-packages\pandas\io\common.py:882, in get_handle(path_or_buf, mode, encoding, compression, memory_map, is_text, errors, storage_options)
    873         handle = open(
    874             handle,
    875             ioargs.mode,
   (...)    878             newline="",
    879         )
    880     else:
    881         # Binary mode
--> 882         handle = open(handle, ioargs.mode)
    883     handles.append(handle)
    885 # Convert BytesIO or file objects passed with an encoding

FileNotFoundError: [Errno 2] No such file or directory: 'E:/codex/mmc2.xlsx'
[ ]:
data = pd.read_csv("E:/codex/CRC_clusters_neighborhoods_markers.csv", index_col=0)
features = [
       'CD44 - stroma:Cyc_2_ch_2', 'FOXP3 - regulatory T cells:Cyc_2_ch_3',
       'CD8 - cytotoxic T cells:Cyc_3_ch_2',
       'p53 - tumor suppressor:Cyc_3_ch_3',
       'GATA3 - Th2 helper T cells:Cyc_3_ch_4',
       'CD45 - hematopoietic cells:Cyc_4_ch_2', 'T-bet - Th1 cells:Cyc_4_ch_3',
       'beta-catenin - Wnt signaling:Cyc_4_ch_4', 'HLA-DR - MHC-II:Cyc_5_ch_2',
       'PD-L1 - checkpoint:Cyc_5_ch_3', 'Ki67 - proliferation:Cyc_5_ch_4',
       'CD45RA - naive T cells:Cyc_6_ch_2', 'CD4 - T helper cells:Cyc_6_ch_3',
       'CD21 - DCs:Cyc_6_ch_4', 'MUC-1 - epithelia:Cyc_7_ch_2',
       'CD30 - costimulator:Cyc_7_ch_3', 'CD2 - T cells:Cyc_7_ch_4',
       'Vimentin - cytoplasm:Cyc_8_ch_2', 'CD20 - B cells:Cyc_8_ch_3',
       'LAG-3 - checkpoint:Cyc_8_ch_4', 'Na-K-ATPase - membranes:Cyc_9_ch_2',
       'CD5 - T cells:Cyc_9_ch_3', 'IDO-1 - metabolism:Cyc_9_ch_4',
       'Cytokeratin - epithelia:Cyc_10_ch_2',
       'CD11b - macrophages:Cyc_10_ch_3', 'CD56 - NK cells:Cyc_10_ch_4',
       'aSMA - smooth muscle:Cyc_11_ch_2', 'BCL-2 - apoptosis:Cyc_11_ch_3',
       'CD25 - IL-2 Ra:Cyc_11_ch_4', 'CD11c - DCs:Cyc_12_ch_3',
       'PD-1 - checkpoint:Cyc_12_ch_4',
       'Granzyme B - cytotoxicity:Cyc_13_ch_2', 'EGFR - signaling:Cyc_13_ch_3',
       'VISTA - costimulator:Cyc_13_ch_4', 'CD15 - granulocytes:Cyc_14_ch_2',
       'ICOS - costimulator:Cyc_14_ch_4',
       'Synaptophysin - neuroendocrine:Cyc_15_ch_3',
       'GFAP - nerves:Cyc_16_ch_2', 'CD7 - T cells:Cyc_16_ch_3',
       'CD3 - T cells:Cyc_16_ch_4',
       'Chromogranin A - neuroendocrine:Cyc_17_ch_2',
       'CD163 - macrophages:Cyc_17_ch_3', 'CD45RO - memory cells:Cyc_18_ch_3',
       'CD68 - macrophages:Cyc_18_ch_4', 'CD31 - vasculature:Cyc_19_ch_3',
       'Podoplanin - lymphatics:Cyc_19_ch_4', 'CD34 - vasculature:Cyc_20_ch_3',
       'CD38 - multifunctional:Cyc_20_ch_4',
       'CD138 - plasma cells:Cyc_21_ch_3',
       'CDX2 - intestinal epithelia:Cyc_2_ch_4',
       'Collagen IV - bas. memb.:Cyc_12_ch_2',
       'CD194 - CCR4 chemokine R:Cyc_14_ch_3',
       'MMP9 - matrix metalloproteinase:Cyc_15_ch_2',
       'CD71 - transferrin R:Cyc_15_ch_4', 'CD57 - NK cells:Cyc_17_ch_4',
       'MMP12 - matrix metalloproteinase:Cyc_21_ch_4']

metadata = data.drop(features, axis=1)
data = data[features]
[ ]:
p2g_dict = {'CD44': 'CD44',
 'FOXP3': 'FOXP3',
 'CD8': 'CD8A',
 'p53': 'TP53',
 'GATA3': 'GATA3',
 'CD45': 'PTPRC',
 'T-bet': 'TBX21',
 'beta-catenin': 'CTNNB1',
 'HLA-DR': 'HLA-DRA',
 'PD-L1': 'CD274',
 'Ki67': 'MKI67',
 'CD45RA': 'PTPRC',
 'CD4': 'CD4',
 'CD21': 'CR2',
 'MUC-1': 'MUC1',
 'CD30': 'TNFRSF8',
 'CD2': 'CD2',
 'Vimentin': 'VIM',
 'CD20': 'MS4A1',
 'LAG-3': 'LAG3' ,
 'Na-K-ATPase': 'ATP1A1',
 'CD5': 'CD5',
 'IDO-1': 'IDO1',
 'Cytokeratin': 'KRT20',
 'CD11b': 'ITGAM',
 'CD56': 'NCAM',
 'aSMA': 'ACTA2',
 'BCL-2': 'BCL2',
 'CD25': 'IL2RA',
 'CD11c': 'ITGAX',
 'PD-1': 'PDCD1',
 'Granzyme B': 'GZMB',
 'EGFR': 'EGFR',
 'VISTA': 'VSIR',
 'CD15': 'FUT4',
 'ICOS': 'ICOS',
 'Synaptophysin': 'SYP',
 'GFAP': 'GFAP',
 'CD7': 'CD7',
 'CD3': 'CD3D',
 'Chromogranin A': 'CHGA',
 'CD163': 'CD163',
 'CD45RO': 'PTPRC',
 'CD68': 'CD68',
 'CD31': 'PECAM1',
 'Podoplanin': 'PDPN',
 'CD34': 'CD34',
 'CD38': 'CD38',
 'CD138': 'SDC1',
 'CDX2': 'CDX2',
 'Collagen IV': 'COL4A1',
 'CD194': 'CCR4',
 'MMP9': 'MMP9',
 'CD71': 'TFRC',
 'CD57': 'B3GAT1',
 'MMP12': 'MMP12'}
[ ]:
def feature_map(s):
    if ':' in s:
        s = s.split(':')[0]
    if ' - ' in s:
        s = s.split(' - ')[0]
    return s
short_features = list(map(feature_map, features))
# short_features
[ ]:
metadata.head().style
  CellID ClusterID EventID File Name Region TMA_AB TMA_12 Index in File groups patients spots cell_id:cell_id tile_nr:tile_nr X:X Y:Y X_withinTile:X_withinTile Y_withinTile:Y_withinTile Z:Z size:size HOECHST1:Cyc_1_ch_1 DRAQ5:Cyc_23_ch_4 Profile_Homogeneity:Fiter1 ClusterSize ClusterName neighborhood10 CD4+ICOS+ CD4+Ki67+ CD4+PD-1+ CD68+CD163+ICOS+ CD68+CD163+Ki67+ CD68+CD163+PD-1+ CD68+ICOS+ CD68+Ki67+ CD68+PD-1+ CD8+ICOS+ CD8+Ki67+ CD8+PD-1+ Treg-ICOS+ Treg-Ki67+ Treg-PD-1+ neighborhood number final neighborhood name
0 0 10668 0 reg001_A reg001 A 1 0 1 1 1_A 3 1 77 589 77 589 10 10120 472.335785 2011.402222 3.170372 22144 granulocytes 9 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 9.000000 Granulocyte enriched
1 1 10668 4 reg001_A reg001 A 1 4 1 1 1_A 9 1 106 826 106 826 10 861 761.088501 3221.254883 5.373874 22144 granulocytes 4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 4.000000 Macrophage enriched
2 2 10668 5 reg001_A reg001 A 1 5 1 1 1_A 10 1 107 545 107 545 10 6206 1353.695679 5228.323730 5.239202 22144 granulocytes 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 3.000000 Immune-infiltrated stroma
3 3 10668 6 reg001_A reg001 A 1 6 1 1 1_A 13 1 98 564 98 564 10 6320 844.661743 1259.016846 4.616237 22144 granulocytes 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 3.000000 Immune-infiltrated stroma
4 4 10668 30 reg001_A reg001 A 1 30 1 1 1_A 65 1 217 329 217 329 10 1591 1746.380859 8041.375000 5.331494 22144 granulocytes 4 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 4.000000 Macrophage enriched
[ ]:
all_adata = sc.AnnData(data, obs=metadata)
all_adata.var_names = short_features
all_adata.obsm['spatial'] = all_adata.obs[['X:X', 'Y:Y']].to_numpy()

adatas = []
for i in all_adata.obs['File Name'].unique():
    adatas.append(all_adata[all_adata.obs['File Name'] == i].copy())
    adatas[-1].obs['global'] = 0
adatas = sf.prep_adatas(adatas)
dataset = sf.make_dataset(adatas, sparse_graph=True, regional_obs=['global'])
C:\Users\lshh\miniconda3\envs\py311_torch211_cuda121\Lib\site-packages\anndata\_core\anndata.py:183: ImplicitModificationWarning: Transforming to str index.
  warnings.warn("Transforming to str index.", ImplicitModificationWarning)
Using ['global'] as regional annotations.
[ ]:
pop_cell_types = all_adata.obs["ClusterName"].value_counts()[all_adata.obs["ClusterName"].value_counts() > all_adata.shape[0] * .005].index.tolist()
pop_cell_types
['tumor cells',
 'CD68+CD163+ macrophages',
 'smooth muscle',
 'granulocytes',
 'stroma',
 'CD8+ T cells',
 'CD4+ T cells CD45RO+',
 'B cells',
 'vasculature',
 'plasma cells',
 'dirt',
 'undefined',
 'immune cells',
 'Tregs',
 'CD4+ T cells',
 'immune cells / vasculature',
 'CD68+ macrophages',
 'adipocytes',
 'tumor cells / immune cells',
 'CD11b+CD68+ macrophages']

Train model

[ ]:
cuda_dataset = dataset.to('cuda')
[ ]:
sf.set_random_seed(0)
model = sf.Steamboat(short_features, n_heads=10, n_scales=3)
model = model.to(device)

# model.fit(cuda_dataset, entry_masking_rate=0.1, feature_masking_rate=0.0,
#           max_epoch=10000,
#           loss_fun=torch.nn.MSELoss(reduction='sum'),
#           opt=torch.optim.Adam, opt_args=dict(lr=0.1), stop_eps=1e-3, report_per=200, stop_tol=200, device=device)

# torch.save(model.state_dict(), 'saved_models/crc_codex.pth')

model.load_state_dict(torch.load('../experiments/saved_models/crc_codex.pth', weights_only=True), strict=False)
_IncompatibleKeys(missing_keys=[], unexpected_keys=['spatial_gather.w_local._scale', 'spatial_gather.w_global._scale'])
[ ]:
sf.tools.calc_obs(adatas, dataset, model, get_recon=False)
[ ]:
head_weights = sf.tools.calc_head_weights(adatas, model)
sf.tools.plot_head_weights(head_weights)
../_images/tutorial_nbs_Ex3_crc_22_0.png
[ ]:
sf.tools.plot_all_transforms(model, top=1)
['CD20' 'CD45' 'CD163' 'p53' 'LAG-3' 'aSMA' 'Ki67' 'EGFR' 'CD194' 'CD30'
 'CD15' 'Cytokeratin' 'CD45RO' 'CD2' 'CD57' 'Collagen IV' 'Chromogranin A'
 'beta-catenin' 'CD71' 'CD7' 'BCL-2']
../_images/tutorial_nbs_Ex3_crc_23_1.png
[ ]:
# T cells (CD3+)
# cytotoxic T cells (CD3+CD8+)
# T helper cells (CD3+CD4+),
# Tregs (CD3+CD4+CD25+FOXP3+),
# B cells (CD3-CD20+),
# plasma cells (CD3-CD20-CD56-CD68-CD163-CD38+)
# NK cells (CD3-CD20-CD56+),
# CD68+ macrophages (CD3-CD20-CD56-CD68+),
# CD163 +macrophages (CD3-CD20-CD56-CD163+),
# CD68+CD163+double-positive macrophages (CD3-CD20-CD56-CD68+CD163+),
# lymphatics (CD3-CD20-CD56-Podoplanin+),
# vasculature (CD3-CD20-CD56-CD31+),
# dendritic cells (CD3-CD20-CD56-CD11c+),
# granulocytes (CD3-CD20-CD56-CD15+)

chosen_features = ['CD3', 'CD4', 'CD8', 'CD25', 'FOXP3',  # T
                   'CD20', 'CD38',  # B/Plasma
                   'CD56',  # NK
                   'CD68', 'CD163', # Macro
                   'Podoplanin', #lymphatics
                   'CD31', #vasculature
                   'CD11c', #dendritic
                   'CD15', #granulocytes
                   'Cytokeratin'
                  ]

sf.tools.plot_vq(model, chosen_features)
(<Figure size 300x300 with 2 Axes>, <Axes: >)
../_images/tutorial_nbs_Ex3_crc_24_1.png

Global attention interpretation

[ ]:
sample_df = pd.crosstab(all_adata.obs['File Name'], all_adata.obs['neighborhood name'])
sample_df = sample_df.div(sample_df.sum(axis=1), axis=0)
sample_df['TLS'] = sample_df['Follicle'] > 0.00

sample_os = []
for i in patient_info['OS']:
    sample_os.extend([i] * 4)
sample_df['OS'] = sample_os

sample_osc = []
for i in patient_info['OS_Censor']:
    sample_osc.extend([i] * 4)
sample_df['OS_Censor'] = sample_osc

sample_dfs = []
for i in patient_info['DFS']:
    sample_dfs.extend([i] * 4)
sample_df['DFS'] = sample_dfs

sample_df['Patient'] = [i.obs['patients'].unique().item() for i in adatas]
sample_df['Group'] = [i.obs['groups'].unique().item() for i in adatas]

sample_df['Group'] = np.array(['', 'CLR', 'DII'])[sample_df['Group'].astype(int)]
sample_df['TLS'] = np.array(['No', 'Yes'])[sample_df['TLS'].astype(int)]
sample_df.loc[sample_df['Group'] == 'DII', 'TLS'] = 'No'
[ ]:
for i_head in [2, 8]:
    score_name = f'Global_score_{i_head}'
    sample_df[score_name] = [i.uns['global_k_0'][0, i_head] for i in adatas]

    g1 = (sample_df['Group'] == 'CLR') & (sample_df['TLS'] == 'No')
    g2 = (sample_df['Group'] == 'CLR') & (sample_df['TLS'] == 'Yes')
    s, p = sp.stats.mannwhitneyu(sample_df.loc[g1, score_name],
                                 sample_df.loc[g2, score_name])
    print('CLR Yes vs No:', s / g1.sum() / g2.sum(), p)

    g1 = (sample_df['Group'] == 'DII')
    g2 = (sample_df['Group'] == 'CLR') & (sample_df['TLS'] == 'No')
    s, p = sp.stats.mannwhitneyu(sample_df.loc[g1, score_name],
                                 sample_df.loc[g2, score_name])
    print('DII vs CLR-No:', s / g1.sum() / g2.sum(), p)

    g1 = (sample_df['Group'] == 'DII')
    g2 = (sample_df['Group'] == 'CLR')
    s, p = sp.stats.mannwhitneyu(sample_df.loc[g1, score_name],
                                 sample_df.loc[g2, score_name])
    print('DII vs CLR:', s / g1.sum() / g2.sum(), p)

    fig, ax = plt.subplots(figsize=(1.6, 2.4))
    sns.violinplot(sample_df, x='Group', hue='TLS', y=score_name, ax=ax, legend="brief")
    plt.ylabel(f'Global_score_{i_head}')
    ax.grid(axis='y')
    # ax.set_ylim([-49, 180])
    for pos in ['right', 'top']:
        ax.spines[pos].set_visible(False)
CLR Yes vs No: 0.5934065934065934 0.20022301659028963
DII vs CLR-No: 0.6633597883597884 0.0037443658462830293
DII vs CLR: 0.701593137254902 3.9057054643225304e-05
CLR Yes vs No: 0.3736263736263736 0.0826997853235793
DII vs CLR-No: 0.7106481481481481 0.00018479325733404154
DII vs CLR: 0.6595179738562091 0.00113745105113063
../_images/tutorial_nbs_Ex3_crc_27_1.png
../_images/tutorial_nbs_Ex3_crc_27_2.png
[ ]:
def plot_global_transform(model, d,
                   top: int = 3, reorder: bool = False,
                   figsize: str | tuple[float, float] = 'auto'):

    q = model.spatial_gather.q.weight[d, :].detach().cpu().numpy()
    k = model.spatial_gather.k_regionals[0].weight[d, :].detach().cpu().numpy()
    v = model.spatial_gather.v.weight[:, d].detach().cpu().numpy()

    if top > 0:
        rank_q = np.argsort(-q)[:top]
        rank_k = np.argsort(-k)[:top]
        rank_v = np.argsort(-v)[:top]
        feature_mask = {}
        for j in rank_k:
            feature_mask[j] = None
        for j in rank_q:
            feature_mask[j] = None
        for j in rank_v:
            feature_mask[j] = None
        feature_mask = list(feature_mask.keys())
        chosen_features = np.array(model.features)[feature_mask]
    else:
        feature_mask = list(range(len(model.features)))
        chosen_features = np.array(model.features)

    if figsize == 'auto':
        figsize = (.65, len(chosen_features) * 0.15 + .75)
    # print(figsize)
    fig, ax = plt.subplots(figsize=figsize)
    common_params = {'linewidths': .05, 'linecolor': 'gray', 'yticklabels': chosen_features,
                     'cmap': 'Reds'}

    to_plot = np.vstack((k[feature_mask],
                         q[feature_mask],
                         v[feature_mask])).T
    true_vmax = to_plot.max(axis=0)
    # print(true_vmax)
    to_plot /= true_vmax

    sns.heatmap(to_plot, xticklabels=['global env', 'center cell', 'reconstruction'], square=True, ax=ax, **common_params)
    ax.set_xticklabels(['global env',  'center cell', 'reconstruction'], rotation=45, ha='right', va='center', rotation_mode='anchor')
    # ax.set_xticklabels(plot_axes[i].get_xticklabels(), rotation=0)
    # ax.get_yaxis().set_visible(False)

    plt.tight_layout()

plot_global_transform(model, 2, top=8)
plot_global_transform(model, 8, top=8)
C:\Users\lshh\AppData\Local\Temp\ipykernel_67648\3390957469.py:45: UserWarning: Tight layout not applied. The left and right margins cannot be made large enough to accommodate all axes decorations.
  plt.tight_layout()
C:\Users\lshh\AppData\Local\Temp\ipykernel_67648\3390957469.py:45: UserWarning: Tight layout not applied. The left and right margins cannot be made large enough to accommodate all axes decorations.
  plt.tight_layout()
../_images/tutorial_nbs_Ex3_crc_28_1.png
../_images/tutorial_nbs_Ex3_crc_28_2.png

Survival analysis

[ ]:
def summarize_samples(df, entity, keys):
    res = {key: [] for key, _ in keys}
    res[entity] = []
    for i in df[entity].unique():
        res[entity].append(i)
        for (key, how) in keys:
            # print(df.loc[df[entity] == i, key])
            if how == 'mean':
                res[key].append(df.loc[df[entity] == i, key].mean())
            elif how == 'median':
                res[key].append(df.loc[df[entity] == i, key].median())
            elif how == 'item':
                res[key].append(df.loc[df[entity] == i, key].unique().item())
    return pd.DataFrame(res)

patient_df = summarize_samples(sample_df, 'Patient', [('OS', 'item'), ('OS_Censor', 'item'), ('Group', 'item')])
patient_df['OS_Censor'] = patient_df['OS_Censor'].astype(bool)
fig, ax = plt.subplots(figsize=(1.2, 1.2))

for group in ("CLR", "DII"):
    mask = patient_df["Group"] == group
    time_survival, survival_prob, conf_int = kaplan_meier_estimator(
        patient_df["OS_Censor"][mask],
        patient_df["OS"][mask],
        conf_type="log-log"
    )

    ax.step(time_survival, survival_prob, where="post", label=f"{group}")
    ax.fill_between(time_survival, conf_int[0], conf_int[1], alpha=0.25, step="post")

import sksurv.compare
y = np.array([(i, j) for i, j in zip(patient_df['OS_Censor'], patient_df["OS"])], dtype=[('OS_Censor', '?'), ('Survival_in_months', '<f8')])
chi2, pval = sksurv.compare.compare_survival(y, patient_df["Group"], return_stats=False)

ax.set_title(f'CLR vs DII p={pval:.3f}')
ax.set_ylim(0, 1)
ax.set_ylabel("est. prob. of survival\n$\hat{S}(t)$")
ax.set_xlabel("time $t$")
ax.legend(loc="best")
ax.spines[['right', 'top']].set_visible(False)
../_images/tutorial_nbs_Ex3_crc_30_0.png
[ ]:
regional = []
for i in range(len(adatas)):
    regional.append(np.mean(adatas[i].obsm['global_attn_0'], axis=0))

for i in [2, 7, 8]:
    sample_df['head_weight'] = np.vstack(regional)[:, i]
    patient_df = summarize_samples(sample_df, 'Patient', [('OS', 'item'), ('OS_Censor', 'item'), ('Group', 'item'), ('head_weight', 'mean')])
    patient_df['OS_Censor'] = patient_df['OS_Censor'].astype(bool)
    patient_df['head_weight_binary'] = 'mid'
    patient_df.loc[patient_df['head_weight'] > patient_df['head_weight'].quantile(.5), 'head_weight_binary'] = 'high'
    patient_df.loc[patient_df['head_weight'] < patient_df['head_weight'].quantile(.5), 'head_weight_binary'] = 'low'

    patient_df = patient_df[patient_df['head_weight_binary'] != 'mid']

    y = np.array([(i, j) for i, j in zip(patient_df['OS_Censor'], patient_df["OS"])], dtype=[('OS_Censor', '?'), ('Survival_in_months', '<f8')])
    chi2, pval = sksurv.compare.compare_survival(y, patient_df["head_weight_binary"], return_stats=False)
    print(i, pval)

    fig, ax = plt.subplots(figsize=(1.2, 1.2))
    for group in ("high", "low"):
        mask = patient_df["head_weight_binary"] == group
        time_survival, survival_prob, conf_int = kaplan_meier_estimator(
            patient_df["OS_Censor"][mask],
            patient_df["OS"][mask],
            conf_type="log-log"
        )

        ax.step(time_survival, survival_prob, where="post", label=f"{group}", lw=1.)
        ax.fill_between(time_survival, conf_int[0], conf_int[1], alpha=0.25, step="post")

    ax.set_title(f'Head #{i} p={pval:.3f}')
    ax.set_ylim(0, 1)
    ax.set_ylabel("Est. prob. of survival\n$\hat{S}(t)$")
    ax.set_xlabel("time $t$ (months)")
    ax.legend(loc="best")
    ax.spines[['right', 'top']].set_visible(False)
2 0.03900862612646816
7 0.11923579092332769
8 0.025214997228824008
../_images/tutorial_nbs_Ex3_crc_31_1.png
../_images/tutorial_nbs_Ex3_crc_31_2.png
../_images/tutorial_nbs_Ex3_crc_31_3.png
[ ]:
sf.tools.plot_cell_type_enrichment(all_adata, adatas, 2, 'ClusterName', pop_cell_types)
sf.tools.plot_cell_type_enrichment(all_adata, adatas, 2, 'neighborhood name')
(<Figure size 75x400 with 2 Axes>, <Axes: ylabel='neighborhood name'>)
../_images/tutorial_nbs_Ex3_crc_32_1.png
../_images/tutorial_nbs_Ex3_crc_32_2.png
[ ]:
sf.tools.plot_cell_type_enrichment(all_adata, adatas, 8, 'ClusterName', pop_cell_types)
sf.tools.plot_cell_type_enrichment(all_adata, adatas, 8, 'neighborhood name')
(<Figure size 75x400 with 2 Axes>, <Axes: ylabel='neighborhood name'>)
../_images/tutorial_nbs_Ex3_crc_33_1.png
../_images/tutorial_nbs_Ex3_crc_33_2.png

Misc visualization

[ ]:
i_sample = 10

major_cell_type_map = {'tumor cells': 'tumor',
                         'CD68+CD163+ macrophages': 'macrophages',
                         'smooth muscle': 'muscle',
                         'granulocytes': 'granulocytes',
                         'stroma': 'stroma',
                         'CD8+ T cells': 'T cells',
                         'CD4+ T cells CD45RO+': 'T cells',
                         'B cells': 'B cells',
                         'vasculature': 'vasculature',
                         'plasma cells': 'plasma',
                         'dirt': 'dirt',
                         'undefined': 'unclear',
                         'immune cells': 'other immune',
                         'Tregs': 'T cells',
                         'CD4+ T cells': 'T cells',
                         'immune cells / vasculature': 'unclear',
                         'CD68+ macrophages': 'macrophages',
                         'adipocytes': 'adipocytes',
                         'tumor cells / immune cells': 'unclear',
                         'CD11b+CD68+ macrophages': 'macrophages',
                         'CD11b+ monocytes': 'monocytes',
                         'nerves': 'nerves',
                         'CD11c+ DCs': 'DCs',
                         'lymphatics': 'lymphatics',
                         'NK cells': 'NK cells',
                         'CD3+ T cells': 'T cells',
                         'CD68+ macrophages GzmB+': 'macrophages',
                         'CD4+ T cells GATA3+': 'T cells',
                         'CD163+ macrophages': 'macrophages'}

coarse_cell_type_map = {'tumor cells': 'tumor',
                         'CD68+CD163+ macrophages': 'immune',
                         'smooth muscle': 'muscle',
                         'granulocytes': 'immune',
                         'stroma': 'stroma',
                         'CD8+ T cells': 'immune',
                         'CD4+ T cells CD45RO+': 'immune',
                         'B cells': 'immune',
                         'vasculature': 'vasculature',
                         'plasma cells': 'immune',
                         'dirt': 'dirt',
                         'undefined': 'unclear',
                         'immune cells': 'immune',
                         'Tregs': 'immune',
                         'CD4+ T cells': 'immune',
                         'immune cells / vasculature': 'unclear',
                         'CD68+ macrophages': 'immune',
                         'adipocytes': 'adipocytes',
                         'tumor cells / immune cells': 'unclear',
                         'CD11b+CD68+ macrophages': 'immune',
                         'CD11b+ monocytes': 'immune',
                         'nerves': 'nerves',
                         'CD11c+ DCs': 'immune',
                         'lymphatics': 'immune',
                         'NK cells': 'immune',
                         'CD3+ T cells': 'immune',
                         'CD68+ macrophages GzmB+': 'immune',
                         'CD4+ T cells GATA3+': 'immune',
                         'CD163+ macrophages': 'immune'}

all_adata.obs['MajorCellTypes'] = all_adata.obs['ClusterName'].apply(major_cell_type_map.__getitem__)
all_adata.obs['CoarseCellTypes'] = all_adata.obs['ClusterName'].apply(coarse_cell_type_map.__getitem__)

neighborhood_palette = {i: j for i, j in zip(sorted(all_adata.obs['neighborhood name'].dropna().unique().tolist()), sc.pl.palettes.vega_10_scanpy)}
adatas[i_sample].uns['neighborhood name_colors'] = [neighborhood_palette[i] for i in sorted(adatas[i_sample].obs['neighborhood name'].dropna().unique().tolist())]
adatas[i_sample].obs['CoarseCellTypes'] = adatas[i_sample].obs['ClusterName'].apply(coarse_cell_type_map.__getitem__)
sq.pl.spatial_scatter(adatas[i_sample], color=['CoarseCellTypes'], shape=None, figsize=(4, 2), size=1., cmap='Reds', frameon=False)

sq.pl.spatial_scatter(adatas[i_sample], color=['neighborhood name'], shape=None, figsize=(5, 2), size=1., cmap='Reds', frameon=False)

WARNING: Please specify a valid `library_id` or set it permanently in `adata.uns['spatial']`
WARNING: Please specify a valid `library_id` or set it permanently in `adata.uns['spatial']`
C:\Users\lshh\miniconda3\envs\py311_torch211_cuda121\Lib\site-packages\squidpy\pl\_spatial_utils.py:946: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap', 'norm' will be ignored
  _cax = scatter(
C:\Users\lshh\miniconda3\envs\py311_torch211_cuda121\Lib\site-packages\squidpy\pl\_spatial_utils.py:946: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap', 'norm' will be ignored
  _cax = scatter(
../_images/tutorial_nbs_Ex3_crc_35_2.png
../_images/tutorial_nbs_Ex3_crc_35_3.png
[ ]:
adatas[0].uns['neighborhood name_colors'] = [neighborhood_palette[i] for i in sorted(adatas[0].obs['neighborhood name'].dropna().unique().tolist())]
sq.pl.spatial_scatter(adatas[0], color=['GATA3', 'neighborhood name'], shape=None, figsize=(2, 1.5), size=1., cmap='Reds', frameon=False, vmin=2, vmax=6)

adatas[10].uns['neighborhood name_colors'] = [neighborhood_palette[i] for i in sorted(adatas[10].obs['neighborhood name'].dropna().unique().tolist())]
sq.pl.spatial_scatter(adatas[10], color=['GATA3', 'neighborhood name'], shape=None, figsize=(2, 1.5), size=1., cmap='Reds', frameon=False, vmin=2, vmax=6)

WARNING: Please specify a valid `library_id` or set it permanently in `adata.uns['spatial']`
WARNING: Please specify a valid `library_id` or set it permanently in `adata.uns['spatial']`
C:\Users\lshh\miniconda3\envs\py311_torch211_cuda121\Lib\site-packages\squidpy\pl\_spatial_utils.py:946: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap', 'norm' will be ignored
  _cax = scatter(
C:\Users\lshh\miniconda3\envs\py311_torch211_cuda121\Lib\site-packages\squidpy\pl\_spatial_utils.py:946: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap', 'norm' will be ignored
  _cax = scatter(
../_images/tutorial_nbs_Ex3_crc_36_2.png
../_images/tutorial_nbs_Ex3_crc_36_3.png
[ ]: