Bayesian modeling of SC reliability¶
The interactive version of this notebook is available at https://github.com/TheAxonLab/hcph-sops/tree/mkdocs/docs/analysis.
In [2]:
Copied!
import h5py
import os
import pickle
import arviz as az
import numpy as np
import pymc as pm
import seaborn as sns
import matplotlib.pyplot as plt
from joblib import Parallel, delayed
from simulate_sc import (
simulate_sc_no_bias,
simulate_sc_noisy_copies,
simulate_sc_length_bias,
simulate_sc_density_bias,
simulate_sc_fns,
simulate_sc_fps,
)
from bayesian_modeling import fit_mixture_model
atlas_path = "/data/probconnatlas/wm.connatlas.scale1.h5"
# Load consistency matrix
with h5py.File(atlas_path, "r") as f:
print(f.keys())
print(f["matrices"].keys())
consistency_matrix = np.array(f["matrices"]["consistency"])
print("Loaded consistency matrix")
print(consistency_matrix)
density_matrix = np.array(f["matrices"]["numbStlines"])
import h5py
import os
import pickle
import arviz as az
import numpy as np
import pymc as pm
import seaborn as sns
import matplotlib.pyplot as plt
from joblib import Parallel, delayed
from simulate_sc import (
simulate_sc_no_bias,
simulate_sc_noisy_copies,
simulate_sc_length_bias,
simulate_sc_density_bias,
simulate_sc_fns,
simulate_sc_fps,
)
from bayesian_modeling import fit_mixture_model
atlas_path = "/data/probconnatlas/wm.connatlas.scale1.h5"
# Load consistency matrix
with h5py.File(atlas_path, "r") as f:
print(f.keys())
print(f["matrices"].keys())
consistency_matrix = np.array(f["matrices"]["consistency"])
print("Loaded consistency matrix")
print(consistency_matrix)
density_matrix = np.array(f["matrices"]["numbStlines"])
<KeysViewHDF5 ['atlas', 'header', 'matrices']> <KeysViewHDF5 ['consistency', 'length', 'numbStlines']> Loaded consistency matrix [[66 65 65 ... 0 1 66] [65 65 65 ... 0 0 60] [65 65 60 ... 0 0 15] ... [ 0 0 0 ... 17 64 57] [ 1 0 0 ... 64 66 66] [66 60 15 ... 57 66 66]]
Determine threshold to consider connection existence¶
We can determine if a connection exists or not based on the consistency matrix from the atlas of Alemán-Gómez et al. 2022. To binarize the consistency matrix, we have to choose a threshold.
In [4]:
Copied!
threshold = np.percentile(consistency_matrix[consistency_matrix > 0], 40)
print("Threshold to binarize consistency matrix:", threshold)
print("Max consistency:", np.max(consistency_matrix))
# Plot histogram for consistency_matrix
plt.figure(figsize=(32, 8))
plt.subplot(1, 3, 1)
plt.hist(consistency_matrix.flatten(), bins=50, color="blue", alpha=0.7)
plt.axvline(
x=threshold, color="red", linestyle="--", label=f"Threshold ({threshold:.2f})"
)
plt.legend()
plt.title("Consistency Matrix")
plt.xlabel("Value")
plt.ylabel("Frequency")
plt.subplot(1, 3, 2)
sns.heatmap(consistency_matrix, cmap="viridis", cbar=True)
plt.title("Heatmap of Consistency Matrix")
plt.xlabel("Region Index")
plt.ylabel("Region Index")
plt.subplot(1, 3, 3)
sns.heatmap(consistency_matrix > threshold, cmap="viridis", cbar=True)
plt.title("Connection exists or not")
plt.xlabel("Region Index")
plt.ylabel("Region Index")
plt.show()
threshold = np.percentile(consistency_matrix[consistency_matrix > 0], 40)
print("Threshold to binarize consistency matrix:", threshold)
print("Max consistency:", np.max(consistency_matrix))
# Plot histogram for consistency_matrix
plt.figure(figsize=(32, 8))
plt.subplot(1, 3, 1)
plt.hist(consistency_matrix.flatten(), bins=50, color="blue", alpha=0.7)
plt.axvline(
x=threshold, color="red", linestyle="--", label=f"Threshold ({threshold:.2f})"
)
plt.legend()
plt.title("Consistency Matrix")
plt.xlabel("Value")
plt.ylabel("Frequency")
plt.subplot(1, 3, 2)
sns.heatmap(consistency_matrix, cmap="viridis", cbar=True)
plt.title("Heatmap of Consistency Matrix")
plt.xlabel("Region Index")
plt.ylabel("Region Index")
plt.subplot(1, 3, 3)
sns.heatmap(consistency_matrix > threshold, cmap="viridis", cbar=True)
plt.title("Connection exists or not")
plt.xlabel("Region Index")
plt.ylabel("Region Index")
plt.show()
Threshold to binarize consistency matrix: 47.0 Max consistency: 66
Load simulated SC matrices¶
In [20]:
Copied!
from matplotlib.colors import LogNorm
# Load simulated SC
SC_matrices, noise = simulate_sc_density_bias(
atlas_path=atlas_path, connectome_atlas_as_ref=True, num_sessions=36, atlas_dim=5
)
# SC_matrices, noise = simulate_sc_density_bias(atlas_path=atlas_path)
num_sessions = SC_matrices.shape[0]
atlas_dim = SC_matrices.shape[1]
ref_SC_matrix = density_matrix
variability = np.std(SC_matrices, axis=0)
fig, (ax_heatmap, ax_scatter) = plt.subplots(
1, 2, figsize=(20, 8), gridspec_kw={"width_ratios": [1, 1]}
)
# Heatmap of the density matrix
sns.heatmap(
ref_SC_matrix,
cmap="viridis",
cbar=True,
norm=LogNorm(
vmin=np.nanmin(ref_SC_matrix[ref_SC_matrix > 0]), vmax=np.nanmax(ref_SC_matrix)
),
ax=ax_heatmap,
)
ax_heatmap.set_title(
"A) Heatmap of the reference SC Matrix (Log Scale)", fontsize=18, loc="left"
)
ax_heatmap.set_xlabel("Region Index", fontsize=16)
ax_heatmap.set_ylabel("Region Index", fontsize=16)
cbar = ax_heatmap.collections[0].colorbar
cbar.ax.tick_params(labelsize=14)
cbar.set_label(" Average number of streamlines", fontsize=14)
# Scatter plot
scatter = ax_scatter.scatter(
ref_SC_matrix.flatten(), variability.flatten(), alpha=0.5, color="grey"
)
ax_scatter.set_title("B) Variability across repeated measures", fontsize=18, loc="left")
ax_scatter.set_xlabel("SC Value", fontsize=16)
ax_scatter.set_ylabel("Standard deviation across repeated measures", fontsize=16)
ax_scatter.tick_params(axis="both", which="major", labelsize=14)
ax_scatter.grid(True)
# Add distribution of variability as a marginal histogram on the right
ax_hist = ax_scatter.inset_axes(
[1.02, 0, 0.2, 1]
) # [x, y, width, height] relative to ax_scatter
sns.histplot(
y=variability.flatten(), bins=50, color="grey", alpha=0.7, kde=True, ax=ax_hist
)
ax_hist.set_xlabel("Frequency", fontsize=16)
ax_hist.tick_params(
axis="both",
which="major",
labelsize=14,
right=True,
labelright=True,
left=False,
labelleft=False,
)
ax_hist.grid(True)
plt.tight_layout()
plt.show()
from matplotlib.colors import LogNorm
# Load simulated SC
SC_matrices, noise = simulate_sc_density_bias(
atlas_path=atlas_path, connectome_atlas_as_ref=True, num_sessions=36, atlas_dim=5
)
# SC_matrices, noise = simulate_sc_density_bias(atlas_path=atlas_path)
num_sessions = SC_matrices.shape[0]
atlas_dim = SC_matrices.shape[1]
ref_SC_matrix = density_matrix
variability = np.std(SC_matrices, axis=0)
fig, (ax_heatmap, ax_scatter) = plt.subplots(
1, 2, figsize=(20, 8), gridspec_kw={"width_ratios": [1, 1]}
)
# Heatmap of the density matrix
sns.heatmap(
ref_SC_matrix,
cmap="viridis",
cbar=True,
norm=LogNorm(
vmin=np.nanmin(ref_SC_matrix[ref_SC_matrix > 0]), vmax=np.nanmax(ref_SC_matrix)
),
ax=ax_heatmap,
)
ax_heatmap.set_title(
"A) Heatmap of the reference SC Matrix (Log Scale)", fontsize=18, loc="left"
)
ax_heatmap.set_xlabel("Region Index", fontsize=16)
ax_heatmap.set_ylabel("Region Index", fontsize=16)
cbar = ax_heatmap.collections[0].colorbar
cbar.ax.tick_params(labelsize=14)
cbar.set_label(" Average number of streamlines", fontsize=14)
# Scatter plot
scatter = ax_scatter.scatter(
ref_SC_matrix.flatten(), variability.flatten(), alpha=0.5, color="grey"
)
ax_scatter.set_title("B) Variability across repeated measures", fontsize=18, loc="left")
ax_scatter.set_xlabel("SC Value", fontsize=16)
ax_scatter.set_ylabel("Standard deviation across repeated measures", fontsize=16)
ax_scatter.tick_params(axis="both", which="major", labelsize=14)
ax_scatter.grid(True)
# Add distribution of variability as a marginal histogram on the right
ax_hist = ax_scatter.inset_axes(
[1.02, 0, 0.2, 1]
) # [x, y, width, height] relative to ax_scatter
sns.histplot(
y=variability.flatten(), bins=50, color="grey", alpha=0.7, kde=True, ax=ax_hist
)
ax_hist.set_xlabel("Frequency", fontsize=16)
ax_hist.tick_params(
axis="both",
which="major",
labelsize=14,
right=True,
labelright=True,
left=False,
labelleft=False,
)
ax_hist.grid(True)
plt.tight_layout()
plt.show()
Using the connectome atlas as reference SC matrix. Simulated a series of 36 SC matrices of shape (95x95) with higher variability (std=0.5) in lower density connections.
In [4]:
Copied!
# Compute the percentage of zero connections
zero_connections_percentage = (
np.sum(consistency_matrix < threshold) / consistency_matrix.size * 100
)
print(f"Percentage of zero connections: {zero_connections_percentage:.2f}%")
# Compute the percentage of zero connections
zero_connections_percentage = (
np.sum(consistency_matrix < threshold) / consistency_matrix.size * 100
)
print(f"Percentage of zero connections: {zero_connections_percentage:.2f}%")
Percentage of zero connections: 42.58%
Bayesian model fitted to our simulated SC matrices¶
Because the code needed to let the code run in the background overnight, we move the fitting of the Bayesian model in its own standalone Python code run_bayesian_sc.py.
!! Please run that code to generate the output files saving the trace and parameters evaluation from Bayesian modeling before running the rest of this notebook. !!
In [5]:
Copied!
def load_edge(output_dir, c, SC_matrices_flat, mu_type="fixed"):
pkl_file = os.path.join(output_dir, f"edge_{c:05d}.pkl")
trace_file = os.path.join(output_dir, f"edge_{c:05d}_trace.nc")
if os.path.exists(pkl_file):
with open(pkl_file, "rb") as f:
param_values = pickle.load(f)
else:
# param_values = {var: 0 for var in ["pi0", "lambda_exp", "sigma", "mu"]}
raise FileNotFoundError(
f"File {pkl_file} not found. Please run the model fitting first using run_bayesian_sc.py."
)
if os.path.exists(trace_file):
trace = az.from_netcdf(trace_file)
else:
raise FileNotFoundError(
f"File {trace_file} not found. Please run the model fitting first using run_bayesian_sc.py."
)
var_names = ["pi0", "lambda_exp", "sigma"]
if mu_type == "fixed":
param_values["mu"] = np.mean(SC_matrices_flat[:, c])
else:
var_names.append("mu")
summary = az.summary(trace, var_names=var_names)
for var in var_names:
param_values[var] = summary.loc[var].to_dict()
return param_values
def load_edge(output_dir, c, SC_matrices_flat, mu_type="fixed"):
pkl_file = os.path.join(output_dir, f"edge_{c:05d}.pkl")
trace_file = os.path.join(output_dir, f"edge_{c:05d}_trace.nc")
if os.path.exists(pkl_file):
with open(pkl_file, "rb") as f:
param_values = pickle.load(f)
else:
# param_values = {var: 0 for var in ["pi0", "lambda_exp", "sigma", "mu"]}
raise FileNotFoundError(
f"File {pkl_file} not found. Please run the model fitting first using run_bayesian_sc.py."
)
if os.path.exists(trace_file):
trace = az.from_netcdf(trace_file)
else:
raise FileNotFoundError(
f"File {trace_file} not found. Please run the model fitting first using run_bayesian_sc.py."
)
var_names = ["pi0", "lambda_exp", "sigma"]
if mu_type == "fixed":
param_values["mu"] = np.mean(SC_matrices_flat[:, c])
else:
var_names.append("mu")
summary = az.summary(trace, var_names=var_names)
for var in var_names:
param_values[var] = summary.loc[var].to_dict()
return param_values
In [6]:
Copied!
mu_type = "fixed"
output_dir = "/home/cprovins/projects/bayesian_sc/mixture_model_1"
print(f"Shape of SC matrices: {SC_matrices.shape}")
SC_matrices_flat = SC_matrices.reshape(num_sessions, -1)
SC_matrices_flat = np.nan_to_num(SC_matrices_flat, nan=0)
param_values_file = os.path.join(output_dir, "extra_param_values.pkl")
if os.path.exists(param_values_file):
print("Loading param_values from file...")
with open(param_values_file, "rb") as f:
param_values = pickle.load(f)
else:
param_values = Parallel(n_jobs=30)(
delayed(load_edge)(output_dir, c, SC_matrices_flat)
for c in range(SC_matrices_flat.shape[1])
)
with open(param_values_file, "wb") as f:
pickle.dump(param_values, f)
print("param_values saved to file.")
mu_type = "fixed"
output_dir = "/home/cprovins/projects/bayesian_sc/mixture_model_1"
print(f"Shape of SC matrices: {SC_matrices.shape}")
SC_matrices_flat = SC_matrices.reshape(num_sessions, -1)
SC_matrices_flat = np.nan_to_num(SC_matrices_flat, nan=0)
param_values_file = os.path.join(output_dir, "extra_param_values.pkl")
if os.path.exists(param_values_file):
print("Loading param_values from file...")
with open(param_values_file, "rb") as f:
param_values = pickle.load(f)
else:
param_values = Parallel(n_jobs=30)(
delayed(load_edge)(output_dir, c, SC_matrices_flat)
for c in range(SC_matrices_flat.shape[1])
)
with open(param_values_file, "wb") as f:
pickle.dump(param_values, f)
print("param_values saved to file.")
Shape of SC matrices: (36, 95, 95) Loading param_values from file...
Plot the estimated parameters¶
In [7]:
Copied!
for p in ["pi0", "lambda_exp", "sigma"]:
param = np.array([edge[p]["mean"] for edge in param_values]).reshape(
atlas_dim, atlas_dim
)
plt.figure(figsize=(10, 8))
sns.heatmap(param, cmap="viridis", cbar=True)
plt.title(f"Heatmap of mean {p}")
plt.xlabel("Region Index")
plt.ylabel("Region Index")
plt.show()
param = np.array([edge[p]["sd"] for edge in param_values]).reshape(
atlas_dim, atlas_dim
)
plt.figure(figsize=(10, 8))
sns.heatmap(param, cmap="viridis", cbar=True)
plt.title(f"Heatmap of {p} standard deviation")
plt.xlabel("Region Index")
plt.ylabel("Region Index")
plt.show()
param = np.array([edge["mu"] for edge in param_values]).reshape(atlas_dim, atlas_dim)
plt.figure(figsize=(10, 8))
sns.heatmap(param, cmap="viridis", cbar=True)
plt.title(f"Heatmap of mu")
plt.xlabel("Region Index")
plt.ylabel("Region Index")
plt.show()
for p in ["pi0", "lambda_exp", "sigma"]:
param = np.array([edge[p]["mean"] for edge in param_values]).reshape(
atlas_dim, atlas_dim
)
plt.figure(figsize=(10, 8))
sns.heatmap(param, cmap="viridis", cbar=True)
plt.title(f"Heatmap of mean {p}")
plt.xlabel("Region Index")
plt.ylabel("Region Index")
plt.show()
param = np.array([edge[p]["sd"] for edge in param_values]).reshape(
atlas_dim, atlas_dim
)
plt.figure(figsize=(10, 8))
sns.heatmap(param, cmap="viridis", cbar=True)
plt.title(f"Heatmap of {p} standard deviation")
plt.xlabel("Region Index")
plt.ylabel("Region Index")
plt.show()
param = np.array([edge["mu"] for edge in param_values]).reshape(atlas_dim, atlas_dim)
plt.figure(figsize=(10, 8))
sns.heatmap(param, cmap="viridis", cbar=True)
plt.title(f"Heatmap of mu")
plt.xlabel("Region Index")
plt.ylabel("Region Index")
plt.show()
In [8]:
Copied!
def plot_hdi(param_values, param_name):
import matplotlib.pyplot as plt
import numpy as np
# Extract values
means = [res[param_name]["mean"] for res in param_values]
lower = [res[param_name]["hdi_3%"] for res in param_values]
upper = [res[param_name]["hdi_97%"] for res in param_values]
# Convert to arrays for easier math
means = np.array(means)
lower = np.array(lower)
upper = np.array(upper)
# Compute asymmetric error bars
yerr = np.vstack([means - lower, upper - means])
# Plot
plt.figure(figsize=(20, 4))
x = np.arange(len(means))
plt.errorbar(x, means, yerr=yerr, fmt="o", capsize=3, ecolor="gray", alpha=0.7)
plt.xlabel("Edge index")
plt.ylabel(f"Posterior mean of {param_name}")
plt.title(f"Estimated {param_name} with 94% HDI per edge")
plt.tight_layout()
plt.show()
plot_hdi(param_values, "pi0")
plot_hdi(param_values, "sigma")
def plot_hdi(param_values, param_name):
import matplotlib.pyplot as plt
import numpy as np
# Extract values
means = [res[param_name]["mean"] for res in param_values]
lower = [res[param_name]["hdi_3%"] for res in param_values]
upper = [res[param_name]["hdi_97%"] for res in param_values]
# Convert to arrays for easier math
means = np.array(means)
lower = np.array(lower)
upper = np.array(upper)
# Compute asymmetric error bars
yerr = np.vstack([means - lower, upper - means])
# Plot
plt.figure(figsize=(20, 4))
x = np.arange(len(means))
plt.errorbar(x, means, yerr=yerr, fmt="o", capsize=3, ecolor="gray", alpha=0.7)
plt.xlabel("Edge index")
plt.ylabel(f"Posterior mean of {param_name}")
plt.title(f"Estimated {param_name} with 94% HDI per edge")
plt.tight_layout()
plt.show()
plot_hdi(param_values, "pi0")
plot_hdi(param_values, "sigma")
In [9]:
Copied!
from scipy.stats import ttest_ind
sigma = np.array([edge["sigma"]["mean"] for edge in param_values]).reshape(
atlas_dim, atlas_dim
)
mask_h = density_matrix >= np.percentile(density_matrix[~np.isnan(density_matrix)], 40)
mask_l = density_matrix < np.percentile(density_matrix[~np.isnan(density_matrix)], 40)
alpha_mask_l = np.ones_like(sigma)
alpha_mask_l[mask_h] = (
0.4 # make the voxels that should not be largely biased semi-transparent
)
alpha_mask_h = np.ones_like(sigma)
alpha_mask_h[mask_l] = (
0.3 # make the voxels that should not be largely biased semi-transparent
)
# Perform a t-test comparing the estimated sigma values for low and high density connections
t_stat, p_value = ttest_ind(sigma[mask_l].flatten(), sigma[mask_h].flatten())
print(f"T-statistic: {t_stat:.4f}, P-value: {p_value:.12e}")
fig, axes = plt.subplots(1, 3, figsize=(20, 6))
# Heatmap with alpha_mask
sns.heatmap(sigma, cmap="viridis", cbar=True, alpha=alpha_mask_l, ax=axes[0])
axes[0].set_title(
r"A) Estimated $\sigma$ highlighted for low density connection",
fontsize=14,
loc="left",
)
axes[0].set_xlabel("Region Index", fontsize=13)
axes[0].set_ylabel("Region Index", fontsize=13)
# Heatmap with 1 - alpha_mask
sns.heatmap(sigma, cmap="viridis", cbar=True, alpha=alpha_mask_h, ax=axes[1])
axes[1].set_title(
r"B) Estimated $\sigma$ highlighted for high density connection",
fontsize=14,
loc="left",
)
axes[1].set_xlabel("Region Index", fontsize=13)
axes[1].set_ylabel("Region Index", fontsize=13)
# Boxplot for low and high density connections
# Boxplot for low and high density connections
axes[2].boxplot(
[sigma[mask_l].flatten(), sigma[mask_h].flatten()],
tick_labels=["Low Density Connections", "High Density Connections"],
)
axes[2].set_xticklabels(
["Low Density Connections", "High Density Connections"], fontsize=13
)
axes[2].set_title(
r"C) $\sigma$ estimation for low versus high density connections",
fontsize=14,
loc="left",
)
axes[2].set_ylabel(r"mean $\sigma$", fontsize=13)
# Add a line with *** to indicate statistical significance
x1, x2 = 1, 2 # positions of the two boxplots
y, h, col = (
sigma[mask_l].max() + 0.05,
0.02,
"black",
) # y position, height of the line, and color
axes[2].plot([x1, x1, x2, x2], [y, y + h, y + h, y], lw=1.5, c=col)
axes[2].text((x1 + x2) * 0.5, y + h, "***", ha="center", va="bottom", color=col)
plt.tight_layout()
plt.show()
from scipy.stats import ttest_ind
sigma = np.array([edge["sigma"]["mean"] for edge in param_values]).reshape(
atlas_dim, atlas_dim
)
mask_h = density_matrix >= np.percentile(density_matrix[~np.isnan(density_matrix)], 40)
mask_l = density_matrix < np.percentile(density_matrix[~np.isnan(density_matrix)], 40)
alpha_mask_l = np.ones_like(sigma)
alpha_mask_l[mask_h] = (
0.4 # make the voxels that should not be largely biased semi-transparent
)
alpha_mask_h = np.ones_like(sigma)
alpha_mask_h[mask_l] = (
0.3 # make the voxels that should not be largely biased semi-transparent
)
# Perform a t-test comparing the estimated sigma values for low and high density connections
t_stat, p_value = ttest_ind(sigma[mask_l].flatten(), sigma[mask_h].flatten())
print(f"T-statistic: {t_stat:.4f}, P-value: {p_value:.12e}")
fig, axes = plt.subplots(1, 3, figsize=(20, 6))
# Heatmap with alpha_mask
sns.heatmap(sigma, cmap="viridis", cbar=True, alpha=alpha_mask_l, ax=axes[0])
axes[0].set_title(
r"A) Estimated $\sigma$ highlighted for low density connection",
fontsize=14,
loc="left",
)
axes[0].set_xlabel("Region Index", fontsize=13)
axes[0].set_ylabel("Region Index", fontsize=13)
# Heatmap with 1 - alpha_mask
sns.heatmap(sigma, cmap="viridis", cbar=True, alpha=alpha_mask_h, ax=axes[1])
axes[1].set_title(
r"B) Estimated $\sigma$ highlighted for high density connection",
fontsize=14,
loc="left",
)
axes[1].set_xlabel("Region Index", fontsize=13)
axes[1].set_ylabel("Region Index", fontsize=13)
# Boxplot for low and high density connections
# Boxplot for low and high density connections
axes[2].boxplot(
[sigma[mask_l].flatten(), sigma[mask_h].flatten()],
tick_labels=["Low Density Connections", "High Density Connections"],
)
axes[2].set_xticklabels(
["Low Density Connections", "High Density Connections"], fontsize=13
)
axes[2].set_title(
r"C) $\sigma$ estimation for low versus high density connections",
fontsize=14,
loc="left",
)
axes[2].set_ylabel(r"mean $\sigma$", fontsize=13)
# Add a line with *** to indicate statistical significance
x1, x2 = 1, 2 # positions of the two boxplots
y, h, col = (
sigma[mask_l].max() + 0.05,
0.02,
"black",
) # y position, height of the line, and color
axes[2].plot([x1, x1, x2, x2], [y, y + h, y + h, y], lw=1.5, c=col)
axes[2].text((x1 + x2) * 0.5, y + h, "***", ha="center", va="bottom", color=col)
plt.tight_layout()
plt.show()
T-statistic: 313.0145, P-value: 0.000000000000e+00
In [10]:
Copied!
threshold = np.percentile(consistency_matrix[consistency_matrix > 0], 40)
pi0 = np.array([edge["pi0"]["mean"] for edge in param_values]).reshape(
atlas_dim, atlas_dim
)
nan_mask = np.isnan(density_matrix)
low_cons = (consistency_matrix <= threshold) & (~nan_mask)
high_cons = consistency_matrix > threshold
alpha_mask_lc = np.ones_like(pi0)
alpha_mask_hc = np.ones_like(pi0)
alpha_mask_na = np.ones_like(pi0)
alpha_mask_na[~nan_mask] = 0.4
alpha_mask_hc[~low_cons] = 0.4
alpha_mask_lc[~high_cons] = 0.4
# Perform a t-test comparing the estimated sigma values for low and high density connections
t_stat, p_value = ttest_ind(pi0[nan_mask].flatten(), pi0[high_cons].flatten())
print(
f"nan versus high consistency: T-statistic: {t_stat:.4f}, P-value: {p_value:.12e}"
)
t_stat, p_value = ttest_ind(pi0[nan_mask].flatten(), pi0[low_cons].flatten())
print(f"nan versus low consistency: T-statistic: {t_stat:.4f}, P-value: {p_value:.12e}")
t_stat, p_value = ttest_ind(pi0[low_cons].flatten(), pi0[high_cons].flatten())
print(
f"low versus high consistency: T-statistic: {t_stat:.4f}, P-value: {p_value:.12e}"
)
fig, axes = plt.subplots(2, 2, figsize=(14, 12))
# Heatmap with alpha_mask
sns.heatmap(pi0, cmap="viridis", cbar=True, alpha=alpha_mask_na, ax=axes[1, 0])
axes[1, 0].set_title(
r"A) Estimated $\pi_0$ highlighted for truly absent connections",
fontsize=14,
loc="left",
)
axes[1, 0].set_xlabel("Region Index", fontsize=13)
axes[1, 0].set_ylabel("Region Index", fontsize=13)
sns.heatmap(pi0, cmap="viridis", cbar=True, alpha=alpha_mask_hc, ax=axes[0, 0])
axes[0, 0].set_title(
"B) Estimated $\pi_0$ highlighted for low consistency connections\npresent in at least one subject",
fontsize=14,
loc="left",
)
axes[0, 0].set_xlabel("Region Index", fontsize=13)
axes[0, 0].set_ylabel("Region Index", fontsize=13)
# Heatmap with 1 - alpha_mask
sns.heatmap(pi0, cmap="viridis", cbar=True, alpha=alpha_mask_lc, ax=axes[0, 1])
axes[0, 1].set_title(
r"C) Estimated $\pi_0$ highlighted for high consistency connections",
fontsize=14,
loc="left",
)
axes[0, 1].set_xlabel("Region Index", fontsize=13)
axes[0, 1].set_ylabel("Region Index", fontsize=13)
# Boxplot
axes[1, 1].boxplot(
[pi0[nan_mask].flatten(), pi0[low_cons].flatten(), pi0[high_cons].flatten()]
)
axes[1, 1].set_xticklabels(
[
"Connections absent\nin all subjects",
"Low-consistency\nconnections\npresent in at least\none subject",
"High-consistency\nconnections",
],
fontsize=13,
)
axes[1, 1].set_title(
r"D) $\pi_0$ estimation per connection grouping", fontsize=14, loc="left"
)
axes[1, 1].set_ylabel(r"mean $\pi_0$", fontsize=13)
# Add a line with *** to indicate statistical significance
x1, x2 = 1, 3 # positions of the two boxplots
y, h, col = (
pi0[~nan_mask].max() + 0.03,
0.01,
"black",
) # y position, height of the line, and color
axes[1, 1].plot([x1, x1, x2, x2], [y, y + h, y + h, y], lw=1.5, c=col)
axes[1, 1].text((x1 + x2) * 0.5, y + h, f"***", ha="center", va="bottom", color=col)
x1, x2 = 2, 3 # positions of the two boxplots
y, h, col = (
pi0[~nan_mask].max() + 0.01,
0.01,
"black",
) # y position, height of the line, and color
axes[1, 1].plot([x1, x1, x2, x2], [y, y + h, y + h, y], lw=1.5, c=col)
axes[1, 1].text((x1 + x2) * 0.5, y + h, f"***", ha="center", va="bottom", color=col)
x1, x2 = 1, 2 # positions of the two boxplots
y, h, col = (
pi0[~nan_mask].max() + 0.01,
0.01,
"black",
) # y position, height of the line, and color
axes[1, 1].plot([x1, x1, x2, x2], [y, y + h, y + h, y], lw=1.5, c=col)
axes[1, 1].text((x1 + x2) * 0.5, y + h, f"***", ha="center", va="bottom", color=col)
plt.tight_layout()
plt.show()
threshold = np.percentile(consistency_matrix[consistency_matrix > 0], 40)
pi0 = np.array([edge["pi0"]["mean"] for edge in param_values]).reshape(
atlas_dim, atlas_dim
)
nan_mask = np.isnan(density_matrix)
low_cons = (consistency_matrix <= threshold) & (~nan_mask)
high_cons = consistency_matrix > threshold
alpha_mask_lc = np.ones_like(pi0)
alpha_mask_hc = np.ones_like(pi0)
alpha_mask_na = np.ones_like(pi0)
alpha_mask_na[~nan_mask] = 0.4
alpha_mask_hc[~low_cons] = 0.4
alpha_mask_lc[~high_cons] = 0.4
# Perform a t-test comparing the estimated sigma values for low and high density connections
t_stat, p_value = ttest_ind(pi0[nan_mask].flatten(), pi0[high_cons].flatten())
print(
f"nan versus high consistency: T-statistic: {t_stat:.4f}, P-value: {p_value:.12e}"
)
t_stat, p_value = ttest_ind(pi0[nan_mask].flatten(), pi0[low_cons].flatten())
print(f"nan versus low consistency: T-statistic: {t_stat:.4f}, P-value: {p_value:.12e}")
t_stat, p_value = ttest_ind(pi0[low_cons].flatten(), pi0[high_cons].flatten())
print(
f"low versus high consistency: T-statistic: {t_stat:.4f}, P-value: {p_value:.12e}"
)
fig, axes = plt.subplots(2, 2, figsize=(14, 12))
# Heatmap with alpha_mask
sns.heatmap(pi0, cmap="viridis", cbar=True, alpha=alpha_mask_na, ax=axes[1, 0])
axes[1, 0].set_title(
r"A) Estimated $\pi_0$ highlighted for truly absent connections",
fontsize=14,
loc="left",
)
axes[1, 0].set_xlabel("Region Index", fontsize=13)
axes[1, 0].set_ylabel("Region Index", fontsize=13)
sns.heatmap(pi0, cmap="viridis", cbar=True, alpha=alpha_mask_hc, ax=axes[0, 0])
axes[0, 0].set_title(
"B) Estimated $\pi_0$ highlighted for low consistency connections\npresent in at least one subject",
fontsize=14,
loc="left",
)
axes[0, 0].set_xlabel("Region Index", fontsize=13)
axes[0, 0].set_ylabel("Region Index", fontsize=13)
# Heatmap with 1 - alpha_mask
sns.heatmap(pi0, cmap="viridis", cbar=True, alpha=alpha_mask_lc, ax=axes[0, 1])
axes[0, 1].set_title(
r"C) Estimated $\pi_0$ highlighted for high consistency connections",
fontsize=14,
loc="left",
)
axes[0, 1].set_xlabel("Region Index", fontsize=13)
axes[0, 1].set_ylabel("Region Index", fontsize=13)
# Boxplot
axes[1, 1].boxplot(
[pi0[nan_mask].flatten(), pi0[low_cons].flatten(), pi0[high_cons].flatten()]
)
axes[1, 1].set_xticklabels(
[
"Connections absent\nin all subjects",
"Low-consistency\nconnections\npresent in at least\none subject",
"High-consistency\nconnections",
],
fontsize=13,
)
axes[1, 1].set_title(
r"D) $\pi_0$ estimation per connection grouping", fontsize=14, loc="left"
)
axes[1, 1].set_ylabel(r"mean $\pi_0$", fontsize=13)
# Add a line with *** to indicate statistical significance
x1, x2 = 1, 3 # positions of the two boxplots
y, h, col = (
pi0[~nan_mask].max() + 0.03,
0.01,
"black",
) # y position, height of the line, and color
axes[1, 1].plot([x1, x1, x2, x2], [y, y + h, y + h, y], lw=1.5, c=col)
axes[1, 1].text((x1 + x2) * 0.5, y + h, f"***", ha="center", va="bottom", color=col)
x1, x2 = 2, 3 # positions of the two boxplots
y, h, col = (
pi0[~nan_mask].max() + 0.01,
0.01,
"black",
) # y position, height of the line, and color
axes[1, 1].plot([x1, x1, x2, x2], [y, y + h, y + h, y], lw=1.5, c=col)
axes[1, 1].text((x1 + x2) * 0.5, y + h, f"***", ha="center", va="bottom", color=col)
x1, x2 = 1, 2 # positions of the two boxplots
y, h, col = (
pi0[~nan_mask].max() + 0.01,
0.01,
"black",
) # y position, height of the line, and color
axes[1, 1].plot([x1, x1, x2, x2], [y, y + h, y + h, y], lw=1.5, c=col)
axes[1, 1].text((x1 + x2) * 0.5, y + h, f"***", ha="center", va="bottom", color=col)
plt.tight_layout()
plt.show()
<>:32: SyntaxWarning: invalid escape sequence '\p'
<>:32: SyntaxWarning: invalid escape sequence '\p'
/tmp/ipykernel_132863/205204751.py:32: SyntaxWarning: invalid escape sequence '\p'
axes[0,0].set_title("B) Estimated $\pi_0$ highlighted for low consistency connections\npresent in at least one subject",fontsize=14, loc='left')
nan versus high consistency: T-statistic: 57.8565, P-value: 0.000000000000e+00 nan versus low consistency: T-statistic: -3.6577, P-value: 2.578655255259e-04 low versus high consistency: T-statistic: 80.9834, P-value: 0.000000000000e+00
In [11]:
Copied!
mu = np.array([edge["mu"] for edge in param_values]).reshape(atlas_dim, atlas_dim)
plt.figure(figsize=(10, 8))
sns.heatmap(param, cmap="viridis_r", cbar=True, vmax=10)
plt.title(f"Heatmap of mu")
plt.xlabel("Region Index")
plt.ylabel("Region Index")
plt.show()
mu = np.array([edge["mu"] for edge in param_values]).reshape(atlas_dim, atlas_dim)
plt.figure(figsize=(10, 8))
sns.heatmap(param, cmap="viridis_r", cbar=True, vmax=10)
plt.title(f"Heatmap of mu")
plt.xlabel("Region Index")
plt.ylabel("Region Index")
plt.show()
In [12]:
Copied!
# Flatten mu and consistency_matrix, mask out nan values in consistency_matrix
consistency_flat = consistency_matrix.flatten()
mu_flat = mu.flatten()
# Mask out nan values (if any) in consistency_matrix
corr = np.corrcoef(mu_flat, consistency_flat)[0, 1]
print(f"Correlation between mu and consistency_matrix: {corr:.4f}")
# Flatten mu and consistency_matrix, mask out nan values in consistency_matrix
consistency_flat = consistency_matrix.flatten()
mu_flat = mu.flatten()
# Mask out nan values (if any) in consistency_matrix
corr = np.corrcoef(mu_flat, consistency_flat)[0, 1]
print(f"Correlation between mu and consistency_matrix: {corr:.4f}")
Correlation between mu and consistency_matrix: 0.2431