#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Script for importing various stumble files in a modular fasion:
# - .ns1 (Netstumber)
# - .gpsxml .netxml (Kismet)
# - DroidStumbler-*.csv (DroidStumber)
#
# Rick van der Zwet <info@rickvanderzwet.nl>
#
from _mysql_exceptions import OperationalError
from django.core.management.base import BaseCommand,CommandError
from django.db import connection, transaction
from django.db.utils import IntegrityError
from gheat.models import *
from optparse import OptionParser, make_option
import datetime
import gzip
import os
import sys
import logging

from collections import defaultdict

import netstumbler
import kismet
import droidstumbler

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

def open_file(file):
  """ Open files for reading, unzip if needed """
  if file.endswith('.gz'):
    return gzip.open(file,'rb')
  else:
   return open(file,'rb')


valid_prefixes = ['DroidStumbler-', 'Kismet-','ScanResult-']
def strip_prefix(filename):
  """ Prefix removal """
  for prefix in valid_prefixes:
    if filename.startswith(prefix):
      filename = filename[len(prefix):]
  return filename


valid_suffixes = ['.gz', '.gpsxml', '.netxml', '.csv', '.ns1']
def strip_suffix(filename):
  """ Suffix removal """
  for suffix in valid_suffixes:
    if filename.endswith(suffix):
      filename = filename[:-len(suffix)]
  return filename


def strip_file(filename):
  """ Prefix and suffix removal """
  return strip_suffix(strip_prefix(filename))


#Kismet-20110805-15-37-30-1
#ScanResult-2011-05-09-201117
strptime_choices = ['%Y%m%d-%H-%M-%S-1', '%Y-%m-%d-%H%M%S']
def process_date(datestr):
  for strptime in strptime_choices:
    try:
      return datetime.datetime.strptime(datestr,strptime)
    except ValueError:
      pass
  logger.error("Invalid date '%s', options: %s, using: now()", datestr, strptime_choices)
  return datetime.datetime.now()


def bulk_sql(sql_table, sql_values):
  """ Awefull hack to ensure we can do mass imports into the DJANGO databases"""
  if len(sql_values) == 0:
    raise ValueError, "No data to import"

  cursor = connection.cursor()
  try:
    # Make sure the special NULL is preserved
    sql = "INSERT INTO %s VALUES %s" % (sql_table, ','.join(sql_values).replace("'NULL'",'NULL'))
    count = cursor.execute(sql)
    transaction.commit_unless_managed()
  except OperationalError, e:
    logger.error("%s - %s ", sql_table, sql_values[0])
    raise
  except IntegrityError, e:
    logger.error("Unable to import - %s" %  e)
    raise
  return count


organizations = None
def get_organization_id_by_ssid(ssid):
  """ Wrapper to return Organization ID of a certain SSID Type
  XXX: This should technically be casted into the Organization properly, but
  XXX: that properly does not cache properly.
  """
  global organizations
  if not organizations:
   organizations = dict(Organization.objects.all().values_list('name','id'))
  
  name = Organization.get_name_by_ssid(ssid)
  if not name:
    return 'NULL'
  else:
    return int(organizations[name])



def import_accespoints(ap_pool, counters):
  # Determine which Accespoints to add
  bssid_list_present = Accespoint.objects.filter(mac__in=ap_pool.keys()).\
    values_list('mac', flat=True)
  bssid_list_insert = set(ap_pool.keys()) - set(bssid_list_present)

  # Create a bulk import list and import
  if bssid_list_insert:
    sql_values = []
    for bssid in bssid_list_insert:
      ssid, encryption = ap_pool[bssid]
      # Special trick in SSID ts avoid escaping in later stage
      item = str((bssid.upper(),ssid.replace('%','%%'),encryption,
        get_organization_id_by_ssid(ssid)))
      sql_values.append(item)
    counters['ap_added'] = bulk_sql('gheat_accespoint (`mac`, `ssid`,\
      `encryptie`, `organization_id`)',sql_values)
  return counters



def import_metingen(meetrondje, meting_pool, counters):
  # Temponary holders
  bssid_failed = defaultdict(int)

  bssid_list = [x[0] for x in meting_pool.keys()]
  # Build mapping for meting import
  mac2id = {}
  for mac,id in Accespoint.objects.filter(mac__in=bssid_list).\
    values_list('mac','id'):
    mac2id[mac] = int(id)

  clients = {}
  for mac in WirelessClient.objects.filter(mac__in=bssid_list).\
    values_list('mac',flat=True):
    clients[mac] = True

  sql_values = []
  for (bssid,lat,lon),signals in meting_pool.iteritems():
    if clients.has_key(bssid):
      counters['meting_ignored'] += len(signals)
    elif not mac2id.has_key(bssid):
      counters['meting_failed'] += len(signals)
      bssid_failed[bssid] += len(signals)
    else:
      item = str((int(meetrondje.id),mac2id[bssid],float(lat),\
        float(lon),max(signals)))
      sql_values.append(item)

  for bssid,count in sorted(bssid_failed.items(),
      key=lambda item: item[1], reverse=True):
    logger.debug("Missing BSSID %s found %3s times", bssid, count)

  if sql_values:
    counters['meting_added'] = bulk_sql('gheat_meting (`meetrondje_id`,\
      `accespoint_id`, `lat`, `lng`, `signaal`)',sql_values)
  return counters


def import_clients(client_pool, counters):
  # Determine which Wireless Clients to add
  bssid_list_present = WirelessClient.objects.filter(mac__in=client_pool.keys()).values_list('mac', flat=True)
  bssid_list_insert = set(client_pool.keys()) - set(bssid_list_present)

  # Create a bulk import list and import
  if bssid_list_insert:
    sql_values = []
    for bssid in bssid_list_insert:
      sql_values.append("('%s')" % bssid.upper())
    counters['client_added'] = bulk_sql('gheat_wirelessclient (`mac`)',sql_values)

  return counters





class Command(BaseCommand):
  args = '<netstumber.ns1>[.gz] [netstumber2.ns1[.gz]  netstumber3.ns1[.gz] ...]'
  option_list = BaseCommand.option_list + (
    make_option('-k', '--kaart', dest='kaart', default='onbekend', 
      help="Kaart gebruikt"),
    make_option('-m', '--meetrondje', dest='meetrondje', default=None),
    make_option('-g', '--gebruiker', dest='gebruiker', default='username',
      help='Naam van de persoon die de meting uitgevoerd heeft'),
    make_option('-e', '--email', dest='email', default='foo@bar.org',
      help='Email van de persoon die de meting uitgevoerd heeft'),
    make_option('-d', '--datum', dest='datum', default=None,
      help="Provide date in following format: '%Y%m%d-%H-%M-%S-1', by \
      default it will be generated from the filename"),
  )

  def handle(self, *args, **options):
    if options['verbosity'] == 2:
      logger.setLevel(logging.DEBUG)
    if len(args) == 0:
      self.print_help(sys.argv[0],sys.argv[1])
      raise CommandError("Not all arguments are provided")

    # Please first the netxml and the gpsxml files and the rest
    sorted_args = [x for x in args if "netxml" in x] +\
     [x for x in args if "gpsxml" in x] +\
     [x for x in args if "ns1" in x]
    remainder = list(set(args) - set(sorted_args))
    args = sorted_args + remainder
    logger.debug("Parsing files in the following order: %s", args)

    # Make sure the all exists at first
    for filename in args:
      if not os.path.isfile(filename):
        raise CommandError("file '%s' does not exists" % filename)


    def get_date(filename):
      if options['datum'] == None:
         datestr = strip_file(os.path.basename(filename))
         datum = process_date(datestr)
      elif options['datum'] == 'now':
         datum = datetime.datetime.now()
      else:
         datum = process_date(options['datum'])
      return datum

    def get_meetrondje(meetrondje):
      # Meetrondje from filename if needed
      if options['meetrondje'] == None:
        meetrondje = strip_suffix(os.path.basename(filename))
      else:
        meetrondje = options['meetrondje']
      return meetrondje

    # Get Gheat Objects, pre-req
    g, created = Gebruiker.objects.get_or_create(naam=options['gebruiker'],
      email=options['email'])
    a, created = Apparatuur.objects.get_or_create(kaart=options['kaart'])

    # Check if all files are valid
    for filename in args:
      logger.info("Processing '%s'" % filename)
      mr, created = MeetRondje.objects.get_or_create(
        datum=get_date(filename), naam=get_meetrondje(filename),
        gebruiker=g, apparatuur=a)
      if not created:
        logger.error("Meetrondje '%s' already imported",  mr)
        continue

      counters = {
        'ap_added' : 0, 'ap_total' : 0,
        'ap_failed' : 0, 'ap_ignored' : 0,
        'client_added' : 0, 'client_total' : 0,
        'client_failed' : 0, 'client_ignored' : 0,
        'meting_added' : 0, 'meting_total' : 0,
        'meting_failed' : 0, 'meting_ignored' : 0
        }
      logger.info('Meetrondje: %s', mr)
      fh = open_file(filename)
      if 'ns1' in filename:
        (counters, ap_pool, client_pool, meting_pool) = netstumbler.process_ns1(fh, counters)
      elif 'gpsxml' in filename:
        (counters, ap_pool, client_pool, meting_pool) = kismet.process_gpsxml(fh, counters)
      elif 'netxml' in filename:
        (counters, ap_pool, client_pool, meting_pool) = kismet.process_netxml(fh, counters)
      elif 'ScanResult' in filename:
        (counters, ap_pool, client_pool, meting_pool) = droidstumbler.process_csv(fh, counters)
      else:
        raise CommandError("file '%s' format not recognized" % filename)

      if ap_pool:
        counters = import_accespoints(ap_pool, counters)
      if client_pool:
        counters = import_clients(client_pool, counters)
      if meting_pool:
        counters = import_metingen(mr, meting_pool, counters)

      logger.info("summary accespoints: total:%(ap_total)-6s added:%(ap_added)-6s failed:%(ap_failed)-6s ignored:%(ap_ignored)-6s" % counters)
      logger.info("summary client     : total:%(client_total)-6s added:%(client_added)-6s failed:%(client_failed)-6s ignored:%(client_ignored)-6s" % counters)
      logger.info("summary metingen   : total:%(meting_total)-6s added:%(meting_added)-6s failed:%(meting_failed)-6s ignored:%(meting_ignored)-6s" % counters)
