import styled from '@emotion/styled/macro'
import distinctColors from 'distinct-colors'
import { PlotData } from 'plotly.js'
import React, { useCallback, useMemo, useState } from 'react'
import Plot from 'react-plotly.js'
import Select from '../../../../components/select'
import { BiquadCoeffs } from '../../../../types/biquad'
import { BiquadStage } from '../../../../types/dbmc2evb2'
import { calculate, logFrequency, toFloat } from '../../../../utils/biquad'
import { baseMeta, FullPlotlyMeta } from '../../../../utils/plotly'
import useLocalStorage from '../../../../hooks/use-local-storage'

type Rate = { label: string; value: number }

export type ResponseType = 'gain' | 'phase'

interface Props {
  rates: Rate[]
  channels: { label: string; value: number }[]
  domains: { label: string; value: string }[]
  stages: BiquadStage[][][][]
  persistenceKey: string
}

const NUMBER_OF_POINTS_IN_CHART = 512
const COLORS = distinctColors({ count: 20 })
const CHART_WIDTH = 860
const GAIN_CHART_HEIGHT = 380
const PHASE_CHART_HEIGHT = 280

const calculateStage = (
  fs: number,
  stage: BiquadCoeffs
): { frequency: number[]; gain: number[]; phase: number[] } =>
  calculate(
    fs,
    toFloat(stage).unwrap(),
    logFrequency(10, 20000, NUMBER_OF_POINTS_IN_CHART)
  )

export default function FilterViewer({
  rates,
  channels,
  domains,
  stages,
  persistenceKey
}: Props) {
  const getTotalLabel = useCallback(
    (rateIndex: number, channelIndex: number, domainIndex: number) => {
      const parts = [rates[rateIndex].label, channels[channelIndex].label]

      if (domains && domainIndex < domains.length)
        parts.push(domains[domainIndex].label)

      return parts.filter(n => n !== null && n.length > 0).join(' ')
    },
    [rates, channels, domains]
  )

  const getStageLabel = useCallback(
    (
      rateIndex: number,
      channelIndex: number,
      domainIndex: number,
      stageIndex: number
    ) =>
      `${getTotalLabel(rateIndex, channelIndex, domainIndex)} (${stageIndex +
        1})`,
    [getTotalLabel]
  )

  const getRateIndex = useCallback(
    rate => {
      for (let i = 0; i < rates.length; i++) {
        if (rates[i].value === rate) {
          return i
        }
      }
      return -1
    },
    [rates]
  )

  const getChannelIndex = useCallback(
    channel => {
      for (let i = 0; i < channels.length; i++) {
        if (channels[i].value === channel) {
          return i
        }
      }
      return -1
    },
    [channels]
  )

  const getDomainIndex = useCallback(
    domain => {
      if (!domains) return -1

      for (let i = 0; i < domains.length; i++) {
        if (domains[i].value === domain) {
          return i
        }
      }
      return -1
    },
    [domains]
  )

  const chartOptions = useMemo(() => {
    const chartOpts = []
    for (let i = 0; i < rates.length; ++i) {
      for (let j = 0; j < channels.length; ++j) {
        if (domains) {
          for (let k = 0; k < domains.length; ++k) {
            chartOpts.push({
              label: getTotalLabel(i, j, k),
              value: getTotalLabel(i, j, k),
              chart: {
                rate: rates[i],
                channel: channels[j],
                domain: domains[k]
              }
            })
          }
        } else {
          // no domains specified

          chartOpts.push({
            label: getTotalLabel(i, j, -1),
            value: getTotalLabel(i, j, -1),
            chart: {
              rate: rates[i],
              channel: channels[j]
            }
          })
        }
      }
    }
    return chartOpts
  }, [rates, channels, domains, getTotalLabel])

  const [
    visible,
    setVisible
  ] = useLocalStorage(`filter-viewer-visible-charts-${persistenceKey}`, [
    chartOptions[0]
  ])

  const getTotalCalculation = (
    stages: { frequency: number[]; gain: number[]; phase: number[] }[]
  ) => {
    const frequency = stages[0].frequency
    const zero = Array.from({ length: frequency.length }, () => 0)

    return stages.reduce(
      (a, c) => ({
        ...a,
        gain: a.gain.map((a, i) => a + c.gain[i]),
        phase: a.phase.map((a, i) => a + c.phase[i])
      }),
      { frequency, gain: zero, phase: zero }
    )
  }

  const constructChartData = (
    { frequency, gain, phase },
    name,
    line,
    showlegend
  ): { gain: Partial<PlotData>; phase: Partial<PlotData> } => ({
    gain: {
      x: frequency,
      y: gain,
      mode: 'lines',
      name,
      line,
      showlegend,
      hoverlabel: {
        namelength: -1
      }
    },
    phase: {
      x: frequency,
      y: phase,
      mode: 'lines',
      name,
      line,
      showlegend,
      hoverlabel: {
        namelength: -1
      }
    }
  })

  const getStageLine = (index: number) => ({
    dash: 'dot',
    color: COLORS[index].hex(),
    width: 0.5
  })

  const push = (array, value) => {
    array.push(value)
    return array
  }

  const generateChartData = useCallback(() => {
    if (!visible) {
      setVisible([])
      return null
    }

    const getTotalLine = (index: number) => ({
      ...getStageLine(index),
      dash: 'solid',
      width: 2
    })

    // remove any unavailable plots
    for (let i = visible.length - 1; i >= 0; i--) {
      if (!visible[i].chart) {
        visible.splice(i, 1)
        continue
      }

      const {
        chart: { rate, channel, domain }
      } = visible[i]

      const rateIndex = getRateIndex(rate.value)
      const channelIndex = getChannelIndex(channel.value)

      if (domains) {
        if (!domain) {
          visible.splice(i, 1)
        } else {
          const domainIndex = getDomainIndex(domain.value)

          if (rateIndex < 0 || channelIndex < 0 || domainIndex < 0) {
            // this chart is not available, so remove it
            visible.splice(i, 1)
          }
        }
      } else if (rateIndex < 0 || channelIndex < 0) {
        // this chart is not available, so remove it
        visible.splice(i, 1)
      }
    }

    const data: {
      gain: Partial<PlotData>[]
      phase: Partial<PlotData>[]
    } = visible
      .map(({ chart: { rate, channel, domain } }, i) => {
        const rateIndex = getRateIndex(rate.value)
        const channelIndex = getChannelIndex(channel.value)

        const fs = rates[rateIndex].value
        const domainIndex = domains ? getDomainIndex(domain.value) : -1

        let stagesArray

        if (domainIndex >= 0) {
          stagesArray = stages[rateIndex][channelIndex][domainIndex]
        } else {
          stagesArray = stages[rateIndex][channelIndex]
        }

        const stagesData = stagesArray.map(stage => calculateStage(fs, stage))

        const totalData = getTotalCalculation(stagesData)

        return [
          constructChartData(
            totalData,
            getTotalLabel(rateIndex, channelIndex, domainIndex),
            getTotalLine(i),
            true
          ),
          ...stagesData.map((d, j) =>
            constructChartData(
              d,
              getStageLabel(rateIndex, channelIndex, domainIndex, j),
              getStageLine(i),
              false
            )
          )
        ]
      })
      .flat()
      .reduce(
        (a, c) => ({
          gain: push(a.gain, c.gain),
          phase: push(a.phase, c.phase)
        }),
        { gain: [], phase: [] }
      )

    return data
  }, [
    domains,
    getChannelIndex,
    getDomainIndex,
    getRateIndex,
    getStageLabel,
    getTotalLabel,
    rates,
    setVisible,
    stages,
    visible
  ])

  const data = generateChartData()

  const createPlotlyMeta = (
    type: 'gain' | 'phase',
    data: FullPlotlyMeta['data'],
    width = CHART_WIDTH
  ) =>
    ({
      ...baseMeta,
      layout: {
        ...baseMeta.layout,
        width,
        height: type === 'gain' ? GAIN_CHART_HEIGHT : PHASE_CHART_HEIGHT,
        xaxis: {
          ...baseMeta.layout.xaxis,
          range: [1, 4.301029995663981],
          title: 'frequency (Hz)',
          font: {
            size: 14
          }
        },
        yaxis: {
          ...baseMeta.layout.yaxis,
          autorange: true,
          title: type === 'gain' ? 'gain (dB)' : 'phase (rad)',
          font: {
            size: 14
          }
        }
      },
      data
    } as FullPlotlyMeta)

  const [plotSectionRef, setPlotSectionRef] = useState<HTMLDivElement>(null)

  return (
    <>
      {plotSectionRef && (
        <Select
          options={chartOptions}
          value={visible}
          isMulti
          // @ts-ignore
          onChange={e => {
            setVisible(e || [])
          }}
          styles={{
            container: baseStyles => ({
              ...baseStyles,
              marginLeft: 76
            }),
            control: baseStyles => ({
              ...baseStyles,
              width: plotSectionRef.clientWidth - 80,
              border: 0
            })
          }}
        />
      )}
      <PlotSection ref={setPlotSectionRef}>
        {!!plotSectionRef && (
          <>
            <ChartSection>
              <Plot
                {...createPlotlyMeta(
                  'gain',
                  data.gain,
                  plotSectionRef.clientWidth
                )}
              />
            </ChartSection>
            <ChartSection>
              <Plot
                {...createPlotlyMeta(
                  'phase',
                  data.phase,
                  plotSectionRef.clientWidth
                )}
              />
            </ChartSection>{' '}
          </>
        )}
      </PlotSection>
    </>
  )
}

const PlotSection = styled.div`
  position: sticky;
  top: 100px;
`

const ChartSection = styled.div`
  .js-plotly-plot,
  .plot-container,
  .svg-container {
    margin-left: auto;
    margin-right: auto;
  }
`
