
import pycurl
import xmltodict
import time
import socket
import transform
import platform
import traceback

import logging
import thread
import pytz
from tzlocal import get_localzone
import datetime
from datetime import timedelta
from ardi.driver import resultstore

class SQLSource:
    def __init__(self,db,driver):
        self.db = db
        self.driver = driver
        self.core = driver
        self.filter = ""
        self.narrowthreshold = 20
        self.datepoints = 4
        self.server = "mysql"		
        self.upperedgequery = "SELECT {stamp}, {lookup}, {value} FROM ( SELECT * FROM {table} WHERE {stamp} <= {start}{additional} ORDER BY {stamp} DESC) AS t1 GROUP BY {lookup}";
        self.loweredgequery = "SELECT {stamp}, {lookup}, {value} FROM ( SELECT * FROM {table} WHERE {stamp} >= {end}{additional} ORDER BY {stamp} DESC) AS t1 GROUP BY {lookup}";
        self.mainquery = "SELECT {extra} as ostamp, {lookup} as lookup, AVG({value}) as value FROM {table} WHERE {stamp}>={start} AND {stamp}<={end}{additional} GROUP BY ostamp, lookup ORDER BY ostamp,lookup"
        self.datetimecalc = "DATE_ADD({start},INTERVAL (CEIL(TIMESTAMPDIFF(SECOND,{start},{stamp}) / {grain}) * {grain}) SECOND)"
        self.numericcalc = "CEIL(({stamp} - {start}) / {grain}) * {grain}"
        self.tableprefix = ""
        self.postquerylimit = " LIMIT 1"
        self.postselectlimit = ""
        #self.filetimecalc = "CEIL(({stamp} - {start}) / {grain}) * {grain}"

        self.datestyle = "1"
        self.tz = "local"
        self.localtz = None

    def fixdate(dt):
        ps = dt.find('.')
        if ps > -1:
            return dt[0:ps]
        return dt

    def ConnectionParameters(self,host,username,password,database):
        self.host = host
        self.username = username
        self.password = password
        self.database = database

    def TableParameters(self,stamp, lookup, value,table,filter=""):
        self.table = table
        self.stamp = stamp
        self.lookup = lookup
        self.value = value
        self.filter = filter

    def Reconnect(self):
        return self.driver.Connect()
        pass

    def TranslateDT(self,dt):
    	#Traditional date...
    	if self.datestyle == "1":
    	    dx = pytz.utc.localize(dt).astimezone(self.localtz).replace(tzinfo=None)
    	    print(str(dx))
    	    return "'" + dx.strftime("%Y-%m-%d %H:%M:%S") + "'";

    	#UNIX Epoch
    	if self.datestyle == "2":
    		#if self.tz != "UTC":
    		#	dt = self.localtz.localize(dt)
    		#	dt = dt.astimezone(pytz.UTC)
    		return str(long((dt - datetime.datetime(1970,1,1,0,0,0,0,pytz.UTC)).total_seconds()))
    	
    	#Win32 FILETIME
    	if self.datestyle == "3":
    		tickspersec = 10000000
    		#if self.tz != "UTC":
    		#	dt = self.localtz.localize(dt)
    		#	dt = dt.astimezone(pytz.UTC)
    		l = long((dt - datetime.datetime(1601,1,1,0,0,0,0,pytz.UTC)).total_seconds())
    		return str(l * tickspersec);
    	
    	#CLR Time
    	if self.datestyle == "4":
    		tickspersec = 10000000
    		#if self.tz != "UTC":
    		#	dt = self.localtz.localize(dt)
    		#	dt = dt.astimezone(pytz.UTC)
    		diff = (dt.replace(tzinfo=None) - datetime.datetime(1,1,1,0,0,0,0)).total_seconds()
    		return str(long(tickspersec) * long(diff));

    def ConvertDateBack(self,dt):
    	#Traditional date...
    	if self.datestyle == "1":
    		return dt;

    	#UNIX Epoch
    	if self.datestyle == "2":
    		dd = datetime.datetime.utcfromtimestamp(long(dt)).replace(tzinfo=pytz.UTC)
    		return dd.astimezone(self.localtz).replace(tzinfo=None)
    	
    	#Win32 FILETIME
    	if self.datestyle == "3":
    	    tickspersec = 10000000
    	    l = long(dt) / tickspersec
    	    dvx = (datetime.datetime(1970,1,1,0,0,0,0,pytz.UTC) - datetime.datetime(1601,1,1,0,0,0,0,pytz.UTC)).total_seconds()
    	    l -= dvx
    	    #print "Final Timestamp: " + str(l)
    	    dd = datetime.datetime.utcfromtimestamp(l).replace(tzinfo=pytz.UTC)
    	    return dd.astimezone(self.localtz).replace(tzinfo=None)
    	
    	#CLR Time
    	if self.datestyle == "4":
    		tickspersec = 10000000

    		diff = long(dt) / tickspersec
    		tstamp = datetime.datetime(1970,1,1,0,0,0,0,pytz.UTC)
    		dvx = (datetime.datetime(1970,1,1,0,0,0) - datetime.datetime(1,1,1,0,0,0)).total_seconds()
    		#print "Total Seconds: " + str(long(dvx))
    		dd = datetime.datetime.utcfromtimestamp(diff - dvx).replace(tzinfo=pytz.UTC)
    		return dd.astimezone(self.localtz).replace(tzinfo=None)

    def Substitute(self, stx, grain, start, end, additional,extra,join):
    	stx = stx.replace("{stamp}",self.stamp)
    	stx = stx.replace("{additional}",additional)
    	stx = stx.replace("{lookup}",self.lookup)
    	stx = stx.replace("{value}",self.value)
    	stx = stx.replace("{table}",self.table)
    	stx = stx.replace("{join}",join);
    	
    	stx = stx.replace("{grain}",str(grain))
    	stx = stx.replace("{extra}",extra)
    	stx = stx.replace("{start}",start)
    	stx = stx.replace("{end}",end)
    	return stx
    
    def Fetch(self, query):
        #print "Fetch Initiated"
        #Results = resultstore.ResultStore()

        if self.datestyle != "1":
            self.tz = "UTC"
        
        #OK - let's convert this UTC time to the expected time...
        sd = query.localstart
        ed = query.localend
        
        query.SetDateFormat(self.tz,self.datestyle,True)

        qsd = query.sd
        qed = query.ed
        grain = query.grain


        #print "Working With Timezone: " + self.tz + " = " + str(self.localtz)
        #print qsd + " - " + qed;
        
        fullquery = ""
        
        #print "Fetch Timing Established"
        #Figure out effective grain...

        diff = ed - sd

        #secs = diff.seconds + (diff.minutes * 60) + (diff.hours * (60*60)) + (diff.days * (60*60*24))
        secs = diff.total_seconds()
        #self.options = "snapshot"

        #print "Calcualting Grain"

        add = ""
        
        try:
            #itemcount = int(secs) / int(grain)
            if int(grain) < 0:
                #print "Dealing With Negative Grain"
                itemcount = -int(grain)
                grain = secs / (itemcount)
            else:
                #print "Positive Grain - " + str(grain)
                itemcount = int(secs) / int(grain)

            if (self.datestyle == "4") or (self.datestyle == "3"):
                #print "WTF"
                grain = long(grain) * 10000000
                #print "Done - New Grain = " + str(grain)

            #print "Applying Filter(s)"
            
            if self.filter != "":
                add += " AND (" + self.filter + ")"
                print "Current Filter: " + self.filter

        except:
            traceback.print_exc()
            
        #print "Loading Point Names..."    
        
        joinquery = ""
        n=0
        if len(query.addresses) < self.narrowthreshold:
            joinquery = " JOIN ( values "
            
            add += " AND (" + self.tableprefix + self.lookup + " IN ("
            n = n+1
            x=0
            for p in query.addresses:
                if x > 0:
                    add += ","
                    joinquery += ","
                add += "'" + p + "'"
                joinquery += "('" + p + "')"
                x = x + 1
            add += "))"

        joinquery += ") V(itm) ON (V.itm = " + self.tableprefix + self.lookup + ")"
        #print "Working Out Fetch"

        PrefetchStart = datetime.datetime.now()
        if self.options == "snapshot":
            #Use the snapshot topoff query
            #print "Using Snapshot..."
            topoffquery = 'SELECT' + self.postselectlimit + ' ' + self.stamp + ' as mx FROM ' + self.table + ' WHERE ' + self.stamp + " <= " + qsd + " ORDER BY " + self.stamp + " DESC" + self.postquerylimit
            self.driver.core.logger.info('Snapshot Start Time Fetch: ' + topoffquery)
            #print topoffquery
            NoDate = False
            mx = ""
            try:
                csr = self.db.cursor()
                csr.execute(topoffquery);
            except:
                print "SQL Server Not Responding"
                return

            mx = False
            for row in csr:
                mx = row[0]

            self.driver.core.logger.info("Snapshot Date: " + str(mx))

            if (mx is None) or (mx == "") or (mx == False):
                self.driver.core.logger.info("Request is asking for data too far into the past")
            else:
                topoffquery = 'SELECT ' + self.stamp + ', ' + self.lookup + ', ' + self.value + ' FROM ' + self.table + ' WHERE ' + self.stamp + ' = '
                vd = str(mx)
                if vd == "None":
                    NoDate = True
                else:
                    if self.datestyle == "1":
                        ps = vd.rfind('.')
                        if ps != -1:
                            if len(vd) > ps+self.datepoints:
                                vd = vd[0:ps+self.datepoints]

                        topoffquery += "'" + vd + "'"
                   
                    else:
                        topoffquery += str(mx)
                    
                topoffquery += add
                if NoDate == False:
                    if self.tableprefix != "":
                        topoffquery = topoffquery.replace(self.tableprefix,"")
                    self.driver.core.logger.info("Synchronised Fetch Query: " + topoffquery)
                    
                    try:
                        csr = self.db.cursor()
                        csr.execute(topoffquery);
                    except:
                        traceback.print_exc()
                        return ""

                    for row in csr:
                        #if row[1] in query.addresses:
                        query.AddLine(str(row[1]),row[0],row[2])
                            #print str(row)
                            #dt = self.ConvertToUTC(self.ConvertDateBack(row[0]))
                            #Results.Record(str(row[1]),dt,row[2])

                        
            self.driver.core.logger.info("Synchronised Fetch Complete!")
        else:
            #Use the topoff query
            topoffquery = self.Substitute(self.upperedgequery,grain,qsd,qed,add,"<=",joinquery)
            self.driver.core.logger.info("Pre-Fetch Query: " + topoffquery)
            fullquery += topoffquery
            try:
                csr = self.db.cursor()
                csr.execute(topoffquery);
            except:
                traceback.print_exc()
                self.driver.Connect()
                self.db = self.driver.db
                try:
                    csr = self.db.cursor()
                    csr.execute(topoffquery);
                except:
                    return ""

            for row in csr:
                #if row[1] in query.points:
                query.AddLine(str(row[1]),row[0],row[2])
                    #print str(row)
                    #dt = self.ConvertToUTC(self.ConvertDateBack(row[0]))
                    #Results.Record(str(row[1]),dt,row[2])

            csr.close()
            
        elapsed = (datetime.datetime.now() - PrefetchStart).total_seconds()
        self.driver.core.logger.info('Prefetch Time Elapsed: ' + str(elapsed))
        
        PrefetchStart = datetime.datetime.now()

        if (sd != ed):

            code = ""
            if self.datestyle == "1":
                code = self.Substitute(self.datetimecalc,grain,qsd,qed,add,"",joinquery)
            else:
               code = self.Substitute(self.numericcalc,grain,qsd,qed,add,"",joinquery)
                
            self.query = self.Substitute(self.mainquery,grain,qsd,qed,add,code,joinquery)
            
            if query.function=="raw":
                self.query = self.Substitute(self.rawquery,grain,qsd,qed,add,code,joinquery)
            
            self.driver.core.logger.info("Main Query: " + self.query)
            
            try:
                #print self.query
                csr = self.db.cursor()
                #print str(points)

                csr.execute(self.query)		
                for row in csr:
                    #print "Rows - '" + row[1] + "'"
                    #if row[1] in query.points:
                    query.AddLine(str(row[1]),row[0],row[2])
                        #print str(row)
                        #dt = self.ConvertToUTC(self.ConvertDateBack(row[0]))
                        #Results.Record(str(row[1]),dt,row[2])

                csr.close()
            except:
                print(self.query)
                traceback.print_exc()
                pass
            
            elapsed = (datetime.datetime.now() - PrefetchStart).total_seconds()
            self.driver.core.logger.info('Main Query Time Elapsed: ' + str(elapsed))
            PrefetchStart = datetime.datetime.now()

        #Get the NEXT values so we have something to interpolate into...
        if self.options == "snapshot":
            #Use the snapshot topoff query
            topoffquery = 'SELECT' + self.postselectlimit + ' ' + self.stamp + ' as mx FROM ' + self.table + ' WHERE ' + self.stamp + " >= " + qed + " ORDER BY " + self.stamp + " ASC" + self.postquerylimit
            NoDate = True
            mx = ""
            self.driver.core.logger.info("Post-Fetch Query: " + topoffquery)
            try:
                csr = self.db.cursor()
                csr.execute(topoffquery);
            except:
                traceback.print_exc()
                self.driver.Connect()
                self.db = self.driver.db
                try:
                    csr = self.db.cursor()
                    csr.execute(topoffquery);
                except:
                    print(topoffquery)
                    traceback.print_exc()
                    return ""

            for row in csr:
                mx = row[0]
                NoDate = False

            self.driver.core.logger.info("Snapshot Date: " + str(mx))

            if (mx is None) or (mx == "") or (mx == False):
                self.driver.core.logger.info("Request is asking for data too far into the future")
            else:
                topoffquery = 'SELECT ' + self.stamp + ', ' + self.lookup + ', ' + self.value + ' FROM ' + self.table + ' WHERE ' + self.stamp + ' = '
                vd = str(mx)
                if vd == "None":
                    NoDate = True
                else:
                    if self.datestyle == "1":
                    
                        ps = vd.rfind('.')
                        if ps != -1:
                            if len(vd) > ps+self.datepoints:
                                vd = vd[0:ps+self.datepoints]
                        
                        topoffquery += "'" + vd + "'"
                    
                    else:
                        topoffquery += str(mx)
                    
                topoffquery += add
                if NoDate == False:
                    if self.tableprefix != "":
                        topoffquery = topoffquery.replace(self.tableprefix,"")
                    self.driver.core.logger.info("Post-Fetch Query: " + topoffquery)
                    
                    try:
                        csr = self.db.cursor()
                        csr.execute(topoffquery);
                    except:
                        traceback.print_exc()
                        return ""

                    for row in csr:
                        #if row[1] in query.points:
                        query.AddLine(str(row[1]),row[0],row[2])
                            #print str(row)
                            #dt = self.ConvertToUTC(self.ConvertDateBack(row[0]))
                            #Results.Record(str(row[1]),dt,row[2])
            
                
        else:
            topoffquery = self.Substitute(self.loweredgequery,grain,qsd,qed,add,"<=",joinquery)
            self.driver.core.logger.info("Post-Fetch Query: " + topoffquery)
            try:	
                csr = self.db.cursor()
                csr.execute(topoffquery);
            
                #OK - same deal, but this time get the NEXT row value from each...

                for row in csr:
                    #if row[1] in query.points:
                    query.AddLine(str(row[1]),row[0],row[2])
                        #print str(row)
                        #dt = self.ConvertToUTC(self.ConvertDateBack(row[0]))        	        
                        #Results.Record(str(row[1]),dt,row[2])
                    pass

                csr.close()
            except:
                self.driver.core.logger.exception("Failed To Add")            
                pass

        elapsed = (datetime.datetime.now() - PrefetchStart).total_seconds()
        self.driver.core.logger.info('Postfetch Time Elapsed: ' + str(elapsed))
            
        try:
        	self.db.commit()
        except:
        	pass

        self.driver.core.logger.info("Preparing Results for Display")
        st = query.Finish()#Results.Prepare(function,grain,sd,ed,self.core,points)
        #print st
        return st