import { scaleBand, ScaleConfig, scaleLinear, scaleOrdinal, scaleTime, scaleUtc } from '@visx/scale';
import { Accessor } from '@visx/shape/lib/types';
import { extent } from 'd3-array';
import { Scale, ScaleParams } from '../chartTypes';
import { DataPoint } from './getDataFromChildSeries';

export type DomainFilter = {
    start: number;
    end: number;
};

export const scaleTypeToScale: Record<ScaleParams['type'], Scale> = {
    time: scaleTime as Scale,
    timeUtc: scaleUtc as Scale,
    linear: scaleLinear as Scale,
    band: scaleBand as Scale,
    ordinal: scaleOrdinal as Scale
};

interface Params extends Omit<ScaleConfig, 'type' | 'domain' | 'range' | 'rangeRound'>, ScaleParams {
    data: DataPoint[];
    filterDomain?: DomainFilter;
    filterMinAccessor?: Accessor<any, any>;
    filterMaxAccessor?: Accessor<any, any>;
    minAccessor: Accessor<any, any>;
    maxAccessor: Accessor<any, any>;
}

function applyIncludeZero([min, max]) {
    return [Math.min(0, min), Math.max(0, max)];
}

export function getScaleForAccessor({
    data = [],
    minAccessor,
    maxAccessor,
    type = 'linear',
    includeZero = true,
    range = [] as any[],
    filterDomain = undefined,
    filterMinAccessor = undefined,
    filterMaxAccessor = undefined,
    domain: domainProp,
    domainPadding,
    ...rest
}: Params): Scale {
    let filteredData = data;

    if (filterDomain?.start) {
        const { start, end } = filterDomain;
        filteredData = data.filter((point) => {
            const minValue = filterMinAccessor?.(point);
            const maxValue = filterMaxAccessor?.(point);

            return minValue >= start && maxValue <= end;
        });
    }

    let domain = domainProp;

    if (!domain && (type === 'band' || type === 'ordinal')) {
        domain = filteredData.map(minAccessor);
    }

    if (!domain && (type === 'linear' || type === 'time' || type === 'timeUtc')) {
        domain = extent([...extent(filteredData, minAccessor), ...extent(filteredData, maxAccessor)]);
    }

    if (type === 'linear' && includeZero) {
        domain = applyIncludeZero(domain);
    }

    if (domainPadding && domain?.length === 2) {
        const [start, end] = domainPadding;
        if (type === 'time' || type === 'timeUtc') {
            if (start) domain[0].setDate(domain[0].getDate() - start);
            if (end) domain[1].setDate(domain[1].getDate() + end);
        } else {
            if (start) domain[0] -= start;
            if (end) domain[1] += end;
        }
    }

    return scaleTypeToScale[type]({ domain, range, ...rest });
}
