import {
    Box,
    Checkbox,
    IconButton,
    TablePaginationProps,
    Theme,
    Typography,
    alpha,
    styled,
} from "@mui/material";
import {
    DataGrid,
    GridColDef,
    GridExpandMoreIcon,
    GridPagination,
    gridPageCountSelector,
    useGridApiContext,
    useGridSelector,
    GRID_CHECKBOX_SELECTION_COL_DEF,
    GridRowId,
    GridRowSelectionModel,
    GridSortDirection,
    DataGridProps,
    GridFilterModel,
    gridExpandedSortedRowEntriesSelector,
    useGridApiRef,
} from "@mui/x-data-grid";
import MuiPagination from "@mui/material/Pagination";
import React, {useCallback, useEffect, useMemo} from "react";

interface iGroupedRow {
    id: number;
    group: string;
    parent: boolean;
    sortOrder?: number;

    [key: string]: any;
}

interface iStyledDataGrid {
    maxDepth: number;
    indentChildren: boolean;
}

interface iGroupedDataGrid<T extends iGroupedRow> {
    columns: GridColDef[];
    hiddenColumns?: string[];
    groupedRows: T[];
    defaultSort: { field: string; order: GridSortDirection };
    countParents?: boolean;
    indentChildren?: boolean;
    nestedGroups?: boolean;
    groupSeparator?: string;
    maxDepth?: number;
    itemLabel: string;
    updateSelectedRows: (selectedRows: GridRowId[]) => void;
    dataGridPropOverrides?: Omit<Partial<DataGridProps>, "columns" | "rows">;
    dataGridSlots?: any;
}

function Pagination({
                        rowCount,
                        itemLabel,
                        filtered,
                    }: {
    rowCount: number;
    itemLabel: string;
    filtered: boolean;
}) {
    const Pagination = ({
                            page,
                            onPageChange,
                            className,
                        }: Pick<TablePaginationProps, "page" | "onPageChange" | "className">) => {
        const apiRef = useGridApiContext();
        const pageCount = useGridSelector(apiRef, gridPageCountSelector);

        return (
            <MuiPagination
                color="primary"
                className={className}
                count={pageCount}
                page={page + 1}
                onChange={(event, newPage) => {
                    onPageChange(event as any, newPage - 1);
                }}
            />
        );
    };

    return (
        <Box
            sx={{
                display: "flex",
                flexDirection: "row",
                alignItems: "center",
                justifyContent: "space-between",
                width: "100%",
                pl: 2,
            }}
        >
            <Typography variant="body2">
                {filtered
                    ? rowCount.toString() +
                    " " +
                    itemLabel.toLowerCase() +
                    (rowCount > 1 ? "s " : " ") +
                    "found"
                    : rowCount.toString() +
                    " total " +
                    itemLabel.toLowerCase() +
                    (rowCount > 1 ? "s" : "")}
            </Typography>
            <GridPagination
                ActionsComponent={Pagination}
                labelDisplayedRows={() => ""}
            />
        </Box>
    );
}

const generateLevelStyles = (
    level: number,
    indentChildren: boolean,
    theme: Theme,
) => ({
    [`& .level-${level}`]: {
        paddingLeft: indentChildren ? `${48 * (level - 1)}px` : 0,
        backgroundColor: alpha(theme.palette.grey[100], level * 0.05),
        "&:hover": {
            backgroundColor: alpha(theme.palette.grey[100], level * 0.075),
        },
        "&.Mui-selected": {
            backgroundColor:
                alpha(theme.palette.primary.main, 0.1 + level * 0.075) + "!important",
            "&:hover": {
                backgroundColor:
                    alpha(theme.palette.primary.main, 0.1 + level * 0.1) + "!important",
            },
        },
    },
});

const StyledDataGrid = styled(DataGrid)<iStyledDataGrid>(({
                                                              theme,
                                                              maxDepth,
                                                              indentChildren,
                                                          }) => {
    let levelStyles = {};
    for (let i = 1; i <= maxDepth; i++) {
        Object.assign(levelStyles, generateLevelStyles(i, indentChildren, theme));
    }

    return {
        ...levelStyles,
        "& .child-row": {
            borderTop: "1px solid",
            borderTopColor: theme.palette.grey[200],
            display: "flex",
            "& > .MuiDataGrid-cell": {
                borderTop: "0px",
            },
        },
        ".flex-cell": {
            flexGrow: 1,
            minWidth: 0,
        },
    };
});

const GroupedDataGrid = <T extends iGroupedRow>({
                                                    columns,
                                                    hiddenColumns,
                                                    groupedRows,
                                                    defaultSort,
                                                    countParents = false,
                                                    nestedGroups,
                                                    indentChildren = true,
                                                    groupSeparator,
                                                    maxDepth = 5,
                                                    itemLabel,
                                                    updateSelectedRows,
                                                    dataGridPropOverrides,
                                                    dataGridSlots,
                                                }: iGroupedDataGrid<T>) => {
    const modifiedColumns = columns.map((column) => ({
        ...column,
        cellClassName: column.flex ? "flex-cell" : "",
    }));
    const groupedColumns: GridColDef[] = [
        {
            ...GRID_CHECKBOX_SELECTION_COL_DEF,
            renderCell: (params) => {
                return (
                    <Checkbox
                        checked={selectedItems.has(params.id)}
                        onChange={(e, checked) => {
                            handleRowSelection(params.id, checked);
                        }}
                        indeterminate={getIndeterminateState(params.id)}
                    />
                );
            },
        },
        {
            field: "id",
            headerName: "ID",
            type: "number",
            resizable: false,
            sortable: false,
            filterable: false,
            hideable: false,
        },
        ...modifiedColumns,
        {
            field: "parent",
            headerName: "Parent",
            type: "boolean",
            resizable: false,
            sortable: false,
            filterable: false,
            hideable: false,
        },
        {
            field: "group",
            headerName: "Group",
            type: "string",
            resizable: false,
            sortable: false,
            filterable: false,
            hideable: false,
        },
        {
            field: "sortOrder",
            type: "number",
            resizable: false,
            sortable: false,
            filterable: false,
            hideable: false,
        },
        {
            field: "expand",
            headerName: "",
            type: "actions",
            width: 48,
            resizable: false,
            sortable: false,
            filterable: false,
            hideable: false,
            renderCell: (params) => {
                const isExpanded = expandedRows.has(params.row.id);
                return (
                    <>
                        {params.row.parent && (
                            <IconButton onClick={() => handleExpandClick(params.row.id)}>
                                <GridExpandMoreIcon
                                    sx={{
                                        transform: isExpanded ? "rotate(180deg)" : "rotate(0deg)",
                                    }}
                                />
                            </IconButton>
                        )}
                    </>
                );
            },
        },
    ];

    const hiddenColumnsObject = hiddenColumns?.reduce(
        (acc: { [key: string]: boolean }, column) => {
            acc[column] = false;
            return acc;
        },
        {} as { [key: string]: boolean },
    );

    const gridApiRef = useGridApiRef();

    const [rowSelectionModel, setRowSelectionModel] =
        React.useState<GridRowSelectionModel>([]);
    const [selectedItems, setSelectedItems] = React.useState(
        new Set<GridRowId>(),
    );
    const [filterModel, setFilterModel] = React.useState<GridFilterModel>();
    const [postFilterExpandedRows, setPostFilterExpandedRows] = React.useState(
        new Set(),
    );
    const [expandedRows, setExpandedRows] = React.useState(new Set());

    const getRowGroups = useCallback(
        (id: GridRowId) => {
            if (nestedGroups) {
                return groupedRows
                    .find((row) => row.id === id)
                    ?.group.split(groupSeparator || ",");
            } else {
                return groupedRows.find((row) => row.id === id)?.group;
            }
        },
        [groupedRows, nestedGroups, groupSeparator],
    );

    const getDirectChildren = useCallback(
        (parentRowId: GridRowId) => {
            const parentRowGroups = getRowGroups(parentRowId);
            if (parentRowGroups) {
                if (!nestedGroups) {
                    return groupedRows.filter(
                        (row) => row.group === parentRowGroups && !row.parent,
                    );
                } else {
                    return groupedRows.filter(
                        (row) =>
                            ((arraysEqual(getRowGroups(row.id), parentRowGroups) &&
                                    !row.parent) ||
                                (arraysEqual(
                                        getRowGroups(row.id)?.slice(0, -1),
                                        parentRowGroups,
                                    ) &&
                                    row.parent)) &&
                            row.id !== parentRowId,
                    );
                }
            }
        },
        [groupedRows, getRowGroups, nestedGroups],
    );

    const visibleRows = useMemo(() => {
        let topLevelParentRows;
        if (nestedGroups) {
            topLevelParentRows = groupedRows.filter(
                (row) => row.parent && getRowGroups(row.id)?.length === 1,
            );
        } else {
            topLevelParentRows = groupedRows.filter((row) => row.parent);
        }
        const ungroupedRows = groupedRows.filter(
            (row) => row.group === "ungrouped",
        );
        if (expandedRows.size === 0) {
            return [...topLevelParentRows, ...ungroupedRows];
        } else {
            let expandedChildren: T[] = [];
            expandedRows.forEach((rowId) => {
                const parentRow = groupedRows.find((row) => row.id === rowId);
                if (parentRow) {
                    const children = getDirectChildren(parentRow.id);
                    if (children) expandedChildren.push(...children);
                }
            });
            return [...topLevelParentRows, ...ungroupedRows, ...expandedChildren];
        }
    }, [
        groupedRows,
        expandedRows,
        getDirectChildren,
        getRowGroups,
        nestedGroups,
    ]);

    function arraysEqual(
        a: string | string[] | undefined,
        b: string | string[] | undefined,
    ) {
        if (!Array.isArray(a) || !Array.isArray(b)) return a === b;
        if (a === b) return true;
        if (a == null || b == null) return false;
        if (a.length !== b.length) return false;

        for (let i = 0; i < a.length; ++i) {
            if (a[i] !== b[i]) return false;
        }
        return true;
    }

    function getChildRows(parentRowId: GridRowId) {
        const parentRowGroups = getRowGroups(parentRowId);
        if (!nestedGroups) {
            return groupedRows.filter(
                (row) => row.group === parentRowGroups && !row.parent,
            );
        } else {
            return groupedRows.filter((row) => {
                const rowGroups = getRowGroups(row.id);
                if (
                    Array.isArray(rowGroups) &&
                    Array.isArray(parentRowGroups) &&
                    row.id !== parentRowId
                ) {
                    return parentRowGroups.every((group) => rowGroups.includes(group));
                } else return false;
            });
        }
    }

    function getParentRows(childRowId: GridRowId) {
        const childRow = groupedRows.find((row) => row.id === childRowId);
        if (childRow) {
            if (!nestedGroups) {
                return groupedRows.filter(
                    (row) => row.group === childRow.group && row.parent,
                );
            } else {
                const childRowGroups = getRowGroups(childRowId);
                return groupedRows.filter((row) => {
                    const rowGroups = getRowGroups(row.id);
                    if (Array.isArray(rowGroups) && Array.isArray(childRowGroups)) {
                        return childRowGroups[0] === rowGroups[0] && row.parent;
                    }
                    return false;
                });
            }
        }
    }

    function handleExpandClick(rowId: number) {
        setExpandedRows((prevExpandedRows) => {
            const newExpandedRows = new Set(prevExpandedRows);
            if (newExpandedRows.has(rowId)) {
                newExpandedRows.delete(rowId);
                if (nestedGroups) {
                    const children = getChildRows(rowId);
                    children.forEach((child) => {
                        newExpandedRows.delete(child.id);
                    });
                }
            } else {
                newExpandedRows.add(rowId);
            }
            return newExpandedRows;
        });
    }

    function handleRowSelection(selectedRowId: GridRowId, checked: boolean) {
        const parentRows = groupedRows.filter((row) => row.parent);
        const newSelectedItems = new Set(selectedItems);

        // Handle parent row selection
        if (parentRows.find((row) => row.id === selectedRowId)) {
            newSelectedItems.add(selectedRowId);
            const children = getChildRows(selectedRowId);
            children.forEach((child) => {
                if (checked) {
                    newSelectedItems.add(child.id);
                } else {
                    newSelectedItems.delete(child.id);
                }
            });
        }

        // Handle individual row selection
        if (checked) {
            newSelectedItems.add(selectedRowId);
        } else {
            newSelectedItems.delete(selectedRowId);
        }

        // Handle child row selection
        const selectedRowParents = getParentRows(selectedRowId);
        if (selectedRowParents) {
            selectedRowParents.sort((a, b) => {
                return (
                    (getRowGroups(b.id)?.length ?? 0) - (getRowGroups(a.id)?.length ?? 0)
                );
            });
            selectedRowParents.forEach((parentRow) => {
                const children = getChildRows(parentRow.id);
                const selectedChildren = children.filter((child) =>
                    newSelectedItems.has(child.id),
                );
                if (selectedChildren.length === 0) {
                    newSelectedItems.delete(parentRow.id);
                } else {
                    newSelectedItems.add(parentRow.id);
                }
            });
        }

        setSelectedItems(newSelectedItems);
    }

    function getIndeterminateState(rowId: GridRowId) {
        const parentRows = groupedRows.filter((row) => row.parent);
        if (parentRows.find((row) => row.id === rowId)) {
            const children = getChildRows(rowId);
            const selectedChildren = children.filter((child) =>
                selectedItems.has(child.id),
            );
            return (
                selectedChildren.length > 0 && selectedChildren.length < children.length
            );
        }
        return false;
    }

    function sortRows(field: string, order: GridSortDirection) {
        const topLevelRows = groupedRows.filter(
            (row) =>
                (row.parent && getRowGroups(row.id)?.length === 1) ||
                row.group === "ungrouped",
        );

        const sortedTopLevelRows = [...topLevelRows].sort((a, b) => {
            if (a[field] < b[field]) {
                return order === "asc" ? -1 : 1;
            }
            if (a[field] > b[field]) {
                return order === "asc" ? 1 : -1;
            }
            return 0;
        });

        let newSortedOrder: iGroupedRow[] = [];

        sortedTopLevelRows.forEach((row) => {
            newSortedOrder.push(row);

            if (row.parent) {
                const children = getChildRows(row.id);
                //TODO: Sort children
                newSortedOrder = newSortedOrder.concat(children);
            }
        });

        newSortedOrder.forEach((row, index) => {
            const foundRow = groupedRows.find((r) => r.id === row.id);
            if (foundRow) {
                foundRow.sortOrder = index;
            }
        });
    }

    function filterRows(newFilterModel: GridFilterModel) {
        const matchingRows = groupedRows.filter((row) =>
            Object.keys(row).some(
                (key) =>
                    groupedColumns.find((groupedColumn) => groupedColumn.field === key) &&
                    newFilterModel.quickFilterValues?.every((value) =>
                        row[key]?.toString().toLowerCase().includes(value.toLowerCase()),
                    ),
            ),
        );

        if (
            matchingRows &&
            newFilterModel.quickFilterValues &&
            newFilterModel.quickFilterValues.length > 0
        ) {
            matchingRows.forEach((row) => {
                if (!visibleRows.find((r) => r.id === row.id)) {
                    setPostFilterExpandedRows((prevExpandedRows) => {
                        const newExpandedRows = new Set(prevExpandedRows);
                        const parentRows = groupedRows.filter(
                            (r) => r.parent && getChildRows(r.id).includes(row),
                        );
                        if (parentRows) {
                            parentRows.forEach((parentRow) => {
                                newExpandedRows.add(parentRow.id);
                            });
                        }
                        return newExpandedRows;
                    });
                }
            });
        }

        // TODO: Include parents & children in filter results
        matchingRows.forEach((row) => {
            const parentRows = groupedRows.filter(
                (r) => r.parent && getChildRows(r.id).includes(row),
            );
            const childRows = getChildRows(row.id);
            parentRows.forEach((parentRow) => {
            });
            childRows.forEach((childRow) => {
            });
        });

        return newFilterModel;
    }

    useEffect(() => {
        const newSelectionModel = new Set(rowSelectionModel);
        visibleRows.forEach((row) => {
            if (selectedItems.has(row.id)) {
                newSelectionModel.add(row.id);
            } else {
                newSelectionModel.delete(row.id);
            }
        });
        selectedItems.forEach((id) => {
            if (!visibleRows.find((row) => row.id === id)) {
                newSelectionModel.delete(id as GridRowId);
            }
        });
        setRowSelectionModel(Array.from(newSelectionModel));
        // eslint-disable-next-line react-hooks/exhaustive-deps
    }, [visibleRows, selectedItems]);

    useEffect(() => {
        if (filterModel) {
            setFilterModel(filterRows(filterModel));
        }
    }, [visibleRows]);

    useEffect(() => {
        updateSelectedRows(
            countParents
                ? Array.from(selectedItems)
                : Array.from(selectedItems).filter((id) =>
                    groupedRows.find((row) => row.id === id && !row.parent),
                ),
        );
    }, [selectedItems, updateSelectedRows]);

    useEffect(() => {
        if (postFilterExpandedRows.size > 0) {
            setExpandedRows((prevExpandedRows) => {
                const newExpandedRows = new Set(prevExpandedRows);
                postFilterExpandedRows.forEach((rowId) => {
                    newExpandedRows.add(rowId);
                });
                return newExpandedRows;
            });
        }
        if (
            filterModel &&
            filterModel.quickFilterValues?.length === 0 &&
            postFilterExpandedRows.size > 0
        ) {
            postFilterExpandedRows.forEach((rowId) => {
                setExpandedRows((prevExpandedRows) => {
                    const newExpandedRows = new Set(prevExpandedRows);
                    newExpandedRows.delete(rowId);
                    return newExpandedRows;
                });
            });
            setPostFilterExpandedRows(new Set());
        }
    }, [postFilterExpandedRows, filterModel]);

    useEffect(() => {
        sortRows(defaultSort.field, defaultSort.order);
    });

    return (
        <StyledDataGrid
            apiRef={gridApiRef}
            maxDepth={maxDepth}
            indentChildren={indentChildren}
            columns={groupedColumns}
            rows={visibleRows}
            checkboxSelection
            disableRowSelectionOnClick
            hideFooterSelectedRowCount
            filterModel={filterModel}
            disableColumnFilter
            onFilterModelChange={(newFilterModel) =>
                setFilterModel(filterRows(newFilterModel))
            }
            rowSelectionModel={rowSelectionModel}
            onRowSelectionModelChange={(newSelection) => {
                if (newSelection.length === 0) {
                    setSelectedItems(new Set());
                } else if (newSelection.length === visibleRows.length) {
                    setSelectedItems(new Set(groupedRows.map((row) => row.id)));
                }
            }}
            sortModel={[
                {
                    field: "sortOrder",
                    sort: "asc",
                },
            ]}
            onSortModelChange={(newSortModel) => {
                sortRows(newSortModel[0].field, newSortModel[0].sort);
            }}
            columnVisibilityModel={{
                id: false,
                parent: false,
                group: false,
                sortOrder: false,
                ...hiddenColumnsObject,
            }}
            initialState={{
                pagination: {
                    paginationModel: {
                        page: 0,
                        pageSize: 10,
                    },
                },
            }}
            pageSizeOptions={[10]}
            slots={{
                ...dataGridSlots,
                pagination: () => (
                    <Pagination
                        rowCount={
                            countParents
                                ? filterModel &&
                                filterModel.quickFilterValues &&
                                filterModel.quickFilterValues.length > 0
                                    ? gridExpandedSortedRowEntriesSelector(gridApiRef).length
                                    : groupedRows.length
                                : filterModel &&
                                filterModel.quickFilterValues &&
                                filterModel.quickFilterValues.length > 0
                                    ? gridExpandedSortedRowEntriesSelector(gridApiRef).filter(
                                        (row) =>
                                            groupedRows.find((r) => r.id === row.id && !r.parent),
                                    ).length
                                    : groupedRows.filter((row) => !row.parent).length
                        }
                        itemLabel={itemLabel}
                        filtered={
                            (filterModel &&
                                filterModel.quickFilterValues &&
                                filterModel.quickFilterValues.length > 0) ||
                            false
                        }
                    />
                ),
            }}
            getRowClassName={(params) => {
                const rowGroups = getRowGroups(params.id);
                if (
                    !nestedGroups &&
                    params.row.group !== "ungrouped" &&
                    !params.row.parent
                ) {
                    return "child-row level-2";
                } else if (
                    rowGroups !== undefined &&
                    rowGroups.length > 1 &&
                    nestedGroups &&
                    params.row.parent
                ) {
                    return "child-row level-" + rowGroups.length;
                } else if (
                    nestedGroups &&
                    !params.row.parent &&
                    rowGroups !== undefined &&
                    params.row.group !== "ungrouped"
                ) {
                    return "child-row level-" + (rowGroups.length + 1);
                } else {
                    return "parent-row";
                }
            }}
            {...dataGridPropOverrides}
        />
    );
};

export default GroupedDataGrid;
