From f46ee407f2cbafbf757c9eb42271da5c081995e3 Mon Sep 17 00:00:00 2001 From: eelke Date: Sat, 30 Aug 2025 20:09:10 +0200 Subject: [PATCH] LibpqCodec --- .../ConnectionStrings/LibpqCodecTests.cs | 89 +++++++ .../ConnectionStrings/Abstractions.cs | 76 ++++++ .../ConnectionStrings/Pq/LibpqCodec.cs | 222 ++++++++++++++++++ 3 files changed, 387 insertions(+) create mode 100644 pgLabII.PgUtils.Tests/ConnectionStrings/LibpqCodecTests.cs create mode 100644 pgLabII.PgUtils/ConnectionStrings/Abstractions.cs create mode 100644 pgLabII.PgUtils/ConnectionStrings/Pq/LibpqCodec.cs diff --git a/pgLabII.PgUtils.Tests/ConnectionStrings/LibpqCodecTests.cs b/pgLabII.PgUtils.Tests/ConnectionStrings/LibpqCodecTests.cs new file mode 100644 index 0000000..9fb430a --- /dev/null +++ b/pgLabII.PgUtils.Tests/ConnectionStrings/LibpqCodecTests.cs @@ -0,0 +1,89 @@ +using pgLabII.PgUtils.ConnectionStrings; +using Npgsql; + +namespace pgLabII.PgUtils.Tests.ConnectionStrings; + +public class LibpqCodecTests +{ + [Fact] + public void Parse_Basic() + { + var codec = new LibpqCodec(); + var res = codec.TryParse("host=localhost port=5433 dbname=testdb user=alice password=secret sslmode=require connect_timeout=15 application_name='pgLab II'"); + Assert.True(res.IsSuccess); + var d = res.Value; + Assert.Single(d.Hosts); + Assert.Equal("localhost", d.Hosts[0].Host); + Assert.Equal((ushort)5433, d.Hosts[0].Port); + Assert.Equal("testdb", d.Database); + Assert.Equal("alice", d.Username); + Assert.Equal("secret", d.Password); + Assert.Equal(SslMode.Require, d.SslMode); + Assert.Equal(15, d.TimeoutSeconds); + Assert.Equal("pgLab II", d.ApplicationName); + } + + [Fact] + public void Parse_MultiHost() + { + var codec = new LibpqCodec(); + var res = codec.TryParse("host=host1,host2,host3 port=5432 dbname=db user=u"); + Assert.True(res.IsSuccess); + var d = res.Value; + Assert.Equal(3, d.Hosts.Count); + Assert.Equal("host1", d.Hosts[0].Host); + Assert.Equal((ushort)5432, d.Hosts[0].Port); + Assert.Equal("host2", d.Hosts[1].Host); + Assert.Equal((ushort)5432, d.Hosts[1].Port); + Assert.Equal("host3", d.Hosts[2].Host); + Assert.Equal((ushort)5432, d.Hosts[2].Port); + } + + [Fact] + public void Format_Basic() + { + var codec = new LibpqCodec(); + var d = new ConnectionDescriptor + { + Hosts = new [] { new HostEndpoint{ Host = "db.example.com", Port = 5432 } }, + Database = "prod db", + Username = "bob", + Password = "p@ss w'rd\\", + SslMode = SslMode.VerifyFull, + ApplicationName = "cli", + TimeoutSeconds = 7, + Properties = new Dictionary{{"search_path","public"}} + }; + var res = codec.TryFormat(d); + Assert.True(res.IsSuccess); + var s = res.Value; + // ensure critical pairs exist and are quoted when needed + Assert.Contains("host=db.example.com", s); + Assert.Contains("port=5432", s); + Assert.Contains("dbname='prod db'", s); + Assert.Contains("user=bob", s); + Assert.Contains("password='p@ss w\\'rd\\\\'", s); + Assert.Contains("sslmode=verify-full", s); + Assert.Contains("application_name=cli", s); + Assert.Contains("connect_timeout=7", s); + Assert.Contains("search_path=public", s); + } + + [Fact] + public void Roundtrip_ParseThenFormat() + { + var codec = new LibpqCodec(); + var input = "host='my host' dbname=postgres user=me password='with space' application_name='my app' sslmode=prefer"; + var parsed = codec.TryParse(input); + Assert.True(parsed.IsSuccess); + var formatted = codec.TryFormat(parsed.Value); + Assert.True(formatted.IsSuccess); + var s = formatted.Value; + Assert.Contains("host='my host'", s); + Assert.Contains("dbname=postgres", s); + Assert.Contains("user=me", s); + Assert.Contains("password='with space'", s); + Assert.Contains("application_name='my app'", s); + Assert.Contains("sslmode=prefer", s); + } +} diff --git a/pgLabII.PgUtils/ConnectionStrings/Abstractions.cs b/pgLabII.PgUtils/ConnectionStrings/Abstractions.cs new file mode 100644 index 0000000..178d257 --- /dev/null +++ b/pgLabII.PgUtils/ConnectionStrings/Abstractions.cs @@ -0,0 +1,76 @@ +using System.Collections.Generic; +using FluentResults; +using Npgsql; + +namespace pgLabII.PgUtils.ConnectionStrings; + +public enum ConnStringFormat +{ + Libpq, + Npgsql, + Url, + Jdbc +} + +public sealed class HostEndpoint +{ + public string Host { get; init; } = string.Empty; + public ushort? Port { get; init; } +} + +/// +/// Canonical, format-agnostic representation of a PostgreSQL connection. +/// Keep minimal fields for broad interoperability; store extras in Properties. +/// +public sealed class ConnectionDescriptor +{ + public string? Name { get; init; } + + // Primary hosts (support multi-host). If empty, implies localhost default. + public IReadOnlyList Hosts { get; init; } = new List(); + + public string? Database { get; init; } + public string? Username { get; init; } + public string? Password { get; init; } + + public SslMode? SslMode { get; init; } + + // Common optional fields + public string? ApplicationName { get; init; } + public int? TimeoutSeconds { get; init; } // connect_timeout + + // Additional parameters preserved across conversions + public IReadOnlyDictionary Properties { get; init; } = + new Dictionary(); +} + +/// +/// Codec for a specific connection string format (parse and format only for its own format). +/// Do not implement format specifics yet; provide interface only. +/// +public interface IConnectionStringCodec +{ + ConnStringFormat Format { get; } + string FormatName { get; } + + // Parse input in this codec's format into a descriptor. + Result TryParse(string input); + + // Format a descriptor into this codec's format. + Result TryFormat(ConnectionDescriptor descriptor); +} + +/// +/// High-level service to detect, parse, format and convert between formats. +/// Implementations will compose specific codecs. +/// +public interface IConnectionStringService +{ + Result DetectFormat(string input); + + Result ParseToDescriptor(string input); + + Result FormatFromDescriptor(ConnectionDescriptor descriptor, ConnStringFormat targetFormat); + + Result Convert(string input, ConnStringFormat targetFormat); +} diff --git a/pgLabII.PgUtils/ConnectionStrings/Pq/LibpqCodec.cs b/pgLabII.PgUtils/ConnectionStrings/Pq/LibpqCodec.cs new file mode 100644 index 0000000..7e4a3dc --- /dev/null +++ b/pgLabII.PgUtils/ConnectionStrings/Pq/LibpqCodec.cs @@ -0,0 +1,222 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using FluentResults; +using Npgsql; + +namespace pgLabII.PgUtils.ConnectionStrings; + +public sealed class LibpqCodec : IConnectionStringCodec +{ + public ConnStringFormat Format => ConnStringFormat.Libpq; + public string FormatName => "libpq"; + + public Result TryParse(string input) + { + try + { + var kv = new PqConnectionStringParser(new PqConnectionStringTokenizer(input)).Parse(); + + // libpq keywords are case-insensitive; normalize to lower for lookup + var dict = new Dictionary(StringComparer.OrdinalIgnoreCase); + foreach (var pair in kv) + dict[pair.Key] = pair.Value; + + var descriptor = new ConnectionDescriptorBuilder(); + + if (dict.TryGetValue("host", out var host)) + { + // libpq supports host lists separated by commas + var hosts = host.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); + ushort? portForAll = null; + if (dict.TryGetValue("port", out var portStr) && ushort.TryParse(portStr, out var p)) + portForAll = p; + foreach (var h in hosts) + { + descriptor.AddHost(h, portForAll); + } + } + if (dict.TryGetValue("hostaddr", out var hostaddr) && !string.IsNullOrWhiteSpace(hostaddr)) + { + // If hostaddr is provided without host, include as host entries as well + var hosts = hostaddr.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries); + ushort? portForAll = null; + if (dict.TryGetValue("port", out var portStr) && ushort.TryParse(portStr, out var p)) + portForAll = p; + foreach (var h in hosts) + descriptor.AddHost(h, portForAll); + } + + if (dict.TryGetValue("dbname", out var db)) + descriptor.Database = db; + if (dict.TryGetValue("user", out var user)) + descriptor.Username = user; + else if (dict.TryGetValue("username", out var username)) + descriptor.Username = username; + if (dict.TryGetValue("password", out var pass)) + descriptor.Password = pass; + + if (dict.TryGetValue("sslmode", out var sslStr)) + descriptor.SslMode = ParseSslMode(sslStr); + if (dict.TryGetValue("application_name", out var app)) + descriptor.ApplicationName = app; + if (dict.TryGetValue("connect_timeout", out var tout) && int.TryParse(tout, out var seconds)) + descriptor.TimeoutSeconds = seconds; + + // Remaining properties: store extras excluding mapped keys + var mapped = new HashSet(StringComparer.OrdinalIgnoreCase) + { + "host","hostaddr","port","dbname","user","username","password","sslmode","application_name","connect_timeout" + }; + foreach (var (k,v) in dict) + { + if (!mapped.Contains(k)) + descriptor.Properties[k] = v; + } + + return Result.Ok(descriptor.Build()); + } + catch (Exception ex) + { + return Result.Fail(ex.Message); + } + } + + public Result TryFormat(ConnectionDescriptor descriptor) + { + try + { + var parts = new List(); + + // Hosts and port + if (descriptor.Hosts != null && descriptor.Hosts.Count > 0) + { + var hostList = string.Join(',', descriptor.Hosts.Select(h => h.Host)); + parts.Add(FormatPair("host", hostList)); + // If all ports are same and present, emit a single port + var ports = descriptor.Hosts.Select(h => h.Port).Where(p => p.HasValue).Select(p => p!.Value).Distinct().ToList(); + if (ports.Count == 1) + parts.Add(FormatPair("port", ports[0].ToString())); + } + + if (!string.IsNullOrEmpty(descriptor.Database)) + parts.Add(FormatPair("dbname", descriptor.Database)); + if (!string.IsNullOrEmpty(descriptor.Username)) + parts.Add(FormatPair("user", descriptor.Username)); + if (!string.IsNullOrEmpty(descriptor.Password)) + parts.Add(FormatPair("password", descriptor.Password)); + if (descriptor.SslMode.HasValue) + parts.Add(FormatPair("sslmode", FormatSslMode(descriptor.SslMode.Value))); + if (!string.IsNullOrEmpty(descriptor.ApplicationName)) + parts.Add(FormatPair("application_name", descriptor.ApplicationName)); + if (descriptor.TimeoutSeconds.HasValue) + parts.Add(FormatPair("connect_timeout", descriptor.TimeoutSeconds.Value.ToString())); + + // Extra properties (avoid duplicating keys we already emitted) + var emitted = new HashSet(parts.Select(p => p.Split('=')[0]), StringComparer.OrdinalIgnoreCase); + foreach (var kv in descriptor.Properties) + { + if (!emitted.Contains(kv.Key)) + parts.Add(FormatPair(kv.Key, kv.Value)); + } + + return Result.Ok(string.Join(' ', parts)); + } + catch (Exception ex) + { + return Result.Fail(ex.Message); + } + } + + private static SslMode ParseSslMode(string s) + { + return s.Trim().ToLowerInvariant() switch + { + "disable" => SslMode.Disable, + "allow" => SslMode.Allow, + "prefer" => SslMode.Prefer, + "require" => SslMode.Require, + "verify-ca" => SslMode.VerifyCA, + "verify-full" => SslMode.VerifyFull, + _ => throw new ArgumentException($"Not a valid SSL mode: {s}") + }; + } + + private static string FormatSslMode(SslMode mode) + { + return mode switch + { + SslMode.Disable => "disable", + SslMode.Allow => "allow", + SslMode.Prefer => "prefer", + SslMode.Require => "require", + SslMode.VerifyCA => "verify-ca", + SslMode.VerifyFull => "verify-full", + _ => "prefer" + }; + } + + private static string FormatPair(string key, string? value) + { + value ??= string.Empty; + if (NeedsQuoting(value)) + return key + "='" + EscapeValue(value) + "'"; + return key + "=" + value; + } + + private static bool NeedsQuoting(string value) + { + if (value.Length == 0) return true; + foreach (var c in value) + { + if (char.IsWhiteSpace(c) || c == '=' || c == '\'' || c == '\\') + return true; + } + return false; + } + + private static string EscapeValue(string value) + { + var sb = new StringBuilder(); + foreach (var c in value) + { + if (c == '\'' || c == '\\') sb.Append('\\'); + sb.Append(c); + } + return sb.ToString(); + } + + private sealed class ConnectionDescriptorBuilder + { + public List Hosts { get; } = new(); + public string? Database { get; set; } + public string? Username { get; set; } + public string? Password { get; set; } + public SslMode? SslMode { get; set; } + public string? ApplicationName { get; set; } + public int? TimeoutSeconds { get; set; } + public Dictionary Properties { get; } = new(StringComparer.OrdinalIgnoreCase); + + public void AddHost(string host, ushort? port) + { + if (string.IsNullOrWhiteSpace(host)) return; + Hosts.Add(new HostEndpoint { Host = host.Trim(), Port = port }); + } + + public ConnectionDescriptor Build() + { + return new ConnectionDescriptor + { + Hosts = Hosts, + Database = Database, + Username = Username, + Password = Password, + SslMode = SslMode, + ApplicationName = ApplicationName, + TimeoutSeconds = TimeoutSeconds, + Properties = Properties + }; + } + } +}