Fix libpq parsing and refactors/code cleanup

This commit is contained in:
eelke 2025-08-31 13:11:59 +02:00
parent 0090f39910
commit 739d6bd65a
12 changed files with 234 additions and 543 deletions

View file

@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.Linq;
using System.Text;
@ -46,11 +47,11 @@ public sealed class JdbcCodec : IConnectionStringCodec
var builder = new ConnectionDescriptorBuilder();
// Parse hosts (comma separated)
foreach (var part in SplitHosts(authority))
foreach (string part in CodecCommon.SplitHosts(authority))
{
if (string.IsNullOrWhiteSpace(part)) continue;
ParseHostPort(part, out var host, out ushort? port);
if (!string.IsNullOrEmpty(host)) builder.AddHost(host!, port);
CodecCommon.ParseHostPort(part, out var host, out ushort? port);
if (!string.IsNullOrEmpty(host))
builder.AddHost(host!, port);
}
// Parse database and query
@ -59,8 +60,8 @@ public sealed class JdbcCodec : IConnectionStringCodec
if (!string.IsNullOrEmpty(pathAndQuery))
{
int qIdx = pathAndQuery.IndexOf('?');
var path = qIdx >= 0 ? pathAndQuery.Substring(0, qIdx) : pathAndQuery;
query = qIdx >= 0 ? pathAndQuery.Substring(qIdx + 1) : string.Empty;
string path = qIdx >= 0 ? pathAndQuery[..qIdx] : pathAndQuery;
query = qIdx >= 0 ? pathAndQuery[(qIdx + 1)..] : string.Empty;
if (path.Length > 0)
{
if (path[0] == '/') path = path.Substring(1);
@ -70,21 +71,22 @@ public sealed class JdbcCodec : IConnectionStringCodec
}
if (!string.IsNullOrEmpty(database)) builder.Database = database;
var queryDict = ParseQuery(query);
var queryDict = CodecCommon.ParseQuery(query);
// Map known properties
if (TryFirst(queryDict, out var ssl, "sslmode", "ssl"))
builder.SslMode = ParseSslMode(ssl);
if (TryFirst(queryDict, out var app, "applicationName", "application_name"))
if (TryFirst(queryDict, out string? ssl, "sslmode", "ssl"))
builder.SslMode = CodecCommon.ParseSslModeLoose(ssl);
if (TryFirst(queryDict, out string? app, "applicationName", "application_name"))
builder.ApplicationName = app;
if (TryFirst(queryDict, out var tout, "loginTimeout", "connectTimeout", "connect_timeout"))
if (TryFirst(queryDict, out string? tout, "loginTimeout", "connectTimeout", "connect_timeout"))
{
if (int.TryParse(tout, NumberStyles.Integer, CultureInfo.InvariantCulture, out var t))
if (int.TryParse(tout, NumberStyles.Integer, CultureInfo.InvariantCulture, out int t))
builder.TimeoutSeconds = t;
}
// Preserve extras
var mapped = new HashSet<string>(new[] { "sslmode", "ssl", "applicationName", "application_name", "loginTimeout", "connectTimeout", "connect_timeout" }, StringComparer.OrdinalIgnoreCase);
var mapped = new HashSet<string>(["sslmode", "ssl", "applicationName", "application_name", "loginTimeout", "connectTimeout", "connect_timeout"
], StringComparer.OrdinalIgnoreCase);
foreach (var kv in queryDict)
{
if (!mapped.Contains(kv.Key))
@ -106,7 +108,7 @@ public sealed class JdbcCodec : IConnectionStringCodec
var sb = new StringBuilder();
sb.Append("jdbc:postgresql://");
if (descriptor.Hosts != null && descriptor.Hosts.Count > 0)
if (descriptor.Hosts.Count > 0)
{
sb.Append(string.Join(',', descriptor.Hosts.Select(FormatHost)));
}
@ -122,7 +124,7 @@ public sealed class JdbcCodec : IConnectionStringCodec
var qp = new List<(string k, string v)>();
if (descriptor.SslMode.HasValue)
{
qp.Add(("sslmode", FormatSslMode(descriptor.SslMode.Value)));
qp.Add(("sslmode", CodecCommon.FormatSslModeUrlLike(descriptor.SslMode.Value)));
}
if (!string.IsNullOrEmpty(descriptor.ApplicationName))
{
@ -154,140 +156,20 @@ public sealed class JdbcCodec : IConnectionStringCodec
return Result.Fail<string>(ex.Message);
}
}
private static string FormatHost(HostEndpoint h) => CodecCommon.FormatHost(h);
private static IEnumerable<string> SplitHosts(string authority)
private static bool TryFirst(
Dictionary<string, string> dict,
[MaybeNullWhen(false)] out string value,
params string[] keys)
{
return authority.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries);
}
private static string FormatHost(HostEndpoint h)
{
var host = h.Host;
if (host.Contains(':') && !host.StartsWith("["))
foreach (string k in keys)
{
// IPv6 literal must be bracketed
host = "[" + host + "]";
}
return h.Port.HasValue ? host + ":" + h.Port.Value.ToString(CultureInfo.InvariantCulture) : host;
}
private static void ParseHostPort(string hostPart, out string host, out ushort? port)
{
host = hostPart;
port = null;
if (string.IsNullOrWhiteSpace(hostPart)) return;
if (hostPart[0] == '[')
{
int end = hostPart.IndexOf(']');
if (end > 0)
{
host = hostPart.Substring(1, end - 1);
if (end + 1 < hostPart.Length && hostPart[end + 1] == ':')
{
var ps = hostPart.Substring(end + 2);
if (ushort.TryParse(ps, NumberStyles.Integer, CultureInfo.InvariantCulture, out var p))
port = p;
}
}
return;
}
int colon = hostPart.LastIndexOf(':');
if (colon > 0 && colon < hostPart.Length - 1)
{
var ps = hostPart.Substring(colon + 1);
if (ushort.TryParse(ps, NumberStyles.Integer, CultureInfo.InvariantCulture, out var p))
{
host = hostPart.Substring(0, colon);
port = p;
}
}
}
private static Dictionary<string, string> ParseQuery(string query)
{
var dict = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
if (string.IsNullOrEmpty(query)) return dict;
foreach (var kv in query.Split('&', StringSplitOptions.RemoveEmptyEntries))
{
var parts = kv.Split('=', 2);
var key = Uri.UnescapeDataString(parts[0]);
var val = parts.Length > 1 ? Uri.UnescapeDataString(parts[1]) : string.Empty;
dict[key] = val;
}
return dict;
}
private static bool TryFirst(Dictionary<string, string> dict, out string value, params string[] keys)
{
foreach (var k in keys)
{
if (dict.TryGetValue(k, out value)) return true;
if (dict.TryGetValue(k, out value))
return true;
}
value = string.Empty;
return false;
}
private static SslMode ParseSslMode(string s)
{
switch (s.Trim().ToLowerInvariant())
{
case "disable": return SslMode.Disable;
case "allow": return SslMode.Allow;
case "prefer": return SslMode.Prefer;
case "require": return SslMode.Require;
case "verify-ca":
case "verifyca": return SslMode.VerifyCA;
case "verify-full":
case "verifyfull": return SslMode.VerifyFull;
default: throw new ArgumentException($"Not a valid SSL Mode: {s}");
}
}
private static string FormatSslMode(SslMode mode)
{
return mode switch
{
SslMode.Disable => "disable",
SslMode.Allow => "allow",
SslMode.Prefer => "prefer",
SslMode.Require => "require",
SslMode.VerifyCA => "verify-ca",
SslMode.VerifyFull => "verify-full",
_ => "prefer"
};
}
private sealed class ConnectionDescriptorBuilder
{
public List<HostEndpoint> Hosts { get; } = new();
public string? Database { get; set; }
public string? Username { get; set; }
public string? Password { get; set; }
public SslMode? SslMode { get; set; }
public string? ApplicationName { get; set; }
public int? TimeoutSeconds { get; set; }
public Dictionary<string, string> Properties { get; } = new(StringComparer.OrdinalIgnoreCase);
public void AddHost(string host, ushort? port)
{
if (string.IsNullOrWhiteSpace(host)) return;
Hosts.Add(new HostEndpoint { Host = host.Trim(), Port = port });
}
public ConnectionDescriptor Build()
{
return new ConnectionDescriptor
{
Hosts = Hosts,
Database = Database,
Username = Username,
Password = Password,
SslMode = SslMode,
ApplicationName = ApplicationName,
TimeoutSeconds = TimeoutSeconds,
Properties = Properties
};
}
}
}