"""
engine/slide_cloner.py
----------------------
Clones slides from any source PPTX (template or infographic bank)
into an output PPTX, preserving all shapes, images, and styling.

Core mechanism — everything else builds on top of this.

Usage:
    cloner = SlideCloner(output_path, template_path="Template new.pptx")
    cloner.clone_slide("Template new.pptx", 0)          # cover
    cloner.clone_slide("infographic_bank/Linear.pptx", 10)  # bank slide
    cloner.save()
"""

import copy
import re as _re
from io import BytesIO
from pathlib import Path
from zipfile import ZipFile, ZIP_DEFLATED
from pptx import Presentation
from pptx.oxml.ns import qn
from pptx.opc.package import Part
from pptx.opc.packuri import PackURI

# Relationship types for images (standard + Microsoft HD Photo)
REL_IMAGE   = "http://schemas.openxmlformats.org/officeDocument/2006/relationships/image"
REL_HDPHOTO = "http://schemas.microsoft.com/office/2007/relationships/hdphoto"
REL_NS      = "http://schemas.openxmlformats.org/officeDocument/2006/relationships"
_IMAGE_RELTYPES = {REL_IMAGE, REL_HDPHOTO}

# Map MIME content-type → file extension used when naming image parts
_CT_EXT = {
    "image/png":          ".png",
    "image/jpeg":         ".jpeg",
    "image/jpg":          ".jpg",
    "image/gif":          ".gif",
    "image/bmp":          ".bmp",
    "image/tiff":         ".tiff",
    "image/x-wmf":        ".wmf",
    "image/x-emf":        ".emf",
    "image/svg+xml":      ".svg",
    "image/vnd.ms-photo": ".wdp",
}


class SlideCloner:
    """
    Builds an output PPTX by cloning slides from source files.

    Starting from a copy of the template ensures the output inherits
    the correct slide master, layouts, theme, and fonts — so template
    slides look identical to the original and brand colors are available
    for infographic slides.
    """

    def __init__(self, output_path: str, template_path: str):
        self.output_path = str(output_path)
        self.template_path = str(Path(template_path).resolve())
        # Strip template slides but keep its master/theme/fonts/colors so that
        # scheme colors (accent1, accent2, bg1 …) resolve correctly in the output.
        Path(output_path).parent.mkdir(parents=True, exist_ok=True)
        _strip_slides_to_path(str(template_path), str(output_path))
        self._prs = Presentation(str(output_path))
        self._source_cache:   dict[str, Presentation] = {}
        self._img_blob_cache: dict[str, Part]        = {}  # md5 → deduplicated Part
        self._img_counter:    int                    = 0   # monotonic counter for unique names

    # ------------------------------------------------------------------ #
    # Public API
    # ------------------------------------------------------------------ #

    def clone_slide(self, source_path: str, slide_index: int):
        """
        Clone slide at slide_index from source_path into the output PPTX.
        Returns the new slide object.
        """
        source_prs   = self._open(source_path)
        source_slide = source_prs.slides[slide_index]

        # Pick the layout:
        # • When cloning from the SAME template the output was initialised with,
        #   match the source slide's actual layout (preserves the template
        #   designer's intent for cover/about/content/etc.).
        # • When cloning from a DIFFERENT pptx (bank, drill_down_bank, gold
        #   standard), the source's layout belongs to that other file's master.
        #   Don't try to match — pick the OUTPUT's decorated blank layout so
        #   the cloned content inherits the active template's chrome (borders,
        #   watermark) from the decorated master.
        src_resolved = str(Path(source_path).resolve())
        if src_resolved == self.template_path:
            target_layout = self._matching_layout(source_slide) or self._blank_layout()
        else:
            target_layout = self._blank_layout()
        new_slide = self._prs.slides.add_slide(target_layout)

        # Clear the default placeholder shapes added by the blank layout
        sp_tree = new_slide.shapes._spTree
        for ph in list(new_slide.placeholders):
            ph._element.getparent().remove(ph._element)

        # Copy image/media relationships and build rId remapping table
        rId_map = self._copy_rels(source_slide, new_slide)
        source_rids = set(source_slide.part.rels.keys())

        # Deep-copy every shape from the source slide.
        # Skip any shape that references an rId that exists in the source but
        # was not copied (e.g. SmartArt/diagram parts) — those would create
        # dangling relationship references and trigger PowerPoint's repair dialog.
        src_tree = source_slide.shapes._spTree
        for elem in list(src_tree)[2:]:          # skip nvGrpSpPr & grpSpPr
            if _has_unresolved_rids(elem, rId_map, source_rids):
                print(f"    [cloner] skipped shape with unresolved rId (diagram/chart)")
                continue
            cloned = copy.deepcopy(elem)
            _remap_rids(cloned, rId_map)
            sp_tree.append(cloned)

        # Copy background fill / image (solid color, gradient, picture)
        self._copy_background(source_slide, new_slide)

        return new_slide

    def copy_elements_from(
        self,
        source_path: str,
        slide_index: int,
        target_slide,
        filter_fn=None,
    ) -> None:
        """
        Copy shapes from a source slide into target_slide's spTree.

        filter_fn: callable(elem) -> bool — if given, only matching elements are copied.
        Image relationships are deduplicated via the shared blob cache so each
        unique image blob is stored only once in the output ZIP.
        """
        import copy
        import hashlib

        src_prs   = self._open(source_path)
        src_slide = src_prs.slides[slide_index]
        sp_tree   = target_slide.shapes._spTree

        # Build rId map: source rId -> new rId in target slide
        rId_map = {}
        for rId, rel in src_slide.part.rels.items():
            if rel.is_external or rel.reltype not in _IMAGE_RELTYPES:
                continue
            try:
                blob = rel.target_part.blob
                key  = hashlib.md5(blob).hexdigest()
                if key in self._img_blob_cache:
                    new_rId = target_slide.part.relate_to(
                        self._img_blob_cache[key], rel.reltype
                    )
                else:
                    self._img_counter += 1
                    ct  = rel.target_part.content_type
                    ext = _CT_EXT.get(ct, ".bin")
                    partname  = PackURI(f"/ppt/media/image{self._img_counter:04d}{ext}")
                    new_part  = Part(partname, ct, target_slide.part.package, blob)
                    new_rId   = target_slide.part.relate_to(new_part, rel.reltype)
                    self._img_blob_cache[key] = new_part
                rId_map[rId] = new_rId
            except Exception:
                pass

        # Copy matching elements, skipping any that reference unresolved rIds
        src_rids = set(src_slide.part.rels.keys())
        for elem in list(src_slide.shapes._spTree)[2:]:  # skip nvGrpSpPr & grpSpPr
            if filter_fn is not None and not filter_fn(elem):
                continue
            if _has_unresolved_rids(elem, rId_map, src_rids):
                continue
            cloned = copy.deepcopy(elem)
            _remap_rids(cloned, rId_map)
            sp_tree.append(cloned)

    def save(self):
        """Write the output PPTX to disk."""
        Path(self.output_path).parent.mkdir(parents=True, exist_ok=True)
        self._prs.save(self.output_path)
        print(f"Saved: {self.output_path}")
        return self.output_path

    def slide_count(self) -> int:
        return len(self._prs.slides)

    # ------------------------------------------------------------------ #
    # Internal helpers
    # ------------------------------------------------------------------ #

    def _open(self, path: str) -> Presentation:
        """Cache-aware presentation loader."""
        key = str(Path(path).resolve())
        if key not in self._source_cache:
            self._source_cache[key] = Presentation(str(path))
        return self._source_cache[key]

    def _all_layouts(self):
        """
        All slide layouts across all masters in the output presentation.
        prs.slide_layouts only returns master 0's layouts — multi-master
        templates require iterating every master.
        """
        layouts = []
        for master in self._prs.slide_masters:
            for layout in master.slide_layouts:
                layouts.append(layout)
        return layouts


    def _matching_layout(self, source_slide):
        """
        Find the output layout matching the source slide's layout (by name +
        master decoration). Returns None if no match. Preserves master
        decoration when cloning from a template with multiple masters.
        """
        try:
            src_layout = source_slide.slide_layout
            src_name   = src_layout.name
            src_master_shapes = len(list(src_layout.slide_master.shapes))
        except Exception:
            return None
        candidates = [l for l in self._all_layouts() if l.name == src_name]
        if not candidates:
            return None
        if len(candidates) == 1:
            return candidates[0]
        # Multi-master templates have duplicate names — pick by master shape count.
        def _score(layout):
            try:
                return -abs(len(list(layout.slide_master.shapes)) - src_master_shapes)
            except Exception:
                return -999
        return max(candidates, key=_score)


    def _blank_layout(self):
        """
        Return a "blank" layout whose slide master has the most decoration —
        templates often have multiple masters; the decorated one carries the
        brand borders/watermark that we want on every output slide.
        """
        layouts = self._all_layouts()
        blank_layouts = [l for l in layouts if l.name.lower() == "blank"]
        if not blank_layouts:
            return layouts[-1]
        def _master_shape_count(layout):
            try:
                return len(list(layout.slide_master.shapes))
            except Exception:
                return 0
        return max(blank_layouts, key=_master_shape_count)

    def _copy_rels(self, source_slide, new_slide) -> dict:
        """
        Copy image/media relationships from source to new slide.
        Every new image blob is stored under a unique sequential partname
        (image0001.png, image0002.png …) so multiple source files that
        internally share common names (image1.png, image10.png …) never
        produce duplicate ZIP entries in the output PPTX.
        Returns {source_rId: new_rId} for XML patching.
        """
        import hashlib
        rId_map = {}
        for rId, rel in source_slide.part.rels.items():
            if rel.is_external or rel.reltype not in _IMAGE_RELTYPES:
                continue
            try:
                blob = rel.target_part.blob
                key  = hashlib.md5(blob).hexdigest()
                if key in self._img_blob_cache:
                    # Identical blob already in output — reuse the existing Part
                    new_rId = new_slide.part.relate_to(
                        self._img_blob_cache[key], rel.reltype
                    )
                else:
                    # New blob — create a Part with a guaranteed-unique partname
                    self._img_counter += 1
                    ct  = rel.target_part.content_type
                    ext = _CT_EXT.get(ct, ".bin")
                    partname  = PackURI(f"/ppt/media/image{self._img_counter:04d}{ext}")
                    new_part  = Part(partname, ct, new_slide.part.package, blob)
                    new_rId   = new_slide.part.relate_to(new_part, rel.reltype)
                    self._img_blob_cache[key] = new_part
                rId_map[rId] = new_rId
            except Exception:
                pass
        return rId_map

    def _copy_background(self, source_slide, new_slide):
        """Copy the p:bg element (background fill/image) if present."""
        src_cSld = source_slide._element.find(qn("p:cSld"))
        dst_cSld = new_slide._element.find(qn("p:cSld"))
        if src_cSld is None or dst_cSld is None:
            return
        src_bg = src_cSld.find(qn("p:bg"))
        if src_bg is None:
            return
        # Remove existing background from destination
        dst_bg = dst_cSld.find(qn("p:bg"))
        if dst_bg is not None:
            dst_cSld.remove(dst_bg)
        dst_cSld.insert(0, copy.deepcopy(src_bg))


# ------------------------------------------------------------------ #
# Module-level helpers
# ------------------------------------------------------------------ #

_REL_PREFIX = f"{{{REL_NS}}}"


def _has_unresolved_rids(element, rId_map: dict, source_rids: set) -> bool:
    """
    Return True if *element* (or any descendant) references an rId that exists
    in the source slide's rels but was NOT carried over into rId_map.
    Such a shape (e.g. SmartArt, chart, OLE object) would produce a dangling
    relationship reference in the output slide and trigger PowerPoint's repair.

    Checks ALL r:* attributes (r:embed, r:link, r:id, r:dm, r:lo, r:qs, r:cs,
    r:cd, etc.) — SmartArt uses r:dm/r:lo/etc., charts use r:id, OLE uses r:id.
    """
    for attr_name, val in element.attrib.items():
        if attr_name.startswith(_REL_PREFIX) and val in source_rids and val not in rId_map:
            return True
    return any(_has_unresolved_rids(child, rId_map, source_rids) for child in element)


def _remap_rids(element, rId_map: dict):
    """
    Recursively patch all r:* attributes in copied XML so they point to the
    new slide's relationship IDs.
    """
    for attr_name in list(element.attrib.keys()):
        if attr_name.startswith(_REL_PREFIX):
            old = element.get(attr_name)
            if old and old in rId_map:
                element.set(attr_name, rId_map[old])
    for child in element:
        _remap_rids(child, rId_map)


def _strip_slides_to_path(template_path: str, output_path: str) -> None:
    """
    Write a copy of *template_path* to *output_path* with all slides removed
    but with the slide master, layouts, theme, and fonts intact.

    This lets SlideCloner inherit the template's color scheme and brand assets
    so scheme-color references (accent1, accent2, bg1 …) resolve correctly.
    """
    _SLIDE_RE     = _re.compile(r"^ppt/slides/slide\d+\.xml$")
    _SLIDE_REL_RE = _re.compile(r"^ppt/slides/_rels/slide\d+\.xml\.rels$")
    _SLIDE_CT_RE  = _re.compile(r'<Override[^>]+/ppt/slides/slide\d+\.xml[^>]*/>')

    _REL_SLIDE_RE = _re.compile(
        r'<Relationship[^>]+/officeDocument/2006/relationships/slide"[^>]*/>'
    )

    with ZipFile(template_path, "r") as src:
        buf = BytesIO()
        with ZipFile(buf, "w", ZIP_DEFLATED) as dst:
            for name in src.namelist():
                if _SLIDE_RE.match(name) or _SLIDE_REL_RE.match(name):
                    continue  # drop slide files entirely

                data = src.read(name)

                if name == "ppt/presentation.xml":
                    text = data.decode("utf-8")
                    # Empty the slide-ID list
                    text = _re.sub(
                        r"<p:sldIdLst>.*?</p:sldIdLst>",
                        "<p:sldIdLst/>",
                        text,
                        flags=_re.DOTALL,
                    )
                    data = text.encode("utf-8")

                elif name == "ppt/_rels/presentation.xml.rels":
                    text = data.decode("utf-8")
                    text = _REL_SLIDE_RE.sub("", text)
                    data = text.encode("utf-8")

                elif name == "[Content_Types].xml":
                    text = data.decode("utf-8")
                    text = _SLIDE_CT_RE.sub("", text)
                    data = text.encode("utf-8")

                dst.writestr(name, data)

        Path(output_path).write_bytes(buf.getvalue())
