diff --git a/cmdstanpy/cmdstan_args.py b/cmdstanpy/cmdstan_args.py index 07040d6d..b4d0ef64 100644 --- a/cmdstanpy/cmdstan_args.py +++ b/cmdstanpy/cmdstan_args.py @@ -377,6 +377,7 @@ def compose(self, idx: int, cmd: List[str]) -> List[str]: else: cmd.append(f'metric_file={self.metric_file[idx]}') cmd.append('adapt') + cmd.append("save_metric=1") if self.adapt_engaged: cmd.append('engaged=1') else: @@ -1001,5 +1002,7 @@ def compose_command( cmd.append(f'refresh={self.refresh}') if self.sig_figs is not None: cmd.append(f'sig_figs={self.sig_figs}') + # NB: implies a lower bound on cmdstan versions of 2.35 + cmd.append("save_cmdstan_config=true") cmd = self.method_args.compose(idx, cmd) return cmd diff --git a/cmdstanpy/stanfit/__init__.py b/cmdstanpy/stanfit/__init__.py index 50764a30..663aef4e 100644 --- a/cmdstanpy/stanfit/__init__.py +++ b/cmdstanpy/stanfit/__init__.py @@ -35,6 +35,7 @@ ] +# should this be a static method of each class? def from_csv( path: Union[str, List[str], os.PathLike, None] = None, method: Optional[str] = None, diff --git a/cmdstanpy/stanfit/mcmc.py b/cmdstanpy/stanfit/mcmc.py index f96ff023..84e06cde 100644 --- a/cmdstanpy/stanfit/mcmc.py +++ b/cmdstanpy/stanfit/mcmc.py @@ -2,6 +2,7 @@ Container for the result of running the sample (MCMC) method """ +import json import math import os from io import StringIO @@ -32,7 +33,6 @@ from cmdstanpy.utils import ( EXTENSION, build_xarray_data, - check_sampler_csv, cmdstan_path, cmdstan_version_before, create_named_text_file, @@ -45,6 +45,10 @@ from .runset import RunSet +# Eventually have a from_files and a from_runset, where from_runset does +# additional checks like that the requested number of draws is actually present +# In this world, we stop storing the runset, relying on the +# files-on-disk as the source of truth class CmdStanMCMC: """ Container for outputs from CmdStan sampler run. @@ -52,9 +56,7 @@ class CmdStanMCMC: and accessor methods to access the entire sample or individual items. Created by :meth:`CmdStanModel.sample` - The sample is lazily instantiated on first access of either - the resulting sample or the HMC tuning parameters, i.e., the - step size and metric. + """ # pylint: disable=too-many-public-methods @@ -99,9 +101,7 @@ def __init__( ) self._chain_time: List[Dict[str, float]] = [] - # info from CSV header and initial and final comment blocks - config = self._validate_csv_files() - self._metadata: InferenceMetadata = InferenceMetadata(config) + self._assemble_draws() if not self._is_fixed_param: self._check_sampler_diagnostics() @@ -129,14 +129,6 @@ def __getattr__(self, attr: str) -> np.ndarray: # pylint: disable=raise-missing-from raise AttributeError(*e.args) - def __getstate__(self) -> dict: - # This function returns the mapping of objects to serialize with pickle. - # See https://docs.python.org/3/library/pickle.html#object.__getstate__ - # for details. We call _assemble_draws to ensure posterior samples have - # been loaded prior to serialization. - self._assemble_draws() - return self.__dict__ - @property def chains(self) -> int: """Number of chains.""" @@ -177,7 +169,7 @@ def column_names(self) -> Tuple[str, ...]: and quantities of interest. Corresponds to Stan CSV file header row, with names munged to array notation, e.g. `beta[1]` not `beta.1`. """ - return self._metadata.cmdstan_config['column_names'] # type: ignore + return self._metadata.column_names # type: ignore @property def metric_type(self) -> Optional[str]: @@ -186,10 +178,14 @@ def metric_type(self) -> Optional[str]: to CmdStan arg 'metric'. When sampler algorithm 'fixed_param' is specified, metric_type is None. """ - return ( - self._metadata.cmdstan_config['metric'] - if not self._is_fixed_param - else None + + return ( # type: ignore + self._metadata.cmdstan_config.get("method", {}) + .get("sample", {}) + .get("algorithm", {}) + .get("hmc", {}) + .get("metric", {}) + .get("value", None) ) @property @@ -200,12 +196,6 @@ def metric(self) -> Optional[np.ndarray]: """ if self._is_fixed_param: return None - if self._metadata.cmdstan_config['metric'] == 'unit_e': - get_logger().info( - 'Unit diagnonal metric, inverse mass matrix size unknown.' - ) - return None - self._assemble_draws() return self._metric @property @@ -214,7 +204,6 @@ def step_size(self) -> Optional[np.ndarray]: Step size used by sampler for each chain. When sampler algorithm 'fixed_param' is specified, step size is None. """ - self._assemble_draws() return self._step_size if not self._is_fixed_param else None @property @@ -275,8 +264,6 @@ def draws( CmdStanMCMC.draws_xr CmdStanGQ.draws """ - self._assemble_draws() - if inc_warmup and not self._save_warmup: get_logger().warning( "Sample doesn't contain draws from warmup iterations," @@ -291,75 +278,11 @@ def draws( return flatten_chains(self._draws[start_idx:, :, :]) return self._draws[start_idx:, :, :] - def _validate_csv_files(self) -> Dict[str, Any]: - """ - Checks that Stan CSV output files for all chains are consistent - and returns dict containing config and column names. - - Tabulates sampling iters which are divergent or at max treedepth - Raises exception when inconsistencies detected. - """ - dzero = {} - for i in range(self.chains): - if i == 0: - dzero = check_sampler_csv( - path=self.runset.csv_files[i], - is_fixed_param=self._is_fixed_param, - iter_sampling=self._iter_sampling, - iter_warmup=self._iter_warmup, - save_warmup=self._save_warmup, - thin=self._thin, - ) - self._chain_time.append(dzero['time']) # type: ignore - if not self._is_fixed_param: - self._divergences[i] = dzero['ct_divergences'] - self._max_treedepths[i] = dzero['ct_max_treedepth'] - else: - drest = check_sampler_csv( - path=self.runset.csv_files[i], - is_fixed_param=self._is_fixed_param, - iter_sampling=self._iter_sampling, - iter_warmup=self._iter_warmup, - save_warmup=self._save_warmup, - thin=self._thin, - ) - self._chain_time.append(drest['time']) # type: ignore - for key in dzero: - # check args that matter for parsing, plus name, version - if ( - key - in [ - 'stan_version_major', - 'stan_version_minor', - 'stan_version_patch', - 'stanc_version', - 'model', - 'num_samples', - 'num_warmup', - 'save_warmup', - 'thin', - 'refresh', - ] - and dzero[key] != drest[key] - ): - raise ValueError( - 'CmdStan config mismatch in Stan CSV file {}: ' - 'arg {} is {}, expected {}'.format( - self.runset.csv_files[i], - key, - dzero[key], - drest[key], - ) - ) - if not self._is_fixed_param: - self._divergences[i] = drest['ct_divergences'] - self._max_treedepths[i] = drest['ct_max_treedepth'] - return dzero - def _check_sampler_diagnostics(self) -> None: """ Warn if any iterations ended in divergences or hit maxtreedepth. """ + # TODO re-write to just sum over these columns of draws if np.any(self._divergences) or np.any(self._max_treedepths): diagnostics = ['Some chains may have failed to converge.'] ct_iters = self._metadata.cmdstan_config['num_samples'] @@ -383,90 +306,52 @@ def _check_sampler_diagnostics(self) -> None: get_logger().warning('\n\t'.join(diagnostics)) def _assemble_draws(self) -> None: - """ - Allocates and populates the step size, metric, and sample arrays - by parsing the validated stan_csv files. - """ - if self._draws.shape != (0,): - return num_draws = self.num_draws_sampling - sampling_iter_start = 0 if self._save_warmup: num_draws += self.num_draws_warmup - sampling_iter_start = self.num_draws_warmup - self._draws = np.empty( - (num_draws, self.chains, len(self.column_names)), - dtype=float, - order='F', - ) + + draws = [] self._step_size = np.empty(self.chains, dtype=float) for chain in range(self.chains): - with open(self.runset.csv_files[chain], 'r') as fd: - line = fd.readline().strip() - # read initial comments, CSV header row - while len(line) > 0 and line.startswith('#'): - line = fd.readline().strip() - if not self._is_fixed_param: - # handle warmup draws, if any - if self._save_warmup: - for i in range(self.num_draws_warmup): - line = fd.readline().strip() - xs = line.split(',') - self._draws[i, chain, :] = [float(x) for x in xs] - line = fd.readline().strip() - if line != '# Adaptation terminated': # shouldn't happen? - while line != '# Adaptation terminated': - line = fd.readline().strip() - # step_size, metric (diag_e and dense_e only) - line = fd.readline().strip() - _, step_size = line.split('=') - self._step_size[chain] = float(step_size.strip()) - if self._metadata.cmdstan_config['metric'] != 'unit_e': - line = fd.readline().strip() # metric type - line = fd.readline().lstrip(' #\t').rstrip() - num_unconstrained_params = len(line.split(',')) - if chain == 0: # can't allocate w/o num params - if self.metric_type == 'diag_e': - self._metric = np.empty( - (self.chains, num_unconstrained_params), - dtype=float, - ) - else: - self._metric = np.empty( - ( - self.chains, - num_unconstrained_params, - num_unconstrained_params, - ), - dtype=float, - ) - if line: - if self.metric_type == 'diag_e': - xs = line.split(',') - self._metric[chain, :] = [float(x) for x in xs] - else: - xs = line.strip().split(',') - self._metric[chain, 0, :] = [ - float(x) for x in xs - ] - for i in range(1, num_unconstrained_params): - line = fd.readline().lstrip(' #\t').rstrip() - xs = line.split(',') - self._metric[chain, i, :] = [ - float(x) for x in xs - ] - else: # unit_e changed in 2.34 to have an extra line - pos = fd.tell() - line = fd.readline().strip() - if not line.startswith('#'): - fd.seek(pos) - - # process draws - for i in range(sampling_iter_start, num_draws): - line = fd.readline().strip() - xs = line.split(',') - self._draws[i, chain, :] = [float(x) for x in xs] - assert self._draws is not None + metric_file = self.runset.metric_files[chain] + sample_file = self.runset.csv_files[chain] + + with open(metric_file, 'r') as fd: + d = json.load(fd) + self._metric_type = d['metric_type'] + if chain == 0: + if self._metric_type == 'dense_e': + self._metric = np.empty( + ( + self.chains, + len(d['inv_metric']), + len(d['inv_metric']), + ), + dtype=float, + ) + else: + self._metric = np.empty( + (self.chains, len(d['inv_metric'])), dtype=float + ) + + self._step_size = np.empty(self.chains, dtype=float) + + self._metric[chain, ...] = np.array(d['inv_metric']) + self._metric_type = d['metric_type'] + self._step_size[chain] = d['stepsize'] + + param_names, sample = read_raw(sample_file) + if chain == 0: + self._param_names = param_names + draws.append(sample) + + self._draws = np.array(draws).reshape( + num_draws, self.chains, len(self._param_names) + ) + + with open(self.runset.config_file, 'r') as config_json: + config = json.load(config_json) + self._metadata = InferenceMetadata(config, self._param_names) def summary( self, @@ -618,7 +503,6 @@ def draws_pd( ' must run sampler with "save_warmup=True".' ) - self._assemble_draws() cols = [] if vars is not None: for var in dict.fromkeys(vars_list): @@ -695,8 +579,6 @@ def draws_xr( else: vars_list = vars - self._assemble_draws() - num_draws = self.num_draws_sampling meta = self._metadata.cmdstan_config attrs: MutableMapping[Hashable, Any] = { @@ -814,7 +696,6 @@ def method_variables(self) -> Dict[str, np.ndarray]: Maps each column name to a numpy.ndarray (draws x chains x 1) containing per-draw diagnostic values. """ - self._assemble_draws() return { name: var.extract_reshape(self._draws) for name, var in self._metadata.method_vars.items() @@ -835,3 +716,37 @@ def save_csvfiles(self, dir: Optional[str] = None) -> None: cmdstanpy.from_csv """ self.runset.save_csvfiles(dir) + + +def read_raw(file: str) -> Tuple[List[str], np.ndarray]: + with open(file, "rb") as f: + magic = f.read(4) + if magic not in [b"STAN", b"NATS"]: + raise ValueError(f"Invalid magic bytes {magic!r}, expected STAN") + + little_endian = magic == b"STAN" + + header_size = ( + int.from_bytes(f.read(8), "little" if little_endian else "big") + - 1 # null terminator + ) + header = f.read(header_size).decode("utf-8") + header_row = header.split(",") + columns = len(header_row) + + # next multiple of 8 + start = (12 + header_size + 7) & ~7 + f.seek(start) + buf = f.read() + + if len(buf) % 8 != 0: + raise ValueError("Invalid file size") + rows = len(buf) // (columns * 8) + + data: np.ndarray = np.ndarray( + shape=(rows, columns), + dtype="f8", + buffer=buf, + ) + + return (header_row, data) diff --git a/cmdstanpy/stanfit/metadata.py b/cmdstanpy/stanfit/metadata.py index 4869f2a0..9ad99d11 100644 --- a/cmdstanpy/stanfit/metadata.py +++ b/cmdstanpy/stanfit/metadata.py @@ -1,7 +1,7 @@ """Container for metadata parsed from the output of a CmdStan run""" import copy -from typing import Any, Dict +from typing import Any, Dict, List import stanio @@ -13,10 +13,11 @@ class InferenceMetadata: Assumes valid CSV files. """ - def __init__(self, config: Dict[str, Any]) -> None: + def __init__(self, config: Dict[str, Any], param_names: List[str]) -> None: """Initialize object from CSV headers""" self._cmdstan_config = config - vars = stanio.parse_header(config['raw_header']) + self.column_names = param_names + vars = stanio.parse_header(','.join(param_names)) self._method_vars = { k: v for (k, v) in vars.items() if k.endswith('__') diff --git a/cmdstanpy/stanfit/runset.py b/cmdstanpy/stanfit/runset.py index de11a461..e782bbba 100644 --- a/cmdstanpy/stanfit/runset.py +++ b/cmdstanpy/stanfit/runset.py @@ -81,9 +81,11 @@ def __init__( # per-chain output files self._csv_files: List[str] = [''] * chains self._diagnostic_files = [''] * chains # optional + self._metric_files = [''] * chains if chains == 1: self._csv_files[0] = self.file_path(".csv") + self._metric_files[0] = self.file_path(".json", extra="_metric") if args.save_latent_dynamics: self._diagnostic_files[0] = self.file_path( ".csv", extra="-diagnostic" @@ -91,6 +93,9 @@ def __init__( else: for i in range(chains): self._csv_files[i] = self.file_path(".csv", id=chain_ids[i]) + self._metric_files[i] = self.file_path( + ".json", extra="_metric", id=chain_ids[i] + ) if args.save_latent_dynamics: self._diagnostic_files[i] = self.file_path( ".csv", extra="-diagnostic", id=chain_ids[i] @@ -162,23 +167,29 @@ def cmd(self, idx: int) -> List[str]: return self._args.compose_command( idx, csv_file=self.csv_files[idx], - diagnostic_file=self.diagnostic_files[idx] - if self._args.save_latent_dynamics - else None, - profile_file=self.profile_files[idx] - if self._args.save_profile - else None, + diagnostic_file=( + self.diagnostic_files[idx] + if self._args.save_latent_dynamics + else None + ), + profile_file=( + self.profile_files[idx] if self._args.save_profile else None + ), ) else: return self._args.compose_command( idx, csv_file=self.file_path('.csv'), - diagnostic_file=self.file_path(".csv", extra="-diagnostic") - if self._args.save_latent_dynamics - else None, - profile_file=self.file_path(".csv", extra="-profile") - if self._args.save_profile - else None, + diagnostic_file=( + self.file_path(".csv", extra="-diagnostic") + if self._args.save_latent_dynamics + else None + ), + profile_file=( + self.file_path(".csv", extra="-profile") + if self._args.save_profile + else None + ), ) @property @@ -186,6 +197,21 @@ def csv_files(self) -> List[str]: """List of paths to CmdStan output files.""" return self._csv_files + @property + def metric_files(self) -> List[str]: + """List of paths to CmdStan output files.""" + return self._metric_files + + @property + def config_file(self) -> str: + """Path to CmdStan config file.""" + if self.one_process_per_chain: + return self.file_path( + ".json", extra="_config", id=self.chain_ids[0] + ) + else: + return self.file_path(".json", extra="_config") + @property def stdout_files(self) -> List[str]: """ @@ -216,7 +242,11 @@ def file_path( self, suffix: str, *, extra: str = "", id: Optional[int] = None ) -> str: if id is not None: - suffix = f"_{id}{suffix}" + # TODO sort this out + if self.one_process_per_chain: + extra = f"_{id}{extra}" + else: + suffix = f"_{id}{suffix}" file = os.path.join( self._output_dir, f"{self._base_outfile}{extra}{suffix}" ) diff --git a/cmdstanpy_tutorial.py b/cmdstanpy_tutorial.py index e7f052f6..d8bdccad 100644 --- a/cmdstanpy_tutorial.py +++ b/cmdstanpy_tutorial.py @@ -32,6 +32,7 @@ # ### Access the sample: the `CmdStanMCMC` object attributes and methods print(fit.draws().shape) +print({k:v.mean() for k,v in fit.stan_variables().items()}) # #### Get HMC sampler tuning parameters @@ -39,10 +40,12 @@ print(fit.metric_type) print(fit.metric) + + # #### Summarize the results -print(fit.summary()) +# print(fit.summary()) # #### Run sampler diagnostics -print(fit.diagnose()) +# print(fit.diagnose())