
import { useState } from 'react'

import { useCallbackRef } from 'hooks/callback-ref'


export function useInfinitePagination(fetcher, { pageSize, rootElem } = {}) {
  if (!pageSize) {
    pageSize = 50
  }
  fetcher = useCallbackRef(fetcher)

  let [allRows, setAllRows] = useState([])
  let [currentRows, setCurrentRows] = useState([])
  let [page, setPage] = useState(0)
  let [loading, setLoading] = useState(true)
  let [range, setRange] = useState({
    start: 0,
    end: 0,
  })
  let [nextPageToken, setNextPageToken] = useState('')
  let [initial, setInitial] = useState(true)

  async function loadPage(newPage, pageToken) {
    setLoading(true)
  
    let start = newPage * pageSize
  
    // Fetch new rows if we are moving forward through the list.
    let rows = allRows
    if (start >= rows.length) {
      let reply = await fetcher.current({ pageToken, pageSize })
      setNextPageToken(reply.nextPageToken || '')
      rows = rows.concat(reply.rows || [])
      setAllRows(rows)
    }
  
    // Set the new range of rows.
    let end = Math.min((newPage + 1) * pageSize, rows.length)
    setCurrentRows(rows.slice(start, end))
    setPage(page)
    setRange({
      start,
      end,
    })
  
    // Scroll to the top of the page in the nearest 
    let scrollable = rootElem?.current?.closest('.overflow-y-scroll')
    if (scrollable) {
      if (scrollable.scrollTop > 0) {
        scrollable.scrollTo({
          top: 0,
          left: 0,
          behavior: 'smooth',
        })
      }
    } else {
      if (window.scrollY > 0) {
        window.scrollTo({
          top: 0,
          left: 0,
          behavior: 'smooth',
        })
      }
    }

    setLoading(false)
  }

  function nextPage() {
    if (loading) {
      return
    }
    if (!nextPageToken) {
      return
    }
    loadPage(page + 1, nextPageToken)
  }

  function prevPage() {
    if (loading) {
      return
    }
    if (page > 0) {
      loadPage(page - 1, nextPageToken)
    }
  }

  function reset() {
    if (!initialLoad) {
      return
    }
    setAllRows([])
    loadPage(0, '')
  }

  function initialLoad() {
    if (!initial) {
      throw new Error(`initialLoad() must be called once only`)
    }
    setInitial(false)
    loadPage(0, '')
  }

  return {
    nextPage,
    prevPage,
    hasPrevPage: page > 0,
    hasNextPage: !!nextPageToken,
    loading,
    reset,
    rows: currentRows,
    range,
    initialLoad,
  }
}
