From 5ab435f5c7e2df6fba904ab567af29a94c591881 Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Thu, 7 Sep 2023 14:56:31 -0400 Subject: [PATCH] scalar -> atleast_2d (#138) Signed-off-by: nstarman --- src/stream_ml/pytorch/prior/_track.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/stream_ml/pytorch/prior/_track.py b/src/stream_ml/pytorch/prior/_track.py index 5751080..45e6d99 100644 --- a/src/stream_ml/pytorch/prior/_track.py +++ b/src/stream_ml/pytorch/prior/_track.py @@ -26,7 +26,9 @@ def _atleast_2d(x: Array) -> Array: """Ensure that x is at least 2d.""" - if x.ndim == 1: + if x.ndim == 0: + return x[None, None] + elif x.ndim == 1: return x[:, None] return x