import os
import tempfile
from typing import Optional

import FreeCADGui as Gui
import FreeCAD as App

class AHB_Render:
    def GetResources(self):
        return {"MenuText": "Render",
                "ToolTip": "Render current 3D view to a PNG image",
                "Pixmap": ""
                }

    def IsActive(self):
        return True
    
    def flatten_objects_tree(self, obj_list):
        result = []
        for obj in obj_list:
            if obj.TypeId == 'Part::FeaturePython' and hasattr(obj, 'LinkedObject'): # variant link
                result.extend(self.flatten_objects_tree(obj.Group))
            elif obj.TypeId in ['App::Link']:
                result.extend(self.flatten_objects_tree([obj.LinkedObject]))
            elif obj.TypeId in ['App::Part', 'App::DocumentObjectGroup']:
                result.extend(self.flatten_objects_tree(obj.Group))
            elif obj.TypeId in ['Part::Feature', 'Part::FeaturePython', 'PartDesign::Body', 'PartDesign::CoordinateSystem', 'PartDesign::Line', 'Part::Mirroring', 'Part::Cut']:
                result.append(obj)
                if hasattr(obj, 'Group'):
                    result.extend(self.flatten_objects_tree(obj.Group))
            
        return result
    
    def renderable_objects(self):
        doc = App.activeDocument()
        return self.flatten_objects_tree(doc.Objects)
    
    def should_render(self, obj):
        return obj.TypeId in ['Part::Feature', 'Part::FeaturePython', 'PartDesign::Body', 'Part::Mirroring', 'Part::Cut']
    
    def should_mask(self, obj):
        return False

    def set_render_lines(self, line_color = (0.0,0.0,0.0,0.0), background_color = (1.0,1.0,1.0,0.0), mask_color=(1.0,1.0,1.0)):
        doc = App.activeDocument()
        for obj in self.renderable_objects():
            if self.should_render(obj):
                #print(obj.Name, obj.TypeId, hasattr(obj, 'LinkedObject'), obj)
                masked = self.should_mask(obj)
                obj.ViewObject.LineColor = line_color
                obj.ViewObject.DisplayMode = 'Flat Lines' if not masked else 'Shaded'
                obj.ViewObject.ShapeMaterial.AmbientColor = (0.0, 0.0, 0.0, 0.0)
                obj.ViewObject.ShapeMaterial.DiffuseColor = (0.0, 0.0, 0.0, 0.0)
                obj.ViewObject.ShapeMaterial.SpecularColor = (0.0, 0.0, 0.0, 0.0)
                obj.ViewObject.ShapeMaterial.EmissiveColor = background_color if not masked else mask_color
                obj.ViewObject.Visibility = True
            else:
                obj.ViewObject.Visibility = False

    def set_render_outlines(self):
        doc = App.activeDocument()
        step = 8
        r = step
        g = step
        b = step
        for obj in self.renderable_objects():
            if self.should_render(obj):
                masked = self.should_mask(obj)
                obj.ViewObject.DisplayMode = 'Shaded'
                obj.ViewObject.ShapeMaterial.AmbientColor = (0.0, 0.0, 0.0, 0.0)
                obj.ViewObject.ShapeMaterial.DiffuseColor = (0.0, 0.0, 0.0, 0.0)
                obj.ViewObject.ShapeMaterial.SpecularColor = (0.0, 0.0, 0.0, 0.0)
                #obj.ViewObject.ShapeMaterial.EmissiveColor = (r/255.0, g/255.0, b/255.0, 0.0) if not masked else (1.0,1.0,1.0,0.0)
                obj.ViewObject.ShapeMaterial.EmissiveColor = (r/255.0, g/255.0, b/255.0, 0.0)
                if masked:
                    obj.ViewObject.Visibility = False

                r = r + step
                if r >= 256 - step:
                    r = step
                    g = g + step
                    if g >= 256 - step:
                        g = step
                        b = b + step
                        if b >= 256 - step:
                            b = step

    def reset_display(self):
        for obj in self.renderable_objects():
            if self.should_render(obj):
                obj.ViewObject.LineColor = (0.0, 0.0, 0.0, 0.0)
                obj.ViewObject.DisplayMode = 'Flat Lines'
                obj.ViewObject.ShapeMaterial.AmbientColor = (0.3, 0.3, 0.3, 0.0)
                obj.ViewObject.ShapeMaterial.DiffuseColor = (1.0, 1.0, 1.0, 0.0)
                obj.ViewObject.ShapeMaterial.SpecularColor = (0.5, 0.5, 0.5, 0.0)
                obj.ViewObject.ShapeMaterial.EmissiveColor = (0.0, 0.0, 0.0, 0.0)
            else:
                #obj.ViewObject.Visibility = True
                pass

    def render(self, resolution, filename: str, line_color=(0.0, 0.0, 0.0, 0.0), fill_color=(1.0, 1.0, 1.0, 0.0), mask_stages_below: Optional[int] = None):
        import time
        from PIL import Image, ImageFilter

        render_start_time = time.perf_counter()

        temp_lines_file_name = tempfile.gettempdir() + "/ahb_temp_lines.png"
        temp_shapes_file_name = tempfile.gettempdir() + "/ahb_temp_shapes.png"

        #temp_lines_file_name = filename + "-lines.png"
        #temp_shapes_file_name = filename + "-shapes.png"

        # render lines in black, background in red, fill shapes in green
        # the green band contains the lines images, the red band contains the inverted alpha layer
        self.set_render_lines((0.0, 0.0, 0.0), (0.0, 1.0, 0.0), mask_color=(1.0, 0.0, 1.0))
        Gui.ActiveDocument.ActiveView.saveImage(temp_lines_file_name, resolution[0]+2, resolution[1]+2, "#ff0000")

        self.set_render_outlines()
        Gui.ActiveDocument.ActiveView.saveImage(temp_shapes_file_name, (resolution[0]+2) * 2, (resolution[1]+2) * 2, "#ffffff")

        lines_bands = Image.open(temp_lines_file_name).split()
        lines = lines_bands[1]
        alpha_band = lines_bands[0].point(lambda p: 255 - p)
        shapes = Image.open(temp_shapes_file_name)

        outlines_start_time = time.perf_counter()

        outlines = None
        for x in range(0, 3):
            for y in range(0, 3):
                if x == 1 and y == 1: continue
                kernel = [0, 0, 0, 0, 1, 0, 0, 0, 0]
                kernel[y * 3 + x] = -1
                partial_outlines = shapes.filter(ImageFilter.Kernel((3, 3), kernel, 1, 127))
                partial_outlines = partial_outlines.point(lambda p: 255 if p == 127 else 0)
                partial_outlines = partial_outlines.convert("L")
                partial_outlines = partial_outlines.point(lambda p: 255 if p == 255 else 0)
                if outlines is None:
                    outlines = partial_outlines
                else:
                    outlines.paste(partial_outlines, None, partial_outlines.point(lambda p: 0 if p == 255 else 255))

        # erase masked outlines
        outlines.paste(outlines.point(lambda p: 255), None, lines_bands[2].resize(outlines.size).point(lambda p: 255 if p == 255 else 0))

        # outlines.save("/home/youen/dev_linux/vhelio-render/vhelio-outlines.png")

        # outlines = outlines.resize(lines.size, Image.BILINEAR)
        # lines.paste(outlines, None, outlines.point(lambda p: 255 - p))

        lines_fullres = lines.resize(outlines.size, Image.NEAREST)
        lines_fullres.paste(outlines, None, outlines.point(lambda p: 255 if p == 0 else 0))
        lines = lines_fullres.resize(lines.size, Image.BILINEAR)

        alpha_band_fullres = alpha_band.resize(outlines.size, Image.NEAREST)
        alpha_band_fullres.paste(outlines.point(lambda p: 255), None, outlines.point(lambda p: 255 if p == 0 else 0))
        alpha_band = alpha_band_fullres.resize(lines.size, Image.BILINEAR)

        outlines_end_time = time.perf_counter()

        # colorize

        result = Image.merge("RGBA", [
            lines.point(lambda p: int(fill_color[0] * p + line_color[0]*(255.0-p))),
            lines.point(lambda p: int(fill_color[1] * p + line_color[1] * (255.0 - p))),
            lines.point(lambda p: int(fill_color[2] * p + line_color[2] * (255.0 - p))),
            alpha_band
        ])

        # crop 1px borders
        result = result.crop((1, 1, result.size[0] - 1, result.size[1] - 1))

        result.save(filename)

        print("Rendered " + filename + " in " + str(
            round((outlines_end_time - render_start_time) * 1000) / 1000) + "s (outlines detection in " + str(
            round((outlines_end_time - outlines_start_time) * 1000) / 1000) + "s)")

    def Activated(self):
        import shutil
        from PIL import Image, ImageFilter
        import math
        from pivy import coin
        import re
        
        Image.MAX_IMAGE_PIXELS = 9999999999 # allow very high resolution images

        Gui.Selection.clearSelection()
        Gui.activeDocument().activeView().setCameraType("Orthographic")

        doc = App.activeDocument()
        doc_file_name: str = doc.FileName
        if doc_file_name is None:
            raise BaseException("You must save your FreeCAD document before rendering images")

        filename = os.path.splitext(doc_file_name)[0] + ".png"
        dir = os.path.dirname(filename)

        resolution = (6000,6000)
        self.render(resolution, filename)
        self.reset_display()
        img_full = Image.new('RGB', resolution, (255, 255, 255))
        img = Image.open(filename)
        img_full.paste(img, None, img.getchannel('A'))
        img_full.save(filename)

from ahb_command import AHB_CommandWrapper
AHB_CommandWrapper.addGuiCommand('AHB_render', AHB_Render())