Source code for idtap.spectrogram

"""Spectrogram data access and visualization for IDTAP audio recordings."""

from __future__ import annotations

import gzip
import numpy as np
from typing import Optional, Tuple, Dict, Any, List, TYPE_CHECKING
from pathlib import Path

if TYPE_CHECKING:
    from .client import SwaraClient
    from .classes.piece import Piece
    from PIL import Image
    from matplotlib.figure import Figure
    from matplotlib.axes import Axes
    from matplotlib.image import AxesImage


# Supported matplotlib colormaps (matches web app functionality)
SUPPORTED_COLORMAPS = [
    # Perceptually uniform
    'viridis', 'plasma', 'inferno', 'magma', 'cividis', 'turbo',
    # Sequential
    'Blues', 'Greens', 'Reds', 'Oranges', 'Purples', 'Greys',
    'BuGn', 'BuPu', 'GnBu', 'OrRd', 'PuBu', 'PuBuGn', 'PuRd', 'RdPu',
    'YlGn', 'YlGnBu', 'YlOrBr', 'YlOrRd',
    # Diverging
    'RdBu', 'BrBG', 'PRGn', 'PiYG', 'PuOr', 'RdGy', 'RdYlBu', 'RdYlGn', 'Spectral',
    # Cyclical
    'rainbow', 'hsv',
    # Temperature
    'cool', 'warm', 'coolwarm'
]


[docs] class SpectrogramData: """Constant-Q spectrogram data for IDTAP audio recordings. This class provides access to the same high-quality constant-Q transform spectrograms used in the IDTAP web application, with tools for visualization, manipulation, and integration with matplotlib-based research workflows. Attributes: audio_id: IDTAP audio recording ID freq_range: Tuple of (min_hz, max_hz) for the frequency range bins_per_octave: Number of frequency bins per octave """ # Constants matching web app implementation DEFAULT_FREQ_RANGE = (75.0, 2400.0) # Hz DEFAULT_BINS_PER_OCTAVE = 72 DEFAULT_TIME_RESOLUTION = 0.015080 # seconds per frame (fallback when DB unavailable)
[docs] def __init__(self, data: np.ndarray, audio_id: str, freq_range: Tuple[float, float] = DEFAULT_FREQ_RANGE, bins_per_octave: int = DEFAULT_BINS_PER_OCTAVE, time_resolution: Optional[float] = None): """Initialize SpectrogramData with raw data. Args: data: Raw uint8 spectrogram array [freq_bins, time_frames] audio_id: Audio recording ID freq_range: Frequency range (min_hz, max_hz) bins_per_octave: Number of frequency bins per octave time_resolution: Time resolution in seconds per frame (optional) If None, uses DEFAULT_TIME_RESOLUTION fallback """ if not isinstance(data, np.ndarray): raise TypeError(f"data must be numpy array, got {type(data)}") if data.dtype != np.uint8: raise TypeError(f"data must be uint8 array, got {data.dtype}") if data.ndim != 2: raise ValueError(f"data must be 2D array, got {data.ndim}D") self._data = data self.audio_id = audio_id self.freq_range = freq_range self.bins_per_octave = bins_per_octave self._time_resolution = time_resolution if time_resolution is not None else self.DEFAULT_TIME_RESOLUTION
[docs] @classmethod def from_audio_id(cls, audio_id: str, client: Optional['SwaraClient'] = None) -> 'SpectrogramData': """Download and load spectrogram data from audio ID. Fetches compressed spectrogram data from https://swara.studio/spec_data/{audio_id}/ and calculates accurate time_resolution from the audio recording duration in the database. Args: audio_id: IDTAP audio recording ID client: Optional SwaraClient instance (creates one if not provided) Returns: SpectrogramData instance with accurate time_resolution Raises: requests.HTTPError: If spectrogram data doesn't exist or download fails """ # Create client if not provided if client is None: from .client import SwaraClient client = SwaraClient() # Download compressed data and metadata compressed_data = client.download_spectrogram_data(audio_id) metadata = client.download_spectrogram_metadata(audio_id) # Decompress data decompressed = gzip.decompress(compressed_data) # Reshape to numpy array shape = tuple(metadata['shape']) # [freq_bins, time_frames] data = np.frombuffer(decompressed, dtype=np.uint8).reshape(shape) # Flip frequency axis so row 0 = lowest frequency (matches freq_bins ordering) # Server data has row 0 = highest frequency, but we want row 0 = lowest data = np.flipud(data) # Get exact audio duration from recording database time_resolution = None try: recording = client.get_audio_recording(audio_id) audio_duration = recording['duration'] time_frames = shape[1] time_resolution = audio_duration / time_frames except Exception: # Fallback to DEFAULT_TIME_RESOLUTION if recording not found # This will be handled by __init__ pass return cls(data, audio_id, time_resolution=time_resolution)
[docs] @classmethod def from_piece(cls, piece: 'Piece', client: Optional['SwaraClient'] = None) -> Optional['SpectrogramData']: """Load spectrogram data from a Piece object. Args: piece: Piece object with audio_id attribute client: Optional SwaraClient instance Returns: SpectrogramData instance, or None if piece has no audio_id """ if not hasattr(piece, 'audio_id') or piece.audio_id is None: return None return cls.from_audio_id(piece.audio_id, client)
[docs] def apply_intensity(self, power: float = 1.0) -> np.ndarray: """Apply power-law intensity transformation (matches web app behavior). This transformation enhances visual contrast in the spectrogram. Formula: output = (input^power / 255^power) * 255 Args: power: Exponent for power transform (1.0-5.0) 1.0 = linear (no change) >1.0 = increased contrast Returns: Transformed uint8 array with same shape as input Raises: ValueError: If power is outside valid range [1.0, 5.0] """ if not 1.0 <= power <= 5.0: raise ValueError(f"Power must be between 1.0 and 5.0, got {power}") if power == 1.0: return self._data.copy() # Vectorized power transform # Convert to float for precision, apply transform, convert back data_float = self._data.astype(np.float32) transformed = np.power(data_float / 255.0, power) * 255.0 return np.clip(transformed, 0, 255).astype(np.uint8)
[docs] def apply_colormap(self, data: Optional[np.ndarray] = None, cmap: str = 'viridis') -> np.ndarray: """Apply matplotlib colormap to spectrogram data. Args: data: Input spectrogram data (if None, uses self._data) cmap: Matplotlib colormap name (see SUPPORTED_COLORMAPS) Returns: RGB array of shape [height, width, 3] with uint8 values Raises: ValueError: If colormap name is not recognized """ import matplotlib.pyplot as plt if data is None: data = self._data # Get matplotlib colormap try: colormap = plt.get_cmap(cmap) except ValueError: raise ValueError( f"Unknown colormap: '{cmap}'. " f"See SUPPORTED_COLORMAPS for valid options." ) # Apply colormap (handles normalization automatically) # Returns RGBA array, we take only RGB channels colored = colormap(data / 255.0) # Normalize to [0, 1] rgb = (colored[:, :, :3] * 255).astype(np.uint8) return rgb
[docs] def crop_frequency(self, min_hz: Optional[float] = None, max_hz: Optional[float] = None) -> 'SpectrogramData': """Crop spectrogram to a specific frequency range. Args: min_hz: Minimum frequency (Hz), defaults to original min max_hz: Maximum frequency (Hz), defaults to original max Returns: New SpectrogramData instance with cropped data """ if min_hz is None: min_hz = self.freq_range[0] if max_hz is None: max_hz = self.freq_range[1] # Calculate bin indices for frequency range freq_bins = self.freq_bins # Find closest bin indices min_idx = np.searchsorted(freq_bins, min_hz) max_idx = np.searchsorted(freq_bins, max_hz) # Ensure valid range min_idx = max(0, min_idx) max_idx = min(len(freq_bins), max_idx) # Crop data cropped_data = self._data[min_idx:max_idx, :] # Create new instance with updated frequency range return SpectrogramData( cropped_data, self.audio_id, freq_range=(freq_bins[min_idx], freq_bins[max_idx - 1] if max_idx > min_idx else freq_bins[min_idx]), bins_per_octave=self.bins_per_octave, time_resolution=self._time_resolution )
[docs] def crop_time(self, start_time: Optional[float] = None, end_time: Optional[float] = None) -> 'SpectrogramData': """Crop spectrogram to a specific time range. Args: start_time: Start time in seconds (defaults to 0) end_time: End time in seconds (defaults to duration) Returns: New SpectrogramData instance with cropped data """ if start_time is None: start_time = 0.0 if end_time is None: end_time = self.duration # Convert times to frame indices start_frame = int(start_time / self.time_resolution) end_frame = int(end_time / self.time_resolution) # Ensure valid range start_frame = max(0, start_frame) end_frame = min(self.shape[1], end_frame) # Crop data cropped_data = self._data[:, start_frame:end_frame] return SpectrogramData( cropped_data, self.audio_id, freq_range=self.freq_range, bins_per_octave=self.bins_per_octave, time_resolution=self._time_resolution )
[docs] def get_extent(self) -> List[float]: """Get matplotlib extent for this spectrogram. Returns: [left, right, bottom, top] = [0, duration, min_freq, max_freq] This is the format matplotlib imshow() expects for extent parameter. """ return [0, self.duration, self.freq_range[0], self.freq_range[1]]
[docs] def get_plot_data(self, power: float = 1.0, apply_cmap: bool = False, cmap: str = 'viridis') -> Tuple[np.ndarray, List[float]]: """Get processed spectrogram data and extent for matplotlib plotting. Use this when you need direct control over the plotting process, or when you want to manipulate the data before plotting. Args: power: Intensity power transform (1.0-5.0) apply_cmap: If True, returns RGB array; if False, returns grayscale uint8 cmap: Colormap name (only used if apply_cmap=True) Returns: Tuple of (data, extent): - data: Processed spectrogram array If apply_cmap=False: uint8 array [freq_bins, time_frames] If apply_cmap=True: RGB uint8 array [freq_bins, time_frames, 3] - extent: [left, right, bottom, top] for matplotlib imshow() Example: >>> # Low-level control >>> data, extent = spec.get_plot_data(power=2.5) >>> fig, ax = plt.subplots() >>> im = ax.imshow(data, extent=extent, aspect='auto', ... origin='lower', cmap='magma') """ # Apply intensity transform transformed = self.apply_intensity(power) # Optionally apply colormap if apply_cmap: data = self.apply_colormap(transformed, cmap) else: data = transformed # Get extent extent = self.get_extent() return data, extent
[docs] def plot_on_axis(self, ax: 'Axes', power: float = 1.0, cmap: str = 'viridis', alpha: float = 1.0, zorder: int = 0, log_freq: bool = True, **imshow_kwargs) -> 'AxesImage': """Plot spectrogram on an existing matplotlib axis (for overlays). This is the primary method for using spectrograms as underlays in custom matplotlib visualizations. Args: ax: Matplotlib axis to plot on power: Intensity power transform (1.0-5.0) cmap: Matplotlib colormap name alpha: Transparency (0.0-1.0), useful for subtle underlays zorder: Drawing order (0 = background, higher = foreground) log_freq: Whether to use logarithmic frequency scale (default: True) **imshow_kwargs: Additional arguments passed to ax.imshow() Returns: AxesImage object (useful for adding colorbars) Example: >>> fig, ax = plt.subplots(figsize=(12, 6)) >>> im = spec.plot_on_axis(ax, power=2.0, cmap='viridis', alpha=0.7) >>> ax.plot(times, pitch_contour, 'r-', linewidth=2) # Overlay pitch >>> ax.set_xlabel('Time (s)') >>> ax.set_ylabel('Frequency (Hz)') >>> plt.colorbar(im, ax=ax, label='Intensity') """ # Get processed data and extent data, extent = self.get_plot_data(power=power, apply_cmap=False) # Plot on provided axis im = ax.imshow( data, extent=extent, aspect='auto', origin='lower', cmap=cmap, alpha=alpha, zorder=zorder, **imshow_kwargs ) # Set log scale for frequency axis (CQT is log-spaced) if log_freq: ax.set_yscale('log') # Set reasonable y-axis limits ax.set_ylim(self.freq_range[0], self.freq_range[1]) return im
[docs] def to_image(self, width: Optional[int] = None, height: Optional[int] = None, power: float = 1.0, cmap: str = 'viridis', interpolation: str = 'bilinear') -> 'Image': """Generate PIL Image with full processing pipeline. Args: width: Output width in pixels (default: original width) height: Output height in pixels (default: original height) power: Intensity power transform (1.0-5.0) cmap: Matplotlib colormap name interpolation: Resampling method ('bilinear', 'nearest', 'lanczos', etc.) See PIL.Image.Resampling for all options Returns: PIL Image in RGB mode """ from PIL import Image # Apply intensity transform transformed = self.apply_intensity(power) # Apply colormap rgb = self.apply_colormap(transformed, cmap) # Create PIL Image img = Image.fromarray(rgb, mode='RGB') # Resize if requested if width or height: # Determine final size orig_height, orig_width = rgb.shape[:2] if width and height: new_size = (width, height) elif width: # Keep aspect ratio ratio = width / orig_width new_size = (width, int(orig_height * ratio)) else: # height only ratio = height / orig_height new_size = (int(orig_width * ratio), height) # Map interpolation string to PIL constant from PIL import Image as PILImage resample_map = { 'nearest': PILImage.Resampling.NEAREST, 'bilinear': PILImage.Resampling.BILINEAR, 'bicubic': PILImage.Resampling.BICUBIC, 'lanczos': PILImage.Resampling.LANCZOS, } resample = resample_map.get(interpolation.lower(), PILImage.Resampling.BILINEAR) img = img.resize(new_size, resample=resample) return img
[docs] def to_matplotlib(self, figsize: Tuple[float, float] = (12, 6), power: float = 1.0, cmap: str = 'viridis', show_colorbar: bool = True, show_axes: bool = True, log_freq: bool = True) -> 'Figure': """Generate standalone matplotlib Figure for publication. Use this for quick visualization. For overlays and custom plots, use plot_on_axis() instead. Args: figsize: Figure size (width, height) in inches power: Intensity power transform (1.0-5.0) cmap: Matplotlib colormap name show_colorbar: Whether to show colorbar show_axes: Whether to show frequency/time axis labels log_freq: Whether to use logarithmic frequency scale (default: True) Returns: Matplotlib Figure object """ import matplotlib.pyplot as plt fig, ax = plt.subplots(figsize=figsize) # Use plot_on_axis internally im = self.plot_on_axis(ax, power=power, cmap=cmap, log_freq=log_freq) if show_axes: ax.set_xlabel('Time (s)') ax.set_ylabel('Frequency (Hz)') else: ax.axis('off') if show_colorbar: plt.colorbar(im, ax=ax, label='Intensity') return fig
[docs] def save(self, filepath: str, width: Optional[int] = None, height: Optional[int] = None, power: float = 1.0, cmap: str = 'viridis', format: Optional[str] = None, **kwargs): """Save spectrogram as image file. Args: filepath: Output file path width: Output width in pixels (default: original) height: Output height in pixels (default: original) power: Intensity power transform (1.0-5.0) cmap: Matplotlib colormap name format: Image format ('png', 'jpg', 'webp', etc.) Auto-detected from filepath extension if not provided **kwargs: Additional arguments passed to PIL Image.save() """ img = self.to_image(width, height, power, cmap) # Auto-detect format from extension if not provided if format is None: suffix = Path(filepath).suffix if suffix: format = suffix[1:] # Remove leading dot img.save(filepath, format=format, **kwargs)
@property def shape(self) -> Tuple[int, int]: """Data shape: (frequency_bins, time_frames).""" return self._data.shape @property def duration(self) -> float: """Audio duration in seconds (estimated from time frames).""" return self.shape[1] * self.time_resolution @property def time_resolution(self) -> float: """Time resolution in seconds per frame. Calculated from audio recording duration in database (when available). Falls back to DEFAULT_TIME_RESOLUTION if recording metadata unavailable. Note: Spectrograms always cover the full audio recording, even when the associated Piece transcribes only an excerpt. """ return self._time_resolution @property def freq_bins(self) -> np.ndarray: """Array of frequency values (Hz) for each bin. Calculated from bins_per_octave and freq_range using log spacing. """ n_bins = self.shape[0] # Calculate frequencies using constant-Q log spacing # freq = min_freq * 2^(bin / bins_per_octave) min_freq = self.freq_range[0] bin_indices = np.arange(n_bins) frequencies = min_freq * np.power(2, bin_indices / self.bins_per_octave) return frequencies