from ..database.database import Database
from .scatter import scatter_graph_row, ScatterMass
from ..util import xmle
from plotly.colors import DEFAULT_PLOTLY_COLORS
import itertools
COLOR_BLUE = "rgb(31, 119, 180)"
COLOR_RED = 'rgb(227, 20, 20)'
COLOR_GREEN = "rgb(44, 160, 44)"
def _pick_color(scope, x, y):
	lev = scope.get_lever_names()
	unc = scope.get_uncertainty_names()
	if y in lev:
		if x in lev:
			return COLOR_BLUE
		if x in unc:
			return COLOR_RED
		return COLOR_BLUE
	elif y in unc:
		if x in lev:
			return COLOR_BLUE
		if x in unc:
			return COLOR_RED
		return COLOR_RED
	else: # y in meas
		if x in lev:
			return COLOR_BLUE
		if x in unc:
			return COLOR_RED
		return COLOR_GREEN
[docs]def scatter_graphs(
		column,
		data,
		scope,
		db=None,
		contrast='infer',
		marker_opacity=None,
		mass=1000,
		render=None,
		use_gl=True,
):
	"""
	Generate a row of scatter plots comparing one column against others.
	Args:
		column (str):
			The name of the principal parameter or measure to analyze.
		data (pandas.DataFrame or str): The experimental results to plot.
			Can be given as a DataFrame or as the name of a design (in
			which case the results are loaded from the provided `db`).
		scope (Scope, optional): The exploratory scope.
		db (Database, optional):
			The database containing the results. This is ignored unless
			`data` is a string.
		contrast (str or list):
			The contrast columns to plot the principal parameter or
			measure against.  Can be given as a list of columns that
			appear in the data, or one of {'uncertainties', 'levers',
			'parameters', 'measures', 'infer'}. If set to 'infer', the
			contrast will be 'measures' if `column` is a parameter,
			and 'parameters' if `columns` is a measure.
		marker_opacity (float, optional):
			The opacity to use for markers. If the number of markers is
			large, the figure may appear as a solid blob; by setting opacity
			to less than 1.0, the figure can more readily show relative
			density in various regions.  If not specified, marker_opacity
			is set based on `mass` instead.
		mass (int or emat.viz.ScatterMass, default 1000):
			The target number of rendered points in each figure. Setting
			to a number less than the number of experiments will make
			each scatter point partially transparent, which will help
			visually convey relative density when there are a very large
			number of points.
		render (str or dict, optional):
			If given, the graph[s] will be rendered to a static image
			using `plotly.io.to_image`.  For default settings, pass
			'png', or give a dictionary that specifies keyword arguments
			to that function.  See `emat.util.rendering.render_plotly`
			for more details.
		use_gl (bool, default True):
			Use Plotly's `Scattergl` instead of `Scatter`, which may
			provide some performance benefit for large data sets.
	Returns:
		FigureWidget or xmle.Elem
	Raises:
		ValueError: If `contrast` is 'infer' but `column` is neither a parameter
			nor a measure.
	"""
	if contrast == 'infer':
		if column in scope.get_uncertainty_names():
			contrast = 'measures'
		elif column in scope.get_lever_names():
			contrast = 'measures'
		elif column in scope.get_measure_names():
			contrast = 'parameters'
		else:
			raise ValueError('cannot infer what to contrast against')
	if contrast == 'uncertainties':
		contrast = scope.get_uncertainty_names()
	elif contrast == 'levers':
		contrast = scope.get_lever_names()
	elif contrast == 'parameters':
		contrast = scope.get_uncertainty_names() + scope.get_lever_names()
	elif contrast == 'measures':
		contrast = scope.get_measure_names()
	if isinstance(data, str):
		if db is None:
			raise ValueError('db cannot be None if data is a design name')
		data = db.read_experiment_all(data)
	if isinstance(mass, int):
		mass = ScatterMass(mass)
	if marker_opacity is None:
		marker_opacity = mass.get_opacity(data)
	y_title = column
	if scope is not None:
		y_title = scope.shortname(y_title)
	contrast_cols = [c for c in contrast if c in data.columns]
	contrast_color = [_pick_color(scope, c, column) for c in contrast_cols]
	fig = scatter_graph_row(
		contrast_cols,
		column,
		df = data,
		marker_opacity=marker_opacity,
		y_title=y_title,
		layout=dict(
			margin=dict(l=50, r=2, t=5, b=40)
		),
		short_name_func=scope.shortname if scope is not None else None,
		use_gl=use_gl,
		C=contrast_color,
	)
	if render:
		if render == 'png':
			render = dict(format='png', width=200*len(contrast_cols), height=270, scale=2)
		if render == 'svg':
			render = dict(format='svg', width=200*len(contrast_cols), height=270)
		from ..util.rendering import render_plotly
		return render_plotly(fig, render)
	return fig 
[docs]def scatter_graphs_2(
		column,
		datas,
		scope,
		db=None,
		contrast='infer',
		render=None,
		colors=None,
		use_gl=True,
		mass=1000,
):
	"""
	Generate a row of scatter plots comparing multiple datasets.
	This function is similar to `scatter_graphs`, but accepts
	multiple data sets and plots them using different colors.
	Args:
		column (str):
			The name of the principal parameter or measure to analyze.
		datas (Collection[pandas.DataFrame or str]):
			The experimental results to plot. Can be given as a DataFrame
			or as the name of a design (in which case the results are
			loaded from the provided Database `db`).
		scope (Scope, optional): The exploratory scope.
		db (Database, optional):
			The database containing the results.  Ignored unless `data`
			is a string.
		contrast (str or list):
			The contrast columns to plot the principal parameter or
			measure against.  Can be given as a list of columns that
			appear in the data, or one of {'uncertainties', 'levers',
			'parameters', 'measures', 'infer'}. If set to 'infer',
			the contrast will be 'measures' if `column` is a parameter,
			and 'parameters' if `columns` is a measure.
		render (str or dict, optional):
			If given, the graph[s] will be rendered to a static image
			using `plotly.io.to_image`.  For default settings, pass
			'png', or give a dictionary that specifies keyword arguments
			to that function.  See `emat.util.rendering.render_plotly`
			for more details.
		mass (int or emat.viz.ScatterMass, default 1000):
			The target number of rendered points in each figure. Setting
			to a number less than the number of experiments will make
			each scatter point partially transparent, which will help
			visually convey relative density when there are a very large
			number of points.
	Returns:
		plotly.FigureWidget or xmle.Elem:
			The latter is returned if a `render` argument is used.
	Raises:
		ValueError:
			If `contrast` is 'infer' but `column` is neither a parameter
			nor a measure.
	"""
	if contrast == 'infer':
		if column in scope.get_uncertainty_names():
			contrast = 'measures'
		elif column in scope.get_lever_names():
			contrast = 'measures'
		elif column in scope.get_measure_names():
			contrast = 'parameters'
		else:
			raise ValueError('cannot infer what to contrast against')
	if contrast == 'uncertainties':
		contrast = scope.get_uncertainty_names()
	elif contrast == 'levers':
		contrast = scope.get_lever_names()
	elif contrast == 'parameters':
		contrast = scope.get_uncertainty_names() + scope.get_lever_names()
	elif contrast == 'measures':
		contrast = scope.get_measure_names()
	if isinstance(mass, int):
		mass = ScatterMass(mass)
	data_ = []
	marker_opacity_ = []
	fig = 'widget'
	if colors is None:
		colorcycle = itertools.cycle(DEFAULT_PLOTLY_COLORS)
	else:
		colorcycle = itertools.cycle(colors)
	for data in datas:
		if isinstance(data, str):
			if db is None:
				raise ValueError('db cannot be None if data is a design name')
			data = db.read_experiment_all(data)
		data_.append(data)
	for data in data_:
		marker_opacity_.append(mass.get_opacity(data))
		y_title = column
		try:
			y_title = scope[column].shortname
		except AttributeError:
			pass
		fig = scatter_graph_row(
			[(c if c in data.columns else None) for c in contrast],
			column,
			df = data,
			marker_opacity=marker_opacity_[-1],
			y_title=y_title,
			layout=dict(
				margin=dict(l=50, r=2, t=5, b=40)
			),
			output=fig,
			C=colorcycle.__next__(),
			short_name_func=scope.shortname if scope is not None else None,
			use_gl=use_gl,
		)
	if render:
		if render == 'png':
			render = dict(format='png', width=1400, height=270, scale=2)
		if render == 'svg':
			render = dict(format='svg', width=1400, height=270)
		from ..util.rendering import render_plotly
		return render_plotly(fig, render)
	return fig 
def heatmap_table(
		data, cmap='viridis', fmt='.3f', linewidths=0.7, figsize=(12,3),
		xlabel=None, ylabel=None, title=None,
		attach_metadata=True,
		scale_color_by_row=True,
		**kwargs
):
	"""
	Generate a SVG heatmap from data.
	Args:
		data (pandas.DataFrame): source data for the heatmap table.
		cmap (str): A colormap for the resulting heatmap.
		fmt (str): how to format values
		linewidths (float): Line widths for the table.
		figsize (tuple): The size of the resulting figure.
		xlabel, ylabel, title (str, optional): Captions for each.
		attach_metadata (bool, default True): Attach `data` to the
			resulting figure as metadata.
		scale_color_by_row (bool, default True):
			Color rows independently.
	Returns:
		xmle.Elem: The xml data for a svg rendering.
	"""
	import seaborn as sns
	from matplotlib import pyplot as plt
	fig, ax = plt.subplots(figsize=figsize)
	if scale_color_by_row:
		coloring = data.div(data.max(1), 0)
	else:
		coloring = data
	axes = sns.heatmap(
		coloring, ax=ax, cmap=cmap, annot=data,
		fmt=fmt, linewidths=linewidths,
		cbar=not scale_color_by_row, **kwargs,
	)
	if xlabel:
		ax.set_xlabel(xlabel, fontweight='bold')
	if ylabel is not None:
		ax.set_ylabel(ylabel, fontweight='bold')
	if title is not None:
		ax.set_title(title, fontweight='bold')
	result= xmle.Elem.from_figure(axes.get_figure())
	if attach_metadata:
		result.metadata = data
	return result