import { useState, useEffect } from 'react';
import { useNavigate } from 'react-router-dom';
import { Typography, Grid, Button, Container, Box, TextField, CircularProgress, Alert, Modal, Tooltip, FormControl, InputLabel, Select, MenuItem, Divider, ListItemText, Popover, MenuList } from '@mui/material';
import CheckCircleOutlineIcon from '@mui/icons-material/CheckCircleOutline';
import { IntegerParameter, BooleanParameter, StringParameter, DropdownParameter } from '../components/inputFields';
import { VersionInformation } from '../types/dataset';
import { SelectChangeEvent } from '@mui/material/Select';
import { styled } from '@mui/material/styles';
import { useGetUserDatasetsQuery } from '../slices/datasetApiSlice';
import { useRecordNewModelMutation } from '../slices/modelsApiSlice';
import { NestedMenuItem } from '../components/NestedMenu';
import templates from '../data/templates.json';

const learningRateSchedulerOptions = ["example1", "example2"]  
const trainingTypeOptions = ["example1", "example2"]  
const baseModelArchitectureOptions = ["example1", "example2"]
const attentionTypeOptions = ["example1", "example2"]
const decoderBlockTypeOptions = ["example1", "example2"]
const encoderBlockTypeOptions = ["example1", "example2"]

type ModelTemplate = {
  learning_rate_scheduler: string;
  training_type: string;
  base_model_architecture: string;
  attention_type: string;
  decoder_block_type: string;
  encoder_block_type: string;
  learning_rate: number;
  num_epochs: number;
  batch_size: number;
  tile_size: number;
  patience: number;
  data_augmentation: boolean;
};


const ModelTemplates: Record<string, ModelTemplate> = {
  "Medical": {
    "learning_rate_scheduler": "example1",
    "training_type": "example1",
    "base_model_architecture": "example1",
    "attention_type": "example1",
    "decoder_block_type": "example1",
    "encoder_block_type": "example1",
    "learning_rate": 0.001,
    "num_epochs": 300,
    "batch_size": 16,
    "tile_size": 224,
    "patience": 10,
    "data_augmentation": true,
  },
  "Quantity Surveying": {
    "learning_rate_scheduler": "example2",
    "training_type": "example2",
    "base_model_architecture": "example2",
    "attention_type": "example2",
    "decoder_block_type": "example2",
    "encoder_block_type": "example2",
    "learning_rate": 0.01,
    "num_epochs": 250,
    "batch_size": 32,
    "tile_size": 224,
    "patience": 10,
    "data_augmentation": true,
  }
}

const CustomFormControl = styled(FormControl)(({ theme }) => ({
  '& .MuiOutlinedInput-root': {
    '& fieldset': {
      borderColor: 'blue',
    },
    '&:hover fieldset': {
      borderColor: 'blue',
    },
    '&.Mui-focused fieldset': {
      borderColor: 'blue',
    },
  },
}));

export default function ModelParameters() {
  const navigate = useNavigate();

  const [ isRedirect, setIsRedirect ] = useState(false);
  const [ isModalOpen, setIsModalOpen ] = useState(false);
  
  //TODO: Handle isLoadDatasetError
  const { data: userDatasets = [], error: isLoadDatasetError, isLoading: isLoadingDataset } = useGetUserDatasetsQuery(1);
  const [ modelVersions, setModelVersions ] = useState<VersionInformation[]>([]);
  const [ classes, setClasses ] = useState<string[]>([]);
  
  // Model Template selection
  const [ selectedModelTemplate, setSelectedModelTemplate ] = useState("");
  
  // Model Parameters
  const [ selectedDatasetName, setSelectedDatasetName ] = useState("");
  const [ selectedModelVersion, setSelectedModelVersion ] = useState("");
  const [ trainingType, setTrainingType ] = useState("");
  const [ learningRateScheduler, setLearningRateScheduler ] = useState("")
  const [ learningRate, setLearningRate ] = useState(NaN);
  const [ isReduceLR, setIsReduceLR ] = useState(false);
  const [ baseModelArchitecture, setBaseModelArchitecture ] = useState("")
  const [ isAttention, setIsAttention ] = useState(true)
  const [ attentionType, setAttentionType ] = useState("")
  const [ decoderType, setDecoderType ] = useState("")
  const [ encoderType, setEncoderType ] = useState("")
  const [ numEpochs, setNumEpochs ] = useState(NaN);
  const [ batchSize, setBatchSize ] = useState(NaN);
  const [ tileSize, setTileSize ] = useState(NaN);
  const [ patience, setPatience ] = useState(NaN);
  const [ isDataAug, setIsDataAug ] = useState(true);
  const [ baseModelName, setBaseModelName ] = useState("");
  
  // Api slice to record new model in database
  const [ recordNewModel, { isLoading: isLoadingRecordNewModel, isSuccess: isSuccessRecordNewModel, isError: isErrorRecordNewModel, error: errorRecordNewModel }] = useRecordNewModelMutation();

  useEffect(() => {
    let selected_dataset = userDatasets.find( ds => ds.datasetName === selectedDatasetName )
    
    if (selected_dataset) {
      setModelVersions(selected_dataset.versions)
      setClasses(selected_dataset.filteredClasses)
    }
  }, [selectedDatasetName, userDatasets])
  
  useEffect(() => {
    if (selectedModelTemplate) {
      // TODO: Set parameters for different model templates
      let selected_template = ModelTemplates["Medical"];

      setLearningRateScheduler(selected_template.learning_rate_scheduler);
      setLearningRate(selected_template.learning_rate);
      setTrainingType(selected_template.training_type);
      setBaseModelArchitecture(selected_template.base_model_architecture);
      setAttentionType(selected_template.attention_type);
      setDecoderType(selected_template.decoder_block_type);
      setEncoderType(selected_template.encoder_block_type);
      setNumEpochs(selected_template.num_epochs);
      setBatchSize(selected_template.batch_size);
      setTileSize(selected_template.tile_size);
      setPatience(selected_template.patience);
      setIsDataAug(selected_template.data_augmentation);
    };
    

  },[selectedModelTemplate])

  const handleSelectedDatasetNameChange = (newValue: string) => {
    setSelectedDatasetName(newValue)
  }
  const handleSelectedModelVersionChange = (newValue: string) => {
    setSelectedModelVersion(newValue)
  }
  const handleTrainingTypeChange = (newValue: string) => {
    setTrainingType(newValue)
  }
  const handleLearningRateSchedulerChange = (newValue: string) => {
    setLearningRateScheduler(newValue)
  }
  const handleLearningRateChange = (newValue: number) => {
    setLearningRate(newValue);
  }
  const handleReduceLRChange = (newValue: boolean) => {
    setIsReduceLR(newValue)
  }
  const handleBaseModelArchitectureChange = (newValue: string) => {
    setBaseModelArchitecture(newValue)
  }
  const handleIsAttentionChange = (newValue: boolean) => {
    setIsAttention(newValue)
  }
  const handleAttentionTypeChange = (newValue: string) => {
    setAttentionType(newValue)
  }
  const handleDecoderTypeChange = (newValue: string) => {
    setDecoderType(newValue)
  }
  const handleEncoderTypeChange = (newValue: string) => {
    setEncoderType(newValue)
  }
  const handleNumEpochsChange = (newValue: number) => {
    setNumEpochs(newValue);
  }
  const handleBatchSizeChange = (newValue: number) => {
    setBatchSize(newValue);
  }
  const handleTileSizeChange = (newValue: number) => {
    setTileSize(newValue);
  }
  const handlePatienceChange = (newValue: number) => {
    setPatience(newValue);
  }
  const handleDataAugChange = (newValue: boolean) => {
    setIsDataAug(newValue)
  }
  const handleBaseModelNameChange = (newValue: string) => {
    setBaseModelName(newValue)
  }


  const renderClassFields = () => {
    if (classes.length === 0) {
      return (
        <Alert severity='info'>No Classes in dataset</Alert>
      )
    }

    return classes.map((cls, index) => (
      <Box
      key={index}
      sx={{
        maxHeight: '200px', 
        overflowY: 'auto',  
      }}> 
        <TextField
          value={cls}
          fullWidth
          size="small"
          sx={{
            marginBottom: 1, 
          }}
          InputProps={{
            readOnly: true,
          }}
        />
      </Box>
    ));
  };

  const handleSubmit = async () => {
    // Check for duplicated model name
    if (userDatasets.some((dataset) =>
      dataset.versions.some((version) =>
        version.models.some((model) => model.model_label === baseModelName)
      )
    )) {
      alert("Model name already exists!");
      return;
    }

    // Check for emtpy model name
    if (!baseModelName) {
      alert("Your model name is empty!")
      return;
    }

    let modelParameters = {
      model_label: baseModelName, 
      patience: patience,
      batchSize : batchSize,
      numEpochs: numEpochs,
      tileSize: tileSize,
      aug: isDataAug,
      isTrained: false,
    }

    await recordNewModel({
      userId: 1,
      datasetName: selectedDatasetName,
      versionName: selectedModelVersion,
      data: modelParameters,
    });
  }

  useEffect(() => {
    if (isSuccessRecordNewModel) {
      setIsModalOpen(true);

      setTimeout(() => {
        setIsRedirect(true);

        setTimeout(() => {
          navigate('/instancedashboard');
        }, 2000);
      }, 2000);
    }
  }, [isSuccessRecordNewModel]);


  const [anchorEl, setAnchorEl] = useState<HTMLButtonElement | null>(null);

  if (isLoadingDataset){
    return(
      <Box
        sx={{
          display: 'flex',
          flexDirection: 'column',  
          justifyContent: 'center', 
          alignItems: 'center',     
          height: '100vh',          
          textAlign: 'center'       
        }}
      >
        <CircularProgress />
        <Typography variant="h6" sx={{ marginTop: 2 }}>Loading data...</Typography>
      </Box>
    )
  }


  return (
    <Container> 
      <Grid container spacing={2}>
        <Grid item xs={12}>
          <Typography variant='h3'>New Model Training Job</Typography>
        </Grid>
        <Grid item xs={12} sm={8} md={8} lg={9} sx={{display:'flex', alignItems: 'center'}}>
          <Typography variant='subtitle1'>Each job will take in a dataset to train on, hyperparameters to train with and outputs a trained model artifact</Typography>
        </Grid>
        <Grid item xs={6} sm={4}  md={4} lg={3} style={{ display: 'flex', flexDirection: 'column', justifyContent: 'flex-end' }}>
          <Button size="large" variant="contained" onClick={(event) => setAnchorEl(event.currentTarget)}>Select Template</Button>
          <Popover
            open={Boolean(anchorEl)}
            anchorEl={anchorEl}
            onClose={() => setAnchorEl(null)}
            anchorOrigin={{
              vertical: 'bottom',
              horizontal: 'right',
            }}
            transformOrigin={{
              vertical: 'top',
              horizontal: 'right',
            }}
          >
            <MenuList>
              {Object.entries(templates).map(([key, category]) => (
              <NestedMenuItem key={key} label={category.label}>
                {Object.entries(category.children).map(([key, template]) => (
                  <Tooltip title={template.description} key={key} disableInteractive arrow>
                    <MenuItem onClick={() => setSelectedModelTemplate(template.label)}>
                      <ListItemText>{template.label}</ListItemText>
                    </MenuItem>
                  </Tooltip>
                ))}
              </NestedMenuItem>))}
            </MenuList>
          </Popover>
        </Grid>
      </Grid>
      
      <Box mt={4} />
      <Divider />

      {isErrorRecordNewModel && 
        <Alert severity='error'>Error creating model... Please fill all required fields!</Alert>
      }

      <Box mt={4} />
      
      <Grid container spacing={3} >
        <Grid item xs={6}>
          <DropdownParameter
            hyperparam_name='Dataset Name'
            selectItems={userDatasets.map((dataset) => dataset.datasetName)}
            value={selectedDatasetName}
            onChange={handleSelectedDatasetNameChange}
            isRequired={true}
            helperText='Select a dataset to train on'
          />
        </Grid>

        <Grid item xs={6} >
          <DropdownParameter
            hyperparam_name='Version'
            selectItems={modelVersions.map((version) => version.versionName)}
            value={selectedModelVersion}
            onChange={handleSelectedModelVersionChange}
            isRequired={true}
            helperText='Select a version'
          />
        </Grid>

        <Grid item xs={6}>
          <DropdownParameter
            hyperparam_name='Training Type'
            selectItems={trainingTypeOptions}
            value={trainingType}
            onChange={handleTrainingTypeChange}
            isRequired={false}
            helperText='Select training type'
          />
        </Grid>

        <Grid item xs={6}>
          <Typography variant="subtitle1" gutterBottom>Classes in dataset</Typography>
          <Box
            sx={{
              maxHeight: '200px', 
              overflowY: 'auto',  
            }}
          >
            {renderClassFields()}
          </Box>
        </Grid>


        <Grid item xs={6} sm={5}>
          <DropdownParameter 
            hyperparam_name='Learning Rate Scheduler'
            selectItems={learningRateSchedulerOptions}
            value={learningRateScheduler}
            onChange={handleLearningRateSchedulerChange}
            isRequired={false}
            helperText=''
          />
        </Grid>

        <Grid item xs={6} sm={3}>
          <IntegerParameter
            hyperparam_name='Learning Rate'
            min={0.00001}
            max={10}
            step={1}
            value={learningRate}
            isRequired={false}
            helperText=''
            onChange={handleLearningRateChange}
          />
        </Grid>

        <Grid item xs={6} sm={4}>
          <BooleanParameter 
            hyperparam_name='Reduce LR on Plateau'
            value={isReduceLR}
            helperText='Adjust learning rate on training plateau'
            onChange={handleReduceLRChange}
          />
        </Grid>

        <Grid item xs={6} sm={4}>
          <DropdownParameter
            hyperparam_name='Base Model Architecture'
            selectItems={baseModelArchitectureOptions}
            value={baseModelArchitecture}
            onChange={handleBaseModelArchitectureChange}
            isRequired={false}
            helperText=''
          />
        </Grid>

        <Grid item xs={6} sm={2}>
          <BooleanParameter
            hyperparam_name='Use Attention'
            value={isAttention}
            helperText='Enhance focus on key features'
            onChange={handleIsAttentionChange}
          />
        </Grid>

        <Grid item xs={6} sm={2}>
          <DropdownParameter
            hyperparam_name='Attention Type'
            selectItems={attentionTypeOptions}
            value={attentionType}
            onChange={handleAttentionTypeChange}
            isRequired={false}
            helperText='Select attention type'
          />
        </Grid>
        <Grid item xs={6} sm={2}>
          <DropdownParameter
            hyperparam_name='Decoder Block Type'
            selectItems={decoderBlockTypeOptions}
            value={decoderType}
            onChange={handleDecoderTypeChange}
            isRequired={false}
            helperText='Select decoder block type'
          />
        </Grid>
        <Grid item xs={6} sm={2}>
          <DropdownParameter
            hyperparam_name='Encoder Block Type'
            selectItems={encoderBlockTypeOptions}
            value={encoderType}
            onChange={handleEncoderTypeChange}
            isRequired={false}
            helperText='Select encoder block type'
          />
        </Grid>

        <Tooltip title={"Number of times the model will be trained. Ranging from 100 to 10000"}>
          <Grid item xs={6} sm={2}>
            <IntegerParameter 
              hyperparam_name='Number of Epochs' 
              min={1} 
              max={10000} 
              step={1}
              value={numEpochs}
              isRequired={true}
              helperText=''
              onChange={handleNumEpochsChange}
            />
          </Grid>
        </Tooltip>

        <Tooltip title={"Number of training samples per batch. Ranging from 2 to 128, in steps of 2"}>
        <Grid item xs={6} sm={3}>
          <IntegerParameter 
            hyperparam_name='Batch Size' 
            min={1} 
            max={10000} 
            step={1}
            value={batchSize}
            isRequired={true}
            helperText='Number of Training Samples' 
            onChange={handleBatchSizeChange}
          />
        </Grid>
        </Tooltip>

        <Tooltip title={"Size of image tiles. Defaulted to be 224."}>
        <Grid item xs={6} sm={2}>
          <IntegerParameter 
            hyperparam_name='Tile Size' 
            min={1} 
            max={10000} 
            step={1}
            value={tileSize}
            isRequired={true}
            helperText='Size of image tiles' 
            onChange={handleTileSizeChange}
          />
        </Grid>
        </Tooltip>

        <Tooltip title={"Number of epochs to wait before early stopping. Ranging from 10 to 20"}>
        <Grid item xs={6} sm={2}>
          <IntegerParameter 
            hyperparam_name='Patience' 
            min={1} 
            max={10000} 
            step={1}
            value={patience}
            isRequired={true}
            helperText='Number of Epochs' 
            onChange={handlePatienceChange}
          />
        </Grid>
        </Tooltip>

        <Grid item xs={6} sm={3}>
          <BooleanParameter 
            hyperparam_name='Data Augmentation'
            value={isDataAug}
            onChange={handleDataAugChange}
            helperText="Enhance training data with variations"
          />
        </Grid>

        <Grid item xs={12}>
          <StringParameter 
            hyperparam_name=' Base Model Name' 
            label='Model will be saved as...' 
            value={baseModelName}
            isRequired={true}
            onChange={handleBaseModelNameChange}
          />
        </Grid>
      </Grid>

      <Box mt={4} />

      <Grid container justifyContent="center">
        <Button variant="contained" onClick={handleSubmit}>Save job and proceed to select instance</Button>
      </Grid>

      <Modal
        aria-labelledby="transition-modal-title"
        aria-describedby="transition-modal-description"
        open={isModalOpen}
        onClose={() => setIsModalOpen(false)}
        sx={{
          '& .MuiBackdrop-root': {
            backgroundColor: 'rgba(0, 0, 0, 0.3)', 
            opacity: '0.1 !important'
          },
          display: 'flex',
          alignItems: 'center',
          justifyContent: 'center',
        }}
        >
        <Box sx={{
            width: 400, 
            bgcolor: 'background.paper', 
            p: 2, 
            borderRadius: 2
        }}>
          {isSuccessRecordNewModel && isRedirect ? (
            <Box
              display="flex"
              flexDirection="column"
              alignItems="center"
              justifyContent="center"
            >
              <CircularProgress />
              <Typography id="transition-modal-description" sx={{ mt: 2 }} variant="h6" component="h2">
                ...redirecting to select instance
              </Typography>
            </Box>
          ) : (
            <Box
              display="flex"
              flexDirection="column"
              alignItems="center"
              justifyContent="center"
              // textAlign="center"
            >
              <CheckCircleOutlineIcon color="success" fontSize="large" />
              <Typography id="transition-modal-title" variant="h6" component="h2">
                Successfully created Model
              </Typography>
            </Box>
          )}
        </Box>    
      </Modal>

    </Container>
  );
}
