"""Temporal Interval Network Density Analysis (TINDA).
This module contains functions for calculating the density profile (i.e.,
fractional occupancy over) in any interval between events it is originally
intended to use it on an HMM state time course to ask questions like what is
the density of state :math:`j` in the first and second part of the interval
between visits to state :math:`i`.
See Also
--------
`Example scripts <https://github.com/OHBA-analysis/osl-dynamics/blob/main\
/examples/tinda>`_ for applying TINDA.
"""
import logging
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import matplotlib.pyplot as plt
_logger = logging.getLogger("osl-dynamics")
[docs]
def find_intervals(tc_hot: np.ndarray) -> Tuple[List[Tuple[int, int]], np.ndarray]:
"""Find intervals (periods where :code:`tc_hot` is zero) in a hot vector.
Parameters
----------
tc_hot : array_like
Hot vector (i.e., binary vector) of shape (n_samples,) or
(n_samples, 1). For example, a hot vector of a state time course
of shape (n_samples, n_states).
Returns
-------
intervals : list
List of tuples of start and end indices of intervals.
durations : array_like
Array of durations of intervals (in samples).
"""
intervals = []
durations = []
tc_tmp = np.insert(np.insert(tc_hot, 0, 0, axis=0), -1, 0, axis=0)
start = np.where(np.diff(tc_tmp) == 1)[0]
end = np.where(np.diff(tc_tmp) == -1)[0]
intervals = list(zip(end[:-1], start[1:]))
durations = np.diff(intervals, axis=1).squeeze()
return intervals, durations
[docs]
def split_intervals(
intervals: List[Tuple[int, int]], n_bins: int = 2
) -> Tuple[List, List, np.ndarray]:
"""Splits each interval into :code:`nbin` equally sized bins.
Parameters
----------
intervals : list
List of tuples of start and end indices of intervals.
n_bins : int, optional
Number of bins to split each interval into.
Returns
-------
divided_intervals : list
List the same length as intervals (minus dropped intervals, see below),
with each element being a list of tuples of start and end indices of
bins.
bin_sizes : list
List of bin sizes (in samples), one per interval.
drop_mask : array_like
Array of zeros and ones indicating whether the interval was dropped
because it was smaller than :code:`n_bins`.
"""
divided_intervals = []
bin_sizes = []
drop_mask = np.zeros(len(intervals))
for i, (interval_start, interval_end) in enumerate(intervals):
n_samples = interval_end - interval_start
if n_samples < n_bins:
drop_mask[i] = 1
continue
bin_size = n_samples // n_bins
remainder = n_samples % n_bins
bins = []
bin_start = interval_start
for i in range(n_bins):
bin_end = bin_start + bin_size
bins.append((bin_start, bin_end))
if remainder > 0:
if remainder == n_bins - 1:
bin_end += 1
elif i == (n_bins // 2 - 1):
bin_end += remainder
bin_start = bin_end
divided_intervals.append(bins)
bin_sizes.append(bin_size)
return divided_intervals, bin_sizes, drop_mask
[docs]
def split_interval_duration(
durations: np.ndarray,
interval_range: Optional[np.ndarray] = None,
mode: str = "sample",
sampling_frequency: Optional[float] = None,
) -> Tuple[List, Optional[np.ndarray]]:
"""Split interval durations into bins based on their duration.
Parameters
----------
durations : array_like
Array of durations of intervals (in samples).
interval_range : array_like, optional
Array of bin edges (in samples, seconds, or percentiles) to split
durations into bins are defined as :code:`[>=interval_range[i],
<interval_range[i+1])`. If :code:`None`, all durations are in the
same bin.
mode : str, optional
Mode of interval_range, either :code:`"sample"` (e.g.,
:code:`[4, 20, 100]`), :code:`"perc"` (e.g., :code:`range(20,100,20)`),
or :code:`"sec"` (e.g., :code:`[0, 0.01, 0.1, 1, np.inf]`). If
:code:`"sec"`, :code:`sfreq` must be provided.
sampling_frequency : float, optional
Sampling frequency (in Hz) of the data, only used if :code:`mode` is
:code:`"sec"`.
Returns
-------
mask : list
List of arrays of zeros and ones indicating whether the interval was
in the bin.
interval_range : array_like
Array of bin edges (in samples) used to split durations into bins.
"""
if interval_range is None:
mask = [np.ones_like(durations)]
interval_range = [np.min(durations), np.max(durations)]
else:
if mode == "sec" and sampling_frequency is None:
raise ValueError(
"Sampling frequency (sfreq) must be specified when mode is 'sec'"
)
return [np.ones_like(durations)], None
else:
if mode == "sec":
interval_range = [r * sampling_frequency for r in interval_range]
elif mode == "perc":
interval_range = np.percentile(durations, interval_range)
mask = []
for i, start in enumerate(interval_range):
if i < len(interval_range) - 1:
hot_vector = np.logical_and(
durations >= start, durations < interval_range[i + 1]
)
mask.append(hot_vector.astype(int))
return mask, interval_range
[docs]
def compute_fo_stats(
tc_sec: np.ndarray,
divided_intervals: List,
interval_mask: Optional[np.ndarray] = None,
return_all_intervals: bool = False,
) -> Tuple[np.ndarray, np.ndarray, Optional[List], Optional[List]]:
"""Compute sums and weighted averages of time courses in each interval.
Parameters
----------
tc_sec : array_like
Time course of shape (n_samples, n_states).
divided_intervals : list
List with each element corresponding to an interval, each itself being a
list of tuples of start and end indices of interval bins.
interval_mask : array_like, optional
Array of zeros and ones indicating whether the interval was in the bin.
return_all_intervals : bool, optional
Whether to return the density/sum of all intervals in addition to the
interval averages/sums.
Returns
-------
interval_weighted_avg : array_like
Array of weighted averages of time courses in each interval of shape
(n_states, n_bins, n_interval_ranges).
interval_sum : array_like
Array of sums of time courses in each interval of shape (n_states,
n_bins, n_interval_ranges).
interval_weighted_avg_all : list
List of length n_interval_ranges with each element an array of weighted
averages of time courses in each interval of shape (n_states, n_bins,
n_intervals). :code:`None` if :code:`return_all_intervals=False`
(default).
interval_sum_all : list
List of length :code:`n_interval_ranges` with each element an array of
sums of time courses in each interval of shape (n_states, n_bins,
n_intervals). :code:`None` if :code:`return_all_intervals=False`.
"""
if interval_mask is None:
interval_mask = [np.ones(len(divided_intervals))]
if (
len(divided_intervals[0]) == 2
): # this corresponds to the matlab code but only works for two bins
# and I think it's less principled than the code below
intervals = []
for interval in divided_intervals:
intervals.append([interval[0][0], interval[-1][-1]])
interval_weighted_avg = np.zeros((tc_sec.shape[1], 2, len(interval_mask)))
interval_sum = np.zeros((tc_sec.shape[1], 2, len(interval_mask)))
temp_to = []
temp_away = []
for i in intervals:
d = int(np.floor((np.diff(i) - 1) / 2))
temp_away.append(tc_sec[i[0] : i[0] + d + 1, :])
temp_to.append(tc_sec[i[1] - d - 1 : i[1], :])
interval_weighted_avg_all = []
interval_sum_all = []
for j, mask in enumerate(interval_mask):
temp_to_flat = np.concatenate([temp_to[k] for k in np.where(mask == 1)[0]])
temp_away_flat = np.concatenate(
[temp_away[k] for k in np.where(mask == 1)[0]]
)
interval_weighted_avg[:, 0, j] = np.mean(temp_away_flat, axis=0)
interval_weighted_avg[:, 1, j] = np.mean(temp_to_flat, axis=0)
interval_sum[:, 0, j] = np.sum(temp_away_flat, axis=0)
interval_sum[:, 1, j] = np.sum(temp_to_flat, axis=0)
interval_weighted_avg_all.append(
np.transpose(
np.stack(
[
np.stack(
[
temp_away[k].mean(axis=0)
for k in np.where(mask == 1)[0]
],
axis=-1,
),
np.stack(
[
temp_to[k].mean(axis=0)
for k in np.where(mask == 1)[0]
],
axis=-1,
),
]
),
axes=[1, 0, 2],
)
)
interval_sum_all.append(
np.transpose(
np.stack(
[
np.stack(
[
temp_away[k].sum(axis=0)
for k in np.where(mask == 1)[0]
],
axis=-1,
),
np.stack(
[
temp_to[k].sum(axis=0)
for k in np.where(mask == 1)[0]
],
axis=-1,
),
]
),
axes=[1, 0, 2],
)
)
else:
# TODO: I think this is more principled than the matlab code,
# but I need to check
interval_sum = np.zeros(
(tc_sec.shape[1], len(divided_intervals[0]), len(divided_intervals))
)
interval_weighted_avg = np.zeros(
(tc_sec.shape[1], len(divided_intervals[0]), len(divided_intervals))
)
for i, interval in enumerate(divided_intervals):
for j, (start, end) in enumerate(interval):
interval_sum[:, j, i] = np.sum(tc_sec[start:end, :], axis=0)
interval_weighted_avg[:, :, i] = interval_sum[:, :, i] / (end - start)
interval_weighted_avg_all = [
interval_weighted_avg[:, :, interval_selection == 1]
for interval_selection in interval_mask
]
interval_sum_all = [
interval_sum[:, :, interval_selection == 1]
for interval_selection in interval_mask
]
interval_weighted_avg = np.stack(
[weighted_avg.mean(axis=-1) for weighted_avg in interval_weighted_avg_all],
axis=-1,
)
interval_sum = np.stack(
[int_sum.mean(axis=-1) for int_sum in interval_sum_all], axis=-1
)
if return_all_intervals:
return (
interval_weighted_avg,
interval_sum,
interval_weighted_avg_all,
interval_sum_all,
)
else:
return interval_weighted_avg, interval_sum, None, None
[docs]
def collate_stats(
stats: List[Dict],
field: str,
all_to_all: bool = False,
ignore_elements: Optional[List[int]] = None,
) -> np.ndarray:
"""Collate list of stats (e.g., of different states) into a single array.
Parameters
----------
stats : list
List of stats (:code:`dict`) for each state. Each element is a
dictionary with keys that at least should include "field" (e.g.,
:code:`interval_wavg`), that is the output of :code:`compute_fo_stats`.
field : str
Field of stats to collate, e.g., :code:`"interval_wavg"`,
:code:`"interval_sum"`.
all_to_all : bool, optional
Whether the density_of was used to compute the stats (in which case
the first 2 dimensions are not :code:`n_states` x :code:`n_states`).
Default is :code:`False`.
ignore_elements : list, optional
List of indices in stats to ignore (i.e. because they don't contain
binary events).
Returns
-------
collated_stat : array_like
The collated stat (:code:`n_interval_states`, :code:`n_density_states`,
:code:`n_bins`, :code:`n_interval_ranges`). If :code:`all_to_all=False`
(default) (i.e., when the density is computed for all states using all
states' intervals), then the first two dimensions are :code:`n_states`
and the diagonal is :code:`np.nan`.
"""
if ignore_elements is None:
ignore_elements = []
num_states = len(stats)
shp = stats[0][field].shape # (n_states, n_bins, n_interval_ranges)
if num_states > 1:
if all_to_all:
collated_stat = np.full((2 * [num_states] + list(shp[1:])), np.nan)
else:
collated_stat = np.full(([num_states] + list(shp)), np.nan)
for i in range(num_states):
if i in ignore_elements:
# intervals are not binary, keep a row of nans
continue
if all_to_all:
collated_stat[i, np.arange(num_states) != i, ...] = stats[i][field]
else:
collated_stat[i] = stats[i][field]
else:
collated_stat = stats[0][field]
return collated_stat
[docs]
def tinda(
tc: Union[np.ndarray, List[np.ndarray]],
density_of: Optional[Union[np.ndarray, List[np.ndarray]]] = None,
n_bins: int = 2,
interval_mode: Optional[str] = None,
interval_range: Optional[np.ndarray] = None,
sampling_frequency: Optional[float] = None,
return_all_intervals: bool = False,
) -> Tuple[np.ndarray, np.ndarray, Union[List[Dict], Dict]]:
"""Compute time-in-state density and sum for each interval.
Parameters
----------
tc : array_like
Time courses of shape (n_samples, n_states) define intervals from will
use the same time courses to compute density of when :code:`density_of`
is :code:`None`. Can be a list of time courses (e.g. state time courses
for each individual).
density_of : array_like, optional
Time course of shape (n_samples, n_states) to compute density of if
:code:`None` (default), density is computed for all columns of tc.
n_bins : int, optional
Number of bins to divide each interval into (default 2).
interval_mode : str, optional
Mode of :code:`interval_range`, either :code:`"sample"` (default),
"sec" (seconds) or "perc" (percentile). To interpret the interval range
as seconds, :code:`sfreq` must be provided.
interval_range : array_like, optional
Array of bin edges (in samples, seconds, or percentiles) used to split
durations into bins (default :code:`None`), e.g.
:code:`np.arange(0, 1, 0.1)` for 100 ms bins.
sampling_frequency : float, optional
Sampling frequency of tc (in Hz), only used if
:code:`interval_mode="sec"`.
return_all_intervals : bool, optional
Whether to return the density/sum of all intervals in addition to the
interval averages/sums. If :code:`True`, will return a list of arrays in
:code:`stats[i]['all_interval_wavg'/'all_interval_sum']`, each
corresponding to an interval range.
Returns
-------
fo_density : array_like
Time-in-state densities array of shape (n_interval_states,
n_density_states, n_bins, n_interval_ranges). :code:`n_interval_states`
is the number of states in the interval time courses (i.e., tc);
:code:`n_density_states` is the number of states in the density time
courses (i.e., :code:`density_of`). If :code:`density_of` is
:code:`None`, :code:`n_density_states` is the same as
:code:`n_interval_states`. If tc is a list of time courses (e.g., state
time courses for multiple individuals), then an extra dimension is appended
for the individuals.
fo_sum : array_like
Same as :code:`fo_density`, but with time-in-state sums instead of
densities.
stats : dict
Dictionary of stats, including
- :code:`durations`: interval durations in samples.
- :code:`intervals`: start/end samples for each interval (intervals).
- :code:`interval_wavg`: the weighted average (i.e, time-in-state
density) over all interval.
- :code:`interval_sum`: the sum (i.e., time-in-state) over all
intervals.
- :code:`divided_intervals`: the bin edges for each interval.
- :code:`bin_sizes`: the bin sizes for each interval.
- :code:`interval_range`: the interval range (in samples).
- :code:`all_interval_wavg`: unaveraged interval densities (only if
:code:`return_all_intervals=True`).
- :code:`all_interval_sum`: unaveraged interval sums (only if
:code:`return_all_intervals=True`).
"""
if isinstance(
tc, list
): # list of time courses (e.g., individuals' HMM state time courses)
if density_of is None:
fo_density_tmp, fo_sum_tmp, stats = zip(
*[
tinda(
itc,
None,
n_bins,
interval_mode,
interval_range,
sampling_frequency,
return_all_intervals,
)
for itc in tc
]
)
elif len(density_of) == len(tc):
fo_density_tmp, fo_sum_tmp, stats = zip(
*[
tinda(
itc,
density_of[ix],
n_bins,
interval_mode,
interval_range,
sampling_frequency,
return_all_intervals,
)
for ix, itc in enumerate(tc)
]
)
fo_density = np.stack(fo_density_tmp, axis=-1)
fo_sum = np.stack(fo_sum_tmp, axis=-1)
else:
stats = []
dim = tc.shape
ignore_elements = []
for i in range(dim[1]):
itc_prim = tc[:, i]
if np.all(itc_prim == 0):
_logger.info(f"Skipping state {i}: no activations detected.")
stats.append(
{
"durations": [],
"intervals": [],
"interval_wavg": np.full((tc.shape[1] - 1, n_bins, 1), 0),
"interval_sum": np.full((tc.shape[1] - 1, n_bins, 1), 0),
"divided_intervals": [],
"bin_sizes": [],
"interval_range": [],
"all_interval_wavg": [],
"all_interval_sum": [],
}
)
continue
if not np.array_equal(
itc_prim, itc_prim.astype(int)
): # if not binary (i.e., intervals are not well defined)
stats.append(None)
ignore_elements.append(i)
else:
if density_of is None:
# we're doing density of all states in all states' intervals
itc_sec = tc[:, np.setdiff1d(range(dim[1]), i)]
else:
itc_sec = density_of
# get interval info
intervals, durations = find_intervals(itc_prim)
divided_intervals, bin_sizes, dropped_intervals = split_intervals(
intervals, n_bins
) # split intervals into nbin
durations = durations[
dropped_intervals == 0
] # drop intervals that are too short to be split into nbin
interval_mask, interval_range_samples = split_interval_duration(
durations,
interval_range=interval_range,
mode=interval_mode,
sampling_frequency=sampling_frequency,
) # split intervals into interval_range (i.e.,
# to compute statistics of intervals with durations in a
# certain range)
# Compute time-in-state densities and sums in all intervals
(
interval_wavg,
interval_sum,
all_interval_wavg,
all_interval_sum,
) = compute_fo_stats(
itc_sec,
divided_intervals,
interval_mask,
return_all_intervals=return_all_intervals,
)
# Append stats
stats.append(
{
"durations": durations,
"intervals": intervals,
"interval_wavg": interval_wavg,
"interval_sum": interval_sum,
"divided_intervals": divided_intervals,
"bin_sizes": bin_sizes,
"interval_range": interval_range_samples,
"all_interval_wavg": all_interval_wavg,
"all_interval_sum": all_interval_sum,
}
)
# Get a full matrix of FO densities and sums
fo_density = collate_stats(
stats,
"interval_wavg",
all_to_all=density_of is None,
ignore_elements=ignore_elements,
)
fo_sum = collate_stats(
stats,
"interval_sum",
all_to_all=density_of is None,
ignore_elements=ignore_elements,
)
return fo_density, fo_sum, stats
[docs]
def circle_angles(order: List[int]) -> np.ndarray:
"""Compute the phase differences between states in a circular plot.
Parameters
----------
order : list
List of state orders (in order of counterclockwise
rotation).
Returns
-------
angleplot : array_like
Array of phase differences between states in a circular plot.
"""
K = len(order)
disttoplot_manual = np.zeros(K, dtype=complex)
for i3 in range(K):
disttoplot_manual[order[i3]] = np.exp(1j * (i3 + 1) / K * 2 * np.pi)
angleplot = np.exp(
1j
* (
np.angle(disttoplot_manual[:, np.newaxis]).T
- np.angle(disttoplot_manual[:, np.newaxis])
)
)
return angleplot
[docs]
def optimise_sequence(
fo_density: np.ndarray, metric_to_use: int = 0, n_perms: int = 10**6
) -> np.ndarray:
"""Optimise the sequence to maximal circularity.
This function reads in the mean pattern of differential fractional
occupancy and computes the optimal display for a sequential circular plot
visualization.
Parameters
----------
fo_density : array_like
Time-in-state densities array of shape (n_interval_states,
n_density_states, 2, n_sessions).
metric : int, optional
Metric to use for optimisation:
- :code:`0`: mean FO asymmetry.
- :code:`1`: proportional FO asymmetry (i.e. asymmetry as a proportion
of a baseline - which time spend in the state).
- :code:`2`: proportional FO asymmetry using global baseline FO, rather
than a individual-specific baseline.
Returns
-------
best_sequence : list
List of best sequence of states to plot (in order of counterclockwise
rotation).
"""
if len(fo_density.shape) == 5:
fo_density = np.squeeze(fo_density)
# make sure there are no nans:
fo_density[np.isnan(fo_density)] = 0
# Compute different metrics to optimise
metric = []
metric.append(np.mean(fo_density[:, :, 0, :] - fo_density[:, :, 1, :], axis=2))
temp = (fo_density[:, :, 0, :] - fo_density[:, :, 1, :]) / np.mean(
fo_density, axis=2
)
temp[np.isnan(temp)] = 0
metric.append(np.mean(temp, axis=2))
metric.append(
np.mean(fo_density[:, :, 0, :] - fo_density[:, :, 1, :], axis=2)
/ np.mean(fo_density, axis=(2, 3))
)
n_metrics = len(metric)
K = fo_density.shape[0]
best_sequence = []
for i in range(n_metrics):
ix = np.arange(K)
v = np.imag(np.sum(circle_angles(ix) * metric[i]))
cnt = 0
while cnt < n_perms:
cnt += 1
swaps = np.random.permutation(K)
swaps = swaps[:2]
tmpix = ix.copy()
tmpix[swaps[0]] = ix[swaps[1]]
tmpix[swaps[1]] = ix[swaps[0]]
tmpv = np.imag(np.sum(circle_angles(tmpix) * metric[i]))
if tmpv < v:
v = tmpv
ix = tmpix
best_sequence.append(np.roll(ix, -np.where([iix == 0 for iix in ix])[0][0]))
# Return the best sequence for the chosen metric (in order of counterclockwise
# rotation)
return best_sequence[metric_to_use]
[docs]
def compute_cycle_strength(
angleplot: np.ndarray,
asym: np.ndarray,
relative: bool = True,
whichstate: Optional[int] = None,
) -> np.ndarray:
"""Compute cycle strength."""
if len(asym.shape) == 3:
tmp = np.stack(
[angleplot * asym[:, :, i] for i in range(asym.shape[2])], axis=-1
)
else:
tmp = angleplot * asym
if whichstate is not None:
# Note that we are counting each (i,j) double because for the rotational
# momentum per state we take into account (i,j) and (j,i) for all j and one
# particular i.
tmp = np.squeeze(
tmp[
whichstate,
:,
]
) + np.squeeze(tmp[:, whichstate])
cycle_strength = np.imag(np.nansum(tmp, axis=0))
else:
cycle_strength = np.imag(np.nansum(tmp, axis=(0, 1)))
# positive rotational momentum should indicate clockwise cycle
cycle_strength = -cycle_strength
if relative: # normalise by the theoretical maximum
cycle_strength = cycle_strength / np.abs(
compute_cycle_strength(
angleplot,
np.sign(np.imag(angleplot)),
relative=False,
whichstate=whichstate,
)
)
return cycle_strength
[docs]
def plot_cycle(
ordering: List[int],
fo_density: np.ndarray,
edges: np.ndarray,
new_figure: bool = False,
color_scheme: Optional[np.ndarray] = None,
) -> None:
"""Plot state network as circular diagram with arrows.
Parameters
----------
ordering : list
List of best sequence of states to plot (in order of counterclockwise
rotation).
fo_density : array_like
Time-in-state densities array of shape (n_interval_states,
n_density_states, 2, (n_interval_ranges,) n_sessions).
edges : array_like
Array of zeros and ones indicating whether the connection should be
plotted.
new_figure : bool, optional
Whether to create a new figure (default is :code:`False`).
color_scheme : array_like, optional
Array of size (K,3) color scheme to use for plotting (default is
:code:`None`). If :code:`None`, will use the default color scheme from
the matlab code.
"""
# Plot state network as circular diagram with arrows
if color_scheme is None:
color_scheme = np.array(
[
[0, 0, 1.0000],
[1.0000, 0.3333, 0],
[
1.0000,
0.6667,
0,
],
[
0.6667,
1.0000,
0.3333,
],
[
0.3333,
1.0000,
0.6667,
],
[
0,
1.0000,
1.0000,
],
[
0.5529,
0.8275,
0.7804,
],
[
1.0000,
0.5000,
0.5000,
],
[
0,
0.6667,
1.0000,
],
[
1.0000,
1.0000,
0,
],
[
0.7451,
0.7294,
0.8549,
],
[
0.6667,
0,
0,
],
]
)
if new_figure:
plt.figure(figsize=(6.02, 4.52), dpi=100)
else:
plt.gca()
K = len(ordering)
if len(fo_density.shape) == 5:
fo_density = np.squeeze(
fo_density
) # squeeze in case there is still a interval_ranges dimension
# compute mean direction of arrows
mean_direction = np.squeeze(
(fo_density[:, :, 0, :] - fo_density[:, :, 1, :]).mean(axis=2)
)
# reorder the states to match the ordering:
ordering = np.roll(
ordering[::-1], 1
) # rotate ordering from clockwise to counter clockwise
edges = edges[ordering][:, ordering]
mean_direction = mean_direction[ordering][:, ordering]
# get the locations on the unit circle
theta = np.arange(0, 2 * np.pi, 2 * np.pi / K)
x = np.roll(np.cos(theta), int(K / 4)) # start from 12 o'clock
y = np.roll(np.sin(theta), int(K - (K / 4)))
distance_to_plot_manual = np.stack([x, y]).T
# plot the scatter points with state identities
for i in range(K):
plt.scatter(
distance_to_plot_manual[i, 0],
distance_to_plot_manual[i, 1],
s=400,
color=color_scheme[ordering[i], :],
)
plt.text(
distance_to_plot_manual[i, 0],
distance_to_plot_manual[i, 1],
str(ordering[i] + 1),
horizontalalignment="center",
verticalalignment="center",
fontsize=16,
)
# plot the arrows
for ik1 in range(K):
for k2 in range(K):
if edges[ik1, k2]:
# arrow lengths have to be proportional to the distance
# between the states. Use Pythagoras:
line_scale = np.sqrt(
np.sum(
(
distance_to_plot_manual[k2, :]
- distance_to_plot_manual[ik1, :]
)
** 2
)
)
arrow_start = (
distance_to_plot_manual[ik1, :]
+ 0.1
* (distance_to_plot_manual[k2, :] - distance_to_plot_manual[ik1, :])
/ line_scale
)
arrow_end = (
distance_to_plot_manual[k2, :]
- 0.1
* (distance_to_plot_manual[k2, :] - distance_to_plot_manual[ik1, :])
/ line_scale
)
if mean_direction[ik1, k2] > 0: # arrow from k1 to k2:
plt.arrow(
arrow_start[0],
arrow_start[1],
arrow_end[0] - arrow_start[0],
arrow_end[1] - arrow_start[1],
head_width=0.05,
head_length=0.1,
length_includes_head=True,
color="k",
)
elif mean_direction[ik1, k2] < 0: # arrow from k2 to k1:
plt.arrow(
arrow_end[0],
arrow_end[1],
arrow_start[0] - arrow_end[0],
arrow_start[1] - arrow_end[1],
head_width=0.05,
head_length=0.1,
length_includes_head=True,
color="k",
)
plt.axis("off")
plt.axis("equal")