diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 786b9e2896c2c..3b8e8165b4bfd 100755 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -1219,6 +1219,11 @@ def test_parse_col_name(self): self.assert_eq(parse_attr_name("`a"), None) self.assert_eq(parse_attr_name("a`"), None) + self.assert_eq(parse_attr_name("`a`.b"), ["a", "b"]) + self.assert_eq(parse_attr_name("`a`.`b`"), ["a", "b"]) + self.assert_eq(parse_attr_name("`a```.b"), ["a`", "b"]) + self.assert_eq(parse_attr_name("`a``.b"), None) + self.assert_eq(parse_attr_name("a.b.c"), ["a", "b", "c"]) self.assert_eq(parse_attr_name("`a`.`b`.`c`"), ["a", "b", "c"]) self.assert_eq(parse_attr_name("a.`b`.c"), ["a", "b", "c"]) @@ -1284,7 +1289,6 @@ def test_verify_col_name(self): self.assertTrue(verify_col_name("m.`s`.id", cdf.schema)) self.assertTrue(verify_col_name("`m`.`s`.`id`", cdf.schema)) self.assertFalse(verify_col_name("m.`s.id`", cdf.schema)) - self.assertFalse(verify_col_name("m.`s.id`", cdf.schema)) self.assertTrue(verify_col_name("a", cdf.schema)) self.assertTrue(verify_col_name("`a`", cdf.schema)) @@ -1294,6 +1298,47 @@ def test_verify_col_name(self): self.assertTrue(verify_col_name("`a`.`v`", cdf.schema)) self.assertFalse(verify_col_name("`a`.`x`", cdf.schema)) + cdf = ( + self.connect.range(10) + .withColumn("v", CF.lit(123)) + .withColumn("s.s", CF.struct("id", "v")) + .withColumn("m`", CF.struct("`s.s`", "v")) + ) + + # root + # |-- id: long (nullable = false) + # |-- v: string (nullable = false) + # |-- s.s: struct (nullable = false) + # | |-- id: long (nullable = false) + # | |-- v: string (nullable = false) + # |-- m`: struct (nullable = false) + # | |-- s.s: struct (nullable = false) + # | | |-- id: long (nullable = false) + # | | |-- v: string (nullable = false) + # | |-- v: string (nullable = false) + + self.assertFalse(verify_col_name("s", cdf.schema)) + self.assertFalse(verify_col_name("`s`", cdf.schema)) + self.assertFalse(verify_col_name("s.s", cdf.schema)) + self.assertFalse(verify_col_name("s.`s`", cdf.schema)) + self.assertFalse(verify_col_name("`s`.s", cdf.schema)) + self.assertTrue(verify_col_name("`s.s`", cdf.schema)) + + self.assertFalse(verify_col_name("m", cdf.schema)) + self.assertFalse(verify_col_name("`m`", cdf.schema)) + self.assertTrue(verify_col_name("`m```", cdf.schema)) + + self.assertFalse(verify_col_name("`m```.s", cdf.schema)) + self.assertFalse(verify_col_name("`m```.`s`", cdf.schema)) + self.assertFalse(verify_col_name("`m```.s.s", cdf.schema)) + self.assertFalse(verify_col_name("`m```.s.`s`", cdf.schema)) + self.assertTrue(verify_col_name("`m```.`s.s`", cdf.schema)) + + self.assertFalse(verify_col_name("`m```.s.s.v", cdf.schema)) + self.assertFalse(verify_col_name("`m```.s.`s`.v", cdf.schema)) + self.assertTrue(verify_col_name("`m```.`s.s`.v", cdf.schema)) + self.assertTrue(verify_col_name("`m```.`s.s`.`v`", cdf.schema)) + if __name__ == "__main__": from pyspark.sql.tests.connect.test_connect_basic import * # noqa: F401