# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    # Use arm_metrics.py from installed package for type checking.
    from nsys_recipe.lib import arm_metrics as am
else:
    # Use arm_metrics.py from the same directory as post_process.py.
    # arm_metrics.py from installed package along with post_process.py
    # are copied to the recipe output next to nvtx_cpu_topdown.ipynb,
    # which uses post_process.py.
    import arm_metrics as am

import pandas as pd
from IPython.display import Markdown


def style_nvtx_summary(nvtx_summary_df):
    nvtx_summary_df = nvtx_summary_df.drop(columns=["index"])

    cols_to_align = ["NVTX Range"]
    if "Notes" in nvtx_summary_df.columns:
        cols_to_align.append("Notes")

    styled_df = nvtx_summary_df.style.apply(
        lambda x: ["text-align: left" for _ in x], subset=cols_to_align
    )
    styled_df = styled_df.set_table_styles(
        [dict(selector="th", props=[("text-align", "center")])]
    )

    if "Notes" in nvtx_summary_df.columns:
        styled_df = styled_df.set_properties(subset=["Notes"], **{"width": "650px"})

    def gray_out_rows(row):
        if "Notes" in row and "filtered" in row["Notes"]:
            return ["opacity: 0.6" for _ in row]
        return ["" for _ in row]

    styled_df = styled_df.apply(gray_out_rows, axis=1)
    return styled_df


def style_report_summary(report_summary_df):
    styled_df = report_summary_df.style.apply(lambda x: ["text-align: left" for _ in x])
    styled_df = styled_df.set_table_styles(
        [dict(selector="th", props=[("text-align", "left")])]
    )
    return styled_df


def create_no_metrics_msg():
    return Markdown("<div style='text-align: center'>" "No metrics to display.</div>")


def generate_topdown_view(cpu_metrics_df, reports_df):
    """
    Generate a view of the top-down metrics for the selected NVTX range.
    5 tables are generated (a table is only created if the required data is available):
    1. Topdown Level 1 metrics
    2. Frontend Bound metrics
    2. Backend Bound metrics
    3. Bad Speculation metrics
    4. Retiring metrics
    5. Miscellaneous metrics
    """
    metric_id_col = "MetricID"

    t_df = cpu_metrics_df.T.reset_index()
    t_df.columns = ["Metric", "Value"]
    t_df = t_df.drop(index=[0, 1]).reset_index(drop=True)

    t_df["Value"] = t_df["Value"].apply(lambda x: f"{round(x, 2):.2f}")

    def split(str, idx):
        res = str.split("#")
        return res[idx] if idx < len(res) else ""

    report_nums = t_df["Metric"].apply(lambda x: split(x, 1))
    t_df["Report"] = reports_df.iloc[report_nums]["Report"].reset_index(drop=True)

    metric_ids = t_df["Metric"].apply(lambda x: split(x, 0))
    t_df[metric_id_col] = metric_ids

    arm_metrics = am.get_arm_metrics()

    def get_metric_attr(x, attr):
        return getattr(arm_metrics[am.PerfMetricType.from_name(x)], attr)

    t_df["Metric"] = metric_ids.apply(lambda x: get_metric_attr(x, "name"))
    t_df["Description"] = metric_ids.apply(lambda x: get_metric_attr(x, "description"))
    t_df["Equation"] = metric_ids.apply(lambda x: get_metric_attr(x, "equation"))

    l1_categories = am.get_topdown_l1_categories()
    l2_categories = am.get_topdown_l2_categories()

    display_obj_pairs = []

    # We are moving IPC to the top table with Topdown L1 categories
    # and removing it from the Miscellanious section later in the function.
    ipc_metric_id = am.PerfMetricType.IPC.name
    l1_metric_ids = [ipc_metric_id] + [x.metric_type.name for x in l1_categories]

    l1_df = t_df[t_df[metric_id_col].apply(lambda x: x in l1_metric_ids)].reset_index(
        drop=True
    )
    if not l1_df.empty:
        l1_df[metric_id_col] = pd.Categorical(
            l1_df[metric_id_col], categories=l1_metric_ids, ordered=True
        )
        l1_df = l1_df.sort_values(metric_id_col).reset_index(drop=True)
        display_obj_pairs.append((None, l1_df))

    for l1_category in l1_categories:
        l2_dfs = []
        for l2_type in l1_category.l2_types:
            l2_category = l2_categories[l2_type]
            metric_ids = [mt.name for mt in l2_category.metric_types]
            l2_category_df = t_df[
                t_df[metric_id_col].apply(lambda x: x in metric_ids)
            ].reset_index(drop=True)
            if l2_category_df.empty:
                continue
            l2_category_df[metric_id_col] = pd.Categorical(
                l2_category_df[metric_id_col], categories=metric_ids, ordered=True
            )
            l2_category_df.loc[:, "Category"] = l2_category.name
            l2_category_df = l2_category_df.sort_values(metric_id_col).reset_index(
                drop=True
            )
            l2_dfs.append(l2_category_df)

        if len(l2_dfs) > 0:
            markdown = Markdown(f"### {l1_category.name}")
            display_obj_pairs.append(
                (markdown, pd.concat(l2_dfs).reset_index(drop=True))
            )

    misc_metric_ids = [
        x.name for x in am.get_misc_metric_types() if x.name != ipc_metric_id
    ]
    misc_df = t_df[
        t_df[metric_id_col].apply(lambda x: x in misc_metric_ids)
    ].reset_index(drop=True)
    if not misc_df.empty:
        display_obj_pairs.append((Markdown(f"### Miscellaneous"), misc_df))

    if len(display_obj_pairs) == 0:
        return create_no_metrics_msg()

    display_objects = []
    for markdown, df in display_obj_pairs:
        if markdown is not None:
            display_objects.append(markdown)

        if "Category" in df.columns:
            df = df.groupby("Category").apply(lambda x: x)

        df = df[["Metric", "Value", "Description", "Equation", "Report"]]

        styled_df = df.style.apply(
            lambda x: ["text-align: left" for _ in x],
            subset=["Metric", "Description", "Equation", "Report"],
        )

        styled_df = styled_df.set_table_styles(
            [dict(selector="th", props=[("text-align", "center")])]
        )
        display_objects.append(styled_df)

    return display_objects


def generate_warnings_view(warnings_df):
    """
    Generate a view for the DataFrame with warning messages.
    """
    warnings = warnings_df["Message"].tolist()
    return Markdown(
        "## Warnings\n" + "\n\n".join(f"\u26a0 {warn}" for warn in warnings)
    )
