Skip to content

Commit

Permalink
ipc: implement Request trait with statically-guaranteed Response type
Browse files Browse the repository at this point in the history
  • Loading branch information
sodiboo committed Apr 17, 2024
1 parent 50b38d4 commit f1bfef9
Show file tree
Hide file tree
Showing 4 changed files with 710 additions and 256 deletions.
221 changes: 192 additions & 29 deletions niri-ipc/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Types for communicating with niri via IPC.
#![warn(missing_docs)]

use std::collections::HashMap;
use std::collections::BTreeMap;
use std::str::FromStr;

use serde::{Deserialize, Serialize};
Expand All @@ -10,21 +10,199 @@ mod socket;

pub use socket::{NiriSocket, SOCKET_PATH_ENV};

/// Request from client to niri.
mod private {
pub trait Sealed {}
}

// TODO: remove ResponseDecoder and AnyRequest?

#[allow(missing_docs)]
pub trait ResponseDecoder {
type Output: for<'de> Deserialize<'de>;
fn decode(&self, value: serde_json::Value) -> serde_json::Result<Self::Output>;
}

#[derive(Debug, Clone, Copy)]
#[allow(missing_docs)]
pub struct TrivialDecoder<T: for<'de> Deserialize<'de>>(std::marker::PhantomData<T>);

impl<T: for<'de> Deserialize<'de>> Default for TrivialDecoder<T> {
fn default() -> Self {
Self(std::marker::PhantomData)
}
}

impl<T: for<'de> Deserialize<'de>> ResponseDecoder for TrivialDecoder<T> {
type Output = T;

fn decode(&self, value: serde_json::Value) -> serde_json::Result<Self::Output> {
serde_json::from_value(value)
}
}

/// A request that can be sent to niri.
pub trait Request:
Serialize + for<'de> Deserialize<'de> + private::Sealed + Into<RequestMessage>
{
/// The type of the response that niri sends for this request.
type Response: Serialize + for<'de> Deserialize<'de>;

#[allow(missing_docs)]
fn decoder(&self) -> impl ResponseDecoder<Output = Reply<Self::Response>> + 'static;

/// Convert the request into a RequestMessage (for serialization).
fn into_message(self) -> RequestMessage;
}

macro_rules! requests {
(@$item:item$(;)?) => { $item };
($($(#[$m:meta])*$variant:ident($v:vis struct $request:ident$($p:tt)?) -> $response:ty;)*) => {
#[derive(Debug, Serialize, Deserialize, Clone)]
/// A plain tag for each request type.
pub enum RequestType {
$(
$(#[$m])*
$variant,
)*
}

#[derive(Debug, Serialize, Deserialize, Clone)]
enum AnyRequest {
$(
$(#[$m])*
$variant($request),
)*
}

#[derive(Debug, Serialize, Deserialize, Clone)]
enum AnyResponse {
$(
$(#[$m])*
$variant($response),
)*
}

impl private::Sealed for AnyRequest {}

struct AnyResponseDecoder(RequestType);

impl ResponseDecoder for AnyResponseDecoder {
type Output = Reply<AnyResponse>;

fn decode(&self, value: serde_json::Value) -> serde_json::Result<Self::Output> {
match self.0 {
$(
RequestType::$variant => TrivialDecoder::<Reply<$response>>::default().decode(value).map(|r| r.map(AnyResponse::$variant)),
)*
}
}
}

impl TryFrom<RequestMessage> for AnyRequest {
type Error = serde_json::Error;

fn try_from(message: RequestMessage) -> serde_json::Result<Self> {
match message.request_type {
$(
RequestType::$variant => serde_json::from_value(message.request_body).map(AnyRequest::$variant),
)*
}
}
}

impl Request for AnyRequest {
type Response = AnyResponse;

fn decoder(&self) -> impl ResponseDecoder<Output = Reply<Self::Response>> + 'static {
match self {
$(
AnyRequest::$variant(_) => AnyResponseDecoder(RequestType::$variant),
)*
}
}

fn into_message(self) -> RequestMessage {
match self {
$(
AnyRequest::$variant(request) => request.into_message(),
)*
}
}
}


$(
requests!(@
$(#[$m])*
#[derive(Debug, Serialize, Deserialize, Clone)]
$v struct $request $($p)?;
);

impl From<$request> for AnyRequest {
fn from(request: $request) -> Self {
AnyRequest::$variant(request)
}
}

impl crate::private::Sealed for $request {}

impl crate::Request for $request {
type Response = $response;

fn decoder(&self) -> impl crate::ResponseDecoder<Output = crate::Reply<Self::Response>> + 'static {
TrivialDecoder::<Reply<$response>>::default()
}

fn into_message(self) -> RequestMessage {
RequestMessage {
request_type: RequestType::$variant,
request_body: serde_json::to_value(self).unwrap(),
}
}
}
)*
}
}

/// The message format for IPC communication.
///
/// This is mainly to avoid using sum types in IPC communication, which are more annoying to use
/// with non-Rust tooling.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum Request {
/// Always responds with an error. (For testing error handling)
ReturnError,
/// Request the version string for the running niri instance.
Version,
/// Request information about connected outputs.
Outputs,
/// Request information about the focused window.
FocusedWindow,
/// Perform an action.
Action(Action),
pub struct RequestMessage {
/// The type of the request.
pub request_type: RequestType,
/// The raw JSON body of the request.
pub request_body: serde_json::Value,
}

impl<R: Request> From<R> for RequestMessage {
fn from(value: R) -> Self {
value.into_message()
}
}

/// Uninstantiable
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum Never {}

requests!(
/// Always responds with an error (for testing error handling).
ReturnError(pub struct ErrorRequest) -> Never;

/// Requests the version string for the running niri instance.
Version(pub struct VersionRequest) -> String;

/// Requests information about connected outputs.
Outputs(pub struct OutputRequest) -> BTreeMap<String, Output>;

/// Requests information about the focused window.
FocusedWindow(pub struct FocusedWindowRequest) -> Option<Window>;

/// Requests that the compositor perform an action.
Action(pub struct ActionRequest(pub Action)) -> ();
);

/// Reply from niri to client.
///
/// Every request gets one reply.
Expand All @@ -33,22 +211,7 @@ pub enum Request {
/// * If the request does not need any particular response, it will be
/// `Reply::Ok(Response::Handled)`. Kind of like an `Ok(())`.
/// * Otherwise, it will be `Reply::Ok(response)` with one of the other [`Response`] variants.
pub type Reply = Result<Response, String>;

/// Successful response from niri to client.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub enum Response {
/// A request that does not need a response was handled successfully.
Handled,
/// The version string for the running niri instance.
Version(String),
/// Information about connected outputs.
///
/// Map from connector name to output info.
Outputs(HashMap<String, Output>),
/// Information about the focused window.
FocusedWindow(Option<Window>),
}
pub type Reply<T> = Result<T, String>;

/// Actions that niri can perform.
// Variants in this enum should match the spelling of the ones in niri-config. Most, but not all,
Expand Down
12 changes: 7 additions & 5 deletions niri-ipc/src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::path::Path;
use serde_json::de::IoRead;
use serde_json::StreamDeserializer;

use crate::{Reply, Request};
use crate::{Reply, Request, ResponseDecoder};

/// Name of the environment variable containing the niri IPC socket path.
pub const SOCKET_PATH_ENV: &str = "NIRI_SOCKET";
Expand All @@ -16,7 +16,7 @@ pub const SOCKET_PATH_ENV: &str = "NIRI_SOCKET";
/// and serialization/deserialization of messages.
pub struct NiriSocket {
stream: UnixStream,
responses: StreamDeserializer<'static, IoRead<UnixStream>, Reply>,
responses: StreamDeserializer<'static, IoRead<UnixStream>, serde_json::Value>,
}

impl TryFrom<UnixStream> for NiriSocket {
Expand Down Expand Up @@ -55,14 +55,16 @@ impl NiriSocket {
/// Ok(Ok([Response](crate::Response))) corresponds to a successful response from the running
/// niri instance. Ok(Err([String])) corresponds to an error received from the running niri
/// instance. Err([std::io::Error]) corresponds to an error in the IPC communication.
pub fn send(mut self, request: Request) -> io::Result<Reply> {
let mut buf = serde_json::to_vec(&request).unwrap();
pub fn send_request<R: Request>(mut self, request: R) -> io::Result<Reply<R::Response>> {
let decoder = request.decoder();
let mut buf = serde_json::to_vec(&request.into_message()).unwrap();
writeln!(buf).unwrap();
self.stream.write_all(&buf)?; // .context("error writing IPC request")?;
self.stream.flush()?;

if let Some(next) = self.responses.next() {
next.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))
next.and_then(|v| decoder.decode(v))
.map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))
} else {
Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
Expand Down
Loading

0 comments on commit f1bfef9

Please sign in to comment.