#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Script for importing various stumble files in a modular fasion:
# - .ns1 (Netstumber)
# - .gpsxml .netxml (Kismet)
# - DroidStumbler-*.csv (DroidStumber)
#
# Rick van der Zwet <info@rickvanderzwet.nl>
#
from _mysql_exceptions import OperationalError
from django.db import connection, transaction
from django.db.utils import IntegrityError
from gheat.models import *
import gzip
import os
import sys
import logging

from collections import defaultdict

import netstumbler
import kismet
import droidstumbler

logger = logging.getLogger(__name__)

def open_file(file):
  """ Open files for reading, unzip if needed """
  if file.endswith('.gz'):
    return gzip.open(file,'rb')
  else:
   return open(file,'rb')




def bulk_sql(sql_table, sql_values):
  """ Awefull hack to ensure we can do mass imports into the DJANGO databases"""
  if len(sql_values) == 0:
    raise ValueError, "No data to import"

  cursor = connection.cursor()
  try:
    # Make sure the special NULL is preserved
    sql = "INSERT INTO %s VALUES %s" % (sql_table, ','.join(sql_values).replace("'NULL'",'NULL'))
    count = cursor.execute(sql)
    transaction.commit_unless_managed()
  except OperationalError, e:
    logger.error("%s - %s ", sql_table, sql_values[0])
    raise
  except IntegrityError, e:
    logger.error("Unable to import - %s" %  e)
    raise
  return count


organizations = None
def get_organization_id_by_ssid(ssid):
  """ Wrapper to return Organization ID of a certain SSID Type
  XXX: This should technically be casted into the Organization properly, but
  XXX: that properly does not cache properly.
  """
  global organizations
  if not organizations:
   organizations = dict(Organization.objects.all().values_list('name','id'))
  
  name = Organization.get_name_by_ssid(ssid)
  if not name:
    return 'NULL'
  else:
    return int(organizations[name])



def import_accespoints(ap_pool, counters):
  # Determine which Accespoints to add
  bssid_list_present = Accespoint.objects.filter(mac__in=ap_pool.keys()).\
    values_list('mac', flat=True)
  bssid_list_insert = set(ap_pool.keys()) - set(bssid_list_present)

  # Create a bulk import list and import
  if bssid_list_insert:
    sql_values = []
    for bssid in bssid_list_insert:
      ssid, encryption = ap_pool[bssid]
      # Special trick in SSID ts avoid escaping in later stage
      item = str((bssid.upper(),ssid.replace('%','%%'),encryption,
        get_organization_id_by_ssid(ssid)))
      sql_values.append(item)
    counters['ap_added'] = bulk_sql('gheat_accespoint (`mac`, `ssid`,\
      `encryptie`, `organization_id`)',sql_values)
  return counters



def import_metingen(meetrondje, meting_pool, counters):
  # Temponary holders
  bssid_failed = defaultdict(int)

  bssid_list = [x[0] for x in meting_pool.keys()]
  # Build mapping for meting import
  mac2id = {}
  for mac,id in Accespoint.objects.filter(mac__in=bssid_list).\
    values_list('mac','id'):
    mac2id[mac] = int(id)

  clients = {}
  for mac in WirelessClient.objects.filter(mac__in=bssid_list).\
    values_list('mac',flat=True):
    clients[mac] = True

  sql_values = []
  for (bssid,lat,lon),signals in meting_pool.iteritems():
    final_signal = max(signals)
    if clients.has_key(bssid):
      counters['meting_ignored'] += len(signals)
    elif not mac2id.has_key(bssid):
      counters['meting_failed'] += len(signals)
      bssid_failed[bssid] += len(signals)
    elif final_signal < MIN_SIGNAL or final_signal > MAX_SIGNAL: 
      counters['meting_failed'] += len(signals)
    else:
      item = str((int(meetrondje.id),mac2id[bssid],float(lat),\
        float(lon),max(signals)))
      sql_values.append(item)

  for bssid,count in sorted(bssid_failed.items(),
      key=lambda item: item[1], reverse=True):
    logger.debug("Missing BSSID %s found %3s times", bssid, count)

  if sql_values:
    counters['meting_added'] = bulk_sql('gheat_meting (`meetrondje_id`,\
      `accespoint_id`, `latitude`, `longitude`, `signaal`)',sql_values)
  return counters


def import_clients(client_pool, counters):
  # Determine which Wireless Clients to add
  bssid_list_present = WirelessClient.objects.filter(mac__in=client_pool.keys()).values_list('mac', flat=True)
  bssid_list_insert = set(client_pool.keys()) - set(bssid_list_present)

  # Create a bulk import list and import
  if bssid_list_insert:
    sql_values = []
    for bssid in bssid_list_insert:
      sql_values.append("('%s')" % bssid.upper())
    counters['client_added'] = bulk_sql('gheat_wirelessclient (`mac`)',sql_values)

  return counters


def import_file(filename,meetrondje):
  """ Import a file (on disk) """
  counters = {
    'ap_added' : 0, 'ap_total' : 0,
    'ap_failed' : 0, 'ap_ignored' : 0,
    'client_added' : 0, 'client_total' : 0,
    'client_failed' : 0, 'client_ignored' : 0,
    'meting_added' : 0, 'meting_total' : 0,
    'meting_failed' : 0, 'meting_ignored' : 0
    }

  if os.path.getsize(filename) == 0:
    logger.error("Cannot parse empty files")
    return counters
    
  fh = open_file(filename)
  try:
    if 'ns1' in filename:
      (counters, ap_pool, client_pool, meting_pool) = netstumbler.process_ns1(fh, counters)
    elif 'gpsxml' in filename:
      (counters, ap_pool, client_pool, meting_pool) = kismet.process_gpsxml(fh, counters)
    elif 'netxml' in filename:
      (counters, ap_pool, client_pool, meting_pool) = kismet.process_netxml(fh, counters)
    elif 'ScanResult' in filename:
      (counters, ap_pool, client_pool, meting_pool) = droidstumbler.process_csv(fh, counters)
    else:
      (ap_pool, client_pool, meting_pool) = (None, None, None)
      logger.error("file '%s' format not recognized")
  except IOError, e:
    logger.error("File invalid: %s", e)
    return counters

  if ap_pool:
    counters = import_accespoints(ap_pool, counters)
  if client_pool:
    counters = import_clients(client_pool, counters)
  if meting_pool:
    counters = import_metingen(meetrondje, meting_pool, counters)

  logger.debug("summary accespoints: total:%(ap_total)-6s added:%(ap_added)-6s failed:%(ap_failed)-6s ignored:%(ap_ignored)-6s" % counters)
  logger.debug("summary client     : total:%(client_total)-6s added:%(client_added)-6s failed:%(client_failed)-6s ignored:%(client_ignored)-6s" % counters)
  logger.debug("summary metingen   : total:%(meting_total)-6s added:%(meting_added)-6s failed:%(meting_failed)-6s ignored:%(meting_ignored)-6s" % counters)

  return counters
