#!/usr/bin/env python
# vim:ts=2:et:sw=2:ai
#
# Check configs with remote addresses
#
# Rick van der Zwet <info@rickvanderzwet.nl>
#
import argparse
import gformat
import getpass
import os
import paramiko
import socket
import struct
import subprocess
import sys
import time
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logging.getLogger("paramiko").setLevel(logging.WARNING)

SSHPASS = None
import netsnmp

class CmdError(Exception):
  pass

class ConnectError(Exception):
  pass



def host_ssh_cmd(hostname, cmd):
  try:
    ssh = paramiko.SSHClient()
    ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
    ssh.connect(hostname, username='root', password=SSHPASS,timeout=3)
    stdin, stdout, stderr = ssh.exec_command(cmd)
    stdout = stdout.readlines()
    stderr = stderr.readlines()
    ssh.close()
    if stderr:
      raise CmdError((stderr, stdout))
    return stdout
  except (socket.error, paramiko.AuthenticationException) as e:
    raise ConnectError(e)

def parse_ini(lines):
  return dict(map(lambda x: x.strip().split('='),lines))

def ubnt_probe(hostname):
  items = parse_ini(host_ssh_cmd(hostname, 'cat /etc/board.info'))
  print items


def get_bridge_type(host):
  """ Both NS and NS Mx uses a slighly different OID"""
  var_list = netsnmp.VarList(
   *map(lambda x: netsnmp.Varbind(x), 
    ['.1.2.840.10036.3.1.2.1.3.5', '.1.2.840.10036.3.1.2.1.3.6', '.1.2.840.10036.3.1.2.1.3.7','.1.3.6.1.2.1.1.5.0']))
  
  sess = netsnmp.Session(Version=1, DestHost=host, Community='public', Timeout=2 * 100000, Retries=1)
  retval = sess.get(var_list)
  if sess.ErrorInd < 0:
    raise CmdError('SNMP Failed -- [%(ErrorInd)s] %(ErrorStr)s (%(DestHost)s)' % vars(sess))
  if not filter(None, retval):
    return None
  else:
    return filter(None, retval)[0]


# http://countergram.com/python-group-iterator-list-function
def group_iter(iterator, n=2, strict=False):
    """ Transforms a sequence of values into a sequence of n-tuples.
    e.g. [1, 2, 3, 4, ...] => [(1, 2), (3, 4), ...] (when n == 2)
    If strict, then it will raise ValueError if there is a group of fewer
    than n items at the end of the sequence. """
    accumulator = []
    for item in iterator:
        accumulator.append(item)
        if len(accumulator) == n: # tested as fast as separate counter
            yield tuple(accumulator)
            accumulator = [] # tested faster than accumulator[:] = []
            # and tested as fast as re-using one list object
    if strict and len(accumulator) != 0:
        raise ValueError("Leftover values")


def get_bridge_mac(host):
  """ Both NS and NS Mx uses a slighly different OID"""
  var_list = netsnmp.VarList(
   *map(lambda x: netsnmp.Varbind(x), 
    ['IF-MIB::ifDescr','IF-MIB::ifPhysAddress']))
  sess = netsnmp.Session(Version=1, DestHost=host, Community='public', Timeout=6 * 100000, Retries=1)
  retval = sess.walk(var_list)
  if sess.ErrorInd < 0:
    raise CmdError('SNMP Failed -- [%(ErrorInd)s] %(ErrorStr)s (%(DestHost)s)' % vars(sess))
  if not filter(None, retval):
    return None
  else:
    # We only have bridge configurations, so looking at bridge MAC addresses
    mac_raw = dict(group_iter(retval,2))['br0']
    return ':'.join(map(lambda x: "%02x" % x,struct.unpack("BBBBBB",mac_raw)))





def node_check(host):
  """ Using multiple connect methods to do some basic health checking as well"""

  print "# Processing host", host
  datadump = gformat.get_yaml(host)
  output = host_ssh_cmd(datadump['autogen_fqdn'], 'cat /var/run/dmesg.boot')

  # Get board Type
  for line in [x.strip() for x in output]:
    if line.startswith('CPU:'):
      print line
    elif line.startswith('Geode LX:'):
      datadump['board'] = 'ALIX2'
      print line
    elif line.startswith('real memory'):
      print line
    elif line.startswith('Elan-mmcr'):
      datadump['board'] = 'net45xx'
  for iface_key in datadump['autogen_iface_keys']:
    ifacedump = datadump[iface_key]
    if ifacedump.has_key('ns_ip') and ifacedump['ns_ip']:
      addr = ifacedump['ns_ip'].split('/')[0]
      print "## Bridge IP: %(ns_ip)s at %(autogen_ifname)s" % ifacedump
      try:
        socket.create_connection((addr,80),2)
        bridge_mac = get_bridge_mac(addr)
        if bridge_mac:
          datadump[iface_key]['ns_mac'] = bridge_mac
        bridge_type = get_bridge_type(addr)
        if bridge_type:
          datadump[iface_key]['bridge_type'] = bridge_type
      except (socket.timeout, socket.error) as e:
        print "### %s (%s)" % (e, addr)
      except paramiko.AuthenticationException:
        print "### Conection failed (invalid username/password)"
      except CmdError, e: 
        print "### Command error: %s" % e

  try:
    wl_release = subprocess.check_output(['snmpget', '-Oq', '-Ov',  '-c', 'public', '-v2c',
      datadump['autogen_fqdn'], 'UCD-SNMP-MIB::ucdavis.84.4.1.2.6.119.108.45.118.101.114.1'])
    datadump['wl_release'] = int(wl_release.replace('"',''))
  except subprocess.CalledProcessError, ValueError:
    pass
  gformat.store_yaml(datadump)


def make_output(stdout, stderr):
  def p(prefix, lines):
    return ''.join(["#%s: %s" % (prefix, line) for line in lines])
  output = p('STDOUT', stdout)
  output += p('STDERR', stderr)
  return output

def ubnt_snmp(hostname):
  lines = """\
snmp.community=public
snmp.contact=beheer@lijst.wirelessleiden.nl
snmp.location=WL
snmp.status=enabled\
"""
  cmd = 'mca-config get /tmp/get.cfg && grep -v snmp /tmp/get.cfg > /tmp/new.cfg && echo "%s" >> /tmp/new.cfg \
    && mca-config activate /tmp/new.cfg 1>/dev/null 2>/dev/null && echo "ALL DONE"' % lines
  ssh = paramiko.SSHClient()
  ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
  ssh.connect(hostname, username='root', password=SSHPASS,timeout=3)
  stdin, stdout, stderr = ssh.exec_command(cmd)
  stdout = stdout.readlines()
  stderr = stderr.readlines()
  print make_output(stdout, stderr)
  ssh.close()

def ubnt_keys(hostname):
  keys = open(os.path.join(gformat.NODE_DIR,'global_keys'),'r').read()
  ssh = paramiko.SSHClient()
  ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
  ssh.connect(hostname, username='root', password=SSHPASS,timeout=3)
  cmd = 'test -d .ssh || mkdir .ssh;\
    cat > .ssh/authorized_keys && \
    chmod 0700 .ssh && \
    chmod 0755 . && cfgmtd -p /etc -w'
  stdin, stdout, stderr = ssh.exec_command(cmd)
  stdin.write(keys)
  stdin.flush()
  stdin.channel.shutdown_write()
  stdout = stdout.readlines()
  stderr = stderr.readlines()
  print make_output(stdout, stderr)
  ssh.close()

if __name__ == '__main__':
  # create the top-level parser
  parser = argparse.ArgumentParser(prog='Various WL management tools')
  parser.add_argument('--ask-pass', dest="ask_pass", action='store_true', help='Ask password if SSHPASS is not found')
  parser.add_argument('--filter', dest="use_filter", action='store_true', help='Thread the host definition as an filter')
  subparsers = parser.add_subparsers(help='sub-command help')
  
  parser_snmp = subparsers.add_parser('bridge', help='UBNT Bridge Management')
  parser_snmp.add_argument('action', type=str, choices=['keys', 'snmp', 'probe'])
  parser_snmp.add_argument('host',type=str)
  parser_snmp.set_defaults(func='bridge')
  
  parser_node = subparsers.add_parser('node', help='Proxy/Node/Hybrid Management')
  parser_node.add_argument('action', type=str, choices=['check',])
  parser_node.add_argument('host', type=str)
  parser_node.set_defaults(func='node')

  args = parser.parse_args()

  try:
    SSHPASS = os.environ['SSHPASS']
  except KeyError:
    print "#WARN: SSHPASS environ variable not found"
    if args.ask_pass:
      SSHPASS = getpass.getpass("WL root password: ")


  if args.use_filter:
    hosts = []
    for host in gformat.get_hostlist():
      if args.host in host:
        hosts.append(host)
  else:
    hosts = [args.host]


  for host in hosts:
    try:
      if args.func == 'bridge':
        if args.action == 'keys':
          ubnt_keys(host)
        elif args.action == 'snmp':
          ubnt_snmp(host)
        elif args.action == 'probe':
          ubnt_probe(host)
      elif args.func == 'node':
        if args.action == 'check':
          node_check(host)
    except ConnectError:
      print "#ERR: Connection failed to host %s" % host
