'''
A scenario discovery oriented implementation of CART. It essentially is
a wrapper around scikit-learn's version of CART.
'''
import io
import math
from io import StringIO
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import pandas as pd
from sklearn import tree
from . import scenario_discovery_util as sdutil
from ..util import get_module_logger
from ..util.ema_exceptions import EMAError
from pygments.unistring import cats
# Created on May 22, 2015
#
# .. codeauthor:: jhkwakkel <j.h.kwakkel (at) tudelft (dot) nl>
__all__ = ['setup_cart',
           'CART']
_logger = get_module_logger(__name__)
def setup_cart(results, classify, incl_unc=None, mass_min=0.05):
    """helper function for performing cart in combination with data
    generated by the workbench.
    Parameters
    ----------
    results : tuple of DataFrame and dict with numpy arrays
              the return from :meth:`perform_experiments`.
    classify : string, function or callable
               either a string denoting the outcome of interest to
               use or a function.
    incl_unc : list of strings, optional
    mass_min : float, optional
    Raises
    ------
    TypeError
        if classify is not a string or a callable.
    """
    x, outcomes = results
    if incl_unc is not None:
        drop_names = set(x.columns.values.tolist()) - set(incl_unc)
        x = x.drop(drop_names, axis=1)
    if isinstance(classify, str):
        y = outcomes[classify]
        mode = sdutil.RuleInductionType.REGRESSION
    elif callable(classify):
        y = classify(outcomes)
        mode = sdutil.RuleInductionType.BINARY
    else:
        raise TypeError("unknown type for classify")
    return CART(x, y, mass_min, mode=mode)
class CART(sdutil.OutputFormatterMixin):
    '''CART algorithm
    can be used in a manner similar to PRIM. It provides access
    to the underlying tree, but it can also show the boxes described by the
    tree in a table or graph form similar to prim.
    Parameters
    ----------
    x : DataFrame
    y : 1D ndarray
    mass_min : float, optional
               a value between 0 and 1 indicating the minimum fraction
               of data points in a terminal leaf. Defaults to 0.05,
               identical to prim.
    mode : {BINARY, CLASSIFICATION, REGRESSION}
           indicates the mode in which CART is used. Binary indicates
           binary classification, classification is multiclass, and regression
           is regression.
    Attributes
    ----------
    boxes : list
            list of DataFrame box lims
    stats : list
            list of dicts with stats
    Notes
    -----
    This class is a wrapper around scikit-learn's CART algorithm. It provides
    an interface to CART that is more oriented towards scenario discovery, and
    shared some methods with PRIM
    See also
    --------
    :mod:`prim`
    '''
    sep = '!?!'
    def __init__(self, x, y, mass_min=0.05,
                 mode=sdutil.RuleInductionType.BINARY):
        ''' init
        '''
        try:
            x = x.drop(["scenario"], axis=1)
        except KeyError:
            pass
        self.x = x
        self.y = y
        self.mass_min = mass_min
        self.mode = mode
        # we need to transform the structured array to a ndarray
        # we use dummy variables for each category in case of categorical
        # variables. Integers are treated as floats
        dummies = pd.get_dummies(self.x, prefix_sep=self.sep)
        
        self.dummiesmap = {}
        for column, values in x.select_dtypes(exclude=np.number).iteritems():
            mapping = {str(entry):entry for entry in values.unique()}
            self.dummiesmap[column] = mapping
        self.feature_names = dummies.columns.values.tolist()
        self._x = dummies.values
        self._boxes = None
        self._stats = None
    @property
    def boxes(self):
        if self._boxes:
            return self._boxes
        # based on
        # http://stackoverflow.com/questions/20224526/how-to-extract-the-
        # decision-rules-from-scikit-learn-decision-tree
        assert self.clf
        left = self.clf.tree_.children_left
        right = self.clf.tree_.children_right
        threshold = self.clf.tree_.threshold
        features = [self.feature_names[i] for i in self.clf.tree_.feature]
        # get ids of leaf nodes
        leafs = np.argwhere(left == -1)[:, 0]
        def recurse(left, right, child, lineage=None):
            if lineage is None:
                # lineage = [self.clf.tree_.value[child]]
                lineage = []
            if child in left:
                parent = np.where(left == child)[0].item()
                split = 'l'
            else:
                parent = np.where(right == child)[0].item()
                split = 'r'
            lineage.append((parent, split, threshold[parent],
                            features[parent]))
            if parent == 0:
                lineage.reverse()
                return lineage
            else:
                return recurse(left, right, parent, lineage)
        box_init = sdutil._make_box(self.x)
        boxes = []
        for leaf in leafs:
            branch = recurse(left, right, leaf)
            box = box_init.copy()
            for node in branch:
                direction = node[1]
                value = node[2]
                unc = node[3]
                
                if direction == 'l':
                    if unc in box_init.columns:
                        box.loc[1, unc] = value
                    else :
                        unc, cat = unc.split(self.sep)
                        cats = box.loc[0, unc]
                        # TODO:: cat is a str needs casting?
                        # what abouta lookup table mapping
                        # each str cat to the associate actual cat
                        # object
                        # can be created when making the dummy variables
                        
                        cats.discard(self.dummiesmap[unc][cat])
                        box.loc[:, unc] = [set(cats), set(cats)]
                else:
                    if unc in box_init.columns:
                        if box[unc].dtype == np.int32:
                            value = math.ceil(value)
                        box.loc[0, unc] = value
            boxes.append(box)
        self._boxes = boxes
        return self._boxes
    @property
    def stats(self):
        if self._stats:
            return self._stats
        boxes = self.boxes
        box_init = sdutil._make_box(self.x)
        self._stats = []
        for box in boxes:
            boxstats = self._boxstat_methods[self.mode](self, box, box_init)
            self._stats.append(boxstats)
        return self._stats
    def _binary_stats(self, box, box_init):
        indices = sdutil._in_box(self.x, box)
        y_in_box = self.y[indices]
        box_coi = np.sum(y_in_box)
        density = box_coi / y_in_box.shape[0]
        gini = 1-((density ** 2) + ((1-density) ** 2))
        with np.errstate(all='ignore'):
            entropy = -((density * np.log2(density)) + ((1-density) * np.log2(1-density)))
        entropy = np.nan_to_num(entropy)
        boxstats = {'coverage': box_coi / np.sum(self.y),
                    'density': density,
                    'gini': gini,
                    'entropy': entropy,
                    'res dim': sdutil._determine_nr_restricted_dims(box,
                                                                    box_init),
                    'res dim names': sdutil._determine_restricted_dims(box,
                                                                    box_init),
                    'mass': y_in_box.shape[0] / self.y.shape[0]}
        return boxstats
    def _regression_stats(self, box, box_init):
        indices = sdutil._in_box(self.x, box)
        y_in_box = self.y[indices]
        boxstats = {'mean': np.mean(y_in_box),
                    'mass': y_in_box.shape[0] / self.y.shape[0],
                    'res dim': sdutil._determine_nr_restricted_dims(box,
                                                                    box_init),
                    'res dim names': sdutil._determine_restricted_dims(box,
                                                                       box_init),
                    }
        return boxstats
    def _classification_stats(self, box, box_init):
        indices = sdutil._in_box(self.x, box)
        y_in_box = self.y[indices]
        classes = set(self.y)
        classes = sorted(classes)
        counts = [y_in_box[y_in_box == ci].shape[0] for ci in classes]
        total_gini = 0
        for count in counts:
            total_gini += (count / y_in_box.shape[0])**2
        gini = 1 - total_gini
        boxstats = {'gini': gini,
                    'mass': y_in_box.shape[0] / self.y.shape[0],
                    'box_composition': counts,
                    'res dim': sdutil._determine_nr_restricted_dims(box,
                                                                    box_init),
                    'res dim names': sdutil._determine_restricted_dims(box,
                                                                    box_init),
                    }
        return boxstats
    _boxstat_methods = {sdutil.RuleInductionType.BINARY: _binary_stats,
                        sdutil.RuleInductionType.REGRESSION: _regression_stats,
                        sdutil.RuleInductionType.CLASSIFICATION: _classification_stats}
[docs]    def build_tree(
            self,
            criterion="gini",
            max_depth=None,
            mass_min=None,
            min_samples_split=2,
    ):
        '''train CART on the data'''
        assert criterion in ["gini", "entropy"]
        if mass_min is not None:
            self.mass_min = mass_min
        min_samples = int(self.mass_min * self.x.shape[0])
        if self.mode == sdutil.RuleInductionType.REGRESSION:
            self.clf = tree.DecisionTreeRegressor(
                min_samples_leaf=min_samples,
            )
        else:
            self.clf = tree.DecisionTreeClassifier(
                criterion=criterion,
                min_samples_split=min_samples_split,
                splitter="best",
                max_depth=max_depth,
                min_samples_leaf=min_samples,
                min_weight_fraction_leaf=0.,
                max_features=None,
                random_state=None,
                max_leaf_nodes=None,
                min_impurity_decrease=0.,
                min_impurity_split=None,
                class_weight=None,
                ccp_alpha=0.0,
            )
        self.clf.fit(self._x, self.y) 
[docs]    def show_tree(self, mplfig=True, format='png'):
        '''return a png of the tree
        Parameters
        ----------
        mplfig : bool, optional
                 if true (default) returns a matplotlib figure with the tree,
                 otherwise, it returns the output as bytes
        format : {'png', 'svg'}, default 'png'
                 Gives a format of the output.
        '''
        assert self.clf
        try:
            import pydotplus as pydot
        except ImportError:
            import pydot  # dirty hack for read the docs
        dot_data = StringIO()
        tree.export_graphviz(
            self.clf, out_file=dot_data,
            feature_names=self.feature_names,
            filled=True,
        )
        dot_data = dot_data.getvalue()  # .encode('ascii') # @UndefinedVariable
        graphs = pydot.graph_from_dot_data(dot_data)
        
        # FIXME:: pydot now always returns a list, usted to be either a
        # singleton or a list. This is a stopgap which might be sufficient
        # but just in case, we raise an error if assumption of len==1 does
        # not hold
        if len(graphs)>1:
            raise EMAError("trying to visualize more than one tree")
        
        graph = graphs[0]
        
        if format == 'png':
            img = graph.create_png()
            if mplfig:
                fig, ax = plt.subplots(figsize=(16,16))
                ax.imshow(mpimg.imread(io.BytesIO(img)))
                ax.axis('off')
                return fig
        elif format == 'svg':
            img = graph.create_svg()
        else:
            raise TypeError('''format must be in {'png', 'svg'}''')
        return img 
    def build_and_show_tree(self, **kwargs):
        self._boxes = None
        self._stats = None
        self.build_tree(**kwargs)
        return self.show_tree()
[docs]    def tree_chooser(self):
        """
        An interactive chooser for setting decision tree hyperparameters.
        This method returns an interactive widget that allows an analyst
        to manipulate selected hyperparameters for the decision tree used
        by CART.  The analyst can set the branch splitting criteria
        (gini impurity or entropy reduction), the maximum tree depth, and
        the minimum fraction of observations in any leaf node.
        Returns
        -------
        ipywidgets.widgets.interaction.interactive
        """
        from ipywidgets import interactive, FloatSlider, Dropdown
        return interactive(
            self.build_and_show_tree,
            criterion=Dropdown(
                options=['gini', 'entropy'], value='gini',
            ),
            max_depth=Dropdown(
                options=[1, 2, 3, 4, 5, 6], value=3,
            ),
            mass_min=FloatSlider(
                min=0.001, max=0.1, step=0.001, value=0.05,
                continuous_update=False,
            ),
        ) 
# if __name__ == '__main__':
#     from test import test_utilities
#     import matplotlib.pyplot as plt
#
#     ema_logging.log_to_stderr(ema_logging.INFO)
#
#     def scarcity_classify(outcomes):
#         outcome = outcomes['relative market price']
#         change = np.abs(outcome[:, 1::]-outcome[:, 0:-1])
#
#         neg_change = np.min(change, axis=1)
#         pos_change = np.max(change, axis=1)
#
#         logical = (neg_change > -0.6) & (pos_change > 0.6)
#
#         classes = np.zeros(outcome.shape[0])
#         classes[logical] = 1
#
#         return classes
#
#     results = test_utilities.load_scarcity_data()
#
#     cart = setup_cart(results, scarcity_classify)
#     cart.build_tree()
#
#     print(cart.boxes_to_dataframe())
#     print(cart.stats_to_dataframe())
#     cart.display_boxes(together=True)
#
#     img = cart.show_tree()
#
#     import matplotlib.pyplot as plt
#     import matplotlib.image as mpimg
#
#     # treat the dot output string as an image file
#     sio = StringIO()
#     sio.write(img)
#     sio.seek(0)
#     img = mpimg.imread(sio)
#
#     # plot the image
#     imgplot = plt.imshow(img, aspect='equal')
#
#     plt.show()