diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 327d7a650d74ccd408b37870cec8deb95c564d74..69599cae1e2d148a8ecdc459b24dce0dc9b25fb5 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -13,7 +13,7 @@ isort: - pip install isort==5.12.0 - isort --version-number script: - - isort --check-only *.py + - isort --check-only --diff *.py flake8: image: pipelinecomponents/flake8:0.12.0 diff --git a/pyproject.toml b/pyproject.toml index f844b2325ba0ecadd15adbbb39173cac1eab8901..dceea6acf2b96dffef64de826a9d943560890dcc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,4 +4,4 @@ line-length = 120 [tool.isort] profile = "black" -known_third_party = "wandb" \ No newline at end of file +known_third_party = "wandb" diff --git a/save-embeddings-per-bucket.py b/save-embeddings-per-bucket.py new file mode 100644 index 0000000000000000000000000000000000000000..2695dba64ae1e104dcb247d328758661a152aa1a --- /dev/null +++ b/save-embeddings-per-bucket.py @@ -0,0 +1,52 @@ +import argparse +import logging +from os import listdir + +import pandas as pd +from tqdm import tqdm + +from utils import create_dir, load_pickle, save_pickle + +LOG = logging.getLogger(__name__) + + +def load_all_embeddings(path): + df = pd.DataFrame([]) + for i, emb_file in tqdm(enumerate([f for f in listdir(path) if f.endswith(".pkl")])): + objs = load_pickle(f"{path}/{emb_file}") + df = pd.concat([df, objs]) + return df + + +if __name__ == '__main__': + argparse.ArgumentParser() + parser = argparse.ArgumentParser(description='Save embeddings per bucket') + parser.add_argument( + '--data-path', type=str, default='/embeddings/ng-granularity-10-randomized/data/', help='path to embeddings' + ) + parser.add_argument( + '--predictions-path', + type=str, + default=( + '/data/proteins/predictions/overall/l0--model-MLP5--batchsize-1000000--n_chunks-25--epochs-per-chunk-20' + '--mem-50--n_classes-2000--sample_size-2000000--n_iterations-10--2023-10-21-21-36-27' + ), + help='path to embeddings', + ) + parser.add_argument( + '--output-path', + type=str, + default='/embeddings/ng-granularity-10-randomized/bucket-data/', + help='path to embeddings', + ) + args = parser.parse_args() + logging.basicConfig(level=logging.INFO, format='[%(asctime)s][%(levelname)-5.5s][%(name)-.20s] %(message)s') + LOG.info('Loading embeddings') + df = load_all_embeddings(args.data_path) + + create_dir(args.output_path) + + LOG.info('Iterating over cluster predictions') + for f in tqdm(listdir(args.predictions_path)): + data_subset = df[df.index.isin(load_pickle(f'{args.predictions_path}/{f}'))] + save_pickle(f'{args.output_path}/{f}', data_subset) diff --git a/train-embeddings.py b/train-embeddings.py index fe4a3ae9116c792c91cc1c142a495ee9d0ba3e02..27dda3ce26d16d44cc2fe22fd5d980acab6df299 100644 --- a/train-embeddings.py +++ b/train-embeddings.py @@ -3,6 +3,7 @@ import gc import logging import sys import time +from typing import Dict import numpy as np import pandas as pd @@ -10,6 +11,7 @@ import torch import wandb from torchsummary import summary +from analysis_utils import load_csv_gt_file, load_json_gt_file from clustering import assign_labels, run_clustering from model import LIDataset, NeuralNetwork, data_X_to_torch from utils import ( @@ -17,6 +19,7 @@ from utils import ( get_current_timestamp, load_dataset, load_newest_file_in_dir, + load_pickle, load_predictions, save_labels, save_model, @@ -29,6 +32,23 @@ np.random.seed(2023) LOG = logging.getLogger(__name__) +def load_ground_truth(config: Dict, protein_id='A0A346LI80') -> pd.DataFrame: + sim_proteins1 = load_json_gt_file(config.analyses_path, protein_id) + sim_proteins2 = load_csv_gt_file(config.analyses_path, protein_id) + sim_proteins = pd.concat([sim_proteins1, sim_proteins2]) + sim_proteins['tm_score'] = sim_proteins['tm_score'].astype(float) + sim_proteins = sim_proteins.query('tm_score > 0.9') + LOG.info(f'Loaded ground-truth results of shape: {sim_proteins.shape}.') + return sim_proteins + + +def get_cluster_assignments(sim_proteins_chunk, kmeans, nn): + LOG.info(f'Finding assignments to {sim_proteins_chunk.shape[0]} similar proteins.') + cluster_assignments_kmeans = assign_labels(kmeans, sim_proteins_chunk) + cluster_assignments_nn = nn.predict(data_X_to_torch(sim_proteins_chunk)) + return cluster_assignments_kmeans, cluster_assignments_nn + + def run_experiment(config): with wandb.init( project='large-data-training', @@ -47,6 +67,13 @@ def run_experiment(config): kmeans = None predictions_per_class = {c: [] for c in range(config.n_classes)} + analysis_protein_id = 'A0A346LI80' + # the embedding file of A0A346LI80 + pickle_file_sample_protein = '0-21.pkl' + LOG.info(f'Loading sample protein: `{analysis_protein_id}`') + sample_protein = ( + load_pickle(f'{config.path}/{pickle_file_sample_protein}').loc[analysis_protein_id].values.reshape(1, -1) + ) pickle_files_used = [] for chunk in range(config.n_chunks): LOG.info(f'--- Chunk: {chunk} | Loading data') @@ -141,11 +168,53 @@ def run_experiment(config): save_model(nn.model, f'/data/proteins/models/{config.name}/chunk-{chunk}.pt') LOG.info(f'Model saved in {time.time() - s} seconds') + # ========== CLUSTER ASSIGNMENTS ========== # + + LOG.info(f'Collecting cluster assignments for a sample protein: `{analysis_protein_id}`') + cluster_kmeans, cluster_nn = get_cluster_assignments(sample_protein, kmeans, nn) + cluster_kmeans = cluster_kmeans[0] + cluster_nn = cluster_nn[0] + LOG.info(f'Sample protein belongs to cluster (kmeans): {cluster_kmeans}') + LOG.info(f'Sample protein belongs to cluster (nn): {cluster_nn}') + + sim_proteins = load_ground_truth(config, analysis_protein_id) + sim_proteins = data_pd[data_pd.index.isin(sim_proteins.protein)] + wandb.log({'sim_proteins_found_in_chunk': sim_proteins.shape[0]}) + cluster_kmeans_sim, cluster_nn_sim = get_cluster_assignments(sim_proteins.values, kmeans, nn) + cluster_kmeans_sim, cluster_nn_sim = get_cluster_assignments(sim_proteins, kmeans, nn) + wandb.log( + {'cluster_accuracy_kmeans': (cluster_kmeans_sim == cluster_kmeans).sum() / cluster_kmeans_sim.shape[0]} + ) + wandb.log({'cluster_accuracy_nn': (cluster_nn_sim == cluster_nn).sum() / cluster_nn_sim.shape[0]}) + + cluster_kmeans_sim = pd.DataFrame(cluster_kmeans_sim, index=sim_proteins.index, columns=['cluster']) + cluster_nn_sim = pd.DataFrame(cluster_nn_sim, index=sim_proteins.index, columns=['cluster']) + + width = 0.4 + ax = cluster_kmeans_sim.cluster.value_counts().plot( + kind='bar', width=width, position=1, label='kmeans', legend=True, grid=True, figsize=(20, 7) + ) + fig = cluster_nn_sim.cluster.value_counts().plot( + kind='bar', + color='orange', + ax=ax, + width=width, + position=0, + label='NN', + legend=True, + grid=True, + title=( + 'cluster assignments of current chunk for `A0A346LI80` |' + f'cluster_kmeans={cluster_kmeans} | cluster_nn={cluster_nn}' + ), + ) + wandb.log({"plot": fig}) # ========== PREDICTIONS ========== # LOG.info('Collecting predictions') s = time.time() predictions_chunk = 1_000_000 if 1_000_000 < data_torch.shape[0] else data_torch.shape[0] + steps = data_torch.shape[0] // predictions_chunk for j in range(steps): predictions = nn.predict(data_torch[j * (predictions_chunk) : (j + 1) * (predictions_chunk), :]) @@ -164,11 +233,12 @@ def run_experiment(config): LOG.info(f'Predictions saved in {time.time() - s} seconds') wandb.log( { - "predictions_accuracy": ( + 'predictions_accuracy': ( predictions == labels[j * (predictions_chunk) : (j + 1) * (predictions_chunk)] ).mean() } ) + predictions.groupby(predictions).apply(lambda x: predictions_per_class[x.name].extend(list(x.index))) # ========== CLEANING ========== # @@ -176,13 +246,17 @@ def run_experiment(config): s = time.time() del predictions del data + del data_pd del labels gc.collect() LOG.info('Deleted data and labels in %s seconds', time.time() - s) # ========== STORE PREDICTIONS FOR EACH CLASS ========== # for clazz, predictions in predictions_per_class.items(): - save_predictions(predictions, f'/data/proteins/predictions/overall/{config.name}/class-{clazz}.pkl') + save_predictions( + predictions, + f'/data/proteins/predictions/overall/{config.name}/class-{clazz}.pkl', + ) LOG.info('Run finished') @@ -199,14 +273,14 @@ def print_label_summary(labels): def collect_wandb_data_stats(data, pickle_files_used): - wandb.run.summary["size_data_matrix_nbytes_gb"] = data.nbytes / 1024**3 - wandb.run.summary["size_data_matrix_getsizeof_gb"] = sys.getsizeof(data) / 1024**3 - wandb.run.summary["embedding_files_used"] = ", ".join(pickle_files_used) + wandb.run.summary['size_data_matrix_nbytes_gb'] = data.nbytes / 1024**3 + wandb.run.summary['size_data_matrix_getsizeof_gb'] = sys.getsizeof(data) / 1024**3 + wandb.run.summary['embedding_files_used'] = ', '.join(pickle_files_used) def collect_wandb_label_stats(labels): - wandb.run.summary["size_labels_nbytes_gb"] = labels.nbytes / 1024**3 - wandb.run.summary["size_labels_getsizeof_gb"] = sys.getsizeof(labels) / 1024**3 + wandb.run.summary['size_labels_nbytes_gb'] = labels.nbytes / 1024**3 + wandb.run.summary['size_labels_getsizeof_gb'] = sys.getsizeof(labels) / 1024**3 def load_model(config, dimensionality): @@ -237,14 +311,17 @@ if __name__ == '__main__': parser.add_argument( '-p', '--path', type=str, default='/embeddings/ng-granularity-10-randomized/data/', help='Path to the dataset' ) + parser.add_argument( + '--analyses-path', type=str, default='/data/proteins/analyses/', help='Path to the analyses folder' + ) parser.add_argument('-n', '--n-chunks', type=int, default=3, help='Number of chunks to load') parser.add_argument('-s', '--chunk-size', type=int, default=10_000_000, help='Size of the data chunk to use') parser.add_argument('--memory', type=int, default=16, help='Gb of mem to use') parser.add_argument('--n-classes', type=int, default=1000, help='Number of classes to use') # Clustering parameters - parser.add_argument("--sample-size", type=int, default=500_000, help="Size of the sample") - parser.add_argument("--n-iterations", type=int, default=10, help="Number of k-means iterations") - parser.add_argument("--predictions-path", type=str, default="", nargs='?', const='', help="Path to the predictions") + parser.add_argument('--sample-size', type=int, default=500_000, help='Size of the sample') + parser.add_argument('--n-iterations', type=int, default=10, help='Number of k-means iterations') + parser.add_argument('--predictions-path', type=str, default='', nargs='?', const='', help='Path to the predictions') args = parser.parse_args()