#!/usr/bin/ardipython

import argparse
import platform
import shutil
import sys
#import pycurl
import os
from subprocess import call
#from StringIO import StringIO
import traceback
import xmltodict
import requests

def GetSiteList(remotehost):    
    sitelist = []
    try:
        #buf = StringIO()
        
        url = 'http://' + remotehost + '/api/sites';
        #print "Querying: " + url
        
        resp = requests.get(url)

        #print "Almost Done..."

        s = buf.text        

        result = xmltodict.parse(s)

        try: 
            return [result['sites']['site']['@url']]
        except:
            lst = []
            for drv in result['sites']['site']:
                lst.append(drv['@url'])
                            
            return lst
    except Exception as e:
        print("Failed To Get ARDI Site List: " + str(e))
        #traceback.print_exc()
        return []
    
parser = argparse.ArgumentParser(description="Perform an Upgrade")
parser.add_argument('--server',help='The BASE url of the ARDI instance')
parser.add_argument('--driver',help='The name of the driver being upgraded')
parser.add_argument('--host',help='The public host name for this machine')
parser.add_argument('--remove',help='Removes the drivers but doesn\'t restart them',action='store_true')

import ctypes, os
try:
    is_admin = os.geteuid() == 0
except AttributeError:
    is_admin = ctypes.windll.shell32.IsUserAnAdmin() != 0
    
if is_admin == False:
    print("Error: This must be run as root/an Administrator")
    sys.exit(0)  

args = parser.parse_args()

serverlist = ['localhost']

if args.server is None:
    pass
else:
    serverlist = [args.server]

driverspecific = ""
if args.driver is not None:
    driverspecific = args.driver

#Get Python Location
if platform.system() == "Windows":
    pathtopython = "c:\\python27\\python.exe"
else:
    pathtopython = "python3"

try:
    pathtopython = os.getenv('ARDIPythonPath')
except:
    pass

#Add additional potential servers...
for srv in serverlist:

    #Get Driver Information
    sites = GetSiteList(srv)

    #print "Sites... = " + str(len(sites))

    if len(sites) == 0:
        print("Invalid or Unlicensed Server - Establishing Basic Consolidator")
        
        #Launch Consolidator...
        if platform.system() == "Windows":
            pass
        else:
            cmd = "/opt/ardi/srv/consolidator/ardiconsolidators " + srv + "/s/default remove"
            call(cmd, shell=True)
            
            cmd = "/opt/ardi/srv/consolidator/ardiconsolidators " + srv + "/s/default install"
            call(cmd, shell=True)

    for site in sites:
        addr = srv + "/s/" + site
        print("Cycling Drivers on " + addr)
        if platform.system() == "Windows":
            cmd = pathtopython + " ardidrivers.py " + addr + " remove"
            if driverspecific != "":
                cmd += " --driver " + driverspecific
            call(cmd, shell=True)

            if (args.remove == True):
                continue

            cmd = pathtopython + " ardidrivers.py " + addr + " install"
            if driverspecific != "":
                cmd += " --driver " + driverspecific
            call(cmd, shell=True)
        else:
            cmd = "/opt/ardi/drivers/ardidrivers " + srv + " remove"
            if driverspecific != "":
                cmd += " --driver " + driverspecific
            call(cmd, shell=True)

            if (args.remove == True):
                continue

            cmd = "/opt/ardi/drivers/ardidrivers " + srv + " install"
            if driverspecific != "":
                cmd += " --driver " + driverspecific
            call(cmd, shell=True)

    print("Upgrade Restart Sequence Complete")
