(**
  Experimental channel assigning program using an evolutionary algorithm.

  The scoring function is simple:
    - the score of a combination of two interfaces is 1 if there's two or more
      channels between them, -1 if there's 1 channel between them, -5 if
      there's no channels between them (they're on adjacent channels) and -10
      if they're on the same channel
    - the score for a node is the sum of the scores of all the combinations of
      interfaces for that node, minus any node-specific penalties (eg.
      channels 7, 8 and 9 are unusable and score -1000 on node x), scaled by
      the number of interfaces to give more weight to larger nodes (the
      assumption being that larger nodes are more important nodes)
    - the total score of the network is the sum of the score of the nodes,
      minus any network-wide penalties (eg. the omni's for node x and node y
      can see eachother, so they should be apart)

   special cases:
     - on node x, interfaces y and z are well separated and can live on the
       same channel
     - the link between x on y and z on w is very important, make sure it is
       well separated on both ends of the link
     - interface x on node y will bother interface z on node w and need to
       be on separate channels
     - on node x, channels y, z and w are not usable
     - node x y and z should be treated as one node
     - if at all possible, do not put accesspoints above channel 11

 - install an O'Caml compiler. /usr/ports/lang/ocaml/ in FreeBSD, ocaml in
   Debian.

 - compile with
  
    $ ocamlopt -o foo str.cmxa channelga.ml

 - run with

   $ ./foo f

 where f is the result of running prepconf.py on a file with a list of 
 paths to the wleiden.conf's to consider.

 Lodewijk Voge <lvoge@cs.vu.nl>
*)

(* a few constants suitable for twiddling *)
(** How large a population should the system maintain? *)
let population_size = 20
(** How long should the system iterate after an improvement? *)
and max_stagnant_iterations = 2000
(** What is the chance for an ESSID-to-channel assignment to mutate to a 
    random channel? *)
and mutation_rate = 0.1
(** The most basic score table *)
and scoretable = [ (<=) 2,  1;
		   (==) 2, -30;
		   (==) 1, -70;
		   (==) 0, -100 ]

(* the type definitions. note that Caml has trouble with mutually recursive
   data structures. you can define them, you just can't ever instantiate them.
   this is why the fields in wi are all loose references by way of strings *)
type wi = {
	wi_name: string;
	wi_nodename: string;
	wi_essid: string
}
type group = {
	group_essid: string;
	mutable group_wis: wi list
}
type node = {
	node_name: string;
	node_wis: wi list
}
(** A configuration is an assignment of groups, identified by essid, to a
    channel, plus a score. The code should be careful not to use the score
    between mutating an re-evaluating. *)
type configuration = {
	mutable score: int;
	conf: (string, int) Hashtbl.t
}
type 'a maybe = Nothing | Just of 'a

(** The global nodes hash, mapping from node name to node struct. *)
let nodes = Hashtbl.create 4
(** The global groups hash, mapping from essid to group struct. *)
let groups = Hashtbl.create 4

(* Now the hashes for the special cases *)
(** Hash mapping from nodename to a list of winame's indicating the wi's that
    don't interfere with eachother for the given node *)
let nointerference = Hashtbl.create 4
(** List of (nodename1, winame1, nodename2, winame2) tuples indicating a very
    important link that should be well-separated on both ends *)
let importantlinks = ref []
(** Hash mapping from nodename to a list of unusable channels for that node *)
let unusable = Hashtbl.create 4
(** List of (nodename1, winame1, nodename2, winame2) tuples indicating two
    interfering interfaces on different nodes *)
let interference = ref []

(** Run the given diff against the given scoretable and return the score *)
let rec runtable t diff =
     match t with
	[]		-> assert false
     | (cond, s)::xs	-> if (cond diff) then s
			   else runtable xs diff

(* some convenience functions *)

(** Function composition. *)
let compose f g = fun x -> f(g(x))
let ($) = compose
(** Turn two individual values into a tuple *)
let maketuple a b = (a, b)
(** Shorthand for List.hd *)
let head = List.hd
(** Shorthand for List.tail *)
let tail = List.tl
let even x = (x mod 2) == 0
(*let shuffle = Array.sort (fun _ _ -> 1 - Random.int(2))*)
let just x = match x with
	       Nothing -> assert false
	     | Just s -> s
(** Given a hashtable, return all the keys as a list *)
let keys t = Hashtbl.fold (fun k d a -> k::a) t []
(** Given a hashtable, return all the values as a list *)
let values t = Hashtbl.fold (fun k d a -> d::a) t []
(** Copy one array into the other *)
let copyarray src dest = Array.blit src 0 dest 0 (Array.length src)

(** Is the given element in the given list? uses compare, so it works on
    strings as well *)
let in_list l e = try
			let _ = List.find (fun x -> (compare e x) == 0) l in
			true
		  with Not_found -> false

(** Given a list, return a list of pairs with all possible combinations of 
   items from the given list *)
let rec combinations l =
	match l with
	  []	-> []
	| x::xs	-> (List.map (maketuple x) xs)@(combinations xs)

(** Given a configuration and two wi's, return the score *)
let wi_score c unusable nointerference wi1 wi2 =
	let channel1 = c wi1.wi_essid in
	let channel2 = c wi2.wi_essid in
	let diff = abs (channel1 - channel2) in
	let is_unusable = in_list unusable in
	if (is_unusable channel1) || (is_unusable channel2) then -10000
	else if (in_list nointerference wi1.wi_name) &&
		(in_list nointerference wi2.wi_name) then 1
	else runtable scoretable diff

(** Given a configuration and a node, return the score. this is simply the sum
    of the scores of all the combinations of interfaces, written down as a fold
    for efficiency *)
let node_score c n =
	let nointerference_ = try Hashtbl.find nointerference n.node_name
			      with Not_found -> [] in
	let unusable_ = try Hashtbl.find unusable n.node_name
			with Not_found -> [] in
	let f a (wi1, wi2) = a + (wi_score c unusable_ nointerference_ wi1 wi2) in
	let base_score = List.fold_left f 0 (combinations n.node_wis) in
	base_score * (List.length n.node_wis)

(** Score the given pair of interferent interfaces against the given
    configuration *)
let score_interference c (nodename1, winame1, nodename2, winame2) = 
	let node1 = Hashtbl.find nodes nodename1 in
	let node2 = Hashtbl.find nodes nodename2 in
	let f winame = fun wi -> (compare wi.wi_name winame) == 0 in
	let wi1 = List.find (f winame1) node1.node_wis in
	let wi2 = List.find (f winame2) node2.node_wis in
	let channel1 = c wi1.wi_essid in
	let channel2 = c wi2.wi_essid in
	let diff = abs (channel1 - channel2) in
	let res = runtable scoretable diff in
	res

(** Given a list of nodes and a configuration, return the score for the whole
    configuration. This is the sum of the scores for all nodes, plus the sum
    of the scores for all user-specified interferent pairs of interfaces. *)
let score_configuration ns c =
	let mapper = Hashtbl.find c in
	let f1 a n = a + (node_score mapper n) in
	let nodescores = List.fold_left f1 0 ns in
	let f2 a i = a + (score_interference mapper i) in
	let interference_score = List.fold_left f2 0 !interference in
	nodescores + interference_score

(** Return a random configuration. For some reason, if this function accesses
  the global 'groups' hash instead of getting it passed in from above, that
  hash is empty. *)
let random_configuration groups evaluate =
	let h = Hashtbl.create 30 in
	Hashtbl.iter (fun k _ -> Hashtbl.add h k (1 + (Random.int 12))) groups;
	{ score = (evaluate h); conf = h }

let print_conf conf = 
	let channel wi = string_of_int (Hashtbl.find conf wi.wi_essid) in
	let print_wi wi = wi.wi_name ^ ": " ^ (channel wi) in
	let wis node = List.fold_left (fun a wi -> a ^ " " ^ (print_wi wi))
				      "" node.node_wis in
	let cmpnode a b = compare (a.node_name) (b.node_name) in
	let sorted_nodes = List.sort cmpnode (values nodes) in
	let print_node n = print_string (n.node_name ^ ": " ^ (wis n) ^ "\n") in
	List.iter print_node sorted_nodes

(** n-point crossover operator. pick n points along the length of the parents, 
    produce a child by copying from one parent, switching parents when hitting a
    chosen crossover point *)
let crossover n c1 c2 = 
	let keys1 = keys (c1.conf) in
	let numkeys1 = List.length keys1 in
	let pickpoint i = (if even i then c1.conf else c2.conf),
			  (if i < n then (Random.int numkeys1) else numkeys1) in
	let crosspoints = Array.init (n + 1) pickpoint in
	let result = Hashtbl.create (List.length keys1) in
	let i = ref 0 in
	Array.sort (fun a b -> compare (snd a) (snd b)) crosspoints;
	Array.iter (fun (h, p) -> while !i < p do
					let key = List.nth keys1 !i in
					Hashtbl.add result key (Hashtbl.find h key);
					incr i
				  done) crosspoints;
	assert ((List.length (keys result)) == (List.length keys1));
	{ score = 0; conf = result }

(** Generalized evolutionary algorithm driver. 
      initialize: () -> configuration array
      recombine:
      mutate: configuration array -> configuration array
      evaluate: configuration array -> configuration array
      select: configuration array -> configuration array
      
    and the result is the best configuration *)
let evolutionary_algorithm initialize recombine mutate evaluate select = 
	let population = (evaluate $ initialize) () in
	let last_high_score = ref population.(0).score in
	let iterations_since_new_high_score = ref 0 in
	let generation = ref 0 in
	let all_nodes = values nodes in
	(*let iterate = recombine $ mutate $ evaluate $ select in*)
	let iterate = select $ evaluate $ mutate $ recombine in 
	while !iterations_since_new_high_score < max_stagnant_iterations do
		let newpop = iterate population in
		assert ((Array.length newpop) == population_size);
		copyarray newpop population;
		let high_score = population.(0).score in
		if high_score > !last_high_score then begin
			last_high_score := high_score;
			iterations_since_new_high_score := 0
		end;
		assert (!last_high_score >= high_score);
		if (!generation mod 10) == 0 then begin
			print_int !generation;
			print_string ": ";
			print_int !last_high_score;
			print_newline();
		end;
		incr iterations_since_new_high_score;
		incr generation
	done;
	population.(0)

(** BEGIN PARSING CODE *)

(** Given a filename, return a list of all the lines in the file with the given
   filename. Don't count on the order of the lines in the result. *)
let snarf_lines fname =
	let infile = open_in fname in
	let result = ref [] in
	try
		while true do
			result := (input_line infile)::!result
		done;
		!result	(* never gets here *)
	with End_of_file -> !result

(** Read the main input from the given filename *)
let parse_file fname =
	let spacere = Str.regexp " " in
	(** Given the name of the node currently being parsed, parse the given
	    tuple that consists of a wi name and an essid. *)
	let parse_pair nodename (wname, essid) = 
		let new_wi = { wi_name = wname;
			       wi_nodename = nodename;
			       wi_essid = essid} in
		let _ = try
				let group = Hashtbl.find groups essid in
				group.group_wis <- new_wi::group.group_wis;
			with Not_found ->
				let group = { group_essid = essid;
					      group_wis = [ new_wi ] } in
				Hashtbl.add groups essid group in
		new_wi in
	let parse_fields fields = 
		let nodename = head fields in
		let rec makepairs l =
			match l with
			  []		-> []
			| x::[]		-> assert false
			| a::b::xs	-> (a, b)::(makepairs xs) in
		let wis = List.map (parse_pair nodename)
				   (makepairs (tail fields)) in
		let sorted_wis = List.sort compare wis in
		let node = { node_name = nodename; node_wis = sorted_wis } in
		Hashtbl.add nodes nodename node in
	List.iter (parse_fields $ (Str.split spacere)) (snarf_lines fname)

(* the parsers for the special case components *)

(** The first field is the nodename, the rest are interface names *)
let parse_nointerference fs = Hashtbl.add nointerference (head fs) (tail fs)
(** Read four fields from the given list and add them as a tuple to the given
    list reference *)
let parse_quadruplet l fs = 
	let f = List.nth fs in
	l := (f 0, f 1, f 2, f 3)::!l
(** The first field is the nodename, the rest are channels.*)
let parse_unusable fs =
	let channels = List.map int_of_string (tail fs) in
	Hashtbl.add unusable (head fs) channels
(** The first field is the supernode name, the rest are the names of the
    subnodes. Construct a new node for the supernode and remove the subnodes
    from the nodes hash *)
let parse_supernode fs =
	let nodename = head fs in
	let subnodenames = tail fs in
	let subnodes = List.map (Hashtbl.find nodes) subnodenames in
	List.iter (Hashtbl.remove nodes) subnodenames;
	let prefixed_wis n = List.map (fun w -> { w with wi_name = n.node_name ^ "." ^ w.wi_name}) n.node_wis in
	let wis = List.fold_left (fun a s -> a@(prefixed_wis s)) [] subnodes in
	let node = { node_name = nodename; node_wis = wis } in
	Hashtbl.add nodes nodename node

let parse_special_conf fname =
	let spacere = Str.regexp " " in
	let functable = [ "nointerference", parse_nointerference;
			  "important", parse_quadruplet importantlinks;
			  "interference", parse_quadruplet interference;
			  "unusable", parse_unusable;
			  "supernode", parse_supernode ] in
	let do_line fs = (List.assoc (head fs) functable) (tail fs) in
	try 
		List.iter (do_line $ Str.split spacere) (snarf_lines fname)
	with x -> ()

(** END PARSING CODE *)

let main = 
	parse_file Sys.argv.(1);
	parse_special_conf "special.conf";
	Random.self_init();
	let all_nodes = values nodes in
	let evaluate_hash = score_configuration all_nodes in
	let initialize () = Array.init population_size (fun _ -> random_configuration groups evaluate_hash) in
	let recombine pop = pop in
(*
		let numoffspring = Random.int population_size in
		let children = Array.init numoffspring (fun _ -> 
			let father = pop.(Random.int population_size) in
			let mother = pop.(Random.int population_size) in
			crossover 2 father mother) in
		Array.append pop children in *)
	let maxchannel essid =
		let group = Hashtbl.find groups essid in
		if (List.length group.group_wis) == 1 then 11
		else 13 in
	let mutate_conf conf =
		Hashtbl.iter (fun essid _ ->
				let f = Random.float 1.0 in
				if (f < mutation_rate) then
					let channel = 1 + (Random.int (maxchannel essid)) in
					Hashtbl.replace conf essid channel) conf in
	let mutate population =
		let mutants = Array.map (fun c -> let hash = Hashtbl.copy c.conf in
						  mutate_conf hash;
						  { score = evaluate_hash hash;
						    conf = hash}) population in
		Array.append population mutants in
	let evaluate population =
		Array.iter (fun c -> c.score <- evaluate_hash c.conf) population;
		population in
	let select p = 
		Array.sort (fun a b -> compare b.score a.score) p;
		(*shuffle p;*)
		Array.sub p 0 population_size in
	let best = evolutionary_algorithm initialize recombine mutate evaluate select in
	print_conf best.conf;;

main
