import { useMutation, useQuery } from "@tanstack/react-query";
import { FewShotTaggingModelsApi } from "./Api";
import { useReducer } from "react";

import {
  initialNewShotLearningModelsComponentState,
  newShotLearningModelsComponentReducer,
} from "./Reducer";
import {
  ClassificationMethod,
  FewShotTaggingModelView,
} from "../../app_client";

export const useProjectFewShotLearningModels = (deps: {
  api: FewShotTaggingModelsApi;
  readXLSX: (file: File) => Promise<string[][]>;
  notifyError: (message: string) => void;
}) => {
  const [state, dispatch] = useReducer(
    newShotLearningModelsComponentReducer,
    initialNewShotLearningModelsComponentState
  );

  let _models = useQuery({
    queryKey: ["_fewShotTaggingModels"],
    queryFn: () => {
      return deps.api.fetchModels();
    },
    onSuccess: (data) => {
      dispatch({ type: "WORKSPACE_MODEL_FETCHED", payload: data });
    },
    enabled: !!deps.api,
    refetchInterval: 5000,
  });

  let _createModel = useMutation({
    mutationFn: (data: {
      name: string;
      model_id: string;
      training_data: string[][];
      embeding_model: string;
      classification_method: ClassificationMethod;
    }) => {
      return deps.api.createModel(data);
    },
    onSuccess: (data) => {
      _models.refetch();
    },
    onError: (error) => {
      deps.notifyError(`Une erreur est survenue`);
    },
  });

  let selectFile = async (files: FileList | null) => {
    if (!files) return;
    if (files.length === 0) return;
    let file = files[0];

    dispatch({ type: "FILE_SELECTED", payload: file });
    let data = await deps.readXLSX(file);
    dispatch({ type: "FILE_READ", payload: data });
  };

  let onOpenCreateModel = () => {
    if (state.modelName.length < 4) {
      deps.notifyError("Le nom du modèle doit contenir au moins 4 caractères");
    } else if (state.selectedFile === null) {
      deps.notifyError("Veuillez sélectionner un fichier");
    } else {
      dispatch({ type: "CREATE_MODAL_OPENED" });
    }
  };

  let onDownloadTrainingData = (model: FewShotTaggingModelView) => {
    deps.api.downloadTrainingData(model);
  };

  let onConfirmFileConfiguration = () => {
    if (state.columnLabel === "" || state.columnVerbatim === "") {
      deps.notifyError("Veuillez sélectionner des colonnes valides");
      return;
    }
    dispatch({ type: "CONFIGURATION_STEP_PASSED" });
  };

  let trainModel = () => {
    let modelId = `fewshot_tagging_model_${crypto.randomUUID()}`;
    _createModel.mutate({
      model_id: modelId,
      name: state.modelName,
      training_data: state.trainingData,
      embeding_model: state.modelEmbeddingModel,
      classification_method: state.modelClassificationMethod,
    });
    dispatch({ type: "MODEL_CREATED", payload: modelId });
  };

  let disableModel = (model: FewShotTaggingModelView) => {
    deps.api.disableModel(model);
    dispatch({ type: "MODEL_DISABLED", payload: model.model_id });
  };

  let triggerModelTraining = (model: FewShotTaggingModelView) => {
    deps.api.triggerTraining(model);
    dispatch({ type: "MODEL_TRAINING_TRIGGERED", payload: model.model_id });
  };

  return {
    _models,
    state,
    dispatch,
    actions: {
      selectFile,
      onOpenCreateModel,
      onDownloadTrainingData,
      onConfirmFileConfiguration,
      trainModel,
      disableModel,
      triggerModelTraining,
    },
  };
};

export type ProjectFewShotLearningModels = ReturnType<
  typeof useProjectFewShotLearningModels
>;
