# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

import argparse
from collections.abc import Sequence
from pathlib import Path
from typing import Optional, Union

import pandas as pd

from nsys_recipe.data_service import DataService
from nsys_recipe.lib import exceptions, summary
from nsys_recipe.lib.table_config import CompositeTable
from nsys_recipe.log import logger


def get_session_start_time(session_start_df: pd.DataFrame) -> int:
    return session_start_df.at[0, "utcEpochNs"]


def filter_by_pace_name(
    range_df: pd.DataFrame, pace_col: str, pace_name: str
) -> pd.DataFrame:
    filtered_range_df = range_df[range_df[pace_col] == pace_name]
    return filtered_range_df.reset_index(drop=True)


def compute_pace_stats_dfs(
    range_df: pd.DataFrame, pace_col: str
) -> tuple[pd.DataFrame, pd.DataFrame]:
    # Filter out incomplete ranges.
    pace_df = range_df[range_df["start"].notnull() & range_df["end"].notnull()]

    if pace_df.empty:
        raise exceptions.NoDataError(
            "Dataframe has no rows with valid start and end times."
        )

    pace_df["duration"] = pace_df["end"] - pace_df["start"]
    pace_gdf = pace_df.groupby(pace_col)
    stats_df = summary.describe_column(pace_gdf["duration"])

    # Calculate the difference in values between the 'start' column and the
    # previous row's 'end' column.
    pace_df["delta"] = pace_df["start"] - pace_df["end"].shift(fill_value=0)
    pace_df["delta_accum"] = pace_df["delta"].cumsum()
    pace_df["duration_accum"] = pace_df["duration"].cumsum()

    # Drop the name column that contains the same value for all rows and
    # reset index.
    pace_df = pace_df.drop(columns=[pace_col]).reset_index(drop=True)

    return pace_df, stats_df


def apply_time_offset(
    session_starts: Sequence[int], pace_dfs: Sequence[pd.DataFrame]
) -> None:
    # Synchronize session start times.
    global_min_start = min(session_starts)
    for pace_df, session_start in zip(pace_dfs, session_starts):
        session_offset = session_start - global_min_start
        pace_df["start"] = pace_df["start"] + session_offset
        pace_df["end"] = pace_df["end"] + session_offset


def describe_delta(df: pd.DataFrame) -> pd.DataFrame:
    agg_df = df.agg(["min", "max", "count", "std", "mean", "sum"])
    quantile_df = df.quantile([0.25, 0.5, 0.75])
    quantile_df.index = ["25%", "50%", "75%"]  # type: ignore[assignment]

    # We transpose the concatenated df to have the statistics as columns.
    stats_df = pd.concat([agg_df, quantile_df]).T
    stats_df = summary.format_columns(stats_df)

    return stats_df


def split_columns_as_dataframes(
    pace_dfs: Sequence[pd.DataFrame],
) -> dict[str, pd.DataFrame]:
    # We want to get the pace info in individual dataframes per column and not
    # per rank.
    pace_df_by_column = {}

    cols = ["start", "end", "duration_accum", "delta_accum", "duration", "delta"]
    for col in cols:
        rank_column_value_map = {
            # Parquet must have string column names.
            str(rank): pace_df[col]
            for rank, pace_df in enumerate(pace_dfs)
        }

        rank_column_value_df = pd.DataFrame(rank_column_value_map)
        rank_column_value_df = rank_column_value_df.rename_axis(
            index="Iteration", columns="Rank"
        )
        pace_df_by_column[col] = rank_column_value_df

    delta_df = pace_df_by_column["delta"]
    # 'delta_df' has ranks as columns and iterations as the index. We
    # transpose it to get the statistics per iteration instead of per rank.
    pace_df_by_column["delta_stats"] = describe_delta(delta_df.T)

    return pace_df_by_column


def get_pacing_info(
    report_path: str,
    parsed_args: argparse.Namespace,
    table: Union[str, CompositeTable],
    name_column: str,
) -> Optional[tuple[str, pd.DataFrame, pd.DataFrame, int]]:
    """Get the pacing statistics.

    Parameters
    ----------
    report_path : str
        Path to the report file.
    parsed_args : argparse.Namespace
        Parsed arguments.
    table : str or CompositeTable
        Name of the table to read.
    name_column : str
        Name of the column containing the pacing names.

    Returns
    -------
    filename : str
        Name of the report file.
    pace_df : DataFrame
        Pacing data containing columns like duration, delta time, and their
        cumulative sums.
    stats_df : DataFrame
        Statistical summary of the duration column after grouping by 'pace_col'.
    session_start : float
        Session start time.

    Returns None if no valid data is found.
    """
    service = DataService(report_path, parsed_args)
    service.queue_table("TARGET_INFO_SESSION_START_TIME")
    service.queue_table(table)

    df_dict = service.read_queued_tables()
    if df_dict is None:
        return None

    df = df_dict[table]
    err_msg = service.filter_and_adjust_time(df)
    if err_msg is not None:
        logger.error(f"{report_path}: {err_msg}")
        return None

    if df.empty:
        logger.info(
            f"{report_path}: Report was successfully processed, but no data was found."
        )
        return None

    df = filter_by_pace_name(df, name_column, parsed_args.name)
    if df.empty:
        logger.warning(f"{report_path}: Report does not contain '{parsed_args.name}'.")
        return None

    filename = Path(report_path).stem
    session_start = get_session_start_time(df_dict["TARGET_INFO_SESSION_START_TIME"])
    try:
        pace_df, stats_df = compute_pace_stats_dfs(df, name_column)
        return filename, pace_df, stats_df, session_start
    except exceptions.NoDataError:
        logger.info(
            f"{report_path}: Report was successfully processed, but no valid data was found."
        )
        return None
