Datasets

In this analysis, we compare Eisenberg's dataset of interacting APRs (labeled as Template structures) and our own created dataset of hexapeptide pairs (extracted from known amyloid proteins in the PDB database and repaired with FoldX). Eisenberg's dataset contains 181 structures, our dataset ~tens of thousands.

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import norm
from Bio.PDB import PDBParser, PPBuilder

import math
import os

# RSA #
from Bio.PDB import PDBParser
from Bio.Data.IUPACData import protein_letters_3to1
from Bio.PDB.Polypeptide import is_aa
from Bio.PDB import SASA
from Bio.PDB.SASA import ShrakeRupley
from collections import defaultdict

from concurrent.futures import ProcessPoolExecutor, as_completed
from Bio.PDB import PDBParser, PDBExceptions
#######

# Shape Analysis #
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import AgglomerativeClustering

from sklearn.decomposition import PCA
import umap
import matplotlib.colors as mcolors

from skbio.stats.ordination import pcoa
from skbio import DistanceMatrix
##################

# Trees #
import scipy.cluster.hierarchy as sch
from scipy.spatial.distance import squareform
#########


sns.set_theme(style="whitegrid")

Some structures needed to be standardized before feeding them in FoldX's Repair command. We load our datasets and remove non-standardized structures (which have a standardized equivalent).

# Get rid of non-standardized files if standardized exists

def filter_non_standardized(df):
    df = df.copy()
    df['base_id'] = df['PDB_File'].apply(get_base_id)

    # Find which base IDs have both versions
    has_standardized = set(
        df[df['PDB_File'].str.endswith('_standardized_Repair.pdb')]['base_id']
    )

    is_standardized = set(
        df[df['PDB_File'].str.endswith('_standardized_Repair.pdb')]['PDB_File']
    )

    # Keep:
    # - all standardized files
    # - non-standardized files only if no standardized version exists
    filtered_df = df[
        (
            ~df['PDB_File'].str.endswith('_Repair.pdb') | 
            (df['PDB_File'].str.endswith('_Repair.pdb') & ~df['base_id'].isin(has_standardized)) |
            df['PDB_File'].isin(is_standardized)
        )
    ]

    # Drop helper column
    filtered_df = filtered_df.drop(columns=['base_id'])
    
    return filtered_df

# Extract base ID (before _Repair.pdb or _standardized_Repair.pdb)
def get_base_id(filename):
    if filename.endswith('_standardized_Repair.pdb'):
        return filename.replace('_standardized_Repair.pdb', '')
    elif filename.endswith('_Repair.pdb'):
        return filename.replace('_Repair.pdb', '')
    else:
        return filename
# Stability datasets
print("Reading stability datasets...")

## Read template (Eisenberg's) dataset
template_df = pd.read_csv("./results/foldx_stability_results_templates.csv")
## Read combined stability results for unrepaired and repaired fragment pairs and split them
combined_stability_df = pd.read_csv("./results/foldx_stability_results_fragment_pairs.csv")
extracted_df = combined_stability_df[~combined_stability_df['PDB_File'].str.endswith('Repair.pdb')]
repaired_df = combined_stability_df[combined_stability_df['PDB_File'].str.endswith('Repair.pdb')]
print("Loaded\n")

## Remove unwanted elements
print(f"Removing unwanted elements from the datasets...")
print(f"Extracted: {len(extracted_df[extracted_df['PDB_File'].str.endswith('standardized.pdb')])} standardized extracts out of {len(extracted_df)}")
print(f"Repaired: {len(repaired_df[repaired_df['PDB_File'].str.endswith('standardized_Repair.pdb')])} standardized repairs out of {len(repaired_df)}")
extracted_df = extracted_df[~extracted_df['PDB_File'].str.endswith('standardized.pdb')] # Remove standardized unrepaired duplicates
repaired_df = filter_non_standardized(repaired_df) # remove Repairs with standardized duplicates
print(f"Extracted: {len(extracted_df)} dihexapeptide fragment pairs after initial filtering")
print(f"Repaired: {len(repaired_df)} dihexapeptide fragment pairs after initial filtering")
Reading stability datasets...
Loaded

Removing unwanted elements from the datasets...
Extracted: 1508 standardized extracts out of 193431
Repaired: 1508 standardized repairs out of 193431
Extracted: 191923 dihexapeptide fragment pairs after initial filtering
Repaired: 191923 dihexapeptide fragment pairs after initial filtering

We are going to be working with repaired_df (or its subsample) which is our extracted dataset after repair.

# Repaired dataset's info
repaired_df.info()
<class 'pandas.core.frame.DataFrame'>
Index: 191923 entries, 1 to 386861
Data columns (total 8 columns):
 #   Column                 Non-Null Count   Dtype  
---  ------                 --------------   -----  
 0   PDB_File               191923 non-null  object 
 1   Total_Energy           191923 non-null  float64
 2   Backbone_HBond         191923 non-null  float64
 3   Sidechain_HBond        191923 non-null  float64
 4   Van_der_Waals          191923 non-null  float64
 5   Electrostatics         191923 non-null  float64
 6   Solvation_Polar        191923 non-null  float64
 7   Solvation_Hydrophobic  191923 non-null  float64
dtypes: float64(7), object(1)
memory usage: 13.2+ MB

Here's a quick summary of all the three dataset's (including estimation of total energy):

# Print how many unique PDB ids are in each dataframe
pdb_ids_t = template_df["PDB_File"].str[:4].unique()
pdb_ids_e = extracted_df["PDB_File"].str[:4].unique()
pdb_ids_r = repaired_df["PDB_File"].str[:4].unique()

K = pd.DataFrame({
    "Dataframe": ["Template", "Extracted", "Repaired"],
    "Unique PDB IDs": [len(pdb_ids_t), len(pdb_ids_e), len(pdb_ids_r)],
    "Total Entries": [len(template_df), len(extracted_df), len(repaired_df)],
    "Mean Total Energy": [template_df["Total_Energy"].mean(), extracted_df["Total_Energy"].mean(), repaired_df["Total_Energy"].mean()],
    "Median Total Energy": [template_df["Total_Energy"].median(), extracted_df["Total_Energy"].median(), repaired_df["Total_Energy"].median()],
    "Max Total Energy": [template_df["Total_Energy"].max(), extracted_df["Total_Energy"].max(), repaired_df["Total_Energy"].max()],
    "Min Total Energy": [template_df["Total_Energy"].min(), extracted_df["Total_Energy"].min(), repaired_df["Total_Energy"].min()],
})

K.head()
Dataframe Unique PDB IDs Total Entries Mean Total Energy Median Total Energy Max Total Energy Min Total Energy
0 Template 80 181 12.377997 14.6891 57.066 -44.0693
1 Extracted 500 191923 49.973542 48.2727 538.420 -30.4614
2 Repaired 500 191923 30.177946 29.5613 270.050 -44.7015

During extraction, information about each structure was written down in a protein_properties csv. Here are few sanity checks if our pipeline worked as expected.

# Protein properties dataset
print("Loading protein properties dataset...")

protein_properties = pd.read_csv("./results/protein_data_pdb_amyloid_structures.csv")
protein_properties['pdb_file'] = protein_properties['pdb_file'].str[:4]
protein_properties = protein_properties.rename(columns={"pdb_file": "PDB ID"})
print("Loaded\n")

## Simple qualities
print(f"We have {len(protein_properties[protein_properties['num_of_pairs'] == 0])}/{len(protein_properties)} proteins with zero pairs.")
print(f"{len(protein_properties[protein_properties['avg_num_of_layers'] != 5])} proteins have an average number of layers different from 5.\n")

## Correction info
corrected_succ = protein_properties[(protein_properties['num_of_layers_before'] != 5) & (protein_properties['avg_num_of_layers'] == 5)]['PDB ID'].unique()
corrected_unsucc = protein_properties[protein_properties['avg_num_of_layers'] != 5]['PDB ID'].unique()
non_five_layer_proteins = protein_properties[protein_properties['avg_num_of_layers'] != 5]['PDB ID']
num_pairs_non_five = len(repaired_df[(repaired_df['PDB_File'].str[:4]).isin(non_five_layer_proteins)])
print(f"{len(corrected_succ)+len(corrected_unsucc)} proteins needed correction, {len(corrected_succ)} were successful, {len(corrected_unsucc)} were not!")
print(f"{len(repaired_df[(repaired_df['PDB_File'].str[:4]).isin(non_five_layer_proteins)])}/{len(repaired_df)} fragments come from NON-5-layered structs!")
print(f"{len(protein_properties[(protein_properties['avg_num_of_layers'] == 5) & (protein_properties['num_of_pairs'] > 0)])} 5-layered proteins contributed with {len(repaired_df) - num_pairs_non_five}/{len(repaired_df)} pairs.")
Loading protein properties dataset...
Loaded

We have 2/502 proteins with zero pairs.
8 proteins have an average number of layers different from 5.

344 proteins needed correction, 336 were successful, 8 were not!
12592/191923 fragments come from NON-5-layered structs!
492 5-layered proteins contributed with 179331/191923 pairs.

As we can see, few of our structures failed the layer correction. Some of them still contributed to our datasets. We will remove them.

# Protein properties dataset's info
protein_properties.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 502 entries, 0 to 501
Data columns (total 11 columns):
 #   Column                 Non-Null Count  Dtype  
---  ------                 --------------  -----  
 0   PDB ID                 502 non-null    object 
 1   num_of_polypeptides    502 non-null    int64  
 2   num_of_windows         502 non-null    int64  
 3   num_of_stacks          502 non-null    int64  
 4   avg_num_of_layers      502 non-null    float64
 5   num_of_pairs           502 non-null    int64  
 6   num_of_pps_before      502 non-null    int64  
 7   num_of_stacks_before   502 non-null    int64  
 8   num_of_layers_before   502 non-null    float64
 9   num_of_shorter_layers  502 non-null    int64  
 10  antiparallel_count     502 non-null    int64  
dtypes: float64(2), int64(8), object(1)
memory usage: 43.3+ KB
# Remove fragment pairs coming from NON-5-layered structures
print("Removing every fragment from the stability dataset's which comes from NON-5-layered structures...")

repaired_df = repaired_df[~(repaired_df['PDB_File'].str[:4]).isin(non_five_layer_proteins)]
extracted_df = extracted_df[~(extracted_df['PDB_File'].str[:4]).isin(non_five_layer_proteins)]
print(f"{len(extracted_df)} is the length of the Extracted dataset after correction")
print(f"{len(repaired_df)} is the length of the Repaired dataset after correction.")
Removing every fragment from the stability dataset's which comes from NON-5-layered structures...
179331 is the length of the Extracted dataset after correction
179331 is the length of the Repaired dataset after correction.

For our repaired_df dataset to be managable during computationally-heavy analyses, we take a 20% sample.

# Create a 20% sample of the repaired dataset for certain heavy analyses
sample_repaired = repaired_df.sample(frac=0.20, random_state=42)
print(f"Sampled {len(sample_repaired)} rows from repaired_df for later analyses.")

paths = set(sample_repaired["PDB_File"].unique())
with open("pdb_file_list_test.txt", "w") as f:
    for pdb in paths:
        f.write(pdb + "\n")
# Then use:
# rsync -av --progress --partial --append-verify --files-from=pdb_file_list.txt meta:/storage/praha1/home/tobiasma/switch-lab/amyloid-interactions/fragment_pairs/ ./fragment_pairs/
Sampled 35866 rows from repaired_df for later analyses.

Our catch-all condition needs further filtering. Let's filter the structure on distance - take structures which:

  • have at least 5 residues in interaction (within one layer) <= 4.5 Angstroms
  • at least 2 residues are used on each side of the layer
# Filter the sample
stricter_paths = []
with open("fragment_pairs_filtered_five_two_sample.txt", "r") as f:
    for line in f:
        name = line.split('/')[-1]
        name = name if name[-1] != '\n' else name[:-1] 
        stricter_paths.append({"PDB_File": name})

filter_helper = pd.DataFrame(stricter_paths)
filter_helper = filter_helper[filter_helper["PDB_File"].isin(set(sample_repaired["PDB_File"]))]
sample_repaired = pd.merge(sample_repaired, filter_helper, on=["PDB_File"], how='right')
sample_repaired = sample_repaired.dropna()
print(f"Correctly filtered on a condition, real sample has {len(sample_repaired)} structures")
sample_repaired.head()
Correctly filtered on a condition, real sample has 12304 structures
PDB_File Total_Energy Backbone_HBond Sidechain_HBond Van_der_Waals Electrostatics Solvation_Polar Solvation_Hydrophobic
0 7qkj_pair_765_Repair.pdb 18.45160 -30.7581 -8.22142 -42.7342 3.79509 61.2674 -53.1292
1 8olq_pair_012_Repair.pdb 50.32270 -36.4475 -31.34510 -59.5809 11.08980 101.5600 -67.5650
2 8ci8_pair_225_Repair.pdb 12.05720 -36.2584 -6.14683 -67.4686 4.42839 89.9917 -93.9525
3 7nrq_pair_279_Repair.pdb 49.53490 -31.8512 -9.81061 -45.1977 3.33110 86.0854 -51.5352
4 7qkk_pair_121_Repair.pdb 8.61727 -46.0144 -19.36770 -57.6874 6.74647 89.7038 -71.9427
# Filter the whole repaired_df as well
stricter_paths = []
with open("fragment_pairs_filtered_five_two.txt", "r") as f:
    for line in f:
        name = line.split('/')[-1]
        name = name if name[-1] != '\n' else name[:-1] 
        stricter_paths.append({"PDB_File": name})

filter_repaired = pd.DataFrame(stricter_paths)
repaired_df = pd.merge(repaired_df, filter_repaired, on=["PDB_File"], how='right')
repaired_df = repaired_df.dropna()
print(f"Correctly filtered on a condition, repaired df has {len(repaired_df)} structures now")
repaired_df.head()
Correctly filtered on a condition, repaired df has 61951 structures now
PDB_File Total_Energy Backbone_HBond Sidechain_HBond Van_der_Waals Electrostatics Solvation_Polar Solvation_Hydrophobic
0 8v1n_pair_342_Repair.pdb 75.12660 -24.4209 -10.17090 -35.7653 9.865810 81.5012 -34.3863
1 8otg_pair_073_Repair.pdb 7.37734 -48.2354 -16.49210 -53.7371 5.152580 79.6257 -68.4023
2 7xo2_pair_150_Repair.pdb 38.10280 -29.0117 -8.14856 -39.2094 2.617550 63.2756 -48.1347
3 7xo3_pair_287_Repair.pdb 30.79310 -32.1495 -10.58240 -46.0042 -0.740443 73.4309 -57.7334
4 8ot4_pair_130_Repair.pdb 75.06350 -16.4509 -20.34710 -43.9366 10.024300 95.2448 -44.8586
with open("fragment_pairs_filtered_for_clustering.txt", "w") as f:
    for file in repaired_df['PDB_File'].unique():
        f.write(file + "\n")

Stability

The first thing to look at is if our dataset's estimation of total energy (with FoldX's Stability command) matches with the Eisenberg's

# Do a histogram of total energy for both dataframes for comparison
template_label = f"Template (n={len(template_df)})"
extracted_label = f"Extracted (n={len(extracted_df)})"
repaired_label = f"Repaired (n={len(repaired_df)})"

all_stability_data = np.concatenate([
    template_df["Total_Energy"].values,
    #extracted_df["Total_Energy"].values,
    repaired_df["Total_Energy"].values
])

# Define bin edges
bins = np.linspace(all_stability_data.min(), all_stability_data.max(), 106)

plt.figure(figsize=(10, 6))
sns.histplot(template_df["Total_Energy"], bins=bins, color="mediumorchid", label=template_label, kde=False, stat="density")
#sns.histplot(extracted_df["Total_Energy"], bins=50, color="red", label=extracted_label, kde=True, stat="density")
sns.histplot(repaired_df["Total_Energy"], bins=bins, color="cornflowerblue", label=repaired_label, kde=False, stat="density")

plt.xlim(-50, 110)
plt.xlabel("Total Energy")
plt.ylabel("Probability Density")
plt.title("Total Energy Histograms of Template and Repaired fragment pairs (Normalized)")
plt.grid(False)
plt.axvline(0, color="lightgray", linewidth=2, linestyle="-")
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['left'].set_color('dimgray')
plt.gca().spines['bottom'].set_color('dimgray')
plt.gca().set_yticks([0, 0.01, 0.02])

plt.legend()
plt.savefig("plots/total_energy_histogram_fragments.png", dpi=600, bbox_inches="tight")
plt.savefig("plots/total_energy_histogram_fragments.pdf", bbox_inches="tight")
plt.savefig("plots/total_energy_histogram_fragments.svg", bbox_inches="tight")
plt.show()
# Best fragments
repaired_df.nsmallest(50, 'Total_Energy').head()
PDB_File Total_Energy Backbone_HBond Sidechain_HBond Van_der_Waals Electrostatics Solvation_Polar Solvation_Hydrophobic
4514 8pkg_pair_302_Repair.pdb -44.7015 -51.9016 -1.37386 -69.9902 4.036450 70.8653 -107.5340
7157 8q98_pair_008_Repair.pdb -39.7086 -51.9347 -18.13530 -70.7382 1.057940 76.6335 -103.1080
29008 8q98_pair_309_Repair.pdb -39.7062 -51.6920 -15.25990 -70.4398 1.018250 76.2852 -102.5010
44807 7r5h_pair_010_Repair.pdb -39.6906 -50.3670 -12.58230 -70.4432 0.979359 74.7961 -102.7270
13972 8q98_pair_305_Repair.pdb -38.9145 -52.5437 -20.16940 -66.3700 0.840698 73.8202 -94.4693
# Worst fragments
repaired_df.nlargest(50, 'Total_Energy').head()
PDB_File Total_Energy Backbone_HBond Sidechain_HBond Van_der_Waals Electrostatics Solvation_Polar Solvation_Hydrophobic
16336 8bg9_pair_084_Repair.pdb 215.356 -34.8066 -10.9174 -51.8010 44.2806 110.281 -59.6230
28803 8bg9_pair_091_Repair.pdb 215.356 -34.8066 -10.9174 -51.8010 44.2806 110.281 -59.6230
11842 8bg9_pair_104_Repair.pdb 210.551 -35.2233 -11.3284 -52.7205 44.2306 109.575 -61.5739
42611 8bg9_pair_097_Repair.pdb 210.551 -35.2233 -11.3284 -52.7205 44.2306 109.575 -61.5739
8658 8bg9_pair_098_Repair.pdb 201.444 -37.8533 -17.3518 -50.9960 34.9611 110.994 -55.5531

Ramachandran

To quickly check whether we found some new shapes which are not present in Eisenberg's dataset, we plot a Ramachandran plot.

# Create a Ramachandran plot of every hexapeptide within the repaired dataset's sample
def extract_phi_psi_angles(folder, sample=None):
    parser = PDBParser(QUIET=True)
    ppb = PPBuilder()
    phi_psi_all = []
    angle_df = []

    for pdb_file in os.listdir(folder):
        if not pdb_file.endswith(".pdb") or (sample and pdb_file not in sample):
            continue

        structure = parser.get_structure(pdb_file, os.path.join(folder, pdb_file))
        for pp in ppb.build_peptides(structure):
            phi_psi = pp.get_phi_psi_list()
            for phi, psi in phi_psi:
                if phi and psi:  # skip None values (termini)
                    angle_df.append({
                        "PDB_File": pdb_file,
                        "phi": math.degrees(phi),
                        "psi": math.degrees(psi)
                    })
                    phi_psi_all.append((math.degrees(phi), math.degrees(psi)))
    angle_df = pd.DataFrame(angle_df)
    angle_df.to_csv(f"./results/phi_psi_angles_{folder[:-1]}.csv", index=False)
    return phi_psi_all

# === Extract phi/psi angles if not computed already ===
if os.path.exists("./results/phi_psi_angles_sample.npz") and os.path.exists("./results/phi_psi_angles_sample.npz"):
    angles1 = np.load("./results/phi_psi_angles_sample.npz")["angles"]
    angles2 = np.load("./results/phi_psi_angles_templates.npz")["angles"]
else:
    angles1 = extract_phi_psi_angles("fragment_pairs/", set(sample_repaired["PDB_File"].unique())) # only from the sample
    angles2 = extract_phi_psi_angles("templates/")
    np.savez_compressed("./results/phi_psi_angles_sample.npz", angles=np.array(angles1))
    np.savez_compressed("./results/phi_psi_angles_templates.npz", angles=np.array(angles2))
fig, axs = plt.subplots(1, 2, figsize=(12, 6), sharex=True, sharey=True)

phi1, psi1 = zip(*angles1)
axs[0].scatter(phi1, psi1, s=1, alpha=0.008, color="cornflowerblue")
axs[0].set_title(f"Ramachandran Plot - Extracted (n={len(phi1)})")
axs[0].set_xlabel("ϕ")
axs[0].set_ylabel("ψ")
axs[0].grid(False)
for spine in axs[0].spines.values():
    spine.set_color('gray')
axs[0].axvline(0, color="lightgray", linewidth=2, linestyle="-")
axs[0].axhline(0, color="lightgray", linewidth=2, linestyle="-")

phi2, psi2 = zip(*angles2)
axs[1].scatter(phi2, psi2, s=1, alpha=0.1, color="mediumorchid")
axs[1].set_title(f"Ramachandran Plot - Templates (n={len(phi2)})")
axs[1].set_xlabel("ϕ")
axs[1].grid(False)
for spine in axs[1].spines.values():
    spine.set_color('gray')
axs[1].axvline(0, color="lightgray", linewidth=2, linestyle="-")
axs[1].axhline(0, color="lightgray", linewidth=2, linestyle="-")

# Set shared limits
for ax in axs:
    ax.set_xlim(-180, 180)
    ax.set_ylim(-180, 180)

plt.savefig("plots/ramachandran.png", dpi=600, bbox_inches="tight")
plt.savefig("plots/ramachandran.pdf", bbox_inches="tight")
plt.savefig("plots/ramachandran.svg", bbox_inches="tight")
plt.tight_layout()
plt.show()
# Convert to numpy arrays
phi1, psi1 = np.array([a[0] for a in angles1]), np.array([a[1] for a in angles1])
phi2, psi2 = np.array([a[0] for a in angles2]), np.array([a[1] for a in angles2])

# Create subplots
fig, axs = plt.subplots(1, 2, figsize=(12, 6), sharex=True, sharey=True)

# Define shared histogram parameters
bins = 75  # adjust to control smoothness
range_limits = [[-180, 180], [-180, 180]]

# Left plot: Extracted
h1 = axs[0].hist2d(phi1, psi1, bins=bins, range=range_limits,
                   cmap="Blues", cmin=1)
axs[0].set_title(f"Ramachandran Density - Extracted (n={len(phi1)})")
axs[0].set_xlabel("ϕ")
axs[0].set_ylabel("ψ")
axs[0].axvline(0, color="lightgray", linewidth=2)
axs[0].axhline(0, color="lightgray", linewidth=2)
axs[0].set_xlim(-180, 180)
axs[0].set_ylim(-180, 180)
axs[0].grid(False)
for spine in axs[0].spines.values():
    spine.set_color('gray')
fig.colorbar(h1[3], ax=axs[0], label='Density')

# Right plot: Templates
h2 = axs[1].hist2d(phi2, psi2, bins=bins, range=range_limits,
                   cmap="Purples", cmin=1)
axs[1].set_title(f"Ramachandran Density - Templates (n={len(phi2)})")
axs[1].set_xlabel("ϕ")
axs[1].axvline(0, color="lightgray", linewidth=2)
axs[1].axhline(0, color="lightgray", linewidth=2)
axs[1].set_xlim(-180, 180)
axs[1].set_ylim(-180, 180)
axs[1].grid(False)
for spine in axs[1].spines.values():
    spine.set_color('gray')
fig.colorbar(h2[3], ax=axs[1], label='Density')

plt.tight_layout()
#plt.savefig("plots/ramachandran_density.png", dpi=600, bbox_inches="tight")
#plt.savefig("plots/ramachandran_density.pdf", bbox_inches="tight")
#plt.savefig("plots/ramachandran_density.svg", bbox_inches="tight")
plt.show()
# Convert to numpy arrays
phi1, psi1 = np.array([a[0] for a in angles1]), np.array([a[1] for a in angles1])
phi2, psi2 = np.array([a[0] for a in angles2]), np.array([a[1] for a in angles2])

# Create subplots
fig, axs = plt.subplots(1, 2, figsize=(12, 6), sharex=True, sharey=True)

# Common axis limits
lims = (-180, 180)

# Left plot: Extracted
sns.kdeplot(
    x=phi1, y=psi1, 
    fill=True, cmap="Blues", 
    bw_adjust=0.8,  # smaller = more detailed, larger = smoother
    levels=10, 
    thresh=0.01,  # cut off very low-density noise
    gridsize=50,
    ax=axs[0]
)
axs[0].set_title(f"Ramachandran Density - Extracted (n={len(phi1)})")
axs[0].set_xlim(lims)
axs[0].set_ylim(lims)
axs[0].axvline(0, color="lightgray", linewidth=2)
axs[0].axhline(0, color="lightgray", linewidth=2)
axs[0].set_xlabel("ϕ")
axs[0].set_ylabel("ψ")
axs[0].grid(False)
for spine in axs[0].spines.values():
    spine.set_color('gray')

# Right plot: Templates
sns.kdeplot(
    x=phi2, y=psi2, 
    fill=True, cmap="Purples", 
    bw_adjust=0.8, 
    levels=10, 
    thresh=0.01,
    gridsize=50,
    ax=axs[1]
)
axs[1].set_title(f"Ramachandran Density - Templates (n={len(phi2)})")
axs[1].set_xlim(lims)
axs[1].set_ylim(lims)
axs[1].axvline(0, color="lightgray", linewidth=2)
axs[1].axhline(0, color="lightgray", linewidth=2)
axs[1].set_xlabel("ϕ")
axs[1].grid(False)
for spine in axs[1].spines.values():
    spine.set_color('gray')

plt.tight_layout()
# plt.savefig("plots/ramachandran_sns_kde.png", dpi=600, bbox_inches="tight")
# plt.savefig("plots/ramachandran_sns_kde.pdf", bbox_inches="tight")
# plt.savefig("plots/ramachandran_sns_kde.svg", bbox_inches="tight")
plt.show()
fig, ax = plt.subplots(figsize=(7, 7))

ax.axvline(0, color="lightgray", linewidth=2, linestyle="-", zorder=0)
ax.axhline(0, color="lightgray", linewidth=2, linestyle="-", zorder=0)

# Extracted dataset (background)
phi1, psi1 = zip(*angles1)
ax.scatter(phi1, psi1, s=2, alpha=0.008, color="cornflowerblue", label=f"Extracted (n={len(phi1)})", edgecolors="none")

# Templates dataset (overlay)
phi2, psi2 = zip(*angles2)
ax.scatter(phi2, psi2, s=2, alpha=1, color="violet", label=f"Templates (n={len(phi2)})", edgecolors="none")
# Changed from mediumorchid / plum

# Styling
ax.set_title("Ramachandran Plot - Extracted vs Templates", fontsize=12)
ax.set_xlabel("ϕ")
ax.set_ylabel("ψ")
ax.set_xlim(-180, 180)
ax.set_ylim(-180, 180)
ax.grid(False)

ax.set_aspect("equal", adjustable="box")

for spine in ax.spines.values():
    spine.set_color('gray')

ax.legend(frameon=False, fontsize=10, loc="upper right")

plt.tight_layout()
plt.savefig("plots/ramachandran_overlay.png", dpi=600, bbox_inches="tight")
plt.savefig("plots/ramachandran_overlay.pdf", bbox_inches="tight")
plt.savefig("plots/ramachandran_overlay.svg", bbox_inches="tight")
plt.show()
angle_temp_df = pd.read_csv("./results/phi_psi_angles_templates.csv")
angle_temp_df.head()
PDB_File phi psi
0 4r0u_a_1.pdb -119.436120 120.686203
1 4r0u_a_1.pdb -134.309295 128.800200
2 4r0u_a_1.pdb -128.759595 127.020193
3 4r0u_a_1.pdb -132.218864 123.946289
4 4r0u_a_1.pdb -119.436119 120.686197
alpha_hel = angle_temp_df[(angle_temp_df["phi"] < -50) & (angle_temp_df["psi"] < 0) & (angle_temp_df["phi"] > -100) & (angle_temp_df["psi"] > -120)]
alpha_hel_left = angle_temp_df[(angle_temp_df["phi"] > 50) & (angle_temp_df["phi"] < 100) & (angle_temp_df["psi"] > 0) & (angle_temp_df["psi"] < 100)]
val_counts_alpha = alpha_hel['PDB_File'].value_counts()
val_counts_alpha_left = alpha_hel_left['PDB_File'].value_counts()
val_counts_alpha_left
PDB_File
4znn_a_1.pdb    10
3ftl_a_1.pdb    10
3ftk_1.pdb      10
5whp.pdb        10
4znn_b_4.pdb    10
3ftl_a_2.pdb    10
4znn_a_2.pdb    10
4znn_b_2.pdb    10
3ftk_2.pdb      10
4znn_b_3.pdb    10
4znn_a_3.pdb    10
3ftl_b_1.pdb    10
3ftl_b_2.pdb    10
4znn_b_1.pdb    10
4znn_a_4.pdb    10
3pzz.pdb         4
Name: count, dtype: int64

Sequence detail

To assess each amino-acid's contribution to the stability their structures, we use FoldX's SequenceDetail command.

# Load sequence detail
print ("Loading sequence detail datasets...")
seqdetail_templates = pd.read_csv("./results/foldx_seqdetail_results_templates.csv")
seqdetail = pd.read_csv("./results/foldx_seqdetail_results_fragment_pairs.csv")
## Get rid of non-standardized files
seqdetail = filter_non_standardized(seqdetail)
print("Loaded")

print("Head of Sequence detail dataset of Repaired fragments")
seqdetail.head()
Loading sequence detail datasets...
Loaded
Head of Sequence detail dataset of Repaired fragments
PDB_File Three_Letter Chain Residue_Number Total_Energy BackHBond SideHBond
0 ./SD_2n0a_pair_2691_Repair.pdb ALA C 85 -0.032793 0.000000 0.0
1 ./SD_2n0a_pair_2691_Repair.pdb GLY C 86 0.657411 0.000000 0.0
2 ./SD_8oq4_pair_180_Repair.pdb GLN F 37 0.643152 0.000000 0.0
3 ./SD_2n0a_pair_5190_Repair.pdb H2S F 50 0.514245 -0.485031 0.0
4 ./SD_8uq7_pair_304_Repair.pdb GLY A 335 0.397621 -0.596954 0.0
# Check how many residues in seqdetail come from pairs of NON-5-layer proteins.
percentage=round(100*len(seqdetail[(seqdetail['PDB_File'].str[5:9]).isin(non_five_layer_proteins)])/len(seqdetail), 4)
print(f"{len(seqdetail[(seqdetail['PDB_File'].str[5:9]).isin(non_five_layer_proteins)])}/{len(seqdetail)} ({percentage}%) elements come from NON-5-layered structures")

# Remove unwanted elements from the dataframe
seqdetail = seqdetail[~(seqdetail['PDB_File'].str[5:9]).isin(non_five_layer_proteins)]
print(f"{len(seqdetail)} is the length of the seqdetail dataset after correction.")
151104/10910230 (1.385%) elements come from NON-5-layered structures
10759126 is the length of the seqdetail dataset after correction.
# Distance filtering
seqdetail = seqdetail[(seqdetail['PDB_File'].str[5:]).isin(set(repaired_df["PDB_File"]))]
print(f"After filtering based on length, we have {len(seqdetail)} residues")
After filtering based on length, we have 3716740 residues

Let's check the dataset's redundancies (in sequence).

# TODO: Use already known functions (to get the ordered list of atoms - and then residues)
print("TODO")
TODO

After seqdetail dataset's initial filtering, we check the residue distribution between the sequence detail datasets of both repaired hexapeptide pairs (our repaired extracted fragments) and template structures (Eisenberg's).

seqdetail['Three_Letter'].value_counts()
Three_Letter
GLY    600140
VAL    509335
ALA    356712
LYS    318728
THR    261343
SER    250223
ILE    181325
GLN    179907
LEU    165575
ASN    157148
GLU    141643
PRO    102752
ASP     79595
TYR     75767
H1S     71267
PHE     65764
H2S     55404
MET     53825
CYS     37430
ARG     25908
HIS     18514
TRP      7930
PTR       505
Name: count, dtype: int64
# FIXME Get rid of hetatoms? e.g. ZN
print("Residues and het-atoms in the template dataset")
seqdetail_templates['Three_Letter'].value_counts()
Residues and het-atoms in the template dataset
Three_Letter
GLY    1930
VAL    1760
ALA     950
THR     920
ASN     850
SER     840
LEU     840
ILE     800
TYR     370
PHE     330
GLN     310
MET     300
LYS     240
GLU     130
H1S     111
ZN       70
ASP      40
HIS      32
PRO      30
H2S      27
TRP      20
Name: count, dtype: int64

Total Energies

Let's see the total energy of each individual residue - again comparing ours and Eisenberg's dataset.

desc = seqdetail['Total_Energy'].describe()
print(desc.apply(lambda x: f"{x:.6f}"))
count    3716740.000000
mean           0.489640
std            1.178426
min           -4.176860
25%           -0.152940
50%            0.527413
75%            1.212220
max           14.548800
Name: Total_Energy, dtype: object
desc = seqdetail_templates['Total_Energy'].describe()
print(desc.apply(lambda x: f"{x:.6f}"))
count    10906.000000
mean         0.205430
std          1.178158
min         -4.129810
25%         -0.446808
50%          0.259494
75%          0.908089
max          7.062650
Name: Total_Energy, dtype: object

Our best (lowest total energy) and worst (highest) residues:

# Best residues according to seqdetail energy estimation
best_repaired = repaired_df.nsmallest(100, "Total_Energy").head() # TODO

seqdetail.nsmallest(100, "Total_Energy").head()
PDB_File Three_Letter Chain Residue_Number Total_Energy BackHBond SideHBond
3577542 ./SD_8r47_pair_205_Repair.pdb ARG C 28 -4.17686 -0.247444 -1.69147
7354890 ./SD_8r47_pair_199_Repair.pdb ARG C 28 -4.17686 -0.247444 -1.69147
4467223 ./SD_7u11_pair_006_Repair.pdb TYR A 125 -4.08744 -1.371750 -1.56968
650860 ./SD_9fnb_pair_004_Repair.pdb TYR B 125 -3.96797 -1.542940 -1.04820
7055515 ./SD_9fnb_pair_005_Repair.pdb TYR B 125 -3.95643 -1.549240 -1.04820
# Worst residues according to seqdetail energy estimation
worst_repaired = repaired_df.nlargest(100, "Total_Energy").head() # TODO

seqdetail.nlargest(100, "Total_Energy").head()
PDB_File Three_Letter Chain Residue_Number Total_Energy BackHBond SideHBond
4182516 ./SD_8bg9_pair_085_Repair.pdb GLU I 22 14.5488 -0.687146 0.0
7962205 ./SD_8bg9_pair_092_Repair.pdb GLU I 22 14.5488 -0.687146 0.0
805688 ./SD_8bg9_pair_091_Repair.pdb GLU A 22 14.3261 -0.622249 0.0
1120332 ./SD_8bg9_pair_084_Repair.pdb GLU A 22 14.3261 -0.622249 0.0
4180836 ./SD_8bg9_pair_085_Repair.pdb GLU A 22 14.2986 -0.626909 0.0

To see which residues make our worst extracted hexapeptides pairs bad, we take a sample which consists of those whose total energy (estimated with Stability command) is among the 1000 largest. We compare those bad fragments with all of our fragments.

# Label bad fragment's sequence detail elements (highest energy fragments according to Stability)
print(len(repaired_df[repaired_df['Total_Energy'] >= 80]))
print(f"{100*len(repaired_df[repaired_df['Total_Energy'] >= 80])/len(repaired_df)} %")

highest_energy_repaired = repaired_df.nlargest(800, "Total_Energy")
seqdetail_bad = seqdetail[(seqdetail["PDB_File"].str[5:]).isin(highest_energy_repaired["PDB_File"])]
# Add a category column to differentiate both datasets
seqdetail.loc[:, "Category"] = "All Fragments"
seqdetail_bad = seqdetail_bad.copy()
seqdetail_bad.loc[:, "Category"] = "Bad Fragments"
799
1.2897289793546514 %
paths = set(seqdetail_bad["PDB_File"].str[5:].unique())
with open("pdb_file_list_bad_test.txt", "w") as f:
    for pdb in paths:
        f.write(pdb + "\n")
# Then use:
# rsync -av --progress --partial --append-verify --files-from=pdb_file_list_bad_test.txt meta:/storage/praha1/home/tobiasma/switch-lab/amyloid-interactions/fragment_pairs/ ./fragment_pairs/
# Define amino acid groups (uppercase)
amino_acid_groups = {
    "Hydrophobic": ["ALA", "VAL", "LEU", "ILE", "MET", "PHE", "TRP", "PRO"],
    "Polar": ["SER", "THR", "ASN", "GLN", "TYR", "CYS", "GLY"],  # Added Glycine to Polar
    "Charged_Positive": ["LYS", "ARG", "HIS"],
    "Charged_Negative": ["ASP", "GLU"]
}

# Define the custom order for x-axis (grouped)
custom_order = (
    amino_acid_groups["Hydrophobic"] +
    amino_acid_groups["Polar"] +
    amino_acid_groups["Charged_Positive"] +
    amino_acid_groups["Charged_Negative"]
)

# Define color palette for groups
group_colors = {
    "Hydrophobic": "yellow", # #bdc6d5
    "Polar": "cyan", # #e7c65b
    "Charged_Positive": "red", # #4c99cf
    "Charged_Negative": "blue", # #df6b6a
    "Unknown": "gray"
}
# Merge both datasets
combined_data = pd.concat([seqdetail, seqdetail_bad], ignore_index=True)

# Assign the "Unknown" group for amino acids not in the predefined groups
all_amino_acids = set(combined_data["Three_Letter"].unique())  # Get unique amino acids from data
known_amino_acids = set(sum(amino_acid_groups.values(), []))  # Flatten list of known amino acids
amino_acid_groups["Unknown"] = list(all_amino_acids - known_amino_acids) # Find any unknown amino acids
custom_order += (amino_acid_groups["Unknown"])

aa_color_map = {aa: group_colors[group] for group, amino_acids in amino_acid_groups.items() for aa in amino_acids}
plt.figure(figsize=(18, 6))

# Create a violin plot with split violins
sns.violinplot(
    x="Three_Letter", y="Total_Energy", hue="Category", 
    data=combined_data, split=True, inner="quartile",
    palette={"All Fragments": "cornflowerblue", "Bad Fragments": "lightcoral"},
    order=custom_order
)

plt.xlabel("Amino Acid")
plt.ylabel("Total Energy")
plt.title("Total Energy Contribution of Each Amino Acid of All Fragments and 100 Worst Fragments")
plt.legend(title="Fragment Type")
plt.xticks(rotation=45)
plt.ylim(-6, 13.5)

ax = plt.gca()
handles, labels = ax.get_legend_handles_labels()
hue_labels = {
    "All Fragments": f"All Fragments (n={len(seqdetail)} Res)",
    "Bad Fragments": f"Bad Fragments (n={len(seqdetail_bad)} Res)"
}
new_labels = [hue_labels.get(label, label) for label in labels]
ax.legend(handles, new_labels, title="Fragment Type")

for label in ax.get_xticklabels():
    aa = label.get_text()
    color = aa_color_map.get(aa)
    label.set_bbox(dict(facecolor=color, edgecolor='none', boxstyle='round,pad=0.3', alpha=0.5)) # Apply background color

plt.show()

Let's plot the estimated energy of each residue to compare the datasets.

seqdetail_templates.loc[:, "Category"] = "Template Fragments"

plt.figure(figsize=(10, 6))

# Merge both dataset (the whole seqdetail and template's seqdetail)
combined_data = pd.concat([seqdetail_templates, seqdetail], ignore_index=True)

# kde plot of seqdetail templates vs seqdetail repaired total energy (of all residues, not-type specific)
sns.kdeplot(
    data=combined_data, x="Total_Energy",
    hue="Category",
    fill=True, common_norm=False,
    palette={"Template Fragments": "mediumorchid", "All Fragments": "cornflowerblue"},
    alpha=0.5
)

plt.axvline(0, color="lightgray", linewidth=2, linestyle="--")
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['left'].set_color('dimgray')
plt.gca().spines['bottom'].set_color('dimgray')
plt.gca().set_yticks([0, 0.1, 0.2, 0.3, 0.4])

plt.xlabel("Total Energy")
plt.ylabel("Density")
plt.grid(False)
plt.xlim(-6, 6)
plt.title("Per-residue Total Energy Distribution of Extracted Fragments and Template Fragments")
plt.savefig("plots/total_energy_kde_residues.png", dpi=600, bbox_inches="tight")
plt.savefig("plots/total_energy_kde_residues.pdf", bbox_inches="tight")
plt.savefig("plots/total_energy_kde_residues.svg", bbox_inches="tight")
plt.show()

Relative Solvent Accessibility

We compute and compare the Relative Solvent Accessibility (RSA) of our bad sample's residues and the residues from structures in our repaired_df sample.

# Biopython's solvent accessibility per residue (one number per residue)

AA_TO_MAX_SASA_MAP = {
"ALA": 129,
"ARG": 274,
"ASN": 195,
"ASP": 193,
"CYS": 167,
"GLN": 223,
"GLU": 225,
"GLY": 104,
"HIS": 224,
"ILE": 197,
"LEU": 201,
"LYS": 236,
"MET": 224,
"PHE": 240,
"PRO": 159,
"SER": 155,
"THR": 155,
"TRP": 285,
"TYR": 263,
"VAL": 174
}

def clip_value(lower, value, upper):
    return lower if value < lower else upper if value > upper else value

def remove_heteroatoms(structure):
    # Cannot delete heteroatoms (which are considered as residues in biopython) while iterating over structure.get_residues(), so we first need to get a list of their ids
    chain_to_heteroatoms_map = defaultdict(list)

    model = structure[0]

    for residue in model.get_residues():
        if not is_aa(residue, standard=True):
            chain_to_heteroatoms_map[residue.parent.id].append(residue.id)

    for chain, heteroatoms in chain_to_heteroatoms_map.items():
        try:
            chain_object = model[chain]
        except KeyError:
            print(f'Chain {chain} not found in structure.')
            print(f'Chains in structure: {[chain.id for chain in model.get_chains()]}')
            print(f'Chain to heteroatoms map: {chain_to_heteroatoms_map}')
            raise
        for heteroatom in heteroatoms:
            chain_object.detach_child(heteroatom)
    return

def calculate_residue_RSA(residue):
    residue_sasa, residue_max_sasa = residue.sasa, AA_TO_MAX_SASA_MAP[residue.resname]
    RSA = (residue_sasa / residue_max_sasa) * 100
    return round(RSA, ndigits = None) # Returns an integer type, which is what we want because the RSA probabilities table is set up using bins of size 1.

def get_residue_RSA_values(PDB_file_path, full_residue_IDs_list, chain_subset=None):
    """
    The RSA (Relative Solvent Accessibility) values of the list of full residue IDs are calculated in the context of the provided chain subset. For example, if the
    full PDB is an antibody antigen complex with chains AHL, chain_subset = 'HL' will calculate RSA values of residues in the HL complex context and ignore
    chain A. This is important to calculate the RSA values of interface residues correctly.
    """
    parser = PDBParser(QUIET = True)
    structure = parser.get_structure('', PDB_file_path)

    remove_heteroatoms(structure) # If not removed, heteroatoms would be accounted for when calculating the SASA of each residue.

    if chain_subset is not None:
        chains_to_remove = set(chain.id for chain in structure.get_chains()) - set(chain_subset)
        for chain_ID in chains_to_remove:
            structure.detach_child(chain_ID)

    sasa_calculator = SASA.ShrakeRupley()
    sasa_calculator.compute(structure, level='R')

    full_residue_IDs_set = set(full_residue_IDs_list)
    full_residue_ID_to_RSA_map = {}
    for residue in structure.get_residues():
        AA = protein_letters_3to1[residue.resname.title()]
        structure_ID, model_ID, chain_ID, (heteroatom_flag, position, icode) = residue.full_id # Biopython residue ID

        full_residue_ID = f'{AA}{chain_ID}{position}'
        if full_residue_ID not in full_residue_IDs_set:
            continue
        residue_RSA = calculate_residue_RSA(residue)
        full_residue_ID_to_RSA_map[full_residue_ID] = clip_value(lower=0, value=residue_RSA, upper=99) # The RSA bins in the probability table range between 0 and 99.

    residue_RSA_values = [full_residue_ID_to_RSA_map[full_residue_ID] for full_residue_ID in full_residue_IDs_list]
    # check if it doesnt contain none values:
    if None in residue_RSA_values:
        print(f'Full residue IDs list: {residue_RSA_values}')
        print(f'Full residue ID to RSA map: {full_residue_ID_to_RSA_map}')
        print(PDB_file_path, structure.get_residues())

    if len(residue_RSA_values) != len(full_residue_IDs_list):
        print(f'Full residue IDs list: {full_residue_IDs_list}')
        print(f'Full residue ID to RSA map: {full_residue_ID_to_RSA_map}')
        print(PDB_file_path, structure.get_residues())
        raise ValueError('Some residues do not have RSA values calculated.')

    return residue_RSA_values

def process_pdb_RSA(pdb_file, directory):
    parser = PDBParser(QUIET=True)
    try:
        structure = parser.get_structure(pdb_file, f"{directory}/{pdb_file}")
    except PDBExceptions.PDBConstructionException as e:
        print(f"Failed to parse {pdb_file}: {e}")
        return pdb_file, []

    residue_ID_list = []
    residue_key = []
    model = structure[0]  # Assuming we want the first model
    for chain in model:
        for residue in chain:
            structure_ID, model_ID, chain_ID, (heteroatom_flag, position, icode) = residue.get_full_id()
            try:
                residue_ID_list.append(f"{protein_letters_3to1[residue.resname.title()]}{chain_ID}{position}")
            except KeyError:
                continue
            residue_key.append((chain_ID, position))

    RSA_values = get_residue_RSA_values(f"{directory}/{pdb_file}", residue_ID_list)
    result_list = []
    for i, residue_ID in enumerate(residue_ID_list):
        chain_ID, position = residue_key[i]
        result_list.append((chain_ID, position, RSA_values[i]))

    return pdb_file, result_list

def get_residue_RSA_values_from_df(seqdetail_df, directory="./fragment_pairs", max_workers=8):
    """
    Parallel RSA (Relative Solvent Accessibility) calculation for residues in the provided dataframe.
    """
    seqdetail_df = seqdetail_df.copy()
    seqdetail_df['PDB_File'] = seqdetail_df['PDB_File'].str[5:]
    pdb_files = seqdetail_df['PDB_File'].unique()

    updates = []  # collect updates here

    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(process_pdb_RSA, pdb_file, directory): pdb_file for pdb_file in pdb_files}
        for future in as_completed(futures):
            pdb_file, results = future.result()
            for chain_ID, position, rsa_value in results:
                updates.append({
                    'PDB_File': pdb_file,
                    'Chain': chain_ID,
                    'Residue_Number': position,
                    'RSA': rsa_value
                })

    # Create DataFrame from updates
    updates_df = pd.DataFrame(updates)

    # Merge in the new RSA values
    seqdetail_df = seqdetail_df.merge(
        updates_df,
        on=['PDB_File', 'Chain', 'Residue_Number'],
        how='left'
    )

    return seqdetail_df
# Create a seqdetail's sample
K = seqdetail[(seqdetail["PDB_File"].str[5:]).isin(repaired_df["PDB_File"])] # FIXME: sample_repaired["PDB_File"]
## Label seqdetail categories
K.loc[:, "Category"] = "Sampled Fragments"

# Get the residue RSA for each dataset
if os.path.exists("./results/templates_with_rsa.csv") and os.path.exists("./results/sample_with_rsa.csv") and os.path.exists("./results/bad_with_rsa.csv"):
    seqdetail_templates_with_RSA = pd.read_csv("./results/templates_with_rsa.csv")
    seqdetail_sample_with_RSA = pd.read_csv("./results/sample_with_rsa.csv")
    seqdetail_bad_with_RSA = pd.read_csv("./results/bad_with_rsa.csv")
else:
    seqdetail_templates_with_RSA = get_residue_RSA_values_from_df(seqdetail_templates, directory="./templates")
    seqdetail_sample_with_RSA = get_residue_RSA_values_from_df(K)
    seqdetail_bad_with_RSA = get_residue_RSA_values_from_df(seqdetail_bad)
    seqdetail_templates_with_RSA.to_csv("./results/templates_with_rsa.csv", index=False)
    seqdetail_sample_with_RSA.to_csv("./results/sample_with_rsa.csv", index=False)
    seqdetail_bad_with_RSA.to_csv("./results/bad_with_rsa.csv", index=False)
len(seqdetail_sample_with_RSA)
738195
plt.figure(figsize=(18, 6))

# Merge both seqdetail datasets (bad and sampled fragments)
combined_data = pd.concat([seqdetail_sample_with_RSA, seqdetail_bad_with_RSA], ignore_index=True)

# Create a violin plot with split violins
sns.violinplot(
    x="Three_Letter", y="RSA", hue="Category", 
    data=combined_data, split=True, inner="quartile",
    palette={"Sampled Fragments": "cornflowerblue", "Bad Fragments": "lightcoral"},
    order=custom_order,
    cut=0 # Do not estimate KDE out of possible range (make it strictly 0-100)
)

plt.xlabel("Amino Acid")
plt.ylabel("RSA")
plt.title("RSA of Residues in Fragments With the Highest Total Energy Compared to Sampled Fragments")
plt.legend(title="Fragment Type")
plt.xticks(rotation=45)
plt.axhline(0, color="lightgray", linewidth=2)
plt.axhline(100, color="lightgray", linewidth=2)

ax = plt.gca()
handles, labels = ax.get_legend_handles_labels()
hue_labels = {
    "Sampled Fragments": f"Sampled Fragments (n={len(seqdetail_sample_with_RSA)} Res)",
    "Bad Fragments": f"Bad Fragments (n={len(seqdetail_bad_with_RSA)} Res)"
}
new_labels = [hue_labels.get(label, label) for label in labels]
ax.legend(handles, new_labels, title="Fragment Type")

for label in ax.get_xticklabels():
    aa = label.get_text()
    color = aa_color_map.get(aa)
    label.set_bbox(dict(facecolor=color, edgecolor='none', boxstyle='round,pad=0.3', alpha=0.5)) # Apply background color

plt.show()

We compare the sample of our dataset to Eisenberg's as well.

plt.figure(figsize=(18, 6))

# Merge both datasets (sampled and seqdetail from templates)
combined_data = pd.concat([seqdetail_sample_with_RSA, seqdetail_templates_with_RSA], ignore_index=True)

# Create a violin plot with split violins
sns.violinplot(
    x="Three_Letter", y="RSA", hue="Category", 
    data=combined_data, split=True, inner="quartile",
    palette={"Sampled Fragments": "cornflowerblue", "Template Fragments": "mediumorchid"},
    alpha=0.8, saturation=1,
    order=custom_order, cut=0 # Do not estimate KDE out of possible range (make it strictly 0-100)
)

plt.xlabel("Amino Acid")
plt.ylabel("RSA")
plt.title("RSA Comparison of Residues From Extracted Fragments and Template Fragments")
plt.legend(title="Fragment Type")
plt.xticks(rotation=45)

plt.gca().set_yticks([0, 20, 40, 60, 80, 100])
plt.ylim(0, 100)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['left'].set_color('dimgray')
plt.gca().spines['bottom'].set_color('dimgray')
plt.grid(False)

ax = plt.gca()
handles, labels = ax.get_legend_handles_labels()
hue_labels = {
    "Sampled Fragments": f"Sampled Fragments (n={len(seqdetail_sample_with_RSA)} Res)",
    "Template Fragments": f"Template Fragments (n={len(seqdetail_templates_with_RSA)} Res)"
}
new_labels = [hue_labels.get(label, label) for label in labels]
ax.legend(handles, new_labels, title="Fragment Type")

for label in ax.get_xticklabels():
    aa = label.get_text()
    color = aa_color_map.get(aa)
    label.set_bbox(dict(facecolor=color, edgecolor='none', boxstyle='round,pad=0.3', alpha=0.5)) # Apply background color


plt.savefig("plots/rsa_violin_all_residues.png", dpi=600, bbox_inches="tight")
plt.savefig("plots/rsa_violin_all_residues.pdf", bbox_inches="tight")
plt.savefig("plots/rsa_violin_all_residues.svg", bbox_inches="tight")
plt.show()

We divide the residues in 3 categories - Stabilizing, Neutral and Destabilizing. Stabilizing residues have estimated total energies < -0.5, Destabilizing > 0.5 and Neutral all the values between. Let's see whether the distributions match across datasets.

# Create a new column with default value (optional, but good practice)
seqdetail_templates_with_RSA["Stability"] = "Unknown"
seqdetail_sample_with_RSA["Stability"] = "Unknown"

# Assign values using .loc
seqdetail_templates_with_RSA.loc[seqdetail_templates_with_RSA["Total_Energy"] < -0.5, "Stability"] = "Stabilizing"
seqdetail_templates_with_RSA.loc[seqdetail_templates_with_RSA["Total_Energy"] >= -0.5, "Stability"] = "Neutral"
seqdetail_templates_with_RSA.loc[seqdetail_templates_with_RSA["Total_Energy"] > 0.5, "Stability"] = "Destabilizing"

seqdetail_sample_with_RSA.loc[seqdetail_sample_with_RSA["Total_Energy"] < -0.5, "Stability"] = "Stabilizing"
seqdetail_sample_with_RSA.loc[seqdetail_sample_with_RSA["Total_Energy"] >= -0.5, "Stability"] = "Neutral"
seqdetail_sample_with_RSA.loc[seqdetail_sample_with_RSA["Total_Energy"] > 0.5, "Stability"] = "Destabilizing"

# Combine the DataFrames
combined_data = pd.concat([seqdetail_templates_with_RSA, seqdetail_sample_with_RSA], ignore_index=True)
plt.figure(figsize=(10, 6))

# sns.violinplot of destabilizing and stabilizing RSA
sns.violinplot(
    x="Stability", y="RSA", hue="Category",
    data=combined_data, split=True, inner="quartile",
    palette={"Sampled Fragments": "cornflowerblue", "Template Fragments": "mediumorchid"},
    order=["Stabilizing", "Neutral", "Destabilizing"],
    alpha=0.8, saturation=1,
    cut=0
)
plt.xlabel("Stability")
plt.ylabel("RSA")
plt.title("RSA Comparison of Different Types of Residues between Sampled and Template Fragments")
plt.legend(title="Fragment Type")

plt.gca().set_yticks([0, 20, 40, 60, 80, 100])
plt.ylim(0, 100)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['left'].set_color('dimgray')
plt.gca().spines['bottom'].set_color('dimgray')
plt.grid(False)

ax = plt.gca()
handles, labels = ax.get_legend_handles_labels()
new_labels = [hue_labels.get(label, label) for label in labels]
ax.legend(handles, new_labels, title="Fragment Type")

plt.savefig("plots/rsa_stabilizing_destabilizing_residues.png", dpi=600, bbox_inches="tight")
plt.savefig("plots/rsa_stabilizing_destabilizing_residues.pdf", bbox_inches="tight")
plt.savefig("plots/rsa_stabilizing_destabilizing_residues.svg", bbox_inches="tight")
plt.show()

Hydrogen Bonds

We compare the hydrogen bonds as well. Backbone and Side-chain.

plt.figure(figsize=(10, 6))

# Merge both dataset (the whole seqdetail and template's seqdetail)
combined_data = pd.concat([seqdetail_templates, seqdetail], ignore_index=True)

# kde plot of seqdetail templates vs seqdetail repaired backhbond (of all residues, not-type specific)
sns.kdeplot(
    data=combined_data, x="BackHBond",
    hue="Category",
    fill=True, common_norm=False,
    palette={"Template Fragments": "mediumorchid", "All Fragments": "cornflowerblue"},
    alpha=0.5
)
plt.xlabel("Backbone Hydrogen Bond")
plt.ylabel("Density")
plt.title("Per-residue BackHBond Distribution of Extracted Fragments and Template Fragments")

plt.xlim(-2.5, 0.5)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['left'].set_color('dimgray')
plt.gca().spines['bottom'].set_color('dimgray')
plt.grid(False)

plt.savefig("plots/backbone_h_bond_kde.png", dpi=600, bbox_inches="tight")
plt.savefig("plots/backbone_h_bond_kde.pdf", bbox_inches="tight")
plt.savefig("plots/backbone_h_bond_kde.svg", bbox_inches="tight")
plt.show()
plt.figure(figsize=(10, 6))

# kde plot of seqdetail templates vs seqdetail repaired sidehbond (of all residues, not-type specific)
sns.kdeplot(
    data=combined_data, x="SideHBond",
    hue="Category",
    fill=True, common_norm=False,
    palette={"Template Fragments": "mediumorchid", "All Fragments": "cornflowerblue"},
    alpha=0.5
)
plt.xlabel("Sidechain Hydrogen Bond")
plt.ylabel("Density")
plt.title("Per-residue SideHBond Distribution of Extracted Fragments and Template Fragments")

plt.gca().set_yticks([0, 4, 8, 12, 16])
plt.xlim(-2.5, 0.5)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['left'].set_color('dimgray')
plt.gca().spines['bottom'].set_color('dimgray')
plt.grid(False)

plt.savefig("plots/sidechain_h_bond_kde.png", dpi=600, bbox_inches="tight")
plt.savefig("plots/sidechain_h_bond_kde.pdf", bbox_inches="tight")
plt.savefig("plots/sidechain_h_bond_kde.svg", bbox_inches="tight")
plt.show()
# TODO: Use utils/clash_check to check for clashes in the worst 200 fragments of the sample
' '.join([f"fragment_pairs/{file}" for file in list(seqdetail_sample_with_RSA.nlargest(200, 'Total_Energy')['PDB_File'].unique())])

# TODO: Or from ac_extracted in analyse complex (check whether WDV Clashes are high (but this can only work for packing))
'fragment_pairs/8bg9_pair_086_Repair.pdb fragment_pairs/8bg9_pair_104_Repair.pdb fragment_pairs/8bg9_pair_134_Repair.pdb fragment_pairs/8bg9_pair_135_Repair.pdb fragment_pairs/8oh2_pair_152_Repair.pdb fragment_pairs/8q8l_pair_080_Repair.pdb fragment_pairs/8q8l_pair_090_Repair.pdb fragment_pairs/8q8l_pair_097_Repair.pdb fragment_pairs/8q8l_pair_098_Repair.pdb fragment_pairs/8q8l_pair_065_Repair.pdb fragment_pairs/8q8l_pair_064_Repair.pdb fragment_pairs/2m4j_pair_251_Repair.pdb fragment_pairs/8fnz_pair_030_standardized_Repair.pdb fragment_pairs/8q2j_pair_263_Repair.pdb fragment_pairs/7vzf_pair_345_Repair.pdb fragment_pairs/8kf5_pair_070_Repair.pdb fragment_pairs/8kf5_pair_069_Repair.pdb fragment_pairs/8q8u_pair_178_Repair.pdb fragment_pairs/8q2j_pair_268_Repair.pdb fragment_pairs/8q9a_pair_279_Repair.pdb fragment_pairs/8fnz_pair_010_standardized_Repair.pdb fragment_pairs/8bg0_pair_024_Repair.pdb fragment_pairs/8bg0_pair_021_Repair.pdb fragment_pairs/8gf7_pair_014_Repair.pdb fragment_pairs/8bg0_pair_042_Repair.pdb fragment_pairs/8kf5_pair_189_Repair.pdb fragment_pairs/8bg0_pair_139_Repair.pdb fragment_pairs/8oh2_pair_126_Repair.pdb fragment_pairs/8oh2_pair_154_Repair.pdb fragment_pairs/8oi0_pair_237_Repair.pdb fragment_pairs/5o3o_pair_445_Repair.pdb fragment_pairs/8bg0_pair_047_Repair.pdb fragment_pairs/8bg0_pair_041_Repair.pdb fragment_pairs/8gf7_pair_003_Repair.pdb fragment_pairs/8bg0_pair_022_Repair.pdb fragment_pairs/5o3o_pair_152_Repair.pdb fragment_pairs/5o3o_pair_239_Repair.pdb fragment_pairs/8zwm_pair_342_Repair.pdb fragment_pairs/8fnz_pair_025_standardized_Repair.pdb fragment_pairs/8bfb_pair_171_Repair.pdb fragment_pairs/5o3o_pair_404_Repair.pdb fragment_pairs/9fyp_pair_296_Repair.pdb fragment_pairs/5o3o_pair_181_Repair.pdb fragment_pairs/8q8e_pair_467_Repair.pdb fragment_pairs/8oi0_pair_021_Repair.pdb fragment_pairs/8bfb_pair_050_Repair.pdb fragment_pairs/5w3n_pair_167_Repair.pdb fragment_pairs/8bg0_pair_058_Repair.pdb fragment_pairs/5o3o_pair_179_Repair.pdb fragment_pairs/8bfb_pair_178_Repair.pdb fragment_pairs/8bfb_pair_038_Repair.pdb fragment_pairs/5w3n_pair_149_Repair.pdb fragment_pairs/8bfa_pair_035_Repair.pdb fragment_pairs/5w3n_pair_166_Repair.pdb fragment_pairs/9fnb_pair_014_Repair.pdb fragment_pairs/8ote_pair_280_Repair.pdb fragment_pairs/8ote_pair_272_Repair.pdb fragment_pairs/8oi0_pair_097_Repair.pdb fragment_pairs/7sas_pair_537_Repair.pdb fragment_pairs/8ot4_pair_148_Repair.pdb fragment_pairs/7u13_pair_217_Repair.pdb fragment_pairs/8ote_pair_281_Repair.pdb fragment_pairs/8dja_pair_494_Repair.pdb fragment_pairs/8dja_pair_1083_Repair.pdb fragment_pairs/5w3n_pair_148_Repair.pdb fragment_pairs/8org_pair_038_Repair.pdb fragment_pairs/8dja_pair_506_Repair.pdb fragment_pairs/8dja_pair_1067_Repair.pdb fragment_pairs/7sar_pair_465_Repair.pdb fragment_pairs/6ti7_pair_041_Repair.pdb fragment_pairs/8otd_pair_020_Repair.pdb fragment_pairs/8bfa_pair_048_Repair.pdb fragment_pairs/8bfa_pair_162_Repair.pdb fragment_pairs/8bfa_pair_036_Repair.pdb fragment_pairs/8dja_pair_543_Repair.pdb fragment_pairs/5o3o_pair_238_Repair.pdb fragment_pairs/7qvc_pair_297_Repair.pdb fragment_pairs/6ti7_pair_051_Repair.pdb fragment_pairs/8oi0_pair_026_Repair.pdb fragment_pairs/8ote_pair_605_Repair.pdb fragment_pairs/5o3o_pair_224_Repair.pdb fragment_pairs/8bfb_pair_167_Repair.pdb fragment_pairs/8dja_pair_1102_Repair.pdb fragment_pairs/8dja_pair_502_Repair.pdb fragment_pairs/8oq4_pair_557_Repair.pdb fragment_pairs/8dja_pair_557_Repair.pdb fragment_pairs/8oq4_pair_559_Repair.pdb fragment_pairs/8q8w_pair_362_Repair.pdb fragment_pairs/8oq4_pair_446_Repair.pdb fragment_pairs/8e7j_pair_178_Repair.pdb fragment_pairs/8e7j_pair_192_Repair.pdb fragment_pairs/8fnz_pair_257_standardized_Repair.pdb fragment_pairs/8kf6_pair_395_Repair.pdb'

Burried Surface Area

We compute the Burried Surface Area (BSA) of Eisenberg's structures and structures from our sample.

# Burried Surface Area (one number per fragment)

import io
from collections import defaultdict
from Bio.PDB import PDBParser, PDBIO, Select

from utils.peptides import get_chains_from_stacks
from utils.peptides import ordered_stacks_split_multi_pp

class ChainSelect(Select):
    def __init__(self, chain_ids):
        self.chain_ids = set(chain_ids)
    def accept_chain(self, chain):
        return chain.id in self.chain_ids

def calculate_structure_SASA(structure):
    # Print all the chains in the structure
    # print(f"Calculating SASA for structure with chains: {[chain.id for chain in structure.get_chains()]}")
    sasa_calculator = SASA.ShrakeRupley()
    sasa_calculator.compute(structure, level='R')
    #pdb_io = PDBIO()
    #pdb_io.set_structure(structure)
    #pdb_io.save(f"temp_{structure.id}.pdb", select=ChainSelect(set([chain.id for chain in structure.get_chains()])))  # Save the structure to a temporary file
    #print(len(list(structure.get_residues())))
    return sum(residue.sasa for residue in structure.get_residues())

def get_unbound_structure(original_structure, subset_chains):
    pdb_io = PDBIO() 
    pdb_parser = PDBParser(QUIET=True)
    buf = io.StringIO()
    pdb_io.set_structure(original_structure)
    pdb_io.save(buf, select=ChainSelect(subset_chains))
    #pdb_io.save(f"unbound_{'_'.join(subset_chains)}{original_structure}.pdb", select=ChainSelect(subset_chains))
    buf.seek(0) # FIXME - using this in parallel is dangerous
    return pdb_parser.get_structure('', buf)

def calculate_BSA(structure, chain_subset_1, chain_subset_2):
    remove_heteroatoms(structure)

    # Filter to only chains of interest (both subsets)
    model = structure[0]
    combined_subset = set(chain_subset_1) | set(chain_subset_2)
    chains_to_remove = set(chain.id for chain in structure.get_chains()) - combined_subset
    #print(chain_subset_1, chain_subset_2, "END")

    for chain_ID in chains_to_remove:
        if chain_ID != ' ': # FIXME what is this?? It is present in the models chains
            try:
                model.detach_child(chain_ID)
            except Exception as e:
                raise KeyError(f"idk, {e}, {set(chain.id for chain in structure.get_chains())}, {combined_subset}")

    #print("Calculate SASA")
    # Calculate bound SASA
    #full_struct = get_unbound_structure(structure, combined_subset)
    bound_SASA = calculate_structure_SASA(structure)

    # Calculate unbound SASAs
    unbound_1 = get_unbound_structure(structure, chain_subset_1)
    unbound_SASA_1 = calculate_structure_SASA(unbound_1)

    unbound_2 = get_unbound_structure(structure, chain_subset_2)
    unbound_SASA_2 = calculate_structure_SASA(unbound_2)

    # Total unbound SASA (sum of both)
    # print(f"Unbound SASA 1: {unbound_SASA_1} + Unbound SASA 2: {unbound_SASA_2}")
    unbound_SASA_total = unbound_SASA_1 + unbound_SASA_2

    # Buried surface area
    # print(f"Unbound SASA Total: {unbound_SASA_total} - Bound SASA: {bound_SASA}")
    buried_surface_area = unbound_SASA_total - bound_SASA

    return unbound_SASA_total, bound_SASA, buried_surface_area

def process_pdb_BSA(pdb_file, directory):
    pdb_parser = PDBParser(QUIET=True)
    try:
        print(f"Processing {pdb_file} for BSA calculation...")
        structure = pdb_parser.get_structure(pdb_file, f"{directory}/{pdb_file}")
    except Exception as e:
        print(f"Failed to parse {pdb_file}: {e}")
        return pdb_file, None, None, None

    #print(structure.id, "Chains:", [chain.id for chain in structure.get_chains()])
    ordered_stacks, peptides, split_struct = ordered_stacks_split_multi_pp(structure, expected_num_of_stacks=2, max_attempts=10)

    #print(split_struct.id, "Chains:", [chain.id for chain in split_struct.get_chains()])
    stack_chains = get_chains_from_stacks(peptides, ordered_stacks)

    if len(stack_chains) != 2:
        print(f"Expected different number of stacks, got {len(stack_chains)}, {ordered_stacks}")
        print(f"Skipping {pdb_file}")
        return pdb_file, None, None, None

    unbound_SASA_total, bound_SASA, buried_surface_area = calculate_BSA(split_struct, stack_chains[0], stack_chains[1])
    return pdb_file, buried_surface_area, bound_SASA, unbound_SASA_total



def get_BSA_from_df(fragment_df, directory="./fragment_pairs", max_workers=8):
    fragment_df = fragment_df.copy()
    pdb_files = fragment_df['PDB_File'].unique()

    updates = []

    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(process_pdb_BSA, pdb_file, directory): pdb_file for pdb_file in pdb_files}
        for future in as_completed(futures):
            pdb_file, buried_surface_area, bound_SASA, unbound_SASA_total = future.result()
            if buried_surface_area is not None:
                updates.append({
                    'PDB_File': pdb_file,
                    'BSA': buried_surface_area,
                    'Bound_SASA': bound_SASA,
                    'Unbound_SASA': unbound_SASA_total
                })

    # Create updates DataFrame
    updates_df = pd.DataFrame(updates)
    
    # Merge in new values
    fragment_df = fragment_df.merge(updates_df, on='PDB_File', how='left')

    return fragment_df
# Get BSA for the template dataset
if os.path.exists("./results/templates_with_bsa.csv"):
    templates_with_BSA = pd.read_csv("./results/templates_with_bsa.csv")
else:
    templates_with_BSA = get_BSA_from_df(template_df, directory="./templates")
    templates_with_BSA.to_csv("./results/templates_with_bsa.csv", index=False)

templates_with_BSA.head()
PDB_File Total_Energy Backbone_HBond Sidechain_HBond Van_der_Waals Electrostatics Solvation_Polar Solvation_Hydrophobic BSA Bound_SASA Unbound_SASA
0 1yjo.pdb -27.20710 -41.6648 -70.69830 -80.8205 -7.77834 164.9760 -84.2058 1430.331422 3589.657794 5019.989216
1 1yjp_1.pdb 6.20392 -45.4826 -70.98320 -65.0946 6.56874 113.7480 -63.5063 1472.207771 3359.014188 4831.221959
2 1yjp_2.pdb -14.26580 -50.0210 -72.01630 -80.3370 2.57071 126.6720 -83.8945 1433.401072 3552.675946 4986.077018
3 2kib_1.pdb 16.43380 -44.0706 0.00000 -43.1552 -3.11863 68.1515 -59.0119 1073.027307 4801.665177 5874.692484
4 2kib_2.pdb 17.88910 -33.9384 -1.84035 -41.7314 1.71712 55.3471 -57.0966 929.267625 4687.238910 5616.506535

Some of the template structures failed the calculation (have NaNs). This is mostly due to their weird shape and therefore incorrect identification of stacks (sides of the hexpeptide fragment pairs) with our algorithm.

# Mostly weird fragments which our alg doesn't know how to handle
# only 2y3l_3 and 5e5z should be computable
# (with looser threshold - but it is just 2 sturcts)
templates_with_BSA[templates_with_BSA['BSA'].isna()]
PDB_File Total_Energy Backbone_HBond Sidechain_HBond Van_der_Waals Electrostatics Solvation_Polar Solvation_Hydrophobic BSA Bound_SASA Unbound_SASA
30 2y3l_3.pdb 38.711100 -23.52490 0.00000 -27.1616 -0.103731 42.3114 -36.6882 NaN NaN NaN
111 4uby.pdb -0.626345 -31.55220 0.00000 -48.7639 -3.787230 67.9743 -68.0499 NaN NaN NaN
118 4w5p_2.pdb -8.510380 -37.39760 -6.91971 -48.4529 -3.225380 63.8204 -66.1297 NaN NaN NaN
141 5e5v_2.pdb 7.228250 -30.64120 -5.05200 -51.7335 2.772190 68.4774 -72.0727 NaN NaN NaN
143 5e5z.pdb 23.775900 -32.34680 -19.45090 -44.6157 -1.399470 77.1989 -52.8902 NaN NaN NaN
175 6cfh_1.pdb 57.066000 -13.72370 -3.42594 -22.1422 5.805420 38.5544 -29.6550 NaN NaN NaN
176 6cfh_2.pdb 15.489300 -33.83700 0.00000 -50.5679 2.129130 66.7553 -72.3968 NaN NaN NaN
179 6cfh_5.pdb 13.991000 -22.49910 -1.86111 -39.6010 1.458860 42.7467 -58.5108 NaN NaN NaN
180 6cfh_6.pdb 55.837500 -3.76357 -7.22857 -11.8741 3.006000 21.9272 -15.7097 NaN NaN NaN

Simple summary of the template's BSA values:

desc = templates_with_BSA['BSA'].describe()
print(desc.apply(lambda x: f"{x:.6f}"))
count     172.000000
mean     1055.313185
std       448.291994
min        -0.000000
25%       841.395469
50%      1118.878224
75%      1394.837274
max      1954.759795
Name: BSA, dtype: object
# BSA is an approximation - these values are very close to zero (and should be zero, or little above)
print(f"{len(templates_with_BSA.loc[(templates_with_BSA['BSA'] < 0)])} template pairs are below 0.")
print(f"{len(templates_with_BSA[(templates_with_BSA['BSA'] < 0) & (templates_with_BSA['BSA'] > -1)])} are approx errors. Removing them..")

# Set approximation errors (negative BSA values) to zero
templates_with_BSA.loc[(templates_with_BSA['BSA'] < 0) & (templates_with_BSA['BSA'] > -1), 'BSA'] = 0
4 template pairs are below 0.
4 are approx errors. Removing them..

Calculate BSA of our sample.

# Get BSA for the repaired sample dataset
if os.path.exists("./results/sample_with_bsa.csv"):
    sample_with_BSA = pd.read_csv("./results/sample_with_bsa.csv")
else:
    sample_with_BSA = get_BSA_from_df(sample_repaired, "./fragment_pairs")
    sample_with_BSA.to_csv("./results/sample_with_bsa.csv", index=False)

sample_with_BSA.head()
PDB_File Total_Energy Backbone_HBond Sidechain_HBond Van_der_Waals Electrostatics Solvation_Polar Solvation_Hydrophobic BSA Bound_SASA Unbound_SASA
0 7qkj_pair_765_Repair.pdb 18.45160 -30.7581 -8.22142 -42.7342 3.79509 61.2674 -53.1292 1075.330095 3265.915691 4341.245786
1 8olq_pair_012_Repair.pdb 50.32270 -36.4475 -31.34510 -59.5809 11.08980 101.5600 -67.5650 926.019646 3922.125508 4848.145153
2 8ci8_pair_225_Repair.pdb 12.05720 -36.2584 -6.14683 -67.4686 4.42839 89.9917 -93.9525 1567.003055 3808.659626 5375.662681
3 7nrq_pair_279_Repair.pdb 49.53490 -31.8512 -9.81061 -45.1977 3.33110 86.0854 -51.5352 1483.879077 3417.488690 4901.367767
4 7qkk_pair_121_Repair.pdb 8.61727 -46.0144 -19.36770 -57.6874 6.74647 89.7038 -71.9427 1072.009079 4046.350113 5118.359192
print(f"{len(sample_with_BSA[sample_with_BSA['BSA'].isna()])}/{len(sample_repaired)} have failed during the BSA calculation.")
sample_with_BSA[sample_with_BSA['BSA'].isna()]
0/12304 have failed during the BSA calculation.
PDB_File Total_Energy Backbone_HBond Sidechain_HBond Van_der_Waals Electrostatics Solvation_Polar Solvation_Hydrophobic BSA Bound_SASA Unbound_SASA

Simple summary of sample's BSA values:

desc = sample_with_BSA['BSA'].describe()
print(desc.apply(lambda x: f"{x:.6f}"))
count    12304.000000
mean      1195.089564
std        209.751344
min        499.762813
25%       1035.606377
50%       1185.562046
75%       1335.289928
max       2297.255447
Name: BSA, dtype: object

Comparison of BSA across datasets:

plt.figure(figsize=(10,6))

# Merge
templates_with_BSA['Category'] = "Template Fragments"
sample_with_BSA['Category'] = "Sampled Fragments"
combined_data = pd.concat([templates_with_BSA, sample_with_BSA], ignore_index=True)

# Plot a comparison of Burried Surface Accessibility
sns.kdeplot(
    data=combined_data, x="BSA",
    hue="Category",
    cut=0,
    fill=True, common_norm=False,
    palette={"Template Fragments": "mediumorchid", "Sampled Fragments": "cornflowerblue"},
    alpha=0.5
)
plt.xlabel("Burried Surface Accessibility")
plt.ylabel("Density")
plt.title(f"Burried Surface Accessibility of Extracted Fragments and Template Fragments\nTemplate n={len(templates_with_BSA)}, Sampled n={len(sample_with_BSA)}")

plt.gca().set_yticks([0, 0.001, 0.002])
#plt.xlim(-2.5, 0.5)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['left'].set_color('dimgray')
plt.gca().spines['bottom'].set_color('dimgray')
plt.grid(False)

plt.savefig("plots/bsa_kde.png", dpi=600, bbox_inches="tight")
plt.savefig("plots/bsa_kde.pdf", bbox_inches="tight")
plt.savefig("plots/bsa_kde.svg", bbox_inches="tight")
plt.show()

Packing/Stacking Energies

We compare the dataset's packing (energy between the two stacks of each structure) and stacking (energy between the first two and last three layers in each stack) energies.

# Load results from Analyse complex for templates
ac_templates = pd.read_csv("./results/foldx_analyse_complex_results_templates.csv")
ac_templates['ID'] = ac_templates['PDB_File'].str.split('_').apply(lambda parts: '_'.join(parts[:-2]))
ac_templates['Category'] = "Templates"
ac_templates.head()
PDB_File Group1 Group2 IntraclashesGroup1 IntraclashesGroup2 Interaction_Energy BackHBond SideHBond VDWClashes ID Category
0 3q2x_b_JH_FDB.pdb HJ BDF 0.044701 0.095801 -2.01238 -6.071510 -0.789906 0.051573 3q2x_b Templates
1 2m5n_3_JHFDB_IGECA.pdb BDFHJ ACEGI 5.677250 4.744780 -22.21540 -0.450673 0.000000 2.607290 2m5n_3 Templates
2 3fr1_AC_EGI.pdb CA GIE 0.110695 0.405026 -7.39247 -6.188010 -0.867155 0.028225 3fr1 Templates
3 2ona_BD_FHJ.pdb BD FHJ 0.204879 0.496664 -3.11506 -5.760990 0.000000 0.022129 2ona Templates
4 3dgj_1_BD_FHJ.pdb DB HJF 0.913331 0.249727 -4.73258 -4.370030 -1.285720 0.060052 3dgj_1 Templates
# Load results from Analyse complex for repaired fragments
ac_extracted = pd.read_csv("./results/foldx_analyse_complex_results_fragment_pairs.csv") # FIXME
ac_extracted['ID'] = ac_extracted['PDB_File'].str[:-4]
ac_extracted['Category'] = "Extracted"
## Replace original files of standardized files
ac_extracted = filter_non_standardized(ac_extracted)
ac_extracted.head()
PDB_File Group1 Group2 IntraclashesGroup1 IntraclashesGroup2 Interaction_Energy BackHBond SideHBond VDWClashes ID Category
0 6ufr_pair_196_Repair.pdb CGFJB ADEHI 0.450845 0.054773 -21.882800 -0.410710 0.000000e+00 0.002200 6ufr_pair_196_Repair Extracted
1 7e0f_pair_124_Repair.pdb DEFGH ABCIJ 4.111020 4.127980 -14.520100 -1.207550 0.000000e+00 1.108400 7e0f_pair_124_Repair Extracted
2 6xfm_pair_135_Repair.pdb 1A 234 0.026684 0.028753 -0.407182 -4.296730 -2.064610e+00 0.000000 6xfm_pair_135_Repair Extracted
3 6xyq_pair_344_Repair.pdb HJ BDF 0.307053 0.433553 -4.612390 -4.757160 0.000000e+00 0.017628 6xyq_pair_344_Repair Extracted
4 8q9g_pair_232_Repair.pdb ACEIJ BDFGH 1.864520 0.124207 -17.891600 -0.508668 1.776360e-15 0.001564 8q9g_pair_232_Repair Extracted
# Filter based on distance (take only what's in repaired_df)
ac_extracted = ac_extracted[ac_extracted["PDB_File"].isin(set(repaired_df["PDB_File"]))]
print(f"After filtering based on length, we have ~{len(ac_extracted)//3} structures ({len(ac_extracted)} rows)")
After filtering based on length, we have ~61940 structures (185820 rows)
combined_data = pd.concat([ac_templates, ac_extracted], ignore_index=True)

combined_data['Interaction Type'] = "Unknown"
combined_data.loc[combined_data['Group1'].str.len() == combined_data['Group2'].str.len(), 'Interaction Type'] = "Packing"
combined_data.loc[combined_data['Group1'].str.len() != combined_data['Group2'].str.len(), 'Interaction Type'] = "Stacking"
combined_data.head(10)
PDB_File Group1 Group2 IntraclashesGroup1 IntraclashesGroup2 Interaction_Energy BackHBond SideHBond VDWClashes ID Category Interaction Type
0 3q2x_b_JH_FDB.pdb HJ BDF 0.044701 0.095801 -2.01238 -6.071510 -0.789906 0.051573 3q2x_b Templates Stacking
1 2m5n_3_JHFDB_IGECA.pdb BDFHJ ACEGI 5.677250 4.744780 -22.21540 -0.450673 0.000000 2.607290 2m5n_3 Templates Packing
2 3fr1_AC_EGI.pdb CA GIE 0.110695 0.405026 -7.39247 -6.188010 -0.867155 0.028225 3fr1 Templates Stacking
3 2ona_BD_FHJ.pdb BD FHJ 0.204879 0.496664 -3.11506 -5.760990 0.000000 0.022129 2ona Templates Stacking
4 3dgj_1_BD_FHJ.pdb DB HJF 0.913331 0.249727 -4.73258 -4.370030 -1.285720 0.060052 3dgj_1 Templates Stacking
5 4r0w_b_1_IG_ECA.pdb GI CAE 0.991309 0.947355 -3.19053 -6.102070 0.000000 0.008595 4r0w_b_1 Templates Stacking
6 4xfo_JH_FDB.pdb HJ BDF 0.094704 0.382677 -5.00158 -6.254170 -1.028470 0.009694 4xfo Templates Stacking
7 4ril_a_3_IG_ECA.pdb GI CAE 0.752904 0.750379 -2.34154 -5.993340 -0.522533 0.000000 4ril_a_3 Templates Stacking
8 6cfh_4_ACEGI_BDFHJ.pdb CAGEI BDHFJ 0.233039 0.345059 -28.30400 -1.715980 0.000000 0.290562 6cfh_4 Templates Packing
9 3fr1_BD_FHJ.pdb DB HJF 0.127741 0.459395 -7.71799 -5.972650 -0.597861 0.118170 3fr1 Templates Stacking
# Sanity check
print(combined_data['Group1'].str.len().value_counts())
print(combined_data['Group2'].str.len().value_counts())
Group1
2    124224
5     62112
Name: count, dtype: int64
Group2
3    124224
5     62112
Name: count, dtype: int64
plt.figure(figsize=(10,6))
K = combined_data[combined_data['Interaction Type'] == "Packing"]

# Plot a comparison of Packing Energies between Templates and Extracted
sns.kdeplot(
    data=K, x="Interaction_Energy",
    hue="Category",
    fill=True, common_norm=False,
    palette={"Templates": "mediumorchid", "Extracted": "cornflowerblue"},
    alpha=0.5
)
plt.xlabel("Packing Interaction Energy")
plt.ylabel("Density")
plt.title(f"Packing Energy Interaction Comparison\nTemplates n={len(K[K['Category']=='Templates'])}, Extracted n={len(K[K['Category']=='Extracted'])}")

plt.axvline(0, color="lightgray", linewidth=2, linestyle="--")
plt.xlim(-50, 50)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['left'].set_color('dimgray')
plt.gca().spines['bottom'].set_color('dimgray')
#plt.gca().set_yticks([0, 0.1, 0.2, 0.3, 0.4])
plt.grid(False)

plt.savefig("plots/packing_energy_kde.png", dpi=600, bbox_inches="tight")
plt.savefig("plots/packing_energy_kde.pdf", bbox_inches="tight")
plt.savefig("plots/packing_energy_kde.svg", bbox_inches="tight")
plt.show()
plt.figure(figsize=(10,6))
K = combined_data[combined_data['Interaction Type'] == "Stacking"]

# Plot a comparison of Packing Energies between Templates and Extracted
sns.kdeplot(
    data=K, x="Interaction_Energy",
    hue="Category",
    fill=True, common_norm=False,
    palette={"Templates": "mediumorchid", "Extracted": "cornflowerblue"},
    alpha=0.5
)
plt.xlabel("Stacking Interaction Energy")
plt.ylabel("Density")
plt.title(f"Stacking Energy Interaction Comparison\nTemplates n={len(K[K['Category']=='Templates'])}, Extracted n={len(K[K['Category']=='Extracted'])}")

plt.axvline(0, color="lightgray", linewidth=2, linestyle="--")
plt.xlim(-12.5, 12.5)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['left'].set_color('dimgray')
plt.gca().spines['bottom'].set_color('dimgray')
plt.gca().set_yticks([0, 0.05, 0.1, 0.15, 0.2])
plt.grid(False)

plt.savefig("plots/stacking_energy_kde.png", dpi=600, bbox_inches="tight")
plt.savefig("plots/stacking_energy_kde.pdf", bbox_inches="tight")
plt.savefig("plots/stacking_energy_kde.svg", bbox_inches="tight")
plt.show()

Let's see if there's any correlation between packing and stacking energy.

combined_data.head()
K = combined_data.copy()

packing = (
    K[K["Interaction Type"].str.lower() == "packing"]
      .loc[:, ["ID", "Interaction_Energy", "Category"]]
      .rename(columns={"Interaction_Energy": "Packing_Energy"})
)

stacking = (
    K[K["Interaction Type"].str.lower() == "stacking"]
      .loc[:, ["ID", "Interaction_Energy"]]
      .rename(columns={"Interaction_Energy": "Stacking_Energy"})
)

merged = stacking.merge(packing, on="ID", how="inner")
merged.head()
ID Stacking_Energy Packing_Energy Category
0 3q2x_b -2.01238 -31.501000 Templates
1 3fr1 -7.39247 -21.395700 Templates
2 2ona -3.11506 -15.931600 Templates
3 3dgj_1 -4.73258 -22.033400 Templates
4 4r0w_b_1 -3.19053 -0.840478 Templates
plt.figure(figsize=(10,6))
K = merged[merged['Category'] == "Templates"]

# Plot the correlation between stacking and packing energy
sns.scatterplot(
    data=K,
    x="Packing_Energy",
    y="Stacking_Energy",
    hue="Category",
    # fill=True, common_norm=False,
    palette={"Templates": "mediumorchid", "Extracted": "cornflowerblue"},
    alpha=0.5
)
plt.xlabel("Packing")
plt.ylabel("Stacking")
plt.title(f"Packing Interaction Energy vs Stacking Interaction Energy")
corr = np.corrcoef(K['Packing_Energy'], K['Stacking_Energy'])[0, 1]
print("Pearson correlation:", corr)

#plt.axvline(0, color="lightgray", linewidth=2, linestyle="--")
plt.xlim(-50, 50)
plt.ylim(-15, 15)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['left'].set_color('dimgray')
plt.gca().spines['bottom'].set_color('dimgray')
#plt.gca().set_yticks([0, 0.05, 0.1, 0.15, 0.2])
plt.grid(False)

plt.savefig("plots/packing_stacking_templates_scatter.png", dpi=600, bbox_inches="tight")
plt.savefig("plots/packing_stacking_templates_scatter.pdf", bbox_inches="tight")
plt.savefig("plots/packing_stacking_templates_scatter.svg", bbox_inches="tight")
plt.show()
Pearson correlation: 0.1850314465829146
plt.figure(figsize=(10,6))
K = merged[merged['Category'] == "Extracted"]

# Plot the correlation between stacking and packing energy
sns.scatterplot(
    data=K,
    x="Packing_Energy",
    y="Stacking_Energy",
    hue="Category",
    # fill=True, common_norm=False,
    palette={"Templates": "mediumorchid", "Extracted": "cornflowerblue"},
    s=10,
    alpha=0.1
)
plt.xlabel("Packing")
plt.ylabel("Stacking")
plt.title(f"Packing Interaction Energy vs Stacking Interaction Energy")
corr = np.corrcoef(K['Packing_Energy'], K['Stacking_Energy'])[0, 1]
print("Pearson correlation:", corr)

#plt.axvline(0, color="lightgray", linewidth=2, linestyle="--")
plt.xlim(-50, 50)
plt.ylim(-15, 15)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['left'].set_color('dimgray')
plt.gca().spines['bottom'].set_color('dimgray')
#plt.gca().set_yticks([0, 0.05, 0.1, 0.15, 0.2])
plt.grid(False)

plt.savefig("plots/packing_stacking_extracted_scatter.png", dpi=600, bbox_inches="tight")
plt.savefig("plots/packing_stacking_extracted_scatter.pdf", bbox_inches="tight")
plt.savefig("plots/packing_stacking_extracted_scatter.svg", bbox_inches="tight")
plt.show()
Pearson correlation: 0.31595007903333355

Shape diversity

To compare the shapes from both datasets and compare them, we have computed two RMSD matrices and clustered in a hierarchical fashion (two-level clustering).

First level: Coordinates of the hexapeptide pairs were projected onto a plane (from 3D to 2D), from these, an RMSD matrix was computed using a QCP superimposer. This matrix was than used as a basis for clustering with a dendrogram (using a ward distance metric).

Second level: Full 3D coordinates were used to create yet another RMSD matrix (with modified QCP superimposer which, after superimposition of the whole structures, took the maximum of the subset's RMSDs of the two stacks). The clustering pipeline was as follows: take all the structures which belong to the same cluster in the First level, take the rows corresponding to them in Second level RMSD matrix and cluster them (again using the dendrogram with the ward distance metric).

In the case of 2D structures, we assess the cluster's shape validity by viewing the superimposed projected coordinates (in the first level). For the 3D case, we superimpose all the structures to its cluster center - we export this cluster in a .cif file, we plot the 2D shapes as well - by averaging across layers and then projecting onto a plane by taking the PCA components.

import warnings

warnings.filterwarnings("ignore", message="'force_all_finite' was renamed to 'ensure_all_finite'", category=FutureWarning)
warnings.filterwarnings("ignore", message="n_jobs value .* overridden to .* by setting random_state", category=UserWarning)
warnings.filterwarnings("ignore", message="'force_all_finite' was renamed to 'ensure_all_finite'", category=FutureWarning, module=r".*sklearn.*")
warnings.filterwarnings("ignore", message=r"n_jobs value .* overridden to .* by setting random_state", category=UserWarning, module=r".*umap.*")

def print_pcoa_explained_var(pcoa_results, num_components=2):
    # Extract the raw eigenvalues, taking only the first 'num_components'
    # pcoa_results.eigvals is a pandas Series, direct slicing works
    raw_eigenvalues = pcoa_results.eigvals[:num_components]
    explained_variance_ratio = pcoa_results.proportion_explained[:num_components]
    cumulative_explained_variance = explained_variance_ratio.cumsum()

    # You now have the other metrics stored in these variables for the top 10 components:
    print("\nRaw Eigenvalues:")
    print(raw_eigenvalues)
    print("\nExplained Variance Ratio:")
    print(explained_variance_ratio)
    print("\nCumulative Explained Variance:")
    print(cumulative_explained_variance)

# Shape analysis
def plot_pca_and_umap(pca, umap, hue, palette, dot_size=20):
    # Create wide figure with two subplots
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))  # 1 row, 2 columns
    # Get value counts
    hue_series = pd.Series(hue).astype(str)
    category_counts = hue_series.value_counts()
    hue_labels = {label: f"{label} (n={category_counts.get(label, 0)})" for label in category_counts.index}
    print(hue_labels)
    num_of_templates = category_counts.get(category_counts.index[1], 0)
    print(f"Number of templates used {num_of_templates}")
    hue_group1 = hue[:num_of_templates]
    hue_group2 = hue[num_of_templates:]

    # Plot PCA
    try:
        sns.scatterplot(x=pca[:, 0], y=pca[:, 1], hue=hue, palette=palette, s=20, ax=axes[0])
    except:
        print("Error, trying to plot a df")
        sns.scatterplot(data=pca[num_of_templates:], x='PCoA_1', y='PCoA_2', hue=hue_group2, palette=palette,s=dot_size, ax=axes[0], alpha=0.3)
        sns.scatterplot(data=pca[:num_of_templates], x='PCoA_1', y='PCoA_2', hue=hue_group1, palette=palette,s=dot_size, ax=axes[0], alpha=0.9)

    axes[0].set_title("PCoA of Reconstructed Fragments")
    axes[0].set_xlabel("PCoA 1")
    axes[0].set_ylabel("PCoA 2")
    handles, labels = axes[0].get_legend_handles_labels()
    new_labels = [hue_labels.get(label, label) for label in labels]
    axes[0].legend(handles, new_labels, title="Category")
    axes[0].spines['top'].set_visible(False)
    axes[0].spines['right'].set_visible(False)
    axes[0].spines['left'].set_color('dimgray')
    axes[0].spines['bottom'].set_color('dimgray')
    axes[0].grid(False)

    # Plot UMAP
    # Split the UMAP coordinates
    umap_group1 = umap[:num_of_templates]
    umap_group2 = umap[num_of_templates:]

    # Plot Group 2 (overrepresented) first
    sns.scatterplot(x=umap_group2[:, 0], y=umap_group2[:, 1], hue=hue_group2,
                    palette=palette, s=dot_size, ax=axes[1], alpha=0.3)
    sns.scatterplot(x=umap_group1[:, 0], y=umap_group1[:, 1], hue=hue_group1,
                    palette=palette, s=dot_size, ax=axes[1], alpha=0.9)

    # Adjust labels and legend
    axes[1].set_title("UMAP of Reconstructed Fragments")
    axes[1].set_xlabel("UMAP 1")
    axes[1].set_ylabel("UMAP 2")
    handles, labels = axes[1].get_legend_handles_labels()
    new_labels = [hue_labels.get(label, label) for label in labels]
    axes[1].legend(handles, new_labels, title="Category")
    axes[1].spines['top'].set_visible(False)
    axes[1].spines['right'].set_visible(False)
    axes[1].spines['left'].set_color('dimgray')
    axes[1].spines['bottom'].set_color('dimgray')
    axes[1].grid(False)
    # Adjust layout
    plt.tight_layout()
    plt.show()

def num_components_explaining(percent_of_variance, pcoa_results):
    explained_variance_sum = 0
    components_explaining = 0

    for i in range(0, pcoa_results.samples.shape[1]):
        explained_variance_sum += pcoa_results.proportion_explained.iloc[i]
        if explained_variance_sum >= 0.9:
            components_explaining = i + 1
            print(f"{percent_of_variance}% of variance is explained by the first {i+1} components")
            return components_explaining

def plot_umap(num_templates, pdb_files, umap_emb, save_path="plots/rmsd_umap", dot_size=20):
    umapca_df = pd.DataFrame({
    'PDB_File': pdb_files,
    'UMAP_1': umap_emb[:, 0],
    'UMAP_2': umap_emb[:, 1]
    })

    fig = plt.figure(figsize=(10, 6))
    sampled_K = umapca_df.iloc[num_templates:]
    template_K = umapca_df.iloc[:num_templates]
    n_sampled =  len(sampled_K) - sampled_K['UMAP_1'].isnull().sum()
    n_template = len(template_K) - template_K['UMAP_1'].isnull().sum()

    sns.scatterplot(data=sampled_K, x='UMAP_1', y='UMAP_2',
                    color='cornflowerblue', s=dot_size, alpha=0.3, label=f'Sampled Fragments (n={n_sampled})')
    sns.scatterplot(data=template_K, x='UMAP_1', y='UMAP_2',
                    color='mediumorchid', s=dot_size, alpha=0.9, label=f'Template Fragments (n={n_template})')

    plt.title('UMAP Projection of All-to-All RMSD Matrix')
    plt.xlabel('UMAP 1')
    plt.ylabel('UMAP 2')
    plt.legend(title='Fragment Type')
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    plt.gca().spines['left'].set_color('dimgray')
    plt.gca().spines['bottom'].set_color('dimgray')
    plt.grid(False)

    # plt.xlim(-0,5, 5.5)
    # plt.gca().set_yticks([0, 0.3, 0.6])

    plt.tight_layout()

    plt.savefig(f"{save_path}.png", dpi=600, bbox_inches="tight")
    plt.savefig(f"{save_path}.pdf", bbox_inches="tight")
    plt.savefig(f"{save_path}.svg", bbox_inches="tight")
    plt.show()

def plot_dendrogram(Z, labels, num_clusters, t_distance, save_path="plots/rmsd_dendrogram", custom_palette=None):
    fig = plt.figure(figsize=(20, 8))
    ax = fig.add_subplot(111)

    # COLORING:
    # if no palette, generate one
    if custom_palette is None:
        custom_palette = plt.cm.get_cmap("tab20", num_clusters).colors
    custom_palette = [mcolors.to_hex(c) for c in custom_palette]

    # cluster → color
    cluster_colors = {cl: custom_palette[(cl - 1) % len(custom_palette)]
                      for cl in range(1, num_clusters + 1)}

    # cluster labels for leaves
    clusters = sch.fcluster(Z, t=t_distance, criterion="distance")

    # get leaf order
    leaves = sch.leaves_list(Z)

    # map node_id → leaf ids under it
    from collections import defaultdict
    node2leaves = {}
    n_leaves = len(clusters)

    # leaves themselves
    for i in range(n_leaves):
        node2leaves[i] = [i]

    # internal nodes
    for node_id in range(n_leaves, n_leaves + Z.shape[0]):
        left, right, _, _ = Z[node_id - n_leaves]
        node2leaves[node_id] = node2leaves[int(left)] + node2leaves[int(right)]

    def link_color_func(node_id):
        # leaf: just return its cluster color
        if node_id < n_leaves:
            return cluster_colors[clusters[node_id]]

        # height of merge
        dist = Z[node_id - n_leaves, 2]
        if dist >= t_distance:
            return "black"

        # otherwise: all leaves under this node
        leaves_under = node2leaves[node_id]
        cluster_ids = set(clusters[i] for i in leaves_under)

        if len(cluster_ids) == 1:
            # all in same cluster → color by that cluster
            return cluster_colors[cluster_ids.pop()]
        else:
            # mixed clusters (shouldn’t happen below cut) → black
            return "black"

    sch.dendrogram(
        Z,
        labels=labels,
        leaf_rotation=90,
        leaf_font_size=8,
        #truncate_mode='lastp',
        p=num_clusters,
        show_contracted=True,
        color_threshold=t_distance,
        above_threshold_color='black',
        ax=ax,
        link_color_func=link_color_func
    )
    ax.grid(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    plt.gca().spines['left'].set_color('dimgray')
    ax.spines['bottom'].set_visible(False)
    ax.set_xticklabels([]) # [str(label) for label in range(1, num_clusters + 1)]

    plt.title('Hierarchical Clustering Dendrogram with Clusters Highlighted')
    plt.xlabel('')
    plt.ylabel('Distance (Ward)')
    plt.tight_layout() # Adjust layout to prevent labels from overlapping

    plt.savefig(f"{save_path}.png", dpi=600, bbox_inches="tight")
    plt.savefig(f"{save_path}.pdf", bbox_inches="tight")
    plt.savefig(f"{save_path}.svg", bbox_inches="tight")
    plt.show()
def load_features(path):
    # Load precomputed features
    data = np.load(path, allow_pickle=True)
    return data["X"], data["files"]

def load_matrix(path):
    print("Loading an RMSD matrix...")
    X, pdb_files = load_features(path)
    print(f"Loaded. Matrix has {X.shape} shape.")
    print(f"With corresponding number of files {len(pdb_files)}")

    # Get the NaN values counts
    nan_count_per_row = np.isnan(X).sum(axis=1)
    nan_count_counts = pd.Series(nan_count_per_row).value_counts()
    print(nan_count_counts)

    # Find indices of files which failed
    nan_count_to_remove = max(nan_count_per_row)
    if nan_count_to_remove > 0:
        indices_to_remove = [idx for idx, count in enumerate(nan_count_per_row) if count == nan_count_to_remove]
        print(f"Files to remove due to high NaN count ({nan_count_to_remove} nans)")
    else:
        indices_to_remove = []

    # Remove these rows and columns
    print(f"Removing {len(indices_to_remove)} files with {nan_count_to_remove} NaN values in their respective row.")
    keep_indices = [i for i in range(len(pdb_files)) if i not in indices_to_remove]
    X = X[np.ix_(keep_indices, keep_indices)]
    print(f"Resulting matrix shape is {X.shape}")

    # Update the pdb_files list
    pdb_files = [pdb_files[i] for i in range(len(pdb_files)) if i not in indices_to_remove]
    print(f"Adjusted number of files {len(pdb_files)}")

    return X, pdb_files

def reduce_matrix(X, files, max_samples=5_000, float_32=True):
    # Get a subset so we can handle it computationally
    if X.shape[0] > max_samples:
        print(f"Reducing matrix size to the first {max_samples} samples")
        selected = np.arange(max_samples)
        X = X[np.ix_(selected, selected)]
        files = files[:max_samples]

    if float_32:
        print("Converting to float32")
        X = X.astype(np.float32) # and normalize if necessary

    return X, files

def filter_matrix(X, files, wanted_files):
    filter_indices = []
    not_found_files = []
    for filtered_file in wanted_files:
        try:
            filter_indices.append(files.index(filtered_file))
        except ValueError:
            not_found_files.append(filtered_file)
    print(f"{len(not_found_files)} were not found in the matrix")
    print(np.array(filter_indices).shape)

    X = X[np.ix_(filter_indices, filter_indices)]
    files = np.asarray(files)[filter_indices].tolist()

    return X, files
# Load and reduce projected coordinates "2D-RMSD" matrix
rmsd_X_projected, pdb_files_projected = load_matrix("svd_projected_max_stack_sample.npz") # load_matrix("svd_projected_all.npz") # TODO: svd_projected_max_stack.npz # load_matrix("rmsd_matrix_projected.npz")
# Filter on our new condition
helper_templates = pd.DataFrame([{"PDB_File": file} for file in pdb_files_projected[:170]])

#
stricter_paths_all = []
with open("fragment_pairs_filtered_five_two.txt", "r") as f:
    for line in f:
        name = line.split('/')[-1]
        name = name if name[-1] != '\n' else name[:-1] 
        stricter_paths_all.append({"PDB_File": name})
filter_helper_all = pd.DataFrame(stricter_paths_all)
filter_helper_all = filter_helper_all[filter_helper_all["PDB_File"].isin(set(repaired_df["PDB_File"]))]
#

temporary = pd.concat([helper_templates, filter_helper_all])
temporary.reset_index() # FIXME (is the reset needed without drop=True)
temporary
# rmsd_X_projected, pdb_files_projected = reduce_matrix(rmsd_X_projected, pdb_files_projected, max_samples=100_000)
rmsd_X_projected, pdb_files_projected = filter_matrix(rmsd_X_projected, pdb_files_projected, temporary["PDB_File"].tolist())
Loading an RMSD matrix...
Loaded. Matrix has (10976, 10976) shape.
With corresponding number of files 10976
0    10976
Name: count, dtype: int64
Removing 0 files with 0 NaN values in their respective row.
Resulting matrix shape is (10976, 10976)
Adjusted number of files 10976
51145 were not found in the matrix
(10976,)
plt.figure(figsize=(5,3))

sns.kdeplot(
    rmsd_X_projected[np.triu_indices_from(rmsd_X_projected, k=1)].flatten(),
    color = "midnightblue",
    linewidth=2,
)

plt.xlim(-0.5, 10)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['left'].set_color('dimgray')
plt.gca().spines['bottom'].set_color('dimgray')
plt.gca().set_yticks([0, 0.3, 0.6])
plt.gca().set_xticks([0, 2, 4, 6, 8])
plt.grid(False)

plt.savefig("plots/rmsd_projected_kde_max_stack_svd.png", dpi=600, bbox_inches="tight")
plt.savefig("plots/rmsd_projected_kde_max_stack_svd.pdf", bbox_inches="tight")
plt.savefig("plots/rmsd_projected_kde_max_stack_svd.svg", bbox_inches="tight")
plt.show()
# Load and reduce "3D-RMSD" matrix with max_stack superimposer
rmsd_X_max, pdb_files_max = load_matrix("rmsd_matrix_max_stack_svd.npz") # rmsd_matrix_max_stack.npz # TODO: rmsd_matrix_max_stack_svd.npz
# rmsd_X_max, pdb_files_max = reduce_matrix(rmsd_X_max, pdb_files_max, max_samples=100_000)
print(pdb_files_max[0], pdb_files_max[-1])
pdb_files_max = [file[26:] for file in pdb_files_max] # :17

rmsd_X_max, pdb_files_max = filter_matrix(rmsd_X_max, pdb_files_max, temporary["PDB_File"].tolist())
Loading an RMSD matrix...
Loaded. Matrix has (11000, 11000) shape.
With corresponding number of files 11000
24       10976
10999       24
Name: count, dtype: int64
Files to remove due to high NaN count (10999 nans)
Removing 24 files with 10999 NaN values in their respective row.
Resulting matrix shape is (10976, 10976)
Adjusted number of files 10976
dist_matrix_pdbs_five_two/1yjo.pdb dist_matrix_pdbs_five_two/8az2_pair_113_Repair.pdb
51145 were not found in the matrix
(10976,)
plt.figure(figsize=(5,3))

sns.kdeplot(
    rmsd_X_max[np.triu_indices_from(rmsd_X_max, k=1)].flatten(),
    color = "midnightblue",
    linewidth=2,
)

plt.xlim(-0.5, 10)
plt.ylim(0, 0.6)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['left'].set_color('dimgray')
plt.gca().spines['bottom'].set_color('dimgray')
plt.gca().set_yticks([0, 0.3, 0.6])
plt.gca().set_xticks([0, 2, 4, 6, 8])
plt.grid(False)

plt.savefig("plots/rmsd_max_kde_svd.png", dpi=600, bbox_inches="tight")
plt.savefig("plots/rmsd_max_kde_svd.pdf", bbox_inches="tight")
plt.savefig("plots/rmsd_max_kde_svd.svg", bbox_inches="tight")
plt.show()

First level

distance_matrix = DistanceMatrix(rmsd_X_projected)

# Perform PCoA
num_top_components = 2
pcoa_results = pcoa(distance_matrix, number_of_dimensions=3000)
results_df = pcoa_results.samples.iloc[:, :num_top_components].copy()
## Create column names for the DataFrame based on the number of top components
column_names = [f'PCoA_{i+1}' for i in range(num_top_components)]
results_df.columns = column_names
print_pcoa_explained_var(pcoa_results, num_top_components)

# UMAP
## Initialize UMAP with 'precomputed' metric
umap_model = umap.UMAP(n_neighbors=30, n_components=2, metric='precomputed', random_state=42, min_dist=0.05)
emb = umap_model.fit_transform(rmsd_X_projected)

num_templates_projected = sum('Repair' not in file for file in pdb_files_projected)
category = np.array(
    ["Templates"] * num_templates_projected + ["Extracted"] * (len(pdb_files_projected) - num_templates_projected)
)

plot_pca_and_umap(results_df, emb, category, {"Templates": "mediumorchid", "Extracted": "cornflowerblue"}, dot_size=10)
Raw Eigenvalues:
PC1    22017.579391
PC2    10216.876781
dtype: float64

Explained Variance Ratio:
PC1    0.141923
PC2    0.065857
dtype: float64

Cumulative Explained Variance:
PC1    0.141923
PC2    0.207780
dtype: float64
{'Extracted': 'Extracted (n=9988)', 'Templates': 'Templates (n=170)'}
Number of templates used 170
Error, trying to plot a df
components_explaining = num_components_explaining(percent_of_variance=90, pcoa_results=pcoa_results)
umapca_model = umap.UMAP(n_neighbors=15, n_components=2, random_state=42, min_dist=0.01) #, min_dist=0.1, spread=1)
umapca_emb = umapca_model.fit_transform(pcoa_results.samples.iloc[:, :components_explaining].values)

plot_umap(num_templates_projected, pdb_files_projected, umapca_emb, save_path="plots/rmsd_projected_umap", dot_size=10) # TODO: smaller dots
90% of variance is explained by the first 1595 components
# Perform hierarchical clustering
t_distance = 25
Z = sch.linkage(squareform(rmsd_X_projected), method='ward', metric='precomputed')

labels = sch.fcluster(Z, t=t_distance, criterion='distance')

num_clusters = len(np.unique(labels))
print(f"Found {num_clusters} number of clusters... labeled {len(labels)} structures")

cluster_palette = sns.color_palette("husl", num_clusters)  # or "hls"
plot_dendrogram(Z, pdb_files_projected, num_clusters, t_distance, save_path="plots/rmsd_projected_dendrogram_max_stack", custom_palette=cluster_palette)
Found 22 number of clusters... labeled 10158 structures
from scipy.cluster.hierarchy import to_tree

def convert_linkage_to_newick(Z, leaf_names):
    tree = to_tree(Z, rd=False)
    newick_str = get_newick(tree, Z, tree.dist, leaf_names)
    return newick_str

def get_newick(node, labels, parent_dist, leaf_names, newick=""):
    if node.is_leaf():
        return "%s:%.6f%s" % (leaf_names[node.id], parent_dist - node.dist, newick)
    else:
        if len(newick) > 0:
            newick = "):%.6f%s" % (parent_dist - node.dist, newick)
        else:
            newick = ");"
        newick = get_newick(node.get_left(), labels, node.dist, leaf_names, newick)
        newick = get_newick(node.get_right(), labels, node.dist, leaf_names, "," + newick)
        newick = "(" + newick
        return newick

newick_str = convert_linkage_to_newick(Z, labels) # FIXME

with open("dendrogram_projected.newick", "w") as f:
    f.write(newick_str)
import superpose_fragments
import projected_pairwise_rmsd_matrix
from Bio.PDB.qcprot import QCPSuperimposer

pdb_to_idx = {file: idx for idx, file in enumerate(pdb_files_projected)}

def calculate_summed_distance_within_cluster(row):
    current_pdb = row['PDB_File']
    current_cluster = row['Cluster']

    # Get all PDB files in the current cluster
    cluster_members_df = K[K['Cluster'] == current_cluster]
    cluster_member_pdbs = cluster_members_df['PDB_File'].tolist()

    # Get the index of the current and other cluster members' PDB files in the RMSD matrix
    current_pdb_idx = pdb_to_idx[current_pdb]
    cluster_member_indices = [pdb_to_idx[pdb] for pdb in cluster_member_pdbs]

    # Extract the relevant row from the RMSD matrix (distances from current PDB)
    distances_from_current_pdb = rmsd_X_projected[current_pdb_idx, :] # FIXME
    summed_distance = distances_from_current_pdb[cluster_member_indices].sum()
    # summed_distance_squared = (distances_from_current_pdb[cluster_member_indices]**2).sum() # TODO
    # if i want to know which files distort the cluster
    return summed_distance

def superpose_cluster(cluster_num, distance_matrix, projected_alignment=False, K_labeled=K):
    # Get the indices of the PDB files in the target cluster
    cluster_rows = K_labeled[K_labeled['Cluster'] == cluster_num].sort_values(by='Summed_Distances_Within_Cluster', ascending=True)
    cluster_indices = cluster_rows['PDB_File'].apply(lambda x: pdb_to_idx[x]).tolist()
    cluster_rmsds_submatrix = distance_matrix[np.ix_(cluster_indices, cluster_indices)]
    # Flatten the upper triangular part
    upper_triangle_indices = np.triu_indices_from(cluster_rmsds_submatrix, k=1) # FIXME: can be rectangular not a square
    all_pairwise_cluster_distances = cluster_rmsds_submatrix[upper_triangle_indices]
    # Get all the cluster paths and superpose
    paths = []
    for file in list(cluster_rows['PDB_File'].values):
        file = file.split('/')[-1]
        if file.split('_')[-1] == "Repair.pdb":
            paths.append(f"./fragment_pairs/{file}")
        else:
            paths.append(f"./templates/{file}")
    if not projected_alignment:
        rmsds, super_opts, aligned_coords = superpose_fragments.superpose_all(paths, output_path=f"superposed_cluster_{cluster_num}.cif")
    else:
        center_projected_coords, _ = projected_pairwise_rmsd_matrix.get_projected_coords_and_std(paths[0])
        superimposer = QCPSuperimposer()
        aligned_coords = [center_projected_coords]
        rmsds = []

        for path in paths[1:]:
            projected_coords, _ = projected_pairwise_rmsd_matrix.get_projected_coords_and_std(path)
            if projected_coords is None: # FIXME: why some templates failing
                print(f"{projected_coords} of {path}")
                continue
            
            possible_aligns = projected_pairwise_rmsd_matrix.generate_possible_alignments_test(projected_coords)

            best_rmsd = float('inf')
            best_variant = None
            for align in possible_aligns:
                superimposer.set(center_projected_coords, align)
                superimposer.run()
                rmsd = superimposer.get_rms()
                
                if rmsd < best_rmsd:
                    best_rmsd = rmsd
                    best_variant = superimposer.get_transformed()

            rmsds.append(best_rmsd)
            aligned_coords.append(best_variant)

    print(f"Number of structures in cluster {cluster_num}: {len(cluster_rows)}")
    print(f"Mean RMSD within cluster: {np.mean(all_pairwise_cluster_distances):.2f} Å")
    print(f"Standard deviation of RMSDs within cluster: {np.std(all_pairwise_cluster_distances):.2f} Å")
    print(f"Standard deviation of RMSDs superposition to centroid structure: {np.std(rmsds):.2f} Å")
    return aligned_coords

def plot_projected_cluster(cluster_num, projected_alignment=True, X=rmsd_X_projected, K_labeled=K, save_path="plots/cluster_test"):
    aligned_coords = superpose_cluster(cluster_num, X, projected_alignment=projected_alignment, K_labeled=K_labeled)
    aligned_coords = np.array(aligned_coords)
    avg_layer_coords = aligned_coords.reshape(-1, 12, 3)
    mean_coords = np.mean(avg_layer_coords, axis=0)
    std_coords = np.std(avg_layer_coords, axis=0)

    centered_mean_coords = mean_coords - np.mean(mean_coords, axis=0)
    pca = PCA(n_components=2)
    pca.fit(centered_mean_coords)
    projected_coords = pca.transform(centered_mean_coords)

    print(f"Total variance explained by best-fit plane: ~{np.sum(pca.explained_variance_ratio_):.2f}")
    print(f"Variance explained by PC1: {pca.explained_variance_ratio_[0]:.2f}")
    print(f"Variance explained by PC2: {pca.explained_variance_ratio_[1]:.2f}")

    projected_stdev_x = np.dot(std_coords, pca.components_[0]) # Project stdev along PC1
    projected_stdev_y = np.dot(std_coords, pca.components_[1]) # Project stdev along PC2
    projected_stdev_x_abs = np.abs(projected_stdev_x)
    projected_stdev_y_abs = np.abs(projected_stdev_y)

    fig_2d_pca = plt.figure(figsize=(6, 4))
    ax_2d_pca = fig_2d_pca.add_subplot(111)

    # Plot the mean trace on the best-fit planeplot_projected_cluster
    ax_2d_pca.plot(projected_coords[:6, 0], 
                    projected_coords[:6, 1], 
                    'o-', label='Mean Trace (Hexapeptide 1)', color='black', linewidth=10, markersize=6)
    ax_2d_pca.plot(projected_coords[6:, 0], 
                    projected_coords[6:, 1], 
                    'o-', label='Mean Trace (Hexapeptide 2)', color='black', linewidth=10, markersize=6)
    # Add "blurriness" as error bars
    # Use projected_stdev_x_abs and projected_stdev_y_abs for error bars along the new axes
    ax_2d_pca.errorbar(projected_coords[:, 0], projected_coords[:, 1],
                       xerr=projected_stdev_x_abs, yerr=projected_stdev_y_abs,
                       fmt='none', capsize=3, color='grey', alpha=0.5, zorder=0)

    ax_2d_pca.axes.get_xaxis().set_visible(False)
    ax_2d_pca.axes.get_yaxis().set_visible(False)
    # remove the box
    for spine in ax_2d_pca.spines.values():
        spine.set_visible(False)
    ax_2d_pca.grid(False)
    ax_2d_pca.set_aspect('equal', adjustable='box') # Keep aspect ratio for true shape
    plt.tight_layout()

    plt.savefig(f"{save_path}_{cluster_num}.png", dpi=300, bbox_inches="tight")
    plt.savefig(f"{save_path}_{cluster_num}.pdf", bbox_inches="tight")
    plt.savefig(f"{save_path}_{cluster_num}.svg", bbox_inches="tight")
    plt.show()

K = pd.DataFrame({
    'PDB_File': pdb_files_projected,
    'Cluster': labels,
    'Category': category
})

# Apply the function to each row
K['Summed_Distances_Within_Cluster'] = K.apply(calculate_summed_distance_within_cluster, axis=1)
cluster_num = 9
# superpose_cluster(cluster_num, rmsd_X_projected)
plot_projected_cluster(cluster_num, True, rmsd_X_projected, K)
Number of structures in cluster 9: 500
Mean RMSD within cluster: 1.49 Å
Standard deviation of RMSDs within cluster: 0.44 Å
Standard deviation of RMSDs superposition to centroid structure: 0.36 Å
Total variance explained by best-fit plane: ~1.00
Variance explained by PC1: 0.82
Variance explained by PC2: 0.18
# If needed remove malicious cluster and recluster
#wrong_cluster_num = 20
#K = K[K['Cluster'] != wrong_cluster_num]
#rmsd_X_projected, pdb_files_projected = filter_matrix(rmsd_X_projected, pdb_files_projected, K['PDB_File'].to_list())
0 were not found in the pdb_files_projected matrix
(10158,)
for i in range(max(labels)):
    plot_projected_cluster(i+1, True, rmsd_X_projected, K) #, save_path="plots/projected_clusters/projected_cluster")
Number of structures in cluster 1: 569
Mean RMSD within cluster: 2.72 Å
Standard deviation of RMSDs within cluster: 0.73 Å
Standard deviation of RMSDs superposition to centroid structure: 0.46 Å
Total variance explained by best-fit plane: ~1.00
Variance explained by PC1: 0.92
Variance explained by PC2: 0.08
Number of structures in cluster 2: 331
Mean RMSD within cluster: 1.63 Å
Standard deviation of RMSDs within cluster: 0.56 Å
Standard deviation of RMSDs superposition to centroid structure: 0.37 Å
Total variance explained by best-fit plane: ~1.00
Variance explained by PC1: 0.94
Variance explained by PC2: 0.06
Number of structures in cluster 3: 230
Mean RMSD within cluster: 2.35 Å
Standard deviation of RMSDs within cluster: 0.76 Å
Standard deviation of RMSDs superposition to centroid structure: 0.56 Å
Total variance explained by best-fit plane: ~1.00
Variance explained by PC1: 0.95
Variance explained by PC2: 0.05
Number of structures in cluster 4: 620
Mean RMSD within cluster: 1.58 Å
Standard deviation of RMSDs within cluster: 0.54 Å
Standard deviation of RMSDs superposition to centroid structure: 0.32 Å
Total variance explained by best-fit plane: ~1.00
Variance explained by PC1: 0.83
Variance explained by PC2: 0.17
Number of structures in cluster 5: 352
Mean RMSD within cluster: 1.64 Å
Standard deviation of RMSDs within cluster: 0.59 Å
Standard deviation of RMSDs superposition to centroid structure: 0.40 Å
Total variance explained by best-fit plane: ~1.00
Variance explained by PC1: 0.81
Variance explained by PC2: 0.19
Number of structures in cluster 6: 486
Mean RMSD within cluster: 2.44 Å
Standard deviation of RMSDs within cluster: 0.60 Å
Standard deviation of RMSDs superposition to centroid structure: 0.41 Å
Total variance explained by best-fit plane: ~1.00
Variance explained by PC1: 0.88
Variance explained by PC2: 0.12
Number of structures in cluster 7: 473
Mean RMSD within cluster: 2.55 Å
Standard deviation of RMSDs within cluster: 0.71 Å
Standard deviation of RMSDs superposition to centroid structure: 0.50 Å
Total variance explained by best-fit plane: ~1.00
Variance explained by PC1: 0.80
Variance explained by PC2: 0.20
Number of structures in cluster 8: 732
Mean RMSD within cluster: 2.50 Å
Standard deviation of RMSDs within cluster: 0.60 Å
Standard deviation of RMSDs superposition to centroid structure: 0.49 Å
Total variance explained by best-fit plane: ~1.00
Variance explained by PC1: 0.85
Variance explained by PC2: 0.15
Number of structures in cluster 9: 590
Mean RMSD within cluster: 2.57 Å
Standard deviation of RMSDs within cluster: 0.65 Å
Standard deviation of RMSDs superposition to centroid structure: 0.49 Å
Total variance explained by best-fit plane: ~1.00
Variance explained by PC1: 0.61
Variance explained by PC2: 0.39
Number of structures in cluster 10: 382
Mean RMSD within cluster: 2.58 Å
Standard deviation of RMSDs within cluster: 0.63 Å
Standard deviation of RMSDs superposition to centroid structure: 0.44 Å
Total variance explained by best-fit plane: ~1.00
Variance explained by PC1: 0.67
Variance explained by PC2: 0.33
Number of structures in cluster 11: 280
Mean RMSD within cluster: 2.58 Å
Standard deviation of RMSDs within cluster: 0.63 Å
Standard deviation of RMSDs superposition to centroid structure: 0.42 Å
Total variance explained by best-fit plane: ~1.00
Variance explained by PC1: 0.73
Variance explained by PC2: 0.27
Number of structures in cluster 12: 423
Mean RMSD within cluster: 2.39 Å
Standard deviation of RMSDs within cluster: 0.57 Å
Standard deviation of RMSDs superposition to centroid structure: 0.43 Å
Total variance explained by best-fit plane: ~1.00
Variance explained by PC1: 0.62
Variance explained by PC2: 0.38
Number of structures in cluster 13: 500
Mean RMSD within cluster: 2.06 Å
Standard deviation of RMSDs within cluster: 0.62 Å
Standard deviation of RMSDs superposition to centroid structure: 0.42 Å
Total variance explained by best-fit plane: ~1.00
Variance explained by PC1: 0.60
Variance explained by PC2: 0.40
Number of structures in cluster 14: 294
Mean RMSD within cluster: 2.46 Å
Standard deviation of RMSDs within cluster: 0.73 Å
Standard deviation of RMSDs superposition to centroid structure: 0.52 Å
Total variance explained by best-fit plane: ~1.00
Variance explained by PC1: 0.57
Variance explained by PC2: 0.43
Number of structures in cluster 15: 821
Mean RMSD within cluster: 1.80 Å
Standard deviation of RMSDs within cluster: 0.59 Å
Standard deviation of RMSDs superposition to centroid structure: 0.46 Å
Total variance explained by best-fit plane: ~1.00
Variance explained by PC1: 0.80
Variance explained by PC2: 0.20
Number of structures in cluster 16: 271
Mean RMSD within cluster: 2.04 Å
Standard deviation of RMSDs within cluster: 0.62 Å
Standard deviation of RMSDs superposition to centroid structure: 0.44 Å
Total variance explained by best-fit plane: ~1.00
Variance explained by PC1: 0.80
Variance explained by PC2: 0.20
Number of structures in cluster 17: 329
Mean RMSD within cluster: 1.74 Å
Standard deviation of RMSDs within cluster: 0.66 Å
Standard deviation of RMSDs superposition to centroid structure: 0.42 Å
Total variance explained by best-fit plane: ~1.00
Variance explained by PC1: 0.67
Variance explained by PC2: 0.33
Number of structures in cluster 18: 294
Mean RMSD within cluster: 1.31 Å
Standard deviation of RMSDs within cluster: 0.43 Å
Standard deviation of RMSDs superposition to centroid structure: 0.29 Å
Total variance explained by best-fit plane: ~1.00
Variance explained by PC1: 0.70
Variance explained by PC2: 0.30
Number of structures in cluster 19: 633
Mean RMSD within cluster: 1.57 Å
Standard deviation of RMSDs within cluster: 0.53 Å
Standard deviation of RMSDs superposition to centroid structure: 0.30 Å
Total variance explained by best-fit plane: ~1.00
Variance explained by PC1: 0.56
Variance explained by PC2: 0.44
Number of structures in cluster 20: 561
Mean RMSD within cluster: 2.41 Å
Standard deviation of RMSDs within cluster: 0.64 Å
Standard deviation of RMSDs superposition to centroid structure: 0.46 Å
Total variance explained by best-fit plane: ~1.00
Variance explained by PC1: 0.67
Variance explained by PC2: 0.33
Number of structures in cluster 21: 313
Mean RMSD within cluster: 1.11 Å
Standard deviation of RMSDs within cluster: 0.41 Å
Standard deviation of RMSDs superposition to centroid structure: 0.28 Å
Total variance explained by best-fit plane: ~1.00
Variance explained by PC1: 0.73
Variance explained by PC2: 0.27
Number of structures in cluster 22: 674
Mean RMSD within cluster: 1.60 Å
Standard deviation of RMSDs within cluster: 0.53 Å
Standard deviation of RMSDs superposition to centroid structure: 0.31 Å
Total variance explained by best-fit plane: ~1.00
Variance explained by PC1: 0.63
Variance explained by PC2: 0.37
def plot_labeled_umap(df):
    fig = plt.figure(figsize=(10, 6))

    sns.scatterplot(data=df, x='UMAP_1', y='UMAP_2',
                    hue='Cluster', s=5, alpha=1, palette=cluster_palette)

    plt.title('UMAP Projection of Sampled and Template Fragments')
    plt.xlabel('UMAP 1')
    plt.ylabel('UMAP 2')
    plt.legend(title='Clusters')
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    plt.grid(False)
    #plt.ylim(-0.5, 7)
    plt.gca().spines['left'].set_color('dimgray')
    plt.gca().spines['bottom'].set_color('dimgray')
    #plt.gca().set_yticks([0, 2, 4, 6])

    plt.tight_layout()
    plt.savefig("plots/rmsd_projected_labeled_umap.png", dpi=600, bbox_inches="tight")
    plt.savefig("plots/rmsd_projected_labeled_umap.pdf", bbox_inches="tight")
    plt.savefig("plots/rmsd_projected_labeled_umap.svg", bbox_inches="tight")
    plt.show()

umapca_df = pd.DataFrame({
'PDB_File': pdb_files_projected,
'UMAP_1': umapca_emb[:, 0],
'UMAP_2': umapca_emb[:, 1]
})
K_merged = pd.merge(K, umapca_df, on="PDB_File")
plot_labeled_umap(K_merged)
unique_cluster_names = K['Cluster'].unique()
cluster_order_numeric_sorted_str = sorted(unique_cluster_names, key=int)
K['Cluster'] = pd.Categorical(K['Cluster'], categories=cluster_order_numeric_sorted_str, ordered=True)

# Calculate value counts based on the ordered categorical 'Cluster'
# This will now respect the order defined in the categorical type
value_counts_ordered = K['Cluster'].value_counts().sort_index()
#print([f"Cluster {i+1}, {val}" for i, val in enumerate(value_counts_ordered.values)])
# Prepare data for "Template Fragments"
stacked_counts = K.groupby(['Cluster', 'Category']).size().reset_index(name='counts')
stacked_counts = stacked_counts[stacked_counts['Category'] == "Templates"]
#print([f"Cluster {i+1}, {val}" for i, val in enumerate(stacked_counts['counts'].values)])

plt.figure(figsize=(25, 2))
# Plot the total cluster sizes, now ordered by cluster number
ax = sns.barplot(x=value_counts_ordered.index, y=value_counts_ordered.values, color='midnightblue')
# Plot the "Template Fragments" counts, which will also be ordered by cluster number
#ax = sns.barplot(x='Cluster', y='counts', color='mediumorchid', data=stacked_counts) # TODO: purple distribution (kde)
tmp = stacked_counts.set_index('Cluster')['counts']
x_template = np.repeat([i-1 for i in tmp.index.values], tmp.values)
x_template_num = pd.Series(x_template).astype("int64").to_numpy(dtype=float)

# Match x-limits to the clusters
xmin, xmax = value_counts_ordered.index.min(), value_counts_ordered.index.max()
ax.set_xlim(-1, xmax)

# --- KDE on a secondary y-axis ---
ax2 = ax.twinx()
sns.kdeplot(
    x=x_template_num, ax=ax2, fill=True, alpha=0.3,
    bw_adjust=0.1, color='mediumorchid', clip=(xmin - 2, xmax + 2)
)

plt.xlabel('Cluster')
plt.ylabel('Count')
ax.grid(False)
ax2.grid(False)
# remove spines
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_color('dimgray')
ax2.spines['left'].set_visible(False)
ax2.spines['top'].set_visible(False)
ax2.spines['bottom'].set_visible(False)
ax2.spines['right'].set_color('dimgray')
ax2.set_ylabel('Density')
plt.xticks(rotation=90)
plt.tight_layout()

plt.savefig("plots/cluster_counts_projected_bar.png", dpi=600, bbox_inches="tight")
plt.savefig("plots/cluster_counts_projected_bar.pdf", bbox_inches="tight")
plt.savefig("plots/cluster_counts_projected_bar.svg", bbox_inches="tight")
plt.show()
# Plot a violin plot of total energies within the cluster
total_e_plotting_df = []

for cluster in cluster_order_numeric_sorted_str:
    cluster_rows = K[K['Cluster'] == cluster]
    
    for pdb_file in cluster_rows['PDB_File'].tolist():
        val_repaired = repaired_df.loc[repaired_df['PDB_File'] == pdb_file, 'Total_Energy']
        val_template = template_df.loc[template_df['PDB_File'] == pdb_file, 'Total_Energy']

        if not val_repaired.empty:
            total_energy = val_repaired.iloc[0]
        elif not val_template.empty:
            total_energy = val_template.iloc[0]
        else:
            print(f"⚠️ No Total_Energy found for {pdb_file}")
            total_energy = None
        total_e_plotting_df.append({'Cluster': cluster, 'Total Energy': total_energy})

total_e_plotting_df = pd.DataFrame(total_e_plotting_df)

plt.figure(figsize=(10, 6))
sns.violinplot(
    x='Cluster',
    y='Total Energy',
    data=total_e_plotting_df,
    palette=cluster_palette,
    inner='quartile',
    width=0.7,
    cut=0,
)

plt.xlabel('Cluster', fontsize=12)
plt.ylabel('Total Energy', fontsize=12)
plt.title('Distribution of Total Energies within Clusters', fontsize=14)
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)

#plt.ylim(-0.5, 7) # TODO
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['left'].set_color('dimgray')
plt.gca().spines['bottom'].set_visible(False)
#plt.gca().set_yticks([0, 2, 4, 6]) # TODO

plt.grid(axis='y', linestyle='--', alpha=0.7) #, linewidth=2)
plt.tight_layout()

plt.savefig("plots/total_energy_within_clusters.png", dpi=600, bbox_inches="tight")
plt.savefig("plots/total_energy_within_clusters.pdf", bbox_inches="tight")
plt.savefig("plots/total_energy_within_clusters.svg", bbox_inches="tight")
plt.show()
pdb_to_idx = {file: idx for idx, file in enumerate(pdb_files_projected)}
rmsd_plotting_df = []

for cluster in cluster_order_numeric_sorted_str:
    cluster_rows = K[K['Cluster'] == cluster]
    cluster_indices = cluster_rows['PDB_File'].apply(lambda x: pdb_to_idx[x]).tolist()
    cluster_rmsds_submatrix = rmsd_X_projected[np.ix_(cluster_indices, cluster_indices)]
    # Flatten the upper triangular part
    upper_triangle_indices = np.triu_indices_from(cluster_rmsds_submatrix, k=1)
    all_pairwise_cluster_distances = cluster_rmsds_submatrix[upper_triangle_indices]

    rmsds_test=[]
    for dist in all_pairwise_cluster_distances:
        if dist == 0:
            continue
        rmsds_test.append(dist)
        rmsd_plotting_df.append({'Cluster': cluster, 'RMSD': dist})

rmsd_plotting_df = pd.DataFrame(rmsd_plotting_df)

plt.figure(figsize=(10, 6))
sns.violinplot(
    x='Cluster',
    y='RMSD',
    data=rmsd_plotting_df,
    palette=cluster_palette,
    inner='quartile',
    width=0.7,
    cut=0,
)

plt.xlabel('Cluster', fontsize=12)
plt.ylabel('Pairwise RMSD', fontsize=12)
plt.title('Distribution of Pairwise RMSDs within Clusters', fontsize=14)
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)

plt.ylim(-0.5, 7)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['left'].set_color('dimgray')
plt.gca().spines['bottom'].set_visible(False)
plt.gca().set_yticks([0, 2, 4, 6])

plt.grid(axis='y', linestyle='--', alpha=0.7) #, linewidth=2)
plt.tight_layout()

plt.savefig("plots/rmsd_projected_within_cluster_violin.png", dpi=600, bbox_inches="tight")
plt.savefig("plots/rmsd_projected_within_cluster_violin.pdf", bbox_inches="tight")
plt.savefig("plots/rmsd_projected_within_cluster_violin.svg", bbox_inches="tight")
plt.show()

Second level

len(results_df)
len(pdb_files_max)
10976
distance_matrix = DistanceMatrix(rmsd_X_max) # FIXME: MAX_STACK SUPERIMPOSER DOESNT GIVE US A `DISTANCE` MATRIX

# Perform PCoA
num_top_components = 2
pcoa_results = pcoa(distance_matrix, number_of_dimensions=3000)
results_df = pcoa_results.samples.iloc[:, :num_top_components].copy()
## Create column names for the DataFrame based on the number of top components
column_names = [f'PCoA_{i+1}' for i in range(num_top_components)]
results_df.columns = column_names
print_pcoa_explained_var(pcoa_results, num_top_components)


# UMAP
## Initialize UMAP with 'precomputed' metric
umap_model = umap.UMAP(n_neighbors=30, n_components=2, metric='precomputed', random_state=42)
emb = umap_model.fit_transform(rmsd_X_max)

num_templates_max = sum('Repair' not in file for file in pdb_files_max)
category_spatial = np.array(
    ["Templates"] * num_templates_max + ["Extracted"] * (len(pdb_files_max) - num_templates_max)
)

plot_pca_and_umap(results_df, emb, category_spatial, {"Templates": "mediumorchid", "Extracted": "cornflowerblue"}, dot_size=10)
Raw Eigenvalues:
PC1    25362.739710
PC2    15010.507916
dtype: float64

Explained Variance Ratio:
PC1    0.153135
PC2    0.090630
dtype: float64

Cumulative Explained Variance:
PC1    0.153135
PC2    0.243765
dtype: float64
{'Extracted': 'Extracted (n=10806)', 'Templates': 'Templates (n=170)'}
Number of templates used 170
Error, trying to plot a df
components_explaining = num_components_explaining(percent_of_variance=90, pcoa_results=pcoa_results)
umapca_model = umap.UMAP(n_neighbors=30, n_components=2, random_state=42, min_dist=0.1, metric='precomputed') #, min_dist=0.1, spread=1)
umapca_emb = umapca_model.fit_transform(rmsd_X_max) # FIXME: pcoa_results.samples.iloc[:, :components_explaining].values)

plot_umap(num_templates_max, pdb_files_max, umapca_emb, save_path="plots/rmsd_max_umap", dot_size=10)
90% of variance is explained by the first 1025 components
pdb_to_idx = {file: idx for idx, file in enumerate(pdb_files_max)}

def calculate_summed_distance_within_cluster_spatial(row):
    current_pdb = row['PDB_File']
    current_cluster = row['Cluster']

    # Get all PDB files in the current cluster
    cluster_members_df = K_spatial[K_spatial['Cluster'] == current_cluster]
    cluster_member_pdbs = cluster_members_df['PDB_File'].tolist()

    # Get the index of the current and other cluster members' PDB files in the RMSD matrix
    current_pdb_idx = pdb_to_idx[current_pdb]
    cluster_member_indices = [pdb_to_idx[pdb] for pdb in cluster_member_pdbs]

    # Extract the relevant row from the RMSD matrix (distances from current PDB)
    distances_from_current_pdb = rmsd_X_max[current_pdb_idx, :] # FIXME (rmsd_X_max or rmsd_X_projected)
    summed_distance = distances_from_current_pdb[cluster_member_indices].sum()
    # summed_distance_squared = (distances_from_current_pdb[cluster_member_indices]**2).sum() # TODO
    # if i want to know which files distort the cluster
    return summed_distance

# Perform hierarchical clustering
t_distance_spatial = 25
Z_spatial = sch.linkage(squareform(rmsd_X_max), method='ward', metric='precomputed')

labels_spatial = sch.fcluster(Z_spatial, t=t_distance_spatial, criterion='distance')

num_clusters_spatial = len(np.unique(labels_spatial))
print(f"Found {num_clusters_spatial} number of clusters... labeled {len(labels_spatial)} structures")

K_spatial = pd.DataFrame({
    'PDB_File': pdb_files_max,
    'Cluster': labels_spatial,
    'Category': category_spatial
})
# Apply the function to each row
K_spatial['Summed_Distances_Within_Cluster'] = K_spatial.apply(calculate_summed_distance_within_cluster_spatial, axis=1)

cluster_palette = sns.color_palette("husl", num_clusters_spatial)
plot_dendrogram(Z_spatial, pdb_files_max, num_clusters_spatial, t_distance_spatial, save_path="plots/rmsd_max_dendrogram", custom_palette=cluster_palette)
Found 29 number of clusters... labeled 10976 structures
def calculate_summed_distance_within_cluster_spatial_sub(row):
    current_pdb = row['PDB_File']
    current_cluster = row['Cluster']

    # Get all PDB files in the current cluster
    cluster_members_df = K_hier[K_hier['Cluster'] == current_cluster]
    cluster_member_pdbs = cluster_members_df['PDB_File'].tolist()

    # Get the index of the current and other cluster members' PDB files in the RMSD matrix
    current_pdb_idx = pdb_to_idx[current_pdb]
    cluster_member_indices = [pdb_to_idx[pdb] for pdb in cluster_member_pdbs]

    # Extract the relevant row from the RMSD matrix (distances from current PDB)
    distances_from_current_pdb = rmsd_X_projected[current_pdb_idx, :] # FIXME
    summed_distance = distances_from_current_pdb[cluster_member_indices].sum()
    # summed_distance_squared = (distances_from_current_pdb[cluster_member_indices]**2).sum() # TODO
    # if i want to know which files distort the cluster
    return summed_distance

def hierarchical_dendrogram(t_dist, cluster_num, cluster_df, pdb_files, X, plot=True):
    # Perform hierarchical clustering and return sub-cluster objects    
    if cluster_num > cluster_df['Cluster'].max() or cluster_num <= 0:
        print(f"Pick correct cluster number... not {cluster_num}")
        return
    
    sub_cluster_df = cluster_df[cluster_df['Cluster'] == cluster_num]
    query_files = sub_cluster_df['PDB_File'].to_numpy()
    pdb_files = np.asarray([file.split('/')[-1] for file in pdb_files])
    
    X_sub = X.copy()
    mask = np.isin(pdb_files, query_files)
    selected = np.flatnonzero(mask) 
    selected_files = pdb_files[selected]

    X_sub = X_sub[np.ix_(selected, selected)]

    Z_sub = sch.linkage(squareform(X_sub), method='ward', metric='precomputed')
    labels_sub = sch.fcluster(Z_sub, t=t_dist, criterion='distance')

    num_clusters_sub = len(np.unique(labels_sub))
    print(f"Found {num_clusters_sub} number of clusters... labeled {len(labels_sub)} structures")

    cluster_palette_sub = sns.color_palette("husl", num_clusters_sub)
    if plot:
        plot_dendrogram(Z_sub, selected_files, num_clusters_sub, t_dist, save_path=f"plots/rmsd_max_of_projected_cluster_{cluster_num}_dendrogram", custom_palette=cluster_palette_sub)

    K_hier = pd.DataFrame({
        'PDB_File': selected_files,
        'Cluster': labels_sub,
    })
    return K_hier, X_sub

# Get a clustering of subcluster number 11 with ward threshold 10
subcluster_num = 8
K_hier, X_sub = hierarchical_dendrogram(10, subcluster_num, K, pdb_files_max, rmsd_X_max)
pdb_to_idx = {file: idx for idx, file in enumerate(K_hier['PDB_File'])}
# Apply the function to each row
K_hier['Summed_Distances_Within_Cluster'] = K_hier.apply(calculate_summed_distance_within_cluster_spatial_sub, axis=1)
Found 14 number of clusters... labeled 732 structures
# Plot each 3D-clustered subcluster in 2D
for i in range(len(K_hier['Cluster'].unique())):
    plot_projected_cluster(i+1, projected_alignment=False, X=X_sub, K_labeled=K_hier) #, save_path=f"plots/max_clusters/sub{subcluster_num}_max_cluster")

And lastly, let's see the distribution of distances between stacks (betasheets-other and templates-extracted).

from pairwise_rmsd_matrix import get_atoms_from_pdb
from pathlib import Path

def get_distances_betw_positions(path):
    atoms = get_atoms_from_pdb(path)

    coords = np.array([atom.get_coord() for atom in atoms])
    if coords.shape[0] != 60:
        print(f"Expected 60 atoms, got {coords.shape[0]} atoms in {path}")
        return None

    reshaped_coords = coords.reshape(5, 12, 3)  # (layers, positions, xyz)

    # two stacks
    stack_a = reshaped_coords[:, :6, :]   # (5, 6, 3)
    stack_b = reshaped_coords[:, 6:12, :] # (5, 6, 3)

    # distances (per layer, per position)
    dists = np.linalg.norm(stack_a - stack_b, axis=-1)  # (5, 6)

    # average across layers
    avg_dists = dists.mean(axis=0)  # (6,)

    return avg_dists

def get_rows_w_dists(files, directory):
    # Collect results
    rows = []
    skipped_files = []

    for file in files:
        file_path = Path(directory) / file
        if not file_path.exists():
            print(f"Warning: File not found: {file_path}. Skipping.")
            skipped_files.append(file)
            continue

        dists = get_distances_betw_positions(file_path)
        if dists is None:
            print(f"Warning: Could not process file {file_path}. Skipping.")
            skipped_files.append(file)
            continue

        row = {"PDB_File": file}
        for i, d in enumerate(dists):
            row[f"Dist_{i}"] = d
        rows.append(row)
        
    print(f"Skipped these files: {skipped_files}")
    return rows

dist_rows = []
#dist_rows.extend(get_rows_w_dists(pdb_files_projected[:170], "templates"))
#dist_rows.extend(get_rows_w_dists(pdb_files_projected[170:], "fragment_pairs"))
dist_rows.extend(get_rows_w_dists(pdb_files_projected, "dist_matrix_pdbs_five_two"))

# Build DataFrame
df_dists = pd.DataFrame(dist_rows)
Skipped these files: []
dist_cols = [f"Dist_{i}" for i in range(6)]

df_dists["Dist_avg"] = df_dists[dist_cols].mean(axis=1)
df_dists["Dist_min"] = df_dists[dist_cols].min(axis=1)

df_dists['Category'] = 'Extracted'
df_dists.loc[:169, "Category"] = "Templates"

df_dists.head()
PDB_File Dist_0 Dist_1 Dist_2 Dist_3 Dist_4 Dist_5 Dist_avg Dist_min Category
0 1yjo.pdb 10.991614 10.187029 10.138923 9.859361 9.997387 10.763280 10.322932 9.859361 Templates
1 1yjp_1.pdb 9.255436 9.483084 8.624692 8.633715 9.315870 10.204151 9.252825 8.624692 Templates
2 1yjp_2.pdb 10.771128 10.031431 9.996222 9.763259 9.832519 10.723134 10.186282 9.763259 Templates
3 2kib_1.pdb 10.302935 10.160251 10.290537 10.448471 10.461041 10.956944 10.436696 10.160251 Templates
4 2kib_2.pdb 10.161047 10.350741 10.451340 10.540387 10.783261 10.989037 10.545968 10.161047 Templates
plt.figure(figsize=(8, 6))

sns.violinplot(
    x="Category", y="Dist_avg",
    data=df_dists, inner="quartile",
    palette={"Extracted": "cornflowerblue", "Templates": "mediumorchid"},
    alpha=0.8, saturation=1,
    cut=0
)

plt.xlabel("Fragment Type")
plt.ylabel("Average Distance")
plt.title("Comparison of Average Distance Between Template and Extracted Fragments")

# Style adjustments
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['left'].set_color('dimgray')
plt.gca().spines['bottom'].set_color('dimgray')
plt.grid(False)

# Legend
ax = plt.gca()
handles, labels = ax.get_legend_handles_labels()
if handles:  # violinplot without hue might not create legend
    ax.legend(handles, labels, title="Fragment Type")

plt.savefig("plots/avg_dist_comparison.png", dpi=600, bbox_inches="tight")
plt.savefig("plots/avg_dist_comparison.pdf", bbox_inches="tight")
plt.savefig("plots/avg_dist_comparison.svg", bbox_inches="tight")
plt.show()

Extract chains from each original protein (where interaction was found) and perform clustering with CD-HIT.

import pandas as pd
import ast

# From each raw interacting pair, extract full sequences of those two interacting chains
file_chain_ids = pd.read_csv("results/file_chain_ids.csv")
print(f"Extracted chain IDs from a total of {len(file_chain_ids)} files.")

# Remove intrachain interactions
file_chain_ids['chain_ids'] = file_chain_ids['chain_ids'].apply(ast.literal_eval)
file_chain_ids = file_chain_ids[file_chain_ids['chain_ids'].apply(lambda x: len(x) == 10)] # Removing intrachain interactions
# Some files can even have between 5 and 10 chain ids (start as intrachain and subsequent layer expansion doesn't account for this)
# this creates new layers which have its own two chains per layer
file_chain_ids.reset_index(drop=True, inplace=True)
print(f"After rejecting intrachain pairs we have {len(file_chain_ids)} files (with interchain interactions).")
file_chain_ids.head()
Extracted chain IDs from a total of 61951 files.
After rejecting intrachain pairs we have 18829 files (with interchain interactions).
file chain_ids
0 8v1n_pair_342_Repair.pdb [C, D, G, H, K, L, O, P, Q, R]
1 8ot4_pair_130_Repair.pdb [C, D, E, F, G, H, I, J, K, L]
2 2m4j_pair_183_Repair.pdb [A, B, D, E, G, H, J, K, L, M]
3 8otj_pair_284_Repair.pdb [A, B, C, D, E, F, G, H, I, J]
4 7qkv_pair_797_Repair.pdb [A, C, E, F, H, I, L, M, N, O]
import pickle as pkl

with open("results/chain_sequences.pkl", "rb") as f:
    chain_sequences = pkl.load(f)
len(chain_sequences.keys())
482
import pandas as pd
import subprocess
from pathlib import Path
cdhit_exec = Path("./cdhit/cd-hit")

def run_cd_hit(sequences, output_prefix="cdhit_out", identity=0.9, word_size=5, min_seq_length=6):
    """
    Run CD-HIT on a list of sequences.
    Returns a dict mapping sequence -> cluster_id.
    """
    # Write sequences to fasta
    fasta_file = Path(f"{output_prefix}.fasta")
    with open(fasta_file, "w") as f:
        for i, seq in enumerate(sequences):
            f.write(f">seq{i}\n{seq}\n")
    
    # Run CD-HIT
    subprocess.run([
        str(cdhit_exec),
        "-i", str(fasta_file),
        "-o", f"{output_prefix}.clstr_out",
        "-c", str(identity),       # sequence identity threshold (0–1.0)
        "-n", str(word_size),      # word length (depends on identity cutoff)
        "-l", str(min_seq_length), # minimal sequence length (default 6)
    ], check=True)

    # Parse cluster output
    cluster_map = {}
    cluster_id = -1
    with open(f"{output_prefix}.clstr_out.clstr") as f:
        for line in f:
            if line.startswith(">Cluster"):
                cluster_id += 1
            else:
                # Extract sequence index from line
                idx = int(line.split(">seq")[1].split("...")[0])
                cluster_map[idx] = cluster_id

    return cluster_map

# Example usage
if __name__ == "__main__":
    # Prepare sequences
    all_sequences = []
    seq_index_map = []  # keep track of (seq_id, pdb_id, chain_id)

    for protein_key in chain_sequences.keys():
        for chain_id, seq in chain_sequences[protein_key].items():
            seq_id = len(all_sequences)
            all_sequences.append(seq)
            seq_index_map.append((seq_id, protein_key, chain_id))

    print(f"{len(all_sequences)} sequences collected. {len(set(all_sequences))} are unique.")

    # Run CD-HIT
    cluster_map = run_cd_hit(all_sequences, output_prefix="myclusters", identity=0.9, word_size=5, min_seq_length=6)

    # Map clusters back to dataframe
    seq_labels = []
    for seq_id, protein_key, chain_id in seq_index_map:
        cluster_id = cluster_map.get(seq_id, -1)
        seq_labels.append((protein_key, chain_id, cluster_id))

    seq_labels_df = pd.DataFrame(seq_labels, columns=["pdb_id", "chain_id", "cluster_id"])
4231 sequences collected. 263 are unique.
================================================================
Program: CD-HIT, V4.8.1 (+OpenMP), Sep 24 2025, 14:16:23
Command: cdhit/cd-hit -i myclusters.fasta -o
         myclusters.clstr_out -c 0.9 -n 5 -l 6

Started: Sat Nov  8 10:22:14 2025
================================================================
                            Output                              
----------------------------------------------------------------
total seq: 4136
longest and shortest : 139 and 7
Total letters: 240538
Sequences have been sorted

Approximated minimal memory consumption:
Sequence        : 0M
Buffer          : 1 X 10M = 10M
Table           : 1 X 65M = 65M
Miscellaneous   : 0M
Total           : 76M

Table limit with the given memory limit:
Max number of representatives: 4000000
Max number of word counting entries: 90402613

comparing sequences from          0  to       4136
....
     4136  finished         49  clusters

Approximated maximum memory consumption: 76M
writing new database
writing clustering information
program completed !

Total CPU time 0.04
seq_labels_df['cluster_id'].value_counts().sort_index()
cluster_id
-1       95
 0      126
 1      101
 2       18
 3       86
 4        9
 5       15
 6     1103
 7     1088
 8        5
 9       89
 10       6
 11       5
 12      12
 13       8
 14      24
 15      12
 16      42
 17       6
 18       5
 19       7
 20      10
 21      10
 22       5
 23      10
 24      78
 25      36
 26      24
 27      20
 28      40
 29       6
 30     529
 31      85
 32       8
 33       5
 34     297
 35       3
 36       9
 37       4
 38       6
 39       8
 40      16
 41      20
 42      12
 43      20
 44      10
 45      18
 46      27
 47      48
 48       5
Name: count, dtype: int64

For some reason the labeling of CD-HIT's resulting labels is not incremental. Let's fix it.

# check whether those fragment_pairs are in the same 3D cluster
all_clusters = []
skipped_files = []
skipped_protein = []
cluster_not_found = []
for idx in range(len(K['Cluster'].unique())):
    subcluster = idx + 1
    K_hier, X_sub = hierarchical_dendrogram(10, subcluster, K, pdb_files_max, rmsd_X_max, plot=False)
    pdb_to_idx = {file: idx for idx, file in enumerate(K_hier['PDB_File'])}

    for K_hier_idx, row in K_hier.iterrows():
        pdb_file = row['PDB_File']
        pdb_id = pdb_file[:4]
        chain_ids = file_chain_ids[file_chain_ids['file'] == pdb_file]['chain_ids'].values
        if len(chain_ids) == 0: #or not pdb_id in seq_labels_df['pdb_id'].values:
            skipped_files.append(pdb_file)
            skipped_protein.append(pdb_id)
            # print(f"Skipping {pdb_file}...")
            continue

        chain_ids = chain_ids[0]
        seq_clusters = []
        for chain_id in chain_ids:
            # print(f"Looking for {pdb_id} {chain_id}...")
            cluster_id_vals = seq_labels_df[(seq_labels_df['pdb_id'] == pdb_id) & (seq_labels_df['chain_id'] == chain_id)]['cluster_id'].values
            if len(cluster_id_vals) == 0:
                cluster_not_found.append((pdb_id, chain_id))
                # print(f"⚠️ No cluster found for {pdb_id} {chain_id}")
                continue
            seq_clusters.append(seq_labels_df[(seq_labels_df['pdb_id'] == pdb_id) & (seq_labels_df['chain_id'] == chain_id)]['cluster_id'].values[0])

        all_clusters.append({
            'PDB_File': pdb_file,
            '2D_Cluster': idx,
            '3D_Subcluster': row['Cluster'],
            'Seq_Clusters': seq_clusters,
        })

print(f"Skipped {len(skipped_files)} files: {skipped_files[:5]}...")
print(f"Cluster not found for {len(cluster_not_found)}: {cluster_not_found[:5]}...")
all_clusters_df = pd.DataFrame(all_clusters)
all_clusters_df.head()
Found 13 number of clusters... labeled 569 structures
Found 3 number of clusters... labeled 331 structures
Found 6 number of clusters... labeled 230 structures
Found 4 number of clusters... labeled 620 structures
Found 3 number of clusters... labeled 352 structures
Found 10 number of clusters... labeled 486 structures
Found 9 number of clusters... labeled 473 structures
Found 14 number of clusters... labeled 732 structures
Found 11 number of clusters... labeled 590 structures
Found 6 number of clusters... labeled 382 structures
Found 5 number of clusters... labeled 280 structures
Found 8 number of clusters... labeled 423 structures
Found 7 number of clusters... labeled 500 structures
Found 5 number of clusters... labeled 294 structures
Found 9 number of clusters... labeled 821 structures
Found 4 number of clusters... labeled 271 structures
Found 4 number of clusters... labeled 329 structures
Found 3 number of clusters... labeled 294 structures
Found 4 number of clusters... labeled 633 structures
Found 8 number of clusters... labeled 561 structures
Found 2 number of clusters... labeled 313 structures
Found 6 number of clusters... labeled 674 structures
Skipped 7266 files: ['2okz_a.pdb', '4znn_a_2.pdb', '8adv_pair_413_Repair.pdb', '7nci_pair_286_Repair.pdb', '7qk5_pair_504_Repair.pdb']...
Cluster not found for 6577: [('7ynf', 'G'), ('7ynf', 'H'), ('7ynf', 'I'), ('7ynf', 'J'), ('7xjx', 'G')]...
PDB_File 2D_Cluster 3D_Subcluster Seq_Clusters
0 7ynf_pair_068_Repair.pdb 0 10 [4, 4, 4, 4, 4, 4]
1 7xjx_pair_066_Repair.pdb 0 2 [6, 6, 6, 6, 6, 6]
2 6n3c_pair_193_Repair.pdb 0 6 [5, 5, 5, 5, 5, 5, 5, 5, 5, 5]
3 9bbm_pair_500_Repair.pdb 0 8 [10, 10, 10, 10, 10, 10]
4 8x7o_pair_204_Repair.pdb 0 1 [9, 9, 9, 9, 9, 9, 9, 9, 9, 9]

Skipped files are because of the load of intrachain interacting pairs. And clusters not found for a lot of chains are because some chain IDs come from multi-layered structures with number of layers > 5 => these chain IDs weren't labeled (only those chain ids from corrected interacting pairs).

print(f"From the sample we have a total of {len(all_clusters_df)} correctly labeled interchain interacting pairs.")
From the sample we have a total of 2892 correctly labeled interchain interacting pairs.
all_clusters_df['Seq_Clusters'] = all_clusters_df['Seq_Clusters'].apply(lambda x: sorted(list(set(x)))) # unique and sort
all_clusters_df['Num_Seq_Clusters'] = all_clusters_df['Seq_Clusters'].apply(lambda x: len(x))
all_clusters_df[all_clusters_df['Num_Seq_Clusters'] > 1].value_counts() # TODO: If nothing - no heterotypic extracted pair
Series([], Name: count, dtype: int64)
all_clusters_df['Seq_Clusters'] = all_clusters_df['Seq_Clusters'].apply(lambda x: int(list(x)[0]))
K_2d = all_clusters_df[['PDB_File', '2D_Cluster', 'Seq_Clusters']].copy()
K_2d = (
    all_clusters_df
    .groupby("2D_Cluster")["Seq_Clusters"]
    .agg(lambda x: sorted(set(x)))
    .reset_index(name="Unique_Seq_Clusters")
)

K_2d
2D_Cluster Unique_Seq_Clusters
0 0 [0, 1, 4, 5, 6, 7, 8, 9, 10]
1 1 [0, 1, 4, 5, 6, 7, 8, 9, 10]
2 2 [0, 1, 4, 5, 6, 7, 9, 10]
3 3 [0, 1, 2, 4, 5, 6, 7, 8, 9, 10]
4 4 [1, 2, 4, 5, 6, 8, 9, 10]
5 5 [1, 2, 4, 5, 8, 9, 10]
6 6 [0, 1, 4, 5, 6, 7, 8, 9, 10]
7 7 [0, 1, 4, 5, 6, 7, 8, 9, 10]
8 8 [0, 1, 4, 5, 6, 7, 8, 9, 10]
9 9 [1, 4, 5, 6, 8, 9, 10]
10 10 [2, 4, 5, 8, 9]
11 11 [4, 5, 7, 8, 9, 10]
12 12 [1, 4, 5, 6, 8, 9, 10]
13 13 [1, 4, 5, 6, 8, 9, 10]
14 14 [0, 1, 2, 4, 5, 6, 7, 8, 9, 10]
15 15 [1, 4, 5, 8, 9, 10]
16 16 [1, 4, 5, 6, 7, 8, 9, 10]
17 17 [0, 1, 2, 4, 5, 6, 7, 8, 9, 10]
18 18 [2, 4, 5, 7, 8, 9]
19 19 [4, 5, 8, 9, 10]
20 20 [4, 5, 9]
21 21 [2, 4, 5, 8, 9, 10]
K_3d = all_clusters_df[['PDB_File', '2D_Cluster', '3D_Subcluster', 'Seq_Clusters']].copy()
K_3d = (
    all_clusters_df
    .groupby(["2D_Cluster", "3D_Subcluster"])["Seq_Clusters"]
    .agg(lambda x: sorted(set(x)))
    .reset_index(name="Unique_Seq_Clusters")
)
K_3d['num_unique_seq_clusters'] = K_3d['Unique_Seq_Clusters'].apply(lambda x: len(x))

K_3d
2D_Cluster 3D_Subcluster Unique_Seq_Clusters num_unique_seq_clusters
0 0 1 [1, 4, 5, 7, 9] 5
1 0 2 [1, 4, 5, 6, 9, 10] 6
2 0 3 [1, 4, 8, 9] 4
3 0 4 [1, 4, 5, 7, 9] 5
4 0 5 [4, 5] 2
... ... ... ... ...
126 21 1 [4, 5, 8, 9] 4
127 21 2 [2, 4, 5, 9, 10] 5
128 21 3 [4, 5] 2
129 21 4 [4, 5] 2
130 21 6 [4, 5] 2

131 rows × 4 columns

all_clusters_df["Protein_ID"] = all_clusters_df["PDB_File"].str[:4]

K_3d = (
    all_clusters_df
    .groupby(["2D_Cluster", "3D_Subcluster"])
    .agg({
        "Seq_Clusters": lambda x: sorted(set(x)),
        "Protein_ID":   lambda x: sorted(set(x)),
        "PDB_File":     lambda x: sorted(set(x)),
    })
    .reset_index()
)

K_3d["num_unique_seq_clusters"] = K_3d["Seq_Clusters"].apply(len)
K_3d["num_unique_proteins"]     = K_3d["Protein_ID"].apply(len)
K_3d["num_unique_pdb_files"]    = K_3d["PDB_File"].apply(len)
K_3d
2D_Cluster 3D_Subcluster Seq_Clusters Protein_ID PDB_File num_unique_seq_clusters num_unique_proteins num_unique_pdb_files
0 0 1 [1, 4, 5, 7, 9] [6l4s, 6n3a, 7bx7, 7q66, 7v4d, 7yat, 7yk2, 7yn... [6l4s_pair_198_Repair.pdb, 6l4s_pair_297_Repai... 5 16 21
1 0 2 [1, 4, 5, 6, 9, 10] [6l4s, 6n3a, 7e0f, 7p68, 7p6c, 7qky, 7v4d, 7wm... [6l4s_pair_282_Repair.pdb, 6n3a_pair_029_Repai... 6 22 26
2 0 3 [1, 4, 8, 9] [6l4s, 7bx7, 7e0f, 8fpt, 8x7l, 8x7o, 8x7p, 8x7... [6l4s_pair_290_Repair.pdb, 7bx7_pair_279_Repai... 4 9 13
3 0 4 [1, 4, 5, 7, 9] [7m62, 7yat, 7yk8, 7ynf, 8qpz, 8x7b, 8x7m, 8x7... [7m62_pair_157_Repair.pdb, 7yat_pair_137_Repai... 5 10 14
4 0 5 [4, 5] [2m4j, 6n3c, 7q65, 7qkf, 7qkw, 8az6, 8bg9, 8dj... [2m4j_pair_022_Repair.pdb, 2m4j_pair_032_Repai... 2 11 14
... ... ... ... ... ... ... ... ...
126 21 1 [4, 5, 8, 9] [6n37, 6xfm, 6zrr, 7bx7, 7qkf, 7ykw, 7ypg, 8az... [6n37_pair_494_Repair.pdb, 6xfm_pair_246_Repai... 4 17 31
127 21 2 [2, 4, 5, 9, 10] [2lmo, 6n3b, 6xfm, 7qkh, 7qkw, 7qky, 7ypg, 8bg... [2lmo_pair_046_Repair.pdb, 6n3b_pair_186_Repai... 5 12 15
128 21 3 [4, 5] [6n3b, 6nwp, 6uur, 7qkh, 8ot6] [6n3b_pair_175_Repair.pdb, 6nwp_pair_163_Repai... 2 5 5
129 21 4 [4, 5] [7bx7, 7qkf, 7qkh, 7r5h, 7ypg, 8az6, 8ci8, 8ec... [7bx7_pair_179_Repair.pdb, 7qkf_pair_025_Repai... 2 11 27
130 21 6 [4, 5] [6qjq, 6xyo, 7qkh, 7ypg, 8az3, 8az4, 8az5, 8az... [6qjq_pair_066_Repair.pdb, 6xyo_pair_175_Repai... 2 17 21

131 rows × 8 columns

We can already see that we have more than one unique sequential clusters labels (from CD-HIT) in single 3D clusters. This suggests a possibility of heterotypic interaction. Let's inspect some of them in pymol (not shown here).

K_3d["num_unique_seq_clusters"].unique()
array([ 5,  6,  4,  2,  7,  3,  9,  1,  8, 10])
K_3d[(K_3d['2D_Cluster'] == 0) & (K_3d['3D_Subcluster'] == 5)]['Protein_ID'].values
array([list(['2m4j', '6n3c', '7lna', '7q65', '7qkf', '7qkw', '8az6', '8bg9', '8dja', '8qn6', '8v1n', '9ero'])],
      dtype=object)
least_pdb_rows = K_3d[K_3d["num_unique_pdb_files"] == 2]
least_pdb_rows
2D_Cluster 3D_Subcluster Seq_Clusters Protein_ID PDB_File num_unique_seq_clusters num_unique_proteins num_unique_pdb_files
14 1 4 [7, 41] [6xfm, 8fnz] [6xfm_pair_150_Repair.pdb, 8fnz_pair_087_stand... 2 2 2
20 3 2 [6, 7] [2mvx, 7ykw] [2mvx_pair_128_Repair.pdb, 7ykw_pair_139_Repai... 2 2 2
33 6 3 [6, 7] [8q7l, 8q8y] [8q7l_pair_071_Repair.pdb, 8q8y_pair_065_Repai... 2 2 2
98 17 1 [6, 7] [8dja, 8tdo] [8dja_pair_1018_Repair.pdb, 8tdo_pair_483_Repa... 2 2 2
117 20 6 [7] [6n3c] [6n3c_pair_099_Repair.pdb, 6n3c_pair_113_Repai... 1 1 2
K_3d[K_3d['2D_Cluster'] == 10]
2D_Cluster 3D_Subcluster Seq_Clusters Protein_ID PDB_File num_unique_seq_clusters num_unique_proteins num_unique_pdb_files
71 10 1 [4, 5] [7nrt, 8ec7, 8ff2, 8hia, 8oi0, 8ppo, 8q8z, 8qn... [7nrt_pair_130_Repair.pdb, 8ec7_pair_505_Repai... 2 10 15
72 10 2 [2, 4, 5, 8] [2lmo, 7qkw, 8ci8, 8q8y, 8zwm] [2lmo_pair_109_Repair.pdb, 7qkw_pair_232_Repai... 4 5 7
73 10 3 [4, 5, 8] [6n3c, 6nwq, 6vw2, 7rl4, 8dja, 8ff3, 8oth, 8q9... [6n3c_pair_110_Repair.pdb, 6nwq_pair_471_Repai... 3 11 12
74 10 4 [4, 5] [6n3c, 6w0o, 7qkw, 7rl4, 8bfa, 8bg9, 8ff3, 8q8... [6n3c_pair_326_Repair.pdb, 6n3c_pair_506_Repai... 2 9 14
75 10 5 [4, 5, 9] [7qjw, 7qkv, 7rl4, 8ci8, 8dja, 8g2v, 8q9l, 8qn6] [7qjw_pair_458_Repair.pdb, 7qkv_pair_429_Repai... 3 8 10
K_3d[(K_3d['2D_Cluster'] == 10) & (K_3d['3D_Subcluster'] == 5)]['PDB_File'].to_list()
[['7qjw_pair_458_Repair.pdb',
  '7qkv_pair_429_Repair.pdb',
  '7qkv_pair_797_Repair.pdb',
  '7rl4_pair_118_Repair.pdb',
  '7rl4_pair_991_Repair.pdb',
  '8ci8_pair_681_Repair.pdb',
  '8dja_pair_389_Repair.pdb',
  '8g2v_pair_024_Repair.pdb',
  '8q9l_pair_255_Repair.pdb',
  '8qn6_pair_168_Repair.pdb']]
 

3D clustering only

def load_matrix_test(path):
    print("Loading an RMSD matrix...")
    X, pdb_files = load_features(path)
    print(f"Loaded. Matrix has {X.shape} shape.")
    print(f"With corresponding number of files {len(pdb_files)}")
    return X, pdb_files

def load_matrix_for_merging(path):
    print("Loading an RMSD matrix...")
    X, pdb_files = load_features(path)
    print(f"Loaded. Matrix has {X.shape} shape.")
    print(f"With corresponding number of files {len(pdb_files)}")
    return X

def merge_rect_matrices(matrix_files, output_file):
    X_total, files = load_matrix_test(matrix_files[0])

    for file in matrix_files[1:]:
        X_part = load_matrix_for_merging(file)

        # Merge the matrices
        if X_total.shape[1] != X_part.shape[1]:
            raise ValueError(f"Matrix shapes do not match: {X_total.shape} vs {X_part.shape}")
        X_total = np.concatenate((X_total, X_part), axis=0)  # Stack rows

        print(f"Merged matrix shape is now {X_total.shape}")
    
    # Save the merged matrix
    np.savez_compressed(output_file, X=X_total, files=files)
    np.save("rmsd_matrix_svd_merged_full.npy", X_total)
    print(f"Merged matrix saved to {output_file}")

def mirror_and_fill_square_matrix(X):
    # Mirror the upper triangle to the lower triangle
    i_lower = np.tril_indices(X.shape[0], -1)
    X[i_lower] = X.T[i_lower]

    # Fill diagonal with zeros (self-comparison)
    np.fill_diagonal(X, 0)

    return X

# dir_path = "../../matrix_parts"
# out_matrix_file = "rmsd_matrix_svd_merged_full.npz"
# merge_rect_matrices([f"{dir_path}/rmsd_matrix_{i}_{i+5000}.npz" for i in range(0, 60000, 5000)] + [f"{dir_path}/rmsd_matrix_60000_62132.npz"], output_file=out_matrix_file)

# rmsd_X_max, pdb_files_max = load_matrix_test(out_matrix_file)
# rmsd_X_max = mirror_and_fill_square_matrix(rmsd_X_max)
# np.savez_compressed(out_matrix_file, X=rmsd_X_max, files=pdb_files_max)
# np.save("rmsd_matrix_svd_merged_full.npy", rmsd_X_max)
rmsd_X_all, pdb_files_all = load_matrix("rmsd_matrix_svd_merged_full.npz")
print(rmsd_X_all.shape)
rmsd_X_all, pdb_files_all = reduce_matrix(rmsd_X_all, pdb_files_all, max_samples=100_000, float_32=True)
print(pdb_files_all[0], pdb_files_all[-1])
print(len(pdb_files_all), rmsd_X_all.shape)
pdb_files_all = [file[26:] for file in pdb_files_all]

#rmsd_X_all, pdb_files_all = filter_matrix(rmsd_X_all, pdb_files_all, temporary["PDB_File"].tolist())
Loading an RMSD matrix...
Loaded. Matrix has (62132, 62132) shape.
With corresponding number of files 62132
108      62024
62131      108
Name: count, dtype: int64
Files to remove due to high NaN count (62131 nans)
Removing 108 files with 62131 NaN values in their respective row.
Resulting matrix shape is (62024, 62024)
Adjusted number of files 62024
(62024, 62024)
Converting to float32
dist_matrix_pdbs_five_two/1yjo.pdb dist_matrix_pdbs_five_two/7q65_pair_160_Repair.pdb
62024 (62024, 62024)
print(rmsd_X_all.shape)
num_templates_all = sum('Repair' not in file for file in pdb_files_all)
print(num_templates_all)
(58081, 58081)
133
from scipy.sparse.linalg import eigsh

def fast_pcoa(dist_matrix, n_components=10):
    n = dist_matrix.shape[0]

    # Square distances
    D2 = dist_matrix ** 2

    # Double centering (done efficiently without forming J explicitly)
    row_mean = D2.mean(axis=1, keepdims=True)
    col_mean = D2.mean(axis=0, keepdims=True)
    total_mean = D2.mean()
    B = -0.5 * (D2 - row_mean - col_mean + total_mean)

    # Compute top eigenvectors using sparse eigendecomposition
    eigvals, eigvecs = eigsh(B, k=n_components, which='LA')  # 'LA' = largest algebraic

    # Sort descending by eigenvalue
    idx = np.argsort(eigvals)[::-1]
    eigvals = eigvals[idx]
    eigvecs = eigvecs[:, idx]

    # Compute coordinates
    coords = eigvecs * np.sqrt(np.maximum(eigvals, 0))
    return coords, eigvals

# Perform PCoA
num_top_components = 2
coords, eigvals = fast_pcoa(rmsd_X_all, n_components=50)
## Save to dataframes
pcoa_results_df = pd.DataFrame(coords[:, :num_top_components], columns=[f"PCoA_{i+1}" for i in range(num_top_components)])
results_df = pcoa_results_df.iloc[:, :num_top_components].copy()

## Optionally print explained variance (if you had a helper for that)
total_var = np.sum(eigvals[eigvals > 0])  # total inertia
explained_ratio = eigvals / total_var
explained_var = explained_ratio[:num_top_components]
print(f"Explained variance ratio (first {num_top_components}): {explained_var}")

# Perform UMAP
umap_model = umap.UMAP(
    n_neighbors=30, n_components=2, metric='precomputed',
    random_state=42, min_dist=0.05
)
emb = umap_model.fit_transform(rmsd_X_all)

num_templates_all = sum('Repair' not in file for file in pdb_files_all)

category = np.array(
    ["Templates"] * num_templates_all + ["Extracted"] * (len(pdb_files_all) - num_templates_all)
)

plot_pca_and_umap(
    results_df,
    emb,
    category,
    {"Templates": "mediumorchid", "Extracted": "cornflowerblue"},
    dot_size=5
)
Explained variance ratio (first 2): [0.2677     0.15034701]
{'Extracted': 'Extracted (n=57948)', 'Templates': 'Templates (n=133)'}
Number of templates used 133
Error, trying to plot a df
rat_sum = 0
num_components_expl_90=0
for i, ratio in enumerate(explained_ratio):
    rat_sum += ratio
    if rat_sum >= 0.9:
        num_components_expl_90 = i + 1
        break

print(rat_sum)
num_components_expl_90
0.90137905
27
components_explaining = num_components_expl_90
#umapca_model = umap.UMAP(n_neighbors=20, n_components=2, random_state=42, min_dist=0.01, spread=0.5) #, min_dist=0.1, spread=1)
umapca_model = umap.UMAP(n_neighbors=10, n_components=2, n_jobs=14, min_dist=0.01, spread=0.5)
umapca_emb = umapca_model.fit_transform(pcoa_results_df.iloc[:, :num_components_expl_90].values)
# FIXME: ALL COMPONENTS DON'T GIVE THE SAME UMAP AS ABOVE

plot_umap(num_templates_all, pdb_files_all, umapca_emb, save_path="plots/rmsd_all_umap", dot_size=5)
components_explaining = num_components_expl_90
#umapca_model = umap.UMAP(n_neighbors=20, n_components=2, random_state=42, min_dist=0.01, spread=0.5) #, min_dist=0.1, spread=1)
umapca_model = umap.UMAP(n_neighbors=50,
                         n_components=2,
                         n_jobs=14,
                         min_dist=0.01,
                         spread=0.5,
                         local_connectivity=3,
                         repulsion_strength=1.5,
                         negative_sample_rate=15,
                         metric='euclidean')
umapca_emb = umapca_model.fit_transform(pcoa_results_df.iloc[:, :num_components_expl_90].values)

plot_umap(num_templates_all, pdb_files_all, umapca_emb, save_path="plots/rmsd_all_umap", dot_size=5)
import matplotlib.pyplot as plt
import numpy as np

def find_nearest_point(emb, x, y):
    coords = np.array([x, y])
    dists = np.linalg.norm(emb - coords, axis=1)
    nearest_index = np.argmin(dists)
    return nearest_index, emb[nearest_index]

idx, coord = find_nearest_point(emb, x=-30, y=0)
print("Nearest index:", idx)
print("Corresponding file:", pdb_files_all[idx])
Nearest index: 26559
Corresponding file: 8x7b_pair_342_Repair.pdb
# Perform hierarchical clustering using the RMSD distance matrix
t_distance = 2.0

# Use "average" linkage since we're providing a distance matrix
Z = sch.linkage(squareform(rmsd_X_all), method='average')  # or 'complete', 'single', etc.

# Get cluster labels based on distance threshold
labels = sch.fcluster(Z, t=t_distance, criterion='distance')

num_clusters = len(np.unique(labels))
print(f"Found {num_clusters} clusters (labeled {len(labels)} structures)")

# Color palette for clusters
cluster_palette = sns.color_palette("husl", num_clusters)

# Plot dendrogram
plot_dendrogram(
    Z,
    pdb_files_all,
    num_clusters,
    t_distance,
    save_path="plots/rmsd_all_dendrogram",
    custom_palette=cluster_palette
)
Found 796 clusters (labeled 58081 structures)
file_category = [
    'Repaired' if 'Repair' in file else 'Template'
    for file in pdb_files_all
]

# Create dataframe: pdb_file ↔ cluster_label ↔ category
df_clusters = pd.DataFrame({
    'pdb_file': pdb_files_all,
    'cluster_label': labels,
    'category': file_category
})

# Save to CSV
df_clusters.to_csv('file_clusters.csv', index=False)
print("✅ Saved clusters.csv")

# Compute cluster summary
df_summary = (
    df_clusters
    .groupby('cluster_label')
    .agg(num_structs=('pdb_file', 'count'))
    .reset_index()
)

# Compute number of templates per cluster
template_counts = (
    df_clusters[df_clusters['category'] == 'Template']
    .groupby('cluster_label')
    .size()
    .rename('num_templates')
    .reset_index()
)

# Merge summaries
df_summary = df_summary.merge(template_counts, on='cluster_label', how='left')
# Fill missing template counts with 0 (in case some clusters have no templates)
df_summary['num_templates'] = df_summary['num_templates'].fillna(0).astype(int)

# Save to CSV
df_summary.to_csv('cluster_summary.csv', index=False)
print("✅ Saved cluster_summary.csv")

# Display results
print(df_clusters.head())
print(df_summary.head())
✅ Saved clusters.csv
✅ Saved cluster_summary.csv
     pdb_file  cluster_label  category
0    1yjo.pdb            667  Template
1  1yjp_1.pdb            241  Template
2  1yjp_2.pdb            667  Template
3  2m5n_2.pdb            483  Template
4  2m5n_3.pdb            263  Template
   cluster_label  num_structs  num_templates
0              1           15              0
1              2           83              0
2              3          102              0
3              4           98              0
4              5           24              0
from scipy.cluster.hierarchy import to_tree

def convert_linkage_to_newick(Z, leaf_names):
    tree = to_tree(Z, rd=False)
    newick_str = get_newick(tree, Z, tree.dist, leaf_names) + ";"
    return newick_str


def get_newick(node, labels, parent_dist, leaf_names, newick=""):
    """Recursively generate a Newick string with internal node IDs."""
    branch_length = parent_dist - node.dist

    if node.is_leaf():
        return f"{leaf_names[node.id]}:{branch_length:.6f}{newick}"
    else:
        if len(newick) > 0:
            newick = f")NODE_{node.id}:{branch_length:.6f}{newick}"
        else:
            newick = f")NODE_{node.id}:{branch_length:.6f}"
        newick = get_newick(node.get_left(), labels, node.dist, leaf_names, newick)
        newick = get_newick(node.get_right(), labels, node.dist, leaf_names, "," + newick)
        newick = "(" + newick
        return newick

newick_str = convert_linkage_to_newick(Z, pdb_files_all)
# TODO: We want to visualize the number of templates as well and the
# and just simple includes / doesnt include templates
# and cluster size somehow (mby even truncate all the branches?? within the cluster)

with open("dendrogram_all.newick", "w") as f:
    f.write(newick_str)
pdb_to_idx = {file: idx for idx, file in enumerate(pdb_files_all)}

def calculate_summed_distance_within_cluster_all(row):
    current_pdb = row['PDB_File']
    current_cluster = row['Cluster']

    # Get all PDB files in the current cluster
    cluster_members_df = K[K['Cluster'] == current_cluster]
    cluster_member_pdbs = cluster_members_df['PDB_File'].tolist()

    # Get the index of the current and other cluster members' PDB files in the RMSD matrix
    current_pdb_idx = pdb_to_idx[current_pdb]
    cluster_member_indices = [pdb_to_idx[pdb] for pdb in cluster_member_pdbs]

    # Extract the relevant row from the RMSD matrix (distances from current PDB)
    distances_from_current_pdb = rmsd_X_all[current_pdb_idx, :] # FIXME (rmsd_X_max or rmsd_X_projected)
    summed_distance = distances_from_current_pdb[cluster_member_indices].sum()
    # summed_distance_squared = (distances_from_current_pdb[cluster_member_indices]**2).sum() # TODO
    # if i want to know which files distort the cluster
    return summed_distance

K = pd.DataFrame({
    'PDB_File': pdb_files_all,
    'Cluster': labels,
    'Category': category
})

# Apply the function to each row
K['Summed_Distances_Within_Cluster'] = K.apply(calculate_summed_distance_within_cluster_all, axis=1)
import superpose_fragments
# FIXME
def superpose_cluster(cluster_num, distance_matrix, projected_alignment=False, K_labeled=K):
    # Get the indices of the PDB files in the target cluster
    cluster_rows = K_labeled[K_labeled['Cluster'] == cluster_num].sort_values(by='Summed_Distances_Within_Cluster', ascending=True)
    cluster_indices = cluster_rows['PDB_File'].apply(lambda x: pdb_to_idx[x]).tolist()
    cluster_rmsds_submatrix = distance_matrix[np.ix_(cluster_indices, cluster_indices)]
    # Flatten the upper triangular part
    upper_triangle_indices = np.triu_indices_from(cluster_rmsds_submatrix, k=1) # FIXME: can be rectangular not a square
    all_pairwise_cluster_distances = cluster_rmsds_submatrix[upper_triangle_indices]
    # Get all the cluster paths and superpose
    paths = []
    for file in list(cluster_rows['PDB_File'].values):
        file = file.split('/')[-1]
        if file.split('_')[-1] == "Repair.pdb":
            paths.append(f"./fragment_pairs/{file}")
        else:
            paths.append(f"./templates/{file}")
    if not projected_alignment:
        rmsds, super_opts, aligned_coords = superpose_fragments.superpose_all(paths, output_path=f"superposed_cluster_{cluster_num}.cif")
    else:
        center_projected_coords, _ = projected_pairwise_rmsd_matrix.get_projected_coords_and_std(paths[0])
        superimposer = QCPSuperimposer()
        aligned_coords = [center_projected_coords]
        rmsds = []

        for path in paths[1:]:
            projected_coords, _ = projected_pairwise_rmsd_matrix.get_projected_coords_and_std(path)
            if projected_coords is None: # FIXME: why some templates failing
                print(f"{projected_coords} of {path}")
                continue
            
            possible_aligns = projected_pairwise_rmsd_matrix.generate_possible_alignments_test(projected_coords)

            best_rmsd = float('inf')
            best_variant = None
            for align in possible_aligns:
                superimposer.set(center_projected_coords, align)
                superimposer.run()
                rmsd = superimposer.get_rms()
                
                if rmsd < best_rmsd:
                    best_rmsd = rmsd
                    best_variant = superimposer.get_transformed()

            rmsds.append(best_rmsd)
            aligned_coords.append(best_variant)

    print(f"Number of structures in cluster {cluster_num}: {len(cluster_rows)}")
    print(f"Mean RMSD within cluster: {np.mean(all_pairwise_cluster_distances):.2f} Å")
    print(f"Standard deviation of RMSDs within cluster: {np.std(all_pairwise_cluster_distances):.2f} Å")
    print(f"Standard deviation of RMSDs superposition to centroid structure: {np.std(rmsds):.2f} Å")
    return aligned_coords

def plot_projected_cluster(cluster_num, projected_alignment=True, X=rmsd_X_projected, K_labeled=K, save_path="plots/cluster_test"):
    aligned_coords = superpose_cluster(cluster_num, X, projected_alignment=projected_alignment, K_labeled=K_labeled)
    aligned_coords = np.array(aligned_coords)
    avg_layer_coords = aligned_coords.reshape(-1, 12, 3)
    mean_coords = np.mean(avg_layer_coords, axis=0)
    std_coords = np.std(avg_layer_coords, axis=0)

    centered_mean_coords = mean_coords - np.mean(mean_coords, axis=0)
    pca = PCA(n_components=2)
    pca.fit(centered_mean_coords)
    projected_coords = pca.transform(centered_mean_coords)

    print(f"Total variance explained by best-fit plane: ~{np.sum(pca.explained_variance_ratio_):.2f}")
    print(f"Variance explained by PC1: {pca.explained_variance_ratio_[0]:.2f}")
    print(f"Variance explained by PC2: {pca.explained_variance_ratio_[1]:.2f}")

    projected_stdev_x = np.dot(std_coords, pca.components_[0]) # Project stdev along PC1
    projected_stdev_y = np.dot(std_coords, pca.components_[1]) # Project stdev along PC2
    projected_stdev_x_abs = np.abs(projected_stdev_x)
    projected_stdev_y_abs = np.abs(projected_stdev_y)

    fig_2d_pca = plt.figure(figsize=(6, 4))
    ax_2d_pca = fig_2d_pca.add_subplot(111)

    # Plot the mean trace on the best-fit planeplot_projected_cluster
    ax_2d_pca.plot(projected_coords[:6, 0], 
                    projected_coords[:6, 1], 
                    'o-', label='Mean Trace (Hexapeptide 1)', color='black', linewidth=10, markersize=6)
    ax_2d_pca.plot(projected_coords[6:, 0], 
                    projected_coords[6:, 1], 
                    'o-', label='Mean Trace (Hexapeptide 2)', color='black', linewidth=10, markersize=6)
    # Add "blurriness" as error bars
    # Use projected_stdev_x_abs and projected_stdev_y_abs for error bars along the new axes
    ax_2d_pca.errorbar(projected_coords[:, 0], projected_coords[:, 1],
                       xerr=projected_stdev_x_abs, yerr=projected_stdev_y_abs,
                       fmt='none', capsize=3, color='grey', alpha=0.5, zorder=0)

    ax_2d_pca.axes.get_xaxis().set_visible(False)
    ax_2d_pca.axes.get_yaxis().set_visible(False)
    # remove the box
    for spine in ax_2d_pca.spines.values():
        spine.set_visible(False)
    ax_2d_pca.grid(False)
    ax_2d_pca.set_aspect('equal', adjustable='box') # Keep aspect ratio for true shape
    plt.tight_layout()

    # plt.savefig(f"{save_path}_{cluster_num}.png", dpi=300, bbox_inches="tight")
    # plt.savefig(f"{save_path}_{cluster_num}.pdf", bbox_inches="tight")
    # plt.savefig(f"{save_path}_{cluster_num}.svg", bbox_inches="tight")
    plt.show()
cluster_num = 9
# superpose_cluster(cluster_num, rmsd_X_projected)

plot_projected_cluster(cluster_num, False, rmsd_X_all, K)
Saved superimposed structure to superposed_cluster_9.cif
Number of structures in cluster 9: 50
Mean RMSD within cluster: 0.95 Å
Standard deviation of RMSDs within cluster: 0.76 Å
Standard deviation of RMSDs superposition to centroid structure: 0.62 Å
Total variance explained by best-fit plane: ~0.99
Variance explained by PC1: 0.95
Variance explained by PC2: 0.05
# If needed remove malicious cluster and recluster
#wrong_cluster_num = 20
#K = K[K['Cluster'] != wrong_cluster_num]
#rmsd_X_projected, pdb_files_projected = filter_matrix(rmsd_X_projected, pdb_files_projected, K['PDB_File'].to_list())
for i in range(max(labels)):
    plot_projected_cluster(i+1, False, rmsd_X_all, K)
    #, save_path=f"plots/max_clusters/sub{subcluster_num}_max_cluster")
def plot_labeled_umap(df):
    fig = plt.figure(figsize=(10, 6))

    sns.scatterplot(
        data=df, x='UMAP_1', y='UMAP_2',
        hue='Cluster', s=5, alpha=1, palette=cluster_palette
    )

    plt.title('UMAP Projection of Sampled and Template Fragments')
    plt.xlabel('UMAP 1')
    plt.ylabel('UMAP 2')

    # Only show legend if there are 100 or fewer clusters
    n_clusters = df['Cluster'].nunique()
    if n_clusters <= 100:
        plt.legend(title='Clusters')
    else:
        plt.legend([], [], frameon=False)

    # Cosmetic adjustments
    ax = plt.gca()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.grid(False)
    ax.spines['left'].set_color('dimgray')
    ax.spines['bottom'].set_color('dimgray')

    plt.tight_layout()
    plt.savefig("plots/rmsd_projected_labeled_umap.png", dpi=600, bbox_inches="tight")
    plt.savefig("plots/rmsd_projected_labeled_umap.pdf", bbox_inches="tight")
    plt.savefig("plots/rmsd_projected_labeled_umap.svg", bbox_inches="tight")
    plt.show()


umapca_df = pd.DataFrame({
'PDB_File': pdb_files_all,
'UMAP_1': umapca_emb[:, 0],
'UMAP_2': umapca_emb[:, 1]
})
K_merged = pd.merge(K, umapca_df, on="PDB_File")
plot_labeled_umap(K_merged)
unique_cluster_names = K['Cluster'].unique()
cluster_order_numeric_sorted_str = sorted(unique_cluster_names, key=int)
K['Cluster'] = pd.Categorical(K['Cluster'], categories=cluster_order_numeric_sorted_str, ordered=True)

# Calculate value counts based on the ordered categorical 'Cluster'
# This will now respect the order defined in the categorical type
value_counts_ordered = K['Cluster'].value_counts().sort_index()
#print([f"Cluster {i+1}, {val}" for i, val in enumerate(value_counts_ordered.values)])
# Prepare data for "Template Fragments"
stacked_counts = K.groupby(['Cluster', 'Category'], observed=True).size().reset_index(name='counts')
stacked_counts = stacked_counts[stacked_counts['Category'] == "Templates"]
#print([f"Cluster {i+1}, {val}" for i, val in enumerate(stacked_counts['counts'].values)])

plt.figure(figsize=(25, 2))
# Plot the total cluster sizes, now ordered by cluster number
ax = sns.barplot(x=value_counts_ordered.index, y=value_counts_ordered.values, color='midnightblue')
# Plot the "Template Fragments" counts, which will also be ordered by cluster number
#ax = sns.barplot(x='Cluster', y='counts', color='mediumorchid', data=stacked_counts) # TODO: purple distribution (kde)
tmp = stacked_counts.set_index('Cluster')['counts']
x_template = np.repeat([i-1 for i in tmp.index.values], tmp.values)
x_template_num = pd.Series(x_template).astype("int64").to_numpy(dtype=float)

# Match x-limits to the clusters
xmin, xmax = value_counts_ordered.index.min(), value_counts_ordered.index.max()
ax.set_xlim(-1, xmax)

# --- KDE on a secondary y-axis ---
ax2 = ax.twinx()
sns.kdeplot(
    x=x_template_num, ax=ax2, fill=True, alpha=0.3,
    bw_adjust=0.1, color='mediumorchid', clip=(xmin - 2, xmax + 2)
)

plt.xlabel('Cluster')
plt.ylabel('Count')
ax.grid(False)
ax2.grid(False)
# remove spines
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_color('dimgray')
ax2.spines['left'].set_visible(False)
ax2.spines['top'].set_visible(False)
ax2.spines['bottom'].set_visible(False)
ax2.spines['right'].set_color('dimgray')
ax2.set_ylabel('Density')
plt.xticks([])
plt.tight_layout()

plt.savefig("plots/cluster_counts_projected_bar.png", dpi=600, bbox_inches="tight")
plt.savefig("plots/cluster_counts_projected_bar.pdf", bbox_inches="tight")
plt.savefig("plots/cluster_counts_projected_bar.svg", bbox_inches="tight")
plt.show()
K_cluster_counts = K['Cluster'].value_counts().rename_axis('cluster_num').reset_index(name='counts')
K_cluster_counts.describe()
counts
count 796.000000
mean 72.966080
std 176.402884
min 10.000000
25% 15.000000
50% 25.000000
75% 57.000000
max 2231.000000
len(K_cluster_counts[K_cluster_counts['counts'] < 10])
0
len(K_cluster_counts[K_cluster_counts['counts'] < 20])
314
K_sorted = K.copy()
value_counts_sorted = value_counts_ordered.copy()
stacked_counts_sorted = stacked_counts.copy()

# --- Sort by descending total counts instead of numeric order ---
clusters_by_count = value_counts_sorted.sort_values(ascending=False).index.tolist()

# Reorder the Cluster column according to count order
K_sorted['Cluster'] = pd.Categorical(K_sorted['Cluster'], categories=clusters_by_count, ordered=True)

# Recompute value counts and template subset for the new order
value_counts_sorted = K_sorted['Cluster'].value_counts().sort_index()
stacked_counts_sorted = (
    K_sorted.groupby(['Cluster', 'Category'], observed=True)
    .size()
    .reset_index(name='counts')
)
stacked_counts_sorted = stacked_counts_sorted[stacked_counts_sorted['Category'] == "Templates"]

# --- Plot sorted-by-count version ---
tmp_sorted = stacked_counts_sorted.set_index('Cluster')['counts']
x_template_sorted = np.repeat(range(len(tmp_sorted)), tmp_sorted.values)
x_template_num_sorted = pd.Series(x_template_sorted).astype("float").to_numpy()

plt.figure(figsize=(200, 10))
ax = plt.gca()
ax2 = ax.twinx()

sns.kdeplot(
    x=x_template_num_sorted, ax=ax2, fill=True, alpha=0.3, lw=1.5,
    bw_adjust=0.1, color='mediumorchid', clip=(0, len(clusters_by_count))
)

ax2.set_zorder(0)
ax.set_zorder(1)
ax.patch.set_visible(False)  # allow bars to show without white background on ax2

sns.barplot(
    x=value_counts_sorted.index,
    y=value_counts_sorted.values,
    width=0.5,
    color='midnightblue',
    order=clusters_by_count,
    ax=ax
)

ax.set_xlim(-1, len(clusters_by_count))
ax.set_xlabel('Cluster (sorted by total count)')
#ax.set_xticklabels(ax.get_xticklabels(), rotation=90, ha='center')
plt.setp(ax.get_xticklabels(), rotation=90, ha='center')
ax.set_ylabel('Count')
ax2.set_ylabel('Density')
plt.xticks(rotation=90)

ax.grid(False)
ax2.grid(False)
for spine in ['top', 'bottom']:
    ax.spines[spine].set_visible(False)
    ax2.spines[spine].set_visible(False)
ax.spines['left'].set_color('dimgray')
ax2.spines['right'].set_color('dimgray')

plt.tight_layout()
plt.savefig("plots/sorted_counts.png", dpi=600, bbox_inches="tight")
plt.savefig("plots/sorted_counts.pdf", bbox_inches="tight")
plt.savefig("plots/sorted_counts.svg", bbox_inches="tight")
plt.show()
# Remove small clusters and recluster
wrong_clusters = K_cluster_counts[K_cluster_counts['counts'] < 10]['cluster_num']
print(f"Number of wrong clusters is {len(wrong_clusters)}/{len(K_cluster_counts)}")

for wrong_cluster_num in wrong_clusters:
    K = K[K['Cluster'] != wrong_cluster_num]

rmsd_X_all, pdb_files_all = filter_matrix(rmsd_X_all, pdb_files_all, K['PDB_File'].to_list())
Number of wrong clusters is 1097/1893
0 were not found in the matrix
(58081,)
# Plot a violin plot of total energies within the cluster
total_e_plotting_df = []

for cluster in cluster_order_numeric_sorted_str:
    cluster_rows = K[K['Cluster'] == cluster]
    
    for pdb_file in cluster_rows['PDB_File'].tolist():
        val_repaired = repaired_df.loc[repaired_df['PDB_File'] == pdb_file, 'Total_Energy']
        val_template = template_df.loc[template_df['PDB_File'] == pdb_file, 'Total_Energy']

        if not val_repaired.empty:
            total_energy = val_repaired.iloc[0]
        elif not val_template.empty:
            total_energy = val_template.iloc[0]
        else:
            print(f"⚠️ No Total_Energy found for {pdb_file}")
            total_energy = None
        total_e_plotting_df.append({'Cluster': cluster, 'Total Energy': total_energy})

total_e_plotting_df = pd.DataFrame(total_e_plotting_df)

plt.figure(figsize=(10, 6))
sns.violinplot(
    x='Cluster',
    y='Total Energy',
    data=total_e_plotting_df,
    hue='Cluster',
    palette=cluster_palette,
    legend=False,
    inner='quartile',
    width=0.7,
    cut=0,
)

plt.xlabel('Cluster', fontsize=12)
plt.ylabel('Total Energy', fontsize=12)
plt.title('Distribution of Total Energies within Clusters', fontsize=14)
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)

#plt.ylim(-0.5, 7) # TODO
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['left'].set_color('dimgray')
plt.gca().spines['bottom'].set_visible(False)
#plt.gca().set_yticks([0, 2, 4, 6]) # TODO

plt.grid(axis='y', linestyle='--', alpha=0.7) #, linewidth=2)
plt.tight_layout()

plt.savefig("plots/total_energy_within_clusters.png", dpi=600, bbox_inches="tight")
plt.savefig("plots/total_energy_within_clusters.pdf", bbox_inches="tight")
plt.savefig("plots/total_energy_within_clusters.svg", bbox_inches="tight")
plt.show()
pdb_to_idx = {file: idx for idx, file in enumerate(pdb_files_all)}
rmsd_plotting_df = []

for cluster in cluster_order_numeric_sorted_str:
    cluster_rows = K[K['Cluster'] == cluster]
    cluster_indices = cluster_rows['PDB_File'].apply(lambda x: pdb_to_idx[x]).tolist()
    cluster_rmsds_submatrix = rmsd_X_all[np.ix_(cluster_indices, cluster_indices)]
    # Flatten the upper triangular part
    upper_triangle_indices = np.triu_indices_from(cluster_rmsds_submatrix, k=1)
    all_pairwise_cluster_distances = cluster_rmsds_submatrix[upper_triangle_indices]

    rmsds_test=[]
    for dist in all_pairwise_cluster_distances:
        if dist == 0:
            continue
        rmsds_test.append(dist)
        rmsd_plotting_df.append({'Cluster': cluster, 'RMSD': dist})

rmsd_plotting_df = pd.DataFrame(rmsd_plotting_df)

plt.figure(figsize=(10, 6))

sns.violinplot(
    x='Cluster',
    y='RMSD',
    data=rmsd_plotting_df,
    hue='Cluster',
    palette=cluster_palette,
    legend=False,
    inner='quartile',
    width=0.7,
    cut=0,
)

plt.xlabel('Cluster', fontsize=12)
plt.ylabel('Pairwise RMSD', fontsize=12)
plt.title('Distribution of Pairwise RMSDs within Clusters', fontsize=14)
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)

plt.ylim(-0.5, 4.5)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['left'].set_color('dimgray')
plt.gca().spines['bottom'].set_visible(False)
plt.gca().set_yticks([0, 2, 4, 6])

plt.grid(axis='y', linestyle='--', alpha=0.7) #, linewidth=2)
plt.tight_layout()

plt.savefig("plots/rmsd_all_within_cluster_violin.png", dpi=600, bbox_inches="tight")
plt.savefig("plots/rmsd_all_within_cluster_violin.pdf", bbox_inches="tight")
plt.savefig("plots/rmsd_all_within_cluster_violin.svg", bbox_inches="tight")
plt.show()
pdb_to_idx = {file: idx for idx, file in enumerate(pdb_files_all)}

def get_dists_to_cluster_center(cluster_num, distance_matrix, K_labeled=K):
    cluster_rows = (
        K_labeled[K_labeled['Cluster'] == cluster_num]
        .sort_values(by='Summed_Distances_Within_Cluster', ascending=True)
    )

    cluster_indices = cluster_rows['PDB_File'].apply(lambda x: pdb_to_idx[x]).tolist()
    center_idx = pdb_to_idx[cluster_rows['PDB_File'].iloc[0]]
    distances = distance_matrix[np.ix_([center_idx], cluster_indices)].flatten()

    return list(cluster_rows['PDB_File']), distances


records = []
for cluster_num in K['Cluster'].unique():
    pdb_files, distances = get_dists_to_cluster_center(cluster_num, rmsd_X_all, K)
    # print(pdb_files[0], distances[0])  # sanity check
    for pdb_file, dist in zip(pdb_files, distances):
        records.append({'Cluster': cluster_num, 'PDB_File': pdb_file, 'Distance': dist})

distances_from_central_struct_df = pd.DataFrame(records)

plt.figure(figsize=(10, 6))
sns.violinplot(
    x='Cluster',
    y='Distance',
    data=distances_from_central_struct_df,
    hue='Cluster',
    palette=cluster_palette,
    legend=False,
    inner='quartile',
    width=0.7,
    cut=0,
)

plt.xlabel('Cluster', fontsize=12)
plt.ylabel('Distance from central structure', fontsize=12)
plt.title('Distribution of Distances from central structure within Clusters', fontsize=14)
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)

plt.ylim(-0.5, 3.5)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['left'].set_color('dimgray')
plt.gca().spines['bottom'].set_visible(False)
plt.gca().set_yticks([0, 1, 1.5, 2, 2.5, 3])

plt.grid(axis='y', linestyle='--', alpha=0.7) #, linewidth=2)
plt.tight_layout()

# plt.savefig("plots/rmsd_all_within_cluster_violin.png", dpi=600, bbox_inches="tight")
# plt.savefig("plots/rmsd_all_within_cluster_violin.pdf", bbox_inches="tight")
# plt.savefig("plots/rmsd_all_within_cluster_violin.svg", bbox_inches="tight")
plt.show()
summary_worst = (
    distances_from_central_struct_df
    .groupby('Cluster', group_keys=False)
    .apply(
        lambda g: pd.Series({
            'Percent_Above_2.0': (g['Distance'] > 2.0).mean() * 100,
            'Percent_Above_2.5': (g['Distance'] > 2.5).mean() * 100,
            'Worst5_Min_Dist': g['Distance'].quantile(0.95),
            'Worst5_Max_Dist': g['Distance'].max()
        }),
        include_groups=False
    )
    .reset_index()
)

summary_worst.describe()
Cluster Percent_Above_2.0 Percent_Above_2.5 Worst5_Min_Dist Worst5_Max_Dist
count 796.000000 796.000000 796.000000 796.000000 796.000000
mean 398.500000 2.132539 0.047347 1.672276 1.841025
std 229.929699 4.152986 0.357951 0.387387 0.426857
min 1.000000 0.000000 0.000000 0.343977 0.343981
25% 199.750000 0.000000 0.000000 1.581300 1.690478
50% 398.500000 0.000000 0.000000 1.767408 1.900510
75% 597.250000 2.945527 0.000000 1.912616 2.090507
max 796.000000 26.315789 4.615385 2.316133 2.894790
plt.figure(figsize=(14, 6))

sns.barplot(
    data=summary_worst,
    x='Cluster',
    y='Percent_Above_2.5',
    color='steelblue',
    width=0.9,
    linewidth=0
)

# Add small scatter dots at the top of each bar
plt.scatter(
    x=range(len(summary_worst)),
    y=summary_worst['Percent_Above_2.5'],
    color='black',
    s=5,
    zorder=3
)

plt.title('Percentage of Structures > 2.5 Å from Cluster Center')
plt.ylabel('Percentage (%)')
plt.xlabel('Cluster')
plt.ylim(-0.5, 5.5)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['left'].set_color('dimgray')
plt.gca().spines['bottom'].set_visible(False)
plt.xticks([]) 
plt.yticks(fontsize=10)

plt.gca().set_yticks([0, 2.5, 5])
plt.grid(axis='y', linestyle='--', alpha=0.7) #, linewidth=2)
plt.tight_layout()
plt.show()
plt.figure(figsize=(14, 6))

sns.barplot(
    data=summary_worst,
    x='Cluster',
    y='Percent_Above_2.0',
    color='steelblue',
    width=0.9,
    linewidth=0
)

# Add small scatter dots at the top of each bar
plt.scatter(
    x=range(len(summary_worst)),
    y=summary_worst['Percent_Above_2.0'],
    color='black',
    s=5,
    zorder=3
)

plt.title('Percentage of Structures > 2.0 Å from Cluster Center')
plt.ylabel('Percentage (%)')
plt.xlabel('Cluster')
plt.ylim(-0.5, 30.5)
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['left'].set_color('dimgray')
plt.gca().spines['bottom'].set_visible(False)
plt.xticks([])
plt.yticks(fontsize=10)

plt.gca().set_yticks([0, 10, 20, 30])
plt.grid(axis='y', linestyle='--', alpha=0.7) #, linewidth=2)
plt.tight_layout()
plt.show()
import ast

# Read
file_chain_ids = pd.read_csv("results/file_chain_ids.csv")
print(f"Extracted chain IDs from a total of {len(file_chain_ids)} files.")

# Convert to list
file_chain_ids['chain_ids'] = file_chain_ids['chain_ids'].apply(ast.literal_eval)

# Default label
file_chain_ids['interaction_type'] = 'unknown'

# Label based on chain count
file_chain_ids.loc[file_chain_ids['chain_ids'].apply(lambda x: len(x) == 10), 'interaction_type'] = 'interchain'
file_chain_ids.loc[file_chain_ids['chain_ids'].apply(lambda x: len(x) == 5), 'interaction_type'] = 'intrachain'

file_chain_ids.head()
Extracted chain IDs from a total of 61951 files.
file chain_ids interaction_type
0 8v1n_pair_342_Repair.pdb [C, D, G, H, K, L, O, P, Q, R] interchain
1 8otg_pair_073_Repair.pdb [A, B, E, F, G] intrachain
2 7xo2_pair_150_Repair.pdb [I, D, E, F, G] intrachain
3 7xo3_pair_287_Repair.pdb [A, C, D, I, J] intrachain
4 8ot4_pair_130_Repair.pdb [C, D, E, F, G, H, I, J, K, L] interchain
file_chain_ids['interaction_type'].value_counts()
interaction_type
intrachain    42643
interchain    18829
unknown         479
Name: count, dtype: int64
seq_labels_df['pdb_id'].value_counts() # TODO 
pdb_id
8fnz    48
8oh2    36
8bqw    30
7q64    30
2mpz    27
        ..
8j7n     3
9gg0     3
7u0z     3
6gx5     3
7v4a     3
Name: count, Length: 482, dtype: int64
all_clusters = []
skipped_files = []
cluster_not_found = []
# FIXME: Vectorize
for cluster_id in K['Cluster'].unique():
    cluster_subset = K[K['Cluster'] == cluster_id]

    for _, row in cluster_subset.iterrows():
        pdb_file = row['PDB_File']
        pdb_id = pdb_file[:4]

        # Get all chain IDs for this PDB file
        chain_ids = file_chain_ids[file_chain_ids['file'] == pdb_file]['chain_ids'].values
        if len(chain_ids) == 0:
            skipped_files.append(pdb_file)
            continue

        chain_ids = chain_ids[0]  # assuming this is a list or string of chain IDs
        seq_clusters = []

        for chain_id in chain_ids:
            mask = (
                (seq_labels_df['pdb_id'] == pdb_id)
                & (seq_labels_df['chain_id'] == chain_id)
            )
            vals = seq_labels_df[mask]['cluster_id'].values
            if len(vals) == 0:
                cluster_not_found.append((pdb_id, chain_id))
                continue
            seq_clusters.append(vals[0])

        all_clusters.append({
            'PDB_File': pdb_file,
            '3D_Cluster': cluster_id,
            'Seq_Clusters': seq_clusters,
        })

print(f"Skipped {len(skipped_files)} files: {skipped_files[:5]}...")
print(f"Cluster not found for {len(cluster_not_found)}: {cluster_not_found[:5]}...")

all_clusters_df = pd.DataFrame(all_clusters)
all_clusters_df.head()
Skipped 133 files: ['1yjo.pdb', '1yjp_2.pdb', '2omm_2.pdb', '3fva_a.pdb', '3hyd_b_2.pdb']...
Cluster not found for 84833: [('9bbl', 'N'), ('9bbl', 'O'), ('7qkj', 'O'), ('7qkj', 'P'), ('8ote', 'J')]...
PDB_File 3D_Cluster Seq_Clusters
0 9bbl_pair_558_Repair.pdb 667 [6, 6, 6]
1 7qkj_pair_498_Repair.pdb 667 [6, 6, 6]
2 8ote_pair_621_Repair.pdb 667 [3, 3, 3, 3]
3 7nrt_pair_309_Repair.pdb 667 [6, 6, 6, 6, 6]
4 8g2v_pair_049_Repair.pdb 667 [44, 44, 44, 44, 44, 44, 44, 44, 44, 44]

It seems like we've lost a lot of chains for which we haven't found a sequential cluster, however, those are mostly coming from artificially expanded structures - we added some of the chains which we then try to find in the original protein... but the original protein didn't have them.

import matplotlib.pyplot as plt

# Count and sort cluster occurrences by frequency (descending)
vc = all_clusters_df['3D_Cluster'].value_counts().sort_values(ascending=False)

plt.figure(figsize=(10, 3))  # wider figure

plt.bar(range(len(vc)), vc, color='skyblue', edgecolor='black')

plt.title("Cluster Distribution")
plt.ylabel("Count")

# Remove grid
plt.grid(False)

# Remove x-axis tick labels
plt.xticks([])

# Remove top and right spines for stylistic cleanliness
ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

plt.show()
all_clusters_df['Seq_Clusters'] = all_clusters_df['Seq_Clusters'].apply(lambda x: sorted(list(set(x)))) # unique and sort
all_clusters_df['Num_Seq_Clusters'] = all_clusters_df['Seq_Clusters'].apply(lambda x: len(x))
base_heterotypic = all_clusters_df[all_clusters_df['Num_Seq_Clusters'] > 1]
base_heterotypic # heterotypic interactions (in our data)
PDB_File 3D_Cluster Seq_Clusters Num_Seq_Clusters
60 8bqw_pair_136_standardized_Repair.pdb 667 [-1, 7] 2
86 9for_pair_202_Repair.pdb 667 [0, 33] 2
115 9for_pair_296_Repair.pdb 667 [0, 33] 2
286 9for_pair_478_Repair.pdb 667 [0, 33] 2
336 9for_pair_204_Repair.pdb 667 [0, 33] 2
... ... ... ... ...
54659 9for_pair_354_Repair.pdb 286 [0, 33] 2
54672 9for_pair_452_Repair.pdb 286 [0, 33] 2
55382 9for_pair_346_Repair.pdb 500 [0, 33] 2
55388 9for_pair_451_Repair.pdb 500 [0, 33] 2
55389 9for_pair_338_Repair.pdb 500 [0, 33] 2

359 rows × 4 columns

base_heterotypic['PDB_File'].str[:4].value_counts()
PDB_File
9for    156
8q96     67
8a9l     48
8ci8     36
8wcp     36
8qv8      7
8bqw      6
8ttl      3
Name: count, dtype: int64
base_heterotypic[base_heterotypic['PDB_File'].str.startswith('9for')]['Seq_Clusters'].value_counts()
Seq_Clusters
[0, 33]    156
Name: count, dtype: int64
base_heterotypic[base_heterotypic['PDB_File'].str.startswith('8ci8')]['Seq_Clusters'].value_counts()
Seq_Clusters
[41, 48]    36
Name: count, dtype: int64
base_heterotypic[base_heterotypic['PDB_File'].str.startswith('8ttl')]['Seq_Clusters'].value_counts()
Seq_Clusters
[6, 14]    3
Name: count, dtype: int64

And yes! We have correctly identified our heterotypic interaction which was present in our data from the beggining. Protein 9for is actually a heteromeric amyloid filament of TDP-43 and AXNA11. Plus two more proteins which have potential heterotypic interaction.. needs further checking. The others have one of the 2 unique sequential clusters labeled as -1 (means that CDHIT was unable to assign cluster to it).

def flatten_and_unique(list_of_lists):
    flat = []
    for seq in list_of_lists:
        if isinstance(seq, np.ndarray):
            seq = seq.tolist()
        flat.extend(seq)
    return sorted(set(flat))

all_clusters_df["PDB_ID"] = all_clusters_df["PDB_File"].str[:4]

K = (
    all_clusters_df
    .groupby("3D_Cluster")
    .agg({
        "Seq_Clusters": lambda x: flatten_and_unique(x),
        "PDB_ID":       lambda x: sorted(set(x)),
        "PDB_File":     lambda x: sorted(set(x)),
    })
    .reset_index()
)

K["num_unique_seq_clusters"] = K["Seq_Clusters"].apply(len)
K["num_unique_proteins"]     = K["PDB_ID"].apply(len)
K["num_unique_pdb_files"]    = K["PDB_File"].apply(len)

K
3D_Cluster Seq_Clusters PDB_ID PDB_File num_unique_seq_clusters num_unique_proteins num_unique_pdb_files
0 1 [3, 30] [7qvc, 7qvf, 7qwg, 7qwl, 7qwm, 7x83, 7x84, 8j7... [7qvc_pair_346_Repair.pdb, 7qvf_pair_346_Repai... 2 12 15
1 2 [0, 2, 3, 6, 7, 25, 34] [6l1t, 6l1u, 6l4s, 6qjq, 7e0f, 7qwm, 7wmm, 7x8... [6l1t_pair_164_Repair.pdb, 6l1t_pair_178_Repai... 7 40 83
2 3 [-1, 6, 7, 9, 31] [6l4s, 6rt0, 6sdz, 7e0f, 7ob4, 7ozg, 7p66, 7p6... [6l4s_pair_087_Repair.pdb, 6l4s_pair_095_Repai... 5 35 102
3 4 [0, 1, 6, 7] [6l1t, 6l1u, 6l4s, 6lrq, 6n3a, 6osj, 6osm, 6pe... [6l1t_pair_013_Repair.pdb, 6l1t_pair_328_Repai... 4 46 98
4 5 [1] [7rl4, 8dja] [7rl4_pair_112_Repair.pdb, 7rl4_pair_1192_Repa... 1 2 24
... ... ... ... ... ... ... ...
791 792 [7] [6h6b, 6osj, 7wmm, 7yk8, 7ynf, 7yng, 7ynm, 7yn... [6h6b_pair_196_Repair.pdb, 6h6b_pair_327_Repai... 1 16 32
792 793 [1] [8dja] [8dja_pair_078_Repair.pdb, 8dja_pair_095_Repai... 1 1 12
793 794 [7] [7v4b, 8cyt, 8cyy, 8cz0, 8cz3, 8cz6] [7v4b_pair_035_Repair.pdb, 7v4b_pair_044_Repai... 1 6 21
794 795 [6, 24] [6zcg, 8q7t, 8q8c, 8q8z, 8q98, 8q9a, 8q9b, 8qc... [6zcg_pair_076_Repair.pdb, 6zcg_pair_237_Repai... 2 9 13
795 796 [7] [6rt0, 6ssx, 7ozg, 8a4l, 8cyx, 8cz1] [6rt0_pair_034_Repair.pdb, 6rt0_pair_043_Repai... 1 6 26

796 rows × 7 columns

Now let's just quickly get another column SeqCluster_File_Sets which can help us quickly query and identify pdb files coming from different proteins.

K_expanded = all_clusters_df.copy().explode("Seq_Clusters")

def get_file_sets_for_3d_cluster(cluster_id):
    subset = K_expanded[K_expanded["3D_Cluster"] == cluster_id]
    return (
        subset.groupby("Seq_Clusters")["PDB_File"]
        .apply(lambda x: sorted(set(x)))
        .to_dict()
    )

K["SeqCluster_File_Sets"] = K["3D_Cluster"].apply(get_file_sets_for_3d_cluster)

OK, we can easily query the first 3D cluster and the protein fragments in potential interaction.

K.loc[0, "SeqCluster_File_Sets"]
{3: ['7qvc_pair_346_Repair.pdb',
  '7qvf_pair_346_Repair.pdb',
  '7qvf_pair_699_Repair.pdb',
  '7qwg_pair_317_Repair.pdb',
  '7qwl_pair_330_Repair.pdb',
  '7qwm_pair_245_Repair.pdb',
  '7x83_pair_207_Repair.pdb',
  '7x84_pair_157_Repair.pdb',
  '7x84_pair_324_Repair.pdb',
  '8j7n_pair_219_Repair.pdb',
  '8j7p_pair_189_Repair.pdb',
  '8otd_pair_325_Repair.pdb',
  '8x5h_pair_219_Repair.pdb'],
 30: ['8qn6_pair_339_Repair.pdb', '8qn6_pair_684_Repair.pdb']}
cluster_sizes = (
    all_clusters_df
        .groupby('3D_Cluster', dropna=True)
        .size()
        .reset_index(name='num_structures')
        .sort_values('num_structures', ascending=False)
        .reset_index(drop=True)
)
cluster_sizes['num_structures'].describe()
count     796.000000
mean       72.798995
std       175.603770
min        10.000000
25%        15.000000
50%        25.000000
75%        57.000000
max      2211.000000
Name: num_structures, dtype: float64
top100 = cluster_sizes[['3D_Cluster', 'num_structures']].head(100)

K_top = (
    K.merge(top100, on='3D_Cluster', how='inner')
     .sort_values('num_structures', ascending=False, kind='mergesort', ignore_index=True)
     .drop(columns='num_structures')
)
K_top
3D_Cluster Seq_Clusters PDB_ID PDB_File num_unique_seq_clusters num_unique_proteins num_unique_pdb_files SeqCluster_File_Sets
0 242 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14,... [2mpz, 2nao, 2nnt, 5o3l, 5o3o, 6dso, 6gx5, 6hr... [2mpz_pair_011_Repair.pdb, 2mpz_pair_052_Repai... 38 316 2211 {0: ['6n37_pair_032_Repair.pdb', '6n37_pair_03...
1 267 [-1, 0, 1, 2, 3, 4, 5, 6, 7, 9, 11, 12, 14, 20... [5o3l, 5o3o, 5o3t, 6cu7, 6gx5, 6h6b, 6hre, 6hu... [5o3l_pair_069_Repair.pdb, 5o3l_pair_322_Repai... 27 304 1603 {-1: ['8a9l_pair_160_standardized_Repair.pdb',...
2 241 [-1, 0, 1, 2, 3, 5, 6, 7, 8, 9, 13, 14, 16, 17... [5o3l, 5o3o, 5o3t, 6cu7, 6gx5, 6h6b, 6hre, 6hr... [5o3l_pair_056_Repair.pdb, 5o3l_pair_063_Repai... 26 267 1385 {-1: ['8bqw_pair_134_standardized_Repair.pdb',...
3 665 [-1, 0, 1, 2, 4, 5, 6, 7, 12, 13, 14, 17, 22, ... [2lmn, 2lmo, 5o3l, 5o3o, 5o3t, 6cu7, 6gx5, 6h6... [2lmn_pair_039_Repair.pdb, 2lmn_pair_047_Repai... 23 254 1219 {-1: ['8a9l_pair_159_standardized_Repair.pdb',...
4 250 [-1, 0, 1, 2, 6, 7, 9, 12, 13, 14, 22, 23, 26,... [2lmn, 2lmo, 2nnt, 6cu7, 6h6b, 6ic3, 6l1t, 6l1... [2lmn_pair_024_Repair.pdb, 2lmn_pair_032_Repai... 23 221 1214 {-1: ['8a9l_pair_161_standardized_Repair.pdb',...
... ... ... ... ... ... ... ... ...
95 632 [6, 7] [6osj, 6osl, 7qjx, 7qjy, 7qk1, 7qkj, 7qkl, 7qk... [6osj_pair_067_Repair.pdb, 6osj_pair_205_Repai... 2 53 107 {6: ['7qjx_pair_090_Repair.pdb', '7qjx_pair_38...
96 217 [0, 2, 6, 7, 30, 41] [5o3l, 5o3o, 6hre, 7mkf, 7nrq, 7nrs, 7nrt, 7nr... [5o3l_pair_182_Repair.pdb, 5o3l_pair_435_Repai... 6 58 104 {0: ['9for_pair_130_Repair.pdb'], 2: ['8a00_pa...
97 531 [1, 2, 6, 10] [5o3l, 6hre, 7mkf, 7nrq, 7nrv, 7p65, 7qig, 7qj... [5o3l_pair_198_Repair.pdb, 5o3l_pair_384_Repai... 4 42 103 {1: ['7qig_pair_177_Repair.pdb', '7qig_pair_29...
98 3 [-1, 6, 7, 9, 31] [6l4s, 6rt0, 6sdz, 7e0f, 7ob4, 7ozg, 7p66, 7p6... [6l4s_pair_087_Repair.pdb, 6l4s_pair_095_Repai... 5 35 102 {-1: ['8a9l_pair_143_standardized_Repair.pdb',...
99 427 [6] [5o3l, 6hre, 7mkh, 7nrq, 7qjv, 7qkh, 7ql4, 7up... [5o3l_pair_129_Repair.pdb, 5o3l_pair_204_Repai... 1 28 102 {6: ['5o3l_pair_129_Repair.pdb', '5o3l_pair_20...

100 rows × 8 columns

K_top.loc[0, "SeqCluster_File_Sets"]
K_top.loc[99, "SeqCluster_File_Sets"]
{6: ['5o3l_pair_129_Repair.pdb',
  '5o3l_pair_204_Repair.pdb',
  '5o3l_pair_422_Repair.pdb',
  '5o3l_pair_472_Repair.pdb',
  '6hre_pair_204_Repair.pdb',
  '6hre_pair_213_Repair.pdb',
  '6hre_pair_477_Repair.pdb',
  '7mkh_pair_116_Repair.pdb',
  '7mkh_pair_179_Repair.pdb',
  '7mkh_pair_388_Repair.pdb',
  '7mkh_pair_438_Repair.pdb',
  '7nrq_pair_222_Repair.pdb',
  '7nrq_pair_231_Repair.pdb',
  '7nrq_pair_510_Repair.pdb',
  '7qjv_pair_1071_Repair.pdb',
  '7qjv_pair_1106_Repair.pdb',
  '7qjv_pair_132_Repair.pdb',
  '7qjv_pair_139_Repair.pdb',
  '7qjv_pair_521_Repair.pdb',
  '7qjv_pair_746_Repair.pdb',
  '7qjv_pair_753_Repair.pdb',
  '7qjv_pair_807_Repair.pdb',
  '7qjv_pair_816_Repair.pdb',
  '7qkh_pair_173_Repair.pdb',
  '7qkh_pair_180_Repair.pdb',
  '7qkh_pair_246_Repair.pdb',
  '7qkh_pair_255_Repair.pdb',
  '7qkh_pair_423_Repair.pdb',
  '7qkh_pair_442_Repair.pdb',
  '7ql4_pair_151_Repair.pdb',
  '7ql4_pair_158_Repair.pdb',
  '7ql4_pair_548_Repair.pdb',
  '7upe_pair_145_Repair.pdb',
  '7upe_pair_151_Repair.pdb',
  '7upe_pair_521_Repair.pdb',
  '7upf_pair_145_Repair.pdb',
  '7upf_pair_151_Repair.pdb',
  '7upf_pair_521_Repair.pdb',
  '7ymn_pair_141_Repair.pdb',
  '7ymn_pair_148_Repair.pdb',
  '8azu_pair_191_Repair.pdb',
  '8azu_pair_200_Repair.pdb',
  '8azu_pair_268_Repair.pdb',
  '8azu_pair_275_Repair.pdb',
  '8azu_pair_457_Repair.pdb',
  '8azu_pair_532_Repair.pdb',
  '8bgv_pair_142_Repair.pdb',
  '8bgv_pair_148_Repair.pdb',
  '8bgv_pair_254_Repair.pdb',
  '8bgv_pair_262_Repair.pdb',
  '8byn_pair_243_Repair.pdb',
  '8byn_pair_250_Repair.pdb',
  '8byn_pair_424_Repair.pdb',
  '8caq_pair_251_Repair.pdb',
  '8caq_pair_258_Repair.pdb',
  '8ot6_pair_251_Repair.pdb',
  '8ot6_pair_258_Repair.pdb',
  '8otg_pair_251_Repair.pdb',
  '8otg_pair_258_Repair.pdb',
  '8q8m_pair_185_Repair.pdb',
  '8q8m_pair_194_Repair.pdb',
  '8q8r_pair_147_Repair.pdb',
  '8q8r_pair_154_Repair.pdb',
  '8q8r_pair_537_Repair.pdb',
  '8q8s_pair_342_Repair.pdb',
  '8q8s_pair_349_Repair.pdb',
  '8q8s_pair_716_Repair.pdb',
  '8uq7_pair_214_Repair.pdb',
  '8uq7_pair_223_Repair.pdb',
  '9bbl_pair_190_Repair.pdb',
  '9bbl_pair_199_Repair.pdb',
  '9bbl_pair_687_Repair.pdb',
  '9bbm_pair_209_Repair.pdb',
  '9bbm_pair_218_Repair.pdb',
  '9bbm_pair_287_Repair.pdb',
  '9bbm_pair_294_Repair.pdb',
  '9bbm_pair_499_Repair.pdb',
  '9bbm_pair_571_Repair.pdb',
  '9bxi_pair_155_Repair.pdb',
  '9bxi_pair_161_Repair.pdb',
  '9bxi_pair_269_Repair.pdb',
  '9bxi_pair_278_Repair.pdb',
  '9bxi_pair_547_Repair.pdb',
  '9bxi_pair_597_Repair.pdb',
  '9bxq_pair_209_Repair.pdb',
  '9bxq_pair_218_Repair.pdb',
  '9bxq_pair_286_Repair.pdb',
  '9bxq_pair_293_Repair.pdb',
  '9bxq_pair_494_Repair.pdb',
  '9bxq_pair_567_Repair.pdb',
  '9eoh_pair_134_Repair.pdb',
  '9eoh_pair_140_Repair.pdb',
  '9eoh_pair_501_Repair.pdb',
  '9erm_pair_243_Repair.pdb',
  '9erm_pair_250_Repair.pdb',
  '9erm_pair_448_Repair.pdb',
  '9h5g_pair_130_Repair.pdb',
  '9h5g_pair_137_Repair.pdb',
  '9h5g_pair_502_Repair.pdb',
  '9hbb_pair_159_Repair.pdb',
  '9hbb_pair_166_Repair.pdb',
  '9hbb_pair_555_Repair.pdb']}
K_subset_cl_60 = K[K['3D_Cluster'] == 50]
K_subset_cl_60.loc[49, "SeqCluster_File_Sets"]
{0: ['8cgh_pair_187_Repair.pdb'],
 1: ['9dmy_pair_021_Repair.pdb', '9dmz_pair_021_Repair.pdb'],
 4: ['7vzf_pair_213_Repair.pdb'],
 7: ['6h6b_pair_249_Repair.pdb',
  '6h6b_pair_380_Repair.pdb',
  '6osj_pair_219_Repair.pdb',
  '6osj_pair_450_Repair.pdb',
  '6osl_pair_194_Repair.pdb',
  '6osl_pair_400_Repair.pdb',
  '7wmm_pair_257_Repair.pdb',
  '7wmm_pair_392_Repair.pdb',
  '7xjx_pair_246_Repair.pdb',
  '7xjx_pair_437_Repair.pdb',
  '7yk8_pair_331_Repair.pdb',
  '7yk8_pair_643_Repair.pdb',
  '7ynm_pair_217_Repair.pdb',
  '7ynm_pair_450_Repair.pdb',
  '7ynn_pair_217_Repair.pdb',
  '7ynn_pair_450_Repair.pdb',
  '7yno_pair_251_Repair.pdb',
  '7yno_pair_447_Repair.pdb',
  '7ynp_pair_278_Repair.pdb',
  '7ynp_pair_434_Repair.pdb',
  '7ynq_pair_251_Repair.pdb',
  '7ynq_pair_447_Repair.pdb',
  '7ynr_pair_285_Repair.pdb',
  '7ynr_pair_448_Repair.pdb',
  '7yns_pair_258_Repair.pdb',
  '7yns_pair_398_Repair.pdb',
  '7ynt_pair_270_Repair.pdb',
  '8adu_pair_085_Repair.pdb',
  '8adv_pair_081_Repair.pdb',
  '8adv_pair_316_Repair.pdb',
  '8adw_pair_085_Repair.pdb',
  '8adw_pair_332_Repair.pdb',
  '8hzb_pair_250_Repair.pdb',
  '8hzb_pair_390_Repair.pdb',
  '8hzc_pair_391_Repair.pdb',
  '8hzc_pair_560_Repair.pdb',
  '8oqi_pair_167_Repair.pdb',
  '8oqi_pair_346_Repair.pdb',
  '8pk4_pair_125_Repair.pdb',
  '8pk4_pair_380_Repair.pdb',
  '8qpz_pair_213_Repair.pdb',
  '8qpz_pair_319_Repair.pdb',
  '8ri9_pair_181_Repair.pdb',
  '8zlp_pair_139_Repair.pdb',
  '8zmy_pair_139_Repair.pdb',
  '8zmy_pair_411_Repair.pdb',
  '8zwi_pair_280_Repair.pdb',
  '8zwi_pair_442_Repair.pdb',
  '8zwj_pair_284_Repair.pdb',
  '8zwj_pair_450_Repair.pdb',
  '9ijp_pair_266_Repair.pdb',
  '9ijp_pair_418_Repair.pdb'],
 14: ['8ttl_pair_308_Repair.pdb'],
 21: ['7zh7_pair_009_Repair.pdb'],
 22: ['8qv8_pair_099_Repair.pdb', '8qv8_pair_296_Repair.pdb'],
 30: ['2nao_pair_053_Repair.pdb', '2nao_pair_115_Repair.pdb']}