import * as d3 from "d3";
import moment from "moment";

export enum AxisType {
    XAxis = 0,
    YAxis = 1,
}

export default class D3ChartBase {
    private root:
        | d3.Selection<SVGSVGElement, unknown, null, undefined>
        | d3.Selection<SVGGElement, unknown, null, undefined>;
    private width: number;
    private height: number;

    constructor(
        root:
            | d3.Selection<SVGSVGElement, unknown, null, undefined>
            | d3.Selection<SVGGElement, unknown, null, undefined>,
        width: number,
        height: number
    ) {
        // Typescript requires those to be initialized here, not in a function
        this.root = root;
        this.width = width;
        this.height = height;
    }

    public reset(
        root:
            | d3.Selection<SVGSVGElement, unknown, null, undefined>
            | d3.Selection<SVGGElement, unknown, null, undefined>,
        width: number,
        height: number
    ): void {
        this.root = root;
        this.width = width;
        this.height = height;
    }

    public drawBackground(color: string): void {
        this.root
            .append("g")
            .append("rect")
            .attr("width", this.width)
            .attr("height", this.height)
            .attr("fill", color);
    }

    public drawBandAxis(
        type: AxisType,
        axisConfig?: { color: string },
        tickConfig?: {
            color: string;
            fontSize: number;
            fontFamily?: string;
            tickValues?: string[];
            tickFormat?: (value: string, index: number) => string;
            tickSize?: number;
        },
        domain: Iterable<string> = [],
        paddingInner: number = 0,
        paddingOuter: number = 0
    ): {
        axis: d3.Axis<string>;
        g: d3.Selection<SVGGElement, unknown, null, undefined>;
    } {
        let scale = d3
            .scaleBand()
            .domain(domain)
            .range(type === AxisType.XAxis ? [0, this.width] : [this.height, 0])
            .paddingInner(paddingInner)
            .paddingOuter(paddingOuter);
        let axisG = this.root.append("g");
        let axis: d3.Axis<string>;
        if (type === AxisType.XAxis) {
            axis = d3.axisBottom(scale).tickSize(0.01);
            axisG.attr("transform", `translate(0, ${this.height})`).call(axis);
        } else {
            axis = d3.axisLeft(scale).tickSize(0.01);
        }

        if (tickConfig != null) {
            if (tickConfig.tickValues != null) {
                axis.tickValues(tickConfig.tickValues);
            }
            if (tickConfig.tickFormat != null) {
                axis.tickFormat(tickConfig.tickFormat);
            }
            if (tickConfig.tickSize != null) {
                axis.tickSize(tickConfig.tickSize);
            }
        }

        axisG.call(axis);

        if (axisConfig != null) {
            axisG.selectAll("path").attr("stroke", axisConfig.color);
        } else {
            axisG.selectAll("path").attr("visibility", "hidden");
        }

        if (tickConfig != null) {
            axisG
                .selectAll("text")
                .attr("fill", tickConfig.color)
                .style("font-size", `${tickConfig.fontSize}px`);
        } else {
            axisG.selectAll("line").attr("visibility", "hidden");
            axisG.selectAll("text").attr("font-size", 0);
        }

        return {
            axis: axis,
            g: axisG,
        };
    }

    public static calculateLinearTicks(
        minVal: number,
        maxVal: number,
        userSpecifiedInterval?:
            | number
            | { amount: number; unit: moment.unitOfTime.DurationConstructor }
            | null,
        decimals?: number
    ): {
        ticks: number[];
        interval: number;
        decimals: number;
    } {
        const defaultNumberOfTicks = 5;

        let ticks: number[];
        let interval: number;

        if (maxVal < minVal) {
            maxVal = minVal;
        }

        // maxVal should always be in ticks
        if (minVal === maxVal) {
            ticks = [minVal];
            interval = 0;
        } else if (
            userSpecifiedInterval != null &&
            ((typeof userSpecifiedInterval === "object" &&
                userSpecifiedInterval?.amount <= 0) ||
                userSpecifiedInterval <= 0)
        ) {
            ticks = [minVal, maxVal];
            interval = maxVal - minVal;
        } else if (userSpecifiedInterval == null) {
            if (decimals === 0) {
                // If decimals are set to 0, then only integers should be ticks
                minVal = Math.ceil(minVal);
                maxVal = Math.floor(maxVal);
                interval = Math.floor(
                    (maxVal - minVal) / (defaultNumberOfTicks - 1)
                );
                if (interval > 0) {
                    ticks = [];
                    for (let tick = minVal; tick <= maxVal; tick += interval) {
                        ticks.push(tick);
                    }
                } else {
                    ticks = [minVal, maxVal];
                    interval = maxVal - minVal;
                }
            } else {
                ticks = [];
                interval = (maxVal - minVal) / (defaultNumberOfTicks - 1);
                for (let tick = minVal; tick < maxVal; tick += interval) {
                    ticks.push(tick);
                }
                ticks.push(maxVal);
            }
        } else {
            ticks = [];
            if (typeof userSpecifiedInterval === "number") {
                interval = Math.abs(userSpecifiedInterval);
                for (let tick = minVal; tick <= maxVal; tick += interval) {
                    ticks.push(tick);
                }
            } else {
                interval = 0;
                let minTick = moment.unix(minVal);
                let maxTick = moment.unix(maxVal);

                let amount = userSpecifiedInterval.amount;
                let unit = userSpecifiedInterval.unit;
                // moment.js does not support fractions
                if (!Number.isInteger(amount)) {
                    amount = Math.ceil(
                        moment
                            .duration(1, userSpecifiedInterval.unit)
                            .asMilliseconds() * userSpecifiedInterval.amount
                    );
                    unit = "ms";
                }

                if (amount <= 0) {
                    ticks = [minVal, maxVal];
                    interval = maxVal - minVal;
                } else {
                    for (
                        let tick = minTick;
                        tick.isSameOrBefore(maxTick);
                        tick.add(amount, unit)
                    ) {
                        ticks.push(tick.unix());
                    }
                }
            }
        }

        // Automatically calculate the number of decimals
        if (typeof userSpecifiedInterval === "number") {
            if (decimals == null) {
                if (interval === 0) {
                    decimals = Number.isInteger(minVal) ? 0 : 1;
                } else if (interval >= 1) {
                    decimals = Number.isInteger(interval) ? 0 : 1;
                } else {
                    // Else get the position of the first non-zero digit
                    decimals = -Math.floor(Math.log10(interval));
                }
            }
        } else {
            decimals = 0;
        }

        return {
            ticks: ticks,
            interval: interval,
            decimals: decimals,
        };
    }

    public drawLinearAxis(
        type: AxisType,
        axisConfig?: { color: string },
        tickConfig?: {
            color?: string;
            fontSize?: number;
            fontFamily?: string;
            tickCount?: number;
            tickValues?: number[];
            tickFormat?: (value: number, index: number) => string;
            tickSize?: number;
            tickRange?: [number | undefined, number | undefined];
        },
        domain: Iterable<d3.NumberValue> = [],
        ticksIntervalConfig?: {
            interval?: number | null;
            minVal: number;
            maxVal: number;
        },
        padding?: {
            start?: number;
            end?: number;
        }
    ): {
        axis: d3.Axis<number>;
        g: d3.Selection<SVGGElement, unknown, null, undefined>;
    } {
        // Set ticks interval
        const ticksValues = [];
        if (ticksIntervalConfig?.interval) {
            const { interval, minVal, maxVal } = ticksIntervalConfig;
            if (interval > 0) {
                const isInteger = Number.isInteger(interval);
                for (let i = minVal; i <= maxVal; i += interval) {
                    const value = i;
                    if (value <= maxVal) {
                        if (isInteger) {
                            ticksValues.push(value);
                        } else {
                            ticksValues.push(Number(value.toFixed(2)));
                        }
                    }
                }
                tickConfig = {
                    ...tickConfig,
                    tickValues: ticksValues,
                    tickCount: ticksValues.length,
                };
            }
        }

        let scale = d3.scaleLinear(
            domain,
            type === AxisType.XAxis ? [0, this.width] : [this.height, 0]
        );

        // Add padding
        const domainArray = Array.from(domain);
        let minDomainValue =
            typeof domainArray[0] === "number"
                ? domainArray[0]
                : domainArray[0].valueOf();
        let maxDomainValue =
            typeof domainArray[1] === "number"
                ? domainArray[1]
                : domainArray[1].valueOf();
        const paddingStart = padding?.start ?? 0;
        const paddingEnd = padding?.end ?? 0;
        if (type === AxisType.XAxis) {
            minDomainValue -= scale.invert(paddingStart) - minDomainValue;
            maxDomainValue += scale.invert(paddingEnd) - minDomainValue;
        } else {
            minDomainValue -=
                scale.invert(this.height - paddingStart) - minDomainValue;
            maxDomainValue +=
                scale.invert(this.height - paddingEnd) - minDomainValue;
        }
        scale.domain([minDomainValue, maxDomainValue]);

        let axisG = this.root.append("g");
        let axis: d3.Axis<number>;
        if (type === AxisType.XAxis) {
            axis = d3.axisBottom<number>(scale);
            axisG.attr("transform", `translate(0, ${this.height})`);
        } else {
            axis = d3.axisLeft<number>(scale);
        }

        if (tickConfig != null) {
            if (tickConfig.tickCount != null) {
                axis.ticks(tickConfig.tickCount);
            }
            if (tickConfig.tickValues != null) {
                axis.tickValues(tickConfig.tickValues);
            }
            if (tickConfig.tickFormat != null) {
                axis.tickFormat(tickConfig.tickFormat);
            }
            if (tickConfig.tickSize != null) {
                axis.tickSize(tickConfig.tickSize);
            }
            if (tickConfig.tickRange != null) {
                axis.tickValues(
                    (axis.tickValues() ?? scale.ticks()).filter(
                        (value) =>
                            (tickConfig!.tickRange![0] == null ||
                                value >= tickConfig!.tickRange![0]) &&
                            (tickConfig!.tickRange![1] == null ||
                                value <= tickConfig!.tickRange![1])
                    )
                );
            }
        }

        axisG.call(axis);

        if (axisConfig != null) {
            axisG.selectAll("path").attr("stroke", axisConfig.color);
            axisG.selectAll("line").attr("stroke", axisConfig.color);
        } else {
            axisG.selectAll("path").attr("visibility", "hidden");
        }

        if (tickConfig != null) {
            let text = axisG.selectAll("text");
            if (tickConfig.color != null) {
                text.attr("fill", tickConfig.color);
            }
            if (tickConfig.fontSize) {
                text.style("font-size", `${tickConfig.fontSize}px`);
            }
            text.attr(
                "transform",
                `translate(0, ${type === AxisType.XAxis ? 2 : 0})`
            );
            if (tickConfig.fontFamily != null) {
                axisG.style("font-family", tickConfig.fontFamily);
            }
        } else {
            axisG.selectAll("line").attr("visibility", "hidden");
            axisG.selectAll("text").attr("font-size", 0);
        }

        return {
            axis: axis,
            g: axisG,
        };
    }

    public drawGrid(
        xTicks: number[],
        yTicks: number[],
        color: string,
        xScale: (tick: number) => number,
        yScale: (tick: number) => number,
        dashed: boolean = false
    ): d3.Selection<SVGGElement, unknown, null, undefined> {
        let grid = this.root.append("g");
        let horizontalLines = grid.append("g");
        horizontalLines
            .selectAll<SVGGElement, number[]>("line")
            .data(yTicks)
            .join("line")
            .attr("x1", 0)
            .attr("x2", this.width)
            .attr("y1", yScale)
            .attr("y2", yScale)
            .attr("fill", "none")
            .attr("stroke", color)
            .attr("stroke-width", 1);
        let verticalLines = grid.append("g");
        verticalLines
            .selectAll<SVGGElement, number[]>("line")
            .data(xTicks)
            .join("line")
            .attr("x1", xScale)
            .attr("x2", xScale)
            .attr("y1", 0)
            .attr("y2", this.height)
            .attr("fill", "none")
            .attr("stroke", color)
            .attr("stroke-width", 1);
        if (dashed) {
            horizontalLines.style("stroke-dasharray", "3, 3");
            verticalLines.style("stroke-dasharray", "3, 3");
        }
        return grid;
    }

    public drawHorizontalLines<LineDatum>(
        data: LineDatum[],
        color: string,
        text: (datum: LineDatum) => string,
        y: (datum: LineDatum) => number
    ): d3.Selection<SVGGElement, unknown, null, undefined> {
        let g = this.root.append("g");
        let selection = g.selectAll("g").data(data).join("g");

        selection
            .append("line")
            .attr("x1", 0)
            .attr("x2", this.width)
            .attr("y1", y)
            .attr("y2", y)
            .attr("fill", "none")
            .attr("stroke", "red")
            .attr("stroke-width", 1)
            .style("stroke-dasharray", "3, 3");
        selection
            .append("text")
            .attr("x", this.width - 20)
            .attr("y", (d) => y(d) + 10)
            .attr("text-anchor", "start")
            .style("fill", color)
            .style("font-weight", 700)
            .style("font-size", 12)
            .text(text);

        return g;
    }

    public drawVerticalLines<LineDatum>(
        data: LineDatum[],
        color: string,
        text: (datum: LineDatum) => string,
        x: (datum: LineDatum) => number
    ): d3.Selection<SVGGElement, unknown, null, undefined> {
        let g = this.root.append("g");
        let selection = g.selectAll("g").data(data).join("g");

        selection
            .append("line")
            .attr("x1", x)
            .attr("x2", x)
            .attr("y1", 0)
            .attr("y2", this.height)
            .attr("fill", "none")
            .attr("stroke", "red")
            .attr("stroke-width", 1)
            .style("stroke-dasharray", "3, 3");
        selection
            .append("text")
            .attr("x", (d) => x(d) + 10)
            .attr("y", this.height - 20)
            .attr("text-anchor", "start")
            .style("fill", color)
            .style("font-weight", 700)
            .style("font-size", 12)
            .text(text);

        return g;
    }
}
