using System; using System.Collections.Generic; using System.Linq; using System.Text; using FluentResults; using Npgsql; namespace pgLabII.PgUtils.ConnectionStrings; public sealed class LibpqCodec : IConnectionStringCodec { public ConnStringFormat Format => ConnStringFormat.Libpq; public string FormatName => "libpq"; public Result TryParse(string input) { try { // Reject Npgsql-style strings that use ';' separators when forcing libpq if (input.IndexOf(';') >= 0) return Result.Fail("Semicolons are not valid separators in libpq connection strings"); var kv = new PqConnectionStringParser(new PqConnectionStringTokenizer(input)).Parse(); // libpq keywords are case-insensitive; normalize to lower for lookup var dict = new Dictionary(StringComparer.OrdinalIgnoreCase); foreach (var pair in kv) dict[pair.Key] = pair.Value; var descriptor = new ConnectionDescriptorBuilder(); if (dict.TryGetValue("host", out var host)) { // libpq supports host lists separated by commas var hosts = host.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); ushort? portForAll = null; if (dict.TryGetValue("port", out var portStr) && ushort.TryParse(portStr, out var p)) portForAll = p; foreach (var h in hosts) { descriptor.AddHost(h, portForAll); } } if (dict.TryGetValue("hostaddr", out var hostaddr) && !string.IsNullOrWhiteSpace(hostaddr)) { // If hostaddr is provided without host, include as host entries as well var hosts = hostaddr.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); ushort? portForAll = null; if (dict.TryGetValue("port", out var portStr) && ushort.TryParse(portStr, out var p)) portForAll = p; foreach (var h in hosts) descriptor.AddHost(h, portForAll); } if (dict.TryGetValue("dbname", out var db)) descriptor.Database = db; if (dict.TryGetValue("user", out var user)) descriptor.Username = user; else if (dict.TryGetValue("username", out var username)) descriptor.Username = username; if (dict.TryGetValue("password", out var pass)) descriptor.Password = pass; if (dict.TryGetValue("sslmode", out var sslStr)) descriptor.SslMode = ParseSslMode(sslStr); if (dict.TryGetValue("application_name", out var app)) descriptor.ApplicationName = app; if (dict.TryGetValue("connect_timeout", out var tout) && int.TryParse(tout, out var seconds)) descriptor.TimeoutSeconds = seconds; // Remaining properties: store extras excluding mapped keys var mapped = new HashSet(StringComparer.OrdinalIgnoreCase) { "host","hostaddr","port","dbname","user","username","password","sslmode","application_name","connect_timeout" }; foreach (var (k,v) in dict) { if (!mapped.Contains(k)) descriptor.Properties[k] = v; } return Result.Ok(descriptor.Build()); } catch (Exception ex) { return Result.Fail(ex.Message); } } public Result TryFormat(ConnectionDescriptor descriptor) { try { var parts = new List(); // Hosts and port if (descriptor.Hosts != null && descriptor.Hosts.Count > 0) { var hostList = string.Join(',', descriptor.Hosts.Select(h => h.Host)); parts.Add(FormatPair("host", hostList)); // If all ports are same and present, emit a single port var ports = descriptor.Hosts.Select(h => h.Port).Where(p => p.HasValue).Select(p => p!.Value).Distinct().ToList(); if (ports.Count == 1) parts.Add(FormatPair("port", ports[0].ToString())); } if (!string.IsNullOrEmpty(descriptor.Database)) parts.Add(FormatPair("dbname", descriptor.Database)); if (!string.IsNullOrEmpty(descriptor.Username)) parts.Add(FormatPair("user", descriptor.Username)); if (!string.IsNullOrEmpty(descriptor.Password)) parts.Add(FormatPair("password", descriptor.Password)); if (descriptor.SslMode.HasValue) parts.Add(FormatPair("sslmode", FormatSslMode(descriptor.SslMode.Value))); if (!string.IsNullOrEmpty(descriptor.ApplicationName)) parts.Add(FormatPair("application_name", descriptor.ApplicationName)); if (descriptor.TimeoutSeconds.HasValue) parts.Add(FormatPair("connect_timeout", descriptor.TimeoutSeconds.Value.ToString())); // Extra properties (avoid duplicating keys we already emitted) var emitted = new HashSet(parts.Select(p => p.Split('=')[0]), StringComparer.OrdinalIgnoreCase); foreach (var kv in descriptor.Properties) { if (!emitted.Contains(kv.Key)) parts.Add(FormatPair(kv.Key, kv.Value)); } return Result.Ok(string.Join(' ', parts)); } catch (Exception ex) { return Result.Fail(ex.Message); } } private static SslMode ParseSslMode(string s) { return s.Trim().ToLowerInvariant() switch { "disable" => SslMode.Disable, "allow" => SslMode.Allow, "prefer" => SslMode.Prefer, "require" => SslMode.Require, "verify-ca" => SslMode.VerifyCA, "verify-full" => SslMode.VerifyFull, _ => throw new ArgumentException($"Not a valid SSL mode: {s}") }; } private static string FormatSslMode(SslMode mode) { return mode switch { SslMode.Disable => "disable", SslMode.Allow => "allow", SslMode.Prefer => "prefer", SslMode.Require => "require", SslMode.VerifyCA => "verify-ca", SslMode.VerifyFull => "verify-full", _ => "prefer" }; } private static string FormatPair(string key, string? value) { value ??= string.Empty; if (NeedsQuoting(value)) return key + "='" + EscapeValue(value) + "'"; return key + "=" + value; } private static bool NeedsQuoting(string value) { if (value.Length == 0) return true; foreach (var c in value) { if (char.IsWhiteSpace(c) || c == '=' || c == '\'' || c == '\\') return true; } return false; } private static string EscapeValue(string value) { var sb = new StringBuilder(); foreach (var c in value) { if (c == '\'' || c == '\\') sb.Append('\\'); sb.Append(c); } return sb.ToString(); } private sealed class ConnectionDescriptorBuilder { public List Hosts { get; } = new(); public string? Database { get; set; } public string? Username { get; set; } public string? Password { get; set; } public SslMode? SslMode { get; set; } public string? ApplicationName { get; set; } public int? TimeoutSeconds { get; set; } public Dictionary Properties { get; } = new(StringComparer.OrdinalIgnoreCase); public void AddHost(string host, ushort? port) { if (string.IsNullOrWhiteSpace(host)) return; Hosts.Add(new HostEndpoint { Host = host.Trim(), Port = port }); } public ConnectionDescriptor Build() { return new ConnectionDescriptor { Hosts = Hosts, Database = Database, Username = Username, Password = Password, SslMode = SslMode, ApplicationName = ApplicationName, TimeoutSeconds = TimeoutSeconds, Properties = Properties }; } } }