import sys

from ardi.driver import drivercore,driverbase
from pymodbus.client.sync import ModbusTcpClient as ModbusClient
try:    
    from pymodbus.payload import PayloadDecoder as BinaryPayloadDecoder
except:
    from pymodbus.payload import BinaryPayloadDecoder

from pymodbus.pdu import ExceptionResponse as ExceptionResponse

import logging

class fetchsetpoint:
    def __init__(self):
        self.point = 0
        self.offset = 0
        self.format = 0

class fetchset:
    def __init__(self):
        self.unit = 0
        self.address = 0
        self.length = 0
        self.register = 0
        self.points = []
        
    def Setup(self):
        self.register = 0
        add = int(self.address)
        if add >= 100000:
            self.register = 1
        if add >= 300000:
            self.register = 3
        if add >= 400000:
            self.register = 4

        #print "**Register Type from " + str(self.address) + " = " + str(self.register)

    def Normalise(self, addr):
        if self.register == 4:
            return addr - 400000
        if self.register == 3:
            return addr - 300000
        if self.register == 1:
            return addr - 100000
        return addr

    def Extend(self,address,mode, code):
        len = 1
        if (mode >= 10):
            len = 2

        if self.address == -1:

            self.address = address
            self.length = len

            pnt = fetchsetpoint()
            pnt.point = code
            pnt.offset = 0
            pnt.format = mode
        
            self.points.append(pnt)
            self.Setup()
            return

        if (self.address > address):
            dist = self.address - address
            self.length = self.length + dist + len + 1
            self.address = address
            for pnt in self.points:
                pnt.offset += dist
        else:
            if self.address + self.length <= address + len:
                self.length = (address - self.address) + (len)

        pnt = fetchsetpoint()
        pnt.point = code
        pnt.offset = address - self.address
        pnt.format = mode
        
        self.points.append(pnt)

class modbustcpdriver:
    def __init__(self):
        self.connected = False
        self.polling = True
        self.digits = 6
        self.endian = '>'
        self.fetches = []
    
    def SetAddress(self, addr):
        self.connected = False        
        bits = addr.split(":")
        self.ip = bits[0]
        if len(bits) > 1:
            self.port = int(bits[1])
        else:
            self.port = 502
        if len(bits) > 2:
            self.unit = int(bits[2])
        else:
            self.unit = 1
        if len(bits) > 3:
            self.endian = bits[3]
            if self.endian == '':
                self.endian = '>'
        if len(bits) > 4:
            self.digits = int(bits[4])
    
    def Connect(self):
        #self.SetAddress(addr)
        self.connected = False
        self.client = ModbusClient(self.ip, port=self.port)

        logging.basicConfig(level=logging.DEBUG)
        #print "Connecting To " + self.ip + "!"
        try:
            success = self.client.connect()
            if success == False:
                return False
            self.connected = True
            return True
        except:
           return False
    
    def Disconnect(self):
        self.connected = False
        self.client.close()
        
    def Poll(self):
        #print "Polling "+ str(len(self.core.pointlist)) + " items"
        #for item in self.core.pointlist:
        #    print "  > "+ item.name

        for fetch in self.fetches:
            requestbase = fetch.Normalise(fetch.address)
            if fetch.register == 0:
                rr = self.client.read_coils(requestbase,fetch.length,unit=fetch.unit)
                
                if type(rr) is ExceptionResponse:                    
                    continue

                q=-1
                for pnt in fetch.points:
                    q = q + 1
                    if rr.bits[q] == True:
                        v=1
                    self.core.UpdateData(pnt.code,v)
                    
                continue

            if fetch.register == 1:
                rr = self.client.read_discrete_inputs(requestbase,fetch.length,unit=fetch.unit)
                
                if type(rr) is ExceptionResponse:                    
                    continue

                q=-1
                for pnt in fetch.points:
                    q = q + 1
                    if rr.bits[q] == True:
                        v=1
                    self.core.UpdateData(pnt.code,v)
                    
                continue

            if fetch.register == 3:
                #print "Fetching " + str(fetch.length) + " from " + str(requestbase)
                rr = self.client.read_input_registers(requestbase,fetch.length,unit=fetch.unit)
            else:
                #print "**HOLDING REGISTERS?"
                rr = self.client.read_holding_registers(requestbase,fetch.length,unit=fetch.unit)            
                
            if type(rr) is ExceptionResponse:
                self.core.logger.warning("Error Detected: " + str(rr))
                continue

            if rr is None:
                self.core.logger.warning("Null Response Detected On Request For " + str(requestbase))
                self.Disconnect()
                return

            for pnt in fetch.points:
                
                if pnt.format == 0:                
                    self.core.UpdateData(pnt.code,rr.registers[pnt.offset])
                    continue
            
                en = '>';
                decoder = BinaryPayloadDecoder.fromRegisters(rr.registers[pnt.offset:],endian=en)
                
                if pnt.format == 1:
                    val = decoder.decode_32bit_int()				
                    self.core.UpdateData(pnt.point,val)
                if pnt.format == 10:
                    val = decoder.decode_32bit_float()
                    self.core.UpdateData(pnt.point,val)
                    #print "Value = " + str(val)
                if pnt.format == 11:
                    val = decoder.decode_64bit_float()
                    self.core.UpdateData(pnt.point,val)
                if pnt.format == 20:
                    val = decoder.decode_string(int(bits[2]))
                    self.core.UpdateData(pnt.point,val)   
        return       

    def Optimise(self):

        #print "-------"
        self.fetches = []
        overhead = 16
        for item in self.core.pointlist:
            found = -1

            bits = item.address.split('|')
            mode = 0
            address = 0
            unit = 0

            if len(bits) == 1:
                address = int(bits[0])
            else:
                if len(bits) == 2:
                    unit = int(bits[0])
                    address = int(bits[1])
                else:
                    unit = int(bits[0])
                    address = int(bits[1])
                    mode = int(bits[2])
                
            q = -1            
            for fetch in self.fetches:
                q=q+1
                if fetch.unit == unit:
                    if address >= fetch.address - overhead:
                        if address <= fetch.address + fetch.length + overhead:
                            found = q
                            break
                    
            if found == -1:
                #print "Add New Map"
                fch = fetchset()
                fch.address = -1
                fch.length = 0
                fch.unit = unit
                fch.Setup()
                fch.Extend(address,mode,item.code)
                #pnt = fetchsetpoint()
                #pnt.point = item.code
                #pnt.offset = 0
                #pnt.format = mode
                #fch.points.append(pnt)

                self.fetches.append(fch)
            else:           
                self.fetches[found].Extend(address,mode,item.code)

        for f in self.fetches:
            print "Fetching Unit " + str(f.unit) + " @ " + str(f.address) + " for " + str(f.length) + " on " + str(len(f.points)) + " points, register " + str(f.register);
       
        
class driverfactory:
    def createinstance(self):
        return modbustcpdriver()
            
if __name__ == "__main__":
    sdf = driverfactory()
    base = driverbase.ardidriver()
    base.start(sdf)