#!/usr/bin/env python3
"""
ibidi_tracks_plot.py

ibidi-style chemotaxis trajectory plot using explicit ReservoirLocation mapping.

Copyright 2025 MetaVi Labs Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Image coordinate system (input):
  - Origin: top-left of image
  - +X right
  - +Y down

ibidi coordinate system (output):
  - +Y = toward chemoattractant (parallel to gradient)
  - +X = perpendicular to gradient (90° clockwise)

ReservoirLocation matches the C# enum semantics exactly.
"""

from __future__ import annotations
import argparse
from enum import Enum
from pathlib import Path
from typing import Dict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D


# -----------------------------
# ReservoirLocation enum (Python mirror of C#)
# -----------------------------
class ReservoirLocation(Enum):
    Bottom = "Bottom"   # +Y
    Top    = "Top"      # -Y
    Right  = "Right"    # +X
    Left   = "Left"     # -X


# -----------------------------
# File parsing
# -----------------------------
REQUIRED_COLS = ["Track n", "Slice n", "x", "y"]


def load_ibidi_tracks(path: Path) -> pd.DataFrame:
    # ibidi exports are TAB-separated; headers contain spaces
    df = pd.read_csv(path, sep="\t", engine="python")
    df.columns = [c.strip() for c in df.columns]

    missing = [c for c in REQUIRED_COLS if c not in df.columns]
    if missing:
        raise ValueError(f"Missing columns {missing}. Found: {list(df.columns)}")

    df = df[REQUIRED_COLS].dropna()
    df["Track n"] = df["Track n"].astype(int)
    df["Slice n"] = df["Slice n"].astype(int)
    df["x"] = df["x"].astype(float)
    df["y"] = df["y"].astype(float)

    return df


def build_tracks(df: pd.DataFrame) -> Dict[int, np.ndarray]:
    tracks = {}
    for tid, g in df.groupby("Track n"):
        g = g.sort_values("Slice n")
        if len(g) >= 2:
            tracks[int(tid)] = g[["x", "y"]].to_numpy(float)
    return tracks


# -----------------------------
# Normalization
# -----------------------------
def normalize_tracks(tracks: Dict[int, np.ndarray], mode: str) -> Dict[int, np.ndarray]:
    if mode == "per_track":
        return {tid: pts - pts[0] for tid, pts in tracks.items()}

    if mode == "com_start":
        starts = np.vstack([pts[0] for pts in tracks.values()])
        com = starts.mean(axis=0)
        return {tid: pts - com for tid, pts in tracks.items()}

    raise ValueError("origin must be per_track or com_start")


# -----------------------------
# Coordinate transform (IMAGE -> IBIDI)
# -----------------------------
def forward_vector_image(res: ReservoirLocation) -> np.ndarray:
    if res == ReservoirLocation.Bottom:
        return np.array([0.0, 1.0])
    if res == ReservoirLocation.Top:
        return np.array([0.0, -1.0])
    if res == ReservoirLocation.Right:
        return np.array([1.0, 0.0])
    if res == ReservoirLocation.Left:
        return np.array([-1.0, 0.0])
    raise ValueError(res)


def image_to_ibidi_matrix(res: ReservoirLocation) -> np.ndarray:
    """
    Build transform so that:
      Y_ibidi = dot(d_image, forward)
      X_ibidi = dot(d_image, perp)

    Perp is 90° clockwise from forward:
      perp = (fy, -fx)
    """
    f = forward_vector_image(res)
    perp = np.array([f[1], -f[0]])
    return np.vstack([perp, f])  # rows


def apply_transform(tracks: Dict[int, np.ndarray], T: np.ndarray) -> Dict[int, np.ndarray]:
    return {tid: (T @ pts.T).T for tid, pts in tracks.items()}


# -----------------------------
# Plotting
# -----------------------------
def compute_M_end(tracks: Dict[int, np.ndarray]) -> np.ndarray:
    ends = np.vstack([pts[-1] for pts in tracks.values()])
    return ends.mean(axis=0)


def set_square_limits(ax, pts: np.ndarray, pad=0.15):
    max_abs = np.max(np.abs(pts))
    lim = max_abs * (1 + pad) if max_abs > 0 else 1.0
    ax.set_xlim(-lim, lim)
    ax.set_ylim(-lim, lim)


def plot_ibidi(tracks: Dict[int, np.ndarray], title: str | None, out: Path, show: bool):
    M_end = compute_M_end(tracks)

    fig, ax = plt.subplots(figsize=(6.2, 5.2))
    ax.axhline(0, color="black", lw=1)
    ax.axvline(0, color="black", lw=1)

    all_pts = []
    endpoints = []

    for pts in tracks.values():
        ax.plot(pts[:, 0], pts[:, 1], color="crimson", lw=1.7)
        endpoints.append(pts[-1])
        all_pts.append(pts)

    endpoints = np.vstack(endpoints)
    all_pts = np.vstack(all_pts)

    ax.scatter(endpoints[:, 0], endpoints[:, 1], s=22, color="0.35")
    ax.scatter(0, 0, marker="x", s=70, color="0.1")
    ax.scatter(M_end[0], M_end[1], marker="+", s=180, color="dodgerblue", lw=2.2)

    ax.plot([0, M_end[0]], [0, M_end[1]],
            linestyle="--", lw=2.2, color="dodgerblue")

    ax.plot([M_end[0], M_end[0]], [0, M_end[1]], ":", color="0.25")
    ax.plot([0, M_end[0]], [M_end[1], M_end[1]], ":", color="0.25")

    ax.set_xlabel("X (perpendicular)")
    ax.set_ylabel("Y (toward chemoattractant)")
    ax.set_aspect("equal")
    if title:
        ax.set_title(title)

    set_square_limits(ax, np.vstack([all_pts, M_end, [0, 0]]))

    legend = [
        Line2D([0], [0], color="crimson", lw=2),
        Line2D([0], [0], color="dodgerblue", lw=2, ls="--"),
        Line2D([0], [0], marker="o", color="0.35", lw=0),
        Line2D([0], [0], marker="+", color="dodgerblue", lw=0),
        Line2D([0], [0], marker="x", color="0.1", lw=0),
    ]
    labels = [
        "Accumulated distance (cell path)",
        "Euclidean distance",
        "Endpoint",
        "M_end (center of mass)",
        "M_start = (0,0)",
    ]
    ax.legend(legend, labels, loc="lower left", fontsize=8, frameon=False)

    plt.tight_layout()
    plt.savefig(out, dpi=200)
    if show:
        plt.show()
    plt.close(fig)


# -----------------------------
# Main
# -----------------------------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("input", help="ibidi Tracks TSV file")
    ap.add_argument("--reservoir", required=True,
                    choices=[e.name for e in ReservoirLocation],
                    help="ReservoirLocation enum value")
    ap.add_argument("--origin", default="per_track",
                    choices=["per_track", "com_start"])
    ap.add_argument("--out", default="ibidi_plot.png")
    ap.add_argument("--title", default=None)
    ap.add_argument("--no-show", action="store_true")
    args = ap.parse_args()

    res = ReservoirLocation[args.reservoir]

    df = load_ibidi_tracks(Path(args.input))
    tracks = build_tracks(df)
    tracks = normalize_tracks(tracks, args.origin)

    T = image_to_ibidi_matrix(res)
    tracks_ibidi = apply_transform(tracks, T)

    plot_ibidi(tracks_ibidi, args.title, Path(args.out), not args.no_show)


if __name__ == "__main__":
    main()
