import { useRef, useState, useEffect } from "react"
import { Grid } from "@mui/material";
import ClassIndicator from "./ClassIndicator";

interface ImageOverlayProps {
  baseImage: File;
  labelImage: File;
  classes: number[];
}


function ImageOverlay({ baseImage, labelImage, classes }: ImageOverlayProps) {
  // Make the pixel class on the label image more differentiable
  const canvasRef = useRef<HTMLCanvasElement>(null);
  const [highlightZero, setHighlightZero] = useState(true); // True to highlight class 0, false for class 255

  useEffect(() => {
    if (labelImage && canvasRef.current) {
      const canvas = canvasRef.current;
      const ctx = canvas.getContext('2d');

      if (!ctx) {
        console.error('Failed to get canvas context');
        return; 
      }

      const img = new Image();

      img.onload = () => {
        canvas.width = img.width;
        canvas.height = img.height;
        ctx.drawImage(img, 0, 0);
        const imageData = ctx.getImageData(0, 0, img.width, img.height);
        const data = imageData.data;

        // Highlight the specified class
        for (let i = 0; i < data.length; i += 4) {
          if (data[i] === 0 && data[i + 1] === 0 && data[i + 2] === 0 && data[i + 3] === 255) {
            data[i] = 255; // Red
            data[i + 1] = 0; // Green
            data[i + 2] = 0; // Blue
          }
          
          if (data[i] === 255 && data[i + 1] === 255 && data[i + 2] === 255 && data[i + 3] === 255) {
            data[i] = 0; // Red
            data[i + 1] = 0; // Green
            data[i + 2] = 255; // Blue
          }
        }
        ctx.putImageData(imageData, 0, 0);
      };

      img.src = URL.createObjectURL(labelImage);
    }
  }, [labelImage, baseImage, highlightZero]);


  return (
    <Grid container>
      <Grid item xs={12} sx={{display: 'flex', flexDirection: 'column', justifyContent: 'center', alignItems: 'center',}}>
        {(labelImage && baseImage) && <ClassIndicator classes={classes} />}
      </Grid>
      <Grid item xs={12} sx={{display: 'flex', flexDirection: 'column', justifyContent: 'center', alignItems: 'center',}}>
        <div style={{ position: 'relative', width: '35%'}}>
          {baseImage && <img src={URL.createObjectURL(baseImage)} alt="Base" style={{ width: '100%', display: 'block' }} />}
          {(labelImage && baseImage) && <canvas ref={canvasRef} style={{ width: '100%', position: 'absolute', top: 0, left: 0, opacity: 0.5 }} />}
        </div>
      </Grid>

    </Grid>
  );
}

export default ImageOverlay;
