Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 88 additions & 14 deletions bin/mut_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import click
import json
import pandas as pd
import numpy as np

from utils import contexts_formatted
from utils import contexts_formatted, contexts_formatted_sigprofiler
from utils import filter_maf


def compute_mutation_matrix(sample_name, mutations_file, mutation_matrix, method = 'unique', pseudocount = 0):
def compute_mutation_matrix(sample_name, mutations_file, mutation_matrix, method = 'unique', pseudocount = 0,
sigprofiler = False, per_sample = False):
"""
Compute mutational profile from the input data
***Remember to add some pseudocounts to the computation***
Expand All @@ -37,20 +39,22 @@ def compute_mutation_matrix(sample_name, mutations_file, mutation_matrix, method
annotated_minimal_maf.columns = ["SAMPLE_ID", "CONTEXT_MUT", "MUT_ID"]

# count the mutations per sample and per context
counts_x_sample_context_long = annotated_minimal_maf.groupby(by = ["CONTEXT_MUT"])["MUT_ID"].count().reset_index()
counts_x_sample_matrix = counts_x_sample_context_long.set_index("CONTEXT_MUT")
counts_x_sample_matrix.columns = [sample_name]

counts_x_sample_context_long = annotated_minimal_maf.groupby(by = ["SAMPLE_ID", "CONTEXT_MUT"])["MUT_ID"].count().reset_index()
counts_x_sample_context_long.columns = ["SAMPLE_ID", "CONTEXT_MUT", "COUNT"]

elif method == 'multiple':
# make sure to count each mutation only once (avoid annotation issues)
annotated_minimal_maf = annotated_maf[["SAMPLE_ID", "CONTEXT_MUT", "MUT_ID", "ALT_DEPTH"]].drop_duplicates().reset_index(drop = True)
annotated_minimal_maf.columns = ["SAMPLE_ID", "CONTEXT_MUT", "MUT_ID", "ALT_DEPTH"]

# count the mutations per sample and per context
counts_x_sample_context_long = annotated_minimal_maf.groupby(by = ["CONTEXT_MUT"])["ALT_DEPTH"].sum().reset_index()
counts_x_sample_matrix = counts_x_sample_context_long.set_index("CONTEXT_MUT")
counts_x_sample_matrix.columns = [sample_name]
counts_x_sample_context_long = annotated_minimal_maf.groupby(by = ["SAMPLE_ID", "CONTEXT_MUT"])["ALT_DEPTH"].sum().reset_index()
counts_x_sample_context_long.columns = ["SAMPLE_ID", "CONTEXT_MUT", "COUNT"]

# here we group the counts of all the samples
counts_x_sample_matrix = counts_x_sample_context_long.groupby(by = ["CONTEXT_MUT"])["COUNT"].sum().reset_index()
counts_x_sample_matrix.columns = ["CONTEXT_MUT", sample_name]
counts_x_sample_matrix = counts_x_sample_matrix.set_index("CONTEXT_MUT")

counts_x_sample_matrix = pd.concat( (empty_matrix, counts_x_sample_matrix) , axis = 1)
counts_x_sample_matrix = counts_x_sample_matrix.fillna(0)
Expand All @@ -64,8 +68,38 @@ def compute_mutation_matrix(sample_name, mutations_file, mutation_matrix, method
index = True,
sep = "\t")


def compute_mutation_profile(sample_name, mutation_matrix_file, trinucleotide_counts_file, json_output, plot):
if sigprofiler:
counts_x_sample_matrix.index = contexts_formatted_sigprofiler
counts_x_sample_matrix.index.name = "CONTEXT_MUT"
counts_x_sample_matrix = counts_x_sample_matrix.reset_index().sort_values(by = "CONTEXT_MUT")
counts_x_sample_matrix.to_csv(f"{mutation_matrix}.single.sigprofiler",
header = True,
index = False,
sep = "\t")

if per_sample:
counts_x_sample_matrix = counts_x_sample_context_long.pivot(index = "CONTEXT_MUT", columns = "SAMPLE_ID", values = "COUNT")
counts_x_sample_matrix = pd.concat( (empty_matrix, counts_x_sample_matrix) , axis = 1)
counts_x_sample_matrix = counts_x_sample_matrix.fillna(0)
counts_x_sample_matrix.index.name = "CONTEXT_MUT"

counts_x_sample_matrix.to_csv(f"{mutation_matrix}.per_sample",
header = True,
index = True,
sep = "\t")

if sigprofiler:
counts_x_sample_matrix.index = contexts_formatted_sigprofiler
counts_x_sample_matrix.index.name = "CONTEXT_MUT"
counts_x_sample_matrix = counts_x_sample_matrix.reset_index().sort_values(by = "CONTEXT_MUT")
counts_x_sample_matrix.to_csv(f"{mutation_matrix}.per_sample.sigprofiler",
header = True,
index = False,
sep = "\t")


def compute_mutation_profile(sample_name, mutation_matrix_file, trinucleotide_counts_file, json_output, plot,
wgs = False, wgs_trinucleotide_counts = False, sigprofiler = False):
"""
Compute mutational profile from the input data

Expand All @@ -79,6 +113,7 @@ def compute_mutation_profile(sample_name, mutation_matrix_file, trinucleotide_co
# Load your mutation matrix
mutation_matrix = pd.read_csv(mutation_matrix_file, sep = "\t", header = 0)
mutation_matrix = mutation_matrix.set_index("CONTEXT_MUT")
total_mutations = np.sum(mutation_matrix[sample_name])

# Load the trinucleotide counts
trinucleotide_counts = pd.read_csv(trinucleotide_counts_file, sep = "\t", header = 0)
Expand Down Expand Up @@ -133,6 +168,38 @@ def compute_mutation_profile(sample_name, mutation_matrix_file, trinucleotide_co
ymax = max_freq,
output_f = f'{sample_name}.profile.pdf')

if wgs:
if not wgs_trinucleotide_counts:
print("Invalid wgs_trinucleotide_counts, provide a correct file: ", wgs_trinucleotide_counts)
sys.exit("Invalid wgs_trinucleotide_counts, provide a correct file.")

mut_probability["CONTEXT"] = mut_probability["CONTEXT_MUT"].apply( lambda x : x[:3])
ref_trinuc32 = pd.read_csv(wgs_trinucleotide_counts,
sep = "\t", header = 0)

profile_trinuc_merge = mut_probability.merge(ref_trinuc32, on = "CONTEXT")
profile_trinuc_merge["MUTS_WGS"] = profile_trinuc_merge[sample_name] * profile_trinuc_merge["COUNT"]
profile_trinuc_merge["SAMPLE_MUTS_WGS"] = profile_trinuc_merge["MUTS_WGS"] / np.sum(profile_trinuc_merge["MUTS_WGS"]) * total_mutations
profile_trinuc_clean = profile_trinuc_merge[["CONTEXT_MUT", "SAMPLE_MUTS_WGS"]].set_index("CONTEXT_MUT")
profile_trinuc_clean.index.name = "CONTEXT_MUT"
profile_trinuc_clean = profile_trinuc_clean.reindex(contexts_formatted)
profile_trinuc_clean.columns = [sample_name]

profile_trinuc_clean.to_csv(f"{json_output}.matrix.WGS",
header = True,
index = True,
sep = "\t")
if sigprofiler:
profile_trinuc_clean.index = contexts_formatted_sigprofiler
profile_trinuc_clean.index.name = "CONTEXT_MUT"
profile_trinuc_clean = profile_trinuc_clean.reset_index().sort_values(by = "CONTEXT_MUT")
profile_trinuc_clean.to_csv(f"{json_output}.matrix.WGS.sigprofiler",
header = True,
index = False,
sep = "\t")





@click.command()
Expand All @@ -142,24 +209,31 @@ def compute_mutation_profile(sample_name, mutation_matrix_file, trinucleotide_co
@click.option('--out_matrix', type=click.Path(), help='Output mutation matrix file')
@click.option('--method', type=click.Choice(['unique', 'multiple']), default='unique')
@click.option('--pseud', type=float, default=0.5)
@click.option('--per_sample', is_flag=True, help='Create a column for each sample in the input')

@click.option('--mutation_matrix', type=click.Path(exists=True), help='Mutation matrix file (for profile mode)')
@click.option('--trinucleotide_counts', type=click.Path(exists=True), help='Trinucleotide counts file (for profile mode)')
@click.option('--out_profile', type=click.Path(), help='JSON output file (for profile mode)')
@click.option('--plot', is_flag=True, help='Generate plot and save as PDF')
@click.option('--wgs', is_flag=True, help='Store matrix of mutation counts at WGS level')
@click.option('--wgs_trinucleotide_counts', type=click.Path(exists=True), help='Trinucleotide counts file of the WGS (for profile mode if WGS active)')


@click.option('--sigprofiler', is_flag=True, help='Store the index column using the SigProfiler format')

def main(mode, sample_name, mut_file, out_matrix, method, pseud, mutation_matrix, trinucleotide_counts, out_profile, plot):
def main(mode, sample_name, mut_file, out_matrix, method, pseud, sigprofiler, per_sample, mutation_matrix,
trinucleotide_counts, out_profile, plot, wgs, wgs_trinucleotide_counts):
# TODO
# add additional mode to normalize mutation counts for the genomic trinucleotide level
if mode == 'matrix':
click.echo(f"Running in matrix mode...")
click.echo(f"Using the method: {method}")
click.echo(f"Using the pseudocount: {pseud}")
compute_mutation_matrix(sample_name, mut_file, out_matrix, method, pseud)
compute_mutation_matrix(sample_name, mut_file, out_matrix, method, pseud, sigprofiler, per_sample)

elif mode == 'profile':
click.echo(f"Running in profile mode...")
compute_mutation_profile(sample_name, mutation_matrix, trinucleotide_counts, out_profile, plot)
compute_mutation_profile(sample_name, mutation_matrix, trinucleotide_counts, out_profile, plot, wgs, wgs_trinucleotide_counts, sigprofiler)
click.echo("Profile computation completed.")

else:
Expand Down
Loading