import { apiSlice } from './apiSlice';
import { ModelFilesParameters, ModelParameters } from '../types/model';

// TODO - export type TrainingData to shared file
export type TrainingData = {
  epochs: number[];
  loss: number[];
  valLoss: number[];
  diceCoefficient: number[];
  valDiceCoefficient: number[];
  iou: number[];
  valIou: number[];
  lr: number[];
};

export const modelsApiSlice = apiSlice.injectEndpoints({
  endpoints: (builder) => ({
    recordNewModel: builder.mutation<any, { userId: number; datasetName: string; versionName: string; data: ModelParameters }>({
      query: ({ userId, datasetName, versionName, data }) => ({
        url: 'model/record_new_model',
        method: 'POST',
        body: {
          user_id: userId,
          dataset_name: datasetName,
          version_name: versionName,
          data: data,
        },
      }),
      invalidatesTags: ['Datasets'],
    }),

    generateModelFiles: builder.mutation<any, { modelFilesParameters: ModelFilesParameters, instance_id: string }>({
      query: ({ modelFilesParameters, instance_id }) => ({
        url: 'templating/generate-scripts',
        method: 'POST',
        body: {
          instanceId: instance_id,
          data: modelFilesParameters
        },
      }),
    }),

    transferModelFiles: builder.mutation<any, { instance_id: string, modelId: string }>({
      query: ({ instance_id, modelId }) => ({
        url: 'templating/transfer_files',
        method: 'POST',
        body: {
          instanceId: instance_id,
          modelId: modelId,
        },

      }),
    }),

    getTrainingHistory: builder.query<any, { user_id: number, modelId: string }>({
      query: ({ user_id, modelId }) => ({
        url: `data/past_data/${user_id}/${modelId}`,
        method: 'GET',
      }),
      transformResponse: (response: any) => {
        return response.trainingMetricData
      },
      providesTags: (result) => [{ type: 'TrainedModelMetrics', id: result.user_id, modelId: result.model_id }],
    }),

    getModelParameters: builder.query<any, { userId: number, datasetName: string, versionName: string, modelName: string, modelId: string }>({
      query: ({ userId, datasetName, versionName, modelName, modelId }) => ({
        url: `model/get_model_parameters`,
        method: 'POST',
        body: {
          user_id: userId,
          dataset_name: datasetName,
          version_name: versionName,
          model_name: modelName,
          model_id: modelId,
        },
      }),
      transformResponse: (response: any) => {
        return response.model_parameters
      },
      providesTags: (result: any) => [{ type: 'ModelParameters', id: result.user_id, modelId: result.modelId }],
    }),

    getTrainedModels: builder.query<any, { user_id: number }>({
      query: ({ user_id }) => ({
        url: `model/get_trained_models`,
        method: 'POST',
        body: { user_id: user_id },
      }),
      keepUnusedDataFor: 0,
    }),

    deleteTrainedModels: builder.mutation<any, { user_id: number; model_ids: string[] }>({
      query: ({ user_id, model_ids }) => ({
        url: 'model/delete_trained_models',
        method: 'POST',
        body: { user_id, model_ids },
      }),
    }),

    downloadModel: builder.mutation<any, any>({
      query: (data) => ({
        url: 'model/download_model',
        method: 'POST',
        body: { data },
      }),
    }),
  }),
});

export const { 
  useRecordNewModelMutation,
  useGenerateModelFilesMutation, 
  useTransferModelFilesMutation, 
  useGetTrainingHistoryQuery, 
  useGetModelParametersQuery, 
  useGetTrainedModelsQuery, 
  useDeleteTrainedModelsMutation,
  useDownloadModelMutation,
} = modelsApiSlice;
