#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Mar 15 15:43:31 2026

@author: ozanari
"""

import numpy as np
import streamlit as st
import plotly.graph_objects as go


# ============================================================
# Constants
# ============================================================
D_GS = 2.87   # GHz, ground-state zero-field splitting
D_ES = 1.42   # GHz, simplified excited-state zero-field splitting
GAMMA_E = 28.0  # GHz/T, electron gyromagnetic ratio
ZPL_WL = 637.0  # nm


# ============================================================
# Helpers
# ============================================================
def deg2rad(x_deg: float) -> float:
    return np.deg2rad(x_deg)


def spherical_to_unit(theta_deg: float, phi_deg: float) -> np.ndarray:
    theta = deg2rad(theta_deg)
    phi = deg2rad(phi_deg)
    v = np.array([
        np.sin(theta) * np.cos(phi),
        np.sin(theta) * np.sin(phi),
        np.cos(theta),
    ], dtype=float)
    n = np.linalg.norm(v)
    if n == 0:
        return np.array([0.0, 0.0, 1.0])
    return v / n


def compute_b_parallel(B_T: float, B_theta_deg: float, B_phi_deg: float,
                       NV_theta_deg: float, NV_phi_deg: float) -> tuple[float, np.ndarray, np.ndarray]:
    B_hat = spherical_to_unit(B_theta_deg, B_phi_deg)
    NV_hat = spherical_to_unit(NV_theta_deg, NV_phi_deg)
    B_parallel_T = B_T * float(np.dot(B_hat, NV_hat))
    return B_parallel_T, B_hat, NV_hat


def ground_state_levels(B_parallel_T: float) -> dict:
    shift = GAMMA_E * B_parallel_T
    return {
        -1: D_GS - shift,
        0: 0.0,
        +1: D_GS + shift,
    }


def excited_state_levels(B_parallel_T: float) -> dict:
    shift = GAMMA_E * B_parallel_T
    return {
        -1: D_ES - shift,
        0: 0.0,
        +1: D_ES + shift,
    }


def simulate_pl_spectrum(wavelengths_nm: np.ndarray, laser_wl_nm: float) -> np.ndarray:
    zpl = 1.00 * np.exp(-0.5 * ((wavelengths_nm - 637.0) / 2.2) ** 2)
    psb1 = 0.80 * np.exp(-0.5 * ((wavelengths_nm - 680.0) / 18.0) ** 2)
    psb2 = 0.55 * np.exp(-0.5 * ((wavelengths_nm - 715.0) / 28.0) ** 2)
    psb3 = 0.22 * np.exp(-0.5 * ((wavelengths_nm - 760.0) / 35.0) ** 2)

    excitation_eff = (
        1.00 * np.exp(-0.5 * ((laser_wl_nm - 532.0) / 22.0) ** 2)
        + 0.20 * np.exp(-0.5 * ((laser_wl_nm - 594.0) / 18.0) ** 2)
        + 0.08
    )
    spectrum = excitation_eff * (zpl + psb1 + psb2 + psb3)
    spectrum /= np.max(spectrum)
    return spectrum


def lorentzian(x: np.ndarray, x0: float, hwhm: float) -> np.ndarray:
    return 1.0 / (1.0 + ((x - x0) / hwhm) ** 2)


def simulate_odmr(freqs_GHz: np.ndarray,
                  B_parallel_T: float,
                  linewidth_MHz: float,
                  contrast: float) -> tuple[np.ndarray, float, float]:
    shift_GHz = GAMMA_E * B_parallel_T
    f_minus = D_GS - shift_GHz
    f_plus = D_GS + shift_GHz

    hwhm_GHz = linewidth_MHz * 1e-3
    dip_minus = lorentzian(freqs_GHz, f_minus, hwhm_GHz)
    dip_plus = lorentzian(freqs_GHz, f_plus, hwhm_GHz)
    pl = 1.0 - contrast * (0.5 * dip_minus + 0.5 * dip_plus)
    return pl, f_minus, f_plus


def direction_text(v: np.ndarray) -> str:
    return f"[{v[0]:+.3f}, {v[1]:+.3f}, {v[2]:+.3f}]"


# ============================================================
# Plot helpers
# ============================================================
def add_level(fig: go.Figure, x0: float, x1: float, y: float,
              color: str, width: float = 4.0):
    fig.add_trace(
        go.Scatter(
            x=[x0, x1],
            y=[y, y],
            mode="lines",
            line=dict(color=color, width=width),
            showlegend=False,
            hoverinfo="skip",
        )
    )


def merge_or_split_levels(center_y: float,
                          zeeman_split_GHz: float,
                          vis_scale: float,
                          merge_threshold_GHz: float,
                          merged_gap: float,
                          min_split_visible: float,
                          max_split_visible: float | None = None) -> tuple[float, float, bool]:
    """
    Returns y_minus, y_plus, merged_flag.

    If Zeeman splitting is very small, show a single visually merged upper pair.
    Otherwise split into m_s=-1 and m_s=+1 with exaggerated visible spacing.
    """
    if abs(zeeman_split_GHz) < merge_threshold_GHz:
        y = center_y + merged_gap
        return y, y, True

    dy = max(vis_scale * abs(zeeman_split_GHz), min_split_visible)
    if max_split_visible is not None:
        dy = min(dy, max_split_visible)
    return center_y + merged_gap - dy / 2, center_y + merged_gap + dy / 2, False


# ============================================================
# Main plots
# ============================================================
def make_energy_level_figure(gs: dict, es: dict, f_minus: float, f_plus: float,
                             laser_wl_nm: float, mw_start: float, mw_end: float) -> go.Figure:
    """
    Literature-style NV- schematic with cleaner alignment and formatted labels.
    """
    fig = go.Figure()

    colors = {-1: "#1f77b4", 0: "#2ca02c", +1: "#d62728"}
    gray1 = "#666666"
    gray2 = "#4d4d4d"

    # ---------- Geometry ----------
    xL0, xL1 = 2.45, 5.05
    y_g0 = 1.20
    y_e0 = 7.00
    g_gap = 1.30
    e_gap = 1.02

    xS0, xS1 = 6.95, 8.00
    y_1A1 = 4.85
    y_1E = 3.72

    g_zeeman = gs[+1] - gs[-1]
    e_zeeman = es[+1] - es[-1]
    y_gm, y_gp, g_merged = merge_or_split_levels(
        y_g0, g_zeeman, vis_scale=18.0, merge_threshold_GHz=0.001,
        merged_gap=g_gap, min_split_visible=0.28, max_split_visible=0.95
    )
    y_em, y_ep, e_merged = merge_or_split_levels(
        y_e0, e_zeeman, vis_scale=14.0, merge_threshold_GHz=0.001,
        merged_gap=e_gap, min_split_visible=0.22, max_split_visible=0.75
    )

    # ---------- Levels ----------
    add_level(fig, xL0, xL1, y_g0, colors[0])
    add_level(fig, xL0, xL1, y_gm, colors[-1])
    if not g_merged:
        add_level(fig, xL0, xL1, y_gp, colors[+1])

    add_level(fig, xL0, xL1, y_e0, colors[0])
    add_level(fig, xL0, xL1, y_em, colors[-1])
    if not e_merged:
        add_level(fig, xL0, xL1, y_ep, colors[+1])

    add_level(fig, xS0, xS1, y_1A1, gray1)
    add_level(fig, xS0, xS1, y_1E, gray2)

    # ---------- Titles ----------
    fig.add_annotation(x=(xL0 + xL1) / 2, y=0.30, text="Ground state (³A<sub>2</sub>)", showarrow=False, font=dict(size=16))
    fig.add_annotation(x=(xL0 + xL1) / 2, y=8.55, text="Excited state (³E)", showarrow=False, font=dict(size=16))
    fig.add_annotation(x=(xS0 + xS1) / 2, y=5.78, text="Singlet dark states", showarrow=False, font=dict(size=16))

    # ---------- State labels ----------
    x_label_left = xL0 - 0.24
    fig.add_annotation(x=x_label_left, y=y_g0, text="m<sub>s</sub> = 0", showarrow=False, xanchor="right", font=dict(size=16, color=colors[0]))
    fig.add_annotation(x=x_label_left, y=y_e0, text="m<sub>s</sub> = 0", showarrow=False, xanchor="right", font=dict(size=16, color=colors[0]))

    if g_merged:
        fig.add_annotation(x=x_label_left, y=y_gm, text="m<sub>s</sub> = ±1", showarrow=False, xanchor="right", font=dict(size=16, color="#444444"))
    else:
        fig.add_annotation(x=x_label_left, y=y_gm - 0.06, text="m<sub>s</sub> = -1", showarrow=False, xanchor="right", font=dict(size=16, color=colors[-1]))
        fig.add_annotation(x=x_label_left, y=y_gp + 0.06, text="m<sub>s</sub> = +1", showarrow=False, xanchor="right", font=dict(size=16, color=colors[+1]))

    if e_merged:
        fig.add_annotation(x=x_label_left, y=y_em, text="m<sub>s</sub> = ±1", showarrow=False, xanchor="right", font=dict(size=16, color="#444444"))
    else:
        fig.add_annotation(x=x_label_left, y=y_em - 0.06, text="m<sub>s</sub> = -1", showarrow=False, xanchor="right", font=dict(size=16, color=colors[-1]))
        fig.add_annotation(x=x_label_left, y=y_ep + 0.06, text="m<sub>s</sub> = +1", showarrow=False, xanchor="right", font=dict(size=16, color=colors[+1]))

    fig.add_annotation(x=xS1 + 0.14, y=y_1A1, text="<sup>1</sup>A<sub>1</sub>", showarrow=False, xanchor="left", font=dict(size=16, color=gray1))
    fig.add_annotation(x=xS1 + 0.14, y=y_1E, text="<sup>1</sup>E", showarrow=False, xanchor="left", font=dict(size=16, color=gray2))

    # ---------- Optical excitation / PL ----------
    optical_x = [2.95, 3.75, 4.55]
    g_targets = [y_gm, y_g0, y_gp if not g_merged else y_gm]
    e_targets = [y_em, y_e0, y_ep if not e_merged else y_em]

    # Use pump-colored arrows for the 532 nm excitation and orange-red for fluorescence.
    pump_color = "#2ca02c"
    pl_color = "#cc3344"

    for x_arrow, yg, ye in zip(optical_x, g_targets, e_targets):
        fig.add_annotation(
            x=x_arrow, y=ye - 0.03, ax=x_arrow, ay=yg + 0.03,
            xref="x", yref="y", axref="x", ayref="y",
            showarrow=True, arrowhead=3, arrowsize=1.12,
            arrowwidth=2.5, arrowcolor=pump_color
        )

    pl_x = [3.18, 3.98, 4.78]
    for x_arrow, ye, yg in zip(pl_x, e_targets, g_targets):
        fig.add_annotation(
            x=x_arrow, y=yg + 0.03, ax=x_arrow, ay=ye - 0.03,
            xref="x", yref="y", axref="x", ayref="y",
            showarrow=True, arrowhead=3, arrowsize=1.12,
            arrowwidth=2.5, arrowcolor=pl_color
        )

    fig.add_annotation(
        x=0.92, y=5.02,
        text=f"Laser excitation ({laser_wl_nm:.1f} nm)",
        showarrow=False, xanchor="left",
        font=dict(size=16, color=pump_color),
        bgcolor="rgba(255,255,255,0.94)"
    )
    fig.add_annotation(
        x=0.92, y=4.40,
        text="Fluorescence / PL",
        showarrow=False, xanchor="left",
        font=dict(size=16, color=pl_color),
        bgcolor="rgba(255,255,255,0.94)"
    )

    # ---------- MW arrows ----------
    # Bring MW arrows closer together and closer to the ground-state manifold
    mw_specs = [
        (1.45, y_gm, f"MW: {f_minus:.4f} GHz", 1.86),
        (1.68, y_gp if not g_merged else y_gm, f"MW: {f_plus:.4f} GHz", 1.86),
    ]

    for idx, (xmw, ytarget, label, ytext) in enumerate(mw_specs):
        fig.add_annotation(
            x=xmw,
            y=ytarget - 0.01,
            ax=xmw,
            ay=y_g0 + 0.06,
            xref="x", yref="y", axref="x", ayref="y",
            showarrow=True,
            arrowhead=2,
            arrowsize=1.05,
            arrowwidth=2.2,
            arrowcolor="orange",
        )

        if idx == 0:
            fig.add_annotation(
                x=xmw - 0.08,
                y=ytext,
                text=label,
                showarrow=False,
                xanchor="right",
                font=dict(size=14, color="orange"),
                bgcolor="rgba(255,255,255,0.96)",
            )
        else:
            fig.add_annotation(
                x=xmw + 0.08,
                y=ytext,
                text=label,
                showarrow=False,
                xanchor="left",
                font=dict(size=14, color="orange"),
                bgcolor="rgba(255,255,255,0.96)",
            )


    # ---------- Dark-state pathway ----------
    # Three aligned arrows from excited manifold to 1A1, one arrow 1A1->1E, one return arrow 1E->ground m_s=0.
    isc_source_x = [5.12, 5.18, 5.24]
    isc_source_y = [y_em, y_e0, y_ep if not e_merged else y_em]
    isc_target_x = [xS0, xS0, xS0]
    isc_target_y = [y_1A1 + 0.10, y_1A1, y_1A1 - 0.10]

    for xs, ys, xt, yt in zip(isc_source_x, isc_source_y, isc_target_x, isc_target_y):
        fig.add_annotation(
            x=xt, y=yt,
            ax=xs, ay=ys,
            xref="x", yref="y", axref="x", ayref="y",
            showarrow=True, arrowhead=2, arrowsize=1.02,
            arrowwidth=2.2, arrowcolor=gray1
        )

    x_s_relax = 7.45
    fig.add_annotation(
        x=x_s_relax, y=y_1E,
        ax=x_s_relax, ay=y_1A1,
        xref="x", yref="y", axref="x", ayref="y",
        showarrow=True, arrowhead=2, arrowsize=1.0,
        arrowwidth=2.0, arrowcolor=gray2
    )

    fig.add_annotation(
        x=5.20, y=y_g0 + 0.02,
        ax=xS0, ay=y_1E,
        xref="x", yref="y", axref="x", ayref="y",
        showarrow=True, arrowhead=2, arrowsize=1.04,
        arrowwidth=2.2, arrowcolor=gray2
    )

    fig.update_layout(
        title="NV- Electronic Levels and ODMR Mechanism",
        xaxis=dict(visible=False, range=[0.6, 8.45]),
        yaxis=dict(visible=False, range=[0.0, 8.90]),
        height=660,
        margin=dict(l=20, r=20, t=60, b=20),
        template="plotly_white",
    )
    return fig


def make_pl_figure(wavelengths_nm: np.ndarray, pl_vals: np.ndarray, laser_wl_nm: float) -> go.Figure:
    fig = go.Figure()
    fig.add_trace(
        go.Scatter(
            x=wavelengths_nm,
            y=pl_vals,
            mode="lines",
            name="NV- PL",
            line=dict(width=3),
        )
    )
    fig.add_vline(
        x=laser_wl_nm,
        line_width=2,
        line_dash="dash",
        line_color="green",
        annotation_text=f"Laser {laser_wl_nm:.1f} nm",
        annotation_position="top left",
    )
    fig.add_vline(
        x=ZPL_WL,
        line_width=2,
        line_dash="dot",
        line_color="crimson",
        annotation_text="ZPL ~ 637 nm",
        annotation_position="top right",
    )
    fig.update_layout(
        title="PL Emission from NV- Center",
        xaxis_title="Wavelength (nm)",
        yaxis_title="Normalized PL (a.u.)",
        height=420,
        template="plotly_white",
        margin=dict(l=20, r=20, t=60, b=20),
    )
    return fig


def make_odmr_figure(freqs_GHz: np.ndarray, odmr_pl: np.ndarray,
                     f_minus: float, f_plus: float) -> go.Figure:
    fig = go.Figure()
    fig.add_trace(
        go.Scatter(
            x=freqs_GHz,
            y=odmr_pl,
            mode="lines",
            name="ODMR",
            line=dict(width=3),
        )
    )
    fig.add_vline(
        x=f_minus,
        line_width=2,
        line_dash="dash",
        annotation_text=f"f- = {f_minus:.4f} GHz",
        annotation_position="bottom left",
    )
    fig.add_vline(
        x=f_plus,
        line_width=2,
        line_dash="dash",
        annotation_text=f"f+ = {f_plus:.4f} GHz",
        annotation_position="bottom right",
    )
    fig.update_layout(
        title="ODMR: PL vs Microwave Frequency",
        xaxis_title="Microwave Frequency (GHz)",
        yaxis_title="Normalized PL (a.u.)",
        height=420,
        template="plotly_white",
        margin=dict(l=20, r=20, t=60, b=20),
    )
    return fig


# ============================================================
# Streamlit UI
# ============================================================
st.set_page_config(page_title="NV Center Visualizer", layout="wide")

st.title("NV Center Interactive Visualizer")
st.write(
    "This app shows a simplified NV- model with ground-state and excited-state triplets, "
    "dark-state pathways, a phenomenological PL spectrum, and ODMR under an external magnetic field."
)

with st.sidebar:
    st.header("Controls")

    st.subheader("Optical and Microwave")
    laser_wl_nm = st.slider("Laser wavelength (nm)", min_value=450.0, max_value=750.0, value=532.0, step=1.0)
    mw_start = st.number_input("MW start (GHz)", min_value=0.1, max_value=20.0, value=2.60, step=0.01, format="%.3f")
    mw_end = st.number_input("MW end (GHz)", min_value=0.1, max_value=20.0, value=3.15, step=0.01, format="%.3f")
    mw_step_MHz = st.number_input("MW step (MHz)", min_value=0.1, max_value=50.0, value=1.0, step=0.1, format="%.2f")

    st.subheader("Magnetic Field")
    B_uT = st.slider(
        "External magnetic field strength (µT)",
        min_value=0.0,
        max_value=200.0,
        value=50.0,
        step=1.0,
    )
    B_T = B_uT * 1e-6
    st.caption(f"External magnetic field strength: {B_uT:.1f} µT")
    B_theta_deg = st.slider("B-field theta (deg)", min_value=0.0, max_value=180.0, value=0.0, step=1.0)
    B_phi_deg = st.slider("B-field phi (deg)", min_value=0.0, max_value=360.0, value=0.0, step=1.0)

    st.subheader("NV Orientation")
    NV_theta_deg = st.slider("NV axis / dipole theta (deg)", min_value=0.0, max_value=180.0, value=0.0, step=1.0)
    NV_phi_deg = st.slider("NV axis / dipole phi (deg)", min_value=0.0, max_value=360.0, value=0.0, step=1.0)

    st.subheader("ODMR Appearance")
    linewidth_MHz = st.slider("ODMR linewidth (MHz)", min_value=0.5, max_value=30.0, value=6.0, step=0.5)
    contrast = st.slider("ODMR contrast", min_value=0.01, max_value=0.30, value=0.08, step=0.01)

if mw_end <= mw_start:
    st.error("MW end frequency must be greater than MW start frequency.")
    st.stop()

mw_step_GHz = mw_step_MHz * 1e-3
npts = int(np.floor((mw_end - mw_start) / mw_step_GHz)) + 1
if npts < 5:
    st.error("Microwave scan has too few points. Increase the scan range or reduce the step size.")
    st.stop()
if npts > 20000:
    st.error("Microwave scan has too many points. Increase the step size or reduce the scan range.")
    st.stop()

B_parallel_T, B_hat, NV_hat = compute_b_parallel(B_T, B_theta_deg, B_phi_deg, NV_theta_deg, NV_phi_deg)
gs = ground_state_levels(B_parallel_T)
es = excited_state_levels(B_parallel_T)

wavelengths_nm = np.linspace(600.0, 800.0, 1200)
pl_spectrum = simulate_pl_spectrum(wavelengths_nm, laser_wl_nm)

freqs_GHz = mw_start + np.arange(npts) * mw_step_GHz
odmr_pl, f_minus, f_plus = simulate_odmr(freqs_GHz, B_parallel_T, linewidth_MHz, contrast)

st.subheader("Model formulation")
form_col1, form_col2 = st.columns(2)
with form_col1:
    st.markdown(
        """
**Simplified ground-state spin Hamiltonian**

H = D S<sub>z</sub><sup>2</sup> + γ<sub>e</sub> B<sub>∥</sub> S<sub>z</sub>

where D = 2.87 GHz, γ<sub>e</sub> ≈ 28 GHz/T, and B<sub>∥</sub> = <b>B · n</b><sub>NV</sub>.

**ODMR transition frequencies**

f<sub>+</sub> = D + γ<sub>e</sub>B<sub>∥</sub>

f<sub>-</sub> = D − γ<sub>e</sub>B<sub>∥</sub>
        """,
        unsafe_allow_html=True,
    )
with form_col2:
    st.markdown(
        """
**ODMR fluorescence model**

Δf = 2γ<sub>e</sub>B<sub>∥</sub>

PL(f) = 1 − C [½ L(f,f<sub>-</sub>) + ½ L(f,f<sub>+</sub>)]

L(f,f<sub>0</sub>) = 1 / (1 + ((f − f<sub>0</sub>)/Γ)<sup>2</sup>)

**Photoluminescence spectrum**

PL(λ) is modeled phenomenologically as a zero-phonon line near 637 nm plus a broad phonon sideband.
        """,
        unsafe_allow_html=True,
    )

energy_fig = make_energy_level_figure(gs, es, f_minus, f_plus, laser_wl_nm, mw_start, mw_end)
pl_fig = make_pl_figure(wavelengths_nm, pl_spectrum, laser_wl_nm)
odmr_fig = make_odmr_figure(freqs_GHz, odmr_pl, f_minus, f_plus)

st.plotly_chart(energy_fig, use_container_width=True)

col1, col2 = st.columns(2)
with col1:
    st.plotly_chart(pl_fig, use_container_width=True)
with col2:
    st.plotly_chart(odmr_fig, use_container_width=True)

st.subheader("Derived Parameters")
splitting_MHz = (f_plus - f_minus) * 1e3
st.markdown(
    f"""
**Ground-state ZFS:** {D_GS:.3f} GHz  
**Excited-state ZFS (simplified):** {D_ES:.3f} GHz  
**B magnitude:** {B_T * 1e6:.3f} µT  
**B parallel to NV axis:** {B_parallel_T * 1e6:+.3f} µT  
**NV axis unit vector:** {direction_text(NV_hat)}  
**B-field unit vector:** {direction_text(B_hat)}  
**ODMR resonance f-:** {f_minus:.6f} GHz  
**ODMR resonance f+:** {f_plus:.6f} GHz  
**Resonance splitting:** {splitting_MHz:.3f} MHz
    """
)


