pgLabII/pgLabII.PgUtils/ConnectionStrings/UrlCodec.cs

355 lines
13 KiB
C#
Raw Normal View History

2025-08-30 20:32:35 +02:00
using System;
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using System.Net;
using System.Text;
using FluentResults;
using Npgsql;
namespace pgLabII.PgUtils.ConnectionStrings;
/// <summary>
/// Codec for PostgreSQL URL format: postgresql://[user[:password]@]host[:port][,hostN[:portN]]/[database]?param=value&...
/// - Supports multi-host by comma-separated host:port entries in authority part.
/// - Supports IPv6 literals in square brackets, e.g. [::1]:5432.
/// - Percent-decodes user, password, database, and query param values on parse.
/// - Percent-encodes user, password, database, and query values on format.
/// - Maps sslmode, application_name, connect_timeout into descriptor fields; others preserved in Properties.
/// </summary>
public sealed class UrlCodec : IConnectionStringCodec
{
public ConnStringFormat Format => ConnStringFormat.Url;
public string FormatName => "URL";
public Result<ConnectionDescriptor> TryParse(string input)
{
try
{
if (string.IsNullOrWhiteSpace(input))
return Result.Fail<ConnectionDescriptor>("Empty URL");
// Accept schemes postgresql:// or postgres://
if (!input.StartsWith("postgresql://", StringComparison.OrdinalIgnoreCase) &&
!input.StartsWith("postgres://", StringComparison.OrdinalIgnoreCase))
return Result.Fail<ConnectionDescriptor>("URL must start with postgresql:// or postgres://");
// We cannot rely entirely on System.Uri for multi-host. We'll manually split scheme and the rest.
var span = input.AsSpan();
int schemeSep = input.IndexOf("//", StringComparison.Ordinal);
if (schemeSep < 0)
return Result.Fail<ConnectionDescriptor>("Invalid URL: missing //");
var scheme = input.Substring(0, schemeSep - 1); // includes ':'
var rest = input.Substring(schemeSep + 2);
// Split off path and query/fragment
string authorityAndMaybeMore = rest;
string pathAndQuery = string.Empty;
int slashIdx = rest.IndexOf('/');
if (slashIdx >= 0)
{
authorityAndMaybeMore = rest.Substring(0, slashIdx);
pathAndQuery = rest.Substring(slashIdx); // starts with '/'
}
string userInfo = string.Empty;
string authority = authorityAndMaybeMore;
int atIdx = authorityAndMaybeMore.IndexOf('@');
if (atIdx >= 0)
{
userInfo = authorityAndMaybeMore.Substring(0, atIdx);
authority = authorityAndMaybeMore.Substring(atIdx + 1);
}
var builder = new ConnectionDescriptorBuilder();
// Parse userinfo
if (!string.IsNullOrEmpty(userInfo))
{
var up = userInfo.Split(':', 2);
if (up.Length > 0 && up[0].Length > 0)
builder.Username = Uri.UnescapeDataString(up[0]);
if (up.Length > 1)
builder.Password = Uri.UnescapeDataString(up[1]);
}
// Parse hosts (may be comma-separated)
foreach (var hostPart in SplitHosts(authority))
{
if (string.IsNullOrWhiteSpace(hostPart)) continue;
ParseHostPort(hostPart, out var host, out ushort? port);
if (!string.IsNullOrEmpty(host))
builder.AddHost(host!, port);
}
// Parse path (database) and query
string database = string.Empty;
string query = string.Empty;
if (!string.IsNullOrEmpty(pathAndQuery))
{
// pathAndQuery like /db?x=y
var qIdx = pathAndQuery.IndexOf('?');
string path = qIdx >= 0 ? pathAndQuery.Substring(0, qIdx) : pathAndQuery;
query = qIdx >= 0 ? pathAndQuery.Substring(qIdx + 1) : string.Empty;
if (path.Length > 0)
{
// strip leading '/'
if (path[0] == '/') path = path.Substring(1);
if (path.Length > 0)
database = Uri.UnescapeDataString(path);
}
}
if (!string.IsNullOrEmpty(database)) builder.Database = database;
var queryDict = ParseQuery(query);
// Map known params
if (queryDict.TryGetValue("sslmode", out var sslVal))
builder.SslMode = ParseSslMode(sslVal);
if (queryDict.TryGetValue("application_name", out var app))
builder.ApplicationName = app;
if (queryDict.TryGetValue("connect_timeout", out var tout) && int.TryParse(tout, NumberStyles.Integer, CultureInfo.InvariantCulture, out var ts))
builder.TimeoutSeconds = ts;
// Preserve extras
foreach (var (k, v) in queryDict)
{
if (!IsMappedQueryKey(k))
builder.Properties[k] = v;
}
return Result.Ok(builder.Build());
}
catch (Exception ex)
{
return Result.Fail<ConnectionDescriptor>(ex.Message);
}
}
public Result<string> TryFormat(ConnectionDescriptor descriptor)
{
try
{
var sb = new StringBuilder();
sb.Append("postgresql://");
// userinfo
if (!string.IsNullOrEmpty(descriptor.Username))
{
sb.Append(Uri.EscapeDataString(descriptor.Username));
if (!string.IsNullOrEmpty(descriptor.Password))
{
sb.Append(':');
sb.Append(Uri.EscapeDataString(descriptor.Password));
}
sb.Append('@');
}
// hosts
if (descriptor.Hosts != null && descriptor.Hosts.Count > 0)
{
var hostParts = new List<string>(descriptor.Hosts.Count);
foreach (var h in descriptor.Hosts)
{
var host = h.Host ?? string.Empty;
bool isIpv6 = host.Contains(':') && !host.StartsWith("[") && !host.EndsWith("]");
if (isIpv6)
host = "[" + host + "]";
if (h.Port.HasValue)
host += ":" + h.Port.Value.ToString(CultureInfo.InvariantCulture);
hostParts.Add(host);
}
sb.Append(string.Join(',', hostParts));
}
// path (database)
sb.Append('/');
if (!string.IsNullOrEmpty(descriptor.Database))
sb.Append(Uri.EscapeDataString(descriptor.Database));
// query
var queryPairs = new List<string>();
if (descriptor.SslMode.HasValue)
queryPairs.Add("sslmode=" + Uri.EscapeDataString(FormatSslMode(descriptor.SslMode.Value)));
if (!string.IsNullOrEmpty(descriptor.ApplicationName))
queryPairs.Add("application_name=" + Uri.EscapeDataString(descriptor.ApplicationName));
if (descriptor.TimeoutSeconds.HasValue)
queryPairs.Add("connect_timeout=" + Uri.EscapeDataString(descriptor.TimeoutSeconds.Value.ToString(CultureInfo.InvariantCulture)));
// Add extra properties not already emitted
var emitted = new HashSet<string>(queryPairs.Select(p => p.Split('=')[0]), StringComparer.OrdinalIgnoreCase);
foreach (var kv in descriptor.Properties)
{
if (emitted.Contains(kv.Key)) continue;
queryPairs.Add(Uri.EscapeDataString(kv.Key) + "=" + Uri.EscapeDataString(kv.Value ?? string.Empty));
}
if (queryPairs.Count > 0)
{
sb.Append('?');
sb.Append(string.Join('&', queryPairs));
}
return Result.Ok(sb.ToString());
}
catch (Exception ex)
{
return Result.Fail<string>(ex.Message);
}
}
private static bool IsMappedQueryKey(string key)
=> key.Equals("sslmode", StringComparison.OrdinalIgnoreCase)
|| key.Equals("application_name", StringComparison.OrdinalIgnoreCase)
|| key.Equals("connect_timeout", StringComparison.OrdinalIgnoreCase);
private static IEnumerable<string> SplitHosts(string authority)
{
// authority may contain comma-separated hosts, each may be IPv6 [..] with optional :port
// We split on commas that are not inside brackets
var parts = new List<string>();
int depth = 0;
int start = 0;
for (int i = 0; i < authority.Length; i++)
{
char c = authority[i];
if (c == '[') depth++;
else if (c == ']') depth = Math.Max(0, depth - 1);
else if (c == ',' && depth == 0)
{
parts.Add(authority.Substring(start, i - start));
start = i + 1;
}
}
// last
if (start <= authority.Length)
parts.Add(authority.Substring(start));
return parts.Select(p => p.Trim()).Where(p => p.Length > 0);
}
private static void ParseHostPort(string hostPart, out string host, out ushort? port)
{
host = string.Empty; port = null;
if (string.IsNullOrWhiteSpace(hostPart)) return;
if (hostPart[0] == '[')
{
// IPv6 literal [....]:port?
int end = hostPart.IndexOf(']');
if (end < 0)
{
host = hostPart; // let it pass raw
return;
}
var h = hostPart.Substring(1, end - 1);
host = h;
if (end + 1 < hostPart.Length && hostPart[end + 1] == ':')
{
var ps = hostPart.Substring(end + 2);
if (ushort.TryParse(ps, NumberStyles.Integer, CultureInfo.InvariantCulture, out var up))
port = up;
}
return;
}
// non-IPv6, split last ':' as port if numeric
int colon = hostPart.LastIndexOf(':');
if (colon > 0 && colon < hostPart.Length - 1)
{
var ps = hostPart.Substring(colon + 1);
if (ushort.TryParse(ps, NumberStyles.Integer, CultureInfo.InvariantCulture, out var up))
{
port = up;
host = hostPart.Substring(0, colon);
return;
}
}
host = hostPart;
}
private static SslMode ParseSslMode(string s)
{
switch (s.Trim().ToLowerInvariant())
{
case "disable": return SslMode.Disable;
case "allow": return SslMode.Allow;
case "prefer": return SslMode.Prefer;
case "require": return SslMode.Require;
case "verify-ca":
case "verifyca": return SslMode.VerifyCA;
case "verify-full":
case "verifyfull": return SslMode.VerifyFull;
default: throw new ArgumentException($"Not a valid sslmode: {s}");
}
}
private static string FormatSslMode(SslMode mode)
{
return mode switch
{
SslMode.Disable => "disable",
SslMode.Allow => "allow",
SslMode.Prefer => "prefer",
SslMode.Require => "require",
SslMode.VerifyCA => "verify-ca",
SslMode.VerifyFull => "verify-full",
_ => "prefer"
};
}
private sealed class ConnectionDescriptorBuilder
{
public List<HostEndpoint> Hosts { get; } = new();
public string? Database { get; set; }
public string? Username { get; set; }
public string? Password { get; set; }
public SslMode? SslMode { get; set; }
public string? ApplicationName { get; set; }
public int? TimeoutSeconds { get; set; }
public Dictionary<string, string> Properties { get; } = new(StringComparer.OrdinalIgnoreCase);
public void AddHost(string host, ushort? port)
{
if (string.IsNullOrWhiteSpace(host)) return;
Hosts.Add(new HostEndpoint { Host = host.Trim(), Port = port });
}
public ConnectionDescriptor Build()
{
return new ConnectionDescriptor
{
Hosts = Hosts,
Database = Database,
Username = Username,
Password = Password,
SslMode = SslMode,
ApplicationName = ApplicationName,
TimeoutSeconds = TimeoutSeconds,
Properties = Properties
};
}
}
private static Dictionary<string, string> ParseQuery(string query)
{
var dict = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
if (string.IsNullOrEmpty(query)) return dict;
var pairs = query.Split('&', StringSplitOptions.RemoveEmptyEntries);
foreach (var pair in pairs)
{
var idx = pair.IndexOf('=');
if (idx < 0)
{
var k = Uri.UnescapeDataString(pair);
dict[k] = string.Empty;
}
else
{
var k = Uri.UnescapeDataString(pair.Substring(0, idx));
var v = Uri.UnescapeDataString(pair.Substring(idx + 1));
dict[k] = v;
}
}
return dict;
}
}