import { Box, Flex } from '@radix-ui/themes';
import { CommonCallout } from 'components/common/callouts';
import { ErrorBoundary } from 'components/common/error-boundary';
import { CommonFormGrid } from 'components/common/form/grid';
import { CommonSelectInput } from 'components/common/form/select';
import { CHART_COLORS } from 'enums/charts.enums';
import { ArrayHelper } from 'lib_ts/classes/array.helper';
import { isNumber } from 'lib_ts/classes/math.utilities';
import { RADIX } from 'lib_ts/enums/radix-ui';
import { IOption } from 'lib_ts/interfaces/common/i-option';
import { IEvalModelResult } from 'lib_ts/interfaces/modelling/i-eval-models';
import React from 'react';
import {
  CartesianGrid,
  Legend,
  ResponsiveContainer,
  Scatter,
  ScatterChart,
  Tooltip,
  XAxis,
  YAxis,
  ZAxis,
} from 'recharts';

const COMPONENT_NAME = 'PlotModelMetrics';

const PRECISION_DECIMALS = 4;
const PRECISION_COEFF = Math.pow(10, PRECISION_DECIMALS);

const getRoundingCoeff = (key?: string) => {
  if (!key) {
    return 1;
  }

  /** msc below */
  if (['tilt', 'yaw'].includes(key)) {
    return 100;
  }

  if (['w1', 'w2', 'w3'].includes(key)) {
    return 0.1;
  }

  if (['a1', 'a2', 'a3'].includes(key)) {
    return 10;
  }

  /** bsp below */
  if (['vx', 'vy', 'vz'].includes(key)) {
    return 10;
  }

  if (['wx', 'wy', 'wz'].includes(key)) {
    return 0.01;
  }

  /** common below */
  if (['px', 'py', 'pz'].includes(key)) {
    return 1;
  }

  if (['qx', 'qy', 'qz', 'qw'].includes(key)) {
    return 1000;
  }

  return 1;
};

interface IScatterPoint {
  xValue: string;
  yValue: number;
  model: string;
}

interface IScatterGroup {
  name: string;
  points: IScatterPoint[];
}

interface IProps {
  metrics: IEvalModelResult[];
}

interface IState {
  selection0?: string;
  selection1?: string;

  refMetric?: IEvalModelResult;

  options0: IOption[];
  options1: IOption[];

  data?: IScatterGroup[];
}

export class PlotModelMetrics extends React.Component<IProps, IState> {
  constructor(props: IProps) {
    super(props);

    const firstMetric = props.metrics[0];

    if (!firstMetric) {
      throw new Error('Cannot get first metric');
    }

    if (!firstMetric.model_performance) {
      throw new Error('Cannot get model_performance of first metric');
    }

    this.state = {
      refMetric: firstMetric,

      options0: firstMetric
        ? Object.keys(firstMetric.model_performance).map((key) => ({
            label: key,
            value: key,
          }))
        : [],
      options1: [],
    };

    this.assembleData = this.assembleData.bind(this);
    this.selectionsComplete = this.selectionsComplete.bind(this);
  }

  componentDidUpdate(
    prevProps: Readonly<IProps>,
    prevState: Readonly<IState>
  ): void {
    if (this.state.selection1) {
      if (prevState.selection1 !== this.state.selection1) {
        this.assembleData();
      }
    } else if (this.state.data) {
      this.setState({ data: undefined });
    }
  }

  private assembleData() {
    const allModelIDs = ArrayHelper.unique(
      this.props.metrics.map((m) => m.model_id)
    );

    const groups: IScatterGroup[] = [];

    allModelIDs.forEach((model_id) => {
      const sortedMetrics = this.props.metrics
        .filter((metric) => metric.model_id === model_id)
        .sort((a, b) => a.created.localeCompare(b.created));

      if (sortedMetrics.length === 0) {
        return;
      }

      const firstMetric = sortedMetrics[0];

      const points: IScatterPoint[] = [];

      sortedMetrics.forEach((metric) => {
        const object0 = (metric.model_performance as any)[
          this.state.selection0 ?? ''
        ];
        if (!object0) {
          return;
        }

        const value1 = (object0 as any)[this.state.selection1 ?? ''] as number;
        if (!isNumber(value1)) {
          return;
        }

        const point: IScatterPoint = {
          model: firstMetric.model_name,
          xValue: `${this.state.selection0}.${this.state.selection1}`,
          yValue: Math.round(value1 * PRECISION_COEFF) / PRECISION_COEFF,
        };

        points.push(point);
      });

      if (points.length === 0) {
        return;
      }

      groups.push({
        name: firstMetric.model_name,
        points: points,
      });
    });

    this.setState({ data: groups });
  }

  private selectionsComplete(): boolean {
    return !!this.state.selection0 && !!this.state.selection1;
  }

  private renderGraph() {
    if (!this.selectionsComplete()) {
      return;
    }

    if (!this.state.data) {
      return;
    }

    const yLabel = ArrayHelper.unique([
      this.state.selection0,
      this.state.selection1,
    ])
      .join(' > ')
      .toUpperCase();

    const roundingCoeff = getRoundingCoeff(this.state.selection0);

    const allValues = this.state.data
      .flatMap((d) => d.points)
      .map((p) => p.yValue);

    const minValue = allValues.reduce(
      (prev, datum) => (prev < datum ? prev : datum),
      Infinity
    );
    const axisMin =
      Math.round(Math.floor(minValue * roundingCoeff)) / roundingCoeff;

    const maxValue = allValues.reduce(
      (prev, datum) => (prev > datum ? prev : datum),
      -Infinity
    );
    const axisMax =
      Math.round(Math.ceil(maxValue * roundingCoeff)) / roundingCoeff;

    return (
      <ResponsiveContainer width="100%" height={400}>
        <ScatterChart
          margin={{
            top: 20,
            bottom: 20,
          }}
        >
          <CartesianGrid />
          <XAxis name="Date" type="category" dataKey="xValue" />

          <YAxis
            name={yLabel}
            type="number"
            dataKey="yValue"
            domain={[axisMin, axisMax]}
          />

          <ZAxis name="Model" type="category" dataKey="model" />

          <Tooltip />

          <Legend layout="vertical" height={36} />

          {this.state.data.map((group, i) => (
            <Scatter
              key={`group-${i}`}
              name={group.name}
              data={group.points}
              fill={CHART_COLORS[i % CHART_COLORS.length]}
            />
          ))}
        </ScatterChart>
      </ResponsiveContainer>
    );
  }

  render() {
    return (
      <ErrorBoundary componentName={COMPONENT_NAME}>
        {!this.state.refMetric && (
          <CommonCallout text="Please provide at least one metric to use this view." />
        )}

        {this.state.refMetric && (
          <Flex direction="column" gap={RADIX.FLEX.GAP.SM}>
            <CommonFormGrid columns={2}>
              <CommonSelectInput
                id="plot-selection0"
                name="selection0"
                options={this.state.options0}
                value={this.state.selection0}
                onChange={(v) => {
                  const newOptions1: IOption[] =
                    this.state.refMetric && v
                      ? Object.keys(
                          (this.state.refMetric.model_performance as any)[v]
                        ).map((key) => ({ label: key, value: key }))
                      : [];

                  this.setState({
                    selection0: v,
                    selection1: undefined,
                    options1: newOptions1,
                  });
                }}
                optional
              />
              <CommonSelectInput
                id="plot-selection1"
                name="selection1"
                options={this.state.options1}
                value={this.state.selection1}
                onChange={(v) => this.setState({ selection1: v })}
                optional
              />
            </CommonFormGrid>

            <Box>{this.renderGraph()}</Box>
          </Flex>
        )}
      </ErrorBoundary>
    );
  }
}
