pgLabII/pgLabII.PgUtils/ConnectionStrings/Pq/LibpqCodec.cs
2025-08-30 20:10:38 +02:00

222 lines
8.3 KiB
C#

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using FluentResults;
using Npgsql;
namespace pgLabII.PgUtils.ConnectionStrings;
public sealed class LibpqCodec : IConnectionStringCodec
{
public ConnStringFormat Format => ConnStringFormat.Libpq;
public string FormatName => "libpq";
public Result<ConnectionDescriptor> TryParse(string input)
{
try
{
var kv = new PqConnectionStringParser(new PqConnectionStringTokenizer(input)).Parse();
// libpq keywords are case-insensitive; normalize to lower for lookup
var dict = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
foreach (var pair in kv)
dict[pair.Key] = pair.Value;
var descriptor = new ConnectionDescriptorBuilder();
if (dict.TryGetValue("host", out var host))
{
// libpq supports host lists separated by commas
var hosts = host.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries);
ushort? portForAll = null;
if (dict.TryGetValue("port", out var portStr) && ushort.TryParse(portStr, out var p))
portForAll = p;
foreach (var h in hosts)
{
descriptor.AddHost(h, portForAll);
}
}
if (dict.TryGetValue("hostaddr", out var hostaddr) && !string.IsNullOrWhiteSpace(hostaddr))
{
// If hostaddr is provided without host, include as host entries as well
var hosts = hostaddr.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries);
ushort? portForAll = null;
if (dict.TryGetValue("port", out var portStr) && ushort.TryParse(portStr, out var p))
portForAll = p;
foreach (var h in hosts)
descriptor.AddHost(h, portForAll);
}
if (dict.TryGetValue("dbname", out var db))
descriptor.Database = db;
if (dict.TryGetValue("user", out var user))
descriptor.Username = user;
else if (dict.TryGetValue("username", out var username))
descriptor.Username = username;
if (dict.TryGetValue("password", out var pass))
descriptor.Password = pass;
if (dict.TryGetValue("sslmode", out var sslStr))
descriptor.SslMode = ParseSslMode(sslStr);
if (dict.TryGetValue("application_name", out var app))
descriptor.ApplicationName = app;
if (dict.TryGetValue("connect_timeout", out var tout) && int.TryParse(tout, out var seconds))
descriptor.TimeoutSeconds = seconds;
// Remaining properties: store extras excluding mapped keys
var mapped = new HashSet<string>(StringComparer.OrdinalIgnoreCase)
{
"host","hostaddr","port","dbname","user","username","password","sslmode","application_name","connect_timeout"
};
foreach (var (k,v) in dict)
{
if (!mapped.Contains(k))
descriptor.Properties[k] = v;
}
return Result.Ok(descriptor.Build());
}
catch (Exception ex)
{
return Result.Fail<ConnectionDescriptor>(ex.Message);
}
}
public Result<string> TryFormat(ConnectionDescriptor descriptor)
{
try
{
var parts = new List<string>();
// Hosts and port
if (descriptor.Hosts != null && descriptor.Hosts.Count > 0)
{
var hostList = string.Join(',', descriptor.Hosts.Select(h => h.Host));
parts.Add(FormatPair("host", hostList));
// If all ports are same and present, emit a single port
var ports = descriptor.Hosts.Select(h => h.Port).Where(p => p.HasValue).Select(p => p!.Value).Distinct().ToList();
if (ports.Count == 1)
parts.Add(FormatPair("port", ports[0].ToString()));
}
if (!string.IsNullOrEmpty(descriptor.Database))
parts.Add(FormatPair("dbname", descriptor.Database));
if (!string.IsNullOrEmpty(descriptor.Username))
parts.Add(FormatPair("user", descriptor.Username));
if (!string.IsNullOrEmpty(descriptor.Password))
parts.Add(FormatPair("password", descriptor.Password));
if (descriptor.SslMode.HasValue)
parts.Add(FormatPair("sslmode", FormatSslMode(descriptor.SslMode.Value)));
if (!string.IsNullOrEmpty(descriptor.ApplicationName))
parts.Add(FormatPair("application_name", descriptor.ApplicationName));
if (descriptor.TimeoutSeconds.HasValue)
parts.Add(FormatPair("connect_timeout", descriptor.TimeoutSeconds.Value.ToString()));
// Extra properties (avoid duplicating keys we already emitted)
var emitted = new HashSet<string>(parts.Select(p => p.Split('=')[0]), StringComparer.OrdinalIgnoreCase);
foreach (var kv in descriptor.Properties)
{
if (!emitted.Contains(kv.Key))
parts.Add(FormatPair(kv.Key, kv.Value));
}
return Result.Ok(string.Join(' ', parts));
}
catch (Exception ex)
{
return Result.Fail<string>(ex.Message);
}
}
private static SslMode ParseSslMode(string s)
{
return s.Trim().ToLowerInvariant() switch
{
"disable" => SslMode.Disable,
"allow" => SslMode.Allow,
"prefer" => SslMode.Prefer,
"require" => SslMode.Require,
"verify-ca" => SslMode.VerifyCA,
"verify-full" => SslMode.VerifyFull,
_ => throw new ArgumentException($"Not a valid SSL mode: {s}")
};
}
private static string FormatSslMode(SslMode mode)
{
return mode switch
{
SslMode.Disable => "disable",
SslMode.Allow => "allow",
SslMode.Prefer => "prefer",
SslMode.Require => "require",
SslMode.VerifyCA => "verify-ca",
SslMode.VerifyFull => "verify-full",
_ => "prefer"
};
}
private static string FormatPair(string key, string? value)
{
value ??= string.Empty;
if (NeedsQuoting(value))
return key + "='" + EscapeValue(value) + "'";
return key + "=" + value;
}
private static bool NeedsQuoting(string value)
{
if (value.Length == 0) return true;
foreach (var c in value)
{
if (char.IsWhiteSpace(c) || c == '=' || c == '\'' || c == '\\')
return true;
}
return false;
}
private static string EscapeValue(string value)
{
var sb = new StringBuilder();
foreach (var c in value)
{
if (c == '\'' || c == '\\') sb.Append('\\');
sb.Append(c);
}
return sb.ToString();
}
private sealed class ConnectionDescriptorBuilder
{
public List<HostEndpoint> Hosts { get; } = new();
public string? Database { get; set; }
public string? Username { get; set; }
public string? Password { get; set; }
public SslMode? SslMode { get; set; }
public string? ApplicationName { get; set; }
public int? TimeoutSeconds { get; set; }
public Dictionary<string,string> Properties { get; } = new(StringComparer.OrdinalIgnoreCase);
public void AddHost(string host, ushort? port)
{
if (string.IsNullOrWhiteSpace(host)) return;
Hosts.Add(new HostEndpoint { Host = host.Trim(), Port = port });
}
public ConnectionDescriptor Build()
{
return new ConnectionDescriptor
{
Hosts = Hosts,
Database = Database,
Username = Username,
Password = Password,
SslMode = SslMode,
ApplicationName = ApplicationName,
TimeoutSeconds = TimeoutSeconds,
Properties = Properties
};
}
}
}