using System;
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using System.Text;
using FluentResults;
using Npgsql;
namespace pgLabII.PgUtils.ConnectionStrings;
///
/// Parser/formatter for Npgsql-style .NET connection strings. We intentionally do not
/// rely on NpgsqlConnectionStringBuilder here because:
/// - We need a lossless, format-agnostic round-trip to our ConnectionDescriptor, including
/// unknown/extension keys and per-host port lists. NpgsqlConnectionStringBuilder normalizes
/// names, may drop unknown keys or coerce values, which breaks lossless conversions.
/// - We support multi-host with per-host ports and want to preserve the original textual
/// representation across conversions. The builder flattens/rewrites these details.
/// - We aim to keep pgLabII.PgUtils independent from Npgsql's evolving parsing rules and
/// version-specific behaviors to ensure stable UX and deterministic tests.
/// - We need symmetric formatting matching our other codecs (libpq/URL/JDBC) and consistent
/// quoting rules across formats.
/// If required, we still reference Npgsql for enums and interop types, but parsing/formatting
/// is done by this small, well-tested custom codec for full control and stability.
///
public sealed class NpgsqlCodec : IConnectionStringCodec
{
public ConnStringFormat Format => ConnStringFormat.Npgsql;
public string FormatName => "Npgsql";
public Result TryParse(string input)
{
try
{
var dict = Tokenize(input);
var descriptor = new ConnectionDescriptorBuilder();
// Hosts and Ports
if (dict.TryGetValue("Host", out var hostVal) || dict.TryGetValue("Server", out hostVal) || dict.TryGetValue("Servers", out hostVal))
{
var hosts = SplitList(hostVal).ToList();
List portsPerHost = new();
if (dict.TryGetValue("Port", out var portVal))
{
var ports = SplitList(portVal).ToList();
if (ports.Count == 1 && ushort.TryParse(ports[0], out var singlePort))
{
foreach (var _ in hosts) portsPerHost.Add(singlePort);
}
else if (ports.Count == hosts.Count)
{
foreach (var p in ports)
{
if (ushort.TryParse(p, NumberStyles.Integer, CultureInfo.InvariantCulture, out var up))
portsPerHost.Add(up);
else
portsPerHost.Add(null);
}
}
}
for (int i = 0; i < hosts.Count; i++)
{
ushort? port = i < portsPerHost.Count ? portsPerHost[i] : null;
descriptor.AddHost(hosts[i], port);
}
}
// Standard fields
if (TryGetFirst(dict, out var db, "Database", "Db", "Initial Catalog", "dbname"))
descriptor.Database = db;
if (TryGetFirst(dict, out var user, "Username", "User ID", "User", "UID"))
descriptor.Username = user;
if (TryGetFirst(dict, out var pass, "Password", "PWD"))
descriptor.Password = pass;
if (TryGetFirst(dict, out var app, "Application Name", "ApplicationName"))
descriptor.ApplicationName = app;
if (TryGetFirst(dict, out var timeout, "Timeout", "Connect Timeout", "Connection Timeout"))
{
if (int.TryParse(timeout, NumberStyles.Integer, CultureInfo.InvariantCulture, out var t))
descriptor.TimeoutSeconds = t;
}
if (TryGetFirst(dict, out var ssl, "SSL Mode", "SslMode", "SSLMode"))
descriptor.SslMode = ParseSslMode(ssl);
// Preserve extras (not mapped) into Properties
var mapped = new HashSet(StringComparer.OrdinalIgnoreCase)
{
"Host","Server","Servers","Port","Database","Db","Initial Catalog","dbname",
"Username","User ID","User","UID","Password","PWD","Application Name","ApplicationName",
"Timeout","Connect Timeout","Connection Timeout","SSL Mode","SslMode","SSLMode"
};
foreach (var (k, v) in dict)
{
if (!mapped.Contains(k))
descriptor.Properties[k] = v;
}
return Result.Ok(descriptor.Build());
}
catch (Exception ex)
{
return Result.Fail(ex.Message);
}
}
public Result TryFormat(ConnectionDescriptor descriptor)
{
try
{
var parts = new List();
if (descriptor.Hosts != null && descriptor.Hosts.Count > 0)
{
var hostList = string.Join(',', descriptor.Hosts.Select(h => h.Host));
parts.Add(FormatPair("Host", hostList));
var ports = descriptor.Hosts.Select(h => h.Port).Where(p => p.HasValue).Select(p => p!.Value).Distinct().ToList();
if (ports.Count == 1)
{
parts.Add(FormatPair("Port", ports[0].ToString(CultureInfo.InvariantCulture)));
}
else if (ports.Count == 0)
{
// nothing
}
else
{
// Per-host ports if provided 1:1
var perHost = descriptor.Hosts.Select(h => h.Port?.ToString(CultureInfo.InvariantCulture) ?? string.Empty).ToList();
if (perHost.All(s => !string.IsNullOrEmpty(s)))
parts.Add(FormatPair("Port", string.Join(',', perHost)));
}
}
if (!string.IsNullOrEmpty(descriptor.Database))
parts.Add(FormatPair("Database", descriptor.Database));
if (!string.IsNullOrEmpty(descriptor.Username))
parts.Add(FormatPair("Username", descriptor.Username));
if (!string.IsNullOrEmpty(descriptor.Password))
parts.Add(FormatPair("Password", descriptor.Password));
if (descriptor.SslMode.HasValue)
parts.Add(FormatPair("SSL Mode", FormatSslMode(descriptor.SslMode.Value)));
if (!string.IsNullOrEmpty(descriptor.ApplicationName))
parts.Add(FormatPair("Application Name", descriptor.ApplicationName));
if (descriptor.TimeoutSeconds.HasValue)
parts.Add(FormatPair("Timeout", descriptor.TimeoutSeconds.Value.ToString(CultureInfo.InvariantCulture)));
var emittedKeys = new HashSet(parts.Select(p => p.Split('=')[0].Trim()), StringComparer.OrdinalIgnoreCase);
foreach (var kv in descriptor.Properties)
{
if (!emittedKeys.Contains(kv.Key))
parts.Add(FormatPair(kv.Key, kv.Value));
}
return Result.Ok(string.Join(";", parts));
}
catch (Exception ex)
{
return Result.Fail(ex.Message);
}
}
private static IEnumerable SplitList(string s)
{
return s.Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries);
}
private static bool TryGetFirst(Dictionary dict, out string value, params string[] keys)
{
foreach (var k in keys)
{
if (dict.TryGetValue(k, out value)) return true;
}
value = string.Empty;
return false;
}
private static SslMode ParseSslMode(string s)
{
switch (s.Trim().ToLowerInvariant())
{
case "disable": return SslMode.Disable;
case "allow": return SslMode.Allow;
case "prefer": return SslMode.Prefer;
case "require": return SslMode.Require;
case "verify-ca":
case "verifyca": return SslMode.VerifyCA;
case "verify-full":
case "verifyfull": return SslMode.VerifyFull;
default: throw new ArgumentException($"Not a valid SSL Mode: {s}");
}
}
private static string FormatSslMode(SslMode mode)
{
return mode switch
{
SslMode.Disable => "Disable",
SslMode.Allow => "Allow",
SslMode.Prefer => "Prefer",
SslMode.Require => "Require",
SslMode.VerifyCA => "VerifyCA",
SslMode.VerifyFull => "VerifyFull",
_ => "Prefer"
};
}
// Npgsql/.NET connection string grammar: semicolon-separated key=value; values with special chars are wrapped in quotes, internal quotes doubled
private static string FormatPair(string key, string? value)
{
value ??= string.Empty;
var needsQuotes = NeedsQuoting(value);
if (!needsQuotes) return key + "=" + value;
return key + "=\"" + EscapeQuoted(value) + "\"";
}
private static bool NeedsQuoting(string value)
{
if (value.Length == 0) return true;
foreach (var c in value)
{
if (char.IsWhiteSpace(c) || c == ';' || c == '=' || c == '"') return true;
}
return false;
}
private static string EscapeQuoted(string value)
{
// Double the quotes per standard DbConnectionString rules
return value.Replace("\"", "\"\"");
}
private static Dictionary Tokenize(string input)
{
// Simple tokenizer for .NET connection strings: key=value pairs separated by semicolons; values may be quoted with double quotes
var dict = new Dictionary(StringComparer.OrdinalIgnoreCase);
int i = 0;
void SkipWs() { while (i < input.Length && char.IsWhiteSpace(input[i])) i++; }
while (true)
{
SkipWs();
if (i >= input.Length) break;
// read key
int keyStart = i;
while (i < input.Length && input[i] != '=') i++;
if (i >= input.Length) { break; }
var key = input.Substring(keyStart, i - keyStart).Trim();
i++; // skip '='
SkipWs();
// read value
string value;
if (i < input.Length && input[i] == '"')
{
i++; // skip opening quote
var sb = new StringBuilder();
while (i < input.Length)
{
char c = input[i++];
if (c == '"')
{
if (i < input.Length && input[i] == '"')
{
// doubled quote -> literal quote
sb.Append('"');
i++;
continue;
}
else
{
break; // end quoted value
}
}
else
{
sb.Append(c);
}
}
value = sb.ToString();
}
else
{
int valStart = i;
while (i < input.Length && input[i] != ';') i++;
value = input.Substring(valStart, i - valStart).Trim();
}
dict[key] = value;
// skip to next, if ; present, consume one
while (i < input.Length && input[i] != ';') i++;
if (i < input.Length && input[i] == ';') i++;
}
return dict;
}
private sealed class ConnectionDescriptorBuilder
{
public List Hosts { get; } = new();
public string? Database { get; set; }
public string? Username { get; set; }
public string? Password { get; set; }
public SslMode? SslMode { get; set; }
public string? ApplicationName { get; set; }
public int? TimeoutSeconds { get; set; }
public Dictionary Properties { get; } = new(StringComparer.OrdinalIgnoreCase);
public void AddHost(string host, ushort? port)
{
if (string.IsNullOrWhiteSpace(host)) return;
Hosts.Add(new HostEndpoint { Host = host.Trim(), Port = port });
}
public ConnectionDescriptor Build()
{
return new ConnectionDescriptor
{
Hosts = Hosts,
Database = Database,
Username = Username,
Password = Password,
SslMode = SslMode,
ApplicationName = ApplicationName,
TimeoutSeconds = TimeoutSeconds,
Properties = Properties
};
}
}
}