import { useCallback, useState } from 'react';
import { isFunction } from 'lodash';

type SetStateCallback<T> = (prevState: T, nextState: T) => void;

const useSetState = <T extends Partial<Record<keyof T, unknown>>>(
  initialState: T = {} as T
): [
  T,
  (
    patch: Partial<T> | ((prevState: T) => Partial<T>),
    callback?: SetStateCallback<T>
  ) => void,
] => {
  const [state, set] = useState<T>(initialState);

  const setState = useCallback(
    (
      patch: Partial<T> | ((prevState: T) => Partial<T>),
      callback?: SetStateCallback<T>
    ) => {
      set((prevState) => {
        const nextState = {
          ...prevState,
          ...(isFunction(patch) ? patch(prevState) : patch),
        };
        if (callback) {
          callback(prevState, nextState);
        }
        return nextState;
      });
    },
    []
  );

  return [state, setState];
};

export default useSetState;
