import argparse
import os
import pandas as pd
import numpy as np
from collections import Counter
import pybedtools
from scipy.stats import hypergeom, binomtest, binom
import plotly.express as px
import plotly.graph_objects as go
from datetime import datetime
import uuid
import logging
import sys

# ---------- Setup Output Directory and Logging ----------

def setup_output_dir(prefix):
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    unique_id = str(uuid.uuid4())[:8]
    output_dir = f"{prefix}_{timestamp}_{unique_id}"
    os.makedirs(output_dir, exist_ok=True)
    return output_dir

def setup_logging(output_dir, verbose):
    log_file = os.path.join(output_dir, "analysis.log")
    logging.basicConfig(
        level=logging.DEBUG if verbose else logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s",
        handlers=[
            logging.FileHandler(log_file),
            logging.StreamHandler(sys.stdout)
        ]
    )
    return logging.getLogger()

# ---------- Utility Functions ----------

def validate_file(file_path, description):
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"{description} not found: {file_path}")
    return file_path

def read_and_sort_bed(bed_file):
    logger.info(f"Reading and sorting BED file: {bed_file}")
    try:
        bed_df = pd.read_csv(bed_file, sep='\t', header=None, names=['chrom', 'start', 'end', 'name'])
        
        # Ensure start and end are integers
        bed_df['start'] = bed_df['start'].astype(int)
        bed_df['end'] = bed_df['end'].astype(int)
        
        # Sort by chromosome and start position
        bed_df.sort_values(by=['chrom', 'start'], inplace=True)
        
        # Select only the first 3 columns for BED format (chrom, start, end)
        bed_df_for_bedtools = bed_df[['chrom', 'start', 'end']]
        
        return pybedtools.BedTool.from_dataframe(bed_df_for_bedtools)
    except Exception as e:
        logger.error(f"Error reading BED file {bed_file}: {e}")
        raise

def intersect_with_enhancers(bed_bt, enhancer_path):
    logger.info(f"Intersecting with enhancer data: {enhancer_path}")
    validate_file(enhancer_path, "Enhancer file")
    try:
        # Read the enhancer file
        enhancer_df = pd.read_csv(enhancer_path, sep='\t', low_memory=False)
        
        logger.info(f"Read enhancer file with {len(enhancer_df)} rows and {len(enhancer_df.columns)} columns")
        logger.info(f"Columns: {list(enhancer_df.columns)}")
        
        # Map column names to expected format
        column_mapping = {
            'Chr': 'chrom',
            'Start_38': 'start',
            'End_38': 'end',
            'Enhancer-id': 'enhancer_id',
            'gene_chromosome': 'gene_chrom',
            'nearest_gene_name': 'gene_name'
        }
        
        enhancer_df.rename(columns=column_mapping, inplace=True)
        
        # Convert coordinates to numeric
        enhancer_df['start'] = pd.to_numeric(enhancer_df['start'], errors='coerce')
        enhancer_df['end'] = pd.to_numeric(enhancer_df['end'], errors='coerce')
        
        # Drop rows with missing coordinates
        before = len(enhancer_df)
        enhancer_df = enhancer_df.dropna(subset=['start', 'end'])
        logger.info(f"Dropped {before - len(enhancer_df)} rows with missing coordinates")
        
        if len(enhancer_df) == 0:
            raise ValueError("No valid enhancer coordinates found")
        
        enhancer_df['start'] = enhancer_df['start'].astype(int)
        enhancer_df['end'] = enhancer_df['end'].astype(int)
        
        # Handle optional gene coordinates
        if 'gene_start' in enhancer_df.columns:
            enhancer_df['gene_start'] = pd.to_numeric(enhancer_df['gene_start'], errors='coerce')
        if 'gene_end' in enhancer_df.columns:
            enhancer_df['gene_end'] = pd.to_numeric(enhancer_df['gene_end'], errors='coerce')
        
        # Select columns for BedTool
        bed_columns = ['chrom', 'start', 'end', 'enhancer_id', 'gene_chrom',
                       'gene_start', 'gene_end', 'gene_name', 'distance_from_enhancer',
                       'vista_id', 'EnhancerDB_id', 'Fantom_id', 'Enhancer_atlas_id']
        
        available_columns = [col for col in bed_columns if col in enhancer_df.columns]
        enhancer_df_bed = enhancer_df[available_columns].copy()
        
        logger.info(f"Creating BedTool with {len(enhancer_df_bed)} enhancers")
        
        enhancer_bt = pybedtools.BedTool.from_dataframe(enhancer_df_bed)
        intersected = bed_bt.intersect(enhancer_bt, wa=True, wb=True)
        
        # Build output column names: peak columns + enhancer columns (with enh_ prefix to avoid duplicates)
        output_names = ['peak_chrom', 'peak_start', 'peak_end']
        for col in available_columns:
            output_names.append(f'enh_{col}')
        
        result_df = intersected.to_dataframe(names=output_names)
        
        # Rename enhancer columns back to original names for consistency
        rename_map = {f'enh_{col}': col for col in available_columns}
        result_df.rename(columns=rename_map, inplace=True)
        result_df.rename(columns={'peak_chrom': 'chrom', 'peak_start': 'start', 'peak_end': 'end'}, inplace=True)
        
        # Convert distance_from_enhancer to numeric if it exists
        if 'distance_from_enhancer' in result_df.columns:
            result_df['distance_from_enhancer'] = pd.to_numeric(result_df['distance_from_enhancer'], errors='coerce')
        
        logger.info(f"Intersection found {len(result_df)} overlapping regions")
        return result_df
        
    except Exception as e:
        logger.error(f"Error intersecting with enhancer data: {e}")
        import traceback
        logger.error(traceback.format_exc())
        raise

def filter_by_distance(df, threshold):
    logger.info(f"Filtering by distance threshold: {threshold} bp")
    if threshold < 0:
        raise ValueError("Distance threshold must be non-negative")
    filtered_df = df[df['distance_from_enhancer'] <= threshold]
    if filtered_df.empty:
        logger.warning("No peaks found within the distance threshold")
    return filtered_df

def process_background(background_file, enhancer_bt, enhancer_path):
    logger.info("Processing background data")
    if background_file:
        logger.info(f"Using background file: {background_file}")
        validate_file(background_file, "Background file")
        bg_df = pd.read_csv(background_file, sep='\t', header=None, names=['chrom', 'start', 'end'])
        background_bt = pybedtools.BedTool.from_dataframe(bg_df)
        intersected = background_bt.intersect(enhancer_bt, wa=True, wb=True)
        
        # Get available columns from enhancer_bt
        bg_df = intersected.to_dataframe()
        
        # Map gene_name column (it might be in different positions)
        if 'gene_name' not in bg_df.columns and len(bg_df.columns) > 10:
            # Try to find the gene name column by position (usually around column 10-11)
            col_names = list(bg_df.columns)
            if len(col_names) > 10:
                bg_df.columns = ['chrom', 'start', 'end', 'enh_chrom', 'enh_start', 'enh_end', 
                               'enhancer_id', 'gene_chrom', 'gene_start', 'gene_end', 'gene_name'] + col_names[11:]
    else:
        logger.info(f"Using default enhancer database as background: {enhancer_path}")
        bg_df = pd.read_csv(enhancer_path, sep='\t')
        
        # Map column names
        column_mapping = {
            'Chr': 'chrom',
            'Start_38': 'start',
            'End_38': 'end',
            'nearest_gene_name': 'gene_name'
        }
        bg_df.rename(columns=column_mapping, inplace=True)

    # Make sure gene_name column exists
    if 'gene_name' not in bg_df.columns and 'nearest_gene_name' in bg_df.columns:
        bg_df['gene_name'] = bg_df['nearest_gene_name']
    
    gene_counts = Counter(bg_df['gene_name'].dropna())
    logger.info(f"Background contains {len(gene_counts)} unique genes")
    return bg_df, gene_counts

def safe_upper(gene):
    if pd.isnull(gene) or str(gene).strip() == '' or str(gene).upper() == 'NAN':
        return None
    return str(gene).strip().upper()

def gene_enrichment(fg_genes, bg_genes):
    logger.info("Performing gene-level enrichment analysis")
    fg_counts = Counter(fg_genes)
    bg_counts = Counter(bg_genes)

    N = len(set(bg_genes))
    n = len(fg_genes)

    genes = np.array(list(fg_counts.keys()))
    k_vals = np.array([fg_counts[gene] for gene in genes])
    K_vals = np.array([bg_counts.get(gene, 0) for gene in genes])
    p_hyper = hypergeom.sf(k_vals - 1, N, K_vals, n)
    p_binom = np.array([binomtest(k, n, K/N, alternative='greater').pvalue for k, K in zip(k_vals, K_vals)])

    df = pd.DataFrame({
        "Gene": genes,
        "Foreground Count": k_vals,
        "Background Count": K_vals,
        "Hypergeometric P-value": p_hyper,
        "Binomial P-value": p_binom
    })
    df = df.sort_values(by="Hypergeometric P-value")
    logger.info(f"Gene enrichment completed for {len(df)} genes")
    return df

def pathway_enrichment(fg_genes, gmt_paths, bg_counts):
    logger.info("Performing pathway enrichment analysis")
    gene_set = {safe_upper(g) for g in fg_genes if safe_upper(g)}
    results = []

    for db, gmt in gmt_paths.items():
        logger.info(f"Processing GMT file for {db}: {gmt}")
        if not os.path.exists(gmt):
            logger.warning(f"Missing GMT file: {gmt}")
            continue

        try:
            with open(gmt) as f:
                lines = f.readlines()
        except Exception as e:
            logger.error(f"Error reading GMT file {gmt}: {e}")
            continue

        for line in lines:
            parts = line.strip().split('\t')
            if len(parts) < 3:
                continue
            pathway, genes = parts[0], {safe_upper(g) for g in parts[2:] if safe_upper(g)}

            overlap = len(genes & gene_set)
            if overlap == 0:
                continue

            k = overlap
            n = len(genes)
            N = len(gene_set)
            sum_path = sum(bg_counts.get(g, 0) for g in genes)
            p_bg = sum_path / sum(bg_counts.values()) if sum(bg_counts.values()) > 0 else 0

            p_hyper = 1 - hypergeom.cdf(k - 1, len(bg_counts), n, N)
            p_binom = binom.sf(k - 1, N, p_bg)

            results.append({
                "Database": db,
                "Pathway": pathway,
                "Overlap": f"{k}/{n}",
                "P-value (Hypergeometric)": p_hyper,
                "P-value (Binomial)": p_binom,
                "Genes": ', '.join(genes & gene_set)
            })
    df = pd.DataFrame(results)
    logger.info(f"Pathway enrichment completed with {len(df)} pathways")
    return df

# ---------- Visualization Functions ----------

def plot_distance_distribution(foreground_df, output_path):
    logger.info(f"Generating distance distribution plot: {output_path}")
    # Create a copy to avoid SettingWithCopyWarning
    plot_df = foreground_df.copy()
    max_distance = plot_df['distance_from_enhancer'].max()
    bin_size = max(100, int(max_distance // 10))
    bins = list(range(0, int(max_distance) + bin_size, bin_size))
    labels = [f"{b}-{b+bin_size} bp" for b in bins[:-1]]
    
    plot_df['distance_bin'] = pd.cut(plot_df['distance_from_enhancer'], 
                                     bins=bins, labels=labels, right=False)
    bin_counts = plot_df['distance_bin'].value_counts(normalize=True).sort_index() * 100
    
    fig = px.pie(
        names=bin_counts.index,
        values=bin_counts.values,
        title="Gene Distribution Across Distance Bins",
        color_discrete_sequence=px.colors.qualitative.Set3
    )
    fig.update_traces(textposition='inside', textinfo='percent+label')
    fig.write_html(output_path)
    logger.info(f"Saved distance distribution plot to: {output_path}")

def plot_top_enriched_genes(enrichment_df, output_path, top_n=20):
    logger.info(f"Generating gene enrichment plot: {output_path}")
    df = enrichment_df.copy()
    df = df[df["Hypergeometric P-value"] > 0]  # Avoid log(0)
    df["-log10(p-value)"] = -np.log10(df["Hypergeometric P-value"])
    df = df.sort_values("-log10(p-value)", ascending=False).head(top_n)

    fig = px.bar(
        df,
        y="Gene",
        x="-log10(p-value)",
        orientation='h',
        title=f"Top {top_n} Enriched Genes",
        color="-log10(p-value)",
        color_continuous_scale="Blues",
        labels={"-log10(p-value)": "-log10(Hypergeometric P-value)"},
        height=600
    )
    fig.update_layout(
        yaxis={'categoryorder': 'total ascending'},
        xaxis_title="-log10(Hypergeometric P-value)",
        yaxis_title="Gene"
    )
    fig.write_html(output_path)
    logger.info(f"Saved gene enrichment plot to: {output_path}")

def plot_top_pathways(pathway_df, output_dir, top_n=10):
    logger.info(f"Generating pathway enrichment plots in: {output_dir}")
    if pathway_df.empty:
        logger.warning("No pathways enriched; skipping pathway plots")
        return

    grouped = pathway_df.groupby('Database')
    
    for database, group_df in grouped:
        # Hypergeometric plot
        df_hyper = group_df.sort_values('P-value (Hypergeometric)').head(top_n).copy()
        df_hyper['-log10(p-value)'] = -np.log10(df_hyper['P-value (Hypergeometric)'].replace(0, 1e-10))
        
        fig_hyper = px.bar(
            df_hyper,
            x='Pathway',
            y='-log10(p-value)',
            title=f"{database} - Top {top_n} Pathways (Hypergeometric)",
            color='-log10(p-value)',
            color_continuous_scale='Viridis',
            labels={'-log10(p-value)': '-log10(Hypergeometric P-value)'},
            height=500
        )
        fig_hyper.update_layout(xaxis_tickangle=-30)
        output_path = os.path.join(output_dir, f"{database}_hypergeometric.html")
        fig_hyper.write_html(output_path)
        logger.info(f"Saved {database} hypergeometric plot to: {output_path}")

        # Binomial plot
        df_binom = group_df.sort_values('P-value (Binomial)').head(top_n).copy()
        df_binom['-log10(p-value)'] = -np.log10(df_binom['P-value (Binomial)'].replace(0, 1e-10))
        
        fig_binom = px.bar(
            df_binom,
            x='Pathway',
            y='-log10(p-value)',
            title=f"{database} - Top {top_n} Pathways (Binomial)",
            color='-log10(p-value)',
            color_continuous_scale='Plasma',
            labels={'-log10(p-value)': '-log10(Binomial P-value)'},
            height=500
        )
        fig_binom.update_layout(xaxis_tickangle=-30)
        output_path = os.path.join(output_dir, f"{database}_binomial.html")
        fig_binom.write_html(output_path)
        logger.info(f"Saved {database} binomial plot to: {output_path}")

# ---------- Summary File ----------

def write_summary(output_dir, args, status, output_files):
    summary_file = os.path.join(output_dir, "summary.txt")
    with open(summary_file, 'w') as f:
        f.write(f"Enhancer Pathway Enrichment Analysis Summary\n")
        f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Status: {status}\n\n")
        f.write(f"Input Parameters:\n")
        f.write(f"  BED File: {args.bed}\n")
        f.write(f"  Background File: {args.background or 'None (used default enhancer database)'}\n")
        f.write(f"  Enhancer File: {args.enhancer}\n")
        f.write(f"  GMT Directory: {args.gmt_dir}\n")
        f.write(f"  Distance Threshold: {args.distance} bp\n")
        f.write(f"  Top N Genes: {args.top_n_genes}\n")
        f.write(f"  Top N Pathways: {args.top_n_pathways}\n")
        f.write(f"  Verbose: {args.verbose}\n")
        f.write(f"  Save Intermediate: {args.save_intermediate}\n")
        f.write(f"  Skip Pathway Enrichment: {args.skip_pathway}\n\n")
        f.write(f"Output Files:\n")
        for file in output_files:
            f.write(f"  {file}\n")
    logger.info(f"Saved summary to: {summary_file}")

# ---------- Main CLI Function ----------

def main():
    parser = argparse.ArgumentParser(description="Enhanced Enhancer Pathway Enrichment CLI")
    parser.add_argument("--bed", required=True, help="Input BED file")
    parser.add_argument("--background", default=None, help="Optional background BED file")
    parser.add_argument("--distance", type=int, default=5000, help="Distance threshold (default: 5000)")
    parser.add_argument("--enhancer", default="enhancer_gene.bed", help="Enhancer BED file")
    parser.add_argument("--gmt_dir", default="intersect", help="Directory containing GMT files")
    parser.add_argument("--output_prefix", default="results", help="Prefix for output directory")
    parser.add_argument("--top_n_genes", type=int, default=20, help="Top N genes to plot")
    parser.add_argument("--top_n_pathways", type=int, default=10, help="Top N pathways to plot")
    parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
    parser.add_argument("--save_intermediate", action="store_true", help="Save intermediate data as CSV")
    parser.add_argument("--skip_pathway", action="store_true", help="Skip pathway enrichment if GMT files are missing")

    args = parser.parse_args()

    global logger
    output_dir = setup_output_dir(args.output_prefix)
    logger = setup_logging(output_dir, args.verbose)
    logger.info("Starting enhancer pathway enrichment analysis")

    gmt_paths = {
        'Wiki': os.path.join(args.gmt_dir, "WikiPathways_2024_Human.gmt"),
        'GO_BP': os.path.join(args.gmt_dir, "GO_Biological_Process_2025.gmt"),
        'KEGG': os.path.join(args.gmt_dir, "KEGG_2021_Human.gmt"),
        'Reactome': os.path.join(args.gmt_dir, "Reactome_Pathways_2024.gmt"),
        'HDSigDB': os.path.join(args.gmt_dir, "HDSigDB_Human_2021.gmt")
    }

    # Check GMT files
    valid_gmt_paths = {db: path for db, path in gmt_paths.items() if os.path.exists(path)}
    if not valid_gmt_paths and not args.skip_pathway:
        logger.error("No valid GMT files found in the specified directory. Please check the --gmt_dir path or use --skip_pathway.")
        raise FileNotFoundError(f"No valid GMT files found in {args.gmt_dir}")

    output_files = []

    try:
        # Read and process BED file
        bed_bt = read_and_sort_bed(args.bed)
        enhancer_bt = pybedtools.BedTool.from_dataframe(pd.read_csv(args.enhancer, sep='\t'))
        intersected_df = intersect_with_enhancers(bed_bt, args.enhancer)

        # Save intermediate data if requested
        if args.save_intermediate:
            intersected_df.to_csv(os.path.join(output_dir, "intersected_peaks.csv"), index=False)
            output_files.append("intersected_peaks.csv")
            logger.info("Saved intersected peaks to intersected_peaks.csv")

        # Filter by distance
        fg_df = filter_by_distance(intersected_df, args.distance)
        if args.save_intermediate:
            fg_df.to_csv(os.path.join(output_dir, "foreground_peaks.csv"), index=False)
            output_files.append("foreground_peaks.csv")
            logger.info("Saved foreground peaks to foreground_peaks.csv")

        # Plot distance distribution
        plot_distance_distribution(fg_df, os.path.join(output_dir, "distance_distribution.html"))
        output_files.append("distance_distribution.html")

        fg_genes = fg_df['gene_name'].tolist()
        bg_df, bg_gene_counts = process_background(args.background, enhancer_bt, args.enhancer)
        bg_genes = bg_df['gene_name'].tolist()

        # Gene-level enrichment
        gene_df = gene_enrichment(fg_genes, bg_genes)
        gene_df.to_csv(os.path.join(output_dir, "gene_enrichment.csv"), index=False)
        output_files.append("gene_enrichment.csv")
        plot_top_enriched_genes(gene_df, os.path.join(output_dir, "gene_enrichment.html"), top_n=args.top_n_genes)
        output_files.append("gene_enrichment.html")

        # Pathway enrichment
        if valid_gmt_paths and not args.skip_pathway:
            pathway_df = pathway_enrichment(fg_genes, valid_gmt_paths, bg_gene_counts)
            pathway_df.to_csv(os.path.join(output_dir, "pathway_enrichment.csv"), index=False)
            output_files.append("pathway_enrichment.csv")
            plot_top_pathways(pathway_df, output_dir, top_n=args.top_n_pathways)
            for db in valid_gmt_paths.keys():
                if os.path.exists(os.path.join(output_dir, f"{db}_hypergeometric.html")):
                    output_files.append(f"{db}_hypergeometric.html")
                if os.path.exists(os.path.join(output_dir, f"{db}_binomial.html")):
                    output_files.append(f"{db}_binomial.html")
        else:
            logger.info("Skipping pathway enrichment due to missing GMT files or --skip_pathway flag")

        logger.info("Analysis complete")
        logger.info(f"Output directory: {output_dir}")
        for file in output_files:
            logger.info(f"Output file: {file}")

        # Write summary
        write_summary(output_dir, args, "Success", output_files)

    except Exception as e:
        logger.error(f"Analysis failed: {e}")
        write_summary(output_dir, args, f"Failed: {str(e)}", output_files)
        raise

if __name__ == "__main__":
    main()