Fix libpq parsing and refactors/code cleanup
This commit is contained in:
parent
0090f39910
commit
739d6bd65a
12 changed files with 234 additions and 543 deletions
|
|
@ -1,9 +1,5 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Text;
|
||||
using FluentResults;
|
||||
using Npgsql;
|
||||
|
||||
namespace pgLabII.PgUtils.ConnectionStrings;
|
||||
|
||||
|
|
@ -16,14 +12,13 @@ public sealed class LibpqCodec : IConnectionStringCodec
|
|||
{
|
||||
try
|
||||
{
|
||||
// Reject Npgsql-style strings that use ';' separators when forcing libpq
|
||||
if (input.IndexOf(';') >= 0)
|
||||
return Result.Fail<ConnectionDescriptor>("Semicolons are not valid separators in libpq connection strings");
|
||||
var kv = new PqConnectionStringParser(new PqConnectionStringTokenizer(input)).Parse();
|
||||
Result<IDictionary<string, string>> kv = new PqConnectionStringParser(new PqConnectionStringTokenizer(input)).Parse();
|
||||
if (kv.IsFailed)
|
||||
return kv.ToResult();
|
||||
|
||||
// libpq keywords are case-insensitive; normalize to lower for lookup
|
||||
var dict = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
|
||||
foreach (var pair in kv)
|
||||
foreach (var pair in kv.Value)
|
||||
dict[pair.Key] = pair.Value;
|
||||
|
||||
var descriptor = new ConnectionDescriptorBuilder();
|
||||
|
|
@ -31,7 +26,7 @@ public sealed class LibpqCodec : IConnectionStringCodec
|
|||
if (dict.TryGetValue("host", out var host))
|
||||
{
|
||||
// libpq supports host lists separated by commas
|
||||
var hosts = host.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries);
|
||||
string[] hosts = CodecCommon.SplitHosts(host);
|
||||
ushort? portForAll = null;
|
||||
if (dict.TryGetValue("port", out var portStr) && ushort.TryParse(portStr, out var p))
|
||||
portForAll = p;
|
||||
|
|
@ -40,10 +35,10 @@ public sealed class LibpqCodec : IConnectionStringCodec
|
|||
descriptor.AddHost(h, portForAll);
|
||||
}
|
||||
}
|
||||
if (dict.TryGetValue("hostaddr", out var hostaddr) && !string.IsNullOrWhiteSpace(hostaddr))
|
||||
if (dict.TryGetValue("hostaddr", out string? hostaddr) && !string.IsNullOrWhiteSpace(hostaddr))
|
||||
{
|
||||
// If hostaddr is provided without host, include as host entries as well
|
||||
var hosts = hostaddr.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries);
|
||||
// If hostaddr is provided without a host, include as host entries as well
|
||||
string[] hosts = CodecCommon.SplitHosts(hostaddr);
|
||||
ushort? portForAll = null;
|
||||
if (dict.TryGetValue("port", out var portStr) && ushort.TryParse(portStr, out var p))
|
||||
portForAll = p;
|
||||
|
|
@ -61,7 +56,7 @@ public sealed class LibpqCodec : IConnectionStringCodec
|
|||
descriptor.Password = pass;
|
||||
|
||||
if (dict.TryGetValue("sslmode", out var sslStr))
|
||||
descriptor.SslMode = ParseSslMode(sslStr);
|
||||
descriptor.SslMode = CodecCommon.ParseSslModeLoose(sslStr);
|
||||
if (dict.TryGetValue("application_name", out var app))
|
||||
descriptor.ApplicationName = app;
|
||||
if (dict.TryGetValue("connect_timeout", out var tout) && int.TryParse(tout, out var seconds))
|
||||
|
|
@ -93,7 +88,7 @@ public sealed class LibpqCodec : IConnectionStringCodec
|
|||
var parts = new List<string>();
|
||||
|
||||
// Hosts and port
|
||||
if (descriptor.Hosts != null && descriptor.Hosts.Count > 0)
|
||||
if (descriptor.Hosts.Count > 0)
|
||||
{
|
||||
var hostList = string.Join(',', descriptor.Hosts.Select(h => h.Host));
|
||||
parts.Add(FormatPair("host", hostList));
|
||||
|
|
@ -110,7 +105,7 @@ public sealed class LibpqCodec : IConnectionStringCodec
|
|||
if (!string.IsNullOrEmpty(descriptor.Password))
|
||||
parts.Add(FormatPair("password", descriptor.Password));
|
||||
if (descriptor.SslMode.HasValue)
|
||||
parts.Add(FormatPair("sslmode", FormatSslMode(descriptor.SslMode.Value)));
|
||||
parts.Add(FormatPair("sslmode", CodecCommon.FormatSslModeUrlLike(descriptor.SslMode.Value)));
|
||||
if (!string.IsNullOrEmpty(descriptor.ApplicationName))
|
||||
parts.Add(FormatPair("application_name", descriptor.ApplicationName));
|
||||
if (descriptor.TimeoutSeconds.HasValue)
|
||||
|
|
@ -132,34 +127,6 @@ public sealed class LibpqCodec : IConnectionStringCodec
|
|||
}
|
||||
}
|
||||
|
||||
private static SslMode ParseSslMode(string s)
|
||||
{
|
||||
return s.Trim().ToLowerInvariant() switch
|
||||
{
|
||||
"disable" => SslMode.Disable,
|
||||
"allow" => SslMode.Allow,
|
||||
"prefer" => SslMode.Prefer,
|
||||
"require" => SslMode.Require,
|
||||
"verify-ca" => SslMode.VerifyCA,
|
||||
"verify-full" => SslMode.VerifyFull,
|
||||
_ => throw new ArgumentException($"Not a valid SSL mode: {s}")
|
||||
};
|
||||
}
|
||||
|
||||
private static string FormatSslMode(SslMode mode)
|
||||
{
|
||||
return mode switch
|
||||
{
|
||||
SslMode.Disable => "disable",
|
||||
SslMode.Allow => "allow",
|
||||
SslMode.Prefer => "prefer",
|
||||
SslMode.Require => "require",
|
||||
SslMode.VerifyCA => "verify-ca",
|
||||
SslMode.VerifyFull => "verify-full",
|
||||
_ => "prefer"
|
||||
};
|
||||
}
|
||||
|
||||
private static string FormatPair(string key, string? value)
|
||||
{
|
||||
value ??= string.Empty;
|
||||
|
|
@ -170,56 +137,17 @@ public sealed class LibpqCodec : IConnectionStringCodec
|
|||
|
||||
private static bool NeedsQuoting(string value)
|
||||
{
|
||||
if (value.Length == 0) return true;
|
||||
foreach (var c in value)
|
||||
{
|
||||
if (char.IsWhiteSpace(c) || c == '=' || c == '\'' || c == '\\')
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
return value.Any(c => char.IsWhiteSpace(c) || c == '=' || c == '\'' || c == '\\');
|
||||
}
|
||||
|
||||
private static string EscapeValue(string value)
|
||||
{
|
||||
var sb = new StringBuilder();
|
||||
foreach (var c in value)
|
||||
foreach (char c in value)
|
||||
{
|
||||
if (c == '\'' || c == '\\') sb.Append('\\');
|
||||
sb.Append(c);
|
||||
}
|
||||
return sb.ToString();
|
||||
}
|
||||
|
||||
private sealed class ConnectionDescriptorBuilder
|
||||
{
|
||||
public List<HostEndpoint> Hosts { get; } = new();
|
||||
public string? Database { get; set; }
|
||||
public string? Username { get; set; }
|
||||
public string? Password { get; set; }
|
||||
public SslMode? SslMode { get; set; }
|
||||
public string? ApplicationName { get; set; }
|
||||
public int? TimeoutSeconds { get; set; }
|
||||
public Dictionary<string,string> Properties { get; } = new(StringComparer.OrdinalIgnoreCase);
|
||||
|
||||
public void AddHost(string host, ushort? port)
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(host)) return;
|
||||
Hosts.Add(new HostEndpoint { Host = host.Trim(), Port = port });
|
||||
}
|
||||
|
||||
public ConnectionDescriptor Build()
|
||||
{
|
||||
return new ConnectionDescriptor
|
||||
{
|
||||
Hosts = Hosts,
|
||||
Database = Database,
|
||||
Username = Username,
|
||||
Password = Password,
|
||||
SslMode = SslMode,
|
||||
ApplicationName = ApplicationName,
|
||||
TimeoutSeconds = TimeoutSeconds,
|
||||
Properties = Properties
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ public ref struct PqConnectionStringParser
|
|||
//service
|
||||
//target_session_attrs
|
||||
|
||||
public static IDictionary<string, string> Parse(string input)
|
||||
public static Result<IDictionary<string, string>> Parse(string input)
|
||||
{
|
||||
return new PqConnectionStringParser(
|
||||
new PqConnectionStringTokenizer(input)
|
||||
|
|
@ -63,12 +63,16 @@ public ref struct PqConnectionStringParser
|
|||
this._tokenizer = tokenizer;
|
||||
}
|
||||
|
||||
public IDictionary<string, string> Parse()
|
||||
public Result<IDictionary<string, string>> Parse()
|
||||
{
|
||||
_result.Clear();
|
||||
|
||||
while (!_tokenizer.IsEof)
|
||||
ParsePair();
|
||||
{
|
||||
var result = ParsePair();
|
||||
if (result.IsFailed)
|
||||
return result;
|
||||
}
|
||||
|
||||
return _result;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -67,12 +67,6 @@ public class PqConnectionStringTokenizer : IPqConnectionStringTokenizer
|
|||
{
|
||||
while (position < input.Length && char.IsWhiteSpace(input[position]))
|
||||
position++;
|
||||
// If a semicolon is encountered between pairs (which is not valid in libpq),
|
||||
// treat as immediate EOF so parser stops and leaves trailing data unparsed.
|
||||
if (position < input.Length && input[position] == ';')
|
||||
{
|
||||
position = input.Length; // force EOF
|
||||
}
|
||||
}
|
||||
|
||||
private string UnquotedString(bool forKeyword)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue