diff --git a/.build/ca.crt b/.build/ca.crt new file mode 100644 index 0000000000..e5a4081a02 --- /dev/null +++ b/.build/ca.crt @@ -0,0 +1,31 @@ +-----BEGIN CERTIFICATE----- +MIIFazCCA1OgAwIBAgIUB/AJgMX+fmeXvBOUWW7WR+XKZ6AwDQYJKoZIhvcNAQEL +BQAwRTELMAkGA1UEBhMCVVMxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDAeFw0yNDExMjAwNDExMjFaFw0zNDEx +MTgwNDExMjFaMEUxCzAJBgNVBAYTAlVTMRMwEQYDVQQIDApTb21lLVN0YXRlMSEw +HwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQwggIiMA0GCSqGSIb3DQEB +AQUAA4ICDwAwggIKAoICAQC8A5+//15VxRCxpHzl7srYx6uWQi1/7q5VFWFZab+7 +82PLr3pV/zMjSMEPBZdq46NWWNnXIFoFHd5MFnN4fNIQ1GIEsTF0kYy142qllnp3 +vLBVBu24n4dsmI8ygl8+1PuGwk45Mz+vOL+RjNIo6ra9yJzYFnZOGCqlt0kWkCau +HR/43ms0vhKq8FaDXPdVXn9Z3EZScxRKQwlfAKOUxLQ8dVkzvRuAm0PF74afRYfg +xiGIX8msFYKzGnWb7ezcag125iEqg+xSplo6QK6vaNURlKwYQ8ZRKz1Hk1oIB4t1 +iEJL2d4nzgTkh/jlVjtTXo6cw96WT9NBT0Rg6JR4PJySlhY+ZwLi6VAxQZ8GyJo4 +YTvx1K3vhXeokKjFTxUtZdx1blX5vCBXv9LCxnjAsBCTRzE425x6UP1gp721gHGW +sqopvkUgN9vk8oigyWLeGvwsBwFTFnY672iCYXhFHs2oKTIX8yo+A2xRr8tewb9C +IsqJSC6JkLs5zbVwKdgVx1H21Uwvi7XjKir9pPp/ks12r9GNMmWc265PK1kCqCHa +oHfgzYMVVFQ3CfYbeeA8/aVf770AfC/1v+VtMse8DEqyep5q0OzOXtWIQlahYiyA +FLTzCBqcHUuRZtS4gEhOk6/Pk1HP3faUC1xGgxO5c/pd7SVMfs+Z58WJbYGFcAlC ++QIDAQABo1MwUTAdBgNVHQ4EFgQUBeKaoc7AMURxdajJ+CF8YrUsdFgwHwYDVR0j +BBgwFoAUBeKaoc7AMURxdajJ+CF8YrUsdFgwDwYDVR0TAQH/BAUwAwEB/zANBgkq +hkiG9w0BAQsFAAOCAgEAGGpFZm0c36Eh5E8QiAg8+8U22Ao+YoF6nJnIlc/ri1pt +J5zXRM2DbCCR9uN5yckmCNIJ4PZO49QBflYGPAkF+Vd0RJYoA4k1Cq+eYcJBWtXl +ESJxeg1QAKAZ4XSasOIijebWlPIZxPGOy8HquKNMDQIm8a7g5zSE4UNJPVY3y9on +zJT7ZhntIwuM8IP6h6gotJfxBHJRWNe/g0zVITQ7vHnxSpobLbuKfY21GLl6clgI +WsePKWWo/mZYquqZz72KBUJ66YX4X7nJCvZs1sLgMnXh87n9hsxAdFlRgLuQ4ztp +mwQbDZ90mJFQLprI4rfyamuloIgOcn05yXfklRAI2P8L2/yf5xNAy+ii0OHRiMVv +jnYUet8Bca1orh7OQ9ol1XTBoCI1gknrdG5Y2IQvQhWLiS5AjIwwQYwjkSFXELtF +X8v9Fv758RA9CFlQDnsp9awNjdLss/TdH6+dNYQfTNGigIPM6oCk5nrcQqF/533W +z2WM0LNHAiQlEn0X38D0wCuRwIVzPG/AFyfsf50vSlH81/uzpyR5q3SJA8OKiCV1 +/OiW7Jv7pOtwqFjxR+m31TqaPM6PLrdasP/CNKSvGuJmtaHK4Wkc3YU9dbtQffzB +MUFwhi233gvE+nSEixse2KlzsrBVZIdz16bZXaAd20JQdq9Hceku2uVgfN1fycI= +-----END CERTIFICATE----- diff --git a/.build/server.crt b/.build/server.crt index d161ab2652..5a2bfc7b01 100644 --- a/.build/server.crt +++ b/.build/server.crt @@ -1,20 +1,30 @@ -----BEGIN CERTIFICATE----- -MIIDUjCCAjoCFAwuj6RwuZSjCGYHja8m9tbr3nFeMA0GCSqGSIb3DQEBCwUAMGgx -EzARBgNVBAoTCk15IENvbXBhbnkxCzAJBgNVBAsTAklUMRAwDgYDVQQHEwdNeSBU -b3duMQ8wDQYDVQQIEwZNb3Njb3cxCzAJBgNVBAYTAlJVMRQwEgYDVQQDEwtsb2Nh -bGhvc3RDQTAeFw0yMTA0MTAxMzA0MDBaFw0yMjA0MTAxMzA0MDBaMGMxEzARBgNV -BAoTCk15IENvbXBhbnkxCzAJBgNVBAsTAklUMRAwDgYDVQQHEwdNeSBUb3duMQ8w -DQYDVQQIEwZNb3Njb3cxCzAJBgNVBAYTAlJVMQ8wDQYDVQQDEwZzZXJ2ZXIwggEi -MA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC8LoQbo2DFwC17gZwJ8xrPKHGX -UKxoo5UcyZ3/2zZ006TYkswssejKksuiICTMI89OD8n55pNTZkXPUH7oR2oIyxTY -SiWPiNzbEh0FOxH9Kh5gmajqM/4X44OaprmyQ56m4Y2LZO2nZ9hHoe+ZRoan3+pa -g8weOM/n/wYuXZtdElOxNsB8pg09K4gevHVaLaSBCEeQfHev51vClFdN3+orBi/r -hnQF3vdw7oMT1JSH75Ray51wRaypLIslAc2DcPFTCQJMmXXMTcAcxmjAVUGrfY+d -sSCdXnOZtd7yk+0X0bVGKLBkCTOP7QpmfOVu9bOhscDiK5EoAaDKqdHSMUfhAgMB -AAEwDQYJKoZIhvcNAQELBQADggEBAKCo2Y1uKbudA8JpV6yo35tc7Z6n03++BAdq -egUBKOiE4ze7xQ7lmlt572ptqXlU/8JuPWa2Qb/wGksR0HpVPTAeU3pbXz1dcCXC -A9wCtSxapjyCYbkDrDl2FQuK0OfJi0q71JZU66D58Qu0l45nWON30to9dSiw3zPw -Rdk7X86GHYIBHKsj7mjiy1v8jH1sXeWvThOmU6+rv8UY8VuJiu4MQDdYa0Y5KFh/ -OL3tVsi7zoNu2OXY1cTKuUpKMQPbO+WSdelYromYK2OAXaNqnC27GegPqvCFWJ2I -9NZuXYj3X+j0ydZSKVjDgCda8H68olBnO0zh44XirCBef7uTVLw= +MIIFJTCCAw0CFAKjNOhsMTYUuQngy2k291XuKOGGMA0GCSqGSIb3DQEBCwUAMEUx +CzAJBgNVBAYTAlVTMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRl +cm5ldCBXaWRnaXRzIFB0eSBMdGQwHhcNMjQxMTIwMDQxOTE0WhcNMjkxMTE5MDQx +OTE0WjBZMQswCQYDVQQGEwJVUzETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UE +CgwYSW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMRIwEAYDVQQDDAlsb2NhbGhvc3Qw +ggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoICAQDlGT9vXb93yoM1YT0GAxJI +B6/2ExUrdprd049oMVZa4Km0nqwN/xjVvQRIWozmbpvps0mCkFM1ZyL1iqZFwiJG +WcQvvIffFM1qKRMOSTLNPCbM9mfvRKsCU9gjgatdhy8xUZhz7uFGMGADnZdlNMYW +GgzMVZo0EyW7Z2QJ+ZCl8wW5IT4iswZWrJsNZU/g7HaNBrXiidDihkmQ8Kt32R0U +nqJeXMHwkQLxddmcGdDmVCKsAEUu3NcvPeAlSJsNHfGDRsf9fImRqZCsgwI8dJtA +ke/luMTttQ34aADFTmTbVk4ngVhCxgBkJ6FUDFJcp3t3nFssiisNon9k5FwtJ3hl +e/QGM9IRdBvGVcOnZZuXXK2lLtakj5UWUik2xWA0hjX+DsFo7TPwKgZy4zmWCRob +W1e1NX52bqYFWZUKYLqbizllOd98o3yed58PhbF1/IuVEuOoiKu7rNdNgzr8vgRP +pWHQNXp3maCcZq2kWybADU2LQNUKAZLSw3nClcX8QVRAfvf8IyDZ/280EYRGu99V +qLqDPLa1+3CNAb93J1ONvVjKgJwQQWy4dYFLHTYdBzXV5SOpH8YHL/1IHs9W5k28 +BdwbeMtJnOaV8rqiA6Xd4Xem111AMAigHExxG3kpSnAq6jiOX0+2V++f7qAunuC6 +B/oJATXLCbBQILr0ARtKuQIDAQABMA0GCSqGSIb3DQEBCwUAA4ICAQCn6R2fvxfs +R7nN9g6bVNJXkJrDJ+O1suVD0tkZzxZAIAFdhKnSFocJph1bC6bSZEQkhG+0WtfU +DU7m19VDHpZWZ+8LygIVikIkvj47v1/yl7TgwkhNAKXXxl6bF/AEevMUZoxT3r8S +UBFURp8QduSQ7sbDRB9qR1EWPjAXgnedzLSGkt5E6VKuVRwsTjv7QUTV8RCbOl9b +YHtTX3dtvr3PeAB5M3B6qrbpniqJfPxUt658UKrDGFr1MuZZ8ONYpdiGH8uGXZhs +9BBjp0g0xWha9LYDYRpqzlC1hqV0J/9jz9QdS9HHPsqa8PvB/YwaDGQm/RSRMUbU +x0wip0me45WU5pLD1djEGQBlxCGgQXIJsebzipdUsayA4MgY3s2lBj2qsPOqyNoP +dFohMm2+Ypi8UAjEbeGY4XsCODLeCvPx24HyjJUORm9uuPCunSBhtgiEBTJrNwHL +F7T1+/g9gVSwCsz4MqceO7IooJ2omSpwk7xrzocccFb1HGR/tE9GxRLNHiyTfx9s +FN9SNOih5DCcOFOiw0vF1qKHk6CAJ0UCBzVWl3YO9OgnFX4FbRYHd3PduWR+fSkd +icBs2AiOKPbOU8yXR8CE6uZiDoN6A27KOE07adZEWBMwd4us7uBHGgnqqYuwPI3d +nqC8srMQ07fw8HyXn7ojPxXyCk+2d6zVgA== -----END CERTIFICATE----- diff --git a/.build/server.key b/.build/server.key index b6dd15913f..f2a7e607b2 100644 --- a/.build/server.key +++ b/.build/server.key @@ -1,27 +1,52 @@ ------BEGIN RSA PRIVATE KEY----- -MIIEowIBAAKCAQEAvC6EG6NgxcAte4GcCfMazyhxl1CsaKOVHMmd/9s2dNOk2JLM -LLHoypLLoiAkzCPPTg/J+eaTU2ZFz1B+6EdqCMsU2Eolj4jc2xIdBTsR/SoeYJmo -6jP+F+ODmqa5skOepuGNi2Ttp2fYR6HvmUaGp9/qWoPMHjjP5/8GLl2bXRJTsTbA -fKYNPSuIHrx1Wi2kgQhHkHx3r+dbwpRXTd/qKwYv64Z0Bd73cO6DE9SUh++UWsud -cEWsqSyLJQHNg3DxUwkCTJl1zE3AHMZowFVBq32PnbEgnV5zmbXe8pPtF9G1Riiw -ZAkzj+0KZnzlbvWzobHA4iuRKAGgyqnR0jFH4QIDAQABAoIBADnMS7U1dAao5Q9X -GrcPnP9dm63vEFU/URA7eLTZ/prZWntOczmTFz4I4lSUbNjqcsS2IsIHqN5nvi9T -uPbc4Ft9DJT2CR1R2wvKP3GY2AibBCOFbpUojPWHYqeAZ+6xyCvXgSL8R+YwBgTS -XwYD3F35b0CH1Iy/xFOsR5i8FXj7He8lOBA76fPrH64DEBTB2zUGztu4qpfv57v5 -sfTISi2ZOqPpXc+8Fw0RPeVWQgSRUh7U3lzL8bNBod6lYcjkhF5Yqet4MdHSyWMT -aKdZ2GRHHdWjpyx6J0cD/bjjaTSDqTD8r265mPzY6bq4t6UQMq4KeDnbeiextDf4 -ELT90YUCgYEA6insCSDJddhFZ51guPPyYE9GL8QQfnzLvFOA4qWsi0u9SAbJ9aS0 -vABaEuot0PyYPwMYq7st07z3DSKno4tisPJ2X7v2nEWxv8MjgczWpltPTPaEdmZE -WGIwG3pyh5wJk1b3VpBJB5jkjtJfGmUJaezU10bzm4QhPiEawemCjucCgYEAzbri -/6EZPbJJa9hGtkJEEVLwbQ2U/CE7mZXL+AcPlS3qMSwyz/1OArPxdTRR4S3sYRRO -fsRDBL8LED/kKUDWNni/zkzmFf/hVkmGd9zc6eif4Zr1gmtHlsHQdaMGxsomzxGL -qydBqDN+4TMmHmUmp2jR/0LIF5UMlNoCvHcxgfcCgYEAnOBNE6h1j4++n7Yd0IsO -PFufx+xwqGzvCVJgLHeV6xRo0NJLh1g7BSCvN7DP1Q0E6mImqxaRkyMr2A75hGWj -TqyBhY2ln/hJJxGSvij/PSA7NnKJN9E3xIazeBVGmXd+Ksm+lq2/X2mc5domgMZj -0iUqSrdsCSoyIy+Gf5bzMs0CgYBcquG044vLDpOj0DeJwS+H3iQN+yAwsYd3FtJZ -VlTejV//5ji9Fwwci5EnifmXxGfFErCIyT6m1KbXGvBa5KmYv6sl8d1x62BEzbmU -JBgeBHp/1JzhshD9BzAuzNAwmr4AZ5bR8UzRxuBP8AorhsRyg/STVjFq7ehM5CZ3 -Xfke4QKBgHCPo3R/oi/E2E7OIM/ELlDpvPQTMrV+rYlMFsy3JRvataIqEGnVbhOR -4dQHEM3u2bJxN79wUYYmZuymVB78wKxTn6hGWcGoM6Y8mrJjVv9D8V0Gc0sWw5pF -KZxuCgzjaN2T7i1LsXEV3gaQrKItToEpGPzSI23egFaG6g5SFqBt ------END RSA PRIVATE KEY----- +-----BEGIN PRIVATE KEY----- +MIIJQQIBADANBgkqhkiG9w0BAQEFAASCCSswggknAgEAAoICAQDlGT9vXb93yoM1 +YT0GAxJIB6/2ExUrdprd049oMVZa4Km0nqwN/xjVvQRIWozmbpvps0mCkFM1ZyL1 +iqZFwiJGWcQvvIffFM1qKRMOSTLNPCbM9mfvRKsCU9gjgatdhy8xUZhz7uFGMGAD +nZdlNMYWGgzMVZo0EyW7Z2QJ+ZCl8wW5IT4iswZWrJsNZU/g7HaNBrXiidDihkmQ +8Kt32R0UnqJeXMHwkQLxddmcGdDmVCKsAEUu3NcvPeAlSJsNHfGDRsf9fImRqZCs +gwI8dJtAke/luMTttQ34aADFTmTbVk4ngVhCxgBkJ6FUDFJcp3t3nFssiisNon9k +5FwtJ3hle/QGM9IRdBvGVcOnZZuXXK2lLtakj5UWUik2xWA0hjX+DsFo7TPwKgZy +4zmWCRobW1e1NX52bqYFWZUKYLqbizllOd98o3yed58PhbF1/IuVEuOoiKu7rNdN +gzr8vgRPpWHQNXp3maCcZq2kWybADU2LQNUKAZLSw3nClcX8QVRAfvf8IyDZ/280 +EYRGu99VqLqDPLa1+3CNAb93J1ONvVjKgJwQQWy4dYFLHTYdBzXV5SOpH8YHL/1I +Hs9W5k28BdwbeMtJnOaV8rqiA6Xd4Xem111AMAigHExxG3kpSnAq6jiOX0+2V++f +7qAunuC6B/oJATXLCbBQILr0ARtKuQIDAQABAoICAAP97y6VPnPLjgLVJxKbfssa +afz0IxG+9ZH11xrpUl6itjpNBUte8LN97jaF8DLhf9FJtZ2mWHJtODBfzw4wnldf +X/O2Y1MZbvHeXA3LHznXX9ROJ9krg/2DCsu/MIZgh5hvQLEmdK6Iw1q7LH5Pz6YA +Pea/YbPUfWGsVC0rUaBFB/C/oEnk/v0g8VIbFZIvAWrRw6oT0JWESJrGr5b9RYxm +Ljo0Mt0dyorjP/YAUI6u4R+VOp9g+Dvpv7909vfg/j2u5k20e/lgI1xdXqGnvrIx ++/4V/KwPeob9TIqJ/bTOGaFtF5j3dirImP8Yq6rsvSuqodkSSELeAor2XEsDumby +PqJY1MIO9DuZSdqf+Cofgzbd6mpeMAwueb+hfBw8AIMG3M9Xj1uDuU+tjsVA79Er +H9acPxLukGjYP5SY2Mo8hLFLLurpjtcDpYdOP2Wh7PBDwHR8anmPQru2rZXT80NY +j3fXNqnTTFbHuntmZ2qWJovmOuKocU5GEm/QCW/f6miqR9Hzc2vbWaIoEO54vcF6 +eS4iLEkAOfmakz3Sno2AXS1jJI6+2v1899cBINvgpATCMkmXnwFDwr9gNYujwlpF +Yl3QM8Vh9dnVt04oyum5x4sz/mTKj5e9O988iqlOkgID4HBVpy/dwYHsHE+XgDDY +yiFetJ/n0+45QHhyvSwBAoIBAQDnrPz2xCbR03KQwZN2DnZClLVFkZe3tZxR6UsY +63yDTrA0ZMJ8AtE/tX79/Iu7gPidNTCrVmOuelf5q3y3AMo6nlKMCc3tIKr6QtaC +99RtHq5p0T3/TS9tWbGjmxEzyx00R3wz5fSypX76qnQLHs6EmrLxFUNmsHIQS2nH +jWvT1+TdmfmogZ/9RaHyBjHGkDfTmlfEKc7/TleE9XsW+G0cGli3fIO0iY0hJTLd +b65X5Gm0URCqsZgIzD99enIvee13Gw8aUJUt8tJZXQHtOWBu491MLd2AVPQ/7eZa +tl/HtjdMj2E3n5NXTie3laRCX+p9mK6087nE7u3JqPqUXU4BAoIBAQD9Jv3hZeii +0pDgLYgiFVds5n2S4CEB4WOT9wn2vUIrYTSjgjAPfsgeJs6N6+WArwaIJrl4tTK4 +m0VjUG394plvyExU8hNZ7hw0E/33rwsKySnkwUFZtOgbsOgUjajRDfFYsqsDhLK0 +o3dY1M+mdYvU9OBo3EhgFy3fYBhtdGIq/4/3kSM6CARQIjddW2pdbB7pyv3qz0mH +6fpzPXWLIex+WBzRVEz7VPPD4coV3LEhmtdPju4RqFPbHS+OpECun8pyaNt14DRr +t216MiyJGNV74zTLELioVHlhlaPvsWnnIeI+2uhhCgQ8UvHn69x2wiAgLlx/e+RD +qPiINhm/xey5AoIBACCASjSsK+3/xfC8110Whkys5AlQdYJWPgnXuqtSTfN11I5l +HEudcZGIerpS9Z9mZnpXfe5rfix6CWGDR0m9GKHEmDwBHByKGrJlMgbJkcmFJl69 +9f6c62xhyuPy2yTy97Pf23LEbeGqCfhMdV8iAULlGPltTDlZw4a5ratLEbd0cC0O +btHO7YzwedmkONNsZAiRfIKOgvWaHfkPHyeHznbE03FaTHfFXEEsIMij5Ed8Sb/8 +J2Rq6bNCRB3sUZyLdF7jMuk0KNl7WTskKyMGi5rC6MbJIGvifymAzHIpZ6Jy06sv +6imNf3QeCMBeg96z6geYpdnI32TbSAykYhLyTAECggEAOowrCVcdX5LdaMt/AYr4 +BjqkbjShzaKH+i+XQVZyGEBKAUrZvKuwsrB88vvMv187Xn++Q3l8uo9Gk/qFBcPD +gsPLS5YU/aaBJVY+VWtJXXw60SoU6B9b0xOuCRreIUNdPwtLW+vzvK1Vq9jEEZZ7 ++YuM3xObNYYG2POLkrzo+1LRxArwH7q87J+NOG0tA2A/IgkNgqHgOqvVfZOIPN5i +qLHOMGeTykjSe8obh8Tbvo7mHwNKchEBG9r7Jb09LGXOV3mC0BdDaGoqyqkR/b8d +mKJqklBStLOcwwHtwUDB4m/GuIy+U7sSUbVJNz8oZNruvSKbx+wqVa+dkzsX529q +GQKCAQBVzafsrfp3yZKa62R7EMtQh6pHDIKvUzZRwxsj4QzQ1y4Rrb6ceXKxI3EQ +ZK6f1Lte+/ifRn8ZsxQOnjNzO9meOco/7CSNGCCcqO/XVN9ixDdF8lzjIsuRqfkT +lsYy7Zo+ZRDUj73UROBvBJtX4jP5It1B/ISKxHxyBFQiB+UtldLl1H+dmGN9LVnF +583i/vTEcLsj9+8yUU8L46sLKfOhNiSBY8D8oKD9Yht0p9SeDxB/r4Rq8Te5Xp1o +FobswNohYBj2rj9+d24uMcpI5nx33JoRkW7VyAXsq8t4b7ei5/sbwuL25NUXhIxf +mMKDxHebdrFY2ADhWLkWus0ik7JA +-----END PRIVATE KEY----- diff --git a/.devcontainer/db/Dockerfile b/.devcontainer/db/Dockerfile index 64cc3febb1..85efd91832 100644 --- a/.devcontainer/db/Dockerfile +++ b/.devcontainer/db/Dockerfile @@ -1,3 +1,3 @@ -FROM postgres +FROM postgres:17 RUN apt-get update && \ - apt-get install -y --no-install-recommends openssl postgresql-16-postgis-3 + apt-get install -y --no-install-recommends openssl postgresql-17-postgis-3 diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 69115b9ad3..1efa473614 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -5,30 +5,37 @@ "workspaceFolder": "/workspace", - "settings": { - "terminal.integrated.profiles.linux": { - "bash": { - "path": "/bin/bash" - } - }, - "terminal.integrated.defaultProfile.linux": "bash", - "remote.extensionKind": { - "ms-azuretools.vscode-docker": "workspace" + "features": { + "ghcr.io/devcontainers/features/dotnet:2": { + "version": "10.0", + "dotnetRuntimeVersions": "8.0,9.0" } }, - "extensions": [ - "ms-dotnettools.csharp", - "formulahendry.dotnet-test-explorer", - "ms-azuretools.vscode-docker", - "mutantdino.resourcemonitor" - ], + "customizations": { + "vscode": { + "settings": { + "terminal.integrated.profiles.linux": { + "bash": { + "path": "/bin/bash" + } + }, + "terminal.integrated.defaultProfile.linux": "bash", + "remote.extensionKind": { + "ms-azuretools.vscode-docker": "workspace" + } + }, - "forwardPorts": [5432, 5050], - - "remoteEnv": { - "DeveloperBuild": "True" + "extensions": [ + "ms-dotnettools.csharp", + "ms-dotnettools.csdevkit", + "ms-azuretools.vscode-docker", + "mutantdino.resourcemonitor" + ] + } }, - "postCreateCommand": "dotnet restore Npgsql.sln" + "forwardPorts": [5432, 5050], + + "postCreateCommand": "dotnet restore Npgsql.slnx" } \ No newline at end of file diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml index e956e66c24..3926f919de 100644 --- a/.devcontainer/docker-compose.yml +++ b/.devcontainer/docker-compose.yml @@ -2,8 +2,7 @@ version: '3' services: npgsql-dev: - # Source for tags: https://mcr.microsoft.com/v2/dotnet/sdk/tags/list - image: mcr.microsoft.com/dotnet/sdk:8.0.100 + image: mcr.microsoft.com/devcontainers/base:ubuntu volumes: - ..:/workspace:cached tty: true diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 78cf887e9d..29158fb563 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -9,13 +9,15 @@ on: - '*' pull_request: +permissions: + contents: read + # Cancel previous PR branch commits (head_ref is only defined on PRs) concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true env: - dotnet_sdk_version: '8.0.100' postgis_version: 3 DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true # Windows comes with PG pre-installed, and defines the PGPASSWORD environment variable. Remove it as it interferes @@ -29,35 +31,53 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-22.04, windows-2022] - pg_major: [16, 15, 14, 13, 12] + os: [ubuntu-24.04] + pg_major: [18, 17, 16, 15, 14] config: [Release] - test_tfm: [net8.0] + test_tfm: [net10.0] include: - - os: ubuntu-22.04 - pg_major: 16 + - os: ubuntu-24.04 + pg_major: 18 config: Debug - test_tfm: net8.0 - - os: macos-12 - pg_major: 14 + test_tfm: net10.0 + - os: macos-15 + pg_major: 16 config: Release - test_tfm: net8.0 -# - os: ubuntu-22.04 -# pg_major: 17 + test_tfm: net10.0 + - os: windows-2022 + pg_major: 18 + config: Release + test_tfm: net10.0 + + # Minimal support TFM build +# - os: ubuntu-24.04 +# pg_major: 18 # config: Release # test_tfm: net8.0 + + # PG prerelease build +# - os: ubuntu-24.04 +# pg_major: 19 +# config: Release +# test_tfm: net10.0 # pg_prerelease: 'PG Prerelease' outputs: is_release: ${{ steps.analyze_tag.outputs.is_release }} is_prerelease: ${{ steps.analyze_tag.outputs.is_prerelease }} + # Installing PostGIS on Windows is complicated/unreliable, so we don't test on it. + # The NPGSQL_TEST_POSTGIS environment variable ensures that if PostGIS isn't installed, + # the PostGIS tests fail and therefore fail the build. + env: + NPGSQL_TEST_POSTGIS: ${{ !startsWith(matrix.os, 'windows') }} + steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: NuGet Cache - uses: actions/cache@v4 + uses: actions/cache@v5 with: path: ~/.nuget/packages key: ${{ runner.os }}-nuget-${{ hashFiles('**/Directory.Build.targets') }} @@ -65,10 +85,7 @@ jobs: ${{ runner.os }}-nuget- - name: Setup .NET Core SDK - uses: actions/setup-dotnet@v4.0.0 - with: - dotnet-version: | - ${{ env.dotnet_sdk_version }} + uses: actions/setup-dotnet@v5.1.0 - name: Build run: dotnet build -c ${{ matrix.config }} @@ -80,17 +97,16 @@ jobs: # First uninstall any PostgreSQL installed on the image dpkg-query -W --showformat='${Package}\n' 'postgresql-*' | xargs sudo dpkg -P postgresql - # Import the repository signing key - wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - - - sudo sh -c 'echo "deb http://apt.postgresql.org/pub/repos/apt/ jammy-pgdg main ${{ matrix.pg_major }}" >> /etc/apt/sources.list.d/pgdg.list' + # Automated repository configuration + sudo apt install -y postgresql-common + sudo /usr/share/postgresql-common/pgdg/apt.postgresql.org.sh -v ${{ matrix.pg_major }} -y sudo apt-get update -qq sudo apt-get install -qq postgresql-${{ matrix.pg_major }} export PGDATA=/etc/postgresql/${{ matrix.pg_major }}/main - sudo cp $GITHUB_WORKSPACE/.build/{server.crt,server.key} $PGDATA - sudo chmod 600 $PGDATA/{server.crt,server.key} - sudo chown postgres $PGDATA/{server.crt,server.key} + sudo cp $GITHUB_WORKSPACE/.build/{server.crt,server.key,ca.crt} $PGDATA + sudo chmod 600 $PGDATA/{server.crt,server.key,ca.crt} + sudo chown postgres $PGDATA/{server.crt,server.key,ca.crt} # Create npgsql_tests user with md5 password 'npgsql_tests' sudo -u postgres psql -c "CREATE USER npgsql_tests SUPERUSER PASSWORD 'md5adf74603a5772843f53e812f03dacb02'" @@ -99,9 +115,9 @@ jobs: sudo -u postgres psql -c "CREATE USER npgsql_tests_nossl SUPERUSER PASSWORD 'npgsql_tests_nossl'" # To disable PostGIS for prereleases (because it usually isn't available until late), surround with the following: - if [ -z "${{ matrix.pg_prerelease }}" ]; then + #if [ -z "${{ matrix.pg_prerelease }}" ]; then sudo apt-get install -qq postgresql-${{ matrix.pg_major }}-postgis-${{ env.postgis_version }} - fi + #fi if [ ${{ matrix.pg_major }} -ge 14 ]; then sudo sed -i "s|unix_socket_directories = '/var/run/postgresql'|unix_socket_directories = '/var/run/postgresql, @/npgsql_unix'|" $PGDATA/postgresql.conf @@ -109,6 +125,7 @@ jobs: sudo sed -i 's/max_connections = 100/max_connections = 500/' $PGDATA/postgresql.conf sudo sed -i 's/#ssl = off/ssl = on/' $PGDATA/postgresql.conf + sudo sed -i "s|ssl_ca_file =|ssl_ca_file = '$PGDATA/ca.crt' #|" $PGDATA/postgresql.conf sudo sed -i "s|ssl_cert_file =|ssl_cert_file = '$PGDATA/server.crt' #|" $PGDATA/postgresql.conf sudo sed -i "s|ssl_key_file =|ssl_key_file = '$PGDATA/server.key' #|" $PGDATA/postgresql.conf sudo sed -i 's/#password_encryption = md5/password_encryption = scram-sha-256/' $PGDATA/postgresql.conf @@ -136,7 +153,7 @@ jobs: sudo -u postgres psql -c "CREATE USER npgsql_tests_scram SUPERUSER PASSWORD 'npgsql_tests_scram'" # Uncomment the following to SSH into the agent running the build (https://github.com/mxschmitt/action-tmate) - #- uses: actions/checkout@v4 + #- uses: actions/checkout@v6 #- name: Setup tmate session # uses: mxschmitt/action-tmate@v3 @@ -159,29 +176,7 @@ jobs: unzip pgsql.zip -x 'pgsql/include/**' 'pgsql/doc/**' 'pgsql/pgAdmin 4/**' 'pgsql/StackBuilder/**' # Match Npgsql CI Docker image and stash one level up - cp $GITHUB_WORKSPACE/.build/{server.crt,server.key} pgsql - - # Find OSGEO version number - OSGEO_VERSION=$(\ - curl -Ls https://download.osgeo.org/postgis/windows/pg${{ matrix.pg_major }} | - sed -n 's/.*>postgis-bundle-pg${{ matrix.pg_major }}-\(${{ env.postgis_version }}.[0-9]*.[0-9]*\)x64.zip<.*/\1/p' | - tail -n 1) - if [ -z "$OSGEO_VERSION" ]; then - OSGEO_VERSION=$(\ - curl -Ls https://download.osgeo.org/postgis/windows/pg${{ matrix.pg_major }}/archive | - sed -n 's/.*>postgis-bundle-pg${{ matrix.pg_major }}-\(${{ env.postgis_version }}.[0-9]*.[0-9]*\)x64.zip<.*/\1/p' | - tail -n 1) - POSTGIS_PATH="archive/" - else - POSTGIS_PATH="" - fi - - # Install PostGIS - echo "Installing PostGIS (version: ${OSGEO_VERSION})" - POSTGIS_FILE="postgis-bundle-pg${{ matrix.pg_major }}-${OSGEO_VERSION}x64" - curl -o postgis.zip -L https://download.osgeo.org/postgis/windows/pg${{ matrix.pg_major }}/${POSTGIS_PATH}${POSTGIS_FILE}.zip - unzip postgis.zip -d postgis - cp -a postgis/$POSTGIS_FILE/. pgsql/ + cp $GITHUB_WORKSPACE/.build/{server.crt,server.key,ca.crt} pgsql # Start PostgreSQL pgsql/bin/initdb -D pgsql/PGDATA -E UTF8 -U postgres @@ -195,7 +190,7 @@ jobs: sed -i "s|#synchronous_standby_names =|synchronous_standby_names = 'npgsql_test_sync_standby' #|" pgsql/PGDATA/postgresql.conf sed -i "s|#synchronous_commit =|synchronous_commit = local #|" pgsql/PGDATA/postgresql.conf sed -i "s|#max_prepared_transactions = 0|max_prepared_transactions = 100|" pgsql/PGDATA/postgresql.conf - pgsql/bin/pg_ctl -D pgsql/PGDATA -l logfile -o '-c ssl=true -c ssl_cert_file=../server.crt -c ssl_key_file=../server.key' start + pgsql/bin/pg_ctl -D pgsql/PGDATA -l logfile -o '-c ssl=true -c ssl_cert_file=../server.crt -c ssl_key_file=../server.key -c ssl_ca_file=../ca.crt' start # Create npgsql_tests user with md5 password 'npgsql_tests' pgsql/bin/psql -U postgres -c "CREATE ROLE npgsql_tests SUPERUSER LOGIN PASSWORD 'md5adf74603a5772843f53e812f03dacb02'" @@ -210,7 +205,7 @@ jobs: sed -i "s|#password_encryption = md5|password_encryption = scram-sha-256|" pgsql/PGDATA/postgresql.conf fi - pgsql/bin/pg_ctl -D pgsql/PGDATA -l logfile -o '-c ssl=true -c ssl_cert_file=../server.crt -c ssl_key_file=../server.key' restart + pgsql/bin/pg_ctl -D pgsql/PGDATA -l logfile -o '-c ssl=true -c ssl_cert_file=../server.crt -c ssl_key_file=../server.key -c ssl_ca_file=../ca.crt' restart pgsql/bin/psql -U postgres -c "CREATE ROLE npgsql_tests_scram SUPERUSER LOGIN PASSWORD 'npgsql_tests_scram'" @@ -232,15 +227,20 @@ jobs: - name: Start PostgreSQL ${{ matrix.pg_major }} (MacOS) if: startsWith(matrix.os, 'macos') run: | - PGDATA=/usr/local/var/postgresql@${{ matrix.pg_major }} + brew update + brew install postgresql@${{ matrix.pg_major }} + + PGDATA=/opt/homebrew/var/postgresql@${{ matrix.pg_major }} sudo sed -i '' 's/#ssl = off/ssl = on/' $PGDATA/postgresql.conf - cp $GITHUB_WORKSPACE/.build/{server.crt,server.key} $PGDATA - chmod 600 $PGDATA/{server.crt,server.key} + sudo sed -i '' "s/#ssl_ca_file =/ssl_ca_file = 'ca.crt' #/" $PGDATA/postgresql.conf + cp $GITHUB_WORKSPACE/.build/{server.crt,server.key,ca.crt} $PGDATA + chmod 600 $PGDATA/{server.crt,server.key,ca.crt} - postgreService=$(brew services list | grep -oe "postgresql\S*") + postgreService=$(brew services list | grep -oe "postgresql@${{ matrix.pg_major }}\S*") brew services start $postgreService + export PATH="/opt/homebrew/opt/postgresql@${{ matrix.pg_major }}/bin:$PATH" echo "Check PostgreSQL service is running" i=5 COMMAND='pg_isready' @@ -301,10 +301,19 @@ jobs: # TODO: Once test/Npgsql.Specification.Tests work, switch to just testing on the solution - name: Test run: | - dotnet test -c ${{ matrix.config }} -f ${{ matrix.test_tfm }} test/Npgsql.Tests --logger "GitHubActions;report-warnings=false" + dotnet test -c ${{ matrix.config }} -f ${{ matrix.test_tfm }} test/Npgsql.Tests --logger "GitHubActions;report-warnings=false" --blame-crash --blame-hang-timeout 30s dotnet test -c ${{ matrix.config }} -f ${{ matrix.test_tfm }} test/Npgsql.DependencyInjection.Tests --logger "GitHubActions;report-warnings=false" shell: bash + - name: Upload Test Hang Dumps + uses: actions/upload-artifact@v7 + if: failure() + with: + name: test-hang-dumps + path: | + **/*.dmp + **/*_Sequence.xml + - name: Test Plugins if: "!startsWith(matrix.os, 'macos')" run: | @@ -328,16 +337,16 @@ jobs: publish-ci: needs: build - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 if: github.event_name == 'push' && github.repository == 'npgsql/npgsql' environment: myget steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: NuGet Cache - uses: actions/cache@v4 + uses: actions/cache@v5 with: path: ~/.nuget/packages key: ${{ runner.os }}-nuget-${{ hashFiles('**/Directory.Build.targets') }} @@ -345,15 +354,13 @@ jobs: ${{ runner.os }}-nuget- - name: Setup .NET Core SDK - uses: actions/setup-dotnet@v4.0.0 - with: - dotnet-version: ${{ env.dotnet_sdk_version }} + uses: actions/setup-dotnet@v5.1.0 - name: Pack - run: dotnet pack Npgsql.sln --configuration Release --property:PackageOutputPath="$PWD/nupkgs" --version-suffix "ci.$(date -u +%Y%m%dT%H%M%S)+sha.${GITHUB_SHA:0:9}" -p:ContinuousIntegrationBuild=true + run: dotnet pack --configuration Release --property:PackageOutputPath="$PWD/nupkgs" --version-suffix "ci.$(date -u +%Y%m%dT%H%M%S)+sha.${GITHUB_SHA:0:9}" -p:ContinuousIntegrationBuild=true - name: Upload artifacts (nupkg) - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v7 with: name: Npgsql.CI path: nupkgs @@ -370,24 +377,22 @@ jobs: release: needs: build - runs-on: ubuntu-22.04 + runs-on: ubuntu-24.04 if: github.event_name == 'push' && startsWith(github.repository, 'npgsql/') && needs.build.outputs.is_release == 'true' environment: nuget.org steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Setup .NET Core SDK - uses: actions/setup-dotnet@v4.0.0 - with: - dotnet-version: ${{ env.dotnet_sdk_version }} + uses: actions/setup-dotnet@v5.1.0 - name: Pack - run: dotnet pack Npgsql.sln --configuration Release --property:PackageOutputPath="$PWD/nupkgs" -p:ContinuousIntegrationBuild=true + run: dotnet pack --configuration Release --property:PackageOutputPath="$PWD/nupkgs" -p:ContinuousIntegrationBuild=true - name: Upload artifacts - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v7 with: name: Npgsql.Release path: nupkgs diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml deleted file mode 100644 index 868ea2418b..0000000000 --- a/.github/workflows/codeql-analysis.yml +++ /dev/null @@ -1,93 +0,0 @@ -# For most projects, this workflow file will not need changing; you simply need -# to commit it to your repository. -# -# You may wish to alter this file to override the set of languages analyzed, -# or to provide custom queries or build logic. -# -# ******** NOTE ******** -# We have attempted to detect the languages in your repository. Please check -# the `language` matrix defined below to confirm you have the correct set of -# supported CodeQL languages. -# -name: "CodeQL" - -on: - push: - branches: - - main - - 'hotfix/**' - - 'release/**' - pull_request: - # The branches below must be a subset of the branches above - branches: - - main - - 'hotfix/**' - - 'release/**' - schedule: - - cron: '21 0 * * 4' - -# Cancel previous PR branch commits (head_ref is only defined on PRs) -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -env: - dotnet_sdk_version: '8.0.100' - DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true - -jobs: - analyze: - name: Analyze - runs-on: ubuntu-latest - permissions: - actions: read - contents: read - security-events: write - - strategy: - fail-fast: false - matrix: - language: [ 'csharp' ] - # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] - # Learn more about CodeQL language support at https://git.io/codeql-language-support - - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - # Initializes the CodeQL tools for scanning. - - name: Initialize CodeQL - uses: github/codeql-action/init@v3 - with: - languages: ${{ matrix.language }} - # If you wish to specify custom queries, you can do so here or in a config file. - # By default, queries listed here will override any specified in a config file. - # Prefix the list here with "+" to use these queries and those in the config file. - # queries: ./path/to/local/query, your-org/your-repo/queries@main - - - name: Setup .NET Core SDK - uses: actions/setup-dotnet@v4.0.0 - with: - dotnet-version: ${{ env.dotnet_sdk_version }} - - - name: Build - run: dotnet build -c Release - - # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). - # If this step fails, then you should remove it and run the build manually (see below) - #- name: Autobuild - # uses: github/codeql-action/autobuild@v2 - - # ℹ️ Command-line programs to run using the OS shell. - # 📚 https://git.io/JvXDl - - # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines - # and modify them (or add more) to build your code if your project - # uses a compiled language - - #- run: | - # make bootstrap - # make release - - - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v3 diff --git a/.github/workflows/native-aot.yml b/.github/workflows/native-aot.yml index 97cf2878e7..cdc4d77ab5 100644 --- a/.github/workflows/native-aot.yml +++ b/.github/workflows/native-aot.yml @@ -9,54 +9,100 @@ on: - '*' pull_request: +permissions: + contents: read + # Cancel previous PR branch commits (head_ref is only defined on PRs) concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true env: - dotnet_sdk_version: '8.0.100' DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true + AOT_Compat: | + param([string]$targetFramework) + + $publishOutput = dotnet publish test/Npgsql.NativeAotTests/Npgsql.NativeAotTests.csproj -r linux-x64 -c Release -f $targetFramework -p:RootNpgsql=true + + $actualWarningCount = 0 + + foreach ($line in $($publishOutput -split "`r`n")) + { + if ($line -like "*analysis warning IL*") + { + Write-Host $line + + $actualWarningCount += 1 + } + } + + $testPassed = 0 + + $binaryPath = "test/Npgsql.NativeAotTests/bin/Release/$targetFramework/linux-x64/native/" + if (-not (Test-Path -LiteralPath $binaryPath)) + { + $testPassed = 1 + Write-Host "Could not publish app, output was:" + foreach ($line in $($publishOutput -split "`r`n")) + { + Write-Host $line + } + } + + Write-Host "Actual warning count is:", $actualWarningCount + $expectedWarningCount = 0 + + if ($actualWarningCount -ne $expectedWarningCount) + { + $testPassed = 2 + Write-Host "Actual warning count:", $actualWarningCount, "is not as expected. Expected warning count is:", $expectedWarningCount + } + + Exit $testPassed # Uncomment and edit the following to use nightly/preview builds -# nuget_config: | -# -# -# -# -# -# -# -# -# -# -# -# -# -# -# -# -# -# -# -# -# -# + # nuget_config: | + # + # + # + # + # + # + # + # + # + # + # + # + # + # + # + # + # + # + # + # + # + # jobs: - build: + full: runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: - os: [ubuntu-22.04] - pg_major: [15] + os: [ ubuntu-24.04 ] + pg_major: [ 18 ] + tfm: [ net10.0 ] steps: - name: Checkout - uses: actions/checkout@v4 - + uses: actions/checkout@v6 + + # - name: Setup nuget config + # run: echo "$nuget_config" > NuGet.config + - name: NuGet Cache - uses: actions/cache@v4 + uses: actions/cache@v5 with: path: ~/.nuget/packages key: ${{ runner.os }}-nuget-${{ hashFiles('**/Directory.Build.targets') }} @@ -64,71 +110,95 @@ jobs: ${{ runner.os }}-nuget- - name: Setup .NET Core SDK - uses: actions/setup-dotnet@v4.0.0 - with: - dotnet-version: | - ${{ env.dotnet_sdk_version }} + uses: actions/setup-dotnet@v5.1.0 + + - name: Write script + run: echo "$AOT_Compat" > test-aot-compatibility.ps1 + + - name: Publish and check for trimmer warnings + run: ./test-aot-compatibility.ps1 ${{ matrix.tfm }} + shell: pwsh + trimmed: + runs-on: ${{ matrix.os }} -# - name: Setup nuget config -# run: echo "$nuget_config" > NuGet.config + strategy: + fail-fast: false + matrix: + os: [ubuntu-24.04] + pg_major: [ 18 ] + tfm: [ net10.0 ] - - name: Setup Native AOT prerequisites - run: sudo apt-get install clang zlib1g-dev - shell: bash + steps: + - name: Checkout + uses: actions/checkout@v6 - - name: Build - run: dotnet publish test/Npgsql.NativeAotTests/Npgsql.NativeAotTests.csproj -r linux-x64 -c Release -f net8.0 -p:OptimizationPreference=Size - shell: bash + # - name: Setup nuget config + # run: echo "$nuget_config" > NuGet.config - # Uncomment the following to SSH into the agent running the build (https://github.com/mxschmitt/action-tmate) - #- uses: actions/checkout@v4 - #- name: Setup tmate session - # uses: mxschmitt/action-tmate@v3 + - name: NuGet Cache + uses: actions/cache@v5 + with: + path: ~/.nuget/packages + key: ${{ runner.os }}-nuget-${{ hashFiles('**/Directory.Build.targets') }} + restore-keys: | + ${{ runner.os }}-nuget- + + - name: Setup .NET Core SDK + uses: actions/setup-dotnet@v5.1.0 - name: Start PostgreSQL run: | sudo systemctl start postgresql.service sudo -u postgres psql -c "CREATE USER npgsql_tests SUPERUSER PASSWORD 'npgsql_tests'" sudo -u postgres psql -c "CREATE DATABASE npgsql_tests OWNER npgsql_tests" + + - name: Build + run: dotnet publish test/Npgsql.NativeAotTests/Npgsql.NativeAotTests.csproj -r linux-x64 -c Release -f ${{ matrix.tfm }} -p:OptimizationPreference=Size + shell: bash + + # Uncomment the following to SSH into the agent running the build (https://github.com/mxschmitt/action-tmate) + #- uses: actions/checkout@v6 + #- name: Setup tmate session + # uses: mxschmitt/action-tmate@v3 - name: Run - run: test/Npgsql.NativeAotTests/bin/Release/net8.0/linux-x64/native/Npgsql.NativeAotTests + run: test/Npgsql.NativeAotTests/bin/Release/${{ matrix.tfm }}/linux-x64/native/Npgsql.NativeAotTests - name: Write binary size to summary run: | - size="$(ls -l test/Npgsql.NativeAotTests/bin/Release/net8.0/linux-x64/native/Npgsql.NativeAotTests | cut -d ' ' -f 5)" + size="$(ls -l test/Npgsql.NativeAotTests/bin/Release/${{ matrix.tfm }}/linux-x64/native/Npgsql.NativeAotTests | cut -d ' ' -f 5)" echo "Binary size is $size bytes ($((size / (1024 * 1024))) mb)" >> $GITHUB_STEP_SUMMARY - name: Dump mstat - run: dotnet run --project test/MStatDumper/MStatDumper.csproj -c release -f net8.0 -- "test/Npgsql.NativeAotTests/obj/Release/net8.0/linux-x64/native/Npgsql.NativeAotTests.mstat" md >> $GITHUB_STEP_SUMMARY + run: dotnet run --project test/MStatDumper/MStatDumper.csproj -c release -f ${{ matrix.tfm }} -- "test/Npgsql.NativeAotTests/obj/Release/${{ matrix.tfm }}/linux-x64/native/Npgsql.NativeAotTests.mstat" md >> $GITHUB_STEP_SUMMARY - name: Upload mstat - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v7 with: name: npgsql.mstat - path: "test/Npgsql.NativeAotTests/obj/Release/net8.0/linux-x64/native/Npgsql.NativeAotTests.mstat" + path: "test/Npgsql.NativeAotTests/obj/Release/${{ matrix.tfm }}/linux-x64/native/Npgsql.NativeAotTests.mstat" retention-days: 3 - name: Upload codedgen dgml - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v7 with: name: npgsql.codegen.dgml.xml - path: "test/Npgsql.NativeAotTests/obj/Release/net8.0/linux-x64/native/Npgsql.NativeAotTests.codegen.dgml.xml" + path: "test/Npgsql.NativeAotTests/obj/Release/${{ matrix.tfm }}/linux-x64/native/Npgsql.NativeAotTests.codegen.dgml.xml" retention-days: 3 - name: Upload scan dgml - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@v7 with: name: npgsql.scan.dgml.xml - path: "test/Npgsql.NativeAotTests/obj/Release/net8.0/linux-x64/native/Npgsql.NativeAotTests.scan.dgml.xml" + path: "test/Npgsql.NativeAotTests/obj/Release/${{ matrix.tfm }}/linux-x64/native/Npgsql.NativeAotTests.scan.dgml.xml" retention-days: 3 - name: Assert binary size run: | - size="$(ls -l test/Npgsql.NativeAotTests/bin/Release/net8.0/linux-x64/native/Npgsql.NativeAotTests | cut -d ' ' -f 5)" + size="$(ls -l test/Npgsql.NativeAotTests/bin/Release/${{ matrix.tfm }}/linux-x64/native/Npgsql.NativeAotTests | cut -d ' ' -f 5)" echo "Binary size is $size bytes ($((size / (1024 * 1024))) mb)" - if (( size > 7340032 )); then - echo "Binary size exceeds 7mb threshold" + if (( size > 5242880 )); then + echo "Binary size exceeds 5MB threshold" exit 1 fi diff --git a/.github/workflows/rich-code-nav.yml b/.github/workflows/rich-code-nav.yml deleted file mode 100644 index 7ee82bfeb9..0000000000 --- a/.github/workflows/rich-code-nav.yml +++ /dev/null @@ -1,39 +0,0 @@ -name: Rich Code Navigation - -on: - workflow_dispatch: - -env: - dotnet_sdk_version: '8.0.100' - DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true - -jobs: - build: - runs-on: windows-latest - - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: NuGet Cache - uses: actions/cache@v4 - with: - path: ~/.nuget/packages - key: ${{ runner.os }}-nuget-${{ hashFiles('**/Directory.Build.targets') }} - restore-keys: | - ${{ runner.os }}-nuget- - - - name: Setup .NET Core SDK - uses: actions/setup-dotnet@v4.0.0 - with: - dotnet-version: ${{ env.dotnet_sdk_version }} - - - name: Build - run: dotnet build Npgsql.sln --configuration Debug - shell: bash - - - name: Rich Navigation Indexing - uses: microsoft/RichCodeNavIndexer@v0 - with: - languages: csharp - repo-token: ${{ github.token }} diff --git a/.github/workflows/trigger-doc-build.yml b/.github/workflows/trigger-doc-build.yml index dfbe89601e..e8783c9e16 100644 --- a/.github/workflows/trigger-doc-build.yml +++ b/.github/workflows/trigger-doc-build.yml @@ -8,9 +8,12 @@ on: branches: - docs +permissions: + contents: read + jobs: build: - runs-on: ubuntu-22.04 + runs-on: ubuntu-latest steps: - name: Trigger documentation build run: | diff --git a/.vscode/extensions.json b/.vscode/extensions.json index 6bf5ff40c5..a505eb8cfc 100644 --- a/.vscode/extensions.json +++ b/.vscode/extensions.json @@ -2,10 +2,6 @@ // List of extensions which should be recommended for users of this workspace. "recommendations": [ "ms-dotnettools.csharp", - "formulahendry.dotnet-test-explorer", - ], - // List of extensions recommended by VS Code that should not be recommended for users of this workspace. - "unwantedRecommendations": [ - + "ms-dotnettools.csdevkit" ] } \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index 3f641af41f..22993a3100 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,4 +1,3 @@ { - "omnisharp.defaultLaunchSolution": "Npgsql.sln", - "dotnet-test-explorer.testProjectPath": "**/*.Tests.csproj" + "dotnet.defaultSolution": "Npgsql.slnx" } \ No newline at end of file diff --git a/Directory.Build.props b/Directory.Build.props index 57494750c7..482dbaf297 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -1,6 +1,6 @@  - 9.0.0-preview.1 + 11.0.0-preview.1 latest true enable @@ -10,7 +10,7 @@ true true - Copyright 2023 © The Npgsql Development Team + Copyright 2026 © The Npgsql Development Team Npgsql PostgreSQL https://github.com/npgsql/npgsql diff --git a/Directory.Packages.props b/Directory.Packages.props index 132bbd43e8..a197b8c29e 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -1,52 +1,53 @@ - - 8.0.0 - $(SystemVersion) + + 10.0.3 + 10.0.3 + + + 10.0.3 + 10.0.3 - - - - - - - - - + + + + - + - - - + + + - + - - - + + + - - - - - - - - - + + + + + + + + + + - + + - - - + + + - + diff --git a/LICENSE b/LICENSE index efec310cda..5f0d26b868 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright (c) 2002-2023, Npgsql +Copyright (c) 2002-2026, Npgsql Permission to use, copy, modify, and distribute this software and its documentation for any purpose, without fee, and without a written agreement diff --git a/Npgsql.sln b/Npgsql.sln deleted file mode 100644 index 80ef02c3a8..0000000000 --- a/Npgsql.sln +++ /dev/null @@ -1,204 +0,0 @@ - -Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio Version 16 -VisualStudioVersion = 16.0.28822.285 -MinimumVisualStudioVersion = 10.0.40219.1 -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "src", "src", "{8537E50E-CF7F-49CB-B4EF-3E2A1B11F050}" -EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "test", "test", "{ED612DB1-AB32-4603-95E7-891BACA71C39}" -EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Npgsql", "src\Npgsql\Npgsql.csproj", "{9D13B739-62B1-4190-B386-7A9547304EB3}" -EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Npgsql.Tests", "test\Npgsql.Tests\Npgsql.Tests.csproj", "{E9C258D7-0D8E-4E6A-9857-5C6438591755}" -EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Npgsql.Benchmarks", "test\Npgsql.Benchmarks\Npgsql.Benchmarks.csproj", "{8B4AE9B6-CDAC-44DD-A5CD-28A470D363B8}" -EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Npgsql.Json.NET", "src\Npgsql.Json.NET\Npgsql.Json.NET.csproj", "{9CBE603F-6746-411D-A5FD-CB2C948CD7D0}" -EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Npgsql.NodaTime", "src\Npgsql.NodaTime\Npgsql.NodaTime.csproj", "{D8DF12D6-FA70-4653-BD8F-C188944836DE}" -EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Npgsql.PluginTests", "test\Npgsql.PluginTests\Npgsql.PluginTests.csproj", "{9BD7FC3D-6956-42A8-A586-2558C499EBA2}" -EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Npgsql.NetTopologySuite", "src\Npgsql.NetTopologySuite\Npgsql.NetTopologySuite.csproj", "{6CB12050-DC9B-4155-BADD-BFDD54CDD70F}" -EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Npgsql.GeoJSON", "src\Npgsql.GeoJSON\Npgsql.GeoJSON.csproj", "{F7C53EBD-0075-474F-A083-419257D04080}" -EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Npgsql.Specification.Tests", "test\Npgsql.Specification.Tests\Npgsql.Specification.Tests.csproj", "{A77E5FAF-D775-4AB4-8846-8965C2104E60}" -EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution Items", "{004A2E0F-D34A-44D4-8DF0-D2BC63B57073}" - ProjectSection(SolutionItems) = preProject - .editorconfig = .editorconfig - Directory.Build.props = Directory.Build.props - Directory.Packages.props = Directory.Packages.props - README.md = README.md - global.json = global.json - NuGet.config = NuGet.config - EndProjectSection -EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Npgsql.SourceGenerators", "src\Npgsql.SourceGenerators\Npgsql.SourceGenerators.csproj", "{63026A19-60B8-4906-81CB-216F30E8094B}" -EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Npgsql.OpenTelemetry", "src\Npgsql.OpenTelemetry\Npgsql.OpenTelemetry.csproj", "{DA29F063-1828-47D8-B051-800AF7C9A0BE}" -EndProject -Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Github", "Github", "{BA7B6F53-D24D-45AC-927A-266857EA8D1E}" - ProjectSection(SolutionItems) = preProject - .github\workflows\build.yml = .github\workflows\build.yml - .github\dependabot.yml = .github\dependabot.yml - .github\workflows\codeql-analysis.yml = .github\workflows\codeql-analysis.yml - .github\workflows\rich-code-nav.yml = .github\workflows\rich-code-nav.yml - .github\workflows\native-aot.yml = .github\workflows\native-aot.yml - EndProjectSection -EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Npgsql.DependencyInjection", "src\Npgsql.DependencyInjection\Npgsql.DependencyInjection.csproj", "{B58E12EB-E43D-4D77-894E-5157D2269836}" -EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Npgsql.DependencyInjection.Tests", "test\Npgsql.DependencyInjection.Tests\Npgsql.DependencyInjection.Tests.csproj", "{EB2530FC-69F7-4DCB-A8B3-3671A157ED32}" -EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Npgsql.NativeAotTests", "test\Npgsql.NativeAotTests\Npgsql.NativeAotTests.csproj", "{20F2E9D6-A69E-4BAE-9236-574B0AA59139}" -EndProject -Global - GlobalSection(SolutionConfigurationPlatforms) = preSolution - Debug|Any CPU = Debug|Any CPU - Debug|x86 = Debug|x86 - Release|Any CPU = Release|Any CPU - Release|x86 = Release|x86 - EndGlobalSection - GlobalSection(ProjectConfigurationPlatforms) = postSolution - {9D13B739-62B1-4190-B386-7A9547304EB3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {9D13B739-62B1-4190-B386-7A9547304EB3}.Debug|Any CPU.Build.0 = Debug|Any CPU - {9D13B739-62B1-4190-B386-7A9547304EB3}.Debug|x86.ActiveCfg = Debug|Any CPU - {9D13B739-62B1-4190-B386-7A9547304EB3}.Debug|x86.Build.0 = Debug|Any CPU - {9D13B739-62B1-4190-B386-7A9547304EB3}.Release|Any CPU.ActiveCfg = Release|Any CPU - {9D13B739-62B1-4190-B386-7A9547304EB3}.Release|Any CPU.Build.0 = Release|Any CPU - {9D13B739-62B1-4190-B386-7A9547304EB3}.Release|x86.ActiveCfg = Release|Any CPU - {9D13B739-62B1-4190-B386-7A9547304EB3}.Release|x86.Build.0 = Release|Any CPU - {E9C258D7-0D8E-4E6A-9857-5C6438591755}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {E9C258D7-0D8E-4E6A-9857-5C6438591755}.Debug|Any CPU.Build.0 = Debug|Any CPU - {E9C258D7-0D8E-4E6A-9857-5C6438591755}.Debug|x86.ActiveCfg = Debug|Any CPU - {E9C258D7-0D8E-4E6A-9857-5C6438591755}.Debug|x86.Build.0 = Debug|Any CPU - {E9C258D7-0D8E-4E6A-9857-5C6438591755}.Release|Any CPU.ActiveCfg = Release|Any CPU - {E9C258D7-0D8E-4E6A-9857-5C6438591755}.Release|Any CPU.Build.0 = Release|Any CPU - {E9C258D7-0D8E-4E6A-9857-5C6438591755}.Release|x86.ActiveCfg = Release|Any CPU - {E9C258D7-0D8E-4E6A-9857-5C6438591755}.Release|x86.Build.0 = Release|Any CPU - {8B4AE9B6-CDAC-44DD-A5CD-28A470D363B8}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {8B4AE9B6-CDAC-44DD-A5CD-28A470D363B8}.Debug|Any CPU.Build.0 = Debug|Any CPU - {8B4AE9B6-CDAC-44DD-A5CD-28A470D363B8}.Debug|x86.ActiveCfg = Debug|Any CPU - {8B4AE9B6-CDAC-44DD-A5CD-28A470D363B8}.Debug|x86.Build.0 = Debug|Any CPU - {8B4AE9B6-CDAC-44DD-A5CD-28A470D363B8}.Release|Any CPU.ActiveCfg = Release|Any CPU - {8B4AE9B6-CDAC-44DD-A5CD-28A470D363B8}.Release|Any CPU.Build.0 = Release|Any CPU - {8B4AE9B6-CDAC-44DD-A5CD-28A470D363B8}.Release|x86.ActiveCfg = Release|Any CPU - {8B4AE9B6-CDAC-44DD-A5CD-28A470D363B8}.Release|x86.Build.0 = Release|Any CPU - {9CBE603F-6746-411D-A5FD-CB2C948CD7D0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {9CBE603F-6746-411D-A5FD-CB2C948CD7D0}.Debug|Any CPU.Build.0 = Debug|Any CPU - {9CBE603F-6746-411D-A5FD-CB2C948CD7D0}.Debug|x86.ActiveCfg = Debug|Any CPU - {9CBE603F-6746-411D-A5FD-CB2C948CD7D0}.Debug|x86.Build.0 = Debug|Any CPU - {9CBE603F-6746-411D-A5FD-CB2C948CD7D0}.Release|Any CPU.ActiveCfg = Release|Any CPU - {9CBE603F-6746-411D-A5FD-CB2C948CD7D0}.Release|Any CPU.Build.0 = Release|Any CPU - {9CBE603F-6746-411D-A5FD-CB2C948CD7D0}.Release|x86.ActiveCfg = Release|Any CPU - {9CBE603F-6746-411D-A5FD-CB2C948CD7D0}.Release|x86.Build.0 = Release|Any CPU - {D8DF12D6-FA70-4653-BD8F-C188944836DE}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {D8DF12D6-FA70-4653-BD8F-C188944836DE}.Debug|Any CPU.Build.0 = Debug|Any CPU - {D8DF12D6-FA70-4653-BD8F-C188944836DE}.Debug|x86.ActiveCfg = Debug|Any CPU - {D8DF12D6-FA70-4653-BD8F-C188944836DE}.Debug|x86.Build.0 = Debug|Any CPU - {D8DF12D6-FA70-4653-BD8F-C188944836DE}.Release|Any CPU.ActiveCfg = Release|Any CPU - {D8DF12D6-FA70-4653-BD8F-C188944836DE}.Release|Any CPU.Build.0 = Release|Any CPU - {D8DF12D6-FA70-4653-BD8F-C188944836DE}.Release|x86.ActiveCfg = Release|Any CPU - {D8DF12D6-FA70-4653-BD8F-C188944836DE}.Release|x86.Build.0 = Release|Any CPU - {9BD7FC3D-6956-42A8-A586-2558C499EBA2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {9BD7FC3D-6956-42A8-A586-2558C499EBA2}.Debug|Any CPU.Build.0 = Debug|Any CPU - {9BD7FC3D-6956-42A8-A586-2558C499EBA2}.Debug|x86.ActiveCfg = Debug|Any CPU - {9BD7FC3D-6956-42A8-A586-2558C499EBA2}.Debug|x86.Build.0 = Debug|Any CPU - {9BD7FC3D-6956-42A8-A586-2558C499EBA2}.Release|Any CPU.ActiveCfg = Release|Any CPU - {9BD7FC3D-6956-42A8-A586-2558C499EBA2}.Release|Any CPU.Build.0 = Release|Any CPU - {9BD7FC3D-6956-42A8-A586-2558C499EBA2}.Release|x86.ActiveCfg = Release|Any CPU - {9BD7FC3D-6956-42A8-A586-2558C499EBA2}.Release|x86.Build.0 = Release|Any CPU - {6CB12050-DC9B-4155-BADD-BFDD54CDD70F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {6CB12050-DC9B-4155-BADD-BFDD54CDD70F}.Debug|Any CPU.Build.0 = Debug|Any CPU - {6CB12050-DC9B-4155-BADD-BFDD54CDD70F}.Debug|x86.ActiveCfg = Debug|Any CPU - {6CB12050-DC9B-4155-BADD-BFDD54CDD70F}.Debug|x86.Build.0 = Debug|Any CPU - {6CB12050-DC9B-4155-BADD-BFDD54CDD70F}.Release|Any CPU.ActiveCfg = Release|Any CPU - {6CB12050-DC9B-4155-BADD-BFDD54CDD70F}.Release|Any CPU.Build.0 = Release|Any CPU - {6CB12050-DC9B-4155-BADD-BFDD54CDD70F}.Release|x86.ActiveCfg = Release|Any CPU - {6CB12050-DC9B-4155-BADD-BFDD54CDD70F}.Release|x86.Build.0 = Release|Any CPU - {F7C53EBD-0075-474F-A083-419257D04080}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {F7C53EBD-0075-474F-A083-419257D04080}.Debug|Any CPU.Build.0 = Debug|Any CPU - {F7C53EBD-0075-474F-A083-419257D04080}.Debug|x86.ActiveCfg = Debug|Any CPU - {F7C53EBD-0075-474F-A083-419257D04080}.Debug|x86.Build.0 = Debug|Any CPU - {F7C53EBD-0075-474F-A083-419257D04080}.Release|Any CPU.ActiveCfg = Release|Any CPU - {F7C53EBD-0075-474F-A083-419257D04080}.Release|Any CPU.Build.0 = Release|Any CPU - {F7C53EBD-0075-474F-A083-419257D04080}.Release|x86.ActiveCfg = Release|Any CPU - {F7C53EBD-0075-474F-A083-419257D04080}.Release|x86.Build.0 = Release|Any CPU - {A77E5FAF-D775-4AB4-8846-8965C2104E60}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {A77E5FAF-D775-4AB4-8846-8965C2104E60}.Debug|Any CPU.Build.0 = Debug|Any CPU - {A77E5FAF-D775-4AB4-8846-8965C2104E60}.Debug|x86.ActiveCfg = Debug|Any CPU - {A77E5FAF-D775-4AB4-8846-8965C2104E60}.Debug|x86.Build.0 = Debug|Any CPU - {A77E5FAF-D775-4AB4-8846-8965C2104E60}.Release|Any CPU.ActiveCfg = Release|Any CPU - {A77E5FAF-D775-4AB4-8846-8965C2104E60}.Release|Any CPU.Build.0 = Release|Any CPU - {A77E5FAF-D775-4AB4-8846-8965C2104E60}.Release|x86.ActiveCfg = Release|Any CPU - {A77E5FAF-D775-4AB4-8846-8965C2104E60}.Release|x86.Build.0 = Release|Any CPU - {63026A19-60B8-4906-81CB-216F30E8094B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {63026A19-60B8-4906-81CB-216F30E8094B}.Debug|Any CPU.Build.0 = Debug|Any CPU - {63026A19-60B8-4906-81CB-216F30E8094B}.Debug|x86.ActiveCfg = Debug|Any CPU - {63026A19-60B8-4906-81CB-216F30E8094B}.Debug|x86.Build.0 = Debug|Any CPU - {63026A19-60B8-4906-81CB-216F30E8094B}.Release|Any CPU.ActiveCfg = Release|Any CPU - {63026A19-60B8-4906-81CB-216F30E8094B}.Release|Any CPU.Build.0 = Release|Any CPU - {63026A19-60B8-4906-81CB-216F30E8094B}.Release|x86.ActiveCfg = Release|Any CPU - {63026A19-60B8-4906-81CB-216F30E8094B}.Release|x86.Build.0 = Release|Any CPU - {DA29F063-1828-47D8-B051-800AF7C9A0BE}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {DA29F063-1828-47D8-B051-800AF7C9A0BE}.Debug|Any CPU.Build.0 = Debug|Any CPU - {DA29F063-1828-47D8-B051-800AF7C9A0BE}.Debug|x86.ActiveCfg = Debug|Any CPU - {DA29F063-1828-47D8-B051-800AF7C9A0BE}.Debug|x86.Build.0 = Debug|Any CPU - {DA29F063-1828-47D8-B051-800AF7C9A0BE}.Release|Any CPU.ActiveCfg = Release|Any CPU - {DA29F063-1828-47D8-B051-800AF7C9A0BE}.Release|Any CPU.Build.0 = Release|Any CPU - {DA29F063-1828-47D8-B051-800AF7C9A0BE}.Release|x86.ActiveCfg = Release|Any CPU - {DA29F063-1828-47D8-B051-800AF7C9A0BE}.Release|x86.Build.0 = Release|Any CPU - {B58E12EB-E43D-4D77-894E-5157D2269836}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {B58E12EB-E43D-4D77-894E-5157D2269836}.Debug|Any CPU.Build.0 = Debug|Any CPU - {B58E12EB-E43D-4D77-894E-5157D2269836}.Debug|x86.ActiveCfg = Debug|Any CPU - {B58E12EB-E43D-4D77-894E-5157D2269836}.Debug|x86.Build.0 = Debug|Any CPU - {B58E12EB-E43D-4D77-894E-5157D2269836}.Release|Any CPU.ActiveCfg = Release|Any CPU - {B58E12EB-E43D-4D77-894E-5157D2269836}.Release|Any CPU.Build.0 = Release|Any CPU - {B58E12EB-E43D-4D77-894E-5157D2269836}.Release|x86.ActiveCfg = Release|Any CPU - {B58E12EB-E43D-4D77-894E-5157D2269836}.Release|x86.Build.0 = Release|Any CPU - {EB2530FC-69F7-4DCB-A8B3-3671A157ED32}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {EB2530FC-69F7-4DCB-A8B3-3671A157ED32}.Debug|Any CPU.Build.0 = Debug|Any CPU - {EB2530FC-69F7-4DCB-A8B3-3671A157ED32}.Debug|x86.ActiveCfg = Debug|Any CPU - {EB2530FC-69F7-4DCB-A8B3-3671A157ED32}.Debug|x86.Build.0 = Debug|Any CPU - {EB2530FC-69F7-4DCB-A8B3-3671A157ED32}.Release|Any CPU.ActiveCfg = Release|Any CPU - {EB2530FC-69F7-4DCB-A8B3-3671A157ED32}.Release|Any CPU.Build.0 = Release|Any CPU - {EB2530FC-69F7-4DCB-A8B3-3671A157ED32}.Release|x86.ActiveCfg = Release|Any CPU - {EB2530FC-69F7-4DCB-A8B3-3671A157ED32}.Release|x86.Build.0 = Release|Any CPU - {20F2E9D6-A69E-4BAE-9236-574B0AA59139}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {20F2E9D6-A69E-4BAE-9236-574B0AA59139}.Debug|Any CPU.Build.0 = Debug|Any CPU - {20F2E9D6-A69E-4BAE-9236-574B0AA59139}.Debug|x86.ActiveCfg = Debug|Any CPU - {20F2E9D6-A69E-4BAE-9236-574B0AA59139}.Debug|x86.Build.0 = Debug|Any CPU - {20F2E9D6-A69E-4BAE-9236-574B0AA59139}.Release|Any CPU.ActiveCfg = Release|Any CPU - {20F2E9D6-A69E-4BAE-9236-574B0AA59139}.Release|Any CPU.Build.0 = Release|Any CPU - {20F2E9D6-A69E-4BAE-9236-574B0AA59139}.Release|x86.ActiveCfg = Release|Any CPU - {20F2E9D6-A69E-4BAE-9236-574B0AA59139}.Release|x86.Build.0 = Release|Any CPU - EndGlobalSection - GlobalSection(SolutionProperties) = preSolution - HideSolutionNode = FALSE - EndGlobalSection - GlobalSection(NestedProjects) = preSolution - {9D13B739-62B1-4190-B386-7A9547304EB3} = {8537E50E-CF7F-49CB-B4EF-3E2A1B11F050} - {E9C258D7-0D8E-4E6A-9857-5C6438591755} = {ED612DB1-AB32-4603-95E7-891BACA71C39} - {8B4AE9B6-CDAC-44DD-A5CD-28A470D363B8} = {ED612DB1-AB32-4603-95E7-891BACA71C39} - {9CBE603F-6746-411D-A5FD-CB2C948CD7D0} = {8537E50E-CF7F-49CB-B4EF-3E2A1B11F050} - {D8DF12D6-FA70-4653-BD8F-C188944836DE} = {8537E50E-CF7F-49CB-B4EF-3E2A1B11F050} - {9BD7FC3D-6956-42A8-A586-2558C499EBA2} = {ED612DB1-AB32-4603-95E7-891BACA71C39} - {6CB12050-DC9B-4155-BADD-BFDD54CDD70F} = {8537E50E-CF7F-49CB-B4EF-3E2A1B11F050} - {F7C53EBD-0075-474F-A083-419257D04080} = {8537E50E-CF7F-49CB-B4EF-3E2A1B11F050} - {A77E5FAF-D775-4AB4-8846-8965C2104E60} = {ED612DB1-AB32-4603-95E7-891BACA71C39} - {63026A19-60B8-4906-81CB-216F30E8094B} = {8537E50E-CF7F-49CB-B4EF-3E2A1B11F050} - {DA29F063-1828-47D8-B051-800AF7C9A0BE} = {8537E50E-CF7F-49CB-B4EF-3E2A1B11F050} - {BA7B6F53-D24D-45AC-927A-266857EA8D1E} = {004A2E0F-D34A-44D4-8DF0-D2BC63B57073} - {B58E12EB-E43D-4D77-894E-5157D2269836} = {8537E50E-CF7F-49CB-B4EF-3E2A1B11F050} - {EB2530FC-69F7-4DCB-A8B3-3671A157ED32} = {ED612DB1-AB32-4603-95E7-891BACA71C39} - {20F2E9D6-A69E-4BAE-9236-574B0AA59139} = {ED612DB1-AB32-4603-95E7-891BACA71C39} - EndGlobalSection - GlobalSection(ExtensibilityGlobals) = postSolution - SolutionGuid = {C90AEECD-DB4C-4BE6-B506-16A449852FB8} - EndGlobalSection - GlobalSection(MonoDevelopProperties) = preSolution - StartupItem = Npgsql.csproj - EndGlobalSection -EndGlobal diff --git a/Npgsql.slnx b/Npgsql.slnx new file mode 100644 index 0000000000..e69a6728c8 --- /dev/null +++ b/Npgsql.slnx @@ -0,0 +1,36 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/Npgsql.sln.DotSettings b/Npgsql.slnx.DotSettings similarity index 100% rename from Npgsql.sln.DotSettings rename to Npgsql.slnx.DotSettings diff --git a/global.json b/global.json index c4fc1c4611..6a288505a1 100644 --- a/global.json +++ b/global.json @@ -1,7 +1,7 @@ { "sdk": { - "version": "8.0.100", + "version": "10.0.100", "rollForward": "latestMajor", - "allowPrerelease": "false" + "allowPrerelease": false } } diff --git a/src/Directory.Build.props b/src/Directory.Build.props index 169a5988a2..6e8c5bb19f 100644 --- a/src/Directory.Build.props +++ b/src/Directory.Build.props @@ -2,15 +2,10 @@ + true true - - true - - - - diff --git a/src/Npgsql.DependencyInjection/Npgsql.DependencyInjection.csproj b/src/Npgsql.DependencyInjection/Npgsql.DependencyInjection.csproj index 357003cf07..3c10503037 100644 --- a/src/Npgsql.DependencyInjection/Npgsql.DependencyInjection.csproj +++ b/src/Npgsql.DependencyInjection/Npgsql.DependencyInjection.csproj @@ -2,10 +2,7 @@ Shay Rojansky - - - net6.0;net8.0 - net8.0 + net10.0 npgsql;postgresql;postgres;ado;ado.net;database;sql;di;dependency injection README.md diff --git a/src/Npgsql.DependencyInjection/NpgsqlServiceCollectionExtensions.Obsolete.cs b/src/Npgsql.DependencyInjection/NpgsqlServiceCollectionExtensions.Obsolete.cs new file mode 100644 index 0000000000..6e2b4e7d4f --- /dev/null +++ b/src/Npgsql.DependencyInjection/NpgsqlServiceCollectionExtensions.Obsolete.cs @@ -0,0 +1,220 @@ +using System; +using System.ComponentModel; +using Npgsql; + +namespace Microsoft.Extensions.DependencyInjection; + +public static partial class NpgsqlServiceCollectionExtensions +{ + /// + /// Registers an and an in the . + /// + /// The to add services to. + /// An Npgsql connection string. + /// + /// The lifetime with which to register the in the container. + /// Defaults to . + /// + /// + /// The lifetime with which to register the service in the container. + /// Defaults to . + /// + /// The same service collection so that multiple calls can be chained. + [EditorBrowsable(EditorBrowsableState.Never), Obsolete("Defined for binary compatibility with 7.0")] + public static IServiceCollection AddNpgsqlDataSource( + this IServiceCollection serviceCollection, + string connectionString, + ServiceLifetime connectionLifetime, + ServiceLifetime dataSourceLifetime) + => AddNpgsqlDataSourceCore( + serviceCollection, serviceKey: null, connectionString, dataSourceBuilderAction: null, + connectionLifetime, dataSourceLifetime, state: null); + + /// + /// Registers an and an in the . + /// + /// The to add services to. + /// An Npgsql connection string. + /// + /// An action to configure the for further customizations of the . + /// + /// + /// The lifetime with which to register the in the container. + /// Defaults to . + /// + /// + /// The lifetime with which to register the service in the container. + /// Defaults to . + /// + /// The same service collection so that multiple calls can be chained. + [EditorBrowsable(EditorBrowsableState.Never), Obsolete("Defined for binary compatibility with 7.0")] + public static IServiceCollection AddNpgsqlDataSource( + this IServiceCollection serviceCollection, + string connectionString, + Action dataSourceBuilderAction, + ServiceLifetime connectionLifetime, + ServiceLifetime dataSourceLifetime) + => AddNpgsqlDataSourceCore(serviceCollection, serviceKey: null, connectionString, + static (_, builder, state) => ((Action)state!)(builder), + connectionLifetime, dataSourceLifetime, state: dataSourceBuilderAction); + + /// + /// Registers an and an in the . + /// + /// The to add services to. + /// An Npgsql connection string. + /// + /// The lifetime with which to register the in the container. + /// Defaults to . + /// + /// + /// The lifetime with which to register the service in the container. + /// Defaults to . + /// + /// The same service collection so that multiple calls can be chained. + [EditorBrowsable(EditorBrowsableState.Never), Obsolete("Defined for binary compatibility with 7.0")] + public static IServiceCollection AddNpgsqlSlimDataSource( + this IServiceCollection serviceCollection, + string connectionString, + ServiceLifetime connectionLifetime, + ServiceLifetime dataSourceLifetime) + => AddNpgsqlSlimDataSourceCore( + serviceCollection, serviceKey: null, connectionString, dataSourceBuilderAction: null, + connectionLifetime, dataSourceLifetime, state: null); + + /// + /// Registers an and an in the . + /// + /// The to add services to. + /// An Npgsql connection string. + /// + /// An action to configure the for further customizations of the . + /// + /// + /// The lifetime with which to register the in the container. + /// Defaults to . + /// + /// + /// The lifetime with which to register the service in the container. + /// Defaults to . + /// + /// The same service collection so that multiple calls can be chained. + [EditorBrowsable(EditorBrowsableState.Never), Obsolete("Defined for binary compatibility with 7.0")] + public static IServiceCollection AddNpgsqlSlimDataSource( + this IServiceCollection serviceCollection, + string connectionString, + Action dataSourceBuilderAction, + ServiceLifetime connectionLifetime, + ServiceLifetime dataSourceLifetime) + => AddNpgsqlSlimDataSourceCore(serviceCollection, serviceKey: null, connectionString, + static (_, builder, state) => ((Action)state!)(builder), + connectionLifetime, dataSourceLifetime, state: dataSourceBuilderAction); + + /// + /// Registers an and an in the + /// . + /// + /// The to add services to. + /// An Npgsql connection string. + /// + /// The lifetime with which to register the in the container. + /// Defaults to . + /// + /// + /// The lifetime with which to register the service in the container. + /// Defaults to . + /// + /// The same service collection so that multiple calls can be chained. + [EditorBrowsable(EditorBrowsableState.Never), Obsolete("Defined for binary compatibility with 7.0")] + public static IServiceCollection AddMultiHostNpgsqlDataSource( + this IServiceCollection serviceCollection, + string connectionString, + ServiceLifetime connectionLifetime, + ServiceLifetime dataSourceLifetime) + => AddMultiHostNpgsqlDataSourceCore( + serviceCollection, serviceKey: null, connectionString, dataSourceBuilderAction: null, + connectionLifetime, dataSourceLifetime, state: null); + + /// + /// Registers an and an in the + /// + /// The to add services to. + /// An Npgsql connection string. + /// + /// An action to configure the for further customizations of the . + /// + /// + /// The lifetime with which to register the in the container. + /// Defaults to . + /// + /// + /// The lifetime with which to register the service in the container. + /// Defaults to . + /// + /// The same service collection so that multiple calls can be chained. + [EditorBrowsable(EditorBrowsableState.Never), Obsolete("Defined for binary compatibility with 7.0")] + public static IServiceCollection AddMultiHostNpgsqlDataSource( + this IServiceCollection serviceCollection, + string connectionString, + Action dataSourceBuilderAction, + ServiceLifetime connectionLifetime, + ServiceLifetime dataSourceLifetime) + => AddMultiHostNpgsqlDataSourceCore( + serviceCollection, serviceKey: null, connectionString, + static (_, builder, state) => ((Action)state!)(builder), + connectionLifetime, dataSourceLifetime, state: dataSourceBuilderAction); + + /// + /// Registers an and an in the + /// . + /// + /// The to add services to. + /// An Npgsql connection string. + /// + /// The lifetime with which to register the in the container. + /// Defaults to . + /// + /// + /// The lifetime with which to register the service in the container. + /// Defaults to . + /// + /// The same service collection so that multiple calls can be chained. + [EditorBrowsable(EditorBrowsableState.Never), Obsolete("Defined for binary compatibility with 7.0")] + public static IServiceCollection AddMultiHostNpgsqlSlimDataSource( + this IServiceCollection serviceCollection, + string connectionString, + ServiceLifetime connectionLifetime, + ServiceLifetime dataSourceLifetime) + => AddMultiHostNpgsqlSlimDataSourceCore( + serviceCollection, serviceKey: null, connectionString, dataSourceBuilderAction: null, + connectionLifetime, dataSourceLifetime, state: null); + + /// + /// Registers an and an in the + /// + /// The to add services to. + /// An Npgsql connection string. + /// + /// An action to configure the for further customizations of the . + /// + /// + /// The lifetime with which to register the in the container. + /// Defaults to . + /// + /// + /// The lifetime with which to register the service in the container. + /// Defaults to . + /// + /// The same service collection so that multiple calls can be chained. + [EditorBrowsable(EditorBrowsableState.Never), Obsolete("Defined for binary compatibility with 7.0")] + public static IServiceCollection AddMultiHostNpgsqlSlimDataSource( + this IServiceCollection serviceCollection, + string connectionString, + Action dataSourceBuilderAction, + ServiceLifetime connectionLifetime, + ServiceLifetime dataSourceLifetime) + => AddMultiHostNpgsqlSlimDataSourceCore( + serviceCollection, serviceKey: null, connectionString, + static (_, builder, state) => ((Action)state!)(builder), + connectionLifetime, dataSourceLifetime, state: dataSourceBuilderAction); +} diff --git a/src/Npgsql.DependencyInjection/NpgsqlServiceCollectionExtensions.cs b/src/Npgsql.DependencyInjection/NpgsqlServiceCollectionExtensions.cs index 755d6b1357..7e22029a40 100644 --- a/src/Npgsql.DependencyInjection/NpgsqlServiceCollectionExtensions.cs +++ b/src/Npgsql.DependencyInjection/NpgsqlServiceCollectionExtensions.cs @@ -1,5 +1,4 @@ using System; -using System.ComponentModel; using System.Data.Common; using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.Logging; @@ -11,16 +10,13 @@ namespace Microsoft.Extensions.DependencyInjection; /// /// Extension method for setting up Npgsql services in an . /// -public static class NpgsqlServiceCollectionExtensions +public static partial class NpgsqlServiceCollectionExtensions { /// /// Registers an and an in the . /// /// The to add services to. /// An Npgsql connection string. - /// - /// An action to configure the for further customizations of the . - /// /// /// The lifetime with which to register the in the container. /// Defaults to . @@ -34,11 +30,12 @@ public static class NpgsqlServiceCollectionExtensions public static IServiceCollection AddNpgsqlDataSource( this IServiceCollection serviceCollection, string connectionString, - Action dataSourceBuilderAction, ServiceLifetime connectionLifetime = ServiceLifetime.Transient, ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton, object? serviceKey = null) - => AddNpgsqlDataSourceCore(serviceCollection, serviceKey, connectionString, dataSourceBuilderAction, connectionLifetime, dataSourceLifetime); + => AddNpgsqlDataSourceCore( + serviceCollection, serviceKey, connectionString, dataSourceBuilderAction: null, + connectionLifetime, dataSourceLifetime, state: null); /// /// Registers an and an in the . @@ -56,21 +53,27 @@ public static IServiceCollection AddNpgsqlDataSource( /// The lifetime with which to register the service in the container. /// Defaults to . /// + /// The of the data source. /// The same service collection so that multiple calls can be chained. - [EditorBrowsable(EditorBrowsableState.Never), Obsolete("Defined for binary compatibility with 7.0")] public static IServiceCollection AddNpgsqlDataSource( this IServiceCollection serviceCollection, string connectionString, Action dataSourceBuilderAction, - ServiceLifetime connectionLifetime, - ServiceLifetime dataSourceLifetime) - => AddNpgsqlDataSourceCore(serviceCollection, serviceKey: null, connectionString, dataSourceBuilderAction, connectionLifetime, dataSourceLifetime); + ServiceLifetime connectionLifetime = ServiceLifetime.Transient, + ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton, + object? serviceKey = null) + => AddNpgsqlDataSourceCore(serviceCollection, serviceKey, connectionString, + static (_, builder, state) => ((Action)state!)(builder) + , connectionLifetime, dataSourceLifetime, state: dataSourceBuilderAction); /// /// Registers an and an in the . /// /// The to add services to. /// An Npgsql connection string. + /// + /// An action to configure the for further customizations of the . + /// /// /// The lifetime with which to register the in the container. /// Defaults to . @@ -84,11 +87,13 @@ public static IServiceCollection AddNpgsqlDataSource( public static IServiceCollection AddNpgsqlDataSource( this IServiceCollection serviceCollection, string connectionString, + Action dataSourceBuilderAction, ServiceLifetime connectionLifetime = ServiceLifetime.Transient, ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton, object? serviceKey = null) - => AddNpgsqlDataSourceCore( - serviceCollection, serviceKey, connectionString, dataSourceBuilderAction: null, connectionLifetime, dataSourceLifetime); + => AddNpgsqlDataSourceCore(serviceCollection, serviceKey, connectionString, + static (sp, builder, state) => ((Action)state!)(sp, builder), + connectionLifetime, dataSourceLifetime, state: dataSourceBuilderAction); /// /// Registers an and an in the . @@ -103,15 +108,17 @@ public static IServiceCollection AddNpgsqlDataSource( /// The lifetime with which to register the service in the container. /// Defaults to . /// + /// The of the data source. /// The same service collection so that multiple calls can be chained. - [EditorBrowsable(EditorBrowsableState.Never), Obsolete("Defined for binary compatibility with 7.0")] - public static IServiceCollection AddNpgsqlDataSource( + public static IServiceCollection AddNpgsqlSlimDataSource( this IServiceCollection serviceCollection, string connectionString, - ServiceLifetime connectionLifetime, - ServiceLifetime dataSourceLifetime) - => AddNpgsqlDataSourceCore( - serviceCollection, serviceKey: null, connectionString, dataSourceBuilderAction: null, connectionLifetime, dataSourceLifetime); + ServiceLifetime connectionLifetime = ServiceLifetime.Transient, + ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton, + object? serviceKey = null) + => AddNpgsqlSlimDataSourceCore( + serviceCollection, serviceKey, connectionString, dataSourceBuilderAction: null, + connectionLifetime, dataSourceLifetime, state: null); /// /// Registers an and an in the . @@ -138,7 +145,9 @@ public static IServiceCollection AddNpgsqlSlimDataSource( ServiceLifetime connectionLifetime = ServiceLifetime.Transient, ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton, object? serviceKey = null) - => AddNpgsqlSlimDataSourceCore(serviceCollection, serviceKey, connectionString, dataSourceBuilderAction, connectionLifetime, dataSourceLifetime); + => AddNpgsqlSlimDataSourceCore(serviceCollection, serviceKey, connectionString, + static (_, builder, state) => ((Action)state!)(builder), + connectionLifetime, dataSourceLifetime, state: dataSourceBuilderAction); /// /// Registers an and an in the . @@ -156,71 +165,25 @@ public static IServiceCollection AddNpgsqlSlimDataSource( /// The lifetime with which to register the service in the container. /// Defaults to . /// - /// The same service collection so that multiple calls can be chained. - [EditorBrowsable(EditorBrowsableState.Never), Obsolete("Defined for binary compatibility with 7.0")] - public static IServiceCollection AddNpgsqlSlimDataSource( - this IServiceCollection serviceCollection, - string connectionString, - Action dataSourceBuilderAction, - ServiceLifetime connectionLifetime, - ServiceLifetime dataSourceLifetime) - => AddNpgsqlSlimDataSourceCore(serviceCollection, serviceKey: null, connectionString, dataSourceBuilderAction, connectionLifetime, dataSourceLifetime); - - /// - /// Registers an and an in the . - /// - /// The to add services to. - /// An Npgsql connection string. - /// - /// The lifetime with which to register the in the container. - /// Defaults to . - /// - /// - /// The lifetime with which to register the service in the container. - /// Defaults to . - /// /// The of the data source. /// The same service collection so that multiple calls can be chained. public static IServiceCollection AddNpgsqlSlimDataSource( this IServiceCollection serviceCollection, string connectionString, + Action dataSourceBuilderAction, ServiceLifetime connectionLifetime = ServiceLifetime.Transient, ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton, object? serviceKey = null) - => AddNpgsqlSlimDataSourceCore( - serviceCollection, serviceKey, connectionString, dataSourceBuilderAction: null, connectionLifetime, dataSourceLifetime); - - /// - /// Registers an and an in the . - /// - /// The to add services to. - /// An Npgsql connection string. - /// - /// The lifetime with which to register the in the container. - /// Defaults to . - /// - /// - /// The lifetime with which to register the service in the container. - /// Defaults to . - /// - /// The same service collection so that multiple calls can be chained. - [EditorBrowsable(EditorBrowsableState.Never), Obsolete("Defined for binary compatibility with 7.0")] - public static IServiceCollection AddNpgsqlSlimDataSource( - this IServiceCollection serviceCollection, - string connectionString, - ServiceLifetime connectionLifetime, - ServiceLifetime dataSourceLifetime) - => AddNpgsqlSlimDataSourceCore( - serviceCollection, serviceKey: null, connectionString, dataSourceBuilderAction: null, connectionLifetime, dataSourceLifetime); + => AddNpgsqlSlimDataSourceCore(serviceCollection, serviceKey, connectionString, + static (sp, builder, state) => ((Action)state!)(sp, builder), + connectionLifetime, dataSourceLifetime, state: dataSourceBuilderAction); /// /// Registers an and an in the + /// . /// /// The to add services to. /// An Npgsql connection string. - /// - /// An action to configure the for further customizations of the . - /// /// /// The lifetime with which to register the in the container. /// Defaults to . @@ -234,12 +197,12 @@ public static IServiceCollection AddNpgsqlSlimDataSource( public static IServiceCollection AddMultiHostNpgsqlDataSource( this IServiceCollection serviceCollection, string connectionString, - Action dataSourceBuilderAction, ServiceLifetime connectionLifetime = ServiceLifetime.Transient, ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton, object? serviceKey = null) => AddMultiHostNpgsqlDataSourceCore( - serviceCollection, serviceKey, connectionString, dataSourceBuilderAction, connectionLifetime, dataSourceLifetime); + serviceCollection, serviceKey, connectionString, dataSourceBuilderAction: null, + connectionLifetime, dataSourceLifetime, state: null); /// /// Registers an and an in the @@ -257,23 +220,28 @@ public static IServiceCollection AddMultiHostNpgsqlDataSource( /// The lifetime with which to register the service in the container. /// Defaults to . /// + /// The of the data source. /// The same service collection so that multiple calls can be chained. - [EditorBrowsable(EditorBrowsableState.Never), Obsolete("Defined for binary compatibility with 7.0")] public static IServiceCollection AddMultiHostNpgsqlDataSource( this IServiceCollection serviceCollection, string connectionString, Action dataSourceBuilderAction, - ServiceLifetime connectionLifetime, - ServiceLifetime dataSourceLifetime) + ServiceLifetime connectionLifetime = ServiceLifetime.Transient, + ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton, + object? serviceKey = null) => AddMultiHostNpgsqlDataSourceCore( - serviceCollection, serviceKey: null, connectionString, dataSourceBuilderAction, connectionLifetime, dataSourceLifetime); + serviceCollection, serviceKey, connectionString, + static (_, builder, state) => ((Action)state!)(builder), + connectionLifetime, dataSourceLifetime, state: dataSourceBuilderAction); /// /// Registers an and an in the - /// . /// /// The to add services to. /// An Npgsql connection string. + /// + /// An action to configure the for further customizations of the . + /// /// /// The lifetime with which to register the in the container. /// Defaults to . @@ -287,11 +255,14 @@ public static IServiceCollection AddMultiHostNpgsqlDataSource( public static IServiceCollection AddMultiHostNpgsqlDataSource( this IServiceCollection serviceCollection, string connectionString, + Action dataSourceBuilderAction, ServiceLifetime connectionLifetime = ServiceLifetime.Transient, ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton, object? serviceKey = null) => AddMultiHostNpgsqlDataSourceCore( - serviceCollection, serviceKey, connectionString, dataSourceBuilderAction: null, connectionLifetime, dataSourceLifetime); + serviceCollection, serviceKey, connectionString, + static (sp, builder, state) => ((Action)state!)(sp, builder), + connectionLifetime, dataSourceLifetime, state: dataSourceBuilderAction); /// /// Registers an and an in the @@ -307,15 +278,17 @@ public static IServiceCollection AddMultiHostNpgsqlDataSource( /// The lifetime with which to register the service in the container. /// Defaults to . /// + /// The of the data source. /// The same service collection so that multiple calls can be chained. - [EditorBrowsable(EditorBrowsableState.Never), Obsolete("Defined for binary compatibility with 7.0")] - public static IServiceCollection AddMultiHostNpgsqlDataSource( + public static IServiceCollection AddMultiHostNpgsqlSlimDataSource( this IServiceCollection serviceCollection, string connectionString, - ServiceLifetime connectionLifetime, - ServiceLifetime dataSourceLifetime) - => AddMultiHostNpgsqlDataSourceCore( - serviceCollection, serviceKey: null, connectionString, dataSourceBuilderAction: null, connectionLifetime, dataSourceLifetime); + ServiceLifetime connectionLifetime = ServiceLifetime.Transient, + ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton, + object? serviceKey = null) + => AddMultiHostNpgsqlSlimDataSourceCore( + serviceCollection, serviceKey, connectionString, dataSourceBuilderAction: null, + connectionLifetime, dataSourceLifetime, state: null); /// /// Registers an and an in the @@ -343,7 +316,9 @@ public static IServiceCollection AddMultiHostNpgsqlSlimDataSource( ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton, object? serviceKey = null) => AddMultiHostNpgsqlSlimDataSourceCore( - serviceCollection, serviceKey, connectionString, dataSourceBuilderAction, connectionLifetime, dataSourceLifetime); + serviceCollection, serviceKey, connectionString, + static (_, builder, state) => ((Action)state!)(builder), + connectionLifetime, dataSourceLifetime, state: dataSourceBuilderAction); /// /// Registers an and an in the @@ -361,73 +336,28 @@ public static IServiceCollection AddMultiHostNpgsqlSlimDataSource( /// The lifetime with which to register the service in the container. /// Defaults to . /// - /// The same service collection so that multiple calls can be chained. - [EditorBrowsable(EditorBrowsableState.Never), Obsolete("Defined for binary compatibility with 7.0")] - public static IServiceCollection AddMultiHostNpgsqlSlimDataSource( - this IServiceCollection serviceCollection, - string connectionString, - Action dataSourceBuilderAction, - ServiceLifetime connectionLifetime, - ServiceLifetime dataSourceLifetime) - => AddMultiHostNpgsqlSlimDataSourceCore( - serviceCollection, serviceKey: null, connectionString, dataSourceBuilderAction, connectionLifetime, dataSourceLifetime); - - /// - /// Registers an and an in the - /// . - /// - /// The to add services to. - /// An Npgsql connection string. - /// - /// The lifetime with which to register the in the container. - /// Defaults to . - /// - /// - /// The lifetime with which to register the service in the container. - /// Defaults to . - /// /// The of the data source. /// The same service collection so that multiple calls can be chained. public static IServiceCollection AddMultiHostNpgsqlSlimDataSource( this IServiceCollection serviceCollection, string connectionString, + Action dataSourceBuilderAction, ServiceLifetime connectionLifetime = ServiceLifetime.Transient, ServiceLifetime dataSourceLifetime = ServiceLifetime.Singleton, object? serviceKey = null) => AddMultiHostNpgsqlSlimDataSourceCore( - serviceCollection, serviceKey, connectionString, dataSourceBuilderAction: null, connectionLifetime, dataSourceLifetime); - - /// - /// Registers an and an in the - /// . - /// - /// The to add services to. - /// An Npgsql connection string. - /// - /// The lifetime with which to register the in the container. - /// Defaults to . - /// - /// - /// The lifetime with which to register the service in the container. - /// Defaults to . - /// - /// The same service collection so that multiple calls can be chained. - [EditorBrowsable(EditorBrowsableState.Never), Obsolete("Defined for binary compatibility with 7.0")] - public static IServiceCollection AddMultiHostNpgsqlSlimDataSource( - this IServiceCollection serviceCollection, - string connectionString, - ServiceLifetime connectionLifetime, - ServiceLifetime dataSourceLifetime) - => AddMultiHostNpgsqlSlimDataSourceCore( - serviceCollection, serviceKey: null, connectionString, dataSourceBuilderAction: null, connectionLifetime, dataSourceLifetime); + serviceCollection, serviceKey, connectionString, + static (sp, builder, state) => ((Action)state!)(sp, builder), + connectionLifetime, dataSourceLifetime, state: dataSourceBuilderAction); static IServiceCollection AddNpgsqlDataSourceCore( this IServiceCollection serviceCollection, object? serviceKey, string connectionString, - Action? dataSourceBuilderAction, + Action? dataSourceBuilderAction, ServiceLifetime connectionLifetime, - ServiceLifetime dataSourceLifetime) + ServiceLifetime dataSourceLifetime, + object? state) { serviceCollection.TryAdd( new ServiceDescriptor( @@ -437,7 +367,7 @@ static IServiceCollection AddNpgsqlDataSourceCore( { var dataSourceBuilder = new NpgsqlDataSourceBuilder(connectionString); dataSourceBuilder.UseLoggerFactory(sp.GetService()); - dataSourceBuilderAction?.Invoke(dataSourceBuilder); + dataSourceBuilderAction?.Invoke(sp, dataSourceBuilder, state); return dataSourceBuilder.Build(); }, dataSourceLifetime)); @@ -451,9 +381,10 @@ static IServiceCollection AddNpgsqlSlimDataSourceCore( this IServiceCollection serviceCollection, object? serviceKey, string connectionString, - Action? dataSourceBuilderAction, + Action? dataSourceBuilderAction, ServiceLifetime connectionLifetime, - ServiceLifetime dataSourceLifetime) + ServiceLifetime dataSourceLifetime, + object? state) { serviceCollection.TryAdd( new ServiceDescriptor( @@ -463,7 +394,7 @@ static IServiceCollection AddNpgsqlSlimDataSourceCore( { var dataSourceBuilder = new NpgsqlSlimDataSourceBuilder(connectionString); dataSourceBuilder.UseLoggerFactory(sp.GetService()); - dataSourceBuilderAction?.Invoke(dataSourceBuilder); + dataSourceBuilderAction?.Invoke(sp, dataSourceBuilder, state); return dataSourceBuilder.Build(); }, dataSourceLifetime)); @@ -477,9 +408,10 @@ static IServiceCollection AddMultiHostNpgsqlDataSourceCore( this IServiceCollection serviceCollection, object? serviceKey, string connectionString, - Action? dataSourceBuilderAction, + Action? dataSourceBuilderAction, ServiceLifetime connectionLifetime, - ServiceLifetime dataSourceLifetime) + ServiceLifetime dataSourceLifetime, + object? state) { serviceCollection.TryAdd( new ServiceDescriptor( @@ -489,7 +421,7 @@ static IServiceCollection AddMultiHostNpgsqlDataSourceCore( { var dataSourceBuilder = new NpgsqlDataSourceBuilder(connectionString); dataSourceBuilder.UseLoggerFactory(sp.GetService()); - dataSourceBuilderAction?.Invoke(dataSourceBuilder); + dataSourceBuilderAction?.Invoke(sp, dataSourceBuilder, state); return dataSourceBuilder.BuildMultiHost(); }, dataSourceLifetime)); @@ -522,9 +454,10 @@ static IServiceCollection AddMultiHostNpgsqlSlimDataSourceCore( this IServiceCollection serviceCollection, object? serviceKey, string connectionString, - Action? dataSourceBuilderAction, + Action? dataSourceBuilderAction, ServiceLifetime connectionLifetime, - ServiceLifetime dataSourceLifetime) + ServiceLifetime dataSourceLifetime, + object? state) { serviceCollection.TryAdd( new ServiceDescriptor( @@ -534,7 +467,7 @@ static IServiceCollection AddMultiHostNpgsqlSlimDataSourceCore( { var dataSourceBuilder = new NpgsqlSlimDataSourceBuilder(connectionString); dataSourceBuilder.UseLoggerFactory(sp.GetService()); - dataSourceBuilderAction?.Invoke(dataSourceBuilder); + dataSourceBuilderAction?.Invoke(sp, dataSourceBuilder, state); return dataSourceBuilder.BuildMultiHost(); }, dataSourceLifetime)); diff --git a/src/Npgsql.DependencyInjection/PublicAPI.Unshipped.txt b/src/Npgsql.DependencyInjection/PublicAPI.Unshipped.txt index ab058de62d..34f2d889e9 100644 --- a/src/Npgsql.DependencyInjection/PublicAPI.Unshipped.txt +++ b/src/Npgsql.DependencyInjection/PublicAPI.Unshipped.txt @@ -1 +1,5 @@ #nullable enable +static Microsoft.Extensions.DependencyInjection.NpgsqlServiceCollectionExtensions.AddMultiHostNpgsqlDataSource(this Microsoft.Extensions.DependencyInjection.IServiceCollection! serviceCollection, string! connectionString, System.Action! dataSourceBuilderAction, Microsoft.Extensions.DependencyInjection.ServiceLifetime connectionLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Transient, Microsoft.Extensions.DependencyInjection.ServiceLifetime dataSourceLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton, object? serviceKey = null) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! +static Microsoft.Extensions.DependencyInjection.NpgsqlServiceCollectionExtensions.AddMultiHostNpgsqlSlimDataSource(this Microsoft.Extensions.DependencyInjection.IServiceCollection! serviceCollection, string! connectionString, System.Action! dataSourceBuilderAction, Microsoft.Extensions.DependencyInjection.ServiceLifetime connectionLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Transient, Microsoft.Extensions.DependencyInjection.ServiceLifetime dataSourceLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton, object? serviceKey = null) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! +static Microsoft.Extensions.DependencyInjection.NpgsqlServiceCollectionExtensions.AddNpgsqlDataSource(this Microsoft.Extensions.DependencyInjection.IServiceCollection! serviceCollection, string! connectionString, System.Action! dataSourceBuilderAction, Microsoft.Extensions.DependencyInjection.ServiceLifetime connectionLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Transient, Microsoft.Extensions.DependencyInjection.ServiceLifetime dataSourceLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton, object? serviceKey = null) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! +static Microsoft.Extensions.DependencyInjection.NpgsqlServiceCollectionExtensions.AddNpgsqlSlimDataSource(this Microsoft.Extensions.DependencyInjection.IServiceCollection! serviceCollection, string! connectionString, System.Action! dataSourceBuilderAction, Microsoft.Extensions.DependencyInjection.ServiceLifetime connectionLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Transient, Microsoft.Extensions.DependencyInjection.ServiceLifetime dataSourceLifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Singleton, object? serviceKey = null) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! diff --git a/src/Npgsql.GeoJSON/CrsMap.WellKnown.cs b/src/Npgsql.GeoJSON/CrsMap.WellKnown.cs index 14da2f893e..dda11bd1d7 100644 --- a/src/Npgsql.GeoJSON/CrsMap.WellKnown.cs +++ b/src/Npgsql.GeoJSON/CrsMap.WellKnown.cs @@ -5,10 +5,10 @@ public partial class CrsMap /// /// These entries came from spatial_res_sys. They are used to elide memory allocations /// if they are identical to the entries for the current connection. Otherwise, - /// memory allocated for overrided entries only (added, removed, or modified). + /// memory allocated for overridden entries only (added, removed, or modified). /// internal static readonly CrsMapEntry[] WellKnown = - { + [ new(2000, 2180, "EPSG"), new(2188, 2217, "EPSG"), new(2219, 2220, "EPSG"), @@ -584,6 +584,6 @@ public partial class CrsMap new(32601, 32667, "EPSG"), new(32701, 32761, "EPSG"), new(32766, 32766, "EPSG"), - new(900913, 900913, "spatialreferencing.org"), - }; + new(900913, 900913, "spatialreferencing.org") + ]; } diff --git a/src/Npgsql.GeoJSON/CrsMap.cs b/src/Npgsql.GeoJSON/CrsMap.cs index dd556d9b33..602387a911 100644 --- a/src/Npgsql.GeoJSON/CrsMap.cs +++ b/src/Npgsql.GeoJSON/CrsMap.cs @@ -6,13 +6,13 @@ namespace Npgsql.GeoJSON; /// public partial class CrsMap { - readonly CrsMapEntry[]? _overriden; + readonly CrsMapEntry[]? _overridden; - internal CrsMap(CrsMapEntry[]? overriden) - => _overriden = overriden; + internal CrsMap(CrsMapEntry[]? overridden) + => _overridden = overridden; internal string? GetAuthority(int srid) - => GetAuthority(_overriden, srid) ?? GetAuthority(WellKnown, srid); + => GetAuthority(_overridden, srid) ?? GetAuthority(WellKnown, srid); static string? GetAuthority(CrsMapEntry[]? entries, int srid) { diff --git a/src/Npgsql.GeoJSON/Internal/BoundingBoxBuilder.cs b/src/Npgsql.GeoJSON/Internal/BoundingBoxBuilder.cs index 7702a7e0b3..c3ea8f271f 100644 --- a/src/Npgsql.GeoJSON/Internal/BoundingBoxBuilder.cs +++ b/src/Npgsql.GeoJSON/Internal/BoundingBoxBuilder.cs @@ -48,6 +48,6 @@ internal void Accumulate(Position position) internal double[] Build() => _hasAltitude - ? new[] { _minLongitude, _minLatitude, _minAltitude, _maxLongitude, _maxLatitude, _maxAltitude } - : new[] { _minLongitude, _minLatitude, _maxLongitude, _maxLatitude }; + ? [_minLongitude, _minLatitude, _minAltitude, _maxLongitude, _maxLatitude, _maxAltitude] + : [_minLongitude, _minLatitude, _maxLongitude, _maxLatitude]; } \ No newline at end of file diff --git a/src/Npgsql.GeoJSON/Internal/CrsMapBuilder.cs b/src/Npgsql.GeoJSON/Internal/CrsMapBuilder.cs index 44829761c9..95f45d5db3 100644 --- a/src/Npgsql.GeoJSON/Internal/CrsMapBuilder.cs +++ b/src/Npgsql.GeoJSON/Internal/CrsMapBuilder.cs @@ -5,7 +5,7 @@ namespace Npgsql.GeoJSON.Internal; struct CrsMapBuilder { CrsMapEntry[] _overrides; - int _overridenIndex; + int _overriddenIndex; int _wellKnownIndex; internal void Add(in CrsMapEntry entry) @@ -33,21 +33,21 @@ internal void Add(in CrsMapEntry entry) void AddCore(in CrsMapEntry entry) { - var index = _overridenIndex + 1; + var index = _overriddenIndex + 1; if (_overrides == null) _overrides = new CrsMapEntry[4]; else if (_overrides.Length == index) Array.Resize(ref _overrides, _overrides.Length << 1); - _overrides[_overridenIndex] = entry; - _overridenIndex = index; + _overrides[_overriddenIndex] = entry; + _overriddenIndex = index; } internal CrsMap Build() { - if (_overrides != null && _overrides.Length < _overridenIndex) - Array.Resize(ref _overrides, _overridenIndex); + if (_overrides != null && _overrides.Length < _overriddenIndex) + Array.Resize(ref _overrides, _overriddenIndex); return new CrsMap(_overrides); } diff --git a/src/Npgsql.GeoJSON/Internal/GeoJSONConverter.cs b/src/Npgsql.GeoJSON/Internal/GeoJSONConverter.cs index 755c8acc19..5d54d16194 100644 --- a/src/Npgsql.GeoJSON/Internal/GeoJSONConverter.cs +++ b/src/Npgsql.GeoJSON/Internal/GeoJSONConverter.cs @@ -278,16 +278,9 @@ static Position ReadPosition(PgReader reader, EwkbGeometryType type, bool little return position; double ReadDouble(bool littleEndian) - { - if (littleEndian) - { - var doubleValue = reader.ReadDouble(); - var value = BinaryPrimitives.ReverseEndianness(Unsafe.As(ref doubleValue)); - return Unsafe.As(ref value); - } - - return reader.ReadDouble(); - } + => littleEndian + ? BitConverter.Int64BitsToDouble(BinaryPrimitives.ReverseEndianness(BitConverter.DoubleToInt64Bits(reader.ReadDouble()))) + : reader.ReadDouble(); } } @@ -330,7 +323,7 @@ static Size GetSize(LineString value) { var coordinates = value.Coordinates; if (NotValid(coordinates, out var hasZ)) - throw AllOrNoneCoordiantesMustHaveZ(nameof(LineString)); + throw AllOrNoneCoordinatesMustHaveZ(nameof(LineString)); var length = Size.Create(SizeOfHeaderWithLength + coordinates.Count * SizeOfPoint(hasZ)); if (GetSrid(value.CRS) != 0) @@ -351,12 +344,12 @@ static Size GetSize(Polygon value) { var coordinates = lines[i].Coordinates; if (NotValid(coordinates, out var lineHasZ)) - throw AllOrNoneCoordiantesMustHaveZ(nameof(Polygon)); + throw AllOrNoneCoordinatesMustHaveZ(nameof(Polygon)); if (hasZ != lineHasZ) { if (i == 0) hasZ = lineHasZ; - else throw AllOrNoneCoordiantesMustHaveZ(nameof(LineString)); + else throw AllOrNoneCoordinatesMustHaveZ(nameof(LineString)); } length = length.Combine(coordinates.Count * SizeOfPoint(hasZ)); @@ -685,7 +678,7 @@ static int SizeOfPoint(EwkbGeometryType type) static Exception UnknownPostGisType() => throw new InvalidOperationException("Invalid PostGIS type"); - static Exception AllOrNoneCoordiantesMustHaveZ(string typeName) + static Exception AllOrNoneCoordinatesMustHaveZ(string typeName) => new ArgumentException($"The Z coordinate must be specified for all or none elements of {typeName}"); static int GetSrid(ICRSObject crs) diff --git a/src/Npgsql.GeoJSON/Internal/GeoJSONTypeInfoResolverFactory.cs b/src/Npgsql.GeoJSON/Internal/GeoJSONTypeInfoResolverFactory.cs index c25118f1d7..f1b56000f2 100644 --- a/src/Npgsql.GeoJSON/Internal/GeoJSONTypeInfoResolverFactory.cs +++ b/src/Npgsql.GeoJSON/Internal/GeoJSONTypeInfoResolverFactory.cs @@ -6,37 +6,17 @@ namespace Npgsql.GeoJSON.Internal; -sealed class GeoJSONTypeInfoResolverFactory : PgTypeInfoResolverFactory +sealed class GeoJSONTypeInfoResolverFactory(GeoJSONOptions options, bool geographyAsDefault, CrsMap? crsMap = null) + : PgTypeInfoResolverFactory { - readonly GeoJSONOptions _options; - readonly bool _geographyAsDefault; - readonly CrsMap? _crsMap; + public override IPgTypeInfoResolver CreateResolver() => new Resolver(options, geographyAsDefault, crsMap); + public override IPgTypeInfoResolver? CreateArrayResolver() => new ArrayResolver(options, geographyAsDefault, crsMap); - public GeoJSONTypeInfoResolverFactory(GeoJSONOptions options, bool geographyAsDefault, CrsMap? crsMap = null) + class Resolver(GeoJSONOptions options, bool geographyAsDefault, CrsMap? crsMap = null) + : IPgTypeInfoResolver { - _options = options; - _geographyAsDefault = geographyAsDefault; - _crsMap = crsMap; - } - - public override IPgTypeInfoResolver CreateResolver() => new Resolver(_options, _geographyAsDefault, _crsMap); - public override IPgTypeInfoResolver? CreateArrayResolver() => new ArrayResolver(_options, _geographyAsDefault, _crsMap); - - class Resolver : IPgTypeInfoResolver - { - readonly GeoJSONOptions _options; - readonly bool _geographyAsDefault; - readonly CrsMap? _crsMap; - TypeInfoMappingCollection? _mappings; - protected TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(), _options, _geographyAsDefault, _crsMap); - - public Resolver(GeoJSONOptions options, bool geographyAsDefault, CrsMap? crsMap = null) - { - _options = options; - _geographyAsDefault = geographyAsDefault; - _crsMap = crsMap; - } + protected TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(), options, geographyAsDefault, crsMap); public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) => Mappings.Find(type, dataTypeName, options); @@ -83,16 +63,12 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings, } } - sealed class ArrayResolver : Resolver, IPgTypeInfoResolver + sealed class ArrayResolver(GeoJSONOptions options, bool geographyAsDefault, CrsMap? crsMap = null) + : Resolver(options, geographyAsDefault, crsMap), IPgTypeInfoResolver { TypeInfoMappingCollection? _mappings; new TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(base.Mappings)); - public ArrayResolver(GeoJSONOptions options, bool geographyAsDefault, CrsMap? crsMap = null) - : base(options, geographyAsDefault, crsMap) - { - } - public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) => Mappings.Find(type, dataTypeName, options); diff --git a/src/Npgsql.GeoJSON/Npgsql.GeoJSON.csproj b/src/Npgsql.GeoJSON/Npgsql.GeoJSON.csproj index a802ca5653..824c5b79e6 100644 --- a/src/Npgsql.GeoJSON/Npgsql.GeoJSON.csproj +++ b/src/Npgsql.GeoJSON/Npgsql.GeoJSON.csproj @@ -3,8 +3,8 @@ Yoh Deadfall;Shay Rojansky GeoJSON plugin for Npgsql, allowing mapping of PostGIS geometry types to GeoJSON types. npgsql;postgresql;postgres;postgis;geojson;spatial;ado;ado.net;database;sql - net6.0 - net8.0 + net10.0 + $(NoWarn);NPG9001 diff --git a/src/Npgsql.GeoJSON/NpgsqlGeoJSONExtensions.cs b/src/Npgsql.GeoJSON/NpgsqlGeoJSONExtensions.cs index b47a9b211f..9651004a86 100644 --- a/src/Npgsql.GeoJSON/NpgsqlGeoJSONExtensions.cs +++ b/src/Npgsql.GeoJSON/NpgsqlGeoJSONExtensions.cs @@ -10,6 +10,7 @@ namespace Npgsql; /// public static class NpgsqlGeoJSONExtensions { + // Note: defined for binary compatibility and NpgsqlConnection.GlobalTypeMapper. /// /// Sets up GeoJSON mappings for the PostGIS types. /// @@ -22,6 +23,7 @@ public static INpgsqlTypeMapper UseGeoJson(this INpgsqlTypeMapper mapper, GeoJSO return mapper; } + // Note: defined for binary compatibility and NpgsqlConnection.GlobalTypeMapper. /// /// Sets up GeoJSON mappings for the PostGIS types. /// @@ -34,4 +36,31 @@ public static INpgsqlTypeMapper UseGeoJson(this INpgsqlTypeMapper mapper, CrsMap mapper.AddTypeInfoResolverFactory(new GeoJSONTypeInfoResolverFactory(options, geographyAsDefault, crsMap)); return mapper; } + + /// + /// Sets up GeoJSON mappings for the PostGIS types. + /// + /// The type mapper to set up (global or connection-specific) + /// Options to use when constructing objects. + /// Specifies that the geography type is used for mapping by default. + public static TMapper UseGeoJson(this TMapper mapper, GeoJSONOptions options = GeoJSONOptions.None, bool geographyAsDefault = false) + where TMapper : INpgsqlTypeMapper + { + mapper.AddTypeInfoResolverFactory(new GeoJSONTypeInfoResolverFactory(options, geographyAsDefault, crsMap: null)); + return mapper; + } + + /// + /// Sets up GeoJSON mappings for the PostGIS types. + /// + /// The type mapper to set up (global or connection-specific) + /// A custom crs map that might contain more or less entries than the default well-known crs map. + /// Options to use when constructing objects. + /// Specifies that the geography type is used for mapping by default. + public static TMapper UseGeoJson(this TMapper mapper, CrsMap crsMap, GeoJSONOptions options = GeoJSONOptions.None, bool geographyAsDefault = false) + where TMapper : INpgsqlTypeMapper + { + mapper.AddTypeInfoResolverFactory(new GeoJSONTypeInfoResolverFactory(options, geographyAsDefault, crsMap)); + return mapper; + } } diff --git a/src/Npgsql.GeoJSON/PublicAPI.Unshipped.txt b/src/Npgsql.GeoJSON/PublicAPI.Unshipped.txt index ab058de62d..34de07f0d7 100644 --- a/src/Npgsql.GeoJSON/PublicAPI.Unshipped.txt +++ b/src/Npgsql.GeoJSON/PublicAPI.Unshipped.txt @@ -1 +1,3 @@ #nullable enable +static Npgsql.NpgsqlGeoJSONExtensions.UseGeoJson(this TMapper mapper, Npgsql.GeoJSON.CrsMap! crsMap, Npgsql.GeoJSONOptions options = Npgsql.GeoJSONOptions.None, bool geographyAsDefault = false) -> TMapper +static Npgsql.NpgsqlGeoJSONExtensions.UseGeoJson(this TMapper mapper, Npgsql.GeoJSONOptions options = Npgsql.GeoJSONOptions.None, bool geographyAsDefault = false) -> TMapper diff --git a/src/Npgsql.Json.NET/Internal/JsonNetJsonConverter.cs b/src/Npgsql.Json.NET/Internal/JsonNetJsonConverter.cs index 42b7c88e0d..10126d25f9 100644 --- a/src/Npgsql.Json.NET/Internal/JsonNetJsonConverter.cs +++ b/src/Npgsql.Json.NET/Internal/JsonNetJsonConverter.cs @@ -10,35 +10,24 @@ namespace Npgsql.Json.NET.Internal; -sealed class JsonNetJsonConverter : PgStreamingConverter +sealed class JsonNetJsonConverter(bool jsonb, Encoding textEncoding, JsonSerializerSettings settings) : PgStreamingConverter { - readonly bool _jsonb; - readonly Encoding _textEncoding; - readonly JsonSerializerSettings _settings; - - public JsonNetJsonConverter(bool jsonb, Encoding textEncoding, JsonSerializerSettings settings) - { - _jsonb = jsonb; - _textEncoding = textEncoding; - _settings = settings; - } - public override T? Read(PgReader reader) - => (T?)JsonNetJsonConverter.Read(async: false, _jsonb, reader, typeof(T), _settings, _textEncoding, CancellationToken.None).GetAwaiter().GetResult(); + => (T?)JsonNetJsonConverter.Read(async: false, jsonb, reader, typeof(T), settings, textEncoding, CancellationToken.None).GetAwaiter().GetResult(); public override async ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) - => (T?)await JsonNetJsonConverter.Read(async: true, _jsonb, reader, typeof(T), _settings, _textEncoding, cancellationToken).ConfigureAwait(false); + => (T?)await JsonNetJsonConverter.Read(async: true, jsonb, reader, typeof(T), settings, textEncoding, cancellationToken).ConfigureAwait(false); public override Size GetSize(SizeContext context, T? value, ref object? writeState) - => JsonNetJsonConverter.GetSize(_jsonb, context, typeof(T), _settings, _textEncoding, value, ref writeState); + => JsonNetJsonConverter.GetSize(jsonb, context, typeof(T), settings, textEncoding, value, ref writeState); public override void Write(PgWriter writer, T? value) - => JsonNetJsonConverter.Write(_jsonb, async: false, writer, CancellationToken.None).GetAwaiter().GetResult(); + => JsonNetJsonConverter.Write(jsonb, async: false, writer, CancellationToken.None).GetAwaiter().GetResult(); public override ValueTask WriteAsync(PgWriter writer, T? value, CancellationToken cancellationToken = default) - => JsonNetJsonConverter.Write(_jsonb, async: true, writer, cancellationToken); + => JsonNetJsonConverter.Write(jsonb, async: true, writer, cancellationToken); } -// Split out to avoid unneccesary code duplication. +// Split out to avoid unnecessary code duplication. static class JsonNetJsonConverter { public const byte JsonbProtocolVersion = 1; @@ -62,7 +51,7 @@ static class JsonNetJsonConverter using var stream = reader.GetStream(); var mem = new MemoryStream(); if (async) - await stream.CopyToAsync(mem, Math.Min((int)mem.Length, 81920), cancellationToken).ConfigureAwait(false); + await stream.CopyToAsync(mem, Math.Min((int)stream.Length, 81920), cancellationToken).ConfigureAwait(false); else stream.CopyTo(mem); mem.Position = 0; diff --git a/src/Npgsql.Json.NET/Internal/JsonNetPocoTypeInfoResolverFactory.cs b/src/Npgsql.Json.NET/Internal/JsonNetPocoTypeInfoResolverFactory.cs index 27f719deca..c038f17aab 100644 --- a/src/Npgsql.Json.NET/Internal/JsonNetPocoTypeInfoResolverFactory.cs +++ b/src/Npgsql.Json.NET/Internal/JsonNetPocoTypeInfoResolverFactory.cs @@ -9,43 +9,29 @@ namespace Npgsql.Json.NET.Internal; [RequiresUnreferencedCode("Json serializer may perform reflection on trimmed types.")] [RequiresDynamicCode("Serializing arbitrary types to json can require creating new generic types or methods, which requires creating code at runtime. This may not work when AOT compiling.")] -sealed class JsonNetPocoTypeInfoResolverFactory : PgTypeInfoResolverFactory +sealed class JsonNetPocoTypeInfoResolverFactory( + Type[]? jsonbClrTypes = null, + Type[]? jsonClrTypes = null, + JsonSerializerSettings? serializerSettings = null) + : PgTypeInfoResolverFactory { - readonly Type[]? _jsonbClrTypes; - readonly Type[]? _jsonClrTypes; - readonly JsonSerializerSettings? _serializerSettings; - - public JsonNetPocoTypeInfoResolverFactory(Type[]? jsonbClrTypes = null, Type[]? jsonClrTypes = null, JsonSerializerSettings? serializerSettings = null) - { - _jsonbClrTypes = jsonbClrTypes; - _jsonClrTypes = jsonClrTypes; - _serializerSettings = serializerSettings; - } - - public override IPgTypeInfoResolver CreateResolver() => new Resolver(_jsonbClrTypes, _jsonClrTypes, _serializerSettings); - public override IPgTypeInfoResolver? CreateArrayResolver() => new ArrayResolver(_jsonbClrTypes, _jsonClrTypes, _serializerSettings); + public override IPgTypeInfoResolver CreateResolver() => new Resolver(jsonbClrTypes, jsonClrTypes, serializerSettings); + public override IPgTypeInfoResolver? CreateArrayResolver() => new ArrayResolver(jsonbClrTypes, jsonClrTypes, serializerSettings); [RequiresUnreferencedCode("Json serializer may perform reflection on trimmed types.")] [RequiresDynamicCode("Serializing arbitrary types to json can require creating new generic types or methods, which requires creating code at runtime. This may not work when AOT compiling.")] - class Resolver : DynamicTypeInfoResolver, IPgTypeInfoResolver + class Resolver(Type[]? jsonbClrTypes = null, Type[]? jsonClrTypes = null, JsonSerializerSettings? serializerSettings = null) + : DynamicTypeInfoResolver, IPgTypeInfoResolver { - readonly Type[]? _jsonbClrTypes; - readonly Type[]? _jsonClrTypes; - readonly JsonSerializerSettings _serializerSettings; + readonly JsonSerializerSettings _serializerSettings = serializerSettings ?? JsonConvert.DefaultSettings?.Invoke() ?? new JsonSerializerSettings(); TypeInfoMappingCollection? _mappings; - protected TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(), _jsonbClrTypes ?? Array.Empty(), _jsonClrTypes ?? Array.Empty(), _serializerSettings); + protected TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(), jsonbClrTypes ?? [], jsonClrTypes ?? [], _serializerSettings); const string JsonDataTypeName = "pg_catalog.json"; const string JsonbDataTypeName = "pg_catalog.jsonb"; - public Resolver(Type[]? jsonbClrTypes = null, Type[]? jsonClrTypes = null, JsonSerializerSettings? serializerSettings = null) - { - _jsonbClrTypes = jsonbClrTypes; - _jsonClrTypes = jsonClrTypes; - // Capture default settings during construction. - _serializerSettings = serializerSettings ?? JsonConvert.DefaultSettings?.Invoke() ?? new JsonSerializerSettings(); - } + // Capture default settings during construction. TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings, Type[] jsonbClrTypes, Type[] jsonClrTypes, JsonSerializerSettings serializerSettings) { @@ -96,16 +82,12 @@ static PgConverter CreateConverter(Type valueType, bool jsonb, Encoding textEnco [RequiresUnreferencedCode("Json serializer may perform reflection on trimmed types.")] [RequiresDynamicCode("Serializing arbitrary types to json can require creating new generic types or methods, which requires creating code at runtime. This may not work when AOT compiling.")] - sealed class ArrayResolver : Resolver, IPgTypeInfoResolver + sealed class ArrayResolver(Type[]? jsonbClrTypes = null, Type[]? jsonClrTypes = null, JsonSerializerSettings? serializerSettings = null) + : Resolver(jsonbClrTypes, jsonClrTypes, serializerSettings), IPgTypeInfoResolver { TypeInfoMappingCollection? _mappings; new TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(base.Mappings), base.Mappings); - public ArrayResolver(Type[]? jsonbClrTypes = null, Type[]? jsonClrTypes = null, JsonSerializerSettings? serializerSettings = null) - : base(jsonbClrTypes, jsonClrTypes, serializerSettings) - { - } - public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) => Mappings.Find(type, dataTypeName, options) ?? base.GetTypeInfo(type, dataTypeName, options); diff --git a/src/Npgsql.Json.NET/Internal/JsonNetTypeInfoResolverFactory.cs b/src/Npgsql.Json.NET/Internal/JsonNetTypeInfoResolverFactory.cs index 1f07bf0252..be2e0a3ba7 100644 --- a/src/Npgsql.Json.NET/Internal/JsonNetTypeInfoResolverFactory.cs +++ b/src/Npgsql.Json.NET/Internal/JsonNetTypeInfoResolverFactory.cs @@ -6,26 +6,18 @@ namespace Npgsql.Json.NET.Internal; -sealed class JsonNetTypeInfoResolverFactory : PgTypeInfoResolverFactory +sealed class JsonNetTypeInfoResolverFactory(JsonSerializerSettings? settings = null) : PgTypeInfoResolverFactory { - readonly JsonSerializerSettings? _settings; + public override IPgTypeInfoResolver CreateResolver() => new Resolver(settings); + public override IPgTypeInfoResolver? CreateArrayResolver() => new ArrayResolver(settings); - public JsonNetTypeInfoResolverFactory(JsonSerializerSettings? settings = null) => _settings = settings; - - public override IPgTypeInfoResolver CreateResolver() => new Resolver(_settings); - public override IPgTypeInfoResolver? CreateArrayResolver() => new ArrayResolver(_settings); - - class Resolver : IPgTypeInfoResolver + class Resolver(JsonSerializerSettings? settings = null) : IPgTypeInfoResolver { TypeInfoMappingCollection? _mappings; - readonly JsonSerializerSettings _serializerSettings; + readonly JsonSerializerSettings _serializerSettings = settings ?? JsonConvert.DefaultSettings?.Invoke() ?? new JsonSerializerSettings(); protected TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(), _serializerSettings); - public Resolver(JsonSerializerSettings? settings = null) - { - // Capture default settings during construction. - _serializerSettings = settings ?? JsonConvert.DefaultSettings?.Invoke() ?? new JsonSerializerSettings(); - } + // Capture default settings during construction. static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings, JsonSerializerSettings settings) { @@ -34,7 +26,7 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings, { var jsonb = dataTypeName == "jsonb"; mappings.AddType(dataTypeName, (options, mapping, _) => - mapping.CreateInfo(options, new JsonNetJsonConverter(jsonb, options.TextEncoding, settings)), isDefault: true); + mapping.CreateInfo(options, new JsonNetJsonConverter(jsonb, options.TextEncoding, settings))); mappings.AddType(dataTypeName, (options, mapping, _) => mapping.CreateInfo(options, new JsonNetJsonConverter(jsonb, options.TextEncoding, settings))); mappings.AddType(dataTypeName, (options, mapping, _) => @@ -50,13 +42,11 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings, => Mappings.Find(type, dataTypeName, options); } - sealed class ArrayResolver : Resolver, IPgTypeInfoResolver + sealed class ArrayResolver(JsonSerializerSettings? settings = null) : Resolver(settings), IPgTypeInfoResolver { TypeInfoMappingCollection? _mappings; new TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(base.Mappings)); - public ArrayResolver(JsonSerializerSettings? settings = null) : base(settings) {} - public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) => Mappings.Find(type, dataTypeName, options); diff --git a/src/Npgsql.Json.NET/Npgsql.Json.NET.csproj b/src/Npgsql.Json.NET/Npgsql.Json.NET.csproj index 49707eb02f..df6b2ccef6 100644 --- a/src/Npgsql.Json.NET/Npgsql.Json.NET.csproj +++ b/src/Npgsql.Json.NET/Npgsql.Json.NET.csproj @@ -3,9 +3,10 @@ Shay Rojansky Json.NET plugin for Npgsql, allowing transparent serialization/deserialization of JSON objects directly to and from the database. npgsql;postgresql;json;postgres;ado;ado.net;database;sql - net6.0 - net8.0 + net10.0 enable + false + $(NoWarn);NPG9001 diff --git a/src/Npgsql.Json.NET/NpgsqlJsonNetExtensions.cs b/src/Npgsql.Json.NET/NpgsqlJsonNetExtensions.cs index f2b33933b8..89c8d21603 100644 --- a/src/Npgsql.Json.NET/NpgsqlJsonNetExtensions.cs +++ b/src/Npgsql.Json.NET/NpgsqlJsonNetExtensions.cs @@ -13,6 +13,7 @@ namespace Npgsql; /// public static class NpgsqlJsonNetExtensions { + // Note: defined for binary compatibility and NpgsqlConnection.GlobalTypeMapper. /// /// Sets up JSON.NET mappings for the PostgreSQL json and jsonb types. /// @@ -37,4 +38,30 @@ public static INpgsqlTypeMapper UseJsonNet( mapper.AddTypeInfoResolverFactory(new JsonNetTypeInfoResolverFactory(settings)); return mapper; } + + /// + /// Sets up JSON.NET mappings for the PostgreSQL json and jsonb types. + /// + /// The type mapper to set up. + /// Optional settings to customize JSON serialization. + /// + /// A list of CLR types to map to PostgreSQL jsonb (no need to specify ). + /// + /// + /// A list of CLR types to map to PostgreSQL json (no need to specify ). + /// + [RequiresUnreferencedCode("Json serializer may perform reflection on trimmed types.")] + [RequiresDynamicCode("Serializing arbitrary types to json can require creating new generic types or methods, which requires creating code at runtime. This may not work when AOT compiling.")] + public static TMapper UseJsonNet( + this TMapper mapper, + JsonSerializerSettings? settings = null, + Type[]? jsonbClrTypes = null, + Type[]? jsonClrTypes = null) + where TMapper : INpgsqlTypeMapper + { + // Reverse order + mapper.AddTypeInfoResolverFactory(new JsonNetPocoTypeInfoResolverFactory(jsonbClrTypes, jsonClrTypes, settings)); + mapper.AddTypeInfoResolverFactory(new JsonNetTypeInfoResolverFactory(settings)); + return mapper; + } } diff --git a/src/Npgsql.Json.NET/PublicAPI.Unshipped.txt b/src/Npgsql.Json.NET/PublicAPI.Unshipped.txt index ab058de62d..f4557570e1 100644 --- a/src/Npgsql.Json.NET/PublicAPI.Unshipped.txt +++ b/src/Npgsql.Json.NET/PublicAPI.Unshipped.txt @@ -1 +1,2 @@ #nullable enable +static Npgsql.NpgsqlJsonNetExtensions.UseJsonNet(this TMapper mapper, Newtonsoft.Json.JsonSerializerSettings? settings = null, System.Type![]? jsonbClrTypes = null, System.Type![]? jsonClrTypes = null) -> TMapper diff --git a/src/Npgsql.NetTopologySuite/Internal/NetTopologySuiteTypeInfoResolverFactory.cs b/src/Npgsql.NetTopologySuite/Internal/NetTopologySuiteTypeInfoResolverFactory.cs index b9a559c12f..e533d62207 100644 --- a/src/Npgsql.NetTopologySuite/Internal/NetTopologySuiteTypeInfoResolverFactory.cs +++ b/src/Npgsql.NetTopologySuite/Internal/NetTopologySuiteTypeInfoResolverFactory.cs @@ -7,24 +7,15 @@ namespace Npgsql.NetTopologySuite.Internal; -sealed class NetTopologySuiteTypeInfoResolverFactory : PgTypeInfoResolverFactory +sealed class NetTopologySuiteTypeInfoResolverFactory( + CoordinateSequenceFactory? coordinateSequenceFactory, + PrecisionModel? precisionModel, + Ordinates handleOrdinates, + bool geographyAsDefault) + : PgTypeInfoResolverFactory { - readonly CoordinateSequenceFactory? _coordinateSequenceFactory; - readonly PrecisionModel? _precisionModel; - readonly Ordinates _handleOrdinates; - readonly bool _geographyAsDefault; - - public NetTopologySuiteTypeInfoResolverFactory(CoordinateSequenceFactory? coordinateSequenceFactory, PrecisionModel? precisionModel, - Ordinates handleOrdinates, bool geographyAsDefault) - { - _coordinateSequenceFactory = coordinateSequenceFactory; - _precisionModel = precisionModel; - _handleOrdinates = handleOrdinates; - _geographyAsDefault = geographyAsDefault; - } - - public override IPgTypeInfoResolver CreateResolver() => new Resolver(_coordinateSequenceFactory, _precisionModel, _handleOrdinates, _geographyAsDefault); - public override IPgTypeInfoResolver? CreateArrayResolver() => new ArrayResolver(_coordinateSequenceFactory, _precisionModel, _handleOrdinates, _geographyAsDefault); + public override IPgTypeInfoResolver CreateResolver() => new Resolver(coordinateSequenceFactory, precisionModel, handleOrdinates, geographyAsDefault); + public override IPgTypeInfoResolver? CreateArrayResolver() => new ArrayResolver(coordinateSequenceFactory, precisionModel, handleOrdinates, geographyAsDefault); class Resolver : IPgTypeInfoResolver { @@ -54,7 +45,7 @@ public Resolver( static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings, PostGisReader reader, PostGisWriter writer, bool geographyAsDefault) { - foreach (var dataTypeName in geographyAsDefault ? new[] {"geography", "geometry"} : new[] { "geometry", "geography" }) + foreach (var dataTypeName in geographyAsDefault ? ["geography", "geometry"] : new[] { "geometry", "geography" }) { mappings.AddType(dataTypeName, (options, mapping, _) => mapping.CreateInfo(options, new NetTopologySuiteConverter(reader, writer)), @@ -82,23 +73,22 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings, } } - sealed class ArrayResolver : Resolver, IPgTypeInfoResolver + sealed class ArrayResolver( + CoordinateSequenceFactory? coordinateSequenceFactory, + PrecisionModel? precisionModel, + Ordinates handleOrdinates, + bool geographyAsDefault) + : Resolver(coordinateSequenceFactory, precisionModel, handleOrdinates, geographyAsDefault), IPgTypeInfoResolver { TypeInfoMappingCollection? _mappings; new TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(base.Mappings), _geographyAsDefault); - public ArrayResolver(CoordinateSequenceFactory? coordinateSequenceFactory, PrecisionModel? precisionModel, - Ordinates handleOrdinates, bool geographyAsDefault) - : base(coordinateSequenceFactory, precisionModel, handleOrdinates, geographyAsDefault) - { - } - public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) => Mappings.Find(type, dataTypeName, options); static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings, bool geographyAsDefault) { - foreach (var dataTypeName in geographyAsDefault ? new[] { "geography", "geometry" } : new[] { "geometry", "geography" }) + foreach (var dataTypeName in geographyAsDefault ? ["geography", "geometry"] : new[] { "geometry", "geography" }) { mappings.AddArrayType(dataTypeName); mappings.AddArrayType(dataTypeName); diff --git a/src/Npgsql.NetTopologySuite/Npgsql.NetTopologySuite.csproj b/src/Npgsql.NetTopologySuite/Npgsql.NetTopologySuite.csproj index fd2342614b..91a4c268a0 100644 --- a/src/Npgsql.NetTopologySuite/Npgsql.NetTopologySuite.csproj +++ b/src/Npgsql.NetTopologySuite/Npgsql.NetTopologySuite.csproj @@ -4,9 +4,9 @@ NetTopologySuite plugin for Npgsql, allowing mapping of PostGIS geometry types to NetTopologySuite types. npgsql;postgresql;postgres;postgis;spatial;nettopologysuite;nts;ado;ado.net;database;sql README.md - net6.0 - net8.0 + net10.0 $(NoWarn);NU5104 + $(NoWarn);NPG9001 diff --git a/src/Npgsql.NetTopologySuite/NpgsqlNetTopologySuiteExtensions.cs b/src/Npgsql.NetTopologySuite/NpgsqlNetTopologySuiteExtensions.cs index a30d023891..76afcf886c 100644 --- a/src/Npgsql.NetTopologySuite/NpgsqlNetTopologySuiteExtensions.cs +++ b/src/Npgsql.NetTopologySuite/NpgsqlNetTopologySuiteExtensions.cs @@ -10,6 +10,7 @@ namespace Npgsql; /// public static class NpgsqlNetTopologySuiteExtensions { + // Note: defined for binary compatibility and NpgsqlConnection.GlobalTypeMapper. /// /// Sets up NetTopologySuite mappings for the PostGIS types. /// @@ -30,4 +31,26 @@ public static INpgsqlTypeMapper UseNetTopologySuite( mapper.AddTypeInfoResolverFactory(new NetTopologySuiteTypeInfoResolverFactory(coordinateSequenceFactory, precisionModel, handleOrdinates, geographyAsDefault)); return mapper; } + + /// + /// Sets up NetTopologySuite mappings for the PostGIS types. + /// + /// The type mapper to set up (global or connection-specific). + /// The factory which knows how to build a particular implementation of ICoordinateSequence from an array of Coordinates. + /// Specifies the grid of allowable points. + /// Specifies the ordinates which will be handled. Not specified ordinates will be ignored. + /// If is specified, an actual value will be taken from + /// the property of . + /// Specifies that the geography type is used for mapping by default. + public static TMapper UseNetTopologySuite( + this TMapper mapper, + CoordinateSequenceFactory? coordinateSequenceFactory = null, + PrecisionModel? precisionModel = null, + Ordinates handleOrdinates = Ordinates.None, + bool geographyAsDefault = false) + where TMapper : INpgsqlTypeMapper + { + mapper.AddTypeInfoResolverFactory(new NetTopologySuiteTypeInfoResolverFactory(coordinateSequenceFactory, precisionModel, handleOrdinates, geographyAsDefault)); + return mapper; + } } diff --git a/src/Npgsql.NetTopologySuite/PublicAPI.Unshipped.txt b/src/Npgsql.NetTopologySuite/PublicAPI.Unshipped.txt index ab058de62d..ab78bca1af 100644 --- a/src/Npgsql.NetTopologySuite/PublicAPI.Unshipped.txt +++ b/src/Npgsql.NetTopologySuite/PublicAPI.Unshipped.txt @@ -1 +1,2 @@ #nullable enable +static Npgsql.NpgsqlNetTopologySuiteExtensions.UseNetTopologySuite(this TMapper mapper, NetTopologySuite.Geometries.CoordinateSequenceFactory? coordinateSequenceFactory = null, NetTopologySuite.Geometries.PrecisionModel? precisionModel = null, NetTopologySuite.Geometries.Ordinates handleOrdinates = NetTopologySuite.Geometries.Ordinates.None, bool geographyAsDefault = false) -> TMapper diff --git a/src/Npgsql.NodaTime/Internal/DateIntervalConverter.cs b/src/Npgsql.NodaTime/Internal/DateIntervalConverter.cs index 5e25d8bfcc..1bf2d027df 100644 --- a/src/Npgsql.NodaTime/Internal/DateIntervalConverter.cs +++ b/src/Npgsql.NodaTime/Internal/DateIntervalConverter.cs @@ -6,17 +6,9 @@ namespace Npgsql.NodaTime.Internal; -public class DateIntervalConverter : PgStreamingConverter +public class DateIntervalConverter(PgConverter> rangeConverter, bool dateTimeInfinityConversions) + : PgStreamingConverter { - readonly bool _dateTimeInfinityConversions; - readonly PgConverter> _rangeConverter; - - public DateIntervalConverter(PgConverter> rangeConverter, bool dateTimeInfinityConversions) - { - _rangeConverter = rangeConverter; - _dateTimeInfinityConversions = dateTimeInfinityConversions; - } - public override DateInterval Read(PgReader reader) => Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); @@ -26,24 +18,24 @@ public override ValueTask ReadAsync(PgReader reader, CancellationT async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) { var range = async - ? await _rangeConverter.ReadAsync(reader, cancellationToken).ConfigureAwait(false) + ? await rangeConverter.ReadAsync(reader, cancellationToken).ConfigureAwait(false) // ReSharper disable once MethodHasAsyncOverloadWithCancellation - : _rangeConverter.Read(reader); + : rangeConverter.Read(reader); var upperBound = range.UpperBound; - if (upperBound != LocalDate.MaxIsoValue || !_dateTimeInfinityConversions) + if (upperBound != LocalDate.MaxIsoValue || !dateTimeInfinityConversions) upperBound -= Period.FromDays(1); return new(range.LowerBound, upperBound); } public override Size GetSize(SizeContext context, DateInterval value, ref object? writeState) - => _rangeConverter.GetSize(context, new NpgsqlRange(value.Start, value.End), ref writeState); + => rangeConverter.GetSize(context, new NpgsqlRange(value.Start, value.End), ref writeState); public override void Write(PgWriter writer, DateInterval value) - => _rangeConverter.Write(writer, new NpgsqlRange(value.Start, value.End)); + => rangeConverter.Write(writer, new NpgsqlRange(value.Start, value.End)); public override ValueTask WriteAsync(PgWriter writer, DateInterval value, CancellationToken cancellationToken = default) - => _rangeConverter.WriteAsync(writer, new NpgsqlRange(value.Start, value.End), cancellationToken); + => rangeConverter.WriteAsync(writer, new NpgsqlRange(value.Start, value.End), cancellationToken); } diff --git a/src/Npgsql.NodaTime/Internal/IntervalConverter.cs b/src/Npgsql.NodaTime/Internal/IntervalConverter.cs index 3ca9ca9ab0..f062079a4a 100644 --- a/src/Npgsql.NodaTime/Internal/IntervalConverter.cs +++ b/src/Npgsql.NodaTime/Internal/IntervalConverter.cs @@ -6,13 +6,8 @@ namespace Npgsql.NodaTime.Internal; -public class IntervalConverter : PgStreamingConverter +sealed class IntervalConverter(PgConverter> rangeConverter, bool dateTimeInfinityConversions) : PgStreamingConverter { - readonly PgConverter> _rangeConverter; - - public IntervalConverter(PgConverter> rangeConverter) - => _rangeConverter = rangeConverter; - public override Interval Read(PgReader reader) => Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); @@ -22,9 +17,9 @@ public override ValueTask ReadAsync(PgReader reader, CancellationToken async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) { var range = async - ? await _rangeConverter.ReadAsync(reader, cancellationToken).ConfigureAwait(false) + ? await rangeConverter.ReadAsync(reader, cancellationToken).ConfigureAwait(false) // ReSharper disable once MethodHasAsyncOverloadWithCancellation - : _rangeConverter.Read(reader); + : rangeConverter.Read(reader); // NodaTime Interval includes the start instant and excludes the end instant. Instant? start = range.LowerBoundInfinite @@ -32,7 +27,12 @@ async ValueTask Read(bool async, PgReader reader, CancellationToken ca : range.LowerBoundIsInclusive ? range.LowerBound : range.LowerBound + Duration.Epsilon; - Instant? end = range.UpperBoundInfinite + // For ranges with element types with infinity values (datetime, date etc.) an + // inclusive lower/upper bound causes their -/+ infinity (respectively) to fall within the range. + // If those values are returned for such a range postgres will not mark the affected bound as infinite accordingly. + // This is documented in https://www.postgresql.org/docs/current/rangetypes.html#RANGETYPES-INFINITE + // As NodaTime uses an exclusive upper bound we must consider this case as being another form of infinity (null). + Instant? end = range.UpperBoundInfinite || (dateTimeInfinityConversions && range.UpperBoundIsInclusive && range.UpperBound == Instant.MaxValue) ? null : range.UpperBoundIsInclusive ? range.UpperBound + Duration.Epsilon @@ -42,13 +42,13 @@ async ValueTask Read(bool async, PgReader reader, CancellationToken ca } public override Size GetSize(SizeContext context, Interval value, ref object? writeState) - => _rangeConverter.GetSize(context, IntervalToNpgsqlRange(value), ref writeState); + => rangeConverter.GetSize(context, IntervalToNpgsqlRange(value), ref writeState); public override void Write(PgWriter writer, Interval value) - => _rangeConverter.Write(writer, IntervalToNpgsqlRange(value)); + => rangeConverter.Write(writer, IntervalToNpgsqlRange(value)); public override ValueTask WriteAsync(PgWriter writer, Interval value, CancellationToken cancellationToken = default) - => _rangeConverter.WriteAsync(writer, IntervalToNpgsqlRange(value), cancellationToken); + => rangeConverter.WriteAsync(writer, IntervalToNpgsqlRange(value), cancellationToken); static NpgsqlRange IntervalToNpgsqlRange(Interval interval) => new( diff --git a/src/Npgsql.NodaTime/Internal/LegacyConverters.cs b/src/Npgsql.NodaTime/Internal/LegacyConverters.cs index 54393a4821..c0b4b82268 100644 --- a/src/Npgsql.NodaTime/Internal/LegacyConverters.cs +++ b/src/Npgsql.NodaTime/Internal/LegacyConverters.cs @@ -5,17 +5,9 @@ namespace Npgsql.NodaTime.Internal; -sealed class LegacyTimestampTzZonedDateTimeConverter : PgBufferedConverter +sealed class LegacyTimestampTzZonedDateTimeConverter(DateTimeZone dateTimeZone, bool dateTimeInfinityConversions) + : PgBufferedConverter { - readonly DateTimeZone _dateTimeZone; - readonly bool _dateTimeInfinityConversions; - - public LegacyTimestampTzZonedDateTimeConverter(DateTimeZone dateTimeZone, bool dateTimeInfinityConversions) - { - _dateTimeZone = dateTimeZone; - _dateTimeInfinityConversions = dateTimeInfinityConversions; - } - public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); @@ -24,34 +16,26 @@ public override bool CanConvert(DataFormat format, out BufferRequirements buffer protected override ZonedDateTime ReadCore(PgReader reader) { - var instant = DecodeInstant(reader.ReadInt64(), _dateTimeInfinityConversions); - if (_dateTimeInfinityConversions && (instant == Instant.MaxValue || instant == Instant.MinValue)) + var instant = DecodeInstant(reader.ReadInt64(), dateTimeInfinityConversions); + if (dateTimeInfinityConversions && (instant == Instant.MaxValue || instant == Instant.MinValue)) throw new InvalidCastException("Infinity values not supported for timestamp with time zone"); - return instant.InZone(_dateTimeZone); + return instant.InZone(dateTimeZone); } protected override void WriteCore(PgWriter writer, ZonedDateTime value) { var instant = value.ToInstant(); - if (_dateTimeInfinityConversions && (instant == Instant.MaxValue || instant == Instant.MinValue)) + if (dateTimeInfinityConversions && (instant == Instant.MaxValue || instant == Instant.MinValue)) throw new ArgumentException("Infinity values not supported for timestamp with time zone"); - writer.WriteInt64(EncodeInstant(instant, _dateTimeInfinityConversions)); + writer.WriteInt64(EncodeInstant(instant, dateTimeInfinityConversions)); } } -sealed class LegacyTimestampTzOffsetDateTimeConverter : PgBufferedConverter +sealed class LegacyTimestampTzOffsetDateTimeConverter(DateTimeZone dateTimeZone, bool dateTimeInfinityConversions) + : PgBufferedConverter { - readonly bool _dateTimeInfinityConversions; - readonly DateTimeZone _dateTimeZone; - - public LegacyTimestampTzOffsetDateTimeConverter(DateTimeZone dateTimeZone, bool dateTimeInfinityConversions) - { - _dateTimeInfinityConversions = dateTimeInfinityConversions; - _dateTimeZone = dateTimeZone; - } - public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); @@ -60,17 +44,17 @@ public override bool CanConvert(DataFormat format, out BufferRequirements buffer protected override OffsetDateTime ReadCore(PgReader reader) { - var instant = DecodeInstant(reader.ReadInt64(), _dateTimeInfinityConversions); - if (_dateTimeInfinityConversions && (instant == Instant.MaxValue || instant == Instant.MinValue)) + var instant = DecodeInstant(reader.ReadInt64(), dateTimeInfinityConversions); + if (dateTimeInfinityConversions && (instant == Instant.MaxValue || instant == Instant.MinValue)) throw new InvalidCastException("Infinity values not supported for timestamp with time zone"); - return instant.InZone(_dateTimeZone).ToOffsetDateTime(); + return instant.InZone(dateTimeZone).ToOffsetDateTime(); } protected override void WriteCore(PgWriter writer, OffsetDateTime value) { var instant = value.ToInstant(); - if (_dateTimeInfinityConversions && (instant == Instant.MaxValue || instant == Instant.MinValue)) + if (dateTimeInfinityConversions && (instant == Instant.MaxValue || instant == Instant.MinValue)) throw new ArgumentException("Infinity values not supported for timestamp with time zone"); writer.WriteInt64(EncodeInstant(instant, true)); diff --git a/src/Npgsql.NodaTime/Internal/LocalDateConverter.cs b/src/Npgsql.NodaTime/Internal/LocalDateConverter.cs index e6be7fe69b..ffaa6e8d45 100644 --- a/src/Npgsql.NodaTime/Internal/LocalDateConverter.cs +++ b/src/Npgsql.NodaTime/Internal/LocalDateConverter.cs @@ -5,13 +5,8 @@ namespace Npgsql.NodaTime.Internal; -sealed class LocalDateConverter : PgBufferedConverter +sealed class LocalDateConverter(bool dateTimeInfinityConversions) : PgBufferedConverter { - readonly bool _dateTimeInfinityConversions; - - public LocalDateConverter(bool dateTimeInfinityConversions) - => _dateTimeInfinityConversions = dateTimeInfinityConversions; - public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(int)); @@ -21,10 +16,10 @@ public override bool CanConvert(DataFormat format, out BufferRequirements buffer protected override LocalDate ReadCore(PgReader reader) => reader.ReadInt32() switch { - int.MaxValue => _dateTimeInfinityConversions + int.MaxValue => dateTimeInfinityConversions ? LocalDate.MaxIsoValue : throw new InvalidCastException(NpgsqlNodaTimeStrings.CannotReadInfinityValue), - int.MinValue => _dateTimeInfinityConversions + int.MinValue => dateTimeInfinityConversions ? LocalDate.MinIsoValue : throw new InvalidCastException(NpgsqlNodaTimeStrings.CannotReadInfinityValue), var value => new LocalDate().PlusDays(value + 730119) @@ -32,7 +27,7 @@ protected override LocalDate ReadCore(PgReader reader) protected override void WriteCore(PgWriter writer, LocalDate value) { - if (_dateTimeInfinityConversions) + if (dateTimeInfinityConversions) { if (value == LocalDate.MaxIsoValue) { diff --git a/src/Npgsql.NodaTime/Internal/NodaTimeTypeInfoResolverFactory.Multirange.cs b/src/Npgsql.NodaTime/Internal/NodaTimeTypeInfoResolverFactory.Multirange.cs index 42c6360dad..fdd8d4c78f 100644 --- a/src/Npgsql.NodaTime/Internal/NodaTimeTypeInfoResolverFactory.Multirange.cs +++ b/src/Npgsql.NodaTime/Internal/NodaTimeTypeInfoResolverFactory.Multirange.cs @@ -31,12 +31,12 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) mappings.AddType(TimestampTzMultirangeDataTypeName, static (options, mapping, _) => mapping.CreateInfo(options, CreateArrayMultirangeConverter(new IntervalConverter( - CreateRangeConverter(new InstantConverter(options.EnableDateTimeInfinityConversions), options)), options)), + CreateRangeConverter(new InstantConverter(options.EnableDateTimeInfinityConversions), options), options.EnableDateTimeInfinityConversions), options)), isDefault: true); mappings.AddType>(TimestampTzMultirangeDataTypeName, static (options, mapping, _) => mapping.CreateInfo(options, CreateListMultirangeConverter(new IntervalConverter( - CreateRangeConverter(new InstantConverter(options.EnableDateTimeInfinityConversions), options)), options))); + CreateRangeConverter(new InstantConverter(options.EnableDateTimeInfinityConversions), options), options.EnableDateTimeInfinityConversions), options))); mappings.AddType[]>(TimestampTzMultirangeDataTypeName, static (options, mapping, _) => mapping.CreateInfo(options, diff --git a/src/Npgsql.NodaTime/Internal/NodaTimeTypeInfoResolverFactory.Range.cs b/src/Npgsql.NodaTime/Internal/NodaTimeTypeInfoResolverFactory.Range.cs index f62669333c..8958f88846 100644 --- a/src/Npgsql.NodaTime/Internal/NodaTimeTypeInfoResolverFactory.Range.cs +++ b/src/Npgsql.NodaTime/Internal/NodaTimeTypeInfoResolverFactory.Range.cs @@ -31,7 +31,7 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) static (options, mapping, _) => mapping.CreateInfo(options, new IntervalConverter( - CreateRangeConverter(new InstantConverter(options.EnableDateTimeInfinityConversions), options))), + CreateRangeConverter(new InstantConverter(options.EnableDateTimeInfinityConversions), options), options.EnableDateTimeInfinityConversions)), isDefault: true); mappings.AddStructType>(TimestampTzRangeDataTypeName, static (options, mapping, _) => mapping.CreateInfo(options, diff --git a/src/Npgsql.NodaTime/Internal/NodaTimeTypeInfoResolverFactory.cs b/src/Npgsql.NodaTime/Internal/NodaTimeTypeInfoResolverFactory.cs index dce258b453..b010ce58a6 100644 --- a/src/Npgsql.NodaTime/Internal/NodaTimeTypeInfoResolverFactory.cs +++ b/src/Npgsql.NodaTime/Internal/NodaTimeTypeInfoResolverFactory.cs @@ -31,47 +31,47 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) // timestamp and timestamptz, legacy and non-legacy modes if (LegacyTimestampBehavior) { + // timestamp is the default for writing an Instant. + + // timestamp + mappings.AddStructType(TimestampDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, new InstantConverter(options.EnableDateTimeInfinityConversions)), isDefault: true); + mappings.AddStructType(TimestampDataTypeName, + static (options, mapping, _) => + mapping.CreateInfo(options, new LocalDateTimeConverter(options.EnableDateTimeInfinityConversions))); + // timestamptz - mappings.AddStructType(new DataTypeName("pg_catalog.timestamptz"), + mappings.AddStructType(TimestampTzDataTypeName, static (options, mapping, _) => - mapping.CreateInfo(options, new InstantConverter(options.EnableDateTimeInfinityConversions)), isDefault: false); - mappings.AddStructType(new DataTypeName("pg_catalog.timestamptz"), + mapping.CreateInfo(options, new InstantConverter(options.EnableDateTimeInfinityConversions)), isDefault: true); + mappings.AddStructType(TimestampTzDataTypeName, static (options, mapping, _) => mapping.CreateInfo(options, new LegacyTimestampTzZonedDateTimeConverter( DateTimeZoneProviders.Tzdb[options.TimeZone], options.EnableDateTimeInfinityConversions))); - mappings.AddStructType(new DataTypeName("pg_catalog.timestamptz"), + mappings.AddStructType(TimestampTzDataTypeName, static (options, mapping, _) => mapping.CreateInfo(options, new LegacyTimestampTzOffsetDateTimeConverter( DateTimeZoneProviders.Tzdb[options.TimeZone], options.EnableDateTimeInfinityConversions))); - + } + else + { // timestamp - mappings.AddStructType(TimestampDataTypeName, - static (options, mapping, _) => - mapping.CreateInfo(options, new InstantConverter(options.EnableDateTimeInfinityConversions)), - isDefault: true); mappings.AddStructType(TimestampDataTypeName, static (options, mapping, _) => mapping.CreateInfo(options, new LocalDateTimeConverter(options.EnableDateTimeInfinityConversions)), - isDefault: false); - } - else - { + isDefault: true); + // timestamptz mappings.AddStructType(TimestampTzDataTypeName, static (options, mapping, _) => mapping.CreateInfo(options, new InstantConverter(options.EnableDateTimeInfinityConversions)), isDefault: true); - mappings.AddStructType(new DataTypeName("pg_catalog.timestamptz"), + mappings.AddStructType(TimestampTzDataTypeName, static (options, mapping, _) => mapping.CreateInfo(options, new ZonedDateTimeConverter(options.EnableDateTimeInfinityConversions))); - mappings.AddStructType(new DataTypeName("pg_catalog.timestamptz"), + mappings.AddStructType(TimestampTzDataTypeName, static (options, mapping, _) => mapping.CreateInfo(options, new OffsetDateTimeConverter(options.EnableDateTimeInfinityConversions))); - - // timestamp - mappings.AddStructType(TimestampDataTypeName, - static (options, mapping, _) => - mapping.CreateInfo(options, new LocalDateTimeConverter(options.EnableDateTimeInfinityConversions)), - isDefault: true); } // date @@ -89,7 +89,7 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) // interval mappings.AddType(IntervalDataTypeName, - static (options, mapping, _) => mapping.CreateInfo(options, new PeriodConverter()), isDefault: true); + static (options, mapping, _) => mapping.CreateInfo(options, new PeriodConverter(options.EnableDateTimeInfinityConversions)), isDefault: true); mappings.AddStructType(IntervalDataTypeName, static (options, mapping, _) => mapping.CreateInfo(options, new DurationConverter())); @@ -107,34 +107,27 @@ sealed class ArrayResolver : Resolver, IPgTypeInfoResolver static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) { - // timestamptz - mappings.AddStructArrayType(TimestampTzDataTypeName); - mappings.AddStructArrayType(TimestampTzDataTypeName); - mappings.AddStructArrayType(TimestampTzDataTypeName); - - // timestamp if (LegacyTimestampBehavior) { + // timestamp mappings.AddStructArrayType(TimestampDataTypeName); + mappings.AddStructArrayType(TimestampDataTypeName); - mappings.AddStructType(TimestampDataTypeName, - static (options, mapping, _) => - mapping.CreateInfo(options, new InstantConverter(options.EnableDateTimeInfinityConversions)), - isDefault: true); - mappings.AddStructType(TimestampDataTypeName, - static (options, mapping, _) => - mapping.CreateInfo(options, new LocalDateTimeConverter(options.EnableDateTimeInfinityConversions)), - isDefault: false); + // timestamptz + mappings.AddStructArrayType(TimestampTzDataTypeName); + mappings.AddStructArrayType(TimestampTzDataTypeName); + mappings.AddStructArrayType(TimestampTzDataTypeName); } else { - mappings.AddStructType(TimestampDataTypeName, - static (options, mapping, _) => - mapping.CreateInfo(options, new LocalDateTimeConverter(options.EnableDateTimeInfinityConversions)), - isDefault: true); - } + // timestamp + mappings.AddStructArrayType(TimestampDataTypeName); - mappings.AddStructArrayType(TimestampDataTypeName); + // timestamptz + mappings.AddStructArrayType(TimestampTzDataTypeName); + mappings.AddStructArrayType(TimestampTzDataTypeName); + mappings.AddStructArrayType(TimestampTzDataTypeName); + } // other mappings.AddStructArrayType(DateDataTypeName); diff --git a/src/Npgsql.NodaTime/Internal/PeriodConverter.cs b/src/Npgsql.NodaTime/Internal/PeriodConverter.cs index 4dbde48dbc..1d768109c4 100644 --- a/src/Npgsql.NodaTime/Internal/PeriodConverter.cs +++ b/src/Npgsql.NodaTime/Internal/PeriodConverter.cs @@ -1,9 +1,11 @@ +using System; using NodaTime; using Npgsql.Internal; +using Npgsql.NodaTime.Properties; namespace Npgsql.NodaTime.Internal; -sealed class PeriodConverter : PgBufferedConverter +sealed class PeriodConverter(bool dateTimeInfinityConversions) : PgBufferedConverter { public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { @@ -17,6 +19,15 @@ protected override Period ReadCore(PgReader reader) var days = reader.ReadInt32(); var totalMonths = reader.ReadInt32(); + if (microsecondsInDay == long.MaxValue && days == int.MaxValue && totalMonths == int.MaxValue) + return dateTimeInfinityConversions + ? Period.MaxValue + : throw new InvalidCastException(NpgsqlNodaTimeStrings.CannotReadInfinityValue); + if (microsecondsInDay == long.MinValue && days == int.MinValue && totalMonths == int.MinValue) + return dateTimeInfinityConversions + ? Period.MinValue + : throw new InvalidCastException(NpgsqlNodaTimeStrings.CannotReadInfinityValue); + // NodaTime will normalize most things (i.e. nanoseconds to milliseconds, seconds...) // but it will not normalize months to years. var months = totalMonths % 12; @@ -33,14 +44,45 @@ protected override Period ReadCore(PgReader reader) protected override void WriteCore(PgWriter writer, Period value) { - // Note that the end result must be long - // see #3438 - var microsecondsInDay = - (((value.Hours * NodaConstants.MinutesPerHour + value.Minutes) * NodaConstants.SecondsPerMinute + value.Seconds) * NodaConstants.MillisecondsPerSecond + value.Milliseconds) * 1000 + - value.Nanoseconds / 1000; // Take the microseconds, discard the nanosecond remainder - - writer.WriteInt64(microsecondsInDay); - writer.WriteInt32(value.Weeks * 7 + value.Days); // days - writer.WriteInt32(value.Years * 12 + value.Months); // months + if (dateTimeInfinityConversions) + { + if (value == Period.MaxValue) + { + writer.WriteInt64(long.MaxValue); // microseconds + writer.WriteInt32(int.MaxValue); // days + writer.WriteInt32(int.MaxValue); // months + return; + } + + if (value == Period.MinValue) + { + writer.WriteInt64(long.MinValue); // microseconds + writer.WriteInt32(int.MinValue); // days + writer.WriteInt32(int.MinValue); // months + return; + } + } + + // We have to normalize the value as otherwise we might get a value with 0 everything except for ticks, which we ignore + value = value.Normalize(); + + try + { + checked + { + // Note that the end result must be long + // see #3438 + var microsecondsInDay = + (((value.Hours * NodaConstants.MinutesPerHour + value.Minutes) * NodaConstants.SecondsPerMinute + value.Seconds) * NodaConstants.MillisecondsPerSecond + value.Milliseconds) * 1000 + + value.Nanoseconds / 1000; // Take the microseconds, discard the nanosecond remainder + writer.WriteInt64(microsecondsInDay); + writer.WriteInt32(value.Weeks * 7 + value.Days); // days + writer.WriteInt32(value.Years * 12 + value.Months); // months + } + } + catch (OverflowException ex) + { + throw new ArgumentException(NpgsqlNodaTimeStrings.CannotWritePeriodDueToOverflow, ex); + } } } diff --git a/src/Npgsql.NodaTime/Internal/TimestampConverters.cs b/src/Npgsql.NodaTime/Internal/TimestampConverters.cs index 6808503638..4ac841c80e 100644 --- a/src/Npgsql.NodaTime/Internal/TimestampConverters.cs +++ b/src/Npgsql.NodaTime/Internal/TimestampConverters.cs @@ -5,13 +5,8 @@ namespace Npgsql.NodaTime.Internal; -sealed class InstantConverter : PgBufferedConverter +sealed class InstantConverter(bool dateTimeInfinityConversions) : PgBufferedConverter { - readonly bool _dateTimeInfinityConversions; - - public InstantConverter(bool dateTimeInfinityConversions) - => _dateTimeInfinityConversions = dateTimeInfinityConversions; - public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); @@ -19,19 +14,14 @@ public override bool CanConvert(DataFormat format, out BufferRequirements buffer } protected override Instant ReadCore(PgReader reader) - => DecodeInstant(reader.ReadInt64(), _dateTimeInfinityConversions); + => DecodeInstant(reader.ReadInt64(), dateTimeInfinityConversions); protected override void WriteCore(PgWriter writer, Instant value) - => writer.WriteInt64(EncodeInstant(value, _dateTimeInfinityConversions)); + => writer.WriteInt64(EncodeInstant(value, dateTimeInfinityConversions)); } -sealed class ZonedDateTimeConverter : PgBufferedConverter +sealed class ZonedDateTimeConverter(bool dateTimeInfinityConversions) : PgBufferedConverter { - readonly bool _dateTimeInfinityConversions; - - public ZonedDateTimeConverter(bool dateTimeInfinityConversions) - => _dateTimeInfinityConversions = dateTimeInfinityConversions; - public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); @@ -39,7 +29,7 @@ public override bool CanConvert(DataFormat format, out BufferRequirements buffer } protected override ZonedDateTime ReadCore(PgReader reader) - => DecodeInstant(reader.ReadInt64(), _dateTimeInfinityConversions).InUtc(); + => DecodeInstant(reader.ReadInt64(), dateTimeInfinityConversions).InUtc(); protected override void WriteCore(PgWriter writer, ZonedDateTime value) { @@ -51,17 +41,12 @@ protected override void WriteCore(PgWriter writer, ZonedDateTime value) "See the Npgsql.EnableLegacyTimestampBehavior AppContext switch to enable legacy behavior."); } - writer.WriteInt64(EncodeInstant(value.ToInstant(), _dateTimeInfinityConversions)); + writer.WriteInt64(EncodeInstant(value.ToInstant(), dateTimeInfinityConversions)); } } -sealed class OffsetDateTimeConverter : PgBufferedConverter +sealed class OffsetDateTimeConverter(bool dateTimeInfinityConversions) : PgBufferedConverter { - readonly bool _dateTimeInfinityConversions; - - public OffsetDateTimeConverter(bool dateTimeInfinityConversions) - => _dateTimeInfinityConversions = dateTimeInfinityConversions; - public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); @@ -69,7 +54,7 @@ public override bool CanConvert(DataFormat format, out BufferRequirements buffer } protected override OffsetDateTime ReadCore(PgReader reader) - => DecodeInstant(reader.ReadInt64(), _dateTimeInfinityConversions).WithOffset(Offset.Zero); + => DecodeInstant(reader.ReadInt64(), dateTimeInfinityConversions).WithOffset(Offset.Zero); protected override void WriteCore(PgWriter writer, OffsetDateTime value) { @@ -81,17 +66,12 @@ protected override void WriteCore(PgWriter writer, OffsetDateTime value) "See the Npgsql.EnableLegacyTimestampBehavior AppContext switch to enable legacy behavior."); } - writer.WriteInt64(EncodeInstant(value.ToInstant(), _dateTimeInfinityConversions)); + writer.WriteInt64(EncodeInstant(value.ToInstant(), dateTimeInfinityConversions)); } } -sealed class LocalDateTimeConverter : PgBufferedConverter +sealed class LocalDateTimeConverter(bool dateTimeInfinityConversions) : PgBufferedConverter { - readonly bool _dateTimeInfinityConversions; - - public LocalDateTimeConverter(bool dateTimeInfinityConversions) - => _dateTimeInfinityConversions = dateTimeInfinityConversions; - public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); @@ -99,8 +79,8 @@ public override bool CanConvert(DataFormat format, out BufferRequirements buffer } protected override LocalDateTime ReadCore(PgReader reader) - => DecodeInstant(reader.ReadInt64(), _dateTimeInfinityConversions).InUtc().LocalDateTime; + => DecodeInstant(reader.ReadInt64(), dateTimeInfinityConversions).InUtc().LocalDateTime; protected override void WriteCore(PgWriter writer, LocalDateTime value) - => writer.WriteInt64(EncodeInstant(value.InUtc().ToInstant(), _dateTimeInfinityConversions)); + => writer.WriteInt64(EncodeInstant(value.InUtc().ToInstant(), dateTimeInfinityConversions)); } diff --git a/src/Npgsql.NodaTime/Npgsql.NodaTime.csproj b/src/Npgsql.NodaTime/Npgsql.NodaTime.csproj index 4ac9e068fa..1fd5d4b767 100644 --- a/src/Npgsql.NodaTime/Npgsql.NodaTime.csproj +++ b/src/Npgsql.NodaTime/Npgsql.NodaTime.csproj @@ -4,8 +4,8 @@ NodaTime plugin for Npgsql, allowing mapping of PostgreSQL date/time types to NodaTime types. npgsql;postgresql;postgres;nodatime;date;time;ado;ado;net;database;sql README.md - net6.0 - net8.0 + net10.0 + $(NoWarn);NPG9001 diff --git a/src/Npgsql.NodaTime/NpgsqlNodaTimeExtensions.cs b/src/Npgsql.NodaTime/NpgsqlNodaTimeExtensions.cs index 9ebf42e83f..585143f3fe 100644 --- a/src/Npgsql.NodaTime/NpgsqlNodaTimeExtensions.cs +++ b/src/Npgsql.NodaTime/NpgsqlNodaTimeExtensions.cs @@ -9,6 +9,7 @@ namespace Npgsql; /// public static class NpgsqlNodaTimeExtensions { + // Note: defined for binary compatibility and NpgsqlConnection.GlobalTypeMapper. /// /// Sets up NodaTime mappings for the PostgreSQL date/time types. /// @@ -18,4 +19,14 @@ public static INpgsqlTypeMapper UseNodaTime(this INpgsqlTypeMapper mapper) mapper.AddTypeInfoResolverFactory(new NodaTimeTypeInfoResolverFactory()); return mapper; } + + /// + /// Sets up NodaTime mappings for the PostgreSQL date/time types. + /// + /// The type mapper to set up (global or connection-specific) + public static TMapper UseNodaTime(this TMapper mapper) where TMapper : INpgsqlTypeMapper + { + mapper.AddTypeInfoResolverFactory(new NodaTimeTypeInfoResolverFactory()); + return mapper; + } } diff --git a/src/Npgsql.NodaTime/Properties/NpgsqlNodaTimeStrings.Designer.cs b/src/Npgsql.NodaTime/Properties/NpgsqlNodaTimeStrings.Designer.cs index bc6511ea9a..ab29289106 100644 --- a/src/Npgsql.NodaTime/Properties/NpgsqlNodaTimeStrings.Designer.cs +++ b/src/Npgsql.NodaTime/Properties/NpgsqlNodaTimeStrings.Designer.cs @@ -11,32 +11,46 @@ namespace Npgsql.NodaTime.Properties { using System; - [System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "4.0.0.0")] - [System.Diagnostics.DebuggerNonUserCodeAttribute()] - [System.Runtime.CompilerServices.CompilerGeneratedAttribute()] + /// + /// A strongly-typed resource class, for looking up localized strings, etc. + /// + // This class was auto-generated by the StronglyTypedResourceBuilder + // class via a tool like ResGen or Visual Studio. + // To add or remove a member, edit your .ResX file then rerun ResGen + // with the /str option, or rebuild your VS project. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "4.0.0.0")] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] + [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()] internal class NpgsqlNodaTimeStrings { - private static System.Resources.ResourceManager resourceMan; + private static global::System.Resources.ResourceManager resourceMan; - private static System.Globalization.CultureInfo resourceCulture; + private static global::System.Globalization.CultureInfo resourceCulture; - [System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")] + [global::System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")] internal NpgsqlNodaTimeStrings() { } - [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Advanced)] - internal static System.Resources.ResourceManager ResourceManager { + /// + /// Returns the cached ResourceManager instance used by this class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Resources.ResourceManager ResourceManager { get { - if (object.Equals(null, resourceMan)) { - System.Resources.ResourceManager temp = new System.Resources.ResourceManager("Npgsql.NodaTime.Properties.NpgsqlNodaTimeStrings", typeof(NpgsqlNodaTimeStrings).Assembly); + if (object.ReferenceEquals(resourceMan, null)) { + global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("Npgsql.NodaTime.Properties.NpgsqlNodaTimeStrings", typeof(NpgsqlNodaTimeStrings).Assembly); resourceMan = temp; } return resourceMan; } } - [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Advanced)] - internal static System.Globalization.CultureInfo Culture { + /// + /// Overrides the current thread's CurrentUICulture property for all + /// resource lookups using this strongly typed resource class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Globalization.CultureInfo Culture { get { return resourceCulture; } @@ -45,16 +59,31 @@ internal static System.Globalization.CultureInfo Culture { } } + /// + /// Looks up a localized string similar to Cannot read infinity value since Npgsql.DisableDateTimeInfinityConversions is enabled.. + /// internal static string CannotReadInfinityValue { get { return ResourceManager.GetString("CannotReadInfinityValue", resourceCulture); } } + /// + /// Looks up a localized string similar to Cannot read PostgreSQL interval with non-zero months to NodaTime Duration. Try reading as a NodaTime Period instead.. + /// internal static string CannotReadIntervalWithMonthsAsDuration { get { return ResourceManager.GetString("CannotReadIntervalWithMonthsAsDuration", resourceCulture); } } + + /// + /// Looks up a localized string similar to Cannot write NodaTime's Period because it's out of range for the PG interval type.. + /// + internal static string CannotWritePeriodDueToOverflow { + get { + return ResourceManager.GetString("CannotWritePeriodDueToOverflow", resourceCulture); + } + } } } diff --git a/src/Npgsql.NodaTime/Properties/NpgsqlNodaTimeStrings.resx b/src/Npgsql.NodaTime/Properties/NpgsqlNodaTimeStrings.resx index d3329f2a80..f0090afb83 100644 --- a/src/Npgsql.NodaTime/Properties/NpgsqlNodaTimeStrings.resx +++ b/src/Npgsql.NodaTime/Properties/NpgsqlNodaTimeStrings.resx @@ -24,4 +24,7 @@ Cannot read PostgreSQL interval with non-zero months to NodaTime Duration. Try reading as a NodaTime Period instead. + + Cannot write NodaTime's Period because it's out of range for the PG interval type. + diff --git a/src/Npgsql.NodaTime/PublicAPI.Unshipped.txt b/src/Npgsql.NodaTime/PublicAPI.Unshipped.txt index ab058de62d..f1ab4e3c0c 100644 --- a/src/Npgsql.NodaTime/PublicAPI.Unshipped.txt +++ b/src/Npgsql.NodaTime/PublicAPI.Unshipped.txt @@ -1 +1,2 @@ #nullable enable +static Npgsql.NpgsqlNodaTimeExtensions.UseNodaTime(this TMapper mapper) -> TMapper diff --git a/src/Npgsql.OpenTelemetry/MeterProviderBuilderExtensions.cs b/src/Npgsql.OpenTelemetry/MeterProviderBuilderExtensions.cs new file mode 100644 index 0000000000..90f81c4cc3 --- /dev/null +++ b/src/Npgsql.OpenTelemetry/MeterProviderBuilderExtensions.cs @@ -0,0 +1,19 @@ +using System; +using OpenTelemetry.Metrics; + +// ReSharper disable once CheckNamespace +namespace Npgsql; + +/// +/// Extension method for setting up Npgsql OpenTelemetry metrics. +/// +public static class MeterProviderBuilderExtensions +{ + /// + /// Subscribes to the Npgsql metrics reporter to enable OpenTelemetry metrics. + /// + public static MeterProviderBuilder AddNpgsqlInstrumentation( + this MeterProviderBuilder builder, + Action? options = null) + => builder.AddMeter("Npgsql"); +} diff --git a/src/Npgsql.OpenTelemetry/Npgsql.OpenTelemetry.csproj b/src/Npgsql.OpenTelemetry/Npgsql.OpenTelemetry.csproj index 7aff759251..18592f8a5f 100644 --- a/src/Npgsql.OpenTelemetry/Npgsql.OpenTelemetry.csproj +++ b/src/Npgsql.OpenTelemetry/Npgsql.OpenTelemetry.csproj @@ -2,8 +2,7 @@ Shay Rojansky - net6.0 - net8.0 + net10.0 npgsql;postgresql;postgres;ado;ado.net;database;sql;opentelemetry;tracing;diagnostics;instrumentation README.md diff --git a/src/Npgsql.OpenTelemetry/TracerProviderBuilderExtensions.cs b/src/Npgsql.OpenTelemetry/TracerProviderBuilderExtensions.cs index 0c34138278..1568d2d080 100644 --- a/src/Npgsql.OpenTelemetry/TracerProviderBuilderExtensions.cs +++ b/src/Npgsql.OpenTelemetry/TracerProviderBuilderExtensions.cs @@ -12,8 +12,6 @@ public static class TracerProviderBuilderExtensions /// /// Subscribes to the Npgsql activity source to enable OpenTelemetry tracing. /// - public static TracerProviderBuilder AddNpgsql( - this TracerProviderBuilder builder, - Action? options = null) + public static TracerProviderBuilder AddNpgsql(this TracerProviderBuilder builder) => builder.AddSource("Npgsql"); -} \ No newline at end of file +} diff --git a/src/Npgsql.SourceGenerators/Npgsql.SourceGenerators.csproj b/src/Npgsql.SourceGenerators/Npgsql.SourceGenerators.csproj index bc0f37e9bb..4f5c1eb42d 100644 --- a/src/Npgsql.SourceGenerators/Npgsql.SourceGenerators.csproj +++ b/src/Npgsql.SourceGenerators/Npgsql.SourceGenerators.csproj @@ -2,6 +2,7 @@ netstandard2.0 + false 1591 true diff --git a/src/Npgsql.SourceGenerators/NpgsqlConnectionStringBuilderSourceGenerator.cs b/src/Npgsql.SourceGenerators/NpgsqlConnectionStringBuilderSourceGenerator.cs index 665789e74e..c7c7228321 100644 --- a/src/Npgsql.SourceGenerators/NpgsqlConnectionStringBuilderSourceGenerator.cs +++ b/src/Npgsql.SourceGenerators/NpgsqlConnectionStringBuilderSourceGenerator.cs @@ -9,7 +9,7 @@ namespace Npgsql.SourceGenerators; [Generator] -public class NpgsqlConnectionStringBuilderSourceGenerator : ISourceGenerator +public class NpgsqlConnectionStringBuilderSourceGenerator : IIncrementalGenerator { static readonly DiagnosticDescriptor InternalError = new DiagnosticDescriptor( id: "PGXXXX", @@ -19,106 +19,107 @@ public class NpgsqlConnectionStringBuilderSourceGenerator : ISourceGenerator DiagnosticSeverity.Error, isEnabledByDefault: true); - public void Initialize(GeneratorInitializationContext context) {} - - public void Execute(GeneratorExecutionContext context) + public void Initialize(IncrementalGeneratorInitializationContext context) { - if (context.Compilation.Assembly.GetTypeByMetadataName("Npgsql.NpgsqlConnectionStringBuilder") is not { } type) - return; - - if (context.Compilation.Assembly.GetTypeByMetadataName("Npgsql.NpgsqlConnectionStringPropertyAttribute") is not - { } connectionStringPropertyAttribute) - { - context.ReportDiagnostic(Diagnostic.Create( - InternalError, - location: null, - "Could not find Npgsql.NpgsqlConnectionStringPropertyAttribute")); - return; - } - - var obsoleteAttribute = context.Compilation.GetTypeByMetadataName("System.ObsoleteAttribute"); - var displayNameAttribute = context.Compilation.GetTypeByMetadataName("System.ComponentModel.DisplayNameAttribute"); - var defaultValueAttribute = context.Compilation.GetTypeByMetadataName("System.ComponentModel.DefaultValueAttribute"); - - if (obsoleteAttribute is null || displayNameAttribute is null || defaultValueAttribute is null) + var compilationProvider = context.CompilationProvider; + context.RegisterSourceOutput(compilationProvider, (spc, compilation) => { - context.ReportDiagnostic(Diagnostic.Create( - InternalError, - location: null, - "Could not find ObsoleteAttribute, DisplayNameAttribute or DefaultValueAttribute")); - return; - } - - var properties = new List(); - var propertiesByKeyword = new Dictionary(); - foreach (var member in type.GetMembers()) - { - if (member is not IPropertySymbol property || - property.GetAttributes().FirstOrDefault(a => connectionStringPropertyAttribute.Equals(a.AttributeClass, SymbolEqualityComparer.Default)) is not { } propertyAttribute || - property.GetAttributes() - .FirstOrDefault(a => displayNameAttribute.Equals(a.AttributeClass, SymbolEqualityComparer.Default)) - ?.ConstructorArguments[0].Value is not string displayName) + var type = compilation.Assembly.GetTypeByMetadataName("Npgsql.NpgsqlConnectionStringBuilder"); + if (type is null) + return; + + var connectionStringPropertyAttribute = compilation.Assembly.GetTypeByMetadataName("Npgsql.NpgsqlConnectionStringPropertyAttribute"); + if (connectionStringPropertyAttribute is null) { - continue; + spc.ReportDiagnostic(Diagnostic.Create( + InternalError, + location: null, + "Could not find Npgsql.NpgsqlConnectionStringPropertyAttribute")); + return; } - var explicitDefaultValue = property.GetAttributes() - .FirstOrDefault(a => defaultValueAttribute.Equals(a.AttributeClass, SymbolEqualityComparer.Default)) - ?.ConstructorArguments[0].Value; - - if (explicitDefaultValue is string s) - explicitDefaultValue = '"' + s.Replace("\"", "\"\"") + '"'; + var obsoleteAttribute = compilation.GetTypeByMetadataName("System.ObsoleteAttribute"); + var displayNameAttribute = compilation.GetTypeByMetadataName("System.ComponentModel.DisplayNameAttribute"); + var defaultValueAttribute = compilation.GetTypeByMetadataName("System.ComponentModel.DefaultValueAttribute"); - if (explicitDefaultValue is not null && property.Type.TypeKind == TypeKind.Enum) + if (obsoleteAttribute is null || displayNameAttribute is null || defaultValueAttribute is null) { - explicitDefaultValue = $"({property.Type.Name}){explicitDefaultValue}"; - // var foo = property.Type.Name; - // explicitDefaultValue += $"/* {foo} */"; + spc.ReportDiagnostic(Diagnostic.Create( + InternalError, + location: null, + "Could not find ObsoleteAttribute, DisplayNameAttribute or DefaultValueAttribute")); + return; } - var propertyDetails = new PropertyDetails + var properties = new List(); + var propertiesByKeyword = new Dictionary(); + foreach (var member in type.GetMembers()) { - Name = property.Name, - CanonicalName = displayName, - TypeName = property.Type.Name, - IsEnum = property.Type.TypeKind == TypeKind.Enum, - IsObsolete = property.GetAttributes().Any(a => obsoleteAttribute.Equals(a.AttributeClass, SymbolEqualityComparer.Default)), - DefaultValue = explicitDefaultValue - }; + if (member is not IPropertySymbol property || + property.GetAttributes().FirstOrDefault(a => connectionStringPropertyAttribute.Equals(a.AttributeClass, SymbolEqualityComparer.Default)) is not { } propertyAttribute || + property.GetAttributes() + .FirstOrDefault(a => displayNameAttribute.Equals(a.AttributeClass, SymbolEqualityComparer.Default)) + ?.ConstructorArguments[0].Value is not string displayName) + { + continue; + } - properties.Add(propertyDetails); + var explicitDefaultValue = property.GetAttributes() + .FirstOrDefault(a => defaultValueAttribute.Equals(a.AttributeClass, SymbolEqualityComparer.Default)) + ?.ConstructorArguments[0].Value; - propertiesByKeyword[displayName.ToUpperInvariant()] = propertyDetails; - if (property.Name != displayName) - { - var propertyName = property.Name.ToUpperInvariant(); - if (!propertiesByKeyword.ContainsKey(propertyName)) - propertyDetails.Alternatives.Add(propertyName); - } + if (explicitDefaultValue is string s) + explicitDefaultValue = '"' + s.Replace("\"", "\"\"") + '"'; - if (propertyAttribute.ConstructorArguments.Length == 1) - { - foreach (var synonymArg in propertyAttribute.ConstructorArguments[0].Values) + if (explicitDefaultValue is not null && property.Type.TypeKind == TypeKind.Enum) { - if (synonymArg.Value is string synonym) + explicitDefaultValue = $"({property.Type.Name}){explicitDefaultValue}"; + } + + var propertyDetails = new PropertyDetails + { + Name = property.Name, + CanonicalName = displayName, + TypeName = property.Type.Name, + IsEnum = property.Type.TypeKind == TypeKind.Enum, + IsObsolete = property.GetAttributes().Any(a => obsoleteAttribute.Equals(a.AttributeClass, SymbolEqualityComparer.Default)), + DefaultValue = explicitDefaultValue + }; + + properties.Add(propertyDetails); + + propertiesByKeyword[displayName.ToUpperInvariant()] = propertyDetails; + if (property.Name != displayName) + { + var propertyName = property.Name.ToUpperInvariant(); + if (!propertiesByKeyword.ContainsKey(propertyName)) + propertyDetails.Alternatives.Add(propertyName); + } + + if (propertyAttribute.ConstructorArguments.Length == 1) + { + foreach (var synonymArg in propertyAttribute.ConstructorArguments[0].Values) { - var synonymName = synonym.ToUpperInvariant(); - if (!propertiesByKeyword.ContainsKey(synonymName)) - propertyDetails.Alternatives.Add(synonymName); + if (synonymArg.Value is string synonym) + { + var synonymName = synonym.ToUpperInvariant(); + if (!propertiesByKeyword.ContainsKey(synonymName)) + propertyDetails.Alternatives.Add(synonymName); + } } } } - } - var template = Template.Parse(EmbeddedResource.GetContent("NpgsqlConnectionStringBuilder.snbtxt"), "NpgsqlConnectionStringBuilder.snbtxt"); + var template = Template.Parse(EmbeddedResource.GetContent("NpgsqlConnectionStringBuilder.snbtxt"), "NpgsqlConnectionStringBuilder.snbtxt"); - var output = template.Render(new - { - Properties = properties, - PropertiesByKeyword = propertiesByKeyword - }); + var output = template.Render(new + { + Properties = properties, + PropertiesByKeyword = propertiesByKeyword + }); - context.AddSource(type.Name + ".Generated.cs", SourceText.From(output, Encoding.UTF8)); + spc.AddSource(type.Name + ".Generated.cs", SourceText.From(output, Encoding.UTF8)); + }); } sealed class PropertyDetails diff --git a/src/Npgsql/BackendMessages/AuthenticationMessages.cs b/src/Npgsql/BackendMessages/AuthenticationMessages.cs index b6320e87b8..c52da80d33 100644 --- a/src/Npgsql/BackendMessages/AuthenticationMessages.cs +++ b/src/Npgsql/BackendMessages/AuthenticationMessages.cs @@ -13,23 +13,15 @@ abstract class AuthenticationRequestMessage : IBackendMessage sealed class AuthenticationOkMessage : AuthenticationRequestMessage { - internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationOk; + internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.Ok; internal static readonly AuthenticationOkMessage Instance = new(); AuthenticationOkMessage() { } } -sealed class AuthenticationKerberosV5Message : AuthenticationRequestMessage -{ - internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationKerberosV5; - - internal static readonly AuthenticationKerberosV5Message Instance = new(); - AuthenticationKerberosV5Message() { } -} - sealed class AuthenticationCleartextPasswordMessage : AuthenticationRequestMessage { - internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationCleartextPassword; + internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.CleartextPassword; internal static readonly AuthenticationCleartextPasswordMessage Instance = new(); AuthenticationCleartextPasswordMessage() { } @@ -37,7 +29,7 @@ sealed class AuthenticationCleartextPasswordMessage : AuthenticationRequestMess sealed class AuthenticationMD5PasswordMessage : AuthenticationRequestMessage { - internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationMD5Password; + internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.MD5Password; internal byte[] Salt { get; } @@ -49,22 +41,12 @@ internal static AuthenticationMD5PasswordMessage Load(NpgsqlReadBuffer buf) } AuthenticationMD5PasswordMessage(byte[] salt) - { - Salt = salt; - } -} - -sealed class AuthenticationSCMCredentialMessage : AuthenticationRequestMessage -{ - internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationSCMCredential; - - internal static readonly AuthenticationSCMCredentialMessage Instance = new(); - AuthenticationSCMCredentialMessage() { } + => Salt = salt; } sealed class AuthenticationGSSMessage : AuthenticationRequestMessage { - internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationGSS; + internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.GSS; internal static readonly AuthenticationGSSMessage Instance = new(); AuthenticationGSSMessage() { } @@ -72,7 +54,7 @@ sealed class AuthenticationGSSMessage : AuthenticationRequestMessage sealed class AuthenticationGSSContinueMessage : AuthenticationRequestMessage { - internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationGSSContinue; + internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.GSSContinue; internal byte[] AuthenticationData { get; } @@ -85,14 +67,12 @@ internal static AuthenticationGSSContinueMessage Load(NpgsqlReadBuffer buf, int } AuthenticationGSSContinueMessage(byte[] authenticationData) - { - AuthenticationData = authenticationData; - } + => AuthenticationData = authenticationData; } sealed class AuthenticationSSPIMessage : AuthenticationRequestMessage { - internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationSSPI; + internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.SSPI; internal static readonly AuthenticationSSPIMessage Instance = new(); AuthenticationSSPIMessage() { } @@ -102,8 +82,8 @@ sealed class AuthenticationSSPIMessage : AuthenticationRequestMessage sealed class AuthenticationSASLMessage : AuthenticationRequestMessage { - internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationSASL; - internal List Mechanisms { get; } = new(); + internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.SASL; + internal List Mechanisms { get; } = []; internal AuthenticationSASLMessage(NpgsqlReadBuffer buf) { @@ -117,7 +97,7 @@ internal AuthenticationSASLMessage(NpgsqlReadBuffer buf) sealed class AuthenticationSASLContinueMessage : AuthenticationRequestMessage { - internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationSASLContinue; + internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.SASLContinue; internal byte[] Payload { get; } internal AuthenticationSASLContinueMessage(NpgsqlReadBuffer buf, int len) @@ -171,7 +151,7 @@ internal static AuthenticationSCRAMServerFirstMessage Load(byte[] bytes, ILogger sealed class AuthenticationSASLFinalMessage : AuthenticationRequestMessage { - internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.AuthenticationSASLFinal; + internal override AuthenticationRequestType AuthRequestType => AuthenticationRequestType.SASLFinal; internal byte[] Payload { get; } internal AuthenticationSASLFinalMessage(NpgsqlReadBuffer buf, int len) @@ -210,20 +190,15 @@ internal AuthenticationSCRAMServerFinalMessage(string serverSignature) #endregion SASL -// TODO: Remove Authentication prefix from everything enum AuthenticationRequestType { - AuthenticationOk = 0, - AuthenticationKerberosV4 = 1, - AuthenticationKerberosV5 = 2, - AuthenticationCleartextPassword = 3, - AuthenticationCryptPassword = 4, - AuthenticationMD5Password = 5, - AuthenticationSCMCredential = 6, - AuthenticationGSS = 7, - AuthenticationGSSContinue = 8, - AuthenticationSSPI = 9, - AuthenticationSASL = 10, - AuthenticationSASLContinue = 11, - AuthenticationSASLFinal = 12 + Ok = 0, + CleartextPassword = 3, + MD5Password = 5, + GSS = 7, + GSSContinue = 8, + SSPI = 9, + SASL = 10, + SASLContinue = 11, + SASLFinal = 12 } diff --git a/src/Npgsql/BackendMessages/CopyMessages.cs b/src/Npgsql/BackendMessages/CopyMessages.cs index 1aa8aec0c2..e7d4d6935c 100644 --- a/src/Npgsql/BackendMessages/CopyMessages.cs +++ b/src/Npgsql/BackendMessages/CopyMessages.cs @@ -13,9 +13,7 @@ abstract class CopyResponseMessageBase : IBackendMessage internal List ColumnFormatCodes { get; } internal CopyResponseMessageBase() - { - ColumnFormatCodes = new List(); - } + => ColumnFormatCodes = []; internal void Load(NpgsqlReadBuffer buf) { diff --git a/src/Npgsql/BackendMessages/ParameterDescriptionMessage.cs b/src/Npgsql/BackendMessages/ParameterDescriptionMessage.cs index ebda485331..16c4687da5 100644 --- a/src/Npgsql/BackendMessages/ParameterDescriptionMessage.cs +++ b/src/Npgsql/BackendMessages/ParameterDescriptionMessage.cs @@ -9,9 +9,7 @@ sealed class ParameterDescriptionMessage : IBackendMessage internal List TypeOIDs { get; } internal ParameterDescriptionMessage() - { - TypeOIDs = new List(); - } + => TypeOIDs = []; internal ParameterDescriptionMessage Load(NpgsqlReadBuffer buf) { @@ -23,4 +21,4 @@ internal ParameterDescriptionMessage Load(NpgsqlReadBuffer buf) } public BackendMessageCode Code => BackendMessageCode.ParameterDescription; -} \ No newline at end of file +} diff --git a/src/Npgsql/BackendMessages/RowDescriptionMessage.cs b/src/Npgsql/BackendMessages/RowDescriptionMessage.cs index 1dd1045e21..fd04ddfdaf 100644 --- a/src/Npgsql/BackendMessages/RowDescriptionMessage.cs +++ b/src/Npgsql/BackendMessages/RowDescriptionMessage.cs @@ -11,18 +11,11 @@ namespace Npgsql.BackendMessages; -readonly struct ColumnInfo +readonly struct ColumnInfo(PgConverterInfo converterInfo, DataFormat dataFormat, bool asObject) { - public ColumnInfo(PgConverterInfo converterInfo, DataFormat dataFormat, bool asObject) - { - ConverterInfo = converterInfo; - DataFormat = dataFormat; - AsObject = asObject; - } - - public PgConverterInfo ConverterInfo { get; } - public DataFormat DataFormat { get; } - public bool AsObject { get; } + public PgConverterInfo ConverterInfo { get; } = converterInfo; + public DataFormat DataFormat { get; } = dataFormat; + public bool AsObject { get; } = asObject; } /// @@ -126,15 +119,19 @@ internal static RowDescriptionMessage CreateForReplication( return msg; } - public FieldDescription this[int index] + public FieldDescription this[int ordinal] { [MethodImpl(MethodImplOptions.AggressiveInlining)] get { - Debug.Assert(index < Count); - Debug.Assert(_fields[index] != null); + if ((uint)ordinal < (uint)Count) + { + Debug.Assert(_fields[ordinal] != null); + return _fields[ordinal]!; + } - return _fields[index]!; + ThrowHelper.ThrowIndexOutOfRangeException("Ordinal must be between 0 and " + (Count - 1)); + return default!; } } @@ -235,7 +232,7 @@ internal FieldDescription(FieldDescription source) DataFormat = source.DataFormat; PostgresType = source.PostgresType; Field = source.Field; - _objectOrDefaultInfo = source._objectOrDefaultInfo; + _objectInfo = source._objectInfo; } internal void Populate( @@ -253,7 +250,7 @@ internal void Populate( DataFormat = dataFormat; PostgresType = _serializerOptions.DatabaseInfo.FindPostgresType((Oid)TypeOID)?.GetRepresentationalType() ?? UnknownBackendType.Instance; Field = new(Name, _serializerOptions.ToCanonicalTypeId(PostgresType), TypeModifier); - _objectOrDefaultInfo = default; + _objectInfo = default; } /// @@ -299,18 +296,18 @@ internal void Populate( internal PostgresType PostgresType { get; private set; } - internal Type FieldType => ObjectOrDefaultInfo.TypeToConvert; + internal Type FieldType => ObjectInfo.TypeToConvert; - ColumnInfo _objectOrDefaultInfo; - internal PgConverterInfo ObjectOrDefaultInfo + ColumnInfo _objectInfo; + internal PgConverterInfo ObjectInfo { get { - if (!_objectOrDefaultInfo.ConverterInfo.IsDefault) - return _objectOrDefaultInfo.ConverterInfo; + if (!_objectInfo.ConverterInfo.IsDefault) + return _objectInfo.ConverterInfo; - ref var info = ref _objectOrDefaultInfo; - GetInfo(null, ref _objectOrDefaultInfo); + ref var info = ref _objectInfo; + GetInfoCore(null, ref _objectInfo); return info.ConverterInfo; } } @@ -323,29 +320,33 @@ internal FieldDescription Clone() return field; } - internal void GetInfo(Type? type, ref ColumnInfo lastColumnInfo) + internal void GetInfo(Type type, ref ColumnInfo lastColumnInfo) => GetInfoCore(type, ref lastColumnInfo); + void GetInfoCore(Type? type, ref ColumnInfo lastColumnInfo) { Debug.Assert(lastColumnInfo.ConverterInfo.IsDefault || ( - ReferenceEquals(_serializerOptions, lastColumnInfo.ConverterInfo.TypeInfo.Options) && - lastColumnInfo.ConverterInfo.TypeInfo.PgTypeId == _serializerOptions.ToCanonicalTypeId(PostgresType)), "Cache is bleeding over"); + ReferenceEquals(_serializerOptions, lastColumnInfo.ConverterInfo.TypeInfo.Options) && ( + IsUnknownResultType() && lastColumnInfo.ConverterInfo.TypeInfo.PgTypeId == _serializerOptions.TextPgTypeId || + // Normal resolution + lastColumnInfo.ConverterInfo.TypeInfo.PgTypeId == _serializerOptions.ToCanonicalTypeId(PostgresType)) + ), "Cache is bleeding over"); if (!lastColumnInfo.ConverterInfo.IsDefault && lastColumnInfo.ConverterInfo.TypeToConvert == type) return; - var odfInfo = DataFormat is DataFormat.Text && type is not null ? ObjectOrDefaultInfo : _objectOrDefaultInfo.ConverterInfo; - if (odfInfo is { IsDefault: false }) + var objectInfo = DataFormat is DataFormat.Text && type is not null ? ObjectInfo : _objectInfo.ConverterInfo; + if (objectInfo is { IsDefault: false }) { if (typeof(object) == type) { - lastColumnInfo = new(odfInfo, DataFormat, true); + lastColumnInfo = new(objectInfo, DataFormat, true); return; } - if (odfInfo.TypeToConvert == type) + if (objectInfo.TypeToConvert == type) { // As TypeInfoMappingCollection is always adding object mappings for // default/datatypename mappings, we'll also check Converter.TypeToConvert. // If we have an exact match we are still able to use e.g. a converter for ints in an unboxed fashion. - lastColumnInfo = new(odfInfo, DataFormat, odfInfo.IsBoxingConverter && odfInfo.Converter.TypeToConvert != type); + lastColumnInfo = new(objectInfo, DataFormat, objectInfo.IsBoxingConverter && objectInfo.Converter.TypeToConvert != type); return; } } @@ -355,33 +356,48 @@ internal void GetInfo(Type? type, ref ColumnInfo lastColumnInfo) [MethodImpl(MethodImplOptions.NoInlining)] void GetInfoSlow(Type? type, out ColumnInfo lastColumnInfo) { - var typeInfo = AdoSerializerHelpers.GetTypeInfoForReading(type ?? typeof(object), PostgresType, _serializerOptions); PgConverterInfo converterInfo; switch (DataFormat) { - case DataFormat.Binary: - // If we don't support binary we'll just throw. + case DataFormat.Text when IsUnknownResultType(): + { + // Try to resolve some 'pg_catalog.text' type info for the expected clr type. + var typeInfo = AdoSerializerHelpers.GetTypeInfoForReading(type ?? typeof(string), _serializerOptions.TextPgTypeId, _serializerOptions); + + // We start binding to DataFormat.Binary as it's the broadest supported format. + // The format however is irrelevant as 'pg_catalog.text' data is identical across either. + // Given we did a resolution against 'pg_catalog.text' and not the actual field type we're in reinterpretation territory anyway. + if (!typeInfo.TryBind(Field, DataFormat.Binary, out converterInfo)) + converterInfo = typeInfo.Bind(Field, DataFormat.Text); + + lastColumnInfo = new(converterInfo, DataFormat, type != converterInfo.TypeToConvert || converterInfo.IsBoxingConverter); + + break; + } + case DataFormat.Binary or DataFormat.Text: + { + var typeInfo = AdoSerializerHelpers.GetTypeInfoForReading(type ?? typeof(object), _serializerOptions.ToCanonicalTypeId(PostgresType), _serializerOptions); + + // If we don't support the DataFormat we'll just throw. converterInfo = typeInfo.Bind(Field, DataFormat); - lastColumnInfo = new(converterInfo, DataFormat.Binary, typeof(object) == type || converterInfo.IsBoxingConverter); + lastColumnInfo = new(converterInfo, DataFormat, typeof(object) == type || converterInfo.IsBoxingConverter); break; + } default: - // For text we'll fall back to any available text converter for the expected clr type or throw. - if (!typeInfo.TryBind(Field, DataFormat, out converterInfo)) - { - typeInfo = AdoSerializerHelpers.GetTypeInfoForReading(type ?? typeof(string), _serializerOptions.TextPgType, _serializerOptions); - converterInfo = typeInfo.Bind(Field, DataFormat); - lastColumnInfo = new(converterInfo, DataFormat, type != converterInfo.TypeToConvert || converterInfo.IsBoxingConverter); - } - else - lastColumnInfo = new(converterInfo, DataFormat, typeof(object) == type || converterInfo.IsBoxingConverter); + ThrowHelper.ThrowUnreachableException("Unknown data format {0}", DataFormat); + lastColumnInfo = default; break; } // We delay initializing ObjectOrDefaultInfo until after the first lookup (unless it is itself the first lookup). // When passed in an unsupported type it allows the error to be more specific, instead of just having object/null to deal with. - if (_objectOrDefaultInfo.ConverterInfo.IsDefault && type is not null) - _ = ObjectOrDefaultInfo; + if (_objectInfo.ConverterInfo.IsDefault && type is not null) + _ = ObjectInfo; } + + // DataFormat.Text today exclusively signals that we executed with an UnknownResultTypeList. + // If we ever want to fully support DataFormat.Text we'll need to flow UnknownResultType status separately. + bool IsUnknownResultType() => DataFormat is DataFormat.Text; } /// diff --git a/src/Npgsql/Internal/AdoSerializerHelpers.cs b/src/Npgsql/Internal/AdoSerializerHelpers.cs index d0ea19c7a8..21010b3f99 100644 --- a/src/Npgsql/Internal/AdoSerializerHelpers.cs +++ b/src/Npgsql/Internal/AdoSerializerHelpers.cs @@ -2,31 +2,36 @@ using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using Npgsql.Internal.Postgres; -using Npgsql.PostgresTypes; using NpgsqlTypes; namespace Npgsql.Internal; static class AdoSerializerHelpers { - public static PgTypeInfo GetTypeInfoForReading(Type type, PostgresType postgresType, PgSerializerOptions options) + public static PgTypeInfo GetTypeInfoForReading(Type type, PgTypeId pgTypeId, PgSerializerOptions options) { PgTypeInfo? typeInfo = null; Exception? inner = null; try { - typeInfo = type == typeof(object) ? options.GetObjectOrDefaultTypeInfo(postgresType) : options.GetTypeInfo(type, postgresType); + typeInfo = options.GetTypeInfoInternal(type, pgTypeId); + if (typeInfo is { SupportsReading: false }) + typeInfo = null; } catch (Exception ex) { inner = ex; } - return typeInfo ?? ThrowReadingNotSupported(type, postgresType.DisplayName, inner); + return typeInfo ?? ThrowReadingNotSupported(type, options, pgTypeId, inner); // InvalidCastException thrown to align with ADO.NET convention. [DoesNotReturn] - static PgTypeInfo ThrowReadingNotSupported(Type? type, string displayName, Exception? inner = null) - => throw new InvalidCastException($"Reading{(type is null ? "" : $" as '{type.FullName}'")} is not supported for fields having DataTypeName '{displayName}'", inner); + static PgTypeInfo ThrowReadingNotSupported(Type? type, PgSerializerOptions options, PgTypeId pgTypeId, Exception? inner = null) + { + throw new InvalidCastException( + $"Reading{(type is null ? "" : $" as '{type.FullName}'")} is not supported for fields having DataTypeName '{options.DatabaseInfo.FindPostgresType(pgTypeId)?.DisplayName ?? "unknown"}'", + inner); + } } public static PgTypeInfo GetTypeInfoForWriting(Type? type, PgTypeId? pgTypeId, PgSerializerOptions options, NpgsqlDbType? npgsqlDbType = null) @@ -37,7 +42,9 @@ public static PgTypeInfo GetTypeInfoForWriting(Type? type, PgTypeId? pgTypeId, P Exception? inner = null; try { - typeInfo = type is null ? options.GetDefaultTypeInfo(pgTypeId!.Value) : options.GetTypeInfo(type, pgTypeId); + typeInfo = options.GetTypeInfoInternal(type, pgTypeId); + if (typeInfo is { SupportsWriting: false }) + typeInfo = null; } catch (Exception ex) { diff --git a/src/Npgsql/Internal/BufferRequirements.cs b/src/Npgsql/Internal/BufferRequirements.cs index cd32c0cbd1..14ffabc52b 100644 --- a/src/Npgsql/Internal/BufferRequirements.cs +++ b/src/Npgsql/Internal/BufferRequirements.cs @@ -1,7 +1,9 @@ using System; +using System.Diagnostics.CodeAnalysis; namespace Npgsql.Internal; +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] public readonly struct BufferRequirements : IEquatable { readonly Size _read; diff --git a/src/Npgsql/Internal/ChainDbTypeResolver.cs b/src/Npgsql/Internal/ChainDbTypeResolver.cs new file mode 100644 index 0000000000..16f3c229ee --- /dev/null +++ b/src/Npgsql/Internal/ChainDbTypeResolver.cs @@ -0,0 +1,33 @@ +using System; +using System.Collections.Generic; +using System.Data; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal; + +sealed class ChainDbTypeResolver(IEnumerable resolvers) : IDbTypeResolver +{ + readonly IDbTypeResolver[] _resolvers = new List(resolvers).ToArray(); + + public string? GetDataTypeName(DbType dbType, Type? type) + { + foreach (var resolver in _resolvers) + { + if (resolver.GetDataTypeName(dbType, type) is { } dataTypeName) + return dataTypeName; + } + + return null; + } + + public DbType? GetDbType(DataTypeName dataTypeName) + { + foreach (var resolver in _resolvers) + { + if (resolver.GetDbType(dataTypeName) is { } dbType) + return dbType; + } + + return null; + } +} diff --git a/src/Npgsql/Internal/ChainTypeInfoResolver.cs b/src/Npgsql/Internal/ChainTypeInfoResolver.cs index 18c39d80b6..4c7f56e454 100644 --- a/src/Npgsql/Internal/ChainTypeInfoResolver.cs +++ b/src/Npgsql/Internal/ChainTypeInfoResolver.cs @@ -4,12 +4,9 @@ namespace Npgsql.Internal; -sealed class ChainTypeInfoResolver : IPgTypeInfoResolver +sealed class ChainTypeInfoResolver(IEnumerable resolvers) : IPgTypeInfoResolver { - readonly IPgTypeInfoResolver[] _resolvers; - - public ChainTypeInfoResolver(IEnumerable resolvers) - => _resolvers = new List(resolvers).ToArray(); + readonly IPgTypeInfoResolver[] _resolvers = new List(resolvers).ToArray(); public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) { diff --git a/src/Npgsql/Internal/Composites/Metadata/CompositeBuilder.cs b/src/Npgsql/Internal/Composites/Metadata/CompositeBuilder.cs index c51c0dafa0..0917dfd834 100644 --- a/src/Npgsql/Internal/Composites/Metadata/CompositeBuilder.cs +++ b/src/Npgsql/Internal/Composites/Metadata/CompositeBuilder.cs @@ -1,18 +1,18 @@ using System; using System.Buffers; +using System.Collections.Generic; using Npgsql.Util; namespace Npgsql.Internal.Composites; -abstract class CompositeBuilder +abstract class CompositeBuilder(StrongBox[] tempBoxes, IReadOnlyList fields) { - protected StrongBox[] _tempBoxes; + protected readonly StrongBox[] _tempBoxes = tempBoxes; + protected readonly IReadOnlyList _fields = fields; protected int _currentField; - - protected CompositeBuilder(StrongBox[] tempBoxes) => _tempBoxes = tempBoxes; + protected object? _boxedInstance; protected abstract void Construct(); - protected abstract void SetField(TValue value); public void AddValue(TValue value) { @@ -32,78 +32,72 @@ public void AddValue(TValue value) } _currentField++; + + void SetField(TValue value) + { + if (_boxedInstance is null) + ThrowHelper.ThrowInvalidOperationException("Not constructed yet, or no more fields were expected."); + + var currentField = _currentField; + var fields = _fields; + if (currentField > fields.Count - 1) + ThrowHelper.ThrowIndexOutOfRangeException($"Cannot set field {value} at position {currentField} - all fields have already been set"); + + ((CompositeFieldInfo)fields[currentField]).Set(_boxedInstance, value); + } } } -sealed class CompositeBuilder : CompositeBuilder, IDisposable +sealed class CompositeBuilder(CompositeInfo compositeInfo) : CompositeBuilder(compositeInfo.CreateTempBoxes(), compositeInfo.Fields), IDisposable { - readonly CompositeInfo _compositeInfo; T _instance = default!; - object? _boxedInstance; - - public CompositeBuilder(CompositeInfo compositeInfo) - : base(compositeInfo.CreateTempBoxes()) - => _compositeInfo = compositeInfo; public T Complete() { - if (_currentField < _compositeInfo.Fields.Count) - throw new InvalidOperationException($"Missing values, expected: {_compositeInfo.Fields.Count} got: {_currentField}"); + if (_currentField < compositeInfo.Fields.Count) + throw new InvalidOperationException($"Missing values, expected: {compositeInfo.Fields.Count} got: {_currentField}"); return (T)(_boxedInstance ?? _instance!); } - public void Reset() - { - _instance = default!; - _boxedInstance = null; - _currentField = 0; - foreach (var box in _tempBoxes) - box.Clear(); - } - - public void Dispose() => Reset(); - protected override void Construct() { var tempBoxes = _tempBoxes; if (_currentField < tempBoxes.Length - 1) throw new InvalidOperationException($"Missing values, expected: {tempBoxes.Length} got: {_currentField + 1}"); - var fields = _compositeInfo.Fields; - var args = ArrayPool.Shared.Rent(_compositeInfo.ConstructorParameters); + var fields = compositeInfo.Fields; + var args = ArrayPool.Shared.Rent(compositeInfo.ConstructorParameters); for (var i = 0; i < tempBoxes.Length; i++) { var field = fields[i]; if (field.ConstructorParameterIndex is { } argIndex) args[argIndex] = tempBoxes[i]; } - _instance = _compositeInfo.Constructor(args)!; - ArrayPool.Shared.Return(args); + _instance = compositeInfo.Constructor(args)!; + ArrayPool.Shared.Return(args, clearArray: true); - if (tempBoxes.Length == _compositeInfo.Fields.Count) + if (tempBoxes.Length == compositeInfo.Fields.Count) return; // We're expecting or already have stored more fields, so box the instance once here. _boxedInstance = _instance; for (var i = 0; i < tempBoxes.Length; i++) { - var field = _compositeInfo.Fields[i]; + var field = compositeInfo.Fields[i]; if (field.ConstructorParameterIndex is null) field.Set(_boxedInstance, tempBoxes[i]); } } - protected override void SetField(TValue value) + public void Reset() { - if (_boxedInstance is null) - ThrowHelper.ThrowInvalidOperationException("Not constructed yet, or no more fields were expected."); - - var currentField = _currentField; - var fields = _compositeInfo.Fields; - if (currentField > fields.Count - 1) - ThrowHelper.ThrowIndexOutOfRangeException($"Cannot set field {value} at position {currentField} - all fields have already been set"); - - ((CompositeFieldInfo)fields[currentField]).Set(_boxedInstance, value); + _instance = default!; + _boxedInstance = null; + _currentField = 0; + foreach (var box in _tempBoxes) + box.Clear(); } + + public void Dispose() { } } diff --git a/src/Npgsql/Internal/Composites/Metadata/CompositeFieldInfo.cs b/src/Npgsql/Internal/Composites/Metadata/CompositeFieldInfo.cs index a6cc79e4e9..080d31ea68 100644 --- a/src/Npgsql/Internal/Composites/Metadata/CompositeFieldInfo.cs +++ b/src/Npgsql/Internal/Composites/Metadata/CompositeFieldInfo.cs @@ -143,10 +143,12 @@ sealed class CompositeFieldInfo : CompositeFieldInfo _getter = getter; } + // Accessed through reflection (ReflectionCompositeInfoFactory) public CompositeFieldInfo(string name, PgTypeInfo typeInfo, PgTypeId nominalPgTypeId, Func getter, int parameterIndex) : this(name, typeInfo, nominalPgTypeId, getter) => _parameterIndex = parameterIndex; + // Accessed through reflection (ReflectionCompositeInfoFactory) public CompositeFieldInfo(string name, PgTypeInfo typeInfo, PgTypeId nominalPgTypeId, Func getter, Action setter) : this(name, typeInfo, nominalPgTypeId, getter) => _setter = setter; diff --git a/src/Npgsql/Internal/Composites/Metadata/CompositeInfo.cs b/src/Npgsql/Internal/Composites/Metadata/CompositeInfo.cs index 1db91b2052..f1e291cf53 100644 --- a/src/Npgsql/Internal/Composites/Metadata/CompositeInfo.cs +++ b/src/Npgsql/Internal/Composites/Metadata/CompositeInfo.cs @@ -12,35 +12,20 @@ sealed class CompositeInfo public CompositeInfo(CompositeFieldInfo[] fields, int constructorParameters, Func constructor) { _lastConstructorFieldIndex = -1; - for (var i = fields.Length - 1; i >= 0; i--) + var constructorFields = 0; + for (var i = 0; i < fields.Length; i++) + { if (fields[i].ConstructorParameterIndex is not null) { _lastConstructorFieldIndex = i; - break; + constructorFields++; } - - var parameterSum = 0; - for (var i = constructorParameters - 1; i > 0; i--) - parameterSum += i; - - var argumentsSum = 0; - if (parameterSum > 0) - { - foreach (var field in fields) - if (field.ConstructorParameterIndex is { } index) - argumentsSum += index; } - if (parameterSum != argumentsSum) + if (constructorParameters != constructorFields) throw new InvalidOperationException($"Missing composite fields to map to the required {constructorParameters} constructor parameters."); _fields = fields; - var arguments = constructorParameters is 0 ? Array.Empty() : new CompositeFieldInfo[constructorParameters]; - foreach (var field in fields) - { - if (field.ConstructorParameterIndex is { } index) - arguments[index] = field; - } Constructor = constructor; ConstructorParameters = constructorParameters; } @@ -56,12 +41,14 @@ public CompositeInfo(CompositeFieldInfo[] fields, int constructorParameters, Fun /// public StrongBox[] CreateTempBoxes() { - var valueCache = _lastConstructorFieldIndex + 1 is 0 ? Array.Empty() : new StrongBox[_lastConstructorFieldIndex + 1]; - var fields = _fields; + if (_lastConstructorFieldIndex is -1) + return []; - for (var i = 0; i < valueCache.Length; i++) - valueCache[i] = fields[i].CreateBox(); + var boxes = new StrongBox[_lastConstructorFieldIndex + 1]; + var fields = _fields; + for (var i = 0; i < boxes.Length; i++) + boxes[i] = fields[i].CreateBox(); - return valueCache; + return boxes; } } diff --git a/src/Npgsql/Internal/Composites/ReflectionCompositeInfoFactory.cs b/src/Npgsql/Internal/Composites/ReflectionCompositeInfoFactory.cs index d6c51b8344..c520c4fdf9 100644 --- a/src/Npgsql/Internal/Composites/ReflectionCompositeInfoFactory.cs +++ b/src/Npgsql/Internal/Composites/ReflectionCompositeInfoFactory.cs @@ -27,13 +27,14 @@ static class ReflectionCompositeInfoFactory throw new AmbiguousMatchException($"Property {propertyMap[duplicates[0]].Name} and field {fieldMap[duplicates[0]].Name} map to the same '{pgFields[duplicates[0]].Name}' composite field name."); var (constructorInfo, parameterFieldMap) = MapBestMatchingConstructor(pgFields, nameTranslator); - var constructorParameters = constructorInfo?.GetParameters() ?? Array.Empty(); + var constructorParameters = constructorInfo?.GetParameters() ?? []; var compositeFields = new CompositeFieldInfo?[pgFields.Count]; for (var i = 0; i < parameterFieldMap.Length; i++) { var fieldIndex = parameterFieldMap[i]; var pgField = pgFields[fieldIndex]; var parameter = constructorParameters[i]; + var reprTypeId = options.ToCanonicalTypeId(pgField.Type.GetRepresentationalType()); PgTypeInfo pgTypeInfo; Delegate getter; if (propertyMap.TryGetValue(fieldIndex, out var property) && property.GetMethod is not null) @@ -41,7 +42,7 @@ static class ReflectionCompositeInfoFactory if (property.PropertyType != parameter.ParameterType) throw new InvalidOperationException($"Could not find a matching getter for constructor parameter {parameter.Name} and type {parameter.ParameterType} mapped to composite field {pgFields[fieldIndex].Name}."); - pgTypeInfo = options.GetTypeInfo(property.PropertyType, pgField.Type.GetRepresentationalType()) ?? throw NotSupportedField(pgType, pgField, isField: false, property.Name, property.PropertyType); + pgTypeInfo = options.GetTypeInfoInternal(property.PropertyType, reprTypeId) ?? throw NotSupportedField(pgType, pgField, isField: false, property.Name, property.PropertyType); getter = CreateGetter(property); } else if (fieldMap.TryGetValue(fieldIndex, out var field)) @@ -49,7 +50,7 @@ static class ReflectionCompositeInfoFactory if (field.FieldType != parameter.ParameterType) throw new InvalidOperationException($"Could not find a matching getter for constructor parameter {parameter.Name} and type {parameter.ParameterType} mapped to composite field {pgFields[fieldIndex].Name}."); - pgTypeInfo = options.GetTypeInfo(field.FieldType, pgField.Type.GetRepresentationalType()) ?? throw NotSupportedField(pgType, pgField, isField: true, field.Name, field.FieldType); + pgTypeInfo = options.GetTypeInfoInternal(field.FieldType, reprTypeId) ?? throw NotSupportedField(pgType, pgField, isField: true, field.Name, field.FieldType); getter = CreateGetter(field); } else @@ -65,19 +66,20 @@ static class ReflectionCompositeInfoFactory continue; var pgField = pgFields[fieldIndex]; + var reprTypeId = options.ToCanonicalTypeId(pgField.Type.GetRepresentationalType()); PgTypeInfo pgTypeInfo; Delegate getter; Delegate setter; if (propertyMap.TryGetValue(fieldIndex, out var property)) { - pgTypeInfo = options.GetTypeInfo(property.PropertyType, pgField.Type.GetRepresentationalType()) + pgTypeInfo = options.GetTypeInfoInternal(property.PropertyType, reprTypeId) ?? throw NotSupportedField(pgType, pgField, isField: false, property.Name, property.PropertyType); getter = CreateGetter(property); setter = CreateSetter(property); } else if (fieldMap.TryGetValue(fieldIndex, out var field)) { - pgTypeInfo = options.GetTypeInfo(field.FieldType, pgField.Type.GetRepresentationalType()) + pgTypeInfo = options.GetTypeInfoInternal(field.FieldType, reprTypeId) ?? throw NotSupportedField(pgType, pgField, isField: true, field.Name, field.FieldType); getter = CreateGetter(field); setter = CreateSetter(field); @@ -120,7 +122,7 @@ static Delegate CreateSetter(FieldInfo info) static Delegate CreateGetter(PropertyInfo info) { - var invalidOpExceptionMessageConstructor = typeof(InvalidOperationException).GetConstructor(new []{ typeof(string) })!; + var invalidOpExceptionMessageConstructor = typeof(InvalidOperationException).GetConstructor([typeof(string)])!; var instance = Expression.Parameter(typeof(object), "instance"); var body = info.GetMethod is null || !info.GetMethod.IsPublic ? (Expression)Expression.Throw(Expression.New(invalidOpExceptionMessageConstructor, @@ -137,7 +139,7 @@ static Delegate CreateSetter(PropertyInfo info) var instance = Expression.Parameter(typeof(object), "instance"); var value = Expression.Parameter(info.PropertyType, "value"); - var invalidOpExceptionMessageConstructor = typeof(InvalidOperationException).GetConstructor(new []{ typeof(string) })!; + var invalidOpExceptionMessageConstructor = typeof(InvalidOperationException).GetConstructor([typeof(string)])!; var body = info.SetMethod is null || !info.SetMethod.IsPublic ? (Expression)Expression.Throw(Expression.New(invalidOpExceptionMessageConstructor, Expression.Constant($"No (public) setter for '{info}' on type {typeof(T)}")), info.PropertyType) @@ -151,8 +153,8 @@ static Delegate CreateSetter(PropertyInfo info) static Expression UnboxAny(Expression expression, Type type) => type.IsValueType ? Expression.Unbox(expression, type) : Expression.Convert(expression, type, null); - [DynamicDependency("TypedValue", typeof(StrongBox<>))] - [DynamicDependency("Length", typeof(StrongBox[]))] + [DynamicDependency(nameof(StrongBox.TypedValue), typeof(StrongBox<>))] + [DynamicDependency(DynamicallyAccessedMemberTypes.PublicProperties, typeof(StrongBox[]))] [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "DynamicDependencies in place for the System.Linq.Expression.Property calls")] static Func CreateStrongBoxConstructor(ConstructorInfo constructorInfo) { @@ -160,12 +162,12 @@ static Func CreateStrongBoxConstructor(ConstructorInfo constr var parameters = constructorInfo.GetParameters(); var parameterCount = Expression.Constant(parameters.Length); - var argumentExceptionNameMessageConstructor = typeof(ArgumentException).GetConstructor(new []{ typeof(string), typeof(string) })!; + var argumentExceptionNameMessageConstructor = typeof(ArgumentException).GetConstructor([typeof(string), typeof(string)])!; return Expression .Lambda>( Expression.Block( Expression.IfThen( - Expression.LessThan(Expression.Property(values, "Length"), parameterCount), + Expression.LessThan(Expression.Property(values, nameof(Array.Length)), parameterCount), Expression.Throw(Expression.New(argumentExceptionNameMessageConstructor, Expression.Constant("Passed fewer arguments than there are constructor parameters."), Expression.Constant(values.Name))) @@ -176,7 +178,7 @@ static Func CreateStrongBoxConstructor(ConstructorInfo constr Expression.ArrayIndex(values, Expression.Constant(i)), typeof(StrongBox<>).MakeGenericType(parameter.ParameterType) ), - "TypedValue" + nameof(StrongBox.TypedValue) ) )) ), values) diff --git a/src/Npgsql/Internal/Converters/ArrayConverter.cs b/src/Npgsql/Internal/Converters/ArrayConverter.cs index 5c2ff9133f..2d6d443329 100644 --- a/src/Npgsql/Internal/Converters/ArrayConverter.cs +++ b/src/Npgsql/Internal/Converters/ArrayConverter.cs @@ -10,71 +10,100 @@ namespace Npgsql.Internal.Converters; +struct Indices +{ + // Public field to be able to return it by ref in GetItem. + public int One; + public int[]? Many { get; private init; } + public int Count { get; private init; } + + public static Indices Create(int dimensions) + => dimensions switch + { + 0 => new() { Count = dimensions, One = -1 }, + 1 => new() { Count = dimensions }, + _ => new() { Count = dimensions, Many = new int[dimensions] } + }; +} + +static class IndicesExtensions +{ + // Workaround for lack of ref returns on struct fields. + public static ref int GetItem(this ref Indices indices, int index) + { + switch (indices.Count) + { + case 0: + ThrowHelper.ThrowIndexOutOfRangeException("Cannot index into a 0-dimensional array."); + return ref Unsafe.NullRef(); + case 1: + Debug.Assert(index is 0); + Debug.Assert(indices.Many is null); + return ref indices.One; + default: + return ref indices.Many![index]; + } + } +} + interface IElementOperations { - object CreateCollection(int[] lengths); + object CreateCollection(ReadOnlySpan lengths); int GetCollectionCount(object collection, out int[]? lengths); - Size? GetSizeOrDbNull(SizeContext context, object collection, int[] indices, ref object? writeState); - ValueTask Read(bool async, PgReader reader, bool isDbNull, object collection, int[] indices, CancellationToken cancellationToken = default); - ValueTask Write(bool async, PgWriter writer, object collection, int[] indices, CancellationToken cancellationToken = default); + Size? GetSizeOrDbNull(SizeContext context, object collection, Indices indices, ref object? writeState); + ValueTask Read(bool async, PgReader reader, bool isDbNull, object collection, Indices indices, CancellationToken cancellationToken = default); + ValueTask Write(bool async, PgWriter writer, object collection, Indices indices, CancellationToken cancellationToken = default); } -readonly struct PgArrayConverter +readonly struct PgArrayConverter( + IElementOperations elemOps, + bool elemTypeDbNullable, + int? expectedDimensions, + BufferRequirements bufferRequirements, + PgTypeId elemTypeId, + int pgLowerBound = 1) { - internal const string ReadNonNullableCollectionWithNullsExceptionMessage = "Cannot read a non-nullable collection of elements because the returned array contains nulls. Call GetFieldValue with a nullable collection type instead."; - - readonly IElementOperations _elemOps; - readonly int? _expectedDimensions; - readonly BufferRequirements _bufferRequirements; - public bool ElemTypeDbNullable { get; } - readonly int _pgLowerBound; - readonly PgTypeId _elemTypeId; + public const string ReadNonNullableCollectionWithNullsExceptionMessage = + "Cannot read a non-nullable collection of elements because the returned array contains nulls. Call GetFieldValue with a nullable collection type instead."; + public const int MaxDimensions = 8; - public PgArrayConverter(IElementOperations elemOps, bool elemTypeDbNullable, int? expectedDimensions, BufferRequirements bufferRequirements, PgTypeId elemTypeId, int pgLowerBound = 1) - { - _elemTypeId = elemTypeId; - ElemTypeDbNullable = elemTypeDbNullable; - _pgLowerBound = pgLowerBound; - _elemOps = elemOps; - _expectedDimensions = expectedDimensions; - _bufferRequirements = bufferRequirements; - } + public bool ElemTypeDbNullable { get; } = elemTypeDbNullable; - bool IsDbNull(object values, int[] indices) + bool IsDbNull(object values, Indices indices) { object? state = null; - return _elemOps.GetSizeOrDbNull(new(DataFormat.Binary, _bufferRequirements.Write), values, indices, ref state) is null; + return elemOps.GetSizeOrDbNull(new(DataFormat.Binary, bufferRequirements.Write), values, indices, ref state) is null; } - Size GetElemsSize(object values, (Size, object?)[] elemStates, out bool anyElementState, DataFormat format, int count, int[] indices, int[]? lengths = null) + Size GetElemsSize(object values, (Size, object?)[] elemStates, out bool anyElementState, DataFormat format, int count, Indices indices, int[]? lengths = null) { Debug.Assert(elemStates.Length >= count); var totalSize = Size.Zero; - var context = new SizeContext(format, _bufferRequirements.Write); + var context = new SizeContext(format, bufferRequirements.Write); anyElementState = false; - var lastLength = lengths?[lengths.Length - 1] ?? count; - ref var lastIndex = ref indices[indices.Length - 1]; + var lastLength = lengths?[^1] ?? count; + ref var lastIndex = ref indices.GetItem(indices.Count - 1); var i = 0; do { ref var elemItem = ref elemStates[i++]; var elemState = (object?)null; - var size = _elemOps.GetSizeOrDbNull(context, values, indices, ref elemState); + var size = elemOps.GetSizeOrDbNull(context, values, indices, ref elemState); anyElementState = anyElementState || elemState is not null; elemItem = (size ?? -1, elemState); totalSize = totalSize.Combine(size ?? 0); } // We can immediately continue if we didn't reach the end of the last dimension. - while (++lastIndex < lastLength || (indices.Length > 1 && CarryIndices(lengths!, indices))); + while (++lastIndex < lastLength || (indices.Count > 1 && CarryIndices(lengths!, indices))); return totalSize; } - Size GetFixedElemsSize(Size elemSize, object values, int count, int[] indices, int[]? lengths = null) + Size GetFixedElemsSize(Size elemSize, object values, int count, Indices indices, int[]? lengths = null) { var nulls = 0; - var lastLength = lengths?[lengths.Length - 1] ?? count; - ref var lastIndex = ref indices[indices.Length - 1]; + var lastLength = lengths?[^1] ?? count; + ref var lastIndex = ref indices.GetItem(indices.Count - 1); if (ElemTypeDbNullable) do { @@ -82,7 +111,7 @@ Size GetFixedElemsSize(Size elemSize, object values, int count, int[] indices, i nulls++; } // We can immediately continue if we didn't reach the end of the last dimension. - while (++lastIndex < lastLength || (indices.Length > 1 && CarryIndices(lengths!, indices))); + while (++lastIndex < lastLength || (indices.Count > 1 && CarryIndices(lengths!, indices))); return (count - nulls) * elemSize.Value; } @@ -96,18 +125,18 @@ int GetFormatSize(int count, int dimensions) public Size GetSize(SizeContext context, object values, ref object? writeState) { - var count = _elemOps.GetCollectionCount(values, out var lengths); + var count = elemOps.GetCollectionCount(values, out var lengths); var dimensions = lengths?.Length ?? 1; - if (dimensions > 8) - throw new ArgumentException(nameof(values), "Postgres arrays can have at most 8 dimensions."); + if (dimensions > MaxDimensions) + ThrowHelper.ThrowArgumentException($"Postgres arrays can have at most {MaxDimensions} dimensions.", nameof(values)); var formatSize = Size.Create(GetFormatSize(count, dimensions)); if (count is 0) return formatSize; Size elemsSize; - var indices = new int[dimensions]; - if (_bufferRequirements.Write is { Kind: SizeKind.Exact } req) + var indices = Indices.Create(dimensions); + if (bufferRequirements.Write is { Kind: SizeKind.Exact } req) { elemsSize = GetFixedElemsSize(req, values, count, indices, lengths); writeState = new WriteState { Count = count, Indices = indices, Lengths = lengths, ArrayPool = null, Data = default, AnyWriteState = false }; @@ -128,23 +157,44 @@ public Size GetSize(SizeContext context, object values, ref object? writeState) sealed class WriteState : MultiWriteState { public required int Count { get; init; } - public required int[] Indices { get; init; } + public required Indices Indices { get; init; } public required int[]? Lengths { get; init; } } + object ReadDimsAndCreateCollection(PgReader reader, int dimensions, out int lastDimLength) + { + Debug.Assert(!reader.ShouldBuffer((sizeof(int) + sizeof(int)) * dimensions)); + + Span dimLengths = stackalloc int[MaxDimensions]; + lastDimLength = 0; + for (var i = 0; i < dimensions; i++) + { + lastDimLength = reader.ReadInt32(); + _ = reader.ReadInt32(); // Lower bound + dimLengths[i] = lastDimLength; + } + + var collection = elemOps.CreateCollection(dimLengths.Slice(0, dimensions)); + Debug.Assert(dimensions <= 1 || collection is Array a && a.Rank == dimensions); + return collection; + } + public async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken = default) { if (reader.ShouldBuffer(sizeof(int) + sizeof(int) + sizeof(uint))) await reader.Buffer(async, sizeof(int) + sizeof(int) + sizeof(uint), cancellationToken).ConfigureAwait(false); var dimensions = reader.ReadInt32(); + if (dimensions > MaxDimensions) + ThrowHelper.ThrowInvalidOperationException($"Postgres arrays can have at most {MaxDimensions} dimensions."); + var containsNulls = reader.ReadInt32() is 1; _ = reader.ReadUInt32(); // Element OID. - if (dimensions is not 0 && _expectedDimensions is not null && dimensions != _expectedDimensions) + if (dimensions is not 0 && expectedDimensions is not null && dimensions != expectedDimensions) ThrowHelper.ThrowInvalidCastException( $"Cannot read an array value with {dimensions} dimension{(dimensions == 1 ? "" : "s")} into a " - + $"collection type with {_expectedDimensions} dimension{(_expectedDimensions == 1 ? "" : "s")}. " + + $"collection type with {expectedDimensions} dimension{(expectedDimensions == 1 ? "" : "s")}. " + $"Call GetValue or a version of GetFieldValue with the commas being the expected amount of dimensions."); if (containsNulls && !ElemTypeDbNullable) @@ -154,32 +204,13 @@ public async ValueTask Read(bool async, PgReader reader, CancellationTok if (reader.ShouldBuffer((sizeof(int) + sizeof(int)) * dimensions)) await reader.Buffer(async, (sizeof(int) + sizeof(int)) * dimensions, cancellationToken).ConfigureAwait(false); - var dimLengths = new int[_expectedDimensions ?? dimensions]; - var lastDimLength = 0; - for (var i = 0; i < dimensions; i++) - { - lastDimLength = reader.ReadInt32(); - reader.ReadInt32(); // Lower bound - if (dimLengths.Length is 0) - break; - dimLengths[i] = lastDimLength; - } - - var collection = _elemOps.CreateCollection(dimLengths); - Debug.Assert(dimensions <= 1 || collection is Array a && a.Rank == dimensions); - + var collection = ReadDimsAndCreateCollection(reader, dimensions, out var lastDimLength); if (dimensions is 0 || lastDimLength is 0) return collection; - int[] indices; - // Reuse array for dim <= 1 - if (dimensions == 1) - { - dimLengths[0] = 0; - indices = dimLengths; - } - else - indices = new int[dimensions]; + _ = elemOps.GetCollectionCount(collection, out var dimLengths); + var indices = Indices.Create(dimensions); + do { if (reader.ShouldBuffer(sizeof(int))) @@ -189,10 +220,10 @@ public async ValueTask Read(bool async, PgReader reader, CancellationTok var isDbNull = length == -1; if (!isDbNull) { - var scope = await reader.BeginNestedRead(async, length, _bufferRequirements.Read, cancellationToken).ConfigureAwait(false); + var scope = await reader.BeginNestedRead(async, length, bufferRequirements.Read, cancellationToken).ConfigureAwait(false); try { - await _elemOps.Read(async, reader, isDbNull, collection, indices, cancellationToken).ConfigureAwait(false); + await elemOps.Read(async, reader, isDbNull, collection, indices, cancellationToken).ConfigureAwait(false); } finally { @@ -203,26 +234,27 @@ public async ValueTask Read(bool async, PgReader reader, CancellationTok } } else - await _elemOps.Read(async, reader, isDbNull, collection, indices, cancellationToken).ConfigureAwait(false); + await elemOps.Read(async, reader, isDbNull, collection, indices, cancellationToken).ConfigureAwait(false); } // We can immediately continue if we didn't reach the end of the last dimension. - while (++indices[indices.Length - 1] < lastDimLength || (dimensions > 1 && CarryIndices(dimLengths, indices))); + while (++indices.GetItem(indices.Count - 1) < lastDimLength || (dimLengths is not null && CarryIndices(dimLengths, indices))); return collection; } - static bool CarryIndices(int[] lengths, int[] indices) + static bool CarryIndices(int[] lengths, Indices indices) { Debug.Assert(lengths.Length > 1); + Debug.Assert(indices.Count > 1); // Find the first dimension from the end that isn't at or past its length, increment it and bring all previous dimensions to zero. - for (var dim = indices.Length - 1; dim >= 0; dim--) + for (var dim = indices.Count - 1; dim >= 0; dim--) { - if (indices[dim] >= lengths[dim] - 1) + if (indices.GetItem(dim) >= lengths[dim] - 1) continue; - indices.AsSpan().Slice(dim + 1).Clear(); - indices[dim]++; + indices.Many.AsSpan().Slice(dim + 1).Clear(); + indices.GetItem(dim)++; return true; } @@ -244,11 +276,11 @@ public async ValueTask Write(bool async, PgWriter writer, object values, Cancell writer.WriteInt32(dims); // Dimensions writer.WriteInt32(0); // Flags (not really used) - writer.WriteAsOid(_elemTypeId); + writer.WriteAsOid(elemTypeId); for (var dim = 0; dim < dims; dim++) { writer.WriteInt32(state?.Lengths?[dim] ?? count); - writer.WriteInt32(_pgLowerBound); // Lower bound + writer.WriteInt32(pgLowerBound); // Lower bound } // We can stop here for empty collections. @@ -259,8 +291,9 @@ public async ValueTask Write(bool async, PgWriter writer, object values, Cancell var elemData = state.Data.Array; var indices = state.Indices; - Array.Clear(indices, 0 , indices.Length); - var lastLength = state.Lengths?[state.Lengths.Length - 1] ?? state.Count; + if (indices.Many is not null) + Array.Clear(indices.Many, 0 , indices.Many.Length); + var lastLength = state.Lengths?[^1] ?? state.Count; var i = state.Data.Offset; do { @@ -268,7 +301,7 @@ public async ValueTask Write(bool async, PgWriter writer, object values, Cancell await writer.Flush(async, cancellationToken).ConfigureAwait(false); var elem = elemData?[i++]; - var size = elem?.Size ?? (elemTypeDbNullable && IsDbNull(values, indices) ? -1 : _bufferRequirements.Write); + var size = elem?.Size ?? (elemTypeDbNullable && IsDbNull(values, indices) ? -1 : bufferRequirements.Write); if (size.Kind is SizeKind.Unknown) throw new NotImplementedException(); @@ -276,21 +309,48 @@ public async ValueTask Write(bool async, PgWriter writer, object values, Cancell writer.WriteInt32(length); if (length != -1) { - using var _ = await writer.BeginNestedWrite(async, _bufferRequirements.Write, length, elem?.WriteState, cancellationToken).ConfigureAwait(false); - await _elemOps.Write(async, writer, values, indices, cancellationToken).ConfigureAwait(false); + using var _ = await writer.BeginNestedWrite(async, bufferRequirements.Write, length, elem?.WriteState, cancellationToken).ConfigureAwait(false); + await elemOps.Write(async, writer, values, indices, cancellationToken).ConfigureAwait(false); } } // We can immediately continue if we didn't reach the end of the last dimension. - while (++indices[indices.Length - 1] < lastLength || (indices.Length > 1 && CarryIndices(state.Lengths!, indices))); + while (++indices.GetItem(indices.Count - 1) < lastLength || (state.Lengths is not null && CarryIndices(state.Lengths, indices))); + } + + // Using a function pointer here is safe against assembly unloading as the instance reference that the static pointer method lives on is passed along. + // As such the instance cannot be collected by the gc which means the entire assembly is prevented from unloading until we're done. + // The alternatives are: + // 1. Add a virtual method and make AwaitTask call into it (bloating the vtable of all derived types). + // 2. Using a delegate, meaning we add a static field + an alloc per T + metadata, slightly slower dispatch perf so overall strictly worse as well. + [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder))] + public static async ValueTask AwaitTask(Task task, Continuation continuation, object collection, Indices indices) + { + await task.ConfigureAwait(false); + continuation.Invoke(task, collection, indices); + // Guarantee the type stays loaded until the function pointer call is done. + GC.KeepAlive(continuation.Handle); + } + + // Split out into a struct as unsafe and async don't mix, while we do want a nicely typed function pointer signature to prevent mistakes. + public readonly unsafe struct Continuation + { + public object Handle { get; } + readonly delegate* _continuation; + + /// A reference to the type that houses the static method points to. + /// The continuation + public Continuation(object handle, delegate* continuation) + { + Handle = handle; + _continuation = continuation; + } + + public void Invoke(Task task, object collection, Indices indices) => _continuation(task, collection, indices); } } -// Class constraint exists to make Unsafe.As, ValueTask> safe, don't remove unless that unsafe cast is also removed. -abstract class ArrayConverter : PgStreamingConverter where T : class +abstract class ArrayConverter : PgStreamingConverter where T : notnull { - protected PgConverterResolution ElemResolution { get; } - protected Type ElemTypeToConvert { get; } - readonly PgArrayConverter _pgArrayConverter; private protected ArrayConverter(int? expectedDimensions, PgConverterResolution elemResolution, int pgLowerBound = 1) @@ -298,18 +358,37 @@ private protected ArrayConverter(int? expectedDimensions, PgConverterResolution if (!elemResolution.Converter.CanConvert(DataFormat.Binary, out var bufferRequirements)) throw new NotSupportedException("Element converter has to support the binary format to be compatible."); - ElemResolution = elemResolution; - ElemTypeToConvert = elemResolution.Converter.TypeToConvert; _pgArrayConverter = new((IElementOperations)this, elemResolution.Converter.IsDbNullable, expectedDimensions, bufferRequirements, elemResolution.PgTypeId, pgLowerBound); } public override T Read(PgReader reader) => (T)_pgArrayConverter.Read(async: false, reader).Result; - public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + public override unsafe ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) { - var value = _pgArrayConverter.Read(async: true, reader, cancellationToken); - return Unsafe.As, ValueTask>(ref value); + // Cheap if we have all the data. + var task = _pgArrayConverter.Read(async: true, reader, cancellationToken); + if (task.IsCompletedSuccessfully) + return new((T)task.Result); + + // Otherwise do these additional allocations (source and task) to allow us to share state machine codegen for all Ts. + // We don't use the PoolingCompletionSource here as it would be backed by an IValueTaskSource. + // Any ReadAsObjectAsync caller would call AsTask() on it immediately, causing another allocation and indirection. + var source = new AsyncHelpers.CompletionSource(); + AsyncHelpers.OnCompletedWithSource(task.AsTask(), source, new(this, &UnboxAndComplete)); + return source.Task; + + static void UnboxAndComplete(Task task, AsyncHelpers.CompletionSource completionSource) + { + // Justification: exact type Unsafe.As used to reduce generic duplication cost when T is a value type (like ReadOnlyMemory). + Debug.Assert(task is Task); + // Using .Result on ValueTask is equivalent to GetAwaiter().GetResult(), this removes TaskAwaiter rooting. + var result = (T)new ValueTask(Unsafe.As>(task)).Result; + + // Justification: exact type Unsafe.As used to reduce generic duplication cost. + Debug.Assert(completionSource is AsyncHelpers.CompletionSource); + Unsafe.As>(completionSource).SetResult(result); + } } public override Size GetSize(SizeContext context, T values, ref object? writeState) @@ -321,118 +400,92 @@ public override void Write(PgWriter writer, T values) public override ValueTask WriteAsync(PgWriter writer, T values, CancellationToken cancellationToken = default) => _pgArrayConverter.Write(async: true, writer, values, cancellationToken); - // Using a function pointer here is safe against assembly unloading as the instance reference that the static pointer method lives on is passed along. - // As such the instance cannot be collected by the gc which means the entire assembly is prevented from unloading until we're done. - // The alternatives are: - // 1. Add a virtual method and make AwaitTask call into it (bloating the vtable of all derived types). - // 2. Using a delegate, meaning we add a static field + an alloc per T + metadata, slightly slower dispatch perf so overall strictly worse as well. - [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder))] - private protected static async ValueTask AwaitTask(Task task, Continuation continuation, object collection, int[] indices) - { - await task.ConfigureAwait(false); - continuation.Invoke(task, collection, indices); - // Guarantee the type stays loaded until the function pointer call is done. - GC.KeepAlive(continuation.Handle); - } - - // Split out into a struct as unsafe and async don't mix, while we do want a nicely typed function pointer signature to prevent mistakes. - protected readonly unsafe struct Continuation + protected static int GetLengths(Array array, out int[]? lengths) { - public object Handle { get; } - readonly delegate* _continuation; + var dimensions = array.Rank; - /// A reference to the type that houses the static method points to. - /// The continuation - public Continuation(object handle, delegate* continuation) + if (dimensions is 1) { - Handle = handle; - _continuation = continuation; + lengths = null; + return array.Length; } - public void Invoke(Task task, object collection, int[] indices) => _continuation(task, collection, indices); - } - - protected static int[]? GetLengths(Array array) - { - if (array.Rank == 1) - return null; - - var lengths = new int[array.Rank]; + lengths = new int[dimensions]; for (var i = 0; i < lengths.Length; i++) lengths[i] = array.GetLength(i); - return lengths; + // If we have a multidim array it may throw an overflow exception for large arrays (LongLength exists for these cases) + // however anything over int.MaxValue wouldn't fit in a parameter anyway so easier to throw here than deal with a long. + return array.Length; } } -sealed class ArrayBasedArrayConverter : ArrayConverter, IElementOperations where T : class +sealed class ArrayBasedArrayConverter(PgConverterResolution elemResolution, Type? effectiveType = null, int pgLowerBound = 1) + : ArrayConverter(expectedDimensions: effectiveType is null ? 1 : effectiveType.IsArray ? effectiveType.GetArrayRank() : null, + elemResolution, pgLowerBound), IElementOperations + where T : class { - readonly PgConverter _elemConverter; - - public ArrayBasedArrayConverter(PgConverterResolution elemResolution, Type? effectiveType = null, int pgLowerBound = 1) - : base( - expectedDimensions: effectiveType is null ? 1 : effectiveType.IsArray ? effectiveType.GetArrayRank() : null, - elemResolution, pgLowerBound) - => _elemConverter = elemResolution.GetConverter(); + readonly PgConverter _elemConverter = elemResolution.GetConverter(); [MethodImpl(MethodImplOptions.AggressiveInlining)] - static TElement? GetValue(object collection, int[] indices) + static TElement? GetValue(object collection, Indices indices) { - switch (indices.Length) + Debug.Assert(indices.Count > 0); + switch (indices.Count) { case 1: + // Justification: exact type Unsafe.As used to avoid the cast overhead for per element calls. Debug.Assert(collection is TElement?[]); - return Unsafe.As(collection)[indices[0]]; + return Unsafe.As(collection)[indices.One]; default: + // Justification: exact type Unsafe.As used to avoid the cast overhead for per element calls. Debug.Assert(collection is Array); - return (TElement?)Unsafe.As(collection).GetValue(indices); + return (TElement?)Unsafe.As(collection).GetValue(indices.Many!); } } [MethodImpl(MethodImplOptions.AggressiveInlining)] - static void SetValue(object collection, int[] indices, TElement? value) + static void SetValue(object collection, Indices indices, TElement? value) { - switch (indices.Length) + Debug.Assert(indices.Count > 0); + switch (indices.Count) { case 1: + // Justification: exact type Unsafe.As used to avoid the cast overhead for per element calls. Debug.Assert(collection is TElement?[]); - Unsafe.As(collection)[indices[0]] = value; + Unsafe.As(collection)[indices.One] = value; break; default: + // Justification: exact type Unsafe.As used to avoid the cast overhead for per element calls. Debug.Assert(collection is Array); - Unsafe.As(collection).SetValue(value, indices); + Unsafe.As(collection).SetValue(value, indices.Many!); break; } } - object IElementOperations.CreateCollection(int[] lengths) + object IElementOperations.CreateCollection(ReadOnlySpan lengths) => lengths.Length switch { 0 => Array.Empty(), 1 when lengths[0] == 0 => Array.Empty(), 1 => new TElement?[lengths[0]], - 2 => new TElement?[lengths[0],lengths[1]], - 3 => new TElement?[lengths[0],lengths[1], lengths[2]], - 4 => new TElement?[lengths[0],lengths[1], lengths[2], lengths[3]], - 5 => new TElement?[lengths[0],lengths[1], lengths[2], lengths[3], lengths[4]], - 6 => new TElement?[lengths[0],lengths[1], lengths[2], lengths[3], lengths[4], lengths[5]], - 7 => new TElement?[lengths[0],lengths[1], lengths[2], lengths[3], lengths[4], lengths[5], lengths[6]], - 8 => new TElement?[lengths[0],lengths[1], lengths[2], lengths[3], lengths[4], lengths[5], lengths[6], lengths[7]], + 2 => new TElement?[lengths[0], lengths[1]], + 3 => new TElement?[lengths[0], lengths[1], lengths[2]], + 4 => new TElement?[lengths[0], lengths[1], lengths[2], lengths[3]], + 5 => new TElement?[lengths[0], lengths[1], lengths[2], lengths[3], lengths[4]], + 6 => new TElement?[lengths[0], lengths[1], lengths[2], lengths[3], lengths[4], lengths[5]], + 7 => new TElement?[lengths[0], lengths[1], lengths[2], lengths[3], lengths[4], lengths[5], lengths[6]], + 8 => new TElement?[lengths[0], lengths[1], lengths[2], lengths[3], lengths[4], lengths[5], lengths[6], lengths[7]], _ => throw new InvalidOperationException("Postgres arrays can have at most 8 dimensions.") }; int IElementOperations.GetCollectionCount(object collection, out int[]? lengths) - { - Debug.Assert(collection is Array); - var array = Unsafe.As(collection); - lengths = GetLengths(array); - return array.Length; - } + => GetLengths((Array)collection, out lengths); - Size? IElementOperations.GetSizeOrDbNull(SizeContext context, object collection, int[] indices, ref object? writeState) + Size? IElementOperations.GetSizeOrDbNull(SizeContext context, object collection, Indices indices, ref object? writeState) => _elemConverter.GetSizeOrDbNull(context.Format, context.BufferRequirement, GetValue(collection, indices), ref writeState); - ValueTask IElementOperations.Read(bool async, PgReader reader, bool isDbNull, object collection, int[] indices, CancellationToken cancellationToken) + ValueTask IElementOperations.Read(bool async, PgReader reader, bool isDbNull, object collection, Indices indices, CancellationToken cancellationToken) { if (!isDbNull && async && _elemConverter is PgStreamingConverter streamingConverter) return ReadAsync(streamingConverter, reader, collection, indices, cancellationToken); @@ -441,23 +494,24 @@ ValueTask IElementOperations.Read(bool async, PgReader reader, bool isDbNull, ob return new(); } - unsafe ValueTask ReadAsync(PgStreamingConverter converter, PgReader reader, object collection, int[] indices, CancellationToken cancellationToken) + unsafe ValueTask ReadAsync(PgStreamingConverter converter, PgReader reader, object collection, Indices indices, CancellationToken cancellationToken) { if (converter.ReadAsyncAsTask(reader, cancellationToken, out var result) is { } task) - return AwaitTask(task, new(this, &SetResult), collection, indices); + return PgArrayConverter.AwaitTask(task, new(this, &SetResult), collection, indices); SetValue(collection, indices, result); return new(); - // Using .Result on ValueTask is equivalent to GetAwaiter().GetResult(), this removes TaskAwaiter rooting. - static void SetResult(Task task, object collection, int[] indices) + static void SetResult(Task task, object collection, Indices indices) { + // Justification: exact type Unsafe.As used to reduce generic duplication cost. Debug.Assert(task is Task); - SetValue(collection, indices, new ValueTask(Unsafe.As>(task)).Result); + // Using .Result on ValueTask is equivalent to GetAwaiter().GetResult(), this removes TaskAwaiter rooting. + SetValue(collection, indices, new ValueTask(task: Unsafe.As>(task)).Result); } } - ValueTask IElementOperations.Write(bool async, PgWriter writer, object collection, int[] indices, CancellationToken cancellationToken) + ValueTask IElementOperations.Write(bool async, PgWriter writer, object collection, Indices indices, CancellationToken cancellationToken) { if (async) return _elemConverter.WriteAsync(writer, GetValue(collection, indices)!, cancellationToken); @@ -467,17 +521,16 @@ ValueTask IElementOperations.Write(bool async, PgWriter writer, object collectio } } -sealed class ListBasedArrayConverter : ArrayConverter, IElementOperations where T : class +sealed class ListBasedArrayConverter(PgConverterResolution elemResolution, int pgLowerBound = 1) + : ArrayConverter(expectedDimensions: 1, elemResolution, pgLowerBound), IElementOperations + where T : class { - readonly PgConverter _elemConverter; - - public ListBasedArrayConverter(PgConverterResolution elemResolution, int pgLowerBound = 1) - : base(expectedDimensions: 1, elemResolution, pgLowerBound) - => _elemConverter = elemResolution.GetConverter(); + readonly PgConverter _elemConverter = elemResolution.GetConverter(); [MethodImpl(MethodImplOptions.AggressiveInlining)] static TElement? GetValue(object collection, int index) { + // Justification: avoid the cast overhead for per element calls. Debug.Assert(collection is IList); return Unsafe.As>(collection)[index]; } @@ -485,69 +538,68 @@ public ListBasedArrayConverter(PgConverterResolution elemResolution, int pgLower [MethodImpl(MethodImplOptions.AggressiveInlining)] static void SetValue(object collection, int index, TElement? value) { + // Justification: avoid the cast overhead for per element calls. Debug.Assert(collection is IList); var list = Unsafe.As>(collection); list.Insert(index, value); } - object IElementOperations.CreateCollection(int[] lengths) + object IElementOperations.CreateCollection(ReadOnlySpan lengths) => new List(lengths.Length is 0 ? 0 : lengths[0]); int IElementOperations.GetCollectionCount(object collection, out int[]? lengths) { - Debug.Assert(collection is IList); lengths = null; - return Unsafe.As>(collection).Count; + return ((IList)collection).Count; } - Size? IElementOperations.GetSizeOrDbNull(SizeContext context, object collection, int[] indices, ref object? writeState) - => _elemConverter.GetSizeOrDbNull(context.Format, context.BufferRequirement, GetValue(collection, indices[0]), ref writeState); + Size? IElementOperations.GetSizeOrDbNull(SizeContext context, object collection, Indices indices, ref object? writeState) + => _elemConverter.GetSizeOrDbNull(context.Format, context.BufferRequirement, GetValue(collection, indices.One), ref writeState); - ValueTask IElementOperations.Read(bool async, PgReader reader, bool isDbNull, object collection, int[] indices, CancellationToken cancellationToken) + ValueTask IElementOperations.Read(bool async, PgReader reader, bool isDbNull, object collection, Indices indices, CancellationToken cancellationToken) { - Debug.Assert(indices.Length is 1); + Debug.Assert(indices.Count is 1); if (!isDbNull && async && _elemConverter is PgStreamingConverter streamingConverter) return ReadAsync(streamingConverter, reader, collection, indices, cancellationToken); - SetValue(collection, indices[0], isDbNull ? default : _elemConverter.Read(reader)); + SetValue(collection, indices.One, isDbNull ? default : _elemConverter.Read(reader)); return new(); } - unsafe ValueTask ReadAsync(PgStreamingConverter converter, PgReader reader, object collection, int[] indices, CancellationToken cancellationToken) + unsafe ValueTask ReadAsync(PgStreamingConverter converter, PgReader reader, object collection, Indices indices, CancellationToken cancellationToken) { + Debug.Assert(indices.Count is 1); if (converter.ReadAsyncAsTask(reader, cancellationToken, out var result) is { } task) - return AwaitTask(task, new(this, &SetResult), collection, indices); + return PgArrayConverter.AwaitTask(task, new(this, &SetResult), collection, indices); - SetValue(collection, indices[0], result); + SetValue(collection, indices.One, result); return new(); - // Using .Result on ValueTask is equivalent to GetAwaiter().GetResult(), this removes TaskAwaiter rooting. - static void SetResult(Task task, object collection, int[] indices) + static void SetResult(Task task, object collection, Indices indices) { + // Justification: exact type Unsafe.As used to reduce generic duplication cost. Debug.Assert(task is Task); - SetValue(collection, indices[0], new ValueTask(Unsafe.As>(task)).Result); + // Using .Result on ValueTask is equivalent to GetAwaiter().GetResult(), this removes TaskAwaiter rooting. + SetValue(collection, indices.One, new ValueTask(task: Unsafe.As>(task)).Result); } } - ValueTask IElementOperations.Write(bool async, PgWriter writer, object collection, int[] indices, CancellationToken cancellationToken) + ValueTask IElementOperations.Write(bool async, PgWriter writer, object collection, Indices indices, CancellationToken cancellationToken) { - Debug.Assert(indices.Length is 1); + Debug.Assert(indices.Count is 1); if (async) - return _elemConverter.WriteAsync(writer, GetValue(collection, indices[0])!, cancellationToken); + return _elemConverter.WriteAsync(writer, GetValue(collection, indices.One)!, cancellationToken); - _elemConverter.Write(writer, GetValue(collection, indices[0])!); + _elemConverter.Write(writer, GetValue(collection, indices.One)!); return new(); } } -sealed class ArrayConverterResolver : PgComposingConverterResolver where T : class +sealed class ArrayConverterResolver(PgResolverTypeInfo elementTypeInfo, Type effectiveType) + : PgComposingConverterResolver(elementTypeInfo.PgTypeId is { } id ? elementTypeInfo.Options.GetArrayTypeId(id) : null, + elementTypeInfo) + where T : class { - readonly Type _effectiveType; - - public ArrayConverterResolver(PgResolverTypeInfo elementTypeInfo, Type effectiveType) - : base(elementTypeInfo.PgTypeId is { } id ? elementTypeInfo.Options.GetArrayTypeId(id) : null, elementTypeInfo) - => _effectiveType = effectiveType; - PgSerializerOptions Options => EffectiveTypeInfo.Options; protected override PgTypeId GetEffectivePgTypeId(PgTypeId pgTypeId) => Options.GetArrayElementTypeId(pgTypeId); @@ -556,7 +608,7 @@ public ArrayConverterResolver(PgResolverTypeInfo elementTypeInfo, Type effective protected override PgConverter CreateConverter(PgConverterResolution effectiveResolution) { if (typeof(T) == typeof(Array) || typeof(T).IsArray) - return new ArrayBasedArrayConverter(effectiveResolution, _effectiveType); + return new ArrayBasedArrayConverter(effectiveResolution, effectiveType); if (typeof(T).IsConstructedGenericType && typeof(T).GetGenericTypeDefinition() == typeof(IList<>)) return new ListBasedArrayConverter(effectiveResolution); @@ -613,15 +665,15 @@ protected override PgConverter CreateConverter(PgConverterResolution effectiv } // T is Array as we only know what type it will be after reading 'contains nulls'. -sealed class PolymorphicArrayConverter : PgStreamingConverter +sealed class PolymorphicArrayConverter( + PgConverter structElementCollectionConverter, + PgConverter nullableElementCollectionConverter) + : PgStreamingConverter { - readonly PgConverter _structElementCollectionConverter; - readonly PgConverter _nullableElementCollectionConverter; - - public PolymorphicArrayConverter(PgConverter structElementCollectionConverter, PgConverter nullableElementCollectionConverter) + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { - _structElementCollectionConverter = structElementCollectionConverter; - _nullableElementCollectionConverter = nullableElementCollectionConverter; + bufferRequirements = BufferRequirements.Create(read: sizeof(int) + sizeof(int), write: Size.Unknown); + return format is DataFormat.Binary; } public override TBase Read(PgReader reader) @@ -630,8 +682,8 @@ public override TBase Read(PgReader reader) var containsNulls = reader.ReadInt32() is 1; reader.Rewind(sizeof(int) + sizeof(int)); return containsNulls - ? _nullableElementCollectionConverter.Read(reader) - : _structElementCollectionConverter.Read(reader); + ? nullableElementCollectionConverter.Read(reader) + : structElementCollectionConverter.Read(reader); } public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) @@ -640,8 +692,8 @@ public override ValueTask ReadAsync(PgReader reader, CancellationToken ca var containsNulls = reader.ReadInt32() is 1; reader.Rewind(sizeof(int) + sizeof(int)); return containsNulls - ? _nullableElementCollectionConverter.ReadAsync(reader, cancellationToken) - : _structElementCollectionConverter.ReadAsync(reader, cancellationToken); + ? nullableElementCollectionConverter.ReadAsync(reader, cancellationToken) + : structElementCollectionConverter.ReadAsync(reader, cancellationToken); } public override Size GetSize(SizeContext context, TBase value, ref object? writeState) diff --git a/src/Npgsql/Internal/Converters/AsyncHelpers.cs b/src/Npgsql/Internal/Converters/AsyncHelpers.cs index ccf8780ca0..bf85a06a9f 100644 --- a/src/Npgsql/Internal/Converters/AsyncHelpers.cs +++ b/src/Npgsql/Internal/Converters/AsyncHelpers.cs @@ -6,35 +6,48 @@ namespace Npgsql.Internal.Converters; -// Using a function pointer here is safe against assembly unloading as the instance reference that the static pointer method lives on is passed along. -// As such the instance cannot be collected by the gc which means the entire assembly is prevented from unloading until we're done. static class AsyncHelpers { - static async void AwaitTask(Task task, CompletionSource tcs, Continuation continuation) + public static void OnCompletedWithSource(Task task, CompletionSource source, CompletionSourceContinuation continuation) { - try - { - await task.ConfigureAwait(false); - continuation.Invoke(task, tcs); - } - catch (Exception ex) + _ = Core(task, source, continuation); + + // Have our state machine be pooled, but don't return the task, source.Task should be used instead. + [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder))] + async ValueTask Core(Task task, CompletionSource source, CompletionSourceContinuation continuation) { - tcs.SetException(ex); + try + { + await task.ConfigureAwait(false); + continuation.Invoke(task, source); + } + catch (Exception ex) + { + source.SetException(ex); + } + // Guarantee the type stays loaded until the function pointer call is done. + continuation.KeepAlive(); } - // Guarantee the type stays loaded until the function pointer call is done. - GC.KeepAlive(continuation.Handle); } - abstract class CompletionSource + public abstract class CompletionSource { public abstract void SetException(Exception exception); } - sealed class CompletionSource : CompletionSource + public sealed class CompletionSource : CompletionSource { - PoolingAsyncValueTaskMethodBuilder _amb = PoolingAsyncValueTaskMethodBuilder.Create(); + AsyncValueTaskMethodBuilder _amb; - public ValueTask Task => _amb.Task; + public ValueTask Task { get; } + + public CompletionSource() + { + _amb = AsyncValueTaskMethodBuilder.Create(); + // AsyncValueTaskMethodBuilder's Task and SetResult aren't thread safe in regard to each other + // Which is why we access it prematurely + Task = _amb.Task; + } public void SetResult(T value) => _amb.SetResult(value); @@ -43,68 +56,88 @@ public override void SetException(Exception exception) => _amb.SetException(exception); } + public sealed class PoolingCompletionSource : CompletionSource + { + PoolingAsyncValueTaskMethodBuilder _amb; + + public ValueTask Task { get; } + + public PoolingCompletionSource() + { + _amb = PoolingAsyncValueTaskMethodBuilder.Create(); + // PoolingAsyncValueTaskMethodBuilder's Task and SetResult aren't thread safe in regard to each other + // Which is why we access it prematurely + Task = _amb.Task; + } + + public void SetResult(T value) + => _amb.SetResult(value); + + public override void SetException(Exception exception) + => _amb.SetException(exception); + } + + // Using a function pointer here is safe against assembly unloading as the instance reference that the static pointer method lives on is passed along. + // As such the instance cannot be collected by the gc which means the entire assembly is prevented from unloading until we're done. // Split out into a struct as unsafe and async don't mix, while we do want a nicely typed function pointer signature to prevent mistakes. - readonly unsafe struct Continuation + public readonly unsafe struct CompletionSourceContinuation { - public object Handle { get; } + readonly object _handle; readonly delegate* _continuation; /// A reference to the type that houses the static method points to. /// The continuation - public Continuation(object handle, delegate* continuation) + public CompletionSourceContinuation(object handle, delegate* continuation) { - Handle = handle; + _handle = handle; _continuation = continuation; } + public void KeepAlive() => GC.KeepAlive(_handle); + public void Invoke(Task task, CompletionSource tcs) => _continuation(task, tcs); } public static unsafe ValueTask ReadAsyncAsNullable(this PgConverter instance, PgConverter effectiveConverter, PgReader reader, CancellationToken cancellationToken) where T : struct { - // Easy if we have all the data. + // Cheap if we have all the data. var task = effectiveConverter.ReadAsync(reader, cancellationToken); if (task.IsCompletedSuccessfully) return new(new T?(task.Result)); - // Otherwise we do one additional allocation, this allow us to share state machine codegen for all Ts. - var source = new CompletionSource(); - AwaitTask(task.AsTask(), source, new(instance, &UnboxAndComplete)); + // Otherwise we do one additional allocation, this allows us to share state machine codegen for all Ts. + var source = new PoolingCompletionSource(); + OnCompletedWithSource(task.AsTask(), source, new(instance, &UnboxAndComplete)); return source.Task; static void UnboxAndComplete(Task task, CompletionSource completionSource) { - // Justification: unsafe exact cast used to reduce generic duplication cost. + // Justification: exact type Unsafe.As used to reduce generic duplication cost. Debug.Assert(task is Task); - Debug.Assert(completionSource is CompletionSource); - Unsafe.As>(completionSource).SetResult(new T?(new ValueTask(Unsafe.As>(task)).Result)); + Debug.Assert(completionSource is PoolingCompletionSource); + Unsafe.As>(completionSource).SetResult(new T?(new ValueTask(Unsafe.As>(task)).Result)); } } public static unsafe ValueTask ReadAsObjectAsyncAsT(this PgConverter instance, PgConverter effectiveConverter, PgReader reader, CancellationToken cancellationToken) { - if (!typeof(T).IsValueType) - { - var value = effectiveConverter.ReadAsObjectAsync(reader, cancellationToken); - return Unsafe.As, ValueTask>(ref value); - } - - // Easy if we have all the data. + // Cheap if we have all the data. var task = effectiveConverter.ReadAsObjectAsync(reader, cancellationToken); if (task.IsCompletedSuccessfully) return new((T)task.Result); - // Otherwise we do one additional allocation, this allow us to share state machine codegen for all Ts. - var source = new CompletionSource(); - AwaitTask(task.AsTask(), source, new(instance, &UnboxAndComplete)); + // Otherwise we do one additional allocation, this allows us to share state machine codegen for all Ts. + var source = new PoolingCompletionSource(); + OnCompletedWithSource(task.AsTask(), source, new(instance, &UnboxAndComplete)); return source.Task; static void UnboxAndComplete(Task task, CompletionSource completionSource) { + // Justification: exact type Unsafe.As used to reduce generic duplication cost. Debug.Assert(task is Task); - Debug.Assert(completionSource is CompletionSource); - Unsafe.As>(completionSource).SetResult((T)new ValueTask(Unsafe.As>(task)).Result); + Debug.Assert(completionSource is PoolingCompletionSource); + Unsafe.As>(completionSource).SetResult((T)new ValueTask(Unsafe.As>(task)).Result); } } } diff --git a/src/Npgsql/Internal/Converters/BitStringConverters.cs b/src/Npgsql/Internal/Converters/BitStringConverters.cs index b7597f96d9..d0d6327a20 100644 --- a/src/Npgsql/Internal/Converters/BitStringConverters.cs +++ b/src/Npgsql/Internal/Converters/BitStringConverters.cs @@ -11,19 +11,16 @@ namespace Npgsql.Internal.Converters; -static class BitStringHelpers +file static class BitStringHelpers { - public static int GetByteLengthFromBits(int n) + public static int GetByteCountFromBitCount(int n) { const int BitShiftPerByte = 3; Debug.Assert(n >= 0); // Due to sign extension, we don't need to special case for n == 0, since ((n - 1) >> 3) + 1 = 0 // This doesn't hold true for ((n - 1) / 8) + 1, which equals 1. - return (int)((uint)(n - 1 + (1 << BitShiftPerByte)) >> BitShiftPerByte); + return (n - 1 + (1 << BitShiftPerByte)) >>> BitShiftPerByte; } - - // http://graphics.stanford.edu/~seander/bithacks.html#ReverseByteWith64Bits - public static byte ReverseBits(byte b) => (byte)(((b * 0x80200802UL) & 0x0884422110UL) * 0x0101010101UL >> 32); } sealed class BitArrayBitStringConverter : PgStreamingConverter @@ -34,7 +31,7 @@ public override BitArray Read(PgReader reader) reader.Buffer(sizeof(int)); var bits = reader.ReadInt32(); - var bytes = new byte[GetByteLengthFromBits(bits)]; + var bytes = new byte[GetByteCountFromBitCount(bits)]; reader.ReadBytes(bytes); return ReadValue(bytes, bits); } @@ -44,7 +41,7 @@ public override async ValueTask ReadAsync(PgReader reader, Cancellatio await reader.BufferAsync(sizeof(int), cancellationToken).ConfigureAwait(false); var bits = reader.ReadInt32(); - var bytes = new byte[GetByteLengthFromBits(bits)]; + var bytes = new byte[GetByteCountFromBitCount(bits)]; await reader.ReadBytesAsync(bytes, cancellationToken).ConfigureAwait(false); return ReadValue(bytes, bits); } @@ -58,10 +55,13 @@ internal static BitArray ReadValue(byte[] bytes, int bits) } return new(bytes) { Length = bits }; + + // https://graphics.stanford.edu/~seander/bithacks.html#ReverseByteWith64Bits + static byte ReverseBits(byte b) => (byte)(((b * 0x80200802UL) & 0x0884422110UL) * 0x0101010101UL >> 32); } public override Size GetSize(SizeContext context, BitArray value, ref object? writeState) - => sizeof(int) + GetByteLengthFromBits(value.Length); + => sizeof(int) + GetByteCountFromBitCount(value.Length); public override void Write(PgWriter writer, BitArray value) => Write(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); @@ -97,11 +97,9 @@ async ValueTask Write(bool async, PgWriter writer, BitArray value, CancellationT sealed class BitVector32BitStringConverter : PgBufferedConverter { - static int MaxSize => sizeof(int) + sizeof(int); - public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { - bufferRequirements = BufferRequirements.Create(Size.CreateUpperBound(MaxSize)); + bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(int) + sizeof(int)); return format is DataFormat.Binary; } @@ -111,7 +109,7 @@ protected override BitVector32 ReadCore(PgReader reader) throw new InvalidCastException("Can't read a BIT(N) with more than 32 bits to BitVector32, only up to BIT(32)."); var bits = reader.ReadInt32(); - return GetByteLengthFromBits(bits) switch + return GetByteCountFromBitCount(bits) switch { 4 => new(reader.ReadInt32()), 3 => new((reader.ReadInt16() << 8) + reader.ReadByte()), @@ -121,18 +119,10 @@ protected override BitVector32 ReadCore(PgReader reader) }; } - public override Size GetSize(SizeContext context, BitVector32 value, ref object? writeState) - => value.Data is 0 ? 4 : MaxSize; - protected override void WriteCore(PgWriter writer, BitVector32 value) { - if (value.Data == 0) - writer.WriteInt32(0); - else - { - writer.WriteInt32(32); - writer.WriteInt32(value.Data); - } + writer.WriteInt32(32); + writer.WriteInt32(value.Data); } } @@ -179,7 +169,7 @@ async ValueTask Read(bool async, PgReader reader, CancellationToken canc await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); var bits = reader.ReadInt32(); - var bytes = new byte[GetByteLengthFromBits(bits)]; + var bytes = new byte[GetByteCountFromBitCount(bits)]; if (async) await reader.ReadBytesAsync(bytes, cancellationToken).ConfigureAwait(false); else @@ -198,7 +188,7 @@ public override Size GetSize(SizeContext context, string value, ref object? writ if (value.AsSpan().IndexOfAnyExcept('0', '1') is not -1 and var index) throw new ArgumentException($"Invalid bitstring character '{value[index]}' at index: {index}", nameof(value)); - return sizeof(int) + GetByteLengthFromBits(value.Length); + return sizeof(int) + GetByteCountFromBitCount(value.Length); } public override void Write(PgWriter writer, string value) @@ -235,13 +225,11 @@ async ValueTask Write(bool async, PgWriter writer, string value, CancellationTok /// Note that for BIT(1), this resolver will return a bool by default, to align with SqlClient /// (see discussion https://github.com/npgsql/npgsql/pull/362#issuecomment-59622101). -sealed class PolymorphicBitStringConverterResolver : PolymorphicConverterResolver +sealed class PolymorphicBitStringConverterResolver(PgTypeId bitString) : PolymorphicConverterResolver(bitString) { BoolBitStringConverter? _boolConverter; BitArrayBitStringConverter? _bitArrayConverter; - public PolymorphicBitStringConverterResolver(PgTypeId bitString) : base(bitString) { } - protected override PgConverter Get(Field? field) => field?.TypeModifier is 1 ? _boolConverter ??= new BoolBitStringConverter() diff --git a/src/Npgsql/Internal/Converters/CastingConverter.cs b/src/Npgsql/Internal/Converters/CastingConverter.cs index 3fbfc5059d..a2b83fd94c 100644 --- a/src/Npgsql/Internal/Converters/CastingConverter.cs +++ b/src/Npgsql/Internal/Converters/CastingConverter.cs @@ -7,53 +7,47 @@ namespace Npgsql.Internal.Converters; /// A converter to map strongly typed apis onto boxed converter results to produce a strongly typed converter over T. -sealed class CastingConverter : PgConverter +sealed class CastingConverter(PgConverter effectiveConverter) + : PgConverter(effectiveConverter.DbNullPredicateKind is DbNullPredicate.Custom) { - readonly PgConverter _effectiveConverter; - public CastingConverter(PgConverter effectiveConverter) - : base(effectiveConverter.DbNullPredicateKind is DbNullPredicate.Custom) - => _effectiveConverter = effectiveConverter; - - protected override bool IsDbNullValue(T? value, ref object? writeState) => _effectiveConverter.IsDbNullAsObject(value, ref writeState); + protected override bool IsDbNullValue(T? value, ref object? writeState) => effectiveConverter.IsDbNullAsObject(value, ref writeState); public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) - => _effectiveConverter.CanConvert(format, out bufferRequirements); + => effectiveConverter.CanConvert(format, out bufferRequirements); - public override T Read(PgReader reader) => (T)_effectiveConverter.ReadAsObject(reader); + public override T Read(PgReader reader) => (T)effectiveConverter.ReadAsObject(reader); public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) - => this.ReadAsObjectAsyncAsT(_effectiveConverter, reader, cancellationToken); + => this.ReadAsObjectAsyncAsT(effectiveConverter, reader, cancellationToken); public override Size GetSize(SizeContext context, T value, ref object? writeState) - => _effectiveConverter.GetSizeAsObject(context, value!, ref writeState); + => effectiveConverter.GetSizeAsObject(context, value!, ref writeState); public override void Write(PgWriter writer, T value) - => _effectiveConverter.WriteAsObject(writer, value!); + => effectiveConverter.WriteAsObject(writer, value!); public override ValueTask WriteAsync(PgWriter writer, T value, CancellationToken cancellationToken = default) - => _effectiveConverter.WriteAsObjectAsync(writer, value!, cancellationToken); + => effectiveConverter.WriteAsObjectAsync(writer, value!, cancellationToken); internal override ValueTask ReadAsObject(bool async, PgReader reader, CancellationToken cancellationToken) => async - ? _effectiveConverter.ReadAsObjectAsync(reader, cancellationToken) - : new(_effectiveConverter.ReadAsObject(reader)); + ? effectiveConverter.ReadAsObjectAsync(reader, cancellationToken) + : new(effectiveConverter.ReadAsObject(reader)); internal override ValueTask WriteAsObject(bool async, PgWriter writer, object value, CancellationToken cancellationToken) { if (async) - return _effectiveConverter.WriteAsObjectAsync(writer, value, cancellationToken); + return effectiveConverter.WriteAsObjectAsync(writer, value, cancellationToken); - _effectiveConverter.WriteAsObject(writer, value); + effectiveConverter.WriteAsObject(writer, value); return new(); } } // Given there aren't many instantiations of converter resolvers (and it's fairly involved to write a fast one) we use the composing base class. -sealed class CastingConverterResolver : PgComposingConverterResolver +sealed class CastingConverterResolver(PgResolverTypeInfo effectiveResolverTypeInfo) + : PgComposingConverterResolver(effectiveResolverTypeInfo.PgTypeId, effectiveResolverTypeInfo) { - public CastingConverterResolver(PgResolverTypeInfo effectiveResolverTypeInfo) - : base(effectiveResolverTypeInfo.PgTypeId, effectiveResolverTypeInfo) { } - protected override PgTypeId GetEffectivePgTypeId(PgTypeId pgTypeId) => pgTypeId; protected override PgTypeId GetPgTypeId(PgTypeId effectivePgTypeId) => effectivePgTypeId; diff --git a/src/Npgsql/Internal/Converters/CompositeConverter.cs b/src/Npgsql/Internal/Converters/CompositeConverter.cs index 24f3d36329..2c985b647c 100644 --- a/src/Npgsql/Internal/Converters/CompositeConverter.cs +++ b/src/Npgsql/Internal/Converters/CompositeConverter.cs @@ -21,14 +21,17 @@ public CompositeConverter(CompositeInfo composite) var readReq = field.BinaryReadRequirement; var writeReq = field.BinaryWriteRequirement; - // If so we cannot depend on its buffer size being fixed. + // If field is nullable we cannot depend on its buffer size being fixed. if (field.IsDbNullable) { readReq = readReq.Combine(Size.CreateUpperBound(0)); writeReq = writeReq.Combine(Size.CreateUpperBound(0)); } - req = req.Combine(readReq, writeReq); + var readSuccess = req.Read.TryCombine(readReq, out readReq); + var writeSuccess = req.Write.TryCombine(writeReq, out writeReq); + // If we fail to combine due to overflow return unknown. + req = BufferRequirements.Create(readSuccess ? readReq : Size.Unknown, writeSuccess ? writeReq : Size.Unknown); } // We have to put a limit on the requirements we report otherwise smaller buffer sizes won't work. @@ -37,7 +40,7 @@ public CompositeConverter(CompositeInfo composite) _bufferRequirements = req; // Return unknown if we hit the limit. - Size Limit(Size requirement) + static Size Limit(Size requirement) { const int maxByteCount = 1024; return requirement.GetValueOrDefault() > maxByteCount ? requirement.Combine(Size.Unknown) : requirement; diff --git a/src/Npgsql/Internal/Converters/FullTextSearch/TsQueryConverter.cs b/src/Npgsql/Internal/Converters/FullTextSearch/TsQueryConverter.cs index 220cc88894..9e88fbe8f1 100644 --- a/src/Npgsql/Internal/Converters/FullTextSearch/TsQueryConverter.cs +++ b/src/Npgsql/Internal/Converters/FullTextSearch/TsQueryConverter.cs @@ -10,14 +10,9 @@ // ReSharper disable once CheckNamespace namespace Npgsql.Internal.Converters; -sealed class TsQueryConverter : PgStreamingConverter +sealed class TsQueryConverter(Encoding encoding) : PgStreamingConverter where T : NpgsqlTsQuery { - readonly Encoding _encoding; - - public TsQueryConverter(Encoding encoding) - => _encoding = encoding; - public override T Read(PgReader reader) => (T)Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); @@ -49,8 +44,8 @@ async ValueTask Read(bool async, PgReader reader, CancellationTok var prefix = reader.ReadByte() != 0; var str = async - ? await reader.ReadNullTerminatedStringAsync(_encoding, cancellationToken).ConfigureAwait(false) - : reader.ReadNullTerminatedString(_encoding); + ? await reader.ReadNullTerminatedStringAsync(encoding, cancellationToken).ConfigureAwait(false) + : reader.ReadNullTerminatedString(encoding); InsertInTree(new NpgsqlTsQueryLexeme(str, weight, prefix), nodes, ref value); continue; @@ -134,7 +129,7 @@ public override Size GetSize(SizeContext context, T value, ref object? writeStat int GetNodeLength(NpgsqlTsQuery node) => node.Kind switch { - Lexeme when _encoding.GetByteCount(((NpgsqlTsQueryLexeme)node).Text) is var strLen + Lexeme when encoding.GetByteCount(((NpgsqlTsQueryLexeme)node).Text) is var strLen => strLen > 2046 ? throw new InvalidCastException("Lexeme text too long. Must be at most 2046 encoded bytes.") : 4 + strLen, @@ -185,9 +180,9 @@ async Task WriteCore(NpgsqlTsQuery node) writer.WriteByte(lexemeNode.IsPrefixSearch ? (byte)1 : (byte)0); if (async) - await writer.WriteCharsAsync(lexemeNode.Text.AsMemory(), _encoding, cancellationToken).ConfigureAwait(false); + await writer.WriteCharsAsync(lexemeNode.Text.AsMemory(), encoding, cancellationToken).ConfigureAwait(false); else - writer.WriteChars(lexemeNode.Text.AsMemory().Span, _encoding); + writer.WriteChars(lexemeNode.Text.AsMemory().Span, encoding); if (writer.ShouldFlush(sizeof(byte))) await writer.Flush(async, cancellationToken).ConfigureAwait(false); diff --git a/src/Npgsql/Internal/Converters/FullTextSearch/TsVectorConverter.cs b/src/Npgsql/Internal/Converters/FullTextSearch/TsVectorConverter.cs index 2c431fd35b..04b16b80f5 100644 --- a/src/Npgsql/Internal/Converters/FullTextSearch/TsVectorConverter.cs +++ b/src/Npgsql/Internal/Converters/FullTextSearch/TsVectorConverter.cs @@ -8,13 +8,8 @@ // ReSharper disable once CheckNamespace namespace Npgsql.Internal.Converters; -sealed class TsVectorConverter : PgStreamingConverter +sealed class TsVectorConverter(Encoding encoding) : PgStreamingConverter { - readonly Encoding _encoding; - - public TsVectorConverter(Encoding encoding) - => _encoding = encoding; - public override NpgsqlTsVector Read(PgReader reader) => Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); @@ -32,8 +27,8 @@ async ValueTask Read(bool async, PgReader reader, CancellationTo for (var i = 0; i < numLexemes; i++) { var lexemeString = async - ? await reader.ReadNullTerminatedStringAsync(_encoding, cancellationToken).ConfigureAwait(false) - : reader.ReadNullTerminatedString(_encoding); + ? await reader.ReadNullTerminatedStringAsync(encoding, cancellationToken).ConfigureAwait(false) + : reader.ReadNullTerminatedString(encoding); if (reader.ShouldBuffer(sizeof(short))) await reader.Buffer(async, sizeof(short), cancellationToken).ConfigureAwait(false); @@ -70,7 +65,7 @@ public override Size GetSize(SizeContext context, NpgsqlTsVector value, ref obje { var size = 4; foreach (var l in value) - size += _encoding.GetByteCount(l.Text) + 1 + 2 + l.Count * 2; + size += encoding.GetByteCount(l.Text) + 1 + 2 + l.Count * 2; return size; } @@ -90,9 +85,9 @@ async ValueTask Write(bool async, PgWriter writer, NpgsqlTsVector value, Cancell foreach (var lexeme in value) { if (async) - await writer.WriteCharsAsync(lexeme.Text.AsMemory(), _encoding, cancellationToken).ConfigureAwait(false); + await writer.WriteCharsAsync(lexeme.Text.AsMemory(), encoding, cancellationToken).ConfigureAwait(false); else - writer.WriteChars(lexeme.Text.AsMemory().Span, _encoding); + writer.WriteChars(lexeme.Text.AsMemory().Span, encoding); if (writer.ShouldFlush(sizeof(byte) + sizeof(short))) await writer.Flush(async, cancellationToken).ConfigureAwait(false); diff --git a/src/Npgsql/Internal/Converters/Geometric/CubeConverter.cs b/src/Npgsql/Internal/Converters/Geometric/CubeConverter.cs new file mode 100644 index 0000000000..05b539cf12 --- /dev/null +++ b/src/Npgsql/Internal/Converters/Geometric/CubeConverter.cs @@ -0,0 +1,87 @@ +using System.Threading; +using System.Threading.Tasks; +using NpgsqlTypes; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class CubeConverter : PgStreamingConverter +{ + const uint PointBit = 0x80000000; + const int DimMask = 0x7fffffff; + + public override NpgsqlCube Read(PgReader reader) + => Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) + => Read(async: true, reader, cancellationToken); + + async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) + { + if (reader.ShouldBuffer(sizeof(int))) + await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); + + var header = reader.ReadInt32(); + var dim = header & DimMask; + var point = (header & PointBit) != 0; + + var lowerLeft = new double[dim]; + for (var i = 0; i < dim; i++) + { + if (reader.ShouldBuffer(sizeof(double))) + await reader.Buffer(async, sizeof(double), cancellationToken).ConfigureAwait(false); + lowerLeft[i] = reader.ReadDouble(); + } + + if (point) + return new NpgsqlCube(lowerLeft); + + var upperRight = new double[dim]; + for (var i = 0; i < dim; i++) + { + if (reader.ShouldBuffer(sizeof(double))) + await reader.Buffer(async, sizeof(double), cancellationToken).ConfigureAwait(false); + upperRight[i] = reader.ReadDouble(); + } + + return new NpgsqlCube(lowerLeft, upperRight); + } + + public override Size GetSize(SizeContext context, NpgsqlCube value, ref object? writeState) + => sizeof(int) + sizeof(double) * (value.IsPoint ? value.Dimensions : value.Dimensions * 2); + + public override void Write(PgWriter writer, NpgsqlCube value) + => Write(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); + + public override ValueTask WriteAsync(PgWriter writer, NpgsqlCube value, CancellationToken cancellationToken = default) + => Write(async: true, writer, value, cancellationToken); + + async ValueTask Write(bool async, PgWriter writer, NpgsqlCube value, CancellationToken cancellationToken) + { + if (writer.ShouldFlush(sizeof(int))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + + var header = value.Dimensions; + if (value.IsPoint) + header |= 1 << 31; + + writer.WriteInt32(header); + + for (var i = 0; i < value.Dimensions; i++) + { + if (writer.ShouldFlush(sizeof(double))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + writer.WriteDouble(value.LowerLeft[i]); + } + + if (value.IsPoint) + return; + + for (var i = 0; i < value.Dimensions; i++) + { + if (writer.ShouldFlush(sizeof(double))) + await writer.Flush(async, cancellationToken).ConfigureAwait(false); + writer.WriteDouble(value.UpperRight[i]); + } + } +} diff --git a/src/Npgsql/Internal/Converters/Geometric/PathConverter.cs b/src/Npgsql/Internal/Converters/Geometric/PathConverter.cs index c78ba84013..0481037254 100644 --- a/src/Npgsql/Internal/Converters/Geometric/PathConverter.cs +++ b/src/Npgsql/Internal/Converters/Geometric/PathConverter.cs @@ -32,7 +32,7 @@ async ValueTask Read(bool async, PgReader reader, CancellationToken for (var i = 0; i < numPoints; i++) { if (reader.ShouldBuffer(sizeof(double) * 2)) - await reader.Buffer(async, sizeof(byte) + sizeof(int), cancellationToken).ConfigureAwait(false); + await reader.Buffer(async, sizeof(double) * 2, cancellationToken).ConfigureAwait(false); result.Add(new NpgsqlPoint(reader.ReadDouble(), reader.ReadDouble())); } diff --git a/src/Npgsql/Internal/Converters/HstoreConverter.cs b/src/Npgsql/Internal/Converters/HstoreConverter.cs index e2e8762d8e..f9514450f7 100644 --- a/src/Npgsql/Internal/Converters/HstoreConverter.cs +++ b/src/Npgsql/Internal/Converters/HstoreConverter.cs @@ -7,17 +7,10 @@ namespace Npgsql.Internal.Converters; -sealed class HstoreConverter : PgStreamingConverter where T : ICollection> +sealed class HstoreConverter(Encoding encoding, Func>, T>? convert = null) + : PgStreamingConverter + where T : ICollection> { - readonly Encoding _encoding; - readonly Func>, T>? _convert; - - public HstoreConverter(Encoding encoding, Func>, T>? convert = null) - { - _encoding = encoding; - _convert = convert; - } - public override T Read(PgReader reader) => Read(async: false, reader, CancellationToken.None).Result; @@ -40,8 +33,8 @@ public override Size GetSize(SizeContext context, T value, ref object? writeStat if (kv.Key is null) throw new ArgumentException("Hstore doesn't support null keys", nameof(value)); - var keySize = _encoding.GetByteCount(kv.Key); - var valueSize = kv.Value is null ? -1 : _encoding.GetByteCount(kv.Value); + var keySize = encoding.GetByteCount(kv.Key); + var valueSize = kv.Value is null ? -1 : encoding.GetByteCount(kv.Value); totalSize += keySize + (valueSize is -1 ? 0 : valueSize); data[i] = (keySize, null); data[i + 1] = (valueSize, null); @@ -78,7 +71,7 @@ async ValueTask Read(bool async, PgReader reader, CancellationToken cancellat if (reader.ShouldBuffer(sizeof(int))) await reader.Buffer(async, sizeof(int), cancellationToken).ConfigureAwait(false); var keySize = reader.ReadInt32(); - var key = _encoding.GetString(async + var key = encoding.GetString(async ? await reader.ReadBytesAsync(keySize, cancellationToken).ConfigureAwait(false) : reader.ReadBytes(keySize) ); @@ -88,7 +81,7 @@ async ValueTask Read(bool async, PgReader reader, CancellationToken cancellat var valueSize = reader.ReadInt32(); string? value = null; if (valueSize is not -1) - value = _encoding.GetString(async + value = encoding.GetString(async ? await reader.ReadBytesAsync(valueSize, cancellationToken).ConfigureAwait(false) : reader.ReadBytes(valueSize) ); @@ -99,7 +92,7 @@ async ValueTask Read(bool async, PgReader reader, CancellationToken cancellat if (typeof(T) == typeof(Dictionary) || typeof(T) == typeof(IDictionary)) return (T)result; - return _convert is null ? throw new NotSupportedException() : _convert(result); + return convert is null ? throw new NotSupportedException() : convert(result); } async ValueTask Write(bool async, PgWriter writer, T value, CancellationToken cancellationToken) @@ -129,9 +122,9 @@ async ValueTask Write(bool async, PgWriter writer, T value, CancellationToken ca var length = size.Value; writer.WriteInt32(length); if (async) - await writer.WriteCharsAsync(kv.Key.AsMemory(), _encoding, cancellationToken).ConfigureAwait(false); + await writer.WriteCharsAsync(kv.Key.AsMemory(), encoding, cancellationToken).ConfigureAwait(false); else - writer.WriteChars(kv.Key.AsSpan(), _encoding); + writer.WriteChars(kv.Key.AsSpan(), encoding); if (writer.ShouldFlush(sizeof(int))) await writer.Flush(async, cancellationToken).ConfigureAwait(false); @@ -145,9 +138,9 @@ async ValueTask Write(bool async, PgWriter writer, T value, CancellationToken ca if (valueLength is not -1) { if (async) - await writer.WriteCharsAsync(kv.Value.AsMemory(), _encoding, cancellationToken).ConfigureAwait(false); + await writer.WriteCharsAsync(kv.Value.AsMemory(), encoding, cancellationToken).ConfigureAwait(false); else - writer.WriteChars(kv.Value.AsSpan(), _encoding); + writer.WriteChars(kv.Value.AsSpan(), encoding); } i += 2; } diff --git a/src/Npgsql/Internal/Converters/Internal/InternalCharConverter.cs b/src/Npgsql/Internal/Converters/Internal/InternalCharConverter.cs index 5d00a26dcb..881d454d3a 100644 --- a/src/Npgsql/Internal/Converters/Internal/InternalCharConverter.cs +++ b/src/Npgsql/Internal/Converters/Internal/InternalCharConverter.cs @@ -4,10 +4,7 @@ // ReSharper disable once CheckNamespace namespace Npgsql.Internal.Converters; -sealed class InternalCharConverter : PgBufferedConverter -#if NET7_0_OR_GREATER - where T : INumberBase -#endif +sealed class InternalCharConverter : PgBufferedConverter where T : INumberBase { public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { @@ -15,29 +12,6 @@ public override bool CanConvert(DataFormat format, out BufferRequirements buffer return format is DataFormat.Binary; } -#if NET7_0_OR_GREATER protected override T ReadCore(PgReader reader) => T.CreateChecked(reader.ReadByte()); protected override void WriteCore(PgWriter writer, T value) => writer.WriteByte(byte.CreateChecked(value)); -#else - protected override T ReadCore(PgReader reader) - { - var value = reader.ReadByte(); - if (typeof(byte) == typeof(T)) - return (T)(object)value; - if (typeof(char) == typeof(T)) - return (T)(object)(char)value; - - throw new NotSupportedException(); - } - - protected override void WriteCore(PgWriter writer, T value) - { - if (typeof(byte) == typeof(T)) - writer.WriteByte((byte)(object)value!); - else if (typeof(char) == typeof(T)) - writer.WriteByte(checked((byte)(char)(object)value!)); - else - throw new NotSupportedException(); - } -#endif } diff --git a/src/Npgsql/Internal/Converters/JsonConverter.cs b/src/Npgsql/Internal/Converters/JsonConverter.cs index 77157875b3..074575e4e1 100644 --- a/src/Npgsql/Internal/Converters/JsonConverter.cs +++ b/src/Npgsql/Internal/Converters/JsonConverter.cs @@ -107,8 +107,8 @@ public override ValueTask WriteAsync(PgWriter writer, T? value, CancellationToke static class JsonConverter { public const byte JsonbProtocolVersion = 1; - // We pick a value that is the largest multiple of 4096 that is still smaller than the large object heap threshold (85K). - const int StreamingThreshold = 81920; + // Largest value that is a power of 2 and a multiple of 4096 while staying under the large object heap threshold (85K). + const int StreamingThreshold = 65536; public static bool TryReadStream(bool jsonb, Encoding encoding, PgReader reader, out int byteCount, [NotNullWhen(true)]out Stream? stream) { diff --git a/src/Npgsql/Internal/Converters/MoneyConverter.cs b/src/Npgsql/Internal/Converters/MoneyConverter.cs index 8443acedc3..2b6c078a84 100644 --- a/src/Npgsql/Internal/Converters/MoneyConverter.cs +++ b/src/Npgsql/Internal/Converters/MoneyConverter.cs @@ -3,72 +3,17 @@ namespace Npgsql.Internal.Converters; -sealed class MoneyConverter : PgBufferedConverter -#if NET7_0_OR_GREATER - where T : INumberBase -#endif +sealed class MoneyConverter : PgBufferedConverter where T : INumberBase { public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); return format is DataFormat.Binary; } + protected override T ReadCore(PgReader reader) => ConvertTo(new PgMoney(reader.ReadInt64())); protected override void WriteCore(PgWriter writer, T value) => writer.WriteInt64(ConvertFrom(value).GetValue()); - static PgMoney ConvertFrom(T value) - { -#if !NET7_0_OR_GREATER - if (typeof(short) == typeof(T)) - return new PgMoney((decimal)(short)(object)value!); - if (typeof(int) == typeof(T)) - return new PgMoney((decimal)(int)(object)value!); - if (typeof(long) == typeof(T)) - return new PgMoney((decimal)(long)(object)value!); - - if (typeof(byte) == typeof(T)) - return new PgMoney((decimal)(byte)(object)value!); - if (typeof(sbyte) == typeof(T)) - return new PgMoney((decimal)(sbyte)(object)value!); - - if (typeof(float) == typeof(T)) - return new PgMoney((decimal)(float)(object)value!); - if (typeof(double) == typeof(T)) - return new PgMoney((decimal)(double)(object)value!); - if (typeof(decimal) == typeof(T)) - return new PgMoney((decimal)(object)value!); - - throw new NotSupportedException(); -#else - return new PgMoney(decimal.CreateChecked(value)); -#endif - } - - static T ConvertTo(PgMoney money) - { -#if !NET7_0_OR_GREATER - if (typeof(short) == typeof(T)) - return (T)(object)(short)money.ToDecimal(); - if (typeof(int) == typeof(T)) - return (T)(object)(int)money.ToDecimal(); - if (typeof(long) == typeof(T)) - return (T)(object)(long)money.ToDecimal(); - - if (typeof(byte) == typeof(T)) - return (T)(object)(byte)money.ToDecimal(); - if (typeof(sbyte) == typeof(T)) - return (T)(object)(sbyte)money.ToDecimal(); - - if (typeof(float) == typeof(T)) - return (T)(object)(float)money.ToDecimal(); - if (typeof(double) == typeof(T)) - return (T)(object)(double)money.ToDecimal(); - if (typeof(decimal) == typeof(T)) - return (T)(object)money.ToDecimal(); - - throw new NotSupportedException(); -#else - return T.CreateChecked(money.ToDecimal()); -#endif - } + static PgMoney ConvertFrom(T value) => new(decimal.CreateChecked(value)); + static T ConvertTo(PgMoney money) => T.CreateChecked(money.ToDecimal()); } diff --git a/src/Npgsql/Internal/Converters/Networking/IPAddressConverter.cs b/src/Npgsql/Internal/Converters/Networking/IPAddressConverter.cs index 9050f36f16..707bcd016b 100644 --- a/src/Npgsql/Internal/Converters/Networking/IPAddressConverter.cs +++ b/src/Npgsql/Internal/Converters/Networking/IPAddressConverter.cs @@ -7,7 +7,7 @@ namespace Npgsql.Internal.Converters; sealed class IPAddressConverter : PgBufferedConverter { public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) - => CanConvertBufferedDefault(format, out bufferRequirements); + => NpgsqlInetConverter.CanConvertImpl(format, out bufferRequirements); public override Size GetSize(SizeContext context, IPAddress value, ref object? writeState) => NpgsqlInetConverter.GetSizeImpl(context, value, ref writeState); diff --git a/src/Npgsql/Internal/Converters/Networking/IPNetworkConverter.cs b/src/Npgsql/Internal/Converters/Networking/IPNetworkConverter.cs new file mode 100644 index 0000000000..6fc7b5401e --- /dev/null +++ b/src/Npgsql/Internal/Converters/Networking/IPNetworkConverter.cs @@ -0,0 +1,31 @@ +using System; +using System.Net; + +// ReSharper disable once CheckNamespace +namespace Npgsql.Internal.Converters; + +sealed class IPNetworkConverter : PgBufferedConverter +{ + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + => CanConvertBufferedDefault(format, out bufferRequirements); + + public override Size GetSize(SizeContext context, IPNetwork value, ref object? writeState) + => NpgsqlInetConverter.GetSizeImpl(context, value.BaseAddress, ref writeState); + + protected override IPNetwork ReadCore(PgReader reader) + { + var (ip, netmask) = NpgsqlInetConverter.ReadImpl(reader, shouldBeCidr: true); + return new(ip, netmask); + } + + protected override void WriteCore(PgWriter writer, IPNetwork value) + => NpgsqlInetConverter.WriteImpl( + writer, + ( + value.BaseAddress, + value.PrefixLength <= byte.MaxValue + ? (byte)value.PrefixLength + : throw new ArgumentOutOfRangeException(nameof(value), "IPNetwork.PrefixLength is too large to fit in a byte") + ), + isCidr: true); +} diff --git a/src/Npgsql/Internal/Converters/Networking/MacaddrConverter.cs b/src/Npgsql/Internal/Converters/Networking/MacaddrConverter.cs index dd8aac78bc..d9c2aa46e8 100644 --- a/src/Npgsql/Internal/Converters/Networking/MacaddrConverter.cs +++ b/src/Npgsql/Internal/Converters/Networking/MacaddrConverter.cs @@ -5,15 +5,11 @@ // ReSharper disable once CheckNamespace namespace Npgsql.Internal.Converters; -sealed class MacaddrConverter : PgBufferedConverter +sealed class MacaddrConverter(bool macaddr8) : PgBufferedConverter { - readonly bool _macaddr8; - - public MacaddrConverter(bool macaddr8) => _macaddr8 = macaddr8; - public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { - bufferRequirements = _macaddr8 ? BufferRequirements.Create(Size.CreateUpperBound(8)) : BufferRequirements.CreateFixedSize(6); + bufferRequirements = macaddr8 ? BufferRequirements.Create(Size.CreateUpperBound(8)) : BufferRequirements.CreateFixedSize(6); return format is DataFormat.Binary; } @@ -33,7 +29,7 @@ protected override PhysicalAddress ReadCore(PgReader reader) protected override void WriteCore(PgWriter writer, PhysicalAddress value) { var bytes = value.GetAddressBytes(); - if (!_macaddr8 && bytes.Length is not 6) + if (!macaddr8 && bytes.Length is not 6) throw new ArgumentException("A macaddr value must be 6 bytes long."); writer.WriteBytes(bytes); } diff --git a/src/Npgsql/Internal/Converters/Networking/NpgsqlCidrConverter.cs b/src/Npgsql/Internal/Converters/Networking/NpgsqlCidrConverter.cs index c6d0ab8d88..451fab4959 100644 --- a/src/Npgsql/Internal/Converters/Networking/NpgsqlCidrConverter.cs +++ b/src/Npgsql/Internal/Converters/Networking/NpgsqlCidrConverter.cs @@ -3,10 +3,11 @@ // ReSharper disable once CheckNamespace namespace Npgsql.Internal.Converters; +#pragma warning disable CS0618 // NpgsqlCidr is obsolete sealed class NpgsqlCidrConverter : PgBufferedConverter { public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) - => CanConvertBufferedDefault(format, out bufferRequirements); + => NpgsqlInetConverter.CanConvertImpl(format, out bufferRequirements); public override Size GetSize(SizeContext context, NpgsqlCidr value, ref object? writeState) => NpgsqlInetConverter.GetSizeImpl(context, value.Address, ref writeState); @@ -20,3 +21,4 @@ protected override NpgsqlCidr ReadCore(PgReader reader) protected override void WriteCore(PgWriter writer, NpgsqlCidr value) => NpgsqlInetConverter.WriteImpl(writer, (value.Address, value.Netmask), isCidr: true); } +#pragma warning restore CS0618 diff --git a/src/Npgsql/Internal/Converters/Networking/NpgsqlInetConverter.cs b/src/Npgsql/Internal/Converters/Networking/NpgsqlInetConverter.cs index 26ce7cfa96..ea0066c9de 100644 --- a/src/Npgsql/Internal/Converters/Networking/NpgsqlInetConverter.cs +++ b/src/Npgsql/Internal/Converters/Networking/NpgsqlInetConverter.cs @@ -13,7 +13,13 @@ sealed class NpgsqlInetConverter : PgBufferedConverter const byte IPv6 = 3; public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) - => CanConvertBufferedDefault(format, out bufferRequirements); + => CanConvertImpl(format, out bufferRequirements); + + internal static bool CanConvertImpl(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.Create(Size.CreateUpperBound(20)); + return format == DataFormat.Binary; + } public override Size GetSize(SizeContext context, NpgsqlInet value, ref object? writeState) => GetSizeImpl(context, value.Address, ref writeState); diff --git a/src/Npgsql/Internal/Converters/NullableConverter.cs b/src/Npgsql/Internal/Converters/NullableConverter.cs index 292def140a..57a12e005f 100644 --- a/src/Npgsql/Internal/Converters/NullableConverter.cs +++ b/src/Npgsql/Internal/Converters/NullableConverter.cs @@ -7,46 +7,42 @@ namespace Npgsql.Internal.Converters; // NULL writing is always responsibility of the caller writing the length, so there is not much we do here. /// Special value converter to be able to use struct converters as System.Nullable converters, it delegates all behavior to the effective converter. -sealed class NullableConverter : PgConverter where T : struct +sealed class NullableConverter(PgConverter effectiveConverter) + : PgConverter(effectiveConverter.DbNullPredicateKind is DbNullPredicate.Custom) + where T : struct { - readonly PgConverter _effectiveConverter; - public NullableConverter(PgConverter effectiveConverter) - : base(effectiveConverter.DbNullPredicateKind is DbNullPredicate.Custom) - => _effectiveConverter = effectiveConverter; - protected override bool IsDbNullValue(T? value, ref object? writeState) - => value is null || _effectiveConverter.IsDbNull(value.GetValueOrDefault(), ref writeState); + => value is null || effectiveConverter.IsDbNull(value.GetValueOrDefault(), ref writeState); public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) - => _effectiveConverter.CanConvert(format, out bufferRequirements); + => effectiveConverter.CanConvert(format, out bufferRequirements); public override T? Read(PgReader reader) - => _effectiveConverter.Read(reader); + => effectiveConverter.Read(reader); public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) - => this.ReadAsyncAsNullable(_effectiveConverter, reader, cancellationToken); + => this.ReadAsyncAsNullable(effectiveConverter, reader, cancellationToken); public override Size GetSize(SizeContext context, [DisallowNull]T? value, ref object? writeState) - => _effectiveConverter.GetSize(context, value.GetValueOrDefault(), ref writeState); + => effectiveConverter.GetSize(context, value.GetValueOrDefault(), ref writeState); public override void Write(PgWriter writer, T? value) - => _effectiveConverter.Write(writer, value.GetValueOrDefault()); + => effectiveConverter.Write(writer, value.GetValueOrDefault()); public override ValueTask WriteAsync(PgWriter writer, T? value, CancellationToken cancellationToken = default) - => _effectiveConverter.WriteAsync(writer, value.GetValueOrDefault(), cancellationToken); + => effectiveConverter.WriteAsync(writer, value.GetValueOrDefault(), cancellationToken); internal override ValueTask ReadAsObject(bool async, PgReader reader, CancellationToken cancellationToken) - => _effectiveConverter.ReadAsObject(async, reader, cancellationToken); + => effectiveConverter.ReadAsObject(async, reader, cancellationToken); internal override ValueTask WriteAsObject(bool async, PgWriter writer, object value, CancellationToken cancellationToken) - => _effectiveConverter.WriteAsObject(async, writer, value, cancellationToken); + => effectiveConverter.WriteAsObject(async, writer, value, cancellationToken); } -sealed class NullableConverterResolver : PgComposingConverterResolver where T : struct +sealed class NullableConverterResolver(PgResolverTypeInfo effectiveTypeInfo) + : PgComposingConverterResolver(effectiveTypeInfo.PgTypeId, effectiveTypeInfo) + where T : struct { - public NullableConverterResolver(PgResolverTypeInfo effectiveTypeInfo) - : base(effectiveTypeInfo.PgTypeId, effectiveTypeInfo) { } - protected override PgTypeId GetEffectivePgTypeId(PgTypeId pgTypeId) => pgTypeId; protected override PgTypeId GetPgTypeId(PgTypeId effectivePgTypeId) => effectivePgTypeId; diff --git a/src/Npgsql/Internal/Converters/ObjectConverter.cs b/src/Npgsql/Internal/Converters/ObjectConverter.cs index 568fc32c2b..4889c60fad 100644 --- a/src/Npgsql/Internal/Converters/ObjectConverter.cs +++ b/src/Npgsql/Internal/Converters/ObjectConverter.cs @@ -1,22 +1,13 @@ using System; +using System.Diagnostics; using System.Threading; using System.Threading.Tasks; using Npgsql.Internal.Postgres; namespace Npgsql.Internal; -sealed class ObjectConverter : PgStreamingConverter +sealed class ObjectConverter(PgSerializerOptions options, PgTypeId pgTypeId) : PgStreamingConverter(customDbNullPredicate: true) { - readonly PgSerializerOptions _options; - readonly PgTypeId _pgTypeId; - - public ObjectConverter(PgSerializerOptions options, PgTypeId pgTypeId) - : base(customDbNullPredicate: true) - { - _options = options; - _pgTypeId = pgTypeId; - } - protected override bool IsDbNullValue(object? value, ref object? writeState) { if (value is null or DBNull) @@ -47,6 +38,7 @@ public override Size GetSize(SizeContext context, object value, ref object? writ // We can call GetDefaultResolution here as validation has already happened in IsDbNullValue. // And we know it was called due to the writeState being filled. + Debug.Assert(typeInfo.PgTypeId is not null); var converter = typeInfo is PgResolverTypeInfo resolverTypeInfo ? resolverTypeInfo.GetDefaultResolution(null).Converter : typeInfo.GetResolution().Converter; @@ -89,6 +81,7 @@ async ValueTask Write(bool async, PgWriter writer, object value, CancellationTok // We can call GetDefaultResolution here as validation has already happened in IsDbNullValue. // And we know it was called due to the writeState being filled. + Debug.Assert(typeInfo.PgTypeId is not null); var converter = typeInfo is PgResolverTypeInfo resolverTypeInfo ? resolverTypeInfo.GetDefaultResolution(null).Converter : typeInfo.GetResolution().Converter; @@ -98,8 +91,8 @@ async ValueTask Write(bool async, PgWriter writer, object value, CancellationTok } PgTypeInfo GetTypeInfo(Type type) - => _options.GetTypeInfo(type, _pgTypeId) - ?? throw new NotSupportedException($"Writing values of '{type.FullName}' having DataTypeName '{_options.DatabaseInfo.GetPostgresType(_pgTypeId).DisplayName}' is not supported."); + => options.GetTypeInfoInternal(type, pgTypeId) + ?? throw new NotSupportedException($"Writing values of '{type.FullName}' having DataTypeName '{options.DatabaseInfo.GetPostgresType(pgTypeId).DisplayName}' is not supported."); sealed class WriteState { diff --git a/src/Npgsql/Internal/Converters/PolymorphicConverterResolver.cs b/src/Npgsql/Internal/Converters/PolymorphicConverterResolver.cs index 7c78e34a24..7cf355d103 100644 --- a/src/Npgsql/Internal/Converters/PolymorphicConverterResolver.cs +++ b/src/Npgsql/Internal/Converters/PolymorphicConverterResolver.cs @@ -5,11 +5,9 @@ namespace Npgsql.Internal.Converters; -abstract class PolymorphicConverterResolver : PgConverterResolver +abstract class PolymorphicConverterResolver(PgTypeId pgTypeId) : PgConverterResolver { - protected PolymorphicConverterResolver(PgTypeId pgTypeId) => PgTypeId = pgTypeId; - - protected PgTypeId PgTypeId { get; } + protected PgTypeId PgTypeId { get; } = pgTypeId; protected abstract PgConverter Get(Field? field); diff --git a/src/Npgsql/Internal/Converters/Primitive/DoubleConverter.cs b/src/Npgsql/Internal/Converters/Primitive/DoubleConverter.cs index 74a56d06ae..8bc9caaf67 100644 --- a/src/Npgsql/Internal/Converters/Primitive/DoubleConverter.cs +++ b/src/Npgsql/Internal/Converters/Primitive/DoubleConverter.cs @@ -4,10 +4,7 @@ // ReSharper disable once CheckNamespace namespace Npgsql.Internal.Converters; -sealed class DoubleConverter : PgBufferedConverter -#if NET7_0_OR_GREATER - where T : INumberBase -#endif +sealed class DoubleConverter : PgBufferedConverter where T : INumberBase { public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { @@ -15,29 +12,6 @@ public override bool CanConvert(DataFormat format, out BufferRequirements buffer return format is DataFormat.Binary; } -#if NET7_0_OR_GREATER protected override T ReadCore(PgReader reader) => T.CreateChecked(reader.ReadDouble()); protected override void WriteCore(PgWriter writer, T value) => writer.WriteDouble(double.CreateChecked(value)); -#else - protected override T ReadCore(PgReader reader) - { - var value = reader.ReadDouble(); - if (typeof(float) == typeof(T)) - return (T)(object)value; - if (typeof(double) == typeof(T)) - return (T)(object)value; - - throw new NotSupportedException(); - } - - protected override void WriteCore(PgWriter writer, T value) - { - if (typeof(float) == typeof(T)) - writer.WriteDouble((float)(object)value!); - else if (typeof(double) == typeof(T)) - writer.WriteDouble((double)(object)value!); - else - throw new NotSupportedException(); - } -#endif } diff --git a/src/Npgsql/Internal/Converters/Primitive/GuidUuidConverter.cs b/src/Npgsql/Internal/Converters/Primitive/GuidUuidConverter.cs index 596deedfce..18e6b0edc5 100644 --- a/src/Npgsql/Internal/Converters/Primitive/GuidUuidConverter.cs +++ b/src/Npgsql/Internal/Converters/Primitive/GuidUuidConverter.cs @@ -11,60 +11,14 @@ public override bool CanConvert(DataFormat format, out BufferRequirements buffer bufferRequirements = BufferRequirements.CreateFixedSize(16 * sizeof(byte)); return format is DataFormat.Binary; } + protected override Guid ReadCore(PgReader reader) - { -#if NET8_0_OR_GREATER - return new Guid(reader.ReadBytes(16).FirstSpan, bigEndian: true); -#else - return new GuidRaw - { - Data1 = reader.ReadInt32(), - Data2 = reader.ReadInt16(), - Data3 = reader.ReadInt16(), - Data4 = BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(reader.ReadInt64()) : reader.ReadInt64() - }.Value; -#endif - } + => new(reader.ReadBytes(16).FirstSpan, bigEndian: true); protected override void WriteCore(PgWriter writer, Guid value) { -#if NET8_0_OR_GREATER Span bytes = stackalloc byte[16]; value.TryWriteBytes(bytes, bigEndian: true, out _); writer.WriteBytes(bytes); -#else - var raw = new GuidRaw(value); - - writer.WriteInt32(raw.Data1); - writer.WriteInt16(raw.Data2); - writer.WriteInt16(raw.Data3); - writer.WriteInt64(BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(raw.Data4) : raw.Data4); -#endif - } - -#if !NET8_0_OR_GREATER - // The following table shows .NET GUID vs Postgres UUID (RFC 4122) layouts. - // - // Note that the first fields are converted from/to native endianness (handled by the Read* - // and Write* methods), while the last field is always read/written in big-endian format. - // - // We're reverting endianness on little endian systems to get it into big endian format. - // - // | Bits | Bytes | Name | Endianness (GUID) | Endianness (RFC 4122) | - // | ---- | ----- | ----- | ----------------- | --------------------- | - // | 32 | 4 | Data1 | Native | Big | - // | 16 | 2 | Data2 | Native | Big | - // | 16 | 2 | Data3 | Native | Big | - // | 64 | 8 | Data4 | Big | Big | - [StructLayout(LayoutKind.Explicit)] - struct GuidRaw - { - [FieldOffset(0)] public Guid Value; - [FieldOffset(0)] public int Data1; - [FieldOffset(4)] public short Data2; - [FieldOffset(6)] public short Data3; - [FieldOffset(8)] public long Data4; - public GuidRaw(Guid value) : this() => Value = value; } -#endif } diff --git a/src/Npgsql/Internal/Converters/Primitive/Int2Converter.cs b/src/Npgsql/Internal/Converters/Primitive/Int2Converter.cs index e54658d925..741af9a75e 100644 --- a/src/Npgsql/Internal/Converters/Primitive/Int2Converter.cs +++ b/src/Npgsql/Internal/Converters/Primitive/Int2Converter.cs @@ -4,67 +4,14 @@ // ReSharper disable once CheckNamespace namespace Npgsql.Internal.Converters; -sealed class Int2Converter : PgBufferedConverter -#if NET7_0_OR_GREATER - where T : INumberBase -#endif +sealed class Int2Converter : PgBufferedConverter where T : INumberBase { public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(short)); return format is DataFormat.Binary; } -#if NET7_0_OR_GREATER + protected override T ReadCore(PgReader reader) => T.CreateChecked(reader.ReadInt16()); protected override void WriteCore(PgWriter writer, T value) => writer.WriteInt16(short.CreateChecked(value)); -#else - protected override T ReadCore(PgReader reader) - { - var value = reader.ReadInt16(); - if (typeof(short) == typeof(T)) - return (T)(object)value; - if (typeof(int) == typeof(T)) - return (T)(object)(int)value; - if (typeof(long) == typeof(T)) - return (T)(object)(long)value; - - if (typeof(byte) == typeof(T)) - return (T)(object)checked((byte)value); - if (typeof(sbyte) == typeof(T)) - return (T)(object)checked((sbyte)value); - - if (typeof(float) == typeof(T)) - return (T)(object)(float)value; - if (typeof(double) == typeof(T)) - return (T)(object)(double)value; - if (typeof(decimal) == typeof(T)) - return (T)(object)(decimal)value; - - throw new NotSupportedException(); - } - - protected override void WriteCore(PgWriter writer, T value) - { - if (typeof(short) == typeof(T)) - writer.WriteInt16((short)(object)value!); - else if (typeof(int) == typeof(T)) - writer.WriteInt16(checked((short)(int)(object)value!)); - else if (typeof(long) == typeof(T)) - writer.WriteInt16(checked((short)(long)(object)value!)); - - else if (typeof(byte) == typeof(T)) - writer.WriteInt16((byte)(object)value!); - else if (typeof(sbyte) == typeof(T)) - writer.WriteInt16((sbyte)(object)value!); - - else if (typeof(float) == typeof(T)) - writer.WriteInt16(checked((short)(float)(object)value!)); - else if (typeof(double) == typeof(T)) - writer.WriteInt16(checked((short)(double)(object)value!)); - else if (typeof(decimal) == typeof(T)) - writer.WriteInt16((short)(decimal)(object)value!); - else - throw new NotSupportedException(); - } -#endif } diff --git a/src/Npgsql/Internal/Converters/Primitive/Int4Converter.cs b/src/Npgsql/Internal/Converters/Primitive/Int4Converter.cs index 1831ca9b1e..4327d2f2e7 100644 --- a/src/Npgsql/Internal/Converters/Primitive/Int4Converter.cs +++ b/src/Npgsql/Internal/Converters/Primitive/Int4Converter.cs @@ -4,10 +4,7 @@ // ReSharper disable once CheckNamespace namespace Npgsql.Internal.Converters; -sealed class Int4Converter : PgBufferedConverter -#if NET7_0_OR_GREATER - where T : INumberBase -#endif +sealed class Int4Converter : PgBufferedConverter where T : INumberBase { public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { @@ -15,57 +12,6 @@ public override bool CanConvert(DataFormat format, out BufferRequirements buffer return format is DataFormat.Binary; } -#if NET7_0_OR_GREATER protected override T ReadCore(PgReader reader) => T.CreateChecked(reader.ReadInt32()); protected override void WriteCore(PgWriter writer, T value) => writer.WriteInt32(int.CreateChecked(value)); -#else - protected override T ReadCore(PgReader reader) - { - var value = reader.ReadInt32(); - if (typeof(short) == typeof(T)) - return (T)(object)checked((short)value); - if (typeof(int) == typeof(T)) - return (T)(object)value; - if (typeof(long) == typeof(T)) - return (T)(object)(long)value; - - if (typeof(byte) == typeof(T)) - return (T)(object)checked((byte)value); - if (typeof(sbyte) == typeof(T)) - return (T)(object)checked((sbyte)value); - - if (typeof(float) == typeof(T)) - return (T)(object)(float)value; - if (typeof(double) == typeof(T)) - return (T)(object)(double)value; - if (typeof(decimal) == typeof(T)) - return (T)(object)(decimal)value; - - throw new NotSupportedException(); - } - - protected override void WriteCore(PgWriter writer, T value) - { - if (typeof(short) == typeof(T)) - writer.WriteInt32((short)(object)value!); - else if (typeof(int) == typeof(T)) - writer.WriteInt32((int)(object)value!); - else if (typeof(long) == typeof(T)) - writer.WriteInt32(checked((int)(long)(object)value!)); - - else if (typeof(byte) == typeof(T)) - writer.WriteInt32((byte)(object)value!); - else if (typeof(sbyte) == typeof(T)) - writer.WriteInt32((sbyte)(object)value!); - - else if (typeof(float) == typeof(T)) - writer.WriteInt32(checked((int)(float)(object)value!)); - else if (typeof(double) == typeof(T)) - writer.WriteInt32(checked((int)(double)(object)value!)); - else if (typeof(decimal) == typeof(T)) - writer.WriteInt32((int)(decimal)(object)value!); - else - throw new NotSupportedException(); - } -#endif } diff --git a/src/Npgsql/Internal/Converters/Primitive/Int8Converter.cs b/src/Npgsql/Internal/Converters/Primitive/Int8Converter.cs index b422816244..09a54cf265 100644 --- a/src/Npgsql/Internal/Converters/Primitive/Int8Converter.cs +++ b/src/Npgsql/Internal/Converters/Primitive/Int8Converter.cs @@ -4,10 +4,7 @@ // ReSharper disable once CheckNamespace namespace Npgsql.Internal.Converters; -sealed class Int8Converter : PgBufferedConverter -#if NET7_0_OR_GREATER - where T : INumberBase -#endif +sealed class Int8Converter : PgBufferedConverter where T : INumberBase { public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { @@ -15,58 +12,6 @@ public override bool CanConvert(DataFormat format, out BufferRequirements buffer return format is DataFormat.Binary; } -#if NET7_0_OR_GREATER protected override T ReadCore(PgReader reader) => T.CreateChecked(reader.ReadInt64()); protected override void WriteCore(PgWriter writer, T value) => writer.WriteInt64(long.CreateChecked(value)); -#else - protected override T ReadCore(PgReader reader) - { - var value = reader.ReadInt64(); - if (typeof(long) == typeof(T)) - return (T)(object)value; - - if (typeof(short) == typeof(T)) - return (T)(object)checked((short)value); - if (typeof(int) == typeof(T)) - return (T)(object)checked((int)value); - - if (typeof(byte) == typeof(T)) - return (T)(object)checked((byte)value); - if (typeof(sbyte) == typeof(T)) - return (T)(object)checked((sbyte)value); - - if (typeof(float) == typeof(T)) - return (T)(object)(float)value; - if (typeof(double) == typeof(T)) - return (T)(object)(double)value; - if (typeof(decimal) == typeof(T)) - return (T)(object)(decimal)value; - - throw new NotSupportedException(); - } - - protected override void WriteCore(PgWriter writer, T value) - { - if (typeof(short) == typeof(T)) - writer.WriteInt64((short)(object)value!); - else if (typeof(int) == typeof(T)) - writer.WriteInt64((int)(object)value!); - else if (typeof(long) == typeof(T)) - writer.WriteInt64((long)(object)value!); - - else if (typeof(byte) == typeof(T)) - writer.WriteInt64((byte)(object)value!); - else if (typeof(sbyte) == typeof(T)) - writer.WriteInt64((sbyte)(object)value!); - - else if (typeof(float) == typeof(T)) - writer.WriteInt64(checked((long)(float)(object)value!)); - else if (typeof(double) == typeof(T)) - writer.WriteInt64(checked((long)(double)(object)value!)); - else if (typeof(decimal) == typeof(T)) - writer.WriteInt64((long)(decimal)(object)value!); - else - throw new NotSupportedException(); - } -#endif } diff --git a/src/Npgsql/Internal/Converters/Primitive/NumericConverters.cs b/src/Npgsql/Internal/Converters/Primitive/NumericConverters.cs index c43e90a1f7..79a82a1bfa 100644 --- a/src/Npgsql/Internal/Converters/Primitive/NumericConverters.cs +++ b/src/Npgsql/Internal/Converters/Primitive/NumericConverters.cs @@ -13,6 +13,9 @@ sealed class BigIntegerNumericConverter : PgStreamingConverter public override BigInteger Read(PgReader reader) { + if (reader.ShouldBuffer(sizeof(short))) + reader.Buffer(sizeof(short)); + var digitCount = reader.ReadInt16(); short[]? digitsFromPool = null; var digits = (digitCount <= StackAllocByteThreshold / sizeof(short) @@ -31,13 +34,15 @@ public override ValueTask ReadAsync(PgReader reader, CancellationTok { // If we don't need a read and can read buffered we delegate to our sync read method which won't do IO in such a case. if (!reader.ShouldBuffer(reader.CurrentRemaining)) - Read(reader); + return new(Read(reader)); return AsyncCore(reader, cancellationToken); static async ValueTask AsyncCore(PgReader reader, CancellationToken cancellationToken) { - await reader.BufferAsync(PgNumeric.GetByteCount(0), cancellationToken).ConfigureAwait(false); + if (reader.ShouldBuffer(sizeof(short))) + await reader.BufferAsync(sizeof(short), cancellationToken).ConfigureAwait(false); + var digitCount = reader.ReadInt16(); var digits = new ArraySegment(ArrayPool.Shared.Rent(digitCount), 0, digitCount); var value = ConvertTo(await NumericConverter.ReadAsync(reader, digits, cancellationToken).ConfigureAwait(false)); @@ -82,12 +87,7 @@ static async ValueTask AsyncCore(PgWriter writer, BigInteger value, Cancellation static BigInteger ConvertTo(in PgNumeric numeric) => numeric.ToBigInteger(); } -sealed class DecimalNumericConverter : PgBufferedConverter -#if NET7_0_OR_GREATER - where T : INumberBase -#else - where T : notnull -#endif +sealed class DecimalNumericConverter : PgBufferedConverter where T : INumberBase { const int StackAllocByteThreshold = 64 * sizeof(uint); @@ -129,65 +129,15 @@ protected override void WriteCore(PgWriter writer, T value) } static PgNumeric.Builder ConvertFrom(T value, Span destination) - { -#if !NET7_0_OR_GREATER - if (typeof(short) == typeof(T)) - return new PgNumeric.Builder((decimal)(short)(object)value!, destination); - if (typeof(int) == typeof(T)) - return new PgNumeric.Builder((decimal)(int)(object)value!, destination); - if (typeof(long) == typeof(T)) - return new PgNumeric.Builder((decimal)(long)(object)value!, destination); - - if (typeof(byte) == typeof(T)) - return new PgNumeric.Builder((decimal)(byte)(object)value!, destination); - if (typeof(sbyte) == typeof(T)) - return new PgNumeric.Builder((decimal)(sbyte)(object)value!, destination); - - if (typeof(float) == typeof(T)) - return new PgNumeric.Builder((decimal)(float)(object)value!, destination); - if (typeof(double) == typeof(T)) - return new PgNumeric.Builder((decimal)(double)(object)value!, destination); - if (typeof(decimal) == typeof(T)) - return new PgNumeric.Builder((decimal)(object)value!, destination); - - throw new NotSupportedException(); -#else - return new PgNumeric.Builder(decimal.CreateChecked(value), destination); -#endif - } + => new(decimal.CreateChecked(value), destination); static T ConvertTo(in PgNumeric.Builder numeric) - { -#if !NET7_0_OR_GREATER - if (typeof(short) == typeof(T)) - return (T)(object)(short)numeric.ToDecimal(); - if (typeof(int) == typeof(T)) - return (T)(object)(int)numeric.ToDecimal(); - if (typeof(long) == typeof(T)) - return (T)(object)(long)numeric.ToDecimal(); - - if (typeof(byte) == typeof(T)) - return (T)(object)(byte)numeric.ToDecimal(); - if (typeof(sbyte) == typeof(T)) - return (T)(object)(sbyte)numeric.ToDecimal(); - - if (typeof(float) == typeof(T)) - return (T)(object)(float)numeric.ToDecimal(); - if (typeof(double) == typeof(T)) - return (T)(object)(double)numeric.ToDecimal(); - if (typeof(decimal) == typeof(T)) - return (T)(object)numeric.ToDecimal(); - - throw new NotSupportedException(); -#else - return T.CreateChecked(numeric.ToDecimal()); -#endif - } + => T.CreateChecked(numeric.ToDecimal()); } static class NumericConverter { - public static int DecimalBasedMaxByteCount = PgNumeric.GetByteCount(PgNumeric.Builder.MaxDecimalNumericDigits); + public static readonly int DecimalBasedMaxByteCount = PgNumeric.GetByteCount(PgNumeric.Builder.MaxDecimalNumericDigits); public static PgNumeric.Builder Read(PgReader reader, Span digits) { @@ -216,7 +166,7 @@ public static async ValueTask ReadAsync(PgReader reader, ArraySegment var sign = reader.ReadInt16(); var scale = reader.ReadInt16(); var array = digits.Array!; - for (var i = digits.Offset; i < array.Length; i++) + for (var i = digits.Offset; i < digits.Offset + digits.Count; i++) { if (reader.ShouldBuffer(sizeof(short))) await reader.BufferAsync(sizeof(short), cancellationToken).ConfigureAwait(false); diff --git a/src/Npgsql/Internal/Converters/Primitive/PgMoney.cs b/src/Npgsql/Internal/Converters/Primitive/PgMoney.cs index dc8755de1f..bddbbda648 100644 --- a/src/Npgsql/Internal/Converters/Primitive/PgMoney.cs +++ b/src/Npgsql/Internal/Converters/Primitive/PgMoney.cs @@ -51,10 +51,6 @@ static void GetDecimalBits(decimal value, Span destination, out short scal Debug.Assert(destination.Length >= DecimalBits); decimal.GetBits(value, MemoryMarshal.Cast(destination)); -#if NET7_0_OR_GREATER scale = value.Scale; -#else - scale = (byte)(destination[3] >> 16); -#endif } } diff --git a/src/Npgsql/Internal/Converters/Primitive/PgNumeric.cs b/src/Npgsql/Internal/Converters/Primitive/PgNumeric.cs index 19266cda1f..c90036d381 100644 --- a/src/Npgsql/Internal/Converters/Primitive/PgNumeric.cs +++ b/src/Npgsql/Internal/Converters/Primitive/PgNumeric.cs @@ -8,28 +8,21 @@ namespace Npgsql.Internal.Converters; -readonly struct PgNumeric +readonly struct PgNumeric(ArraySegment digits, short weight, short sign, short scale) { // numeric digit count + weight + sign + scale const int StructureByteCount = 4 * sizeof(short); const int DecimalBits = 4; const int StackAllocByteThreshold = 64 * sizeof(uint); - readonly ushort _sign; - - public PgNumeric(ArraySegment digits, short weight, short sign, short scale) - { - Digits = digits; - Weight = weight; - _sign = (ushort)sign; - Scale = scale; - } + readonly ushort _sign = (ushort)sign; /// Big endian array of numeric digits - public ArraySegment Digits { get; } - public short Weight { get; } + public ArraySegment Digits { get; } = digits; + + public short Weight { get; } = weight; public short Sign => (short)_sign; - public short Scale { get; } + public short Scale { get; } = scale; public int GetByteCount() => GetByteCount(Digits.Count); public static int GetByteCount(int digitCount) => StructureByteCount + digitCount * sizeof(short); @@ -39,12 +32,7 @@ static void GetDecimalBits(decimal value, Span destination, out short scal Debug.Assert(destination.Length >= DecimalBits); decimal.GetBits(value, MemoryMarshal.Cast(destination)); - -#if NET7_0_OR_GREATER scale = value.Scale; -#else - scale = (byte)(destination[3] >> 16); -#endif } public static int GetDigitCount(decimal value) @@ -101,7 +89,8 @@ public readonly ref struct Builder internal const int MaxDecimalNumericDigits = 8; // Fast access for 10^n where n is 0-9 - static ReadOnlySpan UIntPowers10 => new uint[] { + static ReadOnlySpan UIntPowers10 => + [ 1, 10, 100, @@ -112,7 +101,7 @@ public readonly ref struct Builder 10000000, 100000000, 1000000000 - }; + ]; const int MaxUInt32Scale = 9; const int MaxUInt16Scale = 4; @@ -351,14 +340,24 @@ internal static decimal ToDecimal(short scale, short weight, ushort sign, Span 0) { var scaleChunk = Math.Min(MaxUIntScale, scaleDifference); - result *= UIntPowers10[scaleChunk]; + scaleFactor *= UIntPowers10[scaleChunk]; scaleDifference -= scaleChunk; } + } } result *= scaleFactor; diff --git a/src/Npgsql/Internal/Converters/Primitive/RealConverter.cs b/src/Npgsql/Internal/Converters/Primitive/RealConverter.cs index b47e641aa5..89eeebb7fe 100644 --- a/src/Npgsql/Internal/Converters/Primitive/RealConverter.cs +++ b/src/Npgsql/Internal/Converters/Primitive/RealConverter.cs @@ -4,10 +4,7 @@ // ReSharper disable once CheckNamespace namespace Npgsql.Internal.Converters; -sealed class RealConverter : PgBufferedConverter -#if NET7_0_OR_GREATER - where T : INumberBase -#endif +sealed class RealConverter : PgBufferedConverter where T : INumberBase { public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { @@ -15,29 +12,6 @@ public override bool CanConvert(DataFormat format, out BufferRequirements buffer return format is DataFormat.Binary; } -#if NET7_0_OR_GREATER protected override T ReadCore(PgReader reader) => T.CreateChecked(reader.ReadFloat()); protected override void WriteCore(PgWriter writer, T value) => writer.WriteFloat(float.CreateChecked(value)); -#else - protected override T ReadCore(PgReader reader) - { - var value = reader.ReadFloat(); - if (typeof(float) == typeof(T)) - return (T)(object)value; - if (typeof(double) == typeof(T)) - return (T)(object)(double)value; - - throw new NotSupportedException(); - } - - protected override void WriteCore(PgWriter writer, T value) - { - if (typeof(float) == typeof(T)) - writer.WriteFloat((float)(object)value!); - else if (typeof(double) == typeof(T)) - writer.WriteFloat((float)(double)(object)value!); - else - throw new NotSupportedException(); - } -#endif } diff --git a/src/Npgsql/Internal/Converters/Primitive/TextConverters.cs b/src/Npgsql/Internal/Converters/Primitive/TextConverters.cs index b0e3a1b5bd..e1ef7f714a 100644 --- a/src/Npgsql/Internal/Converters/Primitive/TextConverters.cs +++ b/src/Npgsql/Internal/Converters/Primitive/TextConverters.cs @@ -11,25 +11,22 @@ // ReSharper disable once CheckNamespace namespace Npgsql.Internal.Converters; -abstract class StringBasedTextConverter : PgStreamingConverter +abstract class StringBasedTextConverter(Encoding encoding) : PgStreamingConverter { - readonly Encoding _encoding; - protected StringBasedTextConverter(Encoding encoding) => _encoding = encoding; - public override T Read(PgReader reader) - => Read(async: false, reader, _encoding).GetAwaiter().GetResult(); + => Read(async: false, reader, encoding).GetAwaiter().GetResult(); public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) - => Read(async: true, reader, _encoding, cancellationToken); + => Read(async: true, reader, encoding, cancellationToken); public override Size GetSize(SizeContext context, T value, ref object? writeState) - => TextConverter.GetSize(ref context, ConvertTo(value), _encoding); + => TextConverter.GetSize(ref context, ConvertTo(value), encoding); public override void Write(PgWriter writer, T value) - => writer.WriteChars(ConvertTo(value).Span, _encoding); + => writer.WriteChars(ConvertTo(value).Span, encoding); public override ValueTask WriteAsync(PgWriter writer, T value, CancellationToken cancellationToken = default) - => writer.WriteCharsAsync(ConvertTo(value), _encoding, cancellationToken); + => writer.WriteCharsAsync(ConvertTo(value), encoding, cancellationToken); public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { @@ -52,38 +49,33 @@ async ValueTask ReadAsync(PgReader reader, Encoding encoding, CancellationTok } } -sealed class ReadOnlyMemoryTextConverter : StringBasedTextConverter> +sealed class ReadOnlyMemoryTextConverter(Encoding encoding) : StringBasedTextConverter>(encoding) { - public ReadOnlyMemoryTextConverter(Encoding encoding) : base(encoding) { } protected override ReadOnlyMemory ConvertTo(ReadOnlyMemory value) => value; protected override ReadOnlyMemory ConvertFrom(string value) => value.AsMemory(); } -sealed class StringTextConverter : StringBasedTextConverter +sealed class StringTextConverter(Encoding encoding) : StringBasedTextConverter(encoding) { - public StringTextConverter(Encoding encoding) : base(encoding) { } protected override ReadOnlyMemory ConvertTo(string value) => value.AsMemory(); protected override string ConvertFrom(string value) => value; } -abstract class ArrayBasedTextConverter : PgStreamingConverter +abstract class ArrayBasedTextConverter(Encoding encoding) : PgStreamingConverter { - readonly Encoding _encoding; - protected ArrayBasedTextConverter(Encoding encoding) => _encoding = encoding; - public override T Read(PgReader reader) - => Read(async: false, reader, _encoding).GetAwaiter().GetResult(); + => Read(async: false, reader, encoding).GetAwaiter().GetResult(); public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) - => Read(async: true, reader, _encoding); + => Read(async: true, reader, encoding); public override Size GetSize(SizeContext context, T value, ref object? writeState) - => TextConverter.GetSize(ref context, ConvertTo(value), _encoding); + => TextConverter.GetSize(ref context, ConvertTo(value), encoding); public override void Write(PgWriter writer, T value) - => writer.WriteChars(ConvertTo(value).AsSpan(), _encoding); + => writer.WriteChars(ConvertTo(value).AsSpan(), encoding); public override ValueTask WriteAsync(PgWriter writer, T value, CancellationToken cancellationToken = default) - => writer.WriteCharsAsync(ConvertTo(value), _encoding, cancellationToken); + => writer.WriteCharsAsync(ConvertTo(value), encoding, cancellationToken); public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { @@ -110,16 +102,14 @@ static ArraySegment GetSegment(ReadOnlySequence bytes, Encoding enco } } -sealed class CharArraySegmentTextConverter : ArrayBasedTextConverter> +sealed class CharArraySegmentTextConverter(Encoding encoding) : ArrayBasedTextConverter>(encoding) { - public CharArraySegmentTextConverter(Encoding encoding) : base(encoding) { } protected override ArraySegment ConvertTo(ArraySegment value) => value; protected override ArraySegment ConvertFrom(ArraySegment value) => value; } -sealed class CharArrayTextConverter : ArrayBasedTextConverter +sealed class CharArrayTextConverter(Encoding encoding) : ArrayBasedTextConverter(encoding) { - public CharArrayTextConverter(Encoding encoding) : base(encoding) { } protected override ArraySegment ConvertTo(char[] value) => new(value, 0, value.Length); protected override char[] ConvertFrom(ArraySegment value) { @@ -132,16 +122,9 @@ protected override char[] ConvertFrom(ArraySegment value) } } -sealed class CharTextConverter : PgBufferedConverter +sealed class CharTextConverter(Encoding encoding) : PgBufferedConverter { - readonly Encoding _encoding; - readonly Size _oneCharMaxByteCount; - - public CharTextConverter(Encoding encoding) - { - _encoding = encoding; - _oneCharMaxByteCount = Size.CreateUpperBound(encoding.GetMaxByteCount(1)); - } + readonly Size _oneCharMaxByteCount = Size.CreateUpperBound(encoding.GetMaxByteCount(1)); public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { @@ -155,33 +138,30 @@ protected override char ReadCore(PgReader reader) Debug.Assert(byteSeq.IsSingleSegment); var bytes = byteSeq.FirstSpan; - var chars = _encoding.GetCharCount(bytes); + var chars = encoding.GetCharCount(bytes); if (chars < 1) throw new NpgsqlException("Could not read char - string was empty"); Span destination = stackalloc char[chars]; - _encoding.GetChars(bytes, destination); + encoding.GetChars(bytes, destination); return destination[0]; } public override Size GetSize(SizeContext context, char value, ref object? writeState) { - Span spanValue = stackalloc char[] { value }; - return _encoding.GetByteCount(spanValue); + Span spanValue = [value]; + return encoding.GetByteCount(spanValue); } protected override void WriteCore(PgWriter writer, char value) { - Span spanValue = stackalloc char[] { value }; - writer.WriteChars(spanValue, _encoding); + Span spanValue = [value]; + writer.WriteChars(spanValue, encoding); } } -sealed class TextReaderTextConverter : PgStreamingConverter +sealed class TextReaderTextConverter(Encoding encoding) : PgStreamingConverter { - readonly Encoding _encoding; - public TextReaderTextConverter(Encoding encoding) => _encoding = encoding; - public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { bufferRequirements = BufferRequirements.None; @@ -189,10 +169,10 @@ public override bool CanConvert(DataFormat format, out BufferRequirements buffer } public override TextReader Read(PgReader reader) - => reader.GetTextReader(_encoding); + => reader.GetTextReader(encoding); public override ValueTask ReadAsync(PgReader reader, CancellationToken cancellationToken = default) - => reader.GetTextReaderAsync(_encoding, cancellationToken); + => reader.GetTextReaderAsync(encoding, cancellationToken); public override Size GetSize(SizeContext context, TextReader value, ref object? writeState) => throw new NotImplementedException(); public override void Write(PgWriter writer, TextReader value) => throw new NotImplementedException(); @@ -200,19 +180,15 @@ public override ValueTask ReadAsync(PgReader reader, CancellationTok } -readonly struct GetChars +readonly struct GetChars(int read) { - public int Read { get; } - public GetChars(int read) => Read = read; + public int Read { get; } = read; } -sealed class GetCharsTextConverter : PgStreamingConverter, IResumableRead +sealed class GetCharsTextConverter(Encoding encoding) : PgStreamingConverter { - readonly Encoding _encoding; - public GetCharsTextConverter(Encoding encoding) => _encoding = encoding; - public override GetChars Read(PgReader reader) - => reader.IsCharsRead + => reader.CharsReadActive ? ResumableRead(reader) : throw new NotSupportedException(); @@ -225,28 +201,26 @@ public override ValueTask ReadAsync(PgReader reader, CancellationToken GetChars ResumableRead(PgReader reader) { - reader.GetCharsReadInfo(_encoding, out var charsRead, out var textReader, out var charsOffset, out var buffer); - if (charsOffset < charsRead || (buffer is null && charsRead > 0)) + reader.GetCharsReadInfo(encoding, out var charsRead, out var textReader, out var charsOffset, out var buffer); + + // With variable length encodings, moving backwards based on bytes means we have to start over. + if (charsRead > charsOffset) { - // With variable length encodings, moving backwards based on bytes means we have to start over. - reader.ResetCharsRead(out charsRead); + reader.RestartCharsRead(); + charsRead = 0; } // First seek towards the charsOffset. // If buffer is null read the entire thing and report the length, see sql client remarks. // https://learn.microsoft.com/en-us/dotnet/api/system.data.sqlclient.sqldatareader.getchars - int read; + var read = ConsumeChars(textReader, buffer is null ? null : charsOffset - charsRead); + Debug.Assert(buffer is null || read == charsOffset - charsRead); + reader.AdvanceCharsRead(read); if (buffer is null) - { - read = ConsumeChars(textReader, null); - } - else - { - var consumed = ConsumeChars(textReader, charsOffset - charsRead); - Debug.Assert(consumed == charsOffset - charsRead); - read = textReader.ReadBlock(buffer.GetValueOrDefault().Array!, buffer.GetValueOrDefault().Offset, buffer.GetValueOrDefault().Count); - } + return new(read); + read = textReader.ReadBlock(buffer.GetValueOrDefault().Array!, buffer.GetValueOrDefault().Offset, buffer.GetValueOrDefault().Count); + reader.AdvanceCharsRead(read); return new(read); static int ConsumeChars(TextReader reader, int? count) @@ -271,8 +245,6 @@ static int ConsumeChars(TextReader reader, int? count) return totalRead; } } - - bool IResumableRead.Supported => true; } // Moved out for code size/sharing. diff --git a/src/Npgsql/Internal/Converters/RecordConverter.cs b/src/Npgsql/Internal/Converters/RecordConverter.cs index aabd914b49..05eabcf7cd 100644 --- a/src/Npgsql/Internal/Converters/RecordConverter.cs +++ b/src/Npgsql/Internal/Converters/RecordConverter.cs @@ -5,16 +5,9 @@ namespace Npgsql.Internal.Converters; -sealed class RecordConverter : PgStreamingConverter +sealed class RecordConverter(PgSerializerOptions options, Func? factory = null) : PgStreamingConverter { - readonly PgSerializerOptions _options; - readonly Func? _factory; - - public RecordConverter(PgSerializerOptions options, Func? factory = null) - { - _options = options; - _factory = factory; - } + static bool IsObjectArrayRecord => typeof(T) == typeof(object[]); public override T Read(PgReader reader) => Read(async: false, reader, CancellationToken.None).GetAwaiter().GetResult(); @@ -41,14 +34,17 @@ async ValueTask Read(bool async, PgReader reader, CancellationToken cancellat continue; var postgresType = - _options.DatabaseInfo.GetPostgresType(typeOid).GetRepresentationalType() + options.DatabaseInfo.GetPostgresType(typeOid).GetRepresentationalType() ?? throw new NotSupportedException($"Reading isn't supported for record field {i} (unknown type OID {typeOid}"); + var pgTypeId = options.ToCanonicalTypeId(postgresType); - var typeInfo = _options.GetObjectOrDefaultTypeInfo(postgresType) + // TODO resolve based on types expected by _factory (pass in a Type[] during construcion) + // Only allow object polymorphism for object[] records, valuetuple records are always strongly typed. + var typeInfo = (IsObjectArrayRecord ? options.GetTypeInfo(typeof(object), pgTypeId) : options.GetDefaultTypeInfo(pgTypeId)) ?? throw new NotSupportedException( $"Reading isn't supported for record field {i} (PG type '{postgresType.DisplayName}'"); - var converterInfo = typeInfo.Bind(new Field("?", _options.ToCanonicalTypeId(postgresType), -1), DataFormat.Binary); + var converterInfo = typeInfo.Bind(new Field("?", pgTypeId, -1), DataFormat.Binary); var scope = await reader.BeginNestedRead(async, length, converterInfo.BufferRequirement, cancellationToken).ConfigureAwait(false); try { @@ -63,7 +59,7 @@ async ValueTask Read(bool async, PgReader reader, CancellationToken cancellat } } - return _factory is null ? (T)(object)result : _factory(result); + return factory is null ? (T)(object)result : factory(result); } public override Size GetSize(SizeContext context, T value, ref object? writeState) diff --git a/src/Npgsql/Internal/Converters/Temporal/DateConverters.cs b/src/Npgsql/Internal/Converters/Temporal/DateConverters.cs index 79aabf1d58..41e2cb83da 100644 --- a/src/Npgsql/Internal/Converters/Temporal/DateConverters.cs +++ b/src/Npgsql/Internal/Converters/Temporal/DateConverters.cs @@ -4,14 +4,9 @@ // ReSharper disable once CheckNamespace namespace Npgsql.Internal.Converters; -sealed class DateTimeDateConverter : PgBufferedConverter +sealed class DateOnlyDateConverter(bool dateTimeInfinityConversions) : PgBufferedConverter { - readonly bool _dateTimeInfinityConversions; - - static readonly DateTime BaseValue = new(2000, 1, 1, 0, 0, 0); - - public DateTimeDateConverter(bool dateTimeInfinityConversions) - => _dateTimeInfinityConversions = dateTimeInfinityConversions; + static readonly DateOnly BaseValue = new(2000, 1, 1); public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { @@ -19,47 +14,42 @@ public override bool CanConvert(DataFormat format, out BufferRequirements buffer return format is DataFormat.Binary; } - protected override DateTime ReadCore(PgReader reader) + protected override DateOnly ReadCore(PgReader reader) => reader.ReadInt32() switch { - int.MaxValue => _dateTimeInfinityConversions - ? DateTime.MaxValue + int.MaxValue => dateTimeInfinityConversions + ? DateOnly.MaxValue : throw new InvalidCastException(NpgsqlStrings.CannotReadInfinityValue), - int.MinValue => _dateTimeInfinityConversions - ? DateTime.MinValue + int.MinValue => dateTimeInfinityConversions + ? DateOnly.MinValue : throw new InvalidCastException(NpgsqlStrings.CannotReadInfinityValue), - var value => BaseValue + TimeSpan.FromDays(value) + var value => BaseValue.AddDays(value) }; - protected override void WriteCore(PgWriter writer, DateTime value) + protected override void WriteCore(PgWriter writer, DateOnly value) { - if (_dateTimeInfinityConversions) + if (dateTimeInfinityConversions) { - if (value == DateTime.MaxValue) + if (value == DateOnly.MaxValue) { writer.WriteInt32(int.MaxValue); return; } - if (value == DateTime.MinValue) + if (value == DateOnly.MinValue) { writer.WriteInt32(int.MinValue); return; } } - writer.WriteInt32((value.Date - BaseValue).Days); + writer.WriteInt32(value.DayNumber - BaseValue.DayNumber); } } -sealed class DateOnlyDateConverter : PgBufferedConverter +sealed class DateTimeDateConverter(bool dateTimeInfinityConversions) : PgBufferedConverter { - readonly bool _dateTimeInfinityConversions; - - static readonly DateOnly BaseValue = new(2000, 1, 1); - - public DateOnlyDateConverter(bool dateTimeInfinityConversions) - => _dateTimeInfinityConversions = dateTimeInfinityConversions; + static readonly DateTime BaseValue = new(2000, 1, 1, 0, 0, 0); public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { @@ -67,35 +57,35 @@ public override bool CanConvert(DataFormat format, out BufferRequirements buffer return format is DataFormat.Binary; } - protected override DateOnly ReadCore(PgReader reader) + protected override DateTime ReadCore(PgReader reader) => reader.ReadInt32() switch { - int.MaxValue => _dateTimeInfinityConversions - ? DateOnly.MaxValue + int.MaxValue => dateTimeInfinityConversions + ? DateTime.MaxValue : throw new InvalidCastException(NpgsqlStrings.CannotReadInfinityValue), - int.MinValue => _dateTimeInfinityConversions - ? DateOnly.MinValue + int.MinValue => dateTimeInfinityConversions + ? DateTime.MinValue : throw new InvalidCastException(NpgsqlStrings.CannotReadInfinityValue), - var value => BaseValue.AddDays(value) + var value => BaseValue + TimeSpan.FromDays(value) }; - protected override void WriteCore(PgWriter writer, DateOnly value) + protected override void WriteCore(PgWriter writer, DateTime value) { - if (_dateTimeInfinityConversions) + if (dateTimeInfinityConversions) { - if (value == DateOnly.MaxValue) + if (value == DateTime.MaxValue) { writer.WriteInt32(int.MaxValue); return; } - if (value == DateOnly.MinValue) + if (value == DateTime.MinValue) { writer.WriteInt32(int.MinValue); return; } } - writer.WriteInt32(value.DayNumber - BaseValue.DayNumber); + writer.WriteInt32((value.Date - BaseValue).Days); } } diff --git a/src/Npgsql/Internal/Converters/Temporal/DateTimeConverters.cs b/src/Npgsql/Internal/Converters/Temporal/DateTimeConverters.cs index ed744bb099..389c2ec021 100644 --- a/src/Npgsql/Internal/Converters/Temporal/DateTimeConverters.cs +++ b/src/Npgsql/Internal/Converters/Temporal/DateTimeConverters.cs @@ -3,17 +3,8 @@ // ReSharper disable once CheckNamespace namespace Npgsql.Internal.Converters; -sealed class DateTimeConverter : PgBufferedConverter +sealed class DateTimeConverter(bool dateTimeInfinityConversions, DateTimeKind kind) : PgBufferedConverter { - readonly bool _dateTimeInfinityConversions; - readonly DateTimeKind _kind; - - public DateTimeConverter(bool dateTimeInfinityConversions, DateTimeKind kind) - { - _dateTimeInfinityConversions = dateTimeInfinityConversions; - _kind = kind; - } - public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); @@ -21,18 +12,14 @@ public override bool CanConvert(DataFormat format, out BufferRequirements buffer } protected override DateTime ReadCore(PgReader reader) - => PgTimestamp.Decode(reader.ReadInt64(), _kind, _dateTimeInfinityConversions); + => PgTimestamp.Decode(reader.ReadInt64(), kind, dateTimeInfinityConversions); protected override void WriteCore(PgWriter writer, DateTime value) - => writer.WriteInt64(PgTimestamp.Encode(value, _dateTimeInfinityConversions)); + => writer.WriteInt64(PgTimestamp.Encode(value, dateTimeInfinityConversions)); } -sealed class DateTimeOffsetConverter : PgBufferedConverter +sealed class DateTimeOffsetConverter(bool dateTimeInfinityConversions) : PgBufferedConverter { - readonly bool _dateTimeInfinityConversions; - public DateTimeOffsetConverter(bool dateTimeInfinityConversions) - => _dateTimeInfinityConversions = dateTimeInfinityConversions; - public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); @@ -40,14 +27,14 @@ public override bool CanConvert(DataFormat format, out BufferRequirements buffer } protected override DateTimeOffset ReadCore(PgReader reader) - => new(PgTimestamp.Decode(reader.ReadInt64(), DateTimeKind.Utc, _dateTimeInfinityConversions), TimeSpan.Zero); + => new(PgTimestamp.Decode(reader.ReadInt64(), DateTimeKind.Utc, dateTimeInfinityConversions), TimeSpan.Zero); protected override void WriteCore(PgWriter writer, DateTimeOffset value) { if (value.Offset != TimeSpan.Zero) throw new ArgumentException($"Cannot write DateTimeOffset with Offset={value.Offset} to PostgreSQL type 'timestamp with time zone', only offset 0 (UTC) is supported. ", nameof(value)); - writer.WriteInt64(PgTimestamp.Encode(value.DateTime, _dateTimeInfinityConversions)); + writer.WriteInt64(PgTimestamp.Encode(value.DateTime, dateTimeInfinityConversions)); } } diff --git a/src/Npgsql/Internal/Converters/Temporal/LegacyDateTimeConverter.cs b/src/Npgsql/Internal/Converters/Temporal/LegacyDateTimeConverter.cs index 99ad4ed599..8bcca02db1 100644 --- a/src/Npgsql/Internal/Converters/Temporal/LegacyDateTimeConverter.cs +++ b/src/Npgsql/Internal/Converters/Temporal/LegacyDateTimeConverter.cs @@ -2,17 +2,8 @@ namespace Npgsql.Internal.Converters; -sealed class LegacyDateTimeConverter : PgBufferedConverter +sealed class LegacyDateTimeConverter(bool dateTimeInfinityConversions, bool timestamp) : PgBufferedConverter { - readonly bool _dateTimeInfinityConversions; - readonly bool _timestamp; - - public LegacyDateTimeConverter(bool dateTimeInfinityConversions, bool timestamp) - { - _dateTimeInfinityConversions = dateTimeInfinityConversions; - _timestamp = timestamp; - } - public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); @@ -21,33 +12,28 @@ public override bool CanConvert(DataFormat format, out BufferRequirements buffer protected override DateTime ReadCore(PgReader reader) { - if (_timestamp) + if (timestamp) { - return PgTimestamp.Decode(reader.ReadInt64(), DateTimeKind.Unspecified, _dateTimeInfinityConversions); + return PgTimestamp.Decode(reader.ReadInt64(), DateTimeKind.Unspecified, dateTimeInfinityConversions); } - var dateTime = PgTimestamp.Decode(reader.ReadInt64(), DateTimeKind.Utc, _dateTimeInfinityConversions); - return (dateTime == DateTime.MinValue || dateTime == DateTime.MaxValue) && _dateTimeInfinityConversions + var dateTime = PgTimestamp.Decode(reader.ReadInt64(), DateTimeKind.Utc, dateTimeInfinityConversions); + return (dateTime == DateTime.MinValue || dateTime == DateTime.MaxValue) && dateTimeInfinityConversions ? dateTime : dateTime.ToLocalTime(); } protected override void WriteCore(PgWriter writer, DateTime value) { - if (!_timestamp && value.Kind is DateTimeKind.Local) + if (!timestamp && value.Kind is DateTimeKind.Local) value = value.ToUniversalTime(); - writer.WriteInt64(PgTimestamp.Encode(value, _dateTimeInfinityConversions)); + writer.WriteInt64(PgTimestamp.Encode(value, dateTimeInfinityConversions)); } } -sealed class LegacyDateTimeOffsetConverter : PgBufferedConverter +sealed class LegacyDateTimeOffsetConverter(bool dateTimeInfinityConversions) : PgBufferedConverter { - readonly bool _dateTimeInfinityConversions; - - public LegacyDateTimeOffsetConverter(bool dateTimeInfinityConversions) - => _dateTimeInfinityConversions = dateTimeInfinityConversions; - public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); @@ -56,9 +42,9 @@ public override bool CanConvert(DataFormat format, out BufferRequirements buffer protected override DateTimeOffset ReadCore(PgReader reader) { - var dateTime = PgTimestamp.Decode(reader.ReadInt64(), DateTimeKind.Utc, _dateTimeInfinityConversions); + var dateTime = PgTimestamp.Decode(reader.ReadInt64(), DateTimeKind.Utc, dateTimeInfinityConversions); - if (_dateTimeInfinityConversions) + if (dateTimeInfinityConversions) { if (dateTime == DateTime.MinValue) return DateTimeOffset.MinValue; @@ -70,5 +56,5 @@ protected override DateTimeOffset ReadCore(PgReader reader) } protected override void WriteCore(PgWriter writer, DateTimeOffset value) - => writer.WriteInt64(PgTimestamp.Encode(value.UtcDateTime, _dateTimeInfinityConversions)); + => writer.WriteInt64(PgTimestamp.Encode(value.UtcDateTime, dateTimeInfinityConversions)); } diff --git a/src/Npgsql/Internal/Converters/Temporal/TimeConverters.cs b/src/Npgsql/Internal/Converters/Temporal/TimeConverters.cs index b93a878032..09385712bf 100644 --- a/src/Npgsql/Internal/Converters/Temporal/TimeConverters.cs +++ b/src/Npgsql/Internal/Converters/Temporal/TimeConverters.cs @@ -3,26 +3,26 @@ // ReSharper disable once CheckNamespace namespace Npgsql.Internal.Converters; -sealed class TimeSpanTimeConverter : PgBufferedConverter +sealed class TimeOnlyTimeConverter : PgBufferedConverter { public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); return format is DataFormat.Binary; } - protected override TimeSpan ReadCore(PgReader reader) => new(reader.ReadInt64() * 10); - protected override void WriteCore(PgWriter writer, TimeSpan value) => writer.WriteInt64(value.Ticks / 10); + protected override TimeOnly ReadCore(PgReader reader) => new(reader.ReadInt64() * 10); + protected override void WriteCore(PgWriter writer, TimeOnly value) => writer.WriteInt64(value.Ticks / 10); } -sealed class TimeOnlyTimeConverter : PgBufferedConverter +sealed class TimeSpanTimeConverter : PgBufferedConverter { public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { bufferRequirements = BufferRequirements.CreateFixedSize(sizeof(long)); return format is DataFormat.Binary; } - protected override TimeOnly ReadCore(PgReader reader) => new(reader.ReadInt64() * 10); - protected override void WriteCore(PgWriter writer, TimeOnly value) => writer.WriteInt64(value.Ticks / 10); + protected override TimeSpan ReadCore(PgReader reader) => new(reader.ReadInt64() * 10); + protected override void WriteCore(PgWriter writer, TimeSpan value) => writer.WriteInt64(value.Ticks / 10); } sealed class DateTimeOffsetTimeTzConverter : PgBufferedConverter diff --git a/src/Npgsql/Internal/Converters/VersionPrefixedTextConverter.cs b/src/Npgsql/Internal/Converters/VersionPrefixedTextConverter.cs index ff3a985a66..8dc981a47e 100644 --- a/src/Npgsql/Internal/Converters/VersionPrefixedTextConverter.cs +++ b/src/Npgsql/Internal/Converters/VersionPrefixedTextConverter.cs @@ -5,23 +5,15 @@ namespace Npgsql.Internal.Converters; -sealed class VersionPrefixedTextConverter : PgStreamingConverter, IResumableRead +sealed class VersionPrefixedTextConverter(byte versionPrefix, PgConverter textConverter) + : PgStreamingConverter(textConverter.DbNullPredicateKind is DbNullPredicate.Custom) { - readonly byte _versionPrefix; - readonly PgConverter _textConverter; BufferRequirements _innerRequirements; - public VersionPrefixedTextConverter(byte versionPrefix, PgConverter textConverter) - : base(textConverter.DbNullPredicateKind is DbNullPredicate.Custom) - { - _versionPrefix = versionPrefix; - _textConverter = textConverter; - } - - protected override bool IsDbNullValue(T? value, ref object? writeState) => _textConverter.IsDbNull(value, ref writeState); + protected override bool IsDbNullValue(T? value, ref object? writeState) => textConverter.IsDbNull(value, ref writeState); public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) - => VersionPrefixedTextConverter.CanConvert(_textConverter, format, out _innerRequirements, out bufferRequirements); + => VersionPrefixedTextConverter.CanConvert(textConverter, format, out _innerRequirements, out bufferRequirements); public override T Read(PgReader reader) => Read(async: false, reader, CancellationToken.None).Result; @@ -30,7 +22,7 @@ public override ValueTask ReadAsync(PgReader reader, CancellationToken cancel => Read(async: true, reader, cancellationToken); public override Size GetSize(SizeContext context, [DisallowNull]T value, ref object? writeState) - => _textConverter.GetSize(context, value, ref writeState).Combine(context.Format is DataFormat.Binary ? sizeof(byte) : 0); + => textConverter.GetSize(context, value, ref writeState).Combine(context.Format is DataFormat.Binary ? sizeof(byte) : 0); public override void Write(PgWriter writer, [DisallowNull]T value) => Write(async: false, writer, value, CancellationToken.None).GetAwaiter().GetResult(); @@ -40,20 +32,18 @@ public override ValueTask WriteAsync(PgWriter writer, [DisallowNull]T value, Can async ValueTask Read(bool async, PgReader reader, CancellationToken cancellationToken) { - await VersionPrefixedTextConverter.ReadVersion(async, _versionPrefix, reader, _innerRequirements.Read, cancellationToken).ConfigureAwait(false); - return async ? await _textConverter.ReadAsync(reader, cancellationToken).ConfigureAwait(false) : _textConverter.Read(reader); + await VersionPrefixedTextConverter.ReadVersion(async, versionPrefix, reader, _innerRequirements.Read, cancellationToken).ConfigureAwait(false); + return async ? await textConverter.ReadAsync(reader, cancellationToken).ConfigureAwait(false) : textConverter.Read(reader); } async ValueTask Write(bool async, PgWriter writer, [DisallowNull]T value, CancellationToken cancellationToken) { - await VersionPrefixedTextConverter.WriteVersion(async, _versionPrefix, writer, cancellationToken).ConfigureAwait(false); + await VersionPrefixedTextConverter.WriteVersion(async, versionPrefix, writer, cancellationToken).ConfigureAwait(false); if (async) - await _textConverter.WriteAsync(writer, value, cancellationToken).ConfigureAwait(false); + await textConverter.WriteAsync(writer, value, cancellationToken).ConfigureAwait(false); else - _textConverter.Write(writer, value); + textConverter.Write(writer, value); } - - bool IResumableRead.Supported => _textConverter is IResumableRead { Supported: true }; } static class VersionPrefixedTextConverter diff --git a/src/Npgsql/Internal/DataFormat.cs b/src/Npgsql/Internal/DataFormat.cs index c9950ea417..c52b418b7d 100644 --- a/src/Npgsql/Internal/DataFormat.cs +++ b/src/Npgsql/Internal/DataFormat.cs @@ -1,8 +1,10 @@ using System; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; namespace Npgsql.Internal; +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] public enum DataFormat : byte { Binary, diff --git a/src/Npgsql/Internal/DbTypeResolverFactory.cs b/src/Npgsql/Internal/DbTypeResolverFactory.cs new file mode 100644 index 0000000000..55b3b71235 --- /dev/null +++ b/src/Npgsql/Internal/DbTypeResolverFactory.cs @@ -0,0 +1,9 @@ +using System.Diagnostics.CodeAnalysis; + +namespace Npgsql.Internal; + +[Experimental(NpgsqlDiagnostics.DbTypeResolverExperimental)] +public abstract class DbTypeResolverFactory +{ + public abstract IDbTypeResolver CreateDbTypeResolver(NpgsqlDatabaseInfo databaseInfo); +} diff --git a/src/Npgsql/Internal/DynamicTypeInfoResolver.cs b/src/Npgsql/Internal/DynamicTypeInfoResolver.cs index 637c337321..dfdb5a79e7 100644 --- a/src/Npgsql/Internal/DynamicTypeInfoResolver.cs +++ b/src/Npgsql/Internal/DynamicTypeInfoResolver.cs @@ -1,11 +1,11 @@ using System; using System.Diagnostics.CodeAnalysis; -using System.Reflection; using Npgsql.Internal.Postgres; using Npgsql.PostgresTypes; namespace Npgsql.Internal; +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] [RequiresDynamicCode("A dynamic type info resolver may need to construct a generic converter for a statically unknown type.")] public abstract class DynamicTypeInfoResolver : IPgTypeInfoResolver { @@ -47,79 +47,87 @@ protected class DynamicMappingCollection { TypeInfoMappingCollection? _mappings; - static readonly MethodInfo AddTypeMethodInfo = typeof(TypeInfoMappingCollection).GetMethod(nameof(TypeInfoMappingCollection.AddType), - new[] { typeof(string), typeof(TypeInfoFactory), typeof(Func) }) ?? throw new NullReferenceException(); - - static readonly MethodInfo AddArrayTypeMethodInfo = typeof(TypeInfoMappingCollection) - .GetMethod(nameof(TypeInfoMappingCollection.AddArrayType), new[] { typeof(string) }) ?? throw new NullReferenceException(); - - static readonly MethodInfo AddStructTypeMethodInfo = typeof(TypeInfoMappingCollection).GetMethod(nameof(TypeInfoMappingCollection.AddStructType), - new[] { typeof(string), typeof(TypeInfoFactory), typeof(Func) }) ?? throw new NullReferenceException(); - - static readonly MethodInfo AddStructArrayTypeMethodInfo = typeof(TypeInfoMappingCollection) - .GetMethod(nameof(TypeInfoMappingCollection.AddStructArrayType), new[] { typeof(string) }) ?? throw new NullReferenceException(); - - static readonly MethodInfo AddResolverTypeMethodInfo = typeof(TypeInfoMappingCollection).GetMethod( - nameof(TypeInfoMappingCollection.AddResolverType), - new[] { typeof(string), typeof(TypeInfoFactory), typeof(Func) }) ?? throw new NullReferenceException(); - - static readonly MethodInfo AddResolverArrayTypeMethodInfo = typeof(TypeInfoMappingCollection) - .GetMethod(nameof(TypeInfoMappingCollection.AddResolverArrayType), new[] { typeof(string) }) ?? throw new NullReferenceException(); - - static readonly MethodInfo AddResolverStructTypeMethodInfo = typeof(TypeInfoMappingCollection).GetMethod( - nameof(TypeInfoMappingCollection.AddResolverStructType), - new[] { typeof(string), typeof(TypeInfoFactory), typeof(Func) }) ?? throw new NullReferenceException(); - - static readonly MethodInfo AddResolverStructArrayTypeMethodInfo = typeof(TypeInfoMappingCollection) - .GetMethod(nameof(TypeInfoMappingCollection.AddResolverStructArrayType), new[] { typeof(string) }) ?? throw new NullReferenceException(); - internal DynamicMappingCollection(TypeInfoMappingCollection? baseCollection = null) { if (baseCollection is not null) _mappings = new(baseCollection); } - public DynamicMappingCollection AddMapping(Type type, string dataTypeName, TypeInfoFactory factory, Func? configureMapping = null) + public DynamicMappingCollection AddMapping([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]Type type, string dataTypeName, TypeInfoFactory factory, Func? configureMapping = null) { if (type.IsValueType && Nullable.GetUnderlyingType(type) is not null) throw new NotSupportedException("Mapping nullable types is not supported, map its underlying type instead to get both."); - (type.IsValueType ? AddStructTypeMethodInfo : AddTypeMethodInfo) - .MakeGenericMethod(type).Invoke(_mappings ??= new(), new object?[] - { - dataTypeName, - factory, - configureMapping - }); + if (type.IsValueType) + typeof(TypeInfoMappingCollection) + .GetMethod(nameof(TypeInfoMappingCollection.AddStructType), [typeof(string), typeof(TypeInfoFactory), typeof(Func)])! + .MakeGenericMethod(type).Invoke(_mappings ??= new(), + [ + dataTypeName, + factory, + configureMapping + ]); + else + typeof(TypeInfoMappingCollection) + .GetMethod(nameof(TypeInfoMappingCollection.AddType), [typeof(string), typeof(TypeInfoFactory), typeof(Func)])! + .MakeGenericMethod(type).Invoke(_mappings ??= new(), + [ + dataTypeName, + factory, + configureMapping + ]); return this; } - public DynamicMappingCollection AddArrayMapping(Type elementType, string dataTypeName) + public DynamicMappingCollection AddArrayMapping([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]Type elementType, string dataTypeName) { - (elementType.IsValueType ? AddStructArrayTypeMethodInfo : AddArrayTypeMethodInfo) - .MakeGenericMethod(elementType).Invoke(_mappings ??= new(), new object?[] { dataTypeName }); + if (elementType.IsValueType) + typeof(TypeInfoMappingCollection) + .GetMethod(nameof(TypeInfoMappingCollection.AddStructArrayType), [typeof(string)])! + .MakeGenericMethod(elementType).Invoke(_mappings ??= new(), [dataTypeName]); + else + typeof(TypeInfoMappingCollection) + .GetMethod(nameof(TypeInfoMappingCollection.AddArrayType), [typeof(string)])! + .MakeGenericMethod(elementType).Invoke(_mappings ??= new(), [dataTypeName]); return this; } - public DynamicMappingCollection AddResolverMapping(Type type, string dataTypeName, TypeInfoFactory factory, Func? configureMapping = null) + public DynamicMappingCollection AddResolverMapping([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]Type type, string dataTypeName, TypeInfoFactory factory, Func? configureMapping = null) { if (type.IsValueType && Nullable.GetUnderlyingType(type) is not null) throw new NotSupportedException("Mapping nullable types is not supported"); - (type.IsValueType ? AddResolverStructTypeMethodInfo : AddResolverTypeMethodInfo) - .MakeGenericMethod(type).Invoke(_mappings ??= new(), new object?[] - { - dataTypeName, - factory, - configureMapping - }); + if (type.IsValueType) + typeof(TypeInfoMappingCollection) + .GetMethod(nameof(TypeInfoMappingCollection.AddResolverStructType), [typeof(string), typeof(TypeInfoFactory), typeof(Func)])! + .MakeGenericMethod(type).Invoke(_mappings ??= new(), + [ + dataTypeName, + factory, + configureMapping + ]); + else + typeof(TypeInfoMappingCollection) + .GetMethod(nameof(TypeInfoMappingCollection.AddResolverType), [typeof(string), typeof(TypeInfoFactory), typeof(Func)])! + .MakeGenericMethod(type).Invoke(_mappings ??= new(), + [ + dataTypeName, + factory, + configureMapping + ]); return this; } - public DynamicMappingCollection AddResolverArrayMapping(Type elementType, string dataTypeName) + public DynamicMappingCollection AddResolverArrayMapping([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)]Type elementType, string dataTypeName) { - (elementType.IsValueType ? AddResolverStructArrayTypeMethodInfo : AddResolverArrayTypeMethodInfo) - .MakeGenericMethod(elementType).Invoke(_mappings ??= new(), new object?[] { dataTypeName }); + if (elementType.IsValueType) + typeof(TypeInfoMappingCollection) + .GetMethod(nameof(TypeInfoMappingCollection.AddResolverStructArrayType), [typeof(string)])! + .MakeGenericMethod(elementType).Invoke(_mappings ??= new(), [dataTypeName]); + else + typeof(TypeInfoMappingCollection) + .GetMethod(nameof(TypeInfoMappingCollection.AddResolverArrayType), [typeof(string)])! + .MakeGenericMethod(elementType).Invoke(_mappings ??= new(), [dataTypeName]); return this; } diff --git a/src/Npgsql/Internal/HackyEnumTypeMapping.cs b/src/Npgsql/Internal/HackyEnumTypeMapping.cs deleted file mode 100644 index 1aa4b27554..0000000000 --- a/src/Npgsql/Internal/HackyEnumTypeMapping.cs +++ /dev/null @@ -1,27 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Reflection; -using Npgsql.Internal; -using Npgsql.PostgresTypes; -using NpgsqlTypes; - -namespace Npgsql.Internal; - - -/// -/// Hacky temporary measure used by EFCore.PG to extract user-configured enum mappings. Accessed via reflection only. -/// -public sealed class HackyEnumTypeMapping -{ - public HackyEnumTypeMapping(Type enumClrType, string pgTypeName, INpgsqlNameTranslator nameTranslator) - { - EnumClrType = enumClrType; - PgTypeName = pgTypeName; - NameTranslator = nameTranslator; - } - - public string PgTypeName { get; } - public Type EnumClrType { get; } - public INpgsqlNameTranslator NameTranslator { get; } -} diff --git a/src/Npgsql/Internal/IDbTypeResolver.cs b/src/Npgsql/Internal/IDbTypeResolver.cs new file mode 100644 index 0000000000..c4586a2bee --- /dev/null +++ b/src/Npgsql/Internal/IDbTypeResolver.cs @@ -0,0 +1,28 @@ +using System; +using System.Data; +using System.Diagnostics.CodeAnalysis; +using Npgsql.Internal.Postgres; + +namespace Npgsql.Internal; + +/// +/// An Npgsql resolver for DbType. Used by Npgsql to resolve a DbType to DataTypeName and back. +/// +[Experimental(NpgsqlDiagnostics.DbTypeResolverExperimental)] +public interface IDbTypeResolver +{ + /// + /// Attempts to resolve a DbType to a data type name. + /// + /// The DbType name to resolve. + /// The type of the value to resolve a data type name for. + /// The data type name if it could be mapped, the name can be non-normalized and without schema. + string? GetDataTypeName(DbType dbType, Type? type); + + /// + /// Attempts to resolve a data type name to a DbType. + /// + /// The data type name to map, in a normalized form but possibly without schema. + /// The DbType if it could be mapped, null otherwise. + DbType? GetDbType(DataTypeName dataTypeName); +} diff --git a/src/Npgsql/Internal/INpgsqlDatabaseInfoFactory.cs b/src/Npgsql/Internal/INpgsqlDatabaseInfoFactory.cs index ccdb7a8477..ea3f0ad525 100644 --- a/src/Npgsql/Internal/INpgsqlDatabaseInfoFactory.cs +++ b/src/Npgsql/Internal/INpgsqlDatabaseInfoFactory.cs @@ -1,4 +1,5 @@ -using System.Threading.Tasks; +using System.Diagnostics.CodeAnalysis; +using System.Threading.Tasks; using Npgsql.Util; namespace Npgsql.Internal; @@ -8,6 +9,7 @@ namespace Npgsql.Internal; /// and the types it contains. When first connecting to a database, Npgsql will attempt to load information /// about it via this factory. /// +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] public interface INpgsqlDatabaseInfoFactory { /// @@ -19,4 +21,4 @@ public interface INpgsqlDatabaseInfoFactory /// database isn't of the correct type and isn't handled by this factory. /// Task Load(NpgsqlConnector conn, NpgsqlTimeout timeout, bool async); -} \ No newline at end of file +} diff --git a/src/Npgsql/Internal/IPgTypeInfoResolver.cs b/src/Npgsql/Internal/IPgTypeInfoResolver.cs index 62955446eb..b7b3ddc9ec 100644 --- a/src/Npgsql/Internal/IPgTypeInfoResolver.cs +++ b/src/Npgsql/Internal/IPgTypeInfoResolver.cs @@ -1,4 +1,5 @@ using System; +using System.Diagnostics.CodeAnalysis; using Npgsql.Internal.Postgres; namespace Npgsql.Internal; @@ -6,6 +7,7 @@ namespace Npgsql.Internal; /// /// An Npgsql resolver for type info. Used by Npgsql to read and write values to PostgreSQL. /// +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] public interface IPgTypeInfoResolver { /// diff --git a/src/Npgsql/Internal/IntegratedSecurityHandler.cs b/src/Npgsql/Internal/IntegratedSecurityHandler.cs index 2b2f2f1bb9..5edb826497 100644 --- a/src/Npgsql/Internal/IntegratedSecurityHandler.cs +++ b/src/Npgsql/Internal/IntegratedSecurityHandler.cs @@ -16,7 +16,10 @@ class IntegratedSecurityHandler return new(); } - public virtual ValueTask NegotiateAuthentication(bool async, NpgsqlConnector connector) + public virtual ValueTask NegotiateAuthentication(bool async, NpgsqlConnector connector, CancellationToken cancellationToken) + => throw new NotSupportedException(string.Format(NpgsqlStrings.IntegratedSecurityDisabled, nameof(NpgsqlSlimDataSourceBuilder.EnableIntegratedSecurity))); + + public virtual ValueTask GSSEncrypt(bool async, bool isRequired, NpgsqlConnector connector, CancellationToken cancellationToken) => throw new NotSupportedException(string.Format(NpgsqlStrings.IntegratedSecurityDisabled, nameof(NpgsqlSlimDataSourceBuilder.EnableIntegratedSecurity))); } @@ -27,6 +30,9 @@ sealed class RealIntegratedSecurityHandler : IntegratedSecurityHandler public override ValueTask GetUsername(bool async, bool includeRealm, ILogger connectionLogger, CancellationToken cancellationToken) => KerberosUsernameProvider.GetUsername(async, includeRealm, connectionLogger, cancellationToken); - public override ValueTask NegotiateAuthentication(bool async, NpgsqlConnector connector) - => new(connector.AuthenticateGSS(async)); + public override ValueTask NegotiateAuthentication(bool async, NpgsqlConnector connector, CancellationToken cancellationToken) + => connector.AuthenticateGSS(async, cancellationToken); + + public override ValueTask GSSEncrypt(bool async, bool isRequired, NpgsqlConnector connector, CancellationToken cancellationToken) + => connector.GSSEncrypt(async, isRequired, cancellationToken); } diff --git a/src/Npgsql/Internal/NpgsqlConnector.Auth.cs b/src/Npgsql/Internal/NpgsqlConnector.Auth.cs index 8fe1bfe402..f837f08026 100644 --- a/src/Npgsql/Internal/NpgsqlConnector.Auth.cs +++ b/src/Npgsql/Internal/NpgsqlConnector.Auth.cs @@ -18,46 +18,73 @@ partial class NpgsqlConnector { async Task Authenticate(string username, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken) { + var requiredAuthModes = Settings.RequireAuthModes; + if (requiredAuthModes == default) + requiredAuthModes = NpgsqlConnectionStringBuilder.ParseAuthMode(PostgresEnvironment.RequireAuth); + + var authenticated = false; + while (true) { timeout.CheckAndApply(this); var msg = ExpectAny(await ReadMessage(async).ConfigureAwait(false), this); switch (msg.AuthRequestType) { - case AuthenticationRequestType.AuthenticationOk: + case AuthenticationRequestType.Ok: + // If we didn't complete authentication, check whether it's allowed + if (!authenticated) + { + // User requested GSS authentication, but server said that no auth is required + // If and only if our connection is gss encrypted, we consider us already authenticated + if (requiredAuthModes.HasFlag(RequireAuthMode.GSS) && IsGssEncrypted) + return; + ThrowIfNotAllowed(requiredAuthModes, RequireAuthMode.None); + } return; - case AuthenticationRequestType.AuthenticationCleartextPassword: + case AuthenticationRequestType.CleartextPassword: + ThrowIfNotAllowed(requiredAuthModes, RequireAuthMode.Password); await AuthenticateCleartext(username, async, cancellationToken).ConfigureAwait(false); break; - case AuthenticationRequestType.AuthenticationMD5Password: + case AuthenticationRequestType.MD5Password: + ThrowIfNotAllowed(requiredAuthModes, RequireAuthMode.MD5); await AuthenticateMD5(username, ((AuthenticationMD5PasswordMessage)msg).Salt, async, cancellationToken).ConfigureAwait(false); break; - case AuthenticationRequestType.AuthenticationSASL: + case AuthenticationRequestType.SASL: + ThrowIfNotAllowed(requiredAuthModes, RequireAuthMode.ScramSHA256); await AuthenticateSASL(((AuthenticationSASLMessage)msg).Mechanisms, username, async, cancellationToken).ConfigureAwait(false); break; - case AuthenticationRequestType.AuthenticationGSS: - case AuthenticationRequestType.AuthenticationSSPI: - await DataSource.IntegratedSecurityHandler.NegotiateAuthentication(async, this).ConfigureAwait(false); + case AuthenticationRequestType.GSS: + case AuthenticationRequestType.SSPI: + ThrowIfNotAllowed(requiredAuthModes, msg.AuthRequestType == AuthenticationRequestType.GSS ? RequireAuthMode.GSS : RequireAuthMode.SSPI); + await DataSource.IntegratedSecurityHandler.NegotiateAuthentication(async, this, cancellationToken).ConfigureAwait(false); return; - case AuthenticationRequestType.AuthenticationGSSContinue: + case AuthenticationRequestType.GSSContinue: throw new NpgsqlException("Can't start auth cycle with AuthenticationGSSContinue"); default: throw new NotSupportedException($"Authentication method not supported (Received: {msg.AuthRequestType})"); } + + authenticated = true; + } + + static void ThrowIfNotAllowed(RequireAuthMode requiredAuthModes, RequireAuthMode requestedAuthMode) + { + if (!requiredAuthModes.HasFlag(requestedAuthMode)) + throw new NpgsqlException($"\"{requestedAuthMode}\" authentication method is not allowed. Allowed methods: {requiredAuthModes}"); } } async Task AuthenticateCleartext(string username, bool async, CancellationToken cancellationToken = default) { var passwd = await GetPassword(username, async, cancellationToken).ConfigureAwait(false); - if (passwd == null) + if (string.IsNullOrEmpty(passwd)) throw new NpgsqlException("No password has been provided but the backend requires one (in cleartext)"); var encoded = new byte[Encoding.UTF8.GetByteCount(passwd) + 1]; @@ -71,10 +98,10 @@ async Task AuthenticateSASL(List mechanisms, string username, bool async { // At the time of writing PostgreSQL only supports SCRAM-SHA-256 and SCRAM-SHA-256-PLUS var serverSupportsSha256 = mechanisms.Contains("SCRAM-SHA-256"); - var clientSupportsSha256 = serverSupportsSha256 && Settings.ChannelBinding != ChannelBinding.Require; + var allowSha256 = serverSupportsSha256 && Settings.ChannelBinding != ChannelBinding.Require; var serverSupportsSha256Plus = mechanisms.Contains("SCRAM-SHA-256-PLUS"); - var clientSupportsSha256Plus = serverSupportsSha256Plus && Settings.ChannelBinding != ChannelBinding.Disable; - if (!clientSupportsSha256 && !clientSupportsSha256Plus) + var allowSha256Plus = serverSupportsSha256Plus && Settings.ChannelBinding != ChannelBinding.Disable; + if (!allowSha256 && !allowSha256Plus) { if (serverSupportsSha256 && Settings.ChannelBinding == ChannelBinding.Require) throw new NpgsqlException($"Couldn't connect because {nameof(ChannelBinding)} is set to {nameof(ChannelBinding.Require)} " + @@ -92,10 +119,10 @@ async Task AuthenticateSASL(List mechanisms, string username, bool async var cbind = string.Empty; var successfulBind = false; - if (clientSupportsSha256Plus) + if (allowSha256Plus) DataSource.TransportSecurityHandler.AuthenticateSASLSha256Plus(this, ref mechanism, ref cbindFlag, ref cbind, ref successfulBind); - if (!successfulBind && serverSupportsSha256) + if (!successfulBind && allowSha256) { mechanism = "SCRAM-SHA-256"; // We can get here if PostgreSQL supports only SCRAM-SHA-256 or there was an error while binding to SCRAM-SHA-256-PLUS @@ -114,8 +141,9 @@ async Task AuthenticateSASL(List mechanisms, string username, bool async throw new NpgsqlException("Unable to bind to SCRAM-SHA-256-PLUS, check logs for more information"); } - var passwd = await GetPassword(username, async, cancellationToken).ConfigureAwait(false) ?? - throw new NpgsqlException($"No password has been provided but the backend requires one (in SASL/{mechanism})"); + var passwd = await GetPassword(username, async, cancellationToken).ConfigureAwait(false); + if (string.IsNullOrEmpty(passwd)) + throw new NpgsqlException($"No password has been provided but the backend requires one (in SASL/{mechanism})"); // Assumption: the write buffer is big enough to contain all our outgoing messages var clientNonce = GetNonce(); @@ -124,7 +152,7 @@ async Task AuthenticateSASL(List mechanisms, string username, bool async await Flush(async, cancellationToken).ConfigureAwait(false); var saslContinueMsg = Expect(await ReadMessage(async).ConfigureAwait(false), this); - if (saslContinueMsg.AuthRequestType != AuthenticationRequestType.AuthenticationSASLContinue) + if (saslContinueMsg.AuthRequestType != AuthenticationRequestType.SASLContinue) throw new NpgsqlException("[SASL] AuthenticationSASLContinue message expected"); var firstServerMsg = AuthenticationSCRAMServerFirstMessage.Load(saslContinueMsg.Payload, ConnectionLogger); if (!firstServerMsg.Nonce.StartsWith(clientNonce, StringComparison.Ordinal)) @@ -134,13 +162,7 @@ async Task AuthenticateSASL(List mechanisms, string username, bool async var saltedPassword = Hi(passwd.Normalize(NormalizationForm.FormKC), saltBytes, firstServerMsg.Iteration); var clientKey = HMAC(saltedPassword, "Client Key"); - byte[] storedKey; -#if NET7_0_OR_GREATER - storedKey = SHA256.HashData(clientKey); -#else - using (var sha256 = SHA256.Create()) - storedKey = sha256.ComputeHash(clientKey); -#endif + var storedKey = SHA256.HashData(clientKey); var clientFirstMessageBare = $"n=*,r={clientNonce}"; var serverFirstMessage = $"r={firstServerMsg.Nonce},s={firstServerMsg.Salt},i={firstServerMsg.Iteration}"; var clientFinalMessageWithoutProof = $"c={cbind},r={firstServerMsg.Nonce}"; @@ -160,7 +182,7 @@ async Task AuthenticateSASL(List mechanisms, string username, bool async await Flush(async, cancellationToken).ConfigureAwait(false); var saslFinalServerMsg = Expect(await ReadMessage(async).ConfigureAwait(false), this); - if (saslFinalServerMsg.AuthRequestType != AuthenticationRequestType.AuthenticationSASLFinal) + if (saslFinalServerMsg.AuthRequestType != AuthenticationRequestType.SASLFinal) throw new NpgsqlException("[SASL] AuthenticationSASLFinal message expected"); var scramFinalServerMsg = AuthenticationSCRAMServerFinalMessage.Load(saslFinalServerMsg.Payload, ConnectionLogger); @@ -191,7 +213,7 @@ internal void AuthenticateSASLSha256Plus(ref string mechanism, ref string cbindF // try authenticate without channel binding even though both // the client and server supported it. The SCRAM exchange // checks for that, to prevent downgrade attacks. - if (!IsSecure) + if (!IsSslEncrypted) throw new NpgsqlException("Server offered SCRAM-SHA-256-PLUS authentication over a non-SSL connection"); var sslStream = (SslStream)_stream; @@ -201,53 +223,51 @@ internal void AuthenticateSASLSha256Plus(ref string mechanism, ref string cbindF return; } + // While SslStream.RemoteCertificate is X509Certificate2, it actually returns X509Certificate2 + // But to be on the safe side we'll just create a new instance of it using var remoteCertificate = new X509Certificate2(sslStream.RemoteCertificate); // Checking for hashing algorithms - HashAlgorithm? hashAlgorithm = null; var algorithmName = remoteCertificate.SignatureAlgorithm.FriendlyName; - if (algorithmName is null) - { - ConnectionLogger.LogWarning("Signature algorithm was null, falling back to SCRAM-SHA-256"); - } - else if (algorithmName.StartsWith("sha1", StringComparison.OrdinalIgnoreCase) || - algorithmName.StartsWith("md5", StringComparison.OrdinalIgnoreCase) || - algorithmName.StartsWith("sha256", StringComparison.OrdinalIgnoreCase)) - { - hashAlgorithm = SHA256.Create(); - } - else if (algorithmName.StartsWith("sha384", StringComparison.OrdinalIgnoreCase)) - { - hashAlgorithm = SHA384.Create(); - } - else if (algorithmName.StartsWith("sha512", StringComparison.OrdinalIgnoreCase)) + + HashAlgorithm? hashAlgorithm = algorithmName switch { - hashAlgorithm = SHA512.Create(); - } - else + not null when algorithmName.StartsWith("sha1", StringComparison.OrdinalIgnoreCase) => SHA256.Create(), + not null when algorithmName.StartsWith("md5", StringComparison.OrdinalIgnoreCase) => SHA256.Create(), + not null when algorithmName.StartsWith("sha256", StringComparison.OrdinalIgnoreCase) => SHA256.Create(), + not null when algorithmName.StartsWith("sha384", StringComparison.OrdinalIgnoreCase) => SHA384.Create(), + not null when algorithmName.StartsWith("sha512", StringComparison.OrdinalIgnoreCase) => SHA512.Create(), + not null when algorithmName.StartsWith("sha3-256", StringComparison.OrdinalIgnoreCase) => SHA3_256.Create(), + not null when algorithmName.StartsWith("sha3-384", StringComparison.OrdinalIgnoreCase) => SHA3_384.Create(), + not null when algorithmName.StartsWith("sha3-512", StringComparison.OrdinalIgnoreCase) => SHA3_512.Create(), + + _ => null + }; + + if (hashAlgorithm is null) { ConnectionLogger.LogWarning( - $"Support for signature algorithm {algorithmName} is not yet implemented, falling back to SCRAM-SHA-256"); + algorithmName is null + ? "Signature algorithm was null, falling back to SCRAM-SHA-256" + : $"Support for signature algorithm {algorithmName} is not yet implemented, falling back to SCRAM-SHA-256"); + return; } - if (hashAlgorithm != null) - { - using var _ = hashAlgorithm; - - // RFC 5929 - mechanism = "SCRAM-SHA-256-PLUS"; - // PostgreSQL only supports tls-server-end-point binding - cbindFlag = "p=tls-server-end-point"; - // SCRAM-SHA-256-PLUS depends on using ssl stream, so it's fine - var cbindFlagBytes = Encoding.UTF8.GetBytes($"{cbindFlag},,"); - - var certificateHash = hashAlgorithm.ComputeHash(remoteCertificate.GetRawCertData()); - var cbindBytes = new byte[cbindFlagBytes.Length + certificateHash.Length]; - cbindFlagBytes.CopyTo(cbindBytes, 0); - certificateHash.CopyTo(cbindBytes, cbindFlagBytes.Length); - cbind = Convert.ToBase64String(cbindBytes); - successfulBind = true; - IsScramPlus = true; - } + using var _ = hashAlgorithm; + + // RFC 5929 + mechanism = "SCRAM-SHA-256-PLUS"; + // PostgreSQL only supports tls-server-end-point binding + cbindFlag = "p=tls-server-end-point"; + // SCRAM-SHA-256-PLUS depends on using ssl stream, so it's fine + var cbindFlagBytes = Encoding.UTF8.GetBytes($"{cbindFlag},,"); + + var certificateHash = hashAlgorithm.ComputeHash(remoteCertificate.GetRawCertData()); + var cbindBytes = new byte[cbindFlagBytes.Length + certificateHash.Length]; + cbindFlagBytes.CopyTo(cbindBytes, 0); + certificateHash.CopyTo(cbindBytes, cbindFlagBytes.Length); + cbind = Convert.ToBase64String(cbindBytes); + successfulBind = true; + IsScramPlus = true; } static byte[] Hi(string str, byte[] salt, int count) @@ -260,28 +280,15 @@ static byte[] Xor(byte[] buffer1, byte[] buffer2) return buffer1; } - static byte[] HMAC(byte[] key, string data) - { - var dataBytes = Encoding.UTF8.GetBytes(data); -#if NET7_0_OR_GREATER - return HMACSHA256.HashData(key, dataBytes); -#else - using var ih = IncrementalHash.CreateHMAC(HashAlgorithmName.SHA256, key); - ih.AppendData(dataBytes); - return ih.GetHashAndReset(); -#endif - } + static byte[] HMAC(byte[] key, string data) => HMACSHA256.HashData(key, Encoding.UTF8.GetBytes(data)); async Task AuthenticateMD5(string username, byte[] salt, bool async, CancellationToken cancellationToken = default) { var passwd = await GetPassword(username, async, cancellationToken).ConfigureAwait(false); - if (passwd == null) + if (string.IsNullOrEmpty(passwd)) throw new NpgsqlException("No password has been provided but the backend requires one (in MD5)"); byte[] result; -#if !NET7_0_OR_GREATER - using (var md5 = MD5.Create()) -#endif { // First phase var passwordBytes = NpgsqlWriteBuffer.UTF8Encoding.GetBytes(passwd); @@ -291,11 +298,7 @@ async Task AuthenticateMD5(string username, byte[] salt, bool async, Cancellatio usernameBytes.CopyTo(cryptBuf, passwordBytes.Length); var sb = new StringBuilder(); -#if NET7_0_OR_GREATER var hashResult = MD5.HashData(cryptBuf); -#else - var hashResult = md5.ComputeHash(cryptBuf); -#endif foreach (var b in hashResult) sb.Append(b.ToString("x2")); @@ -310,53 +313,54 @@ async Task AuthenticateMD5(string username, byte[] salt, bool async, Cancellatio prehashbytes.CopyTo(cryptBuf, 0); sb = new StringBuilder("md5"); -#if NET7_0_OR_GREATER hashResult = MD5.HashData(cryptBuf); -#else - hashResult = md5.ComputeHash(cryptBuf); -#endif foreach (var b in hashResult) sb.Append(b.ToString("x2")); var resultString = sb.ToString(); result = new byte[Encoding.UTF8.GetByteCount(resultString) + 1]; Encoding.UTF8.GetBytes(resultString, 0, resultString.Length, result, 0); - result[result.Length - 1] = 0; + result[^1] = 0; } await WritePassword(result, async, cancellationToken).ConfigureAwait(false); await Flush(async, cancellationToken).ConfigureAwait(false); } -#if NET7_0_OR_GREATER - internal async Task AuthenticateGSS(bool async) + internal async ValueTask AuthenticateGSS(bool async, CancellationToken cancellationToken) { var targetName = $"{KerberosServiceName}/{Host}"; - using var authContext = new NegotiateAuthentication(new NegotiateAuthenticationClientOptions{ TargetName = targetName}); + var clientOptions = new NegotiateAuthenticationClientOptions { TargetName = targetName }; + NegotiateOptionsCallback?.Invoke(clientOptions); + + using var authContext = new NegotiateAuthentication(clientOptions); var data = authContext.GetOutgoingBlob(ReadOnlySpan.Empty, out var statusCode)!; - Debug.Assert(statusCode == NegotiateAuthenticationStatusCode.ContinueNeeded); - await WritePassword(data, 0, data.Length, async, UserCancellationToken).ConfigureAwait(false); - await Flush(async, UserCancellationToken).ConfigureAwait(false); + if (statusCode != NegotiateAuthenticationStatusCode.ContinueNeeded) + { + // Unable to retrieve credentials or some other issue + throw new NpgsqlException($"Unable to authenticate with GSS: received {statusCode} instead of the expected ContinueNeeded"); + } + await WritePassword(data, 0, data.Length, async, cancellationToken).ConfigureAwait(false); + await Flush(async, cancellationToken).ConfigureAwait(false); while (true) { var response = ExpectAny(await ReadMessage(async).ConfigureAwait(false), this); - if (response.AuthRequestType == AuthenticationRequestType.AuthenticationOk) + if (response.AuthRequestType == AuthenticationRequestType.Ok) break; if (response is not AuthenticationGSSContinueMessage gssMsg) throw new NpgsqlException($"Received unexpected authentication request message {response.AuthRequestType}"); - data = authContext.GetOutgoingBlob(gssMsg.AuthenticationData.AsSpan(), out statusCode)!; + data = authContext.GetOutgoingBlob(gssMsg.AuthenticationData.AsSpan(), out statusCode); if (statusCode is not NegotiateAuthenticationStatusCode.Completed and not NegotiateAuthenticationStatusCode.ContinueNeeded) throw new NpgsqlException($"Error while authenticating GSS/SSPI: {statusCode}"); // We might get NegotiateAuthenticationStatusCode.Completed but the data will not be null // This can happen if it's the first cycle, in which case we have to send that data to complete handshake (#4888) if (data is null) continue; - await WritePassword(data, 0, data.Length, async, UserCancellationToken).ConfigureAwait(false); - await Flush(async, UserCancellationToken).ConfigureAwait(false); + await WritePassword(data, 0, data.Length, async, cancellationToken).ConfigureAwait(false); + await Flush(async, cancellationToken).ConfigureAwait(false); } } -#endif async ValueTask GetPassword(string username, bool async, CancellationToken cancellationToken = default) { diff --git a/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs b/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs index f3e3173124..b801b11b84 100644 --- a/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs +++ b/src/Npgsql/Internal/NpgsqlConnector.FrontendMessages.cs @@ -19,6 +19,7 @@ internal Task WriteDescribe(StatementOrPortal statementOrPortal, byte[] asciiNam (asciiName.Length + 1); // Statement/portal name var writeBuffer = WriteBuffer; + writeBuffer.StartMessage(len); if (writeBuffer.WriteSpaceLeft < len) return FlushAndWrite(len, statementOrPortal, asciiName, async, cancellationToken); @@ -48,6 +49,7 @@ internal Task WriteSync(bool async, CancellationToken cancellationToken = defaul sizeof(int); // Length var writeBuffer = WriteBuffer; + writeBuffer.StartMessage(len); if (writeBuffer.WriteSpaceLeft < len) return FlushAndWrite(async, cancellationToken); @@ -79,6 +81,7 @@ internal Task WriteExecute(int maxRows, bool async, CancellationToken cancellati sizeof(int); // Max number of rows var writeBuffer = WriteBuffer; + writeBuffer.StartMessage(len); if (writeBuffer.WriteSpaceLeft < len) return FlushAndWrite(maxRows, async, cancellationToken); @@ -118,9 +121,6 @@ internal async Task WriteParse(string sql, byte[] asciiName, List= headerLength, "Write buffer too small for Bind header"); - await Flush(async, cancellationToken).ConfigureAwait(false); - } - var formatCodesSum = 0; var paramsLength = 0; for (var paramIndex = 0; paramIndex < parameters.Count; paramIndex++) @@ -197,8 +196,15 @@ internal async Task WriteBind( sizeof(short) + // Number of result format codes sizeof(short) * (unknownResultTypeList?.Length ?? 1); // Result format codes - writeBuffer.WriteByte(FrontendMessageCode.Bind); - writeBuffer.WriteInt32(messageLength - 1); + WriteBuffer.StartMessage(messageLength); + if (WriteBuffer.WriteSpaceLeft < headerLength) + { + Debug.Assert(WriteBuffer.Size >= headerLength, "Write buffer too small for Bind header"); + await Flush(async, cancellationToken).ConfigureAwait(false); + } + + WriteBuffer.WriteByte(FrontendMessageCode.Bind); + WriteBuffer.WriteInt32(messageLength - 1); Debug.Assert(portal == string.Empty); writeBuffer.WriteByte(0); // Portal is always empty @@ -269,6 +275,7 @@ internal Task WriteClose(StatementOrPortal type, byte[] asciiName, bool async, C asciiName.Length + sizeof(byte); // Statement or portal name plus null terminator var writeBuffer = WriteBuffer; + writeBuffer.StartMessage(len); if (writeBuffer.WriteSpaceLeft < len) return FlushAndWrite(len, type, asciiName, async, cancellationToken); @@ -296,14 +303,17 @@ internal async Task WriteQuery(string sql, bool async, CancellationToken cancell { var queryByteLen = TextEncoding.GetByteCount(sql); + var len = sizeof(byte) + + sizeof(int) + // Message length (including self excluding code) + queryByteLen + // Query byte length + sizeof(byte); + + WriteBuffer.StartMessage(len); if (WriteBuffer.WriteSpaceLeft < 1 + 4) await Flush(async, cancellationToken).ConfigureAwait(false); WriteBuffer.WriteByte(FrontendMessageCode.Query); - WriteBuffer.WriteInt32( - sizeof(int) + // Message length (including self excluding code) - queryByteLen + // Query byte length - sizeof(byte)); // Null terminator + WriteBuffer.WriteInt32(len - 1); await WriteBuffer.WriteString(sql, queryByteLen, async, cancellationToken).ConfigureAwait(false); if (WriteBuffer.WriteSpaceLeft < 1) @@ -316,6 +326,7 @@ internal async Task WriteCopyDone(bool async, CancellationToken cancellationToke const int len = sizeof(byte) + // Message code sizeof(int); // Length + WriteBuffer.StartMessage(len); if (WriteBuffer.WriteSpaceLeft < len) await Flush(async, cancellationToken).ConfigureAwait(false); @@ -331,6 +342,7 @@ internal async Task WriteCopyFail(bool async, CancellationToken cancellationToke sizeof(int) + // Length sizeof(byte); // Error message is always empty (only a null terminator) + WriteBuffer.StartMessage(len); if (WriteBuffer.WriteSpaceLeft < len) await Flush(async, cancellationToken).ConfigureAwait(false); @@ -348,6 +360,7 @@ internal void WriteCancelRequest(int backendProcessId, int backendSecretKey) Debug.Assert(backendProcessId != 0); + WriteBuffer.StartMessage(len); if (WriteBuffer.WriteSpaceLeft < len) Flush(false).GetAwaiter().GetResult(); @@ -362,6 +375,7 @@ internal void WriteTerminate() const int len = sizeof(byte) + // Message code sizeof(int); // Length + WriteBuffer.StartMessage(len); if (WriteBuffer.WriteSpaceLeft < len) Flush(false).GetAwaiter().GetResult(); @@ -374,6 +388,7 @@ internal void WriteSslRequest() const int len = sizeof(int) + // Length sizeof(int); // SSL request code + WriteBuffer.StartMessage(len); if (WriteBuffer.WriteSpaceLeft < len) Flush(false).GetAwaiter().GetResult(); @@ -381,6 +396,19 @@ internal void WriteSslRequest() WriteBuffer.WriteInt32(80877103); } + internal void WriteGSSEncryptRequest() + { + const int len = sizeof(int) + // Length + sizeof(int); // GSSEnc request code + + WriteBuffer.StartMessage(len); + if (WriteBuffer.WriteSpaceLeft < len) + Flush(false).GetAwaiter().GetResult(); + + WriteBuffer.WriteInt32(len); + WriteBuffer.WriteInt32(80877104); + } + internal void WriteStartup(Dictionary parameters) { const int protocolVersion3 = 3 << 16; // 196608 @@ -394,6 +422,7 @@ internal void WriteStartup(Dictionary parameters) NpgsqlWriteBuffer.UTF8Encoding.GetByteCount(kvp.Value) + 1; // Should really never happen, just in case + WriteBuffer.StartMessage(len); if (len > WriteBuffer.Size) throw new Exception("Startup message bigger than buffer"); @@ -417,8 +446,10 @@ internal void WriteStartup(Dictionary parameters) internal async Task WritePassword(byte[] payload, int offset, int count, bool async, CancellationToken cancellationToken = default) { + WriteBuffer.StartMessage(sizeof(byte) + sizeof(int) + count); if (WriteBuffer.WriteSpaceLeft < sizeof(byte) + sizeof(int)) await WriteBuffer.Flush(async, cancellationToken).ConfigureAwait(false); + WriteBuffer.WriteByte(FrontendMessageCode.Password); WriteBuffer.WriteInt32(sizeof(int) + count); @@ -441,6 +472,7 @@ internal async Task WriteSASLInitialResponse(string mechanism, byte[] initialRes sizeof(int) + // Initial response length (initialResponse?.Length ?? 0); // Initial response payload + WriteBuffer.StartMessage(len); if (WriteBuffer.WriteSpaceLeft < len) await WriteBuffer.Flush(async, cancellationToken).ConfigureAwait(false); @@ -464,6 +496,7 @@ internal async Task WriteSASLInitialResponse(string mechanism, byte[] initialRes internal Task WritePregenerated(byte[] data, bool async = false, CancellationToken cancellationToken = default) { + WriteBuffer.StartMessage(data.Length); if (WriteBuffer.WriteSpaceLeft < data.Length) return FlushAndWrite(data, async, cancellationToken); diff --git a/src/Npgsql/Internal/NpgsqlConnector.OldAuth.cs b/src/Npgsql/Internal/NpgsqlConnector.OldAuth.cs deleted file mode 100644 index 6d60251773..0000000000 --- a/src/Npgsql/Internal/NpgsqlConnector.OldAuth.cs +++ /dev/null @@ -1,153 +0,0 @@ -using System; -using System.IO; -using System.Net; -using System.Net.Security; -using System.Security.Cryptography; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.BackendMessages; -using static Npgsql.Util.Statics; - -namespace Npgsql.Internal; - - -partial class NpgsqlConnector -{ -#if !NET7_0_OR_GREATER - internal async Task AuthenticateGSS(bool async) - { - var targetName = $"{KerberosServiceName}/{Host}"; - - using var negotiateStream = new NegotiateStream(new GSSPasswordMessageStream(this), true); - try - { - if (async) - await negotiateStream.AuthenticateAsClientAsync(CredentialCache.DefaultNetworkCredentials, targetName).ConfigureAwait(false); - else - negotiateStream.AuthenticateAsClient(CredentialCache.DefaultNetworkCredentials, targetName); - } - catch (AuthenticationCompleteException) - { - return; - } - catch (IOException e) when (e.InnerException is AuthenticationCompleteException) - { - return; - } - catch (IOException e) when (e.InnerException is PostgresException) - { - throw e.InnerException; - } - - throw new NpgsqlException("NegotiateStream.AuthenticateAsClient completed unexpectedly without signaling success"); - } - - /// - /// This Stream is placed between NegotiateStream and the socket's NetworkStream (or SSLStream). It intercepts - /// traffic and performs the following operations: - /// * Outgoing messages are framed in PostgreSQL's PasswordMessage, and incoming are stripped of it. - /// * NegotiateStream frames payloads with a 5-byte header, which PostgreSQL doesn't understand. This header is - /// stripped from outgoing messages and added to incoming ones. - /// - /// - /// See https://referencesource.microsoft.com/#System/net/System/Net/_StreamFramer.cs,16417e735f0e9530,references - /// - sealed class GSSPasswordMessageStream : Stream - { - readonly NpgsqlConnector _connector; - int _leftToWrite; - int _leftToRead, _readPos; - byte[]? _readBuf; - - internal GSSPasswordMessageStream(NpgsqlConnector connector) - => _connector = connector; - - public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) - => Write(buffer, offset, count, true, cancellationToken); - - public override void Write(byte[] buffer, int offset, int count) - => Write(buffer, offset, count, false).GetAwaiter().GetResult(); - - async Task Write(byte[] buffer, int offset, int count, bool async, CancellationToken cancellationToken = default) - { - if (_leftToWrite == 0) - { - // We're writing the frame header, which contains the payload size. - _leftToWrite = (buffer[3] << 8) | buffer[4]; - - buffer[0] = 22; - if (buffer[1] != 1) - throw new NotSupportedException($"Received frame header major v {buffer[1]} (different from 1)"); - if (buffer[2] != 0) - throw new NotSupportedException($"Received frame header minor v {buffer[2]} (different from 0)"); - - // In case of payload data in the same buffer just after the frame header - if (count == 5) - return; - count -= 5; - offset += 5; - } - - if (count > _leftToWrite) - throw new NpgsqlException($"NegotiateStream trying to write {count} bytes but according to frame header we only have {_leftToWrite} left!"); - await _connector.WritePassword(buffer, offset, count, async, cancellationToken).ConfigureAwait(false); - await _connector.Flush(async, cancellationToken).ConfigureAwait(false); - _leftToWrite -= count; - } - - public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) - => Read(buffer, offset, count, true, cancellationToken); - - public override int Read(byte[] buffer, int offset, int count) - => Read(buffer, offset, count, false).GetAwaiter().GetResult(); - - async Task Read(byte[] buffer, int offset, int count, bool async, CancellationToken cancellationToken = default) - { - if (_leftToRead == 0) - { - var response = ExpectAny(await _connector.ReadMessage(async).ConfigureAwait(false), _connector); - if (response.AuthRequestType == AuthenticationRequestType.AuthenticationOk) - throw new AuthenticationCompleteException(); - var gssMsg = response as AuthenticationGSSContinueMessage; - if (gssMsg == null) - throw new NpgsqlException($"Received unexpected authentication request message {response.AuthRequestType}"); - _readBuf = gssMsg.AuthenticationData; - _leftToRead = gssMsg.AuthenticationData.Length; - _readPos = 0; - buffer[0] = 22; - buffer[1] = 1; - buffer[2] = 0; - buffer[3] = (byte)((_leftToRead >> 8) & 0xFF); - buffer[4] = (byte)(_leftToRead & 0xFF); - return 5; - } - - if (count > _leftToRead) - throw new NpgsqlException($"NegotiateStream trying to read {count} bytes but according to frame header we only have {_leftToRead} left!"); - count = Math.Min(count, _leftToRead); - Array.Copy(_readBuf!, _readPos, buffer, offset, count); - _leftToRead -= count; - return count; - } - - public override void Flush() { } - - public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); - public override void SetLength(long value) => throw new NotSupportedException(); - - public override bool CanRead => true; - public override bool CanWrite => true; - public override bool CanSeek => false; - public override long Length => throw new NotSupportedException(); - - public override long Position - { - get => throw new NotSupportedException(); - set => throw new NotSupportedException(); - } - } - - sealed class AuthenticationCompleteException : Exception { } -#endif -} diff --git a/src/Npgsql/Internal/NpgsqlConnector.cs b/src/Npgsql/Internal/NpgsqlConnector.cs index c3726180a1..63b26f3878 100644 --- a/src/Npgsql/Internal/NpgsqlConnector.cs +++ b/src/Npgsql/Internal/NpgsqlConnector.cs @@ -1,8 +1,10 @@ using System; using System.Buffers; +using System.Buffers.Binary; using System.Collections.Generic; using System.Data; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.IO; using System.Net; using System.Net.Security; @@ -14,21 +16,21 @@ using System.Security.Cryptography.X509Certificates; using System.Text; using System.Threading; -using System.Threading.Channels; using System.Threading.Tasks; using Npgsql.BackendMessages; using Npgsql.Util; -using static Npgsql.Util.Statics; -using System.Transactions; using Microsoft.Extensions.Logging; using Npgsql.Properties; +using static Npgsql.Util.Statics; + namespace Npgsql.Internal; /// /// Represents a connection to a PostgreSQL backend. Unlike NpgsqlConnection objects, which are /// exposed to users, connectors are internal to Npgsql and are recycled by the connection pool. /// +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] public sealed partial class NpgsqlConnector { #region Fields and Properties @@ -53,12 +55,14 @@ public sealed partial class NpgsqlConnector /// public NpgsqlConnectionStringBuilder Settings { get; } - Action? ClientCertificatesCallback { get; } - RemoteCertificateValidationCallback? UserCertificateValidationCallback { get; } + Action? SslClientAuthenticationOptionsCallback { get; } + #pragma warning disable CS0618 // ProvidePasswordCallback is obsolete ProvidePasswordCallback? ProvidePasswordCallback { get; } #pragma warning restore CS0618 + Action? NegotiateOptionsCallback { get; } + public Encoding TextEncoding { get; private set; } = default!; /// @@ -112,12 +116,14 @@ internal string InferredUserName /// internal int Id => BackendProcessId; - internal PgSerializerOptions SerializerOptions { get; set; } = default!; + internal NpgsqlDataSource.ReloadableState ReloadableState = null!; /// /// Information about PostgreSQL and PostgreSQL-like databases (e.g. type definitions, capabilities...). /// - public NpgsqlDatabaseInfo DatabaseInfo { get; internal set; } = default!; + public NpgsqlDatabaseInfo DatabaseInfo => ReloadableState.DatabaseInfo; + internal PgSerializerOptions SerializerOptions => ReloadableState.SerializerOptions; + internal IDbTypeResolver? DbTypeResolver => ReloadableState.DbTypeResolver; /// /// The current transaction status for this connector. @@ -172,48 +178,13 @@ internal string InferredUserName /// /// Holds all run-time parameters in raw, binary format for efficient handling without allocations. /// - readonly List<(byte[] Name, byte[] Value)> _rawParameters = new(); + readonly List<(byte[] Name, byte[] Value)> _rawParameters = []; /// /// If this connector was broken, this contains the exception that caused the break. /// volatile Exception? _breakReason; - // Used by replication to change our cancellation behaviour on ColumnStreams. - internal bool LongRunningConnection { get; set; } - - /// - /// - /// Used by the pool to indicate that I/O is currently in progress on this connector, so that another write - /// isn't started concurrently. Note that since we have only one write loop, this is only ever usedto - /// protect against an over-capacity writes into a connector that's currently *asynchronously* writing. - /// - /// - /// It is guaranteed that the currently-executing - /// Specifically, reading may occur - and the connector may even be returned to the pool - before this is - /// released. - /// - /// - internal volatile int MultiplexAsyncWritingLock; - - /// - internal void FlagAsNotWritableForMultiplexing() - { - Debug.Assert(Settings.Multiplexing); - Debug.Assert(CommandsInFlightCount > 0 || IsBroken || IsClosed, - $"About to mark multiplexing connector as non-writable, but {nameof(CommandsInFlightCount)} is {CommandsInFlightCount}"); - - Interlocked.Exchange(ref MultiplexAsyncWritingLock, 1); - } - - /// - internal void FlagAsWritableForMultiplexing() - { - Debug.Assert(Settings.Multiplexing); - if (Interlocked.CompareExchange(ref MultiplexAsyncWritingLock, 0, 1) != 1) - throw new Exception("Multiplexing lock was not taken when releasing. Please report a bug."); - } - /// /// A lock that's taken while a cancellation is being delivered; new queries are blocked until the /// cancellation is delivered. This reduces the chance that a cancellation meant for a previous @@ -274,9 +245,15 @@ internal bool PostgresCancellationPerformed internal bool UserCancellationRequested => _userCancellationRequested; internal CancellationToken UserCancellationToken { get; set; } internal bool AttemptPostgresCancellation { get; private set; } - static readonly TimeSpan _cancelImmediatelyTimeout = TimeSpan.FromMilliseconds(-1); + static readonly TimeSpan _cancelImmediatelyTimeout = TimeSpan.Zero; + + static readonly SslApplicationProtocol _alpnProtocol = new("postgresql"); - IDisposable? _certificate; +#pragma warning disable CA1859 + // We're casting to IDisposable to not explicitly reference X509Certificate2 for NativeAOT + // TODO: probably pointless now, needs to be rechecked + List? _certificates; +#pragma warning restore CA1859 internal NpgsqlLoggingConfiguration LoggingConfiguration { get; } @@ -331,12 +308,34 @@ internal bool PostgresCancellationPerformed internal NpgsqlConnector(NpgsqlDataSource dataSource, NpgsqlConnection conn) : this(dataSource) { - if (conn.ProvideClientCertificatesCallback is not null) - ClientCertificatesCallback = certs => conn.ProvideClientCertificatesCallback(certs); - if (conn.UserCertificateValidationCallback is not null) - UserCertificateValidationCallback = conn.UserCertificateValidationCallback; - + var sslClientAuthenticationOptionsCallback = conn.SslClientAuthenticationOptionsCallback; #pragma warning disable CS0618 // Obsolete + var provideClientCertificatesCallback = conn.ProvideClientCertificatesCallback; + var userCertificateValidationCallback = conn.UserCertificateValidationCallback; + if (provideClientCertificatesCallback is not null || + userCertificateValidationCallback is not null) + { + if (sslClientAuthenticationOptionsCallback is not null) + throw new NotSupportedException(NpgsqlStrings.SslClientAuthenticationOptionsCallbackWithOtherCallbacksNotSupported); + + sslClientAuthenticationOptionsCallback = options => + { + if (provideClientCertificatesCallback is not null) + { + options.ClientCertificates ??= new X509Certificate2Collection(); + provideClientCertificatesCallback.Invoke(options.ClientCertificates); + } + + if (userCertificateValidationCallback is not null) + { + options.RemoteCertificateValidationCallback = userCertificateValidationCallback; + } + }; + } + + if (sslClientAuthenticationOptionsCallback is not null) + SslClientAuthenticationOptionsCallback = sslClientAuthenticationOptionsCallback; + ProvidePasswordCallback = conn.ProvidePasswordCallback; #pragma warning restore CS0618 } @@ -344,8 +343,7 @@ internal NpgsqlConnector(NpgsqlDataSource dataSource, NpgsqlConnection conn) NpgsqlConnector(NpgsqlConnector connector) : this(connector.DataSource) { - ClientCertificatesCallback = connector.ClientCertificatesCallback; - UserCertificateValidationCallback = connector.UserCertificateValidationCallback; + SslClientAuthenticationOptionsCallback = connector.SslClientAuthenticationOptionsCallback; ProvidePasswordCallback = connector.ProvidePasswordCallback; } @@ -361,8 +359,8 @@ internal NpgsqlConnector(NpgsqlDataSource dataSource, NpgsqlConnection conn) TransactionLogger = LoggingConfiguration.TransactionLogger; CopyLogger = LoggingConfiguration.CopyLogger; - ClientCertificatesCallback = dataSource.ClientCertificatesCallback; - UserCertificateValidationCallback = dataSource.UserCertificateValidationCallback; + SslClientAuthenticationOptionsCallback = dataSource.SslClientAuthenticationOptionsCallback; + NegotiateOptionsCallback = dataSource.Configuration.NegotiateOptionsCallback; State = ConnectorState.Closed; TransactionStatus = TransactionStatus.Idle; @@ -371,30 +369,15 @@ internal NpgsqlConnector(NpgsqlDataSource dataSource, NpgsqlConnection conn) _isKeepAliveEnabled = Settings.KeepAlive > 0; if (_isKeepAliveEnabled) - _keepAliveTimer = new Timer(PerformKeepAlive, null, Timeout.Infinite, Timeout.Infinite); + { + using (ExecutionContext.SuppressFlow()) // Don't capture the current ExecutionContext and its AsyncLocals onto the timer causing them to live forever + _keepAliveTimer = new Timer(PerformKeepAlive, null, Timeout.Infinite, Timeout.Infinite); + } DataReader = new NpgsqlDataReader(this); // TODO: Not just for automatic preparation anymore... PreparedStatementManager = new PreparedStatementManager(this); - - if (Settings.Multiplexing) - { - // Note: It's OK for this channel to be unbounded: each command enqueued to it is accompanied by sending - // it to PostgreSQL. If we overload it, a TCP zero window will make us block on the networking side - // anyway. - // Note: the in-flight channel can probably be single-writer, but that doesn't actually do anything - // at this point. And we currently rely on being able to complete the channel at any point (from - // Break). We may want to revisit this if an optimized, SingleWriter implementation is introduced. - var commandsInFlightChannel = Channel.CreateUnbounded( - new UnboundedChannelOptions { SingleReader = true }); - CommandsInFlightReader = commandsInFlightChannel.Reader; - CommandsInFlightWriter = commandsInFlightChannel.Writer; - - // TODO: Properly implement this - if (_isKeepAliveEnabled) - throw new NotImplementedException("Keepalive not yet implemented for multiplexing"); - } } #endregion @@ -435,7 +418,7 @@ internal ConnectorState State /// /// Returns whether the connector is open, regardless of any task it is currently performing /// - bool IsConnected => State is not (ConnectorState.Closed or ConnectorState.Connecting or ConnectorState.Broken); + internal bool IsConnected => State is not (ConnectorState.Closed or ConnectorState.Connecting or ConnectorState.Broken); internal bool IsReady => State == ConnectorState.Ready; internal bool IsClosed => State == ConnectorState.Closed; @@ -456,20 +439,30 @@ internal async Task Open(NpgsqlTimeout timeout, bool async, CancellationToken ca State = ConnectorState.Connecting; LogMessages.OpeningPhysicalConnection(ConnectionLogger, Host, Port, Database, UserFacingConnectionString); - var stopwatch = Stopwatch.StartNew(); + var startOpenTimestamp = Stopwatch.GetTimestamp(); + + Activity? activity = null; try { - await OpenCore(this, Settings.SslMode, timeout, async, cancellationToken).ConfigureAwait(false); + var username = await GetUsernameAsync(async, cancellationToken).ConfigureAwait(false); + + activity = NpgsqlActivitySource.PhysicalConnectionOpen(this); + + var gssEncMode = GetGssEncMode(Settings); + + await OpenCore(this, username, Settings.SslMode, gssEncMode, timeout, async, cancellationToken).ConfigureAwait(false); + + if (activity is not null) + NpgsqlActivitySource.Enrich(activity, this); await DataSource.Bootstrap(this, timeout, forceReload: false, async, cancellationToken).ConfigureAwait(false); - Debug.Assert(DataSource.SerializerOptions is not null); - Debug.Assert(DataSource.DatabaseInfo is not null); - SerializerOptions = DataSource.SerializerOptions; - DatabaseInfo = DataSource.DatabaseInfo; + // The connector directly references the current reloadable state reference, to protect it against changes by a concurrent + // ReloadTypes. We update them here before returning the connector from the pool. + ReloadableState = DataSource.CurrentReloadableState; - if (Settings.Pooling && !Settings.Multiplexing && !Settings.NoResetOnClose && DatabaseInfo.SupportsDiscard) + if (Settings.Pooling && Settings is { NoResetOnClose: false } && DatabaseInfo.SupportsDiscard) { _sendResetOnClose = true; GenerateResetMessage(); @@ -477,20 +470,6 @@ internal async Task Open(NpgsqlTimeout timeout, bool async, CancellationToken ca OpenTimestamp = DateTime.UtcNow; - if (Settings.Multiplexing) - { - // Start an infinite async loop, which processes incoming multiplexing traffic. - // It is intentionally not awaited and will run as long as the connector is alive. - // The CommandsInFlightWriter channel is completed in Cleanup, which should cause this task - // to complete. - _ = Task.Run(MultiplexingReadLoop, CancellationToken.None) - .ContinueWith(t => - { - // Note that we *must* observe the exception if the task is faulted. - ConnectionLogger.LogError(t.Exception!, "Exception bubbled out of multiplexing read loop", Id); - }, TaskContinuationOptions.OnlyOnFaulted); - } - if (_isKeepAliveEnabled) { // Start the keep alive mechanism to work by scheduling the timer. @@ -513,7 +492,7 @@ internal async Task Open(NpgsqlTimeout timeout, bool async, CancellationToken ca { if (async) await DataSource.ConnectionInitializerAsync(tempConnection).ConfigureAwait(false); - else if (!async) + else DataSource.ConnectionInitializer(tempConnection); } finally @@ -526,41 +505,66 @@ internal async Task Open(NpgsqlTimeout timeout, bool async, CancellationToken ca } } + activity?.Dispose(); + LogMessages.OpenedPhysicalConnection( - ConnectionLogger, Host, Port, Database, UserFacingConnectionString, stopwatch.ElapsedMilliseconds, Id); + ConnectionLogger, Host, Port, Database, UserFacingConnectionString, + (long)Stopwatch.GetElapsedTime(startOpenTimestamp).TotalMilliseconds, Id); } catch (Exception e) { - Break(e); + if (activity is not null) + NpgsqlActivitySource.SetException(activity, e); + Break(e, markHostAsOfflineOnConnecting: true); throw; } static async Task OpenCore( NpgsqlConnector conn, + string username, SslMode sslMode, + GssEncryptionMode gssEncMode, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken) { - await conn.RawOpen(sslMode, timeout, async, cancellationToken).ConfigureAwait(false); - - var username = await conn.GetUsernameAsync(async, cancellationToken).ConfigureAwait(false); - - timeout.CheckAndApply(conn); - conn.WriteStartupMessage(username); - await conn.Flush(async, cancellationToken).ConfigureAwait(false); + // If we fail to connect to the socket, there is no reason to retry even if SslMode/GssEncryption allows it + await conn.RawOpen(timeout, async, cancellationToken).ConfigureAwait(false); - using var cancellationRegistration = conn.StartCancellableOperation(cancellationToken, attemptPgCancellation: false); try { + await conn.SetupEncryption(sslMode, gssEncMode, timeout, async, cancellationToken).ConfigureAwait(false); + timeout.CheckAndApply(conn); + conn.WriteStartupMessage(username); + await conn.Flush(async, cancellationToken).ConfigureAwait(false); + + using var cancellationRegistration = conn.StartCancellableOperation(cancellationToken, attemptPgCancellation: false); await conn.Authenticate(username, timeout, async, cancellationToken).ConfigureAwait(false); } - catch (PostgresException e) - when (e.SqlState == PostgresErrorCodes.InvalidAuthorizationSpecification && - (sslMode == SslMode.Prefer && conn.IsSecure || sslMode == SslMode.Allow && !conn.IsSecure)) + catch (OperationCanceledException) { - cancellationRegistration.Dispose(); - Debug.Assert(!conn.IsBroken); + throw; + } + // We handle any exception here because on Windows while receiving a response from Postgres + // We might hit connection reset, in which case the actual error will be lost + // And we only read some IO error + // In addition, this behavior mimics libpq, where it retries as long as GssEncryptionMode and SslMode allows it + catch (Exception e) when + // We might also get here OperationCancelledException/TimeoutException + // But it's fine to fall down and retry because we'll immediately exit with the exact same exception + // + // Any error after trying with GSS encryption + (gssEncMode == GssEncryptionMode.Prefer || + // Auth error with/without SSL + (sslMode == SslMode.Prefer && conn.IsSslEncrypted || sslMode == SslMode.Allow && !conn.IsSslEncrypted)) + { + if (gssEncMode == GssEncryptionMode.Prefer) + { + conn.ConnectionLogger.LogTrace(e, "Error while opening physical connection with GSS encryption, retrying without it"); + gssEncMode = GssEncryptionMode.Disable; + } + else + sslMode = sslMode == SslMode.Prefer ? SslMode.Disable : SslMode.Require; conn.Cleanup(); @@ -568,7 +572,9 @@ static async Task OpenCore( // If Allow was specified and we failed (without SSL), retry with SSL await OpenCore( conn, - sslMode == SslMode.Prefer ? SslMode.Disable : SslMode.Require, + username, + sslMode, + gssEncMode, timeout, async, cancellationToken).ConfigureAwait(false); @@ -594,6 +600,148 @@ await OpenCore( } } + internal async ValueTask GSSEncrypt(bool async, bool isRequired, CancellationToken cancellationToken) + { + ConnectionLogger.LogTrace("Negotiating GSS encryption"); + + var targetName = $"{KerberosServiceName}/{Host}"; + var clientOptions = new NegotiateAuthenticationClientOptions { TargetName = targetName }; + + NegotiateOptionsCallback?.Invoke(clientOptions); + + var authentication = new NegotiateAuthentication(clientOptions); + + try + { + byte[]? data; + NegotiateAuthenticationStatusCode statusCode; + + try + { + data = authentication.GetOutgoingBlob(ReadOnlySpan.Empty, out statusCode)!; + } + catch (TypeInitializationException) + { + // On UNIX .NET throws TypeInitializationException if it's unable to load the native library + if (isRequired) + throw new NpgsqlException("Unable to load native library to negotiate GSS encryption"); + + return GssEncryptionResult.GetCredentialFailure; + } + + if (statusCode != NegotiateAuthenticationStatusCode.ContinueNeeded) + { + // Unable to retrieve credentials + // If it's required, throw an appropriate exception + if (isRequired) + throw new NpgsqlException($"Unable to negotiate GSS encryption: {statusCode}"); + + return GssEncryptionResult.GetCredentialFailure; + } + + WriteGSSEncryptRequest(); + await Flush(async, cancellationToken).ConfigureAwait(false); + + await ReadBuffer.Ensure(1, async).ConfigureAwait(false); + var response = (char)ReadBuffer.ReadByte(); + + // TODO: Server can respond with an error here + // but according to documentation we shouldn't display this error to the user/application + // since the server has not been authenticated (CVE-2024-10977) + // See https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-GSSAPI + switch (response) + { + default: + throw new NpgsqlException($"Received unknown response {response} for GSSEncRequest (expecting G or N)"); + case 'N': + if (isRequired) + throw new NpgsqlException("GGS encryption requested. No GSS encryption enabled connection from this host is configured."); + return GssEncryptionResult.NegotiateFailure; + case 'G': + break; + } + + if (ReadBuffer.ReadBytesLeft > 0) + throw new NpgsqlException( + "Additional unencrypted data received after GSS encryption negotiation - this should never happen, and may be an indication of a man-in-the-middle attack."); + + var lengthBuffer = new byte[4]; + + await WriteGssEncryptMessage(async, data, lengthBuffer, cancellationToken).ConfigureAwait(false); + + while (true) + { + if (async) + await _stream.ReadExactlyAsync(lengthBuffer, cancellationToken).ConfigureAwait(false); + else + _stream.ReadExactly(lengthBuffer); + + var messageLength = BitConverter.IsLittleEndian + ? BinaryPrimitives.ReverseEndianness(Unsafe.ReadUnaligned(ref lengthBuffer[0])) + : Unsafe.ReadUnaligned(ref lengthBuffer[0]); + + var buffer = ArrayPool.Shared.Rent(messageLength); + if (async) + await _stream.ReadExactlyAsync(buffer.AsMemory(0, messageLength), cancellationToken).ConfigureAwait(false); + else + _stream.ReadExactly(buffer.AsSpan(0, messageLength)); + + data = authentication.GetOutgoingBlob(buffer.AsSpan(0, messageLength), out statusCode); + ArrayPool.Shared.Return(buffer, clearArray: true); + if (statusCode is not NegotiateAuthenticationStatusCode.Completed and not NegotiateAuthenticationStatusCode.ContinueNeeded) + throw new NpgsqlException($"Error while negotiating GSS encryption: {statusCode}"); + + // TODO: the code below is the copy from GSS/SSPI auth + // It's unknown whether it holds true here or not + + // We might get NegotiateAuthenticationStatusCode.Completed but the data will not be null + // This can happen if it's the first cycle, in which case we have to send that data to complete handshake (#4888) + if (data is null) + { + Debug.Assert(statusCode == NegotiateAuthenticationStatusCode.Completed); + break; + } + + await WriteGssEncryptMessage(async, data, lengthBuffer, cancellationToken).ConfigureAwait(false); + } + + _stream = new GSSStream(_stream, authentication); + ReadBuffer.Underlying = _stream; + WriteBuffer.Underlying = _stream; + IsGssEncrypted = true; + authentication = null; + + ConnectionLogger.LogTrace("GSS encryption successful"); + return GssEncryptionResult.Success; + + async ValueTask WriteGssEncryptMessage(bool async, byte[] data, byte[] lengthBuffer, CancellationToken cancellationToken) + { + BinaryPrimitives.WriteInt32BigEndian(lengthBuffer, data.Length); + + if (async) + { + await _stream.WriteAsync(lengthBuffer, cancellationToken).ConfigureAwait(false); + await _stream.WriteAsync(data, cancellationToken).ConfigureAwait(false); + await _stream.FlushAsync(cancellationToken).ConfigureAwait(false); + } + else + { + _stream.Write(lengthBuffer); + _stream.Write(data); + _stream.Flush(); + } + } + } + catch (Exception e) when (e is not OperationCanceledException) + { + throw new NpgsqlException("Exception while performing GSS encryption", e); + } + finally + { + authentication?.Dispose(); + } + } + internal async ValueTask QueryDatabaseState( NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken = default) { @@ -648,8 +796,9 @@ void WriteStartupMessage(string username) if (Settings.Database is not null) startupParams["database"] = Settings.Database; - if (Settings.ApplicationName?.Length > 0) - startupParams["application_name"] = Settings.ApplicationName; + var applicationName = Settings.ApplicationName ?? PostgresEnvironment.AppName; + if (applicationName?.Length > 0) + startupParams["application_name"] = applicationName; if (Settings.SearchPath?.Length > 0) startupParams["search_path"] = Settings.SearchPath; @@ -718,7 +867,7 @@ async ValueTask GetUsernameAsyncInternal() } } - async Task RawOpen(SslMode sslMode, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken) + async Task RawOpen(NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken) { try { @@ -727,6 +876,8 @@ async Task RawOpen(SslMode sslMode, NpgsqlTimeout timeout, bool async, Cancellat else Connect(timeout); + ConnectionLogger.LogTrace("Socket connected to {Host}:{Port}", Host, Port); + _baseStream = new NetworkStream(_socket, true); _stream = _baseStream; @@ -746,36 +897,8 @@ async Task RawOpen(SslMode sslMode, NpgsqlTimeout timeout, bool async, Cancellat timeout.CheckAndApply(this); - IsSecure = false; - - if ((sslMode is SslMode.Prefer && DataSource.TransportSecurityHandler.SupportEncryption) || - sslMode is SslMode.Require or SslMode.VerifyCA or SslMode.VerifyFull) - { - WriteSslRequest(); - await Flush(async, cancellationToken).ConfigureAwait(false); - - await ReadBuffer.Ensure(1, async).ConfigureAwait(false); - var response = (char)ReadBuffer.ReadByte(); - timeout.CheckAndApply(this); - - switch (response) - { - default: - throw new NpgsqlException($"Received unknown response {response} for SSLRequest (expecting S or N)"); - case 'N': - if (sslMode != SslMode.Prefer) - throw new NpgsqlException("SSL connection requested. No SSL enabled connection from this host is configured."); - break; - case 'S': - await DataSource.TransportSecurityHandler.NegotiateEncryption(async, this, sslMode, timeout).ConfigureAwait(false); - break; - } - - if (ReadBuffer.ReadBytesLeft > 0) - throw new NpgsqlException("Additional unencrypted data received after SSL negotiation - this should never happen, and may be an indication of a man-in-the-middle attack."); - } - - ConnectionLogger.LogTrace("Socket connected to {Host}:{Port}", Host, Port); + IsSslEncrypted = false; + IsGssEncrypted = false; } catch { @@ -792,8 +915,129 @@ async Task RawOpen(SslMode sslMode, NpgsqlTimeout timeout, bool async, Cancellat } } - internal async Task NegotiateEncryption(SslMode sslMode, NpgsqlTimeout timeout, bool async) + async Task SetupEncryption(SslMode sslMode, GssEncryptionMode gssEncryptionMode, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken) { + var gssEncryptResult = await TryNegotiateGssEncryption(gssEncryptionMode, async, cancellationToken).ConfigureAwait(false); + if (gssEncryptResult == GssEncryptionResult.Success) + return; + + // TryNegotiateGssEncryption should already throw a much more meaningful exception + // if GSS encryption is required but for some reason we can't negotiate it. + // But since we have to return a specific result instead of generic true/false + // To make absolutely sure we didn't miss anything, recheck again + if (gssEncryptionMode == GssEncryptionMode.Require) + throw new NpgsqlException($"Unable to negotiate GSS encryption: {gssEncryptResult}"); + + timeout.CheckAndApply(this); + + if (GetSslNegotiation(Settings) == SslNegotiation.Direct) + { + // We already check that in NpgsqlConnectionStringBuilder.PostProcessAndValidate, but since we also allow environment variables... + if (Settings.SslMode is not SslMode.Require and not SslMode.VerifyCA and not SslMode.VerifyFull) + throw new ArgumentException("SSL Mode has to be Require or higher to be used with direct SSL Negotiation"); + if (gssEncryptResult == GssEncryptionResult.NegotiateFailure) + { + // We can be here only if it's fallback from preferred (but failed) gss encryption + // In this case, direct encryption isn't going to work anymore, so we throw a bogus exception to retry again without gss + // Alternatively, we can instead just go with the usual route of writing SslRequest, ignoring direct ssl + // But this is how libpq works + Debug.Assert(gssEncryptionMode == GssEncryptionMode.Prefer); + // The exception message doesn't matter since we're going to retry again + throw new NpgsqlException(); + } + + await DataSource.TransportSecurityHandler.NegotiateEncryption(async, this, sslMode, timeout, cancellationToken).ConfigureAwait(false); + if (ReadBuffer.ReadBytesLeft > 0) + throw new NpgsqlException("Additional unencrypted data received after SSL negotiation - this should never happen, and may be an indication of a man-in-the-middle attack."); + } + else if ((sslMode is SslMode.Prefer && DataSource.TransportSecurityHandler.SupportEncryption) || + sslMode is SslMode.Require or SslMode.VerifyCA or SslMode.VerifyFull) + { + WriteSslRequest(); + await Flush(async, cancellationToken).ConfigureAwait(false); + + await ReadBuffer.Ensure(1, async).ConfigureAwait(false); + var response = (char)ReadBuffer.ReadByte(); + timeout.CheckAndApply(this); + + switch (response) + { + default: + throw new NpgsqlException($"Received unknown response {response} for SSLRequest (expecting S or N)"); + case 'N': + if (sslMode != SslMode.Prefer) + throw new NpgsqlException("SSL connection requested. No SSL enabled connection from this host is configured."); + break; + case 'S': + await DataSource.TransportSecurityHandler.NegotiateEncryption(async, this, sslMode, timeout, cancellationToken).ConfigureAwait(false); + break; + } + + if (ReadBuffer.ReadBytesLeft > 0) + throw new NpgsqlException("Additional unencrypted data received after SSL negotiation - this should never happen, and may be an indication of a man-in-the-middle attack."); + } + } + + async ValueTask TryNegotiateGssEncryption(GssEncryptionMode gssEncryptionMode, bool async, CancellationToken cancellationToken) + { + // GetCredentialFailure is essentially a nop (since we didn't send anything over the wire) + // So we can proceed further as if gss encryption wasn't even attempted + if (gssEncryptionMode == GssEncryptionMode.Disable) return GssEncryptionResult.GetCredentialFailure; + + // Same thing as above, though in this case user doesn't require GSS encryption but didn't enable encryption + // Most of the time they're using the default value, in which case also exit without throwing an error + if (gssEncryptionMode == GssEncryptionMode.Prefer && !DataSource.TransportSecurityHandler.SupportEncryption) + return GssEncryptionResult.GetCredentialFailure; + + if (ConnectedEndPoint!.AddressFamily == AddressFamily.Unix) + { + if (gssEncryptionMode == GssEncryptionMode.Prefer) + return GssEncryptionResult.GetCredentialFailure; + + Debug.Assert(gssEncryptionMode == GssEncryptionMode.Require); + throw new NpgsqlException("GSS encryption isn't supported over unix socket"); + } + + return await DataSource.IntegratedSecurityHandler.GSSEncrypt(async, gssEncryptionMode == GssEncryptionMode.Require, this, cancellationToken) + .ConfigureAwait(false); + } + + static SslNegotiation GetSslNegotiation(NpgsqlConnectionStringBuilder settings) + { + if (settings.UserProvidedSslNegotiation is { } userProvidedSslNegotiation) + return userProvidedSslNegotiation; + + if (PostgresEnvironment.SslNegotiation is { } sslNegotiationEnv) + { + if (Enum.TryParse(sslNegotiationEnv, ignoreCase: true, out var sslNegotiation)) + return sslNegotiation; + } + + // If user hasn't provided the value via connection string or environment variable + // Retrieve the default value from property + return settings.SslNegotiation; + } + + static GssEncryptionMode GetGssEncMode(NpgsqlConnectionStringBuilder settings) + { + if (settings.UserProvidedGssEncMode is { } userProvidedGssEncMode) + return userProvidedGssEncMode; + + if (PostgresEnvironment.GssEncryptionMode is { } gssEncModeEnv) + { + if (Enum.TryParse(gssEncModeEnv, ignoreCase: true, out var gssEncMode)) + return gssEncMode; + } + + // If user hasn't provided the value via connection string or environment variable + // Retrieve the default value from property + return settings.GssEncryptionMode; + } + + internal async Task NegotiateEncryption(SslMode sslMode, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken) + { + ConnectionLogger.LogTrace("Negotiating SSL encryption"); + var clientCertificates = new X509Certificate2Collection(); var certPath = Settings.SslCertificate ?? PostgresEnvironment.SslCert ?? PostgresEnvironment.SslCertDefault; @@ -801,62 +1045,72 @@ internal async Task NegotiateEncryption(SslMode sslMode, NpgsqlTimeout timeout, { var password = Settings.SslPassword; - X509Certificate2? cert = null; - if (Path.GetExtension(certPath).ToUpperInvariant() != ".PFX") + if (!string.Equals(Path.GetExtension(certPath), ".pfx", StringComparison.OrdinalIgnoreCase)) { // It's PEM time var keyPath = Settings.SslKey ?? PostgresEnvironment.SslKey ?? PostgresEnvironment.SslKeyDefault; - cert = string.IsNullOrEmpty(password) + + // With PEM certificates we might have multiple certificates in a single file + // Where the first one is a leaf (and it has to have a private key) + // And others are intermediate between it and CA cert + // To support this, we first load the leaf certificate with private key + // And then we load everything else including the leaf, but without private key + // And afterwards we just get rid of the duplicate + var firstClientCert = string.IsNullOrEmpty(password) ? X509Certificate2.CreateFromPemFile(certPath, keyPath) : X509Certificate2.CreateFromEncryptedPemFile(certPath, password, keyPath); + clientCertificates.Add(firstClientCert); + + clientCertificates.ImportFromPemFile(certPath); + clientCertificates[1].Dispose(); + clientCertificates.RemoveAt(1); + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) { - // Windows crypto API has a bug with pem certs - // See #3650 - using var previousCert = cert; - cert = new X509Certificate2(cert.Export(X509ContentType.Pkcs12)); + for (var i = 0; i < clientCertificates.Count; i++) + { + var cert = clientCertificates[i]; + + // Windows crypto API has a bug with pem certs + // See #3650 + using var previousCert = cert; + cert = X509CertificateLoader.LoadPkcs12(cert.Export(X509ContentType.Pkcs12), null); + clientCertificates[i] = cert; + } } } - cert ??= new X509Certificate2(certPath, password); - clientCertificates.Add(cert); + // If it's empty, it's probably PFX + if (clientCertificates.Count == 0) + { + var certs = X509CertificateLoader.LoadPkcs12CollectionFromFile(certPath, password); + clientCertificates.AddRange(certs); + } - _certificate = cert; + var certificates = new List(); + foreach (var certificate in clientCertificates) + certificates.Add(certificate); + _certificates = certificates; } try { - ClientCertificatesCallback?.Invoke(clientCertificates); - var checkCertificateRevocation = Settings.CheckCertificateRevocation; RemoteCertificateValidationCallback? certificateValidationCallback; - X509Certificate2? caCert; + X509Certificate2Collection? caCerts; string? certRootPath = null; - if (UserCertificateValidationCallback is not null) - { - if (sslMode is SslMode.VerifyCA or SslMode.VerifyFull) - throw new ArgumentException(string.Format(NpgsqlStrings.CannotUseSslVerifyWithUserCallback, sslMode)); - - if (Settings.RootCertificate is not null) - throw new ArgumentException(NpgsqlStrings.CannotUseSslRootCertificateWithUserCallback); - - if (DataSource.TransportSecurityHandler.RootCertificateCallback is not null) - throw new ArgumentException(NpgsqlStrings.CannotUseValidationRootCertificateCallbackWithUserCallback); - - certificateValidationCallback = UserCertificateValidationCallback; - } - else if (sslMode is SslMode.Prefer or SslMode.Require) + if (sslMode is SslMode.Prefer or SslMode.Require) { certificateValidationCallback = SslTrustServerValidation; checkCertificateRevocation = false; } - else if ((caCert = DataSource.TransportSecurityHandler.RootCertificateCallback?.Invoke()) is not null || + else if (((caCerts = DataSource.TransportSecurityHandler.RootCertificatesCallback?.Invoke()) is not null && caCerts.Count > 0) || (certRootPath = Settings.RootCertificate ?? PostgresEnvironment.SslCertRoot ?? PostgresEnvironment.SslCertRootDefault) is not null) { - certificateValidationCallback = SslRootValidation(sslMode == SslMode.VerifyFull, certRootPath, caCert); + certificateValidationCallback = SslRootValidation(sslMode == SslMode.VerifyFull, certRootPath, caCerts); } else if (sslMode == SslMode.VerifyCA) { @@ -868,45 +1122,83 @@ internal async Task NegotiateEncryption(SslMode sslMode, NpgsqlTimeout timeout, certificateValidationCallback = SslVerifyFullValidation; } - var host = Host; + SslStreamCertificateContext? clientCertificateContext = null; + if (clientCertificates.Count > 0) + { + // SslClientAuthenticationOptions.ClientCertificates only sends trusted certificates or if they have private key + // Which makes us unable to send intermediate certificates + // Work around this by specifying the first certificate as target + // And others as additional + // See https://github.com/dotnet/runtime/issues/26323 + var clientCertificate = clientCertificates[0]; + clientCertificates.RemoveAt(0); + + clientCertificateContext = SslStreamCertificateContext.Create(clientCertificate, clientCertificates); + } -#if !NET8_0_OR_GREATER - // If the host is a valid IP address - replace it with an empty string - // We do that because .NET uses targetHost argument to send SNI to the server - // RFC explicitly prohibits sending an IP address so some servers might fail - // This was already fixed for .NET 8 - // See #5543 for discussion - if (IPAddress.TryParse(host, out _)) - host = string.Empty; -#endif + var host = Host; timeout.CheckAndApply(this); - try + var sslStream = new SslStream(_stream, leaveInnerStreamOpen: false); + + var sslStreamOptions = new SslClientAuthenticationOptions { - var sslStream = new SslStream(_stream, leaveInnerStreamOpen: false, certificateValidationCallback); + TargetHost = host, + ClientCertificateContext = clientCertificateContext, + EnabledSslProtocols = SslProtocols.None, + CertificateRevocationCheckMode = checkCertificateRevocation ? X509RevocationMode.Online : X509RevocationMode.NoCheck, + RemoteCertificateValidationCallback = certificateValidationCallback, + ApplicationProtocols = [_alpnProtocol] + }; + + if (SslClientAuthenticationOptionsCallback is not null) + { + SslClientAuthenticationOptionsCallback.Invoke(sslStreamOptions); + + // User changed remote certificate validation callback + // Check whether the change doesn't lead to unexpected behavior + if (sslStreamOptions.RemoteCertificateValidationCallback != certificateValidationCallback) + { + if (sslMode is SslMode.VerifyCA or SslMode.VerifyFull) + throw new ArgumentException(string.Format(NpgsqlStrings.CannotUseSslVerifyWithCustomValidationCallback, sslMode)); + + if (Settings.RootCertificate is not null) + throw new ArgumentException(NpgsqlStrings.CannotUseSslRootCertificateWithCustomValidationCallback); + + if (DataSource.TransportSecurityHandler.RootCertificatesCallback is not null) + throw new ArgumentException(NpgsqlStrings.CannotUseValidationRootCertificateCallbackWithCustomValidationCallback); + } + } + try + { if (async) - await sslStream.AuthenticateAsClientAsync(host, clientCertificates, SslProtocols.None, checkCertificateRevocation).ConfigureAwait(false); + await sslStream.AuthenticateAsClientAsync(sslStreamOptions, cancellationToken).ConfigureAwait(false); else - sslStream.AuthenticateAsClient(host, clientCertificates, SslProtocols.None, checkCertificateRevocation); + sslStream.AuthenticateAsClient(sslStreamOptions); _stream = sslStream; + sslStream = null; } - catch (Exception e) + catch (Exception e) when (e is not OperationCanceledException) { throw new NpgsqlException("Exception while performing SSL handshake", e); } + finally + { + sslStream?.Dispose(); + } ReadBuffer.Underlying = _stream; WriteBuffer.Underlying = _stream; - IsSecure = true; + IsSslEncrypted = true; ConnectionLogger.LogTrace("SSL negotiation successful"); } catch { - _certificate?.Dispose(); - _certificate = null; + _certificates?.ForEach(x => x.Dispose()); + _certificates = null; throw; } @@ -914,11 +1206,23 @@ internal async Task NegotiateEncryption(SslMode sslMode, NpgsqlTimeout timeout, void Connect(NpgsqlTimeout timeout) { - // Note that there aren't any timeout-able or cancellable DNS methods - var endpoints = NpgsqlConnectionStringBuilder.IsUnixSocket(Host, Port, out var socketPath) - ? new EndPoint[] { new UnixDomainSocketEndPoint(socketPath) } - : IPAddressesToEndpoints(Dns.GetHostAddresses(Host), Port); - timeout.Check(); + EndPoint[]? endpoints; + if (NpgsqlConnectionStringBuilder.IsUnixSocket(Host, Port, out var socketPath)) + { + endpoints = [new UnixDomainSocketEndPoint(socketPath!)]; + } + else + { + // Note that there aren't any timeout-able or cancellable DNS methods + try + { + endpoints = IPAddressesToEndpoints(Dns.GetHostAddresses(Host), Port); + } + catch (SocketException ex) + { + throw new NpgsqlException(ex.Message, ex); + } + } // Give each endpoint an equal share of the remaining time var perEndpointTimeout = -1; // Default to infinity @@ -941,6 +1245,9 @@ void Connect(NpgsqlTimeout timeout) try { + // Some options are not applied after the socket is open, see #6013 + SetSocketOptions(socket); + try { socket.Connect(endpoint); @@ -959,7 +1266,6 @@ void Connect(NpgsqlTimeout timeout) if (write.Count is 0) throw new TimeoutException("Timeout during connection attempt"); socket.Blocking = true; - SetSocketOptions(socket); _socket = socket; ConnectedEndPoint = endpoint; return; @@ -982,28 +1288,45 @@ void Connect(NpgsqlTimeout timeout) async Task ConnectAsync(NpgsqlTimeout timeout, CancellationToken cancellationToken) { - Task GetHostAddressesAsync(CancellationToken ct) => - Dns.GetHostAddressesAsync(Host, ct); - - // Whether the framework and/or the OS platform support Dns.GetHostAddressesAsync cancellation API or they do not, - // we always fake-cancel the operation with the help of TaskTimeoutAndCancellation.ExecuteAsync. It stops waiting - // and raises the exception, while the actual task may be left running. - var endpoints = NpgsqlConnectionStringBuilder.IsUnixSocket(Host, Port, out var socketPath) - ? new EndPoint[] { new UnixDomainSocketEndPoint(socketPath) } - : IPAddressesToEndpoints(await TaskTimeoutAndCancellation.ExecuteAsync(GetHostAddressesAsync, timeout, cancellationToken).ConfigureAwait(false), - Port); - - // Give each IP an equal share of the remaining time - var perIpTimespan = default(TimeSpan); - var perIpTimeout = timeout; - if (timeout.IsSet) + EndPoint[] endpoints; + if (NpgsqlConnectionStringBuilder.IsUnixSocket(Host, Port, out var socketPath)) + { + endpoints = [new UnixDomainSocketEndPoint(socketPath)]; + } + else { - perIpTimespan = new TimeSpan(timeout.CheckAndGetTimeLeft().Ticks / endpoints.Length); - perIpTimeout = new NpgsqlTimeout(perIpTimespan); + IPAddress[] ipAddresses = []; + using var combinedCts = timeout.IsSet ? CancellationTokenSource.CreateLinkedTokenSource(cancellationToken) : null; + combinedCts?.CancelAfter(timeout.CheckAndGetTimeLeft()); + var combinedToken = combinedCts?.Token ?? cancellationToken; + try + { + ipAddresses = await Dns.GetHostAddressesAsync(Host, combinedToken).ConfigureAwait(false); + } + catch (OperationCanceledException) + { + cancellationToken.ThrowIfCancellationRequested(); + Debug.Assert(timeout.HasExpired); + ThrowHelper.ThrowNpgsqlExceptionWithInnerTimeoutException("The operation has timed out"); + } + catch (SocketException ex) + { + throw new NpgsqlException(ex.Message, ex); + } + + endpoints = IPAddressesToEndpoints(ipAddresses, Port); } + // Give each endpoint an equal share of the remaining time + var perEndpointTimeout = default(TimeSpan); + if (timeout.IsSet) + perEndpointTimeout = timeout.CheckAndGetTimeLeft() / endpoints.Length; + for (var i = 0; i < endpoints.Length; i++) { + var endpointTimeout = timeout.IsSet ? new NpgsqlTimeout(perEndpointTimeout) : timeout; + Debug.Assert(timeout.IsSet == endpointTimeout.IsSet); + var endpoint = endpoints[i]; ConnectionLogger.LogTrace("Attempting to connect to {Endpoint}", endpoint); var protocolType = @@ -1014,8 +1337,14 @@ Task GetHostAddressesAsync(CancellationToken ct) => var socket = new Socket(endpoint.AddressFamily, SocketType.Stream, protocolType); try { - await OpenSocketConnectionAsync(socket, endpoint, perIpTimeout, cancellationToken).ConfigureAwait(false); + // Some options are not applied after the socket is open, see #6013 SetSocketOptions(socket); + + using var combinedCts = endpointTimeout.IsSet ? CancellationTokenSource.CreateLinkedTokenSource(cancellationToken) : null; + combinedCts?.CancelAfter(endpointTimeout.CheckAndGetTimeLeft()); + var combinedToken = combinedCts?.Token ?? cancellationToken; + await socket.ConnectAsync(endpoint, combinedToken).ConfigureAwait(false); + _socket = socket; ConnectedEndPoint = endpoint; return; @@ -1035,6 +1364,8 @@ Task GetHostAddressesAsync(CancellationToken ct) => if (e is OperationCanceledException) e = new TimeoutException("Timeout during connection attempt"); + else if (e is NpgsqlException) + e = e.InnerException!; // We throw NpgsqlException for timeouts, wrapping TimeoutException ConnectionLogger.LogTrace(e, "Failed to connect to {Endpoint}", endpoint); @@ -1042,21 +1373,11 @@ Task GetHostAddressesAsync(CancellationToken ct) => throw new NpgsqlException($"Failed to connect to {endpoint}", e); } } - - static Task OpenSocketConnectionAsync(Socket socket, EndPoint endpoint, NpgsqlTimeout perIpTimeout, CancellationToken cancellationToken) - { - // Whether the OS platform supports Socket.ConnectAsync cancellation API or not, - // we always fake-cancel the operation with the help of TaskTimeoutAndCancellation.ExecuteAsync. It stops waiting - // and raises the exception, while the actual task may be left running. - Task ConnectAsync(CancellationToken ct) => - socket.ConnectAsync(endpoint, ct).AsTask(); - return TaskTimeoutAndCancellation.ExecuteAsync(ConnectAsync, perIpTimeout, cancellationToken); - } } - IPEndPoint[] IPAddressesToEndpoints(IPAddress[] ipAddresses, int port) + EndPoint[] IPAddressesToEndpoints(IPAddress[] ipAddresses, int port) { - var result = new IPEndPoint[ipAddresses.Length]; + var result = new EndPoint[ipAddresses.Length]; for (var i = 0; i < ipAddresses.Length; i++) result[i] = new IPEndPoint(ipAddresses[i], port); return result; @@ -1073,7 +1394,7 @@ void SetSocketOptions(Socket socket) if (Settings.TcpKeepAlive) socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.KeepAlive, true); - if (Settings.TcpKeepAliveInterval > 0 && Settings.TcpKeepAliveTime == 0) + if (Settings is { TcpKeepAliveInterval: > 0, TcpKeepAliveTime: 0 }) throw new ArgumentException("If TcpKeepAliveInterval is defined, TcpKeepAliveTime must be defined as well"); if (Settings.TcpKeepAliveTime > 0) { @@ -1090,111 +1411,6 @@ void SetSocketOptions(Socket socket) #endregion - #region I/O - - readonly ChannelReader? CommandsInFlightReader; - internal readonly ChannelWriter? CommandsInFlightWriter; - - internal volatile int CommandsInFlightCount; - - internal ManualResetValueTaskSource ReaderCompleted { get; } = - new() { RunContinuationsAsynchronously = true }; - - async Task MultiplexingReadLoop() - { - Debug.Assert(Settings.Multiplexing); - Debug.Assert(CommandsInFlightReader != null); - - NpgsqlCommand? command = null; - var commandsRead = 0; - - try - { - while (await CommandsInFlightReader.WaitToReadAsync().ConfigureAwait(false)) - { - commandsRead = 0; - Debug.Assert(!InTransaction); - - while (CommandsInFlightReader.TryRead(out command)) - { - commandsRead++; - - await ReadBuffer.Ensure(5, true).ConfigureAwait(false); - - // We have a resultset for the command - hand back control to the command (which will - // return it to the user) - command.TraceReceivedFirstResponse(); - ReaderCompleted.Reset(); - command.ExecutionCompletion.SetResult(this); - - // Now wait until that command's reader is disposed. Note that RunContinuationsAsynchronously is - // true, so that the user code calling NpgsqlDataReader.Dispose will not continue executing - // synchronously here. The prevents issues if the code after the next command's execution - // completion blocks. - await new ValueTask(ReaderCompleted, ReaderCompleted.Version).ConfigureAwait(false); - Debug.Assert(!InTransaction); - } - - // Atomically update the commands in-flight counter, and check if it reached 0. If so, the - // connector is idle and can be returned. - // Note that this is racing with over-capacity writing, which can select any connector at any - // time (see MultiplexingWriteLoop), and we must make absolutely sure that if a connector is - // returned to the pool, it is *never* written to unless properly dequeued from the Idle channel. - if (Interlocked.Add(ref CommandsInFlightCount, -commandsRead) == 0) - { - // There's a race condition where the continuation of an asynchronous multiplexing write may not - // have executed yet, and the flush may still be in progress. We know all I/O has already - // been sent - because the reader has already consumed the entire resultset. So we wait until - // the connector's write lock has been released (long waiting will never occur here). - SpinWait.SpinUntil(() => MultiplexAsyncWritingLock == 0 || IsBroken); - - ResetReadBuffer(); - DataSource.Return(this); - } - } - - ConnectionLogger.LogTrace("Exiting multiplexing read loop", Id); - } - catch (Exception e) - { - Debug.Assert(IsBroken); - - // Decrement the commands already dequeued from the in-flight counter - Interlocked.Add(ref CommandsInFlightCount, -commandsRead); - - // When a connector is broken, the causing exception is stored on it. We fail commands with - // that exception - rather than the one thrown here - since the break may have happened during - // writing, and we want to bubble that one up. - - // Drain any pending in-flight commands and fail them. Note that some have only been written - // to the buffer, and not sent to the server. - command?.ExecutionCompletion.SetException(_breakReason!); - try - { - while (true) - { - var pendingCommand = await CommandsInFlightReader.ReadAsync().ConfigureAwait(false); - - // TODO: the exception we have here is sometimes just the result of the write loop breaking - // the connector, so it doesn't represent the actual root cause. - pendingCommand.ExecutionCompletion.SetException(new NpgsqlException("A previous command on this connection caused an error requiring all pending commands on this connection to be aborted", _breakReason!)); - } - } - catch (ChannelClosedException) - { - // All good, drained to the channel and failed all commands - } - - // "Return" the connector to the pool to for cleanup (e.g. update total connector count) - DataSource.Return(this); - - ConnectionLogger.LogError(e, "Exception in multiplexing read loop", Id); - } - - Debug.Assert(CommandsInFlightCount == 0); - } - - #endregion #region Frontend message processing @@ -1267,6 +1483,12 @@ internal ValueTask ReadMessage( // We've read all the prepended response. // Allow cancellation to proceed. ReadingPrependedMessagesMRE.Set(); + + // User requested cancellation but it hasn't been performed yet. + // This might happen if the cancellation is requested while we're reading prepended responses + // because we shouldn't cancel them and otherwise might deadlock. + if (UserCancellationRequested && !PostgresCancellationPerformed) + PerformDelayedUserCancellation(); } catch (Exception e) { @@ -1294,7 +1516,7 @@ internal ValueTask ReadMessage( { if (dataRowLoadingMode == DataRowLoadingMode.Skip) { - await ReadBuffer.Skip(len, async).ConfigureAwait(false); + await ReadBuffer.Skip(async, len).ConfigureAwait(false); continue; } } @@ -1461,15 +1683,15 @@ internal ValueTask ReadMessage( var authType = (AuthenticationRequestType)buf.ReadInt32(); return authType switch { - AuthenticationRequestType.AuthenticationOk => AuthenticationOkMessage.Instance, - AuthenticationRequestType.AuthenticationCleartextPassword => AuthenticationCleartextPasswordMessage.Instance, - AuthenticationRequestType.AuthenticationMD5Password => AuthenticationMD5PasswordMessage.Load(buf), - AuthenticationRequestType.AuthenticationGSS => AuthenticationGSSMessage.Instance, - AuthenticationRequestType.AuthenticationSSPI => AuthenticationSSPIMessage.Instance, - AuthenticationRequestType.AuthenticationGSSContinue => AuthenticationGSSContinueMessage.Load(buf, len), - AuthenticationRequestType.AuthenticationSASL => new AuthenticationSASLMessage(buf), - AuthenticationRequestType.AuthenticationSASLContinue => new AuthenticationSASLContinueMessage(buf, len - 4), - AuthenticationRequestType.AuthenticationSASLFinal => new AuthenticationSASLFinalMessage(buf, len - 4), + AuthenticationRequestType.Ok => AuthenticationOkMessage.Instance, + AuthenticationRequestType.CleartextPassword => AuthenticationCleartextPasswordMessage.Instance, + AuthenticationRequestType.MD5Password => AuthenticationMD5PasswordMessage.Load(buf), + AuthenticationRequestType.GSS => AuthenticationGSSMessage.Instance, + AuthenticationRequestType.SSPI => AuthenticationSSPIMessage.Instance, + AuthenticationRequestType.GSSContinue => AuthenticationGSSContinueMessage.Load(buf, len), + AuthenticationRequestType.SASL => new AuthenticationSASLMessage(buf), + AuthenticationRequestType.SASLContinue => new AuthenticationSASLContinueMessage(buf, len - 4), + AuthenticationRequestType.SASLFinal => new AuthenticationSASLFinalMessage(buf, len - 4), _ => throw new NotSupportedException($"Authentication method not supported (Received: {authType})") }; @@ -1561,17 +1783,8 @@ void ProcessNewTransactionStatus(TransactionStatus newStatus) switch (newStatus) { case TransactionStatus.Idle: - return; case TransactionStatus.InTransactionBlock: case TransactionStatus.InFailedTransactionBlock: - // In multiplexing mode, we can't support transaction in SQL: the connector must be removed from the - // writable connectors list, otherwise other commands may get written to it. So the user must tell us - // about the transaction via BeginTransaction. - if (Connection is null) - { - Debug.Assert(Settings.Multiplexing); - ThrowHelper.ThrowNotSupportedException("In multiplexing mode, transactions must be started with BeginTransaction"); - } return; case TransactionStatus.Pending: ThrowHelper.ThrowInvalidOperationException($"Internal Npgsql bug: invalid TransactionStatus {nameof(TransactionStatus.Pending)} received, should be frontend-only"); @@ -1595,15 +1808,20 @@ internal void ClearTransaction(Exception? disposeReason = null) /// /// Returns whether SSL is being used for the connection /// - internal bool IsSecure { get; private set; } + internal bool IsSslEncrypted { get; private set; } + + /// + /// Returns whether GSS is being used for the connection + /// + internal bool IsGssEncrypted { get; private set; } /// - /// Returns whether SCRAM-SHA256 is being user for the connection + /// Returns whether SCRAM-SHA256 is being used for the connection /// internal bool IsScram { get; private set; } /// - /// Returns whether SCRAM-SHA256-PLUS is being user for the connection + /// Returns whether SCRAM-SHA256-PLUS is being used for the connection /// internal bool IsScramPlus { get; private set; } @@ -1619,19 +1837,14 @@ internal void ClearTransaction(Exception? disposeReason = null) (sender, certificate, chain, sslPolicyErrors) => true; - static RemoteCertificateValidationCallback SslRootValidation(bool verifyFull, string? certRootPath, X509Certificate2? caCertificate) + static RemoteCertificateValidationCallback SslRootValidation(bool verifyFull, string? certRootPath, X509Certificate2Collection? caCertificates) => (_, certificate, chain, sslPolicyErrors) => { if (certificate is null || chain is null) return false; - // No errors here - no reason to check further - if (sslPolicyErrors == SslPolicyErrors.None) - return true; - - // That's VerifyCA check and the only error is name mismatch - no reason to check further - if (!verifyFull && sslPolicyErrors == SslPolicyErrors.RemoteCertificateNameMismatch) - return true; + // Even if there was no error while validating, we have to check one more time with the provided certificate + // As this is the exact same behavior as libpq // That's VerifyFull check and we have name mismatch - no reason to check further if (verifyFull && sslPolicyErrors.HasFlag(SslPolicyErrors.RemoteCertificateNameMismatch)) @@ -1641,17 +1854,20 @@ static RemoteCertificateValidationCallback SslRootValidation(bool verifyFull, st if (certRootPath is null) { - Debug.Assert(caCertificate is not null); - certs.Add(caCertificate); + Debug.Assert(caCertificates is { Count: > 0 }); + certs.AddRange(caCertificates); } else { - Debug.Assert(caCertificate is null); + Debug.Assert(caCertificates is null or { Count: > 0 }); if (Path.GetExtension(certRootPath).ToUpperInvariant() != ".PFX") certs.ImportFromPemFile(certRootPath); if (certs.Count == 0) - certs.Add(new X509Certificate2(certRootPath)); + { + // This is not a PEM certificate, probably PFX + certs.Add(X509CertificateLoader.LoadPkcs12FromFile(certRootPath, null)); + } } chain.ChainPolicy.CustomTrustStore.AddRange(certs); @@ -1677,10 +1893,10 @@ internal void ResetCancellation() } } - internal void PerformUserCancellation() + internal void PerformImmediateUserCancellation() { var connection = Connection; - if (connection is null || connection.ConnectorBindingScope == ConnectorBindingScope.Reader || UserCancellationRequested) + if (connection is null || UserCancellationRequested) return; // Take the lock first to make sure there is no concurrent Break. @@ -1697,34 +1913,43 @@ internal void PerformUserCancellation() try { - // Wait before we've read all responses for the prepended queries - // as we can't gracefully handle their cancellation. - // Break makes sure that it's going to be set even if we fail while reading them. + // Set the flag first before waiting on ReadingPrependedMessagesMRE. + // That way we're making sure that in case we're racing with ReadingPrependedMessagesMRE.Set + // that it's going to read the new value of the flag and request cancellation + _userCancellationRequested = true; + // Check whether we've read all responses for the prepended queries + // as we can't gracefully handle their cancellation. // We don't wait indefinitely to avoid deadlocks from synchronous CancellationToken.Register // See #5032 if (!ReadingPrependedMessagesMRE.Wait(0)) return; - _userCancellationRequested = true; - - if (AttemptPostgresCancellation && SupportsPostgresCancellation) - { - var cancellationTimeout = Settings.CancellationTimeout; - if (PerformPostgresCancellation() && cancellationTimeout >= 0) - { - if (cancellationTimeout > 0) - { - ReadBuffer.Timeout = TimeSpan.FromMilliseconds(cancellationTimeout); - ReadBuffer.Cts.CancelAfter(cancellationTimeout); - } + PerformUserCancellationUnsynchronized(); + } + finally + { + Monitor.Exit(CancelLock); + } + } - return; - } - } + void PerformDelayedUserCancellation() + { + // Take the lock first to make sure there is no concurrent Break. + // We should be safe to take it as Break only take it to set the state. + lock (SyncObj) + { + // The connector is dead, exit gracefully. + if (!IsConnected) + return; + // The connector is still alive, take the CancelLock before exiting SingleUseLock. + // If a break will happen after, it's going to wait for the cancellation to complete. + Monitor.Enter(CancelLock); + } - ReadBuffer.Timeout = _cancelImmediatelyTimeout; - ReadBuffer.Cts.Cancel(); + try + { + PerformUserCancellationUnsynchronized(); } finally { @@ -1732,6 +1957,29 @@ internal void PerformUserCancellation() } } + void PerformUserCancellationUnsynchronized() + { + if (AttemptPostgresCancellation && SupportsPostgresCancellation) + { + var cancellationTimeout = Settings.CancellationTimeout; + if (PerformPostgresCancellation() && cancellationTimeout >= 0) + { + // TODO: according to docs, we treat 0 timeout as infinite, yet we do not change the actual value + // We should revisit this here and in NpgsqlReadBuffer + if (cancellationTimeout > 0) + { + ReadBuffer.Timeout = TimeSpan.FromMilliseconds(cancellationTimeout); + ReadBuffer.Cts.CancelAfter(cancellationTimeout); + } + + return; + } + } + + ReadBuffer.Timeout = _cancelImmediatelyTimeout; + ReadBuffer.Cts.Cancel(); + } + /// /// Creates another connector and sends a cancel request through it for this connector. This method never throws, but returns /// whether the cancellation attempt failed. @@ -1781,11 +2029,36 @@ internal bool PerformPostgresCancellation() void DoCancelRequest(int backendProcessId, int backendSecretKey) { Debug.Assert(State == ConnectorState.Closed); + var gssEncMode = GetGssEncMode(Settings); try { - RawOpen(Settings.SslMode, new NpgsqlTimeout(TimeSpan.FromSeconds(ConnectionTimeout)), false, CancellationToken.None) - .GetAwaiter().GetResult(); + try + { + var timeout = new NpgsqlTimeout(TimeSpan.FromSeconds(ConnectionTimeout)); + RawOpen(timeout, false, + CancellationToken.None) + .GetAwaiter().GetResult(); + SetupEncryption(Settings.SslMode, gssEncMode, timeout, false, + CancellationToken.None). + GetAwaiter().GetResult(); + } + catch (Exception e) when (gssEncMode == GssEncryptionMode.Prefer) + { + ConnectionLogger.LogTrace(e, "Error while opening physical connection with GSS encryption, retrying without it"); + Cleanup(); + + // If we hit an error with gss encryption + // Retry again without it + var timeout = new NpgsqlTimeout(TimeSpan.FromSeconds(ConnectionTimeout)); + RawOpen(timeout, false, + CancellationToken.None) + .GetAwaiter().GetResult(); + SetupEncryption(Settings.SslMode, GssEncryptionMode.Disable, timeout, false, + CancellationToken.None). + GetAwaiter().GetResult(); + } + WriteCancelRequest(backendProcessId, backendSecretKey); Flush(); @@ -1814,7 +2087,7 @@ internal CancellationTokenRegistration StartCancellableOperation( AttemptPostgresCancellation = attemptPgCancellation; return _cancellationTokenRegistration = - cancellationToken.Register(static c => ((NpgsqlConnector)c!).PerformUserCancellation(), this); + cancellationToken.Register(static c => ((NpgsqlConnector)c!).PerformImmediateUserCancellation(), this); } /// @@ -1846,34 +2119,26 @@ internal NestedCancellableScope StartNestedCancellableOperation( var currentAttemptPostgresCancellation = AttemptPostgresCancellation; AttemptPostgresCancellation = attemptPgCancellation; - var registration = cancellationToken.Register(static c => ((NpgsqlConnector)c!).PerformUserCancellation(), this); + var registration = cancellationToken.Register(static c => ((NpgsqlConnector)c!).PerformImmediateUserCancellation(), this); return new(this, registration, currentUserCancellationToken, currentAttemptPostgresCancellation); } - internal readonly struct NestedCancellableScope : IDisposable + internal readonly struct NestedCancellableScope( + NpgsqlConnector connector, + CancellationTokenRegistration registration, + CancellationToken previousCancellationToken, + bool previousAttemptPostgresCancellation) + : IDisposable { - readonly NpgsqlConnector _connector; - readonly CancellationTokenRegistration _registration; - readonly CancellationToken _previousCancellationToken; - readonly bool _previousAttemptPostgresCancellation; - - public NestedCancellableScope(NpgsqlConnector connector, CancellationTokenRegistration registration, CancellationToken previousCancellationToken, bool previousAttemptPostgresCancellation) - { - _connector = connector; - _registration = registration; - _previousCancellationToken = previousCancellationToken; - _previousAttemptPostgresCancellation = previousAttemptPostgresCancellation; - } - public void Dispose() { - if (_connector is null) + if (connector is null) return; - _connector.UserCancellationToken = _previousCancellationToken; - _connector.AttemptPostgresCancellation = _previousAttemptPostgresCancellation; - _registration.Dispose(); + connector.UserCancellationToken = previousCancellationToken; + connector.AttemptPostgresCancellation = previousAttemptPostgresCancellation; + registration.Dispose(); } } @@ -1901,7 +2166,7 @@ internal async Task CloseOngoingOperations(bool async) // therefore vulnerable to the race condition in #615. if (copyOperation is NpgsqlBinaryImporter || copyOperation is NpgsqlCopyTextWriter || - copyOperation is NpgsqlRawCopyStream rawCopyStream && rawCopyStream.CanWrite) + copyOperation is NpgsqlRawCopyStream { CanWrite: true }) { try { @@ -1969,9 +2234,6 @@ internal void Close() LogMessages.ClosedPhysicalConnection(ConnectionLogger, Host, Port, Database, UserFacingConnectionString, Id); } - internal bool TryRemovePendingEnlistedConnector(Transaction transaction) - => DataSource.TryRemovePendingEnlistedConnector(this, transaction); - internal void Return() => DataSource.Return(this); /// @@ -1987,14 +2249,16 @@ internal Exception UnexpectedMessageReceived(BackendMessageCode received) /// Note that fatal errors during the Open phase do *not* pass through here. /// /// The exception that caused the break. + /// Whether we treat host as down, even if we're still connecting to PostgreSQL instance. /// The exception given in for chaining calls. - internal Exception Break(Exception reason) + internal Exception Break(Exception reason, bool markHostAsOfflineOnConnecting = false) { Debug.Assert(!IsClosed); Monitor.Enter(SyncObj); - if (State == ConnectorState.Broken) + var state = State; + if (state == ConnectorState.Broken) { // We're already broken. // Exit SingleUseLock to unblock other threads (like cancellation). @@ -2006,11 +2270,6 @@ internal Exception Break(Exception reason) try { - // If we're broken while reading prepended messages - // the cancellation request might still be waiting on the MRE. - // Unblock it. - ReadingPrependedMessagesMRE.Set(); - LogMessages.BreakingConnection(ConnectionLogger, Id, reason); // Note that we may be reading and writing from the same connector concurrently, so safely set @@ -2034,7 +2293,9 @@ internal Exception Break(Exception reason) // Note we only set the cluster to offline and clear the pool if the connection is being broken (we're in this method), // *and* the exception indicates that the PG cluster really is down; the latter includes any IO/timeout issue, // but does not include e.g. authentication failure or timeouts with disabled cancellation. + // We also do not treat host as down if we're still connecting, as we might retry without GSS/TLS if (reason is NpgsqlException { IsTransient: true } ne && + (state != ConnectorState.Connecting || markHostAsOfflineOnConnecting) && (ne.InnerException is not TimeoutException || Settings.CancellationTimeout != -1) || reason is PostgresException pe && PostgresErrorCodes.IsCriticalFailure(pe)) { @@ -2058,11 +2319,9 @@ internal Exception Break(Exception reason) // On the other hand leaving the state Open could indicate to the user that the connection is functional. // (see https://github.com/npgsql/npgsql/issues/3705#issuecomment-839908772) Connection = null; - if (connection.ConnectorBindingScope != ConnectorBindingScope.None) - Return(); + Return(); connection.EnlistedTransaction = null; connection.Connector = null; - connection.ConnectorBindingScope = ConnectorBindingScope.None; } connection.FullState = ConnectionState.Broken; @@ -2082,19 +2341,6 @@ void FullCleanup() { lock (CleanupLock) { - if (Settings.Multiplexing) - { - FlagAsNotWritableForMultiplexing(); - - // Note that in multiplexing, this could be called from the read loop, while the write loop is - // writing into the channel. To make sure this race condition isn't a problem, the channel currently - // isn't set up with SingleWriter (since at this point it doesn't do anything). - CommandsInFlightWriter!.Complete(); - - // The connector's read loop has a continuation to observe and log any exception coming out - // (see Open) - } - ConnectionLogger.LogTrace("Cleaning up connector", Id); Cleanup(); @@ -2112,10 +2358,39 @@ void FullCleanup() /// Closes the socket and cleans up client-side resources associated with this connector. /// /// - /// This method doesn't actually perform any meaningful I/O, and therefore is sync-only. + /// This method doesn't actually perform any meaningful I/O (except sending TLS alert), and therefore is sync-only. /// void Cleanup() { + var sslStream = _stream as SslStream; + if (sslStream is not null) + { + try + { + // Send close_notify TLS alert to correctly close connection on postgres's side + sslStream.ShutdownAsync().GetAwaiter().GetResult(); + // Theoretically we should do a 0 read here to receive server's close_notify alert + // But overall it doesn't look like it makes much of a difference + } + catch + { + // ignored + } + } + + // After we access SslStream.RemoteCertificate (like for SASLSha256Plus) + // SslStream will no longer dispose it for us automatically + // Which is why we have to do it ourselves before disposing the stream + // As otherwise accessing RemoteCertificate will throw an exception + try + { + sslStream?.RemoteCertificate?.Dispose(); + } + catch + { + // ignored + } + try { _stream?.Dispose(); @@ -2170,15 +2445,18 @@ void Cleanup() PostgresParameters.Clear(); _currentCommand = null; - if (_certificate is not null) - { - _certificate.Dispose(); - _certificate = null; - } + _certificates?.ForEach(x => x.Dispose()); + _certificates = null; } + [MemberNotNull(nameof(_resetWithoutDeallocateMessage))] void GenerateResetMessage() { + // Generate a reset message that resets connection state without using DISCARD ALL. + // This is used in two scenarios: + // 1. When closing a pooled connection that has prepared statements (DISCARD ALL would deallocate them) + // 2. When closing a connection within an enlisted System.Transactions transaction (DISCARD ALL cannot + // run inside a transaction block, but its component commands can) var sb = new StringBuilder("SET SESSION AUTHORIZATION DEFAULT;RESET ALL;"); _resetWithoutDeallocateResponseCount = 2; if (DatabaseInfo.SupportsCloseAll) @@ -2220,8 +2498,6 @@ void GenerateResetMessage() /// internal async Task Reset(bool async) { - bool endBindingScope; - // We start user action in case a keeplive happens concurrently, or a concurrent user command (bug) using (StartUserAction(attemptPgCancellation: false)) { @@ -2238,21 +2514,17 @@ internal async Task Reset(bool async) switch (TransactionStatus) { case TransactionStatus.Idle: - // There is an undisposed transaction on multiplexing connection - endBindingScope = Connection?.ConnectorBindingScope == ConnectorBindingScope.Transaction; break; case TransactionStatus.Pending: // BeginTransaction() was called, but was left in the write buffer and not yet sent to server. // Just clear the transaction state. ProcessNewTransactionStatus(TransactionStatus.Idle); ClearTransaction(); - endBindingScope = true; break; case TransactionStatus.InTransactionBlock: case TransactionStatus.InFailedTransactionBlock: await Rollback(async).ConfigureAwait(false); ClearTransaction(); - endBindingScope = true; break; default: ThrowHelper.ThrowInvalidOperationException($"Internal Npgsql bug: unexpected value {TransactionStatus} of enum {nameof(TransactionStatus)}. Please file a bug."); @@ -2277,13 +2549,35 @@ internal async Task Reset(bool async) DataReader.UnbindIfNecessary(); } + } - if (endBindingScope) + /// + /// Called when a pooled connection with an enlisted System.Transactions transaction is closed. + /// Since we're inside a transaction block, we cannot send DISCARD ALL; + /// we prepend a reset message that only includes commands that can safely run within a transaction. + /// + internal void ResetWithinEnlistedTransaction() + { + // We start user action in case a keeplive happens concurrently, or a concurrent user command (bug) + using var _ = StartUserAction(attemptPgCancellation: false); + + // Our buffer may contain unsent prepended messages, so clear it out. + WriteBuffer.Clear(); + PendingPrependedResponses = 0; + + ResetReadBuffer(); + + if (_sendResetOnClose) { - // Connection is null if a connection enlisted in a TransactionScope was closed before the - // TransactionScope completed - the connector is still enlisted, but has no connection. - Connection?.EndBindingScope(ConnectorBindingScope.Transaction); + if (_resetWithoutDeallocateMessage is null) + { + GenerateResetMessage(); + } + + PrependInternalMessage(_resetWithoutDeallocateMessage, _resetWithoutDeallocateResponseCount); } + + DataReader.UnbindIfNecessary(); } /// @@ -2293,7 +2587,6 @@ internal async Task Reset(bool async) [MethodImpl(MethodImplOptions.AggressiveInlining)] void ResetReadBuffer() { - LongRunningConnection = false; if (_origReadBuffer != null) { Debug.Assert(_origReadBuffer.ReadBytesLeft == 0); @@ -2407,10 +2700,11 @@ UserAction DoStartUserAction(ConnectorState newState, NpgsqlCommand? command, StartCancellableOperation(cancellationToken, attemptPgCancellation); - // We reset the ReadBuffer.Timeout for every user action, so it wouldn't leak from the previous query or action + // We reset the ReadBuffer.Timeout and WriteBuffer.Timeout for every user action, so it wouldn't leak from the previous query or action // For example, we might have successfully cancelled the previous query (so the connection is not broken) - // But the next time, we call the Prepare, which doesn't set it's own timeout - ReadBuffer.Timeout = TimeSpan.FromSeconds(command?.CommandTimeout ?? Settings.CommandTimeout); + // But the next time, we call the Prepare, which doesn't set its own timeout + var timeoutSeconds = command?.CommandTimeout ?? Settings.CommandTimeout; + ReadBuffer.Timeout = WriteBuffer.Timeout = timeoutSeconds > 0 ? TimeSpan.FromSeconds(timeoutSeconds) : Timeout.InfiniteTimeSpan; return new UserAction(this); } @@ -2545,12 +2839,15 @@ internal async Task Wait(bool async, int timeout, CancellationToken cancel await Flush(async, cancellationToken).ConfigureAwait(false); var keepaliveMs = Settings.KeepAlive * 1000; + var isTimeoutInfinite = timeout <= 0; while (true) { cancellationToken.ThrowIfCancellationRequested(); - var timeoutForKeepalive = _isKeepAliveEnabled && (timeout <= 0 || keepaliveMs < timeout); - ReadBuffer.Timeout = TimeSpan.FromMilliseconds(timeoutForKeepalive ? keepaliveMs : timeout); + var timeoutForKeepalive = _isKeepAliveEnabled && (isTimeoutInfinite || keepaliveMs < timeout); + ReadBuffer.Timeout = timeoutForKeepalive + ? TimeSpan.FromMilliseconds(keepaliveMs) + : isTimeoutInfinite ? Timeout.InfiniteTimeSpan : TimeSpan.FromMilliseconds(timeout); try { var msg = await ReadMessageWithNotifications(async).ConfigureAwait(false); @@ -2569,7 +2866,7 @@ internal async Task Wait(bool async, int timeout, CancellationToken cancel LogMessages.SendingKeepalive(ConnectionLogger, Id); - var keepaliveTime = Stopwatch.StartNew(); + var keepaliveStartTimestamp = Stopwatch.GetTimestamp(); await WriteSync(async, cancellationToken).ConfigureAwait(false); await Flush(async, cancellationToken).ConfigureAwait(false); @@ -2584,7 +2881,7 @@ internal async Task Wait(bool async, int timeout, CancellationToken cancel { msg = await ReadMessageWithNotifications(async).ConfigureAwait(false); } - catch (Exception e) when (e is OperationCanceledException || e is NpgsqlException npgEx && npgEx.InnerException is TimeoutException) + catch (Exception e) when (e is OperationCanceledException || e is NpgsqlException { InnerException: TimeoutException }) { // We're somewhere in the middle of a reading keepalive messages // Breaking the connection, as we've lost protocol sync @@ -2610,7 +2907,11 @@ internal async Task Wait(bool async, int timeout, CancellationToken cancel } if (timeout > 0) - timeout -= (keepaliveMs + (int)keepaliveTime.ElapsedMilliseconds); + { + timeout -= (keepaliveMs + (int)Stopwatch.GetElapsedTime(keepaliveStartTimestamp).TotalMilliseconds); + // Make sure we don't accidentally set -1 as a timeout (because it's infinite) + timeout = Math.Max(timeout, 0); + } } } @@ -2680,7 +2981,7 @@ void ReadParameterStatus(ReadOnlySpan incomingName, ReadOnlySpan inc for (var i = 0; i < _rawParameters.Count; i++) { - (var currentName, var currentValue) = _rawParameters[i]; + var (currentName, currentValue) = _rawParameters[i]; if (incomingName.SequenceEqual(currentName)) { if (incomingValue.SequenceEqual(currentValue)) @@ -2707,8 +3008,6 @@ void ReadParameterStatus(ReadOnlySpan incomingName, ReadOnlySpan inc switch (name) { case "standard_conforming_strings": - if (value != "on" && Settings.Multiplexing) - throw Break(new NotSupportedException("standard_conforming_strings must be on with multiplexing")); UseConformingStrings = value == "on"; return; @@ -2743,6 +3042,27 @@ void ReadParameterStatus(ReadOnlySpan incomingName, ReadOnlySpan inc return null; } + internal Activity? TraceCopyStart(string copyCommand, string operation) + { + Activity? activity = null; + if (NpgsqlActivitySource.IsEnabled) + { + var tracingOptions = DataSource.Configuration.TracingOptions; + + if (tracingOptions.CopyOperationFilter?.Invoke(copyCommand) ?? true) + { + var spanName = tracingOptions.CopyOperationSpanNameProvider?.Invoke(copyCommand); + activity = NpgsqlActivitySource.CopyStart(copyCommand, this, spanName, operation); + + if (activity != null) + { + tracingOptions.CopyOperationEnrichmentCallback?.Invoke(activity, copyCommand); + } + } + } + return activity; + } + #endregion Misc } @@ -2846,4 +3166,11 @@ enum DataRowLoadingMode Skip } +enum GssEncryptionResult +{ + GetCredentialFailure, + NegotiateFailure, + Success +} + #endregion diff --git a/src/Npgsql/Internal/NpgsqlDatabaseInfo.cs b/src/Npgsql/Internal/NpgsqlDatabaseInfo.cs index fed3f8c165..5c700ac7e3 100644 --- a/src/Npgsql/Internal/NpgsqlDatabaseInfo.cs +++ b/src/Npgsql/Internal/NpgsqlDatabaseInfo.cs @@ -12,14 +12,16 @@ namespace Npgsql.Internal; /// Base class for implementations which provide information about PostgreSQL and PostgreSQL-like databases /// (e.g. type definitions, capabilities...). /// +[Experimental(NpgsqlDiagnostics.DatabaseInfoExperimental)] public abstract class NpgsqlDatabaseInfo { #region Fields - static volatile INpgsqlDatabaseInfoFactory[] Factories = { + static volatile INpgsqlDatabaseInfoFactory[] Factories = + [ new PostgresMinimalDatabaseInfoFactory(), new PostgresDatabaseInfoFactory() - }; + ]; #endregion Fields @@ -114,13 +116,13 @@ public abstract class NpgsqlDatabaseInfo #region Types - readonly List _baseTypesMutable = new(); - readonly List _arrayTypesMutable = new(); - readonly List _rangeTypesMutable = new(); - readonly List _multirangeTypesMutable = new(); - readonly List _enumTypesMutable = new(); - readonly List _compositeTypesMutable = new(); - readonly List _domainTypesMutable = new(); + readonly List _baseTypesMutable = []; + readonly List _arrayTypesMutable = []; + readonly List _rangeTypesMutable = []; + readonly List _multirangeTypesMutable = []; + readonly List _enumTypesMutable = []; + readonly List _compositeTypesMutable = []; + readonly List _domainTypesMutable = []; internal IReadOnlyList BaseTypes => _baseTypesMutable; internal IReadOnlyList ArrayTypes => _arrayTypesMutable; @@ -239,9 +241,9 @@ internal void ProcessTypes() ByFullName[type.DataTypeName.Value] = type; // If more than one type exists with the same partial name, we place a null value. // This allows us to detect this case later and force the user to use full names only. - ByName[type.InternalName] = ByName.ContainsKey(type.InternalName) - ? null - : type; + var typeInternalName = type.InternalName; + if (!ByName.TryAdd(typeInternalName, type)) + ByName[typeInternalName] = null; switch (type) { @@ -311,8 +313,7 @@ protected static Version ParseServerVersion(string value) /// public static void RegisterFactory(INpgsqlDatabaseInfoFactory factory) { - if (factory == null) - throw new ArgumentNullException(nameof(factory)); + ArgumentNullException.ThrowIfNull(factory); var factories = new INpgsqlDatabaseInfoFactory[Factories.Length + 1]; factories[0] = factory; @@ -338,11 +339,11 @@ internal static async Task Load(NpgsqlConnector conn, Npgsql // For tests internal static void ResetFactories() - => Factories = new INpgsqlDatabaseInfoFactory[] - { + => Factories = + [ new PostgresMinimalDatabaseInfoFactory(), new PostgresDatabaseInfoFactory() - }; + ]; #endregion Factory management diff --git a/src/Npgsql/Internal/NpgsqlReadBuffer.Stream.cs b/src/Npgsql/Internal/NpgsqlReadBuffer.Stream.cs index 78e17d4a82..66f53503ed 100644 --- a/src/Npgsql/Internal/NpgsqlReadBuffer.Stream.cs +++ b/src/Npgsql/Internal/NpgsqlReadBuffer.Stream.cs @@ -17,6 +17,7 @@ internal sealed class ColumnStream : Stream int _read; bool _canSeek; bool _commandScoped; + bool _consumeOnDispose; /// Does not throw ODE. internal int CurrentLength { get; private set; } internal bool IsDisposed { get; private set; } @@ -28,7 +29,7 @@ internal ColumnStream(NpgsqlConnector connector) IsDisposed = true; } - internal void Init(int len, bool canSeek, bool commandScoped) + internal void Init(int len, bool canSeek, bool commandScoped, bool consumeOnDispose = true) { Debug.Assert(!canSeek || _buf.ReadBytesLeft >= len, "Seekable stream constructed but not all data is in buffer (sequential)"); @@ -41,6 +42,7 @@ internal void Init(int len, bool canSeek, bool commandScoped) _read = 0; _commandScoped = commandScoped; + _consumeOnDispose = consumeOnDispose; IsDisposed = false; } @@ -71,8 +73,7 @@ public override long Position } set { - if (value < 0) - throw new ArgumentOutOfRangeException(nameof(value), "Non - negative number required."); + ArgumentOutOfRangeException.ThrowIfNegative(value); Seek(value, SeekOrigin.Begin); } } @@ -83,8 +84,7 @@ public override long Seek(long offset, SeekOrigin origin) if (!_canSeek) throw new NotSupportedException(); - if (offset > int.MaxValue) - throw new ArgumentOutOfRangeException(nameof(offset), "Stream length must be non-negative and less than 2^31 - 1 - origin."); + ArgumentOutOfRangeException.ThrowIfGreaterThan(offset, int.MaxValue); const string seekBeforeBegin = "An attempt was made to move the position before the beginning of the stream."; @@ -189,28 +189,28 @@ public override void Write(byte[] buffer, int offset, int count) => throw new NotSupportedException(); void CheckDisposed() - { - if (IsDisposed) - ThrowHelper.ThrowObjectDisposedException(nameof(ColumnStream)); - } + => ObjectDisposedException.ThrowIf(IsDisposed, this); protected override void Dispose(bool disposing) - => DisposeAsync(disposing, async: false).GetAwaiter().GetResult(); + { + if (disposing) + DisposeCore(async: false).GetAwaiter().GetResult(); + } public override ValueTask DisposeAsync() - => DisposeAsync(disposing: true, async: true); + => DisposeCore(async: true); - async ValueTask DisposeAsync(bool disposing, bool async) + async ValueTask DisposeCore(bool async) { - if (IsDisposed || !disposing) + if (IsDisposed) return; - if (!_connector.IsBroken) + if (_consumeOnDispose && !_connector.IsBroken) { var pos = _buf.CumulativeReadPosition - _startPos; var remaining = checked((int)(CurrentLength - pos)); if (remaining > 0) - await _buf.Skip(remaining, async).ConfigureAwait(false); + await _buf.Skip(async, remaining).ConfigureAwait(false); } IsDisposed = true; @@ -219,13 +219,10 @@ async ValueTask DisposeAsync(bool disposing, bool async) static void ValidateArguments(byte[] buffer, int offset, int count) { - if (buffer == null) - throw new ArgumentNullException(nameof(buffer)); - if (offset < 0) - throw new ArgumentOutOfRangeException(nameof(offset)); - if (count < 0) - throw new ArgumentOutOfRangeException(nameof(count)); + ArgumentNullException.ThrowIfNull(buffer); + ArgumentOutOfRangeException.ThrowIfNegative(offset); + ArgumentOutOfRangeException.ThrowIfNegative(count); if (buffer.Length - offset < count) - throw new ArgumentException("Offset and length were out of bounds for the array or count is greater than the number of elements from index to the end of the source collection."); + ThrowHelper.ThrowArgumentException("Offset and length were out of bounds for the array or count is greater than the number of elements from index to the end of the source collection."); } } diff --git a/src/Npgsql/Internal/NpgsqlReadBuffer.cs b/src/Npgsql/Internal/NpgsqlReadBuffer.cs index e9dee8dfc4..0f91bad9d4 100644 --- a/src/Npgsql/Internal/NpgsqlReadBuffer.cs +++ b/src/Npgsql/Internal/NpgsqlReadBuffer.cs @@ -2,6 +2,7 @@ using System.Buffers; using System.Buffers.Binary; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.IO; using System.Net.Sockets; using System.Runtime.CompilerServices; @@ -17,6 +18,7 @@ namespace Npgsql.Internal; /// A buffer used by Npgsql to read data from the socket efficiently. /// Provides methods which decode different values types and tracks the current position. /// +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] sealed partial class NpgsqlReadBuffer : IDisposable { #region Fields and Properties @@ -34,25 +36,16 @@ sealed partial class NpgsqlReadBuffer : IDisposable internal ResettableCancellationTokenSource Cts { get; } readonly MetricsReporter? _metricsReporter; - TimeSpan _preTranslatedTimeout = TimeSpan.Zero; - /// /// Timeout for sync and async reads /// internal TimeSpan Timeout { - get => _preTranslatedTimeout; + get => Cts.Timeout; set { - if (_preTranslatedTimeout != value) + if (Cts.Timeout != value) { - _preTranslatedTimeout = value; - - if (value == TimeSpan.Zero) - value = InfiniteTimeSpan; - else if (value < TimeSpan.Zero) - value = TimeSpan.Zero; - Debug.Assert(_underlyingSocket != null); _underlyingSocket.ReceiveTimeout = (int)value.TotalMilliseconds; @@ -111,10 +104,7 @@ internal NpgsqlReadBuffer( Encoding relaxedTextEncoding, bool usePool = false) { - if (size < MinimumSize) - { - throw new ArgumentOutOfRangeException(nameof(size), size, "Buffer size must be at least " + MinimumSize); - } + ArgumentOutOfRangeException.ThrowIfLessThan(size, MinimumSize); Connector = connector!; // TODO: Clean this up Underlying = stream; @@ -134,6 +124,9 @@ internal NpgsqlReadBuffer( #region I/O + public void Ensure(int count) + => Ensure(count, async: false, readingNotifications: false).GetAwaiter().GetResult(); + public ValueTask Ensure(int count, bool async) => Ensure(count, async, readingNotifications: false); @@ -155,18 +148,13 @@ int ReadWithTimeout(Span buffer) catch (Exception ex) { var connector = Connector; - switch (ex) - { - // Note that mono throws SocketException with the wrong error (see #1330) - case IOException e when (e.InnerException as SocketException)?.SocketErrorCode == - (Type.GetType("Mono.Runtime") == null ? SocketError.TimedOut : SocketError.WouldBlock): + if (ex is IOException { InnerException: SocketException { SocketErrorCode: SocketError.TimedOut } }) { // If we should attempt PostgreSQL cancellation, do it the first time we get a timeout. // TODO: As an optimization, we can still attempt to send a cancellation request, but after // that immediately break the connection - if (connector.AttemptPostgresCancellation && - !connector.PostgresCancellationPerformed && - connector.PerformPostgresCancellation()) + if (connector is { AttemptPostgresCancellation: true, PostgresCancellationPerformed: false } + && connector.PerformPostgresCancellation()) { // Note that if the cancellation timeout is negative, we flow down and break the // connection immediately. @@ -184,16 +172,15 @@ int ReadWithTimeout(Span buffer) // Break the connection, bubbling up the correct exception type (cancellation or timeout) throw connector.Break(CreateCancelException(connector)); } - default: - throw connector.Break(new NpgsqlException("Exception while reading from stream", ex)); - } + + throw connector.Break(new NpgsqlException("Exception while reading from stream", ex)); } } } async ValueTask ReadWithTimeoutAsync(Memory buffer, CancellationToken cancellationToken) { - var finalCt = Timeout != TimeSpan.Zero + var finalCt = Timeout != InfiniteTimeSpan ? Cts.Start(cancellationToken) : Cts.Reset(); @@ -224,8 +211,7 @@ async ValueTask ReadWithTimeoutAsync(Memory buffer, CancellationToken // If we should attempt PostgreSQL cancellation, do it the first time we get a timeout. // TODO: As an optimization, we can still attempt to send a cancellation request, but after // that immediately break the connection - if (connector.AttemptPostgresCancellation && - !connector.PostgresCancellationPerformed && + if (connector is { AttemptPostgresCancellation: true, PostgresCancellationPerformed: false } && connector.PerformPostgresCancellation()) { // Note that if the cancellation timeout is negative, we flow down and break the @@ -294,7 +280,7 @@ static async ValueTask EnsureLong( buffer.ReadPosition = 0; } - var finalCt = async && buffer.Timeout != TimeSpan.Zero + var finalCt = async && buffer.Timeout != InfiniteTimeSpan ? buffer.Cts.Start() : buffer.Cts.Reset(); @@ -351,8 +337,7 @@ static async ValueTask EnsureLong( // If we should attempt PostgreSQL cancellation, do it the first time we get a timeout. // TODO: As an optimization, we can still attempt to send a cancellation request, but after // that immediately break the connection - if (connector.AttemptPostgresCancellation && - !connector.PostgresCancellationPerformed && + if (connector is { AttemptPostgresCancellation: true, PostgresCancellationPerformed: false } && connector.PerformPostgresCancellation()) { // Note that if the cancellation timeout is negative, we flow down and break the @@ -412,8 +397,29 @@ internal NpgsqlReadBuffer AllocateOversize(int count) } /// - /// Does not perform any I/O - assuming that the bytes to be skipped are in the memory buffer. + /// Skip a given number of bytes. /// + internal void Skip(int len, bool allowIO) + { + Debug.Assert(len >= 0); + + if (allowIO && len > ReadBytesLeft) + { + len -= ReadBytesLeft; + while (len > Size) + { + ResetPosition(); + Ensure(Size); + len -= Size; + } + ResetPosition(); + Ensure(len); + } + + Debug.Assert(ReadBytesLeft >= len); + ReadPosition += len; + } + internal void Skip(int len) { Debug.Assert(ReadBytesLeft >= len); @@ -423,7 +429,7 @@ internal void Skip(int len) /// /// Skip a given number of bytes. /// - public async Task Skip(int len, bool async) + public async Task Skip(bool async, int len) { Debug.Assert(len >= 0); @@ -522,19 +528,13 @@ public ulong ReadUInt64() return result; } - [MethodImpl(MethodImplOptions.AggressiveInlining)] public float ReadSingle() { CheckBounds(sizeof(float)); - float result; - if (BitConverter.IsLittleEndian) - { - var value = BinaryPrimitives.ReverseEndianness(Unsafe.ReadUnaligned(ref Buffer[ReadPosition])); - result = Unsafe.As(ref value); - } - else - result = Unsafe.ReadUnaligned(ref Buffer[ReadPosition]); + var result = BitConverter.IsLittleEndian + ? BitConverter.Int32BitsToSingle(BinaryPrimitives.ReverseEndianness(Unsafe.ReadUnaligned(ref Buffer[ReadPosition]))) + : Unsafe.ReadUnaligned(ref Buffer[ReadPosition]); ReadPosition += sizeof(float); return result; } @@ -543,14 +543,9 @@ public float ReadSingle() public double ReadDouble() { CheckBounds(sizeof(double)); - double result; - if (BitConverter.IsLittleEndian) - { - var value = BinaryPrimitives.ReverseEndianness(Unsafe.ReadUnaligned(ref Buffer[ReadPosition])); - result = Unsafe.As(ref value); - } - else - result = Unsafe.ReadUnaligned(ref Buffer[ReadPosition]); + var result = BitConverter.IsLittleEndian + ? BitConverter.Int64BitsToDouble(BinaryPrimitives.ReverseEndianness(Unsafe.ReadUnaligned(ref Buffer[ReadPosition]))) + : Unsafe.ReadUnaligned(ref Buffer[ReadPosition]); ReadPosition += sizeof(double); return result; } @@ -670,11 +665,11 @@ static async ValueTask ReadAsyncLong(NpgsqlReadBuffer buffer, bool commandS } ColumnStream? _lastStream; - public ColumnStream CreateStream(int len, bool canSeek) + public ColumnStream CreateStream(int len, bool canSeek, bool consumeOnDispose = true) { if (_lastStream is not { IsDisposed: true }) _lastStream = new ColumnStream(Connector); - _lastStream.Init(len, canSeek, !Connector.LongRunningConnection); + _lastStream.Init(len, canSeek, Connector.Settings.ReplicationMode == ReplicationMode.Off, consumeOnDispose); return _lastStream; } diff --git a/src/Npgsql/Internal/NpgsqlWriteBuffer.cs b/src/Npgsql/Internal/NpgsqlWriteBuffer.cs index 316495eaa8..ea0b4b265a 100644 --- a/src/Npgsql/Internal/NpgsqlWriteBuffer.cs +++ b/src/Npgsql/Internal/NpgsqlWriteBuffer.cs @@ -28,6 +28,8 @@ sealed class NpgsqlWriteBuffer : IDisposable internal Stream Underlying { private get; set; } readonly Socket? _underlyingSocket; + internal bool MessageLengthValidation { get; set; } = true; + readonly ResettableCancellationTokenSource _timeoutCts; readonly MetricsReporter? _metricsReporter; @@ -76,9 +78,14 @@ internal PgWriter GetWriter(NpgsqlDatabaseInfo typeCatalog, FlushMode flushMode internal int WritePosition; + int _messageBytesFlushed; + int? _messageLength; + bool _disposed; readonly PgWriter _pgWriter; + Span Span => Buffer.AsSpan(WritePosition, WriteSpaceLeft); + /// /// The minimum buffer size possible. /// @@ -96,8 +103,7 @@ internal NpgsqlWriteBuffer( int size, Encoding textEncoding) { - if (size < MinimumSize) - throw new ArgumentOutOfRangeException(nameof(size), size, "Buffer size must be at least " + MinimumSize); + ArgumentOutOfRangeException.ThrowIfLessThan(size, MinimumSize); Connector = connector!; // TODO: Clean this up; only null when used from PregeneratedMessages, where we don't care. Underlying = stream; @@ -131,6 +137,8 @@ public async Task Flush(bool async, CancellationToken cancellationToken = defaul WritePosition = pos; } else if (WritePosition == 0) return; + else + AdvanceMessageBytesFlushed(WritePosition); var finalCt = async && Timeout > TimeSpan.Zero ? _timeoutCts.Start(cancellationToken) @@ -151,28 +159,26 @@ public async Task Flush(bool async, CancellationToken cancellationToken = defaul Underlying.Flush(); } } - catch (Exception e) + catch (Exception ex) { // Stopping twice (in case the previous Stop() call succeeded) doesn't hurt. // Not stopping will cause an assertion failure in debug mode when we call Start() the next time. // We can't stop in a finally block because Connector.Break() will dispose the buffer and the contained // _timeoutCts _timeoutCts.Stop(); - switch (e) + switch (ex) { // User requested the cancellation - case OperationCanceledException _ when (cancellationToken.IsCancellationRequested): - throw Connector.Break(e); + case OperationCanceledException when cancellationToken.IsCancellationRequested: + throw Connector.Break(ex); // Read timeout - case OperationCanceledException _: - // Note that mono throws SocketException with the wrong error (see #1330) - case IOException _ when (e.InnerException as SocketException)?.SocketErrorCode == - (Type.GetType("Mono.Runtime") == null ? SocketError.TimedOut : SocketError.WouldBlock): - Debug.Assert(e is OperationCanceledException ? async : !async); + case OperationCanceledException: + case IOException { InnerException: SocketException { SocketErrorCode: SocketError.TimedOut } }: + Debug.Assert(ex is OperationCanceledException ? async : !async); throw Connector.Break(new NpgsqlException("Exception while writing to stream", new TimeoutException("Timeout during writing attempt"))); } - throw Connector.Break(new NpgsqlException("Exception while writing to stream", e)); + throw Connector.Break(new NpgsqlException("Exception while writing to stream", ex)); } NpgsqlEventSource.Log.BytesWritten(WritePosition); _metricsReporter?.ReportBytesWritten(WritePosition); @@ -199,15 +205,19 @@ internal void DirectWrite(ReadOnlySpan buffer) Debug.Assert(WritePosition == 5); WritePosition = 1; - WriteInt32(buffer.Length + 4); + WriteInt32(checked(buffer.Length + 4)); WritePosition = 5; _copyMode = false; + StartMessage(5); Flush(); _copyMode = true; WriteCopyDataHeader(); // And ready the buffer after the direct write completes } else + { Debug.Assert(WritePosition == 0); + AdvanceMessageBytesFlushed(buffer.Length); + } try { @@ -230,15 +240,19 @@ internal async Task DirectWrite(ReadOnlyMemory memory, bool async, Cancell Debug.Assert(WritePosition == 5); WritePosition = 1; - WriteInt32(memory.Length + 4); + WriteInt32(checked(memory.Length + 4)); WritePosition = 5; _copyMode = false; + StartMessage(5); await Flush(async, cancellationToken).ConfigureAwait(false); _copyMode = true; WriteCopyDataHeader(); // And ready the buffer after the direct write completes } else + { Debug.Assert(WritePosition == 0); + AdvanceMessageBytesFlushed(memory.Length); + } try { @@ -306,37 +320,6 @@ public void WriteInt64(long value) WritePosition += sizeof(long); } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void WriteUInt64(ulong value) - { - CheckBounds(); - Unsafe.WriteUnaligned(ref Buffer[WritePosition], BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(value) : value); - WritePosition += sizeof(ulong); - } - - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void WriteSingle(float value) - { - CheckBounds(); - if (BitConverter.IsLittleEndian) - Unsafe.WriteUnaligned(ref Buffer[WritePosition], BinaryPrimitives.ReverseEndianness(Unsafe.As(ref value))); - else - Unsafe.WriteUnaligned(ref Buffer[WritePosition], value); - WritePosition += sizeof(float); - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public void WriteDouble(double value) - { - CheckBounds(); - if (BitConverter.IsLittleEndian) - Unsafe.WriteUnaligned(ref Buffer[WritePosition], BinaryPrimitives.ReverseEndianness(Unsafe.As(ref value))); - else - Unsafe.WriteUnaligned(ref Buffer[WritePosition], value); - WritePosition += sizeof(double); - } - [Conditional("DEBUG")] unsafe void CheckBounds() where T : unmanaged { @@ -348,46 +331,48 @@ static void ThrowNotSpaceLeft() => ThrowHelper.ThrowInvalidOperationException("There is not enough space left in the buffer."); public Task WriteString(string s, int byteLen, bool async, CancellationToken cancellationToken = default) - => WriteString(s, s.Length, byteLen, async, cancellationToken); - - public Task WriteString(string s, int charLen, int byteLen, bool async, CancellationToken cancellationToken = default) { if (byteLen <= WriteSpaceLeft) { - WriteString(s, charLen); + WriteString(s); return Task.CompletedTask; } - return WriteStringLong(this, async, s, charLen, byteLen, cancellationToken); + return WriteStringLong(this, async, s, byteLen, cancellationToken); - static async Task WriteStringLong(NpgsqlWriteBuffer buffer, bool async, string s, int charLen, int byteLen, CancellationToken cancellationToken) + static async Task WriteStringLong(NpgsqlWriteBuffer buffer, bool async, string s, int byteLen, CancellationToken cancellationToken) { Debug.Assert(byteLen > buffer.WriteSpaceLeft); if (byteLen <= buffer.Size) { // String can fit entirely in an empty buffer. Flush and retry rather than - // going into the partial writing flow below (which requires ToCharArray()) + // going into the partial writing flow below await buffer.Flush(async, cancellationToken).ConfigureAwait(false); - buffer.WriteString(s, charLen); + buffer.WriteString(s); } else { - var charPos = 0; - while (true) + var encoder = buffer._textEncoder; + encoder.Reset(); + var data = s.AsMemory(); + var minBufferSize = buffer.TextEncoding.GetMaxByteCount(1); + + bool completed; + do { - buffer.WriteStringChunked(s, charPos, charLen - charPos, true, out var charsUsed, out var completed); - if (completed) - break; - await buffer.Flush(async, cancellationToken).ConfigureAwait(false); - charPos += charsUsed; - } + if (buffer.WriteSpaceLeft < minBufferSize) + await buffer.Flush(async, cancellationToken).ConfigureAwait(false); + encoder.Convert(data.Span, buffer.Span, flush: true, out var charsUsed, out var bytesUsed, out completed); + data = data.Slice(charsUsed); + buffer.WritePosition += bytesUsed; + } while (!completed); } } } - public void WriteString(string s, int len = 0) + public void WriteString(string s) { Debug.Assert(TextEncoding.GetByteCount(s) <= WriteSpaceLeft); - WritePosition += TextEncoding.GetBytes(s, 0, len == 0 ? s.Length : len, Buffer, WritePosition); + WritePosition += TextEncoding.GetBytes(s, 0, s.Length, Buffer, WritePosition); } public void WriteBytes(ReadOnlySpan buf) @@ -440,30 +425,6 @@ static async Task WriteBytesLong(NpgsqlWriteBuffer buffer, bool async, ReadOnlyM } } - public async Task WriteStreamRaw(Stream stream, int count, bool async, CancellationToken cancellationToken = default) - { - while (count > 0) - { - if (WriteSpaceLeft == 0) - await Flush(async, cancellationToken).ConfigureAwait(false); - try - { - var read = async - ? await stream.ReadAsync(Buffer, WritePosition, Math.Min(WriteSpaceLeft, count), cancellationToken).ConfigureAwait(false) - : stream.Read(Buffer, WritePosition, Math.Min(WriteSpaceLeft, count)); - if (read == 0) - throw new EndOfStreamException(); - WritePosition += read; - count -= read; - } - catch (Exception e) - { - throw Connector.Break(new NpgsqlException("Exception while writing to stream", e)); - } - } - Debug.Assert(count == 0); - } - public void WriteNullTerminatedString(string s) { AssertASCIIOnly(s); @@ -482,47 +443,6 @@ public void WriteNullTerminatedString(byte[] s) #endregion - #region Write Complex - - internal void WriteStringChunked(char[] chars, int charIndex, int charCount, - bool flush, out int charsUsed, out bool completed) - { - if (WriteSpaceLeft < _textEncoder.GetByteCount(chars, charIndex, char.IsHighSurrogate(chars[charIndex]) ? 2 : 1, flush: false)) - { - charsUsed = 0; - completed = false; - return; - } - - _textEncoder.Convert(chars, charIndex, charCount, Buffer, WritePosition, WriteSpaceLeft, - flush, out charsUsed, out var bytesUsed, out completed); - WritePosition += bytesUsed; - } - - internal unsafe void WriteStringChunked(string s, int charIndex, int charCount, - bool flush, out int charsUsed, out bool completed) - { - int bytesUsed; - - fixed (char* sPtr = s) - fixed (byte* bufPtr = Buffer) - { - if (WriteSpaceLeft < _textEncoder.GetByteCount(sPtr + charIndex, char.IsHighSurrogate(*(sPtr + charIndex)) ? 2 : 1, flush: false)) - { - charsUsed = 0; - completed = false; - return; - } - - _textEncoder.Convert(sPtr + charIndex, charCount, bufPtr + WritePosition, WriteSpaceLeft, - flush, out charsUsed, out bytesUsed, out completed); - } - - WritePosition += bytesUsed; - } - - #endregion - #region Copy internal void StartCopyMode() @@ -567,9 +487,50 @@ public void Dispose() #region Misc + internal void StartMessage(int messageLength) + { + if (!MessageLengthValidation) + return; + + if (_messageLength is not null && _messageBytesFlushed != _messageLength && WritePosition != -_messageBytesFlushed + _messageLength) + Throw(); + + // Add negative WritePosition to compensate for previous message(s) written without flushing. + _messageBytesFlushed = -WritePosition; + _messageLength = messageLength; + + void Throw() + { + throw Connector.Break(new OverflowException("Did not write the amount of bytes the message length specified")); + } + } + + void AdvanceMessageBytesFlushed(int count) + { + if (!MessageLengthValidation) + return; + + if (count < 0 || _messageLength is null || (long)_messageBytesFlushed + count > _messageLength) + Throw(); + + _messageBytesFlushed += count; + + void Throw() + { + ArgumentOutOfRangeException.ThrowIfNegative(count); + + if (_messageLength is null) + throw Connector.Break(new InvalidOperationException("No message was started")); + + if ((long)_messageBytesFlushed + count > _messageLength) + throw Connector.Break(new OverflowException("Tried to write more bytes than the message length specified")); + } + } + internal void Clear() { WritePosition = 0; + _messageLength = null; } /// diff --git a/src/Npgsql/Internal/PgBufferedConverter.cs b/src/Npgsql/Internal/PgBufferedConverter.cs index 2bed7ffa3c..beced6d589 100644 --- a/src/Npgsql/Internal/PgBufferedConverter.cs +++ b/src/Npgsql/Internal/PgBufferedConverter.cs @@ -5,10 +5,9 @@ namespace Npgsql.Internal; -public abstract class PgBufferedConverter : PgConverter +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public abstract class PgBufferedConverter(bool customDbNullPredicate = false) : PgConverter(customDbNullPredicate) { - protected PgBufferedConverter(bool customDbNullPredicate = false) : base(customDbNullPredicate) { } - protected abstract T ReadCore(PgReader reader); protected abstract void WriteCore(PgWriter writer, T value); @@ -17,8 +16,8 @@ public override Size GetSize(SizeContext context, T value, ref object? writeStat public sealed override T Read(PgReader reader) { - // We check IsAtStart first to speed up primitive reads. - if (!reader.IsAtStart && reader.ShouldBufferCurrent()) + // We check FieldAtStart to speed up simple value reads, as field level buffering was handled by reader.StartRead() already. + if (!reader.FieldAtStart && reader.ShouldBufferCurrent()) ThrowIORequired(reader.CurrentBufferRequirement); return ReadCore(reader); diff --git a/src/Npgsql/Internal/PgConverter.cs b/src/Npgsql/Internal/PgConverter.cs index 3317361516..627c4dc979 100644 --- a/src/Npgsql/Internal/PgConverter.cs +++ b/src/Npgsql/Internal/PgConverter.cs @@ -2,11 +2,13 @@ using System.Buffers; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; namespace Npgsql.Internal; +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] public abstract class PgConverter { internal DbNullPredicate DbNullPredicateKind { get; } @@ -176,24 +178,20 @@ static class PgConverterExtensions return size; } -} - -interface IResumableRead -{ - bool Supported { get; } -} -public readonly struct SizeContext -{ - [SetsRequiredMembers] - public SizeContext(DataFormat format, Size bufferRequirement) + internal static PgConverter UnsafeDowncast(this PgConverter converter) { - Format = format; - BufferRequirement = bufferRequirement; + // Justification: avoid perf cost of casting to a known base class type per read/write, see callers. + Debug.Assert(converter is PgConverter); + return Unsafe.As>(converter); } +} - public required Size BufferRequirement { get; init; } - public DataFormat Format { get; } +[method: SetsRequiredMembers] +public readonly struct SizeContext(DataFormat format, Size bufferRequirement) +{ + public required Size BufferRequirement { get; init; } = bufferRequirement; + public DataFormat Format { get; } = format; } class MultiWriteState : IDisposable diff --git a/src/Npgsql/Internal/PgConverterResolver.cs b/src/Npgsql/Internal/PgConverterResolver.cs index baee09d58e..5fbe699017 100644 --- a/src/Npgsql/Internal/PgConverterResolver.cs +++ b/src/Npgsql/Internal/PgConverterResolver.cs @@ -1,8 +1,10 @@ using System; +using System.Diagnostics.CodeAnalysis; using Npgsql.Internal.Postgres; namespace Npgsql.Internal; +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] public abstract class PgConverterResolver { private protected PgConverterResolver() { } diff --git a/src/Npgsql/Internal/PgReader.cs b/src/Npgsql/Internal/PgReader.cs index 54672f92e8..5da3ea7681 100644 --- a/src/Npgsql/Internal/PgReader.cs +++ b/src/Npgsql/Internal/PgReader.cs @@ -10,8 +10,11 @@ namespace Npgsql.Internal; +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] public class PgReader { + const int UninitializedSentinel = -1; + // We don't want to add a ton of memory pressure for large strings. internal const int MaxPreparedTextReaderSize = 1024 * 64; @@ -49,34 +52,39 @@ public class PgReader internal PgReader(NpgsqlReadBuffer buffer) { _buffer = buffer; - _fieldStartPos = -1; - _currentSize = -1; + _fieldStartPos = UninitializedSentinel; + _currentSize = UninitializedSentinel; } - internal long FieldStartPos => _fieldStartPos; - internal int FieldSize => _fieldSize; - internal bool Initialized => _fieldStartPos is not -1; - internal int FieldOffset => (int)(_buffer.CumulativeReadPosition - _fieldStartPos); - internal int FieldRemaining => FieldSize - FieldOffset; + internal bool Initialized => _fieldStartPos is not UninitializedSentinel; + int FieldOffset => (int)(_buffer.CumulativeReadPosition - _fieldStartPos); + int FieldSize => _fieldSize; + int FieldRemaining => FieldSize - FieldOffset; + + internal bool FieldIsDbNull => FieldSize is -1; + internal bool FieldAtStart => FieldOffset is 0; + + internal bool IsFieldConsumed(int offset) => FieldOffset > offset; + + // TODO refactor out + internal long GetFieldStartPos(NpgsqlNestedDataReader nestedDataReader) => _fieldStartPos; + // TODO refactor out + internal int GetFieldOffset(NpgsqlNestedDataReader nestedDataReader) => FieldOffset; - bool HasCurrent => _currentSize is not -1; - int CurrentSize => HasCurrent ? _currentSize : _fieldSize; + internal bool NestedInitialized => _currentSize is not UninitializedSentinel; + int CurrentSize => NestedInitialized ? _currentSize : _fieldSize; public ValueMetadata Current => new() { Size = CurrentSize, Format = _fieldFormat, BufferRequirement = CurrentBufferRequirement }; - public int CurrentRemaining => HasCurrent ? _currentSize - CurrentOffset : FieldRemaining; + public int CurrentRemaining => NestedInitialized ? _currentSize - CurrentOffset : FieldRemaining; - internal Size CurrentBufferRequirement => HasCurrent ? _currentBufferRequirement : _fieldBufferRequirement; + internal Size CurrentBufferRequirement => NestedInitialized ? _currentBufferRequirement : _fieldBufferRequirement; int CurrentOffset => FieldOffset - _currentStartPos; - internal bool IsAtStart => FieldOffset is 0; internal bool Resumable => _resumable; - public bool IsResumed => Resumable && CurrentSize != CurrentRemaining; + public bool IsResumed => Resumable && CurrentOffset > 0; ArrayPool ArrayPool => ArrayPool.Shared; - [MemberNotNullWhen(true, nameof(_charsReadReader))] - internal bool IsCharsRead => _charsReadOffset is not null; - // Here for testing purposes internal void BreakConnection() => throw _buffer.Connector.Break(new Exception("Broken")); @@ -98,8 +106,8 @@ void CheckBounds(int count) [MethodImpl(MethodImplOptions.NoInlining)] void Core(int count) { - if (count > FieldRemaining) - ThrowHelper.ThrowInvalidOperationException("Attempt to read past the end of the field."); + if (count > CurrentRemaining) + ThrowHelper.ThrowIndexOutOfRangeException("Attempt to read past the end of the current field size."); } } @@ -193,7 +201,7 @@ public string ReadNullTerminatedString(Encoding encoding) NpgsqlReadBuffer.ColumnStream GetColumnStream(bool canSeek = false, int? length = null) { if (length > CurrentRemaining) - throw new ArgumentOutOfRangeException(nameof(length), "Length is larger than the current remaining value size"); + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(length), "Length is larger than the current remaining value size"); _requiresCleanup = true; // This will cause any previously handed out StreamReaders etc to throw, as intended. @@ -202,7 +210,7 @@ NpgsqlReadBuffer.ColumnStream GetColumnStream(bool canSeek = false, int? length length ??= CurrentRemaining; CheckBounds(length.GetValueOrDefault()); - return _userActiveStream = _buffer.CreateStream(length.GetValueOrDefault(), canSeek && length <= _buffer.ReadBytesLeft); + return _userActiveStream = _buffer.CreateStream(length.GetValueOrDefault(), canSeek && length <= _buffer.ReadBytesLeft, consumeOnDispose: false); } public TextReader GetTextReader(Encoding encoding) @@ -344,14 +352,15 @@ public async ValueTask> ReadBytesAsync(int count, Cancell public void Rewind(int count) { - // Shut down any streaming going on on the column - DisposeUserActiveStream(async: false).GetAwaiter().GetResult(); + if (CurrentOffset < count) + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(count), "Attempt to rewind past the current field start."); if (_buffer.ReadPosition < count) - throw new ArgumentOutOfRangeException("Cannot rewind further than the buffer start"); + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(count), "Attempt to rewind past the buffer start, some of this data is no longer part of the underlying buffer."); - if (CurrentOffset < count) - throw new ArgumentOutOfRangeException("Cannot rewind further than the current field offset"); + // Shut down any streaming going on on the column + if (StreamActive) + DisposeUserActiveStream(async: false).GetAwaiter().GetResult(); _buffer.ReadPosition -= count; } @@ -363,43 +372,31 @@ public void Rewind(int count) /// The stream length, if any async ValueTask DisposeUserActiveStream(bool async) { - if (StreamActive) - { - if (async) - await _userActiveStream.DisposeAsync().ConfigureAwait(false); - else - _userActiveStream.Dispose(); - } - + if (async) + await (_userActiveStream?.DisposeAsync() ?? new()).ConfigureAwait(false); + else + _userActiveStream?.Dispose(); _userActiveStream = null; } - internal bool GetCharsReadInfo(Encoding encoding, out int charsRead, out TextReader reader, out int charsOffset, out ArraySegment? buffer) - { - if (!IsCharsRead) - throw new InvalidOperationException("No active chars read"); + internal int CharsRead => _charsRead; + internal bool CharsReadActive => _charsReadOffset is not null; - if (_charsReadReader is null) - { - charsRead = 0; - reader = _charsReadReader = GetTextReader(encoding); - charsOffset = _charsReadOffset ??= 0; - buffer = _charsReadBuffer; - return true; - } + internal void GetCharsReadInfo(Encoding encoding, out int charsRead, out TextReader reader, out int charsOffset, out ArraySegment? buffer) + { + if (!CharsReadActive) + ThrowHelper.ThrowInvalidOperationException("No active chars read"); charsRead = _charsRead; - reader = _charsReadReader; - charsOffset = _charsReadOffset!.Value; + reader = _charsReadReader ??= GetTextReader(encoding); + charsOffset = _charsReadOffset ?? 0; buffer = _charsReadBuffer; - - return false; } - internal void ResetCharsRead(out int charsRead) + internal void RestartCharsRead() { - if (!IsCharsRead) - throw new InvalidOperationException("No active chars read"); + if (!CharsReadActive) + ThrowHelper.ThrowInvalidOperationException("No active chars read"); switch (_charsReadReader) { @@ -411,52 +408,42 @@ internal void ResetCharsRead(out int charsRead) reader.DiscardBufferedData(); break; } - _charsRead = charsRead = 0; + _charsRead = 0; } - internal void AdvanceCharsRead(int charsRead) - { - _charsRead += charsRead; - _charsReadOffset = null; - _charsReadBuffer = null; - } + internal void AdvanceCharsRead(int charsRead) => _charsRead += charsRead; - internal void InitCharsRead(int dataOffset, ArraySegment? buffer, out int? charsRead) + internal void StartCharsRead(int dataOffset, ArraySegment? buffer) { if (!Resumable) - throw new InvalidOperationException("Wasn't initialized as resumed"); + ThrowHelper.ThrowInvalidOperationException("Reader was not initialized as resumable"); - charsRead = _charsReadReader is null ? null : _charsRead; _charsReadOffset = dataOffset; _charsReadBuffer = buffer; } - internal PgReader Init(int fieldLength, DataFormat format, bool resumable = false) + internal void EndCharsRead() { - if (Initialized) - { - if (resumable) - { - if (Resumable) - return this; - _resumable = true; - } - else - { - if (!IsAtStart) - ThrowHelper.ThrowInvalidOperationException("Cannot be initialized to be non-resumable until a commit is issued."); - _resumable = false; - } - } + if (!Resumable) + ThrowHelper.ThrowInvalidOperationException("Wasn't initialized as resumed"); - Debug.Assert(!_requiresCleanup, "Reader wasn't properly committed before next init"); + if (!CharsReadActive) + ThrowHelper.ThrowInvalidOperationException("No active chars read"); + + _charsReadOffset = null; + _charsReadBuffer = null; + } + + internal void Init(int fieldSize, DataFormat fieldFormat, bool resumable = false) + { + if (Initialized) + ThrowHelper.ThrowInvalidOperationException("Already initialized"); _fieldStartPos = _buffer.CumulativeReadPosition; - _fieldFormat = format; - _fieldSize = fieldLength; - _resumable = resumable; _fieldConsumed = false; - return this; + _fieldSize = fieldSize; + _fieldFormat = fieldFormat; + _resumable = resumable; } internal void StartRead(Size bufferRequirement) @@ -464,7 +451,11 @@ internal void StartRead(Size bufferRequirement) Debug.Assert(FieldSize >= 0); _fieldBufferRequirement = bufferRequirement; if (ShouldBuffer(bufferRequirement)) - Buffer(bufferRequirement); + BufferNoInlined(bufferRequirement); + + [MethodImpl(MethodImplOptions.NoInlining)] + void BufferNoInlined(Size bufferRequirement) + => Buffer(bufferRequirement); } internal ValueTask StartReadAsync(Size bufferRequirement, CancellationToken cancellationToken) @@ -511,10 +502,10 @@ internal ValueTask EndReadAsync() internal async ValueTask BeginNestedRead(bool async, int size, Size bufferRequirement, CancellationToken cancellationToken = default) { if (size > CurrentRemaining) - throw new ArgumentOutOfRangeException(nameof(size), "Cannot begin a read for a larger size than the current remaining size."); + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(size), "Cannot begin a read for a larger size than the current remaining size."); if (size < 0) - throw new ArgumentOutOfRangeException(nameof(size), "Cannot be negative"); + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(size), "Cannot be negative"); var previousSize = CurrentSize; var previousStartPos = _currentStartPos; @@ -533,32 +524,60 @@ public NestedReadScope BeginNestedRead(int size, Size bufferRequirement) public ValueTask BeginNestedReadAsync(int size, Size bufferRequirement, CancellationToken cancellationToken = default) => BeginNestedRead(async: true, size, bufferRequirement, cancellationToken); - internal void Seek(int offset) + /// Seek origin is the start of Current, e.g. Seek(0) rewinds to the start. + internal int Seek(int offset) { if (CurrentOffset > offset) Rewind(CurrentOffset - offset); else if (CurrentOffset < offset) Consume(offset - CurrentOffset); + + return FieldRemaining; } - internal async ValueTask Consume(bool async, int? count = null, CancellationToken cancellationToken = default) + public void Consume(int? count = null) { if (count <= 0 || FieldSize < 0 || FieldRemaining == 0) return; - var remaining = count ?? CurrentRemaining; - CheckBounds(remaining); + var currentRemaining = CurrentRemaining; + var remaining = count ?? currentRemaining; + + if (count > currentRemaining) + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(count), "Attempt to read past the end of the current field size."); + + if (StreamActive) + DisposeUserActiveStream(async: false).GetAwaiter().GetResult(); var origOffset = FieldOffset; // A breaking exception unwind from a nested scope should not try to consume its remaining data. if (!_buffer.Connector.IsBroken) - await _buffer.Skip(remaining, async).ConfigureAwait(false); + _buffer.Skip(remaining, allowIO: true); Debug.Assert(FieldRemaining == FieldSize - origOffset - remaining); } - public void Consume(int? count = null) => Consume(async: false, count).GetAwaiter().GetResult(); - public ValueTask ConsumeAsync(int? count = null, CancellationToken cancellationToken = default) => Consume(async: true, count, cancellationToken); + public async ValueTask ConsumeAsync(int? count = null, CancellationToken cancellationToken = default) + { + if (count <= 0 || FieldSize < 0 || FieldRemaining == 0) + return; + + var currentRemaining = CurrentRemaining; + var remaining = count ?? currentRemaining; + + if (count > currentRemaining) + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(count), "Attempt to read past the end of the current field size."); + + if (StreamActive) + await DisposeUserActiveStream(async: true).ConfigureAwait(false); + + var origOffset = FieldOffset; + // A breaking exception unwind from a nested scope should not try to consume its remaining data. + if (!_buffer.Connector.IsBroken) + await _buffer.Skip(async:true, remaining).ConfigureAwait(false); + + Debug.Assert(FieldRemaining == FieldSize - origOffset - remaining); + } [MemberNotNullWhen(true, nameof(_userActiveStream))] bool StreamActive => _userActiveStream is { IsDisposed: false }; @@ -568,169 +587,134 @@ internal void ThrowIfStreamActive() ThrowHelper.ThrowInvalidOperationException("A stream is already open for this reader"); } - internal bool CommitHasIO(bool resuming) => Initialized && !resuming && FieldRemaining > 0; + [MethodImpl(MethodImplOptions.NoInlining)] + void Cleanup() + { + if (StreamActive) + DisposeUserActiveStream(async: false).GetAwaiter().GetResult(); - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal void Commit(bool resuming) + if (_pooledArray is not null) + { + ArrayPool.Return(_pooledArray); + _pooledArray = null; + } + + if (_charsReadReader is not null) + { + _charsReadReader.Dispose(); + _charsReadReader = null; + _charsRead = default; + } + + _requiresCleanup = false; + } + + void ResetCurrent() + { + _currentStartPos = 0; + _currentBufferRequirement = default; + _currentSize = UninitializedSentinel; + } + + internal int Restart(bool resumable) { if (!Initialized) - return; + ThrowHelper.ThrowInvalidOperationException("Cannot restart a non-initialized reader."); - if (resuming) + // We resume if the reader was initialized as resumable and we're not explicitly restarting as non-resumable. + // When the field size is DbNullFieldSize (i.e. -1) we're always restarting as resumable, to allow rereading null values endlessly. + if ((Resumable && resumable) || FieldIsDbNull) { - if (!Resumable) - ThrowHelper.ThrowInvalidOperationException("Cannot resume a non-resumable read."); - return; + _resumable = resumable || FieldIsDbNull; + return FieldSize; } - // We don't rely on CurrentRemaining, just to make sure we consume fully in the event of a nested scope not being disposed. - // Also shut down any streaming, pooled arrays etc. - if (_requiresCleanup || (!_fieldConsumed && FieldRemaining > 0)) - { - CommitSlow(); + // From this point on we're not resuming, we're resetting any remaining state and rewinding our position. + + // Shut down any streaming and pooling going on on the column. + if (_requiresCleanup) + Cleanup(); + + if (NestedInitialized) + ResetCurrent(); + + _fieldConsumed = false; + _resumable = resumable; + Seek(0); + + Debug.Assert(Initialized); + return FieldSize; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal void Commit() + { + if (!Initialized) return; - } - _fieldStartPos = -1; + // Shut down any streaming and pooling going on on the column. + if (_requiresCleanup) + Cleanup(); + + if (NestedInitialized) + ResetCurrent(); + + // We make sure to fuly consume any FieldRemaining in the event of an exception or a nested scope not being disposed. + Debug.Assert(!NestedInitialized); + if (!_fieldConsumed && FieldRemaining > 0) + Consume(); + + _fieldStartPos = UninitializedSentinel; Debug.Assert(!Initialized); // These will always be re-initialized by Init() // _fieldSize = default; // _fieldFormat = default; // _resumable = default; - // _fieldCompleted = default; - - if (HasCurrent) - { - _currentStartPos = 0; - _currentBufferRequirement = default; - _currentSize = -1; - Debug.Assert(!HasCurrent); - } - - [MethodImpl(MethodImplOptions.NoInlining)] - void CommitSlow() - { - // Shut down any streaming and pooling going on on the column. - if (_requiresCleanup) - { - if (StreamActive) - DisposeUserActiveStream(async: false).GetAwaiter().GetResult(); - - if (_pooledArray is not null) - { - ArrayPool.Return(_pooledArray); - _pooledArray = null; - } - - if (_charsReadReader is not null) - { - _charsReadReader.Dispose(); - _charsReadReader = null; - _charsRead = default; - } - _requiresCleanup = false; - } - - Consume(async: false, count: FieldRemaining).GetAwaiter().GetResult(); - - _fieldStartPos = -1; - Debug.Assert(!Initialized); - - // These will always be re-initialized by Init() - // _fieldSize = default; - // _fieldFormat = default; - // _resumable = default; - // _fieldCompleted = default; - - if (HasCurrent) - { - _currentStartPos = 0; - _currentBufferRequirement = default; - _currentSize = -1; - Debug.Assert(!HasCurrent); - } - } + // _fieldConsumed = default; } [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal ValueTask CommitAsync(bool resuming) + internal ValueTask CommitAsync() { if (!Initialized) return new(); - if (resuming) - { - if (!Resumable) - ThrowHelper.ThrowInvalidOperationException("Cannot resume a non-resumable read."); - return new(); - } + // Shut down any streaming and pooling going on on the column. + if (_requiresCleanup) + Cleanup(); + + if (NestedInitialized) + ResetCurrent(); - // We don't rely on CurrentRemaining, just to make sure we consume fully in the event of a nested scope not being disposed. - // Also shut down any streaming, pooled arrays etc. - if (_requiresCleanup || (!_fieldConsumed && FieldRemaining > 0)) - return CommitSlow(); + // We make sure to fuly consume any FieldRemaining in the event of an exception or a nested scope not being disposed. + Debug.Assert(!NestedInitialized); + if (!_fieldConsumed && FieldRemaining > 0) + return CommitAsync(); - _fieldStartPos = -1; + _fieldStartPos = UninitializedSentinel; Debug.Assert(!Initialized); // These will always be re-initialized by Init() // _fieldSize = default; // _fieldFormat = default; // _resumable = default; - // _fieldCompleted = default; - - if (HasCurrent) - { - _currentStartPos = 0; - _currentBufferRequirement = default; - _currentSize = -1; - Debug.Assert(!HasCurrent); - } + // _fieldConsumed = default; return new(); - async ValueTask CommitSlow() + async ValueTask CommitAsync() { - // Shut down any streaming and pooling going on on the column. - if (_requiresCleanup) - { - if (StreamActive) - await DisposeUserActiveStream(async: true).ConfigureAwait(false); - - if (_pooledArray is not null) - { - ArrayPool.Return(_pooledArray); - _pooledArray = null; - } - - if (_charsReadReader is not null) - { - _charsReadReader.Dispose(); - _charsReadReader = null; - _charsRead = default; - } - _requiresCleanup = false; - } - - await Consume(async: true, count: FieldRemaining).ConfigureAwait(false); - - _fieldStartPos = -1; + await ConsumeAsync().ConfigureAwait(false); + + _fieldStartPos = UninitializedSentinel; Debug.Assert(!Initialized); // These will always be re-initialized by Init() // _fieldSize = default; // _fieldFormat = default; // _resumable = default; - // _fieldCompleted = default; - - if (HasCurrent) - { - _currentStartPos = 0; - _currentBufferRequirement = default; - _currentSize = -1; - Debug.Assert(!HasCurrent); - } + // _fieldConsumed = default; } } @@ -760,30 +744,25 @@ public bool ShouldBuffer(Size bufferRequirement) => ShouldBuffer(GetBufferRequirementByteCount(bufferRequirement)); public bool ShouldBuffer(int byteCount) { - return _buffer.ReadBytesLeft < byteCount && ShouldBufferSlow(); + return _buffer.ReadBytesLeft < byteCount && ShouldBufferSlow(byteCount); [MethodImpl(MethodImplOptions.NoInlining)] - bool ShouldBufferSlow() + bool ShouldBufferSlow(int byteCount) { if (byteCount > _buffer.Size) - ThrowArgumentOutOfRange(); + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(byteCount), + "Buffer requirement is larger than the buffer size, this can never succeed by buffering data but requires a larger buffer size instead."); if (byteCount > CurrentRemaining) - ThrowArgumentOutOfRangeOfValue(); + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(byteCount), + "Buffer requirement is larger than the remaining length of the value, make sure the value is always at least this size or use an upper bound requirement instead."); return true; } - - static void ThrowArgumentOutOfRange() - => throw new ArgumentOutOfRangeException(nameof(byteCount), - "Buffer requirement is larger than the buffer size, this can never succeed by buffering data but requires a larger buffer size instead."); - static void ThrowArgumentOutOfRangeOfValue() - => throw new ArgumentOutOfRangeException(nameof(byteCount), - "Buffer requirement is larger than the remaining length of the value, make sure the value is always at least this size or use an upper bound requirement instead."); } public void Buffer(Size bufferRequirement) => Buffer(GetBufferRequirementByteCount(bufferRequirement)); - public void Buffer(int byteCount) => _buffer.Ensure(byteCount, async: false).GetAwaiter().GetResult(); + public void Buffer(int byteCount) => _buffer.Ensure(byteCount); public ValueTask BufferAsync(Size bufferRequirement, CancellationToken cancellationToken) => BufferAsync(GetBufferRequirementByteCount(bufferRequirement), cancellationToken); @@ -828,7 +807,7 @@ internal NestedReadScope(bool async, PgReader reader, int previousSize, int prev public void Dispose() { if (_async) - throw new InvalidOperationException("Cannot synchronously dispose async scopes, call DisposeAsync instead."); + ThrowHelper.ThrowInvalidOperationException("Cannot synchronously dispose async scopes, call DisposeAsync instead."); DisposeAsync().GetAwaiter().GetResult(); } diff --git a/src/Npgsql/Internal/PgSerializerOptions.cs b/src/Npgsql/Internal/PgSerializerOptions.cs index 405d1d11da..052404da5c 100644 --- a/src/Npgsql/Internal/PgSerializerOptions.cs +++ b/src/Npgsql/Internal/PgSerializerOptions.cs @@ -1,6 +1,6 @@ using System; +using System.Diagnostics.CodeAnalysis; using System.IO; -using System.Runtime.CompilerServices; using System.Text; using Npgsql.Internal.Postgres; using Npgsql.NameTranslation; @@ -8,6 +8,7 @@ namespace Npgsql.Internal; +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] public sealed class PgSerializerOptions { /// @@ -32,7 +33,7 @@ internal PgSerializerOptions(NpgsqlDatabaseInfo databaseInfo, PgTypeInfoResolver internal PgTypeInfo UnspecifiedDBNullTypeInfo { get; } PostgresType? _textPgType; - internal PostgresType TextPgType => _textPgType ??= DatabaseInfo.GetPostgresType(DataTypeNames.Text); + internal PgTypeId TextPgTypeId => ToCanonicalTypeId(_textPgType ??= DatabaseInfo.GetPostgresType(DataTypeNames.Text)); // Used purely for type mapping, where we don't have a full set of types but resolvers might know enough. readonly bool _introspectionInstance; @@ -77,28 +78,22 @@ public static bool IsWellKnownTextType(Type type) // This also makes it easier to realize it should be a cached value if infos for different CLR types are requested for the same // pgTypeId. Effectively it should be 'impossible' to get the wrong kind via any PgConverterOptions api which is what this is mainly // for. - PgTypeInfo? GetTypeInfoCore(Type? type, PgTypeId? pgTypeId, bool defaultTypeFallback) + PgTypeInfo? GetTypeInfoCore(Type? type, PgTypeId? pgTypeId) => PortableTypeIds - ? Unsafe.As>(_typeInfoCache ??= new TypeInfoCache(this)).GetOrAddInfo(type, pgTypeId?.DataTypeName, defaultTypeFallback) - : Unsafe.As>(_typeInfoCache ??= new TypeInfoCache(this)).GetOrAddInfo(type, pgTypeId?.Oid, defaultTypeFallback); + ? ((TypeInfoCache)(_typeInfoCache ??= new TypeInfoCache(this))).GetOrAddInfo(type, pgTypeId?.DataTypeName) + : ((TypeInfoCache)(_typeInfoCache ??= new TypeInfoCache(this))).GetOrAddInfo(type, pgTypeId?.Oid); - public PgTypeInfo? GetDefaultTypeInfo(PostgresType pgType) - => GetTypeInfoCore(null, ToCanonicalTypeId(pgType), false); + internal PgTypeInfo? GetTypeInfoInternal(Type? type, PgTypeId? pgTypeId) + => GetTypeInfoCore(type, pgTypeId); - public PgTypeInfo? GetDefaultTypeInfo(PgTypeId pgTypeId) - => GetTypeInfoCore(null, pgTypeId, false); - - public PgTypeInfo? GetTypeInfo(Type type, PostgresType pgType) - => GetTypeInfoCore(type, ToCanonicalTypeId(pgType), false); + public PgTypeInfo? GetDefaultTypeInfo(Type type) + => GetTypeInfoCore(type, null); - public PgTypeInfo? GetTypeInfo(Type type, PgTypeId? pgTypeId = null) - => GetTypeInfoCore(type, pgTypeId, false); - - public PgTypeInfo? GetObjectOrDefaultTypeInfo(PostgresType pgType) - => GetTypeInfoCore(typeof(object), ToCanonicalTypeId(pgType), true); + public PgTypeInfo? GetDefaultTypeInfo(PgTypeId pgTypeId) + => GetTypeInfoCore(null, GetCanonicalTypeId(pgTypeId)); - public PgTypeInfo? GetObjectOrDefaultTypeInfo(PgTypeId pgTypeId) - => GetTypeInfoCore(typeof(object), pgTypeId, true); + public PgTypeInfo? GetTypeInfo(Type type, PgTypeId pgTypeId) + => GetTypeInfoCore(type, GetCanonicalTypeId(pgTypeId)); // If a given type id is in the opposite form than what was expected it will be mapped according to the requirement. internal PgTypeId GetCanonicalTypeId(PgTypeId pgTypeId) diff --git a/src/Npgsql/Internal/PgStreamingConverter.cs b/src/Npgsql/Internal/PgStreamingConverter.cs index ff9c6b5eb2..951e940fd8 100644 --- a/src/Npgsql/Internal/PgStreamingConverter.cs +++ b/src/Npgsql/Internal/PgStreamingConverter.cs @@ -1,15 +1,15 @@ using System; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; namespace Npgsql.Internal; -public abstract class PgStreamingConverter : PgConverter +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public abstract class PgStreamingConverter(bool customDbNullPredicate = false) : PgConverter(customDbNullPredicate) { - protected PgStreamingConverter(bool customDbNullPredicate = false) : base(customDbNullPredicate) { } - public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) { bufferRequirements = BufferRequirements.None; @@ -42,9 +42,9 @@ internal sealed override unsafe ValueTask ReadAsObject( static object BoxResult(Task task) { + // Justification: exact type Unsafe.As used to reduce generic duplication cost. Debug.Assert(task is Task); - // We're using ValueTask.Result here to avoid rooting any TaskAwaiter or ValueTaskAwaiter types. - // On ValueTask calling .Result is equivalent to GetAwaiter().GetResult() w.r.t. exception wrapping. + // Using .Result on ValueTask is equivalent to GetAwaiter().GetResult(), this removes TaskAwaiter rooting. return new ValueTask(task: Unsafe.As>(task)).Result!; } } diff --git a/src/Npgsql/Internal/PgTypeInfo.cs b/src/Npgsql/Internal/PgTypeInfo.cs index 0c1f2f4ede..93b90b3a70 100644 --- a/src/Npgsql/Internal/PgTypeInfo.cs +++ b/src/Npgsql/Internal/PgTypeInfo.cs @@ -4,6 +4,7 @@ namespace Npgsql.Internal; +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] public class PgTypeInfo { readonly bool _canBinaryConvert; @@ -20,6 +21,7 @@ public class PgTypeInfo Options = options; IsBoxing = unboxedType is not null; Type = unboxedType ?? type; + SupportsReading = GetDefaultSupportsReading(type, unboxedType); SupportsWriting = true; } @@ -53,6 +55,7 @@ private protected PgTypeInfo(PgSerializerOptions options, Type type, PgConverter public Type Type { get; } public PgSerializerOptions Options { get; } + public bool SupportsReading { get; init; } public bool SupportsWriting { get; init; } public DataFormat? PreferredFormat { get; init; } @@ -114,24 +117,27 @@ internal PgConverterResolution GetResolution() return new(Converter, PgTypeId.GetValueOrDefault()); } - bool CachedCanConvert(DataFormat format, out BufferRequirements bufferRequirements) + bool CanConvert(PgConverter converter, DataFormat format, out BufferRequirements bufferRequirements) { - if (format is DataFormat.Binary) + if (HasCachedInfo(converter)) { - bufferRequirements = _binaryBufferRequirements; - return _canBinaryConvert; + switch (format) + { + case DataFormat.Binary: + bufferRequirements = _binaryBufferRequirements; + return _canBinaryConvert; + case DataFormat.Text: + bufferRequirements = _textBufferRequirements; + return _canTextConvert; + } } - bufferRequirements = _textBufferRequirements; - return _canTextConvert; + return converter.CanConvert(format, out bufferRequirements); } public BufferRequirements? GetBufferRequirements(PgConverter converter, DataFormat format) { - var success = HasCachedInfo(converter) - ? CachedCanConvert(format, out var bufferRequirements) - : converter.CanConvert(format, out bufferRequirements); - + var success = CanConvert(converter, format, out var bufferRequirements); return success ? bufferRequirements : null; } @@ -141,7 +147,7 @@ internal bool TryBind(Field field, DataFormat format, out PgConverterInfo info) switch (this) { case { IsResolverInfo: false }: - if (!CachedCanConvert(format, out var bufferRequirements)) + if (!CanConvert(Converter, format, out var bufferRequirements)) { info = default; return false; @@ -150,9 +156,7 @@ internal bool TryBind(Field field, DataFormat format, out PgConverterInfo info) return true; case PgResolverTypeInfo resolverInfo: var resolution = resolverInfo.GetResolution(field); - if (HasCachedInfo(resolution.Converter) - ? !CachedCanConvert(format, out bufferRequirements) - : !resolution.Converter.CanConvert(format, out bufferRequirements)) + if (!CanConvert(resolution.Converter, format, out bufferRequirements)) { info = default; return false; @@ -217,42 +221,45 @@ internal PgConverterInfo Bind(Field field, DataFormat format) return new(this, converter, bufferRequirements.Write); } - // If we don't have a converter stored we must ask the retrieved one. DataFormat ResolveFormat(PgConverter converter, out BufferRequirements bufferRequirements, DataFormat? formatPreference = null) { + // First try to check for preferred support. switch (formatPreference) { - // The common case, no preference means we default to binary if supported. - case null or DataFormat.Binary when HasCachedInfo(converter) ? CachedCanConvert(DataFormat.Binary, out bufferRequirements) : converter.CanConvert(DataFormat.Binary, out bufferRequirements): + case DataFormat.Binary when CanConvert(converter, DataFormat.Binary, out bufferRequirements): return DataFormat.Binary; - // In this case we either prefer text or we have no preference and our converter doesn't support binary. - case null or DataFormat.Text: - var canTextConvert = HasCachedInfo(converter) ? CachedCanConvert(DataFormat.Text, out bufferRequirements) : converter.CanConvert(DataFormat.Text, out bufferRequirements); - if (!canTextConvert) - { - if (formatPreference is null) - throw new InvalidOperationException("Converter doesn't support any data format."); - // Rerun without preference. - return ResolveFormat(converter, out bufferRequirements); - } + case DataFormat.Text when CanConvert(converter, DataFormat.Text, out bufferRequirements): return DataFormat.Text; default: - throw new ArgumentOutOfRangeException(); + // The common case, no preference given (or no match) means we default to binary if supported. + if (CanConvert(converter, DataFormat.Binary, out bufferRequirements)) + return DataFormat.Binary; + if (CanConvert(converter, DataFormat.Text, out bufferRequirements)) + return DataFormat.Text; + + ThrowHelper.ThrowInvalidOperationException("Converter doesn't support any data format."); + bufferRequirements = default; + return default; } } + + // We assume a boxing type info does not support reading as the converter won't be able to produce the derived type statically. + // Cases like Array converters unboxing to int[], int[,] etc. are the exception and the reason why SupportsReading is a settable property. + internal static bool GetDefaultSupportsReading(Type type, Type? unboxedType) + => unboxedType is null || unboxedType == type; } -public sealed class PgResolverTypeInfo : PgTypeInfo +public sealed class PgResolverTypeInfo( + PgSerializerOptions options, + PgConverterResolver converterResolver, + PgTypeId? pgTypeId, + Type? unboxedType = null) + : PgTypeInfo(options, + converterResolver.TypeToConvert, + pgTypeId is { } typeId ? ResolveDefaultId(options, converterResolver, typeId) : null, + unboxedType ?? (converterResolver.TypeToConvert == typeof(object) ? typeof(object) : null)) { - readonly PgConverterResolver _converterResolver; - - public PgResolverTypeInfo(PgSerializerOptions options, PgConverterResolver converterResolver, PgTypeId? pgTypeId, Type? unboxedType = null) - : base(options, - converterResolver.TypeToConvert, - pgTypeId is { } typeId ? ResolveDefaultId(options, converterResolver, typeId) : null, - // We always mark resolvers with type object as boxing, as they may freely return converters for any type (see PgConverterResolver.Validate). - unboxedType ?? (converterResolver.TypeToConvert == typeof(object) ? typeof(object) : null)) - => _converterResolver = converterResolver; + // We always mark resolvers with type object as boxing, as they may freely return converters for any type (see PgConverterResolver.Validate). // We'll always validate the default resolution, the info will be re-used so there is no real downside. static PgConverterResolution ResolveDefaultId(PgSerializerOptions options, PgConverterResolver converterResolver, PgTypeId typeId) @@ -260,7 +267,7 @@ static PgConverterResolution ResolveDefaultId(PgSerializerOptions options, PgCon public PgConverterResolution? GetResolution(T? value, PgTypeId? expectedPgTypeId) { - return _converterResolver is PgConverterResolver resolverT + return converterResolver is PgConverterResolver resolverT ? resolverT.GetInternal(this, value, expectedPgTypeId ?? PgTypeId) : ThrowNotSupportedType(typeof(T)); @@ -271,27 +278,21 @@ PgConverterResolution ThrowNotSupportedType(Type? type) } public PgConverterResolution? GetResolutionAsObject(object? value, PgTypeId? expectedPgTypeId) - => _converterResolver.GetAsObjectInternal(this, value, expectedPgTypeId ?? PgTypeId); + => converterResolver.GetAsObjectInternal(this, value, expectedPgTypeId ?? PgTypeId); public PgConverterResolution GetResolution(Field field) - => _converterResolver.GetInternal(this, field); + => converterResolver.GetInternal(this, field); public PgConverterResolution GetDefaultResolution(PgTypeId? expectedPgTypeId) - => _converterResolver.GetDefaultInternal(ValidateResolution, Options.PortableTypeIds, expectedPgTypeId ?? PgTypeId); + => converterResolver.GetDefaultInternal(ValidateResolution, Options.PortableTypeIds, expectedPgTypeId ?? PgTypeId); - public PgConverterResolver GetConverterResolver() => _converterResolver; + public PgConverterResolver GetConverterResolver() => converterResolver; } -public readonly struct PgConverterResolution +public readonly struct PgConverterResolution(PgConverter converter, PgTypeId pgTypeId) { - public PgConverterResolution(PgConverter converter, PgTypeId pgTypeId) - { - Converter = converter; - PgTypeId = pgTypeId; - } - - public PgConverter Converter { get; } - public PgTypeId PgTypeId { get; } + public PgConverter Converter { get; } = converter; + public PgTypeId PgTypeId { get; } = pgTypeId; public PgConverter GetConverter() => (PgConverter)Converter; } @@ -325,6 +326,4 @@ public PgConverterInfo(PgTypeInfo pgTypeInfo, PgConverter converter, Size buffer /// Whether Converter.TypeToConvert matches PgTypeInfo.Type, if it doesn't object apis should be used. public bool IsBoxingConverter => _typeInfo.IsBoxing; - - public PgConverter GetConverter() => (PgConverter)Converter; } diff --git a/src/Npgsql/Internal/PgTypeInfoResolverChainBuilder.cs b/src/Npgsql/Internal/PgTypeInfoResolverChainBuilder.cs index 548d236096..86c96231a0 100644 --- a/src/Npgsql/Internal/PgTypeInfoResolverChainBuilder.cs +++ b/src/Npgsql/Internal/PgTypeInfoResolverChainBuilder.cs @@ -6,7 +6,7 @@ namespace Npgsql.Internal; struct PgTypeInfoResolverChainBuilder { - readonly List<(Type ImplementationType, object)> _factories = new(); + readonly List<(Type ImplementationType, object)> _factories = []; Action>? _addRangeResolvers; Action>? _addMultirangeResolvers; RangeArrayHandler _rangeArrayHandler = RangeArrayHandler.Instance; @@ -115,6 +115,7 @@ public PgTypeInfoResolverChain Build(Action>? configur _addRangeResolvers?.Invoke(instance, resolvers); _addMultirangeResolvers?.Invoke(instance, resolvers); _addArrayResolvers?.Invoke(instance, resolvers); + configure?.Invoke(resolvers); return new( resolvers, diff --git a/src/Npgsql/Internal/PgTypeInfoResolverFactory.cs b/src/Npgsql/Internal/PgTypeInfoResolverFactory.cs index f30059c7ec..9392e2c840 100644 --- a/src/Npgsql/Internal/PgTypeInfoResolverFactory.cs +++ b/src/Npgsql/Internal/PgTypeInfoResolverFactory.cs @@ -1,5 +1,8 @@ +using System.Diagnostics.CodeAnalysis; + namespace Npgsql.Internal; +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] public abstract class PgTypeInfoResolverFactory { public abstract IPgTypeInfoResolver CreateResolver(); diff --git a/src/Npgsql/Internal/PgWriter.cs b/src/Npgsql/Internal/PgWriter.cs index 3c5064386c..2d08a38e53 100644 --- a/src/Npgsql/Internal/PgWriter.cs +++ b/src/Npgsql/Internal/PgWriter.cs @@ -2,6 +2,7 @@ using System.Buffers; using System.Buffers.Binary; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.IO; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; @@ -26,40 +27,38 @@ interface IStreamingWriter: IBufferWriter ValueTask FlushAsync(CancellationToken cancellationToken = default); } -sealed class NpgsqlBufferWriter : IStreamingWriter +sealed class NpgsqlBufferWriter(NpgsqlWriteBuffer buffer) : IStreamingWriter { - readonly NpgsqlWriteBuffer _buffer; int? _lastBufferSize; - public NpgsqlBufferWriter(NpgsqlWriteBuffer buffer) => _buffer = buffer; public void Advance(int count) { - if (_lastBufferSize < count || _buffer.WriteSpaceLeft < count) + if (_lastBufferSize < count || buffer.WriteSpaceLeft < count) ThrowHelper.ThrowInvalidOperationException("Cannot advance past the end of the current buffer."); _lastBufferSize = null; - _buffer.WritePosition += count; + buffer.WritePosition += count; } public Memory GetMemory(int sizeHint = 0) { - var writePosition = _buffer.WritePosition; - var bufferSize = _buffer.Size - writePosition; + var writePosition = buffer.WritePosition; + var bufferSize = buffer.Size - writePosition; if (sizeHint > bufferSize) ThrowOutOfMemoryException(); _lastBufferSize = bufferSize; - return _buffer.Buffer.AsMemory(writePosition, bufferSize); + return buffer.Buffer.AsMemory(writePosition, bufferSize); } public Span GetSpan(int sizeHint = 0) { - var writePosition = _buffer.WritePosition; - var bufferSize = _buffer.Size - writePosition; + var writePosition = buffer.WritePosition; + var bufferSize = buffer.Size - writePosition; if (sizeHint > bufferSize) ThrowOutOfMemoryException(); _lastBufferSize = bufferSize; - return _buffer.Buffer.AsSpan(writePosition, bufferSize); + return buffer.Buffer.AsSpan(writePosition, bufferSize); } static void ThrowOutOfMemoryException() => throw new OutOfMemoryException("Not enough space left in buffer."); @@ -67,7 +66,7 @@ public Span GetSpan(int sizeHint = 0) public void Flush(TimeSpan timeout = default) { if (timeout == TimeSpan.Zero) - _buffer.Flush(); + buffer.Flush(); else { TimeSpan? originalTimeout = null; @@ -75,23 +74,24 @@ public void Flush(TimeSpan timeout = default) { if (timeout != TimeSpan.Zero) { - originalTimeout = _buffer.Timeout; - _buffer.Timeout = timeout; + originalTimeout = buffer.Timeout; + buffer.Timeout = timeout; } - _buffer.Flush(); + buffer.Flush(); } finally { if (originalTimeout is { } value) - _buffer.Timeout = value; + buffer.Timeout = value; } } } public ValueTask FlushAsync(CancellationToken cancellationToken = default) - => new(_buffer.Flush(async: true, cancellationToken)); + => new(buffer.Flush(async: true, cancellationToken)); } +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] public sealed class PgWriter { readonly IBufferWriter _writer; @@ -298,7 +298,7 @@ void Core(ReadOnlySpan data, Encoding encoding) if (ShouldFlush(minBufferSize)) Flush(); Ensure(minBufferSize); - encoder.Convert(data, Span, flush: data.Length <= Span.Length, out var charsUsed, out var bytesUsed, out completed); + encoder.Convert(data, Span, flush: true, out var charsUsed, out var bytesUsed, out completed); data = data.Slice(charsUsed); Advance(bytesUsed); } while (!completed); @@ -334,7 +334,7 @@ async ValueTask Core(ReadOnlyMemory data, Encoding encoding, CancellationT if (ShouldFlush(minBufferSize)) await FlushAsync(cancellationToken).ConfigureAwait(false); Ensure(minBufferSize); - encoder.Convert(data.Span, Span, flush: data.Length <= Span.Length, out var charsUsed, out var bytesUsed, out completed); + encoder.Convert(data.Span, Span, flush: true, out var charsUsed, out var bytesUsed, out completed); data = data.Slice(charsUsed); Advance(bytesUsed); } while (!completed); @@ -495,14 +495,11 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati Task Write(bool async, byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - if (buffer is null) - throw new ArgumentNullException(nameof(buffer)); - if (offset < 0) - throw new ArgumentNullException(nameof(offset)); - if (count < 0) - throw new ArgumentNullException(nameof(count)); + ArgumentNullException.ThrowIfNull(buffer); + ArgumentOutOfRangeException.ThrowIfNegative(offset); + ArgumentOutOfRangeException.ThrowIfNegative(count); if (buffer.Length - offset < count) - throw new ArgumentException("Offset and length were out of bounds for the array or count is greater than the number of elements from index to the end of the source collection."); + ThrowHelper.ThrowArgumentException("Offset and length were out of bounds for the array or count is greater than the number of elements from index to the end of the source collection."); if (async) { @@ -557,6 +554,7 @@ public override long Seek(long offset, SeekOrigin origin) } // No-op for now. +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] public struct NestedWriteScope : IDisposable { public void Dispose() diff --git a/src/Npgsql/Internal/Postgres/DataTypeName.cs b/src/Npgsql/Internal/Postgres/DataTypeName.cs index d20e479f85..9c9f43e41a 100644 --- a/src/Npgsql/Internal/Postgres/DataTypeName.cs +++ b/src/Npgsql/Internal/Postgres/DataTypeName.cs @@ -1,14 +1,18 @@ using System; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; namespace Npgsql.Internal.Postgres; /// -/// Represents the fully-qualified name of a PostgreSQL type. +/// Represents the normalized name of a PostgreSQL data type. /// +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] [DebuggerDisplay("{DisplayName,nq}")] public readonly struct DataTypeName : IEquatable { + const char InvalidIdentifier = '-'; + /// /// The maximum length of names in an unmodified PostgreSQL installation. /// @@ -25,7 +29,7 @@ namespace Npgsql.Internal.Postgres; if (!validated) { var schemaEndIndex = fullyQualifiedDataTypeName.IndexOf('.'); - if (schemaEndIndex == -1) + if (schemaEndIndex is -1 or 0) throw new ArgumentException("Given value does not contain a schema.", nameof(fullyQualifiedDataTypeName)); // Friendly array syntax is the only fully qualified name quirk that's allowed by postgres (see FromDisplayName). @@ -48,9 +52,9 @@ public DataTypeName(string fullyQualifiedDataTypeName) internal static DataTypeName ValidatedName(string fullyQualifiedDataTypeName) => new(fullyQualifiedDataTypeName, validated: true); - // Includes schema unless it's pg_catalog or the name is unspecified. + // Includes schema unless it's pg_catalog or the schema is an invalid character used to represent an unspecified schema. public string DisplayName => - Value.StartsWith("pg_catalog", StringComparison.Ordinal) || Value == Unspecified + Value.StartsWith("pg_catalog", StringComparison.Ordinal) || IsUnqualified ? UnqualifiedDisplayName : Schema + "." + UnqualifiedDisplayName; @@ -69,12 +73,19 @@ static string ThrowDefaultException() => // This contains two invalid sql identifiers (schema and name are both separate identifiers, and would both have to be quoted to be valid). // Given this is an invalid name it's fine for us to represent a fully qualified 'unspecified' name with it. - public static DataTypeName Unspecified => new("-.-", validated: true); + static string UnspecifiedName => $"{InvalidIdentifier}.{InvalidIdentifier}"; + public static DataTypeName Unspecified => ValidatedName(UnspecifiedName); + + public static string GetUnqualifiedName(string dataTypeName) + => dataTypeName.IndexOf('.') is not -1 and var index + ? dataTypeName.Substring(index + 1) : dataTypeName; + + public bool IsUnqualified => Value.StartsWith(InvalidIdentifier) && Value != UnspecifiedName; public bool IsArray => UnqualifiedNameSpan.StartsWith("_".AsSpan(), StringComparison.Ordinal); internal static DataTypeName CreateFullyQualifiedName(string dataTypeName) - => dataTypeName.IndexOf('.') != -1 ? new(dataTypeName) : new("pg_catalog." + dataTypeName); + => dataTypeName.IndexOf('.') != -1 ? new(dataTypeName) : new("-." + dataTypeName); // Static transform as defined by https://www.postgresql.org/docs/current/sql-createtype.html#SQL-CREATETYPE-ARRAY // We don't have to deal with [] as we're always starting from a normalized fully qualified name. @@ -84,106 +95,111 @@ public DataTypeName ToArrayName() if (unqualifiedNameSpan.StartsWith("_".AsSpan(), StringComparison.Ordinal)) return this; - var unqualifiedName = unqualifiedNameSpan.ToString(); - if (unqualifiedName.Length + "_".Length > NAMEDATALEN) - unqualifiedName = unqualifiedName.Substring(0, NAMEDATALEN - "_".Length); + if (unqualifiedNameSpan.Length + "_".Length > NAMEDATALEN) + unqualifiedNameSpan = unqualifiedNameSpan.Slice(0, NAMEDATALEN - "_".Length); - return new(Schema + "._" + unqualifiedName); + return new(string.Concat(Schema, "._", unqualifiedNameSpan)); } // Static transform as defined by https://www.postgresql.org/docs/current/sql-createtype.html#SQL-CREATETYPE-RANGE // Manual testing on PG confirmed it's only the first occurence of 'range' that gets replaced. public DataTypeName ToDefaultMultirangeName() { - var unqualifiedNameSpan = UnqualifiedNameSpan; - if (UnqualifiedNameSpan.IndexOf("multirange".AsSpan(), StringComparison.Ordinal) != -1) + var nameSpan = UnqualifiedNameSpan; + if (nameSpan.IndexOf("multirange".AsSpan(), StringComparison.Ordinal) is not -1) return this; - var unqualifiedName = unqualifiedNameSpan.ToString(); - var rangeIndex = unqualifiedName.IndexOf("range", StringComparison.Ordinal); - if (rangeIndex != -1) + if (nameSpan.IndexOf("range", StringComparison.Ordinal) is var rangeIndex and not -1) { - var str = unqualifiedName.Substring(0, rangeIndex) + "multirange" + unqualifiedName.Substring(rangeIndex + "range".Length); - - return new($"{Schema}." + (unqualifiedName.Length + "multi".Length > NAMEDATALEN - ? str.Substring(0, NAMEDATALEN - "multi".Length) - : str)); + nameSpan = string.Concat(nameSpan.Slice(0, rangeIndex), "multirange", nameSpan.Slice(rangeIndex + "range".Length)); + return new(string.Concat(SchemaSpan, ".", + nameSpan.Length > NAMEDATALEN ? nameSpan.Slice(0, NAMEDATALEN) : nameSpan)); } - return new($"{Schema}." + (unqualifiedName.Length + "multi".Length > NAMEDATALEN - ? unqualifiedName.Substring(0, NAMEDATALEN - "_multirange".Length) + "_multirange" - : unqualifiedName + "_multirange")); + if (nameSpan.Length + "_multirange".Length > NAMEDATALEN) + nameSpan = nameSpan.Slice(0, NAMEDATALEN - "_multirange".Length); + + return new(string.Concat(SchemaSpan, ".", nameSpan, "_multirange")); } // Create a DataTypeName from a broader range of valid names. // including SQL aliases like 'timestamp without time zone', trailing facet info etc. public static DataTypeName FromDisplayName(string displayName, string? schema = null) + => FromDisplayName(displayName, schema, assumeUnqualified: false); // user strings may come fully qualified. + + // This method is used during type loading, it allows us to accept friendly names in constructors, without having to preconcatenate the schema. + internal static DataTypeName FromDisplayName(string displayName, string? schema, bool assumeUnqualified) { var displayNameSpan = displayName.AsSpan().Trim(); - // If we have a schema we're done, Postgres doesn't do display name conversions on fully qualified names. - // There is one exception and that's array syntax, which is always resolvable in both ways, while we want the canonical name. var schemaEndIndex = displayNameSpan.IndexOf('.'); - if (schemaEndIndex is not -1 && - !displayNameSpan.Slice(schemaEndIndex).StartsWith("_".AsSpan(), StringComparison.Ordinal) && - !displayNameSpan.EndsWith("[]".AsSpan(), StringComparison.Ordinal)) - return new(displayName); - - // First we strip the schema to get the type name. - if (schemaEndIndex is not -1) + ReadOnlySpan schemaSpan; + if (schemaEndIndex is not -1 && !assumeUnqualified) { - schema = displayNameSpan.Slice(0, schemaEndIndex).ToString(); + if (schema is not null) + throw new ArgumentException("Schema provided for a fully qualified name."); + + schemaSpan = displayNameSpan.Slice(0, schemaEndIndex); displayNameSpan = displayNameSpan.Slice(schemaEndIndex + 1); } + else + { + schemaSpan = schema is null ? $"{InvalidIdentifier}" : schema.AsSpan(); + } // Then we strip either of the two valid array representations to get the base type name (with or without facets). var isArray = false; - if (displayNameSpan.StartsWith("_".AsSpan())) + if (displayNameSpan.StartsWith("_", StringComparison.Ordinal)) { isArray = true; displayNameSpan = displayNameSpan.Slice(1); } - else if (displayNameSpan.EndsWith("[]".AsSpan())) + else if (displayNameSpan.EndsWith("[]", StringComparison.Ordinal)) { isArray = true; displayNameSpan = displayNameSpan.Slice(0, displayNameSpan.Length - 2); } - string mapped; - if (schemaEndIndex is -1) + if (schemaEndIndex is not -1) { - // Finally we strip the facet info. - var parenIndex = displayNameSpan.IndexOf('('); - if (parenIndex > -1) - displayNameSpan = displayNameSpan.Slice(0, parenIndex); - - // Map any aliases to the internal type name. - mapped = displayNameSpan.ToString() switch - { - "boolean" => "bool", - "character" => "bpchar", - "decimal" => "numeric", - "real" => "float4", - "double precision" => "float8", - "smallint" => "int2", - "integer" => "int4", - "bigint" => "int8", - "time without time zone" => "time", - "timestamp without time zone" => "timestamp", - "time with time zone" => "timetz", - "timestamp with time zone" => "timestamptz", - "bit varying" => "varbit", - "character varying" => "varchar", - var value => value - }; + // If we have a schema we're done, Postgres doesn't do display name conversions on fully qualified names. + // There is one exception and that's array syntax, which is always resolvable in both ways, while we want the canonical name. + return !isArray + ? new(displayName.Length == schemaEndIndex + displayNameSpan.Length + ? displayName + : string.Concat(schemaSpan, ".", displayNameSpan)) + : new(string.Concat(schemaSpan, ".", "_", displayNameSpan)); } - else + + // Finally we strip the facet info. + var parenIndex = displayNameSpan.IndexOf('('); + if (parenIndex > -1) + displayNameSpan = displayNameSpan.Slice(0, parenIndex); + + // Map any aliases to the internal type name. + var mapped = displayNameSpan switch { - // If we had a schema originally we stop here, see comment at schemaEndIndex. - mapped = displayNameSpan.ToString(); - } + "boolean" => "bool", + "character" => "bpchar", + "decimal" => "numeric", + "real" => "float4", + "double precision" => "float8", + "smallint" => "int2", + "integer" => "int4", + "bigint" => "int8", + "time without time zone" => "time", + "timestamp without time zone" => "timestamp", + "time with time zone" => "timetz", + "timestamp with time zone" => "timestamptz", + "bit varying" => "varbit", + "character varying" => "varchar", + var value => value + }; - return new((schema ?? "pg_catalog") + "." + (isArray ? "_" : "") + mapped); + if (schema is null && DataTypeNames.IsWellKnownUnqualifiedName(mapped)) + schemaSpan = "pg_catalog".AsSpan(); + + return new(string.Concat(schemaSpan, ".", isArray ? "_" : "", mapped)); } // The type names stored in a DataTypeName are usually the actual typname from the pg_type column. @@ -193,8 +209,8 @@ public static DataTypeName FromDisplayName(string displayName, string? schema = // Alternatively some of the source lives at https://github.com/postgres/postgres/blob/c8e1ba736b2b9e8c98d37a5b77c4ed31baf94147/src/backend/utils/adt/format_type.c#L186 static string ToDisplayName(ReadOnlySpan unqualifiedName) { - var isArray = unqualifiedName.IndexOf('_') == 0; - var baseTypeName = isArray ? unqualifiedName.Slice(1).ToString() : unqualifiedName.ToString(); + var isArray = unqualifiedName.IndexOf('_') is 0; + var baseTypeName = isArray ? unqualifiedName.Slice(1) : unqualifiedName; var mappedBaseType = baseTypeName switch { @@ -212,13 +228,12 @@ static string ToDisplayName(ReadOnlySpan unqualifiedName) "timestamptz" => "timestamp with time zone", "varbit" => "bit varying", "varchar" => "character varying", - _ => baseTypeName + _ => null }; - if (isArray) - return mappedBaseType + "[]"; - - return mappedBaseType; + return isArray + ? string.Concat(mappedBaseType ?? baseTypeName, "[]") + : mappedBaseType ?? baseTypeName.ToString(); } internal static bool IsFullyQualified(ReadOnlySpan dataTypeName) => dataTypeName.Contains(".".AsSpan(), StringComparison.Ordinal); diff --git a/src/Npgsql/Internal/Postgres/DataTypeNames.cs b/src/Npgsql/Internal/Postgres/DataTypeNames.cs index 275bcb9937..6c4ca73b2f 100644 --- a/src/Npgsql/Internal/Postgres/DataTypeNames.cs +++ b/src/Npgsql/Internal/Postgres/DataTypeNames.cs @@ -1,3 +1,4 @@ +using System; using static Npgsql.Internal.Postgres.DataTypeName; namespace Npgsql.Internal.Postgres; @@ -7,6 +8,27 @@ namespace Npgsql.Internal.Postgres; /// static class DataTypeNames { + // Generated from the following query: + // SELECT '"' || string_agg(typname, '" or "') || '"' FROM ( + // SELECT typname FROM pg_catalog.pg_type WHERE typtype = 'b' AND typcategory <> 'A' + // AND typnamespace = (SELECT oid FROM pg_namespace WHERE nspname = 'pg_catalog') ORDER BY typname); + public static bool IsWellKnownUnqualifiedName(ReadOnlySpan name) => name switch + { + "aclitem" or "bit" or "bool" or "box" or "bpchar" or "bytea" or "char" or "cid" or + "cidr" or "circle" or "date" or "float4" or "float8" or "gtsvector" or "inet" or + "int2" or "int4" or "int8" or "interval" or "json" or "jsonb" or "jsonpath" or + "line" or "lseg" or "macaddr" or "macaddr8" or "money" or "name" or "numeric" or + "oid" or "path" or "pg_brin_bloom_summary" or "pg_brin_minmax_multi_summary" or + "pg_dependencies" or "pg_lsn" or "pg_mcv_list" or "pg_ndistinct" or "pg_node_tree" or + "pg_snapshot" or "point" or "polygon" or "refcursor" or "regclass" or "regcollation" or + "regconfig" or "regdictionary" or "regnamespace" or "regoper" or "regoperator" or + "regproc" or "regprocedure" or "regrole" or "regtype" or "text" or "tid" or "time" or + "timestamp" or "timestamptz" or "timetz" or "tsquery" or "tsvector" or "txid_snapshot" or + "uuid" or "varbit" or "varchar" or "xid" or "xid8" or "xml" + => true, + _ => false + }; + // Note: The names are fully qualified in source so the strings are constants and instances will be interned after the first call. // Uses an internal constructor bypassing the public DataTypeName constructor validation, as we don't want to store all these names on // fields either. diff --git a/src/Npgsql/Internal/Postgres/Field.cs b/src/Npgsql/Internal/Postgres/Field.cs index f6a261c103..abd74a0bc7 100644 --- a/src/Npgsql/Internal/Postgres/Field.cs +++ b/src/Npgsql/Internal/Postgres/Field.cs @@ -1,16 +1,12 @@ +using System.Diagnostics.CodeAnalysis; + namespace Npgsql.Internal.Postgres; /// Base field type shared between tables and composites. -public readonly struct Field +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public readonly struct Field(string name, PgTypeId pgTypeId, int typeModifier) { - public Field(string name, PgTypeId pgTypeId, int typeModifier) - { - Name = name; - PgTypeId = pgTypeId; - TypeModifier = typeModifier; - } - - public string Name { get; init; } - public PgTypeId PgTypeId { get; init; } - public int TypeModifier { get; init; } + public string Name { get; init; } = name; + public PgTypeId PgTypeId { get; init; } = pgTypeId; + public int TypeModifier { get; init; } = typeModifier; } diff --git a/src/Npgsql/Internal/Postgres/Oid.cs b/src/Npgsql/Internal/Postgres/Oid.cs index e6fcad6f4a..8c01e65ff7 100644 --- a/src/Npgsql/Internal/Postgres/Oid.cs +++ b/src/Npgsql/Internal/Postgres/Oid.cs @@ -1,14 +1,14 @@ using System; +using System.Diagnostics.CodeAnalysis; namespace Npgsql.Internal.Postgres; -public readonly struct Oid: IEquatable +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public readonly struct Oid(uint value) : IEquatable { - public Oid(uint value) => Value = value; - public static explicit operator uint(Oid oid) => oid.Value; public static implicit operator Oid(uint oid) => new(oid); - public uint Value { get; init; } + public uint Value { get; init; } = value; public static Oid Unspecified => new(0); public override string ToString() => Value.ToString(); diff --git a/src/Npgsql/Internal/Postgres/PgTypeId.cs b/src/Npgsql/Internal/Postgres/PgTypeId.cs index c5a40d22ca..ee5ffb9d41 100644 --- a/src/Npgsql/Internal/Postgres/PgTypeId.cs +++ b/src/Npgsql/Internal/Postgres/PgTypeId.cs @@ -6,6 +6,7 @@ namespace Npgsql.Internal.Postgres; /// /// A discriminated union of and . /// +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] public readonly struct PgTypeId: IEquatable { readonly DataTypeName _dataTypeName; diff --git a/src/Npgsql/Internal/ResolverFactories/AdoTypeInfoResolverFactory.Multirange.cs b/src/Npgsql/Internal/ResolverFactories/AdoTypeInfoResolverFactory.Multirange.cs index 873e6b9874..3d82ab03f1 100644 --- a/src/Npgsql/Internal/ResolverFactories/AdoTypeInfoResolverFactory.Multirange.cs +++ b/src/Npgsql/Internal/ResolverFactories/AdoTypeInfoResolverFactory.Multirange.cs @@ -74,18 +74,18 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) else { mappings.AddResolverType[]>(DataTypeNames.TsMultirange, - static (options, mapping, dataTypeNameMatch) => mapping.CreateInfo(options, + static (options, mapping, requiresDataTypeName) => mapping.CreateInfo(options, DateTimeConverterResolver.CreateMultirangeResolver[], NpgsqlRange>(options, options.GetCanonicalTypeId(DataTypeNames.TsTzMultirange), options.GetCanonicalTypeId(DataTypeNames.TsMultirange), - options.EnableDateTimeInfinityConversions), dataTypeNameMatch), + options.EnableDateTimeInfinityConversions), requiresDataTypeName), isDefault: true); mappings.AddResolverType>>(DataTypeNames.TsMultirange, - static (options, mapping, dataTypeNameMatch) => mapping.CreateInfo(options, + static (options, mapping, requiresDataTypeName) => mapping.CreateInfo(options, DateTimeConverterResolver.CreateMultirangeResolver>, NpgsqlRange>(options, options.GetCanonicalTypeId(DataTypeNames.TsTzMultirange), options.GetCanonicalTypeId(DataTypeNames.TsMultirange), - options.EnableDateTimeInfinityConversions), dataTypeNameMatch)); + options.EnableDateTimeInfinityConversions), requiresDataTypeName)); } mappings.AddType[]>(DataTypeNames.TsMultirange, @@ -126,18 +126,18 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) else { mappings.AddResolverType[]>(DataTypeNames.TsTzMultirange, - static (options, mapping, dataTypeNameMatch) => mapping.CreateInfo(options, + static (options, mapping, requiresDataTypeName) => mapping.CreateInfo(options, DateTimeConverterResolver.CreateMultirangeResolver[], NpgsqlRange>(options, options.GetCanonicalTypeId(DataTypeNames.TsTzMultirange), options.GetCanonicalTypeId(DataTypeNames.TsMultirange), - options.EnableDateTimeInfinityConversions), dataTypeNameMatch), + options.EnableDateTimeInfinityConversions), requiresDataTypeName), isDefault: true); mappings.AddResolverType>>(DataTypeNames.TsTzMultirange, - static (options, mapping, dataTypeNameMatch) => mapping.CreateInfo(options, + static (options, mapping, requiresDataTypeName) => mapping.CreateInfo(options, DateTimeConverterResolver.CreateMultirangeResolver>, NpgsqlRange>(options, options.GetCanonicalTypeId(DataTypeNames.TsTzMultirange), options.GetCanonicalTypeId(DataTypeNames.TsMultirange), - options.EnableDateTimeInfinityConversions), dataTypeNameMatch)); + options.EnableDateTimeInfinityConversions), requiresDataTypeName)); mappings.AddType[]>(DataTypeNames.TsTzMultirange, static (options, mapping, _) => mapping.CreateInfo(options, CreateArrayMultirangeConverter( @@ -159,24 +159,23 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) CreateListMultirangeConverter(CreateRangeConverter(new Int8Converter(), options), options))); // datemultirange - mappings.AddType[]>(DataTypeNames.DateMultirange, - static (options, mapping, _) => - mapping.CreateInfo(options, CreateArrayMultirangeConverter( - CreateRangeConverter(new DateTimeDateConverter(options.EnableDateTimeInfinityConversions), options), options)), - isDefault: true); - mappings.AddType>>(DataTypeNames.DateMultirange, - static (options, mapping, _) => - mapping.CreateInfo(options, CreateListMultirangeConverter( - CreateRangeConverter(new DateTimeDateConverter(options.EnableDateTimeInfinityConversions), options), options))); mappings.AddType[]>(DataTypeNames.DateMultirange, static (options, mapping, _) => mapping.CreateInfo(options, CreateArrayMultirangeConverter( CreateRangeConverter(new DateOnlyDateConverter(options.EnableDateTimeInfinityConversions), options), options)), isDefault: true); + mappings.AddType[]>(DataTypeNames.DateMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateArrayMultirangeConverter( + CreateRangeConverter(new DateTimeDateConverter(options.EnableDateTimeInfinityConversions), options), options))); mappings.AddType>>(DataTypeNames.DateMultirange, static (options, mapping, _) => mapping.CreateInfo(options, CreateListMultirangeConverter( CreateRangeConverter(new DateOnlyDateConverter(options.EnableDateTimeInfinityConversions), options), options))); + mappings.AddType>>(DataTypeNames.DateMultirange, + static (options, mapping, _) => + mapping.CreateInfo(options, CreateListMultirangeConverter( + CreateRangeConverter(new DateTimeDateConverter(options.EnableDateTimeInfinityConversions), options), options))); return mappings; } diff --git a/src/Npgsql/Internal/ResolverFactories/AdoTypeInfoResolverFactory.Range.cs b/src/Npgsql/Internal/ResolverFactories/AdoTypeInfoResolverFactory.Range.cs index 54ca555cdd..17ba8c3c33 100644 --- a/src/Npgsql/Internal/ResolverFactories/AdoTypeInfoResolverFactory.Range.cs +++ b/src/Npgsql/Internal/ResolverFactories/AdoTypeInfoResolverFactory.Range.cs @@ -48,11 +48,11 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) else { mappings.AddResolverStructType>(DataTypeNames.TsRange, - static (options, mapping, dataTypeNameMatch) => mapping.CreateInfo(options, + static (options, mapping, requiresDataTypeName) => mapping.CreateInfo(options, DateTimeConverterResolver.CreateRangeResolver(options, options.GetCanonicalTypeId(DataTypeNames.TsTzRange), options.GetCanonicalTypeId(DataTypeNames.TsRange), - options.EnableDateTimeInfinityConversions), dataTypeNameMatch), + options.EnableDateTimeInfinityConversions), requiresDataTypeName), isDefault: true); } mappings.AddStructType>(DataTypeNames.TsRange, @@ -73,11 +73,11 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) else { mappings.AddResolverStructType>(DataTypeNames.TsTzRange, - static (options, mapping, dataTypeNameMatch) => mapping.CreateInfo(options, + static (options, mapping, requiresDataTypeName) => mapping.CreateInfo(options, DateTimeConverterResolver.CreateRangeResolver(options, options.GetCanonicalTypeId(DataTypeNames.TsTzRange), options.GetCanonicalTypeId(DataTypeNames.TsRange), - options.EnableDateTimeInfinityConversions), dataTypeNameMatch), + options.EnableDateTimeInfinityConversions), requiresDataTypeName), isDefault: true); mappings.AddStructType>(DataTypeNames.TsTzRange, static (options, mapping, _) => mapping.CreateInfo(options, @@ -87,15 +87,16 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) static (options, mapping, _) => mapping.CreateInfo(options, CreateRangeConverter(new Int8Converter(), options))); // daterange + mappings.AddStructType>(DataTypeNames.DateRange, + static (options, mapping, _) => + mapping.CreateInfo(options, + CreateRangeConverter(new DateOnlyDateConverter(options.EnableDateTimeInfinityConversions), options)), + isDefault: true); mappings.AddStructType>(DataTypeNames.DateRange, static (options, mapping, _) => mapping.CreateInfo(options, - CreateRangeConverter(new DateTimeDateConverter(options.EnableDateTimeInfinityConversions), options)), - isDefault: true); + CreateRangeConverter(new DateTimeDateConverter(options.EnableDateTimeInfinityConversions), options))); mappings.AddStructType>(DataTypeNames.DateRange, static (options, mapping, _) => mapping.CreateInfo(options, CreateRangeConverter(new Int4Converter(), options))); - mappings.AddStructType>(DataTypeNames.DateRange, - static (options, mapping, _) => - mapping.CreateInfo(options, CreateRangeConverter(new DateOnlyDateConverter(options.EnableDateTimeInfinityConversions), options))); return mappings; } diff --git a/src/Npgsql/Internal/ResolverFactories/AdoTypeInfoResolverFactory.cs b/src/Npgsql/Internal/ResolverFactories/AdoTypeInfoResolverFactory.cs index 61f1bbc2f3..8db547315f 100644 --- a/src/Npgsql/Internal/ResolverFactories/AdoTypeInfoResolverFactory.cs +++ b/src/Npgsql/Internal/ResolverFactories/AdoTypeInfoResolverFactory.cs @@ -38,13 +38,12 @@ class Resolver : IPgTypeInfoResolver static PgTypeInfo? GetEnumTypeInfo(Type? type, DataTypeName dataTypeName, PgSerializerOptions options) { - if (type is not null && type != typeof(string)) + if (type is not null && type != typeof(object) && type != typeof(string) + || options.DatabaseInfo.GetPostgresType(dataTypeName) is not PostgresEnumType) return null; - if (options.DatabaseInfo.GetPostgresType(dataTypeName) is not PostgresEnumType) - return null; - - return new PgTypeInfo(options, new StringTextConverter(options.TextEncoding), dataTypeName); + return new PgTypeInfo(options, new StringTextConverter(options.TextEncoding), dataTypeName, + unboxedType: type == typeof(object) ? typeof(string) : null); } static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) @@ -92,10 +91,10 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) mapping => mapping with { MatchRequirement = MatchRequirement.DataTypeName, TypeMatchPredicate = type => typeof(Stream).IsAssignableFrom(type) }); //Special mappings, these have no corresponding array mapping. mappings.AddType(DataTypeNames.Text, - static (options, mapping, _) => mapping.CreateInfo(options, new TextReaderTextConverter(options.TextEncoding), supportsWriting: false, preferredFormat: DataFormat.Text), + static (options, mapping, _) => mapping.CreateInfo(options, new TextReaderTextConverter(options.TextEncoding), preferredFormat: DataFormat.Text, supportsWriting: false), MatchRequirement.DataTypeName); mappings.AddStructType(DataTypeNames.Text, - static (options, mapping, _) => mapping.CreateInfo(options, new GetCharsTextConverter(options.TextEncoding), supportsWriting: false, preferredFormat: DataFormat.Text), + static (options, mapping, _) => mapping.CreateInfo(options, new GetCharsTextConverter(options.TextEncoding), preferredFormat: DataFormat.Text, supportsWriting: false), MatchRequirement.DataTypeName); // Alternative text types @@ -119,10 +118,10 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) mapping => mapping with { MatchRequirement = MatchRequirement.DataTypeName, TypeMatchPredicate = type => typeof(Stream).IsAssignableFrom(type) }); //Special mappings, these have no corresponding array mapping. mappings.AddType(dataTypeName, - static (options, mapping, _) => mapping.CreateInfo(options, new TextReaderTextConverter(options.TextEncoding), supportsWriting: false, preferredFormat: DataFormat.Text), + static (options, mapping, _) => mapping.CreateInfo(options, new TextReaderTextConverter(options.TextEncoding), preferredFormat: DataFormat.Text, supportsWriting: false), MatchRequirement.DataTypeName); mappings.AddStructType(dataTypeName, - static (options, mapping, _) => mapping.CreateInfo(options, new GetCharsTextConverter(options.TextEncoding), supportsWriting: false, preferredFormat: DataFormat.Text), + static (options, mapping, _) => mapping.CreateInfo(options, new GetCharsTextConverter(options.TextEncoding), preferredFormat: DataFormat.Text, supportsWriting: false), MatchRequirement.DataTypeName); } @@ -143,10 +142,10 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) mapping => mapping with { MatchRequirement = MatchRequirement.DataTypeName, TypeMatchPredicate = type => typeof(Stream).IsAssignableFrom(type) }); //Special mappings, these have no corresponding array mapping. mappings.AddType(DataTypeNames.Jsonb, - static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(jsonbVersion, new TextReaderTextConverter(options.TextEncoding)), supportsWriting: false, preferredFormat: DataFormat.Text), + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(jsonbVersion, new TextReaderTextConverter(options.TextEncoding)), preferredFormat: DataFormat.Text, supportsWriting: false), MatchRequirement.DataTypeName); mappings.AddStructType(DataTypeNames.Jsonb, - static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(jsonbVersion, new GetCharsTextConverter(options.TextEncoding)), supportsWriting: false, preferredFormat: DataFormat.Text), + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(jsonbVersion, new GetCharsTextConverter(options.TextEncoding)), preferredFormat: DataFormat.Text, supportsWriting: false), MatchRequirement.DataTypeName); // Jsonpath @@ -155,10 +154,10 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(jsonpathVersion, new StringTextConverter(options.TextEncoding))), isDefault: true); //Special mappings, these have no corresponding array mapping. mappings.AddType(DataTypeNames.Jsonpath, - static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(jsonpathVersion, new TextReaderTextConverter(options.TextEncoding)), supportsWriting: false, preferredFormat: DataFormat.Text), + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(jsonpathVersion, new TextReaderTextConverter(options.TextEncoding)), preferredFormat: DataFormat.Text, supportsWriting: false), MatchRequirement.DataTypeName); mappings.AddStructType(DataTypeNames.Jsonpath, - static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(jsonpathVersion, new GetCharsTextConverter(options.TextEncoding)), supportsWriting: false, preferredFormat: DataFormat.Text), + static (options, mapping, _) => mapping.CreateInfo(options, new VersionPrefixedTextConverter(jsonpathVersion, new GetCharsTextConverter(options.TextEncoding)), preferredFormat: DataFormat.Text, supportsWriting: false), MatchRequirement.DataTypeName); // Bytea @@ -173,7 +172,7 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) // Varbit mappings.AddType(DataTypeNames.Varbit, static (options, mapping, _) => mapping.CreateInfo(options, - new PolymorphicBitStringConverterResolver(options.GetCanonicalTypeId(DataTypeNames.Varbit)), supportsWriting: false)); + new PolymorphicBitStringConverterResolver(options.GetCanonicalTypeId(DataTypeNames.Varbit)), includeDataTypeName: true, supportsWriting: false)); mappings.AddType(DataTypeNames.Varbit, static (options, mapping, _) => mapping.CreateInfo(options, new BitArrayBitStringConverter()), isDefault: true); mappings.AddStructType(DataTypeNames.Varbit, @@ -184,7 +183,7 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) // Bit mappings.AddType(DataTypeNames.Bit, static (options, mapping, _) => mapping.CreateInfo(options, - new PolymorphicBitStringConverterResolver(options.GetCanonicalTypeId(DataTypeNames.Bit)), supportsWriting: false)); + new PolymorphicBitStringConverterResolver(options.GetCanonicalTypeId(DataTypeNames.Bit)), includeDataTypeName: true, supportsWriting: false)); mappings.AddType(DataTypeNames.Bit, static (options, mapping, _) => mapping.CreateInfo(options, new BitArrayBitStringConverter()), isDefault: true); mappings.AddStructType(DataTypeNames.Bit, @@ -202,9 +201,9 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) else { mappings.AddResolverStructType(DataTypeNames.Timestamp, - static (options, mapping, dataTypeNameMatch) => mapping.CreateInfo(options, + static (options, mapping, requiresDataTypeName) => mapping.CreateInfo(options, DateTimeConverterResolver.CreateResolver(options, options.GetCanonicalTypeId(DataTypeNames.TimestampTz), options.GetCanonicalTypeId(DataTypeNames.Timestamp), - options.EnableDateTimeInfinityConversions), dataTypeNameMatch), isDefault: true); + options.EnableDateTimeInfinityConversions), requiresDataTypeName), isDefault: true); } mappings.AddStructType(DataTypeNames.Timestamp, static (options, mapping, _) => mapping.CreateInfo(options, new Int8Converter())); @@ -221,9 +220,9 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) else { mappings.AddResolverStructType(DataTypeNames.TimestampTz, - static (options, mapping, dataTypeNameMatch) => mapping.CreateInfo(options, + static (options, mapping, requiresDataTypeName) => mapping.CreateInfo(options, DateTimeConverterResolver.CreateResolver(options, options.GetCanonicalTypeId(DataTypeNames.TimestampTz), options.GetCanonicalTypeId(DataTypeNames.Timestamp), - options.EnableDateTimeInfinityConversions), dataTypeNameMatch), isDefault: true); + options.EnableDateTimeInfinityConversions), requiresDataTypeName), isDefault: true); mappings.AddStructType(DataTypeNames.TimestampTz, static (options, mapping, _) => mapping.CreateInfo(options, new DateTimeOffsetConverter(options.EnableDateTimeInfinityConversions))); } @@ -231,14 +230,15 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) static (options, mapping, _) => mapping.CreateInfo(options, new Int8Converter())); // Date + mappings.AddStructType(DataTypeNames.Date, + static (options, mapping, _) => + mapping.CreateInfo(options, new DateOnlyDateConverter(options.EnableDateTimeInfinityConversions)), isDefault: true); mappings.AddStructType(DataTypeNames.Date, static (options, mapping, _) => mapping.CreateInfo(options, new DateTimeDateConverter(options.EnableDateTimeInfinityConversions)), MatchRequirement.DataTypeName); mappings.AddStructType(DataTypeNames.Date, static (options, mapping, _) => mapping.CreateInfo(options, new Int4Converter())); - mappings.AddStructType(DataTypeNames.Date, - static (options, mapping, _) => mapping.CreateInfo(options, new DateOnlyDateConverter(options.EnableDateTimeInfinityConversions))); // Interval mappings.AddStructType(DataTypeNames.Interval, @@ -247,12 +247,12 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) static (options, mapping, _) => mapping.CreateInfo(options, new NpgsqlIntervalConverter())); // Time + mappings.AddStructType(DataTypeNames.Time, + static (options, mapping, _) => mapping.CreateInfo(options, new TimeOnlyTimeConverter()), isDefault: true); mappings.AddStructType(DataTypeNames.Time, - static (options, mapping, _) => mapping.CreateInfo(options, new TimeSpanTimeConverter()), isDefault: true); + static (options, mapping, _) => mapping.CreateInfo(options, new TimeSpanTimeConverter())); mappings.AddStructType(DataTypeNames.Time, static (options, mapping, _) => mapping.CreateInfo(options, new Int8Converter())); - mappings.AddStructType(DataTypeNames.Time, - static (options, mapping, _) => mapping.CreateInfo(options, new TimeOnlyTimeConverter())); // TimeTz mappings.AddStructType(DataTypeNames.TimeTz, @@ -447,9 +447,9 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) mappings.AddStructArrayType(DataTypeNames.TimestampTz); // Date + mappings.AddStructArrayType(DataTypeNames.Date); mappings.AddStructArrayType(DataTypeNames.Date); mappings.AddStructArrayType(DataTypeNames.Date); - mappings.AddStructArrayType(DataTypeNames.Date); // Interval mappings.AddStructArrayType(DataTypeNames.Interval); @@ -499,7 +499,7 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) // Probe if there is any mapping at all for this element type. var elementId = options.ToCanonicalTypeId(pgElementType); - if (options.GetDefaultTypeInfo(elementId) is null) + if (options.GetTypeInfoInternal(null, elementId) is null) return null; var mappings = new TypeInfoMappingCollection(); @@ -511,7 +511,8 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) static PgTypeInfo? GetEnumArrayTypeInfo(Type? elementType, PostgresType pgElementType, Type? type, DataTypeName dataTypeName, PgSerializerOptions options) { - if ((type != typeof(object) && elementType is not null && elementType != typeof(string)) || pgElementType is not PostgresEnumType enumType) + if ((type is not null && type != typeof(object) && elementType != typeof(string)) + || pgElementType is not PostgresEnumType enumType) return null; var mappings = new TypeInfoMappingCollection(); diff --git a/src/Npgsql/Internal/ResolverFactories/CubeTypeInfoResolverFactory.cs b/src/Npgsql/Internal/ResolverFactories/CubeTypeInfoResolverFactory.cs new file mode 100644 index 0000000000..90b872f458 --- /dev/null +++ b/src/Npgsql/Internal/ResolverFactories/CubeTypeInfoResolverFactory.cs @@ -0,0 +1,56 @@ +using System; +using Npgsql.Internal.Converters; +using Npgsql.Internal.Postgres; +using Npgsql.Properties; +using NpgsqlTypes; + +namespace Npgsql.Internal.ResolverFactories; + +sealed class CubeTypeInfoResolverFactory : PgTypeInfoResolverFactory +{ + const string CubeTypeName = "cube"; + + public override IPgTypeInfoResolver CreateResolver() => new Resolver(); + public override IPgTypeInfoResolver CreateArrayResolver() => new ArrayResolver(); + + public static void ThrowIfUnsupported(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + { + if (dataTypeName is { UnqualifiedNameSpan: "cube" or "_cube" } || type == typeof(NpgsqlCube)) + throw new NotSupportedException( + string.Format(NpgsqlStrings.CubeNotEnabled, nameof(NpgsqlSlimDataSourceBuilder.EnableCube), + typeof(TBuilder).Name)); + } + + class Resolver : IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + protected TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new()); + + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + mappings.AddStructType(CubeTypeName, + static (options, mapping, _) => mapping.CreateInfo(options, new CubeConverter()), isDefault: true); + + return mappings; + } + } + + sealed class ArrayResolver : Resolver, IPgTypeInfoResolver + { + TypeInfoMappingCollection? _mappings; + new TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(base.Mappings)); + + public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + => Mappings.Find(type, dataTypeName, options); + + static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) + { + mappings.AddStructArrayType(CubeTypeName); + + return mappings; + } + } +} diff --git a/src/Npgsql/Internal/ResolverFactories/JsonDynamicTypeInfoResolverFactory.cs b/src/Npgsql/Internal/ResolverFactories/JsonDynamicTypeInfoResolverFactory.cs index 2515cf9a5b..aca5484b77 100644 --- a/src/Npgsql/Internal/ResolverFactories/JsonDynamicTypeInfoResolverFactory.cs +++ b/src/Npgsql/Internal/ResolverFactories/JsonDynamicTypeInfoResolverFactory.cs @@ -2,7 +2,6 @@ using System.Diagnostics.CodeAnalysis; using System.Text; using System.Text.Json; -using System.Text.Json.Nodes; using System.Text.Json.Serialization.Metadata; using Npgsql.Internal.Converters; using Npgsql.Internal.Postgres; @@ -12,21 +11,14 @@ namespace Npgsql.Internal.ResolverFactories; [RequiresUnreferencedCode("Json serializer may perform reflection on trimmed types.")] [RequiresDynamicCode("Serializing arbitrary types to json can require creating new generic types or methods, which requires creating code at runtime. This may not work when AOT compiling.")] -sealed class JsonDynamicTypeInfoResolverFactory : PgTypeInfoResolverFactory +sealed class JsonDynamicTypeInfoResolverFactory( + Type[]? jsonbClrTypes = null, + Type[]? jsonClrTypes = null, + JsonSerializerOptions? serializerOptions = null) + : PgTypeInfoResolverFactory { - readonly Type[]? _jsonbClrTypes; - readonly Type[]? _jsonClrTypes; - readonly JsonSerializerOptions? _serializerOptions; - - public JsonDynamicTypeInfoResolverFactory(Type[]? jsonbClrTypes = null, Type[]? jsonClrTypes = null, JsonSerializerOptions? serializerOptions = null) - { - _jsonbClrTypes = jsonbClrTypes; - _jsonClrTypes = jsonClrTypes; - _serializerOptions = serializerOptions; - } - - public override IPgTypeInfoResolver CreateResolver() => new Resolver(_jsonbClrTypes, _jsonClrTypes, _serializerOptions); - public override IPgTypeInfoResolver CreateArrayResolver() => new ArrayResolver(_jsonbClrTypes, _jsonClrTypes, _serializerOptions); + public override IPgTypeInfoResolver CreateResolver() => new Resolver(jsonbClrTypes, jsonClrTypes, serializerOptions); + public override IPgTypeInfoResolver CreateArrayResolver() => new ArrayResolver(jsonbClrTypes, jsonClrTypes, serializerOptions); // Split into a nested class to avoid erroneous trimming/AOT warnings because the JsonDynamicTypeInfoResolverFactory is marked as incompatible. internal static class Support @@ -45,29 +37,18 @@ public static void ThrowIfUnsupported(Type? type, DataTypeName? dataTy [RequiresUnreferencedCode("Json serializer may perform reflection on trimmed types.")] [RequiresDynamicCode("Serializing arbitrary types to json can require creating new generic types or methods, which requires creating code at runtime. This may not work when AOT compiling.")] - class Resolver : DynamicTypeInfoResolver, IPgTypeInfoResolver + class Resolver(Type[]? jsonbClrTypes = null, Type[]? jsonClrTypes = null, JsonSerializerOptions? serializerOptions = null) + : DynamicTypeInfoResolver, IPgTypeInfoResolver { - JsonSerializerOptions? _serializerOptions; - JsonSerializerOptions SerializerOptions - #if NET7_0_OR_GREATER - => _serializerOptions ??= JsonSerializerOptions.Default; - #else - => _serializerOptions ??= new(); - #endif - - readonly Type[] _jsonbClrTypes; - readonly Type[] _jsonClrTypes; + JsonSerializerOptions? _serializerOptions = serializerOptions; + JsonSerializerOptions SerializerOptions => _serializerOptions ??= JsonSerializerOptions.Default; + + readonly Type[] _jsonbClrTypes = jsonbClrTypes ?? []; + readonly Type[] _jsonClrTypes = jsonClrTypes ?? []; TypeInfoMappingCollection? _mappings; protected TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(), _jsonbClrTypes, _jsonClrTypes, SerializerOptions); - public Resolver(Type[]? jsonbClrTypes = null, Type[]? jsonClrTypes = null, JsonSerializerOptions? serializerOptions = null) - { - _jsonbClrTypes = jsonbClrTypes ?? Array.Empty(); - _jsonClrTypes = jsonClrTypes ?? Array.Empty(); - _serializerOptions = serializerOptions; - } - public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) => Mappings.Find(type, dataTypeName, options) ?? base.GetTypeInfo(type, dataTypeName, options); @@ -76,20 +57,6 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings, // We do GetTypeInfo calls directly so we need a resolver. serializerOptions.TypeInfoResolver ??= new DefaultJsonTypeInfoResolver(); - // These live in the RUC/RDC part as JsonValues can contain any .NET type. - foreach (var dataTypeName in new[] { DataTypeNames.Jsonb, DataTypeNames.Json }) - { - var jsonb = dataTypeName == DataTypeNames.Jsonb; - mappings.AddType(dataTypeName, (options, mapping, _) => - mapping.CreateInfo(options, new JsonConverter(jsonb, options.TextEncoding, serializerOptions))); - mappings.AddType(dataTypeName, (options, mapping, _) => - mapping.CreateInfo(options, new JsonConverter(jsonb, options.TextEncoding, serializerOptions))); - mappings.AddType(dataTypeName, (options, mapping, _) => - mapping.CreateInfo(options, new JsonConverter(jsonb, options.TextEncoding, serializerOptions))); - mappings.AddType(dataTypeName, (options, mapping, _) => - mapping.CreateInfo(options, new JsonConverter(jsonb, options.TextEncoding, serializerOptions))); - } - AddUserMappings(jsonb: true, jsonbClrTypes); AddUserMappings(jsonb: false, jsonClrTypes); @@ -107,9 +74,15 @@ void AddUserMappings(bool jsonb, Type[] clrTypes) if (!jsonType.IsValueType && jsonTypeInfo.PolymorphismOptions is not null) { foreach (var derived in jsonTypeInfo.PolymorphismOptions.DerivedTypes) + { + // For jsonb we can't properly support polymorphic serialization unless the SerializerOptions.AllowOutOfOrderMetadataProperties is `true`. + // If `jsonb` AND `AllowOutOfOrderMetadataProperties` is `false`, use `derived.DerivedType` as the base type for the converter, + // this causes STJ to stop serializing the "$type" field; essentially disabling the feature. + var baseType = jsonb && !serializerOptions.AllowOutOfOrderMetadataProperties ? derived.DerivedType : jsonType; dynamicMappings.AddMapping(derived.DerivedType, dataTypeName, factory: (options, mapping, _) => mapping.CreateInfo(options, - CreateSystemTextJsonConverter(mapping.Type, jsonb, options.TextEncoding, serializerOptions, jsonType))); + CreateSystemTextJsonConverter(mapping.Type, jsonb, options.TextEncoding, serializerOptions, baseType))); + } } } mappings.AddRange(dynamicMappings.ToTypeInfoMappingCollection()); @@ -129,9 +102,10 @@ void AddUserMappings(bool jsonb, Type[] clrTypes) { var jsonb = dataTypeName == DataTypeNames.Jsonb; - // For jsonb we can't properly support polymorphic serialization unless we do quite some additional work - // so we default to mapping.Type instead (exact types will never serialize their "$type" fields, essentially disabling the feature). - var baseType = jsonb ? mapping.Type : typeof(object); + // For jsonb we can't properly support polymorphic serialization unless the SerializerOptions.AllowOutOfOrderMetadataProperties is `true`. + // If `jsonb` AND `AllowOutOfOrderMetadataProperties` is `false`, use `mapping.Type` as the base type for the converter, + // this causes STJ to stop serializing the "$type" field; essentially disabling the feature. + var baseType = jsonb && !SerializerOptions.AllowOutOfOrderMetadataProperties ? mapping.Type : typeof(object); return mapping.CreateInfo(options, CreateSystemTextJsonConverter(mapping.Type, jsonb, options.TextEncoding, SerializerOptions, baseType)); @@ -148,14 +122,12 @@ static PgConverter CreateSystemTextJsonConverter(Type valueType, bool jsonb, Enc [RequiresUnreferencedCode("Json serializer may perform reflection on trimmed types.")] [RequiresDynamicCode("Serializing arbitrary types to json can require creating new generic types or methods, which requires creating code at runtime. This may not work when AOT compiling.")] - sealed class ArrayResolver : Resolver, IPgTypeInfoResolver + sealed class ArrayResolver(Type[]? jsonbClrTypes = null, Type[]? jsonClrTypes = null, JsonSerializerOptions? serializerOptions = null) + : Resolver(jsonbClrTypes, jsonClrTypes, serializerOptions), IPgTypeInfoResolver { TypeInfoMappingCollection? _mappings; new TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(base.Mappings), base.Mappings); - public ArrayResolver(Type[]? jsonbClrTypes = null, Type[]? jsonClrTypes = null, JsonSerializerOptions? serializerOptions = null) - : base(jsonbClrTypes, jsonClrTypes, serializerOptions) { } - public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) => Mappings.Find(type, dataTypeName, options) ?? base.GetTypeInfo(type, dataTypeName, options); @@ -169,14 +141,6 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings, if (baseMappings.Items.Count == 0) return mappings; - foreach (var dataTypeName in new[] { DataTypeNames.Jsonb, DataTypeNames.Json }) - { - mappings.AddArrayType(dataTypeName); - mappings.AddArrayType(dataTypeName); - mappings.AddArrayType(dataTypeName); - mappings.AddArrayType(dataTypeName); - } - var dynamicMappings = CreateCollection(baseMappings); foreach (var mapping in baseMappings.Items) dynamicMappings.AddArrayMapping(mapping.Type, mapping.DataTypeName); diff --git a/src/Npgsql/Internal/ResolverFactories/JsonTypeInfoResolverFactory.cs b/src/Npgsql/Internal/ResolverFactories/JsonTypeInfoResolverFactory.cs index a94d5d36f8..f778bea186 100644 --- a/src/Npgsql/Internal/ResolverFactories/JsonTypeInfoResolverFactory.cs +++ b/src/Npgsql/Internal/ResolverFactories/JsonTypeInfoResolverFactory.cs @@ -1,19 +1,16 @@ using System; using System.Text.Json; +using System.Text.Json.Nodes; using System.Text.Json.Serialization.Metadata; using Npgsql.Internal.Converters; using Npgsql.Internal.Postgres; namespace Npgsql.Internal.ResolverFactories; -sealed class JsonTypeInfoResolverFactory : PgTypeInfoResolverFactory +sealed class JsonTypeInfoResolverFactory(JsonSerializerOptions? serializerOptions = null) : PgTypeInfoResolverFactory { - readonly JsonSerializerOptions? _serializerOptions; - - public JsonTypeInfoResolverFactory(JsonSerializerOptions? serializerOptions = null) => _serializerOptions = serializerOptions; - - public override IPgTypeInfoResolver CreateResolver() => new Resolver(_serializerOptions); - public override IPgTypeInfoResolver CreateArrayResolver() => new ArrayResolver(_serializerOptions); + public override IPgTypeInfoResolver CreateResolver() => new Resolver(serializerOptions); + public override IPgTypeInfoResolver CreateArrayResolver() => new ArrayResolver(serializerOptions); class Resolver : IPgTypeInfoResolver { @@ -47,11 +44,19 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings, var jsonb = dataTypeName == DataTypeNames.Jsonb; mappings.AddType(dataTypeName, (options, mapping, _) => mapping.CreateInfo(options, - new JsonConverter(jsonb, options.TextEncoding, serializerOptions)), - isDefault: true); + new JsonConverter(jsonb, options.TextEncoding, serializerOptions))); mappings.AddStructType(dataTypeName, (options, mapping, _) => mapping.CreateInfo(options, new JsonConverter(jsonb, options.TextEncoding, serializerOptions))); + + mappings.AddType(dataTypeName, (options, mapping, _) => + mapping.CreateInfo(options, new JsonConverter(jsonb, options.TextEncoding, serializerOptions))); + mappings.AddType(dataTypeName, (options, mapping, _) => + mapping.CreateInfo(options, new JsonConverter(jsonb, options.TextEncoding, serializerOptions))); + mappings.AddType(dataTypeName, (options, mapping, _) => + mapping.CreateInfo(options, new JsonConverter(jsonb, options.TextEncoding, serializerOptions))); + mappings.AddType(dataTypeName, (options, mapping, _) => + mapping.CreateInfo(options, new JsonConverter(jsonb, options.TextEncoding, serializerOptions))); } return mappings; @@ -68,21 +73,22 @@ sealed class BasicJsonTypeInfoResolver : IJsonTypeInfoResolver return JsonMetadataServices.CreateValueInfo(options, JsonMetadataServices.JsonDocumentConverter); if (type == typeof(JsonElement)) return JsonMetadataServices.CreateValueInfo(options, JsonMetadataServices.JsonElementConverter); + if (type == typeof(JsonObject)) + return JsonMetadataServices.CreateValueInfo(options, JsonMetadataServices.JsonObjectConverter); + if (type == typeof(JsonArray)) + return JsonMetadataServices.CreateValueInfo(options, JsonMetadataServices.JsonArrayConverter); + if (type == typeof(JsonValue)) + return JsonMetadataServices.CreateValueInfo(options, JsonMetadataServices.JsonValueConverter); return null; } } } - sealed class ArrayResolver : Resolver, IPgTypeInfoResolver + sealed class ArrayResolver(JsonSerializerOptions? serializerOptions = null) : Resolver(serializerOptions), IPgTypeInfoResolver { TypeInfoMappingCollection? _mappings; new TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(base.Mappings)); - public ArrayResolver(JsonSerializerOptions? serializerOptions = null) - : base(serializerOptions) - { - } - public new PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) => Mappings.Find(type, dataTypeName, options); @@ -92,6 +98,9 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) { mappings.AddArrayType(dataTypeName); mappings.AddStructArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); + mappings.AddArrayType(dataTypeName); } return mappings; diff --git a/src/Npgsql/Internal/ResolverFactories/NetworkTypeInfoResolverFactory.cs b/src/Npgsql/Internal/ResolverFactories/NetworkTypeInfoResolverFactory.cs index da738f54d0..6a2af4453f 100644 --- a/src/Npgsql/Internal/ResolverFactories/NetworkTypeInfoResolverFactory.cs +++ b/src/Npgsql/Internal/ResolverFactories/NetworkTypeInfoResolverFactory.cs @@ -1,5 +1,4 @@ using System; -using System.Diagnostics.CodeAnalysis; using System.Net; using System.Net.NetworkInformation; using Npgsql.Internal.Converters; @@ -31,13 +30,10 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) mapping => mapping with { MatchRequirement = MatchRequirement.DataTypeName }); // inet - // This is one of the rare mappings that force us to use reflection for a lack of any alternative. // There are certain IPAddress values like Loopback or Any that return a *private* derived type (see https://github.com/dotnet/runtime/issues/27870). - // However we still need to be able to resolve an exactly typed converter for those values. - // We do so by wrapping our converter in a casting converter constructed over the derived type. - // Finally we add a custom predicate to be able to match any type which values are assignable to IPAddress. mappings.AddType(DataTypeNames.Inet, - CreateInfo, + static (options, mapping, _) => new PgTypeInfo(options, new IPAddressConverter(), new DataTypeName(mapping.DataTypeName), + unboxedType: mapping.Type != typeof(IPAddress) ? mapping.Type : null), mapping => mapping with { MatchRequirement = MatchRequirement.Single, @@ -47,23 +43,13 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) static (options, mapping, _) => mapping.CreateInfo(options, new NpgsqlInetConverter())); // cidr - mappings.AddStructType(DataTypeNames.Cidr, - static (options, mapping, _) => mapping.CreateInfo(options, new NpgsqlCidrConverter()), isDefault: true); - - // Code is split out to a local method as suppression attributes on lambdas aren't properly handled by the ILLink analyzer yet. - [UnconditionalSuppressMessage("AotAnalysis", "IL3050", - Justification = "MakeGenericType is safe because the target will only ever be a reference type.")] - static PgTypeInfo CreateInfo(PgSerializerOptions options, TypeInfoMapping resolvedMapping, bool _) - { - var derivedType = resolvedMapping.Type != typeof(IPAddress); - PgConverter converter = new IPAddressConverter(); - if (derivedType) - // There is not much more we can do, the deriving type IPAddress+ReadOnlyIPAddress isn't public. - converter = (PgConverter)Activator.CreateInstance(typeof(CastingConverter<>).MakeGenericType(resolvedMapping.Type), - converter)!; + mappings.AddStructType(DataTypeNames.Cidr, + static (options, mapping, _) => mapping.CreateInfo(options, new IPNetworkConverter()), isDefault: true); - return resolvedMapping.CreateInfo(options, converter); - } +#pragma warning disable CS0618 // NpgsqlCidr is obsolete + mappings.AddStructType(DataTypeNames.Cidr, + static (options, mapping, _) => mapping.CreateInfo(options, new NpgsqlCidrConverter())); +#pragma warning restore CS0618 return mappings; } @@ -88,7 +74,10 @@ static TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) mappings.AddStructArrayType(DataTypeNames.Inet); // cidr + mappings.AddStructArrayType(DataTypeNames.Cidr); +#pragma warning disable CS0618 // NpgsqlCidr is obsolete mappings.AddStructArrayType(DataTypeNames.Cidr); +#pragma warning restore CS0618 return mappings; } diff --git a/src/Npgsql/Internal/ResolverFactories/TupledRecordTypeInfoResolverFactory.cs b/src/Npgsql/Internal/ResolverFactories/TupledRecordTypeInfoResolverFactory.cs index 189f84a868..7ee00d37a7 100644 --- a/src/Npgsql/Internal/ResolverFactories/TupledRecordTypeInfoResolverFactory.cs +++ b/src/Npgsql/Internal/ResolverFactories/TupledRecordTypeInfoResolverFactory.cs @@ -46,7 +46,7 @@ class Resolver : DynamicTypeInfoResolver var factory = typeof(Resolver).GetMethod(nameof(CreateFactory), BindingFlags.Static | BindingFlags.NonPublic)! .MakeGenericMethod(mapping.Type) - .Invoke(null, new object[] { constructor, constructor.GetParameters().Length }); + .Invoke(null, [constructor, constructor.GetParameters().Length]); var converterType = typeof(RecordConverter<>).MakeGenericType(mapping.Type); var converter = (PgConverter)Activator.CreateInstance(converterType, options, factory)!; diff --git a/src/Npgsql/Internal/ResolverFactories/UnmappedTypeInfoResolverFactory.cs b/src/Npgsql/Internal/ResolverFactories/UnmappedTypeInfoResolverFactory.cs index a04c3cc111..d3dcabb467 100644 --- a/src/Npgsql/Internal/ResolverFactories/UnmappedTypeInfoResolverFactory.cs +++ b/src/Npgsql/Internal/ResolverFactories/UnmappedTypeInfoResolverFactory.cs @@ -75,11 +75,10 @@ class RangeResolver : DynamicTypeInfoResolver || options.DatabaseInfo.GetPostgresType(dataTypeName) is not PostgresRangeType rangeType) return null; - var subInfo = - matchedType is null - ? options.GetDefaultTypeInfo(rangeType.Subtype) - // Input matchedType here as we don't want an NpgsqlRange over Nullable (it has its own nullability tracking, for better or worse) - : options.GetTypeInfo(matchedType == typeof(object) ? matchedType : matchedType.GetGenericArguments()[0], rangeType.Subtype); + // Input matchedType here as we don't want an NpgsqlRange over Nullable (it has its own nullability tracking, for better or worse) + var subInfo = options.GetTypeInfoInternal( + matchedType is null ? null : matchedType == typeof(object) ? matchedType : matchedType.GetGenericArguments()[0], + options.ToCanonicalTypeId(rangeType.Subtype.GetRepresentationalType())); // We have no generic RangeConverterResolver so we would not know how to compose a range mapping for such infos. // See https://github.com/npgsql/npgsql/issues/5268 @@ -133,10 +132,7 @@ class MultirangeResolver : DynamicTypeInfoResolver || options.DatabaseInfo.GetPostgresType(dataTypeName) is not PostgresMultirangeType multirangeType) return null; - var subInfo = - type is null - ? options.GetDefaultTypeInfo(multirangeType.Subrange) - : options.GetTypeInfo(elementType ?? typeof(object), multirangeType.Subrange); + var subInfo = options.GetTypeInfoInternal(type is null ? null : elementType ?? typeof(object), options.ToCanonicalTypeId(multirangeType.Subrange)); // We have no generic MultirangeConverterResolver so we would not know how to compose a range mapping for such infos. // See https://github.com/npgsql/npgsql/issues/5268 diff --git a/src/Npgsql/Internal/ResolverFactories/UnsupportedTypeInfoResolver.cs b/src/Npgsql/Internal/ResolverFactories/UnsupportedTypeInfoResolver.cs index 2d47f86807..efcc4633ba 100644 --- a/src/Npgsql/Internal/ResolverFactories/UnsupportedTypeInfoResolver.cs +++ b/src/Npgsql/Internal/ResolverFactories/UnsupportedTypeInfoResolver.cs @@ -16,6 +16,7 @@ sealed class UnsupportedTypeInfoResolver : IPgTypeInfoResolver RecordTypeInfoResolverFactory.ThrowIfUnsupported(type, dataTypeName, options); FullTextSearchTypeInfoResolverFactory.ThrowIfUnsupported(type, dataTypeName, options); LTreeTypeInfoResolverFactory.ThrowIfUnsupported(type, dataTypeName, options); + CubeTypeInfoResolverFactory.ThrowIfUnsupported(type, dataTypeName, options); JsonDynamicTypeInfoResolverFactory.Support.ThrowIfUnsupported(type, dataTypeName); diff --git a/src/Npgsql/Internal/Size.cs b/src/Npgsql/Internal/Size.cs index 7f5e52a1f1..299f2bb229 100644 --- a/src/Npgsql/Internal/Size.cs +++ b/src/Npgsql/Internal/Size.cs @@ -1,8 +1,10 @@ using System; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; namespace Npgsql.Internal; +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] public enum SizeKind { Unknown = 0, @@ -10,6 +12,7 @@ public enum SizeKind UpperBound } +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] [DebuggerDisplay("{DebuggerDisplay,nq}")] public readonly struct Size : IEquatable { @@ -41,15 +44,40 @@ public int Value public static Size Unknown { get; } = new(SizeKind.Unknown, 0); public static Size Zero { get; } = new(SizeKind.Exact, 0); - public Size Combine(Size result) + public bool TryCombine(Size other, out Size result) { - if (_kind is SizeKind.Unknown || result._kind is SizeKind.Unknown) + if (_kind is SizeKind.Unknown || other._kind is SizeKind.Unknown) + { + result = Unknown; + return true; + } + + var sum = unchecked(_value + other._value); + if ((_value >= 0 && sum < other._value) || (_value < 0 && sum > other._value)) + { + result = default; + return false; + } + + if (_kind is SizeKind.UpperBound || other._kind is SizeKind.UpperBound) + { + result = CreateUpperBound(sum); + return true; + } + + result = Create(sum); + return true; + } + + public Size Combine(Size other) + { + if (_kind is SizeKind.Unknown || other._kind is SizeKind.Unknown) return Unknown; - if (_kind is SizeKind.UpperBound || result._kind is SizeKind.UpperBound) - return CreateUpperBound((int)Math.Min((long)(_value + result._value), int.MaxValue)); + if (_kind is SizeKind.UpperBound || other._kind is SizeKind.UpperBound) + return CreateUpperBound(checked(_value + other._value)); - return Create((int)Math.Min((long)(_value + result._value), int.MaxValue)); + return Create(checked(_value + other._value)); } public static implicit operator Size(int value) => Create(value); diff --git a/src/Npgsql/Internal/TransportSecurityHandler.cs b/src/Npgsql/Internal/TransportSecurityHandler.cs index ecb447c6da..fbe8cad72e 100644 --- a/src/Npgsql/Internal/TransportSecurityHandler.cs +++ b/src/Npgsql/Internal/TransportSecurityHandler.cs @@ -1,5 +1,6 @@ using System; using System.Security.Cryptography.X509Certificates; +using System.Threading; using System.Threading.Tasks; using Npgsql.Properties; using Npgsql.Util; @@ -10,13 +11,13 @@ class TransportSecurityHandler { public virtual bool SupportEncryption => false; - public virtual Func? RootCertificateCallback + public virtual Func? RootCertificatesCallback { get => throw new NotSupportedException(string.Format(NpgsqlStrings.TransportSecurityDisabled, nameof(NpgsqlSlimDataSourceBuilder.EnableTransportSecurity))); set => throw new NotSupportedException(string.Format(NpgsqlStrings.TransportSecurityDisabled, nameof(NpgsqlSlimDataSourceBuilder.EnableTransportSecurity))); } - public virtual Task NegotiateEncryption(bool async, NpgsqlConnector connector, SslMode sslMode, NpgsqlTimeout timeout) + public virtual Task NegotiateEncryption(bool async, NpgsqlConnector connector, SslMode sslMode, NpgsqlTimeout timeout, CancellationToken cancellationToken) => throw new NotSupportedException(string.Format(NpgsqlStrings.TransportSecurityDisabled, nameof(NpgsqlSlimDataSourceBuilder.EnableTransportSecurity))); public virtual void AuthenticateSASLSha256Plus(NpgsqlConnector connector, ref string mechanism, ref string cbindFlag, ref string cbind, @@ -28,10 +29,10 @@ sealed class RealTransportSecurityHandler : TransportSecurityHandler { public override bool SupportEncryption => true; - public override Func? RootCertificateCallback { get; set; } + public override Func? RootCertificatesCallback { get; set; } - public override Task NegotiateEncryption(bool async, NpgsqlConnector connector, SslMode sslMode, NpgsqlTimeout timeout) - => connector.NegotiateEncryption(sslMode, timeout, async); + public override Task NegotiateEncryption(bool async, NpgsqlConnector connector, SslMode sslMode, NpgsqlTimeout timeout, CancellationToken cancellationToken) + => connector.NegotiateEncryption(sslMode, timeout, async, cancellationToken); public override void AuthenticateSASLSha256Plus(NpgsqlConnector connector, ref string mechanism, ref string cbindFlag, ref string cbind, ref bool successfulBind) diff --git a/src/Npgsql/Internal/TypeInfoCache.cs b/src/Npgsql/Internal/TypeInfoCache.cs index 5c72463d03..91a6de9295 100644 --- a/src/Npgsql/Internal/TypeInfoCache.cs +++ b/src/Npgsql/Internal/TypeInfoCache.cs @@ -1,21 +1,18 @@ using System; using System.Collections.Concurrent; -using System.Runtime.CompilerServices; using Npgsql.Internal.Postgres; namespace Npgsql.Internal; -sealed class TypeInfoCache where TPgTypeId : struct +sealed class TypeInfoCache(PgSerializerOptions options, bool validatePgTypeIds = true) + where TPgTypeId : struct { - readonly PgSerializerOptions _options; - readonly bool _validatePgTypeIds; - // Mostly used for parameter writing, 8ns readonly ConcurrentDictionary _cacheByClrType = new(); // Used for reading, occasionally for parameter writing where a db type was given. // 8ns, about 10ns total to scan an array with 6, 7 different clr types under one pg type - readonly ConcurrentDictionary _cacheByPgTypeId = new(); + readonly ConcurrentDictionary _cacheByPgTypeId = new(); static TypeInfoCache() { @@ -23,32 +20,22 @@ static TypeInfoCache() throw new InvalidOperationException("Cannot use this type argument."); } - public TypeInfoCache(PgSerializerOptions options, bool validatePgTypeIds = true) - { - _options = options; - _validatePgTypeIds = validatePgTypeIds; - } - /// /// /// /// /// - /// - /// When this flag is true, and both type and pgTypeId are non null, a default info for the pgTypeId can be returned if an exact match - /// can't be found. - /// /// /// - public PgTypeInfo? GetOrAddInfo(Type? type, TPgTypeId? pgTypeId, bool defaultTypeFallback = false) + public PgTypeInfo? GetOrAddInfo(Type? type, TPgTypeId? pgTypeId) { if (pgTypeId is { } id) { if (_cacheByPgTypeId.TryGetValue(id, out var infos)) - if (FindMatch(type, infos, defaultTypeFallback) is { } info) + if (FindMatch(type, infos) is { } info) return info; - return AddEntryById(type, id, infos, defaultTypeFallback); + return AddEntryById(type, id, infos); } if (type is not null) @@ -56,33 +43,22 @@ public TypeInfoCache(PgSerializerOptions options, bool validatePgTypeIds = true) return null; - PgTypeInfo? FindMatch(Type? type, (Type? Type, PgTypeInfo? Info)[] infos, bool defaultTypeFallback) + PgTypeInfo? FindMatch(Type? type, (Type? Type, PgTypeInfo Info)[] infos) { - PgTypeInfo? defaultInfo = null; - var negativeExactMatch = false; for (var i = 0; i < infos.Length; i++) { ref var item = ref infos[i]; if (item.Type == type) - { - if (item.Info is not null || !defaultTypeFallback) - return item.Info; - negativeExactMatch = true; - } - - if (defaultTypeFallback && item.Type is null) - defaultInfo = item.Info; + return item.Info; } - // We can only return default info if we've seen a negative match (type: typeof(object), info: null) - // Otherwise we might return a previously requested default while the resolvers could produce the exact match. - return negativeExactMatch ? defaultInfo : null; + return null; } PgTypeInfo? AddByType(Type type) { // We don't pass PgTypeId as we're interested in default converters here. - var info = CreateInfo(type, null, _options, defaultTypeFallback: false, _validatePgTypeIds); + var info = CreateInfo(type, null, options, validatePgTypeIds); return info is null ? null @@ -91,17 +67,17 @@ public TypeInfoCache(PgSerializerOptions options, bool validatePgTypeIds = true) : _cacheByClrType[type]; } - PgTypeInfo? AddEntryById(Type? type, TPgTypeId pgTypeId, (Type? Type, PgTypeInfo? Info)[]? infos, bool defaultTypeFallback) + PgTypeInfo? AddEntryById(Type? type, TPgTypeId pgTypeId, (Type? Type, PgTypeInfo Info)[]? infos) { - // We cache negatives (null info) to allow 'object or default' checks to never hit the resolvers after the first lookup. - var info = CreateInfo(type, pgTypeId, _options, defaultTypeFallback, _validatePgTypeIds); + if (CreateInfo(type, pgTypeId, options, validatePgTypeIds) is not { } info) + return null; - var isDefaultInfo = type is null && info is not null; + var isDefaultInfo = type is null; if (infos is null) { // Also add defaults by their info type to save a future resolver lookup + resize. infos = isDefaultInfo - ? new [] { (type, info), (info!.Type, info) } + ? new [] { (type, info), (info.Type, info) } : new [] { (type, info) }; if (_cacheByPgTypeId.TryAdd(pgTypeId, infos)) @@ -112,7 +88,7 @@ public TypeInfoCache(PgSerializerOptions options, bool validatePgTypeIds = true) while (true) { infos = _cacheByPgTypeId[pgTypeId]; - if (FindMatch(type, infos, defaultTypeFallback) is { } racedInfo) + if (FindMatch(type, infos) is { } racedInfo) return racedInfo; // Also add defaults by their info type to save a future resolver lookup + resize. @@ -121,39 +97,37 @@ public TypeInfoCache(PgSerializerOptions options, bool validatePgTypeIds = true) if (isDefaultInfo) { foreach (var oldInfo in oldInfos) - if (oldInfo.Type == info!.Type) + if (oldInfo.Type == info.Type) hasExactType = true; } Array.Resize(ref infos, oldInfos.Length + (isDefaultInfo && !hasExactType ? 2 : 1)); infos[oldInfos.Length] = (type, info); if (isDefaultInfo && !hasExactType) - infos[oldInfos.Length + 1] = (info!.Type, info); + infos[oldInfos.Length + 1] = (info.Type, info); if (_cacheByPgTypeId.TryUpdate(pgTypeId, infos, oldInfos)) return info; } } - static PgTypeInfo? CreateInfo(Type? type, TPgTypeId? typeId, PgSerializerOptions options, bool defaultTypeFallback, bool validatePgTypeIds) + static PgTypeInfo? CreateInfo(Type? type, TPgTypeId? typeId, PgSerializerOptions options, bool validatePgTypeIds) { var pgTypeId = AsPgTypeId(typeId); // Validate that we only pass data types that are supported by the backend. var dataTypeName = pgTypeId is { } id ? (DataTypeName?)options.DatabaseInfo.GetDataTypeName(id, validate: validatePgTypeIds) : null; var info = options.TypeInfoResolver.GetTypeInfo(type, dataTypeName, options); - if (info is null && defaultTypeFallback) - { - type = null; - info = options.TypeInfoResolver.GetTypeInfo(type, dataTypeName, options); - } - if (info is null) return null; if (pgTypeId is not null && info.PgTypeId != pgTypeId) throw new InvalidOperationException("A Postgres type was passed but the resolved PgTypeInfo does not have an equal PgTypeId."); - if (type is not null && !info.IsBoxing && info.Type != type) - throw new InvalidOperationException($"A CLR type '{type}' was passed but the resolved PgTypeInfo does not have an equal Type: {info.Type}."); + if (type is not null && info.Type != type) + { + // Types were not equal, throw for IsBoxing = false, otherwise we throw when the returned type isn't assignable to the requested type (after unboxing). + if (!info.IsBoxing || !info.Type.IsAssignableTo(type)) + throw new InvalidOperationException($"A CLR type '{type}' was passed but the resolved PgTypeInfo does not have an equal Type: {info.Type}."); + } return info; } @@ -161,8 +135,8 @@ public TypeInfoCache(PgSerializerOptions options, bool validatePgTypeIds = true) static PgTypeId? AsPgTypeId(TPgTypeId? pgTypeId) => pgTypeId switch { - { } id when typeof(TPgTypeId) == typeof(DataTypeName) => new PgTypeId(Unsafe.As(ref id)), - { } id => new PgTypeId(Unsafe.As(ref id)), + { } id when typeof(TPgTypeId) == typeof(DataTypeName) => new((DataTypeName)(object)id), + { } id => new((Oid)(object)id), null => null }; } diff --git a/src/Npgsql/Internal/TypeInfoMapping.cs b/src/Npgsql/Internal/TypeInfoMapping.cs index 00b9ba18ee..1fc028153f 100644 --- a/src/Npgsql/Internal/TypeInfoMapping.cs +++ b/src/Npgsql/Internal/TypeInfoMapping.cs @@ -16,11 +16,13 @@ namespace Npgsql.Internal; /// /// /// -/// -/// Signals whether a resolver based TypeInfo can keep its PgTypeId undecided or whether it should follow mapping.DataTypeName. +/// +/// Relevant for `PgResolverTypeInfo` only: whether the instance can be constructed without passing mapping.DataTypeName, an exception occurs otherwise. /// -public delegate PgTypeInfo TypeInfoFactory(PgSerializerOptions options, TypeInfoMapping mapping, bool resolvedDataTypeName); +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] +public delegate PgTypeInfo TypeInfoFactory(PgSerializerOptions options, TypeInfoMapping mapping, bool requiresDataTypeName); +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] public enum MatchRequirement { /// Match when the clr type and datatype name both match. @@ -33,6 +35,7 @@ public enum MatchRequirement } /// A factory for well-known PgConverters. +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] public static class PgConverterFactory { public static PgConverter CreateArrayMultirangeConverter(PgConverter rangeConverter, PgSerializerOptions options) where T : notnull @@ -55,26 +58,21 @@ public static PgConverter CreatePolymorphicArrayConverter(Func? TypeMatchPredicate { get; init; } public bool TypeEquals(Type type) => TypeMatchPredicate?.Invoke(type) ?? Type == type; - public bool DataTypeNameEquals(string dataTypeName) + + bool DataTypeNameEqualsCore(string dataTypeName) { var span = DataTypeName.AsSpan(); return Postgres.DataTypeName.IsFullyQualified(span) @@ -82,6 +80,18 @@ public bool DataTypeNameEquals(string dataTypeName) : span.Equals(Postgres.DataTypeName.ValidatedName(dataTypeName).UnqualifiedNameSpan, StringComparison.Ordinal); } + internal bool DataTypeNameEquals(DataTypeName dataTypeName) + { + var value = dataTypeName.Value; + return DataTypeNameEqualsCore(value); + } + + public bool DataTypeNameEquals(string dataTypeName) + { + var normalized = Postgres.DataTypeName.NormalizeName(dataTypeName); + return DataTypeNameEqualsCore(normalized); + } + string DebuggerDisplay { get @@ -99,6 +109,7 @@ string DebuggerDisplay } } +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] public sealed class TypeInfoMappingCollection { readonly TypeInfoMappingCollection? _baseCollection; @@ -114,7 +125,7 @@ public TypeInfoMappingCollection(TypeInfoMappingCollection baseCollection) : thi => _baseCollection = baseCollection; public TypeInfoMappingCollection(IEnumerable items) - => _items = new(items); + => _items = [..items]; public IReadOnlyList Items => _items; @@ -127,7 +138,7 @@ public TypeInfoMappingCollection(IEnumerable items) { var looseTypeMatch = mapping.TypeMatchPredicate is { } pred ? pred(type) : type is null || mapping.Type == type; var typeMatch = type is not null && looseTypeMatch; - var dataTypeMatch = dataTypeName is not null && mapping.DataTypeNameEquals(dataTypeName.Value.Value); + var dataTypeMatch = dataTypeName is not null && mapping.DataTypeNameEquals(dataTypeName.Value); var matchRequirement = mapping.MatchRequirement; if (dataTypeMatch && typeMatch @@ -185,39 +196,42 @@ TypeInfoMapping GetMapping(Type type, string dataTypeName) => TryGetMapping(type, dataTypeName, out var info) ? info : throw new InvalidOperationException($"Could not find mapping for {type} <-> {dataTypeName}"); // Helper to eliminate generic display class duplication. - static TypeInfoFactory CreateComposedFactory(Type mappingType, TypeInfoMapping innerMapping, Func mapper, bool copyPreferredFormat = false, bool supportsWriting = true) - => (options, mapping, dataTypeNameMatch) => + static TypeInfoFactory CreateComposedFactory(Type mappingType, TypeInfoMapping innerMapping, Func mapper, bool copyPreferredFormat = false, bool? supportsReading = null, bool? supportsWriting = null) + => (options, mapping, requiresDataTypeName) => { var resolvedInnerMapping = innerMapping; if (!DataTypeName.IsFullyQualified(innerMapping.DataTypeName.AsSpan())) resolvedInnerMapping = innerMapping with { DataTypeName = new DataTypeName(mapping.DataTypeName).Schema + "." + innerMapping.DataTypeName }; - var innerInfo = innerMapping.Factory(options, resolvedInnerMapping, dataTypeNameMatch); + var innerInfo = innerMapping.Factory(options, resolvedInnerMapping, requiresDataTypeName); var converter = mapper(mapping, innerInfo); var preferredFormat = copyPreferredFormat ? innerInfo.PreferredFormat : null; - var writingSupported = supportsWriting && innerInfo.SupportsWriting; var unboxedType = ComputeUnboxedType(defaultType: mappingType, converter.TypeToConvert, mapping.Type); + var readingSupported = innerInfo.SupportsReading && (supportsReading ?? PgTypeInfo.GetDefaultSupportsReading(converter.TypeToConvert, unboxedType)); + var writingSupported = innerInfo.SupportsWriting && (supportsWriting ?? true); return new PgTypeInfo(options, converter, options.GetCanonicalTypeId(new DataTypeName(mapping.DataTypeName)), unboxedType) { PreferredFormat = preferredFormat, + SupportsReading = readingSupported, SupportsWriting = writingSupported }; }; // Helper to eliminate generic display class duplication. - static TypeInfoFactory CreateComposedFactory(Type mappingType, TypeInfoMapping innerMapping, Func mapper, bool copyPreferredFormat = false, bool supportsWriting = true) - => (options, mapping, dataTypeNameMatch) => + static TypeInfoFactory CreateComposedFactory(Type mappingType, TypeInfoMapping innerMapping, Func mapper, bool copyPreferredFormat = false, bool? supportsReading = null, bool? supportsWriting = null) + => (options, mapping, requiresDataTypeName) => { var resolvedInnerMapping = innerMapping; if (!DataTypeName.IsFullyQualified(innerMapping.DataTypeName.AsSpan())) resolvedInnerMapping = innerMapping with { DataTypeName = new DataTypeName(mapping.DataTypeName).Schema + "." + innerMapping.DataTypeName }; - var innerInfo = (PgResolverTypeInfo)innerMapping.Factory(options, resolvedInnerMapping, dataTypeNameMatch); + var innerInfo = (PgResolverTypeInfo)innerMapping.Factory(options, resolvedInnerMapping, requiresDataTypeName); var resolver = mapper(mapping, innerInfo); var preferredFormat = copyPreferredFormat ? innerInfo.PreferredFormat : null; - var writingSupported = supportsWriting && innerInfo.SupportsWriting; var unboxedType = ComputeUnboxedType(defaultType: mappingType, resolver.TypeToConvert, mapping.Type); + var readingSupported = innerInfo.SupportsReading && (supportsReading ?? PgTypeInfo.GetDefaultSupportsReading(resolver.TypeToConvert, unboxedType)); + var writingSupported = innerInfo.SupportsWriting && (supportsWriting ?? true); // We include the data type name if the inner info did so as well. // This way we can rely on its logic around resolvedDataTypeName, including when it ignores that flag. PgTypeId? pgTypeId = innerInfo.PgTypeId is not null @@ -226,6 +240,7 @@ static TypeInfoFactory CreateComposedFactory(Type mappingType, TypeInfoMapping i return new PgResolverTypeInfo(options, resolver, pgTypeId, unboxedType) { PreferredFormat = preferredFormat, + SupportsReading = readingSupported, SupportsWriting = writingSupported }; }; @@ -340,7 +355,7 @@ public void AddArrayType(TypeInfoMapping elementMapping, bool suppress void AddArrayType(TypeInfoMapping elementMapping, Type type, Func converter, Func? typeMatchPredicate = null, bool suppressObjectMapping = false) { - var arrayMapping = new TypeInfoMapping(type, arrayDataTypeName, CreateComposedFactory(type, elementMapping, converter)) + var arrayMapping = new TypeInfoMapping(type, arrayDataTypeName, CreateComposedFactory(type, elementMapping, converter, supportsReading: true)) { MatchRequirement = elementMapping.MatchRequirement, TypeMatchPredicate = typeMatchPredicate @@ -348,12 +363,12 @@ void AddArrayType(TypeInfoMapping elementMapping, Type type, Func + _items.Add(new TypeInfoMapping(typeof(object), arrayDataTypeName, (options, mapping, requiresDataTypeName) => { - if (!dataTypeNameMatch) + if (!requiresDataTypeName) throw new InvalidOperationException("Should not happen, please file a bug."); - return arrayMapping.Factory(options, mapping, dataTypeNameMatch); + return arrayMapping.Factory(options, mapping, requiresDataTypeName); })); } } @@ -380,7 +395,7 @@ public void AddResolverArrayType(TypeInfoMapping elementMapping, bool void AddResolverArrayType(TypeInfoMapping elementMapping, Type type, Func converter, Func? typeMatchPredicate = null, bool suppressObjectMapping = false) { - var arrayMapping = new TypeInfoMapping(type, arrayDataTypeName, CreateComposedFactory(type, elementMapping, converter)) + var arrayMapping = new TypeInfoMapping(type, arrayDataTypeName, CreateComposedFactory(type, elementMapping, converter, supportsReading: true)) { MatchRequirement = elementMapping.MatchRequirement, TypeMatchPredicate = typeMatchPredicate @@ -388,12 +403,12 @@ void AddResolverArrayType(TypeInfoMapping elementMapping, Type type, Func + _items.Add(new TypeInfoMapping(typeof(object), arrayDataTypeName, (options, mapping, requiresDataTypeName) => { - if (!dataTypeNameMatch) + if (!requiresDataTypeName) throw new InvalidOperationException("Should not happen, please file a bug."); - return arrayMapping.Factory(options, mapping, dataTypeNameMatch); + return arrayMapping.Factory(options, mapping, requiresDataTypeName); })); } } @@ -472,12 +487,12 @@ void AddStructArrayType(TypeInfoMapping elementMapping, TypeInfoMapping nullable Func? typeMatchPredicate, Func? nullableTypeMatchPredicate, bool suppressObjectMapping) { var arrayDataTypeName = GetArrayDataTypeName(elementMapping.DataTypeName); - var arrayMapping = new TypeInfoMapping(type, arrayDataTypeName, CreateComposedFactory(type, elementMapping, converter)) + var arrayMapping = new TypeInfoMapping(type, arrayDataTypeName, CreateComposedFactory(type, elementMapping, converter, supportsReading: true)) { MatchRequirement = elementMapping.MatchRequirement, TypeMatchPredicate = typeMatchPredicate }; - var nullableArrayMapping = new TypeInfoMapping(nullableType, arrayDataTypeName, CreateComposedFactory(nullableType, nullableElementMapping, nullableConverter)) + var nullableArrayMapping = new TypeInfoMapping(nullableType, arrayDataTypeName, CreateComposedFactory(nullableType, nullableElementMapping, nullableConverter, supportsReading: true)) { MatchRequirement = arrayMapping.MatchRequirement, TypeMatchPredicate = nullableTypeMatchPredicate @@ -487,16 +502,16 @@ void AddStructArrayType(TypeInfoMapping elementMapping, TypeInfoMapping nullable _items.Add(nullableArrayMapping); suppressObjectMapping = suppressObjectMapping || arrayMapping.TypeEquals(typeof(object)); if (!suppressObjectMapping && arrayMapping.MatchRequirement is MatchRequirement.DataTypeName or MatchRequirement.Single) - _items.Add(new TypeInfoMapping(typeof(object), arrayDataTypeName, (options, mapping, dataTypeNameMatch) => + _items.Add(new TypeInfoMapping(typeof(object), arrayDataTypeName, (options, mapping, requiresDataTypeName) => { return options.ArrayNullabilityMode switch { - _ when !dataTypeNameMatch => throw new InvalidOperationException("Should not happen, please file a bug."), - ArrayNullabilityMode.Never => arrayMapping.Factory(options, mapping, dataTypeNameMatch), - ArrayNullabilityMode.Always => nullableArrayMapping.Factory(options, mapping, dataTypeNameMatch), + _ when !requiresDataTypeName => throw new InvalidOperationException("Should not happen, please file a bug."), + ArrayNullabilityMode.Never => arrayMapping.Factory(options, mapping, requiresDataTypeName), + ArrayNullabilityMode.Always => nullableArrayMapping.Factory(options, mapping, requiresDataTypeName), ArrayNullabilityMode.PerInstance => CreateComposedPerInstance( - arrayMapping.Factory(options, mapping, dataTypeNameMatch), - nullableArrayMapping.Factory(options, mapping, dataTypeNameMatch), + arrayMapping.Factory(options, mapping, requiresDataTypeName), + nullableArrayMapping.Factory(options, mapping, requiresDataTypeName), mapping.DataTypeName ), _ => throw new ArgumentOutOfRangeException() @@ -590,12 +605,12 @@ void AddResolverStructArrayType(TypeInfoMapping elementMapping, TypeInfoMapping { var arrayDataTypeName = GetArrayDataTypeName(elementMapping.DataTypeName); - var arrayMapping = new TypeInfoMapping(type, arrayDataTypeName, CreateComposedFactory(type, elementMapping, converter)) + var arrayMapping = new TypeInfoMapping(type, arrayDataTypeName, CreateComposedFactory(type, elementMapping, converter, supportsReading: true)) { MatchRequirement = elementMapping.MatchRequirement, TypeMatchPredicate = typeMatchPredicate }; - var nullableArrayMapping = new TypeInfoMapping(nullableType, arrayDataTypeName, CreateComposedFactory(nullableType, nullableElementMapping, nullableConverter)) + var nullableArrayMapping = new TypeInfoMapping(nullableType, arrayDataTypeName, CreateComposedFactory(nullableType, nullableElementMapping, nullableConverter, supportsReading: true)) { MatchRequirement = elementMapping.MatchRequirement, TypeMatchPredicate = nullableTypeMatchPredicate @@ -605,14 +620,14 @@ void AddResolverStructArrayType(TypeInfoMapping elementMapping, TypeInfoMapping _items.Add(nullableArrayMapping); suppressObjectMapping = suppressObjectMapping || arrayMapping.TypeEquals(typeof(object)); if (!suppressObjectMapping && arrayMapping.MatchRequirement is MatchRequirement.DataTypeName or MatchRequirement.Single) - _items.Add(new TypeInfoMapping(typeof(object), arrayDataTypeName, (options, mapping, dataTypeNameMatch) => options.ArrayNullabilityMode switch + _items.Add(new TypeInfoMapping(typeof(object), arrayDataTypeName, (options, mapping, requiresDataTypeName) => options.ArrayNullabilityMode switch { - _ when !dataTypeNameMatch => throw new InvalidOperationException("Should not happen, please file a bug."), - ArrayNullabilityMode.Never => arrayMapping.Factory(options, mapping, dataTypeNameMatch), - ArrayNullabilityMode.Always => nullableArrayMapping.Factory(options, mapping, dataTypeNameMatch), + _ when !requiresDataTypeName => throw new InvalidOperationException("Should not happen, please file a bug."), + ArrayNullabilityMode.Never => arrayMapping.Factory(options, mapping, requiresDataTypeName), + ArrayNullabilityMode.Always => nullableArrayMapping.Factory(options, mapping, requiresDataTypeName), ArrayNullabilityMode.PerInstance => CreateComposedPerInstance( - arrayMapping.Factory(options, mapping, dataTypeNameMatch), - nullableArrayMapping.Factory(options, mapping, dataTypeNameMatch), + arrayMapping.Factory(options, mapping, requiresDataTypeName), + nullableArrayMapping.Factory(options, mapping, requiresDataTypeName), mapping.DataTypeName ), _ => throw new ArgumentOutOfRangeException() @@ -625,7 +640,7 @@ PgTypeInfo CreateComposedPerInstance(PgTypeInfo innerTypeInfo, PgTypeInfo nullab (PgResolverTypeInfo)nullableInnerTypeInfo); return new PgResolverTypeInfo(innerTypeInfo.Options, resolver, - innerTypeInfo.Options.GetCanonicalTypeId(new DataTypeName(dataTypeName))) { SupportsWriting = false }; + innerTypeInfo.Options.GetCanonicalTypeId(new DataTypeName(dataTypeName)), unboxedType: typeof(Array)) { SupportsWriting = false }; } } @@ -643,7 +658,7 @@ void AddPolymorphicResolverArrayType(TypeInfoMapping elementMapping, Type type, { var arrayDataTypeName = GetArrayDataTypeName(elementMapping.DataTypeName); var mapping = new TypeInfoMapping(type, arrayDataTypeName, - CreateComposedFactory(typeof(Array), elementMapping, converter, supportsWriting: false)) + CreateComposedFactory(typeof(Array), elementMapping, converter, supportsReading: true, supportsWriting: false)) { MatchRequirement = elementMapping.MatchRequirement, TypeMatchPredicate = typeMatchPredicate @@ -731,6 +746,7 @@ static void ThrowBoxingNotSupported(bool resolver) => throw new InvalidOperationException($"Boxing converters are not supported, manually construct a mapping over a casting converter{(resolver ? " resolver" : "")} instead."); } +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] public static class TypeInfoMappingHelpers { internal static bool TryResolveFullyQualifiedName(PgSerializerOptions options, string dataTypeName, out DataTypeName fqDataTypeName) @@ -754,6 +770,31 @@ internal static bool TryResolveFullyQualifiedName(PgSerializerOptions options, s internal static PostgresType GetPgType(this TypeInfoMapping mapping, PgSerializerOptions options) => options.DatabaseInfo.GetPostgresType(new DataTypeName(mapping.DataTypeName)); + // NOTE: This method exists since 9.0 to be able to deprecate the method below that has optional arguments in 10.0 (potentially removing it directly or in 11.0). + // It reduces how binary breaking that change will be if this method would not be there to be picked for the most common invocations. + /// + /// Creates a PgTypeInfo from a mapping, optins, and a converter. + /// + /// The mapping to create an info for. + /// The options to use. + /// The converter to create a PgTypeInfo for. + /// The created info instance. + public static PgTypeInfo CreateInfo(this TypeInfoMapping mapping, PgSerializerOptions options, PgConverter converter) + => new(options, converter, new DataTypeName(mapping.DataTypeName)) + { + PreferredFormat = null, + SupportsWriting = true + }; + + /// + /// Creates a PgTypeInfo from a mapping, options, and a converter. + /// + /// The mapping to create an info for. + /// The options to use. + /// The converter to create a PgTypeInfo for. + /// Whether to prefer a specific data format for this info, when null it defaults to the most suitable format. + /// Whether the converters returned from the given converter resolver support writing. + /// The created info instance. public static PgTypeInfo CreateInfo(this TypeInfoMapping mapping, PgSerializerOptions options, PgConverter converter, DataFormat? preferredFormat = null, bool supportsWriting = true) => new(options, converter, new DataTypeName(mapping.DataTypeName)) { @@ -761,7 +802,33 @@ public static PgTypeInfo CreateInfo(this TypeInfoMapping mapping, PgSerializerOp SupportsWriting = supportsWriting }; - public static PgResolverTypeInfo CreateInfo(this TypeInfoMapping mapping, PgSerializerOptions options, PgConverterResolver resolver, bool includeDataTypeName = true, DataFormat? preferredFormat = null, bool supportsWriting = true) + // NOTE: This method exists since 9.0 to be able to deprecate the method below that has optional arguments in 10.0 (potentially removing it directly or in 11.0). + // It reduces how binary breaking that change will be if this method would not be there to be picked for the most common invocations. + /// + /// Creates a PgResolverTypeInfo from a mapping, options, and a converter resolver. + /// + /// The mapping to create an info for. + /// The options to use. + /// The resolver to create a PgResolverTypeInfo for. + /// Whether to pass mapping.DataTypeName to the PgResolverTypeInfo constructor, mandatory when TypeInfoFactory(..., requiresDataTypeName: true). + /// The created info instance. + public static PgResolverTypeInfo CreateInfo(this TypeInfoMapping mapping, PgSerializerOptions options, PgConverterResolver resolver, bool includeDataTypeName) + => new(options, resolver, includeDataTypeName ? new DataTypeName(mapping.DataTypeName) : null) + { + PreferredFormat = null + }; + + /// + /// Creates a PgResolverTypeInfo from a mapping, options, and a converter resolver. + /// + /// The mapping to create an info for. + /// The options to use. + /// The converter resolver to create a PgResolverTypeInfo for. + /// Whether to pass mapping.DataTypeName to the PgResolverTypeInfo constructor, mandatory when TypeInfoFactory(..., requiresDataTypeName: true). + /// Whether to prefer a specific data format for this info, when null it defaults to the most suitable format. + /// Whether the converters returned from the given converter resolver support writing. + /// The created info instance. + public static PgResolverTypeInfo CreateInfo(this TypeInfoMapping mapping, PgSerializerOptions options, PgConverterResolver resolver, bool includeDataTypeName, DataFormat? preferredFormat = null, bool supportsWriting = true) => new(options, resolver, includeDataTypeName ? new DataTypeName(mapping.DataTypeName) : null) { PreferredFormat = preferredFormat, diff --git a/src/Npgsql/Internal/ValueMetadata.cs b/src/Npgsql/Internal/ValueMetadata.cs index ff041a3060..b71028c4a1 100644 --- a/src/Npgsql/Internal/ValueMetadata.cs +++ b/src/Npgsql/Internal/ValueMetadata.cs @@ -1,5 +1,8 @@ +using System.Diagnostics.CodeAnalysis; + namespace Npgsql.Internal; +[Experimental(NpgsqlDiagnostics.ConvertersExperimental)] public readonly struct ValueMetadata { public required DataFormat Format { get; init; } diff --git a/src/Npgsql/KerberosUsernameProvider.cs b/src/Npgsql/KerberosUsernameProvider.cs index 0395bca337..6963e139f0 100644 --- a/src/Npgsql/KerberosUsernameProvider.cs +++ b/src/Npgsql/KerberosUsernameProvider.cs @@ -11,9 +11,9 @@ namespace Npgsql; /// Launches MIT Kerberos klist and parses out the default principal from it. /// Caches the result. /// -sealed class KerberosUsernameProvider +static class KerberosUsernameProvider { - static bool _performedDetection; + static volatile bool _performedDetection; static string? _principalWithRealm; static string? _principalWithoutRealm; @@ -61,11 +61,7 @@ sealed class KerberosUsernameProvider var line = default(string); for (var i = 0; i < 2; i++) // ReSharper disable once MethodHasAsyncOverload -#if NET7_0_OR_GREATER if ((line = async ? await process.StandardOutput.ReadLineAsync(cancellationToken).ConfigureAwait(false) : process.StandardOutput.ReadLine()) == null) -#else - if ((line = async ? await process.StandardOutput.ReadLineAsync().ConfigureAwait(false) : process.StandardOutput.ReadLine()) == null) -#endif { connectionLogger.LogDebug("Unexpected output from klist, aborting Kerberos username detection"); return null; @@ -104,7 +100,7 @@ sealed class KerberosUsernameProvider static string? FindInPath(string name) { - foreach (var p in Environment.GetEnvironmentVariable("PATH")?.Split(Path.PathSeparator) ?? Array.Empty()) + foreach (var p in Environment.GetEnvironmentVariable("PATH")?.Split(Path.PathSeparator) ?? []) { var path = Path.Combine(p, name); if (File.Exists(path)) diff --git a/src/Npgsql/LogMessages.cs b/src/Npgsql/LogMessages.cs index 8d5f471c27..757f972764 100644 --- a/src/Npgsql/LogMessages.cs +++ b/src/Npgsql/LogMessages.cs @@ -26,12 +26,6 @@ static partial class LogMessages Message = "Opened connection to {Host}:{Port}/{Database}")] internal static partial void OpenedConnection(ILogger logger, string Host, int Port, string Database, string ConnectionString, int ConnectorId); - [LoggerMessage( - EventId = NpgsqlEventId.OpenedConnection, - Level = LogLevel.Debug, - Message = "Opened multiplexing connection to {Host}:{Port}/{Database}")] - internal static partial void OpenedMultiplexingConnection(ILogger logger, string Host, int Port, string Database, string ConnectionString); - [LoggerMessage( EventId = NpgsqlEventId.ClosingConnection, Level = LogLevel.Trace, @@ -44,12 +38,6 @@ static partial class LogMessages Message = "Closed connection to {Host}:{Port}/{Database}")] internal static partial void ClosedConnection(ILogger logger, string Host, int Port, string Database, string ConnectionString, int ConnectorId); - [LoggerMessage( - EventId = NpgsqlEventId.ClosedConnection, - Level = LogLevel.Debug, - Message = "Closed multiplexing connection to {Host}:{Port}/{Database}")] - internal static partial void ClosedMultiplexingConnection(ILogger logger, string Host, int Port, string Database, string ConnectionString); - [LoggerMessage( EventId = NpgsqlEventId.OpeningPhysicalConnection, Level = LogLevel.Trace, @@ -134,12 +122,6 @@ static partial class LogMessages Message = "Exception while closing connector")] internal static partial void ExceptionWhenClosingPhysicalConnection(ILogger logger, int ConnectorId, Exception exception); - [LoggerMessage( - EventId = NpgsqlEventId.ExceptionWhenOpeningConnectionForMultiplexing, - Level = LogLevel.Error, - Message = "Exception opening a connection for multiplexing")] - internal static partial void ExceptionWhenOpeningConnectionForMultiplexing(ILogger logger, Exception exception); - [LoggerMessage( Level = LogLevel.Trace, Message = "Start user action")] @@ -180,7 +162,7 @@ static partial class LogMessages Level = LogLevel.Debug, Message = "Executing batch: {BatchCommands}", SkipEnabledCheck = true)] - internal static partial void ExecutingBatchWithParameters(ILogger logger, (string CommandText, object[] Parameters)[] BatchCommands, int ConnectorId); + internal static partial void ExecutingBatchWithParameters(ILogger logger, (string CommandText, IEnumerable Parameters)[] BatchCommands, int ConnectorId); [LoggerMessage( EventId = NpgsqlEventId.CommandExecutionCompleted, @@ -209,7 +191,7 @@ static partial class LogMessages Message = "Batch execution completed (duration={DurationMs}ms): {BatchCommands}", SkipEnabledCheck = true)] internal static partial void BatchExecutionCompletedWithParameters( - ILogger logger, (string CommandText, object[] Parameters)[] BatchCommands, long DurationMs, int ConnectorId); + ILogger logger, (string CommandText, IEnumerable Parameters)[] BatchCommands, long DurationMs, int ConnectorId); [LoggerMessage( EventId = NpgsqlEventId.CancellingCommand, @@ -254,12 +236,6 @@ internal static partial void BatchExecutionCompletedWithParameters( Message = "Deriving Parameters for query: {CommandText}")] internal static partial void DerivingParameters(ILogger logger, string CommandText, int ConnectorId); - [LoggerMessage( - EventId = NpgsqlEventId.ExceptionWhenWritingMultiplexedCommands, - Level = LogLevel.Error, - Message = "Exception while writing multiplexed commands")] - internal static partial void ExceptionWhenWritingMultiplexedCommands(ILogger logger, int ConnectorId, Exception exception); - [LoggerMessage( Level = LogLevel.Trace, Message = "Cleaning up reader")] diff --git a/src/Npgsql/MetricsReporter.cs b/src/Npgsql/MetricsReporter.cs index f806e44852..431c0ea734 100644 --- a/src/Npgsql/MetricsReporter.cs +++ b/src/Npgsql/MetricsReporter.cs @@ -1,7 +1,6 @@ -using System; - namespace Npgsql; +using System; using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.Metrics; @@ -9,7 +8,7 @@ namespace Npgsql; using System.Threading; // .NET docs on metric instrumentation: https://learn.microsoft.com/en-us/dotnet/core/diagnostics/metrics-instrumentation -// OpenTelemetry semantic conventions for database metric: https://opentelemetry.io/docs/specs/otel/metrics/semantic_conventions/database-metrics +// OpenTelemetry semantic conventions for database metric: https://opentelemetry.io/docs/specs/semconv/database/database-metrics sealed class MetricsReporter : IDisposable { const string Version = "0.1.0"; @@ -29,9 +28,16 @@ sealed class MetricsReporter : IDisposable static readonly ObservableGauge PreparedRatio; readonly NpgsqlDataSource _dataSource; + readonly KeyValuePair _poolNameTag; + readonly TagList _durationMetricTags; + + static readonly List Reporters = []; - static readonly List Reporters = new(); + static readonly InstrumentAdvice ShortHistogramAdvice = new() + { + HistogramBucketBoundaries = [0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 5, 10] + }; CommandCounters _commandCounters; @@ -47,64 +53,68 @@ static MetricsReporter() { Meter = new("Npgsql", Version); + // db.client.operation.duration is stable in the OpenTelemetry spec + CommandDuration = Meter.CreateHistogram( + "db.client.operation.duration", + unit: "s", + description: "Duration of database client operations.", + advice: ShortHistogramAdvice); + + // From here, metrics have "development" status (not stable) + Meter.CreateObservableUpDownCounter( + "db.client.connection.count", + GetConnectionCount, + unit: "{connection}", + description: "The number of connections that are currently in state described by the state attribute."); + + // It's a bit ridiculous to manage "max connections" as an observable counter, given that it never changes for a given pool. + // However, we can't simply report it once at startup, since clients who connect later wouldn't have it. And since reporting it + // repeatedly isn't possible because we need to provide incremental figures, we just manage it as an observable counter. + Meter.CreateObservableUpDownCounter( + "db.client.connection.max", + GetConnectionMax, + unit: "{connection}", + description: "The maximum number of open connections allowed."); + + // From here, metrics are entirely Npgsql-specific and not covered by the OpenTelemetry spec. CommandsExecuting = Meter.CreateUpDownCounter( - "db.client.commands.executing", + "db.client.operation.npgsql.executing", unit: "{command}", description: "The number of currently executing database commands."); CommandsFailed = Meter.CreateCounter( - "db.client.commands.failed", + "db.client.operation.failed", unit: "{command}", description: "The number of database commands which have failed."); - CommandDuration = Meter.CreateHistogram( - "db.client.commands.duration", - unit: "s", - description: "The duration of database commands, in seconds."); - BytesWritten = Meter.CreateCounter( - "db.client.commands.bytes_written", + "db.client.operation.npgsql.bytes_written", unit: "By", description: "The number of bytes written."); BytesRead = Meter.CreateCounter( - "db.client.commands.bytes_read", + "db.client.operation.npgsql.bytes_read", unit: "By", description: "The number of bytes read."); PendingConnectionRequests = Meter.CreateUpDownCounter( - "db.client.connections.pending_requests", + "db.client.connection.npgsql.pending_requests", unit: "{request}", description: "The number of pending requests for an open connection, cumulative for the entire pool."); ConnectionTimeouts = Meter.CreateCounter( - "db.client.connections.timeouts", + "db.client.connection.npgsql.timeouts", unit: "{timeout}", description: "The number of connection timeouts that have occurred trying to obtain a connection from the pool."); ConnectionCreateTime = Meter.CreateHistogram( - "db.client.connections.create_time", + "db.client.connection.npgsql.create_time", unit: "s", - description: "The time it took to create a new connection."); - - // Observable metrics; these are for values we already track internally (and efficiently) inside the connection pool implementation. - Meter.CreateObservableUpDownCounter( - "db.client.connections.usage", - GetConnectionUsage, - unit: "{connection}", - description: "The number of connections that are currently in state described by the state attribute."); - - // It's a bit ridiculous to manage "max connections" as an observable counter, given that it never changes for a given pool. - // However, we can't simply report it once at startup, since clients who connect later wouldn't have it. And since reporting it - // repeatedly isn't possible because we need to provide incremental figures, we just manage it as an observable counter. - Meter.CreateObservableUpDownCounter( - "db.client.connections.max", - GetMaxConnections, - unit: "{connection}", - description: "The maximum number of open connections allowed."); + description: "The time it took to create a new connection.", + advice: ShortHistogramAdvice); PreparedRatio = Meter.CreateObservableGauge( - "db.client.commands.prepared_ratio", + "db.client.operation.npgsql.prepared_ratio", GetPreparedCommandsRatio, description: "The ratio of prepared command executions."); } @@ -112,7 +122,16 @@ static MetricsReporter() public MetricsReporter(NpgsqlDataSource dataSource) { _dataSource = dataSource; - _poolNameTag = new KeyValuePair("pool.name", dataSource.Name); + _poolNameTag = new KeyValuePair("db.client.connection.pool.name", dataSource.Name); + + _durationMetricTags = new TagList + { + // TODO: Vary this for PG-like databases (e.g. CockroachDB)? + { "db.system.name", "postgresql" }, + { "db.client.connection.pool.name", _dataSource.Name }, + { "server.address", _dataSource.Settings.Host }, + { "server.port", _dataSource.Settings.Port } + }; lock (Reporters) { @@ -136,12 +155,7 @@ internal void ReportCommandStop(long startTimestamp) if (CommandDuration.Enabled && startTimestamp > 0) { -#if NET7_0_OR_GREATER - var duration = Stopwatch.GetElapsedTime(startTimestamp); -#else - var duration = new TimeSpan((long)((Stopwatch.GetTimestamp() - startTimestamp) * StopWatchTickFrequency)); -#endif - CommandDuration.Record(duration.TotalSeconds, _poolNameTag); + CommandDuration.Record(Stopwatch.GetElapsedTime(startTimestamp).TotalSeconds, _durationMetricTags); } } @@ -167,7 +181,7 @@ internal void ReportPendingConnectionRequestStop() internal void ReportConnectionCreateTime(TimeSpan duration) => ConnectionCreateTime.Record(duration.TotalSeconds, _poolNameTag); - static IEnumerable> GetConnectionUsage() + static IEnumerable> GetConnectionCount() { lock (Reporters) { @@ -177,27 +191,23 @@ static IEnumerable> GetConnectionUsage() { var reporter = Reporters[i]; - if (reporter._dataSource is PoolingDataSource poolingDataSource) - { - var stats = poolingDataSource.Statistics; - - measurements.Add(new Measurement( - stats.Idle, - reporter._poolNameTag, - new KeyValuePair("state", "idle"))); + var connectionStats = reporter._dataSource.Statistics; + measurements.Add(new Measurement( + connectionStats.Idle, + reporter._poolNameTag, + new KeyValuePair("db.client.connection.state", "idle"))); - measurements.Add(new Measurement( - stats.Busy, - reporter._poolNameTag, - new KeyValuePair("state", "used"))); - } + measurements.Add(new Measurement( + connectionStats.Busy, + reporter._poolNameTag, + new KeyValuePair("db.client.connection.state", "used"))); } return measurements; } } - static IEnumerable> GetMaxConnections() + static IEnumerable> GetConnectionMax() { lock (Reporters) { @@ -247,11 +257,4 @@ public void Dispose() Reporters.Remove(this); } } - -#if !NET7_0_OR_GREATER - const long TicksPerMicrosecond = 10; - const long TicksPerMillisecond = TicksPerMicrosecond * 1000; - const long TicksPerSecond = TicksPerMillisecond * 1000; // 10,000,000 - static readonly double StopWatchTickFrequency = (double)TicksPerSecond / Stopwatch.Frequency; -#endif } diff --git a/src/Npgsql/MultiHostDataSourceWrapper.cs b/src/Npgsql/MultiHostDataSourceWrapper.cs index 4dcded98cc..432875ae67 100644 --- a/src/Npgsql/MultiHostDataSourceWrapper.cs +++ b/src/Npgsql/MultiHostDataSourceWrapper.cs @@ -7,15 +7,14 @@ namespace Npgsql; -sealed class MultiHostDataSourceWrapper : NpgsqlDataSource +sealed class MultiHostDataSourceWrapper(NpgsqlMultiHostDataSource wrappedSource, TargetSessionAttributes targetSessionAttributes) + : NpgsqlDataSource(CloneSettingsForTargetSessionAttributes(wrappedSource.Settings, targetSessionAttributes), wrappedSource.Configuration, reportMetrics: false) { - internal override bool OwnsConnectors => false; + internal NpgsqlMultiHostDataSource WrappedSource { get; } = wrappedSource; - readonly NpgsqlMultiHostDataSource _wrappedSource; + internal override bool OwnsConnectors => false; - public MultiHostDataSourceWrapper(NpgsqlMultiHostDataSource source, TargetSessionAttributes targetSessionAttributes) - : base(CloneSettingsForTargetSessionAttributes(source.Settings, targetSessionAttributes), source.Configuration) - => _wrappedSource = source; + public override void Clear() => WrappedSource.Clear(); static NpgsqlConnectionStringBuilder CloneSettingsForTargetSessionAttributes( NpgsqlConnectionStringBuilder settings, @@ -26,23 +25,22 @@ static NpgsqlConnectionStringBuilder CloneSettingsForTargetSessionAttributes( return clonedSettings; } - internal override (int Total, int Idle, int Busy) Statistics => _wrappedSource.Statistics; + internal override (int Total, int Idle, int Busy) Statistics => WrappedSource.Statistics; - internal override void Clear() => _wrappedSource.Clear(); internal override ValueTask Get(NpgsqlConnection conn, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken) - => _wrappedSource.Get(conn, timeout, async, cancellationToken); + => WrappedSource.Get(conn, timeout, async, cancellationToken); internal override bool TryGetIdleConnector([NotNullWhen(true)] out NpgsqlConnector? connector) => throw new NpgsqlException("Npgsql bug: trying to get an idle connector from " + nameof(MultiHostDataSourceWrapper)); internal override ValueTask OpenNewConnector(NpgsqlConnection conn, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken) => throw new NpgsqlException("Npgsql bug: trying to open a new connector from " + nameof(MultiHostDataSourceWrapper)); internal override void Return(NpgsqlConnector connector) - => _wrappedSource.Return(connector); + => WrappedSource.Return(connector); internal override void AddPendingEnlistedConnector(NpgsqlConnector connector, Transaction transaction) - => _wrappedSource.AddPendingEnlistedConnector(connector, transaction); + => WrappedSource.AddPendingEnlistedConnector(connector, transaction); internal override bool TryRemovePendingEnlistedConnector(NpgsqlConnector connector, Transaction transaction) - => _wrappedSource.TryRemovePendingEnlistedConnector(connector, transaction); + => WrappedSource.TryRemovePendingEnlistedConnector(connector, transaction); internal override bool TryRentEnlistedPending(Transaction transaction, NpgsqlConnection connection, [NotNullWhen(true)] out NpgsqlConnector? connector) - => _wrappedSource.TryRentEnlistedPending(transaction, connection, out connector); -} \ No newline at end of file + => WrappedSource.TryRentEnlistedPending(transaction, connection, out connector); +} diff --git a/src/Npgsql/MultiplexingDataSource.cs b/src/Npgsql/MultiplexingDataSource.cs deleted file mode 100644 index 277bc4e835..0000000000 --- a/src/Npgsql/MultiplexingDataSource.cs +++ /dev/null @@ -1,400 +0,0 @@ -using System; -using System.Diagnostics; -using System.Runtime.CompilerServices; -using System.Threading; -using System.Threading.Channels; -using System.Threading.Tasks; -using Microsoft.Extensions.Logging; -using Npgsql.Internal; -using Npgsql.Util; - -namespace Npgsql; - -sealed class MultiplexingDataSource : PoolingDataSource -{ - readonly ILogger _connectionLogger; - readonly ILogger _commandLogger; - - readonly bool _autoPrepare; - - readonly ChannelReader _multiplexCommandReader; - internal ChannelWriter MultiplexCommandWriter { get; } - - readonly Task _multiplexWriteLoop; - - /// - /// When multiplexing is enabled, determines the maximum number of outgoing bytes to buffer before - /// flushing to the network. - /// - readonly int _writeCoalescingBufferThresholdBytes; - - // TODO: Make this configurable - const int MultiplexingCommandChannelBound = 4096; - - internal MultiplexingDataSource( - NpgsqlConnectionStringBuilder settings, - NpgsqlDataSourceConfiguration dataSourceConfig, - NpgsqlMultiHostDataSource? parentPool = null) - : base(settings, dataSourceConfig, parentPool) - { - Debug.Assert(Settings.Multiplexing); - - // TODO: Validate multiplexing options are set only when Multiplexing is on - - _autoPrepare = settings.MaxAutoPrepare > 0; - - _writeCoalescingBufferThresholdBytes = Settings.WriteCoalescingBufferThresholdBytes; - - var multiplexCommandChannel = Channel.CreateBounded( - new BoundedChannelOptions(MultiplexingCommandChannelBound) - { - FullMode = BoundedChannelFullMode.Wait, - SingleReader = true - }); - _multiplexCommandReader = multiplexCommandChannel.Reader; - MultiplexCommandWriter = multiplexCommandChannel.Writer; - - _connectionLogger = dataSourceConfig.LoggingConfiguration.ConnectionLogger; - _commandLogger = dataSourceConfig.LoggingConfiguration.CommandLogger; - - _multiplexWriteLoop = Task.Run(MultiplexingWriteLoop, CancellationToken.None) - .ContinueWith(t => - { - if (t.IsFaulted) - { - // Note that MultiplexingWriteLoop should never throw an exception - everything should be caught and handled internally. - _connectionLogger.LogError(t.Exception, "Exception in multiplexing write loop, this is an Npgsql bug, please file an issue."); - } - }); - } - - async Task MultiplexingWriteLoop() - { - // This method is async, but only ever yields when there are no pending commands in the command channel. - // No I/O should ever be performed asynchronously, as that would block further writing for the entire - // application; whenever I/O cannot complete immediately, we chain a callback with ContinueWith and move - // on to the next connector. - Debug.Assert(_multiplexCommandReader != null); - - var stats = new MultiplexingStats { Stopwatch = new Stopwatch() }; - - while (true) - { - NpgsqlConnector? connector; - NpgsqlCommand? command; - - try - { - // Get a first command out. - if (!_multiplexCommandReader.TryRead(out command)) - command = await _multiplexCommandReader.ReadAsync().ConfigureAwait(false); - } - catch (ChannelClosedException) - { - return; - } - - try - { - // First step is to get a connector on which to execute - var spinwait = new SpinWait(); - while (true) - { - if (TryGetIdleConnector(out connector)) - { - // See increment under over-capacity mode below - Interlocked.Increment(ref connector.CommandsInFlightCount); - break; - } - - connector = await OpenNewConnector( - command.InternalConnection!, - new NpgsqlTimeout(TimeSpan.FromSeconds(Settings.Timeout)), - async: true, - CancellationToken.None).ConfigureAwait(false); - - if (connector != null) - { - // Managed to created a new connector - connector.Connection = null; - - // See increment under over-capacity mode below - Interlocked.Increment(ref connector.CommandsInFlightCount); - - break; - } - - // There were no idle connectors and we're at max capacity, so we can't open a new one. - // Enter over-capacity mode - find an unlocked connector with the least currently in-flight - // commands and sent on it, even though there are already pending commands. - var minInFlight = int.MaxValue; - foreach (var c in Connectors) - { - if (c?.MultiplexAsyncWritingLock == 0 && c.CommandsInFlightCount < minInFlight) - { - minInFlight = c.CommandsInFlightCount; - connector = c; - } - } - - // There could be no writable connectors (all stuck in transaction or flushing). - if (connector == null) - { - // TODO: This is problematic - when absolutely all connectors are both busy *and* currently - // performing (async) I/O, this will spin-wait. - // We could call WaitAsync, but that would wait for an idle connector, whereas we want any - // writeable (non-writing) connector even if it has in-flight commands. Maybe something - // with better back-off. - // On the other hand, this is exactly *one* thread doing spin-wait, maybe not that bad. - spinwait.SpinOnce(); - continue; - } - - // We may be in a race condition with the connector read loop, which may be currently returning - // the connector to the Idle channel (because it has completed all commands). - // Increment the in-flight count to make sure the connector isn't returned as idle. - var newInFlight = Interlocked.Increment(ref connector.CommandsInFlightCount); - if (newInFlight == 1) - { - // The connector's in-flight was 0, so it was idle - abort over-capacity read - // and retry the normal flow. - Interlocked.Decrement(ref connector.CommandsInFlightCount); - spinwait.SpinOnce(); - continue; - } - - break; - } - } - catch (Exception exception) - { - LogMessages.ExceptionWhenOpeningConnectionForMultiplexing(_connectionLogger, exception); - - // Fail the first command in the channel as a way of bubbling the exception up to the user - command.ExecutionCompletion.SetException(exception); - - continue; - } - - // We now have a ready connector, and can start writing commands to it. - Debug.Assert(connector != null); - - try - { - stats.Reset(); - connector.FlagAsNotWritableForMultiplexing(); - command.TraceCommandStart(connector); - - // Read queued commands and write them to the connector's buffer, for as long as we're - // under our write threshold and timer delay. - // Note we already have one command we read above, and have already updated the connector's - // CommandsInFlightCount. Now write that command. - var first = true; - bool writtenSynchronously; - do - { - if (first) - first = false; - else - Interlocked.Increment(ref connector.CommandsInFlightCount); - writtenSynchronously = WriteCommand(connector, command, ref stats); - } while (connector.WriteBuffer.WritePosition < _writeCoalescingBufferThresholdBytes && - writtenSynchronously && - _multiplexCommandReader.TryRead(out command)); - - // If all commands were written synchronously (good path), complete the write here, flushing - // and updating statistics. If not, CompleteRewrite is scheduled to run later, when the async - // operations complete, so skip it and continue. - if (writtenSynchronously) - Flush(connector, ref stats); - } - catch (Exception ex) - { - FailWrite(connector, ex); - } - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - bool WriteCommand(NpgsqlConnector connector, NpgsqlCommand command, ref MultiplexingStats stats) - { - // Note: this method *never* awaits on I/O - doing so would suspend all outgoing multiplexing commands - // for the entire pool. In the normal/fast case, writing the command is purely synchronous (serialize - // to buffer in memory), and the actual flush will occur at the level above. For cases where the - // command overflows the buffer, async I/O is done, and we schedule continuations separately - - // but the main thread continues to handle other commands on other connectors. - if (_autoPrepare) - { - // TODO: Need to log based on numPrepared like in non-multiplexing mode... - for (var i = 0; i < command.InternalBatchCommands.Count; i++) - command.InternalBatchCommands[i].TryAutoPrepare(connector); - } - - var written = connector.CommandsInFlightWriter!.TryWrite(command); - Debug.Assert(written, $"Failed to enqueue command to {connector.CommandsInFlightWriter}"); - - // Purposefully don't wait for I/O to complete - var task = command.Write(connector, async: true, flush: false); - stats.NumCommands++; - - switch (task.Status) - { - case TaskStatus.RanToCompletion: - return true; - - case TaskStatus.Faulted: - task.GetAwaiter().GetResult(); // Throw the exception - return true; - - case TaskStatus.WaitingForActivation: - case TaskStatus.Running: - { - // Asynchronous completion, which means the writing is flushing to network and there's actual I/O - // (i.e. a big command which overflowed our buffer). - // We don't (ever) await in the write loop, so remove the connector from the writable list (as it's - // still flushing) and schedule a continuation to continue taking care of this connector. - // The write loop continues to the next connector. - - // Create a copy of the statistics and purposefully box it via the closure. We need a separate - // copy of the stats for the async writing that will continue in parallel with this loop. - var clonedStats = stats.Clone(); - - // ReSharper disable once MethodSupportsCancellation - task.ContinueWith((t, o) => - { - var conn = (NpgsqlConnector)o!; - - if (t.IsFaulted) - { - FailWrite(conn, t.Exception!.InnerException!); - return; - } - - // There's almost certainly more buffered outgoing data for the command, after the flush - // occured. Complete the write, which will flush again (and update statistics). - try - { - Flush(conn, ref clonedStats); - } - catch (Exception e) - { - FailWrite(conn, e); - } - }, connector); - - return false; - } - - default: - Debug.Fail("When writing command to connector, task is in invalid state " + task.Status); - ThrowHelper.ThrowNpgsqlException("When writing command to connector, task is in invalid state " + task.Status); - return false; - } - } - - void Flush(NpgsqlConnector connector, ref MultiplexingStats stats) - { - var task = connector.Flush(async: true); - switch (task.Status) - { - case TaskStatus.RanToCompletion: - CompleteWrite(connector, ref stats); - return; - - case TaskStatus.Faulted: - task.GetAwaiter().GetResult(); // Throw the exception - return; - - case TaskStatus.WaitingForActivation: - case TaskStatus.Running: - { - // Asynchronous completion - the flush didn't complete immediately (e.g. TCP zero window). - - // Create a copy of the statistics and purposefully box it via the closure. We need a separate - // copy of the stats for the async writing that will continue in parallel with this loop. - var clonedStats = stats.Clone(); - - task.ContinueWith((t, o) => - { - var conn = (NpgsqlConnector)o!; - if (t.IsFaulted) - { - FailWrite(conn, t.Exception!.InnerException!); - return; - } - - CompleteWrite(conn, ref clonedStats); - }, connector); - - return; - } - - default: - Debug.Fail("When flushing, task is in invalid state " + task.Status); - ThrowHelper.ThrowNpgsqlException("When flushing, task is in invalid state " + task.Status); - return; - } - } - - void FailWrite(NpgsqlConnector connector, Exception exception) - { - // Note that all commands already passed validation. This means any error here is either an unrecoverable network issue - // (in which case we're already broken), or some other issue while writing (e.g. invalid UTF8 characters in the SQL query) - - // unrecoverable in any case. - - // All commands enqueued in CommandsInFlightWriter will be drained by the reader and failed. - // Note that some of these commands where only written to the connector's buffer, but never - // actually sent - because of a later exception. - // In theory, we could track commands that were only enqueued and not sent, and retry those - // (on another connector), but that would add some book-keeping and complexity, and in any case - // if one connector was broken, chances are that all are (networking). - Debug.Assert(connector.IsBroken); - - LogMessages.ExceptionWhenWritingMultiplexedCommands(_commandLogger, connector.Id, exception); - } - - static void CompleteWrite(NpgsqlConnector connector, ref MultiplexingStats stats) - { - // All I/O has completed, mark this connector as safe for writing again. - // This will allow the connector to be returned to the pool by its read loop, and also to be selected - // for over-capacity write. - connector.FlagAsWritableForMultiplexing(); - - NpgsqlEventSource.Log.MultiplexingBatchSent(stats.NumCommands, stats.Stopwatch); - } - - // ReSharper disable once FunctionNeverReturns - } - - protected override void DisposeBase() - { - MultiplexCommandWriter.Complete(new ObjectDisposedException(nameof(MultiplexingDataSource))); - _multiplexWriteLoop.GetAwaiter().GetResult(); - base.DisposeBase(); - } - - protected override async ValueTask DisposeAsyncBase() - { - MultiplexCommandWriter.Complete(new ObjectDisposedException(nameof(MultiplexingDataSource))); - await _multiplexWriteLoop.ConfigureAwait(false); - await base.DisposeAsyncBase().ConfigureAwait(false); - } - - struct MultiplexingStats - { - internal Stopwatch Stopwatch; - internal int NumCommands; - - internal void Reset() - { - NumCommands = 0; - Stopwatch.Reset(); - } - - internal MultiplexingStats Clone() - { - var clone = new MultiplexingStats { Stopwatch = Stopwatch, NumCommands = NumCommands }; - Stopwatch = new Stopwatch(); - return clone; - } - } -} diff --git a/src/Npgsql/NameTranslation/NpgsqlSnakeCaseNameTranslator.cs b/src/Npgsql/NameTranslation/NpgsqlSnakeCaseNameTranslator.cs index 760ddb1e5a..998b5f6420 100644 --- a/src/Npgsql/NameTranslation/NpgsqlSnakeCaseNameTranslator.cs +++ b/src/Npgsql/NameTranslation/NpgsqlSnakeCaseNameTranslator.cs @@ -55,8 +55,7 @@ public NpgsqlSnakeCaseNameTranslator(bool legacyMode, CultureInfo? culture = nul /// public string TranslateMemberName(string clrName) { - if (clrName == null) - throw new ArgumentNullException(nameof(clrName)); + ArgumentNullException.ThrowIfNull(clrName); return LegacyMode ? string.Concat(LegacyModeMap(clrName)).ToLower(_culture) diff --git a/src/Npgsql/Npgsql.csproj b/src/Npgsql/Npgsql.csproj index ecae24940a..a4e47f12cf 100644 --- a/src/Npgsql/Npgsql.csproj +++ b/src/Npgsql/Npgsql.csproj @@ -5,30 +5,22 @@ Npgsql is the open source .NET data provider for PostgreSQL. npgsql;postgresql;postgres;ado;ado.net;database;sql README.md - net6.0;net8.0 - net8.0 + net10.0 $(NoWarn);CA2017 + $(NoWarn);NPG9001 + $(NoWarn);NPG9002 + $(NoWarn);NPG9003 + - - - - - - - - - - - @@ -41,6 +33,7 @@ + True True diff --git a/src/Npgsql/NpgsqlActivitySource.cs b/src/Npgsql/NpgsqlActivitySource.cs index 224bb2e658..be4e257c48 100644 --- a/src/Npgsql/NpgsqlActivitySource.cs +++ b/src/Npgsql/NpgsqlActivitySource.cs @@ -4,117 +4,185 @@ using System.Diagnostics; using System.Net; using System.Net.Sockets; +using System.Reflection; namespace Npgsql; +// Semantic conventions for database client spans: https://opentelemetry.io/docs/specs/semconv/database/database-spans/ +// Semantic conventions for PostgreSQL client operations: https://opentelemetry.io/docs/specs/semconv/database/postgresql/ static class NpgsqlActivitySource { - static readonly ActivitySource Source = new("Npgsql", "0.1.0"); + static readonly ActivitySource Source = new("Npgsql", GetLibraryVersion()); internal static bool IsEnabled => Source.HasListeners(); - internal static Activity? CommandStart(NpgsqlConnector connector, string commandText, CommandType commandType) + internal static Activity? CommandStart(string commandText, CommandType commandType, bool? prepared, string? spanName) { - var settings = connector.Settings; + string? operationName = null; - var dbName = settings.Database ?? connector.InferredUserName; - string? dbOperation = null; - string? dbSqlTable = null; - string activityName; switch (commandType) { case CommandType.StoredProcedure: - dbOperation = NpgsqlCommand.EnableStoredProcedureCompatMode ? "SELECT" : "CALL"; - // In this case our activity name follows the concept of the CommandType.TableDirect case - // (" .") but replaces db.sql.table with the procedure name - // which seems to match the spec's intent without being explicitly specified that way (it suggests - // using the procedure name but doesn't mention using db.operation or db.name in that case). - activityName = $"{dbOperation} {dbName}.{commandText}"; + // We follow the {db.operation.name} {target} pattern of the spec, with the operation being SELECT/CALL and + // the target being the stored procedure name. + operationName = NpgsqlCommand.EnableStoredProcedureCompatMode ? "SELECT" : "CALL"; + spanName ??= $"{operationName} {commandText}"; break; case CommandType.TableDirect: - dbOperation = "SELECT"; - // The OpenTelemetry spec actually asks to include the database name into db.sql.table - // but then again mixes the concept of database and schema. - // As I interpret it, it actually wants db.sql.table to include the schema name and not the - // database name if the concept of schemas exists in the database system. - // This also makes sense in the context of the activity name which otherwise would include the - // database name twice. - dbSqlTable = commandText; - activityName = $"{dbOperation} {dbName}.{dbSqlTable}"; + // We follow the {db.operation.name} {target} pattern of the spec, with the operation being SELECT and + // the target being the table (collection) name. + operationName = "SELECT"; + spanName ??= $"{operationName} {commandText}"; break; case CommandType.Text: - activityName = dbName; + // We don't have db.query.summary, db.operation.name or target (without parsing SQL), + // so we fall back to db.system.name as per the specs. + spanName ??= "postgresql"; break; default: throw new ArgumentOutOfRangeException(nameof(commandType), commandType, null); } - var activity = Source.StartActivity(activityName, ActivityKind.Client); + var activity = Source.StartActivity(spanName, ActivityKind.Client); if (activity is not { IsAllDataRequested: true }) return activity; - activity.SetTag("db.system", "postgresql"); - activity.SetTag("db.connection_string", connector.UserFacingConnectionString); - activity.SetTag("db.user", connector.InferredUserName); - // We trace the actual (maybe inferred) database name we're connected to, even if it - // wasn't specified in the connection string - activity.SetTag("db.name", dbName); - activity.SetTag("db.statement", commandText); - activity.SetTag("db.connection_id", connector.Id); - if (dbOperation != null) - activity.SetTag("db.operation", dbOperation); - if (dbSqlTable != null) - activity.SetTag("db.sql.table", dbSqlTable); + activity.SetTag("db.query.text", commandText); + + if (prepared is true) + activity.SetTag("db.npgsql.prepared", true); + + switch (commandType) + { + case CommandType.StoredProcedure: + Debug.Assert(operationName is not null); + activity.SetTag("db.operation.name", operationName); + activity.SetTag("db.stored_procedure.name", commandText); + break; + case CommandType.TableDirect: + Debug.Assert(operationName is not null); + activity.SetTag("db.operation.name", operationName); + activity.SetTag("db.collection.name", commandText); + break; + } + + return activity; + } + + internal static Activity? PhysicalConnectionOpen(NpgsqlConnector connector) + { + if (!connector.DataSource.Configuration.TracingOptions.EnablePhysicalOpenTracing) + return null; + + // Note that physical connection open is not part of the OpenTelemetry spec. + // We emit it if enabled, following the general name/tags guidelines. + var dbName = connector.Settings.Database ?? connector.InferredUserName; + var activity = Source.StartActivity("CONNECT " + dbName, ActivityKind.Client); + if (activity is not { IsAllDataRequested: true }) + return activity; + + // We set these basic tags on the activity so that they're populated even when the physical open fails. + activity.SetTag("db.system.name", "postgresql"); + activity.SetTag("db.npgsql.data_source", connector.DataSource.Name); + + return activity; + } + + internal static void Enrich(Activity activity, NpgsqlConnector connector) + { + if (!activity.IsAllDataRequested) + return; + + activity.SetTag("db.system.name", "postgresql"); + + // TODO: For now, we only set the database name, without adding the first schema in the search_path + // as per the PG tracing specs (https://opentelemetry.io/docs/specs/semconv/database/postgresql/). + // See #6336 + activity.SetTag("db.namespace", connector.Settings.Database ?? connector.InferredUserName); var endPoint = connector.ConnectedEndPoint; Debug.Assert(endPoint is not null); + activity.SetTag("server.address", connector.Host); switch (endPoint) { case IPEndPoint ipEndPoint: - activity.SetTag("net.transport", "ip_tcp"); - activity.SetTag("net.peer.ip", ipEndPoint.Address.ToString()); if (ipEndPoint.Port != 5432) - activity.SetTag("net.peer.port", ipEndPoint.Port); - activity.SetTag("net.peer.name", settings.Host); + activity.SetTag("server.port", ipEndPoint.Port); break; case UnixDomainSocketEndPoint: - activity.SetTag("net.transport", "unix"); - activity.SetTag("net.peer.name", settings.Host); break; default: - throw new ArgumentOutOfRangeException("Invalid endpoint type: " + endPoint.GetType()); + throw new UnreachableException("Invalid endpoint type: " + endPoint.GetType()); } - return activity; + // Npgsql-specific tags + activity.SetTag("db.npgsql.data_source", connector.DataSource.Name); + activity.SetTag("db.npgsql.connection_id", connector.Id); } - internal static void ReceivedFirstResponse(Activity activity) + internal static void ReceivedFirstResponse(Activity activity, NpgsqlTracingOptions tracingOptions) { + if (!activity.IsAllDataRequested || !tracingOptions.EnableFirstResponseEvent) + return; + var activityEvent = new ActivityEvent("received-first-response"); activity.AddEvent(activityEvent); } - internal static void CommandStop(Activity activity) + internal static void SetException(Activity activity, Exception exception, bool escaped = true) { - activity.SetTag("otel.status_code", "OK"); + activity.AddException(exception); + + if (exception is PostgresException { SqlState: var sqlState }) + { + activity.SetTag("db.response.status_code", sqlState); + + // error.type SHOULD match the db.response.status_code returned by the database or the client library, or the canonical name of exception that occurred. + // Since we don't have a table to map the error code to a textual description, the SQL state is the best we can do. + activity.SetTag("error.type", sqlState); + } + else + { + if (exception is NpgsqlException { InnerException: Exception innerException }) + exception = innerException; + + activity.SetTag("error.type", exception.GetType().FullName); + } + + var statusDescription = exception is PostgresException pgEx ? pgEx.SqlState : exception.Message; + activity.SetStatus(ActivityStatusCode.Error, statusDescription); activity.Dispose(); } - internal static void SetException(Activity activity, Exception ex, bool escaped = true) + internal static Activity? CopyStart(string command, NpgsqlConnector connector, string? spanName, string operation) { - var tags = new ActivityTagsCollection - { - { "exception.type", ex.GetType().FullName }, - { "exception.message", ex.Message }, - { "exception.stacktrace", ex.ToString() }, - { "exception.escaped", escaped } - }; - var activityEvent = new ActivityEvent("exception", tags: tags); - activity.AddEvent(activityEvent); - activity.SetTag("otel.status_code", "ERROR"); - activity.SetTag("otel.status_description", ex is PostgresException pgEx ? pgEx.SqlState : ex.Message); + var activity = Source.StartActivity(spanName ?? operation, ActivityKind.Client); + if (activity is not { IsAllDataRequested: true }) + return activity; + activity.SetTag("db.query.text", command); + activity.SetTag("db.operation.name", operation); + Enrich(activity, connector); + return activity; + } + + internal static void SetOperation(Activity activity, string operation) + { + if (!activity.IsAllDataRequested) + return; + activity.SetTag("db.operation.name", operation); + } + + internal static void CopyStop(Activity activity, ulong? rows = null) + { + if (rows.HasValue) + activity.SetTag("db.npgsql.rows", rows.Value); activity.Dispose(); } + + static string GetLibraryVersion() + => typeof(NpgsqlDataSource).Assembly + .GetCustomAttribute()? + .InformationalVersion ?? "UNKNOWN"; } diff --git a/src/Npgsql/NpgsqlBatch.cs b/src/Npgsql/NpgsqlBatch.cs index 446cb4746f..e692199e2b 100644 --- a/src/Npgsql/NpgsqlBatch.cs +++ b/src/Npgsql/NpgsqlBatch.cs @@ -100,7 +100,7 @@ internal bool AllResultTypesAreUnknown public NpgsqlBatch(NpgsqlConnection? connection = null, NpgsqlTransaction? transaction = null) { GC.SuppressFinalize(this); - Command = new(DefaultBatchCommandsSize); + Command = new(this, DefaultBatchCommandsSize); BatchCommands = new NpgsqlBatchCommandCollection(Command.InternalBatchCommands); Connection = connection; @@ -110,14 +110,14 @@ public NpgsqlBatch(NpgsqlConnection? connection = null, NpgsqlTransaction? trans internal NpgsqlBatch(NpgsqlConnector connector) { GC.SuppressFinalize(this); - Command = new(connector, DefaultBatchCommandsSize); + Command = new(this, connector, DefaultBatchCommandsSize); BatchCommands = new NpgsqlBatchCommandCollection(Command.InternalBatchCommands); } - private protected NpgsqlBatch(NpgsqlDataSourceCommand command) + private protected NpgsqlBatch(Func commandFactory, NpgsqlConnection connection) { GC.SuppressFinalize(this); - Command = command; + Command = commandFactory(connection, this); BatchCommands = new NpgsqlBatchCommandCollection(Command.InternalBatchCommands); } diff --git a/src/Npgsql/NpgsqlBatchCommand.cs b/src/Npgsql/NpgsqlBatchCommand.cs index 8175afa614..2812fdbead 100644 --- a/src/Npgsql/NpgsqlBatchCommand.cs +++ b/src/Npgsql/NpgsqlBatchCommand.cs @@ -1,10 +1,12 @@ using System; +using System.Buffers; using System.Collections.Generic; using System.Data; using System.Data.Common; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; +using Microsoft.Extensions.Logging; using Npgsql.BackendMessages; using Npgsql.Internal; @@ -13,7 +15,7 @@ namespace Npgsql; /// public sealed class NpgsqlBatchCommand : DbBatchCommand { - internal static readonly List EmptyParameters = new(); + internal static readonly List EmptyParameters = []; string _commandText; @@ -39,31 +41,15 @@ public override string CommandText internal NpgsqlParameterCollection? _parameters; /// - public new NpgsqlParameterCollection Parameters => _parameters ??= new(); + public new NpgsqlParameterCollection Parameters => _parameters ??= []; + internal bool HasOutputParameters => _parameters?.HasOutputParameters == true; -#if NET8_0_OR_GREATER /// - public override NpgsqlParameter CreateParameter() -#else - /// - /// Creates a new instance of a object. - /// - /// An object. - public NpgsqlParameter CreateParameter() -#endif - => new(); + public override NpgsqlParameter CreateParameter() => new(); -#if NET8_0_OR_GREATER /// - public override bool CanCreateParameter -#else - /// - /// Returns whether the method is implemented. - /// - public bool CanCreateParameter -#endif - => true; + public override bool CanCreateParameter => true; /// @@ -149,7 +135,7 @@ public override int RecordsAffected /// internal List PositionalParameters { - get => _inputParameters ??= _ownedInputParameters ??= new(); + get => _inputParameters ??= _ownedInputParameters ??= []; set => _inputParameters = value; } @@ -183,10 +169,10 @@ internal RowDescriptionMessage? Description /// internal PreparedStatement? PreparedStatement { - get => _preparedStatement != null && _preparedStatement.State == PreparedState.Unprepared + get => _preparedStatement is { State: PreparedState.Unprepared } ? _preparedStatement = null : _preparedStatement; - set => _preparedStatement = value; + private set => _preparedStatement = value; } PreparedStatement? _preparedStatement; @@ -198,7 +184,7 @@ internal PreparedStatement? PreparedStatement /// /// Holds the server-side (prepared) ASCII statement name. Empty string for non-prepared statements. /// - internal byte[] StatementName => PreparedStatement?.Name ?? Array.Empty(); + internal byte[] StatementName => PreparedStatement?.Name ?? []; /// /// Whether this statement has already been prepared (including automatic preparation). @@ -290,8 +276,114 @@ internal void ApplyCommandComplete(CommandCompleteMessage msg) internal void ResetPreparation() { - PreparedStatement = null; ConnectorPreparedOn = null; + PreparedStatement = null; + } + + internal void PopulateOutputParameters(NpgsqlDataReader reader, ILogger logger) + { + Debug.Assert(_parameters is not null); + var parameters = _parameters; + var fieldCount = reader.FieldCount; + switch (parameters.PlaceholderType) + { + case PlaceholderType.Mixed: + case PlaceholderType.Named: + { + // In the case of named and mixed parameters we first try to populate all parameters with a named column match. + // For backwards compat we allow populating named parameters as long as they haven't been filled yet. + // So for every column that we couldn't match by name we fill the first output direction parameter that wasn't filled previously. + // This means a row like {"a" => 1, "some_field" => 2} will populate the following output db params {"a" => 1, "b" => 2}. + // And a row like {"some_field" => 1, "a" => 2} will populate them as follows {"a" => 2, "b" => 1}. + + var parameterIndices = new ArraySegment(ArrayPool.Shared.Rent(fieldCount), 0, fieldCount); + var secondPassOrdinal = -1; + for (var ordinal = 0; ordinal < fieldCount; ordinal++) + { + var name = reader.GetName(ordinal); + var i = parameters.IndexOf(name); + if (i is not -1 && parameters[i] is { IsOutputDirection: true } parameter) + { + SetValue(reader, logger, parameter, ordinal, i); + parameterIndices[ordinal] = i; + } + else + { + parameterIndices[ordinal] = -1; + if (secondPassOrdinal is -1) + secondPassOrdinal = ordinal; + } + } + + if (secondPassOrdinal is -1) + { + ArrayPool.Shared.Return(parameterIndices.Array!); + break; + } + + // This set will also contain -1, but that's not a valid index so we can ignore it is included. + var matchedParameters = new HashSet(parameterIndices); + var parameterList = parameters.InternalList; + for (var i = 0; i < parameterList.Count; i++) + { + // Find an output parameter that wasn't matched by name. + if (parameterList[i] is not { IsOutputDirection: true } parameter || matchedParameters.Contains(i)) + continue; + + SetValue(reader, logger, parameter, secondPassOrdinal, i); + + // And find the next unhandled ordinal. + secondPassOrdinal = NextSecondPassOrdinal(parameterIndices, secondPassOrdinal); + if (secondPassOrdinal is -1) + break; + } + + ArrayPool.Shared.Return(parameterIndices.Array!); + break; + + static int NextSecondPassOrdinal(ArraySegment indices, int offset) + { + for (var i = offset + 1; i < indices.Count; i++) + { + if (indices[i] is -1) + return i; + } + + return -1; + } + } + case PlaceholderType.Positional: + { + var parameterList = parameters.InternalList; + var ordinal = 0; + for (var i = 0; i < parameterList.Count; i++) + { + if (parameterList[i] is not { IsOutputDirection: true } parameter) + continue; + + SetValue(reader, logger, parameter, ordinal, i); + + ordinal++; + if (ordinal == fieldCount) + break; + } + break; + } + } + + static void SetValue(NpgsqlDataReader reader, ILogger logger, NpgsqlParameter p, int ordinal, int index) + { + try + { + p.SetOutputValue(reader, ordinal); + } + catch (Exception ex) + { + logger.LogDebug(ex, "Failed to set value on output parameter instance '{ParameterNameOrIndex}' for output parameter {OutputName}", + p.ParameterName is NpgsqlParameter.PositionalName ? index : p.ParameterName, reader.GetName(ordinal)); + throw; + } + } } /// diff --git a/src/Npgsql/NpgsqlBinaryExporter.cs b/src/Npgsql/NpgsqlBinaryExporter.cs index 35dea6985d..033517c79d 100644 --- a/src/Npgsql/NpgsqlBinaryExporter.cs +++ b/src/Npgsql/NpgsqlBinaryExporter.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Diagnostics; using System.Threading; using System.Threading.Tasks; @@ -7,6 +7,7 @@ using Npgsql.Internal; using Npgsql.Internal.Postgres; using NpgsqlTypes; +using InfiniteTimeout = System.Threading.Timeout; using static Npgsql.Util.Statics; namespace Npgsql; @@ -24,7 +25,7 @@ public sealed class NpgsqlBinaryExporter : ICancelable NpgsqlConnector _connector; NpgsqlReadBuffer _buf; - bool _isConsumed, _isDisposed; + ExporterState _state = ExporterState.Uninitialized; long _endOfMessagePos; short _column; @@ -46,9 +47,11 @@ public sealed class NpgsqlBinaryExporter : ICancelable /// public TimeSpan Timeout { - set => _buf.Timeout = value; + set => _buf.Timeout = value > TimeSpan.Zero ? value : InfiniteTimeout.InfiniteTimeSpan; } + Activity? _activity; + #endregion #region Construction / Initialization @@ -64,38 +67,50 @@ internal NpgsqlBinaryExporter(NpgsqlConnector connector) internal async Task Init(string copyToCommand, bool async, CancellationToken cancellationToken = default) { - await _connector.WriteQuery(copyToCommand, async, cancellationToken).ConfigureAwait(false); - await _connector.Flush(async, cancellationToken).ConfigureAwait(false); - - using var registration = _connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); + Debug.Assert(_activity is null); + _activity = _connector.TraceCopyStart(copyToCommand, "COPY TO"); - CopyOutResponseMessage copyOutResponse; - var msg = await _connector.ReadMessage(async).ConfigureAwait(false); - switch (msg.Code) + try { - case BackendMessageCode.CopyOutResponse: - copyOutResponse = (CopyOutResponseMessage)msg; - if (!copyOutResponse.IsBinary) + await _connector.WriteQuery(copyToCommand, async, cancellationToken).ConfigureAwait(false); + await _connector.Flush(async, cancellationToken).ConfigureAwait(false); + + using var registration = _connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); + + CopyOutResponseMessage copyOutResponse; + var msg = await _connector.ReadMessage(async).ConfigureAwait(false); + switch (msg.Code) { - throw _connector.Break( - new ArgumentException("copyToCommand triggered a text transfer, only binary is allowed", - nameof(copyToCommand))); + case BackendMessageCode.CopyOutResponse: + copyOutResponse = (CopyOutResponseMessage)msg; + if (!copyOutResponse.IsBinary) + { + throw _connector.Break( + new ArgumentException("copyToCommand triggered a text transfer, only binary is allowed", + nameof(copyToCommand))); + } + break; + case BackendMessageCode.CommandComplete: + throw new InvalidOperationException( + "This API only supports import/export from the client, i.e. COPY commands containing TO/FROM STDIN. " + + "To import/export with files on your PostgreSQL machine, simply execute the command with ExecuteNonQuery. " + + "Note that your data has been successfully imported/exported."); + default: + throw _connector.UnexpectedMessageReceived(msg.Code); } - break; - case BackendMessageCode.CommandComplete: - throw new InvalidOperationException( - "This API only supports import/export from the client, i.e. COPY commands containing TO/FROM STDIN. " + - "To import/export with files on your PostgreSQL machine, simply execute the command with ExecuteNonQuery. " + - "Note that your data has been successfully imported/exported."); - default: - throw _connector.UnexpectedMessageReceived(msg.Code); - } - NumColumns = copyOutResponse.NumColumns; - _columnInfoCache = new PgConverterInfo[NumColumns]; - _rowsExported = 0; - _endOfMessagePos = _buf.CumulativeReadPosition; - await ReadHeader(async).ConfigureAwait(false); + _state = ExporterState.Ready; + NumColumns = copyOutResponse.NumColumns; + _columnInfoCache = new PgConverterInfo[NumColumns]; + _rowsExported = 0; + _endOfMessagePos = _buf.CumulativeReadPosition; + await ReadHeader(async).ConfigureAwait(false); + } + catch (Exception e) + { + TraceSetException(e); + throw; + } } async Task ReadHeader(bool async) @@ -141,7 +156,7 @@ async Task ReadHeader(bool async) async ValueTask StartRow(bool async, CancellationToken cancellationToken = default) { ThrowIfDisposed(); - if (_isConsumed) + if (_state == ExporterState.Consumed) return -1; using var registration = _connector.StartNestedCancellableOperation(cancellationToken); @@ -149,7 +164,10 @@ async ValueTask StartRow(bool async, CancellationToken cancellationToken = // Consume and advance any active column. if (_column >= 0) { - await Commit(async).ConfigureAwait(false); + if (async) + await PgReader.CommitAsync().ConfigureAwait(false); + else + PgReader.Commit(); _column++; } @@ -173,7 +191,7 @@ async ValueTask StartRow(bool async, CancellationToken cancellationToken = Expect(await _connector.ReadMessage(async).ConfigureAwait(false), _connector); Expect(await _connector.ReadMessage(async).ConfigureAwait(false), _connector); _column = BeforeRow; - _isConsumed = true; + _state = ExporterState.Consumed; return -1; } @@ -194,7 +212,8 @@ async ValueTask StartRow(bool async, CancellationToken cancellationToken = /// specify the type. /// /// The value of the column - public T Read() => Read(async: false).GetAwaiter().GetResult(); + public T Read() + => Read(null); /// /// Reads the current column, returns its value and moves ahead to the next column. @@ -207,10 +226,7 @@ async ValueTask StartRow(bool async, CancellationToken cancellationToken = /// /// The value of the column public ValueTask ReadAsync(CancellationToken cancellationToken = default) - => Read(async: true, cancellationToken); - - ValueTask Read(bool async, CancellationToken cancellationToken = default) - => Read(async, null, cancellationToken); + => ReadAsync(null, cancellationToken); /// /// Reads the current column, returns its value according to and @@ -225,7 +241,8 @@ ValueTask Read(bool async, CancellationToken cancellationToken = default) /// /// The .NET type of the column to be read. /// The value of the column - public T Read(NpgsqlDbType type) => Read(async: false, type, CancellationToken.None).GetAwaiter().GetResult(); + public T Read(NpgsqlDbType type) + => Read((NpgsqlDbType?)type); /// /// Reads the current column, returns its value according to and @@ -244,42 +261,28 @@ ValueTask Read(bool async, CancellationToken cancellationToken = default) /// The .NET type of the column to be read. /// The value of the column public ValueTask ReadAsync(NpgsqlDbType type, CancellationToken cancellationToken = default) - => Read(async: true, type, cancellationToken); + => ReadAsync((NpgsqlDbType?)type, cancellationToken); - async ValueTask Read(bool async, NpgsqlDbType? type, CancellationToken cancellationToken) + T Read(NpgsqlDbType? type) { ThrowIfNotOnRow(); - using var registration = _connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); - if (!IsInitializedAndAtStart) - await MoveNextColumn(async, resumableOp: false).ConfigureAwait(false); + MoveNextColumn(resumableOp: false); + var reader = PgReader; try { - var reader = PgReader; - if (reader.FieldSize is -1) - return DbNullOrThrow(); + if (reader.FieldIsDbNull) + return DbNullOrThrow(); - var info = GetInfo(type, out var asObject); + var info = GetInfo(typeof(T), type, out var asObject); - T result; - if (async) - { - await reader.StartReadAsync(info.BufferRequirement, cancellationToken).ConfigureAwait(false); - result = asObject - ? (T)await info.Converter.ReadAsObjectAsync(reader, cancellationToken).ConfigureAwait(false) - : await info.GetConverter().ReadAsync(reader, cancellationToken).ConfigureAwait(false); - await reader.EndReadAsync().ConfigureAwait(false); - } - else - { - reader.StartRead(info.BufferRequirement); - result = asObject - ? (T)info.Converter.ReadAsObject(reader) - : info.GetConverter().Read(reader); - reader.EndRead(); - } + reader.StartRead(info.BufferRequirement); + var result = asObject + ? (T)info.Converter.ReadAsObject(reader) + : info.Converter.UnsafeDowncast().Read(reader); + reader.EndRead(); return result; } @@ -288,48 +291,82 @@ async ValueTask Read(bool async, NpgsqlDbType? type, CancellationToken can // Don't delay committing the current column, just do it immediately (as opposed to on the next action: Read, IsNull, Skip). // Zero length columns would otherwise create an edge-case where we'd have to immediately commit as we won't know whether we're at the end. // To guarantee the commit happens in that case we would still need this try finally, at which point it's just better to be consistent. - await Commit(async).ConfigureAwait(false); + reader.Commit(); } + } - static T DbNullOrThrow() + async ValueTask ReadAsync(NpgsqlDbType? type, CancellationToken cancellationToken) + { + ThrowIfNotOnRow(); + + using var registration = _connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); + + if (!IsInitializedAndAtStart) + await MoveNextColumnAsync(resumableOp: false).ConfigureAwait(false); + + var reader = PgReader; + try { - // When T is a Nullable, we support returning null - if (default(T) is null && typeof(T).IsValueType) - return default!; - throw new InvalidCastException("Column is null"); - } + if (reader.FieldIsDbNull) + return DbNullOrThrow(); + + var info = GetInfo(typeof(T), type, out var asObject); + + await reader.StartReadAsync(info.BufferRequirement, cancellationToken).ConfigureAwait(false); + var result = asObject + ? (T)await info.Converter.ReadAsObjectAsync(reader, cancellationToken).ConfigureAwait(false) + : await info.Converter.UnsafeDowncast().ReadAsync(reader, cancellationToken).ConfigureAwait(false); + await reader.EndReadAsync().ConfigureAwait(false); - PgConverterInfo GetInfo(NpgsqlDbType? type, out bool asObject) + return result; + } + finally { - ref var cachedInfo = ref _columnInfoCache[_column]; - var converterInfo = cachedInfo.IsDefault ? cachedInfo = CreateConverterInfo(typeof(T), type) : cachedInfo; - asObject = converterInfo.IsBoxingConverter; - return converterInfo; + // Don't delay committing the current column, just do it immediately (as opposed to on the next action: Read, IsNull, Skip). + // Zero length columns would otherwise create an edge-case where we'd have to immediately commit as we won't know whether we're at the end. + // To guarantee the commit happens in that case we would still need this try finally, at which point it's just better to be consistent. + await reader.CommitAsync().ConfigureAwait(false); } + } + + static T DbNullOrThrow() + { + // When T is a Nullable, we support returning null + if (default(T) is null && typeof(T).IsValueType) + return default!; + throw new InvalidCastException("Column is null"); + } + + PgConverterInfo GetInfo(Type type, NpgsqlDbType? npgsqlDbType, out bool asObject) + { + ref var cachedInfo = ref _columnInfoCache[_column]; + var converterInfo = cachedInfo.IsDefault ? cachedInfo = CreateConverterInfo(type, npgsqlDbType) : cachedInfo; + asObject = converterInfo.IsBoxingConverter; + return converterInfo; + } - PgConverterInfo CreateConverterInfo(Type type, NpgsqlDbType? npgsqlDbType = null) + PgConverterInfo CreateConverterInfo(Type type, NpgsqlDbType? npgsqlDbType = null) + { + var options = _connector.SerializerOptions; + PgTypeId? pgTypeId = null; + if (npgsqlDbType.HasValue) { - var options = _connector.SerializerOptions; - PgTypeId? pgTypeId = null; - if (npgsqlDbType.HasValue) - { - pgTypeId = npgsqlDbType.Value.ToDataTypeName() is { } name - ? options.GetCanonicalTypeId(name) - // Handle plugin types via lookup. - : GetRepresentationalOrDefault(npgsqlDbType.Value.ToUnqualifiedDataTypeNameOrThrow()); - } - var info = options.GetTypeInfo(type, pgTypeId) - ?? throw new NotSupportedException($"Reading is not supported for type '{type}'{(npgsqlDbType is null ? "" : $" and NpgsqlDbType '{npgsqlDbType}'")}"); + pgTypeId = npgsqlDbType.Value.ToDataTypeName() is { } name + ? options.GetCanonicalTypeId(name) + // Handle plugin types via lookup. + : GetRepresentationalOrDefault(npgsqlDbType.Value.ToUnqualifiedDataTypeNameOrThrow()); + } + var info = options.GetTypeInfoInternal(type, pgTypeId) + ?? throw new NotSupportedException($"Reading is not supported for type '{type}'{(npgsqlDbType is null ? "" : $" and NpgsqlDbType '{npgsqlDbType}'")}"); - // Binary export has no type info so we only do caller-directed interpretation of data. - return info.Bind(new Field("?", - info.PgTypeId ?? ((PgResolverTypeInfo)info).GetDefaultResolution(null).PgTypeId, -1), DataFormat.Binary); + // Binary export has no type info so we only do caller-directed interpretation of data. + return info.Bind(new Field("?", + info.PgTypeId ?? ((PgResolverTypeInfo)info).GetDefaultResolution(null).PgTypeId, -1), DataFormat.Binary); - PgTypeId GetRepresentationalOrDefault(string dataTypeName) - { - var type = options.DatabaseInfo.GetPostgresType(dataTypeName); - return options.ToCanonicalTypeId(type.GetRepresentationalType()); - } + PgTypeId GetRepresentationalOrDefault(string dataTypeName) + { + var type = options.DatabaseInfo.GetPostgresType(dataTypeName); + return options.ToCanonicalTypeId(type.GetRepresentationalType()); } } @@ -342,64 +379,68 @@ public bool IsNull { ThrowIfNotOnRow(); if (!IsInitializedAndAtStart) - return MoveNextColumn(async: false, resumableOp: true).GetAwaiter().GetResult() is -1; + MoveNextColumn(resumableOp: true); - return PgReader.FieldSize is - 1; + return PgReader.FieldIsDbNull; } } /// /// Skips the current column without interpreting its value. /// - public void Skip() => Skip(async: false).GetAwaiter().GetResult(); + public void Skip() + { + ThrowIfNotOnRow(); + + if (!IsInitializedAndAtStart) + MoveNextColumn(resumableOp: false); + + PgReader.Commit(); + } /// /// Skips the current column without interpreting its value. /// - public Task SkipAsync(CancellationToken cancellationToken = default) - => Skip(true, cancellationToken); - - async Task Skip(bool async, CancellationToken cancellationToken = default) + public async Task SkipAsync(CancellationToken cancellationToken = default) { ThrowIfNotOnRow(); using var registration = _connector.StartNestedCancellableOperation(cancellationToken); if (!IsInitializedAndAtStart) - await MoveNextColumn(async, resumableOp: false).ConfigureAwait(false); + await MoveNextColumnAsync(resumableOp: false).ConfigureAwait(false); - await Commit(async).ConfigureAwait(false); + await PgReader.CommitAsync().ConfigureAwait(false); } #endregion #region Utilities - bool IsInitializedAndAtStart => PgReader.Initialized && (PgReader.FieldSize is -1 || PgReader.FieldOffset is 0); + bool IsInitializedAndAtStart => PgReader.Initialized && (PgReader.FieldIsDbNull || PgReader.FieldAtStart); - ValueTask Commit(bool async) + void MoveNextColumn(bool resumableOp) { - if (async) - return PgReader.CommitAsync(resuming: false); + PgReader.Commit(); - PgReader.Commit(resuming: false); - return new(); + if (_column + 1 == NumColumns) + ThrowHelper.ThrowInvalidOperationException("No more columns left in the current row"); + _column++; + _buf.Ensure(sizeof(int)); + var columnLen = _buf.ReadInt32(); + PgReader.Init(columnLen, DataFormat.Binary, resumableOp); } - async ValueTask MoveNextColumn(bool async, bool resumableOp) + async ValueTask MoveNextColumnAsync(bool resumableOp) { - if (async) - await PgReader.CommitAsync(resuming: false).ConfigureAwait(false); - else - PgReader.Commit(resuming: false); + await PgReader.CommitAsync().ConfigureAwait(false); if (_column + 1 == NumColumns) ThrowHelper.ThrowInvalidOperationException("No more columns left in the current row"); _column++; - await _buf.Ensure(4, async).ConfigureAwait(false); + await _buf.Ensure(sizeof(int), async: true).ConfigureAwait(false); var columnLen = _buf.ReadInt32(); PgReader.Init(columnLen, DataFormat.Binary, resumableOp); - return PgReader.FieldSize; } void ThrowIfNotOnRow() @@ -411,7 +452,7 @@ void ThrowIfNotOnRow() void ThrowIfDisposed() { - if (_isDisposed) + if (_state == ExporterState.Disposed) ThrowHelper.ThrowObjectDisposedException(nameof(NpgsqlBinaryExporter), "The COPY operation has already ended."); } @@ -422,7 +463,7 @@ void ThrowIfDisposed() /// /// Cancels an ongoing export. /// - public void Cancel() => _connector.PerformUserCancellation(); + public void Cancel() => _connector.PerformImmediateUserCancellation(); /// /// Async cancels an ongoing export. @@ -446,62 +487,103 @@ public Task CancelAsync() async ValueTask DisposeAsync(bool async) { - if (_isDisposed) + if (_state == ExporterState.Disposed) return; - if (_isConsumed) - { - LogMessages.BinaryCopyOperationCompleted(_copyLogger, _rowsExported, _connector.Id); - } - else if (!_connector.IsBroken) + try { - try + if (_state is ExporterState.Consumed or ExporterState.Uninitialized) { - using var registration = _connector.StartNestedCancellableOperation(attemptPgCancellation: false); - // Be sure to commit the reader. - if (async) - await PgReader.CommitAsync(resuming: false).ConfigureAwait(false); - else - PgReader.Commit(resuming: false); - // Finish the current CopyData message - await _buf.Skip(checked((int)(_endOfMessagePos - _buf.CumulativeReadPosition)), async).ConfigureAwait(false); - // Read to the end - _connector.SkipUntil(BackendMessageCode.CopyDone); - // We intentionally do not pass a CancellationToken since we don't want to cancel cleanup - Expect(await _connector.ReadMessage(async).ConfigureAwait(false), _connector); - Expect(await _connector.ReadMessage(async).ConfigureAwait(false), _connector); + LogMessages.BinaryCopyOperationCompleted(_copyLogger, _rowsExported, _connector.Id); + TraceExportStop(); } - catch (OperationCanceledException e) when (e.InnerException is PostgresException pg && pg.SqlState == PostgresErrorCodes.QueryCanceled) + else if (!_connector.IsBroken) { - LogMessages.CopyOperationCancelled(_copyLogger, _connector.Id); + try + { + using var registration = _connector.StartNestedCancellableOperation(attemptPgCancellation: false); + // Be sure to commit the reader. + if (async) + await PgReader.CommitAsync().ConfigureAwait(false); + else + PgReader.Commit(); + // Finish the current CopyData message + await _buf.Skip(async, checked((int)(_endOfMessagePos - _buf.CumulativeReadPosition))).ConfigureAwait(false); + // Read to the end + _connector.SkipUntil(BackendMessageCode.CopyDone); + // We intentionally do not pass a CancellationToken since we don't want to cancel cleanup + Expect(await _connector.ReadMessage(async).ConfigureAwait(false), _connector); + Expect(await _connector.ReadMessage(async).ConfigureAwait(false), _connector); + + TraceExportStop(); + } + catch (OperationCanceledException e) when (e.InnerException is PostgresException { SqlState: PostgresErrorCodes.QueryCanceled }) + { + LogMessages.CopyOperationCancelled(_copyLogger, _connector.Id); + TraceExportStop(); + } + catch (Exception e) + { + LogMessages.ExceptionWhenDisposingCopyOperation(_copyLogger, _connector.Id, e); + TraceSetException(e); + } } - catch (Exception e) + } + finally + { + _connector.EndUserAction(); + Cleanup(); + } + + void Cleanup() + { + Debug.Assert(_state != ExporterState.Disposed); + var connector = _connector; + + if (!ReferenceEquals(connector, null)) { - LogMessages.ExceptionWhenDisposingCopyOperation(_copyLogger, _connector.Id, e); + connector.CurrentCopyOperation = null; + _connector = null!; } - } - _connector.EndUserAction(); - Cleanup(); + _buf = null!; + _state = ExporterState.Disposed; + } } -#pragma warning disable CS8625 - void Cleanup() + #endregion + + #region Tracing + + void TraceExportStop() { - Debug.Assert(!_isDisposed); - var connector = _connector; + if (_activity is not null) + { + NpgsqlActivitySource.CopyStop(_activity, _rowsExported); + _activity = null; + } + } - if (connector != null) + void TraceSetException(Exception exception) + { + if (_activity is not null) { - connector.CurrentCopyOperation = null; - _connector.Connection?.EndBindingScope(ConnectorBindingScope.Copy); - _connector = null; + NpgsqlActivitySource.SetException(_activity, exception); + _activity = null; } + } + + #endregion Tracing + + #region Enums - _buf = null; - _isDisposed = true; + enum ExporterState + { + Uninitialized, + Ready, + Consumed, + Disposed } -#pragma warning restore CS8625 - #endregion + #endregion Enums } diff --git a/src/Npgsql/NpgsqlBinaryImporter.cs b/src/Npgsql/NpgsqlBinaryImporter.cs index 7a9caa5595..08f2a90844 100644 --- a/src/Npgsql/NpgsqlBinaryImporter.cs +++ b/src/Npgsql/NpgsqlBinaryImporter.cs @@ -1,4 +1,5 @@ using System; +using System.Diagnostics; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -7,6 +8,7 @@ using Npgsql.Internal; using Npgsql.Internal.Postgres; using NpgsqlTypes; +using InfiniteTimeout = System.Threading.Timeout; using static Npgsql.Util.Statics; namespace Npgsql; @@ -25,7 +27,7 @@ public sealed class NpgsqlBinaryImporter : ICancelable NpgsqlConnector _connector; NpgsqlWriteBuffer _buf; - ImporterState _state; + ImporterState _state = ImporterState.Uninitialized; /// /// The number of columns in the current (not-yet-written) row. @@ -45,6 +47,8 @@ public sealed class NpgsqlBinaryImporter : ICancelable readonly ILogger _copyLogger; PgWriter _pgWriter = null!; // Setup in Init + Activity? _activity; + /// /// Current timeout /// @@ -52,8 +56,9 @@ public TimeSpan Timeout { set { - _buf.Timeout = value; - _connector.ReadBuffer.Timeout = value; + var timeout = value > TimeSpan.Zero ? value : InfiniteTimeout.InfiniteTimeSpan; + _buf.Timeout = timeout; + _connector.ReadBuffer.Timeout = timeout; } } @@ -72,39 +77,51 @@ internal NpgsqlBinaryImporter(NpgsqlConnector connector) internal async Task Init(string copyFromCommand, bool async, CancellationToken cancellationToken = default) { - await _connector.WriteQuery(copyFromCommand, async, cancellationToken).ConfigureAwait(false); - await _connector.Flush(async, cancellationToken).ConfigureAwait(false); + Debug.Assert(_activity is null); + _activity = _connector.TraceCopyStart(copyFromCommand, "COPY FROM"); - using var registration = _connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); - - CopyInResponseMessage copyInResponse; - var msg = await _connector.ReadMessage(async).ConfigureAwait(false); - switch (msg.Code) + try { - case BackendMessageCode.CopyInResponse: - copyInResponse = (CopyInResponseMessage)msg; - if (!copyInResponse.IsBinary) + await _connector.WriteQuery(copyFromCommand, async, cancellationToken).ConfigureAwait(false); + await _connector.Flush(async, cancellationToken).ConfigureAwait(false); + + using var registration = _connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); + + CopyInResponseMessage copyInResponse; + var msg = await _connector.ReadMessage(async).ConfigureAwait(false); + switch (msg.Code) { - throw _connector.Break( - new ArgumentException("copyFromCommand triggered a text transfer, only binary is allowed", - nameof(copyFromCommand))); + case BackendMessageCode.CopyInResponse: + copyInResponse = (CopyInResponseMessage)msg; + if (!copyInResponse.IsBinary) + { + throw _connector.Break( + new ArgumentException("copyFromCommand triggered a text transfer, only binary is allowed", + nameof(copyFromCommand))); + } + break; + case BackendMessageCode.CommandComplete: + throw new InvalidOperationException( + "This API only supports import/export from the client, i.e. COPY commands containing TO/FROM STDIN. " + + "To import/export with files on your PostgreSQL machine, simply execute the command with ExecuteNonQuery. " + + "Note that your data has been successfully imported/exported."); + default: + throw _connector.UnexpectedMessageReceived(msg.Code); } - break; - case BackendMessageCode.CommandComplete: - throw new InvalidOperationException( - "This API only supports import/export from the client, i.e. COPY commands containing TO/FROM STDIN. " + - "To import/export with files on your PostgreSQL machine, simply execute the command with ExecuteNonQuery. " + - "Note that your data has been successfully imported/exported."); - default: - throw _connector.UnexpectedMessageReceived(msg.Code); - } - _params = new NpgsqlParameter[copyInResponse.NumColumns]; - _rowsImported = 0; - _buf.StartCopyMode(); - WriteHeader(); - // Only init after header. - _pgWriter = _buf.GetWriter(_connector.DatabaseInfo); + _state = ImporterState.Ready; + _params = new NpgsqlParameter[copyInResponse.NumColumns]; + _rowsImported = 0; + _buf.StartCopyMode(); + WriteHeader(); + // Only init after header. + _pgWriter = _buf.GetWriter(_connector.DatabaseInfo); + } + catch (Exception e) + { + TraceSetException(e); + throw; + } } void WriteHeader() @@ -231,9 +248,11 @@ public Task WriteAsync(T value, string dataTypeName, CancellationToken cancel Task Write(bool async, T value, NpgsqlDbType? npgsqlDbType, string? dataTypeName, CancellationToken cancellationToken = default) { - // Statically handle DBNull for backwards compatibility, generic parameters where T = DBNull normally won't find a mapping. - // Also handle null values for object typed parameters, as parameters only accept DBNull.Value when T = object. - if (typeof(T) == typeof(DBNull) || (typeof(T) == typeof(object) && value is null)) + // Handle DBNull: + // 1. when T = DBNull for backwards compatibility, DBNull as a type normally won't find a mapping. + // 2. when T = object we resolve oid 0 if DBNull is the first value, later column value oids would needlessly be limited to oid 0. + // Also handle null values for object typed parameters, these parameters require non null values to be seen as set. + if (typeof(T) == typeof(DBNull) || (typeof(T) == typeof(object) && value is null or DBNull)) return WriteNull(async, cancellationToken); return Core(async, value, npgsqlDbType, dataTypeName, cancellationToken); @@ -279,7 +298,7 @@ async Task Core(bool async, T value, NpgsqlDbType? npgsqlDbType, string? dataTyp // These actions can reset or change the type info, we'll check afterwards whether we're still consistent with the original values. param.TypedValue = value; - param.ResolveTypeInfo(_connector.SerializerOptions); + param.ResolveTypeInfo(_connector.SerializerOptions, _connector.DbTypeResolver); if (previousTypeInfo is not null && previousConverter is not null && param.PgTypeId != previousTypeId) { @@ -296,7 +315,7 @@ async Task Core(bool async, T value, NpgsqlDbType? npgsqlDbType, string? dataTyp if (newParam) _params[_column] = param; - param.Bind(out _, out _); + param.Bind(out _, out _, requiredFormat: DataFormat.Binary); try { @@ -305,6 +324,7 @@ await param.Write(async, _pgWriter.WithFlushMode(async ? FlushMode.NonBlocking : } catch (Exception ex) { + TraceSetException(ex); _connector.Break(ex); throw; } @@ -423,8 +443,9 @@ async ValueTask Complete(bool async, CancellationToken cancellationToken _state = ImporterState.Committed; return cmdComplete.Rows; } - catch + catch (Exception e) { + TraceSetException(e); Cleanup(); throw; } @@ -510,6 +531,7 @@ async ValueTask CloseAsync(bool async, CancellationToken cancellationToken = def case ImporterState.Ready: await Cancel(async, cancellationToken).ConfigureAwait(false); break; + case ImporterState.Uninitialized: case ImporterState.Cancelled: case ImporterState.Committed: break; @@ -517,6 +539,7 @@ async ValueTask CloseAsync(bool async, CancellationToken cancellationToken = def throw new Exception("Invalid state: " + _state); } + TraceImportStop(); Cleanup(); } @@ -533,7 +556,6 @@ void Cleanup() { connector.EndUserAction(); connector.CurrentCopyOperation = null; - connector.Connection?.EndBindingScope(ConnectorBindingScope.Copy); _connector = null; } @@ -551,6 +573,7 @@ void CheckReady() static void Throw(ImporterState state) => throw (state switch { + ImporterState.Uninitialized => throw new InvalidOperationException("The COPY operation has not been initialized."), ImporterState.Disposed => new ObjectDisposedException(typeof(NpgsqlBinaryImporter).FullName, "The COPY operation has already ended."), ImporterState.Cancelled => new InvalidOperationException("The COPY operation has already been cancelled."), @@ -565,6 +588,7 @@ static void Throw(ImporterState state) enum ImporterState { + Uninitialized, Ready, Committed, Cancelled, @@ -575,4 +599,38 @@ enum ImporterState void ThrowColumnMismatch() => throw new InvalidOperationException($"The binary import operation was started with {NumColumns} column(s), but {_column + 1} value(s) were provided."); + + #region Tracing + + void TraceImportStop() + { + if (_activity is not null) + { + switch (_state) + { + case ImporterState.Committed: + NpgsqlActivitySource.CopyStop(_activity, _rowsImported); + break; + case ImporterState.Cancelled: + NpgsqlActivitySource.CopyStop(_activity, rows: 0); + break; + default: + Debug.Fail("Invalid state: " + _state); + break; + } + + _activity = null; + } + } + + void TraceSetException(Exception exception) + { + if (_activity is not null) + { + NpgsqlActivitySource.SetException(_activity, exception); + _activity = null; + } + } + + #endregion Tracing } diff --git a/src/Npgsql/NpgsqlCommand.cs b/src/Npgsql/NpgsqlCommand.cs index eaf11d51ff..1e3f4a1f04 100644 --- a/src/Npgsql/NpgsqlCommand.cs +++ b/src/Npgsql/NpgsqlCommand.cs @@ -13,10 +13,10 @@ using NpgsqlTypes; using static Npgsql.Util.Statics; using System.Diagnostics.CodeAnalysis; -using System.Threading.Channels; using Microsoft.Extensions.Logging; using Npgsql.Internal; using Npgsql.Properties; +using System.Collections; namespace Npgsql; @@ -46,14 +46,11 @@ public class NpgsqlCommand : DbCommand, ICloneable, IComponent int? _timeout; internal NpgsqlParameterCollection? _parameters; - /// - /// Whether this is wrapped by an . - /// - internal bool IsWrappedByBatch { get; } + internal NpgsqlBatch? WrappingBatch { get; } internal List InternalBatchCommands { get; } - Activity? CurrentActivity; + internal Activity? CurrentActivity { get; private set; } /// /// Returns details about each statement that this command has executed. @@ -142,13 +139,13 @@ public NpgsqlCommand(string? cmdText, NpgsqlConnection? connection, NpgsqlTransa /// /// Used when this instance is wrapped inside an . /// - internal NpgsqlCommand(int batchCommandCapacity, NpgsqlConnection? connection = null) + internal NpgsqlCommand(NpgsqlBatch batch, int batchCommandCapacity, NpgsqlConnection? connection = null) { GC.SuppressFinalize(this); InternalBatchCommands = new List(batchCommandCapacity); InternalConnection = connection; CommandType = CommandType.Text; - IsWrappedByBatch = true; + WrappingBatch = batch; // These can/should never be used in this mode _commandText = null!; @@ -161,8 +158,8 @@ internal NpgsqlCommand(string? cmdText, NpgsqlConnector connector) : this(cmdTex /// /// Used when this instance is wrapped inside an . /// - internal NpgsqlCommand(NpgsqlConnector connector, int batchCommandCapacity) - : this(batchCommandCapacity) + internal NpgsqlCommand(NpgsqlBatch batch, NpgsqlConnector connector, int batchCommandCapacity) + : this(batch, batchCommandCapacity) => _connector = connector; internal static NpgsqlCommand CreateCachedCommand(NpgsqlConnection connection) @@ -183,10 +180,20 @@ public override string CommandText get => _commandText; set { - Debug.Assert(!IsWrappedByBatch); + Debug.Assert(WrappingBatch is null); - if (State != CommandState.Idle) - ThrowHelper.ThrowInvalidOperationException("An open data reader exists for this command."); + switch (State) + { + case CommandState.Idle: + break; + case CommandState.Disposed: + ThrowHelper.ThrowObjectDisposedException(typeof(NpgsqlCommand).FullName); + break; + case CommandState.InProgress: + default: + ThrowHelper.ThrowInvalidOperationException("An open data reader exists for this command."); + break; + } _commandText = value ?? string.Empty; @@ -195,6 +202,26 @@ public override string CommandText } } + string GetBatchFullCommandText() + { + Debug.Assert(WrappingBatch is not null); + if (InternalBatchCommands.Count == 0) + return string.Empty; + if (InternalBatchCommands.Count == 1) + return InternalBatchCommands[0].CommandText; + // TODO: Potentially cache on connector/command? + var sb = new StringBuilder(); + sb.Append(InternalBatchCommands[0].CommandText); + for (var i = 1; i < InternalBatchCommands.Count; i++) + { + sb + .Append(';') + .AppendLine() + .Append(InternalBatchCommands[i].CommandText); + } + return sb.ToString(); + } + /// /// Gets or sets the wait time (in seconds) before terminating the attempt to execute a command and generating an error. /// @@ -205,9 +232,7 @@ public override int CommandTimeout get => _timeout ?? (InternalConnection?.CommandTimeout ?? DefaultTimeout); set { - if (value < 0) { - throw new ArgumentOutOfRangeException(nameof(value), value, "CommandTimeout can't be less than zero."); - } + ArgumentOutOfRangeException.ThrowIfNegative(value); _timeout = value; } @@ -236,9 +261,12 @@ protected override DbConnection? DbConnection if (InternalConnection == value) return; - InternalConnection = State == CommandState.Idle - ? (NpgsqlConnection?)value - : throw new InvalidOperationException("An open data reader exists for this command."); + InternalConnection = State switch + { + CommandState.Idle => (NpgsqlConnection?)value, + CommandState.Disposed => throw new ObjectDisposedException(typeof(NpgsqlCommand).FullName), + _ => throw new InvalidOperationException("An open data reader exists for this command."), + }; Transaction = null; } @@ -377,7 +405,12 @@ internal CommandState State } } - internal void ResetPreparation() => _connectorPreparedOn = null; + internal void ResetPreparation() + { + _connectorPreparedOn = null; + foreach (var s in InternalBatchCommands) + s.ResetPreparation(); + } #endregion State management @@ -404,7 +437,7 @@ internal CommandState State /// Gets the . /// /// The parameters of the SQL statement or function (stored procedure). The default is an empty collection. - public new NpgsqlParameterCollection Parameters => _parameters ??= new(); + public new NpgsqlParameterCollection Parameters => _parameters ??= []; #endregion @@ -436,12 +469,11 @@ internal void DeriveParameters() { var conn = CheckAndGetConnection(); Debug.Assert(conn is not null); + var connector = conn.Connector!; if (string.IsNullOrEmpty(CommandText)) throw new InvalidOperationException("CommandText property has not been initialized"); - using var _ = conn.StartTemporaryBindingScope(out var connector); - foreach (var s in InternalBatchCommands) if (s.PreparedStatement?.IsExplicit == true) throw new NpgsqlException("Deriving parameters isn't supported for commands that are already prepared."); @@ -538,7 +570,7 @@ void DeriveParametersForQuery(NpgsqlConnector connector) { LogMessages.DerivingParameters(connector.CommandLogger, CommandText, connector.Id); - if (IsWrappedByBatch) + if (WrappingBatch is not null) foreach (var batchCommand in InternalBatchCommands) connector.SqlQueryParser.ParseRawQuery(batchCommand, connector.UseConformingStrings, deriveParameters: true); else @@ -644,18 +676,16 @@ Task Prepare(bool async, CancellationToken cancellationToken = default) { var connection = CheckAndGetConnection(); Debug.Assert(connection is not null); - if (connection.Settings.Multiplexing) - throw new NotSupportedException("Explicit preparation not supported with multiplexing"); var connector = connection.Connector!; var logger = connector.CommandLogger; var needToPrepare = false; - if (IsWrappedByBatch) + if (WrappingBatch is not null) { foreach (var batchCommand in InternalBatchCommands) { - batchCommand._parameters?.ProcessParameters(connector.SerializerOptions, validateValues: false, CommandType); + batchCommand._parameters?.ProcessParameters(connector.ReloadableState, validateValues: false, batchCommand.CommandType); ProcessRawQuery(connector.SqlQueryParser, connector.UseConformingStrings, batchCommand); needToPrepare = batchCommand.ExplicitPrepare(connector) || needToPrepare; @@ -673,7 +703,7 @@ IEnumerable CommandTexts() } else { - _parameters?.ProcessParameters(connector.SerializerOptions, validateValues: false, CommandType); + _parameters?.ProcessParameters(connector.ReloadableState, validateValues: false, CommandType); ProcessRawQuery(connector.SqlQueryParser, connector.UseConformingStrings, batchCommand: null); foreach (var batchCommand in InternalBatchCommands) @@ -711,11 +741,16 @@ static async Task PrepareLong(NpgsqlCommand command, bool async, NpgsqlConnector continue; var pStatement = batchCommand.PreparedStatement!; + var replacedStatement = pStatement.StatementBeingReplaced; - if (pStatement.StatementBeingReplaced != null) + if (replacedStatement != null) { Expect(await connector.ReadMessage(async).ConfigureAwait(false), connector); - pStatement.StatementBeingReplaced.CompleteUnprepare(); + replacedStatement.CompleteUnprepare(); + + if (!replacedStatement.IsExplicit) + connector.PreparedStatementManager.AutoPrepared[replacedStatement.AutoPreparedSlotIndex] = null; + pStatement.StatementBeingReplaced = null; } @@ -805,8 +840,6 @@ async Task Unprepare(bool async, CancellationToken cancellationToken = default) { var connection = CheckAndGetConnection(); Debug.Assert(connection is not null); - if (connection.Settings.Multiplexing) - throw new NotSupportedException("Explicit preparation not supported with multiplexing"); var forall = true; foreach (var statement in InternalBatchCommands) @@ -839,7 +872,7 @@ async Task Unprepare(bool async, CancellationToken cancellationToken = default) if (!pStatement.IsExplicit) connector.PreparedStatementManager.AutoPrepared[pStatement.AutoPreparedSlotIndex] = null; - batchCommand.PreparedStatement = null; + batchCommand.ResetPreparation(); } } @@ -851,7 +884,7 @@ async Task Unprepare(bool async, CancellationToken cancellationToken = default) #region Query analysis - internal void ProcessRawQuery(SqlQueryParser? parser, bool standardConformingStrings, NpgsqlBatchCommand? batchCommand) + internal void ProcessRawQuery(SqlQueryParser parser, bool standardConformingStrings, NpgsqlBatchCommand? batchCommand) { var (commandText, commandType, parameters) = batchCommand is null ? (CommandText, CommandType, _parameters) @@ -874,7 +907,10 @@ internal void ProcessRawQuery(SqlQueryParser? parser, bool standardConformingStr batchCommand = TruncateStatementsToOne(); batchCommand.FinalCommandText = CommandText; if (parameters is not null) + { batchCommand.PositionalParameters = parameters.InternalList; + batchCommand._parameters = parameters; + } } else { @@ -887,6 +923,11 @@ internal void ProcessRawQuery(SqlQueryParser? parser, bool standardConformingStr break; case PlaceholderType.NoParameters: + if (batchCommand is not null) + { + batchCommand.FinalCommandText = batchCommand.CommandText; + break; + } // Unless the EnableSqlRewriting AppContext switch is explicitly disabled, queries with no parameters are parsed just // like queries with named parameters, since they may contain a semicolon (legacy batching). if (EnableSqlRewriting) @@ -897,9 +938,6 @@ internal void ProcessRawQuery(SqlQueryParser? parser, bool standardConformingStr if (!EnableSqlRewriting) ThrowHelper.ThrowNotSupportedException($"Named parameters are not supported when Npgsql.{nameof(EnableSqlRewriting)} is disabled"); - // The parser is cached on NpgsqlConnector - unless we're in multiplexing mode. - parser ??= new SqlQueryParser(); - if (batchCommand is null) { parser.ParseRawQuery(this, standardConformingStrings); @@ -911,8 +949,6 @@ internal void ProcessRawQuery(SqlQueryParser? parser, bool standardConformingStr else { parser.ParseRawQuery(batchCommand, standardConformingStrings); - if (batchCommand._parameters?.HasOutputParameters == true) - ThrowHelper.ThrowNotSupportedException("Batches cannot cannot have out parameters"); ValidateParameterCount(batchCommand); } @@ -955,6 +991,9 @@ internal void ProcessRawQuery(SqlQueryParser? parser, bool standardConformingStr if (EnableStoredProcedureCompatMode && parameter.Direction == ParameterDirection.Output) continue; + if (parameter.Direction == ParameterDirection.ReturnValue) + continue; + if (isFirstParam) isFirstParam = false; else @@ -989,6 +1028,7 @@ internal void ProcessRawQuery(SqlQueryParser? parser, bool standardConformingStr batchCommand ??= TruncateStatementsToOne(); batchCommand.FinalCommandText = sqlBuilder.ToString(); + batchCommand._parameters = parameters; batchCommand.PositionalParameters.AddRange(inputParameters); ValidateParameterCount(batchCommand); @@ -1001,7 +1041,7 @@ internal void ProcessRawQuery(SqlQueryParser? parser, bool standardConformingStr static void ValidateParameterCount(NpgsqlBatchCommand batchCommand) { - if (batchCommand.HasParameters && batchCommand.PositionalParameters.Count > ushort.MaxValue) + if (batchCommand is { HasParameters: true, PositionalParameters.Count: > ushort.MaxValue }) ThrowHelper.ThrowNpgsqlException("A statement cannot have more than 65535 parameters"); } } @@ -1010,9 +1050,6 @@ static void ValidateParameterCount(NpgsqlBatchCommand batchCommand) #region Message Creation / Population - void BeginSend(NpgsqlConnector connector) - => connector.WriteBuffer.Timeout = TimeSpan.FromSeconds(CommandTimeout); - internal Task Write(NpgsqlConnector connector, bool async, bool flush, CancellationToken cancellationToken = default) { return (_behavior & CommandBehavior.SchemaOnly) == 0 @@ -1026,7 +1063,7 @@ async Task WriteExecute(NpgsqlConnector connector, bool async, bool flush, Cance var syncCaller = !async; for (var i = 0; i < InternalBatchCommands.Count; i++) { - // The following is only for deadlock avoidance when doing sync I/O (so never in multiplexing) + // The following is only for deadlock avoidance when doing sync I/O if (syncCaller && ShouldSchedule(ref async, i)) await new TaskSchedulerAwaitable(ConstrainedConcurrencyScheduler); @@ -1053,7 +1090,7 @@ await connector.WriteBind( i == 0 ? UnknownResultTypeList : null, async, cancellationToken).ConfigureAwait(false); - await connector.WriteDescribe(StatementOrPortal.Portal, Array.Empty(), async, cancellationToken).ConfigureAwait(false); + await connector.WriteDescribe(StatementOrPortal.Portal, [], async, cancellationToken).ConfigureAwait(false); } else { @@ -1092,11 +1129,24 @@ async Task WriteExecuteSchemaOnly(NpgsqlConnector connector, bool async, bool fl await new TaskSchedulerAwaitable(ConstrainedConcurrencyScheduler); var batchCommand = InternalBatchCommands[i]; + var pStatement = batchCommand.PreparedStatement; - if (batchCommand.PreparedStatement?.State == PreparedState.Prepared) - continue; // Prepared, we already have the RowDescription + pStatement?.RefreshLastUsed(); + + Debug.Assert(batchCommand.FinalCommandText is not null); - await connector.WriteParse(batchCommand.FinalCommandText!, batchCommand.StatementName, + if (pStatement != null && !batchCommand.IsPreparing) + { + // Prepared, we already have the RowDescription + Debug.Assert(pStatement.IsPrepared); + continue; + } + + // We may have a prepared statement that replaces an existing statement - close the latter first. + if (pStatement?.StatementBeingReplaced != null) + await connector.WriteClose(StatementOrPortal.Statement, pStatement.StatementBeingReplaced.Name!, async, cancellationToken).ConfigureAwait(false); + + await connector.WriteParse(batchCommand.FinalCommandText, batchCommand.StatementName, batchCommand.CurrentParametersReadOnly, async, cancellationToken).ConfigureAwait(false); await connector.WriteDescribe(StatementOrPortal.Statement, batchCommand.StatementName, async, cancellationToken).ConfigureAwait(false); @@ -1114,8 +1164,6 @@ await connector.WriteParse(batchCommand.FinalCommandText!, batchCommand.Statemen async Task SendDeriveParameters(NpgsqlConnector connector, bool async, CancellationToken cancellationToken = default) { - BeginSend(connector); - var syncCaller = !async; for (var i = 0; i < InternalBatchCommands.Count; i++) { @@ -1124,8 +1172,8 @@ async Task SendDeriveParameters(NpgsqlConnector connector, bool async, Cancellat var batchCommand = InternalBatchCommands[i]; - await connector.WriteParse(batchCommand.FinalCommandText!, Array.Empty(), NpgsqlBatchCommand.EmptyParameters, async, cancellationToken).ConfigureAwait(false); - await connector.WriteDescribe(StatementOrPortal.Statement, Array.Empty(), async, cancellationToken).ConfigureAwait(false); + await connector.WriteParse(batchCommand.FinalCommandText!, [], NpgsqlBatchCommand.EmptyParameters, async, cancellationToken).ConfigureAwait(false); + await connector.WriteDescribe(StatementOrPortal.Statement, [], async, cancellationToken).ConfigureAwait(false); } await connector.WriteSync(async, cancellationToken).ConfigureAwait(false); @@ -1134,8 +1182,6 @@ async Task SendDeriveParameters(NpgsqlConnector connector, bool async, Cancellat async Task SendPrepare(NpgsqlConnector connector, bool async, CancellationToken cancellationToken = default) { - BeginSend(connector); - var syncCaller = !async; for (var i = 0; i < InternalBatchCommands.Count; i++) { @@ -1183,8 +1229,6 @@ bool ShouldSchedule(ref bool async, int indexOfStatementInBatch) async Task SendClose(NpgsqlConnector connector, bool async, CancellationToken cancellationToken = default) { - BeginSend(connector); - foreach (var batchCommand in InternalBatchCommands) { if (!batchCommand.IsPrepared) @@ -1264,7 +1308,7 @@ async Task ExecuteNonQuery(bool async, CancellationToken cancellationToken) async ValueTask ExecuteScalar(bool async, CancellationToken cancellationToken) { var behavior = CommandBehavior.SingleRow; - if (IsWrappedByBatch || _parameters?.HasOutputParameters != true) + if (WrappingBatch is not null || _parameters?.HasOutputParameters != true) behavior |= CommandBehavior.SequentialAccess; var reader = await ExecuteReader(async, behavior, cancellationToken).ConfigureAwait(false); @@ -1338,16 +1382,12 @@ protected override async Task ExecuteDbDataReaderAsync(CommandBeha public new Task ExecuteReaderAsync(CommandBehavior behavior, CancellationToken cancellationToken = default) => ExecuteReader(async: true, behavior, cancellationToken).AsTask(); - // TODO: Maybe pool these? - internal ManualResetValueTaskSource ExecutionCompletion { get; } - = new(); - internal virtual async ValueTask ExecuteReader(bool async, CommandBehavior behavior, CancellationToken cancellationToken) { var conn = CheckAndGetConnection(); _behavior = behavior; - NpgsqlConnector? connector; + NpgsqlConnector connector; if (_connector is not null) { Debug.Assert(conn is null); @@ -1358,221 +1398,162 @@ internal virtual async ValueTask ExecuteReader(bool async, Com else { Debug.Assert(conn is not null); - conn.TryGetBoundConnector(out connector); + connector = conn.Connector!; } try { - if (connector is not null) - { - var logger = connector.CommandLogger; + var logger = connector.CommandLogger; + var reloadableState = connector.ReloadableState; - cancellationToken.ThrowIfCancellationRequested(); - // We cannot pass a token here, as we'll cancel a non-send query - // Also, we don't pass the cancellation token to StartUserAction, since that would make it scope to the entire action (command execution) - // whereas it should only be scoped to the Execute method. - connector.StartUserAction(ConnectorState.Executing, this, CancellationToken.None); + cancellationToken.ThrowIfCancellationRequested(); + // We cannot pass a token here, as we'll cancel a non-send query + // Also, we don't pass the cancellation token to StartUserAction, since that would make it scope to the entire action (command execution) + // whereas it should only be scoped to the Execute method. + connector.StartUserAction(ConnectorState.Executing, this, CancellationToken.None); - Task? sendTask; + Task? sendTask; - var validateParameterValues = !behavior.HasFlag(CommandBehavior.SchemaOnly); - long startTimestamp; + var validateParameterValues = !behavior.HasFlag(CommandBehavior.SchemaOnly); + long startTimestamp; - try + try + { + var fullyPrepared = false; + + switch (IsExplicitlyPrepared) { - switch (IsExplicitlyPrepared) + case true: + Debug.Assert(_connectorPreparedOn != null); + if (WrappingBatch is not null) { - case true: - Debug.Assert(_connectorPreparedOn != null); - if (IsWrappedByBatch) - { - foreach (var batchCommand in InternalBatchCommands) - { - if (batchCommand.ConnectorPreparedOn != connector) - { - foreach (var s in InternalBatchCommands) - s.ResetPreparation(); - ResetPreparation(); - goto case false; - } - - batchCommand._parameters?.ProcessParameters(connector.SerializerOptions, validateParameterValues, CommandType); - } - } - else + foreach (var batchCommand in InternalBatchCommands) { - if (_connectorPreparedOn != connector) + if (batchCommand.ConnectorPreparedOn != connector) { - // The command was prepared, but since then the connector has changed. Detach all prepared statements. - foreach (var s in InternalBatchCommands) - s.PreparedStatement = null; ResetPreparation(); goto case false; } - _parameters?.ProcessParameters(connector.SerializerOptions, validateParameterValues, CommandType); - } - - NpgsqlEventSource.Log.CommandStartPrepared(); - connector.DataSource.MetricsReporter.CommandStartPrepared(); - break; - case false: - var numPrepared = 0; - - if (IsWrappedByBatch) + batchCommand._parameters?.ProcessParameters(reloadableState, validateParameterValues, batchCommand.CommandType); + } + } + else + { + if (_connectorPreparedOn != connector) { - for (var i = 0; i < InternalBatchCommands.Count; i++) - { - var batchCommand = InternalBatchCommands[i]; + // The command was prepared, but since then the connector has changed. Detach all prepared statements. + ResetPreparation(); + goto case false; + } + _parameters?.ProcessParameters(reloadableState, validateParameterValues, CommandType); + } - batchCommand._parameters?.ProcessParameters(connector.SerializerOptions, validateParameterValues, CommandType); - ProcessRawQuery(connector.SqlQueryParser, connector.UseConformingStrings, batchCommand); + NpgsqlEventSource.Log.CommandStartPrepared(); + connector.DataSource.MetricsReporter.CommandStartPrepared(); + fullyPrepared = true; + break; - if (connector.Settings.MaxAutoPrepare > 0 && batchCommand.TryAutoPrepare(connector)) - { - batchCommand.ConnectorPreparedOn = connector; - numPrepared++; - } - } - } - else + case false: + var numPrepared = 0; + + if (WrappingBatch is not null) + { + for (var i = 0; i < InternalBatchCommands.Count; i++) { - _parameters?.ProcessParameters(connector.SerializerOptions, validateParameterValues, CommandType); - ProcessRawQuery(connector.SqlQueryParser, connector.UseConformingStrings, batchCommand: null); + var batchCommand = InternalBatchCommands[i]; - if (connector.Settings.MaxAutoPrepare > 0) - for (var i = 0; i < InternalBatchCommands.Count; i++) - if (InternalBatchCommands[i].TryAutoPrepare(connector)) - numPrepared++; - } + batchCommand._parameters?.ProcessParameters(reloadableState, validateParameterValues, batchCommand.CommandType); + ProcessRawQuery(connector.SqlQueryParser, connector.UseConformingStrings, batchCommand); - if (numPrepared > 0) - { - _connectorPreparedOn = connector; - if (numPrepared == InternalBatchCommands.Count) + if (connector.Settings.MaxAutoPrepare > 0 && batchCommand.TryAutoPrepare(connector)) { - NpgsqlEventSource.Log.CommandStartPrepared(); - connector.DataSource.MetricsReporter.CommandStartPrepared(); + batchCommand.ConnectorPreparedOn = connector; + numPrepared++; } } - - break; } + else + { + _parameters?.ProcessParameters(reloadableState, validateParameterValues, CommandType); + ProcessRawQuery(connector.SqlQueryParser, connector.UseConformingStrings, batchCommand: null); - State = CommandState.InProgress; + if (connector.Settings.MaxAutoPrepare > 0) + for (var i = 0; i < InternalBatchCommands.Count; i++) + if (InternalBatchCommands[i].TryAutoPrepare(connector)) + numPrepared++; + } - if (logger.IsEnabled(LogLevel.Information)) + if (numPrepared > 0) { - connector.QueryLogStopWatch.Restart(); - - if (logger.IsEnabled(LogLevel.Debug)) - LogExecutingCompleted(connector, executing: true); + _connectorPreparedOn = connector; + if (numPrepared == InternalBatchCommands.Count) + { + NpgsqlEventSource.Log.CommandStartPrepared(); + connector.DataSource.MetricsReporter.CommandStartPrepared(); + fullyPrepared = true; + } } - NpgsqlEventSource.Log.CommandStart(CommandText); - startTimestamp = connector.DataSource.MetricsReporter.ReportCommandStart(); - TraceCommandStart(connector); - - // If a cancellation is in progress, wait for it to "complete" before proceeding (#615) - connector.ResetCancellation(); - - // We do not wait for the entire send to complete before proceeding to reading - - // the sending continues in parallel with the user's reading. Waiting for the - // entire send to complete would trigger a deadlock for multi-statement commands, - // where PostgreSQL sends large results for the first statement, while we're sending large - // parameter data for the second. See #641. - // Instead, all sends for non-first statements are performed asynchronously (even if the user requested sync), - // in a special synchronization context to prevents a dependency on the thread pool (which would also trigger - // deadlocks). - BeginSend(connector); - sendTask = Write(connector, async, flush: true, CancellationToken.None); - - // The following is a hack. It raises an exception if one was thrown in the first phases - // of the send (i.e. in parts of the send that executed synchronously). Exceptions may - // still happen later and aren't properly handled. See #1323. - if (sendTask.IsFaulted) - sendTask.GetAwaiter().GetResult(); - } - catch - { - connector.EndUserAction(); - throw; + break; } - // TODO: DRY the following with multiplexing, but be careful with the cancellation registration... - var reader = connector.DataReader; - reader.Init(this, behavior, InternalBatchCommands, startTimestamp, sendTask); - connector.CurrentReader = reader; - if (async) - await reader.NextResultAsync(cancellationToken).ConfigureAwait(false); - else - reader.NextResult(); + // If a cancellation is in progress, wait for it to "complete" before proceeding (#615) + // We do it before changing the state because we only allow sending cancellation request if State == InProgress + connector.ResetCancellation(); - TraceReceivedFirstResponse(); - - return reader; - } - else - { - Debug.Assert(conn is not null); - Debug.Assert(conn.Settings.Multiplexing); - - // The connection isn't bound to a connector - it's multiplexing time. - var dataSource = (MultiplexingDataSource)conn.NpgsqlDataSource; + State = CommandState.InProgress; - if (!async) + if (logger.IsEnabled(LogLevel.Information)) { - // The waiting on the ExecutionCompletion ManualResetValueTaskSource is necessarily - // asynchronous, so allowing sync would mean sync-over-async. - ThrowHelper.ThrowNotSupportedException("Synchronous command execution is not supported when multiplexing is on"); - } + connector.QueryLogStopWatch.Restart(); - if (IsWrappedByBatch) - { - foreach (var batchCommand in InternalBatchCommands) - { - batchCommand._parameters?.ProcessParameters(dataSource.SerializerOptions, validateValues: true, CommandType); - ProcessRawQuery(null, standardConformingStrings: true, batchCommand); - } - } - else - { - _parameters?.ProcessParameters(dataSource.SerializerOptions, validateValues: true, CommandType); - ProcessRawQuery(null, standardConformingStrings: true, batchCommand: null); + if (logger.IsEnabled(LogLevel.Debug)) + LogExecutingCompleted(connector, executing: true); } - State = CommandState.InProgress; + NpgsqlEventSource.Log.CommandStart(CommandText); + startTimestamp = connector.DataSource.MetricsReporter.ReportCommandStart(); + TraceCommandStart(connector.DataSource.Configuration.TracingOptions, fullyPrepared); + TraceCommandEnrich(connector); + + // We do not wait for the entire send to complete before proceeding to reading - + // the sending continues in parallel with the user's reading. Waiting for the + // entire send to complete would trigger a deadlock for multi-statement commands, + // where PostgreSQL sends large results for the first statement, while we're sending large + // parameter data for the second. See #641. + // Instead, all sends for non-first statements are performed asynchronously (even if the user requested sync), + // in a special synchronization context to prevents a dependency on the thread pool (which would also trigger + // deadlocks). + sendTask = Write(connector, async, flush: true, CancellationToken.None); + + // The following is a hack. It raises an exception if one was thrown in the first phases + // of the send (i.e. in parts of the send that executed synchronously). Exceptions may + // still happen later and aren't properly handled. See #1323. + if (sendTask.IsFaulted) + sendTask.GetAwaiter().GetResult(); + } + catch + { + connector.EndUserAction(); + throw; + } - // TODO: Experiment: do we want to wait on *writing* here, or on *reading*? - // Previous behavior was to wait on reading, which throw the exception from ExecuteReader (and not from - // the first read). But waiting on writing would allow us to do sync writing and async reading. - ExecutionCompletion.Reset(); - try - { - await dataSource.MultiplexCommandWriter.WriteAsync(this, cancellationToken).ConfigureAwait(false); - } - catch (ChannelClosedException ex) - { - Debug.Assert(ex.InnerException is not null); - throw ex.InnerException; - } - connector = await new ValueTask(ExecutionCompletion, ExecutionCompletion.Version).ConfigureAwait(false); - // TODO: Overload of StartBindingScope? - conn.Connector = connector; - connector.Connection = conn; - conn.ConnectorBindingScope = ConnectorBindingScope.Reader; - - var reader = connector.DataReader; - reader.Init(this, behavior, InternalBatchCommands); - connector.CurrentReader = reader; + var reader = connector.DataReader; + reader.Init(this, behavior, InternalBatchCommands, startTimestamp, sendTask); + connector.CurrentReader = reader; + if (async) await reader.NextResultAsync(cancellationToken).ConfigureAwait(false); + else + reader.NextResult(); - return reader; - } + TraceReceivedFirstResponse(connector.DataSource.Configuration.TracingOptions); + + return reader; } catch (Exception e) { - var reader = connector?.CurrentReader; + var reader = connector.CurrentReader; if (e is not NpgsqlOperationInProgressException && reader is not null) await reader.Cleanup(async).ConfigureAwait(false); @@ -1603,7 +1584,13 @@ internal virtual async ValueTask ExecuteReader(bool async, Com protected override DbTransaction? DbTransaction { get => _transaction; - set => _transaction = (NpgsqlTransaction?)value; + set + { + var tx = (NpgsqlTransaction?)value; + if (tx is { IsCompleted: true }) + throw new InvalidOperationException("Transaction is already completed"); + _transaction = tx; + } } /// @@ -1634,7 +1621,7 @@ public override void Cancel() if (connector is null) return; - connector.PerformUserCancellation(); + connector.PerformImmediateUserCancellation(); } #endregion Cancel @@ -1667,7 +1654,8 @@ internal void Reset() // Can be null if it's owned by batch _parameters?.Clear(); _timeout = null; - _allResultTypesAreUnknown = false; + AllResultTypesAreUnknown = false; + Debug.Assert(_unknownResultTypeList is null); EnableErrorBarriers = false; } @@ -1677,26 +1665,55 @@ internal void Reset() #region Tracing - internal void TraceCommandStart(NpgsqlConnector connector) + internal void TraceCommandStart(NpgsqlTracingOptions tracingOptions, bool? prepared) { Debug.Assert(CurrentActivity is null); + if (NpgsqlActivitySource.IsEnabled) - CurrentActivity = NpgsqlActivitySource.CommandStart(connector, CommandText, CommandType); + { + var enableTracing = WrappingBatch is not null + ? tracingOptions.BatchFilter?.Invoke(WrappingBatch) ?? true + : tracingOptions.CommandFilter?.Invoke(this) ?? true; + + if (enableTracing) + { + var spanName = WrappingBatch is not null + ? tracingOptions.BatchSpanNameProvider?.Invoke(WrappingBatch) + : tracingOptions.CommandSpanNameProvider?.Invoke(this); + + CurrentActivity = NpgsqlActivitySource.CommandStart( + WrappingBatch is not null ? GetBatchFullCommandText() : CommandText, + CommandType, + prepared, + spanName); + } + } } - internal void TraceReceivedFirstResponse() + internal void TraceCommandEnrich(NpgsqlConnector connector) { if (CurrentActivity is not null) { - NpgsqlActivitySource.ReceivedFirstResponse(CurrentActivity); + NpgsqlActivitySource.Enrich(CurrentActivity, connector); + var tracingOptions = connector.DataSource.Configuration.TracingOptions; + if (WrappingBatch is not null) + tracingOptions.BatchEnrichmentCallback?.Invoke(CurrentActivity, WrappingBatch); + else + tracingOptions.CommandEnrichmentCallback?.Invoke(CurrentActivity, this); } } + internal void TraceReceivedFirstResponse(NpgsqlTracingOptions tracingOptions) + { + if (CurrentActivity is not null) + NpgsqlActivitySource.ReceivedFirstResponse(CurrentActivity, tracingOptions); + } + internal void TraceCommandStop() { if (CurrentActivity is not null) { - NpgsqlActivitySource.CommandStop(CurrentActivity); + CurrentActivity.Dispose(); CurrentActivity = null; } } @@ -1761,6 +1778,7 @@ internal void LogExecutingCompleted(NpgsqlConnector connector, bool executing) { var logParameters = connector.LoggingConfiguration.IsParameterLoggingEnabled || connector.Settings.LogParameters; var logger = connector.LoggingConfiguration.CommandLogger; + Debug.Assert(executing ? logger.IsEnabled(LogLevel.Debug) : logger.IsEnabled(LogLevel.Information)); if (InternalBatchCommands.Count == 1) { @@ -1773,7 +1791,7 @@ internal void LogExecutingCompleted(NpgsqlConnector connector, bool executing) LogMessages.ExecutingCommandWithParameters( logger, singleCommand.FinalCommandText!, - ParametersDbNullAsString(singleCommand), + GetParametersForLogging(singleCommand), connector.Id); } else @@ -1781,7 +1799,7 @@ internal void LogExecutingCompleted(NpgsqlConnector connector, bool executing) LogMessages.CommandExecutionCompletedWithParameters( logger, singleCommand.FinalCommandText!, - ParametersDbNullAsString(singleCommand), + GetParametersForLogging(singleCommand), connector.QueryLogStopWatch.ElapsedMilliseconds, connector.Id); } @@ -1798,9 +1816,9 @@ internal void LogExecutingCompleted(NpgsqlConnector connector, bool executing) { if (logParameters) { - var commands = new (string, object[])[InternalBatchCommands.Count]; + var commands = new (string, IEnumerable)[InternalBatchCommands.Count]; for (var i = 0; i < InternalBatchCommands.Count; i++) - commands[i] = (InternalBatchCommands[i].FinalCommandText!, ParametersDbNullAsString(InternalBatchCommands[i])); + commands[i] = (InternalBatchCommands[i].FinalCommandText!, new LoggingEnumerable(GetParametersForLogging(InternalBatchCommands[i]))); if (executing) LogMessages.ExecutingBatchWithParameters(logger, commands, connector.Id); @@ -1819,15 +1837,53 @@ internal void LogExecutingCompleted(NpgsqlConnector connector, bool executing) } } - object[] ParametersDbNullAsString(NpgsqlBatchCommand c) + static object[] GetParametersForLogging(NpgsqlBatchCommand c) { var positionalParameters = c.CurrentParametersReadOnly; var parameters = new object[positionalParameters.Count]; for (var i = 0; i < positionalParameters.Count; i++) - parameters[i] = positionalParameters[i].Value == DBNull.Value ? "NULL" : positionalParameters[i].Value!; + { + parameters[i] = GetParameterForLogging(positionalParameters[i].Value); + } return parameters; + + object GetParameterForLogging(object? value) + { + return value switch + { + DBNull or null => "NULL", + IEnumerable enumerable and not string => GetEnumerableForLogging(enumerable), + _ => value + }; + + string GetEnumerableForLogging(IEnumerable enumerable) + { + var vsb = new StringBuilder(256); + var count = 0; + vsb.Append('['); + foreach (var e in enumerable) + { + if (count > 9) + { + vsb.Append(", ..."); + break; + } + + if (count > 0) + { + vsb.Append(", "); + } + + vsb.Append(GetParameterForLogging(e)); + count++; + } + + vsb.Append(']'); + return vsb.ToString(); + } + } } - } + } /// /// Create a new command based on this one. @@ -1855,8 +1911,7 @@ public virtual NpgsqlCommand Clone() NpgsqlConnection? CheckAndGetConnection() { - if (State is CommandState.Disposed) - ThrowHelper.ThrowObjectDisposedException(GetType().FullName); + ObjectDisposedException.ThrowIf(State is CommandState.Disposed, this); var conn = InternalConnection; if (conn is null) diff --git a/src/Npgsql/NpgsqlCommandBuilder.cs b/src/Npgsql/NpgsqlCommandBuilder.cs index 9665b8356c..d9a698c2ef 100644 --- a/src/Npgsql/NpgsqlCommandBuilder.cs +++ b/src/Npgsql/NpgsqlCommandBuilder.cs @@ -212,7 +212,11 @@ private static void SetParameterValuesFromRow(NpgsqlCommand command, DataRow row protected override void ApplyParameterInfo(DbParameter p, DataRow row, System.Data.StatementType statementType, bool whereClause) { var param = (NpgsqlParameter)p; - param.NpgsqlDbType = (NpgsqlDbType)row[SchemaTableColumn.ProviderType]; + // DbCommandBuilder is going to set DbType.Int32 onto an existing parameter, reset other db type fields. + if (param.SourceColumnNullMapping) + param.ResetDbType(); + else + param.NpgsqlDbType = (NpgsqlDbType)row[SchemaTableColumn.ProviderType]; } /// diff --git a/src/Npgsql/NpgsqlConnection.cs b/src/Npgsql/NpgsqlConnection.cs index da8262636c..9619e938bd 100644 --- a/src/Npgsql/NpgsqlConnection.cs +++ b/src/Npgsql/NpgsqlConnection.cs @@ -5,7 +5,6 @@ using System.Data.Common; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; -using System.IO; using System.Net.Security; using System.Net.Sockets; using System.Runtime.CompilerServices; @@ -46,8 +45,7 @@ public sealed class NpgsqlConnection : DbConnection, ICloneable, IComponent ConnectionState _fullState; /// - /// The physical connection to the database. This is when the connection is closed, - /// and also when it is open in multiplexing mode and unbound (e.g. not in a transaction). + /// The physical connection to the database. This is when the connection is closed. /// internal NpgsqlConnector? Connector { get; set; } @@ -102,12 +100,6 @@ public INpgsqlTypeMapper TypeMapper /// internal const int TimeoutLimit = 1024; - /// - /// Tracks when this connection was bound to a physical connector (e.g. at open-time, when a transaction - /// was started...). - /// - internal ConnectorBindingScope ConnectorBindingScope { get; set; } - ILogger _connectionLogger = default!; // Initialized in Open, shouldn't be used otherwise static readonly StateChangeEventArgs ClosedToOpenEventArgs = new(ConnectionState.Closed, ConnectionState.Open); @@ -139,7 +131,6 @@ internal NpgsqlConnection(NpgsqlDataSource dataSource, NpgsqlConnector connector Connector = connector; connector.Connection = this; - ConnectorBindingScope = ConnectorBindingScope.Connection; FullState = ConnectionState.Open; } @@ -220,20 +211,7 @@ void SetupDataSource() _cloningInstantiator = s => new NpgsqlConnection(s); _dataSource = PoolManager.Pools.GetOrAdd(canonical, newDataSource); - if (_dataSource == newDataSource) - { - Debug.Assert(_dataSource is not MultiHostDataSourceWrapper); - // If the pool we created was the one that ended up being stored we need to increment the appropriate counter. - // Avoids a race condition where multiple threads will create a pool but only one will be stored. - if (_dataSource is NpgsqlMultiHostDataSource multiHostConnectorPool) - foreach (var hostPool in multiHostConnectorPool.Pools) - NpgsqlEventSource.Log.DataSourceCreated(hostPool); - else - { - NpgsqlEventSource.Log.DataSourceCreated(newDataSource); - } - } - else + if (_dataSource != newDataSource) newDataSource.Dispose(); // If this is a multi-host data source and the user specified a TargetSessionAttributes, create a wrapper in front of the @@ -260,37 +238,10 @@ internal Task Open(bool async, CancellationToken cancellationToken) if (_connectionLogger.IsEnabled(LogLevel.Trace)) LogMessages.OpeningConnection(_connectionLogger, Settings.Host!, Settings.Port, Settings.Database!, _userFacingConnectionString); - if (Settings.Multiplexing) - { - if (Settings.Enlist && Transaction.Current != null) - { - // TODO: Keep in mind that the TransactionScope can be disposed - ThrowHelper.ThrowNotSupportedException(); - } - - // We're opening in multiplexing mode, without a transaction. We don't actually do anything. - - // If we've never connected with this connection string, open a physical connector in order to generate - // any exception (bad user/password, IP address...). This reproduces the standard error behavior. - if (!_dataSource.IsBootstrapped) - { - FullState = ConnectionState.Connecting; - return PerformMultiplexingStartupCheck(async, cancellationToken); - } - - if (_connectionLogger.IsEnabled(LogLevel.Debug)) - LogMessages.OpenedMultiplexingConnection(_connectionLogger, Settings.Host!, Settings.Port, Settings.Database!, _userFacingConnectionString); - FullState = ConnectionState.Open; - - return Task.CompletedTask; - } - return OpenAsync(async, cancellationToken); async Task OpenAsync(bool async, CancellationToken cancellationToken) { - Debug.Assert(!Settings.Multiplexing); - FullState = ConnectionState.Connecting; NpgsqlConnector? connector = null; try @@ -315,7 +266,6 @@ async Task OpenAsync(bool async, CancellationToken cancellationToken) Debug.Assert(connector.Connection is null, $"Connection for opened connector '{Connector?.Id.ToString() ?? "???"}' is bound to another connection"); - ConnectorBindingScope = ConnectorBindingScope.Connection; connector.Connection = this; Connector = connector; @@ -328,7 +278,6 @@ async Task OpenAsync(bool async, CancellationToken cancellationToken) catch { FullState = ConnectionState.Closed; - ConnectorBindingScope = ConnectorBindingScope.None; Connector = null; EnlistedTransaction = null; @@ -342,25 +291,6 @@ async Task OpenAsync(bool async, CancellationToken cancellationToken) } } - async Task PerformMultiplexingStartupCheck(bool async, CancellationToken cancellationToken) - { - try - { - var timeout = new NpgsqlTimeout(TimeSpan.FromSeconds(ConnectionTimeout)); - - _ = await StartBindingScope(ConnectorBindingScope.Connection, timeout, async, cancellationToken).ConfigureAwait(false); - EndBindingScope(ConnectorBindingScope.Connection); - - LogMessages.OpenedMultiplexingConnection(_connectionLogger, Settings.Host!, Settings.Port, Settings.Database!, _userFacingConnectionString); - - FullState = ConnectionState.Open; - } - catch - { - FullState = ConnectionState.Closed; - throw; - } - } } #endregion Open / Init @@ -638,31 +568,21 @@ async ValueTask BeginTransaction(bool async, IsolationLevel l ThrowHelper.ThrowNotSupportedException($"Unsupported IsolationLevel: {nameof(IsolationLevel.Chaos)}"); CheckReady(); - if (Connector is { InTransaction: true }) + var connector = Connector; + if (connector is { InTransaction: true }) ThrowHelper.ThrowInvalidOperationException("A transaction is already in progress; nested/concurrent transactions aren't supported."); // There was a committed/rolled back transaction, but it was not disposed - var connector = ConnectorBindingScope == ConnectorBindingScope.Transaction - ? Connector - : await StartBindingScope(ConnectorBindingScope.Transaction, NpgsqlTimeout.Infinite, async, cancellationToken).ConfigureAwait(false); Debug.Assert(connector != null); - try - { - // Note that beginning a transaction doesn't actually send anything to the backend (only prepends). - // But we start a user action to check the cancellation token and generate exceptions - using var _ = connector.StartUserAction(cancellationToken); + // Note that beginning a transaction doesn't actually send anything to the backend (only prepends). + // But we start a user action to check the cancellation token and generate exceptions + using var _ = connector.StartUserAction(cancellationToken); - connector.Transaction ??= new NpgsqlTransaction(connector); - connector.Transaction.Init(level); - return connector.Transaction; - } - catch - { - EndBindingScope(ConnectorBindingScope.Transaction); - throw; - } + connector.Transaction ??= new NpgsqlTransaction(connector); + connector.Transaction.Init(level); + return connector.Transaction; } /// @@ -712,9 +632,6 @@ protected override async ValueTask BeginDbTransactionAsync(Isolat /// public override void EnlistTransaction(Transaction? transaction) { - if (Settings.Multiplexing) - throw new NotSupportedException("Ambient transactions aren't yet implemented for multiplexing"); - if (EnlistedTransaction != null) { if (EnlistedTransaction.Equals(transaction)) @@ -733,14 +650,11 @@ public override void EnlistTransaction(Transaction? transaction) } CheckReady(); - var connector = StartBindingScope(ConnectorBindingScope.Transaction); + var connector = Connector!; EnlistedTransaction = transaction; if (transaction == null) - { - EndBindingScope(ConnectorBindingScope.Transaction); return; - } // Until #1378 is implemented, we have no recovery, and so no need to enlist as a durable resource manager // (or as promotable single phase). @@ -754,7 +668,7 @@ public override void EnlistTransaction(Transaction? transaction) EnlistedTransaction = transaction; LogMessages.EnlistedVolatileResourceManager( - Connector!.LoggingConfiguration.TransactionLogger, + connector.LoggingConfiguration.TransactionLogger, transaction.TransactionInformation.LocalIdentifier, connector.Id); } @@ -807,28 +721,12 @@ internal Task Close(bool async) throw new ArgumentOutOfRangeException("Unknown connection state: " + FullState); } - // TODO: The following shouldn't exist - we need to flow down the regular path to close any - // open reader / COPY. See test CloseDuringRead with multiplexing. - if (Settings.Multiplexing && ConnectorBindingScope == ConnectorBindingScope.None) - { - // TODO: Consider falling through to the regular reset logic. This adds some unneeded conditions - // and assignment but actual perf impact should be negligible (measure). - Debug.Assert(Connector == null); - ReleaseCloseLock(); - - FullState = ConnectionState.Closed; - LogMessages.ClosedMultiplexingConnection(_connectionLogger, Settings.Host!, Settings.Port, Settings.Database!, _userFacingConnectionString); - - return Task.CompletedTask; - } - return CloseAsync(async); } async Task CloseAsync(bool async) { Debug.Assert(Connector != null); - Debug.Assert(ConnectorBindingScope != ConnectorBindingScope.None); try { @@ -839,16 +737,6 @@ async Task CloseAsync(bool async) { // This method could re-enter connection.Close() due to an underlying connection failure. await connector.CloseOngoingOperations(async).ConfigureAwait(false); - - if (ConnectorBindingScope == ConnectorBindingScope.None) - { - Debug.Assert(Settings.Multiplexing); - Debug.Assert(Connector is null); - - FullState = ConnectionState.Closed; - LogMessages.ClosedMultiplexingConnection(_connectionLogger, Settings.Host!, Settings.Port, Settings.Database!, _userFacingConnectionString); - return; - } } Debug.Assert(connector.IsReady || connector.IsBroken, $"Connector is not ready or broken during close, it's {connector.State}"); @@ -857,14 +745,13 @@ async Task CloseAsync(bool async) if (EnlistedTransaction != null) { - // A System.Transactions transaction is still in progress - - connector.Connection = null; - - // Close the connection and disconnect it from the resource manager but leave the + // A System.Transactions transaction is still in progress. + // Close the connection and disconnect it from the resource manager and reset the connector, but leave the // connector in an enlisted pending list in the data source. If another connection is opened within // the same transaction scope, we will reuse this connector to avoid escalating to a distributed - // transaction + // transaction. + connector.ResetWithinEnlistedTransaction(); + connector.Connection = null; _dataSource?.AddPendingEnlistedConnector(connector, EnlistedTransaction); EnlistedTransaction = null; @@ -874,7 +761,6 @@ async Task CloseAsync(bool async) if (Settings.Pooling) { // Clear the buffer, roll back any pending transaction and prepend a reset message if needed - // Also returns the connector to the pool, if there is an open transaction and multiplexing is on // Note that we're doing this only for pooled connections await connector.Reset(async).ConfigureAwait(false); } @@ -885,23 +771,12 @@ async Task CloseAsync(bool async) connector.Transaction?.UnbindIfNecessary(); } - if (Settings.Multiplexing) - { - // We've already closed ongoing operations rolled back any transaction and the connector is already in the pool, - // so we must be unbound. Nothing to do. - Debug.Assert(ConnectorBindingScope == ConnectorBindingScope.None, - $"When closing a multiplexed connection, the connection was supposed to be unbound, but {nameof(ConnectorBindingScope)} was {ConnectorBindingScope}"); - } - else - { - connector.Connection = null; - connector.Return(); - } + connector.Connection = null; + connector.Return(); } LogMessages.ClosedConnection(_connectionLogger, Settings.Host!, Settings.Port, Settings.Database!, _userFacingConnectionString, connector.Id); Connector = null; - ConnectorBindingScope = ConnectorBindingScope.None; FullState = ConnectionState.Closed; } finally @@ -999,17 +874,50 @@ internal void OnNotification(NpgsqlNotificationEventArgs e) /// /// Returns whether SSL is being used for the connection. /// - internal bool IsSecure => CheckOpenAndRunInTemporaryScope(c => c.IsSecure); + internal bool IsSslEncrypted + { + get + { + CheckOpen(); + return Connector!.IsSslEncrypted; + } + } + + /// + /// Returns whether GSS encryption is being used for the connection. + /// + internal bool IsGssEncrypted + { + get + { + CheckOpen(); + return Connector!.IsGssEncrypted; + } + } /// /// Returns whether SCRAM-SHA256 is being user for the connection /// - internal bool IsScram => CheckOpenAndRunInTemporaryScope(c => c.IsScram); + internal bool IsScram + { + get + { + CheckOpen(); + return Connector!.IsScram; + } + } /// /// Returns whether SCRAM-SHA256-PLUS is being user for the connection /// - internal bool IsScramPlus => CheckOpenAndRunInTemporaryScope(c => c.IsScramPlus); + internal bool IsScramPlus + { + get + { + CheckOpen(); + return Connector!.IsScramPlus; + } + } /// /// Selects the local Secure Sockets Layer (SSL) certificate used for authentication. @@ -1017,6 +925,7 @@ internal void OnNotification(NpgsqlNotificationEventArgs e) /// /// See /// + [Obsolete("Use UseSslClientAuthenticationOptionsCallback")] public ProvideClientCertificatesCallback? ProvideClientCertificatesCallback { get; set; } /// @@ -1032,8 +941,19 @@ internal void OnNotification(NpgsqlNotificationEventArgs e) /// See . /// /// + [Obsolete("Use UseSslClientAuthenticationOptionsCallback")] public RemoteCertificateValidationCallback? UserCertificateValidationCallback { get; set; } + /// + /// When using SSL/TLS, this is a callback that allows customizing SslStream's authentication options. + /// + /// + /// + /// See . + /// + /// + public Action? SslClientAuthenticationOptionsCallback { get; set; } + #endregion SSL #region Backend version, capabilities, settings @@ -1053,7 +973,14 @@ internal void OnNotification(NpgsqlNotificationEventArgs e) /// /// [Browsable(false)] - public Version PostgreSqlVersion => CheckOpenAndRunInTemporaryScope(c => c.DatabaseInfo.Version); + public Version PostgreSqlVersion + { + get + { + CheckOpen(); + return Connector!.DatabaseInfo.Version; + } + } /// /// The PostgreSQL server version as returned by the server_version option. @@ -1061,8 +988,14 @@ internal void OnNotification(NpgsqlNotificationEventArgs e) /// This can only be called when the connection is open. /// /// - public override string ServerVersion => CheckOpenAndRunInTemporaryScope( - c => c.DatabaseInfo.ServerVersion); + public override string ServerVersion + { + get + { + CheckOpen(); + return Connector!.DatabaseInfo.ServerVersion; + } + } /// /// Process id of backend server. @@ -1075,10 +1008,7 @@ public int ProcessID get { CheckOpen(); - - return TryGetBoundConnector(out var connector) - ? connector.BackendProcessId - : throw new InvalidOperationException("No bound physical connection (using multiplexing)"); + return Connector!.BackendProcessId; } } @@ -1088,13 +1018,27 @@ public int ProcessID /// Meant for use by type plugins (e.g. NodaTime) /// [Browsable(false)] - public bool HasIntegerDateTimes => CheckOpenAndRunInTemporaryScope(c => c.DatabaseInfo.HasIntegerDateTimes); + public bool HasIntegerDateTimes + { + get + { + CheckOpen(); + return Connector!.DatabaseInfo.HasIntegerDateTimes; + } + } /// /// The connection's timezone as reported by PostgreSQL, in the IANA/Olson database format. /// [Browsable(false)] - public string Timezone => CheckOpenAndRunInTemporaryScope(c => c.Timezone); + public string Timezone + { + get + { + CheckOpen(); + return Connector!.Timezone; + } + } /// /// Holds all PostgreSQL parameters received for this connection. Is updated if the values change @@ -1102,7 +1046,13 @@ public int ProcessID /// [Browsable(false)] public IReadOnlyDictionary PostgresParameters - => CheckOpenAndRunInTemporaryScope(c => c.PostgresParameters); + { + get + { + CheckOpen(); + return Connector!.PostgresParameters; + } + } #endregion Backend version, capabilities, settings @@ -1133,28 +1083,36 @@ public Task BeginBinaryImportAsync(string copyFromCommand, async Task BeginBinaryImport(bool async, string copyFromCommand, CancellationToken cancellationToken = default) { - if (copyFromCommand == null) - throw new ArgumentNullException(nameof(copyFromCommand)); + ArgumentNullException.ThrowIfNull(copyFromCommand); if (!IsValidCopyCommand(copyFromCommand)) throw new ArgumentException("Must contain a COPY FROM STDIN command!", nameof(copyFromCommand)); CheckReady(); - var connector = StartBindingScope(ConnectorBindingScope.Copy); + var connector = Connector!; LogMessages.StartingBinaryImport(connector.LoggingConfiguration.CopyLogger, connector.Id); // no point in passing a cancellationToken here, as we register the cancellation in the Init method connector.StartUserAction(ConnectorState.Copy, attemptPgCancellation: false); + var importer = new NpgsqlBinaryImporter(connector); try { - var importer = new NpgsqlBinaryImporter(connector); await importer.Init(copyFromCommand, async, cancellationToken).ConfigureAwait(false); connector.CurrentCopyOperation = importer; return importer; } catch { - connector.EndUserAction(); - EndBindingScope(ConnectorBindingScope.Copy); + try + { + if (async) + await importer.DisposeAsync().ConfigureAwait(false); + else + importer.Dispose(); + } + catch + { + // ignored + } throw; } } @@ -1184,28 +1142,36 @@ public Task BeginBinaryExportAsync(string copyToCommand, C async Task BeginBinaryExport(bool async, string copyToCommand, CancellationToken cancellationToken = default) { - if (copyToCommand == null) - throw new ArgumentNullException(nameof(copyToCommand)); + ArgumentNullException.ThrowIfNull(copyToCommand); if (!IsValidCopyCommand(copyToCommand)) throw new ArgumentException("Must contain a COPY TO STDOUT command!", nameof(copyToCommand)); CheckReady(); - var connector = StartBindingScope(ConnectorBindingScope.Copy); + var connector = Connector!; LogMessages.StartingBinaryExport(connector.LoggingConfiguration.CopyLogger, connector.Id); // no point in passing a cancellationToken here, as we register the cancellation in the Init method connector.StartUserAction(ConnectorState.Copy, attemptPgCancellation: false); + var exporter = new NpgsqlBinaryExporter(connector); try { - var exporter = new NpgsqlBinaryExporter(connector); await exporter.Init(copyToCommand, async, cancellationToken).ConfigureAwait(false); connector.CurrentCopyOperation = exporter; return exporter; } catch { - connector.EndUserAction(); - EndBindingScope(ConnectorBindingScope.Copy); + try + { + if (async) + await exporter.DisposeAsync().ConfigureAwait(false); + else + exporter.Dispose(); + } + catch + { + // ignored + } throw; } } @@ -1221,7 +1187,7 @@ async Task BeginBinaryExport(bool async, string copyToComm /// /// See https://www.postgresql.org/docs/current/static/sql-copy.html. /// - public TextWriter BeginTextImport(string copyFromCommand) + public NpgsqlCopyTextWriter BeginTextImport(string copyFromCommand) => BeginTextImport(async: false, copyFromCommand, CancellationToken.None).GetAwaiter().GetResult(); /// @@ -1236,34 +1202,42 @@ public TextWriter BeginTextImport(string copyFromCommand) /// /// See https://www.postgresql.org/docs/current/static/sql-copy.html. /// - public Task BeginTextImportAsync(string copyFromCommand, CancellationToken cancellationToken = default) + public Task BeginTextImportAsync(string copyFromCommand, CancellationToken cancellationToken = default) => BeginTextImport(async: true, copyFromCommand, cancellationToken); - async Task BeginTextImport(bool async, string copyFromCommand, CancellationToken cancellationToken = default) + async Task BeginTextImport(bool async, string copyFromCommand, CancellationToken cancellationToken = default) { - if (copyFromCommand == null) - throw new ArgumentNullException(nameof(copyFromCommand)); + ArgumentNullException.ThrowIfNull(copyFromCommand); if (!IsValidCopyCommand(copyFromCommand)) throw new ArgumentException("Must contain a COPY FROM STDIN command!", nameof(copyFromCommand)); CheckReady(); - var connector = StartBindingScope(ConnectorBindingScope.Copy); + var connector = Connector!; LogMessages.StartingTextImport(connector.LoggingConfiguration.CopyLogger, connector.Id); // no point in passing a cancellationToken here, as we register the cancellation in the Init method connector.StartUserAction(ConnectorState.Copy, attemptPgCancellation: false); + var copyStream = new NpgsqlRawCopyStream(connector); try { - var copyStream = new NpgsqlRawCopyStream(connector); - await copyStream.Init(copyFromCommand, async, cancellationToken).ConfigureAwait(false); + await copyStream.Init(copyFromCommand, async, forExport: false, cancellationToken).ConfigureAwait(false); var writer = new NpgsqlCopyTextWriter(connector, copyStream); connector.CurrentCopyOperation = writer; return writer; } catch { - connector.EndUserAction(); - EndBindingScope(ConnectorBindingScope.Copy); + try + { + if (async) + await copyStream.DisposeAsync().ConfigureAwait(false); + else + copyStream.Dispose(); + } + catch + { + // ignored + } throw; } } @@ -1279,7 +1253,7 @@ async Task BeginTextImport(bool async, string copyFromCommand, Cance /// /// See https://www.postgresql.org/docs/current/static/sql-copy.html. /// - public TextReader BeginTextExport(string copyToCommand) + public NpgsqlCopyTextReader BeginTextExport(string copyToCommand) => BeginTextExport(async: false, copyToCommand, CancellationToken.None).GetAwaiter().GetResult(); /// @@ -1294,34 +1268,42 @@ public TextReader BeginTextExport(string copyToCommand) /// /// See https://www.postgresql.org/docs/current/static/sql-copy.html. /// - public Task BeginTextExportAsync(string copyToCommand, CancellationToken cancellationToken = default) + public Task BeginTextExportAsync(string copyToCommand, CancellationToken cancellationToken = default) => BeginTextExport(async: true, copyToCommand, cancellationToken); - async Task BeginTextExport(bool async, string copyToCommand, CancellationToken cancellationToken = default) + async Task BeginTextExport(bool async, string copyToCommand, CancellationToken cancellationToken = default) { - if (copyToCommand == null) - throw new ArgumentNullException(nameof(copyToCommand)); + ArgumentNullException.ThrowIfNull(copyToCommand); if (!IsValidCopyCommand(copyToCommand)) throw new ArgumentException("Must contain a COPY TO STDOUT command!", nameof(copyToCommand)); CheckReady(); - var connector = StartBindingScope(ConnectorBindingScope.Copy); + var connector = Connector!; LogMessages.StartingTextExport(connector.LoggingConfiguration.CopyLogger, connector.Id); // no point in passing a cancellationToken here, as we register the cancellation in the Init method connector.StartUserAction(ConnectorState.Copy, attemptPgCancellation: false); + var copyStream = new NpgsqlRawCopyStream(connector); try { - var copyStream = new NpgsqlRawCopyStream(connector); - await copyStream.Init(copyToCommand, async, cancellationToken).ConfigureAwait(false); + await copyStream.Init(copyToCommand, async, forExport: true, cancellationToken).ConfigureAwait(false); var reader = new NpgsqlCopyTextReader(connector, copyStream); connector.CurrentCopyOperation = reader; return reader; } catch { - connector.EndUserAction(); - EndBindingScope(ConnectorBindingScope.Copy); + try + { + if (async) + await copyStream.DisposeAsync().ConfigureAwait(false); + else + copyStream.Dispose(); + } + catch + { + // ignored + } throw; } } @@ -1357,21 +1339,20 @@ public Task BeginRawBinaryCopyAsync(string copyCommand, Can async Task BeginRawBinaryCopy(bool async, string copyCommand, CancellationToken cancellationToken = default) { - if (copyCommand == null) - throw new ArgumentNullException(nameof(copyCommand)); + ArgumentNullException.ThrowIfNull(copyCommand); if (!IsValidCopyCommand(copyCommand)) throw new ArgumentException("Must contain a COPY TO STDOUT OR COPY FROM STDIN command!", nameof(copyCommand)); CheckReady(); - var connector = StartBindingScope(ConnectorBindingScope.Copy); + var connector = Connector!; LogMessages.StartingRawCopy(connector.LoggingConfiguration.CopyLogger, connector.Id); // no point in passing a cancellationToken here, as we register the cancellation in the Init method connector.StartUserAction(ConnectorState.Copy, attemptPgCancellation: false); + var stream = new NpgsqlRawCopyStream(connector); try { - var stream = new NpgsqlRawCopyStream(connector); - await stream.Init(copyCommand, async, cancellationToken).ConfigureAwait(false); + await stream.Init(copyCommand, async, forExport: null, cancellationToken).ConfigureAwait(false); if (!stream.IsBinary) { // TODO: Stop the COPY operation gracefully, no breaking @@ -1383,8 +1364,17 @@ async Task BeginRawBinaryCopy(bool async, string copyComman } catch { - connector.EndUserAction(); - EndBindingScope(ConnectorBindingScope.Copy); + try + { + if (async) + await stream.DisposeAsync().ConfigureAwait(false); + else + stream.Dispose(); + } + catch + { + // ignored + } throw; } } @@ -1410,8 +1400,6 @@ public bool Wait(int timeout) { if (timeout != -1 && timeout < 0) throw new ArgumentException("Argument must be -1, 0 or positive", nameof(timeout)); - if (Settings.Multiplexing) - throw new NotSupportedException($"{nameof(Wait)} isn't supported in multiplexing mode"); CheckReady(); @@ -1453,9 +1441,6 @@ public bool Wait(int timeout) /// true if an asynchronous message was received, false if timed out. public Task WaitAsync(int timeout, CancellationToken cancellationToken = default) { - if (Settings.Multiplexing) - throw new NotSupportedException($"{nameof(Wait)} isn't supported in multiplexing mode"); - CheckReady(); LogMessages.StartingWait(_connectionLogger, timeout, Connector!.Id); @@ -1522,10 +1507,7 @@ void CheckClosed() } void CheckDisposed() - { - if (_disposed) - ThrowHelper.ThrowObjectDisposedException(nameof(NpgsqlConnection)); - } + => ObjectDisposedException.ThrowIf(_disposed, this); internal void CheckReady() { @@ -1552,115 +1534,6 @@ internal void CheckReady() #endregion State checks - #region Connector binding - - /// - /// Checks whether the connection is currently bound to a connector, and if so, returns it via - /// . - /// - internal bool TryGetBoundConnector([NotNullWhen(true)] out NpgsqlConnector? connector) - { - if (ConnectorBindingScope == ConnectorBindingScope.None) - { - Debug.Assert(Connector == null, $"Binding scope is None but {Connector} exists"); - connector = null; - return false; - } - Debug.Assert(Connector != null, $"Binding scope is {ConnectorBindingScope} but {Connector} is null"); - Debug.Assert(Connector.Connection == this, $"Bound connector {Connector} does not reference this connection"); - connector = Connector; - return true; - } - - /// - /// Binds this connection to a physical connector. This happens when opening a non-multiplexing connection, - /// or when starting a transaction on a multiplexed connection. - /// - internal ValueTask StartBindingScope( - ConnectorBindingScope scope, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken) - { - // If the connection is around bound at a higher scope, we do nothing (e.g. copy operation started - // within a transaction on a multiplexing connection). - // Note that if we're in an ambient transaction, that means we're already bound and so we do nothing here. - if (ConnectorBindingScope != ConnectorBindingScope.None) - { - Debug.Assert(Connector != null, $"Connection bound with scope {ConnectorBindingScope} but has no connector"); - Debug.Assert(scope != ConnectorBindingScope, $"Binding scopes aren't reentrant ({ConnectorBindingScope})"); - return new ValueTask(Connector); - } - - return StartBindingScopeAsync(); - - async ValueTask StartBindingScopeAsync() - { - try - { - Debug.Assert(Settings.Multiplexing); - Debug.Assert(_dataSource != null); - - var connector = await _dataSource.Get(this, timeout, async, cancellationToken).ConfigureAwait(false); - Connector = connector; - connector.Connection = this; - ConnectorBindingScope = scope; - return connector; - } - catch - { - FullState = ConnectionState.Broken; - throw; - } - } - } - - internal NpgsqlConnector StartBindingScope(ConnectorBindingScope scope) - => StartBindingScope(scope, NpgsqlTimeout.Infinite, async: false, CancellationToken.None) - .GetAwaiter().GetResult(); - - internal EndScopeDisposable StartTemporaryBindingScope(out NpgsqlConnector connector) - { - connector = StartBindingScope(ConnectorBindingScope.Temporary); - return new EndScopeDisposable(this); - } - - internal T CheckOpenAndRunInTemporaryScope(Func func) - { - CheckOpen(); - - using var _ = StartTemporaryBindingScope(out var connector); - var result = func(connector); - return result; - } - - /// - /// Ends binding scope to the physical connection and returns it to the pool. Only useful with multiplexing on. - /// - /// - /// After this method is called, under no circumstances the physical connection (connector) should ever be used if multiplexing is on. - /// See #3249. - /// - internal void EndBindingScope(ConnectorBindingScope scope) - { - Debug.Assert(ConnectorBindingScope != ConnectorBindingScope.None || FullState == ConnectionState.Broken, - $"Ending binding scope {scope} but connection's scope is null"); - - if (scope != ConnectorBindingScope) - return; - - Debug.Assert(Connector != null, $"Ending binding scope {scope} but connector is null"); - Debug.Assert(_dataSource != null, $"Ending binding scope {scope} but _pool is null"); - Debug.Assert(Settings.Multiplexing, $"Ending binding scope {scope} but multiplexing is disabled"); - - // TODO: If enlisted transaction scope is still active, need to AddPendingEnlistedConnector, just like Close - var connector = Connector; - Connector = null; - connector.Connection = null; - connector.Transaction?.UnbindIfNecessary(); - connector.Return(); - ConnectorBindingScope = ConnectorBindingScope.None; - } - - #endregion Connector binding - #region Schema operations /// @@ -1722,9 +1595,7 @@ public override Task GetSchemaAsync(string collectionName, Cancellati /// /// The collection specified. public override Task GetSchemaAsync(string collectionName, string?[]? restrictions, CancellationToken cancellationToken = default) - { - return NpgsqlSchema.GetSchema(async: true, this, collectionName, restrictions, cancellationToken); - } + => NpgsqlSchema.GetSchema(async: true, this, collectionName, restrictions, cancellationToken); #endregion Schema operations @@ -1747,9 +1618,10 @@ object ICloneable.Clone() ? _cloningInstantiator!(_connectionString) : _dataSource.CreateConnection(); + conn.SslClientAuthenticationOptionsCallback = SslClientAuthenticationOptionsCallback; +#pragma warning disable CS0618 // Obsolete conn.ProvideClientCertificatesCallback = ProvideClientCertificatesCallback; conn.UserCertificateValidationCallback = UserCertificateValidationCallback; -#pragma warning disable CS0618 // Obsolete conn.ProvidePasswordCallback = ProvidePasswordCallback; #pragma warning restore CS0618 conn._userFacingConnectionString = _userFacingConnectionString; @@ -1773,13 +1645,35 @@ public NpgsqlConnection CloneWith(string connectionString) return new NpgsqlConnection(csb.ToString()) { - ProvideClientCertificatesCallback = - ProvideClientCertificatesCallback ?? - (_dataSource?.ClientCertificatesCallback is { } clientCertificatesCallback - ? (ProvideClientCertificatesCallback)(certs => clientCertificatesCallback(certs)) - : null), - UserCertificateValidationCallback = UserCertificateValidationCallback ?? _dataSource?.UserCertificateValidationCallback, + SslClientAuthenticationOptionsCallback = SslClientAuthenticationOptionsCallback ?? _dataSource?.SslClientAuthenticationOptionsCallback, #pragma warning disable CS0618 // Obsolete + ProvideClientCertificatesCallback = ProvideClientCertificatesCallback, + UserCertificateValidationCallback = UserCertificateValidationCallback, + ProvidePasswordCallback = ProvidePasswordCallback, +#pragma warning restore CS0618 + }; + } + + /// + /// Clones this connection, replacing its connection string with the given one. + /// This allows creating a new connection with the same security information + /// (password, SSL callbacks) while changing other connection parameters (e.g. + /// database or pooling) + /// + public async ValueTask CloneWithAsync(string connectionString, CancellationToken cancellationToken = default) + { + CheckDisposed(); + var csb = new NpgsqlConnectionStringBuilder(connectionString); + csb.Password ??= _dataSource is null ? null : await _dataSource.GetPassword(async: true, cancellationToken).ConfigureAwait(false); + if (csb.PersistSecurityInfo && !Settings.PersistSecurityInfo) + csb.PersistSecurityInfo = false; + + return new NpgsqlConnection(csb.ToString()) + { + SslClientAuthenticationOptionsCallback = SslClientAuthenticationOptionsCallback ?? _dataSource?.SslClientAuthenticationOptionsCallback, +#pragma warning disable CS0618 // Obsolete + ProvideClientCertificatesCallback = ProvideClientCertificatesCallback, + UserCertificateValidationCallback = UserCertificateValidationCallback, ProvidePasswordCallback = ProvidePasswordCallback, #pragma warning restore CS0618 }; @@ -1792,8 +1686,7 @@ public NpgsqlConnection CloneWith(string connectionString) /// The name of the database to use in place of the current database. public override void ChangeDatabase(string dbName) { - if (dbName == null) - throw new ArgumentNullException(nameof(dbName)); + ArgumentNullException.ThrowIfNull(dbName); if (string.IsNullOrEmpty(dbName)) throw new ArgumentOutOfRangeException(nameof(dbName), dbName, $"Invalid database name: {dbName}"); @@ -1832,9 +1725,6 @@ public override void ChangeDatabase(string dbName) /// public void UnprepareAll() { - if (Settings.Multiplexing) - throw new NotSupportedException("Explicit preparation not supported with multiplexing"); - CheckReady(); using (Connector!.StartUserAction()) @@ -1849,10 +1739,8 @@ public void ReloadTypes() { CheckReady(); - using var scope = StartTemporaryBindingScope(out var connector); - _dataSource!.Bootstrap( - connector, + Connector!, NpgsqlTimeout.Infinite, forceReload: true, async: false, @@ -1864,18 +1752,16 @@ public void ReloadTypes() /// Flushes the type cache for this connection's connection string and reloads the types for this connection only. /// Type changes will appear for other connections only after they are re-opened from the pool. /// - public async Task ReloadTypesAsync() + public async Task ReloadTypesAsync(CancellationToken cancellationToken = default) { CheckReady(); - using var scope = StartTemporaryBindingScope(out var connector); - await _dataSource!.Bootstrap( - connector, + Connector!, NpgsqlTimeout.Infinite, forceReload: true, async: true, - CancellationToken.None).ConfigureAwait(false); + cancellationToken).ConfigureAwait(false); } /// @@ -1897,48 +1783,6 @@ event EventHandler? IComponent.Disposed #endregion Misc } -enum ConnectorBindingScope -{ - /// - /// The connection is currently not bound to a connector. - /// - None, - - /// - /// The connection is bound to its connector for the scope of the entire connection - /// (i.e. non-multiplexed connection). - /// - Connection, - - /// - /// The connection is bound to its connector for the scope of a transaction. - /// - Transaction, - - /// - /// The connection is bound to its connector for the scope of a COPY operation. - /// - Copy, - - /// - /// The connection is bound to its connector for the scope of a single reader. - /// - Reader, - - /// - /// The connection is bound to its connector for an unspecified, temporary scope; the code that initiated - /// the binding is also responsible to unbind it. - /// - Temporary -} - -readonly struct EndScopeDisposable : IDisposable -{ - readonly NpgsqlConnection _connection; - public EndScopeDisposable(NpgsqlConnection connection) => _connection = connection; - public void Dispose() => _connection.EndBindingScope(ConnectorBindingScope.Temporary); -} - #region Delegates /// diff --git a/src/Npgsql/NpgsqlConnectionStringBuilder.cs b/src/Npgsql/NpgsqlConnectionStringBuilder.cs index 88f3043fc6..9b79c9f064 100644 --- a/src/Npgsql/NpgsqlConnectionStringBuilder.cs +++ b/src/Npgsql/NpgsqlConnectionStringBuilder.cs @@ -244,8 +244,7 @@ public int Port get => _port; set { - if (value <= 0) - throw new ArgumentOutOfRangeException(nameof(value), value, "Invalid port: " + value); + ArgumentOutOfRangeException.ThrowIfNegativeOrZero(value); _port = value; SetValue(nameof(Port), value); @@ -461,6 +460,44 @@ public SslMode SslMode } SslMode _sslMode; + /// + /// Controls how SSL encryption is negotiated with the server, if SSL is used. + /// + [Category("Security")] + [Description("Controls how SSL encryption is negotiated with the server, if SSL is used.")] + [DisplayName("SSL Negotiation")] + [NpgsqlConnectionStringProperty] + public SslNegotiation SslNegotiation + { + get => UserProvidedSslNegotiation ?? SslNegotiation.Postgres; + set + { + UserProvidedSslNegotiation = value; + SetValue(nameof(SslNegotiation), value); + } + } + + internal SslNegotiation? UserProvidedSslNegotiation { get; private set; } + + /// + /// Controls whether GSS encryption is required, disabled or preferred, depending on server support. + /// + [Category("Security")] + [Description("Controls whether GSS encryption is required, disabled or preferred, depending on server support.")] + [DisplayName("GSS Encryption Mode")] + [NpgsqlConnectionStringProperty] + public GssEncryptionMode GssEncryptionMode + { + get => UserProvidedGssEncMode ?? GssEncryptionMode.Prefer; + set + { + UserProvidedGssEncMode = value; + SetValue(nameof(GssEncryptionMode), value); + } + } + + internal GssEncryptionMode? UserProvidedGssEncMode { get; private set; } + /// /// Location of a client certificate to be sent to the server. /// @@ -577,6 +614,7 @@ public string KerberosServiceName [Category("Security")] [Description("The Kerberos realm to be used for authentication.")] [DisplayName("Include Realm")] + [DefaultValue(true)] [NpgsqlConnectionStringProperty] public bool IncludeRealm { @@ -646,6 +684,24 @@ public bool IncludeErrorDetail } bool _includeErrorDetail; + /// + /// When enabled, failed statements are included on . + /// + [Category("Security")] + [Description("When enabled, failed batched commands are included on NpgsqlException.BatchCommand.")] + [DisplayName("Include Failed Batched Command")] + [NpgsqlConnectionStringProperty] + public bool IncludeFailedBatchedCommand + { + get => _includeFailedBatchedCommand; + set + { + _includeFailedBatchedCommand = value; + SetValue(nameof(IncludeFailedBatchedCommand), value); + } + } + bool _includeFailedBatchedCommand; + /// /// Controls whether channel binding is required, disabled or preferred, depending on server support. /// @@ -665,6 +721,70 @@ public ChannelBinding ChannelBinding } ChannelBinding _channelBinding; + /// + /// Controls the available authentication methods. + /// + [Category("Security")] + [Description("Controls the available authentication methods.")] + [DisplayName("Require Auth")] + [NpgsqlConnectionStringProperty] + public string? RequireAuth + { + get => _requireAuth; + set + { + RequireAuthModes = ParseAuthMode(value); + _requireAuth = value; + SetValue(nameof(RequireAuth), value); + } + } + string? _requireAuth; + + internal RequireAuthMode RequireAuthModes { get; private set; } + + internal static RequireAuthMode ParseAuthMode(string? value) + { + var modes = value?.Split(',', StringSplitOptions.TrimEntries | StringSplitOptions.RemoveEmptyEntries); + if (modes is not { Length: > 0 }) + return RequireAuthMode.All; + + var isNegative = false; + RequireAuthMode parsedModes = default; + for (var i = 0; i < modes.Length; i++) + { + var mode = modes[i]; + var modeToParse = mode.AsSpan(); + if (mode.StartsWith('!')) + { + if (i > 0 && !isNegative) + throw new ArgumentException("Mixing both positive and negative authentication methods is not supported"); + + modeToParse = modeToParse.Slice(1); + isNegative = true; + } + else + { + if (i > 0 && isNegative) + throw new ArgumentException("Mixing both positive and negative authentication methods is not supported"); + } + + // Explicitly disallow 'All' as libpq doesn't have it + if (!Enum.TryParse(modeToParse, out var parsedMode) || parsedMode == RequireAuthMode.All) + throw new ArgumentException($"Unable to parse authentication method \"{modeToParse}\""); + + parsedModes |= parsedMode; + } + + var allowedModes = isNegative + ? (RequireAuthMode)(RequireAuthMode.All - parsedModes) + : parsedModes; + + if (allowedModes == default) + throw new ArgumentException($"No authentication method is allowed. Check \"{nameof(RequireAuth)}\" in connection string."); + + return allowedModes; + } + #endregion #region Properties - Pooling @@ -701,8 +821,7 @@ public int MinPoolSize get => _minPoolSize; set { - if (value < 0) - throw new ArgumentOutOfRangeException(nameof(value), value, "MinPoolSize can't be negative"); + ArgumentOutOfRangeException.ThrowIfNegative(value); _minPoolSize = value; SetValue(nameof(MinPoolSize), value); @@ -723,8 +842,7 @@ public int MaxPoolSize get => _maxPoolSize; set { - if (value < 0) - throw new ArgumentOutOfRangeException(nameof(value), value, "MaxPoolSize can't be negative"); + ArgumentOutOfRangeException.ThrowIfNegative(value); _maxPoolSize = value; SetValue(nameof(MaxPoolSize), value); @@ -777,13 +895,17 @@ public int ConnectionPruningInterval /// /// The total maximum lifetime of connections (in seconds). Connections which have exceeded this value will be /// destroyed instead of returned from the pool. This is useful in clustered configurations to force load - /// balancing between a running server and a server just brought online. + /// balancing between a running server and a server just brought online. It can also be useful to prevent + /// runaway memory growth of connections at the PostgreSQL server side, because in some cases very long lived + /// connections slowly consume more and more memory over time. + /// Defaults to 3600 seconds (1 hour). /// - /// The time (in seconds) to wait, or 0 to to make connections last indefinitely (the default). + /// The time (in seconds) to wait, or 0 to to make connections last indefinitely. [Category("Pooling")] [Description("The total maximum lifetime of connections (in seconds).")] [DisplayName("Connection Lifetime")] [NpgsqlConnectionStringProperty("Load Balance Timeout")] + [DefaultValue(3600)] public int ConnectionLifetime { get => _connectionLifetime; @@ -813,8 +935,8 @@ public int Timeout get => _timeout; set { - if (value < 0 || value > NpgsqlConnection.TimeoutLimit) - throw new ArgumentOutOfRangeException(nameof(value), value, "Timeout must be between 0 and " + NpgsqlConnection.TimeoutLimit); + ArgumentOutOfRangeException.ThrowIfNegative(value); + ArgumentOutOfRangeException.ThrowIfGreaterThan(value, NpgsqlConnection.TimeoutLimit); _timeout = value; SetValue(nameof(Timeout), value); @@ -838,8 +960,7 @@ public int CommandTimeout get => _commandTimeout; set { - if (value < 0) - throw new ArgumentOutOfRangeException(nameof(value), value, "CommandTimeout can't be negative"); + ArgumentOutOfRangeException.ThrowIfNegative(value); _commandTimeout = value; SetValue(nameof(CommandTimeout), value); @@ -862,8 +983,7 @@ public int CancellationTimeout get => _cancellationTimeout; set { - if (value < -1) - throw new ArgumentOutOfRangeException(nameof(value), value, $"{nameof(CancellationTimeout)} can't less than -1"); + ArgumentOutOfRangeException.ThrowIfLessThan(value, -1); _cancellationTimeout = value; SetValue(nameof(CancellationTimeout), value); @@ -900,7 +1020,7 @@ public string? TargetSessionAttributes set { - TargetSessionAttributesParsed = value is null ? null : ParseTargetSessionAttributes(value); + TargetSessionAttributesParsed = value is null ? null : ParseTargetSessionAttributes(value.ToLowerInvariant()); SetValue(nameof(TargetSessionAttributes), value); } } @@ -952,8 +1072,7 @@ public int HostRecheckSeconds get => _hostRecheckSeconds; set { - if (value < 0) - throw new ArgumentException($"{HostRecheckSeconds} cannot be negative", nameof(HostRecheckSeconds)); + ArgumentOutOfRangeException.ThrowIfNegative(value); _hostRecheckSeconds = value; SetValue(nameof(HostRecheckSeconds), value); } @@ -977,8 +1096,7 @@ public int KeepAlive get => _keepAlive; set { - if (value < 0) - throw new ArgumentOutOfRangeException(nameof(value), value, "KeepAlive can't be negative"); + ArgumentOutOfRangeException.ThrowIfNegative(value); _keepAlive = value; SetValue(nameof(KeepAlive), value); @@ -1018,8 +1136,7 @@ public int TcpKeepAliveTime get => _tcpKeepAliveTime; set { - if (value < 0) - throw new ArgumentOutOfRangeException(nameof(value), value, "TcpKeepAliveTime can't be negative"); + ArgumentOutOfRangeException.ThrowIfNegative(value); _tcpKeepAliveTime = value; SetValue(nameof(TcpKeepAliveTime), value); @@ -1040,8 +1157,7 @@ public int TcpKeepAliveInterval get => _tcpKeepAliveInterval; set { - if (value < 0) - throw new ArgumentOutOfRangeException(nameof(value), value, "TcpKeepAliveInterval can't be negative"); + ArgumentOutOfRangeException.ThrowIfNegative(value); _tcpKeepAliveInterval = value; SetValue(nameof(TcpKeepAliveInterval), value); @@ -1137,8 +1253,7 @@ public int MaxAutoPrepare get => _maxAutoPrepare; set { - if (value < 0) - throw new ArgumentOutOfRangeException(nameof(value), value, $"{nameof(MaxAutoPrepare)} cannot be negative"); + ArgumentOutOfRangeException.ThrowIfNegative(value); _maxAutoPrepare = value; SetValue(nameof(MaxAutoPrepare), value); @@ -1160,8 +1275,7 @@ public int AutoPrepareMinUsages get => _autoPrepareMinUsages; set { - if (value < 1) - throw new ArgumentOutOfRangeException(nameof(value), value, $"{nameof(AutoPrepareMinUsages)} must be 1 or greater"); + ArgumentOutOfRangeException.ThrowIfNegativeOrZero(value); _autoPrepareMinUsages = value; SetValue(nameof(AutoPrepareMinUsages), value); @@ -1188,24 +1302,6 @@ public bool NoResetOnClose } bool _noResetOnClose; - /// - /// Load table composite type definitions, and not just free-standing composite types. - /// - [Category("Advanced")] - [Description("Load table composite type definitions, and not just free-standing composite types.")] - [DisplayName("Load Table Composites")] - [NpgsqlConnectionStringProperty] - public bool LoadTableComposites - { - get => _loadTableComposites; - set - { - _loadTableComposites = value; - SetValue(nameof(LoadTableComposites), value); - } - } - bool _loadTableComposites; - /// /// Set the replication mode of the connection /// @@ -1269,51 +1365,26 @@ public ArrayNullabilityMode ArrayNullabilityMode #endregion - #region Multiplexing - - /// - /// Enables multiplexing, which allows more efficient use of connections. - /// - [Category("Multiplexing")] - [Description("Enables multiplexing, which allows more efficient use of connections.")] - [DisplayName("Multiplexing")] - [NpgsqlConnectionStringProperty] - [DefaultValue(false)] - public bool Multiplexing - { - get => _multiplexing; - set - { - _multiplexing = value; - SetValue(nameof(Multiplexing), value); - } - } - bool _multiplexing; + #region Properties - Obsolete /// - /// When multiplexing is enabled, determines the maximum number of outgoing bytes to buffer before - /// flushing to the network. + /// Load table composite type definitions, and not just free-standing composite types. /// - [Category("Multiplexing")] - [Description("When multiplexing is enabled, determines the maximum number of outgoing bytes to buffer before " + - "flushing to the network.")] - [DisplayName("Write Coalescing Buffer Threshold Bytes")] + [Category("Advanced")] + [Description("Load table composite type definitions, and not just free-standing composite types.")] + [DisplayName("Load Table Composites")] [NpgsqlConnectionStringProperty] - [DefaultValue(1000)] - public int WriteCoalescingBufferThresholdBytes + [Obsolete("Specifying type loading options through the connection string is obsolete, use the DataSource builder instead. See the 9.0 release notes for more information.")] + public bool LoadTableComposites { - get => _writeCoalescingBufferThresholdBytes; + get => _loadTableComposites; set { - _writeCoalescingBufferThresholdBytes = value; - SetValue(nameof(WriteCoalescingBufferThresholdBytes), value); + _loadTableComposites = value; + SetValue(nameof(LoadTableComposites), value); } } - int _writeCoalescingBufferThresholdBytes; - - #endregion - - #region Properties - Compatibility + bool _loadTableComposites; /// /// A compatibility mode for special PostgreSQL server types. @@ -1322,9 +1393,11 @@ public int WriteCoalescingBufferThresholdBytes [Description("A compatibility mode for special PostgreSQL server types.")] [DisplayName("Server Compatibility Mode")] [NpgsqlConnectionStringProperty] + [Obsolete("Specifying type loading options through the connection string is obsolete, use the DataSource builder instead. See the 9.0 release notes for more information.")] public ServerCompatibilityMode ServerCompatibilityMode { - get => _serverCompatibilityMode; + // Physical replication connections don't allow regular queries, so we can't load types from PG + get => ReplicationMode is ReplicationMode.Physical ? ServerCompatibilityMode.NoTypeLoading : _serverCompatibilityMode; set { _serverCompatibilityMode = value; @@ -1333,10 +1406,6 @@ public ServerCompatibilityMode ServerCompatibilityMode } ServerCompatibilityMode _serverCompatibilityMode; - #endregion - - #region Properties - Obsolete - /// /// Whether to trust the server certificate without validating it. /// @@ -1386,12 +1455,11 @@ public int InternalCommandTimeout internal void PostProcessAndValidate() { - if (string.IsNullOrWhiteSpace(Host)) - throw new ArgumentException("Host can't be null"); - if (Multiplexing && !Pooling) - throw new ArgumentException("Pooling must be on to use multiplexing"); + ArgumentException.ThrowIfNullOrWhiteSpace(Host); + if (SslNegotiation == SslNegotiation.Direct && SslMode is not SslMode.Require and not SslMode.VerifyCA and not SslMode.VerifyFull) + throw new ArgumentException("SSL Mode has to be Require or higher to be used with direct SSL Negotiation"); - if (!Host.Contains(",")) + if (!Host.Contains(',')) { if (TargetSessionAttributesParsed is not null && TargetSessionAttributesParsed != Npgsql.TargetSessionAttributes.Any) @@ -1560,9 +1628,22 @@ protected override void GetProperties(Hashtable propertyDescriptors) foreach (var value in propertyDescriptors.Values) { var d = (PropertyDescriptor)value; + var isConnectionStringProperty = false; + var isObsolete = false; foreach (var attribute in d.Attributes) - if (attribute is NpgsqlConnectionStringPropertyAttribute or ObsoleteAttribute) - toRemove.Add(d); + { + if (attribute is NpgsqlConnectionStringPropertyAttribute) + { + isConnectionStringProperty = true; + } + else if (attribute is ObsoleteAttribute) + { + isObsolete = true; + } + } + + if (!isConnectionStringProperty || isObsolete) + toRemove.Add(d); } foreach (var o in toRemove) @@ -1590,7 +1671,7 @@ sealed class NpgsqlConnectionStringPropertyAttribute : Attribute /// Creates a . /// public NpgsqlConnectionStringPropertyAttribute() - => Synonyms = Array.Empty(); + => Synonyms = []; /// /// Creates a . @@ -1603,26 +1684,6 @@ public NpgsqlConnectionStringPropertyAttribute(params string[] synonyms) #region Enums -/// -/// An option specified in the connection string that activates special compatibility features. -/// -public enum ServerCompatibilityMode -{ - /// - /// No special server compatibility mode is active - /// - None, - /// - /// The server is an Amazon Redshift instance. - /// - Redshift, - /// - /// The server is doesn't support full type loading from the PostgreSQL catalogs, support the basic set - /// of types via information hardcoded inside Npgsql. - /// - NoTypeLoading, -} - /// /// Specifies how to manage SSL. /// @@ -1654,6 +1715,40 @@ public enum SslMode VerifyFull } +/// +/// Specifies how to initialize SSL session. +/// +public enum SslNegotiation +{ + /// + /// Perform PostgreSQL protocol negotiation. + /// + Postgres, + /// + /// Start SSL handshake directly after establishing the TCP/IP connection. + /// + Direct +} + +/// +/// Specifies how to manage GSS encryption. +/// +public enum GssEncryptionMode +{ + /// + /// GSS encryption is disabled. If the server requires GSS encryption, the connection will fail. + /// + Disable, + /// + /// Prefer GSS encrypted connections if the server allows them, but allow connections without GSS encryption. + /// + Prefer, + /// + /// Fail the connection if the server doesn't support GSS encryption. + /// + Require +} + /// /// Specifies how to manage channel binding. /// @@ -1728,4 +1823,40 @@ enum ReplicationMode Logical } +/// +/// Specifies which authentication methods are supported. +/// +[Flags] +enum RequireAuthMode +{ + /// + /// Plaintext password. + /// + Password = 1, + /// + /// MD5 hashed password. + /// + MD5 = 2, + /// + /// Kerberos. + /// + GSS = 4, + /// + /// Windows SSPI. + /// + SSPI = 8, + /// + /// SASL. + /// + ScramSHA256 = 16, + /// + /// No authentication exchange. + /// + None = 32, + /// + /// All authentication methods. For internal use. + /// + All = Password | MD5 | GSS | SSPI | ScramSHA256 | None +} + #endregion diff --git a/src/Npgsql/NpgsqlDataAdapter.cs b/src/Npgsql/NpgsqlDataAdapter.cs index c18773b2d6..f98f4cca61 100644 --- a/src/Npgsql/NpgsqlDataAdapter.cs +++ b/src/Npgsql/NpgsqlDataAdapter.cs @@ -213,18 +213,18 @@ async Task Fill(DataTable dataTable, NpgsqlDataReader dataReader, bool asyn #pragma warning disable 1591 -public class NpgsqlRowUpdatingEventArgs : RowUpdatingEventArgs -{ - public NpgsqlRowUpdatingEventArgs(DataRow dataRow, IDbCommand? command, System.Data.StatementType statementType, - DataTableMapping tableMapping) - : base(dataRow, command, statementType, tableMapping) {} -} - -public class NpgsqlRowUpdatedEventArgs : RowUpdatedEventArgs -{ - public NpgsqlRowUpdatedEventArgs(DataRow dataRow, IDbCommand? command, System.Data.StatementType statementType, - DataTableMapping tableMapping) - : base(dataRow, command, statementType, tableMapping) {} -} +public class NpgsqlRowUpdatingEventArgs( + DataRow dataRow, + IDbCommand? command, + System.Data.StatementType statementType, + DataTableMapping tableMapping) + : RowUpdatingEventArgs(dataRow, command, statementType, tableMapping); + +public class NpgsqlRowUpdatedEventArgs( + DataRow dataRow, + IDbCommand? command, + System.Data.StatementType statementType, + DataTableMapping tableMapping) + : RowUpdatedEventArgs(dataRow, command, statementType, tableMapping); #pragma warning restore 1591 diff --git a/src/Npgsql/NpgsqlDataReader.cs b/src/Npgsql/NpgsqlDataReader.cs index 59293e989b..27bc6675c7 100644 --- a/src/Npgsql/NpgsqlDataReader.cs +++ b/src/Npgsql/NpgsqlDataReader.cs @@ -43,8 +43,7 @@ public sealed class NpgsqlDataReader : DbDataReader, IDbColumnSchemaGenerator CommandBehavior _behavior; /// - /// In multiplexing, this is as the sending is managed in the write multiplexing loop, - /// and does not need to be awaited by the reader. + /// The task for writing this command's messages. Awaited on reader cleanup. /// Task? _sendTask; @@ -67,7 +66,7 @@ public sealed class NpgsqlDataReader : DbDataReader, IDbColumnSchemaGenerator /// Records, for each column, its starting offset and length in the current row. /// Used only in non-sequential mode. /// - readonly List<(int Offset, int Length)> _columns = new(); + readonly List<(int Offset, int Length)> _columns = []; int _columnsStartPos; /// @@ -160,7 +159,7 @@ internal void Init( /// public override bool Read() { - CheckClosedOrDisposed(); + ThrowIfClosedOrDisposed(); return TryRead()?.Result ?? Read(false).GetAwaiter().GetResult(); } @@ -173,7 +172,7 @@ public override bool Read() /// A task representing the asynchronous operation. public override Task ReadAsync(CancellationToken cancellationToken) { - CheckClosedOrDisposed(); + ThrowIfClosedOrDisposed(); return TryRead() ?? Read(async: true, cancellationToken); } @@ -197,7 +196,7 @@ public override Task ReadAsync(CancellationToken cancellationToken) if (_behavior.HasFlag(CommandBehavior.SingleRow) || !_isRowBuffered) return null; - ConsumeRowNonSequential(); + ConsumeBufferedRow(); const int headerSize = sizeof(byte) + sizeof(int); var buffer = Buffer; @@ -307,8 +306,12 @@ static async ValueTask ReadMessageSequential(NpgsqlConnector co /// Advances the reader to the next result when reading the results of a batch of statements. /// /// - public override bool NextResult() => (_isSchemaOnly ? NextResultSchemaOnly(false) : NextResult(false)) - .GetAwaiter().GetResult(); + public override bool NextResult() + { + ThrowIfClosedOrDisposed(); + return (_isSchemaOnly ? NextResultSchemaOnly(false) : NextResult(false)) + .GetAwaiter().GetResult(); + } /// /// This is the asynchronous version of NextResult. @@ -318,9 +321,12 @@ public override bool NextResult() => (_isSchemaOnly ? NextResultSchemaOnly(false /// /// A task representing the asynchronous operation. public override Task NextResultAsync(CancellationToken cancellationToken) - => _isSchemaOnly + { + ThrowIfClosedOrDisposed(); + return _isSchemaOnly ? NextResultSchemaOnly(async: true, cancellationToken: cancellationToken) : NextResult(async: true, cancellationToken: cancellationToken); + } /// /// Internal implementation of NextResult @@ -328,8 +334,6 @@ public override Task NextResultAsync(CancellationToken cancellationToken) async Task NextResult(bool async, bool isConsuming = false, CancellationToken cancellationToken = default) { Debug.Assert(!_isSchemaOnly); - CheckClosedOrDisposed(); - if (State is ReaderState.Consumed) return false; @@ -454,15 +458,52 @@ async Task NextResult(bool async, bool isConsuming = false, CancellationTo continue; } - if (!Command.IsWrappedByBatch && StatementIndex == 0 && Command._parameters?.HasOutputParameters == true) + if ((Command.WrappingBatch is not null || StatementIndex is 0) && Command.InternalBatchCommands[StatementIndex] is { HasOutputParameters: true } command) { - // If output parameters are present and this is the first row of the first resultset, + // If output parameters are present and this is the first row of the resultset, // we must always read it in non-sequential mode because it will be traversed twice (once // here for the parameters, then as a regular row). - msg = await Connector.ReadMessage(async).ConfigureAwait(false); + msg = await Connector.ReadMessage(async, dataRowLoadingMode: DataRowLoadingMode.NonSequential).ConfigureAwait(false); ProcessMessage(msg); if (msg.Code == BackendMessageCode.DataRow) - PopulateOutputParameters(); + { + Debug.Assert(RowDescription != null); + Debug.Assert(State == ReaderState.BeforeResult); + + try + { + // Temporarily set our state to InResult and non-sequential to allow us to read the values, and in any order. + var isSequential = _isSequential; + var currentPosition = Buffer.ReadPosition; + State = ReaderState.InResult; + _isSequential = false; + try + { + command.PopulateOutputParameters(this, _commandLogger); + + // On success we want to revert any row and column state for the user to be able to read the same row again. + if (async) + await PgReader.CommitAsync().ConfigureAwait(false); + else + PgReader.Commit(); + + State = ReaderState.BeforeResult; // Set the state back + Buffer.ReadPosition = currentPosition; // Restore position + _column = -1; + } + finally + { + // To be on the safe side we always revert this CommandBehavior state change, including on failure. + _isSequential = isSequential; + } + } + catch (Exception e) + { + // TODO: ideally we should flow down to global exception filter and consume there + await Consume(async, firstException: e).ConfigureAwait(false); + throw; + } + } } else { @@ -486,7 +527,7 @@ async Task NextResult(bool async, bool isConsuming = false, CancellationTo } // There are no more queries, we're done. Read the RFQ. - if (_statements.Count is 0 || !(_statements[_statements.Count - 1].AppendErrorBarrier ?? Command.EnableErrorBarriers)) + if (_statements.Count is 0 || !(_statements[^1].AppendErrorBarrier ?? Command.EnableErrorBarriers)) Expect(await Connector.ReadMessage(async).ConfigureAwait(false), Connector); State = ReaderState.Consumed; @@ -500,11 +541,13 @@ async Task NextResult(bool async, bool isConsuming = false, CancellationTo var statement = _statements[StatementIndex]; // Reference the triggering statement from the exception - postgresException.BatchCommand = statement; + if (Connector.Settings.IncludeFailedBatchedCommand) + postgresException.BatchCommand = statement; // Prevent the command or batch from being recycled (by the connection) when it's disposed. This is important since // the exception is very likely to escape the using statement of the command, and by that time some other user may // already be using the recycled instance. + // TODO: we probably should do than even if it's not PostgresException (error from PopulateOutputParameters) Command.IsCacheable = false; // If the schema of a table changes after a statement is prepared on that table, PostgreSQL errors with @@ -515,8 +558,6 @@ async Task NextResult(bool async, bool isConsuming = false, CancellationTo { preparedStatement.State = PreparedState.Invalidated; Command.ResetPreparation(); - foreach (var s in Command.InternalBatchCommands) - s.ResetPreparation(); } } @@ -535,6 +576,8 @@ async Task NextResult(bool async, bool isConsuming = false, CancellationTo // However, if the command has error barrier, we now have to consume results from the commands after it (unless it's the // last one). // Note that Consume calls NextResult (this method) recursively, the isConsuming flag tells us we're in this mode. + // TODO: We might as well call Consume on every command (even the last one) to make sure we do read every single message until RFQ + // in case we get an exception in the middle of NextResult if ((statement.AppendErrorBarrier ?? Command.EnableErrorBarriers) && StatementIndex < _statements.Count - 1) { if (isConsuming) @@ -592,54 +635,6 @@ async ValueTask ConsumeResultSet(bool async) } } - - void PopulateOutputParameters() - { - // The first row in a stored procedure command that has output parameters needs to be traversed twice - - // once for populating the output parameters and once for the actual result set traversal. So in this - // case we can't be sequential. - Debug.Assert(StatementIndex == 0); - Debug.Assert(RowDescription != null); - Debug.Assert(State == ReaderState.BeforeResult); - - var currentPosition = Buffer.ReadPosition; - - // Temporarily set our state to InResult to allow us to read the values - State = ReaderState.InResult; - - var pending = new Queue(); - var taken = new List(); - for (var i = 0; i < FieldCount; i++) - { - if (Command.Parameters.TryGetValue(GetName(i), out var p) && p.IsOutputDirection) - { - p.Value = GetValue(i); - taken.Add(p); - } - else - pending.Enqueue(GetValue(i)); - } - - // Not sure where this odd behavior comes from: all output parameters which did not get matched by - // name now get populated with column values which weren't matched. Keeping this for backwards compat, - // opened #2252 for investigation. - foreach (var p in (IEnumerable)Command.Parameters) - { - if (!p.IsOutputDirection || taken.Contains(p)) - continue; - - if (pending.Count == 0) - break; - p.Value = pending.Dequeue(); - } - - PgReader.Commit(resuming: false); - State = ReaderState.BeforeResult; // Set the state back - Buffer.ReadPosition = currentPosition; // Restore position - - _column = -1; - } - /// /// Note that in SchemaOnly mode there are no resultsets, and we read nothing from the backend (all /// RowDescriptions have already been processed and are available) @@ -647,26 +642,13 @@ void PopulateOutputParameters() async Task NextResultSchemaOnly(bool async, bool isConsuming = false, CancellationToken cancellationToken = default) { Debug.Assert(_isSchemaOnly); + if (State is ReaderState.Consumed) + return false; using var registration = isConsuming ? default : Connector.StartNestedCancellableOperation(cancellationToken); try { - switch (State) - { - case ReaderState.BeforeResult: - case ReaderState.InResult: - case ReaderState.BetweenResults: - break; - case ReaderState.Consumed: - case ReaderState.Closed: - case ReaderState.Disposed: - return false; - default: - ThrowHelper.ThrowArgumentOutOfRangeException(); - return false; - } - for (StatementIndex++; StatementIndex < _statements.Count; StatementIndex++) { var statement = _statements[StatementIndex]; @@ -708,7 +690,11 @@ async Task NextResultSchemaOnly(bool async, bool isConsuming = false, Canc break; case BackendMessageCode.RowDescription: // We have a resultset - RowDescription = _statements[StatementIndex].Description = (RowDescriptionMessage)msg; + // RowDescription messages are cached on the connector, but if we're auto-preparing, we need to + // clone our own copy which will last beyond the lifetime of this invocation. + RowDescription = _statements[StatementIndex].Description = preparedStatement == null + ? (RowDescriptionMessage)msg + : ((RowDescriptionMessage)msg).Clone(); Command.FixupRowDescription(RowDescription, StatementIndex == 0); break; default: @@ -729,17 +715,7 @@ async Task NextResultSchemaOnly(bool async, bool isConsuming = false, Canc // Found a resultset if (RowDescription is not null) - { - if (ColumnInfoCache?.Length >= ColumnCount) - Array.Clear(ColumnInfoCache, 0, ColumnCount); - else - { - if (ColumnInfoCache is { } cache) - ArrayPool.Shared.Return(cache, clearArray: true); - ColumnInfoCache = ArrayPool.Shared.Rent(ColumnCount); - } return true; - } } State = ReaderState.Consumed; @@ -755,7 +731,9 @@ async Task NextResultSchemaOnly(bool async, bool isConsuming = false, Canc // Reference the triggering statement from the exception if (e is PostgresException postgresException && StatementIndex >= 0 && StatementIndex < _statements.Count) { - postgresException.BatchCommand = _statements[StatementIndex]; + // Reference the triggering statement from the exception + if (Connector.Settings.IncludeFailedBatchedCommand) + postgresException.BatchCommand = _statements[StatementIndex]; // Prevent the command or batch from being recycled (by the connection) when it's disposed. This is important since // the exception is very likely to escape the using statement of the command, and by that time some other user may @@ -876,7 +854,7 @@ void HandleUncommon(IBackendMessage msg) /// /// Gets a value indicating whether the data reader is closed. /// - public override bool IsClosed => State == ReaderState.Closed || State == ReaderState.Disposed; + public override bool IsClosed => State is ReaderState.Closed or ReaderState.Disposed; /// /// Gets the number of rows changed, inserted, or deleted by execution of the SQL statement. @@ -912,18 +890,26 @@ public override int RecordsAffected /// which exposes an aggregation across all statements. /// [Obsolete("Use the new DbBatch API")] - public IReadOnlyList Statements => _statements.AsReadOnly(); + public IReadOnlyList Statements + { + get + { + ThrowIfClosedOrDisposed(); + return _statements.AsReadOnly(); + } + } /// /// Gets a value that indicates whether this DbDataReader contains one or more rows. /// public override bool HasRows - => State switch + { + get { - ReaderState.Closed => throw new InvalidOperationException("Invalid attempt to call HasRows when reader is closed."), - ReaderState.Disposed => throw new ObjectDisposedException(nameof(NpgsqlDataReader)), - _ => _hasRows - }; + ThrowIfClosedOrDisposed(); + return _hasRows; + } + } /// /// Indicates whether the reader is currently positioned on a row, i.e. whether reading a @@ -932,7 +918,14 @@ public override bool HasRows /// return true even if attempting to read a column will fail, e.g. before /// has been called /// - public bool IsOnRow => State == ReaderState.InResult; + public bool IsOnRow + { + get + { + ThrowIfClosedOrDisposed(); + return State is ReaderState.InResult; + } + } /// /// Gets the name of the column, given the zero-based column ordinal. @@ -948,7 +941,7 @@ public override int FieldCount { get { - CheckClosedOrDisposed(); + ThrowIfClosedOrDisposed(); return RowDescription?.Count ?? 0; } } @@ -965,7 +958,16 @@ async Task Consume(bool async, Exception? firstException = null) // Skip over the other result sets. Note that this does tally records affected from CommandComplete messages, and properly sets // state for auto-prepared statements - while (true) + // + // The only exception is when the connector is broken (which can happen in the middle of consuming) + // As then there is no point in going forward. + // An exception to the exception above is when connector is concurrently closed while + // the reader is still going over the result set. + // While this is undefined behavior and user error, we should try to at least do our best to not loop indefinitely. + // + // While we can also check our local state (State == Closed) + // It's probably better to rely on connector since it's private and its state can't be changed + while (Connector.IsConnected) { try { @@ -978,7 +980,7 @@ async Task Consume(bool async, Exception? firstException = null) } catch (Exception e) { - exceptions ??= new(); + exceptions ??= []; exceptions.Add(e); } } @@ -1011,8 +1013,7 @@ protected override void Dispose(bool disposing) catch (Exception ex) { // In the case of a PostgresException (or multiple ones, if we have error barriers), the reader's state has already been set - // to Disposed in Close above; in multiplexing, we also unbind the connector (with its reader), and at that point it can be used - // by other consumers. Therefore, we only set the state fo Disposed if the exception *wasn't* a PostgresException. + // to Disposed in Close above. Therefore, we only set the state to Disposed if the exception *wasn't* a PostgresException. if (!(ex is PostgresException || ex is NpgsqlException { InnerException: AggregateException aggregateException } && AllPostgresExceptions(aggregateException.InnerExceptions))) @@ -1040,8 +1041,7 @@ public override async ValueTask DisposeAsync() catch (Exception ex) { // In the case of a PostgresException (or multiple ones, if we have error barriers), the reader's state has already been set - // to Disposed in Close above; in multiplexing, we also unbind the connector (with its reader), and at that point it can be used - // by other consumers. Therefore, we only set the state to Disposed if the exception *wasn't* a PostgresException. + // to Disposed in Close above. Therefore, we only set the state to Disposed if the exception *wasn't* a PostgresException. if (!(ex is PostgresException || ex is NpgsqlException { InnerException: AggregateException aggregateException } && AllPostgresExceptions(aggregateException.InnerExceptions))) @@ -1139,7 +1139,7 @@ internal async Task Cleanup(bool async, bool connectionClosing = false, bool isD { LogMessages.ReaderCleanup(_commandLogger, Connector.Id); - // If multiplexing isn't on, _sendTask contains the task for the writing of this command. + // _sendTask contains the task for the writing of this command. // Make sure that this task, which may have executed asynchronously and in parallel with the reading, // has completed, throwing any exceptions it generated. If we don't do this, there's the possibility of a race condition where the // user executes a new command after reader.Dispose() returns, but some additional write stuff is still finishing up from the last @@ -1186,27 +1186,10 @@ internal async Task Cleanup(bool async, bool connectionClosing = false, bool isD Connector.DataSource.MetricsReporter.ReportCommandStop(_startTimestamp); Connector.EndUserAction(); - // The reader shouldn't be unbound, if we're disposing - so the state is set prematurely if (isDisposing) State = ReaderState.Disposed; - if (_connection?.ConnectorBindingScope == ConnectorBindingScope.Reader) - { - UnbindIfNecessary(); - - // TODO: Refactor... Use proper scope - _connection.Connector = null; - Connector.Connection = null; - _connection.ConnectorBindingScope = ConnectorBindingScope.None; - - // If the reader is being closed as part of the connection closing, we don't apply - // the reader's CommandBehavior.CloseConnection - if (_behavior.HasFlag(CommandBehavior.CloseConnection) && !connectionClosing) - _connection.Close(); - - Connector.ReaderCompleted.SetResult(null); - } - else if (_behavior.HasFlag(CommandBehavior.CloseConnection) && !connectionClosing) + if (_behavior.HasFlag(CommandBehavior.CloseConnection) && !connectionClosing) { Debug.Assert(_connection is not null); _connection.Close(); @@ -1314,11 +1297,10 @@ internal async Task Cleanup(bool async, bool connectionClosing = false, bool isD /// The number of instances of in the array. public override int GetValues(object[] values) { - if (values == null) - throw new ArgumentNullException(nameof(values)); - CheckResultSet(); + ThrowIfNotInResult(); + ArgumentNullException.ThrowIfNull(values); - var count = Math.Min(FieldCount, values.Length); + var count = Math.Min(ColumnCount, values.Length); for (var i = 0; i < count; i++) values[i] = GetValue(i); return count; @@ -1360,23 +1342,23 @@ public override int GetValues(object[] values) /// A data reader. public new NpgsqlNestedDataReader GetData(int ordinal) { + ThrowIfNotInResult(); + var field = RowDescription[ordinal]; if (_isSequential) - throw new NotSupportedException("GetData() not supported in sequential mode."); + ThrowHelper.ThrowNotSupportedException("GetData() not supported in sequential mode."); - var field = CheckRowAndGetField(ordinal); var type = field.PostgresType; var isArray = type is PostgresArrayType; var elementType = isArray ? ((PostgresArrayType)type).Element : type; var compositeType = elementType as PostgresCompositeType; if (field.DataFormat is DataFormat.Text || (elementType.InternalName != "record" && compositeType == null)) - throw new InvalidCastException("GetData() not supported for type " + field.TypeDisplayName); + ThrowHelper.ThrowInvalidCastException("GetData() not supported for type " + field.TypeDisplayName); - var columnLength = SeekToColumn(async: false, ordinal, field.DataFormat, resumableOp: true).GetAwaiter().GetResult(); - if (columnLength is -1) + if (SeekToColumn(ordinal, field.DataFormat, resumableOp: true) is -1) ThrowHelper.ThrowInvalidCastException_NoValue(field); - if (PgReader.FieldOffset > 0) - PgReader.Rewind(PgReader.FieldOffset); + Debug.Assert(!PgReader.NestedInitialized, "Unexpected nested read active, Seek(0) would seek to the start of the nested data."); + PgReader.Seek(0); var reader = CachedFreeNestedDataReader; if (reader != null) @@ -1410,29 +1392,32 @@ public override int GetValues(object[] values) /// The actual number of bytes read. public override long GetBytes(int ordinal, long dataOffset, byte[]? buffer, int bufferOffset, int length) { + ThrowIfNotInResult(); + var field = RowDescription[ordinal]; + if (dataOffset is < 0 or > int.MaxValue) - throw new ArgumentOutOfRangeException(nameof(dataOffset), dataOffset, $"dataOffset must be between {0} and {int.MaxValue}"); + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(dataOffset), "dataOffset must be between 0 and {0}", int.MaxValue); if (buffer != null && (bufferOffset < 0 || bufferOffset >= buffer.Length + 1)) - throw new IndexOutOfRangeException($"bufferOffset must be between 0 and {buffer.Length}"); + ThrowHelper.ThrowIndexOutOfRangeException("bufferOffset must be between 0 and {0}", buffer.Length); if (buffer != null && (length < 0 || length > buffer.Length - bufferOffset)) - throw new IndexOutOfRangeException($"length must be between 0 and {buffer.Length - bufferOffset}"); + ThrowHelper.ThrowIndexOutOfRangeException("bufferOffset must be between 0 and {0}", buffer.Length - bufferOffset); - var field = CheckRowAndGetField(ordinal); - var columnLength = SeekToColumn(async: false, ordinal, field.DataFormat, resumableOp: true).GetAwaiter().GetResult(); - if (columnLength == -1) + if (SeekToColumn(ordinal, field.DataFormat, resumableOp: true) is var columnLength && columnLength is -1) ThrowHelper.ThrowInvalidCastException_NoValue(field); if (buffer is null) return columnLength; - // Move to offset - if (_isSequential && PgReader.FieldOffset > dataOffset) + // Check whether any sequential seek is contractually sound (even though we might be able to satisfy rewinds we make sure we won't). + if (_isSequential && PgReader.IsFieldConsumed((int)dataOffset)) ThrowHelper.ThrowInvalidOperationException("Attempt to read a position in the column which has already been read"); - PgReader.Seek((int)dataOffset); + // Move to offset + Debug.Assert(!PgReader.NestedInitialized, "Unexpected nested read active, Seek(0) would seek to the start of the nested data."); + var remaining = PgReader.Seek((int)dataOffset); // At offset, read into buffer. - length = Math.Min(length, PgReader.FieldRemaining); + length = Math.Min(length, remaining); PgReader.ReadBytes(new Span(buffer, bufferOffset, length)); return length; } @@ -1471,36 +1456,36 @@ public Task GetStreamAsync(int ordinal, CancellationToken cancellationTo /// The actual number of characters read. public override long GetChars(int ordinal, long dataOffset, char[]? buffer, int bufferOffset, int length) { + ThrowIfNotInResult(); + + // Check whether we have a GetChars implementation for this column type. + var field = GetInfo(ordinal, typeof(GetChars), out var converter, out var bufferRequirement, out var asObject); + if (dataOffset is < 0 or > int.MaxValue) - throw new ArgumentOutOfRangeException(nameof(dataOffset), dataOffset, $"dataOffset must be between 0 and {int.MaxValue}"); + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(dataOffset), "dataOffset must be between 0 and {0}", int.MaxValue); if (buffer != null && (bufferOffset < 0 || bufferOffset >= buffer.Length + 1)) - throw new IndexOutOfRangeException($"bufferOffset must be between 0 and {buffer.Length}"); + ThrowHelper.ThrowIndexOutOfRangeException("bufferOffset must be between 0 and {0}", buffer.Length); if (buffer != null && (length < 0 || length > buffer.Length - bufferOffset)) - throw new IndexOutOfRangeException($"length must be between 0 and {buffer.Length - bufferOffset}"); + ThrowHelper.ThrowIndexOutOfRangeException("bufferOffset must be between 0 and {0}", buffer.Length - bufferOffset); - // Check whether we can do resumable reads. - var field = GetInfo(ordinal, typeof(GetChars), out var converter, out var bufferRequirement, out var asObject); - if (converter is not IResumableRead { Supported: true }) - throw new NotSupportedException("The GetChars method is not supported for this column type"); - - var columnLength = SeekToColumn(async: false, ordinal, field, resumableOp: true).GetAwaiter().GetResult(); - if (columnLength == -1) - ThrowHelper.ThrowInvalidCastException_NoValue(CheckRowAndGetField(ordinal)); + if (SeekToColumn(ordinal, field, resumableOp: true) is -1) + ThrowHelper.ThrowInvalidCastException_NoValue(RowDescription[ordinal]); + var reader = PgReader; dataOffset = buffer is null ? 0 : dataOffset; - PgReader.InitCharsRead(checked((int)dataOffset), - buffer is not null ? new ArraySegment(buffer, bufferOffset, length) : (ArraySegment?)null, - out var previousDataOffset); - - if (_isSequential && previousDataOffset > dataOffset) + if (_isSequential && reader.CharsRead > dataOffset) ThrowHelper.ThrowInvalidOperationException("Attempt to read a position in the column which has already been read"); - PgReader.StartRead(bufferRequirement); + reader.StartCharsRead(checked((int)dataOffset), + buffer is not null ? new ArraySegment(buffer, bufferOffset, length) : (ArraySegment?)null); + + reader.StartRead(bufferRequirement); var result = asObject - ? (GetChars)converter.ReadAsObject(PgReader) - : ((PgConverter)converter).Read(PgReader); - PgReader.AdvanceCharsRead(result.Read); - PgReader.EndRead(); + ? (GetChars)converter.ReadAsObject(reader) + : ((PgConverter)converter).Read(reader); + reader.EndRead(); + + reader.EndCharsRead(); return result.Read; } @@ -1550,12 +1535,11 @@ public override Task GetFieldValueAsync(int ordinal, CancellationToken can async ValueTask Core(int ordinal, CancellationToken cancellationToken) { - using var registration = Connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); - + ThrowIfNotInResult(); var field = GetInfo(ordinal, typeof(T), out var converter, out var bufferRequirement, out var asObject); - var columnLength = await SeekToColumn(async: true, ordinal, field).ConfigureAwait(false); - if (columnLength is -1) + using var registration = Connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); + if (await SeekToColumnAsync(ordinal, field).ConfigureAwait(false) is -1) return DbNullValueOrThrow(ordinal); if (typeof(T) == typeof(TextReader)) @@ -1565,7 +1549,7 @@ async ValueTask Core(int ordinal, CancellationToken cancellationToken) await PgReader.StartReadAsync(bufferRequirement, cancellationToken).ConfigureAwait(false); var result = asObject ? (T)await converter.ReadAsObjectAsync(PgReader, cancellationToken).ConfigureAwait(false) - : await Unsafe.As>(converter).ReadAsync(PgReader, cancellationToken).ConfigureAwait(false); + : await converter.UnsafeDowncast().ReadAsync(PgReader, cancellationToken).ConfigureAwait(false); await PgReader.EndReadAsync().ConfigureAwait(false); return result; } @@ -1577,9 +1561,7 @@ async Task GetStream(int ordinal, CancellationToken cancellationToken) var field = GetDefaultInfo(ordinal, out _, out _); PgReader.ThrowIfStreamActive(); - var columnLength = await SeekToColumn(async: true, ordinal, field).ConfigureAwait(false); - - if (columnLength == -1) + if (await SeekToColumnAsync(ordinal, field).ConfigureAwait(false) is -1) return DbNullValueOrThrow(ordinal); return (T)(object)PgReader.GetStream(canSeek: !_isSequential); @@ -1596,6 +1578,8 @@ async Task GetStream(int ordinal, CancellationToken cancellationToken) T GetFieldValueCore(int ordinal) { + ThrowIfNotInResult(); + // The only statically mapped converter, it always exists. if (typeof(T) == typeof(Stream)) return GetStream(ordinal); @@ -1605,18 +1589,14 @@ T GetFieldValueCore(int ordinal) if (typeof(T) == typeof(TextReader)) PgReader.ThrowIfStreamActive(); - var columnLength = - _isSequential - ? SeekToColumnSequential(async: false, ordinal, field).GetAwaiter().GetResult() - : SeekToColumnNonSequential(ordinal, field); - if (columnLength is -1) + if (SeekToColumn(ordinal, field) is -1) return DbNullValueOrThrow(ordinal); Debug.Assert(asObject || converter is PgConverter); PgReader.StartRead(bufferRequirement); var result = asObject ? (T)converter.ReadAsObject(PgReader) - : Unsafe.As>(converter).Read(PgReader); + : converter.UnsafeDowncast().Read(PgReader); PgReader.EndRead(); return result; @@ -1626,12 +1606,7 @@ T GetStream(int ordinal) var field = GetDefaultInfo(ordinal, out _, out _); PgReader.ThrowIfStreamActive(); - var columnLength = - _isSequential - ? SeekToColumnSequential(async: false, ordinal, field).GetAwaiter().GetResult() - : SeekToColumnNonSequential(ordinal, field); - - if (columnLength == -1) + if (SeekToColumn(ordinal, field) is -1) return DbNullValueOrThrow(ordinal); return (T)(object)PgReader.GetStream(canSeek: !_isSequential); @@ -1649,12 +1624,9 @@ T GetStream(int ordinal) /// The value of the specified column. public override object GetValue(int ordinal) { + ThrowIfNotInResult(); var field = GetDefaultInfo(ordinal, out var converter, out var bufferRequirement); - var columnLength = - _isSequential - ? SeekToColumnSequential(async: false, ordinal, field).GetAwaiter().GetResult() - : SeekToColumnNonSequential(ordinal, field); - if (columnLength == -1) + if (SeekToColumn(ordinal, field) is -1) return DBNull.Value; PgReader.StartRead(bufferRequirement); @@ -1681,7 +1653,10 @@ public override object GetValue(int ordinal) /// The zero-based column ordinal. /// true if the specified column is equivalent to ; otherwise false. public override bool IsDBNull(int ordinal) - => SeekToColumn(async: false, ordinal, CheckRowAndGetField(ordinal).DataFormat, resumableOp: true).GetAwaiter().GetResult() is -1; + { + ThrowIfNotInResult(); + return SeekToColumn(ordinal, RowDescription[ordinal].DataFormat, resumableOp: true) is -1; + } /// /// An asynchronous version of , which gets a value that indicates whether the column contains non-existent or missing values. @@ -1701,8 +1676,9 @@ public override Task IsDBNullAsync(int ordinal, CancellationToken cancella async Task Core(int ordinal, CancellationToken cancellationToken) { + ThrowIfNotInResult(); using var registration = Connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); - return await SeekToColumn(async: true, ordinal, CheckRowAndGetField(ordinal).DataFormat, resumableOp: true).ConfigureAwait(false) is -1; + return await SeekToColumnAsync(ordinal, RowDescription[ordinal].DataFormat, resumableOp: true).ConfigureAwait(false) is -1; } } @@ -1717,9 +1693,9 @@ async Task Core(int ordinal, CancellationToken cancellationToken) /// The zero-based column ordinal. public override int GetOrdinal(string name) { + ThrowIfClosedOrDisposed(); if (string.IsNullOrEmpty(name)) ThrowHelper.ThrowArgumentException($"{nameof(name)} cannot be empty", nameof(name)); - CheckClosedOrDisposed(); if (RowDescription is null) ThrowHelper.ThrowInvalidOperationException("No resultset is currently being traversed"); return RowDescription.GetFieldIndex(name); @@ -1773,7 +1749,7 @@ public override IEnumerator GetEnumerator() /// /// public ReadOnlyCollection GetColumnSchema() - => GetColumnSchema(async: false).GetAwaiter().GetResult(); + => GetColumnSchema(async: false).GetAwaiter().GetResult(); ReadOnlyCollection IDbColumnSchemaGenerator.GetColumnSchema() { @@ -1790,14 +1766,14 @@ ReadOnlyCollection IDbColumnSchemaGenerator.GetColumnSchema() /// Asynchronously returns schema information for the columns in the current resultset. /// /// - public new Task> GetColumnSchemaAsync(CancellationToken cancellationToken = default) - => GetColumnSchema(async: true, cancellationToken); + public override Task> GetColumnSchemaAsync(CancellationToken cancellationToken = default) + => GetColumnSchema(async: true, cancellationToken); - Task> GetColumnSchema(bool async, CancellationToken cancellationToken = default) + Task> GetColumnSchema(bool async, CancellationToken cancellationToken = default) where T : DbColumn => RowDescription == null || ColumnCount == 0 - ? Task.FromResult(new List().AsReadOnly()) + ? Task.FromResult(new List().AsReadOnly()) : new DbColumnSchemaGenerator(_connection!, RowDescription, _behavior.HasFlag(CommandBehavior.KeyInfo)) - .GetColumnSchema(async, cancellationToken); + .GetColumnSchema(async, cancellationToken); #endregion @@ -1855,7 +1831,7 @@ Task> GetColumnSchema(bool async, Cancellatio table.Columns.Add("ProviderSpecificDataType", typeof(Type)); table.Columns.Add("DataTypeName", typeof(string)); - foreach (var column in await GetColumnSchema(async, cancellationToken).ConfigureAwait(false)) + foreach (var column in await GetColumnSchema(async, cancellationToken).ConfigureAwait(false)) { var row = table.NewRow(); @@ -1894,203 +1870,127 @@ Task> GetColumnSchema(bool async, Cancellatio #region Seeking - /// - /// Seeks to the given column. The 4-byte length is read and returned. - /// [MethodImpl(MethodImplOptions.AggressiveInlining)] - ValueTask SeekToColumn(bool async, int ordinal, DataFormat dataFormat, bool resumableOp = false) - => _isSequential - ? SeekToColumnSequential(async, ordinal, dataFormat, resumableOp) - : new(SeekToColumnNonSequential(ordinal, dataFormat, resumableOp)); - - int SeekToColumnNonSequential(int ordinal, DataFormat dataFormat, bool resumableOp = false) + int SeekToColumn(int ordinal, DataFormat dataFormat, bool resumableOp = false) { - var currentColumn = _column; - var buffer = Buffer; - var pgReader = PgReader; - - // Deals with current column commit and rereads - int columnLength; - if (currentColumn >= 0) - { - if (currentColumn == ordinal) - return HandleReread(pgReader.Resumable && resumableOp); - pgReader.Commit(resuming: false); - } + Debug.Assert(_isRowBuffered || _isSequential); + var reader = PgReader; + var column = _column; - // Deals with forward movement - Debug.Assert(ordinal != currentColumn); - if (ordinal > currentColumn) - { - // Written as a while to be able to increment _column directly after reading into it. - while (_column < ordinal - 1) - { - columnLength = buffer.ReadInt32(); - _column++; - Debug.Assert(columnLength >= -1); - if (columnLength > 0) - buffer.Skip(columnLength); - } - columnLength = buffer.ReadInt32(); - } - else - columnLength = SeekBackwards(); + // Column rereading rules for sequential mode: + // * We never allow rereading if the column didn't get initialized as resumable the previous time + // * If it did get initialized as resumable we only allow rereading when either of the following is true: + // - The op is a resumable one again + // - The op isn't resumable but the field is still entirely unconsumed + if (_isSequential && (column > ordinal || (column == ordinal && (!reader.Resumable || (!resumableOp && !reader.FieldAtStart))))) + ThrowInvalidSequentialSeek(column, ordinal); - pgReader.Init(columnLength, dataFormat, resumableOp); - _column = ordinal; + if (column == ordinal) + return reader.Restart(resumableOp); + reader.Commit(); + var columnLength = BufferSeekToColumn(column, ordinal, !_isRowBuffered); + reader.Init(columnLength, dataFormat, resumableOp); return columnLength; - int HandleReread(bool resuming) - { - Debug.Assert(pgReader.Initialized); - var columnLength = pgReader.FieldSize; - pgReader.Commit(resuming); - if (!resuming && columnLength > 0) - buffer.ReadPosition -= columnLength; - pgReader.Init(columnLength, dataFormat, resumableOp); - return columnLength; - } + static void ThrowInvalidSequentialSeek(int column, int ordinal) + => ThrowHelper.ThrowInvalidOperationException( + $"Invalid attempt to read from column ordinal '{ordinal}'. With CommandBehavior.SequentialAccess, " + + $"you may only read from column ordinal '{column}' or greater."); + } - // On the first call to SeekBackwards we'll fill up the columns list as we may need seek positions more than once. - [MethodImpl(MethodImplOptions.NoInlining)] - int SeekBackwards() + ValueTask SeekToColumnAsync(int ordinal, DataFormat dataFormat, bool resumableOp = false) + { + // When the row is buffered or we're rereading previous data no IO will be done. + if (_isRowBuffered || _column >= ordinal) + return new(SeekToColumn(ordinal, dataFormat, resumableOp)); + + return Core(ordinal, dataFormat, resumableOp); + + [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))] + async ValueTask Core(int ordinal, DataFormat dataFormat, bool resumableOp) { - // Backfill the first column. - if (_columns.Count is 0) - { - buffer.ReadPosition = _columnsStartPos; - var len = buffer.ReadInt32(); - _columns.Add((buffer.ReadPosition, len)); - } - for (var lastColumnRead = _columns.Count; ordinal >= lastColumnRead; lastColumnRead++) - { - (Buffer.ReadPosition, var lastLen) = _columns[lastColumnRead - 1]; - if (lastLen > 0) - buffer.Skip(lastLen); - var len = Buffer.ReadInt32(); - _columns.Add((Buffer.ReadPosition, len)); - } - (Buffer.ReadPosition, var columnLength) = _columns[ordinal]; + Debug.Assert(!_isRowBuffered && _column < ordinal); + + var reader = PgReader; + await reader.CommitAsync().ConfigureAwait(false); + var columnLength = await BufferSeekToColumnAsync(_column, ordinal, !_isRowBuffered).ConfigureAwait(false); + reader.Init(columnLength, dataFormat, resumableOp); return columnLength; } } - ValueTask SeekToColumnSequential(bool async, int ordinal, DataFormat dataFormat, bool resumableOp = false) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + int BufferSeekToColumn(int column, int ordinal, bool allowIO) { - var reread = _column == ordinal; - // Column rereading rules for sequential mode: - // * We never allow rereading if the column didn't get initialized as resumable the previous time - // * If it did get initialized as resumable we only allow rereading when either of the following is true: - // - The op is a resumable one again - // - The op isn't resumable but the field is still entirely unconsumed - if (ordinal < _column || (reread && (!PgReader.Resumable || (!resumableOp && !PgReader.IsAtStart)))) - ThrowHelper.ThrowInvalidOperationException( - $"Invalid attempt to read from column ordinal '{ordinal}'. With CommandBehavior.SequentialAccess, " + - $"you may only read from column ordinal '{_column}' or greater."); + Debug.Assert(column < ordinal || !allowIO); - var committed = false; - if (!PgReader.CommitHasIO(reread)) + if (column >= ordinal) { - var columnLength = PgReader.FieldSize; - PgReader.Commit(reread); - committed = true; - if (reread) - { - PgReader.Init(columnLength, dataFormat, columnLength is -1 || resumableOp); - return new(columnLength); - } + _column = ordinal; + return SeekBackwards(ordinal); + } - if (TrySeekBuffered(ordinal, out columnLength)) - { - PgReader.Init(columnLength, dataFormat, columnLength is -1 || resumableOp); - return new(columnLength); - } + // We know we need at least one iteration, a do while also helps with optimal codegen. + var buffer = Buffer; + var columnLength = 0; + do + { + if (columnLength > 0) + buffer.Skip(columnLength, allowIO); - // If we couldn't consume the column TrySeekBuffered had to stop at, do so now. - if (columnLength > -1) - { - // Resumable: true causes commit to consume without error. - PgReader.Init(columnLength, dataFormat, resumable: true); - committed = false; - } - } + if (allowIO) + buffer.Ensure(sizeof(int)); + columnLength = buffer.ReadInt32(); + Debug.Assert(columnLength >= -1); + } while (++_column < ordinal); - return Core(async, reread, !committed, ordinal, dataFormat, resumableOp); + return columnLength; - [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))] - async ValueTask Core(bool async, bool reread, bool commit, int ordinal, DataFormat dataFormat, bool resumableOp) + // On the first call to SeekBackwards we'll fill up the columns list as we may need seek positions more than once. + [MethodImpl(MethodImplOptions.NoInlining)] + int SeekBackwards(int ordinal) { - if (commit) - { - Debug.Assert(ordinal != _column); - if (async) - await PgReader.CommitAsync(reread).ConfigureAwait(false); - else - PgReader.Commit(reread); - } + var buffer = Buffer; + var columns = _columns; - if (reread) - { - PgReader.Init(PgReader.FieldSize, dataFormat, PgReader.FieldSize is -1 || resumableOp); - return PgReader.FieldSize; - } + (buffer.ReadPosition, var columnLength) = columns.Count is 0 + ? (_columnsStartPos, 0) + : columns[Math.Min(columns.Count -1, ordinal)]; - // Seek to the requested column - int columnLength; - var buffer = Buffer; - // Written as a while to be able to increment _column directly after reading into it. - while (_column < ordinal - 1) + while (columns.Count <= ordinal) { - await buffer.Ensure(4, async).ConfigureAwait(false); - columnLength = buffer.ReadInt32(); - _column++; - Debug.Assert(columnLength >= -1); if (columnLength > 0) - await buffer.Skip(columnLength, async).ConfigureAwait(false); + buffer.Skip(columnLength); + columnLength = buffer.ReadInt32(); + columns.Add((buffer.ReadPosition, columnLength)); } - await buffer.Ensure(4, async).ConfigureAwait(false); - columnLength = buffer.ReadInt32(); - _column = ordinal; - - PgReader.Init(columnLength, dataFormat, resumableOp); return columnLength; } + } - bool TrySeekBuffered(int ordinal, out int columnLength) + ValueTask BufferSeekToColumnAsync(int column, int ordinal, bool allowIO) + { + return !allowIO || column >= ordinal ? new(BufferSeekToColumn(column, ordinal, allowIO)) : Core(ordinal); + + [AsyncMethodBuilder(typeof(PoolingAsyncValueTaskMethodBuilder<>))] + async ValueTask Core(int ordinal) { - // Skip over unwanted fields - columnLength = -1; + // We know we need at least one iteration, a do while also helps with optimal codegen. var buffer = Buffer; - // Written as a while to be able to increment _column directly after reading into it. - while (_column < ordinal - 1) + var columnLength = 0; + do { - if (buffer.ReadBytesLeft < 4) - { - columnLength = -1; - return false; - } - columnLength = buffer.ReadInt32(); - _column++; - Debug.Assert(columnLength >= -1); if (columnLength > 0) - { - if (buffer.ReadBytesLeft < columnLength) - return false; - buffer.Skip(columnLength); - } - } + await buffer.Skip(async: true, columnLength).ConfigureAwait(false); - if (buffer.ReadBytesLeft < 4) - { - columnLength = -1; - return false; - } + await buffer.EnsureAsync(sizeof(int)).ConfigureAwait(false); + columnLength = buffer.ReadInt32(); + Debug.Assert(columnLength >= -1); + } while (++_column < ordinal); - columnLength = buffer.ReadInt32(); - _column = ordinal; - return true; + return columnLength; } } @@ -2106,15 +2006,15 @@ Task ConsumeRow(bool async) return ConsumeRowSequential(async); // We get here, if we're in a non-sequential mode (or the row is already in the buffer) - ConsumeRowNonSequential(); + ConsumeBufferedRow(); return Task.CompletedTask; async Task ConsumeRowSequential(bool async) { if (async) - await PgReader.CommitAsync(resuming: false).ConfigureAwait(false); + await PgReader.CommitAsync().ConfigureAwait(false); else - PgReader.Commit(resuming: false); + PgReader.Commit(); // Skip over the remaining columns in the row var buffer = Buffer; @@ -2126,16 +2026,16 @@ async Task ConsumeRowSequential(bool async) _column++; Debug.Assert(columnLength >= -1); if (columnLength > 0) - await buffer.Skip(columnLength, async).ConfigureAwait(false); + await buffer.Skip(async, columnLength).ConfigureAwait(false); } } } [MethodImpl(MethodImplOptions.AggressiveInlining)] - void ConsumeRowNonSequential() + void ConsumeBufferedRow() { Debug.Assert(State is ReaderState.InResult or ReaderState.BeforeResult); - PgReader.Commit(resuming: false); + PgReader.Commit(); Buffer.ReadPosition = _dataMsgEnd; } @@ -2143,25 +2043,6 @@ void ConsumeRowNonSequential() #region Checks - void CheckResultSet() - { - switch (State) - { - case ReaderState.BeforeResult: - case ReaderState.InResult: - return; - case ReaderState.Closed: - ThrowHelper.ThrowInvalidOperationException("The reader is closed"); - return; - case ReaderState.Disposed: - ThrowHelper.ThrowObjectDisposedException(nameof(NpgsqlDataReader)); - return; - default: - ThrowHelper.ThrowInvalidOperationException("No resultset is currently being traversed"); - return; - } - } - [MethodImpl(MethodImplOptions.NoInlining)] T DbNullValueOrThrow(int ordinal) { @@ -2172,22 +2053,15 @@ T DbNullValueOrThrow(int ordinal) if (typeof(T) == typeof(object)) return (T)(object)DBNull.Value; - ThrowHelper.ThrowInvalidCastException_NoValue(CheckRowAndGetField(ordinal)); + ThrowHelper.ThrowInvalidCastException_NoValue(RowDescription![ordinal]); return default; } [MethodImpl(MethodImplOptions.AggressiveInlining)] DataFormat GetInfo(int ordinal, Type type, out PgConverter converter, out Size bufferRequirement, out bool asObject) { - var state = State; - if (state is not ReaderState.InResult || (uint)ordinal > (uint)ColumnCount) - { - Unsafe.SkipInit(out converter); - Unsafe.SkipInit(out bufferRequirement); - Unsafe.SkipInit(out asObject); - HandleInvalidState(state, ColumnCount); - Debug.Fail("Should never get here"); - } + if ((uint)ordinal > (uint)ColumnCount) + ThrowHelper.ThrowIndexOutOfRangeException("Ordinal must be between 0 and " + (ColumnCount - 1)); ref var info = ref ColumnInfoCache![ordinal]; @@ -2206,7 +2080,7 @@ DataFormat GetInfo(int ordinal, Type type, out PgConverter converter, out Size b [MethodImpl(MethodImplOptions.NoInlining)] DataFormat Slow(ref ColumnInfo info, out PgConverter converter, out Size bufferRequirement, out bool asObject) { - var field = CheckRowAndGetField(ordinal); + var field = RowDescription![ordinal]; field.GetInfo(type, ref info); converter = info.ConverterInfo.Converter; bufferRequirement = info.ConverterInfo.BufferRequirement; @@ -2218,33 +2092,47 @@ DataFormat Slow(ref ColumnInfo info, out PgConverter converter, out Size bufferR [MethodImpl(MethodImplOptions.AggressiveInlining)] DataFormat GetDefaultInfo(int ordinal, out PgConverter converter, out Size bufferRequirement) { - var field = CheckRowAndGetField(ordinal); + var field = RowDescription![ordinal]; - converter = field.ObjectOrDefaultInfo.Converter; - bufferRequirement = field.ObjectOrDefaultInfo.BufferRequirement; + converter = field.ObjectInfo.Converter; + bufferRequirement = field.ObjectInfo.BufferRequirement; return field.DataFormat; } - [MethodImpl(MethodImplOptions.AggressiveInlining)] - FieldDescription CheckRowAndGetField(int column) + /// + /// Checks that we have a RowDescription, but not necessary an actual resultset + /// (for operations which work in SchemaOnly mode. + /// + FieldDescription GetField(int ordinal) { - var columns = RowDescription; - var state = State; - if (state is ReaderState.InResult && (uint)column < (uint)columns!.Count) - return columns[column]; + ThrowIfClosedOrDisposed(); + if (RowDescription is { } columns) + return columns[ordinal]; - return HandleInvalidState(state, columns?.Count ?? 0); + ThrowHelper.ThrowInvalidOperationException("No resultset is currently being traversed"); + return default!; + } + + void ThrowIfClosedOrDisposed() + { + if (State is (ReaderState.Closed or ReaderState.Disposed) and var state) + ThrowInvalidState(state); + } + + [MemberNotNull(nameof(RowDescription))] + void ThrowIfNotInResult() + { + if (State is not ReaderState.InResult and var state) + ThrowInvalidState(state); + + Debug.Assert(RowDescription is not null); } - [DoesNotReturn] [MethodImpl(MethodImplOptions.NoInlining)] - static FieldDescription HandleInvalidState(ReaderState state, int maxColumns) + static void ThrowInvalidState(ReaderState state) { switch (state) { - case ReaderState.InResult: - ThrowColumnOutOfRange(maxColumns); - break; case ReaderState.Closed: ThrowHelper.ThrowInvalidOperationException("The reader is closed"); break; @@ -2252,52 +2140,11 @@ static FieldDescription HandleInvalidState(ReaderState state, int maxColumns) ThrowHelper.ThrowObjectDisposedException(nameof(NpgsqlDataReader)); break; default: - ThrowHelper.ThrowInvalidOperationException("No row is available"); - break; - } - return default!; - } - - /// - /// Checks that we have a RowDescription, but not necessary an actual resultset - /// (for operations which work in SchemaOnly mode. - /// - FieldDescription GetField(int column) - { - if (RowDescription is null) ThrowHelper.ThrowInvalidOperationException("No resultset is currently being traversed"); - - var columns = RowDescription; - if (column < 0 || column >= columns.Count) - ThrowColumnOutOfRange(columns.Count); - - return columns[column]; - } - - void CheckClosedOrDisposed() - { - if (State is (ReaderState.Closed or ReaderState.Disposed) and var state) - Throw(state); - - [MethodImpl(MethodImplOptions.NoInlining)] - static void Throw(ReaderState state) - { - switch (state) - { - case ReaderState.Closed: - ThrowHelper.ThrowInvalidOperationException("The reader is closed"); - return; - case ReaderState.Disposed: - ThrowHelper.ThrowObjectDisposedException(nameof(NpgsqlDataReader)); - return; - } + break; } } - [DoesNotReturn] - static void ThrowColumnOutOfRange(int maxIndex) => - throw new IndexOutOfRangeException($"Column must be between {0} and {maxIndex - 1}"); - #endregion #region Misc diff --git a/src/Npgsql/NpgsqlDataSource.cs b/src/Npgsql/NpgsqlDataSource.cs index 9415296585..e9311a16c4 100644 --- a/src/Npgsql/NpgsqlDataSource.cs +++ b/src/Npgsql/NpgsqlDataSource.cs @@ -4,7 +4,6 @@ using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Net.Security; -using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; using System.Transactions; @@ -32,16 +31,26 @@ public abstract class NpgsqlDataSource : DbDataSource internal NpgsqlLoggingConfiguration LoggingConfiguration { get; } readonly PgTypeInfoResolverChain _resolverChain; - internal PgSerializerOptions SerializerOptions { get; private set; } = null!; // Initialized at bootstrapping + readonly IEnumerable _dbTypeResolverFactories; - /// - /// Information about PostgreSQL and PostgreSQL-like databases (e.g. type definitions, capabilities...). - /// - internal NpgsqlDatabaseInfo DatabaseInfo { get; private set; } = null!; // Initialized at bootstrapping + internal ReloadableState CurrentReloadableState = null!; // Initialized during bootstrapping. + + // Initialized at bootstrapping + internal sealed class ReloadableState(NpgsqlDatabaseInfo databaseInfo, PgSerializerOptions serializerOptions, IDbTypeResolver? dbTypeResolver) + { + /// + /// Information about PostgreSQL and PostgreSQL-like databases (e.g. type definitions, capabilities...). + /// + public NpgsqlDatabaseInfo DatabaseInfo { get; } = databaseInfo; + + public PgSerializerOptions SerializerOptions { get; } = serializerOptions; + + public IDbTypeResolver? DbTypeResolver { get; } = dbTypeResolver; + } internal TransportSecurityHandler TransportSecurityHandler { get; } - internal RemoteCertificateValidationCallback? UserCertificateValidationCallback { get; } - internal Action? ClientCertificatesCallback { get; } + + internal Action? SslClientAuthenticationOptionsCallback { get; } readonly Func? _passwordProvider; readonly Func>? _passwordProviderAsync; @@ -82,42 +91,50 @@ private protected readonly Dictionary> _pendi readonly SemaphoreSlim _setupMappingsSemaphore = new(1); readonly INpgsqlNameTranslator _defaultNameTranslator; + readonly IDisposable? _eventSourceEvents; - internal List? _hackyEnumTypeMappings; - - internal NpgsqlDataSource( - NpgsqlConnectionStringBuilder settings, - NpgsqlDataSourceConfiguration dataSourceConfig) + internal NpgsqlDataSource(NpgsqlConnectionStringBuilder settings, NpgsqlDataSourceConfiguration dataSourceConfig, bool reportMetrics) { - Settings = settings; - ConnectionString = settings.PersistSecurityInfo - ? settings.ToString() - : settings.ToStringWithoutPassword(); - Configuration = dataSourceConfig; (var name, LoggingConfiguration, + _, + _, TransportSecurityHandler, IntegratedSecurityHandler, - UserCertificateValidationCallback, - ClientCertificatesCallback, + SslClientAuthenticationOptionsCallback, _passwordProvider, _passwordProviderAsync, _periodicPasswordProvider, _periodicPasswordSuccessRefreshInterval, _periodicPasswordFailureRefreshInterval, - var resolverChain, - _hackyEnumTypeMappings, + _resolverChain, + _dbTypeResolverFactories, _defaultNameTranslator, ConnectionInitializer, - ConnectionInitializerAsync) + ConnectionInitializerAsync, + _) = dataSourceConfig; _connectionLogger = LoggingConfiguration.ConnectionLogger; Debug.Assert(_passwordProvider is null || _passwordProviderAsync is not null); - _resolverChain = resolverChain; + Settings = settings; + + if (settings.PersistSecurityInfo) + { + ConnectionString = settings.ToString(); + + // The data source name is reported in tracing/metrics, so avoid leaking the password through there. + Name = name ?? settings.ToStringWithoutPassword(); + } + else + { + ConnectionString = settings.ToStringWithoutPassword(); + Name = name ?? ConnectionString; + } + _password = settings.Password; if (_periodicPasswordSuccessRefreshInterval != default) @@ -126,15 +143,28 @@ internal NpgsqlDataSource( _timerPasswordProviderCancellationTokenSource = new(); - // Create the timer, but don't start it; the manual run below will will schedule the first refresh. - _periodicPasswordProviderTimer = new Timer(state => _ = RefreshPassword(), null, Timeout.InfiniteTimeSpan, Timeout.InfiniteTimeSpan); + // Create the timer, but don't start it; the manual run below will schedule the first refresh. + using (ExecutionContext.SuppressFlow()) // Don't capture the current ExecutionContext and its AsyncLocals onto the timer causing them to live forever + _periodicPasswordProviderTimer = new Timer(state => _ = RefreshPassword(), null, Timeout.InfiniteTimeSpan, Timeout.InfiniteTimeSpan); // Trigger the first refresh attempt right now, outside the timer; this allows us to capture the Task so it can be observed // in GetPasswordAsync. _passwordRefreshTask = Task.Run(RefreshPassword); } - Name = name ?? ConnectionString; - MetricsReporter = new MetricsReporter(this); + // TODO this needs a rework, but for now we just avoid tracking multi-host data sources directly. + if (reportMetrics) + { + MetricsReporter = new MetricsReporter(this); + if (!NpgsqlEventSource.Log.TryTrackDataSource(Name, this, out _eventSourceEvents)) + _connectionLogger.LogDebug("NpgsqlEventSource could not start tracking a DataSource, " + + "this can happen if more than one data source uses the same connection string."); + } + else + { + // This is not accessed anywhere currently for multi-host data sources. + // Connectors which handle the metrics always access their nonpooling/pooling data source instead. + MetricsReporter = null!; + } } /// @@ -208,6 +238,12 @@ protected override DbBatch CreateDbBatch() public new NpgsqlBatch CreateBatch() => new NpgsqlDataSourceBatch(CreateConnection()); + /// + /// If the data source pools connections, clears any idle connections and flags any busy connections to be closed as soon as they're + /// returned to the pool. + /// + public abstract void Clear(); + /// /// Creates a new for the given . /// @@ -220,6 +256,29 @@ public static NpgsqlDataSource Create(string connectionString) public static NpgsqlDataSource Create(NpgsqlConnectionStringBuilder connectionStringBuilder) => Create(connectionStringBuilder.ToString()); + /// + /// Flushes the type cache for this data source. + /// Type changes will appear for connections only after they are re-opened from the pool. + /// + public void ReloadTypes() + { + using var connection = OpenConnection(); + connection.ReloadTypes(); + } + + /// + /// Flushes the type cache for this data source. + /// Type changes will appear for connections only after they are re-opened from the pool. + /// + public async Task ReloadTypesAsync(CancellationToken cancellationToken = default) + { + var connection = await OpenConnectionAsync(cancellationToken).ConfigureAwait(false); + await using (connection.ConfigureAwait(false)) + { + await connection.ReloadTypesAsync(cancellationToken).ConfigureAwait(false); + } + } + internal async Task Bootstrap( NpgsqlConnector connector, NpgsqlTimeout timeout, @@ -244,28 +303,36 @@ internal async Task Bootstrap( // The type loading below will need to send queries to the database, and that depends on a type mapper being set up (even if its // empty). So we set up a minimal version here, and then later inject the actual DatabaseInfo. - connector.SerializerOptions = - new(PostgresMinimalDatabaseInfo.DefaultTypeCatalog) + connector.ReloadableState = new( + databaseInfo: PostgresMinimalDatabaseInfo.DefaultTypeCatalog, + serializerOptions: new(PostgresMinimalDatabaseInfo.DefaultTypeCatalog) { TextEncoding = connector.TextEncoding, TypeInfoResolver = AdoTypeInfoResolverFactory.Instance.CreateResolver(), - }; + }, + dbTypeResolver: null); NpgsqlDatabaseInfo databaseInfo; using (connector.StartUserAction(ConnectorState.Executing, cancellationToken)) databaseInfo = await NpgsqlDatabaseInfo.Load(connector, timeout, async).ConfigureAwait(false); - connector.DatabaseInfo = DatabaseInfo = databaseInfo; - connector.SerializerOptions = SerializerOptions = - new(databaseInfo, _resolverChain, CreateTimeZoneProvider(connector.Timezone)) - { - ArrayNullabilityMode = Settings.ArrayNullabilityMode, - EnableDateTimeInfinityConversions = !Statics.DisableDateTimeInfinityConversions, - TextEncoding = connector.TextEncoding, - DefaultNameTranslator = _defaultNameTranslator, + var serializerOptions = new PgSerializerOptions(databaseInfo, _resolverChain, CreateTimeZoneProvider(connector.Timezone)) + { + ArrayNullabilityMode = Settings.ArrayNullabilityMode, + EnableDateTimeInfinityConversions = !Statics.DisableDateTimeInfinityConversions, + TextEncoding = connector.TextEncoding, + DefaultNameTranslator = _defaultNameTranslator + }; + + var resolvers = new List(); + foreach (var dbTypeResolverFactory in _dbTypeResolverFactories) + resolvers.Add(dbTypeResolverFactory.CreateDbTypeResolver(databaseInfo)); - }; + connector.ReloadableState = CurrentReloadableState = new ReloadableState( + databaseInfo: databaseInfo, + serializerOptions: serializerOptions, + dbTypeResolver: new ChainDbTypeResolver(resolvers)); IsBootstrapped = true; } @@ -371,8 +438,6 @@ internal abstract ValueTask Get( internal abstract void Return(NpgsqlConnector connector); - internal abstract void Clear(); - internal abstract bool OwnsConnectors { get; } #region Database state management @@ -399,7 +464,7 @@ internal DatabaseState UpdateDatabaseState( var databaseStateInfo = _databaseStateInfo; if (!ignoreTimeStamp && timeStamp <= databaseStateInfo.TimeStamp) - return _databaseStateInfo.State; + return databaseStateInfo.State; _databaseStateInfo = new(newState, new NpgsqlTimeout(stateExpiration), timeStamp); @@ -443,7 +508,7 @@ internal virtual bool TryRentEnlistedPending(Transaction transaction, NpgsqlConn connector = null; return false; } - connector = list[list.Count - 1]; + connector = list[^1]; list.RemoveAt(list.Count - 1); if (list.Count == 0) _pendingEnlistedConnectors.Remove(transaction); @@ -473,8 +538,15 @@ protected virtual void DisposeBase() } _periodicPasswordProviderTimer?.Dispose(); - _setupMappingsSemaphore.Dispose(); - MetricsReporter.Dispose(); + if (MetricsReporter is not null) + { + MetricsReporter.Dispose(); + _eventSourceEvents?.Dispose(); + } + + // We do not dispose _setupMappingsSemaphore explicitly, leaving it to finalizer + // Due to possible concurrent access, which might lead to deadlock + // See issue #6115 Clear(); } @@ -501,31 +573,31 @@ protected virtual async ValueTask DisposeAsyncBase() if (_periodicPasswordProviderTimer is not null) await _periodicPasswordProviderTimer.DisposeAsync().ConfigureAwait(false); - _setupMappingsSemaphore.Dispose(); - MetricsReporter.Dispose(); + if (MetricsReporter is not null) + { + MetricsReporter.Dispose(); + _eventSourceEvents?.Dispose(); + } + // We do not dispose _setupMappingsSemaphore explicitly, leaving it to finalizer + // Due to possible concurrent access, which might lead to deadlock + // See issue #6115 // TODO: async Clear, #4499 Clear(); } private protected void CheckDisposed() - { - if (_isDisposed == 1) - ThrowHelper.ThrowObjectDisposedException(GetType().FullName); - } + => ObjectDisposedException.ThrowIf(_isDisposed == 1, this); #endregion - sealed class DatabaseStateInfo + sealed class DatabaseStateInfo(DatabaseState state, NpgsqlTimeout timeout, DateTime timeStamp) { - internal readonly DatabaseState State; - internal readonly NpgsqlTimeout Timeout; + internal readonly DatabaseState State = state; + internal readonly NpgsqlTimeout Timeout = timeout; // While the TimeStamp is not strictly required, it does lower the risk of overwriting the current state with an old value - internal readonly DateTime TimeStamp; + internal readonly DateTime TimeStamp = timeStamp; public DatabaseStateInfo() : this(default, default, default) { } - - public DatabaseStateInfo(DatabaseState state, NpgsqlTimeout timeout, DateTime timeStamp) - => (State, Timeout, TimeStamp) = (state, timeout, timeStamp); } } diff --git a/src/Npgsql/NpgsqlDataSourceBatch.cs b/src/Npgsql/NpgsqlDataSourceBatch.cs index fa239ee8e6..c5b44e9ff6 100644 --- a/src/Npgsql/NpgsqlDataSourceBatch.cs +++ b/src/Npgsql/NpgsqlDataSourceBatch.cs @@ -9,7 +9,7 @@ namespace Npgsql; sealed class NpgsqlDataSourceBatch : NpgsqlBatch { internal NpgsqlDataSourceBatch(NpgsqlConnection connection) - : base(new NpgsqlDataSourceCommand(DefaultBatchCommandsSize, connection)) + : base(static (conn, batch) => new NpgsqlDataSourceCommand(batch, DefaultBatchCommandsSize, conn), connection) { } diff --git a/src/Npgsql/NpgsqlDataSourceBuilder.cs b/src/Npgsql/NpgsqlDataSourceBuilder.cs index e304a559cc..156885d04e 100644 --- a/src/Npgsql/NpgsqlDataSourceBuilder.cs +++ b/src/Npgsql/NpgsqlDataSourceBuilder.cs @@ -9,6 +9,7 @@ using Microsoft.Extensions.Logging; using Npgsql.Internal; using Npgsql.Internal.ResolverFactories; +using Npgsql.NameTranslation; using Npgsql.TypeMapping; using NpgsqlTypes; @@ -40,7 +41,7 @@ public INpgsqlNameTranslator DefaultNameTranslator } /// - /// A connection string builder that can be used to configured the connection string on the builder. + /// A connection string builder that can be used to configure the connection string on the builder. /// public NpgsqlConnectionStringBuilder ConnectionStringBuilder => _internalBuilder.ConnectionStringBuilder; @@ -50,8 +51,7 @@ public INpgsqlNameTranslator DefaultNameTranslator public string ConnectionString => _internalBuilder.ConnectionString; internal static void ResetGlobalMappings(bool overwrite) - => GlobalTypeMapper.Instance.AddGlobalTypeMappingResolvers(new PgTypeInfoResolverFactory[] - { + => GlobalTypeMapper.Instance.AddGlobalTypeMappingResolvers([ overwrite ? new AdoTypeInfoResolverFactory() : AdoTypeInfoResolverFactory.Instance, new ExtraConversionResolverFactory(), new JsonTypeInfoResolverFactory(), @@ -60,7 +60,8 @@ internal static void ResetGlobalMappings(bool overwrite) new NetworkTypeInfoResolverFactory(), new GeometricTypeInfoResolverFactory(), new LTreeTypeInfoResolverFactory(), - }, static () => + new CubeTypeInfoResolverFactory() + ], static () => { var builder = new PgTypeInfoResolverChainBuilder(); builder.EnableRanges(); @@ -88,6 +89,7 @@ public NpgsqlDataSourceBuilder(string? connectionString = null) instance.AppendResolverFactory(new NetworkTypeInfoResolverFactory()); instance.AppendResolverFactory(new GeometricTypeInfoResolverFactory()); instance.AppendResolverFactory(new LTreeTypeInfoResolverFactory()); + instance.AppendResolverFactory(new CubeTypeInfoResolverFactory()); }; _internalBuilder.ConfigureResolverChain = static chain => chain.Add(UnsupportedTypeInfoResolver); _internalBuilder.EnableTransportSecurity(); @@ -121,11 +123,30 @@ public NpgsqlDataSourceBuilder EnableParameterLogging(bool parameterLoggingEnabl return this; } + /// + /// Configures type loading options for the DataSource. + /// + public NpgsqlDataSourceBuilder ConfigureTypeLoading(Action configureAction) + { + _internalBuilder.ConfigureTypeLoading(configureAction); + return this; + } + + /// + /// Configures OpenTelemetry tracing options. + /// + /// The same builder instance so that multiple calls can be chained. + public NpgsqlDataSourceBuilder ConfigureTracing(Action configureAction) + { + _internalBuilder.ConfigureTracing(configureAction); + return this; + } + /// /// Configures the JSON serializer options used when reading and writing all System.Text.Json data. /// /// Options to customize JSON serialization and deserialization. - /// + /// The same builder instance so that multiple calls can be chained. public NpgsqlDataSourceBuilder ConfigureJsonOptions(JsonSerializerOptions serializerOptions) { _internalBuilder.ConfigureJsonOptions(serializerOptions); @@ -194,6 +215,7 @@ public NpgsqlDataSourceBuilder EnableUnmappedTypes() /// /// /// The same builder instance so that multiple calls can be chained. + [Obsolete("Use UseSslClientAuthenticationOptionsCallback")] public NpgsqlDataSourceBuilder UseUserCertificateValidationCallback(RemoteCertificateValidationCallback userCertificateValidationCallback) { _internalBuilder.UseUserCertificateValidationCallback(userCertificateValidationCallback); @@ -205,6 +227,7 @@ public NpgsqlDataSourceBuilder UseUserCertificateValidationCallback(RemoteCertif /// /// The client certificate to be sent to PostgreSQL when opening a connection. /// The same builder instance so that multiple calls can be chained. + [Obsolete("Use UseSslClientAuthenticationOptionsCallback")] public NpgsqlDataSourceBuilder UseClientCertificate(X509Certificate? clientCertificate) { _internalBuilder.UseClientCertificate(clientCertificate); @@ -216,12 +239,29 @@ public NpgsqlDataSourceBuilder UseClientCertificate(X509Certificate? clientCerti /// /// The client certificate collection to be sent to PostgreSQL when opening a connection. /// The same builder instance so that multiple calls can be chained. + [Obsolete("Use UseSslClientAuthenticationOptionsCallback")] public NpgsqlDataSourceBuilder UseClientCertificates(X509CertificateCollection? clientCertificates) { _internalBuilder.UseClientCertificates(clientCertificates); return this; } + /// + /// When using SSL/TLS, this is a callback that allows customizing SslStream's authentication options. + /// + /// The callback to customize SslStream's authentication options. + /// + /// + /// See . + /// + /// + /// The same builder instance so that multiple calls can be chained. + public NpgsqlDataSourceBuilder UseSslClientAuthenticationOptionsCallback(Action? sslClientAuthenticationOptionsCallback) + { + _internalBuilder.UseSslClientAuthenticationOptionsCallback(sslClientAuthenticationOptionsCallback); + return this; + } + /// /// Specifies a callback to modify the collection of SSL/TLS client certificates which Npgsql will send to PostgreSQL for /// certificate-based authentication. This is an advanced API, consider using or @@ -239,6 +279,7 @@ public NpgsqlDataSourceBuilder UseClientCertificates(X509CertificateCollection? /// /// /// The same builder instance so that multiple calls can be chained. + [Obsolete("Use UseSslClientAuthenticationOptionsCallback")] public NpgsqlDataSourceBuilder UseClientCertificatesCallback(Action? clientCertificatesCallback) { _internalBuilder.UseClientCertificatesCallback(clientCertificatesCallback); @@ -256,6 +297,17 @@ public NpgsqlDataSourceBuilder UseRootCertificate(X509Certificate2? rootCertific return this; } + /// + /// Sets the that will be used validate SSL certificate, received from the server. + /// + /// The CA certificates. + /// The same builder instance so that multiple calls can be chained. + public NpgsqlDataSourceBuilder UseRootCertificates(X509Certificate2Collection? rootCertificates) + { + _internalBuilder.UseRootCertificates(rootCertificates); + return this; + } + /// /// Specifies a callback that will be used to validate SSL certificate, received from the server. /// @@ -272,6 +324,23 @@ public NpgsqlDataSourceBuilder UseRootCertificateCallback(Func return this; } + /// + /// Specifies a callback that will be used to validate SSL certificate, received from the server. + /// + /// The callback to get CA certificates. + /// The same builder instance so that multiple calls can be chained. + /// + /// This overload, which accepts a callback, is suitable for scenarios where the certificate rotates + /// and might change during the lifetime of the application. + /// When that's not the case, use the overload which directly accepts the certificate. + /// + /// The same builder instance so that multiple calls can be chained. + public NpgsqlDataSourceBuilder UseRootCertificatesCallback(Func? rootCertificateCallback) + { + _internalBuilder.UseRootCertificatesCallback(rootCertificateCallback); + return this; + } + /// /// Configures a periodic password provider, which is automatically called by the data source at some regular interval. This is the /// recommended way to fetch a rotating access token. @@ -325,19 +394,59 @@ public NpgsqlDataSourceBuilder UsePasswordProvider( return this; } + /// + /// When using Kerberos, this is a callback that allows customizing default settings for Kerberos authentication. + /// + /// The callback containing logic to customize Kerberos authentication settings. + /// + /// + /// See . + /// + /// + /// The same builder instance so that multiple calls can be chained. + public NpgsqlDataSourceBuilder UseNegotiateOptionsCallback(Action? negotiateOptionsCallback) + { + _internalBuilder.UseNegotiateOptionsCallback(negotiateOptionsCallback); + return this; + } + #endregion Authentication #region Type mapping /// + void INpgsqlTypeMapper.AddDbTypeResolverFactory(DbTypeResolverFactory factory) + => ((INpgsqlTypeMapper)_internalBuilder).AddDbTypeResolverFactory(factory); + + /// + [Experimental(NpgsqlDiagnostics.ConvertersExperimental)] public void AddTypeInfoResolverFactory(PgTypeInfoResolverFactory factory) => _internalBuilder.AddTypeInfoResolverFactory(factory); /// void INpgsqlTypeMapper.Reset() => ((INpgsqlTypeMapper)_internalBuilder).Reset(); - /// - public INpgsqlTypeMapper MapEnum<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields)] TEnum>(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + /// + /// Maps a CLR enum to a PostgreSQL enum type. + /// + /// + /// CLR enum labels are mapped by name to PostgreSQL enum labels. + /// The translation strategy can be controlled by the parameter, + /// which defaults to . + /// You can also use the on your enum fields to manually specify a PostgreSQL enum label. + /// If there is a discrepancy between the .NET and database labels while an enum is read or written, + /// an exception will be raised. + /// + /// + /// A PostgreSQL type name for the corresponding enum type in the database. + /// If null, the name translator given in will be used. + /// + /// + /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). + /// Defaults to . + /// + /// The .NET enum type to be mapped + public NpgsqlDataSourceBuilder MapEnum<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields)] TEnum>(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) where TEnum : struct, Enum { _internalBuilder.MapEnum(pgName, nameTranslator); @@ -349,41 +458,100 @@ public void AddTypeInfoResolverFactory(PgTypeInfoResolverFactory factory) where TEnum : struct, Enum => _internalBuilder.UnmapEnum(pgName, nameTranslator); - /// + /// + /// Maps a CLR enum to a PostgreSQL enum type. + /// + /// + /// CLR enum labels are mapped by name to PostgreSQL enum labels. + /// The translation strategy can be controlled by the parameter, + /// which defaults to . + /// You can also use the on your enum fields to manually specify a PostgreSQL enum label. + /// If there is a discrepancy between the .NET and database labels while an enum is read or written, + /// an exception will be raised. + /// + /// The .NET enum type to be mapped + /// + /// A PostgreSQL type name for the corresponding enum type in the database. + /// If null, the name translator given in will be used. + /// + /// + /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). + /// Defaults to . + /// [RequiresDynamicCode("Calling MapEnum with a Type can require creating new generic types or methods. This may not work when AOT compiling.")] - public INpgsqlTypeMapper MapEnum([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields | DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)] + public NpgsqlDataSourceBuilder MapEnum([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields | DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)] Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) - => _internalBuilder.MapEnum(clrType, pgName, nameTranslator); + { + _internalBuilder.MapEnum(clrType, pgName, nameTranslator); + return this; + } /// public bool UnmapEnum([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields | DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)] Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) => _internalBuilder.UnmapEnum(clrType, pgName, nameTranslator); - /// + /// + /// Maps a CLR type to a PostgreSQL composite type. + /// + /// + /// CLR fields and properties by string to PostgreSQL names. + /// The translation strategy can be controlled by the parameter, + /// which defaults to . + /// You can also use the on your members to manually specify a PostgreSQL name. + /// If there is a discrepancy between the .NET type and database type while a composite is read or written, + /// an exception will be raised. + /// + /// + /// A PostgreSQL type name for the corresponding composite type in the database. + /// If null, the name translator given in will be used. + /// + /// + /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). + /// Defaults to . + /// + /// The .NET type to be mapped [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] - public INpgsqlTypeMapper MapComposite<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] T>( + public NpgsqlDataSourceBuilder MapComposite<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] T>( string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) { - _internalBuilder.MapComposite(pgName, nameTranslator); + _internalBuilder.MapComposite(typeof(T), pgName, nameTranslator); return this; } /// [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] - public INpgsqlTypeMapper MapComposite([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] + public bool UnmapComposite<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] T>( + string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + => _internalBuilder.UnmapComposite(typeof(T), pgName, nameTranslator); + + /// + /// Maps a CLR type to a composite type. + /// + /// + /// Maps CLR fields and properties by string to PostgreSQL names. + /// The translation strategy can be controlled by the parameter, + /// which defaults to . + /// If there is a discrepancy between the .NET type and database type while a composite is read or written, + /// an exception will be raised. + /// + /// The .NET type to be mapped. + /// + /// A PostgreSQL type name for the corresponding composite type in the database. + /// If null, the name translator given in will be used. + /// + /// + /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). + /// Defaults to . + /// + [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] + public NpgsqlDataSourceBuilder MapComposite([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) { _internalBuilder.MapComposite(clrType, pgName, nameTranslator); return this; } - /// - [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] - public bool UnmapComposite<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] T>( - string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) - => _internalBuilder.UnmapComposite(pgName, nameTranslator); - /// [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] public bool UnmapComposite([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] @@ -455,4 +623,38 @@ INpgsqlTypeMapper INpgsqlTypeMapper.EnableRecordsAsTuples() "The use of unmapped enums, ranges or multiranges requires dynamic code usage which is incompatible with NativeAOT.")] INpgsqlTypeMapper INpgsqlTypeMapper.EnableUnmappedTypes() => EnableUnmappedTypes(); + + /// + INpgsqlTypeMapper INpgsqlTypeMapper.MapEnum<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields)] TEnum>(string? pgName, INpgsqlNameTranslator? nameTranslator) + { + _internalBuilder.MapEnum(pgName, nameTranslator); + return this; + } + + /// + [RequiresDynamicCode("Calling MapEnum with a Type can require creating new generic types or methods. This may not work when AOT compiling.")] + INpgsqlTypeMapper INpgsqlTypeMapper.MapEnum([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields | DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)] + Type clrType, string? pgName, INpgsqlNameTranslator? nameTranslator) + { + _internalBuilder.MapEnum(clrType, pgName, nameTranslator); + return this; + } + + /// + [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] + INpgsqlTypeMapper INpgsqlTypeMapper.MapComposite<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] T>( + string? pgName, INpgsqlNameTranslator? nameTranslator) + { + _internalBuilder.MapComposite(typeof(T), pgName, nameTranslator); + return this; + } + + /// + [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] + INpgsqlTypeMapper INpgsqlTypeMapper.MapComposite([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] + Type clrType, string? pgName, INpgsqlNameTranslator? nameTranslator) + { + _internalBuilder.MapComposite(clrType, pgName, nameTranslator); + return this; + } } diff --git a/src/Npgsql/NpgsqlDataSourceCommand.cs b/src/Npgsql/NpgsqlDataSourceCommand.cs index 3ff565de66..d293194f66 100644 --- a/src/Npgsql/NpgsqlDataSourceCommand.cs +++ b/src/Npgsql/NpgsqlDataSourceCommand.cs @@ -15,8 +15,8 @@ internal NpgsqlDataSourceCommand(NpgsqlConnection connection) } // For NpgsqlBatch only - internal NpgsqlDataSourceCommand(int batchCommandCapacity, NpgsqlConnection connection) - : base(batchCommandCapacity, connection) + internal NpgsqlDataSourceCommand(NpgsqlBatch batch, int batchCommandCapacity, NpgsqlConnection connection) + : base(batch, batchCommandCapacity, connection) { } diff --git a/src/Npgsql/NpgsqlDataSourceConfiguration.cs b/src/Npgsql/NpgsqlDataSourceConfiguration.cs index ec3e5e4611..f3cdd4b513 100644 --- a/src/Npgsql/NpgsqlDataSourceConfiguration.cs +++ b/src/Npgsql/NpgsqlDataSourceConfiguration.cs @@ -1,7 +1,6 @@ using System; using System.Collections.Generic; using System.Net.Security; -using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; using Npgsql.Internal; @@ -10,17 +9,19 @@ namespace Npgsql; sealed record NpgsqlDataSourceConfiguration(string? Name, NpgsqlLoggingConfiguration LoggingConfiguration, + NpgsqlTracingOptions TracingOptions, + NpgsqlTypeLoadingOptions TypeLoading, TransportSecurityHandler TransportSecurityHandler, - IntegratedSecurityHandler userCertificateValidationCallback, - RemoteCertificateValidationCallback? UserCertificateValidationCallback, - Action? ClientCertificatesCallback, + IntegratedSecurityHandler IntegratedSecurityHandler, + Action? SslClientAuthenticationOptionsCallback, Func? PasswordProvider, Func>? PasswordProviderAsync, Func>? PeriodicPasswordProvider, TimeSpan PeriodicPasswordSuccessRefreshInterval, TimeSpan PeriodicPasswordFailureRefreshInterval, PgTypeInfoResolverChain ResolverChain, - List HackyEnumMappings, + IEnumerable DbTypeResolverFactories, INpgsqlNameTranslator DefaultNameTranslator, Action? ConnectionInitializer, - Func? ConnectionInitializerAsync); + Func? ConnectionInitializerAsync, + Action? NegotiateOptionsCallback); diff --git a/src/Npgsql/NpgsqlDiagnostics.cs b/src/Npgsql/NpgsqlDiagnostics.cs new file mode 100644 index 0000000000..0d9ff5f846 --- /dev/null +++ b/src/Npgsql/NpgsqlDiagnostics.cs @@ -0,0 +1,8 @@ +namespace Npgsql; + +static class NpgsqlDiagnostics +{ + public const string ConvertersExperimental = "NPG9001"; + public const string DatabaseInfoExperimental = "NPG9002"; + public const string DbTypeResolverExperimental = "NPG9003"; +} diff --git a/src/Npgsql/NpgsqlEventId.cs b/src/Npgsql/NpgsqlEventId.cs index cf82ea063d..a0bf0bf30c 100644 --- a/src/Npgsql/NpgsqlEventId.cs +++ b/src/Npgsql/NpgsqlEventId.cs @@ -30,7 +30,7 @@ public static class NpgsqlEventId public const int CaughtUserExceptionInNoticeEventHandler = 1901; public const int CaughtUserExceptionInNotificationEventHandler = 1902; public const int ExceptionWhenClosingPhysicalConnection = 1903; - public const int ExceptionWhenOpeningConnectionForMultiplexing = 1904; + public const int ExceptionWhenOpeningConnectionForMultiplexing = 1904; // Multiplexing has been removed #endregion Connection @@ -48,7 +48,7 @@ public static class NpgsqlEventId public const int DerivingParameters = 2500; - public const int ExceptionWhenWritingMultiplexedCommands = 2600; + public const int ExceptionWhenWritingMultiplexedCommands = 2600; // Multiplexing has been removed #endregion Command diff --git a/src/Npgsql/NpgsqlEventSource.cs b/src/Npgsql/NpgsqlEventSource.cs index d50979bb64..4122bbd8d5 100644 --- a/src/Npgsql/NpgsqlEventSource.cs +++ b/src/Npgsql/NpgsqlEventSource.cs @@ -1,14 +1,19 @@ using System; -using System.Collections.Generic; +using System.Collections.Concurrent; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.Threading; using System.Diagnostics.Tracing; +using System.Runtime.CompilerServices; namespace Npgsql; sealed class NpgsqlEventSource : EventSource { public static readonly NpgsqlEventSource Log = new(); + // A static to keep the CWT values from making themselves uncollectable if they would have a reference through the + // NpgsqlEventSource instance to the CWT table, which they would if this was an instance field. + static readonly NpgsqlEventSourceDataSources DataSourceEvents = new(Log); const string EventSourceName = "Npgsql"; @@ -25,11 +30,6 @@ sealed class NpgsqlEventSource : EventSource PollingCounter? _preparedCommandsRatioCounter; PollingCounter? _poolsCounter; - readonly object _dataSourcesLock = new(); - readonly Dictionary _dataSources = new(); - - PollingCounter? _multiplexingAverageCommandsPerBatchCounter; - PollingCounter? _multiplexingAverageWriteTimePerBatchCounter; long _bytesWritten; long _bytesRead; @@ -39,10 +39,6 @@ sealed class NpgsqlEventSource : EventSource long _currentCommands; long _failedCommands; - long _multiplexingBatchesSent; - long _multiplexingCommandsSent; - long _multiplexingTicksWritten; - internal NpgsqlEventSource() : base(EventSourceName) {} // NOTE @@ -64,7 +60,7 @@ internal void BytesRead(long bytesRead) Interlocked.Add(ref _bytesRead, bytesRead); } - public void CommandStart(string sql) + internal void CommandStart(string sql) { if (IsEnabled()) { @@ -74,7 +70,7 @@ public void CommandStart(string sql) NpgsqlSqlEventSource.Log.CommandStart(sql); } - public void CommandStop() + internal void CommandStop() { if (IsEnabled()) Interlocked.Decrement(ref _currentCommands); @@ -93,56 +89,14 @@ internal void CommandFailed() Interlocked.Increment(ref _failedCommands); } - internal void DataSourceCreated(NpgsqlDataSource dataSource) - { - lock (_dataSourcesLock) - { - _dataSources.Add(dataSource, null); - } - } + internal bool TryTrackDataSource(string name, NpgsqlDataSource dataSource, [NotNullWhen(true)]out IDisposable? untrack) + => DataSourceEvents.TryTrack(name, dataSource, out untrack); - internal void MultiplexingBatchSent(int numCommands, Stopwatch stopwatch) - { - // TODO: CAS loop instead of 3 separate interlocked operations? - if (IsEnabled()) - { - Interlocked.Increment(ref _multiplexingBatchesSent); - Interlocked.Add(ref _multiplexingCommandsSent, numCommands); - Interlocked.Add(ref _multiplexingTicksWritten, stopwatch.ElapsedTicks); - } - } - - double GetDataSourceCount() - { - lock (_dataSourcesLock) - { - return _dataSources.Count; - } - } - - double GetMultiplexingAverageCommandsPerBatch() - { - var batchesSent = Interlocked.Read(ref _multiplexingBatchesSent); - if (batchesSent == 0) - return -1; - - var commandsSent = (double)Interlocked.Read(ref _multiplexingCommandsSent); - return commandsSent / batchesSent; - } - - double GetMultiplexingAverageWriteTimePerBatch() - { - var batchesSent = Interlocked.Read(ref _multiplexingBatchesSent); - if (batchesSent == 0) - return -1; - - var ticksWritten = (double)Interlocked.Read(ref _multiplexingTicksWritten); - return ticksWritten / batchesSent / 1000; - } + double GetDataSourceCount() => DataSourceEvents.GetDataSourceCount(); protected override void OnEventCommand(EventCommandEventArgs command) { - if (command.Command == EventCommand.Enable) + if (command.Command is EventCommand.Enable) { // Comment taken from RuntimeEventSource in CoreCLR // NOTE: These counters will NOT be disposed on disable command because we may be introducing @@ -197,28 +151,93 @@ protected override void OnEventCommand(EventCommandEventArgs command) DisplayName = "Connection Pools" }; - _multiplexingAverageCommandsPerBatchCounter = new PollingCounter("multiplexing-average-commands-per-batch", this, GetMultiplexingAverageCommandsPerBatch) - { - DisplayName = "Average commands per multiplexing batch" - }; + DataSourceEvents.EnableAll(); + } + } +} - _multiplexingAverageWriteTimePerBatchCounter = new PollingCounter("multiplexing-average-write-time-per-batch", this, GetMultiplexingAverageWriteTimePerBatch) +// This is a separate class to avoid accidentally making the CWT instance reachable through the value. +// The EventSource is stored in the counters, part of the value, so the EventSource *must not* reference this instance on an instance field. +// This goes for any state captured by the value, which is why the other state has its own object for the value to reference. +// See https://github.com/dotnet/runtime/issues/12255. +sealed class NpgsqlEventSourceDataSources(EventSource eventSource) +{ + readonly ConditionalWeakTable> _dataSources = new(); + readonly StrongBox<(int DataSourceCount, ConcurrentDictionary DataSourceNames)> _nonCwtState = new((0, new())); + + internal double GetDataSourceCount() => _nonCwtState.Value.DataSourceCount; + + internal bool TryTrack(string name, NpgsqlDataSource dataSource, [NotNullWhen(true)]out IDisposable? untrack) + { + untrack = null; + if (!_nonCwtState.Value.DataSourceNames.TryAdd(name, default)) + return false; + + var lazy = new Lazy( + () => new DataSourceEvents(name: name, dataSource, eventSource, _nonCwtState), + LazyThreadSafetyMode.ExecutionAndPublication); + var tracked = _dataSources.TryAdd(dataSource, lazy); + + if (tracked) + { + Interlocked.Increment(ref _nonCwtState.Value.DataSourceCount); + // We must initialize directly when the event source is already enabled. + if (eventSource.IsEnabled()) + untrack = lazy.Value; + else + untrack = new DataSourceEventsDisposable(lazy); + } + + return tracked; + } + + internal void EnableAll() + { + foreach (var dataSourceKv in _dataSources) + { + _ = dataSourceKv.Value.Value; + } + } + + sealed class DataSourceEventsDisposable(Lazy events) : IDisposable + { + public void Dispose() => events.Value.Dispose(); + } + + sealed class DataSourceEvents : IDisposable + { + readonly string _name; + readonly StrongBox<(int Count, ConcurrentDictionary Names)> _state; + readonly PollingCounter _idleConnections; + readonly PollingCounter _busyConnections; + + int _disposed; + + public DataSourceEvents(string name, NpgsqlDataSource dataSource, EventSource eventSource, StrongBox<(int, ConcurrentDictionary)> state) + { + _name = name; + _state = state; + _idleConnections = new($"idle-connections-{name}", eventSource, () => dataSource.Statistics.Idle) { - DisplayName = "Average write time per multiplexing batch", - DisplayUnits = "us" + DisplayName = $"Idle Connections [{name}]" }; - lock (_dataSourcesLock) + _busyConnections = new($"busy-connections-{name}", eventSource, () => dataSource.Statistics.Busy) { - foreach (var dataSource in _dataSources.Keys) - { - if (!_dataSources[dataSource].HasValue) - { - _dataSources[dataSource] = ( - new PollingCounter($"Idle Connections ({dataSource.Settings.ToStringWithoutPassword()}])", this, () => dataSource.Statistics.Idle), - new PollingCounter($"Busy Connections ({dataSource.Settings.ToStringWithoutPassword()}])", this, () => dataSource.Statistics.Busy)); - } - } - } + DisplayName = $"Busy Connections [{name}]" + }; + } + + public void Dispose() + { + if (Interlocked.Exchange(ref _disposed, 1) is 1) + return; + + _idleConnections.Dispose(); + _busyConnections.Dispose(); + + Interlocked.Decrement(ref _state.Value.Count); + var success = _state.Value.Names.TryRemove(_name, out _); + Debug.Assert(success); } } } diff --git a/src/Npgsql/NpgsqlException.cs b/src/Npgsql/NpgsqlException.cs index 91eb84adef..9e2dfe9ee0 100644 --- a/src/Npgsql/NpgsqlException.cs +++ b/src/Npgsql/NpgsqlException.cs @@ -46,6 +46,7 @@ public override bool IsTransient => InnerException is IOException or SocketException or TimeoutException or NpgsqlException { IsTransient: true }; /// + /// This property is null unless in connection string is set to true. public new NpgsqlBatchCommand? BatchCommand { get; set; } /// @@ -58,9 +59,7 @@ public override bool IsTransient /// /// The SerializationInfo that holds the serialized object data about the exception being thrown. /// The StreamingContext that contains contextual information about the source or destination. -#if NET8_0_OR_GREATER [Obsolete("This API supports obsolete formatter-based serialization. It should not be called or extended by application code.")] -#endif protected internal NpgsqlException(SerializationInfo info, StreamingContext context) : base(info, context) {} #endregion diff --git a/src/Npgsql/NpgsqlFactory.cs b/src/Npgsql/NpgsqlFactory.cs index 15a1cd431e..d95e645f70 100644 --- a/src/Npgsql/NpgsqlFactory.cs +++ b/src/Npgsql/NpgsqlFactory.cs @@ -66,11 +66,9 @@ public sealed class NpgsqlFactory : DbProviderFactory, IServiceProvider /// public override DbBatchCommand CreateBatchCommand() => new NpgsqlBatchCommand(); -#if NET7_0_OR_GREATER /// public override DbDataSource CreateDataSource(string connectionString) => NpgsqlDataSource.Create(connectionString); -#endif #region IServiceProvider Members diff --git a/src/Npgsql/NpgsqlLargeObjectStream.cs b/src/Npgsql/NpgsqlLargeObjectStream.cs index 2f3c8b19b0..09d90b164a 100644 --- a/src/Npgsql/NpgsqlLargeObjectStream.cs +++ b/src/Npgsql/NpgsqlLargeObjectStream.cs @@ -64,14 +64,11 @@ public override Task ReadAsync(byte[] buffer, int offset, int count, Cancel async Task Read(bool async, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) { - if (buffer == null) - throw new ArgumentNullException(nameof(buffer)); - if (offset < 0) - throw new ArgumentOutOfRangeException(nameof(offset)); - if (count < 0) - throw new ArgumentOutOfRangeException(nameof(count)); + ArgumentNullException.ThrowIfNull(buffer); + ArgumentOutOfRangeException.ThrowIfNegative(offset); + ArgumentOutOfRangeException.ThrowIfNegative(count); if (buffer.Length - offset < count) - throw new ArgumentException("Invalid offset or count for this buffer"); + ThrowHelper.ThrowArgumentException("Invalid offset or count for this buffer"); CheckDisposed(); @@ -115,14 +112,11 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati async Task Write(bool async, byte[] buffer, int offset, int count, CancellationToken cancellationToken = default) { - if (buffer == null) - throw new ArgumentNullException(nameof(buffer)); - if (offset < 0) - throw new ArgumentOutOfRangeException(nameof(offset)); - if (count < 0) - throw new ArgumentOutOfRangeException(nameof(count)); + ArgumentNullException.ThrowIfNull(buffer); + ArgumentOutOfRangeException.ThrowIfNegative(offset); + ArgumentOutOfRangeException.ThrowIfNegative(count); if (buffer.Length - offset < count) - throw new ArgumentException("Invalid offset or count for this buffer"); + ThrowHelper.ThrowArgumentException("Invalid offset or count for this buffer"); CheckDisposed(); @@ -262,8 +256,7 @@ async Task SetLength(bool async, long value, CancellationToken cancellationToken { cancellationToken.ThrowIfCancellationRequested(); - if (value < 0) - throw new ArgumentOutOfRangeException(nameof(value)); + ArgumentOutOfRangeException.ThrowIfNegative(value); if (!Has64BitSupport && value != (int)value) throw new ArgumentOutOfRangeException(nameof(value), "offset must fit in 32 bits for PostgreSQL versions older than 9.3"); diff --git a/src/Npgsql/NpgsqlMetricsOptions.cs b/src/Npgsql/NpgsqlMetricsOptions.cs new file mode 100644 index 0000000000..b4da63dc7a --- /dev/null +++ b/src/Npgsql/NpgsqlMetricsOptions.cs @@ -0,0 +1,9 @@ +namespace Npgsql; + +/// +/// Options to configure Npgsql's support for OpenTelemetry metrics. +/// Currently no options are available. +/// +public class NpgsqlMetricsOptions +{ +} diff --git a/src/Npgsql/NpgsqlMultiHostDataSource.cs b/src/Npgsql/NpgsqlMultiHostDataSource.cs index 4b8731e5b6..4e7d63bddb 100644 --- a/src/Npgsql/NpgsqlMultiHostDataSource.cs +++ b/src/Npgsql/NpgsqlMultiHostDataSource.cs @@ -4,7 +4,6 @@ using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; -using System.Linq; using System.Threading; using System.Threading.Tasks; using System.Transactions; @@ -31,14 +30,13 @@ public sealed class NpgsqlMultiHostDataSource : NpgsqlDataSource volatile int _roundRobinIndex = -1; internal NpgsqlMultiHostDataSource(NpgsqlConnectionStringBuilder settings, NpgsqlDataSourceConfiguration dataSourceConfig) - : base(settings, dataSourceConfig) + : base(settings, dataSourceConfig, reportMetrics: false) { var hosts = settings.Host!.Split(','); _pools = new NpgsqlDataSource[hosts.Length]; for (var i = 0; i < hosts.Length; i++) { var poolSettings = settings.Clone(); - Debug.Assert(!poolSettings.Multiplexing); var host = hosts[i].AsSpan().Trim(); if (NpgsqlConnectionStringBuilder.TrySplitHostPort(host, out var newHost, out var newPort)) { @@ -49,11 +47,11 @@ internal NpgsqlMultiHostDataSource(NpgsqlConnectionStringBuilder settings, Npgsq poolSettings.Host = host.ToString(); _pools[i] = settings.Pooling - ? new PoolingDataSource(poolSettings, dataSourceConfig, this) + ? new PoolingDataSource(poolSettings, dataSourceConfig) : new UnpooledDataSource(poolSettings, dataSourceConfig); } - var targetSessionAttributeValues = Enum.GetValues().ToArray(); + var targetSessionAttributeValues = Enum.GetValues(); var highestValue = 0; foreach (var value in targetSessionAttributeValues) if ((int)value > highestValue) @@ -219,6 +217,12 @@ static bool IsOnline(DatabaseState state, TargetSessionAttributes preferredType) } } } + catch (OperationCanceledException oce) when (cancellationToken.IsCancellationRequested && oce.CancellationToken == cancellationToken) + { + if (connector is not null) + pool.Return(connector); + throw; + } catch (Exception ex) { exceptions.Add(ex); @@ -363,7 +367,8 @@ internal override bool TryGetIdleConnector([NotNullWhen(true)] out NpgsqlConnect internal override ValueTask OpenNewConnector(NpgsqlConnection conn, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken) => throw new NpgsqlException("Npgsql bug: trying to open a new connector from " + nameof(NpgsqlMultiHostDataSource)); - internal override void Clear() + /// + public override void Clear() { foreach (var pool in _pools) pool.Clear(); @@ -455,6 +460,6 @@ bool TryGetValidConnector(List list, TargetSessionAttributes pr static TargetSessionAttributes GetTargetSessionAttributes(NpgsqlConnection connection) => connection.Settings.TargetSessionAttributesParsed ?? (PostgresEnvironment.TargetSessionAttributes is { } s - ? NpgsqlConnectionStringBuilder.ParseTargetSessionAttributes(s) + ? NpgsqlConnectionStringBuilder.ParseTargetSessionAttributes(s.ToLowerInvariant()) : TargetSessionAttributes.Any); } diff --git a/src/Npgsql/NpgsqlNestedDataReader.cs b/src/Npgsql/NpgsqlNestedDataReader.cs index 1d499585f8..cda412d1a5 100644 --- a/src/Npgsql/NpgsqlNestedDataReader.cs +++ b/src/Npgsql/NpgsqlNestedDataReader.cs @@ -29,32 +29,23 @@ public sealed class NpgsqlNestedDataReader : DbDataReader int _nextRowBufferPos; ReaderState _readerState; - readonly List _columns = new(); + readonly List _columns = []; long _startPos; DataFormat Format => DataFormat.Binary; - readonly struct ColumnInfo + readonly struct ColumnInfo(PostgresType postgresType, int bufferPos, PgTypeInfo objectOrDefaultTypeInfo, DataFormat format) { - readonly DataFormat _format; - public PostgresType PostgresType { get; } - public int BufferPos { get; } + public PostgresType PostgresType { get; } = postgresType; + public int BufferPos { get; } = bufferPos; public PgConverterInfo LastConverterInfo { get; init; } - public PgTypeInfo ObjectOrDefaultTypeInfo { get; } - public PgConverterInfo GetObjectOrDefaultInfo() => ObjectOrDefaultTypeInfo.Bind(Field, _format); + public PgTypeInfo ObjectOrDefaultTypeInfo { get; } = objectOrDefaultTypeInfo; + public PgConverterInfo GetObjectOrDefaultInfo() => ObjectOrDefaultTypeInfo.Bind(Field, format); Field Field => new("?", ObjectOrDefaultTypeInfo.Options.PortableTypeIds ? PostgresType.DataTypeName : (Oid)PostgresType.OID, -1); - public PgConverterInfo Bind(PgTypeInfo typeInfo) => typeInfo.Bind(Field, _format); - - public ColumnInfo(PostgresType postgresType, int bufferPos, PgTypeInfo objectOrDefaultTypeInfo, DataFormat format) - { - _format = format; - PostgresType = postgresType; - BufferPos = bufferPos; - ObjectOrDefaultTypeInfo = objectOrDefaultTypeInfo; - } + public PgConverterInfo Bind(PgTypeInfo typeInfo) => typeInfo.Bind(Field, format); } PgReader PgReader => _outermostReader.Buffer.PgReader; @@ -67,12 +58,12 @@ internal NpgsqlNestedDataReader(NpgsqlDataReader outermostReader, NpgsqlNestedDa _outerNestedReader = outerNestedReader; _depth = depth; _compositeType = compositeType; - _startPos = PgReader.FieldStartPos; + _startPos = PgReader.GetFieldStartPos(this); } internal void Init(PostgresCompositeType? compositeType) { - _startPos = PgReader.FieldStartPos; + _startPos = PgReader.GetFieldStartPos(this); _columns.Clear(); _numRows = 0; _nextRowIndex = 0; @@ -102,13 +93,13 @@ internal void InitArray() if (_numRows > 0) PgReader.ReadInt32(); // Length of first row - _nextRowBufferPos = PgReader.FieldOffset; + _nextRowBufferPos = PgReader.GetFieldOffset(this); } internal void InitSingleRow() { _numRows = 1; - _nextRowBufferPos = PgReader.FieldOffset; + _nextRowBufferPos = PgReader.GetFieldOffset(this); } /// @@ -150,7 +141,7 @@ public override bool HasRows /// public override bool IsClosed => _readerState == ReaderState.Closed || _readerState == ReaderState.Disposed - || _outermostReader.IsClosed || PgReader.FieldStartPos != _startPos; + || _outermostReader.IsClosed || PgReader.GetFieldStartPos(this) != _startPos; /// public override int RecordsAffected => -1; @@ -183,8 +174,8 @@ public override bool IsClosed /// public override long GetBytes(int ordinal, long dataOffset, byte[]? buffer, int bufferOffset, int length) { - if (dataOffset is < 0 or > int.MaxValue) - throw new ArgumentOutOfRangeException(nameof(dataOffset), dataOffset, $"dataOffset must be between 0 and {int.MaxValue}"); + ArgumentOutOfRangeException.ThrowIfNegative(dataOffset); + ArgumentOutOfRangeException.ThrowIfGreaterThan(dataOffset, int.MaxValue); if (buffer != null && (bufferOffset < 0 || bufferOffset >= buffer.Length + 1)) throw new IndexOutOfRangeException($"bufferOffset must be between 0 and {buffer.Length}"); if (buffer != null && (length < 0 || length > buffer.Length - bufferOffset)) @@ -312,8 +303,7 @@ public override object GetValue(int ordinal) /// public override int GetValues(object[] values) { - if (values == null) - throw new ArgumentNullException(nameof(values)); + ArgumentNullException.ThrowIfNull(values); CheckOnRow(); var count = Math.Min(FieldCount, values.Length); @@ -353,7 +343,7 @@ public override T GetFieldValue(int ordinal) using var _ = PgReader.BeginNestedRead(columnLength, info.BufferRequirement); return asObject ? (T)info.Converter.ReadAsObject(PgReader)! - : info.GetConverter().Read(PgReader); + : info.Converter.UnsafeDowncast().Read(PgReader); } /// @@ -376,18 +366,20 @@ public override bool Read() for (var i = 0; i < numColumns; i++) { var typeOid = PgReader.ReadUInt32(); - var bufferPos = PgReader.FieldOffset; + var bufferPos = PgReader.GetFieldOffset(this); if (i >= _columns.Count) { var pgType = SerializerOptions.DatabaseInfo.GetPostgresType(typeOid); - _columns.Add(new ColumnInfo(pgType, bufferPos, AdoSerializerHelpers.GetTypeInfoForReading(typeof(object), pgType, SerializerOptions), Format)); + var pgTypeId = SerializerOptions.ToCanonicalTypeId(pgType); + _columns.Add(new ColumnInfo(pgType, bufferPos, AdoSerializerHelpers.GetTypeInfoForReading(typeof(object), pgTypeId, SerializerOptions), Format)); } else { var pgType = _columns[i].PostgresType.OID == typeOid ? _columns[i].PostgresType : SerializerOptions.DatabaseInfo.GetPostgresType(typeOid); - _columns[i] = new ColumnInfo(pgType, bufferPos, AdoSerializerHelpers.GetTypeInfoForReading(typeof(object), pgType, SerializerOptions), Format); + var pgTypeId = SerializerOptions.ToCanonicalTypeId(pgType); + _columns[i] = new ColumnInfo(pgType, bufferPos, AdoSerializerHelpers.GetTypeInfoForReading(typeof(object), pgTypeId, SerializerOptions), Format); } var columnLen = PgReader.ReadInt32(); @@ -396,7 +388,7 @@ public override bool Read() } _columns.RemoveRange(numColumns, _columns.Count - numColumns); - _nextRowBufferPos = PgReader.FieldOffset; + _nextRowBufferPos = PgReader.GetFieldOffset(this); _readerState = ReaderState.OnRow; return true; @@ -517,7 +509,7 @@ PgConverterInfo GetOrAddConverterInfo(Type type, ColumnInfo column, int ordinal, } } - var converterInfo = column.Bind(AdoSerializerHelpers.GetTypeInfoForReading(type, column.PostgresType, SerializerOptions)); + var converterInfo = column.Bind(AdoSerializerHelpers.GetTypeInfoForReading(type, SerializerOptions.ToCanonicalTypeId(column.PostgresType), SerializerOptions)); _columns[ordinal] = column with { LastConverterInfo = converterInfo }; asObject = converterInfo.IsBoxingConverter; return converterInfo; diff --git a/src/Npgsql/NpgsqlOperationInProgressException.cs b/src/Npgsql/NpgsqlOperationInProgressException.cs index eb7377afcd..74e7e646ff 100644 --- a/src/Npgsql/NpgsqlOperationInProgressException.cs +++ b/src/Npgsql/NpgsqlOperationInProgressException.cs @@ -16,9 +16,7 @@ public sealed class NpgsqlOperationInProgressException : NpgsqlException /// public NpgsqlOperationInProgressException(NpgsqlCommand command) : base("A command is already in progress: " + command.CommandText) - { - CommandInProgress = command; - } + => CommandInProgress = command; internal NpgsqlOperationInProgressException(ConnectorState state) : base($"The connection is already in state '{state}'") @@ -31,4 +29,4 @@ internal NpgsqlOperationInProgressException(ConnectorState state) /// . /// public NpgsqlCommand? CommandInProgress { get; } -} \ No newline at end of file +} diff --git a/src/Npgsql/NpgsqlParameter.cs b/src/Npgsql/NpgsqlParameter.cs index d1dba6af5d..8930724c92 100644 --- a/src/Npgsql/NpgsqlParameter.cs +++ b/src/Npgsql/NpgsqlParameter.cs @@ -30,6 +30,7 @@ public class NpgsqlParameter : DbParameter, IDbDataParameter, ICloneable internal NpgsqlDbType? _npgsqlDbType; internal string? _dataTypeName; + internal DbType? _dbType; private protected string _name = string.Empty; object? _value; @@ -40,6 +41,7 @@ public class NpgsqlParameter : DbParameter, IDbDataParameter, ICloneable internal string TrimmedName { get; private protected set; } = PositionalName; internal const string PositionalName = ""; + IDbTypeResolver? _dbTypeResolver; private protected PgTypeInfo? TypeInfo { get; private set; } internal PgTypeId PgTypeId { get; private set; } @@ -315,26 +317,32 @@ public sealed override DbType DbType { get { - if (_npgsqlDbType is { } npgsqlDbType) - return npgsqlDbType.ToDbType(); + if (_dbType is { } dbType) + return dbType; if (_dataTypeName is not null) - return Internal.Postgres.DataTypeName.FromDisplayName(_dataTypeName).ToNpgsqlDbType()?.ToDbType() ?? DbType.Object; + { + var dataTypeName = Internal.Postgres.DataTypeName.FromDisplayName(_dataTypeName); + if (TryResolveDbType(dataTypeName, out var resolvedDbType)) + return resolvedDbType; + + return dataTypeName.ToNpgsqlDbType()?.ToDbType() ?? DbType.Object; + } + + if (_npgsqlDbType is { } npgsqlDbType) + return npgsqlDbType.ToDbType(); // Infer from value but don't cache - if (Value is not null) - // We pass ValueType here for the generic derived type, where we should respect T and not the runtime type. - return GlobalTypeMapper.Instance.FindDataTypeName(GetValueType(StaticValueType)!, Value)?.ToNpgsqlDbType()?.ToDbType() ?? DbType.Object; + // We pass ValueType here for the generic derived type, where we should respect T and not the runtime type. + if (GetValueType(StaticValueType) is { } valueType) + return GlobalTypeMapper.Instance.FindDataTypeName(valueType, Value)?.ToNpgsqlDbType()?.ToDbType() ?? DbType.Object; return DbType.Object; } set { ResetTypeInfo(); - _npgsqlDbType = value == DbType.Object - ? null - : value.ToNpgsqlDbType() - ?? throw new NotSupportedException($"The parameter type DbType.{value} isn't supported by PostgreSQL or Npgsql"); + _dbType = value; } } @@ -355,19 +363,28 @@ public NpgsqlDbType NpgsqlDbType if (_dataTypeName is not null) return Internal.Postgres.DataTypeName.FromDisplayName(_dataTypeName).ToNpgsqlDbType() ?? NpgsqlDbType.Unknown; + var valueType = GetValueType(StaticValueType); + if (_dbType is { } dbType) + { + if (TryResolveDbTypeDataTypeName(dbType, valueType, out var dataTypeName)) + return NpgsqlDbTypeExtensions.ToNpgsqlDbType(dataTypeName) ?? NpgsqlDbType.Unknown; + + return dbType.ToNpgsqlDbType() ?? NpgsqlDbType.Unknown; + } + // Infer from value but don't cache - if (Value is not null) - // We pass ValueType here for the generic derived type (NpgsqlParameter) where we should respect T and not the runtime type. - return GlobalTypeMapper.Instance.FindDataTypeName(GetValueType(StaticValueType)!, Value)?.ToNpgsqlDbType() ?? NpgsqlDbType.Unknown; + // We pass ValueType here for the generic derived type, where we should respect T and not the runtime type. + if (valueType is not null) + return GlobalTypeMapper.Instance.FindDataTypeName(valueType, Value)?.ToNpgsqlDbType() ?? NpgsqlDbType.Unknown; return NpgsqlDbType.Unknown; } set { if (value == NpgsqlDbType.Array) - throw new ArgumentOutOfRangeException(nameof(value), "Cannot set NpgsqlDbType to just Array, Binary-Or with the element type (e.g. Array of Box is NpgsqlDbType.Array | NpgsqlDbType.Box)."); + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(value), "Cannot set NpgsqlDbType to just Array, Binary-Or with the element type (e.g. Array of Box is NpgsqlDbType.Array | NpgsqlDbType.Box)."); if (value == NpgsqlDbType.Range) - throw new ArgumentOutOfRangeException(nameof(value), "Cannot set NpgsqlDbType to just Range, Binary-Or with the element type (e.g. Range of integer is NpgsqlDbType.Range | NpgsqlDbType.Integer)"); + ThrowHelper.ThrowArgumentOutOfRangeException(nameof(value), "Cannot set NpgsqlDbType to just Range, Binary-Or with the element type (e.g. Range of integer is NpgsqlDbType.Range | NpgsqlDbType.Integer)"); ResetTypeInfo(); _npgsqlDbType = value; @@ -392,10 +409,21 @@ public string? DataTypeName "pg_catalog." + unqualifiedName).UnqualifiedDisplayName; } + var valueType = GetValueType(StaticValueType); + if (_dbType is { } dbType) + { + if (TryResolveDbTypeDataTypeName(dbType, valueType, out var dataTypeName)) + return dataTypeName; + + var unqualifiedName = dbType.ToNpgsqlDbType()?.ToUnqualifiedDataTypeName(); + return unqualifiedName is null ? null : Internal.Postgres.DataTypeName.ValidatedName( + "pg_catalog." + unqualifiedName).UnqualifiedDisplayName; + } + // Infer from value but don't cache - if (Value is not null) - // We pass ValueType here for the generic derived type, where we should respect T and not the runtime type. - return GlobalTypeMapper.Instance.FindDataTypeName(GetValueType(StaticValueType)!, Value)?.DisplayName; + // We pass ValueType here for the generic derived type, where we should respect T and not the runtime type. + if (valueType is not null) + return GlobalTypeMapper.Instance.FindDataTypeName(valueType, Value)?.DisplayName; return null; } @@ -453,7 +481,7 @@ public sealed override int Size set { if (value < -1) - throw new ArgumentException($"Invalid parameter Size value '{value}'. The value must be greater than or equal to 0."); + ThrowHelper.ThrowArgumentException($"Invalid parameter Size value '{value}'. The value must be greater than or equal to 0."); ResetBindingInfo(); _size = value; @@ -497,6 +525,40 @@ public sealed override string SourceColumn Type? GetValueType(Type staticValueType) => staticValueType != typeof(object) ? staticValueType : Value?.GetType(); + bool TryResolveDbType(DataTypeName dataTypeName, out DbType dbType) + { + if (_dbTypeResolver?.GetDbType(dataTypeName) is { } result) + { + dbType = result; + return true; + } + + dbType = default; + return false; + } + + bool TryResolveDbTypeDataTypeName(DbType dbType, Type? type, [NotNullWhen(true)]out string? normalizedDataTypeName) + { + if (_dbTypeResolver?.GetDataTypeName(dbType, type) is { } result) + { + normalizedDataTypeName = Internal.Postgres.DataTypeName.NormalizeName(result); + return true; + } + + normalizedDataTypeName = null; + return false; + } + + internal void SetOutputValue(NpgsqlDataReader reader, int ordinal) + { + if (GetType() == typeof(NpgsqlParameter)) + Value = reader.GetValue(ordinal); + else + SetOutputValueCore(reader, ordinal); + } + + private protected virtual void SetOutputValueCore(NpgsqlDataReader reader, int ordinal) {} + internal bool ShouldResetObjectTypeInfo(object? value) { var currentType = TypeInfo?.Type; @@ -526,18 +588,44 @@ internal void SetResolutionInfo(PgTypeInfo typeInfo, PgConverter converter, PgTy } /// Attempt to resolve a type info based on available (postgres) type information on the parameter. - internal void ResolveTypeInfo(PgSerializerOptions options) + internal void ResolveTypeInfo(PgSerializerOptions options, IDbTypeResolver? dbTypeResolver) { var typeInfo = TypeInfo; var previouslyResolved = ReferenceEquals(typeInfo?.Options, options); if (!previouslyResolved) { - var dataTypeName = - _npgsqlDbType is { } npgsqlDbType - ? npgsqlDbType.ToDataTypeName() ?? npgsqlDbType.ToUnqualifiedDataTypeNameOrThrow() - : _dataTypeName is not null - ? Internal.Postgres.DataTypeName.NormalizeName(_dataTypeName) - : null; + var staticValueType = StaticValueType; + var valueType = GetValueType(staticValueType); + + string? dataTypeName = null; + if (_dataTypeName is not null) + { + dataTypeName = Internal.Postgres.DataTypeName.NormalizeName(_dataTypeName); + } + else if (_npgsqlDbType is { } npgsqlDbType) + { + dataTypeName = npgsqlDbType.ToDataTypeName() ?? npgsqlDbType.ToUnqualifiedDataTypeNameOrThrow(); + } + else if (_dbType is { } dbType) + { + if (dbTypeResolver is not null) + { + _dbTypeResolver = dbTypeResolver; + if (dbTypeResolver.GetDataTypeName(dbType, valueType) is { } result) + { + dataTypeName = Internal.Postgres.DataTypeName.NormalizeName(result); + } + } + + // Fall back to builtin mappings if there was no resolver, or it didn't produce a result. + if (dataTypeName is null) + { + dataTypeName = dbType.ToNpgsqlDbType()?.ToDataTypeName(); + // If DbType.Object was specified we will only throw (see ThrowNoTypeInfo) if valueType is also null. + if (dataTypeName is null && dbType is not DbType.Object) + ThrowDbTypeNotSupported(); + } + } PgTypeId? pgTypeId = null; if (dataTypeName is not null) @@ -551,35 +639,24 @@ _npgsqlDbType is { } npgsqlDbType pgTypeId = options.ToCanonicalTypeId(pgType.GetRepresentationalType()); } - var unspecifiedDBNull = false; - var valueType = StaticValueType; - if (valueType == typeof(object)) + if (pgTypeId is null && valueType is null) { - valueType = Value?.GetType(); - if (valueType is null && pgTypeId is null) - { - ThrowNoTypeInfo(); - return; - } - - // We treat object typed DBNull values as default info. - // Unless we don't have a pgTypeId either, at which point we'll use an 'unspecified' PgTypeInfo to help us write a NULL. - if (valueType == typeof(DBNull)) - { - if (pgTypeId is null) - { - unspecifiedDBNull = true; - typeInfo = options.UnspecifiedDBNullTypeInfo; - } - else - valueType = null; - } + ThrowNoTypeInfo(); + return; } - if (!unspecifiedDBNull) - typeInfo = AdoSerializerHelpers.GetTypeInfoForWriting(valueType, pgTypeId, options, _npgsqlDbType); - - TypeInfo = typeInfo; + // We treat object typed DBNull values as default info (we don't supply a type). + // Unless we don't have a pgTypeId either, at which point we'll use an 'unspecified' PgTypeInfo to help us write a NULL. + if (valueType == typeof(DBNull) && staticValueType == typeof(object)) + { + TypeInfo = typeInfo = pgTypeId is null + ? options.UnspecifiedDBNullTypeInfo + : AdoSerializerHelpers.GetTypeInfoForWriting(type: null, pgTypeId, options, _npgsqlDbType); + } + else + { + TypeInfo = typeInfo = AdoSerializerHelpers.GetTypeInfoForWriting(valueType, pgTypeId, options, _npgsqlDbType); + } } // This step isn't part of BindValue because we need to know the PgTypeId beforehand for things like SchemaOnly with null values. @@ -595,14 +672,16 @@ _npgsqlDbType is { } npgsqlDbType void ThrowNoTypeInfo() => ThrowHelper.ThrowInvalidOperationException( - $"Parameter '{(!string.IsNullOrEmpty(ParameterName) ? ParameterName : $"${Collection?.IndexOf(this) + 1}")}' must have either its NpgsqlDbType or its DataTypeName or its Value set."); + $"Parameter '{(!string.IsNullOrEmpty(ParameterName) ? ParameterName : $"${Collection?.IndexOf(this) + 1}")}' must have either its DbType, NpgsqlDbType, DataTypeName or its Value set."); + + void ThrowDbTypeNotSupported() + => ThrowHelper.ThrowNotSupportedException( + $"The DbType '{_dbType}' isn't supported by Npgsql. There might be an Npgsql plugin with support for this DbType."); void ThrowNotSupported(string dataTypeName) - { - throw new NotSupportedException(_npgsqlDbType is not null - ? $"The NpgsqlDbType '{_npgsqlDbType}' isn't present in your database. You may need to install an extension or upgrade to a newer version." - : $"The data type name '{dataTypeName}' isn't present in your database. You may need to install an extension or upgrade to a newer version."); - } + => ThrowHelper.ThrowNotSupportedException( + $"The data type name '{dataTypeName}'{(_npgsqlDbType is not null ? $", provided as NpgsqlDbType '{_npgsqlDbType}'," : null)} could not be found in the types that were loaded by Npgsql. " + + $"Your database details or Npgsql type loading configuration may be incorrect. Alternatively your PostgreSQL installation might need to be upgraded, or an extension adding the missing data type might not have been installed."); } // Pull from Value so we also support object typed generic params. @@ -613,28 +692,24 @@ private protected virtual PgConverterResolution ResolveConverter(PgTypeInfo type } /// Bind the current value to the type info, truncate (if applicable), take its size, and do any final validation before writing. - internal void Bind(out DataFormat format, out Size size) + internal void Bind(out DataFormat format, out Size size, DataFormat? requiredFormat = null) { if (TypeInfo is null) ThrowHelper.ThrowInvalidOperationException($"Missing type info, {nameof(ResolveTypeInfo)} needs to be called before {nameof(Bind)}."); - if (!TypeInfo.SupportsWriting) - ThrowHelper.ThrowNotSupportedException($"Cannot write values for parameters of type '{TypeInfo.Type}' and postgres type '{TypeInfo.Options.DatabaseInfo.GetDataTypeName(PgTypeId).DisplayName}'."); - // We might call this twice, once during validation and once during WriteBind, only compute things once. - if (WriteSize is not null) + if (WriteSize is null) { - format = Format; - size = WriteSize.Value; - return; - } + if (_size > 0) + HandleSizeTruncation(); - if (_size > 0) - HandleSizeTruncation(); + BindCore(requiredFormat); + } - BindCore(); format = Format; size = WriteSize!.Value; + if (requiredFormat is not null && format != requiredFormat) + ThrowHelper.ThrowNotSupportedException($"Parameter '{ParameterName}' must be written in {requiredFormat} format, but does not support this format."); // Handle Size truncate behavior for a predetermined set of types and pg types. // Doesn't matter if we 'box' Value, all supported types are reference types. @@ -674,7 +749,7 @@ void HandleSizeTruncation() } } - private protected virtual void BindCore(bool allowNullReference = false) + private protected virtual void BindCore(DataFormat? formatPreference, bool allowNullReference = false) { // Pull from Value so we also support object typed generic params. var value = Value; @@ -684,7 +759,7 @@ private protected virtual void BindCore(bool allowNullReference = false) if (_useSubStream && value is not null) value = _subStream = new SubReadStream((Stream)value, _size); - if (TypeInfo!.BindObject(Converter!, value, out var size, out _writeState, out var dataFormat) is { } info) + if (TypeInfo!.BindObject(Converter!, value, out var size, out _writeState, out var dataFormat, formatPreference) is { } info) { WriteSize = size; _bufferRequirement = info.BufferRequirement; @@ -694,6 +769,7 @@ private protected virtual void BindCore(bool allowNullReference = false) WriteSize = -1; _bufferRequirement = default; } + Format = dataFormat; } @@ -748,6 +824,7 @@ private protected virtual ValueTask WriteValue(bool async, PgWriter writer, Canc /// public override void ResetDbType() { + _dbType = null; _npgsqlDbType = null; _dataTypeName = null; ResetTypeInfo(); @@ -808,6 +885,7 @@ private protected virtual NpgsqlParameter CloneCore() => _precision = _precision, _scale = _scale, _size = _size, + _dbType = _dbType, _npgsqlDbType = _npgsqlDbType, _dataTypeName = _dataTypeName, Direction = Direction, diff --git a/src/Npgsql/NpgsqlParameterCollection.cs b/src/Npgsql/NpgsqlParameterCollection.cs index a10f9dceb0..51a40e6648 100644 --- a/src/Npgsql/NpgsqlParameterCollection.cs +++ b/src/Npgsql/NpgsqlParameterCollection.cs @@ -5,7 +5,6 @@ using System.Data.Common; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; -using Npgsql.Internal; using NpgsqlTypes; namespace Npgsql; @@ -143,7 +142,7 @@ internal void ChangeParameterName(NpgsqlParameter parameter, string? value) var oldTrimmedName = parameter.TrimmedName; parameter.ChangeParameterName(value); - if (_caseInsensitiveLookup is null || _caseInsensitiveLookup.Count == 0) + if (_caseInsensitiveLookup is null) return; var index = IndexOf(parameter); @@ -166,28 +165,25 @@ internal void ChangeParameterName(NpgsqlParameter parameter, string? value) { get { - if (parameterName is null) - throw new ArgumentNullException(nameof(parameterName)); + ArgumentNullException.ThrowIfNull(parameterName); var index = IndexOf(parameterName); if (index == -1) - throw new ArgumentException("Parameter not found"); + ThrowHelper.ThrowArgumentException("Parameter not found"); return InternalList[index]; } set { - if (parameterName is null) - throw new ArgumentNullException(nameof(parameterName)); - if (value is null) - throw new ArgumentNullException(nameof(value)); + ArgumentNullException.ThrowIfNull(parameterName); + ArgumentNullException.ThrowIfNull(value); var index = IndexOf(parameterName); if (index == -1) - throw new ArgumentException("Parameter not found"); + ThrowHelper.ThrowArgumentException("Parameter not found"); if (!string.Equals(parameterName, value.TrimmedName, StringComparison.OrdinalIgnoreCase)) - throw new ArgumentException("Parameter name must be a case-insensitive match with the property 'ParameterName' on the given NpgsqlParameter", nameof(parameterName)); + ThrowHelper.ThrowArgumentException("Parameter name must be a case-insensitive match with the property 'ParameterName' on the given NpgsqlParameter", nameof(parameterName)); var oldValue = InternalList[index]; LookupChangeName(value, oldValue.ParameterName, oldValue.TrimmedName, index); @@ -206,8 +202,7 @@ internal void ChangeParameterName(NpgsqlParameter parameter, string? value) get => InternalList[index]; set { - if (value is null) - ThrowHelper.ThrowArgumentNullException(nameof(value)); + ArgumentNullException.ThrowIfNull(value); if (value.Collection is not null) ThrowHelper.ThrowInvalidOperationException("The parameter already belongs to a collection"); @@ -228,11 +223,10 @@ internal void ChangeParameterName(NpgsqlParameter parameter, string? value) /// Adds the specified object to the . /// /// The to add to the collection. - /// The index of the new object. + /// The parameter that was added. public NpgsqlParameter Add(NpgsqlParameter value) { - if (value is null) - ThrowHelper.ThrowArgumentNullException(nameof(value)); + ArgumentNullException.ThrowIfNull(value); if (value.Collection is not null) ThrowHelper.ThrowInvalidOperationException("The parameter already belongs to a collection"); @@ -315,7 +309,7 @@ public NpgsqlParameter AddWithValue(NpgsqlDbType parameterType, object value) /// /// The name of the parameter. /// One of the values. - /// The index of the new object. + /// The parameter that was added. public NpgsqlParameter Add(string parameterName, NpgsqlDbType parameterType) => Add(new NpgsqlParameter(parameterName, parameterType)); @@ -326,7 +320,7 @@ public NpgsqlParameter Add(string parameterName, NpgsqlDbType parameterType) /// The name of the parameter. /// One of the values. /// The length of the column. - /// The index of the new object. + /// The parameter that was added. public NpgsqlParameter Add(string parameterName, NpgsqlDbType parameterType, int size) => Add(new NpgsqlParameter(parameterName, parameterType, size)); @@ -338,7 +332,7 @@ public NpgsqlParameter Add(string parameterName, NpgsqlDbType parameterType, int /// One of the values. /// The length of the column. /// The name of the source column. - /// The index of the new object. + /// The parameter that was added. public NpgsqlParameter Add(string parameterName, NpgsqlDbType parameterType, int size, string sourceColumn) => Add(new NpgsqlParameter(parameterName, parameterType, size, sourceColumn)); @@ -430,24 +424,30 @@ void BuildLookup() /// The zero-based index of the parameter. public override void RemoveAt(int index) { - if (InternalList.Count - 1 < index) - throw new ArgumentOutOfRangeException(nameof(index)); + ArgumentOutOfRangeException.ThrowIfGreaterThanOrEqual(index, InternalList.Count); Remove(InternalList[index]); } - /// + /// + /// Inserts a parameter into the at the specified index. + /// + /// The zero-based index at which to insert the parameter. + /// The parameter to insert. + /// + /// Although this method accepts , only instances of are supported. + /// Passing any other type will result in an . + /// public override void Insert(int index, object value) => Insert(index, Cast(value)); /// - /// Removes the specified from the collection. + /// Removes the with the specified name from the collection. /// /// The name of the to remove from the collection. public void Remove(string parameterName) { - if (parameterName is null) - ThrowHelper.ThrowArgumentNullException(nameof(parameterName)); + ArgumentNullException.ThrowIfNull(parameterName); var index = IndexOf(parameterName); if (index < 0) @@ -460,6 +460,10 @@ public void Remove(string parameterName) /// Removes the specified from the collection. /// /// The to remove from the collection. + /// + /// Although this method accepts , only instances of are supported. + /// Passing any other type will result in an . + /// public override void Remove(object value) => Remove(Cast(value)); @@ -481,8 +485,7 @@ public override bool Contains(object value) /// public bool TryGetValue(string parameterName, [NotNullWhen(true)] out NpgsqlParameter? parameter) { - if (parameterName is null) - throw new ArgumentNullException(nameof(parameterName)); + ArgumentNullException.ThrowIfNull(parameterName); var index = IndexOf(parameterName); @@ -509,11 +512,29 @@ public override void Clear() LookupClear(); } - /// + /// + /// Returns the index of the specified parameter in the . + /// + /// The parameter to find. + /// The index of the parameter if found; otherwise, -1. + /// + /// Although this method accepts , only instances of are supported. + /// Passing any other type will result in an . + /// public override int IndexOf(object value) => IndexOf(Cast(value)); - /// + /// + /// Adds a parameter to the . + /// + /// The parameter to add. + /// The zero-based index at which the parameter was added. + /// + /// Although this method accepts , only instances of are supported. + /// Passing any other type will result in an . + /// To add a parameter by value, use , , + /// or one of the typed overloads. + /// public override int Add(object value) { Add(Cast(value)); @@ -558,11 +579,16 @@ IEnumerator IEnumerable.GetEnumerator() #endregion - /// + /// + /// Adds the elements of the specified array to the end of the . + /// + /// + /// An array of s to add. Each item must be an instance of . + /// Passing any other type will result in an . + /// public override void AddRange(Array values) { - if (values is null) - throw new ArgumentNullException(nameof(values)); + ArgumentNullException.ThrowIfNull(values); foreach (var parameter in values) Add(Cast(parameter)); @@ -599,8 +625,7 @@ public int IndexOf(NpgsqlParameter item) /// Parameter to insert. public void Insert(int index, NpgsqlParameter item) { - if (item is null) - throw new ArgumentNullException(nameof(item)); + ArgumentNullException.ThrowIfNull(item); if (item.Collection != null) throw new Exception("The parameter already belongs to a collection"); @@ -624,8 +649,7 @@ public void Insert(int index, NpgsqlParameter item) /// True if the parameter was found and removed, otherwise false. public bool Remove(NpgsqlParameter item) { - if (item == null) - ThrowHelper.ThrowArgumentNullException(nameof(item)); + ArgumentNullException.ThrowIfNull(item); if (item.Collection != this) ThrowHelper.ThrowInvalidOperationException("The item does not belong to this collection"); @@ -664,7 +688,7 @@ internal void CloneTo(NpgsqlParameterCollection other) foreach (var param in InternalList) { var newParam = param.Clone(); - newParam.Collection = this; + newParam.Collection = other; other.InternalList.Add(newParam); } @@ -679,7 +703,7 @@ internal void CloneTo(NpgsqlParameterCollection other) } } - internal void ProcessParameters(PgSerializerOptions options, bool validateValues, CommandType commandType) + internal void ProcessParameters(NpgsqlDataSource.ReloadableState reloadableState, bool validateValues, CommandType commandType) { HasOutputParameters = false; PlaceholderType = PlaceholderType.NoParameters; @@ -736,7 +760,7 @@ internal void ProcessParameters(PgSerializerOptions options, bool validateValues break; } - p.ResolveTypeInfo(options); + p.ResolveTypeInfo(reloadableState.SerializerOptions, reloadableState.DbTypeResolver); if (validateValues) { diff --git a/src/Npgsql/NpgsqlParameter`.cs b/src/Npgsql/NpgsqlParameter`.cs index a749734643..2f1e1b24bc 100644 --- a/src/Npgsql/NpgsqlParameter`.cs +++ b/src/Npgsql/NpgsqlParameter`.cs @@ -81,6 +81,9 @@ public NpgsqlParameter(string parameterName, DbType dbType) #endregion Constructors + private protected override void SetOutputValueCore(NpgsqlDataReader reader, int ordinal) + => TypedValue = reader.GetFieldValue(ordinal); + private protected override PgConverterResolution ResolveConverter(PgTypeInfo typeInfo) { if (typeof(T) == typeof(object) || TypeInfo!.IsBoxing) @@ -91,18 +94,17 @@ private protected override PgConverterResolution ResolveConverter(PgTypeInfo typ } // We ignore allowNullReference, it's just there to control the base implementation. - private protected override void BindCore(bool allowNullReference = false) + private protected override void BindCore(DataFormat? formatPreference, bool allowNullReference = false) { if (_asObject) { // If we're object typed we should not support null. - base.BindCore(typeof(T) != typeof(object)); + base.BindCore(formatPreference, typeof(T) != typeof(object)); return; } var value = TypedValue; - Debug.Assert(Converter is PgConverter); - if (TypeInfo!.Bind(Unsafe.As>(Converter), value, out var size, out _writeState, out var dataFormat) is { } info) + if (TypeInfo!.Bind(Converter!.UnsafeDowncast(), value, out var size, out _writeState, out var dataFormat, formatPreference) is { } info) { WriteSize = size; _bufferRequirement = info.BufferRequirement; @@ -112,6 +114,7 @@ private protected override void BindCore(bool allowNullReference = false) WriteSize = -1; _bufferRequirement = default; } + Format = dataFormat; } @@ -120,11 +123,10 @@ private protected override ValueTask WriteValue(bool async, PgWriter writer, Can if (_asObject) return base.WriteValue(async, writer, cancellationToken); - Debug.Assert(Converter is PgConverter); if (async) - return Unsafe.As>(Converter!).WriteAsync(writer, TypedValue!, cancellationToken); + return Converter!.UnsafeDowncast().WriteAsync(writer, TypedValue!, cancellationToken); - Unsafe.As>(Converter!).Write(writer, TypedValue!); + Converter!.UnsafeDowncast().Write(writer, TypedValue!); return new(); } @@ -136,6 +138,7 @@ private protected override NpgsqlParameter CloneCore() => _precision = _precision, _scale = _scale, _size = _size, + _dbType = _dbType, _npgsqlDbType = _npgsqlDbType, _dataTypeName = _dataTypeName, Direction = Direction, diff --git a/src/Npgsql/NpgsqlRawCopyStream.cs b/src/Npgsql/NpgsqlRawCopyStream.cs index ffae8e9fc4..981065b813 100644 --- a/src/Npgsql/NpgsqlRawCopyStream.cs +++ b/src/Npgsql/NpgsqlRawCopyStream.cs @@ -1,4 +1,4 @@ -using System; +using System; using System.Diagnostics; using System.IO; using System.Threading; @@ -6,6 +6,7 @@ using Microsoft.Extensions.Logging; using Npgsql.BackendMessages; using Npgsql.Internal; +using InfiniteTimeout = System.Threading.Timeout; using static Npgsql.Util.Statics; #pragma warning disable 1591 @@ -28,7 +29,7 @@ public sealed class NpgsqlRawCopyStream : Stream, ICancelable NpgsqlWriteBuffer _writeBuf; int _leftToReadInDataMsg; - bool _isDisposed, _isConsumed; + CopyStreamState _state = CopyStreamState.Uninitialized; bool _canRead; bool _canWrite; @@ -42,24 +43,25 @@ public sealed class NpgsqlRawCopyStream : Stream, ICancelable public override int WriteTimeout { get => (int) _writeBuf.Timeout.TotalMilliseconds; - set => _writeBuf.Timeout = TimeSpan.FromMilliseconds(value); + set => _writeBuf.Timeout = value > 0 ? TimeSpan.FromMilliseconds(value) : InfiniteTimeout.InfiniteTimeSpan; } public override int ReadTimeout { get => (int) _readBuf.Timeout.TotalMilliseconds; - set => _readBuf.Timeout = TimeSpan.FromMilliseconds(value); + set => _readBuf.Timeout = value > 0 ? TimeSpan.FromMilliseconds(value) : InfiniteTimeout.InfiniteTimeSpan; } /// /// The copy binary format header signature /// internal static readonly byte[] BinarySignature = - { + [ (byte)'P',(byte)'G',(byte)'C',(byte)'O',(byte)'P',(byte)'Y', (byte)'\n', 255, (byte)'\r', (byte)'\n', 0 - }; + ]; readonly ILogger _copyLogger; + Activity? _activity; #endregion @@ -73,34 +75,54 @@ internal NpgsqlRawCopyStream(NpgsqlConnector connector) _copyLogger = connector.LoggingConfiguration.CopyLogger; } - internal async Task Init(string copyCommand, bool async, CancellationToken cancellationToken = default) + internal async Task Init(string copyCommand, bool async, bool? forExport, CancellationToken cancellationToken = default) { - await _connector.WriteQuery(copyCommand, async, cancellationToken).ConfigureAwait(false); - await _connector.Flush(async, cancellationToken).ConfigureAwait(false); + Debug.Assert(_activity is null); + _activity = _connector.TraceCopyStart(copyCommand, forExport switch + { + true => "COPY TO", + false => "COPY FROM", + null => "COPY", + }); - using var registration = _connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); + try + { + await _connector.WriteQuery(copyCommand, async, cancellationToken).ConfigureAwait(false); + await _connector.Flush(async, cancellationToken).ConfigureAwait(false); + + using var registration = _connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); - var msg = await _connector.ReadMessage(async).ConfigureAwait(false); - switch (msg.Code) + var msg = await _connector.ReadMessage(async).ConfigureAwait(false); + switch (msg.Code) + { + case BackendMessageCode.CopyInResponse: + _state = CopyStreamState.Ready; + var copyInResponse = (CopyInResponseMessage)msg; + IsBinary = copyInResponse.IsBinary; + _canWrite = true; + _writeBuf.StartCopyMode(); + TraceSetImport(); + break; + case BackendMessageCode.CopyOutResponse: + _state = CopyStreamState.Ready; + var copyOutResponse = (CopyOutResponseMessage)msg; + IsBinary = copyOutResponse.IsBinary; + _canRead = true; + TraceSetExport(); + break; + case BackendMessageCode.CommandComplete: + throw new InvalidOperationException( + "This API only supports import/export from the client, i.e. COPY commands containing TO/FROM STDIN. " + + "To import/export with files on your PostgreSQL machine, simply execute the command with ExecuteNonQuery. " + + "Note that your data has been successfully imported/exported."); + default: + throw _connector.UnexpectedMessageReceived(msg.Code); + } + } + catch (Exception e) { - case BackendMessageCode.CopyInResponse: - var copyInResponse = (CopyInResponseMessage) msg; - IsBinary = copyInResponse.IsBinary; - _canWrite = true; - _writeBuf.StartCopyMode(); - break; - case BackendMessageCode.CopyOutResponse: - var copyOutResponse = (CopyOutResponseMessage) msg; - IsBinary = copyOutResponse.IsBinary; - _canRead = true; - break; - case BackendMessageCode.CommandComplete: - throw new InvalidOperationException( - "This API only supports import/export from the client, i.e. COPY commands containing TO/FROM STDIN. " + - "To import/export with files on your PostgreSQL machine, simply execute the command with ExecuteNonQuery. " + - "Note that your data has been successfully imported/exported."); - default: - throw _connector.UnexpectedMessageReceived(msg.Code); + TraceSetException(e); + throw; } } @@ -245,7 +267,7 @@ async ValueTask ReadAsyncInternal() async ValueTask ReadCore(int count, bool async, CancellationToken cancellationToken = default) { - if (_isConsumed) + if (_state == CopyStreamState.Consumed) return 0; using var registration = _connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); @@ -259,10 +281,13 @@ async ValueTask ReadCore(int count, bool async, CancellationToken cancellat // read the next message msg = await _connector.ReadMessage(async).ConfigureAwait(false); } - catch + catch (Exception e) { - if (!_isDisposed) + if (_state != CopyStreamState.Disposed) + { + TraceSetException(e); Cleanup(); + } throw; } @@ -274,7 +299,7 @@ async ValueTask ReadCore(int count, bool async, CancellationToken cancellat case BackendMessageCode.CopyDone: Expect(await _connector.ReadMessage(async).ConfigureAwait(false), _connector); Expect(await _connector.ReadMessage(async).ConfigureAwait(false), _connector); - _isConsumed = true; + _state = CopyStreamState.Consumed; return 0; default: throw _connector.UnexpectedMessageReceived(msg.Code); @@ -331,10 +356,18 @@ async Task Cancel(bool async) } catch (PostgresException e) { + // TODO: NpgsqlBinaryImporter doesn't cleanup on cancellation + // And instead relies on users disposing the object + // We probably should do the same here Cleanup(); if (e.SqlState != PostgresErrorCodes.QueryCanceled) + { + TraceSetException(e); throw; + } + + TraceStop(); } } else @@ -352,10 +385,9 @@ async Task Cancel(bool async) public override ValueTask DisposeAsync() => DisposeAsync(disposing: true, async: true); - async ValueTask DisposeAsync(bool disposing, bool async) { - if (_isDisposed || !disposing) + if (_state == CopyStreamState.Disposed || !disposing) return; try @@ -364,33 +396,46 @@ async ValueTask DisposeAsync(bool disposing, bool async) if (CanWrite) { - await FlushAsync(async).ConfigureAwait(false); - _writeBuf.EndCopyMode(); - await _connector.WriteCopyDone(async).ConfigureAwait(false); - await _connector.Flush(async).ConfigureAwait(false); - Expect(await _connector.ReadMessage(async).ConfigureAwait(false), _connector); - Expect(await _connector.ReadMessage(async).ConfigureAwait(false), _connector); + try + { + await FlushAsync(async).ConfigureAwait(false); + _writeBuf.EndCopyMode(); + await _connector.WriteCopyDone(async).ConfigureAwait(false); + await _connector.Flush(async).ConfigureAwait(false); + Expect(await _connector.ReadMessage(async).ConfigureAwait(false), _connector); + Expect(await _connector.ReadMessage(async).ConfigureAwait(false), _connector); + TraceStop(); + } + catch (Exception e) + { + TraceSetException(e); + throw; + } } else { - if (!_isConsumed) + try { - try + if (_state != CopyStreamState.Consumed && _state != CopyStreamState.Uninitialized) { if (_leftToReadInDataMsg > 0) { - await _readBuf.Skip(_leftToReadInDataMsg, async).ConfigureAwait(false); + await _readBuf.Skip(async, _leftToReadInDataMsg).ConfigureAwait(false); } _connector.SkipUntil(BackendMessageCode.ReadyForQuery); } - catch (OperationCanceledException e) when (e.InnerException is PostgresException pg && pg.SqlState == PostgresErrorCodes.QueryCanceled) - { - LogMessages.CopyOperationCancelled(_copyLogger, _connector.Id); - } - catch (Exception e) - { - LogMessages.ExceptionWhenDisposingCopyOperation(_copyLogger, _connector.Id, e); - } + + TraceStop(); + } + catch (OperationCanceledException e) when (e.InnerException is PostgresException { SqlState: PostgresErrorCodes.QueryCanceled }) + { + LogMessages.CopyOperationCancelled(_copyLogger, _connector.Id); + TraceStop(); + } + catch (Exception e) + { + LogMessages.ExceptionWhenDisposingCopyOperation(_copyLogger, _connector.Id, e); + TraceSetException(e); } } } @@ -403,21 +448,20 @@ async ValueTask DisposeAsync(bool disposing, bool async) #pragma warning disable CS8625 void Cleanup() { - Debug.Assert(!_isDisposed); + Debug.Assert(_state != CopyStreamState.Disposed); LogMessages.CopyOperationCompleted(_copyLogger, _connector.Id); _connector.EndUserAction(); _connector.CurrentCopyOperation = null; - _connector.Connection?.EndBindingScope(ConnectorBindingScope.Copy); _connector = null; _readBuf = null; _writeBuf = null; - _isDisposed = true; + _state = CopyStreamState.Disposed; } #pragma warning restore CS8625 void CheckDisposed() { - if (_isDisposed) { + if (_state == CopyStreamState.Disposed) { throw new ObjectDisposedException(nameof(NpgsqlRawCopyStream), "The COPY operation has already ended."); } } @@ -428,15 +472,9 @@ void CheckDisposed() public override bool CanSeek => false; - public override long Seek(long offset, SeekOrigin origin) - { - throw new NotSupportedException(); - } + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); - public override void SetLength(long value) - { - throw new NotSupportedException(); - } + public override void SetLength(long value) => throw new NotSupportedException(); public override long Length => throw new NotSupportedException(); @@ -451,16 +489,63 @@ public override long Position #region Input validation static void ValidateArguments(byte[] buffer, int offset, int count) { - if (buffer == null) - throw new ArgumentNullException(nameof(buffer)); - if (offset < 0) - throw new ArgumentNullException(nameof(offset)); - if (count < 0) - throw new ArgumentNullException(nameof(count)); + ArgumentNullException.ThrowIfNull(buffer); + ArgumentOutOfRangeException.ThrowIfNegative(offset); + ArgumentOutOfRangeException.ThrowIfNegative(count); if (buffer.Length - offset < count) - throw new ArgumentException("Offset and length were out of bounds for the array or count is greater than the number of elements from index to the end of the source collection."); + ThrowHelper.ThrowArgumentException("Offset and length were out of bounds for the array or count is greater than the number of elements from index to the end of the source collection."); + } + #endregion + + #region Tracing + + private void TraceSetImport() + { + if (_activity is not null) + { + NpgsqlActivitySource.SetOperation(_activity, "COPY FROM"); + } + } + + private void TraceSetExport() + { + if (_activity is not null) + { + NpgsqlActivitySource.SetOperation(_activity, "COPY TO"); + } + } + + private void TraceStop() + { + if (_activity is not null) + { + NpgsqlActivitySource.CopyStop(_activity); + _activity = null; + } + } + + private void TraceSetException(Exception e) + { + if (_activity is not null) + { + NpgsqlActivitySource.SetException(_activity, e); + _activity = null; + } } + #endregion + + #region Enums + + enum CopyStreamState + { + Uninitialized, + Ready, + Consumed, + Disposed + } + + #endregion Enums } /// @@ -477,6 +562,20 @@ internal NpgsqlCopyTextWriter(NpgsqlConnector connector, NpgsqlRawCopyStream und throw connector.Break(new Exception("Can't use a binary copy stream for text writing")); } + /// + /// Gets or sets a value, in milliseconds, that determines how long the text writer will attempt to write before timing out. + /// + public int Timeout + { + get => ((NpgsqlRawCopyStream)BaseStream).WriteTimeout; + set + { + var stream = (NpgsqlRawCopyStream)BaseStream; + stream.ReadTimeout = value; + stream.WriteTimeout = value; + } + } + /// /// Cancels and terminates an ongoing import. Any data already written will be discarded. /// @@ -503,6 +602,20 @@ internal NpgsqlCopyTextReader(NpgsqlConnector connector, NpgsqlRawCopyStream und throw connector.Break(new Exception("Can't use a binary copy stream for text reading")); } + /// + /// Gets or sets a value, in milliseconds, that determines how long the text reader will attempt to read before timing out. + /// + public int Timeout + { + get => ((NpgsqlRawCopyStream)BaseStream).ReadTimeout; + set + { + var stream = (NpgsqlRawCopyStream)BaseStream; + stream.ReadTimeout = value; + stream.WriteTimeout = value; + } + } + /// /// Cancels and terminates an ongoing export. /// diff --git a/src/Npgsql/NpgsqlSchema.cs b/src/Npgsql/NpgsqlSchema.cs index ba18c0acc7..aea2f6e925 100644 --- a/src/Npgsql/NpgsqlSchema.cs +++ b/src/Npgsql/NpgsqlSchema.cs @@ -19,8 +19,7 @@ static class NpgsqlSchema { public static Task GetSchema(bool async, NpgsqlConnection conn, string? collectionName, string?[]? restrictions, CancellationToken cancellationToken = default) { - if (collectionName is null) - throw new ArgumentNullException(nameof(collectionName)); + ArgumentNullException.ThrowIfNull(collectionName); if (collectionName.Length == 0) throw new ArgumentException("Collection name cannot be empty.", nameof(collectionName)); @@ -759,7 +758,8 @@ static DataTable GetDataSourceInformation(NpgsqlConnection conn) static DataTable GetDataTypes(NpgsqlConnection conn) { - using var _ = conn.StartTemporaryBindingScope(out var connector); + conn.CheckReady(); + var connector = conn.Connector!; var table = new DataTable("DataTypes"); @@ -790,10 +790,10 @@ static DataTable GetDataTypes(NpgsqlConnection conn) // Npgsql-specific table.Columns.Add("OID", typeof(uint)); - // TODO: Support type name restriction try { + var serializerOptions = connector.SerializerOptions; PgSerializerOptions.IntrospectionCaller = true; var types = new List(); @@ -802,7 +802,7 @@ static DataTable GetDataTypes(NpgsqlConnection conn) types.AddRange(connector.DatabaseInfo.CompositeTypes); foreach (var baseType in types) { - if (connector.SerializerOptions.GetDefaultTypeInfo(baseType) is not { } info) + if (serializerOptions.GetTypeInfoInternal(null, serializerOptions.ToCanonicalTypeId(baseType)) is not { } info) continue; var row = table.Rows.Add(); @@ -817,7 +817,7 @@ static DataTable GetDataTypes(NpgsqlConnection conn) foreach (var arrayType in connector.DatabaseInfo.ArrayTypes) { - if (connector.SerializerOptions.GetDefaultTypeInfo(arrayType) is not { } info) + if (serializerOptions.GetTypeInfoInternal(null, serializerOptions.ToCanonicalTypeId(arrayType)) is not { } info) continue; var row = table.Rows.Add(); @@ -836,7 +836,7 @@ static DataTable GetDataTypes(NpgsqlConnection conn) foreach (var rangeType in connector.DatabaseInfo.RangeTypes) { - if (connector.SerializerOptions.GetDefaultTypeInfo(rangeType) is not { } info) + if (serializerOptions.GetTypeInfoInternal(null, serializerOptions.ToCanonicalTypeId(rangeType)) is not { } info) continue; var row = table.Rows.Add(); @@ -856,7 +856,7 @@ static DataTable GetDataTypes(NpgsqlConnection conn) foreach (var multirangeType in connector.DatabaseInfo.MultirangeTypes) { var subtypeType = multirangeType.Subrange.Subtype; - if (connector.SerializerOptions.GetDefaultTypeInfo(multirangeType) is not { } info) + if (serializerOptions.GetTypeInfoInternal(null, serializerOptions.ToCanonicalTypeId(multirangeType)) is not { } info) continue; var row = table.Rows.Add(); @@ -876,7 +876,7 @@ static DataTable GetDataTypes(NpgsqlConnection conn) foreach (var domainType in connector.DatabaseInfo.DomainTypes) { var representationalType = domainType.GetRepresentationalType(); - if (connector.SerializerOptions.GetDefaultTypeInfo(representationalType) is not { } info) + if (serializerOptions.GetTypeInfoInternal(null, serializerOptions.ToCanonicalTypeId(representationalType)) is not { } info) continue; var row = table.Rows.Add(); @@ -1006,7 +1006,7 @@ static DataTable GetReservedWords() /// List of keywords taken from PostgreSQL 9.0 reserved words documentation. /// static readonly string[] ReservedKeywords = - { + [ "ALL", "ANALYSE", "ANALYZE", @@ -1106,7 +1106,7 @@ static DataTable GetReservedWords() "WHERE", "WINDOW", "WITH" - }; + ]; #endregion Reserved Keywords diff --git a/src/Npgsql/NpgsqlSlimDataSourceBuilder.cs b/src/Npgsql/NpgsqlSlimDataSourceBuilder.cs index 72cfeb4949..ebe7fd9163 100644 --- a/src/Npgsql/NpgsqlSlimDataSourceBuilder.cs +++ b/src/Npgsql/NpgsqlSlimDataSourceBuilder.cs @@ -10,6 +10,7 @@ using Microsoft.Extensions.Logging; using Npgsql.Internal; using Npgsql.Internal.ResolverFactories; +using Npgsql.NameTranslation; using Npgsql.Properties; using Npgsql.TypeMapping; using NpgsqlTypes; @@ -29,10 +30,15 @@ public sealed class NpgsqlSlimDataSourceBuilder : INpgsqlTypeMapper ILoggerFactory? _loggerFactory; bool _sensitiveDataLoggingEnabled; + List>? _tracingOptionsBuilderCallbacks; + List>? _typeLoadingOptionsBuilderCallbacks; TransportSecurityHandler _transportSecurityHandler = new(); RemoteCertificateValidationCallback? _userCertificateValidationCallback; Action? _clientCertificatesCallback; + Action? _sslClientAuthenticationOptionsCallback; + + Action? _negotiateOptionsCallback; IntegratedSecurityHandler _integratedSecurityHandler = new(); @@ -42,6 +48,7 @@ public sealed class NpgsqlSlimDataSourceBuilder : INpgsqlTypeMapper Func>? _periodicPasswordProvider; TimeSpan _periodicPasswordSuccessRefreshInterval, _periodicPasswordFailureRefreshInterval; + List? _dbTypeResolverFactories; PgTypeInfoResolverChainBuilder _resolverChainBuilder = new(); // mutable struct, don't make readonly. readonly UserTypeMapper _userTypeMapper; @@ -54,7 +61,7 @@ public sealed class NpgsqlSlimDataSourceBuilder : INpgsqlTypeMapper internal Action ConfigureDefaultFactories { get; set; } /// - /// A connection string builder that can be used to configured the connection string on the builder. + /// A connection string builder that can be used to configure the connection string on the builder. /// public NpgsqlConnectionStringBuilder ConnectionStringBuilder { get; } @@ -64,7 +71,7 @@ public sealed class NpgsqlSlimDataSourceBuilder : INpgsqlTypeMapper public string ConnectionString => ConnectionStringBuilder.ToString(); static NpgsqlSlimDataSourceBuilder() - => GlobalTypeMapper.Instance.AddGlobalTypeMappingResolvers(new PgTypeInfoResolverFactory[] { new AdoTypeInfoResolverFactory() }); + => GlobalTypeMapper.Instance.AddGlobalTypeMappingResolvers([new AdoTypeInfoResolverFactory()]); /// /// A diagnostics name used by Npgsql when generating tracing, logging and metrics. @@ -111,13 +118,38 @@ public NpgsqlSlimDataSourceBuilder EnableParameterLogging(bool parameterLoggingE return this; } + /// + /// Configure type loading options for the DataSource. Calling this again will replace + /// the prior action. + /// + public NpgsqlSlimDataSourceBuilder ConfigureTypeLoading(Action configureAction) + { + ArgumentNullException.ThrowIfNull(configureAction); + _typeLoadingOptionsBuilderCallbacks ??= new(); + _typeLoadingOptionsBuilderCallbacks.Add(configureAction); + return this; + } + + /// + /// Configures OpenTelemetry tracing options. + /// + /// The same builder instance so that multiple calls can be chained. + public NpgsqlSlimDataSourceBuilder ConfigureTracing(Action configureAction) + { + ArgumentNullException.ThrowIfNull(configureAction); + _tracingOptionsBuilderCallbacks ??= new(); + _tracingOptionsBuilderCallbacks.Add(configureAction); + return this; + } + /// /// Configures the JSON serializer options used when reading and writing all System.Text.Json data. /// /// Options to customize JSON serialization and deserialization. - /// + /// The same builder instance so that multiple calls can be chained. public NpgsqlSlimDataSourceBuilder ConfigureJsonOptions(JsonSerializerOptions serializerOptions) { + ArgumentNullException.ThrowIfNull(serializerOptions); JsonSerializerOptions = serializerOptions; return this; } @@ -139,6 +171,7 @@ public NpgsqlSlimDataSourceBuilder ConfigureJsonOptions(JsonSerializerOptions se /// /// /// The same builder instance so that multiple calls can be chained. + [Obsolete("Use UseSslClientAuthenticationOptionsCallback")] public NpgsqlSlimDataSourceBuilder UseUserCertificateValidationCallback( RemoteCertificateValidationCallback userCertificateValidationCallback) { @@ -152,6 +185,7 @@ public NpgsqlSlimDataSourceBuilder UseUserCertificateValidationCallback( /// /// The client certificate to be sent to PostgreSQL when opening a connection. /// The same builder instance so that multiple calls can be chained. + [Obsolete("Use UseSslClientAuthenticationOptionsCallback")] public NpgsqlSlimDataSourceBuilder UseClientCertificate(X509Certificate? clientCertificate) { if (clientCertificate is null) @@ -166,9 +200,27 @@ public NpgsqlSlimDataSourceBuilder UseClientCertificate(X509Certificate? clientC /// /// The client certificate collection to be sent to PostgreSQL when opening a connection. /// The same builder instance so that multiple calls can be chained. + [Obsolete("Use UseSslClientAuthenticationOptionsCallback")] public NpgsqlSlimDataSourceBuilder UseClientCertificates(X509CertificateCollection? clientCertificates) => UseClientCertificatesCallback(clientCertificates is null ? null : certs => certs.AddRange(clientCertificates)); + /// + /// When using SSL/TLS, this is a callback that allows customizing SslStream's authentication options. + /// + /// The callback to customize SslStream's authentication options. + /// + /// + /// See . + /// + /// + /// The same builder instance so that multiple calls can be chained. + public NpgsqlSlimDataSourceBuilder UseSslClientAuthenticationOptionsCallback(Action? sslClientAuthenticationOptionsCallback) + { + _sslClientAuthenticationOptionsCallback = sslClientAuthenticationOptionsCallback; + + return this; + } + /// /// Specifies a callback to modify the collection of SSL/TLS client certificates which Npgsql will send to PostgreSQL for /// certificate-based authentication. This is an advanced API, consider using or @@ -186,6 +238,7 @@ public NpgsqlSlimDataSourceBuilder UseClientCertificates(X509CertificateCollecti /// /// /// The same builder instance so that multiple calls can be chained. + [Obsolete("Use UseSslClientAuthenticationOptionsCallback")] public NpgsqlSlimDataSourceBuilder UseClientCertificatesCallback(Action? clientCertificatesCallback) { _clientCertificatesCallback = clientCertificatesCallback; @@ -200,9 +253,19 @@ public NpgsqlSlimDataSourceBuilder UseClientCertificatesCallback(ActionThe same builder instance so that multiple calls can be chained. public NpgsqlSlimDataSourceBuilder UseRootCertificate(X509Certificate2? rootCertificate) => rootCertificate is null - ? UseRootCertificateCallback(null) + ? UseRootCertificatesCallback((Func?)null) : UseRootCertificateCallback(() => rootCertificate); + /// + /// Sets the that will be used validate SSL certificate, received from the server. + /// + /// The CA certificates. + /// The same builder instance so that multiple calls can be chained. + public NpgsqlSlimDataSourceBuilder UseRootCertificates(X509Certificate2Collection? rootCertificates) + => rootCertificates is null + ? UseRootCertificatesCallback((Func?)null) + : UseRootCertificatesCallback(() => rootCertificates); + /// /// Specifies a callback that will be used to validate SSL certificate, received from the server. /// @@ -216,7 +279,27 @@ public NpgsqlSlimDataSourceBuilder UseRootCertificate(X509Certificate2? rootCert /// The same builder instance so that multiple calls can be chained. public NpgsqlSlimDataSourceBuilder UseRootCertificateCallback(Func? rootCertificateCallback) { - _transportSecurityHandler.RootCertificateCallback = rootCertificateCallback; + _transportSecurityHandler.RootCertificatesCallback = () => rootCertificateCallback is not null + ? new X509Certificate2Collection(rootCertificateCallback()) + : null; + + return this; + } + + /// + /// Specifies a callback that will be used to validate SSL certificate, received from the server. + /// + /// The callback to get CA certificates. + /// The same builder instance so that multiple calls can be chained. + /// + /// This overload, which accepts a callback, is suitable for scenarios where the certificate rotates + /// and might change during the lifetime of the application. + /// When that's not the case, use the overload which directly accepts the certificate. + /// + /// The same builder instance so that multiple calls can be chained. + public NpgsqlSlimDataSourceBuilder UseRootCertificatesCallback(Func? rootCertificateCallback) + { + _transportSecurityHandler.RootCertificatesCallback = rootCertificateCallback; return this; } @@ -289,6 +372,23 @@ public NpgsqlSlimDataSourceBuilder UsePasswordProvider( return this; } + /// + /// When using Kerberos, this is a callback that allows customizing default settings for Kerberos authentication. + /// + /// The callback containing logic to customize Kerberos authentication settings. + /// + /// + /// See . + /// + /// + /// The same builder instance so that multiple calls can be chained. + public NpgsqlSlimDataSourceBuilder UseNegotiateOptionsCallback(Action? negotiateOptionsCallback) + { + _negotiateOptionsCallback = negotiateOptionsCallback; + + return this; + } + #endregion Authentication #region Type mapping @@ -300,8 +400,27 @@ public INpgsqlNameTranslator DefaultNameTranslator set => _userTypeMapper.DefaultNameTranslator = value; } - /// - public INpgsqlTypeMapper MapEnum<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields)] TEnum>(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) + /// + /// Maps a CLR enum to a PostgreSQL enum type. + /// + /// + /// CLR enum labels are mapped by name to PostgreSQL enum labels. + /// The translation strategy can be controlled by the parameter, + /// which defaults to . + /// You can also use the on your enum fields to manually specify a PostgreSQL enum label. + /// If there is a discrepancy between the .NET and database labels while an enum is read or written, + /// an exception will be raised. + /// + /// + /// A PostgreSQL type name for the corresponding enum type in the database. + /// If null, the name translator given in will be used. + /// + /// + /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). + /// Defaults to . + /// + /// The .NET enum type to be mapped + public NpgsqlSlimDataSourceBuilder MapEnum<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields)] TEnum>(string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) where TEnum : struct, Enum { _userTypeMapper.MapEnum(pgName, nameTranslator); @@ -313,9 +432,28 @@ public INpgsqlNameTranslator DefaultNameTranslator where TEnum : struct, Enum => _userTypeMapper.UnmapEnum(pgName, nameTranslator); - /// + /// + /// Maps a CLR enum to a PostgreSQL enum type. + /// + /// + /// CLR enum labels are mapped by name to PostgreSQL enum labels. + /// The translation strategy can be controlled by the parameter, + /// which defaults to . + /// You can also use the on your enum fields to manually specify a PostgreSQL enum label. + /// If there is a discrepancy between the .NET and database labels while an enum is read or written, + /// an exception will be raised. + /// + /// The .NET enum type to be mapped + /// + /// A PostgreSQL type name for the corresponding enum type in the database. + /// If null, the name translator given in will be used. + /// + /// + /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). + /// Defaults to . + /// [RequiresDynamicCode("Calling MapEnum with a Type can require creating new generic types or methods. This may not work when AOT compiling.")] - public INpgsqlTypeMapper MapEnum([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields | DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)] + public NpgsqlSlimDataSourceBuilder MapEnum([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields | DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)] Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) { _userTypeMapper.MapEnum(clrType, pgName, nameTranslator); @@ -327,9 +465,28 @@ public bool UnmapEnum([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) => _userTypeMapper.UnmapEnum(clrType, pgName, nameTranslator); - /// + /// + /// Maps a CLR type to a PostgreSQL composite type. + /// + /// + /// CLR fields and properties by string to PostgreSQL names. + /// The translation strategy can be controlled by the parameter, + /// which defaults to . + /// You can also use the on your members to manually specify a PostgreSQL name. + /// If there is a discrepancy between the .NET type and database type while a composite is read or written, + /// an exception will be raised. + /// + /// + /// A PostgreSQL type name for the corresponding composite type in the database. + /// If null, the name translator given in will be used. + /// + /// + /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). + /// Defaults to . + /// + /// The .NET type to be mapped [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] - public INpgsqlTypeMapper MapComposite<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] T>( + public NpgsqlSlimDataSourceBuilder MapComposite<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] T>( string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) { _userTypeMapper.MapComposite(typeof(T), pgName, nameTranslator); @@ -342,9 +499,27 @@ public bool UnmapEnum([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) => _userTypeMapper.UnmapComposite(typeof(T), pgName, nameTranslator); - /// + /// + /// Maps a CLR type to a composite type. + /// + /// + /// Maps CLR fields and properties by string to PostgreSQL names. + /// The translation strategy can be controlled by the parameter, + /// which defaults to . + /// If there is a discrepancy between the .NET type and database type while a composite is read or written, + /// an exception will be raised. + /// + /// The .NET type to be mapped. + /// + /// A PostgreSQL type name for the corresponding composite type in the database. + /// If null, the name translator given in will be used. + /// + /// + /// A component which will be used to translate CLR names (e.g. SomeClass) into database names (e.g. some_class). + /// Defaults to . + /// [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] - public INpgsqlTypeMapper MapComposite([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] + public NpgsqlSlimDataSourceBuilder MapComposite([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) { _userTypeMapper.MapComposite(clrType, pgName, nameTranslator); @@ -357,9 +532,14 @@ public bool UnmapComposite([DynamicallyAccessedMembers(DynamicallyAccessedMember Type clrType, string? pgName = null, INpgsqlNameTranslator? nameTranslator = null) => _userTypeMapper.UnmapComposite(clrType, pgName, nameTranslator); + /// + void INpgsqlTypeMapper.AddDbTypeResolverFactory(DbTypeResolverFactory factory) + => (_dbTypeResolverFactories ??= new()).Add(factory); /// - public void AddTypeInfoResolverFactory(PgTypeInfoResolverFactory factory) => _resolverChainBuilder.PrependResolverFactory(factory); + [Experimental(NpgsqlDiagnostics.ConvertersExperimental)] + public void AddTypeInfoResolverFactory(PgTypeInfoResolverFactory factory) + => _resolverChainBuilder.PrependResolverFactory(factory); /// void INpgsqlTypeMapper.Reset() => _resolverChainBuilder.Clear(); @@ -445,6 +625,16 @@ public NpgsqlSlimDataSourceBuilder EnableLTree() return this; } + /// + /// Sets up mappings for the PostgreSQL cube extension type. + /// + /// The same builder instance so that multiple calls can be chained. + public NpgsqlSlimDataSourceBuilder EnableCube() + { + AddTypeInfoResolverFactory(new CubeTypeInfoResolverFactory()); + return this; + } + /// /// Sets up mappings for extra conversions from PostgreSQL to .NET types. /// @@ -467,7 +657,7 @@ public NpgsqlSlimDataSourceBuilder EnableTransportSecurity() } /// - /// Enables the possibility to use GSS/SSPI authentication for connections to PostgreSQL. This does not guarantee that it will + /// Enables the possibility to use GSS/SSPI authentication and encryption for connections to PostgreSQL. This does not guarantee that it will /// actually be used; see for more details. /// /// The same builder instance so that multiple calls can be chained. @@ -477,6 +667,39 @@ public NpgsqlSlimDataSourceBuilder EnableIntegratedSecurity() return this; } + /// + /// Sets up network mappings. This allows mapping PhysicalAddress, IPAddress, NpgsqlInet and NpgsqlCidr types + /// to PostgreSQL macaddr, macaddr8, inet and cidr types. + /// + /// The same builder instance so that multiple calls can be chained. + public NpgsqlSlimDataSourceBuilder EnableNetworkTypes() + { + _resolverChainBuilder.AppendResolverFactory(new NetworkTypeInfoResolverFactory()); + return this; + } + + /// + /// Sets up network mappings. This allows mapping types like NpgsqlPoint and NpgsqlPath + /// to PostgreSQL point, path and so on types. + /// + /// The same builder instance so that multiple calls can be chained. + public NpgsqlSlimDataSourceBuilder EnableGeometricTypes() + { + _resolverChainBuilder.AppendResolverFactory(new GeometricTypeInfoResolverFactory()); + return this; + } + + /// + /// Sets up System.Text.Json mappings. This allows mapping JsonDocument and JsonElement types to PostgreSQL json and jsonb + /// types. + /// + /// The same builder instance so that multiple calls can be chained. + public NpgsqlSlimDataSourceBuilder EnableJsonTypes() + { + _resolverChainBuilder.AppendResolverFactory(() => new JsonTypeInfoResolverFactory(JsonSerializerOptions)); + return this; + } + /// /// Sets up dynamic System.Text.Json mappings. This allows mapping arbitrary .NET types to PostgreSQL json and jsonb /// types, as well as and its derived types. @@ -490,6 +713,7 @@ public NpgsqlSlimDataSourceBuilder EnableIntegratedSecurity() /// /// Due to the dynamic nature of these mappings, they are not compatible with NativeAOT or trimming. /// + /// The same builder instance so that multiple calls can be chained. [RequiresUnreferencedCode("Json serializer may perform reflection on trimmed types.")] [RequiresDynamicCode("Serializing arbitrary types to json can require creating new generic types or methods, which requires creating code at runtime. This may not work when AOT compiling.")] public NpgsqlSlimDataSourceBuilder EnableDynamicJson( @@ -565,21 +789,18 @@ public NpgsqlSlimDataSourceBuilder UsePhysicalConnectionInitializer( /// public NpgsqlDataSource Build() { - var config = PrepareConfiguration(); - var connectionStringBuilder = ConnectionStringBuilder.Clone(); + var (connectionStringBuilder, config) = PrepareConfiguration(); - if (ConnectionStringBuilder.Host!.Contains(",")) + if (ConnectionStringBuilder.Host!.Contains(',')) { ValidateMultiHost(); return new NpgsqlMultiHostDataSource(connectionStringBuilder, config); } - return ConnectionStringBuilder.Multiplexing - ? new MultiplexingDataSource(connectionStringBuilder, config) - : ConnectionStringBuilder.Pooling - ? new PoolingDataSource(connectionStringBuilder, config) - : new UnpooledDataSource(connectionStringBuilder, config); + return ConnectionStringBuilder.Pooling + ? new PoolingDataSource(connectionStringBuilder, config) + : new UnpooledDataSource(connectionStringBuilder, config); } /// @@ -587,18 +808,43 @@ public NpgsqlDataSource Build() /// public NpgsqlMultiHostDataSource BuildMultiHost() { - var config = PrepareConfiguration(); + var (connectionStringBuilder, config) = PrepareConfiguration(); ValidateMultiHost(); - return new(ConnectionStringBuilder.Clone(), config); + return new(connectionStringBuilder, config); } - NpgsqlDataSourceConfiguration PrepareConfiguration() + (NpgsqlConnectionStringBuilder, NpgsqlDataSourceConfiguration) PrepareConfiguration() { ConnectionStringBuilder.PostProcessAndValidate(); + var connectionStringBuilder = ConnectionStringBuilder.Clone(); - if (!_transportSecurityHandler.SupportEncryption && (_userCertificateValidationCallback is not null || _clientCertificatesCallback is not null)) + var sslClientAuthenticationOptionsCallback = _sslClientAuthenticationOptionsCallback; + var hasCertificateCallbacks = _userCertificateValidationCallback is not null || _clientCertificatesCallback is not null; + if (sslClientAuthenticationOptionsCallback is not null && hasCertificateCallbacks) + { + throw new NotSupportedException(NpgsqlStrings.SslClientAuthenticationOptionsCallbackWithOtherCallbacksNotSupported); + } + + if (sslClientAuthenticationOptionsCallback is null && hasCertificateCallbacks) + { + sslClientAuthenticationOptionsCallback = options => + { + if (_clientCertificatesCallback is not null) + { + options.ClientCertificates ??= new X509Certificate2Collection(); + _clientCertificatesCallback.Invoke(options.ClientCertificates); + } + + if (_userCertificateValidationCallback is not null) + { + options.RemoteCertificateValidationCallback = _userCertificateValidationCallback; + } + }; + } + + if (!_transportSecurityHandler.SupportEncryption && sslClientAuthenticationOptionsCallback is not null) { throw new InvalidOperationException(NpgsqlStrings.TransportSecurityDisabled); } @@ -616,48 +862,45 @@ NpgsqlDataSourceConfiguration PrepareConfiguration() ConfigureDefaultFactories(this); - return new( + var typeLoadingOptionsBuilder = new NpgsqlTypeLoadingOptionsBuilder(); +#pragma warning disable CS0618 // Type or member is obsolete + typeLoadingOptionsBuilder.EnableTableCompositesLoading(connectionStringBuilder.LoadTableComposites); + typeLoadingOptionsBuilder.EnableTypeLoading(connectionStringBuilder.ServerCompatibilityMode is not ServerCompatibilityMode.NoTypeLoading); +#pragma warning restore CS0618 // Type or member is obsolete + foreach (var callback in _typeLoadingOptionsBuilderCallbacks ?? (IEnumerable>)[]) + callback.Invoke(typeLoadingOptionsBuilder); + var typeLoadingOptions = typeLoadingOptionsBuilder.Build(); + + var tracingOptionsBuilder = new NpgsqlTracingOptionsBuilder(); + foreach (var callback in _tracingOptionsBuilderCallbacks ?? (IEnumerable>)[]) + callback.Invoke(tracingOptionsBuilder); + var tracingOptions = tracingOptionsBuilder.Build(); + + return (connectionStringBuilder, new( Name, _loggerFactory is null ? NpgsqlLoggingConfiguration.NullConfiguration : new NpgsqlLoggingConfiguration(_loggerFactory, _sensitiveDataLoggingEnabled), + tracingOptions, + typeLoadingOptions, _transportSecurityHandler, _integratedSecurityHandler, - _userCertificateValidationCallback, - _clientCertificatesCallback, + sslClientAuthenticationOptionsCallback, _passwordProvider, _passwordProviderAsync, _periodicPasswordProvider, _periodicPasswordSuccessRefreshInterval, _periodicPasswordFailureRefreshInterval, _resolverChainBuilder.Build(ConfigureResolverChain), - HackyEnumMappings(), + _dbTypeResolverFactories ?? [], DefaultNameTranslator, _connectionInitializer, - _connectionInitializerAsync); - - List HackyEnumMappings() - { - var mappings = new List(); - - if (_userTypeMapper.Items.Count > 0) - foreach (var userTypeMapping in _userTypeMapper.Items) - if (userTypeMapping is UserTypeMapper.EnumMapping enumMapping) - mappings.Add(new(enumMapping.ClrType, enumMapping.PgTypeName, enumMapping.NameTranslator)); - - if (GlobalTypeMapper.Instance.HackyEnumTypeMappings.Count > 0) - mappings.AddRange(GlobalTypeMapper.Instance.HackyEnumTypeMappings); - - return mappings; - } + _connectionInitializerAsync, + _negotiateOptionsCallback)); } void ValidateMultiHost() { - if (ConnectionStringBuilder.TargetSessionAttributes is not null) - throw new InvalidOperationException(NpgsqlStrings.CannotSpecifyTargetSessionAttributes); - if (ConnectionStringBuilder.Multiplexing) - throw new NotSupportedException("Multiplexing is not supported with multiple hosts"); if (ConnectionStringBuilder.ReplicationMode != ReplicationMode.Off) throw new NotSupportedException("Replication is not supported with multiple hosts"); } @@ -684,4 +927,38 @@ INpgsqlTypeMapper INpgsqlTypeMapper.EnableRecordsAsTuples() "The use of unmapped enums, ranges or multiranges requires dynamic code usage which is incompatible with NativeAOT.")] INpgsqlTypeMapper INpgsqlTypeMapper.EnableUnmappedTypes() => EnableUnmappedTypes(); + + /// + INpgsqlTypeMapper INpgsqlTypeMapper.MapEnum<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields)] TEnum>(string? pgName, INpgsqlNameTranslator? nameTranslator) + { + _userTypeMapper.MapEnum(pgName, nameTranslator); + return this; + } + + /// + [RequiresDynamicCode("Calling MapEnum with a Type can require creating new generic types or methods. This may not work when AOT compiling.")] + INpgsqlTypeMapper INpgsqlTypeMapper.MapEnum([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields | DynamicallyAccessedMemberTypes.PublicParameterlessConstructor)] + Type clrType, string? pgName, INpgsqlNameTranslator? nameTranslator) + { + _userTypeMapper.MapEnum(clrType, pgName, nameTranslator); + return this; + } + + /// + [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] + INpgsqlTypeMapper INpgsqlTypeMapper.MapComposite<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] T>( + string? pgName, INpgsqlNameTranslator? nameTranslator) + { + _userTypeMapper.MapComposite(typeof(T), pgName, nameTranslator); + return this; + } + + /// + [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] + INpgsqlTypeMapper INpgsqlTypeMapper.MapComposite([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicProperties | DynamicallyAccessedMemberTypes.PublicFields)] + Type clrType, string? pgName, INpgsqlNameTranslator? nameTranslator) + { + _userTypeMapper.MapComposite(clrType, pgName, nameTranslator); + return this; + } } diff --git a/src/Npgsql/NpgsqlTracingOptions.cs b/src/Npgsql/NpgsqlTracingOptions.cs deleted file mode 100644 index 4aa61beec6..0000000000 --- a/src/Npgsql/NpgsqlTracingOptions.cs +++ /dev/null @@ -1,9 +0,0 @@ -namespace Npgsql; - -/// -/// Options to configure Npgsql's support for OpenTelemetry tracing. -/// Currently no options are available. -/// -public class NpgsqlTracingOptions -{ -} \ No newline at end of file diff --git a/src/Npgsql/NpgsqlTracingOptionsBuilder.cs b/src/Npgsql/NpgsqlTracingOptionsBuilder.cs new file mode 100644 index 0000000000..d38cf1d32d --- /dev/null +++ b/src/Npgsql/NpgsqlTracingOptionsBuilder.cs @@ -0,0 +1,164 @@ +using System; +using System.Diagnostics; + +namespace Npgsql; + +/// +/// A builder to configure Npgsql's support for OpenTelemetry tracing. +/// +public sealed class NpgsqlTracingOptionsBuilder +{ + Func? _commandFilter; + Func? _batchFilter; + Action? _commandEnrichmentCallback; + Action? _batchEnrichmentCallback; + Func? _commandSpanNameProvider; + Func? _batchSpanNameProvider; + bool _enableFirstResponseEvent = true; + bool _enablePhysicalOpenTracing = true; + + Func? _copyOperationFilter; + Action? _copyOperationEnrichmentCallback; + Func? _copyOperationSpanNameProvider; + + internal NpgsqlTracingOptionsBuilder() + { + } + + /// + /// Configures a filter function that determines whether to emit tracing information for an . + /// By default, tracing information is emitted for all commands. + /// + public NpgsqlTracingOptionsBuilder ConfigureCommandFilter(Func? commandFilter) + { + _commandFilter = commandFilter; + return this; + } + + /// + /// Configures a filter function that determines whether to emit tracing information for an . + /// By default, tracing information is emitted for all batches. + /// + public NpgsqlTracingOptionsBuilder ConfigureBatchFilter(Func? batchFilter) + { + _batchFilter = batchFilter; + return this; + } + + /// + /// Configures a callback that can enrich the emitted for the given . + /// + public NpgsqlTracingOptionsBuilder ConfigureCommandEnrichmentCallback(Action? commandEnrichmentCallback) + { + _commandEnrichmentCallback = commandEnrichmentCallback; + return this; + } + + /// + /// Configures a callback that can enrich the emitted for the given . + /// + public NpgsqlTracingOptionsBuilder ConfigureBatchEnrichmentCallback(Action? batchEnrichmentCallback) + { + _batchEnrichmentCallback = batchEnrichmentCallback; + return this; + } + + /// + /// Configures a callback that provides the tracing span's name for an . If null, the default standard + /// span name is used, which is the database name. + /// + public NpgsqlTracingOptionsBuilder ConfigureCommandSpanNameProvider(Func? commandSpanNameProvider) + { + _commandSpanNameProvider = commandSpanNameProvider; + return this; + } + + /// + /// Configures a callback that provides the tracing span's name for an . If null, the default standard + /// span name is used, which is the database name. + /// + public NpgsqlTracingOptionsBuilder ConfigureBatchSpanNameProvider(Func? batchSpanNameProvider) + { + _batchSpanNameProvider = batchSpanNameProvider; + return this; + } + + /// + /// Gets or sets a value indicating whether to enable the "time-to-first-read" event. + /// Default is true to preserve existing behavior. + /// + public NpgsqlTracingOptionsBuilder EnableFirstResponseEvent(bool enable = true) + { + _enableFirstResponseEvent = enable; + return this; + } + + /// + /// Gets or sets a value indicating whether to trace physical connection open. + /// Default is true to preserve existing behavior. + /// + public NpgsqlTracingOptionsBuilder EnablePhysicalOpenTracing(bool enable = true) + { + _enablePhysicalOpenTracing = enable; + return this; + } + + /// + /// Configures a filter function that determines whether to emit tracing information for a copy operation. + /// By default, tracing information is emitted for all copy operations. + /// + public NpgsqlTracingOptionsBuilder ConfigureCopyOperationFilter(Func? copyOperationFilter) + { + _copyOperationFilter = copyOperationFilter; + return this; + } + + /// + /// Configures a callback that can enrich the emitted for a given copy operation. + /// + public NpgsqlTracingOptionsBuilder ConfigureCopyOperationEnrichmentCallback(Action? copyOperationEnrichmentCallback) + { + _copyOperationEnrichmentCallback = copyOperationEnrichmentCallback; + return this; + } + + /// + /// Configures a callback that provides the tracing span's name for a copy operation. If null, the default standard + /// span name is used, which is the database name. + /// + public NpgsqlTracingOptionsBuilder ConfigureCopyOperationSpanNameProvider(Func? copyOperationSpanNameProvider) + { + _copyOperationSpanNameProvider = copyOperationSpanNameProvider; + return this; + } + + internal NpgsqlTracingOptions Build() => new() + { + CommandFilter = _commandFilter, + BatchFilter = _batchFilter, + CommandEnrichmentCallback = _commandEnrichmentCallback, + BatchEnrichmentCallback = _batchEnrichmentCallback, + CommandSpanNameProvider = _commandSpanNameProvider, + BatchSpanNameProvider = _batchSpanNameProvider, + EnableFirstResponseEvent = _enableFirstResponseEvent, + EnablePhysicalOpenTracing = _enablePhysicalOpenTracing, + CopyOperationFilter = _copyOperationFilter, + CopyOperationEnrichmentCallback = _copyOperationEnrichmentCallback, + CopyOperationSpanNameProvider = _copyOperationSpanNameProvider + }; +} + +sealed class NpgsqlTracingOptions +{ + internal Func? CommandFilter { get; init; } + internal Func? BatchFilter { get; init; } + internal Action? CommandEnrichmentCallback { get; init; } + internal Action? BatchEnrichmentCallback { get; init; } + internal Func? CommandSpanNameProvider { get; init; } + internal Func? BatchSpanNameProvider { get; init; } + internal bool EnableFirstResponseEvent { get; init; } + internal bool EnablePhysicalOpenTracing { get; init; } + internal Func? CopyOperationFilter { get; init; } + internal Action? CopyOperationEnrichmentCallback { get; init; } + internal Func? CopyOperationSpanNameProvider { get; init; } +} diff --git a/src/Npgsql/NpgsqlTransaction.cs b/src/Npgsql/NpgsqlTransaction.cs index 6481e185af..14254bdccc 100644 --- a/src/Npgsql/NpgsqlTransaction.cs +++ b/src/Npgsql/NpgsqlTransaction.cs @@ -192,10 +192,7 @@ public override Task RollbackAsync(CancellationToken cancellationToken = default /// public override void Save(string name) { - if (name == null) - throw new ArgumentNullException(nameof(name)); - if (string.IsNullOrWhiteSpace(name)) - throw new ArgumentException("name can't be empty", nameof(name)); + ArgumentException.ThrowIfNullOrWhiteSpace(name); CheckReady(); if (!_connector.DatabaseInfo.SupportsTransactions) @@ -212,16 +209,7 @@ public override void Save(string name) // Note: savepoint names are PostgreSQL identifiers, and so limited by default to 63 characters. // Since we are prepending, we assume below that the statement will always fit in the buffer. - _connector.WriteBuffer.WriteByte(FrontendMessageCode.Query); - _connector.WriteBuffer.WriteInt32( - sizeof(int) + // Message length (including self excluding code) - _connector.TextEncoding.GetByteCount("SAVEPOINT ") + - _connector.TextEncoding.GetByteCount(name) + - sizeof(byte)); // Null terminator - - _connector.WriteBuffer.WriteString("SAVEPOINT "); - _connector.WriteBuffer.WriteString(name); - _connector.WriteBuffer.WriteByte(0); + _connector.WriteQuery("SAVEPOINT " + name, async: false).GetAwaiter().GetResult(); _connector.PendingPrependedResponses += 2; } @@ -245,10 +233,7 @@ public override Task SaveAsync(string name, CancellationToken cancellationToken async Task Rollback(bool async, string name, CancellationToken cancellationToken = default) { - if (name == null) - throw new ArgumentNullException(nameof(name)); - if (string.IsNullOrWhiteSpace(name)) - throw new ArgumentException("name can't be empty", nameof(name)); + ArgumentException.ThrowIfNullOrWhiteSpace(name); CheckReady(); if (!_connector.DatabaseInfo.SupportsTransactions) @@ -280,10 +265,7 @@ public override Task RollbackAsync(string name, CancellationToken cancellationTo async Task Release(bool async, string name, CancellationToken cancellationToken = default) { - if (name == null) - throw new ArgumentNullException(nameof(name)); - if (string.IsNullOrWhiteSpace(name)) - throw new ArgumentException("name can't be empty", nameof(name)); + ArgumentException.ThrowIfNullOrWhiteSpace(name); CheckReady(); if (!_connector.DatabaseInfo.SupportsTransactions) @@ -316,10 +298,7 @@ public override Task ReleaseAsync(string name, CancellationToken cancellationTok /// /// Indicates whether this transaction supports database savepoints. /// - public override bool SupportsSavepoints - { - get => _connector.DatabaseInfo.SupportsTransactions; - } + public override bool SupportsSavepoints => _connector.DatabaseInfo.SupportsTransactions; #endregion @@ -349,7 +328,6 @@ protected override void Dispose(bool disposing) } IsDisposed = true; - _connector?.Connection?.EndBindingScope(ConnectorBindingScope.Transaction); } } @@ -366,8 +344,8 @@ public override ValueTask DisposeAsync() } IsDisposed = true; - _connector?.Connection?.EndBindingScope(ConnectorBindingScope.Transaction); } + return default; async ValueTask DisposeAsyncInternal() @@ -385,7 +363,6 @@ async ValueTask DisposeAsyncInternal() } IsDisposed = true; - _connector?.Connection?.EndBindingScope(ConnectorBindingScope.Transaction); } } diff --git a/src/Npgsql/NpgsqlTypeLoadingOptions.cs b/src/Npgsql/NpgsqlTypeLoadingOptions.cs new file mode 100644 index 0000000000..c031826675 --- /dev/null +++ b/src/Npgsql/NpgsqlTypeLoadingOptions.cs @@ -0,0 +1,114 @@ +using System; +using System.Collections.Generic; + +namespace Npgsql; + +/// +/// Options for configuring Npgsql type loading. +/// +sealed class NpgsqlTypeLoadingOptions +{ + /// + /// Load table composite type definitions, and not just free-standing composite types. + /// + public required bool LoadTableComposites { get; init; } + + /// + /// When false, if the server doesn't support full type loading from the PostgreSQL catalogs, + /// support the basic set of types via information hardcoded inside Npgsql. + /// + public required bool LoadTypes { get; init; } = true; + + /// + /// Load type definitions from the given schemas. + /// + public required string[]? TypeLoadingSchemas { get; init; } +} + +/// +/// Options builder for configuring Npgsql type loading. +/// +public sealed class NpgsqlTypeLoadingOptionsBuilder +{ + bool _loadTableComposites; + bool _loadTypes = true; + List? _typeLoadingSchemas; + + internal NpgsqlTypeLoadingOptionsBuilder() {} + + /// + /// Enable loading table composite type definitions, and not just free-standing composite types. + /// + public NpgsqlTypeLoadingOptionsBuilder EnableTableCompositesLoading(bool enable = true) + { + _loadTableComposites = enable; + return this; + } + + /// + /// Enable loading of types, when disabled Npgsql falls back to a small, builtin, set of known types and type ids. + /// + public NpgsqlTypeLoadingOptionsBuilder EnableTypeLoading(bool enable = true) + { + _loadTypes = enable; + return this; + } + + /// + /// Set the schemas to load types from, this can be used to reduce the work done during type loading. + /// + /// Npgsql will always load types from the following schemas: pg_catalog, information_schema, pg_toast. + /// Any user-defined types (typcategory 'U') will also be loaded regardless of their schema. + /// Schemas to load types from. + public NpgsqlTypeLoadingOptionsBuilder SetTypeLoadingSchemas(params IEnumerable? schemas) + { + if (schemas is null) + { + _typeLoadingSchemas = null; + return this; + } + + _typeLoadingSchemas = new(); + foreach (var schema in schemas) + { + if (schema is not { Length: > 0 }) + { + _typeLoadingSchemas = null; + throw new ArgumentException("Schema cannot be null or empty."); + } + _typeLoadingSchemas.Add(schema); + } + + return this; + } + + internal NpgsqlTypeLoadingOptions Build() => new() + { + LoadTableComposites = _loadTableComposites, + LoadTypes = _loadTypes, + TypeLoadingSchemas = _typeLoadingSchemas?.ToArray() + }; +} + +/// +/// An option specified in the connection string that activates special compatibility features. +/// +public enum ServerCompatibilityMode +{ + /// + /// No special server compatibility mode is active + /// + None, + + /// + /// The server is an Amazon Redshift instance. + /// + [Obsolete("ServerCompatibilityMode.Redshift no longer does anything and can be safely removed.")] + Redshift, + + /// + /// The server is doesn't support full type loading from the PostgreSQL catalogs, support the basic set + /// of types via information hardcoded inside Npgsql. + /// + NoTypeLoading, +} diff --git a/src/Npgsql/NpgsqlTypes/NpgsqlCube.cs b/src/Npgsql/NpgsqlTypes/NpgsqlCube.cs new file mode 100644 index 0000000000..b84953c483 --- /dev/null +++ b/src/Npgsql/NpgsqlTypes/NpgsqlCube.cs @@ -0,0 +1,251 @@ +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Text; + +// ReSharper disable once CheckNamespace +namespace NpgsqlTypes; + +/// +/// Represents a PostgreSQL cube data type. +/// +/// +/// See https://www.postgresql.org/docs/current/cube.html +/// +public readonly struct NpgsqlCube : IEquatable +{ + // Store the coordinates as a value tuple array + readonly double[] _lowerLeft; + readonly double[] _upperRight; + + /// + /// The lower left coordinates of the cube. + /// + public IReadOnlyList LowerLeft => _lowerLeft; + + /// + /// The upper right coordinates of the cube. + /// + public IReadOnlyList UpperRight => _upperRight; + + /// + /// The number of dimensions of the cube. + /// + public int Dimensions => _lowerLeft.Length; + + /// + /// True if the cube is a point, that is, the two defining corners are the same. + /// + public bool IsPoint { get; } + + /// + /// Makes a cube with upper right and lower left coordinates as defined by the two arrays, which must be of the same length. + /// + /// This is an internal constructor to optimize the number of allocations. + /// The lower left values. + /// The upper right values. + /// + /// Thrown if the number of dimensions in the upper left and lower right values do not match. + /// + internal NpgsqlCube(double[] lowerLeft, double[] upperRight) + { + if (lowerLeft.Length != upperRight.Length) + throw new ArgumentException($"Not a valid cube: Different point dimensions in {lowerLeft} and {upperRight}."); + + IsPoint = lowerLeft.SequenceEqual(upperRight); + _lowerLeft = lowerLeft; + _upperRight = upperRight; + } + + /// + /// Makes a one dimensional cube with both coordinates the same. + /// + /// The point coordinate. + public NpgsqlCube(double coord) + { + IsPoint = true; + _lowerLeft = [coord]; + _upperRight = _lowerLeft; + } + + /// + /// Makes a one dimensional cube. + /// + /// The lower left value. + /// The upper right value. + public NpgsqlCube(double lowerLeft, double upperRight) + { + IsPoint = lowerLeft.CompareTo(upperRight) == 0; + _lowerLeft = [lowerLeft]; + _upperRight = IsPoint ? _lowerLeft : [upperRight]; + } + + /// + /// Makes a zero-volume cube using the coordinates defined by the array. + /// + /// The coordinates. + public NpgsqlCube(IEnumerable coords) + { + // Always create a defensive copy to prevent external mutation + _lowerLeft = coords.ToArray(); + IsPoint = true; + _upperRight = _lowerLeft; + } + + /// + /// Makes a cube with upper right and lower left coordinates as defined by the two arrays, which must be of the same length. + /// + /// The lower left values. + /// The upper right values. + /// + /// Thrown if the number of dimensions in the upper left and lower right values do not match + /// or if the cube exceeds the maximum dimensions (100). + /// + public NpgsqlCube(IEnumerable lowerLeft, IEnumerable upperRight) : + this(lowerLeft.ToArray(), upperRight.ToArray()) + { } + + /// + /// Makes a new cube by adding a dimension on to an existing cube, with the same values for both endpoints of the new coordinate. + /// This is useful for building cubes piece by piece from calculated values. + /// + /// The existing cube. + /// The coordinate to add. + public NpgsqlCube(NpgsqlCube cube, double coord) + { + IsPoint = cube.IsPoint; + if (IsPoint) + { + _lowerLeft = cube._lowerLeft.Append(coord).ToArray(); + _upperRight = _lowerLeft; + } + else + { + _lowerLeft = cube._lowerLeft.Append(coord).ToArray(); + _upperRight = cube._upperRight.Append(coord).ToArray(); + } + } + + /// + /// Makes a new cube by adding a dimension on to an existing cube. + /// This is useful for building cubes piece by piece from calculated values. + /// + /// The existing cube. + /// The lower left value. + /// The upper right value. + public NpgsqlCube(NpgsqlCube cube, double lowerLeft, double upperRight) + { + IsPoint = cube.IsPoint && lowerLeft.CompareTo(upperRight) == 0; + if (IsPoint) + { + _lowerLeft = cube._lowerLeft.Append(lowerLeft).ToArray(); + _upperRight = _lowerLeft; + } + else + { + _lowerLeft = cube._lowerLeft.Append(lowerLeft).ToArray(); + _upperRight = cube._upperRight.Append(upperRight).ToArray(); + } + } + + /// + /// Makes a new cube from an existing cube, using a list of dimension indexes from an array. + /// Can be used to extract the endpoints of a single dimension, or to drop dimensions, or to reorder them as desired. + /// + /// The list of dimension indexes. + /// A new cube. + /// + /// + /// var cube = new NpgsqlCube(new[] { 1, 3, 5 }, new[] { 6, 7, 8 }); // '(1,3,5),(6,7,8)' + /// cube.ToSubset(1); // '(3),(7)' + /// cube.ToSubset(2, 1, 0, 0); // '(5,3,1,1),(8,7,6,6)' + /// + /// + public NpgsqlCube ToSubset(params int[] indexes) + { + var lowerLeft = new double[indexes.Length]; + var upperRight = new double[indexes.Length]; + + for (var i = 0; i < indexes.Length; i++) + { + lowerLeft[i] = _lowerLeft[indexes[i]]; + upperRight[i] = _upperRight[indexes[i]]; + } + + return new NpgsqlCube(lowerLeft, upperRight); + } + + /// + public bool Equals(NpgsqlCube other) => Dimensions == other.Dimensions + && _lowerLeft.SequenceEqual(other._lowerLeft) + && _upperRight.SequenceEqual(other._upperRight); + + /// + public override bool Equals(object? obj) => obj is NpgsqlCube other && Equals(other); + + /// + public static bool operator ==(NpgsqlCube x, NpgsqlCube y) => x.Equals(y); + + /// + public static bool operator !=(NpgsqlCube x, NpgsqlCube y) => !(x == y); + + /// + public override int GetHashCode() + { + var hashCode = new HashCode(); + for (var i = 0; i < Dimensions; i++) + { + hashCode.Add(_lowerLeft[i]); + hashCode.Add(_upperRight[i]); + } + return hashCode.ToHashCode(); + } + + /// + /// Writes the cube in PostgreSQL's text format. + /// + void Write(StringBuilder stringBuilder) + { + var leftBuilder = new StringBuilder(); + var rightBuilder = new StringBuilder(); + + leftBuilder.Append('('); + rightBuilder.Append('('); + + for (var i = 0; i < Dimensions; i++) + { + leftBuilder.Append(CultureInfo.InvariantCulture, $"{_lowerLeft[i]:G17}"); + rightBuilder.Append(CultureInfo.InvariantCulture, $"{_upperRight[i]:G17}"); + + if (i >= Dimensions - 1) continue; + + leftBuilder.Append(", "); + rightBuilder.Append(", "); + } + + leftBuilder.Append(')'); + rightBuilder.Append(')'); + + if (IsPoint) + { + stringBuilder.Append(leftBuilder); + } + else + { + stringBuilder.Append(leftBuilder); + stringBuilder.Append(','); + stringBuilder.Append(rightBuilder); + } + } + + /// + /// Writes the cube in PostgreSQL's text format. + /// + public override string ToString() + { + var sb = new StringBuilder(); + Write(sb); + return sb.ToString(); + } +} diff --git a/src/Npgsql/NpgsqlTypes/NpgsqlDbType.cs b/src/Npgsql/NpgsqlTypes/NpgsqlDbType.cs index 687ebf16b7..ab8d1480f8 100644 --- a/src/Npgsql/NpgsqlTypes/NpgsqlDbType.cs +++ b/src/Npgsql/NpgsqlTypes/NpgsqlDbType.cs @@ -123,6 +123,12 @@ public enum NpgsqlDbType /// See https://www.postgresql.org/docs/current/static/datatype-geometric.html Polygon = 16, + /// + /// Corresponds to the PostgreSQL "cube" type, a geometric type representing multi-dimensional cubes. + /// + /// See https://www.postgresql.org/docs/current/cube.html + Cube = 63, // Extension type + #endregion #region Character Types @@ -627,10 +633,6 @@ public static DbType ToDbType(this NpgsqlDbType npgsqlDbType) NpgsqlDbType.Char => DbType.String, NpgsqlDbType.Name => DbType.String, NpgsqlDbType.Citext => DbType.String, - NpgsqlDbType.Refcursor => DbType.Object, - NpgsqlDbType.Jsonb => DbType.Object, - NpgsqlDbType.Json => DbType.Object, - NpgsqlDbType.JsonPath => DbType.Object, // Date/time types NpgsqlDbType.Timestamp => LegacyTimestampBehavior ? DbType.DateTime : DbType.DateTime2, @@ -643,8 +645,6 @@ public static DbType ToDbType(this NpgsqlDbType npgsqlDbType) NpgsqlDbType.Boolean => DbType.Boolean, NpgsqlDbType.Uuid => DbType.Guid, - NpgsqlDbType.Unknown => DbType.Object, - _ => DbType.Object }; @@ -740,6 +740,7 @@ public static DbType ToDbType(this NpgsqlDbType npgsqlDbType) // Plugin types NpgsqlDbType.Citext => "citext", + NpgsqlDbType.Cube => "cube", NpgsqlDbType.LQuery => "lquery", NpgsqlDbType.LTree => "ltree", NpgsqlDbType.LTxtQuery => "ltxtquery", @@ -869,11 +870,11 @@ _ when npgsqlDbType.HasFlag(NpgsqlDbType.Multirange) internal static NpgsqlDbType? ToNpgsqlDbType(this DataTypeName dataTypeName) => ToNpgsqlDbType(dataTypeName.UnqualifiedName); /// Should not be used with display names, first normalize it instead. - internal static NpgsqlDbType? ToNpgsqlDbType(string dataTypeName) + internal static NpgsqlDbType? ToNpgsqlDbType(string normalizedDataTypeName) { - var unqualifiedName = dataTypeName; - if (dataTypeName.IndexOf(".", StringComparison.Ordinal) is not -1 and var index) - unqualifiedName = dataTypeName.Substring(0, index); + var unqualifiedName = normalizedDataTypeName.AsSpan(); + if (unqualifiedName.IndexOf('.') is not -1 and var index) + unqualifiedName = unqualifiedName.Slice(index + 1); return unqualifiedName switch { @@ -964,6 +965,7 @@ _ when npgsqlDbType.HasFlag(NpgsqlDbType.Multirange) // Plugin types "citext" => NpgsqlDbType.Citext, + "cube" => NpgsqlDbType.Cube, "lquery" => NpgsqlDbType.LQuery, "ltree" => NpgsqlDbType.LTree, "ltxtquery" => NpgsqlDbType.LTxtQuery, @@ -971,12 +973,12 @@ _ when npgsqlDbType.HasFlag(NpgsqlDbType.Multirange) "geometry" => NpgsqlDbType.Geometry, "geography" => NpgsqlDbType.Geography, - _ when unqualifiedName.Contains("unknown") + _ when unqualifiedName.IndexOf("unknown") != -1 => !unqualifiedName.StartsWith("_", StringComparison.Ordinal) ? NpgsqlDbType.Unknown : null, _ when unqualifiedName.StartsWith("_", StringComparison.Ordinal) - => ToNpgsqlDbType(unqualifiedName.Substring(1)) is { } elementNpgsqlDbType + => ToNpgsqlDbType(unqualifiedName.Slice(1).ToString()) is { } elementNpgsqlDbType ? elementNpgsqlDbType | NpgsqlDbType.Array : null, // e.g. custom ranges, plugin types etc. diff --git a/src/Npgsql/NpgsqlTypes/NpgsqlRange.cs b/src/Npgsql/NpgsqlTypes/NpgsqlRange.cs index c260202ce9..23b2578c13 100644 --- a/src/Npgsql/NpgsqlTypes/NpgsqlRange.cs +++ b/src/Npgsql/NpgsqlTypes/NpgsqlRange.cs @@ -378,8 +378,7 @@ public override string ToString() [RequiresUnreferencedCode("Parse implementations for certain types of T may require members that have been trimmed.")] public static NpgsqlRange Parse(string value) { - if (value is null) - throw new ArgumentNullException(nameof(value)); + ArgumentNullException.ThrowIfNull(value); value = value.Trim(); @@ -395,8 +394,8 @@ public static NpgsqlRange Parse(string value) if (!lowerInclusive && !lowerExclusive) throw new FormatException("Malformed range literal. Missing left parenthesis or bracket."); - var upperInclusive = value[value.Length - 1] == UpperInclusiveBound; - var upperExclusive = value[value.Length - 1] == UpperExclusiveBound; + var upperInclusive = value[^1] == UpperInclusiveBound; + var upperExclusive = value[^1] == UpperExclusiveBound; if (!upperInclusive && !upperExclusive) throw new FormatException("Malformed range literal. Missing right parenthesis or bracket."); diff --git a/src/Npgsql/NpgsqlTypes/NpgsqlTsQuery.cs b/src/Npgsql/NpgsqlTypes/NpgsqlTsQuery.cs index eb1dfbb86b..96585832f3 100644 --- a/src/Npgsql/NpgsqlTypes/NpgsqlTsQuery.cs +++ b/src/Npgsql/NpgsqlTypes/NpgsqlTsQuery.cs @@ -79,8 +79,7 @@ public override string ToString() [Obsolete("Client-side parsing of NpgsqlTsQuery is unreliable and cannot fully duplicate the PostgreSQL logic. Use PG functions instead (e.g. to_tsquery)")] public static NpgsqlTsQuery Parse(string value) { - if (value == null) - throw new ArgumentNullException(nameof(value)); + ArgumentNullException.ThrowIfNull(value); var valStack = new Stack(); var opStack = new Stack(); @@ -380,16 +379,10 @@ public override bool Equals(object? obj) => left is null ? right is not null : !left.Equals(right); } -readonly struct NpgsqlTsQueryOperator +readonly struct NpgsqlTsQueryOperator(char character, short followedByDistance) { - public readonly char Char; - public readonly short FollowedByDistance; - - public NpgsqlTsQueryOperator(char character, short followedByDistance) - { - Char = character; - FollowedByDistance = followedByDistance; - } + public readonly char Char = character; + public readonly short FollowedByDistance = followedByDistance; public static implicit operator NpgsqlTsQueryOperator(char c) => new(c, 0); public static implicit operator char(NpgsqlTsQueryOperator o) => o.Char; @@ -410,8 +403,7 @@ public string Text get => _text; set { - if (string.IsNullOrEmpty(value)) - throw new ArgumentException("Text is null or empty string", nameof(value)); + ArgumentException.ThrowIfNullOrEmpty(value); _text = value; } @@ -539,9 +531,7 @@ public sealed class NpgsqlTsQueryNot : NpgsqlTsQuery /// public NpgsqlTsQueryNot(NpgsqlTsQuery child) : base(NodeKind.Not) - { - Child = child; - } + => Child = child; internal override void WriteCore(StringBuilder sb, bool first = false) { @@ -683,8 +673,7 @@ public NpgsqlTsQueryFollowedBy( NpgsqlTsQuery right) : base(NodeKind.Phrase, left, right) { - if (distance < 0) - throw new ArgumentOutOfRangeException(nameof(distance)); + ArgumentOutOfRangeException.ThrowIfNegative(distance); Distance = distance; } @@ -721,7 +710,7 @@ public override int GetHashCode() } /// -/// Represents an empty tsquery. Shold only be used as top node. +/// Represents an empty tsquery. Should only be used as top node. /// public sealed class NpgsqlTsQueryEmpty : NpgsqlTsQuery { diff --git a/src/Npgsql/NpgsqlTypes/NpgsqlTsVector.cs b/src/Npgsql/NpgsqlTypes/NpgsqlTsVector.cs index 2ec4c66afe..4dd1e28b08 100644 --- a/src/Npgsql/NpgsqlTypes/NpgsqlTsVector.cs +++ b/src/Npgsql/NpgsqlTypes/NpgsqlTsVector.cs @@ -11,6 +11,11 @@ namespace NpgsqlTypes; /// public sealed class NpgsqlTsVector : IEnumerable, IEquatable { + /// + /// Represents an empty tsvector. + /// + public static readonly NpgsqlTsVector Empty = new NpgsqlTsVector([], noCheck: true); + readonly List _lexemes; internal NpgsqlTsVector(List lexemes, bool noCheck = false) @@ -21,7 +26,7 @@ internal NpgsqlTsVector(List lexemes, bool noCheck = false) return; } - _lexemes = new List(lexemes); + _lexemes = [..lexemes]; if (_lexemes.Count == 0) return; @@ -76,8 +81,7 @@ internal NpgsqlTsVector(List lexemes, bool noCheck = false) [Obsolete("Client-side parsing of NpgsqlTsVector is unreliable and cannot fully duplicate the PostgreSQL logic. Use PG functions instead (e.g. to_tsvector)")] public static NpgsqlTsVector Parse(string value) { - if (value == null) - throw new ArgumentNullException(nameof(value)); + ArgumentNullException.ThrowIfNull(value); var lexemes = new List(); var pos = 0; @@ -167,7 +171,7 @@ public static NpgsqlTsVector Parse(string value) goto WaitWord; StartPosInfo: - wordEntryPositions = new List(); + wordEntryPositions = []; InPosInfo: var digitPos = pos; @@ -189,7 +193,7 @@ public static NpgsqlTsVector Parse(string value) if (value[pos] >= 'B' && value[pos] <= 'D' || value[pos] >= 'b' && value[pos] <= 'd') { var weight = value[pos]; - if (weight >= 'b' && weight <= 'd') + if (weight is >= 'b' and <= 'd') weight = (char)(weight - ('b' - 'B')); wordEntryPositions.Add(new Lexeme.WordEntryPos(wordPos, Lexeme.Weight.D + ('D' - weight))); pos++; @@ -321,7 +325,7 @@ internal Lexeme(string text, List? wordEntryPositions, bool noCopy { Text = text; if (wordEntryPositions != null) - WordEntryPositions = noCopy ? wordEntryPositions : new List(wordEntryPositions); + WordEntryPositions = noCopy ? wordEntryPositions : [..wordEntryPositions]; else WordEntryPositions = null; } @@ -343,7 +347,7 @@ internal Lexeme(string text, List? wordEntryPositions, bool noCopy return list; // Don't change the original list, as the user might inspect it later if he holds a reference to the lexeme's list - list = new List(list); + list = [..list]; list.Sort((x, y) => x.Pos.CompareTo(y.Pos)); @@ -414,9 +418,7 @@ public struct WordEntryPos : IEquatable internal short Value { get; } internal WordEntryPos(short value) - { - Value = value; - } + => Value = value; /// /// Creates a WordEntryPos with a given position and weight. diff --git a/src/Npgsql/NpgsqlTypes/NpgsqlTypes.cs b/src/Npgsql/NpgsqlTypes/NpgsqlTypes.cs index b16fe8ccea..4f63a9defb 100644 --- a/src/Npgsql/NpgsqlTypes/NpgsqlTypes.cs +++ b/src/Npgsql/NpgsqlTypes/NpgsqlTypes.cs @@ -17,17 +17,10 @@ namespace NpgsqlTypes; /// /// See https://www.postgresql.org/docs/current/static/datatype-geometric.html /// -public struct NpgsqlPoint : IEquatable +public struct NpgsqlPoint(double x, double y) : IEquatable { - public double X { get; set; } - public double Y { get; set; } - - public NpgsqlPoint(double x, double y) - : this() - { - X = x; - Y = y; - } + public double X { get; set; } = x; + public double Y { get; set; } = y; // ReSharper disable CompareOfFloatsByEqualityOperator public bool Equals(NpgsqlPoint other) => X == other.X && Y == other.Y; @@ -45,6 +38,8 @@ public override int GetHashCode() public override string ToString() => string.Format(CultureInfo.InvariantCulture, "({0},{1})", X, Y); + + public void Deconstruct(out double x, out double y) => (x, y) = (X, Y); } /// @@ -53,19 +48,11 @@ public override string ToString() /// /// See https://www.postgresql.org/docs/current/static/datatype-geometric.html /// -public struct NpgsqlLine : IEquatable +public struct NpgsqlLine(double a, double b, double c) : IEquatable { - public double A { get; set; } - public double B { get; set; } - public double C { get; set; } - - public NpgsqlLine(double a, double b, double c) - : this() - { - A = a; - B = b; - C = c; - } + public double A { get; set; } = a; + public double B { get; set; } = b; + public double C { get; set; } = c; public override string ToString() => string.Format(CultureInfo.InvariantCulture, "{{{0},{1},{2}}}", A, B, C); @@ -81,6 +68,8 @@ public override bool Equals(object? obj) public static bool operator ==(NpgsqlLine x, NpgsqlLine y) => x.Equals(y); public static bool operator !=(NpgsqlLine x, NpgsqlLine y) => !(x == y); + + public void Deconstruct(out double a, out double b, out double c) => (a, b, c) = (A, B, C); } /// @@ -118,6 +107,8 @@ public override bool Equals(object? obj) public static bool operator ==(NpgsqlLSeg x, NpgsqlLSeg y) => x.Equals(y); public static bool operator !=(NpgsqlLSeg x, NpgsqlLSeg y) => !(x == y); + + public void Deconstruct(out NpgsqlPoint start, out NpgsqlPoint end) => (start, end) = (Start, End); } /// @@ -193,6 +184,30 @@ void NormalizeBox() if (_upperRight.Y < _lowerLeft.Y) (_upperRight.Y, _lowerLeft.Y) = (_lowerLeft.Y, _upperRight.Y); } + + public void Deconstruct(out NpgsqlPoint lowerLeft, out NpgsqlPoint upperRight) + { + lowerLeft = LowerLeft; + upperRight = UpperRight; + } + + public void Deconstruct(out double left, out double right, out double bottom, out double top) + { + left = Left; + right = Right; + bottom = Bottom; + top = Top; + } + + public void Deconstruct(out double left, out double right, out double bottom, out double top, out double width, out double height) + { + left = Left; + right = Right; + bottom = Bottom; + top = Top; + width = Width; + height = Height; + } } /// @@ -200,15 +215,18 @@ void NormalizeBox() /// public struct NpgsqlPath : IList, IEquatable { - readonly List _points; + List _points; + + List Points => _points ??= []; + public bool Open { get; set; } public NpgsqlPath() - => _points = new(); + => _points = []; public NpgsqlPath(IEnumerable points, bool open) { - _points = new List(points); + _points = [..points]; Open = open; } @@ -217,7 +235,7 @@ public NpgsqlPath(params NpgsqlPoint[] points) : this(points, false) {} public NpgsqlPath(bool open) : this() { - _points = new List(); + _points = []; Open = open; } @@ -231,23 +249,23 @@ public NpgsqlPath(int capacity) : this(capacity, false) {} public NpgsqlPoint this[int index] { - get => _points[index]; - set => _points[index] = value; + get => Points[index]; + set => Points[index] = value; } - public int Capacity => _points.Capacity; - public int Count => _points.Count; + public int Capacity => Points.Capacity; + public int Count => _points?.Count ?? 0; public bool IsReadOnly => false; - public int IndexOf(NpgsqlPoint item) => _points.IndexOf(item); - public void Insert(int index, NpgsqlPoint item) => _points.Insert(index, item); - public void RemoveAt(int index) => _points.RemoveAt(index); - public void Add(NpgsqlPoint item) => _points.Add(item); - public void Clear() => _points.Clear(); - public bool Contains(NpgsqlPoint item) => _points.Contains(item); - public void CopyTo(NpgsqlPoint[] array, int arrayIndex) => _points.CopyTo(array, arrayIndex); - public bool Remove(NpgsqlPoint item) => _points.Remove(item); - public IEnumerator GetEnumerator() => _points.GetEnumerator(); + public int IndexOf(NpgsqlPoint item) => Points.IndexOf(item); + public void Insert(int index, NpgsqlPoint item) => Points.Insert(index, item); + public void RemoveAt(int index) => Points.RemoveAt(index); + public void Add(NpgsqlPoint item) => Points.Add(item); + public void Clear() => Points.Clear(); + public bool Contains(NpgsqlPoint item) => Points.Contains(item); + public void CopyTo(NpgsqlPoint[] array, int arrayIndex) => Points.CopyTo(array, arrayIndex); + public bool Remove(NpgsqlPoint item) => Points.Remove(item); + public IEnumerator GetEnumerator() => Points.GetEnumerator(); IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); public bool Equals(NpgsqlPath other) @@ -287,12 +305,12 @@ public override string ToString() var sb = new StringBuilder(); sb.Append(Open ? '[' : '('); int i; - for (i = 0; i < _points.Count; i++) + for (i = 0; i < Count; i++) { var p = _points[i]; sb.AppendFormat(CultureInfo.InvariantCulture, "({0},{1})", p.X, p.Y); if (i < _points.Count - 1) - sb.Append(","); + sb.Append(','); } sb.Append(Open ? ']' : ')'); return sb.ToString(); @@ -302,15 +320,17 @@ public override string ToString() /// /// Represents a PostgreSQL Polygon type. /// -public readonly struct NpgsqlPolygon : IList, IEquatable +public struct NpgsqlPolygon : IList, IEquatable { - readonly List _points; + List _points; + + List Points => _points ??= []; public NpgsqlPolygon() - => _points = new(); + => _points = []; public NpgsqlPolygon(IEnumerable points) - => _points = new List(points); + => _points = [..points]; public NpgsqlPolygon(params NpgsqlPoint[] points) : this((IEnumerable) points) {} @@ -319,23 +339,23 @@ public NpgsqlPolygon(int capacity) public NpgsqlPoint this[int index] { - get => _points[index]; - set => _points[index] = value; + get => Points[index]; + set => Points[index] = value; } - public int Capacity => _points.Capacity; - public int Count => _points.Count; + public int Capacity => Points.Capacity; + public int Count => _points?.Count ?? 0; public bool IsReadOnly => false; - public int IndexOf(NpgsqlPoint item) => _points.IndexOf(item); - public void Insert(int index, NpgsqlPoint item) => _points.Insert(index, item); - public void RemoveAt(int index) => _points.RemoveAt(index); - public void Add(NpgsqlPoint item) => _points.Add(item); - public void Clear() => _points.Clear(); - public bool Contains(NpgsqlPoint item) => _points.Contains(item); - public void CopyTo(NpgsqlPoint[] array, int arrayIndex) => _points.CopyTo(array, arrayIndex); - public bool Remove(NpgsqlPoint item) => _points.Remove(item); - public IEnumerator GetEnumerator() => _points.GetEnumerator(); + public int IndexOf(NpgsqlPoint item) => Points.IndexOf(item); + public void Insert(int index, NpgsqlPoint item) => Points.Insert(index, item); + public void RemoveAt(int index) => Points.RemoveAt(index); + public void Add(NpgsqlPoint item) => Points.Add(item); + public void Clear() => Points.Clear(); + public bool Contains(NpgsqlPoint item) => Points.Contains(item); + public void CopyTo(NpgsqlPoint[] array, int arrayIndex) => Points.CopyTo(array, arrayIndex); + public bool Remove(NpgsqlPoint item) => Points.Remove(item); + public IEnumerator GetEnumerator() => Points.GetEnumerator(); IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); public bool Equals(NpgsqlPolygon other) @@ -374,7 +394,7 @@ public override string ToString() var sb = new StringBuilder(); sb.Append('('); int i; - for (i = 0; i < _points.Count; i++) + for (i = 0; i < Count; i++) { var p = _points[i]; sb.AppendFormat(CultureInfo.InvariantCulture, "({0},{1})", p.X, p.Y); @@ -390,25 +410,15 @@ public override string ToString() /// /// Represents a PostgreSQL Circle type. /// -public struct NpgsqlCircle : IEquatable +public struct NpgsqlCircle(double x, double y, double radius) : IEquatable { - public double X { get; set; } - public double Y { get; set; } - public double Radius { get; set; } + public double X { get; set; } = x; + public double Y { get; set; } = y; + public double Radius { get; set; } = radius; public NpgsqlCircle(NpgsqlPoint center, double radius) - : this() - { - X = center.X; - Y = center.Y; - Radius = radius; - } - - public NpgsqlCircle(double x, double y, double radius) : this() + : this(center.X, center.Y, radius) { - X = x; - Y = y; - Radius = radius; } public NpgsqlPoint Center @@ -433,6 +443,19 @@ public override string ToString() public override int GetHashCode() => HashCode.Combine(X, Y, Radius); + + public void Deconstruct(out double x, out double y, out double radius) + { + x = X; + y = Y; + radius = Radius; + } + + public void Deconstruct(out NpgsqlPoint center, out double radius) + { + center = Center; + radius = Radius; + } } /// @@ -448,9 +471,7 @@ public readonly record struct NpgsqlInet public NpgsqlInet(IPAddress address, byte netmask) { - if (address.AddressFamily != AddressFamily.InterNetwork && address.AddressFamily != AddressFamily.InterNetworkV6) - throw new ArgumentException("Only IPAddress of InterNetwork or InterNetworkV6 address families are accepted", nameof(address)); - + CheckAddressFamily(address); Address = address; Netmask = netmask; } @@ -461,16 +482,27 @@ public NpgsqlInet(IPAddress address) } public NpgsqlInet(string addr) - => (Address, Netmask) = addr.Split('/') switch + { + switch (addr.Split('/')) { - { Length: 2 } segments => (IPAddress.Parse(segments[0]), byte.Parse(segments[1])), - { Length: 1 } segments => (IPAddress.Parse(segments[0]), (byte)32), - _ => throw new FormatException("Invalid number of parts in CIDR specification") - }; + case { Length: 2 } segments: + (Address, Netmask) = (IPAddress.Parse(segments[0]), byte.Parse(segments[1])); + break; + case { Length: 1 } segments: + var ipAddr = IPAddress.Parse(segments[0]); + CheckAddressFamily(ipAddr); + (Address, Netmask) = ( + ipAddr, + ipAddr.AddressFamily == AddressFamily.InterNetworkV6 ? (byte)128 : (byte)32); + break; + default: + throw new FormatException("Invalid number of parts in CIDR specification"); + } + } public override string ToString() - => (Address.AddressFamily == AddressFamily.InterNetwork && Netmask == 32) || - (Address.AddressFamily == AddressFamily.InterNetworkV6 && Netmask == 128) + => (Address?.AddressFamily == AddressFamily.InterNetwork && Netmask == 32) || + (Address?.AddressFamily == AddressFamily.InterNetworkV6 && Netmask == 128) ? Address.ToString() : $"{Address}/{Netmask}"; @@ -480,11 +512,24 @@ public static explicit operator IPAddress(NpgsqlInet inet) public static implicit operator NpgsqlInet(IPAddress ip) => new(ip); + public static implicit operator NpgsqlInet(IPNetwork cidr) + => new( + cidr.BaseAddress, + cidr.PrefixLength <= byte.MaxValue + ? (byte)cidr.PrefixLength + : throw new ArgumentOutOfRangeException(nameof(cidr), "IPNetwork.PrefixLength is too large to fit in a byte")); + public void Deconstruct(out IPAddress address, out byte netmask) { address = Address; netmask = Netmask; } + + static void CheckAddressFamily(IPAddress address) + { + if (address.AddressFamily != AddressFamily.InterNetwork && address.AddressFamily != AddressFamily.InterNetworkV6) + throw new ArgumentException("Only IPAddress of InterNetwork or InterNetworkV6 address families are accepted", nameof(address)); + } } /// @@ -493,6 +538,7 @@ public void Deconstruct(out IPAddress address, out byte netmask) /// /// https://www.postgresql.org/docs/current/static/datatype-net-types.html /// +[Obsolete("Use .NET IPNetwork instead of NpgsqlCidr to map to PostgreSQL cidr")] public readonly record struct NpgsqlCidr { public IPAddress Address { get; } @@ -537,23 +583,17 @@ public void Deconstruct(out IPAddress address, out byte netmask) /// /// https://www.postgresql.org/docs/current/static/datatype-oid.html /// -public readonly struct NpgsqlTid : IEquatable +public readonly struct NpgsqlTid(uint blockNumber, ushort offsetNumber) : IEquatable { /// /// Block number /// - public uint BlockNumber { get; } + public uint BlockNumber { get; } = blockNumber; /// /// Tuple index within block /// - public ushort OffsetNumber { get; } - - public NpgsqlTid(uint blockNumber, ushort offsetNumber) - { - BlockNumber = blockNumber; - OffsetNumber = offsetNumber; - } + public ushort OffsetNumber { get; } = offsetNumber; public bool Equals(NpgsqlTid other) => BlockNumber == other.BlockNumber && OffsetNumber == other.OffsetNumber; @@ -565,6 +605,12 @@ public override bool Equals(object? o) public static bool operator ==(NpgsqlTid left, NpgsqlTid right) => left.Equals(right); public static bool operator !=(NpgsqlTid left, NpgsqlTid right) => !(left == right); public override string ToString() => $"({BlockNumber},{OffsetNumber})"; + + public void Deconstruct(out uint blockNumber, out ushort offsetNumber) + { + blockNumber = BlockNumber; + offsetNumber = OffsetNumber; + } } #pragma warning restore 1591 diff --git a/src/Npgsql/PoolingDataSource.cs b/src/Npgsql/PoolingDataSource.cs index 192a86c052..4de3cfd928 100644 --- a/src/Npgsql/PoolingDataSource.cs +++ b/src/Npgsql/PoolingDataSource.cs @@ -5,7 +5,6 @@ using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; -using System.Transactions; using Microsoft.Extensions.Logging; using Npgsql.Internal; using Npgsql.Util; @@ -31,8 +30,6 @@ class PoolingDataSource : NpgsqlDataSource /// private protected readonly NpgsqlConnector?[] Connectors; - readonly NpgsqlMultiHostDataSource? _parentPool; - /// /// Reader side for the idle connector channel. Contains nulls in order to release waiting attempts after /// a connector has been physically closed/broken. @@ -76,18 +73,15 @@ internal sealed override (int Total, int Idle, int Busy) Statistics internal PoolingDataSource( NpgsqlConnectionStringBuilder settings, - NpgsqlDataSourceConfiguration dataSourceConfig, - NpgsqlMultiHostDataSource? parentPool = null) - : base(settings, dataSourceConfig) + NpgsqlDataSourceConfiguration dataSourceConfig) + : base(settings, dataSourceConfig, reportMetrics: true) { if (settings.MaxPoolSize < settings.MinPoolSize) throw new ArgumentException($"Connection can't have 'Max Pool Size' {settings.MaxPoolSize} under 'Min Pool Size' {settings.MinPoolSize}"); - _parentPool = parentPool; - // We enforce Max Pool Size, so no need to to create a bounded channel (which is less efficient) - // On the consuming side, we have the multiplexing write loop but also non-multiplexing Rents - // On the producing side, we have connections being released back into the pool (both multiplexing and not) + // On the consuming side, we have Rents + // On the producing side, we have connections being released back into the pool var idleChannel = Channel.CreateUnbounded(); _idleConnectorReader = idleChannel.Reader; IdleConnectorWriter = idleChannel.Writer; @@ -102,7 +96,8 @@ internal PoolingDataSource( if (connectionIdleLifetime < pruningSamplingInterval) throw new ArgumentException($"Connection can't have {nameof(settings.ConnectionIdleLifetime)} {connectionIdleLifetime} under {nameof(settings.ConnectionPruningInterval)} {pruningSamplingInterval}"); - _pruningTimer = new Timer(PruningTimerCallback, this, Timeout.Infinite, Timeout.Infinite); + using (ExecutionContext.SuppressFlow()) // Don't capture the current ExecutionContext and its AsyncLocals onto the timer causing them to live forever + _pruningTimer = new Timer(PruningTimerCallback, this, Timeout.Infinite, Timeout.Infinite); _pruningSampleSize = DivideRoundingUp(settings.ConnectionIdleLifetime, settings.ConnectionPruningInterval); _pruningMedianIndex = DivideRoundingUp(_pruningSampleSize, 2) - 1; // - 1 to go from length to index _pruningSamplingInterval = pruningSamplingInterval; @@ -243,19 +238,12 @@ bool CheckIdleConnector([NotNullWhen(true)] NpgsqlConnector? connector) return false; } - // The connector directly references the data source type mapper into the connector, to protect it against changes by a concurrent + // The connector directly references the current reloadable state reference, to protect it against changes by a concurrent // ReloadTypes. We update them here before returning the connector from the pool. - Debug.Assert(SerializerOptions is not null); - Debug.Assert(DatabaseInfo is not null); - connector.SerializerOptions = SerializerOptions; - connector.DatabaseInfo = DatabaseInfo; + connector.ReloadableState = CurrentReloadableState; Debug.Assert(connector.State == ConnectorState.Ready, $"Got idle connector but {nameof(connector.State)} is {connector.State}"); - Debug.Assert(connector.CommandsInFlightCount == 0, - $"Got idle connector but {nameof(connector.CommandsInFlightCount)} is {connector.CommandsInFlightCount}"); - Debug.Assert(connector.MultiplexAsyncWritingLock == 0, - $"Got idle connector but {nameof(connector.MultiplexAsyncWritingLock)} is 1"); return true; } @@ -273,14 +261,10 @@ bool CheckIdleConnector([NotNullWhen(true)] NpgsqlConnector? connector) try { // We've managed to increase the open counter, open a physical connections. -#if NET7_0_OR_GREATER var startTime = Stopwatch.GetTimestamp(); -#endif var connector = new NpgsqlConnector(this, conn) { ClearCounter = _clearCounter }; await connector.Open(timeout, async, cancellationToken).ConfigureAwait(false); -#if NET7_0_OR_GREATER MetricsReporter.ReportConnectionCreateTime(Stopwatch.GetElapsedTime(startTime)); -#endif var i = 0; for (; i < MaxConnections; i++) @@ -322,8 +306,6 @@ bool CheckIdleConnector([NotNullWhen(true)] NpgsqlConnector? connector) internal sealed override void Return(NpgsqlConnector connector) { Debug.Assert(!connector.InTransaction); - Debug.Assert(connector.MultiplexAsyncWritingLock == 0 || connector.IsBroken || connector.IsClosed, - $"About to return multiplexing connector to the pool, but {nameof(connector.MultiplexAsyncWritingLock)} is {connector.MultiplexAsyncWritingLock}"); // If Clear/ClearAll has been been called since this connector was first opened, // throw it away. The same if it's broken (in which case CloseConnector is only @@ -340,7 +322,7 @@ internal sealed override void Return(NpgsqlConnector connector) Debug.Assert(written); } - internal override void Clear() + public override void Clear() { Interlocked.Increment(ref _clearCounter); @@ -400,11 +382,6 @@ void CloseConnector(NpgsqlConnector connector) UpdatePruningTimer(); } - internal override bool TryRemovePendingEnlistedConnector(NpgsqlConnector connector, Transaction transaction) - => _parentPool is null - ? base.TryRemovePendingEnlistedConnector(connector, transaction) - : _parentPool.TryRemovePendingEnlistedConnector(connector, transaction); - #region Pruning void UpdatePruningTimer() diff --git a/src/Npgsql/PostgresDatabaseInfo.cs b/src/Npgsql/PostgresDatabaseInfo.cs index 0d2397eac3..1c1b518a3f 100644 --- a/src/Npgsql/PostgresDatabaseInfo.cs +++ b/src/Npgsql/PostgresDatabaseInfo.cs @@ -44,19 +44,23 @@ class PostgresDatabaseInfo : NpgsqlDatabaseInfo /// List? _types; + bool? _isRedshift; + /// protected override IEnumerable GetTypes() => _types ?? (IEnumerable)Array.Empty(); /// /// The PostgreSQL version string as returned by the version() function. Populated during loading. /// - public string LongVersion { get; set; } = default!; + public string LongVersion { get; set; } = ""; /// /// True if the backend is Amazon Redshift; otherwise, false. /// - public bool IsRedshift { get; private set; } + public bool IsRedshift => _isRedshift ??= LongVersion.Contains("redshift", StringComparison.OrdinalIgnoreCase); + // Note that UNLISTEN is only needed for the reset message, but those don't get generated for Redshift anyway because e.g. DISCARD + // isn't supported there either. So the IsRedshift check isn't actually used, but is here for completeness. /// public override bool SupportsUnlisten => Version.IsGreaterOrEqual(6, 4) && !IsRedshift; @@ -97,10 +101,11 @@ internal async Task LoadPostgresInfo(NpgsqlConnector conn, NpgsqlTimeout timeout conn.PostgresParameters.TryGetValue("integer_datetimes", out var intDateTimes) && intDateTimes == "on"; - IsRedshift = conn.Settings.ServerCompatibilityMode == ServerCompatibilityMode.Redshift; _types = await LoadBackendTypes(conn, timeout, async).ConfigureAwait(false); } + const string BuiltinSchemaListSqlFragment = "'pg_catalog', 'information_schema', 'pg_toast'"; + /// /// Generates a raw SQL query string to select type information. /// @@ -111,7 +116,7 @@ internal async Task LoadPostgresInfo(NpgsqlConnector conn, NpgsqlTimeout timeout /// For arrays and ranges, join in the element OID and type (to filter out arrays of unhandled /// types). /// - static string GenerateLoadTypesQuery(bool withRange, bool withMultirange, bool loadTableComposites) + static string GenerateLoadTypesQuery(bool withRange, bool withMultirange, bool loadTableComposites, string? schemaListSqlFragment, bool hasTypeCategory) => $@" SELECT ns.nspname, t.oid, t.typname, t.typtype, t.typnotnull, t.elemtypoid FROM ( @@ -122,6 +127,7 @@ static string GenerateLoadTypesQuery(bool withRange, bool withMultirange, bool l typ.oid, typ.typnamespace, typ.typname, typ.typtype, typ.typrelid, typ.typnotnull, typ.relkind, elemtyp.oid AS elemtypoid, elemtyp.typname AS elemtypname, elemcls.relkind AS elemrelkind, CASE WHEN elemproc.proname='array_recv' THEN 'a' ELSE elemtyp.typtype END AS elemtyptype + {(hasTypeCategory ? ", typ.typcategory" : "")} FROM ( SELECT typ.oid, typnamespace, typname, typrelid, typnotnull, relkind, typelem AS elemoid, CASE WHEN proc.proname='array_recv' THEN 'a' ELSE typ.typtype END AS typtype, @@ -131,6 +137,7 @@ static string GenerateLoadTypesQuery(bool withRange, bool withMultirange, bool l {(withMultirange ? "WHEN typ.typtype='m' THEN (SELECT rngtypid FROM pg_range WHERE rngmultitypid = typ.oid)" : "")} WHEN typ.typtype='d' THEN typ.typbasetype END AS elemtypoid + {(hasTypeCategory ? ", typ.typcategory" : "")} FROM pg_type AS typ LEFT JOIN pg_class AS cls ON (cls.oid = typ.typrelid) LEFT JOIN pg_proc AS proc ON proc.oid = typ.typreceive @@ -142,25 +149,26 @@ LEFT JOIN pg_class AS elemcls ON (elemcls.oid = elemtyp.typrelid) ) AS t JOIN pg_namespace AS ns ON (ns.oid = typnamespace) WHERE - typtype IN ('b', 'r', 'm', 'e', 'd') OR -- Base, range, multirange, enum, domain - (typtype = 'c' AND {(loadTableComposites ? "ns.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')" : "relkind='c'")}) OR -- User-defined free-standing composites (not table composites) by default + {(schemaListSqlFragment is not null ? $"(ns.nspname IN ({schemaListSqlFragment}){(hasTypeCategory ? " OR typcategory = 'U'" : "" )}) AND " : "")} + (typtype IN ('b', 'r', 'm', 'e', 'd') OR -- Base, range, multirange, enum, domain + (typtype = 'c' AND {(loadTableComposites ? $"ns.nspname NOT IN ({BuiltinSchemaListSqlFragment})" : "relkind='c'")}) OR -- User-defined free-standing composites (not table composites) by default (typtype = 'p' AND typname IN ('record', 'void', 'unknown')) OR -- Some special supported pseudo-types (typtype = 'a' AND ( -- Array of... elemtyptype IN ('b', 'r', 'm', 'e', 'd') OR -- Array of base, range, multirange, enum, domain (elemtyptype = 'p' AND elemtypname IN ('record', 'void')) OR -- Arrays of special supported pseudo-types - (elemtyptype = 'c' AND {(loadTableComposites ? "ns.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')" : "elemrelkind='c'")}) -- Array of user-defined free-standing composites (not table composites) by default - )) + (elemtyptype = 'c' AND {(loadTableComposites ? $"ns.nspname NOT IN ({BuiltinSchemaListSqlFragment})" : "elemrelkind='c'")}) -- Array of user-defined free-standing composites (not table composites) by default + ))) ORDER BY CASE WHEN typtype IN ('b', 'e', 'p') THEN 0 -- First base types, enums, pseudo-types - WHEN typtype = 'r' THEN 1 -- Ranges after - WHEN typtype = 'm' THEN 2 -- Multiranges after - WHEN typtype = 'c' THEN 3 -- Composites after + WHEN typtype = 'c' THEN 1 -- Composites after (fields loaded later in 2nd pass) + WHEN typtype = 'r' THEN 2 -- Ranges after + WHEN typtype = 'm' THEN 3 -- Multiranges after WHEN typtype = 'd' AND elemtyptype <> 'a' THEN 4 -- Domains over non-arrays after WHEN typtype = 'a' THEN 5 -- Arrays after WHEN typtype = 'd' AND elemtyptype = 'a' THEN 6 -- Domains over arrays last END;"; - static string GenerateLoadCompositeTypesQuery(bool loadTableComposites) + static string GenerateLoadCompositeTypesQuery(bool loadTableComposites, string? schemaListSqlFragment) => $@" -- Load field definitions for (free-standing) composite types SELECT typ.oid, att.attname, att.atttypid @@ -169,17 +177,20 @@ JOIN pg_namespace AS ns ON (ns.oid = typ.typnamespace) JOIN pg_class AS cls ON (cls.oid = typ.typrelid) JOIN pg_attribute AS att ON (att.attrelid = typ.typrelid) WHERE - (typ.typtype = 'c' AND {(loadTableComposites ? "ns.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')" : "cls.relkind='c'")}) AND + (typ.typtype = 'c' AND {(loadTableComposites ? $"ns.nspname NOT IN ({BuiltinSchemaListSqlFragment})" : "cls.relkind='c'")}) AND + {(schemaListSqlFragment is not null ? $"(ns.nspname IN ({schemaListSqlFragment})) AND " : "")} attnum > 0 AND -- Don't load system attributes NOT attisdropped ORDER BY typ.oid, att.attnum;"; - static string GenerateLoadEnumFieldsQuery(bool withEnumSortOrder) + static string GenerateLoadEnumFieldsQuery(bool withEnumSortOrder, string? schemaListSqlFragment) => $@" -- Load enum fields -SELECT pg_type.oid, enumlabel +SELECT typ.oid, enumlabel FROM pg_enum -JOIN pg_type ON pg_type.oid=enumtypid +JOIN pg_type AS typ ON typ.oid = enumtypid +JOIN pg_namespace AS ns ON ns.oid = typ.typnamespace +{(schemaListSqlFragment is not null ? $"WHERE (ns.nspname IN ({schemaListSqlFragment}))" : "")} ORDER BY oid{(withEnumSortOrder ? ", enumsortorder" : "")};"; /// @@ -196,10 +207,31 @@ FROM pg_enum internal async Task> LoadBackendTypes(NpgsqlConnector conn, NpgsqlTimeout timeout, bool async) { var versionQuery = "SELECT version();"; - var loadTypesQuery = GenerateLoadTypesQuery(SupportsRangeTypes, SupportsMultirangeTypes, conn.Settings.LoadTableComposites); - var loadCompositeTypesQuery = GenerateLoadCompositeTypesQuery(conn.Settings.LoadTableComposites); + var typeLoading = conn.DataSource.Configuration.TypeLoading; + var loadTableComposites = typeLoading.LoadTableComposites; + + // Escape the schemas configured by the user, we need these as literals to be used in an IN() operator, and we cannot use parameters. + // Add an opening quote, escape any quotes in the schema, and add a closing quote. + string? schemaListSqlFragment = null; + if (typeLoading.TypeLoadingSchemas is not null) + { + var builder = new StringBuilder(BuiltinSchemaListSqlFragment); + for (var i = 0; i < typeLoading.TypeLoadingSchemas.Length; i++) + { + builder.Append(", "); + var schema = typeLoading.TypeLoadingSchemas[i]; + builder.Append('\''); + builder.Append(EscapeLiteral(schema)); + builder.Append('\''); + } + + schemaListSqlFragment = builder.ToString(); + } + + var loadTypesQuery = GenerateLoadTypesQuery(SupportsRangeTypes, SupportsMultirangeTypes, loadTableComposites, schemaListSqlFragment, HasTypeCategory); + var loadCompositeTypesQuery = GenerateLoadCompositeTypesQuery(loadTableComposites, schemaListSqlFragment); var loadEnumFieldsQuery = SupportsEnumTypes - ? GenerateLoadEnumFieldsQuery(HasEnumSortOrder) + ? GenerateLoadEnumFieldsQuery(HasEnumSortOrder, schemaListSqlFragment) : string.Empty; timeout.CheckAndApply(conn); @@ -319,6 +351,7 @@ static string SanitizeForReplicationConnection(string str) // Then load the types Expect(await conn.ReadMessage(async).ConfigureAwait(false), conn); IBackendMessage msg; + var unknownPostgresTypes = new List(); while (true) { msg = await conn.ReadMessage(async).ConfigureAwait(false); @@ -335,93 +368,32 @@ static string SanitizeForReplicationConnection(string str) var len = conn.ReadBuffer.ReadInt32(); var elemtypoid = len == -1 ? 0 : uint.Parse(conn.ReadBuffer.ReadString(len), NumberFormatInfo.InvariantInfo); - switch (typtype) - { - case 'b': // Normal base type - var baseType = new PostgresBaseType(nspname, typname, oid); - byOID[baseType.OID] = baseType; - continue; + var postgresTypeDefinition = new PostgresTypeDefinition(nspname, oid, typname, typtype, typnotnull, elemtypoid); + if (!TryAddPostgresType(postgresTypeDefinition, byOID)) + unknownPostgresTypes.Add(postgresTypeDefinition); + } - case 'a': // Array + while (unknownPostgresTypes.Count > 0) + { + var hasChanges = false; + for (var i = unknownPostgresTypes.Count - 1; i >= 0; i--) { - Debug.Assert(elemtypoid > 0); - if (!byOID.TryGetValue(elemtypoid, out var elementPostgresType)) + var unknownPostgresType = unknownPostgresTypes[i]; + if (TryAddPostgresType(unknownPostgresType, byOID)) { - _connectionLogger.LogTrace("Array type '{ArrayTypeName}' refers to unknown element with OID {ElementTypeOID}, skipping", - typname, elemtypoid); - continue; + unknownPostgresTypes.RemoveAt(i); + hasChanges = true; } - - var arrayType = new PostgresArrayType(nspname, typname, oid, elementPostgresType); - byOID[arrayType.OID] = arrayType; - continue; } - case 'r': // Range + if (!hasChanges) { - Debug.Assert(elemtypoid > 0); - if (!byOID.TryGetValue(elemtypoid, out var subtypePostgresType)) - { - _connectionLogger.LogTrace("Range type '{RangeTypeName}' refers to unknown subtype with OID {ElementTypeOID}, skipping", - typname, elemtypoid); - continue; - } - - var rangeType = new PostgresRangeType(nspname, typname, oid, subtypePostgresType); - byOID[rangeType.OID] = rangeType; - continue; - } - - case 'm': // Multirange - Debug.Assert(elemtypoid > 0); - if (!byOID.TryGetValue(elemtypoid, out var type)) - { - _connectionLogger.LogTrace("Multirange type '{MultirangeTypeName}' refers to unknown range with OID {ElementTypeOID}, skipping", - typname, elemtypoid); - continue; - } - - if (type is not PostgresRangeType rangePostgresType) - { - _connectionLogger.LogTrace("Multirange type '{MultirangeTypeName}' refers to non-range type '{TypeName}', skipping", - typname, type.Name); - continue; - } - - var multirangeType = new PostgresMultirangeType(nspname, typname, oid, rangePostgresType); - byOID[multirangeType.OID] = multirangeType; - continue; - - case 'e': // Enum - var enumType = new PostgresEnumType(nspname, typname, oid); - byOID[enumType.OID] = enumType; - continue; - - case 'c': // Composite - var compositeType = new PostgresCompositeType(nspname, typname, oid); - byOID[compositeType.OID] = compositeType; - continue; - - case 'd': // Domain - Debug.Assert(elemtypoid > 0); - if (!byOID.TryGetValue(elemtypoid, out var basePostgresType)) - { - _connectionLogger.LogTrace("Domain type '{DomainTypeName}' refers to unknown base type with OID {ElementTypeOID}, skipping", - typname, elemtypoid); - continue; - } - - var domainType = new PostgresDomainType(nspname, typname, oid, basePostgresType, typnotnull); - byOID[domainType.OID] = domainType; - continue; - - case 'p': // pseudo-type (record, void) - goto case 'b'; // Hack this as a base type - - default: - throw new ArgumentOutOfRangeException($"Unknown typtype for type '{typname}' in pg_type: {typtype}"); + _connectionLogger.LogWarning("Unable to load '{UnknownTypeCount}' Postgres types while loading database info.", + unknownPostgresTypes.Count); + break; } } + Expect(msg, conn); if (isReplicationConnection) Expect(await conn.ReadMessage(async).ConfigureAwait(false), conn); @@ -541,9 +513,101 @@ static string SanitizeForReplicationConnection(string str) if (!isReplicationConnection) Expect(await conn.ReadMessage(async).ConfigureAwait(false), conn); - return new(byOID.Values); + return [..byOID.Values]; static string ReadNonNullableString(NpgsqlReadBuffer buffer) => buffer.ReadString(buffer.ReadInt32()); + + bool TryAddPostgresType(PostgresTypeDefinition postgresTypeDefinition, Dictionary byOID) + { + switch (postgresTypeDefinition.Type) + { + case 'b': // Normal base type + var baseType = new PostgresBaseType(postgresTypeDefinition.Namespace, postgresTypeDefinition.Name, postgresTypeDefinition.OID); + byOID[baseType.OID] = baseType; + return true; + + case 'a': // Array + { + Debug.Assert(postgresTypeDefinition.ElemTypeOID > 0); + if (!byOID.TryGetValue(postgresTypeDefinition.ElemTypeOID, out var elementPostgresType)) + { + _connectionLogger.LogTrace("Array type '{ArrayTypeName}' refers to unknown element with OID {ElementTypeOID}, skipping", + postgresTypeDefinition.Name, postgresTypeDefinition.ElemTypeOID); + return false; + } + + var arrayType = new PostgresArrayType(postgresTypeDefinition.Namespace, postgresTypeDefinition.Name, postgresTypeDefinition.OID, elementPostgresType); + byOID[arrayType.OID] = arrayType; + return true; + } + + case 'r': // Range + { + Debug.Assert(postgresTypeDefinition.ElemTypeOID > 0); + if (!byOID.TryGetValue(postgresTypeDefinition.ElemTypeOID, out var subtypePostgresType)) + { + _connectionLogger.LogTrace("Range type '{RangeTypeName}' refers to unknown subtype with OID {ElementTypeOID}, skipping", + postgresTypeDefinition.Name, postgresTypeDefinition.ElemTypeOID); + return false; + } + + var rangeType = new PostgresRangeType(postgresTypeDefinition.Namespace, postgresTypeDefinition.Name, postgresTypeDefinition.OID, subtypePostgresType); + byOID[rangeType.OID] = rangeType; + return true; + } + + case 'm': // Multirange + Debug.Assert(postgresTypeDefinition.ElemTypeOID > 0); + if (!byOID.TryGetValue(postgresTypeDefinition.ElemTypeOID, out var type)) + { + _connectionLogger.LogTrace("Multirange type '{MultirangeTypeName}' refers to unknown range with OID {ElementTypeOID}, skipping", + postgresTypeDefinition.Name, postgresTypeDefinition.ElemTypeOID); + return false; + } + + if (type is not PostgresRangeType rangePostgresType) + { + _connectionLogger.LogTrace("Multirange type '{MultirangeTypeName}' refers to non-range type '{TypeName}', skipping", + postgresTypeDefinition.Name, type.Name); + return false; + } + + var multirangeType = new PostgresMultirangeType(postgresTypeDefinition.Namespace, postgresTypeDefinition.Name, postgresTypeDefinition.OID, rangePostgresType); + byOID[multirangeType.OID] = multirangeType; + return true; + + case 'e': // Enum + var enumType = new PostgresEnumType(postgresTypeDefinition.Namespace, postgresTypeDefinition.Name, postgresTypeDefinition.OID); + byOID[enumType.OID] = enumType; + return true; + + case 'c': // Composite + var compositeType = new PostgresCompositeType(postgresTypeDefinition.Namespace, postgresTypeDefinition.Name, postgresTypeDefinition.OID); + byOID[compositeType.OID] = compositeType; + return true; + + case 'd': // Domain + Debug.Assert(postgresTypeDefinition.ElemTypeOID > 0); + if (!byOID.TryGetValue(postgresTypeDefinition.ElemTypeOID, out var basePostgresType)) + { + _connectionLogger.LogTrace("Domain type '{DomainTypeName}' refers to unknown base type with OID {ElementTypeOID}, skipping", + postgresTypeDefinition.Name, postgresTypeDefinition.ElemTypeOID); + return false; + } + + var domainType = new PostgresDomainType(postgresTypeDefinition.Namespace, postgresTypeDefinition.Name, postgresTypeDefinition.OID, basePostgresType, postgresTypeDefinition.NotNull); + byOID[domainType.OID] = domainType; + return true; + + case 'p': // pseudo-type (record, void) + goto case 'b'; // Hack this as a base type + + default: + throw new ArgumentOutOfRangeException($"Unknown typtype for type '{postgresTypeDefinition.Name}' in pg_type: {postgresTypeDefinition.Type}"); + } + } } } + +readonly record struct PostgresTypeDefinition(string Namespace, uint OID, string Name, char Type, bool NotNull, uint ElemTypeOID); diff --git a/src/Npgsql/PostgresEnvironment.cs b/src/Npgsql/PostgresEnvironment.cs index 69036601e5..3ba874ae4c 100644 --- a/src/Npgsql/PostgresEnvironment.cs +++ b/src/Npgsql/PostgresEnvironment.cs @@ -48,6 +48,14 @@ internal static string? SslCertRootDefault internal static string? TargetSessionAttributes => Environment.GetEnvironmentVariable("PGTARGETSESSIONATTRS"); + internal static string? SslNegotiation => Environment.GetEnvironmentVariable("PGSSLNEGOTIATION"); + + internal static string? GssEncryptionMode => Environment.GetEnvironmentVariable("PGGSSENCMODE"); + + internal static string? RequireAuth => Environment.GetEnvironmentVariable("PGREQUIREAUTH"); + + internal static string? AppName => Environment.GetEnvironmentVariable("PGAPPNAME"); + static string? GetHomeDir() => Environment.GetEnvironmentVariable(RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? "APPDATA" : "HOME"); @@ -55,4 +63,4 @@ internal static string? SslCertRootDefault => GetHomeDir() is string homedir ? Path.Combine(homedir, RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? "postgresql" : ".postgresql") : null; -} \ No newline at end of file +} diff --git a/src/Npgsql/PostgresErrorCodes.cs b/src/Npgsql/PostgresErrorCodes.cs index afeadbf2c6..98d878e12b 100644 --- a/src/Npgsql/PostgresErrorCodes.cs +++ b/src/Npgsql/PostgresErrorCodes.cs @@ -466,15 +466,15 @@ public static class PostgresErrorCodes #endregion Class XX - Internal Error static readonly string[] CriticalFailureCodes = - { + [ "53", // Insufficient resources AdminShutdown, // Self explanatory CrashShutdown, // Self explanatory CannotConnectNow, // Database is starting up "58", // System errors, external to PG (server is dying) "F0", // Configuration file error - "XX", // Internal error (database is dying) - }; + "XX" // Internal error (database is dying) + ]; internal static bool IsCriticalFailure(PostgresException e, bool clusterError = true) { diff --git a/src/Npgsql/PostgresException.cs b/src/Npgsql/PostgresException.cs index a157f0ab87..c4ee7ec691 100644 --- a/src/Npgsql/PostgresException.cs +++ b/src/Npgsql/PostgresException.cs @@ -110,9 +110,7 @@ static string GetMessage(string sqlState, string messageText, int position, stri internal static PostgresException Load(NpgsqlReadBuffer buf, bool includeDetail, ILogger exceptionLogger) => new(ErrorOrNoticeMessage.Load(buf, includeDetail, exceptionLogger)); -#if NET8_0_OR_GREATER [Obsolete("This API supports obsolete formatter-based serialization. It should not be called or extended by application code.")] -#endif internal PostgresException(SerializationInfo info, StreamingContext context) : base(info, context) { @@ -143,9 +141,7 @@ internal PostgresException(SerializationInfo info, StreamingContext context) /// /// The to populate with data. /// The destination (see ) for this serialization. -#if NET8_0_OR_GREATER [Obsolete("This API supports obsolete formatter-based serialization. It should not be called or extended by application code.")] -#endif public override void GetObjectData(SerializationInfo info, StreamingContext context) { base.GetObjectData(info, context); diff --git a/src/Npgsql/PostgresMinimalDatabaseInfo.cs b/src/Npgsql/PostgresMinimalDatabaseInfo.cs index eb90453062..ed2ef15e81 100644 --- a/src/Npgsql/PostgresMinimalDatabaseInfo.cs +++ b/src/Npgsql/PostgresMinimalDatabaseInfo.cs @@ -11,7 +11,7 @@ sealed class PostgresMinimalDatabaseInfoFactory : INpgsqlDatabaseInfoFactory { public Task Load(NpgsqlConnector conn, NpgsqlTimeout timeout, bool async) => Task.FromResult( - conn.Settings.ServerCompatibilityMode == ServerCompatibilityMode.NoTypeLoading + !conn.DataSource.Configuration.TypeLoading.LoadTypes ? (NpgsqlDatabaseInfo)new PostgresMinimalDatabaseInfo(conn) : null); } @@ -118,10 +118,8 @@ protected override IEnumerable GetTypes() internal PostgresMinimalDatabaseInfo(NpgsqlConnector conn) : base(conn) - { - HasIntegerDateTimes = !conn.PostgresParameters.TryGetValue("integer_datetimes", out var intDateTimes) || - intDateTimes == "on"; - } + => HasIntegerDateTimes = !conn.PostgresParameters.TryGetValue("integer_datetimes", out var intDateTimes) || + intDateTimes == "on"; // TODO, split database info and type catalog. internal PostgresMinimalDatabaseInfo() diff --git a/src/Npgsql/PostgresNotice.cs b/src/Npgsql/PostgresNotice.cs index ef55ad4e13..6ed9c7f98d 100644 --- a/src/Npgsql/PostgresNotice.cs +++ b/src/Npgsql/PostgresNotice.cs @@ -198,7 +198,5 @@ public sealed class NpgsqlNoticeEventArgs : EventArgs public PostgresNotice Notice { get; } internal NpgsqlNoticeEventArgs(PostgresNotice notice) - { - Notice = notice; - } + => Notice = notice; } diff --git a/src/Npgsql/PostgresTypes/PostgresCompositeType.cs b/src/Npgsql/PostgresTypes/PostgresCompositeType.cs index 2d53199e6f..1663b01ebd 100644 --- a/src/Npgsql/PostgresTypes/PostgresCompositeType.cs +++ b/src/Npgsql/PostgresTypes/PostgresCompositeType.cs @@ -16,7 +16,7 @@ public class PostgresCompositeType : PostgresType /// public IReadOnlyList Fields => MutableFields; - internal List MutableFields { get; } = new(); + internal List MutableFields { get; } = []; /// /// Constructs a representation of a PostgreSQL array data type. diff --git a/src/Npgsql/PostgresTypes/PostgresEnumType.cs b/src/Npgsql/PostgresTypes/PostgresEnumType.cs index 7e4440252e..2422cb07a2 100644 --- a/src/Npgsql/PostgresTypes/PostgresEnumType.cs +++ b/src/Npgsql/PostgresTypes/PostgresEnumType.cs @@ -16,7 +16,7 @@ public class PostgresEnumType : PostgresType /// public IReadOnlyList Labels => MutableLabels; - internal List MutableLabels { get; } = new(); + internal List MutableLabels { get; } = []; /// /// Constructs a representation of a PostgreSQL enum data type. diff --git a/src/Npgsql/PostgresTypes/PostgresType.cs b/src/Npgsql/PostgresTypes/PostgresType.cs index 1182588c8c..842d1f3eea 100644 --- a/src/Npgsql/PostgresTypes/PostgresType.cs +++ b/src/Npgsql/PostgresTypes/PostgresType.cs @@ -24,7 +24,7 @@ public abstract class PostgresType /// The data type's OID. private protected PostgresType(string ns, string name, uint oid) { - DataTypeName = DataTypeName.FromDisplayName(name, ns); + DataTypeName = DataTypeName.FromDisplayName(name, ns, assumeUnqualified: true); OID = oid; FullName = Namespace + "." + Name; } diff --git a/src/Npgsql/PostgresTypes/PostgresUnknownType.cs b/src/Npgsql/PostgresTypes/PostgresUnknownType.cs index 2955295000..d7cfc983e9 100644 --- a/src/Npgsql/PostgresTypes/PostgresUnknownType.cs +++ b/src/Npgsql/PostgresTypes/PostgresUnknownType.cs @@ -1,4 +1,6 @@ -namespace Npgsql.PostgresTypes; +using Npgsql.Internal.Postgres; + +namespace Npgsql.PostgresTypes; /// /// Represents a PostgreSQL data type that isn't known to Npgsql and cannot be handled. @@ -10,5 +12,5 @@ public sealed class UnknownBackendType : PostgresType /// /// Constructs a the unknown backend type. /// - UnknownBackendType() : base("", "", 0) { } + UnknownBackendType() : base(DataTypeName.Unspecified,0) { } } diff --git a/src/Npgsql/PregeneratedMessages.cs b/src/Npgsql/PregeneratedMessages.cs index 54c736b64c..4e6434e9c1 100644 --- a/src/Npgsql/PregeneratedMessages.cs +++ b/src/Npgsql/PregeneratedMessages.cs @@ -26,7 +26,7 @@ internal static byte[] Generate(NpgsqlWriteBuffer buf, string query) { NpgsqlWriteBuffer.AssertASCIIOnly(query); - var queryByteLen = Encoding.ASCII.GetByteCount(query); + var queryByteLen = buf.TextEncoding.GetByteCount(query); buf.WriteByte(FrontendMessageCode.Query); buf.WriteInt32(4 + // Message length (including self excluding code) diff --git a/src/Npgsql/PreparedStatement.cs b/src/Npgsql/PreparedStatement.cs index f24905eb41..5a5a877eb2 100644 --- a/src/Npgsql/PreparedStatement.cs +++ b/src/Npgsql/PreparedStatement.cs @@ -26,7 +26,8 @@ sealed class PreparedStatement internal PreparedState State { get; set; } - internal bool IsPrepared => State == PreparedState.Prepared; + // Invalidated statement is still prepared and allocated on PG's side + internal bool IsPrepared => State is PreparedState.Prepared or PreparedState.Invalidated; /// /// If true, the user explicitly requested this statement be prepared. It does not get closed as part of @@ -83,7 +84,7 @@ internal void SetParamTypes(List parameters) { if (parameters.Count == 0) { - ConverterParamTypes = Array.Empty(); + ConverterParamTypes = []; return; } diff --git a/src/Npgsql/PreparedStatementManager.cs b/src/Npgsql/PreparedStatementManager.cs index ef72879c6d..8f80223753 100644 --- a/src/Npgsql/PreparedStatementManager.cs +++ b/src/Npgsql/PreparedStatementManager.cs @@ -17,8 +17,6 @@ sealed class PreparedStatementManager readonly PreparedStatement?[] _candidates; - static readonly List EmptyParameters = new(); - /// /// Total number of current prepared statements (whether explicit or automatic). /// @@ -63,7 +61,8 @@ internal PreparedStatementManager(NpgsqlConnector connector) if (BySql.TryGetValue(sql, out var pStatement)) { Debug.Assert(pStatement.State != PreparedState.Unprepared); - if (pStatement.IsExplicit) + // If statement is invalidated, fall through below where we replace it with another + if (pStatement.IsExplicit && pStatement.State != PreparedState.Invalidated) { // Great, we've found an explicit prepared statement. // We just need to check that the parameter types correspond, since prepared statements are @@ -80,8 +79,10 @@ internal PreparedStatementManager(NpgsqlConnector connector) // Found a candidate for autopreparation. Remove it and prepare explicitly. RemoveCandidate(pStatement); break; + // The statement is invalidated. Just replace it with a new one. + case PreparedState.Invalidated: + // The statement has already been autoprepared. We need to "promote" it to explicit. case PreparedState.Prepared: - // The statement has already been autoprepared. We need to "promote" it to explicit. statementBeingReplaced = pStatement; break; case PreparedState.Unprepared: @@ -98,161 +99,180 @@ internal PreparedStatementManager(NpgsqlConnector connector) internal PreparedStatement? TryGetAutoPrepared(NpgsqlBatchCommand batchCommand) { var sql = batchCommand.FinalCommandText!; - if (!BySql.TryGetValue(sql, out var pStatement)) + // We could also test for PreparedState.BeingPrepared as it's handled the exact same way as PreparedState.Prepared + // But since it's so rare we'll just go through the slow path + if (!BySql.TryGetValue(sql, out var pStatement) || pStatement.State != PreparedState.Prepared) + return TryGetAutoPreparedSlow(batchCommand, pStatement); + + // The statement has already been prepared (explicitly or automatically) + // We just need to check that the parameter types correspond, since prepared statements are + // only keyed by SQL (to prevent pointless allocations). If we have a mismatch, simply run unprepared. + if (!pStatement.DoParametersMatch(batchCommand.CurrentParametersReadOnly)) + return null; + // Prevent this statement from being replaced within this batch + pStatement.LastUsed = long.MaxValue; + return pStatement; + + PreparedStatement? TryGetAutoPreparedSlow(NpgsqlBatchCommand batchCommand, PreparedStatement? pStatement) { - // New candidate. Find an empty candidate slot or eject a least-used one. - int slotIndex = -1, leastUsages = int.MaxValue; - var lastUsed = long.MaxValue; - for (var i = 0; i < _candidates.Length; i++) + var sql = batchCommand.FinalCommandText!; + if (pStatement is null) { - var candidate = _candidates[i]; - // ReSharper disable once ConditionIsAlwaysTrueOrFalse - // ReSharper disable HeuristicUnreachableCode - if (candidate == null) // Found an unused candidate slot, return immediately - { - slotIndex = i; - break; - } - // ReSharper restore HeuristicUnreachableCode - if (candidate.Usages < leastUsages) + // New candidate. Find an empty candidate slot or eject a least-used one. + int slotIndex = -1, leastUsages = int.MaxValue; + var lastUsed = long.MaxValue; + for (var i = 0; i < _candidates.Length; i++) { - leastUsages = candidate.Usages; - slotIndex = i; - lastUsed = candidate.LastUsed; - } - else if (candidate.Usages == leastUsages && candidate.LastUsed < lastUsed) - { - slotIndex = i; - lastUsed = candidate.LastUsed; + var candidate = _candidates[i]; + // ReSharper disable once ConditionIsAlwaysTrueOrFalse + // ReSharper disable HeuristicUnreachableCode + if (candidate == null) // Found an unused candidate slot, return immediately + { + slotIndex = i; + break; + } + // ReSharper restore HeuristicUnreachableCode + if (candidate.Usages < leastUsages) + { + leastUsages = candidate.Usages; + slotIndex = i; + lastUsed = candidate.LastUsed; + } + else if (candidate.Usages == leastUsages && candidate.LastUsed < lastUsed) + { + slotIndex = i; + lastUsed = candidate.LastUsed; + } } + + var leastUsed = _candidates[slotIndex]; + // ReSharper disable once ConditionIsAlwaysTrueOrFalse + if (leastUsed != null) + BySql.Remove(leastUsed.Sql); + pStatement = BySql[sql] = _candidates[slotIndex] = PreparedStatement.CreateAutoPrepareCandidate(this, sql); } - var leastUsed = _candidates[slotIndex]; - // ReSharper disable once ConditionIsAlwaysTrueOrFalse - if (leastUsed != null) - BySql.Remove(leastUsed.Sql); - pStatement = BySql[sql] = _candidates[slotIndex] = PreparedStatement.CreateAutoPrepareCandidate(this, sql); - } + switch (pStatement.State) + { + case PreparedState.NotPrepared: + case PreparedState.Invalidated: + break; - switch (pStatement.State) - { - case PreparedState.NotPrepared: - case PreparedState.Invalidated: - break; - - case PreparedState.Prepared: - case PreparedState.BeingPrepared: - // The statement has already been prepared (explicitly or automatically), or has been selected - // for preparation (earlier identical statement in the same command). - // We just need to check that the parameter types correspond, since prepared statements are - // only keyed by SQL (to prevent pointless allocations). If we have a mismatch, simply run unprepared. - if (!pStatement.DoParametersMatch(batchCommand.CurrentParametersReadOnly)) + // We shouldn't ever get PreparedState.Prepared since it's handled above but handle it here just in case + case PreparedState.Prepared: + case PreparedState.BeingPrepared: + // The statement has already been prepared (explicitly or automatically), or has been selected + // for preparation (earlier identical statement in the same command). + // We just need to check that the parameter types correspond, since prepared statements are + // only keyed by SQL (to prevent pointless allocations). If we have a mismatch, simply run unprepared. + if (!pStatement.DoParametersMatch(batchCommand.CurrentParametersReadOnly)) + return null; + // Prevent this statement from being replaced within this batch + pStatement.LastUsed = long.MaxValue; + return pStatement; + + case PreparedState.BeingUnprepared: + // The statement is being replaced by an earlier statement in this same batch. return null; - // Prevent this statement from being replaced within this batch - pStatement.LastUsed = long.MaxValue; - return pStatement; - - case PreparedState.BeingUnprepared: - // The statement is being replaced by an earlier statement in this same batch. - return null; - - default: - Debug.Fail($"Unexpected {nameof(PreparedState)} in auto-preparation: {pStatement.State}"); - break; - } - if (++pStatement.Usages < UsagesBeforePrepare) - { - // Statement still hasn't passed the usage threshold, no automatic preparation. - // Return null for unprepared execution. - pStatement.RefreshLastUsed(); - return null; - } - - // Bingo, we've just passed the usage threshold, statement should get prepared - LogMessages.AutoPreparingStatement(_commandLogger, sql, _connector.Id); - - // Look for either an empty autoprepare slot, or the least recently used prepared statement which we'll replace it. - var oldestLastUsed = long.MaxValue; - var selectedIndex = -1; - for (var i = 0; i < AutoPrepared.Length; i++) - { - var slot = AutoPrepared[i]; + default: + Debug.Fail($"Unexpected {nameof(PreparedState)} in auto-preparation: {pStatement.State}"); + break; + } - if (slot is null or { State: PreparedState.Invalidated }) + if (++pStatement.Usages < UsagesBeforePrepare) { - // We found a free or invalidated slot, exit the loop immediately - selectedIndex = i; - break; + // Statement still hasn't passed the usage threshold, no automatic preparation. + // Return null for unprepared execution. + pStatement.RefreshLastUsed(); + return null; } - switch (slot.State) + // Bingo, we've just passed the usage threshold, statement should get prepared + LogMessages.AutoPreparingStatement(_commandLogger, sql, _connector.Id); + + // Look for either an empty autoprepare slot, or the least recently used prepared statement which we'll replace it. + var oldestLastUsed = long.MaxValue; + var selectedIndex = -1; + for (var i = 0; i < AutoPrepared.Length; i++) { - case PreparedState.Prepared: - if (slot.LastUsed < oldestLastUsed) + var slot = AutoPrepared[i]; + + if (slot is null or { State: PreparedState.Invalidated }) { + // We found a free or invalidated slot, exit the loop immediately selectedIndex = i; - oldestLastUsed = slot.LastUsed; + break; } - break; - case PreparedState.BeingPrepared: - // Slot has already been selected for preparation by an earlier statement in this batch. Skip it. - continue; + switch (slot.State) + { + case PreparedState.Prepared: + if (slot.LastUsed < oldestLastUsed) + { + selectedIndex = i; + oldestLastUsed = slot.LastUsed; + } + break; - default: - ThrowHelper.ThrowInvalidOperationException($"Invalid {nameof(PreparedState)} state {slot.State} encountered when scanning prepared statement slots"); - return null; + case PreparedState.BeingPrepared: + // Slot has already been selected for preparation by an earlier statement in this batch. Skip it. + continue; + + default: + ThrowHelper.ThrowInvalidOperationException($"Invalid {nameof(PreparedState)} state {slot.State} encountered when scanning prepared statement slots"); + return null; + } } - } - if (selectedIndex == -1) - { - // We're here if we couldn't find a free slot or a prepared statement to replace - this means all slots are taken by - // statements being prepared in this batch. - return null; - } + if (selectedIndex < 0) + { + // We're here if we couldn't find a free slot or a prepared statement to replace - this means all slots are taken by + // statements being prepared in this batch. + return null; + } - if (pStatement.State != PreparedState.Invalidated) - RemoveCandidate(pStatement); + if (pStatement.State != PreparedState.Invalidated) + RemoveCandidate(pStatement); - var oldPreparedStatement = AutoPrepared[selectedIndex]; + var oldPreparedStatement = AutoPrepared[selectedIndex]; - if (oldPreparedStatement is null) - { - pStatement.Name = Encoding.ASCII.GetBytes("_auto" + selectedIndex); - } - else - { - // When executing an invalidated prepared statement, the old and the new statements are the same instance. - // Create a copy so that we have two distinct instances with their own states. - if (oldPreparedStatement == pStatement) + if (oldPreparedStatement is null) { - oldPreparedStatement = new PreparedStatement(this, oldPreparedStatement.Sql, isExplicit: false) - { - Name = oldPreparedStatement.Name - }; + pStatement.Name = Encoding.ASCII.GetBytes("_auto" + selectedIndex); } + else + { + // When executing an invalidated prepared statement, the old and the new statements are the same instance. + // Create a copy so that we have two distinct instances with their own states. + if (oldPreparedStatement == pStatement) + { + oldPreparedStatement = new PreparedStatement(this, oldPreparedStatement.Sql, isExplicit: false) + { + Name = oldPreparedStatement.Name + }; + } - pStatement.Name = oldPreparedStatement.Name; - pStatement.State = PreparedState.NotPrepared; - pStatement.StatementBeingReplaced = oldPreparedStatement; - oldPreparedStatement.State = PreparedState.BeingUnprepared; - } + pStatement.Name = oldPreparedStatement.Name; + pStatement.State = PreparedState.NotPrepared; + pStatement.StatementBeingReplaced = oldPreparedStatement; + oldPreparedStatement.State = PreparedState.BeingUnprepared; + } - pStatement.AutoPreparedSlotIndex = selectedIndex; - AutoPrepared[selectedIndex] = pStatement; + pStatement.AutoPreparedSlotIndex = selectedIndex; + AutoPrepared[selectedIndex] = pStatement; - // Make sure this statement isn't replaced by a later statement in the same batch. - pStatement.LastUsed = long.MaxValue; + // Make sure this statement isn't replaced by a later statement in the same batch. + pStatement.LastUsed = long.MaxValue; - // Note that the parameter types are only set at the moment of preparation - in the candidate phase - // there's no differentiation between overloaded statements, which are a pretty rare case, saving - // allocations. - pStatement.SetParamTypes(batchCommand.CurrentParametersReadOnly); + // Note that the parameter types are only set at the moment of preparation - in the candidate phase + // there's no differentiation between overloaded statements, which are a pretty rare case, saving + // allocations. + pStatement.SetParamTypes(batchCommand.CurrentParametersReadOnly); - return pStatement; + return pStatement; + } } void RemoveCandidate(PreparedStatement candidate) diff --git a/src/Npgsql/PreparedTextReader.cs b/src/Npgsql/PreparedTextReader.cs index 4831850684..80ee543d9b 100644 --- a/src/Npgsql/PreparedTextReader.cs +++ b/src/Npgsql/PreparedTextReader.cs @@ -57,17 +57,12 @@ public override int Read(Span buffer) public override int Read(char[] buffer, int index, int count) { - if (buffer == null) - { - throw new ArgumentNullException(nameof(buffer)); - } - if (index < 0 || count < 0) - { - throw new ArgumentOutOfRangeException(index < 0 ? nameof(index) : nameof(count)); - } + ArgumentNullException.ThrowIfNull(buffer); + ArgumentOutOfRangeException.ThrowIfNegative(index); + ArgumentOutOfRangeException.ThrowIfNegative(count); if (buffer.Length - index < count) { - throw new ArgumentException("Offset and length were out of bounds for the array or count is greater than the number of elements from index to the end of the source collection."); + ThrowHelper.ThrowArgumentException("Offset and length were out of bounds for the array or count is greater than the number of elements from index to the end of the source collection."); } return Read(buffer.AsSpan(index, count)); @@ -95,10 +90,7 @@ public override string ReadToEnd() public override Task ReadToEndAsync() => Task.FromResult(ReadToEnd()); void CheckDisposed() - { - if (_disposed || _stream.IsDisposed) - ThrowHelper.ThrowObjectDisposedException(nameof(PreparedTextReader)); - } + => ObjectDisposedException.ThrowIf(_disposed || _stream.IsDisposed, this); public void Restart() { diff --git a/src/Npgsql/Properties/NpgsqlStrings.Designer.cs b/src/Npgsql/Properties/NpgsqlStrings.Designer.cs index f00370da48..d0b7839d6c 100644 --- a/src/Npgsql/Properties/NpgsqlStrings.Designer.cs +++ b/src/Npgsql/Properties/NpgsqlStrings.Designer.cs @@ -11,32 +11,46 @@ namespace Npgsql.Properties { using System; - [System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "4.0.0.0")] - [System.Diagnostics.DebuggerNonUserCodeAttribute()] - [System.Runtime.CompilerServices.CompilerGeneratedAttribute()] + /// + /// A strongly-typed resource class, for looking up localized strings, etc. + /// + // This class was auto-generated by the StronglyTypedResourceBuilder + // class via a tool like ResGen or Visual Studio. + // To add or remove a member, edit your .ResX file then rerun ResGen + // with the /str option, or rebuild your VS project. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "4.0.0.0")] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] + [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()] internal class NpgsqlStrings { - private static System.Resources.ResourceManager resourceMan; + private static global::System.Resources.ResourceManager resourceMan; - private static System.Globalization.CultureInfo resourceCulture; + private static global::System.Globalization.CultureInfo resourceCulture; - [System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")] + [global::System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")] internal NpgsqlStrings() { } - [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Advanced)] - internal static System.Resources.ResourceManager ResourceManager { + /// + /// Returns the cached ResourceManager instance used by this class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Resources.ResourceManager ResourceManager { get { - if (object.Equals(null, resourceMan)) { - System.Resources.ResourceManager temp = new System.Resources.ResourceManager("Npgsql.Properties.NpgsqlStrings", typeof(NpgsqlStrings).Assembly); + if (object.ReferenceEquals(resourceMan, null)) { + global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("Npgsql.Properties.NpgsqlStrings", typeof(NpgsqlStrings).Assembly); resourceMan = temp; } return resourceMan; } } - [System.ComponentModel.EditorBrowsableAttribute(System.ComponentModel.EditorBrowsableState.Advanced)] - internal static System.Globalization.CultureInfo Culture { + /// + /// Overrides the current thread's CurrentUICulture property for all + /// resource lookups using this strongly typed resource class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Globalization.CultureInfo Culture { get { return resourceCulture; } @@ -45,174 +59,272 @@ internal static System.Globalization.CultureInfo Culture { } } - internal static string CannotUseSslVerifyWithUserCallback { + /// + /// Looks up a localized string similar to '{0}' must be positive.. + /// + internal static string ArgumentMustBePositive { get { - return ResourceManager.GetString("CannotUseSslVerifyWithUserCallback", resourceCulture); + return ResourceManager.GetString("ArgumentMustBePositive", resourceCulture); } } - internal static string CannotUseSslRootCertificateWithUserCallback { + /// + /// Looks up a localized string similar to Arrays aren't enabled; please call {0} on {1} to enable arrays.. + /// + internal static string ArraysNotEnabled { get { - return ResourceManager.GetString("CannotUseSslRootCertificateWithUserCallback", resourceCulture); + return ResourceManager.GetString("ArraysNotEnabled", resourceCulture); } } - internal static string TransportSecurityDisabled { + /// + /// Looks up a localized string similar to Cannot read infinity value since Npgsql.DisableDateTimeInfinityConversions is enabled.. + /// + internal static string CannotReadInfinityValue { get { - return ResourceManager.GetString("TransportSecurityDisabled", resourceCulture); + return ResourceManager.GetString("CannotReadInfinityValue", resourceCulture); } } - internal static string IntegratedSecurityDisabled { + /// + /// Looks up a localized string similar to Cannot read interval values with non-zero months as TimeSpan, since that type doesn't support months. Consider using NodaTime Period which better corresponds to PostgreSQL interval, or read the value as NpgsqlInterval, or transform the interval to not contain months or years in PostgreSQL before reading it.. + /// + internal static string CannotReadIntervalWithMonthsAsTimeSpan { get { - return ResourceManager.GetString("IntegratedSecurityDisabled", resourceCulture); + return ResourceManager.GetString("CannotReadIntervalWithMonthsAsTimeSpan", resourceCulture); } } - internal static string NoMultirangeTypeFound { + /// + /// Looks up a localized string similar to When registering a password provider, a password or password file may not be set.. + /// + internal static string CannotSetBothPasswordProviderAndPassword { get { - return ResourceManager.GetString("NoMultirangeTypeFound", resourceCulture); + return ResourceManager.GetString("CannotSetBothPasswordProviderAndPassword", resourceCulture); } } - internal static string NotSupportedOnDataSourceCommand { + /// + /// Looks up a localized string similar to Multiple kinds of password providers were found, only one kind may be configured per DbDataSource.. + /// + internal static string CannotSetMultiplePasswordProviderKinds { get { - return ResourceManager.GetString("NotSupportedOnDataSourceCommand", resourceCulture); + return ResourceManager.GetString("CannotSetMultiplePasswordProviderKinds", resourceCulture); } } - internal static string NotSupportedOnDataSourceBatch { + /// + /// Looks up a localized string similar to RootCertificate cannot be used in conjunction with SslClientAuthenticationOptionsCallback overwriting RemoteCertificateValidationCallback; when registering a validation callback, perform whatever validation you require in that callback.. + /// + internal static string CannotUseSslRootCertificateWithCustomValidationCallback { get { - return ResourceManager.GetString("NotSupportedOnDataSourceBatch", resourceCulture); + return ResourceManager.GetString("CannotUseSslRootCertificateWithCustomValidationCallback", resourceCulture); } } - internal static string CannotSetBothPasswordProviderAndPassword { + /// + /// Looks up a localized string similar to SslMode.{0} cannot be used in conjunction with SslClientAuthenticationOptionsCallback overwriting RemoteCertificateValidationCallback; when registering a validation callback, perform whatever validation you require in that callback.. + /// + internal static string CannotUseSslVerifyWithCustomValidationCallback { get { - return ResourceManager.GetString("CannotSetBothPasswordProviderAndPassword", resourceCulture); + return ResourceManager.GetString("CannotUseSslVerifyWithCustomValidationCallback", resourceCulture); } } - internal static string CannotSetMultiplePasswordProviderKinds { + /// + /// Looks up a localized string similar to ValidationRootCertificateCallback cannot be used in conjunction with SslClientAuthenticationOptionsCallback overwriting RemoteCertificateValidationCallback; when registering a validation callback, perform whatever validation you require in that callback.. + /// + internal static string CannotUseValidationRootCertificateCallbackWithCustomValidationCallback { get { - return ResourceManager.GetString("CannotSetMultiplePasswordProviderKinds", resourceCulture); + return ResourceManager.GetString("CannotUseValidationRootCertificateCallbackWithCustomValidationCallback", resourceCulture); } } - - internal static string SyncAndAsyncPasswordProvidersRequired { + + /// + /// Looks up a localized string similar to Cube isn't enabled; please call {0} on {1} to enable Cube.. + /// + internal static string CubeNotEnabled { get { - return ResourceManager.GetString("SyncAndAsyncPasswordProvidersRequired", resourceCulture); + return ResourceManager.GetString("CubeNotEnabled", resourceCulture); } } - - internal static string PasswordProviderMissing { + + /// + /// Looks up a localized string similar to Type '{0}' required dynamic JSON serialization, which requires an explicit opt-in; call '{1}' on '{2}' or NpgsqlConnection.GlobalTypeMapper (see https://www.npgsql.org/doc/types/json.html and the 8.0 release notes for more details). Alternatively, if you meant to use Newtonsoft JSON.NET instead of System.Text.Json, call UseJsonNet() instead. + ///. + /// + internal static string DynamicJsonNotEnabled { get { - return ResourceManager.GetString("PasswordProviderMissing", resourceCulture); + return ResourceManager.GetString("DynamicJsonNotEnabled", resourceCulture); } } - internal static string ArgumentMustBePositive { + /// + /// Looks up a localized string similar to Full-text search isn't enabled; please call {0} on {1} to enable full-text search.. + /// + internal static string FullTextSearchNotEnabled { get { - return ResourceManager.GetString("ArgumentMustBePositive", resourceCulture); + return ResourceManager.GetString("FullTextSearchNotEnabled", resourceCulture); } } - internal static string CannotSpecifyTargetSessionAttributes { + /// + /// Looks up a localized string similar to Integrated security hasn't been enabled; please call {0} on NpgsqlSlimDataSourceBuilder to enable it.. + /// + internal static string IntegratedSecurityDisabled { get { - return ResourceManager.GetString("CannotSpecifyTargetSessionAttributes", resourceCulture); + return ResourceManager.GetString("IntegratedSecurityDisabled", resourceCulture); } } - internal static string CannotReadIntervalWithMonthsAsTimeSpan { + /// + /// Looks up a localized string similar to Ltree isn't enabled; please call {0} on {1} to enable LTree.. + /// + internal static string LTreeNotEnabled { get { - return ResourceManager.GetString("CannotReadIntervalWithMonthsAsTimeSpan", resourceCulture); + return ResourceManager.GetString("LTreeNotEnabled", resourceCulture); } } - internal static string PositionalParameterAfterNamed { + /// + /// Looks up a localized string similar to Multiranges aren't enabled; please call {0} on {1} to enable multiranges.. + /// + internal static string MultirangesNotEnabled { get { - return ResourceManager.GetString("PositionalParameterAfterNamed", resourceCulture); + return ResourceManager.GetString("MultirangesNotEnabled", resourceCulture); } } - internal static string CannotReadInfinityValue { + /// + /// Looks up a localized string similar to No multirange type could be found in the database for subtype {0}.. + /// + internal static string NoMultirangeTypeFound { get { - return ResourceManager.GetString("CannotReadInfinityValue", resourceCulture); + return ResourceManager.GetString("NoMultirangeTypeFound", resourceCulture); } } - internal static string SyncAndAsyncConnectionInitializersRequired { + /// + /// Looks up a localized string similar to Connection and transaction access is not supported on batches created from DbDataSource.. + /// + internal static string NotSupportedOnDataSourceBatch { get { - return ResourceManager.GetString("SyncAndAsyncConnectionInitializersRequired", resourceCulture); + return ResourceManager.GetString("NotSupportedOnDataSourceBatch", resourceCulture); } } - internal static string CannotUseValidationRootCertificateCallbackWithUserCallback { + /// + /// Looks up a localized string similar to Connection and transaction access is not supported on commands created from DbDataSource.. + /// + internal static string NotSupportedOnDataSourceCommand { get { - return ResourceManager.GetString("CannotUseValidationRootCertificateCallbackWithUserCallback", resourceCulture); + return ResourceManager.GetString("NotSupportedOnDataSourceCommand", resourceCulture); } } - internal static string RecordsNotEnabled { + /// + /// Looks up a localized string similar to The right type of password provider (sync or async) was not found.. + /// + internal static string PasswordProviderMissing { get { - return ResourceManager.GetString("RecordsNotEnabled", resourceCulture); + return ResourceManager.GetString("PasswordProviderMissing", resourceCulture); } } - internal static string FullTextSearchNotEnabled { + /// + /// Looks up a localized string similar to When using CommandType.StoredProcedure, all positional parameters must come before named parameters.. + /// + internal static string PositionalParameterAfterNamed { get { - return ResourceManager.GetString("FullTextSearchNotEnabled", resourceCulture); + return ResourceManager.GetString("PositionalParameterAfterNamed", resourceCulture); } } - internal static string LTreeNotEnabled { + /// + /// Looks up a localized string similar to Ranges aren't enabled; please call {0} on {1} to enable ranges.. + /// + internal static string RangesNotEnabled { get { - return ResourceManager.GetString("LTreeNotEnabled", resourceCulture); + return ResourceManager.GetString("RangesNotEnabled", resourceCulture); } } - internal static string RangesNotEnabled { + /// + /// Looks up a localized string similar to Could not read a PostgreSQL record. If you're attempting to read a record as a .NET tuple, call '{0}' on '{1}' or NpgsqlConnection.GlobalTypeMapper (see https://www.npgsql.org/doc/types/basic.html and the 8.0 release notes for more details). If you're reading a record as a .NET object array using NpgsqlSlimDataSourceBuilder, call '{2}'. + ///. + /// + internal static string RecordsNotEnabled { get { - return ResourceManager.GetString("RangesNotEnabled", resourceCulture); + return ResourceManager.GetString("RecordsNotEnabled", resourceCulture); } } - internal static string MultirangesNotEnabled { + /// + /// Looks up a localized string similar to SslClientAuthenticationOptionsCallback is not supported together with UserCertificateValidationCallback and ClientCertificatesCallback. + /// + internal static string SslClientAuthenticationOptionsCallbackWithOtherCallbacksNotSupported { get { - return ResourceManager.GetString("MultirangesNotEnabled", resourceCulture); + return ResourceManager.GetString("SslClientAuthenticationOptionsCallbackWithOtherCallbacksNotSupported", resourceCulture); } } - internal static string ArraysNotEnabled { + /// + /// Looks up a localized string similar to Both sync and async connection initializers must be provided.. + /// + internal static string SyncAndAsyncConnectionInitializersRequired { get { - return ResourceManager.GetString("ArraysNotEnabled", resourceCulture); + return ResourceManager.GetString("SyncAndAsyncConnectionInitializersRequired", resourceCulture); } } - internal static string TimestampTzNoDateTimeUnspecified { + /// + /// Looks up a localized string similar to Both sync and async password providers must be provided.. + /// + internal static string SyncAndAsyncPasswordProvidersRequired { get { - return ResourceManager.GetString("TimestampTzNoDateTimeUnspecified", resourceCulture); + return ResourceManager.GetString("SyncAndAsyncPasswordProvidersRequired", resourceCulture); } } + /// + /// Looks up a localized string similar to Cannot write DateTime with Kind=UTC to PostgreSQL type '{0}', consider using '{1}'. Note that it's not possible to mix DateTimes with different Kinds in an array, range, or multirange.. + /// internal static string TimestampNoDateTimeUtc { get { return ResourceManager.GetString("TimestampNoDateTimeUtc", resourceCulture); } } - internal static string DynamicJsonNotEnabled { + /// + /// Looks up a localized string similar to Cannot write DateTime with Kind={0} to PostgreSQL type '{1}', only UTC is supported. Note that it's not possible to mix DateTimes with different Kinds in an array, range, or multirange.. + /// + internal static string TimestampTzNoDateTimeUnspecified { get { - return ResourceManager.GetString("DynamicJsonNotEnabled", resourceCulture); + return ResourceManager.GetString("TimestampTzNoDateTimeUnspecified", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Transport security hasn't been enabled; please call {0} on NpgsqlSlimDataSourceBuilder to enable it.. + /// + internal static string TransportSecurityDisabled { + get { + return ResourceManager.GetString("TransportSecurityDisabled", resourceCulture); } } + /// + /// Looks up a localized string similar to Reading and writing unmapped enums requires an explicit opt-in; call '{0}' on '{1}' or NpgsqlConnection.GlobalTypeMapper (see https://www.npgsql.org/doc/types/enums_and_composites.html and the 8.0 release notes for more details).. + /// internal static string UnmappedEnumsNotEnabled { get { return ResourceManager.GetString("UnmappedEnumsNotEnabled", resourceCulture); } } + /// + /// Looks up a localized string similar to Reading and writing unmapped ranges and multiranges requires an explicit opt-in; call '{0}' on '{1}' or NpgsqlConnection.GlobalTypeMapper (see https://www.npgsql.org/doc/types/ranges.html and the 8.0 release notes for more details).. + /// internal static string UnmappedRangesNotEnabled { get { return ResourceManager.GetString("UnmappedRangesNotEnabled", resourceCulture); diff --git a/src/Npgsql/Properties/NpgsqlStrings.resx b/src/Npgsql/Properties/NpgsqlStrings.resx index 5dbc58acdf..c39af4abc4 100644 --- a/src/Npgsql/Properties/NpgsqlStrings.resx +++ b/src/Npgsql/Properties/NpgsqlStrings.resx @@ -18,11 +18,11 @@ System.Resources.ResXResourceWriter, System.Windows.Forms, Version=2.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 - - SslMode.{0} cannot be used in conjunction with UserCertificateValidationCallback; when registering a validation callback, perform whatever validation you require in that callback. + + SslMode.{0} cannot be used in conjunction with SslClientAuthenticationOptionsCallback overwriting RemoteCertificateValidationCallback; when registering a validation callback, perform whatever validation you require in that callback. - - RootCertificate cannot be used in conjunction with UserCertificateValidationCallback; when registering a validation callback, perform whatever validation you require in that callback. + + RootCertificate cannot be used in conjunction with SslClientAuthenticationOptionsCallback overwriting RemoteCertificateValidationCallback; when registering a validation callback, perform whatever validation you require in that callback. Transport security hasn't been enabled; please call {0} on NpgsqlSlimDataSourceBuilder to enable it. @@ -54,9 +54,6 @@ '{0}' must be positive. - - When creating a multi-host data source, TargetSessionAttributes cannot be specified. Create without TargetSessionAttributes, and then obtain DataSource wrappers from it. Consult the docs for more information. - Cannot read interval values with non-zero months as TimeSpan, since that type doesn't support months. Consider using NodaTime Period which better corresponds to PostgreSQL interval, or read the value as NpgsqlInterval, or transform the interval to not contain months or years in PostgreSQL before reading it. @@ -69,8 +66,8 @@ Both sync and async connection initializers must be provided. - - ValidationRootCertificateCallback cannot be used in conjunction with UserCertificateValidationCallback; when registering a validation callback, perform whatever validation you require in that callback. + + ValidationRootCertificateCallback cannot be used in conjunction with SslClientAuthenticationOptionsCallback overwriting RemoteCertificateValidationCallback; when registering a validation callback, perform whatever validation you require in that callback. Could not read a PostgreSQL record. If you're attempting to read a record as a .NET tuple, call '{0}' on '{1}' or NpgsqlConnection.GlobalTypeMapper (see https://www.npgsql.org/doc/types/basic.html and the 8.0 release notes for more details). If you're reading a record as a .NET object array using NpgsqlSlimDataSourceBuilder, call '{2}'. @@ -82,6 +79,9 @@ Ltree isn't enabled; please call {0} on {1} to enable LTree. + + Cube isn't enabled; please call {0} on {1} to enable Cube. + Ranges aren't enabled; please call {0} on {1} to enable ranges. @@ -107,4 +107,7 @@ Reading and writing unmapped ranges and multiranges requires an explicit opt-in; call '{0}' on '{1}' or NpgsqlConnection.GlobalTypeMapper (see https://www.npgsql.org/doc/types/ranges.html and the 8.0 release notes for more details). + + SslClientAuthenticationOptionsCallback is not supported together with UserCertificateValidationCallback and ClientCertificatesCallback + diff --git a/src/Npgsql/PublicAPI.Unshipped.txt b/src/Npgsql/PublicAPI.Unshipped.txt index ab058de62d..b52e2d68fb 100644 --- a/src/Npgsql/PublicAPI.Unshipped.txt +++ b/src/Npgsql/PublicAPI.Unshipped.txt @@ -1 +1,146 @@ #nullable enable +*REMOVED*Npgsql.NpgsqlConnectionStringBuilder.Multiplexing.get -> bool +*REMOVED*Npgsql.NpgsqlConnectionStringBuilder.Multiplexing.set -> void +*REMOVED*Npgsql.NpgsqlConnectionStringBuilder.WriteCoalescingBufferThresholdBytes.get -> int +*REMOVED*Npgsql.NpgsqlConnectionStringBuilder.WriteCoalescingBufferThresholdBytes.set -> void +abstract Npgsql.NpgsqlDataSource.Clear() -> void +Npgsql.GssEncryptionMode +Npgsql.GssEncryptionMode.Disable = 0 -> Npgsql.GssEncryptionMode +Npgsql.GssEncryptionMode.Prefer = 1 -> Npgsql.GssEncryptionMode +Npgsql.GssEncryptionMode.Require = 2 -> Npgsql.GssEncryptionMode +Npgsql.TypeMapping.INpgsqlTypeMapper.AddDbTypeResolverFactory(Npgsql.Internal.DbTypeResolverFactory! factory) -> void +Npgsql.NpgsqlConnection.BeginTextExport(string! copyToCommand) -> Npgsql.NpgsqlCopyTextReader! +Npgsql.NpgsqlConnection.BeginTextExportAsync(string! copyToCommand, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlConnection.BeginTextImport(string! copyFromCommand) -> Npgsql.NpgsqlCopyTextWriter! +Npgsql.NpgsqlConnection.BeginTextImportAsync(string! copyFromCommand, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlConnection.CloneWithAsync(string! connectionString, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask +Npgsql.NpgsqlConnection.SslClientAuthenticationOptionsCallback.get -> System.Action? +Npgsql.NpgsqlConnection.SslClientAuthenticationOptionsCallback.set -> void +Npgsql.NpgsqlConnectionStringBuilder.GssEncryptionMode.get -> Npgsql.GssEncryptionMode +Npgsql.NpgsqlConnectionStringBuilder.GssEncryptionMode.set -> void +Npgsql.NpgsqlConnectionStringBuilder.IncludeFailedBatchedCommand.get -> bool +Npgsql.NpgsqlConnectionStringBuilder.IncludeFailedBatchedCommand.set -> void +Npgsql.NpgsqlConnectionStringBuilder.RequireAuth.get -> string? +Npgsql.NpgsqlConnectionStringBuilder.RequireAuth.set -> void +Npgsql.NpgsqlConnectionStringBuilder.SslNegotiation.get -> Npgsql.SslNegotiation +Npgsql.NpgsqlConnectionStringBuilder.SslNegotiation.set -> void +Npgsql.NpgsqlCopyTextReader.Timeout.get -> int +Npgsql.NpgsqlCopyTextReader.Timeout.set -> void +Npgsql.NpgsqlCopyTextWriter.Timeout.get -> int +Npgsql.NpgsqlCopyTextWriter.Timeout.set -> void +Npgsql.NpgsqlDataSourceBuilder.ConfigureTypeLoading(System.Action! configureAction) -> Npgsql.NpgsqlDataSourceBuilder! +Npgsql.NpgsqlDataSourceBuilder.MapComposite(System.Type! clrType, string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.NpgsqlDataSourceBuilder! +Npgsql.NpgsqlDataSourceBuilder.MapComposite(string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.NpgsqlDataSourceBuilder! +Npgsql.NpgsqlDataSourceBuilder.MapEnum(System.Type! clrType, string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.NpgsqlDataSourceBuilder! +Npgsql.NpgsqlDataSourceBuilder.MapEnum(string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.NpgsqlDataSourceBuilder! +Npgsql.NpgsqlDataSourceBuilder.ConfigureTracing(System.Action! configureAction) -> Npgsql.NpgsqlDataSourceBuilder! +Npgsql.NpgsqlDataSourceBuilder.UseNegotiateOptionsCallback(System.Action? negotiateOptionsCallback) -> Npgsql.NpgsqlDataSourceBuilder! +Npgsql.NpgsqlDataSourceBuilder.UseRootCertificates(System.Security.Cryptography.X509Certificates.X509Certificate2Collection? rootCertificates) -> Npgsql.NpgsqlDataSourceBuilder! +Npgsql.NpgsqlDataSourceBuilder.UseRootCertificatesCallback(System.Func? rootCertificateCallback) -> Npgsql.NpgsqlDataSourceBuilder! +Npgsql.NpgsqlDataSourceBuilder.UseSslClientAuthenticationOptionsCallback(System.Action? sslClientAuthenticationOptionsCallback) -> Npgsql.NpgsqlDataSourceBuilder! +Npgsql.NpgsqlMetricsOptions +Npgsql.NpgsqlMetricsOptions.NpgsqlMetricsOptions() -> void +Npgsql.NpgsqlSlimDataSourceBuilder.ConfigureTracing(System.Action! configureAction) -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.ConfigureTypeLoading(System.Action! configureAction) -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.EnableCube() -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.EnableGeometricTypes() -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.EnableJsonTypes() -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.EnableNetworkTypes() -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.MapComposite(System.Type! clrType, string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.MapComposite(string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.MapEnum(System.Type! clrType, string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.MapEnum(string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.UseNegotiateOptionsCallback(System.Action? negotiateOptionsCallback) -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.UseRootCertificates(System.Security.Cryptography.X509Certificates.X509Certificate2Collection? rootCertificates) -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.UseRootCertificatesCallback(System.Func? rootCertificateCallback) -> Npgsql.NpgsqlSlimDataSourceBuilder! +Npgsql.NpgsqlSlimDataSourceBuilder.UseSslClientAuthenticationOptionsCallback(System.Action? sslClientAuthenticationOptionsCallback) -> Npgsql.NpgsqlSlimDataSourceBuilder! +*REMOVED*Npgsql.NpgsqlTracingOptions +*REMOVED*Npgsql.NpgsqlTracingOptions.NpgsqlTracingOptions() -> void +Npgsql.NpgsqlTracingOptionsBuilder +Npgsql.NpgsqlTracingOptionsBuilder.ConfigureBatchEnrichmentCallback(System.Action? batchEnrichmentCallback) -> Npgsql.NpgsqlTracingOptionsBuilder! +Npgsql.NpgsqlTracingOptionsBuilder.ConfigureBatchFilter(System.Func? batchFilter) -> Npgsql.NpgsqlTracingOptionsBuilder! +Npgsql.NpgsqlTracingOptionsBuilder.ConfigureBatchSpanNameProvider(System.Func? batchSpanNameProvider) -> Npgsql.NpgsqlTracingOptionsBuilder! +Npgsql.NpgsqlTracingOptionsBuilder.ConfigureCommandEnrichmentCallback(System.Action? commandEnrichmentCallback) -> Npgsql.NpgsqlTracingOptionsBuilder! +Npgsql.NpgsqlTracingOptionsBuilder.ConfigureCommandFilter(System.Func? commandFilter) -> Npgsql.NpgsqlTracingOptionsBuilder! +Npgsql.NpgsqlTracingOptionsBuilder.ConfigureCommandSpanNameProvider(System.Func? commandSpanNameProvider) -> Npgsql.NpgsqlTracingOptionsBuilder! +Npgsql.NpgsqlTracingOptionsBuilder.ConfigureCopyOperationEnrichmentCallback(System.Action? copyOperationEnrichmentCallback) -> Npgsql.NpgsqlTracingOptionsBuilder! +Npgsql.NpgsqlTracingOptionsBuilder.ConfigureCopyOperationFilter(System.Func? copyOperationFilter) -> Npgsql.NpgsqlTracingOptionsBuilder! +Npgsql.NpgsqlTracingOptionsBuilder.ConfigureCopyOperationSpanNameProvider(System.Func? copyOperationSpanNameProvider) -> Npgsql.NpgsqlTracingOptionsBuilder! +Npgsql.NpgsqlTracingOptionsBuilder.EnableFirstResponseEvent(bool enable = true) -> Npgsql.NpgsqlTracingOptionsBuilder! +Npgsql.NpgsqlTracingOptionsBuilder.EnablePhysicalOpenTracing(bool enable = true) -> Npgsql.NpgsqlTracingOptionsBuilder! +Npgsql.NpgsqlTypeLoadingOptionsBuilder +Npgsql.NpgsqlTypeLoadingOptionsBuilder.EnableTableCompositesLoading(bool enable = true) -> Npgsql.NpgsqlTypeLoadingOptionsBuilder! +Npgsql.NpgsqlTypeLoadingOptionsBuilder.EnableTypeLoading(bool enable = true) -> Npgsql.NpgsqlTypeLoadingOptionsBuilder! +Npgsql.NpgsqlTypeLoadingOptionsBuilder.SetTypeLoadingSchemas(params System.Collections.Generic.IEnumerable? schemas) -> Npgsql.NpgsqlTypeLoadingOptionsBuilder! +Npgsql.Replication.PgOutput.ReplicationValue.GetFieldName() -> string! +Npgsql.Replication.PgOutput.Messages.ParallelStreamAbortMessage +Npgsql.Replication.PgOutput.Messages.ParallelStreamAbortMessage.AbortLsn.get -> NpgsqlTypes.NpgsqlLogSequenceNumber +Npgsql.Replication.PgOutput.Messages.ParallelStreamAbortMessage.AbortTimestamp.get -> System.DateTime +Npgsql.Replication.PgOutput.PgOutputProtocolVersion +Npgsql.Replication.PgOutput.PgOutputProtocolVersion.V1 = 1 -> Npgsql.Replication.PgOutput.PgOutputProtocolVersion +Npgsql.Replication.PgOutput.PgOutputProtocolVersion.V2 = 2 -> Npgsql.Replication.PgOutput.PgOutputProtocolVersion +Npgsql.Replication.PgOutput.PgOutputProtocolVersion.V3 = 3 -> Npgsql.Replication.PgOutput.PgOutputProtocolVersion +Npgsql.Replication.PgOutput.PgOutputProtocolVersion.V4 = 4 -> Npgsql.Replication.PgOutput.PgOutputProtocolVersion +Npgsql.Replication.PgOutput.PgOutputReplicationOptions.PgOutputReplicationOptions(string! publicationName, Npgsql.Replication.PgOutput.PgOutputProtocolVersion protocolVersion, bool? binary = null, Npgsql.Replication.PgOutput.PgOutputStreamingMode? streamingMode = null, bool? messages = null, bool? twoPhase = null) -> void +Npgsql.Replication.PgOutput.PgOutputReplicationOptions.PgOutputReplicationOptions(System.Collections.Generic.IEnumerable! publicationNames, Npgsql.Replication.PgOutput.PgOutputProtocolVersion protocolVersion, bool? binary = null, Npgsql.Replication.PgOutput.PgOutputStreamingMode? streamingMode = null, bool? messages = null, bool? twoPhase = null) -> void +Npgsql.Replication.PgOutput.PgOutputReplicationOptions.ProtocolVersion.get -> Npgsql.Replication.PgOutput.PgOutputProtocolVersion +*REMOVED*Npgsql.Replication.PgOutput.PgOutputReplicationOptions.ProtocolVersion.get -> ulong +Npgsql.Replication.PgOutput.PgOutputReplicationOptions.StreamingMode.get -> Npgsql.Replication.PgOutput.PgOutputStreamingMode? +*REMOVED*Npgsql.Replication.PgOutput.PgOutputReplicationOptions.Streaming.get -> bool? +Npgsql.Replication.PgOutput.PgOutputStreamingMode +Npgsql.Replication.PgOutput.PgOutputStreamingMode.Off = 0 -> Npgsql.Replication.PgOutput.PgOutputStreamingMode +Npgsql.Replication.PgOutput.PgOutputStreamingMode.On = 1 -> Npgsql.Replication.PgOutput.PgOutputStreamingMode +Npgsql.Replication.PgOutput.PgOutputStreamingMode.Parallel = 2 -> Npgsql.Replication.PgOutput.PgOutputStreamingMode +Npgsql.SslNegotiation +Npgsql.SslNegotiation.Direct = 1 -> Npgsql.SslNegotiation +Npgsql.SslNegotiation.Postgres = 0 -> Npgsql.SslNegotiation +override Npgsql.NpgsqlDataReader.GetColumnSchemaAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task!>! +override Npgsql.NpgsqlMultiHostDataSource.Clear() -> void +Npgsql.NpgsqlDataSource.ReloadTypes() -> void +Npgsql.NpgsqlDataSource.ReloadTypesAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +Npgsql.NpgsqlConnection.ReloadTypesAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +*REMOVED*Npgsql.NpgsqlConnection.ReloadTypesAsync() -> System.Threading.Tasks.Task! +*REMOVED*Npgsql.NpgsqlDataSourceBuilder.MapComposite(System.Type! clrType, string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.TypeMapping.INpgsqlTypeMapper! +*REMOVED*Npgsql.NpgsqlDataSourceBuilder.MapComposite(string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.TypeMapping.INpgsqlTypeMapper! +*REMOVED*Npgsql.NpgsqlDataSourceBuilder.MapEnum(System.Type! clrType, string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.TypeMapping.INpgsqlTypeMapper! +*REMOVED*Npgsql.NpgsqlDataSourceBuilder.MapEnum(string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.TypeMapping.INpgsqlTypeMapper! +*REMOVED*Npgsql.NpgsqlSlimDataSourceBuilder.MapComposite(System.Type! clrType, string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.TypeMapping.INpgsqlTypeMapper! +*REMOVED*Npgsql.NpgsqlSlimDataSourceBuilder.MapComposite(string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.TypeMapping.INpgsqlTypeMapper! +*REMOVED*Npgsql.NpgsqlSlimDataSourceBuilder.MapEnum(System.Type! clrType, string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.TypeMapping.INpgsqlTypeMapper! +*REMOVED*Npgsql.NpgsqlSlimDataSourceBuilder.MapEnum(string? pgName = null, Npgsql.INpgsqlNameTranslator? nameTranslator = null) -> Npgsql.TypeMapping.INpgsqlTypeMapper! +static NpgsqlTypes.NpgsqlInet.implicit operator NpgsqlTypes.NpgsqlInet(System.Net.IPNetwork cidr) -> NpgsqlTypes.NpgsqlInet +static readonly NpgsqlTypes.NpgsqlTsVector.Empty -> NpgsqlTypes.NpgsqlTsVector! +NpgsqlTypes.NpgsqlBox.Deconstruct(out double left, out double right, out double bottom, out double top) -> void +NpgsqlTypes.NpgsqlBox.Deconstruct(out double left, out double right, out double bottom, out double top, out double width, out double height) -> void +NpgsqlTypes.NpgsqlBox.Deconstruct(out NpgsqlTypes.NpgsqlPoint lowerLeft, out NpgsqlTypes.NpgsqlPoint upperRight) -> void +NpgsqlTypes.NpgsqlCircle.Deconstruct(out double x, out double y, out double radius) -> void +NpgsqlTypes.NpgsqlCircle.Deconstruct(out NpgsqlTypes.NpgsqlPoint center, out double radius) -> void +NpgsqlTypes.NpgsqlCube +NpgsqlTypes.NpgsqlCube.NpgsqlCube() -> void +NpgsqlTypes.NpgsqlCube.Dimensions.get -> int +NpgsqlTypes.NpgsqlCube.Equals(NpgsqlTypes.NpgsqlCube other) -> bool +NpgsqlTypes.NpgsqlCube.LowerLeft.get -> System.Collections.Generic.IReadOnlyList! +NpgsqlTypes.NpgsqlCube.NpgsqlCube(double coord) -> void +NpgsqlTypes.NpgsqlCube.NpgsqlCube(double lowerLeft, double upperRight) -> void +NpgsqlTypes.NpgsqlCube.NpgsqlCube(NpgsqlTypes.NpgsqlCube cube, double coord) -> void +NpgsqlTypes.NpgsqlCube.NpgsqlCube(NpgsqlTypes.NpgsqlCube cube, double lowerLeft, double upperRight) -> void +NpgsqlTypes.NpgsqlCube.NpgsqlCube(System.Collections.Generic.IEnumerable! coords) -> void +NpgsqlTypes.NpgsqlCube.NpgsqlCube(System.Collections.Generic.IEnumerable! lowerLeft, System.Collections.Generic.IEnumerable! upperRight) -> void +NpgsqlTypes.NpgsqlCube.IsPoint.get -> bool +NpgsqlTypes.NpgsqlCube.ToSubset(params int[]! indexes) -> NpgsqlTypes.NpgsqlCube +NpgsqlTypes.NpgsqlCube.UpperRight.get -> System.Collections.Generic.IReadOnlyList! +NpgsqlTypes.NpgsqlDbType.Cube = 63 -> NpgsqlTypes.NpgsqlDbType +override NpgsqlTypes.NpgsqlCube.Equals(object? obj) -> bool +override NpgsqlTypes.NpgsqlCube.GetHashCode() -> int +override NpgsqlTypes.NpgsqlCube.ToString() -> string! +static NpgsqlTypes.NpgsqlCube.operator !=(NpgsqlTypes.NpgsqlCube x, NpgsqlTypes.NpgsqlCube y) -> bool +static NpgsqlTypes.NpgsqlCube.operator ==(NpgsqlTypes.NpgsqlCube x, NpgsqlTypes.NpgsqlCube y) -> bool +NpgsqlTypes.NpgsqlLine.Deconstruct(out double a, out double b, out double c) -> void +NpgsqlTypes.NpgsqlLSeg.Deconstruct(out NpgsqlTypes.NpgsqlPoint start, out NpgsqlTypes.NpgsqlPoint end) -> void +NpgsqlTypes.NpgsqlPoint.Deconstruct(out double x, out double y) -> void +NpgsqlTypes.NpgsqlTid.Deconstruct(out uint blockNumber, out ushort offsetNumber) -> void +*REMOVED*Npgsql.NpgsqlConnection.BeginTextExport(string! copyToCommand) -> System.IO.TextReader! +*REMOVED*Npgsql.NpgsqlConnection.BeginTextExportAsync(string! copyToCommand, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +*REMOVED*Npgsql.NpgsqlConnection.BeginTextImport(string! copyFromCommand) -> System.IO.TextWriter! +*REMOVED*Npgsql.NpgsqlConnection.BeginTextImportAsync(string! copyFromCommand, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! +*REMOVED*Npgsql.NpgsqlDataReader.GetColumnSchemaAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task!>! diff --git a/src/Npgsql/Replication/Internal/LogicalReplicationConnectionExtensions.cs b/src/Npgsql/Replication/Internal/LogicalReplicationConnectionExtensions.cs index 6f703970de..d66a9e55d1 100644 --- a/src/Npgsql/Replication/Internal/LogicalReplicationConnectionExtensions.cs +++ b/src/Npgsql/Replication/Internal/LogicalReplicationConnectionExtensions.cs @@ -61,10 +61,8 @@ public static Task CreateLogicalReplicationSlot( CancellationToken cancellationToken = default) { connection.CheckDisposed(); - if (slotName is null) - throw new ArgumentNullException(nameof(slotName)); - if (outputPlugin is null) - throw new ArgumentNullException(nameof(outputPlugin)); + ArgumentNullException.ThrowIfNull(slotName); + ArgumentNullException.ThrowIfNull(outputPlugin); cancellationToken.ThrowIfCancellationRequested(); diff --git a/src/Npgsql/Replication/PgOutput/Messages/StreamAbortMessage.cs b/src/Npgsql/Replication/PgOutput/Messages/StreamAbortMessage.cs index 23fc2c5a24..20e5c4d2e3 100644 --- a/src/Npgsql/Replication/PgOutput/Messages/StreamAbortMessage.cs +++ b/src/Npgsql/Replication/PgOutput/Messages/StreamAbortMessage.cs @@ -4,9 +4,9 @@ namespace Npgsql.Replication.PgOutput.Messages; /// -/// Logical Replication Protocol stream abort message +/// Logical Replication Protocol stream abort message for Logical Streaming Replication Protocol versions 2-3 /// -public sealed class StreamAbortMessage : TransactionControlMessage +public class StreamAbortMessage : TransactionControlMessage { /// /// Xid of the subtransaction (will be same as xid of the transaction for top-level transactions). @@ -22,4 +22,31 @@ internal StreamAbortMessage Populate(NpgsqlLogSequenceNumber walStart, NpgsqlLog SubtransactionXid = subtransactionXid; return this; } -} \ No newline at end of file +} + +/// +/// Logical Replication Protocol stream abort message for Logical Streaming Replication Protocol versions 4+ +/// +public sealed class ParallelStreamAbortMessage : StreamAbortMessage +{ + /// + /// The LSN of the abort. + /// + public NpgsqlLogSequenceNumber AbortLsn { get; private set; } + + /// + /// Abort timestamp of the transaction. + /// + public DateTime AbortTimestamp { get; private set; } + + internal ParallelStreamAbortMessage() {} + + internal ParallelStreamAbortMessage Populate(NpgsqlLogSequenceNumber walStart, NpgsqlLogSequenceNumber walEnd, DateTime serverClock, + uint transactionXid, uint subtransactionXid, NpgsqlLogSequenceNumber abortLsn, DateTime abortTimestamp) + { + base.Populate(walStart, walEnd, serverClock, transactionXid, subtransactionXid); + AbortLsn = abortLsn; + AbortTimestamp = abortTimestamp; + return this; + } +} diff --git a/src/Npgsql/Replication/PgOutput/PgOutputAsyncEnumerable.cs b/src/Npgsql/Replication/PgOutput/PgOutputAsyncEnumerable.cs index ae26d229f6..d200c780a1 100644 --- a/src/Npgsql/Replication/PgOutput/PgOutputAsyncEnumerable.cs +++ b/src/Npgsql/Replication/PgOutput/PgOutputAsyncEnumerable.cs @@ -13,6 +13,7 @@ namespace Npgsql.Replication.PgOutput; sealed class PgOutputAsyncEnumerable : IAsyncEnumerable { + readonly PgOutputProtocolVersion _protocolVersion; readonly LogicalReplicationConnection _connection; readonly PgOutputReplicationSlot _slot; readonly PgOutputReplicationOptions _options; @@ -38,17 +39,20 @@ sealed class PgOutputAsyncEnumerable : IAsyncEnumerable _truncateMessageRelations = new(); // V2 - readonly StreamStartMessage _streamStartMessage = new(); - readonly StreamStopMessage _streamStopMessage = new(); - readonly StreamCommitMessage _streamCommitMessage = new(); - readonly StreamAbortMessage _streamAbortMessage = new(); + readonly StreamStartMessage _streamStartMessage = null!; + readonly StreamStopMessage _streamStopMessage = null!; + readonly StreamCommitMessage _streamCommitMessage = null!; + readonly StreamAbortMessage _streamAbortMessage = null!; // V3 - readonly BeginPrepareMessage _beginPrepareMessage = new(); - readonly PrepareMessage _prepareMessage = new(); - readonly CommitPreparedMessage _commitPreparedMessage = new(); - readonly RollbackPreparedMessage _rollbackPreparedMessage = new(); - readonly StreamPrepareMessage _streamPrepareMessage = new(); + readonly BeginPrepareMessage _beginPrepareMessage = null!; + readonly PrepareMessage _prepareMessage = null!; + readonly CommitPreparedMessage _commitPreparedMessage = null!; + readonly RollbackPreparedMessage _rollbackPreparedMessage = null!; + readonly StreamPrepareMessage _streamPrepareMessage = null!; + + // V4 + readonly ParallelStreamAbortMessage _parallelStreamAbortMessage = null!; #endregion @@ -59,12 +63,38 @@ internal PgOutputAsyncEnumerable( CancellationToken cancellationToken, NpgsqlLogSequenceNumber? walLocation = null) { + _protocolVersion = options.ProtocolVersion; _connection = connection; _slot = slot; _options = options; _baseCancellationToken = cancellationToken; _walLocation = walLocation; + + if (_protocolVersion >= PgOutputProtocolVersion.V2) + { + _streamStartMessage = new(); + _streamStopMessage = new(); + _streamCommitMessage = new(); + } + if (_protocolVersion >= PgOutputProtocolVersion.V3) + { + _beginPrepareMessage = new(); + _prepareMessage = new(); + _commitPreparedMessage = new(); + _rollbackPreparedMessage = new(); + _streamPrepareMessage = new(); + } + + if (_protocolVersion >= PgOutputProtocolVersion.V4) + { + _parallelStreamAbortMessage = new(); + } + else if (_protocolVersion >= PgOutputProtocolVersion.V2) + { + _streamAbortMessage = new(); + } + var connector = _connection.Connector; _insertMessage = new(connector); _defaultUpdateMessage = new(connector); @@ -395,9 +425,23 @@ async IAsyncEnumerator StartReplicationInternal(Canc } case BackendReplicationMessageCode.StreamAbort: { - await buf.EnsureAsync(8).ConfigureAwait(false); - yield return _streamAbortMessage.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, - transactionXid: buf.ReadUInt32(), subtransactionXid: buf.ReadUInt32()); + if (_protocolVersion >= PgOutputProtocolVersion.V4) + { + await buf.EnsureAsync(24).ConfigureAwait(false); + yield return _parallelStreamAbortMessage.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, + transactionXid: buf.ReadUInt32(), + subtransactionXid: buf.ReadUInt32(), + abortLsn: new(buf.ReadUInt64()), + abortTimestamp: PgDateTime.DecodeTimestamp(buf.ReadInt64(), DateTimeKind.Utc)); + + } + else + { + await buf.EnsureAsync(8).ConfigureAwait(false); + yield return _streamAbortMessage.Populate(xLogData.WalStart, xLogData.WalEnd, xLogData.ServerClock, + transactionXid: buf.ReadUInt32(), subtransactionXid: buf.ReadUInt32()); + + } continue; } case BackendReplicationMessageCode.BeginPrepare: diff --git a/src/Npgsql/Replication/PgOutput/PgOutputProtocolVersion.cs b/src/Npgsql/Replication/PgOutput/PgOutputProtocolVersion.cs new file mode 100644 index 0000000000..fd717b6791 --- /dev/null +++ b/src/Npgsql/Replication/PgOutput/PgOutputProtocolVersion.cs @@ -0,0 +1,30 @@ +namespace Npgsql.Replication.PgOutput; + +/// +/// The Logical Streaming Replication Protocol version. +/// +public enum PgOutputProtocolVersion : ulong +{ + /// + /// Version 1 is supported for server version 10 and above. + /// + V1 = 1UL, + + /// + /// Version 2 is supported only for server version 14 and above, and it allows + /// streaming of large in-progress transactions. + /// + V2 = 2UL, + + /// + /// Version 3 is supported only for server version 15 and above, and it allows + /// streaming of two-phase commits. + /// + V3 = 3UL, + + /// + /// Version 4 is supported only for server version 16 and above, and it allows + /// streams of large in-progress transactions to be applied in parallel. + /// + V4 = 4UL +} diff --git a/src/Npgsql/Replication/PgOutput/PgOutputReplicationOptions.cs b/src/Npgsql/Replication/PgOutput/PgOutputReplicationOptions.cs index 5835b88ee2..b6aedba5d3 100644 --- a/src/Npgsql/Replication/PgOutput/PgOutputReplicationOptions.cs +++ b/src/Npgsql/Replication/PgOutput/PgOutputReplicationOptions.cs @@ -9,18 +9,61 @@ namespace Npgsql.Replication.PgOutput; /// public class PgOutputReplicationOptions : IEquatable { + /// + /// Creates a new instance of . + /// + /// The publication names to include into the stream + /// The version of the logical streaming replication protocol. + /// Passing in unsupported protocol version numbers may lead to runtime errors. + /// Send values in binary representation + /// Enable streaming of in-progress transactions. + /// Setting this to sets + /// to . + /// Write logical decoding messages into the replication stream + /// Enable streaming of prepared transactions + [Obsolete("Please switch to the overloads that take PgOutputProtocolVersion and PgOutputStreamingMode values instead.")] + public PgOutputReplicationOptions(string publicationName, ulong protocolVersion, bool? binary = null, bool? streaming = null, + bool? messages = null, bool? twoPhase = null) + : this([publicationName ?? throw new ArgumentNullException(nameof(publicationName))], (PgOutputProtocolVersion)protocolVersion, + binary, streaming.HasValue ? streaming.Value ? PgOutputStreamingMode.On : PgOutputStreamingMode.Off : null, messages, twoPhase) + { + } + /// /// Creates a new instance of . /// /// The publication names to include into the stream /// The version of the logical streaming replication protocol /// Send values in binary representation - /// Enable streaming of in-progress transactions + /// Enable streaming of in-progress transactions + /// Write logical decoding messages into the replication stream + /// Enable streaming of prepared transactions + public PgOutputReplicationOptions(string publicationName, PgOutputProtocolVersion protocolVersion, bool? binary = null, + PgOutputStreamingMode? streamingMode = null, bool? messages = null, bool? twoPhase = null) + : this([publicationName ?? throw new ArgumentNullException(nameof(publicationName))], protocolVersion, binary, streamingMode, + messages, twoPhase) + { + } + + /// + /// Creates a new instance of . + /// + /// The publication names to include into the stream + /// The version of the logical streaming replication protocol. + /// Passing in unsupported protocol version numbers may lead to runtime errors. + /// Send values in binary representation + /// Enable streaming of in-progress transactions. + /// Setting this to sets + /// to . /// Write logical decoding messages into the replication stream /// Enable streaming of prepared transactions - public PgOutputReplicationOptions(string publicationName, ulong protocolVersion, bool? binary = null, bool? streaming = null, bool? messages = null, bool? twoPhase = null) - : this(new List { publicationName ?? throw new ArgumentNullException(nameof(publicationName)) }, protocolVersion, binary, streaming, messages, twoPhase) - { } + [Obsolete("Please switch to the overloads that take PgOutputProtocolVersion and PgOutputStreamingMode values instead.")] + public PgOutputReplicationOptions(IEnumerable publicationNames, ulong protocolVersion, bool? binary = null, + bool? streaming = null, bool? messages = null, bool? twoPhase = null) + : this(publicationNames, (PgOutputProtocolVersion)protocolVersion, binary, + streaming.HasValue ? streaming.Value ? PgOutputStreamingMode.On : PgOutputStreamingMode.Off : null, messages, twoPhase) + { + } /// /// Creates a new instance of . @@ -28,10 +71,11 @@ public PgOutputReplicationOptions(string publicationName, ulong protocolVersion, /// The publication names to include into the stream /// The version of the logical streaming replication protocol /// Send values in binary representation - /// Enable streaming of in-progress transactions + /// Enable streaming of in-progress transactions /// Write logical decoding messages into the replication stream /// Enable streaming of prepared transactions - public PgOutputReplicationOptions(IEnumerable publicationNames, ulong protocolVersion, bool? binary = null, bool? streaming = null, bool? messages = null, bool? twoPhase = null) + public PgOutputReplicationOptions(IEnumerable publicationNames, PgOutputProtocolVersion protocolVersion, bool? binary = null, + PgOutputStreamingMode? streamingMode = null, bool? messages = null, bool? twoPhase = null) { var publicationNamesList = new List(publicationNames); if (publicationNamesList.Count < 1) @@ -46,7 +90,7 @@ public PgOutputReplicationOptions(IEnumerable publicationNames, ulong pr PublicationNames = publicationNamesList; ProtocolVersion = protocolVersion; Binary = binary; - Streaming = streaming; + StreamingMode = streamingMode; Messages = messages; TwoPhase = twoPhase; } @@ -54,7 +98,7 @@ public PgOutputReplicationOptions(IEnumerable publicationNames, ulong pr /// /// The version of the Logical Streaming Replication Protocol /// - public ulong ProtocolVersion { get; } + public PgOutputProtocolVersion ProtocolVersion { get; } /// /// The publication names to stream @@ -74,10 +118,12 @@ public PgOutputReplicationOptions(IEnumerable publicationNames, ulong pr /// Enable streaming of in-progress transactions /// /// - /// This works as of logical streaming replication protocol version 2 (PostgreSQL 14+) + /// works as of logical streaming replication protocol version 2 (PostgreSQL 14+), + /// works as of logical streaming replication protocol version 4 (PostgreSQL 16+), /// // See: https://github.com/postgres/postgres/commit/464824323e57dc4b397e8b05854d779908b55304 - public bool? Streaming { get; } + // and https://github.com/postgres/postgres/commit/216a784829c2c5f03ab0c43e009126cbb819e9b2 + public PgOutputStreamingMode? StreamingMode { get; } /// /// Write logical decoding messages into the replication stream @@ -100,13 +146,21 @@ public PgOutputReplicationOptions(IEnumerable publicationNames, ulong pr internal IEnumerable> GetOptionPairs() { - yield return new KeyValuePair("proto_version", ProtocolVersion.ToString(CultureInfo.InvariantCulture)); + yield return new KeyValuePair("proto_version", ((ulong)ProtocolVersion).ToString(CultureInfo.InvariantCulture)); yield return new KeyValuePair("publication_names", "\"" + string.Join("\",\"", PublicationNames) + "\""); if (Binary != null) yield return new KeyValuePair("binary", Binary.Value ? "on" : "off"); - if (Streaming != null) - yield return new KeyValuePair("streaming", Streaming.Value ? "on" : "off"); + if (StreamingMode != null) + { + yield return new KeyValuePair("streaming", StreamingMode.Value switch + { + PgOutputStreamingMode.Off => "off", + PgOutputStreamingMode.On => "on", + PgOutputStreamingMode.Parallel => "parallel", + _ => throw new ArgumentOutOfRangeException($"Unknown {nameof(PgOutputStreamingMode)} value: {StreamingMode.Value}") + }); + } if (Messages != null) yield return new KeyValuePair("messages", Messages.Value ? "on" : "off"); if (TwoPhase != null) @@ -118,12 +172,12 @@ public bool Equals(PgOutputReplicationOptions? other) => other != null && ( ReferenceEquals(this, other) || ProtocolVersion == other.ProtocolVersion && PublicationNames.Equals(other.PublicationNames) && Binary == other.Binary && - Streaming == other.Streaming && Messages == other.Messages && TwoPhase == other.TwoPhase); + StreamingMode == other.StreamingMode && Messages == other.Messages && TwoPhase == other.TwoPhase); /// public override bool Equals(object? obj) => obj is PgOutputReplicationOptions other && other.Equals(this); /// - public override int GetHashCode() => HashCode.Combine(ProtocolVersion, PublicationNames, Binary, Streaming, Messages, TwoPhase); + public override int GetHashCode() => HashCode.Combine(ProtocolVersion, PublicationNames, Binary, StreamingMode, Messages, TwoPhase); } diff --git a/src/Npgsql/Replication/PgOutput/PgOutputStreamingMode.cs b/src/Npgsql/Replication/PgOutput/PgOutputStreamingMode.cs new file mode 100644 index 0000000000..935ad2792c --- /dev/null +++ b/src/Npgsql/Replication/PgOutput/PgOutputStreamingMode.cs @@ -0,0 +1,23 @@ +namespace Npgsql.Replication.PgOutput; + +/// +/// Option to enable streaming of in-progress transactions. +/// Minimum protocol version 2 is required to turn it on. Minimum protocol version 4 is required for the "parallel" option. +/// +public enum PgOutputStreamingMode +{ + /// + /// Disable streaming of in-progress transactions + /// + Off, + + /// + /// Enable streaming of in-progress transactions + /// + On, + + /// + /// Enable streaming of in-progress transactions and enable sending extra information with some messages to be used for parallelisation + /// + Parallel +} diff --git a/src/Npgsql/Replication/PgOutput/ReadonlyArrayBuffer.cs b/src/Npgsql/Replication/PgOutput/ReadonlyArrayBuffer.cs index df910af4d2..3d22b5f5f6 100644 --- a/src/Npgsql/Replication/PgOutput/ReadonlyArrayBuffer.cs +++ b/src/Npgsql/Replication/PgOutput/ReadonlyArrayBuffer.cs @@ -11,7 +11,7 @@ sealed class ReadOnlyArrayBuffer : IReadOnlyList int _size; public ReadOnlyArrayBuffer() - => _items = Array.Empty(); + => _items = []; ReadOnlyArrayBuffer(T[] items) { diff --git a/src/Npgsql/Replication/PgOutput/ReplicationValue.cs b/src/Npgsql/Replication/PgOutput/ReplicationValue.cs index c918325840..5f7d76b418 100644 --- a/src/Npgsql/Replication/PgOutput/ReplicationValue.cs +++ b/src/Npgsql/Replication/PgOutput/ReplicationValue.cs @@ -76,6 +76,12 @@ public bool IsUnchangedToastedValue /// The data type of the specified column. public Type GetFieldType() => _fieldDescription.FieldType; + /// + /// Gets the name of the specified column. + /// + /// The name of the specified column. + public string GetFieldName() => _fieldDescription.Name; + /// /// Gets the value of the specified column as a type. /// @@ -111,11 +117,12 @@ public async ValueTask Get(CancellationToken cancellationToken = default) using var registration = _readBuffer.Connector.StartNestedCancellableOperation(cancellationToken, attemptPgCancellation: false); - var reader = PgReader.Init(Length, _fieldDescription.DataFormat); + var reader = PgReader; + reader.Init(Length, _fieldDescription.DataFormat); await reader.StartReadAsync(info.ConverterInfo.BufferRequirement, cancellationToken).ConfigureAwait(false); var result = info.AsObject ? (T)await info.ConverterInfo.Converter.ReadAsObjectAsync(reader, cancellationToken).ConfigureAwait(false) - : await info.ConverterInfo.GetConverter().ReadAsync(reader, cancellationToken).ConfigureAwait(false); + : await info.ConverterInfo.Converter.UnsafeDowncast().ReadAsync(reader, cancellationToken).ConfigureAwait(false); await reader.EndReadAsync().ConfigureAwait(false); return result; } @@ -146,7 +153,8 @@ public Stream GetStream() throw new InvalidCastException($"Column '{_fieldDescription.Name}' is an unchanged TOASTed value (actual value not sent)."); } - var reader = _readBuffer.PgReader.Init(Length, _fieldDescription.DataFormat); + var reader = PgReader; + reader.Init(Length, _fieldDescription.DataFormat); return reader.GetStream(canSeek: false); } @@ -170,7 +178,8 @@ public TextReader GetTextReader() throw new InvalidCastException($"Column '{_fieldDescription.Name}' is an unchanged TOASTed value (actual value not sent)."); } - var reader = PgReader.Init(Length, _fieldDescription.DataFormat); + var reader = PgReader; + reader.Init(Length, _fieldDescription.DataFormat); reader.StartRead(info.ConverterInfo.BufferRequirement); var result = (TextReader)info.ConverterInfo.Converter.ReadAsObject(reader); reader.EndRead(); @@ -182,10 +191,11 @@ internal async Task Consume(CancellationToken cancellationToken) if (_isConsumed) return; - if (!PgReader.Initialized) - PgReader.Init(Length, _fieldDescription.DataFormat); - await PgReader.ConsumeAsync(cancellationToken: cancellationToken).ConfigureAwait(false); - await PgReader.CommitAsync(resuming: false).ConfigureAwait(false); + var reader = PgReader; + if (!reader.Initialized) + reader.Init(Length, _fieldDescription.DataFormat); + await reader.ConsumeAsync(cancellationToken: cancellationToken).ConfigureAwait(false); + await reader.CommitAsync().ConfigureAwait(false); _isConsumed = true; } diff --git a/src/Npgsql/Replication/ReplicationConnection.cs b/src/Npgsql/Replication/ReplicationConnection.cs index 283ffd6d3f..3877254521 100644 --- a/src/Npgsql/Replication/ReplicationConnection.cs +++ b/src/Npgsql/Replication/ReplicationConnection.cs @@ -81,8 +81,8 @@ private protected ReplicationConnection(string? connectionString) : this() /// /// /// Since replication connections are a special kind of connection, - /// , , - /// and + /// , + /// and /// are always disabled no matter what you set them to in your connection string. /// [AllowNull] @@ -95,15 +95,10 @@ public string ConnectionString { { Pooling = false, Enlist = false, - Multiplexing = false, KeepAlive = 0, ReplicationMode = ReplicationMode }; - // Physical replication connections don't allow regular queries, so we can't load types from PG - if (ReplicationMode == ReplicationMode.Physical) - cs.ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading; - _npgsqlConnection.ConnectionString = cs.ToString(); } } @@ -241,8 +236,6 @@ public async Task Open(CancellationToken cancellationToken = default) SetTimeouts(CommandTimeout, CommandTimeout); - _npgsqlConnection.Connector!.LongRunningConnection = true; - ReplicationLogger = _npgsqlConnection.Connector!.LoggingConfiguration.ReplicationLogger; } @@ -326,8 +319,7 @@ public async Task IdentifySystem(CancellationTo /// The current setting of the run-time parameter specified in as . public Task Show(string parameterName, CancellationToken cancellationToken = default) { - if (parameterName is null) - throw new ArgumentNullException(nameof(parameterName)); + ArgumentNullException.ThrowIfNull(parameterName); return ShowInternal(parameterName, cancellationToken); @@ -459,8 +451,11 @@ internal async IAsyncEnumerator StartReplicationInternal( SetTimeouts(_walReceiverTimeout, CommandTimeout); - _sendFeedbackTimer = new Timer(TimerSendFeedback, state: null, WalReceiverStatusInterval, Timeout.InfiniteTimeSpan); - _requestFeedbackTimer = new Timer(TimerRequestFeedback, state: null, _requestFeedbackInterval, Timeout.InfiniteTimeSpan); + using (ExecutionContext.SuppressFlow()) // Don't capture the current ExecutionContext and its AsyncLocals onto the timer causing them to live forever + { + _sendFeedbackTimer = new Timer(TimerSendFeedback, state: null, WalReceiverStatusInterval, Timeout.InfiniteTimeSpan); + _requestFeedbackTimer = new Timer(TimerRequestFeedback, state: null, _requestFeedbackInterval, Timeout.InfiniteTimeSpan); + } while (true) { @@ -499,7 +494,7 @@ internal async IAsyncEnumerator StartReplicationInternal( // Our consumer may not have read the stream to the end, but it might as well have been us // ourselves bypassing the stream and reading directly from the buffer in StartReplication() if (!columnStream.IsDisposed && columnStream.Position < columnStream.Length && !bypassingStream) - await buf.Skip(checked((int)(columnStream.Length - columnStream.Position)), true).ConfigureAwait(false); + await buf.Skip(async: true, checked((int)(columnStream.Length - columnStream.Position))).ConfigureAwait(false); continue; } @@ -621,6 +616,7 @@ async Task SendFeedback(bool waitOnSemaphore = false, bool requestReply = false, if (buf.WriteSpaceLeft < len) await connector.Flush(async: true, cancellationToken).ConfigureAwait(false); + buf.StartMessage(len); buf.WriteByte(FrontendMessageCode.CopyData); buf.WriteInt32(len - 1); buf.WriteByte((byte)'r'); // TODO: enum/const? @@ -713,8 +709,7 @@ async void TimerSendFeedback(object? obj) /// A task representing the asynchronous drop operation. public Task DropReplicationSlot(string slotName, bool wait = false, CancellationToken cancellationToken = default) { - if (slotName is null) - throw new ArgumentNullException(nameof(slotName)); + ArgumentNullException.ThrowIfNull(slotName); CheckDisposed(); @@ -896,7 +891,7 @@ void SetTimeouts(TimeSpan readTimeout, TimeSpan writeTimeout) var connector = Connector; var readBuffer = connector.ReadBuffer; if (readBuffer != null) - readBuffer.Timeout = readTimeout > TimeSpan.Zero ? readTimeout : TimeSpan.Zero; + readBuffer.Timeout = readTimeout > TimeSpan.Zero ? readTimeout : Timeout.InfiniteTimeSpan; var writeBuffer = connector.WriteBuffer; if (writeBuffer != null) diff --git a/src/Npgsql/Replication/ReplicationSlot.cs b/src/Npgsql/Replication/ReplicationSlot.cs index 8790303444..1e9b3473b6 100644 --- a/src/Npgsql/Replication/ReplicationSlot.cs +++ b/src/Npgsql/Replication/ReplicationSlot.cs @@ -6,12 +6,10 @@ public abstract class ReplicationSlot { internal ReplicationSlot(string name) - { - Name = name; - } + => Name = name; /// /// The name of the newly-created replication slot. /// public string Name { get; } -} \ No newline at end of file +} diff --git a/src/Npgsql/Schema/DbColumnSchemaGenerator.cs b/src/Npgsql/Schema/DbColumnSchemaGenerator.cs index 300001e72d..3abda51fae 100644 --- a/src/Npgsql/Schema/DbColumnSchemaGenerator.cs +++ b/src/Npgsql/Schema/DbColumnSchemaGenerator.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Collections.ObjectModel; using System.Data; +using System.Data.Common; using System.Threading; using System.Threading.Tasks; using System.Transactions; @@ -30,77 +31,79 @@ internal DbColumnSchemaGenerator(NpgsqlConnection connection, RowDescriptionMess #region Columns queries static string GenerateColumnsQuery(Version pgVersion, string columnFieldFilter) => - $@"SELECT - typ.oid AS typoid, nspname, relname, attname, attrelid, attnum, attnotnull, - {(pgVersion.IsGreaterOrEqual(10) ? "attidentity != ''" : "FALSE")} AS isidentity, - CASE WHEN typ.typtype = 'd' THEN typ.typtypmod ELSE atttypmod END AS typmod, - CASE WHEN atthasdef THEN (SELECT pg_get_expr(adbin, cls.oid) FROM pg_attrdef WHERE adrelid = cls.oid AND adnum = attr.attnum) ELSE NULL END AS default, - CASE WHEN col.is_updatable = 'YES' THEN true ELSE false END AS is_updatable, - EXISTS ( - SELECT * FROM pg_index - WHERE pg_index.indrelid = cls.oid AND - pg_index.indisprimary AND - attnum = ANY (indkey) - ) AS isprimarykey, - EXISTS ( - SELECT * FROM pg_index - WHERE pg_index.indrelid = cls.oid AND - pg_index.indisunique AND - pg_index.{(pgVersion.IsGreaterOrEqual(11) ? "indnkeyatts" : "indnatts")} = 1 AND - attnum = pg_index.indkey[0] - ) AS isunique + $""" +SELECT + typ.oid AS typoid, nspname, relname, attname, attrelid, attnum, attnotnull, + {(pgVersion.IsGreaterOrEqual(10) ? "attidentity != ''" : "FALSE")} AS isidentity, + CASE WHEN typ.typtype = 'd' THEN typ.typtypmod ELSE atttypmod END AS typmod, + CASE WHEN atthasdef THEN (SELECT pg_get_expr(adbin, cls.oid) FROM pg_attrdef WHERE adrelid = cls.oid AND adnum = attr.attnum) ELSE NULL END AS default, + ((cls.relkind = ANY (ARRAY['r'::"char", 'p'::"char"])) + OR ((cls.relkind = ANY (ARRAY['v'::"char", 'f'::"char"])) + AND pg_column_is_updatable((cls.oid)::regclass, attr.attnum, false))) + {(pgVersion.IsGreaterOrEqual(10) ? "AND attr.attidentity NOT IN ('a')" : "")} + AS is_updatable, + EXISTS ( + SELECT * FROM pg_index + WHERE pg_index.indrelid = cls.oid AND + pg_index.indisprimary AND + attnum = ANY (indkey) + ) AS isprimarykey, + EXISTS ( + SELECT * FROM pg_index + WHERE pg_index.indrelid = cls.oid AND + pg_index.indisunique AND + pg_index.{(pgVersion.IsGreaterOrEqual(11) ? "indnkeyatts" : "indnatts")} = 1 AND + attnum = pg_index.indkey[0] + ) AS isunique FROM pg_attribute AS attr JOIN pg_type AS typ ON attr.atttypid = typ.oid JOIN pg_class AS cls ON cls.oid = attr.attrelid JOIN pg_namespace AS ns ON ns.oid = cls.relnamespace -LEFT OUTER JOIN information_schema.columns AS col ON col.table_schema = nspname AND - col.table_name = relname AND - col.column_name = attname WHERE - atttypid <> 0 AND - relkind IN ('r', 'v', 'm') AND - NOT attisdropped AND - nspname NOT IN ('pg_catalog', 'information_schema') AND - attnum > 0 AND - ({columnFieldFilter}) -ORDER BY attnum"; + atttypid <> 0 AND + relkind IN ('r', 'v', 'm') AND + NOT attisdropped AND + nspname NOT IN ('pg_catalog', 'information_schema') AND + attnum > 0 AND + ({columnFieldFilter}) +ORDER BY attnum +"""; /// /// Stripped-down version of , mainly to support Amazon Redshift. /// static string GenerateOldColumnsQuery(string columnFieldFilter) => - $@"SELECT - typ.oid AS typoid, nspname, relname, attname, attrelid, attnum, attnotnull, - CASE WHEN typ.typtype = 'd' THEN typ.typtypmod ELSE atttypmod END AS typmod, - CASE WHEN atthasdef THEN (SELECT pg_get_expr(adbin, cls.oid) FROM pg_attrdef WHERE adrelid = cls.oid AND adnum = attr.attnum) ELSE NULL END AS default, - TRUE AS is_updatable, /* Supported only since PG 8.2 */ - FALSE AS isprimarykey, /* Can't do ANY() on pg_index.indkey which is int2vector */ - FALSE AS isunique /* Can't do ANY() on pg_index.indkey which is int2vector */ + $""" +SELECT + typ.oid AS typoid, nspname, relname, attname, attrelid, attnum, attnotnull, + CASE WHEN typ.typtype = 'd' THEN typ.typtypmod ELSE atttypmod END AS typmod, + CASE WHEN atthasdef THEN (SELECT pg_get_expr(adbin, cls.oid) FROM pg_attrdef WHERE adrelid = cls.oid AND adnum = attr.attnum) ELSE NULL END AS default, + TRUE AS is_updatable, /* Supported only since PG 8.2 */ + FALSE AS isprimarykey, /* Can't do ANY() on pg_index.indkey which is int2vector */ + FALSE AS isunique /* Can't do ANY() on pg_index.indkey which is int2vector */ FROM pg_attribute AS attr JOIN pg_type AS typ ON attr.atttypid = typ.oid JOIN pg_class AS cls ON cls.oid = attr.attrelid JOIN pg_namespace AS ns ON ns.oid = cls.relnamespace -LEFT OUTER JOIN information_schema.columns AS col ON col.table_schema = nspname AND - col.table_name = relname AND - col.column_name = attname WHERE - atttypid <> 0 AND - relkind IN ('r', 'v', 'm') AND - NOT attisdropped AND - nspname NOT IN ('pg_catalog', 'information_schema') AND - attnum > 0 AND - ({columnFieldFilter}) -ORDER BY attnum"; + atttypid <> 0 AND + relkind IN ('r', 'v', 'm') AND + NOT attisdropped AND + nspname NOT IN ('pg_catalog', 'information_schema') AND + attnum > 0 AND + ({columnFieldFilter}) +ORDER BY attnum +"""; #endregion Column queries - internal async Task> GetColumnSchema(bool async, CancellationToken cancellationToken = default) + internal async Task> GetColumnSchema(bool async, CancellationToken cancellationToken = default) where T : DbColumn { // This is mainly for Amazon Redshift var oldQueryMode = _connection.PostgreSqlVersion < new Version(8, 2); var numFields = _rowDescription.Count; - var result = new List(numFields); + var result = new List(numFields); for (var i = 0; i < numFields; i++) result.Add(null); var populatedColumns = 0; @@ -154,7 +157,7 @@ internal async Task> GetColumnSchema(bool asy // The column's ordinal is with respect to the resultset, not its table column.ColumnOrdinal = ordinal; - result[ordinal] = column; + result[ordinal] = (T?)(object)column; } } } @@ -173,14 +176,14 @@ internal async Task> GetColumnSchema(bool asy // Fill in whatever info we have from the RowDescription itself for (var i = 0; i < numFields; i++) { - var column = result[i]; + var column = (NpgsqlDbColumn?)(object?)result[i]; var field = _rowDescription[i]; if (column is null) { column = SetUpNonColumnField(field); column.ColumnOrdinal = i; - result[i] = column; + result[i] = (T?)(object)column; populatedColumns++; } @@ -260,8 +263,11 @@ void ColumnPostConfig(NpgsqlDbColumn column, int typeModifier) { var serializerOptions = _connection.Connector!.SerializerOptions; - column.NpgsqlDbType = column.PostgresType.DataTypeName.ToNpgsqlDbType(); - if (serializerOptions.GetObjectOrDefaultTypeInfo(column.PostgresType) is { } typeInfo) + // Call GetRepresentationalType to also handle domain types + // Because NpgsqlCommandBuilder relies on NpgsqlDbType for correct type mapping + // And otherwise we'll get NpgsqlDbType.Unknown + column.NpgsqlDbType = column.PostgresType.GetRepresentationalType().DataTypeName.ToNpgsqlDbType(); + if (serializerOptions.GetTypeInfo(typeof(object), serializerOptions.ToCanonicalTypeId(column.PostgresType)) is { } typeInfo) { column.DataType = typeInfo.Type; column.IsLong = column.PostgresType.DataTypeName == DataTypeNames.Bytea; diff --git a/src/Npgsql/Schema/NpgsqlDbColumn.cs b/src/Npgsql/Schema/NpgsqlDbColumn.cs index 4b118e97f6..e4597e3d86 100644 --- a/src/Npgsql/Schema/NpgsqlDbColumn.cs +++ b/src/Npgsql/Schema/NpgsqlDbColumn.cs @@ -1,6 +1,5 @@ using System; using System.Data.Common; -using System.Runtime.CompilerServices; using Npgsql.PostgresTypes; using NpgsqlTypes; @@ -32,7 +31,7 @@ public NpgsqlDbColumn() } internal NpgsqlDbColumn Clone() => - Unsafe.As(MemberwiseClone()); + (NpgsqlDbColumn)MemberwiseClone(); #region Standard fields // ReSharper disable once InconsistentNaming @@ -232,4 +231,4 @@ public override object? this[string propertyName] }; #endregion Npgsql-specific fields -} \ No newline at end of file +} diff --git a/src/Npgsql/Shims/DbDataSource.cs b/src/Npgsql/Shims/DbDataSource.cs deleted file mode 100644 index 6951d427fb..0000000000 --- a/src/Npgsql/Shims/DbDataSource.cs +++ /dev/null @@ -1,70 +0,0 @@ -#if !NET7_0_OR_GREATER - -using System.Threading; -using System.Threading.Tasks; - -#pragma warning disable CS1591 // Missing XML comment for publicly visible type or member (compatibility shim for old TFMs) - -// ReSharper disable once CheckNamespace -namespace System.Data.Common; - -public abstract class DbDataSource : IDisposable, IAsyncDisposable -{ - public abstract string ConnectionString { get; } - - protected abstract DbConnection CreateDbConnection(); - - // No need for an actual implementation in this compat shim - it's only implementation will be NpgsqlDataSource, which overrides this. - protected virtual DbConnection OpenDbConnection() - => throw new NotSupportedException(); - - // No need for an actual implementation in this compat shim - it's only implementation will be NpgsqlDataSource, which overrides this. - protected virtual ValueTask OpenDbConnectionAsync(CancellationToken cancellationToken = default) - => throw new NotSupportedException(); - - // No need for an actual implementation in this compat shim - it's only implementation will be NpgsqlDataSource, which overrides this. - protected virtual DbCommand CreateDbCommand(string? commandText = null) - => throw new NotSupportedException(); - - // No need for an actual implementation in this compat shim - it's only implementation will be NpgsqlDataSource, which overrides this. - protected virtual DbBatch CreateDbBatch() - => throw new NotSupportedException(); - - public DbConnection CreateConnection() - => CreateDbConnection(); - - public DbConnection OpenConnection() - => OpenDbConnection(); - - public ValueTask OpenConnectionAsync(CancellationToken cancellationToken = default) - => OpenDbConnectionAsync(cancellationToken); - - public DbCommand CreateCommand(string? commandText = null) - => CreateDbCommand(commandText); - - public DbBatch CreateBatch() - => CreateDbBatch(); - - public void Dispose() - { - Dispose(disposing: true); - GC.SuppressFinalize(this); - } - - public async ValueTask DisposeAsync() - { - await DisposeAsyncCore().ConfigureAwait(false); - - Dispose(disposing: false); - GC.SuppressFinalize(this); - } - - protected virtual void Dispose(bool disposing) - { - } - - protected virtual ValueTask DisposeAsyncCore() - => default; -} - -#endif \ No newline at end of file diff --git a/src/Npgsql/Shims/MemoryExtensions.cs b/src/Npgsql/Shims/MemoryExtensions.cs deleted file mode 100644 index 0da143f3c4..0000000000 --- a/src/Npgsql/Shims/MemoryExtensions.cs +++ /dev/null @@ -1,20 +0,0 @@ -#if !NET7_0_OR_GREATER -using System; - -namespace Npgsql; - -static class MemoryExtensions -{ - public static int IndexOfAnyExcept(this ReadOnlySpan span, T value0, T value1) where T : IEquatable - { - for (var i = 0; i < span.Length; i++) - { - var v = span[i]; - if (!v.Equals(value0) && !v.Equals(value1)) - return i; - } - - return -1; - } -} -#endif diff --git a/src/Npgsql/Shims/StreamExtensions.cs b/src/Npgsql/Shims/StreamExtensions.cs deleted file mode 100644 index 60dbf9ca3b..0000000000 --- a/src/Npgsql/Shims/StreamExtensions.cs +++ /dev/null @@ -1,38 +0,0 @@ -#if !NET7_0_OR_GREATER -using System.Threading; -using System.Threading.Tasks; - -// ReSharper disable once CheckNamespace -namespace System.IO -{ - // Helpers to read/write Span/Memory to Stream before netstandard 2.1 - static class StreamExtensions - { - public static void ReadExactly(this Stream stream, Span buffer) - { - var totalRead = 0; - while (totalRead < buffer.Length) - { - var read = stream.Read(buffer.Slice(totalRead)); - if (read is 0) - throw new EndOfStreamException(); - - totalRead += read; - } - } - - public static async ValueTask ReadExactlyAsync(this Stream stream, Memory buffer, CancellationToken cancellationToken = default) - { - var totalRead = 0; - while (totalRead < buffer.Length) - { - var read = await stream.ReadAsync(buffer.Slice(totalRead), cancellationToken).ConfigureAwait(false); - if (read is 0) - throw new EndOfStreamException(); - - totalRead += read; - } - } - } -} -#endif diff --git a/src/Npgsql/Shims/UnreachableException.cs b/src/Npgsql/Shims/UnreachableException.cs deleted file mode 100644 index f75989df13..0000000000 --- a/src/Npgsql/Shims/UnreachableException.cs +++ /dev/null @@ -1,39 +0,0 @@ -#if !NET7_0_OR_GREATER -namespace System.Diagnostics; - -/// -/// Exception thrown when the program executes an instruction that was thought to be unreachable. -/// -sealed class UnreachableException : Exception -{ - /// - /// Initializes a new instance of the class with the default error message. - /// - public UnreachableException() - : base("The program executed an instruction that was thought to be unreachable.") - { - } - - /// - /// Initializes a new instance of the - /// class with a specified error message. - /// - /// The error message that explains the reason for the exception. - public UnreachableException(string? message) - : base(message) - { - } - - /// - /// Initializes a new instance of the - /// class with a specified error message and a reference to the inner exception that is the cause of - /// this exception. - /// - /// The error message that explains the reason for the exception. - /// The exception that is the cause of the current exception. - public UnreachableException(string? message, Exception? innerException) - : base(message, innerException) - { - } -} -#endif diff --git a/src/Npgsql/SqlQueryParser.cs b/src/Npgsql/SqlQueryParser.cs index 2a76755f0b..c037a51342 100644 --- a/src/Npgsql/SqlQueryParser.cs +++ b/src/Npgsql/SqlQueryParser.cs @@ -7,7 +7,7 @@ namespace Npgsql; sealed class SqlQueryParser { - static NpgsqlParameterCollection EmptyParameters { get; } = new(); + static NpgsqlParameterCollection EmptyParameters { get; } = []; readonly Dictionary _paramIndexMap = new(StringComparer.OrdinalIgnoreCase); readonly StringBuilder _rewrittenSql = new(); @@ -501,10 +501,12 @@ void MoveToNextBatchCommand() { batchCommand = batchCommands[statementIndex]; batchCommand.Reset(); + batchCommand._parameters = parameters; } else { - batchCommand = new NpgsqlBatchCommand(); + batchCommand = new NpgsqlBatchCommand { _parameters = parameters }; + batchCommand.CommandText = sql; batchCommands.Add(batchCommand); } } diff --git a/src/Npgsql/TaskTimeoutAndCancellation.cs b/src/Npgsql/TaskTimeoutAndCancellation.cs deleted file mode 100644 index ceed87ba94..0000000000 --- a/src/Npgsql/TaskTimeoutAndCancellation.cs +++ /dev/null @@ -1,66 +0,0 @@ -using System; -using System.Threading; -using System.Threading.Tasks; -using Npgsql.Util; - -namespace Npgsql; - -/// -/// Utility class to execute a potentially non-cancellable while allowing to timeout and/or cancel awaiting for it and at the same time prevent event if the original fails later. -/// -static class TaskTimeoutAndCancellation -{ - /// - /// Executes a potentially non-cancellable while allowing to timeout and/or cancel awaiting for it. - /// If the given task does not complete within , a is thrown. - /// The executed may be left in an incomplete state after the that this method returns completes dues to timeout and/or cancellation request. - /// The method guarantees that the abandoned, incomplete is not going to produce event if it fails later. - /// - /// Gets the for execution with a combined that attempts to cancel the in an event of the timeout or external cancellation request. - /// The timeout after which the should be faulted with a if it hasn't otherwise completed. - /// The to monitor for a cancellation request. - /// The result . - /// The representing the asynchronous wait. - internal static async Task ExecuteAsync(Func> getTaskFunc, NpgsqlTimeout timeout, CancellationToken cancellationToken) - { - Task? task = default; - await ExecuteAsync(ct => (Task)(task = getTaskFunc(ct)), timeout, cancellationToken).ConfigureAwait(false); - return await task!.ConfigureAwait(false); - } - - /// - /// Executes a potentially non-cancellable while allowing to timeout and/or cancel awaiting for it. - /// If the given task does not complete within , a is thrown. - /// The executed may be left in an incomplete state after the that this method returns completes dues to timeout and/or cancellation request. - /// The method guarantees that the abandoned, incomplete is not going to produce event if it fails later. - /// - /// Gets the for execution with a combined that attempts to cancel the in an event of the timeout or external cancellation request. - /// The timeout after which the should be faulted with a if it hasn't otherwise completed. - /// The to monitor for a cancellation request. - /// The representing the asynchronous wait. - internal static async Task ExecuteAsync(Func getTaskFunc, NpgsqlTimeout timeout, CancellationToken cancellationToken) - { - using var combinedCts = timeout.IsSet ? CancellationTokenSource.CreateLinkedTokenSource(cancellationToken) : null; - var task = getTaskFunc(combinedCts?.Token ?? cancellationToken); - try - { - try - { - await task.WaitAsync(timeout.CheckAndGetTimeLeft(), cancellationToken).ConfigureAwait(false); - } - catch (TimeoutException) when (!task!.IsCompleted) - { - // Attempt to stop the Task in progress. - combinedCts?.Cancel(); - throw; - } - } - catch - { - // Prevent unobserved Task notifications by observing the failed Task exception. - // To test: comment the next line out and re-run TaskExtensionsTest.DelayedFaultedTaskCancellation. - _ = task.ContinueWith(t => _ = t.Exception, CancellationToken.None, TaskContinuationOptions.OnlyOnFaulted, TaskScheduler.Current); - throw; - } - } -} diff --git a/src/Npgsql/ThrowHelper.cs b/src/Npgsql/ThrowHelper.cs index f20dac780c..dc79128537 100644 --- a/src/Npgsql/ThrowHelper.cs +++ b/src/Npgsql/ThrowHelper.cs @@ -1,5 +1,6 @@ using Npgsql.BackendMessages; using System; +using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using Npgsql.Internal; @@ -19,6 +20,10 @@ internal static void ThrowArgumentOutOfRangeException(string paramName, string m internal static void ThrowArgumentOutOfRangeException(string paramName, string message, object argument) => throw new ArgumentOutOfRangeException(paramName, string.Format(message, argument)); + [DoesNotReturn] + internal static void ThrowUnreachableException(string message, object argument) + => throw new UnreachableException(string.Format(message, argument)); + [DoesNotReturn] internal static void ThrowInvalidOperationException() => throw new InvalidOperationException(); @@ -83,10 +88,6 @@ internal static void ThrowArgumentException(string message) internal static void ThrowArgumentException(string message, string paramName) => throw new ArgumentException(message, paramName); - [DoesNotReturn] - internal static void ThrowArgumentNullException(string paramName) - => throw new ArgumentNullException(paramName); - [DoesNotReturn] internal static void ThrowArgumentNullException(string message, string paramName) => throw new ArgumentNullException(paramName, message); @@ -95,6 +96,10 @@ internal static void ThrowArgumentNullException(string message, string paramName internal static void ThrowIndexOutOfRangeException(string message) => throw new IndexOutOfRangeException(message); + [DoesNotReturn] + internal static void ThrowIndexOutOfRangeException(string message, object argument) + => throw new IndexOutOfRangeException(string.Format(message, argument)); + [DoesNotReturn] internal static void ThrowNotSupportedException(string? message = null) => throw new NotSupportedException(message); diff --git a/src/Npgsql/TypeMapping/GlobalTypeMapper.cs b/src/Npgsql/TypeMapping/GlobalTypeMapper.cs index caffde5fc0..ef3981d22f 100644 --- a/src/Npgsql/TypeMapping/GlobalTypeMapper.cs +++ b/src/Npgsql/TypeMapping/GlobalTypeMapper.cs @@ -1,14 +1,11 @@ using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; -using System.Linq; using System.Text.Json; -using System.Text.Json.Nodes; using System.Threading; using Npgsql.Internal; using Npgsql.Internal.Postgres; using Npgsql.Internal.ResolverFactories; -using NpgsqlTypes; namespace Npgsql.TypeMapping; @@ -16,11 +13,9 @@ namespace Npgsql.TypeMapping; sealed class GlobalTypeMapper : INpgsqlTypeMapper { readonly UserTypeMapper _userTypeMapper = new(); - readonly List _pluginResolverFactories = new(); + readonly List _pluginResolverFactories = []; readonly ReaderWriterLockSlim _lock = new(); - PgTypeInfoResolverFactory[] _typeMappingResolvers = Array.Empty(); - - internal List HackyEnumTypeMappings { get; } = new(); + PgTypeInfoResolverFactory[] _typeMappingResolvers = []; internal IEnumerable GetPluginResolverFactories() { @@ -100,12 +95,12 @@ PgSerializerOptions TypeMappingOptions } } - internal DataTypeName? FindDataTypeName(Type type, object value) + internal DataTypeName? FindDataTypeName(Type type, object? value) { DataTypeName? dataTypeName; try { - var typeInfo = TypeMappingOptions.GetTypeInfo(type); + var typeInfo = TypeMappingOptions.GetTypeInfoInternal(type, null); if (typeInfo is PgResolverTypeInfo info) dataTypeName = info.GetObjectResolution(value).PgTypeId.DataTypeName; else @@ -153,6 +148,9 @@ public void AddTypeInfoResolverFactory(PgTypeInfoResolverFactory factory) } } + public void AddDbTypeResolverFactory(DbTypeResolverFactory factory) + => throw new NotSupportedException("The global type mapper does not support DbTypeResolverFactories. Call this method on a data source builder instead."); + void ReplaceTypeInfoResolverFactory(PgTypeInfoResolverFactory factory) { _lock.EnterWriteLock(); @@ -185,7 +183,6 @@ public void Reset() { _pluginResolverFactories.Clear(); _userTypeMapper.Items.Clear(); - HackyEnumTypeMappings.Clear(); } finally { @@ -245,13 +242,7 @@ public INpgsqlTypeMapper EnableUnmappedTypes() try { _userTypeMapper.MapEnum(pgName, nameTranslator); - - // Temporary hack for EFCore.PG enum mapping compat - if (_userTypeMapper.Items.FirstOrDefault(i => i.ClrType == typeof(TEnum)) is UserTypeMapping userTypeMapping) - HackyEnumTypeMappings.Add(new(typeof(TEnum), userTypeMapping.PgTypeName, nameTranslator ?? DefaultNameTranslator)); - ResetTypeMappingCache(); - return this; } finally @@ -267,13 +258,7 @@ public INpgsqlTypeMapper EnableUnmappedTypes() try { var removed = _userTypeMapper.UnmapEnum(pgName, nameTranslator); - - // Temporary hack for EFCore.PG enum mapping compat - if (removed && ((List)_userTypeMapper.Items).FindIndex(m => m.ClrType == typeof(TEnum)) is > -1 and var index) - HackyEnumTypeMappings.RemoveAt(index); - ResetTypeMappingCache(); - return removed; } finally @@ -307,9 +292,9 @@ public bool UnmapEnum([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes _lock.EnterWriteLock(); try { - var result = _userTypeMapper.UnmapEnum(clrType, pgName, nameTranslator); + var removed = _userTypeMapper.UnmapEnum(clrType, pgName, nameTranslator); ResetTypeMappingCache(); - return result; + return removed; } finally { diff --git a/src/Npgsql/TypeMapping/INpgsqlTypeMapper.cs b/src/Npgsql/TypeMapping/INpgsqlTypeMapper.cs index 83728785d6..3fc5d0cbf1 100644 --- a/src/Npgsql/TypeMapping/INpgsqlTypeMapper.cs +++ b/src/Npgsql/TypeMapping/INpgsqlTypeMapper.cs @@ -3,7 +3,6 @@ using System.Text.Json; using System.Text.Json.Nodes; using Npgsql.Internal; -using Npgsql.Internal.ResolverFactories; using Npgsql.NameTranslation; using NpgsqlTypes; @@ -197,8 +196,17 @@ bool UnmapComposite( /// Typically used by plugins. /// /// The type resolver factory to be added. + [Experimental(NpgsqlDiagnostics.ConvertersExperimental)] void AddTypeInfoResolverFactory(PgTypeInfoResolverFactory factory); + /// + /// Adds a DbType resolver factory which can change how DbType cases are mapped to PostgreSQL data types. + /// Typically used by plugins. + /// + /// The resolver factory to be added. + [Experimental(NpgsqlDiagnostics.DbTypeResolverExperimental)] + public void AddDbTypeResolverFactory(DbTypeResolverFactory factory); + /// /// Configures the JSON serializer options used when reading and writing all System.Text.Json data. /// diff --git a/src/Npgsql/TypeMapping/UserTypeMapper.cs b/src/Npgsql/TypeMapping/UserTypeMapper.cs index 35fabb90fe..3b7928bbd2 100644 --- a/src/Npgsql/TypeMapping/UserTypeMapper.cs +++ b/src/Npgsql/TypeMapping/UserTypeMapper.cs @@ -38,10 +38,19 @@ sealed class UserTypeMapper : PgTypeInfoResolverFactory readonly List _mappings; public IList Items => _mappings; - public INpgsqlNameTranslator DefaultNameTranslator { get; set; } = NpgsqlSnakeCaseNameTranslator.Instance; + INpgsqlNameTranslator _defaultNameTranslator = NpgsqlSnakeCaseNameTranslator.Instance; + public INpgsqlNameTranslator DefaultNameTranslator + { + get => _defaultNameTranslator; + set + { + ArgumentNullException.ThrowIfNull(value); + _defaultNameTranslator = value; + } + } - UserTypeMapper(IEnumerable mappings) => _mappings = new List(mappings); - public UserTypeMapper() => _mappings = new(); + UserTypeMapper(IEnumerable mappings) => _mappings = [..mappings]; + public UserTypeMapper() => _mappings = []; public UserTypeMapper Clone() => new(_mappings) { DefaultNameTranslator = DefaultNameTranslator }; @@ -65,9 +74,9 @@ public UserTypeMapper MapEnum([DynamicallyAccessedMembers(DynamicallyAccessedMem if (!clrType.IsEnum || !clrType.IsValueType) throw new ArgumentException("Type must be a concrete Enum", nameof(clrType)); - var openMethod = typeof(UserTypeMapper).GetMethod(nameof(MapEnum), new[] { typeof(string), typeof(INpgsqlNameTranslator) })!; + var openMethod = typeof(UserTypeMapper).GetMethod(nameof(MapEnum), [typeof(string), typeof(INpgsqlNameTranslator)])!; var method = openMethod.MakeGenericMethod(clrType); - method.Invoke(this, new object?[] { pgName, nameTranslator }); + method.Invoke(this, [pgName, nameTranslator]); return this; } @@ -107,11 +116,11 @@ public UserTypeMapper MapComposite([DynamicallyAccessedMembers(DynamicallyAccess var openMethod = typeof(UserTypeMapper).GetMethod( clrType.IsValueType ? nameof(MapStructComposite) : nameof(MapComposite), - new[] { typeof(string), typeof(INpgsqlNameTranslator) })!; + [typeof(string), typeof(INpgsqlNameTranslator)])!; var method = openMethod.MakeGenericMethod(clrType); - method.Invoke(this, new object?[] { pgName, nameTranslator }); + method.Invoke(this, [pgName, nameTranslator]); return this; } @@ -145,17 +154,15 @@ static string GetPgName(Type type, INpgsqlNameTranslator nameTranslator) => type.GetCustomAttribute()?.PgName ?? nameTranslator.TranslateTypeName(type.Name); - public override IPgTypeInfoResolver CreateResolver() => new Resolver(new(_mappings)); - public override IPgTypeInfoResolver CreateArrayResolver() => new ArrayResolver(new(_mappings)); + public override IPgTypeInfoResolver CreateResolver() => new Resolver([.._mappings]); + public override IPgTypeInfoResolver CreateArrayResolver() => new ArrayResolver([.._mappings]); - class Resolver : IPgTypeInfoResolver + class Resolver(List userTypeMappings) : IPgTypeInfoResolver { - protected readonly List _userTypeMappings; + protected readonly List _userTypeMappings = userTypeMappings; TypeInfoMappingCollection? _mappings; protected TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new()); - public Resolver(List userTypeMappings) => _userTypeMappings = userTypeMappings; - PgTypeInfo? IPgTypeInfoResolver.GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) => Mappings.Find(type, dataTypeName, options); @@ -168,13 +175,11 @@ TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) } } - sealed class ArrayResolver : Resolver, IPgTypeInfoResolver + sealed class ArrayResolver(List userTypeMappings) : Resolver(userTypeMappings), IPgTypeInfoResolver { TypeInfoMappingCollection? _mappings; new TypeInfoMappingCollection Mappings => _mappings ??= AddMappings(new(base.Mappings)); - public ArrayResolver(List userTypeMappings) : base(userTypeMappings) { } - PgTypeInfo? IPgTypeInfoResolver.GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) => Mappings.Find(type, dataTypeName, options); @@ -188,62 +193,54 @@ TypeInfoMappingCollection AddMappings(TypeInfoMappingCollection mappings) } [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] - sealed class CompositeMapping<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicFields | DynamicallyAccessedMemberTypes.PublicProperties)] T> : UserTypeMapping where T : class + sealed class CompositeMapping< + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicFields | + DynamicallyAccessedMemberTypes.PublicProperties)] + T>(string pgTypeName, INpgsqlNameTranslator nameTranslator) : UserTypeMapping(pgTypeName, typeof(T)) + where T : class { - readonly INpgsqlNameTranslator _nameTranslator; - - public CompositeMapping(string pgTypeName, INpgsqlNameTranslator nameTranslator) - : base(pgTypeName, typeof(T)) - => _nameTranslator = nameTranslator; - internal override void AddMapping(TypeInfoMappingCollection mappings) - { - mappings.AddType(PgTypeName, (options, mapping, _) => + => mappings.AddType(PgTypeName, (options, mapping, _) => { var pgType = mapping.GetPgType(options); if (pgType is not PostgresCompositeType compositeType) throw new InvalidOperationException("Composite mapping must be to a composite type"); return mapping.CreateInfo(options, new CompositeConverter( - ReflectionCompositeInfoFactory.CreateCompositeInfo(compositeType, _nameTranslator, options))); + ReflectionCompositeInfoFactory.CreateCompositeInfo(compositeType, nameTranslator, options))); }, isDefault: true); - } internal override void AddArrayMapping(TypeInfoMappingCollection mappings) => mappings.AddArrayType(PgTypeName); } [RequiresDynamicCode("Mapping composite types involves serializing arbitrary types which can require creating new generic types or methods. This is currently unsupported with NativeAOT, vote on issue #5303 if this is important to you.")] - sealed class StructCompositeMapping<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicFields | DynamicallyAccessedMemberTypes.PublicProperties)] T> : UserTypeMapping where T : struct + sealed class StructCompositeMapping< + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors | DynamicallyAccessedMemberTypes.PublicFields | + DynamicallyAccessedMemberTypes.PublicProperties)] + T>(string pgTypeName, INpgsqlNameTranslator nameTranslator) : UserTypeMapping(pgTypeName, typeof(T)) + where T : struct { - readonly INpgsqlNameTranslator _nameTranslator; - - public StructCompositeMapping(string pgTypeName, INpgsqlNameTranslator nameTranslator) - : base(pgTypeName, typeof(T)) - => _nameTranslator = nameTranslator; - internal override void AddMapping(TypeInfoMappingCollection mappings) - { - mappings.AddStructType(PgTypeName, (options, mapping, dataTypeNameMatch) => + => mappings.AddStructType(PgTypeName, (options, mapping, requiresDataTypeName) => { var pgType = mapping.GetPgType(options); if (pgType is not PostgresCompositeType compositeType) throw new InvalidOperationException("Composite mapping must be to a composite type"); return mapping.CreateInfo(options, new CompositeConverter( - ReflectionCompositeInfoFactory.CreateCompositeInfo(compositeType, _nameTranslator, options))); + ReflectionCompositeInfoFactory.CreateCompositeInfo(compositeType, nameTranslator, options))); }, isDefault: true); - } internal override void AddArrayMapping(TypeInfoMappingCollection mappings) => mappings.AddStructArrayType(PgTypeName); } - internal abstract class EnumMapping : UserTypeMapping + internal abstract class EnumMapping( + string pgTypeName, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields)] Type enumClrType, + INpgsqlNameTranslator nameTranslator) + : UserTypeMapping(pgTypeName, enumClrType) { - internal INpgsqlNameTranslator NameTranslator { get; } - - public EnumMapping(string pgTypeName, [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields)]Type enumClrType, INpgsqlNameTranslator nameTranslator) - : base(pgTypeName, enumClrType) - => NameTranslator = nameTranslator; + internal INpgsqlNameTranslator NameTranslator { get; } = nameTranslator; } sealed class EnumMapping<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicFields)] TEnum> : EnumMapping diff --git a/src/Npgsql/UnpooledDataSource.cs b/src/Npgsql/UnpooledDataSource.cs index 549a45f9b8..55ce5d65af 100644 --- a/src/Npgsql/UnpooledDataSource.cs +++ b/src/Npgsql/UnpooledDataSource.cs @@ -6,13 +6,9 @@ namespace Npgsql; -sealed class UnpooledDataSource : NpgsqlDataSource +sealed class UnpooledDataSource(NpgsqlConnectionStringBuilder settings, NpgsqlDataSourceConfiguration dataSourceConfig) + : NpgsqlDataSource(settings, dataSourceConfig, reportMetrics: true) { - public UnpooledDataSource(NpgsqlConnectionStringBuilder settings, NpgsqlDataSourceConfiguration dataSourceConfig) - : base(settings, dataSourceConfig) - { - } - volatile int _numConnectors; internal override (int Total, int Idle, int Busy) Statistics => (_numConnectors, 0, _numConnectors); @@ -46,5 +42,7 @@ internal override void Return(NpgsqlConnector connector) connector.Close(); } - internal override void Clear() {} + public override void Clear() + { + } } diff --git a/src/Npgsql/Util/GSSStream.cs b/src/Npgsql/Util/GSSStream.cs new file mode 100644 index 0000000000..4f98a1d1fa --- /dev/null +++ b/src/Npgsql/Util/GSSStream.cs @@ -0,0 +1,177 @@ +using System; +using System.Buffers; +using System.Buffers.Binary; +using System.IO; +using System.Net.Security; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Npgsql.Util; + +// For more detailed explanation of communication protocol +// See https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-GSSAPI +sealed class GSSStream : Stream +{ + // At most, postgres supports GSS messages up to 16kb + // We use the recommended value of 8kb for the write buffer + // Which will result in messages of slightly larger than 8kb + const int MaxWriteMessageSizeLimit = 8 * 1024; + const int MaxReadMessageSizeLimit = 16 * 1024; + + readonly Stream _stream; + readonly NegotiateAuthentication _authentication; + + readonly ArrayBufferWriter _writeBuffer; + readonly byte[] _writeLengthBuffer; + + readonly byte[] _readBuffer; + int _readPosition; + int _leftToRead; + + internal GSSStream(Stream stream, NegotiateAuthentication authentication) + { + _stream = stream; + _authentication = authentication; + // While we guarantee that unencrypted messages are at most 8kb + // Encrypting them will result in messages slightly larger than the original size + // Which is why the initial capacity has an additional 2kb of free space + _writeBuffer = new ArrayBufferWriter(MaxWriteMessageSizeLimit + 2048); + _writeLengthBuffer = new byte[4]; + _readBuffer = new byte[MaxReadMessageSizeLimit]; + } + + public override void Write(ReadOnlySpan buffer) + { + var start = 0; + while (start != buffer.Length) + { + var lengthToWrite = Math.Min(buffer.Length - start, MaxWriteMessageSizeLimit); + var result = _authentication.Wrap( + buffer.Slice(start, lengthToWrite), + _writeBuffer, + _authentication.IsEncrypted, + out _); + if (result != NegotiateAuthenticationStatusCode.Completed) + throw new NpgsqlException($"Error while encrypting buffer: {result}"); + + var written = _writeBuffer.WrittenMemory; + Unsafe.WriteUnaligned(ref _writeLengthBuffer[0], BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(written.Length) : written.Length); + + _stream.Write(_writeLengthBuffer); + _stream.Write(_writeBuffer.WrittenMemory.Span); + + _writeBuffer.ResetWrittenCount(); + start += lengthToWrite; + } + } + + public override void Write(byte[] buffer, int offset, int count) + => Write(buffer.AsSpan(offset, count)); + + public override async ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + var start = 0; + while (start != buffer.Length) + { + var lengthToWrite = Math.Min(buffer.Length - start, MaxWriteMessageSizeLimit); + var result = _authentication.Wrap( + buffer.Slice(start, lengthToWrite).Span, + _writeBuffer, + _authentication.IsEncrypted, + out _); + if (result != NegotiateAuthenticationStatusCode.Completed) + throw new NpgsqlException($"Error while encrypting buffer: {result}"); + + var written = _writeBuffer.WrittenMemory; + Unsafe.WriteUnaligned(ref _writeLengthBuffer[0], BitConverter.IsLittleEndian ? BinaryPrimitives.ReverseEndianness(written.Length) : written.Length); + + await _stream.WriteAsync(_writeLengthBuffer, cancellationToken).ConfigureAwait(false); + await _stream.WriteAsync(_writeBuffer.WrittenMemory, cancellationToken).ConfigureAwait(false); + + _writeBuffer.ResetWrittenCount(); + start += lengthToWrite; + } + } + + public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => await WriteAsync(buffer.AsMemory(offset, count), cancellationToken).ConfigureAwait(false); + + public override void Flush() => _stream.Flush(); + + public override Task FlushAsync(CancellationToken cancellationToken) => _stream.FlushAsync(cancellationToken); + + public override int Read(Span buffer) + { + if (_leftToRead == 0) + { + _stream.ReadExactly(_readBuffer.AsSpan(0, 4)); + var messageLength = BitConverter.IsLittleEndian + ? BinaryPrimitives.ReverseEndianness(Unsafe.ReadUnaligned(ref _readBuffer[0])) + : Unsafe.ReadUnaligned(ref _readBuffer[0]); + var messageBuffer = _readBuffer.AsSpan(0, messageLength); + _stream.ReadExactly(messageBuffer); + var result = _authentication.UnwrapInPlace(messageBuffer, out _readPosition, out _leftToRead, out _); + if (result != NegotiateAuthenticationStatusCode.Completed) + throw new NpgsqlException($"Error while decrypting buffer: {result}"); + } + + var maxRead = Math.Min(_leftToRead, buffer.Length); + _readBuffer.AsSpan(_readPosition, maxRead).CopyTo(buffer); + _readPosition += maxRead; + _leftToRead -= maxRead; + return maxRead; + } + + public override int Read(byte[] buffer, int offset, int count) + => Read(buffer.AsSpan(offset, count)); + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + if (_leftToRead == 0) + { + await _stream.ReadExactlyAsync(_readBuffer.AsMemory(0, 4), cancellationToken).ConfigureAwait(false); + var messageLength = BitConverter.IsLittleEndian + ? BinaryPrimitives.ReverseEndianness(Unsafe.ReadUnaligned(ref _readBuffer[0])) + : Unsafe.ReadUnaligned(ref _readBuffer[0]); + var messageBuffer = _readBuffer.AsMemory(0, messageLength); + await _stream.ReadExactlyAsync(messageBuffer, cancellationToken).ConfigureAwait(false); + var result = _authentication.UnwrapInPlace(messageBuffer.Span, out _readPosition, out _leftToRead, out _); + if (result != NegotiateAuthenticationStatusCode.Completed) + throw new NpgsqlException($"Error while decrypting buffer: {result}"); + } + + var maxRead = Math.Min(_leftToRead, buffer.Length); + _readBuffer.AsMemory(_readPosition, maxRead).CopyTo(buffer); + _readPosition += maxRead; + _leftToRead -= maxRead; + return maxRead; + } + + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + => await ReadAsync(buffer.AsMemory(offset, count), cancellationToken).ConfigureAwait(false); + + public override void Close() => _stream.Close(); + + protected override void Dispose(bool disposing) + { + _authentication.Dispose(); + _stream.Dispose(); + } + + public override ValueTask DisposeAsync() => _stream.DisposeAsync(); + + public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException(); + public override void SetLength(long value) => throw new NotSupportedException(); + + public override bool CanRead => true; + public override bool CanWrite => true; + public override bool CanSeek => false; + public override long Length => throw new NotSupportedException(); + + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } +} diff --git a/src/Npgsql/Util/LoggingEnumerable.cs b/src/Npgsql/Util/LoggingEnumerable.cs new file mode 100644 index 0000000000..eabc7ebdd5 --- /dev/null +++ b/src/Npgsql/Util/LoggingEnumerable.cs @@ -0,0 +1,36 @@ +using System.Collections; +using System.Collections.Generic; +using System.Text; + +namespace Npgsql.Util; + +// For logging batches we have to use a wrapper for parameters, otherwise they're logged as object[]. See https://github.com/npgsql/npgsql/issues/6078. +sealed class LoggingEnumerable(IEnumerable wrappedEnumerable) : IEnumerable +{ + public IEnumerator GetEnumerator() => wrappedEnumerable.GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)wrappedEnumerable).GetEnumerator(); + + public override string ToString() + { + var sb = new StringBuilder(); + + sb.Append('['); + + var appended = false; + + foreach (var o in wrappedEnumerable) + { + if (appended) + sb.Append(", "); + else + appended = true; + + sb.Append(o); + } + + sb.Append(']'); + + return sb.ToString(); + } +} diff --git a/src/Npgsql/Util/ManualResetValueTaskSource.cs b/src/Npgsql/Util/ManualResetValueTaskSource.cs deleted file mode 100644 index 55e45aa225..0000000000 --- a/src/Npgsql/Util/ManualResetValueTaskSource.cs +++ /dev/null @@ -1,21 +0,0 @@ -using System; -using System.Threading.Tasks.Sources; - -namespace Npgsql.Util; - -sealed class ManualResetValueTaskSource : IValueTaskSource, IValueTaskSource -{ - ManualResetValueTaskSourceCore _core; // mutable struct; do not make this readonly - - public bool RunContinuationsAsynchronously { get => _core.RunContinuationsAsynchronously; set => _core.RunContinuationsAsynchronously = value; } - public short Version => _core.Version; - public void Reset() => _core.Reset(); - public void SetResult(T result) => _core.SetResult(result); - public void SetException(Exception error) => _core.SetException(error); - - public T GetResult(short token) => _core.GetResult(token); - void IValueTaskSource.GetResult(short token) => _core.GetResult(token); - public ValueTaskSourceStatus GetStatus(short token) => _core.GetStatus(token); - public void OnCompleted(Action continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags) - => _core.OnCompleted(continuation, state, token, flags); -} \ No newline at end of file diff --git a/src/Npgsql/Util/ResettableCancellationTokenSource.cs b/src/Npgsql/Util/ResettableCancellationTokenSource.cs index 874d7a40f8..f4b1652e2a 100644 --- a/src/Npgsql/Util/ResettableCancellationTokenSource.cs +++ b/src/Npgsql/Util/ResettableCancellationTokenSource.cs @@ -13,17 +13,17 @@ namespace Npgsql.Util; /// we need to make sure that an existing cancellation token source hasn't been cancelled, /// every time we start it (see https://github.com/dotnet/runtime/issues/4694). /// -sealed class ResettableCancellationTokenSource : IDisposable +sealed class ResettableCancellationTokenSource(TimeSpan timeout) : IDisposable { bool isDisposed; - public TimeSpan Timeout { get; set; } + public TimeSpan Timeout { get; set; } = timeout; CancellationTokenSource _cts = new(); CancellationTokenRegistration? _registration; /// - /// Used, so we wouldn't concurently use the cts for the cancellation, while it's being disposed + /// Used, so we wouldn't concurrently use the cts for the cancellation, while it's being disposed /// readonly object lockObject = new(); @@ -31,9 +31,9 @@ sealed class ResettableCancellationTokenSource : IDisposable bool _isRunning; #endif - public ResettableCancellationTokenSource() => Timeout = InfiniteTimeSpan; - - public ResettableCancellationTokenSource(TimeSpan timeout) => Timeout = timeout; + public ResettableCancellationTokenSource() : this(InfiniteTimeSpan) + { + } /// /// Set the timeout on the wrapped diff --git a/src/Npgsql/Util/Statics.cs b/src/Npgsql/Util/Statics.cs index 2b1101171b..c21d10fbe5 100644 --- a/src/Npgsql/Util/Statics.cs +++ b/src/Npgsql/Util/Statics.cs @@ -4,6 +4,7 @@ using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; +using System.Text; namespace Npgsql.Util; @@ -25,6 +26,15 @@ static Statics() DisableDateTimeInfinityConversions = AppContext.TryGetSwitch("Npgsql.DisableDateTimeInfinityConversions", out enabled) && enabled; } + /// Returns the escaped SQL representation of a string literal. + /// The identifier to be escaped. + internal static string EscapeLiteral(string literal) + { + // There is no support for escape sequences in quoted values for PostgreSQL, so replacing ' is enough. + // (to be able to use escaped characters an alternative syntax exists, it requires E to appear directly before the opening quote) + return literal.Replace("'", "''"); + } + internal static T Expect(IBackendMessage msg, NpgsqlConnector connector) { if (msg.GetType() != typeof(T)) @@ -88,9 +98,3 @@ static void ThrowUnknownMessageCode(BackendMessageCode code) => ThrowHelper.ThrowNpgsqlException($"Unknown message code: {code}"); } } - -static class EnumerableExtensions -{ - internal static string Join(this IEnumerable values, string separator) - => string.Join(separator, values); -} diff --git a/src/Npgsql/Util/SubReadStream.cs b/src/Npgsql/Util/SubReadStream.cs index 9f0176b631..8d9d1b1ec5 100644 --- a/src/Npgsql/Util/SubReadStream.cs +++ b/src/Npgsql/Util/SubReadStream.cs @@ -75,10 +75,7 @@ public override long Position public override bool CanWrite => false; void ThrowIfDisposed() - { - if (_isDisposed) - throw new ObjectDisposedException(GetType().ToString()); - } + => ObjectDisposedException.ThrowIf(_isDisposed, this); void ThrowIfCantRead() { diff --git a/src/Npgsql/Util/TaskSchedulerAwaitable.cs b/src/Npgsql/Util/TaskSchedulerAwaitable.cs index be16d8fa55..1b6d2c5647 100644 --- a/src/Npgsql/Util/TaskSchedulerAwaitable.cs +++ b/src/Npgsql/Util/TaskSchedulerAwaitable.cs @@ -6,11 +6,8 @@ namespace Npgsql.Util; -readonly struct TaskSchedulerAwaitable : ICriticalNotifyCompletion +readonly struct TaskSchedulerAwaitable(TaskScheduler scheduler) : ICriticalNotifyCompletion { - readonly TaskScheduler _scheduler; - public TaskSchedulerAwaitable(TaskScheduler scheduler) => _scheduler = scheduler; - public void GetResult() {} public bool IsCompleted => false; @@ -18,7 +15,7 @@ public void OnCompleted(Action continuation) { var task = Task.Factory.StartNew(continuation, CancellationToken.None, TaskCreationOptions.DenyChildAttach, - scheduler: _scheduler); + scheduler: scheduler); // Exceptions should never happen as the continuation should be the async statemachine. // It normally does its own error handling through the returned task unless it's an async void returning method. diff --git a/src/Npgsql/VolatileResourceManager.cs b/src/Npgsql/VolatileResourceManager.cs index 239b62fe8e..92a716f2e2 100644 --- a/src/Npgsql/VolatileResourceManager.cs +++ b/src/Npgsql/VolatileResourceManager.cs @@ -17,6 +17,7 @@ namespace Npgsql; sealed class VolatileResourceManager : ISinglePhaseNotification { NpgsqlConnector _connector; + NpgsqlDataSource _dataSource; Transaction _transaction; readonly string _txId; NpgsqlTransaction _localTx = null!; @@ -31,6 +32,7 @@ sealed class VolatileResourceManager : ISinglePhaseNotification internal VolatileResourceManager(NpgsqlConnection connection, Transaction transaction) { _connector = connection.Connector!; + _dataSource = connection.NpgsqlDataSource; _transaction = transaction; // _tx gets disposed by System.Transactions at some point, but we want to be able to log its local ID _txId = transaction.TransactionInformation.LocalIdentifier; @@ -277,8 +279,10 @@ void Dispose() { // We're here for connections which were closed before their TransactionScope completes. // These need to be closed now. - // We should return the connector to the pool only if we've successfully removed it from the pending list - if (_connector.TryRemovePendingEnlistedConnector(_transaction)) + // We should return the connector to the pool only if we've successfully removed it from the pending list. + // Note that we remove it from the NpgsqlDataSource bound to connection and not to connector + // because of NpgsqlMultiHostDataSource which has its own list to which connection adds connectors. + if (_dataSource.TryRemovePendingEnlistedConnector(_connector, _transaction)) _connector.Return(); } @@ -289,10 +293,7 @@ void Dispose() #pragma warning restore CS8625 void CheckDisposed() - { - if (_isDisposed) - throw new ObjectDisposedException(nameof(VolatileResourceManager)); - } + => ObjectDisposedException.ThrowIf(_isDisposed, this); #endregion diff --git a/src/Shared/CodeAnalysis.cs b/src/Shared/CodeAnalysis.cs deleted file mode 100644 index 0e98a03210..0000000000 --- a/src/Shared/CodeAnalysis.cs +++ /dev/null @@ -1,90 +0,0 @@ - -namespace System.Diagnostics.CodeAnalysis -{ -#if !NET7_0_OR_GREATER - /// - /// Indicates that the specified method requires the ability to generate new code at runtime, - /// for example through . - /// - /// - /// This allows tools to understand which methods are unsafe to call when compiling ahead of time. - /// - [AttributeUsage(AttributeTargets.Method | AttributeTargets.Constructor | AttributeTargets.Class, Inherited = false)] - sealed class RequiresDynamicCodeAttribute : Attribute - { - /// - /// Initializes a new instance of the class - /// with the specified message. - /// - /// - /// A message that contains information about the usage of dynamic code. - /// - public RequiresDynamicCodeAttribute(string message) - { - Message = message; - } - - /// - /// Gets a message that contains information about the usage of dynamic code. - /// - public string Message { get; } - - /// - /// Gets or sets an optional URL that contains more information about the method, - /// why it requires dynamic code, and what options a consumer has to deal with it. - /// - public string? Url { get; set; } - } - - [AttributeUsage(AttributeTargets.Constructor, AllowMultiple = false, Inherited = false)] - sealed class SetsRequiredMembersAttribute : Attribute - { - } - [AttributeUsageAttribute(AttributeTargets.Method | AttributeTargets.Property | AttributeTargets.Parameter, AllowMultiple = false, Inherited = false)] - sealed class UnscopedRefAttribute : Attribute - { - /// - /// Initializes a new instance of the class. - /// - public UnscopedRefAttribute() { } - } -#endif -} - -namespace System.Runtime.CompilerServices -{ -#if !NET7_0_OR_GREATER - [AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct | AttributeTargets.Field | AttributeTargets.Property, AllowMultiple = false, Inherited = false)] - sealed class RequiredMemberAttribute : Attribute - { } - - [AttributeUsage(AttributeTargets.All, AllowMultiple = true, Inherited = false)] - sealed class CompilerFeatureRequiredAttribute : Attribute - { - public CompilerFeatureRequiredAttribute(string featureName) - { - FeatureName = featureName; - } - - /// - /// The name of the compiler feature. - /// - public string FeatureName { get; } - - /// - /// If true, the compiler can choose to allow access to the location where this attribute is applied if it does not understand . - /// - public bool IsOptional { get; set; } - - /// - /// The used for the ref structs C# feature. - /// - public const string RefStructs = nameof(RefStructs); - - /// - /// The used for the required members C# feature. - /// - public const string RequiredMembers = nameof(RequiredMembers); - } -#endif -} diff --git a/test/Directory.Build.props b/test/Directory.Build.props index b51b1c04ba..6af6edc496 100644 --- a/test/Directory.Build.props +++ b/test/Directory.Build.props @@ -2,7 +2,7 @@ - net8.0 + net10.0 false diff --git a/test/MStatDumper/MStatDumper.csproj b/test/MStatDumper/MStatDumper.csproj index 3cab4d57fd..6405431678 100644 --- a/test/MStatDumper/MStatDumper.csproj +++ b/test/MStatDumper/MStatDumper.csproj @@ -2,8 +2,6 @@ Exe - - net8.0 enable disable diff --git a/test/Npgsql.Benchmarks/CommandExecuteBenchmarks.cs b/test/Npgsql.Benchmarks/CommandExecuteBenchmarks.cs index c75febe708..e2e6d4706a 100644 --- a/test/Npgsql.Benchmarks/CommandExecuteBenchmarks.cs +++ b/test/Npgsql.Benchmarks/CommandExecuteBenchmarks.cs @@ -55,8 +55,6 @@ public object ExecuteReader() class Config : ManualConfig { public Config() - { - AddColumn(StatisticColumn.OperationsPerSecond); - } + => AddColumn(StatisticColumn.OperationsPerSecond); } -} \ No newline at end of file +} diff --git a/test/Npgsql.Benchmarks/Commit.cs b/test/Npgsql.Benchmarks/Commit.cs index 96e04ade96..9ab03c11db 100644 --- a/test/Npgsql.Benchmarks/Commit.cs +++ b/test/Npgsql.Benchmarks/Commit.cs @@ -29,8 +29,6 @@ public void Basic() class Config : ManualConfig { public Config() - { - AddColumn(StatisticColumn.OperationsPerSecond); - } + => AddColumn(StatisticColumn.OperationsPerSecond); } -} \ No newline at end of file +} diff --git a/test/Npgsql.Benchmarks/ConnectionCreationBenchmarks.cs b/test/Npgsql.Benchmarks/ConnectionCreationBenchmarks.cs index e63bbba7c6..633445ae0a 100644 --- a/test/Npgsql.Benchmarks/ConnectionCreationBenchmarks.cs +++ b/test/Npgsql.Benchmarks/ConnectionCreationBenchmarks.cs @@ -22,8 +22,6 @@ public class ConnectionCreationBenchmarks class Config : ManualConfig { public Config() - { - AddColumn(StatisticColumn.OperationsPerSecond); - } + => AddColumn(StatisticColumn.OperationsPerSecond); } -} \ No newline at end of file +} diff --git a/test/Npgsql.Benchmarks/ConnectionOpenCloseBenchmarks.cs b/test/Npgsql.Benchmarks/ConnectionOpenCloseBenchmarks.cs index d733ff9c11..ef5e69f62e 100644 --- a/test/Npgsql.Benchmarks/ConnectionOpenCloseBenchmarks.cs +++ b/test/Npgsql.Benchmarks/ConnectionOpenCloseBenchmarks.cs @@ -168,8 +168,6 @@ public void NonPooled() class Config : ManualConfig { public Config() - { - AddColumn(StatisticColumn.OperationsPerSecond); - } + => AddColumn(StatisticColumn.OperationsPerSecond); } -} \ No newline at end of file +} diff --git a/test/Npgsql.Benchmarks/Npgsql.Benchmarks.csproj b/test/Npgsql.Benchmarks/Npgsql.Benchmarks.csproj index bc51b25561..013bfb8a9d 100644 --- a/test/Npgsql.Benchmarks/Npgsql.Benchmarks.csproj +++ b/test/Npgsql.Benchmarks/Npgsql.Benchmarks.csproj @@ -4,6 +4,10 @@ portable Npgsql.Benchmarks Exe + $(NoWarn);NPG9001 + + + NU1901;NU1902;NU1903;NU1904 diff --git a/test/Npgsql.Benchmarks/Prepare.cs b/test/Npgsql.Benchmarks/Prepare.cs index 6b8d9b06bc..5648e75f98 100644 --- a/test/Npgsql.Benchmarks/Prepare.cs +++ b/test/Npgsql.Benchmarks/Prepare.cs @@ -54,9 +54,7 @@ public void GlobalSetup() [GlobalCleanup] public void GlobalCleanup() - { - _conn.Dispose(); - } + => _conn.Dispose(); public Prepare() { @@ -119,4 +117,4 @@ static string GenerateQuery(int tablesToJoin) .Values .Cast() .ToArray(); -} \ No newline at end of file +} diff --git a/test/Npgsql.Benchmarks/ResolveHandler.cs b/test/Npgsql.Benchmarks/ResolveHandler.cs index 86e5d20fbb..ead3a547ed 100644 --- a/test/Npgsql.Benchmarks/ResolveHandler.cs +++ b/test/Npgsql.Benchmarks/ResolveHandler.cs @@ -22,7 +22,7 @@ public void Setup() if (NumPlugins > 1) dataSourceBuilder.UseNetTopologySuite(); _dataSource = dataSourceBuilder.Build(); - _serializerOptions = _dataSource.SerializerOptions; + _serializerOptions = _dataSource.CurrentReloadableState.SerializerOptions; } [GlobalCleanup] @@ -30,13 +30,13 @@ public void Setup() [Benchmark] public PgTypeInfo? ResolveDefault() - => _serializerOptions.GetDefaultTypeInfo(new Oid(23)); // int4 + => _serializerOptions.GetTypeInfoInternal(null, new Oid(23)); // int4 [Benchmark] public PgTypeInfo? ResolveType() - => _serializerOptions.GetTypeInfo(typeof(int)); + => _serializerOptions.GetTypeInfoInternal(typeof(int), null); [Benchmark] public PgTypeInfo? ResolveBoth() - => _serializerOptions.GetTypeInfo(typeof(int), new Oid(23)); // int4 + => _serializerOptions.GetTypeInfoInternal(typeof(int), new Oid(23)); // int4 } diff --git a/test/Npgsql.Benchmarks/TypeHandlers/Numeric.cs b/test/Npgsql.Benchmarks/TypeHandlers/Numeric.cs index 42f5f3936a..ae5dbfe0d9 100644 --- a/test/Npgsql.Benchmarks/TypeHandlers/Numeric.cs +++ b/test/Npgsql.Benchmarks/TypeHandlers/Numeric.cs @@ -5,42 +5,25 @@ namespace Npgsql.Benchmarks.TypeHandlers; [Config(typeof(Config))] -public class Int16 : TypeHandlerBenchmarks -{ - public Int16() : base(new Int2Converter()) { } -} +public class Int16() : TypeHandlerBenchmarks(new Int2Converter()); [Config(typeof(Config))] -public class Int32 : TypeHandlerBenchmarks -{ - public Int32() : base(new Int4Converter()) { } -} +public class Int32() : TypeHandlerBenchmarks(new Int4Converter()); [Config(typeof(Config))] -public class Int64 : TypeHandlerBenchmarks -{ - public Int64() : base(new Int8Converter()) { } -} +public class Int64() : TypeHandlerBenchmarks(new Int8Converter()); [Config(typeof(Config))] -public class Single : TypeHandlerBenchmarks -{ - public Single() : base(new RealConverter()) { } -} +public class Single() : TypeHandlerBenchmarks(new RealConverter()); [Config(typeof(Config))] -public class Double : TypeHandlerBenchmarks -{ - public Double() : base(new DoubleConverter()) { } -} +public class Double() : TypeHandlerBenchmarks(new DoubleConverter()); [Config(typeof(Config))] -public class Numeric : TypeHandlerBenchmarks +public class Numeric() : TypeHandlerBenchmarks(new DecimalNumericConverter()) { - public Numeric() : base(new DecimalNumericConverter()) { } - - protected override IEnumerable ValuesOverride() => new[] - { + protected override IEnumerable ValuesOverride() => + [ 0.0000000000000000000000000001M, 0.000000000000000000000001M, 0.00000000000000000001M, @@ -55,12 +38,9 @@ protected override IEnumerable ValuesOverride() => new[] 10000000000000000M, 100000000000000000000M, 1000000000000000000000000M, - 10000000000000000000000000000M, - }; + 10000000000000000000000000000M + ]; } [Config(typeof(Config))] -public class Money : TypeHandlerBenchmarks -{ - public Money() : base(new MoneyConverter()) { } -} +public class Money() : TypeHandlerBenchmarks(new MoneyConverter()); diff --git a/test/Npgsql.Benchmarks/TypeHandlers/Text.cs b/test/Npgsql.Benchmarks/TypeHandlers/Text.cs index 80d5f6ce0c..6216cdc5de 100644 --- a/test/Npgsql.Benchmarks/TypeHandlers/Text.cs +++ b/test/Npgsql.Benchmarks/TypeHandlers/Text.cs @@ -6,10 +6,8 @@ namespace Npgsql.Benchmarks.TypeHandlers; [Config(typeof(Config))] -public class Text : TypeHandlerBenchmarks +public class Text() : TypeHandlerBenchmarks(new StringTextConverter(Encoding.UTF8)) { - public Text() : base(new StringTextConverter(Encoding.UTF8)) { } - protected override IEnumerable ValuesOverride() { for (var i = 1; i <= 10000; i *= 10) diff --git a/test/Npgsql.Benchmarks/TypeHandlers/TypeHandlerBenchmarks.cs b/test/Npgsql.Benchmarks/TypeHandlers/TypeHandlerBenchmarks.cs index 994839c219..9bc09dac99 100644 --- a/test/Npgsql.Benchmarks/TypeHandlers/TypeHandlerBenchmarks.cs +++ b/test/Npgsql.Benchmarks/TypeHandlers/TypeHandlerBenchmarks.cs @@ -61,7 +61,7 @@ protected TypeHandlerBenchmarks(PgConverter handler) public IEnumerable Values() => ValuesOverride(); - protected virtual IEnumerable ValuesOverride() => new[] { default(T) }; + protected virtual IEnumerable ValuesOverride() => [default(T)]; [ParamsSource(nameof(Values))] public T Value diff --git a/test/Npgsql.Benchmarks/TypeHandlers/Uuid.cs b/test/Npgsql.Benchmarks/TypeHandlers/Uuid.cs index 7c229a3b57..a497a0c509 100644 --- a/test/Npgsql.Benchmarks/TypeHandlers/Uuid.cs +++ b/test/Npgsql.Benchmarks/TypeHandlers/Uuid.cs @@ -5,7 +5,4 @@ namespace Npgsql.Benchmarks.TypeHandlers; [Config(typeof(Config))] -public class Uuid : TypeHandlerBenchmarks -{ - public Uuid() : base(new GuidUuidConverter()) { } -} +public class Uuid() : TypeHandlerBenchmarks(new GuidUuidConverter()); diff --git a/test/Npgsql.DependencyInjection.Tests/Npgsql.DependencyInjection.Tests.csproj b/test/Npgsql.DependencyInjection.Tests/Npgsql.DependencyInjection.Tests.csproj index 5f0006d79c..2f1f442547 100644 --- a/test/Npgsql.DependencyInjection.Tests/Npgsql.DependencyInjection.Tests.csproj +++ b/test/Npgsql.DependencyInjection.Tests/Npgsql.DependencyInjection.Tests.csproj @@ -4,6 +4,10 @@ + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + diff --git a/test/Npgsql.NativeAotTests/Npgsql.NativeAotTests.csproj b/test/Npgsql.NativeAotTests/Npgsql.NativeAotTests.csproj index f384594eb3..7f9ce607ca 100644 --- a/test/Npgsql.NativeAotTests/Npgsql.NativeAotTests.csproj +++ b/test/Npgsql.NativeAotTests/Npgsql.NativeAotTests.csproj @@ -9,12 +9,19 @@ true false true - Size + false + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + diff --git a/test/Npgsql.PluginTests/GeoJSONTests.cs b/test/Npgsql.PluginTests/GeoJSONTests.cs index 0630eebc8d..287c1277bc 100644 --- a/test/Npgsql.PluginTests/GeoJSONTests.cs +++ b/test/Npgsql.PluginTests/GeoJSONTests.cs @@ -23,89 +23,89 @@ public struct TestData } public static readonly TestData[] Tests = - { + [ new() { Geometry = new Point( new Position(longitude: 1d, latitude: 2d)) - { BoundingBoxes = new[] { 1d, 2d, 1d, 2d } }, + { BoundingBoxes = [1d, 2d, 1d, 2d] }, CommandText = "st_makepoint(1,2)" }, new() { - Geometry = new LineString(new[] { + Geometry = new LineString([ new Position(longitude: 1d, latitude: 1d), new Position(longitude: 1d, latitude: 2d) - }) - { BoundingBoxes = new[] { 1d, 1d, 1d, 2d } }, + ]) + { BoundingBoxes = [1d, 1d, 1d, 2d] }, CommandText = "st_makeline(st_makepoint(1,1), st_makepoint(1,2))" }, new() { - Geometry = new Polygon(new[] { - new LineString(new[] { + Geometry = new Polygon([ + new LineString([ new Position(longitude: 1d, latitude: 1d), new Position(longitude: 2d, latitude: 2d), new Position(longitude: 3d, latitude: 3d), new Position(longitude: 1d, latitude: 1d) - }) - }) - { BoundingBoxes = new[] { 1d, 1d, 3d, 3d } }, + ]) + ]) + { BoundingBoxes = [1d, 1d, 3d, 3d] }, CommandText = "st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1), st_makepoint(2,2), st_makepoint(3,3), st_makepoint(1,1)]))" }, new() { - Geometry = new MultiPoint(new[] { + Geometry = new MultiPoint([ new Point(new Position(longitude: 1d, latitude: 1d)) - }) - { BoundingBoxes = new[] { 1d, 1d, 1d, 1d } }, + ]) + { BoundingBoxes = [1d, 1d, 1d, 1d] }, CommandText = "st_multi(st_makepoint(1, 1))" }, new() { - Geometry = new MultiLineString(new[] { - new LineString(new[] { + Geometry = new MultiLineString([ + new LineString([ new Position(longitude: 1d, latitude: 1d), new Position(longitude: 1d, latitude: 2d) - }) - }) - { BoundingBoxes = new[] { 1d, 1d, 1d, 2d } }, + ]) + ]) + { BoundingBoxes = [1d, 1d, 1d, 2d] }, CommandText = "st_multi(st_makeline(st_makepoint(1,1), st_makepoint(1,2)))" }, new() { - Geometry = new MultiPolygon(new[] { - new Polygon(new[] { - new LineString(new[] { + Geometry = new MultiPolygon([ + new Polygon([ + new LineString([ new Position(longitude: 1d, latitude: 1d), new Position(longitude: 2d, latitude: 2d), new Position(longitude: 3d, latitude: 3d), new Position(longitude: 1d, latitude: 1d) - }) - }) - }) - { BoundingBoxes = new[] { 1d, 1d, 3d, 3d } }, + ]) + ]) + ]) + { BoundingBoxes = [1d, 1d, 3d, 3d] }, CommandText = "st_multi(st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1), st_makepoint(2,2), st_makepoint(3,3), st_makepoint(1,1)])))" }, new() { - Geometry = new GeometryCollection(new IGeometryObject[] { + Geometry = new GeometryCollection([ new Point(new Position(longitude: 1d, latitude: 1d)), - new MultiPolygon(new[] { - new Polygon(new[] { - new LineString(new[] { + new MultiPolygon([ + new Polygon([ + new LineString([ new Position(longitude: 1d, latitude: 1d), new Position(longitude: 2d, latitude: 2d), new Position(longitude: 3d, latitude: 3d), new Position(longitude: 1d, latitude: 1d) - }) - }) - }) - }) - { BoundingBoxes = new[] { 1d, 1d, 3d, 3d } }, + ]) + ]) + ]) + ]) + { BoundingBoxes = [1d, 1d, 3d, 3d] }, CommandText = "st_collect(st_makepoint(1,1),st_multi(st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1), st_makepoint(2,2), st_makepoint(3,3), st_makepoint(1,1)]))))" - }, - }; + } + ]; [Test, TestCaseSource(nameof(Tests))] public async Task Read(TestData data) @@ -138,24 +138,24 @@ public async Task IgnoreM() } public static readonly TestData[] NotAllZSpecifiedTests = - { + [ new() { - Geometry = new LineString(new[] { + Geometry = new LineString([ new Position(1d, 1d, 0d), new Position(2d, 2d) - }) + ]) }, new() { - Geometry = new LineString(new[] { + Geometry = new LineString([ new Position(1d, 1d, 0d), new Position(2d, 2d), new Position(3d, 3d), new Position(4d, 4d) - }) + ]) } - }; + ]; [Test, TestCaseSource(nameof(NotAllZSpecifiedTests))] public async Task Not_all_Z_specified(TestData data) @@ -304,7 +304,7 @@ public async Task Import_geometry(TestData data) await using var cmd = conn.CreateCommand(); cmd.CommandText = $"SELECT field FROM {table}"; await using var reader = await cmd.ExecuteReaderAsync(); - Assert.IsTrue(await reader.ReadAsync()); + Assert.That(await reader.ReadAsync()); var actual = reader.GetValue(0); Assert.That(actual, Is.EqualTo(data.Geometry)); } @@ -315,18 +315,18 @@ public async Task Import_big_geometry() await using var conn = await OpenConnectionAsync(); var table = await CreateTempTable(conn, "id text, field geometry"); - var geometry = new MultiLineString(new[] { + var geometry = new MultiLineString([ new LineString( Enumerable.Range(1, 507) .Select(i => new Position(longitude: i, latitude: i)) .Append(new Position(longitude: 1d, latitude: 1d))), - new LineString(new[] { + new LineString([ new Position(longitude: 1d, latitude: 1d), new Position(longitude: 1d, latitude: 2d), new Position(longitude: 1d, latitude: 3d), - new Position(longitude: 1d, latitude: 1d), - }) - }); + new Position(longitude: 1d, latitude: 1d) + ]) + ]); await using (var writer = await conn.BeginBinaryImportAsync($"COPY {table} (id, field) FROM STDIN BINARY")) { @@ -341,7 +341,7 @@ public async Task Import_big_geometry() await using var cmd = conn.CreateCommand(); cmd.CommandText = $"SELECT field FROM {table}"; await using var reader = await cmd.ExecuteReaderAsync(); - Assert.IsTrue(await reader.ReadAsync()); + Assert.That(await reader.ReadAsync()); var actual = reader.GetValue(0); Assert.That(actual, Is.EqualTo(geometry)); } @@ -375,18 +375,18 @@ public async Task Export_big_geometry() await using var conn = await OpenConnectionAsync(); var table = await CreateTempTable(conn, "id text, field geometry"); - var geometry = new Polygon(new[] { + var geometry = new Polygon([ new LineString( Enumerable.Range(1, 507) .Select(i => new Position(longitude: i, latitude: i)) .Append(new Position(longitude: 1d, latitude: 1d))), - new LineString(new[] { + new LineString([ new Position(longitude: 1d, latitude: 1d), new Position(longitude: 1d, latitude: 2d), new Position(longitude: 1d, latitude: 3d), - new Position(longitude: 1d, latitude: 1d), - }) - }); + new Position(longitude: 1d, latitude: 1d) + ]) + ]); await using (var writer = await conn.BeginBinaryImportAsync($"COPY {table} (id, field) FROM STDIN BINARY")) { diff --git a/test/Npgsql.PluginTests/JsonNetTests.cs b/test/Npgsql.PluginTests/JsonNetTests.cs index b3fb1e26bb..f20704e52f 100644 --- a/test/Npgsql.PluginTests/JsonNetTests.cs +++ b/test/Npgsql.PluginTests/JsonNetTests.cs @@ -1,10 +1,9 @@ using Newtonsoft.Json; using Newtonsoft.Json.Linq; using Npgsql.Tests; -using NpgsqlTypes; using NUnit.Framework; using System; -using System.Text; +using System.Data; using System.Threading.Tasks; // ReSharper disable AccessToModifiedClosure @@ -15,9 +14,9 @@ namespace Npgsql.PluginTests; /// /// Tests for the Npgsql.Json.NET mapping plugin /// -[TestFixture(NpgsqlDbType.Jsonb)] -[TestFixture(NpgsqlDbType.Json)] -public class JsonNetTests : TestBase +[TestFixture("jsonb")] +[TestFixture("json")] +public class JsonNetTests(string dataTypeName) : TestBase { [Test] public Task Roundtrip_object() @@ -25,10 +24,8 @@ public Task Roundtrip_object() JsonDataSource, new Foo { Bar = 8 }, IsJsonb ? @"{""Bar"": 8}" : @"{""Bar"":8}", - _pgTypeName, - _npgsqlDbType, - isDefault: false, - isNpgsqlDbTypeInferredFromClrType: false); + dataTypeName, dataTypeInference: DataTypeInference.Nothing, + valueTypeEqualsFieldType: false); [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3085")] public Task Roundtrip_string() @@ -36,10 +33,8 @@ public Task Roundtrip_string() JsonDataSource, @"{""p"": 1}", @"{""p"": 1}", - _pgTypeName, - _npgsqlDbType, - isDefault: false, - isNpgsqlDbTypeInferredFromClrType: false); + dataTypeName, dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Object, DbType.String)); [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3085")] public Task Roundtrip_char_array() @@ -47,21 +42,17 @@ public Task Roundtrip_char_array() JsonDataSource, @"{""p"": 1}".ToCharArray(), @"{""p"": 1}", - _pgTypeName, - _npgsqlDbType, - isDefault: false, - isNpgsqlDbTypeInferredFromClrType: false); + dataTypeName, dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Object, DbType.String), valueTypeEqualsFieldType: false); [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3085")] public Task Roundtrip_byte_array() => AssertType( JsonDataSource, - Encoding.ASCII.GetBytes(@"{""p"": 1}"), + @"{""p"": 1}"u8.ToArray(), @"{""p"": 1}", - _pgTypeName, - _npgsqlDbType, - isDefault: false, - isNpgsqlDbTypeInferredFromClrType: false); + dataTypeName, dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Object, DbType.Binary), valueTypeEqualsFieldType: false); [Test] public Task Roundtrip_JObject() @@ -69,12 +60,8 @@ public Task Roundtrip_JObject() JsonDataSource, new JObject { ["Bar"] = 8 }, IsJsonb ? @"{""Bar"": 8}" : @"{""Bar"":8}", - _pgTypeName, - _npgsqlDbType, - // By default we map JObject to jsonb - isDefaultForWriting: IsJsonb, - isDefaultForReading: false, - isNpgsqlDbTypeInferredFromClrType: false); + dataTypeName, dataTypeInference: DataTypeInference.Nothing, + valueTypeEqualsFieldType: false); [Test] public Task Roundtrip_JArray() @@ -82,18 +69,14 @@ public Task Roundtrip_JArray() JsonDataSource, new JArray(new[] { 1, 2, 3 }), IsJsonb ? "[1, 2, 3]" : "[1,2,3]", - _pgTypeName, - _npgsqlDbType, - // By default we map JArray to jsonb - isDefaultForWriting: IsJsonb, - isDefaultForReading: false, - isNpgsqlDbTypeInferredFromClrType: false); + dataTypeName, dataTypeInference: DataTypeInference.Nothing, + valueTypeEqualsFieldType: false); [Test] public async Task Deserialize_failure() { await using var conn = await JsonDataSource.OpenConnectionAsync(); - await using var cmd = new NpgsqlCommand($@"SELECT '[1, 2, 3]'::{_pgTypeName}", conn); + await using var cmd = new NpgsqlCommand($@"SELECT '[1, 2, 3]'::{dataTypeName}", conn); await using var reader = await cmd.ExecuteReaderAsync(); await reader.ReadAsync(); // Attempt to deserialize JSON array into object @@ -108,19 +91,17 @@ public async Task Clr_type_mapping() { var dataSourceBuilder = CreateDataSourceBuilder(); if (IsJsonb) - dataSourceBuilder.UseJsonNet(jsonbClrTypes: new[] { typeof(Foo) }); + dataSourceBuilder.UseJsonNet(jsonbClrTypes: [typeof(Foo)]); else - dataSourceBuilder.UseJsonNet(jsonClrTypes: new[] { typeof(Foo) }); + dataSourceBuilder.UseJsonNet(jsonClrTypes: [typeof(Foo)]); await using var dataSource = dataSourceBuilder.Build(); await AssertType( dataSource, new Foo { Bar = 8 }, IsJsonb ? @"{""Bar"": 8}" : @"{""Bar"":8}", - _pgTypeName, - _npgsqlDbType, - isDefaultForReading: false, - isNpgsqlDbTypeInferredFromClrType: false); + dataTypeName, + dataTypeInference: DataTypeInference.Nothing, valueTypeEqualsFieldType: false); } [Test] @@ -128,19 +109,17 @@ public async Task Roundtrip_clr_array() { var dataSourceBuilder = CreateDataSourceBuilder(); if (IsJsonb) - dataSourceBuilder.UseJsonNet(jsonbClrTypes: new[] { typeof(int[]) }); + dataSourceBuilder.UseJsonNet(jsonbClrTypes: [typeof(int[])]); else - dataSourceBuilder.UseJsonNet(jsonClrTypes: new[] { typeof(int[]) }); + dataSourceBuilder.UseJsonNet(jsonClrTypes: [typeof(int[])]); await using var dataSource = dataSourceBuilder.Build(); await AssertType( dataSource, new[] { 1, 2, 3 }, IsJsonb ? "[1, 2, 3]" : "[1,2,3]", - _pgTypeName, - _npgsqlDbType, - isDefaultForReading: false, - isNpgsqlDbTypeInferredFromClrType: false); + dataTypeName, dataTypeInference: DataTypeInference.Mismatch, + valueTypeEqualsFieldType: false, skipArrayCheck: true); // there is no value only mapping for int[][] } class DateWrapper @@ -157,34 +136,32 @@ public async Task Custom_serializer_settings() var dataSourceBuilder = CreateDataSourceBuilder(); if (IsJsonb) - dataSourceBuilder.UseJsonNet(jsonbClrTypes: new[] { typeof(DateWrapper) }, settings: settings); + dataSourceBuilder.UseJsonNet(jsonbClrTypes: [typeof(DateWrapper)], settings: settings); else - dataSourceBuilder.UseJsonNet(jsonClrTypes: new[] { typeof(DateWrapper) }, settings: settings); + dataSourceBuilder.UseJsonNet(jsonClrTypes: [typeof(DateWrapper)], settings: settings); await using var dataSource = dataSourceBuilder.Build(); await AssertType( dataSource, new DateWrapper { Date = new DateTime(2018, 04, 20) }, IsJsonb ? "{\"Date\": \"The 20th of April, 2018\"}" : "{\"Date\":\"The 20th of April, 2018\"}", - _pgTypeName, - _npgsqlDbType, - isDefault: false, - isNpgsqlDbTypeInferredFromClrType: false); + dataTypeName, + dataTypeInference: DataTypeInference.Nothing, valueTypeEqualsFieldType: false); } [Test] public async Task Bug3464() { var dataSourceBuilder = CreateDataSourceBuilder(); - dataSourceBuilder.UseJsonNet(jsonbClrTypes: new[] { typeof(Bug3464Class) }); + dataSourceBuilder.UseJsonNet(jsonbClrTypes: [typeof(Bug3464Class)]); await using var dataSource = dataSourceBuilder.Build(); var expected = new Bug3464Class { SomeString = new string('5', 8174) }; await using var conn = await dataSource.OpenConnectionAsync(); await using var cmd = new NpgsqlCommand(@"SELECT @p1, @p2", conn); - cmd.Parameters.AddWithValue("p1", expected).NpgsqlDbType = _npgsqlDbType; - cmd.Parameters.AddWithValue("p2", expected).NpgsqlDbType = _npgsqlDbType; + cmd.Parameters.AddWithValue("p1", expected).DataTypeName = dataTypeName; + cmd.Parameters.AddWithValue("p2", expected).DataTypeName = dataTypeName; await using var reader = cmd.ExecuteReader(); } @@ -261,9 +238,6 @@ class Foo public override int GetHashCode() => Bar.GetHashCode(); } - readonly NpgsqlDbType _npgsqlDbType; - readonly string _pgTypeName; - [OneTimeSetUp] public void SetUp() { @@ -276,13 +250,7 @@ public void SetUp() public async Task Teardown() => await JsonDataSource.DisposeAsync(); - public JsonNetTests(NpgsqlDbType npgsqlDbType) - { - _npgsqlDbType = npgsqlDbType; - _pgTypeName = npgsqlDbType.ToString().ToLower(); - } - - bool IsJsonb => _npgsqlDbType == NpgsqlDbType.Jsonb; + bool IsJsonb => dataTypeName == "jsonb"; NpgsqlDataSource JsonDataSource = default!; } diff --git a/test/Npgsql.PluginTests/LegacyNodaTimeTests.cs b/test/Npgsql.PluginTests/LegacyNodaTimeTests.cs index 6af0afec24..ff177b38a4 100644 --- a/test/Npgsql.PluginTests/LegacyNodaTimeTests.cs +++ b/test/Npgsql.PluginTests/LegacyNodaTimeTests.cs @@ -2,10 +2,9 @@ using System.Data; using System.Threading.Tasks; using NodaTime; +using Npgsql.NodaTime.Internal; using Npgsql.Tests; -using NpgsqlTypes; using NUnit.Framework; -using Npgsql.NodaTime.Internal; namespace Npgsql.PluginTests; @@ -16,47 +15,35 @@ public class LegacyNodaTimeTests : TestBase, IDisposable [Test] public async Task Timestamp_as_ZonedDateTime() - { - await AssertType( + => await AssertType( new LocalDateTime(1998, 4, 12, 13, 26, 38, 789).InZoneLeniently(DateTimeZoneProviders.Tzdb[TimeZone]), "1998-04-12 13:26:38.789+02", - "timestamp with time zone", - NpgsqlDbType.TimestampTz, - DbType.DateTimeOffset, - isNpgsqlDbTypeInferredFromClrType: false, isDefault: false); - } + "timestamp with time zone", dataTypeInference: DataTypeInference.Nothing, + dbType: new(DbType.DateTimeOffset, DbType.Object), valueTypeEqualsFieldType: false); [Test] public Task Timestamp_as_Instant() => AssertType( new LocalDateTime(1998, 4, 12, 13, 26, 38, 789).InUtc().ToInstant(), "1998-04-12 13:26:38.789", - "timestamp without time zone", - NpgsqlDbType.Timestamp, - DbType.DateTime, - isNpgsqlDbTypeInferredFromClrType: false); + "timestamp without time zone", dataTypeInference: DataTypeInference.Nothing, + dbType: new(DbType.DateTime, DbType.Object)); [Test] public Task Timestamp_as_LocalDateTime() => AssertType( new LocalDateTime(1998, 4, 12, 13, 26, 38, 789), "1998-04-12 13:26:38.789", - "timestamp without time zone", - NpgsqlDbType.Timestamp, - DbType.DateTime, - isDefaultForReading: false, - isNpgsqlDbTypeInferredFromClrType: false); + "timestamp without time zone", dataTypeInference: DataTypeInference.Nothing, + dbType: new(DbType.DateTime, DbType.Object), valueTypeEqualsFieldType: false); [Test] public Task Timestamptz_as_Instant() => AssertType( new LocalDateTime(1998, 4, 12, 13, 26, 38, 789).InUtc().ToInstant(), "1998-04-12 15:26:38.789+02", - "timestamp with time zone", - NpgsqlDbType.TimestampTz, - DbType.DateTimeOffset, - isDefault: false, - isNpgsqlDbTypeInferredFromClrType: false); + "timestamp with time zone", dataTypeInference: DataTypeInference.Nothing, + dbType: new(DbType.DateTimeOffset, DbType.Object)); [Test] public async Task Timestamptz_ZonedDateTime_infinite_values_are_not_supported() diff --git a/test/Npgsql.PluginTests/NetTopologySuiteTests.cs b/test/Npgsql.PluginTests/NetTopologySuiteTests.cs index 4e225d121c..54a1a91026 100644 --- a/test/Npgsql.PluginTests/NetTopologySuiteTests.cs +++ b/test/Npgsql.PluginTests/NetTopologySuiteTests.cs @@ -14,60 +14,55 @@ namespace Npgsql.PluginTests; public class NetTopologySuiteTests : TestBase { static readonly TestCaseData[] TestCases = - { + [ new TestCaseData(Ordinates.None, new Point(1d, 2500d), "st_makepoint(1,2500)") .SetName("Point"), - new TestCaseData(Ordinates.None, new MultiPoint(new[] { new Point(new Coordinate(1d, 1d)) }), "st_multi(st_makepoint(1, 1))") + new TestCaseData(Ordinates.None, new MultiPoint([new Point(new Coordinate(1d, 1d))]), "st_multi(st_makepoint(1, 1))") .SetName("MultiPoint"), new TestCaseData( Ordinates.None, - new LineString(new[] { new Coordinate(1d, 1d), new Coordinate(1d, 2500d) }), + new LineString([new Coordinate(1d, 1d), new Coordinate(1d, 2500d)]), "st_makeline(st_makepoint(1,1),st_makepoint(1,2500))") .SetName("LineString"), new TestCaseData( Ordinates.None, - new MultiLineString(new[] - { - new LineString(new[] - { + new MultiLineString([ + new LineString([ new Coordinate(1d, 1d), new Coordinate(1d, 2500d) - }) - }), + ]) + ]), "st_multi(st_makeline(st_makepoint(1,1),st_makepoint(1,2500)))") .SetName("MultiLineString"), new TestCaseData( Ordinates.None, new Polygon( - new LinearRing(new[] - { + new LinearRing([ new Coordinate(1d, 1d), new Coordinate(2d, 2d), new Coordinate(3d, 3d), new Coordinate(1d, 1d) - }) + ]) ), "st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1),st_makepoint(2,2),st_makepoint(3,3),st_makepoint(1,1)]))") .SetName("Polygon"), new TestCaseData( Ordinates.None, - new MultiPolygon(new[] - { + new MultiPolygon([ new Polygon( - new LinearRing(new[] - { + new LinearRing([ new Coordinate(1d, 1d), new Coordinate(2d, 2d), new Coordinate(3d, 3d), new Coordinate(1d, 1d) - }) + ]) ) - }), + ]), "st_multi(st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1),st_makepoint(2,2),st_makepoint(3,3),st_makepoint(1,1)])))") .SetName("MultiPolygon"), @@ -76,47 +71,40 @@ public class NetTopologySuiteTests : TestBase new TestCaseData( Ordinates.None, - new GeometryCollection(new Geometry[] - { + new GeometryCollection([ new Point(new Coordinate(1d, 1d)), - new MultiPolygon(new[] - { + new MultiPolygon([ new Polygon( - new LinearRing(new[] - { + new LinearRing([ new Coordinate(1d, 1d), new Coordinate(2d, 2d), new Coordinate(3d, 3d), new Coordinate(1d, 1d) - }) + ]) ) - }) - }), + ]) + ]), "st_collect(st_makepoint(1,1),st_multi(st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1),st_makepoint(2,2),st_makepoint(3,3),st_makepoint(1,1)]))))") .SetName("Collection"), new TestCaseData( Ordinates.None, - new GeometryCollection(new Geometry[] - { + new GeometryCollection([ new Point(new Coordinate(1d, 1d)), - new GeometryCollection(new Geometry[] - { + new GeometryCollection([ new Point(new Coordinate(1d, 1d)), - new MultiPolygon(new[] - { + new MultiPolygon([ new Polygon( - new LinearRing(new[] - { + new LinearRing([ new Coordinate(1d, 1d), new Coordinate(2d, 2d), new Coordinate(3d, 3d), new Coordinate(1d, 1d) - }) + ]) ) - }) - }) - }), + ]) + ]) + ]), "st_collect(st_makepoint(1,1),st_collect(st_makepoint(1,1),st_multi(st_makepolygon(st_makeline(ARRAY[st_makepoint(1,1),st_makepoint(2,2),st_makepoint(3,3),st_makepoint(1,1)])))))") .SetName("CollectionNested"), @@ -126,23 +114,22 @@ public class NetTopologySuiteTests : TestBase new TestCaseData( Ordinates.XYZM, new Point( - new DotSpatialAffineCoordinateSequence(new[] { 1d, 2d }, new[] { 3d }, new[] { 4d }), + new DotSpatialAffineCoordinateSequence([1d, 2d], [3d], [4d]), GeometryFactory.Default), "st_makepoint(1,2,3,4)") .SetName("PointXYZM"), new TestCaseData( Ordinates.None, - new LinearRing(new[] - { - new Coordinate(1d, 1d), + new LinearRing([ + new Coordinate(1d, 1d), new Coordinate(2d, 2d), new Coordinate(3d, 3d), new Coordinate(1d, 1d) - }), + ]), "st_makeline(ARRAY[st_makepoint(1,1),st_makepoint(2,2),st_makepoint(3,3),st_makepoint(1,1)])") .SetName("LinearRing") - }; + ]; [Test, TestCaseSource(nameof(TestCases))] public async Task Read(Ordinates ordinates, Geometry geometry, string sqlRepresentation) @@ -173,8 +160,7 @@ await AssertType( new Geometry[] { point }, '{' + GetSqlLiteral(point) + '}', "geometry[]", - NpgsqlDbType.Geometry | NpgsqlDbType.Array, - isNpgsqlDbTypeInferredFromClrType: false); + dataTypeInference: DataTypeInference.Nothing); } [Test] @@ -221,53 +207,45 @@ public async Task Concurrency_test() await adminConnection.ExecuteNonQueryAsync($"INSERT INTO {table} DEFAULT VALUES"); var point = new Point(new Coordinate(1d, 1d)); - var lineString = new LineString(new[] { new Coordinate(1d, 1d), new Coordinate(1d, 2500d) }); + var lineString = new LineString([new Coordinate(1d, 1d), new Coordinate(1d, 2500d)]); var polygon = new Polygon( - new LinearRing(new[] - { + new LinearRing([ new Coordinate(1d, 1d), new Coordinate(2d, 2d), new Coordinate(3d, 3d), new Coordinate(1d, 1d) - }) + ]) ); - var multiPoint = new MultiPoint(new[] { new Point(new Coordinate(1d, 1d)) }); - var multiLineString = new MultiLineString(new[] - { - new LineString(new[] - { + var multiPoint = new MultiPoint([new Point(new Coordinate(1d, 1d))]); + var multiLineString = new MultiLineString([ + new LineString([ new Coordinate(1d, 1d), new Coordinate(1d, 2500d) - }) - }); - var multiPolygon = new MultiPolygon(new[] - { + ]) + ]); + var multiPolygon = new MultiPolygon([ new Polygon( - new LinearRing(new[] - { + new LinearRing([ new Coordinate(1d, 1d), new Coordinate(2d, 2d), new Coordinate(3d, 3d), new Coordinate(1d, 1d) - }) + ]) ) - }); - var collection = new GeometryCollection(new Geometry[] - { + ]); + var collection = new GeometryCollection([ new Point(new Coordinate(1d, 1d)), - new MultiPolygon(new[] - { + new MultiPolygon([ new Polygon( - new LinearRing(new[] - { + new LinearRing([ new Coordinate(1d, 1d), new Coordinate(2d, 2d), new Coordinate(3d, 3d), new Coordinate(1d, 1d) - }) + ]) ) - }) - }); + ]) + ]); await Task.WhenAll(Enumerable.Range(0, 30).Select(i => Task.Run(async () => { diff --git a/test/Npgsql.PluginTests/NodaTimeInfinityTests.cs b/test/Npgsql.PluginTests/NodaTimeInfinityTests.cs index 59f581e7de..52068898d2 100644 --- a/test/Npgsql.PluginTests/NodaTimeInfinityTests.cs +++ b/test/Npgsql.PluginTests/NodaTimeInfinityTests.cs @@ -1,11 +1,11 @@ using System; +using System.Data; using System.Threading.Tasks; using NodaTime; using Npgsql.Tests; using Npgsql.Util; using NpgsqlTypes; using NUnit.Framework; -using static Npgsql.NodaTime.Internal.NodaTimeUtils; namespace Npgsql.PluginTests; @@ -26,15 +26,12 @@ await AssertType( new DateInterval(LocalDate.MinIsoValue, LocalDate.MaxIsoValue), "[-infinity,infinity]", "daterange", - NpgsqlDbType.DateRange, - isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); // NpgsqlRange[] is mapped to multirange by default, not array; test separately + dataTypeInference: DataTypeInference.Nothing, skipArrayCheck: true); // NpgsqlRange[] is mapped to multirange by default, not array; test separately await AssertType( new [] {new DateInterval(LocalDate.MinIsoValue, LocalDate.MaxIsoValue)}, """{"[-infinity,infinity]"}""", - "daterange[]", - NpgsqlDbType.DateRange | NpgsqlDbType.Array, - isDefault: false, skipArrayCheck: true); + "daterange[]", dataTypeInference: DataTypeInference.Nothing, skipArrayCheck: true); await using var conn = await OpenConnectionAsync(); if (conn.PostgreSqlVersion < new Version(14, 0)) @@ -43,8 +40,7 @@ await AssertType( await AssertType( new [] {new DateInterval(LocalDate.MinIsoValue, LocalDate.MaxIsoValue)}, """{[-infinity,infinity]}""", - "datemultirange", - NpgsqlDbType.DateMultirange, isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); + "datemultirange", dataTypeInference: DataTypeInference.Nothing, skipArrayCheck: true); } [Test] @@ -281,6 +277,106 @@ public async Task DateConvertInfinity() } } + [Test] + public async Task Interval_write() + { + await using var conn = await OpenConnectionAsync(); + TestUtil.MinimumPgVersion(conn, "17.0", "Infinity values for intervals were introduced in PostgreSQL 17"); + await using var cmd = new NpgsqlCommand("SELECT $1::text", conn) + { + Parameters = { new() { Value = Period.MinValue, NpgsqlDbType = NpgsqlDbType.Interval } } + }; + + // While Period.MinValue technically isn't outside of supported values by postgres, we can't reasonably convert it + if (Statics.DisableDateTimeInfinityConversions) + { + Assert.That(async () => await cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf()); + Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); + await conn.OpenAsync(); + } + else + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo("-infinity")); + + cmd.Parameters[0].Value = Period.MaxValue; + + // While Period.MaxValue technically isn't outside of supported values by postgres, we can't reasonably convert it + if (Statics.DisableDateTimeInfinityConversions) + Assert.That(async () => await cmd.ExecuteScalarAsync(), Throws.Exception.TypeOf()); + else + Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo("infinity")); + } + + [Test] + public async Task Interval_read() + { + await using var conn = await OpenConnectionAsync(); + TestUtil.MinimumPgVersion(conn, "17.0", "Infinity values for intervals were introduced in PostgreSQL 17"); + + await using var cmd = new NpgsqlCommand("SELECT '-infinity'::interval, 'infinity'::interval", conn); + + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + + if (Statics.DisableDateTimeInfinityConversions) + { + Assert.That(() => reader[0], Throws.Exception.TypeOf()); + Assert.That(() => reader[1], Throws.Exception.TypeOf()); + } + else + { + Assert.That(reader[0], Is.EqualTo(Period.MinValue)); + Assert.That(reader[1], Is.EqualTo(Period.MaxValue)); + } + } + + [Test, Description("Makes sure that when ConvertInfinityDateTime is true, infinity values are properly converted")] + public async Task Interval_convert_infinity() + { + if (Statics.DisableDateTimeInfinityConversions) + return; + + await using var conn = await OpenConnectionAsync(); + TestUtil.MinimumPgVersion(conn, "17.0", "Infinity values for intervals were introduced in PostgreSQL 17"); + await conn.ExecuteNonQueryAsync("CREATE TEMP TABLE data (i1 INTERVAL, i2 INTERVAL)"); + + using (var cmd = new NpgsqlCommand("INSERT INTO data VALUES (@p1, @p2)", conn)) + { + cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Interval, Period.MaxValue); + cmd.Parameters.AddWithValue("p2", NpgsqlDbType.Interval, Period.MinValue); + await cmd.ExecuteNonQueryAsync(); + } + + using (var cmd = new NpgsqlCommand("SELECT i1::TEXT, i2::TEXT, i1, i2 FROM data", conn)) + using (var reader = await cmd.ExecuteReaderAsync()) + { + await reader.ReadAsync(); + Assert.That(reader.GetValue(0), Is.EqualTo("infinity")); + Assert.That(reader.GetValue(1), Is.EqualTo("-infinity")); + Assert.That(reader.GetFieldValue(2), Is.EqualTo(Period.MaxValue)); + Assert.That(reader.GetFieldValue(3), Is.EqualTo(Period.MinValue)); + } + } + + [Test] + public async Task Inclusive_End_Range_Infinity_read() + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand( + "SELECT tstzrange('-infinity', 'infinity','[]') as val", conn); + + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + + if (Statics.DisableDateTimeInfinityConversions) + { + Assert.That(() => reader[0], Throws.Exception.TypeOf()); + } + else + { + Assert.That(reader[0], Is.EqualTo(new Interval(Instant.MinValue, null))); + } + } + protected override NpgsqlDataSource DataSource { get; } public NodaTimeInfinityTests(bool disableDateTimeInfinityConversions) diff --git a/test/Npgsql.PluginTests/NodaTimeTests.cs b/test/Npgsql.PluginTests/NodaTimeTests.cs index a6af632723..bec0b46c9b 100644 --- a/test/Npgsql.PluginTests/NodaTimeTests.cs +++ b/test/Npgsql.PluginTests/NodaTimeTests.cs @@ -13,24 +13,25 @@ namespace Npgsql.PluginTests; -public class NodaTimeTests : MultiplexingTestBase, IDisposable +public class NodaTimeTests : TestBase, IDisposable { #region Timestamp without time zone static readonly TestCaseData[] TimestampValues = - { + [ new TestCaseData(new LocalDateTime(1998, 4, 12, 13, 26, 38, 789), "1998-04-12 13:26:38.789") .SetName("Timestamp_pre2000"), new TestCaseData(new LocalDateTime(2015, 1, 27, 8, 45, 12, 345), "2015-01-27 08:45:12.345") .SetName("Timestamp_post2000"), new TestCaseData(new LocalDateTime(1999, 12, 31, 23, 59, 59, 999).PlusNanoseconds(456000), "1999-12-31 23:59:59.999456") .SetName("Timestamp_with_microseconds") - }; + ]; [Test, TestCaseSource(nameof(TimestampValues))] public Task Timestamp_as_LocalDateTime(LocalDateTime localDateTime, string sqlLiteral) - => AssertType(localDateTime, sqlLiteral, "timestamp without time zone", NpgsqlDbType.Timestamp, DbType.DateTime2, - isNpgsqlDbTypeInferredFromClrType: false); + => AssertType(localDateTime, sqlLiteral, + "timestamp without time zone", dataTypeInference: DataTypeInference.Nothing, + dbType: new(DbType.DateTime2, DbType.Object)); [Test] public Task Timestamp_as_unspecified_DateTime() @@ -38,19 +39,15 @@ public Task Timestamp_as_unspecified_DateTime() new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Unspecified), "1998-04-12 13:26:38", "timestamp without time zone", - NpgsqlDbType.Timestamp, - DbType.DateTime2, - isDefaultForReading: false); + dbType: DbType.DateTime2, valueTypeEqualsFieldType: false); [Test] public Task Timestamp_as_long() => AssertType( -54297202000000, "1998-04-12 13:26:38", - "timestamp without time zone", - NpgsqlDbType.Timestamp, - DbType.DateTime2, - isDefault: false); + "timestamp without time zone", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.DateTime2, DbType.Int64), valueTypeEqualsFieldType: false); [Test] public Task Timestamp_cannot_use_as_Instant() @@ -93,8 +90,7 @@ await AssertType( new(1998, 4, 12, 15, 26, 38)), """["1998-04-12 13:26:38","1998-04-12 15:26:38"]""", "tsrange", - NpgsqlDbType.TimestampRange, - isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); // NpgsqlRange[] is mapped to multirange by default, not array; test separately + dataTypeInference: DataTypeInference.Nothing, skipArrayCheck: true); // NpgsqlRange[] is mapped to multirange by default, not array; test separately await AssertType( new [] { new NpgsqlRange( @@ -102,8 +98,8 @@ await AssertType( new(1998, 4, 12, 15, 26, 38)), }, """{"[\"1998-04-12 13:26:38\",\"1998-04-12 15:26:38\"]"}""", "tsrange[]", - NpgsqlDbType.TimestampRange | NpgsqlDbType.Array, - isDefault: false, skipArrayCheck: true); + dataTypeInference: DataTypeInference.Nothing, + skipArrayCheck: true); await using var conn = await OpenConnectionAsync(); if (conn.PostgreSqlVersion < new Version(14, 0)) @@ -114,8 +110,7 @@ await AssertType( new(1998, 4, 12, 13, 26, 38), new(1998, 4, 12, 15, 26, 38)), }, """{["1998-04-12 13:26:38","1998-04-12 15:26:38"]}""", - "tsmultirange", - NpgsqlDbType.TimestampMultirange, isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); + "tsmultirange", dataTypeInference: DataTypeInference.Nothing, skipArrayCheck: true); } [Test] @@ -136,8 +131,7 @@ await AssertType( }, """{["1998-04-12 13:26:38","1998-04-12 15:26:38"],["1998-04-13 13:26:38","1998-04-13 15:26:38"]}""", "tsmultirange", - NpgsqlDbType.TimestampMultirange, - isNpgsqlDbTypeInferredFromClrType: false); + dataTypeInference: DataTypeInference.Nothing); } #endregion Timestamp without time zone @@ -145,7 +139,7 @@ await AssertType( #region Timestamp with time zone static readonly TestCaseData[] TimestamptzValues = - { + [ new TestCaseData(new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc().ToInstant(), "1998-04-12 15:26:38+02") .SetName("Timestamptz_pre2000"), new TestCaseData(new LocalDateTime(2015, 1, 27, 8, 45, 12, 345).InUtc().ToInstant(), "2015-01-27 09:45:12.345+01") @@ -154,64 +148,53 @@ await AssertType( .SetName("Timestamptz_write_date_only"), new TestCaseData(new LocalDateTime(1999, 12, 31, 23, 59, 59, 999).PlusNanoseconds(456000).InUtc().ToInstant(), "2000-01-01 00:59:59.999456+01") .SetName("Timestamptz_with_microseconds") - }; + ]; [Test, TestCaseSource(nameof(TimestamptzValues))] public Task Timestamptz_as_Instant(Instant instant, string sqlLiteral) - => AssertType(instant, sqlLiteral, "timestamp with time zone", NpgsqlDbType.TimestampTz, DbType.DateTime, - isNpgsqlDbTypeInferredFromClrType: false); + => AssertType(instant, sqlLiteral, + "timestamp with time zone", dataTypeInference: DataTypeInference.Nothing, + dbType: new(DbType.DateTime, DbType.Object)); [Test] public Task Timestamptz_as_ZonedDateTime() => AssertType( new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc(), "1998-04-12 15:26:38+02", - "timestamp with time zone", - NpgsqlDbType.TimestampTz, - DbType.DateTime, - isNpgsqlDbTypeInferredFromClrType: false, - isDefaultForReading: false); + "timestamp with time zone", dataTypeInference: DataTypeInference.Nothing, + dbType: new(DbType.DateTime, DbType.Object), valueTypeEqualsFieldType: false); [Test] public Task Timestamptz_as_OffsetDateTime() => AssertType( new LocalDateTime(1998, 4, 12, 13, 26, 38).WithOffset(Offset.Zero), "1998-04-12 15:26:38+02", - "timestamp with time zone", - NpgsqlDbType.TimestampTz, - DbType.DateTime, - isNpgsqlDbTypeInferredFromClrType: false, - isDefaultForReading: false); + "timestamp with time zone", dataTypeInference: DataTypeInference.Nothing, + dbType: new(DbType.DateTime, DbType.Object), valueTypeEqualsFieldType: false); [Test] public Task Timestamptz_as_utc_DateTime() => AssertType( new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), "1998-04-12 15:26:38+02", - "timestamp with time zone", - NpgsqlDbType.TimestampTz, - DbType.DateTime, - isDefaultForReading: false); + "timestamp with time zone", dataTypeInference: DataTypeInference.Mismatch, + dbType: DbType.DateTime, valueTypeEqualsFieldType: false); [Test] public Task Timestamptz_as_DateTimeOffset() => AssertType( new DateTimeOffset(1998, 4, 12, 13, 26, 38, TimeSpan.Zero), "1998-04-12 15:26:38+02", - "timestamp with time zone", - NpgsqlDbType.TimestampTz, - DbType.DateTime, - isDefaultForReading: false); + "timestamp with time zone", dataTypeInference: DataTypeInference.Mismatch, + dbType: DbType.DateTime, valueTypeEqualsFieldType: false); [Test] public Task Timestamptz_as_long() => AssertType( -54297202000000, "1998-04-12 15:26:38+02", - "timestamp with time zone", - NpgsqlDbType.TimestampTz, - DbType.DateTime, - isDefault: false); + "timestamp with time zone", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.DateTime, DbType.Int64), valueTypeEqualsFieldType: false); [Test] public Task Timestamptz_cannot_use_as_LocalDateTime() @@ -243,17 +226,14 @@ await AssertType( new LocalDateTime(1998, 4, 12, 15, 26, 38).InUtc().ToInstant()), """["1998-04-12 15:26:38+02","1998-04-12 17:26:38+02")""", "tstzrange", - NpgsqlDbType.TimestampTzRange, - isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); // NpgsqlRange[] is mapped to multirange by default, not array; test separately + dataTypeInference: DataTypeInference.Nothing, skipArrayCheck: true); // NpgsqlRange[] is mapped to multirange by default, not array; test separately await AssertType( new [] { new Interval( new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc().ToInstant(), new LocalDateTime(1998, 4, 12, 15, 26, 38).InUtc().ToInstant()), }, """{"[\"1998-04-12 15:26:38+02\",\"1998-04-12 17:26:38+02\")"}""", - "tstzrange[]", - NpgsqlDbType.TimestampTzRange | NpgsqlDbType.Array, - isDefault: false, skipArrayCheck: true); + "tstzrange[]", dataTypeInference: DataTypeInference.Nothing, skipArrayCheck: true); await using var conn = await OpenConnectionAsync(); if (conn.PostgreSqlVersion < new Version(14, 0)) @@ -264,8 +244,7 @@ await AssertType( new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc().ToInstant(), new LocalDateTime(1998, 4, 12, 15, 26, 38).InUtc().ToInstant()), }, """{["1998-04-12 15:26:38+02","1998-04-12 17:26:38+02")}""", - "tstzmultirange", - NpgsqlDbType.TimestampTzMultirange, isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); + "tstzmultirange", dataTypeInference: DataTypeInference.Nothing, skipArrayCheck: true); } [Test] @@ -273,18 +252,14 @@ public Task Tstzrange_with_no_end_as_Interval() => AssertType( new Interval(new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc().ToInstant(), null), """["1998-04-12 15:26:38+02",)""", - "tstzrange", - NpgsqlDbType.TimestampTzRange, - isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); + "tstzrange", dataTypeInference: DataTypeInference.Nothing, skipArrayCheck: true); [Test] public Task Tstzrange_with_no_start_as_Interval() => AssertType( new Interval(null, new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc().ToInstant()), """(,"1998-04-12 15:26:38+02")""", - "tstzrange", - NpgsqlDbType.TimestampTzRange, - isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); + "tstzrange", dataTypeInference: DataTypeInference.Nothing, skipArrayCheck: true); [Test] public Task Tstzrange_with_no_start_or_end_as_Interval() @@ -292,8 +267,7 @@ public Task Tstzrange_with_no_start_or_end_as_Interval() new Interval(null, null), """(,)""", "tstzrange", - NpgsqlDbType.TimestampTzRange, - isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); + dataTypeInference: DataTypeInference.Nothing, skipArrayCheck: true); [Test] public Task Tstzrange_as_NpgsqlRange_of_Instant() @@ -302,10 +276,9 @@ public Task Tstzrange_as_NpgsqlRange_of_Instant() new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc().ToInstant(), new LocalDateTime(1998, 4, 12, 15, 26, 38).InUtc().ToInstant()), """["1998-04-12 15:26:38+02","1998-04-12 17:26:38+02"]""", - "tstzrange", - NpgsqlDbType.TimestampTzRange, - isNpgsqlDbTypeInferredFromClrType: false, - isDefaultForReading: false, skipArrayCheck: true); + "tstzrange", dataTypeInference: DataTypeInference.Nothing, + valueTypeEqualsFieldType: false, + skipArrayCheck: true); [Test] public Task Tstzrange_as_NpgsqlRange_of_ZonedDateTime() @@ -314,10 +287,9 @@ public Task Tstzrange_as_NpgsqlRange_of_ZonedDateTime() new LocalDateTime(1998, 4, 12, 13, 26, 38).InUtc(), new LocalDateTime(1998, 4, 12, 15, 26, 38).InUtc()), """["1998-04-12 15:26:38+02","1998-04-12 17:26:38+02"]""", - "tstzrange", - NpgsqlDbType.TimestampTzRange, - isNpgsqlDbTypeInferredFromClrType: false, - isDefaultForReading: false, skipArrayCheck: true); + "tstzrange", dataTypeInference: DataTypeInference.Nothing, + valueTypeEqualsFieldType: false, + skipArrayCheck: true); [Test] public Task Tstzrange_as_NpgsqlRange_of_OffsetDateTime() @@ -326,10 +298,9 @@ public Task Tstzrange_as_NpgsqlRange_of_OffsetDateTime() new LocalDateTime(1998, 4, 12, 13, 26, 38).WithOffset(Offset.Zero), new LocalDateTime(1998, 4, 12, 15, 26, 38).WithOffset(Offset.Zero)), """["1998-04-12 15:26:38+02","1998-04-12 17:26:38+02"]""", - "tstzrange", - NpgsqlDbType.TimestampTzRange, - isNpgsqlDbTypeInferredFromClrType: false, - isDefaultForReading: false, skipArrayCheck: true); + "tstzrange", dataTypeInference: DataTypeInference.Nothing, + valueTypeEqualsFieldType: false, + skipArrayCheck: true); [Test] public async Task Tstzmultirange_as_array_of_Interval() @@ -348,9 +319,7 @@ await AssertType( new LocalDateTime(1998, 4, 13, 15, 26, 38).InUtc().ToInstant()), }, """{["1998-04-12 15:26:38+02","1998-04-12 17:26:38+02"),["1998-04-13 15:26:38+02","1998-04-13 17:26:38+02")}""", - "tstzmultirange", - NpgsqlDbType.TimestampTzMultirange, - isNpgsqlDbTypeInferredFromClrType: false); + "tstzmultirange", dataTypeInference: DataTypeInference.Nothing); } [Test] @@ -370,10 +339,8 @@ await AssertType( new LocalDateTime(1998, 4, 13, 15, 26, 38).InUtc().ToInstant()), }, """{["1998-04-12 15:26:38+02","1998-04-12 17:26:38+02"],["1998-04-13 15:26:38+02","1998-04-13 17:26:38+02"]}""", - "tstzmultirange", - NpgsqlDbType.TimestampTzMultirange, - isNpgsqlDbTypeInferredFromClrType: false, - isDefaultForReading: false); + "tstzmultirange", dataTypeInference: DataTypeInference.Nothing, + valueTypeEqualsFieldType: false); } [Test] @@ -393,10 +360,8 @@ await AssertType( new LocalDateTime(1998, 4, 13, 15, 26, 38).InUtc()), }, """{["1998-04-12 15:26:38+02","1998-04-12 17:26:38+02"],["1998-04-13 15:26:38+02","1998-04-13 17:26:38+02"]}""", - "tstzmultirange", - NpgsqlDbType.TimestampTzMultirange, - isNpgsqlDbTypeInferredFromClrType: false, - isDefaultForReading: false); + "tstzmultirange", dataTypeInference: DataTypeInference.Nothing, + valueTypeEqualsFieldType: false); } [Test] @@ -416,10 +381,8 @@ await AssertType( new LocalDateTime(1998, 4, 13, 15, 26, 38).WithOffset(Offset.Zero)), }, """{["1998-04-12 15:26:38+02","1998-04-12 17:26:38+02"],["1998-04-13 15:26:38+02","1998-04-13 17:26:38+02"]}""", - "tstzmultirange", - NpgsqlDbType.TimestampTzMultirange, - isNpgsqlDbTypeInferredFromClrType: false, - isDefaultForReading: false); + "tstzmultirange", dataTypeInference: DataTypeInference.Nothing, + valueTypeEqualsFieldType: false); } [Test] @@ -447,10 +410,7 @@ await AssertType( null) }, """{"[\"1998-04-12 15:26:38+02\",\"1998-04-12 17:26:38+02\")","[\"1998-04-13 15:26:38+02\",\"1998-04-13 17:26:38+02\")","[\"1998-04-13 15:26:38+02\",)","(,\"1998-04-13 15:26:38+02\")","(,)"}""", - "tstzrange[]", - NpgsqlDbType.TimestampTzRange | NpgsqlDbType.Array, - isNpgsqlDbTypeInferredFromClrType: false, - isDefaultForWriting: false); + "tstzrange[]", dataTypeInference: DataTypeInference.Nothing); } [Test] @@ -469,10 +429,8 @@ await AssertType( new LocalDateTime(1998, 4, 13, 15, 26, 38).InUtc().ToInstant()), }, """{"[\"1998-04-12 15:26:38+02\",\"1998-04-12 17:26:38+02\"]","[\"1998-04-13 15:26:38+02\",\"1998-04-13 17:26:38+02\"]"}""", - "tstzrange[]", - NpgsqlDbType.TimestampTzRange | NpgsqlDbType.Array, - isNpgsqlDbTypeInferredFromClrType: false, - isDefault: false); + "tstzrange[]", dataTypeInference: DataTypeInference.Nothing, + valueTypeEqualsFieldType: false); } #endregion Timestamp with time zone @@ -481,16 +439,21 @@ await AssertType( [Test] public Task Date_as_LocalDate() - => AssertType(new LocalDate(2020, 10, 1), "2020-10-01", "date", NpgsqlDbType.Date, DbType.Date, - isNpgsqlDbTypeInferredFromClrType: false); + => AssertType(new LocalDate(2020, 10, 1), "2020-10-01", + "date", dataTypeInference: DataTypeInference.Nothing, + dbType: new(DbType.Date, DbType.Object)); [Test] public Task Date_as_DateTime() - => AssertType(new DateTime(2020, 10, 1), "2020-10-01", "date", NpgsqlDbType.Date, DbType.Date, isDefault: false); + => AssertType(new DateTime(2020, 10, 1), "2020-10-01", + "date", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Date, DbType.DateTime2), valueTypeEqualsFieldType: false); [Test] public Task Date_as_int() - => AssertType(7579, "2020-10-01", "date", NpgsqlDbType.Date, DbType.Date, isDefault: false); + => AssertType(7579, "2020-10-01", + "date", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Date, DbType.Int32), valueTypeEqualsFieldType: false); [Test] public async Task Daterange_as_DateInterval() @@ -499,15 +462,13 @@ await AssertType( new DateInterval(new(2002, 3, 4), new(2002, 3, 6)), "[2002-03-04,2002-03-07)", "daterange", - NpgsqlDbType.DateRange, - isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); // DateInterval[] is mapped to multirange by default, not array; test separately + dataTypeInference: DataTypeInference.Nothing, skipArrayCheck: true); // DateInterval[] is mapped to multirange by default, not array; test separately await AssertType( new [] {new DateInterval(new(2002, 3, 4), new(2002, 3, 6))}, """{"[2002-03-04,2002-03-07)"}""", - "daterange[]", - NpgsqlDbType.DateRange | NpgsqlDbType.Array, - isDefault: false, skipArrayCheck: true); + "daterange[]", dataTypeInference: DataTypeInference.Nothing, + skipArrayCheck: true); await using var conn = await OpenConnectionAsync(); if (conn.PostgreSqlVersion < new Version(14, 0)) @@ -516,8 +477,8 @@ await AssertType( await AssertType( new [] {new DateInterval(new(2002, 3, 4), new(2002, 3, 6))}, """{[2002-03-04,2002-03-07)}""", - "datemultirange", - NpgsqlDbType.DateMultirange, isNpgsqlDbTypeInferredFromClrType: false, skipArrayCheck: true); + "datemultirange", dataTypeInference: DataTypeInference.Nothing, + skipArrayCheck: true); } [Test] @@ -526,17 +487,16 @@ public async Task Daterange_as_NpgsqlRange_of_LocalDate() await AssertType( new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false), "[2002-03-04,2002-03-06)", - "daterange", - NpgsqlDbType.DateRange, - isNpgsqlDbTypeInferredFromClrType: false, - isDefaultForReading: false, skipArrayCheck: true); // NpgsqlRange[] is mapped to multirange by default, not array; test separately + "daterange", dataTypeInference: DataTypeInference.Nothing, + valueTypeEqualsFieldType: false, + skipArrayCheck: true); // NpgsqlRange[] is mapped to multirange by default, not array; test separately await AssertType( new [] { new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false) }, """{"[2002-03-04,2002-03-06)"}""", - "daterange[]", - NpgsqlDbType.DateRange | NpgsqlDbType.Array, - isDefault: false, skipArrayCheck: true); + "daterange[]", dataTypeInference: DataTypeInference.Nothing, + valueTypeEqualsFieldType: false, + skipArrayCheck: true); await using var conn = await OpenConnectionAsync(); if (conn.PostgreSqlVersion < new Version(14, 0)) @@ -545,8 +505,9 @@ await AssertType( await AssertType( new [] { new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false) }, """{[2002-03-04,2002-03-06)}""", - "datemultirange", - NpgsqlDbType.DateMultirange, isDefault: false, skipArrayCheck: true); + "datemultirange", dataTypeInference: DataTypeInference.Nothing, + valueTypeEqualsFieldType: false, + skipArrayCheck: true); } [Test] @@ -563,8 +524,7 @@ await AssertType( }, "{[2002-03-04,2002-03-06),[2002-03-08,2002-03-11)}", "datemultirange", - NpgsqlDbType.DateMultirange, - isNpgsqlDbTypeInferredFromClrType: false); + dataTypeInference: DataTypeInference.Nothing); } [Test] @@ -580,15 +540,13 @@ await AssertType( new NpgsqlRange(new(2002, 3, 8), true, new(2002, 3, 11), false) }, "{[2002-03-04,2002-03-06),[2002-03-08,2002-03-11)}", - "datemultirange", - NpgsqlDbType.DateMultirange, - isDefaultForReading: false, - isNpgsqlDbTypeInferredFromClrType: false); + "datemultirange", dataTypeInference: DataTypeInference.Nothing, + valueTypeEqualsFieldType: false); } [Test] public Task Date_as_DateOnly() - => AssertType(new DateOnly(2020, 10, 1), "2020-10-01", "date", NpgsqlDbType.Date, DbType.Date, isDefaultForReading: false); + => AssertType(new DateOnly(2020, 10, 1), "2020-10-01", "date", dbType: DbType.Date, valueTypeEqualsFieldType: false); [Test] public async Task Daterange_as_NpgsqlRange_of_DateOnly() @@ -597,15 +555,15 @@ await AssertType( new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false), "[2002-03-04,2002-03-06)", "daterange", - NpgsqlDbType.DateRange, - isDefaultForReading: false, skipArrayCheck: true); + valueTypeEqualsFieldType: false, + skipArrayCheck: true); await AssertType( new [] { new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false) }, """{"[2002-03-04,2002-03-06)"}""", - "daterange[]", - NpgsqlDbType.DateRange | NpgsqlDbType.Array, - isDefault: false, skipArrayCheck: true); + "daterange[]", dataTypeInference: DataTypeInference.Mismatch, + valueTypeEqualsFieldType: false, + skipArrayCheck: true); await using var conn = await OpenConnectionAsync(); if (conn.PostgreSqlVersion < new Version(14, 0)) @@ -614,8 +572,9 @@ await AssertType( await AssertType( new [] { new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false) }, """{[2002-03-04,2002-03-06)}""", - "datemultirange", - NpgsqlDbType.DateMultirange, isDefault: false, skipArrayCheck: true); + "datemultirange", dataTypeInference: DataTypeInference.Mismatch, + valueTypeEqualsFieldType: false, + skipArrayCheck: true); } [Test] @@ -630,9 +589,7 @@ await AssertType( new DateInterval(new(2002, 3, 8), new(2002, 3, 10)) }, """{"[2002-03-04,2002-03-06)","[2002-03-08,2002-03-11)"}""", - "daterange[]", - NpgsqlDbType.DateRange | NpgsqlDbType.Array, - isDefaultForWriting: false); + "daterange[]", dataTypeInference: DataTypeInference.Nothing); } [Test] @@ -647,9 +604,7 @@ await AssertType( new NpgsqlRange(new(2002, 3, 8), true, new(2002, 3, 11), false) }, """{"[2002-03-04,2002-03-06)","[2002-03-08,2002-03-11)"}""", - "daterange[]", - NpgsqlDbType.DateRange | NpgsqlDbType.Array, - isDefault: false); + "daterange[]", dataTypeInference: DataTypeInference.Nothing, valueTypeEqualsFieldType: false); } #endregion Date @@ -658,28 +613,25 @@ await AssertType( [Test] public Task Time_as_LocalTime() - => AssertType(new LocalTime(10, 45, 34, 500), "10:45:34.5", "time without time zone", NpgsqlDbType.Time, DbType.Time, - isNpgsqlDbTypeInferredFromClrType: false); + => AssertType(new LocalTime(10, 45, 34, 500), "10:45:34.5", + "time without time zone", dataTypeInference: DataTypeInference.Nothing, + dbType: new(DbType.Time, DbType.Object)); [Test] public Task Time_as_TimeSpan() => AssertType( new TimeSpan(0, 10, 45, 34, 500), "10:45:34.5", - "time without time zone", - NpgsqlDbType.Time, - DbType.Time, - isDefault: false); + "time without time zone", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Time, DbType.Object), valueTypeEqualsFieldType: false); [Test] public Task Time_as_TimeOnly() => AssertType( new TimeOnly(10, 45, 34, 500), "10:45:34.5", - "time without time zone", - NpgsqlDbType.Time, - DbType.Time, - isDefaultForReading: false); + "time without time zone", dataTypeInference: DataTypeInference.Mismatch, + dbType: DbType.Time, valueTypeEqualsFieldType: false); #endregion Time @@ -691,8 +643,7 @@ public Task TimeTz_as_OffsetTime() new OffsetTime(new LocalTime(1, 2, 3, 4).PlusNanoseconds(5000), Offset.FromHoursAndMinutes(3, 30) + Offset.FromSeconds(5)), "01:02:03.004005+03:30:05", "time with time zone", - NpgsqlDbType.TimeTz, - isNpgsqlDbTypeInferredFromClrType: false); + dataTypeInference: DataTypeInference.Nothing); [Test] public async Task TimeTz_as_DateTimeOffset() @@ -700,14 +651,13 @@ public async Task TimeTz_as_DateTimeOffset() await AssertTypeRead( "13:03:45.51+02", "time with time zone", - new DateTimeOffset(1, 1, 2, 13, 3, 45, 510, TimeSpan.FromHours(2)), isDefault: false); + new DateTimeOffset(1, 1, 2, 13, 3, 45, 510, TimeSpan.FromHours(2)), valueTypeEqualsFieldType: false); await AssertTypeWrite( new DateTimeOffset(1, 1, 1, 13, 3, 45, 510, TimeSpan.FromHours(2)), "13:03:45.51+02", - "time with time zone", - NpgsqlDbType.TimeTz, - isDefault: false); + "time with time zone", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Object, DbType.DateTime)); } #endregion Time with time zone @@ -731,8 +681,7 @@ public Task Interval_as_Period() }.Build().Normalize(), "1 year 2 mons 25 days 05:06:07.008009", "interval", - NpgsqlDbType.Interval, - isNpgsqlDbTypeInferredFromClrType: false); + dataTypeInference: DataTypeInference.Nothing); [Test] public Task Interval_as_Duration() @@ -740,10 +689,8 @@ public Task Interval_as_Duration() Duration.FromDays(5) + Duration.FromMinutes(4) + Duration.FromSeconds(3) + Duration.FromMilliseconds(2) + Duration.FromNanoseconds(1000), "5 days 00:04:03.002001", - "interval", - NpgsqlDbType.Interval, - isDefaultForReading: false, - isNpgsqlDbTypeInferredFromClrType: false); + "interval", dataTypeInference: DataTypeInference.Nothing, + valueTypeEqualsFieldType: false); [Test] public async Task Interval_as_Duration_with_months_fails() @@ -770,14 +717,50 @@ public async Task Bug3438() } } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/5867")] + public async Task Normalize_period_on_write() + { + var value = Period.FromTicks(-3675048768766); + var expected = value.Normalize(); + var expectedAfterRoundtripBuilder = expected.ToBuilder(); + // Postgres doesn't support nanoseconds, trim them to microseconds + expectedAfterRoundtripBuilder.Nanoseconds -= expected.Nanoseconds % 1000; + var expectedAfterRoundtrip = expectedAfterRoundtripBuilder.Build(); + + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT $1, $2", conn); + cmd.Parameters.AddWithValue(value); + cmd.Parameters.AddWithValue(expected); + + await using var reader = await cmd.ExecuteReaderAsync(); + await reader.ReadAsync(); + + var dbValue = reader.GetFieldValue(0); + var dbExpected = reader.GetFieldValue(1); + + Assert.That(dbValue, Is.EqualTo(dbExpected)); + Assert.That(dbValue, Is.EqualTo(expectedAfterRoundtrip)); + } + + [Test] + public async Task Period_write_throw_on_overflow() + { + var periodBuilder = new PeriodBuilder + { + Years = int.MaxValue + }; + var ex = await AssertTypeUnsupportedWrite(periodBuilder.Build(), "interval"); + Assert.That(ex.Message, Is.EqualTo(NpgsqlNodaTimeStrings.CannotWritePeriodDueToOverflow)); + Assert.That(ex.InnerException, Is.TypeOf()); + } + #endregion Interval #region Support protected override NpgsqlDataSource DataSource { get; } - public NodaTimeTests(MultiplexingMode multiplexingMode) - : base(multiplexingMode) + public NodaTimeTests() { var builder = CreateDataSourceBuilder(); builder.UseNodaTime(); diff --git a/test/Npgsql.PluginTests/Npgsql.PluginTests.csproj b/test/Npgsql.PluginTests/Npgsql.PluginTests.csproj index 30dfb8ea16..499373bc63 100644 --- a/test/Npgsql.PluginTests/Npgsql.PluginTests.csproj +++ b/test/Npgsql.PluginTests/Npgsql.PluginTests.csproj @@ -5,6 +5,10 @@ + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + diff --git a/test/Npgsql.Specification.Tests/NpgsqlCommandTests.cs b/test/Npgsql.Specification.Tests/NpgsqlCommandTests.cs index c92cd069f9..8318435aa9 100644 --- a/test/Npgsql.Specification.Tests/NpgsqlCommandTests.cs +++ b/test/Npgsql.Specification.Tests/NpgsqlCommandTests.cs @@ -2,13 +2,8 @@ namespace Npgsql.Specification.Tests; -public sealed class NpgsqlCommandTests : CommandTestBase +public sealed class NpgsqlCommandTests(NpgsqlDbFactoryFixture fixture) : CommandTestBase(fixture) { - public NpgsqlCommandTests(NpgsqlDbFactoryFixture fixture) - : base(fixture) - { - } - // PostgreSQL only supports a single transaction on a given connection at a given time. As a result, // Npgsql completely ignores DbCommand.Transaction. public override void ExecuteReader_throws_when_transaction_required() {} diff --git a/test/Npgsql.Specification.Tests/NpgsqlConnectionTests.cs b/test/Npgsql.Specification.Tests/NpgsqlConnectionTests.cs index fa71ea0f2f..20f5bc2547 100644 --- a/test/Npgsql.Specification.Tests/NpgsqlConnectionTests.cs +++ b/test/Npgsql.Specification.Tests/NpgsqlConnectionTests.cs @@ -2,10 +2,4 @@ namespace Npgsql.Specification.Tests; -public sealed class NpgsqlConnectionTests : ConnectionTestBase -{ - public NpgsqlConnectionTests(NpgsqlDbFactoryFixture fixture) - : base(fixture) - { - } -} \ No newline at end of file +public sealed class NpgsqlConnectionTests(NpgsqlDbFactoryFixture fixture) : ConnectionTestBase(fixture); \ No newline at end of file diff --git a/test/Npgsql.Specification.Tests/NpgsqlDataReaderTests.cs b/test/Npgsql.Specification.Tests/NpgsqlDataReaderTests.cs index 356d1da966..3f3c9021aa 100644 --- a/test/Npgsql.Specification.Tests/NpgsqlDataReaderTests.cs +++ b/test/Npgsql.Specification.Tests/NpgsqlDataReaderTests.cs @@ -2,8 +2,4 @@ namespace Npgsql.Specification.Tests; -public sealed class NpgsqlDataReaderTests : DataReaderTestBase -{ - public NpgsqlDataReaderTests(NpgsqlSelectValueFixture fixture) - : base(fixture) {} -} \ No newline at end of file +public sealed class NpgsqlDataReaderTests(NpgsqlSelectValueFixture fixture) : DataReaderTestBase(fixture); \ No newline at end of file diff --git a/test/Npgsql.Specification.Tests/NpgsqlSelectValueFixture.cs b/test/Npgsql.Specification.Tests/NpgsqlSelectValueFixture.cs index 67f1d9f1b4..06bdb837f2 100644 --- a/test/Npgsql.Specification.Tests/NpgsqlSelectValueFixture.cs +++ b/test/Npgsql.Specification.Tests/NpgsqlSelectValueFixture.cs @@ -10,8 +10,7 @@ namespace Npgsql.Specification.Tests; public class NpgsqlSelectValueFixture : NpgsqlDbFactoryFixture, ISelectValueFixture, IDeleteFixture, IDisposable { public NpgsqlSelectValueFixture() - { - Utility.ExecuteNonQuery(this, @" + => Utility.ExecuteNonQuery(this, @" DROP TABLE IF EXISTS select_value; CREATE TABLE select_value ( @@ -39,7 +38,6 @@ INSERT INTO select_value VALUES (4, NULL, false, '0001-01-01', '0001-01-01', '0001-01-01', 0.000000000000001, 2.23e-308, '33221100-5544-7766-9988-aabbccddeeff', -32768, -2147483648, -9223372036854775808, 1.18e-38, NULL, '00:00:00'), (5, NULL, true, '9999-12-31', '9999-12-31 23:59:59.999', '9999-12-31 23:59:59.999 +14:00', 99999999999999999999.999999999999999, 1.79e308, 'ccddeeff-aabb-8899-7766-554433221100', 32767, 2147483647, 9223372036854775807, 3.40e38, NULL, '23:59:59.999'); "); - } public void Dispose() => Utility.ExecuteNonQuery(this, "DROP TABLE IF EXISTS select_value;"); @@ -51,8 +49,7 @@ public string CreateSelectSql(byte[] value) => public string SelectNoRows => "SELECT 1 WHERE 0 = 1;"; - public IReadOnlyCollection SupportedDbTypes { get; } = new ReadOnlyCollection(new[] - { + public IReadOnlyCollection SupportedDbTypes { get; } = new ReadOnlyCollection([ DbType.Binary, DbType.Boolean, DbType.Date, @@ -67,9 +64,9 @@ public string CreateSelectSql(byte[] value) => DbType.Single, DbType.String, DbType.Time - }); + ]); public Type NullValueExceptionType => typeof(InvalidCastException); public string DeleteNoRows => "DELETE FROM select_value WHERE 1 = 0"; -} \ No newline at end of file +} diff --git a/test/Npgsql.Tests/AuthenticationTests.cs b/test/Npgsql.Tests/AuthenticationTests.cs index 5a041a7aca..a3765d41ae 100644 --- a/test/Npgsql.Tests/AuthenticationTests.cs +++ b/test/Npgsql.Tests/AuthenticationTests.cs @@ -11,7 +11,7 @@ namespace Npgsql.Tests; -public class AuthenticationTests : MultiplexingTestBase +public class AuthenticationTests : TestBase { [Test] [NonParallelizable] // Sets environment variable @@ -66,7 +66,7 @@ public async Task Password_provider([Values]bool async) using var dataSource = dataSourceBuilder.Build(); using var conn = async ? await dataSource.OpenConnectionAsync() : dataSource.OpenConnection(); - Assert.True(async ? asyncProviderCalled : syncProviderCalled, "Password_provider not used"); + Assert.That(async ? asyncProviderCalled : syncProviderCalled, "Password_provider not used"); } [Test] @@ -338,11 +338,9 @@ public void Password_source_precedence() static DeferDisposable Defer(Action action) => new(action); } - readonly struct DeferDisposable : IDisposable + readonly struct DeferDisposable(Action action) : IDisposable { - readonly Action _action; - public DeferDisposable(Action action) => _action = action; - public void Dispose() => _action(); + public void Dispose() => action(); } [Test, Description("Connects with a bad password to ensure the proper error is thrown")] @@ -370,9 +368,8 @@ public async Task Timeout_during_authentication() // request. This should trigger a timeout await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); await using var connection = dataSource.CreateConnection(); - Assert.That(async () => await connection.OpenAsync(), - Throws.Exception.TypeOf() - .With.InnerException.TypeOf()); + var ex = Assert.ThrowsAsync(async () => await connection.OpenAsync()); + Assert.That(ex.InnerException, Is.TypeOf()); } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1180")] @@ -420,16 +417,14 @@ public async Task ProvidePasswordCallback_is_used() using (var conn = new NpgsqlConnection(builder.ConnectionString) { ProvidePasswordCallback = ProvidePasswordCallback }) { conn.Open(); - Assert.True(getPasswordDelegateWasCalled, "ProvidePasswordCallback delegate not used"); + Assert.That(getPasswordDelegateWasCalled, "ProvidePasswordCallback delegate not used"); - // Do this again, since with multiplexing the very first connection attempt is done via - // the non-multiplexing path, to surface any exceptions. NpgsqlConnection.ClearPool(conn); conn.Close(); getPasswordDelegateWasCalled = false; conn.Open(); Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - Assert.True(getPasswordDelegateWasCalled, "ProvidePasswordCallback delegate not used"); + Assert.That(getPasswordDelegateWasCalled, "ProvidePasswordCallback delegate not used"); } string ProvidePasswordCallback(string host, int port, string database, string username) @@ -448,8 +443,6 @@ public void ProvidePasswordCallback_is_not_used() { conn.Open(); - // Do this again, since with multiplexing the very first connection attempt is done via - // the non-multiplexing path, to surface any exceptions. NpgsqlConnection.ClearPool(conn); conn.Close(); conn.Open(); @@ -502,10 +495,10 @@ public void ProvidePasswordCallback_gets_correct_arguments() using (var conn = new NpgsqlConnection(builder.ConnectionString) { ProvidePasswordCallback = ProvidePasswordCallback }) { conn.Open(); - Assert.AreEqual(builder.Host, receivedHost); - Assert.AreEqual(builder.Port, receivedPort); - Assert.AreEqual(builder.Database, receivedDatabase); - Assert.AreEqual(builder.Username, receivedUsername); + Assert.That(receivedHost, Is.EqualTo(builder.Host)); + Assert.That(receivedPort, Is.EqualTo(builder.Port)); + Assert.That(receivedDatabase, Is.EqualTo(builder.Database)); + Assert.That(receivedUsername, Is.EqualTo(builder.Username)); } string ProvidePasswordCallback(string host, int port, string database, string username) @@ -531,6 +524,4 @@ NpgsqlDataSourceBuilder GetPasswordlessDataSourceBuilder() Password = null } }; - - public AuthenticationTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/AutoPrepareTests.cs b/test/Npgsql.Tests/AutoPrepareTests.cs index 14d6997230..97a46ad277 100644 --- a/test/Npgsql.Tests/AutoPrepareTests.cs +++ b/test/Npgsql.Tests/AutoPrepareTests.cs @@ -168,6 +168,10 @@ public void Promote_auto_to_explicit() // cmd1's statement is no longer valid (has been closed), make sure it still works (will run unprepared) cmd2.ExecuteScalar(); + + // Trigger autoprepare on a different query to confirm we didn't leave replaced statement in a bad state + using var cmd3 = new NpgsqlCommand("SELECT 2", conn); + cmd3.ExecuteNonQuery(); cmd3.ExecuteNonQuery(); } [Test] @@ -534,6 +538,63 @@ public async Task SchemaOnly() await cmd.ExecuteScalarAsync(); } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/6038")] + public async Task Auto_prepared_schema_only_correct_schema() + { + await using var dataSource = CreateDataSource(csb => + { + csb.MaxAutoPrepare = 1; + csb.AutoPrepareMinUsages = 5; + }); + await using var connection = await dataSource.OpenConnectionAsync(); + var table1 = await CreateTempTable(connection, "foo int"); + var table2 = await CreateTempTable(connection, "bar int"); + + await using var cmd = connection.CreateCommand(); + cmd.CommandText = $"SELECT * FROM {table1}"; + for (var i = 0; i < 5; i++) + { + // Make sure we prepare the first query + await using (await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly)) { } + } + + cmd.CommandText = $"SELECT * FROM {table2}"; + // The second query will load RowDescription, which is a singleton on NpgsqlConnector + // This shouldn't affect the first query, because we create a copy of RowDescription on prepare + await using (await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly)) { } + + cmd.CommandText = $"SELECT * FROM {table1}"; + // If we indeed made a copy of RowDescription on prepare, we should get the column for the first query and not for the second + await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly | CommandBehavior.KeyInfo); + var columns = await reader.GetColumnSchemaAsync(); + Assert.That(columns.Count, Is.EqualTo(1)); + Assert.That(columns[0].ColumnName, Is.EqualTo("foo")); + } + + [Test] + public async Task Auto_prepared_schema_only_replace() + { + await using var dataSource = CreateDataSource(csb => + { + csb.MaxAutoPrepare = 1; + csb.AutoPrepareMinUsages = 5; + }); + await using var connection = await dataSource.OpenConnectionAsync(); + + await using var cmd = connection.CreateCommand(); + cmd.CommandText = "SELECT 1"; + for (var i = 0; i < 5; i++) + { + await using (await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly)) { } + } + + cmd.CommandText = "SELECT 2"; + for (var i = 0; i < 5; i++) + { + await using (await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly)) { } + } + } + [Test] public async Task Auto_prepared_statement_invalidation() { @@ -559,6 +620,33 @@ public async Task Auto_prepared_statement_invalidation() Assert.DoesNotThrowAsync(() => command.ExecuteNonQueryAsync()); } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/6432")] + public async Task Reuse_batch_with_different_connectors() + { + await using var dataSource = CreateDataSource(csb => + { + csb.MaxAutoPrepare = 10; + csb.AutoPrepareMinUsages = 2; + }); + await using var batch = new NpgsqlBatch(); + batch.BatchCommands.Add(new NpgsqlBatchCommand("SELECT 1")); + await using (var connection = await dataSource.OpenConnectionAsync()) + { + batch.Connection = connection; + + for (var i = 0; i < 2; i++) + await batch.ExecuteNonQueryAsync(); + } + + dataSource.Clear(); + + await using (var connection = await dataSource.OpenConnectionAsync()) + { + batch.Connection = connection; + await batch.ExecuteNonQueryAsync(); + } + } + void DumpPreparedStatements(NpgsqlConnection conn) { using var cmd = new NpgsqlCommand("SELECT name,statement FROM pg_prepared_statements", conn); diff --git a/test/Npgsql.Tests/BatchTests.cs b/test/Npgsql.Tests/BatchTests.cs index 977bb5f98f..0a8daccac7 100644 --- a/test/Npgsql.Tests/BatchTests.cs +++ b/test/Npgsql.Tests/BatchTests.cs @@ -8,11 +8,9 @@ namespace Npgsql.Tests; -[TestFixture(MultiplexingMode.NonMultiplexing, CommandBehavior.Default)] -[TestFixture(MultiplexingMode.Multiplexing, CommandBehavior.Default)] -[TestFixture(MultiplexingMode.NonMultiplexing, CommandBehavior.SequentialAccess)] -[TestFixture(MultiplexingMode.Multiplexing, CommandBehavior.SequentialAccess)] -public class BatchTests : MultiplexingTestBase +[TestFixture(CommandBehavior.Default)] +[TestFixture(CommandBehavior.SequentialAccess)] +public class BatchTests : TestBase, IDisposable { #region Parameters @@ -70,24 +68,6 @@ public async Task Positional_parameters() Assert.That(await reader.NextResultAsync(), Is.False); } - [Test] - public async Task Out_parameters_are_not_allowed() - { - await using var conn = await OpenConnectionAsync(); - await using var batch = new NpgsqlBatch(conn) - { - BatchCommands = - { - new("SELECT @p1") - { - Parameters = { new("p", 8) { Direction = ParameterDirection.InputOutput } } - } - } - }; - - Assert.That(() => batch.ExecuteReaderAsync(Behavior), Throws.Exception.TypeOf()); - } - #endregion Parameters #region NpgsqlBatchCommand @@ -312,10 +292,10 @@ public async Task StatementOID() } [Test] - public void CanCreateParameter() => Assert.True(new NpgsqlBatchCommand().CanCreateParameter); + public void CanCreateParameter() => Assert.That(new NpgsqlBatchCommand().CanCreateParameter); [Test] - public void CreateParameter() => Assert.NotNull(new NpgsqlBatchCommand().CreateParameter()); + public void CreateParameter() => Assert.That(new NpgsqlBatchCommand().CreateParameter(), Is.Not.Null); #endregion NpgsqlBatchCommand @@ -495,7 +475,7 @@ public async Task Batch_with_multiple_errors([Values] bool withErrorBarriers) public async Task Batch_close_dispose_reader_with_multiple_errors([Values] bool withErrorBarriers, [Values] bool dispose) { // Create a temp pool since we dispose the reader (and check the state afterwards) and it can be reused by another connection - await using var dataSource = CreateDataSource(); + await using var dataSource = CreateDataSource(x => x.IncludeFailedBatchedCommand = true); await using var conn = await dataSource.OpenConnectionAsync(); var table = await CreateTempTable(conn, "id INT"); @@ -687,7 +667,7 @@ public async Task Empty_batch() } [Test] - public async Task Semicolon_is_not_allowed() + public async Task Semicolon_is_not_allowed_with_no_parameters() { await using var conn = await OpenConnectionAsync(); await using var batch = new NpgsqlBatch(conn) @@ -695,6 +675,24 @@ public async Task Semicolon_is_not_allowed() BatchCommands = { new("SELECT 1; SELECT 2") } }; + Assert.That(() => batch.ExecuteReaderAsync(Behavior), Throws.Exception.TypeOf()); + } + + [Test] + public async Task Semicolon_is_not_allowed_with_named_parameters() + { + await using var conn = await OpenConnectionAsync(); + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = + { + new("SELECT @p1; SELECT 2") + { + Parameters = { new("p1", 1) } + } + } + }; + Assert.That(() => batch.ExecuteReaderAsync(Behavior), Throws.Exception.TypeOf()); } @@ -720,7 +718,7 @@ await conn.ExecuteNonQueryAsync($@" // resources are referenced by the exception above, which is very likely to escape the using statement of the command. batch.Dispose(); var cmd2 = conn.CreateBatch(); - Assert.AreNotSame(cmd2, batch); + Assert.That(batch, Is.Not.SameAs(cmd2)); } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/967")] @@ -741,7 +739,6 @@ await conn.ExecuteNonQueryAsync($@" await using (var reader = await batch.ExecuteReaderAsync(Behavior)) { - var e = Assert.ThrowsAsync(async () => await reader.NextResultAsync())!; Assert.That(e.BatchCommand, Is.SameAs(batch.BatchCommands[1])); } @@ -750,7 +747,7 @@ await conn.ExecuteNonQueryAsync($@" // resources are referenced by the exception above, which is very likely to escape the using statement of the command. batch.Dispose(); var cmd2 = conn.CreateBatch(); - Assert.AreNotSame(cmd2, batch); + Assert.That(batch, Is.Not.SameAs(cmd2)); } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4202")] @@ -816,100 +813,6 @@ public async Task Batch_dispose_reuse() #endregion Miscellaneous - #region Logging - - [Test] - public async Task Log_ExecuteScalar_single_statement_without_parameters() - { - await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider); - await using var conn = await dataSource.OpenConnectionAsync(); - await using var cmd = new NpgsqlBatch(conn) - { - BatchCommands = { new("SELECT 1") } - }; - - using (listLoggerProvider.Record()) - { - await cmd.ExecuteScalarAsync(); - } - - var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); - - Assert.That(executingCommandEvent.Message, Does.Contain("Command execution completed").And.Contains("SELECT 1")); - AssertLoggingStateContains(executingCommandEvent, "CommandText", "SELECT 1"); - AssertLoggingStateDoesNotContain(executingCommandEvent, "Parameters"); - - if (!IsMultiplexing) - AssertLoggingStateContains(executingCommandEvent, "ConnectorId", conn.ProcessID); - } - - [Test] - public async Task Log_ExecuteScalar_multiple_statements_with_parameters() - { - await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider); - await using var conn = await dataSource.OpenConnectionAsync(); - await using var batch = new NpgsqlBatch(conn) - { - BatchCommands = - { - new("SELECT $1") { Parameters = { new() { Value = 8 } } }, - new("SELECT $1, 9") { Parameters = { new() { Value = 9 } } } - } - }; - - using (listLoggerProvider.Record()) - { - await batch.ExecuteScalarAsync(); - } - - var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); - - // Note: the message formatter of Microsoft.Extensions.Logging doesn't seem to handle arrays inside tuples, so we get the - // following ugliness (https://github.com/dotnet/runtime/issues/63165). Serilog handles this fine. - Assert.That(executingCommandEvent.Message, Does.Contain("Batch execution completed").And.Contains("[(SELECT $1, System.Object[]), (SELECT $1, 9, System.Object[])]")); - AssertLoggingStateDoesNotContain(executingCommandEvent, "CommandText"); - AssertLoggingStateDoesNotContain(executingCommandEvent, "Parameters"); - - if (!IsMultiplexing) - AssertLoggingStateContains(executingCommandEvent, "ConnectorId", conn.ProcessID); - - var batchCommands = (IList<(string CommandText, object[] Parameters)>)AssertLoggingStateContains(executingCommandEvent, "BatchCommands"); - Assert.That(batchCommands.Count, Is.EqualTo(2)); - Assert.That(batchCommands[0].CommandText, Is.EqualTo("SELECT $1")); - Assert.That(batchCommands[0].Parameters[0], Is.EqualTo(8)); - Assert.That(batchCommands[1].CommandText, Is.EqualTo("SELECT $1, 9")); - Assert.That(batchCommands[1].Parameters[0], Is.EqualTo(9)); - } - - [Test] - public async Task Log_ExecuteScalar_single_statement_with_parameter_logging_off() - { - await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider, sensitiveDataLoggingEnabled: false); - await using var conn = await dataSource.OpenConnectionAsync(); - await using var batch = new NpgsqlBatch(conn) - { - BatchCommands = - { - new("SELECT $1") { Parameters = { new() { Value = 8 } } }, - new("SELECT $1, 9") { Parameters = { new() { Value = 9 } } } - } - }; - - using (listLoggerProvider.Record()) - { - await batch.ExecuteScalarAsync(); - } - - var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); - Assert.That(executingCommandEvent.Message, Does.Contain("Batch execution completed").And.Contains("[SELECT $1, SELECT $1, 9]")); - var batchCommands = (IList)AssertLoggingStateContains(executingCommandEvent, "BatchCommands"); - Assert.That(batchCommands.Count, Is.EqualTo(2)); - Assert.That(batchCommands[0], Is.EqualTo("SELECT $1")); - Assert.That(batchCommands[1], Is.EqualTo("SELECT $1, 9")); - } - - #endregion Logging - #region Initialization / setup / teardown // ReSharper disable InconsistentNaming @@ -917,11 +820,16 @@ public async Task Log_ExecuteScalar_single_statement_with_parameter_logging_off( readonly CommandBehavior Behavior; // ReSharper restore InconsistentNaming - public BatchTests(MultiplexingMode multiplexingMode, CommandBehavior behavior) : base(multiplexingMode) + NpgsqlDataSource? _dataSource; + protected override NpgsqlDataSource DataSource => _dataSource ??= CreateDataSource(csb => csb.IncludeFailedBatchedCommand = true); + + public BatchTests(CommandBehavior behavior) { Behavior = behavior; IsSequential = (Behavior & CommandBehavior.SequentialAccess) != 0; } + public void Dispose() => DataSource.Dispose(); + #endregion } diff --git a/test/Npgsql.Tests/BugTests.cs b/test/Npgsql.Tests/BugTests.cs index e3c05dd5fb..5c0b77b1dd 100644 --- a/test/Npgsql.Tests/BugTests.cs +++ b/test/Npgsql.Tests/BugTests.cs @@ -4,6 +4,7 @@ using NUnit.Framework; using System; using System.Data; +using System.Numerics; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -174,8 +175,7 @@ public void Bug1695() [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1700")] public void Bug1700() - { - Assert.That(() => + => Assert.That(() => { using var conn = OpenConnection(); using var tx = conn.BeginTransaction(); @@ -197,7 +197,6 @@ public void Bug1700() // Note, we never get here tx.Commit(); }, Throws.InvalidOperationException.With.Message.EqualTo("Some problem parsing the returned data")); - } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1964")] public void Bug1964() @@ -1263,12 +1262,12 @@ public async Task Bug3649() using (var exporter = await conn.BeginBinaryExportAsync($"COPY {table} (value) TO STDIN (FORMAT binary)")) { await exporter.StartRowAsync(); - Assert.IsTrue(exporter.IsNull); + Assert.That(exporter.IsNull); await exporter.SkipAsync(); await exporter.StartRowAsync(); - Assert.AreEqual(1, await exporter.ReadAsync()); + Assert.That(await exporter.ReadAsync(), Is.EqualTo(1)); await exporter.StartRowAsync(); - Assert.AreEqual(2, await exporter.ReadAsync()); + Assert.That(await exporter.ReadAsync(), Is.EqualTo(2)); } } @@ -1334,51 +1333,6 @@ public async Task Bug3924() } } - [Test] - [IssueLink("https://github.com/npgsql/npgsql/issues/4099")] - public async Task Bug4099() - { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - Multiplexing = true, - MaxPoolSize = 1 - }; - await using var postmaster = PgPostmasterMock.Start(csb.ConnectionString); - await using var dataSource = CreateDataSource(postmaster.ConnectionString); - await using var firstConn = await dataSource.OpenConnectionAsync(); - await using var secondConn = await dataSource.OpenConnectionAsync(); - - var firstQuery = firstConn.ExecuteScalarAsync("SELECT data"); - - var server = await postmaster.WaitForServerConnection(); - await server.ExpectExtendedQuery(); - - var secondQuery = secondConn.ExecuteScalarAsync("SELECT other_data"); - await server.ExpectExtendedQuery(); - - var data = new byte[10000]; - await server - .WriteParseComplete() - .WriteBindComplete() - .WriteRowDescription(new FieldDescription(ByteaOid)) - .WriteDataRowWithFlush(data); - - var otherData = new byte[10]; - await server - .WriteCommandComplete() - .WriteReadyForQuery() - .WriteParseComplete() - .WriteBindComplete() - .WriteRowDescription(new FieldDescription(ByteaOid)) - .WriteDataRow(otherData) - .WriteCommandComplete() - .WriteReadyForQuery() - .FlushAsync(); - - Assert.That(data, Is.EquivalentTo((byte[])(await firstQuery)!)); - Assert.That(otherData, Is.EquivalentTo((byte[])(await secondQuery)!)); - } - [Test] [IssueLink("https://github.com/npgsql/npgsql/issues/4123")] public async Task Bug4123() @@ -1393,4 +1347,30 @@ public async Task Bug4123() Assert.DoesNotThrowAsync(stream.FlushAsync); Assert.DoesNotThrow(stream.Flush); } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/6389")] + public async Task Composite_with_BigInteger([Values(CommandBehavior.Default, CommandBehavior.SequentialAccess)] CommandBehavior behavior) + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {type} as (value numeric)"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(type); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + await using var cmd = connection.CreateCommand(); + cmd.CommandText = $"SELECT ROW(1234567890::numeric)::{type} FROM generate_series(1, 8000)"; + await using var reader = await cmd.ExecuteReaderAsync(behavior); + while (await reader.ReadAsync()) + { + Assert.DoesNotThrowAsync(async () => await reader.GetFieldValueAsync(0)); + } + } + + class Composite_with_BigInteger_Composite + { + public BigInteger Value { get; set; } + } } diff --git a/test/Npgsql.Tests/CommandBuilderTests.cs b/test/Npgsql.Tests/CommandBuilderTests.cs index e917b7f6b3..f9643adfd5 100644 --- a/test/Npgsql.Tests/CommandBuilderTests.cs +++ b/test/Npgsql.Tests/CommandBuilderTests.cs @@ -341,7 +341,7 @@ PRIMARY KEY (Cod) Assert.That(row[0], Is.EqualTo("key1")); Assert.That(row[1], Is.EqualTo("description")); - Assert.That(row[2], Is.EqualTo(new DateTime(2018, 7, 3))); + Assert.That(row[2], Is.EqualTo(new DateOnly(2018, 7, 3))); Assert.That(row[3], Is.EqualTo(new DateTime(2018, 7, 3, 7, 2, 0))); Assert.That(row[4], Is.EqualTo(123)); Assert.That(row[5], Is.EqualTo(123.4)); @@ -364,7 +364,7 @@ public async Task Get_update_command_with_column_aliases() using var cbCommandBuilder = new NpgsqlCommandBuilder(daDataAdapter); daDataAdapter.UpdateCommand = cbCommandBuilder.GetUpdateCommand(); - Assert.True(daDataAdapter.UpdateCommand.CommandText.Contains("SET \"cod\" = @p1, \"descr\" = @p2, \"data\" = @p3 WHERE ((\"cod\" = @p4) AND ((@p5 = 1 AND \"descr\" IS NULL) OR (\"descr\" = @p6)) AND ((@p7 = 1 AND \"data\" IS NULL) OR (\"data\" = @p8)))")); + Assert.That(daDataAdapter.UpdateCommand.CommandText.Contains("SET \"cod\" = @p1, \"descr\" = @p2, \"data\" = @p3 WHERE ((\"cod\" = @p4) AND ((@p5 = 1 AND \"descr\" IS NULL) OR (\"descr\" = @p6)) AND ((@p7 = 1 AND \"data\" IS NULL) OR (\"data\" = @p8)))")); } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2846")] @@ -387,4 +387,60 @@ public async Task Get_update_command_with_array_column_type() daDataAdapter.Update(dtTable); } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/6240")] + public async Task Get_update_command_with_domain_column_type() + { + await using var adminConnection = await OpenConnectionAsync(); + var domainTypeName = await GetTempTypeName(adminConnection); + + await adminConnection.ExecuteNonQueryAsync($"CREATE DOMAIN {domainTypeName} AS smallint"); + + var tableName = await CreateTempTable(adminConnection, $"id serial PRIMARY KEY, domtest {domainTypeName}"); + + await using var dataSource = CreateDataSource(); + await using var conn = await dataSource.OpenConnectionAsync(); + + using var adapter = new NpgsqlDataAdapter($"select * from {tableName}", conn); + + var builder = new NpgsqlCommandBuilder(adapter) + { + ConflictOption = ConflictOption.CompareAllSearchableValues, + SetAllValues = true + }; + + adapter.InsertCommand = builder.GetInsertCommand(); + adapter.UpdateCommand = builder.GetUpdateCommand(); + adapter.DeleteCommand = builder.GetDeleteCommand(); + + using var dataTable = new DataTable(); + + adapter.Fill(dataTable); + + const short sval = 5; + + var newRow = dataTable.NewRow(); + newRow[1] = sval; + dataTable.Rows.Add(newRow); + + adapter.Update(dataTable); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/6240")] + public async Task Fill_datatable_with_array_column_type() + { + await using var connection = await OpenConnectionAsync(); + + var tableName = await CreateTempTable(connection, "id serial PRIMARY KEY, textarr text[] COLLATE pg_catalog.\"default\""); + + using var adapter = new NpgsqlDataAdapter($"select * from {tableName}", connection); + + using var dataTable = new DataTable(); + + adapter.FillSchema(dataTable, SchemaType.Source); + + adapter.MissingSchemaAction = MissingSchemaAction.Ignore; + + adapter.Fill(dataTable); + } } diff --git a/test/Npgsql.Tests/CommandParameterTests.cs b/test/Npgsql.Tests/CommandParameterTests.cs index 1e4355df4b..adc5d311a5 100644 --- a/test/Npgsql.Tests/CommandParameterTests.cs +++ b/test/Npgsql.Tests/CommandParameterTests.cs @@ -6,7 +6,7 @@ namespace Npgsql.Tests; -public class CommandParameterTests : MultiplexingTestBase +public class CommandParameterTests : TestBase { [Test] [TestCase(CommandBehavior.Default)] @@ -22,17 +22,14 @@ public async Task Input_and_output_parameters(CommandBehavior behavior) cmd.Parameters.Add(c); using (await cmd.ExecuteReaderAsync(behavior)) { - Assert.AreEqual(5, b.Value); - Assert.AreEqual(3, c.Value); + Assert.That(b.Value, Is.EqualTo(5)); + Assert.That(c.Value, Is.EqualTo(3)); } } [Test] public async Task Send_NpgsqlDbType_Unknown([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) { - if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - return; - using var conn = await OpenConnectionAsync(); using var cmd = new NpgsqlCommand("SELECT @p::TIMESTAMP", conn); cmd.CommandText = "SELECT @p::TIMESTAMP"; @@ -128,20 +125,20 @@ public void Parameters_get_name() command.Parameters.Add(new NpgsqlParameter("Parameter4", DbType.DateTime)); var idbPrmtr = command.Parameters["Parameter1"]; - Assert.IsNotNull(idbPrmtr); + Assert.That(idbPrmtr, Is.Not.Null); command.Parameters[0].Value = 1; // Get by indexers. - Assert.AreEqual(":Parameter1", command.Parameters["Parameter1"].ParameterName); - Assert.AreEqual(":Parameter2", command.Parameters["Parameter2"].ParameterName); - Assert.AreEqual(":Parameter3", command.Parameters["Parameter3"].ParameterName); - Assert.AreEqual("Parameter4", command.Parameters["Parameter4"].ParameterName); //Should this work? + Assert.That(command.Parameters["Parameter1"].ParameterName, Is.EqualTo(":Parameter1")); + Assert.That(command.Parameters["Parameter2"].ParameterName, Is.EqualTo(":Parameter2")); + Assert.That(command.Parameters["Parameter3"].ParameterName, Is.EqualTo(":Parameter3")); + Assert.That(command.Parameters["Parameter4"].ParameterName, Is.EqualTo("Parameter4")); //Should this work? - Assert.AreEqual(":Parameter1", command.Parameters[0].ParameterName); - Assert.AreEqual(":Parameter2", command.Parameters[1].ParameterName); - Assert.AreEqual(":Parameter3", command.Parameters[2].ParameterName); - Assert.AreEqual("Parameter4", command.Parameters[3].ParameterName); + Assert.That(command.Parameters[0].ParameterName, Is.EqualTo(":Parameter1")); + Assert.That(command.Parameters[1].ParameterName, Is.EqualTo(":Parameter2")); + Assert.That(command.Parameters[2].ParameterName, Is.EqualTo(":Parameter3")); + Assert.That(command.Parameters[3].ParameterName, Is.EqualTo("Parameter4")); } [Test] @@ -164,7 +161,7 @@ public async Task Generic_parameter() cmd.Parameters.Add(new NpgsqlParameter("p1", 8)); cmd.Parameters.Add(new NpgsqlParameter("p2", 8) { NpgsqlDbType = NpgsqlDbType.Integer }); cmd.Parameters.Add(new NpgsqlParameter("p3", "hello")); - cmd.Parameters.Add(new NpgsqlParameter("p4", new[] { 'f', 'o', 'o' })); + cmd.Parameters.Add(new NpgsqlParameter("p4", ['f', 'o', 'o'])); using var reader = await cmd.ExecuteReaderAsync(); reader.Read(); Assert.That(reader.GetInt32(0), Is.EqualTo(8)); @@ -189,16 +186,14 @@ public async Task Parameter_must_be_set(bool genericParam) Assert.That(async () => await cmd.ExecuteReaderAsync(), Throws.Exception .TypeOf() - .With.Message.EqualTo("Parameter 'p1' must have either its NpgsqlDbType or its DataTypeName or its Value set.")); + .With.Message.EqualTo("Parameter 'p1' must have either its DbType, NpgsqlDbType, DataTypeName or its Value set.")); } [Test] public async Task Object_generic_param_does_runtime_lookup() { - await AssertTypeWrite(1, "1", "integer", NpgsqlDbType.Integer, DbType.Int32, DbType.Int32, isDefault: false, - isNpgsqlDbTypeInferredFromClrType: true, skipArrayCheck: true); - await AssertTypeWrite(new[] {1, 1}, "{1,1}", "integer[]", NpgsqlDbType.Integer | NpgsqlDbType.Array, isDefault: false, - isNpgsqlDbTypeInferredFromClrType: true, skipArrayCheck: true); + await AssertTypeWrite(1, "1", "integer", dbType: DbType.Int32, skipArrayCheck: true); + await AssertTypeWrite(new[] {1, 1}, "{1,1}", "integer[]", skipArrayCheck: true); } [Test] @@ -209,8 +204,4 @@ public async Task Object_generic_parameter_works() cmd.Parameters.Add(new NpgsqlParameter { NpgsqlDbType = NpgsqlDbType.Integer, Value = 8 }); Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(8)); } - - public CommandParameterTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) - { - } } diff --git a/test/Npgsql.Tests/CommandTests.cs b/test/Npgsql.Tests/CommandTests.cs index 9c42da9c0c..7c6888c9a7 100644 --- a/test/Npgsql.Tests/CommandTests.cs +++ b/test/Npgsql.Tests/CommandTests.cs @@ -5,7 +5,6 @@ using NUnit.Framework; using System; using System.Buffers.Binary; -using System.Collections.Generic; using System.Data; using System.Linq; using System.Text; @@ -16,7 +15,7 @@ namespace Npgsql.Tests; -public class CommandTests : MultiplexingTestBase +public class CommandTests : TestBase { static uint Int4Oid => PostgresMinimalDatabaseInfo.DefaultTypeCatalog.GetOid(DataTypeNames.Int4).Value; static uint TextOid => PostgresMinimalDatabaseInfo.DefaultTypeCatalog.GetOid(DataTypeNames.Text).Value; @@ -42,7 +41,7 @@ public async Task Multiple_statements(bool[] queries) { await using var cmd = conn.CreateCommand(); cmd.CommandText = sql; - if (prepare && !IsMultiplexing) + if (prepare) await cmd.PrepareAsync(); await using var reader = await cmd.ExecuteReaderAsync(); var numResultSets = queries.Count(q => q); @@ -58,9 +57,6 @@ public async Task Multiple_statements(bool[] queries) [Test] public async Task Multiple_statements_with_parameters([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) { - if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - return; - await using var conn = await OpenConnectionAsync(); await using var cmd = conn.CreateCommand(); cmd.CommandText = "SELECT @p1; SELECT @p2"; @@ -84,9 +80,6 @@ public async Task Multiple_statements_with_parameters([Values(PrepareOrNot.NotPr [Test] public async Task SingleRow_legacy_batching([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) { - if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - return; - using var conn = await OpenConnectionAsync(); using var cmd = new NpgsqlCommand("SELECT 1; SELECT 2", conn); if (prepare == PrepareOrNot.Prepared) @@ -170,9 +163,6 @@ public async Task Named_parameters_are_not_supported_when_EnableSqlParsing_is_di [IssueLink("https://github.com/npgsql/npgsql/issues/327")] public async Task Timeout() { - if (IsMultiplexing) - return; // Multiplexing, Timeout - await using var dataSource = CreateDataSource(csb => csb.CommandTimeout = 1); await using var conn = await dataSource.OpenConnectionAsync(); await using var cmd = CreateSleepCommand(conn, 10); @@ -187,9 +177,6 @@ public async Task Timeout() [IssueLink("https://github.com/npgsql/npgsql/issues/607")] public async Task Timeout_async_soft() { - if (IsMultiplexing) - return; // Multiplexing, Timeout - await using var dataSource = CreateDataSource(csb => csb.CommandTimeout = 1); await using var conn = await dataSource.OpenConnectionAsync(); await using var cmd = CreateSleepCommand(conn, 10); @@ -204,9 +191,6 @@ public async Task Timeout_async_soft() [IssueLink("https://github.com/npgsql/npgsql/issues/607")] public async Task Timeout_async_hard() { - if (IsMultiplexing) - return; // Multiplexing, Timeout - var builder = new NpgsqlConnectionStringBuilder(ConnectionString) { CommandTimeout = 1 }; await using var postmasterMock = PgPostmasterMock.Start(builder.ConnectionString); await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); @@ -243,7 +227,7 @@ public async Task Timeout_from_connection_string() public async Task Timeout_switch_connection() { var csb = new NpgsqlConnectionStringBuilder(ConnectionString); - if (csb.CommandTimeout >= 100 && csb.CommandTimeout < 105) + if (csb.CommandTimeout is >= 100 and < 105) IgnoreExceptOnBuildServer("Bad default command timeout"); await using var dataSource1 = CreateDataSource(ConnectionString + ";CommandTimeout=100"); @@ -267,9 +251,6 @@ public async Task Timeout_switch_connection() [Test] public async Task Prepare_timeout_hard([Values] SyncOrAsync async) { - if (IsMultiplexing) - return; // Multiplexing, Timeout - var builder = new NpgsqlConnectionStringBuilder(ConnectionString) { CommandTimeout = 1 }; await using var postmasterMock = PgPostmasterMock.Start(builder.ConnectionString); await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); @@ -300,12 +281,8 @@ public async Task Prepare_timeout_hard([Values] SyncOrAsync async) #region Cancel [Test, Description("Basic cancellation scenario")] - [Ignore("Flaky, see https://github.com/npgsql/npgsql/issues/5070")] public async Task Cancel() { - if (IsMultiplexing) - return; - await using var conn = await OpenConnectionAsync(); await using var cmd = CreateSleepCommand(conn, 5); @@ -323,9 +300,6 @@ public async Task Cancel() [Test] public async Task Cancel_async_immediately() { - if (IsMultiplexing) - return; // Multiplexing, cancellation - await using var conn = await OpenConnectionAsync(); await using var cmd = conn.CreateCommand(); cmd.CommandText = "SELECT 1"; @@ -340,12 +314,8 @@ public async Task Cancel_async_immediately() } [Test, Description("Cancels an async query with the cancellation token, with successful PG cancellation")] - [Explicit("Flaky due to #5033")] public async Task Cancel_async_soft() { - if (IsMultiplexing) - return; // Multiplexing, cancellation - await using var conn = await OpenConnectionAsync(); await using var cmd = CreateSleepCommand(conn); using var cancellationSource = new CancellationTokenSource(); @@ -361,12 +331,48 @@ public async Task Cancel_async_soft() Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); } + [Test, Description("Cancels an async query with the cancellation token and prepended query, with successful PG cancellation")] + [IssueLink("https://github.com/npgsql/npgsql/issues/5191")] + public async Task Cancel_async_soft_with_prepended_query() + { + await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); + await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); + await using var conn = await dataSource.OpenConnectionAsync(); + var server = await postmasterMock.WaitForServerConnection(); + + var processId = conn.ProcessID; + + await using var tx = await conn.BeginTransactionAsync(); + await using var cmd = CreateSleepCommand(conn); + using var cancellationSource = new CancellationTokenSource(); + var t = cmd.ExecuteNonQueryAsync(cancellationSource.Token); + + await server.ExpectSimpleQuery("BEGIN TRANSACTION ISOLATION LEVEL READ COMMITTED"); + cancellationSource.Cancel(); + await server + .WriteCommandComplete() + .WriteReadyForQuery(TransactionStatus.InTransactionBlock) + .FlushAsync(); + + Assert.That((await postmasterMock.WaitForCancellationRequest()).ProcessId, + Is.EqualTo(processId)); + + await server + .WriteErrorResponse(PostgresErrorCodes.QueryCanceled) + .WriteReadyForQuery() + .FlushAsync(); + + var exception = Assert.ThrowsAsync(async () => await t)!; + Assert.That(exception.InnerException, + Is.TypeOf().With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.QueryCanceled)); + Assert.That(exception.CancellationToken, Is.EqualTo(cancellationSource.Token)); + + Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); + } + [Test, Description("Cancels an async query with the cancellation token, with unsuccessful PG cancellation (socket break)")] public async Task Cancel_async_hard() { - if (IsMultiplexing) - return; // Multiplexing, cancellation - await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); await using var conn = await dataSource.OpenConnectionAsync(); @@ -392,9 +398,6 @@ public async Task Cancel_async_hard() [Ignore("https://github.com/npgsql/npgsql/issues/4668")] public async Task Bug3466([Values(false, true)] bool isBroken) { - if (IsMultiplexing) - return; // Multiplexing, cancellation - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) { Pooling = false @@ -483,7 +486,7 @@ public async Task Cursor_statement() while (dr.Read()) i++; - Assert.AreEqual(3, i); + Assert.That(i, Is.EqualTo(3)); dr.Close(); i = 0; @@ -491,7 +494,7 @@ public async Task Cursor_statement() var dr2 = command.ExecuteReader(); while (dr2.Read()) i++; - Assert.AreEqual(1, i); + Assert.That(i, Is.EqualTo(1)); dr2.Close(); command.CommandText = "close te;"; @@ -507,7 +510,7 @@ public async Task Cursor_move_RecordsAffected() command.ExecuteNonQuery(); command.CommandText = "MOVE FORWARD ALL IN curs"; var count = command.ExecuteNonQuery(); - Assert.AreEqual(3, count); + Assert.That(count, Is.EqualTo(3)); } #endregion @@ -550,9 +553,6 @@ public async Task CloseConnection_with_exception() [Test] public async Task SingleRow([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) { - if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - return; - await using var conn = await OpenConnectionAsync(); await using var cmd = new NpgsqlCommand("SELECT 1, 2 UNION SELECT 3, 4", conn); if (prepare == PrepareOrNot.Prepared) @@ -645,17 +645,9 @@ public async Task Non_standards_conforming_strings() await using var dataSource = CreateDataSource(); await using var conn = await dataSource.OpenConnectionAsync(); - if (IsMultiplexing) - { - Assert.That(async () => await conn.ExecuteNonQueryAsync("set standard_conforming_strings=off"), - Throws.Exception.TypeOf()); - } - else - { - await conn.ExecuteNonQueryAsync("set standard_conforming_strings=off"); - Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); - await conn.ExecuteNonQueryAsync("set standard_conforming_strings=on"); - } + await conn.ExecuteNonQueryAsync("set standard_conforming_strings=off"); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + await conn.ExecuteNonQueryAsync("set standard_conforming_strings=on"); } [Test] @@ -668,7 +660,7 @@ public async Task Parameter_and_operator_unclear() command.Parameters.AddWithValue(":arr", new int[] {5, 4, 3, 2, 1}); await using var rdr = await command.ExecuteReaderAsync(); rdr.Read(); - Assert.AreEqual(rdr.GetInt32(0), 4); + Assert.That(rdr.GetInt32(0), Is.EqualTo(4)); } [Test] @@ -702,14 +694,14 @@ public async Task Cached_command_clears_parameters_placeholder_type() public async Task Statement_mapped_output_parameters(CommandBehavior behavior) { await using var conn = await OpenConnectionAsync(); - var command = new NpgsqlCommand("select 3, 4 as param1, 5 as param2, 6;", conn); + var command = new NpgsqlCommand("select 3 as unknown, 4 as param1, 5 as param2, 6;", conn); - var p = new NpgsqlParameter("param2", NpgsqlDbType.Integer); + var p = new NpgsqlParameter("param1", NpgsqlDbType.Integer); p.Direction = ParameterDirection.Output; p.Value = -1; command.Parameters.Add(p); - p = new NpgsqlParameter("param1", NpgsqlDbType.Integer); + p = new NpgsqlParameter("param2", NpgsqlDbType.Integer); p.Direction = ParameterDirection.Output; p.Value = -1; command.Parameters.Add(p); @@ -721,15 +713,55 @@ public async Task Statement_mapped_output_parameters(CommandBehavior behavior) await using var reader = await command.ExecuteReaderAsync(behavior); - Assert.AreEqual(4, command.Parameters["param1"].Value); - Assert.AreEqual(5, command.Parameters["param2"].Value); + Assert.That(command.Parameters["p"].Value, Is.EqualTo(3)); + Assert.That(command.Parameters["param1"].Value, Is.EqualTo(4)); + Assert.That(command.Parameters["param2"].Value, Is.EqualTo(5)); + + reader.Read(); + + Assert.That(reader.GetInt32(0), Is.EqualTo(3)); + Assert.That(reader.GetInt32(1), Is.EqualTo(4)); + Assert.That(reader.GetInt32(2), Is.EqualTo(5)); + Assert.That(reader.GetInt32(3), Is.EqualTo(6)); + } + + + [Test] + [TestCase(CommandBehavior.Default)] + [TestCase(CommandBehavior.SequentialAccess)] + public async Task Statement_mapped_generic_output_parameters(CommandBehavior behavior) + { + await using var conn = await OpenConnectionAsync(); + var command = new NpgsqlCommand("select '' as unknown, 4 as param1, 5 as param2, 6;", conn); + + var p = new NpgsqlParameter("param1", NpgsqlDbType.Integer); + p.Direction = ParameterDirection.Output; + p.Value = -1; + command.Parameters.Add(p); + + p = new NpgsqlParameter("param2", NpgsqlDbType.Integer); + p.Direction = ParameterDirection.Output; + p.Value = -1; + command.Parameters.Add(p); + + // char[] is one alternative mapping for text. + var textP = new NpgsqlParameter("p", NpgsqlDbType.Text); + textP.Direction = ParameterDirection.Output; + textP.Value = "text".ToCharArray(); + command.Parameters.Add(textP); + + await using var reader = await command.ExecuteReaderAsync(behavior); + + Assert.That(command.Parameters["p"].Value, Is.EquivalentTo(Array.Empty())); + Assert.That(command.Parameters["param1"].Value, Is.EqualTo(4)); + Assert.That(command.Parameters["param2"].Value, Is.EqualTo(5)); reader.Read(); - Assert.AreEqual(3, reader.GetInt32(0)); - Assert.AreEqual(4, reader.GetInt32(1)); - Assert.AreEqual(5, reader.GetInt32(2)); - Assert.AreEqual(6, reader.GetInt32(3)); + Assert.That(reader.GetFieldValue(0), Is.EquivalentTo(Array.Empty())); + Assert.That(reader.GetInt32(1), Is.EqualTo(4)); + Assert.That(reader.GetInt32(2), Is.EqualTo(5)); + Assert.That(reader.GetInt32(3), Is.EqualTo(6)); } [Test] @@ -760,30 +792,27 @@ public async Task Bug1006158_output_parameters() _ = await command.ExecuteScalarAsync(); - Assert.AreEqual(3, command.Parameters[0].Value); - Assert.AreEqual(true, command.Parameters[1].Value); + Assert.That(command.Parameters[0].Value, Is.EqualTo(3)); + Assert.That(command.Parameters[1].Value, Is.EqualTo(true)); } [Test] public async Task Bug1010788_UpdateRowSource() { - if (IsMultiplexing) - return; - using var conn = await OpenConnectionAsync(); var table = await CreateTempTable(conn, "id SERIAL PRIMARY KEY, name TEXT"); var command = new NpgsqlCommand($"SELECT * FROM {table}", conn); - Assert.AreEqual(UpdateRowSource.Both, command.UpdatedRowSource); + Assert.That(command.UpdatedRowSource, Is.EqualTo(UpdateRowSource.Both)); var cmdBuilder = new NpgsqlCommandBuilder(); var da = new NpgsqlDataAdapter(command); cmdBuilder.DataAdapter = da; - Assert.IsNotNull(da.SelectCommand); - Assert.IsNotNull(cmdBuilder.DataAdapter); + Assert.That(da.SelectCommand, Is.Not.Null); + Assert.That(cmdBuilder.DataAdapter, Is.Not.Null); var updateCommand = cmdBuilder.GetUpdateCommand(); - Assert.AreEqual(UpdateRowSource.None, updateCommand.UpdatedRowSource); + Assert.That(updateCommand.UpdatedRowSource, Is.EqualTo(UpdateRowSource.None)); } [Test] @@ -812,9 +841,6 @@ public async Task Invalid_UTF8() [Test, IssueLink("https://github.com/npgsql/npgsql/issues/395")] public async Task Use_across_connection_change([Values(PrepareOrNot.Prepared, PrepareOrNot.NotPrepared)] PrepareOrNot prepare) { - if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - return; - using var conn1 = await OpenConnectionAsync(); using var conn2 = await OpenConnectionAsync(); using var cmd = new NpgsqlCommand("SELECT 1", conn1); @@ -831,9 +857,6 @@ public async Task Use_across_connection_change([Values(PrepareOrNot.Prepared, Pr [Test] public async Task Use_after_reload_types_invalidates_cached_infos() { - if (IsMultiplexing) - return; - using var conn1 = await OpenConnectionAsync(); using var cmd = new NpgsqlCommand("SELECT 1", conn1); cmd.Prepare(); @@ -852,6 +875,180 @@ public async Task Use_after_reload_types_invalidates_cached_infos() } } + [Test] + public async Task Parameter_overflow_message_length_throws() + { + // Create a separate data source to avoid breaking unrelated queries + await using var dataSource = CreateDataSource(); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT @a, @b, @c, @d, @e, @f, @g, @h", conn); + + var largeParam = new string('A', 1 << 29); + cmd.Parameters.AddWithValue("a", largeParam); + cmd.Parameters.AddWithValue("b", largeParam); + cmd.Parameters.AddWithValue("c", largeParam); + cmd.Parameters.AddWithValue("d", largeParam); + cmd.Parameters.AddWithValue("e", largeParam); + cmd.Parameters.AddWithValue("f", largeParam); + cmd.Parameters.AddWithValue("g", largeParam); + cmd.Parameters.AddWithValue("h", largeParam); + + Assert.ThrowsAsync(() => cmd.ExecuteReaderAsync()); + } + + [Test] + public async Task Composite_overflow_message_length_throws() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + + await adminConnection.ExecuteNonQueryAsync( + $"CREATE TYPE {type} AS (a text, b text, c text, d text, e text, f text, g text, h text)"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(type); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + var largeString = new string('A', 1 << 29); + + await using var cmd = connection.CreateCommand(); + cmd.CommandText = "SELECT @a"; + cmd.Parameters.AddWithValue("a", new BigComposite + { + A = largeString, + B = largeString, + C = largeString, + D = largeString, + E = largeString, + F = largeString, + G = largeString, + H = largeString + }); + + Assert.ThrowsAsync(async () => await cmd.ExecuteNonQueryAsync()); + } + + record BigComposite + { + public string A { get; set; } = null!; + public string B { get; set; } = null!; + public string C { get; set; } = null!; + public string D { get; set; } = null!; + public string E { get; set; } = null!; + public string F { get; set; } = null!; + public string G { get; set; } = null!; + public string H { get; set; } = null!; + } + + [Test] + public async Task Array_overflow_message_length_throws() + { + // Create a separate data source to avoid breaking unrelated queries + await using var dataSource = CreateDataSource(); + await using var conn = await dataSource.OpenConnectionAsync(); + + var largeString = new string('A', 1 << 29); + + await using var cmd = conn.CreateCommand(); + cmd.CommandText = "SELECT @a"; + var array = new[] + { + largeString, + largeString, + largeString, + largeString, + largeString, + largeString, + largeString, + largeString + }; + cmd.Parameters.AddWithValue("a", array); + + Assert.ThrowsAsync(async () => await cmd.ExecuteNonQueryAsync()); + } + + [Test] + public async Task Range_overflow_message_length_throws() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + var rangeType = await GetTempTypeName(adminConnection); + + await adminConnection.ExecuteNonQueryAsync( + $"CREATE TYPE {type} AS (a text, b text, c text, d text, e text, f text, g text, h text);CREATE TYPE {rangeType} AS RANGE(subtype={type})"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(type); + dataSourceBuilder.EnableUnmappedTypes(); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + var largeString = new string('A', (1 << 28) + 2000000); + + await using var cmd = connection.CreateCommand(); + cmd.CommandText = "SELECT @a"; + var composite = new BigComposite + { + A = largeString, + B = largeString, + C = largeString, + D = largeString + }; + var range = new NpgsqlRange(composite, composite); + cmd.Parameters.Add(new NpgsqlParameter + { + Value = range, + ParameterName = "a", + DataTypeName = rangeType + }); + + Assert.ThrowsAsync(async () => await cmd.ExecuteNonQueryAsync()); + } + + [Test] + public async Task Multirange_overflow_message_length_throws() + { + await using var adminConnection = await OpenConnectionAsync(); + MinimumPgVersion(adminConnection, "14.0", "Multirange types were introduced in PostgreSQL 14"); + var type = await GetTempTypeName(adminConnection); + var rangeType = await GetTempTypeName(adminConnection); + + await adminConnection.ExecuteNonQueryAsync( + $"CREATE TYPE {type} AS (a text, b text, c text, d text, e text, f text, g text, h text);CREATE TYPE {rangeType} AS RANGE(subtype={type})"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(type); + dataSourceBuilder.EnableUnmappedTypes(); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + var largeString = new string('A', (1 << 28) + 2000000); + + await using var cmd = connection.CreateCommand(); + cmd.CommandText = "SELECT @a"; + var composite = new BigComposite + { + A = largeString + }; + var range = new NpgsqlRange(composite, composite); + var multirange = new[] + { + range, + range, + range, + range + }; + cmd.Parameters.Add(new NpgsqlParameter + { + Value = multirange, + ParameterName = "a", + DataTypeName = rangeType + "_multirange" + }); + + Assert.ThrowsAsync(async () => await cmd.ExecuteNonQueryAsync()); + } + [Test, Description("CreateCommand before connection open")] [IssueLink("https://github.com/npgsql/npgsql/issues/565")] public async Task Create_command_before_connection_open() @@ -880,9 +1077,6 @@ public void Connection_not_open_throws() [Test] public async Task ExecuteNonQuery_Throws_PostgresException([Values] bool async) { - if (!async && IsMultiplexing) - return; - await using var conn = await OpenConnectionAsync(); var table1 = await CreateTempTable(conn, "id integer PRIMARY key, t varchar(40)"); @@ -899,9 +1093,6 @@ public async Task ExecuteNonQuery_Throws_PostgresException([Values] bool async) [Test] public async Task ExecuteScalar_Throws_PostgresException([Values] bool async) { - if (!async && IsMultiplexing) - return; - await using var conn = await OpenConnectionAsync(); var table1 = await CreateTempTable(conn, "id integer PRIMARY key, t varchar(40)"); @@ -918,9 +1109,6 @@ public async Task ExecuteScalar_Throws_PostgresException([Values] bool async) [Test] public async Task ExecuteReader_Throws_PostgresException([Values] bool async) { - if (!async && IsMultiplexing) - return; - await using var conn = await OpenConnectionAsync(); var table1 = await CreateTempTable(conn, "id integer PRIMARY key, t varchar(40)"); @@ -933,10 +1121,10 @@ public async Task ExecuteReader_Throws_PostgresException([Values] bool async) ? await cmd.ExecuteReaderAsync() : cmd.ExecuteReader(); - Assert.IsTrue(async ? await reader.ReadAsync() : reader.Read()); + Assert.That(async ? await reader.ReadAsync() : reader.Read()); var value = reader.GetInt32(0); Assert.That(value, Is.EqualTo(1)); - Assert.IsFalse(async ? await reader.ReadAsync() : reader.Read()); + Assert.That(async ? await reader.ReadAsync() : reader.Read(), Is.False); var ex = async ? Assert.ThrowsAsync(async () => await reader.NextResultAsync()) : Assert.Throws(() => reader.NextResult()); @@ -944,11 +1132,15 @@ public async Task ExecuteReader_Throws_PostgresException([Values] bool async) } [Test] - public void Command_is_recycled() + public void Command_is_recycled([Values] bool allResultTypesAreUnknown) { using var conn = OpenConnection(); var cmd1 = conn.CreateCommand(); cmd1.CommandText = "SELECT @p1"; + if (allResultTypesAreUnknown) + cmd1.AllResultTypesAreUnknown = true; + else + cmd1.UnknownResultTypeList = [true]; var tx = conn.BeginTransaction(); cmd1.Transaction = tx; cmd1.Parameters.AddWithValue("p1", 8); @@ -961,6 +1153,8 @@ public void Command_is_recycled() Assert.That(cmd2.CommandType, Is.EqualTo(CommandType.Text)); Assert.That(cmd2.Transaction, Is.Null); Assert.That(cmd2.Parameters, Is.Empty); + Assert.That(cmd2.AllResultTypesAreUnknown, Is.False); + Assert.That(cmd2.UnknownResultTypeList, Is.Null); // TODO: Leaving this for now, since it'll be replaced by the new batching API // Assert.That(cmd2.Statements, Is.Empty); } @@ -982,9 +1176,6 @@ public void Command_recycled_resets_CommandType() [IssueLink("https://github.com/npgsql/npgsql/issues/2795")] public async Task Many_parameters([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) { - if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - return; - using var conn = await OpenConnectionAsync(); var table = await CreateTempTable(conn, "some_column INT"); using var cmd = new NpgsqlCommand { Connection = conn }; @@ -1012,9 +1203,6 @@ public async Task Many_parameters([Values(PrepareOrNot.NotPrepared, PrepareOrNot [IssueLink("https://github.com/npgsql/npgsql/issues/2703")] public async Task Too_many_parameters_throws([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) { - if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - return; - using var conn = await OpenConnectionAsync(); using var cmd = new NpgsqlCommand { Connection = conn }; var sb = new StringBuilder("SOME RANDOM SQL "); @@ -1027,6 +1215,7 @@ public async Task Too_many_parameters_throws([Values(PrepareOrNot.NotPrepared, P sb.Append('@'); sb.Append(paramName); } + cmd.CommandText = sb.ToString(); if (prepare == PrepareOrNot.Prepared) @@ -1098,9 +1287,6 @@ public async Task Batched_big_statements_do_not_deadlock() [Test] public void Batched_small_then_big_statements_do_not_deadlock_in_sync_io() { - if (IsMultiplexing) - return; // Multiplexing, sync I/O - // This makes sure we switch to async writing for batches, starting from the 2nd statement at the latest. // Otherwise, a small first first statement followed by a huge big one could cause us to deadlock, as we're stuck // synchronously sending the 2nd statement while PG is stuck sending the results of the 1st. @@ -1139,9 +1325,6 @@ public async Task Same_command_different_param_instances() [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3509"), Ignore("Flaky")] public async Task Bug3509() { - if (IsMultiplexing) - return; - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) { KeepAlive = 1, @@ -1188,9 +1371,6 @@ public async Task Cached_command_double_dispose() [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4330")] public async Task Prepare_with_positional_placeholders_after_named() { - if (IsMultiplexing) - return; // Explicit preparation - await using var conn = await OpenConnectionAsync(); await using var command = new NpgsqlCommand("SELECT @p", conn); @@ -1208,9 +1388,6 @@ public async Task Prepare_with_positional_placeholders_after_named() [Description("Most of 08* errors are coming whenever there was an error while connecting to a remote server from a cluster, so the connection to the cluster is still OK")] public async Task Postgres_connection_errors_not_break_connection() { - if (IsMultiplexing) - return; - await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); await using var conn = await dataSource.OpenConnectionAsync(); @@ -1234,9 +1411,6 @@ await server [Description("Concurrent write and read failure can lead to deadlocks while cleaning up the connector.")] public async Task Concurrent_read_write_failure_deadlock() { - if (IsMultiplexing) - return; - await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); await using var conn = await dataSource.OpenConnectionAsync(); @@ -1258,9 +1432,6 @@ public async Task Concurrent_read_write_failure_deadlock() [Explicit("Flaky due to #5033")] public async Task Not_cancel_prepended_query([Values] bool failPrependedQuery) { - if (IsMultiplexing) - return; - await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); var csb = new NpgsqlConnectionStringBuilder(postmasterMock.ConnectionString) { @@ -1283,8 +1454,8 @@ public async Task Not_cancel_prepended_query([Values] bool failPrependedQuery) var cancellationRequestTask = postmasterMock.WaitForCancellationRequest().AsTask(); // Give 1 second to make sure we didn't send cancellation request await Task.Delay(1000); - Assert.IsFalse(cancelTask.IsCompleted); - Assert.IsFalse(cancellationRequestTask.IsCompleted); + Assert.That(cancelTask.IsCompleted, Is.False); + Assert.That(cancellationRequestTask.IsCompleted, Is.False); if (failPrependedQuery) { @@ -1331,9 +1502,6 @@ await server [Test] public async Task Cancel_while_reading_from_long_running_query() { - if (IsMultiplexing) - return; - await using var conn = await OpenConnectionAsync(); await using var cmd = conn.CreateCommand(); @@ -1368,9 +1536,6 @@ SELECT generate_series(1, 1000000) AS "i" [Description("Make sure we do not lose unread messages after resetting oversize buffer")] public async Task Oversize_buffer_lost_messages() { - if (IsMultiplexing) - return; - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) { NoResetOnClose = true @@ -1389,7 +1554,7 @@ await server // Just to make sure we have enough space await server.FlushAsync(); await server - .WriteDataRow(Encoding.ASCII.GetBytes("abc")) + .WriteDataRow("abc"u8.ToArray()) .WriteCommandComplete() .WriteReadyForQuery() .WriteParameterStatus("SomeKey", "SomeValue") @@ -1402,7 +1567,7 @@ await server await connection.CloseAsync(); await connection.OpenAsync(); - Assert.AreSame(connector, connection.Connector); + Assert.That(connection.Connector, Is.SameAs(connector)); // We'll get new value after the next query reads ParameterStatus from the buffer Assert.That(connection.PostgresParameters, Does.Not.ContainKey("SomeKey").WithValue("SomeValue")); @@ -1410,7 +1575,7 @@ await server .WriteParseComplete() .WriteBindComplete() .WriteRowDescription(new FieldDescription(TextOid)) - .WriteDataRow(Encoding.ASCII.GetBytes("abc")) + .WriteDataRow("abc"u8.ToArray()) .WriteCommandComplete() .WriteReadyForQuery() .FlushAsync(); @@ -1420,180 +1585,32 @@ await server Assert.That(connection.PostgresParameters, Contains.Key("SomeKey").WithValue("SomeValue")); } - #region Logging - [Test] - public async Task Log_ExecuteScalar_single_statement_without_parameters() + public async Task Completed_transaction_throws([Values] bool commit) { - await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider); - await using var conn = await dataSource.OpenConnectionAsync(); - await using var cmd = new NpgsqlCommand("SELECT 1", conn); - - using (listLoggerProvider.Record()) - { - await cmd.ExecuteScalarAsync(); - } - - var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); - Assert.That(executingCommandEvent.Message, Does.Contain("Command execution completed").And.Contains("SELECT 1")); - AssertLoggingStateContains(executingCommandEvent, "CommandText", "SELECT 1"); - AssertLoggingStateDoesNotContain(executingCommandEvent, "Parameters"); - - if (!IsMultiplexing) - AssertLoggingStateContains(executingCommandEvent, "ConnectorId", conn.ProcessID); - } - - [Test] - public async Task Log_ExecuteScalar_single_statement_with_positional_parameters() - { - await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider); - await using var conn = await dataSource.OpenConnectionAsync(); - await using var cmd = new NpgsqlCommand("SELECT $1, $2", conn); - cmd.Parameters.Add(new() { Value = 8 }); - cmd.Parameters.Add(new() { NpgsqlDbType = NpgsqlDbType.Integer, Value = DBNull.Value }); - - using (listLoggerProvider.Record()) - { - await cmd.ExecuteScalarAsync(); - } - - var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); - Assert.That(executingCommandEvent.Message, Does.Contain("Command execution completed") - .And.Contains("SELECT $1, $2") - .And.Contains("Parameters: [8, NULL]")); - AssertLoggingStateContains(executingCommandEvent, "CommandText", "SELECT $1, $2"); - AssertLoggingStateContains(executingCommandEvent, "Parameters", new object[] { 8, "NULL" }); - - if (!IsMultiplexing) - AssertLoggingStateContains(executingCommandEvent, "ConnectorId", conn.ProcessID); - } - - [Test] - public async Task Log_ExecuteScalar_single_statement_with_named_parameters() - { - await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider); - await using var conn = await dataSource.OpenConnectionAsync(); - await using var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn); - cmd.Parameters.Add(new() { ParameterName = "p1", Value = 8 }); - cmd.Parameters.Add(new() { ParameterName = "p2", NpgsqlDbType = NpgsqlDbType.Integer, Value = DBNull.Value }); - - using (listLoggerProvider.Record()) - { - await cmd.ExecuteScalarAsync(); - } - - var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); - Assert.That(executingCommandEvent.Message, Does.Contain("Command execution completed") - .And.Contains("SELECT $1, $2") - .And.Contains("Parameters: [8, NULL]")); - AssertLoggingStateContains(executingCommandEvent, "CommandText", "SELECT $1, $2"); - AssertLoggingStateContains(executingCommandEvent, "Parameters", new object[] { 8, "NULL" }); - - if (!IsMultiplexing) - AssertLoggingStateContains(executingCommandEvent, "ConnectorId", conn.ProcessID); - } - - [Test] - public async Task Log_ExecuteScalar_single_statement_with_parameter_logging_off() - { - await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider, sensitiveDataLoggingEnabled: false); - await using var conn = await dataSource.OpenConnectionAsync(); - await using var cmd = new NpgsqlCommand("SELECT $1, $2", conn); - cmd.Parameters.Add(new() { Value = 8 }); - cmd.Parameters.Add(new() { Value = 9 }); - - using (listLoggerProvider.Record()) - { - await cmd.ExecuteScalarAsync(); - } - - var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); - Assert.That(executingCommandEvent.Message, Does.Contain("Command execution completed").And.Contains($"SELECT $1, $2")); - AssertLoggingStateContains(executingCommandEvent, "CommandText", "SELECT $1, $2"); - AssertLoggingStateDoesNotContain(executingCommandEvent, "Parameters"); - } - - [Test] - public async Task Log_ExecuteScalar_multiple_statement_without_parameters() - { - await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider); - await using var conn = await dataSource.OpenConnectionAsync(); - await using var cmd = new NpgsqlCommand("SELECT 1; SELECT 2", conn); - - using (listLoggerProvider.Record()) - { - await cmd.ExecuteScalarAsync(); - } - - var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); - Assert.That(executingCommandEvent.Message, Does.Contain("Batch execution completed").And.Contains("[(SELECT 1, System.Object[]), (SELECT 2, System.Object[])]")); - var batchCommands = (IList<(string CommandText, object[] Parameters)>)AssertLoggingStateContains(executingCommandEvent, "BatchCommands"); - Assert.That(batchCommands.Count, Is.EqualTo(2)); - Assert.That(batchCommands[0].CommandText, Is.EqualTo("SELECT 1")); - Assert.That(batchCommands[0].Parameters, Is.Empty); - Assert.That(batchCommands[1].CommandText, Is.EqualTo("SELECT 2")); - Assert.That(batchCommands[1].Parameters, Is.Empty); - AssertLoggingStateDoesNotContain(executingCommandEvent, "Parameters"); - - if (!IsMultiplexing) - AssertLoggingStateContains(executingCommandEvent, "ConnectorId", conn.ProcessID); - } - - [Test] - public async Task Log_ExecuteScalar_multiple_statement_with_parameters() - { - await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider); - await using var conn = await dataSource.OpenConnectionAsync(); - await using var cmd = new NpgsqlCommand("SELECT @p1; SELECT @p2", conn); - cmd.Parameters.Add(new() { ParameterName = "p1", Value = 8 }); - cmd.Parameters.Add(new() { ParameterName = "p2", Value = 9 }); - - using (listLoggerProvider.Record()) - { - await cmd.ExecuteScalarAsync(); - } + await using var conn = await OpenConnectionAsync(); + await using var tx = await conn.BeginTransactionAsync(); + await using var cmd = conn.CreateCommand(); - var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); - Assert.That(executingCommandEvent.Message, Does.Contain("Batch execution completed").And.Contains("[(SELECT $1, System.Object[]), (SELECT $1, System.Object[])]")); - var batchCommands = (IList<(string CommandText, object[] Parameters)>)AssertLoggingStateContains(executingCommandEvent, "BatchCommands"); - Assert.That(batchCommands.Count, Is.EqualTo(2)); - Assert.That(batchCommands[0].CommandText, Is.EqualTo("SELECT $1")); - Assert.That(batchCommands[0].Parameters[0], Is.EqualTo(8)); - Assert.That(batchCommands[1].CommandText, Is.EqualTo("SELECT $1")); - Assert.That(batchCommands[1].Parameters[0], Is.EqualTo(9)); - AssertLoggingStateDoesNotContain(executingCommandEvent, "Parameters"); + if (commit) + await tx.CommitAsync(); + else + await tx.RollbackAsync(); - if (!IsMultiplexing) - AssertLoggingStateContains(executingCommandEvent, "ConnectorId", conn.ProcessID); + Assert.Throws(() => cmd.Transaction = tx); } - [Test] - public async Task Log_ExecuteScalar_multiple_statement_with_parameter_logging_off() + [Test, Description("Writing to properties of a disposed command raises ObjectDisposedException.")] + public async Task Disposed_command_throws_on_assignment() { - await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider, sensitiveDataLoggingEnabled: false); - await using var conn = await dataSource.OpenConnectionAsync(); - await using var cmd = new NpgsqlCommand("SELECT @p1; SELECT @p2", conn); - cmd.Parameters.Add(new() { ParameterName = "p1", Value = 8 }); - cmd.Parameters.Add(new() { ParameterName = "p2", Value = 9 }); - - using (listLoggerProvider.Record()) - { - await cmd.ExecuteScalarAsync(); - } + await using var conn = await OpenConnectionAsync(); + var command = new NpgsqlCommand("SELECT 1"); + command.Dispose(); - var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); - Assert.That(executingCommandEvent.Message, Does.Contain("Batch execution completed").And.Contains("[SELECT $1, SELECT $1]")); - var batchCommands = (IList)AssertLoggingStateContains(executingCommandEvent, "BatchCommands"); - Assert.That(batchCommands.Count, Is.EqualTo(2)); - Assert.That(batchCommands[0], Is.EqualTo("SELECT $1")); - Assert.That(batchCommands[1], Is.EqualTo("SELECT $1")); - AssertLoggingStateDoesNotContain(executingCommandEvent, "Parameters"); + Assert.Throws(() => command.Connection = conn); + Assert.Throws(() => command.CommandText = "SELECT 2"); - if (!IsMultiplexing) - AssertLoggingStateContains(executingCommandEvent, "ConnectorId", conn.ProcessID); + Assert.That(command.Connection, Is.Null); + Assert.That(command.CommandText, Is.EqualTo("SELECT 1")); } - - #endregion Logging - - public CommandTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/ConnectionTests.cs b/test/Npgsql.Tests/ConnectionTests.cs index d0fb7f827b..d33441488c 100644 --- a/test/Npgsql.Tests/ConnectionTests.cs +++ b/test/Npgsql.Tests/ConnectionTests.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Net; using System.Net.Security; +using System.Net.Sockets; using System.Runtime.InteropServices; using System.Security.Cryptography.X509Certificates; using System.Text; @@ -13,6 +14,7 @@ using System.Threading.Tasks; using Npgsql.Internal; using Npgsql.PostgresTypes; +using Npgsql.Tests.Support; using Npgsql.Util; using NpgsqlTypes; using NUnit.Framework; @@ -20,7 +22,7 @@ namespace Npgsql.Tests; -public class ConnectionTests : MultiplexingTestBase +public class ConnectionTests : TestBase { [Test, Description("Makes sure the connection goes through the proper state lifecycle")] public async Task Basic_lifecycle() @@ -32,12 +34,10 @@ public async Task Basic_lifecycle() conn.StateChange += (s, e) => { - if (e.OriginalState == ConnectionState.Closed && - e.CurrentState == ConnectionState.Open) + if (e is { OriginalState: ConnectionState.Closed, CurrentState: ConnectionState.Open }) eventOpen = true; - if (e.OriginalState == ConnectionState.Open && - e.CurrentState == ConnectionState.Closed) + if (e is { OriginalState: ConnectionState.Open, CurrentState: ConnectionState.Closed }) eventClosed = true; }; @@ -72,9 +72,6 @@ public async Task Basic_lifecycle() [Test, Description("Makes sure the connection goes through the proper state lifecycle")] public async Task Broken_lifecycle([Values] bool openFromClose) { - if (IsMultiplexing) - return; - await using var dataSource = CreateDataSource(); await using var conn = dataSource.CreateConnection(); @@ -83,12 +80,10 @@ public async Task Broken_lifecycle([Values] bool openFromClose) conn.StateChange += (s, e) => { - if (e.OriginalState == ConnectionState.Closed && - e.CurrentState == ConnectionState.Open) + if (e is { OriginalState: ConnectionState.Closed, CurrentState: ConnectionState.Open }) eventOpen = true; - if (e.OriginalState == ConnectionState.Open && - e.CurrentState == ConnectionState.Closed) + if (e is { OriginalState: ConnectionState.Open, CurrentState: ConnectionState.Closed }) eventClosed = true; }; @@ -115,7 +110,7 @@ public async Task Broken_lifecycle([Values] bool openFromClose) Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); Assert.That(eventClosed, Is.True); Assert.That(conn.Connector is null); - Assert.AreEqual(0, conn.NpgsqlDataSource.Statistics.Total); + Assert.That(conn.NpgsqlDataSource.Statistics.Total, Is.EqualTo(0)); if (openFromClose) { @@ -127,18 +122,14 @@ public async Task Broken_lifecycle([Values] bool openFromClose) } Assert.DoesNotThrowAsync(conn.OpenAsync); - Assert.AreEqual(1, await conn.ExecuteScalarAsync("SELECT 1")); - Assert.AreEqual(1, conn.NpgsqlDataSource.Statistics.Total); + Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); + Assert.That(conn.NpgsqlDataSource.Statistics.Total, Is.EqualTo(1)); Assert.DoesNotThrowAsync(conn.CloseAsync); } [Test] - [Platform(Exclude = "MacOsX", Reason = "Flaky on MacOS")] public async Task Break_while_open() { - if (IsMultiplexing) - return; - await using var dataSource = CreateDataSource(); await using var conn = await dataSource.OpenConnectionAsync(); @@ -192,7 +183,6 @@ public async Task Connection_refused_async(bool pooled) #endif [Test] - [Ignore("Fails in a non-determinstic manner and only on the build server... investigate...")] public void Invalid_Username() { var connString = new NpgsqlConnectionStringBuilder(ConnectionString) @@ -221,15 +211,12 @@ public void Bad_database() [Test, Description("Tests that mandatory connection string parameters are indeed mandatory")] public void Mandatory_connection_string_params() - => Assert.Throws(() => + => Assert.Throws(() => new NpgsqlConnection("User ID=npgsql_tests;Password=npgsql_tests;Database=npgsql_tests")); [Test, Description("Reuses the same connection instance for a failed connection, then a successful one")] public async Task Fail_connect_then_succeed([Values] bool pooling) { - if (IsMultiplexing && !pooling) // Multiplexing doesn't work without pooling - return; - var dbName = GetUniqueIdentifier(nameof(Fail_connect_then_succeed)); await using var conn1 = await OpenConnectionAsync(); await conn1.ExecuteNonQueryAsync($"DROP DATABASE IF EXISTS \"{dbName}\""); @@ -317,6 +304,38 @@ public void Connect_timeout_cancel() Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); } + [Test] + public void Bad_hostname() + { + using var dataSource = CreateDataSource(csb => csb.Host = "hostname.that.does.not.exist"); + using var conn = dataSource.CreateConnection(); + + Assert.That( + () => conn.Open(), + Throws.Exception + .TypeOf() + .With + .Property(nameof(NpgsqlException.InnerException)) + .TypeOf() + ); + } + + [Test] + public void Bad_hostname_async() + { + using var dataSource = CreateDataSource(csb => csb.Host = "hostname.that.does.not.exist"); + using var conn = dataSource.CreateConnection(); + + Assert.That( + async () => await conn.OpenAsync(), + Throws.Exception + .TypeOf() + .With + .Property(nameof(NpgsqlException.InnerException)) + .TypeOf() + ); + } + #endregion #region Client Encoding @@ -395,6 +414,47 @@ public async Task Timezone_connection_param() #endregion Timezone + #region Application Name + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/6133")] + [NonParallelizable] // Sets environment variable + public async Task Application_name_env_var() + { + const string testAppName = "MyTestApp"; + + // Note that the pool is unaware of the environment variable, so if a connection is + // returned from the pool it may contain the wrong application name + using var _ = SetEnvironmentVariable("PGAPPNAME", testAppName); + await using var dataSource = CreateDataSource(); + await using var conn = await dataSource.OpenConnectionAsync(); + Assert.That(conn.PostgresParameters["application_name"], Is.EqualTo(testAppName)); + } + + [Test] + public async Task Application_name_connection_param() + { + const string testAppName = "MyTestApp2"; + + await using var dataSource = CreateDataSource(csb => csb.ApplicationName = testAppName); + await using var conn = await dataSource.OpenConnectionAsync(); + Assert.That(conn.PostgresParameters["application_name"], Is.EqualTo(testAppName)); + } + + [Test] + [NonParallelizable] // Sets environment variable + public async Task Application_name_connection_param_overrides_env_var() + { + const string envAppName = "EnvApp"; + const string connAppName = "ConnApp"; + + using var _ = SetEnvironmentVariable("PGAPPNAME", envAppName); + await using var dataSource = CreateDataSource(csb => csb.ApplicationName = connAppName); + await using var conn = await dataSource.OpenConnectionAsync(); + Assert.That(conn.PostgresParameters["application_name"], Is.EqualTo(connAppName)); + } + + #endregion Application Name + #region ConnectionString - Host [TestCase("127.0.0.1", ExpectedResult = new [] { "127.0.0.1:5432" })] @@ -509,10 +569,6 @@ public void DataSource_property() conn.ConnectionString = csb.ConnectionString; Assert.That(conn.DataSource, Is.EqualTo($"tcp://{csb.Host}:{csb.Port}")); - // Multiplexing isn't supported with multiple hosts - if (IsMultiplexing) - return; - csb.Host = "127.0.0.1, 127.0.0.2"; conn.ConnectionString = csb.ConnectionString; Assert.That(conn.DataSource, Is.EqualTo(string.Empty)); @@ -656,6 +712,83 @@ public void Set_connection_string_to_empty() Assert.That(() => conn.Open(), Throws.Exception.TypeOf()); } + [Test] + [TestCase("test_schema_1", "public", true)] + [TestCase("test_schema_1", "test_schema_2", true)] + [TestCase("test_schema_2", "test_schema_3", true)] + [TestCase("test_schema_1", "public", false)] + [TestCase("test_schema_1", "test_schema_2", false)] + [TestCase("test_schema_2", "test_schema_3", false)] + [TestCase("'DROP TABLE X", "'COMMIT; ", false)] + [Parallelizable(ParallelScope.None)] + public async Task Set_Schemas_And_Load_Relevant_Types(string testSchema, string otherSchema, bool enabled) + { + await using var conn1 = await OpenConnectionAsync(); + try + { + await conn1.ExecuteNonQueryAsync("DROP TYPE IF EXISTS public.test_type_1"); + await conn1.ExecuteNonQueryAsync("DROP TYPE IF EXISTS public.test_type_2"); + await conn1.ExecuteNonQueryAsync("DROP TYPE IF EXISTS public.test_type_3"); + await conn1.ExecuteNonQueryAsync("CREATE TYPE public.test_type_3 AS (id int, name text)"); + + if (testSchema != "public") + { + await conn1.ExecuteNonQueryAsync($"DROP SCHEMA IF EXISTS \"{testSchema}\" CASCADE"); + await conn1.ExecuteNonQueryAsync($"CREATE SCHEMA \"{testSchema}\""); + } + + if (otherSchema != "public") + { + await conn1.ExecuteNonQueryAsync($"DROP SCHEMA IF EXISTS \"{otherSchema}\" CASCADE"); + await conn1.ExecuteNonQueryAsync($"CREATE SCHEMA \"{otherSchema}\""); + } + + await conn1.ExecuteNonQueryAsync($"DROP TYPE IF EXISTS \"{testSchema}\".test_type_1"); + await conn1.ExecuteNonQueryAsync($"CREATE TYPE \"{testSchema}\".test_type_1 AS (id int)"); + await conn1.ExecuteNonQueryAsync($"DROP TYPE IF EXISTS \"{otherSchema}\".test_type_2"); + await conn1.ExecuteNonQueryAsync($"CREATE TYPE \"{otherSchema}\".test_type_2 AS (id int, name text)"); + + using var dataSource = CreateDataSource(builder => + { + builder.ConfigureTypeLoading(builder => + { + if (enabled) + builder.SetTypeLoadingSchemas(testSchema, otherSchema); + }); + }); + using var conn = await dataSource.OpenConnectionAsync(); + var databaseInfo = dataSource.CurrentReloadableState.DatabaseInfo; + if (enabled) + { + Assert.That(databaseInfo.CompositeTypes.Any(x => x.Name == "test_type_1")); + if (testSchema == "public" || otherSchema == "public") + { + Assert.That(databaseInfo.CompositeTypes.Any(x => x.Name == "test_type_2")); + Assert.That(databaseInfo.CompositeTypes.Any(x => x.Name == "test_type_3")); + } + else + { + Assert.That(databaseInfo.CompositeTypes.Any(x => x.Name == "test_type_2")); + Assert.That(databaseInfo.CompositeTypes.Any(x => x.Name == "test_type_3"), Is.False); + } + } + else + { + Assert.That(databaseInfo.CompositeTypes.Any(x => x.Name == "test_type_1")); + Assert.That(databaseInfo.CompositeTypes.Any(x => x.Name == "test_type_2")); + Assert.That(databaseInfo.CompositeTypes.Any(x => x.Name == "test_type_3")); + } + } + finally + { + if (testSchema != "public") + await conn1.ExecuteNonQueryAsync($"DROP SCHEMA IF EXISTS \"{testSchema}\" CASCADE"); + if (otherSchema != "public") + await conn1.ExecuteNonQueryAsync($"DROP SCHEMA IF EXISTS \"{otherSchema}\" CASCADE"); + } + + } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/703")] public async Task No_database_defaults_to_username() { @@ -668,14 +801,10 @@ public async Task No_database_defaults_to_username() } [Test, Description("Breaks a connector while it's in the pool, with a keepalive and without")] - [Platform(Exclude = "MacOsX", Reason = "Fails only on mac, needs to be investigated")] [TestCase(false, TestName = nameof(Break_connector_in_pool) + "_without_keep_alive")] [TestCase(true, TestName = nameof(Break_connector_in_pool) + "_with_keep_alive")] public async Task Break_connector_in_pool(bool keepAlive) { - if (IsMultiplexing) - Assert.Ignore("Multiplexing, hanging"); - var dataSourceBuilder = CreateDataSourceBuilder(); dataSourceBuilder.ConnectionStringBuilder.MaxPoolSize = 1; if (keepAlive) @@ -712,9 +841,6 @@ public async Task Break_connector_in_pool(bool keepAlive) [IssueLink("https://github.com/npgsql/npgsql/issues/4603")] public async Task Reload_types_keepalive_concurrent() { - if (IsMultiplexing) - Assert.Ignore("Multiplexing doesn't support keepalive"); - await using var dataSource = CreateDataSource(csb => csb.KeepAlive = 1); await using var conn = await dataSource.OpenConnectionAsync(); @@ -764,9 +890,6 @@ public void ChangeDatabase_connection_on_closed_connection_throws() [Test, Description("Tests closing a connector while a reader is open")] public async Task Close_during_read([Values(PooledOrNot.Pooled, PooledOrNot.Unpooled)] PooledOrNot pooled) { - if (IsMultiplexing && pooled == PooledOrNot.Unpooled) - return; // Multiplexing requires pooling - await using var dataSource = CreateDataSource(csb => csb.Pooling = pooled == PooledOrNot.Pooled); await using var conn = await dataSource.OpenConnectionAsync(); await using (var cmd = new NpgsqlCommand("SELECT 1", conn)) @@ -829,7 +952,7 @@ public void Bug1011001() var cs1 = csb1.ToString(); var csb2 = new NpgsqlConnectionStringBuilder(cs1); var cs2 = csb2.ToString(); - Assert.IsTrue(cs1 == cs2); + Assert.That(cs1 == cs2); } [Test, IssueLink("https://github.com/npgsql/npgsql/pull/164")] @@ -837,7 +960,7 @@ public void Connection_State_is_Closed_when_disposed() { var c = new NpgsqlConnection(); c.Dispose(); - Assert.AreEqual(ConnectionState.Closed, c.State); + Assert.That(c.State, Is.EqualTo(ConnectionState.Closed)); } [Test] @@ -886,8 +1009,6 @@ await conn.ExecuteNonQueryAsync($@" [Test, Description("Makes sure that concurrent use of the connection throws an exception")] public async Task Concurrent_use_throws() { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: fails"); using var conn = await OpenConnectionAsync(); using (var cmd = new NpgsqlCommand("SELECT 1", conn)) using (await cmd.ExecuteReaderAsync()) @@ -910,9 +1031,6 @@ public async Task Concurrent_use_throws() [IssueLink("https://github.com/npgsql/npgsql/issues/783")] public void PersistSecurityInfo_is_true([Values(true, false)] bool pooling) { - if (IsMultiplexing && !pooling) - return; - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) { PersistSecurityInfo = true, @@ -929,9 +1047,6 @@ public void PersistSecurityInfo_is_true([Values(true, false)] bool pooling) [IssueLink("https://github.com/npgsql/npgsql/issues/783")] public void No_password_without_PersistSecurityInfo([Values(true, false)] bool pooling) { - if (IsMultiplexing && !pooling) - return; - var connString = new NpgsqlConnectionStringBuilder(ConnectionString) { Pooling = pooling @@ -945,7 +1060,7 @@ public void No_password_without_PersistSecurityInfo([Values(true, false)] bool p } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2725")] - public void Clone_with_PersistSecurityInfo() + public async Task Clone_with_PersistSecurityInfo([Values] bool async) { var builder = new NpgsqlConnectionStringBuilder(ConnectionString) { @@ -958,20 +1073,24 @@ public void Clone_with_PersistSecurityInfo() // First un-persist, should work builder.PersistSecurityInfo = false; var connStringWithoutPersist = builder.ToString(); - using var clonedWithoutPersist = connWithPersist.CloneWith(connStringWithoutPersist); + using var clonedWithoutPersist = async + ? await connWithPersist.CloneWithAsync(connStringWithoutPersist) + : connWithPersist.CloneWith(connStringWithoutPersist); clonedWithoutPersist.Open(); Assert.That(clonedWithoutPersist.ConnectionString, Does.Not.Contain("Password=")); // Then attempt to re-persist, should not work - using var clonedConn = clonedWithoutPersist.CloneWith(connStringWithPersist); + using var clonedConn = async + ? await clonedWithoutPersist.CloneWithAsync(connStringWithPersist) + : clonedWithoutPersist.CloneWith(connStringWithPersist); clonedConn.Open(); Assert.That(clonedConn.ConnectionString, Does.Not.Contain("Password=")); } [Test] - public async Task CloneWith_and_data_source_with_password() + public async Task CloneWith_and_data_source_with_password([Values] bool async) { var dataSourceBuilder = new NpgsqlDataSourceBuilder(ConnectionString); // Set the password via the data source property later to make sure that's picked up by CloneWith @@ -984,33 +1103,41 @@ public async Task CloneWith_and_data_source_with_password() // Test that the up-to-date password gets copied to the clone, as if we opened the original connection instead of cloning it using var _ = CreateTempPool(new NpgsqlConnectionStringBuilder(ConnectionString) { Password = null }, out var tempConnectionString); - await using var clonedConnection = connection.CloneWith(tempConnectionString); + await using var clonedConnection = async + ? await connection.CloneWithAsync(tempConnectionString) + : connection.CloneWith(tempConnectionString); await clonedConnection.OpenAsync(); } [Test] - public async Task CloneWith_and_data_source_with_auth_callbacks() + public async Task CloneWith_and_data_source_with_auth_callbacks([Values] bool async) { var (userCertificateValidationCallbackCalled, clientCertificatesCallbackCalled) = (false, false); var dataSourceBuilder = CreateDataSourceBuilder(); - dataSourceBuilder.UseUserCertificateValidationCallback(UserCertificateValidationCallback); - dataSourceBuilder.UseClientCertificatesCallback(ClientCertificatesCallback); + dataSourceBuilder.UseSslClientAuthenticationOptionsCallback(options => + { + ClientCertificatesCallback(options.ClientCertificates); + options.RemoteCertificateValidationCallback = UserCertificateValidationCallback; + }); await using var dataSource = dataSourceBuilder.Build(); await using var connection = dataSource.CreateConnection(); using var _ = CreateTempPool(ConnectionString, out var tempConnectionString); - await using var clonedConnection = connection.CloneWith(tempConnectionString); + await using var clonedConnection = async + ? await connection.CloneWithAsync(tempConnectionString) + : connection.CloneWith(tempConnectionString); - clonedConnection.UserCertificateValidationCallback!(null!, null, null, SslPolicyErrors.None); - Assert.True(userCertificateValidationCallbackCalled); - clonedConnection.ProvideClientCertificatesCallback!(null!); - Assert.True(clientCertificatesCallbackCalled); + var sslClientAuthenticationOptions = new SslClientAuthenticationOptions(); + clonedConnection.SslClientAuthenticationOptionsCallback!(sslClientAuthenticationOptions); + Assert.That(clientCertificatesCallbackCalled); + sslClientAuthenticationOptions.RemoteCertificateValidationCallback!(null!, null, null, SslPolicyErrors.None); + Assert.That(userCertificateValidationCallbackCalled); bool UserCertificateValidationCallback(object sender, X509Certificate? certificate, X509Chain? chain, SslPolicyErrors errors) => userCertificateValidationCallbackCalled = true; - void ClientCertificatesCallback(X509CertificateCollection certs) + void ClientCertificatesCallback(X509CertificateCollection? certs) => clientCertificatesCallbackCalled = true; } @@ -1023,18 +1150,15 @@ public void Clone() { using var pool = CreateTempPool(ConnectionString, out var connectionString); using var conn = new NpgsqlConnection(connectionString); - ProvideClientCertificatesCallback callback1 = certificates => { }; - conn.ProvideClientCertificatesCallback = callback1; - RemoteCertificateValidationCallback callback2 = (sender, certificate, chain, errors) => true; - conn.UserCertificateValidationCallback = callback2; + Action callback = _ => { }; + conn.SslClientAuthenticationOptionsCallback = callback; conn.Open(); Assert.That(async () => await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); using var conn2 = (NpgsqlConnection)((ICloneable)conn).Clone(); Assert.That(conn2.ConnectionString, Is.EqualTo(conn.ConnectionString)); - Assert.That(conn2.ProvideClientCertificatesCallback, Is.SameAs(callback1)); - Assert.That(conn2.UserCertificateValidationCallback, Is.SameAs(callback2)); + Assert.That(conn2.SslClientAuthenticationOptionsCallback, Is.SameAs(callback)); conn2.Open(); Assert.That(async () => await conn2.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); } @@ -1052,8 +1176,6 @@ public async Task Clone_with_data_source() [Test] public async Task DatabaseInfo_is_shared() { - if (IsMultiplexing) - return; // Create a temp pool to make sure the second connection will be new and not idle await using var dataSource = CreateDataSource(); await using var conn1 = await dataSource.OpenConnectionAsync(); @@ -1100,7 +1222,6 @@ public async Task Many_open_close_with_transaction() [Test] [IssueLink("https://github.com/npgsql/npgsql/issues/927")] [IssueLink("https://github.com/npgsql/npgsql/issues/736")] - [Ignore("Fails when running the entire test suite but not on its own...")] public async Task Rollback_on_close() { // Npgsql 3.0.0 to 3.0.4 prepended a rollback for the next time the connector is used, as an optimization. @@ -1128,10 +1249,6 @@ public async Task Rollback_on_close() [IssueLink("https://github.com/npgsql/npgsql/issues/777")] public async Task Exception_during_close() { - // Pooling must be on to use multiplexing - if (IsMultiplexing) - return; - await using var dataSource = CreateDataSource(csb => csb.Pooling = false); await using var conn = await dataSource.OpenConnectionAsync(); var connectorId = conn.ProcessID; @@ -1145,7 +1262,7 @@ public async Task Exception_during_close() [Test, Description("Some pseudo-PG database don't support pg_type loading, we have a minimal DatabaseInfo for this")] public async Task NoTypeLoading() { - await using var dataSource = CreateDataSource(csb => csb.ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading); + await using var dataSource = CreateDataSource(builder => builder.ConfigureTypeLoading(builder => builder.EnableTypeLoading())); await using var conn = await dataSource.OpenConnectionAsync(); Assert.That(await conn.ExecuteScalarAsync("SELECT 8"), Is.EqualTo(8)); @@ -1179,9 +1296,6 @@ public async Task NoTypeLoading() [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1158")] public async Task Table_named_record() { - if (IsMultiplexing) - Assert.Ignore("Multiplexing, ReloadTypes"); - using var conn = await OpenConnectionAsync(); await conn.ExecuteNonQueryAsync(@" @@ -1199,8 +1313,7 @@ await conn.ExecuteNonQueryAsync(@" } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/392")] - [NonParallelizable] - [Platform(Exclude = "MacOsX", Reason = "Flaky in CI on Mac")] + [NonParallelizable] // Drops and creates same database across modes public async Task Non_UTF8_Encoding() { Encoding.RegisterProvider(CodePagesEncodingProvider.Instance); @@ -1234,7 +1347,7 @@ await adminConn.ExecuteNonQueryAsync( await using var cmd = conn.CreateCommand(); cmd.CommandText = "SELECT * FROM foo"; await using var reader = await cmd.ExecuteReaderAsync(); - Assert.IsTrue(await reader.ReadAsync()); + Assert.That(await reader.ReadAsync()); using (var textReader = await reader.GetTextReaderAsync(0)) Assert.That(textReader.ReadToEnd(), Is.EqualTo(value)); @@ -1261,9 +1374,6 @@ await adminConn.ExecuteNonQueryAsync( [Test] public async Task Oversize_buffer() { - if (IsMultiplexing) - return; - await using var dataSource = CreateDataSource(); await using var conn = await dataSource.OpenConnectionAsync(); var csb = new NpgsqlConnectionStringBuilder(ConnectionString); @@ -1319,9 +1429,6 @@ public async Task TcpKeepalive() [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3511")] public async Task Keepalive_with_failed_transaction() { - if (IsMultiplexing) - return; - await using var dataSource = CreateDataSource(csb => csb.KeepAlive = 1); await using var conn = await dataSource.OpenConnectionAsync(); await using var tx = await conn.BeginTransactionAsync(); @@ -1341,9 +1448,6 @@ public async Task Keepalive_with_failed_transaction() [Test] public async Task Change_parameter() { - if (IsMultiplexing) - return; - using var conn = await OpenConnectionAsync(); var defaultApplicationName = conn.PostgresParameters["application_name"]; await conn.ExecuteNonQueryAsync("SET application_name = 'some_test_value'"); @@ -1386,7 +1490,7 @@ public async Task NoResetOnClose(bool noResetOnClose) await conn.CloseAsync(); await conn.OpenAsync(); Assert.That(await conn.ExecuteScalarAsync("SHOW application_name"), Is.EqualTo( - noResetOnClose || IsMultiplexing + noResetOnClose ? "modified" : originalApplicationName)); } @@ -1395,9 +1499,6 @@ public async Task NoResetOnClose(bool noResetOnClose) [Description("Test whether the internal NpgsqlConnection.Open method stays on the same thread with async=false")] public async Task Sync_open_blocked_same_thread() { - if (IsMultiplexing) - return; - await using var dataSource = CreateDataSource(csb => { csb.MaxPoolSize = 1; @@ -1430,7 +1531,32 @@ public async Task Sync_open_blocked_same_thread() foreach (var sameThreadTask in sameThreadTasks) { - Assert.IsTrue(await sameThreadTask, "Synchronous open completed on different thread"); + Assert.That(await sameThreadTask, "Synchronous open completed on different thread"); + } + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/6427")] + public async Task Gss_encryption_retry_does_not_clear_pool() + { + var csb = new NpgsqlConnectionStringBuilder(ConnectionString) + { + GssEncryptionMode = GssEncryptionMode.Prefer + }; + // Break connection on gss encryption request to force the client to create a new connection and retry again + // This emulates the behavior of older versions of PostgreSQL or its forks, like Supabase + await using var postmaster = PgPostmasterMock.Start(csb.ConnectionString, breakOnGssEncryptionRequest: true); + await using var dataSource = CreateDataSource(postmaster.ConnectionString); + + int processID; + await using (var conn = await dataSource.OpenConnectionAsync()) + { + processID = conn.ProcessID; + } + + // The second time we get a connection from the pool we should ge the exact same connection + await using (var conn = await dataSource.OpenConnectionAsync()) + { + Assert.That(conn.ProcessID, Is.EqualTo(processID)); } } @@ -1439,9 +1565,6 @@ public async Task Sync_open_blocked_same_thread() [Test] public async Task PhysicalConnectionInitializer_sync() { - if (IsMultiplexing) // Sync I/O - return; - await using var adminConn = await OpenConnectionAsync(); var table = await CreateTempTable(adminConn, "ID INTEGER"); @@ -1466,11 +1589,6 @@ public async Task PhysicalConnectionInitializer_sync() [Test] public async Task PhysicalConnectionInitializer_async() { - // With multiplexing the connector might become idle at undetermined point after the query is executed. - // Which is why we ignore it. - if (IsMultiplexing) - return; - await using var adminConn = await OpenConnectionAsync(); var table = await CreateTempTable(adminConn, "ID INTEGER"); @@ -1495,9 +1613,6 @@ public async Task PhysicalConnectionInitializer_async() [Test] public async Task PhysicalConnectionInitializer_sync_with_break() { - if (IsMultiplexing) // Sync I/O - return; - var dataSourceBuilder = CreateDataSourceBuilder(); dataSourceBuilder.UsePhysicalConnectionInitializer( conn => @@ -1538,9 +1653,7 @@ public async Task PhysicalConnectionInitializer_async_with_break() [Test] public async Task PhysicalConnectionInitializer_async_throws_on_second_open() { - // With multiplexing a physical connection might open on NpgsqlConnection.OpenAsync (if there was no completed bootstrap beforehand) - // or on NpgsqlCommand.ExecuteReaderAsync. - // We've already tested the first case in PhysicalConnectionInitializer_async_throws above, testing the second one below. + // We've already tested a simpler case in PhysicalConnectionInitializer_async_throws above, testing a second one below. var count = 0; var dataSourceBuilder = CreateDataSourceBuilder(); dataSourceBuilder.UsePhysicalConnectionInitializer( @@ -1556,18 +1669,10 @@ public async Task PhysicalConnectionInitializer_async_throws_on_second_open() await using var conn1 = dataSource.CreateConnection(); Assert.DoesNotThrowAsync(async () => await conn1.OpenAsync()); - // We start a transaction specifically for multiplexing (to bind a connector to the connection) await using var tx = await conn1.BeginTransactionAsync(); await using var conn2 = dataSource.CreateConnection(); - Exception exception; - if (IsMultiplexing) - { - await conn2.OpenAsync(); - exception = Assert.ThrowsAsync(async () => await conn2.BeginTransactionAsync())!; - } - else - exception = Assert.ThrowsAsync(async () => await conn2.OpenAsync())!; + var exception = Assert.ThrowsAsync(async () => await conn2.OpenAsync())!; Assert.That(exception.Message, Is.EqualTo("INTENTIONAL FAILURE")); } @@ -1595,14 +1700,144 @@ public async Task PhysicalConnectionInitializer_disposes_connection() #endregion Physical connection initialization + #region Require auth + + [Test] + public async Task Connect_with_any_auth() + { + await using var dataSource = CreateDataSource(csb => + { + csb.RequireAuth = $"{RequireAuthMode.Password},{RequireAuthMode.MD5},{RequireAuthMode.GSS},{RequireAuthMode.SSPI},{RequireAuthMode.ScramSHA256},{RequireAuthMode.None}"; + }); + await using var conn = await dataSource.OpenConnectionAsync(); + } + + [Test] + [NonParallelizable] // Sets environment variable + public async Task Connect_with_any_auth_env() + { + using var _ = SetEnvironmentVariable("PGREQUIREAUTH", $"{RequireAuthMode.Password},{RequireAuthMode.MD5},{RequireAuthMode.GSS},{RequireAuthMode.SSPI},{RequireAuthMode.ScramSHA256},{RequireAuthMode.None}"); + await using var dataSource = CreateDataSource(); + await using var conn = await dataSource.OpenConnectionAsync(); + } + + [Test] + public async Task Connect_with_any_except_none_auth() + { + await using var dataSource = CreateDataSource(csb => + { + csb.RequireAuth = $"!{RequireAuthMode.None}"; + }); + await using var conn = await dataSource.OpenConnectionAsync(); + } + + [Test] + [NonParallelizable] // Sets environment variable + public async Task Connect_with_any_except_none_auth_env() + { + using var _ = SetEnvironmentVariable("PGREQUIREAUTH", $"!{RequireAuthMode.None}"); + await using var dataSource = CreateDataSource(); + await using var conn = await dataSource.OpenConnectionAsync(); + } + + [Test] + public async Task Fail_connect_with_none_auth() + { + await using var dataSource = CreateDataSource(csb => + { + csb.RequireAuth = $"{RequireAuthMode.None}"; + }); + var ex = Assert.ThrowsAsync(async () => await dataSource.OpenConnectionAsync())!; + Assert.That(ex.Message, Does.Contain("authentication method is not allowed")); + } + + [Test] + [NonParallelizable] // Sets environment variable + public async Task Fail_connect_with_none_auth_env() + { + using var _ = SetEnvironmentVariable("PGREQUIREAUTH", $"{RequireAuthMode.None}"); + await using var dataSource = CreateDataSource(); + var ex = Assert.ThrowsAsync(async () => await dataSource.OpenConnectionAsync())!; + Assert.That(ex.Message, Does.Contain("authentication method is not allowed")); + } + + [Test] + public async Task Connect_with_md5_auth() + { + await using var dataSource = CreateDataSource(csb => + { + csb.RequireAuth = $"{RequireAuthMode.MD5}"; + }); + try + { + await using var conn = await dataSource.OpenConnectionAsync(); + } + catch (Exception e) when (!IsOnBuildServer) + { + Console.WriteLine(e); + Assert.Ignore("MD5 authentication doesn't seem to be set up"); + } + } + + [Test] + [NonParallelizable] // Sets environment variable + public async Task Connect_with_md5_auth_env() + { + using var _ = SetEnvironmentVariable("PGREQUIREAUTH", $"{RequireAuthMode.MD5}"); + await using var dataSource = CreateDataSource(); + try + { + await using var conn = await dataSource.OpenConnectionAsync(); + } + catch (Exception e) when (!IsOnBuildServer) + { + Console.WriteLine(e); + Assert.Ignore("MD5 authentication doesn't seem to be set up"); + } + } + + [Test] + public void Mixed_auth_methods_not_supported([Values( + $"{nameof(RequireAuthMode.ScramSHA256)},!{nameof(RequireAuthMode.None)}", + $"!{nameof(RequireAuthMode.ScramSHA256)},{nameof(RequireAuthMode.None)}")] + string authMethods) + { + var csb = new NpgsqlConnectionStringBuilder(); + Assert.Throws(() => csb.RequireAuth = authMethods); + } + + [Test] + public void Remove_all_auth_methods_throws() + { + var csb = new NpgsqlConnectionStringBuilder(); + Assert.Throws(() => + csb.RequireAuth = $"!{RequireAuthMode.Password},!{RequireAuthMode.MD5},!{RequireAuthMode.GSS},!{RequireAuthMode.SSPI},!{RequireAuthMode.ScramSHA256},!{RequireAuthMode.None}"); + } + + [Test] + public void Unknown_auth_method_throws() + { + var csb = new NpgsqlConnectionStringBuilder(); + Assert.Throws(() => csb.RequireAuth = "SuperSecure"); + } + + [Test] + public void Auth_methods_are_trimmed() + { + var csb = new NpgsqlConnectionStringBuilder + { + RequireAuth = $"{RequireAuthMode.Password} , {RequireAuthMode.MD5}" + }; + Assert.That(csb.RequireAuthModes, Is.EqualTo(RequireAuthMode.Password | RequireAuthMode.MD5)); + } + + #endregion Require auth + [Test] [NonParallelizable] // Modifies global database info factories [IssueLink("https://github.com/npgsql/npgsql/issues/4425")] public async Task Breaking_connection_while_loading_database_info() { - if (IsMultiplexing) - return; - await using var dataSource = CreateDataSource(); await using var firstConn = dataSource.CreateConnection(); @@ -1689,12 +1924,9 @@ public async Task Log_Open_Close_pooled() AssertLoggingStateContains(closedConnectionEvent, "Port", port); AssertLoggingStateContains(closedConnectionEvent, "Database", database); - if (!IsMultiplexing) - { - AssertLoggingStateContains(openedConnectionEvent, "ConnectorId", processId); - AssertLoggingStateContains(closingConnectionEvent, "ConnectorId", processId); - AssertLoggingStateContains(closedConnectionEvent, "ConnectorId", processId); - } + AssertLoggingStateContains(openedConnectionEvent, "ConnectorId", processId); + AssertLoggingStateContains(closingConnectionEvent, "ConnectorId", processId); + AssertLoggingStateContains(closedConnectionEvent, "ConnectorId", processId); var ids = new[] { @@ -1711,9 +1943,6 @@ public async Task Log_Open_Close_pooled() [Test] public async Task Log_Open_Close_physical() { - if (IsMultiplexing) - return; - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) { Pooling = false }; await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider, csb.ToString()); await using var conn = dataSource.CreateConnection(); @@ -1765,6 +1994,4 @@ void AssertLoggingConnectionString(NpgsqlConnection connection, object? logState } #endregion Logging tests - - public ConnectionTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/CopyTests.cs b/test/Npgsql.Tests/CopyTests.cs index d84ad7b53a..6d1421fe82 100644 --- a/test/Npgsql.Tests/CopyTests.cs +++ b/test/Npgsql.Tests/CopyTests.cs @@ -17,7 +17,7 @@ namespace Npgsql.Tests; -public class CopyTests : MultiplexingTestBase +public class CopyTests : TestBase { #region Issue 2257 @@ -227,8 +227,6 @@ public async Task Wrong_table_definition_raw_binary_copy() [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2330")] public async Task Wrong_format_raw_binary_copy() { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: fails"); using (var conn = await OpenConnectionAsync()) { var table = await CreateTempTable(conn, "blob BYTEA"); @@ -362,9 +360,9 @@ public async Task Import_numeric() await using var cmd = conn.CreateCommand(); cmd.CommandText = $"SELECT field FROM {table}"; await using var reader = await cmd.ExecuteReaderAsync(); - Assert.IsTrue(await reader.ReadAsync()); + Assert.That(await reader.ReadAsync()); Assert.That(reader.GetValue(0), Is.EqualTo(1234m)); - Assert.IsTrue(await reader.ReadAsync()); + Assert.That(await reader.ReadAsync()); Assert.That(reader.GetValue(0), Is.EqualTo(5678m)); } @@ -386,6 +384,26 @@ public async Task Import_string_array() Assert.That(await conn.ExecuteScalarAsync($"SELECT field FROM {table}"), Is.EqualTo(data)); } + [Test] + public async Task Import_DBNull_then_other_object() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "field TEXT"); + + object data = "foo"; + using (var writer = conn.BeginBinaryImport($"COPY {table} (field) FROM STDIN BINARY")) + { + writer.StartRow(); + writer.Write((object?)DBNull.Value); + writer.StartRow(); + writer.Write(data); + var rowsWritten = writer.Complete(); + Assert.That(rowsWritten, Is.EqualTo(2)); + } + + Assert.That(await conn.ExecuteScalarAsync($"SELECT field FROM {table} OFFSET 1"), Is.EqualTo(data)); + } + [Test] public async Task Import_reused_instance_mapping_info_identical_or_throws() { @@ -457,10 +475,10 @@ public async Task Import_object_null() } static readonly TestCaseData[] DBNullValues = - { + [ new TestCaseData(DBNull.Value).SetName("DBNull.Value"), new TestCaseData(null).SetName("null") - }; + ]; [Test, TestCaseSource(nameof(DBNullValues))] public async Task Import_dbnull(DBNull? value) @@ -492,8 +510,6 @@ public async Task Wrong_table_definition_binary_import() [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2330")] public async Task Wrong_format_binary_import() { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: fails"); using var conn = await OpenConnectionAsync(); var table = await CreateTempTable(conn, "blob BYTEA"); Assert.Throws(() => conn.BeginBinaryImport($"COPY {table} (blob) FROM STDIN")); @@ -513,8 +529,6 @@ public async Task Wrong_table_definition_binary_export() [Test, IssueLink("https://github.com/npgsql/npgsql/issues/5457")] public async Task MixedOperations() { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: fails"); using var conn = await OpenConnectionAsync(); using var reader = conn.BeginBinaryExport(""" @@ -538,8 +552,6 @@ public async Task MixedOperations() [Test] public async Task ReadMoreColumnsThanExist() { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: fails"); using var conn = await OpenConnectionAsync(); using var reader = conn.BeginBinaryExport(""" @@ -565,8 +577,6 @@ public async Task ReadMoreColumnsThanExist() [Test] public async Task ReadZeroSizedColumns() { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: fails"); using var conn = await OpenConnectionAsync(); using var reader = conn.BeginBinaryExport(""" @@ -597,8 +607,6 @@ public async Task ReadZeroSizedColumns() [Test] public async Task ReadConverterResolverType() { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: fails"); using var conn = await OpenConnectionAsync(); using (var reader = conn.BeginBinaryExport(""" @@ -633,8 +641,6 @@ public async Task ReadConverterResolverType() [Test] public async Task StreamingRead() { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: fails"); using var conn = await OpenConnectionAsync(); var str = new string('a', PgReader.MaxPreparedTextReaderSize + 1); @@ -648,8 +654,6 @@ public async Task StreamingRead() [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2330")] public async Task Wrong_format_binary_export() { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: fails"); using var conn = await OpenConnectionAsync(); var table = await CreateTempTable(conn, "blob BYTEA"); Assert.Throws(() => conn.BeginBinaryExport($"COPY {table} (blob) TO STDOUT")); @@ -657,12 +661,8 @@ public async Task Wrong_format_binary_export() } [Test, NonParallelizable, IssueLink("https://github.com/npgsql/npgsql/issues/661")] - [Ignore("Unreliable")] public async Task Unexpected_exception_binary_import() { - if (IsMultiplexing) - return; - // Use a private data source since we terminate the connection below (affects database state) await using var dataSource = CreateDataSource(); await using var conn = await dataSource.OpenConnectionAsync(); @@ -681,7 +681,7 @@ public async Task Unexpected_exception_binary_import() writer.StartRow(); writer.Write(data); writer.Dispose(); - }, Throws.Exception.TypeOf()); + }, Throws.Exception.InstanceOf()); Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); } @@ -733,7 +733,7 @@ public async Task Export_long_string() { var str = reader.Read(); Assert.That(str.Length, Is.EqualTo(len)); - Assert.True(str.AsSpan().IndexOfAnyExcept('x') is -1); + Assert.That(str.AsSpan().IndexOfAnyExcept('x') is -1); } } Assert.That(row, Is.EqualTo(100)); @@ -753,12 +753,13 @@ await conn.ExecuteNonQueryAsync($@" using var reader = conn.BeginBinaryExport($"COPY {table} (bits, bitvector, bitarray) TO STDIN BINARY"); reader.StartRow(); - Assert.That(reader.Read(), Is.EqualTo(new BitArray(new[] { false, false, false, false, false, false, false, true, true, false, true }))); + Assert.That(reader.Read(), Is.EqualTo(new BitArray([false, false, false, false, false, false, false, true, true, false, true + ]))); Assert.That(reader.Read(), Is.EqualTo(new BitVector32(0b00000001101000000000000000000000))); Assert.That(reader.Read(), Is.EqualTo(new[] { - new BitArray(new[] { true, false, true }), - new BitArray(new[] { true, true, true }) + new BitArray([true, false, true]), + new BitArray([true, true, true]) })); } @@ -957,7 +958,7 @@ public async Task Cancel_raw_binary_export_when_not_consumed_and_then_Dispose() // This must be large enough to cause Postgres to queue up CopyData messages. var stream = conn.BeginRawBinaryCopy("COPY (select md5(random()::text) as id from generate_series(1, 100000)) TO STDOUT BINARY"); var buffer = new byte[32]; - await stream.ReadAsync(buffer, 0, buffer.Length); + await stream.ReadExactlyAsync(buffer, 0, buffer.Length); stream.Cancel(); Assert.DoesNotThrowAsync(async () => await stream.DisposeAsync()); } @@ -1102,8 +1103,6 @@ await conn.ExecuteNonQueryAsync($@" [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2330")] public async Task Wrong_table_definition_text_import() { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: fails"); using var conn = await OpenConnectionAsync(); Assert.Throws(() => conn.BeginTextImport("COPY table_is_not_exist (blob) FROM STDIN")); Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); @@ -1113,8 +1112,6 @@ public async Task Wrong_table_definition_text_import() [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2330")] public async Task Wrong_format_text_import() { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: fails"); using var conn = await OpenConnectionAsync(); var table = await CreateTempTable(conn, "blob BYTEA"); Assert.Throws(() => conn.BeginTextImport($"COPY {table} (blob) FROM STDIN BINARY")); @@ -1124,8 +1121,6 @@ public async Task Wrong_format_text_import() [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2330")] public async Task Wrong_table_definition_text_export() { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: fails"); using var conn = await OpenConnectionAsync(); Assert.Throws(() => conn.BeginTextExport("COPY table_is_not_exist (blob) TO STDOUT")); Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Open)); @@ -1135,8 +1130,6 @@ public async Task Wrong_table_definition_text_export() [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2330")] public async Task Wrong_format_text_export() { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: fails"); using var conn = await OpenConnectionAsync(); var table = await CreateTempTable(conn, "blob BYTEA"); Assert.Throws(() => conn.BeginTextExport($"COPY {table} (blob) TO STDOUT BINARY")); @@ -1270,7 +1263,7 @@ public async Task Write_different_types() Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(2)); } - [Test, Description("Tests nested binding scopes in multiplexing")] + [Test] public async Task Within_transaction() { using var conn = await OpenConnectionAsync(); @@ -1313,8 +1306,7 @@ public async Task Within_transaction() [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4199")] public async Task Copy_from_is_not_supported_in_regular_command_execution() { - // Run in a separate pool to protect other queries in multiplexing - // because we're going to break the connection on CopyInResponse + // Run in a separate pool because we're going to break the connection on CopyInResponse await using var dataSource = CreateDataSource(); await using var conn = await dataSource.OpenConnectionAsync(); var table = await CreateTempTable(conn, "foo INT"); @@ -1325,8 +1317,7 @@ public async Task Copy_from_is_not_supported_in_regular_command_execution() [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4974")] public async Task Copy_to_is_not_supported_in_regular_command_execution() { - // Run in a separate pool to protect other queries in multiplexing - // because we're going to break the connection on CopyInResponse + // Run in a separate pool because we're going to break the connection on CopyInResponse await using var dataSource = CreateDataSource(); await using var conn = await dataSource.OpenConnectionAsync(); var table = await CreateTempTable(conn, "foo INT"); @@ -1373,6 +1364,4 @@ void StateAssertions(NpgsqlConnection conn) } #endregion - - public CopyTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/DataAdapterTests.cs b/test/Npgsql.Tests/DataAdapterTests.cs index 4b413409d7..91de36e734 100644 --- a/test/Npgsql.Tests/DataAdapterTests.cs +++ b/test/Npgsql.Tests/DataAdapterTests.cs @@ -92,8 +92,8 @@ public async Task Insert_with_DataSet() var dr2 = new NpgsqlCommand($"SELECT field_int2, field_numeric, field_timestamp FROM {table}", conn).ExecuteReader(); dr2.Read(); - Assert.AreEqual(4, dr2[0]); - Assert.AreEqual(7.3000000M, dr2[1]); + Assert.That(dr2[0], Is.EqualTo(4)); + Assert.That(dr2[1], Is.EqualTo(7.3000000M)); dr2.Close(); } @@ -137,11 +137,10 @@ public async Task DataAdapter_update_return_value() var ds2 = ds.GetChanges()!; var daupdate = da.Update(ds2); - Assert.AreEqual(2, daupdate); + Assert.That(daupdate, Is.EqualTo(2)); } [Test] - [Ignore("")] public async Task DataAdapter_update_return_value2() { using var conn = await OpenConnectionAsync(); @@ -158,15 +157,15 @@ public async Task DataAdapter_update_return_value2() da.Update(ds); //## change id from 1 to 2 - cmd.CommandText = $"update {table} set field_float4 = 0.8"; + cmd.CommandText = $"update {table} set field_numeric = 0.8"; cmd.ExecuteNonQuery(); //## change value to newvalue ds.Tables[0].Rows[0][1] = 0.7; //## update should fail, and make a DBConcurrencyException var count = da.Update(ds); - //## count is 1, even if the isn't updated in the database - Assert.AreEqual(0, count); + //## count is 1, even if the row isn't updated in the database + Assert.That(count, Is.EqualTo(1)); } [Test] @@ -180,16 +179,15 @@ public async Task Fill_with_empty_resultset() da.Fill(ds); - Assert.AreEqual(1, ds.Tables.Count); - Assert.AreEqual(4, ds.Tables[0].Columns.Count); - Assert.AreEqual("field_serial", ds.Tables[0].Columns[0].ColumnName); - Assert.AreEqual("field_int2", ds.Tables[0].Columns[1].ColumnName); - Assert.AreEqual("field_timestamp", ds.Tables[0].Columns[2].ColumnName); - Assert.AreEqual("field_numeric", ds.Tables[0].Columns[3].ColumnName); + Assert.That(ds.Tables.Count, Is.EqualTo(1)); + Assert.That(ds.Tables[0].Columns.Count, Is.EqualTo(4)); + Assert.That(ds.Tables[0].Columns[0].ColumnName, Is.EqualTo("field_serial")); + Assert.That(ds.Tables[0].Columns[1].ColumnName, Is.EqualTo("field_int2")); + Assert.That(ds.Tables[0].Columns[2].ColumnName, Is.EqualTo("field_timestamp")); + Assert.That(ds.Tables[0].Columns[3].ColumnName, Is.EqualTo("field_numeric")); } [Test] - [Ignore("")] public async Task Fill_add_with_key() { using var conn = await OpenConnectionAsync(); @@ -206,33 +204,33 @@ public async Task Fill_add_with_key() var field_timestamp = ds.Tables[0].Columns[2]; var field_numeric = ds.Tables[0].Columns[3]; - Assert.IsFalse(field_serial.AllowDBNull); - Assert.IsTrue(field_serial.AutoIncrement); - Assert.AreEqual("field_serial", field_serial.ColumnName); - Assert.AreEqual(typeof(int), field_serial.DataType); - Assert.AreEqual(0, field_serial.Ordinal); - Assert.IsTrue(field_serial.Unique); - - Assert.IsTrue(field_int2.AllowDBNull); - Assert.IsFalse(field_int2.AutoIncrement); - Assert.AreEqual("field_int2", field_int2.ColumnName); - Assert.AreEqual(typeof(short), field_int2.DataType); - Assert.AreEqual(1, field_int2.Ordinal); - Assert.IsFalse(field_int2.Unique); - - Assert.IsTrue(field_timestamp.AllowDBNull); - Assert.IsFalse(field_timestamp.AutoIncrement); - Assert.AreEqual("field_timestamp", field_timestamp.ColumnName); - Assert.AreEqual(typeof(DateTime), field_timestamp.DataType); - Assert.AreEqual(2, field_timestamp.Ordinal); - Assert.IsFalse(field_timestamp.Unique); - - Assert.IsTrue(field_numeric.AllowDBNull); - Assert.IsFalse(field_numeric.AutoIncrement); - Assert.AreEqual("field_numeric", field_numeric.ColumnName); - Assert.AreEqual(typeof(decimal), field_numeric.DataType); - Assert.AreEqual(3, field_numeric.Ordinal); - Assert.IsFalse(field_numeric.Unique); + Assert.That(field_serial.AllowDBNull, Is.False); + Assert.That(field_serial.AutoIncrement); + Assert.That(field_serial.ColumnName, Is.EqualTo("field_serial")); + Assert.That(field_serial.DataType, Is.EqualTo(typeof(int))); + Assert.That(field_serial.Ordinal, Is.EqualTo(0)); + Assert.That(field_serial.Unique, Is.False); + + Assert.That(field_int2.AllowDBNull); + Assert.That(field_int2.AutoIncrement, Is.False); + Assert.That(field_int2.ColumnName, Is.EqualTo("field_int2")); + Assert.That(field_int2.DataType, Is.EqualTo(typeof(short))); + Assert.That(field_int2.Ordinal, Is.EqualTo(1)); + Assert.That(field_int2.Unique, Is.False); + + Assert.That(field_timestamp.AllowDBNull); + Assert.That(field_timestamp.AutoIncrement, Is.False); + Assert.That(field_timestamp.ColumnName, Is.EqualTo("field_timestamp")); + Assert.That(field_timestamp.DataType, Is.EqualTo(typeof(DateTime))); + Assert.That(field_timestamp.Ordinal, Is.EqualTo(2)); + Assert.That(field_timestamp.Unique, Is.False); + + Assert.That(field_numeric.AllowDBNull); + Assert.That(field_numeric.AutoIncrement, Is.False); + Assert.That(field_numeric.ColumnName, Is.EqualTo("field_numeric")); + Assert.That(field_numeric.DataType, Is.EqualTo(typeof(decimal))); + Assert.That(field_numeric.Ordinal, Is.EqualTo(3)); + Assert.That(field_numeric.Unique, Is.False); } [Test] @@ -252,21 +250,21 @@ public async Task Fill_add_columns() var field_timestamp = ds.Tables[0].Columns[2]; var field_numeric = ds.Tables[0].Columns[3]; - Assert.AreEqual("field_serial", field_serial.ColumnName); - Assert.AreEqual(typeof(int), field_serial.DataType); - Assert.AreEqual(0, field_serial.Ordinal); + Assert.That(field_serial.ColumnName, Is.EqualTo("field_serial")); + Assert.That(field_serial.DataType, Is.EqualTo(typeof(int))); + Assert.That(field_serial.Ordinal, Is.EqualTo(0)); - Assert.AreEqual("field_int2", field_int2.ColumnName); - Assert.AreEqual(typeof(short), field_int2.DataType); - Assert.AreEqual(1, field_int2.Ordinal); + Assert.That(field_int2.ColumnName, Is.EqualTo("field_int2")); + Assert.That(field_int2.DataType, Is.EqualTo(typeof(short))); + Assert.That(field_int2.Ordinal, Is.EqualTo(1)); - Assert.AreEqual("field_timestamp", field_timestamp.ColumnName); - Assert.AreEqual(typeof(DateTime), field_timestamp.DataType); - Assert.AreEqual(2, field_timestamp.Ordinal); + Assert.That(field_timestamp.ColumnName, Is.EqualTo("field_timestamp")); + Assert.That(field_timestamp.DataType, Is.EqualTo(typeof(DateTime))); + Assert.That(field_timestamp.Ordinal, Is.EqualTo(2)); - Assert.AreEqual("field_numeric", field_numeric.ColumnName); - Assert.AreEqual(typeof(decimal), field_numeric.DataType); - Assert.AreEqual(3, field_numeric.Ordinal); + Assert.That(field_numeric.ColumnName, Is.EqualTo("field_numeric")); + Assert.That(field_numeric.DataType, Is.EqualTo(typeof(decimal))); + Assert.That(field_numeric.Ordinal, Is.EqualTo(3)); } [Test] @@ -302,9 +300,9 @@ public async Task Update_letting_null_field_falue() da.Fill(ds); var dt = ds.Tables[0]; - Assert.IsNotNull(dt); + Assert.That(dt, Is.Not.Null); - var dr = ds.Tables[0].Rows[ds.Tables[0].Rows.Count - 1]; + var dr = ds.Tables[0].Rows[^1]; dr["field_int2"] = 4; var ds2 = ds.GetChanges()!; @@ -314,7 +312,7 @@ public async Task Update_letting_null_field_falue() using var dr2 = new NpgsqlCommand($"SELECT field_int2 FROM {table}", conn).ExecuteReader(); dr2.Read(); - Assert.AreEqual(4, dr2["field_int2"]); + Assert.That(dr2["field_int2"], Is.EqualTo(4)); } [Test] @@ -329,7 +327,6 @@ public async Task Fill_with_duplicate_column_name() } [Test] - [Ignore("")] public Task Update_with_DataSet() => DoUpdateWithDataSet(); public async Task DoUpdateWithDataSet() @@ -343,14 +340,14 @@ public async Task DoUpdateWithDataSet() var ds = new DataSet(); var da = new NpgsqlDataAdapter($"select * from {table}", conn); var cb = new NpgsqlCommandBuilder(da); - Assert.IsNotNull(cb); + Assert.That(cb, Is.Not.Null); da.Fill(ds); var dt = ds.Tables[0]; - Assert.IsNotNull(dt); + Assert.That(dt, Is.Not.Null); - var dr = ds.Tables[0].Rows[ds.Tables[0].Rows.Count - 1]; + var dr = ds.Tables[0].Rows[^1]; dr["field_int2"] = 4; @@ -361,11 +358,10 @@ public async Task DoUpdateWithDataSet() using var dr2 = new NpgsqlCommand($"select * from {table}", conn).ExecuteReader(); dr2.Read(); - Assert.AreEqual(4, dr2["field_int2"]); + Assert.That(dr2["field_int2"], Is.EqualTo(4)); } [Test] - [Ignore("")] public async Task Insert_with_CommandBuilder_case_sensitive() { using var conn = await OpenConnectionAsync(); @@ -374,13 +370,13 @@ public async Task Insert_with_CommandBuilder_case_sensitive() var ds = new DataSet(); var da = new NpgsqlDataAdapter($"select * from {table}", conn); var builder = new NpgsqlCommandBuilder(da); - Assert.IsNotNull(builder); + Assert.That(builder, Is.Not.Null); da.Fill(ds); var dt = ds.Tables[0]; var dr = dt.NewRow(); - dr["Field_Case_Sensitive"] = 4; + dr["Field_int4"] = 4; dt.Rows.Add(dr); var ds2 = ds.GetChanges()!; @@ -390,7 +386,7 @@ public async Task Insert_with_CommandBuilder_case_sensitive() using var dr2 = new NpgsqlCommand($"select * from {table}", conn).ExecuteReader(); dr2.Read(); - Assert.AreEqual(4, dr2[1]); + Assert.That(dr2["field_int4"], Is.EqualTo(4)); } [Test] @@ -449,12 +445,11 @@ public async Task DataAdapter_command_access() var da = new NpgsqlDataAdapter(); da.SelectCommand = command; System.Data.Common.DbDataAdapter common = da; - Assert.IsNotNull(common.SelectCommand); + Assert.That(common.SelectCommand, Is.Not.Null); } [Test, Description("Makes sure that the INSERT/UPDATE/DELETE commands are auto-populated on NpgsqlDataAdapter")] [IssueLink("https://github.com/npgsql/npgsql/issues/179")] - [Ignore("Somehow related to us using a temporary table???")] public async Task Auto_populate_adapter_commands() { using var conn = await OpenConnectionAsync(); @@ -494,7 +489,6 @@ public void Command_builder_quoting() [Test, Description("Makes sure a correct SQL string is built with GetUpdateCommand(true) using correct parameter names and placeholders")] [IssueLink("https://github.com/npgsql/npgsql/issues/397")] - [Ignore("Somehow related to us using a temporary table???")] public async Task Get_UpdateCommand() { using var conn = await OpenConnectionAsync(); @@ -532,13 +526,13 @@ public async Task Load_DataTable() dt.Load(dr); dr.Close(); - Assert.AreEqual(5, dt.Columns[0].MaxLength); - Assert.AreEqual(5, dt.Columns[1].MaxLength); + Assert.That(dt.Columns[0].MaxLength, Is.EqualTo(5)); + Assert.That(dt.Columns[1].MaxLength, Is.EqualTo(5)); } public Task SetupTempTable(NpgsqlConnection conn) => CreateTempTable(conn, @" -field_pk SERIAL PRIMARY KEY, +field_pk INTEGER PRIMARY KEY GENERATED ALWAYS AS IDENTITY, field_serial SERIAL, field_int2 SMALLINT, field_int4 INTEGER, diff --git a/test/Npgsql.Tests/DataSourceTests.cs b/test/Npgsql.Tests/DataSourceTests.cs index 639e83a795..c2ef7bc9cb 100644 --- a/test/Npgsql.Tests/DataSourceTests.cs +++ b/test/Npgsql.Tests/DataSourceTests.cs @@ -5,6 +5,7 @@ using System.Text.Json.Serialization; using System.Threading.Tasks; using NUnit.Framework; +using static Npgsql.Tests.TestUtil; // ReSharper disable MethodHasAsyncOverload @@ -75,7 +76,7 @@ public async Task ExecuteReader_on_connectionless_command([Values] bool async) await using (var reader = async ? await command.ExecuteReaderAsync() : command.ExecuteReader()) { - Assert.True(reader.Read()); + Assert.That(reader.Read()); Assert.That(reader.GetInt32(0), Is.EqualTo(1)); } @@ -124,16 +125,40 @@ public async Task ExecuteReader_on_connectionless_batch([Values] bool async) using (var reader = async ? await batch.ExecuteReaderAsync() : batch.ExecuteReader()) { - Assert.True(reader.Read()); + Assert.That(reader.Read()); Assert.That(reader.GetInt32(0), Is.EqualTo(1)); - Assert.True(reader.NextResult()); - Assert.True(reader.Read()); + Assert.That(reader.NextResult()); + Assert.That(reader.Read()); Assert.That(reader.GetInt32(0), Is.EqualTo(2)); } Assert.That(dataSource.Statistics, Is.EqualTo((Total: 1, Idle: 1, Busy: 0))); } + [Test] + public void Clear() + { + using var dataSource = NpgsqlDataSource.Create(ConnectionString); + var connection1 = dataSource.OpenConnection(); + var connection2 = dataSource.OpenConnection(); + connection1.Close(); + + Assert.That(dataSource.Statistics, Is.EqualTo((Total: 2, Idle: 1, Busy: 1))); + + dataSource.Clear(); + + Assert.That(dataSource.Statistics, Is.EqualTo((Total: 1, Idle: 0, Busy: 1))); + + var connection3 = dataSource.OpenConnection(); + Assert.That(dataSource.Statistics, Is.EqualTo((Total: 2, Idle: 0, Busy: 2))); + + connection2.Close(); + Assert.That(dataSource.Statistics, Is.EqualTo((Total: 1, Idle: 0, Busy: 1))); + + connection3.Close(); + Assert.That(dataSource.Statistics, Is.EqualTo((Total: 1, Idle: 1, Busy: 0))); + } + [Test] public void Dispose() { @@ -264,39 +289,15 @@ public async Task As_DbDataSource([Values] bool async) } [Test] - public async Task Executing_command_on_disposed_datasource([Values] bool multiplexing) + public async Task Executing_command_on_disposed_datasource() { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - Multiplexing = multiplexing - }; - DbDataSource dataSource = NpgsqlDataSource.Create(csb.ConnectionString); + DbDataSource dataSource = NpgsqlDataSource.Create(ConnectionString); await using (var _ = await dataSource.OpenConnectionAsync()) {} await dataSource.DisposeAsync(); await using var command = dataSource.CreateCommand("SELECT 1"); Assert.ThrowsAsync(command.ExecuteNonQueryAsync); } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4840")] - public async Task Multiplexing_connectionless_command_open_connection() - { - var csb = new NpgsqlConnectionStringBuilder(ConnectionString) - { - Multiplexing = true - }; - await using var dataSource = NpgsqlDataSource.Create(csb.ConnectionString); - - await using var conn = await dataSource.OpenConnectionAsync(); - await using var _ = await conn.BeginTransactionAsync(); - - await using var command = dataSource.CreateCommand(); - command.CommandText = "SELECT 1"; - - await using var reader = await command.ExecuteReaderAsync(); - Assert.True(reader.Read()); - Assert.That(reader.GetInt32(0), Is.EqualTo(1)); - } - [Test] public async Task Connection_string_builder_settings_are_frozen_on_Build() { @@ -356,4 +357,72 @@ public async Task ConfigureJsonOptions_is_order_independent() Assert.That(reader.GetFieldValue(0).Id, Is.EqualTo(1)); } } + + [Test] + public async Task ReloadTypes([Values] bool async) + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapEnum(type); + await using var dataSource = dataSourceBuilder.Build(); + + await using var connection = await dataSource.OpenConnectionAsync(); + await connection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); + + if (async) + await dataSource.ReloadTypesAsync(); + else + dataSource.ReloadTypes(); + + Assert.ThrowsAsync(async () => await connection.ExecuteScalarAsync($"SELECT 'happy'::{type}")); + + // Close connection and reopen to make sure it picks up the new type and mapping from the data source + await connection.CloseAsync(); + await connection.OpenAsync(); + + Assert.DoesNotThrowAsync(async () => await connection.ExecuteScalarAsync($"SELECT 'happy'::{type}")); + } + + [Test] + public async Task ReloadTypes_across_data_sources([Values] bool async) + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapEnum(type); + await using var dataSource1 = dataSourceBuilder.Build(); + await using var connection1 = await dataSource1.OpenConnectionAsync(); + + await using var dataSource2 = dataSourceBuilder.Build(); + await using var connection2 = await dataSource2.OpenConnectionAsync(); + + await connection1.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); + + if (async) + await dataSource1.ReloadTypesAsync(); + else + dataSource1.ReloadTypes(); + + Assert.ThrowsAsync(async () => await connection1.ExecuteScalarAsync($"SELECT 'happy'::{type}")); + Assert.ThrowsAsync(async () => await connection2.ExecuteScalarAsync($"SELECT 'happy'::{type}")); + + // Close connection and reopen to check that the new type and mapping is not available in dataSource2 + await connection2.CloseAsync(); + await connection2.OpenAsync(); + + Assert.ThrowsAsync(async () => await connection2.ExecuteScalarAsync($"SELECT 'happy'::{type}")); + + await dataSource2.ReloadTypesAsync(); + + // Close connection2 and reopen to make sure it picks up the new type and mapping from dataSource2 + await connection2.CloseAsync(); + await connection2.OpenAsync(); + + Assert.DoesNotThrowAsync(async () => await connection2.ExecuteScalarAsync($"SELECT 'happy'::{type}")); + } + + enum Mood { Sad, Ok, Happy } } diff --git a/test/Npgsql.Tests/DataTypeNameTests.cs b/test/Npgsql.Tests/DataTypeNameTests.cs index fd366d8258..067eb217c4 100644 --- a/test/Npgsql.Tests/DataTypeNameTests.cs +++ b/test/Npgsql.Tests/DataTypeNameTests.cs @@ -12,7 +12,7 @@ public void MaxLengthDataTypeName() var name = new string('a', DataTypeName.NAMEDATALEN); var fullyQualifiedDataTypeName= $"public.{name}"; Assert.DoesNotThrow(() => new DataTypeName(fullyQualifiedDataTypeName)); - Assert.AreEqual(new DataTypeName(fullyQualifiedDataTypeName).Value, fullyQualifiedDataTypeName); + Assert.That(fullyQualifiedDataTypeName, Is.EqualTo(new DataTypeName(fullyQualifiedDataTypeName).Value)); } [Test] @@ -23,4 +23,46 @@ public void TooLongDataTypeName() var exception = Assert.Throws(() => new DataTypeName(fullyQualifiedDataTypeName)); Assert.That(exception!.Message, Does.EndWith($": public.{new string('a', DataTypeName.NAMEDATALEN)}")); } + + [TestCase("public.name", ExpectedResult = "public._name")] + [TestCase("public._name", ExpectedResult = "public._name")] + [TestCase("public.zzzaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa123", ExpectedResult = "public._zzzaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa12")] + public string ToArrayName(string name) + => new DataTypeName(name).ToArrayName(); + + [TestCase("public.multirange", ExpectedResult = "public.multirange")] + [TestCase("public.abcmultirange123", ExpectedResult = "public.abcmultirange123")] + [TestCase("public.multiRANGE", ExpectedResult = "public.multiRANGE_multirange")] + public string ToDefaultMultirangeNameHasMultiRange(string name) + => new DataTypeName(name).ToDefaultMultirangeName(); + + [TestCase("public.range", ExpectedResult = "public.multirange")] + [TestCase("public.abcrange123", ExpectedResult = "public.abcmultirange123")] + [TestCase("public.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaarange", ExpectedResult = "public.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaamultirange")] // Replace goes to max length + [TestCase("public.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaarange1", ExpectedResult = "public.aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaamultir")] // Replace goes over max length + [TestCase("public.RANGE", ExpectedResult = "public.RANGE_multirange")] + public string ToDefaultMultirangeNameHasRange(string name) + => new DataTypeName(name).ToDefaultMultirangeName(); + + [TestCase("public.name", null, ExpectedResult = "public.name")] + [TestCase("public._name", null, ExpectedResult = "public._name")] + [TestCase("public.name[]", null, ExpectedResult = "public._name")] + [TestCase("public.integer", null, ExpectedResult = "public.integer")] + [TestCase("name", null, ExpectedResult = "pg_catalog.name")] + [TestCase("_name", null, ExpectedResult = "pg_catalog._name")] + [TestCase("name[]", null, ExpectedResult = "pg_catalog._name")] + [TestCase("mytype", null, ExpectedResult = "-.mytype")] + [TestCase("_mytype", null, ExpectedResult = "-._mytype")] + [TestCase("mytype[]", null, ExpectedResult = "-._mytype")] + [TestCase("character varying", null, ExpectedResult = "pg_catalog.varchar")] + [TestCase("decimal(facet_name)", null, ExpectedResult = "pg_catalog.numeric")] + [TestCase("name", "public", ExpectedResult = "public.name")] + [TestCase("name ", "public", ExpectedResult = "public.name")] + [TestCase("_name", "public", ExpectedResult = "public._name")] + [TestCase("name[]", "public", ExpectedResult = "public._name")] + [TestCase("timestamp with time zone", "public", ExpectedResult = "public.timestamptz")] + [TestCase("boolean(facet_name)", "public", ExpectedResult = "public.bool")] + [TestCase(" public.name ", null, ExpectedResult = "public.name")] + public string FromDisplayName(string name, string? schema) + => DataTypeName.FromDisplayName(name, schema).Value; } diff --git a/test/Npgsql.Tests/DistributedTransactionTests.cs b/test/Npgsql.Tests/DistributedTransactionTests.cs index e55d6e7bd9..aab4447ff2 100644 --- a/test/Npgsql.Tests/DistributedTransactionTests.cs +++ b/test/Npgsql.Tests/DistributedTransactionTests.cs @@ -1,5 +1,3 @@ -#if NET7_0_OR_GREATER - using System; using System.Collections.Concurrent; using System.Collections.Generic; @@ -181,12 +179,12 @@ public void Transaction_race([Values(false, true)] bool distributed) } catch (Exception ex) { - Assert.Fail( - @"Failed at iteration {0}. -Events: -{1} -Exception {2}", - i, FormatEventQueue(eventQueue), ex); + Assert.Fail($""" + Failed at iteration {i}. + Events: + {FormatEventQueue(eventQueue)} + Exception {ex} + """); } } } @@ -235,12 +233,12 @@ public void Connection_reuse_race_after_transaction([Values(false, true)] bool d } catch (Exception ex) { - Assert.Fail( - @"Failed at iteration {0}. -Events: -{1} -Exception {2}", - i, FormatEventQueue(eventQueue), ex); + Assert.Fail($""" + Failed at iteration {i}. + Events: + {FormatEventQueue(eventQueue)} + Exception {ex} + """); } } } @@ -289,12 +287,12 @@ public void Connection_reuse_race_after_rollback([Values(false, true)] bool dist } catch (Exception ex) { - Assert.Fail( - @"Failed at iteration {0}. -Events: -{1} -Exception {2}", - i, FormatEventQueue(eventQueue), ex); + Assert.Fail($""" + Failed at iteration {i}. + Events: + {FormatEventQueue(eventQueue)} + Exception {ex} + """); } } } @@ -367,12 +365,12 @@ public void Connection_reuse_race_chaining_transaction([Values(false, true)] boo } catch (Exception ex) { - Assert.Fail( - @"Failed at iteration {0}. -Events: -{1} -Exception {2}", - i, FormatEventQueue(eventQueue), ex); + Assert.Fail($""" + Failed at iteration {i}. + Events: + {FormatEventQueue(eventQueue)} + Exception {ex} + """); } } } @@ -564,11 +562,9 @@ void Current_TransactionCompleted(object sender, TransactionEventArgs e) } } - public class TransactionEvent + public class TransactionEvent(string message) { - public TransactionEvent(string message) - => Message = $"{message} (TId {Thread.CurrentThread.ManagedThreadId})"; - public string Message { get; } + public string Message { get; } = $"{message} (TId {Thread.CurrentThread.ManagedThreadId})"; } #endregion Utilities @@ -635,5 +631,3 @@ internal static string CreateTempTable(NpgsqlConnection conn, string columns) #endregion } - -#endif diff --git a/test/Npgsql.Tests/ExceptionTests.cs b/test/Npgsql.Tests/ExceptionTests.cs index ac87ef2b0e..ec7e7f18db 100644 --- a/test/Npgsql.Tests/ExceptionTests.cs +++ b/test/Npgsql.Tests/ExceptionTests.cs @@ -203,107 +203,10 @@ public void NpgsqlException_with_async() [Test] public void NpgsqlException_IsTransient() { - Assert.True(new NpgsqlException("", new IOException()).IsTransient); - Assert.True(new NpgsqlException("", new SocketException()).IsTransient); - Assert.True(new NpgsqlException("", new TimeoutException()).IsTransient); - Assert.False(new NpgsqlException().IsTransient); - Assert.False(new NpgsqlException("", new Exception("Inner Exception")).IsTransient); + Assert.That(new NpgsqlException("", new IOException()).IsTransient); + Assert.That(new NpgsqlException("", new SocketException()).IsTransient); + Assert.That(new NpgsqlException("", new TimeoutException()).IsTransient); + Assert.That(new NpgsqlException().IsTransient, Is.False); + Assert.That(new NpgsqlException("", new Exception("Inner Exception")).IsTransient, Is.False); } - -#pragma warning disable SYSLIB0051 -#pragma warning disable 618 - [Test] - public void PostgresException_IsTransient() - { - Assert.True(CreateWithSqlState("53300").IsTransient); - Assert.False(CreateWithSqlState("0").IsTransient); - - PostgresException CreateWithSqlState(string sqlState) - { - var info = CreateSerializationInfo(); - new Exception().GetObjectData(info, default); - - info.AddValue(nameof(PostgresException.Severity), null); - info.AddValue(nameof(PostgresException.InvariantSeverity), null); - info.AddValue(nameof(PostgresException.SqlState), sqlState); - info.AddValue(nameof(PostgresException.MessageText), null); - info.AddValue(nameof(PostgresException.Detail), null); - info.AddValue(nameof(PostgresException.Hint), null); - info.AddValue(nameof(PostgresException.Position), 0); - info.AddValue(nameof(PostgresException.InternalPosition), 0); - info.AddValue(nameof(PostgresException.InternalQuery), null); - info.AddValue(nameof(PostgresException.Where), null); - info.AddValue(nameof(PostgresException.SchemaName), null); - info.AddValue(nameof(PostgresException.TableName), null); - info.AddValue(nameof(PostgresException.ColumnName), null); - info.AddValue(nameof(PostgresException.DataTypeName), null); - info.AddValue(nameof(PostgresException.ConstraintName), null); - info.AddValue(nameof(PostgresException.File), null); - info.AddValue(nameof(PostgresException.Line), null); - info.AddValue(nameof(PostgresException.Routine), null); - - return new PostgresException(info, default); - } - } -#pragma warning restore SYSLIB0051 -#pragma warning restore 618 - -#pragma warning disable SYSLIB0011 -#pragma warning disable SYSLIB0050 - [Test] - public void Serialization() - { - var actual = new PostgresException("message text", "high", "high2", "53300", "detail", "hint", 18, 42, "internal query", - "where", "schema", "table", "column", "data type", "constraint", "file", "line", "routine"); - - var formatter = new BinaryFormatter(); - var stream = new MemoryStream(); - - formatter.Serialize(stream, actual); - stream.Seek(0, SeekOrigin.Begin); - - var expected = (PostgresException)formatter.Deserialize(stream); - - Assert.That(expected.Severity, Is.EqualTo(actual.Severity)); - Assert.That(expected.InvariantSeverity, Is.EqualTo(actual.InvariantSeverity)); - Assert.That(expected.SqlState, Is.EqualTo(actual.SqlState)); - Assert.That(expected.MessageText, Is.EqualTo(actual.MessageText)); - Assert.That(expected.Detail, Is.EqualTo(actual.Detail)); - Assert.That(expected.Hint, Is.EqualTo(actual.Hint)); - Assert.That(expected.Position, Is.EqualTo(actual.Position)); - Assert.That(expected.InternalPosition, Is.EqualTo(actual.InternalPosition)); - Assert.That(expected.InternalQuery, Is.EqualTo(actual.InternalQuery)); - Assert.That(expected.Where, Is.EqualTo(actual.Where)); - Assert.That(expected.SchemaName, Is.EqualTo(actual.SchemaName)); - Assert.That(expected.TableName, Is.EqualTo(actual.TableName)); - Assert.That(expected.ColumnName, Is.EqualTo(actual.ColumnName)); - Assert.That(expected.DataTypeName, Is.EqualTo(actual.DataTypeName)); - Assert.That(expected.ConstraintName, Is.EqualTo(actual.ConstraintName)); - Assert.That(expected.File, Is.EqualTo(actual.File)); - Assert.That(expected.Line, Is.EqualTo(actual.Line)); - Assert.That(expected.Routine, Is.EqualTo(actual.Routine)); - } - - SerializationInfo CreateSerializationInfo() => new(typeof(PostgresException), new FormatterConverter()); -#pragma warning restore SYSLIB0011 - -#pragma warning disable SYSLIB0051 - [Test] - [IssueLink("https://github.com/npgsql/npgsql/issues/3204")] - public void Base_exception_property_serialization() - { - var ex = new PostgresException("the message", "low", "low2", "XX123"); - - var info = CreateSerializationInfo(); - ex.GetObjectData(info, default); - - // Check virtual base properties, which can be incorrectly deserialized if overridden, because the base - // Exception.GetObjectData() method writes the fields, not the properties (e.g. "_message" instead of "Message"). - Assert.That(ex.Data, Is.EquivalentTo((IDictionary?)info.GetValue("Data", typeof(IDictionary)))); - Assert.That(ex.HelpLink, Is.EqualTo(info.GetValue("HelpURL", typeof(string)))); - Assert.That(ex.Message, Is.EqualTo(info.GetValue("Message", typeof(string)))); - Assert.That(ex.Source, Is.EqualTo(info.GetValue("Source", typeof(string)))); - Assert.That(ex.StackTrace, Is.EqualTo(info.GetValue("StackTraceString", typeof(string)))); - } -#pragma warning restore SYSLIB0051 } diff --git a/test/Npgsql.Tests/FunctionTests.cs b/test/Npgsql.Tests/FunctionTests.cs index 37f203b812..4c3b1e10aa 100644 --- a/test/Npgsql.Tests/FunctionTests.cs +++ b/test/Npgsql.Tests/FunctionTests.cs @@ -107,12 +107,12 @@ public async Task Named_parameters() command.Parameters.AddWithValue("sec", 4); var dt = (DateTime)(await command.ExecuteScalarAsync())!; - Assert.AreEqual(new DateTime(2015, 8, 1, 2, 3, 4), dt); + Assert.That(dt, Is.EqualTo(new DateTime(2015, 8, 1, 2, 3, 4))); command.Parameters[0].Value = 2014; command.Parameters[0].ParameterName = ""; // 2014 will be sent as a positional parameter dt = (DateTime)(await command.ExecuteScalarAsync())!; - Assert.AreEqual(new DateTime(2014, 8, 1, 2, 3, 4), dt); + Assert.That(dt, Is.EqualTo(new DateTime(2014, 8, 1, 2, 3, 4))); } [Test] @@ -143,6 +143,25 @@ public async Task Too_many_output_params() Assert.That(command.Parameters["c"].Value, Is.EqualTo(-1)); } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/5793")] + public async Task ReturnValue_parameter_ignored() + { + await using var conn = await OpenConnectionAsync(); + var funcName = await GetTempFunctionName(conn); + await conn.ExecuteNonQueryAsync(@$"CREATE FUNCTION {funcName}() RETURNS integer AS 'SELECT 8;' LANGUAGE 'sql'"); + await using var cmd = new NpgsqlCommand(funcName, conn) { CommandType = CommandType.StoredProcedure }; + var param = new NpgsqlParameter + { + ParameterName = "@ReturnValue", + NpgsqlDbType = NpgsqlDbType.Integer, + Direction = ParameterDirection.ReturnValue, + Value = 0 + }; + cmd.Parameters.Add(param); + Assert.That(cmd.ExecuteScalar(), Is.EqualTo(8)); + Assert.That(param.Value, Is.EqualTo(0)); + } + [Test] public async Task CommandBehavior_SchemaOnly_support_function_call() { @@ -155,7 +174,33 @@ public async Task CommandBehavior_SchemaOnly_support_function_call() var i = 0; while (dr.Read()) i++; - Assert.AreEqual(0, i); + Assert.That(i, Is.EqualTo(0)); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/5820")] + public async Task Output_param_cast_error() + { + await using var conn = await OpenConnectionAsync(); + var function = await GetTempFunctionName(conn); + await conn.ExecuteNonQueryAsync(@$" +CREATE FUNCTION {function} (INOUT param_in int4, OUT param_out interval) AS $$ +BEGIN + param_out = interval '5 years'; +END +$$ LANGUAGE plpgsql"); + await using var cmd = new NpgsqlCommand(function, conn); + cmd.CommandType = CommandType.StoredProcedure; + cmd.Parameters.Add(new NpgsqlParameter("param_in", DbType.Int32) + { + Direction = ParameterDirection.InputOutput, + Value = 1 + }); + cmd.Parameters.Add(new NpgsqlParameter("param_out", NpgsqlDbType.Interval) + { + Direction = ParameterDirection.Output + }); + Assert.ThrowsAsync(cmd.ExecuteNonQueryAsync); + Assert.DoesNotThrowAsync(async () => await conn.ExecuteNonQueryAsync("SELECT 1")); } #region DeriveParameters @@ -245,8 +290,8 @@ await conn.ExecuteNonQueryAsync( { await using var command = new NpgsqlCommand(@"""FunctionCaseSensitive""", conn) { CommandType = CommandType.StoredProcedure }; NpgsqlCommandBuilder.DeriveParameters(command); - Assert.AreEqual(NpgsqlDbType.Integer, command.Parameters[0].NpgsqlDbType); - Assert.AreEqual(NpgsqlDbType.Text, command.Parameters[1].NpgsqlDbType); + Assert.That(command.Parameters[0].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer)); + Assert.That(command.Parameters[1].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Text)); } finally { @@ -265,8 +310,8 @@ public async Task DeriveParameters_quote_characters_in_function_name() { await using var command = new NpgsqlCommand(function, conn) { CommandType = CommandType.StoredProcedure }; NpgsqlCommandBuilder.DeriveParameters(command); - Assert.AreEqual(NpgsqlDbType.Integer, command.Parameters[0].NpgsqlDbType); - Assert.AreEqual(NpgsqlDbType.Text, command.Parameters[1].NpgsqlDbType); + Assert.That(command.Parameters[0].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer)); + Assert.That(command.Parameters[1].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Text)); } finally { @@ -285,8 +330,8 @@ await conn.ExecuteNonQueryAsync( { await using var command = new NpgsqlCommand(@"""My.Dotted.Function""", conn) { CommandType = CommandType.StoredProcedure }; NpgsqlCommandBuilder.DeriveParameters(command); - Assert.AreEqual(NpgsqlDbType.Integer, command.Parameters[0].NpgsqlDbType); - Assert.AreEqual(NpgsqlDbType.Text, command.Parameters[1].NpgsqlDbType); + Assert.That(command.Parameters[0].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer)); + Assert.That(command.Parameters[1].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Text)); } finally { @@ -304,8 +349,8 @@ await conn.ExecuteNonQueryAsync( $"CREATE FUNCTION {function}(x int, y int, out sum int, out product int) AS 'SELECT $1 + $2, $1 * $2' LANGUAGE sql"); await using var command = new NpgsqlCommand(function, conn) { CommandType = CommandType.StoredProcedure }; NpgsqlCommandBuilder.DeriveParameters(command); - Assert.AreEqual("x", command.Parameters[0].ParameterName); - Assert.AreEqual("y", command.Parameters[1].ParameterName); + Assert.That(command.Parameters[0].ParameterName, Is.EqualTo("x")); + Assert.That(command.Parameters[1].ParameterName, Is.EqualTo("y")); } [Test] diff --git a/test/Npgsql.Tests/GlobalTypeMapperTests.cs b/test/Npgsql.Tests/GlobalTypeMapperTests.cs index a5c75e41bf..51f950045e 100644 --- a/test/Npgsql.Tests/GlobalTypeMapperTests.cs +++ b/test/Npgsql.Tests/GlobalTypeMapperTests.cs @@ -1,4 +1,5 @@ using System; +using System.Data; using System.Threading.Tasks; using Npgsql.Internal; using Npgsql.Internal.Postgres; @@ -26,17 +27,22 @@ public async Task MapEnum() await connection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); await connection.ReloadTypesAsync(); - await AssertType(connection, Mood.Happy, "happy", type, npgsqlDbType: null); + await AssertType(connection, Mood.Happy, "happy", type, dataTypeInference: DataTypeInference.Nothing); } NpgsqlConnection.GlobalTypeMapper.UnmapEnum(type); // Global mapping changes have no effect on already-built data sources - await AssertType(dataSource1, Mood.Happy, "happy", type, npgsqlDbType: null); + await AssertType(dataSource1, Mood.Happy, "happy", type, dataTypeInference: DataTypeInference.Nothing); + await AssertType(dataSource1, "happy", "happy", + type, dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Object, DbType.String), valueTypeEqualsFieldType: false); // But they do affect new data sources await using var dataSource2 = CreateDataSource(); - await AssertType(dataSource2, "happy", "happy", type, npgsqlDbType: null, isDefault: false); + Assert.ThrowsAsync(() => AssertType(dataSource2, Mood.Happy, "happy", + type, dataTypeInference: DataTypeInference.Nothing)); + await AssertType(dataSource2, "happy", "happy", "text", dbType: DbType.String); } [Test] @@ -55,17 +61,21 @@ public async Task MapEnum_NonGeneric() await connection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); await connection.ReloadTypesAsync(); - await AssertType(connection, Mood.Happy, "happy", type, npgsqlDbType: null); + await AssertType(connection, Mood.Happy, "happy", type, dataTypeInference: DataTypeInference.Nothing); } NpgsqlConnection.GlobalTypeMapper.UnmapEnum(typeof(Mood), type); // Global mapping changes have no effect on already-built data sources - await AssertType(dataSource1, Mood.Happy, "happy", type, npgsqlDbType: null); + await AssertType(dataSource1, Mood.Happy, "happy", type, dataTypeInference: DataTypeInference.Nothing); + await AssertType(dataSource1, "happy", "happy", + type, dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Object, DbType.String), valueTypeEqualsFieldType: false); // But they do affect new data sources await using var dataSource2 = CreateDataSource(); - Assert.ThrowsAsync(() => AssertType(dataSource2, Mood.Happy, "happy", type, npgsqlDbType: null)); + Assert.ThrowsAsync(() => AssertType(dataSource2, Mood.Happy, "happy", type, dataTypeInference: DataTypeInference.Nothing)); + await AssertType(dataSource2, "happy", "happy", "text", dbType: DbType.String); } finally { @@ -86,17 +96,25 @@ public async Task Reset() { await connection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); await connection.ReloadTypesAsync(); + + await AssertType(connection, Mood.Happy, "happy", type, dataTypeInference: DataTypeInference.Nothing); } // A global mapping change has no effects on data sources which have already been built NpgsqlConnection.GlobalTypeMapper.Reset(); // Global mapping changes have no effect on already-built data sources - await AssertType(dataSource1, Mood.Happy, "happy", type, npgsqlDbType: null); + await AssertType(dataSource1, Mood.Happy, "happy", type, dataTypeInference: DataTypeInference.Nothing); + await AssertType(dataSource1, "happy", "happy", + type, dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Object, DbType.String), valueTypeEqualsFieldType: false); // But they do affect new data sources await using var dataSource2 = CreateDataSource(); - await AssertType(dataSource2, "happy", "happy", type, npgsqlDbType: null, isDefault: false); + Assert.ThrowsAsync(() => AssertType(dataSource2, Mood.Happy, "happy", type, dataTypeInference: DataTypeInference.Nothing)); + await AssertType(dataSource2, "happy", "happy", + type, dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Object, DbType.String)); } [Test] diff --git a/test/Npgsql.Tests/LargeObjectTests.cs b/test/Npgsql.Tests/LargeObjectTests.cs index fb7179abb4..bdb1c51084 100644 --- a/test/Npgsql.Tests/LargeObjectTests.cs +++ b/test/Npgsql.Tests/LargeObjectTests.cs @@ -17,30 +17,30 @@ public void Test() var oid = manager.Create(); using (var stream = manager.OpenReadWrite(oid)) { - var buf = Encoding.UTF8.GetBytes("Hello"); + var buf = "Hello"u8.ToArray(); stream.Write(buf, 0, buf.Length); stream.Seek(0, System.IO.SeekOrigin.Begin); var buf2 = new byte[buf.Length]; - stream.Read(buf2, 0, buf2.Length); + stream.ReadExactly(buf2, 0, buf2.Length); Assert.That(buf.SequenceEqual(buf2)); - Assert.AreEqual(5, stream.Position); + Assert.That(stream.Position, Is.EqualTo(5)); - Assert.AreEqual(5, stream.Length); + Assert.That(stream.Length, Is.EqualTo(5)); stream.Seek(-1, System.IO.SeekOrigin.Current); - Assert.AreEqual((int)'o', stream.ReadByte()); + Assert.That(stream.ReadByte(), Is.EqualTo((int)'o')); manager.MaxTransferBlockSize = 3; stream.Write(buf, 0, buf.Length); stream.Seek(-5, System.IO.SeekOrigin.End); var buf3 = new byte[100]; - Assert.AreEqual(5, stream.Read(buf3, 0, 100)); + Assert.That(stream.Read(buf3, 0, 100), Is.EqualTo(5)); Assert.That(buf.SequenceEqual(buf3.Take(5))); stream.SetLength(43); - Assert.AreEqual(43, stream.Length); + Assert.That(stream.Length, Is.EqualTo(43)); } manager.Unlink(oid); diff --git a/test/Npgsql.Tests/LoggingTests.cs b/test/Npgsql.Tests/LoggingTests.cs new file mode 100644 index 0000000000..0d5d0ee10d --- /dev/null +++ b/test/Npgsql.Tests/LoggingTests.cs @@ -0,0 +1,282 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using NpgsqlTypes; +using NUnit.Framework; +using static Npgsql.Tests.TestUtil; + +namespace Npgsql.Tests; + +public class LoggingTests : TestBase +{ + [Test] + public async Task Command_ExecuteScalar_single_statement_without_parameters() + { + await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT 1", conn); + + using (listLoggerProvider.Record()) + { + await cmd.ExecuteScalarAsync(); + } + + var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); + Assert.That(executingCommandEvent.Message, Does.Contain("Command execution completed").And.Contains("SELECT 1")); + AssertLoggingStateContains(executingCommandEvent, "CommandText", "SELECT 1"); + AssertLoggingStateDoesNotContain(executingCommandEvent, "Parameters"); + AssertLoggingStateContains(executingCommandEvent, "ConnectorId", conn.ProcessID); + } + + [Test] + public async Task Command_ExecuteScalar_single_statement_with_positional_parameters() + { + await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT $1, $2", conn); + cmd.Parameters.Add(new() { Value = 8 }); + cmd.Parameters.Add(new() { NpgsqlDbType = NpgsqlDbType.Integer, Value = DBNull.Value }); + + using (listLoggerProvider.Record()) + { + await cmd.ExecuteScalarAsync(); + } + + var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); + Assert.That(executingCommandEvent.Message, Does.Contain("Command execution completed") + .And.Contains("SELECT $1, $2") + .And.Contains("Parameters: [8, NULL]")); + AssertLoggingStateContains(executingCommandEvent, "CommandText", "SELECT $1, $2"); + AssertLoggingStateContains(executingCommandEvent, "Parameters", new object[] { 8, "NULL" }); + AssertLoggingStateContains(executingCommandEvent, "ConnectorId", conn.ProcessID); + } + + [Test] + public async Task Command_ExecuteScalar_single_statement__Should_unwrap_array_and_truncate_and_write_nulls() + { + await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT $1, $2, $3, $4, $5, $6", conn); + cmd.Parameters.Add(new NpgsqlParameter { TypedValue = 1024 }); + cmd.Parameters.Add(new NpgsqlParameter { TypedValue = [1, 2, 3], NpgsqlDbType = NpgsqlDbType.Array | NpgsqlDbType.Integer }); + cmd.Parameters.Add(new NpgsqlParameter { TypedValue = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], NpgsqlDbType = NpgsqlDbType.Array | NpgsqlDbType.Integer }); + cmd.Parameters.Add(new NpgsqlParameter { TypedValue = [1, null], NpgsqlDbType = NpgsqlDbType.Array | NpgsqlDbType.Integer }); + cmd.Parameters.Add(new NpgsqlParameter { TypedValue = null }); + cmd.Parameters.Add(new() { NpgsqlDbType = NpgsqlDbType.Integer, Value = DBNull.Value }); + + using (listLoggerProvider.Record()) + { + await cmd.ExecuteScalarAsync(); + } + + var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); + Assert.That(executingCommandEvent.Message, Does.Contain("Command execution completed") + .And.Contains("SELECT $1, $2, $3, $4, $5, $6") + .And.Contains("Parameters: [1024, [1, 2, 3], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, ...], [1, NULL], NULL, NULL]")); + AssertLoggingStateContains(executingCommandEvent, "CommandText", "SELECT $1, $2, $3, $4, $5, $6"); + AssertLoggingStateContains(executingCommandEvent, "Parameters", new object[] { 1024, "[1, 2, 3]", "[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, ...]", "[1, NULL]", "NULL", "NULL" }); + AssertLoggingStateContains(executingCommandEvent, "ConnectorId", conn.ProcessID); + } + + [Test] + public async Task Command_ExecuteScalar_single_statement_with_named_parameters() + { + await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT @p1, @p2", conn); + cmd.Parameters.Add(new() { ParameterName = "p1", Value = 8 }); + cmd.Parameters.Add(new() { ParameterName = "p2", NpgsqlDbType = NpgsqlDbType.Integer, Value = DBNull.Value }); + + using (listLoggerProvider.Record()) + { + await cmd.ExecuteScalarAsync(); + } + + var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); + Assert.That(executingCommandEvent.Message, Does.Contain("Command execution completed") + .And.Contains("SELECT $1, $2") + .And.Contains("Parameters: [8, NULL]")); + AssertLoggingStateContains(executingCommandEvent, "CommandText", "SELECT $1, $2"); + AssertLoggingStateContains(executingCommandEvent, "Parameters", new object[] { 8, "NULL" }); + AssertLoggingStateContains(executingCommandEvent, "ConnectorId", conn.ProcessID); + } + + [Test] + public async Task Command_ExecuteScalar_single_statement_with_parameter_logging_off() + { + await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider, sensitiveDataLoggingEnabled: false); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT $1, $2", conn); + cmd.Parameters.Add(new() { Value = 8 }); + cmd.Parameters.Add(new() { Value = 9 }); + + using (listLoggerProvider.Record()) + { + await cmd.ExecuteScalarAsync(); + } + + var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); + Assert.That(executingCommandEvent.Message, Does.Contain("Command execution completed").And.Contains($"SELECT $1, $2")); + AssertLoggingStateContains(executingCommandEvent, "CommandText", "SELECT $1, $2"); + AssertLoggingStateDoesNotContain(executingCommandEvent, "Parameters"); + } + + [Test] + public async Task Command_ExecuteScalar_multiple_statement_without_parameters() + { + await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT 1; SELECT 2", conn); + + using (listLoggerProvider.Record()) + { + await cmd.ExecuteScalarAsync(); + } + + var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); + Assert.That(executingCommandEvent.Message, Does.Contain("Batch execution completed").And.Contains("[(SELECT 1, []), (SELECT 2, [])]")); + var batchCommands = (IList<(string CommandText, IEnumerable Parameters)>)AssertLoggingStateContains(executingCommandEvent, "BatchCommands"); + Assert.That(batchCommands.Count, Is.EqualTo(2)); + Assert.That(batchCommands[0].CommandText, Is.EqualTo("SELECT 1")); + Assert.That(batchCommands[0].Parameters, Is.Empty); + Assert.That(batchCommands[1].CommandText, Is.EqualTo("SELECT 2")); + Assert.That(batchCommands[1].Parameters, Is.Empty); + AssertLoggingStateDoesNotContain(executingCommandEvent, "Parameters"); + AssertLoggingStateContains(executingCommandEvent, "ConnectorId", conn.ProcessID); + } + + [Test] + public async Task Command_ExecuteScalar_multiple_statement_with_parameters() + { + await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT @p1; SELECT @p2", conn); + cmd.Parameters.Add(new() { ParameterName = "p1", Value = 8 }); + cmd.Parameters.Add(new() { ParameterName = "p2", Value = 9 }); + + using (listLoggerProvider.Record()) + { + await cmd.ExecuteScalarAsync(); + } + + var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); + Assert.That(executingCommandEvent.Message, Does.Contain("Batch execution completed").And.Contains("[(SELECT $1, [8]), (SELECT $1, [9])]")); + var batchCommands = (IList<(string CommandText, IEnumerable Parameters)>)AssertLoggingStateContains(executingCommandEvent, "BatchCommands"); + Assert.That(batchCommands.Count, Is.EqualTo(2)); + Assert.That(batchCommands[0].CommandText, Is.EqualTo("SELECT $1")); + Assert.That(batchCommands[0].Parameters.First(), Is.EqualTo(8)); + Assert.That(batchCommands[1].CommandText, Is.EqualTo("SELECT $1")); + Assert.That(batchCommands[1].Parameters.First(), Is.EqualTo(9)); + AssertLoggingStateDoesNotContain(executingCommandEvent, "Parameters"); + AssertLoggingStateContains(executingCommandEvent, "ConnectorId", conn.ProcessID); + } + + [Test] + public async Task Command_ExecuteScalar_multiple_statement_with_parameter_logging_off() + { + await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider, sensitiveDataLoggingEnabled: false); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT @p1; SELECT @p2", conn); + cmd.Parameters.Add(new() { ParameterName = "p1", Value = 8 }); + cmd.Parameters.Add(new() { ParameterName = "p2", Value = 9 }); + + using (listLoggerProvider.Record()) + { + await cmd.ExecuteScalarAsync(); + } + + var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); + Assert.That(executingCommandEvent.Message, Does.Contain("Batch execution completed").And.Contains("[SELECT $1, SELECT $1]")); + var batchCommands = (IList)AssertLoggingStateContains(executingCommandEvent, "BatchCommands"); + Assert.That(batchCommands.Count, Is.EqualTo(2)); + Assert.That(batchCommands[0], Is.EqualTo("SELECT $1")); + Assert.That(batchCommands[1], Is.EqualTo("SELECT $1")); + AssertLoggingStateDoesNotContain(executingCommandEvent, "Parameters"); + AssertLoggingStateContains(executingCommandEvent, "ConnectorId", conn.ProcessID); + } + + [Test] + public async Task Batch_ExecuteScalar_single_statement_without_parameters() + { + await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlBatch(conn) + { + BatchCommands = { new("SELECT 1") } + }; + + using (listLoggerProvider.Record()) + { + await cmd.ExecuteScalarAsync(); + } + + var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); + + Assert.That(executingCommandEvent.Message, Does.Contain("Command execution completed").And.Contains("SELECT 1")); + AssertLoggingStateContains(executingCommandEvent, "CommandText", "SELECT 1"); + AssertLoggingStateDoesNotContain(executingCommandEvent, "Parameters"); + AssertLoggingStateContains(executingCommandEvent, "ConnectorId", conn.ProcessID); + } + + [Test] + public async Task Batch_ExecuteScalar_multiple_statements_with_parameters() + { + await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = + { + new("SELECT $1") { Parameters = { new() { Value = 8 } } }, + new("SELECT $1, 9") { Parameters = { new() { Value = 9 } } } + } + }; + + using (listLoggerProvider.Record()) + { + await batch.ExecuteScalarAsync(); + } + + var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); + + Assert.That(executingCommandEvent.Message, Does.Contain("Batch execution completed").And.Contains("[(SELECT $1, [8]), (SELECT $1, 9, [9])]")); + AssertLoggingStateDoesNotContain(executingCommandEvent, "CommandText"); + AssertLoggingStateDoesNotContain(executingCommandEvent, "Parameters"); + AssertLoggingStateContains(executingCommandEvent, "ConnectorId", conn.ProcessID); + + var batchCommands = (IList<(string CommandText, IEnumerable Parameters)>)AssertLoggingStateContains(executingCommandEvent, "BatchCommands"); + Assert.That(batchCommands.Count, Is.EqualTo(2)); + Assert.That(batchCommands[0].CommandText, Is.EqualTo("SELECT $1")); + Assert.That(batchCommands[0].Parameters.First(), Is.EqualTo(8)); + Assert.That(batchCommands[1].CommandText, Is.EqualTo("SELECT $1, 9")); + Assert.That(batchCommands[1].Parameters.First(), Is.EqualTo(9)); + } + + [Test] + public async Task Batch_ExecuteScalar_single_statement_with_parameter_logging_off() + { + await using var dataSource = CreateLoggingDataSource(out var listLoggerProvider, sensitiveDataLoggingEnabled: false); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var batch = new NpgsqlBatch(conn) + { + BatchCommands = + { + new("SELECT $1") { Parameters = { new() { Value = 8 } } }, + new("SELECT $1, 9") { Parameters = { new() { Value = 9 } } } + } + }; + + using (listLoggerProvider.Record()) + { + await batch.ExecuteScalarAsync(); + } + + var executingCommandEvent = listLoggerProvider.Log.Single(l => l.Id == NpgsqlEventId.CommandExecutionCompleted); + Assert.That(executingCommandEvent.Message, Does.Contain("Batch execution completed").And.Contains("[SELECT $1, SELECT $1, 9]")); + var batchCommands = (IList)AssertLoggingStateContains(executingCommandEvent, "BatchCommands"); + Assert.That(batchCommands.Count, Is.EqualTo(2)); + Assert.That(batchCommands[0], Is.EqualTo("SELECT $1")); + Assert.That(batchCommands[1], Is.EqualTo("SELECT $1, 9")); + } +} diff --git a/test/Npgsql.Tests/MetricTests.cs b/test/Npgsql.Tests/MetricTests.cs new file mode 100644 index 0000000000..9a8b2757e3 --- /dev/null +++ b/test/Npgsql.Tests/MetricTests.cs @@ -0,0 +1,186 @@ +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using NUnit.Framework; +using OpenTelemetry; +using OpenTelemetry.Metrics; + +namespace Npgsql.Tests; + +public class MetricTests : TestBase +{ + [Test] + public async Task OperationDuration() + { + var exportedItems = new List(); + using var meterProvider = Sdk.CreateMeterProviderBuilder() + .AddMeter("Npgsql") + .AddInMemoryExporter(exportedItems) + .Build(); + + await using var dataSource = CreateDataSource(); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = conn.CreateCommand(); + cmd.CommandText = "SELECT 1"; + await using (var reader = await cmd.ExecuteReaderAsync()) + while (await reader.ReadAsync()); + + meterProvider.ForceFlush(); + + var metric = exportedItems.SingleOrDefault(m => m.Name == "db.client.operation.duration"); + Assert.That(metric, Is.Not.Null, "Metric 'db.client.operation.duration' not found."); + + var point = GetFilteredPoints(metric.GetMetricPoints(), dataSource.Name).Single(); + + Assert.That(point.GetHistogramSum(), Is.GreaterThan(0)); + Assert.That(point.GetHistogramCount(), Is.EqualTo(1)); + + var tags = ToDictionary(point.Tags); + + using (Assert.EnterMultipleScope()) + { + // TODO: Vary this for PG-like databases (e.g. CockroachDB)? + Assert.That(tags["db.system.name"], Is.EqualTo("postgresql")); + + Assert.That(tags["server.address"], Is.EqualTo(dataSource.Settings.Host)); + Assert.That(tags["server.port"], Is.EqualTo(dataSource.Settings.Port)); + Assert.That(tags["db.client.connection.pool.name"], Is.EqualTo(dataSource.Name)); + } + } + + [Test] + public async Task ConnectionCount() + { + var exportedItems = new List(); + using var meterProvider = Sdk.CreateMeterProviderBuilder() + .AddMeter("Npgsql") + .AddInMemoryExporter(exportedItems) + .Build(); + + await using var dataSource = CreateDataSource(); + + using (var _ = await dataSource.OpenConnectionAsync()) + { + meterProvider.ForceFlush(); + + var metric = exportedItems.Single(m => m.Name == "db.client.connection.count"); + var points = GetFilteredPoints(metric.GetMetricPoints(), dataSource.Name); + + var usedPoint = GetPoint(points, "used"); + Assert.That(usedPoint.GetSumLong(), Is.EqualTo(1), "Expected used connections to be 1"); + + var idlePoint = GetPoint(points, "idle"); + Assert.That(idlePoint.GetSumLong(), Is.Zero, "Expected idle connections to be 0"); + + exportedItems.Clear(); + } + + meterProvider.ForceFlush(); + + { + var metric = exportedItems.Single(m => m.Name == "db.client.connection.count"); + var points = GetFilteredPoints(metric.GetMetricPoints(), dataSource.Name); + + var usedPoint = GetPoint(points, "used"); + Assert.That(usedPoint.GetSumLong(), Is.Zero, "Expected used connections to be 0"); + + var idlePoint = GetPoint(points, "idle"); + Assert.That(idlePoint.GetSumLong(), Is.EqualTo(1), "Expected idle connections to be 1"); + } + + static MetricPoint GetPoint(IEnumerable points, string state) + { + foreach (var point in points) + { + foreach (var tag in point.Tags) + { + if (tag.Key == "db.client.connection.state" && (string?)tag.Value == state) + return point; + } + } + + Assert.Fail($"Point with state '{state}' not found"); + throw new UnreachableException(); + } + } + + [Test] + public async Task ConnectionMax() + { + var exportedItems = new List(); + using var meterProvider = Sdk.CreateMeterProviderBuilder() + .AddMeter("Npgsql") + .AddInMemoryExporter(exportedItems) + .Build(); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.ConnectionStringBuilder.MaxPoolSize = 134; + await using var dataSource = dataSourceBuilder.Build(); + + meterProvider.ForceFlush(); + + var metric = exportedItems.Single(m => m.Name == "db.client.connection.max"); + var point = GetFilteredPoints(metric.GetMetricPoints(), dataSource.Name).First(p => p.GetSumLong() == 134); + var tags = ToDictionary(point.Tags); + Assert.That(tags["db.client.connection.pool.name"], Is.EqualTo(dataSource.Name)); + } + + [Test] + public async Task Password_does_not_leak_via_datasource_name([Values] bool persistSecurityInfo) + { + var exportedItems = new List(); + using var meterProvider = Sdk.CreateMeterProviderBuilder() + .AddMeter("Npgsql") + .AddInMemoryExporter(exportedItems) + .Build(); + + var dataSourceBuilder = base.CreateDataSourceBuilder(); + dataSourceBuilder.ConnectionStringBuilder.ApplicationName = "MetricsDataSource" + Interlocked.Increment(ref _dataSourceCounter); + dataSourceBuilder.ConnectionStringBuilder.PersistSecurityInfo = persistSecurityInfo; + // Do not set the data source name - this makes it default to the connection string, but without + // the password (even when Persist Security Info is true) + await using var dataSource = dataSourceBuilder.Build(); + + meterProvider.ForceFlush(); + + var metric = exportedItems.Single(m => m.Name == "db.client.connection.max"); + var point = GetFilteredPoints(metric.GetMetricPoints(), dataSource.Name).First(); + var tags = ToDictionary(point.Tags); + var connectionString = new NpgsqlConnectionStringBuilder((string)tags["db.client.connection.pool.name"]!); + Assert.That(connectionString.Password, Is.Null); + } + + static Dictionary ToDictionary(ReadOnlyTagCollection tags) + { + var dict = new Dictionary(); + foreach (var tag in tags) + dict[tag.Key] = tag.Value; + return dict; + } + + protected override NpgsqlDataSourceBuilder CreateDataSourceBuilder() + { + var dataSourceBuilder = base.CreateDataSourceBuilder(); + dataSourceBuilder.Name = "MetricsDataSource" + Interlocked.Increment(ref _dataSourceCounter); + return dataSourceBuilder; + } + + protected override NpgsqlDataSource CreateDataSource() + => CreateDataSourceBuilder().Build(); + + int _dataSourceCounter; + + static IEnumerable GetFilteredPoints(MetricPointsAccessor points, string dataSourceName) + { + foreach (var point in points) + { + foreach (var tag in point.Tags) + { + if (tag.Key == "db.client.connection.pool.name" && (string?)tag.Value == dataSourceName) + yield return point; + } + } + } +} diff --git a/test/Npgsql.Tests/MultipleHostsTests.cs b/test/Npgsql.Tests/MultipleHostsTests.cs index f4f2dfffb7..94ba45dd00 100644 --- a/test/Npgsql.Tests/MultipleHostsTests.cs +++ b/test/Npgsql.Tests/MultipleHostsTests.cs @@ -8,11 +8,9 @@ using System.Linq; using System.Net; using System.Net.Sockets; -using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; using System.Transactions; -using Npgsql.Properties; using static Npgsql.Tests.Support.MockState; using static Npgsql.Tests.TestUtil; using IsolationLevel = System.Transactions.IsolationLevel; @@ -20,10 +18,12 @@ namespace Npgsql.Tests; +#pragma warning disable CS0618 + public class MultipleHostsTests : TestBase { static readonly object[] MyCases = - { + [ new object[] { TargetSessionAttributes.Standby, new[] { Primary, Standby }, 1 }, new object[] { TargetSessionAttributes.Standby, new[] { PrimaryReadOnly, Standby }, 1 }, new object[] { TargetSessionAttributes.PreferStandby, new[] { Primary, Standby }, 1 }, @@ -41,7 +41,7 @@ public class MultipleHostsTests : TestBase new object[] { TargetSessionAttributes.ReadWrite, new[] { PrimaryReadOnly, Primary }, 1 }, new object[] { TargetSessionAttributes.ReadOnly, new[] { Primary, Standby }, 1 }, new object[] { TargetSessionAttributes.ReadOnly, new[] { PrimaryReadOnly, Standby }, 0 } - }; + ]; [Test] [TestCaseSource(nameof(MyCases))] @@ -91,6 +91,55 @@ public async Task Connect_to_correct_host_unpooled(TargetSessionAttributes targe _ = await postmasters[i].WaitForServerConnection(); } + [Test] + [TestCaseSource(nameof(MyCases))] + public async Task Connect_to_correct_host_legacy(TargetSessionAttributes targetSessionAttributes, MockState[] servers, int expectedServer) + { + var postmasters = servers.Select(s => PgPostmasterMock.Start(state: s)).ToArray(); + await using var __ = new DisposableWrapper(postmasters); + + var connectionStringBuilder = new NpgsqlConnectionStringBuilder + { + Host = MultipleHosts(postmasters), + ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading, + TargetSessionAttributes = TargetSessionAttributesAsString(targetSessionAttributes) + }; + + using var pool = CreateTempPool(connectionStringBuilder, out var connectionString); + await using var conn = new NpgsqlConnection(connectionString); + await conn.OpenAsync(); + + Assert.That(conn.Port, Is.EqualTo(postmasters[expectedServer].Port)); + + for (var i = 0; i <= expectedServer; i++) + _ = await postmasters[i].WaitForServerConnection(); + } + + [Test] + [TestCaseSource(nameof(MyCases))] + public async Task Connect_to_correct_host_connection_string(TargetSessionAttributes targetSessionAttributes, MockState[] servers, int expectedServer) + { + var postmasters = servers.Select(s => PgPostmasterMock.Start(state: s)).ToArray(); + await using var __ = new DisposableWrapper(postmasters); + + var connectionStringBuilder = new NpgsqlConnectionStringBuilder + { + Host = MultipleHosts(postmasters), + ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading, + TargetSessionAttributes = TargetSessionAttributesAsString(targetSessionAttributes) + }; + + await using var dataSource = new NpgsqlDataSourceBuilder(connectionStringBuilder.ConnectionString) + .Build(); + Assert.That(dataSource, Is.TypeOf()); + await using var conn = await dataSource.OpenConnectionAsync(); + + Assert.That(conn.Port, Is.EqualTo(postmasters[expectedServer].Port)); + + for (var i = 0; i <= expectedServer; i++) + _ = await postmasters[i].WaitForServerConnection(); + } + [Test] [TestCaseSource(nameof(MyCases))] public async Task Connect_to_correct_host_with_available_idle( @@ -131,6 +180,40 @@ public async Task Connect_to_correct_host_with_available_idle( _ = await postmasters[i].WaitForServerConnection(); } + [Test] + public async Task Legacy_connection_shares_datasource() + { + await using var primaryPostmaster = PgPostmasterMock.Start(state: Primary); + await using var standbyPostmaster = PgPostmasterMock.Start(state: Standby); + + var builder1 = new NpgsqlConnectionStringBuilder + { + Host = MultipleHosts(primaryPostmaster, standbyPostmaster), + ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading, + TargetSessionAttributes = "Prefer-Primary" + }; + + // Use the exact same pool for both connections as CreateTempPool adds a unique `ApplicationName` to connection string + using var pool = CreateTempPool(builder1, out var connectionString1); + var connectionString2 = new NpgsqlConnectionStringBuilder(connectionString1) + { + TargetSessionAttributes = "Prefer-Standby" + }.ConnectionString; + + await using var conn1 = new NpgsqlConnection(connectionString1); + await conn1.OpenAsync(); + Assert.That(conn1.Port, Is.EqualTo(primaryPostmaster.Port)); + + await using var conn2 = new NpgsqlConnection(connectionString2); + await conn2.OpenAsync(); + Assert.That(conn2.Port, Is.EqualTo(standbyPostmaster.Port)); + + Assert.That(conn1.NpgsqlDataSource, Is.Not.SameAs(conn2.NpgsqlDataSource)); + Assert.That(conn1.NpgsqlDataSource, Is.TypeOf()); + Assert.That(conn2.NpgsqlDataSource, Is.TypeOf()); + Assert.That(((MultiHostDataSourceWrapper)conn1.NpgsqlDataSource).WrappedSource, Is.SameAs(((MultiHostDataSourceWrapper)conn2.NpgsqlDataSource).WrappedSource)); + } + [Test] [TestCase(TargetSessionAttributes.Standby, new[] { Primary, Primary })] [TestCase(TargetSessionAttributes.Primary, new[] { Standby, Standby })] @@ -213,7 +296,6 @@ public async Task All_hosts_are_unavailable( } [Test] - [Platform(Exclude = "MacOsX", Reason = "Flaky in CI on Mac")] public async Task First_host_is_down() { using var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); @@ -253,7 +335,7 @@ public async Task TargetSessionAttributes_with_single_host(string targetSessionA if (targetSessionAttributes == "any") { - await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); + await using var postmasterMock = PgPostmasterMock.Start(connectionString); using var pool = CreateTempPool(postmasterMock.ConnectionString, out connectionString); await using var conn = new NpgsqlConnection(connectionString); await conn.OpenAsync(); @@ -322,7 +404,7 @@ public void HostRecheckSeconds_zero_value() [Test] public void HostRecheckSeconds_invalid_throws() - => Assert.Throws(() => + => Assert.Throws(() => new NpgsqlConnectionStringBuilder { HostRecheckSeconds = -1 @@ -358,21 +440,21 @@ public async Task Connect_with_load_balancing() secondConnector = secondConnection.Connector!; } - Assert.AreNotSame(firstConnector, secondConnector); + Assert.That(secondConnector, Is.Not.SameAs(firstConnector)); await using (var firstBalancedConnection = await dataSource.OpenConnectionAsync()) { - Assert.AreSame(firstConnector, firstBalancedConnection.Connector); + Assert.That(firstBalancedConnection.Connector, Is.SameAs(firstConnector)); } await using (var secondBalancedConnection = await dataSource.OpenConnectionAsync()) { - Assert.AreSame(secondConnector, secondBalancedConnection.Connector); + Assert.That(secondBalancedConnection.Connector, Is.SameAs(secondConnector)); } await using (var thirdBalancedConnection = await dataSource.OpenConnectionAsync()) { - Assert.AreSame(firstConnector, thirdBalancedConnection.Connector); + Assert.That(thirdBalancedConnection.Connector, Is.SameAs(firstConnector)); } } @@ -402,7 +484,7 @@ public async Task Connect_without_load_balancing() } await using (var secondConnection = await dataSource.OpenConnectionAsync()) { - Assert.AreSame(firstConnector, secondConnection.Connector); + Assert.That(secondConnection.Connector, Is.SameAs(firstConnector)); } await using (var firstConnection = await dataSource.OpenConnectionAsync()) await using (var secondConnection = await dataSource.OpenConnectionAsync()) @@ -410,16 +492,16 @@ public async Task Connect_without_load_balancing() secondConnector = secondConnection.Connector!; } - Assert.AreNotSame(firstConnector, secondConnector); + Assert.That(secondConnector, Is.Not.SameAs(firstConnector)); await using (var firstUnbalancedConnection = await dataSource.OpenConnectionAsync()) { - Assert.AreSame(firstConnector, firstUnbalancedConnection.Connector); + Assert.That(firstUnbalancedConnection.Connector, Is.SameAs(firstConnector)); } await using (var secondUnbalancedConnection = await dataSource.OpenConnectionAsync()) { - Assert.AreSame(firstConnector, secondUnbalancedConnection.Connector); + Assert.That(secondUnbalancedConnection.Connector, Is.SameAs(firstConnector)); } } @@ -480,7 +562,7 @@ public async Task Connect_state_changing_hosts([Values] bool alwaysCheckHostStat } await using var thirdConnection = await dataSource.OpenConnectionAsync(TargetSessionAttributes.PreferPrimary); - Assert.AreSame(alwaysCheckHostState ? secondConnector : firstConnector, thirdConnection.Connector); + Assert.That(thirdConnection.Connector, Is.SameAs(alwaysCheckHostState ? secondConnector : firstConnector)); await firstServerTask; await secondServerTask; @@ -493,22 +575,22 @@ public void Database_state_cache_basic() var timeStamp = DateTime.UtcNow; dataSource.UpdateDatabaseState(DatabaseState.PrimaryReadWrite, timeStamp, TimeSpan.Zero); - Assert.AreEqual(DatabaseState.PrimaryReadWrite, dataSource.GetDatabaseState()); + Assert.That(dataSource.GetDatabaseState(), Is.EqualTo(DatabaseState.PrimaryReadWrite)); // Update with the same timestamp - shouldn't change anything dataSource.UpdateDatabaseState(DatabaseState.Standby, timeStamp, TimeSpan.Zero); - Assert.AreEqual(DatabaseState.PrimaryReadWrite, dataSource.GetDatabaseState()); + Assert.That(dataSource.GetDatabaseState(), Is.EqualTo(DatabaseState.PrimaryReadWrite)); // Update with a new timestamp timeStamp = timeStamp.AddSeconds(1); dataSource.UpdateDatabaseState(DatabaseState.PrimaryReadOnly, timeStamp, TimeSpan.Zero); - Assert.AreEqual(DatabaseState.PrimaryReadOnly, dataSource.GetDatabaseState()); + Assert.That(dataSource.GetDatabaseState(), Is.EqualTo(DatabaseState.PrimaryReadOnly)); // Expired state returns as Unknown (depending on ignoreExpiration) timeStamp = timeStamp.AddSeconds(1); dataSource.UpdateDatabaseState(DatabaseState.PrimaryReadWrite, timeStamp, TimeSpan.FromSeconds(-1)); - Assert.AreEqual(DatabaseState.Unknown, dataSource.GetDatabaseState(ignoreExpiration: false)); - Assert.AreEqual(DatabaseState.PrimaryReadWrite, dataSource.GetDatabaseState(ignoreExpiration: true)); + Assert.That(dataSource.GetDatabaseState(ignoreExpiration: false), Is.EqualTo(DatabaseState.Unknown)); + Assert.That(dataSource.GetDatabaseState(ignoreExpiration: true), Is.EqualTo(DatabaseState.PrimaryReadWrite)); } [Test] @@ -615,10 +697,11 @@ public async Task Offline_state_on_query_execution_IOException() public async Task Offline_state_on_query_execution_TimeoutException() { await using var postmaster = PgPostmasterMock.Start(ConnectionString); - var dataSourceBuilder = postmaster.GetDataSourceBuilder(); - dataSourceBuilder.ConnectionStringBuilder.CommandTimeout = 1; - dataSourceBuilder.ConnectionStringBuilder.CancellationTimeout = 1; - await using var dataSource = dataSourceBuilder.Build(); + await using var dataSource = postmaster.CreateDataSource(builder => + { + builder.ConnectionStringBuilder.CommandTimeout = 1; + builder.ConnectionStringBuilder.CancellationTimeout = 1; + }); await using var conn = await dataSource.OpenConnectionAsync(); await using var anotherConn = await dataSource.OpenConnectionAsync(); @@ -641,10 +724,11 @@ public async Task Offline_state_on_query_execution_TimeoutException() public async Task Unknown_state_on_query_execution_TimeoutException_with_disabled_cancellation() { await using var postmaster = PgPostmasterMock.Start(ConnectionString); - var dataSourceBuilder = postmaster.GetDataSourceBuilder(); - dataSourceBuilder.ConnectionStringBuilder.CommandTimeout = 1; - dataSourceBuilder.ConnectionStringBuilder.CancellationTimeout = -1; - await using var dataSource = dataSourceBuilder.Build(); + await using var dataSource = postmaster.CreateDataSource(builder => + { + builder.ConnectionStringBuilder.CommandTimeout = 1; + builder.ConnectionStringBuilder.CancellationTimeout = -1; + }); await using var conn = await dataSource.OpenConnectionAsync(); await using var anotherConn = await dataSource.OpenConnectionAsync(); @@ -667,10 +751,11 @@ public async Task Unknown_state_on_query_execution_TimeoutException_with_disable public async Task Unknown_state_on_query_execution_cancellation_with_disabled_cancellation_timeout() { await using var postmaster = PgPostmasterMock.Start(ConnectionString); - var dataSourceBuilder = postmaster.GetDataSourceBuilder(); - dataSourceBuilder.ConnectionStringBuilder.CommandTimeout = 30; - dataSourceBuilder.ConnectionStringBuilder.CancellationTimeout = -1; - await using var dataSource = dataSourceBuilder.Build(); + await using var dataSource = postmaster.CreateDataSource(builder => + { + builder.ConnectionStringBuilder.CommandTimeout = 30; + builder.ConnectionStringBuilder.CancellationTimeout = -1; + }); await using var conn = await dataSource.OpenConnectionAsync(); await using var anotherConn = await dataSource.OpenConnectionAsync(); @@ -697,10 +782,11 @@ public async Task Unknown_state_on_query_execution_cancellation_with_disabled_ca public async Task Unknown_state_on_query_execution_TimeoutException_with_cancellation_failure() { await using var postmaster = PgPostmasterMock.Start(ConnectionString); - var dataSourceBuilder = postmaster.GetDataSourceBuilder(); - dataSourceBuilder.ConnectionStringBuilder.CommandTimeout = 1; - dataSourceBuilder.ConnectionStringBuilder.CancellationTimeout = 0; - await using var dataSource = dataSourceBuilder.Build(); + await using var dataSource = postmaster.CreateDataSource(builder => + { + builder.ConnectionStringBuilder.CommandTimeout = 1; + builder.ConnectionStringBuilder.CancellationTimeout = 0; + }); await using var conn = await dataSource.OpenConnectionAsync(); @@ -787,6 +873,8 @@ public async Task Transaction_enlist_reuses_connection(string targetSessionAttri TargetSessionAttributes = targetSessionAttributes, ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading, MaxPoolSize = 10, + // Our mock PG server doesn't know how to handle the reset messages + NoResetOnClose = true, }; using var _ = CreateTempPool(csb, out var connString); @@ -920,7 +1008,7 @@ public void IntegrationTest([Values] bool loadBalancing, [Values] bool alwaysChe Assert.DoesNotThrowAsync(() => clientsTask); Assert.ThrowsAsync(() => onlyStandbyClient); Assert.ThrowsAsync(() => readOnlyClient); - Assert.AreEqual(125, queriesDone); + Assert.That(queriesDone, Is.EqualTo(125)); Task Client(NpgsqlMultiHostDataSource multiHostDataSource, TargetSessionAttributes targetSessionAttributes) { @@ -1018,15 +1106,6 @@ public async Task DataSource_without_wrappers() Assert.That(standbyConnection.Port, Is.EqualTo(standbyPostmasterMock.Port)); } - [Test] - public void DataSource_with_TargetSessionAttributes_is_not_supported() - { - var builder = new NpgsqlDataSourceBuilder("Host=foo,bar;Target Session Attributes=primary"); - - Assert.That(() => builder.BuildMultiHost(), Throws.Exception.TypeOf() - .With.Message.EqualTo(NpgsqlStrings.CannotSpecifyTargetSessionAttributes)); - } - [Test] public async Task BuildMultiHost_with_single_host_is_supported() { @@ -1055,6 +1134,20 @@ public async Task Build_with_multiple_hosts_is_supported() await using var connection = await dataSource.OpenConnectionAsync(); } + [Test] + public async Task OpenConnection_when_canceled_throws_TaskCanceledException() + { + var builder = new NpgsqlDataSourceBuilder(ConnectionString); + await using var dataSource = builder.BuildMultiHost(); + using var cts = new CancellationTokenSource(); + cts.Cancel(); + var ex = Assert.ThrowsAsync(async () => + { + await using var connection = await dataSource.OpenConnectionAsync(cts.Token); + }); + Assert.That(ex.CancellationToken, Is.EqualTo(cts.Token)); + } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4181")] [Explicit("Fails until #4181 is fixed.")] public async Task LoadBalancing_is_fair_if_first_host_is_down([Values]TargetSessionAttributes targetSessionAttributes) @@ -1166,15 +1259,24 @@ public async Task LoadBalancing_is_fair_if_first_host_is_down([Values]TargetSess static string MultipleHosts(params PgPostmasterMock[] postmasters) => string.Join(",", postmasters.Select(p => $"{p.Host}:{p.Port}")); - class DisposableWrapper : IAsyncDisposable - { - readonly IEnumerable _disposables; - - public DisposableWrapper(IEnumerable disposables) => _disposables = disposables; + static string? TargetSessionAttributesAsString(TargetSessionAttributes targetSessionAttributes) + => targetSessionAttributes switch + { + TargetSessionAttributes.Any => "Any", + TargetSessionAttributes.Primary => "Primary", + TargetSessionAttributes.Standby => "Standby", + TargetSessionAttributes.PreferPrimary => "Prefer-Primary", + TargetSessionAttributes.PreferStandby => "Prefer-Standby", + TargetSessionAttributes.ReadOnly => "Read-Only", + TargetSessionAttributes.ReadWrite => "Read-Write", + _ => null + }; + sealed class DisposableWrapper(IEnumerable disposables) : IAsyncDisposable + { public async ValueTask DisposeAsync() { - foreach (var disposable in _disposables) + foreach (var disposable in disposables) await disposable.DisposeAsync(); } } diff --git a/test/Npgsql.Tests/NestedDataReaderTests.cs b/test/Npgsql.Tests/NestedDataReaderTests.cs index 72553a6b5e..7e157c3426 100644 --- a/test/Npgsql.Tests/NestedDataReaderTests.cs +++ b/test/Npgsql.Tests/NestedDataReaderTests.cs @@ -199,15 +199,15 @@ public void GetBytes() Assert.That(nestedReader.GetBytes(0, 0, null, 0, 4), Is.EqualTo(3)); Assert.That(nestedReader.GetBytes(0, 0, buf, 0, 3), Is.EqualTo(3)); Assert.That(nestedReader.GetBytes(0, 0, buf, 0, 4), Is.EqualTo(3)); - CollectionAssert.AreEqual(new byte[] { 1, 2, 3, 0 }, buf); + Assert.That(buf, Is.EqualTo(new byte[] { 1, 2, 3, 0 }).AsCollection); buf = new byte[2]; Assert.That(nestedReader.GetBytes(0, 0, buf, 0, 2), Is.EqualTo(2)); - CollectionAssert.AreEqual(new byte[] { 1, 2 }, buf); + Assert.That(buf, Is.EqualTo(new byte[] { 1, 2 }).AsCollection); buf = new byte[2]; Assert.That(nestedReader.GetBytes(0, 1, buf, 1, 1), Is.EqualTo(1)); - CollectionAssert.AreEqual(new byte[] { 0, 2 }, buf); + Assert.That(buf, Is.EqualTo(new byte[] { 0, 2 }).AsCollection); Assert.That(nestedReader.GetBytes(0, 2, buf, 1, 1), Is.EqualTo(1)); - CollectionAssert.AreEqual(new byte[] { 0, 3 }, buf); + Assert.That(buf, Is.EqualTo(new byte[] { 0, 3 }).AsCollection); Assert.Throws(() => nestedReader.GetBytes(1, 0, buf, 0, 1)); Assert.Throws(() => nestedReader.GetBytes(0, 4, buf, 0, 1)); } diff --git a/test/Npgsql.Tests/NotificationTests.cs b/test/Npgsql.Tests/NotificationTests.cs index 9df9aba44d..5f6c11efcd 100644 --- a/test/Npgsql.Tests/NotificationTests.cs +++ b/test/Npgsql.Tests/NotificationTests.cs @@ -19,7 +19,7 @@ public void Notification() conn.ExecuteNonQuery($"LISTEN {notify}"); conn.Notification += (o, e) => receivedNotification = true; conn.ExecuteNonQuery($"NOTIFY {notify}"); - Assert.IsTrue(receivedNotification); + Assert.That(receivedNotification); } [Test, Description("Generates a notification that arrives after reader data that is already being read")] @@ -53,12 +53,12 @@ public async Task Notification_after_data() // Allow some time for the notification to get delivered await Task.Delay(2000); - Assert.IsTrue(reader.Read()); - Assert.AreEqual(1, reader.GetValue(0)); + Assert.That(reader.Read()); + Assert.That(reader.GetValue(0), Is.EqualTo(1)); } Assert.That(conn.ExecuteScalar("SELECT 1"), Is.EqualTo(1)); - Assert.IsTrue(receivedNotification); + Assert.That(receivedNotification); } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1024")] @@ -73,7 +73,7 @@ public void Wait() notifyingConn.ExecuteNonQuery($"NOTIFY {notify}"); conn.Notification += (o, e) => receivedNotification = true; Assert.That(conn.Wait(0), Is.EqualTo(true)); - Assert.IsTrue(receivedNotification); + Assert.That(receivedNotification); Assert.That(conn.ExecuteScalar("SELECT 1"), Is.EqualTo(1)); } @@ -106,7 +106,7 @@ public async Task WaitAsync() await notifyingConn.ExecuteNonQueryAsync($"NOTIFY {notify}"); conn.Notification += (o, e) => receivedNotification = true; await conn.WaitAsync(0); - Assert.IsTrue(receivedNotification); + Assert.That(receivedNotification); Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); } diff --git a/test/Npgsql.Tests/Npgsql.Tests.csproj b/test/Npgsql.Tests/Npgsql.Tests.csproj index 980b51d8aa..3714b9edaa 100644 --- a/test/Npgsql.Tests/Npgsql.Tests.csproj +++ b/test/Npgsql.Tests/Npgsql.Tests.csproj @@ -3,14 +3,27 @@ - + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + PreserveNewest + + true + $(NoWarn);NPG9001 + $(NoWarn);NPG9002 + $(NoWarn);NPG9003 diff --git a/test/Npgsql.Tests/NpgsqlEventSourceTests.cs b/test/Npgsql.Tests/NpgsqlEventSourceTests.cs index c1659e6fba..1da8c0745d 100644 --- a/test/Npgsql.Tests/NpgsqlEventSourceTests.cs +++ b/test/Npgsql.Tests/NpgsqlEventSourceTests.cs @@ -44,12 +44,10 @@ public void DisableEventSource() TestEventListener _listener = null!; - readonly List _events = new(); + readonly List _events = []; - class TestEventListener : EventListener + class TestEventListener(List events) : EventListener { - readonly List _events; - public TestEventListener(List events) => _events = events; - protected override void OnEventWritten(EventWrittenEventArgs eventData) => _events.Add(eventData); + protected override void OnEventWritten(EventWrittenEventArgs eventData) => events.Add(eventData); } } diff --git a/test/Npgsql.Tests/NpgsqlParameterCollectionTests.cs b/test/Npgsql.Tests/NpgsqlParameterCollectionTests.cs index 6c09b7b708..901e34ece9 100644 --- a/test/Npgsql.Tests/NpgsqlParameterCollectionTests.cs +++ b/test/Npgsql.Tests/NpgsqlParameterCollectionTests.cs @@ -4,6 +4,7 @@ using System.Data; using System.Data.Common; using System.Diagnostics.CodeAnalysis; +using System.Linq; namespace Npgsql.Tests; @@ -36,13 +37,13 @@ public void Clear() var c1 = new NpgsqlCommand(); var c2 = new NpgsqlCommand(); c1.Parameters.Add(p); - Assert.AreEqual(1, c1.Parameters.Count); - Assert.AreEqual(0, c2.Parameters.Count); + Assert.That(c1.Parameters.Count, Is.EqualTo(1)); + Assert.That(c2.Parameters.Count, Is.EqualTo(0)); c1.Parameters.Clear(); - Assert.AreEqual(0, c1.Parameters.Count); + Assert.That(c1.Parameters.Count, Is.EqualTo(0)); c2.Parameters.Add(p); - Assert.AreEqual(0, c1.Parameters.Count); - Assert.AreEqual(1, c2.Parameters.Count); + Assert.That(c1.Parameters.Count, Is.EqualTo(0)); + Assert.That(c2.Parameters.Count, Is.EqualTo(1)); } [Test] @@ -59,7 +60,7 @@ public void Hash_lookup_parameter_rename_bug() } // Make sure hash lookup is generated. - Assert.AreEqual(command.Parameters["p03"].ParameterName, "p03"); + Assert.That(command.Parameters["p03"].ParameterName, Is.EqualTo("p03")); // Rename the target parameter. command.Parameters["p03"].ParameterName = "a_new_name"; @@ -70,6 +71,34 @@ public void Hash_lookup_parameter_rename_bug() Assert.That(command.Parameters.IndexOf("a_new_name"), Is.GreaterThanOrEqualTo(0)); } + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/6067")] + public void Hash_lookup_unnamed_parameter_rename_bug() + { + if (_compatMode == CompatMode.TwoPass) + return; + + using var command = new NpgsqlCommand(); + + for (var i = 0; i < 3; i++) + { + // Put plenty of parameters in the collection to turn on hash lookup functionality. + for (var j = 0; j < LookupThreshold; j++) + { + // Create and add an unnamed parameter before renaming it + var parameter = command.CreateParameter(); + command.Parameters.Add(parameter); + parameter.ParameterName = $"{j}"; + } + + // Make sure hash lookup is generated. + Assert.That(command.Parameters["3"].ParameterName, Is.EqualTo("3")); + + // Remove all parameters to clear hash lookup + command.Parameters.Clear(); + } + } + [Test] public void Remove_duplicate_parameter([Values(LookupThreshold, LookupThreshold - 2)] int count) { @@ -85,7 +114,7 @@ public void Remove_duplicate_parameter([Values(LookupThreshold, LookupThreshold } // Make sure lookup is generated. - Assert.AreEqual(command.Parameters["p02"].ParameterName, "p02"); + Assert.That(command.Parameters["p02"].ParameterName, Is.EqualTo("p02")); // Add uppercased version causing a list to be created. command.Parameters.AddWithValue("P02", NpgsqlDbType.Text, "String parameter value 2"); @@ -94,10 +123,10 @@ public void Remove_duplicate_parameter([Values(LookupThreshold, LookupThreshold command.Parameters.Remove(command.Parameters["p02"]); // Test whether we can still find the last added parameter, and if its index is correctly shifted in the lookup. - Assert.IsTrue(command.Parameters.IndexOf("p02") == count - 1); - Assert.IsTrue(command.Parameters.IndexOf("P02") == count - 1); + Assert.That(command.Parameters.IndexOf("p02") == count - 1); + Assert.That(command.Parameters.IndexOf("P02") == count - 1); // And finally test whether other parameters were also correctly shifted. - Assert.IsTrue(command.Parameters.IndexOf("p03") == 1); + Assert.That(command.Parameters.IndexOf("p03") == 1); } [Test] @@ -115,8 +144,8 @@ public void Remove_parameter([Values(LookupThreshold, LookupThreshold - 2)] int command.Parameters.Remove(command.Parameters["p02"]); // Make sure we cannot find it, also not case insensitively. - Assert.IsTrue(command.Parameters.IndexOf("p02") == -1); - Assert.IsTrue(command.Parameters.IndexOf("P02") == -1); + Assert.That(command.Parameters.IndexOf("p02") == -1); + Assert.That(command.Parameters.IndexOf("P02") == -1); } [Test] @@ -155,7 +184,7 @@ public void Correct_index_returned_for_duplicate_ParameterName([Values(LookupThr } // Make sure lookup is generated. - Assert.AreEqual(command.Parameters["parameter02"].ParameterName, "parameter02"); + Assert.That(command.Parameters["parameter02"].ParameterName, Is.EqualTo("parameter02")); // Add uppercased version. command.Parameters.AddWithValue("Parameter02", NpgsqlDbType.Text, "String parameter value 2"); @@ -164,14 +193,14 @@ public void Correct_index_returned_for_duplicate_ParameterName([Values(LookupThr command.Parameters.Insert(0, new NpgsqlParameter("ParameteR02", NpgsqlDbType.Text) { Value = "String parameter value 2" }); // Try to find the exact index. - Assert.IsTrue(command.Parameters.IndexOf("parameter02") == 2); - Assert.IsTrue(command.Parameters.IndexOf("Parameter02") == command.Parameters.Count - 1); - Assert.IsTrue(command.Parameters.IndexOf("ParameteR02") == 0); + Assert.That(command.Parameters.IndexOf("parameter02") == 2); + Assert.That(command.Parameters.IndexOf("Parameter02") == command.Parameters.Count - 1); + Assert.That(command.Parameters.IndexOf("ParameteR02") == 0); // This name does not exist so we expect the first case insensitive match to be returned. - Assert.IsTrue(command.Parameters.IndexOf("ParaMeteR02") == 0); + Assert.That(command.Parameters.IndexOf("ParaMeteR02") == 0); // And finally test whether other parameters were also correctly shifted. - Assert.IsTrue(command.Parameters.IndexOf("parameter03") == 3); + Assert.That(command.Parameters.IndexOf("parameter03") == 3); } [Test] @@ -316,10 +345,22 @@ public void Clean_name() param.ParameterName = null; // These should not throw exceptions - Assert.AreEqual(0, command.Parameters.IndexOf(param.ParameterName)); - Assert.AreEqual(NpgsqlParameter.PositionalName, param.ParameterName); + Assert.That(command.Parameters.IndexOf(param.ParameterName), Is.EqualTo(0)); + Assert.That(param.ParameterName, Is.EqualTo(NpgsqlParameter.PositionalName)); + } + + [Test] + public void Clone_sets_correct_collection() + { + var cmd = new NpgsqlCommand(); + cmd.Parameters.Add(new NpgsqlParameter { TypedValue = 42 }); + Assert.That(cmd.Parameters.Single().Collection, Is.SameAs(cmd.Parameters)); + + cmd = cmd.Clone(); + Assert.That(cmd.Parameters.Single().Collection, Is.SameAs(cmd.Parameters)); } + public NpgsqlParameterCollectionTests(CompatMode compatMode) { _compatMode = compatMode; diff --git a/test/Npgsql.Tests/NpgsqlParameterTests.cs b/test/Npgsql.Tests/NpgsqlParameterTests.cs index 9a4610aadd..6070cc7266 100644 --- a/test/Npgsql.Tests/NpgsqlParameterTests.cs +++ b/test/Npgsql.Tests/NpgsqlParameterTests.cs @@ -4,7 +4,6 @@ using System.Data; using System.Data.Common; using System.Threading.Tasks; -using Npgsql.Internal.Postgres; namespace Npgsql.Tests; @@ -133,7 +132,7 @@ public void Setting_NpgsqlDbType_sets_DbType() [Test] public void Setting_value_does_not_change_DbType() { - var p = new NpgsqlParameter { DbType = DbType.String, NpgsqlDbType = NpgsqlDbType.Bytea }; + var p = new NpgsqlParameter { DbType = DbType.Binary, NpgsqlDbType = NpgsqlDbType.Bytea }; p.Value = 8; Assert.That(p.DbType, Is.EqualTo(DbType.Binary)); Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Bytea)); @@ -147,17 +146,17 @@ public void Setting_value_does_not_change_DbType() public void Constructor1() { var p = new NpgsqlParameter(); - Assert.AreEqual(DbType.Object, p.DbType, "DbType"); - Assert.AreEqual(ParameterDirection.Input, p.Direction, "Direction"); - Assert.IsFalse(p.IsNullable, "IsNullable"); - Assert.AreEqual(string.Empty, p.ParameterName, "ParameterName"); - Assert.AreEqual(0, p.Precision, "Precision"); - Assert.AreEqual(0, p.Scale, "Scale"); - Assert.AreEqual(0, p.Size, "Size"); - Assert.AreEqual(string.Empty, p.SourceColumn, "SourceColumn"); - Assert.AreEqual(DataRowVersion.Current, p.SourceVersion, "SourceVersion"); - Assert.AreEqual(NpgsqlDbType.Unknown, p.NpgsqlDbType, "NpgsqlDbType"); - Assert.IsNull(p.Value, "Value"); + Assert.That(p.DbType, Is.EqualTo(DbType.Object), "DbType"); + Assert.That(p.Direction, Is.EqualTo(ParameterDirection.Input), "Direction"); + Assert.That(p.IsNullable, Is.False, "IsNullable"); + Assert.That(p.ParameterName, Is.Empty, "ParameterName"); + Assert.That(p.Precision, Is.EqualTo(0), "Precision"); + Assert.That(p.Scale, Is.EqualTo(0), "Scale"); + Assert.That(p.Size, Is.EqualTo(0), "Size"); + Assert.That(p.SourceColumn, Is.Empty, "SourceColumn"); + Assert.That(p.SourceVersion, Is.EqualTo(DataRowVersion.Current), "SourceVersion"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Unknown), "NpgsqlDbType"); + Assert.That(p.Value, Is.Null, "Value"); } [Test] @@ -166,51 +165,51 @@ public void Constructor2_Value_DateTime() var value = new DateTime(2004, 8, 24); var p = new NpgsqlParameter("address", value); - Assert.AreEqual(DbType.DateTime2, p.DbType, "B:DbType"); - Assert.AreEqual(ParameterDirection.Input, p.Direction, "B:Direction"); - Assert.IsFalse(p.IsNullable, "B:IsNullable"); - Assert.AreEqual("address", p.ParameterName, "B:ParameterName"); - Assert.AreEqual(0, p.Precision, "B:Precision"); - Assert.AreEqual(0, p.Scale, "B:Scale"); + Assert.That(p.DbType, Is.EqualTo(DbType.DateTime2), "B:DbType"); + Assert.That(p.Direction, Is.EqualTo(ParameterDirection.Input), "B:Direction"); + Assert.That(p.IsNullable, Is.False, "B:IsNullable"); + Assert.That(p.ParameterName, Is.EqualTo("address"), "B:ParameterName"); + Assert.That(p.Precision, Is.EqualTo(0), "B:Precision"); + Assert.That(p.Scale, Is.EqualTo(0), "B:Scale"); //Assert.AreEqual (0, p.Size, "B:Size"); - Assert.AreEqual(string.Empty, p.SourceColumn, "B:SourceColumn"); - Assert.AreEqual(DataRowVersion.Current, p.SourceVersion, "B:SourceVersion"); - Assert.AreEqual(NpgsqlDbType.Timestamp, p.NpgsqlDbType, "B:NpgsqlDbType"); - Assert.AreEqual(value, p.Value, "B:Value"); + Assert.That(p.SourceColumn, Is.Empty, "B:SourceColumn"); + Assert.That(p.SourceVersion, Is.EqualTo(DataRowVersion.Current), "B:SourceVersion"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Timestamp), "B:NpgsqlDbType"); + Assert.That(p.Value, Is.EqualTo(value), "B:Value"); } [Test] public void Constructor2_Value_DBNull() { var p = new NpgsqlParameter("address", DBNull.Value); - Assert.AreEqual(DbType.Object, p.DbType, "B:DbType"); - Assert.AreEqual(ParameterDirection.Input, p.Direction, "B:Direction"); - Assert.IsFalse(p.IsNullable, "B:IsNullable"); - Assert.AreEqual("address", p.ParameterName, "B:ParameterName"); - Assert.AreEqual(0, p.Precision, "B:Precision"); - Assert.AreEqual(0, p.Scale, "B:Scale"); - Assert.AreEqual(0, p.Size, "B:Size"); - Assert.AreEqual(string.Empty, p.SourceColumn, "B:SourceColumn"); - Assert.AreEqual(DataRowVersion.Current, p.SourceVersion, "B:SourceVersion"); - Assert.AreEqual(NpgsqlDbType.Unknown, p.NpgsqlDbType, "B:NpgsqlDbType"); - Assert.AreEqual(DBNull.Value, p.Value, "B:Value"); + Assert.That(p.DbType, Is.EqualTo(DbType.Object), "B:DbType"); + Assert.That(p.Direction, Is.EqualTo(ParameterDirection.Input), "B:Direction"); + Assert.That(p.IsNullable, Is.False, "B:IsNullable"); + Assert.That(p.ParameterName, Is.EqualTo("address"), "B:ParameterName"); + Assert.That(p.Precision, Is.EqualTo(0), "B:Precision"); + Assert.That(p.Scale, Is.EqualTo(0), "B:Scale"); + Assert.That(p.Size, Is.EqualTo(0), "B:Size"); + Assert.That(p.SourceColumn, Is.Empty, "B:SourceColumn"); + Assert.That(p.SourceVersion, Is.EqualTo(DataRowVersion.Current), "B:SourceVersion"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Unknown), "B:NpgsqlDbType"); + Assert.That(p.Value, Is.EqualTo(DBNull.Value), "B:Value"); } [Test] public void Constructor2_Value_null() { var p = new NpgsqlParameter("address", null); - Assert.AreEqual(DbType.Object, p.DbType, "A:DbType"); - Assert.AreEqual(ParameterDirection.Input, p.Direction, "A:Direction"); - Assert.IsFalse(p.IsNullable, "A:IsNullable"); - Assert.AreEqual("address", p.ParameterName, "A:ParameterName"); - Assert.AreEqual(0, p.Precision, "A:Precision"); - Assert.AreEqual(0, p.Scale, "A:Scale"); - Assert.AreEqual(0, p.Size, "A:Size"); - Assert.AreEqual(string.Empty, p.SourceColumn, "A:SourceColumn"); - Assert.AreEqual(DataRowVersion.Current, p.SourceVersion, "A:SourceVersion"); - Assert.AreEqual(NpgsqlDbType.Unknown, p.NpgsqlDbType, "A:NpgsqlDbType"); - Assert.IsNull(p.Value, "A:Value"); + Assert.That(p.DbType, Is.EqualTo(DbType.Object), "A:DbType"); + Assert.That(p.Direction, Is.EqualTo(ParameterDirection.Input), "A:Direction"); + Assert.That(p.IsNullable, Is.False, "A:IsNullable"); + Assert.That(p.ParameterName, Is.EqualTo("address"), "A:ParameterName"); + Assert.That(p.Precision, Is.EqualTo(0), "A:Precision"); + Assert.That(p.Scale, Is.EqualTo(0), "A:Scale"); + Assert.That(p.Size, Is.EqualTo(0), "A:Size"); + Assert.That(p.SourceColumn, Is.Empty, "A:SourceColumn"); + Assert.That(p.SourceVersion, Is.EqualTo(DataRowVersion.Current), "A:SourceVersion"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Unknown), "A:NpgsqlDbType"); + Assert.That(p.Value, Is.Null, "A:Value"); } [Test] @@ -220,20 +219,20 @@ public void Constructor7() var p1 = new NpgsqlParameter("p1Name", NpgsqlDbType.Varchar, 20, "srcCol", ParameterDirection.InputOutput, false, 0, 0, DataRowVersion.Original, "foo"); - Assert.AreEqual(DbType.String, p1.DbType, "DbType"); - Assert.AreEqual(ParameterDirection.InputOutput, p1.Direction, "Direction"); - Assert.AreEqual(false, p1.IsNullable, "IsNullable"); + Assert.That(p1.DbType, Is.EqualTo(DbType.String), "DbType"); + Assert.That(p1.Direction, Is.EqualTo(ParameterDirection.InputOutput), "Direction"); + Assert.That(p1.IsNullable, Is.EqualTo(false), "IsNullable"); //Assert.AreEqual (999, p1.LocaleId, "#"); - Assert.AreEqual("p1Name", p1.ParameterName, "ParameterName"); - Assert.AreEqual(0, p1.Precision, "Precision"); - Assert.AreEqual(0, p1.Scale, "Scale"); - Assert.AreEqual(20, p1.Size, "Size"); - Assert.AreEqual("srcCol", p1.SourceColumn, "SourceColumn"); - Assert.AreEqual(false, p1.SourceColumnNullMapping, "SourceColumnNullMapping"); - Assert.AreEqual(DataRowVersion.Original, p1.SourceVersion, "SourceVersion"); - Assert.AreEqual(NpgsqlDbType.Varchar, p1.NpgsqlDbType, "NpgsqlDbType"); + Assert.That(p1.ParameterName, Is.EqualTo("p1Name"), "ParameterName"); + Assert.That(p1.Precision, Is.EqualTo(0), "Precision"); + Assert.That(p1.Scale, Is.EqualTo(0), "Scale"); + Assert.That(p1.Size, Is.EqualTo(20), "Size"); + Assert.That(p1.SourceColumn, Is.EqualTo("srcCol"), "SourceColumn"); + Assert.That(p1.SourceColumnNullMapping, Is.EqualTo(false), "SourceColumnNullMapping"); + Assert.That(p1.SourceVersion, Is.EqualTo(DataRowVersion.Original), "SourceVersion"); + Assert.That(p1.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Varchar), "NpgsqlDbType"); //Assert.AreEqual (3210, p1.NpgsqlValue, "#"); - Assert.AreEqual("foo", p1.Value, "Value"); + Assert.That(p1.Value, Is.EqualTo("foo"), "Value"); //Assert.AreEqual ("database", p1.XmlSchemaCollectionDatabase, "XmlSchemaCollectionDatabase"); //Assert.AreEqual ("name", p1.XmlSchemaCollectionName, "XmlSchemaCollectionName"); //Assert.AreEqual ("schema", p1.XmlSchemaCollectionOwningSchema, "XmlSchemaCollectionOwningSchema"); @@ -263,22 +262,22 @@ public void Clone() }; var actual = expected.Clone(); - Assert.AreEqual(expected.Value, actual.Value); - Assert.AreEqual(expected.ParameterName, actual.ParameterName); + Assert.That(actual.Value, Is.EqualTo(expected.Value)); + Assert.That(actual.ParameterName, Is.EqualTo(expected.ParameterName)); - Assert.AreEqual(expected.DbType, actual.DbType); - Assert.AreEqual(expected.NpgsqlDbType, actual.NpgsqlDbType); - Assert.AreEqual(expected.DataTypeName, actual.DataTypeName); + Assert.That(actual.DbType, Is.EqualTo(expected.DbType)); + Assert.That(actual.NpgsqlDbType, Is.EqualTo(expected.NpgsqlDbType)); + Assert.That(actual.DataTypeName, Is.EqualTo(expected.DataTypeName)); - Assert.AreEqual(expected.Direction, actual.Direction); - Assert.AreEqual(expected.IsNullable, actual.IsNullable); - Assert.AreEqual(expected.Precision, actual.Precision); - Assert.AreEqual(expected.Scale, actual.Scale); - Assert.AreEqual(expected.Size, actual.Size); + Assert.That(actual.Direction, Is.EqualTo(expected.Direction)); + Assert.That(actual.IsNullable, Is.EqualTo(expected.IsNullable)); + Assert.That(actual.Precision, Is.EqualTo(expected.Precision)); + Assert.That(actual.Scale, Is.EqualTo(expected.Scale)); + Assert.That(actual.Size, Is.EqualTo(expected.Size)); - Assert.AreEqual(expected.SourceVersion, actual.SourceVersion); - Assert.AreEqual(expected.SourceColumn, actual.SourceColumn); - Assert.AreEqual(expected.SourceColumnNullMapping, actual.SourceColumnNullMapping); + Assert.That(actual.SourceVersion, Is.EqualTo(expected.SourceVersion)); + Assert.That(actual.SourceColumn, Is.EqualTo(expected.SourceColumn)); + Assert.That(actual.SourceColumnNullMapping, Is.EqualTo(expected.SourceColumnNullMapping)); } [Test] @@ -305,136 +304,97 @@ public void Clone_generic() }; var actual = (NpgsqlParameter)expected.Clone(); - Assert.AreEqual(expected.Value, actual.Value); - Assert.AreEqual(expected.TypedValue, actual.TypedValue); - Assert.AreEqual(expected.ParameterName, actual.ParameterName); + Assert.That(actual.Value, Is.EqualTo(expected.Value)); + Assert.That(actual.TypedValue, Is.EqualTo(expected.TypedValue)); + Assert.That(actual.ParameterName, Is.EqualTo(expected.ParameterName)); - Assert.AreEqual(expected.DbType, actual.DbType); - Assert.AreEqual(expected.NpgsqlDbType, actual.NpgsqlDbType); - Assert.AreEqual(expected.DataTypeName, actual.DataTypeName); + Assert.That(actual.DbType, Is.EqualTo(expected.DbType)); + Assert.That(actual.NpgsqlDbType, Is.EqualTo(expected.NpgsqlDbType)); + Assert.That(actual.DataTypeName, Is.EqualTo(expected.DataTypeName)); - Assert.AreEqual(expected.Direction, actual.Direction); - Assert.AreEqual(expected.IsNullable, actual.IsNullable); - Assert.AreEqual(expected.Precision, actual.Precision); - Assert.AreEqual(expected.Scale, actual.Scale); - Assert.AreEqual(expected.Size, actual.Size); + Assert.That(actual.Direction, Is.EqualTo(expected.Direction)); + Assert.That(actual.IsNullable, Is.EqualTo(expected.IsNullable)); + Assert.That(actual.Precision, Is.EqualTo(expected.Precision)); + Assert.That(actual.Scale, Is.EqualTo(expected.Scale)); + Assert.That(actual.Size, Is.EqualTo(expected.Size)); - Assert.AreEqual(expected.SourceVersion, actual.SourceVersion); - Assert.AreEqual(expected.SourceColumn, actual.SourceColumn); - Assert.AreEqual(expected.SourceColumnNullMapping, actual.SourceColumnNullMapping); + Assert.That(actual.SourceVersion, Is.EqualTo(expected.SourceVersion)); + Assert.That(actual.SourceColumn, Is.EqualTo(expected.SourceColumn)); + Assert.That(actual.SourceColumnNullMapping, Is.EqualTo(expected.SourceColumnNullMapping)); } #endregion - [Test] - [Ignore("")] - public void InferType_invalid_throws() - { - var notsupported = new object[] - { - ushort.MaxValue, - uint.MaxValue, - ulong.MaxValue, - sbyte.MaxValue, - new NpgsqlParameter() - }; - - var param = new NpgsqlParameter(); - - for (var i = 0; i < notsupported.Length; i++) - { - try - { - param.Value = notsupported[i]; - Assert.Fail("#A1:" + i); - } - catch (FormatException) - { - // appears to be bug in .NET 1.1 while - // constructing exception message - } - catch (ArgumentException ex) - { - // The parameter data type of ... is invalid - Assert.AreEqual(typeof(ArgumentException), ex.GetType(), "#A2"); - Assert.IsNull(ex.InnerException, "#A3"); - Assert.IsNotNull(ex.Message, "#A4"); - Assert.IsNull(ex.ParamName, "#A5"); - } - } - } - [Test] // bug #320196 public void Parameter_null() { var param = new NpgsqlParameter("param", NpgsqlDbType.Numeric); - Assert.AreEqual(0, param.Scale, "#A1"); + Assert.That(param.Scale, Is.EqualTo(0), "#A1"); param.Value = DBNull.Value; - Assert.AreEqual(0, param.Scale, "#A2"); + Assert.That(param.Scale, Is.EqualTo(0), "#A2"); param = new NpgsqlParameter("param", NpgsqlDbType.Integer); - Assert.AreEqual(0, param.Scale, "#B1"); + Assert.That(param.Scale, Is.EqualTo(0), "#B1"); param.Value = DBNull.Value; - Assert.AreEqual(0, param.Scale, "#B2"); + Assert.That(param.Scale, Is.EqualTo(0), "#B2"); } [Test] - [Ignore("")] public void Parameter_type() { NpgsqlParameter p; // If Type is not set, then type is inferred from the value // assigned. The Type should be inferred everytime Value is assigned - // If value is null or DBNull, then the current Type should be reset to Text. - p = new NpgsqlParameter(); - Assert.AreEqual(DbType.String, p.DbType, "#A1"); - Assert.AreEqual(NpgsqlDbType.Text, p.NpgsqlDbType, "#A2"); + // If value is null or DBNull, then the current Type should be reset to Unknown (DbType.Object and NpgsqlDbType.Unknown). + p = new NpgsqlParameter { Value = "" }; + Assert.That(p.DbType, Is.EqualTo(DbType.String), "#A1"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Text), "#A2"); p.Value = DBNull.Value; - Assert.AreEqual(DbType.String, p.DbType, "#B1"); - Assert.AreEqual(NpgsqlDbType.Text, p.NpgsqlDbType, "#B2"); + Assert.That(p.DbType, Is.EqualTo(DbType.Object), "#B1"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Unknown), "#B2"); p.Value = 1; - Assert.AreEqual(DbType.Int32, p.DbType, "#C1"); - Assert.AreEqual(NpgsqlDbType.Integer, p.NpgsqlDbType, "#C2"); + Assert.That(p.DbType, Is.EqualTo(DbType.Int32), "#C1"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer), "#C2"); p.Value = DBNull.Value; - Assert.AreEqual(DbType.String, p.DbType, "#D1"); - Assert.AreEqual(NpgsqlDbType.Text, p.NpgsqlDbType, "#D2"); + Assert.That(p.DbType, Is.EqualTo(DbType.Object), "#D1"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Unknown), "#D2"); p.Value = new byte[] { 0x0a }; - Assert.AreEqual(DbType.Binary, p.DbType, "#E1"); - Assert.AreEqual(NpgsqlDbType.Bytea, p.NpgsqlDbType, "#E2"); + Assert.That(p.DbType, Is.EqualTo(DbType.Binary), "#E1"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Bytea), "#E2"); p.Value = null; - Assert.AreEqual(DbType.String, p.DbType, "#F1"); - Assert.AreEqual(NpgsqlDbType.Text, p.NpgsqlDbType, "#F2"); + Assert.That(p.DbType, Is.EqualTo(DbType.Object), "#F1"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Unknown), "#F2"); p.Value = DateTime.Now; - Assert.AreEqual(DbType.DateTime, p.DbType, "#G1"); - Assert.AreEqual(NpgsqlDbType.Timestamp, p.NpgsqlDbType, "#G2"); + Assert.That(p.DbType, Is.EqualTo(DbType.DateTime2), "#G1"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Timestamp), "#G2"); p.Value = null; - Assert.AreEqual(DbType.String, p.DbType, "#H1"); - Assert.AreEqual(NpgsqlDbType.Text, p.NpgsqlDbType, "#H2"); + Assert.That(p.DbType, Is.EqualTo(DbType.Object), "#H1"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Unknown), "#H2"); // If DbType is set, then the NpgsqlDbType should not be // inferred from the value assigned. p = new NpgsqlParameter(); p.DbType = DbType.DateTime; - Assert.AreEqual(NpgsqlDbType.Timestamp, p.NpgsqlDbType, "#I1"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.TimestampTz), "#I1"); p.Value = 1; - Assert.AreEqual(NpgsqlDbType.Timestamp, p.NpgsqlDbType, "#I2"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.TimestampTz), "#I2"); p.Value = null; - Assert.AreEqual(NpgsqlDbType.Timestamp, p.NpgsqlDbType, "#I3"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.TimestampTz), "#I3"); p.Value = DBNull.Value; - Assert.AreEqual(NpgsqlDbType.Timestamp, p.NpgsqlDbType, "#I4"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.TimestampTz), "#I4"); // If NpgsqlDbType is set, then the DbType should not be // inferred from the value assigned. p = new NpgsqlParameter(); p.NpgsqlDbType = NpgsqlDbType.Bytea; - Assert.AreEqual(NpgsqlDbType.Bytea, p.NpgsqlDbType, "#J1"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Bytea), "#J1"); p.Value = 1; - Assert.AreEqual(NpgsqlDbType.Bytea, p.NpgsqlDbType, "#J2"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Bytea), "#J2"); p.Value = null; - Assert.AreEqual(NpgsqlDbType.Bytea, p.NpgsqlDbType, "#J3"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Bytea), "#J3"); p.Value = DBNull.Value; - Assert.AreEqual(NpgsqlDbType.Bytea, p.NpgsqlDbType, "#J4"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Bytea), "#J4"); } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/5428")] @@ -447,29 +407,28 @@ public async Task Match_param_index_case_insensitively() } [Test] - [Ignore("")] public void ParameterName() { var p = new NpgsqlParameter(); p.ParameterName = "name"; - Assert.AreEqual("name", p.ParameterName, "#A:ParameterName"); - Assert.AreEqual(string.Empty, p.SourceColumn, "#A:SourceColumn"); + Assert.That(p.ParameterName, Is.EqualTo("name"), "#A:ParameterName"); + Assert.That(p.SourceColumn, Is.Empty, "#A:SourceColumn"); p.ParameterName = null; - Assert.AreEqual(string.Empty, p.ParameterName, "#B:ParameterName"); - Assert.AreEqual(string.Empty, p.SourceColumn, "#B:SourceColumn"); + Assert.That(p.ParameterName, Is.Empty, "#B:ParameterName"); + Assert.That(p.SourceColumn, Is.Empty, "#B:SourceColumn"); p.ParameterName = " "; - Assert.AreEqual(" ", p.ParameterName, "#C:ParameterName"); - Assert.AreEqual(string.Empty, p.SourceColumn, "#C:SourceColumn"); + Assert.That(p.ParameterName, Is.EqualTo(" "), "#C:ParameterName"); + Assert.That(p.SourceColumn, Is.Empty, "#C:SourceColumn"); p.ParameterName = " name "; - Assert.AreEqual(" name ", p.ParameterName, "#D:ParameterName"); - Assert.AreEqual(string.Empty, p.SourceColumn, "#D:SourceColumn"); + Assert.That(p.ParameterName, Is.EqualTo(" name "), "#D:ParameterName"); + Assert.That(p.SourceColumn, Is.Empty, "#D:SourceColumn"); p.ParameterName = string.Empty; - Assert.AreEqual(string.Empty, p.ParameterName, "#E:ParameterName"); - Assert.AreEqual(string.Empty, p.SourceColumn, "#E:SourceColumn"); + Assert.That(p.ParameterName, Is.Empty, "#E:ParameterName"); + Assert.That(p.SourceColumn, Is.Empty, "#E:SourceColumn"); } [Test] @@ -480,59 +439,59 @@ public void ResetDbType() //Parameter with an assigned value but no DbType specified p = new NpgsqlParameter("foo", 42); p.ResetDbType(); - Assert.AreEqual(DbType.Int32, p.DbType, "#A:DbType"); - Assert.AreEqual(NpgsqlDbType.Integer, p.NpgsqlDbType, "#A:NpgsqlDbType"); - Assert.AreEqual(42, p.Value, "#A:Value"); + Assert.That(p.DbType, Is.EqualTo(DbType.Int32), "#A:DbType"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer), "#A:NpgsqlDbType"); + Assert.That(p.Value, Is.EqualTo(42), "#A:Value"); p.DbType = DbType.DateTime; //assigning a DbType - Assert.AreEqual(DbType.DateTime, p.DbType, "#B:DbType1"); - Assert.AreEqual(NpgsqlDbType.TimestampTz, p.NpgsqlDbType, "#B:SqlDbType1"); + Assert.That(p.DbType, Is.EqualTo(DbType.DateTime), "#B:DbType1"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.TimestampTz), "#B:SqlDbType1"); p.ResetDbType(); - Assert.AreEqual(DbType.Int32, p.DbType, "#B:DbType2"); - Assert.AreEqual(NpgsqlDbType.Integer, p.NpgsqlDbType, "#B:SqlDbtype2"); + Assert.That(p.DbType, Is.EqualTo(DbType.Int32), "#B:DbType2"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer), "#B:SqlDbtype2"); //Parameter with an assigned NpgsqlDbType but no specified value p = new NpgsqlParameter("foo", NpgsqlDbType.Integer); p.ResetDbType(); - Assert.AreEqual(DbType.Object, p.DbType, "#C:DbType"); - Assert.AreEqual(NpgsqlDbType.Unknown, p.NpgsqlDbType, "#C:NpgsqlDbType"); + Assert.That(p.DbType, Is.EqualTo(DbType.Object), "#C:DbType"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Unknown), "#C:NpgsqlDbType"); p.NpgsqlDbType = NpgsqlDbType.TimestampTz; //assigning a NpgsqlDbType - Assert.AreEqual(DbType.DateTime, p.DbType, "#D:DbType1"); - Assert.AreEqual(NpgsqlDbType.TimestampTz, p.NpgsqlDbType, "#D:SqlDbType1"); + Assert.That(p.DbType, Is.EqualTo(DbType.DateTime), "#D:DbType1"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.TimestampTz), "#D:SqlDbType1"); p.ResetDbType(); - Assert.AreEqual(DbType.Object, p.DbType, "#D:DbType2"); - Assert.AreEqual(NpgsqlDbType.Unknown, p.NpgsqlDbType, "#D:SqlDbType2"); + Assert.That(p.DbType, Is.EqualTo(DbType.Object), "#D:DbType2"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Unknown), "#D:SqlDbType2"); p = new NpgsqlParameter(); p.Value = DateTime.MaxValue; - Assert.AreEqual(DbType.DateTime2, p.DbType, "#E:DbType1"); - Assert.AreEqual(NpgsqlDbType.Timestamp, p.NpgsqlDbType, "#E:SqlDbType1"); + Assert.That(p.DbType, Is.EqualTo(DbType.DateTime2), "#E:DbType1"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Timestamp), "#E:SqlDbType1"); p.Value = null; p.ResetDbType(); - Assert.AreEqual(DbType.Object, p.DbType, "#E:DbType2"); - Assert.AreEqual(NpgsqlDbType.Unknown, p.NpgsqlDbType, "#E:SqlDbType2"); + Assert.That(p.DbType, Is.EqualTo(DbType.Object), "#E:DbType2"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Unknown), "#E:SqlDbType2"); p = new NpgsqlParameter("foo", NpgsqlDbType.Varchar); p.Value = DateTime.MaxValue; p.ResetDbType(); - Assert.AreEqual(DbType.DateTime2, p.DbType, "#F:DbType"); - Assert.AreEqual(NpgsqlDbType.Timestamp, p.NpgsqlDbType, "#F:NpgsqlDbType"); - Assert.AreEqual(DateTime.MaxValue, p.Value, "#F:Value"); + Assert.That(p.DbType, Is.EqualTo(DbType.DateTime2), "#F:DbType"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Timestamp), "#F:NpgsqlDbType"); + Assert.That(p.Value, Is.EqualTo(DateTime.MaxValue), "#F:Value"); p = new NpgsqlParameter("foo", NpgsqlDbType.Varchar); p.Value = DBNull.Value; p.ResetDbType(); - Assert.AreEqual(DbType.Object, p.DbType, "#G:DbType"); - Assert.AreEqual(NpgsqlDbType.Unknown, p.NpgsqlDbType, "#G:NpgsqlDbType"); - Assert.AreEqual(DBNull.Value, p.Value, "#G:Value"); + Assert.That(p.DbType, Is.EqualTo(DbType.Object), "#G:DbType"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Unknown), "#G:NpgsqlDbType"); + Assert.That(p.Value, Is.EqualTo(DBNull.Value), "#G:Value"); p = new NpgsqlParameter("foo", NpgsqlDbType.Varchar); p.Value = null; p.ResetDbType(); - Assert.AreEqual(DbType.Object, p.DbType, "#G:DbType"); - Assert.AreEqual(NpgsqlDbType.Unknown, p.NpgsqlDbType, "#G:NpgsqlDbType"); - Assert.IsNull(p.Value, "#G:Value"); + Assert.That(p.DbType, Is.EqualTo(DbType.Object), "#G:DbType"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Unknown), "#G:NpgsqlDbType"); + Assert.That(p.Value, Is.Null, "#G:Value"); } [Test] @@ -540,29 +499,28 @@ public void ParameterName_retains_prefix() => Assert.That(new NpgsqlParameter("@p", DbType.String).ParameterName, Is.EqualTo("@p")); [Test] - [Ignore("")] public void SourceColumn() { var p = new NpgsqlParameter(); p.SourceColumn = "name"; - Assert.AreEqual(string.Empty, p.ParameterName, "#A:ParameterName"); - Assert.AreEqual("name", p.SourceColumn, "#A:SourceColumn"); + Assert.That(p.ParameterName, Is.Empty, "#A:ParameterName"); + Assert.That(p.SourceColumn, Is.EqualTo("name"), "#A:SourceColumn"); p.SourceColumn = null; - Assert.AreEqual(string.Empty, p.ParameterName, "#B:ParameterName"); - Assert.AreEqual(string.Empty, p.SourceColumn, "#B:SourceColumn"); + Assert.That(p.ParameterName, Is.Empty, "#B:ParameterName"); + Assert.That(p.SourceColumn, Is.Empty, "#B:SourceColumn"); p.SourceColumn = " "; - Assert.AreEqual(string.Empty, p.ParameterName, "#C:ParameterName"); - Assert.AreEqual(" ", p.SourceColumn, "#C:SourceColumn"); + Assert.That(p.ParameterName, Is.Empty, "#C:ParameterName"); + Assert.That(p.SourceColumn, Is.EqualTo(" "), "#C:SourceColumn"); p.SourceColumn = " name "; - Assert.AreEqual(string.Empty, p.ParameterName, "#D:ParameterName"); - Assert.AreEqual(" name ", p.SourceColumn, "#D:SourceColumn"); + Assert.That(p.ParameterName, Is.Empty, "#D:ParameterName"); + Assert.That(p.SourceColumn, Is.EqualTo(" name "), "#D:SourceColumn"); p.SourceColumn = string.Empty; - Assert.AreEqual(string.Empty, p.ParameterName, "#E:ParameterName"); - Assert.AreEqual(string.Empty, p.SourceColumn, "#E:SourceColumn"); + Assert.That(p.ParameterName, Is.Empty, "#E:ParameterName"); + Assert.That(p.SourceColumn, Is.Empty, "#E:SourceColumn"); } [Test] @@ -570,8 +528,8 @@ public void Bug1011100_NpgsqlDbType() { var p = new NpgsqlParameter(); p.Value = DBNull.Value; - Assert.AreEqual(DbType.Object, p.DbType, "#A:DbType"); - Assert.AreEqual(NpgsqlDbType.Unknown, p.NpgsqlDbType, "#A:NpgsqlDbType"); + Assert.That(p.DbType, Is.EqualTo(DbType.Object), "#A:DbType"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Unknown), "#A:NpgsqlDbType"); // Now change parameter value. // Note that as we didn't explicitly specified a dbtype, the dbtype property should change when @@ -579,8 +537,8 @@ public void Bug1011100_NpgsqlDbType() p.Value = 8; - Assert.AreEqual(DbType.Int32, p.DbType, "#A:DbType"); - Assert.AreEqual(NpgsqlDbType.Integer, p.NpgsqlDbType, "#A:NpgsqlDbType"); + Assert.That(p.DbType, Is.EqualTo(DbType.Int32), "#A:DbType"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer), "#A:NpgsqlDbType"); //Assert.AreEqual(3510, p.Value, "#A:Value"); //p.NpgsqlDbType = NpgsqlDbType.Varchar; @@ -608,19 +566,19 @@ public void NpgsqlParameter_Clone() var newParam = param.Clone(); - Assert.AreEqual(param.Value, newParam.Value); - Assert.AreEqual(param.Precision, newParam.Precision); - Assert.AreEqual(param.Scale, newParam.Scale); - Assert.AreEqual(param.Size, newParam.Size); - Assert.AreEqual(param.Direction, newParam.Direction); - Assert.AreEqual(param.IsNullable, newParam.IsNullable); - Assert.AreEqual(param.ParameterName, newParam.ParameterName); - Assert.AreEqual(param.TrimmedName, newParam.TrimmedName); - Assert.AreEqual(param.SourceColumn, newParam.SourceColumn); - Assert.AreEqual(param.SourceVersion, newParam.SourceVersion); - Assert.AreEqual(param.NpgsqlValue, newParam.NpgsqlValue); - Assert.AreEqual(param.SourceColumnNullMapping, newParam.SourceColumnNullMapping); - Assert.AreEqual(param.NpgsqlValue, newParam.NpgsqlValue); + Assert.That(newParam.Value, Is.EqualTo(param.Value)); + Assert.That(newParam.Precision, Is.EqualTo(param.Precision)); + Assert.That(newParam.Scale, Is.EqualTo(param.Scale)); + Assert.That(newParam.Size, Is.EqualTo(param.Size)); + Assert.That(newParam.Direction, Is.EqualTo(param.Direction)); + Assert.That(newParam.IsNullable, Is.EqualTo(param.IsNullable)); + Assert.That(newParam.ParameterName, Is.EqualTo(param.ParameterName)); + Assert.That(newParam.TrimmedName, Is.EqualTo(param.TrimmedName)); + Assert.That(newParam.SourceColumn, Is.EqualTo(param.SourceColumn)); + Assert.That(newParam.SourceVersion, Is.EqualTo(param.SourceVersion)); + Assert.That(newParam.NpgsqlValue, Is.EqualTo(param.NpgsqlValue)); + Assert.That(newParam.SourceColumnNullMapping, Is.EqualTo(param.SourceColumnNullMapping)); + Assert.That(newParam.NpgsqlValue, Is.EqualTo(param.NpgsqlValue)); } @@ -632,7 +590,7 @@ public void Precision_via_interface() paramIface.Precision = 42; - Assert.AreEqual((byte)42, paramIface.Precision); + Assert.That(paramIface.Precision, Is.EqualTo((byte)42)); } [Test] @@ -643,7 +601,7 @@ public void Precision_via_base_class() paramBase.Precision = 42; - Assert.AreEqual((byte)42, paramBase.Precision); + Assert.That(paramBase.Precision, Is.EqualTo((byte)42)); } [Test] @@ -654,7 +612,7 @@ public void Scale_via_interface() paramIface.Scale = 42; - Assert.AreEqual((byte)42, paramIface.Scale); + Assert.That(paramIface.Scale, Is.EqualTo((byte)42)); } [Test] @@ -665,7 +623,7 @@ public void Scale_via_base_class() paramBase.Scale = 42; - Assert.AreEqual((byte)42, paramBase.Scale); + Assert.That(paramBase.Scale, Is.EqualTo((byte)42)); } [Test] @@ -698,7 +656,7 @@ public void Null_value_with_nullable_type() public void DBNull_reuses_type_info([Values]bool generic) { var param = generic ? new NpgsqlParameter { Value = "value" } : new NpgsqlParameter { Value = "value" }; - param.ResolveTypeInfo(DataSource.SerializerOptions); + param.ResolveTypeInfo(DataSource.CurrentReloadableState.SerializerOptions, null); param.GetResolutionInfo(out var typeInfo, out _, out _); Assert.That(typeInfo, Is.Not.Null); @@ -708,7 +666,7 @@ public void DBNull_reuses_type_info([Values]bool generic) Assert.That(secondTypeInfo, Is.SameAs(typeInfo)); // Make sure we don't resolve a different type info either. - param.ResolveTypeInfo(DataSource.SerializerOptions); + param.ResolveTypeInfo(DataSource.CurrentReloadableState.SerializerOptions, null); param.GetResolutionInfo(out var thirdTypeInfo, out _, out _); Assert.That(thirdTypeInfo, Is.SameAs(secondTypeInfo)); } @@ -717,7 +675,7 @@ public void DBNull_reuses_type_info([Values]bool generic) public void DBNull_followed_by_non_null_reresolves([Values]bool generic) { var param = generic ? new NpgsqlParameter { Value = DBNull.Value } : new NpgsqlParameter { Value = DBNull.Value }; - param.ResolveTypeInfo(DataSource.SerializerOptions); + param.ResolveTypeInfo(DataSource.CurrentReloadableState.SerializerOptions, null); param.GetResolutionInfo(out var typeInfo, out _, out var pgTypeId); Assert.That(typeInfo, Is.Not.Null); Assert.That(pgTypeId.IsUnspecified, Is.True); @@ -727,7 +685,7 @@ public void DBNull_followed_by_non_null_reresolves([Values]bool generic) Assert.That(secondTypeInfo, Is.Null); // Make sure we don't resolve the same type info either. - param.ResolveTypeInfo(DataSource.SerializerOptions); + param.ResolveTypeInfo(DataSource.CurrentReloadableState.SerializerOptions, null); param.GetResolutionInfo(out var thirdTypeInfo, out _, out _); Assert.That(thirdTypeInfo, Is.Not.SameAs(typeInfo)); } @@ -736,7 +694,7 @@ public void DBNull_followed_by_non_null_reresolves([Values]bool generic) public void Changing_value_type_reresolves([Values]bool generic) { var param = generic ? new NpgsqlParameter { Value = "value" } : new NpgsqlParameter { Value = "value" }; - param.ResolveTypeInfo(DataSource.SerializerOptions); + param.ResolveTypeInfo(DataSource.CurrentReloadableState.SerializerOptions, null); param.GetResolutionInfo(out var typeInfo, out _, out _); Assert.That(typeInfo, Is.Not.Null); @@ -745,11 +703,31 @@ public void Changing_value_type_reresolves([Values]bool generic) Assert.That(secondTypeInfo, Is.Null); // Make sure we don't resolve a different type info either. - param.ResolveTypeInfo(DataSource.SerializerOptions); + param.ResolveTypeInfo(DataSource.CurrentReloadableState.SerializerOptions, null); param.GetResolutionInfo(out var thirdTypeInfo, out _, out _); Assert.That(thirdTypeInfo, Is.Not.SameAs(typeInfo)); } + [Test] + public void DataTypeName_prioritized_over_NpgsqlDbType([Values]bool generic) + { + var param = generic ? new NpgsqlParameter + { + NpgsqlDbType = NpgsqlDbType.Integer, + DataTypeName = "text", + Value = "value" + } : new NpgsqlParameter + { + NpgsqlDbType = NpgsqlDbType.Integer, + DataTypeName = "text", + Value = "value" + }; + param.ResolveTypeInfo(DataSource.CurrentReloadableState.SerializerOptions, null); + param.GetResolutionInfo(out var typeInfo, out _, out _); + Assert.That(typeInfo, Is.Not.Null); + Assert.That(typeInfo.PgTypeId, Is.EqualTo(DataSource.CurrentReloadableState.SerializerOptions.TextPgTypeId)); + } + #if NeedsPorting [Test] [Category ("NotWorking")] diff --git a/test/Npgsql.Tests/PgPassEntryTests.cs b/test/Npgsql.Tests/PgPassEntryTests.cs index 9db518aabc..db78e893ad 100644 --- a/test/Npgsql.Tests/PgPassEntryTests.cs +++ b/test/Npgsql.Tests/PgPassEntryTests.cs @@ -13,11 +13,11 @@ public void Parses_well_formed_entry() var entry = PgPassFile.Entry.Parse(input); Assert.That(entry, Is.Not.Null); - Assert.That("test", Is.EqualTo(entry.Host)); - Assert.That(1234, Is.EqualTo(entry.Port)); - Assert.That("test2", Is.EqualTo(entry.Database)); - Assert.That("test3", Is.EqualTo(entry.Username)); - Assert.That("test4", Is.EqualTo(entry.Password)); + Assert.That(entry.Host, Is.EqualTo("test")); + Assert.That(entry.Port, Is.EqualTo(1234)); + Assert.That(entry.Database, Is.EqualTo("test2")); + Assert.That(entry.Username, Is.EqualTo("test3")); + Assert.That(entry.Password, Is.EqualTo("test4")); } [Test] @@ -36,11 +36,11 @@ public void Escaped_characters() var entry = PgPassFile.Entry.Parse(input); Assert.That(entry, Is.Not.Null); - Assert.That("t:est", Is.EqualTo(entry.Host)); - Assert.That(1234, Is.EqualTo(entry.Port)); - Assert.That("test2", Is.EqualTo(entry.Database)); - Assert.That("test3", Is.EqualTo(entry.Username)); - Assert.That("test\\4", Is.EqualTo(entry.Password)); + Assert.That(entry.Host, Is.EqualTo("t:est")); + Assert.That(entry.Port, Is.EqualTo(1234)); + Assert.That(entry.Database, Is.EqualTo("test2")); + Assert.That(entry.Username, Is.EqualTo("test3")); + Assert.That(entry.Password, Is.EqualTo("test\\4")); } [Test] diff --git a/test/Npgsql.Tests/PoolTests.cs b/test/Npgsql.Tests/PoolTests.cs index d9024dd0dd..4a3ecca261 100644 --- a/test/Npgsql.Tests/PoolTests.cs +++ b/test/Npgsql.Tests/PoolTests.cs @@ -314,7 +314,7 @@ public void ClearPool(int iterations) } // Now have one connection in the pool - Assert.True(PoolManager.Pools.TryGetValue(connString, out var pool)); + Assert.That(PoolManager.Pools.TryGetValue(connString, out var pool)); AssertPoolState(pool, open: 1, idle: 1); NpgsqlConnection.ClearPool(conn); @@ -346,7 +346,7 @@ public void ClearPool_with_busy() NpgsqlConnection.ClearPool(conn); // conn is still busy but should get closed when returned to the pool - Assert.True(PoolManager.Pools.TryGetValue(connString, out pool)); + Assert.That(PoolManager.Pools.TryGetValue(connString, out pool)); AssertPoolState(pool, open: 1, idle: 0); } @@ -468,29 +468,5 @@ await Task.WhenAll(Enumerable.Range(0, numParallelCommands) })); } - // When multiplexing, and the pool is totally saturated (at Max Pool Size and 0 idle connectors), we select - // the connector with the least commands in flight and execute on it. We must never select a connector with - // a pending transaction on it. - // TODO: Test not tested - [Test] - [Ignore("Multiplexing: fails")] - public async Task MultiplexedCommandDoesntGetExecutedOnTransactionedConnector() - { - await using var dataSource = CreateDataSource(csb => - { - csb.MaxPoolSize = 1; - csb.Timeout = 1; - }); - - await using var connWithTx = await dataSource.OpenConnectionAsync(); - await using var tx = await connWithTx.BeginTransactionAsync(); - // connWithTx should now be bound with the only physical connector available. - // Any commands execute should timeout - - await using var conn2 = await dataSource.OpenConnectionAsync(); - await using var cmd = new NpgsqlCommand("SELECT 1", conn2); - Assert.ThrowsAsync(() => cmd.ExecuteScalarAsync()); - } - #endregion } diff --git a/test/Npgsql.Tests/PostgresTypeTests.cs b/test/Npgsql.Tests/PostgresTypeTests.cs index 056830cf32..7c3945fb0a 100644 --- a/test/Npgsql.Tests/PostgresTypeTests.cs +++ b/test/Npgsql.Tests/PostgresTypeTests.cs @@ -69,6 +69,6 @@ public async Task Multirange() async Task GetDatabaseInfo() { await using var conn = await OpenConnectionAsync(); - return conn.NpgsqlDataSource.DatabaseInfo; + return conn.NpgsqlDataSource.CurrentReloadableState.DatabaseInfo; } } diff --git a/test/Npgsql.Tests/PrepareTests.cs b/test/Npgsql.Tests/PrepareTests.cs index 1d9c6dde85..1957d1091e 100644 --- a/test/Npgsql.Tests/PrepareTests.cs +++ b/test/Npgsql.Tests/PrepareTests.cs @@ -462,7 +462,7 @@ public void Overloaded_sql() // SQL overloading is a pretty rare/exotic scenario. Handling it properly would involve keying // prepared statements not just by SQL but also by the parameter types, which would pointlessly - // increase allocations. Instead, the second execution simply reuns unprepared + // increase allocations. Instead, the second execution simply reruns unprepared AssertNumPreparedStatements(conn, 1); conn.UnprepareAll(); } @@ -659,7 +659,7 @@ public void Same_sql_different_params() using (var conn = OpenConnectionAndUnprepare()) using (var cmd = new NpgsqlCommand("SELECT @p", conn)) { - throw new NotImplementedException("Problem: currentl setting NpgsqlParameter.Value clears/invalidates..."); + throw new NotImplementedException("Problem: current setting NpgsqlParameter.Value clears/invalidates..."); cmd.Parameters.Add(new NpgsqlParameter("p", NpgsqlDbType.Integer)); cmd.Prepare(true); @@ -726,18 +726,7 @@ public void Prepare_multiple_commands_with_parameters() } [Test] - public void Multiplexing_not_supported() - { - using var dataSource = CreateDataSource(csb => csb.Multiplexing = true); - using var conn = dataSource.OpenConnection(); - using var cmd = new NpgsqlCommand("SELECT 1", conn); - - Assert.That(() => cmd.Prepare(), Throws.Exception.TypeOf()); - Assert.That(() => conn.UnprepareAll(), Throws.Exception.TypeOf()); - } - - [Test] - public async Task Explicitly_prepared_statement_invalidation() + public async Task Explicitly_prepared_statement_invalidation([Values] bool prepareAfterError, [Values] bool unprepareAfterError) { await using var dataSource = CreateDataSource(csb => { @@ -755,12 +744,30 @@ public async Task Explicitly_prepared_statement_invalidation() // Since we've changed the table schema, the next execution of the prepared statement will error with 0A000 var exception = Assert.ThrowsAsync(() => command.ExecuteNonQueryAsync())!; Assert.That(exception.SqlState, Is.EqualTo(PostgresErrorCodes.FeatureNotSupported)); // cached plan must not change result type + Assert.That(command.IsPrepared, Is.False); + + if (unprepareAfterError) + { + // Just check that calling unprepare after error doesn't break anything + await command.UnprepareAsync(); + Assert.That(command.IsPrepared, Is.False); + } + + if (prepareAfterError) + { + // If we explicitly prepare after error, we should replace the previous prepared statement with a new one + await command.PrepareAsync(); + Assert.That(command.IsPrepared); + } // However, Npgsql should invalidate the prepared statement in this case, so the next execution should work Assert.DoesNotThrowAsync(() => command.ExecuteNonQueryAsync()); - // The command is unprepared, though. It's the user's responsibility to re-prepare if they wish. - Assert.False(command.IsPrepared); + if (!prepareAfterError) + { + // The command is unprepared, though. It's the user's responsibility to re-prepare if they wish. + Assert.That(command.IsPrepared, Is.False); + } } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4920")] diff --git a/test/Npgsql.Tests/Properties/AssemblyInfo.cs b/test/Npgsql.Tests/Properties/AssemblyInfo.cs index f7cdcd188d..89a1bb2e0d 100644 --- a/test/Npgsql.Tests/Properties/AssemblyInfo.cs +++ b/test/Npgsql.Tests/Properties/AssemblyInfo.cs @@ -1,7 +1,7 @@ using System.Runtime.CompilerServices; using NUnit.Framework; -[assembly: Parallelizable(ParallelScope.Children), Timeout(30000)] +[assembly: Parallelizable(ParallelScope.Children)] [assembly: InternalsVisibleTo("Npgsql.PluginTests, PublicKey=" + "0024000004800000940000000602000000240000525341310004000001000100" + diff --git a/test/Npgsql.Tests/ReadBufferTests.cs b/test/Npgsql.Tests/ReadBufferTests.cs index 7d33bf68e1..3169e5366d 100644 --- a/test/Npgsql.Tests/ReadBufferTests.cs +++ b/test/Npgsql.Tests/ReadBufferTests.cs @@ -16,14 +16,14 @@ public void Skip() for (byte i = 0; i < 50; i++) Writer.WriteByte(i); - ReadBuffer.Ensure(10, async: false).GetAwaiter().GetResult(); + ReadBuffer.Ensure(10); ReadBuffer.Skip(7); Assert.That(ReadBuffer.ReadByte(), Is.EqualTo(7)); ReadBuffer.Skip(10); - ReadBuffer.Ensure(1, async: false).GetAwaiter().GetResult(); + ReadBuffer.Ensure(1); Assert.That(ReadBuffer.ReadByte(), Is.EqualTo(18)); ReadBuffer.Skip(20); - ReadBuffer.Ensure(1, async: false).GetAwaiter().GetResult(); + ReadBuffer.Ensure(1); Assert.That(ReadBuffer.ReadByte(), Is.EqualTo(39)); } @@ -35,7 +35,7 @@ public void ReadSingle() Array.Reverse(bytes); Writer.Write(bytes); - ReadBuffer.Ensure(4, async: false).GetAwaiter().GetResult(); + ReadBuffer.Ensure(4); Assert.That(ReadBuffer.ReadSingle(), Is.EqualTo(expected)); } @@ -47,7 +47,7 @@ public void ReadDouble() Array.Reverse(bytes); Writer.Write(bytes); - ReadBuffer.Ensure(8, async: false).GetAwaiter().GetResult(); + ReadBuffer.Ensure(8); Assert.That(ReadBuffer.ReadDouble(), Is.EqualTo(expected)); } @@ -60,7 +60,7 @@ public void ReadNullTerminatedString_buffered_only() .Write(NpgsqlWriteBuffer.UTF8Encoding.GetBytes(new string("bar"))) .WriteByte(0); - ReadBuffer.Ensure(1, async: false); + ReadBuffer.Ensure(1); Assert.That(ReadBuffer.ReadNullTerminatedString(), Is.EqualTo("foo")); Assert.That(ReadBuffer.ReadNullTerminatedString(), Is.EqualTo("bar")); @@ -133,12 +133,8 @@ async Task Read(byte[] buffer, int offset, int count, bool async) return count; } - internal class MockStreamWriter + internal class MockStreamWriter(MockStream stream) { - readonly MockStream _stream; - - public MockStreamWriter(MockStream stream) => _stream = stream; - public MockStreamWriter WriteByte(byte b) { Span bytes = stackalloc byte[1]; @@ -149,11 +145,11 @@ public MockStreamWriter WriteByte(byte b) public MockStreamWriter Write(ReadOnlySpan bytes) { - if (_stream._filled + bytes.Length > Size) + if (stream._filled + bytes.Length > Size) throw new Exception("Mock stream overrun"); - bytes.CopyTo(new Span(_stream._data, _stream._filled, bytes.Length)); - _stream._filled += bytes.Length; - _stream._tcs.TrySetResult(new()); + bytes.CopyTo(new Span(stream._data, stream._filled, bytes.Length)); + stream._filled += bytes.Length; + stream._tcs.TrySetResult(new()); return this; } } diff --git a/test/Npgsql.Tests/ReaderNewSchemaTests.cs b/test/Npgsql.Tests/ReaderNewSchemaTests.cs index 01e46cdd06..3eafdc8a89 100644 --- a/test/Npgsql.Tests/ReaderNewSchemaTests.cs +++ b/test/Npgsql.Tests/ReaderNewSchemaTests.cs @@ -1,7 +1,10 @@ -using System.Collections.ObjectModel; +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; using System.Data; using System.Data.Common; using System.Linq; +using System.Threading; using System.Threading.Tasks; using Npgsql.PostgresTypes; using NUnit.Framework; @@ -14,7 +17,7 @@ namespace Npgsql.Tests; /// Note that this API is also available on .NET Framework. /// For the old DataTable-based API, see . /// -public class ReaderNewSchemaTests : SyncOrAsyncTestBase +public class ReaderNewSchemaTests(SyncOrAsync syncOrAsync) : SyncOrAsyncTestBase(syncOrAsync) { // ReSharper disable once InconsistentNaming [Test] @@ -204,7 +207,7 @@ public async Task ColumnAttributeNumber() public async Task ColumnSize() { using var conn = await OpenConnectionAsync(); - IgnoreOnRedshift(conn, "Column size is never unlimited on Redshift"); + await IgnoreOnRedshift(conn, "Column size is never unlimited on Redshift"); var table = await CreateTempTable(conn, "bounded VARCHAR(30), unbounded VARCHAR"); using var cmd = new NpgsqlCommand($"SELECT bounded,unbounded,'a'::VARCHAR(10),'b'::VARCHAR FROM {table}", conn); @@ -220,7 +223,7 @@ public async Task ColumnSize() public async Task IsAutoIncrement() { await using var conn = await OpenConnectionAsync(); - IgnoreOnRedshift(conn, "Serial columns not support on Redshift"); + await IgnoreOnRedshift(conn, "Serial columns not support on Redshift"); var table = await CreateTempTable(conn, "serial SERIAL, int INT"); @@ -236,7 +239,7 @@ public async Task IsAutoIncrement() public async Task IsAutoIncrement_identity() { await using var conn = await OpenConnectionAsync(); - IgnoreOnRedshift(conn, "Identity columns not support on Redshift"); + await IgnoreOnRedshift(conn, "Identity columns not support on Redshift"); MinimumPgVersion(conn, "10.0", "IDENTITY introduced in PostgreSQL 10"); var table = @@ -253,7 +256,7 @@ public async Task IsAutoIncrement_identity() public async Task IsIdentity() { await using var conn = await OpenConnectionAsync(); - IgnoreOnRedshift(conn, "Identity columns not support on Redshift"); + await IgnoreOnRedshift(conn, "Identity columns not support on Redshift"); MinimumPgVersion(conn, "10.0", "IDENTITY introduced in PostgreSQL 10"); var table = await CreateTempTable( conn, @@ -273,7 +276,7 @@ public async Task IsIdentity() public async Task IsKey() { using var conn = await OpenConnectionAsync(); - IgnoreOnRedshift(conn, "Key not supported in reader schema on Redshift"); + await IgnoreOnRedshift(conn, "Key not supported in reader schema on Redshift"); var table = await CreateTempTable(conn, "id INT PRIMARY KEY, non_id INT, uniq INT UNIQUE"); using var cmd = new NpgsqlCommand($"SELECT id,non_id,uniq,8 FROM {table}", conn); @@ -294,7 +297,7 @@ public async Task IsKey() public async Task IsKey_composite() { using var conn = await OpenConnectionAsync(); - IgnoreOnRedshift(conn, "Key not supported in reader schema on Redshift"); + await IgnoreOnRedshift(conn, "Key not supported in reader schema on Redshift"); var table = await CreateTempTable(conn, "id1 INT, id2 INT, PRIMARY KEY (id1, id2)"); using var cmd = new NpgsqlCommand($"SELECT id1,id2 FROM {table}", conn); @@ -308,7 +311,7 @@ public async Task IsKey_composite() public async Task IsLong() { using var conn = await OpenConnectionAsync(); - IgnoreOnRedshift(conn, "bytea not supported on Redshift"); + await IgnoreOnRedshift(conn, "bytea not supported on Redshift"); var table = await CreateTempTable(conn, "long BYTEA, non_long INT"); using var cmd = new NpgsqlCommand($"SELECT long, non_long, 8 FROM {table}", conn); @@ -351,7 +354,7 @@ public async Task IsReadOnly_on_non_column() public async Task IsUnique() { using var conn = await OpenConnectionAsync(); - IgnoreOnRedshift(conn, "Unique not supported in reader schema on Redshift"); + await IgnoreOnRedshift(conn, "Unique not supported in reader schema on Redshift"); var table = await GetTempTableName(conn); await conn.ExecuteNonQueryAsync($@" @@ -373,7 +376,7 @@ await conn.ExecuteNonQueryAsync($@" public async Task NumericPrecision() { using var conn = await OpenConnectionAsync(); - IgnoreOnRedshift(conn, "Precision is never unlimited on Redshift"); + await IgnoreOnRedshift(conn, "Precision is never unlimited on Redshift"); var table = await CreateTempTable(conn, "a NUMERIC(8), b NUMERIC, c INTEGER"); using var cmd = new NpgsqlCommand($"SELECT a,b,c,8.3::NUMERIC(8) FROM {table}", conn); @@ -389,7 +392,7 @@ public async Task NumericPrecision() public async Task NumericScale() { using var conn = await OpenConnectionAsync(); - IgnoreOnRedshift(conn, "Scale is never unlimited on Redshift"); + await IgnoreOnRedshift(conn, "Scale is never unlimited on Redshift"); var table = await CreateTempTable(conn, "a NUMERIC(8,5), b NUMERIC, c INTEGER"); using var cmd = new NpgsqlCommand($"SELECT a,b,c,8.3::NUMERIC(8,5) FROM {table}", conn); @@ -431,7 +434,7 @@ public async Task DataType_unknown_type() public async Task DataType_with_composite() { await using var adminConnection = await OpenConnectionAsync(); - IgnoreOnRedshift(adminConnection, "Composite types not support on Redshift"); + await IgnoreOnRedshift(adminConnection, "Composite types not support on Redshift"); var type = await GetTempTypeName(adminConnection); await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS (foo int)"); var tableName = await CreateTempTable(adminConnection, $"comp {type}"); @@ -450,6 +453,19 @@ public async Task DataType_with_composite() Assert.That(columns[1].UdtAssemblyQualifiedName, Is.EqualTo(typeof(SomeComposite).AssemblyQualifiedName)); } + [Test] + public async Task DataType_with_array() + { + using var conn = await OpenConnectionAsync(); + var table = await CreateTempTable(conn, "foo INTEGER[]"); + + using var cmd = new NpgsqlCommand($"SELECT foo, ARRAY[1::INTEGER, 2::INTEGER] FROM {table}", conn); + using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); + var columns = await GetColumnSchema(reader); + Assert.That(columns[0].DataType, Is.SameAs(typeof(Array))); + Assert.That(columns[1].DataType, Is.SameAs(typeof(Array))); + } + [Test] public async Task UdtAssemblyQualifiedName() { @@ -654,10 +670,8 @@ await conn.ExecuteNonQueryAsync($@" [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1553")] public async Task Domain_type() { - // if (IsMultiplexing) - // Assert.Ignore("Multiplexing: ReloadTypes"); using var conn = await OpenConnectionAsync(); - IgnoreOnRedshift(conn, "Domain types not support on Redshift"); + await IgnoreOnRedshift(conn, "Domain types not support on Redshift"); const string domainTypeName = "my_domain"; var schema = await CreateTempSchema(conn); @@ -673,6 +687,8 @@ public async Task Domain_type() var pgType = domainSchema.PostgresType; Assert.That(pgType, Is.InstanceOf()); Assert.That(((PostgresDomainType)pgType).BaseType.Name, Is.EqualTo("character varying")); + // For domains we should return the underlying type + Assert.That(domainSchema.NpgsqlDbType, Is.EqualTo(NpgsqlTypes.NpgsqlDbType.Varchar)); } [Test] @@ -757,9 +773,9 @@ public async Task GetColumnSchema_via_interface() var iface = (IDbColumnSchemaGenerator)reader; var schema = iface.GetColumnSchema(); - Assert.NotNull(schema); - Assert.AreEqual(1, schema.Count); - Assert.NotNull(schema[0]); + Assert.That(schema, Is.Not.Null); + Assert.That(schema.Count, Is.EqualTo(1)); + Assert.That(schema[0], Is.Not.Null); } #region Not supported @@ -793,8 +809,6 @@ class SomeComposite public int Foo { get; set; } } - public ReaderNewSchemaTests(SyncOrAsync syncOrAsync) : base(syncOrAsync) { } - - async Task> GetColumnSchema(NpgsqlDataReader reader) - => IsAsync ? await reader.GetColumnSchemaAsync() : reader.GetColumnSchema(); + async Task> GetColumnSchema(NpgsqlDataReader reader) + => IsAsync ? (await reader.GetColumnSchemaAsync(CancellationToken.None)).Cast().ToArray() : reader.GetColumnSchema(); } diff --git a/test/Npgsql.Tests/ReaderOldSchemaTests.cs b/test/Npgsql.Tests/ReaderOldSchemaTests.cs index edbeb15842..72a6401468 100644 --- a/test/Npgsql.Tests/ReaderOldSchemaTests.cs +++ b/test/Npgsql.Tests/ReaderOldSchemaTests.cs @@ -11,7 +11,7 @@ namespace Npgsql.Tests; /// This tests the .NET Framework DbDataReader schema/metadata API, which returns DataTable. /// For the new CoreCLR API, see . /// -public class ReaderOldSchemaTests : SyncOrAsyncTestBase +public class ReaderOldSchemaTests(SyncOrAsync syncOrAsync) : SyncOrAsyncTestBase(syncOrAsync) { [Test] public async Task Primary_key_composite() @@ -55,7 +55,7 @@ public async Task Primary_key() public async Task IsAutoIncrement() { await using var conn = await OpenConnectionAsync(); - IgnoreOnRedshift(conn, "Serial columns not supported on Redshift"); + await IgnoreOnRedshift(conn, "Serial columns not supported on Redshift"); var table = await CreateTempTable(conn, "serial SERIAL, int INT"); @@ -72,7 +72,7 @@ public async Task IsAutoIncrement() public async Task IsAutoIncrement_identity() { await using var conn = await OpenConnectionAsync(); - IgnoreOnRedshift(conn, "Serial columns not supported on Redshift"); + await IgnoreOnRedshift(conn, "Serial columns not supported on Redshift"); MinimumPgVersion(conn, "10.0", "IDENTITY introduced in PostgreSQL 10"); var table = @@ -90,7 +90,7 @@ public async Task IsAutoIncrement_identity() public async Task IsIdentity() { await using var conn = await OpenConnectionAsync(); - IgnoreOnRedshift(conn, "Identity columns not support on Redshift"); + await IgnoreOnRedshift(conn, "Identity columns not support on Redshift"); MinimumPgVersion(conn, "10.0", "IDENTITY introduced in PostgreSQL 10"); var table = await CreateTempTable( conn, @@ -124,12 +124,12 @@ await conn.ExecuteNonQueryAsync($@" var metadata = await GetSchemaTable(dr); var idRow = metadata!.Rows.OfType().FirstOrDefault(x => (string)x["ColumnName"] == "id"); - Assert.IsNotNull(idRow, "Unable to find metadata for id column"); + Assert.That(idRow, Is.Not.Null, "Unable to find metadata for id column"); var int2Row = metadata.Rows.OfType().FirstOrDefault(x => (string)x["ColumnName"] == "int2"); - Assert.IsNotNull(int2Row, "Unable to find metadata for int2 column"); + Assert.That(int2Row, Is.Not.Null, "Unable to find metadata for int2 column"); - Assert.IsFalse((bool)idRow!["IsReadonly"]); - Assert.IsTrue((bool)int2Row!["IsReadonly"]); + Assert.That((bool)idRow!["IsReadonly"], Is.False); + Assert.That((bool)int2Row!["IsReadonly"]); } // ReSharper disable once InconsistentNaming @@ -144,12 +144,12 @@ public async Task AllowDBNull() using var metadata = await GetSchemaTable(reader); var nullableRow = metadata!.Rows.OfType().FirstOrDefault(x => (string)x["ColumnName"] == "nullable"); - Assert.IsNotNull(nullableRow, "Unable to find metadata for nullable column"); + Assert.That(nullableRow, Is.Not.Null, "Unable to find metadata for nullable column"); var nonNullableRow = metadata.Rows.OfType().FirstOrDefault(x => (string)x["ColumnName"] == "non_nullable"); - Assert.IsNotNull(nonNullableRow, "Unable to find metadata for non_nullable column"); + Assert.That(nonNullableRow, Is.Not.Null, "Unable to find metadata for non_nullable column"); - Assert.IsTrue((bool)nullableRow!["AllowDBNull"]); - Assert.IsFalse((bool)nonNullableRow!["AllowDBNull"]); + Assert.That((bool)nullableRow!["AllowDBNull"]); + Assert.That((bool)nonNullableRow!["AllowDBNull"], Is.False); } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1027")] @@ -181,9 +181,6 @@ public async Task Precision_and_scale() [Test] public async Task SchemaOnly([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) { - // if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - // return; - using var conn = await OpenConnectionAsync(); var table = await CreateTempTable(conn, "name TEXT"); @@ -240,7 +237,5 @@ CONSTRAINT PK_test_Cod PRIMARY KEY (Cod) Assert.That(dt.Rows[2]["ColumnName"].ToString(), Is.EqualTo("date")); } - public ReaderOldSchemaTests(SyncOrAsync syncOrAsync) : base(syncOrAsync) { } - async Task GetSchemaTable(NpgsqlDataReader dr) => IsAsync ? await dr.GetSchemaTableAsync() : dr.GetSchemaTable(); } diff --git a/test/Npgsql.Tests/ReaderTests.cs b/test/Npgsql.Tests/ReaderTests.cs index e37bea6b31..8a4484ef9b 100644 --- a/test/Npgsql.Tests/ReaderTests.cs +++ b/test/Npgsql.Tests/ReaderTests.cs @@ -20,11 +20,9 @@ namespace Npgsql.Tests; -[TestFixture(MultiplexingMode.NonMultiplexing, CommandBehavior.Default)] -[TestFixture(MultiplexingMode.Multiplexing, CommandBehavior.Default)] -[TestFixture(MultiplexingMode.NonMultiplexing, CommandBehavior.SequentialAccess)] -[TestFixture(MultiplexingMode.Multiplexing, CommandBehavior.SequentialAccess)] -public class ReaderTests : MultiplexingTestBase +[TestFixture(CommandBehavior.Default)] +[TestFixture(CommandBehavior.SequentialAccess)] +public class ReaderTests : TestBase { static uint Int4Oid => PostgresMinimalDatabaseInfo.DefaultTypeCatalog.GetOid(DataTypeNames.Int4).Value; static uint ByteaOid => PostgresMinimalDatabaseInfo.DefaultTypeCatalog.GetOid(DataTypeNames.Bytea).Value; @@ -237,7 +235,7 @@ public async Task Get_string_with_parameter() using var dr = await command.ExecuteReaderAsync(Behavior); dr.Read(); var result = dr.GetString(0); - Assert.AreEqual(text, result); + Assert.That(result, Is.EqualTo(text)); } [Test] @@ -263,7 +261,7 @@ await conn.ExecuteNonQueryAsync($@" using var dr = await command.ExecuteReaderAsync(Behavior); dr.Read(); var result = dr.GetString(0); - Assert.AreEqual(test, result); + Assert.That(result, Is.EqualTo(test)); } [Test] @@ -304,7 +302,7 @@ public async Task GetFieldType_SchemaOnly() { await using var conn = await OpenConnectionAsync(); await using var cmd = new NpgsqlCommand(@"SELECT 1::INT4 AS some_column", conn); - await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); + await using var reader = await cmd.ExecuteReaderAsync(Behavior | CommandBehavior.SchemaOnly); reader.Read(); Assert.That(reader.GetFieldType(0), Is.SameAs(typeof(int))); } @@ -312,9 +310,6 @@ public async Task GetFieldType_SchemaOnly() [Test] public async Task GetPostgresType() { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: Fails"); - using var conn = await OpenConnectionAsync(); PostgresType intType; using (var cmd = new NpgsqlCommand(@"SELECT 1::INTEGER AS some_column", conn)) @@ -386,7 +381,6 @@ public async Task GetDataTypeName_enum() await using var conn = await dataSource.OpenConnectionAsync(); var typeName = await GetTempTypeName(conn); await conn.ExecuteNonQueryAsync($"CREATE TYPE {typeName} AS ENUM ('one')"); - await Task.Yield(); // TODO: fix multiplexing deadlock bug conn.ReloadTypes(); await using var cmd = new NpgsqlCommand($"SELECT 'one'::{typeName}", conn); await using var reader = await cmd.ExecuteReaderAsync(Behavior); @@ -401,7 +395,6 @@ public async Task GetDataTypeName_domain() await using var conn = await dataSource.OpenConnectionAsync(); var typeName = await GetTempTypeName(conn); await conn.ExecuteNonQueryAsync($"CREATE DOMAIN {typeName} AS VARCHAR(10)"); - await Task.Yield(); // TODO: fix multiplexing deadlock bug conn.ReloadTypes(); await using var cmd = new NpgsqlCommand($"SELECT 'one'::{typeName}", conn); await using var reader = await cmd.ExecuteReaderAsync(Behavior); @@ -475,7 +468,7 @@ public async Task GetValues() dr.Read(); var values = new object[4]; Assert.That(dr.GetValues(values), Is.EqualTo(3)); - Assert.That(values, Is.EqualTo(new object?[] { "hello", 1, new DateTime(2014, 1, 1), null })); + Assert.That(values, Is.EqualTo(new object?[] { "hello", 1, new DateOnly(2014, 1, 1), null })); } using (var dr = await command.ExecuteReaderAsync(Behavior)) { @@ -496,7 +489,7 @@ public async Task ExecuteReader_getting_empty_resultset_with_output_parameter() param.Direction = ParameterDirection.Output; command.Parameters.Add(param); using var dr = await command.ExecuteReaderAsync(Behavior); - Assert.IsFalse(dr.NextResult()); + Assert.That(dr.NextResult(), Is.False); } [Test] @@ -535,7 +528,7 @@ public async Task Read_past_reader_end() [Test] public async Task Reader_dispose_state_does_not_leak() { - if (IsMultiplexing || Behavior != CommandBehavior.Default) + if (Behavior != CommandBehavior.Default) return; var startReaderClosedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); @@ -582,9 +575,6 @@ public async Task SingleResult() [Test, IssueLink("https://github.com/npgsql/npgsql/issues/400")] public async Task Exception_thrown_from_ExecuteReaderAsync([Values(PrepareOrNot.Prepared, PrepareOrNot.NotPrepared)] PrepareOrNot prepare) { - if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - return; - using var conn = await OpenConnectionAsync(); var function = await GetTempFunctionName(conn); @@ -603,9 +593,6 @@ await conn.ExecuteNonQueryAsync($@" [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1032")] public async Task Exception_thrown_from_NextResult([Values(PrepareOrNot.Prepared, PrepareOrNot.NotPrepared)] PrepareOrNot prepare) { - if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - return; - using var conn = await OpenConnectionAsync(); var function = await GetTempFunctionName(conn); @@ -623,9 +610,10 @@ await conn.ExecuteNonQueryAsync($@" } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/967")] - public async Task NpgsqlException_references_BatchCommand_with_single_command() + public async Task NpgsqlException_references_BatchCommand_with_single_command([Values] bool includeFailedBatchedCommand) { - await using var conn = await OpenConnectionAsync(); + await using var dataSource = CreateDataSource(x => x.IncludeFailedBatchedCommand = includeFailedBatchedCommand); + await using var conn = await dataSource.OpenConnectionAsync(); var function = await GetTempFunctionName(conn); await conn.ExecuteNonQueryAsync($@" @@ -638,19 +626,23 @@ await conn.ExecuteNonQueryAsync($@" cmd.CommandText = $"SELECT {function}()"; var exception = Assert.ThrowsAsync(() => cmd.ExecuteReaderAsync(Behavior))!; - Assert.That(exception.BatchCommand, Is.SameAs(cmd.InternalBatchCommands[0])); + if (includeFailedBatchedCommand) + Assert.That(exception.BatchCommand, Is.SameAs(cmd.InternalBatchCommands[0])); + else + Assert.That(exception.BatchCommand, Is.Null); // Make sure the command isn't recycled by the connection when it's disposed - this is important since internal command // resources are referenced by the exception above, which is very likely to escape the using statement of the command. cmd.Dispose(); var cmd2 = conn.CreateCommand(); - Assert.AreNotSame(cmd2, cmd); + Assert.That(cmd, Is.Not.SameAs(cmd2)); } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/967")] - public async Task NpgsqlException_references_BatchCommand_with_multiple_commands() + public async Task NpgsqlException_references_BatchCommand_with_multiple_commands([Values] bool includeFailedBatchedCommand) { - await using var conn = await OpenConnectionAsync(); + await using var dataSource = CreateDataSource(x => x.IncludeFailedBatchedCommand = includeFailedBatchedCommand); + await using var conn = await dataSource.OpenConnectionAsync(); var function = await GetTempFunctionName(conn); await conn.ExecuteNonQueryAsync($@" @@ -665,14 +657,17 @@ await conn.ExecuteNonQueryAsync($@" await using (var reader = await cmd.ExecuteReaderAsync(Behavior)) { var exception = Assert.ThrowsAsync(() => reader.NextResultAsync())!; - Assert.That(exception.BatchCommand, Is.SameAs(cmd.InternalBatchCommands[1])); + if (includeFailedBatchedCommand) + Assert.That(exception.BatchCommand, Is.SameAs(cmd.InternalBatchCommands[1])); + else + Assert.That(exception.BatchCommand, Is.Null); } // Make sure the command isn't recycled by the connection when it's disposed - this is important since internal command // resources are referenced by the exception above, which is very likely to escape the using statement of the command. cmd.Dispose(); var cmd2 = conn.CreateCommand(); - Assert.AreNotSame(cmd2, cmd); + Assert.That(cmd, Is.Not.SameAs(cmd2)); } #region SchemaOnly @@ -694,8 +689,8 @@ public async Task SchemaOnly_next_result_beyond_end() using var cmd = new NpgsqlCommand($"SELECT * FROM {table}", conn); using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SchemaOnly); - Assert.False(reader.NextResult()); - Assert.False(reader.NextResult()); + Assert.That(reader.NextResult(), Is.False); + Assert.That(reader.NextResult(), Is.False); } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4124")] @@ -759,8 +754,7 @@ public async Task Reader_is_still_open() { await using var conn = await OpenConnectionAsync(); // We might get the connection, on which the second command was already prepared, so prepare wouldn't start the UserAction - if (!IsMultiplexing) - conn.UnprepareAll(); + conn.UnprepareAll(); using var cmd1 = new NpgsqlCommand("SELECT 1", conn); await using var reader1 = await cmd1.ExecuteReaderAsync(Behavior); Assert.That(() => conn.ExecuteNonQuery("SELECT 1"), Throws.Exception.TypeOf()); @@ -768,16 +762,12 @@ public async Task Reader_is_still_open() using var cmd2 = new NpgsqlCommand("SELECT 2", conn); Assert.That(() => cmd2.ExecuteReader(Behavior), Throws.Exception.TypeOf()); - if (!IsMultiplexing) - Assert.That(() => cmd2.Prepare(), Throws.Exception.TypeOf()); + Assert.That(() => cmd2.Prepare(), Throws.Exception.TypeOf()); } [Test] public async Task Cleans_up_ok_with_dispose_calls([Values(PrepareOrNot.Prepared, PrepareOrNot.NotPrepared)] PrepareOrNot prepare) { - if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - return; - using var conn = await OpenConnectionAsync(); using var command = new NpgsqlCommand("SELECT 1", conn); using var dr = await command.ExecuteReaderAsync(Behavior); @@ -809,6 +799,7 @@ public async Task Null() Assert.That(reader.GetFieldValue(i), Is.EqualTo(DBNull.Value)); Assert.That(reader.GetProviderSpecificValue(i), Is.EqualTo(DBNull.Value)); Assert.That(() => reader.GetString(i), Throws.Exception.TypeOf()); + Assert.That(() => reader.GetStream(i), Throws.Exception.TypeOf()); } } @@ -819,9 +810,6 @@ public async Task Null() [IssueLink("https://github.com/npgsql/npgsql/issues/1898")] public async Task HasRows([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) { - if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - return; - using var conn = await OpenConnectionAsync(); var table = await CreateTempTable(conn, "name TEXT"); @@ -879,7 +867,7 @@ public async Task HasRows_without_resultset() var table = await CreateTempTable(conn, "name TEXT"); using var command = new NpgsqlCommand($"DELETE FROM {table} WHERE name = 'unknown'", conn); using var reader = await command.ExecuteReaderAsync(Behavior); - Assert.IsFalse(reader.HasRows); + Assert.That(reader.HasRows, Is.False); } [Test] @@ -888,9 +876,9 @@ public async Task Interval_as_TimeSpan() using var conn = await OpenConnectionAsync(); using var command = new NpgsqlCommand("SELECT CAST('1 hour' AS interval) AS dauer", conn); using var dr = await command.ExecuteReaderAsync(Behavior); - Assert.IsTrue(dr.HasRows); - Assert.IsTrue(dr.Read()); - Assert.IsTrue(dr.HasRows); + Assert.That(dr.HasRows); + Assert.That(dr.Read()); + Assert.That(dr.HasRows); var ts = dr.GetTimeSpan(0); } @@ -946,7 +934,7 @@ public async Task SequentialBufferedSeekReread() //_ = rdr[5]; // uncomment lines for successful execution _ = rdr.IsDBNull(6); _ = rdr[6]; - Assert.True(rdr.IsDBNull(6)); + Assert.That(rdr.IsDBNull(6)); } } @@ -979,7 +967,7 @@ await pgMock .WriteCommandComplete() .WriteReadyForQuery() .FlushAsync(); - Assert.AreEqual(expected, await task); + Assert.That(await task, Is.EqualTo(expected)); } } @@ -1096,10 +1084,6 @@ public async Task Nullable_scalar() [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2913")] public async Task Bug2913_reading_previous_query_messages() { - // No point in testing for multiplexing, as every query may use another connection - if (IsMultiplexing) - return; - var firstMrs = new ManualResetEventSlim(false); var secondMrs = new ManualResetEventSlim(false); @@ -1292,8 +1276,8 @@ public async Task Bug3772() reader.GetInt32(0); - Assert.Zero(reader.Connector.ReadBuffer.ReadBytesLeft); - Assert.NotZero(reader.Connector.ReadBuffer.ReadPosition); + Assert.That(reader.Connector.ReadBuffer.ReadBytesLeft, Is.Zero); + Assert.That(reader.Connector.ReadBuffer.ReadPosition, Is.Not.Zero); writeBuffer.WriteInt32(byteValue.Length); writeBuffer.WriteBytes(byteValue); @@ -1314,14 +1298,8 @@ public async Task Dispose_does_not_swallow_exceptions([Values(true, false)] bool await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); await using var conn = await dataSource.OpenConnectionAsync(); - await using var tx = IsMultiplexing ? await conn.BeginTransactionAsync() : null; var pgMock = await postmasterMock.WaitForServerConnection(); - if (IsMultiplexing) - pgMock - .WriteEmptyQueryResponse() - .WriteReadyForQuery(TransactionStatus.InTransactionBlock); - // Write responses for the query, but break the connection before sending CommandComplete/ReadyForQuery await pgMock .WriteParseComplete() @@ -1351,7 +1329,7 @@ public async Task Read_string_as_char() cmd.CommandText = "SELECT 'abcdefgh', 'ijklmnop'"; await using var reader = await cmd.ExecuteReaderAsync(Behavior); - Assert.IsTrue(await reader.ReadAsync()); + Assert.That(await reader.ReadAsync()); Assert.That(reader.GetChar(0), Is.EqualTo('a')); if (Behavior == CommandBehavior.SequentialAccess) Assert.Throws(() => reader.GetChar(0)); @@ -1369,7 +1347,7 @@ public async Task GetBytes() var table = await CreateTempTable(conn, "bytes BYTEA"); // TODO: This is too small to actually test any interesting sequential behavior - byte[] expected = { 1, 2, 3, 4, 5 }; + byte[] expected = [1, 2, 3, 4, 5]; var actual = new byte[expected.Length]; await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (bytes) VALUES ({EncodeByteaHex(expected)})"); @@ -1400,7 +1378,7 @@ public async Task GetBytes() reader.GetBytes(4, 0, actual, 0, 2); Assert.That(reader.GetBytes(4, expected.Length - 1, actual, 0, 2), Is.EqualTo(1), "Length greater than data length"); - Assert.That(actual[0], Is.EqualTo(expected[expected.Length - 1]), "Length greater than data length"); + Assert.That(actual[0], Is.EqualTo(expected[^1]), "Length greater than data length"); Assert.That(() => reader.GetBytes(4, 0, actual, 0, actual.Length + 1), Throws.Exception.TypeOf(), "Length great than output buffer length"); // Close in the middle of a column @@ -1428,6 +1406,25 @@ public async Task GetStream_second_time_throws([Values(true, false)] bool isAsyn Throws.Exception.TypeOf()); } + [Test] + public async Task GetBytes_before_getstream([Values(true, false)] bool isAsync) + { + var expected = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 }; + var streamGetter = BuildStreamGetter(isAsync); + + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand($"SELECT {EncodeByteaHex(expected)}::bytea", conn); + using var reader = await cmd.ExecuteReaderAsync(Behavior); + + await reader.ReadAsync(); + + // GetBytes with null buffer won't consume column in any way + Assert.That(reader.GetBytes(0, 0, null, 0, 0), Is.EqualTo(expected.Length), "Bad column length"); + + using var stream = await streamGetter(reader, 0); + Assert.That(stream.Length, Is.EqualTo(expected.Length)); + } + public static IEnumerable GetStreamCases() { var binary = MemoryMarshal @@ -1590,7 +1587,7 @@ public async Task GetStream_seek() var buffer = new byte[4]; await using var stream = reader.GetStream(0); - Assert.IsTrue(stream.CanSeek); + Assert.That(stream.CanSeek); var seekPosition = stream.Seek(-1, SeekOrigin.End); Assert.That(seekPosition, Is.EqualTo(stream.Length - 1)); @@ -1662,12 +1659,36 @@ public async Task GetChars() // Jump to another column from the middle of the column reader.GetChars(5, 0, actual, 0, 2); Assert.That(reader.GetChars(5, expected.Length - 1, actual, 0, 2), Is.EqualTo(1), "Length greater than data length"); - Assert.That(actual[0], Is.EqualTo(expected[expected.Length - 1]), "Length greater than data length"); + Assert.That(actual[0], Is.EqualTo(expected[^1]), "Length greater than data length"); Assert.That(() => reader.GetChars(5, 0, actual, 0, actual.Length + 1), Throws.Exception.TypeOf(), "Length great than output buffer length"); // Close in the middle of a column reader.GetChars(6, 0, actual, 0, 2); } + [Test] + public async Task GetChars_AdvanceConsumed() + { + const string value = "01234567"; + + using var conn = await OpenConnectionAsync(); + using var cmd = new NpgsqlCommand($"SELECT '{value}'", conn); + using var reader = await cmd.ExecuteReaderAsync(Behavior); + reader.Read(); + + var buffer = new char[2]; + // Don't start at the beginning of the column. + reader.GetChars(0, 2, buffer, 0, 2); + reader.GetChars(0, 4, buffer, 0, 2); + reader.GetChars(0, 6, buffer, 0, 2); + + // Ask for data past the start and the previous point, exercising restart logic. + if (!IsSequential) + { + reader.GetChars(0, 4, buffer, 0, 2); + reader.GetChars(0, 6, buffer, 0, 2); + } + } + [Test] public async Task GetTextReader([Values(true, false)] bool isAsync) { @@ -1719,7 +1740,7 @@ public async Task TextReader_zero_length_column() cmd.CommandText = "SELECT ''"; await using var reader = await cmd.ExecuteReaderAsync(Behavior); - Assert.IsTrue(await reader.ReadAsync()); + Assert.That(await reader.ReadAsync()); using var textReader = reader.GetTextReader(0); Assert.That(textReader.Peek(), Is.EqualTo(-1)); @@ -1768,9 +1789,6 @@ public async Task GetChars_when_null() [Test] public async Task Reader_is_reused() { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: Fails"); - using var conn = await OpenConnectionAsync(); NpgsqlDataReader reader1; @@ -1833,9 +1851,6 @@ public async Task GetTextReader_in_middle_of_column_throws([Values] bool async) [Test, IssueLink("https://github.com/npgsql/npgsql/issues/5450")] public async Task EndRead_StreamActive([Values]bool async) { - if (IsMultiplexing) - return; - const int columnLength = 1; await using var conn = await OpenConnectionAsync(); @@ -1856,7 +1871,7 @@ public async Task EndRead_StreamActive([Values]bool async) Assert.DoesNotThrow(() => reader.EndRead()); } - reader.Commit(resuming: false); + reader.Commit(); } [Test, Description("Tests that everything goes well when a type handler generates a NpgsqlSafeReadException")] @@ -1898,9 +1913,6 @@ public async Task Non_SafeReadException() [Test, Description("Cancels ReadAsync via the NpgsqlCommand.Cancel, with successful PG cancellation")] public async Task ReadAsync_cancel_command_soft() { - if (IsMultiplexing) - return; // Multiplexing, cancellation - await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); await using var conn = await dataSource.OpenConnectionAsync(); @@ -1918,7 +1930,7 @@ await pgMock await using (var reader = await cmd.ExecuteReaderAsync(Behavior)) { // Successfully read the first row - Assert.True(await reader.ReadAsync()); + Assert.That(await reader.ReadAsync()); Assert.That(reader.GetInt32(0), Is.EqualTo(1)); // Attempt to read the second row - simulate blocking and cancellation @@ -1947,9 +1959,6 @@ await pgMock [Test, Description("Cancels ReadAsync via the cancellation token, with successful PG cancellation")] public async Task ReadAsync_cancel_soft() { - if (IsMultiplexing) - return; // Multiplexing, cancellation - await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); await using var conn = await dataSource.OpenConnectionAsync(); @@ -1967,7 +1976,7 @@ await pgMock await using (var reader = await cmd.ExecuteReaderAsync(Behavior)) { // Successfully read the first row - Assert.True(await reader.ReadAsync()); + Assert.That(await reader.ReadAsync()); Assert.That(reader.GetInt32(0), Is.EqualTo(1)); // Attempt to read the second row - simulate blocking and cancellation @@ -1998,9 +2007,6 @@ await pgMock [Test, Description("Cancels NextResultAsync via the cancellation token, with successful PG cancellation")] public async Task NextResult_cancel_soft() { - if (IsMultiplexing) - return; // Multiplexing, cancellation - await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); await using var conn = await dataSource.OpenConnectionAsync(); @@ -2019,7 +2025,7 @@ await pgMock await using (var reader = await cmd.ExecuteReaderAsync(Behavior)) { // Successfully read the first resultset - Assert.True(await reader.ReadAsync()); + Assert.That(await reader.ReadAsync()); Assert.That(reader.GetInt32(0), Is.EqualTo(1)); // Attempt to advance to the second resultset - simulate blocking and cancellation @@ -2050,9 +2056,6 @@ await pgMock [Test, Description("Cancels ReadAsync via the cancellation token, with unsuccessful PG cancellation (socket break)")] public async Task ReadAsync_cancel_hard([Values(true, false)] bool passCancelledToken) { - if (IsMultiplexing) - return; // Multiplexing, cancellation - await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); await using var conn = await dataSource.OpenConnectionAsync(); @@ -2070,7 +2073,7 @@ await pgMock await using var reader = await cmd.ExecuteReaderAsync(Behavior); // Successfully read the first row - Assert.True(await reader.ReadAsync()); + Assert.That(await reader.ReadAsync()); Assert.That(reader.GetInt32(0), Is.EqualTo(1)); // Attempt to read the second row - simulate blocking and cancellation @@ -2094,9 +2097,6 @@ await pgMock [Test, Description("Cancels NextResultAsync via the cancellation token, with unsuccessful PG cancellation (socket break)")] public async Task NextResultAsync_cancel_hard([Values(true, false)] bool passCancelledToken) { - if (IsMultiplexing) - return; // Multiplexing, cancellation - await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); await using var conn = await dataSource.OpenConnectionAsync(); @@ -2115,7 +2115,7 @@ await pgMock await using var reader = await cmd.ExecuteReaderAsync(Behavior); // Successfully read the first resultset - Assert.True(await reader.ReadAsync()); + Assert.That(await reader.ReadAsync()); Assert.That(reader.GetInt32(0), Is.EqualTo(1)); // Attempt to read the second row - simulate blocking and cancellation @@ -2139,9 +2139,6 @@ await pgMock [Test, Description("Cancels sequential ReadAsGetFieldValueAsync")] public async Task GetFieldValueAsync_sequential_cancel([Values(true, false)] bool passCancelledToken) { - if (IsMultiplexing) - return; // Multiplexing, cancellation - if (!IsSequential) return; @@ -2177,9 +2174,6 @@ await pgMock [Test, Description("Cancels sequential ReadAsGetFieldValueAsync")] public async Task IsDBNullAsync_sequential_cancel([Values(true, false)] bool passCancelledToken) { - if (IsMultiplexing) - return; // Multiplexing, cancellation - if (!IsSequential) return; @@ -2212,24 +2206,6 @@ await pgMock Assert.That(conn.FullState, Is.EqualTo(ConnectionState.Broken)); } - [Test, Description("Cancellation does not work with the multiplexing")] - public async Task Cancel_multiplexing_disabled() - { - if (!IsMultiplexing) - return; - - await using var dataSource = CreateDataSource(); - await using var conn = await dataSource.OpenConnectionAsync(); - await using var cmd = new NpgsqlCommand("SELECT generate_series(1, 100); SELECT generate_series(1, 100)", conn); - await using var reader = await cmd.ExecuteReaderAsync(Behavior); - var cancelledToken = new CancellationToken(canceled: true); - Assert.IsTrue(await reader.ReadAsync()); - while (await reader.ReadAsync(cancelledToken)) { } - Assert.IsTrue(await reader.NextResultAsync(cancelledToken)); - while (await reader.ReadAsync(cancelledToken)) { } - Assert.IsFalse(conn.Connector!.UserCancellationRequested); - } - #endregion Cancellation #region Timeout @@ -2237,9 +2213,6 @@ public async Task Cancel_multiplexing_disabled() [Test, Description("Timeouts sequential ReadAsGetFieldValueAsync")] public async Task GetFieldValueAsync_sequential_timeout() { - if (IsMultiplexing) - return; // Multiplexing, cancellation - if (!IsSequential) return; @@ -2277,9 +2250,6 @@ await pgMock [Test, Description("Timeouts sequential IsDBNullAsync")] public async Task IsDBNullAsync_sequential_timeout() { - if (IsMultiplexing) - return; // Multiplexing, cancellation - if (!IsSequential) return; @@ -2317,9 +2287,6 @@ await pgMock [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3446")] public async Task Bug3446() { - if (IsMultiplexing) - return; // Multiplexing, cancellation - await using var postmasterMock = PgPostmasterMock.Start(ConnectionString); await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); await using var conn = await dataSource.OpenConnectionAsync(); @@ -2347,6 +2314,41 @@ await pgMock Assert.That(conn.Connector!.State, Is.EqualTo(ConnectorState.Ready)); } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/6160")] + [Description("Consuming result set shouldn't go infinite in case connection is broken")] + public async Task Bug6160() + { + var csb = new NpgsqlConnectionStringBuilder(ConnectionString) + { + // Set to -1 to trigger immediate connection break on timeout + CancellationTimeout = -1, + CommandTimeout = 1 + }; + await using var postmasterMock = PgPostmasterMock.Start(csb.ConnectionString); + await using var dataSource = CreateDataSource(postmasterMock.ConnectionString); + await using var conn = await dataSource.OpenConnectionAsync(); + + var pgMock = await postmasterMock.WaitForServerConnection(); + await pgMock + .WriteParseComplete() + .WriteBindComplete() + .WriteRowDescription(new FieldDescription(Int4Oid)) + .WriteDataRow(new byte[4]) + .FlushAsync(); + + await using var cmd = new NpgsqlCommand("SELECT 1", conn); + await using (var reader = await cmd.ExecuteReaderAsync(Behavior | CommandBehavior.SingleRow)) + { + await reader.ReadAsync(); + // The second read will try to consume the whole resultset due to CommandBehavior.SingleRow + // Which will fail with timeout (and immediate connection break) since we didn't send anything else beside the first row + var ex = Assert.ThrowsAsync(async () => await reader.ReadAsync())!; + Assert.That(ex.InnerException, Is.TypeOf()); + + Assert.That(conn.State, Is.EqualTo(ConnectionState.Closed)); + } + } + #endregion #region Initialization / setup / teardown @@ -2356,7 +2358,7 @@ await pgMock readonly CommandBehavior Behavior; // ReSharper restore InconsistentNaming - public ReaderTests(MultiplexingMode multiplexingMode, CommandBehavior behavior) : base(multiplexingMode) + public ReaderTests(CommandBehavior behavior) { Behavior = behavior; IsSequential = (Behavior & CommandBehavior.SequentialAccess) != 0; @@ -2367,23 +2369,17 @@ public ReaderTests(MultiplexingMode multiplexingMode, CommandBehavior behavior) #region Mock Type Handlers -sealed class ExplodingTypeHandlerResolverFactory : PgTypeInfoResolverFactory +sealed class ExplodingTypeHandlerResolverFactory(bool safe) : PgTypeInfoResolverFactory { - readonly bool _safe; - public ExplodingTypeHandlerResolverFactory(bool safe) => _safe = safe; - - public override IPgTypeInfoResolver CreateResolver() => new Resolver(_safe); + public override IPgTypeInfoResolver CreateResolver() => new Resolver(safe); public override IPgTypeInfoResolver? CreateArrayResolver() => null; - sealed class Resolver : IPgTypeInfoResolver + sealed class Resolver(bool safe) : IPgTypeInfoResolver { - readonly bool _safe; - public Resolver(bool safe) => _safe = safe; - public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) { if (dataTypeName == DataTypeNames.Int4 && (type == typeof(int) || type is null)) - return new(options, new ExplodingTypeHandler(_safe), DataTypeNames.Int4); + return new(options, new ExplodingTypeHandler(safe), DataTypeNames.Int4); return null; } diff --git a/test/Npgsql.Tests/Replication/CommonLogicalReplicationTests.cs b/test/Npgsql.Tests/Replication/CommonLogicalReplicationTests.cs index a8a363a583..cb434edd8a 100644 --- a/test/Npgsql.Tests/Replication/CommonLogicalReplicationTests.cs +++ b/test/Npgsql.Tests/Replication/CommonLogicalReplicationTests.cs @@ -16,7 +16,6 @@ namespace Npgsql.Tests.Replication; /// for the individual logical replication tests, they are in fact not, because /// the methods they test are extension points for plugin developers. /// -[Platform(Exclude = "MacOsX", Reason = "Replication tests are flaky in CI on Mac")] [NonParallelizable] public class CommonLogicalReplicationTests : SafeReplicationTestBase { diff --git a/test/Npgsql.Tests/Replication/CommonReplicationTests.cs b/test/Npgsql.Tests/Replication/CommonReplicationTests.cs index 36a11b434a..2be1f3faff 100644 --- a/test/Npgsql.Tests/Replication/CommonReplicationTests.cs +++ b/test/Npgsql.Tests/Replication/CommonReplicationTests.cs @@ -14,7 +14,6 @@ namespace Npgsql.Tests.Replication; [TestFixture(typeof(LogicalReplicationConnection))] [TestFixture(typeof(PhysicalReplicationConnection))] -[Platform(Exclude = "MacOsX", Reason = "Replication tests are flaky in CI on Mac")] [NonParallelizable] public class CommonReplicationTests : SafeReplicationTestBase where TConnection : ReplicationConnection, new() @@ -432,7 +431,7 @@ async Task GetCommitLsn(string valueString) // NpgsqlLogicalReplicationConnection // Begin Transaction, Insert, Commit Transaction for (var i = 0; i < 3; i++) - Assert.True(await messages.MoveNextAsync()); + Assert.That(await messages.MoveNextAsync()); return messages.Current.Lsn; } diff --git a/test/Npgsql.Tests/Replication/PgOutputReplicationTests.cs b/test/Npgsql.Tests/Replication/PgOutputReplicationTests.cs index e3d81a63f5..a9a90842d3 100644 --- a/test/Npgsql.Tests/Replication/PgOutputReplicationTests.cs +++ b/test/Npgsql.Tests/Replication/PgOutputReplicationTests.cs @@ -10,54 +10,44 @@ using Npgsql.Replication; using Npgsql.Replication.PgOutput; using Npgsql.Replication.PgOutput.Messages; +using Npgsql.Util; using TruncateOptions = Npgsql.Replication.PgOutput.Messages.TruncateMessage.TruncateOptions; using ReplicaIdentitySetting = Npgsql.Replication.PgOutput.Messages.RelationMessage.ReplicaIdentitySetting; using static Npgsql.Tests.TestUtil; namespace Npgsql.Tests.Replication; -[TestFixture(ProtocolVersion.V1, ReplicationDataMode.DefaultReplicationDataMode, TransactionMode.DefaultTransactionMode)] -[TestFixture(ProtocolVersion.V1, ReplicationDataMode.BinaryReplicationDataMode, TransactionMode.DefaultTransactionMode)] -[TestFixture(ProtocolVersion.V2, ReplicationDataMode.DefaultReplicationDataMode, TransactionMode.StreamingTransactionMode)] -[TestFixture(ProtocolVersion.V3, ReplicationDataMode.DefaultReplicationDataMode, TransactionMode.DefaultTransactionMode)] -[TestFixture(ProtocolVersion.V3, ReplicationDataMode.DefaultReplicationDataMode, TransactionMode.StreamingTransactionMode)] -// We currently don't execute all possible combinations of settings for efficiency reasons because they don't -// interact in the current implementation. -// Feel free to uncomment some or all of the following lines if the implementation changed or you suspect a -// problem with some combination. -// [TestFixture(ProtocolVersion.V1, ReplicationDataMode.TextReplicationDataMode, TransactionMode.NonStreamingTransactionMode)] -// [TestFixture(ProtocolVersion.V2, ReplicationDataMode.DefaultReplicationDataMode, TransactionMode.DefaultTransactionMode)] -// [TestFixture(ProtocolVersion.V2, ReplicationDataMode.TextReplicationDataMode, TransactionMode.NonStreamingTransactionMode)] -// [TestFixture(ProtocolVersion.V2, ReplicationDataMode.BinaryReplicationDataMode, TransactionMode.DefaultTransactionMode)] -// [TestFixture(ProtocolVersion.V2, ReplicationDataMode.BinaryReplicationDataMode, TransactionMode.StreamingTransactionMode)] -// [TestFixture(ProtocolVersion.V3, ReplicationDataMode.TextReplicationDataMode, TransactionMode.NonStreamingTransactionMode)] -// [TestFixture(ProtocolVersion.V3, ReplicationDataMode.BinaryReplicationDataMode, TransactionMode.DefaultTransactionMode)] -// [TestFixture(ProtocolVersion.V3, ReplicationDataMode.BinaryReplicationDataMode, TransactionMode.StreamingTransactionMode)] +[TestFixture(PgOutputProtocolVersion.V1, ReplicationDataMode.DefaultReplicationDataMode, TransactionMode.DefaultTransactionMode)] +[TestFixture(PgOutputProtocolVersion.V1, ReplicationDataMode.BinaryReplicationDataMode, TransactionMode.DefaultTransactionMode)] +[TestFixture(PgOutputProtocolVersion.V2, ReplicationDataMode.DefaultReplicationDataMode, TransactionMode.StreamingTransactionMode)] +[TestFixture(PgOutputProtocolVersion.V3, ReplicationDataMode.DefaultReplicationDataMode, TransactionMode.DefaultTransactionMode)] +[TestFixture(PgOutputProtocolVersion.V3, ReplicationDataMode.DefaultReplicationDataMode, TransactionMode.StreamingTransactionMode)] +[TestFixture(PgOutputProtocolVersion.V4, ReplicationDataMode.DefaultReplicationDataMode, TransactionMode.DefaultTransactionMode)] +[TestFixture(PgOutputProtocolVersion.V4, ReplicationDataMode.DefaultReplicationDataMode, TransactionMode.ParallelStreamingTransactionMode)] [NonParallelizable] // These tests aren't designed to be parallelizable -public class PgOutputReplicationTests : SafeReplicationTestBase +public class PgOutputReplicationTests( + PgOutputProtocolVersion protocolVersion, + PgOutputReplicationTests.ReplicationDataMode dataMode, + PgOutputReplicationTests.TransactionMode transactionMode) + : SafeReplicationTestBase { - readonly ulong _protocolVersion; - readonly bool? _binary; - readonly bool? _streaming; + readonly bool? _binary = dataMode == ReplicationDataMode.BinaryReplicationDataMode + ? true + : dataMode == ReplicationDataMode.TextReplicationDataMode + ? false + : null; + readonly PgOutputStreamingMode? _streamingMode = transactionMode switch + { + TransactionMode.DefaultTransactionMode => null, + TransactionMode.NonStreamingTransactionMode => PgOutputStreamingMode.Off, + TransactionMode.StreamingTransactionMode => PgOutputStreamingMode.On, + TransactionMode.ParallelStreamingTransactionMode => PgOutputStreamingMode.Parallel, + _ => throw new ArgumentOutOfRangeException(nameof(transactionMode), transactionMode, null) + }; bool IsBinary => _binary ?? false; - bool IsStreaming => _streaming ?? false; - ulong Version => _protocolVersion; - - public PgOutputReplicationTests(ProtocolVersion protocolVersion, ReplicationDataMode dataMode, TransactionMode transactionMode) - { - _protocolVersion = (ulong)protocolVersion; - _binary = dataMode == ReplicationDataMode.BinaryReplicationDataMode - ? true - : dataMode == ReplicationDataMode.TextReplicationDataMode - ? false - : null; - _streaming = transactionMode == TransactionMode.StreamingTransactionMode - ? true - : transactionMode == TransactionMode.NonStreamingTransactionMode - ? false - : null; - } + bool IsStreaming => _streamingMode.HasValue && _streamingMode.Value != PgOutputStreamingMode.Off; + PgOutputProtocolVersion Version => protocolVersion; [Test] public Task CreatePgOutputReplicationSlot() @@ -125,12 +115,27 @@ public Task Insert() Assert.That(insertMsg.Relation, Is.SameAs(relationMsg)); var columnEnumerator = insertMsg.NewRow.GetAsyncEnumerator(); Assert.That(await columnEnumerator.MoveNextAsync(), Is.True); + var postgresType = columnEnumerator.Current.GetPostgresType(); + Assert.That(postgresType.FullName, Is.EqualTo("pg_catalog.integer")); + Assert.That(columnEnumerator.Current.GetDataTypeName(), Is.EqualTo("integer")); + Assert.That(columnEnumerator.Current.GetFieldName(), Is.EqualTo("id")); if (IsBinary) + { + Assert.That(columnEnumerator.Current.GetFieldType(), Is.EqualTo(typeof(int))); Assert.That(await columnEnumerator.Current.Get(), Is.EqualTo(1)); + } else + { + Assert.That(columnEnumerator.Current.GetFieldType(), Is.EqualTo(typeof(string))); Assert.That(await columnEnumerator.Current.Get(), Is.EqualTo("1")); + } Assert.That(await columnEnumerator.MoveNextAsync(), Is.True); + postgresType = columnEnumerator.Current.GetPostgresType(); + Assert.That(postgresType.FullName, Is.EqualTo("pg_catalog.text")); + Assert.That(columnEnumerator.Current.GetDataTypeName(), Is.EqualTo("text")); + Assert.That(columnEnumerator.Current.GetFieldType(), Is.EqualTo(typeof(string))); + Assert.That(columnEnumerator.Current.GetFieldName(), Is.EqualTo("name")); Assert.That(columnEnumerator.Current.IsDBNull, Is.False); Assert.That(await columnEnumerator.Current.Get(), Is.EqualTo("val1")); Assert.That(await columnEnumerator.MoveNextAsync(), Is.False); @@ -640,7 +645,6 @@ await c.ExecuteNonQueryAsync(@$" await NextMessage(messages); }, nameof(Dispose_while_replicating)); - [Platform(Exclude = "MacOsX", Reason = "Test is flaky in CI on Mac, see https://github.com/npgsql/npgsql/issues/5294")] [TestCase(true, true)] [TestCase(true, false)] [TestCase(false, false)] @@ -724,74 +728,101 @@ public Task LogicalDecodingMessage(bool writeMessages, bool readMessages) } } - if (IsStreaming) + // PostgreSQL 18 skips logical decoding of already-aborted transactions + if (c.PostgreSqlVersion.IsGreaterOrEqual(18)) { - // Begin Transaction 2 - transactionXid = await AssertTransactionStart(messages); - - // Relation - await NextMessage(messages); - - // Inserts - for (var insertCount = 0; insertCount < 10; insertCount++) - await NextMessage(messages); - - // LogicalDecodingMessage 2 (transactional) + // LogicalDecodingMessage 2 (non-transactional) if (writeMessages) { var msg = await NextMessage(messages); - Assert.That(msg.TransactionXid, IsStreaming ? Is.EqualTo(transactionXid) : Is.Null); - Assert.That(msg.Flags, Is.EqualTo(1)); + Assert.That(msg.TransactionXid, Is.Null); + Assert.That(msg.Flags, Is.EqualTo(0)); Assert.That(msg.Prefix, Is.EqualTo(prefix)); - Assert.That(msg.Data.Length, Is.EqualTo(transactionalMessage.Length)); + Assert.That(msg.Data.Length, Is.EqualTo(nonTransactionalMessage.Length)); if (readMessages) { var buffer = new MemoryStream(); await msg.Data.CopyToAsync(buffer, CancellationToken.None); - Assert.That(rc.Encoding.GetString(buffer.ToArray()), Is.EqualTo(transactionalMessage)); + Assert.That(rc.Encoding.GetString(buffer.ToArray()), Is.EqualTo(nonTransactionalMessage)); } } - - // Further inserts - // We don't try to predict how many insert messages we get here - // since the streaming transaction will most likely abort before - // we reach the expected number - while (await messages.MoveNextAsync() && messages.Current is InsertMessage - || messages.Current is StreamStopMessage - && await messages.MoveNextAsync() - && messages.Current is StreamStartMessage - && await messages.MoveNextAsync() - && messages.Current is InsertMessage) + } + else + { + if (IsStreaming) { - // Ignore + // Begin Transaction 2 + transactionXid = await AssertTransactionStart(messages); + + // Relation + await NextMessage(messages); + + // Inserts + for (var insertCount = 0; insertCount < 10; insertCount++) + await NextMessage(messages); + + // LogicalDecodingMessage 2 (transactional) + if (writeMessages) + { + var msg = await NextMessage(messages); + Assert.That(msg.TransactionXid, IsStreaming ? Is.EqualTo(transactionXid) : Is.Null); + Assert.That(msg.Flags, Is.EqualTo(1)); + Assert.That(msg.Prefix, Is.EqualTo(prefix)); + Assert.That(msg.Data.Length, Is.EqualTo(transactionalMessage.Length)); + if (readMessages) + { + var buffer = new MemoryStream(); + await msg.Data.CopyToAsync(buffer, CancellationToken.None); + Assert.That(rc.Encoding.GetString(buffer.ToArray()), Is.EqualTo(transactionalMessage)); + } + } + + // Further inserts + // We don't try to predict how many insert messages we get here + // since the streaming transaction will most likely abort before + // we reach the expected number + while (await messages.MoveNextAsync() && messages.Current is InsertMessage + || messages.Current is StreamStopMessage + && await messages.MoveNextAsync() + && messages.Current is StreamStartMessage + && await messages.MoveNextAsync() + && messages.Current is InsertMessage) + { + // Ignore + } } - } - else if (writeMessages) - await messages.MoveNextAsync(); + else if (writeMessages) + await messages.MoveNextAsync(); - // LogicalDecodingMessage 3 (non-transactional) - if (writeMessages) - { - var msg = (LogicalDecodingMessage)messages.Current; - Assert.That(msg.TransactionXid, Is.Null); - Assert.That(msg.Flags, Is.EqualTo(0)); - Assert.That(msg.Prefix, Is.EqualTo(prefix)); - Assert.That(msg.Data.Length, Is.EqualTo(nonTransactionalMessage.Length)); - if (readMessages) + // LogicalDecodingMessage 3 (non-transactional) + if (writeMessages) { - var buffer = new MemoryStream(); - await msg.Data.CopyToAsync(buffer, CancellationToken.None); - Assert.That(rc.Encoding.GetString(buffer.ToArray()), Is.EqualTo(nonTransactionalMessage)); + var msg = (LogicalDecodingMessage)messages.Current; + Assert.That(msg.TransactionXid, Is.Null); + Assert.That(msg.Flags, Is.EqualTo(0)); + Assert.That(msg.Prefix, Is.EqualTo(prefix)); + Assert.That(msg.Data.Length, Is.EqualTo(nonTransactionalMessage.Length)); + if (readMessages) + { + var buffer = new MemoryStream(); + await msg.Data.CopyToAsync(buffer, CancellationToken.None); + Assert.That(rc.Encoding.GetString(buffer.ToArray()), Is.EqualTo(nonTransactionalMessage)); + } + + if (IsStreaming) + await messages.MoveNextAsync(); } + // Rollback Transaction 2 if (IsStreaming) - await messages.MoveNextAsync(); + { + Assert.That(messages.Current, + _streamingMode == PgOutputStreamingMode.On + ? Is.TypeOf() + : Is.TypeOf()); + } } - // Rollback Transaction 2 - if (IsStreaming) - Assert.That(messages.Current, Is.TypeOf()); - streamingCts.Cancel(); await AssertReplicationCancellation(messages); await rc.DropReplicationSlot(slotName, cancellationToken: CancellationToken.None); @@ -1090,7 +1121,7 @@ public Task TwoPhase([Values]bool commit) { // Streaming of prepared transaction is only supported for // logical streaming replication protocol >= 3 - if (_protocolVersion < 3UL) + if (protocolVersion < PgOutputProtocolVersion.V3) return Task.CompletedTask; return SafePgOutputReplicationTest( @@ -1170,7 +1201,7 @@ public Task TwoPhase([Values]bool commit) public Task Bug4633() { // We don't need all the various test cases here since the bug gets triggered in any case - if (IsStreaming || IsBinary || Version > 1) + if (IsStreaming || IsBinary || Version > PgOutputProtocolVersion.V1) return Task.CompletedTask; return SafePgOutputReplicationTest( @@ -1349,7 +1380,7 @@ await c.ExecuteNonQueryAsync($""" async Task AssertTransactionStart(IAsyncEnumerator messages) { - Assert.True(await messages.MoveNextAsync()); + Assert.That(await messages.MoveNextAsync()); switch (messages.Current) { @@ -1370,13 +1401,13 @@ await c.ExecuteNonQueryAsync($""" async Task AssertTransactionCommit(IAsyncEnumerator messages) { - Assert.True(await messages.MoveNextAsync()); + Assert.That(await messages.MoveNextAsync()); switch (messages.Current) { case StreamStopMessage: Assert.That(IsStreaming); - Assert.True(await messages.MoveNextAsync()); + Assert.That(await messages.MoveNextAsync()); Assert.That(messages.Current, Is.TypeOf()); return; case CommitMessage: @@ -1389,10 +1420,10 @@ async Task AssertTransactionCommit(IAsyncEnumerator async Task AssertPrepare(IAsyncEnumerator enumerator) { - Assert.True(await enumerator.MoveNextAsync()); + Assert.That(await enumerator.MoveNextAsync()); if (IsStreaming && enumerator.Current is StreamStopMessage) { - Assert.True(await enumerator.MoveNextAsync()); + Assert.That(await enumerator.MoveNextAsync()); Assert.That(enumerator.Current, Is.TypeOf()); return (PrepareMessageBase)enumerator.Current!; } @@ -1404,16 +1435,16 @@ async Task AssertPrepare(IAsyncEnumerator NextMessage(IAsyncEnumerator enumerator, bool expectRelationMessage = false) where TExpected : PgOutputReplicationMessage { - Assert.True(await enumerator.MoveNextAsync()); + Assert.That(await enumerator.MoveNextAsync()); if (IsStreaming && enumerator.Current is StreamStopMessage) { - Assert.True(await enumerator.MoveNextAsync()); + Assert.That(await enumerator.MoveNextAsync()); Assert.That(enumerator.Current, Is.TypeOf()); - Assert.True(await enumerator.MoveNextAsync()); + Assert.That(await enumerator.MoveNextAsync()); if (expectRelationMessage) { Assert.That(enumerator.Current, Is.TypeOf()); - Assert.True(await enumerator.MoveNextAsync()); + Assert.That(await enumerator.MoveNextAsync()); } } @@ -1453,7 +1484,7 @@ async IAsyncEnumerable SkipEmptyTransactions(IAsyncE } PgOutputReplicationOptions GetOptions(string publicationName, bool? messages = null) - => new(publicationName, _protocolVersion, _binary, _streaming, messages); + => new(publicationName, protocolVersion, _binary, _streamingMode, messages); Task SafePgOutputReplicationTest(Func testAction, [CallerMemberName] string memberName = "") => SafeReplicationTest(testAction, GetObjectName(memberName)); @@ -1464,11 +1495,11 @@ Task SafePgOutputReplicationTest(Func testAction string GetObjectName(string memberName) { var sb = new StringBuilder(memberName) - .Append("_v").Append(_protocolVersion); + .Append("_v").Append(protocolVersion); if (_binary.HasValue) sb.Append("_b_").Append(BoolToChar(_binary.Value)); - if (_streaming.HasValue) - sb.Append("_s_").Append(BoolToChar(_streaming.Value)); + if (_streamingMode.HasValue) + sb.Append("_s_").Append(_streamingMode.Value); return sb.ToString(); } @@ -1483,15 +1514,25 @@ public async Task SetUp() { await using var c = await OpenConnectionAsync(); TestUtil.MinimumPgVersion(c, "10.0", "The Logical Replication Protocol (via pgoutput plugin) was introduced in PostgreSQL 10"); - if (_protocolVersion > 2) + if (protocolVersion > PgOutputProtocolVersion.V3) + TestUtil.MinimumPgVersion(c, "16.0", "Logical Streaming Replication Protocol version 4 was introduced in PostgreSQL 16"); + if (protocolVersion > PgOutputProtocolVersion.V2) TestUtil.MinimumPgVersion(c, "15.0", "Logical Streaming Replication Protocol version 3 was introduced in PostgreSQL 15"); - if (_protocolVersion > 1) + if (protocolVersion > PgOutputProtocolVersion.V1) TestUtil.MinimumPgVersion(c, "14.0", "Logical Streaming Replication Protocol version 2 was introduced in PostgreSQL 14"); if (IsBinary) TestUtil.MinimumPgVersion(c, "14.0", "Sending replication values in binary representation was introduced in PostgreSQL 14"); if (IsStreaming) { - TestUtil.MinimumPgVersion(c, "14.0", "Streaming of in-progress transactions was introduced in PostgreSQL 14"); + switch (_streamingMode) + { + case PgOutputStreamingMode.On: + TestUtil.MinimumPgVersion(c, "14.0", "Streaming of in-progress transactions was introduced in PostgreSQL 14"); + break; + case PgOutputStreamingMode.Parallel: + TestUtil.MinimumPgVersion(c, "16.0", "Parallel streaming of in-progress transactions was introduced in PostgreSQL 16"); + break; + } var logicalDecodingWorkMem = (string)(await c.ExecuteScalarAsync("SHOW logical_decoding_work_mem"))!; if (logicalDecodingWorkMem != "64kB") { @@ -1502,12 +1543,6 @@ public async Task SetUp() } } - public enum ProtocolVersion : ulong - { - V1 = 1UL, - V2 = 2UL, - V3 = 3UL, - } public enum ReplicationDataMode { DefaultReplicationDataMode, @@ -1519,6 +1554,7 @@ public enum TransactionMode DefaultTransactionMode, NonStreamingTransactionMode, StreamingTransactionMode, + ParallelStreamingTransactionMode } #endregion Non-Test stuff (helper methods, initialization, ennums, ...) diff --git a/test/Npgsql.Tests/Replication/PhysicalReplicationTests.cs b/test/Npgsql.Tests/Replication/PhysicalReplicationTests.cs index 59698b87ac..62c19451e9 100644 --- a/test/Npgsql.Tests/Replication/PhysicalReplicationTests.cs +++ b/test/Npgsql.Tests/Replication/PhysicalReplicationTests.cs @@ -90,7 +90,7 @@ public Task Replication_with_slot() // other transactions possibly from system processes can // interfere here, inserting additional messages, but more // likely we'll get everything in one big chunk. - Assert.True(await messages.MoveNextAsync()); + Assert.That(await messages.MoveNextAsync()); var message = messages.Current; Assert.That(message.WalStart, Is.EqualTo(info.XLogPos)); Assert.That(message.WalEnd, Is.GreaterThan(message.WalStart)); @@ -128,7 +128,7 @@ public async Task Replication_without_slot() // other transactions possibly from system processes can // interfere here, inserting additional messages, but more // likely we'll get everything in one big chunk. - Assert.True(await messages.MoveNextAsync()); + Assert.That(await messages.MoveNextAsync()); var message = messages.Current; Assert.That(message.WalStart, Is.EqualTo(info.XLogPos)); Assert.That(message.WalEnd, Is.GreaterThan(message.WalStart)); diff --git a/test/Npgsql.Tests/Replication/TestDecodingReplicationTests.cs b/test/Npgsql.Tests/Replication/TestDecodingReplicationTests.cs index 5d7c633f6c..406be7b809 100644 --- a/test/Npgsql.Tests/Replication/TestDecodingReplicationTests.cs +++ b/test/Npgsql.Tests/Replication/TestDecodingReplicationTests.cs @@ -12,7 +12,6 @@ namespace Npgsql.Tests.Replication; /// implementation of logical replication was still somewhat incomplete. /// Please don't change them without confirming that they still work on those old versions. /// -[Platform(Exclude = "MacOsX", Reason = "Replication tests are flaky in CI on Mac")] [NonParallelizable] // These tests aren't designed to be parallelizable public class TestDecodingReplicationTests : SafeReplicationTestBase { @@ -327,7 +326,7 @@ await c.ExecuteNonQueryAsync(@$" static async ValueTask NextMessage(IAsyncEnumerator enumerator) { - Assert.True(await enumerator.MoveNextAsync()); + Assert.That(await enumerator.MoveNextAsync()); return enumerator.Current!; } diff --git a/test/Npgsql.Tests/SchemaTests.cs b/test/Npgsql.Tests/SchemaTests.cs index e65fc48cf2..ebe36269b0 100644 --- a/test/Npgsql.Tests/SchemaTests.cs +++ b/test/Npgsql.Tests/SchemaTests.cs @@ -10,7 +10,7 @@ namespace Npgsql.Tests; -public class SchemaTests : SyncOrAsyncTestBase +public class SchemaTests(SyncOrAsync syncOrAsync) : SyncOrAsyncTestBase(syncOrAsync) { [Test] public async Task MetaDataCollections() @@ -47,7 +47,7 @@ public async Task No_parameter() Assert.That(collections1, Is.EquivalentTo(collections2)); } - [Test, Description("Calling GetSchema(collectionName [, restrictions]) case insensive collectionName can be used")] + [Test, Description("Calling GetSchema(collectionName [, restrictions]) case insensitive collectionName can be used")] public async Task Case_insensitive_collection_name() { await using var conn = await OpenConnectionAsync(); @@ -257,7 +257,7 @@ public async Task ForeignKeys() { await using var conn = await OpenConnectionAsync(); var dt = await GetSchema(conn, "ForeignKeys"); - Assert.IsNotNull(dt); + Assert.That(dt, Is.Not.Null); } [Test] @@ -276,7 +276,7 @@ public async Task ParameterMarkerFormat() command.CommandText = $"SELECT * FROM {table} WHERE int=" + string.Format(parameterMarkerFormat, parameterName); command.Parameters.Add(new NpgsqlParameter(parameterName, 4)); await using var reader = await command.ExecuteReaderAsync(); - Assert.IsTrue(reader.Read()); + Assert.That(reader.Read()); } [Test] @@ -286,7 +286,7 @@ public async Task Precision_and_scale() var table = await CreateTempTable( conn, "explicit_both NUMERIC(10,2), explicit_precision NUMERIC(10), implicit_both NUMERIC, integer INTEGER, text TEXT"); - var dataTable = await GetSchema(conn, "Columns", new[] { null, null, table }); + var dataTable = await GetSchema(conn, "Columns", [null, null, table]); var rows = dataTable.Rows.Cast().ToList(); var explicitBoth = rows.Single(r => (string)r["column_name"] == "explicit_both"); @@ -339,7 +339,7 @@ public async Task GetSchema_tables_with_restrictions() await using var conn = await OpenConnectionAsync(); var table = await CreateTempTable(conn, "bar INTEGER"); - var dt = await GetSchema(conn, "Tables", new[] { null, null, table }); + var dt = await GetSchema(conn, "Tables", [null, null, table]); foreach (var row in dt.Rows.OfType()) Assert.That(row["table_name"], Is.EqualTo(table)); } @@ -352,7 +352,7 @@ public async Task GetSchema_views_with_restrictions() await conn.ExecuteNonQueryAsync($"CREATE VIEW {view} AS SELECT 8 AS foo"); - var dt = await GetSchema(conn, "Views", new[] { null, null, view }); + var dt = await GetSchema(conn, "Views", [null, null, view]); foreach (var row in dt.Rows.OfType()) Assert.That(row["table_name"], Is.EqualTo(view)); } @@ -365,7 +365,7 @@ public async Task GetSchema_materialized_views_with_restrictions() await conn.ExecuteNonQueryAsync($"CREATE MATERIALIZED VIEW {viewName} AS SELECT 8 AS foo"); - var dt = await GetSchema(conn, "MaterializedViews", new[] { null, viewName, null, null }); + var dt = await GetSchema(conn, "MaterializedViews", [null, viewName, null, null]); foreach (var row in dt.Rows.OfType()) Assert.That(row["table_name"], Is.EqualTo(viewName)); } @@ -376,7 +376,7 @@ public async Task Primary_key() await using var conn = await OpenConnectionAsync(); var table = await CreateTempTable(conn, "id INT PRIMARY KEY, f1 INT"); - var dataTable = await GetSchema(conn, "CONSTRAINTCOLUMNS", new[] { null, null, table }); + var dataTable = await GetSchema(conn, "CONSTRAINTCOLUMNS", [null, null, table]); var column = dataTable.Rows.Cast().Single(); Assert.That(column["table_schema"], Is.EqualTo("public")); @@ -391,7 +391,7 @@ public async Task Primary_key_composite() await using var conn = await OpenConnectionAsync(); var table = await CreateTempTable(conn, "id1 INT, id2 INT, f1 INT, PRIMARY KEY (id1, id2)"); - var dataTable = await GetSchema(conn, "CONSTRAINTCOLUMNS", new[] { null, null, table }); + var dataTable = await GetSchema(conn, "CONSTRAINTCOLUMNS", [null, null, table]); var columns = dataTable.Rows.Cast().OrderBy(r => r["ordinal_number"]).ToList(); Assert.That(columns.All(r => r["table_schema"].Equals("public"))); @@ -410,7 +410,7 @@ public async Task Unique_constraint() var database = await conn.ExecuteScalarAsync("SELECT current_database()"); - var dataTable = await GetSchema(conn, "CONSTRAINTCOLUMNS", new[] { null, null, table }); + var dataTable = await GetSchema(conn, "CONSTRAINTCOLUMNS", [null, null, table]); var columns = dataTable.Rows.Cast().ToList(); Assert.That(columns.All(r => r["constraint_catalog"].Equals(database))); @@ -425,11 +425,11 @@ public async Task Unique_constraint() // Columns are not necessarily in the correct order var firstColumn = columns.FirstOrDefault(x => (string)x["column_name"] == "f1")!; - Assert.NotNull(firstColumn); + Assert.That(firstColumn, Is.Not.Null); Assert.That(firstColumn["ordinal_number"], Is.EqualTo(1)); var secondColumn = columns.FirstOrDefault(x => (string)x["column_name"] == "f2")!; - Assert.NotNull(secondColumn); + Assert.That(secondColumn, Is.Not.Null); Assert.That(secondColumn["ordinal_number"], Is.EqualTo(2)); } @@ -448,7 +448,7 @@ await conn.ExecuteNonQueryAsync(@$" var database = await conn.ExecuteScalarAsync("SELECT current_database()"); - var dataTable = await GetSchema(conn, "INDEXES", new[] { null, null, table }); + var dataTable = await GetSchema(conn, "INDEXES", [null, null, table]); var index = dataTable.Rows.Cast().Single(); Assert.That(index["table_schema"], Is.EqualTo("public")); @@ -456,7 +456,7 @@ await conn.ExecuteNonQueryAsync(@$" Assert.That(index["index_name"], Is.EqualTo(constraint)); Assert.That(index["type_desc"], Is.EqualTo("")); - string[] indexColumnRestrictions = { null!, null!, table }; + string[] indexColumnRestrictions = [null!, null!, table]; var dataTable2 = await GetSchema(conn, "INDEXCOLUMNS", indexColumnRestrictions); var columns = dataTable2.Rows.Cast().ToList(); @@ -470,7 +470,7 @@ await conn.ExecuteNonQueryAsync(@$" Assert.That(columns[0]["column_name"], Is.EqualTo("f1")); Assert.That(columns[1]["column_name"], Is.EqualTo("f2")); - string[] indexColumnRestrictions3 = { (string) database! , "public", table, constraint, "f1" }; + string[] indexColumnRestrictions3 = [(string) database! , "public", table, constraint, "f1"]; var dataTable3 = await GetSchema(conn, "INDEXCOLUMNS", indexColumnRestrictions3); var columns3 = dataTable3.Rows.Cast().ToList(); Assert.That(columns3.Count, Is.EqualTo(1)); @@ -533,7 +533,7 @@ vbit bit varying(5), cid cid"; var table = await CreateTempTable(conn, columnDefinition); - var columnsSchema = await GetSchema(conn, "Columns", new[] { null, null, table }); + var columnsSchema = await GetSchema(conn, "Columns", [null, null, table]); var columns = columnsSchema.Rows.Cast().ToList(); var dataTypes = await GetSchema(conn, DbMetaDataCollectionNames.DataTypes); @@ -554,7 +554,7 @@ await conn.ExecuteNonQueryAsync($@" CREATE TYPE {enumName} AS ENUM ('red', 'yellow', 'blue'); CREATE TABLE {table} (color {enumName});"); - var dataTable = await GetSchema(conn, "Columns", new[] { null, null, table }); + var dataTable = await GetSchema(conn, "Columns", [null, null, table]); var row = dataTable.Rows.Cast().Single(); Assert.That(row["data_type"], Is.EqualTo(enumName)); } @@ -571,7 +571,7 @@ await conn.ExecuteNonQueryAsync($@" CREATE TYPE {schema}.{enumName} AS ENUM ('red', 'yellow', 'blue'); CREATE TABLE {table} (color {schema}.{enumName});"); - var dataTable = await GetSchema(conn, "Columns", new[] { null, null, table }); + var dataTable = await GetSchema(conn, "Columns", [null, null, table]); var row = dataTable.Rows.Cast().Single(); Assert.That(row["data_type"], Is.EqualTo($"{schema}.{enumName}")); } @@ -584,8 +584,6 @@ public async Task SlimBuilder_introspection_without_unsupported_type_exceptions( Assert.That(() => GetSchema(conn, DbMetaDataCollectionNames.DataTypes), Throws.Nothing); } - public SchemaTests(SyncOrAsync syncOrAsync) : base(syncOrAsync) { } - // ReSharper disable MethodHasAsyncOverload async Task GetSchema(NpgsqlConnection conn) => IsAsync ? await conn.GetSchemaAsync() : conn.GetSchema(); diff --git a/test/Npgsql.Tests/SecurityTests.cs b/test/Npgsql.Tests/SecurityTests.cs index 8600942969..cb591b39eb 100644 --- a/test/Npgsql.Tests/SecurityTests.cs +++ b/test/Npgsql.Tests/SecurityTests.cs @@ -2,6 +2,8 @@ using System.IO; using System.Runtime.InteropServices; using System.Security.Authentication; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; using System.Threading; using System.Threading.Tasks; using Npgsql.Properties; @@ -13,27 +15,27 @@ namespace Npgsql.Tests; public class SecurityTests : TestBase { [Test, Description("Establishes an SSL connection, assuming a self-signed server certificate")] - public void Basic_ssl() + public async Task Basic_ssl() { - using var dataSource = CreateDataSource(csb => + await using var dataSource = CreateDataSource(csb => { csb.SslMode = SslMode.Require; }); - using var conn = dataSource.OpenConnection(); - Assert.That(conn.IsSecure, Is.True); + await using var conn = await dataSource.OpenConnectionAsync(); + Assert.That(conn.IsSslEncrypted, Is.True); } [Test, Description("Default user must run with md5 password encryption")] - public void Default_user_uses_md5_password() + public async Task Default_user_uses_md5_password() { if (!IsOnBuildServer) Assert.Ignore("Only executed in CI"); - using var dataSource = CreateDataSource(csb => + await using var dataSource = CreateDataSource(csb => { csb.SslMode = SslMode.Require; }); - using var conn = dataSource.OpenConnection(); + await using var conn = await dataSource.OpenConnectionAsync(); Assert.That(conn.IsScram, Is.False); Assert.That(conn.IsScramPlus, Is.False); } @@ -72,7 +74,7 @@ public void IsSecure_without_ssl() { using var dataSource = CreateDataSource(csb => csb.SslMode = SslMode.Disable); using var conn = dataSource.OpenConnection(); - Assert.That(conn.IsSecure, Is.False); + Assert.That(conn.IsSslEncrypted, Is.False); } [Test, Explicit("Needs to be set up (and run with with Kerberos credentials on Linux)")] @@ -230,13 +232,8 @@ public void ScramPlus_channel_binding([Values] ChannelBinding channelBinding) } [Test] - public async Task Connect_with_only_ssl_allowed_user([Values] bool multiplexing, [Values] bool keepAlive) + public async Task Connect_with_only_ssl_allowed_user([Values] bool keepAlive) { - if (multiplexing && keepAlive) - { - Assert.Ignore("Multiplexing doesn't support keepalive"); - } - try { await using var dataSource = CreateDataSource(csb => @@ -244,11 +241,10 @@ public async Task Connect_with_only_ssl_allowed_user([Values] bool multiplexing, csb.SslMode = SslMode.Allow; csb.Username = "npgsql_tests_ssl"; csb.Password = "npgsql_tests_ssl"; - csb.Multiplexing = multiplexing; csb.KeepAlive = keepAlive ? 10 : 0; }); await using var conn = await dataSource.OpenConnectionAsync(); - Assert.IsTrue(conn.IsSecure); + Assert.That(conn.IsSslEncrypted); } catch (Exception e) when (!IsOnBuildServer) { @@ -259,13 +255,8 @@ public async Task Connect_with_only_ssl_allowed_user([Values] bool multiplexing, [Test] [Platform(Exclude = "Win", Reason = "Postgresql doesn't close connection correctly on windows which might result in missing error message")] - public async Task Connect_with_only_non_ssl_allowed_user([Values] bool multiplexing, [Values] bool keepAlive) + public async Task Connect_with_only_non_ssl_allowed_user([Values] bool keepAlive) { - if (multiplexing && keepAlive) - { - Assert.Ignore("Multiplexing doesn't support keepalive"); - } - try { await using var dataSource = CreateDataSource(csb => @@ -273,11 +264,10 @@ public async Task Connect_with_only_non_ssl_allowed_user([Values] bool multiplex csb.SslMode = SslMode.Prefer; csb.Username = "npgsql_tests_nossl"; csb.Password = "npgsql_tests_nossl"; - csb.Multiplexing = multiplexing; csb.KeepAlive = keepAlive ? 10 : 0; }); await using var conn = await dataSource.OpenConnectionAsync(); - Assert.IsFalse(conn.IsSecure); + Assert.That(conn.IsSslEncrypted, Is.False); } catch (NpgsqlException ex) when (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && ex.InnerException is IOException) { @@ -294,16 +284,19 @@ public async Task Connect_with_only_non_ssl_allowed_user([Values] bool multiplex } [Test] - public async Task DataSource_UserCertificateValidationCallback_is_invoked([Values] bool acceptCertificate) + public async Task DataSource_SslClientAuthenticationOptionsCallback_is_invoked([Values] bool acceptCertificate) { var callbackWasInvoked = false; var dataSourceBuilder = CreateDataSourceBuilder(); dataSourceBuilder.ConnectionStringBuilder.SslMode = SslMode.Require; - dataSourceBuilder.UseUserCertificateValidationCallback((_, _, _, _) => + dataSourceBuilder.UseSslClientAuthenticationOptionsCallback(options => { - callbackWasInvoked = true; - return acceptCertificate; + options.RemoteCertificateValidationCallback = (_, _, _, _) => + { + callbackWasInvoked = true; + return acceptCertificate; + }; }); await using var dataSource = dataSourceBuilder.Build(); await using var connection = dataSource.CreateConnection(); @@ -320,7 +313,7 @@ public async Task DataSource_UserCertificateValidationCallback_is_invoked([Value } [Test] - public async Task Connection_UserCertificateValidationCallback_is_invoked([Values] bool acceptCertificate) + public async Task Connection_SslClientAuthenticationOptionsCallback_is_invoked([Values] bool acceptCertificate) { var callbackWasInvoked = false; @@ -328,10 +321,13 @@ public async Task Connection_UserCertificateValidationCallback_is_invoked([Value dataSourceBuilder.ConnectionStringBuilder.SslMode = SslMode.Require; await using var dataSource = dataSourceBuilder.Build(); await using var connection = dataSource.CreateConnection(); - connection.UserCertificateValidationCallback = (_, _, _, _) => + connection.SslClientAuthenticationOptionsCallback = options => { - callbackWasInvoked = true; - return acceptCertificate; + options.RemoteCertificateValidationCallback = (_, _, _, _) => + { + callbackWasInvoked = true; + return acceptCertificate; + }; }; if (acceptCertificate) @@ -350,10 +346,13 @@ public void Connect_with_Verify_and_callback_throws([Values(SslMode.VerifyCA, Ss { using var dataSource = CreateDataSource(csb => csb.SslMode = sslMode); using var connection = dataSource.CreateConnection(); - connection.UserCertificateValidationCallback = (_, _, _, _) => true; + connection.SslClientAuthenticationOptionsCallback = options => + { + options.RemoteCertificateValidationCallback = (_, _, _, _) => true; + }; var ex = Assert.ThrowsAsync(async () => await connection.OpenAsync())!; - Assert.That(ex.Message, Is.EqualTo(string.Format(NpgsqlStrings.CannotUseSslVerifyWithUserCallback, sslMode))); + Assert.That(ex.Message, Is.EqualTo(string.Format(NpgsqlStrings.CannotUseSslVerifyWithCustomValidationCallback, sslMode))); } [Test] @@ -365,10 +364,13 @@ public void Connect_with_RootCertificate_and_callback_throws() csb.RootCertificate = "foo"; }); using var connection = dataSource.CreateConnection(); - connection.UserCertificateValidationCallback = (_, _, _, _) => true; + connection.SslClientAuthenticationOptionsCallback = options => + { + options.RemoteCertificateValidationCallback = (_, _, _, _) => true; + }; var ex = Assert.ThrowsAsync(async () => await connection.OpenAsync())!; - Assert.That(ex.Message, Is.EqualTo(string.Format(NpgsqlStrings.CannotUseSslRootCertificateWithUserCallback))); + Assert.That(ex.Message, Is.EqualTo(string.Format(NpgsqlStrings.CannotUseSslRootCertificateWithCustomValidationCallback))); } [Test] @@ -388,7 +390,7 @@ public async Task Bug4305_Secure([Values] bool async) try { conn = await dataSource.OpenConnectionAsync(); - Assert.IsTrue(conn.IsSecure); + Assert.That(conn.IsSslEncrypted); } catch (Exception e) when (!IsOnBuildServer) { @@ -412,7 +414,7 @@ public async Task Bug4305_Secure([Values] bool async) await conn.CloseAsync(); await conn.OpenAsync(); - Assert.AreSame(originalConnector, conn.Connector); + Assert.That(conn.Connector, Is.SameAs(originalConnector)); } cmd.CommandText = "SELECT 1"; @@ -439,7 +441,7 @@ public async Task Bug4305_not_Secure([Values] bool async) try { conn = await dataSource.OpenConnectionAsync(); - Assert.IsFalse(conn.IsSecure); + Assert.That(conn.IsSslEncrypted, Is.False); } catch (Exception e) when (!IsOnBuildServer) { @@ -461,7 +463,7 @@ public async Task Bug4305_not_Secure([Values] bool async) await conn.CloseAsync(); await conn.OpenAsync(); - Assert.AreSame(originalConnector, conn.Connector); + Assert.That(conn.Connector, Is.SameAs(originalConnector)); cmd.CommandText = "SELECT 1"; if (async) @@ -470,6 +472,149 @@ public async Task Bug4305_not_Secure([Values] bool async) Assert.DoesNotThrow(() => cmd.ExecuteNonQuery()); } + [Test] + public async Task Direct_ssl_negotiation() + { + await using var adminConn = await OpenConnectionAsync(); + MinimumPgVersion(adminConn, "17.0"); + + await using var dataSource = CreateDataSource(csb => + { + csb.SslMode = SslMode.Require; + csb.SslNegotiation = SslNegotiation.Direct; + }); + await using var conn = await dataSource.OpenConnectionAsync(); + Assert.That(conn.IsSslEncrypted); + } + + [Test] + public void Direct_ssl_requires_correct_sslmode([Values] SslMode sslMode) + { + if (sslMode is SslMode.Disable or SslMode.Allow or SslMode.Prefer) + { + var ex = Assert.Throws(() => + { + using var dataSource = CreateDataSource(csb => + { + csb.SslMode = sslMode; + csb.SslNegotiation = SslNegotiation.Direct; + }); + })!; + Assert.That(ex.Message, Is.EqualTo("SSL Mode has to be Require or higher to be used with direct SSL Negotiation")); + } + else + { + using var dataSource = CreateDataSource(csb => + { + csb.SslMode = sslMode; + csb.SslNegotiation = SslNegotiation.Direct; + }); + } + } + + [Test] + [Platform(Exclude = "MacOsX", Reason = "Mac requires explicit opt-in to receive CA certificate in TLS handshake")] + public async Task Connect_with_verify_and_ca_cert([Values(SslMode.VerifyCA, SslMode.VerifyFull)] SslMode sslMode) + { + if (!IsOnBuildServer) + Assert.Ignore("Only executed in CI"); + + await using var dataSource = CreateDataSource(csb => + { + csb.SslMode = sslMode; + csb.RootCertificate = "ca.crt"; + }); + + await using var _ = await dataSource.OpenConnectionAsync(); + } + + [Test] + [Platform(Exclude = "MacOsX", Reason = "Mac requires explicit opt-in to receive CA certificate in TLS handshake")] + public async Task Connect_with_verify_check_host([Values(SslMode.VerifyCA, SslMode.VerifyFull)] SslMode sslMode) + { + if (!IsOnBuildServer) + Assert.Ignore("Only executed in CI"); + + await using var dataSource = CreateDataSource(csb => + { + csb.Host = "127.0.0.1"; + csb.SslMode = sslMode; + csb.RootCertificate = "ca.crt"; + }); + + if (sslMode == SslMode.VerifyCA) + { + await using var _ = await dataSource.OpenConnectionAsync(); + } + else + { + var ex = Assert.ThrowsAsync(async () => await dataSource.OpenConnectionAsync())!; + Assert.That(ex.InnerException, Is.TypeOf()); + } + } + + [Test] + [Platform(Exclude = "MacOsX", Reason = "Mac requires explicit opt-in to receive CA certificate in TLS handshake")] + public async Task Connect_with_verify_and_multiple_ca_cert([Values(SslMode.VerifyCA, SslMode.VerifyFull)] SslMode sslMode, [Values] bool realCaFirst) + { + if (!IsOnBuildServer) + Assert.Ignore("Only executed in CI"); + + var certificates = new X509Certificate2Collection(); + + using var realCaCert = X509CertificateLoader.LoadCertificateFromFile("ca.crt"); + + using var ecdsa = ECDsa.Create(); + var req = new CertificateRequest("cn=localhost", ecdsa, HashAlgorithmName.SHA256); + using var unrelatedCaCert = req.CreateSelfSigned(DateTimeOffset.UtcNow.AddDays(-1), DateTimeOffset.UtcNow.AddDays(1)); + + if (realCaFirst) + { + certificates.Add(realCaCert); + certificates.Add(unrelatedCaCert); + } + else + { + certificates.Add(unrelatedCaCert); + certificates.Add(realCaCert); + } + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.ConnectionStringBuilder.SslMode = sslMode; + dataSourceBuilder.UseRootCertificates(certificates); + + await using var dataSource = dataSourceBuilder.Build(); + + await using var _ = await dataSource.OpenConnectionAsync(); + } + + [Test] + [NonParallelizable] // Sets environment variable + public async Task Direct_ssl_via_env_requires_correct_sslmode() + { + await using var adminConn = await OpenConnectionAsync(); + MinimumPgVersion(adminConn, "17.0"); + + // NonParallelizable attribute doesn't work with parameters that well + foreach (var sslMode in new[] { SslMode.Disable, SslMode.Allow, SslMode.Prefer, SslMode.Require }) + { + using var _ = SetEnvironmentVariable("PGSSLNEGOTIATION", nameof(SslNegotiation.Direct)); + await using var dataSource = CreateDataSource(csb => + { + csb.SslMode = sslMode; + }); + if (sslMode is SslMode.Disable or SslMode.Allow or SslMode.Prefer) + { + var ex = Assert.ThrowsAsync(async () => await dataSource.OpenConnectionAsync())!; + Assert.That(ex.Message, Is.EqualTo("SSL Mode has to be Require or higher to be used with direct SSL Negotiation")); + } + else + { + await using var conn = await dataSource.OpenConnectionAsync(); + } + } + } + #region Setup / Teardown / Utils [OneTimeSetUp] diff --git a/test/Npgsql.Tests/SizeTests.cs b/test/Npgsql.Tests/SizeTests.cs new file mode 100644 index 0000000000..93bd3b8d29 --- /dev/null +++ b/test/Npgsql.Tests/SizeTests.cs @@ -0,0 +1,59 @@ +using System; +using NUnit.Framework; +using Npgsql.Internal; + +namespace Npgsql.Tests; + +public class SizeTests +{ + [Test] + public void UnknownKind() => Assert.That(Size.Unknown.Kind, Is.EqualTo(SizeKind.Unknown)); + + [Test] + public void UnknownThrowsOnValue() => Assert.Throws(() => _ = Size.Unknown.Value); + + [Test] + public void Exact() + { + Assert.That(Size.Create(1).Value, Is.EqualTo(1)); + Assert.That(Size.Create(1).Kind, Is.EqualTo(SizeKind.Exact)); + } + + [Test] + public void ZeroIsExactKind() => Assert.That(Size.Zero.Kind, Is.EqualTo(SizeKind.Exact)); + + [Test] + public void UpperBound() + { + Assert.That(Size.CreateUpperBound(1).Value, Is.EqualTo(1)); + Assert.That(Size.CreateUpperBound(1).Kind, Is.EqualTo(SizeKind.UpperBound)); + } + + [Test] + public void CombineThrowsOnOverflow() => Assert.Throws(() => Size.Create(1).Combine(int.MaxValue)); + + [Test] + public void CombineExactWorks() => Assert.That(Size.Create(1).Combine(1), Is.EqualTo(Size.Create(2))); + + [Test] + public void CombineUpperBoundWorks() => Assert.That(Size.CreateUpperBound(1).Combine(1), Is.EqualTo(Size.CreateUpperBound(2))); + + [Test] + public void CombineUnknownWithAnyGivesUnknown() + { + Assert.That(Size.Unknown.Combine(Size.Unknown), Is.EqualTo(Size.Unknown)); + + Assert.That(Size.Create(1).Combine(Size.Unknown), Is.EqualTo(Size.Unknown)); + Assert.That(Size.Unknown.Combine(Size.Create(1)), Is.EqualTo(Size.Unknown)); + + Assert.That(Size.Unknown.Combine(Size.CreateUpperBound(1)), Is.EqualTo(Size.Unknown)); + Assert.That(Size.CreateUpperBound(1).Combine(Size.Unknown), Is.EqualTo(Size.Unknown)); + } + + [Test] + public void CombineUpperBoundWithExactGivesUpperBound() + { + Assert.That(Size.Create(1).Combine(Size.CreateUpperBound(1)), Is.EqualTo(Size.CreateUpperBound(2))); + Assert.That(Size.CreateUpperBound(1).Combine(Size.Create(1)), Is.EqualTo(Size.CreateUpperBound(2))); + } +} diff --git a/test/Npgsql.Tests/SnakeCaseNameTranslatorTests.cs b/test/Npgsql.Tests/SnakeCaseNameTranslatorTests.cs index 52de32bccf..9a64ccdaa2 100644 --- a/test/Npgsql.Tests/SnakeCaseNameTranslatorTests.cs +++ b/test/Npgsql.Tests/SnakeCaseNameTranslatorTests.cs @@ -66,9 +66,9 @@ public void TurkeyTest() const string clrName = "IPhone"; const string expected = "i_phone"; - Assert.AreEqual(expected, translator.TranslateMemberName(clrName)); - Assert.AreEqual(expected, translator.TranslateTypeName(clrName)); - Assert.AreEqual(expected, legacyTranslator.TranslateMemberName(clrName)); - Assert.AreEqual(expected, legacyTranslator.TranslateTypeName(clrName)); + Assert.That(translator.TranslateMemberName(clrName), Is.EqualTo(expected)); + Assert.That(translator.TranslateTypeName(clrName), Is.EqualTo(expected)); + Assert.That(legacyTranslator.TranslateMemberName(clrName), Is.EqualTo(expected)); + Assert.That(legacyTranslator.TranslateTypeName(clrName), Is.EqualTo(expected)); } } diff --git a/test/Npgsql.Tests/StoredProcedureTests.cs b/test/Npgsql.Tests/StoredProcedureTests.cs index 8666740f74..ae13fa015c 100644 --- a/test/Npgsql.Tests/StoredProcedureTests.cs +++ b/test/Npgsql.Tests/StoredProcedureTests.cs @@ -129,6 +129,88 @@ LANGUAGE plpgsql Assert.That(reader[1], Is.EqualTo(11)); } + [Test] + public async Task Batch_positional_parameters_works() + { + var tempname = await GetTempProcedureName(DataSource); + await using var connection = await DataSource.OpenConnectionAsync(); + await using var transaction = await connection.BeginTransactionAsync(IsolationLevel.Serializable); + await using var batch = new NpgsqlBatch(connection, transaction) + { + BatchCommands = + { + new(tempname) + { + CommandType = CommandType.StoredProcedure, + Parameters = + { + new() { Value = "" }, + new() { DbType = DbType.Int64, Direction = ParameterDirection.Output } + } + }, + new ("COMMIT") + } + }; + + Assert.ThrowsAsync(() => batch.ExecuteNonQueryAsync()); + } + + [Test] + public async Task Batch_StoredProcedure_output_parameters_works() + { + // Proper OUT params were introduced in PostgreSQL 14 + MinimumPgVersion(DataSource, "14.0", "Stored procedure OUT parameters are only support starting with version 14"); + var sproc = await GetTempProcedureName(DataSource); + + await using var connection = await DataSource.OpenConnectionAsync(); + await using var transaction = await connection.BeginTransactionAsync(IsolationLevel.Serializable); + var c = connection.CreateCommand(); + c.CommandText = $""" + CREATE OR REPLACE PROCEDURE {sproc} + ( + p_username TEXT, + OUT p_user_id BIGINT + ) + LANGUAGE plpgsql + AS $$ + BEGIN + p_user_id = 1; + return; + END; + $$; + """; + await c.ExecuteNonQueryAsync(); + + await using var batch = new NpgsqlBatch(connection, transaction) + { + BatchCommands = + { + new(sproc) + { + CommandType = CommandType.StoredProcedure, + Parameters = + { + new() { Value = "" }, + new() { NpgsqlDbType = NpgsqlDbType.Bigint, Direction = ParameterDirection.Output } + } + }, + new(sproc) + { + CommandType = CommandType.StoredProcedure, + Parameters = + { + new() { Value = "" }, + new() { NpgsqlDbType = NpgsqlDbType.Bigint, Direction = ParameterDirection.Output } + } + } + } + }; + + await batch.ExecuteNonQueryAsync(); + Assert.That(batch.BatchCommands[0].Parameters[1].Value, Is.EqualTo(1)); + Assert.That(batch.BatchCommands[1].Parameters[1].Value, Is.EqualTo(1)); + } + #region DeriveParameters [Test, Description("Tests function parameter derivation with IN, OUT and INOUT parameters")] @@ -213,8 +295,8 @@ public async Task DeriveParameters_procedure_with_case_sensitive_name() { await using var command = new NpgsqlCommand(@"""ProcedureCaseSensitive""", conn) { CommandType = CommandType.StoredProcedure }; NpgsqlCommandBuilder.DeriveParameters(command); - Assert.AreEqual(NpgsqlDbType.Integer, command.Parameters[0].NpgsqlDbType); - Assert.AreEqual(NpgsqlDbType.Text, command.Parameters[1].NpgsqlDbType); + Assert.That(command.Parameters[0].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer)); + Assert.That(command.Parameters[1].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Text)); } finally { @@ -233,8 +315,8 @@ public async Task DeriveParameters_quote_characters_in_function_name() { await using var command = new NpgsqlCommand(sproc, conn) { CommandType = CommandType.StoredProcedure }; NpgsqlCommandBuilder.DeriveParameters(command); - Assert.AreEqual(NpgsqlDbType.Integer, command.Parameters[0].NpgsqlDbType); - Assert.AreEqual(NpgsqlDbType.Text, command.Parameters[1].NpgsqlDbType); + Assert.That(command.Parameters[0].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer)); + Assert.That(command.Parameters[1].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Text)); } finally { @@ -253,8 +335,8 @@ await conn.ExecuteNonQueryAsync( { await using var command = new NpgsqlCommand(@"""My.Dotted.Procedure""", conn) { CommandType = CommandType.StoredProcedure }; NpgsqlCommandBuilder.DeriveParameters(command); - Assert.AreEqual(NpgsqlDbType.Integer, command.Parameters[0].NpgsqlDbType); - Assert.AreEqual(NpgsqlDbType.Text, command.Parameters[1].NpgsqlDbType); + Assert.That(command.Parameters[0].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Integer)); + Assert.That(command.Parameters[1].NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Text)); } finally { @@ -273,8 +355,8 @@ await conn.ExecuteNonQueryAsync( $"CREATE PROCEDURE {sproc}(x int, y int, out sum int, out product int) AS 'SELECT $1 + $2, $1 * $2' LANGUAGE sql"); await using var command = new NpgsqlCommand(sproc, conn) { CommandType = CommandType.StoredProcedure }; NpgsqlCommandBuilder.DeriveParameters(command); - Assert.AreEqual("x", command.Parameters[0].ParameterName); - Assert.AreEqual("y", command.Parameters[1].ParameterName); + Assert.That(command.Parameters[0].ParameterName, Is.EqualTo("x")); + Assert.That(command.Parameters[1].ParameterName, Is.EqualTo("y")); } [Test] diff --git a/test/Npgsql.Tests/Support/AssemblySetUp.cs b/test/Npgsql.Tests/Support/AssemblySetUp.cs index f1619ecec4..c7d16f0501 100644 --- a/test/Npgsql.Tests/Support/AssemblySetUp.cs +++ b/test/Npgsql.Tests/Support/AssemblySetUp.cs @@ -26,7 +26,6 @@ public void Setup() var builder = new NpgsqlConnectionStringBuilder(connString) { Pooling = false, - Multiplexing = false, Database = "postgres" }; diff --git a/test/Npgsql.Tests/Support/ListLoggerFactory.cs b/test/Npgsql.Tests/Support/ListLoggerFactory.cs index 2852335df8..98f94cb8fa 100644 --- a/test/Npgsql.Tests/Support/ListLoggerFactory.cs +++ b/test/Npgsql.Tests/Support/ListLoggerFactory.cs @@ -35,20 +35,15 @@ public void AddProvider(ILoggerProvider provider) public void Dispose() => StopRecording(); - class ListLogger : ILogger + class ListLogger(ListLoggerProvider provider) : ILogger { - readonly ListLoggerProvider _provider; - - public ListLogger(ListLoggerProvider provider) - => _provider = provider; - public List<(LogLevel, EventId, string, object?, Exception?)> LoggedEvents { get; } - = new(); + = []; public void Log(LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func formatter) { - if (_provider._recording) + if (provider._recording) { lock (this) { @@ -66,7 +61,7 @@ public void Clear() } } - public bool IsEnabled(LogLevel logLevel) => _provider._recording; + public bool IsEnabled(LogLevel logLevel) => provider._recording; public IDisposable BeginScope(TState state) where TState : notnull => new Scope(); @@ -79,14 +74,9 @@ public void Dispose() } } - class RecordingDisposable : IDisposable + class RecordingDisposable(ListLoggerProvider provider) : IDisposable { - readonly ListLoggerProvider _provider; - - public RecordingDisposable(ListLoggerProvider provider) - => _provider = provider; - public void Dispose() - => _provider.StopRecording(); + => provider.StopRecording(); } } diff --git a/test/Npgsql.Tests/Support/MultiplexingTestBase.cs b/test/Npgsql.Tests/Support/MultiplexingTestBase.cs deleted file mode 100644 index 892dd79f5e..0000000000 --- a/test/Npgsql.Tests/Support/MultiplexingTestBase.cs +++ /dev/null @@ -1,37 +0,0 @@ -using System.Collections.Concurrent; -using NUnit.Framework; - -namespace Npgsql.Tests; - -[TestFixture(MultiplexingMode.NonMultiplexing)] -[TestFixture(MultiplexingMode.Multiplexing)] -public abstract class MultiplexingTestBase : TestBase -{ - protected bool IsMultiplexing => MultiplexingMode == MultiplexingMode.Multiplexing; - - protected MultiplexingMode MultiplexingMode { get; } - - readonly ConcurrentDictionary<(string ConnString, bool IsMultiplexing), string> _connStringCache - = new(); - - public override string ConnectionString { get; } - - protected MultiplexingTestBase(MultiplexingMode multiplexingMode) - { - MultiplexingMode = multiplexingMode; - - // If the test requires multiplexing to be on or off, use a small cache to avoid reparsing and - // regenerating the connection string every time - ConnectionString = _connStringCache.GetOrAdd((base.ConnectionString, IsMultiplexing), - tup => new NpgsqlConnectionStringBuilder(tup.ConnString) - { - Multiplexing = tup.IsMultiplexing - }.ToString()); - } -} - -public enum MultiplexingMode -{ - NonMultiplexing, - Multiplexing -} diff --git a/test/Npgsql.Tests/Support/PgCancellationRequest.cs b/test/Npgsql.Tests/Support/PgCancellationRequest.cs index c07f606bb8..6773c55dd2 100644 --- a/test/Npgsql.Tests/Support/PgCancellationRequest.cs +++ b/test/Npgsql.Tests/Support/PgCancellationRequest.cs @@ -3,35 +3,21 @@ namespace Npgsql.Tests.Support; -class PgCancellationRequest +class PgCancellationRequest(NpgsqlReadBuffer readBuffer, NpgsqlWriteBuffer writeBuffer, Stream stream, int processId, int secret) { - readonly NpgsqlReadBuffer _readBuffer; - readonly NpgsqlWriteBuffer _writeBuffer; - readonly Stream _stream; - - public int ProcessId { get; } - public int Secret { get; } + public int ProcessId { get; } = processId; + public int Secret { get; } = secret; bool completed; - public PgCancellationRequest(NpgsqlReadBuffer readBuffer, NpgsqlWriteBuffer writeBuffer, Stream stream, int processId, int secret) - { - _readBuffer = readBuffer; - _writeBuffer = writeBuffer; - _stream = stream; - - ProcessId = processId; - Secret = secret; - } - public void Complete() { if (completed) return; - _readBuffer.Dispose(); - _writeBuffer.Dispose(); - _stream.Dispose(); + readBuffer.Dispose(); + writeBuffer.Dispose(); + stream.Dispose(); completed = true; } diff --git a/test/Npgsql.Tests/Support/PgPostmasterMock.cs b/test/Npgsql.Tests/Support/PgPostmasterMock.cs index e45c1a7f28..d9a93531a1 100644 --- a/test/Npgsql.Tests/Support/PgPostmasterMock.cs +++ b/test/Npgsql.Tests/Support/PgPostmasterMock.cs @@ -16,18 +16,20 @@ class PgPostmasterMock : IAsyncDisposable const int WriteBufferSize = 8192; const int CancelRequestCode = 1234 << 16 | 5678; const int SslRequest = 80877103; + const int GssRequest = 80877104; static readonly Encoding Encoding = NpgsqlWriteBuffer.UTF8Encoding; static readonly Encoding RelaxedEncoding = NpgsqlWriteBuffer.RelaxedUTF8Encoding; readonly Socket _socket; - readonly List _allServers = new(); + readonly List _allServers = []; bool _acceptingClients; Task? _acceptClientsTask; int _processIdCounter; readonly bool _completeCancellationImmediately; readonly string? _startupErrorCode; + readonly bool _breakOnGssEncryptionRequest; ChannelWriter> _pendingRequestsWriter { get; } ChannelReader> _pendingRequestsReader { get; } @@ -48,9 +50,10 @@ internal static PgPostmasterMock Start( string? connectionString = null, bool completeCancellationImmediately = true, MockState state = MockState.MultipleHostsDisabled, - string? startupErrorCode = null) + string? startupErrorCode = null, + bool breakOnGssEncryptionRequest = false) { - var mock = new PgPostmasterMock(connectionString, completeCancellationImmediately, state, startupErrorCode); + var mock = new PgPostmasterMock(connectionString, completeCancellationImmediately, state, startupErrorCode, breakOnGssEncryptionRequest); mock.AcceptClients(); return mock; } @@ -59,7 +62,8 @@ internal PgPostmasterMock( string? connectionString = null, bool completeCancellationImmediately = true, MockState state = MockState.MultipleHostsDisabled, - string? startupErrorCode = null) + string? startupErrorCode = null, + bool breakOnGssEncryptionRequest = false) { var pendingRequestsChannel = Channel.CreateUnbounded>(); _pendingRequestsReader = pendingRequestsChannel.Reader; @@ -70,6 +74,7 @@ internal PgPostmasterMock( _completeCancellationImmediately = completeCancellationImmediately; State = state; _startupErrorCode = startupErrorCode; + _breakOnGssEncryptionRequest = breakOnGssEncryptionRequest; _socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp); var endpoint = new IPEndPoint(IPAddress.Loopback, 0); @@ -80,17 +85,20 @@ internal PgPostmasterMock( Port = localEndPoint.Port; connectionStringBuilder.Host = Host; connectionStringBuilder.Port = Port; +#pragma warning disable CS0618 // Type or member is obsolete connectionStringBuilder.ServerCompatibilityMode = ServerCompatibilityMode.NoTypeLoading; +#pragma warning restore CS0618 // Type or member is obsolete ConnectionString = connectionStringBuilder.ConnectionString; _socket.Listen(5); } - public NpgsqlDataSourceBuilder GetDataSourceBuilder() - => new(ConnectionString); - - public NpgsqlDataSource CreateDataSource() - => NpgsqlDataSource.Create(ConnectionString); + public NpgsqlDataSource CreateDataSource(Action? configure = null) + { + var builder = new NpgsqlDataSourceBuilder(ConnectionString); + configure?.Invoke(builder); + return builder.Build(); + } void AcceptClients() { @@ -138,12 +146,32 @@ async Task Accept(bool completeCancellationImmediat var readBuffer = new NpgsqlReadBuffer(null!, stream, clientSocket, ReadBufferSize, Encoding, RelaxedEncoding); var writeBuffer = new NpgsqlWriteBuffer(null!, stream, clientSocket, WriteBufferSize, Encoding); + writeBuffer.MessageLengthValidation = false; await readBuffer.EnsureAsync(4); var len = readBuffer.ReadInt32(); await readBuffer.EnsureAsync(len - 4); var request = readBuffer.ReadInt32(); + if (request == GssRequest) + { + if (_breakOnGssEncryptionRequest) + { + readBuffer.Dispose(); + writeBuffer.Dispose(); + await stream.DisposeAsync(); + return default; + } + + writeBuffer.WriteByte((byte)'N'); + await writeBuffer.Flush(async: true); + + await readBuffer.EnsureAsync(4); + len = readBuffer.ReadInt32(); + await readBuffer.EnsureAsync(len - 4); + request = readBuffer.ReadInt32(); + } + if (request == SslRequest) { writeBuffer.WriteByte((byte)'N'); diff --git a/test/Npgsql.Tests/Support/PgServerMock.cs b/test/Npgsql.Tests/Support/PgServerMock.cs index c34a9315c8..9f7a799649 100644 --- a/test/Npgsql.Tests/Support/PgServerMock.cs +++ b/test/Npgsql.Tests/Support/PgServerMock.cs @@ -41,6 +41,7 @@ internal PgServerMock( _stream = stream; _readBuffer = readBuffer; _writeBuffer = writeBuffer; + writeBuffer.MessageLengthValidation = false; } internal async Task Startup(MockState state) diff --git a/test/Npgsql.Tests/Support/TestBase.cs b/test/Npgsql.Tests/Support/TestBase.cs index 463b132d56..8fc0131889 100644 --- a/test/Npgsql.Tests/Support/TestBase.cs +++ b/test/Npgsql.Tests/Support/TestBase.cs @@ -7,9 +7,11 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Logging; +using Npgsql.Internal.Postgres; using Npgsql.Tests.Support; using NpgsqlTypes; using NUnit.Framework; +using NUnit.Framework.Constraints; namespace Npgsql.Tests; @@ -29,307 +31,310 @@ public abstract class TestBase #region Type testing - public async Task AssertType( + public Task AssertType( T value, string sqlLiteral, - string pgTypeName, - NpgsqlDbType? npgsqlDbType, - DbType? dbType = null, - DbType? inferredDbType = null, - bool isDefaultForReading = true, - bool isDefaultForWriting = true, - bool? isDefault = null, - bool isNpgsqlDbTypeInferredFromClrType = true, + string dataTypeName, + DataTypeInference? dataTypeInference = null, + DbTypes? dbType = null, Func? comparer = null, + bool valueTypeEqualsFieldType = true, bool skipArrayCheck = false) - { - await using var connection = await OpenConnectionAsync(); - return await AssertType( - connection, value, sqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefaultForReading, isDefaultForWriting, - isDefault, isNpgsqlDbTypeInferredFromClrType, comparer, skipArrayCheck); - } + => AssertTypeCore(OpenConnectionAsync(), disposeConnection: true, () => value, sqlLiteral, dataTypeName, dataTypeInference, + dbType, comparer, valueTypeEqualsFieldType, skipArrayCheck); - public async Task AssertType( + public Task AssertType( NpgsqlDataSource dataSource, T value, string sqlLiteral, - string pgTypeName, - NpgsqlDbType? npgsqlDbType, - DbType? dbType = null, - DbType? inferredDbType = null, - bool isDefaultForReading = true, - bool isDefaultForWriting = true, - bool? isDefault = null, - bool isNpgsqlDbTypeInferredFromClrType = true, + string dataTypeName, + DataTypeInference? dataTypeInference = null, + DbTypes? dbType = null, Func? comparer = null, + bool valueTypeEqualsFieldType = true, bool skipArrayCheck = false) - { - await using var connection = await dataSource.OpenConnectionAsync(); + => AssertTypeCore(dataSource.OpenConnectionAsync(), disposeConnection: true, () => value, sqlLiteral, dataTypeName, dataTypeInference, + dbType, comparer, valueTypeEqualsFieldType, skipArrayCheck); - return await AssertType(connection, value, sqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefaultForReading, - isDefaultForWriting, isDefault, isNpgsqlDbTypeInferredFromClrType, comparer, skipArrayCheck); - } - - public async Task AssertType( + public Task AssertType( NpgsqlConnection connection, T value, string sqlLiteral, - string pgTypeName, - NpgsqlDbType? npgsqlDbType, - DbType? dbType = null, - DbType? inferredDbType = null, - bool isDefaultForReading = true, - bool isDefaultForWriting = true, - bool? isDefault = null, - bool isNpgsqlDbTypeInferredFromClrType = true, + string dataTypeName, + DataTypeInference? dataTypeInference = null, + DbTypes? dbType = null, Func? comparer = null, + bool valueTypeEqualsFieldType = true, bool skipArrayCheck = false) - { - if (isDefault is not null) - isDefaultForReading = isDefaultForWriting = isDefault.Value; + => AssertTypeCore(new(connection), disposeConnection: false, () => value, sqlLiteral, dataTypeName, dataTypeInference, + dbType, comparer, valueTypeEqualsFieldType, skipArrayCheck); - await AssertTypeWrite(connection, () => value, sqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefaultForWriting, isNpgsqlDbTypeInferredFromClrType, skipArrayCheck); - return await AssertTypeRead(connection, sqlLiteral, pgTypeName, value, isDefaultForReading, comparer, fieldType: null, skipArrayCheck); - } + public Task AssertType( + Func valueFactory, + string sqlLiteral, + string dataTypeName, + DataTypeInference? dataTypeInference = null, + DbTypes? dbType = null, + Func? comparer = null, + bool valueTypeEqualsFieldType = true, + bool skipArrayCheck = false) + => AssertTypeCore(OpenConnectionAsync(), disposeConnection: true, valueFactory, sqlLiteral, dataTypeName, dataTypeInference, + dbType, comparer, valueTypeEqualsFieldType, skipArrayCheck); - public async Task AssertTypeRead(string sqlLiteral, string pgTypeName, T expected, bool isDefault = true, bool skipArrayCheck = false) - { - await using var connection = await OpenConnectionAsync(); - return await AssertTypeRead(connection, sqlLiteral, pgTypeName, expected, isDefault, comparer: null, fieldType: null, skipArrayCheck); - } + public Task AssertType( + NpgsqlDataSource dataSource, + Func valueFactory, + string sqlLiteral, + string dataTypeName, + DataTypeInference? dataTypeInference = null, + DbTypes? dbType = null, + Func? comparer = null, + bool valueTypeEqualsFieldType = true, + bool skipArrayCheck = false) + => AssertTypeCore(dataSource.OpenConnectionAsync(), disposeConnection: true, valueFactory, sqlLiteral, dataTypeName, dataTypeInference, + dbType, comparer, valueTypeEqualsFieldType, skipArrayCheck); - public async Task AssertTypeRead(NpgsqlDataSource dataSource, string sqlLiteral, string pgTypeName, T expected, - bool isDefault = true, Func? comparer = null, Type? fieldType = null, bool skipArrayCheck = false) - { - await using var connection = await dataSource.OpenConnectionAsync(); - return await AssertTypeRead(connection, sqlLiteral, pgTypeName, expected, isDefault, comparer, fieldType, skipArrayCheck); - } + public Task AssertType( + NpgsqlConnection connection, + Func valueFactory, + string sqlLiteral, + string dataTypeName, + DataTypeInference? dataTypeInference = null, + DbTypes? dbType = null, + Func? comparer = null, + bool valueTypeEqualsFieldType = true, + bool skipArrayCheck = false) + => AssertTypeCore(new(connection), disposeConnection: false, valueFactory, sqlLiteral, dataTypeName, dataTypeInference, + dbType, comparer, valueTypeEqualsFieldType, skipArrayCheck); - public async Task AssertTypeWrite( - NpgsqlDataSource dataSource, - T value, - string expectedSqlLiteral, - string pgTypeName, - NpgsqlDbType npgsqlDbType, - DbType? dbType = null, - DbType? inferredDbType = null, - bool isDefault = true, - bool isNpgsqlDbTypeInferredFromClrType = true, + static async Task AssertTypeCore( + ValueTask connectionTask, + bool disposeConnection, + Func valueFactory, + string sqlLiteral, + string dataTypeName, + DataTypeInference? dataTypeInference = null, + DbTypes? dbType = null, + Func? comparer = null, + bool valueTypeEqualsFieldType = true, bool skipArrayCheck = false) { - await using var connection = await dataSource.OpenConnectionAsync(); + var connection = await connectionTask; + await using var _ = disposeConnection ? connection : null; - await AssertTypeWrite(connection, () => value, expectedSqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefault, - isNpgsqlDbTypeInferredFromClrType, skipArrayCheck); + await AssertTypeWriteCore(new(connection), disposeConnection: false, valueFactory, sqlLiteral, + dataTypeName, dataTypeInference, dbType, skipArrayCheck); + return await AssertTypeReadCore(new(connection), disposeConnection: false, sqlLiteral, dataTypeName, valueFactory(), + valueTypeEqualsFieldType, comparer, skipArrayCheck); } public Task AssertTypeWrite( T value, - string expectedSqlLiteral, - string pgTypeName, - NpgsqlDbType npgsqlDbType, - DbType? dbType = null, - DbType? inferredDbType = null, - bool isDefault = true, - bool isNpgsqlDbTypeInferredFromClrType = true, + string sqlLiteral, + string dataTypeName, + DataTypeInference? dataTypeInference = null, + DbTypes? dbType = null, bool skipArrayCheck = false) - => AssertTypeWrite(() => value, expectedSqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefault, - isNpgsqlDbTypeInferredFromClrType, skipArrayCheck); + => AssertTypeWriteCore(OpenConnectionAsync(), disposeConnection: true, () => value, sqlLiteral, + dataTypeName, dataTypeInference, dbType, skipArrayCheck); - public async Task AssertTypeWrite( - Func valueFactory, - string expectedSqlLiteral, - string pgTypeName, - NpgsqlDbType npgsqlDbType, - DbType? dbType = null, - DbType? inferredDbType = null, - bool isDefault = true, - bool isNpgsqlDbTypeInferredFromClrType = true, + public Task AssertTypeWrite( + NpgsqlDataSource dataSource, + T value, + string sqlLiteral, + string dataTypeName, + DataTypeInference? dataTypeInference = null, + DbTypes? dbType = null, bool skipArrayCheck = false) - { - await using var connection = await OpenConnectionAsync(); - await AssertTypeWrite(connection, valueFactory, expectedSqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefault, isNpgsqlDbTypeInferredFromClrType, skipArrayCheck); - } + => AssertTypeWriteCore(dataSource.OpenConnectionAsync(), disposeConnection: true, () => value, sqlLiteral, + dataTypeName, dataTypeInference, dbType, skipArrayCheck); - internal static async Task AssertTypeRead( + public Task AssertTypeWrite( NpgsqlConnection connection, + T value, string sqlLiteral, - string pgTypeName, - T expected, - bool isDefault = true, - Func? comparer = null, - Type? fieldType = null, + string dataTypeName, + DataTypeInference? dataTypeInference = null, + DbTypes? dbType = null, bool skipArrayCheck = false) - { - var result = await AssertTypeReadCore(connection, sqlLiteral, pgTypeName, expected, isDefault, comparer); - - // Check the corresponding array type as well - if (!skipArrayCheck && !pgTypeName.EndsWith("[]", StringComparison.Ordinal)) - { - await AssertTypeReadCore( - connection, - ArrayLiteral(sqlLiteral), - pgTypeName + "[]", - new[] { expected, expected }, - isDefault, - comparer is null ? null : (array1, array2) => comparer(array1[0], array2[0]) && comparer(array1[1], array2[1])); - } + => AssertTypeWriteCore(new(connection), disposeConnection: false, () => value, sqlLiteral, dataTypeName, dataTypeInference, + dbType, skipArrayCheck); - return result; - } - - internal static async Task AssertTypeReadCore( - NpgsqlConnection connection, + public Task AssertTypeWrite( + Func valueFactory, string sqlLiteral, - string pgTypeName, - T expected, - bool isDefault = true, - Func? comparer = null, - Type? fieldType = null) - { - if (sqlLiteral.Contains('\'')) - sqlLiteral = sqlLiteral.Replace("'", "''"); - - await using var cmd = new NpgsqlCommand($"SELECT '{sqlLiteral}'::{pgTypeName}", connection); - await using var reader = await cmd.ExecuteReaderAsync(CommandBehavior.SequentialAccess); - await reader.ReadAsync(); - - var truncatedSqlLiteral = sqlLiteral.Length > 40 ? sqlLiteral[..40] + "..." : sqlLiteral; - - var dataTypeName = reader.GetDataTypeName(0); - var dotIndex = dataTypeName.IndexOf('.'); - if (dotIndex > -1 && dataTypeName.Substring(0, dotIndex) is "pg_catalog" or "public") - dataTypeName = dataTypeName.Substring(dotIndex + 1); - - Assert.That(dataTypeName, Is.EqualTo(pgTypeName), - $"Got wrong result from GetDataTypeName when reading '{truncatedSqlLiteral}'"); - - if (isDefault) - { - // For arrays, GetFieldType always returns typeof(Array), since PG arrays can have arbitrary dimensionality - Assert.That(reader.GetFieldType(0), Is.EqualTo(dataTypeName.EndsWith("[]") ? typeof(Array) : fieldType ?? typeof(T)), - $"Got wrong result from GetFieldType when reading '{truncatedSqlLiteral}'"); - } - - var actual = isDefault ? (T)reader.GetValue(0) : reader.GetFieldValue(0); - - Assert.That(actual, comparer is null ? Is.EqualTo(expected) : Is.EqualTo(expected).Using(new SimpleComparer(comparer)), - $"Got wrong result from GetFieldValue value when reading '{truncatedSqlLiteral}'"); + string dataTypeName, + DataTypeInference? dataTypeInference = null, + DbTypes? dbType = null, + bool skipArrayCheck = false) + => AssertTypeWriteCore(OpenConnectionAsync(), disposeConnection: true, valueFactory, sqlLiteral, + dataTypeName, dataTypeInference, dbType, skipArrayCheck); - return actual; - } + public Task AssertTypeWrite( + NpgsqlDataSource dataSource, + Func valueFactory, + string sqlLiteral, + string dataTypeName, + DataTypeInference? dataTypeInference = null, + DbTypes? dbType = null, + bool skipArrayCheck = false) + => AssertTypeWriteCore(dataSource.OpenConnectionAsync(), disposeConnection: true, valueFactory, sqlLiteral, + dataTypeName, dataTypeInference, dbType, skipArrayCheck); - internal static async Task AssertTypeWrite( + public Task AssertTypeWrite( NpgsqlConnection connection, Func valueFactory, - string expectedSqlLiteral, - string pgTypeName, - NpgsqlDbType? npgsqlDbType, - DbType? dbType = null, - DbType? inferredDbType = null, - bool isDefault = true, - bool isNpgsqlDbTypeInferredFromClrType = true, + string sqlLiteral, + string dataTypeName, + DataTypeInference? dataTypeInference = null, + DbTypes? dbType = null, bool skipArrayCheck = false) + => AssertTypeWriteCore(new(connection), disposeConnection: false, valueFactory, sqlLiteral, + dataTypeName, dataTypeInference, dbType, skipArrayCheck); + + static async Task AssertTypeWriteCore( + ValueTask connectionTask, + bool disposeConnection, + Func valueFactory, + string sqlLiteral, + string dataTypeName, + DataTypeInference? dataTypeInference, + DbTypes? dbType, + bool skipArrayCheck) { + var connection = await connectionTask; + await using var _ = disposeConnection ? connection : null; + await AssertTypeWriteCore( - connection, valueFactory, expectedSqlLiteral, pgTypeName, npgsqlDbType, dbType, inferredDbType, isDefault, - isNpgsqlDbTypeInferredFromClrType); + connection, valueFactory, sqlLiteral, + dataTypeName, dataTypeInference ?? DataTypeInference.Match, + dbType); // Check the corresponding array type as well - if (!skipArrayCheck && !pgTypeName.EndsWith("[]", StringComparison.Ordinal)) + if (!skipArrayCheck && !dataTypeName.EndsWith("[]", StringComparison.Ordinal)) { await AssertTypeWriteCore( connection, () => new[] { valueFactory(), valueFactory() }, - ArrayLiteral(expectedSqlLiteral), - pgTypeName + "[]", - npgsqlDbType | NpgsqlDbType.Array, - dbType: null, - inferredDbType: null, - isDefault, - isNpgsqlDbTypeInferredFromClrType); + ArrayLiteral(sqlLiteral), + dataTypeName + "[]", dataTypeInference ?? DataTypeInference.Match, + expectedDbTypes: null); } } - internal static async Task AssertTypeWriteCore( - NpgsqlConnection connection, - Func valueFactory, - string expectedSqlLiteral, - string pgTypeName, - NpgsqlDbType? npgsqlDbType, - DbType? dbType = null, - DbType? inferredDbType = null, - bool isDefault = true, - bool isNpgsqlDbTypeInferredFromClrType = true) + public enum DataTypeInference + { + /// + /// Data type is inferred from the CLR value and matches the data type under test. + /// + Match, + + /// + /// Data type is inferred from the CLR value but differs from the data type under test. + /// + /// + /// Used when we get some inferred data type (e.g. CLR strings are inferred to be 'text') but this does not match the data type (e.g. 'json') under test. + /// + Mismatch, + + /// + /// Data type can not be inferred from the CLR value. + /// + /// + /// This is for CLR types that are statically unknown to Npgsql (plugin types: NodaTime/NTS, composite types, enums...), + /// or where we specifically don't want to infer a data type because there's no good option + /// (e.g. uint can be mapped to 'oid/xid/cid', but we don't want any of these as a default/inferred data type) + /// + Nothing, + } + + public readonly struct DbTypes(DbType dataTypeMappedDbType, DbType valueInferredDbType, DbType dbTypeToSet) { - if (npgsqlDbType is null) - isNpgsqlDbTypeInferredFromClrType = false; + public DbType DataTypeMappedDbType { get; } = dataTypeMappedDbType; + public DbType ValueInferredDbType { get; } = valueInferredDbType; + + // The DbType to explicitly set on the parameter. Usually same as ValueInferredDbType, + // It differs when testing DbType aliases (e.g. VarNumeric → DbType.Decimal) as we want to test those also work correctly. + public DbType DbTypeToSet { get; } = dbTypeToSet; - inferredDbType ??= isNpgsqlDbTypeInferredFromClrType ? dbType ?? DbType.Object : DbType.Object; + public DbTypes(DbType dataTypeMappedDbType, DbType valueInferredDbType) + : this(dataTypeMappedDbType, valueInferredDbType, valueInferredDbType) {} - // TODO: Interferes with both multiplexing and connection-specific mapping (used e.g. in NodaTime) - // Reset the type mapper to make sure we're resolving this type with a clean slate (for isolation, just in case) - // connection.TypeMapper.Reset(); + public static implicit operator DbTypes(DbType dbType) => new(dbType, dbType, dbType); + } + + static async Task AssertTypeWriteCore( + NpgsqlConnection connection, + Func valueFactory, + string sqlLiteral, + string dataTypeName, + DataTypeInference dataTypeInference, + DbTypes? expectedDbTypes) + { + var npgsqlDbType = DataTypeName.FromDisplayName(dataTypeName).ToNpgsqlDbType(); // Strip any facet information (length/precision/scale) - var parenIndex = pgTypeName.IndexOf('('); - // var pgTypeNameWithoutFacets = parenIndex > -1 ? pgTypeName[..parenIndex] : pgTypeName; - var pgTypeNameWithoutFacets = parenIndex > -1 - ? pgTypeName[..parenIndex] + pgTypeName[(pgTypeName.IndexOf(')') + 1)..] - : pgTypeName; + var parenIndex = dataTypeName.IndexOf('('); + var dataTypeNameWithoutFacets = parenIndex > -1 + ? dataTypeName[..parenIndex] + dataTypeName[(dataTypeName.IndexOf(')') + 1)..] + : dataTypeName; + + // For composite type with dots in name, Postgresql returns name with quotes - scheme."My.type.name" + // but for npgsql mapping we should use names without quotes - scheme.My.type.name + var dataTypeNameWithoutFacetsAndQuotes = dataTypeNameWithoutFacets.Replace("\"", string.Empty); // We test the following scenarios (between 2 and 5 in total): - // 1. With NpgsqlDbType explicitly set - // 2. With DataTypeName explicitly set - // 3. With DbType explicitly set (if one was provided) - // 4. With only the value set (if it's the default) - // 5. With only the value set, using generic NpgsqlParameter (if it's the default) + // 1. With value and DataTypeName set + // 2. With value and NpgsqlDbType set (when available) + // 3. With value and DbType explicitly set + // 4. With only the value set + // 5. With only the value set, using generic NpgsqlParameter + + // We only actually attempt to write to the database with a set DataTypeName, NpgsqlDbType, or when data type inference is exact. var errorIdentifierIndex = -1; var errorIdentifier = new Dictionary(); await using var cmd = new NpgsqlCommand { Connection = connection }; NpgsqlParameter p; - // With NpgsqlDbType - if (npgsqlDbType is not null) - { - p = new NpgsqlParameter { Value = valueFactory(), NpgsqlDbType = npgsqlDbType.Value }; - cmd.Parameters.Add(p); - errorIdentifier[++errorIdentifierIndex] = $"NpgsqlDbType={npgsqlDbType}"; - CheckInference(); - } // With data type name - p = new NpgsqlParameter { Value = valueFactory(), DataTypeName = pgTypeNameWithoutFacets }; + p = new NpgsqlParameter { Value = valueFactory(), DataTypeName = dataTypeNameWithoutFacetsAndQuotes }; + errorIdentifier[++errorIdentifierIndex] = $"Value and DataTypeName={dataTypeNameWithoutFacetsAndQuotes}"; + DataTypeAsserts(); cmd.Parameters.Add(p); - errorIdentifier[++errorIdentifierIndex] = $"DataTypeName={pgTypeNameWithoutFacets}"; - CheckInference(); - // With DbType - if (dbType is not null) + // With NpgsqlDbType + if (npgsqlDbType is not null) { - p = new NpgsqlParameter { Value = valueFactory(), DbType = dbType.Value }; + p = new NpgsqlParameter { Value = valueFactory(), NpgsqlDbType = npgsqlDbType.Value }; + errorIdentifier[++errorIdentifierIndex] = $"Value and NpgsqlDbType={npgsqlDbType}"; + DataTypeAsserts(); cmd.Parameters.Add(p); - errorIdentifier[++errorIdentifierIndex] = $"DbType={dbType}"; - CheckInference(); } - if (isDefault) - { - // With (non-generic) value only - p = new NpgsqlParameter { Value = valueFactory() }; + // With DbType, if none was supplied we verify it's DbType.Object. + p = new NpgsqlParameter { Value = valueFactory() }; + errorIdentifier[++errorIdentifierIndex] = $"Value and DbType={expectedDbTypes?.DbTypeToSet}"; + if (expectedDbTypes?.DbTypeToSet is { } expectedDbType) + p.DbType = expectedDbType; + DbTypeAsserts(); + if (dataTypeInference is DataTypeInference.Match) cmd.Parameters.Add(p); - errorIdentifier[++errorIdentifierIndex] = $"Value only (type {p.Value!.GetType().Name}, non-generic)"; - CheckInference(valueOnlyInference: true); - // With (generic) value only - p = new NpgsqlParameter { TypedValue = valueFactory() }; + // With (non-generic) value only + p = new NpgsqlParameter { Value = valueFactory() }; + errorIdentifier[++errorIdentifierIndex] = $"Value (type {p.Value!.GetType().Name}, non-generic)"; + ValueAsserts(); + if (dataTypeInference is DataTypeInference.Match) cmd.Parameters.Add(p); - errorIdentifier[++errorIdentifierIndex] = $"Value only (type {p.Value!.GetType().Name}, generic)"; - CheckInference(valueOnlyInference: true); - } - Debug.Assert(cmd.Parameters.Count == errorIdentifierIndex + 1); + // With (generic) value only + p = new NpgsqlParameter { TypedValue = valueFactory() }; + errorIdentifier[++errorIdentifierIndex] = $"Value (type {p.Value!.GetType().Name}, generic)"; + ValueAsserts(); + if (dataTypeInference is DataTypeInference.Match) + cmd.Parameters.Add(p); cmd.CommandText = "SELECT " + string.Join(", ", Enumerable.Range(1, cmd.Parameters.Count).Select(i => "pg_typeof($1)::text, $1::text".Replace("$1", $"${i}"))); @@ -339,134 +344,261 @@ internal static async Task AssertTypeWriteCore( for (var i = 0; i < cmd.Parameters.Count * 2; i += 2) { - Assert.That(reader[i], Is.EqualTo(pgTypeNameWithoutFacets), $"Got wrong PG type name when writing with {errorIdentifier[i / 2]}"); - Assert.That(reader[i+1], Is.EqualTo(expectedSqlLiteral), $"Got wrong SQL literal when writing with {errorIdentifier[i / 2]}"); + var error = errorIdentifier[i / 2]; + Assert.That(reader[i], Is.EqualTo(dataTypeNameWithoutFacets), $"Got wrong data type name when writing with {error}"); + Assert.That(reader[i+1], Is.EqualTo(sqlLiteral), $"Got wrong SQL literal when writing with {error}"); } - void CheckInference(bool valueOnlyInference = false) + void DataTypeAsserts() { - if (isNpgsqlDbTypeInferredFromClrType && npgsqlDbType is not null) - { - Assert.That(p.NpgsqlDbType, Is.EqualTo(npgsqlDbType), - () => $"Got wrong inferred NpgsqlDbType when inferring with {errorIdentifier[errorIdentifierIndex]}"); - } + var expectedDataTypeName = dataTypeNameWithoutFacetsAndQuotes; + var expectedNpgsqlDbType = npgsqlDbType ?? NpgsqlDbType.Unknown; + + var expectedDbType = expectedDbTypes?.DataTypeMappedDbType ?? DbType.Object; + + AssertParameterProperties(expectedDataTypeName, expectedNpgsqlDbType, expectedDbType); + } + + void DbTypeAsserts() + { + // If DbType was set it overrules any value based data type inference. + // As DbType.Object never has any mapping either we check for null/Unknown when DbType.Object was set. + var (expectedDataTypeName, expectedNpgsqlDbType) = + expectedDbTypes is { DbTypeToSet: DbType.Object } + ? (null, NpgsqlDbType.Unknown) + : GetInferredDataType(); + + var expectedDbType = expectedDbTypes?.DbTypeToSet ?? DbType.Object; - Assert.That(p.DbType, Is.EqualTo(valueOnlyInference ? inferredDbType : isNpgsqlDbTypeInferredFromClrType ? inferredDbType : dbType ?? DbType.Object), - () => $"Got wrong inferred DbType when inferring with {errorIdentifier[errorIdentifierIndex]}"); + AssertParameterProperties(expectedDataTypeName, expectedNpgsqlDbType, expectedDbType); + } + + void ValueAsserts() + { + var (expectedDataTypeName, expectedNpgsqlDbType) = GetInferredDataType(); - if (isNpgsqlDbTypeInferredFromClrType) - Assert.That(p.DataTypeName, Is.EqualTo(pgTypeNameWithoutFacets), - () => $"Got wrong inferred DataTypeName when inferring with {errorIdentifier[errorIdentifierIndex]}"); + var expectedDbType = expectedDbTypes?.ValueInferredDbType ?? DbType.Object; + + AssertParameterProperties(expectedDataTypeName, expectedNpgsqlDbType, expectedDbType); } + + void AssertParameterProperties(string? expectedDataTypeName, NpgsqlDbType expectedNpgsqlDbType, DbType expectedDbType) + { + Assert.That(p.DataTypeName, Is.EqualTo(expectedDataTypeName), + $"Got wrong DataTypeName when checking with {errorIdentifier[errorIdentifierIndex]}"); + Assert.That(p.NpgsqlDbType, Is.EqualTo(expectedNpgsqlDbType), + $"Got wrong NpgsqlDbType when checking with {errorIdentifier[errorIdentifierIndex]}"); + Assert.That(p.DbType, Is.EqualTo(expectedDbType), + $"Got wrong DbType when checking with {errorIdentifier[errorIdentifierIndex]}"); + } + + (string? ExpectedDataTypeName, NpgsqlDbType ExpectedNpgsqlDbType) GetInferredDataType() + => dataTypeInference switch + { + DataTypeInference.Match => + (dataTypeNameWithoutFacetsAndQuotes, npgsqlDbType ?? NpgsqlDbType.Unknown), + DataTypeInference.Mismatch => + // Only respect Mismatch if the type is well known (for now that means it has an NpgsqlDbType). + // Otherwise use the exact values so we'll error with the right details. + p.NpgsqlDbType is not NpgsqlDbType.Unknown + ? (p.DataTypeName, p.NpgsqlDbType) + : (dataTypeNameWithoutFacetsAndQuotes, npgsqlDbType ?? NpgsqlDbType.Unknown), + DataTypeInference.Nothing => + (null, NpgsqlDbType.Unknown), + _ => throw new UnreachableException($"Unknown case {dataTypeInference}") + }; } - public async Task AssertTypeUnsupported(T value, string sqlLiteral, string pgTypeName, NpgsqlDataSource? dataSource = null) + public Task AssertTypeRead(string sqlLiteral, string dataTypeName, T value, + bool valueTypeEqualsFieldType = true, Func? comparer = null, bool skipArrayCheck = false) + => AssertTypeReadCore(OpenConnectionAsync(), disposeConnection: true, sqlLiteral, dataTypeName, + value, valueTypeEqualsFieldType, comparer, skipArrayCheck); + + public Task AssertTypeRead(NpgsqlDataSource dataSource, string sqlLiteral, string dataTypeName, T value, + bool valueTypeEqualsFieldType = true, Func? comparer = null, bool skipArrayCheck = false) + => AssertTypeReadCore(dataSource.OpenConnectionAsync(), disposeConnection: true, sqlLiteral, dataTypeName, + value, valueTypeEqualsFieldType, comparer, skipArrayCheck); + + public Task AssertTypeRead(NpgsqlConnection connection, string sqlLiteral, string dataTypeName, T value, + bool valueTypeEqualsFieldType = true, Func? comparer = null, bool skipArrayCheck = false) + => AssertTypeReadCore(new(connection), disposeConnection: false, sqlLiteral, dataTypeName, + value, valueTypeEqualsFieldType, comparer, skipArrayCheck); + + static async Task AssertTypeReadCore( + ValueTask connectionTask, + bool disposeConnection, + string sqlLiteral, + string dataTypeName, + T value, + bool valueTypeEqualsFieldType, + Func? comparer, + bool skipArrayCheck) { - await AssertTypeUnsupportedRead(sqlLiteral, pgTypeName, dataSource); - await AssertTypeUnsupportedWrite(value, pgTypeName, dataSource); + var connection = await connectionTask; + await using var _ = disposeConnection ? connection : null; + + var result = await AssertTypeReadCore(connection, sqlLiteral, dataTypeName, value, valueTypeEqualsFieldType, comparer); + + // Check the corresponding array type as well + if (!skipArrayCheck && !dataTypeName.EndsWith("[]", StringComparison.Ordinal)) + { + await AssertTypeReadCore( + connection, + ArrayLiteral(sqlLiteral), + dataTypeName + "[]", + new[] { value, value }, + valueTypeEqualsFieldType, + comparer is null ? null : (array1, array2) => array1.SequenceEqual(array2, CreateEqualityComparer(comparer!))); + } + return result; } - public async Task AssertTypeUnsupportedRead(string sqlLiteral, string pgTypeName, NpgsqlDataSource? dataSource = null) + static async Task AssertTypeReadCore( + NpgsqlConnection connection, + string sqlLiteral, + string dataTypeName, + T value, + bool valueTypeEqualsFieldType, + Func? comparer) { - dataSource ??= DataSource; + if (sqlLiteral.Contains('\'')) + sqlLiteral = sqlLiteral.Replace("'", "''"); - await using var conn = await dataSource.OpenConnectionAsync(); - // Make sure we don't poison the connection with a fault, potentially terminating other perfectly passing tests as well. - await using var tx = dataSource.Settings.Multiplexing ? await conn.BeginTransactionAsync() : null; - await using var cmd = new NpgsqlCommand($"SELECT '{sqlLiteral}'::{pgTypeName}", conn); + await using var cmd = new NpgsqlCommand($"SELECT '{sqlLiteral}'::{dataTypeName}", connection); await using var reader = await cmd.ExecuteReaderAsync(); await reader.ReadAsync(); - return Assert.Throws(() => reader.GetValue(0))!; + var truncatedSqlLiteral = sqlLiteral.Length > 40 ? sqlLiteral[..40] + "..." : sqlLiteral; + + var actualDataTypeName = reader.GetDataTypeName(0); + var dotIndex = actualDataTypeName.IndexOf('.'); + if (dotIndex > -1 && actualDataTypeName.Substring(0, dotIndex) is "pg_catalog" or "public") + actualDataTypeName = actualDataTypeName.Substring(dotIndex + 1); + + // For composite type with dots, postgres works only with quoted name - scheme."My.type.name" + // but npgsql converts it to name without quotes + var dataTypeNameWithoutQuotes = dataTypeName.Replace("\"", string.Empty); + Assert.That(actualDataTypeName, Is.EqualTo(dataTypeNameWithoutQuotes), + $"Got wrong result from GetDataTypeName when reading '{truncatedSqlLiteral}'"); + + // For arrays, GetFieldType always returns typeof(Array), since PG arrays can have arbitrary dimensionality. + var isArrayTest = actualDataTypeName.EndsWith("[]", StringComparison.Ordinal) && typeof(T).IsArray; + Assert.That(reader.GetFieldType(0), + (valueTypeEqualsFieldType || isArrayTest ? new ConstraintExpression() : Is.Not) + .EqualTo(isArrayTest ? typeof(Array) : typeof(T)), + $"Got wrong result from GetFieldType when reading '{truncatedSqlLiteral}'"); + + T actual; + if (valueTypeEqualsFieldType) + { + actual = (T)reader.GetValue(0); + Assert.That(actual, comparer is null ? Is.EqualTo(value) : Is.EqualTo(value).Using(CreateEqualityComparer(comparer!)), + $"Got wrong result from GetValue() value when reading '{truncatedSqlLiteral}'"); + + actual = (T)reader.GetFieldValue(0); + Assert.That(actual, comparer is null ? Is.EqualTo(value) : Is.EqualTo(value).Using(CreateEqualityComparer(comparer)), + $"Got wrong result from GetFieldValue() value when reading '{truncatedSqlLiteral}'"); + + return actual; + } + + actual = reader.GetFieldValue(0); + + Assert.That(actual, comparer is null ? Is.EqualTo(value) : Is.EqualTo(value).Using(CreateEqualityComparer(comparer!)), + $"Got wrong result from GetFieldValue() value when reading '{truncatedSqlLiteral}'"); + + return actual; } - public Task AssertTypeUnsupportedRead(string sqlLiteral, string pgTypeName, + static EqualityComparer CreateEqualityComparer(Func comparer) + => EqualityComparer.Create((x, y) => + { + if (x is null && y is null) + return true; + if (x is null || y is null) + return false; + return comparer(x, y); + }); + + public async Task AssertTypeUnsupported(T value, string sqlLiteral, string dataTypeName, NpgsqlDataSource? dataSource = null, bool skipArrayCheck = false) + { + await AssertTypeUnsupportedRead(sqlLiteral, dataTypeName, dataSource, skipArrayCheck); + await AssertTypeUnsupportedWrite(value, dataTypeName, dataSource, skipArrayCheck); + } + + public Task AssertTypeUnsupportedRead(string sqlLiteral, string dataTypeName, NpgsqlDataSource? dataSource = null, bool skipArrayCheck = false) - => AssertTypeUnsupportedRead(sqlLiteral, pgTypeName, dataSource); + => AssertTypeUnsupportedRead(sqlLiteral, dataTypeName, dataSource, skipArrayCheck); - public async Task AssertTypeUnsupportedRead(string sqlLiteral, string pgTypeName, + public async Task AssertTypeUnsupportedRead(string sqlLiteral, string dataTypeName, NpgsqlDataSource? dataSource = null, bool skipArrayCheck = false) where TException : Exception { - var result = await AssertTypeUnsupportedReadCore(sqlLiteral, pgTypeName, dataSource); + var result = await AssertTypeUnsupportedReadCore(sqlLiteral, dataTypeName, dataSource); // Check the corresponding array type as well - if (!skipArrayCheck && !pgTypeName.EndsWith("[]", StringComparison.Ordinal)) + if (!skipArrayCheck && !dataTypeName.EndsWith("[]", StringComparison.Ordinal)) { - await AssertTypeUnsupportedReadCore(ArrayLiteral(sqlLiteral), pgTypeName + "[]", dataSource); + await AssertTypeUnsupportedReadCore(ArrayLiteral(sqlLiteral), dataTypeName + "[]", dataSource); } return result; } - async Task AssertTypeUnsupportedReadCore(string sqlLiteral, string pgTypeName, NpgsqlDataSource? dataSource = null) + async Task AssertTypeUnsupportedReadCore(string sqlLiteral, string dataTypeName, NpgsqlDataSource? dataSource = null) where TException : Exception { dataSource ??= DataSource; await using var conn = await dataSource.OpenConnectionAsync(); - // Make sure we don't poison the connection with a fault, potentially terminating other perfectly passing tests as well. - await using var tx = dataSource.Settings.Multiplexing ? await conn.BeginTransactionAsync() : null; - await using var cmd = new NpgsqlCommand($"SELECT '{sqlLiteral}'::{pgTypeName}", conn); + await using var tx = await conn.BeginTransactionAsync(); + await using var cmd = new NpgsqlCommand($"SELECT '{sqlLiteral}'::{dataTypeName}", conn); await using var reader = await cmd.ExecuteReaderAsync(); await reader.ReadAsync(); - return Assert.Throws(() => reader.GetFieldValue(0))!; + return Assert.Throws(() => + { + _ = typeof(T) == typeof(object) ? reader.GetValue(0) : reader.GetFieldValue(0); + })!; } - public Task AssertTypeUnsupportedWrite(T value, string? pgTypeName = null, NpgsqlDataSource? dataSource = null, + public Task AssertTypeUnsupportedWrite(T value, string? dataTypeName = null, NpgsqlDataSource? dataSource = null, bool skipArrayCheck = false) - => AssertTypeUnsupportedWrite(value, pgTypeName, dataSource, skipArrayCheck: false); + => AssertTypeUnsupportedWrite(value, dataTypeName, dataSource, skipArrayCheck); - public async Task AssertTypeUnsupportedWrite(T value, string? pgTypeName = null, + public async Task AssertTypeUnsupportedWrite(T value, string? dataTypeName = null, NpgsqlDataSource? dataSource = null, bool skipArrayCheck = false) where TException : Exception { - var result = await AssertTypeUnsupportedWriteCore(value, pgTypeName, dataSource); + var result = await AssertTypeUnsupportedWriteCore(value, dataTypeName, dataSource); // Check the corresponding array type as well - if (!skipArrayCheck && !pgTypeName?.EndsWith("[]", StringComparison.Ordinal) == true) + if (!skipArrayCheck && !dataTypeName?.EndsWith("[]", StringComparison.Ordinal) == true) { - await AssertTypeUnsupportedWriteCore(new[] { value, value }, pgTypeName + "[]", dataSource); + await AssertTypeUnsupportedWriteCore([value, value], dataTypeName + "[]", dataSource); } return result; } - async Task AssertTypeUnsupportedWriteCore(T value, string? pgTypeName = null, NpgsqlDataSource? dataSource = null) + async Task AssertTypeUnsupportedWriteCore(T value, string? dataTypeName = null, NpgsqlDataSource? dataSource = null) where TException : Exception { dataSource ??= DataSource; await using var conn = await dataSource.OpenConnectionAsync(); - // Make sure we don't poison the connection with a fault, potentially terminating other perfectly passing tests as well. - await using var tx = dataSource.Settings.Multiplexing ? await conn.BeginTransactionAsync() : null; + await using var tx = await conn.BeginTransactionAsync(); await using var cmd = new NpgsqlCommand("SELECT $1", conn) { Parameters = { new() { Value = value } } }; - if (pgTypeName is not null) - cmd.Parameters[0].DataTypeName = pgTypeName; + if (dataTypeName is not null) + cmd.Parameters[0].DataTypeName = dataTypeName; return Assert.ThrowsAsync(() => cmd.ExecuteReaderAsync())!; } - class SimpleComparer : IEqualityComparer - { - readonly Func _comparerDelegate; - - public SimpleComparer(Func comparerDelegate) - => _comparerDelegate = comparerDelegate; - - public bool Equals(T? x, T? y) - => x is null - ? y is null - : y is not null && _comparerDelegate(x, y); - - public int GetHashCode(T obj) => throw new NotSupportedException(); - } - // For array quoting rules, see array_out in https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/arrayfuncs.c static string ArrayLiteral(string elementLiteral) { @@ -634,7 +766,6 @@ async Task OpenConnectionInternal(bool hasLock) var builder = new NpgsqlConnectionStringBuilder(TestUtil.ConnectionString) { Pooling = false, - Multiplexing = false, Database = "postgres" }; diff --git a/test/Npgsql.Tests/SyncOrAsyncTestBase.cs b/test/Npgsql.Tests/SyncOrAsyncTestBase.cs index 0eff0c7488..2a676e97cb 100644 --- a/test/Npgsql.Tests/SyncOrAsyncTestBase.cs +++ b/test/Npgsql.Tests/SyncOrAsyncTestBase.cs @@ -4,13 +4,11 @@ namespace Npgsql.Tests; [TestFixture(SyncOrAsync.Sync)] [TestFixture(SyncOrAsync.Async)] -public abstract class SyncOrAsyncTestBase : TestBase +public abstract class SyncOrAsyncTestBase(SyncOrAsync syncOrAsync) : TestBase { protected bool IsAsync => SyncOrAsync == SyncOrAsync.Async; - protected SyncOrAsync SyncOrAsync { get; } - - protected SyncOrAsyncTestBase(SyncOrAsync syncOrAsync) => SyncOrAsync = syncOrAsync; + protected SyncOrAsync SyncOrAsync { get; } = syncOrAsync; } public enum SyncOrAsync diff --git a/test/Npgsql.Tests/SystemTransactionTests.cs b/test/Npgsql.Tests/SystemTransactionTests.cs index b71c949259..2363fb170a 100644 --- a/test/Npgsql.Tests/SystemTransactionTests.cs +++ b/test/Npgsql.Tests/SystemTransactionTests.cs @@ -258,6 +258,34 @@ public void Reuse_connection_rollback() AssertNumberOfRows(0, tableName); } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3735")] + public void Reuse_connection_resets_temp_tables() + { + // When a connection is closed inside a TransactionScope and then reopened, + // temp tables should be discarded. + using var dataSource = CreateDataSource(csb => csb.Enlist = true); + using (new TransactionScope()) + using (var conn = dataSource.CreateConnection()) + { + conn.Open(); + var processId = conn.ProcessID; + + // Create a temp table + conn.ExecuteNonQuery("CREATE TEMP TABLE temp_test (id INT)"); + + conn.Close(); + + // Reopen - should get the same physical connection but with reset state + conn.Open(); + Assert.That(conn.ProcessID, Is.EqualTo(processId), "Should reuse the same physical connection"); + + // The temp table should have been discarded + Assert.That(() => conn.ExecuteScalar("SELECT COUNT(*) FROM temp_test"), + Throws.Exception.TypeOf() + .With.Property(nameof(PostgresException.SqlState)).EqualTo(PostgresErrorCodes.UndefinedTable)); + } + } + [Test, Ignore("Timeout doesn't seem to fire on .NET Core / Linux")] public void Timeout_triggers_rollback_while_busy() { @@ -310,13 +338,15 @@ public void Single_unpooled_connection() scope.Complete(); } - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4963")] - public void Single_unpooled_closed_connection() + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/4963"), IssueLink("https://github.com/npgsql/npgsql/issues/5783")] + public void Single_closed_connection_in_transaction_scope([Values] bool pooling, [Values] bool multipleHosts) { using var dataSource = CreateDataSource(csb => { - csb.Pooling = false; + csb.Pooling = pooling; csb.Enlist = true; + csb.Host = multipleHosts ? "localhost,127.0.0.1" : csb.Host; }); using (var scope = new TransactionScope()) @@ -325,11 +355,11 @@ public void Single_unpooled_closed_connection() { cmd.ExecuteNonQuery(); conn.Close(); - Assert.That(dataSource.Statistics.Total, Is.EqualTo(1)); + Assert.That(pooling ? dataSource.Statistics.Busy : dataSource.Statistics.Total, Is.EqualTo(1)); scope.Complete(); } - Assert.That(dataSource.Statistics.Total, Is.EqualTo(0)); + Assert.That(pooling ? dataSource.Statistics.Busy : dataSource.Statistics.Total, Is.EqualTo(0)); } [Test] diff --git a/test/Npgsql.Tests/TaskTimeoutAndCancellationTest.cs b/test/Npgsql.Tests/TaskTimeoutAndCancellationTest.cs deleted file mode 100644 index e3759d35e9..0000000000 --- a/test/Npgsql.Tests/TaskTimeoutAndCancellationTest.cs +++ /dev/null @@ -1,162 +0,0 @@ -using System; -using System.Threading; -using System.Threading.Tasks; -using NUnit.Framework; -using Npgsql.Util; - -namespace Npgsql.Tests; - -[NonParallelizable] // To make sure unobserved tasks from other tests do not leak -public class TaskTimeoutAndCancellationTest : TestBase -{ - const int TestResultValue = 777; - - async Task GetResultTaskAsync(int timeout, CancellationToken ct) - { - await Task.Delay(timeout, ct); - return TestResultValue; - } - - Task GetVoidTaskAsync(int timeout, CancellationToken ct) => Task.Delay(timeout, ct); - - [Test] - public async Task SuccessfulResultTaskAsync() => - Assert.AreEqual(TestResultValue, await TaskTimeoutAndCancellation.ExecuteAsync(ct => GetResultTaskAsync(10, ct), NpgsqlTimeout.Infinite, CancellationToken.None)); - - [Test] - public async Task SuccessfulVoidTaskAsync() => - await TaskTimeoutAndCancellation.ExecuteAsync(ct => GetVoidTaskAsync(10, ct), NpgsqlTimeout.Infinite, CancellationToken.None); - - [Test] - public void InfinitelyLongTaskTimeout() => - Assert.ThrowsAsync(async () => - await TaskTimeoutAndCancellation.ExecuteAsync(ct => GetVoidTaskAsync(Timeout.Infinite, ct), new NpgsqlTimeout(TimeSpan.FromMilliseconds(10)), CancellationToken.None)); - - [Test] - public void InfinitelyLongTaskCancellation() - { - using var cts = new CancellationTokenSource(10); - Assert.ThrowsAsync(async () => - await TaskTimeoutAndCancellation.ExecuteAsync(ct => GetVoidTaskAsync(Timeout.Infinite, ct), NpgsqlTimeout.Infinite, cts.Token)); - } - - /// - /// The test creates a delayed execution Task that is being fake-cancelled and fails subsequently and triggers 'TaskScheduler.UnobservedTaskException event'. - /// - /// - /// The test is based on timing and depends on availability of thread pool threads. Therefore it could become unstable if the environment is under pressure. - /// - [Theory, IssueLink("https://github.com/npgsql/npgsql/issues/4149")] - [TestCase("CancelAndTimeout")] - [TestCase("CancelOnly")] - [TestCase("TimeoutOnly")] - [TestCase("CancelAndTimeout")] - [TestCase("CancelOnly")] - [TestCase("TimeoutOnly")] - public Task DelayedFaultedTaskCancellation(string testCase) => RunDelayedFaultedTaskTestAsync(async getUnobservedTaskException => - { - var cancel = true; - var timeout = true; - switch (testCase) - { - case "TimeoutOnly": - cancel = false; - break; - case "CancelOnly": - timeout = false; - break; - } - - var notifyDelayCompleted = new SemaphoreSlim(0, 1); - - // Invoke the method that creates a delayed execution Task that fails subsequently. - await CreateTaskAndPreemptWithCancellationAsync(500, cancel, timeout, notifyDelayCompleted); - - // Wait enough time for the non-cancelable task to notify us that an exception is thrown. - await notifyDelayCompleted.WaitAsync(); - - // And then wait some more. - var repeatCount = 2; - while (getUnobservedTaskException() is null && repeatCount-- > 0) - { - await Task.Delay(100); - - // Run the garbage collector to collect unobserved Tasks. - GC.Collect(); - GC.WaitForPendingFinalizers(); - } - }); - - static async Task RunDelayedFaultedTaskTestAsync(Func, Task> test) - { - // Run the garbage collector to collect unobserved Tasks from other tests. - GC.Collect(); - GC.WaitForPendingFinalizers(); - GC.Collect(); - - Exception? unobservedTaskException = null; - - // Subscribe to UnobservedTaskException event to store the Exception, if any. - void OnUnobservedTaskException(object? source, UnobservedTaskExceptionEventArgs args) - { - if (!args.Observed) - { - args.SetObserved(); - } - unobservedTaskException = args.Exception; - } - TaskScheduler.UnobservedTaskException += OnUnobservedTaskException; - - try - { - await test(() => unobservedTaskException); - - // Verify the unobserved Task exception event has not been received. - Assert.IsNull(unobservedTaskException, unobservedTaskException?.Message); - } - finally - { - TaskScheduler.UnobservedTaskException -= OnUnobservedTaskException; - } - } - - /// - /// Create a delayed execution, non-Cancellable Task that fails subsequently after the Task goes out of scope. - /// - static async Task CreateTaskAndPreemptWithCancellationAsync(int delayMs, bool cancel, bool timeout, SemaphoreSlim notifyDelayCompleted) - { - var nonCancellableTask = Task.Delay(delayMs, CancellationToken.None) - .ContinueWith( - async _ => - { - try - { - await Task.FromException(new Exception("Unobserved Task Test Exception")); - } - finally - { - notifyDelayCompleted.Release(); - } - }) - .Unwrap(); - - var timeoutMs = delayMs / 5; - using var cts = cancel ? new CancellationTokenSource(timeoutMs) : null; - try - { - await TaskTimeoutAndCancellation.ExecuteAsync( - _ => nonCancellableTask, - timeout ? new NpgsqlTimeout(TimeSpan.FromMilliseconds(timeoutMs)) : NpgsqlTimeout.Infinite, - cts?.Token ?? CancellationToken.None); - } - catch (TimeoutException) - { - // Expected due to preemptive time out. - } - catch (OperationCanceledException) when (cts?.IsCancellationRequested == true) - { - // Expected due to preemptive cancellation. - } - Assert.False(nonCancellableTask.IsCompleted); - } -} diff --git a/test/Npgsql.Tests/TestMetrics.cs b/test/Npgsql.Tests/TestMetrics.cs index 52bf2ed935..3b6c11dbda 100644 --- a/test/Npgsql.Tests/TestMetrics.cs +++ b/test/Npgsql.Tests/TestMetrics.cs @@ -41,17 +41,13 @@ private TestMetrics(TimeSpan allowedTime, bool reportOnStop) /// Report metrics to stdout when stopped. /// A new running TestMetrics object. public static TestMetrics Start(TimeSpan allowedTime, bool reportOnStop) - { - return new(allowedTime, reportOnStop); - } + => new(allowedTime, reportOnStop); /// - /// Incremnent the Iterations value by one. + /// Increment the Iterations value by one. /// public void IncrementIterations() - { - Iterations++; - } + => Iterations++; /// /// Stop the internal stop watch and record elapsed CPU times. @@ -81,9 +77,7 @@ public void Stop() /// Stop the internal stop watch and record elapsed CPU times. /// public void Dispose() - { - Stop(); - } + => Stop(); /// /// Report whether ElapsedClockTime has met or exceeded the maximum run time. @@ -96,19 +90,15 @@ public void Dispose() /// /// The number of iterations accumulated per the time span provided. public double IterationsPer(TimeSpan timeSpan) - { - return (double)Iterations / ((double)stopwatch.Elapsed.TotalMilliseconds / (double)timeSpan.TotalMilliseconds); - } + => (double)Iterations / ((double)stopwatch.Elapsed.TotalMilliseconds / (double)timeSpan.TotalMilliseconds); /// /// Calculate the number of iterations accumulated per second. - /// Equivelent to calling IterationsPer(new TimeSpan(0, 0, 1)). + /// Equivalent to calling IterationsPer(new TimeSpan(0, 0, 1)). /// /// The number of iterations accumulated per second. public double IterationsPerSecond() - { - return IterationsPer(new TimeSpan(0, 0, 1)); - } + => IterationsPer(new TimeSpan(0, 0, 1)); /// /// Calculate the number of iterations accumulated per the CPU time span provided. @@ -116,20 +106,16 @@ public double IterationsPerSecond() /// /// The number of iterations accumulated per the CPU time span provided. public double IterationsPerCPU(TimeSpan timeSpan) - { - return (double)Iterations / ((double)ElapsedTotalCPUTime.TotalMilliseconds / (double)timeSpan.TotalMilliseconds); - } + => (double)Iterations / ((double)ElapsedTotalCPUTime.TotalMilliseconds / (double)timeSpan.TotalMilliseconds); /// /// Calculate the number of iterations accumulated per CPU second. - /// Equivelent to calling IterationsPerCPU(new TimeSpan(0, 0, 1)). + /// Equivalent to calling IterationsPerCPU(new TimeSpan(0, 0, 1)). /// /// /// The number of iterations accumulated per CPU second. public double IterationsPerCPUSecond() - { - return IterationsPerCPU(new TimeSpan(0, 0, 1)); - } + => IterationsPerCPU(new TimeSpan(0, 0, 1)); /// /// Elapsed time since start. @@ -176,4 +162,4 @@ public TimeSpan ElapsedUserCPUTime /// Elapsed total (system + user) CPU time since start. /// public TimeSpan ElapsedTotalCPUTime => ElapsedSystemCPUTime + ElapsedUserCPUTime; -} \ No newline at end of file +} diff --git a/test/Npgsql.Tests/TestUtil.cs b/test/Npgsql.Tests/TestUtil.cs index 81b1140f48..0f83946ac7 100644 --- a/test/Npgsql.Tests/TestUtil.cs +++ b/test/Npgsql.Tests/TestUtil.cs @@ -19,7 +19,7 @@ public static class TestUtil /// test database. /// public const string DefaultConnectionString = - "Host=localhost;Username=npgsql_tests;Password=npgsql_tests;Database=npgsql_tests;Timeout=0;Command Timeout=0;SSL Mode=Disable;Multiplexing=False"; + "Host=localhost;Username=npgsql_tests;Password=npgsql_tests;Database=npgsql_tests;Timeout=0;Command Timeout=0;SSL Mode=Disable"; /// /// The connection string that will be used when opening the connection to the tests database. @@ -86,9 +86,12 @@ public static void MaximumPgVersionExclusive(NpgsqlConnection conn, string maxVe static readonly Version MinCreateExtensionVersion = new(9, 1); - public static void IgnoreOnRedshift(NpgsqlConnection conn, string? ignoreText = null) + public static async Task IgnoreOnRedshift(NpgsqlConnection conn, string? ignoreText = null) { - if (new NpgsqlConnectionStringBuilder(conn.ConnectionString).ServerCompatibilityMode == ServerCompatibilityMode.Redshift) + await using var command = conn.CreateCommand(); + command.CommandText = "SELECT version()"; + var version = (string)(await command.ExecuteScalarAsync())!; + if (version.Contains("redshift", StringComparison.OrdinalIgnoreCase)) { var msg = "Test ignored on Redshift"; if (ignoreText != null) @@ -97,9 +100,6 @@ public static void IgnoreOnRedshift(NpgsqlConnection conn, string? ignoreText = } } - public static bool IsPgPrerelease(NpgsqlConnection conn) - => ((string)conn.ExecuteScalar("SELECT version()")!).Contains("beta"); - public static void EnsureExtension(NpgsqlConnection conn, string extension, string? minVersion = null) => EnsureExtension(conn, extension, minVersion, async: false).GetAwaiter().GetResult(); @@ -165,21 +165,19 @@ static async Task IgnoreIfFeatureNotSupported(NpgsqlConnection conn, string test public static async Task EnsurePostgis(NpgsqlConnection conn) { - var isPreRelease = IsPgPrerelease(conn); try { await EnsureExtensionAsync(conn, "postgis"); } - catch (PostgresException e) when (e.SqlState == PostgresErrorCodes.UndefinedFile) + catch (PostgresException) { - // PostGIS packages aren't available for PostgreSQL prereleases - if (isPreRelease) + if (Environment.GetEnvironmentVariable("NPGSQL_TEST_POSTGIS")?.ToLower(CultureInfo.InvariantCulture) is "1" or "true") { - Assert.Ignore($"PostGIS could not be installed, but PostgreSQL is prerelease ({conn.ServerVersion}), ignoring test suite."); + throw; } else { - throw; + Assert.Ignore($"PostGIS isn't installed, skipping tests"); } } } @@ -425,7 +423,7 @@ internal static void AssertLoggingStateContains( (LogLevel Level, EventId Id, string Message, object? State, Exception? Exception) log, string key, T value) - => Assert.That(log.State, Contains.Item(new KeyValuePair(key, value))); + => Assert.That(log.State as IEnumerable>, Contains.Item(new KeyValuePair(key, value))); internal static void AssertLoggingStateDoesNotContain( (LogLevel Level, EventId Id, string Message, object? State, Exception? Exception) log, @@ -518,13 +516,9 @@ public static void WaitUntilCommandIsInProgress(this NpgsqlCommand command) /// test reproduces the issue) /// [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] -public class IssueLink : Attribute +public class IssueLink(string linkAddress) : Attribute { - public string LinkAddress { get; private set; } - public IssueLink(string linkAddress) - { - LinkAddress = linkAddress; - } + public string LinkAddress { get; private set; } = linkAddress; } public enum PrepareOrNot diff --git a/test/Npgsql.Tests/TracingTests.cs b/test/Npgsql.Tests/TracingTests.cs new file mode 100644 index 0000000000..0d51e6a6f6 --- /dev/null +++ b/test/Npgsql.Tests/TracingTests.cs @@ -0,0 +1,855 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Threading.Tasks; +using NpgsqlTypes; +using NUnit.Framework; +using static Npgsql.Tests.TestUtil; + +namespace Npgsql.Tests; + +[NonParallelizable] +[TestFixture(true)] +[TestFixture(false)] +public class TracingTests(bool async) : TestBase +{ + #region Physical open + + [Test] + public async Task PhysicalOpen() + { + using var activityListener = StartListener(out var activities); + await using var dataSource = CreateDataSource(); + await using var connection = async + ? await dataSource.OpenConnectionAsync() + : dataSource.OpenConnection(); + + Assert.That(activities, Has.Count.EqualTo(1)); + + var activity = activities[0]; + Assert.That(activity.DisplayName, Is.EqualTo("CONNECT " + connection.Settings.Database)); + Assert.That(activity.OperationName, Is.EqualTo("CONNECT " + connection.Settings.Database)); + Assert.That(activity.Status, Is.EqualTo(ActivityStatusCode.Unset)); + + Assert.That(activity.Events.Count(), Is.EqualTo(0)); + + var tags = activity.TagObjects.ToDictionary(t => t.Key, t => t.Value); + Assert.That(tags, Has.Count.EqualTo(connection.Settings.Port == 5432 ? 5 : 6)); + + Assert.That(tags["db.system.name"], Is.EqualTo("postgresql")); + Assert.That(tags["db.namespace"], Is.EqualTo(connection.Settings.Database)); + + Assert.That(tags, Does.Not.ContainKey("db.query.text")); + + Assert.That(tags["db.npgsql.data_source"], Is.EqualTo(connection.ConnectionString)); + Assert.That(tags["db.npgsql.connection_id"], Is.EqualTo(connection.ProcessID)); + } + + [Test] + public async Task PhysicalOpen_error() + { + using var activityListener = StartListener(out var activities); + await using var dataSource = CreateDataSource(x => x.Host = "not-existing-host"); + var exception = Assert.ThrowsAsync(async () => + { + await using var connection = async + ? await dataSource.OpenConnectionAsync() + : dataSource.OpenConnection(); + })!; + + var activity = GetSingleActivity(activities, "CONNECT " + dataSource.Settings.Database, "CONNECT " + dataSource.Settings.Database, ActivityStatusCode.Error, exception.Message); + + Assert.That(activity.Events.Count(), Is.EqualTo(1)); + var exceptionEvent = activity.Events.First(); + Assert.That(exceptionEvent.Name, Is.EqualTo("exception")); + + var exceptionTags = exceptionEvent.Tags.ToDictionary(t => t.Key, t => t.Value); + Assert.That(exceptionTags, Has.Count.EqualTo(3)); + + Assert.That(exceptionTags["exception.type"], Is.EqualTo(exception.GetType().FullName)); + Assert.That(exceptionTags["exception.message"], Does.Contain(exception.Message)); + Assert.That(exceptionTags["exception.stacktrace"], Does.Contain(exception.Message)); + + var activityTags = activity.TagObjects.ToDictionary(t => t.Key, t => t.Value); + Assert.That(activityTags, Has.Count.EqualTo(3)); + + Assert.That(activityTags["db.system.name"], Is.EqualTo("postgresql")); + Assert.That(activityTags["db.npgsql.data_source"], Is.EqualTo(dataSource.ConnectionString)); + + Assert.That(activityTags["error.type"], Is.EqualTo("System.Net.Sockets.SocketException")); + } + + [Test] + public async Task PhysicalOpen_disable() + { + using var activityListener = StartListener(out var activities); + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.ConfigureTracing(options => options.EnablePhysicalOpenTracing(enable: false)); + await using var dataSource = dataSourceBuilder.Build(); + + await using var connection = async ? await dataSource.OpenConnectionAsync() : dataSource.OpenConnection(); + + Assert.That(activities, Is.Empty); + } + + #endregion Physical open + + #region Command execution + + [Test] + public async Task CommandExecute([Values] bool batch) + { + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.Name = "TestTracingDataSource"; + dataSourceBuilder.ConfigureTracing(o => o.EnablePhysicalOpenTracing(false)); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + using var activityListener = StartListener(out var activities); + + await ExecuteScalar(connection, async, batch, "SELECT 42"); + + var activity = GetSingleActivity(activities, "postgresql", "postgresql"); + + Assert.That(activity.Events.Count(), Is.EqualTo(1)); + var firstResponseEvent = activity.Events.First(); + Assert.That(firstResponseEvent.Name, Is.EqualTo("received-first-response")); + + var tags = activity.TagObjects.ToDictionary(t => t.Key, t => t.Value); + Assert.That(tags, Has.Count.EqualTo(connection.Settings.Port == 5432 ? 6 : 7)); + + Assert.That(tags["db.query.text"], Is.EqualTo("SELECT 42")); + Assert.That(tags["db.system.name"], Is.EqualTo("postgresql")); + Assert.That(tags["db.namespace"], Is.EqualTo(connection.Settings.Database)); + + Assert.That(tags["db.npgsql.data_source"], Is.EqualTo("TestTracingDataSource")); + Assert.That(tags["db.npgsql.connection_id"], Is.EqualTo(connection.ProcessID)); + } + + [Test] + public async Task CommandExecute_error([Values] bool batch) + { + await using var dataSource = CreateDataSource(ds => ds.ConfigureTracing(o => o.EnablePhysicalOpenTracing(false))); + await using var connection = await dataSource.OpenConnectionAsync(); + + using var activityListener = StartListener(out var activities); + + Assert.ThrowsAsync(async () => await ExecuteScalar(connection, async, batch, "SELECT * FROM non_existing_table")); + + var activity = GetSingleActivity(activities, "postgresql", "postgresql", ActivityStatusCode.Error, PostgresErrorCodes.UndefinedTable); + + Assert.That(activity.Events.Count(), Is.EqualTo(1)); + var exceptionEvent = activity.Events.First(); + Assert.That(exceptionEvent.Name, Is.EqualTo("exception")); + + var exceptionTags = exceptionEvent.Tags.ToDictionary(t => t.Key, t => t.Value); + Assert.That(exceptionTags, Has.Count.EqualTo(3)); + + Assert.That(exceptionTags["exception.type"], Is.EqualTo("Npgsql.PostgresException")); + Assert.That(exceptionTags["exception.message"], Does.Contain("relation \"non_existing_table\" does not exist")); + Assert.That(exceptionTags["exception.stacktrace"], Does.Contain("relation \"non_existing_table\" does not exist")); + + var activityTags = activity.TagObjects.ToDictionary(t => t.Key, t => t.Value); + Assert.That(activityTags, Has.Count.EqualTo(connection.Settings.Port == 5432 ? 8 : 9)); + + Assert.That(activityTags["db.query.text"], Is.EqualTo("SELECT * FROM non_existing_table")); + Assert.That(activityTags["db.system.name"], Is.EqualTo("postgresql")); + Assert.That(activityTags["db.namespace"], Is.EqualTo(connection.Settings.Database)); + + Assert.That(activityTags["db.response.status_code"], Is.EqualTo(PostgresErrorCodes.UndefinedTable)); + Assert.That(activityTags["error.type"], Is.EqualTo(PostgresErrorCodes.UndefinedTable)); + + Assert.That(activityTags["db.npgsql.data_source"], Is.EqualTo(connection.ConnectionString)); + Assert.That(activityTags["db.npgsql.connection_id"], Is.EqualTo(connection.ProcessID)); + } + + [Test] + public async Task CommandExecute_explicit_prepare([Values] bool batch) + { + await using var dataSource = CreateDataSource(o => o.ConfigureTracing(o => o.EnablePhysicalOpenTracing(false))); + await using var connection = await dataSource.OpenConnectionAsync(); + + using var activityListener = StartListener(out var activities); + + await ExecuteScalar(connection, async, batch, "SELECT 42", prepare: false); + var activity = GetSingleActivity(activities, "postgresql", "postgresql"); + var tags = activity.TagObjects.ToDictionary(t => t.Key, t => t.Value); + Assert.That(tags, Does.Not.ContainKey("db.npgsql.prepared")); + + activities.Clear(); + await ExecuteScalar(connection, async, batch, "SELECT 42", prepare: true); + activity = GetSingleActivity(activities, "postgresql", "postgresql"); + tags = activity.TagObjects.ToDictionary(t => t.Key, t => t.Value); + Assert.That(tags["db.npgsql.prepared"], Is.True); + } + + [Test] + public async Task CommandExecute_auto_prepare([Values] bool batch) + { + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.ConnectionStringBuilder.MaxPoolSize = 1; + dataSourceBuilder.ConnectionStringBuilder.MaxAutoPrepare = 10; + dataSourceBuilder.ConnectionStringBuilder.AutoPrepareMinUsages = 2; + dataSourceBuilder.ConfigureTracing(o => o.EnablePhysicalOpenTracing(false)); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + using var activityListener = StartListener(out var activities); + + await ExecuteScalar(connection, async, batch, "SELECT 42"); + var activity = GetSingleActivity(activities, "postgresql", "postgresql"); + var tags = activity.TagObjects.ToDictionary(t => t.Key, t => t.Value); + Assert.That(tags, Does.Not.ContainKey("db.npgsql.prepared")); + + activities.Clear(); + await ExecuteScalar(connection, async, batch, "SELECT 42"); + activity = GetSingleActivity(activities, "postgresql", "postgresql"); + tags = activity.TagObjects.ToDictionary(t => t.Key, t => t.Value); + Assert.That(tags["db.npgsql.prepared"], Is.True); + } + + [Test] + public async Task CommandExecute_ConfigureTracing([Values] bool batch) + { + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.ConfigureTracing(options => + { + options + .EnablePhysicalOpenTracing(false) + .EnableFirstResponseEvent(enable: false) + .ConfigureCommandFilter(cmd => cmd.CommandText.Contains('2')) + .ConfigureBatchFilter(batch => batch.BatchCommands[0].CommandText.Contains('2')) + .ConfigureCommandSpanNameProvider(_ => "unknown_query") + .ConfigureBatchSpanNameProvider(_ => "unknown_query") + .ConfigureCommandEnrichmentCallback((activity, _) => activity.AddTag("custom_tag", "custom_value")) + .ConfigureBatchEnrichmentCallback((activity, _) => activity.AddTag("custom_tag", "custom_value")); + }); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + using var activityListener = StartListener(out var activities); + + await ExecuteScalar(connection, async, batch, "SELECT 1"); + + Assert.That(activities, Is.Empty); + + await ExecuteScalar(connection, async, batch, "SELECT 2"); + + var activity = GetSingleActivity(activities, "unknown_query", "unknown_query"); + + Assert.That(activity.Events.Count(), Is.EqualTo(0)); + + var tags = activity.TagObjects.ToDictionary(t => t.Key, t => t.Value); + Assert.That(tags["custom_tag"], Is.EqualTo("custom_value")); + } + + #endregion Command execution + + #region Binary import + + [Test] + public async Task BinaryImport() + { + await using var dataSource = CreateDataSource(ds => ds.ConfigureTracing(o => o.EnablePhysicalOpenTracing(false))); + await using var connection = await dataSource.OpenConnectionAsync(); + + var table = await CreateTempTable(connection, "field_text TEXT, field_int2 SMALLINT"); + + using var activityListener = StartListener(out var activities); + + var copyFromCommand = $"COPY {table} (field_text, field_int2) FROM STDIN BINARY"; + + if (async) + { + await using var writer = await connection.BeginBinaryImportAsync(copyFromCommand); + + await writer.StartRowAsync(); + await writer.WriteAsync("Hello"); + await writer.WriteAsync((short)8, NpgsqlDbType.Smallint); + + await writer.CompleteAsync(); + } + else + { + using var writer = connection.BeginBinaryImport(copyFromCommand); + + writer.StartRow(); + writer.Write("Hello"); + writer.Write((short)8, NpgsqlDbType.Smallint); + + writer.Complete(); + } + + var activity = GetSingleActivity(activities, "COPY FROM"); + + var tags = activity.TagObjects.ToDictionary(t => t.Key, t => t.Value); + Assert.That(tags, Has.Count.EqualTo(connection.Settings.Port == 5432 ? 8 : 9)); + + Assert.That(tags["db.query.text"], Is.EqualTo(copyFromCommand)); + Assert.That(tags["db.operation.name"], Is.EqualTo("COPY FROM")); + Assert.That(tags["db.system.name"], Is.EqualTo("postgresql")); + Assert.That(tags["db.namespace"], Is.EqualTo(connection.Settings.Database)); + + Assert.That(tags["db.npgsql.data_source"], Is.EqualTo(connection.ConnectionString)); + Assert.That(tags["db.npgsql.rows"], Is.EqualTo(1)); + + Assert.That(tags["db.npgsql.connection_id"], Is.EqualTo(connection.ProcessID)); + } + + [Test] + public async Task BinaryImport_cancel() + { + await using var dataSource = CreateDataSource(ds => ds.ConfigureTracing(o => o.EnablePhysicalOpenTracing(false))); + await using var connection = await dataSource.OpenConnectionAsync(); + + var table = await CreateTempTable(connection, "field_text TEXT, field_int2 SMALLINT"); + + using var activityListener = StartListener(out var activities); + + var copyFromCommand = $"COPY {table} (field_text, field_int2) FROM STDIN BINARY"; + + if (async) + { + await using var writer = await connection.BeginBinaryImportAsync(copyFromCommand); + await writer.StartRowAsync(); + await writer.WriteAsync("Hello"); + await writer.WriteAsync((short)8, NpgsqlDbType.Smallint); + // No Complete() call - disposing cancels + } + else + { + using var writer = connection.BeginBinaryImport(copyFromCommand); + writer.StartRow(); + writer.Write("Hello"); + writer.Write((short)8, NpgsqlDbType.Smallint); + // No Complete() call - disposing cancels + } + + _ = GetSingleActivity(activities, "COPY FROM"); + } + + [Test] + public async Task BinaryImport_error() + { + await using var dataSource = CreateDataSource(ds => ds.ConfigureTracing(o => o.EnablePhysicalOpenTracing(false))); + await using var connection = await dataSource.OpenConnectionAsync(); + + using var activityListener = StartListener(out var activities); + + var copyFromCommand = $"COPY non_existing_table (field_text, field_int2) FROM STDIN BINARY"; + + var ex = Assert.ThrowsAsync(async () => + { + await using var writer = async + ? await connection.BeginBinaryImportAsync(copyFromCommand) + : connection.BeginBinaryImport(copyFromCommand); + }); + + var activity = GetSingleActivity(activities, "COPY FROM", "COPY FROM", ActivityStatusCode.Error, PostgresErrorCodes.UndefinedTable); + + Assert.That(activity.Events.Count(), Is.EqualTo(1)); + var exceptionEvent = activity.Events.First(); + Assert.That(exceptionEvent.Name, Is.EqualTo("exception")); + + var exceptionTags = exceptionEvent.Tags.ToDictionary(t => t.Key, t => t.Value); + Assert.That(exceptionTags, Has.Count.EqualTo(3)); + + Assert.That(exceptionTags["exception.type"], Is.EqualTo("Npgsql.PostgresException")); + Assert.That(exceptionTags["exception.message"], Does.Contain("relation \"non_existing_table\" does not exist")); + Assert.That(exceptionTags["exception.stacktrace"], Does.Contain("relation \"non_existing_table\" does not exist")); + } + + #endregion Binary import + + #region Binary export + + [Test] + public async Task BinaryExport() + { + await using var dataSource = CreateDataSource(ds => ds.ConfigureTracing(o => o.EnablePhysicalOpenTracing(false))); + await using var connection = await dataSource.OpenConnectionAsync(); + + var table = await CreateTempTable(connection, "field_text TEXT, field_int2 SMALLINT"); + await connection.ExecuteNonQueryAsync($"INSERT INTO {table} (field_text, field_int2) VALUES ('Hello', 8)"); + + using var activityListener = StartListener(out var activities); + + var copyToCommand = $"COPY {table} (field_text, field_int2) TO STDOUT BINARY"; + + if (async) + { + await using var reader = await connection.BeginBinaryExportAsync(copyToCommand); + while (await reader.StartRowAsync() != -1) + { + _ = await reader.ReadAsync(); + _ = await reader.ReadAsync(); + } + } + else + { + using var reader = connection.BeginBinaryExport(copyToCommand); + while (reader.StartRow() != -1) + { + _ = reader.Read(); + _ = reader.Read(); + } + } + + var activity = GetSingleActivity(activities, "COPY TO"); + + var tags = activity.TagObjects.ToDictionary(t => t.Key, t => t.Value); + Assert.That(tags, Has.Count.EqualTo(connection.Settings.Port == 5432 ? 8 : 9)); + + Assert.That(tags["db.query.text"], Is.EqualTo(copyToCommand)); + Assert.That(tags["db.operation.name"], Is.EqualTo("COPY TO")); + Assert.That(tags["db.system.name"], Is.EqualTo("postgresql")); + Assert.That(tags["db.namespace"], Is.EqualTo(connection.Settings.Database)); + + Assert.That(tags["db.npgsql.data_source"], Is.EqualTo(connection.ConnectionString)); + Assert.That(tags["db.npgsql.rows"], Is.EqualTo(1)); + + Assert.That(tags["db.npgsql.connection_id"], Is.EqualTo(connection.ProcessID)); + } + + [Test] + public async Task BinaryExport_cancel() + { + await using var dataSource = CreateDataSource(ds => ds.ConfigureTracing(o => o.EnablePhysicalOpenTracing(false))); + await using var conn = await dataSource.OpenConnectionAsync(); + + using var activityListener = StartListener(out var activities); + + // This must be large enough to cause Postgres to queue up CopyData messages. + const string copyToCommand = "COPY (select md5(random()::text) as id from generate_series(1, 100000)) TO STDOUT BINARY"; + + if (async) + { + await using var exporter = await conn.BeginBinaryExportAsync(copyToCommand); + await exporter.StartRowAsync(); + await exporter.ReadAsync(); + await exporter.CancelAsync(); + } + else + { + using var exporter = await conn.BeginBinaryExportAsync(copyToCommand); + exporter.StartRow(); + exporter.Read(); + exporter.Cancel(); + } + + _ = GetSingleActivity(activities, "COPY TO"); + } + + [Test] + public async Task BinaryExport_error() + { + await using var dataSource = CreateDataSource(ds => ds.ConfigureTracing(o => o.EnablePhysicalOpenTracing(false))); + await using var connection = await dataSource.OpenConnectionAsync(); + + using var activityListener = StartListener(out var activities); + + var copyToCommand = $"COPY non_existing_table (field_text, field_int2) TO STDOUT BINARY"; + var ex = Assert.ThrowsAsync(async () => + { + await using var reader = async + ? await connection.BeginBinaryExportAsync(copyToCommand) + : connection.BeginBinaryExport(copyToCommand); + }); + + var activity = GetSingleActivity(activities, "COPY TO", "COPY TO", ActivityStatusCode.Error, PostgresErrorCodes.UndefinedTable); + + Assert.That(activity.Events.Count(), Is.EqualTo(1)); + var exceptionEvent = activity.Events.First(); + Assert.That(exceptionEvent.Name, Is.EqualTo("exception")); + + var exceptionTags = exceptionEvent.Tags.ToDictionary(t => t.Key, t => t.Value); + Assert.That(exceptionTags, Has.Count.EqualTo(3)); + + Assert.That(exceptionTags["exception.type"], Is.EqualTo("Npgsql.PostgresException")); + Assert.That(exceptionTags["exception.message"], Does.Contain("relation \"non_existing_table\" does not exist")); + Assert.That(exceptionTags["exception.stacktrace"], Does.Contain("relation \"non_existing_table\" does not exist")); + } + + #endregion Binary export + + #region Raw binary + + [Test] + public async Task RawBinaryExport() + { + await using var dataSource = CreateDataSource(ds => ds.ConfigureTracing(o => o.EnablePhysicalOpenTracing(false))); + await using var connection = await dataSource.OpenConnectionAsync(); + + var table = await CreateTempTable(connection, "field_text TEXT, field_int2 SMALLINT"); + await connection.ExecuteNonQueryAsync($"INSERT INTO {table} (field_text, field_int2) VALUES ('Hello', 8)"); + + using var activityListener = StartListener(out var activities); + + // Raw binary export + var copyToCommand = $"COPY {table} (field_text, field_int2) TO STDIN BINARY"; + var buffer = new byte[1024]; + if (async) + { + await using var stream = await connection.BeginRawBinaryCopyAsync(copyToCommand); + while (await stream.ReadAsync(buffer, 0, buffer.Length) > 0) { } + } + else + { + using var stream = connection.BeginRawBinaryCopy(copyToCommand); + while (stream.Read(buffer, 0, buffer.Length) > 0) { } + } + + var activity = GetSingleActivity(activities, "COPY"); + + var tags = activity.TagObjects.ToDictionary(t => t.Key, t => t.Value); + + Assert.That(tags, Has.Count.EqualTo(connection.Settings.Port == 5432 ? 7 : 8)); + + Assert.That(tags["db.query.text"], Is.EqualTo(copyToCommand)); + Assert.That(tags["db.operation.name"], Is.EqualTo("COPY TO")); + Assert.That(tags["db.system.name"], Is.EqualTo("postgresql")); + Assert.That(tags["db.namespace"], Is.EqualTo(connection.Settings.Database)); + + Assert.That(tags["db.npgsql.data_source"], Is.EqualTo(connection.ConnectionString)); + + Assert.That(tags["db.npgsql.connection_id"], Is.EqualTo(connection.ProcessID)); + + Assert.That(tags, Does.Not.ContainKey("db.npgsql.rows")); + } + + [Test] + public async Task RawBinaryExport_cancel() + { + await using var dataSource = CreateDataSource(ds => ds.ConfigureTracing(o => o.EnablePhysicalOpenTracing(false))); + await using var connection = await dataSource.OpenConnectionAsync(); + + var table = await CreateTempTable(connection, "field_text TEXT, field_int2 SMALLINT"); + await connection.ExecuteNonQueryAsync($"INSERT INTO {table} (field_text, field_int2) VALUES ('Hello', 8)"); + + using var activityListener = StartListener(out var activities); + + var copyToCommand = $"COPY {table} (field_text, field_int2) TO STDIN BINARY"; + var buffer = new byte[1024]; + if (async) + { + await using var stream = await connection.BeginRawBinaryCopyAsync(copyToCommand); + var _ = await stream.ReadAsync(buffer, 0, buffer.Length); + await stream.CancelAsync(); + } + else + { + using var stream = connection.BeginRawBinaryCopy(copyToCommand); + var _ = stream.Read(buffer, 0, buffer.Length); + stream.Cancel(); + } + + _ = GetSingleActivity(activities, "COPY"); + } + + [Test] + public async Task RawBinaryImport_cancel() + { + await using var dataSource = CreateDataSource(ds => ds.ConfigureTracing(o => o.EnablePhysicalOpenTracing(false))); + await using var connection = await dataSource.OpenConnectionAsync(); + + var table = await CreateTempTable(connection, "field_text TEXT, field_int2 SMALLINT"); + + using var activityListener = StartListener(out var activities); + + var copyToCommand = $"COPY {table} (field_text, field_int2) FROM STDIN BINARY"; + byte[] garbage = [1, 2, 3, 4]; + if (async) + { + await using var stream = await connection.BeginRawBinaryCopyAsync(copyToCommand); + await stream.WriteAsync(garbage); + await stream.FlushAsync(); + await stream.CancelAsync(); + } + else + { + using var stream = connection.BeginRawBinaryCopy(copyToCommand); + stream.Write(garbage); + stream.Flush(); + stream.Cancel(); + } + + _ = GetSingleActivity(activities, "COPY"); + } + + [Test] + public async Task RawBinaryImport_error() + { + await using var dataSource = CreateDataSource(ds => ds.ConfigureTracing(o => o.EnablePhysicalOpenTracing(false))); + await using var connection = await dataSource.OpenConnectionAsync(); + + using var activityListener = StartListener(out var activities); + + var copyFromCommand = $"COPY non_existing_table (field_text, field_int2) FROM STDIN BINARY"; + var ex = Assert.ThrowsAsync(async () => + { + await using var stream = async + ? await connection.BeginRawBinaryCopyAsync(copyFromCommand) + : connection.BeginRawBinaryCopy(copyFromCommand); + }); + + var activity = GetSingleActivity(activities, "COPY", "COPY", ActivityStatusCode.Error, PostgresErrorCodes.UndefinedTable); + + Assert.That(activity.Events.Count(), Is.EqualTo(1)); + var exceptionEvent = activity.Events.First(); + Assert.That(exceptionEvent.Name, Is.EqualTo("exception")); + + var exceptionTags = exceptionEvent.Tags.ToDictionary(t => t.Key, t => t.Value); + Assert.That(exceptionTags, Has.Count.EqualTo(3)); + + Assert.That(exceptionTags["exception.type"], Is.EqualTo("Npgsql.PostgresException")); + Assert.That(exceptionTags["exception.message"], Does.Contain("relation \"non_existing_table\" does not exist")); + Assert.That(exceptionTags["exception.stacktrace"], Does.Contain("relation \"non_existing_table\" does not exist")); + } + + #endregion Raw binary + + #region Text COPY + + [Test] + public async Task TextImport() + { + await using var dataSource = CreateDataSource(ds => ds.ConfigureTracing(o => o.EnablePhysicalOpenTracing(false))); + await using var connection = await dataSource.OpenConnectionAsync(); + + var table = await CreateTempTable(connection, "field_text TEXT, field_int2 SMALLINT"); + + using var activityListener = StartListener(out var activities); + + var copyFromCommand = $"COPY {table} (field_text, field_int2) FROM STDIN"; + + if (async) + { + await using var writer = await connection.BeginTextImportAsync(copyFromCommand); + await writer.WriteAsync("Hello\t8\n"); + } + else + { + using var writer = connection.BeginTextImport(copyFromCommand); + writer.Write("Hello\t8\n"); + } + + var activity = GetSingleActivity(activities, "COPY FROM"); + + var tags = activity.TagObjects.ToDictionary(t => t.Key, t => t.Value); + + Assert.That(tags, Has.Count.EqualTo(connection.Settings.Port == 5432 ? 7 : 8)); + + Assert.That(tags["db.query.text"], Is.EqualTo(copyFromCommand)); + Assert.That(tags["db.operation.name"], Is.EqualTo("COPY FROM")); + Assert.That(tags["db.system.name"], Is.EqualTo("postgresql")); + Assert.That(tags["db.namespace"], Is.EqualTo(connection.Settings.Database)); + + Assert.That(tags["db.npgsql.data_source"], Is.EqualTo(connection.ConnectionString)); + + Assert.That(tags["db.npgsql.connection_id"], Is.EqualTo(connection.ProcessID)); + + Assert.That(tags, Does.Not.ContainKey("db.npgsql.rows")); + } + + [Test] + public async Task TextExport() + { + await using var dataSource = CreateDataSource(ds => ds.ConfigureTracing(o => o.EnablePhysicalOpenTracing(false))); + await using var connection = await dataSource.OpenConnectionAsync(); + + var table = await CreateTempTable(connection, "field_text TEXT, field_int2 SMALLINT"); + + var insertCmd = $"INSERT INTO {table} (field_text, field_int2) VALUES ('Hello', 8)"; + await connection.ExecuteNonQueryAsync(insertCmd); + + using var activityListener = StartListener(out var activities); + + var copyFromCommand = $"COPY {table} (field_text, field_int2) TO STDIN"; + + var chars = new char[30]; + if (async) + { + await using var reader = await connection.BeginTextExportAsync(copyFromCommand); + _ = await reader.ReadAsync(chars); + } + else + { + using var reader = connection.BeginTextExport(copyFromCommand); + _ = reader.Read(chars); + } + + var activity = GetSingleActivity(activities, "COPY TO"); + + var tags = activity.TagObjects.ToDictionary(t => t.Key, t => t.Value); + + Assert.That(tags, Has.Count.EqualTo(connection.Settings.Port == 5432 ? 7 : 8)); + + Assert.That(tags["db.query.text"], Is.EqualTo(copyFromCommand)); + Assert.That(tags["db.operation.name"], Is.EqualTo("COPY TO")); + Assert.That(tags["db.system.name"], Is.EqualTo("postgresql")); + Assert.That(tags["db.namespace"], Is.EqualTo(connection.Settings.Database)); + + Assert.That(tags["db.npgsql.data_source"], Is.EqualTo(connection.ConnectionString)); + + Assert.That(tags["db.npgsql.connection_id"], Is.EqualTo(connection.ProcessID)); + + Assert.That(tags, Does.Not.ContainKey("db.npgsql.rows")); + } + + // Text COPY is implemented over NpgsqlRawCopyStream internally, without any additional tracing-related logic. + // So we do only basic direct coverage and depend on the general raw tests for the rest. + + #endregion Text COPY + + // All ConfigureTracing() aspects of COPY are implemented in a single code path for all COPY paths, so we test just one. + + [Test] + public async Task Copy_ConfigureTracing() + { + await using var dataSource = CreateDataSource(builder => builder.ConfigureTracing(options => + options + .EnablePhysicalOpenTracing(false) + .ConfigureCopyOperationFilter(command => command.Contains("filter_in")) + .ConfigureCopyOperationSpanNameProvider(_ => "custom_binary_import") + .ConfigureCopyOperationEnrichmentCallback((activity, _) => activity.AddTag("custom_tag", "custom_value")))); + + await using var conn = await dataSource.OpenConnectionAsync(); + + var table = await CreateTempTable(conn, "field_text TEXT, field_int_filter_in SMALLINT"); + var copyCommand = $"COPY {table} (field_text, field_int_filter_in) FROM STDIN BINARY"; + + var filteredOutTable = await CreateTempTable(conn, "field_text TEXT, field_int_filter_out SMALLINT"); + var filteredOutCopyCommand = $"COPY {filteredOutTable} (field_text, field_int_filter_out) FROM STDIN BINARY"; + + using var activityListener = StartListener(out var activities); + + + if (async) + { + await using (var writer = await conn.BeginBinaryImportAsync(copyCommand)) + { + await writer.CompleteAsync(); + } + + await using (var writer = await conn.BeginBinaryImportAsync(filteredOutCopyCommand)) + { + await writer.CompleteAsync(); + } + } + else + { + using (var writer = conn.BeginBinaryImport(copyCommand)) + { + writer.Complete(); + } + + using (var writer = conn.BeginBinaryImport(filteredOutCopyCommand)) + { + writer.Complete(); + } + } + + // There should be just one activity since one of the two COPY commands is filtered out + var activity = GetSingleActivity(activities, "custom_binary_import", "custom_binary_import"); + + var tags = activity.TagObjects.ToDictionary(t => t.Key, t => t.Value); + Assert.That(tags["custom_tag"], Is.EqualTo("custom_value")); + } + + [Test] + public async Task Password_does_not_leak_via_datasource_name([Values] bool persistSecurityInfo) + { + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.ConnectionStringBuilder.PersistSecurityInfo = persistSecurityInfo; + // Do not set the data source name - this makes it default to the connection string, but without + // the password (even when Persist Security Info is true) + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + using var activityListener = StartListener(out var activities); + + await ExecuteScalar(connection, async, isBatch: false, query: "SELECT 42"); + + var activity = GetSingleActivity(activities, "postgresql", "postgresql"); + + var tags = activity.TagObjects.ToDictionary(t => t.Key, t => t.Value); + var connectionString = new NpgsqlConnectionStringBuilder((string)tags["db.npgsql.data_source"]!); + Assert.That(connectionString.Password, Is.Null); + } + + static ActivityListener StartListener(out List activities) + { + var a = new List(); + + var activityListener = new ActivityListener + { + ShouldListenTo = source => source.Name == "Npgsql", + Sample = (ref _) => ActivitySamplingResult.AllDataAndRecorded, + ActivityStopped = activity => a.Add(activity) + }; + ActivitySource.AddActivityListener(activityListener); + + activities = a; + return activityListener; + } + + static Activity GetSingleActivity( + List activities, + string? expectedDisplayName, + string? expectedOperationName = null, + ActivityStatusCode? expectedStatusCode = null, + string? expectedStatusDescription = null) + { + Assert.That(activities, Has.Count.EqualTo(1)); + var activity = activities[0]; + Assert.That(activity.DisplayName, Is.EqualTo(expectedDisplayName)); + Assert.That(activity.OperationName, Is.EqualTo(expectedOperationName ?? expectedDisplayName)); + Assert.That(activity.Status, Is.EqualTo(expectedStatusCode ?? ActivityStatusCode.Unset)); + Assert.That(activity.StatusDescription, Is.EqualTo(expectedStatusDescription)); + + return activity; + } + + static async Task ExecuteScalar(NpgsqlConnection connection, bool async, bool isBatch, string query, bool prepare = false) + { + if (isBatch) + { + await using var batch = connection.CreateBatch(); + var batchCommand = batch.CreateBatchCommand(); + batchCommand.CommandText = query; + batch.BatchCommands.Add(batchCommand); + + if (prepare) + { + if (async) + await batch.PrepareAsync(); + else + batch.Prepare(); + } + + if (async) + return await batch.ExecuteScalarAsync(); + else + return batch.ExecuteScalar(); + } + else + { + await using var command = connection.CreateCommand(); + command.CommandText = query; + + if (prepare) + { + if (async) + await command.PrepareAsync(); + else + command.Prepare(); + } + + if (async) + return await command.ExecuteScalarAsync(); + else + return command.ExecuteScalar(); + } + } +} diff --git a/test/Npgsql.Tests/TransactionTests.cs b/test/Npgsql.Tests/TransactionTests.cs index e0e61f95b4..2832ed7fa1 100644 --- a/test/Npgsql.Tests/TransactionTests.cs +++ b/test/Npgsql.Tests/TransactionTests.cs @@ -12,14 +12,11 @@ namespace Npgsql.Tests; -public class TransactionTests : MultiplexingTestBase +public class TransactionTests : TestBase { - [Test, Description("Basic insert within a commited transaction")] + [Test, Description("Basic insert within a committed transaction")] public async Task Commit([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) { - if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - return; - await using var conn = await OpenConnectionAsync(); var table = await CreateTempTable(conn, "name TEXT"); @@ -37,18 +34,12 @@ public async Task Commit([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(1)); } - // With multiplexing we can't assume that disposed NpgsqlTransaction will throw ObjectDisposedException - // Because disposed NpgsqlTransaction might be reused by another thread - if (!IsMultiplexing) - Assert.That(() => tx.Connection, Throws.Exception.TypeOf()); + Assert.That(() => tx.Connection, Throws.Exception.TypeOf()); } - [Test, Description("Basic insert within a commited transaction")] + [Test, Description("Basic insert within a committed transaction")] public async Task CommitAsync([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) { - if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - return; - await using var conn = await OpenConnectionAsync(); var table = await CreateTempTable(conn, "name TEXT"); @@ -66,18 +57,12 @@ public async Task CommitAsync([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Pre Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(1)); } - // With multiplexing we can't assume that disposed NpgsqlTransaction will throw ObjectDisposedException - // Because disposed NpgsqlTransaction might be reused by another thread - if (!IsMultiplexing) - Assert.That(() => tx.Connection, Throws.Exception.TypeOf()); + Assert.That(() => tx.Connection, Throws.Exception.TypeOf()); } [Test, Description("Basic insert within a rolled back transaction")] public async Task Rollback([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) { - if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - return; - await using var conn = await OpenConnectionAsync(); var table = await CreateTempTable(conn, "name TEXT"); @@ -95,18 +80,12 @@ public async Task Rollback([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepar Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(0)); } - // With multiplexing we can't assume that disposed NpgsqlTransaction will throw ObjectDisposedException - // Because disposed NpgsqlTransaction might be reused by another thread - if (!IsMultiplexing) - Assert.That(() => tx.Connection, Throws.Exception.TypeOf()); + Assert.That(() => tx.Connection, Throws.Exception.TypeOf()); } [Test, Description("Basic insert within a rolled back transaction")] public async Task RollbackAsync([Values(PrepareOrNot.NotPrepared, PrepareOrNot.Prepared)] PrepareOrNot prepare) { - if (prepare == PrepareOrNot.Prepared && IsMultiplexing) - return; - await using var conn = await OpenConnectionAsync(); var table = await CreateTempTable(conn, "name TEXT"); @@ -124,10 +103,7 @@ public async Task RollbackAsync([Values(PrepareOrNot.NotPrepared, PrepareOrNot.P Assert.That(await conn.ExecuteScalarAsync($"SELECT COUNT(*) FROM {table}"), Is.EqualTo(0)); } - // With multiplexing we can't assume that disposed NpgsqlTransaction will throw ObjectDisposedException - // Because disposed NpgsqlTransaction might be reused by another thread - if (!IsMultiplexing) - Assert.That(() => tx.Connection, Throws.Exception.TypeOf()); + Assert.That(() => tx.Connection, Throws.Exception.TypeOf()); } [Test, Description("Dispose a transaction in progress, should roll back")] @@ -249,21 +225,12 @@ public async Task Default_IsolationLevel() tx.Rollback(); } - [Test, Description("Makes sure that transactions started in SQL work, except in multiplexing")] + [Test, Description("Makes sure that transactions started in SQL work")] public async Task Via_sql() { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: not implemented"); - await using var conn = await OpenConnectionAsync(); var table = await CreateTempTable(conn, "name TEXT"); - if (IsMultiplexing) - { - Assert.That(async () => await conn.ExecuteNonQueryAsync("BEGIN"), Throws.Exception.TypeOf()); - return; - } - await conn.ExecuteNonQueryAsync("BEGIN"); await conn.ExecuteNonQueryAsync($"INSERT INTO {table} (name) VALUES ('X')"); await conn.ExecuteNonQueryAsync("ROLLBACK"); @@ -356,9 +323,6 @@ public async Task Failed_transaction_on_close_with_custom_timeout() [Test, IssueLink("https://github.com/npgsql/npgsql/issues/555")] public async Task Transaction_on_recycled_connection() { - if (IsMultiplexing) - Assert.Ignore("Multiplexing: fails"); - // Use application name to make sure we have our very own private connection pool await using var conn = new NpgsqlConnection(ConnectionString + $";Application Name={GetUniqueIdentifier(nameof(Transaction_on_recycled_connection))}"); conn.Open(); @@ -507,12 +471,10 @@ public async Task IsCompleted_rollback_failed() public async Task Transaction_not_supported() { // TODO: rewrite to DataSource - if (IsMultiplexing) - Assert.Ignore("Need to rethink/redo dummy transaction mode"); var connString = new NpgsqlConnectionStringBuilder(ConnectionString) { - ApplicationName = nameof(Transaction_not_supported) + IsMultiplexing + ApplicationName = nameof(Transaction_not_supported) }.ToString(); NpgsqlDatabaseInfo.RegisterFactory(new NoTransactionDatabaseInfoFactory()); @@ -555,45 +517,6 @@ public async Task Transaction_not_supported() } } - [Test] - [IssueLink("https://github.com/npgsql/npgsql/issues/3248")] - // More at #3254 - public async Task Bug3248_Dispose_transaction_Rollback() - { - if (!IsMultiplexing) - return; - - using var conn = await OpenConnectionAsync(); - await using (var tx = await conn.BeginTransactionAsync()) - { - Assert.That(conn.Connector, Is.Not.Null); - Assert.That(async () => await conn.ExecuteScalarAsync("SELECT * FROM \"unknown_table\"", tx: tx), - Throws.Exception.TypeOf()); - Assert.That(conn.Connector, Is.Not.Null); - } - - Assert.That(conn.Connector, Is.Null); - } - - [Test] - [IssueLink("https://github.com/npgsql/npgsql/issues/3248")] - // More at #3254 - public async Task Bug3248_Dispose_connection_Rollback() - { - if (!IsMultiplexing) - return; - - var conn = await OpenConnectionAsync(); - var tx = conn.BeginTransaction(); - Assert.That(conn.Connector, Is.Not.Null); - Assert.That(async () => await conn.ExecuteScalarAsync("SELECT * FROM \"unknown_table\"", tx: tx), - Throws.Exception.TypeOf()); - Assert.That(conn.Connector, Is.Not.Null); - - await conn.DisposeAsync(); - Assert.That(conn.Connector, Is.Null); - } - [Test] [IssueLink("https://github.com/npgsql/npgsql/issues/3306")] [TestCase(true)] @@ -696,9 +619,6 @@ public async Task Unbound_transaction_reuse() [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3686")] public async Task Bug3686() { - if (IsMultiplexing) - return; - await using var dataSource = CreateDataSource(csb => csb.Pooling = false); await using var conn = await dataSource.OpenConnectionAsync(); await using var tx = await conn.BeginTransactionAsync(); @@ -746,6 +666,4 @@ public void Bug184_Rollback_fails_on_aborted_transaction() t.Rollback(); } } - - public TransactionTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/TypeMapperTests.cs b/test/Npgsql.Tests/TypeMapperTests.cs index d0d1e36587..469a57be01 100644 --- a/test/Npgsql.Tests/TypeMapperTests.cs +++ b/test/Npgsql.Tests/TypeMapperTests.cs @@ -1,9 +1,12 @@ using Npgsql.Internal; using NUnit.Framework; using System; +using System.Data; using System.Threading.Tasks; using Npgsql.Internal.Converters; using Npgsql.Internal.Postgres; +using Npgsql.TypeMapping; +using NpgsqlTypes; using static Npgsql.Tests.TestUtil; namespace Npgsql.Tests; @@ -42,7 +45,6 @@ public async Task ReloadTypes_across_connections_in_data_source() } [Test] - [NonParallelizable] // Depends on citext which could be dropped concurrently public async Task String_to_citext() { await using var adminConnection = await OpenConnectionAsync(); @@ -58,8 +60,97 @@ public async Task String_to_citext() Assert.That(command.ExecuteScalar(), Is.True); } + [Test] + public async Task String_to_citext_with_db_type_string() + { + await using var adminConnection = await OpenConnectionAsync(); + await EnsureExtensionAsync(adminConnection, "citext"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + ((INpgsqlTypeMapper)dataSourceBuilder).AddDbTypeResolverFactory(new ForceStringToCitextResolverFactory()); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + await using var command = new NpgsqlCommand("SELECT @p = 'hello'::citext", connection); + var parameter = new NpgsqlParameter("p", DbType.String) + { + Value = "HeLLo" + }; + command.Parameters.Add(parameter); + + Assert.That(command.ExecuteScalar(), Is.True); + Assert.That(parameter.DbType, Is.EqualTo(DbType.String)); + Assert.That(parameter.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Citext)); + Assert.That(parameter.DataTypeName, Is.EqualTo("citext")); + } + + [Test] + public async Task Guid_to_custom_type() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.AddTypeInfoResolverFactory(new GuidTextConverterFactory(type)); + ((INpgsqlTypeMapper)dataSourceBuilder).AddDbTypeResolverFactory(new GuidTextDbTypeResolverFactory(type)); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + await connection.ExecuteNonQueryAsync($"CREATE TYPE {type}"); + await connection.ExecuteNonQueryAsync($""" + -- Input: cstring -> Custom type + CREATE FUNCTION {type}_in(cstring) + RETURNS {type} + AS 'textin' + LANGUAGE internal IMMUTABLE STRICT; + + -- Output: Custom type -> cstring + CREATE FUNCTION {type}_out({type}) + RETURNS cstring + AS 'textout' + LANGUAGE internal IMMUTABLE STRICT; + + -- 3️⃣ Create wrappers for binary I/O + CREATE FUNCTION {type}_recv(internal) + RETURNS {type} + AS 'textrecv' + LANGUAGE internal IMMUTABLE STRICT; + + CREATE FUNCTION {type}_send({type}) + RETURNS bytea + AS 'textsend' + LANGUAGE internal IMMUTABLE STRICT; + """); + + await connection.ExecuteNonQueryAsync($""" + CREATE TYPE {type} ( + internallength = variable, + input = {type}_in, + output = {type}_out, + receive = {type}_recv, + send = {type}_send, + alignment = int4 + ); + CREATE CAST ({type} AS text) WITH INOUT AS IMPLICIT; + """); + await connection.ReloadTypesAsync(); + + var guid = Guid.NewGuid(); + await using var command = new NpgsqlCommand($"SELECT @p::text = '{guid}'", connection); + var parameter = new NpgsqlParameter("p", DbType.Guid) + { + Value = guid + }; + command.Parameters.Add(parameter); + + Assert.That(command.ExecuteScalar(), Is.True); + Assert.That(parameter.DbType, Is.EqualTo(DbType.Guid)); + Assert.That(parameter.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Unknown)); + Assert.That(parameter.DataTypeName, Is.EqualTo(type)); + } + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4582")] - [NonParallelizable] // Drops extension + [NonParallelizable] // Drops global citext extension. public async Task Type_in_non_default_schema() { await using var conn = await OpenConnectionAsync(); @@ -109,6 +200,81 @@ sealed class Resolver : IPgTypeInfoResolver } + class ForceStringToCitextResolverFactory : DbTypeResolverFactory + { + public override IDbTypeResolver CreateDbTypeResolver(NpgsqlDatabaseInfo databaseInfo) => new DbTypeResolver(); + + sealed class DbTypeResolver : IDbTypeResolver + { + public string? GetDataTypeName(DbType dbType, Type? type) + { + if (dbType == DbType.String) + return "citext"; + + return null; + } + + public DbType? GetDbType(DataTypeName dataTypeName) + { + if (dataTypeName.UnqualifiedName == "citext") + return DbType.String; + + return null; + } + } + } + + class GuidTextConverterFactory(string typeName) : PgTypeInfoResolverFactory + { + public override IPgTypeInfoResolver? CreateArrayResolver() => null; + public override IPgTypeInfoResolver CreateResolver() => new GuidTextTypeInfoResolver(typeName); + + sealed class GuidTextTypeInfoResolver(string typeName) : IPgTypeInfoResolver + { + public PgTypeInfo? GetTypeInfo(Type? type, DataTypeName? dataTypeName, PgSerializerOptions options) + { + if (type == typeof(Guid) || dataTypeName?.UnqualifiedName == typeName) + if (options.DatabaseInfo.TryGetPostgresTypeByName(typeName, out var pgType)) + return new(options, new GuidTextConverter(options.TextEncoding), options.ToCanonicalTypeId(pgType)); + + return null; + } + } + + sealed class GuidTextConverter(System.Text.Encoding encoding) : StringBasedTextConverter(encoding) + { + public override bool CanConvert(DataFormat format, out BufferRequirements bufferRequirements) + { + bufferRequirements = BufferRequirements.None; + return format is DataFormat.Text; + } + protected override Guid ConvertFrom(string value) => Guid.Parse(value); + protected override ReadOnlyMemory ConvertTo(Guid value) => value.ToString().AsMemory(); + } + } + + class GuidTextDbTypeResolverFactory(string typeName) : DbTypeResolverFactory + { + public override IDbTypeResolver CreateDbTypeResolver(NpgsqlDatabaseInfo databaseInfo) => new DbTypeResolver(typeName); + + sealed class DbTypeResolver(string typeName) : IDbTypeResolver + { + public string? GetDataTypeName(DbType dbType, Type? type) + { + if (dbType == DbType.Guid) + return typeName; + return null; + } + + public DbType? GetDbType(DataTypeName dataTypeName) + { + if (dataTypeName == typeName) + return DbType.Guid; + return null; + } + } + } + enum Mood { Sad, Ok, Happy } #endregion Support diff --git a/test/Npgsql.Tests/Types/ArrayTests.cs b/test/Npgsql.Tests/Types/ArrayTests.cs index a567e4891e..cf6f27e038 100644 --- a/test/Npgsql.Tests/Types/ArrayTests.cs +++ b/test/Npgsql.Tests/Types/ArrayTests.cs @@ -3,7 +3,7 @@ using System.Collections.Generic; using System.Collections.Immutable; using System.Data; -using System.Linq; +using System.Diagnostics; using System.Text; using System.Threading.Tasks; using Npgsql.Internal.Converters; @@ -21,23 +21,23 @@ namespace Npgsql.Tests.Types; /// /// https://www.postgresql.org/docs/current/static/arrays.html /// -public class ArrayTests : MultiplexingTestBase +public class ArrayTests : TestBase { static readonly TestCaseData[] ArrayTestCases = - { - new TestCaseData(new[] { 1, 2, 3 }, "{1,2,3}", "integer[]", NpgsqlDbType.Integer | NpgsqlDbType.Array) + [ + new TestCaseData(new[] { 1, 2, 3 }, "{1,2,3}", "integer[]") .SetName("Integer_array"), - new TestCaseData(Array.Empty(), "{}", "integer[]", NpgsqlDbType.Integer | NpgsqlDbType.Array) + new TestCaseData(Array.Empty(), "{}", "integer[]") .SetName("Empty_array"), - new TestCaseData(new[,] { { 1, 2, 3 }, { 7, 8, 9 } }, "{{1,2,3},{7,8,9}}", "integer[]", NpgsqlDbType.Integer | NpgsqlDbType.Array) + new TestCaseData(new[,] { { 1, 2, 3 }, { 7, 8, 9 } }, "{{1,2,3},{7,8,9}}", "integer[]") .SetName("Two_dimensional_array"), - new TestCaseData(new[] { new byte[] { 1, 2 }, new byte[] { 3, 4 } }, """{"\\x0102","\\x0304"}""", "bytea[]", NpgsqlDbType.Bytea | NpgsqlDbType.Array) + new TestCaseData(new[] { [1, 2], new byte[] { 3, 4 } }, """{"\\x0102","\\x0304"}""", "bytea[]") .SetName("Bytea_array") - }; + ]; [Test, TestCaseSource(nameof(ArrayTestCases))] - public Task Arrays(T array, string sqlLiteral, string pgTypeName, NpgsqlDbType? npgsqlDbType) - => AssertType(array, sqlLiteral, pgTypeName, npgsqlDbType); + public Task Arrays(T array, string sqlLiteral, string dataTypeName) + => AssertType(array, sqlLiteral, dataTypeName); [Test] public async Task NullableInts() @@ -49,7 +49,7 @@ public async Task NullableInts() var dataSourceBuilder = new NpgsqlDataSourceBuilder(connectionStringBuilder.ToString()); await using var dataSource = dataSourceBuilder.Build(); - await AssertType(dataSource, new int?[] { 1, 2, null, 3 }, "{1,2,NULL,3}", "integer[]", NpgsqlDbType.Integer | NpgsqlDbType.Array); + await AssertType(dataSource, new int?[] { 1, 2, null, 3 }, "{1,2,NULL,3}", "integer[]"); } [Test, Description("Checks that PG arrays containing nulls can't be read as CLR arrays of non-nullable value types (the default).")] @@ -65,7 +65,7 @@ public async Task Throws_too_many_dimensions() cmd.Parameters.AddWithValue("p", new int[1, 1, 1, 1, 1, 1, 1, 1, 1]); // 9 dimensions Assert.That( () => cmd.ExecuteScalarAsync(), - Throws.Exception.TypeOf().With.Message.EqualTo("values (Parameter 'Postgres arrays can have at most 8 dimensions.')")); + Throws.Exception.TypeOf().With.Message.EqualTo("Postgres arrays can have at most 8 dimensions. (Parameter 'values')")); } [Test, Description("Checks that PG arrays containing nulls are returned as set via ValueTypeArrayMode.")] @@ -142,31 +142,132 @@ public async Task Value_type_array_nullabilities(ArrayNullabilityMode mode) Assert.That(value, Is.EqualTo(new int?[,]{{5, null},{6, 7}})); break; default: - throw new ArgumentOutOfRangeException(nameof(mode), mode, null); + throw new UnreachableException($"Unknown case {mode}"); + } + } + + [Test, Description("Checks that PG arrays containing nulls are returned as set via ValueTypeArrayMode.")] + [TestCase(ArrayNullabilityMode.Always)] + [TestCase(ArrayNullabilityMode.Never)] + [TestCase(ArrayNullabilityMode.PerInstance)] + public async Task Value_type_array_nullabilities_converter_resolver(ArrayNullabilityMode mode) + { + await using var dataSource = CreateDataSource(csb => + { + csb.ArrayNullabilityMode = mode; + csb.Timezone = "Europe/Berlin"; + }); + await using var conn = await dataSource.OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand( +""" +SELECT onedim, twodim FROM (VALUES +('{"1998-04-12 15:26:38+02"}'::timestamptz[],'{{"1998-04-12 15:26:38+02"},{"1998-04-13 15:26:38+02"}}'::timestamptz[][]), +('{"1998-04-14 15:26:38+02", NULL}'::timestamptz[],'{{"1998-04-14 15:26:38+02", NULL},{"1998-04-15 15:26:38+02", "1998-04-16 15:26:38+02"}}'::timestamptz[][])) AS x(onedim,twodim) +""", conn); + await using var reader = await cmd.ExecuteReaderAsync(); + + switch (mode) + { + case ArrayNullabilityMode.Never: + reader.Read(); + var value = reader.GetValue(0); + Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(Array))); + Assert.That(value.GetType(), Is.EqualTo(typeof(DateTime[]))); + Assert.That(value, Is.EqualTo(new []{new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc)})); + value = reader.GetValue(1); + Assert.That(reader.GetFieldType(1), Is.EqualTo(typeof(Array))); + Assert.That(value.GetType(), Is.EqualTo(typeof(DateTime[,]))); + Assert.That(value, Is.EqualTo(new [,] + { + { new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc) }, + { new DateTime(1998, 4, 13, 13, 26, 38, DateTimeKind.Utc) } + })); + reader.Read(); + Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(Array))); + Assert.That(() => reader.GetValue(0), Throws.Exception.TypeOf()); + Assert.That(reader.GetFieldType(1), Is.EqualTo(typeof(Array))); + Assert.That(() => reader.GetValue(1), Throws.Exception.TypeOf()); + break; + case ArrayNullabilityMode.Always: + reader.Read(); + value = reader.GetValue(0); + Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(Array))); + Assert.That(value.GetType(), Is.EqualTo(typeof(DateTime?[]))); + Assert.That(value, Is.EqualTo(new DateTime?[]{new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc)})); + value = reader.GetValue(1); + Assert.That(reader.GetFieldType(1), Is.EqualTo(typeof(Array))); + Assert.That(value.GetType(), Is.EqualTo(typeof(DateTime?[,]))); + Assert.That(value, Is.EqualTo(new DateTime?[,] + { + { new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc) }, + { new DateTime(1998, 4, 13, 13, 26, 38, DateTimeKind.Utc) } + })); + reader.Read(); + value = reader.GetValue(0); + Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(Array))); + Assert.That(value.GetType(), Is.EqualTo(typeof(DateTime?[]))); + Assert.That(value, Is.EqualTo(new DateTime?[]{ new DateTime(1998, 4, 14, 13, 26, 38, DateTimeKind.Utc), null })); + value = reader.GetValue(1); + Assert.That(reader.GetFieldType(1), Is.EqualTo(typeof(Array))); + Assert.That(value.GetType(), Is.EqualTo(typeof(DateTime?[,]))); + Assert.That(value, Is.EqualTo(new DateTime?[,] + { + { new DateTime(1998, 4, 14, 13, 26, 38, DateTimeKind.Utc), null }, + { new DateTime(1998, 4, 15, 13, 26, 38, DateTimeKind.Utc), new DateTime(1998, 4, 16, 13, 26, 38, DateTimeKind.Utc) } + })); + break; + case ArrayNullabilityMode.PerInstance: + reader.Read(); + value = reader.GetValue(0); + Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(Array))); + Assert.That(value.GetType(), Is.EqualTo(typeof(DateTime[]))); + Assert.That(value, Is.EqualTo(new []{new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc)})); + value = reader.GetValue(1); + Assert.That(reader.GetFieldType(1), Is.EqualTo(typeof(Array))); + Assert.That(value.GetType(), Is.EqualTo(typeof(DateTime[,]))); + Assert.That(value, Is.EqualTo(new [,] + { + { new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc) }, + { new DateTime(1998, 4, 13, 13, 26, 38, DateTimeKind.Utc) } + })); + reader.Read(); + value = reader.GetValue(0); + Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(Array))); + Assert.That(value.GetType(), Is.EqualTo(typeof(DateTime?[]))); + Assert.That(value, Is.EqualTo(new DateTime?[]{ new DateTime(1998, 4, 14, 13, 26, 38, DateTimeKind.Utc), null })); + value = reader.GetValue(1); + Assert.That(reader.GetFieldType(1), Is.EqualTo(typeof(Array))); + Assert.That(value.GetType(), Is.EqualTo(typeof(DateTime?[,]))); + Assert.That(value, Is.EqualTo(new DateTime?[,] + { + { new DateTime(1998, 4, 14, 13, 26, 38, DateTimeKind.Utc), null }, + { new DateTime(1998, 4, 15, 13, 26, 38, DateTimeKind.Utc), new DateTime(1998, 4, 16, 13, 26, 38, DateTimeKind.Utc) } + })); + break; + default: + throw new UnreachableException($"Unknown case {mode}"); } } // Note that PG normalizes empty multidimensional arrays to single-dimensional, e.g. ARRAY[[], []]::integer[] returns {}. [Test] public async Task Write_empty_multidimensional_array() - => await AssertTypeWrite(new int[0, 0], "{}", "integer[]", NpgsqlDbType.Integer | NpgsqlDbType.Array); + => await AssertTypeWrite(new int[0, 0], "{}", "integer[]"); [Test] public async Task Generic_List() => await AssertType( - new List { 1, 2, 3 }, "{1,2,3}", "integer[]", NpgsqlDbType.Integer | NpgsqlDbType.Array, isDefaultForReading: false); + new List { 1, 2, 3 }, "{1,2,3}", "integer[]", valueTypeEqualsFieldType: false); [Test] public async Task Write_IList_implementation() => await AssertTypeWrite( - ImmutableArray.Create(1, 2, 3), "{1,2,3}", "integer[]", NpgsqlDbType.Integer | NpgsqlDbType.Array); + ImmutableArray.Create(1, 2, 3), "{1,2,3}", "integer[]"); [Test] public void Read_IList_implementation_throws() - { - Assert.ThrowsAsync(() => - AssertTypeRead("{1,2,3}", "integer[]", ImmutableArray.Create(1, 2, 3), isDefault: false)); - } + => Assert.ThrowsAsync(() => + AssertTypeRead("{1,2,3}", "integer[]", ImmutableArray.Create(1, 2, 3), valueTypeEqualsFieldType: false)); [Test] public async Task Generic_IList() @@ -179,7 +280,7 @@ public async Task Generic_IList() var reader = await cmd.ExecuteReaderAsync(); reader.Read(); - Assert.AreEqual(expected, reader.GetFieldValue(0)); + Assert.That(reader.GetFieldValue(0), Is.EqualTo(expected)); } [Test, Description("Verifies that an InvalidOperationException is thrown when the returned array has a different number of dimensions from what was requested.")] @@ -311,7 +412,7 @@ public async Task Jagged_arrays_not_supported() { await using var conn = await OpenConnectionAsync(); await using var cmd = new NpgsqlCommand("SELECT @p1", conn); - cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Array | NpgsqlDbType.Integer, new[] { new[] { 8 }, new[] { 8, 10 } }); + cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Array | NpgsqlDbType.Integer, new[] { [8], new[] { 8, 10 } }); Assert.That(async () => await cmd.ExecuteNonQueryAsync(), Throws.Exception .TypeOf() .With.Property("InnerException").Message.Contains("jagged")); @@ -320,9 +421,6 @@ public async Task Jagged_arrays_not_supported() [Test, Description("Roundtrips one-dimensional and two-dimensional arrays of a PostgreSQL domain.")] public async Task Array_of_domain() { - if (IsMultiplexing) - Assert.Ignore("Multiplexing, ReloadTypes"); - await using var conn = await OpenConnectionAsync(); MinimumPgVersion(conn, "11.0", "Arrays of domains were introduced in PostgreSQL 11"); await conn.ExecuteNonQueryAsync("CREATE DOMAIN pg_temp.posint AS integer CHECK (VALUE > 0);"); @@ -352,9 +450,6 @@ public async Task Array_of_domain() [Test, Description("Roundtrips a PostgreSQL domain over a one-dimensional and a two-dimensional array.")] public async Task Domain_of_array() { - if (IsMultiplexing) - Assert.Ignore("Multiplexing, ReloadTypes"); - await using var conn = await OpenConnectionAsync(); MinimumPgVersion(conn, "11.0", "Domains over arrays were introduced in PostgreSQL 11"); await conn.ExecuteNonQueryAsync( @@ -391,9 +486,9 @@ public async Task Read_two_empty_arrays() await using var cmd = new NpgsqlCommand("SELECT '{}'::INT[], '{}'::INT[]", conn); await using var reader = await cmd.ExecuteReaderAsync(); await reader.ReadAsync(); - Assert.AreSame(reader.GetFieldValue(0), reader.GetFieldValue(1)); + Assert.That(reader.GetFieldValue(1), Is.SameAs(reader.GetFieldValue(0))); // Unlike T[], List is mutable so we should not return the same instance - Assert.AreNotSame(reader.GetFieldValue>(0), reader.GetFieldValue>(1)); + Assert.That(reader.GetFieldValue>(1), Is.Not.SameAs(reader.GetFieldValue>(0))); } [Test] @@ -403,7 +498,7 @@ public async Task Arrays_not_supported_by_default_on_NpgsqlSlimSourceBuilder() await using var dataSource = dataSourceBuilder.Build(); await AssertTypeUnsupportedRead("{1,2,3}", "integer[]", dataSource); - await AssertTypeUnsupportedWrite(new[] { 1, 2, 3 }, "integer[]", dataSource); + await AssertTypeUnsupportedWrite([1, 2, 3], "integer[]", dataSource); } [Test] @@ -413,8 +508,6 @@ public async Task NpgsqlSlimSourceBuilder_EnableArrays() dataSourceBuilder.EnableArrays(); await using var dataSource = dataSourceBuilder.Build(); - await AssertType(dataSource, new[] { 1, 2, 3 }, "{1,2,3}", "integer[]", NpgsqlDbType.Integer | NpgsqlDbType.Array); + await AssertType(dataSource, new[] { 1, 2, 3 }, "{1,2,3}", "integer[]"); } - - public ArrayTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/Types/BitStringTests.cs b/test/Npgsql.Tests/Types/BitStringTests.cs index 95c81ffb41..e41ae0cc8f 100644 --- a/test/Npgsql.Tests/Types/BitStringTests.cs +++ b/test/Npgsql.Tests/Types/BitStringTests.cs @@ -1,6 +1,7 @@ using System; using System.Collections; using System.Collections.Specialized; +using System.Data; using System.Threading.Tasks; using NpgsqlTypes; using NUnit.Framework; @@ -13,7 +14,7 @@ namespace Npgsql.Tests.Types; /// /// https://www.postgresql.org/docs/current/static/datatype-bit.html /// -public class BitStringTests : MultiplexingTestBase +public class BitStringTests : TestBase { [Test] [TestCase("10110110", TestName = "BitArray")] @@ -27,10 +28,10 @@ public async Task BitArray(string sqlLiteral) for (var i = 0; i < sqlLiteral.Length; i++) bitArray[i] = sqlLiteral[i] == '1'; - await AssertType(bitArray, sqlLiteral, "bit varying", NpgsqlDbType.Varbit); + await AssertType(bitArray, sqlLiteral, "bit varying"); if (len > 0) - await AssertType(bitArray, sqlLiteral, $"bit({len})", NpgsqlDbType.Bit, isDefaultForWriting: false); + await AssertType(bitArray, sqlLiteral, $"bit({len})", dataTypeInference: DataTypeInference.Mismatch); } [Test] @@ -47,7 +48,7 @@ public async Task BitArray_long() [Test] public Task BitVector32() => AssertType( - new BitVector32(4), "00000000000000000000000000000100", "bit varying", NpgsqlDbType.Varbit, isDefaultForReading: false); + new BitVector32(4), "00000000000000000000000000000100", "bit varying", valueTypeEqualsFieldType: false); [Test] public Task BitVector32_too_long() @@ -55,7 +56,7 @@ public Task BitVector32_too_long() [Test] public Task Bool() - => AssertType(true, "1", "bit(1)", NpgsqlDbType.Bit, isDefault: false); + => AssertType(true, "1", "bit(1)", dataTypeInference: DataTypeInference.Mismatch, dbType: new(DbType.Object, DbType.Boolean)); [Test] public async Task Bitstring_with_multiple_bits_as_bool_throws() @@ -69,7 +70,7 @@ public async Task Array() { using var conn = await OpenConnectionAsync(); using var cmd = new NpgsqlCommand("SELECT @p", conn); - var expected = new[] { new BitArray(new[] { true, false, true }), new BitArray(new[] { false }) }; + var expected = new[] { new BitArray([true, false, true]), new BitArray([false]) }; var p = new NpgsqlParameter("p", NpgsqlDbType.Array | NpgsqlDbType.Varbit) { Value = expected }; cmd.Parameters.Add(p); p.Value = expected; @@ -118,11 +119,11 @@ public async Task Array_of_single_bits_and_null() [Test] public Task As_string() - => AssertType("010101", "010101", "bit varying", NpgsqlDbType.Varbit, isDefault: false); + => AssertType("010101", "010101", + "bit varying", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Object, DbType.String), valueTypeEqualsFieldType: false); [Test] public Task Write_as_string_validation() => AssertTypeUnsupportedWrite("001q0", "bit varying"); - - public BitStringTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/Types/ByteaTests.cs b/test/Npgsql.Tests/Types/ByteaTests.cs index c34bce04ff..d87ed48216 100644 --- a/test/Npgsql.Tests/Types/ByteaTests.cs +++ b/test/Npgsql.Tests/Types/ByteaTests.cs @@ -2,6 +2,7 @@ using System.Collections.Generic; using System.Data; using System.IO; +using System.Net.Sockets; using System.Threading.Tasks; using NpgsqlTypes; using NUnit.Framework; @@ -14,13 +15,13 @@ namespace Npgsql.Tests.Types; /// /// https://www.postgresql.org/docs/current/static/datatype-binary.html /// -public class ByteaTests : MultiplexingTestBase +public class ByteaTests : TestBase { [Test] [TestCase(new byte[] { 1, 2, 3, 4, 5 }, "\\x0102030405", TestName = "Bytea")] [TestCase(new byte[] { }, "\\x", TestName = "Bytea_empty")] public Task Bytea(byte[] byteArray, string sqlLiteral) - => AssertType(byteArray, sqlLiteral, "bytea", NpgsqlDbType.Bytea, DbType.Binary); + => AssertType(byteArray, sqlLiteral, "bytea", dbType: DbType.Binary); [Test] public async Task Bytea_long() @@ -37,37 +38,38 @@ public async Task Bytea_long() [Test] public Task AsMemory() => AssertType( - new Memory(new byte[] { 1, 2, 3 }), "\\x010203", "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false, - comparer: (left, right) => left.Span.SequenceEqual(right.Span)); + new Memory([1, 2, 3]), "\\x010203", "bytea", dbType: DbType.Binary, + comparer: (left, right) => left.Span.SequenceEqual(right.Span), + valueTypeEqualsFieldType: false); [Test] public Task AsReadOnlyMemory() => AssertType( - new ReadOnlyMemory(new byte[] { 1, 2, 3 }), "\\x010203", "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false, - comparer: (left, right) => left.Span.SequenceEqual(right.Span)); + new ReadOnlyMemory([1, 2, 3]), "\\x010203", "bytea", dbType: DbType.Binary, + comparer: (left, right) => left.Span.SequenceEqual(right.Span), + valueTypeEqualsFieldType: false); [Test] public Task AsArraySegment() - => AssertType( - new ArraySegment(new byte[] { 1, 2, 3 }), "\\x010203", "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false); + => AssertType(new ArraySegment([1, 2, 3]), "\\x010203", + "bytea", dbType: DbType.Binary, valueTypeEqualsFieldType: false); [Test] public Task Write_as_MemoryStream() => AssertTypeWrite( - () => new MemoryStream(new byte[] { 1, 2, 3 }), "\\x010203", "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false); + () => new MemoryStream([1, 2, 3]), "\\x010203", "bytea", dbType: DbType.Binary); [Test] public Task Write_as_MemoryStream_truncated() { var msFactory = () => { - var ms = new MemoryStream(new byte[] { 1, 2, 3, 4 }); + var ms = new MemoryStream([1, 2, 3, 4]); ms.ReadByte(); return ms; }; - return AssertTypeWrite( - msFactory, "\\x020304", "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false); + return AssertTypeWrite(valueFactory: msFactory, "\\x020304", "bytea", dbType: DbType.Binary); } [Test] @@ -84,8 +86,7 @@ public Task Write_as_MemoryStream_exposableArray() return ms; }; - return AssertTypeWrite( - msFactory, "\\x020304", "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false); + return AssertTypeWrite(valueFactory: msFactory, "\\x020304", "bytea", dbType: DbType.Binary); } [Test] @@ -96,8 +97,7 @@ public async Task Write_as_MemoryStream_long() rnd.NextBytes(bytes); var expectedSql = "\\x" + ToHex(bytes); - await AssertTypeWrite( - () => new MemoryStream(bytes), expectedSql, "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false); + await AssertTypeWrite(() => new MemoryStream(bytes), expectedSql, "bytea", dbType: DbType.Binary); } [Test] @@ -107,10 +107,9 @@ public async Task Write_as_FileStream() var fsList = new List(); try { - await File.WriteAllBytesAsync(filePath, new byte[] { 1, 2, 3 }); + await File.WriteAllBytesAsync(filePath, [1, 2, 3]); - await AssertTypeWrite( - () => FileStreamFactory(filePath, fsList), "\\x010203", "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false); + await AssertTypeWrite(() => FileStreamFactory(filePath, fsList), "\\x010203", "bytea", dbType: DbType.Binary); } finally { @@ -145,8 +144,7 @@ public async Task Write_as_FileStream_long() await File.WriteAllBytesAsync(filePath, bytes); var expectedSql = "\\x" + ToHex(bytes); - await AssertTypeWrite( - () => FileStreamFactory(filePath, fsList), expectedSql, "bytea", NpgsqlDbType.Bytea, DbType.Binary, isDefault: false); + await AssertTypeWrite(() => FileStreamFactory(filePath, fsList), expectedSql, "bytea", dbType: DbType.Binary); } finally { @@ -191,14 +189,14 @@ public async Task Truncate_array() { await using var conn = await OpenConnectionAsync(); await using var cmd = new NpgsqlCommand("SELECT @p", conn); - byte[] data = { 1, 2, 3, 4, 5, 6 }; + byte[] data = [1, 2, 3, 4, 5, 6]; var p = new NpgsqlParameter("p", data) { Size = 4 }; cmd.Parameters.Add(p); Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(new byte[] { 1, 2, 3, 4 })); Assert.That(p.Value, Is.EqualTo(new byte[] { 1, 2, 3, 4 }), "Truncated parameter value should be persisted on the parameter per DbParameter.Size docs"); // NpgsqlParameter.Size needs to persist when value is changed - byte[] data2 = { 11, 12, 13, 14, 15, 16 }; + byte[] data2 = [11, 12, 13, 14, 15, 16]; p.Value = data2; Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(new byte[] { 11, 12, 13, 14 })); @@ -219,13 +217,13 @@ public async Task Truncate_stream() { await using var conn = await OpenConnectionAsync(); await using var cmd = new NpgsqlCommand("SELECT @p", conn); - byte[] data = { 1, 2, 3, 4, 5, 6 }; + byte[] data = [1, 2, 3, 4, 5, 6]; var p = new NpgsqlParameter("p", new MemoryStream(data)) { Size = 4 }; cmd.Parameters.Add(p); Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(new byte[] { 1, 2, 3, 4 })); // NpgsqlParameter.Size needs to persist when value is changed - byte[] data2 = { 11, 12, 13, 14, 15, 16 }; + byte[] data2 = [11, 12, 13, 14, 15, 16]; p.Value = new MemoryStream(data2); Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(new byte[] { 11, 12, 13, 14 })); @@ -254,7 +252,7 @@ public async Task Write_as_NonSeekable_stream() { await using var conn = await OpenConnectionAsync(); await using var cmd = new NpgsqlCommand("SELECT @p", conn); - byte[] data = { 1, 2, 3, 4, 5, 6 }; + byte[] data = [1, 2, 3, 4, 5, 6]; var p = new NpgsqlParameter("p", new NonSeekableStream(data)) { Size = 4 }; cmd.Parameters.Add(p); Assert.That(await cmd.ExecuteScalarAsync(), Is.EqualTo(new byte[] { 1, 2, 3, 4 })); @@ -279,19 +277,26 @@ public async Task Array_of_bytea() var inVal = new[] { bytes, bytes }; cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Bytea | NpgsqlDbType.Array, inVal); var retVal = (byte[][]?)await cmd.ExecuteScalarAsync(); - Assert.AreEqual(inVal.Length, retVal!.Length); - Assert.AreEqual(inVal[0], retVal[0]); - Assert.AreEqual(inVal[1], retVal[1]); + Assert.That(retVal!.Length, Is.EqualTo(inVal.Length)); + Assert.That(retVal[0], Is.EqualTo(inVal[0])); + Assert.That(retVal[1], Is.EqualTo(inVal[1])); } - sealed class NonSeekableStream : MemoryStream + [Test] + public async Task InvalidCastException_unknown_stream_read() { - public override bool CanSeek => false; - - public NonSeekableStream(byte[] data) : base(data) + await using var conn = await OpenConnectionAsync(); + await using var cmd = new NpgsqlCommand("SELECT :p1", conn); + cmd.Parameters.AddWithValue("p1", NpgsqlDbType.Bytea, new byte[] { 1 }); + await using var reader = await cmd.ExecuteReaderAsync(); + while (await reader.ReadAsync()) { + Assert.Throws(() => reader.GetFieldValue(0)); } } - public ByteaTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} + sealed class NonSeekableStream(byte[] data) : MemoryStream(data) + { + public override bool CanSeek => false; + } } diff --git a/test/Npgsql.Tests/Types/CompositeHandlerTests.Read.cs b/test/Npgsql.Tests/Types/CompositeHandlerTests.Read.cs index 2188569a49..732dbf83e1 100644 --- a/test/Npgsql.Tests/Types/CompositeHandlerTests.Read.cs +++ b/test/Npgsql.Tests/Types/CompositeHandlerTests.Read.cs @@ -32,27 +32,27 @@ async Task Read(T composite, Action, T> assert, string? schema = null [Test] public Task Read_class_with_property() => - Read((execute, expected) => Assert.AreEqual(expected.Value, execute().Value)); + Read((execute, expected) => Assert.That(execute().Value, Is.EqualTo(expected.Value))); [Test] public Task Read_class_with_field() => - Read((execute, expected) => Assert.AreEqual(expected.Value, execute().Value)); + Read((execute, expected) => Assert.That(execute().Value, Is.EqualTo(expected.Value))); [Test] public Task Read_struct_with_property() => - Read((execute, expected) => Assert.AreEqual(expected.Value, execute().Value)); + Read((execute, expected) => Assert.That(execute().Value, Is.EqualTo(expected.Value))); [Test] public Task Read_struct_with_field() => - Read((execute, expected) => Assert.AreEqual(expected.Value, execute().Value)); + Read((execute, expected) => Assert.That(execute().Value, Is.EqualTo(expected.Value))); [Test] public Task Read_type_with_two_properties() => Read((execute, expected) => { var actual = execute(); - Assert.AreEqual(expected.IntValue, actual.IntValue); - Assert.AreEqual(expected.StringValue, actual.StringValue); + Assert.That(actual.IntValue, Is.EqualTo(expected.IntValue)); + Assert.That(actual.StringValue, Is.EqualTo(expected.StringValue)); }); [Test] @@ -60,8 +60,8 @@ public Task Read_type_with_two_properties_inverted() => Read((execute, expected) => { var actual = execute(); - Assert.AreEqual(expected.IntValue, actual.IntValue); - Assert.AreEqual(expected.StringValue, actual.StringValue); + Assert.That(actual.IntValue, Is.EqualTo(expected.IntValue)); + Assert.That(actual.StringValue, Is.EqualTo(expected.StringValue)); }); [Test] @@ -98,7 +98,7 @@ public Task Read_type_with_more_properties_than_attributes() => Read(new TypeWithMorePropertiesThanAttributes(), (execute, expected) => { var actual = execute(); - Assert.That(actual.IntValue, Is.Not.Null); + Assert.That((int?)actual.IntValue, Is.Not.Null); Assert.That(actual.StringValue, Is.Null); }); diff --git a/test/Npgsql.Tests/Types/CompositeHandlerTests.Write.cs b/test/Npgsql.Tests/Types/CompositeHandlerTests.Write.cs index 160b037a97..800270f7c3 100644 --- a/test/Npgsql.Tests/Types/CompositeHandlerTests.Write.cs +++ b/test/Npgsql.Tests/Types/CompositeHandlerTests.Write.cs @@ -45,34 +45,34 @@ async Task Write(T composite, Action? assert = null, str [Test] public Task Write_class_with_property() - => Write((reader, expected) => Assert.AreEqual(expected.Value, reader.GetString(0))); + => Write((reader, expected) => Assert.That(reader.GetString(0), Is.EqualTo(expected.Value))); [Test] public Task Write_class_with_field() - => Write((reader, expected) => Assert.AreEqual(expected.Value, reader.GetString(0))); + => Write((reader, expected) => Assert.That(reader.GetString(0), Is.EqualTo(expected.Value))); [Test] public Task Write_struct_with_property() - => Write((reader, expected) => Assert.AreEqual(expected.Value, reader.GetString(0))); + => Write((reader, expected) => Assert.That(reader.GetString(0), Is.EqualTo(expected.Value))); [Test] public Task Write_struct_with_field() - => Write((reader, expected) => Assert.AreEqual(expected.Value, reader.GetString(0))); + => Write((reader, expected) => Assert.That(reader.GetString(0), Is.EqualTo(expected.Value))); [Test] public Task Write_type_with_two_properties() => Write((reader, expected) => { - Assert.AreEqual(expected.IntValue, reader.GetInt32(0)); - Assert.AreEqual(expected.StringValue, reader.GetString(1)); + Assert.That(reader.GetInt32(0), Is.EqualTo(expected.IntValue)); + Assert.That(reader.GetString(1), Is.EqualTo(expected.StringValue)); }); [Test] public Task Write_type_with_two_properties_inverted() => Write((reader, expected) => { - Assert.AreEqual(expected.IntValue, reader.GetInt32(1)); - Assert.AreEqual(expected.StringValue, reader.GetString(0)); + Assert.That(reader.GetInt32(1), Is.EqualTo(expected.IntValue)); + Assert.That(reader.GetString(0), Is.EqualTo(expected.StringValue)); }); [Test] diff --git a/test/Npgsql.Tests/Types/CompositeHandlerTests.cs b/test/Npgsql.Tests/Types/CompositeHandlerTests.cs index 1df95980a3..cc84efd094 100644 --- a/test/Npgsql.Tests/Types/CompositeHandlerTests.cs +++ b/test/Npgsql.Tests/Types/CompositeHandlerTests.cs @@ -155,10 +155,9 @@ public class TypeWithExplicitPropertyName : SimpleComposite protected override string GetValue() => MyValue; } - public class TypeWithExplicitParameterName : SimpleComposite + public class TypeWithExplicitParameterName([PgName("value")] string myValue) : SimpleComposite { - public TypeWithExplicitParameterName([PgName("value")] string myValue) => Value = myValue; - public string Value { get; } + public string Value { get; } = myValue; protected override string GetValue() => Value; } @@ -178,81 +177,72 @@ public class TypeWithLessPropertiesThanAttributes : IComposite public int IntValue { get; set; } } - public class TypeWithMoreParametersThanAttributes : IComposite + public class TypeWithMoreParametersThanAttributes(int intValue, string? stringValue) : IComposite { public string GetAttributes() => "int_value integer"; public string GetValues() => $"{IntValue}"; - public TypeWithMoreParametersThanAttributes(int intValue, string? stringValue) - { - IntValue = intValue; - StringValue = stringValue; - } - - public int IntValue { get; set; } - public string? StringValue { get; set; } + public int IntValue { get; set; } = intValue; + public string? StringValue { get; set; } = stringValue; } - public class TypeWithLessParametersThanAttributes : IComposite + public class TypeWithLessParametersThanAttributes(int intValue) : IComposite { public string GetAttributes() => "int_value integer, string_value text"; public string GetValues() => $"{IntValue}, NULL"; - public TypeWithLessParametersThanAttributes(int intValue) => - IntValue = intValue; - - public int IntValue { get; } + public int IntValue { get; } = intValue; } - public class TypeWithOneParameter : IComposite + public class TypeWithOneParameter(int value1) : IComposite { public string GetAttributes() => "value1 integer"; public string GetValues() => $"{Value1}"; - public TypeWithOneParameter(int value1) => Value1 = value1; - public int Value1 { get; } + public int Value1 { get; } = value1; } - public class TypeWithTwoParameters : IComposite + public class TypeWithTwoParameters(int intValue, string stringValue) : IComposite { public string GetAttributes() => "int_value integer, string_value text"; public string GetValues() => $"{IntValue}, '{StringValue}'"; - public TypeWithTwoParameters(int intValue, string stringValue) => - (IntValue, StringValue) = (intValue, stringValue); - - public int IntValue { get; } - public string? StringValue { get; } + public int IntValue { get; } = intValue; + public string? StringValue { get; } = stringValue; } - public class TypeWithTwoParametersReversed : IComposite + public class TypeWithTwoParametersReversed(string stringValue, int intValue) : IComposite { public string GetAttributes() => "int_value integer, string_value text"; public string GetValues() => $"{IntValue}, '{StringValue}'"; - public TypeWithTwoParametersReversed(string stringValue, int intValue) => - (StringValue, IntValue) = (stringValue, intValue); - - public int IntValue { get; } - public string? StringValue { get; } + public int IntValue { get; } = intValue; + public string? StringValue { get; } = stringValue; } - public class TypeWithNineParameters : IComposite + public class TypeWithNineParameters( + int value1, + int value2, + int value3, + int value4, + int value5, + int value6, + int value7, + int value8, + int value9) + : IComposite { public string GetAttributes() => "value1 integer, value2 integer, value3 integer, value4 integer, value5 integer, value6 integer, value7 integer, value8 integer, value9 integer"; public string GetValues() => $"{Value1}, {Value2}, {Value3}, {Value4}, {Value5}, {Value6}, {Value7}, {Value8}, {Value9}"; - public TypeWithNineParameters(int value1, int value2, int value3, int value4, int value5, int value6, int value7, int value8, int value9) - => (Value1, Value2, Value3, Value4, Value5, Value6, Value7, Value8, Value9) = (value1, value2, value3, value4, value5, value6, value7, value8, value9); - - public int Value1 { get; } - public int Value2 { get; } - public int Value3 { get; } - public int Value4 { get; } - public int Value5 { get; } - public int Value6 { get; } - public int Value7 { get; } - public int Value8 { get; } - public int Value9 { get; } + public int Value1 { get; } = value1; + public int Value2 { get; } = value2; + public int Value3 { get; } = value3; + public int Value4 { get; } = value4; + public int Value5 { get; } = value5; + public int Value6 { get; } = value6; + public int Value7 { get; } = value7; + public int Value8 { get; } = value8; + public int Value9 { get; } = value9; } } diff --git a/test/Npgsql.Tests/Types/CompositeTests.cs b/test/Npgsql.Tests/Types/CompositeTests.cs index 36257e126a..d9c1db253d 100644 --- a/test/Npgsql.Tests/Types/CompositeTests.cs +++ b/test/Npgsql.Tests/Types/CompositeTests.cs @@ -1,5 +1,6 @@ using System; using System.Linq; +using System.Net; using System.Reflection; using System.Threading.Tasks; using Npgsql.PostgresTypes; @@ -9,7 +10,7 @@ namespace Npgsql.Tests.Types; -public class CompositeTests : MultiplexingTestBase +public class CompositeTests : TestBase { [Test] public async Task Basic() @@ -29,7 +30,7 @@ await AssertType( new SomeComposite { SomeText = "foo", X = 8 }, "(8,foo)", type, - npgsqlDbType: null); + dataTypeInference: DataTypeInference.Nothing); } [Test] @@ -51,7 +52,7 @@ await AssertType( new SomeComposite { SomeText = "foo", X = 8 }, "(8,foo)", type, - npgsqlDbType: null); + dataTypeInference: DataTypeInference.Nothing); } [Test] @@ -72,7 +73,7 @@ await AssertType( new SomeComposite { SomeText = "foo", X = 8 }, "(8,foo)", type, - npgsqlDbType: null); + dataTypeInference: DataTypeInference.Nothing); } class CustomTranslator : INpgsqlNameTranslator @@ -105,7 +106,7 @@ await AssertType( new SomeComposite { SomeText = "foo", X = 8 }, "(8,foo)", type, - npgsqlDbType: null); + dataTypeInference: DataTypeInference.Nothing); } finally { @@ -138,7 +139,7 @@ await AssertType( new SomeCompositeContainer { A = 8, Containee = new() { SomeText = "foo", X = 9 } }, @"(8,""(9,foo)"")", containerType, - npgsqlDbType: null); + dataTypeInference: DataTypeInference.Nothing); } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1168")] @@ -159,7 +160,7 @@ await AssertType( new SomeComposite { SomeText = "foo", X = 8 }, "(8,foo)", $"{schema}.some_composite", - npgsqlDbType: null); + dataTypeInference: DataTypeInference.Nothing); } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/4365")] @@ -189,16 +190,36 @@ await AssertType( new SomeCompositeContainer { A = 8, Containee = new() { SomeText = "foo", X = 9 } }, @"(8,""(9,foo)"")", $"{secondSchemaName}.container", - npgsqlDbType: null, - isDefaultForWriting: false); + dataTypeInference: DataTypeInference.Nothing); await AssertType( connection, new SomeCompositeContainer { A = 8, Containee = new() { SomeText = "foo", X = 9 } }, @"(8,""(9,foo)"")", $"{firstSchemaName}.container", - npgsqlDbType: null, - isDefaultForWriting: true); + dataTypeInference: DataTypeInference.Nothing); + } + + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/5972")] + public async Task With_schema_and_dots_in_type_name() + { + await using var adminConnection = await OpenConnectionAsync(); + var schema = await CreateTempSchema(adminConnection); + var typename = "Some.Composite.with.dots"; + + await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {schema}.\"{typename}\" AS (x int, some_text text)"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite($"{schema}.{typename}"); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + await AssertType( + connection, + new SomeComposite { SomeText = "foobar", X = 10 }, + "(10,foobar)", + $"{schema}.\"{typename}\"", + dataTypeInference: DataTypeInference.Nothing); } [Test] @@ -219,7 +240,7 @@ await AssertType( new SomeCompositeStruct { SomeText = "foo", X = 8 }, "(8,foo)", type, - npgsqlDbType: null); + dataTypeInference: DataTypeInference.Nothing); } [Test] @@ -240,7 +261,7 @@ await AssertType( new SomeComposite[] { new() { SomeText = "foo", X = 8 }, new() { SomeText = "bar", X = 9 }}, @"{""(8,foo)"",""(9,bar)""}", type + "[]", - npgsqlDbType: null); + dataTypeInference: DataTypeInference.Nothing); } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/859")] @@ -262,7 +283,7 @@ await AssertType( new NameTranslationComposite { Simple = 2, TwoWords = 3, SomeClrName = 4 }, "(2,3,4)", type, - npgsqlDbType: null); + dataTypeInference: DataTypeInference.Nothing); } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/856")] @@ -286,7 +307,7 @@ await AssertType( new Address { PostalCode = "12345", Street = "Main St." }, @"(""Main St."",12345)", compositeType, - npgsqlDbType: null); + dataTypeInference: DataTypeInference.Nothing); } [Test] @@ -305,13 +326,62 @@ await adminConnection.ExecuteNonQueryAsync($@" await AssertType( connection, - new SomeCompositeWithArray { Ints = new[] { 1, 2, 3, 4 } }, + new SomeCompositeWithArray { Ints = [1, 2, 3, 4] }, @"(""{1,2,3,4}"")", compositeType, - npgsqlDbType: null, + dataTypeInference: DataTypeInference.Nothing, comparer: (actual, expected) => actual.Ints!.SequenceEqual(expected.Ints!)); } + [Test] + public async Task Composite_containing_enum_type() + { + await using var adminConnection = await OpenConnectionAsync(); + var enumType = await GetTempTypeName(adminConnection); + var compositeType = await GetTempTypeName(adminConnection); + + await adminConnection.ExecuteNonQueryAsync($@" +CREATE TYPE {enumType} AS enum ('value1', 'value2', 'value3'); +CREATE TYPE {compositeType} AS (enum_value {enumType});"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(compositeType); + dataSourceBuilder.MapEnum(enumType); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + await AssertType( + connection, + new SomeCompositeWithEnum { EnumValue = SomeCompositeWithEnum.TestEnum.Value2 }, + @"(value2)", + compositeType, + dataTypeInference: DataTypeInference.Nothing, + comparer: (actual, expected) => actual.EnumValue == expected.EnumValue); + } + + [Test] + public async Task Composite_containing_IPAddress() + { + await using var adminConnection = await OpenConnectionAsync(); + var compositeType = await GetTempTypeName(adminConnection); + + await adminConnection.ExecuteNonQueryAsync($@" +CREATE TYPE {compositeType} AS (address inet)"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(compositeType); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + await AssertType( + connection, + new SomeCompositeWithIPAddress { Address = IPAddress.Loopback }, + @"(127.0.0.1)", + compositeType, + dataTypeInference: DataTypeInference.Nothing, + comparer: (actual, expected) => actual.Address!.Equals(expected.Address)); + } + [Test] public async Task Composite_containing_converter_resolver_type() { @@ -329,10 +399,12 @@ await adminConnection.ExecuteNonQueryAsync($@" await AssertType( connection, - new SomeCompositeWithConverterResolverType { DateTimes = new [] { new DateTime(DateTime.UnixEpoch.Ticks, DateTimeKind.Unspecified), new DateTime(DateTime.UnixEpoch.Ticks, DateTimeKind.Unspecified).AddDays(1) } }, + new SomeCompositeWithConverterResolverType { DateTimes = [new DateTime(DateTime.UnixEpoch.Ticks, DateTimeKind.Unspecified), new DateTime(DateTime.UnixEpoch.Ticks, DateTimeKind.Unspecified).AddDays(1) + ] + }, """("{""1970-01-01 00:00:00"",""1970-01-02 00:00:00""}")""", compositeType, - npgsqlDbType: null, + dataTypeInference: DataTypeInference.Nothing, comparer: (actual, expected) => actual.DateTimes!.SequenceEqual(expected.DateTimes!)); } @@ -353,10 +425,10 @@ await adminConnection.ExecuteNonQueryAsync($@" Assert.ThrowsAsync(() => AssertType( connection, - new SomeCompositeWithConverterResolverType { DateTimes = new[] { DateTime.UnixEpoch } }, // UTC DateTime + new SomeCompositeWithConverterResolverType { DateTimes = [DateTime.UnixEpoch] }, // UTC DateTime """("{""1970-01-01 01:00:00"",""1970-01-02 01:00:00""}")""", compositeType, - npgsqlDbType: null, + dataTypeInference: DataTypeInference.Nothing, comparer: (actual, expected) => actual.DateTimes!.SequenceEqual(expected.DateTimes!))); } @@ -368,8 +440,7 @@ public async Task Table_as_composite([Values] bool enabled) var dataSourceBuilder = CreateDataSourceBuilder(); dataSourceBuilder.MapComposite(table); - if (enabled) - dataSourceBuilder.ConnectionStringBuilder.LoadTableComposites = true; + dataSourceBuilder.ConfigureTypeLoading(b => b.EnableTableCompositesLoading(enabled)); await using var dataSource = dataSourceBuilder.Build(); await using var connection = await dataSource.OpenConnectionAsync(); @@ -378,20 +449,16 @@ public async Task Table_as_composite([Values] bool enabled) else { Assert.ThrowsAsync(DoAssertion); - // Start a transaction specifically for multiplexing (to bind a connector to the connection) - await using var tx = await connection.BeginTransactionAsync(); - Assert.Null(connection.Connector!.DatabaseInfo.CompositeTypes.SingleOrDefault(c => c.Name.Contains(table))); - Assert.Null(connection.Connector!.DatabaseInfo.ArrayTypes.SingleOrDefault(c => c.Name.Contains(table))); + Assert.That(connection.Connector!.DatabaseInfo.CompositeTypes.SingleOrDefault(c => c.Name.Contains(table)), Is.Null); + Assert.That(connection.Connector!.DatabaseInfo.ArrayTypes.SingleOrDefault(c => c.Name.Contains(table)), Is.Null); } Task DoAssertion() => AssertType( - connection, - new SomeComposite { SomeText = "foo", X = 8 }, - "(8,foo)", - table, - npgsqlDbType: null); + dataSource, + new SomeComposite { SomeText = "foo", X = 8 }, "(8,foo)", + table, dataTypeInference: DataTypeInference.Nothing); } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1267")] @@ -402,17 +469,14 @@ public async Task Table_as_composite_with_deleted_columns() await adminConnection.ExecuteNonQueryAsync($"ALTER TABLE {table} DROP COLUMN bar;"); var dataSourceBuilder = CreateDataSourceBuilder(); - dataSourceBuilder.ConnectionStringBuilder.LoadTableComposites = true; + dataSourceBuilder.ConfigureTypeLoading(b => b.EnableTableCompositesLoading()); dataSourceBuilder.MapComposite(table); await using var dataSource = dataSourceBuilder.Build(); - await using var connection = await dataSource.OpenConnectionAsync(); await AssertType( - connection, - new SomeComposite { SomeText = "foo", X = 8 }, - "(8,foo)", - table, - npgsqlDbType: null); + dataSource, + new SomeComposite { SomeText = "foo", X = 8 }, "(8,foo)", + table, dataTypeInference: DataTypeInference.Nothing); } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1125")] @@ -426,21 +490,16 @@ public async Task Nullable_property_in_class_composite() var dataSourceBuilder = CreateDataSourceBuilder(); dataSourceBuilder.MapComposite(type); await using var dataSource = dataSourceBuilder.Build(); - await using var connection = await dataSource.OpenConnectionAsync(); await AssertType( - connection, - new ClassWithNullableProperty { Foo = 8 }, - "(8)", - type, - npgsqlDbType: null); + dataSource, + new ClassWithNullableProperty { Foo = 8 }, "(8)", + type, dataTypeInference: DataTypeInference.Nothing); await AssertType( - connection, - new ClassWithNullableProperty { Foo = null }, - "()", - type, - npgsqlDbType: null); + dataSource, + new ClassWithNullableProperty { Foo = null }, "()", + type, dataTypeInference: DataTypeInference.Nothing); } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1125")] @@ -454,27 +513,24 @@ public async Task Nullable_property_in_struct_composite() var dataSourceBuilder = CreateDataSourceBuilder(); dataSourceBuilder.MapComposite(type); await using var dataSource = dataSourceBuilder.Build(); - await using var connection = await dataSource.OpenConnectionAsync(); await AssertType( - connection, - new StructWithNullableProperty { Foo = 8 }, - "(8)", - type, - npgsqlDbType: null); + dataSource, + new StructWithNullableProperty { Foo = 8 }, "(8)", + type, dataTypeInference: DataTypeInference.Nothing); await AssertType( - connection, - new StructWithNullableProperty { Foo = null }, - "()", - type, - npgsqlDbType: null); + dataSource, + new StructWithNullableProperty { Foo = null }, "()", + type, dataTypeInference: DataTypeInference.Nothing); } [Test] public async Task PostgresType() { - await using var connection = await OpenConnectionAsync(); + // Set max pool size to 1 to ensure we execute queries on the connection which has the new types + await using var dataSource = CreateDataSource(connectionStringBuilderAction: csb => csb.MaxPoolSize = 1); + await using var connection = await dataSource.OpenConnectionAsync(); var type1 = await GetTempTypeName(connection); var type2 = await GetTempTypeName(connection); @@ -516,14 +572,12 @@ public async Task DuplicateConstructorParameters() var dataSourceBuilder = CreateDataSourceBuilder(); dataSourceBuilder.MapComposite(type); await using var dataSource = dataSourceBuilder.Build(); - await using var connection = await dataSource.OpenConnectionAsync(); var ex = Assert.ThrowsAsync(async () => await AssertType( - connection, + dataSource, new DuplicateOneLongOneBool(true, 1), "(1,t)", - type, - npgsqlDbType: null)); + type, dataTypeInference: DataTypeInference.Nothing)); Assert.That(ex!.InnerException, Is.TypeOf()); } @@ -538,10 +592,9 @@ public async Task PartialConstructorMissingSetter() var dataSourceBuilder = CreateDataSourceBuilder(); dataSourceBuilder.MapComposite(type); await using var dataSource = dataSourceBuilder.Build(); - await using var connection = await dataSource.OpenConnectionAsync(); var ex = Assert.ThrowsAsync(async () => await AssertTypeRead( - connection, + dataSource, "(1,t)", type, new MissingSetterOneLongOneBool(true, 1))); @@ -559,37 +612,63 @@ public async Task PartialConstructorWorks() var dataSourceBuilder = CreateDataSourceBuilder(); dataSourceBuilder.MapComposite(type); await using var dataSource = dataSourceBuilder.Build(); - await using var connection = await dataSource.OpenConnectionAsync(); await AssertType( - connection, - new OneLongOneBool(1) { BooleanValue = true }, - "(1,t)", - type, - npgsqlDbType: null); + dataSource, + new OneLongOneBool(1) { BooleanValue = true }, "(1,t)", + type, dataTypeInference: DataTypeInference.Nothing); } - #region Test Types - - readonly struct DuplicateOneLongOneBool + [Test] + public async Task CompositeOverRange() { - public DuplicateOneLongOneBool(bool boolean, [PgName("boolean")]int @bool) + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + var rangeType = await GetTempTypeName(adminConnection); + + await adminConnection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS (x int, some_text text); CREATE TYPE {rangeType} AS RANGE(subtype={type})"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.MapComposite(type); + dataSourceBuilder.EnableUnmappedTypes(); + await using var dataSource = dataSourceBuilder.Build(); + + var composite1 = new SomeComposite { - } + SomeText = "foo", + X = 8 + }; + var composite2 = new SomeComposite + { + SomeText = "bar", + X = 42 + }; + + await AssertType( + dataSource, + new NpgsqlRange(composite1, composite2), + "[\"(8,foo)\",\"(42,bar)\"]", + rangeType, dataTypeInference: DataTypeInference.Nothing); + } + + #region Test Types + +#pragma warning disable CS9113 + readonly struct DuplicateOneLongOneBool(bool boolean, [PgName("boolean")] int @bool) + { [PgName("long")] public long LongValue { get; } [PgName("boolean")] public bool BooleanValue { get; } } +#pragma warning restore CS9113 readonly struct MissingSetterOneLongOneBool { public MissingSetterOneLongOneBool(long @long) - { - LongValue = @long; - } + => LongValue = @long; public MissingSetterOneLongOneBool(bool boolean, [PgName("boolean")]int @bool) { @@ -609,9 +688,7 @@ public OneLongOneBool(bool boolean, [PgName("boolean")]int @bool) } public OneLongOneBool(long @long) - { - LongValue = @long; - } + => LongValue = @long; public OneLongOneBool(double other) { @@ -652,6 +729,23 @@ class SomeCompositeWithArray public int[]? Ints { get; set; } } + class SomeCompositeWithEnum + { + public enum TestEnum + { + Value1, + Value2, + Value3 + } + + public TestEnum EnumValue { get; set; } + } + + class SomeCompositeWithIPAddress + { + public IPAddress? Address { get; set; } + } + class SomeCompositeWithConverterResolverType { public DateTime[]? DateTimes { get; set; } @@ -681,7 +775,5 @@ struct StructWithNullableProperty public int? Foo { get; set; } } - public CompositeTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} - #endregion } diff --git a/test/Npgsql.Tests/Types/CubeTests.cs b/test/Npgsql.Tests/Types/CubeTests.cs new file mode 100644 index 0000000000..9c98438ab7 --- /dev/null +++ b/test/Npgsql.Tests/Types/CubeTests.cs @@ -0,0 +1,267 @@ +using System; +using System.Threading.Tasks; +using Npgsql.Properties; +using NpgsqlTypes; +using NUnit.Framework; + +namespace Npgsql.Tests.Types; + +public class CubeTests : TestBase +{ + static readonly TestCaseData[] CubeValues = + { + new TestCaseData(new NpgsqlCube(new[] { 1.0, 2.0, 3.0 }, new[] { 4.0, 5.0, 6.0 }), "(1, 2, 3),(4, 5, 6)") + .SetName("Cube_MultiDimensional"), + new TestCaseData(new NpgsqlCube(new[] { 1.0, 2.0, 3.0 }), "(1, 2, 3)") + .SetName("Cube_MultiDimensionalPoint"), + new TestCaseData(new NpgsqlCube(1.0), "(1)") + .SetName("Cube_SingleDimensionalPoint"), + new TestCaseData(new NpgsqlCube(1.0, 2.0), "(1),(2)") + .SetName("Cube_SingleDimensional") + }; + + [Test, TestCaseSource(nameof(CubeValues))] + public Task Cube(NpgsqlCube cube, string sqlLiteral) + => AssertType(cube, sqlLiteral, "cube", dataTypeInference: DataTypeInference.Nothing); + + [Test] + public void Cube_Constructor_SingleValue() + { + var cube = new NpgsqlCube(1.0); + Assert.That(cube.IsPoint, Is.True); + Assert.That(cube.Dimensions, Is.EqualTo(1)); + Assert.That(cube.LowerLeft, Is.EquivalentTo(new [] { 1.0 })); + Assert.That(cube.UpperRight, Is.EquivalentTo(new [] { 1.0 })); + } + + [Test] + public void Cube_Constructor_SingleCoord_Point() + { + var cube = new NpgsqlCube(1.0, 1.0); + Assert.That(cube.IsPoint, Is.True); + Assert.That(cube.Dimensions, Is.EqualTo(1)); + Assert.That(cube.LowerLeft, Is.EquivalentTo(new [] { 1.0 })); + Assert.That(cube.UpperRight, Is.EquivalentTo(new [] { 1.0 })); + } + + [Test] + public void Cube_Constructor_SingleCoord_NotPoint() + { + var cube = new NpgsqlCube(1.0, 2.0); + Assert.That(cube.IsPoint, Is.False); + Assert.That(cube.Dimensions, Is.EqualTo(1)); + Assert.That(cube.LowerLeft, Is.EquivalentTo(new [] { 1.0 })); + Assert.That(cube.UpperRight, Is.EquivalentTo(new [] { 2.0 })); + } + + [Test] + public void Cube_Constructor_LowerLeft_UpperRight_NotPoint() + { + var cube = new NpgsqlCube(new[] { 1.0, 2.0 }, new[] { 3.0, 4.0 }); + Assert.That(cube.IsPoint, Is.False); + Assert.That(cube.Dimensions, Is.EqualTo(2)); + Assert.That(cube.LowerLeft, Is.EquivalentTo(new [] { 1.0, 2.0 })); + Assert.That(cube.UpperRight, Is.EquivalentTo(new [] { 3.0, 4.0 })); + } + + [Test] + public void Cube_Constructor_LowerLeft_UpperRight_Point() + { + var cube = new NpgsqlCube(new[] { 1.0, 2.0 }, new[] { 1.0, 2.0 }); + Assert.That(cube.IsPoint, Is.True); + Assert.That(cube.Dimensions, Is.EqualTo(2)); + Assert.That(cube.LowerLeft, Is.EquivalentTo(new [] { 1.0, 2.0 })); + Assert.That(cube.UpperRight, Is.EquivalentTo(new [] { 1.0, 2.0 })); + } + + [Test] + public void Cube_Constructor_AddDimension_Single_Point() + { + var existingCube = new NpgsqlCube(new[] { 1.0, 2.0, 3.0 }); + var cube = new NpgsqlCube(existingCube, 4.0); + Assert.That(cube.IsPoint, Is.True); + Assert.That(cube.Dimensions, Is.EqualTo(4)); + Assert.That(cube.LowerLeft, Is.EquivalentTo(new [] { 1.0, 2.0, 3.0, 4.0 })); + Assert.That(cube.UpperRight, Is.EquivalentTo(new [] { 1.0, 2.0, 3.0, 4.0 })); + } + + [Test] + public void Cube_Constructor_AddDimension_Single_NotPoint() + { + var existingCube = new NpgsqlCube(new [] { 1.0, 2.0 }, new [] { 3.0, 4.0 }); + var cube = new NpgsqlCube(existingCube, 3.0); + Assert.That(cube.IsPoint, Is.False); + Assert.That(cube.Dimensions, Is.EqualTo(3)); + Assert.That(cube.LowerLeft, Is.EquivalentTo(new [] { 1.0, 2.0, 3.0 })); + Assert.That(cube.UpperRight, Is.EquivalentTo(new [] { 3.0, 4.0, 3.0 })); + } + + [Test] + public void Cube_Constructor_AddDimension_LowerLeft_UpperRight_Point() + { + var existingCube = new NpgsqlCube(new[] { 1.0, 2.0, 3.0 }); + var cube = new NpgsqlCube(existingCube, 4.0, 4.0); + Assert.That(cube.IsPoint, Is.True); + Assert.That(cube.Dimensions, Is.EqualTo(4)); + Assert.That(cube.LowerLeft, Is.EquivalentTo(new [] { 1.0, 2.0, 3.0, 4.0 })); + Assert.That(cube.UpperRight, Is.EquivalentTo(new [] { 1.0, 2.0, 3.0, 4.0 })); + } + + [Test] + public void Cube_Constructor_AddDimension_LowerLeft_UpperRight_NotPoint() + { + var existingCube = new NpgsqlCube(new [] { 1.0, 2.0 }, new [] { 3.0, 4.0 }); + var cube = new NpgsqlCube(existingCube, 4.0, 5.0); + Assert.That(cube.IsPoint, Is.False); + Assert.That(cube.Dimensions, Is.EqualTo(3)); + Assert.That(cube.LowerLeft, Is.EquivalentTo(new [] { 1.0, 2.0, 4.0 })); + Assert.That(cube.UpperRight, Is.EquivalentTo(new [] { 3.0, 4.0, 5.0 })); + } + + [Test] + public void Cube_Subset() + { + var cube = new NpgsqlCube(new [] { 1.0, 2.0, 3.0 }, new [] { 4.0, 5.0, 6.0 }); + Assert.That(cube.ToSubset(0, 2, 1, 1), Is.EqualTo(new NpgsqlCube(new [] { 1.0, 3.0, 2.0, 2.0 }, new [] { 4.0, 6.0, 5.0, 5.0 }))); + } + + [Test] + public void Cube_ToString_NotPoint() + { + var cube = new NpgsqlCube(new[] { 1.0, 2.0, 3.0 }, new[] { 4.0, 5.0, 6.0 }); + Assert.That(cube.ToString(), Is.EqualTo("(1, 2, 3),(4, 5, 6)")); + } + + [Test] + public void Cube_ToString_Point() + { + var cube = new NpgsqlCube(new[] { 1.0, 2.0, 3.0 }); + Assert.That(cube.ToString(), Is.EqualTo("(1, 2, 3)")); + } + + [Test] + public async Task Cube_Array() + { + var data = new[] + { + new NpgsqlCube(new[] { 1.0, 2.0 }, new[] { 3.0, 4.0 }), + new NpgsqlCube(new[] { 5.0, 6.0 }), + new NpgsqlCube(1.0, 2.0) + }; + + await AssertType( + data, + @"{""(1, 2),(3, 4)"",""(5, 6)"",""(1),(2)""}", + "cube[]", + dataTypeInference: DataTypeInference.Nothing); + } + + [Test] + public void Cube_DimensionMismatch_ThrowsArgumentException() + { + var ex = Assert.Throws(() => new NpgsqlCube(new[] { 1.0, 2.0 }, new[] { 3.0 })); + Assert.That(ex!.Message, Does.Contain("Different point dimensions")); + } + + [Test] + public Task Cube_NegativeValues() + => AssertType( + new NpgsqlCube(new[] { -1.0, -2.0, -3.0 }, new[] { -4.0, -5.0, -6.0 }), + "(-1, -2, -3),(-4, -5, -6)", + "cube", + dataTypeInference: DataTypeInference.Nothing); + + [Test] + public void Cube_Equality_HashCode() + { + var cube1 = new NpgsqlCube(new[] { 1.0, 2.0 }, new[] { 3.0, 4.0 }); + var cube2 = new NpgsqlCube(new[] { 1.0, 2.0 }, new[] { 3.0, 4.0 }); + var cube3 = new NpgsqlCube(new[] { 1.0, 2.0 }, new[] { 3.0, 5.0 }); + + // Test equality + Assert.That(cube1, Is.EqualTo(cube2)); + Assert.That(cube1 == cube2, Is.True); + Assert.That(cube1 != cube3, Is.True); + Assert.That(cube1.Equals(cube2), Is.True); + Assert.That(cube1.Equals(cube3), Is.False); + + // Test hash code consistency + Assert.That(cube1.GetHashCode(), Is.EqualTo(cube2.GetHashCode())); + Assert.That(cube1.GetHashCode(), Is.Not.EqualTo(cube3.GetHashCode())); + } + + [Test] + public Task Cube_ZeroValues() + => AssertType( + new NpgsqlCube(0.0, 0.0), + "(0)", + "cube", + dataTypeInference: DataTypeInference.Nothing); + + [Test] + public Task Cube_MaxDimensions() + { + var lowerLeft = new double[100]; + var upperRight = new double[100]; + for (var i = 0; i < 100; i++) + { + lowerLeft[i] = i; + upperRight[i] = i + 100; + } + + var expectedLower = string.Join(", ", lowerLeft); + var expectedUpper = string.Join(", ", upperRight); + var expected = $"({expectedLower}),({expectedUpper})"; + + return AssertType( + new NpgsqlCube(lowerLeft, upperRight), + expected, + "cube", + dataTypeInference: DataTypeInference.Nothing); + } + + [Test] + public async Task Cube_not_supported_by_default_on_NpgsqlSlimSourceBuilder() + { + var errorMessage = string.Format( + NpgsqlStrings.CubeNotEnabled, nameof(NpgsqlSlimDataSourceBuilder.EnableCube), nameof(NpgsqlSlimDataSourceBuilder)); + + var dataSourceBuilder = new NpgsqlSlimDataSourceBuilder(ConnectionString); + await using var dataSource = dataSourceBuilder.Build(); + + var exception = + await AssertTypeUnsupportedRead("(1),(2)", "cube", dataSource); + Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); + exception = await AssertTypeUnsupportedWrite(new NpgsqlCube(1.0, 2.0), "cube", dataSource); + Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); + } + + [Test] + public async Task NpgsqlSlimSourceBuilder_EnableCube() + { + var dataSourceBuilder = new NpgsqlSlimDataSourceBuilder(ConnectionString); + dataSourceBuilder.EnableCube(); + await using var dataSource = dataSourceBuilder.Build(); + + await AssertType(dataSource, new NpgsqlCube(1.0, 2.0), "(1),(2)", "cube", dataTypeInference: DataTypeInference.Nothing, skipArrayCheck: true); + } + + [Test] + public async Task NpgsqlSlimSourceBuilder_EnableArrays() + { + var dataSourceBuilder = new NpgsqlSlimDataSourceBuilder(ConnectionString); + dataSourceBuilder.EnableCube(); + dataSourceBuilder.EnableArrays(); + await using var dataSource = dataSourceBuilder.Build(); + + await AssertType(dataSource, new NpgsqlCube(1.0, 2.0), "(1),(2)", "cube", dataTypeInference: DataTypeInference.Nothing); + } + + [OneTimeSetUp] + public async Task SetUp() + { + await using var conn = await OpenConnectionAsync(); + TestUtil.MinimumPgVersion(conn, "13.0"); + await TestUtil.EnsureExtensionAsync(conn, "cube"); + } +} diff --git a/test/Npgsql.Tests/Types/DateTimeInfinityTests.cs b/test/Npgsql.Tests/Types/DateTimeInfinityTests.cs index e1ccad4445..834ad346e9 100644 --- a/test/Npgsql.Tests/Types/DateTimeInfinityTests.cs +++ b/test/Npgsql.Tests/Types/DateTimeInfinityTests.cs @@ -1,58 +1,57 @@ using System; using System.Data; using System.Threading.Tasks; -using NpgsqlTypes; using NUnit.Framework; using static Npgsql.Util.Statics; namespace Npgsql.Tests.Types; -[TestFixture(true)] -#if DEBUG [TestFixture(false)] +#if DEBUG +[TestFixture(true)] [NonParallelizable] #endif public sealed class DateTimeInfinityTests : TestBase, IDisposable { static readonly TestCaseData[] TimestampDateTimeValues = - { + [ new TestCaseData(DateTime.MinValue.AddYears(1), "0002-01-01 00:00:00", "0002-01-01 00:00:00") .SetName("MinValue_AddYear"), new TestCaseData(DateTime.MinValue, "0001-01-01 00:00:00", "-infinity") .SetName("MinValue"), new TestCaseData(DateTime.MaxValue, "9999-12-31 23:59:59.999999", "infinity") - .SetName("MaxValue"), - }; + .SetName("MaxValue") + ]; static readonly TestCaseData[] TimestampTzDateTimeValues = - { + [ new TestCaseData(DateTime.MinValue.AddYears(1), "0002-01-01 00:00:00+00", "0002-01-01 00:00:00+00") .SetName("MinValue_AddYear"), new TestCaseData(DateTime.MinValue, "0001-01-01 00:00:00+00", "-infinity") .SetName("MinValue"), new TestCaseData(DateTime.MaxValue, "9999-12-31 23:59:59.999999+00", "infinity") - .SetName("MaxValue"), - }; + .SetName("MaxValue") + ]; static readonly TestCaseData[] TimestampTzDateTimeOffsetValues = - { + [ new TestCaseData(DateTimeOffset.MinValue.ToUniversalTime().AddYears(1), "0002-01-01 00:00:00+00", "0002-01-01 00:00:00+00") .SetName("MinValue_AddYear"), new TestCaseData(DateTimeOffset.MinValue, "0001-01-01 00:00:00+00", "-infinity") .SetName("MinValue"), new TestCaseData(DateTimeOffset.MaxValue, "9999-12-31 23:59:59.999999+00", "infinity") - .SetName("MaxValue"), - }; + .SetName("MaxValue") + ]; static readonly TestCaseData[] DateDateTimeValues = - { + [ new TestCaseData(DateTime.MinValue.AddYears(1), "0002-01-01", "0002-01-01") .SetName("MinValue_AddYear"), new TestCaseData(DateTime.MinValue, "0001-01-01", "-infinity") .SetName("MinValue"), new TestCaseData(DateTime.MaxValue, "9999-12-31", "infinity") - .SetName("MaxValue"), - }; + .SetName("MaxValue") + ]; // As we can't roundtrip DateTime.MaxValue due to precision differences with postgres we are lenient with equality for this particular value. static readonly Func MaxValuePrecisionLenientComparer = @@ -61,45 +60,46 @@ public sealed class DateTimeInfinityTests : TestBase, IDisposable [Test, TestCaseSource(nameof(TimestampDateTimeValues))] public Task Timestamp_DateTime(DateTime dateTime, string sqlLiteral, string infinityConvertedSqlLiteral) => AssertType(dateTime, DisableDateTimeInfinityConversions ? sqlLiteral : infinityConvertedSqlLiteral, - "timestamp without time zone", NpgsqlDbType.Timestamp, DbType.DateTime2, - comparer: MaxValuePrecisionLenientComparer, - isDefault: true); + "timestamp without time zone", + dbType: DbType.DateTime2, + comparer: MaxValuePrecisionLenientComparer); [Test, TestCaseSource(nameof(TimestampTzDateTimeValues))] public Task TimestampTz_DateTime(DateTime dateTime, string sqlLiteral, string infinityConvertedSqlLiteral) - => AssertType(new(dateTime.Ticks, DateTimeKind.Utc), DisableDateTimeInfinityConversions ? sqlLiteral : infinityConvertedSqlLiteral, - "timestamp with time zone", NpgsqlDbType.TimestampTz, DbType.DateTime, DbType.DateTime, - comparer: MaxValuePrecisionLenientComparer, - isDefault: true, isNpgsqlDbTypeInferredFromClrType: false); + => AssertType(new DateTime(dateTime.Ticks, DateTimeKind.Utc), DisableDateTimeInfinityConversions ? sqlLiteral : infinityConvertedSqlLiteral, + "timestamp with time zone", + dbType: DbType.DateTime, + comparer: MaxValuePrecisionLenientComparer); [Test, TestCaseSource(nameof(TimestampTzDateTimeOffsetValues))] public Task TimestampTz_DateTimeOffset(DateTimeOffset dateTime, string sqlLiteral, string infinityConvertedSqlLiteral) => AssertType(dateTime, DisableDateTimeInfinityConversions ? sqlLiteral : infinityConvertedSqlLiteral, - "timestamp with time zone", NpgsqlDbType.TimestampTz, DbType.DateTime, DbType.DateTime, + "timestamp with time zone", + dbType: DbType.DateTime, comparer: (expected, actual) => MaxValuePrecisionLenientComparer(expected.DateTime, actual.DateTime), - isDefault: false); + valueTypeEqualsFieldType: false); [Test, TestCaseSource(nameof(DateDateTimeValues))] public Task Date_DateTime(DateTime dateTime, string sqlLiteral, string infinityConvertedSqlLiteral) => AssertType(DisableDateTimeInfinityConversions ? dateTime.Date : dateTime, DisableDateTimeInfinityConversions ? sqlLiteral : infinityConvertedSqlLiteral, - "date", NpgsqlDbType.Date, DbType.Date, - isDefault: false); + "date", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Date, DbType.DateTime2), valueTypeEqualsFieldType: false); static readonly TestCaseData[] DateOnlyDateTimeValues = - { + [ new TestCaseData(DateOnly.MinValue.AddYears(1), "0002-01-01", "0002-01-01") .SetName("MinValue_AddYear"), new TestCaseData(DateOnly.MinValue, "0001-01-01", "-infinity") .SetName("MinValue"), new TestCaseData(DateOnly.MaxValue, "9999-12-31", "infinity") - .SetName("MaxValue"), - }; + .SetName("MaxValue") + ]; [Test, TestCaseSource(nameof(DateOnlyDateTimeValues))] public Task Date_DateOnly(DateOnly dateTime, string sqlLiteral, string infinityConvertedSqlLiteral) - => AssertType(dateTime, - DisableDateTimeInfinityConversions ? sqlLiteral : infinityConvertedSqlLiteral, "date", NpgsqlDbType.Date, DbType.Date, - isDefault: false); + => AssertType(dateTime, DisableDateTimeInfinityConversions ? sqlLiteral : infinityConvertedSqlLiteral, + "date", + dbType: DbType.Date); NpgsqlDataSource? _dataSource; protected override NpgsqlDataSource DataSource => _dataSource ??= CreateDataSource(csb => csb.Timezone = "UTC"); diff --git a/test/Npgsql.Tests/Types/DateTimeTests.cs b/test/Npgsql.Tests/Types/DateTimeTests.cs index ec91148f2c..c698514e10 100644 --- a/test/Npgsql.Tests/Types/DateTimeTests.cs +++ b/test/Npgsql.Tests/Types/DateTimeTests.cs @@ -8,54 +8,32 @@ namespace Npgsql.Tests.Types; -// Since this test suite manipulates TimeZone, it is incompatible with multiplexing public class DateTimeTests : TestBase { #region Date + [Test] + public Task Date_as_DateOnly() + => AssertType(new DateOnly(2020, 10, 1), "2020-10-01", "date", dbType: DbType.Date); + [Test] public Task Date_as_DateTime() - => AssertType(new DateTime(2020, 10, 1), "2020-10-01", "date", NpgsqlDbType.Date, DbType.Date, isDefaultForWriting: false); + => AssertType(new DateTime(2020, 10, 1), "2020-10-01", + "date", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Date, DbType.DateTime2), valueTypeEqualsFieldType: false); [Test] public Task Date_as_DateTime_with_date_and_time_before_2000() - => AssertTypeWrite(new DateTime(1980, 10, 1, 11, 0, 0), "1980-10-01", "date", NpgsqlDbType.Date, DbType.Date, isDefault: false); + => AssertTypeWrite(new DateTime(1980, 10, 1, 11, 0, 0), "1980-10-01", + "date", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Date, DbType.DateTime2)); // Internal PostgreSQL representation (days since 2020-01-01), for out-of-range values. [Test] public Task Date_as_int() - => AssertType(7579, "2020-10-01", "date", NpgsqlDbType.Date, DbType.Date, isDefault: false); - - [Test] - public Task Daterange_as_NpgsqlRange_of_DateTime() - => AssertType( - new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false), - "[2002-03-04,2002-03-06)", - "daterange", - NpgsqlDbType.DateRange, - isDefaultForWriting: false); - - [Test] - public async Task Datemultirange_as_array_of_NpgsqlRange_of_DateTime() - { - await using var conn = await OpenConnectionAsync(); - MinimumPgVersion(conn, "14.0", "Multirange types were introduced in PostgreSQL 14"); - - await AssertType( - new[] - { - new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false), - new NpgsqlRange(new(2002, 3, 8), true, new(2002, 3, 11), false) - }, - "{[2002-03-04,2002-03-06),[2002-03-08,2002-03-11)}", - "datemultirange", - NpgsqlDbType.DateMultirange, - isDefaultForWriting: false); - } - - [Test] - public Task Date_as_DateOnly() - => AssertType(new DateOnly(2020, 10, 1), "2020-10-01", "date", NpgsqlDbType.Date, DbType.Date, isDefaultForReading: false); + => AssertType(7579, "2020-10-01", + "date", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Date, DbType.Int32), valueTypeEqualsFieldType: false); [Test] public Task Daterange_as_NpgsqlRange_of_DateOnly() @@ -63,8 +41,6 @@ public Task Daterange_as_NpgsqlRange_of_DateOnly() new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false), "[2002-03-04,2002-03-06)", "daterange", - NpgsqlDbType.DateRange, - isDefaultForReading: false, skipArrayCheck: true); // NpgsqlRange[] is mapped to multirange by default, not array; test separately [Test] @@ -76,9 +52,16 @@ public Task Daterange_array_as_NpgsqlRange_of_DateOnly_array() new NpgsqlRange(new(2002, 3, 8), true, new(2002, 3, 9), false) }, """{"[2002-03-04,2002-03-06)","[2002-03-08,2002-03-09)"}""", - "daterange[]", - NpgsqlDbType.DateRange | NpgsqlDbType.Array, - isDefault: false); + "daterange[]", dataTypeInference: DataTypeInference.Mismatch, + valueTypeEqualsFieldType: false); + + [Test] + public Task Daterange_as_NpgsqlRange_of_DateTime() + => AssertType( + new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false), + "[2002-03-04,2002-03-06)", + "daterange", dataTypeInference: DataTypeInference.Mismatch, + valueTypeEqualsFieldType: false); [Test] public async Task Datemultirange_as_array_of_NpgsqlRange_of_DateOnly() @@ -93,9 +76,24 @@ await AssertType( new NpgsqlRange(new(2002, 3, 8), true, new(2002, 3, 11), false) }, "{[2002-03-04,2002-03-06),[2002-03-08,2002-03-11)}", - "datemultirange", - NpgsqlDbType.DateMultirange, - isDefaultForReading: false); + "datemultirange"); + } + + [Test] + public async Task Datemultirange_as_array_of_NpgsqlRange_of_DateTime() + { + await using var conn = await OpenConnectionAsync(); + MinimumPgVersion(conn, "14.0", "Multirange types were introduced in PostgreSQL 14"); + + await AssertType( + new[] + { + new NpgsqlRange(new(2002, 3, 4), true, new(2002, 3, 6), false), + new NpgsqlRange(new(2002, 3, 8), true, new(2002, 3, 11), false) + }, + "{[2002-03-04,2002-03-06),[2002-03-08,2002-03-11)}", + "datemultirange", dataTypeInference: DataTypeInference.Mismatch, + valueTypeEqualsFieldType: false); } #endregion @@ -103,31 +101,23 @@ await AssertType( #region Time [Test] - public Task Time_as_TimeSpan() - => AssertType( - new TimeSpan(0, 10, 45, 34, 500), - "10:45:34.5", + public Task Time_as_TimeOnly() + => AssertType(new TimeOnly(10, 45, 34, 500), "10:45:34.5", "time without time zone", - NpgsqlDbType.Time, - DbType.Time, - isDefaultForWriting: false); + dbType: DbType.Time); [Test] - public Task Time_as_TimeOnly() - => AssertType( - new TimeOnly(10, 45, 34, 500), - "10:45:34.5", - "time without time zone", - NpgsqlDbType.Time, - DbType.Time, - isDefaultForReading: false); + public Task Time_as_TimeSpan() + => AssertType(new TimeSpan(0, 10, 45, 34, 500), "10:45:34.5", + "time without time zone", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Time, DbType.Object), valueTypeEqualsFieldType: false); #endregion #region Time with timezone static readonly TestCaseData[] TimeTzValues = - { + [ new TestCaseData(new DateTimeOffset(1, 1, 2, 13, 3, 45, 510, TimeSpan.FromHours(2)), "13:03:45.51+02") .SetName("Timezone"), new TestCaseData(new DateTimeOffset(1, 1, 2, 1, 0, 45, 510, TimeSpan.FromHours(-3)), "01:00:45.51-03") @@ -135,37 +125,39 @@ public Task Time_as_TimeOnly() new TestCaseData(new DateTimeOffset(1212720130000, TimeSpan.Zero), "09:41:12.013+00") .SetName("Utc"), new TestCaseData(new DateTimeOffset(1, 1, 2, 1, 0, 0, new TimeSpan(0, 2, 0, 0)), "01:00:00+02") - .SetName("Before_utc_zero"), - }; + .SetName("Before_utc_zero") + ]; [Test, TestCaseSource(nameof(TimeTzValues))] public Task TimeTz_as_DateTimeOffset(DateTimeOffset time, string sqlLiteral) - => AssertType(time, sqlLiteral, "time with time zone", NpgsqlDbType.TimeTz, isDefault: false); + => AssertType(time, sqlLiteral, + "time with time zone", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Object, DbType.DateTime)); #endregion #region Timestamp static readonly TestCaseData[] TimestampValues = - { + [ new TestCaseData(new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Unspecified), "1998-04-12 13:26:38") .SetName("Timestamp_pre2000"), new TestCaseData(new DateTime(2015, 1, 27, 8, 45, 12, 345, DateTimeKind.Unspecified), "2015-01-27 08:45:12.345") .SetName("Timestamp_post2000"), new TestCaseData(new DateTime(2013, 7, 25, 0, 0, 0, DateTimeKind.Unspecified), "2013-07-25 00:00:00") .SetName("Timestamp_date_only") - }; + ]; [Test, TestCaseSource(nameof(TimestampValues))] public async Task Timestamp_as_DateTime(DateTime dateTime, string sqlLiteral) { - await AssertType(dateTime, sqlLiteral, "timestamp without time zone", NpgsqlDbType.Timestamp, DbType.DateTime2, + await AssertType(dateTime, sqlLiteral, "timestamp without time zone", dbType: DbType.DateTime2, // Explicitly check kind as well. comparer: (actual, expected) => actual.Kind == expected.Kind && actual.Equals(expected)); await AssertType( - new List { dateTime, dateTime }, $$"""{"{{sqlLiteral}}","{{sqlLiteral}}"}""", "timestamp without time zone[]", NpgsqlDbType.Timestamp | NpgsqlDbType.Array, - isDefaultForReading: false); + new List { dateTime, dateTime }, $$"""{"{{sqlLiteral}}","{{sqlLiteral}}"}""", "timestamp without time zone[]", + valueTypeEqualsFieldType: false); } [Test] @@ -174,13 +166,9 @@ public Task Timestamp_cannot_write_utc_DateTime() [Test] public Task Timestamp_as_long() - => AssertType( - -54297202000000, - "1998-04-12 13:26:38", - "timestamp without time zone", - NpgsqlDbType.Timestamp, - DbType.DateTime2, - isDefault: false); + => AssertType(-54297202000000, "1998-04-12 13:26:38", + "timestamp without time zone", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.DateTime2, DbType.Int64), valueTypeEqualsFieldType: false); [Test] public Task Timestamp_cannot_use_as_DateTimeOffset() @@ -197,7 +185,6 @@ public Task Tsrange_as_NpgsqlRange_of_DateTime() new(1998, 4, 12, 15, 26, 38, DateTimeKind.Local)), @"[""1998-04-12 13:26:38"",""1998-04-12 15:26:38""]", "tsrange", - NpgsqlDbType.TimestampRange, skipArrayCheck: true); // NpgsqlRange[] is mapped to multirange by default, not array; test separately [Test] @@ -214,8 +201,7 @@ public Task Tsrange_array_as_NpgsqlRange_of_DateTime_array() }, """{"[\"1998-04-12 13:26:38\",\"1998-04-12 15:26:38\"]","[\"1998-04-13 13:26:38\",\"1998-04-13 15:26:38\"]"}""", "tsrange[]", - NpgsqlDbType.TimestampRange | NpgsqlDbType.Array, - isDefault: false); + dataTypeInference: DataTypeInference.Mismatch); [Test] public async Task Tsmultirange_as_array_of_NpgsqlRange_of_DateTime() @@ -234,8 +220,7 @@ await AssertType( new(1998, 4, 13, 15, 26, 38, DateTimeKind.Local)), }, @"{[""1998-04-12 13:26:38"",""1998-04-12 15:26:38""],[""1998-04-13 13:26:38"",""1998-04-13 15:26:38""]}", - "tsmultirange", - NpgsqlDbType.TimestampMultirange); + "tsmultirange"); } #endregion @@ -245,35 +230,37 @@ await AssertType( // Note that the below text representations are local (according to TimeZone, which is set to Europe/Berlin in this test class), // because that's how PG does timestamptz *text* representation. static readonly TestCaseData[] TimestampTzWriteValues = - { + [ new TestCaseData(new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), "1998-04-12 15:26:38+02") .SetName("Timestamptz_write_pre2000"), new TestCaseData(new DateTime(2015, 1, 27, 8, 45, 12, 345, DateTimeKind.Utc), "2015-01-27 09:45:12.345+01") .SetName("Timestamptz_write_post2000"), new TestCaseData(new DateTime(2013, 7, 25, 0, 0, 0, DateTimeKind.Utc), "2013-07-25 02:00:00+02") .SetName("Timestamptz_write_date_only") - }; + ]; [Test, TestCaseSource(nameof(TimestampTzWriteValues))] public async Task Timestamptz_as_DateTime(DateTime dateTime, string sqlLiteral) { - await AssertType(dateTime, sqlLiteral, "timestamp with time zone", NpgsqlDbType.TimestampTz, DbType.DateTime, + await AssertType(dateTime, sqlLiteral, "timestamp with time zone", dbType: DbType.DateTime, // Explicitly check kind as well. comparer: (actual, expected) => actual.Kind == expected.Kind && actual.Equals(expected)); await AssertType( - new List { dateTime, dateTime }, $$"""{"{{sqlLiteral}}","{{sqlLiteral}}"}""", "timestamp with time zone[]", NpgsqlDbType.TimestampTz | NpgsqlDbType.Array, - isDefaultForReading: false); + new List { dateTime, dateTime }, $$"""{"{{sqlLiteral}}","{{sqlLiteral}}"}""", "timestamp with time zone[]", + valueTypeEqualsFieldType: false); } [Test] public async Task Timestamptz_infinity_as_DateTime() { - await AssertType(DateTime.MinValue, "-infinity", "timestamp with time zone", NpgsqlDbType.TimestampTz, DbType.DateTime, - isDefault: false); - await AssertType(DateTime.MaxValue, "infinity", "timestamp with time zone", NpgsqlDbType.TimestampTz, DbType.DateTime, - isDefault: false); + await AssertType(DateTime.MinValue, "-infinity", + "timestamp with time zone", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.DateTime, DbType.DateTime2)); + await AssertType(DateTime.MaxValue, "infinity", + "timestamp with time zone", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.DateTime, DbType.DateTime2)); } [Test] @@ -290,9 +277,8 @@ public async Task Timestamptz_as_DateTimeOffset_utc() new DateTimeOffset(1998, 4, 12, 13, 26, 38, TimeSpan.Zero), "1998-04-12 15:26:38+02", "timestamp with time zone", - NpgsqlDbType.TimestampTz, - DbType.DateTime, - isDefaultForReading: false); + dbType: DbType.DateTime, + valueTypeEqualsFieldType: false); Assert.That(dateTimeOffset.Offset, Is.EqualTo(TimeSpan.Zero)); } @@ -302,11 +288,8 @@ public Task Timestamptz_as_DateTimeOffset_utc_with_DbType_DateTimeOffset() => AssertTypeWrite( new DateTimeOffset(1998, 4, 12, 13, 26, 38, TimeSpan.Zero), "1998-04-12 15:26:38+02", - "timestamp with time zone", - NpgsqlDbType.TimestampTz, - DbType.DateTimeOffset, - inferredDbType: DbType.DateTime, - isDefault: false); + "timestamp with time zone", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.DateTime, DbType.DateTime, DbType.DateTimeOffset)); [Test] public Task Timestamptz_cannot_write_non_utc_DateTimeOffset() @@ -314,13 +297,9 @@ public Task Timestamptz_cannot_write_non_utc_DateTimeOffset() [Test] public Task Timestamptz_as_long() - => AssertType( - -54297202000000, - "1998-04-12 15:26:38+02", - "timestamp with time zone", - NpgsqlDbType.TimestampTz, - DbType.DateTime, - isDefault: false); + => AssertType(-54297202000000, "1998-04-12 15:26:38+02", + "timestamp with time zone", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.DateTime, DbType.Int64), valueTypeEqualsFieldType: false); [Test] public async Task Timestamptz_array_as_DateTimeOffset_array() @@ -333,8 +312,7 @@ public async Task Timestamptz_array_as_DateTimeOffset_array() }, """{"1998-04-12 15:26:38+02","1999-04-12 15:26:38+02"}""", "timestamp with time zone[]", - NpgsqlDbType.TimestampTz | NpgsqlDbType.Array, - isDefaultForReading: false); + valueTypeEqualsFieldType: false); Assert.That(dateTimeOffsets[0].Offset, Is.EqualTo(TimeSpan.Zero)); Assert.That(dateTimeOffsets[1].Offset, Is.EqualTo(TimeSpan.Zero)); @@ -348,7 +326,6 @@ public Task Tstzrange_as_NpgsqlRange_of_DateTime() new DateTime(1998, 4, 12, 15, 26, 38, DateTimeKind.Utc)), @"[""1998-04-12 15:26:38+02"",""1998-04-12 17:26:38+02""]", "tstzrange", - NpgsqlDbType.TimestampTzRange, skipArrayCheck: true); // NpgsqlRange[] is mapped to multirange by default, not array; test separately [Test] @@ -365,8 +342,7 @@ public Task Tstzrange_array_as_NpgsqlRange_of_DateTime_array() }, """{"[\"1998-04-12 15:26:38+02\",\"1998-04-12 17:26:38+02\"]","[\"1998-04-13 15:26:38+02\",\"1998-04-13 17:26:38+02\"]"}""", "tstzrange[]", - NpgsqlDbType.TimestampTzRange | NpgsqlDbType.Array, - isDefault: false); + dataTypeInference: DataTypeInference.Mismatch); [Test] public async Task Tstzmultirange_as_array_of_NpgsqlRange_of_DateTime() @@ -385,17 +361,15 @@ await AssertType( new DateTime(1998, 4, 13, 15, 26, 38, DateTimeKind.Utc)), }, @"{[""1998-04-12 15:26:38+02"",""1998-04-12 17:26:38+02""],[""1998-04-13 15:26:38+02"",""1998-04-13 17:26:38+02""]}", - "tstzmultirange", - NpgsqlDbType.TimestampTzMultirange); + "tstzmultirange"); } [Test] public Task Cannot_mix_DateTime_Kinds_in_array() - => AssertTypeUnsupportedWrite(new[] - { + => AssertTypeUnsupportedWrite([ new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), - new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Local), - }); + new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Local) + ]); [Test] @@ -410,8 +384,7 @@ public async Task Cannot_mix_DateTime_Kinds_in_multirange() await using var conn = await OpenConnectionAsync(); MinimumPgVersion(conn, "14.0", "Multirange types were introduced in PostgreSQL 14"); - await AssertTypeUnsupportedWrite[], ArgumentException>(new[] - { + await AssertTypeUnsupportedWrite[], ArgumentException>([ new NpgsqlRange( new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), new DateTime(1998, 4, 12, 15, 26, 38, DateTimeKind.Utc)), @@ -438,8 +411,8 @@ await AssertTypeUnsupportedWrite[], ArgumentException>(new new DateTime(1998, 4, 12, 15, 26, 38, DateTimeKind.Utc)), new NpgsqlRange( new DateTime(1998, 4, 13, 13, 26, 38, DateTimeKind.Local), - new DateTime(1998, 4, 13, 15, 26, 38, DateTimeKind.Local)), - }); + new DateTime(1998, 4, 13, 15, 26, 38, DateTimeKind.Local)) + ]); } [Test] @@ -447,8 +420,8 @@ public void NpgsqlParameterDbType_is_value_dependent_datetime_or_datetime2() { var localtimestamp = new NpgsqlParameter { Value = DateTime.Now }; var unspecifiedtimestamp = new NpgsqlParameter { Value = new DateTime() }; - Assert.AreEqual(DbType.DateTime2, localtimestamp.DbType); - Assert.AreEqual(DbType.DateTime2, unspecifiedtimestamp.DbType); + Assert.That(localtimestamp.DbType, Is.EqualTo(DbType.DateTime2)); + Assert.That(unspecifiedtimestamp.DbType, Is.EqualTo(DbType.DateTime2)); // We don't support any DateTimeOffset other than offset 0 which maps to timestamptz, // we might add an exception for offset == DateTimeOffset.Now.Offset (local offset) mapping to timestamp at some point. @@ -457,8 +430,8 @@ public void NpgsqlParameterDbType_is_value_dependent_datetime_or_datetime2() var timestamptz = new NpgsqlParameter { Value = DateTime.UtcNow }; var dtotimestamptz = new NpgsqlParameter { Value = DateTimeOffset.UtcNow }; - Assert.AreEqual(DbType.DateTime, timestamptz.DbType); - Assert.AreEqual(DbType.DateTime, dtotimestamptz.DbType); + Assert.That(timestamptz.DbType, Is.EqualTo(DbType.DateTime)); + Assert.That(dtotimestamptz.DbType, Is.EqualTo(DbType.DateTime)); } [Test] @@ -466,34 +439,47 @@ public void NpgsqlParameterNpgsqlDbType_is_value_dependent_timestamp_or_timestam { var localtimestamp = new NpgsqlParameter { Value = DateTime.Now }; var unspecifiedtimestamp = new NpgsqlParameter { Value = new DateTime() }; - Assert.AreEqual(NpgsqlDbType.Timestamp, localtimestamp.NpgsqlDbType); - Assert.AreEqual(NpgsqlDbType.Timestamp, unspecifiedtimestamp.NpgsqlDbType); + Assert.That(localtimestamp.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Timestamp)); + Assert.That(unspecifiedtimestamp.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.Timestamp)); var timestamptz = new NpgsqlParameter { Value = DateTime.UtcNow }; var dtotimestamptz = new NpgsqlParameter { Value = DateTimeOffset.UtcNow }; - Assert.AreEqual(NpgsqlDbType.TimestampTz, timestamptz.NpgsqlDbType); - Assert.AreEqual(NpgsqlDbType.TimestampTz, dtotimestamptz.NpgsqlDbType); + Assert.That(timestamptz.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.TimestampTz)); + Assert.That(dtotimestamptz.NpgsqlDbType, Is.EqualTo(NpgsqlDbType.TimestampTz)); } [Test] public async Task Array_of_nullable_timestamptz() - => await AssertType( + { + await using var datasource = CreateDataSource(csb => + { + csb.ArrayNullabilityMode = ArrayNullabilityMode.PerInstance; + csb.Timezone = "Europe/Berlin"; + }); + await AssertType(datasource, new DateTime?[] { new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), null }, @"{""1998-04-12 15:26:38+02"",NULL}", - "timestamp with time zone[]", - NpgsqlDbType.TimestampTz | NpgsqlDbType.Array, - isDefault: false); + "timestamp with time zone[]"); + + await AssertType(datasource, + new DateTime?[] + { + new DateTime(1998, 4, 12, 13, 26, 38, DateTimeKind.Utc), + }, + @"{""1998-04-12 15:26:38+02""}", + "timestamp with time zone[]", valueTypeEqualsFieldType: false); // we write DateTime?[], but will read DateTime[] from GetValue + } #endregion #region Interval static readonly TestCaseData[] IntervalValues = - { + [ new TestCaseData(new TimeSpan(0, 2, 3, 4, 5), "02:03:04.005") .SetName("Interval_time_only"), new TestCaseData(new TimeSpan(1, 2, 3, 4, 5), "1 day 02:03:04.005") @@ -502,27 +488,24 @@ public async Task Array_of_nullable_timestamptz() .SetName("Interval_with_many_days"), new TestCaseData(new TimeSpan(new TimeSpan(2, 3, 4).Ticks + 10), "02:03:04.000001") .SetName("Interval_with_microsecond") - }; + ]; [Test, TestCaseSource(nameof(IntervalValues))] public Task Interval_as_TimeSpan(TimeSpan timeSpan, string sqlLiteral) - => AssertType(timeSpan, sqlLiteral, "interval", NpgsqlDbType.Interval); + => AssertType(timeSpan, sqlLiteral, "interval"); [Test] public Task Interval_write_as_TimeSpan_truncates_ticks() => AssertTypeWrite( new TimeSpan(new TimeSpan(2, 3, 4).Ticks + 1), "02:03:04", - "interval", - NpgsqlDbType.Interval); + "interval"); [Test] public Task Interval_as_NpgsqlInterval() => AssertType( new NpgsqlInterval(2, 15, 7384005000), - "2 mons 15 days 02:03:04.005", "interval", - NpgsqlDbType.Interval, - isDefaultForReading: false); + "2 mons 15 days 02:03:04.005", "interval", valueTypeEqualsFieldType: false); [Test] public Task Interval_with_months_cannot_read_as_TimeSpan() diff --git a/test/Npgsql.Tests/Types/DomainTests.cs b/test/Npgsql.Tests/Types/DomainTests.cs index 4faaceb212..0905b7b4d6 100644 --- a/test/Npgsql.Tests/Types/DomainTests.cs +++ b/test/Npgsql.Tests/Types/DomainTests.cs @@ -1,18 +1,16 @@ using System; using System.Threading.Tasks; +using NpgsqlTypes; using NUnit.Framework; using static Npgsql.Tests.TestUtil; namespace Npgsql.Tests.Types; -public class DomainTests : MultiplexingTestBase +public class DomainTests : TestBase { [Test, Description("Resolves a domain type handler via the different pathways")] public async Task Domain_resolution() { - if (IsMultiplexing) - Assert.Ignore("Multiplexing, ReloadTypes"); - await using var dataSource = CreateDataSource(csb => csb.Pooling = false); await using var conn = await dataSource.OpenConnectionAsync(); var type = await GetTempTypeName(conn); @@ -75,5 +73,25 @@ class SomeComposite public string? Value { get; set; } } - public DomainTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} + [Test] + public async Task Domain_over_range() + { + await using var adminConnection = await OpenConnectionAsync(); + var type = await GetTempTypeName(adminConnection); + var rangeType = await GetTempTypeName(adminConnection); + + await adminConnection.ExecuteNonQueryAsync($"CREATE DOMAIN {type} AS integer; CREATE TYPE {rangeType} AS RANGE(subtype={type})"); + + var dataSourceBuilder = CreateDataSourceBuilder(); + dataSourceBuilder.EnableUnmappedTypes(); + await using var dataSource = dataSourceBuilder.Build(); + await using var connection = await dataSource.OpenConnectionAsync(); + + await AssertType( + connection, + new NpgsqlRange(1, 2), + "[1,2]", + rangeType, + dataTypeInference: DataTypeInference.Mismatch); + } } diff --git a/test/Npgsql.Tests/Types/EnumTests.cs b/test/Npgsql.Tests/Types/EnumTests.cs index c36514d6d3..52f512c944 100644 --- a/test/Npgsql.Tests/Types/EnumTests.cs +++ b/test/Npgsql.Tests/Types/EnumTests.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Data; using System.Threading.Tasks; using Npgsql.NameTranslation; using Npgsql.PostgresTypes; @@ -10,7 +11,7 @@ namespace Npgsql.Tests.Types; -public class EnumTests : MultiplexingTestBase +public class EnumTests : TestBase { enum Mood { Sad, Ok, Happy } enum AnotherEnum { Value1, Value2 } @@ -26,7 +27,7 @@ public async Task Data_source_mapping() dataSourceBuilder.MapEnum(type); await using var dataSource = dataSourceBuilder.Build(); - await AssertType(dataSource, Mood.Happy, "happy", type, npgsqlDbType: null); + await AssertType(dataSource, Mood.Happy, "happy", type, dataTypeInference: DataTypeInference.Nothing); } [Test] @@ -42,8 +43,8 @@ public async Task Data_source_unmap() var isUnmapSuccessful = dataSourceBuilder.UnmapEnum(type); await using var dataSource = dataSourceBuilder.Build(); - Assert.IsTrue(isUnmapSuccessful); - Assert.ThrowsAsync(() => AssertType(dataSource, Mood.Happy, "happy", type, npgsqlDbType: null)); + Assert.That(isUnmapSuccessful); + Assert.ThrowsAsync(() => AssertType(dataSource, Mood.Happy, "happy", type, dataTypeInference: DataTypeInference.Nothing)); } [Test] @@ -56,7 +57,7 @@ public async Task Data_source_mapping_non_generic() var dataSourceBuilder = CreateDataSourceBuilder(); dataSourceBuilder.MapEnum(typeof(Mood), type); await using var dataSource = dataSourceBuilder.Build(); - await AssertType(dataSource, Mood.Happy, "happy", type, npgsqlDbType: null); + await AssertType(dataSource, Mood.Happy, "happy", type, dataTypeInference: DataTypeInference.Nothing); } [Test] @@ -72,8 +73,8 @@ public async Task Data_source_unmap_non_generic() var isUnmapSuccessful = dataSourceBuilder.UnmapEnum(typeof(Mood), type); await using var dataSource = dataSourceBuilder.Build(); - Assert.IsTrue(isUnmapSuccessful); - Assert.ThrowsAsync(() => AssertType(dataSource, Mood.Happy, "happy", type, npgsqlDbType: null)); + Assert.That(isUnmapSuccessful); + Assert.ThrowsAsync(() => AssertType(dataSource, Mood.Happy, "happy", type, dataTypeInference: DataTypeInference.Nothing)); } [Test] @@ -91,7 +92,7 @@ await adminConnection.ExecuteNonQueryAsync($@" dataSourceBuilder.MapEnum(type2); await using var dataSource = dataSourceBuilder.Build(); - await AssertType(dataSource, new[] { Mood.Ok, Mood.Sad }, "{ok,sad}", type1 + "[]", npgsqlDbType: null); + await AssertType(dataSource, new[] { Mood.Ok, Mood.Sad }, "{ok,sad}", type1 + "[]", dataTypeInference: DataTypeInference.Nothing); } [Test] @@ -105,7 +106,7 @@ public async Task Array() dataSourceBuilder.MapEnum(type); await using var dataSource = dataSourceBuilder.Build(); - await AssertType(dataSource, new[] { Mood.Ok, Mood.Happy }, "{ok,happy}", type + "[]", npgsqlDbType: null); + await AssertType(dataSource, new[] { Mood.Ok, Mood.Happy }, "{ok,happy}", type + "[]", dataTypeInference: DataTypeInference.Nothing); } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/859")] @@ -119,9 +120,9 @@ public async Task Name_translation_default_snake_case() dataSourceBuilder.MapEnum(enumName1); await using var dataSource = dataSourceBuilder.Build(); - await AssertType(dataSource, NameTranslationEnum.Simple, "simple", enumName1, npgsqlDbType: null); - await AssertType(dataSource, NameTranslationEnum.TwoWords, "two_words", enumName1, npgsqlDbType: null); - await AssertType(dataSource, NameTranslationEnum.SomeClrName, "some_database_name", enumName1, npgsqlDbType: null); + await AssertType(dataSource, NameTranslationEnum.Simple, "simple", enumName1, dataTypeInference: DataTypeInference.Nothing); + await AssertType(dataSource, NameTranslationEnum.TwoWords, "two_words", enumName1, dataTypeInference: DataTypeInference.Nothing); + await AssertType(dataSource, NameTranslationEnum.SomeClrName, "some_database_name", enumName1, dataTypeInference: DataTypeInference.Nothing); } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/859")] @@ -135,9 +136,9 @@ public async Task Name_translation_null() dataSourceBuilder.MapEnum(type, nameTranslator: new NpgsqlNullNameTranslator()); await using var dataSource = dataSourceBuilder.Build(); - await AssertType(dataSource, NameTranslationEnum.Simple, "Simple", type, npgsqlDbType: null); - await AssertType(dataSource, NameTranslationEnum.TwoWords, "TwoWords", type, npgsqlDbType: null); - await AssertType(dataSource, NameTranslationEnum.SomeClrName, "some_database_name", type, npgsqlDbType: null); + await AssertType(dataSource, NameTranslationEnum.Simple, "Simple", type, dataTypeInference: DataTypeInference.Nothing); + await AssertType(dataSource, NameTranslationEnum.TwoWords, "TwoWords", type, dataTypeInference: DataTypeInference.Nothing); + await AssertType(dataSource, NameTranslationEnum.SomeClrName, "some_database_name", type, dataTypeInference: DataTypeInference.Nothing); } [Test] @@ -152,8 +153,8 @@ await connection.ExecuteNonQueryAsync(@$" CREATE TYPE {type2} AS ENUM ('value1', 'value2');"); await connection.ReloadTypesAsync(); - await AssertType(connection, Mood.Happy, "happy", type1, npgsqlDbType: null, isDefault: false); - await AssertType(connection, AnotherEnum.Value2, "value2", type2, npgsqlDbType: null, isDefault: false); + await AssertType(connection, Mood.Happy, "happy", type1, dataTypeInference: DataTypeInference.Nothing, valueTypeEqualsFieldType: false); + await AssertType(connection, AnotherEnum.Value2, "value2", type2, dataTypeInference: DataTypeInference.Nothing, valueTypeEqualsFieldType: false); } [Test] @@ -170,11 +171,11 @@ public async Task Unmapped_enum_as_clr_enum_supported_only_with_EnableUnmappedTy nameof(NpgsqlDataSourceBuilder)); var exception = await AssertTypeUnsupportedWrite(Mood.Happy, enumType); - Assert.IsInstanceOf(exception.InnerException); + Assert.That(exception.InnerException, Is.InstanceOf()); Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); exception = await AssertTypeUnsupportedRead("happy", enumType); - Assert.IsInstanceOf(exception.InnerException); + Assert.That(exception.InnerException, Is.InstanceOf()); Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); } @@ -186,7 +187,9 @@ public async Task Unmapped_enum_as_string() await connection.ExecuteNonQueryAsync($"CREATE TYPE {type} AS ENUM ('sad', 'ok', 'happy')"); await connection.ReloadTypesAsync(); - await AssertType(connection, "happy", "happy", type, npgsqlDbType: null, isDefaultForWriting: false); + await AssertType(connection, "happy", "happy", type, + dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Object, DbType.String)); } enum NameTranslationEnum @@ -212,8 +215,8 @@ await adminConnection.ExecuteNonQueryAsync($@" dataSourceBuilder.MapEnum($"{schema2}.my_enum"); await using var dataSource = dataSourceBuilder.Build(); - await AssertType(dataSource, Enum1.One, "one", $"{schema1}.my_enum", npgsqlDbType: null); - await AssertType(dataSource, Enum2.Alpha, "alpha", $"{schema2}.my_enum", npgsqlDbType: null); + await AssertType(dataSource, Enum1.One, "one", $"{schema1}.my_enum", dataTypeInference: DataTypeInference.Nothing); + await AssertType(dataSource, Enum2.Alpha, "alpha", $"{schema2}.my_enum", dataTypeInference: DataTypeInference.Nothing); } enum Enum1 { One } @@ -243,6 +246,4 @@ enum TestEnum [PgName("label3")] Label3 } - - public EnumTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/Types/FullTextSearchTests.cs b/test/Npgsql.Tests/Types/FullTextSearchTests.cs index eda874b12a..079bb7dec5 100644 --- a/test/Npgsql.Tests/Types/FullTextSearchTests.cs +++ b/test/Npgsql.Tests/Types/FullTextSearchTests.cs @@ -9,46 +9,38 @@ namespace Npgsql.Tests.Types; -public class FullTextSearchTests : MultiplexingTestBase +public class FullTextSearchTests : TestBase { - public FullTextSearchTests(MultiplexingMode multiplexingMode) - : base(multiplexingMode) { } - [Test] public Task TsVector() => AssertType( NpgsqlTsVector.Parse("'1' '2' 'a':24,25A,26B,27,28,12345C 'b' 'c' 'd'"), "'1' '2' 'a':24,25A,26B,27,28,12345C 'b' 'c' 'd'", - "tsvector", - NpgsqlDbType.TsVector); + "tsvector"); public static IEnumerable TsQueryTestCases() => new[] { - new object[] - { + [ "'a'", new NpgsqlTsQueryLexeme("a") - }, - new object[] - { + ], + [ "!'a'", new NpgsqlTsQueryNot( new NpgsqlTsQueryLexeme("a")) - }, - new object[] - { + ], + [ "'a' | 'b'", new NpgsqlTsQueryOr( new NpgsqlTsQueryLexeme("a"), new NpgsqlTsQueryLexeme("b")) - }, - new object[] - { + ], + [ "'a' & 'b'", new NpgsqlTsQueryAnd( new NpgsqlTsQueryLexeme("a"), new NpgsqlTsQueryLexeme("b")) - }, + ], new object[] { "'a' <-> 'b'", @@ -60,7 +52,7 @@ public static IEnumerable TsQueryTestCases() => new[] [Test] [TestCaseSource(nameof(TsQueryTestCases))] public Task TsQuery(string sqlLiteral, NpgsqlTsQuery query) - => AssertType(query, sqlLiteral, "tsquery", NpgsqlDbType.TsQuery); + => AssertType(query, sqlLiteral, "tsquery"); [Test] public async Task Full_text_search_not_supported_by_default_on_NpgsqlSlimSourceBuilder() @@ -74,20 +66,20 @@ public async Task Full_text_search_not_supported_by_default_on_NpgsqlSlimSourceB await using var dataSource = dataSourceBuilder.Build(); var exception = await AssertTypeUnsupportedRead("a", "tsquery", dataSource); - Assert.IsInstanceOf(exception.InnerException); - Assert.AreEqual(errorMessage, exception.InnerException!.Message); + Assert.That(exception.InnerException, Is.InstanceOf()); + Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); - exception = await AssertTypeUnsupportedWrite(new NpgsqlTsQueryLexeme("a"), pgTypeName: null, dataSource); - Assert.IsInstanceOf(exception.InnerException); - Assert.AreEqual(errorMessage, exception.InnerException!.Message); + exception = await AssertTypeUnsupportedWrite(new NpgsqlTsQueryLexeme("a"), dataTypeName: null, dataSource); + Assert.That(exception.InnerException, Is.InstanceOf()); + Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); exception = await AssertTypeUnsupportedRead("1", "tsvector", dataSource); - Assert.IsInstanceOf(exception.InnerException); - Assert.AreEqual(errorMessage, exception.InnerException!.Message); + Assert.That(exception.InnerException, Is.InstanceOf()); + Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); - exception = await AssertTypeUnsupportedWrite(NpgsqlTsVector.Parse("'1'"), pgTypeName: null, dataSource); - Assert.IsInstanceOf(exception.InnerException); - Assert.AreEqual(errorMessage, exception.InnerException!.Message); + exception = await AssertTypeUnsupportedWrite(NpgsqlTsVector.Parse("'1'"), dataTypeName: null, dataSource); + Assert.That(exception.InnerException, Is.InstanceOf()); + Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); } [Test] @@ -97,7 +89,7 @@ public async Task NpgsqlSlimSourceBuilder_EnableFullTextSearch() dataSourceBuilder.EnableFullTextSearch(); await using var dataSource = dataSourceBuilder.Build(); - await AssertType(new NpgsqlTsQueryLexeme("a"), "'a'", "tsquery", NpgsqlDbType.TsQuery); - await AssertType(NpgsqlTsVector.Parse("'1'"), "'1'", "tsvector", NpgsqlDbType.TsVector); + await AssertType(new NpgsqlTsQueryLexeme("a"), "'a'", "tsquery"); + await AssertType(NpgsqlTsVector.Parse("'1'"), "'1'", "tsvector"); } } diff --git a/test/Npgsql.Tests/Types/GeometricTypeTests.cs b/test/Npgsql.Tests/Types/GeometricTypeTests.cs index c4d8d53b0e..20a7606e04 100644 --- a/test/Npgsql.Tests/Types/GeometricTypeTests.cs +++ b/test/Npgsql.Tests/Types/GeometricTypeTests.cs @@ -10,19 +10,19 @@ namespace Npgsql.Tests.Types; /// /// https://www.postgresql.org/docs/current/static/datatype-geometric.html /// -class GeometricTypeTests : MultiplexingTestBase +class GeometricTypeTests : TestBase { [Test] public Task Point() - => AssertType(new NpgsqlPoint(1.2, 3.4), "(1.2,3.4)", "point", NpgsqlDbType.Point); + => AssertType(new NpgsqlPoint(1.2, 3.4), "(1.2,3.4)", "point"); [Test] public Task Line() - => AssertType(new NpgsqlLine(1, 2, 3), "{1,2,3}", "line", NpgsqlDbType.Line); + => AssertType(new NpgsqlLine(1, 2, 3), "{1,2,3}", "line"); [Test] public Task LineSegment() - => AssertType(new NpgsqlLSeg(1, 2, 3, 4), "[(1,2),(3,4)]", "lseg", NpgsqlDbType.LSeg); + => AssertType(new NpgsqlLSeg(1, 2, 3, 4), "[(1,2),(3,4)]", "lseg"); [Test] public async Task Box() @@ -31,21 +31,18 @@ await AssertType( new NpgsqlBox(top: 3, right: 4, bottom: 1, left: 2), "(4,3),(2,1)", "box", - NpgsqlDbType.Box, skipArrayCheck: true); // Uses semicolon instead of comma as separator await AssertType( new NpgsqlBox(top: -10, right: 0, bottom: -20, left: -10), "(0,-10),(-10,-20)", "box", - NpgsqlDbType.Box, skipArrayCheck: true); // Uses semicolon instead of comma as separator await AssertType( new NpgsqlBox(top: 1, right: 2, bottom: 3, left: 4), "(4,3),(2,1)", "box", - NpgsqlDbType.Box, skipArrayCheck: true); // Uses semicolon instead of comma as separator var swapped = new NpgsqlBox(top: -20, right: -10, bottom: -10, left: 0); @@ -54,21 +51,18 @@ await AssertType( swapped, "(0,-10),(-10,-20)", "box", - NpgsqlDbType.Box, skipArrayCheck: true); // Uses semicolon instead of comma as separator await AssertType( swapped with { UpperRight = new NpgsqlPoint(-20,-10) }, "(-10,-10),(-20,-20)", "box", - NpgsqlDbType.Box, skipArrayCheck: true); // Uses semicolon instead of comma as separator await AssertType( swapped with { LowerLeft = new NpgsqlPoint(10, 10) }, "(10,10),(0,-10)", "box", - NpgsqlDbType.Box, skipArrayCheck: true); // Uses semicolon instead of comma as separator } @@ -85,9 +79,7 @@ public async Task Box_array() await AssertType( data, "{(4,3),(2,1);(6,5),(4,3);(0,-10),(-10,-20)}", - "box[]", - NpgsqlDbType.Box | NpgsqlDbType.Array - ); + "box[]"); var swappedData = new[] { @@ -99,42 +91,34 @@ await AssertType( await AssertType( swappedData, "{(4,3),(2,1);(6,5),(4,3);(0,-10),(-10,-20)}", - "box[]", - NpgsqlDbType.Box | NpgsqlDbType.Array - ); + "box[]"); } [Test] public Task Path_closed() => AssertType( - new NpgsqlPath(new[] { new NpgsqlPoint(1, 2), new NpgsqlPoint(3, 4) }, false), + new NpgsqlPath([new NpgsqlPoint(1, 2), new NpgsqlPoint(3, 4)], false), "((1,2),(3,4))", - "path", - NpgsqlDbType.Path); + "path"); [Test] public Task Path_open() => AssertType( - new NpgsqlPath(new[] { new NpgsqlPoint(1, 2), new NpgsqlPoint(3, 4) }, true), + new NpgsqlPath([new NpgsqlPoint(1, 2), new NpgsqlPoint(3, 4)], true), "[(1,2),(3,4)]", - "path", - NpgsqlDbType.Path); + "path"); [Test] public Task Polygon() => AssertType( new NpgsqlPolygon(new NpgsqlPoint(1, 2), new NpgsqlPoint(3, 4)), "((1,2),(3,4))", - "polygon", - NpgsqlDbType.Polygon); + "polygon"); [Test] public Task Circle() => AssertType( new NpgsqlCircle(1, 2, 0.5), "<(1,2),0.5>", - "circle", - NpgsqlDbType.Circle); - - public GeometricTypeTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} + "circle"); } diff --git a/test/Npgsql.Tests/Types/HstoreTests.cs b/test/Npgsql.Tests/Types/HstoreTests.cs index 5696cad98b..2d42be4448 100644 --- a/test/Npgsql.Tests/Types/HstoreTests.cs +++ b/test/Npgsql.Tests/Types/HstoreTests.cs @@ -1,12 +1,11 @@ using System.Collections.Generic; using System.Collections.Immutable; using System.Threading.Tasks; -using NpgsqlTypes; using NUnit.Framework; namespace Npgsql.Tests.Types; -public class HstoreTests : MultiplexingTestBase +public class HstoreTests : TestBase { [Test] public Task Hstore() @@ -18,12 +17,11 @@ public Task Hstore() {"cd", "hello"} }, @"""a""=>""3"", ""b""=>NULL, ""cd""=>""hello""", - "hstore", - NpgsqlDbType.Hstore, isNpgsqlDbTypeInferredFromClrType: false); + "hstore", dataTypeInference: DataTypeInference.Nothing); [Test] public Task Hstore_empty() - => AssertType(new Dictionary(), @"", "hstore", NpgsqlDbType.Hstore, isNpgsqlDbTypeInferredFromClrType: false); + => AssertType(new Dictionary(), @"", "hstore", dataTypeInference: DataTypeInference.Nothing); [Test] public Task Hstore_as_ImmutableDictionary() @@ -38,8 +36,7 @@ public Task Hstore_as_ImmutableDictionary() immutableDictionary, @"""a""=>""3"", ""b""=>NULL, ""cd""=>""hello""", "hstore", - NpgsqlDbType.Hstore, - isDefaultForReading: false, isNpgsqlDbTypeInferredFromClrType: false); + dataTypeInference: DataTypeInference.Nothing, valueTypeEqualsFieldType: false); } [Test] @@ -53,8 +50,7 @@ public Task Hstore_as_IDictionary() }, @"""a""=>""3"", ""b""=>NULL, ""cd""=>""hello""", "hstore", - NpgsqlDbType.Hstore, - isDefaultForReading: false, isNpgsqlDbTypeInferredFromClrType: false); + dataTypeInference: DataTypeInference.Nothing, valueTypeEqualsFieldType: false); [OneTimeSetUp] public async Task SetUp() @@ -63,6 +59,4 @@ public async Task SetUp() TestUtil.MinimumPgVersion(conn, "9.1", "Hstore introduced in PostgreSQL 9.1"); await TestUtil.EnsureExtensionAsync(conn, "hstore", "9.1"); } - - public HstoreTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/Types/InternalTypeTests.cs b/test/Npgsql.Tests/Types/InternalTypeTests.cs index a5d69664a4..21b9e8c24f 100644 --- a/test/Npgsql.Tests/Types/InternalTypeTests.cs +++ b/test/Npgsql.Tests/Types/InternalTypeTests.cs @@ -4,7 +4,7 @@ namespace Npgsql.Tests.Types; -public class InternalTypeTests : MultiplexingTestBase +public class InternalTypeTests : TestBase { [Test] public async Task Read_internal_char() @@ -52,21 +52,22 @@ public async Task Tid() cmd.Parameters.AddWithValue("p", NpgsqlDbType.Tid, expected); using var reader = await cmd.ExecuteReaderAsync(); reader.Read(); - Assert.AreEqual(1234, reader.GetFieldValue(0).BlockNumber); - Assert.AreEqual(40000, reader.GetFieldValue(0).OffsetNumber); - Assert.AreEqual(expected.BlockNumber, reader.GetFieldValue(1).BlockNumber); - Assert.AreEqual(expected.OffsetNumber, reader.GetFieldValue(1).OffsetNumber); + Assert.That(reader.GetFieldValue(0).BlockNumber, Is.EqualTo(1234)); + Assert.That(reader.GetFieldValue(0).OffsetNumber, Is.EqualTo(40000)); + Assert.That(reader.GetFieldValue(1).BlockNumber, Is.EqualTo(expected.BlockNumber)); + Assert.That(reader.GetFieldValue(1).OffsetNumber, Is.EqualTo(expected.OffsetNumber)); } #region NpgsqlLogSequenceNumber / PgLsn - static readonly TestCaseData[] EqualsObjectCases = { + static readonly TestCaseData[] EqualsObjectCases = + [ new TestCaseData(new NpgsqlLogSequenceNumber(1ul), null).Returns(false), new TestCaseData(new NpgsqlLogSequenceNumber(1ul), new object()).Returns(false), new TestCaseData(new NpgsqlLogSequenceNumber(1ul), 1ul).Returns(false), // no implicit cast new TestCaseData(new NpgsqlLogSequenceNumber(1ul), "0/0").Returns(false), // no implicit cast/parsing - new TestCaseData(new NpgsqlLogSequenceNumber(1ul), new NpgsqlLogSequenceNumber(1ul)).Returns(true), - }; + new TestCaseData(new NpgsqlLogSequenceNumber(1ul), new NpgsqlLogSequenceNumber(1ul)).Returns(true) + ]; [Test, TestCaseSource(nameof(EqualsObjectCases))] public bool NpgsqlLogSequenceNumber_equals(NpgsqlLogSequenceNumber lsn, object? obj) @@ -77,7 +78,7 @@ public bool NpgsqlLogSequenceNumber_equals(NpgsqlLogSequenceNumber lsn, object? public async Task NpgsqlLogSequenceNumber() { var expected1 = new NpgsqlLogSequenceNumber(42949672971ul); - Assert.AreEqual(expected1, NpgsqlTypes.NpgsqlLogSequenceNumber.Parse("A/B")); + Assert.That(NpgsqlTypes.NpgsqlLogSequenceNumber.Parse("A/B"), Is.EqualTo(expected1)); await using var conn = await OpenConnectionAsync(); using var cmd = conn.CreateCommand(); cmd.CommandText = "SELECT 'A/B'::pg_lsn, @p::pg_lsn"; @@ -86,15 +87,13 @@ public async Task NpgsqlLogSequenceNumber() reader.Read(); var result1 = reader.GetFieldValue(0); var result2 = reader.GetFieldValue(1); - Assert.AreEqual(expected1, result1); - Assert.AreEqual(42949672971ul, (ulong)result1); - Assert.AreEqual("A/B", result1.ToString()); - Assert.AreEqual(expected1, result2); - Assert.AreEqual(42949672971ul, (ulong)result2); - Assert.AreEqual("A/B", result2.ToString()); + Assert.That(result1, Is.EqualTo(expected1)); + Assert.That((ulong)result1, Is.EqualTo(42949672971ul)); + Assert.That(result1.ToString(), Is.EqualTo("A/B")); + Assert.That(result2, Is.EqualTo(expected1)); + Assert.That((ulong)result2, Is.EqualTo(42949672971ul)); + Assert.That(result2.ToString(), Is.EqualTo("A/B")); } #endregion NpgsqlLogSequenceNumber / PgLsn - - public InternalTypeTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} -} \ No newline at end of file +} diff --git a/test/Npgsql.Tests/Types/JsonDynamicTests.cs b/test/Npgsql.Tests/Types/JsonDynamicTests.cs index 3c948a4816..a3e68838ac 100644 --- a/test/Npgsql.Tests/Types/JsonDynamicTests.cs +++ b/test/Npgsql.Tests/Types/JsonDynamicTests.cs @@ -1,78 +1,16 @@ using System; using System.Text.Json; -using System.Text.Json.Nodes; using System.Text.Json.Serialization; using System.Threading.Tasks; using Npgsql.Properties; -using NpgsqlTypes; using NUnit.Framework; namespace Npgsql.Tests.Types; -[TestFixture(MultiplexingMode.NonMultiplexing, NpgsqlDbType.Json)] -[TestFixture(MultiplexingMode.NonMultiplexing, NpgsqlDbType.Jsonb)] -[TestFixture(MultiplexingMode.Multiplexing, NpgsqlDbType.Json)] -[TestFixture(MultiplexingMode.Multiplexing, NpgsqlDbType.Jsonb)] -public class JsonDynamicTests : MultiplexingTestBase +[TestFixture("json")] +[TestFixture("jsonb")] +public class JsonDynamicTests : TestBase { - [Test] - public Task Roundtrip_JsonObject() - => AssertType( - new JsonObject { ["Bar"] = 8 }, - IsJsonb ? """{"Bar": 8}""" : """{"Bar":8}""", - PostgresType, - NpgsqlDbType, - // By default we map JsonObject to jsonb - isDefaultForWriting: IsJsonb, - isDefaultForReading: false, - isNpgsqlDbTypeInferredFromClrType: false, - comparer: (x, y) => x.ToString() == y.ToString()); - - [Test] - public Task Roundtrip_JsonArray() - => AssertType( - new JsonArray { 1, 2, 3 }, - IsJsonb ? "[1, 2, 3]" : "[1,2,3]", - PostgresType, - NpgsqlDbType, - // By default we map JsonArray to jsonb - isDefaultForWriting: IsJsonb, - isDefaultForReading: false, - isNpgsqlDbTypeInferredFromClrType: false, - comparer: (x, y) => x.ToString() == y.ToString()); - - [Test] - [IssueLink("https://github.com/npgsql/npgsql/issues/4537")] - public async Task Write_jsonobject_array_without_npgsqldbtype() - { - // By default we map JsonObject to jsonb - if (!IsJsonb) - return; - - await using var conn = await OpenConnectionAsync(); - var tableName = await TestUtil.CreateTempTable(conn, "key SERIAL PRIMARY KEY, ingredients json[]"); - - await using var cmd = new NpgsqlCommand { Connection = conn }; - - var jsonObject1 = new JsonObject - { - { "name", "value1" }, - { "amount", 1 }, - { "unit", "ml" } - }; - - var jsonObject2 = new JsonObject - { - { "name", "value2" }, - { "amount", 2 }, - { "unit", "g" } - }; - - cmd.CommandText = $"INSERT INTO {tableName} (ingredients) VALUES (@p)"; - cmd.Parameters.Add(new("p", new[] { jsonObject1, jsonObject2 })); - await cmd.ExecuteNonQueryAsync(); - } - [Test] public async Task As_poco() => await AssertType( @@ -86,8 +24,8 @@ public async Task As_poco() ? """{"Date": "2019-09-01T00:00:00", "Summary": "Partly cloudy", "TemperatureC": 10}""" : """{"Date":"2019-09-01T00:00:00","TemperatureC":10,"Summary":"Partly cloudy"}""", PostgresType, - NpgsqlDbType, - isDefault: false); + dataTypeInference: DataTypeInference.Nothing, + valueTypeEqualsFieldType: false); [Test] public async Task As_poco_long() @@ -107,8 +45,8 @@ await AssertType( ? $$"""{"Date": "2019-09-01T00:00:00", "Summary": "{{bigString}}", "TemperatureC": 10}""" : $$"""{"Date":"2019-09-01T00:00:00","TemperatureC":10,"Summary":"{{bigString}}"}""", PostgresType, - NpgsqlDbType, - isDefault: false); + dataTypeInference: DataTypeInference.Nothing, + valueTypeEqualsFieldType: false); } [Test] @@ -132,7 +70,7 @@ public async Task As_poco_supported_only_with_EnableDynamicJson() PostgresType, base.DataSource); - Assert.IsInstanceOf(exception.InnerException); + Assert.That(exception.InnerException, Is.InstanceOf()); Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); exception = await AssertTypeUnsupportedRead( @@ -142,7 +80,7 @@ public async Task As_poco_supported_only_with_EnableDynamicJson() PostgresType, base.DataSource); - Assert.IsInstanceOf(exception.InnerException); + Assert.That(exception.InnerException, Is.InstanceOf()); Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); } @@ -150,7 +88,7 @@ public async Task As_poco_supported_only_with_EnableDynamicJson() public async Task Poco_does_not_stomp_GetValue_string() { var dataSource = CreateDataSourceBuilder() - .EnableDynamicJson(new[] {typeof(WeatherForecast)}, new[] {typeof(WeatherForecast)}) + .EnableDynamicJson([typeof(WeatherForecast)], [typeof(WeatherForecast)]) .Build(); var sqlLiteral = IsJsonb @@ -184,8 +122,7 @@ await AssertTypeWrite( ? """{"date": "2019-09-01T00:00:00", "summary": "Partly cloudy", "temperatureC": 10}""" : """{"date":"2019-09-01T00:00:00","temperatureC":10,"summary":"Partly cloudy"}""", PostgresType, - NpgsqlDbType, - isDefault: false); + dataTypeInference: DataTypeInference.Nothing); } [Test, Ignore("TODO We should not change the default type for json/jsonb, it makes little sense.")] @@ -193,9 +130,9 @@ public async Task Poco_default_mapping() { var dataSourceBuilder = CreateDataSourceBuilder(); if (IsJsonb) - dataSourceBuilder.EnableDynamicJson(jsonbClrTypes: new[] { typeof(WeatherForecast) }); + dataSourceBuilder.EnableDynamicJson(jsonbClrTypes: [typeof(WeatherForecast)]); else - dataSourceBuilder.EnableDynamicJson(jsonClrTypes: new[] { typeof(WeatherForecast) }); + dataSourceBuilder.EnableDynamicJson(jsonClrTypes: [typeof(WeatherForecast)]); await using var dataSource = dataSourceBuilder.Build(); await AssertType( @@ -210,115 +147,210 @@ await AssertType( ? """{"Date": "2019-09-01T00:00:00", "Summary": "Partly cloudy", "TemperatureC": 10}""" : """{"Date":"2019-09-01T00:00:00","TemperatureC":10,"Summary":"Partly cloudy"}""", PostgresType, - NpgsqlDbType, - isDefaultForReading: false, - isNpgsqlDbTypeInferredFromClrType: false); + dataTypeInference: DataTypeInference.Nothing); } + #region Polymorphic + [Test] public async Task Poco_polymorphic_mapping() { - // We don't yet support polymorphic deserialization for jsonb as $type does not come back as the first property. - // This could be fixed by detecting PolymorphicOptions types, always buffering their values and modifying the text. - if (IsJsonb) - return; + await using var dataSource = CreateDataSource(builder => + { + var types = new[] {typeof(WeatherForecast)}; + builder + .ConfigureJsonOptions(new() { AllowOutOfOrderMetadataProperties = true }) + .EnableDynamicJson(jsonClrTypes: IsJsonb ? [] : types, jsonbClrTypes: !IsJsonb ? [] : types); + }); - var dataSourceBuilder = CreateDataSourceBuilder(); - dataSourceBuilder.EnableDynamicJson(jsonClrTypes: new[] { typeof(WeatherForecast) }); - await using var dataSource = dataSourceBuilder.Build(); + var value = new ExtendedDerivedWeatherForecast + { + Date = new DateTime(2019, 9, 1), + Summary = "Partly cloudy", + TemperatureC = 10 + }; - await AssertType( - dataSource, - new ExtendedDerivedWeatherForecast() + // Note: we assert a specific string representation, though jsonb doesn't guarantee the property ordering; so the assert may break + // for jsonb if PostgreSQL changes its implementation. + var sql = + IsJsonb + ? """{"Date": "2019-09-01T00:00:00", "$type": "extended", "Summary": "Partly cloudy", "TemperatureC": 10, "TemperatureF": 49}""" + : """{"$type":"extended","TemperatureF":49,"Date":"2019-09-01T00:00:00","TemperatureC":10,"Summary":"Partly cloudy"}"""; + + await AssertTypeWrite(dataSource, value, sql, PostgresType, dataTypeInference: DataTypeInference.Nothing); + await AssertTypeRead(dataSource, sql, PostgresType, value, valueTypeEqualsFieldType: false); + } + + [Test] + public async Task Poco_polymorphic_mapping_read_parents() + { + await using var dataSource = CreateDataSource(builder => + { + var types = new[] {typeof(WeatherForecast)}; + builder + .ConfigureJsonOptions(new() { AllowOutOfOrderMetadataProperties = true }) + .EnableDynamicJson(jsonClrTypes: IsJsonb ? [] : types, jsonbClrTypes: !IsJsonb ? [] : types); + }); + + var value = new ExtendedDerivedWeatherForecast + { + Date = new DateTime(2019, 9, 1), + Summary = "Partly cloudy", + TemperatureC = 10 + }; + + // Note: we assert a specific string representation, though jsonb doesn't guarantee the property ordering; so the assert may break + // for jsonb if PostgreSQL changes its implementation. + var sql = + IsJsonb + ? """{"Date": "2019-09-01T00:00:00", "$type": "extended", "Summary": "Partly cloudy", "TemperatureC": 10, "TemperatureF": 49}""" + : """{"$type":"extended","TemperatureF":49,"Date":"2019-09-01T00:00:00","TemperatureC":10,"Summary":"Partly cloudy"}"""; + + await AssertTypeWrite(dataSource, value, sql, PostgresType, + dataTypeInference: DataTypeInference.Nothing); + + await AssertTypeRead(dataSource, sql, PostgresType, value, valueTypeEqualsFieldType: false); + await AssertTypeRead(dataSource, sql, PostgresType, + new DerivedWeatherForecast { Date = new DateTime(2019, 9, 1), Summary = "Partly cloudy", TemperatureC = 10 }, - """{"$type":"extended","TemperatureF":49,"Date":"2019-09-01T00:00:00","TemperatureC":10,"Summary":"Partly cloudy"}""", - PostgresType, - NpgsqlDbType, - isDefaultForReading: false, - isNpgsqlDbTypeInferredFromClrType: false); + valueTypeEqualsFieldType: false); + await AssertTypeRead(dataSource, sql, PostgresType, value, valueTypeEqualsFieldType: false); } [Test] - public async Task Poco_polymorphic_mapping_read_parents() + public async Task Poco_exact_polymorphic_mapping() { - // We don't yet support polymorphic deserialization for jsonb as $type does not come back as the first property. - // This could be fixed by detecting PolymorphicOptions types, always buffering their values and modifying the text. - if (IsJsonb) - return; - - var dataSourceBuilder = CreateDataSourceBuilder(); - dataSourceBuilder.EnableDynamicJson(jsonClrTypes: new[] { typeof(WeatherForecast) }); - await using var dataSource = dataSourceBuilder.Build(); + await using var dataSource = CreateDataSource(builder => + { + var types = new[] {typeof(ExtendedDerivedWeatherForecast)}; + builder + .ConfigureJsonOptions(new() { AllowOutOfOrderMetadataProperties = true }) + .EnableDynamicJson(jsonClrTypes: IsJsonb ? [] : types, jsonbClrTypes: !IsJsonb ? [] : types); + }); - var value = new ExtendedDerivedWeatherForecast() + var value = new ExtendedDerivedWeatherForecast { Date = new DateTime(2019, 9, 1), Summary = "Partly cloudy", TemperatureC = 10 }; - var sql = """{"$type":"extended","TemperatureF":49,"Date":"2019-09-01T00:00:00","TemperatureC":10,"Summary":"Partly cloudy"}"""; + // Note: we assert a specific string representation, though jsonb doesn't guarantee the property ordering; so the assert may break + // for jsonb if PostgreSQL changes its implementation. + var sql = + IsJsonb + ? """{"Date": "2019-09-01T00:00:00", "Summary": "Partly cloudy", "TemperatureC": 10, "TemperatureF": 49}""" + : """{"TemperatureF":49,"Date":"2019-09-01T00:00:00","TemperatureC":10,"Summary":"Partly cloudy"}"""; - await AssertTypeWrite( - dataSource, - value, - sql, - PostgresType, - NpgsqlDbType, - isNpgsqlDbTypeInferredFromClrType: false); + await AssertTypeWrite(dataSource, value, sql, PostgresType, dataTypeInference: DataTypeInference.Nothing); + await AssertTypeRead(dataSource, sql, PostgresType, value, valueTypeEqualsFieldType: false); + } + + [Test] + public async Task Poco_unspecified_polymorphic_mapping() + { + await using var dataSource = CreateDataSource(builder => + { + builder + .ConfigureJsonOptions(new() { AllowOutOfOrderMetadataProperties = true }) + .EnableDynamicJson(); + }); - // GetFieldValue - await AssertTypeRead(dataSource, sql, PostgresType, value, - comparer: (_, actual) => actual.GetType() == typeof(ExtendedDerivedWeatherForecast), - isDefault: false); + var value = new ExtendedDerivedWeatherForecast + { + Date = new DateTime(2019, 9, 1), + Summary = "Partly cloudy", + TemperatureC = 10 + }; - await AssertTypeRead(dataSource, sql, PostgresType, value, - comparer: (_, actual) => actual.GetType() == typeof(DerivedWeatherForecast), isDefault: false); + // Note: we assert a specific string representation, though jsonb doesn't guarantee the property ordering; so the assert may break + // for jsonb if PostgreSQL changes its implementation. + var sql = + IsJsonb + ? """{"Date": "2019-09-01T00:00:00", "$type": "extended", "Summary": "Partly cloudy", "TemperatureC": 10, "TemperatureF": 49}""" + : """{"$type":"extended","TemperatureF":49,"Date":"2019-09-01T00:00:00","TemperatureC":10,"Summary":"Partly cloudy"}"""; - await AssertTypeRead(dataSource, sql, PostgresType, value, isDefault: false); - } + await AssertTypeWrite(dataSource, value, sql, PostgresType, dataTypeInference: DataTypeInference.Nothing); + // Reading as DerivedWeatherForecast should not cause us to get an instance of ExtendedDerivedWeatherForecast (as it doesn't define JsonDerivedType) + await AssertTypeRead(dataSource, sql, PostgresType, + new DerivedWeatherForecast + { + Date = new DateTime(2019, 9, 1), + Summary = "Partly cloudy", + TemperatureC = 10 + }, + valueTypeEqualsFieldType: false); + await AssertTypeRead(dataSource, sql, PostgresType, value, valueTypeEqualsFieldType: false); + } [Test] - public async Task Poco_exact_polymorphic_mapping() + public async Task Poco_polymorphic_mapping_without_AllowOutOfOrderMetadataProperties() { - // We don't yet support polymorphic deserialization for jsonb as $type does not come back as the first property. - // This could be fixed by detecting PolymorphicOptions types, always buffering their values and modifying the text. - if (IsJsonb) - return; + await using var dataSource = CreateDataSource(builder => + { + var types = new[] {typeof(WeatherForecast)}; + builder + .ConfigureJsonOptions(new() { AllowOutOfOrderMetadataProperties = false }) + .EnableDynamicJson(jsonClrTypes: IsJsonb ? [] : types, jsonbClrTypes: !IsJsonb ? [] : types); + }); - var dataSourceBuilder = CreateDataSourceBuilder(); - dataSourceBuilder.EnableDynamicJson(jsonClrTypes: new[] { typeof(ExtendedDerivedWeatherForecast) }); - await using var dataSource = dataSourceBuilder.Build(); + var value = new ExtendedDerivedWeatherForecast + { + Date = new DateTime(2019, 9, 1), + Summary = "Partly cloudy", + TemperatureC = 10 + }; - await AssertType( - dataSource, - new ExtendedDerivedWeatherForecast() + // Note: we assert a specific string representation, though jsonb doesn't guarantee the property ordering; so the assert may break + // for jsonb if PostgreSQL changes its implementation. + var sql = + IsJsonb + ? """{"Date": "2019-09-01T00:00:00", "Summary": "Partly cloudy", "TemperatureC": 10, "TemperatureF": 49}""" + : """{"$type":"extended","TemperatureF":49,"Date":"2019-09-01T00:00:00","TemperatureC":10,"Summary":"Partly cloudy"}"""; + + await AssertTypeWrite(dataSource, value, sql, PostgresType, dataTypeInference: DataTypeInference.Nothing); + + // As we have disabled polymorphism for jsonb when AllowOutOfOrderMetadataProperties = false we should be able to read it as equalt to a WeatherForecast instance. + if (IsJsonb) + await AssertTypeRead(dataSource, sql, PostgresType, + new WeatherForecast + { + Date = new DateTime(2019, 9, 1), + Summary = "Partly cloudy", + TemperatureC = 10 + }, + valueTypeEqualsFieldType: false); + + // Reading as DerivedWeatherForecast should not cause us to get an instance of ExtendedDerivedWeatherForecast (as it doesn't define JsonDerivedType) + await AssertTypeRead(dataSource, sql, PostgresType, + new DerivedWeatherForecast { Date = new DateTime(2019, 9, 1), Summary = "Partly cloudy", TemperatureC = 10 }, - """{"TemperatureF":49,"Date":"2019-09-01T00:00:00","TemperatureC":10,"Summary":"Partly cloudy"}""", - PostgresType, - NpgsqlDbType, - isDefaultForReading: false, - isNpgsqlDbTypeInferredFromClrType: false); + valueTypeEqualsFieldType: false); + + // We won't get the original value back for jsonb as we can't support polymorphism without also enforcing AllowOutOfOrderMetadataProperties is true. + // If we output $type, jsonb won't have that at the start and STJ will throw due to it appearing later in the object. So it's disabled entirely. + if (!IsJsonb) + await AssertTypeRead(dataSource, sql, PostgresType, value, valueTypeEqualsFieldType: false); } [Test] - public async Task Poco_unspecified_polymorphic_mapping() + public async Task Poco_unspecified_polymorphic_mapping_without_AllowOutOfOrderMetadataProperties() { - // We don't yet support polymorphic deserialization for jsonb as $type does not come back as the first property. - // This could be fixed by detecting PolymorphicOptions types, always buffering their values and modifying the text. - // In this case we don't have any statically mapped base type to check its PolymorphicOptions on. - // Detecting whether the type could be polymorphic would require us to duplicate STJ's nearest polymorphic ancestor search. - if (IsJsonb) - return; + await using var dataSource = CreateDataSource(builder => + { + builder + .ConfigureJsonOptions(new() { AllowOutOfOrderMetadataProperties = false }) + .EnableDynamicJson(); + }); var value = new ExtendedDerivedWeatherForecast { @@ -327,22 +359,44 @@ public async Task Poco_unspecified_polymorphic_mapping() TemperatureC = 10 }; - var sql = """{"$type":"extended","TemperatureF":49,"Date":"2019-09-01T00:00:00","TemperatureC":10,"Summary":"Partly cloudy"}"""; + // Note: we assert a specific string representation, though jsonb doesn't guarantee the property ordering; so the assert may break + // for jsonb if PostgreSQL changes its implementation. + var sql = + IsJsonb + ? """{"Date": "2019-09-01T00:00:00", "Summary": "Partly cloudy", "TemperatureC": 10, "TemperatureF": 49}""" + : """{"$type":"extended","TemperatureF":49,"Date":"2019-09-01T00:00:00","TemperatureC":10,"Summary":"Partly cloudy"}"""; - await AssertType( - value, - sql, - PostgresType, - NpgsqlDbType, - isDefault: false); + await AssertTypeWrite(dataSource, value, sql, PostgresType, dataTypeInference: DataTypeInference.Nothing); + + // As we have disabled polymorphism for jsonb when AllowOutOfOrderMetadataProperties = false we should be able to read it as equalt to a WeatherForecast instance. + if (IsJsonb) + await AssertTypeRead(dataSource, sql, PostgresType, + new WeatherForecast + { + Date = new DateTime(2019, 9, 1), + Summary = "Partly cloudy", + TemperatureC = 10 + }, + valueTypeEqualsFieldType: false); - await AssertTypeRead(DataSource, sql, PostgresType, value, - comparer: (_, actual) => actual.GetType() == typeof(DerivedWeatherForecast), isDefault: false); + // Reading as DerivedWeatherForecast should not cause us to get an instance of ExtendedDerivedWeatherForecast (as it doesn't define JsonDerivedType) + await AssertTypeRead(dataSource, sql, PostgresType, + new DerivedWeatherForecast + { + Date = new DateTime(2019, 9, 1), + Summary = "Partly cloudy", + TemperatureC = 10 + }, + valueTypeEqualsFieldType: false); - await AssertTypeRead(DataSource, sql, PostgresType, value, - comparer: (_, actual) => actual.GetType() == typeof(ExtendedDerivedWeatherForecast), isDefault: false); + // We won't get the original value back for jsonb as we can't support polymorphism without also enforcing AllowOutOfOrderMetadataProperties is true. + // If we output $type, jsonb won't have that at the start and STJ will throw due to it appearing later in the object. So it's disabled entirely. + if (!IsJsonb) + await AssertTypeRead(dataSource, sql, PostgresType, value, valueTypeEqualsFieldType: false); } + // ReSharper disable UnusedAutoPropertyAccessor.Local + // ReSharper disable UnusedMember.Local [JsonDerivedType(typeof(ExtendedDerivedWeatherForecast), typeDiscriminator: "extended")] record WeatherForecast { @@ -351,30 +405,36 @@ record WeatherForecast public string Summary { get; set; } = ""; } - record DerivedWeatherForecast : WeatherForecast - { - } + record DerivedWeatherForecast : WeatherForecast; record ExtendedDerivedWeatherForecast : DerivedWeatherForecast { public int TemperatureF => 32 + (int)(TemperatureC / 0.5556); } + // ReSharper restore UnusedMember.Local + // ReSharper restore UnusedAutoPropertyAccessor.Local - public JsonDynamicTests(MultiplexingMode multiplexingMode, NpgsqlDbType npgsqlDbType) - : base(multiplexingMode) + #endregion Polymorphic + + public JsonDynamicTests(string dataTypeName) { DataSource = CreateDataSource(b => b.EnableDynamicJson()); - if (npgsqlDbType == NpgsqlDbType.Jsonb) + if (dataTypeName == "jsonb") using (var conn = OpenConnection()) TestUtil.MinimumPgVersion(conn, "9.4.0", "JSONB data type not yet introduced"); - NpgsqlDbType = npgsqlDbType; + PostgresType = dataTypeName; } protected override NpgsqlDataSource DataSource { get; } - bool IsJsonb => NpgsqlDbType == NpgsqlDbType.Jsonb; - string PostgresType => IsJsonb ? "jsonb" : "json"; - readonly NpgsqlDbType NpgsqlDbType; + [OneTimeTearDown] + protected void CleanUpDataSource() + { + DataSource.Dispose(); + } + + bool IsJsonb => PostgresType == "jsonb"; + string PostgresType { get; } } diff --git a/test/Npgsql.Tests/Types/JsonPathTests.cs b/test/Npgsql.Tests/Types/JsonPathTests.cs index de49a631e0..62db50032b 100644 --- a/test/Npgsql.Tests/Types/JsonPathTests.cs +++ b/test/Npgsql.Tests/Types/JsonPathTests.cs @@ -6,16 +6,13 @@ namespace Npgsql.Tests.Types; -public class JsonPathTests : MultiplexingTestBase +public class JsonPathTests : TestBase { - public JsonPathTests(MultiplexingMode multiplexingMode) - : base(multiplexingMode) { } - - static readonly object[] ReadWriteCases = new[] - { + static readonly object[] ReadWriteCases = + [ new object[] { "'$'", "$" }, - new object[] { "'$\"varname\"'", "$\"varname\"" }, - }; + new object[] { "'$\"varname\"'", "$\"varname\"" } + ]; [Test] [TestCase("$")] @@ -25,8 +22,8 @@ public async Task JsonPath(string jsonPath) using var conn = await OpenConnectionAsync(); MinimumPgVersion(conn, "12.0", "The jsonpath type was introduced in PostgreSQL 12"); await AssertType( - jsonPath, jsonPath, "jsonpath", NpgsqlDbType.JsonPath, isDefaultForWriting: false, isNpgsqlDbTypeInferredFromClrType: false, - inferredDbType: DbType.Object); + jsonPath, jsonPath, "jsonpath", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Object, DbType.String)); } [Test] @@ -54,6 +51,6 @@ public async Task Write(string query, string expected) using var cmd = new NpgsqlCommand($"SELECT 'Passed' WHERE @p::text = {query}::text", conn) { Parameters = { new NpgsqlParameter("p", NpgsqlDbType.JsonPath) { Value = expected } } }; using var rdr = await cmd.ExecuteReaderAsync(); - Assert.True(rdr.Read()); + Assert.That(rdr.Read()); } } diff --git a/test/Npgsql.Tests/Types/JsonTests.cs b/test/Npgsql.Tests/Types/JsonTests.cs index e7a9b4576e..5cf8504ac4 100644 --- a/test/Npgsql.Tests/Types/JsonTests.cs +++ b/test/Npgsql.Tests/Types/JsonTests.cs @@ -4,22 +4,20 @@ using System.Text; using System.Text.Json; using System.Text.Json.Nodes; -using System.Text.Json.Serialization; using System.Threading.Tasks; -using NpgsqlTypes; using NUnit.Framework; namespace Npgsql.Tests.Types; -[TestFixture(MultiplexingMode.NonMultiplexing, NpgsqlDbType.Json)] -[TestFixture(MultiplexingMode.NonMultiplexing, NpgsqlDbType.Jsonb)] -[TestFixture(MultiplexingMode.Multiplexing, NpgsqlDbType.Json)] -[TestFixture(MultiplexingMode.Multiplexing, NpgsqlDbType.Jsonb)] -public class JsonTests : MultiplexingTestBase +[TestFixture("json")] +[TestFixture("jsonb")] +public class JsonTests : TestBase { [Test] public async Task As_string() - => await AssertType("""{"K": "V"}""", """{"K": "V"}""", PostgresType, NpgsqlDbType, isDefaultForWriting: false); + => await AssertType("""{"K": "V"}""", """{"K": "V"}""", + PostgresType, dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Object, DbType.String)); [Test] public async Task As_string_long() @@ -32,7 +30,9 @@ public async Task As_string_long() .Append(@"""}") .ToString(); - await AssertType(value, value, PostgresType, NpgsqlDbType, isDefaultForWriting: false); + await AssertType(value, value, + PostgresType, dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Object, DbType.String)); } [Test] @@ -48,25 +48,33 @@ public async Task As_string_with_GetTextReader() [Test] public async Task As_char_array() - => await AssertType("""{"K": "V"}""".ToCharArray(), """{"K": "V"}""", PostgresType, NpgsqlDbType, isDefault: false); + => await AssertType("""{"K": "V"}""".ToCharArray(), """{"K": "V"}""", + PostgresType, dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Object, DbType.String), valueTypeEqualsFieldType: false); [Test] public async Task As_bytes() - => await AssertType("""{"K": "V"}"""u8.ToArray(), """{"K": "V"}""", PostgresType, NpgsqlDbType, isDefault: false); + => await AssertType("""{"K": "V"}"""u8.ToArray(), """{"K": "V"}""", + PostgresType, dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Object, DbType.Binary), valueTypeEqualsFieldType: false); [Test] public async Task Write_as_ReadOnlyMemory_of_byte() - => await AssertTypeWrite(new ReadOnlyMemory("""{"K": "V"}"""u8.ToArray()), """{"K": "V"}""", PostgresType, NpgsqlDbType, - isDefault: false); + => await AssertTypeWrite(new ReadOnlyMemory("""{"K": "V"}"""u8.ToArray()), """{"K": "V"}""", + PostgresType, dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Object, DbType.Binary)); [Test] public async Task Write_as_ArraySegment_of_char() - => await AssertTypeWrite(new ArraySegment("""{"K": "V"}""".ToCharArray()), """{"K": "V"}""", PostgresType, NpgsqlDbType, - isDefault: false); + => await AssertTypeWrite(new ArraySegment("""{"K": "V"}""".ToCharArray()), """{"K": "V"}""", + PostgresType, dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Object, DbType.String)); [Test] public Task As_MemoryStream() - => AssertTypeWrite(() => new MemoryStream("""{"K": "V"}"""u8.ToArray()), """{"K": "V"}""", PostgresType, NpgsqlDbType, isDefault: false); + => AssertTypeWrite(() => new MemoryStream("""{"K": "V"}"""u8.ToArray()), """{"K": "V"}""", + PostgresType, dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Object, DbType.Binary)); [Test] public async Task As_JsonDocument() @@ -74,9 +82,9 @@ public async Task As_JsonDocument() JsonDocument.Parse("""{"K": "V"}"""), IsJsonb ? """{"K": "V"}""" : """{"K":"V"}""", PostgresType, - NpgsqlDbType, - isDefault: false, - comparer: (x, y) => x.RootElement.GetProperty("K").GetString() == y.RootElement.GetProperty("K").GetString()); + dataTypeInference: DataTypeInference.Mismatch, + comparer: (x, y) => x.RootElement.GetProperty("K").GetString() == y.RootElement.GetProperty("K").GetString(), + valueTypeEqualsFieldType: false); [Test, IssueLink("https://github.com/npgsql/npgsql/issues/5540")] public async Task As_JsonDocument_with_null_root() @@ -84,9 +92,9 @@ public async Task As_JsonDocument_with_null_root() JsonDocument.Parse("null"), "null", PostgresType, - NpgsqlDbType, - isDefault: false, + dataTypeInference: DataTypeInference.Mismatch, comparer: (x, y) => x.RootElement.ValueKind == y.RootElement.ValueKind, + valueTypeEqualsFieldType: false, skipArrayCheck: true); [Test] @@ -95,9 +103,9 @@ public async Task As_JsonElement_with_null_root() JsonDocument.Parse("null").RootElement, "null", PostgresType, - NpgsqlDbType, - isDefault: false, + dataTypeInference: DataTypeInference.Mismatch, comparer: (x, y) => x.ValueKind == y.ValueKind, + valueTypeEqualsFieldType: false, skipArrayCheck: true); [Test] @@ -117,30 +125,24 @@ public Task Roundtrip_string() => AssertType( @"{""p"": 1}", @"{""p"": 1}", - PostgresType, - NpgsqlDbType, - isDefault: false, - isNpgsqlDbTypeInferredFromClrType: false); + PostgresType, dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Object, DbType.String), valueTypeEqualsFieldType: true); [Test] public Task Roundtrip_char_array() => AssertType( @"{""p"": 1}".ToCharArray(), @"{""p"": 1}", - PostgresType, - NpgsqlDbType, - isDefault: false, - isNpgsqlDbTypeInferredFromClrType: false); + PostgresType, dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Object, DbType.String), valueTypeEqualsFieldType: false); [Test] public Task Roundtrip_byte_array() => AssertType( - Encoding.ASCII.GetBytes(@"{""p"": 1}"), + @"{""p"": 1}"u8.ToArray(), @"{""p"": 1}", - PostgresType, - NpgsqlDbType, - isDefault: false, - isNpgsqlDbTypeInferredFromClrType: false); + PostgresType, dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Object, DbType.Binary), valueTypeEqualsFieldType: false); [Test] [IssueLink("https://github.com/npgsql/npgsql/issues/2811")] @@ -168,17 +170,69 @@ public async Task Can_read_two_json_documents() Assert.That(car.RootElement.GetProperty("key").GetString(), Is.EqualTo("foo")); } - public JsonTests(MultiplexingMode multiplexingMode, NpgsqlDbType npgsqlDbType) - : base(multiplexingMode) + [Test] + public Task Roundtrip_JsonObject() + => AssertType( + new JsonObject { ["Bar"] = 8 }, + IsJsonb ? """{"Bar": 8}""" : """{"Bar":8}""", + PostgresType, + // By default we map JsonObject to jsonb + dataTypeInference: IsJsonb ? DataTypeInference.Match : DataTypeInference.Mismatch, + valueTypeEqualsFieldType: false, + comparer: (x, y) => x.ToString() == y.ToString()); + + [Test] + public Task Roundtrip_JsonArray() + => AssertType( + new JsonArray { 1, 2, 3 }, + IsJsonb ? "[1, 2, 3]" : "[1,2,3]", + PostgresType, + // By default we map JsonArray to jsonb + dataTypeInference: IsJsonb ? DataTypeInference.Match : DataTypeInference.Mismatch, + valueTypeEqualsFieldType: false, + comparer: (x, y) => x.ToString() == y.ToString()); + + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/4537")] + public async Task Write_jsonobject_array_without_npgsqldbtype() + { + // By default we map JsonObject to jsonb + if (!IsJsonb) + return; + + await using var conn = await OpenConnectionAsync(); + var tableName = await TestUtil.CreateTempTable(conn, "key SERIAL PRIMARY KEY, ingredients json[]"); + + await using var cmd = new NpgsqlCommand { Connection = conn }; + + var jsonObject1 = new JsonObject + { + { "name", "value1" }, + { "amount", 1 }, + { "unit", "ml" } + }; + + var jsonObject2 = new JsonObject + { + { "name", "value2" }, + { "amount", 2 }, + { "unit", "g" } + }; + + cmd.CommandText = $"INSERT INTO {tableName} (ingredients) VALUES (@p)"; + cmd.Parameters.Add(new("p", new[] { jsonObject1, jsonObject2 })); + await cmd.ExecuteNonQueryAsync(); + } + + public JsonTests(string dataTypeName) { - if (npgsqlDbType == NpgsqlDbType.Jsonb) + if (dataTypeName == "jsonb") using (var conn = OpenConnection()) TestUtil.MinimumPgVersion(conn, "9.4.0", "JSONB data type not yet introduced"); - NpgsqlDbType = npgsqlDbType; + PostgresType = dataTypeName; } - bool IsJsonb => NpgsqlDbType == NpgsqlDbType.Jsonb; - string PostgresType => IsJsonb ? "jsonb" : "json"; - readonly NpgsqlDbType NpgsqlDbType; + bool IsJsonb => PostgresType == "jsonb"; + string PostgresType { get; } } diff --git a/test/Npgsql.Tests/Types/LTreeTests.cs b/test/Npgsql.Tests/Types/LTreeTests.cs index f836b49ca0..c7498adf83 100644 --- a/test/Npgsql.Tests/Types/LTreeTests.cs +++ b/test/Npgsql.Tests/Types/LTreeTests.cs @@ -1,23 +1,30 @@ -using System.Threading.Tasks; +using System.Data; +using System.Threading.Tasks; using Npgsql.Properties; using NpgsqlTypes; using NUnit.Framework; namespace Npgsql.Tests.Types; -public class LTreeTests : MultiplexingTestBase +public class LTreeTests : TestBase { [Test] public Task LQuery() - => AssertType("Top.Science.*", "Top.Science.*", "lquery", NpgsqlDbType.LQuery, isDefaultForWriting: false); + => AssertType("Top.Science.*", "Top.Science.*", + "lquery", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Object, DbType.String)); [Test] public Task LTree() - => AssertType("Top.Science.Astronomy", "Top.Science.Astronomy", "ltree", NpgsqlDbType.LTree, isDefaultForWriting: false); + => AssertType("Top.Science.Astronomy", "Top.Science.Astronomy", + "ltree", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Object, DbType.String)); [Test] public Task LTxtQuery() - => AssertType("Science & Astronomy", "Science & Astronomy", "ltxtquery", NpgsqlDbType.LTxtQuery, isDefaultForWriting: false); + => AssertType("Science & Astronomy", "Science & Astronomy", + "ltxtquery", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Object, DbType.String)); [Test] public async Task LTree_not_supported_by_default_on_NpgsqlSlimSourceBuilder() @@ -36,24 +43,18 @@ public async Task LTree_not_supported_by_default_on_NpgsqlSlimSourceBuilder() } [Test] - public async Task NpgsqlSlimSourceBuilder_EnableLTree() + public async Task NpgsqlSlimSourceBuilder_EnableLTree([Values] bool withArrays) { var dataSourceBuilder = new NpgsqlSlimDataSourceBuilder(ConnectionString); dataSourceBuilder.EnableLTree(); + if (withArrays) + dataSourceBuilder.EnableArrays(); await using var dataSource = dataSourceBuilder.Build(); - await AssertType(dataSource, "Top.Science.Astronomy", "Top.Science.Astronomy", "ltree", NpgsqlDbType.LTree, isDefaultForWriting: false, skipArrayCheck: true); - } - - [Test] - public async Task NpgsqlSlimSourceBuilder_EnableArrays() - { - var dataSourceBuilder = new NpgsqlSlimDataSourceBuilder(ConnectionString); - dataSourceBuilder.EnableLTree(); - dataSourceBuilder.EnableArrays(); - await using var dataSource = dataSourceBuilder.Build(); - - await AssertType(dataSource, "Top.Science.Astronomy", "Top.Science.Astronomy", "ltree", NpgsqlDbType.LTree, isDefaultForWriting: false); + await AssertType(dataSource, "Top.Science.Astronomy", "Top.Science.Astronomy", + "ltree", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Object, DbType.String), + skipArrayCheck: !withArrays); } [OneTimeSetUp] @@ -63,6 +64,4 @@ public async Task SetUp() TestUtil.MinimumPgVersion(conn, "13.0"); await TestUtil.EnsureExtensionAsync(conn, "ltree"); } - - public LTreeTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/Types/LegacyDateTimeTests.cs b/test/Npgsql.Tests/Types/LegacyDateTimeTests.cs index c500324986..730188330e 100644 --- a/test/Npgsql.Tests/Types/LegacyDateTimeTests.cs +++ b/test/Npgsql.Tests/Types/LegacyDateTimeTests.cs @@ -2,13 +2,11 @@ using System.Data; using System.Threading.Tasks; using Npgsql.Internal.ResolverFactories; -using NpgsqlTypes; using NUnit.Framework; using static Npgsql.Util.Statics; namespace Npgsql.Tests.Types; -// Since this test suite manipulates TimeZone, it is incompatible with multiplexing [NonParallelizable] public class LegacyDateTimeTests : TestBase { @@ -18,8 +16,7 @@ public Task Timestamp_with_all_DateTime_kinds([Values] DateTimeKind kind) new DateTime(1998, 4, 12, 13, 26, 38, 789, kind), "1998-04-12 13:26:38.789", "timestamp without time zone", - NpgsqlDbType.Timestamp, - DbType.DateTime); + dbType: DbType.DateTime); [Test] public async Task Timestamp_read_as_Unspecified_DateTime() @@ -32,8 +29,8 @@ public async Task Timestamp_read_as_Unspecified_DateTime() [Test] public async Task Timestamptz_negative_infinity() { - var dto = await AssertType(DateTimeOffset.MinValue, "-infinity", "timestamp with time zone", NpgsqlDbType.TimestampTz, - DbType.DateTimeOffset, isDefaultForReading: false); + var dto = await AssertType(DateTimeOffset.MinValue, "-infinity", "timestamp with time zone", + dbType: DbType.DateTimeOffset, valueTypeEqualsFieldType: false); Assert.That(dto.Offset, Is.EqualTo(TimeSpan.Zero)); } @@ -41,8 +38,8 @@ public async Task Timestamptz_negative_infinity() public async Task Timestamptz_infinity() { var dto = await AssertType( - DateTimeOffset.MaxValue, "infinity", "timestamp with time zone", NpgsqlDbType.TimestampTz, DbType.DateTimeOffset, - isDefaultForReading: false); + DateTimeOffset.MaxValue, "infinity", "timestamp with time zone", dbType: DbType.DateTimeOffset, + valueTypeEqualsFieldType: false); Assert.That(dto.Offset, Is.EqualTo(TimeSpan.Zero)); } @@ -51,12 +48,9 @@ public async Task Timestamptz_infinity() [TestCase(DateTimeKind.Unspecified, TestName = "Timestamptz_write_unspecified_DateTime_does_not_convert")] public Task Timestamptz_write_utc_DateTime_does_not_convert(DateTimeKind kind) => AssertTypeWrite( - new DateTime(1998, 4, 12, 13, 26, 38, 789, kind), - "1998-04-12 15:26:38.789+02", - "timestamp with time zone", - NpgsqlDbType.TimestampTz, - DbType.DateTimeOffset, - isDefault: false); + new DateTime(1998, 4, 12, 13, 26, 38, 789, kind), "1998-04-12 15:26:38.789+02", + "timestamp with time zone", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.DateTimeOffset, DbType.DateTime)); [Test] public Task Timestamptz_local_DateTime_converts() @@ -66,12 +60,9 @@ public Task Timestamptz_local_DateTime_converts() var dateTime = new DateTime(1998, 4, 12, 13, 26, 38, 789, DateTimeKind.Utc).ToLocalTime(); return AssertType( - dateTime, - "1998-04-12 15:26:38.789+02", - "timestamp with time zone", - NpgsqlDbType.TimestampTz, - DbType.DateTimeOffset, - isDefaultForWriting: false); + dateTime, "1998-04-12 15:26:38.789+02", + "timestamp with time zone", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.DateTimeOffset, DbType.DateTime)); } NpgsqlDataSource _dataSource = null!; diff --git a/test/Npgsql.Tests/Types/MiscTypeTests.cs b/test/Npgsql.Tests/Types/MiscTypeTests.cs index d689a268ef..e2bd17cf29 100644 --- a/test/Npgsql.Tests/Types/MiscTypeTests.cs +++ b/test/Npgsql.Tests/Types/MiscTypeTests.cs @@ -9,16 +9,16 @@ namespace Npgsql.Tests.Types; /// /// Tests on PostgreSQL types which don't fit elsewhere /// -class MiscTypeTests : MultiplexingTestBase +class MiscTypeTests : TestBase { [Test] public async Task Boolean() { - await AssertType(true, "true", "boolean", NpgsqlDbType.Boolean, DbType.Boolean, skipArrayCheck: true); - await AssertType(false, "false", "boolean", NpgsqlDbType.Boolean, DbType.Boolean, skipArrayCheck: true); + await AssertType(true, "true", "boolean", dbType: DbType.Boolean, skipArrayCheck: true); + await AssertType(false, "false", "boolean", dbType: DbType.Boolean, skipArrayCheck: true); // The literal representations for bools inside array are different ({t,f} instead of true/false, so we check separately. - await AssertType(new[] { true, false }, "{t,f}", "boolean[]", NpgsqlDbType.Boolean | NpgsqlDbType.Array); + await AssertType(new[] { true, false }, "{t,f}", "boolean[]"); } [Test] @@ -26,7 +26,7 @@ public Task Uuid() => AssertType( new Guid("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"), "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11", - "uuid", NpgsqlDbType.Uuid, DbType.Guid); + "uuid", dbType: DbType.Guid); [Test, Description("Makes sure that the PostgreSQL 'unknown' type (OID 705) is read properly")] public async Task Read_unknown() @@ -103,16 +103,28 @@ public async Task AllResultTypesAreUnknown() [Test, Description("Mixes and matches an unknown type with a known type")] public async Task UnknownResultTypeList() { - if (IsMultiplexing) - return; - await using var conn = await OpenConnectionAsync(); await using var cmd = new NpgsqlCommand("SELECT TRUE, 8", conn); - cmd.UnknownResultTypeList = new[] { true, false }; + cmd.UnknownResultTypeList = [true, false]; await using var reader = await cmd.ExecuteReaderAsync(); reader.Read(); + Assert.That(reader.GetFieldType(0), Is.EqualTo(typeof(string))); Assert.That(reader.GetString(0), Is.EqualTo("t")); + Assert.That(reader.GetValue(0), Is.EqualTo("t")); + Assert.That(reader.GetFieldValue(0), Is.EqualTo("t")); + + // Try some alternative text types + Assert.That(reader.GetFieldValue(0), Is.EqualTo("t")); + Assert.That(reader.GetFieldValue(0), Is.EqualTo("t")); + + // Try as async + Assert.That(await reader.GetFieldValueAsync(0), Is.EqualTo("t")); + Assert.That(await reader.GetFieldValueAsync(0), Is.EqualTo("t")); + Assert.That(await reader.GetFieldValueAsync(0), Is.EqualTo("t")); + Assert.That(await reader.GetFieldValueAsync(0), Is.EqualTo("t")); + + // Normal binary column Assert.That(reader.GetInt32(1), Is.EqualTo(8)); } @@ -157,25 +169,33 @@ public async Task Send_unknown() [Test] public async Task ObjectArray() { - await AssertTypeWrite(new object?[] { (short)4, null, (long)5, 6 }, "{4,NULL,5,6}", "integer[]", NpgsqlDbType.Integer | NpgsqlDbType.Array, isDefault: false); - await AssertTypeWrite(new object?[] { "text", null, DBNull.Value, "chars".ToCharArray(), 'c' }, "{text,NULL,NULL,chars,c}", "text[]", NpgsqlDbType.Text | NpgsqlDbType.Array, isDefault: false); + await AssertTypeWrite(new object?[] { (short)4, null, (long)5, 6 }, "{4,NULL,5,6}", + "integer[]", dataTypeInference: DataTypeInference.Nothing); + await AssertTypeWrite(new object?[] { "text", null, DBNull.Value, "chars".ToCharArray(), 'c' }, "{text,NULL,NULL,chars,c}", + "text[]", dataTypeInference: DataTypeInference.Nothing); await using var dataSource = CreateDataSource(b => b.ConnectionStringBuilder.Timezone = "Europe/Berlin"); - await AssertTypeWrite(dataSource, new object?[] { DateTime.UnixEpoch, null, DBNull.Value, DateTime.UnixEpoch.AddDays(1) }, "{\"1970-01-01 01:00:00+01\",NULL,NULL,\"1970-01-02 01:00:00+01\"}", "timestamp with time zone[]", NpgsqlDbType.TimestampTz | NpgsqlDbType.Array, isDefault: false); + await AssertTypeWrite(dataSource, new object?[] { DateTime.UnixEpoch, null, DBNull.Value, DateTime.UnixEpoch.AddDays(1) }, + "{\"1970-01-01 01:00:00+01\",NULL,NULL,\"1970-01-02 01:00:00+01\"}", + "timestamp with time zone[]", dataTypeInference: DataTypeInference.Nothing); Assert.ThrowsAsync(() => AssertTypeWrite(dataSource, new object?[] { DateTime.Now, null, DBNull.Value, DateTime.UnixEpoch.AddDays(1) }, "{\"1970-01-01 01:00:00+01\",NULL,NULL,\"1970-01-02 01:00:00+01\"}", "timestamp with time zone[]", - NpgsqlDbType.TimestampTz | NpgsqlDbType.Array, isDefault: false)); + dataTypeInference: DataTypeInference.Nothing)); } [Test] public Task Int2Vector() - => AssertType(new short[] { 4, 5, 6 }, "4 5 6", "int2vector", NpgsqlDbType.Int2Vector, isDefault: false); + => AssertType(new short[] { 4, 5, 6 }, "4 5 6", + "int2vector", dataTypeInference: DataTypeInference.Mismatch, + // int2vector mappings require a data type name, so passing a value of type short[][] will result in no mapping. + skipArrayCheck: true); [Test] public Task Oidvector() - => AssertType(new uint[] { 4, 5, 6 }, "4 5 6", "oidvector", NpgsqlDbType.Oidvector, isDefault: false); + => AssertType(new uint[] { 4, 5, 6 }, "4 5 6", + "oidvector", dataTypeInference: DataTypeInference.Nothing); [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1138")] public async Task Void() @@ -183,15 +203,4 @@ public async Task Void() await using var conn = await OpenConnectionAsync(); Assert.That(await conn.ExecuteScalarAsync("SELECT pg_sleep(0)"), Is.SameAs(null)); } - - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1364")] - public async Task Unsupported_DbType() - { - await using var conn = await OpenConnectionAsync(); - await using var cmd = new NpgsqlCommand("SELECT @p", conn); - Assert.That(() => cmd.Parameters.Add(new NpgsqlParameter("p", DbType.UInt32) { Value = 8u }), - Throws.Exception.TypeOf()); - } - - public MiscTypeTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/Types/MoneyTests.cs b/test/Npgsql.Tests/Types/MoneyTests.cs index 4c38f3d111..8f277a6e34 100644 --- a/test/Npgsql.Tests/Types/MoneyTests.cs +++ b/test/Npgsql.Tests/Types/MoneyTests.cs @@ -7,8 +7,8 @@ namespace Npgsql.Tests.Types; public class MoneyTests : TestBase { - static readonly object[] MoneyValues = new[] - { + static readonly object[] MoneyValues = + [ new object[] { "$1.22", 1.22M }, new object[] { "$1,000.22", 1000.22M }, new object[] { "$1,000,000.22", 1000000.22M }, @@ -18,8 +18,8 @@ public class MoneyTests : TestBase new object[] { "$92,233,720,368,547,758.07", +92233720368547758.07M }, new object[] { "-$92,233,720,368,547,758.08", -92233720368547758.08M }, - new object[] { "-$92,233,720,368,547,758.08", -92233720368547758.08M }, - }; + new object[] { "-$92,233,720,368,547,758.08", -92233720368547758.08M } + ]; [Test] [TestCaseSource(nameof(MoneyValues))] @@ -27,7 +27,9 @@ public async Task Money(string sqlLiteral, decimal money) { using var conn = await OpenConnectionAsync(); await conn.ExecuteNonQueryAsync("SET lc_monetary='C'"); - await AssertType(conn, money, sqlLiteral, "money", NpgsqlDbType.Money, DbType.Currency, isDefault: false); + await AssertType(conn, money, sqlLiteral, + "money", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Currency, DbType.Decimal)); } [Test] @@ -41,11 +43,11 @@ public async Task Non_decimal_types_are_not_supported() await AssertTypeUnsupportedRead("8", "money"); } - static readonly object[] WriteWithLargeScaleCases = new[] - { + static readonly object[] WriteWithLargeScaleCases = + [ new object[] { "0.004::money", 0.004M, 0.00M }, - new object[] { "0.005::money", 0.005M, 0.01M }, - }; + new object[] { "0.005::money", 0.005M, 0.01M } + ]; [Test] [TestCaseSource(nameof(WriteWithLargeScaleCases))] @@ -59,4 +61,4 @@ public async Task Write_with_large_scale(string query, decimal parameter, decima Assert.That(decimal.GetBits(rdr.GetFieldValue(0)), Is.EqualTo(decimal.GetBits(expected))); Assert.That(rdr.GetFieldValue(1)); } -} \ No newline at end of file +} diff --git a/test/Npgsql.Tests/Types/MultirangeTests.cs b/test/Npgsql.Tests/Types/MultirangeTests.cs index d01dc6e408..9bf53bf528 100644 --- a/test/Npgsql.Tests/Types/MultirangeTests.cs +++ b/test/Npgsql.Tests/Types/MultirangeTests.cs @@ -1,7 +1,6 @@ using System; using System.Collections.Generic; using System.Data; -using System.Linq; using System.Threading.Tasks; using Npgsql.Properties; using NpgsqlTypes; @@ -13,7 +12,7 @@ namespace Npgsql.Tests.Types; public class MultirangeTests : TestBase { static readonly TestCaseData[] MultirangeTestCases = - { + [ // int4multirange new TestCaseData( new NpgsqlRange[] @@ -21,7 +20,7 @@ public class MultirangeTests : TestBase new(3, true, false, 7, false, false), new(9, true, false, 0, false, true) }, - "{[3,7),[9,)}", "int4multirange", NpgsqlDbType.IntegerMultirange, true, true, default(NpgsqlRange)) + "{[3,7),[9,)}", "int4multirange", DataTypeInference.Match, true, default(NpgsqlRange)) .SetName("Int"), // int8multirange @@ -31,7 +30,7 @@ public class MultirangeTests : TestBase new(3, true, false, 7, false, false), new(9, true, false, 0, false, true) }, - "{[3,7),[9,)}", "int8multirange", NpgsqlDbType.BigIntMultirange, true, true, default(NpgsqlRange)) + "{[3,7),[9,)}", "int8multirange", DataTypeInference.Match, true, default(NpgsqlRange)) .SetName("Long"), // nummultirange @@ -42,17 +41,17 @@ public class MultirangeTests : TestBase new(3, true, false, 7, true, false), new(9, false, false, 0, false, true) }, - "{[3,7],(9,)}", "nummultirange", NpgsqlDbType.NumericMultirange, true, true, default(NpgsqlRange)) + "{[3,7],(9,)}", "nummultirange", DataTypeInference.Match, true, default(NpgsqlRange)) .SetName("Decimal"), // daterange new TestCaseData( - new NpgsqlRange[] + new NpgsqlRange[] { new(new(2020, 1, 1), true, false, new(2020, 1, 5), false, false), new(new(2020, 1, 10), true, false, default, false, true) }, - "{[2020-01-01,2020-01-05),[2020-01-10,)}", "datemultirange", NpgsqlDbType.DateMultirange, true, false, default(NpgsqlRange)) + "{[2020-01-01,2020-01-05),[2020-01-10,)}", "datemultirange", DataTypeInference.Match, true, default(NpgsqlRange)) .SetName("DateTime DateMultirange"), // tsmultirange @@ -62,7 +61,7 @@ public class MultirangeTests : TestBase new(new(2020, 1, 1), true, false, new(2020, 1, 5), false, false), new(new(2020, 1, 10), true, false, default, false, true) }, - """{["2020-01-01 00:00:00","2020-01-05 00:00:00"),["2020-01-10 00:00:00",)}""", "tsmultirange", NpgsqlDbType.TimestampMultirange, true, true, default(NpgsqlRange)) + """{["2020-01-01 00:00:00","2020-01-05 00:00:00"),["2020-01-10 00:00:00",)}""", "tsmultirange", DataTypeInference.Match, true, default(NpgsqlRange)) .SetName("DateTime TimestampMultirange"), // tstzmultirange @@ -72,7 +71,7 @@ public class MultirangeTests : TestBase new(new(2020, 1, 1, 0, 0, 0, kind: DateTimeKind.Utc), true, false, new(2020, 1, 5, 0, 0, 0, kind: DateTimeKind.Utc), false, false), new(new(2020, 1, 10, 0, 0, 0, kind: DateTimeKind.Utc), true, false, default, false, true) }, - """{["2020-01-01 01:00:00+01","2020-01-05 01:00:00+01"),["2020-01-10 01:00:00+01",)}""", "tstzmultirange", NpgsqlDbType.TimestampTzMultirange, true, true, default(NpgsqlRange)) + """{["2020-01-01 01:00:00+01","2020-01-05 01:00:00+01"),["2020-01-10 01:00:00+01",)}""", "tstzmultirange", DataTypeInference.Match, true, default(NpgsqlRange)) .SetName("DateTime TimestampTzMultirange"), new TestCaseData( @@ -81,26 +80,25 @@ public class MultirangeTests : TestBase new(new(2020, 1, 1), true, false, new(2020, 1, 5), false, false), new(new(2020, 1, 10), true, false, default, false, true) }, - "{[2020-01-01,2020-01-05),[2020-01-10,)}", "datemultirange", NpgsqlDbType.DateMultirange, false, false, default(NpgsqlRange)) - .SetName("DateOnly"), - }; + "{[2020-01-01,2020-01-05),[2020-01-10,)}", "datemultirange", DataTypeInference.Mismatch, true, default(NpgsqlRange)) + .SetName("DateOnly") + ]; [Test, TestCaseSource(nameof(MultirangeTestCases))] public Task Multirange_as_array( - T multirangeAsArray, string sqlLiteral, string pgTypeName, NpgsqlDbType? npgsqlDbType, bool isDefaultForReading, bool isDefaultForWriting, TRange _) - => AssertType(multirangeAsArray, sqlLiteral, pgTypeName, npgsqlDbType, isDefaultForReading: isDefaultForReading, - isDefaultForWriting: isDefaultForWriting); + T multirangeAsArray, string sqlLiteral, string dataTypeName, DataTypeInference datatypeDataTypeInference, bool valueTypeEqualsFieldType, TRange _) + => AssertType(multirangeAsArray, sqlLiteral, dataTypeName, + dataTypeInference: datatypeDataTypeInference, valueTypeEqualsFieldType: valueTypeEqualsFieldType); [Test, TestCaseSource(nameof(MultirangeTestCases))] public Task Multirange_as_list( - T multirangeAsArray, string sqlLiteral, string pgTypeName, NpgsqlDbType? npgsqlDbType, bool isDefaultForReading, bool isDefaultForWriting, TRange _) + T multirangeAsArray, string sqlLiteral, string dataTypeName, DataTypeInference datatypeDataTypeInference, bool valueTypeEqualsFieldType, TRange _) where T : IList => AssertType( - new List(multirangeAsArray), - sqlLiteral, pgTypeName, npgsqlDbType, isDefaultForReading: false, isDefaultForWriting: isDefaultForWriting); + new List(multirangeAsArray), sqlLiteral, dataTypeName, + dataTypeInference: datatypeDataTypeInference, valueTypeEqualsFieldType: false); [Test] - [NonParallelizable] public async Task Unmapped_multirange_with_mapped_subtype() { await using var dataSource = CreateDataSource(b => b.EnableUnmappedTypes().ConnectionStringBuilder.MaxPoolSize = 1); @@ -108,7 +106,6 @@ public async Task Unmapped_multirange_with_mapped_subtype() var typeName = await GetTempTypeName(conn); await conn.ExecuteNonQueryAsync($"CREATE TYPE {typeName} AS RANGE(subtype=text)"); - await Task.Yield(); // TODO: fix multiplexing deadlock bug conn.ReloadTypes(); Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); @@ -135,7 +132,6 @@ public async Task Unmapped_multirange_supported_only_with_EnableUnmappedTypes() var rangeType = await GetTempTypeName(connection); var multirangeTypeName = rangeType + "_multirange"; await connection.ExecuteNonQueryAsync($"CREATE TYPE {rangeType} AS RANGE(subtype=text)"); - await Task.Yield(); // TODO: fix multiplexing deadlock bug await connection.ReloadTypesAsync(); var errorMessage = string.Format( @@ -150,18 +146,18 @@ public async Task Unmapped_multirange_supported_only_with_EnableUnmappedTypes() new("moo", "zoo"), }, multirangeTypeName); - Assert.IsInstanceOf(exception.InnerException); + Assert.That(exception.InnerException, Is.InstanceOf()); Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); - exception = await AssertTypeUnsupportedRead("""{["bar","foo"],["moo","zoo"]}""", + exception = await AssertTypeUnsupportedRead("""{["bar","foo"],["moo","zoo"]}""", multirangeTypeName); - Assert.IsInstanceOf(exception.InnerException); + Assert.That(exception.InnerException, Is.InstanceOf()); Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); exception = await AssertTypeUnsupportedRead>( """{["bar","foo"],["moo","zoo"]}""", multirangeTypeName); - Assert.IsInstanceOf(exception.InnerException); + Assert.That(exception.InnerException, Is.InstanceOf()); Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); } diff --git a/test/Npgsql.Tests/Types/NetworkTypeTests.cs b/test/Npgsql.Tests/Types/NetworkTypeTests.cs index f164b57d75..3ddc78e87c 100644 --- a/test/Npgsql.Tests/Types/NetworkTypeTests.cs +++ b/test/Npgsql.Tests/Types/NetworkTypeTests.cs @@ -13,11 +13,11 @@ namespace Npgsql.Tests.Types; /// /// https://www.postgresql.org/docs/current/static/datatype-net-types.html /// -class NetworkTypeTests : MultiplexingTestBase +class NetworkTypeTests : TestBase { [Test] public Task Inet_v4_as_IPAddress() - => AssertType(IPAddress.Parse("192.168.1.1"), "192.168.1.1/32", "inet", NpgsqlDbType.Inet, skipArrayCheck: true); + => AssertType(IPAddress.Parse("192.168.1.1"), "192.168.1.1/32", "inet", skipArrayCheck: true); [Test] public Task Inet_v4_array_as_IPAddress_array() @@ -27,7 +27,7 @@ public Task Inet_v4_array_as_IPAddress_array() IPAddress.Parse("192.168.1.1"), IPAddress.Parse("192.168.1.2") }, - "{192.168.1.1,192.168.1.2}", "inet[]", NpgsqlDbType.Inet | NpgsqlDbType.Array); + "{192.168.1.1,192.168.1.2}", "inet[]"); [Test] public Task Inet_v6_as_IPAddress() @@ -35,7 +35,6 @@ public Task Inet_v6_as_IPAddress() IPAddress.Parse("2001:1db8:85a3:1142:1000:8a2e:1370:7334"), "2001:1db8:85a3:1142:1000:8a2e:1370:7334/128", "inet", - NpgsqlDbType.Inet, skipArrayCheck: true); [Test] @@ -46,20 +45,28 @@ public Task Inet_v6_array_as_IPAddress_array() IPAddress.Parse("2001:1db8:85a3:1142:1000:8a2e:1370:7334"), IPAddress.Parse("2001:1db8:85a3:1142:1000:8a2e:1370:7335") }, - "{2001:1db8:85a3:1142:1000:8a2e:1370:7334,2001:1db8:85a3:1142:1000:8a2e:1370:7335}", "inet[]", NpgsqlDbType.Inet | NpgsqlDbType.Array); + "{2001:1db8:85a3:1142:1000:8a2e:1370:7334,2001:1db8:85a3:1142:1000:8a2e:1370:7335}", "inet[]"); [Test, IssueLink("https://github.com/dotnet/corefx/issues/33373")] public Task IPAddress_Any() - => AssertTypeWrite(IPAddress.Any, "0.0.0.0/32", "inet", NpgsqlDbType.Inet, skipArrayCheck: true); + => AssertTypeWrite(IPAddress.Any, "0.0.0.0/32", "inet", skipArrayCheck: true); [Test] - public Task Cidr() + public Task IPNetwork_as_cidr() + => AssertType( + new IPNetwork(IPAddress.Parse("192.168.1.0"), 24), + "192.168.1.0/24", + "cidr"); + +#pragma warning disable CS0618 // NpgsqlCidr is obsolete + [Test] + public Task NpgsqlCidr_as_Cidr() => AssertType( new NpgsqlCidr(IPAddress.Parse("192.168.1.0"), netmask: 24), "192.168.1.0/24", "cidr", - NpgsqlDbType.Cidr, - isDefaultForWriting: false); + valueTypeEqualsFieldType: false); +#pragma warning restore CS0618 [Test] public Task Inet_v4_as_NpgsqlInet() @@ -67,8 +74,7 @@ public Task Inet_v4_as_NpgsqlInet() new NpgsqlInet(IPAddress.Parse("192.168.1.1"), 24), "192.168.1.1/24", "inet", - NpgsqlDbType.Inet, - isDefaultForReading: false); + valueTypeEqualsFieldType: false); [Test] public Task Inet_v6_as_NpgsqlInet() @@ -76,12 +82,11 @@ public Task Inet_v6_as_NpgsqlInet() new NpgsqlInet(IPAddress.Parse("2001:1db8:85a3:1142:1000:8a2e:1370:7334"), 24), "2001:1db8:85a3:1142:1000:8a2e:1370:7334/24", "inet", - NpgsqlDbType.Inet, - isDefaultForReading: false); + valueTypeEqualsFieldType: false); [Test] public Task Macaddr() - => AssertType(PhysicalAddress.Parse("08-00-2B-01-02-03"), "08:00:2b:01:02:03", "macaddr", NpgsqlDbType.MacAddr); + => AssertType(PhysicalAddress.Parse("08-00-2B-01-02-03"), "08:00:2b:01:02:03", "macaddr"); [Test] public async Task Macaddr8() @@ -90,8 +95,8 @@ public async Task Macaddr8() if (conn.PostgreSqlVersion < new Version(10, 0)) Assert.Ignore("macaddr8 only supported on PostgreSQL 10 and above"); - await AssertType(PhysicalAddress.Parse("08-00-2B-01-02-03-04-05"), "08:00:2b:01:02:03:04:05", "macaddr8", NpgsqlDbType.MacAddr8, - isDefaultForWriting: false); + await AssertType(PhysicalAddress.Parse("08-00-2B-01-02-03-04-05"), "08:00:2b:01:02:03:04:05", + "macaddr8", dataTypeInference: DataTypeInference.Mismatch); } [Test] @@ -101,8 +106,8 @@ public async Task Macaddr8_write_with_6_bytes() if (conn.PostgreSqlVersion < new Version(10, 0)) Assert.Ignore("macaddr8 only supported on PostgreSQL 10 and above"); - await AssertTypeWrite(PhysicalAddress.Parse("08-00-2B-01-02-03"), "08:00:2b:ff:fe:01:02:03", "macaddr8", NpgsqlDbType.MacAddr8, - isDefault: false); + await AssertTypeWrite(PhysicalAddress.Parse("08-00-2B-01-02-03"), "08:00:2b:ff:fe:01:02:03", + "macaddr8", dataTypeInference: DataTypeInference.Mismatch); } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/835")] @@ -128,6 +133,4 @@ public async Task Macaddr_write_validation() await AssertTypeUnsupportedWrite(PhysicalAddress.Parse("08-00-2B-01-02-03-04-05"), "macaddr"); } - - public NetworkTypeTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/Types/NumericTests.cs b/test/Npgsql.Tests/Types/NumericTests.cs index 43dd846a8c..38b95cfc0e 100644 --- a/test/Npgsql.Tests/Types/NumericTests.cs +++ b/test/Npgsql.Tests/Types/NumericTests.cs @@ -3,15 +3,14 @@ using System.Linq; using System.Numerics; using System.Threading.Tasks; -using NpgsqlTypes; using NUnit.Framework; namespace Npgsql.Tests.Types; -public class NumericTests : MultiplexingTestBase +public class NumericTests : TestBase { - static readonly object[] ReadWriteCases = new[] - { + static readonly object[] ReadWriteCases = + [ new object[] { "0.0000000000000000000000000001::numeric", 0.0000000000000000000000000001M }, new object[] { "0.000000000000000000000001::numeric", 0.000000000000000000000001M }, new object[] { "0.00000000000000000001::numeric", 0.00000000000000000001M }, @@ -76,14 +75,16 @@ public class NumericTests : MultiplexingTestBase // Bug 2033 new object[] { "0.0036882500000000000000000000", 0.0036882500000000000000000000M }, + // Bug 5848 + new object[] { "10836968.715000000000000000000000", 10836968.715000000000000000000000M }, new object[] { "936490726837837729197", 936490726837837729197M }, new object[] { "9364907268378377291970000", 9364907268378377291970000M }, new object[] { "3649072683783772919700000000", 3649072683783772919700000000M }, new object[] { "1234567844445555.000000000", 1234567844445555.000000000M }, new object[] { "11112222000000000000", 11112222000000000000M }, - new object[] { "0::numeric", 0M }, - }; + new object[] { "0::numeric", 0M } + ]; [Test] [TestCaseSource(nameof(ReadWriteCases))] @@ -112,15 +113,21 @@ public async Task Write(string query, decimal expected) [Test] public async Task Numeric() { - await AssertType(5.5m, "5.5", "numeric", NpgsqlDbType.Numeric, DbType.Decimal); - await AssertTypeWrite(5.5m, "5.5", "numeric", NpgsqlDbType.Numeric, DbType.VarNumeric, inferredDbType: DbType.Decimal); - - await AssertType((short)8, "8", "numeric", NpgsqlDbType.Numeric, DbType.Decimal, isDefault: false); - await AssertType(8, "8", "numeric", NpgsqlDbType.Numeric, DbType.Decimal, isDefault: false); - await AssertType((byte)8, "8", "numeric", NpgsqlDbType.Numeric, DbType.Decimal, isDefault: false); - await AssertType(8F, "8", "numeric", NpgsqlDbType.Numeric, DbType.Decimal, isDefault: false); - await AssertType(8D, "8", "numeric", NpgsqlDbType.Numeric, DbType.Decimal, isDefault: false); - await AssertType(8M, "8", "numeric", NpgsqlDbType.Numeric, DbType.Decimal, isDefault: false); + await AssertType(5.5m, "5.5", "numeric", dbType: DbType.Decimal); + await AssertTypeWrite(5.5m, "5.5", "numeric", dbType: new(DbType.Decimal, DbType.Decimal, DbType.VarNumeric)); + + await AssertType((short)8, "8", "numeric", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Decimal, DbType.Int16), valueTypeEqualsFieldType: false); + await AssertType(8, "8", "numeric", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Decimal, DbType.Int32), valueTypeEqualsFieldType: false); + await AssertType(8L, "8", "numeric", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Decimal, DbType.Int64), valueTypeEqualsFieldType: false); + await AssertType((byte)8, "8", "numeric", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Decimal, DbType.Int16), valueTypeEqualsFieldType: false, skipArrayCheck: true); + await AssertType(8F, "8", "numeric", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Decimal, DbType.Single), valueTypeEqualsFieldType: false); + await AssertType(8D, "8", "numeric", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Decimal, DbType.Double), valueTypeEqualsFieldType: false); } [Test, Description("Tests that when Numeric value does not fit in a System.Decimal and reader is in ReaderState.InResult, the value was read wholly and it is safe to continue reading")] @@ -211,5 +218,17 @@ public async Task NumericZero_WithScale() Assert.That(value.Scale, Is.EqualTo(2)); } - public NumericTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} + [Test, IssueLink("https://github.com/npgsql/npgsql/issues/6383")] + public async Task Read_Many_Numerics_As_BigInteger([Values(CommandBehavior.Default, CommandBehavior.SequentialAccess)] CommandBehavior behavior) + { + await using var conn = await OpenConnectionAsync(); + await using var cmd = conn.CreateCommand(); + cmd.CommandText = "SELECT 1234567890::numeric FROM generate_series(1, 8000)"; + + await using var reader = await cmd.ExecuteReaderAsync(behavior); + while (await reader.ReadAsync()) + { + Assert.DoesNotThrowAsync(async () => await reader.GetFieldValueAsync(0)); + } + } } diff --git a/test/Npgsql.Tests/Types/NumericTypeTests.cs b/test/Npgsql.Tests/Types/NumericTypeTests.cs index 9fcd5b695b..dc41a387c8 100644 --- a/test/Npgsql.Tests/Types/NumericTypeTests.cs +++ b/test/Npgsql.Tests/Types/NumericTypeTests.cs @@ -2,7 +2,6 @@ using System.Data; using System.Globalization; using System.Threading.Tasks; -using NpgsqlTypes; using NUnit.Framework; using static Npgsql.Tests.TestUtil; @@ -14,64 +13,87 @@ namespace Npgsql.Tests.Types; /// /// https://www.postgresql.org/docs/current/static/datatype-numeric.html /// -public class NumericTypeTests : MultiplexingTestBase +public class NumericTypeTests : TestBase { [Test] public async Task Int16() { - await AssertType((short)8, "8", "smallint", NpgsqlDbType.Smallint, DbType.Int16); + await AssertType((short)8, "8", "smallint", dbType: DbType.Int16); // Clr byte/sbyte maps to 'int2' as there is no byte type in PostgreSQL, byte[] maps to bytea however. - await AssertType((byte)8, "8", "smallint", NpgsqlDbType.Smallint, DbType.Int16, isDefaultForReading: false, skipArrayCheck: true); - await AssertType((sbyte)8, "8", "smallint", NpgsqlDbType.Smallint, DbType.Int16, isDefaultForReading: false); - - await AssertType(8, "8", "smallint", NpgsqlDbType.Smallint, DbType.Int16, isDefault: false); - await AssertType(8L, "8", "smallint", NpgsqlDbType.Smallint, DbType.Int16, isDefault: false); - await AssertType(8F, "8", "smallint", NpgsqlDbType.Smallint, DbType.Int16, isDefault: false); - await AssertType(8D, "8", "smallint", NpgsqlDbType.Smallint, DbType.Int16, isDefault: false); - await AssertType(8M, "8", "smallint", NpgsqlDbType.Smallint, DbType.Int16, isDefault: false); + await AssertType((byte)8, "8", "smallint", dataTypeInference: DataTypeInference.Mismatch, + dbType: DbType.Int16, valueTypeEqualsFieldType: false, skipArrayCheck: true); + await AssertType((sbyte)8, "8", "smallint", dataTypeInference: DataTypeInference.Mismatch, + dbType: DbType.Int16, valueTypeEqualsFieldType: false); + + await AssertType(8, "8", "smallint", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Int16, DbType.Int32), valueTypeEqualsFieldType: false); + await AssertType(8L, "8", "smallint", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Int16, DbType.Int64), valueTypeEqualsFieldType: false); + await AssertType(8F, "8", "smallint", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Int16, DbType.Single), valueTypeEqualsFieldType: false); + await AssertType(8D, "8", "smallint", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Int16, DbType.Double), valueTypeEqualsFieldType: false); + await AssertType(8M, "8", "smallint", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Int16, DbType.Decimal), valueTypeEqualsFieldType: false); } [Test] public async Task Int32() { - await AssertType(8, "8", "integer", NpgsqlDbType.Integer, DbType.Int32); - - await AssertType((short)8, "8", "integer", NpgsqlDbType.Integer, DbType.Int32, isDefault: false); - await AssertType(8L, "8", "integer", NpgsqlDbType.Integer, DbType.Int32, isDefault: false); - await AssertType((byte)8, "8", "integer", NpgsqlDbType.Integer, DbType.Int32, isDefault: false); - await AssertType(8F, "8", "integer", NpgsqlDbType.Integer, DbType.Int32, isDefault: false); - await AssertType(8D, "8", "integer", NpgsqlDbType.Integer, DbType.Int32, isDefault: false); - await AssertType(8M, "8", "integer", NpgsqlDbType.Integer, DbType.Int32, isDefault: false); + await AssertType(8, "8", "integer", dbType: DbType.Int32); + + await AssertType((short)8, "8", "integer", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Int32, DbType.Int16), valueTypeEqualsFieldType: false); + await AssertType(8L, "8", "integer", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Int32, DbType.Int64), valueTypeEqualsFieldType: false); + await AssertType((byte)8, "8", "integer", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Int32, DbType.Int16), valueTypeEqualsFieldType: false, skipArrayCheck: true); // byte[] maps to bytea + await AssertType((sbyte)8, "8", "integer", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Int32, DbType.Int16), valueTypeEqualsFieldType: false); + await AssertType(8F, "8", "integer", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Int32, DbType.Single), valueTypeEqualsFieldType: false); + await AssertType(8D, "8", "integer", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Int32, DbType.Double), valueTypeEqualsFieldType: false); + await AssertType(8M, "8", "integer", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Int32, DbType.Decimal), valueTypeEqualsFieldType: false); } [Test, Description("Tests some types which are aliased to UInt32")] - [TestCase("oid", NpgsqlDbType.Oid, TestName="OID")] - [TestCase("xid", NpgsqlDbType.Xid, TestName="XID")] - [TestCase("cid", NpgsqlDbType.Cid, TestName="CID")] - public Task UInt32(string pgTypeName, NpgsqlDbType npgsqlDbType) - => AssertType(8u, "8", pgTypeName, npgsqlDbType, isDefaultForWriting: false); + [TestCase("oid", TestName="OID")] + [TestCase("xid", TestName="XID")] + [TestCase("cid", TestName="CID")] + public Task UInt32(string dataTypeName) + => AssertType(8u, "8", dataTypeName, dataTypeInference: DataTypeInference.Nothing); [Test] - [TestCase("xid8", NpgsqlDbType.Xid8, TestName="XID8")] - public async Task UInt64(string pgTypeName, NpgsqlDbType npgsqlDbType) + [TestCase("xid8", TestName="XID8")] + public async Task UInt64(string dataTypeName) { await using var conn = await OpenConnectionAsync(); MinimumPgVersion(conn, "13.0", "The xid8 type was introduced in PostgreSQL 13"); - await AssertType(8ul, "8", pgTypeName, npgsqlDbType, isDefaultForWriting: false); + await AssertType(8ul, "8", dataTypeName, dataTypeInference: DataTypeInference.Nothing); } [Test] public async Task Int64() { - await AssertType(8L, "8", "bigint", NpgsqlDbType.Bigint, DbType.Int64); - - await AssertType((short)8, "8", "bigint", NpgsqlDbType.Bigint, DbType.Int64, isDefault: false); - await AssertType(8, "8", "bigint", NpgsqlDbType.Bigint, DbType.Int64, isDefault: false); - await AssertType((byte)8, "8", "bigint", NpgsqlDbType.Bigint, DbType.Int64, isDefault: false); - await AssertType(8F, "8", "bigint", NpgsqlDbType.Bigint, DbType.Int64, isDefault: false); - await AssertType(8D, "8", "bigint", NpgsqlDbType.Bigint, DbType.Int64, isDefault: false); - await AssertType(8M, "8", "bigint", NpgsqlDbType.Bigint, DbType.Int64, isDefault: false); + await AssertType(8L, "8", "bigint", dbType: DbType.Int64); + + await AssertType((short)8, "8", "bigint", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Int64, DbType.Int16), valueTypeEqualsFieldType: false); + await AssertType(8, "8", "bigint", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Int64, DbType.Int32), valueTypeEqualsFieldType: false); + await AssertType((byte)8, "8", "bigint", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Int64, DbType.Int16), valueTypeEqualsFieldType: false, skipArrayCheck: true); // byte[] maps to bytea + await AssertType((sbyte)8, "8", "bigint", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Int64, DbType.Int16), valueTypeEqualsFieldType: false); + await AssertType(8F, "8", "bigint", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Int64, DbType.Single), valueTypeEqualsFieldType: false); + await AssertType(8D, "8", "bigint", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Int64, DbType.Double), valueTypeEqualsFieldType: false); + await AssertType(8M, "8", "bigint", dataTypeInference: DataTypeInference.Mismatch, + dbType: new(DbType.Int64, DbType.Decimal), valueTypeEqualsFieldType: false); } [Test] @@ -84,7 +106,7 @@ public async Task Double(double value, string sqlLiteral) await using var conn = await OpenConnectionAsync(); MinimumPgVersion(conn, "12.0"); - await AssertType(value, sqlLiteral, "double precision", NpgsqlDbType.Double, DbType.Double); + await AssertType(value, sqlLiteral, "double precision", dbType: DbType.Double); } [Test] @@ -93,21 +115,19 @@ public async Task Double(double value, string sqlLiteral) [TestCase(float.PositiveInfinity, "Infinity", TestName = "Float_PositiveInfinity")] [TestCase(float.NegativeInfinity, "-Infinity", TestName = "Float_NegativeInfinity")] public Task Float(float value, string sqlLiteral) - => AssertType(value, sqlLiteral, "real", NpgsqlDbType.Real, DbType.Single); + => AssertType(value, sqlLiteral, "real", dbType: DbType.Single); [Test] [TestCase(short.MaxValue + 1, "smallint")] [TestCase(int.MaxValue + 1L, "integer")] [TestCase(long.MaxValue + 1D, "bigint")] - public Task Write_overflow(T value, string pgTypeName) - => AssertTypeUnsupportedWrite(value, pgTypeName); + public Task Write_overflow(T value, string dataTypeName) + => AssertTypeUnsupportedWrite(value, dataTypeName); [Test] [TestCase((short)0, short.MaxValue + 1D, "int")] [TestCase(0, int.MaxValue + 1D, "bigint")] [TestCase(0L, long.MaxValue + 1D, "decimal")] - public Task Read_overflow(T _, double value, string pgTypeName) - => AssertTypeUnsupportedRead(value.ToString(CultureInfo.InvariantCulture), pgTypeName); - - public NumericTypeTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} + public Task Read_overflow(T _, double value, string dataTypeName) + => AssertTypeUnsupportedRead(value.ToString(CultureInfo.InvariantCulture), dataTypeName); } diff --git a/test/Npgsql.Tests/Types/RangeTests.cs b/test/Npgsql.Tests/Types/RangeTests.cs index 38449d30a2..23974f5583 100644 --- a/test/Npgsql.Tests/Types/RangeTests.cs +++ b/test/Npgsql.Tests/Types/RangeTests.cs @@ -2,7 +2,6 @@ using System.ComponentModel; using System.Data; using System.Globalization; -using System.Linq; using System.Threading.Tasks; using Npgsql.Properties; using Npgsql.Util; @@ -12,50 +11,50 @@ namespace Npgsql.Tests.Types; -class RangeTests : MultiplexingTestBase +class RangeTests : TestBase { static readonly TestCaseData[] RangeTestCases = - { - new TestCaseData(new NpgsqlRange(1, true, 10, false), "[1,10)", "int4range", NpgsqlDbType.IntegerRange) + [ + new TestCaseData(new NpgsqlRange(1, true, 10, false), "[1,10)", "int4range") .SetName("IntegerRange"), - new TestCaseData(new NpgsqlRange(1, true, 10, false), "[1,10)", "int8range", NpgsqlDbType.BigIntRange) + new TestCaseData(new NpgsqlRange(1, true, 10, false), "[1,10)", "int8range") .SetName("BigIntRange"), - new TestCaseData(new NpgsqlRange(1, true, 10, false), "[1,10)", "numrange", NpgsqlDbType.NumericRange) + new TestCaseData(new NpgsqlRange(1, true, 10, false), "[1,10)", "numrange") .SetName("NumericRange"), new TestCaseData(new NpgsqlRange( new DateTime(2020, 1, 1, 12, 0, 0), true, new DateTime(2020, 1, 3, 13, 0, 0), false), - """["2020-01-01 12:00:00","2020-01-03 13:00:00")""", "tsrange", NpgsqlDbType.TimestampRange) + """["2020-01-01 12:00:00","2020-01-03 13:00:00")""", "tsrange") .SetName("TimestampRange"), // Note that the below text representations are local (according to TimeZone, which is set to Europe/Berlin in this test class), // because that's how PG does timestamptz *text* representation. new TestCaseData(new NpgsqlRange( new DateTime(2020, 1, 1, 12, 0, 0, DateTimeKind.Utc), true, new DateTime(2020, 1, 3, 13, 0, 0, DateTimeKind.Utc), false), - """["2020-01-01 13:00:00+01","2020-01-03 14:00:00+01")""", "tstzrange", NpgsqlDbType.TimestampTzRange) + """["2020-01-01 13:00:00+01","2020-01-03 14:00:00+01")""", "tstzrange") .SetName("TimestampTzRange"), // Note that numrange is a non-discrete range, and therefore doesn't undergo normalization to inclusive/exclusive in PG - new TestCaseData(NpgsqlRange.Empty, "empty", "numrange", NpgsqlDbType.NumericRange) + new TestCaseData(NpgsqlRange.Empty, "empty", "numrange") .SetName("EmptyRange"), - new TestCaseData(new NpgsqlRange(1, true, 10, true), "[1,10]", "numrange", NpgsqlDbType.NumericRange) + new TestCaseData(new NpgsqlRange(1, true, 10, true), "[1,10]", "numrange") .SetName("Inclusive"), - new TestCaseData(new NpgsqlRange(1, false, 10, false), "(1,10)", "numrange", NpgsqlDbType.NumericRange) + new TestCaseData(new NpgsqlRange(1, false, 10, false), "(1,10)", "numrange") .SetName("Exclusive"), - new TestCaseData(new NpgsqlRange(1, true, 10, false), "[1,10)", "numrange", NpgsqlDbType.NumericRange) + new TestCaseData(new NpgsqlRange(1, true, 10, false), "[1,10)", "numrange") .SetName("InclusiveExclusive"), - new TestCaseData(new NpgsqlRange(1, false, 10, true), "(1,10]", "numrange", NpgsqlDbType.NumericRange) + new TestCaseData(new NpgsqlRange(1, false, 10, true), "(1,10]", "numrange") .SetName("ExclusiveInclusive"), - new TestCaseData(new NpgsqlRange(1, false, true, 10, false, false), "(,10)", "numrange", NpgsqlDbType.NumericRange) + new TestCaseData(new NpgsqlRange(1, false, true, 10, false, false), "(,10)", "numrange") .SetName("InfiniteLowerBound"), - new TestCaseData(new NpgsqlRange(1, true, false, 10, false, true), "[1,)", "numrange", NpgsqlDbType.NumericRange) + new TestCaseData(new NpgsqlRange(1, true, false, 10, false, true), "[1,)", "numrange") .SetName("InfiniteUpperBound") - }; + ]; // See more test cases in DateTimeTests [Test, TestCaseSource(nameof(RangeTestCases))] - public Task Range(T range, string sqlLiteral, string pgTypeName, NpgsqlDbType? npgsqlDbType) - => AssertType(range, sqlLiteral, pgTypeName, npgsqlDbType, + public Task Range(T range, string sqlLiteral, string dataTypeName) + => AssertType(range, sqlLiteral, dataTypeName, // NpgsqlRange[] is mapped to multirange by default, not array, so the built-in AssertType testing for arrays fails // (see below) skipArrayCheck: true); @@ -63,8 +62,8 @@ public Task Range(T range, string sqlLiteral, string pgTypeName, NpgsqlDbType // This re-executes the same scenario as above, but with isDefaultForWriting: false and without skipArrayCheck: true. // This tests coverage of range arrays (as opposed to multiranges). [Test, TestCaseSource(nameof(RangeTestCases))] - public Task Range_array(T range, string sqlLiteral, string pgTypeName, NpgsqlDbType? npgsqlDbType) - => AssertType(range, sqlLiteral, pgTypeName, npgsqlDbType, isDefaultForWriting: false); + public Task Range_array(T range, string sqlLiteral, string dataTypeName) + => AssertType(range, sqlLiteral, dataTypeName, dataTypeInference: DataTypeInference.Mismatch); [Test] public void Equality_finite() @@ -73,23 +72,23 @@ public void Equality_finite() //different bounds var r2 = new NpgsqlRange(1, true, false, 2, false, false); - Assert.IsFalse(r1 == r2); + Assert.That(r1 == r2, Is.False); //lower bound is not inclusive var r3 = new NpgsqlRange(0, false, false, 1, false, false); - Assert.IsFalse(r1 == r3); + Assert.That(r1 == r3, Is.False); //upper bound is inclusive var r4 = new NpgsqlRange(0, true, false, 1, true, false); - Assert.IsFalse(r1 == r4); + Assert.That(r1 == r4, Is.False); var r5 = new NpgsqlRange(0, true, false, 1, false, false); - Assert.IsTrue(r1 == r5); + Assert.That(r1 == r5); //check some other combinations while we are here - Assert.IsFalse(r2 == r3); - Assert.IsFalse(r2 == r4); - Assert.IsFalse(r3 == r4); + Assert.That(r2 == r3, Is.False); + Assert.That(r2 == r4, Is.False); + Assert.That(r3 == r4, Is.False); } [Test] @@ -97,22 +96,22 @@ public void Equality_infinite() { var r1 = new NpgsqlRange(0, false, true, 1, false, false); - //different upper bound (lower bound shoulnd't matter since it is infinite) + //different upper bound (lower bound shouldn't matter since it is infinite) var r2 = new NpgsqlRange(1, false, true, 2, false, false); - Assert.IsFalse(r1 == r2); + Assert.That(r1 == r2, Is.False); //upper bound is inclusive var r3 = new NpgsqlRange(0, false, true, 1, true, false); - Assert.IsFalse(r1 == r3); + Assert.That(r1 == r3, Is.False); //value of lower bound shouldn't matter since it is infinite var r4 = new NpgsqlRange(10, false, true, 1, false, false); - Assert.IsTrue(r1 == r4); + Assert.That(r1 == r4); //check some other combinations while we are here - Assert.IsFalse(r2 == r3); - Assert.IsFalse(r2 == r4); - Assert.IsFalse(r3 == r4); + Assert.That(r2 == r3, Is.False); + Assert.That(r2 == r4, Is.False); + Assert.That(r3 == r4, Is.False); } [Test] @@ -122,12 +121,12 @@ public void GetHashCode_value_types() NpgsqlRange b = NpgsqlRange.Empty; NpgsqlRange c = NpgsqlRange.Parse("(,)"); - Assert.IsFalse(a.Equals(b)); - Assert.IsFalse(a.Equals(c)); - Assert.IsFalse(b.Equals(c)); - Assert.AreNotEqual(a.GetHashCode(), b.GetHashCode()); - Assert.AreNotEqual(a.GetHashCode(), c.GetHashCode()); - Assert.AreNotEqual(b.GetHashCode(), c.GetHashCode()); + Assert.That(a.Equals(b), Is.False); + Assert.That(a.Equals(c), Is.False); + Assert.That(b.Equals(c), Is.False); + Assert.That(b.GetHashCode(), Is.Not.EqualTo(a.GetHashCode())); + Assert.That(c.GetHashCode(), Is.Not.EqualTo(a.GetHashCode())); + Assert.That(c.GetHashCode(), Is.Not.EqualTo(b.GetHashCode())); } [Test] @@ -137,12 +136,12 @@ public void GetHashCode_reference_types() NpgsqlRange b = NpgsqlRange.Empty; NpgsqlRange c = NpgsqlRange.Parse("(,)"); - Assert.IsFalse(a.Equals(b)); - Assert.IsFalse(a.Equals(c)); - Assert.IsFalse(b.Equals(c)); - Assert.AreNotEqual(a.GetHashCode(), b.GetHashCode()); - Assert.AreNotEqual(a.GetHashCode(), c.GetHashCode()); - Assert.AreNotEqual(b.GetHashCode(), c.GetHashCode()); + Assert.That(a.Equals(b), Is.False); + Assert.That(a.Equals(c), Is.False); + Assert.That(b.Equals(c), Is.False); + Assert.That(b.GetHashCode(), Is.Not.EqualTo(a.GetHashCode())); + Assert.That(c.GetHashCode(), Is.Not.EqualTo(a.GetHashCode())); + Assert.That(c.GetHashCode(), Is.Not.EqualTo(b.GetHashCode())); } [Test] @@ -165,7 +164,6 @@ public async Task TimestampTz_range_with_DateTimeOffset() } [Test] - [NonParallelizable] public async Task Unmapped_range_with_mapped_subtype() { await using var dataSource = CreateDataSource(b => b.EnableUnmappedTypes().ConnectionStringBuilder.MaxPoolSize = 1); @@ -173,7 +171,6 @@ public async Task Unmapped_range_with_mapped_subtype() var typeName = await GetTempTypeName(conn); await conn.ExecuteNonQueryAsync($"CREATE TYPE {typeName} AS RANGE(subtype=text)"); - await Task.Yield(); // TODO: fix multiplexing deadlock bug conn.ReloadTypes(); Assert.That(await conn.ExecuteScalarAsync("SELECT 1"), Is.EqualTo(1)); @@ -199,7 +196,6 @@ public async Task Unmapped_range_supported_only_with_EnableUnmappedTypes() await using var connection = await DataSource.OpenConnectionAsync(); var rangeType = await GetTempTypeName(connection); await connection.ExecuteNonQueryAsync($"CREATE TYPE {rangeType} AS RANGE(subtype=text)"); - await Task.Yield(); // TODO: fix multiplexing deadlock bug await connection.ReloadTypesAsync(); var errorMessage = string.Format( @@ -208,15 +204,15 @@ public async Task Unmapped_range_supported_only_with_EnableUnmappedTypes() nameof(NpgsqlDataSourceBuilder)); var exception = await AssertTypeUnsupportedWrite(new NpgsqlRange("bar", "foo"), rangeType); - Assert.IsInstanceOf(exception.InnerException); + Assert.That(exception.InnerException, Is.InstanceOf()); Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); - exception = await AssertTypeUnsupportedRead("""["bar","foo"]""", rangeType); - Assert.IsInstanceOf(exception.InnerException); + exception = await AssertTypeUnsupportedRead("""["bar","foo"]""", rangeType); + Assert.That(exception.InnerException, Is.InstanceOf()); Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); exception = await AssertTypeUnsupportedRead>("""["bar","foo"]""", rangeType); - Assert.IsInstanceOf(exception.InnerException); + Assert.That(exception.InnerException, Is.InstanceOf()); Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); } @@ -245,9 +241,7 @@ await AssertType( }, """{"[3,4)","[5,6)"}""", "int4range[]", - NpgsqlDbType.IntegerRange | NpgsqlDbType.Array, - isDefaultForWriting: !supportsMultirange, - isNpgsqlDbTypeInferredFromClrType: false); + dataTypeInference: supportsMultirange ? DataTypeInference.Mismatch : DataTypeInference.Match); } [Test] @@ -274,7 +268,7 @@ public async Task NpgsqlSlimSourceBuilder_EnableRanges() await AssertType( dataSource, - new NpgsqlRange(1, true, 10, false), "[1,10)", "int4range", NpgsqlDbType.IntegerRange, skipArrayCheck: true); + new NpgsqlRange(1, true, 10, false), "[1,10)", "int4range", skipArrayCheck: true); } protected override NpgsqlConnection OpenConnection() @@ -288,7 +282,7 @@ public void Roundtrip_DateTime_ranges_through_ToString_and_Parse(NpgsqlRange.Parse(wellKnownText); - Assert.AreEqual(input, result); + Assert.That(result, Is.EqualTo(input)); } [Theory] @@ -298,7 +292,7 @@ public void Roundtrip_DateTime_ranges_through_ToString_and_Parse(NpgsqlRange.Parse(value); - Assert.AreEqual(NpgsqlRange.Empty, result); + Assert.That(result, Is.EqualTo(NpgsqlRange.Empty)); } [Theory] @@ -310,7 +304,7 @@ public void Parse_empty(string value) public void Roundtrip_int_ranges_through_ToString_and_Parse(string input) { var result = NpgsqlRange.Parse(input); - Assert.AreEqual(input.Replace(" ", null), result.ToString()); + Assert.That(result.ToString(), Is.EqualTo(input.Replace(" ", null))); } [Theory] @@ -330,7 +324,7 @@ public void Roundtrip_int_ranges_through_ToString_and_Parse(string input) public void Int_range_Parse_ToString_returns_normalized_representations(string input, string normalized) { var result = NpgsqlRange.Parse(input); - Assert.AreEqual(normalized, result.ToString()); + Assert.That(result.ToString(), Is.EqualTo(normalized)); } [Theory] @@ -350,7 +344,7 @@ public void Int_range_Parse_ToString_returns_normalized_representations(string i public void Nullable_int_range_Parse_ToString_returns_normalized_representations(string input, string normalized) { var result = NpgsqlRange.Parse(input); - Assert.AreEqual(normalized, result.ToString()); + Assert.That(result.ToString(), Is.EqualTo(normalized)); } [Theory] @@ -361,7 +355,7 @@ public void Nullable_int_range_Parse_ToString_returns_normalized_representations public void String_range_Parse_ToString_returns_normalized_representations(string input, string normalized) { var result = NpgsqlRange.Parse(input); - Assert.AreEqual(normalized, result.ToString()); + Assert.That(result.ToString(), Is.EqualTo(normalized)); } [Theory] @@ -369,7 +363,7 @@ public void String_range_Parse_ToString_returns_normalized_representations(strin public void Roundtrip_string_ranges_through_ToString_and_Parse2(string input) { var result = NpgsqlRange.Parse(input); - Assert.AreEqual(input, result.ToString()); + Assert.That(result.ToString(), Is.EqualTo(input)); } [Theory] @@ -388,12 +382,12 @@ public void TypeConverter() var converter = TypeDescriptor.GetConverter(typeof(NpgsqlRange)); // Act - Assert.IsInstanceOf.RangeTypeConverter>(converter); - Assert.IsTrue(converter.CanConvertFrom(typeof(string))); + Assert.That(converter, Is.InstanceOf.RangeTypeConverter>()); + Assert.That(converter.CanConvertFrom(typeof(string))); var result = converter.ConvertFromString("empty"); // Assert - Assert.AreEqual(NpgsqlRange.Empty, result); + Assert.That(result, Is.Empty); } #endregion @@ -406,14 +400,10 @@ class SimpleType string? Value { get; } SimpleType(string? value) - { - Value = value; - } + => Value = value; public override string? ToString() - { - return Value; - } + => Value; class SimpleTypeConverter : TypeConverter { @@ -438,35 +428,35 @@ public override object ConvertFrom(ITypeDescriptorContext? context, CultureInfo? new object[][] { // (2018-05-17, 2018-05-18) - new object[] { new NpgsqlRange(May_17_2018, false, false, May_18_2018, false, false) }, + [new NpgsqlRange(May_17_2018, false, false, May_18_2018, false, false)], // [2018-05-17, 2018-05-18] - new object[] { new NpgsqlRange(May_17_2018, true, false, May_18_2018, true, false) }, + [new NpgsqlRange(May_17_2018, true, false, May_18_2018, true, false)], // [2018-05-17, 2018-05-18) - new object[] { new NpgsqlRange(May_17_2018, true, false, May_18_2018, false, false) }, + [new NpgsqlRange(May_17_2018, true, false, May_18_2018, false, false)], // (2018-05-17, 2018-05-18] - new object[] { new NpgsqlRange(May_17_2018, false, false, May_18_2018, true, false) }, + [new NpgsqlRange(May_17_2018, false, false, May_18_2018, true, false)], // (,) - new object[] { new NpgsqlRange(default, false, true, default, false, true) }, - new object[] { new NpgsqlRange(May_17_2018, false, true, May_18_2018, false, true) }, + [new NpgsqlRange(default, false, true, default, false, true)], + [new NpgsqlRange(May_17_2018, false, true, May_18_2018, false, true)], // (2018-05-17,) - new object[] { new NpgsqlRange(May_17_2018, false, false, default, false, true) }, - new object[] { new NpgsqlRange(May_17_2018, false, false, May_18_2018, false, true) }, + [new NpgsqlRange(May_17_2018, false, false, default, false, true)], + [new NpgsqlRange(May_17_2018, false, false, May_18_2018, false, true)], // (,2018-05-18) - new object[] { new NpgsqlRange(default, false, true, May_18_2018, false, false) }, - new object[] { new NpgsqlRange(May_17_2018, false, true, May_18_2018, false, false) } + [new NpgsqlRange(default, false, true, May_18_2018, false, false)], + [new NpgsqlRange(May_17_2018, false, true, May_18_2018, false, false)] }; #endregion protected override NpgsqlDataSource DataSource { get; } - public RangeTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) + public RangeTests() => DataSource = CreateDataSource(builder => { builder.ConnectionStringBuilder.Timezone = "Europe/Berlin"; diff --git a/test/Npgsql.Tests/Types/RecordTests.cs b/test/Npgsql.Tests/Types/RecordTests.cs index 7aefe1e98d..2fd330badf 100644 --- a/test/Npgsql.Tests/Types/RecordTests.cs +++ b/test/Npgsql.Tests/Types/RecordTests.cs @@ -7,7 +7,7 @@ namespace Npgsql.Tests.Types; -public class RecordTests : MultiplexingTestBase +public class RecordTests : TestBase { [Test] [IssueLink("https://github.com/npgsql/npgsql/issues/724")] @@ -103,8 +103,8 @@ public async Task As_ValueTuple_supported_only_with_EnableRecordsAsTuples() nameof(NpgsqlSlimDataSourceBuilder.EnableRecords)); var exception = Assert.Throws(() => reader.GetFieldValue<(int, string)>(0))!; - Assert.IsInstanceOf(exception.InnerException); - Assert.AreEqual(errorMessage, exception.InnerException!.Message); + Assert.That(exception.InnerException, Is.InstanceOf()); + Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); } [Test] @@ -127,12 +127,12 @@ public async Task Records_not_supported_by_default_on_NpgsqlSlimSourceBuilder() nameof(NpgsqlSlimDataSourceBuilder.EnableRecords)); var exception = Assert.Throws(() => reader.GetValue(0))!; - Assert.IsInstanceOf(exception.InnerException); - Assert.AreEqual(errorMessage, exception.InnerException!.Message); + Assert.That(exception.InnerException, Is.InstanceOf()); + Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); exception = Assert.Throws(() => reader.GetFieldValue(0))!; - Assert.IsInstanceOf(exception.InnerException); - Assert.AreEqual(errorMessage, exception.InnerException!.Message); + Assert.That(exception.InnerException, Is.InstanceOf()); + Assert.That(exception.InnerException!.Message, Is.EqualTo(errorMessage)); } [Test] @@ -152,6 +152,4 @@ public async Task NpgsqlSlimSourceBuilder_EnableRecords() Assert.That(() => reader.GetValue(0), Throws.Nothing); Assert.That(() => reader.GetFieldValue(0), Throws.Nothing); } - - public RecordTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/Types/TextTests.cs b/test/Npgsql.Tests/Types/TextTests.cs index 7e86fb131b..27b9566009 100644 --- a/test/Npgsql.Tests/Types/TextTests.cs +++ b/test/Npgsql.Tests/Types/TextTests.cs @@ -15,47 +15,54 @@ namespace Npgsql.Tests.Types; /// /// https://www.postgresql.org/docs/current/static/datatype-character.html /// -public class TextTests : MultiplexingTestBase +public class TextTests : TestBase { [Test] public Task Text_as_string() - => AssertType("foo", "foo", "text", NpgsqlDbType.Text, DbType.String); + => AssertType("foo", "foo", "text", dbType: DbType.String); [Test] public Task Text_as_array_of_chars() - => AssertType("foo".ToCharArray(), "foo", "text", NpgsqlDbType.Text, DbType.String, isDefaultForReading: false); + => AssertType("foo".ToCharArray(), "foo", "text", dataTypeInference: DataTypeInference.Mismatch, + dbType: DbType.String, valueTypeEqualsFieldType: false); [Test] public Task Text_as_ArraySegment_of_chars() - => AssertTypeWrite(new ArraySegment("foo".ToCharArray()), "foo", "text", NpgsqlDbType.Text, DbType.String, - isDefault: false); + => AssertTypeWrite(new ArraySegment("foo".ToCharArray()), "foo", "text", dbType: DbType.String); [Test] public Task Text_as_array_of_bytes() - => AssertType(Encoding.UTF8.GetBytes("foo"), "foo", "text", NpgsqlDbType.Text, DbType.String, isDefault: false); + => AssertType("foo"u8.ToArray(), "foo", "text", dataTypeInference: DataTypeInference.Mismatch, + new(DbType.String, DbType.Binary), valueTypeEqualsFieldType: false); [Test] public Task Text_as_ReadOnlyMemory_of_bytes() - => AssertTypeWrite(new ReadOnlyMemory(Encoding.UTF8.GetBytes("foo")), "foo", "text", NpgsqlDbType.Text, DbType.String, - isDefault: false); + => AssertTypeWrite(new ReadOnlyMemory("foo"u8.ToArray()), "foo", + "text", dataTypeInference: DataTypeInference.Mismatch, + new(DbType.String, DbType.Binary)); [Test] public Task Char_as_char() - => AssertType('f', "f", "character", NpgsqlDbType.Char, inferredDbType: DbType.String, isDefault: false); + => AssertType('f', "f", + "character", dataTypeInference: DataTypeInference.Mismatch, + dbType: DbType.String, valueTypeEqualsFieldType: false, skipArrayCheck: true); // char[] maps to text [Test] - [NonParallelizable] public async Task Citext_as_string() { await using var conn = await OpenConnectionAsync(); await EnsureExtensionAsync(conn, "citext"); - await AssertType("foo", "foo", "citext", NpgsqlDbType.Citext, inferredDbType: DbType.String, isDefaultForWriting: false); + await AssertType("foo", "foo", + "citext", dataTypeInference: DataTypeInference.Mismatch, + dbType: DbType.String); } [Test] public Task Text_as_MemoryStream() - => AssertTypeWrite(() => new MemoryStream("foo"u8.ToArray()), "foo", "text", NpgsqlDbType.Text, DbType.String, isDefault: false); + => AssertTypeWrite(() => new MemoryStream("foo"u8.ToArray()), "foo", + "text", dataTypeInference: DataTypeInference.Mismatch, + new(DbType.String, DbType.Binary)); [Test] public async Task Text_long() @@ -65,7 +72,7 @@ public async Task Text_long() builder.Append('X', conn.Settings.WriteBufferSize); var value = builder.ToString(); - await AssertType(value, value, "text", NpgsqlDbType.Text, DbType.String); + await AssertType(value, value, "text", dbType: DbType.String); } [Test, Description("Tests that strings are truncated when the NpgsqlParameter's Size is set")] @@ -103,10 +110,10 @@ public async Task Null_character() } [Test, Description("Tests some types which are aliased to strings")] - [TestCase("character varying", NpgsqlDbType.Varchar)] - [TestCase("name", NpgsqlDbType.Name)] - public Task Aliased_postgres_types(string pgTypeName, NpgsqlDbType npgsqlDbType) - => AssertType("foo", "foo", pgTypeName, npgsqlDbType, inferredDbType: DbType.String, isDefaultForWriting: false); + [TestCase("character varying")] + [TestCase("name")] + public Task Aliased_postgres_types(string dataTypeName) + => AssertType("foo", "foo", dataTypeName, dataTypeInference: DataTypeInference.Mismatch, dbType: DbType.String); [Test] [TestCase(DbType.AnsiString)] @@ -138,17 +145,15 @@ public async Task Internal_char() var expected = new char[] { 'a', (char)(256 - 3), 'b', (char)66, (char)230 }; for (var i = 0; i < expected.Length; i++) { - Assert.AreEqual(expected[i], reader.GetChar(i)); + Assert.That(reader.GetChar(i), Is.EqualTo(expected[i])); } var arr = (char[])reader.GetValue(5); var arr2 = (char[])reader.GetValue(6); - Assert.AreEqual(testArr.Length, arr.Length); + Assert.That(arr.Length, Is.EqualTo(testArr.Length)); for (var i = 0; i < arr.Length; i++) { - Assert.AreEqual(testArr[i], arr[i]); - Assert.AreEqual(testArr2[i], arr2[i]); + Assert.That(arr[i], Is.EqualTo(testArr[i])); + Assert.That(arr2[i], Is.EqualTo(testArr2[i])); } } - - public TextTests(MultiplexingMode multiplexingMode) : base(multiplexingMode) {} } diff --git a/test/Npgsql.Tests/TypesTests.cs b/test/Npgsql.Tests/TypesTests.cs index 610d640a02..4110a0856f 100644 --- a/test/Npgsql.Tests/TypesTests.cs +++ b/test/Npgsql.Tests/TypesTests.cs @@ -17,22 +17,22 @@ public void TsVector() NpgsqlTsVector vec; vec = NpgsqlTsVector.Parse("a"); - Assert.AreEqual("'a'", vec.ToString()); + Assert.That(vec.ToString(), Is.EqualTo("'a'")); vec = NpgsqlTsVector.Parse("a "); - Assert.AreEqual("'a'", vec.ToString()); + Assert.That(vec.ToString(), Is.EqualTo("'a'")); vec = NpgsqlTsVector.Parse("a:1A"); - Assert.AreEqual("'a':1A", vec.ToString()); + Assert.That(vec.ToString(), Is.EqualTo("'a':1A")); vec = NpgsqlTsVector.Parse(@"\abc\def:1a "); - Assert.AreEqual("'abcdef':1A", vec.ToString()); + Assert.That(vec.ToString(), Is.EqualTo("'abcdef':1A")); vec = NpgsqlTsVector.Parse(@"abc:3A 'abc' abc:4B 'hello''yo' 'meh\'\\':5"); - Assert.AreEqual(@"'abc':3A,4B 'hello''yo' 'meh''\\':5", vec.ToString()); + Assert.That(vec.ToString(), Is.EqualTo(@"'abc':3A,4B 'hello''yo' 'meh''\\':5")); vec = NpgsqlTsVector.Parse(" a:12345C a:24D a:25B b c d 1 2 a:25A,26B,27,28"); - Assert.AreEqual("'1' '2' 'a':24,25A,26B,27,28,12345C 'b' 'c' 'd'", vec.ToString()); + Assert.That(vec.ToString(), Is.EqualTo("'1' '2' 'a':24,25A,26B,27,28,12345C 'b' 'c' 'd'")); } [Test] @@ -47,27 +47,27 @@ public void TsQuery() var str = query.ToString(); query = NpgsqlTsQuery.Parse("a & b | c"); - Assert.AreEqual("'a' & 'b' | 'c'", query.ToString()); + Assert.That(query.ToString(), Is.EqualTo("'a' & 'b' | 'c'")); query = NpgsqlTsQuery.Parse("'a''':*ab&d:d&!c"); - Assert.AreEqual("'a''':*AB & 'd':D & !'c'", query.ToString()); + Assert.That(query.ToString(), Is.EqualTo("'a''':*AB & 'd':D & !'c'")); query = NpgsqlTsQuery.Parse("(a & !(c | d)) & (!!a&b) | c | d | e"); - Assert.AreEqual("( ( 'a' & !( 'c' | 'd' ) & !( !'a' ) & 'b' | 'c' ) | 'd' ) | 'e'", query.ToString()); - Assert.AreEqual(query.ToString(), NpgsqlTsQuery.Parse(query.ToString()).ToString()); + Assert.That(query.ToString(), Is.EqualTo("( ( 'a' & !( 'c' | 'd' ) & !( !'a' ) & 'b' | 'c' ) | 'd' ) | 'e'")); + Assert.That(NpgsqlTsQuery.Parse(query.ToString()).ToString(), Is.EqualTo(query.ToString())); query = NpgsqlTsQuery.Parse("(((a:*)))"); - Assert.AreEqual("'a':*", query.ToString()); + Assert.That(query.ToString(), Is.EqualTo("'a':*")); query = NpgsqlTsQuery.Parse(@"'a\\b''cde'"); - Assert.AreEqual(@"a\b'cde", ((NpgsqlTsQueryLexeme)query).Text); - Assert.AreEqual(@"'a\\b''cde'", query.ToString()); + Assert.That(((NpgsqlTsQueryLexeme)query).Text, Is.EqualTo(@"a\b'cde")); + Assert.That(query.ToString(), Is.EqualTo(@"'a\\b''cde'")); query = NpgsqlTsQuery.Parse(@"a <-> b"); - Assert.AreEqual("'a' <-> 'b'", query.ToString()); + Assert.That(query.ToString(), Is.EqualTo("'a' <-> 'b'")); query = NpgsqlTsQuery.Parse("((a & b) <5> c) <-> !d <0> e"); - Assert.AreEqual("( ( 'a' & 'b' <5> 'c' ) <-> !'d' ) <0> 'e'", query.ToString()); + Assert.That(query.ToString(), Is.EqualTo("( ( 'a' & 'b' <5> 'c' ) <-> !'d' ) <0> 'e'")); Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse("a b c & &")); Assert.Throws(typeof(FormatException), () => NpgsqlTsQuery.Parse("&")); @@ -86,6 +86,13 @@ public void TsQuery() } #pragma warning restore CS0618 // {NpgsqlTsVector,NpgsqlTsQuery}.Parse are obsolete + [Test] + public void TsVector_empty() + { + Assert.That(NpgsqlTsVector.Empty, Is.Empty); + Assert.That(NpgsqlTsVector.Empty.ToString(), Is.Empty); + } + [Test] public void TsQueryEquatibility() { @@ -160,18 +167,18 @@ public void TsQueryEquatibility() void AreEqual(NpgsqlTsQuery left, NpgsqlTsQuery right) { - Assert.True(left == right); - Assert.False(left != right); - Assert.AreEqual(left, right); - Assert.AreEqual(left.GetHashCode(), right.GetHashCode()); + Assert.That(left == right); + Assert.That(left != right, Is.False); + Assert.That(right, Is.EqualTo(left)); + Assert.That(right.GetHashCode(), Is.EqualTo(left.GetHashCode())); } void AreNotEqual(NpgsqlTsQuery left, NpgsqlTsQuery right) { - Assert.False(left == right); - Assert.True(left != right); - Assert.AreNotEqual(left, right); - Assert.AreNotEqual(left.GetHashCode(), right.GetHashCode()); + Assert.That(left == right, Is.False); + Assert.That(left != right); + Assert.That(right, Is.Not.EqualTo(left)); + Assert.That(right.GetHashCode(), Is.Not.EqualTo(left.GetHashCode())); } } @@ -181,7 +188,7 @@ public void TsQueryOperatorPrecedence() { var query = NpgsqlTsQuery.Parse("!a <-> b & c | d & e"); var expectedGrouping = NpgsqlTsQuery.Parse("((!(a) <-> b) & c) | (d & e)"); - Assert.AreEqual(expectedGrouping.ToString(), query.ToString()); + Assert.That(query.ToString(), Is.EqualTo(expectedGrouping.ToString())); } #pragma warning restore CS0618 // {NpgsqlTsVector,NpgsqlTsQuery}.Parse are obsolete @@ -193,6 +200,20 @@ public void NpgsqlPath_empty() public void NpgsqlPolygon_empty() => Assert.That(new NpgsqlPolygon { new(1, 2) }, Is.EqualTo(new NpgsqlPolygon(new NpgsqlPoint(1, 2)))); + [Test] + public void NpgsqlPath_default() + { + NpgsqlPath defaultPath = default; + Assert.That(defaultPath.Equals([new(1, 2)]), Is.False); + } + + [Test] + public void NpgsqlPolygon_default() + { + NpgsqlPolygon defaultPolygon = default; + Assert.That(defaultPolygon.Equals([new(1, 2)]), Is.False); + } + [Test] public void Bug1011018() { @@ -209,4 +230,43 @@ public void NpgsqlInet() var v = new NpgsqlInet(IPAddress.Parse("2001:1db8:85a3:1142:1000:8a2e:1370:7334"), 32); Assert.That(v.ToString(), Is.EqualTo("2001:1db8:85a3:1142:1000:8a2e:1370:7334/32")); } + + [Test] + public void NpgsqlInet_parse_ipv4() + { + var ipv4 = new NpgsqlInet("192.168.1.1/8"); + Assert.That(ipv4.Address, Is.EqualTo(IPAddress.Parse("192.168.1.1"))); + Assert.That(ipv4.Netmask, Is.EqualTo(8)); + + ipv4 = new NpgsqlInet("192.168.1.1/32"); + Assert.That(ipv4.Address, Is.EqualTo(IPAddress.Parse("192.168.1.1"))); + Assert.That(ipv4.Netmask, Is.EqualTo(32)); + } + + [Test] + [IssueLink("https://github.com/npgsql/npgsql/issues/5638")] + public void NpgsqlInet_parse_ipv6() + { + var ipv6 = new NpgsqlInet("2001:0000:130F:0000:0000:09C0:876A:130B/32"); + Assert.That(ipv6.Address, Is.EqualTo(IPAddress.Parse("2001:0000:130F:0000:0000:09C0:876A:130B"))); + Assert.That(ipv6.Netmask, Is.EqualTo(32)); + + ipv6 = new NpgsqlInet("2001:0000:130F:0000:0000:09C0:876A:130B"); + Assert.That(ipv6.Address, Is.EqualTo(IPAddress.Parse("2001:0000:130F:0000:0000:09C0:876A:130B"))); + Assert.That(ipv6.Netmask, Is.EqualTo(128)); + } + + [Test] + public void NpgsqlInet_ToString_ipv4() + { + Assert.That(new NpgsqlInet("192.168.1.1/8").ToString(), Is.EqualTo("192.168.1.1/8")); + Assert.That(new NpgsqlInet("192.168.1.1/32").ToString(), Is.EqualTo("192.168.1.1")); + } + + [Test] + public void NpgsqlInet_ToString_ipv6() + { + Assert.That(new NpgsqlInet("2001:0:130f::9c0:876a:130b/32").ToString(), Is.EqualTo("2001:0:130f::9c0:876a:130b/32")); + Assert.That(new NpgsqlInet("2001:0:130f::9c0:876a:130b/128").ToString(), Is.EqualTo("2001:0:130f::9c0:876a:130b")); + } } diff --git a/test/Npgsql.Tests/WriteBufferTests.cs b/test/Npgsql.Tests/WriteBufferTests.cs index 99e5626b75..53bf753dd6 100644 --- a/test/Npgsql.Tests/WriteBufferTests.cs +++ b/test/Npgsql.Tests/WriteBufferTests.cs @@ -33,20 +33,17 @@ public void GetWriter_Full_Buffer() } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/1275")] - public void Write_zero_characters() + public void Chunked_string_with_full_buffer() { // Fill up the buffer entirely WriteBuffer.WriteBytes(new byte[WriteBuffer.Size], 0, WriteBuffer.Size); Assert.That(WriteBuffer.WriteSpaceLeft, Is.Zero); - int charsUsed; - bool completed; - WriteBuffer.WriteStringChunked("hello", 0, 5, true, out charsUsed, out completed); - Assert.That(charsUsed, Is.Zero); - Assert.That(completed, Is.False); - WriteBuffer.WriteStringChunked("hello".ToCharArray(), 0, 5, true, out charsUsed, out completed); - Assert.That(charsUsed, Is.Zero); - Assert.That(completed, Is.False); + var data = new string('a', WriteBuffer.Size) + "hello"; + var byteLength = WriteBuffer.TextEncoding.GetByteCount(data); + WriteBuffer.WriteString(data, byteLength, false); + Assert.That(WriteBuffer.WritePosition, Is.EqualTo(5)); + Assert.That(WriteBuffer.Buffer.AsSpan(0, 5).ToArray(), Is.EqualTo(new byte[] { (byte)'h', (byte)'e', (byte)'l', (byte)'l', (byte)'o' })); } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2849")] @@ -55,26 +52,11 @@ public void Chunked_string_encoding_fits() WriteBuffer.WriteBytes(new byte[WriteBuffer.Size - 1], 0, WriteBuffer.Size - 1); Assert.That(WriteBuffer.WriteSpaceLeft, Is.EqualTo(1)); - var charsUsed = 1; - var completed = true; // This unicode character is three bytes when encoded in UTF8 - Assert.That(() => WriteBuffer.WriteStringChunked("\uD55C", 0, 1, true, out charsUsed, out completed), Throws.Nothing); - Assert.That(charsUsed, Is.EqualTo(0)); - Assert.That(completed, Is.False); - } - - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/2849")] - public void Chunked_byte_array_encoding_fits() - { - WriteBuffer.WriteBytes(new byte[WriteBuffer.Size - 1], 0, WriteBuffer.Size - 1); - Assert.That(WriteBuffer.WriteSpaceLeft, Is.EqualTo(1)); - - var charsUsed = 1; - var completed = true; - // This unicode character is three bytes when encoded in UTF8 - Assert.That(() => WriteBuffer.WriteStringChunked("\uD55C".ToCharArray(), 0, 1, true, out charsUsed, out completed), Throws.Nothing); - Assert.That(charsUsed, Is.EqualTo(0)); - Assert.That(completed, Is.False); + var data = "\uD55C" + new string('a', WriteBuffer.Size); + var byteLength = WriteBuffer.TextEncoding.GetByteCount(data); + WriteBuffer.WriteString(data, byteLength, false); + Assert.That(WriteBuffer.WritePosition, Is.EqualTo(3)); } [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3733")] @@ -83,28 +65,10 @@ public void Chunked_string_encoding_fits_with_surrogates() WriteBuffer.WriteBytes(new byte[WriteBuffer.Size - 1]); Assert.That(WriteBuffer.WriteSpaceLeft, Is.EqualTo(1)); - var charsUsed = 1; - var completed = true; - var cyclone = "🌀"; - - Assert.That(() => WriteBuffer.WriteStringChunked(cyclone, 0, cyclone.Length, true, out charsUsed, out completed), Throws.Nothing); - Assert.That(charsUsed, Is.EqualTo(0)); - Assert.That(completed, Is.False); - } - - [Test, IssueLink("https://github.com/npgsql/npgsql/issues/3733")] - public void Chunked_char_array_encoding_fits_with_surrogates() - { - WriteBuffer.WriteBytes(new byte[WriteBuffer.Size - 1]); - Assert.That(WriteBuffer.WriteSpaceLeft, Is.EqualTo(1)); - - var charsUsed = 1; - var completed = true; - var cyclone = "🌀"; - - Assert.That(() => WriteBuffer.WriteStringChunked(cyclone.ToCharArray(), 0, cyclone.Length, true, out charsUsed, out completed), Throws.Nothing); - Assert.That(charsUsed, Is.EqualTo(0)); - Assert.That(completed, Is.False); + var cyclone = "🌀" + new string('a', WriteBuffer.Size); + var byteLength = WriteBuffer.TextEncoding.GetByteCount(cyclone); + WriteBuffer.WriteString(cyclone, byteLength, false); + Assert.That(WriteBuffer.WritePosition, Is.EqualTo(4)); } [SetUp] @@ -112,6 +76,7 @@ public void SetUp() { Underlying = new MemoryStream(); WriteBuffer = new NpgsqlWriteBuffer(null, Underlying, null, NpgsqlReadBuffer.DefaultSize, NpgsqlWriteBuffer.UTF8Encoding); + WriteBuffer.MessageLengthValidation = false; } // ReSharper disable once InconsistentNaming