diff --git a/pgLabII.PgUtils.Tests/ConnectionStrings/NpgsqlCodecTests.cs b/pgLabII.PgUtils.Tests/ConnectionStrings/NpgsqlCodecTests.cs new file mode 100644 index 0000000..5f29a89 --- /dev/null +++ b/pgLabII.PgUtils.Tests/ConnectionStrings/NpgsqlCodecTests.cs @@ -0,0 +1,87 @@ +using System.Collections.Generic; +using Npgsql; +using pgLabII.PgUtils.ConnectionStrings; + +namespace pgLabII.PgUtils.Tests.ConnectionStrings; + +public class NpgsqlCodecTests +{ + [Fact] + public void Parse_Basic() + { + var codec = new NpgsqlCodec(); + var res = codec.TryParse("Host=localhost;Port=5434;Database=testdb;Username=alice;Password=secret;SSL Mode=Require;Application Name=pgLab;Timeout=10"); + Assert.True(res.IsSuccess); + var d = res.Value; + Assert.Single(d.Hosts); + Assert.Equal("localhost", d.Hosts[0].Host); + Assert.Equal((ushort)5434, d.Hosts[0].Port); + Assert.Equal("testdb", d.Database); + Assert.Equal("alice", d.Username); + Assert.Equal("secret", d.Password); + Assert.Equal(SslMode.Require, d.SslMode); + Assert.Equal("pgLab", d.ApplicationName); + Assert.Equal(10, d.TimeoutSeconds); + } + + [Fact] + public void Parse_MultiHost_WithSinglePort() + { + var codec = new NpgsqlCodec(); + var res = codec.TryParse("Host=host1,host2;Port=5433;Database=db;Username=u"); + Assert.True(res.IsSuccess); + var d = res.Value; + Assert.Equal(2, d.Hosts.Count); + Assert.Equal("host1", d.Hosts[0].Host); + Assert.Equal((ushort)5433, d.Hosts[0].Port); + Assert.Equal("host2", d.Hosts[1].Host); + Assert.Equal((ushort)5433, d.Hosts[1].Port); + } + + [Fact] + public void Format_Basic_WithQuoting() + { + var codec = new NpgsqlCodec(); + var d = new ConnectionDescriptor + { + Hosts = new [] { new HostEndpoint{ Host = "db.example.com", Port = 5432 } }, + Database = "prod db", + Username = "bob", + Password = "p;ss\"word", + SslMode = SslMode.VerifyFull, + ApplicationName = "cli app", + TimeoutSeconds = 9, + Properties = new Dictionary{{"Search Path","public"}} + }; + var res = codec.TryFormat(d); + Assert.True(res.IsSuccess); + var s = res.Value; + Assert.Contains("Host=db.example.com", s); + Assert.Contains("Port=5432", s); + Assert.Contains("Database=\"prod db\"", s); + Assert.Contains("Username=bob", s); + Assert.Contains("Password=\"p;ss\"\"word\"", s); + Assert.Contains("SSL Mode=VerifyFull", s); + Assert.Contains("Application Name=\"cli app\"", s); + Assert.Contains("Timeout=9", s); + Assert.Contains("Search Path=public", s); + } + + [Fact] + public void Roundtrip_ParseThenFormat() + { + var codec = new NpgsqlCodec(); + var input = "Host=\"my host\";Database=postgres;Username=me;Password=\"with;quote\"\"\";Application Name=\"my app\";SSL Mode=Prefer"; + var parsed = codec.TryParse(input); + Assert.True(parsed.IsSuccess); + var formatted = codec.TryFormat(parsed.Value); + Assert.True(formatted.IsSuccess); + var s = formatted.Value; + Assert.Contains("Host=\"my host\"", s); + Assert.Contains("Database=postgres", s); + Assert.Contains("Username=me", s); + Assert.Contains("Password=\"with;quote\"\"\"", s); + Assert.Contains("Application Name=\"my app\"", s); + Assert.Contains("SSL Mode=Prefer", s); + } +} diff --git a/pgLabII.PgUtils/ConnectionStrings/NpgsqlCodec.cs b/pgLabII.PgUtils/ConnectionStrings/NpgsqlCodec.cs new file mode 100644 index 0000000..7e2f0fa --- /dev/null +++ b/pgLabII.PgUtils/ConnectionStrings/NpgsqlCodec.cs @@ -0,0 +1,316 @@ +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Text; +using FluentResults; +using Npgsql; + +namespace pgLabII.PgUtils.ConnectionStrings; + +public sealed class NpgsqlCodec : IConnectionStringCodec +{ + public ConnStringFormat Format => ConnStringFormat.Npgsql; + public string FormatName => "Npgsql"; + + public Result TryParse(string input) + { + try + { + var dict = Tokenize(input); + var descriptor = new ConnectionDescriptorBuilder(); + + // Hosts and Ports + if (dict.TryGetValue("Host", out var hostVal) || dict.TryGetValue("Server", out hostVal) || dict.TryGetValue("Servers", out hostVal)) + { + var hosts = SplitList(hostVal).ToList(); + List portsPerHost = new(); + if (dict.TryGetValue("Port", out var portVal)) + { + var ports = SplitList(portVal).ToList(); + if (ports.Count == 1 && ushort.TryParse(ports[0], out var singlePort)) + { + foreach (var _ in hosts) portsPerHost.Add(singlePort); + } + else if (ports.Count == hosts.Count) + { + foreach (var p in ports) + { + if (ushort.TryParse(p, NumberStyles.Integer, CultureInfo.InvariantCulture, out var up)) + portsPerHost.Add(up); + else + portsPerHost.Add(null); + } + } + } + for (int i = 0; i < hosts.Count; i++) + { + ushort? port = i < portsPerHost.Count ? portsPerHost[i] : null; + descriptor.AddHost(hosts[i], port); + } + } + + // Standard fields + if (TryGetFirst(dict, out var db, "Database", "Db", "Initial Catalog", "dbname")) + descriptor.Database = db; + if (TryGetFirst(dict, out var user, "Username", "User ID", "User", "UID")) + descriptor.Username = user; + if (TryGetFirst(dict, out var pass, "Password", "PWD")) + descriptor.Password = pass; + if (TryGetFirst(dict, out var app, "Application Name", "ApplicationName")) + descriptor.ApplicationName = app; + if (TryGetFirst(dict, out var timeout, "Timeout", "Connect Timeout", "Connection Timeout")) + { + if (int.TryParse(timeout, NumberStyles.Integer, CultureInfo.InvariantCulture, out var t)) + descriptor.TimeoutSeconds = t; + } + if (TryGetFirst(dict, out var ssl, "SSL Mode", "SslMode", "SSLMode")) + descriptor.SslMode = ParseSslMode(ssl); + + // Preserve extras (not mapped) into Properties + var mapped = new HashSet(StringComparer.OrdinalIgnoreCase) + { + "Host","Server","Servers","Port","Database","Db","Initial Catalog","dbname", + "Username","User ID","User","UID","Password","PWD","Application Name","ApplicationName", + "Timeout","Connect Timeout","Connection Timeout","SSL Mode","SslMode","SSLMode" + }; + foreach (var (k, v) in dict) + { + if (!mapped.Contains(k)) + descriptor.Properties[k] = v; + } + + return Result.Ok(descriptor.Build()); + } + catch (Exception ex) + { + return Result.Fail(ex.Message); + } + } + + public Result TryFormat(ConnectionDescriptor descriptor) + { + try + { + var parts = new List(); + + if (descriptor.Hosts != null && descriptor.Hosts.Count > 0) + { + var hostList = string.Join(',', descriptor.Hosts.Select(h => h.Host)); + parts.Add(FormatPair("Host", hostList)); + var ports = descriptor.Hosts.Select(h => h.Port).Where(p => p.HasValue).Select(p => p!.Value).Distinct().ToList(); + if (ports.Count == 1) + { + parts.Add(FormatPair("Port", ports[0].ToString(CultureInfo.InvariantCulture))); + } + else if (ports.Count == 0) + { + // nothing + } + else + { + // Per-host ports if provided 1:1 + var perHost = descriptor.Hosts.Select(h => h.Port?.ToString(CultureInfo.InvariantCulture) ?? string.Empty).ToList(); + if (perHost.All(s => !string.IsNullOrEmpty(s))) + parts.Add(FormatPair("Port", string.Join(',', perHost))); + } + } + + if (!string.IsNullOrEmpty(descriptor.Database)) + parts.Add(FormatPair("Database", descriptor.Database)); + if (!string.IsNullOrEmpty(descriptor.Username)) + parts.Add(FormatPair("Username", descriptor.Username)); + if (!string.IsNullOrEmpty(descriptor.Password)) + parts.Add(FormatPair("Password", descriptor.Password)); + if (descriptor.SslMode.HasValue) + parts.Add(FormatPair("SSL Mode", FormatSslMode(descriptor.SslMode.Value))); + if (!string.IsNullOrEmpty(descriptor.ApplicationName)) + parts.Add(FormatPair("Application Name", descriptor.ApplicationName)); + if (descriptor.TimeoutSeconds.HasValue) + parts.Add(FormatPair("Timeout", descriptor.TimeoutSeconds.Value.ToString(CultureInfo.InvariantCulture))); + + var emittedKeys = new HashSet(parts.Select(p => p.Split('=')[0].Trim()), StringComparer.OrdinalIgnoreCase); + foreach (var kv in descriptor.Properties) + { + if (!emittedKeys.Contains(kv.Key)) + parts.Add(FormatPair(kv.Key, kv.Value)); + } + + return Result.Ok(string.Join(";", parts)); + } + catch (Exception ex) + { + return Result.Fail(ex.Message); + } + } + + private static IEnumerable SplitList(string s) + { + return s.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); + } + + private static bool TryGetFirst(Dictionary dict, out string value, params string[] keys) + { + foreach (var k in keys) + { + if (dict.TryGetValue(k, out value)) return true; + } + value = string.Empty; + return false; + } + + private static SslMode ParseSslMode(string s) + { + switch (s.Trim().ToLowerInvariant()) + { + case "disable": return SslMode.Disable; + case "allow": return SslMode.Allow; + case "prefer": return SslMode.Prefer; + case "require": return SslMode.Require; + case "verify-ca": + case "verifyca": return SslMode.VerifyCA; + case "verify-full": + case "verifyfull": return SslMode.VerifyFull; + default: throw new ArgumentException($"Not a valid SSL Mode: {s}"); + } + } + + private static string FormatSslMode(SslMode mode) + { + return mode switch + { + SslMode.Disable => "Disable", + SslMode.Allow => "Allow", + SslMode.Prefer => "Prefer", + SslMode.Require => "Require", + SslMode.VerifyCA => "VerifyCA", + SslMode.VerifyFull => "VerifyFull", + _ => "Prefer" + }; + } + + // Npgsql/.NET connection string grammar: semicolon-separated key=value; values with special chars are wrapped in quotes, internal quotes doubled + private static string FormatPair(string key, string? value) + { + value ??= string.Empty; + var needsQuotes = NeedsQuoting(value); + if (!needsQuotes) return key + "=" + value; + return key + "=\"" + EscapeQuoted(value) + "\""; + } + + private static bool NeedsQuoting(string value) + { + if (value.Length == 0) return true; + foreach (var c in value) + { + if (char.IsWhiteSpace(c) || c == ';' || c == '=' || c == '"') return true; + } + return false; + } + + private static string EscapeQuoted(string value) + { + // Double the quotes per standard DbConnectionString rules + return value.Replace("\"", "\"\""); + } + + private static Dictionary Tokenize(string input) + { + // Simple tokenizer for .NET connection strings: key=value pairs separated by semicolons; values may be quoted with double quotes + var dict = new Dictionary(StringComparer.OrdinalIgnoreCase); + int i = 0; + void SkipWs() { while (i < input.Length && char.IsWhiteSpace(input[i])) i++; } + + while (true) + { + SkipWs(); + if (i >= input.Length) break; + + // read key + int keyStart = i; + while (i < input.Length && input[i] != '=') i++; + if (i >= input.Length) { break; } + var key = input.Substring(keyStart, i - keyStart).Trim(); + i++; // skip '=' + SkipWs(); + + // read value + string value; + if (i < input.Length && input[i] == '"') + { + i++; // skip opening quote + var sb = new StringBuilder(); + while (i < input.Length) + { + char c = input[i++]; + if (c == '"') + { + if (i < input.Length && input[i] == '"') + { + // doubled quote -> literal quote + sb.Append('"'); + i++; + continue; + } + else + { + break; // end quoted value + } + } + else + { + sb.Append(c); + } + } + value = sb.ToString(); + } + else + { + int valStart = i; + while (i < input.Length && input[i] != ';') i++; + value = input.Substring(valStart, i - valStart).Trim(); + } + + dict[key] = value; + + // skip to next, if ; present, consume one + while (i < input.Length && input[i] != ';') i++; + if (i < input.Length && input[i] == ';') i++; + } + + return dict; + } + + private sealed class ConnectionDescriptorBuilder + { + public List Hosts { get; } = new(); + public string? Database { get; set; } + public string? Username { get; set; } + public string? Password { get; set; } + public SslMode? SslMode { get; set; } + public string? ApplicationName { get; set; } + public int? TimeoutSeconds { get; set; } + public Dictionary Properties { get; } = new(StringComparer.OrdinalIgnoreCase); + + public void AddHost(string host, ushort? port) + { + if (string.IsNullOrWhiteSpace(host)) return; + Hosts.Add(new HostEndpoint { Host = host.Trim(), Port = port }); + } + + public ConnectionDescriptor Build() + { + return new ConnectionDescriptor + { + Hosts = Hosts, + Database = Database, + Username = Username, + Password = Password, + SslMode = SslMode, + ApplicationName = ApplicationName, + TimeoutSeconds = TimeoutSeconds, + Properties = Properties + }; + } + } +}