From aa821025061b333f707ed17bfc52d2caab1a9a01 Mon Sep 17 00:00:00 2001 From: Michal Puncochar Date: Mon, 2 Oct 2023 11:15:48 +0200 Subject: [PATCH] Minor fixes and code improvements --- metaphlan/strainphlan.py | 213 +++++++++++++------------ metaphlan/utils/database_controller.py | 22 ++- metaphlan/utils/sample2markers.py | 2 +- 3 files changed, 135 insertions(+), 102 deletions(-) diff --git a/metaphlan/strainphlan.py b/metaphlan/strainphlan.py index 6a53120..f347c82 100755 --- a/metaphlan/strainphlan.py +++ b/metaphlan/strainphlan.py @@ -17,8 +17,9 @@ import tempfile import time from shutil import copyfile, rmtree +from typing import Iterable -import numpy +import numpy as np import pandas as pd from Bio import SeqIO, Seq @@ -32,45 +33,35 @@ class Strainphlan: - def get_markers_matrix_from_samples(self, print_clades=False): - """Gets a binary matrix representing the presence/absence of the clade markers in the uploaded samples + @staticmethod + def sample_path_to_name(sample_path): + if sample_path.endswith('.bz2'): + sample_path = sample_path[:-len('.bz2')] + + sample_file = os.path.basename(sample_path) + name, ext = os.path.splitext(sample_file) # ext can be pkl, json, fna, ... + return name - Args: - print_clades (bool, optional): Whether the function has been called from the print_clades_only function. Defaults to False. + + def get_markers_matrix_from_samples(self): + """Gets a binary matrix representing the presence/absence of the clade markers in the uploaded samples Returns: list: the list containing the samples-to-markers information """ - if print_clades: - self.db_clade_markers = pd.DataFrame() - else: - if not self.clade_markers_file: - self.database_controller.extract_markers([self.clade], self.tmp_dir) - self.clade_markers_file = os.path.join(self.tmp_dir, "{}.fna".format(self.clade)) - else: - base_name = os.path.basename(self.clade_markers_file) - name, ext = os.path.splitext(base_name) - if ext == '.bz2': - decompress_bz2(self.clade_markers_file, self.tmp_dir) - self.clade_markers_file = os.path.join(self.tmp_dir, name) - else: - copyfile(self.clade_markers_file, - os.path.join(self.tmp_dir, base_name)) - self.clade_markers_file = os.path.join( - self.tmp_dir, base_name) - self.db_clade_markers = {rec.id: rec.seq for rec in SeqIO.parse( - open(self.clade_markers_file, 'r'), 'fasta')} - markers_matrix = execute_pool(((Strainphlan.get_matrix_for_sample, sample, list( - self.db_clade_markers.keys()), self.breadth_thres) for sample in self.samples), self.nprocs) + + markers_matrix = execute_pool(((Strainphlan.get_matrix_for_sample, sample, self.clade_markers_names, + self.breadth_thres) for sample in self.samples), self.nprocs) return markers_matrix + @staticmethod def get_matrix_for_sample(sample_path, clade_markers, breadth_thres): """Returns the matrix with the presence / absence of the clade markers in a samples Args: sample_path (str): the path to the sample - clade_markers (list): the list with the clade markers names + clade_markers (Iterable): the list with the clade markers names breadth_thres: Returns: @@ -89,32 +80,36 @@ def filter_markers_matrix(self, markers_matrix, messages=True): """Filters the primary samples, references and markers based on the user defined thresholds Args: - markers_matrix (list): The list with the sample-to-markers information - messages (bool, optional): Whether to be verbose. Defaults to True. + markers_matrix (list[dict]): The list with the sample-to-markers information + messages (bool): Whether to be verbose and halt execution when less than 4 samples are left. + Defaults to True. """ mm = pd.DataFrame.from_records(markers_matrix, index='sample') # df with index samples and columns markers # Step I: samples with not enough markers are treated as secondary # here the percentage is calculated from the number of markers in at least one sample, it can be less than # the total number of available markers in the database for the clade - _, n_markers = mm.shape - min_markers = max(self.sample_with_n_markers, n_markers * self.sample_with_n_markers_perc / 100) + n_samples_0, n_markers_0 = mm.shape + min_markers = max(self.sample_with_n_markers, n_markers_0 * self.sample_with_n_markers_perc / 100) mm_primary = mm.loc[mm.sum(axis=1) >= min_markers] # Step II: filter markers in not enough primary samples - n_samples, _ = mm_primary.shape - min_samples = n_samples * self.marker_in_n_samples_perc / 100 + n_samples_primary, _ = mm_primary.shape + min_samples = n_samples_primary * self.marker_in_n_samples_perc / 100 mm = mm.loc[:, mm_primary.sum(axis=0) >= min_samples] # Step III: filter samples with not enough markers # here the percentage is calculated from the remaining markers - _, n_markers = mm.shape + _, n_markers_2 = mm.shape min_markers = max(self.sample_with_n_markers_after_filt, - n_markers * self.sample_with_n_markers_after_filt_perc / 100) + n_markers_2 * self.sample_with_n_markers_after_filt_perc / 100) mm = mm.loc[mm.sum(axis=1) >= min_markers] - if len(mm) < 4: + n_samples_3 = len(mm) + if n_samples_3 < 4: if messages: - error("Phylogeny can not be inferred. Less than 4 samples remained after filtering of the markers.", + error(f"Phylogeny can not be inferred. Less than 4 samples remained after filtering.\n" + f"{n_samples_3} / {n_samples_0} samples ({n_samples_primary} primary) " + f"and {n_markers_2} / {n_markers_0 } markers remained.", exit=True) return @@ -125,22 +120,22 @@ def copy_filtered_references(self, markers_tmp_dir, filtered_samples): """Copies the FASTA files of the filtered references to be processed by PhyloPhlAn Args: - markers_tmp_dir (str): the temporal folder where to copy the reference genomes + markers_tmp_dir (str): the temporary folder where to copy the reference genomes filtered_samples (set): the set of samples after filtering """ for reference in self.references: - sample_name = os.path.splitext(os.path.basename(reference[:-4]))[0] if reference.endswith( - '.bz2') else os.path.splitext(os.path.basename(reference))[0] + reference_name = self.sample_path_to_name(reference) + + if reference_name in filtered_samples: + reference_marker = os.path.join(self.tmp_dir, "reference_markers", f'{reference_name}.fna.bz2') + copyfile(reference_marker, os.path.join(markers_tmp_dir, f"{reference_name}.fna.bz2")) - if sample_name in filtered_samples: - reference_marker = os.path.join(self.tmp_dir, "reference_markers", f'{sample_name}.fna.bz2') - copyfile(reference_marker, os.path.join(markers_tmp_dir, f"{sample_name}.fna.bz2")) def matrix_markers_to_fasta(self): """For each sample, writes the FASTA files with the sequences of the filtered markers Returns: - str: the temporal folder where the FASTA files were written + str: the temporary folder where the FASTA files were written """ markers_tmp_dir = os.path.join(self.tmp_dir, "{}.StrainPhlAn4".format(self.clade)) create_folder(markers_tmp_dir) @@ -151,8 +146,9 @@ def matrix_markers_to_fasta(self): self.trim_sequences, markers_tmp_dir) for sample in self.samples), self.nprocs) return markers_tmp_dir - @staticmethod - def sample_markers_to_fasta(sample_path, filtered_samples, filtered_markers, trim_sequences, markers_tmp_dir): + + @classmethod + def sample_markers_to_fasta(cls, sample_path, filtered_samples, filtered_markers, trim_sequences, markers_tmp_dir): """Writes a FASTA file with the filtered clade markers of a sample Args: @@ -160,11 +156,11 @@ def sample_markers_to_fasta(sample_path, filtered_samples, filtered_markers, tri filtered_markers: filtered_samples: trim_sequences: - markers_tmp_dir (str): the temporal folder were the FASTA file is written + markers_tmp_dir (str): the temporary folder were the FASTA file is written """ if sample_path in filtered_samples: - sample_name = os.path.splitext(os.path.basename(sample_path))[0].replace(".pkl", "").replace('.json', '') - marker_output_file = os.path.join(markers_tmp_dir, '{}.fna.bz2'.format(sample_name)) + sample_name = cls.sample_path_to_name(sample_path) + marker_output_file = os.path.join(markers_tmp_dir, f'{sample_name}.fna.bz2') sample = ConsensusMarkers.from_file(sample_path) sample.consensus_markers = [m for m in sample.consensus_markers if m.name in filtered_markers] sample.to_fasta(marker_output_file, trim_ends=trim_sequences) @@ -178,8 +174,16 @@ def get_markers_from_references(self): Returns: list: the list with the samples-to-markers information of the main samples and references """ - return execute_pool(((Strainphlan.process_reference, reference, self.tmp_dir, self.clade_markers_file, list( - self.db_clade_markers.keys()), self.trim_sequences) for reference in self.references), self.nprocs) + if not self.clade_markers_file: + self.database_controller.extract_markers([self.clade], self.tmp_dir) + self.clade_markers_file = os.path.join(self.tmp_dir, "{}.fna".format(self.clade)) + elif self.clade_markers_file.endswith(".bz2"): + self.clade_markers_file = decompress_bz2(self.clade_markers_file, self.tmp_dir) + + return execute_pool(((Strainphlan.process_reference, reference, self.tmp_dir, self.clade_markers_file, + self.clade_markers_names, self.trim_sequences) + for reference in self.references), self.nprocs) + @classmethod def process_reference(cls, reference_file, tmp_dir, clade_markers_file, clade_markers, trim_sequences): @@ -187,19 +191,17 @@ def process_reference(cls, reference_file, tmp_dir, clade_markers_file, clade_ma Args: reference_file (str): path to the reference file - tmp_dir (str): the temporal folder where the BLASTn results where saved + tmp_dir (str): the temporary folder where the BLASTn results where saved clade_markers_file (str): - clade_markers (list): the list with the clade markers names + clade_markers (Iterable): the list with the clade markers names trim_sequences: Returns: dict: the dictionary with the reference-to-markers information """ - name = os.path.splitext(os.path.basename(reference_file))[0] if reference_file.endswith(".bz2"): uncompressed_refernces_dir = os.path.join(tmp_dir, "uncompressed_references") os.makedirs(uncompressed_refernces_dir, exist_ok=True) - name = os.path.splitext(name)[0] reference_file = decompress_bz2(reference_file, uncompressed_refernces_dir) ext_markers = cls.extract_markers_from_genome(reference_file, clade_markers_file) @@ -208,7 +210,8 @@ def process_reference(cls, reference_file, tmp_dir, clade_markers_file, clade_ma os.makedirs(reference_markers_dir, exist_ok=True) consensus_markers = ConsensusMarkers([ConsensusMarker(m, s) for m, s in ext_markers.items()]) - consensus_markers.to_fasta(os.path.join(reference_markers_dir, f'{name}.fna.bz2'), trim_ends=trim_sequences) + reference_name = cls.sample_path_to_name(reference_file) + consensus_markers.to_fasta(os.path.join(reference_markers_dir, f'{reference_name}.fna.bz2'), trim_ends=trim_sequences) markers_matrix = {'sample': reference_file} markers_matrix.update({m: int(m in ext_markers) for m in clade_markers}) @@ -241,6 +244,7 @@ def extract_markers_from_genome(reference_file, clade_markers_file): # run blastn and pass the raw data to stdin r = run_command(cmd, input=input_file_data, text=True) + # load the blast output df = pd.read_csv(io.StringIO(r.stdout), sep='\t', names=columns.split(' ')) ext_markers = {} @@ -289,35 +293,42 @@ def extract_markers_from_genome(reference_file, clade_markers_file): def calculate_polymorphic_rates(self): """Generates a file with the polymorphic rates of the species for each sample""" - with open(os.path.join(self.output_dir, "{}.polymorphic".format(self.clade)), 'w') as polymorphic_file: - polymorphic_file.write("sample\tpercentage_of_polymorphic_sites\tavg_by_marker\tmedian_by_marker" + - "\tstd_by_marker\tmin_by_marker\tmax_by_marker\tq25_by_marker\tq75_by_marker") - for sample_path in self.samples: - sample = ConsensusMarkers.from_file(sample_path) - p_stats, p_count, m_len = [], 0, 0 - for marker in sample.consensus_markers: - if marker.name in self.db_clade_markers: - p_count += marker.get_polymorphisms() - m_len += marker.get_sequence_length() - p_stats.append(marker.get_polymorphism_perc()) - if m_len > 0: - polymorphic_file.write("\n{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}".format(os.path.splitext(os.path.basename(sample_path))[0], - p_count * 100 / m_len, numpy.average(p_stats), numpy.percentile( - p_stats, 50), numpy.std(p_stats), numpy.min(p_stats), - numpy.max(p_stats), numpy.percentile(p_stats, 25), numpy.percentile(p_stats, 75))) + rows = [] + for sample_path in self.samples: + sample = ConsensusMarkers.from_file(sample_path) + p_stats, p_count, m_len = [], 0, 0 + for marker in sample.consensus_markers: + if marker.name in self.clade_markers_names: + p_count += marker.get_polymorphisms() + m_len += marker.get_sequence_length() + p_stats.append(marker.get_polymorphism_perc()) + if m_len > 0: + rows.append({ + 'sample': self.sample_path_to_name(sample_path), + 'percentage_of_polymorphic_sites': p_count * 100 / m_len, + 'avg_by_marker': np.mean(p_stats), + 'median_by_marker': np.median(p_stats), + 'std_by_marker': np.std(p_stats), + 'min_by_marker': np.min(p_stats), + 'max_by_marker': np.max(p_stats), + 'q25_by_marker': np.percentile(p_stats, 25), + 'q75_by_marker': np.percentile(p_stats, 75), + }) + + df = pd.DataFrame(rows) + df.to_csv(os.path.join(self.output_dir, f"{self.clade}.polymorphic"), sep='\t', index=False) + def write_info(self): """Writes the information file for the execution""" - filtered_names = [os.path.splitext(os.path.basename(sample))[ - 0] for sample in self.cleaned_markers_matrix.index.tolist()] + filtered_names = [self.sample_path_to_name(sample) for sample in self.cleaned_markers_matrix.index] with open(os.path.join(self.output_dir, "{}.info".format(self.clade)), 'w') as info_file: info_file.write("Clade: {}\n".format(self.clade)) info_file.write( "Number of samples: {}\n".format(len(self.samples))) info_file.write( "Number of references: {}\n".format(len(self.references))) - info_file.write("Number of available markers for the clade: {}\n".format( - len(self.db_clade_markers))) + info_file.write("Number of available markers for the clade: {}\n".format(len(self.clade_markers_names))) info_file.write("Filtering parameters:\n") info_file.write("\tNumber of bases to remove when trimming markers: {}\n".format( self.trim_sequences)) @@ -329,45 +340,43 @@ def write_info(self): info_file.write("Number of markers selected after filtering: {}\n".format( len(self.cleaned_markers_matrix.columns))) info_file.write("Number of samples after filtering: {}\n".format(len( - [sample for sample in self.samples if sample in self.cleaned_markers_matrix.index.tolist()]))) - info_file.write("Number of references after filtering: {}\n".format(len([reference for reference in self.references if (os.path.splitext( - os.path.basename(reference[:-4]))[0] if reference.endswith('.bz2') else os.path.splitext(os.path.basename(reference))[0]) in filtered_names]))) + [sample for sample in self.samples if sample in self.cleaned_markers_matrix.index]))) + info_file.write("Number of references after filtering: {}\n".format(len([reference for reference in self.references if self.sample_path_to_name(reference) in filtered_names]))) info_file.write( "PhyloPhlan phylogenetic precision mode: {}\n".format(self.phylophlan_mode)) info_file.write( "Number of processes used: {}\n".format(self.nprocs)) - def detect_clades(self, markers2species): - """Checks the clades that can be reconstructed from the pkl files - Args: - markers2species (dict): dictionary containing the clade each marker belong to + def detect_clades(self): + """Checks the clades that can be reconstructed from the pkl files Returns: dict: dictionary containing the number of samples a clade can be reconstructed from """ + markers2species = self.database_controller.get_markers2species() species2samples = {} species_to_check = set() info('Detecting clades...') for sample_path in self.samples: sample = ConsensusMarkers.from_file(sample_path) - species_to_check.update({markers2species[marker.name] for marker in sample.consensus_markers if ( - marker.name in markers2species and marker.breadth >= self.breadth_thres)}) + species_to_check.update((markers2species[marker.name] for marker in sample.consensus_markers if ( + marker.name in markers2species and marker.breadth >= self.breadth_thres))) + info(f' Will check {len(species_to_check)} species') for species in species_to_check: self.cleaned_markers_matrix = pd.DataFrame() self.clade = species - self.db_clade_markers = { - marker: marker for marker in markers2species if markers2species[marker] == self.clade} + self.clade_markers_names = self.database_controller.get_markers_for_clade(species) self.filter_markers_samples(print_clades=True) if len(self.cleaned_markers_matrix) >= 4: species2samples[species] = len(self.cleaned_markers_matrix) info('Done.') return species2samples + def print_clades(self): """Prints the clades detected in the reconstructed markers""" - markers2species = self.database_controller.get_markers2species() - species2samples = self.detect_clades(markers2species) + species2samples = self.detect_clades() info('Detected clades: ') sorted_species2samples = collections.OrderedDict(sorted(species2samples.items(), key=lambda kv: kv[1], reverse=True)) @@ -378,8 +387,9 @@ def print_clades(self): wf.write('{}\t{}\n'.format(species, sorted_species2samples[species])) info('Done.') + def interactive_clade_selection(self): - """Allows the user to interactively select the SGB-level clade when specifing the clade at the species level""" + """Allows the user to interactively select the SGB-level clade when specifying the clade at the species level""" if not self.non_interactive: info("The clade has been specified at the species level, starting interactive clade selection...") species2sgbs = self.database_controller.get_species2sgbs() @@ -421,6 +431,7 @@ def interactive_clade_selection(self): info("Done.") self.clade = selected_clade + def filter_markers_samples(self, print_clades=False): """Retrieves the filtered markers matrix with the filtered samples and references @@ -429,17 +440,19 @@ def filter_markers_samples(self, print_clades=False): """ if not print_clades: info("Getting markers from samples...") - markers_matrix = self.get_markers_matrix_from_samples(print_clades) + markers_matrix = self.get_markers_matrix_from_samples() if not print_clades: info("Done.") - info("Getting markers from references...") - markers_matrix += self.get_markers_from_references() - info("Done.") + if len(self.references) > 0: + info("Getting markers from references...") + markers_matrix += self.get_markers_from_references() + info("Done.") info("Removing bad markers / samples...") self.filter_markers_matrix(markers_matrix, messages=not print_clades) if not print_clades: info("Done.") + def run_strainphlan(self): """Runs the full StrainPhlAn pipeline""" if self.print_clades_only: @@ -448,6 +461,7 @@ def run_strainphlan(self): if self.clade.startswith('s__') and 'CHOCOPhlAnSGB' in self.database_controller.get_database_name(): self.interactive_clade_selection() + self.clade_markers_names = self.database_controller.get_markers_for_clade(self.clade) info("Creating temporary directory...") self.tmp_dir = tempfile.mkdtemp(dir=self.tmp_dir) info("Done.") @@ -471,9 +485,10 @@ def run_strainphlan(self): info("Removing temporary files...") rmtree(self.tmp_dir, ignore_errors=False, onerror=None) info("Done.") - + + def __init__(self, args): - self.db_clade_markers = None + self.clade_markers_names = None self.cleaned_markers_matrix = None self.database_controller = MetaphlanDatabaseController(args.database) self.clade_markers_file = args.clade_markers @@ -516,7 +531,7 @@ def read_params(): help="The reference genomes") p.add_argument('-c', '--clade', type=str, default=None, help="The clade to investigate") - p.add_argument('-o', '--output_dir', type=str, + p.add_argument('-o', '--output_dir', type=str, required=True, help="The output directory") p.add_argument('-n', '--nprocs', type=int, default=1, help="The number of threads to use") @@ -556,7 +571,7 @@ def read_params(): p.add_argument('--treeshrink', action='store_true', default=False, help="If specified, StrainPhlAn will execute TreeShrink after building the tree") p.add_argument('--debug', action='store_true', default=False, - help="If specified, StrainPhlAn will not remove the temporal folders") + help="If specified, StrainPhlAn will not remove the temporary folders") p.add_argument('-v', '--version', action='store_true', help="Shows this help message and exit") @@ -575,7 +590,7 @@ def check_params(args): error('-c (or --clade) must be specified', exit=True) elif not os.path.exists(args.output_dir): error('The directory {} does not exist'.format(args.output_dir), exit=True) - elif not (args.tmp is None) and not os.path.exists(args.tmp): + elif args.tmp is not None and not os.path.exists(args.tmp): error('The directory {} does not exist'.format(args.tmp), exit=True) elif args.database != 'latest' and not os.path.exists(args.database): error('The database does not exist', exit=True) diff --git a/metaphlan/utils/database_controller.py b/metaphlan/utils/database_controller.py index 957951b..d6cfed4 100644 --- a/metaphlan/utils/database_controller.py +++ b/metaphlan/utils/database_controller.py @@ -27,6 +27,7 @@ def load_database(self, verbose=True): if verbose: info('Done.') + def get_database_name(self): """Gets database name @@ -35,16 +36,33 @@ def get_database_name(self): """ return self.database.split('/')[-1][:-4] + def get_markers2species(self): """Retrieve information from the MetaPhlAn database Returns: - dict: the dictionary assigning markers to clades + dict[str, str]: the dictionary assigning markers to clades """ self.load_database() return {marker_name: marker_info['clade'] for marker_name, marker_info in self.database_pkl['markers'].items()} - def get_markers(self): + + def get_markers_for_clade(self, clade): + """ + + Args: + clade (str): + + Returns: + set: marker names for the given clade + + """ + self.load_database() + return set(marker_name for marker_name, marker_info in self.database_pkl['markers'].items() + if marker_info['clade'] == clade) + + + def get_all_markers(self): self.load_database() return list(self.database_pkl['markers'].keys()) diff --git a/metaphlan/utils/sample2markers.py b/metaphlan/utils/sample2markers.py index 49dee8a..c925fe9 100755 --- a/metaphlan/utils/sample2markers.py +++ b/metaphlan/utils/sample2markers.py @@ -305,7 +305,7 @@ def filter_mapping_line(line_fields_, markers_subset): def filter_sam_files(self): """Filters the input SAM files with the hits against markers of specific clades and low quality reads""" filtered_markers = self.database_controller.get_filtered_markers(self.clades) if len(self.clades) > 0 else None - all_markers = set(self.database_controller.get_markers()) + all_markers = set(self.database_controller.get_all_markers()) self.input = execute_pool(((SampleToMarkers.parallel_filter_sam, i, self.tmp_dir, self.input_format, self.min_mapping_quality, self.min_reads_aligning, filtered_markers, all_markers) for i in self.input), self.nprocs)