diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 327d7a650d74ccd408b37870cec8deb95c564d74..69599cae1e2d148a8ecdc459b24dce0dc9b25fb5 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -13,7 +13,7 @@ isort:
     - pip install isort==5.12.0
     - isort --version-number
   script:
-    - isort --check-only *.py
+    - isort --check-only --diff *.py
 
 flake8:
   image: pipelinecomponents/flake8:0.12.0
diff --git a/pyproject.toml b/pyproject.toml
index f844b2325ba0ecadd15adbbb39173cac1eab8901..dceea6acf2b96dffef64de826a9d943560890dcc 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,4 +4,4 @@ line-length = 120
 
 [tool.isort]
 profile = "black"
-known_third_party = "wandb"
\ No newline at end of file
+known_third_party = "wandb"
diff --git a/save-embeddings-per-bucket.py b/save-embeddings-per-bucket.py
new file mode 100644
index 0000000000000000000000000000000000000000..2695dba64ae1e104dcb247d328758661a152aa1a
--- /dev/null
+++ b/save-embeddings-per-bucket.py
@@ -0,0 +1,52 @@
+import argparse
+import logging
+from os import listdir
+
+import pandas as pd
+from tqdm import tqdm
+
+from utils import create_dir, load_pickle, save_pickle
+
+LOG = logging.getLogger(__name__)
+
+
+def load_all_embeddings(path):
+    df = pd.DataFrame([])
+    for i, emb_file in tqdm(enumerate([f for f in listdir(path) if f.endswith(".pkl")])):
+        objs = load_pickle(f"{path}/{emb_file}")
+        df = pd.concat([df, objs])
+    return df
+
+
+if __name__ == '__main__':
+    argparse.ArgumentParser()
+    parser = argparse.ArgumentParser(description='Save embeddings per bucket')
+    parser.add_argument(
+        '--data-path', type=str, default='/embeddings/ng-granularity-10-randomized/data/', help='path to embeddings'
+    )
+    parser.add_argument(
+        '--predictions-path',
+        type=str,
+        default=(
+            '/data/proteins/predictions/overall/l0--model-MLP5--batchsize-1000000--n_chunks-25--epochs-per-chunk-20'
+            '--mem-50--n_classes-2000--sample_size-2000000--n_iterations-10--2023-10-21-21-36-27'
+        ),
+        help='path to embeddings',
+    )
+    parser.add_argument(
+        '--output-path',
+        type=str,
+        default='/embeddings/ng-granularity-10-randomized/bucket-data/',
+        help='path to embeddings',
+    )
+    args = parser.parse_args()
+    logging.basicConfig(level=logging.INFO, format='[%(asctime)s][%(levelname)-5.5s][%(name)-.20s] %(message)s')
+    LOG.info('Loading embeddings')
+    df = load_all_embeddings(args.data_path)
+
+    create_dir(args.output_path)
+
+    LOG.info('Iterating over cluster predictions')
+    for f in tqdm(listdir(args.predictions_path)):
+        data_subset = df[df.index.isin(load_pickle(f'{args.predictions_path}/{f}'))]
+        save_pickle(f'{args.output_path}/{f}', data_subset)
diff --git a/train-embeddings.py b/train-embeddings.py
index fe4a3ae9116c792c91cc1c142a495ee9d0ba3e02..27dda3ce26d16d44cc2fe22fd5d980acab6df299 100644
--- a/train-embeddings.py
+++ b/train-embeddings.py
@@ -3,6 +3,7 @@ import gc
 import logging
 import sys
 import time
+from typing import Dict
 
 import numpy as np
 import pandas as pd
@@ -10,6 +11,7 @@ import torch
 import wandb
 from torchsummary import summary
 
+from analysis_utils import load_csv_gt_file, load_json_gt_file
 from clustering import assign_labels, run_clustering
 from model import LIDataset, NeuralNetwork, data_X_to_torch
 from utils import (
@@ -17,6 +19,7 @@ from utils import (
     get_current_timestamp,
     load_dataset,
     load_newest_file_in_dir,
+    load_pickle,
     load_predictions,
     save_labels,
     save_model,
@@ -29,6 +32,23 @@ np.random.seed(2023)
 LOG = logging.getLogger(__name__)
 
 
+def load_ground_truth(config: Dict, protein_id='A0A346LI80') -> pd.DataFrame:
+    sim_proteins1 = load_json_gt_file(config.analyses_path, protein_id)
+    sim_proteins2 = load_csv_gt_file(config.analyses_path, protein_id)
+    sim_proteins = pd.concat([sim_proteins1, sim_proteins2])
+    sim_proteins['tm_score'] = sim_proteins['tm_score'].astype(float)
+    sim_proteins = sim_proteins.query('tm_score > 0.9')
+    LOG.info(f'Loaded ground-truth results of shape: {sim_proteins.shape}.')
+    return sim_proteins
+
+
+def get_cluster_assignments(sim_proteins_chunk, kmeans, nn):
+    LOG.info(f'Finding assignments to {sim_proteins_chunk.shape[0]} similar proteins.')
+    cluster_assignments_kmeans = assign_labels(kmeans, sim_proteins_chunk)
+    cluster_assignments_nn = nn.predict(data_X_to_torch(sim_proteins_chunk))
+    return cluster_assignments_kmeans, cluster_assignments_nn
+
+
 def run_experiment(config):
     with wandb.init(
         project='large-data-training',
@@ -47,6 +67,13 @@ def run_experiment(config):
         kmeans = None
         predictions_per_class = {c: [] for c in range(config.n_classes)}
 
+        analysis_protein_id = 'A0A346LI80'
+        # the embedding file of A0A346LI80
+        pickle_file_sample_protein = '0-21.pkl'
+        LOG.info(f'Loading sample protein: `{analysis_protein_id}`')
+        sample_protein = (
+            load_pickle(f'{config.path}/{pickle_file_sample_protein}').loc[analysis_protein_id].values.reshape(1, -1)
+        )
         pickle_files_used = []
         for chunk in range(config.n_chunks):
             LOG.info(f'--- Chunk: {chunk} | Loading data')
@@ -141,11 +168,53 @@ def run_experiment(config):
             save_model(nn.model, f'/data/proteins/models/{config.name}/chunk-{chunk}.pt')
             LOG.info(f'Model saved in {time.time() - s} seconds')
 
+            # ========== CLUSTER ASSIGNMENTS ========== #
+
+            LOG.info(f'Collecting cluster assignments for a sample protein: `{analysis_protein_id}`')
+            cluster_kmeans, cluster_nn = get_cluster_assignments(sample_protein, kmeans, nn)
+            cluster_kmeans = cluster_kmeans[0]
+            cluster_nn = cluster_nn[0]
+            LOG.info(f'Sample protein belongs to cluster (kmeans): {cluster_kmeans}')
+            LOG.info(f'Sample protein belongs to cluster (nn): {cluster_nn}')
+
+            sim_proteins = load_ground_truth(config, analysis_protein_id)
+            sim_proteins = data_pd[data_pd.index.isin(sim_proteins.protein)]
+            wandb.log({'sim_proteins_found_in_chunk': sim_proteins.shape[0]})
+            cluster_kmeans_sim, cluster_nn_sim = get_cluster_assignments(sim_proteins.values, kmeans, nn)
+            cluster_kmeans_sim, cluster_nn_sim = get_cluster_assignments(sim_proteins, kmeans, nn)
+            wandb.log(
+                {'cluster_accuracy_kmeans': (cluster_kmeans_sim == cluster_kmeans).sum() / cluster_kmeans_sim.shape[0]}
+            )
+            wandb.log({'cluster_accuracy_nn': (cluster_nn_sim == cluster_nn).sum() / cluster_nn_sim.shape[0]})
+
+            cluster_kmeans_sim = pd.DataFrame(cluster_kmeans_sim, index=sim_proteins.index, columns=['cluster'])
+            cluster_nn_sim = pd.DataFrame(cluster_nn_sim, index=sim_proteins.index, columns=['cluster'])
+
+            width = 0.4
+            ax = cluster_kmeans_sim.cluster.value_counts().plot(
+                kind='bar', width=width, position=1, label='kmeans', legend=True, grid=True, figsize=(20, 7)
+            )
+            fig = cluster_nn_sim.cluster.value_counts().plot(
+                kind='bar',
+                color='orange',
+                ax=ax,
+                width=width,
+                position=0,
+                label='NN',
+                legend=True,
+                grid=True,
+                title=(
+                    'cluster assignments of current chunk for `A0A346LI80` |'
+                    f'cluster_kmeans={cluster_kmeans} | cluster_nn={cluster_nn}'
+                ),
+            )
+            wandb.log({"plot": fig})
             # ========== PREDICTIONS ========== #
             LOG.info('Collecting predictions')
             s = time.time()
 
             predictions_chunk = 1_000_000 if 1_000_000 < data_torch.shape[0] else data_torch.shape[0]
+
             steps = data_torch.shape[0] // predictions_chunk
             for j in range(steps):
                 predictions = nn.predict(data_torch[j * (predictions_chunk) : (j + 1) * (predictions_chunk), :])
@@ -164,11 +233,12 @@ def run_experiment(config):
                 LOG.info(f'Predictions saved in {time.time() - s} seconds')
                 wandb.log(
                     {
-                        "predictions_accuracy": (
+                        'predictions_accuracy': (
                             predictions == labels[j * (predictions_chunk) : (j + 1) * (predictions_chunk)]
                         ).mean()
                     }
                 )
+
                 predictions.groupby(predictions).apply(lambda x: predictions_per_class[x.name].extend(list(x.index)))
 
             # ========== CLEANING ========== #
@@ -176,13 +246,17 @@ def run_experiment(config):
             s = time.time()
             del predictions
             del data
+            del data_pd
             del labels
             gc.collect()
             LOG.info('Deleted data and labels in %s seconds', time.time() - s)
 
         # ========== STORE PREDICTIONS FOR EACH CLASS ========== #
         for clazz, predictions in predictions_per_class.items():
-            save_predictions(predictions, f'/data/proteins/predictions/overall/{config.name}/class-{clazz}.pkl')
+            save_predictions(
+                predictions,
+                f'/data/proteins/predictions/overall/{config.name}/class-{clazz}.pkl',
+            )
 
         LOG.info('Run finished')
 
@@ -199,14 +273,14 @@ def print_label_summary(labels):
 
 
 def collect_wandb_data_stats(data, pickle_files_used):
-    wandb.run.summary["size_data_matrix_nbytes_gb"] = data.nbytes / 1024**3
-    wandb.run.summary["size_data_matrix_getsizeof_gb"] = sys.getsizeof(data) / 1024**3
-    wandb.run.summary["embedding_files_used"] = ", ".join(pickle_files_used)
+    wandb.run.summary['size_data_matrix_nbytes_gb'] = data.nbytes / 1024**3
+    wandb.run.summary['size_data_matrix_getsizeof_gb'] = sys.getsizeof(data) / 1024**3
+    wandb.run.summary['embedding_files_used'] = ', '.join(pickle_files_used)
 
 
 def collect_wandb_label_stats(labels):
-    wandb.run.summary["size_labels_nbytes_gb"] = labels.nbytes / 1024**3
-    wandb.run.summary["size_labels_getsizeof_gb"] = sys.getsizeof(labels) / 1024**3
+    wandb.run.summary['size_labels_nbytes_gb'] = labels.nbytes / 1024**3
+    wandb.run.summary['size_labels_getsizeof_gb'] = sys.getsizeof(labels) / 1024**3
 
 
 def load_model(config, dimensionality):
@@ -237,14 +311,17 @@ if __name__ == '__main__':
     parser.add_argument(
         '-p', '--path', type=str, default='/embeddings/ng-granularity-10-randomized/data/', help='Path to the dataset'
     )
+    parser.add_argument(
+        '--analyses-path', type=str, default='/data/proteins/analyses/', help='Path to the analyses folder'
+    )
     parser.add_argument('-n', '--n-chunks', type=int, default=3, help='Number of chunks to load')
     parser.add_argument('-s', '--chunk-size', type=int, default=10_000_000, help='Size of the data chunk to use')
     parser.add_argument('--memory', type=int, default=16, help='Gb of mem to use')
     parser.add_argument('--n-classes', type=int, default=1000, help='Number of classes to use')
     # Clustering parameters
-    parser.add_argument("--sample-size", type=int, default=500_000, help="Size of the sample")
-    parser.add_argument("--n-iterations", type=int, default=10, help="Number of k-means iterations")
-    parser.add_argument("--predictions-path", type=str, default="", nargs='?', const='', help="Path to the predictions")
+    parser.add_argument('--sample-size', type=int, default=500_000, help='Size of the sample')
+    parser.add_argument('--n-iterations', type=int, default=10, help='Number of k-means iterations')
+    parser.add_argument('--predictions-path', type=str, default='', nargs='?', const='', help='Path to the predictions')
 
     args = parser.parse_args()