"""HTTP client for the Swara Studio API."""
from __future__ import annotations
from typing import Any, Dict, Optional, Union, List, Callable
import json
from pathlib import Path
import requests
import os
from idtap.classes.piece import Piece
from .auth import login_google, load_token
from .secure_storage import SecureTokenStorage
from .query import Query
from .query_types import (
QueryType, MultipleReturnType, CategoryType,
DesignatorType, SegmentationType, QueryAnswerType
)
from .classes.pitch import Pitch
[docs]
class SwaraClient:
"""Minimal client wrapping the public API served at https://swara.studio."""
[docs]
def __init__(
self,
base_url: str = "https://swara.studio/",
token_path: str | Path | None = None,
auto_login: bool = True,
) -> None:
self.base_url = base_url.rstrip("/") + "/"
# Initialize secure storage
self.secure_storage = SecureTokenStorage()
# Keep token_path for backwards compatibility
self.token_path = Path(token_path or os.environ.get("SWARA_TOKEN_PATH", "~/.swara/token.json")).expanduser() if token_path else None
self.auto_login = auto_login
self.token: Optional[str] = None
self.user: Optional[Dict[str, Any]] = None
self.load_token()
if self.token is None and self.auto_login:
try:
login_google(base_url=self.base_url, storage=self.secure_storage)
self.load_token()
except Exception as e:
print(f"Failed to log in to Swara Studio: {e}")
raise
@property
def user_id(self) -> Optional[str]:
"""Return the user ID if available, otherwise ``None``."""
if self.user:
return self.user.get("_id") or self.user.get("sub")
return None
# ---- auth utilities ----
[docs]
def load_token(self, token_path: Optional[str | Path] = None) -> None:
"""Load saved token and profile information from secure storage."""
try:
# Use the new secure storage with backwards compatibility
legacy_path = Path(token_path or self.token_path) if (token_path or self.token_path) else None
data = load_token(storage=self.secure_storage, token_path=legacy_path)
if data:
# Check if tokens are expired and need refresh
if self.secure_storage.is_token_expired(data):
print("⚠️ Stored tokens are expired. Please re-authenticate.")
# Clear expired tokens
self.secure_storage.clear_tokens()
self.token = None
self.user = None
return
self.token = data.get("id_token") or data.get("token")
self.user = data.get("profile") or data.get("user")
else:
self.token = None
self.user = None
except Exception as e:
print(f"Failed to load tokens: {e}")
self.token = None
self.user = None
[docs]
def get_auth_info(self) -> Dict[str, Any]:
"""Get information about the current authentication and storage setup.
Returns:
Dict containing authentication status and storage information
"""
storage_info = self.secure_storage.get_storage_info()
return {
"authenticated": self.token is not None,
"user_id": self.user_id,
"user_email": self.user.get("email") if self.user else None,
"storage_info": storage_info,
"token_expired": False if not self.token else None
}
def _auth_headers(self) -> Dict[str, str]:
if self.token:
return {"Authorization": f"Bearer {self.token}"}
return {}
def _post_json(self, endpoint: str, payload: Dict[str, Any]) -> Any:
url = self.base_url + endpoint
headers = self._auth_headers()
response = requests.post(url, json=payload, headers=headers, timeout=1800) # 30 minutes
response.raise_for_status()
if response.content:
return response.json()
return None
def _get(self, endpoint: str, params: Optional[Dict[str, Any]] = None) -> Any:
url = self.base_url + endpoint
headers = self._auth_headers()
response = requests.get(url, params=params, headers=headers, timeout=1800) # 30 minutes
response.raise_for_status()
ctype = response.headers.get("Content-Type", "")
if ctype.startswith("application/json"):
return response.json()
return response.content
def _delete_json(self, endpoint: str, payload: Dict[str, Any]) -> Any:
url = self.base_url + endpoint
headers = self._auth_headers()
response = requests.delete(url, json=payload, headers=headers, timeout=1800)
response.raise_for_status()
if response.content:
return response.json()
return None
# ---- API methods ----
[docs]
def get_piece(self, piece_id: str, fetch_rule_set: bool = True) -> Any:
"""Return transcription JSON for the given id.
Args:
piece_id: The ID of the piece to fetch
fetch_rule_set: If True and raga has no ruleSet, fetch it from database
Returns:
Dictionary containing the piece data with ruleSet populated if needed
"""
# Check waiver and prompt if needed
self._prompt_for_waiver_if_needed()
piece_data = self._get(f"api/transcription/{piece_id}")
# If fetch_rule_set is True and there's a raga without a ruleSet, fetch it
if fetch_rule_set and 'raga' in piece_data:
raga_data = piece_data['raga']
if 'ruleSet' not in raga_data or not raga_data.get('ruleSet'):
raga_name = raga_data.get('name')
if raga_name and raga_name != 'Yaman':
try:
# Fetch the rule_set from the database
raga_rules = self.get_raga_rules(raga_name)
if 'rules' in raga_rules:
piece_data['raga']['ruleSet'] = raga_rules['rules']
except:
# If fetch fails, leave it as is
pass
return piece_data
[docs]
def excel_data(self, piece_id: str) -> bytes:
"""Export transcription data as Excel file."""
# Check waiver and prompt if needed
self._prompt_for_waiver_if_needed()
return self._get(f"api/transcription/{piece_id}/excel")
[docs]
def json_data(self, piece_id: str) -> bytes:
"""Export transcription data as JSON file."""
# Check waiver and prompt if needed
self._prompt_for_waiver_if_needed()
return self._get(f"api/transcription/{piece_id}/json")
[docs]
def save_piece(self, piece: Dict[str, Any]) -> Any:
"""Save transcription using authenticated API route."""
return self._post_json("api/transcription", piece)
[docs]
def insert_new_transcription(self, piece: Dict[str, Any]) -> Any:
"""Insert a new transcription document as the current authenticated user."""
if not self.user_id:
raise RuntimeError("Not authenticated: cannot insert new transcription")
payload = dict(piece)
payload["userID"] = self.user_id
return self._post_json("insertNewTranscription", payload)
[docs]
def clone_transcription(
self,
piece_id: str,
title: Optional[str] = None,
explicit_permissions: Optional[Dict[str, Any]] = None,
soloist: Optional[str] = None,
solo_instrument: Optional[str] = None,
) -> Any:
"""Clone a transcription, creating a new copy owned by the current user.
The server copies all transcription data (phrases, trajectories, pitches,
raga, audio association, etc.) and assigns a new ID, owner, and timestamps.
Args:
piece_id: The ID of the transcription to clone.
title: Title for the cloned transcription. Defaults to server behavior.
explicit_permissions: Permission object with 'edit', 'view' (user ID
lists) and 'publicView' (bool). Defaults to private.
soloist: Soloist name for the clone.
solo_instrument: Solo instrument for the clone.
Returns:
Server response with ``insertedId`` of the new transcription.
"""
if not self.user_id:
raise RuntimeError("Not authenticated: cannot clone transcription")
payload: Dict[str, Any] = {
"id": piece_id,
"newOwner": self.user_id,
}
if title is not None:
payload["title"] = title
if self.user:
payload["name"] = self.user.get("name", "")
payload["family_name"] = self.user.get("family_name", "")
payload["given_name"] = self.user.get("given_name", "")
if explicit_permissions is not None:
payload["explicitPermissions"] = explicit_permissions
else:
payload["explicitPermissions"] = {
"edit": [],
"view": [],
"publicView": False,
}
if soloist is not None:
payload["soloist"] = soloist
if solo_instrument is not None:
payload["soloInstrument"] = solo_instrument
return self._post_json("cloneTranscription", payload)
[docs]
def delete_transcription(self, piece_id: str) -> Any:
"""Delete a transcription from the server.
Removes the transcription document and the reference from the user's
transcriptions array.
Args:
piece_id: The ID of the transcription to delete.
Returns:
Server response with ``deletedCount``.
"""
if not self.user_id:
raise RuntimeError("Not authenticated: cannot delete transcription")
payload = {
"_id": piece_id,
"userID": self.user_id,
}
return self._delete_json("oneTranscription", payload)
def _prompt_for_waiver_if_needed(self) -> None:
"""Interactively prompt user to agree to waiver if not already agreed."""
if self.has_agreed_to_waiver():
return
print("\n" + "=" * 60)
print("📋 IDTAP RESEARCH WAIVER REQUIRED")
print("=" * 60)
print("\nBefore accessing transcription data, you must agree to the following terms:\n")
waiver_text = self.get_waiver_text()
print(waiver_text)
print("\n" + "=" * 60)
while True:
response = input("Do you agree to these terms? (yes/no): ").strip().lower()
if response == "yes":
print("\nSubmitting waiver agreement...")
try:
self.agree_to_waiver(i_agree=True)
print("✅ Waiver agreement successful! You now have access to transcription data.\n")
break
except Exception as e:
print(f"❌ Error submitting waiver agreement: {e}")
raise
elif response == "no":
print("\n👋 You must agree to the waiver to access transcription data.")
raise RuntimeError("Waiver agreement required but declined by user.")
else:
print("Please respond with 'yes' or 'no'.")
[docs]
def get_viewable_transcriptions(
self,
sort_key: str = "title",
sort_dir: str | int = 1,
new_permissions: Optional[bool] = None,
) -> Any:
"""Return transcriptions viewable by the user."""
# Check waiver and prompt if needed
self._prompt_for_waiver_if_needed()
params = {
"sortKey": sort_key,
"sortDir": sort_dir,
"newPermissions": new_permissions,
}
# remove None values
params = {k: str(v) for k, v in params.items() if v is not None}
return self._get("api/transcriptions", params=params)
[docs]
def update_visibility(
self,
artifact_type: str,
_id: str,
explicit_permissions: Dict[str, Any],
) -> Any:
payload = {
"artifactType": artifact_type,
"_id": _id,
"explicitPermissions": explicit_permissions,
}
return self._post_json("api/visibility", payload)
[docs]
def has_agreed_to_waiver(self) -> bool:
"""Check if the current user has agreed to the research waiver.
This makes a fresh API call to get the latest waiver status from the database.
Returns:
True if user has agreed to waiver, False otherwise
"""
try:
# Make a fresh API call to get current user data from database
fresh_user_data = self._get("api/user")
return fresh_user_data.get("waiverAgreed", False)
except Exception:
# Fall back to cached data if API call fails
if not self.user:
return False
return self.user.get("waiverAgreed", False)
[docs]
def get_waiver_text(self) -> str:
"""Get the research waiver text that users must agree to.
Returns:
The full waiver text
"""
return ("I agree to only use the IDTAP for scholarly and/or pedagogical purposes. "
"I understand that any copyrighted materials that I upload to the IDTAP "
"are liable to be taken down in response to a DMCA takedown notice.")
[docs]
def agree_to_waiver(self, i_agree: bool = False) -> Any:
"""Agree to the research waiver after reading it.
You must first read the waiver text using get_waiver_text() and then
explicitly set i_agree=True to confirm agreement.
Args:
i_agree: Must be True to confirm you have read and agree to the waiver
Returns:
Server response confirming waiver agreement
Raises:
RuntimeError: If not authenticated or if i_agree is not True
"""
if not self.user_id:
raise RuntimeError("Not authenticated: cannot agree to waiver")
if not i_agree:
waiver_text = self.get_waiver_text()
raise RuntimeError(
f"You must read and agree to the research waiver before accessing transcriptions.\n\n"
f"WAIVER TEXT:\n{waiver_text}\n\n"
f"If you agree to these terms, call: client.agree_to_waiver(i_agree=True)"
)
payload = {"userID": self.user_id}
result = self._post_json("api/agreeToWaiver", payload)
# Update local user object to reflect waiver agreement
if self.user:
self.user["waiverAgreed"] = True
return result
[docs]
def upload_audio(
self,
file_path: str,
metadata: "AudioMetadata",
audio_event: Optional["AudioEventConfig"] = None,
progress_callback: Optional[Callable[[float], None]] = None
) -> "AudioUploadResult":
"""Upload audio recording with comprehensive metadata.
Requires the `requests-toolbelt` library for multipart encoding.
Args:
file_path: Path to the audio file to upload
metadata: AudioMetadata object with recording information.
Ragas can be specified in multiple formats:
- AudioRaga objects: AudioRaga(name="Rageshree") (recommended)
- Strings: "Rageshree" (auto-converted to AudioRaga)
- Name dicts: {"name": "Rageshree"} (auto-converted to AudioRaga)
- Legacy format: {"Rageshree": {"performance_sections": {}}} (auto-converted)
audio_event: Optional AudioEventConfig for associating with audio events
progress_callback: Optional callback for upload progress (0-100)
Returns:
AudioUploadResult with upload status and file information
Raises:
FileNotFoundError: If the audio file doesn't exist
ValueError: If the file is not a supported audio format or metadata validation fails
RuntimeError: If upload fails
"""
import os
from pathlib import Path
# Validate file exists
if not os.path.exists(file_path):
raise FileNotFoundError(f"Audio file not found: {file_path}")
file_path_obj = Path(file_path)
# Check file extension
supported_extensions = {'.mp3', '.wav', '.m4a', '.flac', '.opus', '.ogg'}
if file_path_obj.suffix.lower() not in supported_extensions:
raise ValueError(f"Unsupported audio format: {file_path_obj.suffix}. "
f"Supported formats: {', '.join(supported_extensions)}")
# Validate metadata early to provide clear error messages
try:
# This will trigger raga normalization and validation
metadata.to_json()
except ValueError as e:
raise ValueError(f"Metadata validation failed: {e}")
# Prepare form data
try:
# Prepare data fields
data = {
'metadata': json.dumps(metadata.to_json()),
}
if audio_event:
data['audioEventConfig'] = json.dumps(audio_event.to_json())
# Open file and make request
with open(file_path, 'rb') as f:
files = {'audioFile': (file_path_obj.name, f, self._get_mimetype(file_path_obj.suffix))}
from requests_toolbelt.multipart.encoder import MultipartEncoder, MultipartEncoderMonitor
def progress_monitor(monitor):
progress_callback((monitor.bytes_read / monitor.len) * 100)
fields = {
'metadata': json.dumps(metadata.to_json()),
}
if audio_event:
fields['audioEventConfig'] = json.dumps(audio_event.to_json())
fields['audioFile'] = (
file_path_obj.name,
f,
self._get_mimetype(file_path_obj.suffix)
)
encoder = MultipartEncoder(fields=fields)
payload = MultipartEncoderMonitor(encoder, progress_monitor) if progress_callback else encoder
headers = {
**self._auth_headers(),
'Content-Type': payload.content_type,
'Content-Length': str(payload.len),
}
response = requests.post(
f"{self.base_url}api/uploadAudio",
data=payload,
headers=headers,
timeout=1800
)
response.raise_for_status()
result_data = response.json()
from .audio_models import AudioUploadResult
return AudioUploadResult.from_api_response(result_data)
except requests.exceptions.RequestException as e:
raise RuntimeError(f"Upload failed: {e}")
except Exception as e:
raise RuntimeError(f"Upload error: {e}")
def _get_mimetype(self, extension: str) -> str:
"""Get MIME type for file extension."""
mime_types = {
'.mp3': 'audio/mpeg',
'.wav': 'audio/wav',
'.m4a': 'audio/m4a',
'.flac': 'audio/flac',
'.opus': 'audio/opus',
'.ogg': 'audio/ogg'
}
return mime_types.get(extension.lower(), 'audio/mpeg')
[docs]
def get_available_musicians(self) -> List[Dict[str, Any]]:
"""Get list of musicians in database."""
return self._get("api/musicians")
[docs]
def get_available_ragas(self) -> List[str]:
"""Get list of ragas in database."""
return self._get("api/ragas")
[docs]
def get_raga_rules(self, raga_name: str) -> Dict[str, Any]:
"""Get pitch rules for a specific raga.
Args:
raga_name: Name of the raga to get rules for
Returns:
Dictionary containing the raga's pitch rules and updated date
Raises:
ValueError: If raga_name is empty or None
requests.HTTPError: If raga not found or API error
"""
if not raga_name:
raise ValueError("Raga name cannot be empty")
params = {"name": raga_name}
return self._get("api/ragaRules", params)
[docs]
def get_available_instruments(self, melody_only: bool = False) -> List[str]:
"""Get list of instruments in database."""
params = {'melody': 'true'} if melody_only else {}
return self._get("api/instruments", params)
[docs]
def get_location_hierarchy(self) -> "LocationHierarchy":
"""Get continent/country/city structure."""
data = self._get("api/locations")
from .audio_models import LocationHierarchy
return LocationHierarchy.from_api_response(data)
[docs]
def get_available_gharanas(self) -> List[Dict[str, Any]]:
"""Get list of gharanas in database."""
return self._get("api/gharanas")
[docs]
def get_event_types(self) -> List[str]:
"""Get list of audio event types."""
return self._get("api/eventTypes")
[docs]
def get_editable_audio_events(self) -> List[Dict[str, Any]]:
"""Get audio events the user can edit."""
return self._get("api/audioEvents")
[docs]
def logout(self, confirm: bool = False) -> bool:
"""Log out the current user and clear all stored authentication tokens.
This will:
- Clear tokens from OS keyring, encrypted storage, and plaintext files
- Reset the client's authentication state
- Require re-authentication for future API calls
Args:
confirm: Set to True to confirm logout without interactive prompt
Returns:
True if logout was successful, False otherwise
"""
if not confirm:
print("🚪 Logging out will clear all stored authentication tokens.")
print("You will need to re-authenticate to use the API again.")
user_input = input("Are you sure you want to log out? (yes/no): ").strip().lower()
if user_input != 'yes':
print("Logout cancelled.")
return False
try:
# Clear tokens from all storage backends
success = self.secure_storage.clear_tokens()
if success:
# Reset client authentication state
self.token = None
self.user = None
print("✅ Successfully logged out. All authentication tokens have been cleared.")
return True
else:
print("⚠️ Logout partially successful - some tokens may not have been cleared.")
return False
except Exception as e:
print(f"❌ Error during logout: {e}")
return False
[docs]
def download_audio(self, audio_id: str, format: str = "wav") -> bytes:
"""Download audio recording by audio ID.
Args:
audio_id: The audio recording ID
format: Audio format (wav, mp3, opus)
Returns:
Raw audio data as bytes
"""
if format not in ["wav", "mp3", "opus"]:
raise ValueError(f"Unsupported audio format: {format}. Use 'wav', 'mp3', or 'opus'")
endpoint = f"audio/{format}/{audio_id}.{format}"
return self._get(endpoint)
[docs]
def download_transcription_audio(self, piece: Union[Dict[str, Any], Piece], format: str = "wav") -> Optional[bytes]:
"""Download audio recording associated with a transcription.
Args:
piece: Transcription piece data (dict or Piece object)
format: Audio format (wav, mp3, opus)
Returns:
Raw audio data as bytes, or None if no audio is associated
"""
# Extract audio ID from piece
if hasattr(piece, 'audio_id'):
audio_id = piece.audio_id
elif isinstance(piece, dict):
audio_id = piece.get('audioID')
else:
raise TypeError(f"Expected Piece object or dict, got {type(piece)}")
if not audio_id:
return None
return self.download_audio(audio_id, format)
[docs]
def save_audio_file(self, audio_data: bytes, filename: str, filepath: Optional[str] = None) -> str:
"""Save audio data to a file.
Args:
audio_data: Raw audio data from download_audio()
filename: Output filename (should include extension)
filepath: Directory to save file (defaults to user's Downloads folder)
Returns:
Full path to the saved file
"""
import os
from pathlib import Path
if filepath is None:
# Cross-platform default to Downloads folder
if os.name == 'nt': # Windows
downloads_dir = Path.home() / 'Downloads'
else: # macOS, Linux, Unix
downloads_dir = Path.home() / 'Downloads'
filepath = str(downloads_dir)
# Ensure directory exists
Path(filepath).mkdir(parents=True, exist_ok=True)
# Combine path and filename
full_path = Path(filepath) / filename
with open(full_path, 'wb') as f:
f.write(audio_data)
return str(full_path)
[docs]
def download_and_save_transcription_audio(self, piece: Union[Dict[str, Any], Piece],
format: str = "wav",
filepath: Optional[str] = None,
filename: Optional[str] = None) -> Optional[str]:
"""Download and save audio recording associated with a transcription.
Args:
piece: Transcription piece data (dict or Piece object)
format: Audio format (wav, mp3, opus)
filepath: Directory to save file (defaults to Downloads folder)
filename: Custom filename (defaults to transcription title + ID)
Returns:
Full path to saved file, or None if no audio is associated
"""
# Download audio data
audio_data = self.download_transcription_audio(piece, format)
if not audio_data:
return None
# Generate filename if not provided
if filename is None:
if hasattr(piece, 'title') and hasattr(piece, '_id'):
title = piece.title
piece_id = piece._id
elif isinstance(piece, dict):
title = piece.get('title', 'untitled')
piece_id = piece.get('_id', 'unknown')
else:
title = 'untitled'
piece_id = 'unknown'
# Clean title for filename
clean_title = "".join(c for c in title if c.isalnum() or c in (' ', '-', '_')).strip()
filename = f"{clean_title}_{piece_id}.{format}"
# Save file and return path
return self.save_audio_file(audio_data, filename, filepath)
[docs]
def download_spectrogram_data(self, audio_id: str) -> bytes:
"""Download gzip-compressed spectrogram data.
Args:
audio_id: The audio recording ID
Returns:
Gzipped binary data containing uint8 spectrogram array
"""
endpoint = f"spec_data/{audio_id}/spec_data.gz"
return self._get(endpoint)
[docs]
def get_audio_recording(self, audio_id: str) -> Dict[str, Any]:
"""Get audio recording metadata by ID.
Fetches complete recording metadata including duration, musicians,
ragas, location, and permissions.
Args:
audio_id: The audio recording ID
Returns:
Dictionary with recording metadata including:
- duration: Audio duration in seconds (float)
- musicians: Dictionary of performer information
- raags: Dictionary of raga information
- title: Recording title
- etc.
Raises:
requests.HTTPError: If recording not found (404)
"""
return self._get("getAudioRecording", params={"_id": audio_id})
[docs]
def save_transcription(self, piece: Piece, fill_duration: bool = True) -> Any:
"""Save a transcription piece to the server.
Handles both new transcriptions (without _id) and existing transcriptions (with _id).
Args:
piece: The Piece object or dict to save
fill_duration: Whether to automatically fill remaining duration with silence
Returns:
Server response from the save operation
"""
# Convert Piece object to dict if needed
if hasattr(piece, 'to_json'):
payload = piece.to_json()
elif isinstance(piece, dict):
payload = dict(piece)
else:
raise TypeError(f"Expected Piece object with to_json() method or dict, got {type(piece)}")
# Fill remaining duration with silence if requested
if fill_duration and hasattr(piece, 'fill_remaining_duration') and hasattr(piece, 'dur_tot'):
piece.fill_remaining_duration(piece.dur_tot)
payload = piece.to_json()
# Set transcriber information from authenticated user if not already set
if hasattr(piece, 'given_name') and self.user:
if not getattr(piece, 'given_name', None):
piece.given_name = self.user.get("given_name", "")
if not getattr(piece, 'family_name', None):
piece.family_name = self.user.get("family_name", "")
if not getattr(piece, 'name', None):
piece.name = self.user.get("name", "")
# Set default soloist and instrument information if not already set
if hasattr(piece, 'soloist') and not getattr(piece, 'soloist', None):
piece.soloist = None
if hasattr(piece, 'solo_instrument') and not getattr(piece, 'solo_instrument', None):
instrumentation = getattr(piece, 'instrumentation', [])
piece.solo_instrument = instrumentation[0] if instrumentation else "Unknown Instrument"
# Regenerate payload after setting user info
if hasattr(piece, 'to_json'):
payload = piece.to_json()
else:
payload = dict(piece)
# Determine if this is a new or existing transcription
has_id = payload.get("_id") is not None
if has_id:
# Existing transcription - use save_piece
print(f"Updating existing transcription: {payload.get('title', 'untitled')}")
try:
response = self.save_piece(payload)
print("✅ Updated transcription:", response)
return response
except Exception as e:
print("❌ Failed to update transcription:", e)
raise
else:
# New transcription - remove any null _id and use insert_new_transcription
payload.pop("_id", None)
print(f"Inserting new transcription: {payload.get('title', 'untitled')}")
try:
response = self.insert_new_transcription(payload)
print("✅ Inserted transcription:", response)
return response
except Exception as e:
print("❌ Failed to insert transcription:", e)
raise
# ---- Query methods ----
[docs]
def single_query(
self,
transcription_id: str,
segmentation: Union[SegmentationType, str] = SegmentationType.PHRASE,
designator: Union[DesignatorType, str] = DesignatorType.INCLUDES,
category: Union[CategoryType, str] = CategoryType.TRAJECTORY_ID,
pitch: Optional[Pitch] = None,
sequence_length: Optional[int] = None,
trajectory_id: Optional[int] = None,
vowel: Optional[str] = None,
consonant: Optional[str] = None,
instrument_idx: int = 0,
**kwargs
) -> Query:
"""Create and execute a single query on a transcription.
Args:
transcription_id: ID of the transcription to query
segmentation: Type of segmentation (phrase, group, etc.)
designator: Query designator (includes, excludes, etc.)
category: Query category (trajectoryID, pitch, etc.)
pitch: Pitch object to search for (if category is pitch)
sequence_length: Length of trajectory sequences (if needed)
trajectory_id: Trajectory ID to search for (if category is trajectoryID)
vowel: Vowel to search for (if category is vowel)
consonant: Consonant to search for (if category is consonant)
instrument_idx: Index of instrument track to query
**kwargs: Additional query parameters
Returns:
Query object with results
"""
# Check waiver and prompt if needed
self._prompt_for_waiver_if_needed()
# Fetch the piece data
piece_data = self.get_piece(transcription_id)
piece = Piece.from_json(piece_data)
# Convert string enums to enum objects if needed
if isinstance(segmentation, str):
segmentation = SegmentationType(segmentation)
if isinstance(designator, str):
designator = DesignatorType(designator)
if isinstance(category, str):
category = CategoryType(category)
# Build query options
query_options = {
"segmentation": segmentation,
"designator": designator,
"category": category,
"pitch": pitch,
"sequence_length": sequence_length,
"trajectory_id": trajectory_id,
"vowel": vowel,
"consonant": consonant,
"instrument_idx": instrument_idx,
**kwargs
}
return Query(piece, query_options)
[docs]
def multiple_query(
self,
queries: List[Union[QueryType, Dict[str, Any]]],
transcription_id: str,
segmentation: Union[SegmentationType, str] = SegmentationType.PHRASE,
sequence_length: Optional[int] = None,
min_dur: float = 0.0,
max_dur: float = 60.0,
every: bool = True,
instrument_idx: int = 0,
) -> MultipleReturnType:
"""Execute multiple queries on a transcription and combine results.
Args:
queries: List of query specifications
transcription_id: ID of transcription to query
segmentation: Segmentation type for all queries
sequence_length: Sequence length for trajectory sequences
min_dur: Minimum duration filter
max_dur: Maximum duration filter
every: If True, require all queries to match; if False, any query can match
instrument_idx: Index of instrument track to query
Returns:
Tuple of (trajectories, identifiers, query_answers)
"""
# Check waiver and prompt if needed
self._prompt_for_waiver_if_needed()
if not queries:
raise ValueError("No queries provided")
# Fetch the piece data
piece_data = self.get_piece(transcription_id)
piece = Piece.from_json(piece_data)
# Convert string enum to enum object if needed
if isinstance(segmentation, str):
segmentation = SegmentationType(segmentation)
# Execute multiple query logic (similar to the static method but integrated)
output_trajectories: List[List["Trajectory"]] = []
output_identifiers: List[str] = []
query_answers: List[QueryAnswerType] = []
non_stringified_output_identifiers: List[Union[int, str, Dict[str, int]]] = []
# Create query objects
query_objs = []
for query in queries:
# Handle both dict and QueryType
if isinstance(query, dict):
query_dict = query
else:
query_dict = dict(query)
# Convert string enums in query if needed
if "designator" in query_dict and isinstance(query_dict["designator"], str):
query_dict["designator"] = DesignatorType(query_dict["designator"])
if "category" in query_dict and isinstance(query_dict["category"], str):
query_dict["category"] = CategoryType(query_dict["category"])
query_options = {
"segmentation": segmentation,
"designator": query_dict.get("designator"),
"category": query_dict.get("category"),
"pitch": query_dict.get("pitch"),
"sequence_length": sequence_length,
"trajectory_id": query_dict.get("trajectory_id"),
"vowel": query_dict.get("vowel"),
"consonant": query_dict.get("consonant"),
"pitch_sequence": query_dict.get("pitch_sequence"),
"traj_id_sequence": query_dict.get("traj_id_sequence"),
"section_top_level": query_dict.get("section_top_level"),
"alap_section": query_dict.get("alap_section"),
"comp_type": query_dict.get("comp_type"),
"comp_sec_tempo": query_dict.get("comp_sec_tempo"),
"tala": query_dict.get("tala"),
"phrase_type": query_dict.get("phrase_type"),
"elaboration_type": query_dict.get("elaboration_type"),
"vocal_art_type": query_dict.get("vocal_art_type"),
"inst_art_type": query_dict.get("inst_art_type"),
"incidental": query_dict.get("incidental"),
"min_dur": min_dur,
"max_dur": max_dur,
"instrument_idx": instrument_idx,
}
query_objs.append(Query(piece, query_options))
if every:
# Only select trajectories that are in all answers
if query_objs:
output_identifiers = query_objs[0].stringified_identifier[:]
for answer in query_objs[1:]:
output_identifiers = [
id_str for id_str in output_identifiers
if id_str in answer.stringified_identifier
]
# Get corresponding trajectories and answers
idxs = [
query_objs[0].stringified_identifier.index(id_str)
for id_str in output_identifiers
]
output_trajectories = [query_objs[0].trajectories[idx] for idx in idxs]
non_stringified_output_identifiers = [query_objs[0].identifier[idx] for idx in idxs]
query_answers = [query_objs[0].query_answers[idx] for idx in idxs]
else:
# Select trajectories that are in any answer
start_times = []
seen_ids = set()
for answer in query_objs:
for s_idx, s_id in enumerate(answer.stringified_identifier):
if s_id not in seen_ids:
seen_ids.add(s_id)
output_identifiers.append(s_id)
output_trajectories.append(answer.trajectories[s_idx])
non_stringified_output_identifiers.append(answer.identifier[s_idx])
query_answers.append(answer.query_answers[s_idx])
start_times.append(answer.start_times[s_idx])
# Sort by start times
sort_idxs = sorted(range(len(start_times)), key=lambda i: start_times[i])
output_trajectories = [output_trajectories[idx] for idx in sort_idxs]
non_stringified_output_identifiers = [non_stringified_output_identifiers[idx] for idx in sort_idxs]
query_answers = [query_answers[idx] for idx in sort_idxs]
return output_trajectories, non_stringified_output_identifiers, query_answers