#!/usr/bin/python
#-*- coding: utf-8 -*-

import graph_tool.all as gt
from numpy.random import random
import copy

class Cost(object):
    """Co��t de construction, rassemblant diff��rents crit��res : distance ��
    l'origine (nombre de pas), suivi de terrain, etc.
    """

    # Somme des ��cart par rapport au suivi de terrain optimal
    suivi = 0

    # Nombre de cases parcourues
    longueur = 0

    # Nombre de virages
    virage = 0


    def __add__(self, autre):
        retour = copy.copy(self)
        retour += autre
        return retour


    def __cmp__(self, autre):
        return cmp(self.suivi, autre.suivi) \
            or cmp(self.longueur, autre.longueur) \
            or cmp(self.virage, autre.virage)


    def __hash__(self):
        return hash(str(self))

    
    def __iadd__(self, autre):
        self.suivi += autre.suivi
        self.longueur += autre.longueur
        self.virage += autre.virage

        return self


    def __repr__(self):
        return str(self.suivi) \
            + "|" + str(self.longueur) \
            + "|" + str(self.virage)


    def __str__(self):
        return "suivi : " + str(self.suivi) \
            + " longueur : " + str(self.longueur) \
            + " virage : " + str(self.virage)


    
INFINITE_COST = Cost()
INFINITE_COST.suivi = float("inf")
INFINITE_COST.longueur = float("inf")
INFINITE_COST.virage = float("inf")



class HammingVisitor(gt.AStarVisitor):


    def __init__(self, g, target, state, weight, dist, cost):
        self.g = g
        self.state = state
        self.target = target
        self.weight = weight
        self.dist = dist
        self.cost = cost
        self.visited = {}


    def examine_vertex(self, u):
        for i in xrange(len(self.state[u])):
            nstate = list(self.state[u])
            nstate[i] ^= 1
            if tuple(nstate) in self.visited:
                v = self.visited[tuple(nstate)]
            else:
                v = self.g.add_vertex()
                self.visited[tuple(nstate)] = v
                self.state[v] = nstate
                self.dist[v] = self.cost[v] = INFINITE_COST
            for e in u.out_edges():
                if e.target() == v:
                    break
            else:
                e = self.g.add_edge(u, v)
                self.weight[e] = Cost()
                self.weight[e].suivi = int(15 * random())
                self.weight[e].longueur = int(15 * random())
                self.weight[e].virage = int(15 * random())
        self.visited[tuple(self.state[u])] = u


    def edge_relaxed(self, e):
        if self.state[e.target()] == self.target:
            self.visited[tuple(self.target)] = e.target()
            raise gt.StopSearch()



def h(v, target, state):
    retour = Cost()
    retour.suivi = sum(abs(state[v].a - target)) / 2

    return retour



if __name__ == "__main__":
    g = gt.Graph(directed=False)
    state = g.new_vertex_property("vector<bool>")
    v = g.add_vertex()
    state[v] = [0] * 10
    target = [1] * 10
    weight = g.new_edge_property("object")
    dist = g.new_vertex_property("object")
    cost = g.new_vertex_property("object")
    visitor = HammingVisitor(g, target, state, weight, dist, cost)
    dist[g.vertex(0)] = Cost()
    dist, pred = gt.astar_search(g, g.vertex(0), weight, visitor, dist_map=dist,
                                 cost_map=cost,
                                 heuristic=lambda v: h(v, list(target), state),
                                 zero = Cost(),
                                 infinity = INFINITE_COST,
                                 implicit=True)

    # We can now observe the best path found, and how many vertices and edges
    # were visited in the process.

    ecolor = g.new_edge_property("string")
    vcolor = g.new_vertex_property("string")
    ewidth = g.new_edge_property("double")
    ewidth.a = 2
    for e in g.edges():
        ecolor[e] = "black"
    for v in g.vertices():
        vcolor[v] = "white"
    v = visitor.visited[tuple(target)]
    while v != g.vertex(0):
        vcolor[v] = "black"
        p = g.vertex(pred[v])
        for e in v.out_edges():
            if e.target() == p:
                ecolor[e] = "red"
                ewidth[e] = 4
        v = p
    vcolor[v] = "black"
    gt.graph_draw(g, size=(10,10), vsize=0.25, vprops={"fillcolor": vcolor},
                  eprops={"color": ecolor}, penwidth=ewidth,
                  output="astar-implicit.pdf")
