Skip to content

Commit

Permalink
Merge pull request #663 from stan-dev/fix/647-accept-np-typed-args
Browse files Browse the repository at this point in the history
Accept np.floating and np.integer for arguments
  • Loading branch information
WardBrian authored Mar 24, 2023
2 parents 10c2b46 + 20c57cc commit cd65084
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 33 deletions.
83 changes: 50 additions & 33 deletions cmdstanpy/cmdstan_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def validate(self, chains: Optional[int]) -> None:
* if file(s) for metric are supplied, check contents.
* length of per-chain lists equals specified # of chains
"""
if not isinstance(chains, int) or chains < 1:
if not isinstance(chains, (int, np.integer)) or chains < 1:
raise ValueError(
'Sampler expects number of chains to be greater than 0.'
)
Expand Down Expand Up @@ -110,7 +110,9 @@ def validate(self, chains: Optional[int]) -> None:
raise ValueError(msg)

if self.iter_warmup is not None:
if self.iter_warmup < 0 or not isinstance(self.iter_warmup, int):
if self.iter_warmup < 0 or not isinstance(
self.iter_warmup, (int, np.integer)
):
raise ValueError(
'Value for iter_warmup must be a non-negative integer,'
' found {}.'.format(self.iter_warmup)
Expand All @@ -122,28 +124,30 @@ def validate(self, chains: Optional[int]) -> None:
)
if self.iter_sampling is not None:
if self.iter_sampling < 0 or not isinstance(
self.iter_sampling, int
self.iter_sampling, (int, np.integer)
):
raise ValueError(
'Argument "iter_sampling" must be a non-negative integer,'
' found {}.'.format(self.iter_sampling)
)
if self.thin is not None:
if self.thin < 1 or not isinstance(self.thin, int):
if self.thin < 1 or not isinstance(self.thin, (int, np.integer)):
raise ValueError(
'Argument "thin" must be a positive integer,'
'found {}.'.format(self.thin)
)
if self.max_treedepth is not None:
if self.max_treedepth < 1 or not isinstance(
self.max_treedepth, int
self.max_treedepth, (int, np.integer)
):
raise ValueError(
'Argument "max_treedepth" must be a positive integer,'
' found {}.'.format(self.max_treedepth)
)
if self.step_size is not None:
if isinstance(self.step_size, (float, int)):
if isinstance(
self.step_size, (float, int, np.integer, np.floating)
):
if self.step_size <= 0:
raise ValueError(
'Argument "step_size" must be > 0, '
Expand Down Expand Up @@ -178,7 +182,7 @@ def validate(self, chains: Optional[int]) -> None:
else:
self.metric_type = 'dense_e'
self.metric_file = self.metric
elif isinstance(self.metric, Dict):
elif isinstance(self.metric, dict):
if 'inv_metric' not in self.metric:
raise ValueError(
'Entry "inv_metric" not found in metric dict.'
Expand Down Expand Up @@ -289,23 +293,23 @@ def validate(self, chains: Optional[int]) -> None:
)
if self.adapt_init_phase is not None:
if self.adapt_init_phase < 0 or not isinstance(
self.adapt_init_phase, int
self.adapt_init_phase, (int, np.integer)
):
raise ValueError(
'Argument "adapt_init_phase" must be a non-negative '
'integer, found {}'.format(self.adapt_init_phase)
)
if self.adapt_metric_window is not None:
if self.adapt_metric_window < 0 or not isinstance(
self.adapt_metric_window, int
self.adapt_metric_window, (int, np.integer)
):
raise ValueError(
'Argument "adapt_metric_window" must be a non-negative '
' integer, found {}'.format(self.adapt_metric_window)
)
if self.adapt_step_size is not None:
if self.adapt_step_size < 0 or not isinstance(
self.adapt_step_size, int
self.adapt_step_size, (int, np.integer)
):
raise ValueError(
'Argument "adapt_step_size" must be a non-negative integer,'
Expand Down Expand Up @@ -426,14 +430,14 @@ def validate(
raise ValueError(
'init_alpha requires that algorithm be set to bfgs or lbfgs'
)
if isinstance(self.init_alpha, float):
if isinstance(self.init_alpha, (float, np.floating)):
if self.init_alpha <= 0:
raise ValueError('init_alpha must be greater than 0')
else:
raise ValueError('init_alpha must be type of float')

if self.iter is not None:
if isinstance(self.iter, int):
if isinstance(self.iter, (int, np.integer)):
if self.iter < 0:
raise ValueError('iter must be greater than 0')
else:
Expand All @@ -444,7 +448,7 @@ def validate(
raise ValueError(
'tol_obj requires that algorithm be set to bfgs or lbfgs'
)
if isinstance(self.tol_obj, float):
if isinstance(self.tol_obj, (float, np.floating)):
if self.tol_obj <= 0:
raise ValueError('tol_obj must be greater than 0')
else:
Expand All @@ -456,7 +460,7 @@ def validate(
'tol_rel_obj requires that algorithm be set to bfgs'
' or lbfgs'
)
if isinstance(self.tol_rel_obj, float):
if isinstance(self.tol_rel_obj, (float, np.floating)):
if self.tol_rel_obj <= 0:
raise ValueError('tol_rel_obj must be greater than 0')
else:
Expand All @@ -467,7 +471,7 @@ def validate(
raise ValueError(
'tol_grad requires that algorithm be set to bfgs or lbfgs'
)
if isinstance(self.tol_grad, float):
if isinstance(self.tol_grad, (float, np.floating)):
if self.tol_grad <= 0:
raise ValueError('tol_grad must be greater than 0')
else:
Expand All @@ -479,7 +483,7 @@ def validate(
'tol_rel_grad requires that algorithm be set to bfgs'
' or lbfgs'
)
if isinstance(self.tol_rel_grad, float):
if isinstance(self.tol_rel_grad, (float, np.floating)):
if self.tol_rel_grad <= 0:
raise ValueError('tol_rel_grad must be greater than 0')
else:
Expand All @@ -490,7 +494,7 @@ def validate(
raise ValueError(
'tol_param requires that algorithm be set to bfgs or lbfgs'
)
if isinstance(self.tol_param, float):
if isinstance(self.tol_param, (float, np.floating)):
if self.tol_param <= 0:
raise ValueError('tol_param must be greater than 0')
else:
Expand All @@ -501,7 +505,7 @@ def validate(
raise ValueError(
'history_size requires that algorithm be set to lbfgs'
)
if isinstance(self.history_size, int):
if isinstance(self.history_size, (int, np.integer)):
if self.history_size < 0:
raise ValueError('history_size must be greater than 0')
else:
Expand Down Expand Up @@ -610,52 +614,62 @@ def validate(
)
)
if self.iter is not None:
if self.iter < 1 or not isinstance(self.iter, int):
if self.iter < 1 or not isinstance(self.iter, (int, np.integer)):
raise ValueError(
'iter must be a positive integer,'
' found {}'.format(self.iter)
)
if self.grad_samples is not None:
if self.grad_samples < 1 or not isinstance(self.grad_samples, int):
if self.grad_samples < 1 or not isinstance(
self.grad_samples, (int, np.integer)
):
raise ValueError(
'grad_samples must be a positive integer,'
' found {}'.format(self.grad_samples)
)
if self.elbo_samples is not None:
if self.elbo_samples < 1 or not isinstance(self.elbo_samples, int):
if self.elbo_samples < 1 or not isinstance(
self.elbo_samples, (int, np.integer)
):
raise ValueError(
'elbo_samples must be a positive integer,'
' found {}'.format(self.elbo_samples)
)
if self.eta is not None:
if self.eta < 0 or not isinstance(self.eta, (int, float)):
if self.eta < 0 or not isinstance(
self.eta, (int, float, np.integer, np.floating)
):
raise ValueError(
'eta must be a non-negative number,'
' found {}'.format(self.eta)
)
if self.adapt_iter is not None:
if self.adapt_iter < 1 or not isinstance(self.adapt_iter, int):
if self.adapt_iter < 1 or not isinstance(
self.adapt_iter, (int, np.integer)
):
raise ValueError(
'adapt_iter must be a positive integer,'
' found {}'.format(self.adapt_iter)
)
if self.tol_rel_obj is not None:
if self.tol_rel_obj <= 0 or not isinstance(
self.tol_rel_obj, (int, float)
self.tol_rel_obj, (int, float, np.integer, np.floating)
):
raise ValueError(
'tol_rel_obj must be a positive number,'
' found {}'.format(self.tol_rel_obj)
)
if self.eval_elbo is not None:
if self.eval_elbo < 1 or not isinstance(self.eval_elbo, int):
if self.eval_elbo < 1 or not isinstance(
self.eval_elbo, (int, np.integer)
):
raise ValueError(
'eval_elbo must be a positive integer,'
' found {}'.format(self.eval_elbo)
)
if self.output_samples is not None:
if self.output_samples < 1 or not isinstance(
self.output_samples, int
self.output_samples, (int, np.integer)
):
raise ValueError(
'output_samples must be a positive integer,'
Expand Down Expand Up @@ -792,15 +806,18 @@ def validate(self) -> None:
' cannot write to dir: {}.'.format(self.output_dir)
) from exc
if self.refresh is not None:
if not isinstance(self.refresh, int) or self.refresh < 1:
if (
not isinstance(self.refresh, (int, np.integer))
or self.refresh < 1
):
raise ValueError(
'Argument "refresh" must be a positive integer value, '
'found {}.'.format(self.refresh)
)

if self.sig_figs is not None:
if (
not isinstance(self.sig_figs, int)
not isinstance(self.sig_figs, (int, np.integer))
or self.sig_figs < 1
or self.sig_figs > 18
):
Expand All @@ -822,13 +839,13 @@ def validate(self) -> None:
rng = RandomState()
self.seed = rng.randint(1, 99999 + 1)
else:
if not isinstance(self.seed, (int, list)):
if not isinstance(self.seed, (int, list, np.integer)):
raise ValueError(
'Argument "seed" must be an integer between '
'0 and 2**32-1, found {}.'.format(self.seed)
)
if isinstance(self.seed, int):
if self.seed < 0 or self.seed > 2 ** 32 - 1:
if isinstance(self.seed, (int, np.integer)):
if self.seed < 0 or self.seed > 2**32 - 1:
raise ValueError(
'Argument "seed" must be an integer between '
'0 and 2**32-1, found {}.'.format(self.seed)
Expand All @@ -847,7 +864,7 @@ def validate(self) -> None:
)
)
for seed in self.seed:
if seed < 0 or seed > 2 ** 32 - 1:
if seed < 0 or seed > 2**32 - 1:
raise ValueError(
'Argument "seed" must be an integer value'
' between 0 and 2**32-1,'
Expand All @@ -861,7 +878,7 @@ def validate(self) -> None:
raise ValueError('Argument "data" must be string or dict')

if self.inits is not None:
if isinstance(self.inits, (float, int)):
if isinstance(self.inits, (float, int, np.floating, np.integer)):
if self.inits < 0:
raise ValueError(
'Argument "inits" must be > 0, found {}'.format(
Expand Down
18 changes: 18 additions & 0 deletions test/test_cmdstan_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from test import check_present
from time import time

import numpy as np
import pytest

from cmdstanpy import _TMPDIR, cmdstan_path
Expand Down Expand Up @@ -349,6 +350,23 @@ def test_args_good() -> None:
cmd = cmdstan_args.compose_command(idx=0, csv_file='bern-output-1.csv')
assert 'id=7 random seed=' in ' '.join(cmd)

# integer type
rng = np.random.default_rng(42)
seed = rng.integers(low=0, high=int(1e7))
assert not isinstance(seed, int)
assert isinstance(seed, np.integer)

cmdstan_args = CmdStanArgs(
model_name='bernoulli',
model_exe=exe,
chain_ids=[7, 11, 18, 29],
data=jdata,
seed=seed,
method_args=sampler_args,
)
cmd = cmdstan_args.compose_command(idx=0, csv_file='bern-output-1.csv')
assert f'id=7 random seed={seed}' in ' '.join(cmd)

dirname = 'tmp' + str(time())
if os.path.exists(dirname):
os.rmdir(dirname)
Expand Down

0 comments on commit cd65084

Please sign in to comment.