﻿import math

def norm_zeroone(a,channel):
    return (a - channel.min) / (channel.max - channel.min)

def denorm_zeroone(a,channel):
    return (a * (channel.max - channel.min)) + channel.min

def norm_invzeroone(a,channel):
    return -(a - channel.max) / abs(channel.min - channel.max)

def denorm_invzeroone(a,channel):
    return -(a * abs(channel.min - channel.max)) + channel.max

def norm_oneone(a,channel):
    return ((a - channel.min) / (channel.max - channel.min) * 2) - 1

def denorm_oneone(a,channel):    
    a += 1
    a /= 2
    return (a * (channel.max - channel.min)) + channel.min

def norm_tan(a,channel):
    return math.tan((a - channel.min) / (channel.max - channel.min))

def denorm_tan(a,channel):
    a = math.atan(a)
    return (a * (channel.max - channel.min)) + channel.min    

def norm_dig(a,channel):
    if a > channel.target - 0.5:
        if a < channel.target + 0.5:
            return 1
    return 0

def denorm_dig(a,channel):
    if a == 1:
        return channel.target
    else:
        return 0

def InvertFunction(func):
    if func == norm_zeroone:
        return denorm_zeroone

    if func == norm_invzeroone:
        return denorm_invzeroone

    if func == norm_oneone:
        return denorm_oneone

    if func == norm_tan:
        return denorm_tan

    if func == norm_dig:
        return denorm_dig

    return None

class IChannel:
    def __init__(self,name,min,max):
        self.name = name
        self.max = max
        self.min = min
        self.index = 0

class OChannel:
    def __init__(self,inp,func=None):
        self.min = inp.min
        self.max = inp.max
        self.input = inp
        self.func = func
        self.invfunc = InvertFunction(func)
        self.target = None
        self.index = 0

    def __repr__(self):
        return "Output Channel " + self.input.name + " - " + str(self.func)

    def GetInverse(self):
        self.invfunc = InvertFunction(self.func)

    def ToString(self):
        st = str(self.input.name) + ">"        

        if self.func is None:
            st += "*"
        else:
            if self.func == norm_zeroone:
                st += "0-1"
            if self.func == norm_oneone:
                st += "-1-1"
            if self.func == norm_invzeroone:
                st += "-1-0"
            if self.func == norm_dig:
                st += "D"
            if self.func == norm_tan:
                st += "tan"

        if self.func == norm_dig:
            if self.target is not None:
                st += "=" + str(self.target)
        else:
            if self.func is not None:
                st += "=" + str(self.min) + "|" + str(self.max)

        return st


class Normaliser:
    def __init__(self,format='0-1'):
        self.defaultformat = format
        self.inputcount = 0
        self.outputcount = 0
        self.inputs = []
        self.outputs = []
        self.maxinput = 0
        
    def AddInput(self,name,min,max):
        self.inputs.append(IChannel(name,min,max))

        i = len(self.inputs)-1
        self.inputs[i].index = i

    def Normalise(self,inp,map=None):
        if self.inputcount == 0:
            self.Format(self.defaultformat)

        mapping = [0] * len(self.outputs)
        for n in range(0,len(self.outputs)):
            mapping[n] = self.outputs[n].input.index

        if map is not None:            
            #print("Map: " + str(len(map)) + " " + str(map))
            for x in range(0,len(self.outputs)):
                vx = self.inputs[x].name.lower()
                for q in range(0,len(map)):
                    if vx == map[q].lower():
                        mapping[x] = q

        out = [0] * self.outputcount

        #print(str(mapping))
        #print(str(inp))
        #print("Inputs: " + str(len(inp)) + " " + str(inp))
        #print("Outputs: " + str(self.outputcount) + " " + str(out))
        #print("Mapping: " + str(mapping))
        
        for o in self.outputs:
            #print("Output: " + str(o.index) + " from input " + str(mapping[o.index]))
            if o.func is None:
                out[o.index] = inp[mapping[o.index]]#o.input.index]
            else:
                out[o.index] = o.func(inp[mapping[o.index]],o)#o.func(inp[o.input.index],o)

        return out

    def Denormalise(self,out):
        if self.inputcount == 0:
            self.Format(self.defaultformat)
        
        inp = [0] * self.inputcount
        for o in self.outputs:            
            if o.invfunc is None:
                inp[o.input.index] = out[o.index]
            else:
                if o.func == norm_dig:
                    x = o.invfunc(out[o.index],o)
                    if x != 0:                        
                        inp[o.input.index] = o.invfunc(out[o.index],o)
                else:
                    inp[o.input.index] = o.invfunc(out[o.index],o)

        return inp

    def Debug(self):
        for n in self.outputs:
            print(n.input.name + " ( " + str(n.input.index) + " ) - " + str(n.min) + " to " + str(n.max) + " To Output " + str(n.index))

        print(self.DetailFormat())

    def FinishFormatting(self):
        for x in range(0,len(self.outputs)):
            self.outputs[x].index = x
            self.outputs[x].GetInverse()
        self.inputcount = len(self.inputs)
        self.outputcount = len(self.outputs)

    def Format(self,format):
        #Check for whole-input notation...
        self.outputs = []

        if format == "0-1":
            for i in self.inputs:
                self.outputs.append(OChannel(i,norm_zeroone))
            self.FinishFormatting()
            return
        if format == "-1-1":
            for i in self.inputs:
                self.outputs.append(OChannel(i,norm_oneone))
            self.FinishFormatting()
            return
        if format == "-1-0":
            for i in self.inputs:
                self.outputs.append(OChannel(i,norm_invzeroone))
            self.FinishFormatting()
            return
        if format == "*":
            for i in self.inputs:
                self.outputs.append(OChannel(i))
            self.FinishFormatting()
            return
        if format == "A":
            for i in self.inputs:
                oc = OChannel(i)
                oc = self.AutoFormat(oc)
                self.outputs.append(oc)
            self.FinishFormatting()
            return

        bits = format.split(':')
        for o in bits:
            parts = self.ParseOutChannel(o)
            for n in parts:
                #print("Adding Output Channel: " + str(n))
                self.outputs.append(n)   
                i = len(self.outputs)-1
                self.outputs[i].index = i

        self.inputcount = len(self.inputs)
        self.outputcount = len(self.outputs)   

        #print(str(self.outputs))
        
    def AutoFormat(self,ob):
        if ob.min < 0 and ob.max > 0:
            if abs(ob.min) > ob.max:
                ob.max = abs(ob.min)
            else:
                ob.min = -ob.max
            ob.func = norm_oneone
        else:
            if ob.min < 0 and ob.max < 0:                
                ob.func = norm_invzeroone
            else:
                ob.func = norm_zeroone        

        return ob

    def CopyInputs(self,other):
        for inp in other.inputs:
            self.AddInput(inp.name,inp.max,inp.min)

    def ParseOutChannel(self,format):        

        inputid = self.maxinput

        #print("Parsing Output Channel: " + format)        

        #Check for input index
        px = format.find('>')
        if px != -1:
            fromchannel = format[0:px]
            format = format[px+1:]
            if fromchannel.isdigit():            
                inputid = int(fromchannel)
            else:
                for n in range(0,len(self.inputs)):
                    if self.inputs[n].name.lower() == fromchannel:
                        inputid = n
                        break

        outs = [OChannel(self.inputs[inputid])]

        outs[0].min = self.inputs[inputid].min
        outs[0].max = self.inputs[inputid].max 

        px = format.find('=')
        if px != -1:
            values = format[px+1:]
            format = format[0:px]

            if format == "D":
                hits = values.split('|')
                outs = []            
                for h in hits:
                    outs.append(OChannel(self.inputs[inputid]))
                    indx = len(outs)-1
                    outs[indx].target = int(h)
                    outs[indx].func = norm_dig
                    outs[indx].invfunc = denorm_dig
                    outs[indx].min = 0
                    outs[indx].max = 1

                if inputid >= self.maxinput:
                    self.maxinput = inputid + 1

                return outs
            else:
                pieces = values.split('|')
                if len(pieces) > 1:
                    outs[0].min = float(pieces[0])
                    outs[0].max = float(pieces[1])

        #print("Remaining Format: " + format)

        #Determine the correct normalisation function
        func = None        
        
        if format == "0-1":
            func = norm_zeroone
        if format == "-1-0":
            func = norm_invzeroone
        if format == "-1-1":
            func = norm_oneone
        if format == "D":
            func = norm_dig
        if format == "tan":
            func = norm_tan                      
        if format == "A":
            outs[0] = self.AutoFormat(outs[0])
        else:
            outs[0].func = func
        #print("Channel Format: " + str(outs[0]))
        
        outs[0].GetInverse()

        #Update the master input id
        if inputid >= self.maxinput:
            self.maxinput = inputid + 1

        return outs

    def DetailFormat(self):
        outset = []
        for o in self.outputs:
            outset.append(o.ToString())

        return ":".join(outset)

    def GetNegotiatedFormat(self,other):
        parts = []
        outno = -1
        for oa in self.outputs:
            outno += 1
            found = -1
            indx = -1
            for ob in other.inputs:
                indx += 1
                if ob.name == oa.input.name:
                    #print("Found Channel Name " + ob.name + " @ " + str(indx))
                    found = indx
                    break

            piece = ""

            if found == -1:
                continue

            st = oa.ToString()
            exp = st.split(">")
            parts.append(str(found) + ">" + exp[1])

        return ":".join(parts)

    def Negotiate(self,other):
        fmt = self.GetNegotiatedFormat(other)
        nm = Normaliser()
        nm.CopyInputs(other)
        nm.Format(fmt)
        return nm
