diff --git a/pgLabII.PgUtils.Tests/ConnectionStrings/UrlCodecTests.cs b/pgLabII.PgUtils.Tests/ConnectionStrings/UrlCodecTests.cs new file mode 100644 index 0000000..8d1c39b --- /dev/null +++ b/pgLabII.PgUtils.Tests/ConnectionStrings/UrlCodecTests.cs @@ -0,0 +1,85 @@ +using System.Collections.Generic; +using Npgsql; +using pgLabII.PgUtils.ConnectionStrings; + +namespace pgLabII.PgUtils.Tests.ConnectionStrings; + +public class UrlCodecTests +{ + [Fact] + public void Parse_Basic() + { + var codec = new UrlCodec(); + var res = codec.TryParse("postgresql://alice:secret@localhost:5433/testdb?sslmode=require&application_name=pgLab&connect_timeout=12&search_path=public"); + Assert.True(res.IsSuccess); + var d = res.Value; + Assert.Single(d.Hosts); + Assert.Equal("localhost", d.Hosts[0].Host); + Assert.Equal((ushort)5433, d.Hosts[0].Port); + Assert.Equal("testdb", d.Database); + Assert.Equal("alice", d.Username); + Assert.Equal("secret", d.Password); + Assert.Equal(SslMode.Require, d.SslMode); + Assert.Equal("pgLab", d.ApplicationName); + Assert.Equal(12, d.TimeoutSeconds); + Assert.True(d.Properties.ContainsKey("search_path")); + Assert.Equal("public", d.Properties["search_path"]); + } + + [Fact] + public void Parse_MultiHost_WithIPv6() + { + var codec = new UrlCodec(); + var res = codec.TryParse("postgresql://user@[::1]:5432,db.example.com:5433/db"); + Assert.True(res.IsSuccess); + var d = res.Value; + Assert.Equal(2, d.Hosts.Count); + Assert.Equal("::1", d.Hosts[0].Host); + Assert.Equal((ushort)5432, d.Hosts[0].Port); + Assert.Equal("db.example.com", d.Hosts[1].Host); + Assert.Equal((ushort)5433, d.Hosts[1].Port); + Assert.Equal("db", d.Database); + } + + [Fact] + public void Format_Basic_WithEncoding() + { + var codec = new UrlCodec(); + var d = new ConnectionDescriptor + { + Hosts = new [] { new HostEndpoint{ Host = "db.example.com", Port = 5432 } }, + Database = "prod db", + Username = "bob", + Password = "p@ss w?rd", + SslMode = SslMode.VerifyFull, + ApplicationName = "cli app", + TimeoutSeconds = 7, + Properties = new Dictionary{{"search_path","public"}} + }; + var res = codec.TryFormat(d); + Assert.True(res.IsSuccess); + var s = res.Value; + Assert.StartsWith("postgresql://", s); + Assert.Contains("bob:p%40ss%20w%3Frd@db.example.com:5432", s); + Assert.Contains("/prod%20db", s); + Assert.Contains("sslmode=verify-full", s); + Assert.Contains("application_name=cli%20app", s); + Assert.Contains("connect_timeout=7", s); + Assert.Contains("search_path=public", s); + } + + [Fact] + public void Roundtrip_ParseThenFormat() + { + var codec = new UrlCodec(); + var input = "postgresql://me:pa%3Ass@host1,host2:5433/postgres?application_name=my%20app&sslmode=prefer"; + var parsed = codec.TryParse(input); + Assert.True(parsed.IsSuccess); + var formatted = codec.TryFormat(parsed.Value); + Assert.True(formatted.IsSuccess); + var s = formatted.Value; + Assert.StartsWith("postgresql://me:pa%3Ass@host1,host2:5433/postgres?", s); + Assert.Contains("application_name=my%20app", s); + Assert.Contains("sslmode=prefer", s); + } +} diff --git a/pgLabII.PgUtils/ConnectionStrings/UrlCodec.cs b/pgLabII.PgUtils/ConnectionStrings/UrlCodec.cs new file mode 100644 index 0000000..3714d94 --- /dev/null +++ b/pgLabII.PgUtils/ConnectionStrings/UrlCodec.cs @@ -0,0 +1,354 @@ +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Net; +using System.Text; +using FluentResults; +using Npgsql; + +namespace pgLabII.PgUtils.ConnectionStrings; + +/// +/// Codec for PostgreSQL URL format: postgresql://[user[:password]@]host[:port][,hostN[:portN]]/[database]?param=value&... +/// - Supports multi-host by comma-separated host:port entries in authority part. +/// - Supports IPv6 literals in square brackets, e.g. [::1]:5432. +/// - Percent-decodes user, password, database, and query param values on parse. +/// - Percent-encodes user, password, database, and query values on format. +/// - Maps sslmode, application_name, connect_timeout into descriptor fields; others preserved in Properties. +/// +public sealed class UrlCodec : IConnectionStringCodec +{ + public ConnStringFormat Format => ConnStringFormat.Url; + public string FormatName => "URL"; + + public Result TryParse(string input) + { + try + { + if (string.IsNullOrWhiteSpace(input)) + return Result.Fail("Empty URL"); + + // Accept schemes postgresql:// or postgres:// + if (!input.StartsWith("postgresql://", StringComparison.OrdinalIgnoreCase) && + !input.StartsWith("postgres://", StringComparison.OrdinalIgnoreCase)) + return Result.Fail("URL must start with postgresql:// or postgres://"); + + // We cannot rely entirely on System.Uri for multi-host. We'll manually split scheme and the rest. + var span = input.AsSpan(); + int schemeSep = input.IndexOf("//", StringComparison.Ordinal); + if (schemeSep < 0) + return Result.Fail("Invalid URL: missing //"); + var scheme = input.Substring(0, schemeSep - 1); // includes ':' + var rest = input.Substring(schemeSep + 2); + + // Split off path and query/fragment + string authorityAndMaybeMore = rest; + string pathAndQuery = string.Empty; + int slashIdx = rest.IndexOf('/'); + if (slashIdx >= 0) + { + authorityAndMaybeMore = rest.Substring(0, slashIdx); + pathAndQuery = rest.Substring(slashIdx); // starts with '/' + } + + string userInfo = string.Empty; + string authority = authorityAndMaybeMore; + int atIdx = authorityAndMaybeMore.IndexOf('@'); + if (atIdx >= 0) + { + userInfo = authorityAndMaybeMore.Substring(0, atIdx); + authority = authorityAndMaybeMore.Substring(atIdx + 1); + } + + var builder = new ConnectionDescriptorBuilder(); + + // Parse userinfo + if (!string.IsNullOrEmpty(userInfo)) + { + var up = userInfo.Split(':', 2); + if (up.Length > 0 && up[0].Length > 0) + builder.Username = Uri.UnescapeDataString(up[0]); + if (up.Length > 1) + builder.Password = Uri.UnescapeDataString(up[1]); + } + + // Parse hosts (may be comma-separated) + foreach (var hostPart in SplitHosts(authority)) + { + if (string.IsNullOrWhiteSpace(hostPart)) continue; + ParseHostPort(hostPart, out var host, out ushort? port); + if (!string.IsNullOrEmpty(host)) + builder.AddHost(host!, port); + } + + // Parse path (database) and query + string database = string.Empty; + string query = string.Empty; + if (!string.IsNullOrEmpty(pathAndQuery)) + { + // pathAndQuery like /db?x=y + var qIdx = pathAndQuery.IndexOf('?'); + string path = qIdx >= 0 ? pathAndQuery.Substring(0, qIdx) : pathAndQuery; + query = qIdx >= 0 ? pathAndQuery.Substring(qIdx + 1) : string.Empty; + if (path.Length > 0) + { + // strip leading '/' + if (path[0] == '/') path = path.Substring(1); + if (path.Length > 0) + database = Uri.UnescapeDataString(path); + } + } + if (!string.IsNullOrEmpty(database)) builder.Database = database; + + var queryDict = ParseQuery(query); + + // Map known params + if (queryDict.TryGetValue("sslmode", out var sslVal)) + builder.SslMode = ParseSslMode(sslVal); + if (queryDict.TryGetValue("application_name", out var app)) + builder.ApplicationName = app; + if (queryDict.TryGetValue("connect_timeout", out var tout) && int.TryParse(tout, NumberStyles.Integer, CultureInfo.InvariantCulture, out var ts)) + builder.TimeoutSeconds = ts; + + // Preserve extras + foreach (var (k, v) in queryDict) + { + if (!IsMappedQueryKey(k)) + builder.Properties[k] = v; + } + + return Result.Ok(builder.Build()); + } + catch (Exception ex) + { + return Result.Fail(ex.Message); + } + } + + public Result TryFormat(ConnectionDescriptor descriptor) + { + try + { + var sb = new StringBuilder(); + sb.Append("postgresql://"); + + // userinfo + if (!string.IsNullOrEmpty(descriptor.Username)) + { + sb.Append(Uri.EscapeDataString(descriptor.Username)); + if (!string.IsNullOrEmpty(descriptor.Password)) + { + sb.Append(':'); + sb.Append(Uri.EscapeDataString(descriptor.Password)); + } + sb.Append('@'); + } + + // hosts + if (descriptor.Hosts != null && descriptor.Hosts.Count > 0) + { + var hostParts = new List(descriptor.Hosts.Count); + foreach (var h in descriptor.Hosts) + { + var host = h.Host ?? string.Empty; + bool isIpv6 = host.Contains(':') && !host.StartsWith("[") && !host.EndsWith("]"); + if (isIpv6) + host = "[" + host + "]"; + if (h.Port.HasValue) + host += ":" + h.Port.Value.ToString(CultureInfo.InvariantCulture); + hostParts.Add(host); + } + sb.Append(string.Join(',', hostParts)); + } + + // path (database) + sb.Append('/'); + if (!string.IsNullOrEmpty(descriptor.Database)) + sb.Append(Uri.EscapeDataString(descriptor.Database)); + + // query + var queryPairs = new List(); + if (descriptor.SslMode.HasValue) + queryPairs.Add("sslmode=" + Uri.EscapeDataString(FormatSslMode(descriptor.SslMode.Value))); + if (!string.IsNullOrEmpty(descriptor.ApplicationName)) + queryPairs.Add("application_name=" + Uri.EscapeDataString(descriptor.ApplicationName)); + if (descriptor.TimeoutSeconds.HasValue) + queryPairs.Add("connect_timeout=" + Uri.EscapeDataString(descriptor.TimeoutSeconds.Value.ToString(CultureInfo.InvariantCulture))); + + // Add extra properties not already emitted + var emitted = new HashSet(queryPairs.Select(p => p.Split('=')[0]), StringComparer.OrdinalIgnoreCase); + foreach (var kv in descriptor.Properties) + { + if (emitted.Contains(kv.Key)) continue; + queryPairs.Add(Uri.EscapeDataString(kv.Key) + "=" + Uri.EscapeDataString(kv.Value ?? string.Empty)); + } + + if (queryPairs.Count > 0) + { + sb.Append('?'); + sb.Append(string.Join('&', queryPairs)); + } + + return Result.Ok(sb.ToString()); + } + catch (Exception ex) + { + return Result.Fail(ex.Message); + } + } + + private static bool IsMappedQueryKey(string key) + => key.Equals("sslmode", StringComparison.OrdinalIgnoreCase) + || key.Equals("application_name", StringComparison.OrdinalIgnoreCase) + || key.Equals("connect_timeout", StringComparison.OrdinalIgnoreCase); + + private static IEnumerable SplitHosts(string authority) + { + // authority may contain comma-separated hosts, each may be IPv6 [..] with optional :port + // We split on commas that are not inside brackets + var parts = new List(); + int depth = 0; + int start = 0; + for (int i = 0; i < authority.Length; i++) + { + char c = authority[i]; + if (c == '[') depth++; + else if (c == ']') depth = Math.Max(0, depth - 1); + else if (c == ',' && depth == 0) + { + parts.Add(authority.Substring(start, i - start)); + start = i + 1; + } + } + // last + if (start <= authority.Length) + parts.Add(authority.Substring(start)); + return parts.Select(p => p.Trim()).Where(p => p.Length > 0); + } + + private static void ParseHostPort(string hostPart, out string host, out ushort? port) + { + host = string.Empty; port = null; + if (string.IsNullOrWhiteSpace(hostPart)) return; + + if (hostPart[0] == '[') + { + // IPv6 literal [....]:port? + int end = hostPart.IndexOf(']'); + if (end < 0) + { + host = hostPart; // let it pass raw + return; + } + var h = hostPart.Substring(1, end - 1); + host = h; + if (end + 1 < hostPart.Length && hostPart[end + 1] == ':') + { + var ps = hostPart.Substring(end + 2); + if (ushort.TryParse(ps, NumberStyles.Integer, CultureInfo.InvariantCulture, out var up)) + port = up; + } + return; + } + // non-IPv6, split last ':' as port if numeric + int colon = hostPart.LastIndexOf(':'); + if (colon > 0 && colon < hostPart.Length - 1) + { + var ps = hostPart.Substring(colon + 1); + if (ushort.TryParse(ps, NumberStyles.Integer, CultureInfo.InvariantCulture, out var up)) + { + port = up; + host = hostPart.Substring(0, colon); + return; + } + } + host = hostPart; + } + + private static SslMode ParseSslMode(string s) + { + switch (s.Trim().ToLowerInvariant()) + { + case "disable": return SslMode.Disable; + case "allow": return SslMode.Allow; + case "prefer": return SslMode.Prefer; + case "require": return SslMode.Require; + case "verify-ca": + case "verifyca": return SslMode.VerifyCA; + case "verify-full": + case "verifyfull": return SslMode.VerifyFull; + default: throw new ArgumentException($"Not a valid sslmode: {s}"); + } + } + + private static string FormatSslMode(SslMode mode) + { + return mode switch + { + SslMode.Disable => "disable", + SslMode.Allow => "allow", + SslMode.Prefer => "prefer", + SslMode.Require => "require", + SslMode.VerifyCA => "verify-ca", + SslMode.VerifyFull => "verify-full", + _ => "prefer" + }; + } + + private sealed class ConnectionDescriptorBuilder + { + public List Hosts { get; } = new(); + public string? Database { get; set; } + public string? Username { get; set; } + public string? Password { get; set; } + public SslMode? SslMode { get; set; } + public string? ApplicationName { get; set; } + public int? TimeoutSeconds { get; set; } + public Dictionary Properties { get; } = new(StringComparer.OrdinalIgnoreCase); + + public void AddHost(string host, ushort? port) + { + if (string.IsNullOrWhiteSpace(host)) return; + Hosts.Add(new HostEndpoint { Host = host.Trim(), Port = port }); + } + + public ConnectionDescriptor Build() + { + return new ConnectionDescriptor + { + Hosts = Hosts, + Database = Database, + Username = Username, + Password = Password, + SslMode = SslMode, + ApplicationName = ApplicationName, + TimeoutSeconds = TimeoutSeconds, + Properties = Properties + }; + } + } + + private static Dictionary ParseQuery(string query) + { + var dict = new Dictionary(StringComparer.OrdinalIgnoreCase); + if (string.IsNullOrEmpty(query)) return dict; + var pairs = query.Split('&', StringSplitOptions.RemoveEmptyEntries); + foreach (var pair in pairs) + { + var idx = pair.IndexOf('='); + if (idx < 0) + { + var k = Uri.UnescapeDataString(pair); + dict[k] = string.Empty; + } + else + { + var k = Uri.UnescapeDataString(pair.Substring(0, idx)); + var v = Uri.UnescapeDataString(pair.Substring(idx + 1)); + dict[k] = v; + } + } + return dict; + } +}