Source code for lcviz.viewers

from glue.core.subset import Subset
from glue.config import data_translator
from glue.core import BaseData
from glue.core.exceptions import IncompatibleAttribute
from glue.core.roi import RangeROI
from glue.core.subset_group import GroupedSubset

from glue_jupyter.bqplot.scatter import BqplotScatterView

from astropy import units as u
from astropy.time import Time

from jdaviz.core.events import NewViewerMessage
from jdaviz.core.registries import viewer_registry
from jdaviz.configs.cubeviz.plugins.viewers import (CubevizImageView,
                                                    WithSliceIndicator, WithSliceSelection)
from jdaviz.configs.default.plugins.viewers import JdavizViewerMixin
from jdaviz.configs.specviz.plugins.viewers import SpecvizProfileView

from lcviz.state import ScatterViewerState

from lightkurve import LightCurve

__all__ = ['TimeScatterView', 'PhaseScatterView', 'CubeView']


class CloneViewerMixin:
    def _get_clone_viewer_reference(self):
        base_name = self.reference.split("[")[0]
        name = base_name
        ind = 0
        while name in self.jdaviz_helper.viewers.keys():
            ind += 1
            name = f"{base_name}[{ind}]"
        return name

    def clone_viewer(self):
        name = self.jdaviz_helper._get_clone_viewer_reference(self.reference)

        self.jdaviz_app._on_new_viewer(NewViewerMessage(self.__class__,
                                                        data=None,
                                                        sender=self.jdaviz_app),
                                       vid=name, name=name)

        this_viewer_item = self.jdaviz_app._get_viewer_item(self.reference)
        for data_id, visible in this_viewer_item['selected_data_items'].items():
            data_label = data_label = self.jdaviz_app._get_data_item_by_id(data_id)['name']
            self.jdaviz_app.set_data_visibility(name, data_label, visible == 'visible')
            # TODO: don't revert color when adding same data to a new viewer
            # (same happens when creating a phase-viewer from ephemeris plugin)

        new_viewer = self.jdaviz_app.get_viewer(name)
        if hasattr(self, 'ephemeris_component'):
            new_viewer._ephemeris_component = self._ephemeris_component
        for k, v in self.state.as_dict().items():
            if k in ('layers',):
                continue
            setattr(new_viewer.state, k, v)

        for this_layer_state, new_layer_state in zip(self.state.layers, new_viewer.state.layers):
            for k, v in this_layer_state.as_dict().items():
                if k in ('layer',):
                    continue
                setattr(new_layer_state, k, v)

        return new_viewer.user_api


[docs] @viewer_registry("lcviz-time-viewer", label="flux-vs-time") class TimeScatterView(JdavizViewerMixin, CloneViewerMixin, WithSliceIndicator, BqplotScatterView): # categories: zoom resets, zoom, pan, subset, select tools, shortcuts tools_nested = [ ['jdaviz:homezoom', 'jdaviz:prevzoom'], ['jdaviz:boxzoom', 'jdaviz:xrangezoom', 'jdaviz:yrangezoom'], ['jdaviz:panzoom', 'jdaviz:panzoom_x', 'jdaviz:panzoom_y'], ['bqplot:xrange', 'bqplot:yrange', 'bqplot:rectangle'], ['jdaviz:selectslice'], ['lcviz:viewer_clone', 'jdaviz:sidebar_plot', 'jdaviz:sidebar_export'] ] default_class = LightCurve _state_cls = ScatterViewerState _native_mark_classnames = ('Image', 'ImageGL', 'Scatter', 'ScatterGL') def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.display_mask = False self.time_unit = kwargs.get('time_unit', u.d) self.initialize_toolbar(default_tool_priority=['jdaviz:selectslice']) self._subscribe_to_layers_update() # hack to inherit a small subset of methods from SpecvizProfileView # TODO: refactor jdaviz so these can be included in some mixin self._show_uncertainty_changed = lambda value: SpecvizProfileView._show_uncertainty_changed(self, value) # noqa self._plot_uncertainties = lambda: SpecvizProfileView._plot_uncertainties(self) # TODO: _plot_uncertainties in specviz is hardcoded to look at spectral_axis and so crashes self._clean_error = lambda: SpecvizProfileView._clean_error(self) self.density_map = kwargs.get('density_map', False) @property def slice_component_label(self): # label of the component in the lightcurves corresponding to the slice axis # calling data_collection_item.get_component(slice_component_label) must work return 'dt' @property def slice_display_unit_name(self): return 'time'
[docs] def data(self, cls=None): data = [] # TODO: generalize upstream in jdaviz. # This method is generalized from # jdaviz/configs/specviz/plugins/viewers.py # to support non-spectral viewers. for layer_state in self.state.layers: if hasattr(layer_state, 'layer'): lyr = layer_state.layer # For raw data, just include the data itself if isinstance(lyr, BaseData): _class = cls or self.default_class if _class is not None: cache_key = lyr.label if cache_key in self.jdaviz_app._get_object_cache: layer_data = self.jdaviz_app._get_object_cache[cache_key] else: layer_data = lyr.get_object(cls=_class) self.jdaviz_app._get_object_cache[cache_key] = layer_data data.append(layer_data) # For subsets, make sure to apply the subset mask to the layer data first elif isinstance(lyr, (Subset, GroupedSubset)): layer_data = lyr if _class is not None: handler, _ = data_translator.get_handler_for(_class) try: layer_data = handler.to_object(layer_data) except IncompatibleAttribute: continue data.append(layer_data) return data
def _apply_layer_defaults(self, layer_state): if getattr(layer_state.layer, 'meta', {}).get('Plugin', None) == 'Binning': # increased size of binned results, by default layer_state.size = 5
[docs] def set_plot_axes(self): # set which components should be plotted dc = self.jdaviz_app.data_collection component_labels = [comp.label for comp in dc[0].components] # Get data to be used for axes labels light_curve = self.data()[0] self._set_plot_x_axes(dc, component_labels, light_curve) self._set_plot_y_axes(dc, component_labels, light_curve)
def _set_plot_x_axes(self, dc, component_labels, light_curve): self.state.x_att = dc[0].components[component_labels.index('dt')] x_unit = self.time_unit reference_time = light_curve.meta.get('reference_time', None) if reference_time is not None: xlabel = f'{str(x_unit.physical_type).title()} from {reference_time.iso} ({x_unit})' else: xlabel = f'{str(x_unit.physical_type).title()} ({x_unit})' self.figure.axes[0].label = xlabel self.figure.axes[0].num_ticks = 5 def _set_plot_y_axes(self, dc, component_labels, light_curve): self.state.y_att = dc[0].components[component_labels.index('flux')] y_unit = light_curve.flux.unit y_unit_physical_type = str(y_unit.physical_type).title() common_count_rate_units = (u.electron / u.s, u.dn / u.s, u.ct / u.s) if y_unit_physical_type == 'Unknown': if y_unit.is_equivalent(common_count_rate_units): y_unit_physical_type = 'Flux' if y_unit_physical_type == 'Dimensionless': y_unit_physical_type = 'Relative Flux' ylabel = f'{y_unit_physical_type}' if not y_unit.is_equivalent(u.dimensionless_unscaled): ylabel += f' ({y_unit})' self.figure.axes[1].label = ylabel # Make it so y axis label is not covering tick numbers (sometimes) self.figure.axes[1].label_offset = "-50" # Set (X,Y)-axis to scientific notation if necessary: self.figure.axes[0].tick_format = 'g' self.figure.axes[1].tick_format = 'g' self.figure.axes[1].num_ticks = 5 def _expected_subset_layer_default(self, layer_state): super()._expected_subset_layer_default(layer_state) layer_state.linewidth = 3 # optionally prevent subset from being rendered # as a density map, rather than shaded markers over data: layer_state.density_map = self.density_map
[docs] def add_data(self, data, color=None, alpha=None, **layer_state): """ Overrides the base class to handle subset styling defaults. Parameters ---------- data : :class:`glue.core.data.Data` Data object with the light curve. color : obj Color value for plotting. alpha : float Alpha value for plotting. Returns ------- result : bool `True` if successful, `False` otherwise. """ result = super().add_data(data, color, alpha, **layer_state) for layer in self.layers: # optionally render as a density map layer.state.density_map = self.density_map # Set default linewidth on any created subset layers for layer in self.state.layers: if "Subset" in layer.layer.label and layer.layer.data.label == data.label: layer.linewidth = 3 # update viewer limits when data are added self.set_plot_axes() self.state.reset_limits() return result
def _show_uncertainty_changed(*args, **kwargs): # method required by jdaviz pass
[docs] def apply_roi(self, roi, use_current=False): if isinstance(roi, RangeROI): # allow ROIs describing times to be applied with min and max defined as: # 1. floats, representing bounds in units of ``self.time_unit`` # 2. Time objects, which get converted to work like (1) via the reference time if isinstance(roi.min, Time) or isinstance(roi.max, Time): reference_time = self.data()[0].meta.get('reference_time', 0) roi = roi.transformed(xfunc=lambda x: (x - reference_time).to_value(self.time_unit)) super().apply_roi(roi, use_current=use_current)
[docs] @viewer_registry("lcviz-phase-viewer", label="flux-vs-phase") class PhaseScatterView(TimeScatterView): def __init__(self, *args, **kwargs): self._ephemeris_component = 'default' super().__init__(*args, **kwargs) @property def ephemeris(self): ephem = self.jdaviz_helper.plugins.get('Ephemeris', None) if ephem is None: raise ValueError("must have ephemeris plugin loaded to access ephemeris") return ephem.ephemerides.get(self._ephemeris_component) def _set_plot_x_axes(self, dc, component_labels, light_curve): # setting of y_att will be handled by ephemeris plugin self.state.x_att = dc[0].components[component_labels.index(f'phase:{self._ephemeris_component}')] # noqa self.figure.axes[0].label = 'phase' self.figure.axes[0].num_ticks = 5
[docs] def times_to_phases(self, times): ephem = self.jdaviz_helper.plugins.get('Ephemeris', None) if ephem is None: raise ValueError("must have ephemeris plugin loaded to convert") return ephem.times_to_phases(times, ephem_component=self._ephemeris_component)
def _set_slice_indicator_value(self, value): # NOTE: on first call, this will initialize the indicator itself self.slice_indicator.value = self.times_to_phases(value)
[docs] @viewer_registry("lcviz-cube-viewer", label="cube") class CubeView(CloneViewerMixin, CubevizImageView, WithSliceSelection): # categories: zoom resets, zoom, pan, subset, select tools, shortcuts tools_nested = [ ['jdaviz:homezoom', 'jdaviz:prevzoom'], ['jdaviz:boxzoom'], ['jdaviz:panzoom'], ['bqplot:rectangle'], ['lcviz:viewer_clone', 'jdaviz:sidebar_plot', 'jdaviz:sidebar_export'] ] # TODO: can we vary this default_class based on Kepler vs TESS, etc? # see https://github.com/spacetelescope/lcviz/pull/81#discussion_r1469721009 default_class = None def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.display_mask = False self.time_unit = kwargs.get('time_unit', u.d) self.initialize_toolbar() self._subscribe_to_layers_update() # Hide axes by default self.state.show_axes = False # TODO: refactor upstream so lcviz can inherit cubeviewer methods/setup without # jdaviz-specific logic: # * _default_spectrum_viewer_reference_name # * _default_flux_viewer_reference_name # * _default_uncert_viewer_reference_name @property def slice_component_label(self): # label of the component in the cubes corresponding to the slice axis # calling data_collection_item.get_component(slice_component_label) on any # input cube-data must work return 'dt' @property def slice_index(self): # index in viewer.slices corresponding to the slice axis return 0 @property def slice_display_unit_name(self): return 'time' def _initial_x_axis(self, *args): # Make sure that the x_att/y_att is correct on data load # called via a callback set upstream in CubevizImageView when reference_data is changed ref_data = self.state.reference_data if ref_data is not None: self.state.x_att = ref_data.id['Pixel Axis 2 [x]'] self.state.y_att = ref_data.id['Pixel Axis 1 [y]'] def _on_layers_update(self, layers=None): super()._on_layers_update(layers=layers) ref_data = self.state.reference_data if ref_data is None: return flux_comp = ref_data.id['flux'] for layer in self.state.layers: if hasattr(layer, 'attribute') and layer.attribute != flux_comp: layer.attribute = flux_comp
[docs] def data(self, cls=None): # TODO: generalize upstream in jdaviz. # This method is generalized from # jdaviz/configs/cubeviz/plugins/viewers.py return [layer_state.layer for layer_state in self.state.layers if hasattr(layer_state, 'layer') and isinstance(layer_state.layer, BaseData)]