diff --git a/sharktank/tests/types/dataset_test.py b/sharktank/tests/types/dataset_test.py index 82c9723f0..99d176c5b 100644 --- a/sharktank/tests/types/dataset_test.py +++ b/sharktank/tests/types/dataset_test.py @@ -77,6 +77,22 @@ def testTransform(self): self.assertIsNot(pt1, pt2) torch.testing.assert_close(pt1, pt2) + def testPop(self): + t1 = Theta( + _flat_t_dict( + _t("a.b.c", 1, 2), + _t("a.c.d", 10, 11), + _t("a.b.3", 3, 4), + ) + ) + popped = t1.pop("a.b").flatten() + t1 = t1.flatten() + + self.assertIsNotNone("a.c.d", t1.keys()) + self.assertNotIn("a.b.c", t1.keys()) + self.assertNotIn("a.b.3", t1.keys()) + self.assertIn("a.b.3", popped.keys()) + class DatasetTest(unittest.TestCase): def setUp(self):