using System; using System.Collections.Generic; using System.Globalization; using System.Linq; using System.Text; using FluentResults; using Npgsql; namespace pgLabII.PgUtils.ConnectionStrings; /// /// Parser/formatter for Npgsql-style .NET connection strings. We intentionally do not /// rely on NpgsqlConnectionStringBuilder here because: /// - We need a lossless, format-agnostic round-trip to our ConnectionDescriptor, including /// unknown/extension keys and per-host port lists. NpgsqlConnectionStringBuilder normalizes /// names, may drop unknown keys or coerce values, which breaks lossless conversions. /// - We support multi-host with per-host ports and want to preserve the original textual /// representation across conversions. The builder flattens/rewrites these details. /// - We aim to keep pgLabII.PgUtils independent from Npgsql's evolving parsing rules and /// version-specific behaviors to ensure stable UX and deterministic tests. /// - We need symmetric formatting matching our other codecs (libpq/URL/JDBC) and consistent /// quoting rules across formats. /// If required, we still reference Npgsql for enums and interop types, but parsing/formatting /// is done by this small, well-tested custom codec for full control and stability. /// public sealed class NpgsqlCodec : IConnectionStringCodec { public ConnStringFormat Format => ConnStringFormat.Npgsql; public string FormatName => "Npgsql"; public Result TryParse(string input) { try { var dict = Tokenize(input); var descriptor = new ConnectionDescriptorBuilder(); // Hosts and Ports if (dict.TryGetValue("Host", out var hostVal) || dict.TryGetValue("Server", out hostVal) || dict.TryGetValue("Servers", out hostVal)) { var rawHosts = SplitList(hostVal).ToList(); var hosts = new List(rawHosts.Count); var portsPerHost = new List(rawHosts.Count); // First, extract inline ports from each host entry (e.g., host:5432 or [::1]:5432) foreach (var raw in rawHosts) { ParseHostPort(raw, out var hostOnly, out var inlinePort); hosts.Add(hostOnly); portsPerHost.Add(inlinePort); } // Then, merge values from Port key: single port applies to all hosts missing a port; // list of ports applies 1:1 for hosts that still miss a port. Inline ports take precedence. if (dict.TryGetValue("Port", out var portVal)) { var ports = SplitList(portVal).ToList(); if (ports.Count == 1 && ushort.TryParse(ports[0], NumberStyles.Integer, CultureInfo.InvariantCulture, out var singlePort)) { for (int i = 0; i < portsPerHost.Count; i++) if (!portsPerHost[i].HasValue) portsPerHost[i] = singlePort; } else if (ports.Count == hosts.Count) { for (int i = 0; i < ports.Count; i++) { if (!portsPerHost[i].HasValue && ushort.TryParse(ports[i], NumberStyles.Integer, CultureInfo.InvariantCulture, out var up)) portsPerHost[i] = up; } } } for (int i = 0; i < hosts.Count; i++) { descriptor.AddHost(hosts[i], i < portsPerHost.Count ? portsPerHost[i] : null); } } // Standard fields if (TryGetFirst(dict, out var db, "Database", "Db", "Initial Catalog", "dbname")) descriptor.Database = db; if (TryGetFirst(dict, out var user, "Username", "User ID", "User", "UID")) descriptor.Username = user; if (TryGetFirst(dict, out var pass, "Password", "PWD")) descriptor.Password = pass; if (TryGetFirst(dict, out var app, "Application Name", "ApplicationName")) descriptor.ApplicationName = app; if (TryGetFirst(dict, out var timeout, "Timeout", "Connect Timeout", "Connection Timeout")) { if (int.TryParse(timeout, NumberStyles.Integer, CultureInfo.InvariantCulture, out var t)) descriptor.TimeoutSeconds = t; } if (TryGetFirst(dict, out var ssl, "SSL Mode", "SslMode", "SSLMode")) descriptor.SslMode = ParseSslMode(ssl); // Preserve extras (not mapped) into Properties var mapped = new HashSet(StringComparer.OrdinalIgnoreCase) { "Host","Server","Servers","Port","Database","Db","Initial Catalog","dbname", "Username","User ID","User","UID","Password","PWD","Application Name","ApplicationName", "Timeout","Connect Timeout","Connection Timeout","SSL Mode","SslMode","SSLMode" }; foreach (var (k, v) in dict) { if (!mapped.Contains(k)) descriptor.Properties[k] = v; } return Result.Ok(descriptor.Build()); } catch (Exception ex) { return Result.Fail(ex.Message); } } public Result TryFormat(ConnectionDescriptor descriptor) { try { var parts = new List(); if (descriptor.Hosts != null && descriptor.Hosts.Count > 0) { var hostList = string.Join(',', descriptor.Hosts.Select(h => h.Host)); parts.Add(FormatPair("Host", hostList)); var ports = descriptor.Hosts.Select(h => h.Port).Where(p => p.HasValue).Select(p => p!.Value).Distinct().ToList(); if (ports.Count == 1) { parts.Add(FormatPair("Port", ports[0].ToString(CultureInfo.InvariantCulture))); } else if (ports.Count == 0) { // nothing } else { // Per-host ports if provided 1:1 var perHost = descriptor.Hosts.Select(h => h.Port?.ToString(CultureInfo.InvariantCulture) ?? string.Empty).ToList(); if (perHost.All(s => !string.IsNullOrEmpty(s))) parts.Add(FormatPair("Port", string.Join(',', perHost))); } } if (!string.IsNullOrEmpty(descriptor.Database)) parts.Add(FormatPair("Database", descriptor.Database)); if (!string.IsNullOrEmpty(descriptor.Username)) parts.Add(FormatPair("Username", descriptor.Username)); if (!string.IsNullOrEmpty(descriptor.Password)) parts.Add(FormatPair("Password", descriptor.Password)); if (descriptor.SslMode.HasValue) parts.Add(FormatPair("SSL Mode", FormatSslMode(descriptor.SslMode.Value))); if (!string.IsNullOrEmpty(descriptor.ApplicationName)) parts.Add(FormatPair("Application Name", descriptor.ApplicationName)); if (descriptor.TimeoutSeconds.HasValue) parts.Add(FormatPair("Timeout", descriptor.TimeoutSeconds.Value.ToString(CultureInfo.InvariantCulture))); var emittedKeys = new HashSet(parts.Select(p => p.Split('=')[0].Trim()), StringComparer.OrdinalIgnoreCase); foreach (var kv in descriptor.Properties) { if (!emittedKeys.Contains(kv.Key)) parts.Add(FormatPair(kv.Key, kv.Value)); } return Result.Ok(string.Join(";", parts)); } catch (Exception ex) { return Result.Fail(ex.Message); } } private static IEnumerable SplitList(string s) { return s.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); } private static void ParseHostPort(string hostPart, out string host, out ushort? port) { host = hostPart; port = null; if (string.IsNullOrWhiteSpace(hostPart)) return; // IPv6 in brackets: [::1]:5432 if (hostPart[0] == '[') { int end = hostPart.IndexOf(']'); if (end > 0) { host = hostPart.Substring(1, end - 1); if (end + 1 < hostPart.Length && hostPart[end + 1] == ':') { var ps = hostPart.Substring(end + 2); if (ushort.TryParse(ps, NumberStyles.Integer, CultureInfo.InvariantCulture, out var p)) port = p; } } return; } // Non-IPv6: split on last ':' and ensure right side is numeric int colon = hostPart.LastIndexOf(':'); if (colon > 0 && colon < hostPart.Length - 1) { var ps = hostPart.Substring(colon + 1); if (ushort.TryParse(ps, NumberStyles.Integer, CultureInfo.InvariantCulture, out var p)) { host = hostPart.Substring(0, colon); port = p; } } } private static bool TryGetFirst(Dictionary dict, out string value, params string[] keys) { foreach (var k in keys) { if (dict.TryGetValue(k, out value)) return true; } value = string.Empty; return false; } private static SslMode ParseSslMode(string s) { switch (s.Trim().ToLowerInvariant()) { case "disable": return SslMode.Disable; case "allow": return SslMode.Allow; case "prefer": return SslMode.Prefer; case "require": return SslMode.Require; case "verify-ca": case "verifyca": return SslMode.VerifyCA; case "verify-full": case "verifyfull": return SslMode.VerifyFull; default: throw new ArgumentException($"Not a valid SSL Mode: {s}"); } } private static string FormatSslMode(SslMode mode) { return mode switch { SslMode.Disable => "Disable", SslMode.Allow => "Allow", SslMode.Prefer => "Prefer", SslMode.Require => "Require", SslMode.VerifyCA => "VerifyCA", SslMode.VerifyFull => "VerifyFull", _ => "Prefer" }; } // Npgsql/.NET connection string grammar: semicolon-separated key=value; values with special chars are wrapped in quotes, internal quotes doubled private static string FormatPair(string key, string? value) { value ??= string.Empty; var needsQuotes = NeedsQuoting(value); if (!needsQuotes) return key + "=" + value; return key + "=\"" + EscapeQuoted(value) + "\""; } private static bool NeedsQuoting(string value) { if (value.Length == 0) return true; foreach (var c in value) { if (char.IsWhiteSpace(c) || c == ';' || c == '=' || c == '"') return true; } return false; } private static string EscapeQuoted(string value) { // Double the quotes per standard DbConnectionString rules return value.Replace("\"", "\"\""); } private static Dictionary Tokenize(string input) { // Simple tokenizer for .NET connection strings: key=value pairs separated by semicolons; values may be quoted with double quotes var dict = new Dictionary(StringComparer.OrdinalIgnoreCase); int i = 0; void SkipWs() { while (i < input.Length && char.IsWhiteSpace(input[i])) i++; } while (true) { SkipWs(); if (i >= input.Length) break; // read key int keyStart = i; while (i < input.Length && input[i] != '=') i++; if (i >= input.Length) { break; } var key = input.Substring(keyStart, i - keyStart).Trim(); i++; // skip '=' SkipWs(); // read value string value; if (i < input.Length && input[i] == '"') { i++; // skip opening quote var sb = new StringBuilder(); while (i < input.Length) { char c = input[i++]; if (c == '"') { if (i < input.Length && input[i] == '"') { // doubled quote -> literal quote sb.Append('"'); i++; continue; } else { break; // end quoted value } } else { sb.Append(c); } } value = sb.ToString(); } else { int valStart = i; while (i < input.Length && input[i] != ';') i++; value = input.Substring(valStart, i - valStart).Trim(); } dict[key] = value; // skip to next, if ; present, consume one while (i < input.Length && input[i] != ';') i++; if (i < input.Length && input[i] == ';') i++; } return dict; } private sealed class ConnectionDescriptorBuilder { public List Hosts { get; } = new(); public string? Database { get; set; } public string? Username { get; set; } public string? Password { get; set; } public SslMode? SslMode { get; set; } public string? ApplicationName { get; set; } public int? TimeoutSeconds { get; set; } public Dictionary Properties { get; } = new(StringComparer.OrdinalIgnoreCase); public void AddHost(string host, ushort? port) { if (string.IsNullOrWhiteSpace(host)) return; Hosts.Add(new HostEndpoint { Host = host.Trim(), Port = port }); } public ConnectionDescriptor Build() { return new ConnectionDescriptor { Hosts = Hosts, Database = Database, Username = Username, Password = Password, SslMode = SslMode, ApplicationName = ApplicationName, TimeoutSeconds = TimeoutSeconds, Properties = Properties }; } } }