"""
Face preprocessing pipeline using OpenCV.
Handles face detection, alignment, cropping, and normalization.
"""

import logging

import cv2
import numpy as np
from PIL import Image, ExifTags

logger = logging.getLogger(__name__)


class FacePreprocessor:
    """
    Preprocesses images for face recognition:
    1. Load and convert to RGB
    2. Apply EXIF rotation correction
    3. Detect face using Haarcascade
    4. Crop the largest face with padding
    5. Align using eye landmarks
    6. Resize to 160x160
    7. Apply histogram equalization
    8. Normalize pixel values to [0, 1]
    """

    def __init__(self):
        # Load Haarcascade classifiers
        self.face_cascade = cv2.CascadeClassifier(
            cv2.data.haarcascades + "haarcascade_frontalface_default.xml"
        )
        self.eye_cascade = cv2.CascadeClassifier(
            cv2.data.haarcascades + "haarcascade_eye.xml"
        )

    def _load_image(self, image_input) -> np.ndarray:
        """
        Load image from various input types and return as RGB numpy array.
        Supports: file path (str), PIL Image, numpy array, Django UploadedFile.
        """
        if isinstance(image_input, str):
            # File path
            img = cv2.imread(image_input)
            if img is None:
                raise ValueError(f"Could not load image from path: {image_input}")
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            return img

        elif isinstance(image_input, Image.Image):
            # PIL Image
            img = self._apply_exif_rotation(image_input)
            img = np.array(img.convert("RGB"))
            return img

        elif isinstance(image_input, np.ndarray):
            # Numpy array — check if grayscale
            if len(image_input.shape) == 2:
                img = cv2.cvtColor(image_input, cv2.COLOR_GRAY2RGB)
            elif image_input.shape[2] == 4:
                img = cv2.cvtColor(image_input, cv2.COLOR_BGRA2RGB)
            elif image_input.shape[2] == 3:
                img = image_input.copy()
            else:
                raise ValueError(f"Unsupported image shape: {image_input.shape}")
            return img

        elif hasattr(image_input, "read"):
            # Django UploadedFile or file-like object
            image_input.seek(0)
            pil_img = Image.open(image_input)
            pil_img = self._apply_exif_rotation(pil_img)
            img = np.array(pil_img.convert("RGB"))
            return img

        else:
            raise ValueError(
                f"Unsupported image input type: {type(image_input).__name__}. "
                "Expected str (path), PIL Image, numpy array, or file-like object."
            )

    def _apply_exif_rotation(self, pil_image: Image.Image) -> Image.Image:
        """Apply EXIF orientation tag rotation to a PIL Image."""
        try:
            exif = pil_image._getexif()
            if exif is None:
                return pil_image

            orientation_key = None
            for key, val in ExifTags.TAGS.items():
                if val == "Orientation":
                    orientation_key = key
                    break

            if orientation_key is None or orientation_key not in exif:
                return pil_image

            orientation = exif[orientation_key]

            rotation_map = {
                3: 180,
                6: 270,
                8: 90,
            }

            if orientation in rotation_map:
                pil_image = pil_image.rotate(rotation_map[orientation], expand=True)
            elif orientation == 2:
                pil_image = pil_image.transpose(Image.FLIP_LEFT_RIGHT)
            elif orientation == 4:
                pil_image = pil_image.transpose(Image.FLIP_TOP_BOTTOM)
            elif orientation == 5:
                pil_image = pil_image.rotate(270, expand=True).transpose(Image.FLIP_LEFT_RIGHT)
            elif orientation == 7:
                pil_image = pil_image.rotate(90, expand=True).transpose(Image.FLIP_LEFT_RIGHT)

        except (AttributeError, KeyError, IndexError):
            pass

        return pil_image

    def _detect_faces(self, img: np.ndarray) -> list:
        """Detect faces using Haarcascade and return bounding boxes."""
        gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        faces = self.face_cascade.detectMultiScale(
            gray,
            scaleFactor=1.1,
            minNeighbors=5,
            minSize=(30, 30),
            flags=cv2.CASCADE_SCALE_IMAGE,
        )

        if len(faces) == 0:
            raise ValueError(
                "No face detected in image. Please ensure the face is clearly "
                "visible, well-lit, and facing the camera."
            )

        return faces

    def _get_largest_face(self, faces) -> tuple:
        """Select the largest detected face by area."""
        if len(faces) == 1:
            return tuple(faces[0])

        areas = [w * h for (x, y, w, h) in faces]
        largest_idx = np.argmax(areas)
        return tuple(faces[largest_idx])

    def _crop_face(self, img: np.ndarray, face_box: tuple, padding: int = 20) -> np.ndarray:
        """Crop face region with padding, clamped to image bounds."""
        x, y, w, h = face_box
        height, width = img.shape[:2]

        x1 = max(0, x - padding)
        y1 = max(0, y - padding)
        x2 = min(width, x + w + padding)
        y2 = min(height, y + h + padding)

        return img[y1:y2, x1:x2]

    def _align_face(self, face_crop: np.ndarray) -> np.ndarray:
        """
        Align face using eye detection.
        If two eyes are found, compute the angle between them and
        apply warpAffine rotation to align the face horizontally.
        """
        gray = cv2.cvtColor(face_crop, cv2.COLOR_RGB2GRAY)
        eyes = self.eye_cascade.detectMultiScale(
            gray,
            scaleFactor=1.1,
            minNeighbors=10,
            minSize=(15, 15),
        )

        if len(eyes) >= 2:
            # Sort by x-coordinate (left-to-right)
            eyes_sorted = sorted(eyes, key=lambda e: e[0])
            left_eye = eyes_sorted[0]
            right_eye = eyes_sorted[1]

            # Compute center of each eye
            left_center = (
                left_eye[0] + left_eye[2] // 2,
                left_eye[1] + left_eye[3] // 2,
            )
            right_center = (
                right_eye[0] + right_eye[2] // 2,
                right_eye[1] + right_eye[3] // 2,
            )

            # Compute angle between eyes
            dy = right_center[1] - left_center[1]
            dx = right_center[0] - left_center[0]
            angle = np.degrees(np.arctan2(dy, dx))

            # Compute center of face for rotation
            h, w = face_crop.shape[:2]
            center = (w // 2, h // 2)

            # Apply rotation
            rotation_matrix = cv2.getRotationMatrix2D(center, angle, scale=1.0)
            aligned = cv2.warpAffine(
                face_crop,
                rotation_matrix,
                (w, h),
                flags=cv2.INTER_CUBIC,
                borderMode=cv2.BORDER_REPLICATE,
            )
            return aligned

        # If eyes not reliably detected, return the crop as-is
        logger.debug("Could not detect 2 eyes for alignment, using unaligned crop.")
        return face_crop

    def _apply_histogram_equalization(self, img: np.ndarray) -> np.ndarray:
        """
        Apply histogram equalization on the L channel (LAB color space)
        to normalize lighting conditions.
        """
        lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
        l_channel, a_channel, b_channel = cv2.split(lab)

        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
        l_equalized = clahe.apply(l_channel)

        lab_equalized = cv2.merge([l_equalized, a_channel, b_channel])
        result = cv2.cvtColor(lab_equalized, cv2.COLOR_LAB2RGB)
        return result

    def preprocess(self, image_input) -> np.ndarray:
        """
        Full preprocessing pipeline:
        1. Load image (handles str path, PIL Image, numpy array, file-like)
        2. Detect face using Haarcascade
        3. Crop the largest detected face with 20px padding
        4. Align using eye landmarks
        5. Resize to 160x160 (ArcFace input size)
        6. Apply histogram equalization on L channel (LAB)
        7. Normalize pixel values to [0, 1]
        8. Return as float32 numpy array shape (160, 160, 3)

        Raises:
            ValueError: If no face is detected in the image.
        """
        # 1. Load image
        img = self._load_image(image_input)

        # Handle grayscale images
        if len(img.shape) == 2:
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)

        # 2. Detect faces
        faces = self._detect_faces(img)

        # 3. Select the largest face
        largest_face = self._get_largest_face(faces)

        # 4. Crop with padding
        face_crop = self._crop_face(img, largest_face, padding=20)

        # 5. Align using eyes
        aligned_face = self._align_face(face_crop)

        # 6. Resize to 160x160
        resized = cv2.resize(aligned_face, (160, 160), interpolation=cv2.INTER_CUBIC)

        # 7. Histogram equalization on L channel
        equalized = self._apply_histogram_equalization(resized)

        # 8. Normalize to [0, 1] float32
        normalized = equalized.astype(np.float32) / 255.0

        logger.info(
            "Preprocessing complete: face detected, cropped, aligned, "
            "resized to 160x160, equalized, normalized."
        )

        return normalized
