"""
HMM: Plotting fMRI Networks
===========================

We normally plot networks from an HMM trained on fMRI directly from the dual-estimated means/covariances. This tutorial covers:

1. Load means/covariances
2. Spatial maps
3. Functional connectivity networks

Note, we assume a group ICA parcellation was used and the HMM was trained on the ICA time courses.
"""

#%%
# Load means/covariances
# ^^^^^^^^^^^^^^^^^^^^^^
# First we need to the dual estimated means and covariances.
#
# .. code-block:: python
#
#     import numpy as np
#
#     means = np.load("results/dual_estimates/means.npy")
#     covs = np.load("results/dual_estimates/covs.npy")
#     print(means.shape)
#     print(covs.shape)

#%%
# `means` is a (subjects, states, channels) array and `covariances` is a (subjects, states, channels, channels) array.
#
# Spatial maps
# ^^^^^^^^^^^^
#
# Volumetric parcellation
# ***********************
# The spatial activity maps correspond to the `means`, or if we did not learn a mean, the diagonal of the `covs`. How we plot the spatial maps depends on how the data was preprocessed. If we used a volumetric parcellation, then we can plot the spatial maps with
#
# .. code-block:: python
#
#     from osl_dynamics.analysis import power
#
#     # Calculate a group average
#     group_mean = np.mean(means, axis=0)
#
#     # Plot
#     fig, ax = power.save(
#         means,
#         mask_file="MNI152_T1_2mm_brain.nii.gz",
#         parcellation_file="melodic_IC.nii.gz",  # this should be the group-ICA spatial maps from FSL
#         plot_kwargs={"views": ["lateral", "medial"]},
#     )

#%%
# Surface parcellation
# ********************
# Alternatively, if we used a surface parcellation, we can use workbench to plot the spatial maps
#
# .. code-block:: python
#
#     # Save cifti containing the state maps
#     power.independent_components_to_surface_maps(
#         ica_spatial_maps="melodic_IC.dscalar.nii",
#         ic_values=group_mean,
#         output_file="results/inf_params/means.dscalar.nii",
#     )
#
#    # Plot
#    workbench.setup("/path/to/workbench/bin_macosxub")
#    workbench.render(
#        img=f"{inf_params_dir}/means.dscalar.nii",
#        gui=False,
#        save_dir="tmp",
#        image_name="plots/means_.png",
#        input_is_cifti=True,
#    )
#
#    # Delete temporary directory
#    os.system("rm -r tmp")

#%%
# You can download workbench `here <https://www.humanconnectome.org/software/get-connectome-workbench>`_.
#
# If you did not learn the mean, then replace `means` above with the diagonal of the covariances.
#
# Functional Connectivity Networks
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# The off-diagonal elements in the covariances corresponds to the functional connectivity. In osl-dynamics, we can only plot these if we used a volumetric parcellation with
#
# .. code-block:: python
#
#     from osl_dynamics.analysis import connectivity
#
#     connectivity.save(
#         covs,
#         parcellation_file="melodic_IC.nii.gz",  # this should be the group-ICA spatial maps from FSL
#         plot_kwargs={
#             "edge_cmap": "Reds",
#             "display_mode": "xz",
#             "annotate": False,
#         },
#     )
