Skip to content

Commit

Permalink
Create database private endpoint to AKS vnet rather than CoA vnet
Browse files Browse the repository at this point in the history
  • Loading branch information
jsaun committed Aug 17, 2023
1 parent d2bb21c commit 4baa749
Showing 1 changed file with 32 additions and 11 deletions.
43 changes: 32 additions & 11 deletions src/deploy-cromwell-on-azure/Deployer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ await Execute("Connecting to Azure Services...", async () =>
var keyVaultUri = string.Empty;
IIdentity managedIdentity = null;
IIdentity aksNodepoolIdentity = null;
INetwork aksVnet = null;
IPrivateDnsZone postgreSqlDnsZone = null;
IKubernetes kubernetesClient = null;

Expand Down Expand Up @@ -329,6 +330,10 @@ await Execute("Connecting to Azure Services...", async () =>
if (aksCluster is not null)
{
aksNodepoolIdentity = await GetUserManagedIdentityAsync(aksCluster.Identity.UserAssignedIdentities.First().Value.PrincipalId);

var aksSubnet = aksCluster.AgentPoolProfiles.Where(x => string.Equals(x.Mode, "System", StringComparison.OrdinalIgnoreCase)).First().VnetSubnetID;
var aksVnetString = aksSubnet.Remove(aksSubnet.IndexOf("/subnets/"));
aksVnet = await azureSubscriptionClient.Networks.GetByIdAsync(aksVnetString);
}

postgreSqlFlexServer = await ValidateAndGetExistingPostgresqlServer();
Expand Down Expand Up @@ -420,6 +425,11 @@ await Execute("Connecting to Azure Services...", async () =>
}
}

if (aksVnet is null)
{
aksVnet = vnetAndSubnet.Value.virtualNetwork;
}

if (string.IsNullOrWhiteSpace(configuration.LogAnalyticsArmId))
{
var workspaceName = SdkContext.RandomResourceName(configuration.MainIdentifierPrefix, 15);
Expand Down Expand Up @@ -459,14 +469,11 @@ await Task.Run(async () =>

if (postgreSqlFlexServer is null)
{
if (configuration.VnetResourceGroupName is not null)
{
postgreSqlDnsZone = await GetExistingPrivateDnsZoneAsync(vnetAndSubnet.Value.virtualNetwork, $"privatelink.postgres.database.azure.com");
}
postgreSqlDnsZone = await GetExistingPrivateDnsZoneAsync(aksVnet, $"privatelink.postgres.database.azure.com");

if (postgreSqlDnsZone is null)
{
postgreSqlDnsZone = await CreatePrivateDnsZoneAsync(vnetAndSubnet.Value.virtualNetwork, $"privatelink.postgres.database.azure.com", "PostgreSQL Server");
postgreSqlDnsZone = await CreatePrivateDnsZoneAsync(aksVnet, $"privatelink.postgres.database.azure.com", "PostgreSQL Server");
}
}

Expand All @@ -485,7 +492,13 @@ await Task.WhenAll(new Task[]
Task.Run(async () => {
if (configuration.UsePostgreSqlSingleServer)
{
postgreSqlSingleServer ??= await CreateSinglePostgreSqlServerAndDatabaseAsync(postgreSqlSingleManagementClient, vnetAndSubnet.Value.vmSubnet, postgreSqlDnsZone);
var peSubnet = vnetAndSubnet.Value.vmSubnet;
if (aksVnet is not null)
{
peSubnet = aksVnet.Subnets.First().Value;
}
postgreSqlSingleServer ??= await CreateSinglePostgreSqlServerAndDatabaseAsync(postgreSqlSingleManagementClient, peSubnet, postgreSqlDnsZone);
}
else
{
Expand Down Expand Up @@ -1383,12 +1396,20 @@ private string GetPostgreSQLCreateCromwellUserCommand(bool useSingleServer, stri

private async Task<IPrivateDnsZone> GetExistingPrivateDnsZoneAsync(INetwork virtualNetwork, string name)
{
var dnsZone = (await azureSubscriptionClient.PrivateDnsZones
.ListByResourceGroupAsync(configuration.VnetResourceGroupName))
.SingleOrDefault(a => a.Name.Equals(name, StringComparison.OrdinalIgnoreCase)
&& (a.VirtualNetworkLinks.List().Where(x => string.Equals(x.ReferencedVirtualNetworkId, virtualNetwork.Id, StringComparison.OrdinalIgnoreCase)).Count() > 0));
var dnsZones = (await azureSubscriptionClient.PrivateDnsZones.ListAsync()).Where(x => x.Name.Equals(name, StringComparison.OrdinalIgnoreCase));
var dnsZonesMap = new Dictionary<string, IPrivateDnsZone>();

foreach (var zone in dnsZones)
{
var pairs = zone.VirtualNetworkLinks.List().Select(x => new KeyValuePair<string, IPrivateDnsZone>(x.ReferencedVirtualNetworkId, zone));
foreach (var pair in pairs)
{
dnsZonesMap.Add(pair.Key, pair.Value);
}
}

return dnsZone;
dnsZonesMap.TryGetValue(virtualNetwork.Id, out var privateDnsZone);
return privateDnsZone;
}

private Task<IPrivateDnsZone> CreatePrivateDnsZoneAsync(INetwork virtualNetwork, string name, string title)
Expand Down

0 comments on commit 4baa749

Please sign in to comment.