Source code for cmdstanpy.stanfit.runset

"""
Container for the information used in a generic CmdStan run,
such as file locations
"""

import os
import re
import shutil
import tempfile
from datetime import datetime
from time import time
from typing import List, Optional

from cmdstanpy import _TMPDIR
from cmdstanpy.cmdstan_args import CmdStanArgs, Method
from cmdstanpy.utils import get_logger


[docs]class RunSet: """ Encapsulates the configuration and results of a call to any CmdStan inference method. Records the method return code and locations of all console, error, and output files. RunSet objects are instantiated by the CmdStanModel class inference methods which validate all inputs, therefore "__init__" method skips input checks. """ def __init__( self, args: CmdStanArgs, chains: int = 1, *, chain_ids: Optional[List[int]] = None, time_fmt: str = "%Y%m%d%H%M%S", one_process_per_chain: bool = True, ) -> None: """Initialize object (no input arg checks).""" self._args = args self._chains = chains self._one_process_per_chain = one_process_per_chain if one_process_per_chain: self._num_procs = chains else: self._num_procs = 1 self._retcodes = [-1 for _ in range(self._num_procs)] self._timeout_flags = [False for _ in range(self._num_procs)] if chain_ids is None: chain_ids = [i + 1 for i in range(chains)] self._chain_ids = chain_ids if args.output_dir is not None: self._output_dir = args.output_dir else: # make a per-run subdirectory of our master temp directory self._output_dir = tempfile.mkdtemp( prefix=args.model_name, dir=_TMPDIR ) # output files prefix: ``<model_name>-<YYYYMMDDHHMM>_<chain_id>`` self._base_outfile = ( f'{args.model_name}-{datetime.now().strftime(time_fmt)}' ) # per-process console messages self._stdout_files = [''] * self._num_procs if one_process_per_chain: for i in range(chains): self._stdout_files[i] = self.file_path("-stdout.txt", id=i) else: self._stdout_files[0] = self.file_path("-stdout.txt") # per-chain output files self._csv_files: List[str] = [''] * chains self._diagnostic_files = [''] * chains # optional self._profile_files = [''] * chains # optional if chains == 1: self._csv_files[0] = self.file_path(".csv") if args.save_latent_dynamics: self._diagnostic_files[0] = self.file_path( ".csv", extra="-diagnostic" ) if args.save_profile: self._profile_files[0] = self.file_path( ".csv", extra="-profile" ) else: for i in range(chains): self._csv_files[i] = self.file_path(".csv", id=chain_ids[i]) if args.save_latent_dynamics: self._diagnostic_files[i] = self.file_path( ".csv", extra="-diagnostic", id=chain_ids[i] ) if args.save_profile: self._profile_files[i] = self.file_path( ".csv", extra="-profile", id=chain_ids[i] ) def __repr__(self) -> str: repr = 'RunSet: chains={}, chain_ids={}, num_processes={}'.format( self._chains, self._chain_ids, self._num_procs ) repr = '{}\n cmd (chain 1):\n\t{}'.format(repr, self.cmd(0)) repr = '{}\n retcodes={}'.format(repr, self._retcodes) repr = f'{repr}\n per-chain output files (showing chain 1 only):' repr = '{}\n csv_file:\n\t{}'.format(repr, self._csv_files[0]) if self._args.save_latent_dynamics: repr = '{}\n diagnostics_file:\n\t{}'.format( repr, self._diagnostic_files[0] ) if self._args.save_profile: repr = '{}\n profile_file:\n\t{}'.format( repr, self._profile_files[0] ) repr = '{}\n console_msgs (if any):\n\t{}'.format( repr, self._stdout_files[0] ) return repr @property def model(self) -> str: """Stan model name.""" return self._args.model_name @property def method(self) -> Method: """CmdStan method used to generate this fit.""" return self._args.method @property def num_procs(self) -> int: """Number of processes run.""" return self._num_procs @property def one_process_per_chain(self) -> bool: """ When True, for each chain, call CmdStan in its own subprocess. When False, use CmdStan's `num_chains` arg to run parallel chains. Always True if CmdStan < 2.28. For CmdStan 2.28 and up, `sample` method determines value. """ return self._one_process_per_chain @property def chains(self) -> int: """Number of chains.""" return self._chains @property def chain_ids(self) -> List[int]: """Chain ids.""" return self._chain_ids
[docs] def cmd(self, idx: int) -> List[str]: """ Assemble CmdStan invocation. When running parallel chains from single process (2.28 and up), specify CmdStan arg `num_chains` and leave chain idx off CSV files. """ if self._one_process_per_chain: return self._args.compose_command( idx, csv_file=self.csv_files[idx], diagnostic_file=self.diagnostic_files[idx] if self._args.save_latent_dynamics else None, profile_file=self.profile_files[idx] if self._args.save_profile else None, ) else: return self._args.compose_command( idx, csv_file=self.file_path('.csv'), diagnostic_file=self.file_path(".csv", extra="-diagnostic") if self._args.save_latent_dynamics else None, profile_file=self.file_path(".csv", extra="-profile") if self._args.save_profile else None, )
@property def csv_files(self) -> List[str]: """List of paths to CmdStan output files.""" return self._csv_files @property def stdout_files(self) -> List[str]: """ List of paths to transcript of CmdStan messages sent to the console. Transcripts include config information, progress, and error messages. """ return self._stdout_files def _check_retcodes(self) -> bool: """Returns ``True`` when all chains have retcode 0.""" for code in self._retcodes: if code != 0: return False return True @property def diagnostic_files(self) -> List[str]: """List of paths to CmdStan hamiltonian diagnostic files.""" return self._diagnostic_files @property def profile_files(self) -> List[str]: """List of paths to CmdStan profiler files.""" return self._profile_files # pylint: disable=invalid-name def file_path( self, suffix: str, *, extra: str = "", id: Optional[int] = None ) -> str: if id is not None: suffix = f"_{id}{suffix}" file = os.path.join( self._output_dir, f"{self._base_outfile}{extra}{suffix}" ) return file def _retcode(self, idx: int) -> int: """Get retcode for process[idx].""" return self._retcodes[idx] def _set_retcode(self, idx: int, val: int) -> None: """Set retcode at process[idx] to val.""" self._retcodes[idx] = val def _set_timeout_flag(self, idx: int, val: bool) -> None: """Set timeout_flag at process[idx] to val.""" self._timeout_flags[idx] = val
[docs] def get_err_msgs(self) -> str: """Checks console messages for each CmdStan run.""" msgs = [] for i in range(self._num_procs): if ( os.path.exists(self._stdout_files[i]) and os.stat(self._stdout_files[i]).st_size > 0 ): if self._args.method == Method.OPTIMIZE: msgs.append('console log output:\n') with open(self._stdout_files[0], 'r') as fd: msgs.append(fd.read()) else: with open(self._stdout_files[i], 'r') as fd: contents = fd.read() # pattern matches initial "Exception" or "Error" msg pat = re.compile(r'^E[rx].*$', re.M) errors = re.findall(pat, contents) if len(errors) > 0: msgs.append('\n\t'.join(errors)) return '\n'.join(msgs)
[docs] def save_csvfiles(self, dir: Optional[str] = None) -> None: """ Moves CSV files to specified directory. :param dir: directory path See Also -------- cmdstanpy.from_csv """ if dir is None: dir = os.path.realpath('.') test_path = os.path.join(dir, str(time())) try: os.makedirs(dir, exist_ok=True) with open(test_path, 'w'): pass os.remove(test_path) # cleanup except (IOError, OSError, PermissionError) as exc: raise RuntimeError('Cannot save to path: {}'.format(dir)) from exc for i in range(self.chains): if not os.path.exists(self._csv_files[i]): raise ValueError( 'Cannot access CSV file {}'.format(self._csv_files[i]) ) to_path = os.path.join(dir, os.path.basename(self._csv_files[i])) if os.path.exists(to_path): raise ValueError( 'File exists, not overwriting: {}'.format(to_path) ) try: get_logger().debug( 'saving tmpfile: "%s" as: "%s"', self._csv_files[i], to_path ) shutil.move(self._csv_files[i], to_path) self._csv_files[i] = to_path except (IOError, OSError, PermissionError) as e: raise ValueError( 'Cannot save to file: {}'.format(to_path) ) from e
def raise_for_timeouts(self) -> None: if any(self._timeout_flags): raise TimeoutError( f"{sum(self._timeout_flags)} of {self.num_procs} processes " "timed out" )