"""
Face recognition service using DeepFace (ArcFace) and pgvector for similarity search.
"""

import logging
import tempfile
import os

import cv2
import numpy as np
from deepface import DeepFace
from django.core.files.base import ContentFile
from pgvector.django import CosineDistance

from faces.models import FaceRecord
from .preprocessing import FacePreprocessor

logger = logging.getLogger(__name__)


class FaceRecognitionService:
    """
    Provides face registration and identification using ArcFace embeddings
    and pgvector cosine similarity search.
    """

    def __init__(self):
        self.preprocessor = FacePreprocessor()

    def extract_embedding(self, preprocessed_face: np.ndarray) -> list:
        """
        Extract a 512-dimensional ArcFace embedding from a preprocessed face image.

        Args:
            preprocessed_face: Preprocessed face as float32 numpy array,
                               shape (160, 160, 3), values in [0, 1].

        Returns:
            List of 512 float values representing the face embedding.

        Raises:
            RuntimeError: If DeepFace fails to extract the embedding.
        """
        try:
            # DeepFace expects uint8 [0, 255] images
            face_uint8 = (preprocessed_face * 255).astype(np.uint8)

            # Write to a temporary file for DeepFace
            temp_path = None
            try:
                with tempfile.NamedTemporaryFile(
                    suffix=".jpg", delete=False
                ) as tmp_file:
                    temp_path = tmp_file.name
                    cv2.imwrite(
                        temp_path,
                        cv2.cvtColor(face_uint8, cv2.COLOR_RGB2BGR),
                    )

                # Extract embedding using ArcFace with detector skipped
                results = DeepFace.represent(
                    img_path=temp_path,
                    model_name="ArcFace",
                    detector_backend="skip",
                    enforce_detection=False,
                )

                if not results or len(results) == 0:
                    raise RuntimeError("DeepFace returned no embeddings.")

                embedding = results[0]["embedding"]

                if len(embedding) != 512:
                    logger.warning(
                        "Expected 512-dim embedding, got %d-dim. Proceeding anyway.",
                        len(embedding),
                    )

                return embedding

            finally:
                if temp_path and os.path.exists(temp_path):
                    os.unlink(temp_path)

        except RuntimeError:
            raise
        except Exception as e:
            logger.error("DeepFace embedding extraction failed: %s", str(e))
            raise RuntimeError(
                f"Failed to extract face embedding: {str(e)}"
            ) from e

    def register_face(
        self, name: str, image_input, metadata: dict = None
    ) -> FaceRecord:
        """
        Register a new face:
        1. Preprocess image via FacePreprocessor
        2. Extract 512-dim ArcFace embedding
        3. Save FaceRecord to database

        Args:
            name: Name to associate with the face.
            image_input: Image as file path, PIL Image, numpy array, or file-like.
            metadata: Optional metadata dict.

        Returns:
            The saved FaceRecord instance.

        Raises:
            ValueError: If no face is detected.
            RuntimeError: If embedding extraction fails.
        """
        if metadata is None:
            metadata = {}

        # 1. Preprocess
        preprocessed = self.preprocessor.preprocess(image_input)

        # 2. Extract embedding
        embedding = self.extract_embedding(preprocessed)

        # 3. Save the photo file
        # If image_input is a file-like object (Django UploadedFile), save it directly
        if hasattr(image_input, "read"):
            image_input.seek(0)
            photo_content = image_input.read()
            photo_name = getattr(image_input, "name", f"{name}.jpg")
            photo_file = ContentFile(photo_content, name=photo_name)
        else:
            # Convert preprocessed face to JPEG bytes for storage
            face_uint8 = (preprocessed * 255).astype(np.uint8)
            face_bgr = cv2.cvtColor(face_uint8, cv2.COLOR_RGB2BGR)
            success, buffer = cv2.imencode(".jpg", face_bgr)
            if not success:
                raise RuntimeError("Failed to encode preprocessed face to JPEG.")
            photo_file = ContentFile(buffer.tobytes(), name=f"{name}.jpg")

        # 4. Save to database
        record = FaceRecord(
            name=name,
            embedding=embedding,
            metadata=metadata,
        )
        record.photo.save(photo_file.name, photo_file, save=False)
        record.save()

        logger.info(
            "Registered face for '%s' (id=%d) with %d-dim embedding.",
            name,
            record.id,
            len(embedding),
        )

        return record

    def identify_face(self, image_input, threshold: float = 0.4) -> dict:
        """
        Identify a face against registered embeddings using cosine similarity.

        Args:
            image_input: Image as file path, PIL Image, numpy array, or file-like.
            threshold: Maximum cosine distance for a match (default 0.4).

        Returns:
            Dict with keys: matched, name, confidence, id, distance.

        Raises:
            ValueError: If no face is detected.
            RuntimeError: If embedding extraction fails.
        """
        # 1. Preprocess
        preprocessed = self.preprocessor.preprocess(image_input)

        # 2. Extract embedding
        query_embedding = self.extract_embedding(preprocessed)

        # 3. Check if there are any registered faces
        if FaceRecord.objects.count() == 0:
            return {
                "matched": False,
                "name": "Unknown",
                "confidence": 0.0,
                "id": None,
                "distance": float("inf"),
            }

        # 4. Query using pgvector CosineDistance
        nearest = (
            FaceRecord.objects.annotate(
                distance=CosineDistance("embedding", query_embedding)
            )
            .order_by("distance")
            .first()
        )

        if nearest is None:
            return {
                "matched": False,
                "name": "Unknown",
                "confidence": 0.0,
                "id": None,
                "distance": float("inf"),
            }

        distance = float(nearest.distance)

        if distance <= threshold:
            confidence = round(1.0 - distance, 4)
            logger.info(
                "Face identified as '%s' (id=%d) with confidence=%.4f, distance=%.4f",
                nearest.name,
                nearest.id,
                confidence,
                distance,
            )
            return {
                "matched": True,
                "name": nearest.name,
                "confidence": confidence,
                "id": nearest.id,
                "distance": round(distance, 4),
            }
        else:
            logger.info(
                "Face not matched. Nearest: '%s' (id=%d), distance=%.4f > threshold=%.4f",
                nearest.name,
                nearest.id,
                distance,
                threshold,
            )
            return {
                "matched": False,
                "name": "Unknown",
                "confidence": 0.0,
                "id": None,
                "distance": round(distance, 4),
            }

    def list_registered_faces(self) -> list:
        """
        Return all registered faces with id, name, created_at, metadata.
        Excludes embeddings for performance.

        Returns:
            List of dicts with face record info.
        """
        records = FaceRecord.objects.all().values(
            "id", "name", "created_at", "metadata"
        )
        return list(records)

    def delete_face(self, face_id: int) -> bool:
        """
        Delete a registered face by ID.

        Args:
            face_id: The ID of the face record to delete.

        Returns:
            True if the record was deleted, False if not found.
        """
        try:
            record = FaceRecord.objects.get(id=face_id)
            # Delete the associated photo file
            if record.photo:
                record.photo.delete(save=False)
            record.delete()
            logger.info("Deleted face record id=%d ('%s').", face_id, record.name)
            return True
        except FaceRecord.DoesNotExist:
            logger.warning("Face record id=%d not found for deletion.", face_id)
            return False
