Skip to content

Commit

Permalink
Refactor usage of SocketAddr (#2088)
Browse files Browse the repository at this point in the history
* Remove unnecessary reference for parameter

Not needed because `u16` implements `Copy`.

* Replace `.to_socket_addrs` with `lookup_host`

The `ToSocketAddrs::to_socket_addrs` function performs a blocking DNS
lookup, so it should be avoided in asynchronous tasks. It can be
replaced by `tokio::net::lookup_host` which receives a
`tokio::net::ToSocketAddrs` implementation.

* Change address parameter to be generic

Don't force a string to be built, which then needs to be parsed into a
`SocketAddr`.

* Use `(String, u16)` instead of a raw `String`

Avoid having to format the string to include the port, which will then
get parsed out later.

* Create a `SocketAddr` directly from integers

Avoid parsing the IP address and port.

* Avoid parsing `SocketAddr` if not needed

Let the IP address be parsed by Tokio, and avoid formatting a string
just to include the port which will get parsed out again later.
  • Loading branch information
jvff authored May 31, 2024
1 parent 058d953 commit a79bd1c
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 31 deletions.
4 changes: 2 additions & 2 deletions linera-rpc/src/simple/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ where
"Listening to {:?} traffic on {}:{}",
self.network.protocol, self.host, self.port
);
let address = format!("{}:{}", self.host, self.port);
let address = (self.host.clone(), self.port);

let (cross_chain_sender, cross_chain_receiver) =
mpsc::channel(self.cross_chain_config.queue_size);
Expand All @@ -161,7 +161,7 @@ where
cross_chain_sender,
};
// Launch server for the appropriate protocol.
protocol.spawn_server(&address, state).await
protocol.spawn_server(address, state).await
}
}

Expand Down
17 changes: 10 additions & 7 deletions linera-rpc/src/simple/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Copyright (c) Zefchain Labs, Inc.
// SPDX-License-Identifier: Apache-2.0

use std::{collections::HashMap, io, net::ToSocketAddrs, sync::Arc};
use std::{collections::HashMap, io, sync::Arc};

use async_trait::async_trait;
use futures::{
Expand All @@ -12,7 +12,7 @@ use futures::{
};
use serde::{Deserialize, Serialize};
use tokio::{
net::{TcpListener, TcpStream, UdpSocket},
net::{lookup_host, TcpListener, TcpStream, ToSocketAddrs, UdpSocket},
sync::Mutex,
};
use tokio_util::{codec::Framed, udp::UdpFramed};
Expand Down Expand Up @@ -111,9 +111,12 @@ impl<T> Transport for T where

impl TransportProtocol {
/// Creates a transport for this protocol.
pub async fn connect(self, address: String) -> Result<impl Transport, std::io::Error> {
let mut addresses = address
.to_socket_addrs()
pub async fn connect(
self,
address: impl ToSocketAddrs,
) -> Result<impl Transport, std::io::Error> {
let mut addresses = lookup_host(address)
.await
.expect("Invalid address to connect to");
let address = addresses
.next()
Expand Down Expand Up @@ -152,7 +155,7 @@ impl TransportProtocol {
/// Runs a server for this protocol and the given message handler.
pub async fn spawn_server<S>(
self,
address: &str,
address: impl ToSocketAddrs,
state: S,
) -> Result<ServerHandle, std::io::Error>
where
Expand All @@ -161,7 +164,7 @@ impl TransportProtocol {
let (abort, registration) = AbortHandle::new_pair();
let handle = match self {
Self::Udp => {
let socket = UdpSocket::bind(&address).await?;
let socket = UdpSocket::bind(address).await?;
tokio::spawn(Self::run_udp_server(socket, state, registration))
}
Self::Tcp => {
Expand Down
5 changes: 3 additions & 2 deletions linera-service/src/prometheus_server.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
// Copyright (c) Zefchain Labs, Inc.
// SPDX-License-Identifier: Apache-2.0

use std::net::SocketAddr;
use std::fmt::Debug;

use axum::{http::StatusCode, response::IntoResponse, routing::get, Router};
use tokio::net::ToSocketAddrs;
use tracing::info;

pub fn start_metrics(address: SocketAddr) {
pub fn start_metrics(address: impl ToSocketAddrs + Debug + Send + 'static) {
info!("Starting to serve metrics on {:?}", address);
let prometheus_router = Router::new().route("/metrics", get(serve_metrics));

Expand Down
19 changes: 8 additions & 11 deletions linera-service/src/proxy.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Zefchain Labs, Inc.
// SPDX-License-Identifier: Apache-2.0

use std::{path::PathBuf, time::Duration};
use std::{net::SocketAddr, path::PathBuf, time::Duration};

use anyhow::{bail, Result};
use async_trait::async_trait;
Expand All @@ -15,6 +15,8 @@ use linera_rpc::{
simple::{MessageHandler, TransportProtocol},
RpcMessage,
};
#[cfg(with_metrics)]
use linera_service::prometheus_server;
use linera_service::{
config::{GenesisConfig, Import, ValidatorServerConfig},
grpc_proxy::GrpcProxy,
Expand All @@ -24,8 +26,6 @@ use linera_service::{
use linera_storage::Storage;
use linera_views::{common::CommonStoreConfig, views::ViewError};
use tracing::{error, info, instrument};
#[cfg(with_metrics)]
use {linera_service::prometheus_server, std::net::SocketAddr};

/// Options for running the proxy.
#[derive(clap::Parser, Debug, Clone)]
Expand Down Expand Up @@ -232,7 +232,7 @@ where
let address = self.get_listen_address(self.public_config.port);

#[cfg(with_metrics)]
Self::start_metrics(&self.get_listen_address(self.internal_config.metrics_port));
Self::start_metrics(self.get_listen_address(self.internal_config.metrics_port));

self.public_config
.protocol
Expand All @@ -244,15 +244,12 @@ where
}

#[cfg(with_metrics)]
pub fn start_metrics(address: &String) {
match address.parse::<SocketAddr>() {
Err(err) => panic!("Invalid metrics address for {address}: {err}"),
Ok(address) => prometheus_server::start_metrics(address),
}
pub fn start_metrics(address: SocketAddr) {
prometheus_server::start_metrics(address)
}

fn get_listen_address(&self, port: u16) -> String {
format!("0.0.0.0:{}", port)
fn get_listen_address(&self, port: u16) -> SocketAddr {
SocketAddr::from(([0, 0, 0, 0], port))
}

async fn try_proxy_message(
Expand Down
15 changes: 6 additions & 9 deletions linera-service/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ use linera_rpc::{
},
grpc, simple,
};
#[cfg(with_metrics)]
use linera_service::prometheus_server;
use linera_service::{
config::{
CommitteeConfig, Export, GenesisConfig, Import, ValidatorConfig, ValidatorServerConfig,
Expand All @@ -28,8 +30,6 @@ use linera_storage::Storage;
use linera_views::{common::CommonStoreConfig, views::ViewError};
use serde::Deserialize;
use tracing::{error, info};
#[cfg(with_metrics)]
use {linera_service::prometheus_server, std::net::SocketAddr};

struct ServerContext {
server_config: ValidatorServerConfig,
Expand Down Expand Up @@ -84,7 +84,7 @@ impl ServerContext {
handles.push(async move {
#[cfg(with_metrics)]
if let Some(port) = shard.metrics_port {
Self::start_metrics(listen_address, &port);
Self::start_metrics(listen_address, port);
}
let server = simple::Server::new(
internal_network,
Expand Down Expand Up @@ -128,7 +128,7 @@ impl ServerContext {
handles.push(async move {
#[cfg(with_metrics)]
if let Some(port) = shard.metrics_port {
Self::start_metrics(listen_address, &port);
Self::start_metrics(listen_address, port);
}
let spawned_server = match grpc::GrpcServer::spawn(
listen_address.to_string(),
Expand Down Expand Up @@ -159,11 +159,8 @@ impl ServerContext {
}

#[cfg(with_metrics)]
fn start_metrics(host: &str, port: &u16) {
match format!("{}:{}", host, port).parse::<SocketAddr>() {
Err(err) => panic!("Invalid metrics address for {host}:{port}: {err}"),
Ok(address) => prometheus_server::start_metrics(address),
}
fn start_metrics(host: &str, port: u16) {
prometheus_server::start_metrics((host.to_owned(), port));
}

fn get_listen_address(&self) -> String {
Expand Down

0 comments on commit a79bd1c

Please sign in to comment.