import { Group } from '@visx/group';
import { Threshold } from '@visx/threshold';
import React, { cloneElement, FC, memo, ReactElement, useContext, useMemo } from 'react';
import { ChartContext } from '../ChartContext';
import { Interpolation, SeriesOptions } from '../chartTypes';
import { callOrValue, componentName } from '../utils/chartUtils';
import { getChildSeries, interpolatorLookup, useAccessors, ChartAccessor } from '../utils';
import { AreaSeriesProps } from './AreaSeries';

const DEFAULT_OPACITY = 0.4;

interface AreaDifferenceSeriesProps<T extends object> extends Omit<SeriesOptions<T>, 'data'> {
    id: string;
    interpolation?: Interpolation;
    children: ReactElement[];
}

/**
 * @category Component
 * @group Chart
 */
let AreaDifferenceSeries = <T extends object>({
    id,
    axis,
    disableMouseEvents,
    interpolation = 'monotoneX',
    onClick,
    onMouseMove,
    onMouseLeave,
    children,
}: AreaDifferenceSeriesProps<T>) => {
    const { legend, xScales, yScales, getX: globalGetX, getY: globalGetY } = useContext(ChartContext);

    const [childSeries, restChilds] = getChildSeries<AreaSeriesProps>(children);
    const [series1, series2] = childSeries;

    if (
        childSeries.length !== 2 ||
        componentName(series1) !== 'AreaSeries' ||
        componentName(series2) !== 'AreaSeries'
    ) {
        console.warn('AreaDifferenceSeries expects exactly two AreaSeries children');
    }

    const {
        data: data1,
        fill: fill1Prop,
        fillOpacity: opacity1Prop,
        xAccessor: xAccessor1,
        yAccessor: yAccessor1,
    } = series1?.props ?? {};

    const {
        data: data2,
        fill: fill2Prop,
        fillOpacity: opacity2Prop,
        xAccessor: xAccessor2,
        yAccessor: yAccessor2,
    } = series2.props;

    const fill1Value = callOrValue(fill1Prop, data1);
    const fill2Value = callOrValue(fill2Prop, data2);

    const opacity1Value = callOrValue(opacity1Prop, data1);
    const opacity2Value = callOrValue(opacity2Prop, data2);

    // if (data1.length !== data2.length) {
    //     console.warn('AreaDifferenceSeries children should have the same data length');
    // }

    const { getX: getFirstX, getY: getFirstY } = useAccessors(xAccessor1, globalGetX, yAccessor1, globalGetY);
    const { getY: getSecondY } = useAccessors(xAccessor2, globalGetX, yAccessor2, globalGetY);

    const group = axis ?? 'undefined';
    const xScale = useMemo(() => xScales[group] ?? xScales.undefined, [group, xScales]);
    const yScale = useMemo(() => yScales[group] ?? yScales.undefined, [group, yScales]);

    const { getX, getY: getY0 } = useAccessors('x', globalGetX, 'y0', globalGetY);
    const { getY: getY1 } = useAccessors('x', globalGetX, 'y1', globalGetY);

    const x = useMemo(
        (): ChartAccessor<any> =>
            (...args) =>
                xScale?.(getX?.(...args)) ?? 0,
        [getX, xScale]
    );
    const y0 = useMemo(
        (): ChartAccessor<any> =>
            (...args) =>
                yScale?.(getY0?.(...args)) ?? 0,
        [getY0, yScale]
    );
    const y1 = useMemo(
        (): ChartAccessor<any> =>
            (...args) =>
                yScale?.(getY1?.(...args)) ?? 0,
        [getY1, yScale]
    );

    if (!xScale || !yScale || legend.state[id]) return null;

    const curve = interpolatorLookup[interpolation];
    const yExtent = yScale.range();
    const mergedData = data1.map((point, index) => ({
        x: getFirstX?.(point),
        y0: getFirstY?.(point),
        y1: getSecondY?.(data2[index]),
    }));

    return (
        <Group>
            {restChilds}
            <Threshold
                id={id}
                data={mergedData}
                x={x as any}
                y0={y0 as any}
                y1={y1 as any}
                clipAboveTo={Math.min(...yExtent)}
                clipBelowTo={Math.max(...yExtent)}
                curve={curve}
                aboveAreaProps={{
                    fill: fill1Value,
                    fillOpacity: opacity1Value || DEFAULT_OPACITY,
                }}
                belowAreaProps={{
                    fill: fill2Value,
                    fillOpacity: opacity2Value || DEFAULT_OPACITY,
                }}
            />
            {/* Threshold series do NOT plot lines, so render the area series without fill */}
            {childSeries.map((Child, index) =>
                cloneElement(Child, {
                    id: Child.props.id ?? `${id}-child-series-${index}`,
                    key: `${id}-child-series-${index}`,
                    onClick,
                    onMouseMove,
                    onMouseLeave,
                    interpolation,
                    disableMouseEvents: Child.props.disableMouseEvents || disableMouseEvents,
                    fill: 'transparent',
                })
            )}
        </Group>
    );
};

(AreaDifferenceSeries as FC).displayName = 'AreaDifferenceSeries';
AreaDifferenceSeries = memo(AreaDifferenceSeries);

export { AreaDifferenceSeries };
