diff --git a/src/deploy-cromwell-on-azure/Deployer.cs b/src/deploy-cromwell-on-azure/Deployer.cs index 5b71bf93..7e7f9c89 100644 --- a/src/deploy-cromwell-on-azure/Deployer.cs +++ b/src/deploy-cromwell-on-azure/Deployer.cs @@ -167,6 +167,7 @@ await Execute("Connecting to Azure Services...", async () => var keyVaultUri = string.Empty; IIdentity managedIdentity = null; IIdentity aksNodepoolIdentity = null; + INetwork aksVnet = null; IPrivateDnsZone postgreSqlDnsZone = null; IKubernetes kubernetesClient = null; @@ -329,6 +330,10 @@ await Execute("Connecting to Azure Services...", async () => if (aksCluster is not null) { aksNodepoolIdentity = await GetUserManagedIdentityAsync(aksCluster.Identity.UserAssignedIdentities.First().Value.PrincipalId); + + var aksSubnet = aksCluster.AgentPoolProfiles.Where(x => string.Equals(x.Mode, "System", StringComparison.OrdinalIgnoreCase)).First().VnetSubnetID; + var aksVnetString = aksSubnet.Remove(aksSubnet.IndexOf("/subnets/")); + aksVnet = await azureSubscriptionClient.Networks.GetByIdAsync(aksVnetString); } postgreSqlFlexServer = await ValidateAndGetExistingPostgresqlServer(); @@ -420,6 +425,11 @@ await Execute("Connecting to Azure Services...", async () => } } + if (aksVnet is null) + { + aksVnet = vnetAndSubnet.Value.virtualNetwork; + } + if (string.IsNullOrWhiteSpace(configuration.LogAnalyticsArmId)) { var workspaceName = SdkContext.RandomResourceName(configuration.MainIdentifierPrefix, 15); @@ -459,14 +469,11 @@ await Task.Run(async () => if (postgreSqlFlexServer is null) { - if (configuration.VnetResourceGroupName is not null) - { - postgreSqlDnsZone = await GetExistingPrivateDnsZoneAsync(vnetAndSubnet.Value.virtualNetwork, $"privatelink.postgres.database.azure.com"); - } + postgreSqlDnsZone = await GetExistingPrivateDnsZoneAsync(aksVnet, $"privatelink.postgres.database.azure.com"); if (postgreSqlDnsZone is null) { - postgreSqlDnsZone = await CreatePrivateDnsZoneAsync(vnetAndSubnet.Value.virtualNetwork, $"privatelink.postgres.database.azure.com", "PostgreSQL Server"); + postgreSqlDnsZone = await CreatePrivateDnsZoneAsync(aksVnet, $"privatelink.postgres.database.azure.com", "PostgreSQL Server"); } } @@ -485,7 +492,13 @@ await Task.WhenAll(new Task[] Task.Run(async () => { if (configuration.UsePostgreSqlSingleServer) { - postgreSqlSingleServer ??= await CreateSinglePostgreSqlServerAndDatabaseAsync(postgreSqlSingleManagementClient, vnetAndSubnet.Value.vmSubnet, postgreSqlDnsZone); + var peSubnet = vnetAndSubnet.Value.vmSubnet; + if (aksVnet is not null) + { + peSubnet = aksVnet.Subnets.First().Value; + } + + postgreSqlSingleServer ??= await CreateSinglePostgreSqlServerAndDatabaseAsync(postgreSqlSingleManagementClient, peSubnet, postgreSqlDnsZone); } else { @@ -1383,12 +1396,20 @@ private string GetPostgreSQLCreateCromwellUserCommand(bool useSingleServer, stri private async Task GetExistingPrivateDnsZoneAsync(INetwork virtualNetwork, string name) { - var dnsZone = (await azureSubscriptionClient.PrivateDnsZones - .ListByResourceGroupAsync(configuration.VnetResourceGroupName)) - .SingleOrDefault(a => a.Name.Equals(name, StringComparison.OrdinalIgnoreCase) - && (a.VirtualNetworkLinks.List().Where(x => string.Equals(x.ReferencedVirtualNetworkId, virtualNetwork.Id, StringComparison.OrdinalIgnoreCase)).Count() > 0)); + var dnsZones = (await azureSubscriptionClient.PrivateDnsZones.ListAsync()).Where(x => x.Name.Equals(name, StringComparison.OrdinalIgnoreCase)); + var dnsZonesMap = new Dictionary(); + + foreach (var zone in dnsZones) + { + var pairs = zone.VirtualNetworkLinks.List().Select(x => new KeyValuePair(x.ReferencedVirtualNetworkId, zone)); + foreach (var pair in pairs) + { + dnsZonesMap.Add(pair.Key, pair.Value); + } + } - return dnsZone; + dnsZonesMap.TryGetValue(virtualNetwork.Id, out var privateDnsZone); + return privateDnsZone; } private Task CreatePrivateDnsZoneAsync(INetwork virtualNetwork, string name, string title)