import { AccuracyMeasureType, AccuracyWeightMethod } from '@core/constants/accuracy.constants';
import { PastForecastDTO } from '@core/entities/dtos/past-forecast-dto';
import { FittedValueDTO, PlotValueDTO, SimplePlotValue } from '@core/entities/dtos/plot-value-dto';
import { ShapValueDTO } from '@core/store/stat-model/dtos/stat-model.dto';
import { StatTransformationModel } from '@core/store/stat-model/stat-model.model';
import { StatModelUtils } from '@core/store/stat-model/stat-model.utils';
import { PeriodicityType } from '@modules/lang/language-files/periodicities';
import { ModelName } from '@modules/lang/types/model-name';
import { ALGSingleSeriesModel } from '@shared/components/line-graph/alg-models/graph-data.model';
import { ValueUtils } from '../value.utils';

export namespace SummaryUtils {
  export interface SummaryInfo {
    Name: string;
    Description: string;
    Values: PlotValueDTO[];
    WFittedValues?: FittedValueDTO[];
    WShapValues?: ShapValueDTO[];
    Factor?: number;
    FactorAbbr?: string;
    PastForecasts: any[];
    Color: string;
    show: boolean;
  }

  /** Contains all weights for the selected model weight measurement. */
  export class ModelWeightsInfo {
    public ModelName: string;
    public Stepwise: number[];
  }

  export class WeightMeasure {
    public Measure: AccuracyMeasureType | 'EQUAL';
    public Method: AccuracyWeightMethod;
  }

  /** Map from a model-name to an object containing the type and the corresponding accuracies. */
  export class AccuracyModelMap extends Map<string, { type: WeightMeasure, accuracies: number[][]; }> { }

  export function getWeightMeasureString(measure: WeightMeasure) {
    return `${measure.Measure}${measure.Method === 'First' ? '' : ' ' + measure.Method}`;
  }

  export function getWeightMeasure(dbString: string): WeightMeasure {
    switch (dbString) {
      case 'EQUAL': return { Measure: 'EQUAL', Method: 'First' };
      case 'RMSE': return { Measure: 'RMSE', Method: 'First' };
      case 'RMSE Stepwise': return { Measure: 'RMSE', Method: 'Stepwise' };
      case 'RMSE Average': return { Measure: 'RMSE', Method: 'Average' };
      case 'MAPE': return { Measure: 'MAPE', Method: 'First' };
      case 'MAPE Stepwise': return { Measure: 'MAPE', Method: 'Stepwise' };
      case 'MAPE Average': return { Measure: 'MAPE', Method: 'Average' };
      case 'MASE': return { Measure: 'MASE', Method: 'First' };
      case 'MASE Stepwise': return { Measure: 'MASE', Method: 'Stepwise' };
      case 'MASE Average': return { Measure: 'MASE', Method: 'Average' };
      case 'ME': return { Measure: 'ME', Method: 'First' };
      case 'ME Stepwise': return { Measure: 'ME', Method: 'Stepwise' };
      case 'ME Average': return { Measure: 'ME', Method: 'Average' };
      case 'MAE': return { Measure: 'MAE', Method: 'First' };
      case 'MAE Stepwise': return { Measure: 'MAE', Method: 'Stepwise' };
      case 'MAE Average': return { Measure: 'MAE', Method: 'Average' };
      case 'MPE': return { Measure: 'MPE', Method: 'First' };
      case 'MPE Stepwise': return { Measure: 'MPE', Method: 'Stepwise' };
      case 'MPE Average': return { Measure: 'MPE', Method: 'Average' };
      case 'HR': return { Measure: 'HR', Method: 'First' };
      case 'HR Stepwise': return { Measure: 'HR', Method: 'Stepwise' };
      case 'HR Average': return { Measure: 'HR', Method: 'Average' };
      case 'RSQ': return { Measure: 'RSQ', Method: 'First' };
      case 'RSQ Stepwise': return { Measure: 'RSQ', Method: 'Stepwise' };
      case 'RSQ Average': return { Measure: 'RSQ', Method: 'Average' };
      default:
        return { Measure: 'RMSE', Method: 'Stepwise' };
    }
  }

  export function getWeight(measure: WeightMeasure, info: ModelWeightsInfo) {
    return measure.Method === 'First'
      ? info.Stepwise.map(() => info.Stepwise[0])
      : measure.Method === 'Stepwise'
        ? info.Stepwise
        : info.Stepwise.map(() => info.Stepwise.avg());
  }

  export function summarizeAndConsolidateModels(input: {
    activeVariableIds: string[],
    dates: { D: Date, m: moment.Moment; }[];
    data: StatTransformationModel[],
    variableId: string,
    forecastedCount: number,
    exponent: number,
    weightInfo: WeightMeasure,
    setPastForecast: boolean,
    name: string,
    desc: string,
    periodicity: PeriodicityType,
    color: string,
    modelName: ModelName;
  }): ALGSingleSeriesModel {
    if (!input.data.filter(x => x.show).length) { return null; }
    const modelWeights = calculateModelWeights(input.data, input.weightInfo, input.forecastedCount, input.exponent);
    calculateWeightedModelValues(input.data, input.variableId, input.forecastedCount, input.weightInfo, modelWeights);
    return consolidateModelValues({
      activeVariableIds: input.activeVariableIds,
      dates: input.dates,
      models: input.data,
      variableId: input.variableId,
      hidePast: !input.setPastForecast,
      forecastedCount: input.forecastedCount,
      name: input.name,
      desc: input.desc,
      periodicity: input.periodicity,
      color: input.color,
      modelName: input.modelName
    });
  }

  export function calculateModelWeights(
    models: StatTransformationModel[],
    weightMeasure: SummaryUtils.WeightMeasure,
    forecastedCount: number,
    exponent: number
  ): ModelWeightsInfo[] {

    const steps = Array.from(Array(forecastedCount).keys());
    const modelsToUse = models.filter(x => x.show);
    let stepwise: number[][] = [];

    switch (weightMeasure.Measure) {
      case 'EQUAL':
        const equal_weight = 1 / modelsToUse.length;
        const weights = modelsToUse.map(() => equal_weight);
        steps.forEach(() => stepwise.push(weights));
        break;
      case 'MAPE':
        steps.forEach(step => stepwise.push(getWeights(modelsToUse.map(x => x.MAPESteps[step]?.Value || Number.MAX_VALUE), exponent)));
        break;
      case 'MAE':
        steps.forEach(step => stepwise.push(getWeights(modelsToUse.map(x => x.MAESteps[step]?.Value || Number.MAX_VALUE), exponent)));
        break;
      case 'ME':
        steps.forEach(step => stepwise.push(getWeights(modelsToUse.map(x => x.MESteps[step]?.Value || Number.MAX_VALUE), exponent)));
        break;
      case 'MASE':
        steps.forEach(step => stepwise.push(getWeights(modelsToUse.map(x => x.MASESteps[step]?.Value || Number.MAX_VALUE), exponent)));
        break;
      case 'RSQ':
        steps.forEach(step => stepwise.push(getWeights(modelsToUse.map(x => x.RSQSteps[step]?.Value || Number.MAX_VALUE), exponent, false)));
        break;
      case 'MPE':
        steps.forEach(step => stepwise.push(getWeights(modelsToUse.map(x => x.MPESteps[step]?.Value || Number.MAX_VALUE), exponent)));
        break;
      case 'HR':
        steps.forEach(step => stepwise.push(getWeights(modelsToUse.map(x => x.HRSteps[step]?.Value || Number.MAX_VALUE), exponent, false)));
        break;
      case 'RMSE':
        steps.forEach(step => stepwise.push(getWeights(modelsToUse.map(x => x.RMSESteps[step]?.Value || Number.MAX_VALUE), exponent)));
        break;
    }

    return modelsToUse.map((m, i) => ({
      ModelName: m.Name,
      Stepwise: steps.map(step => stepwise[step][i])
    }));
  }

  export function calculateWeightedModelValues(
    models: StatTransformationModel[],
    variableId: string,
    forecastedCount: number,
    weightInfo: WeightMeasure,
    modelWeights: ModelWeightsInfo[]
  ) {
    /** Private helper func */
    const getWeightedValue = (value: number, weight: number) => {
      if (ValueUtils.isNum(value) && ValueUtils.isNum(weight)) {
        return value * weight;
      }
    };

    /** Private helper func */
    const setValues = (model: StatTransformationModel, weights: number[]) => {
      let j = 0;
      const variable = model.Variables.find(x => x.Id === variableId);
      for (let i = 0; i < variable.Data.length; i++) {
        const value = variable.Data[i];
        setWeightedValue(value as PlotValueDTO, weights[j]);
        if (i >= variable.Data.length - forecastedCount) {
          j++;
        }
      }

      for (let i = 0; i < model.PastForecasts.length; i++) {
        const pastValues = model.PastForecasts[i].Values;
        model.PastForecasts[i]['WValues'] = pastValues.map((v, k) => v * weights[k]);
      }

      if (model.FittedValues != null) {
        const fittedValues = model.FittedValues.slice(-fittedValueCount);
        model.WFittedValues = fittedValues.map((x) => {
          return { ...x, WF: getWeightedValue(x.F, weights[0]) };
        });
      }

      if (model.ShapValues != null) {
        model.WShapValues = model.ShapValues.map(x => StatModelUtils.copyShap(x));
        for (let varIdx = 0; varIdx < model.WShapValues.length; varIdx++) {
          model.WShapValues[varIdx].Values = model.WShapValues[varIdx].Values.map((x, i) => ({
            ...x, V: getWeightedValue(x.V, weights[i])
          }));
        }
      }
    };

    const activeModels = models.filter(mod => mod.show);
    const fittedValueCount = Math.min(...activeModels.map(mod => mod.FittedValues.length));
    for (const model of activeModels) {
      let w = modelWeights.find(x => x.ModelName === model.Name);
      const wArray = getWeight(weightInfo, w);
      setValues(model, wArray);
    }
  }

  export function consolidateModelValues(input: {
    activeVariableIds: string[],
    dates: { D: Date, m: moment.Moment; }[],
    models: StatTransformationModel[],
    variableId: string,
    hidePast: boolean,
    forecastedCount: number,
    name: string,
    desc: string,
    periodicity: PeriodicityType,
    color: string,
    modelName: ModelName;
  }): ALGSingleSeriesModel {

    const getValues = () => {
      const res: PlotValueDTO[] = [];
      const valueFields = ['V', 'I50', 'I75', 'I95', 'A50', 'A75', 'A95'];
      for (const model of activeModels) {
        const variable = model.Variables.find(x => x.Id === input.variableId);
        for (const value of variable.Data as PlotValueDTO[]) {
          let val: PlotValueDTO = res.find(x => x.m.isSame(value.m, 'day'));
          if (!val) {
            val = new PlotValueDTO();
            val.IF = value.IF;
            val.m = value.m;
            val.D = value.D;
            valueFields.forEach(field => {
              val[field] = value['W' + field] != null ? value['W' + field] : undefined;
            });
            res.push(val);
          } else {
            valueFields.forEach(field => {
              if (value['W' + field] != null) val[field] === undefined ? val[field] = value['W' + field] : val[field] += value['W' + field];
            });
          }
        }
      }
      return res;
    };

    const getFittedValues = (): FittedValueDTO[] => {
      let res: FittedValueDTO[] = [];
      for (let i = 0; i < activeModels.length; i++) {
        const model = activeModels[i];
        if (model.WFittedValues != null) {
          for (let p = 0; p < model.WFittedValues.length; p++) {
            const value = model.WFittedValues[p];
            if (res[p] != null) {
              res[p].WF += value.WF;
            } else {
              res[p] = value;
            }
          }
        }
      }
      return res;
    };

    const getShapValues = (): ShapValueDTO[] => {
      const activeModelsWShap = activeModels
        .filter(x => x.WShapValues != null && x.WShapValues.length > 0)
        .map(x => x.WShapValues);

      let res: ShapValueDTO[] = input.activeVariableIds.map(x => {
        let dto = new ShapValueDTO();
        dto.Id = x;
        dto.Values = input.dates.map(x => {
          let value = new SimplePlotValue();
          value.D = x.D;
          value.m = x.m;
          value.V = 0;
          return value;
        });
        return dto;
      });

      for (let i = 0; i < activeModelsWShap.length; i++) {
        const shapForModel = activeModelsWShap[i];
        for (let j = 0; j < shapForModel.length; j++) {
          const shap = shapForModel[j];
          const resEntry = res.find(x => x.Id === shap.Id);
          if (!resEntry) { continue; }
          resEntry.IsEvent = shap.IsEvent;
          resEntry.Name = shap.Name;
          resEntry.Values.forEach((x, i) => x.V = x.V + shap.Values[i]?.V || 0);
        }
      }

      return res.filter(x => x.Name != null);
    };

    const activeModels = input.models.filter(x => x.show);
    const values = getValues();
    const forecastedValues = values.filter(v => v.IF);
    const valueFactor = ValueUtils.GetMostCommonValueFactor(forecastedValues.map(x => x.V));
    const numFactor = valueFactor.factor;
    const abbrFactor = valueFactor.abbr;

    return {
      Description: input.desc,
      Name: input.name,
      modelName: { ...input.modelName, Description: input.desc, Display: input.name },
      FittedValues: [],
      PastForecasts: [],
      RollingAccuracy: false,
      Periodicity: input.periodicity,
      Values: values,
      ShapValues: [],
      WFittedValues: getFittedValues(),
      WShapValues: getShapValues() || [],
      show: true,
      Factor: numFactor > 0 ? numFactor : undefined,
      FactorAbbr: abbrFactor,
      Color: input.color
    };
  }

  /**
   * @param rollingAccuracy Dict holding info about previously calculated weights.
   * @param measurementType The active accuracy measurement type.
   * @param pastForecasts Past forecasts, for an array of models.
   * @param historicData The historic data to compare against.
   * @param exponent The exponent to apply when calculating weights
   * @returns
   */
  export function getRollingWeightedPastForecasts(
    rollingAccuracy: AccuracyModelMap,
    measurementType: WeightMeasure,
    pastForecasts: { Pasts: PastForecastDTO[], Model: string; }[],
    historicData: PlotValueDTO[],
    exponent: number
  ) {
    const modelCount = pastForecasts.length;
    if (modelCount === 0) return null;
    const pastForecastCount = pastForecasts[0].Pasts.length;
    const horizon = pastForecasts.map(p => p.Pasts[0].Values.length).max();
    const historicValues = historicData.slice(-pastForecastCount).map(h => h.V);

    const accuracies: number[][][] = [];
    for (let m = 0; m < modelCount; m++) {
      accuracies.push(getRollingAccuracy(rollingAccuracy, horizon, pastForecastCount, measurementType, pastForecasts[m], historicValues).accuracies);
    }

    const inverse = measurementType.Measure !== 'RSQ' && measurementType.Measure !== 'HR';

    const rollingWeights: number[][][] = Array.from({ length: pastForecastCount + 1 }, () => Array.from({ length: horizon }, () => []));
    for (let h = 0; h < horizon; h++) {
      const initWeight = 1.0 / accuracies.filter(x => !Number.isNaN(x[0][h])).length;
      const initWeights = accuracies.map(x => !Number.isNaN(x[0][h]) ? initWeight : 0);
      for (let p = 0; p < pastForecastCount; p++) {
        rollingWeights[p][h] = [...initWeights];
      }
    }

    for (let p = 0; p < pastForecastCount; p++) {
      for (let h = 0; h < horizon; h++) {
        let h_acc = measurementType.Method === 'First' ? 0 : h;
        const acc = measurementType.Method === 'Average'
          ? accuracies.map(x => x[p].filter(v => !Number.isNaN(v)).avg())
          : accuracies.map(x => x[p][h_acc]);

        rollingWeights[p + 1][h] = getWeights(acc, exponent, inverse);
      }
    }

    return getWeightedPast(pastForecasts.map(m => m.Pasts), rollingWeights, horizon);
  }


  /**
   * @param value  Calculate the weighted value (contribution from the model).
   * @param weight The weight for this model.
   */
  function setWeightedValue(value: PlotValueDTO, weight: number) {
    const propsToSet = ['V', 'I50', 'I75', 'I95', 'A50', 'A75', 'A95'];
    propsToSet.forEach(prop => {
      if (value[prop] || value[prop] === 0) {
        value['W' + prop] = value[prop] * weight;
      }
    });
  }

  /**
   * @param values Accuracy numbers to calculate weights from, given in an array, one number per model.
   * @param exponent The exponent to use to bias the result towards the more accurate models.
   * @param inverse Should the measure be weighted as an inverse (true for all but HR and RSQ where higher is better).
   * @returns Weights-array
   */
  function getWeights(values: number[], exponent: number, inverse: boolean = true) {
    // Maximum allowed value per invertedExp to avoid overflow
    const maxValue = Number.MAX_VALUE / values.length;
    let inverted_exp = inverse
      ? values.map(x => Math.pow(1 / (x + Number.MIN_VALUE), exponent))
      : values.map(x => Math.pow(x, exponent));
    inverted_exp = inverted_exp.map(v => Number.isNaN(v) ? v : !Number.isFinite(v) || v > maxValue ? maxValue : v);
    let sum = inverted_exp.filter(v => !Number.isNaN(v)).sum();
    if (!Number.isFinite(sum)) { sum = Number.MAX_VALUE; }
    else if (sum === 0) {
      const equal = 1.0 / inverted_exp.filter(v => !Number.isNaN(v)).length;
      return inverted_exp.map(v => Number.isNaN(v) ? 0 : equal);
    }
    return inverted_exp.map(x => x / sum);
  }

  /**
   *
   * @param pastForecasts is [model][pasts]
   * @param weights is [PastCount][Horizon][ModelIndex]
   * @param horizon
   */
  function getWeightedPast(pastForecasts: PastForecastDTO[][], weights: number[][][], horizon: number) {
    const result: PastForecastDTO[] = [];
    const pastCount = pastForecasts[0].length;
    for (let p = 0; p < pastCount; p++) {
      const stepValues: number[] = [];
      for (let h = 0; h < horizon; h++) {
        stepValues.push(pastForecasts.map((m, i) => m[p].Values.length > h ? m[p].Values[h] * weights[pastCount - p - 1][h][i] : 0).sum());
      }
      result.push({ Step: p + 1, Values: stepValues, WValues: [] });
    }
    return result;
  }

  function getRollingAccuracy(
    rollingAccuracy: AccuracyModelMap,
    horizon: number,
    pastCount: number,
    weightMeasure: WeightMeasure,
    modelInfo: { Pasts: PastForecastDTO[], Model: string; },
    historicValues: number[]
  ) {
    const accEntry = rollingAccuracy.get(modelInfo.Model);
    if (!!accEntry && accEntry.type.Measure === weightMeasure.Measure) {
      return accEntry;
    }

    // [horizon][past-step]
    const pastPerHorizon: number[][] = Array.from({ length: horizon }, () => Array.from({ length: pastCount }, () => NaN));
    for (let past = 0; past < pastCount; past++) {
      for (let h = 0; h < horizon; h++) {
        pastPerHorizon[h][past] = modelInfo.Pasts[pastCount - past - 1].Values[h];
      }
    }

    const modelHorizon = modelInfo.Pasts[0].Values.length;
    const defaultAcc = [...Array.from({length : modelHorizon}, () => 1), ...Array.from({length : horizon - modelHorizon}, () => NaN)];

    const accPerHorizon: number[][] = Array.from({ length: pastCount }, () => [...defaultAcc]);
    for (let h = 0; h < modelHorizon; h++) {
      let accumulator = 0;
      let accumulator2 = 0;
      let accumulator3 = 0;
      let work = 0;
      for (let past = 0; past < pastCount; past++) {
        if (past + h >= pastCount) {
          continue;
        }
        const hist = historicValues[past + h];
        const fcast = pastPerHorizon[h][past];
        switch (weightMeasure.Measure) {
          default:
          case 'RMSE':
            work = hist - fcast; // E
            work *= work; // S
            accumulator += work;
            accPerHorizon[past + h][h] = Math.sqrt(accumulator / (past + 1)); // RM
            break;
          case 'MAPE':
            work = Math.abs((hist - fcast) / (hist + Number.EPSILON) * 100); // APE
            accumulator += work;
            accPerHorizon[past + h][h] = accumulator / (past + 1); // M
            break;
          case 'MAE':
            work = hist - fcast; // E
            accumulator += Math.abs(work); // A
            accPerHorizon[past + h][h] = accumulator / (past + 1); // M
            break;
          case 'ME':
            work = hist - fcast; // E
            accumulator += work;
            accPerHorizon[past + h][h] = accumulator / (past + 1); // M
            break;
          case 'MPE':
            work = (hist - fcast) / (hist + Number.EPSILON) * 100; // PE
            accumulator += work;
            accPerHorizon[past + h][h] = accumulator / (past + 1); // M
            break;
          case 'EQUAL':
            accPerHorizon[past + h][h] = 1;
            break;
          case 'HR':
            if (past + h > 0 && past > 0) {
              const sign = hist - historicValues[past + h - 1] > 0;
              const fSign = fcast - pastPerHorizon[h][past - 1] > 0;
              accumulator += sign === fSign ? 1 : 0;
            }
            accPerHorizon[past + h][h] = accumulator / (past + 1);
            break;
          case 'RSQ':
            // acc1: total of actuals, acc2: total of squared actuals, acc3: SSE
            accumulator += hist;
            accumulator2 += hist * hist;
            let deviance = fcast - hist;
            accumulator3 += deviance * deviance; // SSE
            let sst = accumulator2 - accumulator * accumulator / (past + 1);
            accPerHorizon[past + h][h] = Math.max(1 - accumulator3 / (sst + Number.EPSILON), 0);
            break;
          case 'MASE':
            accumulator += Math.abs(fcast - hist); // Sum of absolute errors
            if (past + h > 0) {
              accumulator2 += Math.abs(hist - historicValues[past + h - 1]); // Naive sum of absolute errors
              accPerHorizon[past + h][h] = accumulator / ((past + 1) * accumulator2 / past + Number.EPSILON);
            }
            break;
        }
      }
    }

    const newEntry = { accuracies: accPerHorizon, type: { Measure: weightMeasure.Measure, Method: weightMeasure.Method } };
    rollingAccuracy.set(modelInfo.Model, newEntry);
    return newEntry;
  }
}
