import { Box, Typography, useTheme } from "@mui/material";
import React, { useEffect, useMemo, useState } from "react";
import TableRow from "./TableRow";
import BottomBar from "./BottomBar";
import HeaderRow from "./HeaderRow";
import Toolbar from "./Toolbar";

export interface Data {
  [key: string]: any;
}

export interface Column {
  field: keyof Data;
  headerName: string;
  renderCell?: (
    rowOrGroup: Row | Group,
    rows: Map<string, Row>,
    groups: Map<string, Group>
  ) => React.ReactNode;
  valueFormatter?: (value: any) => string;
  sortable?: boolean;
  width?: number;
  afterExpandColumn?: boolean;
}

export interface Row {
  data: Data;
}

export interface Group {
  data: Data;
  groups: Map<string, Group>;
  rows: Map<string, Row>;
}

export interface Filter {
  value: string;
  columns: (keyof Data)[];
  search?: boolean;
}

interface GroupedGridProps {
  selectable?: boolean;
  columns: Column[];
  hiddenColumns?: string[];
  groups: Map<string, Group>;
  rows: Map<string, Row>;
  selectedRowIds: string[];
  setSelectedRowIds: (selectedRowIds: string[]) => void;
  rowsContainerStyleOverrides?: React.CSSProperties;
  showToolbar?: boolean;
  toolbarButtons?: React.ReactNode;
  paginationOptions?: {
    pageSizeOptions: number[];
    defaultPageSize: number;
  };
  defaultSort?: {
    field: keyof Data;
    order: "asc" | "desc";
  };
  dataLabel?: string;
}

export type SelectionStatus =
  | "fullySelected"
  | "partiallySelected"
  | "notSelected";

export function isEverythingSelected(
  groups: Map<string, Group>,
  rows: Map<string, Row>,
  selectedRowIds: string[]
) {
  const allRows = getAllRowKeys(groups, rows);
  if (selectedRowIds.length === 0) return "notSelected";
  if (selectedRowIds.length === allRows.length) return "fullySelected";
  return "partiallySelected";
}

function getRowOrGroupDepth(
  rows: Map<string, Row>,
  groups: Map<string, Group>,
  key: string,
  depth: number = 0
): number {
  if (rows.has(key)) return depth;
  if (groups.has(key)) return depth;

  for (let group of Array.from(groups.values())) {
    const groupDepth = getRowOrGroupDepth(
      group.rows,
      group.groups,
      key,
      depth + 1
    );
    if (groupDepth > depth) return groupDepth;
  }

  return -1;
}

export function getAllChildRowKeys(group: Group) {
  const children = Array.from(group.rows.keys());
  for (let subGroup of Array.from(group.groups.values())) {
    children.push(...getAllChildRowKeys(subGroup));
  }
  return children;
}

export function getAllChildRowEntries(group: Group) {
  const children = Array.from(group.rows.entries());
  for (let subGroup of Array.from(group.groups.values())) {
    children.push(...getAllChildRowEntries(subGroup));
  }
  return children;
}

function getDirectChildrenKeys(group: Group) {
  return Array.from(group.rows.keys()).concat(Array.from(group.groups.keys()));
}

function getAllRowKeys(groups: Map<string, Group>, rows: Map<string, Row>) {
  const allRows = Array.from(rows.keys());
  for (let group of Array.from(groups.values())) {
    allRows.push(...getAllChildRowKeys(group));
  }
  return allRows;
}

function getAllRowAndGroupKeys(
  groups: Map<string, Group>,
  rows: Map<string, Row>
) {
  const allRows = Array.from(rows.keys());
  for (let [groupId, group] of Array.from(groups.entries())) {
    allRows.push(groupId);
    allRows.push(...getAllRowAndGroupKeys(group.groups, group.rows));
  }
  return allRows;
}

export function getAllRowAndGroupEntries(
  rows: Map<string, Row>,
  groups: Map<string, Group>
) {
  const allRows = Array.from(rows.entries());
  for (let [groupId, group] of Array.from(groups.entries())) {
    allRows.push([groupId, group]);
    allRows.push(...getAllRowAndGroupEntries(group.rows, group.groups));
  }
  return allRows;
}

function getFilteredRowsAndGroupsKeys(
  filters: Filter[],
  groups: Map<string, Group>,
  rows: Map<string, Row>
) {
  if (filters.length === 0) {
    return getAllRowAndGroupKeys(groups, rows);
  }

  const results: string[] = [];

  Array.from(rows.entries()).forEach(([rowId, row]) => {
    if (
      filters.every((filter) => {
        return filter.columns.some((column) => {
          const value = row.data[column];
          return (
            value !== undefined &&
            value !== null &&
            value.toString().toLowerCase().includes(filter.value.toLowerCase())
          );
        });
      })
    ) {
      results.push(rowId);
    }
  });

  Array.from(groups.entries()).forEach(([groupId, group]) => {
    const groupMatches = filters.every((filter) => {
      return filter.columns.some((column) => {
        const value = group.data[column];
        return (
          value !== undefined &&
          value !== null &&
          value.toString().toLowerCase().includes(filter.value.toLowerCase())
        );
      });
    });

    const childrenResults = getFilteredRowsAndGroupsKeys(
      filters,
      group.groups,
      group.rows
    );

    if (groupMatches || childrenResults.length > 0) {
      results.push(groupId);
      results.push(...childrenResults);
    }
  });

  return results;
}

function getSortedRowsAndGroups(
  groups: Map<string, Group>,
  rows: Map<string, Row>,
  sortOptions: { field: keyof Data; order: "asc" | "desc" }
): Map<string, Group | Row> {
  function sortEntries(entries: [string, Group | Row][]) {
    return entries.sort(([aId, a], [bId, b]) => {
      if (a.data[sortOptions.field] < b.data[sortOptions.field]) {
        return sortOptions.order === "asc" ? -1 : 1;
      }
      if (a.data[sortOptions.field] > b.data[sortOptions.field]) {
        return sortOptions.order === "asc" ? 1 : -1;
      }
      return 0;
    });
  }

  function recursiveSort(
    entries: [string, Group | Row][]
  ): [string, Group | Row][] {
    const sortedEntries = sortEntries(entries);
    return sortedEntries.flatMap(([entryId, entry]) => {
      if ("groups" in entry) {
        const sortedGroups = recursiveSort([
          ...Array.from(entry.groups.entries()),
        ]);
        const sortedRows = recursiveSort([...Array.from(entry.rows.entries())]);
        return [[entryId, entry], ...sortedGroups, ...sortedRows];
      }
      return [[entryId, entry]];
    });
  }

  const sortedRowsAndGroups = recursiveSort([
    ...Array.from(groups.entries()),
    ...Array.from(rows.entries()),
  ]);

  return new Map(sortedRowsAndGroups);
}

const GroupedGrid = ({
  selectable = true,
  columns,
  hiddenColumns,
  rows,
  groups,
  selectedRowIds = [],
  setSelectedRowIds = () => {},
  rowsContainerStyleOverrides,
  showToolbar = false,
  toolbarButtons,
  paginationOptions,
  defaultSort,
  dataLabel,
}: GroupedGridProps) => {
  const theme = useTheme();

  const visibleColumns = columns.filter(
    (column) => !hiddenColumns?.includes(column.field.toString())
  );

  const preExpandColumns = visibleColumns.filter(
    (column) => !column.afterExpandColumn
  );
  const postExpandColumns = visibleColumns.filter(
    (column) => column.afterExpandColumn
  );

  const [sortOptions, setSortOptions] = useState<{
    field: string | number;
    order: "asc" | "desc";
  }>(defaultSort || { field: columns[0].field, order: "asc" });

  const [filters, setFilters] = useState<Filter[]>([]);

  // All rows and groups (including nested ones) sorted
  const rowsAndGroups = useMemo(() => {
    return getSortedRowsAndGroups(groups, rows, sortOptions);
  }, [groups, rows, sortOptions]);

  // Keys of expanded groups
  const [expandedGroups, setExpandedGroups] = useState<Map<string, Group>>(
    new Map()
  );

  // Top level rows and children of expanded groups
  const openRowsAndGroups = useMemo(() => {
    const expandedGroupChildren: string[] = [];
    expandedGroups.forEach((group) => {
      expandedGroupChildren.push(...getDirectChildrenKeys(group));
    });

    const allKeys = [
      ...Array.from(rows.keys()),
      ...Array.from(groups.keys()),
      ...Array.from(expandedGroups.keys()),
      ...expandedGroupChildren,
    ];

    return new Set(allKeys);
  }, [rows, groups, expandedGroups]);

  // Rows and groups that match filters
  const filterResults = useMemo(() => {
    return getFilteredRowsAndGroupsKeys(filters, groups, rows);
  }, [filters, groups, rows]);

  const [pageSize, setPageSize] = useState(
    paginationOptions?.defaultPageSize || openRowsAndGroups.size
  );
  const [currentPage, setCurrentPage] = useState(1);
  const pageCount = useMemo(
    () =>
      Math.ceil(
        (Array.from(openRowsAndGroups.entries()).filter(([rgId, _]) =>
          filterResults.includes(rgId)
        ).length || 1) / pageSize
      ),
    [openRowsAndGroups, pageSize, filterResults]
  );

  useEffect(() => {
    if (currentPage > pageCount) setCurrentPage(pageCount);
  }, [currentPage, pageCount]);

  useEffect(() => {
    if (!paginationOptions) setPageSize(openRowsAndGroups.size);
  }, [paginationOptions, openRowsAndGroups]);

  // Sorted rows and groups to display on the current page
  const rowsAndGroupsToDisplay = useMemo(() => {
    const start = (currentPage - 1) * pageSize;
    const end = start + pageSize;
    return Array.from(rowsAndGroups.entries())
      .filter(([rgId, _]) => openRowsAndGroups.has(rgId))
      .filter(([rgId, _]) => filterResults.includes(rgId))
      .slice(start, end)
      .map(([rgId, _]) => rgId);
  }, [openRowsAndGroups, currentPage, pageSize, rowsAndGroups, filterResults]);

  // Remove selected rows that are no longer in the data
  useEffect(() => {
    selectedRowIds.forEach((id) => {
      if (!rowsAndGroups.has(id)) {
        setSelectedRowIds(selectedRowIds.filter((rowId) => rowId !== id));
      }
    });
  }, [rowsAndGroups, selectedRowIds]);

  return (
    <Box
      sx={{
        display: "flex",
        flexDirection: "column",
        borderRadius: 1,
        border: "1px solid",
        borderColor: "divider",
      }}
    >
      {showToolbar && (
        <Toolbar
          showSearch={true}
          toolbarButtons={toolbarButtons}
          addSearchFilter={(value: string) => {
            setFilters((prev: Filter[]) => {
              return prev
                .filter((filter) => !filter.search)
                .concat({
                  value: value,
                  columns: columns.map((column) => column.field),
                  search: true,
                });
            });
          }}
          removeSearchFilter={() => {
            setFilters((prev: Filter[]) => {
              return prev.filter((filter) => !filter.search);
            });
          }}
        />
      )}
      <Box
        sx={{
          display: "grid",
          gridTemplateColumns: `${selectable ? "min-content " : ""} ${preExpandColumns.map((column) => (column.width ? column.width + "px" : "auto")).join(" ")} min-content ${postExpandColumns.map((column) => (column.width ? column.width + "px" : "auto")).join(" ")}`,
          gridColumnGap: "16px",
        }}
      >
        <HeaderRow
          columns={visibleColumns}
          selectable={selectable}
          selectionStatus={isEverythingSelected(groups, rows, selectedRowIds)}
          selectAllRows={() => {
            setSelectedRowIds(getAllRowKeys(groups, rows));
          }}
          deselectAllRows={() => {
            setSelectedRowIds([]);
          }}
          sortOptions={sortOptions}
          setSortOptions={setSortOptions}
        />
        <Box
          sx={{
            display: "grid",
            gridTemplateColumns: "subgrid",
            gridColumnStart: 1,
            gridColumnEnd: visibleColumns.length + (selectable ? 3 : 2),
            ...rowsContainerStyleOverrides,
          }}
        >
          {rowsAndGroupsToDisplay.map((id, index) => {
            const rowOrGroup = rowsAndGroups.get(id);
            if (rowOrGroup == null) return <></>;
            return (
              <TableRow
                key={id}
                columns={visibleColumns}
                rows={rows}
                groups={groups}
                rowOrGroup={rowOrGroup}
                id={id}
                isGroup={"groups" in rowOrGroup}
                selectable={selectable}
                selectedRowIds={selectedRowIds}
                setSelectedRowIds={setSelectedRowIds}
                depth={getRowOrGroupDepth(rows, groups, id)}
                expandedGroups={expandedGroups}
                setExpandedGroups={setExpandedGroups}
                lastRow={index === rowsAndGroupsToDisplay.length - 1}
              />
            );
          })}
          {groups.size === 0 && rows.size === 0 && (
            <Box
              sx={{
                gridColumnStart: 1,
                gridColumnEnd: visibleColumns.length + (selectable ? 3 : 2),
                display: "flex",
                justifyContent: "center",
                alignItems: "center",
                p: 2,
              }}
            >
              <Typography
                variant="body2"
                sx={{ color: theme.palette.grey[500] }}
              >
                No data
              </Typography>
            </Box>
          )}
          {rowsAndGroupsToDisplay.length === 0 &&
            (groups.size > 0 || rows.size > 0) && (
              <Box
                sx={{
                  gridColumnStart: 1,
                  gridColumnEnd: visibleColumns.length + (selectable ? 3 : 2),
                  display: "flex",
                  justifyContent: "center",
                  alignItems: "center",
                  p: 2,
                }}
              >
                <Typography
                  variant="body2"
                  sx={{ color: theme.palette.grey[500] }}
                >
                  No results
                </Typography>
              </Box>
            )}
        </Box>
      </Box>
      <BottomBar
        totalRows={getAllRowKeys(groups, rows).length}
        selectedRowIds={selectedRowIds}
        currentPage={currentPage}
        setCurrentPage={setCurrentPage}
        pageCount={pageCount}
        pageSize={pageSize}
        setPageSize={setPageSize}
        paginationOptions={paginationOptions}
        dataLabel={dataLabel}
      />
    </Box>
  );
};

export default GroupedGrid;
