diff --git a/integrations/spark/openhouse-spark-itest/src/test/java/com/linkedin/openhouse/spark/statementtest/SetTableReplicationPolicyStatementTest.java b/integrations/spark/openhouse-spark-itest/src/test/java/com/linkedin/openhouse/spark/statementtest/SetTableReplicationPolicyStatementTest.java new file mode 100644 index 00000000..02138f06 --- /dev/null +++ b/integrations/spark/openhouse-spark-itest/src/test/java/com/linkedin/openhouse/spark/statementtest/SetTableReplicationPolicyStatementTest.java @@ -0,0 +1,260 @@ +package com.linkedin.openhouse.spark.statementtest; + +import com.google.gson.Gson; +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.linkedin.openhouse.spark.sql.catalyst.parser.extensions.OpenhouseParseException; +import java.nio.file.Files; +import lombok.SneakyThrows; +import org.apache.hadoop.fs.Path; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.execution.ExplainMode; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +public class SetTableReplicationPolicyStatementTest { + private static SparkSession spark = null; + + @SneakyThrows + @BeforeAll + public void setupSpark() { + Path unittest = new Path(Files.createTempDirectory("unittest_settablepolicy").toString()); + spark = + SparkSession.builder() + .master("local[2]") + .config( + "spark.sql.extensions", + ("org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions," + + "com.linkedin.openhouse.spark.extensions.OpenhouseSparkSessionExtensions")) + .config("spark.sql.catalog.openhouse", "org.apache.iceberg.spark.SparkCatalog") + .config("spark.sql.catalog.openhouse.type", "hadoop") + .config("spark.sql.catalog.openhouse.warehouse", unittest.toString()) + .getOrCreate(); + } + + @Test + public void testSimpleSetReplicationPolicy() { + String replicationConfigJson = "[{\"destination\":\"a\", \"interval\":\"24H\"}]"; + Dataset ds = + spark.sql( + "ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = " + + "({destination:'a', interval:24H}))"); + assert isPlanValid(ds, replicationConfigJson); + + // Test support with multiple clusters + replicationConfigJson = + "[{\"destination\":\"a\", \"interval\":\"12H\"}, {\"destination\":\"aa\", \"interval\":\"12H\"}]"; + ds = + spark.sql( + "ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = " + + "({destination:'a', interval:12h}, {destination:'aa', interval:12H}))"); + assert isPlanValid(ds, replicationConfigJson); + } + + @Test + public void testSimpleSetReplicationPolicyOptionalInterval() { + // Test with optional interval + String replicationConfigJson = "[{\"destination\":\"a\"}]"; + Dataset ds = + spark.sql( + "ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = " + "({destination:'a'}))"); + assert isPlanValid(ds, replicationConfigJson); + + // Test with optional interval for multiple clusters + replicationConfigJson = "[{\"destination\":\"a\"}, {\"destination\":\"b\"}]"; + ds = + spark.sql( + "ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = " + + "({destination:'a'}, {destination:'b'}))"); + assert isPlanValid(ds, replicationConfigJson); + } + + @Test + public void testReplicationPolicyWithoutProperSyntax() { + // Empty cluster value + Assertions.assertThrows( + OpenhouseParseException.class, + () -> + spark + .sql("ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({destination:}))") + .show()); + + // Empty interval value + Assertions.assertThrows( + OpenhouseParseException.class, + () -> + spark + .sql( + "ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({destination: 'aa', interval:}))") + .show()); + + // Empty interval value + Assertions.assertThrows( + OpenhouseParseException.class, + () -> + spark + .sql( + "ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({destination: 'aa', interval:}))") + .show()); + + // Missing cluster value but interval present + Assertions.assertThrows( + OpenhouseParseException.class, + () -> + spark + .sql( + "ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({destination:, interval: '12h'}))") + .show()); + + // Missing interval value but keyword present + Assertions.assertThrows( + OpenhouseParseException.class, + () -> + spark + .sql( + "ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({destination: 'a', interval:}))") + .show()); + + // Missing cluster value for multiple clusters + Assertions.assertThrows( + OpenhouseParseException.class, + () -> + spark + .sql( + "ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({destination:, interval:'12H'}, {cluster:, interval: '12H'}))") + .show()); + + // Missing cluster keyword for multiple clusters + Assertions.assertThrows( + OpenhouseParseException.class, + () -> + spark + .sql( + "ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({destination:'a'}, {interval: '12h'}))") + .show()); + + // Missing cluster keyword + Assertions.assertThrows( + OpenhouseParseException.class, + () -> + spark + .sql( + "ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({interval: '12h'}))") + .show()); + + // Typo in keyword interval + Assertions.assertThrows( + OpenhouseParseException.class, + () -> + spark + .sql( + "ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({destination: 'aa', interv: '12h'}))") + .show()); + + // Typo in keyword cluster + Assertions.assertThrows( + OpenhouseParseException.class, + () -> + spark + .sql( + "ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({destina: 'aa', interval: '12h'}))") + .show()); + + // Missing quote in cluster value + Assertions.assertThrows( + OpenhouseParseException.class, + () -> + spark + .sql( + "ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({destination: aa', interval: '12h}))") + .show()); + + // Typo in REPLICATION keyword + Assertions.assertThrows( + OpenhouseParseException.class, + () -> + spark + .sql( + "ALTER TABLE openhouse.db.table SET POLICY (REPLICAT = ({destination: 'aa', interval: '12h'}))") + .show()); + + // Interval input does not follow 'h/H' format + Assertions.assertThrows( + OpenhouseParseException.class, + () -> + spark + .sql( + "ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({destination: 'aa', interval: '12'}))") + .show()); + + Assertions.assertThrows( + OpenhouseParseException.class, + () -> + spark + .sql( + "ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({destination: 'aa', interval: '1D'}))") + .show()); + + Assertions.assertThrows( + OpenhouseParseException.class, + () -> + spark + .sql( + "ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({destination: 'aa', interval: '12d'}))") + .show()); + + // Missing cluster and interval values + Assertions.assertThrows( + OpenhouseParseException.class, + () -> spark.sql("ALTER TABLE openhouse.db.table SET POLICY (REPLICATION = ({}))").show()); + } + + @BeforeEach + public void setup() { + spark.sql("CREATE TABLE openhouse.db.table (id bigint, data string) USING iceberg").show(); + spark.sql("CREATE TABLE openhouse.0_.0_ (id bigint, data string) USING iceberg").show(); + spark + .sql("ALTER TABLE openhouse.db.table SET TBLPROPERTIES ('openhouse.tableId' = 'tableid')") + .show(); + spark + .sql("ALTER TABLE openhouse.0_.0_ SET TBLPROPERTIES ('openhouse.tableId' = 'tableid')") + .show(); + } + + @AfterEach + public void tearDown() { + spark.sql("DROP TABLE openhouse.db.table").show(); + spark.sql("DROP TABLE openhouse.0_.0_").show(); + } + + @AfterAll + public void tearDownSpark() { + spark.close(); + } + + @SneakyThrows + private boolean isPlanValid(Dataset dataframe, String replicationConfigJson) { + String queryStr = dataframe.queryExecution().explainString(ExplainMode.fromString("simple")); + JsonArray jsonArray = new Gson().fromJson(replicationConfigJson, JsonArray.class); + boolean isValid = false; + for (JsonElement element : jsonArray) { + JsonObject entry = element.getAsJsonObject(); + String destination = entry.get("destination").getAsString(); + isValid = queryStr.contains(destination); + if (entry.has("interval")) { + String interval = entry.get("interval").getAsString(); + isValid = queryStr.contains(destination) && queryStr.contains(interval); + } + } + return isValid; + } +} diff --git a/integrations/spark/openhouse-spark-runtime/src/main/antlr/com/linkedin/openhouse/spark/sql/catalyst/parser/extensions/OpenhouseSqlExtensions.g4 b/integrations/spark/openhouse-spark-runtime/src/main/antlr/com/linkedin/openhouse/spark/sql/catalyst/parser/extensions/OpenhouseSqlExtensions.g4 index bf762af6..67fa6585 100644 --- a/integrations/spark/openhouse-spark-runtime/src/main/antlr/com/linkedin/openhouse/spark/sql/catalyst/parser/extensions/OpenhouseSqlExtensions.g4 +++ b/integrations/spark/openhouse-spark-runtime/src/main/antlr/com/linkedin/openhouse/spark/sql/catalyst/parser/extensions/OpenhouseSqlExtensions.g4 @@ -24,6 +24,7 @@ singleStatement statement : ALTER TABLE multipartIdentifier SET POLICY '(' retentionPolicy (columnRetentionPolicy)? ')' #setRetentionPolicy + | ALTER TABLE multipartIdentifier SET POLICY '(' replicationPolicy ')' #setReplicationPolicy | ALTER TABLE multipartIdentifier SET POLICY '(' sharingPolicy ')' #setSharingPolicy | ALTER TABLE multipartIdentifier MODIFY columnNameClause SET columnPolicy #setColumnPolicyTag | GRANT privilege ON grantableResource TO principal #grantStatement @@ -64,7 +65,7 @@ quotedIdentifier ; nonReserved - : ALTER | TABLE | SET | POLICY | RETENTION | SHARING + : ALTER | TABLE | SET | POLICY | RETENTION | SHARING | REPLICATION | GRANT | REVOKE | ON | TO | SHOW | GRANTS | PATTERN | WHERE | COLUMN ; @@ -83,6 +84,26 @@ columnRetentionPolicy : ON columnNameClause (columnRetentionPolicyPatternClause)? ; +replicationPolicy + : REPLICATION '=' tableReplicationPolicy + ; + +tableReplicationPolicy + : '(' replicationPolicyClause (',' replicationPolicyClause)* ')' + ; + +replicationPolicyClause + : '{' replicationPolicyClusterClause (',' replicationPolicyIntervalClause)? '}' + ; + +replicationPolicyClusterClause + : DESTINATION ':' STRING + ; + +replicationPolicyIntervalClause + : INTERVAL ':' RETENTION_HOUR + ; + columnRetentionPolicyPatternClause : WHERE retentionColumnPatternClause ; @@ -136,6 +157,7 @@ TABLE: 'TABLE'; SET: 'SET'; POLICY: 'POLICY'; RETENTION: 'RETENTION'; +REPLICATION: 'REPLICATION'; SHARING: 'SHARING'; GRANT: 'GRANT'; REVOKE: 'REVOKE'; @@ -150,6 +172,8 @@ DATABASE: 'DATABASE'; SHOW: 'SHOW'; GRANTS: 'GRANTS'; PATTERN: 'PATTERN'; +DESTINATION: 'DESTINATION'; +INTERVAL: 'INTERVAL'; WHERE: 'WHERE'; COLUMN: 'COLUMN'; PII: 'PII'; diff --git a/integrations/spark/openhouse-spark-runtime/src/main/scala/com/linkedin/openhouse/spark/sql/catalyst/parser/extensions/OpenhouseSqlExtensionsAstBuilder.scala b/integrations/spark/openhouse-spark-runtime/src/main/scala/com/linkedin/openhouse/spark/sql/catalyst/parser/extensions/OpenhouseSqlExtensionsAstBuilder.scala index 4b8fc405..408c9cf3 100644 --- a/integrations/spark/openhouse-spark-runtime/src/main/scala/com/linkedin/openhouse/spark/sql/catalyst/parser/extensions/OpenhouseSqlExtensionsAstBuilder.scala +++ b/integrations/spark/openhouse-spark-runtime/src/main/scala/com/linkedin/openhouse/spark/sql/catalyst/parser/extensions/OpenhouseSqlExtensionsAstBuilder.scala @@ -2,13 +2,14 @@ package com.linkedin.openhouse.spark.sql.catalyst.parser.extensions import com.linkedin.openhouse.spark.sql.catalyst.enums.GrantableResourceTypes import com.linkedin.openhouse.spark.sql.catalyst.parser.extensions.OpenhouseSqlExtensionsParser._ -import com.linkedin.openhouse.spark.sql.catalyst.plans.logical.{GrantRevokeStatement, SetRetentionPolicy, SetSharingPolicy, SetColumnPolicyTag, ShowGrantsStatement} +import com.linkedin.openhouse.spark.sql.catalyst.plans.logical.{GrantRevokeStatement, SetColumnPolicyTag, SetReplicationPolicy, SetRetentionPolicy, SetSharingPolicy, ShowGrantsStatement} import com.linkedin.openhouse.spark.sql.catalyst.enums.GrantableResourceTypes.GrantableResourceType import com.linkedin.openhouse.gen.tables.client.model.TimePartitionSpec import org.antlr.v4.runtime.tree.ParseTree import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import scala.collection.JavaConversions.iterableAsScalaIterable import scala.collection.JavaConverters._ class OpenhouseSqlExtensionsAstBuilder (delegate: ParserInterface) extends OpenhouseSqlExtensionsBaseVisitor[AnyRef] { @@ -26,6 +27,12 @@ class OpenhouseSqlExtensionsAstBuilder (delegate: ParserInterface) extends Openh SetRetentionPolicy(tableName, granularity, count, Option(colName), Option(colPattern)) } + override def visitSetReplicationPolicy(ctx: SetReplicationPolicyContext): SetReplicationPolicy = { + val tableName = typedVisit[Seq[String]](ctx.multipartIdentifier) + val replicationPolicies = typedVisit[Seq[(String, Option[String])]](ctx.replicationPolicy()) + SetReplicationPolicy(tableName, replicationPolicies) + } + override def visitSetSharingPolicy(ctx: SetSharingPolicyContext): SetSharingPolicy = { val tableName = typedVisit[Seq[String]](ctx.multipartIdentifier) val sharing = typedVisit[String](ctx.sharingPolicy()) @@ -86,6 +93,31 @@ class OpenhouseSqlExtensionsAstBuilder (delegate: ParserInterface) extends Openh typedVisit[(String, Int)](ctx.duration()) } + override def visitReplicationPolicy(ctx: ReplicationPolicyContext): Seq[(String, Option[String])] = { + typedVisit[Seq[(String, Option[String])]](ctx.tableReplicationPolicy()) + } + + override def visitTableReplicationPolicy(ctx: TableReplicationPolicyContext): Seq[(String, Option[String])] = { + toSeq(ctx.replicationPolicyClause()).map(typedVisit[(String, Option[String])]) + } + + override def visitReplicationPolicyClause(ctx: ReplicationPolicyClauseContext): (String, Option[String]) = { + val cluster = typedVisit[String](ctx.replicationPolicyClusterClause()) + val interval = if (ctx.replicationPolicyIntervalClause() != null) + typedVisit[String](ctx.replicationPolicyIntervalClause()) + else + null + (cluster, Option(interval)) + } + + override def visitReplicationPolicyClusterClause(ctx: ReplicationPolicyClusterClauseContext): (String) = { + ctx.STRING().getText + } + + override def visitReplicationPolicyIntervalClause(ctx: ReplicationPolicyIntervalClauseContext): (String) = { + ctx.RETENTION_HOUR().getText.toUpperCase + } + override def visitColumnRetentionPolicy(ctx: ColumnRetentionPolicyContext): (String, String) = { if (ctx.columnRetentionPolicyPatternClause() != null) { (ctx.columnNameClause().identifier().getText(), ctx.columnRetentionPolicyPatternClause().retentionColumnPatternClause().STRING().getText) diff --git a/integrations/spark/openhouse-spark-runtime/src/main/scala/com/linkedin/openhouse/spark/sql/catalyst/plans/logical/SetReplicationPolicy.scala b/integrations/spark/openhouse-spark-runtime/src/main/scala/com/linkedin/openhouse/spark/sql/catalyst/plans/logical/SetReplicationPolicy.scala new file mode 100644 index 00000000..9c7b0791 --- /dev/null +++ b/integrations/spark/openhouse-spark-runtime/src/main/scala/com/linkedin/openhouse/spark/sql/catalyst/plans/logical/SetReplicationPolicy.scala @@ -0,0 +1,9 @@ +package com.linkedin.openhouse.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.plans.logical.Command + +case class SetReplicationPolicy(tableName: Seq[String], replicationPolicies: Seq[(String, Option[String])]) extends Command { + override def simpleString(maxFields: Int): String = { + s"SetReplicationPolicy: ${tableName} ${replicationPolicies}" + } +} diff --git a/integrations/spark/openhouse-spark-runtime/src/main/scala/com/linkedin/openhouse/spark/sql/execution/datasources/v2/OpenhouseDataSourceV2Strategy.scala b/integrations/spark/openhouse-spark-runtime/src/main/scala/com/linkedin/openhouse/spark/sql/execution/datasources/v2/OpenhouseDataSourceV2Strategy.scala index 595c6e06..8545a2bc 100644 --- a/integrations/spark/openhouse-spark-runtime/src/main/scala/com/linkedin/openhouse/spark/sql/execution/datasources/v2/OpenhouseDataSourceV2Strategy.scala +++ b/integrations/spark/openhouse-spark-runtime/src/main/scala/com/linkedin/openhouse/spark/sql/execution/datasources/v2/OpenhouseDataSourceV2Strategy.scala @@ -1,6 +1,6 @@ package com.linkedin.openhouse.spark.sql.execution.datasources.v2 -import com.linkedin.openhouse.spark.sql.catalyst.plans.logical.{GrantRevokeStatement, SetRetentionPolicy, SetSharingPolicy, SetColumnPolicyTag, ShowGrantsStatement} +import com.linkedin.openhouse.spark.sql.catalyst.plans.logical.{GrantRevokeStatement, SetColumnPolicyTag, SetReplicationPolicy, SetRetentionPolicy, SetSharingPolicy, ShowGrantsStatement} import org.apache.iceberg.spark.{Spark3Util, SparkCatalog, SparkSessionCatalog} import org.apache.spark.sql.{SparkSession, Strategy} import org.apache.spark.sql.catalyst.expressions.PredicateHelper @@ -15,6 +15,8 @@ case class OpenhouseDataSourceV2Strategy(spark: SparkSession) extends Strategy w override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case SetRetentionPolicy(CatalogAndIdentifierExtractor(catalog, ident), granularity, count, colName, colPattern) => SetRetentionPolicyExec(catalog, ident, granularity, count, colName, colPattern) :: Nil + case SetReplicationPolicy(CatalogAndIdentifierExtractor(catalog, ident), replicationPolicies) => + SetReplicationPolicyExec(catalog, ident, replicationPolicies) :: Nil case SetSharingPolicy(CatalogAndIdentifierExtractor(catalog, ident), sharing) => SetSharingPolicyExec(catalog, ident, sharing) :: Nil case SetColumnPolicyTag(CatalogAndIdentifierExtractor(catalog, ident), policyTag, cols) => diff --git a/integrations/spark/openhouse-spark-runtime/src/main/scala/com/linkedin/openhouse/spark/sql/execution/datasources/v2/SetReplicationPolicyExec.scala b/integrations/spark/openhouse-spark-runtime/src/main/scala/com/linkedin/openhouse/spark/sql/execution/datasources/v2/SetReplicationPolicyExec.scala new file mode 100644 index 00000000..5fd33236 --- /dev/null +++ b/integrations/spark/openhouse-spark-runtime/src/main/scala/com/linkedin/openhouse/spark/sql/execution/datasources/v2/SetReplicationPolicyExec.scala @@ -0,0 +1,26 @@ +package com.linkedin.openhouse.spark.sql.execution.datasources.v2 + +import org.apache.iceberg.spark.source.SparkTable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog} +import org.apache.spark.sql.execution.datasources.v2.V2CommandExec + +case class SetReplicationPolicyExec(catalog: TableCatalog, ident: Identifier, replicationPolicies: Seq[(String, Option[String])]) extends V2CommandExec{ + override protected def run(): Seq[InternalRow] = { + catalog.loadTable(ident) match { + case iceberg: SparkTable if iceberg.table().properties().containsKey("openhouse.tableId") => + val key = "updated.openhouse.policy" + val value = s"""{"replication":{"config":[${replicationPolicies.map(replication => s"""{"destination":"${replication._1}","interval":"${replication._2.getOrElse("")}"}""").mkString(",")}]}}""" + iceberg.table().updateProperties() + .set(key, value) + .commit() + + case table => + throw new UnsupportedOperationException(s"Cannot set replication policy for non-Openhouse table: $table") + } + Nil + } + + override def output: Seq[Attribute] = Nil +}