From 20c57ccdd2351474eae2194587b5a0e1b0a94f5f Mon Sep 17 00:00:00 2001 From: Brian Ward Date: Fri, 24 Mar 2023 13:28:54 -0400 Subject: [PATCH] Accept np.floating and np.integer for arguments --- cmdstanpy/cmdstan_args.py | 83 +++++++++++++++++++++++---------------- test/test_cmdstan_args.py | 18 +++++++++ 2 files changed, 68 insertions(+), 33 deletions(-) diff --git a/cmdstanpy/cmdstan_args.py b/cmdstanpy/cmdstan_args.py index 2fc75261..49b96a20 100644 --- a/cmdstanpy/cmdstan_args.py +++ b/cmdstanpy/cmdstan_args.py @@ -81,7 +81,7 @@ def validate(self, chains: Optional[int]) -> None: * if file(s) for metric are supplied, check contents. * length of per-chain lists equals specified # of chains """ - if not isinstance(chains, int) or chains < 1: + if not isinstance(chains, (int, np.integer)) or chains < 1: raise ValueError( 'Sampler expects number of chains to be greater than 0.' ) @@ -110,7 +110,9 @@ def validate(self, chains: Optional[int]) -> None: raise ValueError(msg) if self.iter_warmup is not None: - if self.iter_warmup < 0 or not isinstance(self.iter_warmup, int): + if self.iter_warmup < 0 or not isinstance( + self.iter_warmup, (int, np.integer) + ): raise ValueError( 'Value for iter_warmup must be a non-negative integer,' ' found {}.'.format(self.iter_warmup) @@ -122,28 +124,30 @@ def validate(self, chains: Optional[int]) -> None: ) if self.iter_sampling is not None: if self.iter_sampling < 0 or not isinstance( - self.iter_sampling, int + self.iter_sampling, (int, np.integer) ): raise ValueError( 'Argument "iter_sampling" must be a non-negative integer,' ' found {}.'.format(self.iter_sampling) ) if self.thin is not None: - if self.thin < 1 or not isinstance(self.thin, int): + if self.thin < 1 or not isinstance(self.thin, (int, np.integer)): raise ValueError( 'Argument "thin" must be a positive integer,' 'found {}.'.format(self.thin) ) if self.max_treedepth is not None: if self.max_treedepth < 1 or not isinstance( - self.max_treedepth, int + self.max_treedepth, (int, np.integer) ): raise ValueError( 'Argument "max_treedepth" must be a positive integer,' ' found {}.'.format(self.max_treedepth) ) if self.step_size is not None: - if isinstance(self.step_size, (float, int)): + if isinstance( + self.step_size, (float, int, np.integer, np.floating) + ): if self.step_size <= 0: raise ValueError( 'Argument "step_size" must be > 0, ' @@ -178,7 +182,7 @@ def validate(self, chains: Optional[int]) -> None: else: self.metric_type = 'dense_e' self.metric_file = self.metric - elif isinstance(self.metric, Dict): + elif isinstance(self.metric, dict): if 'inv_metric' not in self.metric: raise ValueError( 'Entry "inv_metric" not found in metric dict.' @@ -289,7 +293,7 @@ def validate(self, chains: Optional[int]) -> None: ) if self.adapt_init_phase is not None: if self.adapt_init_phase < 0 or not isinstance( - self.adapt_init_phase, int + self.adapt_init_phase, (int, np.integer) ): raise ValueError( 'Argument "adapt_init_phase" must be a non-negative ' @@ -297,7 +301,7 @@ def validate(self, chains: Optional[int]) -> None: ) if self.adapt_metric_window is not None: if self.adapt_metric_window < 0 or not isinstance( - self.adapt_metric_window, int + self.adapt_metric_window, (int, np.integer) ): raise ValueError( 'Argument "adapt_metric_window" must be a non-negative ' @@ -305,7 +309,7 @@ def validate(self, chains: Optional[int]) -> None: ) if self.adapt_step_size is not None: if self.adapt_step_size < 0 or not isinstance( - self.adapt_step_size, int + self.adapt_step_size, (int, np.integer) ): raise ValueError( 'Argument "adapt_step_size" must be a non-negative integer,' @@ -426,14 +430,14 @@ def validate( raise ValueError( 'init_alpha requires that algorithm be set to bfgs or lbfgs' ) - if isinstance(self.init_alpha, float): + if isinstance(self.init_alpha, (float, np.floating)): if self.init_alpha <= 0: raise ValueError('init_alpha must be greater than 0') else: raise ValueError('init_alpha must be type of float') if self.iter is not None: - if isinstance(self.iter, int): + if isinstance(self.iter, (int, np.integer)): if self.iter < 0: raise ValueError('iter must be greater than 0') else: @@ -444,7 +448,7 @@ def validate( raise ValueError( 'tol_obj requires that algorithm be set to bfgs or lbfgs' ) - if isinstance(self.tol_obj, float): + if isinstance(self.tol_obj, (float, np.floating)): if self.tol_obj <= 0: raise ValueError('tol_obj must be greater than 0') else: @@ -456,7 +460,7 @@ def validate( 'tol_rel_obj requires that algorithm be set to bfgs' ' or lbfgs' ) - if isinstance(self.tol_rel_obj, float): + if isinstance(self.tol_rel_obj, (float, np.floating)): if self.tol_rel_obj <= 0: raise ValueError('tol_rel_obj must be greater than 0') else: @@ -467,7 +471,7 @@ def validate( raise ValueError( 'tol_grad requires that algorithm be set to bfgs or lbfgs' ) - if isinstance(self.tol_grad, float): + if isinstance(self.tol_grad, (float, np.floating)): if self.tol_grad <= 0: raise ValueError('tol_grad must be greater than 0') else: @@ -479,7 +483,7 @@ def validate( 'tol_rel_grad requires that algorithm be set to bfgs' ' or lbfgs' ) - if isinstance(self.tol_rel_grad, float): + if isinstance(self.tol_rel_grad, (float, np.floating)): if self.tol_rel_grad <= 0: raise ValueError('tol_rel_grad must be greater than 0') else: @@ -490,7 +494,7 @@ def validate( raise ValueError( 'tol_param requires that algorithm be set to bfgs or lbfgs' ) - if isinstance(self.tol_param, float): + if isinstance(self.tol_param, (float, np.floating)): if self.tol_param <= 0: raise ValueError('tol_param must be greater than 0') else: @@ -501,7 +505,7 @@ def validate( raise ValueError( 'history_size requires that algorithm be set to lbfgs' ) - if isinstance(self.history_size, int): + if isinstance(self.history_size, (int, np.integer)): if self.history_size < 0: raise ValueError('history_size must be greater than 0') else: @@ -610,52 +614,62 @@ def validate( ) ) if self.iter is not None: - if self.iter < 1 or not isinstance(self.iter, int): + if self.iter < 1 or not isinstance(self.iter, (int, np.integer)): raise ValueError( 'iter must be a positive integer,' ' found {}'.format(self.iter) ) if self.grad_samples is not None: - if self.grad_samples < 1 or not isinstance(self.grad_samples, int): + if self.grad_samples < 1 or not isinstance( + self.grad_samples, (int, np.integer) + ): raise ValueError( 'grad_samples must be a positive integer,' ' found {}'.format(self.grad_samples) ) if self.elbo_samples is not None: - if self.elbo_samples < 1 or not isinstance(self.elbo_samples, int): + if self.elbo_samples < 1 or not isinstance( + self.elbo_samples, (int, np.integer) + ): raise ValueError( 'elbo_samples must be a positive integer,' ' found {}'.format(self.elbo_samples) ) if self.eta is not None: - if self.eta < 0 or not isinstance(self.eta, (int, float)): + if self.eta < 0 or not isinstance( + self.eta, (int, float, np.integer, np.floating) + ): raise ValueError( 'eta must be a non-negative number,' ' found {}'.format(self.eta) ) if self.adapt_iter is not None: - if self.adapt_iter < 1 or not isinstance(self.adapt_iter, int): + if self.adapt_iter < 1 or not isinstance( + self.adapt_iter, (int, np.integer) + ): raise ValueError( 'adapt_iter must be a positive integer,' ' found {}'.format(self.adapt_iter) ) if self.tol_rel_obj is not None: if self.tol_rel_obj <= 0 or not isinstance( - self.tol_rel_obj, (int, float) + self.tol_rel_obj, (int, float, np.integer, np.floating) ): raise ValueError( 'tol_rel_obj must be a positive number,' ' found {}'.format(self.tol_rel_obj) ) if self.eval_elbo is not None: - if self.eval_elbo < 1 or not isinstance(self.eval_elbo, int): + if self.eval_elbo < 1 or not isinstance( + self.eval_elbo, (int, np.integer) + ): raise ValueError( 'eval_elbo must be a positive integer,' ' found {}'.format(self.eval_elbo) ) if self.output_samples is not None: if self.output_samples < 1 or not isinstance( - self.output_samples, int + self.output_samples, (int, np.integer) ): raise ValueError( 'output_samples must be a positive integer,' @@ -792,7 +806,10 @@ def validate(self) -> None: ' cannot write to dir: {}.'.format(self.output_dir) ) from exc if self.refresh is not None: - if not isinstance(self.refresh, int) or self.refresh < 1: + if ( + not isinstance(self.refresh, (int, np.integer)) + or self.refresh < 1 + ): raise ValueError( 'Argument "refresh" must be a positive integer value, ' 'found {}.'.format(self.refresh) @@ -800,7 +817,7 @@ def validate(self) -> None: if self.sig_figs is not None: if ( - not isinstance(self.sig_figs, int) + not isinstance(self.sig_figs, (int, np.integer)) or self.sig_figs < 1 or self.sig_figs > 18 ): @@ -822,13 +839,13 @@ def validate(self) -> None: rng = RandomState() self.seed = rng.randint(1, 99999 + 1) else: - if not isinstance(self.seed, (int, list)): + if not isinstance(self.seed, (int, list, np.integer)): raise ValueError( 'Argument "seed" must be an integer between ' '0 and 2**32-1, found {}.'.format(self.seed) ) - if isinstance(self.seed, int): - if self.seed < 0 or self.seed > 2 ** 32 - 1: + if isinstance(self.seed, (int, np.integer)): + if self.seed < 0 or self.seed > 2**32 - 1: raise ValueError( 'Argument "seed" must be an integer between ' '0 and 2**32-1, found {}.'.format(self.seed) @@ -847,7 +864,7 @@ def validate(self) -> None: ) ) for seed in self.seed: - if seed < 0 or seed > 2 ** 32 - 1: + if seed < 0 or seed > 2**32 - 1: raise ValueError( 'Argument "seed" must be an integer value' ' between 0 and 2**32-1,' @@ -861,7 +878,7 @@ def validate(self) -> None: raise ValueError('Argument "data" must be string or dict') if self.inits is not None: - if isinstance(self.inits, (float, int)): + if isinstance(self.inits, (float, int, np.floating, np.integer)): if self.inits < 0: raise ValueError( 'Argument "inits" must be > 0, found {}'.format( diff --git a/test/test_cmdstan_args.py b/test/test_cmdstan_args.py index 8c6364d5..049c5928 100644 --- a/test/test_cmdstan_args.py +++ b/test/test_cmdstan_args.py @@ -6,6 +6,7 @@ from test import check_present from time import time +import numpy as np import pytest from cmdstanpy import _TMPDIR, cmdstan_path @@ -349,6 +350,23 @@ def test_args_good() -> None: cmd = cmdstan_args.compose_command(idx=0, csv_file='bern-output-1.csv') assert 'id=7 random seed=' in ' '.join(cmd) + # integer type + rng = np.random.default_rng(42) + seed = rng.integers(low=0, high=int(1e7)) + assert not isinstance(seed, int) + assert isinstance(seed, np.integer) + + cmdstan_args = CmdStanArgs( + model_name='bernoulli', + model_exe=exe, + chain_ids=[7, 11, 18, 29], + data=jdata, + seed=seed, + method_args=sampler_args, + ) + cmd = cmdstan_args.compose_command(idx=0, csv_file='bern-output-1.csv') + assert f'id=7 random seed={seed}' in ' '.join(cmd) + dirname = 'tmp' + str(time()) if os.path.exists(dirname): os.rmdir(dirname)