import { useEffect, useState } from 'react';
import toast from 'react-hot-toast';
import { getErrorMessage } from '../common/utils';

import { PromptComposer, PromptVersionSelector, Selector, TabGroup } from '../components';
import { PV } from '../components/common/PromptVersionSelector';
import {
  calculateSentiment,
  getModels,
  gradeModels,
  runModel,
  scoreEvaluation,
  summarizeResults
} from '../services/Models';
import { Model } from '../types';
import { ModelRunResult } from '../types/Models';
import { CompletionMetric } from '../types/Performance';
import { ModelPicker, PerformanceGraph } from '../components/performance';
import { mean } from '../common/math';
import { SelectorValue } from '../components/common/Selector';
import { EvaluationType, Grades, RatedEvaluation, Sentiment, Summary } from '../types/Evaluations';
import ProgressTracker, { Progress } from '../components/performance/Progress';
import Skeleton from 'react-loading-skeleton';
import ResultsTable from '../components/performance/ResultsTable';
import { translateModelParameters } from '../common/models';
import { FontAwesomeIcon } from '@fortawesome/react-fontawesome';
import { faPersonRunning } from '@fortawesome/free-solid-svg-icons';
import { PromptVersionTypes } from '../types/Prompt';

const runsSelectorValues: SelectorValue[] = [
  { value: 1, label: '1 Run' },
  { value: 5, label: '5 Runs' },
  { value: 10, label: '10 Runs' },
  { value: 15, label: '15 Runs' }
];

/**
 * Props interface for the Performance component.
 */
interface Props {}

/**
 * Performance page component.
 *
 * @component
 * @param {Props} props - The component props.
 * @returns {JSX.Element} The rendered component.
 */
const Performance: React.FC<Props> = ({}: Props) => {
  const [isBusy, setIsBusy] = useState<boolean>(false);
  const [models, setModels] = useState<Model[]>([]);
  const [modelNameMap, setModelNameMap] = useState<Record<string, string>>({});
  const [apv, setApv] = useState<PV>();
  const [selectedModelIds, setSelectedModelIds] = useState<string[]>([]);
  const [selectedRunCount, setSelectedRunCount] = useState<SelectorValue>(runsSelectorValues[1]);
  const [completionMetrics, setCompletionMetrics] = useState<CompletionMetric[]>([]);
  const [rawMetrics, setRawMetrics] = useState<ModelRunResult[]>([]);
  const [progress, setProgress] = useState<Progress>();
  const [currentPayload, setCurrentPayload] = useState<any>({});
  const [generatedPrompt, setGeneratedPrompt] = useState<string>();
  const [completionSentiments, setCompletionSentiments] = useState<Sentiment[]>([]);
  const [completionCoherence, setCompletionCoherence] = useState<RatedEvaluation[]>([]);
  const [completionFluency, setCompletionFluency] = useState<RatedEvaluation[]>([]);
  const [summary, setSummary] = useState<Summary>();
  const [grades, setGrades] = useState<Grades>();
  const [selectedTab, setSelectedTab] = useState<number>(0);

  const reset = () => {
    setIsBusy(false);
    setApv(undefined);
    setGeneratedPrompt(undefined);
    setCurrentPayload(undefined);
    resetMetrics();
  };

  const resetMetrics = () => {
    setRawMetrics([]);
    setCompletionMetrics([]);
    setCompletionSentiments([]);
    setCompletionCoherence([]);
    setCompletionFluency([]);
    setProgress(undefined);
    setSummary(undefined);
    setGrades(undefined);
  };

  useEffect(() => {
    (async () => {
      setIsBusy(true);
      try {
        const data = await getModels();
        setModels(data);
        setModelNameMap(
          data.reduce((acc: Record<string, string>, model) => {
            acc[model.mid] = model.name;
            return acc;
          }, {})
        );
      } catch (error) {
        return toast.error(getErrorMessage(error));
      } finally {
        setIsBusy(false);
      }
    })();
  }, []);

  useEffect(() => {
    if (!rawMetrics.length) return;

    const metrics = rawMetrics.reduce((acc: Record<string, any>, result: ModelRunResult) => {
      if (!acc[result.modelId]) {
        acc[result.modelId] = {
          model: result.modelName,
          completion: undefined,
          latencies: [],
          ttfbs: [],
          requestCosts: [],
          responseCosts: [],
          requestTokens: [],
          responseTokens: []
        };
      }

      if (!result.failed) {
        const { modelId, latency, ttfb, tokens, completion } = result;
        const { requestCost, responseCost, requestTokens, responseTokens } = tokens;
        acc[modelId].completion = completion;
        acc[modelId].latencies.push(latency);
        acc[modelId].ttfbs.push(ttfb);
        acc[modelId].requestCosts.push(requestCost);
        acc[modelId].responseCosts.push(responseCost);
        acc[modelId].requestTokens.push(requestTokens);
        acc[modelId].responseTokens.push(responseTokens);
      }
      return acc;
    }, {});

    const metricsAvg: CompletionMetric[] = Object.values(metrics).map((data: any) => {
      return {
        model: data.model,
        completion: data.completion,
        avgLatency: mean(data.latencies),
        maxLatency: Math.max(...data.latencies),
        minLatency: Math.min(...data.latencies),
        avgTtfb: mean(data.ttfbs),
        maxTtfb: Math.max(...data.ttfbs),
        minTtfb: Math.min(...data.ttfbs),
        avgRequestCost: mean(data.requestCosts),
        avgResponseCost: mean(data.responseCosts),
        avgTotalCost: mean(data.requestCosts) + mean(data.responseCosts),
        avgRequestTokens: mean(data.requestTokens),
        avgResponseTokens: mean(data.responseTokens)
      } as CompletionMetric;
    });

    const isMetricsDone = rawMetrics.length === progress?.totalMetrics;

    setCompletionMetrics(metricsAvg);
    updateProgress(rawMetrics.length);

    if (isMetricsDone) {
      (async () => {
        await Promise.all([gradeCompletions(), generateSummary(metricsAvg), generateModelGrades(metricsAvg)]);
      })();
    }
  }, [rawMetrics]);

  const updateProgress = (completedCount: number) => {
    setProgress((prev) => {
      return {
        ...prev!,
        completedCount,
        endTime: Date.now(),
        failedCount: rawMetrics.filter((m) => m.failed || !m.completion.trim().length).length,
        status: `Completed ${completedCount} of ${prev!.totalCount}`,
        percentageComplete: (completedCount / prev!.totalCount) * 100,
        isDone: completedCount === prev!.totalCount
      };
    });
  };

  const onPVChange = (apv: PV) => {
    if (apv.version) {
      setApv(apv);
    } else {
      reset();
    }
  };

  const onModelSelect = (mid: string, checked: boolean) => {
    if (checked) {
      setSelectedModelIds((prev) => [...prev, mid]);
    } else {
      setSelectedModelIds((prev) => prev.filter((id) => id !== mid));
    }
  };

  const start = async () => {
    if (!apv?.version) return toast.error('Please select a prompt version');
    if (!selectedModelIds.length) return toast.error('Please select at least one model');
    if (!generatedPrompt) return toast.error('Please validate the prompt.');

    const totalMetrics = selectedModelIds.length * Number(selectedRunCount.value);
    const totalCount = totalMetrics + selectedModelIds.length * 3; // sentiment, coherence, fluency,

    resetMetrics();
    setProgress({
      startTime: Date.now(),
      endTime: Date.now(),
      completedCount: 0,
      totalMetrics,
      totalCount,
      percentageComplete: 0,
      status: 'Starting',
      failedCount: 0,
      isDone: false
    });

    setIsBusy(true);

    let promises = [];
    for (let i = 0; i < selectedModelIds.length; i++) {
      for (let j = 0; j < Number(selectedRunCount.value); j++) {
        const translatedParameters = translateModelParameters(
          models,
          apv.version.model,
          selectedModelIds[i],
          apv.version.parameters,
          apv.version.type
        );

        promises.push(
          new Promise<void>((resolve) => {
            runModel(selectedModelIds[i], generatedPrompt, translatedParameters)
              .then((result) => {
                result.modelName = modelNameMap[result.modelId] || result.modelId;
                return result;
              })
              .then((result) => setRawMetrics((prev) => [...prev, result]))
              .catch((error) => toast.error(getErrorMessage(error)))
              .finally(resolve);
          })
        );
      }
    }

    setSelectedTab(1);

    // axios interceptor will throttle the requests
    await Promise.all(promises.sort(() => Math.random() - 0.5));

    setIsBusy(false);
  };

  const gradeCompletions = async () => {
    setCompletionSentiments([]);
    setCompletionCoherence([]);
    setCompletionFluency([]);

    const groupedResponses = rawMetrics.reduce((acc: Record<string, any>, result: ModelRunResult) => {
      if (!acc[result.modelName]) {
        acc[result.modelName] = [];
      }

      acc[result.modelName].push(result.completion);

      return acc;
    }, {});

    let count = 0;

    for await (const [model, completions] of Object.entries(groupedResponses)) {
      try {
        let uniqueCompletions: string[] = Array.from(new Set(completions));
        uniqueCompletions = uniqueCompletions.filter((c) => c.length > 0);
        try {
          const sentiment = await calculateSentiment(uniqueCompletions);
          sentiment.model = model;
          setCompletionSentiments((prev) => [...prev, sentiment]);
        } catch (error) {}
        updateProgress(rawMetrics.length + ++count);

        try {
          const coherence = await scoreEvaluation(EvaluationType.COHERENCE, generatedPrompt!, uniqueCompletions);
          coherence.model = model;
          setCompletionCoherence((prev) => [...prev, coherence]);
        } catch (error) {}
        updateProgress(rawMetrics.length + ++count);

        try {
          const fluency = await scoreEvaluation(EvaluationType.FLUENCY, generatedPrompt!, uniqueCompletions);
          fluency.model = model;
          setCompletionFluency((prev) => [...prev, fluency]);
        } catch (error) {}
        updateProgress(rawMetrics.length + ++count);
      } catch (error) {
        console.error(error);
        toast.error(getErrorMessage(error));
      }
    }
  };

  const getRollupPayload = (finalMetrics: CompletionMetric[]): any[] => {
    return finalMetrics.map((metric) => {
      const sentiment = completionSentiments.find((s) => s.model === metric.model);
      const coherence = completionCoherence.find((c) => c.model === metric.model);
      const fluency = completionFluency.find((f) => f.model === metric.model);

      return {
        model: metric.model,
        latency: metric.avgLatency,
        ttfb: metric.avgTtfb,
        cost: metric.avgTotalCost,
        request_cost: metric.avgRequestCost,
        response_tokens: metric.avgResponseTokens,
        sentiment: sentiment?.results,
        coherence: coherence?.ratings,
        fluency: fluency?.ratings
      };
    });
  };

  const generateSummary = async (finalMetrics: CompletionMetric[]) => {
    try {
      setSummary(await summarizeResults(getRollupPayload(finalMetrics)));
    } catch (error) {
      toast.error(getErrorMessage(error));
    }
  };

  const generateModelGrades = async (finalMetrics: CompletionMetric[]) => {
    try {
      setGrades(await gradeModels(getRollupPayload(finalMetrics)));
    } catch (error) {
      toast.error(getErrorMessage(error));
    }
  };

  const getCost = (): number => {
    return (
      rawMetrics.map((m) => m.tokens.requestCost + m.tokens.responseCost).reduce((a, b) => a + b, 0) +
      completionSentiments.reduce((a, b) => a + b.totalCost, 0) +
      completionCoherence.reduce((a, b) => a + b.totalCost, 0) +
      completionFluency.reduce((a, b) => a + b.totalCost, 0) +
      (summary?.totalCost || 0) +
      (grades?.totalCost || 0)
    );
  };

  return (
    <TabGroup
      tabNames={['Configuration', 'Results', 'Data']}
      selectedTabIndex={selectedTab}
      onTabChange={setSelectedTab}>
      <>
        <div className="mx-auto">
          {selectedTab === 0 && (
            <div>
              <div className="mb-5"></div>
              <PromptVersionSelector
                selected={apv}
                onChange={onPVChange}
                className="mb-6"
                defaultLabels={['Select Prompt', 'Select Version']}
              />
              <div className="mt-4">
                <div className="flex gap-x-4">
                  <PromptComposer
                    template={
                      apv?.version?.type === PromptVersionTypes.MESSAGING
                        ? JSON.stringify(apv?.version?.messageTemplate)
                        : apv?.version?.template
                    }
                    payload={currentPayload}
                    disabled={isBusy}
                    onPromptGenerate={setGeneratedPrompt}
                    onUnload={setCurrentPayload}
                  />
                  <div className="w-64">
                    <ModelPicker
                      models={models}
                      selectedModels={selectedModelIds}
                      disabled={isBusy}
                      onModelClick={onModelSelect}
                    />
                    <div className="mt-4 flex">
                      <div className="flex-1">
                        <Selector
                          values={runsSelectorValues}
                          onChange={setSelectedRunCount}
                          defaultValue={selectedRunCount}
                          classNames="w-28"
                          disabled={isBusy}
                        />
                      </div>
                      <div className="text-right">
                        <button className="standard" onClick={start} disabled={isBusy}>
                          <FontAwesomeIcon icon={faPersonRunning} className="mr-1" />
                          Start
                        </button>
                      </div>
                    </div>
                  </div>
                </div>
              </div>
            </div>
          )}
          {selectedTab === 1 && (
            <div>
              {progress && <ProgressTracker progress={progress} cost={getCost()} />}

              <div className="my-6 grid grid-cols-2 gap-x-16 text-sm text-gray-700">
                <div>
                  <h2 className="text-lg font-semibold text-gray-700 mb-0">Analysis</h2>
                  {summary ? summary.results : <Skeleton count={3} />}
                </div>
                <div>
                  <h2 className="text-lg font-semibold text-gray-700 mb-0">Model Grading</h2>
                  {grades ? (
                    <div className="grid grid-cols-2 grow-0 w-72">
                      {Object.entries(grades.grades)
                        .sort(([a], [b]) => a.localeCompare(b))
                        .map(([key, value]) => (
                          <div key={key}>
                            <div className="text-gray-500">{key}</div> <div className="font-semibold">{value}</div>
                          </div>
                        ))}
                    </div>
                  ) : (
                    <Skeleton count={3} />
                  )}
                </div>
              </div>
              <PerformanceGraph
                metrics={completionMetrics}
                sentiments={completionSentiments}
                coherence={completionCoherence}
                fluency={completionFluency}
              />
            </div>
          )}
          {selectedTab === 2 && <ResultsTable results={rawMetrics} />}
        </div>
      </>
    </TabGroup>
  );
};

export default Performance;
