Source code for pydtmc.plotting

# -*- coding: utf-8 -*-

__all__ = [
    'plot_comparison',
    'plot_eigenvalues',
    'plot_flow',
    'plot_graph',
    'plot_redistributions',
    'plot_sequence',
    'plot_trellis'
]


###########
# IMPORTS #
###########

# Standard

import copy as _cp
import inspect as _ins
import io as _io
import math as _mt
import subprocess as _sub

# Libraries

import matplotlib.colorbar as _mplcb
import matplotlib.colors as _mplcr
import matplotlib.image as _mpli
import matplotlib.patches as _mplpc
import matplotlib.pyplot as _mplp
import matplotlib.ticker as _mplt
import networkx as _nx
import numpy as _np
import numpy.linalg as _npl
import scipy.interpolate as _spip

try:
    import pydot as _pyd
    _pydot_found = True
except ImportError:  # pragma: no cover
    _pyd = None
    _pydot_found = False

# Internal

from .custom_types import (
    oint as _oint,
    olist_str as _olist_str,
    oplot as _oplot,
    ostate as _ostate,
    ostatus as _ostatus,
    thmm as _thmm,
    tlist_model as _tlist_model,
    tmodel as _tmodel
)

from .exceptions import (
    ValidationError as _ValidationError
)

from .hidden_markov_model import (
    HiddenMarkovModel as _HiddenMarkovModel
)

from .markov_chain import (
    MarkovChain as _MarkovChain
)

from .utilities import (
    create_validation_error as _create_validation_error,
    create_models_names as _create_models_names
)

from .validation import (
    validate_boolean as _validate_boolean,
    validate_dpi as _validate_dpi,
    validate_enumerator as _validate_enumerator,
    validate_hidden_markov_model as _validate_hidden_markov_model,
    validate_integer as _validate_integer,
    validate_label as _validate_label,
    validate_model as _validate_model,
    validate_models as _validate_models,
    validate_status as _validate_status,
    validate_strings as _validate_strings
)


#############
# CONSTANTS #
#############

_color_black = '#000000'
_color_gray = '#E0E0E0'
_color_white = '#FFFFFF'
_colors = ('#80B1D3', '#FFED6F', '#B3DE69', '#BEBADA', '#FDB462', '#8DD3C7', '#FB8072', '#FCCDE5', '#E5C494')

_default_color_edge = _color_black
_default_color_node = _color_white
_default_color_path = _colors[0]
_default_color_symbol = _color_gray
_default_node_size = 600


#############
# FUNCTIONS #
#############


def _decode_image(g, dpi):

    buffer = _io.BytesIO()
    buffer.write(g.create(format='png'))
    buffer.seek(0)

    img = _mpli.imread(buffer)

    img_x = img.shape[0] / dpi
    img_xi = img_x * 1.1
    img_xo = ((img_xi - img_x) / 2.0) * dpi

    img_y = img.shape[1] / dpi
    img_yi = img_y * 1.1
    img_yo = ((img_yi - img_y) / 2.0) * dpi

    return img, img_x, img_xo, img_y, img_yo


def _xticks_labels(ax, size, labels_name, labels, minor_major):

    if labels_name is not None:
        ax.set_xlabel(labels_name, fontsize=13.0)

    if minor_major:
        ax.set_xticks(_np.arange(0.0, size, 1.0), minor=False)
        ax.set_xticks(_np.arange(-0.5, size, 1.0), minor=True)
    else:
        ax.set_xticks(_np.arange(0.0, size, 1.0))

    ax.set_xticklabels(labels)


def _xticks_steps(ax, length):

    ax.set_xlabel('Steps', fontsize=13.0)
    ax.set_xticks(_np.arange(0.0, length + 1.0, 1.0 if length <= 11 else 10.0), minor=False)
    ax.set_xticks(_np.arange(-0.5, length, 1.0), minor=True)
    ax.set_xticklabels(_np.arange(0, length + 1, 1 if length <= 11 else 10))
    ax.set_xlim(-0.5, length - 0.5)


def _yticks_frequency(ax, bottom, top):

    ax.set_ylabel('Frequency', fontsize=13.0)
    ax.set_yticks(_np.linspace(0.0, 1.0, 11))
    ax.set_ylim(bottom, top)


def _yticks_labels(ax, size, labels_name, labels):

    if labels_name is not None:
        ax.set_ylabel(labels_name, fontsize=13.0)

    ax.set_yticks(_np.arange(0.0, size, 1.0), minor=False)
    ax.set_yticks(_np.arange(-0.5, size, 1.0), minor=True)
    ax.set_yticklabels(labels)


[docs]def plot_comparison(models: _tlist_model, underlying_matrices: str = 'transition', names: _olist_str = None, dpi: int = 100) -> _oplot: """ The function plots the underlying matrices of the given models in the form of a heatmap. | **Notes:** * If `Matplotlib <https://matplotlib.org/>`_ is in `interactive mode <https://matplotlib.org/stable/users/interactive.html>`_, the plot is immediately displayed and the function does not return the plot handles. :param models: the models. :param underlying_matrices: - **emission** for comparing the emission matrices; - **transition** for comparing the transition matrices. :param names: the name of each model subplot (*if omitted, a standard name is given to each subplot*). :param dpi: the resolution of the plot expressed in dots per inch. :raises ValidationError: if any input argument is not compliant. """ try: models = _validate_models(models) underlying_matrices = _validate_enumerator(underlying_matrices, ['emission', 'transition']) names = _create_models_names(models) if names is None else _validate_strings(names, len(models)) dpi = _validate_dpi(dpi) except Exception as ex: # pragma: no cover raise _create_validation_error(ex, _ins.trace()) from None if underlying_matrices == 'emission' and not all(isinstance(model, _HiddenMarkovModel) for model in models): # pragma: no cover raise _ValidationError('In order to compare emission matrices, the list must contain only hidden Markov models.') space = len(models) rows = int(_mt.sqrt(space)) columns = int(_mt.ceil(space / float(rows))) figure, axes = _mplp.subplots(nrows=rows, ncols=columns, constrained_layout=True, dpi=dpi) axes = list(axes.flat) ax_is = None color_map = _mplcr.LinearSegmentedColormap.from_list('ColorMap', [_color_white, _colors[0]], 20) for ax, model, name in zip(axes, models, names): matrix = model.e if underlying_matrices == 'emission' else model.p ax_is = ax.imshow(matrix, aspect='auto', cmap=color_map, interpolation='none', vmin=0.0, vmax=1.0) ax.set_title(name, fontsize=9.0, fontweight='normal', pad=1) ax.set_xticks([]) ax.set_xticks([], minor=True) ax.set_yticks([]) ax.set_yticks([], minor=True) color_map_ax, color_map_ax_kwargs = _mplcb.make_axes(axes, drawedges=True, orientation='horizontal', ticks=[0.0, 0.25, 0.5, 0.75, 1.0]) figure.colorbar(ax_is, cax=color_map_ax, **color_map_ax_kwargs) color_map_ax.set_xticklabels([0.0, 0.25, 0.5, 0.75, 1.0]) figure.suptitle('Comparison Plot', fontsize=15.0, fontweight='bold') if _mplp.isinteractive(): # pragma: no cover _mplp.show(block=False) return None return figure, axes
[docs]def plot_eigenvalues(model: _tmodel, dpi: int = 100) -> _oplot: """ The function plots the eigenvalues of the transition matrix of the given model on the complex plane. | **Notes:** * If `Matplotlib <https://matplotlib.org/>`_ is in `interactive mode <https://matplotlib.org/stable/users/interactive.html>`_, the plot is immediately displayed and the function does not return the plot handles. :param model: the model. :param dpi: the resolution of the plot expressed in dots per inch. :raises ValidationError: if any input argument is not compliant. """ try: model = _validate_model(model) dpi = _validate_dpi(dpi) except Exception as ex: # pragma: no cover raise _create_validation_error(ex, _ins.trace()) from None if model.__class__.__name__ == 'MarkovChain': mc = model else: mc = _MarkovChain(model.p, model.states) figure, ax = _mplp.subplots(dpi=dpi) handles, labels = [], [] theta = _np.linspace(0.0, 2.0 * _np.pi, 200) evalues = _npl.eigvals(model.p).astype(complex) evalues_final = _np.unique(_np.append(evalues, _np.array([1.0]).astype(complex))) x_unit_circle = _np.cos(theta) y_unit_circle = _np.sin(theta) if mc.is_ergodic: values_abs = _np.sort(_np.abs(evalues)) values_ct1 = _np.isclose(values_abs, 1.0) if not _np.all(values_ct1): mu = values_abs[~values_ct1][-1] if not _np.isclose(mu, 0.0): x_slem_circle = mu * x_unit_circle y_slem_circle = mu * y_unit_circle cs = _np.linspace(-1.1, 1.1, 201) x_spectral_gap, y_spectral_gap = _np.meshgrid(cs, cs) z_spectral_gap = x_spectral_gap**2.0 + y_spectral_gap**2.0 h = ax.contourf(x_spectral_gap, y_spectral_gap, z_spectral_gap, alpha=0.2, colors='r', levels=[mu**2.0, 1.0]) handles.append(_mplp.Rectangle((0.0, 0.0), 1.0, 1.0, fc=h.collections[0].get_facecolor()[0])) labels.append('Spectral Gap') ax.plot(x_slem_circle, y_slem_circle, color='red', linestyle='--', linewidth=1.5) ax.plot(x_unit_circle, y_unit_circle, color='red', linestyle='-', linewidth=3.0) h, = ax.plot(_np.real(evalues_final), _np.imag(evalues_final), color='blue', linestyle='None', marker='*', markersize=12.5) handles.append(h) labels.append('Eigenvalues') ax.set_xlim(-1.1, 1.1) ax.set_ylim(-1.1, 1.1) ax.set_aspect('equal') formatter = _mplt.FormatStrFormatter('%g') ax.xaxis.set_major_formatter(formatter) ax.yaxis.set_major_formatter(formatter) ax.set_xticks(_np.linspace(-1.0, 1.0, 9), minor=False) ax.set_yticks(_np.linspace(-1.0, 1.0, 9), minor=False) ax.grid(which='major') ax.legend(handles[::-1], labels[::-1], bbox_to_anchor=(0.5, -0.1), loc='upper center', ncol=len(handles)) ax.set_title('Eigenvalues Plot', fontsize=15.0, fontweight='bold') _mplp.subplots_adjust(bottom=0.2) if _mplp.isinteractive(): # pragma: no cover _mplp.show(block=False) return None return figure, ax
[docs]def plot_flow(model: _tmodel, steps: int, interval: int, initial_status: _ostatus = None, palette: str = 'viridis', dpi: int = 100) -> _oplot: """ The function produces an alluvial diagram of the given model. | **Notes:** * If `Matplotlib <https://matplotlib.org/>`_ is in `interactive mode <https://matplotlib.org/stable/users/interactive.html>`_, the plot is immediately displayed and the function does not return the plot handles. :param model: the model. :param steps: the number of steps. :param interval: the interval between each step. :param initial_status: the initial state or the initial distribution of the states (*if omitted, the states are assumed to be uniformly distributed*). :param palette: the palette of the plot. :param dpi: the resolution of the plot expressed in dots per inch. :raises ValidationError: if any input argument is not compliant. """ def _get_boundaries(gb_d): i, j = gb_d.shape k = 0.1 / (i - 1.0) b = _np.zeros((i, j), dtype=float) t = _np.zeros((i, j), dtype=float) for o in range(j): dj = d[:, o] b[:, o] = _np.cumsum(dj + k) - dj - k t[:, o] = _np.cumsum(dj + k) - k b = _np.clip(b, 0.0, 1.0) t = _np.clip(t, 0.0, 1.0) return b, t def _get_colors(gc_pn, gc_d): i = gc_d.shape[0] cm = _np.array(_mplp.get_cmap(gc_pn).colors) ipf = _spip.interp1d(_np.linspace(0.0, 1.0, cm.shape[0]), cm, kind='linear', axis=0) ipv = ipf(_np.linspace(0.0, 1.0, 3 + ((i - 1) * 10))) cm = ipv[1:-1:10, :] return cm def _get_curves(gc_n, gc_x1, gc_y1, gc_x2, gc_y2): tx = _np.reshape(_np.linspace(gc_x1, gc_x2, 15), (1, -1)) cx = _np.tile(_np.transpose(tx), (1, gc_n)) ty = (1.0 - _np.cos(_np.reshape(_np.linspace(0.0, _np.pi, 15), (1, -1)))) / 2.0 cy = _np.tile(gc_y1, (15, 1)) + (_np.tile(gc_y2 - gc_y1, (15, 1)) * _np.tile(_np.transpose(ty), (1, gc_n))) return cx, cy def _get_legend(gl_mc, gl_c): handles = [] labels = gl_mc.states for i, label in enumerate(labels): handles.append(_mplpc.Patch(color=gl_c[i, :], label=label)) return handles, labels def _get_polygons_bars(gpb_d, gpb_bb, gpb_bt, gpb_c): i, j = gpb_d.shape w = j / 40.0 polygons = [] for oj in range(j): xm = oj - w xp = oj + w for oi in range(i): yb = gpb_bb[oi, oj] yt = gpb_bt[oi, oj] x = [xm, xp, xp, xm] y = [yb, yb, yt, yt] polygons.append(_mplpc.Polygon(list(zip(x, y)), edgecolor=None, facecolor=gpb_c[oi, :], alpha=0.8)) return polygons def _get_polygons_flows(gpf_p, gpf_d, gpf_bb, gpf_c): i, j = gpf_d.shape w = j / 40.0 polygons = [] for oj in range(j - 1): q = _npl.matrix_power(gpf_p, indices[oj + 1] - indices[oj]) bj = _np.copy(gpf_bb[:, oj + 1]) x_lo = oj + w x_hi = oj - w + 1.0 for oi in range(i): dij = gpf_d[oi, oj] bij = bb[oi, oj] qi = q[oi, :] qis = _np.cumsum(qi) tl = ((qis - qi) * dij) + bij bl = (qis * dij) + bij tr = _np.copy(bj) br = tr + bl - tl bj += bl - tl [bottom_x, bottom_y] = _get_curves(i, x_lo, bl, x_hi, br) [top_x, top_y] = _get_curves(i, x_hi, tr, x_lo, tl) x = _np.concatenate([bottom_x, top_x], axis=0) y = _np.concatenate([bottom_y, top_y], axis=0) for z in range(x.shape[1]): polygons.append(_mplpc.Polygon(list(zip(x[:, z], y[:, z])), edgecolor=None, facecolor=gpf_c[oi, :], alpha=0.3)) return polygons try: model = _validate_model(model) steps = _validate_integer(steps, lower_limit=(1, False)) interval = _validate_integer(interval, lower_limit=(1, False)) initial_status = None if initial_status is None else _validate_status(initial_status, model.states) palette = _validate_enumerator(palette, list(_mplp.colormaps)) dpi = _validate_dpi(dpi) except Exception as ex: # pragma: no cover raise _create_validation_error(ex, _ins.trace()) from None if model.__class__.__name__ == 'MarkovChain': mc = model else: mc = _MarkovChain(model.p, model.states) p = mc.p indices = list(range(0, (steps * interval) + 1, interval)) distributions = mc.redistribute(indices[-1], initial_status=initial_status, output_last=False) distributions = _np.transpose(_np.stack([distribution for index, distribution in enumerate(distributions) if index in indices])) d = distributions * 0.9 bb, bt = _get_boundaries(d) bm = (bb + bt) / 2.0 c = _get_colors(palette, d) lh, ll = _get_legend(mc, c) polygons_bars = _get_polygons_bars(d, bb, bt, c) polygons_flows = _get_polygons_flows(p, d, bb, c) figure, ax = _mplp.subplots(dpi=dpi) for e in polygons_flows: ax.add_patch(e) for e in polygons_bars: ax.add_patch(e) for ai in range(distributions.shape[0]): for aj in range(distributions.shape[1]): dv = distributions[ai, aj] if dv > 0.05: ax.text(aj, bm[ai, aj], f'{dv:.3f}', horizontalalignment='center', verticalalignment='center') _xticks_steps(ax, steps) ax.set_ylim(0.0, 1.0) ax.invert_yaxis() ax.legend(lh, ll, bbox_to_anchor=(0.5, -0.1), loc='upper center', ncol=len(lh)) ax.set_title('Flow Plot', fontsize=15.0, fontweight='bold') _mplp.subplots_adjust(bottom=0.2) if _mplp.isinteractive(): # pragma: no cover _mplp.show(block=False) return None return figure, ax
# noinspection PyBroadException
[docs]def plot_graph(model: _tmodel, nodes_color: bool = True, nodes_shape: bool = True, edges_label: bool = True, force_standard: bool = False, dpi: int = 100) -> _oplot: """ The function plots the directed graph of the given model. | **Notes:** * If `Matplotlib <https://matplotlib.org/>`_ is in `interactive mode <https://matplotlib.org/stable/users/interactive.html>`_, the plot is immediately displayed and the function does not return the plot handles. * `Graphviz <https://graphviz.org/>`_ and `pydot <https://pypi.org/project/pydot/>`_ are not required, but they provide access to extended mode with improved rendering and additional features. * The rendering, especially in standard mode or for big graphs, is not granted to be high-quality. * For Markov chains, the color of nodes is based on communicating classes; for hidden Markov models, every state node has a different color and symbol nodes are gray. * For Markov chains, recurrent nodes have an elliptical shape and transient nodes have a rectangular shape; for hidden Markov models, state nodes have an elliptical shape and symbol nodes have a hexagonal shape. :param model: the model. :param nodes_color: a boolean indicating whether to use a different color for every type of node. :param nodes_shape: a boolean indicating whether to use a different shape for every type of node. :param edges_label: a boolean indicating whether to display the probability of every edge as text. :param force_standard: a boolean indicating whether to use standard mode even if extended mode is available. :param dpi: the resolution of the plot expressed in dots per inch. :raises ValidationError: if any input argument is not compliant. """ def _calculate_magnitude(*cm_elements): magnitudes = [] for element in cm_elements: element_minimum = _np.min(element).item() element_magnitude = 0 if element_minimum == 0.0 else int(-_mt.floor(_mt.log10(abs(element_minimum)))) magnitudes.append(element_magnitude) magnitude = max(1, min(max(magnitudes), 4)) return magnitude def _draw_edge_labels_curved(delc_ax, delc_positions, delc_edge_labels): for (n1, n2), (rad, label) in delc_edge_labels.items(): (x1, y1) = delc_positions[n1] (x2, y2) = delc_positions[n2] p1 = delc_ax.transData.transform(_np.array(delc_positions[n1])) p2 = delc_ax.transData.transform(_np.array(delc_positions[n2])) linear_mid = (0.5 * p1) + (0.5 * p2) cp_mid = linear_mid + (rad * _np.dot(_np.array([(0, 1), (-1, 0)]), p2 - p1)) cp1 = (0.5 * p1) + (0.5 * cp_mid) cp2 = (0.5 * p2) + (0.5 * cp_mid) bezier_mid = (0.5 * cp1) + (0.5 * cp2) (x, y) = delc_ax.transData.inverted().transform(bezier_mid) xy = _np.array((x, y)) angle = (_np.arctan2(y2 - y1, x2 - x1) / (2.0 * _np.pi)) * 360.0 if angle > 90.0: angle -= 180.0 if angle < -90.0: angle += 180.0 rotation = delc_ax.transData.transform_angles(_np.array((angle,)), xy.reshape((1, 2)))[0] transform = delc_ax.transData bbox = { 'boxstyle': 'round', 'ec': (1.0, 1.0, 1.0), 'fc': (1.0, 1.0, 1.0) } delc_ax.text( x, y, label, color='k', size=10, family='sans-serif', weight='normal', horizontalalignment='center', verticalalignment='center', bbox=bbox, clip_on=True, rotation=rotation, transform=transform, zorder=1 ) delc_ax.tick_params( axis='both', which='both', bottom=False, left=False, labelbottom=False, labelleft=False ) def _node_colors(nc_count): colors = _cp.deepcopy(_colors) colors_limit = len(colors) - 1 colors_offset = 0 colors_list = [] while nc_count > 0: colors_list.append(colors[colors_offset]) colors_offset += 1 if colors_offset > colors_limit: # pragma: no cover colors_offset = 0 nc_count -= 1 return colors_list # noinspection DuplicatedCode def _plot_hmm_extended(phe_hmm, phe_nodes_color, phe_nodes_type, phe_edges_label, phe_dpi): magnitude = _calculate_magnitude(phe_hmm.p, phe_hmm.e) node_colors = _node_colors(phe_hmm.n) if phe_nodes_color else [] edge_colors = _cp.deepcopy(node_colors) if phe_nodes_color else [] g = _pyd.Dot(graph_type='digraph') g_sub1 = _pyd.Subgraph() g.add_subgraph(g_sub1) g_sub2 = _pyd.Subgraph(rank='same') g.add_subgraph(g_sub2) for i in range(phe_hmm.n): state = phe_hmm.states[i] node_attributes = {} if phe_nodes_color: node_attributes['style'] = 'filled' node_attributes['fillcolor'] = node_colors[i] if phe_nodes_type: node_attributes['shape'] = 'ellipse' g_sub1.add_node(_pyd.Node(state, **node_attributes)) for symbol in phe_hmm.symbols: node_attributes = {} if phe_nodes_color: node_attributes['style'] = 'filled' node_attributes['fillcolor'] = _default_color_symbol if phe_nodes_type: node_attributes['shape'] = 'hexagon' g_sub2.add_node(_pyd.Node(symbol, **node_attributes)) for i in range(phe_hmm.n): state_i = phe_hmm.states[i] for j in range(phe_hmm.n): tp = phe_hmm.p[i, j] if tp > 0.0: state_j = phe_hmm.states[j] edge_attributes = { 'style': 'filled', 'color': _default_color_edge } if phe_edges_label: edge_attributes['label'] = f' {round(tp, magnitude):.{magnitude}f} ' edge_attributes['fontsize'] = 9 g.add_edge(_pyd.Edge(state_i, state_j, **edge_attributes)) for j in range(phe_hmm.k): ep = phe_hmm.e[i, j] if ep > 0.0: symbol = phe_hmm.symbols[j] edge_attributes = { 'style': 'dashed', 'color': edge_colors[i] if phe_nodes_color else _default_color_edge } if phe_edges_label: edge_attributes['label'] = f' {round(ep, magnitude):.{magnitude}f} ' edge_attributes['fontsize'] = 9 g.add_edge(_pyd.Edge(state_i, symbol, **edge_attributes)) img, img_x, img_xo, img_y, img_yo = _decode_image(g, phe_dpi) f = _mplp.figure(figsize=(img_y * 1.1, img_x * 1.1), dpi=phe_dpi) f.figimage(img, yo=img_yo, xo=img_xo) a = f.gca() a.axis('off') return f, a def _plot_hmm_standard(phs_hmm, phs_nodes_color, phs_nodes_shape, phe_edges_label, phs_dpi): g = phs_hmm.to_graph() positions = _nx.multipartite_layout(g, align='horizontal', subset_key='layer') magnitude = _calculate_magnitude(phs_hmm.p, phs_hmm.e) node_colors = _node_colors(phs_hmm.n) if phs_nodes_color else [] edge_colors = _cp.deepcopy(node_colors) if phs_nodes_color else [] mpi = _mplp.isinteractive() _mplp.interactive(False) f, a = _mplp.subplots(dpi=phs_dpi) for i, node in enumerate(g.nodes): if phs_nodes_color: if node in phs_hmm.states: node_color = node_colors[i] else: node_color = _default_color_symbol else: node_color = _default_color_node if phs_nodes_shape: if node in phs_hmm.states: node_shape = 'o' else: node_shape = 'H' else: node_shape = 'o' _nx.draw_networkx_nodes(g, positions, ax=a, nodelist=[node], node_color=node_color, node_shape=node_shape, node_size=_default_node_size, edgecolors=_color_black) _nx.draw_networkx_labels(g, positions, ax=a) edge_labels_curved, edge_labels_straight_state, edge_labels_straight_symbol = {}, {}, {} for i in range(phs_hmm.n): state_i = phs_hmm.states[i] for j in range(phs_hmm.n): tp = phs_hmm.p[i, j] if tp > 0.0: state_j = phs_hmm.states[j] edge = (state_i, state_j) edge_color = _default_color_edge if i != j and reversed(edge) in g.edges: edge_length = abs(i - j) edge_rad = 0.15 if edge_length == 1 else edge_length * 0.25 edge_connection = f'arc3, rad={edge_rad:f}' if phe_edges_label: edge_labels_curved[edge] = (edge_rad, f' {round(tp, magnitude):.{magnitude}f} ') else: edge_connection = 'arc3' if phe_edges_label: edge_labels_straight_state[edge] = f' {round(tp, magnitude):.{magnitude}f} ' _nx.draw_networkx_edges(g, positions, ax=a, edgelist=[edge], edge_color=edge_color, arrows=True, connectionstyle=edge_connection) for j in range(phs_hmm.k): ep = phs_hmm.e[i, j] if ep > 0.0: symbol_j = phs_hmm.symbols[j] edge = (state_i, symbol_j) edge_color = edge_colors[i] if phs_nodes_color else _default_color_edge if phe_edges_label: edge_labels_straight_symbol[edge] = f' {round(ep, magnitude):.{magnitude}f} ' _nx.draw_networkx_edges(g, positions, ax=a, edgelist=[edge], edge_color=edge_color, arrows=True, style='dashed') if len(edge_labels_straight_state) > 0: _nx.draw_networkx_edge_labels(g, positions, ax=a, edge_labels=edge_labels_straight_state) if len(edge_labels_straight_symbol) > 0: _nx.draw_networkx_edge_labels(g, positions, ax=a, edge_labels=edge_labels_straight_symbol, label_pos=0.25) if len(edge_labels_curved) > 0: _draw_edge_labels_curved(a, positions, edge_labels_curved) _mplp.interactive(mpi) return f, a # noinspection DuplicatedCode def _plot_mc_extended(pme_mc, pme_nodes_color, pme_nodes_shape, phe_edges_label, pme_dpi): magnitude = _calculate_magnitude(pme_mc.p) node_colors = _node_colors(len(pme_mc.communicating_classes)) if pme_nodes_color else [] g = _pyd.Dot(graph_type='digraph') for i in range(pme_mc.size): state_i = pme_mc.states[i] node_attributes = {} if pme_nodes_color: for index, cc in enumerate(pme_mc.communicating_classes): if state_i in cc: node_attributes['style'] = 'filled' node_attributes['fillcolor'] = node_colors[index] break if pme_nodes_shape: if state_i in pme_mc.transient_states: # pragma: no cover node_attributes['shape'] = 'box' else: node_attributes['shape'] = 'ellipse' g.add_node(_pyd.Node(state_i, **node_attributes)) for j in range(pme_mc.size): tp = pme_mc.p[i, j] if tp > 0.0: state_j = pme_mc.states[j] edge_attributes = { 'style': 'filled', 'color': _default_color_edge } if phe_edges_label: edge_attributes['label'] = f' {round(tp, magnitude):.{magnitude}f} ' edge_attributes['fontsize'] = 9 g.add_edge(_pyd.Edge(state_i, state_j, **edge_attributes)) img, img_x, img_xo, img_y, img_yo = _decode_image(g, pme_dpi) f = _mplp.figure(figsize=(img_y * 1.1, img_x * 1.1), dpi=pme_dpi) f.figimage(img, yo=img_yo, xo=img_xo) a = f.gca() a.axis('off') return f, a def _plot_mc_standard(pms_mc, pms_nodes_color, pms_nodes_shape, phe_edges_label, pms_dpi): g = pms_mc.to_graph() positions = _nx.spring_layout(g) magnitude = _calculate_magnitude(pms_mc.p) node_colors = _node_colors(len(pms_mc.communicating_classes)) if pms_nodes_color else [] mpi = _mplp.isinteractive() _mplp.interactive(False) f, a = _mplp.subplots(dpi=pms_dpi) for node in g.nodes: node_color = _default_color_node if pms_nodes_color: for index, cc in enumerate(pms_mc.communicating_classes): if node in cc: node_color = node_colors[index] break if pms_nodes_shape: if node in pms_mc.transient_states: # pragma: no cover node_shape = 's' else: node_shape = 'o' else: node_shape = 'o' _nx.draw_networkx_nodes(g, positions, ax=a, nodelist=[node], node_color=node_color, node_shape=node_shape, node_size=_default_node_size, edgecolors=_color_black) _nx.draw_networkx_labels(g, positions, ax=a) edge_labels_curved, edge_labels_straight = {}, {} for i in range(pms_mc.size): state_i = pms_mc.states[i] for j in range(pms_mc.size): tp = pms_mc.p[i, j] if tp > 0.0: state_j = pms_mc.states[j] edge = (state_i, state_j) edge_color = _default_color_edge if i != j and reversed(edge) in g.edges: edge_connection = 'arc3, rad=0.1' if phe_edges_label: edge_labels_curved[edge] = (0.1, f' {round(tp, magnitude):.{magnitude}f} ') else: edge_connection = 'arc3' if phe_edges_label: edge_labels_straight[edge] = f' {round(tp, magnitude):.{magnitude}f} ' _nx.draw_networkx_edges(g, positions, ax=a, edgelist=[edge], edge_color=edge_color, arrows=True, connectionstyle=edge_connection) if len(edge_labels_straight) > 0: _nx.draw_networkx_edge_labels(g, positions, ax=a, edge_labels=edge_labels_straight) if len(edge_labels_curved) > 0: _draw_edge_labels_curved(a, positions, edge_labels_curved) _mplp.interactive(mpi) return f, a try: model = _validate_model(model) nodes_color = _validate_boolean(nodes_color) nodes_shape = _validate_boolean(nodes_shape) edges_label = _validate_boolean(edges_label) force_standard = _validate_boolean(force_standard) dpi = _validate_dpi(dpi) except Exception as ex: # pragma: no cover raise _create_validation_error(ex, _ins.trace()) from None extended_graph = not force_standard and _pydot_found if extended_graph: try: _sub.call(['dot', '-V'], stdout=_sub.PIPE, stderr=_sub.PIPE) except Exception: # pragma: no cover extended_graph = False model_mc = model.__class__.__name__ == 'MarkovChain' if extended_graph: func = _plot_mc_extended if model_mc else _plot_hmm_extended else: func = _plot_mc_standard if model_mc else _plot_hmm_standard figure, ax = func(model, nodes_color, nodes_shape, edges_label, dpi) if _mplp.isinteractive(): # pragma: no cover _mplp.show(block=False) return None return figure, ax
# noinspection DuplicatedCode
[docs]def plot_redistributions(model: _tmodel, redistributions: int, initial_status: _ostatus = None, plot_type: str = 'projection', dpi: int = 100) -> _oplot: """ The function plots a redistribution of states on the given model. | **Notes:** * If `Matplotlib <https://matplotlib.org/>`_ is in `interactive mode <https://matplotlib.org/stable/users/interactive.html>`_, the plot is immediately displayed and the function does not return the plot handles. :param model: the model to be converted into a graph. :param redistributions: the number of redistributions to perform. :param initial_status: the initial state or the initial distribution of the states (*if omitted, the states are assumed to be uniformly distributed*). :param plot_type: - **heatmap** for displaying a heatmap plot; - **projection** for displaying a projection plot. :param dpi: the resolution of the plot expressed in dots per inch. :raises ValidationError: if any input argument is not compliant. :raises ValueError: if the "distributions" parameter represents a sequence of redistributions and the "initial_status" parameter does not match its first element. """ try: model = _validate_model(model) redistributions = _validate_integer(redistributions, lower_limit=(1, False)) initial_status = None if initial_status is None else _validate_status(initial_status, model.states) plot_type = _validate_enumerator(plot_type, ['heatmap', 'projection']) dpi = _validate_dpi(dpi) except Exception as ex: # pragma: no cover raise _create_validation_error(ex, _ins.trace()) from None if model.__class__.__name__ == 'MarkovChain': mc = model else: mc = _MarkovChain(model.p, model.states) distributions = mc.redistribute(redistributions, initial_status=initial_status, output_last=False) if initial_status is not None and not _np.array_equal(distributions[0], initial_status): # pragma: no cover raise ValueError('The "initial_status" parameter, if specified when the "distributions" parameter represents a sequence of redistributions, must match the first element.') distributions_length = 1 if isinstance(distributions, _np.ndarray) else len(distributions) distributions = _np.array([distributions]) if isinstance(distributions, _np.ndarray) else _np.array(distributions) figure, ax = _mplp.subplots(dpi=dpi) if plot_type == 'heatmap': color_map = _mplcr.LinearSegmentedColormap.from_list('ColorMap', [_color_white, _colors[0]], 20) ax_is = ax.imshow(_np.transpose(distributions), aspect='auto', cmap=color_map, interpolation='none', vmin=0.0, vmax=1.0) _xticks_steps(ax, distributions_length) _yticks_labels(ax, mc.size, None, mc.states) ax.grid(which='minor', color='k') cb = figure.colorbar(ax_is, drawedges=True, orientation='horizontal', ticks=[0.0, 0.25, 0.5, 0.75, 1.0]) cb.ax.set_xticklabels([0.0, 0.25, 0.5, 0.75, 1.0]) ax.set_title('Redistributions Plot (Heatmap)', fontsize=15.0, fontweight='bold') else: ax.set_prop_cycle('color', _colors) if distributions_length == 2: for i in range(mc.size): ax.plot(_np.arange(0.0, distributions_length, 1.0), distributions[:, i], label=mc.states[i], marker='o') else: for i in range(mc.size): ax.plot(_np.arange(0.0, distributions_length, 1.0), distributions[:, i], label=mc.states[i]) if _np.allclose(distributions[0, :], _np.ones(mc.size, dtype=float) / mc.size): ax.plot(0.0, distributions[0, 0], color=_color_black, label="Start", marker='o', markeredgecolor=_color_black, markerfacecolor=_color_black) legend_size = mc.size + 1 else: # pragma: no cover legend_size = mc.size _xticks_steps(ax, distributions_length) _yticks_frequency(ax, -0.05, 1.05) ax.grid() ax.legend(bbox_to_anchor=(0.5, -0.1), loc='upper center', ncol=legend_size) ax.set_title('Redistributions Plot (Projection)', fontsize=15.0, fontweight='bold') _mplp.subplots_adjust(bottom=0.2) if _mplp.isinteractive(): # pragma: no cover _mplp.show(block=False) return None return figure, ax
# noinspection DuplicatedCode
[docs]def plot_sequence(model: _tmodel, steps: int, initial_state: _ostate = None, plot_type: str = 'histogram', seed: _oint = None, dpi: int = 100) -> _oplot: """ The function plots a random walk on the given model. | **Notes:** * If `Matplotlib <https://matplotlib.org/>`_ is in `interactive mode <https://matplotlib.org/stable/users/interactive.html>`_, the plot is immediately displayed and the function does not return the plot handles. :param model: the model. :param steps: the number of steps. :param initial_state: the initial state of the random walk (*if omitted, it is chosen uniformly at random*). :param plot_type: - **heatmap** for displaying heatmap-like plots; - **histogram** for displaying a histogram plots; - **matrix** for displaying matrix plots. :param seed: a seed to be used as RNG initializer for reproducibility purposes. :param dpi: the resolution of the plot expressed in dots per inch. :raises ValidationError: if any input argument is not compliant. """ # noinspection DuplicatedCode def _plot_heatmap(phm_walk_data, phm_dpi): walk_steps, walks = phm_walk_data plots_count = len(walks) mpi = _mplp.isinteractive() _mplp.interactive(False) f, a = _mplp.subplots(nrows=plots_count, constrained_layout=True, dpi=phm_dpi) a = [a] if plots_count == 1 else list(a.flat) color_map = _mplcr.LinearSegmentedColormap.from_list('ColorMap', [_color_white, _colors[0]], 20) is_axes = [] for a_current, (size, labels_name, labels, sequence) in zip(a, walks): sequence_matrix = _np.zeros((size, size), dtype=float) for i in range(1, walk_steps): sequence_matrix[sequence[i - 1], sequence[i]] += 1.0 sequence_matrix /= _np.sum(sequence_matrix) a_current_is = a_current.imshow(sequence_matrix, aspect='auto', cmap=color_map, interpolation='none', vmin=0.0, vmax=1.0) is_axes.append(a_current_is) _xticks_labels(a_current, size, labels_name, labels, True) _yticks_labels(a_current, size, labels_name, labels) a_current.grid(which='minor', color='k') color_map_ax, color_map_ax_kwargs = _mplcb.make_axes(a, drawedges=True, orientation='vertical', ticks=[0.0, 0.25, 0.5, 0.75, 1.0]) for is_ax in is_axes: f.colorbar(is_ax, cax=color_map_ax, **color_map_ax_kwargs) f.suptitle('Sequence Plot (Heatmap)', fontsize=15.0, fontweight='bold') _mplp.interactive(mpi) return f, a # noinspection DuplicatedCode def _plot_histogram(ph_walk_data, ph_dpi): walk_steps, walks = ph_walk_data plots_count = len(walks) mpi = _mplp.isinteractive() _mplp.interactive(False) f, a = _mplp.subplots(nrows=plots_count, tight_layout=True, dpi=ph_dpi) a = [a] if plots_count == 1 else list(a.flat) for a_current, (size, labels_name, labels, sequence) in zip(a, walks): sequence_histogram = _np.zeros((size, walk_steps), dtype=float) for index, label in enumerate(sequence): sequence_histogram[label, index] = 1.0 sequence_histogram = _np.sum(sequence_histogram, axis=1) / _np.sum(sequence_histogram) a_current.bar(_np.arange(0.0, size, 1.0), sequence_histogram, edgecolor=_color_black, facecolor=_colors[0]) _xticks_labels(a_current, size, labels_name, labels, False) _yticks_frequency(a_current, 0.0, 1.0) f.suptitle('Sequence Plot (Histogram)', fontsize=15.0, fontweight='bold') _mplp.interactive(mpi) return f, a # noinspection DuplicatedCode def _plot_matrix(pm_walk_data, pm_dpi): walk_steps, walks = pm_walk_data plots_count = len(walks) mpi = _mplp.isinteractive() _mplp.interactive(False) f, a = _mplp.subplots(nrows=plots_count, tight_layout=True, dpi=pm_dpi) a = [a] if plots_count == 1 else list(a.flat) color_map = _mplcr.LinearSegmentedColormap.from_list('ColorMap', [_color_white, _colors[0]], 2) for a_current, (size, labels_name, labels, sequence) in zip(a, walks): sequence_matrix = _np.zeros((size, walk_steps), dtype=float) for index, state in enumerate(sequence): sequence_matrix[state, index] = 1.0 a_current.imshow(sequence_matrix, aspect='auto', cmap=color_map, interpolation='none', vmin=0.0, vmax=1.0) _xticks_steps(a_current, walk_steps) _yticks_labels(a_current, size, labels_name, labels) a_current.grid(which='minor', color='k') f.suptitle('Sequence Plot (Matrix)', fontsize=15.0, fontweight='bold') _mplp.interactive(mpi) return f, a try: model = _validate_model(model) steps = _validate_integer(steps, lower_limit=(2, False)) initial_state = None if initial_state is None else _validate_label(initial_state, model.states) plot_type = _validate_enumerator(plot_type, ['heatmap', 'histogram', 'matrix']) dpi = _validate_dpi(dpi) except Exception as ex: # pragma: no cover raise _create_validation_error(ex, _ins.trace()) from None model_mc = model.__class__.__name__ == 'MarkovChain' model_sequence = model.simulate(steps, initial_state=initial_state, output_indices=True, seed=seed) if model_mc: walk_data = ( steps + 1, [ (model.n, 'States', model.states, model_sequence) ] ) else: walk_data = ( steps + 1, [ (model.n, 'States', model.states, model_sequence[0]), (model.k, 'Symbols', model.symbols, model_sequence[1]) ] ) if plot_type == 'heatmap': func = _plot_heatmap elif plot_type == 'histogram': func = _plot_histogram else: func = _plot_matrix figure, ax = func(walk_data, dpi) if _mplp.isinteractive(): # pragma: no cover _mplp.show(block=False) return None return figure, ax
# noinspection PyBroadException def plot_trellis(hmm: _thmm, steps: int, initial_state: _ostate = None, seed: _oint = None, force_standard: bool = False, dpi: int = 100) -> _oplot: """ The function plots the trellis diagrams of a random walk on the given hidden Markov model. | **Notes:** * If `Matplotlib <https://matplotlib.org/>`_ is in `interactive mode <https://matplotlib.org/stable/users/interactive.html>`_, the plot is immediately displayed and the function does not return the plot handles. * `Graphviz <https://graphviz.org/>`_ and `pydot <https://pypi.org/project/pydot/>`_ are not required, but they provide access to extended mode with improved rendering and additional features. * The rendering of large simulations is not granted to be high-quality. * Red nodes and edges belong to the most probable states path calculated using the Viterbi algorithm. :param hmm: the hidden Markov model. :param steps: the number of steps. :param initial_state: the initial state of the random walk (*if omitted, it is chosen uniformly at random*). :param seed: a seed to be used as RNG initializer for reproducibility purposes. :param force_standard: a boolean indicating whether to use standard mode even if extended mode is available. :param dpi: the resolution of the plot expressed in dots per inch. :raises ValidationError: if any input argument is not compliant. :raises ValueError: if the computation of backward and forward probabilities fails or if the computation of the most probable states path fails. """ def _generate_trellis_extended(gte_hmm, gte_initial_distribution, gte_symbols, gte_forward, gte_matrix, gte_states_path): n, f = gte_hmm.n, len(gte_symbols) g = _pyd.Dot(graph_type='digraph', compound='true', margin='0') sub = _pyd.Subgraph('cluster_0', color='transparent') for row in range(n): sub_label = gte_hmm.states[row] sub.add_node(_pyd.Node(f'state{row}', color='transparent', label=sub_label, shape='plaintext')) if row > 0: sub.add_edge(_pyd.Edge(f'state{row - 1}', f'state{row}', style='invis')) g.add_subgraph(sub) for col in range(f): path_current = gte_states_path[col] sub_label = "<T<FONT POINT-SIZE='8'><SUB>0</SUB></FONT>>" if col == 0 else gte_hmm.symbols[gte_symbols[col]] sub = _pyd.Subgraph(f'cluster_{col + 1}', color='transparent', label=sub_label) for row in range(n): node_index = (row * f) + col node_color = _default_color_path if path_current == row else _default_color_node sub.add_node(_pyd.Node(f'node{node_index}', fillcolor=node_color, label=f'{round(gte_matrix[row, col], 2):.2f}', style='filled')) if row > 0: sub.add_edge(_pyd.Edge(f'node{((row - 1) * f) + col}', f'node{node_index}', style='invis')) g.add_subgraph(sub) for row in range(n): row_offset = row * f for col in range(f - 1): if col == 0 and not gte_initial_distribution[col] > 0.0: continue for row_next in range(n): if hmm.p[row][row_next] > 0.0: if gte_forward: edge_from = f'node{row_offset + col}' edge_to = f'node{(row_next * f) + col + 1}' else: edge_from = f'node{(row_next * f) + col + 1}' edge_to = f'node{row_offset + col}' edge_color = _default_color_path if gte_states_path[col] == row and gte_states_path[col + 1] == row_next else _color_black g.add_edge(_pyd.Edge(edge_from, edge_to, color=edge_color, constraint='false')) return g def _generate_trellis_standard(gts_hmm, gts_initial_distribution, gts_symbols, gts_forward, gts_matrix, gts_states_path): n, f = gts_hmm.n, len(gts_symbols) trellis = _nx.DiGraph() node_colors, node_edges, node_labels, node_positions = [], [], {}, {} node_index = 0 for row in range(n): row_offset = float(n - row) for col in range(f): trellis.add_node(node_index) if gts_states_path[col] == row: node_colors.append(_default_color_path) else: node_colors.append(_default_color_node) node_edges.append(_color_black) node_labels[node_index] = f'{round(gts_matrix[row, col], 2):.2f}' node_positions[node_index] = (col + 1.0, row_offset) node_index += 1 edge_colors = [] for row in range(n): row_offset = row * f for col in range(f - 1): if col == 0 and not gts_initial_distribution[col] > 0.0: continue for row_next in range(n): if hmm.p[row][row_next] > 0.0: if gts_forward: trellis.add_edge(row_offset + col, (row_next * f) + col + 1) on_path = gts_states_path[col] == row and gts_states_path[col + 1] == row_next else: trellis.add_edge((row_next * f) + col + 1, row_offset + col) on_path = gts_states_path[col] == row_next and gts_states_path[col + 1] == row if on_path: edge_colors.append(_default_color_path) else: edge_colors.append(_color_black) for row in range(n): trellis.add_node(node_index) node_colors.append('none') node_edges.append('none') node_labels[node_index] = hmm.states[row] node_positions[node_index] = (0.6, float(n - row)) node_index += 1 headers = ["$\mathregular{T_0}$"] + [hmm.symbols[symbol] for symbol in gts_symbols[1:]] for col, header in enumerate(headers): trellis.add_node(node_index) node_colors.append('none') node_edges.append('none') node_labels[node_index] = header node_positions[node_index] = (col + 1.0, n + 0.35) node_index += 1 return trellis, node_colors, node_edges, node_labels, node_positions, edge_colors def _plot_extended(ps_hmm, ps_initial_distribution, ps_symbols, ps_backward, ps_forward, ps_states_path, ps_dpi): mpi = _mplp.isinteractive() _mplp.interactive(False) f, a = _mplp.subplots(nrows=2, tight_layout=True, dpi=ps_dpi) a = list(a.flat) trellis = _generate_trellis_extended(ps_hmm, ps_initial_distribution, ps_symbols, False, ps_backward, ps_states_path) img, _, _, _, _ = _decode_image(trellis, ps_dpi) ax_current = a[0] ax_current.imshow(img) ax_current.axis('off') ax_current.set_title('Backward Trellis', fontsize=15.0, fontweight='normal', pad=1) trellis = _generate_trellis_extended(ps_hmm, ps_initial_distribution, ps_symbols, True, ps_forward, ps_states_path) img, _, _, _, _ = _decode_image(trellis, ps_dpi) ax_current = a[1] ax_current.imshow(img) ax_current.axis('off') ax_current.set_title('Forward Trellis', fontsize=15.0, fontweight='normal', pad=1) _mplp.interactive(mpi) return f, a def _plot_standard(ps_hmm, ps_initial_distribution, ps_symbols, ps_backward, ps_forward, ps_states_path, ps_dpi): y_top = ps_hmm.n + 0.5 mpi = _mplp.isinteractive() _mplp.interactive(False) f, a = _mplp.subplots(nrows=2, tight_layout=True, dpi=ps_dpi) a = list(a.flat) trellis, node_colors, node_edges, node_labels, node_positions, edge_colors = _generate_trellis_standard(ps_hmm, ps_initial_distribution, ps_symbols, False, ps_backward, ps_states_path) ax_current = a[0] _nx.draw_networkx(trellis, node_positions, ax=ax_current, edgecolors=node_edges, edge_color=edge_colors, font_size=9, labels=node_labels, node_color=node_colors, node_size=_default_node_size) ax_current.set_ylim(0.5, y_top) ax_current.axis('off') ax_current.set_title('Backward Trellis', fontsize=15.0, fontweight='normal', pad=1) trellis, node_colors, node_edges, node_labels, node_positions, edge_colors = _generate_trellis_standard(ps_hmm, ps_initial_distribution, ps_symbols, True, ps_forward, ps_states_path) ax_current = a[1] _nx.draw_networkx(trellis, node_positions, ax=ax_current, edgecolors=node_edges, edge_color=edge_colors, font_size=9, labels=node_labels, node_color=node_colors, node_size=_default_node_size) ax_current.set_ylim(0.5, y_top) ax_current.axis('off') ax_current.set_title('Forward Trellis', fontsize=15.0, fontweight='normal', pad=1) _mplp.interactive(mpi) return f, a try: hmm = _validate_hidden_markov_model(hmm) steps = _validate_integer(steps, lower_limit=(2, False)) initial_state = None if initial_state is None else _validate_label(initial_state, hmm.states) force_standard = _validate_boolean(force_standard) dpi = _validate_dpi(dpi) except Exception as ex: # pragma: no cover raise _create_validation_error(ex, _ins.trace()) from None if initial_state is None: initial_distribution = _np.full(hmm.n, 1.0 / hmm.n, dtype=float) else: initial_distribution = _np.zeros(hmm.n, dtype=float) initial_distribution[initial_state] = 1.0 _, symbols = hmm.simulate(steps, initial_state=initial_state, output_indices=True, seed=seed) decoding = hmm.decode(symbols, initial_status=initial_distribution) if decoding is None: # pragma: no cover raise ValueError('The computation of backward and forward probabilities failed.') _, _, backward, forward, _ = decoding prediction = hmm.predict('mle', symbols, initial_status=initial_distribution, output_indices=True) if prediction is None: # pragma: no cover raise ValueError('The computation of the most probable states path failed.') _, states_path = prediction extended_graph = not force_standard and _pydot_found if extended_graph: try: _sub.call(['dot', '-V'], stdout=_sub.PIPE, stderr=_sub.PIPE) except Exception: # pragma: no cover extended_graph = False if extended_graph: figure, ax = _plot_extended(hmm, initial_distribution, symbols, backward, forward, states_path, dpi) else: figure, ax = _plot_standard(hmm, initial_distribution, symbols, backward, forward, states_path, dpi) if _mplp.isinteractive(): # pragma: no cover _mplp.show(block=False) return None return figure, ax