"""CmdStan arguments"""importosfromenumimportEnum,autofromtimeimporttimefromtypingimportAny,Dict,List,Mapping,Optional,Unionimportnumpyasnpfromnumpy.randomimportdefault_rngfromcmdstanpyimport_TMPDIRfromcmdstanpy.utilsimport(cmdstan_path,cmdstan_version_before,create_named_text_file,get_logger,read_metric,write_stan_json,)OptionalPath=Union[str,os.PathLike,None]classMethod(Enum):"""Supported CmdStan method names."""SAMPLE=auto()OPTIMIZE=auto()GENERATE_QUANTITIES=auto()VARIATIONAL=auto()LAPLACE=auto()PATHFINDER=auto()def__repr__(self)->str:return'<%s.%s>'%(self.__class__.__name__,self.name)defpositive_int(value:Any,name:str)->None:ifvalueisnotNone:ifisinstance(value,(int,np.integer)):ifvalue<=0:raiseValueError(f'{name} must be greater than 0')else:raiseValueError(f'{name} must be of type int')defpositive_float(value:Any,name:str)->None:ifvalueisnotNone:ifisinstance(value,(int,float,np.floating)):ifvalue<=0:raiseValueError(f'{name} must be greater than 0')else:raiseValueError(f'{name} must be of type float')
[docs]classSamplerArgs:"""Arguments for the NUTS adaptive sampler."""def__init__(self,iter_warmup:Optional[int]=None,iter_sampling:Optional[int]=None,save_warmup:bool=False,thin:Optional[int]=None,max_treedepth:Optional[int]=None,metric:Union[str,Dict[str,Any],List[str],List[Dict[str,Any]],None]=None,step_size:Union[float,List[float],None]=None,adapt_engaged:bool=True,adapt_delta:Optional[float]=None,adapt_init_phase:Optional[int]=None,adapt_metric_window:Optional[int]=None,adapt_step_size:Optional[int]=None,fixed_param:bool=False,num_chains:int=1,)->None:"""Initialize object."""self.iter_warmup=iter_warmupself.iter_sampling=iter_samplingself.save_warmup=save_warmupself.thin=thinself.max_treedepth=max_treedepthself.metric=metricself.metric_type:Optional[str]=Noneself.metric_file:Union[str,List[str],None]=Noneself.step_size=step_sizeself.adapt_engaged=adapt_engagedself.adapt_delta=adapt_deltaself.adapt_init_phase=adapt_init_phaseself.adapt_metric_window=adapt_metric_windowself.adapt_step_size=adapt_step_sizeself.fixed_param=fixed_paramself.diagnostic_file=Noneself.num_chains=num_chains
[docs]defvalidate(self,chains:Optional[int])->None:""" Check arguments correctness and consistency. * adaptation and warmup args are consistent * if file(s) for metric are supplied, check contents. * length of per-chain lists equals specified # of chains """ifnotisinstance(chains,(int,np.integer))orchains<1:raiseValueError('Sampler expects number of chains to be greater than 0.')ifnot(self.adapt_deltaisNoneandself.adapt_init_phaseisNoneandself.adapt_metric_windowisNoneandself.adapt_step_sizeisNone):ifself.adapt_engagedisFalse:msg='Conflicting arguments: adapt_engaged: False'ifself.adapt_deltaisnotNone:msg='{}, adapt_delta: {}'.format(msg,self.adapt_delta)ifself.adapt_init_phaseisnotNone:msg='{}, adapt_init_phase: {}'.format(msg,self.adapt_init_phase)ifself.adapt_metric_windowisnotNone:msg='{}, adapt_metric_window: {}'.format(msg,self.adapt_metric_window)ifself.adapt_step_sizeisnotNone:msg='{}, adapt_step_size: {}'.format(msg,self.adapt_step_size)raiseValueError(msg)ifself.iter_warmupisnotNone:ifself.iter_warmup<0ornotisinstance(self.iter_warmup,(int,np.integer)):raiseValueError('Value for iter_warmup must be a non-negative integer,'' found {}.'.format(self.iter_warmup))ifself.iter_warmup==0andself.adapt_engaged:raiseValueError('Must specify iter_warmup > 0 when adapt_engaged=True.')ifself.iter_samplingisnotNone:ifself.iter_sampling<0ornotisinstance(self.iter_sampling,(int,np.integer)):raiseValueError('Argument "iter_sampling" must be a non-negative integer,'' found {}.'.format(self.iter_sampling))positive_int(self.thin,'thin')positive_int(self.max_treedepth,'max_treedepth')ifself.step_sizeisnotNone:ifisinstance(self.step_size,(float,int,np.integer,np.floating)):ifself.step_size<=0:raiseValueError('Argument "step_size" must be > 0, ''found {}.'.format(self.step_size))else:iflen(self.step_size)!=chains:raiseValueError('Expecting {} per-chain step_size specifications, '' found {}.'.format(chains,len(self.step_size)))fori,step_sizeinenumerate(self.step_size):ifstep_size<0:raiseValueError('Argument "step_size" must be > 0, ''chain {}, found {}.'.format(i+1,step_size))ifself.metricisnotNone:ifisinstance(self.metric,str):ifself.metricin['diag','diag_e']:self.metric_type='diag_e'elifself.metricin['dense','dense_e']:self.metric_type='dense_e'elifself.metricin['unit','unit_e']:self.metric_type='unit_e'else:ifnotos.path.exists(self.metric):raiseValueError('no such file {}'.format(self.metric))dims=read_metric(self.metric)iflen(dims)==1:self.metric_type='diag_e'else:self.metric_type='dense_e'self.metric_file=self.metricelifisinstance(self.metric,dict):if'inv_metric'notinself.metric:raiseValueError('Entry "inv_metric" not found in metric dict.')dims=list(np.asarray(self.metric['inv_metric']).shape)iflen(dims)==1:self.metric_type='diag_e'else:self.metric_type='dense_e'dict_file=create_named_text_file(dir=_TMPDIR,prefix="metric",suffix=".json")write_stan_json(dict_file,self.metric)self.metric_file=dict_fileelifisinstance(self.metric,(list,tuple)):iflen(self.metric)!=chains:raiseValueError('Number of metric files must match number of chains,'' found {} metric files for {} chains.'.format(len(self.metric),chains))ifall(isinstance(elem,dict)foreleminself.metric):metric_files:List[str]=[]fori,metricinenumerate(self.metric):metric_dict:Dict[str,Any]=metric# type: ignoreif'inv_metric'notinmetric_dict:raiseValueError('Entry "inv_metric" not found in metric dict ''for chain {}.'.format(i+1))ifi==0:dims=list(np.asarray(metric_dict['inv_metric']).shape)else:dims2=list(np.asarray(metric_dict['inv_metric']).shape)ifdims!=dims2:raiseValueError('Found inconsistent "inv_metric" entry ''for chain {}: entry has dims ''{}, expected {}.'.format(i+1,dims,dims2))dict_file=create_named_text_file(dir=_TMPDIR,prefix="metric",suffix=".json")write_stan_json(dict_file,metric_dict)metric_files.append(dict_file)iflen(dims)==1:self.metric_type='diag_e'else:self.metric_type='dense_e'self.metric_file=metric_fileselifall(isinstance(elem,str)foreleminself.metric):metric_files=[]fori,metricinenumerate(self.metric):assertisinstance(metric,str)# typecheckifnotos.path.exists(metric):raiseValueError('no such file {}'.format(metric))ifi==0:dims=read_metric(metric)else:dims2=read_metric(metric)iflen(dims)!=len(dims2):raiseValueError('Metrics files {}, {},'' inconsistent metrics'.format(self.metric[0],metric))ifdims!=dims2:raiseValueError('Metrics files {}, {},'' inconsistent metrics'.format(self.metric[0],metric))metric_files.append(metric)iflen(dims)==1:self.metric_type='diag_e'else:self.metric_type='dense_e'self.metric_file=metric_fileselse:raiseValueError('Argument "metric" must be a list of pathnames or ''Python dicts, found list of {}.'.format(type(self.metric[0])))else:raiseValueError('Invalid metric specified, not a recognized metric type, ''must be either a metric type name, a filepath, dict, ''or list of per-chain filepaths or dicts. Found ''an object of type {}.'.format(type(self.metric)))ifself.adapt_deltaisnotNone:ifnot0<self.adapt_delta<1:raiseValueError('Argument "adapt_delta" must be between 0 and 1,'' found {}'.format(self.adapt_delta))ifself.adapt_init_phaseisnotNone:ifself.adapt_init_phase<0ornotisinstance(self.adapt_init_phase,(int,np.integer)):raiseValueError('Argument "adapt_init_phase" must be a non-negative ''integer, found {}'.format(self.adapt_init_phase))ifself.adapt_metric_windowisnotNone:ifself.adapt_metric_window<0ornotisinstance(self.adapt_metric_window,(int,np.integer)):raiseValueError('Argument "adapt_metric_window" must be a non-negative '' integer, found {}'.format(self.adapt_metric_window))ifself.adapt_step_sizeisnotNone:ifself.adapt_step_size<0ornotisinstance(self.adapt_step_size,(int,np.integer)):raiseValueError('Argument "adapt_step_size" must be a non-negative integer,''found {}'.format(self.adapt_step_size))positive_int(self.num_chains,'num_chains')ifself.fixed_paramand(self.max_treedepthisnotNoneorself.metricisnotNoneorself.step_sizeisnotNoneornot(self.adapt_deltaisNoneandself.adapt_init_phaseisNoneandself.adapt_metric_windowisNoneandself.adapt_step_sizeisNone)):raiseValueError('When fixed_param=True, cannot specify adaptation parameters.')
[docs]defcompose(self,idx:int,cmd:List[str])->List[str]:""" Compose CmdStan command for method-specific non-default arguments. """cmd.append('method=sample')ifself.iter_samplingisnotNone:cmd.append(f'num_samples={self.iter_sampling}')ifself.iter_warmupisnotNone:cmd.append(f'num_warmup={self.iter_warmup}')ifself.save_warmup:cmd.append('save_warmup=1')ifself.thinisnotNone:cmd.append(f'thin={self.thin}')ifself.fixed_param:cmd.append('algorithm=fixed_param')returncmdelse:cmd.append('algorithm=hmc')ifself.max_treedepthisnotNone:cmd.append('engine=nuts')cmd.append(f'max_depth={self.max_treedepth}')ifself.step_sizeisnotNone:ifnotisinstance(self.step_size,list):cmd.append(f'stepsize={self.step_size}')else:cmd.append(f'stepsize={self.step_size[idx]}')ifself.metricisnotNone:cmd.append(f'metric={self.metric_type}')ifself.metric_fileisnotNone:ifnotisinstance(self.metric_file,list):cmd.append(f'metric_file={self.metric_file}')else:cmd.append(f'metric_file={self.metric_file[idx]}')cmd.append('adapt')ifself.adapt_engaged:cmd.append('engaged=1')else:cmd.append('engaged=0')ifself.adapt_deltaisnotNone:cmd.append(f'delta={self.adapt_delta}')ifself.adapt_init_phaseisnotNone:cmd.append(f'init_buffer={self.adapt_init_phase}')ifself.adapt_metric_windowisnotNone:cmd.append(f'window={self.adapt_metric_window}')ifself.adapt_step_sizeisnotNone:cmd.append('term_buffer={}'.format(self.adapt_step_size))ifself.num_chains>1:cmd.append('num_chains={}'.format(self.num_chains))returncmd
[docs]classOptimizeArgs:"""Container for arguments for the optimizer."""OPTIMIZE_ALGOS={'BFGS','bfgs','LBFGS','lbfgs','Newton','newton'}bfgs_only={"init_alpha","tol_obj","tol_rel_obj","tol_grad","tol_rel_grad","tol_param","history_size",}def__init__(self,algorithm:Optional[str]=None,init_alpha:Optional[float]=None,iter:Optional[int]=None,save_iterations:bool=False,tol_obj:Optional[float]=None,tol_rel_obj:Optional[float]=None,tol_grad:Optional[float]=None,tol_rel_grad:Optional[float]=None,tol_param:Optional[float]=None,history_size:Optional[int]=None,jacobian:bool=False,)->None:self.algorithm=algorithmor""self.init_alpha=init_alphaself.iter=iterself.save_iterations=save_iterationsself.tol_obj=tol_objself.tol_rel_obj=tol_rel_objself.tol_grad=tol_gradself.tol_rel_grad=tol_rel_gradself.tol_param=tol_paramself.history_size=history_sizeself.jacobian=jacobian
[docs]defvalidate(self,_chains:Optional[int]=None)->None:""" Check arguments correctness and consistency. """ifself.algorithmandself.algorithmnotinself.OPTIMIZE_ALGOS:raiseValueError('Please specify optimizer algorithms as one of [{}]'.format(', '.join(self.OPTIMIZE_ALGOS)))ifself.algorithm.lower()notin{'bfgs','lbfgs'}:forarginself.bfgs_only:ifgetattr(self,arg)isnotNone:raiseValueError(f'{arg} requires that algorithm be set to bfgs or lbfgs')ifself.algorithm.lower()!='lbfgs':ifself.history_sizeisnotNone:raiseValueError('history_size requires that algorithm be set to lbfgs')positive_float(self.init_alpha,'init_alpha')positive_int(self.iter,'iter')positive_float(self.tol_obj,'tol_obj')positive_float(self.tol_rel_obj,'tol_rel_obj')positive_float(self.tol_grad,'tol_grad')positive_float(self.tol_rel_grad,'tol_rel_grad')positive_float(self.tol_param,'tol_param')positive_int(self.history_size,'history_size')
[docs]defcompose(self,_idx:int,cmd:List[str])->List[str]:"""compose command string for CmdStan for non-default arg values."""cmd.append('method=optimize')ifself.algorithm:cmd.append(f'algorithm={self.algorithm.lower()}')ifself.init_alphaisnotNone:cmd.append(f'init_alpha={self.init_alpha}')ifself.tol_objisnotNone:cmd.append(f'tol_obj={self.tol_obj}')ifself.tol_rel_objisnotNone:cmd.append(f'tol_rel_obj={self.tol_rel_obj}')ifself.tol_gradisnotNone:cmd.append(f'tol_grad={self.tol_grad}')ifself.tol_rel_gradisnotNone:cmd.append(f'tol_rel_grad={self.tol_rel_grad}')ifself.tol_paramisnotNone:cmd.append(f'tol_param={self.tol_param}')ifself.history_sizeisnotNone:cmd.append(f'history_size={self.history_size}')ifself.iterisnotNone:cmd.append(f'iter={self.iter}')ifself.save_iterations:cmd.append('save_iterations=1')ifself.jacobian:cmd.append("jacobian=1")returncmd
[docs]classLaplaceArgs:"""Arguments needed for laplace method."""def__init__(self,mode:str,draws:Optional[int]=None,jacobian:bool=True)->None:self.mode=modeself.jacobian=jacobianself.draws=draws
[docs]defvalidate(self,_chains:Optional[int]=None)->None:"""Check arguments correctness and consistency."""ifnotos.path.exists(self.mode):raiseValueError(f'Invalid path for mode file: {self.mode}')positive_int(self.draws,'draws')
[docs]defcompose(self,_idx:int,cmd:List[str])->List[str]:"""compose command string for CmdStan for non-default arg values."""cmd.append('method=laplace')cmd.append(f'mode={self.mode}')ifself.draws:cmd.append(f'draws={self.draws}')ifnotself.jacobian:cmd.append("jacobian=0")returncmd
[docs]classPathfinderArgs:"""Container for arguments for Pathfinder."""def__init__(self,init_alpha:Optional[float]=None,tol_obj:Optional[float]=None,tol_rel_obj:Optional[float]=None,tol_grad:Optional[float]=None,tol_rel_grad:Optional[float]=None,tol_param:Optional[float]=None,history_size:Optional[int]=None,num_psis_draws:Optional[int]=None,num_paths:Optional[int]=None,max_lbfgs_iters:Optional[int]=None,num_draws:Optional[int]=None,num_elbo_draws:Optional[int]=None,save_single_paths:bool=False,psis_resample:bool=True,calculate_lp:bool=True,)->None:self.init_alpha=init_alphaself.tol_obj=tol_objself.tol_rel_obj=tol_rel_objself.tol_grad=tol_gradself.tol_rel_grad=tol_rel_gradself.tol_param=tol_paramself.history_size=history_sizeself.num_psis_draws=num_psis_drawsself.num_paths=num_pathsself.max_lbfgs_iters=max_lbfgs_itersself.num_draws=num_drawsself.num_elbo_draws=num_elbo_drawsself.save_single_paths=save_single_pathsself.psis_resample=psis_resampleself.calculate_lp=calculate_lp
[docs]defvalidate(self,_chains:Optional[int]=None)->None:""" Check arguments correctness and consistency. """positive_float(self.init_alpha,'init_alpha')positive_float(self.tol_obj,'tol_obj')positive_float(self.tol_rel_obj,'tol_rel_obj')positive_float(self.tol_grad,'tol_grad')positive_float(self.tol_rel_grad,'tol_rel_grad')positive_float(self.tol_param,'tol_param')positive_int(self.history_size,'history_size')positive_int(self.num_psis_draws,'num_psis_draws')positive_int(self.num_paths,'num_paths')positive_int(self.max_lbfgs_iters,'max_lbfgs_iters')positive_int(self.num_draws,'num_draws')positive_int(self.num_elbo_draws,'num_elbo_draws')
[docs]defcompose(self,_idx:int,cmd:List[str])->List[str]:"""compose command string for CmdStan for non-default arg values."""cmd.append('method=pathfinder')ifself.init_alphaisnotNone:cmd.append(f'init_alpha={self.init_alpha}')ifself.tol_objisnotNone:cmd.append(f'tol_obj={self.tol_obj}')ifself.tol_rel_objisnotNone:cmd.append(f'tol_rel_obj={self.tol_rel_obj}')ifself.tol_gradisnotNone:cmd.append(f'tol_grad={self.tol_grad}')ifself.tol_rel_gradisnotNone:cmd.append(f'tol_rel_grad={self.tol_rel_grad}')ifself.tol_paramisnotNone:cmd.append(f'tol_param={self.tol_param}')ifself.history_sizeisnotNone:cmd.append(f'history_size={self.history_size}')ifself.num_psis_drawsisnotNone:cmd.append(f'num_psis_draws={self.num_psis_draws}')ifself.num_pathsisnotNone:cmd.append(f'num_paths={self.num_paths}')ifself.max_lbfgs_itersisnotNone:cmd.append(f'max_lbfgs_iters={self.max_lbfgs_iters}')ifself.num_drawsisnotNone:cmd.append(f'num_draws={self.num_draws}')ifself.num_elbo_drawsisnotNone:cmd.append(f'num_elbo_draws={self.num_elbo_draws}')ifself.save_single_paths:cmd.append('save_single_paths=1')ifnotself.psis_resample:cmd.append('psis_resample=0')ifnotself.calculate_lp:cmd.append('calculate_lp=0')returncmd
classGenerateQuantitiesArgs:"""Arguments needed for generate_quantities method."""def__init__(self,csv_files:List[str])->None:"""Initialize object."""self.sample_csv_files=csv_filesdefvalidate(self,chains:Optional[int]=None# pylint: disable=unused-argument)->None:""" Check arguments correctness and consistency. * check that sample csv files exist """forcsvinself.sample_csv_files:ifnotos.path.exists(csv):raiseValueError('Invalid path for sample csv file: {}'.format(csv))defcompose(self,idx:int,cmd:List[str])->List[str]:""" Compose CmdStan command for method-specific non-default arguments. """cmd.append('method=generate_quantities')cmd.append(f'fitted_params={self.sample_csv_files[idx]}')returncmd
[docs]classVariationalArgs:"""Arguments needed for variational method."""VARIATIONAL_ALGOS={'meanfield','fullrank'}def__init__(self,algorithm:Optional[str]=None,iter:Optional[int]=None,grad_samples:Optional[int]=None,elbo_samples:Optional[int]=None,eta:Optional[float]=None,adapt_iter:Optional[int]=None,adapt_engaged:bool=True,tol_rel_obj:Optional[float]=None,eval_elbo:Optional[int]=None,output_samples:Optional[int]=None,)->None:self.algorithm=algorithmself.iter=iterself.grad_samples=grad_samplesself.elbo_samples=elbo_samplesself.eta=etaself.adapt_iter=adapt_iterself.adapt_engaged=adapt_engagedself.tol_rel_obj=tol_rel_objself.eval_elbo=eval_elboself.output_samples=output_samples
[docs]defvalidate(self,chains:Optional[int]=None# pylint: disable=unused-argument)->None:""" Check arguments correctness and consistency. """if(self.algorithmisnotNoneandself.algorithmnotinself.VARIATIONAL_ALGOS):raiseValueError('Please specify variational algorithms as one of [{}]'.format(', '.join(self.VARIATIONAL_ALGOS)))positive_int(self.iter,'iter')positive_int(self.grad_samples,'grad_samples')positive_int(self.elbo_samples,'elbo_samples')positive_float(self.eta,'eta')positive_int(self.adapt_iter,'adapt_iter')positive_float(self.tol_rel_obj,'tol_rel_obj')positive_int(self.eval_elbo,'eval_elbo')positive_int(self.output_samples,'output_samples')
# pylint: disable=unused-argument
[docs]defcompose(self,idx:int,cmd:List[str])->List[str]:""" Compose CmdStan command for method-specific non-default arguments. """cmd.append('method=variational')ifself.algorithmisnotNone:cmd.append(f'algorithm={self.algorithm}')ifself.iterisnotNone:cmd.append(f'iter={self.iter}')ifself.grad_samplesisnotNone:cmd.append(f'grad_samples={self.grad_samples}')ifself.elbo_samplesisnotNone:cmd.append(f'elbo_samples={self.elbo_samples}')ifself.etaisnotNone:cmd.append(f'eta={self.eta}')cmd.append('adapt')ifself.adapt_engaged:cmd.append('engaged=1')ifself.adapt_iterisnotNone:cmd.append(f'iter={self.adapt_iter}')else:cmd.append('engaged=0')ifself.tol_rel_objisnotNone:cmd.append(f'tol_rel_obj={self.tol_rel_obj}')ifself.eval_elboisnotNone:cmd.append(f'eval_elbo={self.eval_elbo}')ifself.output_samplesisnotNone:cmd.append(f'output_samples={self.output_samples}')returncmd
[docs]classCmdStanArgs:""" Container for CmdStan command line arguments. Consists of arguments common to all methods and and an object which contains the method-specific arguments. """def__init__(self,model_name:str,model_exe:OptionalPath,chain_ids:Optional[List[int]],method_args:Union[SamplerArgs,OptimizeArgs,GenerateQuantitiesArgs,VariationalArgs,LaplaceArgs,PathfinderArgs,],data:Union[Mapping[str,Any],str,None]=None,seed:Union[int,List[int],None]=None,inits:Union[int,float,str,List[str],None]=None,output_dir:OptionalPath=None,sig_figs:Optional[int]=None,save_latent_dynamics:bool=False,save_profile:bool=False,refresh:Optional[int]=None,)->None:"""Initialize object."""self.model_name=model_nameself.model_exe=model_exeself.chain_ids=chain_idsself.data=dataself.seed=seedself.inits=initsself.output_dir=output_dirself.sig_figs=sig_figsself.save_latent_dynamics=save_latent_dynamicsself.save_profile=save_profileself.refresh=refreshself.method_args=method_argsifisinstance(method_args,SamplerArgs):self.method=Method.SAMPLEelifisinstance(method_args,OptimizeArgs):self.method=Method.OPTIMIZEelifisinstance(method_args,GenerateQuantitiesArgs):self.method=Method.GENERATE_QUANTITIESelifisinstance(method_args,VariationalArgs):self.method=Method.VARIATIONALelifisinstance(method_args,LaplaceArgs):self.method=Method.LAPLACEelifisinstance(method_args,PathfinderArgs):self.method=Method.PATHFINDERelse:raiseValueError('Unsupported method args type: {}'.format(type(method_args)))self.method_args.validate(len(chain_ids)ifchain_idselseNone)self.validate()
[docs]defvalidate(self)->None:""" Check arguments correctness and consistency. * input files must exist * output files must be in a writeable directory * if no seed specified, set random seed. * length of per-chain lists equals specified # of chains """ifself.model_nameisNone:raiseValueError('no stan model specified')ifself.model_exeisNone:raiseValueError('model not compiled')ifself.chain_idsisnotNone:forchain_idinself.chain_ids:ifchain_id<1:raiseValueError('invalid chain_id {}'.format(chain_id))ifself.output_dirisnotNone:self.output_dir=os.path.realpath(os.path.expanduser(self.output_dir))ifnotos.path.exists(self.output_dir):try:os.makedirs(self.output_dir)get_logger().info('created output directory: %s',self.output_dir)except(RuntimeError,PermissionError)asexc:raiseValueError('Invalid path for output files, ''no such dir: {}.'.format(self.output_dir))fromexcifnotos.path.isdir(self.output_dir):raiseValueError('Specified output_dir is not a directory: {}.'.format(self.output_dir))try:testpath=os.path.join(self.output_dir,str(time()))withopen(testpath,'w+'):passos.remove(testpath)# cleanupexceptExceptionasexc:raiseValueError('Invalid path for output files,'' cannot write to dir: {}.'.format(self.output_dir))fromexcifself.refreshisnotNone:if(notisinstance(self.refresh,(int,np.integer))orself.refresh<1):raiseValueError('Argument "refresh" must be a positive integer value, ''found {}.'.format(self.refresh))ifself.sig_figsisnotNone:if(notisinstance(self.sig_figs,(int,np.integer))orself.sig_figs<1orself.sig_figs>18):raiseValueError('Argument "sig_figs" must be an integer between 1 and 18,'' found {}'.format(self.sig_figs))# TODO: remove at some future releaseifcmdstan_version_before(2,25):self.sig_figs=Noneget_logger().warning('Argument "sig_figs" invalid for CmdStan versions < 2.25, ''using version %s in directory %s',os.path.basename(cmdstan_path()),os.path.dirname(cmdstan_path()),)ifself.seedisNone:rng=default_rng()self.seed=rng.integers(low=1,high=99999,size=1).item()else:ifnotisinstance(self.seed,(int,list,np.integer)):raiseValueError('Argument "seed" must be an integer between ''0 and 2**32-1, found {}.'.format(self.seed))ifisinstance(self.seed,(int,np.integer)):ifself.seed<0orself.seed>2**32-1:raiseValueError('Argument "seed" must be an integer between ''0 and 2**32-1, found {}.'.format(self.seed))else:ifself.chain_idsisNone:raiseValueError('List of per-chain seeds cannot be evaluated without ''corresponding list of chain_ids.')iflen(self.seed)!=len(self.chain_ids):raiseValueError('Number of seeds must match number of chains,'' found {} seed for {} chains.'.format(len(self.seed),len(self.chain_ids)))forseedinself.seed:ifseed<0orseed>2**32-1:raiseValueError('Argument "seed" must be an integer value'' between 0 and 2**32-1,'' found {}'.format(seed))ifisinstance(self.data,str):ifnotos.path.exists(self.data):raiseValueError('no such file {}'.format(self.data))elifself.dataisnotNoneandnotisinstance(self.data,(str,dict)):raiseValueError('Argument "data" must be string or dict')ifself.initsisnotNone:ifisinstance(self.inits,(float,int,np.floating,np.integer)):ifself.inits<0:raiseValueError('Argument "inits" must be > 0, found {}'.format(self.inits))elifisinstance(self.inits,str):ifnot(isinstance(self.method_args,SamplerArgs)andself.method_args.num_chains>1orisinstance(self.method_args,PathfinderArgs)):ifnotos.path.exists(self.inits):raiseValueError('no such file {}'.format(self.inits))elifisinstance(self.inits,list):ifself.chain_idsisNone:raiseValueError('List of inits files cannot be evaluated without ''corresponding list of chain_ids.')iflen(self.inits)!=len(self.chain_ids):raiseValueError('Number of inits files must match number of chains,'' found {} inits files for {} chains.'.format(len(self.inits),len(self.chain_ids)))forinitsinself.inits:ifnotos.path.exists(inits):raiseValueError('no such file {}'.format(inits))
[docs]defcompose_command(self,idx:int,csv_file:str,*,diagnostic_file:Optional[str]=None,profile_file:Optional[str]=None,)->List[str]:""" Compose CmdStan command for non-default arguments. """cmd:List[str]=[]ifidxisnotNoneandself.chain_idsisnotNone:ifidx<0oridx>len(self.chain_ids)-1:raiseValueError('index ({}) exceeds number of chains ({})'.format(idx,len(self.chain_ids)))cmd.append(self.model_exe)# type: ignore # guaranteed by validatecmd.append(f'id={self.chain_ids[idx]}')else:cmd.append(self.model_exe)# type: ignore # guaranteed by validateifself.seedisnotNone:ifnotisinstance(self.seed,list):cmd.append('random')cmd.append(f'seed={self.seed}')else:cmd.append('random')cmd.append(f'seed={self.seed[idx]}')ifself.dataisnotNone:cmd.append('data')cmd.append(f'file={self.data}')ifself.initsisnotNone:ifnotisinstance(self.inits,list):cmd.append(f'init={self.inits}')else:cmd.append(f'init={self.inits[idx]}')cmd.append('output')cmd.append(f'file={csv_file}')ifdiagnostic_file:cmd.append(f'diagnostic_file={diagnostic_file}')ifprofile_file:cmd.append(f'profile_file={profile_file}')ifself.refreshisnotNone:cmd.append(f'refresh={self.refresh}')ifself.sig_figsisnotNone:cmd.append(f'sig_figs={self.sig_figs}')cmd=self.method_args.compose(idx,cmd)returncmd