pgLabII/pgLabII.PgUtils/ConnectionStrings/UrlCodec.cs

201 lines
8.2 KiB
C#
Raw Permalink Normal View History

using System.Globalization;
2025-08-30 20:32:35 +02:00
using System.Text;
using FluentResults;
namespace pgLabII.PgUtils.ConnectionStrings;
/// <summary>
/// Codec for PostgreSQL URL format: postgresql://[user[:password]@]host[:port][,hostN[:portN]]/[database]?param=value&...
/// - Supports multi-host by comma-separated host:port entries in authority part.
/// - Supports IPv6 literals in square brackets, e.g. [::1]:5432.
/// - Percent-decodes user, password, database, and query param values on parse.
/// - Percent-encodes user, password, database, and query values on format.
/// - Maps sslmode, application_name, connect_timeout into descriptor fields; others preserved in Properties.
/// </summary>
public sealed class UrlCodec : IConnectionStringCodec
{
public ConnStringFormat Format => ConnStringFormat.Url;
public string FormatName => "URL";
public Result<ConnectionDescriptor> TryParse(string input)
{
try
{
if (string.IsNullOrWhiteSpace(input))
return Result.Fail<ConnectionDescriptor>("Empty URL");
// Accept schemes postgresql:// or postgres://
if (!input.StartsWith("postgresql://", StringComparison.OrdinalIgnoreCase) &&
!input.StartsWith("postgres://", StringComparison.OrdinalIgnoreCase))
return Result.Fail<ConnectionDescriptor>("URL must start with postgresql:// or postgres://");
// We cannot rely entirely on System.Uri for multi-host. We'll manually split scheme and the rest.
var span = input.AsSpan();
int schemeSep = input.IndexOf("//", StringComparison.Ordinal);
if (schemeSep < 0)
return Result.Fail<ConnectionDescriptor>("Invalid URL: missing //");
var scheme = input.Substring(0, schemeSep - 1); // includes ':'
var rest = input.Substring(schemeSep + 2);
// Split off path and query/fragment
string authorityAndMaybeMore = rest;
string pathAndQuery = string.Empty;
int slashIdx = rest.IndexOf('/');
if (slashIdx >= 0)
{
authorityAndMaybeMore = rest.Substring(0, slashIdx);
pathAndQuery = rest.Substring(slashIdx); // starts with '/'
}
string userInfo = string.Empty;
string authority = authorityAndMaybeMore;
int atIdx = authorityAndMaybeMore.IndexOf('@');
if (atIdx >= 0)
{
userInfo = authorityAndMaybeMore.Substring(0, atIdx);
authority = authorityAndMaybeMore.Substring(atIdx + 1);
}
var builder = new ConnectionDescriptorBuilder();
// Parse userinfo
if (!string.IsNullOrEmpty(userInfo))
{
var up = userInfo.Split(':', 2);
if (up.Length > 0 && up[0].Length > 0)
builder.Username = Uri.UnescapeDataString(up[0]);
if (up.Length > 1)
builder.Password = Uri.UnescapeDataString(up[1]);
}
// Parse hosts (maybe comma-separated)
foreach (string hostPart in CodecCommon.SplitHosts(authority))
2025-08-30 20:32:35 +02:00
{
CodecCommon.ParseHostPort(hostPart, out string host, out ushort? port);
2025-08-30 20:32:35 +02:00
if (!string.IsNullOrEmpty(host))
builder.AddHost(host, port);
2025-08-30 20:32:35 +02:00
}
// Parse path (database) and query
string database = string.Empty;
string query = string.Empty;
if (!string.IsNullOrEmpty(pathAndQuery))
{
// pathAndQuery like /db?x=y
int qIdx = pathAndQuery.IndexOf('?');
string path = qIdx >= 0 ? pathAndQuery[..qIdx] : pathAndQuery;
query = qIdx >= 0 ? pathAndQuery[(qIdx + 1)..] : string.Empty;
2025-08-30 20:32:35 +02:00
if (path.Length > 0)
{
// strip leading '/'
if (path[0] == '/')
path = path[1..];
2025-08-30 20:32:35 +02:00
if (path.Length > 0)
database = Uri.UnescapeDataString(path);
}
}
if (!string.IsNullOrEmpty(database)) builder.Database = database;
var queryDict = CodecCommon.ParseQuery(query);
2025-08-30 20:32:35 +02:00
// Map known params
if (queryDict.TryGetValue("sslmode", out var sslVal))
builder.SslMode = CodecCommon.ParseSslModeLoose(sslVal);
2025-08-30 20:32:35 +02:00
if (queryDict.TryGetValue("application_name", out var app))
builder.ApplicationName = app;
if (queryDict.TryGetValue("connect_timeout", out var tout) && int.TryParse(tout, NumberStyles.Integer, CultureInfo.InvariantCulture, out var ts))
builder.TimeoutSeconds = ts;
// Preserve extras
foreach (var (k, v) in queryDict)
{
if (!IsMappedQueryKey(k))
builder.Properties[k] = v;
}
return Result.Ok(builder.Build());
}
catch (Exception ex)
{
return Result.Fail<ConnectionDescriptor>(ex.Message);
}
}
public Result<string> TryFormat(ConnectionDescriptor descriptor)
{
try
{
var sb = new StringBuilder();
sb.Append("postgresql://");
// userinfo
if (!string.IsNullOrEmpty(descriptor.Username))
{
sb.Append(Uri.EscapeDataString(descriptor.Username));
if (!string.IsNullOrEmpty(descriptor.Password))
{
sb.Append(':');
sb.Append(Uri.EscapeDataString(descriptor.Password));
}
sb.Append('@');
}
// hosts
if (descriptor.Hosts.Count > 0)
2025-08-30 20:32:35 +02:00
{
var hostParts = new List<string>(descriptor.Hosts.Count);
foreach (var h in descriptor.Hosts)
{
var host = h.Host ?? string.Empty;
bool isIpv6 = host.Contains(':') && !host.StartsWith("[") && !host.EndsWith("]");
if (isIpv6)
host = "[" + host + "]";
if (h.Port.HasValue)
host += ":" + h.Port.Value.ToString(CultureInfo.InvariantCulture);
hostParts.Add(host);
}
sb.Append(string.Join(',', hostParts));
}
// path (database)
sb.Append('/');
if (!string.IsNullOrEmpty(descriptor.Database))
sb.Append(Uri.EscapeDataString(descriptor.Database));
// query
var queryPairs = new List<string>();
if (descriptor.SslMode.HasValue)
queryPairs.Add("sslmode=" + Uri.EscapeDataString(CodecCommon.FormatSslModeUrlLike(descriptor.SslMode.Value)));
2025-08-30 20:32:35 +02:00
if (!string.IsNullOrEmpty(descriptor.ApplicationName))
queryPairs.Add("application_name=" + Uri.EscapeDataString(descriptor.ApplicationName));
if (descriptor.TimeoutSeconds.HasValue)
queryPairs.Add("connect_timeout=" + Uri.EscapeDataString(descriptor.TimeoutSeconds.Value.ToString(CultureInfo.InvariantCulture)));
// Add extra properties not already emitted
var emitted = new HashSet<string>(queryPairs.Select(p => p.Split('=')[0]), StringComparer.OrdinalIgnoreCase);
foreach (var kv in descriptor.Properties)
{
if (emitted.Contains(kv.Key)) continue;
queryPairs.Add(Uri.EscapeDataString(kv.Key) + "=" + Uri.EscapeDataString(kv.Value ?? string.Empty));
}
if (queryPairs.Count > 0)
{
sb.Append('?');
sb.Append(string.Join('&', queryPairs));
}
return Result.Ok(sb.ToString());
}
catch (Exception ex)
{
return Result.Fail<string>(ex.Message);
}
}
private static bool IsMappedQueryKey(string key)
=> key.Equals("sslmode", StringComparison.OrdinalIgnoreCase)
|| key.Equals("application_name", StringComparison.OrdinalIgnoreCase)
|| key.Equals("connect_timeout", StringComparison.OrdinalIgnoreCase);
}