import {useCallback, useEffect, useRef} from 'react';
import {isFocusable, tabbable} from 'tabbable';
import type {FocusableElement} from 'tabbable';
import {v4 as uuid} from 'uuid';

const manager = (() => {
    const data: string[] = [];
    const remove = (item: string) => {
        const idx = data.indexOf(item);
        if (idx === -1) {
            return;
        }
        data.splice(idx, 1);
    };
    const add = (item: string) => {
        data.push(item);
        return () => {
            remove(item);
        };
    };
    return {
        add,
        isLast: (item: string) =>
            data.length ? data[data.length - 1] === item : false,
        remove,
    };
})();

export const getRootOfElement = (element: Element | null) => {
    const isRoot = (el: Element) => {
        return (el as HTMLElement)?.dataset?.focusRoot === '1';
    };
    const rootOfElement = (() => {
        let nextElement = element as HTMLElement | null;
        while (nextElement) {
            if (isRoot(nextElement)) {
                return nextElement;
            }
            nextElement = nextElement.parentElement;
        }
    })();
    return rootOfElement;
};

export const initializeFocus = (rootElement: HTMLElement) => {
    const tabbableElements = tabbable(rootElement, {
        includeContainer: false,
    });
    if (!tabbableElements.length) {
        focusElement(rootElement);
        return;
    }
    focusElement(tabbableElements[0], rootElement);
};

const focusElement = (element?: FocusableElement, fallback?: HTMLElement) => {
    if (element === document.activeElement) {
        return;
    }
    if (!element?.focus) {
        focusElement(fallback);
        return;
    }
    element?.focus();
};

export const focusTrap = (parentElement: HTMLElement, id: string) => {
    const handleTabEvents = (e: KeyboardEvent) => {
        if (e.key?.toLowerCase() !== 'tab') {
            return;
        }
        if (!manager.isLast(id)) {
            return;
        }

        const tabbableElements = tabbable(parentElement, {
            includeContainer: false,
        });
        const lastElement = tabbableElements[tabbableElements.length - 1];
        const firstElement = tabbableElements[0];

        if (tabbableElements.length === 0) {
            focusElement(parentElement);
            return;
        }

        if (e.shiftKey) {
            if (e.target === firstElement) {
                focusElement(lastElement);
                e.preventDefault();
            }
        } else {
            if (e.target === lastElement) {
                focusElement(firstElement);
                e.preventDefault();
            }
        }
    };

    return [
        () => {
            document.addEventListener('keydown', handleTabEvents);
        },
        () => {
            document.removeEventListener('keydown', handleTabEvents);
        },
    ];
};

export const useFocusTrap = (active = true) => {
    const idRef = useRef(uuid());
    const elementToRestoreFocusRef = useRef<Element | null>(null);
    const rootRef = useRef<HTMLElement | null>();

    const prevOpenRef = useRef(false);
    useEffect(() => {
        prevOpenRef.current = active;
    }, [active]);
    if (!prevOpenRef.current && active) {
        elementToRestoreFocusRef.current = document.activeElement;
    }

    const setRef = useCallback(
        (root: HTMLElement | null) => {
            rootRef.current = root;
            if (!active || !root) {
                return;
            }

            root.setAttribute('data-focus-root', '1');

            // move the focus within this trap
            if (!root.contains(document.activeElement)) {
                initializeFocus(root);
            }
        },
        [active],
    );

    useEffect(() => {
        const root = rootRef.current;
        const id = idRef.current;

        if (!active || !root) {
            return undefined;
        }

        const [startFocusTrap, stopFocusTrap] = focusTrap(root, id);
        manager.add(id);
        startFocusTrap?.();

        const elementToRestoreFocus =
            elementToRestoreFocusRef.current as HTMLElement;

        return () => {
            manager.remove(id);
            stopFocusTrap?.();
            const activeElement = document.activeElement;
            const activeRoot = getRootOfElement(activeElement);
            const isFocusInRoot = activeRoot && activeRoot === root;
            const isActiveElementFocusable =
                activeElement && isFocusable(activeElement);

            elementToRestoreFocusRef.current = null;
            if (
                !elementToRestoreFocus ||
                (!isFocusInRoot && isActiveElementFocusable)
            ) {
                return;
            }
            elementToRestoreFocus.focus?.();
        };
    }, [active]);

    return {focusTrapRef: rootRef, setFocusTrapRef: setRef};
};
