"""
slicerender.py

Autor: Mahesh Venkitachalam

Ten moduł ma klasy i metody związane renderowaniem wycinków X Y Z 
zbioru danych wolumetrycznych.
"""

import OpenGL
from OpenGL.GL import *
from OpenGL.GL.shaders import *
import numpy, math, sys 

import volreader, glutils

strVS = """
# version 330 core

in vec3 aVert;

uniform mat4 uMVMatrix;
uniform mat4 uPMatrix;

uniform float uSliceFrac;
uniform int uSliceMode;

out vec3 texcoord;

void main() {

  // wycinek x
  if (uSliceMode == 0) {
    texcoord = vec3(uSliceFrac, aVert.x, 1.0-aVert.y);
  }
  // wycinek y
  else if (uSliceMode == 1) {
    texcoord = vec3(aVert.x, uSliceFrac, 1.0-aVert.y);
  }
  // wycinek z
  else {
    texcoord = vec3(aVert.x, 1.0-aVert.y, uSliceFrac);
  }

  // obliczanie przetransformowanego wierzchołka
  gl_Position = uPMatrix * uMVMatrix * vec4(aVert, 1.0); 
}
"""
strFS = """
# version 330 core

in vec3 texcoord;

uniform sampler3D tex;

out vec4 fragColor;

void main() {
  // wyszukanie koloru w teksturze
  vec4 col = texture(tex, texcoord);
  fragColor = col.rrra;
}

"""

class SliceRender:
    # tryby wycinków
    XSLICE, YSLICE, ZSLICE = 0, 1, 2

    def __init__(self, width, height, volume):
        """Konstruktor klasy SliceRender"""
        self.width = width
        self.height = height
        self.aspect = width/float(height)

        # tryb wycinka
        self.mode = SliceRender.ZSLICE

        # utworzenie shadera
        self.program = glutils.loadShaders(strVS, strFS)

        glUseProgram(self.program)

        self.pMatrixUniform = glGetUniformLocation(self.program, b'uPMatrix')
        self.mvMatrixUniform = glGetUniformLocation(self.program, 
                                                  b"uMVMatrix")

        # atrybuty
        self.vertIndex = glGetAttribLocation(self.program, b"aVert")
 
        # konfigurowanie obiektu tablicy wierzchołków (VAO)
        self.vao = glGenVertexArrays(1)
        glBindVertexArray(self.vao)

        # definiowanie wierzchołków kwadratu 
        vertexData = numpy.array([ 0.0, 1.0, 0.0, 
                                   0.0, 0.0, 0.0, 
                                   1.0, 1.0, 0.0,
                                   1.0, 0.0, 0.0], numpy.float32)
        # bufor wierzchołków
        self.vertexBuffer = glGenBuffers(1)
        glBindBuffer(GL_ARRAY_BUFFER, self.vertexBuffer)
        glBufferData(GL_ARRAY_BUFFER, 4*len(vertexData), vertexData, 
                     GL_STATIC_DRAW)
        # włączenie tablic
        glEnableVertexAttribArray(self.vertIndex)
        # ustawienie buforów
        glBindBuffer(GL_ARRAY_BUFFER, self.vertexBuffer)
        glVertexAttribPointer(self.vertIndex, 3, GL_FLOAT, GL_FALSE, 0, None)

        # usunięcie dowiązania do VAO
        glBindVertexArray(0)

        # ładowanie tekstury
        self.texture, self.Nx, self.Ny, self.Nz = volume

        # indeks bieżącego wycinka
        self.currSliceIndex = int(self.Nz/2);
        self.currSliceMax = self.Nz;


    def reshape(self, width, height):
        self.width = width
        self.height = height
        self.aspect = width/float(height)
        
    def draw(self):
        # czyszczenie buforów
        glClear(GL_COLOR_BUFFER_BIT | GL_DEPTH_BUFFER_BIT)
        # budowanie macierzy rzutowania
        pMatrix = glutils.ortho(-0.6, 0.6, -0.6, 0.6, 0.1, 100.0)
        # macierz widoku modelu
        mvMatrix = numpy.array([1.0, 0.0, 0.0, 0.0, 
                                0.0, 1.0, 0.0, 0.0, 
                                0.0, 0.0, 1.0, 0.0, 
                                -0.5, -0.5, -1.0, 1.0], numpy.float32)        
        # użycie shadera
        glUseProgram(self.program)
        
        # ustawienie macierzy rzutowania
        glUniformMatrix4fv(self.pMatrixUniform, 1, GL_FALSE, pMatrix)

        # ustawienie macierzy widoku modelu
        glUniformMatrix4fv(self.mvMatrixUniform, 1, GL_FALSE, mvMatrix)

        # ustawienie ułamka bieżącego wycinka 
        glUniform1f(glGetUniformLocation(self.program, b"uSliceFrac"), 
                    float(self.currSliceIndex)/float(self.currSliceMax))
        # ustawienie trybu bieżącego wycinka
        glUniform1i(glGetUniformLocation(self.program, b"uSliceMode"), 
                    self.mode)
        
        # włączenie tekstury
        glActiveTexture(GL_TEXTURE0)
        glBindTexture(GL_TEXTURE_3D, self.texture)
        glUniform1i(glGetUniformLocation(self.program, b"tex"), 0)

        # wiązanie VAO
        glBindVertexArray(self.vao)
        # rysowanie
        glDrawArrays(GL_TRIANGLE_STRIP, 0, 4)
        # usunięcie dowiązania do VAO
        glBindVertexArray(0)

    def keyPressed(self, key):
        """Procedura obsługi zdarzeń klawiatury"""
        if key == 'x':
            self.mode = SliceRender.XSLICE
            # resetowanie indeksu wycinków
            self.currSliceIndex = int(self.Nx/2)
            self.currSliceMax = self.Nx
        elif key == 'y':
            self.mode = SliceRender.YSLICE
            # resetowanie indeksu wycinków
            self.currSliceIndex = int(self.Ny/2)
            self.currSliceMax = self.Ny
        elif key == 'z':
            self.mode = SliceRender.ZSLICE
            # resetowanie indeksu wycinków
            self.currSliceIndex = int(self.Nz/2)
            self.currSliceMax = self.Nz
        elif key == 'l':
            self.currSliceIndex = (self.currSliceIndex + 1) % self.currSliceMax
        elif key == 'r':
            self.currSliceIndex = (self.currSliceIndex - 1) % self.currSliceMax
            
    def close(self):
        pass
