import { Dispatch, SetStateAction, useEffect, useRef, useState } from "react";
import * as d3 from "d3";
import {
  Box,
  Divider,
  Table,
  TableBody,
  TableCell,
  TableHead,
  TableRow,
  Theme,
  Tooltip,
  Typography,
  useTheme,
} from "@mui/material";
import { Instance } from "@popperjs/core";
import { CvssScoreType } from "~/lib/score";
import { truncateText } from "~/lib/truncateText";
import {
  computeScore,
  Data,
  Datum,
  getEntries,
  GraphTarget,
  pathItemToId,
  pathToId,
  statusColorMap,
  StatusDomain,
  StatusValueMap,
  UseAssetUrlStats,
} from "~/hooks/useAssetUrlStats";

export type SunburstChartProps = {
  assetUrlStats: UseAssetUrlStats;
};

export function SunburstChart({ assetUrlStats }: SunburstChartProps) {
  const ref = useRef<SVGSVGElement>(null);
  const theme = useTheme();

  const { data, statusDomain } = assetUrlStats;

  const positionRef = useRef<{ x: number; y: number }>({
    x: 0,
    y: 0,
  });

  const [hoveredPoint, setHoveredPoint] = useState<
    d3.HierarchyRectangularNode<Datum> | undefined
  >(undefined);
  const [zoomedPoint, setZoomedPoint] = useState<
    d3.HierarchyRectangularNode<Datum> | undefined
  >(undefined);
  const [tooltipVisible, setTooltipVisible] = useState(false);

  const statusLightRange = statusDomain.map((s) => {
    return theme.palette[statusColorMap[s]].light;
  });

  const statusLightScale = d3
    .scaleOrdinal<string>()
    .domain(statusDomain)
    .range(statusLightRange);

  const { ...gradeStats } = hoveredPoint?.data.stats || {};
  const tooltipTitle = (
    <Box>
      <Typography fontSize={12} textAlign="center">
        <strong style={{ display: "block" }}>{hoveredPoint?.data.name}</strong>
        {hoveredPoint?.data.value}
      </Typography>
      <Divider sx={{ my: 1, borderColor: "rgba(0,0,0,0.3)" }} />
      <Table
        size="small"
        sx={{
          "th, td": {
            fontSize: 10,
            fontWeight: "bold",
            borderBottom: "unset",
            px: 0.5,
            py: 0,
            textAlign: "right",
          },
          thead: { backgroundColor: "unset", boxShadow: "unset" },
          th: {
            color: "text.primary",
            fontSize: 10,
            textTransform: "uppercase",
          },
        }}
      >
        <TableHead>
          <TableRow>
            <TableCell>Score</TableCell>
            <TableCell>Assets</TableCell>
          </TableRow>
        </TableHead>
        <TableBody>
          {Object.entries(gradeStats).map(([grade, value]) => (
            <TableRow
              key={grade}
              sx={{
                td: {
                  color: statusLightScale(grade.toLowerCase()),
                  textTransform: "uppercase",
                },
              }}
            >
              <TableCell>{grade}</TableCell>
              <TableCell>{value}</TableCell>
            </TableRow>
          ))}
        </TableBody>
      </Table>
    </Box>
  );

  const popperRef = useRef<Instance>(null);

  const handleMouseOver = (
    x: number,
    y: number,
    p: d3.HierarchyRectangularNode<Datum>,
  ) => {
    positionRef.current = { x, y };
    setHoveredPoint(p);
    setTooltipVisible(true);
  };

  const handleMouseMove = (
    x: number,
    y: number,
    p: d3.HierarchyRectangularNode<Datum>,
  ) => {
    positionRef.current = { x, y };
    setTooltipVisible(true);
    setHoveredPoint(p);

    if (popperRef.current != null) {
      popperRef.current.update();
    }
  };

  const handleMouseOut = () => {
    setTooltipVisible(false);
  };

  const handleZoomStart = (p: d3.HierarchyRectangularNode<Datum>) => {
    setTooltipVisible(false);
  };

  const handleZoomEnd = (p: d3.HierarchyRectangularNode<Datum>) => {
    assetUrlStats.setActivePath(p.data.path);
  };

  useEffect(() => {
    if (ref.current) {
      DrawSunburst(
        assetUrlStats.scopeId,
        ref.current,
        statusDomain,
        data,
        theme,
        handleMouseOver,
        handleMouseMove,
        handleMouseOut,
        handleZoomStart,
        handleZoomEnd,
        zoomedPoint,
        setZoomedPoint,
      );
    }
  }, [data]);

  return (
    <Box sx={{ width: "100%" }}>
      <Tooltip
        title={tooltipTitle}
        arrow
        placement="top"
        open={tooltipVisible}
        disableInteractive
        PopperProps={{
          popperRef: popperRef,
          anchorEl: {
            getBoundingClientRect: () => {
              return new DOMRect(
                positionRef.current.x,
                positionRef.current.y,
                0,
                0,
              );
            },
          },
          sx: {
            ".MuiTooltip-tooltip": {
              p: 1,
              fontSize: 14,
              lineHeight: "24px",
              fontWeight: (theme) => theme.typography.fontWeightRegular,
            },
            ".MuiTable-root": {
              width: "auto",
              m: "0 auto",
            },
          },
        }}
      >
        <svg
          ref={ref}
          role="application"
          style={{ display: "block", margin: "auto", maxWidth: "600px" }}
        />
      </Tooltip>
    </Box>
  );
}

function DrawSunburst(
  scopeId: string,
  containerEl: SVGSVGElement,
  statusDomain: StatusDomain,
  data: Data,
  theme: Theme,
  onMouseOver: (
    x: number,
    y: number,
    p: d3.HierarchyRectangularNode<Datum>,
  ) => void,
  onMouseMove: (
    x: number,
    y: number,
    p: d3.HierarchyRectangularNode<Datum>,
  ) => void,
  onMouseOut: () => void,
  onZoomStart: (p: d3.HierarchyRectangularNode<Datum>) => void,
  onZoomEnd: (p: d3.HierarchyRectangularNode<Datum>) => void,
  zoomedPoint: d3.HierarchyRectangularNode<Datum> | undefined,
  setZoomedPoint: Dispatch<
    SetStateAction<d3.HierarchyRectangularNode<Datum> | undefined>
  >,
) {
  // Specify the chart’s dimensions.
  const width = 516;
  const height = width;
  const radius = width / 4.5;
  let isHover = false;
  let isZooming = false;

  const svg = d3.select(containerEl);

  svg
    .attr("viewBox", [-width / 2, -height / 2, width, width])
    .style("font", "10px sans-serif");

  type TreeLayer = { [key: string]: TreeLayer };
  type TreeNode = {
    name: string;
    depth: number;
    children: TreeNode[];
    path: { key: string; value: string }[];
    pathId: string;
  };

  function burrow(leaves: Datum[]): TreeNode {
    // create nested object
    const obj: TreeLayer = {};
    leaves.forEach((leaf) => {
      // start at root
      let layer = obj;

      // create children as nested objects
      leaf.path.forEach((p, i) => {
        const key = pathItemToId(p);
        layer[key] = key in layer ? layer[key] : {};
        layer = layer[key];
      });
    });

    // recursively create children array
    function descend(
      obj: TreeLayer,
      depth: number,
      parentPath: { key: string; value: string }[],
    ): TreeNode[] {
      const arr = [];
      for (const k in obj) {
        const [key, value] = k.split("=");
        const path = [...parentPath, { key, value }];
        const pathId = pathToId(path);
        arr.push({
          depth: depth,
          children: descend(obj[k], depth + 1, path),
          path,
          pathId,
          ...nodeData(pathId),
        });
      }
      return arr;
    }

    function nodeData(pathId: string) {
      const leaf = leaves.find((leaf) => pathId.endsWith(leaf.pathId));
      return leaf || { name: pathId.split("=").pop() || "" };
    }

    const rootPath = [{ key: `root-${scopeId}`, value: `root-${scopeId}` }];

    // use descend to create nested children arrys
    return {
      name: leaves[0].name === "empty" ? "empty" : "total",
      children: descend(obj, 1, rootPath),
      path: rootPath,
      depth: 0,
      pathId: pathToId(rootPath),
    };
  }

  const initialLeafStats = statusDomain.reduce<Partial<StatusValueMap>>(
    (stats, statusKey) => {
      stats[statusKey] = 0;
      return stats;
    },
    {},
  );
  const emptyPath = [{ key: "empty", value: "empty" }];
  const emptyPathId = pathToId(emptyPath);
  const leafData =
    data.length > 0
      ? data
      : [
          {
            path: emptyPath,
            pathId: emptyPathId,
            name: "empty",
            score: 0,
            value: 1,
            stats: { ...initialLeafStats },
          },
        ];

  const hierarchy = d3
    // @ts-ignore
    .hierarchy<Datum>(burrow(leafData))
    .sum((d) => d.value)
    .sort((a, b) => {
      if (b.value === undefined || a.value === undefined) return -1;
      if (b.value === a.value) return b.data.name.localeCompare(a.data.name);
      return b.value - a.value;
    });

  hierarchy.each((d) => {
    if (!d.data.stats) {
      d.data.stats = d.leaves().reduce(
        (acc, curr) => {
          getEntries(curr.data.stats).forEach((entry) => {
            if (!entry) return acc;
            const [k, v = 0] = entry;
            const a = acc[k] || 0;
            acc[k] = a + v;
          });
          return acc;
        },
        { ...initialLeafStats },
      );
      d.data.score = computeScore(d.data.stats);
      d.data.value = d3.sum(Object.values(d.data.stats));
    }
  });

  const root = d3.partition<Datum>().size([2 * Math.PI, hierarchy.height + 1])(
    hierarchy,
  );

  root.each((d) => (d.data.current = d));

  const radArc = d3.scaleLinear(
    [1, 2, 3],
    [radius, 1.95 * radius, 2.13 * radius],
  );
  const radArcSpacing = 0.05 * radius;

  // Create the arc generator.
  const arc = d3
    .arc<GraphTarget>()
    .startAngle((d) => d.x0)
    .endAngle((d) => d.x1)
    .padAngle((d) => Math.min((d.x1 - d.x0) / 2, 0.005))
    .padRadius(radius * 2)
    .cornerRadius(2)
    .innerRadius((d) => radArc(d.y0))
    .outerRadius((d) => {
      return Math.max(radArc(d.y0), radArc(d.y1)) - radArcSpacing;
    });

  const ringData = root.descendants().slice(1);

  // Append the arcs.
  const path = svg
    .selectAll<SVGPathElement, (typeof ringData)[0]>("path")
    .data(ringData, (d) => d.data.pathId)
    .join("path")
    .attr("fill", (d) => {
      if (d.data.name === "empty") return theme.palette.background.lightest;
      if (d.data.score === -1) return theme.palette.unrated.main;
      return theme.palette[CvssScoreType(100 - d.data.score * 100)].main;
    })
    .attr("fill-opacity", arcOpacity)
    .attr("pointer-events", (d) => {
      return arcVisible(d.data.current) ? "auto" : "none";
    })
    .attr("d", (d) => (d.data.current == null ? null : arc(d.data.current)))
    .on("mouseover", mouseover)
    .on("mouseout", mouseout)
    .on("mousemove", mousemove);

  path
    .style("cursor", (d) =>
      d.data.name === "empty" || !d.children ? "default" : "pointer",
    )
    .on("click", clicked);

  const label = svg
    .selectAll<SVGGElement, (typeof ringData)[0]>("g.label-group")
    .data(ringData, (d) => d.data.pathId)
    .join(
      (enter) => {
        const enterLabels = enter.append("g");
        enterLabels
          .append("text")
          .attr("class", "label-group-name")
          .attr("dy", "0.35em")
          .attr("fill", "white")
          .attr("font-size", "12px")
          .attr("font-weight", "bold")
          .style("transform", "translateY(-0.7em)")
          .text((d) => truncateText(d.data.name, 15, true));
        enterLabels
          .append("text")
          .attr("class", "label-group-value")
          .attr("dy", "0.35em")
          .attr("fill", "white")
          .attr("font-size", "12px")
          .attr("font-weight", "normal")
          .style("transform", "translateY(0.7em)")
          .text((d) => d.value || "");
        return enterLabels;
      },
      (update) => {
        update
          .selectAll<
            SVGTextElement,
            d3.HierarchyRectangularNode<Datum>
          >("text.label-group-name")
          .text((d) => truncateText(d.data.name, 15, true));
        update
          .selectAll<
            SVGTextElement,
            d3.HierarchyRectangularNode<Datum>
          >("text.label-group-value")
          .text((d) => d.value || "");
        return update;
      },
      (exit) => {
        return exit.remove();
      },
    )
    .attr("class", "label-group")
    .attr("pointer-events", "none")
    .attr("text-anchor", "middle")
    .attr("fill-opacity", (d) => +labelVisible(d.data.current))
    .attr("transform", (d) => labelTransform(d.data.current))
    .style("opacity", (d) => (d.data.name === "empty" ? 0 : 1))
    .style("user-select", "none");

  const parent = svg
    .selectAll<SVGGElement, d3.HierarchyRectangularNode<Datum>>(
      "g.parent-group",
    )
    .data([zoomedPoint?.parent || root], (d) => d.data.pathId)
    .style("cursor", "default")
    .on("click", clicked);

  const parentExit = parent.exit().remove();

  const parentEnter = parent
    .enter()
    .append("g")
    .attr("class", "parent-group")
    .style("opacity", (d) => (d.data.name === "empty" ? 0 : 1));

  const parentCircle = parentEnter
    .append<SVGCircleElement>("circle")
    .attr("r", radius)
    .attr("fill", "transparent")
    .attr("pointer-events", "all");

  const parentLabel = parent.select("g.parent-label");
  const parentName = parentLabel
    .select("text.parent-label-name")
    .attr("fill", theme.palette.text.primary);
  const parentValue = parentLabel
    .select("text.parent-label-value")
    .attr("fill", theme.palette.text.primary);

  const parentLabelEnter = parentEnter
    .append<SVGGElement>("svg:g")
    .attr("class", "parent-label");

  const parentNameEnter = parentLabelEnter
    .append<SVGTextElement>("text")
    .attr("class", "parent-label-name")
    .attr("dy", "0.35em")
    .attr("fill", theme.palette.text.primary)
    .attr("font-size", "12px")
    .attr("font-weight", "bold")
    .attr("text-anchor", "middle")
    .style("transform", "translateY(-0.7em)")
    .text((d) => d.data.name);

  const parentValueEnter = parentLabelEnter
    .append<SVGTextElement>("text")
    .attr("class", "parent-label-value")
    .attr("dy", "0.35em")
    .attr("fill", theme.palette.text.primary)
    .attr("font-size", "12px")
    .attr("font-weight", "normal")
    .attr("text-anchor", "middle")
    .style("transform", "translateY(0.7em)")
    .text((d) => d.value || 0);

  if (zoomedPoint) {
    clicked({ clientX: 0, clientY: 0 }, zoomedPoint, false);
  }

  function arcOpacity(d: d3.HierarchyRectangularNode<Datum>): number {
    if (!arcVisible(d.data.current)) return 0;
    if (isHover && (d.data.current == null || d.data.hover != true)) return 0.4;
    return 1;
  }

  function mouseover(e: any, p: d3.HierarchyRectangularNode<Datum>) {
    if (isZooming || p.data.name === "empty") return;

    isHover = true;
    root.each((d) => delete d.data.hover);
    p.data.hover = true;
    if (p.parent != null) p.parent.data.hover = true;
    p.children?.forEach((c) => (c.data.hover = true));

    path
      .transition()
      .duration(250)
      .filter(function (d) {
        if (d.data.hover) return true;
        return (
          parseFloat(this.getAttribute("fill-opacity") || "0") > 0 ||
          arcVisible(d.data.target)
        );
      })
      .attr("fill-opacity", arcOpacity);

    onMouseOver(e.clientX, e.clientY, p);
  }

  function mousemove(e: any, p: d3.HierarchyRectangularNode<Datum>) {
    if (p.data.hover != true) return mouseover(e, p);
    onMouseMove(e.clientX, e.clientY, p);
  }

  function mouseout(_event: any, p: d3.HierarchyRectangularNode<Datum>) {
    if (isZooming) return;

    isHover = false;
    delete p.data.hover;
    if (p.parent != null) delete p.parent.data.hover;
    p.children?.forEach((c) => delete c.data.hover);

    path
      .transition()
      .duration(200)
      .filter(function (d) {
        return (
          parseFloat(this.getAttribute("fill-opacity") || "0") > 0 ||
          arcVisible(d.data.target)
        );
      })
      .attr("fill-opacity", arcOpacity);

    onMouseOut();
  }

  // Handle zoom on click.
  function clicked(
    event: any,
    p: d3.HierarchyRectangularNode<Datum>,
    transition = true,
  ) {
    if (isZooming) return;

    const zoomDuration = transition ? 750 : 0;

    parent.style("cursor", p.depth !== 0 ? "pointer" : "default");

    // For the outermost nodes we only zoom into their parent, no further
    if (p.children == null || p.children.length == 0) {
      p = p.parent as d3.HierarchyRectangularNode<Datum>;

      // If we are already at the current depth, there is nothing to do
      if (p.data.current?.y0 == 0) return;
    }

    // Early exit condition: If we clicked on the root element on the
    // outermost zoom level, then there is nothing to do (can't go further up).
    // We also do this to avoid another animation transition starting which would
    // block hover effects.
    if (
      p.parent == null &&
      root.children != null &&
      root.children[0].data.current?.y0 == 1
    )
      return;

    isZooming = true;
    onZoomStart(p);

    parent.datum(p.parent || root);
    if (transition) {
      // parent.data([p.parent || root]);
      parentLabel
        .transition()
        .duration(zoomDuration / 2)
        .style("opacity", "0")
        .style("transform", "scale(0.8)")
        .end()
        .then(() => {
          parentName.text(p.data.name);
          parentValue.text(p.value || 0);
          parentLabel
            .style("transform", "scale(1.2)")
            .transition()
            .duration(zoomDuration / 2)
            .style("opacity", "1")
            .style("transform", "scale(1)");
        });
    }

    root.each(
      (d) =>
        (d.data.target = {
          height: d.height,
          depth: d.depth,
          x0:
            Math.max(0, Math.min(1, (d.x0 - p.x0) / (p.x1 - p.x0))) *
            2 *
            Math.PI,
          x1:
            Math.max(0, Math.min(1, (d.x1 - p.x0) / (p.x1 - p.x0))) *
            2 *
            Math.PI,
          y0: Math.max(0, d.y0 - p.depth),
          y1: Math.max(0, d.y1 - p.depth),
        }),
    );

    // Transition the data on all arcs, even the ones that aren’t visible,
    // so that if this transition is interrupted, entering arcs will start
    // the next transition from the desired position.
    path
      .transition()
      .duration(zoomDuration)
      .tween("data", (d) => {
        // we can forcefully override the target here because at this point
        // it must be set to non-null (see above)
        const i = d3.interpolate<GraphTarget>(d.data.current, d.data.target!);
        return (t) => (d.data.current = i(t));
      })
      .filter(function (d) {
        return (
          parseFloat(this.getAttribute("fill-opacity") || "0") > 0 ||
          arcVisible(d.data.target)
        );
      })
      .attr("fill-opacity", (d) => (arcVisible(d.data.target) ? 1 : 0))
      .attr("pointer-events", (d) =>
        arcVisible(d.data.target) ? "auto" : "none",
      )
      .attrTween(
        "d",
        (d) => () => (d.data.current ? arc(d.data.current) : null) || "",
      )
      .end()
      .then(() => {
        isZooming = false;
        setZoomedPoint(p);
        if (transition) {
          onZoomEnd(p);
        }
      });

    label
      .filter(function (d) {
        return (
          parseFloat(this.getAttribute("fill-opacity") || "0") > 0 ||
          labelVisible(d.data.target)
        );
      })
      .transition()
      .duration(zoomDuration)
      .attr("fill-opacity", (d) => +labelVisible(d.data.target))
      .attrTween("transform", (d) => () => labelTransform(d.data.current));
  }

  function arcVisible(d: GraphTarget | undefined): boolean {
    if (d == null) return false;
    return d.y1 <= 3 && d.y0 >= 1 && d.x1 > d.x0;
  }

  function labelVisible(d: GraphTarget | undefined): boolean {
    if (d == null) return false;
    return d.y1 <= 2 && d.y0 >= 1 && (d.y1 - d.y0) * (d.x1 - d.x0) > 0.15;
  }

  function labelPosition(d: GraphTarget | undefined): { x: number; y: number } {
    if (d == null) return { x: 0, y: 0 };
    const x = (((d.x0 + d.x1) / 2) * 180) / Math.PI;
    const y = ((d.y0 + d.y1) / 2) * radius - radArcSpacing;
    return { x, y };
  }

  function labelTransform(d: GraphTarget | undefined): string {
    const { x, y } = labelPosition(d);
    return `rotate(${x - 90}) translate(${y},0) rotate(${
      Math.round(x) <= 180 ? 0 : 180
    })`;
  }
}
