use std::{
pin::Pin,
task::{ready, Context, Poll},
};
use tokio::{
io::{AsyncRead, AsyncWrite},
net::{TcpListener, TcpStream, UnixListener, UnixStream},
};
pub enum SocketAddr {
Unix(tokio::net::unix::SocketAddr),
Net(std::net::SocketAddr),
}
impl From<tokio::net::unix::SocketAddr> for SocketAddr {
fn from(value: tokio::net::unix::SocketAddr) -> Self {
Self::Unix(value)
}
}
impl From<std::net::SocketAddr> for SocketAddr {
fn from(value: std::net::SocketAddr) -> Self {
Self::Net(value)
}
}
impl std::fmt::Debug for SocketAddr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Unix(l) => std::fmt::Debug::fmt(l, f),
Self::Net(l) => std::fmt::Debug::fmt(l, f),
}
}
}
impl SocketAddr {
#[must_use]
pub fn into_net(self) -> Option<std::net::SocketAddr> {
match self {
Self::Net(socket) => Some(socket),
Self::Unix(_) => None,
}
}
#[must_use]
pub fn into_unix(self) -> Option<tokio::net::unix::SocketAddr> {
match self {
Self::Net(_) => None,
Self::Unix(socket) => Some(socket),
}
}
#[must_use]
pub const fn as_net(&self) -> Option<&std::net::SocketAddr> {
match self {
Self::Net(socket) => Some(socket),
Self::Unix(_) => None,
}
}
#[must_use]
pub const fn as_unix(&self) -> Option<&tokio::net::unix::SocketAddr> {
match self {
Self::Net(_) => None,
Self::Unix(socket) => Some(socket),
}
}
}
pub enum UnixOrTcpListener {
Unix(UnixListener),
Tcp(TcpListener),
}
impl From<UnixListener> for UnixOrTcpListener {
fn from(listener: UnixListener) -> Self {
Self::Unix(listener)
}
}
impl From<TcpListener> for UnixOrTcpListener {
fn from(listener: TcpListener) -> Self {
Self::Tcp(listener)
}
}
impl TryFrom<std::os::unix::net::UnixListener> for UnixOrTcpListener {
type Error = std::io::Error;
fn try_from(listener: std::os::unix::net::UnixListener) -> Result<Self, Self::Error> {
listener.set_nonblocking(true)?;
Ok(Self::Unix(UnixListener::from_std(listener)?))
}
}
impl TryFrom<std::net::TcpListener> for UnixOrTcpListener {
type Error = std::io::Error;
fn try_from(listener: std::net::TcpListener) -> Result<Self, Self::Error> {
listener.set_nonblocking(true)?;
Ok(Self::Tcp(TcpListener::from_std(listener)?))
}
}
impl UnixOrTcpListener {
pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
match self {
Self::Unix(listener) => listener.local_addr().map(SocketAddr::from),
Self::Tcp(listener) => listener.local_addr().map(SocketAddr::from),
}
}
pub const fn is_unix(&self) -> bool {
matches!(self, Self::Unix(_))
}
pub const fn is_tcp(&self) -> bool {
matches!(self, Self::Tcp(_))
}
pub async fn accept(&self) -> Result<(SocketAddr, UnixOrTcpConnection), std::io::Error> {
match self {
Self::Unix(listener) => {
let (stream, remote_addr) = listener.accept().await?;
let socket = socket2::SockRef::from(&stream);
socket.set_keepalive(true)?;
socket.set_nodelay(true)?;
Ok((remote_addr.into(), UnixOrTcpConnection::Unix { stream }))
}
Self::Tcp(listener) => {
let (stream, remote_addr) = listener.accept().await?;
let socket = socket2::SockRef::from(&stream);
socket.set_keepalive(true)?;
socket.set_nodelay(true)?;
Ok((remote_addr.into(), UnixOrTcpConnection::Tcp { stream }))
}
}
}
pub fn poll_accept(
&self,
cx: &mut Context<'_>,
) -> Poll<Result<(SocketAddr, UnixOrTcpConnection), std::io::Error>> {
match self {
Self::Unix(listener) => {
let (stream, remote_addr) = ready!(listener.poll_accept(cx)?);
let socket = socket2::SockRef::from(&stream);
socket.set_keepalive(true)?;
socket.set_nodelay(true)?;
Poll::Ready(Ok((
remote_addr.into(),
UnixOrTcpConnection::Unix { stream },
)))
}
Self::Tcp(listener) => {
let (stream, remote_addr) = ready!(listener.poll_accept(cx)?);
let socket = socket2::SockRef::from(&stream);
socket.set_keepalive(true)?;
socket.set_nodelay(true)?;
Poll::Ready(Ok((
remote_addr.into(),
UnixOrTcpConnection::Tcp { stream },
)))
}
}
}
}
pin_project_lite::pin_project! {
#[project = UnixOrTcpConnectionProj]
pub enum UnixOrTcpConnection {
Unix {
#[pin]
stream: UnixStream,
},
Tcp {
#[pin]
stream: TcpStream,
},
}
}
impl From<TcpStream> for UnixOrTcpConnection {
fn from(stream: TcpStream) -> Self {
Self::Tcp { stream }
}
}
impl UnixOrTcpConnection {
pub fn local_addr(&self) -> Result<SocketAddr, std::io::Error> {
match self {
Self::Unix { stream } => stream.local_addr().map(SocketAddr::from),
Self::Tcp { stream } => stream.local_addr().map(SocketAddr::from),
}
}
pub fn peer_addr(&self) -> Result<SocketAddr, std::io::Error> {
match self {
Self::Unix { stream } => stream.peer_addr().map(SocketAddr::from),
Self::Tcp { stream } => stream.peer_addr().map(SocketAddr::from),
}
}
}
impl AsyncRead for UnixOrTcpConnection {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match self.project() {
UnixOrTcpConnectionProj::Unix { stream } => stream.poll_read(cx, buf),
UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_read(cx, buf),
}
}
}
impl AsyncWrite for UnixOrTcpConnection {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
match self.project() {
UnixOrTcpConnectionProj::Unix { stream } => stream.poll_write(cx, buf),
UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_write(cx, buf),
}
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
match self.project() {
UnixOrTcpConnectionProj::Unix { stream } => stream.poll_write_vectored(cx, bufs),
UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_write_vectored(cx, bufs),
}
}
fn is_write_vectored(&self) -> bool {
match self {
UnixOrTcpConnection::Unix { stream } => stream.is_write_vectored(),
UnixOrTcpConnection::Tcp { stream } => stream.is_write_vectored(),
}
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
match self.project() {
UnixOrTcpConnectionProj::Unix { stream } => stream.poll_flush(cx),
UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_flush(cx),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
match self.project() {
UnixOrTcpConnectionProj::Unix { stream } => stream.poll_shutdown(cx),
UnixOrTcpConnectionProj::Tcp { stream } => stream.poll_shutdown(cx),
}
}
}