NpgsqlCodec
This commit is contained in:
parent
f46ee407f2
commit
ca32fa776a
2 changed files with 403 additions and 0 deletions
87
pgLabII.PgUtils.Tests/ConnectionStrings/NpgsqlCodecTests.cs
Normal file
87
pgLabII.PgUtils.Tests/ConnectionStrings/NpgsqlCodecTests.cs
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
using System.Collections.Generic;
|
||||
using Npgsql;
|
||||
using pgLabII.PgUtils.ConnectionStrings;
|
||||
|
||||
namespace pgLabII.PgUtils.Tests.ConnectionStrings;
|
||||
|
||||
public class NpgsqlCodecTests
|
||||
{
|
||||
[Fact]
|
||||
public void Parse_Basic()
|
||||
{
|
||||
var codec = new NpgsqlCodec();
|
||||
var res = codec.TryParse("Host=localhost;Port=5434;Database=testdb;Username=alice;Password=secret;SSL Mode=Require;Application Name=pgLab;Timeout=10");
|
||||
Assert.True(res.IsSuccess);
|
||||
var d = res.Value;
|
||||
Assert.Single(d.Hosts);
|
||||
Assert.Equal("localhost", d.Hosts[0].Host);
|
||||
Assert.Equal((ushort)5434, d.Hosts[0].Port);
|
||||
Assert.Equal("testdb", d.Database);
|
||||
Assert.Equal("alice", d.Username);
|
||||
Assert.Equal("secret", d.Password);
|
||||
Assert.Equal(SslMode.Require, d.SslMode);
|
||||
Assert.Equal("pgLab", d.ApplicationName);
|
||||
Assert.Equal(10, d.TimeoutSeconds);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Parse_MultiHost_WithSinglePort()
|
||||
{
|
||||
var codec = new NpgsqlCodec();
|
||||
var res = codec.TryParse("Host=host1,host2;Port=5433;Database=db;Username=u");
|
||||
Assert.True(res.IsSuccess);
|
||||
var d = res.Value;
|
||||
Assert.Equal(2, d.Hosts.Count);
|
||||
Assert.Equal("host1", d.Hosts[0].Host);
|
||||
Assert.Equal((ushort)5433, d.Hosts[0].Port);
|
||||
Assert.Equal("host2", d.Hosts[1].Host);
|
||||
Assert.Equal((ushort)5433, d.Hosts[1].Port);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Format_Basic_WithQuoting()
|
||||
{
|
||||
var codec = new NpgsqlCodec();
|
||||
var d = new ConnectionDescriptor
|
||||
{
|
||||
Hosts = new [] { new HostEndpoint{ Host = "db.example.com", Port = 5432 } },
|
||||
Database = "prod db",
|
||||
Username = "bob",
|
||||
Password = "p;ss\"word",
|
||||
SslMode = SslMode.VerifyFull,
|
||||
ApplicationName = "cli app",
|
||||
TimeoutSeconds = 9,
|
||||
Properties = new Dictionary<string,string>{{"Search Path","public"}}
|
||||
};
|
||||
var res = codec.TryFormat(d);
|
||||
Assert.True(res.IsSuccess);
|
||||
var s = res.Value;
|
||||
Assert.Contains("Host=db.example.com", s);
|
||||
Assert.Contains("Port=5432", s);
|
||||
Assert.Contains("Database=\"prod db\"", s);
|
||||
Assert.Contains("Username=bob", s);
|
||||
Assert.Contains("Password=\"p;ss\"\"word\"", s);
|
||||
Assert.Contains("SSL Mode=VerifyFull", s);
|
||||
Assert.Contains("Application Name=\"cli app\"", s);
|
||||
Assert.Contains("Timeout=9", s);
|
||||
Assert.Contains("Search Path=public", s);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Roundtrip_ParseThenFormat()
|
||||
{
|
||||
var codec = new NpgsqlCodec();
|
||||
var input = "Host=\"my host\";Database=postgres;Username=me;Password=\"with;quote\"\"\";Application Name=\"my app\";SSL Mode=Prefer";
|
||||
var parsed = codec.TryParse(input);
|
||||
Assert.True(parsed.IsSuccess);
|
||||
var formatted = codec.TryFormat(parsed.Value);
|
||||
Assert.True(formatted.IsSuccess);
|
||||
var s = formatted.Value;
|
||||
Assert.Contains("Host=\"my host\"", s);
|
||||
Assert.Contains("Database=postgres", s);
|
||||
Assert.Contains("Username=me", s);
|
||||
Assert.Contains("Password=\"with;quote\"\"\"", s);
|
||||
Assert.Contains("Application Name=\"my app\"", s);
|
||||
Assert.Contains("SSL Mode=Prefer", s);
|
||||
}
|
||||
}
|
||||
316
pgLabII.PgUtils/ConnectionStrings/NpgsqlCodec.cs
Normal file
316
pgLabII.PgUtils/ConnectionStrings/NpgsqlCodec.cs
Normal file
|
|
@ -0,0 +1,316 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Globalization;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using FluentResults;
|
||||
using Npgsql;
|
||||
|
||||
namespace pgLabII.PgUtils.ConnectionStrings;
|
||||
|
||||
public sealed class NpgsqlCodec : IConnectionStringCodec
|
||||
{
|
||||
public ConnStringFormat Format => ConnStringFormat.Npgsql;
|
||||
public string FormatName => "Npgsql";
|
||||
|
||||
public Result<ConnectionDescriptor> TryParse(string input)
|
||||
{
|
||||
try
|
||||
{
|
||||
var dict = Tokenize(input);
|
||||
var descriptor = new ConnectionDescriptorBuilder();
|
||||
|
||||
// Hosts and Ports
|
||||
if (dict.TryGetValue("Host", out var hostVal) || dict.TryGetValue("Server", out hostVal) || dict.TryGetValue("Servers", out hostVal))
|
||||
{
|
||||
var hosts = SplitList(hostVal).ToList();
|
||||
List<ushort?> portsPerHost = new();
|
||||
if (dict.TryGetValue("Port", out var portVal))
|
||||
{
|
||||
var ports = SplitList(portVal).ToList();
|
||||
if (ports.Count == 1 && ushort.TryParse(ports[0], out var singlePort))
|
||||
{
|
||||
foreach (var _ in hosts) portsPerHost.Add(singlePort);
|
||||
}
|
||||
else if (ports.Count == hosts.Count)
|
||||
{
|
||||
foreach (var p in ports)
|
||||
{
|
||||
if (ushort.TryParse(p, NumberStyles.Integer, CultureInfo.InvariantCulture, out var up))
|
||||
portsPerHost.Add(up);
|
||||
else
|
||||
portsPerHost.Add(null);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < hosts.Count; i++)
|
||||
{
|
||||
ushort? port = i < portsPerHost.Count ? portsPerHost[i] : null;
|
||||
descriptor.AddHost(hosts[i], port);
|
||||
}
|
||||
}
|
||||
|
||||
// Standard fields
|
||||
if (TryGetFirst(dict, out var db, "Database", "Db", "Initial Catalog", "dbname"))
|
||||
descriptor.Database = db;
|
||||
if (TryGetFirst(dict, out var user, "Username", "User ID", "User", "UID"))
|
||||
descriptor.Username = user;
|
||||
if (TryGetFirst(dict, out var pass, "Password", "PWD"))
|
||||
descriptor.Password = pass;
|
||||
if (TryGetFirst(dict, out var app, "Application Name", "ApplicationName"))
|
||||
descriptor.ApplicationName = app;
|
||||
if (TryGetFirst(dict, out var timeout, "Timeout", "Connect Timeout", "Connection Timeout"))
|
||||
{
|
||||
if (int.TryParse(timeout, NumberStyles.Integer, CultureInfo.InvariantCulture, out var t))
|
||||
descriptor.TimeoutSeconds = t;
|
||||
}
|
||||
if (TryGetFirst(dict, out var ssl, "SSL Mode", "SslMode", "SSLMode"))
|
||||
descriptor.SslMode = ParseSslMode(ssl);
|
||||
|
||||
// Preserve extras (not mapped) into Properties
|
||||
var mapped = new HashSet<string>(StringComparer.OrdinalIgnoreCase)
|
||||
{
|
||||
"Host","Server","Servers","Port","Database","Db","Initial Catalog","dbname",
|
||||
"Username","User ID","User","UID","Password","PWD","Application Name","ApplicationName",
|
||||
"Timeout","Connect Timeout","Connection Timeout","SSL Mode","SslMode","SSLMode"
|
||||
};
|
||||
foreach (var (k, v) in dict)
|
||||
{
|
||||
if (!mapped.Contains(k))
|
||||
descriptor.Properties[k] = v;
|
||||
}
|
||||
|
||||
return Result.Ok(descriptor.Build());
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
return Result.Fail<ConnectionDescriptor>(ex.Message);
|
||||
}
|
||||
}
|
||||
|
||||
public Result<string> TryFormat(ConnectionDescriptor descriptor)
|
||||
{
|
||||
try
|
||||
{
|
||||
var parts = new List<string>();
|
||||
|
||||
if (descriptor.Hosts != null && descriptor.Hosts.Count > 0)
|
||||
{
|
||||
var hostList = string.Join(',', descriptor.Hosts.Select(h => h.Host));
|
||||
parts.Add(FormatPair("Host", hostList));
|
||||
var ports = descriptor.Hosts.Select(h => h.Port).Where(p => p.HasValue).Select(p => p!.Value).Distinct().ToList();
|
||||
if (ports.Count == 1)
|
||||
{
|
||||
parts.Add(FormatPair("Port", ports[0].ToString(CultureInfo.InvariantCulture)));
|
||||
}
|
||||
else if (ports.Count == 0)
|
||||
{
|
||||
// nothing
|
||||
}
|
||||
else
|
||||
{
|
||||
// Per-host ports if provided 1:1
|
||||
var perHost = descriptor.Hosts.Select(h => h.Port?.ToString(CultureInfo.InvariantCulture) ?? string.Empty).ToList();
|
||||
if (perHost.All(s => !string.IsNullOrEmpty(s)))
|
||||
parts.Add(FormatPair("Port", string.Join(',', perHost)));
|
||||
}
|
||||
}
|
||||
|
||||
if (!string.IsNullOrEmpty(descriptor.Database))
|
||||
parts.Add(FormatPair("Database", descriptor.Database));
|
||||
if (!string.IsNullOrEmpty(descriptor.Username))
|
||||
parts.Add(FormatPair("Username", descriptor.Username));
|
||||
if (!string.IsNullOrEmpty(descriptor.Password))
|
||||
parts.Add(FormatPair("Password", descriptor.Password));
|
||||
if (descriptor.SslMode.HasValue)
|
||||
parts.Add(FormatPair("SSL Mode", FormatSslMode(descriptor.SslMode.Value)));
|
||||
if (!string.IsNullOrEmpty(descriptor.ApplicationName))
|
||||
parts.Add(FormatPair("Application Name", descriptor.ApplicationName));
|
||||
if (descriptor.TimeoutSeconds.HasValue)
|
||||
parts.Add(FormatPair("Timeout", descriptor.TimeoutSeconds.Value.ToString(CultureInfo.InvariantCulture)));
|
||||
|
||||
var emittedKeys = new HashSet<string>(parts.Select(p => p.Split('=')[0].Trim()), StringComparer.OrdinalIgnoreCase);
|
||||
foreach (var kv in descriptor.Properties)
|
||||
{
|
||||
if (!emittedKeys.Contains(kv.Key))
|
||||
parts.Add(FormatPair(kv.Key, kv.Value));
|
||||
}
|
||||
|
||||
return Result.Ok(string.Join(";", parts));
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
return Result.Fail<string>(ex.Message);
|
||||
}
|
||||
}
|
||||
|
||||
private static IEnumerable<string> SplitList(string s)
|
||||
{
|
||||
return s.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries);
|
||||
}
|
||||
|
||||
private static bool TryGetFirst(Dictionary<string, string> dict, out string value, params string[] keys)
|
||||
{
|
||||
foreach (var k in keys)
|
||||
{
|
||||
if (dict.TryGetValue(k, out value)) return true;
|
||||
}
|
||||
value = string.Empty;
|
||||
return false;
|
||||
}
|
||||
|
||||
private static SslMode ParseSslMode(string s)
|
||||
{
|
||||
switch (s.Trim().ToLowerInvariant())
|
||||
{
|
||||
case "disable": return SslMode.Disable;
|
||||
case "allow": return SslMode.Allow;
|
||||
case "prefer": return SslMode.Prefer;
|
||||
case "require": return SslMode.Require;
|
||||
case "verify-ca":
|
||||
case "verifyca": return SslMode.VerifyCA;
|
||||
case "verify-full":
|
||||
case "verifyfull": return SslMode.VerifyFull;
|
||||
default: throw new ArgumentException($"Not a valid SSL Mode: {s}");
|
||||
}
|
||||
}
|
||||
|
||||
private static string FormatSslMode(SslMode mode)
|
||||
{
|
||||
return mode switch
|
||||
{
|
||||
SslMode.Disable => "Disable",
|
||||
SslMode.Allow => "Allow",
|
||||
SslMode.Prefer => "Prefer",
|
||||
SslMode.Require => "Require",
|
||||
SslMode.VerifyCA => "VerifyCA",
|
||||
SslMode.VerifyFull => "VerifyFull",
|
||||
_ => "Prefer"
|
||||
};
|
||||
}
|
||||
|
||||
// Npgsql/.NET connection string grammar: semicolon-separated key=value; values with special chars are wrapped in quotes, internal quotes doubled
|
||||
private static string FormatPair(string key, string? value)
|
||||
{
|
||||
value ??= string.Empty;
|
||||
var needsQuotes = NeedsQuoting(value);
|
||||
if (!needsQuotes) return key + "=" + value;
|
||||
return key + "=\"" + EscapeQuoted(value) + "\"";
|
||||
}
|
||||
|
||||
private static bool NeedsQuoting(string value)
|
||||
{
|
||||
if (value.Length == 0) return true;
|
||||
foreach (var c in value)
|
||||
{
|
||||
if (char.IsWhiteSpace(c) || c == ';' || c == '=' || c == '"') return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private static string EscapeQuoted(string value)
|
||||
{
|
||||
// Double the quotes per standard DbConnectionString rules
|
||||
return value.Replace("\"", "\"\"");
|
||||
}
|
||||
|
||||
private static Dictionary<string, string> Tokenize(string input)
|
||||
{
|
||||
// Simple tokenizer for .NET connection strings: key=value pairs separated by semicolons; values may be quoted with double quotes
|
||||
var dict = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
|
||||
int i = 0;
|
||||
void SkipWs() { while (i < input.Length && char.IsWhiteSpace(input[i])) i++; }
|
||||
|
||||
while (true)
|
||||
{
|
||||
SkipWs();
|
||||
if (i >= input.Length) break;
|
||||
|
||||
// read key
|
||||
int keyStart = i;
|
||||
while (i < input.Length && input[i] != '=') i++;
|
||||
if (i >= input.Length) { break; }
|
||||
var key = input.Substring(keyStart, i - keyStart).Trim();
|
||||
i++; // skip '='
|
||||
SkipWs();
|
||||
|
||||
// read value
|
||||
string value;
|
||||
if (i < input.Length && input[i] == '"')
|
||||
{
|
||||
i++; // skip opening quote
|
||||
var sb = new StringBuilder();
|
||||
while (i < input.Length)
|
||||
{
|
||||
char c = input[i++];
|
||||
if (c == '"')
|
||||
{
|
||||
if (i < input.Length && input[i] == '"')
|
||||
{
|
||||
// doubled quote -> literal quote
|
||||
sb.Append('"');
|
||||
i++;
|
||||
continue;
|
||||
}
|
||||
else
|
||||
{
|
||||
break; // end quoted value
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
sb.Append(c);
|
||||
}
|
||||
}
|
||||
value = sb.ToString();
|
||||
}
|
||||
else
|
||||
{
|
||||
int valStart = i;
|
||||
while (i < input.Length && input[i] != ';') i++;
|
||||
value = input.Substring(valStart, i - valStart).Trim();
|
||||
}
|
||||
|
||||
dict[key] = value;
|
||||
|
||||
// skip to next, if ; present, consume one
|
||||
while (i < input.Length && input[i] != ';') i++;
|
||||
if (i < input.Length && input[i] == ';') i++;
|
||||
}
|
||||
|
||||
return dict;
|
||||
}
|
||||
|
||||
private sealed class ConnectionDescriptorBuilder
|
||||
{
|
||||
public List<HostEndpoint> Hosts { get; } = new();
|
||||
public string? Database { get; set; }
|
||||
public string? Username { get; set; }
|
||||
public string? Password { get; set; }
|
||||
public SslMode? SslMode { get; set; }
|
||||
public string? ApplicationName { get; set; }
|
||||
public int? TimeoutSeconds { get; set; }
|
||||
public Dictionary<string, string> Properties { get; } = new(StringComparer.OrdinalIgnoreCase);
|
||||
|
||||
public void AddHost(string host, ushort? port)
|
||||
{
|
||||
if (string.IsNullOrWhiteSpace(host)) return;
|
||||
Hosts.Add(new HostEndpoint { Host = host.Trim(), Port = port });
|
||||
}
|
||||
|
||||
public ConnectionDescriptor Build()
|
||||
{
|
||||
return new ConnectionDescriptor
|
||||
{
|
||||
Hosts = Hosts,
|
||||
Database = Database,
|
||||
Username = Username,
|
||||
Password = Password,
|
||||
SslMode = SslMode,
|
||||
ApplicationName = ApplicationName,
|
||||
TimeoutSeconds = TimeoutSeconds,
|
||||
Properties = Properties
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue